1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/costs/virtual_scheduler.h"
17 #include "tensorflow/cc/ops/standard_ops.h"
18 #include "tensorflow/core/framework/tensor.pb.h" // NOLINT
19 #include "tensorflow/core/framework/tensor_description.pb.h"
20 #include "tensorflow/core/framework/tensor_shape.pb.h"
21 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
22 #include "tensorflow/core/grappler/costs/utils.h"
23 #include "tensorflow/core/grappler/costs/virtual_placer.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26
27 namespace tensorflow {
28 namespace grappler {
29
30 // Class for testing virtual scheduler.
31 class TestVirtualScheduler : public VirtualScheduler {
32 public:
TestVirtualScheduler(const bool use_static_shapes,const bool use_aggressive_shape_inference,Cluster * cluster)33 TestVirtualScheduler(const bool use_static_shapes,
34 const bool use_aggressive_shape_inference,
35 Cluster* cluster)
36 : VirtualScheduler(use_static_shapes, use_aggressive_shape_inference,
37 cluster, &ready_node_manager_) {
38 enable_mem_usage_tracking();
39 }
40
41 FRIEND_TEST(VirtualSchedulerTest, MemoryUsage);
42 FRIEND_TEST(VirtualSchedulerTest, ControlDependency);
43 FRIEND_TEST(VirtualSchedulerTest, ComplexDependency);
44 FRIEND_TEST(VirtualSchedulerTest, Variable);
45 FRIEND_TEST(VirtualSchedulerTest, InterDeviceTransfer);
46
47 protected:
48 FirstReadyManager ready_node_manager_;
49 };
50
51 class VirtualSchedulerTest : public ::testing::Test {
52 protected:
VirtualSchedulerTest()53 VirtualSchedulerTest() {
54 // node1_ to node6_ on kCPU0, with time_ready in reverse_order.
55 NodeSetUp("Node1", kConv2D, kCPU0, 6000, &node1_);
56 NodeSetUp("Node2", kConv2D, kCPU0, 5000, &node2_);
57 NodeSetUp("Node3", kConv2D, kCPU0, 4000, &node3_);
58 NodeSetUp("Node4", kConv2D, kCPU0, 3000, &node4_);
59 NodeSetUp("Node5", kConv2D, kCPU0, 2000, &node5_);
60 NodeSetUp("Node6", kConv2D, kCPU0, 1000, &node6_);
61
62 // Initializes cluster_ and scheduler_.
63 std::unordered_map<string, DeviceProperties> devices;
64
65 // Set some dummy CPU properties
66 DeviceProperties cpu_device = GetDummyCPUDevice();
67
68 // IMPORTANT: Device is not actually ever used in the test case since
69 // force_cpu_type is defaulted to "Haswell"
70 devices[kCPU0] = cpu_device;
71 devices[kCPU1] = cpu_device;
72 cluster_ = absl::make_unique<VirtualCluster>(devices);
73 scheduler_ = absl::make_unique<TestVirtualScheduler>(
74 /*use_static_shapes=*/true,
75 /*use_aggressive_shape_inference=*/true, cluster_.get());
76 }
77
78 NodeDef node1_, node2_, node3_, node4_, node5_, node6_;
79 std::unordered_map<const NodeDef*, NodeState> node_states_;
80
81 // Device names:
82 const string kCPU0 = "/job:localhost/replica:0/task:0/cpu:0";
83 const string kCPU1 = "/job:localhost/replica:0/task:0/cpu:1";
84 const string kChannelFrom0To1 = "Channel from CPU0 to CPU1";
85 const string kChannelFrom1To0 = "Channel from CPU1 to CPU0";
86 // Op names:
87 const string kSend = "_Send";
88 const string kRecv = "_Recv";
89 const string kConv2D = "Conv2D";
90
GetDummyCPUDevice()91 DeviceProperties GetDummyCPUDevice() {
92 // Create CPU with 2 cores, 4 Ghz freq, 2 GB/s mem bandwidth.
93 // - 8 Gflops
94 // - 2 GB/s
95 DeviceProperties cpu_device;
96 cpu_device.set_type("CPU");
97 cpu_device.set_frequency(4000);
98 cpu_device.set_num_cores(2);
99 cpu_device.set_bandwidth(2000000);
100 return cpu_device;
101 }
102
NodeSetUp(const string & name,const string & op_name,const string & device_name,const uint64 time_ready,NodeDef * node)103 void NodeSetUp(const string& name, const string& op_name,
104 const string& device_name, const uint64 time_ready,
105 NodeDef* node) {
106 node->set_name(name);
107 node->set_op(op_name);
108 node->set_device(device_name);
109
110 node_states_[node] = NodeState();
111 node_states_[node].time_ready = time_ready;
112 node_states_[node].device_name = device_name;
113 }
114
115 // Three Conv2Ds with only two in fetch nodes.
CreateGrapplerItemWithConv2Ds()116 void CreateGrapplerItemWithConv2Ds() {
117 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
118 auto x = ops::RandomUniform(
119 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
120 auto y = ops::RandomUniform(
121 s.WithOpName("y"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
122 auto z = ops::RandomUniform(
123 s.WithOpName("z"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
124 auto f = ops::RandomUniform(
125 s.WithOpName("f"), {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
126 std::vector<int> strides = {1, 1, 1, 1};
127 auto c0 = ops::Conv2D(s.WithOpName("c0"), x, f, strides, "SAME");
128 auto c1 = ops::Conv2D(s.WithOpName("c1"), y, f, strides, "SAME");
129 auto c2 = ops::Conv2D(s.WithOpName("c2"), z, f, strides, "SAME");
130 GraphDef def;
131 TF_CHECK_OK(s.ToGraphDef(&def));
132
133 grappler_item_.reset(new GrapplerItem);
134 grappler_item_->id = "test_conv2d_graph";
135 grappler_item_->graph = def;
136 grappler_item_->fetch = {"c0", "c1"};
137
138 dependency_["c0"] = {"x", "f"};
139 dependency_["c1"] = {"y", "f"};
140 }
141
142 // A Conv2D with a variable.
CreateGrapplerItemWithConv2DAndVariable()143 void CreateGrapplerItemWithConv2DAndVariable() {
144 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
145 auto x = ops::RandomUniform(
146 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
147 auto f = ops::Variable(s.WithOpName("f"),
148 {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
149 std::vector<int> strides = {1, 1, 1, 1};
150 auto y = ops::Conv2D(s.WithOpName("y"), x, f, strides, "SAME");
151 GraphDef def;
152 TF_CHECK_OK(s.ToGraphDef(&def));
153
154 grappler_item_.reset(new GrapplerItem);
155 grappler_item_->id = "test_conv2d_var_graph";
156 grappler_item_->graph = def;
157 grappler_item_->fetch = {"y"};
158
159 dependency_["y"] = {"x", "f"};
160 }
161
CreateGrapplerItemWithMatmulChain()162 void CreateGrapplerItemWithMatmulChain() {
163 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
164 // Add control dependencies to ensure tests do not rely on specific
165 // manager and the order remains consistent for the test.
166 auto a = ops::RandomUniform(s.WithOpName("a"), {3200, 3200}, DT_FLOAT);
167 auto b = ops::RandomUniform(s.WithOpName("b").WithControlDependencies(a),
168 {3200, 3200}, DT_FLOAT);
169 auto c = ops::RandomUniform(s.WithOpName("c").WithControlDependencies(b),
170 {3200, 3200}, DT_FLOAT);
171 auto d = ops::RandomUniform(s.WithOpName("d").WithControlDependencies(c),
172 {3200, 3200}, DT_FLOAT);
173 auto e = ops::RandomUniform(s.WithOpName("e").WithControlDependencies(d),
174 {3200, 3200}, DT_FLOAT);
175
176 auto ab = ops::MatMul(s.WithOpName("ab").WithControlDependencies(e), a, b);
177 auto abc = ops::MatMul(s.WithOpName("abc"), ab, c);
178 auto abcd = ops::MatMul(s.WithOpName("abcd"), abc, d);
179 auto abcde = ops::MatMul(s.WithOpName("abcde"), abcd, e);
180
181 GraphDef def;
182 TF_CHECK_OK(s.ToGraphDef(&def));
183
184 grappler_item_.reset(new GrapplerItem);
185 grappler_item_->id = "test_matmul_sequence_graph";
186 grappler_item_->graph = def;
187 grappler_item_->fetch = {"abcde"};
188
189 dependency_["ab"] = {"a", "b"};
190 dependency_["abc"] = {"ab", "c"};
191 dependency_["abcd"] = {"abc", "d"};
192 dependency_["abcde"] = {"abcd", "e"};
193 }
194
195 // AddN that takes 4 tensors with 10x10x10x10.
CreateGrapplerItemWithAddN()196 void CreateGrapplerItemWithAddN() {
197 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
198 auto x = ops::RandomUniform(s.WithOpName("x"), {10, 10, 10, 10}, DT_FLOAT);
199 auto y = ops::RandomUniform(s.WithOpName("y"), {10, 10, 10, 10}, DT_FLOAT);
200 auto z = ops::RandomUniform(s.WithOpName("z"), {10, 10, 10, 10}, DT_FLOAT);
201 auto w = ops::RandomUniform(s.WithOpName("w"), {10, 10, 10, 10}, DT_FLOAT);
202 OutputList input_tensors = {x, y, z, w};
203 auto out = ops::AddN(s.WithOpName("out"), input_tensors);
204 GraphDef def;
205 TF_CHECK_OK(s.ToGraphDef(&def));
206
207 grappler_item_.reset(new GrapplerItem);
208 grappler_item_->id = "test_addn_graph";
209 grappler_item_->graph = def;
210 grappler_item_->fetch = {"out"};
211
212 dependency_["out"] = {"x", "y", "z", "w"};
213 }
214
215 // Graph with some placeholder feed nodes that are not in the fetch fan-in.
CreateGrapplerItemWithUnnecessaryPlaceholderNodes()216 void CreateGrapplerItemWithUnnecessaryPlaceholderNodes() {
217 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
218 auto unnecessary = ops::Placeholder(s.WithOpName("unnecessary"), DT_FLOAT);
219 auto x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT);
220
221 GraphDef def;
222 TF_CHECK_OK(s.ToGraphDef(&def));
223
224 grappler_item_.reset(new GrapplerItem);
225 grappler_item_->id = "test_extra_placeholders";
226 grappler_item_->graph = def;
227 grappler_item_->fetch = {"x"};
228
229 // Grappler Item Builder puts all placeholder nodes into the feed
230 // list by default.
231 grappler_item_->feed = {{"x", Tensor()}, {"unnecessary", Tensor()}};
232 }
233
234 // NoOp that takes 7 NoOps as control dependency.
CreateGrapplerItemWithControlDependency()235 void CreateGrapplerItemWithControlDependency() {
236 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
237 std::vector<string> input_noop_names = {"x", "y", "z", "w", "u", "v", "t"};
238 std::vector<Operation> input_tensors;
239 for (const auto& input : input_noop_names) {
240 auto x = ops::NoOp(s.WithOpName(input));
241 input_tensors.push_back(x.operation);
242 }
243 auto out =
244 ops::NoOp(s.WithControlDependencies(input_tensors).WithOpName("out"));
245 GraphDef def;
246 TF_CHECK_OK(s.ToGraphDef(&def));
247
248 grappler_item_.reset(new GrapplerItem);
249 grappler_item_->id = "test_control_dependency_graph";
250 grappler_item_->graph = def;
251 grappler_item_->fetch = {"out"};
252
253 dependency_["out"] = input_noop_names;
254 }
255
256 // FusedBN [an op with multiple outputs] with multiple consumers (including
257 // control dependency).
CreateGrapplerItemWithBatchNorm()258 void CreateGrapplerItemWithBatchNorm() {
259 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
260 auto x = ops::RandomUniform(
261 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
262 auto scale =
263 ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
264 auto offset =
265 ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
266 auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
267 auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
268
269 auto batch_norm = ops::FusedBatchNorm(
270 s.WithOpName("bn"), x, scale, offset, mean, var,
271 ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
272 auto y = batch_norm.y;
273 auto batch_mean = batch_norm.batch_mean;
274 auto batch_var = batch_norm.batch_variance;
275
276 auto z1 = ops::Add(s.WithOpName("z1"), x, y);
277 auto z2 = ops::Add(s.WithOpName("z2"), batch_var, batch_var);
278 auto z3 = ops::Add(s.WithOpName("z3"), batch_var, batch_var);
279 std::vector<Operation> input_tensors = {
280 batch_mean.op(),
281 z1.z.op(),
282 z2.z.op(),
283 z3.z.op(),
284 };
285 auto z4 = ops::NoOp(s.WithControlDependencies(batch_var).WithOpName("z4"));
286
287 GraphDef def;
288 TF_CHECK_OK(s.ToGraphDef(&def));
289
290 grappler_item_.reset(new GrapplerItem);
291 grappler_item_->id = "test_complex_dependency_graph";
292 grappler_item_->graph = def;
293 grappler_item_->fetch = {"z1", "z2", "z3", "z4"};
294
295 dependency_["bn"] = {"x", "scale", "offset", "mean", "var"};
296 dependency_["z1"] = {"x", "bn"};
297 dependency_["z2"] = {"bn"};
298 dependency_["z3"] = {"bn"};
299 dependency_["z4"] = {"bn"};
300 }
301
CreateGrapplerItemWithSendRecv()302 void CreateGrapplerItemWithSendRecv() {
303 const string gdef_ascii = R"EOF(
304 node {
305 name: "Const"
306 op: "Const"
307 device: "/job:localhost/replica:0/task:0/device:CPU:0"
308 attr {
309 key: "dtype"
310 value {
311 type: DT_FLOAT
312 }
313 }
314 attr {
315 key: "value"
316 value {
317 tensor {
318 dtype: DT_FLOAT
319 tensor_shape {
320 }
321 float_val: 3.1415
322 }
323 }
324 }
325 }
326 node {
327 name: "Send"
328 op: "_Send"
329 input: "Const"
330 device: "/job:localhost/replica:0/task:0/device:CPU:0"
331 attr {
332 key: "T"
333 value {
334 type: DT_FLOAT
335 }
336 }
337 attr {
338 key: "client_terminated"
339 value {
340 b: false
341 }
342 }
343 attr {
344 key: "recv_device"
345 value {
346 s: "/job:localhost/replica:0/task:0/device:CPU:0"
347 }
348 }
349 attr {
350 key: "send_device"
351 value {
352 s: "/job:localhost/replica:0/task:0/device:CPU:0"
353 }
354 }
355 attr {
356 key: "send_device_incarnation"
357 value {
358 i: 0
359 }
360 }
361 attr {
362 key: "tensor_name"
363 value {
364 s: "test"
365 }
366 }
367 }
368 node {
369 name: "Recv"
370 op: "_Recv"
371 device: "/job:localhost/replica:0/task:0/device:CPU:0"
372 attr {
373 key: "client_terminated"
374 value {
375 b: false
376 }
377 }
378 attr {
379 key: "recv_device"
380 value {
381 s: "/job:localhost/replica:0/task:0/device:CPU:0"
382 }
383 }
384 attr {
385 key: "send_device"
386 value {
387 s: "/job:localhost/replica:0/task:0/device:CPU:0"
388 }
389 }
390 attr {
391 key: "send_device_incarnation"
392 value {
393 i: 0
394 }
395 }
396 attr {
397 key: "tensor_name"
398 value {
399 s: "test"
400 }
401 }
402 attr {
403 key: "tensor_type"
404 value {
405 type: DT_FLOAT
406 }
407 }
408 }
409 library {
410 }
411 versions {
412 producer: 24
413 }
414 )EOF";
415
416 grappler_item_.reset(new GrapplerItem);
417 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
418 &grappler_item_->graph));
419 grappler_item_->id = "test_graph";
420 grappler_item_->fetch = {"Recv"};
421 }
422
CreateGrapplerItemWithRecvWithoutSend()423 void CreateGrapplerItemWithRecvWithoutSend() {
424 const string gdef_ascii = R"EOF(
425 node {
426 name: "Recv"
427 op: "_Recv"
428 device: "/job:localhost/replica:0/task:0/device:CPU:0"
429 attr {
430 key: "client_terminated"
431 value {
432 b: false
433 }
434 }
435 attr {
436 key: "recv_device"
437 value {
438 s: "/job:localhost/replica:0/task:0/device:CPU:0"
439 }
440 }
441 attr {
442 key: "send_device"
443 value {
444 s: "/job:localhost/replica:0/task:0/device:CPU:0"
445 }
446 }
447 attr {
448 key: "send_device_incarnation"
449 value {
450 i: 0
451 }
452 }
453 attr {
454 key: "tensor_name"
455 value {
456 s: "test"
457 }
458 }
459 attr {
460 key: "tensor_type"
461 value {
462 type: DT_FLOAT
463 }
464 }
465 }
466 library {
467 }
468 versions {
469 producer: 24
470 }
471 )EOF";
472
473 grappler_item_.reset(new GrapplerItem);
474 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
475 &grappler_item_->graph));
476 grappler_item_->id = "test_graph";
477 grappler_item_->fetch = {"Recv"};
478 }
479
480 // A simple while loop
CreateGrapplerItemWithLoop()481 void CreateGrapplerItemWithLoop() {
482 // Test graph produced in python using:
483 /*
484 with tf.Graph().as_default():
485 i0 = tf.constant(0)
486 m0 = tf.ones([2, 2])
487 c = lambda i, m: i < 10
488 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
489 r = tf.while_loop(
490 c, b, loop_vars=[i0, m0],
491 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
492 with open('/tmp/graph.pbtxt', 'w') as f:
493 f.write(str(tf.get_default_graph().as_graph_def()))
494 */
495 const string gdef_ascii = R"EOF(
496 node {
497 name: "Const"
498 op: "Const"
499 attr {
500 key: "dtype"
501 value {
502 type: DT_INT32
503 }
504 }
505 attr {
506 key: "value"
507 value {
508 tensor {
509 dtype: DT_INT32
510 tensor_shape {
511 }
512 int_val: 0
513 }
514 }
515 }
516 }
517 node {
518 name: "ones"
519 op: "Const"
520 attr {
521 key: "dtype"
522 value {
523 type: DT_FLOAT
524 }
525 }
526 attr {
527 key: "value"
528 value {
529 tensor {
530 dtype: DT_FLOAT
531 tensor_shape {
532 dim {
533 size: 2
534 }
535 dim {
536 size: 2
537 }
538 }
539 float_val: 1.0
540 }
541 }
542 }
543 }
544 node {
545 name: "while/Enter"
546 op: "Enter"
547 input: "Const"
548 attr {
549 key: "T"
550 value {
551 type: DT_INT32
552 }
553 }
554 attr {
555 key: "frame_name"
556 value {
557 s: "while/while/"
558 }
559 }
560 attr {
561 key: "is_constant"
562 value {
563 b: false
564 }
565 }
566 attr {
567 key: "parallel_iterations"
568 value {
569 i: 10
570 }
571 }
572 }
573 node {
574 name: "while/Enter_1"
575 op: "Enter"
576 input: "ones"
577 attr {
578 key: "T"
579 value {
580 type: DT_FLOAT
581 }
582 }
583 attr {
584 key: "frame_name"
585 value {
586 s: "while/while/"
587 }
588 }
589 attr {
590 key: "is_constant"
591 value {
592 b: false
593 }
594 }
595 attr {
596 key: "parallel_iterations"
597 value {
598 i: 10
599 }
600 }
601 }
602 node {
603 name: "while/Merge"
604 op: "Merge"
605 input: "while/Enter"
606 input: "while/NextIteration"
607 attr {
608 key: "N"
609 value {
610 i: 2
611 }
612 }
613 attr {
614 key: "T"
615 value {
616 type: DT_INT32
617 }
618 }
619 }
620 node {
621 name: "while/Merge_1"
622 op: "Merge"
623 input: "while/Enter_1"
624 input: "while/NextIteration_1"
625 attr {
626 key: "N"
627 value {
628 i: 2
629 }
630 }
631 attr {
632 key: "T"
633 value {
634 type: DT_FLOAT
635 }
636 }
637 }
638 node {
639 name: "while/Less/y"
640 op: "Const"
641 input: "^while/Merge"
642 attr {
643 key: "dtype"
644 value {
645 type: DT_INT32
646 }
647 }
648 attr {
649 key: "value"
650 value {
651 tensor {
652 dtype: DT_INT32
653 tensor_shape {
654 }
655 int_val: 10
656 }
657 }
658 }
659 }
660 node {
661 name: "while/Less"
662 op: "Less"
663 input: "while/Merge"
664 input: "while/Less/y"
665 attr {
666 key: "T"
667 value {
668 type: DT_INT32
669 }
670 }
671 }
672 node {
673 name: "while/LoopCond"
674 op: "LoopCond"
675 input: "while/Less"
676 }
677 node {
678 name: "while/Switch"
679 op: "Switch"
680 input: "while/Merge"
681 input: "while/LoopCond"
682 attr {
683 key: "T"
684 value {
685 type: DT_INT32
686 }
687 }
688 attr {
689 key: "_class"
690 value {
691 list {
692 s: "loc:@while/Merge"
693 }
694 }
695 }
696 }
697 node {
698 name: "while/Switch_1"
699 op: "Switch"
700 input: "while/Merge_1"
701 input: "while/LoopCond"
702 attr {
703 key: "T"
704 value {
705 type: DT_FLOAT
706 }
707 }
708 attr {
709 key: "_class"
710 value {
711 list {
712 s: "loc:@while/Merge_1"
713 }
714 }
715 }
716 }
717 node {
718 name: "while/Identity"
719 op: "Identity"
720 input: "while/Switch:1"
721 attr {
722 key: "T"
723 value {
724 type: DT_INT32
725 }
726 }
727 }
728 node {
729 name: "while/Identity_1"
730 op: "Identity"
731 input: "while/Switch_1:1"
732 attr {
733 key: "T"
734 value {
735 type: DT_FLOAT
736 }
737 }
738 }
739 node {
740 name: "while/add/y"
741 op: "Const"
742 input: "^while/Identity"
743 attr {
744 key: "dtype"
745 value {
746 type: DT_INT32
747 }
748 }
749 attr {
750 key: "value"
751 value {
752 tensor {
753 dtype: DT_INT32
754 tensor_shape {
755 }
756 int_val: 1
757 }
758 }
759 }
760 }
761 node {
762 name: "while/add"
763 op: "Add"
764 input: "while/Identity"
765 input: "while/add/y"
766 attr {
767 key: "T"
768 value {
769 type: DT_INT32
770 }
771 }
772 }
773 node {
774 name: "while/concat/axis"
775 op: "Const"
776 input: "^while/Identity"
777 attr {
778 key: "dtype"
779 value {
780 type: DT_INT32
781 }
782 }
783 attr {
784 key: "value"
785 value {
786 tensor {
787 dtype: DT_INT32
788 tensor_shape {
789 }
790 int_val: 0
791 }
792 }
793 }
794 }
795 node {
796 name: "while/concat"
797 op: "ConcatV2"
798 input: "while/Identity_1"
799 input: "while/Identity_1"
800 input: "while/concat/axis"
801 attr {
802 key: "N"
803 value {
804 i: 2
805 }
806 }
807 attr {
808 key: "T"
809 value {
810 type: DT_FLOAT
811 }
812 }
813 attr {
814 key: "Tidx"
815 value {
816 type: DT_INT32
817 }
818 }
819 }
820 node {
821 name: "while/NextIteration"
822 op: "NextIteration"
823 input: "while/add"
824 attr {
825 key: "T"
826 value {
827 type: DT_INT32
828 }
829 }
830 }
831 node {
832 name: "while/NextIteration_1"
833 op: "NextIteration"
834 input: "while/concat"
835 attr {
836 key: "T"
837 value {
838 type: DT_FLOAT
839 }
840 }
841 }
842 node {
843 name: "while/Exit"
844 op: "Exit"
845 input: "while/Switch"
846 attr {
847 key: "T"
848 value {
849 type: DT_INT32
850 }
851 }
852 }
853 node {
854 name: "while/Exit_1"
855 op: "Exit"
856 input: "while/Switch_1"
857 attr {
858 key: "T"
859 value {
860 type: DT_FLOAT
861 }
862 }
863 }
864 versions {
865 producer: 21
866 }
867 )EOF";
868
869 grappler_item_.reset(new GrapplerItem);
870 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
871 &grappler_item_->graph));
872 grappler_item_->id = "test_graph";
873 grappler_item_->fetch = {"while/Exit", "while/Exit_1"};
874 }
875
876 // A simple while loop strengthened with Switch outputs xxx.
CreateGrapplerItemWithLoopAnnotated()877 void CreateGrapplerItemWithLoopAnnotated() {
878 // Test graph produced in python using:
879 /*
880 with tf.Graph().as_default():
881 i0 = tf.constant(0)
882 m0 = tf.ones([2, 2])
883 c = lambda i, m: i < 10
884 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
885 r = tf.while_loop(
886 c, b, loop_vars=[i0, m0],
887 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
888 with open('/tmp/graph.pbtxt', 'w') as f:
889 f.write(str(tf.get_default_graph().as_graph_def()))
890 */
891 const string gdef_ascii = R"EOF(
892 node {
893 name: "Const"
894 op: "Const"
895 attr {
896 key: "dtype"
897 value {
898 type: DT_INT32
899 }
900 }
901 attr {
902 key: "value"
903 value {
904 tensor {
905 dtype: DT_INT32
906 tensor_shape {
907 }
908 int_val: 0
909 }
910 }
911 }
912 attr {
913 key: "_execution_count"
914 value {
915 i: 1
916 }
917 }
918 }
919 node {
920 name: "ones"
921 op: "Const"
922 attr {
923 key: "dtype"
924 value {
925 type: DT_FLOAT
926 }
927 }
928 attr {
929 key: "value"
930 value {
931 tensor {
932 dtype: DT_FLOAT
933 tensor_shape {
934 dim {
935 size: 2
936 }
937 dim {
938 size: 2
939 }
940 }
941 float_val: 1.0
942 }
943 }
944 }
945 attr {
946 key: "_execution_count"
947 value {
948 i: 1
949 }
950 }
951 }
952 node {
953 name: "while/Enter"
954 op: "Enter"
955 input: "Const"
956 attr {
957 key: "T"
958 value {
959 type: DT_INT32
960 }
961 }
962 attr {
963 key: "frame_name"
964 value {
965 s: "while/while/"
966 }
967 }
968 attr {
969 key: "is_constant"
970 value {
971 b: false
972 }
973 }
974 attr {
975 key: "parallel_iterations"
976 value {
977 i: 10
978 }
979 }
980 attr {
981 key: "_execution_count"
982 value {
983 i: 1
984 }
985 }
986 }
987 node {
988 name: "while/Enter_1"
989 op: "Enter"
990 input: "ones"
991 attr {
992 key: "T"
993 value {
994 type: DT_FLOAT
995 }
996 }
997 attr {
998 key: "frame_name"
999 value {
1000 s: "while/while/"
1001 }
1002 }
1003 attr {
1004 key: "is_constant"
1005 value {
1006 b: false
1007 }
1008 }
1009 attr {
1010 key: "parallel_iterations"
1011 value {
1012 i: 10
1013 }
1014 }
1015 attr {
1016 key: "_execution_count"
1017 value {
1018 i: 1
1019 }
1020 }
1021 }
1022 node {
1023 name: "while/Merge"
1024 op: "Merge"
1025 input: "while/Enter"
1026 input: "while/NextIteration"
1027 attr {
1028 key: "N"
1029 value {
1030 i: 2
1031 }
1032 }
1033 attr {
1034 key: "T"
1035 value {
1036 type: DT_INT32
1037 }
1038 }
1039 attr {
1040 key: "_execution_count"
1041 value {
1042 i: 10
1043 }
1044 }
1045 }
1046 node {
1047 name: "while/Merge_1"
1048 op: "Merge"
1049 input: "while/Enter_1"
1050 input: "while/NextIteration_1"
1051 attr {
1052 key: "N"
1053 value {
1054 i: 2
1055 }
1056 }
1057 attr {
1058 key: "T"
1059 value {
1060 type: DT_FLOAT
1061 }
1062 }
1063 attr {
1064 key: "_execution_count"
1065 value {
1066 i: 10
1067 }
1068 }
1069 }
1070 node {
1071 name: "while/Less/y"
1072 op: "Const"
1073 input: "^while/Merge"
1074 attr {
1075 key: "dtype"
1076 value {
1077 type: DT_INT32
1078 }
1079 }
1080 attr {
1081 key: "value"
1082 value {
1083 tensor {
1084 dtype: DT_INT32
1085 tensor_shape {
1086 }
1087 int_val: 10
1088 }
1089 }
1090 }
1091 attr {
1092 key: "_execution_count"
1093 value {
1094 i: 10
1095 }
1096 }
1097 }
1098 node {
1099 name: "while/Less"
1100 op: "Less"
1101 input: "while/Merge"
1102 input: "while/Less/y"
1103 attr {
1104 key: "T"
1105 value {
1106 type: DT_INT32
1107 }
1108 }
1109 attr {
1110 key: "_execution_count"
1111 value {
1112 i: 10
1113 }
1114 }
1115 }
1116 node {
1117 name: "while/LoopCond"
1118 op: "LoopCond"
1119 input: "while/Less"
1120 attr {
1121 key: "_execution_count"
1122 value {
1123 i: 10
1124 }
1125 }
1126 }
1127 node {
1128 name: "while/Switch"
1129 op: "Switch"
1130 input: "while/Merge"
1131 input: "while/LoopCond"
1132 attr {
1133 key: "T"
1134 value {
1135 type: DT_INT32
1136 }
1137 }
1138 attr {
1139 key: "_class"
1140 value {
1141 list {
1142 s: "loc:@while/Merge"
1143 }
1144 }
1145 }
1146 attr {
1147 key: "_execution_count"
1148 value {
1149 i: 11
1150 }
1151 }
1152 attr {
1153 key: "_output_slot_vector"
1154 value {
1155 list {
1156 i: 1
1157 i: 1
1158 i: 1
1159 i: 1
1160 i: 1
1161 i: 1
1162 i: 1
1163 i: 1
1164 i: 1
1165 i: 1
1166 i: 0
1167 }
1168 }
1169 }
1170 }
1171 node {
1172 name: "while/Switch_1"
1173 op: "Switch"
1174 input: "while/Merge_1"
1175 input: "while/LoopCond"
1176 attr {
1177 key: "T"
1178 value {
1179 type: DT_FLOAT
1180 }
1181 }
1182 attr {
1183 key: "_class"
1184 value {
1185 list {
1186 s: "loc:@while/Merge_1"
1187 }
1188 }
1189 }
1190 attr {
1191 key: "_execution_count"
1192 value {
1193 i: 11
1194 }
1195 }
1196 attr {
1197 key: "_output_slot_vector"
1198 value {
1199 list {
1200 i: 1
1201 i: 1
1202 i: 1
1203 i: 1
1204 i: 1
1205 i: 1
1206 i: 1
1207 i: 1
1208 i: 1
1209 i: 1
1210 i: 0
1211 }
1212 }
1213 }
1214 }
1215 node {
1216 name: "while/Identity"
1217 op: "Identity"
1218 input: "while/Switch:1"
1219 attr {
1220 key: "T"
1221 value {
1222 type: DT_INT32
1223 }
1224 }
1225 attr {
1226 key: "_execution_count"
1227 value {
1228 i: 10
1229 }
1230 }
1231 }
1232 node {
1233 name: "while/Identity_1"
1234 op: "Identity"
1235 input: "while/Switch_1:1"
1236 attr {
1237 key: "T"
1238 value {
1239 type: DT_FLOAT
1240 }
1241 }
1242 attr {
1243 key: "_execution_count"
1244 value {
1245 i: 10
1246 }
1247 }
1248 }
1249 node {
1250 name: "while/add/y"
1251 op: "Const"
1252 input: "^while/Identity"
1253 attr {
1254 key: "dtype"
1255 value {
1256 type: DT_INT32
1257 }
1258 }
1259 attr {
1260 key: "value"
1261 value {
1262 tensor {
1263 dtype: DT_INT32
1264 tensor_shape {
1265 }
1266 int_val: 1
1267 }
1268 }
1269 }
1270 attr {
1271 key: "_execution_count"
1272 value {
1273 i: 10
1274 }
1275 }
1276 }
1277 node {
1278 name: "while/add"
1279 op: "Add"
1280 input: "while/Identity"
1281 input: "while/add/y"
1282 attr {
1283 key: "T"
1284 value {
1285 type: DT_INT32
1286 }
1287 }
1288 attr {
1289 key: "_execution_count"
1290 value {
1291 i: 10
1292 }
1293 }
1294 }
1295 node {
1296 name: "while/concat/axis"
1297 op: "Const"
1298 input: "^while/Identity"
1299 attr {
1300 key: "dtype"
1301 value {
1302 type: DT_INT32
1303 }
1304 }
1305 attr {
1306 key: "value"
1307 value {
1308 tensor {
1309 dtype: DT_INT32
1310 tensor_shape {
1311 }
1312 int_val: 0
1313 }
1314 }
1315 }
1316 attr {
1317 key: "_execution_count"
1318 value {
1319 i: 10
1320 }
1321 }
1322 }
1323 node {
1324 name: "while/concat"
1325 op: "ConcatV2"
1326 input: "while/Identity_1"
1327 input: "while/Identity_1"
1328 input: "while/concat/axis"
1329 attr {
1330 key: "N"
1331 value {
1332 i: 2
1333 }
1334 }
1335 attr {
1336 key: "T"
1337 value {
1338 type: DT_FLOAT
1339 }
1340 }
1341 attr {
1342 key: "Tidx"
1343 value {
1344 type: DT_INT32
1345 }
1346 }
1347 attr {
1348 key: "_execution_count"
1349 value {
1350 i: 10
1351 }
1352 }
1353 }
1354 node {
1355 name: "while/NextIteration"
1356 op: "NextIteration"
1357 input: "while/add"
1358 attr {
1359 key: "T"
1360 value {
1361 type: DT_INT32
1362 }
1363 }
1364 attr {
1365 key: "_execution_count"
1366 value {
1367 i: 10
1368 }
1369 }
1370 }
1371 node {
1372 name: "while/NextIteration_1"
1373 op: "NextIteration"
1374 input: "while/concat"
1375 attr {
1376 key: "T"
1377 value {
1378 type: DT_FLOAT
1379 }
1380 }
1381 attr {
1382 key: "_execution_count"
1383 value {
1384 i: 10
1385 }
1386 }
1387 }
1388 node {
1389 name: "while/Exit"
1390 op: "Exit"
1391 input: "while/Switch"
1392 attr {
1393 key: "T"
1394 value {
1395 type: DT_INT32
1396 }
1397 }
1398 attr {
1399 key: "_execution_count"
1400 value {
1401 i: 1
1402 }
1403 }
1404 }
1405 node {
1406 name: "while/Exit_1"
1407 op: "Exit"
1408 input: "while/Switch_1"
1409 attr {
1410 key: "T"
1411 value {
1412 type: DT_FLOAT
1413 }
1414 }
1415 attr {
1416 key: "_execution_count"
1417 value {
1418 i: 1
1419 }
1420 }
1421 }
1422 versions {
1423 producer: 21
1424 }
1425 )EOF";
1426
1427 grappler_item_.reset(new GrapplerItem);
1428 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
1429 &grappler_item_->graph));
1430 grappler_item_->id = "test_graph";
1431 grappler_item_->fetch = {"while/Exit", "while/Exit_1"};
1432 }
1433
1434 // A simple condition graph.
CreateGrapplerItemWithCondition()1435 void CreateGrapplerItemWithCondition() {
1436 // Handcrafted test graph: a/Less -> Switch -> First/Second -> Merge.
1437 const string gdef_ascii = R"EOF(
1438 node {
1439 name: "a"
1440 op: "Const"
1441 attr {
1442 key: "dtype"
1443 value {
1444 type: DT_FLOAT
1445 }
1446 }
1447 attr {
1448 key: "value"
1449 value {
1450 tensor {
1451 dtype: DT_FLOAT
1452 tensor_shape {
1453 }
1454 float_val: 2.0
1455 }
1456 }
1457 }
1458 }
1459 node {
1460 name: "Less"
1461 op: "Const"
1462 attr {
1463 key: "dtype"
1464 value {
1465 type: DT_BOOL
1466 }
1467 }
1468 attr {
1469 key: "value"
1470 value {
1471 tensor {
1472 dtype: DT_BOOL
1473 tensor_shape {
1474 }
1475 tensor_content: "\001"
1476 }
1477 }
1478 }
1479 }
1480 node {
1481 name: "Switch"
1482 op: "Switch"
1483 input: "a"
1484 input: "Less"
1485 attr {
1486 key: "T"
1487 value {
1488 type: DT_FLOAT
1489 }
1490 }
1491 }
1492 node {
1493 name: "First"
1494 op: "Identity"
1495 input: "Switch"
1496 attr {
1497 key: "T"
1498 value {
1499 type: DT_FLOAT
1500 }
1501 }
1502 }
1503 node {
1504 name: "Second"
1505 op: "Identity"
1506 input: "Switch:1"
1507 attr {
1508 key: "T"
1509 value {
1510 type: DT_FLOAT
1511 }
1512 }
1513 }
1514 node {
1515 name: "Merge"
1516 op: "Merge"
1517 input: "First"
1518 input: "Second"
1519 attr {
1520 key: "N"
1521 value {
1522 i: 2
1523 }
1524 }
1525 attr {
1526 key: "T"
1527 value {
1528 type: DT_FLOAT
1529 }
1530 }
1531 }
1532 versions {
1533 producer: 27
1534 })EOF";
1535
1536 grappler_item_.reset(new GrapplerItem);
1537 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
1538 &grappler_item_->graph));
1539 grappler_item_->id = "test_graph";
1540 grappler_item_->fetch = {"Merge"};
1541 }
1542
1543 // Create a FusedBatchNorm op that has multiple output ports.
CreateGrapplerItemWithInterDeviceTransfers()1544 void CreateGrapplerItemWithInterDeviceTransfers() {
1545 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
1546
1547 // Create a FusedBatchNorm op that has multiple output ports.
1548 auto x = ops::RandomUniform(
1549 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
1550 auto scale =
1551 ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
1552 auto offset =
1553 ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
1554 auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
1555 auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
1556
1557 auto batch_norm = ops::FusedBatchNorm(
1558 s.WithOpName("bn"), x, scale, offset, mean, var,
1559 ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
1560 auto y = batch_norm.y;
1561 auto batch_mean = batch_norm.batch_mean;
1562 auto batch_var = batch_norm.batch_variance;
1563 // y1 and y2 take the same tensor, so there should be only 1 Send and Recv.
1564 auto y1 = ops::Identity(s.WithOpName("y1").WithDevice(kCPU1), y);
1565 auto y2 = ops::Identity(s.WithOpName("y2").WithDevice(kCPU1), y);
1566 // batch_mean1 and batch_var1 take different output ports, so each will
1567 // initiate Send/Recv.
1568 auto batch_mean1 = ops::Identity(
1569 s.WithOpName("batch_mean1").WithDevice(kCPU1), batch_mean);
1570 auto batch_var1 =
1571 ops::Identity(s.WithOpName("batch_var1").WithDevice(kCPU1), batch_var);
1572 // This is control dependency.
1573 auto control_dep = ops::NoOp(s.WithOpName("control_dep")
1574 .WithControlDependencies(y)
1575 .WithDevice(kCPU1));
1576
1577 GraphDef def;
1578 TF_CHECK_OK(s.ToGraphDef(&def));
1579
1580 grappler_item_.reset(new GrapplerItem);
1581 grappler_item_->id = "test_conv2d_graph";
1582 grappler_item_->graph = def;
1583 grappler_item_->fetch = {"y1", "y2", "batch_mean1", "batch_var1",
1584 "control_dep"};
1585
1586 dependency_["bn"] = {"x", "mean", "var"};
1587 dependency_["y1"] = {"bn"};
1588 dependency_["y2"] = {"bn"};
1589 dependency_["batch_mean1"] = {"bn"};
1590 dependency_["batch_var1"] = {"bn"};
1591 dependency_["control_dep"] = {"bn"};
1592 }
1593
1594 // Call this after creating grappler_item_ and setting up dependency_.
InitScheduler()1595 void InitScheduler() { TF_ASSERT_OK(scheduler_->Init(grappler_item_.get())); }
1596
1597 // Returns cost based on op.
SimplePredictCosts(const OpContext & op_context) const1598 Costs SimplePredictCosts(const OpContext& op_context) const {
1599 Costs c;
1600 int64 exec_cost = 0;
1601 if (op_context.op_info.op() == "MatMul") {
1602 exec_cost = 2000000000;
1603 } else if (op_context.op_info.op() == "RandomUniform") {
1604 exec_cost = 1000000000;
1605 } else {
1606 exec_cost = 1000;
1607 }
1608 c.execution_time = Costs::NanoSeconds(exec_cost);
1609 return c;
1610 }
1611
1612 // Call this after init scheduler_. Scheduler stops after executing
1613 // target_node.
RunScheduler(const string & target_node)1614 std::unordered_map<string, OpContext> RunScheduler(
1615 const string& target_node) {
1616 std::unordered_map<string, OpContext> ops_executed;
1617 bool more_nodes = true;
1618 do {
1619 OpContext op_context = scheduler_->GetCurrNode();
1620 ops_executed[op_context.name] = op_context;
1621 std::cout << op_context.name << std::endl;
1622
1623 Costs node_costs = SimplePredictCosts(op_context);
1624
1625 // Check scheduling order.
1626 auto it = dependency_.find(op_context.name);
1627 if (it != dependency_.end()) {
1628 for (const auto& preceding_node : it->second) {
1629 EXPECT_GT(ops_executed.count(preceding_node), 0);
1630 }
1631 }
1632 more_nodes = scheduler_->MarkCurrNodeExecuted(node_costs);
1633
1634 if (op_context.name == target_node) {
1635 // Scheduler has the state after executing the target node.
1636 break;
1637 }
1638 } while (more_nodes);
1639 return ops_executed;
1640 }
1641
1642 // Helper method for validating a vector.
1643 template <typename T>
ExpectVectorEq(const std::vector<T> & expected,const std::vector<T> & test_elements)1644 void ExpectVectorEq(const std::vector<T>& expected,
1645 const std::vector<T>& test_elements) {
1646 // Set of expected elements for an easy comparison.
1647 std::set<T> expected_set(expected.begin(), expected.end());
1648 for (const auto& element : test_elements) {
1649 EXPECT_GT(expected_set.count(element), 0);
1650 }
1651 EXPECT_EQ(expected.size(), test_elements.size());
1652 }
1653
1654 // Helper method that checks the name of nodes.
ValidateNodeDefs(const std::vector<string> & expected,const std::vector<const NodeDef * > & node_defs)1655 void ValidateNodeDefs(const std::vector<string>& expected,
1656 const std::vector<const NodeDef*>& node_defs) {
1657 std::vector<string> node_names;
1658 std::transform(node_defs.begin(), node_defs.end(),
1659 std::back_inserter(node_names),
1660 [](const NodeDef* node) { return node->name(); });
1661 ExpectVectorEq(expected, node_names);
1662 }
1663
1664 // Helper method for validating a set.
1665 template <typename T>
ExpectSetEq(const std::set<T> & expected,const std::set<T> & test_elements)1666 void ExpectSetEq(const std::set<T>& expected,
1667 const std::set<T>& test_elements) {
1668 for (const auto& element : test_elements) {
1669 EXPECT_GT(expected.count(element), 0);
1670 }
1671 EXPECT_EQ(expected.size(), test_elements.size());
1672 }
1673
1674 // Helper method tthat checks name - port pairs.
ValidateMemoryUsageSnapshot(const std::vector<string> & expected_names,const int port_num_expected,const std::unordered_set<std::pair<const NodeDef *,int>,DeviceState::NodePairHash> & mem_usage_snapshot)1675 void ValidateMemoryUsageSnapshot(
1676 const std::vector<string>& expected_names, const int port_num_expected,
1677 const std::unordered_set<std::pair<const NodeDef*, int>,
1678 DeviceState::NodePairHash>& mem_usage_snapshot) {
1679 std::set<std::pair<string, int>> nodes_at_peak_mem_usage;
1680 std::transform(
1681 mem_usage_snapshot.begin(), mem_usage_snapshot.end(),
1682 std::inserter(nodes_at_peak_mem_usage, nodes_at_peak_mem_usage.begin()),
1683 [](const std::pair<const NodeDef*, int>& node_port) {
1684 return std::make_pair(node_port.first->name(), node_port.second);
1685 });
1686 std::set<std::pair<string, int>> expected;
1687 std::transform(expected_names.begin(), expected_names.end(),
1688 std::inserter(expected, expected.begin()),
1689 [port_num_expected](const string& name) {
1690 return std::make_pair(name, port_num_expected);
1691 });
1692 ExpectSetEq(expected, nodes_at_peak_mem_usage);
1693 }
1694
1695 // Helper method for checking nodes dependency.
ValidateDependencyChain(const std::unordered_map<string,int64> & start_times,const std::vector<string> & nodes_in_dependency_order)1696 void ValidateDependencyChain(
1697 const std::unordered_map<string, int64>& start_times,
1698 const std::vector<string>& nodes_in_dependency_order) {
1699 int64 prev_node_time = -1;
1700 for (const auto& node : nodes_in_dependency_order) {
1701 int64 curr_node_time = start_times.at(node);
1702 EXPECT_GE(curr_node_time, prev_node_time);
1703 prev_node_time = curr_node_time;
1704 }
1705 }
1706
1707 // cluster_ and scheduler_ are initialized in the c'tor.
1708 std::unique_ptr<VirtualCluster> cluster_;
1709 std::unique_ptr<TestVirtualScheduler> scheduler_;
1710
1711 // grappler_item_ will be initialized differently for each test case.
1712 std::unique_ptr<GrapplerItem> grappler_item_;
1713 // Node name -> its preceding nodes map for testing scheduling order.
1714 std::unordered_map<string, std::vector<string>> dependency_;
1715
1716 // Shared params for Conv2D related graphs:
1717 const int batch_size_ = 4;
1718 const int width_ = 10;
1719 const int height_ = 10;
1720 const int depth_in_ = 8;
1721 const int kernel_ = 3;
1722 const int depth_out_ = 16;
1723 };
1724
1725 // Test that FIFOManager correctly returns the current node with only 1 node.
TEST_F(VirtualSchedulerTest,GetSingleNodeFIFOManager)1726 TEST_F(VirtualSchedulerTest, GetSingleNodeFIFOManager) {
1727 // Init.
1728 FIFOManager manager = FIFOManager();
1729
1730 // Add the node to FIFOManager.
1731 manager.AddNode(&node1_);
1732 EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1733 }
1734
1735 // Test that FIFOManager removes the only node contained within.
TEST_F(VirtualSchedulerTest,RemoveSingleNodeFIFOManager)1736 TEST_F(VirtualSchedulerTest, RemoveSingleNodeFIFOManager) {
1737 // Init.
1738 FIFOManager manager = FIFOManager();
1739
1740 // Add the node to FIFOManager.
1741 manager.AddNode(&node1_);
1742
1743 // Remove the only node in FIFOManager.
1744 manager.RemoveCurrNode();
1745 EXPECT_TRUE(manager.Empty());
1746 }
1747
1748 // Test that FIFOManager can remove multiple nodes and returns the current node
1749 // in the right order
TEST_F(VirtualSchedulerTest,GetAndRemoveMultipleFIFOManager)1750 TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleFIFOManager) {
1751 // Init.
1752 FIFOManager manager = FIFOManager();
1753
1754 // Add the nodes to FIFOManager.
1755 manager.AddNode(&node1_);
1756 manager.AddNode(&node2_);
1757 manager.AddNode(&node3_);
1758 manager.AddNode(&node4_);
1759
1760 // Keep checking current node while removing nodes from manager.
1761 EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1762 manager.RemoveCurrNode();
1763 EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1764 manager.RemoveCurrNode();
1765 EXPECT_EQ("Node3", manager.GetCurrNode()->name());
1766 manager.RemoveCurrNode();
1767 EXPECT_EQ("Node4", manager.GetCurrNode()->name());
1768 manager.RemoveCurrNode();
1769 EXPECT_TRUE(manager.Empty());
1770 }
1771
1772 // Test that FIFOManager can remove multiple nodes and add more nodes, still
1773 // returning the current node in the right order
TEST_F(VirtualSchedulerTest,AddAndRemoveMultipleFIFOManager)1774 TEST_F(VirtualSchedulerTest, AddAndRemoveMultipleFIFOManager) {
1775 // Init.
1776 FIFOManager manager = FIFOManager();
1777
1778 // Add the nodes to FIFOManager.
1779 manager.AddNode(&node1_);
1780 manager.AddNode(&node2_);
1781 manager.AddNode(&node3_);
1782 manager.AddNode(&node4_);
1783
1784 // Keep checking current node as nodes are removed and added.
1785 EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1786 manager.RemoveCurrNode();
1787 EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1788 manager.AddNode(&node5_);
1789 // GetCurrNode() should return the same node even if some nodes are added,
1790 // until RemoveCurrNode() is called.
1791 EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1792 manager.RemoveCurrNode();
1793 EXPECT_EQ("Node3", manager.GetCurrNode()->name());
1794 manager.RemoveCurrNode();
1795 EXPECT_EQ("Node4", manager.GetCurrNode()->name());
1796 manager.AddNode(&node6_);
1797 EXPECT_EQ("Node4", manager.GetCurrNode()->name());
1798 manager.RemoveCurrNode();
1799 EXPECT_EQ("Node5", manager.GetCurrNode()->name());
1800 manager.RemoveCurrNode();
1801 EXPECT_EQ("Node6", manager.GetCurrNode()->name());
1802 manager.RemoveCurrNode();
1803 EXPECT_TRUE(manager.Empty());
1804 }
1805
1806 // Test that LIFOManager correctly returns the current node with only 1 node.
TEST_F(VirtualSchedulerTest,GetSingleNodeLIFOManager)1807 TEST_F(VirtualSchedulerTest, GetSingleNodeLIFOManager) {
1808 // Init.
1809 LIFOManager manager = LIFOManager();
1810
1811 // Add the node to LIFOManager.
1812 manager.AddNode(&node1_);
1813 EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1814 }
1815
1816 // Test that LIFOManager removes the only node contained within.
TEST_F(VirtualSchedulerTest,RemoveSingleNodeLIFOManager)1817 TEST_F(VirtualSchedulerTest, RemoveSingleNodeLIFOManager) {
1818 // Init.
1819 LIFOManager manager = LIFOManager();
1820
1821 // Add the node to LIFOManager.
1822 manager.AddNode(&node1_);
1823
1824 // Remove the only node in LIFOManager.
1825 manager.RemoveCurrNode();
1826 EXPECT_TRUE(manager.Empty());
1827 }
1828
1829 // Test that LIFOManager can remove multiple nodes and returns the current node
1830 // in the right order
TEST_F(VirtualSchedulerTest,GetAndRemoveMultipleLIFOManager)1831 TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleLIFOManager) {
1832 // Init.
1833 LIFOManager manager = LIFOManager();
1834
1835 // Add the nodes to LIFOManager.
1836 manager.AddNode(&node1_);
1837 manager.AddNode(&node2_);
1838 manager.AddNode(&node3_);
1839 manager.AddNode(&node4_);
1840
1841 // Keep checking current node while removing nodes from manager.
1842 EXPECT_EQ("Node4", manager.GetCurrNode()->name());
1843 manager.RemoveCurrNode();
1844 EXPECT_EQ("Node3", manager.GetCurrNode()->name());
1845 manager.RemoveCurrNode();
1846 EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1847 manager.RemoveCurrNode();
1848 EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1849 manager.RemoveCurrNode();
1850 EXPECT_TRUE(manager.Empty());
1851 }
1852
1853 // Test that LIFOManager can remove multiple nodes (must be removing the current
1854 // node) and add more nodes, still returning the current node in the right order
TEST_F(VirtualSchedulerTest,AddAndRemoveMultipleLIFOManager)1855 TEST_F(VirtualSchedulerTest, AddAndRemoveMultipleLIFOManager) {
1856 // Init.
1857 LIFOManager manager = LIFOManager();
1858
1859 // Add the nodes to LIFOManager.
1860 manager.AddNode(&node1_);
1861 manager.AddNode(&node2_);
1862 manager.AddNode(&node3_);
1863 manager.AddNode(&node4_);
1864
1865 // Keep checking current node as nodes are removed and added.
1866 EXPECT_EQ("Node4", manager.GetCurrNode()->name());
1867 manager.RemoveCurrNode();
1868 EXPECT_EQ("Node3", manager.GetCurrNode()->name());
1869 manager.AddNode(&node5_);
1870 // GetCurrNode() should return the same node even if some nodes are added,
1871 // until RemoveCurrNode() is called.
1872 EXPECT_EQ("Node3", manager.GetCurrNode()->name());
1873 manager.RemoveCurrNode();
1874 EXPECT_EQ("Node5", manager.GetCurrNode()->name());
1875 manager.RemoveCurrNode();
1876 EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1877 manager.AddNode(&node6_);
1878 EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1879 manager.RemoveCurrNode();
1880 EXPECT_EQ("Node6", manager.GetCurrNode()->name());
1881 manager.RemoveCurrNode();
1882 EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1883 manager.RemoveCurrNode();
1884 EXPECT_TRUE(manager.Empty());
1885 }
1886
TEST_F(VirtualSchedulerTest,GetSingleNodeFirstReadyManager)1887 TEST_F(VirtualSchedulerTest, GetSingleNodeFirstReadyManager) {
1888 FirstReadyManager manager;
1889 manager.Init(&node_states_);
1890
1891 manager.AddNode(&node1_);
1892 EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1893 }
1894
TEST_F(VirtualSchedulerTest,RemoveSingleNodeFirstReadyManager)1895 TEST_F(VirtualSchedulerTest, RemoveSingleNodeFirstReadyManager) {
1896 FirstReadyManager manager;
1897 manager.Init(&node_states_);
1898 manager.AddNode(&node1_);
1899 manager.RemoveCurrNode();
1900 EXPECT_TRUE(manager.Empty());
1901 }
1902
TEST_F(VirtualSchedulerTest,GetAndRemoveMultipleFirstReadyManager)1903 TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleFirstReadyManager) {
1904 FirstReadyManager manager;
1905 manager.Init(&node_states_);
1906 // Insert nodes in some random order.
1907 manager.AddNode(&node2_);
1908 manager.AddNode(&node1_);
1909 manager.AddNode(&node4_);
1910 manager.AddNode(&node5_);
1911 manager.AddNode(&node3_);
1912 manager.AddNode(&node6_);
1913
1914 // In whatever order we insert nodes, we get the same order based on nodes'
1915 // time_ready.
1916 EXPECT_EQ("Node6", manager.GetCurrNode()->name());
1917 manager.RemoveCurrNode();
1918 EXPECT_EQ("Node5", manager.GetCurrNode()->name());
1919 manager.RemoveCurrNode();
1920 EXPECT_EQ("Node4", manager.GetCurrNode()->name());
1921 manager.RemoveCurrNode();
1922 EXPECT_EQ("Node3", manager.GetCurrNode()->name());
1923 manager.RemoveCurrNode();
1924 EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1925 manager.RemoveCurrNode();
1926 EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1927 manager.RemoveCurrNode();
1928 EXPECT_TRUE(manager.Empty());
1929 }
1930
TEST_F(VirtualSchedulerTest,GetCurrNodeFirstReadyManager)1931 TEST_F(VirtualSchedulerTest, GetCurrNodeFirstReadyManager) {
1932 FirstReadyManager manager;
1933 manager.Init(&node_states_);
1934 // Insert nodes in some random order.
1935 manager.AddNode(&node2_);
1936 manager.AddNode(&node1_);
1937 manager.AddNode(&node4_);
1938 manager.AddNode(&node5_);
1939 manager.AddNode(&node3_);
1940 manager.AddNode(&node6_);
1941
1942 // Among these nodes, node6 has the smallest time_ready, hence, GetCurrNode()
1943 // should return it.
1944 EXPECT_EQ("Node6", manager.GetCurrNode()->name());
1945 // Now insret a few other nodes, but their time_ready's are even smaller than
1946 // that of Node6. Before calling RemoveCurrNode(), GetCurrNode() should return
1947 // the same node, Node6, in this case.
1948
1949 NodeDef node7;
1950 NodeDef node8;
1951 NodeDef node9;
1952 NodeSetUp("Node7", kConv2D, kCPU0, 5, &node7);
1953 NodeSetUp("Node8", kConv2D, kCPU0, 4, &node8);
1954 NodeSetUp("Node9", kConv2D, kCPU0, 3, &node9);
1955
1956 manager.AddNode(&node7);
1957 EXPECT_EQ("Node6", manager.GetCurrNode()->name());
1958
1959 manager.AddNode(&node8);
1960 EXPECT_EQ("Node6", manager.GetCurrNode()->name());
1961
1962 manager.RemoveCurrNode();
1963 // Now Node6 is removed, and GetCurrNode() will return Node8.
1964 EXPECT_EQ("Node8", manager.GetCurrNode()->name());
1965
1966 // Again, AddNode shouldn't change GetCurrNode().
1967 manager.AddNode(&node9);
1968 EXPECT_EQ("Node8", manager.GetCurrNode()->name());
1969
1970 manager.RemoveCurrNode();
1971 EXPECT_EQ("Node9", manager.GetCurrNode()->name());
1972 manager.RemoveCurrNode();
1973 EXPECT_EQ("Node7", manager.GetCurrNode()->name());
1974 manager.RemoveCurrNode();
1975 EXPECT_EQ("Node5", manager.GetCurrNode()->name());
1976 manager.RemoveCurrNode();
1977 EXPECT_EQ("Node4", manager.GetCurrNode()->name());
1978 manager.RemoveCurrNode();
1979 EXPECT_EQ("Node3", manager.GetCurrNode()->name());
1980 manager.RemoveCurrNode();
1981 EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1982 manager.RemoveCurrNode();
1983 EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1984 manager.RemoveCurrNode();
1985 EXPECT_TRUE(manager.Empty());
1986 }
1987
TEST_F(VirtualSchedulerTest,DeterminismInFirstReadyManager)1988 TEST_F(VirtualSchedulerTest, DeterminismInFirstReadyManager) {
1989 FirstReadyManager manager1;
1990 manager1.Init(&node_states_);
1991 FirstReadyManager manager2;
1992 manager2.Init(&node_states_);
1993
1994 // 6 nodes with same time_ready.
1995 NodeDef node7;
1996 NodeDef node8;
1997 NodeDef node9;
1998 NodeDef node10;
1999 NodeDef node11;
2000 NodeDef node12;
2001 NodeSetUp("Node7", kConv2D, kCPU0, 1000, &node7);
2002 NodeSetUp("Node8", kConv2D, kCPU0, 1000, &node8);
2003 NodeSetUp("Node9", kConv2D, kCPU0, 1000, &node9);
2004 NodeSetUp("Node10", kConv2D, kCPU0, 1000, &node10);
2005 NodeSetUp("Node11", kConv2D, kCPU0, 1000, &node11);
2006 NodeSetUp("Node12", kConv2D, kCPU0, 1000, &node12);
2007
2008 // Add the above 6 nodes to manager1.
2009 manager1.AddNode(&node7);
2010 manager1.AddNode(&node8);
2011 manager1.AddNode(&node9);
2012 manager1.AddNode(&node10);
2013 manager1.AddNode(&node11);
2014 manager1.AddNode(&node12);
2015
2016 // Add the above 6 nodes to manager2, but in a different order.
2017 manager2.AddNode(&node8);
2018 manager2.AddNode(&node11);
2019 manager2.AddNode(&node9);
2020 manager2.AddNode(&node10);
2021 manager2.AddNode(&node7);
2022 manager2.AddNode(&node12);
2023
2024 // Expect both managers return the same nodes for deterministic node
2025 // scheduling.
2026 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
2027 manager1.RemoveCurrNode();
2028 manager2.RemoveCurrNode();
2029
2030 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
2031 manager1.RemoveCurrNode();
2032 manager2.RemoveCurrNode();
2033
2034 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
2035 manager1.RemoveCurrNode();
2036 manager2.RemoveCurrNode();
2037
2038 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
2039 manager1.RemoveCurrNode();
2040 manager2.RemoveCurrNode();
2041
2042 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
2043 manager1.RemoveCurrNode();
2044 manager2.RemoveCurrNode();
2045
2046 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
2047 manager1.RemoveCurrNode();
2048 manager2.RemoveCurrNode();
2049
2050 EXPECT_TRUE(manager1.Empty());
2051 EXPECT_TRUE(manager2.Empty());
2052 }
2053
TEST_F(VirtualSchedulerTest,RemoveSingleNodeCompositeNodeManager)2054 TEST_F(VirtualSchedulerTest, RemoveSingleNodeCompositeNodeManager) {
2055 CompositeNodeManager manager;
2056 manager.Init(&node_states_);
2057 manager.AddNode(&node1_);
2058 manager.RemoveCurrNode();
2059 EXPECT_TRUE(manager.Empty());
2060 }
2061
TEST_F(VirtualSchedulerTest,RemoveSingleNodeComopsiteNodeManager)2062 TEST_F(VirtualSchedulerTest, RemoveSingleNodeComopsiteNodeManager) {
2063 CompositeNodeManager manager;
2064 manager.Init(&node_states_);
2065
2066 manager.AddNode(&node1_);
2067 manager.RemoveCurrNode();
2068 EXPECT_TRUE(manager.Empty());
2069 }
2070
TEST_F(VirtualSchedulerTest,GetAndRemoveMultipleComopsiteNodeManager)2071 TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleComopsiteNodeManager) {
2072 CompositeNodeManager manager;
2073 manager.Init(&node_states_);
2074
2075 // Add the nodes to LIFOManager.
2076 manager.AddNode(&node1_);
2077 manager.AddNode(&node2_);
2078 manager.AddNode(&node3_);
2079 manager.AddNode(&node4_);
2080
2081 // Keep checking current node as nodes are removed and added.
2082 EXPECT_EQ("Node4", manager.GetCurrNode()->name());
2083 manager.RemoveCurrNode();
2084 EXPECT_EQ("Node3", manager.GetCurrNode()->name());
2085 manager.AddNode(&node5_);
2086 // GetCurrNode() should return the same node even if some nodes are added,
2087 // until RemoveCurrNode() is called.
2088 EXPECT_EQ("Node3", manager.GetCurrNode()->name());
2089 manager.RemoveCurrNode();
2090 EXPECT_EQ("Node5", manager.GetCurrNode()->name());
2091 manager.RemoveCurrNode();
2092 EXPECT_EQ("Node2", manager.GetCurrNode()->name());
2093 manager.AddNode(&node6_);
2094 EXPECT_EQ("Node2", manager.GetCurrNode()->name());
2095 manager.RemoveCurrNode();
2096 EXPECT_EQ("Node6", manager.GetCurrNode()->name());
2097 manager.RemoveCurrNode();
2098 EXPECT_EQ("Node1", manager.GetCurrNode()->name());
2099 manager.RemoveCurrNode();
2100 EXPECT_TRUE(manager.Empty());
2101 }
2102
TEST_F(VirtualSchedulerTest,MultiDeviceSendRecvComopsiteNodeManager)2103 TEST_F(VirtualSchedulerTest, MultiDeviceSendRecvComopsiteNodeManager) {
2104 CompositeNodeManager manager;
2105 manager.Init(&node_states_);
2106 // Additional nodes on kCPU1
2107 NodeDef node7;
2108 NodeDef node8;
2109 NodeDef node9;
2110 NodeSetUp("Node7", kConv2D, kCPU1, 1001, &node7);
2111 NodeSetUp("Node8", kConv2D, kCPU1, 2001, &node8);
2112 NodeSetUp("Node9", kConv2D, kCPU1, 3001, &node9);
2113
2114 // Send and Recv nodes.
2115 NodeDef send1;
2116 NodeDef send2;
2117 NodeDef recv1;
2118 NodeDef recv2;
2119 NodeSetUp("Send1", kSend, kChannelFrom0To1, 2002, &send1);
2120 NodeSetUp("Send2", kSend, kChannelFrom1To0, 2005, &send2);
2121 NodeSetUp("Recv1", kRecv, kCPU0, 2003, &recv1);
2122 NodeSetUp("Recv2", kRecv, kCPU1, 2004, &recv2);
2123
2124 // Insert nodes.
2125 manager.AddNode(&node1_);
2126 manager.AddNode(&node2_);
2127 manager.AddNode(&node3_);
2128 manager.AddNode(&node4_);
2129 manager.AddNode(&node5_);
2130 manager.AddNode(&node6_);
2131 manager.AddNode(&node7);
2132 manager.AddNode(&node8);
2133 manager.AddNode(&node9);
2134 manager.AddNode(&send1);
2135 manager.AddNode(&send2);
2136 manager.AddNode(&recv1);
2137 manager.AddNode(&recv2);
2138
2139 // on kCPU0; last one is node6_, on kCPU1: last one is node9;
2140 // so choose one that has earliest time_ready among node6_, node9,
2141 // Send1, Send2, Recv1, and Recv2.
2142 EXPECT_EQ("Node6", manager.GetCurrNode()->name());
2143 manager.RemoveCurrNode();
2144 // Then, the next one on kCPU0 is node5_; choose the earliest time_ready node
2145 // among node5_, node9, Send1, Send2, Recv1, and Recv2.
2146 EXPECT_EQ("Node5", manager.GetCurrNode()->name());
2147 manager.RemoveCurrNode();
2148 // Next, choose among node4_, node9, Send1, Send2, Recv1, and Recv2.
2149 EXPECT_EQ("Send1", manager.GetCurrNode()->name());
2150 manager.RemoveCurrNode();
2151 // Next, choose among node4_, node9, Sen2, Recv1, and Recv2.
2152 EXPECT_EQ("Recv1", manager.GetCurrNode()->name());
2153 manager.RemoveCurrNode();
2154 // Next, choose among node4_, node9, Send2, and Recv2.
2155 EXPECT_EQ("Recv2", manager.GetCurrNode()->name());
2156 manager.RemoveCurrNode();
2157 // Next, choose among node4_, node9, and Send2.
2158 EXPECT_EQ("Send2", manager.GetCurrNode()->name());
2159 manager.RemoveCurrNode();
2160 // Next, choose between node4_, node9.
2161 EXPECT_EQ("Node4", manager.GetCurrNode()->name());
2162 manager.RemoveCurrNode();
2163 // Next, choose between node3_, node9.
2164 EXPECT_EQ("Node9", manager.GetCurrNode()->name());
2165 manager.RemoveCurrNode();
2166 // Next, choose between node3_, node8.
2167 EXPECT_EQ("Node8", manager.GetCurrNode()->name());
2168 manager.RemoveCurrNode();
2169 // Next, choose between node3_, node7.
2170 EXPECT_EQ("Node7", manager.GetCurrNode()->name());
2171 manager.RemoveCurrNode();
2172 // Then, just the nodes on kCPU1 -- LIFO.
2173 EXPECT_EQ("Node3", manager.GetCurrNode()->name());
2174 manager.RemoveCurrNode();
2175 EXPECT_EQ("Node2", manager.GetCurrNode()->name());
2176 manager.RemoveCurrNode();
2177 EXPECT_EQ("Node1", manager.GetCurrNode()->name());
2178 manager.RemoveCurrNode();
2179 EXPECT_TRUE(manager.Empty());
2180 }
2181
TEST_F(VirtualSchedulerTest,DeterminismInCompositeNodeManager)2182 TEST_F(VirtualSchedulerTest, DeterminismInCompositeNodeManager) {
2183 CompositeNodeManager manager;
2184 manager.Init(&node_states_);
2185 CompositeNodeManager manager2;
2186 manager2.Init(&node_states_);
2187
2188 // 6 nodes with same time_ready.
2189 NodeDef node7;
2190 NodeDef node8;
2191 NodeDef node9;
2192 NodeDef node10;
2193 NodeDef node11;
2194 NodeDef node12;
2195 NodeSetUp("Node7", kConv2D, kCPU0, 1000, &node7);
2196 NodeSetUp("Node8", kSend, kCPU0, 1000, &node8);
2197 NodeSetUp("Node9", kRecv, kCPU0, 1000, &node9);
2198 NodeSetUp("Node10", kConv2D, kCPU0, 999, &node10);
2199 NodeSetUp("Node11", kRecv, kCPU0, 999, &node11);
2200 NodeSetUp("Node12", kConv2D, kCPU1, 1000, &node12);
2201
2202 // Add Nodes 7 to 9 to manager.
2203 manager.AddNode(&node7);
2204 manager.AddNode(&node8);
2205 manager.AddNode(&node9);
2206
2207 // It should return _Send, Recv, and the other op order, when the candidate
2208 // nodes have same time_ready.
2209 EXPECT_EQ("Node8", manager.GetCurrNode()->name());
2210 EXPECT_EQ(kSend, manager.GetCurrNode()->op());
2211 manager.RemoveCurrNode();
2212 EXPECT_EQ("Node9", manager.GetCurrNode()->name());
2213 EXPECT_EQ(kRecv, manager.GetCurrNode()->op());
2214 manager.RemoveCurrNode();
2215 EXPECT_EQ("Node7", manager.GetCurrNode()->name());
2216 EXPECT_EQ(kConv2D, manager.GetCurrNode()->op());
2217 manager.RemoveCurrNode();
2218 EXPECT_TRUE(manager.Empty());
2219
2220 // Add Nodes 7 to 9 to manager, but in a different order.
2221 manager.AddNode(&node9);
2222 manager.AddNode(&node8);
2223 manager.AddNode(&node7);
2224
2225 // Expect same order (_Send, _Recv, and the other op), regardless of Add
2226 // order.
2227 EXPECT_EQ("Node8", manager.GetCurrNode()->name());
2228 EXPECT_EQ(kSend, manager.GetCurrNode()->op());
2229 manager.RemoveCurrNode();
2230 EXPECT_EQ("Node9", manager.GetCurrNode()->name());
2231 EXPECT_EQ(kRecv, manager.GetCurrNode()->op());
2232 manager.RemoveCurrNode();
2233 EXPECT_EQ("Node7", manager.GetCurrNode()->name());
2234 EXPECT_EQ(kConv2D, manager.GetCurrNode()->op());
2235 manager.RemoveCurrNode();
2236 EXPECT_TRUE(manager.Empty());
2237
2238 // Conv2D's time_ready < Send's time_ready; Expect Conv2D first.
2239 manager.AddNode(&node8);
2240 manager.AddNode(&node10);
2241 EXPECT_EQ("Node10", manager.GetCurrNode()->name());
2242 EXPECT_EQ(kConv2D, manager.GetCurrNode()->op());
2243 manager.RemoveCurrNode();
2244 EXPECT_EQ("Node8", manager.GetCurrNode()->name());
2245 EXPECT_EQ(kSend, manager.GetCurrNode()->op());
2246 manager.RemoveCurrNode();
2247 EXPECT_TRUE(manager.Empty());
2248
2249 // Recv's time_ready < Send' time_ready; Expect Recv first.
2250 manager.AddNode(&node11);
2251 manager.AddNode(&node8);
2252 EXPECT_EQ("Node11", manager.GetCurrNode()->name());
2253 EXPECT_EQ(kRecv, manager.GetCurrNode()->op());
2254 manager.RemoveCurrNode();
2255 EXPECT_EQ("Node8", manager.GetCurrNode()->name());
2256 EXPECT_EQ(kSend, manager.GetCurrNode()->op());
2257 manager.RemoveCurrNode();
2258 EXPECT_TRUE(manager.Empty());
2259
2260 // Node7 and 12 are normal ops with the same time_ready, placed on different
2261 // devices. These two nodes are added to manager and manager2, but in
2262 // different orders; Expect GetCurrNode() returns the nodes in the same order.
2263 manager.AddNode(&node7);
2264 manager.AddNode(&node12);
2265
2266 manager2.AddNode(&node12);
2267 manager2.AddNode(&node7);
2268
2269 EXPECT_EQ(manager.GetCurrNode()->name(), manager2.GetCurrNode()->name());
2270 manager.RemoveCurrNode();
2271 manager2.RemoveCurrNode();
2272 EXPECT_EQ(manager.GetCurrNode()->name(), manager2.GetCurrNode()->name());
2273 manager.RemoveCurrNode();
2274 manager2.RemoveCurrNode();
2275 EXPECT_TRUE(manager.Empty());
2276 }
2277
2278 // Create small graph, run predict costs on it, make sure the costs from the
2279 // summary match the hand-calculated costs.
TEST_F(VirtualSchedulerTest,SummaryCostTest)2280 TEST_F(VirtualSchedulerTest, SummaryCostTest) {
2281 // Run matmul test.
2282 CreateGrapplerItemWithMatmulChain();
2283 InitScheduler();
2284 auto ops_executed = RunScheduler("");
2285 Costs c = scheduler_->Summary();
2286
2287 // RandomUniform - 5 * 1s
2288 // Matmuls - 4 * 2s = 8
2289 // Misc - 5 * 1us
2290 // Total: 13000005
2291 EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count());
2292 EXPECT_EQ(grappler_item_->graph.node_size(), c.num_ops_total);
2293 EXPECT_FALSE(c.inaccurate);
2294 EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2295 }
2296
2297 // Like the above SummaryCostTest, but makes sure the stepstats timeline is
2298 // correct.
TEST_F(VirtualSchedulerTest,SummaryCostStepStatsTest)2299 TEST_F(VirtualSchedulerTest, SummaryCostStepStatsTest) {
2300 // Run matmul test.
2301 CreateGrapplerItemWithMatmulChain();
2302 InitScheduler();
2303 auto ops_executed = RunScheduler("");
2304 RunMetadata metadata;
2305 Costs c = scheduler_->Summary(&metadata);
2306 StepStats stepstats = metadata.step_stats();
2307 EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count());
2308 EXPECT_EQ(grappler_item_->graph.node_size(), c.num_ops_total);
2309 EXPECT_FALSE(c.inaccurate);
2310 EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2311
2312 // Should only be 1 device!
2313 EXPECT_EQ(1, stepstats.dev_stats().size());
2314
2315 // Create a map of op name -> start and end times (micros).
2316 std::map<string, std::pair<int64, int64>> start_end_times;
2317 for (const auto& device_step_stats : stepstats.dev_stats()) {
2318 for (const auto& stats : device_step_stats.node_stats()) {
2319 int64 start = stats.all_start_micros();
2320 int64 end = start + stats.all_end_rel_micros();
2321 start_end_times[stats.node_name()] = std::pair<int64, int64>(start, end);
2322
2323 // Make sure that the output properties are correct for
2324 // MatMul and RandomUniform operations.
2325 // We only check for dtype, and shape (excluding alloc)
2326 // since alloc is not set by the virtual scheduler.
2327 if (stats.timeline_label() == "MatMul" ||
2328 stats.timeline_label() == "RandomUniform") {
2329 EXPECT_EQ(1, stats.output().size());
2330 for (const auto& output : stats.output()) {
2331 EXPECT_EQ(DT_FLOAT, output.tensor_description().dtype());
2332 EXPECT_EQ(2, output.tensor_description().shape().dim().size());
2333 for (const auto& dim : output.tensor_description().shape().dim()) {
2334 EXPECT_EQ(3200, dim.size());
2335 }
2336 }
2337 }
2338 }
2339 }
2340
2341 // The base start_time is the time to compute RandomUniforms
2342 int64 cur_time = static_cast<int64>(5000005);
2343 // The increment is the execution time of one matmul. See
2344 // CreateGrapplerItemWithMatmulChain for details.
2345 int64 increment = static_cast<int64>(2000000);
2346 auto op_names = {"ab", "abc", "abcd", "abcde"};
2347 for (const auto& op_name : op_names) {
2348 int64 actual_start = start_end_times[op_name].first;
2349 int64 actual_end = start_end_times[op_name].second;
2350 int64 expected_start = cur_time;
2351 int64 expected_end = cur_time + increment;
2352 EXPECT_EQ(expected_start, actual_start);
2353 EXPECT_EQ(expected_end, actual_end);
2354 cur_time += increment;
2355 }
2356 }
2357
TEST_F(VirtualSchedulerTest,InitAndBasicScheduling)2358 TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) {
2359 // Init.
2360 CreateGrapplerItemWithConv2Ds();
2361 InitScheduler();
2362
2363 // Run the scheduler.
2364 auto ops_executed = RunScheduler(""); // Run all the nodes.
2365
2366 // [const and rand] * (x, y, f), and c0 and c1. c2 and z shouldn't be
2367 // executed.
2368 EXPECT_EQ(8, ops_executed.size());
2369
2370 // x, y, f, c0, and c1 should be in the ops executed.
2371 EXPECT_GT(ops_executed.count("x"), 0);
2372 EXPECT_GT(ops_executed.count("y"), 0);
2373 EXPECT_GT(ops_executed.count("f"), 0);
2374 EXPECT_GT(ops_executed.count("c0"), 0);
2375 EXPECT_GT(ops_executed.count("c1"), 0);
2376
2377 // z and c2 shouldn't be part of it.
2378 EXPECT_EQ(ops_executed.count("z"), 0);
2379 EXPECT_EQ(ops_executed.count("c2"), 0);
2380
2381 // Check input / output properties.
2382 EXPECT_EQ(1, ops_executed["x"].op_info.outputs_size());
2383 EXPECT_EQ(1, ops_executed["y"].op_info.outputs_size());
2384 EXPECT_EQ(1, ops_executed["f"].op_info.outputs_size());
2385 EXPECT_EQ(2, ops_executed["c0"].op_info.inputs_size());
2386 EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size());
2387 }
2388
TEST_F(VirtualSchedulerTest,MemoryUsage)2389 TEST_F(VirtualSchedulerTest, MemoryUsage) {
2390 // Init.
2391 CreateGrapplerItemWithAddN();
2392 InitScheduler();
2393
2394 // Run the scheduler.
2395 RunScheduler("");
2396
2397 const auto* device_states = scheduler_->GetDeviceStates();
2398 const auto& cpu_state = device_states->at(kCPU0);
2399
2400 // out node adds 4 tensors, each with 10x10x10x10, so the peak memory usage
2401 // is 4 x the input tensor size while executing the out node.
2402 int64 one_input_node_size = 4 * 10 * 10 * 10 * 10;
2403 const std::vector<string> expected_names = {"x", "y", "z", "w"};
2404 EXPECT_EQ(expected_names.size() * one_input_node_size,
2405 cpu_state.max_memory_usage);
2406 ValidateMemoryUsageSnapshot(expected_names, 0 /* port_num_expected */,
2407 cpu_state.mem_usage_snapshot_at_peak);
2408 }
2409
TEST_F(VirtualSchedulerTest,UnnecessaryFeedNodes)2410 TEST_F(VirtualSchedulerTest, UnnecessaryFeedNodes) {
2411 CreateGrapplerItemWithUnnecessaryPlaceholderNodes();
2412 InitScheduler();
2413
2414 // Test that scheduler can run graphs with extra unnecessary feed nodes.
2415 auto ops_executed = RunScheduler("");
2416 ASSERT_EQ(1, ops_executed.size());
2417 ASSERT_EQ(ops_executed.count("x"), 1);
2418 }
2419
TEST_F(VirtualSchedulerTest,ControlDependency)2420 TEST_F(VirtualSchedulerTest, ControlDependency) {
2421 // Init.
2422 CreateGrapplerItemWithControlDependency();
2423 InitScheduler();
2424
2425 // Run the scheduler.
2426 RunScheduler("");
2427
2428 const auto* device_states = scheduler_->GetDeviceStates();
2429 const auto& cpu_state = device_states->at(kCPU0);
2430
2431 // The graph has a NoOp that takes control dependency from 7 NoOps. The peak
2432 // memory usage is when executing the final NoOp.
2433 int64 one_input_node_size = 4; // control dependency
2434 const std::vector<string> expected_names = {"x", "y", "z", "w",
2435 "u", "v", "t"};
2436 EXPECT_EQ(expected_names.size() * one_input_node_size,
2437 cpu_state.max_memory_usage);
2438 ValidateMemoryUsageSnapshot(expected_names, -1 /* port_num_expected */,
2439 cpu_state.mem_usage_snapshot_at_peak);
2440 }
2441
TEST_F(VirtualSchedulerTest,ComplexDependency)2442 TEST_F(VirtualSchedulerTest, ComplexDependency) {
2443 // Init.
2444 CreateGrapplerItemWithBatchNorm();
2445 InitScheduler();
2446
2447 // Run the scheduler.
2448 RunScheduler("bn");
2449
2450 const auto& device_states = scheduler_->GetDeviceStates();
2451 const auto& cpu_state = device_states->at(kCPU0);
2452
2453 // The graph is
2454 // bn = FusedBatchNorm(x, scale, offset, mean, var)
2455 // z1 = bn.y + x
2456 // z2 = bn.var + bn.var
2457 // z3 = bn.var + bn.var
2458 // z4 = control dependency from bn.
2459 // Note that bn.mean doesn't have any consumer.
2460 const int x_size = batch_size_ * width_ * height_ * depth_in_;
2461 int64 expected_size =
2462 4 * (2 * x_size /* x and bn.y */ + depth_in_ /* bn.var */ +
2463 1 /* control dependency */);
2464 EXPECT_EQ(expected_size, cpu_state.memory_usage);
2465
2466 // Nodes currently in memory: bn's port -1, 0, and 2, and x's port 0.
2467 std::set<std::pair<string, int>> nodes_in_memory;
2468 std::transform(
2469 cpu_state.nodes_in_memory.begin(), cpu_state.nodes_in_memory.end(),
2470 std::inserter(nodes_in_memory, nodes_in_memory.begin()),
2471 [](const std::pair<const NodeDef*, int>& node_port) {
2472 return std::make_pair(node_port.first->name(), node_port.second);
2473 });
2474 std::set<std::pair<string, int>> expected = {
2475 std::make_pair("bn", -1),
2476 std::make_pair("bn", 0),
2477 std::make_pair("bn", 2),
2478 std::make_pair("x", 0),
2479 };
2480 ExpectSetEq(expected, nodes_in_memory);
2481
2482 const auto* node_states = scheduler_->GetNodeStates();
2483 const NodeState* bn_node = nullptr;
2484 const NodeState* x_node = nullptr;
2485 for (const auto& nodedef_node_state : *node_states) {
2486 const NodeDef* node = nodedef_node_state.first;
2487 const NodeState& node_state = nodedef_node_state.second;
2488 if (node->name() == "bn") {
2489 bn_node = &node_state;
2490 }
2491 if (node->name() == "x") {
2492 x_node = &node_state;
2493 }
2494 }
2495 CHECK_NOTNULL(bn_node);
2496 CHECK_NOTNULL(x_node);
2497
2498 ValidateNodeDefs({"bn", "z1"}, x_node->outputs.at(0));
2499 ValidateNodeDefs({"z4"}, bn_node->outputs.at(-1));
2500 ValidateNodeDefs({"z1"}, bn_node->outputs.at(0));
2501 // z2 and z3 are bn.var + bn.var, so they appear twice in bn's output port 2.
2502 ValidateNodeDefs({"z2", "z3", "z2", "z3"}, bn_node->outputs.at(2));
2503 }
2504
TEST_F(VirtualSchedulerTest,Variable)2505 TEST_F(VirtualSchedulerTest, Variable) {
2506 // Init.
2507 CreateGrapplerItemWithConv2DAndVariable();
2508 InitScheduler();
2509
2510 // Run the scheduler.
2511 RunScheduler("");
2512
2513 const auto* device_states = scheduler_->GetDeviceStates();
2514 const auto& cpu_state = device_states->at(kCPU0);
2515
2516 // There is one Conv2D that takes x and f, but f is variable, so it should be
2517 // in persistent nodes.
2518 // f is variable.
2519 ValidateMemoryUsageSnapshot({"f"}, 0 /* port_num_expected */,
2520 cpu_state.persistent_nodes);
2521 // Only x in peak memory usage snapshot.
2522 ValidateMemoryUsageSnapshot({"x"}, 0 /* port_num_expected */,
2523 cpu_state.mem_usage_snapshot_at_peak);
2524 }
2525
TEST_F(VirtualSchedulerTest,WhileLoop)2526 TEST_F(VirtualSchedulerTest, WhileLoop) {
2527 // Init.
2528 CreateGrapplerItemWithLoop();
2529 InitScheduler();
2530
2531 // Run the scheduler.
2532 RunScheduler("");
2533
2534 // Check the timeline
2535 RunMetadata metadata;
2536 scheduler_->Summary(&metadata);
2537
2538 // Nodes in topological order:
2539 // * const, ones
2540 // * while/Enter, while/Enter_1
2541 // * while/Merge, while/Merge_1
2542 // * while/Less/y
2543 // * while/Less
2544 // * while/LoopCond
2545 // * while/Switch, while/Switch_1
2546 // * while/Identity, while/Identity_1, while/Exit, while/Exit_1
2547 // * while/add/y, while/concat/axis
2548 // * while/add, while/concat
2549 // * while/NextIteration, while/NextIteration_1
2550
2551 int num_next_iteration = 0;
2552 int num_next_iteration_1 = 0;
2553 int num_exit = 0;
2554 int num_exit_1 = 0;
2555 int64 next_iter_start_micro;
2556 int64 next_iter_1_start_micro;
2557 int64 exit_start_micro;
2558 int64 exit_1_start_micro;
2559
2560 std::unordered_map<string, int64> start_times;
2561 for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
2562 for (const auto& stats : device_step_stats.node_stats()) {
2563 start_times[stats.node_name()] = stats.all_start_micros();
2564 if (stats.node_name() == "while/NextIteration") {
2565 ++num_next_iteration;
2566 next_iter_start_micro = stats.all_start_micros();
2567 } else if (stats.node_name() == "while/NextIteration_1") {
2568 ++num_next_iteration_1;
2569 next_iter_1_start_micro = stats.all_start_micros();
2570 } else if (stats.node_name() == "while/Exit") {
2571 ++num_exit;
2572 exit_start_micro = stats.all_start_micros();
2573 } else if (stats.node_name() == "while/Exit_1") {
2574 ++num_exit_1;
2575 exit_1_start_micro = stats.all_start_micros();
2576 }
2577 }
2578 }
2579
2580 // Make sure we went though the body of the loop once, and that the output of
2581 // the loop was scheduled as well.
2582 EXPECT_EQ(1, num_next_iteration);
2583 EXPECT_EQ(1, num_next_iteration_1);
2584 EXPECT_EQ(1, num_exit);
2585 EXPECT_EQ(1, num_exit_1);
2586
2587 // Start times of while/NextIteration and while/NextIteration_1 should be
2588 // different, so should be those of while/Exit and while/Exit_1.
2589 EXPECT_NE(next_iter_start_micro, next_iter_1_start_micro);
2590 EXPECT_NE(exit_start_micro, exit_1_start_micro);
2591
2592 // Check dependency among the nodes; no matter what scheduling mechanism we
2593 // use, the scheduled ops should follow these dependency chains.
2594 // Note that currently, VirtualScheduler executes while/Merge twice; hence,
2595 // we're not testing dependency chains related to while/Merge.
2596 // TODO(dyoon): after fixing while loop behavior correctly (run nodes in the
2597 // order of Enter, Merge, ...loop condition ..., ... loop body ...,
2598 // NextIteration, Merge, ... loop condition ..., Exit), re-enable dependency
2599 // chaining test w/ Merge nodes.
2600 ValidateDependencyChain(
2601 start_times,
2602 {"Const", "while/Enter", // "while/Merge",
2603 "while/Less/y", "while/Less", "while/LoopCond", "while/Switch",
2604 "while/Identity", "while/add/y", "while/add", "while/NextIteration"});
2605 // ValidateDependencyChain(start_times, {"while/Merge", "while/Less"});
2606 ValidateDependencyChain(start_times,
2607 {"ones", "while/Enter_1", // "while/Merge_1",
2608 "while/Switch_1", "while/Identity_1", "while/concat",
2609 "while/NextIteration_1"});
2610 ValidateDependencyChain(start_times, {"while/Switch", "while/Exit"});
2611 ValidateDependencyChain(
2612 start_times, {"while/Identity", "while/concat/axis", "while/concat"});
2613 ValidateDependencyChain(start_times, {"while/Identity", "while/add"});
2614 ValidateDependencyChain(start_times, {"while/Switch_1", "while/Exit_1"});
2615 }
2616
TEST_F(VirtualSchedulerTest,AnnotatedWhileLoop)2617 TEST_F(VirtualSchedulerTest, AnnotatedWhileLoop) {
2618 {
2619 // Init.
2620 CreateGrapplerItemWithLoop();
2621 InitScheduler();
2622
2623 // Runs the scheduler.
2624 RunScheduler("");
2625 Costs c = scheduler_->Summary();
2626
2627 EXPECT_EQ(23, c.execution_time.asMicroSeconds().count());
2628 // Both while/Merge and while/Merge_1 are scheduled twice.
2629 EXPECT_EQ(grappler_item_->graph.node_size() + 2, c.num_ops_total);
2630 EXPECT_FALSE(c.inaccurate);
2631 EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2632 }
2633
2634 {
2635 // Init.
2636 CreateGrapplerItemWithLoopAnnotated();
2637 InitScheduler();
2638
2639 // Runs the scheduler.
2640 RunScheduler("");
2641 Costs c = scheduler_->Summary();
2642
2643 // The costs for Merge is accumulated twice for execution_count times, but
2644 // since Merge's cost is minimal, we keep this behavior here.
2645 EXPECT_EQ(178, c.execution_time.asMicroSeconds().count());
2646 // Both while/Merge and while/Merge_1 are scheduled twice.
2647 EXPECT_EQ(grappler_item_->graph.node_size() + 2, c.num_ops_total);
2648 EXPECT_FALSE(c.inaccurate);
2649 EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2650 }
2651 }
2652
TEST_F(VirtualSchedulerTest,Condition)2653 TEST_F(VirtualSchedulerTest, Condition) {
2654 // Without annotation.
2655 {
2656 // Inits.
2657 CreateGrapplerItemWithCondition();
2658 InitScheduler();
2659
2660 // Runs the scheduler.
2661 RunScheduler("");
2662 RunMetadata metadata;
2663 Costs c = scheduler_->Summary(&metadata);
2664
2665 // Nodes in topological order: a/Less, Switch, First/Second, Merge.
2666 int num_a = 0;
2667 int num_less = 0;
2668 int num_switch = 0;
2669 int num_first = 0;
2670 int num_second = 0;
2671 int num_merge = 0;
2672
2673 for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
2674 for (const auto& stats : device_step_stats.node_stats()) {
2675 if (stats.node_name() == "a") {
2676 ++num_a;
2677 } else if (stats.node_name() == "Less") {
2678 ++num_less;
2679 } else if (stats.node_name() == "Switch") {
2680 ++num_switch;
2681 } else if (stats.node_name() == "First") {
2682 ++num_first;
2683 } else if (stats.node_name() == "Second") {
2684 ++num_second;
2685 } else if (stats.node_name() == "Merge") {
2686 ++num_merge;
2687 }
2688 }
2689 }
2690
2691 EXPECT_EQ(1, num_a);
2692 EXPECT_EQ(1, num_less);
2693 EXPECT_EQ(1, num_switch);
2694 EXPECT_EQ(1, num_first);
2695 EXPECT_EQ(1, num_second);
2696 EXPECT_EQ(2, num_merge);
2697
2698 EXPECT_EQ(7, c.execution_time.asMicroSeconds().count());
2699 // Merge is executed twice.
2700 EXPECT_EQ(grappler_item_->graph.node_size() + 1, c.num_ops_total);
2701 EXPECT_FALSE(c.inaccurate);
2702 EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2703 }
2704
2705 // With annotation.
2706 {
2707 // Inits.
2708 CreateGrapplerItemWithCondition();
2709
2710 // Annotates the Switch node.
2711 for (auto& node : *grappler_item_->graph.mutable_node()) {
2712 if (node.name() == "Switch") {
2713 AttrValue attr_output_info;
2714 // Adds one output slot 0 so that Second shouldn't be executed.
2715 (*attr_output_info.mutable_list()).add_i(0);
2716 AddNodeAttr(kOutputSlots, attr_output_info, &node);
2717 }
2718 }
2719
2720 InitScheduler();
2721
2722 // Runs the scheduler.
2723 RunScheduler("");
2724 RunMetadata metadata;
2725 Costs c = scheduler_->Summary(&metadata);
2726
2727 // Nodes in topological order: a/Less, Switch, Merge
2728 int num_a = 0;
2729 int num_less = 0;
2730 int num_switch = 0;
2731 int num_first = 0;
2732 int num_second = 0;
2733 int num_merge = 0;
2734
2735 for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
2736 for (const auto& stats : device_step_stats.node_stats()) {
2737 if (stats.node_name() == "a") {
2738 ++num_a;
2739 } else if (stats.node_name() == "Less") {
2740 ++num_less;
2741 } else if (stats.node_name() == "Switch") {
2742 ++num_switch;
2743 } else if (stats.node_name() == "First") {
2744 ++num_first;
2745 } else if (stats.node_name() == "Second") {
2746 ++num_second;
2747 } else if (stats.node_name() == "Merge") {
2748 ++num_merge;
2749 }
2750 }
2751 }
2752
2753 EXPECT_EQ(1, num_a);
2754 EXPECT_EQ(1, num_less);
2755 EXPECT_EQ(1, num_switch);
2756 EXPECT_EQ(1, num_first);
2757 EXPECT_EQ(0, num_second);
2758 EXPECT_EQ(1, num_merge);
2759
2760 EXPECT_EQ(5, c.execution_time.asMicroSeconds().count());
2761 // Second is not executed.
2762 EXPECT_EQ(grappler_item_->graph.node_size() - 1, c.num_ops_total);
2763 EXPECT_FALSE(c.inaccurate);
2764 EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2765 }
2766 }
2767
TEST_F(VirtualSchedulerTest,InterDeviceTransfer)2768 TEST_F(VirtualSchedulerTest, InterDeviceTransfer) {
2769 // Init.
2770 CreateGrapplerItemWithInterDeviceTransfers();
2771 InitScheduler();
2772
2773 // Run the scheduler.
2774 auto ops_executed = RunScheduler("");
2775
2776 // Helper lambda to extract port num from _Send and _Recv op name.
2777 auto get_port_num = [](const string& name) -> int {
2778 if (name.find("bn_0") != string::npos) {
2779 return 0;
2780 } else if (name.find("bn_1") != string::npos) {
2781 return 1;
2782 } else if (name.find("bn_2") != string::npos) {
2783 return 2;
2784 } else if (name.find("bn_minus1") != string::npos) {
2785 return -1;
2786 }
2787 return -999;
2788 };
2789
2790 // Reorganize ops_executed for further testing.
2791 std::unordered_map<string, int> op_count;
2792 std::unordered_map<int, string> recv_op_names;
2793 std::unordered_map<int, string> send_op_names;
2794 for (const auto& x : ops_executed) {
2795 const auto& name = x.first;
2796 const auto& node_info = x.second;
2797 const auto& op = node_info.op_info.op();
2798 if (op == kRecv) {
2799 recv_op_names[get_port_num(name)] = name;
2800 } else if (op == kSend) {
2801 send_op_names[get_port_num(name)] = name;
2802 }
2803 op_count[op]++;
2804 }
2805
2806 // Same number of _Send and _Recv.
2807 EXPECT_EQ(op_count.at(kSend), op_count.at(kRecv));
2808
2809 // Expect 4 Send and Recvs each: port 0, 1, and, 2, and control dependency.
2810 EXPECT_EQ(op_count.at(kRecv), 4);
2811 EXPECT_EQ(op_count.at(kSend), 4);
2812
2813 // Helper lambda for extracting output Tensor size.
2814 auto get_output_size = [this, ops_executed](const string& name) -> int64 {
2815 const auto& output_properties_ = ops_executed.at(name).op_info.outputs();
2816 std::vector<OpInfo::TensorProperties> output_properties;
2817 for (const auto& output_property : output_properties_) {
2818 output_properties.push_back(output_property);
2819 }
2820 return CalculateOutputSize(output_properties, 0);
2821 };
2822
2823 // Validate transfer size.
2824 // Batchnorm output y is 4D vector: batch x width x width x depth.
2825 int input_size = 4 * batch_size_ * width_ * height_ * depth_in_;
2826 EXPECT_EQ(get_output_size(recv_op_names[0]), input_size);
2827 EXPECT_EQ(get_output_size(send_op_names[0]), input_size);
2828 // Mean and vars are 1-D vector with size depth_in_.
2829 EXPECT_EQ(get_output_size(recv_op_names[1]), 4 * depth_in_);
2830 EXPECT_EQ(get_output_size(send_op_names[1]), 4 * depth_in_);
2831 EXPECT_EQ(get_output_size(recv_op_names[2]), 4 * depth_in_);
2832 EXPECT_EQ(get_output_size(send_op_names[2]), 4 * depth_in_);
2833 // Control dependency size is 4B.
2834 EXPECT_EQ(get_output_size(recv_op_names[-1]), 4);
2835 EXPECT_EQ(get_output_size(send_op_names[-1]), 4);
2836 }
2837
TEST_F(VirtualSchedulerTest,GraphWithSendRecv)2838 TEST_F(VirtualSchedulerTest, GraphWithSendRecv) {
2839 // Init.
2840 CreateGrapplerItemWithSendRecv();
2841 InitScheduler();
2842
2843 // Run the scheduler.
2844 auto ops_executed = RunScheduler("");
2845
2846 EXPECT_GT(ops_executed.count("Const"), 0);
2847 EXPECT_GT(ops_executed.count("Send"), 0);
2848 EXPECT_GT(ops_executed.count("Recv"), 0);
2849 }
2850
TEST_F(VirtualSchedulerTest,GraphWithSendRecvDifferentDevice)2851 TEST_F(VirtualSchedulerTest, GraphWithSendRecvDifferentDevice) {
2852 // Init.
2853 CreateGrapplerItemWithSendRecv();
2854 // Change Recv node's device so that Send and Recv are placed on different
2855 // devices.
2856 auto& graph = grappler_item_->graph;
2857 const string recv_device = kCPU1;
2858 for (int i = 0; i < graph.node_size(); i++) {
2859 auto* node = graph.mutable_node(i);
2860 if (node->name() == "Recv") {
2861 node->set_device(recv_device);
2862 auto* attr = node->mutable_attr();
2863 (*attr)["recv_device"].set_s(recv_device);
2864 } else if (node->name() == "Send") {
2865 auto* attr = node->mutable_attr();
2866 (*attr)["recv_device"].set_s(recv_device);
2867 }
2868 }
2869 InitScheduler();
2870
2871 // Run the scheduler.
2872 auto ops_executed = RunScheduler("");
2873
2874 // Expect Const, Send, Recv, and VirtualScheduler created Send and Recv ops.
2875 EXPECT_GT(ops_executed.count("Const"), 0);
2876 EXPECT_GT(ops_executed.count("Send"), 0);
2877 EXPECT_GT(ops_executed.count("Send_Send_0_from_/job_localhost/replica_0/"
2878 "task_0/cpu_0_to_/job_localhost"
2879 "/replica_0/task_0/cpu_1"),
2880 0);
2881 EXPECT_GT(ops_executed.count(
2882 "Recv_Send_0_on_/job_localhost/replica_0/task_0/cpu_1"),
2883 0);
2884 EXPECT_GT(ops_executed.count("Recv"), 0);
2885 }
2886
TEST_F(VirtualSchedulerTest,GraphWihtOnlyRecv)2887 TEST_F(VirtualSchedulerTest, GraphWihtOnlyRecv) {
2888 // Init.
2889 CreateGrapplerItemWithRecvWithoutSend();
2890 InitScheduler();
2891
2892 // Run the scheduler.
2893 auto ops_executed = RunScheduler("");
2894
2895 // Recv without Send will be treated as initially ready node.
2896 EXPECT_GT(ops_executed.count("Recv"), 0);
2897 }
2898
2899 } // end namespace grappler
2900 } // end namespace tensorflow
2901