1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
17
18 #include "tensorflow/core/framework/function_testlib.h"
19 #include "tensorflow/core/graph/node_builder.h"
20 #include "tensorflow/core/lib/core/status_test_util.h"
21 #include "tensorflow/core/platform/test.h"
22
23 namespace tensorflow {
24 namespace grappler {
25 namespace graph_utils {
26 namespace {
27
TEST(GraphUtilsTest,GetFirstElementIndexWithPredicate)28 TEST(GraphUtilsTest, GetFirstElementIndexWithPredicate) {
29 std::vector<int> vec({1, 2, 3, 4, 5, 6});
30 auto result = GetFirstElementIndexWithPredicate(
31 [](int elem) { return elem % 3 == 0; }, vec);
32
33 EXPECT_EQ(result, 2);
34
35 result = GetFirstElementIndexWithPredicate(
36 [](int elem) { return elem % 7 == 0; }, vec);
37 EXPECT_EQ(result, -1);
38 }
39
TEST(GraphUtilsTest,AddScalarConstNodeBool)40 TEST(GraphUtilsTest, AddScalarConstNodeBool) {
41 GraphDef graph_def;
42 MutableGraphView graph(&graph_def);
43 NodeDef* bool_node = AddScalarConstNode<bool>(true, &graph);
44 EXPECT_TRUE(ContainsGraphNodeWithName(bool_node->name(), *graph.graph()));
45 EXPECT_EQ(bool_node->attr().at("value").tensor().bool_val(0), true);
46 }
47
TEST(GraphUtilsTest,AddScalarConstNodeDouble)48 TEST(GraphUtilsTest, AddScalarConstNodeDouble) {
49 GraphDef graph_def;
50 MutableGraphView graph(&graph_def);
51 NodeDef* double_node = AddScalarConstNode<double>(3.14, &graph);
52 EXPECT_TRUE(ContainsGraphNodeWithName(double_node->name(), *graph.graph()));
53 EXPECT_FLOAT_EQ(double_node->attr().at("value").tensor().double_val(0), 3.14);
54 }
55
TEST(GraphUtilsTest,AddScalarConstNodeFloat)56 TEST(GraphUtilsTest, AddScalarConstNodeFloat) {
57 GraphDef graph_def;
58 MutableGraphView graph(&graph_def);
59 NodeDef* float_node = AddScalarConstNode<float>(3.14, &graph);
60 EXPECT_TRUE(ContainsGraphNodeWithName(float_node->name(), *graph.graph()));
61 EXPECT_FLOAT_EQ(float_node->attr().at("value").tensor().float_val(0), 3.14);
62 }
63
TEST(GraphUtilsTest,AddScalarConstNodeInt)64 TEST(GraphUtilsTest, AddScalarConstNodeInt) {
65 GraphDef graph_def;
66 MutableGraphView graph(&graph_def);
67 NodeDef* int_node = AddScalarConstNode<int>(42, &graph);
68 EXPECT_TRUE(ContainsGraphNodeWithName(int_node->name(), *graph.graph()));
69 EXPECT_EQ(int_node->attr().at("value").tensor().int_val(0), 42);
70 }
71
TEST(GraphUtilsTest,AddScalarConstNodeInt64)72 TEST(GraphUtilsTest, AddScalarConstNodeInt64) {
73 GraphDef graph_def;
74 MutableGraphView graph(&graph_def);
75 NodeDef* int64_node = AddScalarConstNode<int64>(42, &graph);
76 EXPECT_TRUE(ContainsGraphNodeWithName(int64_node->name(), *graph.graph()));
77 EXPECT_EQ(int64_node->attr().at("value").tensor().int64_val(0), 42);
78 }
79
TEST(GraphUtilsTest,AddScalarConstNodeString)80 TEST(GraphUtilsTest, AddScalarConstNodeString) {
81 GraphDef graph_def;
82 MutableGraphView graph(&graph_def);
83 NodeDef* string_node = AddScalarConstNode<StringPiece>("hello", &graph);
84 EXPECT_TRUE(ContainsGraphNodeWithName(string_node->name(), *graph.graph()));
85 EXPECT_EQ(string_node->attr().at("value").tensor().string_val(0), "hello");
86 }
87
TEST(GraphUtilsTest,GetScalarConstNodeInt64)88 TEST(GraphUtilsTest, GetScalarConstNodeInt64) {
89 GraphDef graph_def;
90 MutableGraphView graph(&graph_def);
91 NodeDef* int64_node = AddScalarConstNode<int64>(128, &graph);
92 int64 result;
93 EXPECT_TRUE(GetScalarConstNodeValue<int64>(*int64_node, &result).ok());
94 EXPECT_EQ(result, 128);
95 }
96
TEST(GraphUtilsTest,GetScalarConstNodeBool)97 TEST(GraphUtilsTest, GetScalarConstNodeBool) {
98 GraphDef graph_def;
99 MutableGraphView graph(&graph_def);
100 NodeDef* bool_node = AddScalarConstNode<bool>(true, &graph);
101 bool result;
102 EXPECT_TRUE(GetScalarConstNodeValue<bool>(*bool_node, &result).ok());
103 EXPECT_EQ(result, true);
104 }
105
TEST(GraphUtilsTest,GetScalarConstNodeErrorWithNonConst)106 TEST(GraphUtilsTest, GetScalarConstNodeErrorWithNonConst) {
107 GraphDef graph_def;
108 MutableGraphView graph(&graph_def);
109 NodeDef* non_const = AddScalarPlaceholder(DT_INT64, &graph);
110 int64 result;
111 Status s = GetScalarConstNodeValue<int64>(*non_const, &result);
112 EXPECT_FALSE(s.ok());
113 EXPECT_EQ(s.error_message(),
114 "Node Placeholder is not a Const node. Op: Placeholder");
115 }
116
TEST(GraphUtilsTest,GetScalarConstNodeErrorWithType)117 TEST(GraphUtilsTest, GetScalarConstNodeErrorWithType) {
118 GraphDef graph_def;
119 MutableGraphView graph(&graph_def);
120 NodeDef* int64_node = AddScalarConstNode<int64>(128, &graph);
121 bool result;
122 Status s = GetScalarConstNodeValue<bool>(*int64_node, &result);
123 EXPECT_FALSE(s.ok());
124 EXPECT_EQ(s.error_message(),
125 "Node Const should have type bool but has type: int64");
126 }
127
TEST(GraphUtilsTest,GetScalarConstNodeErrorWithVector)128 TEST(GraphUtilsTest, GetScalarConstNodeErrorWithVector) {
129 NodeDef node;
130 node.set_name("Const");
131 node.set_op("Const");
132
133 (*node.mutable_attr())["dtype"].set_type(DT_INT64);
134 auto tensor = (*node.mutable_attr())["value"].mutable_tensor();
135 tensor->set_dtype(DT_INT64);
136 tensor->mutable_tensor_shape()->mutable_dim()->Add()->set_size(1);
137 tensor->add_int64_val(128);
138
139 int64 result;
140 Status s = GetScalarConstNodeValue<int64>(node, &result);
141 EXPECT_FALSE(s.ok());
142 EXPECT_EQ(s.error_message(),
143 "Node Const should be a scalar but has shape: [1]");
144 }
145
TEST(GraphUtilsTest,Compare)146 TEST(GraphUtilsTest, Compare) {
147 GraphDef graph_def_a;
148 MutableGraphView graph_a(&graph_def_a);
149 GraphDef graph_def_b;
150 MutableGraphView graph_b(&graph_def_b);
151
152 EXPECT_TRUE(Compare(graph_def_a, graph_def_b));
153
154 AddNode("A", "OpA", {}, {}, &graph_a);
155 AddNode("B", "OpB", {"A"}, {}, &graph_a);
156 EXPECT_FALSE(Compare(graph_def_a, graph_def_b));
157
158 graph_def_b.mutable_node()->CopyFrom(graph_def_a.node());
159 EXPECT_TRUE(Compare(graph_def_a, graph_def_b));
160 }
161
TEST(GraphUtilsTest,ContainsGraphNodeWithName)162 TEST(GraphUtilsTest, ContainsGraphNodeWithName) {
163 GraphDef graph_def;
164 MutableGraphView graph(&graph_def);
165 EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.graph()));
166
167 AddNode("A", "OpA", {}, {}, &graph);
168 EXPECT_TRUE(ContainsGraphNodeWithName("A", *graph.graph()));
169
170 EXPECT_TRUE(graph.DeleteNodes({"A"}).ok());
171 EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.graph()));
172 }
173
TEST(GraphUtilsTest,ContainsGraphFunctionWithName)174 TEST(GraphUtilsTest, ContainsGraphFunctionWithName) {
175 FunctionDefLibrary library;
176 EXPECT_FALSE(ContainsGraphFunctionWithName("new_function", library));
177 FunctionDef* new_function = library.add_function();
178 SetUniqueGraphFunctionName("new_function", &library, new_function);
179
180 EXPECT_TRUE(
181 ContainsGraphFunctionWithName(new_function->signature().name(), library));
182 }
183
TEST(GraphUtilsTest,ContainsNodeWithOp)184 TEST(GraphUtilsTest, ContainsNodeWithOp) {
185 GraphDef graph_def;
186 MutableGraphView graph(&graph_def);
187 EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.graph()));
188
189 AddNode("A", "OpA", {}, {}, &graph);
190 EXPECT_TRUE(ContainsNodeWithOp("OpA", *graph.graph()));
191
192 EXPECT_TRUE(graph.DeleteNodes({"A"}).ok());
193 EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.graph()));
194 }
195
TEST(GraphUtilsTest,FindGraphNodeWithName)196 TEST(GraphUtilsTest, FindGraphNodeWithName) {
197 GraphDef graph_def;
198 MutableGraphView graph(&graph_def);
199 EXPECT_EQ(FindGraphNodeWithName("A", *graph.graph()), -1);
200
201 AddNode("A", "OpA", {}, {}, &graph);
202 EXPECT_NE(FindGraphNodeWithName("A", *graph.graph()), -1);
203
204 EXPECT_TRUE(graph.DeleteNodes({"A"}).ok());
205 EXPECT_EQ(FindGraphNodeWithName("A", *graph.graph()), -1);
206 }
207
TEST(GraphUtilsTest,FindGraphFunctionWithName)208 TEST(GraphUtilsTest, FindGraphFunctionWithName) {
209 FunctionDefLibrary library;
210 EXPECT_EQ(FindGraphFunctionWithName("new_function", library), -1);
211 FunctionDef* new_function = library.add_function();
212 SetUniqueGraphFunctionName("new_function", &library, new_function);
213
214 EXPECT_NE(
215 FindGraphFunctionWithName(new_function->signature().name(), library), -1);
216 }
217
TEST(GraphUtilsTest,FindGraphNodeWithOp)218 TEST(GraphUtilsTest, FindGraphNodeWithOp) {
219 GraphDef graph_def;
220 MutableGraphView graph(&graph_def);
221 EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), -1);
222
223 AddNode("A", "OpA", {}, {}, &graph);
224 AddNode("B", "OpB", {"A"}, {}, &graph);
225 AddNode("A2", "OpA", {"A"}, {}, &graph);
226 EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), 0);
227
228 EXPECT_TRUE(graph.DeleteNodes({"B"}).ok());
229 EXPECT_EQ(FindGraphNodeWithOp("OpB", *graph.graph()), -1);
230 EXPECT_EQ(FindGraphNodeWithName("A2", *graph.graph()), 1);
231 }
232
TEST(GraphUtilsTest,FindAllGraphNodesWithOp)233 TEST(GraphUtilsTest, FindAllGraphNodesWithOp) {
234 GraphDef graph_def;
235 MutableGraphView graph(&graph_def);
236 EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), -1);
237
238 AddNode("A", "OpA", {}, {}, &graph);
239 AddNode("B", "OpB", {"A"}, {}, &graph);
240 AddNode("A2", "OpA", {"B"}, {}, &graph);
241 std::vector<int> result_indices =
242 FindAllGraphNodesWithOp("OpA", *graph.graph());
243 EXPECT_EQ(result_indices.size(), 2);
244 EXPECT_EQ(result_indices.at(0), 0);
245 EXPECT_EQ(result_indices.at(1), 2);
246
247 EXPECT_TRUE(graph.DeleteNodes({"A2"}).ok());
248 std::vector<int> result_indices_new =
249 FindAllGraphNodesWithOp("OpA", *graph.graph());
250 EXPECT_EQ(result_indices_new.size(), 1);
251 EXPECT_EQ(result_indices_new.at(0), 0);
252 }
253
TEST(GraphUtilsTest,SetUniqueGraphNodeName)254 TEST(GraphUtilsTest, SetUniqueGraphNodeName) {
255 GraphDef graph_def;
256 MutableGraphView graph(&graph_def);
257
258 NodeDef* node1 = AddNode("", "A", {}, {}, &graph);
259 NodeDef* node2 = AddNode("", "A", {}, {}, &graph);
260 EXPECT_NE(node1->name(), node2->name());
261
262 EXPECT_TRUE(graph.DeleteNodes({node1->name()}).ok());
263 NodeDef* node3 = AddNode("", "A", {}, {}, &graph);
264 EXPECT_NE(node2->name(), node3->name());
265 }
266
TEST(GraphUtilsTest,SetUniqueGraphFunctionName)267 TEST(GraphUtilsTest, SetUniqueGraphFunctionName) {
268 FunctionDefLibrary library;
269 FunctionDef* new_function = library.add_function();
270 SetUniqueGraphFunctionName("new_function", &library, new_function);
271
272 FunctionDef* other_function = library.add_function();
273 SetUniqueGraphFunctionName("new_function", &library, other_function);
274 EXPECT_NE(new_function->signature().name(),
275 other_function->signature().name());
276 }
277
TEST(GraphUtilsTest,GetInputNode)278 TEST(GraphUtilsTest, GetInputNode) {
279 GraphDef graph_def;
280 MutableGraphView graph(&graph_def);
281
282 NodeDef* node1 = AddNode("", "A", {}, {}, &graph);
283 NodeDef* node2 = AddNode("", "A", {node1->name()}, {}, &graph);
284
285 EXPECT_EQ(GetInputNode(*node2, graph), node1);
286 EXPECT_EQ(GetInputNode(*node1, graph), nullptr);
287 }
288
TEST(GraphUtilsTest,GetIthInputNode)289 TEST(GraphUtilsTest, GetIthInputNode) {
290 GraphDef graph_def;
291 MutableGraphView graph(&graph_def);
292
293 NodeDef* node1 = AddNode("", "A", {}, {}, &graph);
294 NodeDef* node2 = AddNode("", "A", {}, {}, &graph);
295 NodeDef* node3 = AddNode("", "A", {node1->name(), node2->name()}, {}, &graph);
296
297 EXPECT_EQ(GetInputNode(*node3, graph), node1);
298 EXPECT_EQ(GetInputNode(*node3, graph, 1), node2);
299 EXPECT_EQ(GetInputNode(*node3, graph, 0), node1);
300 EXPECT_EQ(GetInputNode(*node3, graph, 2), nullptr);
301 EXPECT_EQ(GetInputNode(*node1, graph), nullptr);
302 }
303
TEST(GraphUtilsTest,EnsureNodeNamesUnique)304 TEST(GraphUtilsTest, EnsureNodeNamesUnique) {
305 Graph g(OpRegistry::Global());
306
307 Node *const_0, *const_1, *const_2;
308
309 // Arbitrary const
310 Tensor tensor(DT_INT32, {});
311 tensor.scalar<int32>()() = 5;
312
313 for (auto node : {&const_0, &const_1}) {
314 TF_EXPECT_OK(NodeBuilder("Const", "Const")
315 .Attr("value", tensor)
316 .Attr("dtype", DT_INT32)
317 .Finalize(&g, node));
318 }
319 // Make sure generated name doesn't clash with existing name either
320 TF_EXPECT_OK(NodeBuilder("Const_1", "Const")
321 .Attr("value", tensor)
322 .Attr("dtype", DT_INT32)
323 .Finalize(&g, &const_2));
324
325 TF_EXPECT_OK(EnsureNodeNamesUnique(&g));
326 EXPECT_NE(const_0->name(), const_1->name());
327 EXPECT_NE(const_1->name(), const_2->name());
328 EXPECT_NE(const_0->name(), const_2->name());
329 }
330
TEST(GraphUtilsTest,TestGetFetchNode)331 TEST(GraphUtilsTest, TestGetFetchNode) {
332 GrapplerItem item;
333 MutableGraphView graph(&item.graph);
334
335 NodeDef* node1 = AddNode("node1", "Identity", {}, {}, &graph);
336 NodeDef* node2 = AddNode("node2", "Identity", {node1->name()}, {}, &graph);
337 NodeDef* node3 = AddNode("node3", "Identity", {node2->name()}, {}, &graph);
338 item.fetch.push_back(node3->name());
339
340 NodeDef* sink_node;
341 TF_EXPECT_OK(GetFetchNode(graph, item, &sink_node));
342 EXPECT_EQ(sink_node->name(), node3->name());
343 }
344
TEST(GraphUtilsTest,TestFindSinkNodeMultipleFetches)345 TEST(GraphUtilsTest, TestFindSinkNodeMultipleFetches) {
346 GrapplerItem item;
347 MutableGraphView graph(&item.graph);
348
349 NodeDef* node1 = AddNode("node1", "Identity", {}, {}, &graph);
350 NodeDef* node2 = AddNode("node2", "Identity", {node1->name()}, {}, &graph);
351 NodeDef* node3 = AddNode("node3", "Identity", {node2->name()}, {}, &graph);
352 item.fetch.push_back(node2->name());
353 item.fetch.push_back(node3->name());
354
355 NodeDef* sink_node;
356 Status s = GetFetchNode(graph, item, &sink_node);
357 EXPECT_FALSE(s.ok());
358 }
359
TEST(GraphUtilsTest,TestFindSinkNodeNoFetches)360 TEST(GraphUtilsTest, TestFindSinkNodeNoFetches) {
361 GrapplerItem item;
362 MutableGraphView graph(&item.graph);
363
364 NodeDef* node1 = AddNode("node1", "Identity", {}, {}, &graph);
365 NodeDef* node2 = AddNode("node2", "Identity", {node1->name()}, {}, &graph);
366 AddNode("node3", "Identity", {node2->name()}, {}, &graph);
367
368 NodeDef* sink_node;
369 Status s = GetFetchNode(graph, item, &sink_node);
370 EXPECT_FALSE(s.ok());
371 }
372
373 } // namespace
374 } // namespace graph_utils
375 } // namespace grappler
376 } // namespace tensorflow
377