• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/hash_utils.h"
17 
18 #include "tensorflow/core/framework/function.h"
19 #include "tensorflow/core/framework/function.pb.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/core/framework/variant.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/platform/test_benchmark.h"
27 #include "tensorflow/core/protobuf/config.pb.h"
28 #include "tensorflow/core/protobuf/error_codes.pb.h"
29 #include "tensorflow/core/util/work_sharder.h"
30 
31 namespace tensorflow {
32 namespace data {
33 namespace {
34 using ::testing::ContainsRegex;
35 
36 class DatasetHashUtilsTest : public ::testing::Test {
37  protected:
GetHash(const FunctionDefLibrary & library,const FunctionDef & fn)38   uint64 GetHash(const FunctionDefLibrary& library, const FunctionDef& fn) {
39     // Construct a node with a function as an attr.
40     GraphDef graph_def;
41     *graph_def.mutable_library() = library;
42     NodeDef* node = graph_def.add_node();
43     node->set_op("RemoteCall");
44     NameAttrList func;
45     func.set_name(fn.signature().name());
46     AddNodeAttr("f", func, node);
47     uint64 hash = 0;
48     TF_CHECK_OK(HashNode(graph_def, *node, &hash));
49     return hash;
50   }
51 
CheckEqual(const FunctionDefLibrary & library,const FunctionDef & fn1,const FunctionDef & fn2)52   Status CheckEqual(const FunctionDefLibrary& library, const FunctionDef& fn1,
53                     const FunctionDef& fn2) {
54     // Construct nodes with a function as an attr.
55     GraphDef graph_def;
56     *graph_def.mutable_library() = library;
57 
58     NodeDef* node1 = graph_def.add_node();
59     node1->set_name("RemoteCall");
60     node1->set_op("RemoteCall");
61     NameAttrList func1;
62     func1.set_name(fn1.signature().name());
63     AddNodeAttr("f", func1, node1);
64 
65     NodeDef* node2 = graph_def.add_node();
66     node1->set_name("RemoteCall2");
67     node2->set_op("RemoteCall");
68     NameAttrList func2;
69     func2.set_name(fn2.signature().name());
70     AddNodeAttr("f", func2, node2);
71 
72     return CheckSubgraphsEqual(graph_def, node1, graph_def, node2);
73   }
74 
GetHash(const GraphDef & graph,const NodeDef & node)75   uint64 GetHash(const GraphDef& graph, const NodeDef& node) {
76     uint64 hash = 0;
77     TF_CHECK_OK(HashNode(graph, node, &hash));
78     return hash;
79   }
80 
GetHash(const Tensor & tensor)81   uint64 GetHash(const Tensor& tensor) {
82     uint64 hash = 0;
83     TF_CHECK_OK(HashTensor(tensor, &hash));
84     return hash;
85   }
86 };
87 
TEST_F(DatasetHashUtilsTest,HashFunctionSameFunctionDifferentNames)88 TEST_F(DatasetHashUtilsTest, HashFunctionSameFunctionDifferentNames) {
89   FunctionDefLibrary fl;
90 
91   FunctionDef* f1 = fl.add_function();
92   *f1 = FunctionDefHelper::Create(
93       "AddAndMul", {"i: float"}, {"o: float"}, {},
94       {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
95        {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
96       /*ret_def=*/{{"o", "ret:z:0"}},
97       /*control_ret_def=*/{{"must_execute", "add"}});
98 
99   FunctionDef* f2 = fl.add_function();
100   *f2 = FunctionDefHelper::Create(
101       "AddAndMul2", {"input: float"}, {"o: float"}, {},
102       {{{"add"}, "Add", {"input", "input"}, {{"T", DT_FLOAT}}},
103        {{"ret"}, "Mul", {"input", "input"}, {{"T", DT_FLOAT}}}},
104       /*ret_def=*/{{"o", "ret:z:0"}},
105       /*control_ret_def=*/{{"must_execute", "add"}});
106 
107   EXPECT_EQ(GetHash(fl, *f1), GetHash(fl, *f2));
108   TF_EXPECT_OK(CheckEqual(fl, *f1, *f2));
109 }
110 
TEST_F(DatasetHashUtilsTest,HashFunctionDifferentFunctions)111 TEST_F(DatasetHashUtilsTest, HashFunctionDifferentFunctions) {
112   FunctionDefLibrary fl;
113 
114   FunctionDef* f1 = fl.add_function();
115   *f1 = FunctionDefHelper::Create(
116       "AddAndMul", {"i: float"}, {"o: float"}, {},
117       {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
118        {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
119       /*ret_def=*/{{"o", "ret:z:0"}},
120       /*control_ret_def=*/{{"must_execute", "add"}});
121 
122   FunctionDef* f2 = fl.add_function();
123   *f2 = FunctionDefHelper::Create(
124       "AddAndAdd", {"i: float"}, {"o: float"}, {},
125       {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
126        {{"ret"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}}},
127       /*ret_def=*/{{"o", "ret:z:0"}},
128       /*control_ret_def=*/{{"must_execute", "add"}});
129 
130   // The second op in `f2` is changed to "Add"
131   EXPECT_NE(GetHash(fl, *f1), GetHash(fl, *f2));
132   Status s = CheckEqual(fl, *f1, *f2);
133   EXPECT_NE(s.code(), error::OK);
134   EXPECT_THAT(s.error_message(), ContainsRegex("Add"));
135 }
136 
TEST_F(DatasetHashUtilsTest,HashFunctionDifferentInternalNodeNames)137 TEST_F(DatasetHashUtilsTest, HashFunctionDifferentInternalNodeNames) {
138   FunctionDefLibrary fl;
139 
140   FunctionDef* f1 = fl.add_function();
141   *f1 = FunctionDefHelper::Create(
142       "AddAndMul", {"i: float", "j: float", "k: float"}, {"o: float"}, {},
143       {{{"add"}, "Add", {"i", "j"}, {{"T", DT_FLOAT}}},
144        {{"ret"}, "Mul", {"add:z:0", "k"}, {{"T", DT_FLOAT}}}},
145       /*ret_def=*/{{"o", "ret:z:0"}},
146       /*control_ret_def=*/{{"must_execute", "ret"}});
147 
148   FunctionDef* f2 = fl.add_function();
149   *f2 = FunctionDefHelper::Create(
150       "AddAndMul2", {"a: float", "b: float", "c: float"}, {"o: float"}, {},
151       {{{"add"}, "Add", {"a", "b"}, {{"T", DT_FLOAT}}},
152        {{"mul"}, "Mul", {"add:z:0", "c"}, {{"T", DT_FLOAT}}}},
153       /*ret_def=*/{{"o", "mul:z:0"}},
154       /*control_ret_def=*/{{"must_execute", "mul"}});
155 
156   EXPECT_EQ(GetHash(fl, *f1), GetHash(fl, *f2));
157   TF_EXPECT_OK(CheckEqual(fl, *f1, *f2));
158 }
159 
TEST_F(DatasetHashUtilsTest,HashGraphWithMultipleCycles)160 TEST_F(DatasetHashUtilsTest, HashGraphWithMultipleCycles) {
161   uint64 hash = 0;
162   for (int i = 0; i < 1000; ++i) {
163     GraphDef g;
164     NodeDef* output_node = g.add_node();
165     TF_CHECK_OK(NodeDefBuilder("O", "Add")
166                     .Input("A", 0, DT_FLOAT)
167                     .Input("D", 0, DT_FLOAT)
168                     .Finalize(output_node));
169     TF_CHECK_OK(NodeDefBuilder("A", "Abs")
170                     .Input("B", 0, DT_FLOAT)
171                     .Finalize(g.add_node()));
172     TF_CHECK_OK(NodeDefBuilder("B", "Add")
173                     .Input("C", 0, DT_FLOAT)
174                     .Input("D", 0, DT_FLOAT)
175                     .Finalize(g.add_node()));
176     TF_CHECK_OK(NodeDefBuilder("C", "Ceil")
177                     .Input("A", 0, DT_FLOAT)
178                     .Finalize(g.add_node()));
179     TF_CHECK_OK(NodeDefBuilder("D", "Cos")
180                     .Input("E", 0, DT_FLOAT)
181                     .Finalize(g.add_node()));
182     TF_CHECK_OK(NodeDefBuilder("E", "Floor")
183                     .Input("B", 0, DT_FLOAT)
184                     .Finalize(g.add_node()));
185     uint64 t = GetHash(g, *output_node);
186     if (hash == 0) {
187       hash = t;
188     } else {
189       EXPECT_EQ(t, hash);
190     }
191   }
192 }
193 
TEST_F(DatasetHashUtilsTest,HashNodeSameGraphDifferentNames)194 TEST_F(DatasetHashUtilsTest, HashNodeSameGraphDifferentNames) {
195   GraphDef gd;
196 
197   NodeDef* n1 = gd.add_node();
198   TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
199                   .Attr("value", 1)
200                   .Device("CPU:0")
201                   .Finalize(n1));
202 
203   NodeDef* n2 = gd.add_node();
204   TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "Const")
205                   .Attr("value", 2)
206                   .Device("CPU:0")
207                   .Finalize(n2));
208 
209   NodeDef* n3 = gd.add_node();
210   TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "Add")
211                   .Device("CPU:0")
212                   .Input(n1->name(), 0, DT_INT32)
213                   .Input(n2->name(), 0, DT_INT32)
214                   .Finalize(n3));
215 
216   NodeDef* n4 = gd.add_node();
217   TF_CHECK_OK(NodeDefBuilder("graph_3/node_7", "Const")
218                   .Attr("value", 1)
219                   .Device("CPU:0")
220                   .Finalize(n4));
221 
222   NodeDef* n5 = gd.add_node();
223   TF_CHECK_OK(NodeDefBuilder("graph_4/node_9", "Const")
224                   .Attr("value", 2)
225                   .Device("CPU:0")
226                   .Finalize(n5));
227 
228   NodeDef* n6 = gd.add_node();
229   TF_CHECK_OK(NodeDefBuilder("graph_5/node_11", "Add")
230                   .Device("CPU:0")
231                   .Input(n4->name(), 0, DT_INT32)
232                   .Input(n5->name(), 0, DT_INT32)
233                   .Finalize(n6));
234 
235   uint64 hash1 = GetHash(gd, *n3);
236   uint64 hash2 = GetHash(gd, *n6);
237   EXPECT_EQ(hash1, hash2);
238   TF_EXPECT_OK(CheckSubgraphsEqual(gd, n3, gd, n6));
239 }
240 
TEST_F(DatasetHashUtilsTest,HashNodeDifferentGraphs)241 TEST_F(DatasetHashUtilsTest, HashNodeDifferentGraphs) {
242   GraphDef gd;
243 
244   NodeDef* n1 = gd.add_node();
245   TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
246                   .Attr("value", 1)
247                   .Device("CPU:0")
248                   .Finalize(n1));
249 
250   NodeDef* n2 = gd.add_node();
251   TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "Const")
252                   .Attr("value", 2)
253                   .Device("CPU:0")
254                   .Finalize(n2));
255 
256   NodeDef* n3 = gd.add_node();
257   TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "Add")
258                   .Device("CPU:0")
259                   .Input(n1->name(), 0, DT_INT32)
260                   .Input(n2->name(), 0, DT_INT32)
261                   .Finalize(n3));
262 
263   NodeDef* n4 = gd.add_node();
264   TF_CHECK_OK(NodeDefBuilder("graph_1/node_4", "Mul")
265                   .Device("CPU:0")
266                   .Input(n1->name(), 0, DT_INT32)
267                   .Input(n2->name(), 0, DT_INT32)
268                   .Finalize(n4));
269 
270   uint64 hash1 = GetHash(gd, *n3);
271   uint64 hash2 = GetHash(gd, *n4);
272   // We expect different hashes because the op has changed.
273   EXPECT_NE(hash1, hash2);
274   Status s = CheckSubgraphsEqual(gd, n3, gd, n4);
275   EXPECT_NE(s.code(), error::OK);
276   EXPECT_THAT(s.error_message(), ContainsRegex("Add"));
277   EXPECT_THAT(s.error_message(), ContainsRegex("Mul"));
278 }
279 
TEST_F(DatasetHashUtilsTest,HashSameGraphDifferentSeeds)280 TEST_F(DatasetHashUtilsTest, HashSameGraphDifferentSeeds) {
281   GraphDef gd;
282 
283   NodeDef* n1 = gd.add_node();
284   TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
285                   .Attr("value", 1)
286                   .Device("CPU:0")
287                   .Finalize(n1));
288 
289   NodeDef* seed = gd.add_node();
290   TF_CHECK_OK(NodeDefBuilder("graph_1/seed", "Const")
291                   .Attr("value", 123)
292                   .Device("CPU:0")
293                   .Finalize(seed));
294 
295   NodeDef* seed2 = gd.add_node();
296   TF_CHECK_OK(NodeDefBuilder("graph_1/seed2", "Const")
297                   .Attr("value", 456)
298                   .Device("CPU:0")
299                   .Finalize(seed2));
300 
301   NodeDef* range_ds = gd.add_node();
302   TF_CHECK_OK(NodeDefBuilder("graph_1/range", "RangeDataset")
303                   .Input(n1->name(), 0, DT_INT64)
304                   .Input(n1->name(), 0, DT_INT64)
305                   .Input(n1->name(), 0, DT_INT64)
306                   .Device("CPU:0")
307                   .Finalize(range_ds));
308 
309   NodeDef* shuffle_ds = gd.add_node();
310   TF_CHECK_OK(NodeDefBuilder("graph_1/shuffle", "ShuffleDataset")
311                   .Input(range_ds->name(), 0, DT_VARIANT)
312                   .Input(n1->name(), 0, DT_INT64)
313                   .Input(seed->name(), 0, DT_INT64)
314                   .Input(seed2->name(), 0, DT_INT64)
315                   .Device("CPU:0")
316                   .Finalize(shuffle_ds));
317 
318   NodeDef* different_seed = gd.add_node();
319   TF_CHECK_OK(NodeDefBuilder("graph_1/different_seed", "Const")
320                   .Attr("value", 789)
321                   .Device("CPU:0")
322                   .Finalize(different_seed));
323   NodeDef* different_seed2 = gd.add_node();
324   TF_CHECK_OK(NodeDefBuilder("graph_1/different_seed2", "Const")
325                   .Attr("value", 654)
326                   .Device("CPU:0")
327                   .Finalize(different_seed2));
328 
329   NodeDef* range_ds_2 = gd.add_node();
330   TF_CHECK_OK(NodeDefBuilder("graph_1/range_2", "RangeDataset")
331                   .Input(n1->name(), 0, DT_INT64)
332                   .Input(n1->name(), 0, DT_INT64)
333                   .Input(n1->name(), 0, DT_INT64)
334                   .Device("CPU:0")
335                   .Finalize(range_ds_2));
336 
337   NodeDef* shuffle_ds_2 = gd.add_node();
338   TF_CHECK_OK(NodeDefBuilder("graph_1/shuffle_2", "ShuffleDataset")
339                   .Input(range_ds_2->name(), 0, DT_VARIANT)
340                   .Input(n1->name(), 0, DT_INT64)
341                   .Input(different_seed->name(), 0, DT_INT64)
342                   .Input(different_seed2->name(), 0, DT_INT64)
343                   .Device("CPU:0")
344                   .Finalize(shuffle_ds_2));
345 
346   uint64 hash1 = GetHash(gd, *shuffle_ds);
347   uint64 hash2 = GetHash(gd, *shuffle_ds_2);
348   EXPECT_EQ(hash1, hash2);
349   TF_EXPECT_OK(CheckSubgraphsEqual(gd, shuffle_ds, gd, shuffle_ds_2));
350 }
351 
TEST_F(DatasetHashUtilsTest,HashNodeSameGraphDifferentColocationNames)352 TEST_F(DatasetHashUtilsTest, HashNodeSameGraphDifferentColocationNames) {
353   GraphDef gd;
354 
355   NodeDef* n1 = gd.add_node();
356   TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
357                   .Attr("value", 1)
358                   .Attr("_class", {"graph_1/node_2"})
359                   .Device("CPU:0")
360                   .Finalize(n1));
361 
362   NodeDef* n2 = gd.add_node();
363   TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "Const")
364                   .Attr("value", 2)
365                   .Device("CPU:0")
366                   .Finalize(n2));
367 
368   NodeDef* n3 = gd.add_node();
369   TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "Add")
370                   .Device("CPU:0")
371                   .Input(n1->name(), 0, DT_INT32)
372                   .Input(n2->name(), 0, DT_INT32)
373                   .Finalize(n3));
374 
375   NodeDef* n4 = gd.add_node();
376   TF_CHECK_OK(NodeDefBuilder("graph_3/node_7", "Const")
377                   .Attr("value", 1)
378                   .Attr("_class", {"graph_3/node_9"})
379                   .Device("CPU:0")
380                   .Finalize(n4));
381 
382   NodeDef* n5 = gd.add_node();
383   TF_CHECK_OK(NodeDefBuilder("graph_4/node_9", "Const")
384                   .Attr("value", 2)
385                   .Device("CPU:0")
386                   .Finalize(n5));
387 
388   NodeDef* n6 = gd.add_node();
389   TF_CHECK_OK(NodeDefBuilder("graph_5/node_11", "Add")
390                   .Device("CPU:0")
391                   .Input(n1->name(), 0, DT_INT32)
392                   .Input(n2->name(), 0, DT_INT32)
393                   .Finalize(n6));
394 
395   uint64 hash1 = GetHash(gd, *n3);
396   uint64 hash2 = GetHash(gd, *n6);
397 
398   EXPECT_EQ(hash1, hash2);
399   TF_EXPECT_OK(CheckSubgraphsEqual(gd, n3, gd, n6));
400 }
401 
TEST_F(DatasetHashUtilsTest,HashNodeReversedOrder)402 TEST_F(DatasetHashUtilsTest, HashNodeReversedOrder) {
403   GraphDef gd;
404 
405   NodeDef* n1 = gd.add_node();
406   TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
407                   .Attr("value", 1)
408                   .Device("CPU:0")
409                   .Finalize(n1));
410 
411   NodeDef* n2 = gd.add_node();
412   TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "Const")
413                   .Attr("value", 2)
414                   .Device("CPU:0")
415                   .Finalize(n2));
416 
417   NodeDef* n3 = gd.add_node();
418   TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "Add")
419                   .Device("CPU:0")
420                   .Input(n1->name(), 0, DT_INT32)
421                   .Input(n2->name(), 0, DT_INT32)
422                   .Finalize(n3));
423 
424   NodeDef* n4 = gd.add_node();
425   TF_CHECK_OK(NodeDefBuilder("graph_1/node_4", "Add")
426                   .Device("CPU:0")
427                   .Input(n2->name(), 0, DT_INT32)
428                   .Input(n1->name(), 0, DT_INT32)
429                   .Finalize(n4));
430 
431   uint64 hash1 = GetHash(gd, *n3);
432   uint64 hash2 = GetHash(gd, *n4);
433   // We expect different hashes because the inputs of n3 are swapped.
434   EXPECT_NE(hash1, hash2);
435   Status s = CheckSubgraphsEqual(gd, n3, gd, n4);
436   EXPECT_NE(s.code(), error::OK);
437   EXPECT_THAT(s.error_message(), ContainsRegex("AttrValues are different"));
438 }
439 
TEST_F(DatasetHashUtilsTest,HashNodeInputPortChanged)440 TEST_F(DatasetHashUtilsTest, HashNodeInputPortChanged) {
441   GraphDef gd;
442 
443   NodeDef* n1 = gd.add_node();
444   TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
445                   .Attr("value", 1)
446                   .Device("CPU:0")
447                   .Finalize(n1));
448 
449   NodeDef* n2 = gd.add_node();
450   TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "Const")
451                   .Attr("value", 2)
452                   .Device("CPU:0")
453                   .Finalize(n2));
454 
455   NodeDef* n3 = gd.add_node();
456   TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "Add")
457                   .Device("CPU:0")
458                   .Input(n1->name(), 0, DT_INT32)
459                   .Input(n2->name(), 0, DT_INT32)
460                   .Finalize(n3));
461 
462   NodeDef* n4 = gd.add_node();
463   TF_CHECK_OK(NodeDefBuilder("graph_1/node_4", "Add")
464                   .Device("CPU:0")
465                   .Input(n1->name(), 1, DT_INT32)
466                   .Input(n2->name(), 2, DT_INT32)
467                   .Finalize(n4));
468 
469   uint64 hash1 = GetHash(gd, *n3);
470   uint64 hash2 = GetHash(gd, *n4);
471   // We expect different hashes because the input ports for nodes used by n3
472   // has changed.
473   EXPECT_NE(hash1, hash2);
474   Status s = CheckSubgraphsEqual(gd, n3, gd, n4);
475   EXPECT_NE(s.code(), error::OK);
476   EXPECT_THAT(s.error_message(), ContainsRegex("Node inputs"));
477 }
478 
TEST_F(DatasetHashUtilsTest,HashNodeSameFunctionDifferentNames)479 TEST_F(DatasetHashUtilsTest, HashNodeSameFunctionDifferentNames) {
480   GraphDef gd;
481   FunctionDefLibrary* fl1 = gd.mutable_library();
482 
483   FunctionDef* f1 = fl1->add_function();
484   *f1 = FunctionDefHelper::Create(
485       "AddAndMul", {"i: float"}, {"o: float"}, {},
486       {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
487        {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
488       /*ret_def=*/{{"o", "ret:z:0"}},
489       /*control_ret_def=*/{{"must_execute", "add"}});
490 
491   FunctionDef* f2 = fl1->add_function();
492   *f2 = FunctionDefHelper::Create(
493       "AddAndMul2", {"input: float"}, {"o: float"}, {},
494       {{{"add"}, "Add", {"input", "input"}, {{"T", DT_FLOAT}}},
495        {{"ret"}, "Mul", {"input", "input"}, {{"T", DT_FLOAT}}}},
496       /*ret_def=*/{{"o", "ret:z:0"}},
497       /*control_ret_def=*/{{"must_execute", "add"}});
498 
499   AttrValue a1;
500   NameAttrList* nal1 = a1.mutable_func();
501   nal1->set_name("AddAndMul");
502 
503   NodeDef* n1 = gd.add_node();
504   TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
505                   .Attr("value", 1)
506                   .Device("CPU:0")
507                   .Finalize(n1));
508 
509   std::vector<NodeDefBuilder::NodeOut> func_inputs;
510   func_inputs.emplace_back(n1->name(), 0, DT_FLOAT);
511   func_inputs.emplace_back(n1->name(), 0, DT_FLOAT);
512 
513   NodeDef* n2 = gd.add_node();
514   TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "For")
515                   .Input(n1->name(), 0, DT_INT32)
516                   .Input(n1->name(), 0, DT_INT32)
517                   .Input(n1->name(), 0, DT_INT32)
518                   .Input(func_inputs)
519                   .Attr("body", a1)
520                   .Device("CPU:0")
521                   .Finalize(n2));
522 
523   NodeDef* n3 = gd.add_node();
524   AttrValue a2;
525   NameAttrList* nal2 = a2.mutable_func();
526   nal2->set_name("AddAndMul2");
527 
528   TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "For")
529                   .Input(n1->name(), 0, DT_INT32)
530                   .Input(n1->name(), 0, DT_INT32)
531                   .Input(n1->name(), 0, DT_INT32)
532                   .Input(func_inputs)
533                   .Attr("body", a2)
534                   .Device("CPU:0")
535                   .Finalize(n3));
536 
537   uint64 hash1 = GetHash(gd, *n2);
538   uint64 hash2 = GetHash(gd, *n3);
539   EXPECT_EQ(hash1, hash2);
540   TF_EXPECT_OK(CheckSubgraphsEqual(gd, n2, gd, n3));
541 }
542 
TEST_F(DatasetHashUtilsTest,HashNodeSameFunctionListsDifferentNames)543 TEST_F(DatasetHashUtilsTest, HashNodeSameFunctionListsDifferentNames) {
544   GraphDef gd;
545   FunctionDefLibrary* fl1 = gd.mutable_library();
546 
547   FunctionDef* f1 = fl1->add_function();
548   *f1 = FunctionDefHelper::Create(
549       "AddAndMul", {"i: float"}, {"o: float"}, {},
550       {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
551        {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
552       /*ret_def=*/{{"o", "ret:z:0"}},
553       /*control_ret_def=*/{{"must_execute", "add"}});
554 
555   FunctionDef* f2 = fl1->add_function();
556   *f2 = FunctionDefHelper::Create(
557       "AddAndMul2", {"input: float"}, {"o: float"}, {},
558       {{{"add"}, "Add", {"input", "input"}, {{"T", DT_FLOAT}}},
559        {{"ret"}, "Mul", {"input", "input"}, {{"T", DT_FLOAT}}}},
560       /*ret_def=*/{{"o", "ret:z:0"}},
561       /*control_ret_def=*/{{"must_execute", "add"}});
562 
563   AttrValue a1;
564   AttrValue_ListValue* list1 = a1.mutable_list();
565   NameAttrList* nal1 = list1->add_func();
566   nal1->set_name("AddAndMul");
567 
568   NodeDef* n1 = gd.add_node();
569   TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
570                   .Attr("value", 1)
571                   .Device("CPU:0")
572                   .Finalize(n1));
573 
574   std::vector<NodeDefBuilder::NodeOut> func_inputs;
575   func_inputs.emplace_back(n1->name(), 0, DT_FLOAT);
576   func_inputs.emplace_back(n1->name(), 0, DT_FLOAT);
577 
578   NodeDef* n2 = gd.add_node();
579   TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "For")
580                   .Input(n1->name(), 0, DT_INT32)
581                   .Input(n1->name(), 0, DT_INT32)
582                   .Input(n1->name(), 0, DT_INT32)
583                   .Input(func_inputs)
584                   .Attr("body", a1)
585                   .Device("CPU:0")
586                   .Finalize(n2));
587 
588   NodeDef* n3 = gd.add_node();
589   AttrValue a2;
590   AttrValue_ListValue* list2 = a2.mutable_list();
591   NameAttrList* nal2 = list2->add_func();
592   nal2->set_name("AddAndMul2");
593 
594   TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "For")
595                   .Input(n1->name(), 0, DT_INT32)
596                   .Input(n1->name(), 0, DT_INT32)
597                   .Input(n1->name(), 0, DT_INT32)
598                   .Input(func_inputs)
599                   .Attr("body", a2)
600                   .Device("CPU:0")
601                   .Finalize(n3));
602 
603   uint64 hash1 = GetHash(gd, *n2);
604   uint64 hash2 = GetHash(gd, *n3);
605   EXPECT_EQ(hash1, hash2);
606   TF_EXPECT_OK(CheckSubgraphsEqual(gd, n2, gd, n3));
607 }
608 
TEST_F(DatasetHashUtilsTest,HashNodeSameFunctionsOps)609 TEST_F(DatasetHashUtilsTest, HashNodeSameFunctionsOps) {
610   GraphDef gd;
611 
612   FunctionDefLibrary* fl1 = gd.mutable_library();
613   FunctionDef* f1 = fl1->add_function();
614 
615   FunctionDef func = FunctionDefHelper::Create(
616       "AddAndMul", {"i: float"}, {"o: float"}, {},
617       {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
618        {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
619       /*ret_def=*/{{"o", "ret:z:0"}},
620       /*control_ret_def=*/{{"must_execute", "add"}});
621   *f1 = func;
622 
623   FunctionDef* f2 = fl1->add_function();
624   func = FunctionDefHelper::Create(
625       "AddAndMul2", {"i: float"}, {"o: float"}, {},
626       {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
627        {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
628       /*ret_def=*/{{"o", "ret:z:0"}},
629       /*control_ret_def=*/{{"must_execute", "add"}});
630   *f2 = func;
631   FunctionLibraryDefinition flib(OpRegistry::Global(), gd.library());
632 
633   NodeDef* n1 = gd.add_node();
634   TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
635                   .Attr("value", 1)
636                   .Device("CPU:0")
637                   .Finalize(n1));
638 
639   NodeDef* n2 = gd.add_node();
640   TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "AddAndMul", &flib)
641                   .Input(n1->name(), 0, DT_FLOAT)
642                   .Device("CPU:0")
643                   .Finalize(n2));
644 
645   NodeDef* n3 = gd.add_node();
646   TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "AddAndMul2", &flib)
647                   .Input(n1->name(), 0, DT_FLOAT)
648                   .Device("CPU:0")
649                   .Finalize(n3));
650 
651   uint64 hash1 = GetHash(gd, *n2);
652   uint64 hash2 = GetHash(gd, *n3);
653   EXPECT_EQ(hash1, hash2);
654   TF_EXPECT_OK(CheckSubgraphsEqual(gd, n2, gd, n3));
655 }
656 
TEST_F(DatasetHashUtilsTest,HashNodeDifferentFunctionsOps)657 TEST_F(DatasetHashUtilsTest, HashNodeDifferentFunctionsOps) {
658   GraphDef gd;
659 
660   FunctionDefLibrary* fl1 = gd.mutable_library();
661   FunctionDef* f1 = fl1->add_function();
662 
663   FunctionDef func = FunctionDefHelper::Create(
664       "AddAndMul", {"i: float"}, {"o: float"}, {},
665       {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
666        {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
667       /*ret_def=*/{{"o", "ret:z:0"}},
668       /*control_ret_def=*/{{"must_execute", "add"}});
669   *f1 = func;
670 
671   FunctionDef* f2 = fl1->add_function();
672   func = FunctionDefHelper::Create(
673       "AddAndMul2", {"i: float"}, {"o: float"}, {},
674       {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
675        {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
676       /*ret_def=*/{{"o", "ret:z:0"}},
677       /*control_ret_def=*/{{"must_execute", "ret"}});
678   *f2 = func;
679   FunctionLibraryDefinition flib(OpRegistry::Global(), gd.library());
680 
681   NodeDef* n1 = gd.add_node();
682   TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
683                   .Attr("value", 1)
684                   .Device("CPU:0")
685                   .Finalize(n1));
686 
687   NodeDef* n2 = gd.add_node();
688   TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "AddAndMul", &flib)
689                   .Input(n1->name(), 0, DT_FLOAT)
690                   .Device("CPU:0")
691                   .Finalize(n2));
692 
693   NodeDef* n3 = gd.add_node();
694   TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "AddAndMul2", &flib)
695                   .Input(n1->name(), 0, DT_FLOAT)
696                   .Device("CPU:0")
697                   .Finalize(n3));
698 
699   uint64 hash1 = GetHash(gd, *n2);
700   uint64 hash2 = GetHash(gd, *n3);
701   EXPECT_NE(hash1, hash2);
702   Status s = CheckSubgraphsEqual(gd, n2, gd, n3);
703   EXPECT_NE(s.code(), error::OK);
704   EXPECT_THAT(
705       s.error_message(),
706       ContainsRegex("Functions AddAndMul and AddAndMul2 are not the same"));
707 }
708 
TEST_F(DatasetHashUtilsTest,HashNodeDifferentFunctions)709 TEST_F(DatasetHashUtilsTest, HashNodeDifferentFunctions) {
710   GraphDef gd;
711 
712   FunctionDefLibrary* fl1 = gd.mutable_library();
713   FunctionDef* f1 = fl1->add_function();
714 
715   FunctionDef func = FunctionDefHelper::Create(
716       "AddAndMul", {"i: float"}, {"o: float"}, {},
717       {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
718        {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
719       /*ret_def=*/{{"o", "ret:z:0"}},
720       /*control_ret_def=*/{{"must_execute", "add"}});
721   *f1 = func;
722 
723   FunctionDef* f2 = fl1->add_function();
724   func = FunctionDefHelper::Create(
725       "AddAndMul2", {"i: float"}, {"o: float"}, {},
726       {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
727        {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
728       /*ret_def=*/{{"o", "ret:z:0"}},
729       /*control_ret_def=*/{{"must_execute", "ret"}});
730   *f2 = func;
731 
732   AttrValue a1;
733   NameAttrList* nal1 = a1.mutable_func();
734   nal1->set_name("AddAndMul");
735 
736   NodeDef* n1 = gd.add_node();
737   TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
738                   .Attr("value", 1)
739                   .Device("CPU:0")
740                   .Finalize(n1));
741 
742   std::vector<NodeDefBuilder::NodeOut> func_inputs;
743   func_inputs.emplace_back(n1->name(), 0, DT_FLOAT);
744   func_inputs.emplace_back(n1->name(), 0, DT_FLOAT);
745 
746   NodeDef* n2 = gd.add_node();
747   TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "For")
748                   .Input(n1->name(), 0, DT_INT32)
749                   .Input(n1->name(), 0, DT_INT32)
750                   .Input(n1->name(), 0, DT_INT32)
751                   .Input(func_inputs)
752                   .Attr("body", a1)
753                   .Device("CPU:0")
754                   .Finalize(n2));
755 
756   NodeDef* n3 = gd.add_node();
757   AttrValue a2;
758   NameAttrList* nal2 = a2.mutable_func();
759   nal2->set_name("AddAndMul2");
760 
761   TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "For")
762                   .Input(n1->name(), 0, DT_INT32)
763                   .Input(n1->name(), 0, DT_INT32)
764                   .Input(n1->name(), 0, DT_INT32)
765                   .Input(func_inputs)
766                   .Attr("body", a2)
767                   .Device("CPU:0")
768                   .Finalize(n3));
769 
770   uint64 hash1 = GetHash(gd, *n2);
771   uint64 hash2 = GetHash(gd, *n3);
772   EXPECT_NE(hash1, hash2);
773   Status s = CheckSubgraphsEqual(gd, n2, gd, n3);
774   EXPECT_NE(s.code(), error::OK);
775   EXPECT_THAT(
776       s.error_message(),
777       ContainsRegex("Functions AddAndMul and AddAndMul2 are not the same"));
778 }
779 
TEST_F(DatasetHashUtilsTest,HashNodeDifferentFunctionLists)780 TEST_F(DatasetHashUtilsTest, HashNodeDifferentFunctionLists) {
781   GraphDef gd;
782 
783   FunctionDefLibrary* fl1 = gd.mutable_library();
784   FunctionDef* f1 = fl1->add_function();
785 
786   FunctionDef func = FunctionDefHelper::Create(
787       "AddAndMul", {"i: float"}, {"o: float"}, {},
788       {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
789        {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
790       /*ret_def=*/{{"o", "ret:z:0"}},
791       /*control_ret_def=*/{{"must_execute", "add"}});
792   *f1 = func;
793 
794   FunctionDef* f2 = fl1->add_function();
795   func = FunctionDefHelper::Create(
796       "AddAndMul2", {"i: float"}, {"o: float"}, {},
797       {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
798        {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
799       /*ret_def=*/{{"o", "ret:z:0"}},
800       /*control_ret_def=*/{{"must_execute", "ret"}});
801   *f2 = func;
802 
803   AttrValue a1;
804   AttrValue_ListValue* list1 = a1.mutable_list();
805   NameAttrList* nal1 = list1->add_func();
806   nal1->set_name("AddAndMul");
807 
808   NodeDef* n1 = gd.add_node();
809   TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
810                   .Attr("value", 1)
811                   .Device("CPU:0")
812                   .Finalize(n1));
813 
814   std::vector<NodeDefBuilder::NodeOut> func_inputs;
815   func_inputs.emplace_back(n1->name(), 0, DT_FLOAT);
816   func_inputs.emplace_back(n1->name(), 0, DT_FLOAT);
817 
818   NodeDef* n2 = gd.add_node();
819   TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "For")
820                   .Input(n1->name(), 0, DT_INT32)
821                   .Input(n1->name(), 0, DT_INT32)
822                   .Input(n1->name(), 0, DT_INT32)
823                   .Input(func_inputs)
824                   .Attr("body", a1)
825                   .Device("CPU:0")
826                   .Finalize(n2));
827 
828   NodeDef* n3 = gd.add_node();
829   AttrValue a2;
830   AttrValue_ListValue* list2 = a2.mutable_list();
831   NameAttrList* nal2 = list2->add_func();
832   nal2->set_name("AddAndMul2");
833 
834   TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "For")
835                   .Input(n1->name(), 0, DT_INT32)
836                   .Input(n1->name(), 0, DT_INT32)
837                   .Input(n1->name(), 0, DT_INT32)
838                   .Input(func_inputs)
839                   .Attr("body", a2)
840                   .Device("CPU:0")
841                   .Finalize(n3));
842 
843   uint64 hash1 = GetHash(gd, *n2);
844   uint64 hash2 = GetHash(gd, *n3);
845   EXPECT_NE(hash1, hash2);
846   Status s = CheckSubgraphsEqual(gd, n2, gd, n3);
847   EXPECT_NE(s.code(), error::OK);
848   EXPECT_THAT(
849       s.error_message(),
850       ContainsRegex("Functions AddAndMul and AddAndMul2 are not the same"));
851 }
852 
TEST_F(DatasetHashUtilsTest,HashNodeDifferentControlInputs)853 TEST_F(DatasetHashUtilsTest, HashNodeDifferentControlInputs) {
854   GraphDef gd;
855 
856   NodeDef* n1 = gd.add_node();
857   TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
858                   .Attr("value", 1)
859                   .Device("CPU:0")
860                   .Finalize(n1));
861 
862   NodeDef* n2 = gd.add_node();
863   TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "Const")
864                   .Attr("value", 2)
865                   .Device("CPU:0")
866                   .Finalize(n2));
867 
868   NodeDef* n3 = gd.add_node();
869   TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "Const")
870                   .Attr("value", 10)
871                   .Device("CPU:0")
872                   .Finalize(n3));
873 
874   NodeDef* n4 = gd.add_node();
875   TF_CHECK_OK(NodeDefBuilder("graph_1/node_4", "Identity")
876                   .Device("CPU:0")
877                   .Input(n1->name(), 0, DT_INT32)
878                   .ControlInput(n2->name())
879                   .Finalize(n4));
880 
881   NodeDef* n5 = gd.add_node();
882   TF_CHECK_OK(NodeDefBuilder("graph_1/node_5", "Identity")
883                   .Device("CPU:0")
884                   .Input(n1->name(), 0, DT_INT32)
885                   .ControlInput(n3->name())
886                   .Finalize(n5));
887 
888   // Control inputs are different between these two graphs.
889   uint64 hash1 = GetHash(gd, *n4);
890   uint64 hash2 = GetHash(gd, *n5);
891   EXPECT_NE(hash1, hash2);
892   Status s = CheckSubgraphsEqual(gd, n4, gd, n5);
893   EXPECT_NE(s.code(), error::OK);
894   EXPECT_THAT(s.error_message(),
895               ContainsRegex("Control dependencies are different"));
896 }
897 
TEST_F(DatasetHashUtilsTest,HashNodeControlInputDifferentOrdering)898 TEST_F(DatasetHashUtilsTest, HashNodeControlInputDifferentOrdering) {
899   GraphDef gd;
900 
901   NodeDef* n1 = gd.add_node();
902   TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
903                   .Attr("value", 1)
904                   .Device("CPU:0")
905                   .Finalize(n1));
906 
907   NodeDef* n2 = gd.add_node();
908   TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "Const")
909                   .Attr("value", 2)
910                   .Device("CPU:0")
911                   .Finalize(n2));
912 
913   NodeDef* n3 = gd.add_node();
914   TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "Const")
915                   .Attr("value", 10)
916                   .Device("CPU:0")
917                   .Finalize(n3));
918 
919   NodeDef* n4 = gd.add_node();
920   TF_CHECK_OK(NodeDefBuilder("graph_1/node_4", "Identity")
921                   .Device("CPU:0")
922                   .Input(n1->name(), 0, DT_INT32)
923                   .ControlInput(n2->name())
924                   .ControlInput(n3->name())
925                   .Finalize(n4));
926 
927   NodeDef* n5 = gd.add_node();
928   TF_CHECK_OK(NodeDefBuilder("graph_1/node_5", "Identity")
929                   .Device("CPU:0")
930                   .Input(n1->name(), 0, DT_INT32)
931                   .ControlInput(n3->name())
932                   .ControlInput(n2->name())
933                   .Finalize(n5));
934 
935   uint64 hash1 = GetHash(gd, *n4);
936   uint64 hash2 = GetHash(gd, *n5);
937   EXPECT_EQ(hash1, hash2);
938   TF_EXPECT_OK(CheckSubgraphsEqual(gd, n4, gd, n5));
939 }
940 
TEST_F(DatasetHashUtilsTest,HashNodeDifferentGraphSamePartialGraph)941 TEST_F(DatasetHashUtilsTest, HashNodeDifferentGraphSamePartialGraph) {
942   GraphDef gd;
943 
944   NodeDef* n1 = gd.add_node();
945   TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
946                   .Attr("value", 1)
947                   .Device("CPU:0")
948                   .Finalize(n1));
949 
950   NodeDef* n2 = gd.add_node();
951   TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "Const")
952                   .Attr("value", 2)
953                   .Device("CPU:0")
954                   .Finalize(n2));
955 
956   NodeDef* n3 = gd.add_node();
957 
958   TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "Add")
959                   .Device("CPU:0")
960                   .Input(n1->name(), 0, DT_INT32)
961                   .Input(n2->name(), 0, DT_INT32)
962                   .Finalize(n3));
963 
964   uint64 hash1 = GetHash(gd, *n1);
965 
966   n3->Clear();
967   TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "Mul")
968                   .Device("CPU:0")
969                   .Input(n1->name(), 0, DT_INT32)
970                   .Input(n2->name(), 0, DT_INT32)
971                   .Finalize(n3));
972 
973   uint64 hash2 = GetHash(gd, *n1);
974 
975   EXPECT_EQ(hash1, hash2);
976 }
977 
TEST_F(DatasetHashUtilsTest,HashNodeWithManyControlDependencies)978 TEST_F(DatasetHashUtilsTest, HashNodeWithManyControlDependencies) {
979   GraphDef gd;
980   NodeDef* n;
981 
982   for (int i = 0; i < 1000; ++i) {
983     n = gd.add_node();
984     NodeDefBuilder ndb(absl::StrCat("graph_1/node_", i), "Const");
985     ndb.Attr("value", 1);
986     ndb.Device("CPU:0");
987     for (int j = 0; j < i; ++j) {
988       ndb.ControlInput(absl::StrCat("graph_1/node_", j));
989     }
990     TF_CHECK_OK(ndb.Finalize(n));
991   }
992 
993   // No checks here, because so long as this does not time out, we are OK.
994   GetHash(gd, *n);
995 }
996 
TEST_F(DatasetHashUtilsTest,HashFunctionsWithControlDependencyLoop)997 TEST_F(DatasetHashUtilsTest, HashFunctionsWithControlDependencyLoop) {
998   GraphDef gd;
999 
1000   FunctionDefLibrary* fl1 = gd.mutable_library();
1001   FunctionDef* f1 = fl1->add_function();
1002 
1003   AttrValue a1;
1004   NameAttrList* nal1 = a1.mutable_func();
1005   nal1->set_name("AddAndMul");
1006 
1007   std::pair<string, FunctionDefHelper::AttrValueWrapper> func_attr = {
1008       "body", FunctionDefHelper::AttrValueWrapper(*nal1)};
1009 
1010   FunctionDef func = FunctionDefHelper::Create(
1011       /*function_name=*/"AddAndMul",
1012       /*in_def=*/{"i: float", "j: int32"},
1013       /*out_def=*/{"o: float"},
1014       /*attr_def=*/{},
1015       /*node_def=*/
1016       {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}, {"ret"}},
1017        // This creates a dependency on the same function.
1018        {{"for"}, "For", {"j", "j", "j"}, {func_attr, {"T", DT_FLOAT}}, {"ret"}},
1019        {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
1020       /*ret_def=*/{{"o", "ret:z:0"}},
1021       /*control_ret_def=*/{{"must_execute", "add"}});
1022   *f1 = func;
1023 
1024   NodeDef* n1 = gd.add_node();
1025   TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
1026                   .Attr("value", 1)
1027                   .Device("CPU:0")
1028                   .Finalize(n1));
1029 
1030   std::vector<NodeDefBuilder::NodeOut> func_inputs;
1031   func_inputs.emplace_back(n1->name(), 0, DT_FLOAT);
1032   func_inputs.emplace_back(n1->name(), 0, DT_FLOAT);
1033 
1034   NodeDef* n2 = gd.add_node();
1035   TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "For")
1036                   .Input(n1->name(), 0, DT_INT32)
1037                   .Input(n1->name(), 0, DT_INT32)
1038                   .Input(n1->name(), 0, DT_INT32)
1039                   .Input(func_inputs)
1040                   .ControlInput("graph_1/node_2")
1041                   .Attr("body", a1)
1042                   .Device("CPU:0")
1043                   .Finalize(n2));
1044 
1045   // No checks in the test, the fact that it runs and doesn't timeout or exhaust
1046   // the stack means it is successful.
1047   GetHash(gd, *n2);
1048 }
1049 
TEST_F(DatasetHashUtilsTest,HashNodeWithControlDependencyLoop)1050 TEST_F(DatasetHashUtilsTest, HashNodeWithControlDependencyLoop) {
1051   GraphDef gd;
1052 
1053   NodeDef* n1 = gd.add_node();
1054   TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
1055                   .Attr("value", 1)
1056                   .Device("CPU:0")
1057                   .ControlInput("graph_1/node_2")
1058                   .Finalize(n1));
1059 
1060   NodeDef* n2 = gd.add_node();
1061   TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "Const")
1062                   .Attr("value", 2)
1063                   .Device("CPU:0")
1064                   .ControlInput("graph_1/node_1")
1065                   .Finalize(n2));
1066 
1067   NodeDef* n3 = gd.add_node();
1068   TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "Add")
1069                   .Device("CPU:0")
1070                   .Input(n1->name(), 0, DT_INT32)
1071                   .Input(n2->name(), 0, DT_INT32)
1072                   .ControlInput("graph_1/node_1")
1073                   .ControlInput("graph_1/node_2")
1074                   .Finalize(n3));
1075 
1076   // No checks in the test, the fact that it runs and doesn't timeout or exhaust
1077   // the stack means it is successful.
1078   GetHash(gd, *n3);
1079 }
1080 
TEST_F(DatasetHashUtilsTest,HashNodeWithControlDependencyLoopDifferentNames)1081 TEST_F(DatasetHashUtilsTest, HashNodeWithControlDependencyLoopDifferentNames) {
1082   GraphDef gd1;
1083 
1084   NodeDef* n1 = gd1.add_node();
1085   TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const")
1086                   .Attr("value", 1)
1087                   .Device("CPU:0")
1088                   .ControlInput("graph_1/node_2")
1089                   .Finalize(n1));
1090 
1091   NodeDef* n2 = gd1.add_node();
1092   TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "Const")
1093                   .Attr("value", 2)
1094                   .Device("CPU:0")
1095                   .ControlInput("graph_1/node_1")
1096                   .Finalize(n2));
1097 
1098   NodeDef* n3 = gd1.add_node();
1099   TF_CHECK_OK(NodeDefBuilder("graph_1/node_3", "Add")
1100                   .Device("CPU:0")
1101                   .Input(n1->name(), 0, DT_INT32)
1102                   .Input(n2->name(), 0, DT_INT32)
1103                   .ControlInput("graph_1/node_1")
1104                   .ControlInput("graph_1/node_2")
1105                   .Finalize(n3));
1106 
1107   GraphDef gd2;
1108 
1109   NodeDef* n4 = gd2.add_node();
1110   TF_CHECK_OK(NodeDefBuilder("graph_1/node_4", "Const")
1111                   .Attr("value", 1)
1112                   .Device("CPU:0")
1113                   .ControlInput("graph_1/node_5")
1114                   .Finalize(n4));
1115 
1116   NodeDef* n5 = gd2.add_node();
1117   TF_CHECK_OK(NodeDefBuilder("graph_1/node_5", "Const")
1118                   .Attr("value", 2)
1119                   .Device("CPU:0")
1120                   .ControlInput("graph_1/node_4")
1121                   .Finalize(n5));
1122 
1123   NodeDef* n6 = gd2.add_node();
1124   TF_CHECK_OK(NodeDefBuilder("graph_1/node_6", "Add")
1125                   .Device("CPU:0")
1126                   .Input(n4->name(), 0, DT_INT32)
1127                   .Input(n5->name(), 0, DT_INT32)
1128                   .ControlInput("graph_1/node_4")
1129                   .ControlInput("graph_1/node_5")
1130                   .Finalize(n6));
1131 
1132   EXPECT_EQ(GetHash(gd1, *n3), GetHash(gd2, *n6));
1133 }
1134 
TEST_F(DatasetHashUtilsTest,HashInt32Tensor)1135 TEST_F(DatasetHashUtilsTest, HashInt32Tensor) {
1136   Tensor s1(42);
1137   Tensor s2(42);
1138   Tensor s3(43);
1139 
1140   EXPECT_EQ(GetHash(s1), GetHash(s2));
1141   EXPECT_NE(GetHash(s1), GetHash(s3));
1142 
1143   Tensor v1(DT_INT32, TensorShape({2}));
1144   v1.vec<int32>()(0) = 0;
1145   v1.vec<int32>()(1) = 1;
1146   Tensor v2(DT_INT32, TensorShape({2}));
1147   v2.vec<int32>()(0) = 0;
1148   v2.vec<int32>()(1) = 1;
1149   Tensor v3(DT_INT32, TensorShape({2}));
1150   v3.vec<int32>()(0) = 0;
1151   v3.vec<int32>()(1) = 2;
1152 
1153   EXPECT_EQ(GetHash(v1), GetHash(v2));
1154   EXPECT_NE(GetHash(v1), GetHash(v3));
1155 }
1156 
TEST_F(DatasetHashUtilsTest,HashStringTensor)1157 TEST_F(DatasetHashUtilsTest, HashStringTensor) {
1158   Tensor s1("hello");
1159   Tensor s2("hello");
1160   Tensor s3("world");
1161 
1162   EXPECT_EQ(GetHash(s1), GetHash(s2));
1163   EXPECT_NE(GetHash(s1), GetHash(s3));
1164 
1165   Tensor v1(DT_STRING, TensorShape({2}));
1166   v1.vec<tstring>()(0) = "hello";
1167   v1.vec<tstring>()(1) = "world";
1168   Tensor v2(DT_STRING, TensorShape({2}));
1169   v2.vec<tstring>()(0) = "hello";
1170   v2.vec<tstring>()(1) = "world";
1171   Tensor v3(DT_STRING, TensorShape({2}));
1172   v3.vec<tstring>()(0) = "hello";
1173   v3.vec<tstring>()(1) = "universe";
1174 
1175   EXPECT_EQ(GetHash(v1), GetHash(v2));
1176   EXPECT_NE(GetHash(v1), GetHash(v3));
1177 }
1178 
1179 // Benchmark that simulates a shallow and wide graph.
BM_ParallelFunctionCallsGraph(benchmark::State & state)1180 static void BM_ParallelFunctionCallsGraph(benchmark::State& state) {
1181   GraphDef graph_def;
1182   FunctionDefLibrary* fl = graph_def.mutable_library();
1183 
1184   FunctionDef* fd = fl->add_function();
1185   *fd = FunctionDefHelper::Create(
1186       "AddAndMul", {"i: float"}, {"o: float"}, {},
1187       {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
1188        {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
1189       /*ret_def=*/{{"o", "ret:z:0"}},
1190       /*control_ret_def=*/{{"must_execute", "add"}});
1191 
1192   NodeDef* input = graph_def.add_node();
1193   input->set_name("InputPlaceholder");
1194   input->set_op("Placeholder");
1195   AddNodeAttr("dtype", DT_FLOAT, input);
1196 
1197   // Equivalent of a `tf.group()`.
1198   NodeDef* target = graph_def.add_node();
1199   target->set_name("Target");
1200   target->set_op("NoOp");
1201 
1202   // Create 100 parallel PartitionedCalls that all depend on input. Generate a
1203   // NodeDef that has close to similar attributes that TensorFlow will generate.
1204   ConfigProto config_pb;
1205   config_pb.mutable_device_count()->insert({"CPU", 1});
1206   config_pb.mutable_device_count()->insert({"GPU", 1});
1207   config_pb.set_allow_soft_placement(true);
1208   for (int i = 0; i < 100; ++i) {
1209     NodeDef* node = graph_def.add_node();
1210     node->set_name(absl::StrCat("PartitionedCall_", i));
1211     node->set_op("PartitionedCall");
1212     *node->add_input() = input->name();
1213     AddNodeAttr("Tin", DT_FLOAT, node);
1214     AddNodeAttr("Tout", DT_FLOAT, node);
1215     AddNodeAttr("config", "", node);
1216     AddNodeAttr("config_proto", config_pb.SerializeAsString(), node);
1217     NameAttrList func;
1218     func.set_name(fd->signature().name());
1219     AddNodeAttr("f", func, node);
1220     *target->add_input() = absl::StrCat("^", node->name());
1221   }
1222 
1223   uint64 hash_value;
1224   for (auto _ : state) {
1225     CHECK(HashNode(graph_def, *target, &hash_value).ok());
1226   }
1227 }
1228 BENCHMARK(BM_ParallelFunctionCallsGraph);
1229 
1230 // Benchmark that simulates a narrow and deep graph.
BM_ChainedFunctionCallsGraph(benchmark::State & state)1231 static void BM_ChainedFunctionCallsGraph(benchmark::State& state) {
1232   GraphDef graph_def;
1233   FunctionDefLibrary* fl = graph_def.mutable_library();
1234 
1235   FunctionDef* fd = fl->add_function();
1236   *fd = FunctionDefHelper::Create(
1237       "AddAndMul", {"i: float"}, {"o: float"}, {},
1238       {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
1239        {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
1240       /*ret_def=*/{{"o", "ret:z:0"}},
1241       /*control_ret_def=*/{{"must_execute", "add"}});
1242 
1243   NodeDef* input = graph_def.add_node();
1244   input->set_name("InputPlaceholder");
1245   input->set_op("Placeholder");
1246   AddNodeAttr("dtype", DT_FLOAT, input);
1247 
1248   // Create 100 chained PartitionedCalls, each depending on the previous.
1249   // Generate a NodeDef that has close to similar attributes that TensorFlow
1250   // will generate.
1251   ConfigProto config_pb;
1252   config_pb.mutable_device_count()->insert({"CPU", 1});
1253   config_pb.mutable_device_count()->insert({"GPU", 1});
1254   config_pb.set_allow_soft_placement(true);
1255   for (int i = 0; i < 100; ++i) {
1256     NodeDef* node = graph_def.add_node();
1257     node->set_name(absl::StrCat("PartitionedCall_", i));
1258     node->set_op("PartitionedCall");
1259     if (i > 0) {
1260       *node->add_input() = absl::StrCat("PartitionedCall_", i - 1);
1261     } else {
1262       *node->add_input() = input->name();
1263     }
1264     AddNodeAttr("Tin", DT_FLOAT, node);
1265     AddNodeAttr("Tout", DT_FLOAT, node);
1266     AddNodeAttr("config", "", node);
1267     AddNodeAttr("config_proto", config_pb.SerializeAsString(), node);
1268     NameAttrList func;
1269     func.set_name(fd->signature().name());
1270     AddNodeAttr("f", func, node);
1271   }
1272 
1273   const NodeDef& target = graph_def.node(graph_def.node_size() - 1);
1274 
1275   uint64 hash_value;
1276   for (auto _ : state) {
1277     CHECK(HashNode(graph_def, target, &hash_value).ok());
1278   }
1279 }
1280 BENCHMARK(BM_ChainedFunctionCallsGraph);
1281 
1282 // Benchmark that simulates many nested function calls.
BM_ComposedFunctionCallsGraph(benchmark::State & state)1283 static void BM_ComposedFunctionCallsGraph(benchmark::State& state) {
1284   GraphDef graph_def;
1285   FunctionDefLibrary* fl = graph_def.mutable_library();
1286 
1287   // AddAndMul will be the last function, all others will be calls up to this.
1288   FunctionDef* fd = fl->add_function();
1289   *fd = FunctionDefHelper::Create(
1290       "AddAndMul", {"i: float"}, {"o: float"}, {},
1291       {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
1292        {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
1293       /*ret_def=*/{{"o", "ret:z:0"}},
1294       /*control_ret_def=*/{{"must_execute", "add"}});
1295 
1296   ConfigProto config_pb;
1297   config_pb.mutable_device_count()->insert({"CPU", 1});
1298   config_pb.mutable_device_count()->insert({"GPU", 1});
1299   config_pb.set_allow_soft_placement(true);
1300   for (int i = 0; i < 99; ++i) {
1301     // Get the name fo the previous function
1302     NameAttrList func;
1303     func.set_name(fd->signature().name());
1304 
1305     FunctionDef* fd = fl->add_function();
1306     *fd = FunctionDefHelper::Create(
1307         /*function_name=*/absl::StrCat("F_", i),
1308         /*in_def=*/{"i: float"},
1309         /*out_def=*/{"o: float"},
1310         /*attr_def=*/{},
1311         /*node_def=*/
1312         {
1313             {
1314                 {"inner_call"},
1315                 "PartitionedCall",
1316                 {"i"},
1317                 {{"Ti", DT_FLOAT},
1318                  {"Tout", DT_FLOAT},
1319                  {"config", ""},
1320                  {"config_proto", config_pb.SerializeAsString()},
1321                  {"f", func}},
1322             },
1323         },
1324         /*ret_def=*/{{"o", "inner_call:o:0"}},
1325         /*control_ret_def=*/{{"must_execute", "inner_call"}});
1326   }
1327 
1328   NodeDef* input = graph_def.add_node();
1329   input->set_name("InputPlaceholder");
1330   input->set_op("Placeholder");
1331   AddNodeAttr("dtype", DT_FLOAT, input);
1332 
1333   // Create call to the outer most function.
1334   NodeDef* node = graph_def.add_node();
1335   node->set_name("PartitionedCall_start");
1336   node->set_op("PartitionedCall");
1337   *node->add_input() = input->name();
1338   AddNodeAttr("Tin", DT_FLOAT, node);
1339   AddNodeAttr("Tout", DT_FLOAT, node);
1340   AddNodeAttr("config", "", node);
1341   AddNodeAttr("config_proto", config_pb.SerializeAsString(), node);
1342   NameAttrList func;
1343   func.set_name(fd->signature().name());
1344   AddNodeAttr("f", func, node);
1345 
1346   const NodeDef& target = graph_def.node(graph_def.node_size() - 1);
1347 
1348   uint64 hash_value;
1349   for (auto _ : state) {
1350     CHECK(HashNode(graph_def, target, &hash_value).ok());
1351   }
1352 }
1353 BENCHMARK(BM_ComposedFunctionCallsGraph);
1354 
1355 }  // namespace
1356 }  // namespace data
1357 }  // namespace tensorflow
1358