1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/graph.h"
17
18 #include <set>
19 #include <unordered_map>
20 #include <vector>
21
22 #include "tensorflow/core/common_runtime/function.h"
23 #include "tensorflow/core/common_runtime/graph_constructor.h"
24 #include "tensorflow/core/framework/function_testlib.h"
25 #include "tensorflow/core/graph/benchmark_testlib.h"
26 #include "tensorflow/core/graph/node_builder.h"
27 #include "tensorflow/core/kernels/ops_util.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/lib/random/simple_philox.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/protobuf.h"
33 #include "tensorflow/core/platform/test.h"
34 #include "tensorflow/core/platform/test_benchmark.h"
35
36 namespace tensorflow {
37 namespace {
38
39 REGISTER_OP("OneInput").Input("x: float");
40
41 REGISTER_OP("OneOutput").Output("y: float");
42
43 REGISTER_OP("OneInputTwoOutputs")
44 .Input("x: float")
45 .Output("y: float")
46 .Output("z: float");
47
48 REGISTER_OP("TwoInputsOneOutput")
49 .Input("x: float")
50 .Input("y: float")
51 .Output("z: float");
52
53 class GraphTest : public ::testing::Test {
54 protected:
GraphTest()55 GraphTest() : graph_(OpRegistry::Global()) {}
~GraphTest()56 ~GraphTest() override {}
57
VerifyNodes(Node * node,const std::vector<Node * > & expected_in,const std::vector<Node * > & expected_out)58 static void VerifyNodes(Node* node, const std::vector<Node*>& expected_in,
59 const std::vector<Node*>& expected_out) {
60 std::vector<Node*> in;
61 for (const Edge* e : node->in_edges()) {
62 in.push_back(e->src());
63 }
64 EXPECT_EQ(Stringify(expected_in), Stringify(in));
65
66 std::vector<Node*> out;
67 for (const Edge* e : node->out_edges()) {
68 out.push_back(e->dst());
69 }
70 EXPECT_EQ(Stringify(expected_out), Stringify(out));
71 }
72
VerifyGraphStats()73 void VerifyGraphStats() {
74 int nodes = 0;
75 for (const Node* n : graph_.nodes()) {
76 VLOG(1) << n->id();
77 ++nodes;
78 }
79 EXPECT_EQ(nodes, graph_.num_nodes());
80 int edges = 0;
81 for (const Edge* e : graph_.edges()) {
82 VLOG(1) << e->id();
83 ++edges;
84 }
85 EXPECT_EQ(edges, graph_.num_edges());
86 }
87
AddNodeWithName(const string & name)88 Node* AddNodeWithName(const string& name) {
89 Node* node;
90 TF_CHECK_OK(NodeBuilder(name, "NoOp").Finalize(&graph_, &node));
91 return node;
92 }
93
FromNodeDef(const string & name,const string & node_type,int num_inputs)94 Node* FromNodeDef(const string& name, const string& node_type,
95 int num_inputs) {
96 auto builder = NodeDefBuilder(name, node_type);
97 for (int i = 0; i < num_inputs; ++i) {
98 builder = builder.Input(strings::StrCat("node_", i), i, DT_FLOAT);
99 }
100
101 NodeDef node_def;
102 TF_CHECK_OK(builder.Finalize(&node_def));
103
104 Status s;
105 Node* node = graph_.AddNode(node_def, &s);
106 TF_CHECK_OK(s);
107 return node;
108 }
109
FromGraphDef(const string & gdef_ascii)110 void FromGraphDef(const string& gdef_ascii) {
111 GraphDef gdef;
112 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &gdef));
113 GraphConstructorOptions opts;
114 TF_CHECK_OK(ConvertGraphDefToGraph(opts, gdef, &graph_));
115 }
116
FindNode(const string & name)117 Node* FindNode(const string& name) {
118 for (Node* node : graph_.nodes()) {
119 if (node->name() == name) return node;
120 }
121 LOG(FATAL) << name;
122 }
123
ControlEdgeExistsInGraphOrNodeDef(const Node * src,const Node * dst)124 bool ControlEdgeExistsInGraphOrNodeDef(const Node* src, const Node* dst) {
125 for (const Edge* e : dst->in_edges()) {
126 if (e->IsControlEdge() && e->src() == src &&
127 e->src_output() == Graph::kControlSlot &&
128 e->dst_input() == Graph::kControlSlot) {
129 return true;
130 }
131 }
132 std::string control_edge_name = strings::StrCat("^", src->name());
133 for (int i = 0; i < dst->def().input_size(); ++i) {
134 if (dst->def().input(i) == control_edge_name) {
135 return true;
136 }
137 }
138 return false;
139 }
140
141 Graph graph_;
142
143 private:
144 // Convert a list of nodes to a sorted list of strings so failure messages
145 // are readable.
Stringify(const std::vector<Node * > & nodes)146 static std::vector<string> Stringify(const std::vector<Node*>& nodes) {
147 std::vector<string> result;
148 result.reserve(nodes.size());
149 for (Node* n : nodes) {
150 result.push_back(n->DebugString());
151 }
152 std::sort(result.begin(), result.end());
153 return result;
154 }
155 };
156
TEST_F(GraphTest,Constructor)157 TEST_F(GraphTest, Constructor) {
158 Node* source = graph_.source_node();
159 EXPECT_NE(source, nullptr);
160 Node* sink = graph_.sink_node();
161 EXPECT_NE(sink, nullptr);
162 VerifyNodes(source, {}, {sink});
163 VerifyNodes(sink, {source}, {});
164 EXPECT_EQ(2, graph_.num_node_ids());
165 VerifyGraphStats();
166 }
167
TEST_F(GraphTest,RemoveThenAdd)168 TEST_F(GraphTest, RemoveThenAdd) {
169 AddNodeWithName("A");
170 Node* b = AddNodeWithName("B");
171 const int b_id = b->id();
172 AddNodeWithName("C");
173 EXPECT_EQ(5, graph_.num_node_ids());
174 graph_.RemoveNode(b);
175 EXPECT_EQ(5, graph_.num_node_ids());
176 Node* d = AddNodeWithName("D");
177 EXPECT_NE(b_id, d->id()); // Ids should not be reused.
178 EXPECT_EQ(6, graph_.num_node_ids());
179 VerifyGraphStats();
180 }
181
TEST_F(GraphTest,InNodesAndOutNodes)182 TEST_F(GraphTest, InNodesAndOutNodes) {
183 Node* a = FromNodeDef("A", "OneOutput", 0);
184 Node* b = AddNodeWithName("B");
185 Node* c = FromNodeDef("C", "OneInput", 1);
186 graph_.RemoveNode(b);
187 Node* d = AddNodeWithName("D");
188
189 const Edge* source_to_a = graph_.AddControlEdge(graph_.source_node(), a);
190 graph_.AddControlEdge(a, graph_.sink_node());
191 graph_.AddEdge(a, 0, c, 0);
192 graph_.AddControlEdge(c, graph_.sink_node());
193
194 EXPECT_EQ("A", a->name());
195 VerifyNodes(a, {graph_.source_node()}, {c, graph_.sink_node()});
196
197 EXPECT_EQ("C", c->name());
198 VerifyNodes(c, {a}, {graph_.sink_node()});
199
200 EXPECT_EQ("D", d->name());
201 VerifyNodes(d, {}, {});
202
203 VerifyNodes(graph_.source_node(), {}, {a, graph_.sink_node()});
204 VerifyNodes(graph_.sink_node(), {a, c, graph_.source_node()}, {});
205
206 graph_.RemoveEdge(source_to_a);
207 VerifyNodes(a, {}, {c, graph_.sink_node()});
208 VerifyNodes(graph_.source_node(), {}, {graph_.sink_node()}); // no more a
209
210 graph_.RemoveNode(c);
211 VerifyNodes(a, {}, {graph_.sink_node()}); // no more c
212 VerifyNodes(graph_.sink_node(), {a, graph_.source_node()}, {}); // no more c
213 EXPECT_EQ(6, graph_.num_node_ids());
214 EXPECT_EQ(5, graph_.num_edge_ids());
215 VerifyGraphStats();
216 }
217
TEST_F(GraphTest,NodeByIndex)218 TEST_F(GraphTest, NodeByIndex) {
219 Node* a = FromNodeDef("A", "OneOutput", 0);
220 Node* c = FromNodeDef("C", "OneInput", 1);
221 graph_.AddEdge(a, 0, c, 0);
222
223 // Ask for 'a' from 'c' by index.
224 const Node* a_copy;
225 TF_ASSERT_OK(c->input_node(0, &a_copy));
226 EXPECT_EQ(a, a_copy);
227
228 const Edge* e;
229 TF_ASSERT_OK(c->input_edge(0, &e));
230 EXPECT_EQ(0, e->dst_input());
231 EXPECT_EQ(a, e->src());
232 EXPECT_EQ(c, e->dst());
233 EXPECT_EQ(0, e->src_output());
234
235 Node* t = FromNodeDef("T", "TwoInputsOneOutput", 2);
236 graph_.AddEdge(a, 0, t, 0);
237 // Weird self edge
238 graph_.AddEdge(t, 0, t, 1);
239
240 const Node* t_0;
241 const Node* t_1;
242 TF_ASSERT_OK(t->input_node(0, &t_0));
243 EXPECT_EQ(a, t_0);
244 TF_ASSERT_OK(t->input_node(1, &t_1));
245 EXPECT_EQ(t, t_1);
246
247 TF_ASSERT_OK(t->input_edge(1, &e));
248 EXPECT_EQ(1, e->dst_input());
249 EXPECT_EQ(t, e->src());
250
251 std::vector<const Edge*> t_input_edges;
252 TF_ASSERT_OK(t->input_edges(&t_input_edges));
253 ASSERT_EQ(2, t_input_edges.size());
254 EXPECT_EQ(a, t_input_edges[0]->src());
255 EXPECT_EQ(e, t_input_edges[1]);
256
257 // Check out of bounds access
258 EXPECT_FALSE(c->input_node(1, &a_copy).ok());
259 EXPECT_FALSE(c->input_node(-1, &a_copy).ok());
260
261 graph_.RemoveNode(a);
262
263 // 'c's input_node entry should be invalidated.
264 Status s = c->input_node(0, &a_copy);
265 EXPECT_FALSE(s.ok());
266
267 // Add two new nodes.
268 Node* a_new = FromNodeDef("A_new", "OneOutput", 0);
269 Node* b_new = FromNodeDef("B_new", "OneOutput", 0);
270
271 // Connect one up to c.
272 graph_.AddEdge(a_new, 0, c, 0);
273 const Edge* a_new_c_edge;
274 TF_ASSERT_OK(c->input_edge(0, &a_new_c_edge));
275
276 // Connect up the second edge
277 graph_.AddEdge(b_new, 0, c, 0);
278 const Edge* b_new_c_edge;
279 TF_ASSERT_OK(c->input_edge(0, &b_new_c_edge));
280
281 // Now remove the old one
282 graph_.RemoveEdge(a_new_c_edge);
283
284 // Check that the second edge can still be retrieved
285 TF_ASSERT_OK(c->input_edge(0, &b_new_c_edge));
286
287 std::vector<const Edge*> c_input_edges;
288 TF_ASSERT_OK(c->input_edges(&c_input_edges));
289 ASSERT_EQ(1, c_input_edges.size());
290 EXPECT_EQ(b_new_c_edge, c_input_edges[0]);
291 }
292
TEST_F(GraphTest,NodeIteration)293 TEST_F(GraphTest, NodeIteration) {
294 // Set up the graph with some holes due to removals.
295 Node* a = FromNodeDef("A", "OneOutput", 0);
296 Node* b = AddNodeWithName("B");
297 Node* c = FromNodeDef("C", "OneInput", 1);
298 graph_.RemoveNode(b);
299 Node* d = AddNodeWithName("D");
300 const Edge* source_to_a = graph_.AddControlEdge(graph_.source_node(), a);
301 graph_.AddControlEdge(a, graph_.sink_node());
302 graph_.AddEdge(a, 0, c, 0);
303 graph_.AddControlEdge(c, graph_.sink_node());
304 graph_.RemoveEdge(source_to_a);
305 graph_.RemoveNode(c);
306
307 // expected = set of all node DebugStrings we expect in the graph
308 std::set<string> expected;
309 expected.insert(graph_.source_node()->DebugString());
310 expected.insert(a->DebugString());
311 expected.insert(d->DebugString());
312 expected.insert(graph_.sink_node()->DebugString());
313
314 // Verify that iterating through ids gets the same set of nodes.
315 std::set<string> actual;
316 for (int id = 0; id < graph_.num_node_ids(); ++id) {
317 Node* node = graph_.FindNodeId(id);
318 if (node != nullptr) {
319 actual.insert(node->DebugString());
320 }
321 }
322 EXPECT_EQ(expected, actual);
323
324 // Verify that range-based for loop gets the same set of nodes.
325 actual.clear();
326 for (Node* node : graph_.nodes()) {
327 actual.insert(node->DebugString());
328 }
329 EXPECT_EQ(expected, actual);
330 VerifyGraphStats();
331 }
332
CheckType(Node * node,bool b)333 static void CheckType(Node* node, bool b) {
334 EXPECT_TRUE(b) << node->DebugString();
335 // Make sure none of the other IsFoo() methods return true.
336 int count = 0;
337 if (node->IsSource()) count++;
338 if (node->IsSink()) count++;
339 if (node->IsOp()) count++;
340 EXPECT_EQ(1, count) << node->DebugString();
341 }
342
TEST_F(GraphTest,Type)343 TEST_F(GraphTest, Type) {
344 Node* op = AddNodeWithName("A");
345 CheckType(graph_.source_node(), graph_.source_node()->IsSource());
346 CheckType(graph_.sink_node(), graph_.sink_node()->IsSink());
347 CheckType(op, op->IsOp());
348 VerifyGraphStats();
349 }
350
TEST_F(GraphTest,AddAttr)351 TEST_F(GraphTest, AddAttr) {
352 Node* n1 = AddNodeWithName("A");
353
354 n1->AddAttr("_a", "new_attr");
355
356 string attr;
357 EXPECT_EQ(OkStatus(), GetNodeAttr(n1->attrs(), "_a", &attr));
358 EXPECT_EQ("new_attr", attr);
359
360 Node* n2 = graph_.CopyNode(n1);
361
362 n1->AddAttr("_b", "new_attr_2");
363
364 EXPECT_EQ(OkStatus(), GetNodeAttr(n1->attrs(), "_a", &attr));
365 EXPECT_EQ("new_attr", attr);
366 EXPECT_EQ(OkStatus(), GetNodeAttr(n1->attrs(), "_b", &attr));
367 EXPECT_EQ("new_attr_2", attr);
368
369 EXPECT_EQ(OkStatus(), GetNodeAttr(n2->attrs(), "_a", &attr));
370 EXPECT_EQ("new_attr", attr);
371 EXPECT_NE(OkStatus(), GetNodeAttr(n2->attrs(), "_b", &attr));
372 }
373
374 // Convert edge iteration results into a sorted string.
EdgeIter(const Graph & g)375 static string EdgeIter(const Graph& g) {
376 std::vector<std::pair<int, int> > edges;
377 for (const Edge* e : g.edges()) {
378 edges.push_back(std::make_pair(e->src()->id(), e->dst()->id()));
379 }
380 std::sort(edges.begin(), edges.end());
381 string result;
382 for (auto& p : edges) {
383 strings::StrAppend(&result, p.first, "->", p.second, ";");
384 }
385 return result;
386 }
387
TEST_F(GraphTest,EdgeIteration)388 TEST_F(GraphTest, EdgeIteration) {
389 EXPECT_EQ("0->1;", EdgeIter(graph_));
390
391 Node* a = FromNodeDef("A", "OneInputTwoOutputs", 1);
392 Node* b = FromNodeDef("B", "OneInput", 1);
393 EXPECT_EQ("0->1;", EdgeIter(graph_)); // Since a,b are currently disconnected
394
395 graph_.AddEdge(a, 0, b, 0);
396 EXPECT_EQ("0->1;2->3;", EdgeIter(graph_));
397
398 graph_.AddControlEdge(graph_.source_node(), a);
399 graph_.AddControlEdge(b, graph_.sink_node());
400 EXPECT_EQ("0->1;0->2;2->3;3->1;", EdgeIter(graph_));
401
402 graph_.AddEdge(a, 1, a, 0);
403 EXPECT_EQ("0->1;0->2;2->2;2->3;3->1;", EdgeIter(graph_));
404 VerifyGraphStats();
405 }
406
TEST_F(GraphTest,NewName)407 TEST_F(GraphTest, NewName) {
408 string a1 = graph_.NewName("A");
409 string a2 = graph_.NewName("A");
410 string b1 = graph_.NewName("B");
411 EXPECT_NE(a1, a2);
412 EXPECT_NE(a1, b1);
413 EXPECT_NE(a2, b1);
414 EXPECT_TRUE(absl::StartsWith(a1, "A")) << a1;
415 }
416
TEST_F(GraphTest,IsValidNode)417 TEST_F(GraphTest, IsValidNode) {
418 // Add 1 node to graph_
419 Node* g1_node1;
420 TF_CHECK_OK(NodeBuilder("g1_node1", "NoOp").Finalize(&graph_, &g1_node1));
421
422 // Add 2 nodes to graph2
423 Graph graph2(OpRegistry::Global());
424 Node* g2_node1;
425 Node* g2_node2;
426 TF_CHECK_OK(NodeBuilder("g2_node1", "NoOp").Finalize(&graph2, &g2_node1));
427 TF_CHECK_OK(NodeBuilder("g2_node2", "NoOp").Finalize(&graph2, &g2_node2));
428
429 // nullptr
430 Status s = graph_.IsValidNode(nullptr);
431 EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
432 EXPECT_EQ(string("Node is null"), s.error_message());
433
434 // node id_ is too high
435 s = graph_.IsValidNode(g2_node2);
436 EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
437 EXPECT_EQ(string("node id 3 is >= than number of nodes in graph 3"),
438 s.error_message());
439
440 // valid id_ but different ptr
441 s = graph_.IsValidNode(g2_node1);
442 EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
443 EXPECT_EQ(string("Node with id 2 is different from the passed in node. "
444 "Does it belong to a different graph?"),
445 s.error_message());
446 }
447
TEST_F(GraphTest,AddControlEdge)448 TEST_F(GraphTest, AddControlEdge) {
449 FromGraphDef(
450 "node { name: 'A' op: 'OneOutput' }"
451 "node { name: 'B' op: 'OneInputTwoOutputs' input: [ 'A:0' ] }"
452 "node { name: 'C' op: 'NoOp' } ");
453 Node* a = FindNode("A");
454 Node* b = FindNode("B");
455 Node* c = FindNode("C");
456
457 // Add a control edge.
458 const Edge* edge = graph_.AddControlEdge(c, a);
459 ASSERT_TRUE(edge != nullptr);
460 // Check newly-created edge.
461 EXPECT_EQ(edge->src(), c);
462 EXPECT_EQ(edge->src_output(), Graph::kControlSlot);
463 EXPECT_EQ(edge->dst(), a);
464 EXPECT_EQ(edge->dst_input(), Graph::kControlSlot);
465 // Check A's NodeDef.
466 ASSERT_EQ(a->def().input_size(), 1);
467 EXPECT_EQ(a->def().input(0), "^C");
468
469 // Can add control edge redundant with data edge.
470 edge = graph_.AddControlEdge(a, b);
471 EXPECT_TRUE(edge != nullptr);
472 ASSERT_EQ(b->def().input_size(), 2);
473 EXPECT_EQ(b->def().input(0), "A:0");
474 EXPECT_EQ(b->def().input(1), "^A");
475
476 // Doesn't add edge redundant with control edge.
477 edge = graph_.AddControlEdge(a, b);
478 EXPECT_TRUE(edge == nullptr);
479 EXPECT_EQ(b->def().input_size(), 2);
480
481 // Can add redundant control edge with allow_duplicates.
482 edge = graph_.AddControlEdge(a, b, /*allow_duplicates=*/true);
483 EXPECT_TRUE(edge != nullptr);
484 // create_duplicate causes the NodeDef not to be updated.
485 ASSERT_EQ(b->def().input_size(), 2);
486 EXPECT_EQ(b->def().input(0), "A:0");
487 EXPECT_EQ(b->def().input(1), "^A");
488
489 // Add control edge from source.
490 edge = graph_.AddControlEdge(graph_.source_node(), b);
491 EXPECT_TRUE(edge != nullptr);
492 // Check that we don't include source input in the NodeDef.
493 EXPECT_EQ(b->def().input_size(), 2);
494 // Doesn't add redundant edge.
495 edge = graph_.AddControlEdge(graph_.source_node(), b);
496 EXPECT_TRUE(edge == nullptr);
497 EXPECT_EQ(b->def().input_size(), 2);
498 }
499
TEST_F(GraphTest,RemoveControlEdge)500 TEST_F(GraphTest, RemoveControlEdge) {
501 FromGraphDef(
502 "node { name: 'A' op: 'OneOutput' }"
503 "node { name: 'B' op: 'OneInputTwoOutputs' input: [ 'A:0' ] }"
504 "node { name: 'C' op: 'NoOp' } ");
505 Node* a = FindNode("A");
506 Node* b = FindNode("B");
507 Node* c = FindNode("C");
508
509 // Add a control edge.
510 const Edge* edge_1 = graph_.AddControlEdge(c, a);
511 const Edge* edge_2 = graph_.AddControlEdge(a, b);
512 ASSERT_TRUE(edge_1 != nullptr);
513 ASSERT_TRUE(edge_2 != nullptr);
514
515 ASSERT_TRUE(ControlEdgeExistsInGraphOrNodeDef(c, a));
516 ASSERT_TRUE(ControlEdgeExistsInGraphOrNodeDef(a, b));
517
518 graph_.RemoveControlEdge(edge_1);
519 ASSERT_TRUE(!ControlEdgeExistsInGraphOrNodeDef(c, a));
520 ASSERT_TRUE(ControlEdgeExistsInGraphOrNodeDef(a, b));
521
522 graph_.RemoveControlEdge(edge_2);
523 ASSERT_TRUE(!ControlEdgeExistsInGraphOrNodeDef(c, a));
524 ASSERT_TRUE(!ControlEdgeExistsInGraphOrNodeDef(a, b));
525
526 // Test removing a duplicate control edge.
527 // Note that unless allow_duplicates is true, the duplicate edge
528 // will not be added. That's why we expect edge_4 to be a null
529 // pointer. We are not testing with allow_duplicates set to true,
530 // as that is a highly unlikely use case that does not make much
531 // sense.
532 const Edge* edge_3 = graph_.AddControlEdge(c, a);
533 const Edge* edge_4 = graph_.AddControlEdge(c, a);
534 ASSERT_TRUE(edge_3 != nullptr);
535 ASSERT_TRUE(edge_4 == nullptr);
536
537 graph_.RemoveControlEdge(edge_3);
538 ASSERT_TRUE(!ControlEdgeExistsInGraphOrNodeDef(c, a));
539 }
540
TEST_F(GraphTest,UpdateEdge)541 TEST_F(GraphTest, UpdateEdge) {
542 // Build a little graph
543 Node* a = FromNodeDef("A", "OneOutput", 0);
544 Node* b = FromNodeDef("B", "OneInputTwoOutputs", 1);
545 Node* c = FromNodeDef("C", "OneInputTwoOutputs", 1);
546 Node* d = FromNodeDef("D", "OneInput", 1);
547
548 graph_.AddControlEdge(graph_.source_node(), a);
549 graph_.AddControlEdge(a, graph_.sink_node());
550 graph_.AddEdge(a, 0, c, 0);
551
552 graph_.AddControlEdge(c, graph_.sink_node());
553 graph_.AddEdge(c, 0, b, 0);
554 graph_.AddEdge(c, 1, d, 0);
555
556 // Initial edge connections
557 EXPECT_EQ("0->1;0->2;2->1;2->4;4->1;4->3;4->5;", EdgeIter(graph_));
558
559 // Update the inputs, expect that Edge a to b (2->3) is now in the graph
560 // and c to b (4->3) no longer appears.
561 TF_EXPECT_OK(graph_.UpdateEdge(a, 0, b, 0));
562 // Check that the edge is connecting the correct nodes.
563 EXPECT_EQ("0->1;0->2;2->1;2->3;2->4;4->1;4->5;", EdgeIter(graph_));
564
565 // Update a's 0th output again.
566 TF_EXPECT_OK(graph_.UpdateEdge(a, 0, d, 0));
567 EXPECT_EQ("0->1;0->2;2->1;2->3;2->4;2->5;4->1;", EdgeIter(graph_));
568
569 // Update a's 1st output which is out of range.
570 Status s = graph_.UpdateEdge(a, 1, d, 0);
571 EXPECT_FALSE(s.ok());
572 EXPECT_EQ(
573 s.error_message(),
574 "Node 'A' (type: 'OneOutput', num of outputs: 1) does not have output 1");
575
576 // Update a's 1st input which is out of range.
577 s = graph_.UpdateEdge(c, 0, a, 0);
578 EXPECT_FALSE(s.ok());
579 EXPECT_EQ(
580 s.error_message(),
581 "Node 'A' (type: 'OneOutput', num of inputs: 0) does not have input 0");
582 }
583
TEST_F(GraphTest,InputEdges)584 TEST_F(GraphTest, InputEdges) {
585 Node* a = FromNodeDef("A", "OneOutput", 0);
586 Node* b = FromNodeDef("B", "TwoInputsOneOutput", 2);
587 graph_.AddEdge(a, 0, b, 0);
588 std::vector<const Edge*> edges;
589 EXPECT_EQ(error::INVALID_ARGUMENT, b->input_edges(&edges).code());
590 graph_.AddEdge(a, 0, b, 1);
591 TF_EXPECT_OK(b->input_edges(&edges));
592 }
593
TEST_F(GraphTest,AddFunctionLibrary)594 TEST_F(GraphTest, AddFunctionLibrary) {
595 // Basic functionality
596 FunctionDefLibrary proto;
597 *proto.add_function() = test::function::XTimesTwo();
598 *proto.add_function() = test::function::XTimesFour();
599 TF_EXPECT_OK(graph_.AddFunctionLibrary(proto));
600 EXPECT_TRUE(graph_.flib_def().Find("XTimesTwo") != nullptr);
601 EXPECT_TRUE(graph_.flib_def().Find("XTimesFour") != nullptr);
602
603 // Duplicate functions are ignored
604 TF_EXPECT_OK(graph_.AddFunctionLibrary(proto));
605 EXPECT_TRUE(graph_.flib_def().Find("XTimesTwo") != nullptr);
606 EXPECT_TRUE(graph_.flib_def().Find("XTimesFour") != nullptr);
607
608 // Duplicate names corresponding to different functions trigger an error
609 FunctionDefLibrary error_proto = proto;
610 *error_proto.mutable_function(0)->add_node_def() =
611 error_proto.function(0).node_def(0);
612 Status s = graph_.AddFunctionLibrary(error_proto);
613 EXPECT_FALSE(s.ok());
614 EXPECT_EQ(s.error_message(),
615 "Cannot add function 'XTimesTwo' because a different function with "
616 "the same name already exists.");
617
618 // Function with same name as an existing op triggers an error
619 error_proto = proto;
620 error_proto.mutable_function(0)->mutable_signature()->set_name("Add");
621 s = graph_.AddFunctionLibrary(error_proto);
622 EXPECT_FALSE(s.ok());
623 EXPECT_EQ(s.error_message(),
624 "Cannot add function 'Add' because an op with the same name "
625 "already exists.");
626
627 // Adding a gradient function to an existing function is ok
628 GradientDef* grad = proto.add_gradient();
629 grad->set_function_name("XTimesTwo");
630 grad->set_gradient_func("Undefined"); // undefined funcs in grads are ok
631 TF_EXPECT_OK(graph_.AddFunctionLibrary(proto));
632 EXPECT_EQ(graph_.flib_def().FindGradient("XTimesTwo"), "Undefined");
633
634 // Duplicate gradients are ignored
635 TF_EXPECT_OK(graph_.AddFunctionLibrary(proto));
636 EXPECT_EQ(graph_.flib_def().FindGradient("XTimesTwo"), "Undefined");
637
638 // Conflicting gradient triggers an error
639 error_proto = proto;
640 error_proto.mutable_gradient(0)->set_gradient_func("Undefined2");
641 s = graph_.AddFunctionLibrary(error_proto);
642 EXPECT_FALSE(s.ok());
643 EXPECT_EQ(s.error_message(),
644 "Cannot assign gradient function 'Undefined2' to 'XTimesTwo' "
645 "because it already has gradient function 'Undefined'");
646 }
647
TEST_F(GraphTest,BuildNodeNameIndex)648 TEST_F(GraphTest, BuildNodeNameIndex) {
649 FromGraphDef(
650 "node { name: 'A' op: 'OneOutput' }"
651 "node { name: 'B' op: 'OneInputTwoOutputs' input: [ 'A:0' ] }"
652 "node { name: 'C' op: 'NoOp' } ");
653
654 auto node_name_index = graph_.BuildNodeNameIndex();
655 EXPECT_EQ(node_name_index.size(), 5);
656
657 std::vector<string> node_names{"_SOURCE", "_SINK", "A", "B", "C"};
658 for (const string& node_name : node_names) {
659 EXPECT_NE(node_name_index.find(node_name), node_name_index.end());
660 EXPECT_EQ(node_name_index[node_name], FindNode(node_name));
661 }
662 }
663
TEST_F(GraphTest,Clear)664 TEST_F(GraphTest, Clear) {
665 const int num_nodes = 10;
666 const int num_edges_per_node = 2;
667 const GraphDef graph_def =
668 test::CreateGraphDef(num_nodes, num_edges_per_node);
669 const auto registry = OpRegistry::Global();
670 GraphConstructorOptions opts;
671 Graph graph(registry);
672 TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
673 graph.Clear();
674 EXPECT_EQ(graph.num_nodes(), 2);
675 }
676
BM_InEdgeIteration(::testing::benchmark::State & state)677 void BM_InEdgeIteration(::testing::benchmark::State& state) {
678 const int num_nodes = state.range(0);
679 const int num_edges_per_node = state.range(1);
680 const GraphDef graph_def =
681 test::CreateGraphDef(num_nodes, num_edges_per_node);
682 Graph graph(OpRegistry::Global());
683 GraphConstructorOptions opts;
684 TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
685
686 int64_t sum = 0;
687 for (auto s : state) {
688 for (const Node* node : graph.nodes()) {
689 for (auto e : node->in_edges()) {
690 sum += e->id();
691 }
692 }
693 }
694 VLOG(1) << sum;
695 }
696 BENCHMARK(BM_InEdgeIteration)->ArgPair(10, 2);
697 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 6, 2);
698 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 9, 2);
699 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 12, 2);
700 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 15, 2);
701 BENCHMARK(BM_InEdgeIteration)->ArgPair(10, 4);
702 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 6, 4);
703 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 9, 4);
704 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 12, 4);
705 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 15, 4);
706 BENCHMARK(BM_InEdgeIteration)->ArgPair(10, 8);
707 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 6, 8);
708 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 9, 8);
709 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 12, 8);
710 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 15, 8);
711 BENCHMARK(BM_InEdgeIteration)->ArgPair(10, 16);
712 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 6, 16);
713 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 9, 16);
714 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 12, 16);
715 BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 15, 16);
716
BM_GraphCreation(::testing::benchmark::State & state)717 void BM_GraphCreation(::testing::benchmark::State& state) {
718 const int num_nodes = state.range(0);
719 const int num_edges_per_node = state.range(1);
720 const GraphDef graph_def =
721 test::CreateGraphDef(num_nodes, num_edges_per_node);
722 const auto registry = OpRegistry::Global();
723 GraphConstructorOptions opts;
724 // Warmup step.
725 Graph graph(registry);
726 TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
727 int64_t sum = 0;
728 for (auto s : state) {
729 Graph graph(registry);
730 TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
731 sum += graph.num_node_ids();
732 }
733 VLOG(1) << sum;
734 }
735 BENCHMARK(BM_GraphCreation)->ArgPair(10, 2);
736 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 6, 2);
737 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 9, 2);
738 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 12, 2);
739 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 15, 2);
740 BENCHMARK(BM_GraphCreation)->ArgPair(10, 4);
741 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 6, 4);
742 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 9, 4);
743 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 12, 4);
744 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 15, 4);
745 BENCHMARK(BM_GraphCreation)->ArgPair(10, 8);
746 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 6, 8);
747 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 9, 8);
748 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 12, 8);
749 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 15, 8);
750 BENCHMARK(BM_GraphCreation)->ArgPair(10, 16);
751 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 6, 16);
752 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 9, 16);
753 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 12, 16);
754 BENCHMARK(BM_GraphCreation)->ArgPair(1 << 15, 16);
755
BM_ToGraphDef(::testing::benchmark::State & state)756 void BM_ToGraphDef(::testing::benchmark::State& state) {
757 const int num_nodes = state.range(0);
758 const int num_edges_per_node = state.range(1);
759 const GraphDef graph_def =
760 test::CreateGraphDef(num_nodes, num_edges_per_node);
761 const auto registry = OpRegistry::Global();
762 GraphConstructorOptions opts;
763 // Warmup step.
764 Graph graph(registry);
765 TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
766 int64_t sum = 0;
767 for (auto s : state) {
768 GraphDef graph_def;
769 graph.ToGraphDef(&graph_def);
770 sum += graph_def.node_size();
771 }
772 VLOG(1) << sum;
773 }
774 BENCHMARK(BM_ToGraphDef)->ArgPair(10, 2);
775 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 2);
776 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 2);
777 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 2);
778 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 2);
779 BENCHMARK(BM_ToGraphDef)->ArgPair(10, 4);
780 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 4);
781 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 4);
782 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 4);
783 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 4);
784 BENCHMARK(BM_ToGraphDef)->ArgPair(10, 8);
785 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 8);
786 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 8);
787 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 8);
788 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 8);
789 BENCHMARK(BM_ToGraphDef)->ArgPair(10, 16);
790 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 16);
791 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 16);
792 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 16);
793 BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 16);
794
BM_RemoveNode(::testing::benchmark::State & state)795 void BM_RemoveNode(::testing::benchmark::State& state) {
796 const int num_nodes = state.range(0);
797 const int num_edges_per_node = state.range(1);
798 const GraphDef graph_def =
799 test::CreateGraphDef(num_nodes, num_edges_per_node);
800 const auto registry = OpRegistry::Global();
801 GraphConstructorOptions opts;
802 for (auto s : state) {
803 state.PauseTiming();
804 Graph graph(registry);
805 TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
806 state.ResumeTiming();
807 for (Node* n : graph.op_nodes()) {
808 graph.RemoveNode(n);
809 }
810 }
811 }
812 BENCHMARK(BM_RemoveNode)->ArgPair(10, 2);
813 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 2);
814 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 2);
815 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 2);
816 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 2);
817 BENCHMARK(BM_RemoveNode)->ArgPair(10, 4);
818 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 4);
819 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 4);
820 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 4);
821 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 4);
822 BENCHMARK(BM_RemoveNode)->ArgPair(10, 8);
823 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 8);
824 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 8);
825 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 8);
826 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 8);
827 BENCHMARK(BM_RemoveNode)->ArgPair(10, 16);
828 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 16);
829 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 16);
830 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 16);
831 BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 16);
832
833 } // namespace
834 } // namespace tensorflow
835