• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
17 
18 #include "tensorflow/core/framework/function_testlib.h"
19 #include "tensorflow/core/graph/node_builder.h"
20 #include "tensorflow/core/lib/core/status_test_util.h"
21 #include "tensorflow/core/platform/test.h"
22 
23 namespace tensorflow {
24 namespace grappler {
25 namespace graph_utils {
26 namespace {
27 
TEST(GraphUtilsTest,GetFirstElementIndexWithPredicate)28 TEST(GraphUtilsTest, GetFirstElementIndexWithPredicate) {
29   std::vector<int> vec({1, 2, 3, 4, 5, 6});
30   auto result = GetFirstElementIndexWithPredicate(
31       [](int elem) { return elem % 3 == 0; }, vec);
32 
33   EXPECT_EQ(result, 2);
34 
35   result = GetFirstElementIndexWithPredicate(
36       [](int elem) { return elem % 7 == 0; }, vec);
37   EXPECT_EQ(result, -1);
38 }
39 
TEST(GraphUtilsTest,AddScalarConstNodeBool)40 TEST(GraphUtilsTest, AddScalarConstNodeBool) {
41   GraphDef graph_def;
42   MutableGraphView graph(&graph_def);
43   NodeDef* bool_node = AddScalarConstNode<bool>(true, &graph);
44   EXPECT_TRUE(ContainsGraphNodeWithName(bool_node->name(), *graph.graph()));
45   EXPECT_EQ(bool_node->attr().at("value").tensor().bool_val(0), true);
46 }
47 
TEST(GraphUtilsTest,AddScalarConstNodeDouble)48 TEST(GraphUtilsTest, AddScalarConstNodeDouble) {
49   GraphDef graph_def;
50   MutableGraphView graph(&graph_def);
51   NodeDef* double_node = AddScalarConstNode<double>(3.14, &graph);
52   EXPECT_TRUE(ContainsGraphNodeWithName(double_node->name(), *graph.graph()));
53   EXPECT_FLOAT_EQ(double_node->attr().at("value").tensor().double_val(0), 3.14);
54 }
55 
TEST(GraphUtilsTest,AddScalarConstNodeFloat)56 TEST(GraphUtilsTest, AddScalarConstNodeFloat) {
57   GraphDef graph_def;
58   MutableGraphView graph(&graph_def);
59   NodeDef* float_node = AddScalarConstNode<float>(3.14, &graph);
60   EXPECT_TRUE(ContainsGraphNodeWithName(float_node->name(), *graph.graph()));
61   EXPECT_FLOAT_EQ(float_node->attr().at("value").tensor().float_val(0), 3.14);
62 }
63 
TEST(GraphUtilsTest,AddScalarConstNodeInt)64 TEST(GraphUtilsTest, AddScalarConstNodeInt) {
65   GraphDef graph_def;
66   MutableGraphView graph(&graph_def);
67   NodeDef* int_node = AddScalarConstNode<int>(42, &graph);
68   EXPECT_TRUE(ContainsGraphNodeWithName(int_node->name(), *graph.graph()));
69   EXPECT_EQ(int_node->attr().at("value").tensor().int_val(0), 42);
70 }
71 
TEST(GraphUtilsTest,AddScalarConstNodeInt64)72 TEST(GraphUtilsTest, AddScalarConstNodeInt64) {
73   GraphDef graph_def;
74   MutableGraphView graph(&graph_def);
75   NodeDef* int64_node = AddScalarConstNode<int64>(42, &graph);
76   EXPECT_TRUE(ContainsGraphNodeWithName(int64_node->name(), *graph.graph()));
77   EXPECT_EQ(int64_node->attr().at("value").tensor().int64_val(0), 42);
78 }
79 
TEST(GraphUtilsTest,AddScalarConstNodeString)80 TEST(GraphUtilsTest, AddScalarConstNodeString) {
81   GraphDef graph_def;
82   MutableGraphView graph(&graph_def);
83   NodeDef* string_node = AddScalarConstNode<StringPiece>("hello", &graph);
84   EXPECT_TRUE(ContainsGraphNodeWithName(string_node->name(), *graph.graph()));
85   EXPECT_EQ(string_node->attr().at("value").tensor().string_val(0), "hello");
86 }
87 
TEST(GraphUtilsTest,GetScalarConstNodeInt64)88 TEST(GraphUtilsTest, GetScalarConstNodeInt64) {
89   GraphDef graph_def;
90   MutableGraphView graph(&graph_def);
91   NodeDef* int64_node = AddScalarConstNode<int64>(128, &graph);
92   int64 result;
93   EXPECT_TRUE(GetScalarConstNodeValue<int64>(*int64_node, &result).ok());
94   EXPECT_EQ(result, 128);
95 }
96 
TEST(GraphUtilsTest,GetScalarConstNodeBool)97 TEST(GraphUtilsTest, GetScalarConstNodeBool) {
98   GraphDef graph_def;
99   MutableGraphView graph(&graph_def);
100   NodeDef* bool_node = AddScalarConstNode<bool>(true, &graph);
101   bool result;
102   EXPECT_TRUE(GetScalarConstNodeValue<bool>(*bool_node, &result).ok());
103   EXPECT_EQ(result, true);
104 }
105 
TEST(GraphUtilsTest,GetScalarConstNodeErrorWithNonConst)106 TEST(GraphUtilsTest, GetScalarConstNodeErrorWithNonConst) {
107   GraphDef graph_def;
108   MutableGraphView graph(&graph_def);
109   NodeDef* non_const = AddScalarPlaceholder(DT_INT64, &graph);
110   int64 result;
111   Status s = GetScalarConstNodeValue<int64>(*non_const, &result);
112   EXPECT_FALSE(s.ok());
113   EXPECT_EQ(s.error_message(),
114             "Node Placeholder is not a Const node. Op: Placeholder");
115 }
116 
TEST(GraphUtilsTest,GetScalarConstNodeErrorWithType)117 TEST(GraphUtilsTest, GetScalarConstNodeErrorWithType) {
118   GraphDef graph_def;
119   MutableGraphView graph(&graph_def);
120   NodeDef* int64_node = AddScalarConstNode<int64>(128, &graph);
121   bool result;
122   Status s = GetScalarConstNodeValue<bool>(*int64_node, &result);
123   EXPECT_FALSE(s.ok());
124   EXPECT_EQ(s.error_message(),
125             "Node Const should have type bool but has type: int64");
126 }
127 
TEST(GraphUtilsTest,GetScalarConstNodeErrorWithVector)128 TEST(GraphUtilsTest, GetScalarConstNodeErrorWithVector) {
129   NodeDef node;
130   node.set_name("Const");
131   node.set_op("Const");
132 
133   (*node.mutable_attr())["dtype"].set_type(DT_INT64);
134   auto tensor = (*node.mutable_attr())["value"].mutable_tensor();
135   tensor->set_dtype(DT_INT64);
136   tensor->mutable_tensor_shape()->mutable_dim()->Add()->set_size(1);
137   tensor->add_int64_val(128);
138 
139   int64 result;
140   Status s = GetScalarConstNodeValue<int64>(node, &result);
141   EXPECT_FALSE(s.ok());
142   EXPECT_EQ(s.error_message(),
143             "Node Const should be a scalar but has shape: [1]");
144 }
145 
TEST(GraphUtilsTest,Compare)146 TEST(GraphUtilsTest, Compare) {
147   GraphDef graph_def_a;
148   MutableGraphView graph_a(&graph_def_a);
149   GraphDef graph_def_b;
150   MutableGraphView graph_b(&graph_def_b);
151 
152   EXPECT_TRUE(Compare(graph_def_a, graph_def_b));
153 
154   AddNode("A", "OpA", {}, {}, &graph_a);
155   AddNode("B", "OpB", {"A"}, {}, &graph_a);
156   EXPECT_FALSE(Compare(graph_def_a, graph_def_b));
157 
158   graph_def_b.mutable_node()->CopyFrom(graph_def_a.node());
159   EXPECT_TRUE(Compare(graph_def_a, graph_def_b));
160 }
161 
TEST(GraphUtilsTest,ContainsGraphNodeWithName)162 TEST(GraphUtilsTest, ContainsGraphNodeWithName) {
163   GraphDef graph_def;
164   MutableGraphView graph(&graph_def);
165   EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.graph()));
166 
167   AddNode("A", "OpA", {}, {}, &graph);
168   EXPECT_TRUE(ContainsGraphNodeWithName("A", *graph.graph()));
169 
170   EXPECT_TRUE(graph.DeleteNodes({"A"}).ok());
171   EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.graph()));
172 }
173 
TEST(GraphUtilsTest,ContainsGraphFunctionWithName)174 TEST(GraphUtilsTest, ContainsGraphFunctionWithName) {
175   FunctionDefLibrary library;
176   EXPECT_FALSE(ContainsGraphFunctionWithName("new_function", library));
177   FunctionDef* new_function = library.add_function();
178   SetUniqueGraphFunctionName("new_function", &library, new_function);
179 
180   EXPECT_TRUE(
181       ContainsGraphFunctionWithName(new_function->signature().name(), library));
182 }
183 
TEST(GraphUtilsTest,ContainsNodeWithOp)184 TEST(GraphUtilsTest, ContainsNodeWithOp) {
185   GraphDef graph_def;
186   MutableGraphView graph(&graph_def);
187   EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.graph()));
188 
189   AddNode("A", "OpA", {}, {}, &graph);
190   EXPECT_TRUE(ContainsNodeWithOp("OpA", *graph.graph()));
191 
192   EXPECT_TRUE(graph.DeleteNodes({"A"}).ok());
193   EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.graph()));
194 }
195 
TEST(GraphUtilsTest,FindGraphNodeWithName)196 TEST(GraphUtilsTest, FindGraphNodeWithName) {
197   GraphDef graph_def;
198   MutableGraphView graph(&graph_def);
199   EXPECT_EQ(FindGraphNodeWithName("A", *graph.graph()), -1);
200 
201   AddNode("A", "OpA", {}, {}, &graph);
202   EXPECT_NE(FindGraphNodeWithName("A", *graph.graph()), -1);
203 
204   EXPECT_TRUE(graph.DeleteNodes({"A"}).ok());
205   EXPECT_EQ(FindGraphNodeWithName("A", *graph.graph()), -1);
206 }
207 
TEST(GraphUtilsTest,FindGraphFunctionWithName)208 TEST(GraphUtilsTest, FindGraphFunctionWithName) {
209   FunctionDefLibrary library;
210   EXPECT_EQ(FindGraphFunctionWithName("new_function", library), -1);
211   FunctionDef* new_function = library.add_function();
212   SetUniqueGraphFunctionName("new_function", &library, new_function);
213 
214   EXPECT_NE(
215       FindGraphFunctionWithName(new_function->signature().name(), library), -1);
216 }
217 
TEST(GraphUtilsTest,FindGraphNodeWithOp)218 TEST(GraphUtilsTest, FindGraphNodeWithOp) {
219   GraphDef graph_def;
220   MutableGraphView graph(&graph_def);
221   EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), -1);
222 
223   AddNode("A", "OpA", {}, {}, &graph);
224   AddNode("B", "OpB", {"A"}, {}, &graph);
225   AddNode("A2", "OpA", {"A"}, {}, &graph);
226   EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), 0);
227 
228   EXPECT_TRUE(graph.DeleteNodes({"B"}).ok());
229   EXPECT_EQ(FindGraphNodeWithOp("OpB", *graph.graph()), -1);
230   EXPECT_EQ(FindGraphNodeWithName("A2", *graph.graph()), 1);
231 }
232 
TEST(GraphUtilsTest,FindAllGraphNodesWithOp)233 TEST(GraphUtilsTest, FindAllGraphNodesWithOp) {
234   GraphDef graph_def;
235   MutableGraphView graph(&graph_def);
236   EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), -1);
237 
238   AddNode("A", "OpA", {}, {}, &graph);
239   AddNode("B", "OpB", {"A"}, {}, &graph);
240   AddNode("A2", "OpA", {"B"}, {}, &graph);
241   std::vector<int> result_indices =
242       FindAllGraphNodesWithOp("OpA", *graph.graph());
243   EXPECT_EQ(result_indices.size(), 2);
244   EXPECT_EQ(result_indices.at(0), 0);
245   EXPECT_EQ(result_indices.at(1), 2);
246 
247   EXPECT_TRUE(graph.DeleteNodes({"A2"}).ok());
248   std::vector<int> result_indices_new =
249       FindAllGraphNodesWithOp("OpA", *graph.graph());
250   EXPECT_EQ(result_indices_new.size(), 1);
251   EXPECT_EQ(result_indices_new.at(0), 0);
252 }
253 
TEST(GraphUtilsTest,SetUniqueGraphNodeName)254 TEST(GraphUtilsTest, SetUniqueGraphNodeName) {
255   GraphDef graph_def;
256   MutableGraphView graph(&graph_def);
257 
258   NodeDef* node1 = AddNode("", "A", {}, {}, &graph);
259   NodeDef* node2 = AddNode("", "A", {}, {}, &graph);
260   EXPECT_NE(node1->name(), node2->name());
261 
262   EXPECT_TRUE(graph.DeleteNodes({node1->name()}).ok());
263   NodeDef* node3 = AddNode("", "A", {}, {}, &graph);
264   EXPECT_NE(node2->name(), node3->name());
265 }
266 
TEST(GraphUtilsTest,SetUniqueGraphFunctionName)267 TEST(GraphUtilsTest, SetUniqueGraphFunctionName) {
268   FunctionDefLibrary library;
269   FunctionDef* new_function = library.add_function();
270   SetUniqueGraphFunctionName("new_function", &library, new_function);
271 
272   FunctionDef* other_function = library.add_function();
273   SetUniqueGraphFunctionName("new_function", &library, other_function);
274   EXPECT_NE(new_function->signature().name(),
275             other_function->signature().name());
276 }
277 
TEST(GraphUtilsTest,GetInputNode)278 TEST(GraphUtilsTest, GetInputNode) {
279   GraphDef graph_def;
280   MutableGraphView graph(&graph_def);
281 
282   NodeDef* node1 = AddNode("", "A", {}, {}, &graph);
283   NodeDef* node2 = AddNode("", "A", {node1->name()}, {}, &graph);
284 
285   EXPECT_EQ(GetInputNode(*node2, graph), node1);
286   EXPECT_EQ(GetInputNode(*node1, graph), nullptr);
287 }
288 
TEST(GraphUtilsTest,GetIthInputNode)289 TEST(GraphUtilsTest, GetIthInputNode) {
290   GraphDef graph_def;
291   MutableGraphView graph(&graph_def);
292 
293   NodeDef* node1 = AddNode("", "A", {}, {}, &graph);
294   NodeDef* node2 = AddNode("", "A", {}, {}, &graph);
295   NodeDef* node3 = AddNode("", "A", {node1->name(), node2->name()}, {}, &graph);
296 
297   EXPECT_EQ(GetInputNode(*node3, graph), node1);
298   EXPECT_EQ(GetInputNode(*node3, graph, 1), node2);
299   EXPECT_EQ(GetInputNode(*node3, graph, 0), node1);
300   EXPECT_EQ(GetInputNode(*node3, graph, 2), nullptr);
301   EXPECT_EQ(GetInputNode(*node1, graph), nullptr);
302 }
303 
TEST(GraphUtilsTest,EnsureNodeNamesUnique)304 TEST(GraphUtilsTest, EnsureNodeNamesUnique) {
305   Graph g(OpRegistry::Global());
306 
307   Node *const_0, *const_1, *const_2;
308 
309   // Arbitrary const
310   Tensor tensor(DT_INT32, {});
311   tensor.scalar<int32>()() = 5;
312 
313   for (auto node : {&const_0, &const_1}) {
314     TF_EXPECT_OK(NodeBuilder("Const", "Const")
315                      .Attr("value", tensor)
316                      .Attr("dtype", DT_INT32)
317                      .Finalize(&g, node));
318   }
319   // Make sure generated name doesn't clash with existing name either
320   TF_EXPECT_OK(NodeBuilder("Const_1", "Const")
321                    .Attr("value", tensor)
322                    .Attr("dtype", DT_INT32)
323                    .Finalize(&g, &const_2));
324 
325   TF_EXPECT_OK(EnsureNodeNamesUnique(&g));
326   EXPECT_NE(const_0->name(), const_1->name());
327   EXPECT_NE(const_1->name(), const_2->name());
328   EXPECT_NE(const_0->name(), const_2->name());
329 }
330 
TEST(GraphUtilsTest,TestGetFetchNode)331 TEST(GraphUtilsTest, TestGetFetchNode) {
332   GrapplerItem item;
333   MutableGraphView graph(&item.graph);
334 
335   NodeDef* node1 = AddNode("node1", "Identity", {}, {}, &graph);
336   NodeDef* node2 = AddNode("node2", "Identity", {node1->name()}, {}, &graph);
337   NodeDef* node3 = AddNode("node3", "Identity", {node2->name()}, {}, &graph);
338   item.fetch.push_back(node3->name());
339 
340   NodeDef* sink_node;
341   TF_EXPECT_OK(GetFetchNode(graph, item, &sink_node));
342   EXPECT_EQ(sink_node->name(), node3->name());
343 }
344 
TEST(GraphUtilsTest,TestFindSinkNodeMultipleFetches)345 TEST(GraphUtilsTest, TestFindSinkNodeMultipleFetches) {
346   GrapplerItem item;
347   MutableGraphView graph(&item.graph);
348 
349   NodeDef* node1 = AddNode("node1", "Identity", {}, {}, &graph);
350   NodeDef* node2 = AddNode("node2", "Identity", {node1->name()}, {}, &graph);
351   NodeDef* node3 = AddNode("node3", "Identity", {node2->name()}, {}, &graph);
352   item.fetch.push_back(node2->name());
353   item.fetch.push_back(node3->name());
354 
355   NodeDef* sink_node;
356   Status s = GetFetchNode(graph, item, &sink_node);
357   EXPECT_FALSE(s.ok());
358 }
359 
TEST(GraphUtilsTest,TestFindSinkNodeNoFetches)360 TEST(GraphUtilsTest, TestFindSinkNodeNoFetches) {
361   GrapplerItem item;
362   MutableGraphView graph(&item.graph);
363 
364   NodeDef* node1 = AddNode("node1", "Identity", {}, {}, &graph);
365   NodeDef* node2 = AddNode("node2", "Identity", {node1->name()}, {}, &graph);
366   AddNode("node3", "Identity", {node2->name()}, {}, &graph);
367 
368   NodeDef* sink_node;
369   Status s = GetFetchNode(graph, item, &sink_node);
370   EXPECT_FALSE(s.ok());
371 }
372 
373 }  // namespace
374 }  // namespace graph_utils
375 }  // namespace grappler
376 }  // namespace tensorflow
377