1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/graph_partition.h"
17
18 #include <unordered_map>
19 #include <utility>
20
21 #include "tensorflow/cc/ops/array_ops.h"
22 #include "tensorflow/cc/ops/const_op.h"
23 #include "tensorflow/cc/ops/control_flow_ops.h"
24 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
25 #include "tensorflow/cc/ops/math_ops.h"
26 #include "tensorflow/cc/ops/random_ops.h"
27 #include "tensorflow/cc/ops/sendrecv_ops.h"
28 #include "tensorflow/cc/ops/while_loop.h"
29 #include "tensorflow/core/framework/common_shape_fns.h"
30 #include "tensorflow/core/framework/function_testlib.h"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/versions.pb.h"
33 #include "tensorflow/core/graph/graph.h"
34 #include "tensorflow/core/graph/graph_constructor.h"
35 #include "tensorflow/core/graph/graph_def_builder.h"
36 #include "tensorflow/core/kernels/ops_util.h"
37 #include "tensorflow/core/lib/core/status_test_util.h"
38 #include "tensorflow/core/lib/strings/str_util.h"
39 #include "tensorflow/core/platform/logging.h"
40 #include "tensorflow/core/platform/protobuf.h"
41 #include "tensorflow/core/platform/test.h"
42 #include "tensorflow/core/public/version.h"
43 #include "tensorflow/core/util/equal_graph_def.h"
44
45 namespace tensorflow {
46
47 // from graph_partition.cc
48 extern Status TopologicalSortNodesWithTimePriority(
49 const GraphDef* gdef, std::vector<std::pair<const NodeDef*, int64>>* nodes,
50 std::unordered_map<const NodeDef*, int64>* node_to_start_time_out);
51
52 namespace {
53
54 using ops::_Recv;
55 using ops::_Send;
56 using ops::Const;
57 using ops::Identity;
58 using ops::LoopCond;
59 using ops::NextIteration;
60
61 const char gpu_device[] = "/job:a/replica:0/task:0/device:GPU:0";
62
SplitByDevice(const Node * node)63 string SplitByDevice(const Node* node) { return node->assigned_device_name(); }
64
DeviceName(const Node * node)65 string DeviceName(const Node* node) {
66 char first = node->name()[0];
67 if (first == 'G') {
68 return gpu_device;
69 } else {
70 const string cpu_prefix = "/job:a/replica:0/task:0/cpu:";
71 int index = first - 'A';
72 return strings::StrCat(cpu_prefix, index);
73 }
74 }
75
Partition(const GraphDef & graph_def,std::unordered_map<string,GraphDef> * partitions)76 void Partition(const GraphDef& graph_def,
77 std::unordered_map<string, GraphDef>* partitions) {
78 Graph g(OpRegistry::Global());
79 GraphConstructorOptions opts;
80 TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &g));
81
82 // Assigns devices to each node. Uses 1st letter of the node name as the
83 // device index if no device is specified.
84 for (Node* node : g.nodes()) {
85 string device_name = !node->requested_device().empty()
86 ? node->requested_device()
87 : DeviceName(node);
88 node->set_assigned_device_name(device_name);
89 }
90
91 PartitionOptions popts;
92 popts.node_to_loc = SplitByDevice;
93 popts.new_name = [&g](const string& prefix) { return g.NewName(prefix); };
94 popts.get_incarnation = [](const string& name) {
95 return (name[0] - 'A') + 100;
96 };
97 Status s = Partition(popts, &g, partitions);
98 CHECK(s.ok()) << s;
99
100 // Check versions.
101 EXPECT_EQ(graph_def.versions().producer(), TF_GRAPH_DEF_VERSION);
102 // Partitions must inherit the versions of the original graph.
103 for (auto& it : *partitions) {
104 EXPECT_EQ(graph_def.versions().producer(), it.second.versions().producer());
105 EXPECT_EQ(graph_def.versions().min_consumer(),
106 it.second.versions().min_consumer());
107 }
108 }
109
CheckLoopConstruction(const GraphDef & graph_def)110 void CheckLoopConstruction(const GraphDef& graph_def) {
111 std::unordered_map<string, GraphDef> partitions;
112 Partition(graph_def, &partitions);
113 for (const auto& kv : partitions) {
114 const GraphDef& gdef = kv.second;
115 bool has_control_enter = false;
116 bool has_control_merge = false;
117 bool has_control_switch = false;
118 bool has_control_next = false;
119 for (const NodeDef& ndef : gdef.node()) {
120 // _recvs must have a control input
121 if (ndef.op() == "_Recv") {
122 bool has_control = false;
123 for (const string& input_name : ndef.input()) {
124 if (str_util::StartsWith(input_name, "^")) {
125 has_control = true;
126 break;
127 }
128 }
129 EXPECT_TRUE(has_control);
130 }
131 // Must have a control loop
132 if (str_util::StartsWith(ndef.name(), "_cloop")) {
133 if (ndef.op() == "Enter") {
134 has_control_enter = true;
135 }
136 if (ndef.op() == "Merge") {
137 has_control_merge = true;
138 }
139 if (ndef.op() == "Switch") {
140 has_control_switch = true;
141 }
142 if (ndef.op() == "NextIteration") {
143 has_control_next = true;
144 }
145 }
146 }
147 EXPECT_TRUE(has_control_enter);
148 EXPECT_TRUE(has_control_merge);
149 EXPECT_TRUE(has_control_switch);
150 EXPECT_TRUE(has_control_next);
151 }
152 }
153
154 REGISTER_OP("FloatInput")
155 .Output("o: float")
156 .SetShapeFn(shape_inference::UnknownShape);
157 REGISTER_OP("BoolInput")
158 .Output("o: bool")
159 .SetShapeFn(shape_inference::UnknownShape);
160 REGISTER_OP("Combine")
161 .Input("a: float")
162 .Input("b: float")
163 .Output("o: float")
164 .SetShapeFn(shape_inference::UnknownShape);
165
ConstructOp(const Scope & scope,const string & op_type,const gtl::ArraySlice<Input> & inputs)166 Output ConstructOp(const Scope& scope, const string& op_type,
167 const gtl::ArraySlice<Input>& inputs) {
168 if (!scope.ok()) return Output();
169 const string unique_name = scope.GetUniqueNameForOp(op_type);
170 auto builder =
171 NodeBuilder(unique_name, op_type, scope.graph()->op_registry());
172 for (auto const& input : inputs) {
173 builder.Input(ops::NodeOut(input.node(), input.index()));
174 }
175 scope.UpdateBuilder(&builder);
176 Node* ret;
177 scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
178 if (!scope.ok()) return Output();
179 scope.UpdateStatus(scope.DoShapeInference(ret));
180 if (!scope.ok()) return Output();
181 return Output(ret);
182 }
183
FloatInput(const Scope & scope)184 Output FloatInput(const Scope& scope) {
185 return ConstructOp(scope, "FloatInput", {});
186 }
187
BoolInput(const Scope & scope)188 Output BoolInput(const Scope& scope) {
189 return ConstructOp(scope, "BoolInput", {});
190 }
191
Combine(const Scope & scope,Input a,Input b)192 Output Combine(const Scope& scope, Input a, Input b) {
193 return ConstructOp(scope, "Combine", {std::move(a), std::move(b)});
194 }
195
196 class GraphPartitionTest : public ::testing::Test {
197 protected:
GraphPartitionTest()198 GraphPartitionTest()
199 : in_(Scope::NewRootScope().ExitOnError()),
200 scope_a_(Scope::NewRootScope().ExitOnError().WithDevice(
201 "/job:a/replica:0/task:0/cpu:0")),
202 scope_b_(Scope::NewRootScope().ExitOnError().WithDevice(
203 "/job:a/replica:0/task:0/cpu:1")) {}
204
ToGraphDef()205 const GraphDef& ToGraphDef() {
206 TF_EXPECT_OK(in_.ToGraphDef(&in_graph_def_));
207 return in_graph_def_;
208 }
209
ExpectMatchA()210 void ExpectMatchA() {
211 GraphDef graph_def;
212 TF_EXPECT_OK(scope_a_.ToGraphDef(&graph_def));
213 string a = "/job:a/replica:0/task:0/cpu:0";
214 TF_EXPECT_GRAPH_EQ(graph_def, partitions_[a]);
215 }
216
ExpectMatchB()217 void ExpectMatchB() {
218 GraphDef graph_def;
219 TF_EXPECT_OK(scope_b_.ToGraphDef(&graph_def));
220 string b = "/job:a/replica:0/task:0/cpu:1";
221 TF_EXPECT_GRAPH_EQ(graph_def, partitions_[b]);
222 }
223
ExpectFunctions(const FunctionDefLibrary & library,const std::set<string> & expected_names)224 void ExpectFunctions(const FunctionDefLibrary& library,
225 const std::set<string>& expected_names) {
226 std::set<string> actual_names;
227 for (const FunctionDef& fdef : library.function()) {
228 actual_names.insert(fdef.signature().name());
229 }
230 EXPECT_EQ(actual_names, expected_names);
231 }
232
233 Scope in_;
234 GraphDef in_graph_def_;
235 Scope scope_a_;
236 Scope scope_b_;
237 std::unordered_map<string, GraphDef> partitions_;
238 };
239
TEST_F(GraphPartitionTest,SingleDevice)240 TEST_F(GraphPartitionTest, SingleDevice) {
241 auto a1 = FloatInput(in_.WithOpName("A1"));
242 Combine(in_.WithOpName("A2"), a1, a1);
243
244 Partition(ToGraphDef(), &partitions_);
245 EXPECT_EQ(1, partitions_.size());
246
247 a1 = FloatInput(scope_a_.WithOpName("A1"));
248 Combine(scope_a_.WithOpName("A2"), a1, a1);
249 ExpectMatchA();
250 }
251
TEST_F(GraphPartitionTest,CrossDeviceData)252 TEST_F(GraphPartitionTest, CrossDeviceData) {
253 auto a1 = FloatInput(in_.WithOpName("A1"));
254 auto b1 = FloatInput(in_.WithOpName("B1"));
255 Combine(in_.WithOpName("B2"), a1, b1);
256
257 Partition(ToGraphDef(), &partitions_);
258 EXPECT_EQ(2, partitions_.size());
259
260 string a = "/job:a/replica:0/task:0/cpu:0";
261 string b = "/job:a/replica:0/task:0/cpu:1";
262 a1 = FloatInput(scope_a_.WithOpName("A1"));
263 _Send(scope_a_.WithOpName("A1/_0"), a1, "edge_1_A1", a, 82, b);
264 ExpectMatchA();
265
266 b1 = FloatInput(scope_b_.WithOpName("B1"));
267 auto recv =
268 _Recv(scope_b_.WithOpName("A1/_1"), DT_FLOAT, "edge_1_A1", a, 82, b);
269 Combine(scope_b_.WithOpName("B2"), recv, b1);
270 ExpectMatchB();
271 }
272
TEST_F(GraphPartitionTest,CrossDeviceControl)273 TEST_F(GraphPartitionTest, CrossDeviceControl) {
274 auto a1 = FloatInput(in_.WithOpName("A1"));
275 auto b1 = FloatInput(in_.WithOpName("B1"));
276 Combine(in_.WithOpName("B2").WithControlDependencies(a1), b1, b1);
277
278 Partition(ToGraphDef(), &partitions_);
279 EXPECT_EQ(2, partitions_.size());
280
281 string a = "/job:a/replica:0/task:0/cpu:0";
282 string b = "/job:a/replica:0/task:0/cpu:1";
283 a1 = FloatInput(scope_a_.WithOpName("A1"));
284 auto c = Const(scope_a_.WithOpName("A1/_0").WithControlDependencies(a1), {});
285 _Send(scope_a_.WithOpName("A1/_1"), c, "edge_3_A1", a, 82, b);
286 ExpectMatchA();
287
288 auto recv =
289 _Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_3_A1", a, 82, b);
290 auto id = Identity(scope_b_.WithOpName("A1/_3"), recv);
291 b1 = FloatInput(scope_b_.WithOpName("B1"));
292 Combine(scope_b_.WithOpName("B2").WithControlDependencies(id), b1, b1);
293 ExpectMatchB();
294 }
295
TEST_F(GraphPartitionTest,CrossDeviceData_MultiUse)296 TEST_F(GraphPartitionTest, CrossDeviceData_MultiUse) {
297 auto a1 = FloatInput(in_.WithOpName("A1"));
298 auto b1 = FloatInput(in_.WithOpName("B1"));
299 Combine(in_.WithOpName("B2"), a1, b1);
300 Combine(in_.WithOpName("B3"), a1, a1);
301
302 Partition(ToGraphDef(), &partitions_);
303 EXPECT_EQ(2, partitions_.size());
304
305 string a = "/job:a/replica:0/task:0/cpu:0";
306 string b = "/job:a/replica:0/task:0/cpu:1";
307 a1 = FloatInput(scope_a_.WithOpName("A1"));
308 _Send(scope_a_.WithOpName("A1/_0"), a1, "edge_1_A1", a, 82, b);
309 ExpectMatchA();
310
311 auto recv =
312 _Recv(scope_b_.WithOpName("A1/_1"), DT_FLOAT, "edge_1_A1", a, 82, b);
313 b1 = FloatInput(scope_b_.WithOpName("B1"));
314 Combine(scope_b_.WithOpName("B2"), recv, b1);
315 Combine(scope_b_.WithOpName("B3"), recv, recv);
316 ExpectMatchB();
317 }
318
TEST_F(GraphPartitionTest,CrossDeviceControl_MultiUse)319 TEST_F(GraphPartitionTest, CrossDeviceControl_MultiUse) {
320 auto a1 = FloatInput(in_.WithOpName("A1"));
321 auto b1 = FloatInput(in_.WithOpName("B1"));
322 Combine(in_.WithOpName("B2").WithControlDependencies(a1), b1, b1);
323 FloatInput(in_.WithOpName("B3").WithControlDependencies(a1));
324
325 Partition(ToGraphDef(), &partitions_);
326 EXPECT_EQ(2, partitions_.size());
327
328 string a = "/job:a/replica:0/task:0/cpu:0";
329 string b = "/job:a/replica:0/task:0/cpu:1";
330 a1 = FloatInput(scope_a_.WithOpName("A1"));
331 auto c = Const(scope_a_.WithOpName("A1/_0").WithControlDependencies(a1), {});
332 _Send(scope_a_.WithOpName("A1/_1"), c, "edge_3_A1", a, 82, b);
333 ExpectMatchA();
334
335 auto recv =
336 _Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_3_A1", a, 82, b);
337 auto id = Identity(scope_b_.WithOpName("A1/_3"), recv);
338 b1 = FloatInput(scope_b_.WithOpName("B1"));
339 Combine(scope_b_.WithOpName("B2").WithControlDependencies(id), b1, b1);
340 FloatInput(scope_b_.WithOpName("B3").WithControlDependencies(id));
341 ExpectMatchB();
342 }
343
TEST_F(GraphPartitionTest,CrossDevice_DataControl)344 TEST_F(GraphPartitionTest, CrossDevice_DataControl) {
345 auto a1 = FloatInput(in_.WithOpName("A1"));
346 auto b1 = FloatInput(in_.WithOpName("B1"));
347 Combine(in_.WithOpName("B2"), a1, b1);
348 FloatInput(in_.WithOpName("B3").WithControlDependencies(a1));
349
350 Partition(ToGraphDef(), &partitions_);
351 EXPECT_EQ(2, partitions_.size());
352
353 string a = "/job:a/replica:0/task:0/cpu:0";
354 string b = "/job:a/replica:0/task:0/cpu:1";
355 a1 = FloatInput(scope_a_.WithOpName("A1"));
356 _Send(scope_a_.WithOpName("A1/_0"), a1, "edge_1_A1", a, 82, b);
357 auto c = Const(scope_a_.WithOpName("A1/_2").WithControlDependencies(a1), {});
358 // NOTE: Send 0 A1/_1 -> A1/_2 is not necessarily needed. We could
359 // use A1/_0 -> A1/_4 as the control as a minor optimization.
360 _Send(scope_a_.WithOpName("A1/_3"), c, "edge_3_A1", a, 82, b);
361 ExpectMatchA();
362
363 auto recv1 =
364 _Recv(scope_b_.WithOpName("A1/_4"), DT_FLOAT, "edge_3_A1", a, 82, b);
365 auto id1 = Identity(scope_b_.WithOpName("A1/_5"), recv1);
366 auto recv2 =
367 _Recv(scope_b_.WithOpName("A1/_1"), DT_FLOAT, "edge_1_A1", a, 82, b);
368 b1 = FloatInput(scope_b_.WithOpName("B1"));
369 Combine(scope_b_.WithOpName("B2"), recv2, b1);
370 FloatInput(scope_b_.WithOpName("B3").WithControlDependencies(id1));
371 ExpectMatchB();
372 }
373
TEST_F(GraphPartitionTest,CrossDeviceLoopSimple)374 TEST_F(GraphPartitionTest, CrossDeviceLoopSimple) {
375 auto a1 = BoolInput(in_.WithOpName("A1"));
376 auto a2 = ::tensorflow::ops::internal::Enter(in_.WithOpName("A2"), a1, "foo");
377 auto a3 = ::tensorflow::ops::Merge(in_.WithOpName("A3"),
378 {a2, Input("A5", 0, DT_BOOL)})
379 .output;
380 LoopCond(in_.WithOpName("A4"), a3);
381 auto b1 = Identity(in_.WithOpName("B1"), a3);
382 NextIteration(in_.WithOpName("A5"), b1);
383
384 CheckLoopConstruction(ToGraphDef());
385 }
386
TEST_F(GraphPartitionTest,CrossDeviceLoopSimple1)387 TEST_F(GraphPartitionTest, CrossDeviceLoopSimple1) {
388 auto a1 = BoolInput(in_.WithOpName("A1"));
389 auto a2 = ::tensorflow::ops::internal::Enter(in_.WithOpName("B2"), a1, "foo");
390 auto a3 = ::tensorflow::ops::Merge(in_.WithOpName("A3"),
391 {a2, Input("B5", 0, DT_BOOL)})
392 .output;
393 LoopCond(in_.WithOpName("A4"), a3);
394 auto b1 = Identity(in_.WithOpName("B1"), a3);
395 NextIteration(in_.WithOpName("B5"), b1);
396
397 std::unordered_map<string, GraphDef> partitions;
398 Partition(ToGraphDef(), &partitions);
399 for (const auto& kv : partitions) {
400 const GraphDef& gdef = kv.second;
401 for (const NodeDef& ndef : gdef.node()) {
402 if (ndef.name() == "A3") {
403 // A3, B2, and B5 are on the same device.
404 EXPECT_EQ(ndef.input(0), "B2");
405 EXPECT_EQ(ndef.input(1), "B5");
406 }
407 }
408 }
409 }
410
TEST_F(GraphPartitionTest,CrossDeviceLoopFull)411 TEST_F(GraphPartitionTest, CrossDeviceLoopFull) {
412 Scope cpu0 = in_.WithDevice("/job:a/replica:0/task:0/cpu:0");
413 auto p1 = ops::Placeholder(cpu0, DT_INT32);
414 auto p2 = ops::Placeholder(cpu0, DT_INT32);
415 OutputList outputs;
416 // while i1 < 10: i1 += i2
417 TF_ASSERT_OK(ops::BuildWhileLoop(
418 cpu0, {p1, p2},
419 [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
420 *output = ops::Less(s, inputs[0], 10);
421 return s.status();
422 },
423 [](const Scope& s, const std::vector<Output>& inputs,
424 std::vector<Output>* outputs) {
425 Scope cpu1 = s.WithDevice("/job:a/replica:0/task:0/cpu:1");
426 outputs->push_back(ops::AddN(cpu1, {inputs[0], inputs[1]}));
427 outputs->push_back(inputs[1]);
428 return s.status();
429 },
430 "test_loop", &outputs));
431 CheckLoopConstruction(ToGraphDef());
432 }
433
TEST_F(GraphPartitionTest,PartitionIncompleteGraph)434 TEST_F(GraphPartitionTest, PartitionIncompleteGraph) {
435 NodeDef ndef;
436 Graph g(OpRegistry::Global());
437 // Invalid graph since the Combine node requires an input.
438 bool parsed = protobuf::TextFormat::ParseFromString(
439 R"EOF(
440 name: "N"
441 op: "Combine"
442 )EOF",
443 &ndef);
444 ASSERT_TRUE(parsed);
445 Status status;
446 g.AddNode(ndef, &status);
447 TF_ASSERT_OK(status);
448
449 PartitionOptions popts;
450 popts.node_to_loc = SplitByDevice;
451 popts.new_name = [&g](const string& prefix) { return g.NewName(prefix); };
452 popts.get_incarnation = [](const string&) { return 1; };
453
454 std::unordered_map<string, GraphDef> partitions;
455 status = Partition(popts, &g, &partitions);
456 // Partitioning should fail, but not crash like it did before the
457 // changes that accompanied the addition of this test.
458 EXPECT_EQ(error::INVALID_ARGUMENT, status.code()) << status;
459 }
460
TEST_F(GraphPartitionTest,Functions)461 TEST_F(GraphPartitionTest, Functions) {
462 FunctionDefLibrary fdef_lib;
463 *fdef_lib.add_function() = test::function::XTimesTwo();
464 *fdef_lib.add_function() = test::function::XTimesFour();
465 TF_ASSERT_OK(in_.graph()->AddFunctionLibrary(fdef_lib));
466
467 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
468 auto a1 = FloatInput(in_.WithOpName("A1"));
469 auto b1 = FloatInput(in_.WithOpName("B1"));
470 ConstructOp(in_.WithOpName("A2"), "XTimesTwo", {a1});
471 ConstructOp(in_.WithOpName("B2"), "XTimesFour", {b1});
472
473 // The `Partition()` helper function uses the first letter of the op name ('A'
474 // or 'B') to choose a device for each node.
475 Partition(ToGraphDef(), &partitions_);
476 EXPECT_EQ(2, partitions_.size());
477
478 // Test that partition graphs inherit function library from original graph.
479 string a = "/job:a/replica:0/task:0/cpu:0";
480 string b = "/job:a/replica:0/task:0/cpu:1";
481
482 // Node "A2" is placed in part `a`, and uses only "XTimesTwo".
483 ExpectFunctions(partitions_[a].library(), {"XTimesTwo"});
484 // Node "B2" is placed in part `b`, and uses both "XTimesFour" directly,
485 // and "XTimesTwo" in the body of "XTimesFour".
486 ExpectFunctions(partitions_[b].library(), {"XTimesTwo", "XTimesFour"});
487 }
488
TEST_F(GraphPartitionTest,SetIncarnation)489 TEST_F(GraphPartitionTest, SetIncarnation) {
490 GraphDef gdef;
491 const char* const kSendRecvAttrs = R"proto(
492 attr { key: 'T' value { type: DT_FLOAT } }
493 attr { key: 'client_terminated' value { b: false } }
494 attr { key: 'recv_device' value { s: 'B' } }
495 attr { key: 'send_device' value { s: 'A' } }
496 attr { key: 'send_device_incarnation' value { i: 0 } }
497 attr { key: 'tensor_name' value { s: 'test' } }
498 )proto";
499 CHECK(protobuf::TextFormat::ParseFromString(
500 strings::StrCat(
501 "node { name: 'A/Pi' op: 'Const' ",
502 " attr { key: 'dtype' value { type: DT_FLOAT } } ",
503 " attr { key: 'value' value { tensor { ",
504 " dtype: DT_FLOAT tensor_shape {} float_val: 3.14 } } } }",
505 "node { name: 'A' op: '_Send' input: 'A/Pi' ", kSendRecvAttrs, "}",
506 "node { name: 'B' op: '_Recv' ", kSendRecvAttrs,
507 " attr { key: 'tensor_type' value { type:DT_FLOAT}}}"),
508 &gdef));
509 gdef.mutable_versions()->set_producer(TF_GRAPH_DEF_VERSION);
510 Partition(gdef, &partitions_);
511 EXPECT_EQ(2, partitions_.size());
512
513 for (const auto& kv : partitions_) {
514 const GraphDef& gdef = kv.second;
515 for (const NodeDef& ndef : gdef.node()) {
516 if (ndef.name() == "A" || ndef.name() == "B") {
517 int64 val;
518 TF_CHECK_OK(GetNodeAttr(ndef, "send_device_incarnation", &val));
519 EXPECT_EQ(val, 100); // Send device is "A".
520 }
521 }
522 }
523 }
524
TEST(TopologicalSortNodesWithTimePriorityTest,NoDependencies)525 TEST(TopologicalSortNodesWithTimePriorityTest, NoDependencies) {
526 // Create placeholders, shuffle them so the order in the graph is not strictly
527 // increasing.
528 Scope root = Scope::NewRootScope().ExitOnError();
529 std::vector<int> indexes;
530 for (int i = 0; i < 20; ++i) {
531 indexes.push_back((i + 2001) % 20);
532 }
533 std::vector<ops::Placeholder> placeholders;
534 for (int i : indexes) {
535 placeholders.emplace_back(root.WithOpName(strings::StrCat("p", i)),
536 DT_FLOAT);
537 placeholders.back().node()->AddAttr("_start_time", i + 1);
538 }
539
540 GraphDef gdef;
541 TF_EXPECT_OK(root.ToGraphDef(&gdef));
542
543 std::vector<std::pair<const NodeDef*, int64>> nodes;
544 std::unordered_map<const NodeDef*, int64> node_to_start_time;
545 TF_CHECK_OK(
546 TopologicalSortNodesWithTimePriority(&gdef, &nodes, &node_to_start_time));
547 ASSERT_EQ(nodes.size(), 20);
548 for (int i = 0; i < nodes.size(); ++i) {
549 EXPECT_EQ(strings::StrCat("p", i), nodes[i].first->name());
550 EXPECT_EQ(i + 1, nodes[i].second);
551 }
552 }
553
TEST(TopologicalSortNodesWithTimePriority,Dependencies)554 TEST(TopologicalSortNodesWithTimePriority, Dependencies) {
555 // Create placeholders, shuffle them so the order in the graph is not strictly
556 // increasing.
557 Scope root = Scope::NewRootScope().ExitOnError();
558 std::vector<int> indexes;
559 std::vector<ops::Placeholder> placeholders_in_order;
560 const int num_leaves = 20;
561 for (int i = 0; i < num_leaves; ++i) {
562 indexes.push_back((i + 2001) % num_leaves);
563 placeholders_in_order.emplace_back(root.WithOpName(strings::StrCat("p", i)),
564 DT_FLOAT);
565 placeholders_in_order.back().node()->AddAttr("_start_time", i + 1);
566 }
567 std::vector<ops::Placeholder> placeholders;
568 for (int i : indexes) {
569 placeholders.push_back(placeholders_in_order[i]);
570 }
571
572 // Create ops that depend on the placeholders. We give start times to these
573 // that are in descending order (e.g., the op that depends on the first
574 // placeholder runs last).
575 std::vector<ops::Square> squares;
576 for (int i : indexes) {
577 squares.emplace_back(root.WithOpName(strings::StrCat("s", i)),
578 placeholders[i]);
579 squares.back().node()->AddAttr("_start_time", 50 - (i + 1));
580 }
581
582 // Create addn to sum all squares.
583 std::vector<Input> inputs;
584 for (const auto& s : squares) inputs.push_back(s);
585 ops::AddN addn = ops::AddN(root.WithOpName("addn"),
586 tensorflow::gtl::ArraySlice<Input>(inputs));
587 // Start times is actually listed earlier than the nodes it depends on.
588 // But because of dependency ordering, it is last in the list.
589 addn.node()->AddAttr("_start_time", 1);
590
591 GraphDef gdef;
592 TF_EXPECT_OK(root.ToGraphDef(&gdef));
593
594 std::vector<std::pair<const NodeDef*, int64>> nodes;
595 std::unordered_map<const NodeDef*, int64> node_to_start_time;
596 TF_CHECK_OK(
597 TopologicalSortNodesWithTimePriority(&gdef, &nodes, &node_to_start_time));
598 ASSERT_EQ(1 + squares.size() + placeholders.size(), nodes.size());
599 for (int i = 0; i < placeholders.size(); ++i) {
600 const NodeDef* node = nodes[i].first;
601 EXPECT_EQ(strings::StrCat("p", i), node->name());
602 EXPECT_EQ(i + 1, nodes[i].second);
603 EXPECT_EQ(i + 1, node_to_start_time[node]);
604 }
605 for (int i = 0; i < squares.size(); ++i) {
606 int node_index = placeholders.size() + i;
607 int square_index = num_leaves - 1 - i;
608 const NodeDef* node = nodes[node_index].first;
609 EXPECT_EQ(strings::StrCat("s", square_index), node->name());
610 EXPECT_EQ(50 - (square_index + 1), nodes[node_index].second);
611 EXPECT_EQ(50 - (square_index + 1), node_to_start_time[node]);
612 }
613 EXPECT_EQ("addn", nodes.back().first->name());
614 EXPECT_EQ(50, nodes.back().second);
615 EXPECT_EQ(50, node_to_start_time[nodes.back().first]);
616 }
617
TEST(TopologicalSortNodesWithTimePriority,WhileLoop)618 TEST(TopologicalSortNodesWithTimePriority, WhileLoop) {
619 using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
620 using namespace ::tensorflow::ops::internal; // NOLINT(build/namespaces)
621
622 // Create placeholders.
623 Scope root = Scope::NewRootScope().ExitOnError();
624 std::vector<int> indexes;
625 std::vector<Placeholder> placeholders_in_order;
626 const int num_leaves = 20;
627 for (int i = 0; i < num_leaves; ++i) {
628 indexes.push_back((i + 2001) % num_leaves);
629 placeholders_in_order.emplace_back(root.WithOpName(strings::StrCat("p", i)),
630 DT_FLOAT);
631 placeholders_in_order.back().node()->AddAttr("_start_time", i + 1);
632 }
633 std::vector<Placeholder> placeholders;
634 placeholders.reserve(indexes.size());
635 for (int i : indexes) {
636 placeholders.push_back(placeholders_in_order[i]);
637 }
638
639 // Add a while loop above each placeholder.
640 std::vector<Exit> while_exits;
641 const int nodes_per_loop = 8;
642 for (int i : indexes) {
643 Scope scope = root.NewSubScope(strings::StrCat("while", i));
644 auto dummy = Placeholder(scope, DT_FLOAT);
645
646 Enter enter(scope, placeholders[i], strings::StrCat("frame", i));
647 Merge merge(scope, std::initializer_list<Input>{enter, dummy});
648 auto cv = Const(scope.WithControlDependencies({merge.output}), false);
649 LoopCond loop_cond(scope, cv);
650 Switch switch_node(scope, merge.output, loop_cond);
651 Identity identity(scope, switch_node.output_true);
652 NextIteration next_iteration(scope, identity);
653 while_exits.emplace_back(scope.WithOpName("exit"),
654 switch_node.output_false);
655
656 // Complete loop by removing dummy node and attaching NextIteration to
657 // that input of the merge node.
658 scope.graph()->RemoveNode(dummy.node());
659 scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1);
660
661 int base_start_time = i * 10 + 100;
662 for (const auto& op : std::initializer_list<Output>{
663 enter, merge.output, cv, loop_cond, switch_node.output_false,
664 identity, next_iteration, while_exits.back()}) {
665 op.node()->AddAttr("_start_time", base_start_time++);
666 }
667 }
668
669 // Create ops that depend on the loop exits.
670 std::vector<Square> squares;
671 squares.reserve(indexes.size());
672 for (int i : indexes) {
673 squares.emplace_back(root.WithOpName(strings::StrCat("s", i)),
674 while_exits[i]);
675 squares.back().node()->AddAttr("_start_time", 500 - (i + 1));
676 }
677
678 GraphDef gdef;
679 TF_EXPECT_OK(root.ToGraphDef(&gdef));
680
681 // Run the sort. The while loop nodes do not appear in the output <nodes>.
682 std::vector<std::pair<const NodeDef*, int64>> nodes;
683 std::unordered_map<const NodeDef*, int64> node_to_start_time;
684 TF_CHECK_OK(
685 TopologicalSortNodesWithTimePriority(&gdef, &nodes, &node_to_start_time));
686 ASSERT_LT(while_exits.size() + squares.size() + placeholders.size(),
687 nodes.size());
688 int node_index = 0;
689 for (int i = 0; i < placeholders.size(); ++i, ++node_index) {
690 const NodeDef* node = nodes[i].first;
691 EXPECT_EQ(strings::StrCat("p", i), node->name());
692 EXPECT_EQ(i + 1, nodes[i].second);
693 EXPECT_EQ(i + 1, node_to_start_time[node]);
694 }
695 for (int i = 0; i < while_exits.size(); ++i, node_index += nodes_per_loop) {
696 const NodeDef* node = nodes[node_index].first;
697 EXPECT_EQ(strings::StrCat("while", i, "/Enter"), node->name());
698 EXPECT_EQ(100 + i * 10, nodes[node_index].second);
699 EXPECT_EQ(100 + i * 10, node_to_start_time[node]);
700 }
701 for (int i = 0; i < squares.size(); ++i, ++node_index) {
702 int square_index = num_leaves - 1 - i;
703 const NodeDef* node = nodes[node_index].first;
704 EXPECT_EQ(strings::StrCat("s", square_index), node->name());
705 EXPECT_EQ(500 - (square_index + 1), nodes[node_index].second);
706 EXPECT_EQ(500 - (square_index + 1), node_to_start_time[node]);
707 }
708 }
709
710 } // namespace
711 } // namespace tensorflow
712