• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/costs/virtual_scheduler.h"
17 #include "tensorflow/cc/ops/standard_ops.h"
18 #include "tensorflow/core/framework/tensor.pb.h"  // NOLINT
19 #include "tensorflow/core/framework/tensor_description.pb.h"
20 #include "tensorflow/core/framework/tensor_shape.pb.h"
21 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
22 #include "tensorflow/core/grappler/costs/utils.h"
23 #include "tensorflow/core/grappler/costs/virtual_placer.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26 
27 namespace tensorflow {
28 namespace grappler {
29 
30 // Class for testing virtual scheduler.
31 class TestVirtualScheduler : public VirtualScheduler {
32  public:
TestVirtualScheduler(const bool use_static_shapes,const bool use_aggressive_shape_inference,Cluster * cluster)33   TestVirtualScheduler(const bool use_static_shapes,
34                        const bool use_aggressive_shape_inference,
35                        Cluster* cluster)
36       : VirtualScheduler(use_static_shapes, use_aggressive_shape_inference,
37                          cluster, &ready_node_manager_) {
38     enable_mem_usage_tracking();
39   }
40 
41   FRIEND_TEST(VirtualSchedulerTest, MemoryUsage);
42   FRIEND_TEST(VirtualSchedulerTest, ControlDependency);
43   FRIEND_TEST(VirtualSchedulerTest, ComplexDependency);
44   FRIEND_TEST(VirtualSchedulerTest, Variable);
45   FRIEND_TEST(VirtualSchedulerTest, InterDeviceTransfer);
46 
47  protected:
48   FirstReadyManager ready_node_manager_;
49 };
50 
51 class VirtualSchedulerTest : public ::testing::Test {
52  protected:
VirtualSchedulerTest()53   VirtualSchedulerTest() {
54     // node1_ to node6_ on kCPU0, with time_ready in reverse_order.
55     NodeSetUp("Node1", kConv2D, kCPU0, 6000, &node1_);
56     NodeSetUp("Node2", kConv2D, kCPU0, 5000, &node2_);
57     NodeSetUp("Node3", kConv2D, kCPU0, 4000, &node3_);
58     NodeSetUp("Node4", kConv2D, kCPU0, 3000, &node4_);
59     NodeSetUp("Node5", kConv2D, kCPU0, 2000, &node5_);
60     NodeSetUp("Node6", kConv2D, kCPU0, 1000, &node6_);
61 
62     // Initializes cluster_ and scheduler_.
63     std::unordered_map<string, DeviceProperties> devices;
64 
65     // Set some dummy CPU properties
66     DeviceProperties cpu_device = GetDummyCPUDevice();
67 
68     // IMPORTANT: Device is not actually ever used in the test case since
69     // force_cpu_type is defaulted to "Haswell"
70     devices[kCPU0] = cpu_device;
71     devices[kCPU1] = cpu_device;
72     cluster_ = absl::make_unique<VirtualCluster>(devices);
73     scheduler_ = absl::make_unique<TestVirtualScheduler>(
74         /*use_static_shapes=*/true,
75         /*use_aggressive_shape_inference=*/true, cluster_.get());
76   }
77 
78   NodeDef node1_, node2_, node3_, node4_, node5_, node6_;
79   std::unordered_map<const NodeDef*, NodeState> node_states_;
80 
81   // Device names:
82   const string kCPU0 = "/job:localhost/replica:0/task:0/cpu:0";
83   const string kCPU1 = "/job:localhost/replica:0/task:0/cpu:1";
84   const string kChannelFrom0To1 = "Channel from CPU0 to CPU1";
85   const string kChannelFrom1To0 = "Channel from CPU1 to CPU0";
86   // Op names:
87   const string kSend = "_Send";
88   const string kRecv = "_Recv";
89   const string kConv2D = "Conv2D";
90 
GetDummyCPUDevice()91   DeviceProperties GetDummyCPUDevice() {
92     // Create CPU with 2 cores, 4 Ghz freq, 2 GB/s mem bandwidth.
93     // - 8 Gflops
94     // - 2 GB/s
95     DeviceProperties cpu_device;
96     cpu_device.set_type("CPU");
97     cpu_device.set_frequency(4000);
98     cpu_device.set_num_cores(2);
99     cpu_device.set_bandwidth(2000000);
100     return cpu_device;
101   }
102 
NodeSetUp(const string & name,const string & op_name,const string & device_name,const uint64 time_ready,NodeDef * node)103   void NodeSetUp(const string& name, const string& op_name,
104                  const string& device_name, const uint64 time_ready,
105                  NodeDef* node) {
106     node->set_name(name);
107     node->set_op(op_name);
108     node->set_device(device_name);
109 
110     node_states_[node] = NodeState();
111     node_states_[node].time_ready = time_ready;
112     node_states_[node].device_name = device_name;
113   }
114 
115   // Three Conv2Ds with only two in fetch nodes.
CreateGrapplerItemWithConv2Ds()116   void CreateGrapplerItemWithConv2Ds() {
117     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
118     auto x = ops::RandomUniform(
119         s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
120     auto y = ops::RandomUniform(
121         s.WithOpName("y"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
122     auto z = ops::RandomUniform(
123         s.WithOpName("z"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
124     auto f = ops::RandomUniform(
125         s.WithOpName("f"), {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
126     std::vector<int> strides = {1, 1, 1, 1};
127     auto c0 = ops::Conv2D(s.WithOpName("c0"), x, f, strides, "SAME");
128     auto c1 = ops::Conv2D(s.WithOpName("c1"), y, f, strides, "SAME");
129     auto c2 = ops::Conv2D(s.WithOpName("c2"), z, f, strides, "SAME");
130     GraphDef def;
131     TF_CHECK_OK(s.ToGraphDef(&def));
132 
133     grappler_item_.reset(new GrapplerItem);
134     grappler_item_->id = "test_conv2d_graph";
135     grappler_item_->graph = def;
136     grappler_item_->fetch = {"c0", "c1"};
137 
138     dependency_["c0"] = {"x", "f"};
139     dependency_["c1"] = {"y", "f"};
140   }
141 
142   // A Conv2D with a variable.
CreateGrapplerItemWithConv2DAndVariable()143   void CreateGrapplerItemWithConv2DAndVariable() {
144     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
145     auto x = ops::RandomUniform(
146         s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
147     auto f = ops::Variable(s.WithOpName("f"),
148                            {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
149     std::vector<int> strides = {1, 1, 1, 1};
150     auto y = ops::Conv2D(s.WithOpName("y"), x, f, strides, "SAME");
151     GraphDef def;
152     TF_CHECK_OK(s.ToGraphDef(&def));
153 
154     grappler_item_.reset(new GrapplerItem);
155     grappler_item_->id = "test_conv2d_var_graph";
156     grappler_item_->graph = def;
157     grappler_item_->fetch = {"y"};
158 
159     dependency_["y"] = {"x", "f"};
160   }
161 
CreateGrapplerItemWithMatmulChain()162   void CreateGrapplerItemWithMatmulChain() {
163     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
164     // Add control dependencies to ensure tests do not rely on specific
165     // manager and the order remains consistent for the test.
166     auto a = ops::RandomUniform(s.WithOpName("a"), {3200, 3200}, DT_FLOAT);
167     auto b = ops::RandomUniform(s.WithOpName("b").WithControlDependencies(a),
168                                 {3200, 3200}, DT_FLOAT);
169     auto c = ops::RandomUniform(s.WithOpName("c").WithControlDependencies(b),
170                                 {3200, 3200}, DT_FLOAT);
171     auto d = ops::RandomUniform(s.WithOpName("d").WithControlDependencies(c),
172                                 {3200, 3200}, DT_FLOAT);
173     auto e = ops::RandomUniform(s.WithOpName("e").WithControlDependencies(d),
174                                 {3200, 3200}, DT_FLOAT);
175 
176     auto ab = ops::MatMul(s.WithOpName("ab").WithControlDependencies(e), a, b);
177     auto abc = ops::MatMul(s.WithOpName("abc"), ab, c);
178     auto abcd = ops::MatMul(s.WithOpName("abcd"), abc, d);
179     auto abcde = ops::MatMul(s.WithOpName("abcde"), abcd, e);
180 
181     GraphDef def;
182     TF_CHECK_OK(s.ToGraphDef(&def));
183 
184     grappler_item_.reset(new GrapplerItem);
185     grappler_item_->id = "test_matmul_sequence_graph";
186     grappler_item_->graph = def;
187     grappler_item_->fetch = {"abcde"};
188 
189     dependency_["ab"] = {"a", "b"};
190     dependency_["abc"] = {"ab", "c"};
191     dependency_["abcd"] = {"abc", "d"};
192     dependency_["abcde"] = {"abcd", "e"};
193   }
194 
195   // AddN that takes 4 tensors with 10x10x10x10.
CreateGrapplerItemWithAddN()196   void CreateGrapplerItemWithAddN() {
197     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
198     auto x = ops::RandomUniform(s.WithOpName("x"), {10, 10, 10, 10}, DT_FLOAT);
199     auto y = ops::RandomUniform(s.WithOpName("y"), {10, 10, 10, 10}, DT_FLOAT);
200     auto z = ops::RandomUniform(s.WithOpName("z"), {10, 10, 10, 10}, DT_FLOAT);
201     auto w = ops::RandomUniform(s.WithOpName("w"), {10, 10, 10, 10}, DT_FLOAT);
202     OutputList input_tensors = {x, y, z, w};
203     auto out = ops::AddN(s.WithOpName("out"), input_tensors);
204     GraphDef def;
205     TF_CHECK_OK(s.ToGraphDef(&def));
206 
207     grappler_item_.reset(new GrapplerItem);
208     grappler_item_->id = "test_addn_graph";
209     grappler_item_->graph = def;
210     grappler_item_->fetch = {"out"};
211 
212     dependency_["out"] = {"x", "y", "z", "w"};
213   }
214 
215   // Graph with some placeholder feed nodes that are not in the fetch fan-in.
CreateGrapplerItemWithUnnecessaryPlaceholderNodes()216   void CreateGrapplerItemWithUnnecessaryPlaceholderNodes() {
217     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
218     auto unnecessary = ops::Placeholder(s.WithOpName("unnecessary"), DT_FLOAT);
219     auto x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT);
220 
221     GraphDef def;
222     TF_CHECK_OK(s.ToGraphDef(&def));
223 
224     grappler_item_.reset(new GrapplerItem);
225     grappler_item_->id = "test_extra_placeholders";
226     grappler_item_->graph = def;
227     grappler_item_->fetch = {"x"};
228 
229     // Grappler Item Builder puts all placeholder nodes into the feed
230     // list by default.
231     grappler_item_->feed = {{"x", Tensor()}, {"unnecessary", Tensor()}};
232   }
233 
234   // NoOp that takes 7 NoOps as control dependency.
CreateGrapplerItemWithControlDependency()235   void CreateGrapplerItemWithControlDependency() {
236     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
237     std::vector<string> input_noop_names = {"x", "y", "z", "w", "u", "v", "t"};
238     std::vector<Operation> input_tensors;
239     for (const auto& input : input_noop_names) {
240       auto x = ops::NoOp(s.WithOpName(input));
241       input_tensors.push_back(x.operation);
242     }
243     auto out =
244         ops::NoOp(s.WithControlDependencies(input_tensors).WithOpName("out"));
245     GraphDef def;
246     TF_CHECK_OK(s.ToGraphDef(&def));
247 
248     grappler_item_.reset(new GrapplerItem);
249     grappler_item_->id = "test_control_dependency_graph";
250     grappler_item_->graph = def;
251     grappler_item_->fetch = {"out"};
252 
253     dependency_["out"] = input_noop_names;
254   }
255 
256   // FusedBN [an op with multiple outputs] with multiple consumers (including
257   // control dependency).
CreateGrapplerItemWithBatchNorm()258   void CreateGrapplerItemWithBatchNorm() {
259     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
260     auto x = ops::RandomUniform(
261         s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
262     auto scale =
263         ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
264     auto offset =
265         ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
266     auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
267     auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
268 
269     auto batch_norm = ops::FusedBatchNorm(
270         s.WithOpName("bn"), x, scale, offset, mean, var,
271         ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
272     auto y = batch_norm.y;
273     auto batch_mean = batch_norm.batch_mean;
274     auto batch_var = batch_norm.batch_variance;
275 
276     auto z1 = ops::Add(s.WithOpName("z1"), x, y);
277     auto z2 = ops::Add(s.WithOpName("z2"), batch_var, batch_var);
278     auto z3 = ops::Add(s.WithOpName("z3"), batch_var, batch_var);
279     std::vector<Operation> input_tensors = {
280         batch_mean.op(),
281         z1.z.op(),
282         z2.z.op(),
283         z3.z.op(),
284     };
285     auto z4 = ops::NoOp(s.WithControlDependencies(batch_var).WithOpName("z4"));
286 
287     GraphDef def;
288     TF_CHECK_OK(s.ToGraphDef(&def));
289 
290     grappler_item_.reset(new GrapplerItem);
291     grappler_item_->id = "test_complex_dependency_graph";
292     grappler_item_->graph = def;
293     grappler_item_->fetch = {"z1", "z2", "z3", "z4"};
294 
295     dependency_["bn"] = {"x", "scale", "offset", "mean", "var"};
296     dependency_["z1"] = {"x", "bn"};
297     dependency_["z2"] = {"bn"};
298     dependency_["z3"] = {"bn"};
299     dependency_["z4"] = {"bn"};
300   }
301 
CreateGrapplerItemWithSendRecv()302   void CreateGrapplerItemWithSendRecv() {
303     const string gdef_ascii = R"EOF(
304 node {
305   name: "Const"
306   op: "Const"
307   device: "/job:localhost/replica:0/task:0/device:CPU:0"
308   attr {
309     key: "dtype"
310     value {
311       type: DT_FLOAT
312     }
313   }
314   attr {
315     key: "value"
316     value {
317       tensor {
318         dtype: DT_FLOAT
319         tensor_shape {
320         }
321         float_val: 3.1415
322       }
323     }
324   }
325 }
326 node {
327   name: "Send"
328   op: "_Send"
329   input: "Const"
330   device: "/job:localhost/replica:0/task:0/device:CPU:0"
331   attr {
332     key: "T"
333     value {
334       type: DT_FLOAT
335     }
336   }
337   attr {
338     key: "client_terminated"
339     value {
340       b: false
341     }
342   }
343   attr {
344     key: "recv_device"
345     value {
346       s: "/job:localhost/replica:0/task:0/device:CPU:0"
347     }
348   }
349   attr {
350     key: "send_device"
351     value {
352       s: "/job:localhost/replica:0/task:0/device:CPU:0"
353     }
354   }
355   attr {
356     key: "send_device_incarnation"
357     value {
358       i: 0
359     }
360   }
361   attr {
362     key: "tensor_name"
363     value {
364       s: "test"
365     }
366   }
367 }
368 node {
369   name: "Recv"
370   op: "_Recv"
371   device: "/job:localhost/replica:0/task:0/device:CPU:0"
372   attr {
373     key: "client_terminated"
374     value {
375       b: false
376     }
377   }
378   attr {
379     key: "recv_device"
380     value {
381       s: "/job:localhost/replica:0/task:0/device:CPU:0"
382     }
383   }
384   attr {
385     key: "send_device"
386     value {
387       s: "/job:localhost/replica:0/task:0/device:CPU:0"
388     }
389   }
390   attr {
391     key: "send_device_incarnation"
392     value {
393       i: 0
394     }
395   }
396   attr {
397     key: "tensor_name"
398     value {
399       s: "test"
400     }
401   }
402   attr {
403     key: "tensor_type"
404     value {
405       type: DT_FLOAT
406     }
407   }
408 }
409 library {
410 }
411 versions {
412   producer: 24
413 }
414     )EOF";
415 
416     grappler_item_.reset(new GrapplerItem);
417     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
418                                                 &grappler_item_->graph));
419     grappler_item_->id = "test_graph";
420     grappler_item_->fetch = {"Recv"};
421   }
422 
CreateGrapplerItemWithRecvWithoutSend()423   void CreateGrapplerItemWithRecvWithoutSend() {
424     const string gdef_ascii = R"EOF(
425 node {
426   name: "Recv"
427   op: "_Recv"
428   device: "/job:localhost/replica:0/task:0/device:CPU:0"
429   attr {
430     key: "client_terminated"
431     value {
432       b: false
433     }
434   }
435   attr {
436     key: "recv_device"
437     value {
438       s: "/job:localhost/replica:0/task:0/device:CPU:0"
439     }
440   }
441   attr {
442     key: "send_device"
443     value {
444       s: "/job:localhost/replica:0/task:0/device:CPU:0"
445     }
446   }
447   attr {
448     key: "send_device_incarnation"
449     value {
450       i: 0
451     }
452   }
453   attr {
454     key: "tensor_name"
455     value {
456       s: "test"
457     }
458   }
459   attr {
460     key: "tensor_type"
461     value {
462       type: DT_FLOAT
463     }
464   }
465 }
466 library {
467 }
468 versions {
469   producer: 24
470 }
471     )EOF";
472 
473     grappler_item_.reset(new GrapplerItem);
474     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
475                                                 &grappler_item_->graph));
476     grappler_item_->id = "test_graph";
477     grappler_item_->fetch = {"Recv"};
478   }
479 
480   // A simple while loop
CreateGrapplerItemWithLoop()481   void CreateGrapplerItemWithLoop() {
482     // Test graph produced in python using:
483     /*
484       with tf.Graph().as_default():
485       i0 = tf.constant(0)
486       m0 = tf.ones([2, 2])
487       c = lambda i, m: i < 10
488       b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
489       r = tf.while_loop(
490       c, b, loop_vars=[i0, m0],
491       shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
492       with open('/tmp/graph.pbtxt', 'w') as f:
493       f.write(str(tf.get_default_graph().as_graph_def()))
494     */
495     const string gdef_ascii = R"EOF(
496 node {
497   name: "Const"
498   op: "Const"
499   attr {
500     key: "dtype"
501     value {
502       type: DT_INT32
503     }
504   }
505   attr {
506     key: "value"
507     value {
508       tensor {
509         dtype: DT_INT32
510         tensor_shape {
511         }
512         int_val: 0
513       }
514     }
515   }
516 }
517 node {
518   name: "ones"
519   op: "Const"
520   attr {
521     key: "dtype"
522     value {
523       type: DT_FLOAT
524     }
525   }
526   attr {
527     key: "value"
528     value {
529       tensor {
530         dtype: DT_FLOAT
531         tensor_shape {
532           dim {
533             size: 2
534           }
535           dim {
536             size: 2
537           }
538         }
539         float_val: 1.0
540       }
541     }
542   }
543 }
544 node {
545   name: "while/Enter"
546   op: "Enter"
547   input: "Const"
548   attr {
549     key: "T"
550     value {
551       type: DT_INT32
552     }
553   }
554   attr {
555     key: "frame_name"
556     value {
557       s: "while/while/"
558     }
559   }
560   attr {
561     key: "is_constant"
562     value {
563       b: false
564     }
565   }
566   attr {
567     key: "parallel_iterations"
568     value {
569       i: 10
570     }
571   }
572 }
573 node {
574   name: "while/Enter_1"
575   op: "Enter"
576   input: "ones"
577   attr {
578     key: "T"
579     value {
580       type: DT_FLOAT
581     }
582   }
583   attr {
584     key: "frame_name"
585     value {
586       s: "while/while/"
587     }
588   }
589   attr {
590     key: "is_constant"
591     value {
592       b: false
593     }
594   }
595   attr {
596     key: "parallel_iterations"
597     value {
598       i: 10
599     }
600   }
601 }
602 node {
603   name: "while/Merge"
604   op: "Merge"
605   input: "while/Enter"
606   input: "while/NextIteration"
607   attr {
608     key: "N"
609     value {
610       i: 2
611     }
612   }
613   attr {
614     key: "T"
615     value {
616       type: DT_INT32
617     }
618   }
619 }
620 node {
621   name: "while/Merge_1"
622   op: "Merge"
623   input: "while/Enter_1"
624   input: "while/NextIteration_1"
625   attr {
626     key: "N"
627     value {
628       i: 2
629     }
630   }
631   attr {
632     key: "T"
633     value {
634       type: DT_FLOAT
635     }
636   }
637 }
638 node {
639   name: "while/Less/y"
640   op: "Const"
641   input: "^while/Merge"
642   attr {
643     key: "dtype"
644     value {
645       type: DT_INT32
646     }
647   }
648   attr {
649     key: "value"
650     value {
651       tensor {
652         dtype: DT_INT32
653         tensor_shape {
654         }
655         int_val: 10
656       }
657     }
658   }
659 }
660 node {
661   name: "while/Less"
662   op: "Less"
663   input: "while/Merge"
664   input: "while/Less/y"
665   attr {
666     key: "T"
667     value {
668       type: DT_INT32
669     }
670   }
671 }
672 node {
673   name: "while/LoopCond"
674   op: "LoopCond"
675   input: "while/Less"
676 }
677 node {
678   name: "while/Switch"
679   op: "Switch"
680   input: "while/Merge"
681   input: "while/LoopCond"
682   attr {
683     key: "T"
684     value {
685       type: DT_INT32
686     }
687   }
688   attr {
689     key: "_class"
690     value {
691       list {
692         s: "loc:@while/Merge"
693       }
694     }
695   }
696 }
697 node {
698   name: "while/Switch_1"
699   op: "Switch"
700   input: "while/Merge_1"
701   input: "while/LoopCond"
702   attr {
703     key: "T"
704     value {
705       type: DT_FLOAT
706     }
707   }
708   attr {
709     key: "_class"
710     value {
711       list {
712         s: "loc:@while/Merge_1"
713       }
714     }
715   }
716 }
717 node {
718   name: "while/Identity"
719   op: "Identity"
720   input: "while/Switch:1"
721   attr {
722     key: "T"
723     value {
724       type: DT_INT32
725     }
726   }
727 }
728 node {
729   name: "while/Identity_1"
730   op: "Identity"
731   input: "while/Switch_1:1"
732   attr {
733     key: "T"
734     value {
735       type: DT_FLOAT
736     }
737   }
738 }
739 node {
740   name: "while/add/y"
741   op: "Const"
742   input: "^while/Identity"
743   attr {
744     key: "dtype"
745     value {
746       type: DT_INT32
747     }
748   }
749   attr {
750     key: "value"
751     value {
752       tensor {
753         dtype: DT_INT32
754         tensor_shape {
755         }
756         int_val: 1
757       }
758     }
759   }
760 }
761 node {
762   name: "while/add"
763   op: "Add"
764   input: "while/Identity"
765   input: "while/add/y"
766   attr {
767     key: "T"
768     value {
769       type: DT_INT32
770     }
771   }
772 }
773 node {
774   name: "while/concat/axis"
775   op: "Const"
776   input: "^while/Identity"
777   attr {
778     key: "dtype"
779     value {
780       type: DT_INT32
781     }
782   }
783   attr {
784     key: "value"
785     value {
786       tensor {
787         dtype: DT_INT32
788         tensor_shape {
789         }
790         int_val: 0
791       }
792     }
793   }
794 }
795 node {
796   name: "while/concat"
797   op: "ConcatV2"
798   input: "while/Identity_1"
799   input: "while/Identity_1"
800   input: "while/concat/axis"
801   attr {
802     key: "N"
803     value {
804       i: 2
805     }
806   }
807   attr {
808     key: "T"
809     value {
810       type: DT_FLOAT
811     }
812   }
813   attr {
814     key: "Tidx"
815     value {
816       type: DT_INT32
817     }
818   }
819 }
820 node {
821   name: "while/NextIteration"
822   op: "NextIteration"
823   input: "while/add"
824   attr {
825     key: "T"
826     value {
827       type: DT_INT32
828     }
829   }
830 }
831 node {
832   name: "while/NextIteration_1"
833   op: "NextIteration"
834   input: "while/concat"
835   attr {
836     key: "T"
837     value {
838       type: DT_FLOAT
839     }
840   }
841 }
842 node {
843   name: "while/Exit"
844   op: "Exit"
845   input: "while/Switch"
846   attr {
847     key: "T"
848     value {
849       type: DT_INT32
850     }
851   }
852 }
853 node {
854   name: "while/Exit_1"
855   op: "Exit"
856   input: "while/Switch_1"
857   attr {
858     key: "T"
859     value {
860       type: DT_FLOAT
861     }
862   }
863 }
864 versions {
865   producer: 21
866 }
867   )EOF";
868 
869     grappler_item_.reset(new GrapplerItem);
870     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
871                                                 &grappler_item_->graph));
872     grappler_item_->id = "test_graph";
873     grappler_item_->fetch = {"while/Exit", "while/Exit_1"};
874   }
875 
876   // A simple while loop strengthened with Switch outputs xxx.
CreateGrapplerItemWithLoopAnnotated()877   void CreateGrapplerItemWithLoopAnnotated() {
878     // Test graph produced in python using:
879     /*
880       with tf.Graph().as_default():
881       i0 = tf.constant(0)
882       m0 = tf.ones([2, 2])
883       c = lambda i, m: i < 10
884       b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
885       r = tf.while_loop(
886       c, b, loop_vars=[i0, m0],
887       shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
888       with open('/tmp/graph.pbtxt', 'w') as f:
889       f.write(str(tf.get_default_graph().as_graph_def()))
890     */
891     const string gdef_ascii = R"EOF(
892 node {
893   name: "Const"
894   op: "Const"
895   attr {
896     key: "dtype"
897     value {
898       type: DT_INT32
899     }
900   }
901   attr {
902     key: "value"
903     value {
904       tensor {
905         dtype: DT_INT32
906         tensor_shape {
907         }
908         int_val: 0
909       }
910     }
911   }
912   attr {
913     key: "_execution_count"
914     value {
915       i: 1
916     }
917   }
918 }
919 node {
920   name: "ones"
921   op: "Const"
922   attr {
923     key: "dtype"
924     value {
925       type: DT_FLOAT
926     }
927   }
928   attr {
929     key: "value"
930     value {
931       tensor {
932         dtype: DT_FLOAT
933         tensor_shape {
934           dim {
935             size: 2
936           }
937           dim {
938             size: 2
939           }
940         }
941         float_val: 1.0
942       }
943     }
944   }
945   attr {
946     key: "_execution_count"
947     value {
948       i: 1
949     }
950   }
951 }
952 node {
953   name: "while/Enter"
954   op: "Enter"
955   input: "Const"
956   attr {
957     key: "T"
958     value {
959       type: DT_INT32
960     }
961   }
962   attr {
963     key: "frame_name"
964     value {
965       s: "while/while/"
966     }
967   }
968   attr {
969     key: "is_constant"
970     value {
971       b: false
972     }
973   }
974   attr {
975     key: "parallel_iterations"
976     value {
977       i: 10
978     }
979   }
980   attr {
981     key: "_execution_count"
982     value {
983       i: 1
984     }
985   }
986 }
987 node {
988   name: "while/Enter_1"
989   op: "Enter"
990   input: "ones"
991   attr {
992     key: "T"
993     value {
994       type: DT_FLOAT
995     }
996   }
997   attr {
998     key: "frame_name"
999     value {
1000       s: "while/while/"
1001     }
1002   }
1003   attr {
1004     key: "is_constant"
1005     value {
1006       b: false
1007     }
1008   }
1009   attr {
1010     key: "parallel_iterations"
1011     value {
1012       i: 10
1013     }
1014   }
1015   attr {
1016     key: "_execution_count"
1017     value {
1018       i: 1
1019     }
1020   }
1021 }
1022 node {
1023   name: "while/Merge"
1024   op: "Merge"
1025   input: "while/Enter"
1026   input: "while/NextIteration"
1027   attr {
1028     key: "N"
1029     value {
1030       i: 2
1031     }
1032   }
1033   attr {
1034     key: "T"
1035     value {
1036       type: DT_INT32
1037     }
1038   }
1039   attr {
1040     key: "_execution_count"
1041     value {
1042       i: 10
1043     }
1044   }
1045 }
1046 node {
1047   name: "while/Merge_1"
1048   op: "Merge"
1049   input: "while/Enter_1"
1050   input: "while/NextIteration_1"
1051   attr {
1052     key: "N"
1053     value {
1054       i: 2
1055     }
1056   }
1057   attr {
1058     key: "T"
1059     value {
1060       type: DT_FLOAT
1061     }
1062   }
1063   attr {
1064     key: "_execution_count"
1065     value {
1066       i: 10
1067     }
1068   }
1069 }
1070 node {
1071   name: "while/Less/y"
1072   op: "Const"
1073   input: "^while/Merge"
1074   attr {
1075     key: "dtype"
1076     value {
1077       type: DT_INT32
1078     }
1079   }
1080   attr {
1081     key: "value"
1082     value {
1083       tensor {
1084         dtype: DT_INT32
1085         tensor_shape {
1086         }
1087         int_val: 10
1088       }
1089     }
1090   }
1091   attr {
1092     key: "_execution_count"
1093     value {
1094       i: 10
1095     }
1096   }
1097 }
1098 node {
1099   name: "while/Less"
1100   op: "Less"
1101   input: "while/Merge"
1102   input: "while/Less/y"
1103   attr {
1104     key: "T"
1105     value {
1106       type: DT_INT32
1107     }
1108   }
1109   attr {
1110     key: "_execution_count"
1111     value {
1112       i: 10
1113     }
1114   }
1115 }
1116 node {
1117   name: "while/LoopCond"
1118   op: "LoopCond"
1119   input: "while/Less"
1120   attr {
1121     key: "_execution_count"
1122     value {
1123       i: 10
1124     }
1125   }
1126 }
1127 node {
1128   name: "while/Switch"
1129   op: "Switch"
1130   input: "while/Merge"
1131   input: "while/LoopCond"
1132   attr {
1133     key: "T"
1134     value {
1135       type: DT_INT32
1136     }
1137   }
1138   attr {
1139     key: "_class"
1140     value {
1141       list {
1142         s: "loc:@while/Merge"
1143       }
1144     }
1145   }
1146   attr {
1147     key: "_execution_count"
1148     value {
1149       i: 11
1150     }
1151   }
1152   attr {
1153     key: "_output_slot_vector"
1154     value {
1155       list {
1156         i: 1
1157         i: 1
1158         i: 1
1159         i: 1
1160         i: 1
1161         i: 1
1162         i: 1
1163         i: 1
1164         i: 1
1165         i: 1
1166         i: 0
1167       }
1168     }
1169   }
1170 }
1171 node {
1172   name: "while/Switch_1"
1173   op: "Switch"
1174   input: "while/Merge_1"
1175   input: "while/LoopCond"
1176   attr {
1177     key: "T"
1178     value {
1179       type: DT_FLOAT
1180     }
1181   }
1182   attr {
1183     key: "_class"
1184     value {
1185       list {
1186         s: "loc:@while/Merge_1"
1187       }
1188     }
1189   }
1190   attr {
1191     key: "_execution_count"
1192     value {
1193       i: 11
1194     }
1195   }
1196   attr {
1197     key: "_output_slot_vector"
1198     value {
1199       list {
1200         i: 1
1201         i: 1
1202         i: 1
1203         i: 1
1204         i: 1
1205         i: 1
1206         i: 1
1207         i: 1
1208         i: 1
1209         i: 1
1210         i: 0
1211       }
1212     }
1213   }
1214 }
1215 node {
1216   name: "while/Identity"
1217   op: "Identity"
1218   input: "while/Switch:1"
1219   attr {
1220     key: "T"
1221     value {
1222       type: DT_INT32
1223     }
1224   }
1225   attr {
1226     key: "_execution_count"
1227     value {
1228       i: 10
1229     }
1230   }
1231 }
1232 node {
1233   name: "while/Identity_1"
1234   op: "Identity"
1235   input: "while/Switch_1:1"
1236   attr {
1237     key: "T"
1238     value {
1239       type: DT_FLOAT
1240     }
1241   }
1242   attr {
1243     key: "_execution_count"
1244     value {
1245       i: 10
1246     }
1247   }
1248 }
1249 node {
1250   name: "while/add/y"
1251   op: "Const"
1252   input: "^while/Identity"
1253   attr {
1254     key: "dtype"
1255     value {
1256       type: DT_INT32
1257     }
1258   }
1259   attr {
1260     key: "value"
1261     value {
1262       tensor {
1263         dtype: DT_INT32
1264         tensor_shape {
1265         }
1266         int_val: 1
1267       }
1268     }
1269   }
1270   attr {
1271     key: "_execution_count"
1272     value {
1273       i: 10
1274     }
1275   }
1276 }
1277 node {
1278   name: "while/add"
1279   op: "Add"
1280   input: "while/Identity"
1281   input: "while/add/y"
1282   attr {
1283     key: "T"
1284     value {
1285       type: DT_INT32
1286     }
1287   }
1288   attr {
1289     key: "_execution_count"
1290     value {
1291       i: 10
1292     }
1293   }
1294 }
1295 node {
1296   name: "while/concat/axis"
1297   op: "Const"
1298   input: "^while/Identity"
1299   attr {
1300     key: "dtype"
1301     value {
1302       type: DT_INT32
1303     }
1304   }
1305   attr {
1306     key: "value"
1307     value {
1308       tensor {
1309         dtype: DT_INT32
1310         tensor_shape {
1311         }
1312         int_val: 0
1313       }
1314     }
1315   }
1316   attr {
1317     key: "_execution_count"
1318     value {
1319       i: 10
1320     }
1321   }
1322 }
1323 node {
1324   name: "while/concat"
1325   op: "ConcatV2"
1326   input: "while/Identity_1"
1327   input: "while/Identity_1"
1328   input: "while/concat/axis"
1329   attr {
1330     key: "N"
1331     value {
1332       i: 2
1333     }
1334   }
1335   attr {
1336     key: "T"
1337     value {
1338       type: DT_FLOAT
1339     }
1340   }
1341   attr {
1342     key: "Tidx"
1343     value {
1344       type: DT_INT32
1345     }
1346   }
1347   attr {
1348     key: "_execution_count"
1349     value {
1350       i: 10
1351     }
1352   }
1353 }
1354 node {
1355   name: "while/NextIteration"
1356   op: "NextIteration"
1357   input: "while/add"
1358   attr {
1359     key: "T"
1360     value {
1361       type: DT_INT32
1362     }
1363   }
1364   attr {
1365     key: "_execution_count"
1366     value {
1367       i: 10
1368     }
1369   }
1370 }
1371 node {
1372   name: "while/NextIteration_1"
1373   op: "NextIteration"
1374   input: "while/concat"
1375   attr {
1376     key: "T"
1377     value {
1378       type: DT_FLOAT
1379     }
1380   }
1381   attr {
1382     key: "_execution_count"
1383     value {
1384       i: 10
1385     }
1386   }
1387 }
1388 node {
1389   name: "while/Exit"
1390   op: "Exit"
1391   input: "while/Switch"
1392   attr {
1393     key: "T"
1394     value {
1395       type: DT_INT32
1396     }
1397   }
1398   attr {
1399     key: "_execution_count"
1400     value {
1401       i: 1
1402     }
1403   }
1404 }
1405 node {
1406   name: "while/Exit_1"
1407   op: "Exit"
1408   input: "while/Switch_1"
1409   attr {
1410     key: "T"
1411     value {
1412       type: DT_FLOAT
1413     }
1414   }
1415   attr {
1416     key: "_execution_count"
1417     value {
1418       i: 1
1419     }
1420   }
1421 }
1422 versions {
1423   producer: 21
1424 }
1425   )EOF";
1426 
1427     grappler_item_.reset(new GrapplerItem);
1428     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
1429                                                 &grappler_item_->graph));
1430     grappler_item_->id = "test_graph";
1431     grappler_item_->fetch = {"while/Exit", "while/Exit_1"};
1432   }
1433 
1434   // A simple condition graph.
CreateGrapplerItemWithCondition()1435   void CreateGrapplerItemWithCondition() {
1436     // Handcrafted test graph: a/Less -> Switch -> First/Second -> Merge.
1437     const string gdef_ascii = R"EOF(
1438 node {
1439   name: "a"
1440   op: "Const"
1441   attr {
1442     key: "dtype"
1443     value {
1444       type: DT_FLOAT
1445     }
1446   }
1447   attr {
1448     key: "value"
1449     value {
1450       tensor {
1451         dtype: DT_FLOAT
1452         tensor_shape {
1453         }
1454         float_val: 2.0
1455       }
1456     }
1457   }
1458 }
1459 node {
1460   name: "Less"
1461   op: "Const"
1462   attr {
1463     key: "dtype"
1464     value {
1465       type: DT_BOOL
1466     }
1467   }
1468   attr {
1469     key: "value"
1470     value {
1471       tensor {
1472         dtype: DT_BOOL
1473         tensor_shape {
1474         }
1475         tensor_content: "\001"
1476       }
1477     }
1478   }
1479 }
1480 node {
1481   name: "Switch"
1482   op: "Switch"
1483   input: "a"
1484   input: "Less"
1485   attr {
1486     key: "T"
1487     value {
1488       type: DT_FLOAT
1489     }
1490   }
1491 }
1492 node {
1493   name: "First"
1494   op: "Identity"
1495   input: "Switch"
1496   attr {
1497     key: "T"
1498     value {
1499       type: DT_FLOAT
1500     }
1501   }
1502 }
1503 node {
1504   name: "Second"
1505   op: "Identity"
1506   input: "Switch:1"
1507   attr {
1508     key: "T"
1509     value {
1510       type: DT_FLOAT
1511     }
1512   }
1513 }
1514 node {
1515   name: "Merge"
1516   op: "Merge"
1517   input: "First"
1518   input: "Second"
1519   attr {
1520     key: "N"
1521     value {
1522       i: 2
1523     }
1524   }
1525   attr {
1526     key: "T"
1527     value {
1528       type: DT_FLOAT
1529     }
1530   }
1531 }
1532 versions {
1533   producer: 27
1534 })EOF";
1535 
1536     grappler_item_.reset(new GrapplerItem);
1537     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
1538                                                 &grappler_item_->graph));
1539     grappler_item_->id = "test_graph";
1540     grappler_item_->fetch = {"Merge"};
1541   }
1542 
1543   // Create a FusedBatchNorm op that has multiple output ports.
CreateGrapplerItemWithInterDeviceTransfers()1544   void CreateGrapplerItemWithInterDeviceTransfers() {
1545     tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
1546 
1547     // Create a FusedBatchNorm op that has multiple output ports.
1548     auto x = ops::RandomUniform(
1549         s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
1550     auto scale =
1551         ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
1552     auto offset =
1553         ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
1554     auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
1555     auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
1556 
1557     auto batch_norm = ops::FusedBatchNorm(
1558         s.WithOpName("bn"), x, scale, offset, mean, var,
1559         ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
1560     auto y = batch_norm.y;
1561     auto batch_mean = batch_norm.batch_mean;
1562     auto batch_var = batch_norm.batch_variance;
1563     // y1 and y2 take the same tensor, so there should be only 1 Send and Recv.
1564     auto y1 = ops::Identity(s.WithOpName("y1").WithDevice(kCPU1), y);
1565     auto y2 = ops::Identity(s.WithOpName("y2").WithDevice(kCPU1), y);
1566     // batch_mean1 and batch_var1 take different output ports, so each will
1567     // initiate Send/Recv.
1568     auto batch_mean1 = ops::Identity(
1569         s.WithOpName("batch_mean1").WithDevice(kCPU1), batch_mean);
1570     auto batch_var1 =
1571         ops::Identity(s.WithOpName("batch_var1").WithDevice(kCPU1), batch_var);
1572     // This is control dependency.
1573     auto control_dep = ops::NoOp(s.WithOpName("control_dep")
1574                                      .WithControlDependencies(y)
1575                                      .WithDevice(kCPU1));
1576 
1577     GraphDef def;
1578     TF_CHECK_OK(s.ToGraphDef(&def));
1579 
1580     grappler_item_.reset(new GrapplerItem);
1581     grappler_item_->id = "test_conv2d_graph";
1582     grappler_item_->graph = def;
1583     grappler_item_->fetch = {"y1", "y2", "batch_mean1", "batch_var1",
1584                              "control_dep"};
1585 
1586     dependency_["bn"] = {"x", "mean", "var"};
1587     dependency_["y1"] = {"bn"};
1588     dependency_["y2"] = {"bn"};
1589     dependency_["batch_mean1"] = {"bn"};
1590     dependency_["batch_var1"] = {"bn"};
1591     dependency_["control_dep"] = {"bn"};
1592   }
1593 
1594   // Call this after creating grappler_item_ and setting up dependency_.
InitScheduler()1595   void InitScheduler() { TF_ASSERT_OK(scheduler_->Init(grappler_item_.get())); }
1596 
1597   // Returns cost based on op.
SimplePredictCosts(const OpContext & op_context) const1598   Costs SimplePredictCosts(const OpContext& op_context) const {
1599     Costs c;
1600     int64 exec_cost = 0;
1601     if (op_context.op_info.op() == "MatMul") {
1602       exec_cost = 2000000000;
1603     } else if (op_context.op_info.op() == "RandomUniform") {
1604       exec_cost = 1000000000;
1605     } else {
1606       exec_cost = 1000;
1607     }
1608     c.execution_time = Costs::NanoSeconds(exec_cost);
1609     return c;
1610   }
1611 
1612   // Call this after init scheduler_. Scheduler stops after executing
1613   // target_node.
RunScheduler(const string & target_node)1614   std::unordered_map<string, OpContext> RunScheduler(
1615       const string& target_node) {
1616     std::unordered_map<string, OpContext> ops_executed;
1617     bool more_nodes = true;
1618     do {
1619       OpContext op_context = scheduler_->GetCurrNode();
1620       ops_executed[op_context.name] = op_context;
1621       std::cout << op_context.name << std::endl;
1622 
1623       Costs node_costs = SimplePredictCosts(op_context);
1624 
1625       // Check scheduling order.
1626       auto it = dependency_.find(op_context.name);
1627       if (it != dependency_.end()) {
1628         for (const auto& preceding_node : it->second) {
1629           EXPECT_GT(ops_executed.count(preceding_node), 0);
1630         }
1631       }
1632       more_nodes = scheduler_->MarkCurrNodeExecuted(node_costs);
1633 
1634       if (op_context.name == target_node) {
1635         // Scheduler has the state after executing the target node.
1636         break;
1637       }
1638     } while (more_nodes);
1639     return ops_executed;
1640   }
1641 
1642   // Helper method for validating a vector.
1643   template <typename T>
ExpectVectorEq(const std::vector<T> & expected,const std::vector<T> & test_elements)1644   void ExpectVectorEq(const std::vector<T>& expected,
1645                       const std::vector<T>& test_elements) {
1646     // Set of expected elements for an easy comparison.
1647     std::set<T> expected_set(expected.begin(), expected.end());
1648     for (const auto& element : test_elements) {
1649       EXPECT_GT(expected_set.count(element), 0);
1650     }
1651     EXPECT_EQ(expected.size(), test_elements.size());
1652   }
1653 
1654   // Helper method that checks the name of nodes.
ValidateNodeDefs(const std::vector<string> & expected,const std::vector<const NodeDef * > & node_defs)1655   void ValidateNodeDefs(const std::vector<string>& expected,
1656                         const std::vector<const NodeDef*>& node_defs) {
1657     std::vector<string> node_names;
1658     std::transform(node_defs.begin(), node_defs.end(),
1659                    std::back_inserter(node_names),
1660                    [](const NodeDef* node) { return node->name(); });
1661     ExpectVectorEq(expected, node_names);
1662   }
1663 
1664   // Helper method for validating a set.
1665   template <typename T>
ExpectSetEq(const std::set<T> & expected,const std::set<T> & test_elements)1666   void ExpectSetEq(const std::set<T>& expected,
1667                    const std::set<T>& test_elements) {
1668     for (const auto& element : test_elements) {
1669       EXPECT_GT(expected.count(element), 0);
1670     }
1671     EXPECT_EQ(expected.size(), test_elements.size());
1672   }
1673 
1674   // Helper method tthat checks name - port pairs.
ValidateMemoryUsageSnapshot(const std::vector<string> & expected_names,const int port_num_expected,const std::unordered_set<std::pair<const NodeDef *,int>,DeviceState::NodePairHash> & mem_usage_snapshot)1675   void ValidateMemoryUsageSnapshot(
1676       const std::vector<string>& expected_names, const int port_num_expected,
1677       const std::unordered_set<std::pair<const NodeDef*, int>,
1678                                DeviceState::NodePairHash>& mem_usage_snapshot) {
1679     std::set<std::pair<string, int>> nodes_at_peak_mem_usage;
1680     std::transform(
1681         mem_usage_snapshot.begin(), mem_usage_snapshot.end(),
1682         std::inserter(nodes_at_peak_mem_usage, nodes_at_peak_mem_usage.begin()),
1683         [](const std::pair<const NodeDef*, int>& node_port) {
1684           return std::make_pair(node_port.first->name(), node_port.second);
1685         });
1686     std::set<std::pair<string, int>> expected;
1687     std::transform(expected_names.begin(), expected_names.end(),
1688                    std::inserter(expected, expected.begin()),
1689                    [port_num_expected](const string& name) {
1690                      return std::make_pair(name, port_num_expected);
1691                    });
1692     ExpectSetEq(expected, nodes_at_peak_mem_usage);
1693   }
1694 
1695   // Helper method for checking nodes dependency.
ValidateDependencyChain(const std::unordered_map<string,int64> & start_times,const std::vector<string> & nodes_in_dependency_order)1696   void ValidateDependencyChain(
1697       const std::unordered_map<string, int64>& start_times,
1698       const std::vector<string>& nodes_in_dependency_order) {
1699     int64 prev_node_time = -1;
1700     for (const auto& node : nodes_in_dependency_order) {
1701       int64 curr_node_time = start_times.at(node);
1702       EXPECT_GE(curr_node_time, prev_node_time);
1703       prev_node_time = curr_node_time;
1704     }
1705   }
1706 
1707   // cluster_ and scheduler_ are initialized in the c'tor.
1708   std::unique_ptr<VirtualCluster> cluster_;
1709   std::unique_ptr<TestVirtualScheduler> scheduler_;
1710 
1711   // grappler_item_ will be initialized differently for each test case.
1712   std::unique_ptr<GrapplerItem> grappler_item_;
1713   // Node name -> its preceding nodes map for testing scheduling order.
1714   std::unordered_map<string, std::vector<string>> dependency_;
1715 
1716   // Shared params for Conv2D related graphs:
1717   const int batch_size_ = 4;
1718   const int width_ = 10;
1719   const int height_ = 10;
1720   const int depth_in_ = 8;
1721   const int kernel_ = 3;
1722   const int depth_out_ = 16;
1723 };
1724 
1725 // Test that FIFOManager correctly returns the current node with only 1 node.
TEST_F(VirtualSchedulerTest,GetSingleNodeFIFOManager)1726 TEST_F(VirtualSchedulerTest, GetSingleNodeFIFOManager) {
1727   // Init.
1728   FIFOManager manager = FIFOManager();
1729 
1730   // Add the node to FIFOManager.
1731   manager.AddNode(&node1_);
1732   EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1733 }
1734 
1735 // Test that FIFOManager removes the only node contained within.
TEST_F(VirtualSchedulerTest,RemoveSingleNodeFIFOManager)1736 TEST_F(VirtualSchedulerTest, RemoveSingleNodeFIFOManager) {
1737   // Init.
1738   FIFOManager manager = FIFOManager();
1739 
1740   // Add the node to FIFOManager.
1741   manager.AddNode(&node1_);
1742 
1743   // Remove the only node in FIFOManager.
1744   manager.RemoveCurrNode();
1745   EXPECT_TRUE(manager.Empty());
1746 }
1747 
1748 // Test that FIFOManager can remove multiple nodes and returns the current node
1749 // in the right order
TEST_F(VirtualSchedulerTest,GetAndRemoveMultipleFIFOManager)1750 TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleFIFOManager) {
1751   // Init.
1752   FIFOManager manager = FIFOManager();
1753 
1754   // Add the nodes to FIFOManager.
1755   manager.AddNode(&node1_);
1756   manager.AddNode(&node2_);
1757   manager.AddNode(&node3_);
1758   manager.AddNode(&node4_);
1759 
1760   // Keep checking current node while removing nodes from manager.
1761   EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1762   manager.RemoveCurrNode();
1763   EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1764   manager.RemoveCurrNode();
1765   EXPECT_EQ("Node3", manager.GetCurrNode()->name());
1766   manager.RemoveCurrNode();
1767   EXPECT_EQ("Node4", manager.GetCurrNode()->name());
1768   manager.RemoveCurrNode();
1769   EXPECT_TRUE(manager.Empty());
1770 }
1771 
1772 // Test that FIFOManager can remove multiple nodes and add more nodes, still
1773 // returning the current node in the right order
TEST_F(VirtualSchedulerTest,AddAndRemoveMultipleFIFOManager)1774 TEST_F(VirtualSchedulerTest, AddAndRemoveMultipleFIFOManager) {
1775   // Init.
1776   FIFOManager manager = FIFOManager();
1777 
1778   // Add the nodes to FIFOManager.
1779   manager.AddNode(&node1_);
1780   manager.AddNode(&node2_);
1781   manager.AddNode(&node3_);
1782   manager.AddNode(&node4_);
1783 
1784   // Keep checking current node as nodes are removed and added.
1785   EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1786   manager.RemoveCurrNode();
1787   EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1788   manager.AddNode(&node5_);
1789   // GetCurrNode()  should return the same node even if some nodes are added,
1790   // until RemoveCurrNode() is called.
1791   EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1792   manager.RemoveCurrNode();
1793   EXPECT_EQ("Node3", manager.GetCurrNode()->name());
1794   manager.RemoveCurrNode();
1795   EXPECT_EQ("Node4", manager.GetCurrNode()->name());
1796   manager.AddNode(&node6_);
1797   EXPECT_EQ("Node4", manager.GetCurrNode()->name());
1798   manager.RemoveCurrNode();
1799   EXPECT_EQ("Node5", manager.GetCurrNode()->name());
1800   manager.RemoveCurrNode();
1801   EXPECT_EQ("Node6", manager.GetCurrNode()->name());
1802   manager.RemoveCurrNode();
1803   EXPECT_TRUE(manager.Empty());
1804 }
1805 
1806 // Test that LIFOManager correctly returns the current node with only 1 node.
TEST_F(VirtualSchedulerTest,GetSingleNodeLIFOManager)1807 TEST_F(VirtualSchedulerTest, GetSingleNodeLIFOManager) {
1808   // Init.
1809   LIFOManager manager = LIFOManager();
1810 
1811   // Add the node to LIFOManager.
1812   manager.AddNode(&node1_);
1813   EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1814 }
1815 
1816 // Test that LIFOManager removes the only node contained within.
TEST_F(VirtualSchedulerTest,RemoveSingleNodeLIFOManager)1817 TEST_F(VirtualSchedulerTest, RemoveSingleNodeLIFOManager) {
1818   // Init.
1819   LIFOManager manager = LIFOManager();
1820 
1821   // Add the node to LIFOManager.
1822   manager.AddNode(&node1_);
1823 
1824   // Remove the only node in LIFOManager.
1825   manager.RemoveCurrNode();
1826   EXPECT_TRUE(manager.Empty());
1827 }
1828 
1829 // Test that LIFOManager can remove multiple nodes and returns the current node
1830 // in the right order
TEST_F(VirtualSchedulerTest,GetAndRemoveMultipleLIFOManager)1831 TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleLIFOManager) {
1832   // Init.
1833   LIFOManager manager = LIFOManager();
1834 
1835   // Add the nodes to LIFOManager.
1836   manager.AddNode(&node1_);
1837   manager.AddNode(&node2_);
1838   manager.AddNode(&node3_);
1839   manager.AddNode(&node4_);
1840 
1841   // Keep checking current node while removing nodes from manager.
1842   EXPECT_EQ("Node4", manager.GetCurrNode()->name());
1843   manager.RemoveCurrNode();
1844   EXPECT_EQ("Node3", manager.GetCurrNode()->name());
1845   manager.RemoveCurrNode();
1846   EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1847   manager.RemoveCurrNode();
1848   EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1849   manager.RemoveCurrNode();
1850   EXPECT_TRUE(manager.Empty());
1851 }
1852 
1853 // Test that LIFOManager can remove multiple nodes (must be removing the current
1854 // node) and add more nodes, still returning the current node in the right order
TEST_F(VirtualSchedulerTest,AddAndRemoveMultipleLIFOManager)1855 TEST_F(VirtualSchedulerTest, AddAndRemoveMultipleLIFOManager) {
1856   // Init.
1857   LIFOManager manager = LIFOManager();
1858 
1859   // Add the nodes to LIFOManager.
1860   manager.AddNode(&node1_);
1861   manager.AddNode(&node2_);
1862   manager.AddNode(&node3_);
1863   manager.AddNode(&node4_);
1864 
1865   // Keep checking current node as nodes are removed and added.
1866   EXPECT_EQ("Node4", manager.GetCurrNode()->name());
1867   manager.RemoveCurrNode();
1868   EXPECT_EQ("Node3", manager.GetCurrNode()->name());
1869   manager.AddNode(&node5_);
1870   // GetCurrNode()  should return the same node even if some nodes are added,
1871   // until RemoveCurrNode() is called.
1872   EXPECT_EQ("Node3", manager.GetCurrNode()->name());
1873   manager.RemoveCurrNode();
1874   EXPECT_EQ("Node5", manager.GetCurrNode()->name());
1875   manager.RemoveCurrNode();
1876   EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1877   manager.AddNode(&node6_);
1878   EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1879   manager.RemoveCurrNode();
1880   EXPECT_EQ("Node6", manager.GetCurrNode()->name());
1881   manager.RemoveCurrNode();
1882   EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1883   manager.RemoveCurrNode();
1884   EXPECT_TRUE(manager.Empty());
1885 }
1886 
TEST_F(VirtualSchedulerTest,GetSingleNodeFirstReadyManager)1887 TEST_F(VirtualSchedulerTest, GetSingleNodeFirstReadyManager) {
1888   FirstReadyManager manager;
1889   manager.Init(&node_states_);
1890 
1891   manager.AddNode(&node1_);
1892   EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1893 }
1894 
TEST_F(VirtualSchedulerTest,RemoveSingleNodeFirstReadyManager)1895 TEST_F(VirtualSchedulerTest, RemoveSingleNodeFirstReadyManager) {
1896   FirstReadyManager manager;
1897   manager.Init(&node_states_);
1898   manager.AddNode(&node1_);
1899   manager.RemoveCurrNode();
1900   EXPECT_TRUE(manager.Empty());
1901 }
1902 
TEST_F(VirtualSchedulerTest,GetAndRemoveMultipleFirstReadyManager)1903 TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleFirstReadyManager) {
1904   FirstReadyManager manager;
1905   manager.Init(&node_states_);
1906   // Insert nodes in some random order.
1907   manager.AddNode(&node2_);
1908   manager.AddNode(&node1_);
1909   manager.AddNode(&node4_);
1910   manager.AddNode(&node5_);
1911   manager.AddNode(&node3_);
1912   manager.AddNode(&node6_);
1913 
1914   // In whatever order we insert nodes, we get the same order based on nodes'
1915   // time_ready.
1916   EXPECT_EQ("Node6", manager.GetCurrNode()->name());
1917   manager.RemoveCurrNode();
1918   EXPECT_EQ("Node5", manager.GetCurrNode()->name());
1919   manager.RemoveCurrNode();
1920   EXPECT_EQ("Node4", manager.GetCurrNode()->name());
1921   manager.RemoveCurrNode();
1922   EXPECT_EQ("Node3", manager.GetCurrNode()->name());
1923   manager.RemoveCurrNode();
1924   EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1925   manager.RemoveCurrNode();
1926   EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1927   manager.RemoveCurrNode();
1928   EXPECT_TRUE(manager.Empty());
1929 }
1930 
TEST_F(VirtualSchedulerTest,GetCurrNodeFirstReadyManager)1931 TEST_F(VirtualSchedulerTest, GetCurrNodeFirstReadyManager) {
1932   FirstReadyManager manager;
1933   manager.Init(&node_states_);
1934   // Insert nodes in some random order.
1935   manager.AddNode(&node2_);
1936   manager.AddNode(&node1_);
1937   manager.AddNode(&node4_);
1938   manager.AddNode(&node5_);
1939   manager.AddNode(&node3_);
1940   manager.AddNode(&node6_);
1941 
1942   // Among these nodes, node6 has the smallest time_ready, hence, GetCurrNode()
1943   // should return it.
1944   EXPECT_EQ("Node6", manager.GetCurrNode()->name());
1945   // Now insret a few other nodes, but their time_ready's are even smaller than
1946   // that of Node6. Before calling RemoveCurrNode(), GetCurrNode() should return
1947   // the same node, Node6, in this case.
1948 
1949   NodeDef node7;
1950   NodeDef node8;
1951   NodeDef node9;
1952   NodeSetUp("Node7", kConv2D, kCPU0, 5, &node7);
1953   NodeSetUp("Node8", kConv2D, kCPU0, 4, &node8);
1954   NodeSetUp("Node9", kConv2D, kCPU0, 3, &node9);
1955 
1956   manager.AddNode(&node7);
1957   EXPECT_EQ("Node6", manager.GetCurrNode()->name());
1958 
1959   manager.AddNode(&node8);
1960   EXPECT_EQ("Node6", manager.GetCurrNode()->name());
1961 
1962   manager.RemoveCurrNode();
1963   // Now Node6 is removed, and GetCurrNode() will return Node8.
1964   EXPECT_EQ("Node8", manager.GetCurrNode()->name());
1965 
1966   // Again, AddNode shouldn't change GetCurrNode().
1967   manager.AddNode(&node9);
1968   EXPECT_EQ("Node8", manager.GetCurrNode()->name());
1969 
1970   manager.RemoveCurrNode();
1971   EXPECT_EQ("Node9", manager.GetCurrNode()->name());
1972   manager.RemoveCurrNode();
1973   EXPECT_EQ("Node7", manager.GetCurrNode()->name());
1974   manager.RemoveCurrNode();
1975   EXPECT_EQ("Node5", manager.GetCurrNode()->name());
1976   manager.RemoveCurrNode();
1977   EXPECT_EQ("Node4", manager.GetCurrNode()->name());
1978   manager.RemoveCurrNode();
1979   EXPECT_EQ("Node3", manager.GetCurrNode()->name());
1980   manager.RemoveCurrNode();
1981   EXPECT_EQ("Node2", manager.GetCurrNode()->name());
1982   manager.RemoveCurrNode();
1983   EXPECT_EQ("Node1", manager.GetCurrNode()->name());
1984   manager.RemoveCurrNode();
1985   EXPECT_TRUE(manager.Empty());
1986 }
1987 
TEST_F(VirtualSchedulerTest,DeterminismInFirstReadyManager)1988 TEST_F(VirtualSchedulerTest, DeterminismInFirstReadyManager) {
1989   FirstReadyManager manager1;
1990   manager1.Init(&node_states_);
1991   FirstReadyManager manager2;
1992   manager2.Init(&node_states_);
1993 
1994   // 6 nodes with same time_ready.
1995   NodeDef node7;
1996   NodeDef node8;
1997   NodeDef node9;
1998   NodeDef node10;
1999   NodeDef node11;
2000   NodeDef node12;
2001   NodeSetUp("Node7", kConv2D, kCPU0, 1000, &node7);
2002   NodeSetUp("Node8", kConv2D, kCPU0, 1000, &node8);
2003   NodeSetUp("Node9", kConv2D, kCPU0, 1000, &node9);
2004   NodeSetUp("Node10", kConv2D, kCPU0, 1000, &node10);
2005   NodeSetUp("Node11", kConv2D, kCPU0, 1000, &node11);
2006   NodeSetUp("Node12", kConv2D, kCPU0, 1000, &node12);
2007 
2008   // Add the above 6 nodes to manager1.
2009   manager1.AddNode(&node7);
2010   manager1.AddNode(&node8);
2011   manager1.AddNode(&node9);
2012   manager1.AddNode(&node10);
2013   manager1.AddNode(&node11);
2014   manager1.AddNode(&node12);
2015 
2016   // Add the above 6 nodes to manager2, but in a different order.
2017   manager2.AddNode(&node8);
2018   manager2.AddNode(&node11);
2019   manager2.AddNode(&node9);
2020   manager2.AddNode(&node10);
2021   manager2.AddNode(&node7);
2022   manager2.AddNode(&node12);
2023 
2024   // Expect both managers return the same nodes for deterministic node
2025   // scheduling.
2026   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
2027   manager1.RemoveCurrNode();
2028   manager2.RemoveCurrNode();
2029 
2030   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
2031   manager1.RemoveCurrNode();
2032   manager2.RemoveCurrNode();
2033 
2034   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
2035   manager1.RemoveCurrNode();
2036   manager2.RemoveCurrNode();
2037 
2038   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
2039   manager1.RemoveCurrNode();
2040   manager2.RemoveCurrNode();
2041 
2042   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
2043   manager1.RemoveCurrNode();
2044   manager2.RemoveCurrNode();
2045 
2046   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
2047   manager1.RemoveCurrNode();
2048   manager2.RemoveCurrNode();
2049 
2050   EXPECT_TRUE(manager1.Empty());
2051   EXPECT_TRUE(manager2.Empty());
2052 }
2053 
TEST_F(VirtualSchedulerTest,RemoveSingleNodeCompositeNodeManager)2054 TEST_F(VirtualSchedulerTest, RemoveSingleNodeCompositeNodeManager) {
2055   CompositeNodeManager manager;
2056   manager.Init(&node_states_);
2057   manager.AddNode(&node1_);
2058   manager.RemoveCurrNode();
2059   EXPECT_TRUE(manager.Empty());
2060 }
2061 
TEST_F(VirtualSchedulerTest,RemoveSingleNodeComopsiteNodeManager)2062 TEST_F(VirtualSchedulerTest, RemoveSingleNodeComopsiteNodeManager) {
2063   CompositeNodeManager manager;
2064   manager.Init(&node_states_);
2065 
2066   manager.AddNode(&node1_);
2067   manager.RemoveCurrNode();
2068   EXPECT_TRUE(manager.Empty());
2069 }
2070 
TEST_F(VirtualSchedulerTest,GetAndRemoveMultipleComopsiteNodeManager)2071 TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleComopsiteNodeManager) {
2072   CompositeNodeManager manager;
2073   manager.Init(&node_states_);
2074 
2075   // Add the nodes to LIFOManager.
2076   manager.AddNode(&node1_);
2077   manager.AddNode(&node2_);
2078   manager.AddNode(&node3_);
2079   manager.AddNode(&node4_);
2080 
2081   // Keep checking current node as nodes are removed and added.
2082   EXPECT_EQ("Node4", manager.GetCurrNode()->name());
2083   manager.RemoveCurrNode();
2084   EXPECT_EQ("Node3", manager.GetCurrNode()->name());
2085   manager.AddNode(&node5_);
2086   // GetCurrNode()  should return the same node even if some nodes are added,
2087   // until RemoveCurrNode() is called.
2088   EXPECT_EQ("Node3", manager.GetCurrNode()->name());
2089   manager.RemoveCurrNode();
2090   EXPECT_EQ("Node5", manager.GetCurrNode()->name());
2091   manager.RemoveCurrNode();
2092   EXPECT_EQ("Node2", manager.GetCurrNode()->name());
2093   manager.AddNode(&node6_);
2094   EXPECT_EQ("Node2", manager.GetCurrNode()->name());
2095   manager.RemoveCurrNode();
2096   EXPECT_EQ("Node6", manager.GetCurrNode()->name());
2097   manager.RemoveCurrNode();
2098   EXPECT_EQ("Node1", manager.GetCurrNode()->name());
2099   manager.RemoveCurrNode();
2100   EXPECT_TRUE(manager.Empty());
2101 }
2102 
TEST_F(VirtualSchedulerTest,MultiDeviceSendRecvComopsiteNodeManager)2103 TEST_F(VirtualSchedulerTest, MultiDeviceSendRecvComopsiteNodeManager) {
2104   CompositeNodeManager manager;
2105   manager.Init(&node_states_);
2106   // Additional nodes on kCPU1
2107   NodeDef node7;
2108   NodeDef node8;
2109   NodeDef node9;
2110   NodeSetUp("Node7", kConv2D, kCPU1, 1001, &node7);
2111   NodeSetUp("Node8", kConv2D, kCPU1, 2001, &node8);
2112   NodeSetUp("Node9", kConv2D, kCPU1, 3001, &node9);
2113 
2114   // Send and Recv nodes.
2115   NodeDef send1;
2116   NodeDef send2;
2117   NodeDef recv1;
2118   NodeDef recv2;
2119   NodeSetUp("Send1", kSend, kChannelFrom0To1, 2002, &send1);
2120   NodeSetUp("Send2", kSend, kChannelFrom1To0, 2005, &send2);
2121   NodeSetUp("Recv1", kRecv, kCPU0, 2003, &recv1);
2122   NodeSetUp("Recv2", kRecv, kCPU1, 2004, &recv2);
2123 
2124   // Insert nodes.
2125   manager.AddNode(&node1_);
2126   manager.AddNode(&node2_);
2127   manager.AddNode(&node3_);
2128   manager.AddNode(&node4_);
2129   manager.AddNode(&node5_);
2130   manager.AddNode(&node6_);
2131   manager.AddNode(&node7);
2132   manager.AddNode(&node8);
2133   manager.AddNode(&node9);
2134   manager.AddNode(&send1);
2135   manager.AddNode(&send2);
2136   manager.AddNode(&recv1);
2137   manager.AddNode(&recv2);
2138 
2139   // on kCPU0; last one is node6_, on kCPU1: last one is node9;
2140   // so choose one that has earliest time_ready among node6_, node9,
2141   // Send1, Send2, Recv1, and Recv2.
2142   EXPECT_EQ("Node6", manager.GetCurrNode()->name());
2143   manager.RemoveCurrNode();
2144   // Then, the next one on kCPU0 is node5_; choose the earliest time_ready node
2145   // among node5_, node9, Send1, Send2, Recv1, and Recv2.
2146   EXPECT_EQ("Node5", manager.GetCurrNode()->name());
2147   manager.RemoveCurrNode();
2148   // Next, choose among node4_, node9, Send1, Send2, Recv1, and Recv2.
2149   EXPECT_EQ("Send1", manager.GetCurrNode()->name());
2150   manager.RemoveCurrNode();
2151   // Next, choose among node4_, node9, Sen2, Recv1, and Recv2.
2152   EXPECT_EQ("Recv1", manager.GetCurrNode()->name());
2153   manager.RemoveCurrNode();
2154   // Next, choose among node4_, node9, Send2, and Recv2.
2155   EXPECT_EQ("Recv2", manager.GetCurrNode()->name());
2156   manager.RemoveCurrNode();
2157   // Next, choose among node4_, node9, and Send2.
2158   EXPECT_EQ("Send2", manager.GetCurrNode()->name());
2159   manager.RemoveCurrNode();
2160   // Next, choose between node4_, node9.
2161   EXPECT_EQ("Node4", manager.GetCurrNode()->name());
2162   manager.RemoveCurrNode();
2163   // Next, choose between node3_, node9.
2164   EXPECT_EQ("Node9", manager.GetCurrNode()->name());
2165   manager.RemoveCurrNode();
2166   // Next, choose between node3_, node8.
2167   EXPECT_EQ("Node8", manager.GetCurrNode()->name());
2168   manager.RemoveCurrNode();
2169   // Next, choose between node3_, node7.
2170   EXPECT_EQ("Node7", manager.GetCurrNode()->name());
2171   manager.RemoveCurrNode();
2172   // Then, just the nodes on kCPU1 -- LIFO.
2173   EXPECT_EQ("Node3", manager.GetCurrNode()->name());
2174   manager.RemoveCurrNode();
2175   EXPECT_EQ("Node2", manager.GetCurrNode()->name());
2176   manager.RemoveCurrNode();
2177   EXPECT_EQ("Node1", manager.GetCurrNode()->name());
2178   manager.RemoveCurrNode();
2179   EXPECT_TRUE(manager.Empty());
2180 }
2181 
TEST_F(VirtualSchedulerTest,DeterminismInCompositeNodeManager)2182 TEST_F(VirtualSchedulerTest, DeterminismInCompositeNodeManager) {
2183   CompositeNodeManager manager;
2184   manager.Init(&node_states_);
2185   CompositeNodeManager manager2;
2186   manager2.Init(&node_states_);
2187 
2188   // 6 nodes with same time_ready.
2189   NodeDef node7;
2190   NodeDef node8;
2191   NodeDef node9;
2192   NodeDef node10;
2193   NodeDef node11;
2194   NodeDef node12;
2195   NodeSetUp("Node7", kConv2D, kCPU0, 1000, &node7);
2196   NodeSetUp("Node8", kSend, kCPU0, 1000, &node8);
2197   NodeSetUp("Node9", kRecv, kCPU0, 1000, &node9);
2198   NodeSetUp("Node10", kConv2D, kCPU0, 999, &node10);
2199   NodeSetUp("Node11", kRecv, kCPU0, 999, &node11);
2200   NodeSetUp("Node12", kConv2D, kCPU1, 1000, &node12);
2201 
2202   // Add Nodes 7 to 9 to manager.
2203   manager.AddNode(&node7);
2204   manager.AddNode(&node8);
2205   manager.AddNode(&node9);
2206 
2207   // It should return _Send, Recv, and the other op order, when the candidate
2208   // nodes have same time_ready.
2209   EXPECT_EQ("Node8", manager.GetCurrNode()->name());
2210   EXPECT_EQ(kSend, manager.GetCurrNode()->op());
2211   manager.RemoveCurrNode();
2212   EXPECT_EQ("Node9", manager.GetCurrNode()->name());
2213   EXPECT_EQ(kRecv, manager.GetCurrNode()->op());
2214   manager.RemoveCurrNode();
2215   EXPECT_EQ("Node7", manager.GetCurrNode()->name());
2216   EXPECT_EQ(kConv2D, manager.GetCurrNode()->op());
2217   manager.RemoveCurrNode();
2218   EXPECT_TRUE(manager.Empty());
2219 
2220   // Add Nodes 7 to 9 to manager, but in a different order.
2221   manager.AddNode(&node9);
2222   manager.AddNode(&node8);
2223   manager.AddNode(&node7);
2224 
2225   // Expect same order (_Send, _Recv, and the other op), regardless of Add
2226   // order.
2227   EXPECT_EQ("Node8", manager.GetCurrNode()->name());
2228   EXPECT_EQ(kSend, manager.GetCurrNode()->op());
2229   manager.RemoveCurrNode();
2230   EXPECT_EQ("Node9", manager.GetCurrNode()->name());
2231   EXPECT_EQ(kRecv, manager.GetCurrNode()->op());
2232   manager.RemoveCurrNode();
2233   EXPECT_EQ("Node7", manager.GetCurrNode()->name());
2234   EXPECT_EQ(kConv2D, manager.GetCurrNode()->op());
2235   manager.RemoveCurrNode();
2236   EXPECT_TRUE(manager.Empty());
2237 
2238   // Conv2D's time_ready < Send's time_ready; Expect Conv2D first.
2239   manager.AddNode(&node8);
2240   manager.AddNode(&node10);
2241   EXPECT_EQ("Node10", manager.GetCurrNode()->name());
2242   EXPECT_EQ(kConv2D, manager.GetCurrNode()->op());
2243   manager.RemoveCurrNode();
2244   EXPECT_EQ("Node8", manager.GetCurrNode()->name());
2245   EXPECT_EQ(kSend, manager.GetCurrNode()->op());
2246   manager.RemoveCurrNode();
2247   EXPECT_TRUE(manager.Empty());
2248 
2249   // Recv's time_ready < Send' time_ready; Expect Recv first.
2250   manager.AddNode(&node11);
2251   manager.AddNode(&node8);
2252   EXPECT_EQ("Node11", manager.GetCurrNode()->name());
2253   EXPECT_EQ(kRecv, manager.GetCurrNode()->op());
2254   manager.RemoveCurrNode();
2255   EXPECT_EQ("Node8", manager.GetCurrNode()->name());
2256   EXPECT_EQ(kSend, manager.GetCurrNode()->op());
2257   manager.RemoveCurrNode();
2258   EXPECT_TRUE(manager.Empty());
2259 
2260   // Node7 and 12 are normal ops with the same time_ready, placed on different
2261   // devices. These two nodes are added to manager and manager2, but in
2262   // different orders; Expect GetCurrNode() returns the nodes in the same order.
2263   manager.AddNode(&node7);
2264   manager.AddNode(&node12);
2265 
2266   manager2.AddNode(&node12);
2267   manager2.AddNode(&node7);
2268 
2269   EXPECT_EQ(manager.GetCurrNode()->name(), manager2.GetCurrNode()->name());
2270   manager.RemoveCurrNode();
2271   manager2.RemoveCurrNode();
2272   EXPECT_EQ(manager.GetCurrNode()->name(), manager2.GetCurrNode()->name());
2273   manager.RemoveCurrNode();
2274   manager2.RemoveCurrNode();
2275   EXPECT_TRUE(manager.Empty());
2276 }
2277 
2278 // Create small graph, run predict costs on it, make sure the costs from the
2279 // summary match the hand-calculated costs.
TEST_F(VirtualSchedulerTest,SummaryCostTest)2280 TEST_F(VirtualSchedulerTest, SummaryCostTest) {
2281   // Run matmul test.
2282   CreateGrapplerItemWithMatmulChain();
2283   InitScheduler();
2284   auto ops_executed = RunScheduler("");
2285   Costs c = scheduler_->Summary();
2286 
2287   // RandomUniform - 5 * 1s
2288   // Matmuls - 4 * 2s = 8
2289   // Misc - 5 * 1us
2290   // Total: 13000005
2291   EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count());
2292   EXPECT_EQ(grappler_item_->graph.node_size(), c.num_ops_total);
2293   EXPECT_FALSE(c.inaccurate);
2294   EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2295 }
2296 
2297 // Like the above SummaryCostTest, but makes sure the stepstats timeline is
2298 // correct.
TEST_F(VirtualSchedulerTest,SummaryCostStepStatsTest)2299 TEST_F(VirtualSchedulerTest, SummaryCostStepStatsTest) {
2300   // Run matmul test.
2301   CreateGrapplerItemWithMatmulChain();
2302   InitScheduler();
2303   auto ops_executed = RunScheduler("");
2304   RunMetadata metadata;
2305   Costs c = scheduler_->Summary(&metadata);
2306   StepStats stepstats = metadata.step_stats();
2307   EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count());
2308   EXPECT_EQ(grappler_item_->graph.node_size(), c.num_ops_total);
2309   EXPECT_FALSE(c.inaccurate);
2310   EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2311 
2312   // Should only be 1 device!
2313   EXPECT_EQ(1, stepstats.dev_stats().size());
2314 
2315   // Create a map of op name -> start and end times (micros).
2316   std::map<string, std::pair<int64, int64>> start_end_times;
2317   for (const auto& device_step_stats : stepstats.dev_stats()) {
2318     for (const auto& stats : device_step_stats.node_stats()) {
2319       int64 start = stats.all_start_micros();
2320       int64 end = start + stats.all_end_rel_micros();
2321       start_end_times[stats.node_name()] = std::pair<int64, int64>(start, end);
2322 
2323       // Make sure that the output properties are correct for
2324       // MatMul and RandomUniform operations.
2325       // We only check for dtype, and shape (excluding alloc)
2326       // since alloc is not set by the virtual scheduler.
2327       if (stats.timeline_label() == "MatMul" ||
2328           stats.timeline_label() == "RandomUniform") {
2329         EXPECT_EQ(1, stats.output().size());
2330         for (const auto& output : stats.output()) {
2331           EXPECT_EQ(DT_FLOAT, output.tensor_description().dtype());
2332           EXPECT_EQ(2, output.tensor_description().shape().dim().size());
2333           for (const auto& dim : output.tensor_description().shape().dim()) {
2334             EXPECT_EQ(3200, dim.size());
2335           }
2336         }
2337       }
2338     }
2339   }
2340 
2341   // The base start_time is the time to compute RandomUniforms
2342   int64 cur_time = static_cast<int64>(5000005);
2343   // The increment is the execution time of one matmul. See
2344   // CreateGrapplerItemWithMatmulChain for details.
2345   int64 increment = static_cast<int64>(2000000);
2346   auto op_names = {"ab", "abc", "abcd", "abcde"};
2347   for (const auto& op_name : op_names) {
2348     int64 actual_start = start_end_times[op_name].first;
2349     int64 actual_end = start_end_times[op_name].second;
2350     int64 expected_start = cur_time;
2351     int64 expected_end = cur_time + increment;
2352     EXPECT_EQ(expected_start, actual_start);
2353     EXPECT_EQ(expected_end, actual_end);
2354     cur_time += increment;
2355   }
2356 }
2357 
TEST_F(VirtualSchedulerTest,InitAndBasicScheduling)2358 TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) {
2359   // Init.
2360   CreateGrapplerItemWithConv2Ds();
2361   InitScheduler();
2362 
2363   // Run the scheduler.
2364   auto ops_executed = RunScheduler("");  // Run all the nodes.
2365 
2366   // [const and rand] * (x, y, f), and c0 and c1. c2 and z shouldn't be
2367   // executed.
2368   EXPECT_EQ(8, ops_executed.size());
2369 
2370   // x, y, f, c0, and c1 should be in the ops executed.
2371   EXPECT_GT(ops_executed.count("x"), 0);
2372   EXPECT_GT(ops_executed.count("y"), 0);
2373   EXPECT_GT(ops_executed.count("f"), 0);
2374   EXPECT_GT(ops_executed.count("c0"), 0);
2375   EXPECT_GT(ops_executed.count("c1"), 0);
2376 
2377   // z and c2 shouldn't be part of it.
2378   EXPECT_EQ(ops_executed.count("z"), 0);
2379   EXPECT_EQ(ops_executed.count("c2"), 0);
2380 
2381   // Check input / output properties.
2382   EXPECT_EQ(1, ops_executed["x"].op_info.outputs_size());
2383   EXPECT_EQ(1, ops_executed["y"].op_info.outputs_size());
2384   EXPECT_EQ(1, ops_executed["f"].op_info.outputs_size());
2385   EXPECT_EQ(2, ops_executed["c0"].op_info.inputs_size());
2386   EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size());
2387 }
2388 
TEST_F(VirtualSchedulerTest,MemoryUsage)2389 TEST_F(VirtualSchedulerTest, MemoryUsage) {
2390   // Init.
2391   CreateGrapplerItemWithAddN();
2392   InitScheduler();
2393 
2394   // Run the scheduler.
2395   RunScheduler("");
2396 
2397   const auto* device_states = scheduler_->GetDeviceStates();
2398   const auto& cpu_state = device_states->at(kCPU0);
2399 
2400   // out node adds 4 tensors, each with 10x10x10x10, so the peak memory usage
2401   // is 4 x the input tensor size while executing the out node.
2402   int64 one_input_node_size = 4 * 10 * 10 * 10 * 10;
2403   const std::vector<string> expected_names = {"x", "y", "z", "w"};
2404   EXPECT_EQ(expected_names.size() * one_input_node_size,
2405             cpu_state.max_memory_usage);
2406   ValidateMemoryUsageSnapshot(expected_names, 0 /* port_num_expected */,
2407                               cpu_state.mem_usage_snapshot_at_peak);
2408 }
2409 
TEST_F(VirtualSchedulerTest,UnnecessaryFeedNodes)2410 TEST_F(VirtualSchedulerTest, UnnecessaryFeedNodes) {
2411   CreateGrapplerItemWithUnnecessaryPlaceholderNodes();
2412   InitScheduler();
2413 
2414   // Test that scheduler can run graphs with extra unnecessary feed nodes.
2415   auto ops_executed = RunScheduler("");
2416   ASSERT_EQ(1, ops_executed.size());
2417   ASSERT_EQ(ops_executed.count("x"), 1);
2418 }
2419 
TEST_F(VirtualSchedulerTest,ControlDependency)2420 TEST_F(VirtualSchedulerTest, ControlDependency) {
2421   // Init.
2422   CreateGrapplerItemWithControlDependency();
2423   InitScheduler();
2424 
2425   // Run the scheduler.
2426   RunScheduler("");
2427 
2428   const auto* device_states = scheduler_->GetDeviceStates();
2429   const auto& cpu_state = device_states->at(kCPU0);
2430 
2431   // The graph has a NoOp that takes control dependency from 7 NoOps. The peak
2432   // memory usage is when executing the final NoOp.
2433   int64 one_input_node_size = 4;  // control dependency
2434   const std::vector<string> expected_names = {"x", "y", "z", "w",
2435                                               "u", "v", "t"};
2436   EXPECT_EQ(expected_names.size() * one_input_node_size,
2437             cpu_state.max_memory_usage);
2438   ValidateMemoryUsageSnapshot(expected_names, -1 /* port_num_expected */,
2439                               cpu_state.mem_usage_snapshot_at_peak);
2440 }
2441 
TEST_F(VirtualSchedulerTest,ComplexDependency)2442 TEST_F(VirtualSchedulerTest, ComplexDependency) {
2443   // Init.
2444   CreateGrapplerItemWithBatchNorm();
2445   InitScheduler();
2446 
2447   // Run the scheduler.
2448   RunScheduler("bn");
2449 
2450   const auto& device_states = scheduler_->GetDeviceStates();
2451   const auto& cpu_state = device_states->at(kCPU0);
2452 
2453   // The graph is
2454   //  bn = FusedBatchNorm(x, scale, offset, mean, var)
2455   //  z1 = bn.y + x
2456   //  z2 = bn.var + bn.var
2457   //  z3 = bn.var + bn.var
2458   //  z4 = control dependency from bn.
2459   //  Note that bn.mean doesn't have any consumer.
2460   const int x_size = batch_size_ * width_ * height_ * depth_in_;
2461   int64 expected_size =
2462       4 * (2 * x_size /* x and bn.y */ + depth_in_ /* bn.var */ +
2463            1 /* control dependency */);
2464   EXPECT_EQ(expected_size, cpu_state.memory_usage);
2465 
2466   // Nodes currently in memory: bn's port -1, 0, and 2, and x's port 0.
2467   std::set<std::pair<string, int>> nodes_in_memory;
2468   std::transform(
2469       cpu_state.nodes_in_memory.begin(), cpu_state.nodes_in_memory.end(),
2470       std::inserter(nodes_in_memory, nodes_in_memory.begin()),
2471       [](const std::pair<const NodeDef*, int>& node_port) {
2472         return std::make_pair(node_port.first->name(), node_port.second);
2473       });
2474   std::set<std::pair<string, int>> expected = {
2475       std::make_pair("bn", -1),
2476       std::make_pair("bn", 0),
2477       std::make_pair("bn", 2),
2478       std::make_pair("x", 0),
2479   };
2480   ExpectSetEq(expected, nodes_in_memory);
2481 
2482   const auto* node_states = scheduler_->GetNodeStates();
2483   const NodeState* bn_node = nullptr;
2484   const NodeState* x_node = nullptr;
2485   for (const auto& nodedef_node_state : *node_states) {
2486     const NodeDef* node = nodedef_node_state.first;
2487     const NodeState& node_state = nodedef_node_state.second;
2488     if (node->name() == "bn") {
2489       bn_node = &node_state;
2490     }
2491     if (node->name() == "x") {
2492       x_node = &node_state;
2493     }
2494   }
2495   CHECK_NOTNULL(bn_node);
2496   CHECK_NOTNULL(x_node);
2497 
2498   ValidateNodeDefs({"bn", "z1"}, x_node->outputs.at(0));
2499   ValidateNodeDefs({"z4"}, bn_node->outputs.at(-1));
2500   ValidateNodeDefs({"z1"}, bn_node->outputs.at(0));
2501   // z2 and z3 are bn.var + bn.var, so they appear twice in bn's output port 2.
2502   ValidateNodeDefs({"z2", "z3", "z2", "z3"}, bn_node->outputs.at(2));
2503 }
2504 
TEST_F(VirtualSchedulerTest,Variable)2505 TEST_F(VirtualSchedulerTest, Variable) {
2506   // Init.
2507   CreateGrapplerItemWithConv2DAndVariable();
2508   InitScheduler();
2509 
2510   // Run the scheduler.
2511   RunScheduler("");
2512 
2513   const auto* device_states = scheduler_->GetDeviceStates();
2514   const auto& cpu_state = device_states->at(kCPU0);
2515 
2516   // There is one Conv2D that takes x and f, but f is variable, so it should be
2517   // in persistent nodes.
2518   // f is variable.
2519   ValidateMemoryUsageSnapshot({"f"}, 0 /* port_num_expected */,
2520                               cpu_state.persistent_nodes);
2521   // Only x in peak memory usage snapshot.
2522   ValidateMemoryUsageSnapshot({"x"}, 0 /* port_num_expected */,
2523                               cpu_state.mem_usage_snapshot_at_peak);
2524 }
2525 
TEST_F(VirtualSchedulerTest,WhileLoop)2526 TEST_F(VirtualSchedulerTest, WhileLoop) {
2527   // Init.
2528   CreateGrapplerItemWithLoop();
2529   InitScheduler();
2530 
2531   // Run the scheduler.
2532   RunScheduler("");
2533 
2534   // Check the timeline
2535   RunMetadata metadata;
2536   scheduler_->Summary(&metadata);
2537 
2538   // Nodes in topological order:
2539   // * const, ones
2540   // * while/Enter, while/Enter_1
2541   // * while/Merge, while/Merge_1
2542   // * while/Less/y
2543   // * while/Less
2544   // * while/LoopCond
2545   // * while/Switch, while/Switch_1
2546   // * while/Identity, while/Identity_1, while/Exit, while/Exit_1
2547   // * while/add/y, while/concat/axis
2548   // * while/add, while/concat
2549   // * while/NextIteration, while/NextIteration_1
2550 
2551   int num_next_iteration = 0;
2552   int num_next_iteration_1 = 0;
2553   int num_exit = 0;
2554   int num_exit_1 = 0;
2555   int64 next_iter_start_micro;
2556   int64 next_iter_1_start_micro;
2557   int64 exit_start_micro;
2558   int64 exit_1_start_micro;
2559 
2560   std::unordered_map<string, int64> start_times;
2561   for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
2562     for (const auto& stats : device_step_stats.node_stats()) {
2563       start_times[stats.node_name()] = stats.all_start_micros();
2564       if (stats.node_name() == "while/NextIteration") {
2565         ++num_next_iteration;
2566         next_iter_start_micro = stats.all_start_micros();
2567       } else if (stats.node_name() == "while/NextIteration_1") {
2568         ++num_next_iteration_1;
2569         next_iter_1_start_micro = stats.all_start_micros();
2570       } else if (stats.node_name() == "while/Exit") {
2571         ++num_exit;
2572         exit_start_micro = stats.all_start_micros();
2573       } else if (stats.node_name() == "while/Exit_1") {
2574         ++num_exit_1;
2575         exit_1_start_micro = stats.all_start_micros();
2576       }
2577     }
2578   }
2579 
2580   // Make sure we went though the body of the loop once, and that the output of
2581   // the loop was scheduled as well.
2582   EXPECT_EQ(1, num_next_iteration);
2583   EXPECT_EQ(1, num_next_iteration_1);
2584   EXPECT_EQ(1, num_exit);
2585   EXPECT_EQ(1, num_exit_1);
2586 
2587   // Start times of while/NextIteration and while/NextIteration_1 should be
2588   // different, so should be those of while/Exit and while/Exit_1.
2589   EXPECT_NE(next_iter_start_micro, next_iter_1_start_micro);
2590   EXPECT_NE(exit_start_micro, exit_1_start_micro);
2591 
2592   // Check dependency among the nodes; no matter what scheduling mechanism we
2593   // use, the scheduled ops should follow these dependency chains.
2594   // Note that currently, VirtualScheduler executes while/Merge twice; hence,
2595   // we're not testing dependency chains related to while/Merge.
2596   // TODO(dyoon): after fixing while loop behavior correctly (run nodes in the
2597   // order of Enter, Merge, ...loop condition ..., ... loop body ...,
2598   // NextIteration, Merge, ... loop condition ..., Exit), re-enable dependency
2599   // chaining test w/ Merge nodes.
2600   ValidateDependencyChain(
2601       start_times,
2602       {"Const", "while/Enter",  // "while/Merge",
2603        "while/Less/y", "while/Less", "while/LoopCond", "while/Switch",
2604        "while/Identity", "while/add/y", "while/add", "while/NextIteration"});
2605   // ValidateDependencyChain(start_times, {"while/Merge", "while/Less"});
2606   ValidateDependencyChain(start_times,
2607                           {"ones", "while/Enter_1",  // "while/Merge_1",
2608                            "while/Switch_1", "while/Identity_1", "while/concat",
2609                            "while/NextIteration_1"});
2610   ValidateDependencyChain(start_times, {"while/Switch", "while/Exit"});
2611   ValidateDependencyChain(
2612       start_times, {"while/Identity", "while/concat/axis", "while/concat"});
2613   ValidateDependencyChain(start_times, {"while/Identity", "while/add"});
2614   ValidateDependencyChain(start_times, {"while/Switch_1", "while/Exit_1"});
2615 }
2616 
TEST_F(VirtualSchedulerTest,AnnotatedWhileLoop)2617 TEST_F(VirtualSchedulerTest, AnnotatedWhileLoop) {
2618   {
2619     // Init.
2620     CreateGrapplerItemWithLoop();
2621     InitScheduler();
2622 
2623     // Runs the scheduler.
2624     RunScheduler("");
2625     Costs c = scheduler_->Summary();
2626 
2627     EXPECT_EQ(23, c.execution_time.asMicroSeconds().count());
2628     // Both while/Merge and while/Merge_1 are scheduled twice.
2629     EXPECT_EQ(grappler_item_->graph.node_size() + 2, c.num_ops_total);
2630     EXPECT_FALSE(c.inaccurate);
2631     EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2632   }
2633 
2634   {
2635     // Init.
2636     CreateGrapplerItemWithLoopAnnotated();
2637     InitScheduler();
2638 
2639     // Runs the scheduler.
2640     RunScheduler("");
2641     Costs c = scheduler_->Summary();
2642 
2643     // The costs for Merge is accumulated twice for execution_count times, but
2644     // since Merge's cost is minimal, we keep this behavior here.
2645     EXPECT_EQ(178, c.execution_time.asMicroSeconds().count());
2646     // Both while/Merge and while/Merge_1 are scheduled twice.
2647     EXPECT_EQ(grappler_item_->graph.node_size() + 2, c.num_ops_total);
2648     EXPECT_FALSE(c.inaccurate);
2649     EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2650   }
2651 }
2652 
TEST_F(VirtualSchedulerTest,Condition)2653 TEST_F(VirtualSchedulerTest, Condition) {
2654   // Without annotation.
2655   {
2656     // Inits.
2657     CreateGrapplerItemWithCondition();
2658     InitScheduler();
2659 
2660     // Runs the scheduler.
2661     RunScheduler("");
2662     RunMetadata metadata;
2663     Costs c = scheduler_->Summary(&metadata);
2664 
2665     // Nodes in topological order: a/Less, Switch, First/Second, Merge.
2666     int num_a = 0;
2667     int num_less = 0;
2668     int num_switch = 0;
2669     int num_first = 0;
2670     int num_second = 0;
2671     int num_merge = 0;
2672 
2673     for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
2674       for (const auto& stats : device_step_stats.node_stats()) {
2675         if (stats.node_name() == "a") {
2676           ++num_a;
2677         } else if (stats.node_name() == "Less") {
2678           ++num_less;
2679         } else if (stats.node_name() == "Switch") {
2680           ++num_switch;
2681         } else if (stats.node_name() == "First") {
2682           ++num_first;
2683         } else if (stats.node_name() == "Second") {
2684           ++num_second;
2685         } else if (stats.node_name() == "Merge") {
2686           ++num_merge;
2687         }
2688       }
2689     }
2690 
2691     EXPECT_EQ(1, num_a);
2692     EXPECT_EQ(1, num_less);
2693     EXPECT_EQ(1, num_switch);
2694     EXPECT_EQ(1, num_first);
2695     EXPECT_EQ(1, num_second);
2696     EXPECT_EQ(2, num_merge);
2697 
2698     EXPECT_EQ(7, c.execution_time.asMicroSeconds().count());
2699     // Merge is executed twice.
2700     EXPECT_EQ(grappler_item_->graph.node_size() + 1, c.num_ops_total);
2701     EXPECT_FALSE(c.inaccurate);
2702     EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2703   }
2704 
2705   // With annotation.
2706   {
2707     // Inits.
2708     CreateGrapplerItemWithCondition();
2709 
2710     // Annotates the Switch node.
2711     for (auto& node : *grappler_item_->graph.mutable_node()) {
2712       if (node.name() == "Switch") {
2713         AttrValue attr_output_info;
2714         // Adds one output slot 0 so that Second shouldn't be executed.
2715         (*attr_output_info.mutable_list()).add_i(0);
2716         AddNodeAttr(kOutputSlots, attr_output_info, &node);
2717       }
2718     }
2719 
2720     InitScheduler();
2721 
2722     // Runs the scheduler.
2723     RunScheduler("");
2724     RunMetadata metadata;
2725     Costs c = scheduler_->Summary(&metadata);
2726 
2727     // Nodes in topological order: a/Less, Switch, Merge
2728     int num_a = 0;
2729     int num_less = 0;
2730     int num_switch = 0;
2731     int num_first = 0;
2732     int num_second = 0;
2733     int num_merge = 0;
2734 
2735     for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
2736       for (const auto& stats : device_step_stats.node_stats()) {
2737         if (stats.node_name() == "a") {
2738           ++num_a;
2739         } else if (stats.node_name() == "Less") {
2740           ++num_less;
2741         } else if (stats.node_name() == "Switch") {
2742           ++num_switch;
2743         } else if (stats.node_name() == "First") {
2744           ++num_first;
2745         } else if (stats.node_name() == "Second") {
2746           ++num_second;
2747         } else if (stats.node_name() == "Merge") {
2748           ++num_merge;
2749         }
2750       }
2751     }
2752 
2753     EXPECT_EQ(1, num_a);
2754     EXPECT_EQ(1, num_less);
2755     EXPECT_EQ(1, num_switch);
2756     EXPECT_EQ(1, num_first);
2757     EXPECT_EQ(0, num_second);
2758     EXPECT_EQ(1, num_merge);
2759 
2760     EXPECT_EQ(5, c.execution_time.asMicroSeconds().count());
2761     // Second is not executed.
2762     EXPECT_EQ(grappler_item_->graph.node_size() - 1, c.num_ops_total);
2763     EXPECT_FALSE(c.inaccurate);
2764     EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2765   }
2766 }
2767 
TEST_F(VirtualSchedulerTest,InterDeviceTransfer)2768 TEST_F(VirtualSchedulerTest, InterDeviceTransfer) {
2769   // Init.
2770   CreateGrapplerItemWithInterDeviceTransfers();
2771   InitScheduler();
2772 
2773   // Run the scheduler.
2774   auto ops_executed = RunScheduler("");
2775 
2776   // Helper lambda to extract port num from _Send and _Recv op name.
2777   auto get_port_num = [](const string& name) -> int {
2778     if (name.find("bn_0") != string::npos) {
2779       return 0;
2780     } else if (name.find("bn_1") != string::npos) {
2781       return 1;
2782     } else if (name.find("bn_2") != string::npos) {
2783       return 2;
2784     } else if (name.find("bn_minus1") != string::npos) {
2785       return -1;
2786     }
2787     return -999;
2788   };
2789 
2790   // Reorganize ops_executed for further testing.
2791   std::unordered_map<string, int> op_count;
2792   std::unordered_map<int, string> recv_op_names;
2793   std::unordered_map<int, string> send_op_names;
2794   for (const auto& x : ops_executed) {
2795     const auto& name = x.first;
2796     const auto& node_info = x.second;
2797     const auto& op = node_info.op_info.op();
2798     if (op == kRecv) {
2799       recv_op_names[get_port_num(name)] = name;
2800     } else if (op == kSend) {
2801       send_op_names[get_port_num(name)] = name;
2802     }
2803     op_count[op]++;
2804   }
2805 
2806   // Same number of _Send and _Recv.
2807   EXPECT_EQ(op_count.at(kSend), op_count.at(kRecv));
2808 
2809   // Expect 4 Send and Recvs each: port 0, 1, and, 2, and control dependency.
2810   EXPECT_EQ(op_count.at(kRecv), 4);
2811   EXPECT_EQ(op_count.at(kSend), 4);
2812 
2813   // Helper lambda for extracting output Tensor size.
2814   auto get_output_size = [this, ops_executed](const string& name) -> int64 {
2815     const auto& output_properties_ = ops_executed.at(name).op_info.outputs();
2816     std::vector<OpInfo::TensorProperties> output_properties;
2817     for (const auto& output_property : output_properties_) {
2818       output_properties.push_back(output_property);
2819     }
2820     return CalculateOutputSize(output_properties, 0);
2821   };
2822 
2823   // Validate transfer size.
2824   // Batchnorm output y is 4D vector: batch x width x width x depth.
2825   int input_size = 4 * batch_size_ * width_ * height_ * depth_in_;
2826   EXPECT_EQ(get_output_size(recv_op_names[0]), input_size);
2827   EXPECT_EQ(get_output_size(send_op_names[0]), input_size);
2828   // Mean and vars are 1-D vector with size depth_in_.
2829   EXPECT_EQ(get_output_size(recv_op_names[1]), 4 * depth_in_);
2830   EXPECT_EQ(get_output_size(send_op_names[1]), 4 * depth_in_);
2831   EXPECT_EQ(get_output_size(recv_op_names[2]), 4 * depth_in_);
2832   EXPECT_EQ(get_output_size(send_op_names[2]), 4 * depth_in_);
2833   // Control dependency size is 4B.
2834   EXPECT_EQ(get_output_size(recv_op_names[-1]), 4);
2835   EXPECT_EQ(get_output_size(send_op_names[-1]), 4);
2836 }
2837 
TEST_F(VirtualSchedulerTest,GraphWithSendRecv)2838 TEST_F(VirtualSchedulerTest, GraphWithSendRecv) {
2839   // Init.
2840   CreateGrapplerItemWithSendRecv();
2841   InitScheduler();
2842 
2843   // Run the scheduler.
2844   auto ops_executed = RunScheduler("");
2845 
2846   EXPECT_GT(ops_executed.count("Const"), 0);
2847   EXPECT_GT(ops_executed.count("Send"), 0);
2848   EXPECT_GT(ops_executed.count("Recv"), 0);
2849 }
2850 
TEST_F(VirtualSchedulerTest,GraphWithSendRecvDifferentDevice)2851 TEST_F(VirtualSchedulerTest, GraphWithSendRecvDifferentDevice) {
2852   // Init.
2853   CreateGrapplerItemWithSendRecv();
2854   // Change Recv node's device so that Send and Recv are placed on different
2855   // devices.
2856   auto& graph = grappler_item_->graph;
2857   const string recv_device = kCPU1;
2858   for (int i = 0; i < graph.node_size(); i++) {
2859     auto* node = graph.mutable_node(i);
2860     if (node->name() == "Recv") {
2861       node->set_device(recv_device);
2862       auto* attr = node->mutable_attr();
2863       (*attr)["recv_device"].set_s(recv_device);
2864     } else if (node->name() == "Send") {
2865       auto* attr = node->mutable_attr();
2866       (*attr)["recv_device"].set_s(recv_device);
2867     }
2868   }
2869   InitScheduler();
2870 
2871   // Run the scheduler.
2872   auto ops_executed = RunScheduler("");
2873 
2874   // Expect Const, Send, Recv, and VirtualScheduler created Send and Recv ops.
2875   EXPECT_GT(ops_executed.count("Const"), 0);
2876   EXPECT_GT(ops_executed.count("Send"), 0);
2877   EXPECT_GT(ops_executed.count("Send_Send_0_from_/job_localhost/replica_0/"
2878                                "task_0/cpu_0_to_/job_localhost"
2879                                "/replica_0/task_0/cpu_1"),
2880             0);
2881   EXPECT_GT(ops_executed.count(
2882                 "Recv_Send_0_on_/job_localhost/replica_0/task_0/cpu_1"),
2883             0);
2884   EXPECT_GT(ops_executed.count("Recv"), 0);
2885 }
2886 
TEST_F(VirtualSchedulerTest,GraphWihtOnlyRecv)2887 TEST_F(VirtualSchedulerTest, GraphWihtOnlyRecv) {
2888   // Init.
2889   CreateGrapplerItemWithRecvWithoutSend();
2890   InitScheduler();
2891 
2892   // Run the scheduler.
2893   auto ops_executed = RunScheduler("");
2894 
2895   // Recv without Send will be treated as initially ready node.
2896   EXPECT_GT(ops_executed.count("Recv"), 0);
2897 }
2898 
2899 }  // end namespace grappler
2900 }  // end namespace tensorflow
2901