• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/costs/virtual_scheduler.h"
17 
18 #include "tensorflow/cc/ops/standard_ops.h"
19 #include "tensorflow/core/framework/tensor_description.pb.h"
20 #include "tensorflow/core/framework/tensor_shape.pb.h"
21 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
22 #include "tensorflow/core/grappler/costs/utils.h"
23 #include "tensorflow/core/grappler/costs/virtual_placer.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26 
27 namespace tensorflow {
28 namespace grappler {
29 namespace {
30 
31 // Device names:
32 constexpr char kCPU0[] = "/job:localhost/replica:0/task:0/cpu:0";
33 constexpr char kCPU1[] = "/job:localhost/replica:0/task:0/cpu:1";
34 constexpr char kChannelFrom0To1[] = "Channel from CPU0 to CPU1";
35 constexpr char kChannelFrom1To0[] = "Channel from CPU1 to CPU0";
36 // Op names:
37 constexpr char kConv2D[] = "Conv2D";
38 constexpr char kSend[] = "_Send";
39 constexpr char kRecv[] = "_Recv";
40 
41 class ReadyNodeManagerTest : public ::testing::Test {
42  protected:
ReadyNodeManagerTest()43   ReadyNodeManagerTest() {
44     // node1_ to node6_ on kCPU0, with time_ready in reverse_order.
45     NodeSetUp("Node1", kConv2D, kCPU0, 6000, &node1_);
46     NodeSetUp("Node2", kConv2D, kCPU0, 5000, &node2_);
47     NodeSetUp("Node3", kConv2D, kCPU0, 4000, &node3_);
48     NodeSetUp("Node4", kConv2D, kCPU0, 3000, &node4_);
49     NodeSetUp("Node5", kConv2D, kCPU0, 2000, &node5_);
50     NodeSetUp("Node6", kConv2D, kCPU0, 1000, &node6_);
51   }
52 
NodeSetUp(const string & name,const string & op_name,const string & device_name,const uint64 time_ready,NodeDef * node)53   void NodeSetUp(const string& name, const string& op_name,
54                  const string& device_name, const uint64 time_ready,
55                  NodeDef* node) {
56     node->set_name(name);
57     node->set_op(op_name);
58     node->set_device(device_name);
59 
60     node_states_[node] = NodeState();
61     node_states_[node].time_ready = time_ready;
62     node_states_[node].device_name = device_name;
63   }
64 
65   NodeDef node1_, node2_, node3_, node4_, node5_, node6_;
66   std::unordered_map<const NodeDef*, NodeState> node_states_;
67 };
68 
69 // Tests that FIFOManager correctly returns the current node with only 1 node.
TEST_F(ReadyNodeManagerTest,GetSingleNodeFIFOManager)70 TEST_F(ReadyNodeManagerTest, GetSingleNodeFIFOManager) {
71   FIFOManager manager = FIFOManager();
72   manager.AddNode(&node1_);
73   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
74 }
75 
76 // Tests that FIFOManager removes the only node contained within.
TEST_F(ReadyNodeManagerTest,RemoveSingleNodeFIFOManager)77 TEST_F(ReadyNodeManagerTest, RemoveSingleNodeFIFOManager) {
78   FIFOManager manager = FIFOManager();
79   manager.AddNode(&node1_);
80 
81   // Removes the only node in FIFOManager.
82   manager.RemoveCurrNode();
83   EXPECT_TRUE(manager.Empty());
84 }
85 
86 // Tests that FIFOManager can remove multiple nodes and returns the current node
87 // in the right order.
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultipleFIFOManager)88 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleFIFOManager) {
89   FIFOManager manager = FIFOManager();
90   manager.AddNode(&node1_);
91   manager.AddNode(&node2_);
92   manager.AddNode(&node3_);
93   manager.AddNode(&node4_);
94 
95   // Keeps checking current node while removing nodes from manager.
96   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
97   manager.RemoveCurrNode();
98   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
99   manager.RemoveCurrNode();
100   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
101   manager.RemoveCurrNode();
102   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
103   manager.RemoveCurrNode();
104   EXPECT_TRUE(manager.Empty());
105 }
106 
107 // Tests that FIFOManager can remove multiple nodes and add more nodes, still
108 // returning the current node in the right order.
TEST_F(ReadyNodeManagerTest,AddAndRemoveMultipleFIFOManager)109 TEST_F(ReadyNodeManagerTest, AddAndRemoveMultipleFIFOManager) {
110   FIFOManager manager = FIFOManager();
111   manager.AddNode(&node1_);
112   manager.AddNode(&node2_);
113   manager.AddNode(&node3_);
114   manager.AddNode(&node4_);
115 
116   // Keeps checking current node as nodes are removed and added.
117   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
118   manager.RemoveCurrNode();
119   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
120   manager.AddNode(&node5_);
121   // GetCurrNode() should return the same node even if some nodes are added,
122   // until RemoveCurrNode() is called.
123   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
124   manager.RemoveCurrNode();
125   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
126   manager.RemoveCurrNode();
127   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
128   manager.AddNode(&node6_);
129   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
130   manager.RemoveCurrNode();
131   EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
132   manager.RemoveCurrNode();
133   EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
134   manager.RemoveCurrNode();
135   EXPECT_TRUE(manager.Empty());
136 }
137 
138 // Tests that LIFOManager correctly returns the current node with only 1 node.
TEST_F(ReadyNodeManagerTest,GetSingleNodeLIFOManager)139 TEST_F(ReadyNodeManagerTest, GetSingleNodeLIFOManager) {
140   LIFOManager manager = LIFOManager();
141   manager.AddNode(&node1_);
142   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
143 }
144 
145 // Tests that LIFOManager removes the only node contained within.
TEST_F(ReadyNodeManagerTest,RemoveSingleNodeLIFOManager)146 TEST_F(ReadyNodeManagerTest, RemoveSingleNodeLIFOManager) {
147   LIFOManager manager = LIFOManager();
148   manager.AddNode(&node1_);
149 
150   // Removes the only node in LIFOManager.
151   manager.RemoveCurrNode();
152   EXPECT_TRUE(manager.Empty());
153 }
154 
155 // Tests that LIFOManager can remove multiple nodes and returns the current node
156 // in the right order.
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultipleLIFOManager)157 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleLIFOManager) {
158   LIFOManager manager = LIFOManager();
159   manager.AddNode(&node1_);
160   manager.AddNode(&node2_);
161   manager.AddNode(&node3_);
162   manager.AddNode(&node4_);
163 
164   // Keeps checking current node while removing nodes from manager.
165   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
166   manager.RemoveCurrNode();
167   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
168   manager.RemoveCurrNode();
169   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
170   manager.RemoveCurrNode();
171   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
172   manager.RemoveCurrNode();
173   EXPECT_TRUE(manager.Empty());
174 }
175 
176 // Tests that LIFOManager can remove multiple nodes (must be removing the
177 // current node) and add more nodes, still returning the current node in the
178 // right order.
TEST_F(ReadyNodeManagerTest,AddAndRemoveMultipleLIFOManager)179 TEST_F(ReadyNodeManagerTest, AddAndRemoveMultipleLIFOManager) {
180   LIFOManager manager = LIFOManager();
181   manager.AddNode(&node1_);
182   manager.AddNode(&node2_);
183   manager.AddNode(&node3_);
184   manager.AddNode(&node4_);
185 
186   // Keeps checking current node as nodes are removed and added.
187   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
188   manager.RemoveCurrNode();
189   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
190   manager.AddNode(&node5_);
191   // GetCurrNode()  should return the same node even if some nodes are added,
192   // until RemoveCurrNode() is called.
193   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
194   manager.RemoveCurrNode();
195   EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
196   manager.RemoveCurrNode();
197   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
198   manager.AddNode(&node6_);
199   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
200   manager.RemoveCurrNode();
201   EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
202   manager.RemoveCurrNode();
203   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
204   manager.RemoveCurrNode();
205   EXPECT_TRUE(manager.Empty());
206 }
207 
TEST_F(ReadyNodeManagerTest,GetSingleNodeFirstReadyManager)208 TEST_F(ReadyNodeManagerTest, GetSingleNodeFirstReadyManager) {
209   FirstReadyManager manager;
210   TF_EXPECT_OK(manager.Init(&node_states_));
211   manager.AddNode(&node1_);
212   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
213 }
214 
TEST_F(ReadyNodeManagerTest,RemoveSingleNodeFirstReadyManager)215 TEST_F(ReadyNodeManagerTest, RemoveSingleNodeFirstReadyManager) {
216   FirstReadyManager manager;
217   TF_EXPECT_OK(manager.Init(&node_states_));
218   manager.AddNode(&node1_);
219   manager.RemoveCurrNode();
220   EXPECT_TRUE(manager.Empty());
221 }
222 
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultipleFirstReadyManager)223 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleFirstReadyManager) {
224   FirstReadyManager manager;
225   TF_EXPECT_OK(manager.Init(&node_states_));
226   // Insert nodes in some random order.
227   manager.AddNode(&node2_);
228   manager.AddNode(&node1_);
229   manager.AddNode(&node4_);
230   manager.AddNode(&node5_);
231   manager.AddNode(&node3_);
232   manager.AddNode(&node6_);
233 
234   // In whatever order we insert nodes, we get the same order based on nodes'
235   // time_ready.
236   EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
237   manager.RemoveCurrNode();
238   EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
239   manager.RemoveCurrNode();
240   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
241   manager.RemoveCurrNode();
242   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
243   manager.RemoveCurrNode();
244   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
245   manager.RemoveCurrNode();
246   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
247   manager.RemoveCurrNode();
248   EXPECT_TRUE(manager.Empty());
249 }
250 
TEST_F(ReadyNodeManagerTest,GetCurrNodeFirstReadyManager)251 TEST_F(ReadyNodeManagerTest, GetCurrNodeFirstReadyManager) {
252   FirstReadyManager manager;
253   TF_EXPECT_OK(manager.Init(&node_states_));
254 
255   // Inserts nodes in some random order.
256   manager.AddNode(&node2_);
257   manager.AddNode(&node1_);
258   manager.AddNode(&node4_);
259   manager.AddNode(&node5_);
260   manager.AddNode(&node3_);
261   manager.AddNode(&node6_);
262 
263   // Among these nodes, node6 has the smallest time_ready, hence, GetCurrNode()
264   // should return it.
265   EXPECT_EQ("Node6", manager.GetCurrNode()->name());
266 
267   // Now insrets a few other nodes, but their time_ready's are even smaller than
268   // that of Node6. Before calling RemoveCurrNode(), GetCurrNode() should return
269   // the same node, Node6, in this case.
270   NodeDef node7;
271   NodeDef node8;
272   NodeDef node9;
273   NodeSetUp("Node7", kConv2D, kCPU0, 5, &node7);
274   NodeSetUp("Node8", kConv2D, kCPU0, 4, &node8);
275   NodeSetUp("Node9", kConv2D, kCPU0, 3, &node9);
276 
277   manager.AddNode(&node7);
278   EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
279 
280   manager.AddNode(&node8);
281   EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
282 
283   manager.RemoveCurrNode();
284   // Now Node6 is removed, and GetCurrNode() will return Node8.
285   EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
286 
287   // Again, AddNode shouldn't change GetCurrNode().
288   manager.AddNode(&node9);
289   EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
290 
291   manager.RemoveCurrNode();
292   EXPECT_EQ(manager.GetCurrNode()->name(), "Node9");
293   manager.RemoveCurrNode();
294   EXPECT_EQ(manager.GetCurrNode()->name(), "Node7");
295   manager.RemoveCurrNode();
296   EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
297   manager.RemoveCurrNode();
298   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
299   manager.RemoveCurrNode();
300   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
301   manager.RemoveCurrNode();
302   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
303   manager.RemoveCurrNode();
304   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
305   manager.RemoveCurrNode();
306   EXPECT_TRUE(manager.Empty());
307 }
308 
TEST_F(ReadyNodeManagerTest,DeterminismInFirstReadyManager)309 TEST_F(ReadyNodeManagerTest, DeterminismInFirstReadyManager) {
310   FirstReadyManager manager1;
311   TF_EXPECT_OK(manager1.Init(&node_states_));
312   FirstReadyManager manager2;
313   TF_EXPECT_OK(manager2.Init(&node_states_));
314 
315   // 6 nodes with same time_ready.
316   NodeDef node7;
317   NodeDef node8;
318   NodeDef node9;
319   NodeDef node10;
320   NodeDef node11;
321   NodeDef node12;
322   NodeSetUp("Node7", kConv2D, kCPU0, 1000, &node7);
323   NodeSetUp("Node8", kConv2D, kCPU0, 1000, &node8);
324   NodeSetUp("Node9", kConv2D, kCPU0, 1000, &node9);
325   NodeSetUp("Node10", kConv2D, kCPU0, 1000, &node10);
326   NodeSetUp("Node11", kConv2D, kCPU0, 1000, &node11);
327   NodeSetUp("Node12", kConv2D, kCPU0, 1000, &node12);
328 
329   // Adds the above 6 nodes to manager1.
330   manager1.AddNode(&node7);
331   manager1.AddNode(&node8);
332   manager1.AddNode(&node9);
333   manager1.AddNode(&node10);
334   manager1.AddNode(&node11);
335   manager1.AddNode(&node12);
336 
337   // Adds the above 6 nodes to manager2, but in a different order.
338   manager2.AddNode(&node8);
339   manager2.AddNode(&node11);
340   manager2.AddNode(&node9);
341   manager2.AddNode(&node10);
342   manager2.AddNode(&node7);
343   manager2.AddNode(&node12);
344 
345   // Expects both managers return the same nodes for deterministic node
346   // scheduling.
347   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
348   manager1.RemoveCurrNode();
349   manager2.RemoveCurrNode();
350 
351   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
352   manager1.RemoveCurrNode();
353   manager2.RemoveCurrNode();
354 
355   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
356   manager1.RemoveCurrNode();
357   manager2.RemoveCurrNode();
358 
359   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
360   manager1.RemoveCurrNode();
361   manager2.RemoveCurrNode();
362 
363   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
364   manager1.RemoveCurrNode();
365   manager2.RemoveCurrNode();
366 
367   EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
368   manager1.RemoveCurrNode();
369   manager2.RemoveCurrNode();
370 
371   EXPECT_TRUE(manager1.Empty());
372   EXPECT_TRUE(manager2.Empty());
373 }
374 
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultiplePriorityReadyManager)375 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultiplePriorityReadyManager) {
376   PriorityReadyManager manager;
377   TF_EXPECT_OK(manager.Init(&node_states_));
378 
379   // Sets up node priorities.
380   std::unordered_map<string, int> node_priority = {
381       {"Node1", 1}, {"Node2", 2}, {"Node3", 2}, {"Node4", 4}, {"Node5", 5}};
382   TF_EXPECT_OK(manager.SetPriority(node_priority));
383 
384   // Inserts nodes in some random order.
385   manager.AddNode(&node3_);
386   manager.AddNode(&node1_);
387   manager.AddNode(&node4_);
388   manager.AddNode(&node5_);
389   manager.AddNode(&node2_);
390   manager.AddNode(&node6_);
391 
392   // Expects nodes scheduled based on priority.
393   // Node6 should default to lowest priority, since it is not found.
394   EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
395   manager.RemoveCurrNode();
396   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
397   manager.RemoveCurrNode();
398   // Nodes 2 and 3 have equal priority and so should be scheduled ready-first.
399   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
400   manager.RemoveCurrNode();
401   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
402   manager.RemoveCurrNode();
403   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
404   manager.RemoveCurrNode();
405   EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
406   manager.RemoveCurrNode();
407   EXPECT_TRUE(manager.Empty());
408 }
409 
TEST_F(ReadyNodeManagerTest,RemoveSingleNodeCompositeNodeManager)410 TEST_F(ReadyNodeManagerTest, RemoveSingleNodeCompositeNodeManager) {
411   CompositeNodeManager manager;
412   TF_EXPECT_OK(manager.Init(&node_states_));
413   manager.AddNode(&node1_);
414   manager.RemoveCurrNode();
415   EXPECT_TRUE(manager.Empty());
416 }
417 
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultipleComopsiteNodeManager)418 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleComopsiteNodeManager) {
419   CompositeNodeManager manager;
420   TF_EXPECT_OK(manager.Init(&node_states_));
421   manager.AddNode(&node1_);
422   manager.AddNode(&node2_);
423   manager.AddNode(&node3_);
424   manager.AddNode(&node4_);
425 
426   // Keeps checking current node as nodes are removed and added.
427   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
428   manager.RemoveCurrNode();
429   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
430   manager.AddNode(&node5_);
431   // GetCurrNode()  should return the same node even if some nodes are added,
432   // until RemoveCurrNode() is called.
433   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
434   manager.RemoveCurrNode();
435   EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
436   manager.RemoveCurrNode();
437   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
438   manager.AddNode(&node6_);
439   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
440   manager.RemoveCurrNode();
441   EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
442   manager.RemoveCurrNode();
443   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
444   manager.RemoveCurrNode();
445   EXPECT_TRUE(manager.Empty());
446 }
447 
TEST_F(ReadyNodeManagerTest,MultiDeviceSendRecvComopsiteNodeManager)448 TEST_F(ReadyNodeManagerTest, MultiDeviceSendRecvComopsiteNodeManager) {
449   CompositeNodeManager manager;
450   TF_EXPECT_OK(manager.Init(&node_states_));
451   // Additional nodes on kCPU1.
452   NodeDef node7;
453   NodeDef node8;
454   NodeDef node9;
455   NodeSetUp("Node7", kConv2D, kCPU1, 1001, &node7);
456   NodeSetUp("Node8", kConv2D, kCPU1, 2001, &node8);
457   NodeSetUp("Node9", kConv2D, kCPU1, 3001, &node9);
458 
459   // Send and Recv nodes.
460   NodeDef send1;
461   NodeDef send2;
462   NodeDef recv1;
463   NodeDef recv2;
464   NodeSetUp("Send1", kSend, kChannelFrom0To1, 2002, &send1);
465   NodeSetUp("Send2", kSend, kChannelFrom1To0, 2005, &send2);
466   NodeSetUp("Recv1", kRecv, kCPU0, 2003, &recv1);
467   NodeSetUp("Recv2", kRecv, kCPU1, 2004, &recv2);
468 
469   // Inserts nodes.
470   manager.AddNode(&node1_);
471   manager.AddNode(&node2_);
472   manager.AddNode(&node3_);
473   manager.AddNode(&node4_);
474   manager.AddNode(&node5_);
475   manager.AddNode(&node6_);
476   manager.AddNode(&node7);
477   manager.AddNode(&node8);
478   manager.AddNode(&node9);
479   manager.AddNode(&send1);
480   manager.AddNode(&send2);
481   manager.AddNode(&recv1);
482   manager.AddNode(&recv2);
483 
484   // On kCPU0; last one is node6_, on kCPU1: last one is node9;
485   // so choose one that has earliest time_ready among node6_, node9,
486   // Send1, Send2, Recv1, and Recv2.
487   EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
488   manager.RemoveCurrNode();
489   // Then, the next one on kCPU0 is node5_; choose the earliest time_ready node
490   // among node5_, node9, Send1, Send2, Recv1, and Recv2.
491   EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
492   manager.RemoveCurrNode();
493   // Next, choose among node4_, node9, Send1, Send2, Recv1, and Recv2.
494   EXPECT_EQ(manager.GetCurrNode()->name(), "Send1");
495   manager.RemoveCurrNode();
496   // Next, choose among node4_, node9, Sen2, Recv1, and Recv2.
497   EXPECT_EQ(manager.GetCurrNode()->name(), "Recv1");
498   manager.RemoveCurrNode();
499   // Next, choose among node4_, node9, Send2, and Recv2.
500   EXPECT_EQ(manager.GetCurrNode()->name(), "Recv2");
501   manager.RemoveCurrNode();
502   // Next, choose among node4_, node9, and Send2.
503   EXPECT_EQ(manager.GetCurrNode()->name(), "Send2");
504   manager.RemoveCurrNode();
505   // Next, choose between node4_, node9.
506   EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
507   manager.RemoveCurrNode();
508   // Next, choose between node3_, node9.
509   EXPECT_EQ(manager.GetCurrNode()->name(), "Node9");
510   manager.RemoveCurrNode();
511   // Next, choose between node3_, node8.
512   EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
513   manager.RemoveCurrNode();
514   // Next, choose between node3_, node7.
515   EXPECT_EQ(manager.GetCurrNode()->name(), "Node7");
516   manager.RemoveCurrNode();
517   // Then, just the nodes on kCPU1 -- LIFO.
518   EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
519   manager.RemoveCurrNode();
520   EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
521   manager.RemoveCurrNode();
522   EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
523   manager.RemoveCurrNode();
524   EXPECT_TRUE(manager.Empty());
525 }
526 
TEST_F(ReadyNodeManagerTest,DeterminismInCompositeNodeManager)527 TEST_F(ReadyNodeManagerTest, DeterminismInCompositeNodeManager) {
528   CompositeNodeManager manager;
529   TF_EXPECT_OK(manager.Init(&node_states_));
530   CompositeNodeManager manager2;
531   TF_EXPECT_OK(manager2.Init(&node_states_));
532 
533   // 6 nodes with same time_ready.
534   NodeDef node7;
535   NodeDef node8;
536   NodeDef node9;
537   NodeDef node10;
538   NodeDef node11;
539   NodeDef node12;
540   NodeSetUp("Node7", kConv2D, kCPU0, 1000, &node7);
541   NodeSetUp("Node8", kSend, kCPU0, 1000, &node8);
542   NodeSetUp("Node9", kRecv, kCPU0, 1000, &node9);
543   NodeSetUp("Node10", kConv2D, kCPU0, 999, &node10);
544   NodeSetUp("Node11", kRecv, kCPU0, 999, &node11);
545   NodeSetUp("Node12", kConv2D, kCPU1, 1000, &node12);
546 
547   // Adds Nodes 7 to 9 to manager.
548   manager.AddNode(&node7);
549   manager.AddNode(&node8);
550   manager.AddNode(&node9);
551 
552   // It should return _Send, Recv, and the other op order, when the candidate
553   // nodes have same time_ready.
554   EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
555   EXPECT_EQ(manager.GetCurrNode()->op(), kSend);
556   manager.RemoveCurrNode();
557   EXPECT_EQ(manager.GetCurrNode()->name(), "Node9");
558   EXPECT_EQ(manager.GetCurrNode()->op(), kRecv);
559   manager.RemoveCurrNode();
560   EXPECT_EQ(manager.GetCurrNode()->name(), "Node7");
561   EXPECT_EQ(manager.GetCurrNode()->op(), kConv2D);
562   manager.RemoveCurrNode();
563   EXPECT_TRUE(manager.Empty());
564 
565   // Adds Nodes 7 to 9 to manager, but in a different order.
566   manager.AddNode(&node9);
567   manager.AddNode(&node8);
568   manager.AddNode(&node7);
569 
570   // Expects same order (_Send, _Recv, and the other op), regardless of Add
571   // order.
572   EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
573   EXPECT_EQ(manager.GetCurrNode()->op(), kSend);
574   manager.RemoveCurrNode();
575   EXPECT_EQ(manager.GetCurrNode()->name(), "Node9");
576   EXPECT_EQ(manager.GetCurrNode()->op(), kRecv);
577   manager.RemoveCurrNode();
578   EXPECT_EQ(manager.GetCurrNode()->name(), "Node7");
579   EXPECT_EQ(manager.GetCurrNode()->op(), kConv2D);
580   manager.RemoveCurrNode();
581   EXPECT_TRUE(manager.Empty());
582 
583   // Conv2D's time_ready < Send's time_ready; Expects Conv2D first.
584   manager.AddNode(&node8);
585   manager.AddNode(&node10);
586   EXPECT_EQ(manager.GetCurrNode()->name(), "Node10");
587   EXPECT_EQ(manager.GetCurrNode()->op(), kConv2D);
588   manager.RemoveCurrNode();
589   EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
590   EXPECT_EQ(manager.GetCurrNode()->op(), kSend);
591   manager.RemoveCurrNode();
592   EXPECT_TRUE(manager.Empty());
593 
594   // Recv's time_ready < Send' time_ready; Expects Recv first.
595   manager.AddNode(&node11);
596   manager.AddNode(&node8);
597   EXPECT_EQ(manager.GetCurrNode()->name(), "Node11");
598   EXPECT_EQ(manager.GetCurrNode()->op(), kRecv);
599   manager.RemoveCurrNode();
600   EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
601   EXPECT_EQ(manager.GetCurrNode()->op(), kSend);
602   manager.RemoveCurrNode();
603   EXPECT_TRUE(manager.Empty());
604 
605   // Node7 and 12 are normal ops with the same time_ready, placed on different
606   // devices. These two nodes are added to manager and manager2, but in
607   // different orders; Expects GetCurrNode() returns the nodes in the same
608   // order.
609   manager.AddNode(&node7);
610   manager.AddNode(&node12);
611 
612   manager2.AddNode(&node12);
613   manager2.AddNode(&node7);
614 
615   EXPECT_EQ(manager.GetCurrNode()->name(), manager2.GetCurrNode()->name());
616   manager.RemoveCurrNode();
617   manager2.RemoveCurrNode();
618   EXPECT_EQ(manager.GetCurrNode()->name(), manager2.GetCurrNode()->name());
619   manager.RemoveCurrNode();
620   manager2.RemoveCurrNode();
621   EXPECT_TRUE(manager.Empty());
622 }
623 
624 // Class for testing virtual scheduler.
625 class TestVirtualScheduler : public VirtualScheduler {
626  public:
TestVirtualScheduler(const bool use_static_shapes,const bool use_aggressive_shape_inference,ReadyNodeManager * ready_node_manager,Cluster * cluster)627   TestVirtualScheduler(const bool use_static_shapes,
628                        const bool use_aggressive_shape_inference,
629                        ReadyNodeManager* ready_node_manager, Cluster* cluster)
630       : VirtualScheduler(
631             use_static_shapes, use_aggressive_shape_inference, cluster,
632             ready_node_manager,
633             absl::make_unique<VirtualPlacer>(cluster->GetDevices())) {
634     enable_mem_usage_tracking();
635   }
636 
637   FRIEND_TEST(VirtualSchedulerTest, MemoryUsage);
638   FRIEND_TEST(VirtualSchedulerTest, ControlDependency);
639   FRIEND_TEST(VirtualSchedulerTest, ComplexDependency);
640   FRIEND_TEST(VirtualSchedulerTest, Variable);
641   FRIEND_TEST(VirtualSchedulerTest, InterDeviceTransfer);
642 };
643 
644 class VirtualSchedulerTest : public ::testing::Test {
645  protected:
VirtualSchedulerTest()646   VirtualSchedulerTest() {
647     // Initializes cluster_ and scheduler_.
648     std::unordered_map<string, DeviceProperties> devices;
649 
650     // Set some dummy CPU properties
651     DeviceProperties cpu_device = GetDummyCPUDevice();
652 
653     // IMPORTANT: Device is not actually ever used in the test case since
654     // force_cpu_type is defaulted to "Haswell"
655     devices[kCPU0] = cpu_device;
656     devices[kCPU1] = cpu_device;
657     cluster_ = absl::make_unique<VirtualCluster>(devices);
658     scheduler_ = absl::make_unique<TestVirtualScheduler>(
659         /*use_static_shapes=*/true,
660         /*use_aggressive_shape_inference=*/true, &first_ready_manager_,
661         cluster_.get());
662   }
663 
GetDummyCPUDevice()664   DeviceProperties GetDummyCPUDevice() {
665     // Create CPU with 2 cores, 4 Ghz freq, 2 GB/s mem bandwidth.
666     // - 8 Gflops
667     // - 2 GB/s
668     DeviceProperties cpu_device;
669     cpu_device.set_type("CPU");
670     cpu_device.set_frequency(4000);
671     cpu_device.set_num_cores(2);
672     cpu_device.set_bandwidth(2000000);
673     return cpu_device;
674   }
675 
676   // Three Conv2Ds with only two in fetch nodes.
CreateGrapplerItemWithConv2Ds()677   void CreateGrapplerItemWithConv2Ds() {
678     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
679     auto x = ops::RandomUniform(
680         s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
681     auto y = ops::RandomUniform(
682         s.WithOpName("y"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
683     auto z = ops::RandomUniform(
684         s.WithOpName("z"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
685     auto f = ops::RandomUniform(
686         s.WithOpName("f"), {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
687     std::vector<int> strides = {1, 1, 1, 1};
688     auto c0 = ops::Conv2D(s.WithOpName("c0"), x, f, strides, "SAME");
689     auto c1 = ops::Conv2D(s.WithOpName("c1"), y, f, strides, "SAME");
690     auto c2 = ops::Conv2D(s.WithOpName("c2"), z, f, strides, "SAME");
691 
692     grappler_item_ = absl::make_unique<GrapplerItem>();
693     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
694     grappler_item_->id = "test_conv2d_graph";
695     grappler_item_->fetch = {"c0", "c1"};
696 
697     dependency_["c0"] = {"x", "f"};
698     dependency_["c1"] = {"y", "f"};
699   }
700 
701   // A Conv2D with a variable.
CreateGrapplerItemWithConv2DAndVariable()702   void CreateGrapplerItemWithConv2DAndVariable() {
703     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
704     auto x = ops::RandomUniform(
705         s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
706     auto f = ops::Variable(s.WithOpName("f"),
707                            {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
708     std::vector<int> strides = {1, 1, 1, 1};
709     auto y = ops::Conv2D(s.WithOpName("y"), x, f, strides, "SAME");
710 
711     grappler_item_ = absl::make_unique<GrapplerItem>();
712     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
713     grappler_item_->id = "test_conv2d_var_graph";
714 
715     grappler_item_->fetch = {"y"};
716 
717     dependency_["y"] = {"x", "f"};
718   }
719 
CreateGrapplerItemWithMatmulChain()720   void CreateGrapplerItemWithMatmulChain() {
721     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
722     // Add control dependencies to ensure tests do not rely on specific
723     // manager and the order remains consistent for the test.
724     auto a = ops::RandomUniform(s.WithOpName("a"), {3200, 3200}, DT_FLOAT);
725     auto b = ops::RandomUniform(s.WithOpName("b").WithControlDependencies(a),
726                                 {3200, 3200}, DT_FLOAT);
727     auto c = ops::RandomUniform(s.WithOpName("c").WithControlDependencies(b),
728                                 {3200, 3200}, DT_FLOAT);
729     auto d = ops::RandomUniform(s.WithOpName("d").WithControlDependencies(c),
730                                 {3200, 3200}, DT_FLOAT);
731     auto e = ops::RandomUniform(s.WithOpName("e").WithControlDependencies(d),
732                                 {3200, 3200}, DT_FLOAT);
733 
734     auto ab = ops::MatMul(s.WithOpName("ab").WithControlDependencies(e), a, b);
735     auto abc = ops::MatMul(s.WithOpName("abc"), ab, c);
736     auto abcd = ops::MatMul(s.WithOpName("abcd"), abc, d);
737     auto abcde = ops::MatMul(s.WithOpName("abcde"), abcd, e);
738 
739     grappler_item_ = absl::make_unique<GrapplerItem>();
740     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
741     grappler_item_->id = "test_matmul_sequence_graph";
742     grappler_item_->fetch = {"abcde"};
743 
744     dependency_["ab"] = {"a", "b"};
745     dependency_["abc"] = {"ab", "c"};
746     dependency_["abcd"] = {"abc", "d"};
747     dependency_["abcde"] = {"abcd", "e"};
748   }
749 
750   // AddN that takes 4 tensors with 10x10x10x10.
CreateGrapplerItemWithAddN()751   void CreateGrapplerItemWithAddN() {
752     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
753     auto x = ops::RandomUniform(s.WithOpName("x"), {10, 10, 10, 10}, DT_FLOAT);
754     auto y = ops::RandomUniform(s.WithOpName("y"), {10, 10, 10, 10}, DT_FLOAT);
755     auto z = ops::RandomUniform(s.WithOpName("z"), {10, 10, 10, 10}, DT_FLOAT);
756     auto w = ops::RandomUniform(s.WithOpName("w"), {10, 10, 10, 10}, DT_FLOAT);
757     OutputList input_tensors = {x, y, z, w};
758     auto out = ops::AddN(s.WithOpName("out"), input_tensors);
759 
760     grappler_item_ = absl::make_unique<GrapplerItem>();
761     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
762     grappler_item_->id = "test_addn_graph";
763     grappler_item_->fetch = {"out"};
764 
765     dependency_["out"] = {"x", "y", "z", "w"};
766   }
767 
768   // Graph with some placeholder feed nodes that are not in the fetch fan-in.
CreateGrapplerItemWithUnnecessaryPlaceholderNodes()769   void CreateGrapplerItemWithUnnecessaryPlaceholderNodes() {
770     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
771     auto unnecessary = ops::Placeholder(s.WithOpName("unnecessary"), DT_FLOAT);
772     auto x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT);
773 
774     grappler_item_ = absl::make_unique<GrapplerItem>();
775     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
776 
777     grappler_item_->id = "test_extra_placeholders";
778     grappler_item_->fetch = {"x"};
779 
780     // Grappler Item Builder puts all placeholder nodes into the feed
781     // list by default.
782     grappler_item_->feed = {{"x", Tensor()}, {"unnecessary", Tensor()}};
783   }
784 
785   // NoOp that takes 7 NoOps as control dependency.
CreateGrapplerItemWithControlDependency()786   void CreateGrapplerItemWithControlDependency() {
787     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
788     std::vector<string> input_noop_names = {"x", "y", "z", "w", "u", "v", "t"};
789     std::vector<Operation> input_tensors;
790     for (const auto& input : input_noop_names) {
791       auto x = ops::NoOp(s.WithOpName(input));
792       input_tensors.push_back(x.operation);
793     }
794     auto out =
795         ops::NoOp(s.WithControlDependencies(input_tensors).WithOpName("out"));
796 
797     grappler_item_ = absl::make_unique<GrapplerItem>();
798     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
799 
800     grappler_item_->id = "test_control_dependency_graph";
801     grappler_item_->fetch = {"out"};
802 
803     dependency_["out"] = input_noop_names;
804   }
805 
CreateGrapplerItemWithAddFromOneTensor()806   void CreateGrapplerItemWithAddFromOneTensor() {
807     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
808     auto x = tensorflow::ops::RandomUniform(
809         s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
810 
811     auto y = tensorflow::ops::Add(s.WithOpName("y"), x, x);
812     Output fetch = ops::Identity(s.WithOpName("fetch"), y);
813 
814     grappler_item_ = absl::make_unique<GrapplerItem>();
815     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
816 
817     grappler_item_->id = "test_add_from_one_tensor";
818     grappler_item_->fetch = {"fetch"};
819 
820     dependency_["fetch"] = {"y"};
821     dependency_["y"] = {"x"};
822   }
823 
CreateGrapplerItemWithSwitchMergeInput()824   void CreateGrapplerItemWithSwitchMergeInput() {
825     // sw = Switch(x, pred)
826     // a = Add(S:1, b)
827     // m = Merge(sw:0, a)
828     // y = Add(m, z)
829 
830     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
831     auto x = ops::RandomUniform(
832         s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
833     auto pred = ops::Const(s.WithOpName("pred"), false, {});
834     auto sw = ops::Switch(s.WithOpName("switch"), x, pred);
835     auto b = ops::RandomUniform(
836         s.WithOpName("b"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
837     auto a = ops::Add(s.WithOpName("a"), sw.output_true, b);
838     auto m = ops::Merge(s.WithOpName("m"), {sw.output_false, a.z});
839     auto z = ops::RandomUniform(
840         s.WithOpName("z"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
841     auto y = ops::Add(s.WithOpName("y"), m.output, z);
842 
843     grappler_item_ = absl::make_unique<GrapplerItem>();
844     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
845 
846     grappler_item_->id = "test_add_merge_switch";
847     grappler_item_->fetch = {"y"};
848 
849     dependency_["y"] = {"m", "z"};
850   }
851 
852   // FusedBN [an op with multiple outputs] with multiple consumers (including
853   // control dependency).
CreateGrapplerItemWithBatchNorm()854   void CreateGrapplerItemWithBatchNorm() {
855     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
856     auto x = ops::RandomUniform(
857         s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
858     auto scale =
859         ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
860     auto offset =
861         ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
862     auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
863     auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
864 
865     auto batch_norm = ops::FusedBatchNorm(
866         s.WithOpName("bn"), x, scale, offset, mean, var,
867         ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
868     auto y = batch_norm.y;
869     auto batch_mean = batch_norm.batch_mean;
870     auto batch_var = batch_norm.batch_variance;
871 
872     auto z1 = ops::Add(s.WithOpName("z1"), x, y);
873     auto z2 = ops::Add(s.WithOpName("z2"), batch_var, batch_var);
874     auto z3 = ops::Add(s.WithOpName("z3"), batch_var, batch_var);
875     std::vector<Operation> input_tensors = {
876         batch_mean.op(),
877         z1.z.op(),
878         z2.z.op(),
879         z3.z.op(),
880     };
881     auto z4 = ops::NoOp(s.WithControlDependencies(batch_var).WithOpName("z4"));
882 
883     grappler_item_ = absl::make_unique<GrapplerItem>();
884     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
885 
886     grappler_item_->id = "test_complex_dependency_graph";
887     grappler_item_->fetch = {"z1", "z2", "z3", "z4"};
888 
889     dependency_["bn"] = {"x", "scale", "offset", "mean", "var"};
890     dependency_["z1"] = {"x", "bn"};
891     dependency_["z2"] = {"bn"};
892     dependency_["z3"] = {"bn"};
893     dependency_["z4"] = {"bn"};
894   }
895 
CreateGrapplerItemWithSendRecv()896   void CreateGrapplerItemWithSendRecv() {
897     const string gdef_ascii = R"EOF(
898 node {
899   name: "Const"
900   op: "Const"
901   device: "/job:localhost/replica:0/task:0/device:CPU:0"
902   attr {
903     key: "dtype"
904     value {
905       type: DT_FLOAT
906     }
907   }
908   attr {
909     key: "value"
910     value {
911       tensor {
912         dtype: DT_FLOAT
913         tensor_shape {
914         }
915         float_val: 3.1415
916       }
917     }
918   }
919 }
920 node {
921   name: "Send"
922   op: "_Send"
923   input: "Const"
924   device: "/job:localhost/replica:0/task:0/device:CPU:0"
925   attr {
926     key: "T"
927     value {
928       type: DT_FLOAT
929     }
930   }
931   attr {
932     key: "client_terminated"
933     value {
934       b: false
935     }
936   }
937   attr {
938     key: "recv_device"
939     value {
940       s: "/job:localhost/replica:0/task:0/device:CPU:0"
941     }
942   }
943   attr {
944     key: "send_device"
945     value {
946       s: "/job:localhost/replica:0/task:0/device:CPU:0"
947     }
948   }
949   attr {
950     key: "send_device_incarnation"
951     value {
952       i: 0
953     }
954   }
955   attr {
956     key: "tensor_name"
957     value {
958       s: "test"
959     }
960   }
961 }
962 node {
963   name: "Recv"
964   op: "_Recv"
965   device: "/job:localhost/replica:0/task:0/device:CPU:0"
966   attr {
967     key: "client_terminated"
968     value {
969       b: false
970     }
971   }
972   attr {
973     key: "recv_device"
974     value {
975       s: "/job:localhost/replica:0/task:0/device:CPU:0"
976     }
977   }
978   attr {
979     key: "send_device"
980     value {
981       s: "/job:localhost/replica:0/task:0/device:CPU:0"
982     }
983   }
984   attr {
985     key: "send_device_incarnation"
986     value {
987       i: 0
988     }
989   }
990   attr {
991     key: "tensor_name"
992     value {
993       s: "test"
994     }
995   }
996   attr {
997     key: "tensor_type"
998     value {
999       type: DT_FLOAT
1000     }
1001   }
1002 }
1003 library {
1004 }
1005 versions {
1006   producer: 24
1007 }
1008     )EOF";
1009 
1010     grappler_item_ = absl::make_unique<GrapplerItem>();
1011 
1012     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
1013                                                 &grappler_item_->graph));
1014     grappler_item_->id = "test_graph";
1015     grappler_item_->fetch = {"Recv"};
1016   }
1017 
CreateGrapplerItemWithRecvWithoutSend()1018   void CreateGrapplerItemWithRecvWithoutSend() {
1019     const string gdef_ascii = R"EOF(
1020 node {
1021   name: "Recv"
1022   op: "_Recv"
1023   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1024   attr {
1025     key: "client_terminated"
1026     value {
1027       b: false
1028     }
1029   }
1030   attr {
1031     key: "recv_device"
1032     value {
1033       s: "/job:localhost/replica:0/task:0/device:CPU:0"
1034     }
1035   }
1036   attr {
1037     key: "send_device"
1038     value {
1039       s: "/job:localhost/replica:0/task:0/device:CPU:0"
1040     }
1041   }
1042   attr {
1043     key: "send_device_incarnation"
1044     value {
1045       i: 0
1046     }
1047   }
1048   attr {
1049     key: "tensor_name"
1050     value {
1051       s: "test"
1052     }
1053   }
1054   attr {
1055     key: "tensor_type"
1056     value {
1057       type: DT_FLOAT
1058     }
1059   }
1060 }
1061 library {
1062 }
1063 versions {
1064   producer: 24
1065 }
1066     )EOF";
1067 
1068     grappler_item_ = absl::make_unique<GrapplerItem>();
1069     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
1070                                                 &grappler_item_->graph));
1071     grappler_item_->id = "test_graph";
1072     grappler_item_->fetch = {"Recv"};
1073   }
1074 
1075   // A simple while loop
CreateGrapplerItemWithLoop()1076   void CreateGrapplerItemWithLoop() {
1077     // Test graph produced in python using:
1078     /*
1079       with tf.Graph().as_default():
1080       i0 = tf.constant(0)
1081       m0 = tf.ones([2, 2])
1082       c = lambda i, m: i < 10
1083       b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
1084       r = tf.while_loop(
1085       c, b, loop_vars=[i0, m0],
1086       shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
1087       with open('/tmp/graph.pbtxt', 'w') as f:
1088       f.write(str(tf.get_default_graph().as_graph_def()))
1089     */
1090     const string gdef_ascii = R"EOF(
1091 node {
1092   name: "Const"
1093   op: "Const"
1094   attr {
1095     key: "dtype"
1096     value {
1097       type: DT_INT32
1098     }
1099   }
1100   attr {
1101     key: "value"
1102     value {
1103       tensor {
1104         dtype: DT_INT32
1105         tensor_shape {
1106         }
1107         int_val: 0
1108       }
1109     }
1110   }
1111 }
1112 node {
1113   name: "ones"
1114   op: "Const"
1115   attr {
1116     key: "dtype"
1117     value {
1118       type: DT_FLOAT
1119     }
1120   }
1121   attr {
1122     key: "value"
1123     value {
1124       tensor {
1125         dtype: DT_FLOAT
1126         tensor_shape {
1127           dim {
1128             size: 2
1129           }
1130           dim {
1131             size: 2
1132           }
1133         }
1134         float_val: 1.0
1135       }
1136     }
1137   }
1138 }
1139 node {
1140   name: "while/Enter"
1141   op: "Enter"
1142   input: "Const"
1143   attr {
1144     key: "T"
1145     value {
1146       type: DT_INT32
1147     }
1148   }
1149   attr {
1150     key: "frame_name"
1151     value {
1152       s: "while/while/"
1153     }
1154   }
1155   attr {
1156     key: "is_constant"
1157     value {
1158       b: false
1159     }
1160   }
1161   attr {
1162     key: "parallel_iterations"
1163     value {
1164       i: 10
1165     }
1166   }
1167 }
1168 node {
1169   name: "while/Enter_1"
1170   op: "Enter"
1171   input: "ones"
1172   attr {
1173     key: "T"
1174     value {
1175       type: DT_FLOAT
1176     }
1177   }
1178   attr {
1179     key: "frame_name"
1180     value {
1181       s: "while/while/"
1182     }
1183   }
1184   attr {
1185     key: "is_constant"
1186     value {
1187       b: false
1188     }
1189   }
1190   attr {
1191     key: "parallel_iterations"
1192     value {
1193       i: 10
1194     }
1195   }
1196 }
1197 node {
1198   name: "while/Merge"
1199   op: "Merge"
1200   input: "while/Enter"
1201   input: "while/NextIteration"
1202   attr {
1203     key: "N"
1204     value {
1205       i: 2
1206     }
1207   }
1208   attr {
1209     key: "T"
1210     value {
1211       type: DT_INT32
1212     }
1213   }
1214 }
1215 node {
1216   name: "while/Merge_1"
1217   op: "Merge"
1218   input: "while/Enter_1"
1219   input: "while/NextIteration_1"
1220   attr {
1221     key: "N"
1222     value {
1223       i: 2
1224     }
1225   }
1226   attr {
1227     key: "T"
1228     value {
1229       type: DT_FLOAT
1230     }
1231   }
1232 }
1233 node {
1234   name: "while/Less/y"
1235   op: "Const"
1236   input: "^while/Merge"
1237   attr {
1238     key: "dtype"
1239     value {
1240       type: DT_INT32
1241     }
1242   }
1243   attr {
1244     key: "value"
1245     value {
1246       tensor {
1247         dtype: DT_INT32
1248         tensor_shape {
1249         }
1250         int_val: 10
1251       }
1252     }
1253   }
1254 }
1255 node {
1256   name: "while/Less"
1257   op: "Less"
1258   input: "while/Merge"
1259   input: "while/Less/y"
1260   attr {
1261     key: "T"
1262     value {
1263       type: DT_INT32
1264     }
1265   }
1266 }
1267 node {
1268   name: "while/LoopCond"
1269   op: "LoopCond"
1270   input: "while/Less"
1271 }
1272 node {
1273   name: "while/Switch"
1274   op: "Switch"
1275   input: "while/Merge"
1276   input: "while/LoopCond"
1277   attr {
1278     key: "T"
1279     value {
1280       type: DT_INT32
1281     }
1282   }
1283   attr {
1284     key: "_class"
1285     value {
1286       list {
1287         s: "loc:@while/Merge"
1288       }
1289     }
1290   }
1291 }
1292 node {
1293   name: "while/Switch_1"
1294   op: "Switch"
1295   input: "while/Merge_1"
1296   input: "while/LoopCond"
1297   attr {
1298     key: "T"
1299     value {
1300       type: DT_FLOAT
1301     }
1302   }
1303   attr {
1304     key: "_class"
1305     value {
1306       list {
1307         s: "loc:@while/Merge_1"
1308       }
1309     }
1310   }
1311 }
1312 node {
1313   name: "while/Identity"
1314   op: "Identity"
1315   input: "while/Switch:1"
1316   attr {
1317     key: "T"
1318     value {
1319       type: DT_INT32
1320     }
1321   }
1322 }
1323 node {
1324   name: "while/Identity_1"
1325   op: "Identity"
1326   input: "while/Switch_1:1"
1327   attr {
1328     key: "T"
1329     value {
1330       type: DT_FLOAT
1331     }
1332   }
1333 }
1334 node {
1335   name: "while/add/y"
1336   op: "Const"
1337   input: "^while/Identity"
1338   attr {
1339     key: "dtype"
1340     value {
1341       type: DT_INT32
1342     }
1343   }
1344   attr {
1345     key: "value"
1346     value {
1347       tensor {
1348         dtype: DT_INT32
1349         tensor_shape {
1350         }
1351         int_val: 1
1352       }
1353     }
1354   }
1355 }
1356 node {
1357   name: "while/add"
1358   op: "Add"
1359   input: "while/Identity"
1360   input: "while/add/y"
1361   attr {
1362     key: "T"
1363     value {
1364       type: DT_INT32
1365     }
1366   }
1367 }
1368 node {
1369   name: "while/concat/axis"
1370   op: "Const"
1371   input: "^while/Identity"
1372   attr {
1373     key: "dtype"
1374     value {
1375       type: DT_INT32
1376     }
1377   }
1378   attr {
1379     key: "value"
1380     value {
1381       tensor {
1382         dtype: DT_INT32
1383         tensor_shape {
1384         }
1385         int_val: 0
1386       }
1387     }
1388   }
1389 }
1390 node {
1391   name: "while/concat"
1392   op: "ConcatV2"
1393   input: "while/Identity_1"
1394   input: "while/Identity_1"
1395   input: "while/concat/axis"
1396   attr {
1397     key: "N"
1398     value {
1399       i: 2
1400     }
1401   }
1402   attr {
1403     key: "T"
1404     value {
1405       type: DT_FLOAT
1406     }
1407   }
1408   attr {
1409     key: "Tidx"
1410     value {
1411       type: DT_INT32
1412     }
1413   }
1414 }
1415 node {
1416   name: "while/NextIteration"
1417   op: "NextIteration"
1418   input: "while/add"
1419   attr {
1420     key: "T"
1421     value {
1422       type: DT_INT32
1423     }
1424   }
1425 }
1426 node {
1427   name: "while/NextIteration_1"
1428   op: "NextIteration"
1429   input: "while/concat"
1430   attr {
1431     key: "T"
1432     value {
1433       type: DT_FLOAT
1434     }
1435   }
1436 }
1437 node {
1438   name: "while/Exit"
1439   op: "Exit"
1440   input: "while/Switch"
1441   attr {
1442     key: "T"
1443     value {
1444       type: DT_INT32
1445     }
1446   }
1447 }
1448 node {
1449   name: "while/Exit_1"
1450   op: "Exit"
1451   input: "while/Switch_1"
1452   attr {
1453     key: "T"
1454     value {
1455       type: DT_FLOAT
1456     }
1457   }
1458 }
1459 versions {
1460   producer: 21
1461 }
1462   )EOF";
1463 
1464     grappler_item_ = absl::make_unique<GrapplerItem>();
1465     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
1466                                                 &grappler_item_->graph));
1467     grappler_item_->id = "test_graph";
1468     grappler_item_->fetch = {"while/Exit", "while/Exit_1"};
1469   }
1470 
1471   // A simple while loop strengthened with Switch outputs xxx.
CreateGrapplerItemWithLoopAnnotated()1472   void CreateGrapplerItemWithLoopAnnotated() {
1473     // Test graph produced in python using:
1474     /*
1475       with tf.Graph().as_default():
1476       i0 = tf.constant(0)
1477       m0 = tf.ones([2, 2])
1478       c = lambda i, m: i < 10
1479       b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
1480       r = tf.while_loop(
1481       c, b, loop_vars=[i0, m0],
1482       shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
1483       with open('/tmp/graph.pbtxt', 'w') as f:
1484       f.write(str(tf.get_default_graph().as_graph_def()))
1485     */
1486     const string gdef_ascii = R"EOF(
1487 node {
1488   name: "Const"
1489   op: "Const"
1490   attr {
1491     key: "dtype"
1492     value {
1493       type: DT_INT32
1494     }
1495   }
1496   attr {
1497     key: "value"
1498     value {
1499       tensor {
1500         dtype: DT_INT32
1501         tensor_shape {
1502         }
1503         int_val: 0
1504       }
1505     }
1506   }
1507   attr {
1508     key: "_execution_count"
1509     value {
1510       i: 1
1511     }
1512   }
1513 }
1514 node {
1515   name: "ones"
1516   op: "Const"
1517   attr {
1518     key: "dtype"
1519     value {
1520       type: DT_FLOAT
1521     }
1522   }
1523   attr {
1524     key: "value"
1525     value {
1526       tensor {
1527         dtype: DT_FLOAT
1528         tensor_shape {
1529           dim {
1530             size: 2
1531           }
1532           dim {
1533             size: 2
1534           }
1535         }
1536         float_val: 1.0
1537       }
1538     }
1539   }
1540   attr {
1541     key: "_execution_count"
1542     value {
1543       i: 1
1544     }
1545   }
1546 }
1547 node {
1548   name: "while/Enter"
1549   op: "Enter"
1550   input: "Const"
1551   attr {
1552     key: "T"
1553     value {
1554       type: DT_INT32
1555     }
1556   }
1557   attr {
1558     key: "frame_name"
1559     value {
1560       s: "while/while/"
1561     }
1562   }
1563   attr {
1564     key: "is_constant"
1565     value {
1566       b: false
1567     }
1568   }
1569   attr {
1570     key: "parallel_iterations"
1571     value {
1572       i: 10
1573     }
1574   }
1575   attr {
1576     key: "_execution_count"
1577     value {
1578       i: 1
1579     }
1580   }
1581 }
1582 node {
1583   name: "while/Enter_1"
1584   op: "Enter"
1585   input: "ones"
1586   attr {
1587     key: "T"
1588     value {
1589       type: DT_FLOAT
1590     }
1591   }
1592   attr {
1593     key: "frame_name"
1594     value {
1595       s: "while/while/"
1596     }
1597   }
1598   attr {
1599     key: "is_constant"
1600     value {
1601       b: false
1602     }
1603   }
1604   attr {
1605     key: "parallel_iterations"
1606     value {
1607       i: 10
1608     }
1609   }
1610   attr {
1611     key: "_execution_count"
1612     value {
1613       i: 1
1614     }
1615   }
1616 }
1617 node {
1618   name: "while/Merge"
1619   op: "Merge"
1620   input: "while/Enter"
1621   input: "while/NextIteration"
1622   attr {
1623     key: "N"
1624     value {
1625       i: 2
1626     }
1627   }
1628   attr {
1629     key: "T"
1630     value {
1631       type: DT_INT32
1632     }
1633   }
1634   attr {
1635     key: "_execution_count"
1636     value {
1637       i: 10
1638     }
1639   }
1640 }
1641 node {
1642   name: "while/Merge_1"
1643   op: "Merge"
1644   input: "while/Enter_1"
1645   input: "while/NextIteration_1"
1646   attr {
1647     key: "N"
1648     value {
1649       i: 2
1650     }
1651   }
1652   attr {
1653     key: "T"
1654     value {
1655       type: DT_FLOAT
1656     }
1657   }
1658   attr {
1659     key: "_execution_count"
1660     value {
1661       i: 10
1662     }
1663   }
1664 }
1665 node {
1666   name: "while/Less/y"
1667   op: "Const"
1668   input: "^while/Merge"
1669   attr {
1670     key: "dtype"
1671     value {
1672       type: DT_INT32
1673     }
1674   }
1675   attr {
1676     key: "value"
1677     value {
1678       tensor {
1679         dtype: DT_INT32
1680         tensor_shape {
1681         }
1682         int_val: 10
1683       }
1684     }
1685   }
1686   attr {
1687     key: "_execution_count"
1688     value {
1689       i: 10
1690     }
1691   }
1692 }
1693 node {
1694   name: "while/Less"
1695   op: "Less"
1696   input: "while/Merge"
1697   input: "while/Less/y"
1698   attr {
1699     key: "T"
1700     value {
1701       type: DT_INT32
1702     }
1703   }
1704   attr {
1705     key: "_execution_count"
1706     value {
1707       i: 10
1708     }
1709   }
1710 }
1711 node {
1712   name: "while/LoopCond"
1713   op: "LoopCond"
1714   input: "while/Less"
1715   attr {
1716     key: "_execution_count"
1717     value {
1718       i: 10
1719     }
1720   }
1721 }
1722 node {
1723   name: "while/Switch"
1724   op: "Switch"
1725   input: "while/Merge"
1726   input: "while/LoopCond"
1727   attr {
1728     key: "T"
1729     value {
1730       type: DT_INT32
1731     }
1732   }
1733   attr {
1734     key: "_class"
1735     value {
1736       list {
1737         s: "loc:@while/Merge"
1738       }
1739     }
1740   }
1741   attr {
1742     key: "_execution_count"
1743     value {
1744       i: 11
1745     }
1746   }
1747   attr {
1748     key: "_output_slot_vector"
1749     value {
1750       list {
1751         i: 1
1752         i: 1
1753         i: 1
1754         i: 1
1755         i: 1
1756         i: 1
1757         i: 1
1758         i: 1
1759         i: 1
1760         i: 1
1761         i: 0
1762       }
1763     }
1764   }
1765 }
1766 node {
1767   name: "while/Switch_1"
1768   op: "Switch"
1769   input: "while/Merge_1"
1770   input: "while/LoopCond"
1771   attr {
1772     key: "T"
1773     value {
1774       type: DT_FLOAT
1775     }
1776   }
1777   attr {
1778     key: "_class"
1779     value {
1780       list {
1781         s: "loc:@while/Merge_1"
1782       }
1783     }
1784   }
1785   attr {
1786     key: "_execution_count"
1787     value {
1788       i: 11
1789     }
1790   }
1791   attr {
1792     key: "_output_slot_vector"
1793     value {
1794       list {
1795         i: 1
1796         i: 1
1797         i: 1
1798         i: 1
1799         i: 1
1800         i: 1
1801         i: 1
1802         i: 1
1803         i: 1
1804         i: 1
1805         i: 0
1806       }
1807     }
1808   }
1809 }
1810 node {
1811   name: "while/Identity"
1812   op: "Identity"
1813   input: "while/Switch:1"
1814   attr {
1815     key: "T"
1816     value {
1817       type: DT_INT32
1818     }
1819   }
1820   attr {
1821     key: "_execution_count"
1822     value {
1823       i: 10
1824     }
1825   }
1826 }
1827 node {
1828   name: "while/Identity_1"
1829   op: "Identity"
1830   input: "while/Switch_1:1"
1831   attr {
1832     key: "T"
1833     value {
1834       type: DT_FLOAT
1835     }
1836   }
1837   attr {
1838     key: "_execution_count"
1839     value {
1840       i: 10
1841     }
1842   }
1843 }
1844 node {
1845   name: "while/add/y"
1846   op: "Const"
1847   input: "^while/Identity"
1848   attr {
1849     key: "dtype"
1850     value {
1851       type: DT_INT32
1852     }
1853   }
1854   attr {
1855     key: "value"
1856     value {
1857       tensor {
1858         dtype: DT_INT32
1859         tensor_shape {
1860         }
1861         int_val: 1
1862       }
1863     }
1864   }
1865   attr {
1866     key: "_execution_count"
1867     value {
1868       i: 10
1869     }
1870   }
1871 }
1872 node {
1873   name: "while/add"
1874   op: "Add"
1875   input: "while/Identity"
1876   input: "while/add/y"
1877   attr {
1878     key: "T"
1879     value {
1880       type: DT_INT32
1881     }
1882   }
1883   attr {
1884     key: "_execution_count"
1885     value {
1886       i: 10
1887     }
1888   }
1889 }
1890 node {
1891   name: "while/concat/axis"
1892   op: "Const"
1893   input: "^while/Identity"
1894   attr {
1895     key: "dtype"
1896     value {
1897       type: DT_INT32
1898     }
1899   }
1900   attr {
1901     key: "value"
1902     value {
1903       tensor {
1904         dtype: DT_INT32
1905         tensor_shape {
1906         }
1907         int_val: 0
1908       }
1909     }
1910   }
1911   attr {
1912     key: "_execution_count"
1913     value {
1914       i: 10
1915     }
1916   }
1917 }
1918 node {
1919   name: "while/concat"
1920   op: "ConcatV2"
1921   input: "while/Identity_1"
1922   input: "while/Identity_1"
1923   input: "while/concat/axis"
1924   attr {
1925     key: "N"
1926     value {
1927       i: 2
1928     }
1929   }
1930   attr {
1931     key: "T"
1932     value {
1933       type: DT_FLOAT
1934     }
1935   }
1936   attr {
1937     key: "Tidx"
1938     value {
1939       type: DT_INT32
1940     }
1941   }
1942   attr {
1943     key: "_execution_count"
1944     value {
1945       i: 10
1946     }
1947   }
1948 }
1949 node {
1950   name: "while/NextIteration"
1951   op: "NextIteration"
1952   input: "while/add"
1953   attr {
1954     key: "T"
1955     value {
1956       type: DT_INT32
1957     }
1958   }
1959   attr {
1960     key: "_execution_count"
1961     value {
1962       i: 10
1963     }
1964   }
1965 }
1966 node {
1967   name: "while/NextIteration_1"
1968   op: "NextIteration"
1969   input: "while/concat"
1970   attr {
1971     key: "T"
1972     value {
1973       type: DT_FLOAT
1974     }
1975   }
1976   attr {
1977     key: "_execution_count"
1978     value {
1979       i: 10
1980     }
1981   }
1982 }
1983 node {
1984   name: "while/Exit"
1985   op: "Exit"
1986   input: "while/Switch"
1987   attr {
1988     key: "T"
1989     value {
1990       type: DT_INT32
1991     }
1992   }
1993   attr {
1994     key: "_execution_count"
1995     value {
1996       i: 1
1997     }
1998   }
1999 }
2000 node {
2001   name: "while/Exit_1"
2002   op: "Exit"
2003   input: "while/Switch_1"
2004   attr {
2005     key: "T"
2006     value {
2007       type: DT_FLOAT
2008     }
2009   }
2010   attr {
2011     key: "_execution_count"
2012     value {
2013       i: 1
2014     }
2015   }
2016 }
2017 versions {
2018   producer: 21
2019 }
2020   )EOF";
2021 
2022     grappler_item_.reset(new GrapplerItem);
2023     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
2024                                                 &grappler_item_->graph));
2025     grappler_item_->id = "test_graph";
2026     grappler_item_->fetch = {"while/Exit", "while/Exit_1"};
2027   }
2028 
2029   // A simple condition graph.
CreateGrapplerItemWithCondition()2030   void CreateGrapplerItemWithCondition() {
2031     // Handcrafted test graph: a/Less -> Switch -> First/Second -> Merge.
2032     const string gdef_ascii = R"EOF(
2033 node {
2034   name: "a"
2035   op: "Const"
2036   attr {
2037     key: "dtype"
2038     value {
2039       type: DT_FLOAT
2040     }
2041   }
2042   attr {
2043     key: "value"
2044     value {
2045       tensor {
2046         dtype: DT_FLOAT
2047         tensor_shape {
2048         }
2049         float_val: 2.0
2050       }
2051     }
2052   }
2053 }
2054 node {
2055   name: "Less"
2056   op: "Const"
2057   attr {
2058     key: "dtype"
2059     value {
2060       type: DT_BOOL
2061     }
2062   }
2063   attr {
2064     key: "value"
2065     value {
2066       tensor {
2067         dtype: DT_BOOL
2068         tensor_shape {
2069         }
2070         tensor_content: "\001"
2071       }
2072     }
2073   }
2074 }
2075 node {
2076   name: "Switch"
2077   op: "Switch"
2078   input: "a"
2079   input: "Less"
2080   attr {
2081     key: "T"
2082     value {
2083       type: DT_FLOAT
2084     }
2085   }
2086 }
2087 node {
2088   name: "First"
2089   op: "Identity"
2090   input: "Switch"
2091   attr {
2092     key: "T"
2093     value {
2094       type: DT_FLOAT
2095     }
2096   }
2097 }
2098 node {
2099   name: "Second"
2100   op: "Identity"
2101   input: "Switch:1"
2102   attr {
2103     key: "T"
2104     value {
2105       type: DT_FLOAT
2106     }
2107   }
2108 }
2109 node {
2110   name: "Merge"
2111   op: "Merge"
2112   input: "First"
2113   input: "Second"
2114   attr {
2115     key: "N"
2116     value {
2117       i: 2
2118     }
2119   }
2120   attr {
2121     key: "T"
2122     value {
2123       type: DT_FLOAT
2124     }
2125   }
2126 }
2127 versions {
2128   producer: 27
2129 })EOF";
2130 
2131     grappler_item_ = absl::make_unique<GrapplerItem>();
2132     CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
2133                                                 &grappler_item_->graph));
2134     grappler_item_->id = "test_graph";
2135     grappler_item_->fetch = {"Merge"};
2136   }
2137 
2138   // Create a FusedBatchNorm op that has multiple output ports.
CreateGrapplerItemWithInterDeviceTransfers()2139   void CreateGrapplerItemWithInterDeviceTransfers() {
2140     tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
2141 
2142     // Create a FusedBatchNorm op that has multiple output ports.
2143     auto x = ops::RandomUniform(
2144         s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
2145     auto scale =
2146         ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
2147     auto offset =
2148         ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
2149     auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
2150     auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
2151 
2152     auto batch_norm = ops::FusedBatchNorm(
2153         s.WithOpName("bn"), x, scale, offset, mean, var,
2154         ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
2155     auto y = batch_norm.y;
2156     auto batch_mean = batch_norm.batch_mean;
2157     auto batch_var = batch_norm.batch_variance;
2158     // y1 and y2 take the same tensor, so there should be only 1 Send and Recv.
2159     auto y1 = ops::Identity(s.WithOpName("y1").WithDevice(kCPU1), y);
2160     auto y2 = ops::Identity(s.WithOpName("y2").WithDevice(kCPU1), y);
2161     // batch_mean1 and batch_var1 take different output ports, so each will
2162     // initiate Send/Recv.
2163     auto batch_mean1 = ops::Identity(
2164         s.WithOpName("batch_mean1").WithDevice(kCPU1), batch_mean);
2165     auto batch_var1 =
2166         ops::Identity(s.WithOpName("batch_var1").WithDevice(kCPU1), batch_var);
2167     // This is control dependency.
2168     auto control_dep = ops::NoOp(s.WithOpName("control_dep")
2169                                      .WithControlDependencies(y)
2170                                      .WithDevice(kCPU1));
2171 
2172     grappler_item_ = absl::make_unique<GrapplerItem>();
2173     TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
2174     grappler_item_->id = "test_conv2d_graph";
2175     grappler_item_->fetch = {"y1", "y2", "batch_mean1", "batch_var1",
2176                              "control_dep"};
2177 
2178     dependency_["bn"] = {"x", "mean", "var"};
2179     dependency_["y1"] = {"bn"};
2180     dependency_["y2"] = {"bn"};
2181     dependency_["batch_mean1"] = {"bn"};
2182     dependency_["batch_var1"] = {"bn"};
2183     dependency_["control_dep"] = {"bn"};
2184   }
2185 
2186   // Call this after creating grappler_item_ and setting up dependency_.
InitScheduler()2187   void InitScheduler() { TF_ASSERT_OK(scheduler_->Init(grappler_item_.get())); }
2188 
2189   // Returns cost based on op.
SimplePredictCosts(const OpContext & op_context) const2190   Costs SimplePredictCosts(const OpContext& op_context) const {
2191     Costs c;
2192     int64 exec_cost = 0;
2193     if (op_context.op_info.op() == "MatMul") {
2194       exec_cost = 2000000000;
2195     } else if (op_context.op_info.op() == "RandomUniform") {
2196       exec_cost = 1000000000;
2197     } else {
2198       exec_cost = 1000;
2199     }
2200     c.execution_time = Costs::NanoSeconds(exec_cost);
2201     return c;
2202   }
2203 
2204   // Call this after init scheduler_. Scheduler stops after executing
2205   // target_node.
RunScheduler(const string & target_node)2206   std::unordered_map<string, OpContext> RunScheduler(
2207       const string& target_node) {
2208     std::unordered_map<string, OpContext> ops_executed;
2209     bool more_nodes = true;
2210     do {
2211       OpContext op_context = scheduler_->GetCurrNode();
2212       ops_executed[op_context.name] = op_context;
2213       std::cout << op_context.name << std::endl;
2214 
2215       Costs node_costs = SimplePredictCosts(op_context);
2216 
2217       // Check scheduling order.
2218       auto it = dependency_.find(op_context.name);
2219       if (it != dependency_.end()) {
2220         for (const auto& preceding_node : it->second) {
2221           EXPECT_GT(ops_executed.count(preceding_node), 0);
2222         }
2223       }
2224       more_nodes = scheduler_->MarkCurrNodeExecuted(node_costs);
2225 
2226       if (op_context.name == target_node) {
2227         // Scheduler has the state after executing the target node.
2228         break;
2229       }
2230     } while (more_nodes);
2231     return ops_executed;
2232   }
2233 
2234   // Helper method for validating a vector.
2235   template <typename T>
ExpectVectorEq(const std::vector<T> & expected,const std::vector<T> & test_elements)2236   void ExpectVectorEq(const std::vector<T>& expected,
2237                       const std::vector<T>& test_elements) {
2238     // Set of expected elements for an easy comparison.
2239     std::set<T> expected_set(expected.begin(), expected.end());
2240     for (const auto& element : test_elements) {
2241       EXPECT_GT(expected_set.count(element), 0);
2242     }
2243     EXPECT_EQ(expected.size(), test_elements.size());
2244   }
2245 
2246   // Helper method that checks the name of nodes.
ValidateNodeDefs(const std::vector<string> & expected,const std::vector<const NodeDef * > & node_defs)2247   void ValidateNodeDefs(const std::vector<string>& expected,
2248                         const std::vector<const NodeDef*>& node_defs) {
2249     std::vector<string> node_names;
2250     std::transform(node_defs.begin(), node_defs.end(),
2251                    std::back_inserter(node_names),
2252                    [](const NodeDef* node) { return node->name(); });
2253     ExpectVectorEq(expected, node_names);
2254   }
2255 
2256   // Helper method for validating a set.
2257   template <typename T>
ExpectSetEq(const std::set<T> & expected,const std::set<T> & test_elements)2258   void ExpectSetEq(const std::set<T>& expected,
2259                    const std::set<T>& test_elements) {
2260     for (const auto& element : test_elements) {
2261       EXPECT_GT(expected.count(element), 0);
2262     }
2263     EXPECT_EQ(expected.size(), test_elements.size());
2264   }
2265 
2266   // Helper method for validating an unordered map.
2267   template <typename T, typename U>
ExpectUnorderedMapEq(const std::unordered_map<T,U> & expected,const std::unordered_map<T,U> & test_map)2268   void ExpectUnorderedMapEq(const std::unordered_map<T, U>& expected,
2269                             const std::unordered_map<T, U>& test_map) {
2270     EXPECT_EQ(expected.size(), test_map.size());
2271     for (const auto& key_val : expected) {
2272       EXPECT_GT(test_map.count(key_val.first), 0);
2273       EXPECT_EQ(test_map.at(key_val.first), key_val.second);
2274     }
2275   }
2276 
2277   // Helper method that checks name - port pairs.
ValidateMemoryUsageSnapshot(const std::vector<string> & expected_names,const int port_num_expected,const std::unordered_set<std::pair<const NodeDef *,int>,DeviceState::NodePairHash> & mem_usage_snapshot)2278   void ValidateMemoryUsageSnapshot(
2279       const std::vector<string>& expected_names, const int port_num_expected,
2280       const std::unordered_set<std::pair<const NodeDef*, int>,
2281                                DeviceState::NodePairHash>& mem_usage_snapshot) {
2282     std::set<std::pair<string, int>> nodes_at_peak_mem_usage;
2283     std::transform(
2284         mem_usage_snapshot.begin(), mem_usage_snapshot.end(),
2285         std::inserter(nodes_at_peak_mem_usage, nodes_at_peak_mem_usage.begin()),
2286         [](const std::pair<const NodeDef*, int>& node_port) {
2287           return std::make_pair(node_port.first->name(), node_port.second);
2288         });
2289     std::set<std::pair<string, int>> expected;
2290     std::transform(expected_names.begin(), expected_names.end(),
2291                    std::inserter(expected, expected.begin()),
2292                    [port_num_expected](const string& name) {
2293                      return std::make_pair(name, port_num_expected);
2294                    });
2295     ExpectSetEq(expected, nodes_at_peak_mem_usage);
2296   }
2297 
2298   // Helper method for checking nodes dependency.
ValidateDependencyChain(const std::unordered_map<string,int64> & start_times,const std::vector<string> & nodes_in_dependency_order)2299   void ValidateDependencyChain(
2300       const std::unordered_map<string, int64>& start_times,
2301       const std::vector<string>& nodes_in_dependency_order) {
2302     int64 prev_node_time = -1;
2303     for (const auto& node : nodes_in_dependency_order) {
2304       int64 curr_node_time = start_times.at(node);
2305       EXPECT_GE(curr_node_time, prev_node_time);
2306       prev_node_time = curr_node_time;
2307     }
2308   }
2309 
2310   // cluster_ and scheduler_ are initialized in the c'tor.
2311   std::unique_ptr<VirtualCluster> cluster_;
2312   std::unique_ptr<TestVirtualScheduler> scheduler_;
2313   FirstReadyManager first_ready_manager_;
2314   CompositeNodeManager composite_node_manager_;
2315 
2316   // grappler_item_ will be initialized differently for each test case.
2317   std::unique_ptr<GrapplerItem> grappler_item_;
2318   // Node name -> its preceding nodes map for testing scheduling order.
2319   std::unordered_map<string, std::vector<string>> dependency_;
2320 
2321   // Shared params for Conv2D related graphs:
2322   const int batch_size_ = 4;
2323   const int width_ = 10;
2324   const int height_ = 10;
2325   const int depth_in_ = 8;
2326   const int kernel_ = 3;
2327   const int depth_out_ = 16;
2328 };
2329 
2330 // Create small graph, run predict costs on it, make sure the costs from the
2331 // summary match the hand-calculated costs.
TEST_F(VirtualSchedulerTest,SummaryCostTest)2332 TEST_F(VirtualSchedulerTest, SummaryCostTest) {
2333   // Run matmul test.
2334   CreateGrapplerItemWithMatmulChain();
2335   InitScheduler();
2336   auto ops_executed = RunScheduler("");
2337   Costs c = scheduler_->Summary();
2338 
2339   // RandomUniform - 5 * 1s
2340   // Matmuls - 4 * 2s = 8
2341   // Misc - 5 * 1us
2342   // Total: 13000005
2343   EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count());
2344   EXPECT_EQ(grappler_item_->graph.node_size(), c.num_ops_total);
2345   EXPECT_FALSE(c.inaccurate);
2346   EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2347 }
2348 
2349 // Like the above SummaryCostTest, but makes sure the stepstats timeline is
2350 // correct.
TEST_F(VirtualSchedulerTest,SummaryCostStepStatsTest)2351 TEST_F(VirtualSchedulerTest, SummaryCostStepStatsTest) {
2352   // Run matmul test.
2353   CreateGrapplerItemWithMatmulChain();
2354   InitScheduler();
2355   auto ops_executed = RunScheduler("");
2356   RunMetadata metadata;
2357   Costs c = scheduler_->Summary(&metadata);
2358   StepStats stepstats = metadata.step_stats();
2359   EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count());
2360   EXPECT_EQ(grappler_item_->graph.node_size(), c.num_ops_total);
2361   EXPECT_FALSE(c.inaccurate);
2362   EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2363 
2364   // Should only be 1 device!
2365   EXPECT_EQ(1, stepstats.dev_stats().size());
2366 
2367   // Create a map of op name -> start and end times (micros).
2368   std::map<string, std::pair<int64, int64>> start_end_times;
2369   for (const auto& device_step_stats : stepstats.dev_stats()) {
2370     for (const auto& stats : device_step_stats.node_stats()) {
2371       int64 start = stats.all_start_micros();
2372       int64 end = start + stats.all_end_rel_micros();
2373       start_end_times[stats.node_name()] = std::pair<int64, int64>(start, end);
2374 
2375       // Make sure that the output properties are correct for
2376       // MatMul and RandomUniform operations.
2377       // We only check for dtype, and shape (excluding alloc)
2378       // since alloc is not set by the virtual scheduler.
2379       if (stats.timeline_label() == "MatMul" ||
2380           stats.timeline_label() == "RandomUniform") {
2381         EXPECT_EQ(1, stats.output().size());
2382         for (const auto& output : stats.output()) {
2383           EXPECT_EQ(DT_FLOAT, output.tensor_description().dtype());
2384           EXPECT_EQ(2, output.tensor_description().shape().dim().size());
2385           for (const auto& dim : output.tensor_description().shape().dim()) {
2386             EXPECT_EQ(3200, dim.size());
2387           }
2388         }
2389       }
2390     }
2391   }
2392 
2393   // The base start_time is the time to compute RandomUniforms
2394   int64 cur_time = static_cast<int64>(5000005);
2395   // The increment is the execution time of one matmul. See
2396   // CreateGrapplerItemWithMatmulChain for details.
2397   int64 increment = static_cast<int64>(2000000);
2398   auto op_names = {"ab", "abc", "abcd", "abcde"};
2399   for (const auto& op_name : op_names) {
2400     int64 actual_start = start_end_times[op_name].first;
2401     int64 actual_end = start_end_times[op_name].second;
2402     int64 expected_start = cur_time;
2403     int64 expected_end = cur_time + increment;
2404     EXPECT_EQ(expected_start, actual_start);
2405     EXPECT_EQ(expected_end, actual_end);
2406     cur_time += increment;
2407   }
2408 }
2409 
TEST_F(VirtualSchedulerTest,InitAndBasicScheduling)2410 TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) {
2411   // Init.
2412   CreateGrapplerItemWithConv2Ds();
2413   InitScheduler();
2414 
2415   // Run the scheduler.
2416   auto ops_executed = RunScheduler("");  // Run all the nodes.
2417 
2418   // [const and rand] * (x, y, f), and c0 and c1. c2 and z shouldn't be
2419   // executed.
2420   EXPECT_EQ(8, ops_executed.size());
2421 
2422   // x, y, f, c0, and c1 should be in the ops executed.
2423   EXPECT_GT(ops_executed.count("x"), 0);
2424   EXPECT_GT(ops_executed.count("y"), 0);
2425   EXPECT_GT(ops_executed.count("f"), 0);
2426   EXPECT_GT(ops_executed.count("c0"), 0);
2427   EXPECT_GT(ops_executed.count("c1"), 0);
2428 
2429   // z and c2 shouldn't be part of it.
2430   EXPECT_EQ(ops_executed.count("z"), 0);
2431   EXPECT_EQ(ops_executed.count("c2"), 0);
2432 
2433   // Check input / output properties.
2434   EXPECT_EQ(1, ops_executed["x"].op_info.outputs_size());
2435   EXPECT_EQ(1, ops_executed["y"].op_info.outputs_size());
2436   EXPECT_EQ(1, ops_executed["f"].op_info.outputs_size());
2437   EXPECT_EQ(2, ops_executed["c0"].op_info.inputs_size());
2438   EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size());
2439 }
2440 
TEST_F(VirtualSchedulerTest,MemoryUsage)2441 TEST_F(VirtualSchedulerTest, MemoryUsage) {
2442   // Init.
2443   CreateGrapplerItemWithAddN();
2444   InitScheduler();
2445 
2446   // Run the scheduler.
2447   RunScheduler("");
2448 
2449   const auto* device_states = scheduler_->GetDeviceStates();
2450   const auto& cpu_state = device_states->at(kCPU0);
2451 
2452   // out node adds 4 tensors, each with 10x10x10x10, so the peak memory usage
2453   // is 4 x the input tensor size while executing the out node.
2454   int64 one_input_node_size = 4 * 10 * 10 * 10 * 10;
2455   const std::vector<string> expected_names = {"x", "y", "z", "w"};
2456   EXPECT_EQ(expected_names.size() * one_input_node_size,
2457             cpu_state.max_memory_usage);
2458   ValidateMemoryUsageSnapshot(expected_names, 0 /* port_num_expected */,
2459                               cpu_state.mem_usage_snapshot_at_peak);
2460   ExpectUnorderedMapEq(
2461       {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 64)},
2462       scheduler_->GetPersistentMemoryUsage());
2463   ExpectUnorderedMapEq(
2464       {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 160000)},
2465       scheduler_->GetPeakMemoryUsage());
2466 }
2467 
TEST_F(VirtualSchedulerTest,UnnecessaryFeedNodes)2468 TEST_F(VirtualSchedulerTest, UnnecessaryFeedNodes) {
2469   CreateGrapplerItemWithUnnecessaryPlaceholderNodes();
2470   InitScheduler();
2471 
2472   // Test that scheduler can run graphs with extra unnecessary feed nodes.
2473   auto ops_executed = RunScheduler("");
2474   ASSERT_EQ(1, ops_executed.size());
2475   ASSERT_EQ(ops_executed.count("x"), 1);
2476 }
2477 
TEST_F(VirtualSchedulerTest,ControlDependency)2478 TEST_F(VirtualSchedulerTest, ControlDependency) {
2479   // Init.
2480   CreateGrapplerItemWithControlDependency();
2481   InitScheduler();
2482 
2483   // Run the scheduler.
2484   RunScheduler("");
2485 
2486   const auto* device_states = scheduler_->GetDeviceStates();
2487   const auto& cpu_state = device_states->at(kCPU0);
2488 
2489   // The graph has a NoOp that takes control dependency from 7 NoOps. The peak
2490   // memory usage is when executing the final NoOp.
2491   int64 one_input_node_size = 4;  // control dependency
2492   const std::vector<string> expected_names = {"x", "y", "z", "w",
2493                                               "u", "v", "t"};
2494   EXPECT_EQ(expected_names.size() * one_input_node_size,
2495             cpu_state.max_memory_usage);
2496   ValidateMemoryUsageSnapshot(expected_names, -1 /* port_num_expected */,
2497                               cpu_state.mem_usage_snapshot_at_peak);
2498   ExpectUnorderedMapEq(
2499       {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 0)},
2500       scheduler_->GetPersistentMemoryUsage());
2501   ExpectUnorderedMapEq(
2502       {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 28)},
2503       scheduler_->GetPeakMemoryUsage());
2504 }
2505 
TEST_F(VirtualSchedulerTest,ComplexDependency)2506 TEST_F(VirtualSchedulerTest, ComplexDependency) {
2507   // Init.
2508   CreateGrapplerItemWithBatchNorm();
2509   InitScheduler();
2510 
2511   // Run the scheduler.
2512   RunScheduler("bn");
2513 
2514   const auto& device_states = scheduler_->GetDeviceStates();
2515   const auto& cpu_state = device_states->at(kCPU0);
2516 
2517   // The graph is
2518   //  bn = FusedBatchNorm(x, scale, offset, mean, var)
2519   //  z1 = bn.y + x
2520   //  z2 = bn.var + bn.var
2521   //  z3 = bn.var + bn.var
2522   //  z4 = control dependency from bn.
2523   //  Note that bn.mean doesn't have any consumer.
2524   const int x_size = batch_size_ * width_ * height_ * depth_in_;
2525   int64 expected_size =
2526       4 * (2 * x_size /* x and bn.y */ + depth_in_ /* bn.var */ +
2527            1 /* control dependency */);
2528   EXPECT_EQ(expected_size, cpu_state.memory_usage);
2529 
2530   // Nodes currently in memory: bn's port -1, 0, and 2, and x's port 0.
2531   std::set<std::pair<string, int>> nodes_in_memory;
2532   std::transform(
2533       cpu_state.nodes_in_memory.begin(), cpu_state.nodes_in_memory.end(),
2534       std::inserter(nodes_in_memory, nodes_in_memory.begin()),
2535       [](const std::pair<const NodeDef*, int>& node_port) {
2536         return std::make_pair(node_port.first->name(), node_port.second);
2537       });
2538   std::set<std::pair<string, int>> expected = {
2539       std::make_pair("bn", -1),
2540       std::make_pair("bn", 0),
2541       std::make_pair("bn", 2),
2542       std::make_pair("x", 0),
2543   };
2544   ExpectSetEq(expected, nodes_in_memory);
2545 
2546   const auto* node_states = scheduler_->GetNodeStates();
2547   const NodeState* bn_node = nullptr;
2548   const NodeState* x_node = nullptr;
2549   for (const auto& nodedef_node_state : *node_states) {
2550     const NodeDef* node = nodedef_node_state.first;
2551     const NodeState& node_state = nodedef_node_state.second;
2552     if (node->name() == "bn") {
2553       bn_node = &node_state;
2554     }
2555     if (node->name() == "x") {
2556       x_node = &node_state;
2557     }
2558   }
2559   CHECK_NOTNULL(bn_node);
2560   CHECK_NOTNULL(x_node);
2561 
2562   ValidateNodeDefs({"bn", "z1"}, x_node->outputs.at(0));
2563   ValidateNodeDefs({"z4"}, bn_node->outputs.at(-1));
2564   ValidateNodeDefs({"z1"}, bn_node->outputs.at(0));
2565   // z2 and z3 are bn.var + bn.var, so they appear twice in bn's output port 2.
2566   ValidateNodeDefs({"z2", "z3", "z2", "z3"}, bn_node->outputs.at(2));
2567 }
2568 
TEST_F(VirtualSchedulerTest,Variable)2569 TEST_F(VirtualSchedulerTest, Variable) {
2570   // Init.
2571   CreateGrapplerItemWithConv2DAndVariable();
2572   InitScheduler();
2573 
2574   // Run the scheduler.
2575   RunScheduler("");
2576 
2577   const auto* device_states = scheduler_->GetDeviceStates();
2578   const auto& cpu_state = device_states->at(kCPU0);
2579 
2580   // There is one Conv2D that takes x and f, but f is variable, so it should be
2581   // in persistent nodes.
2582   ValidateMemoryUsageSnapshot({"f", "Const/Const"}, /*port_num_expected=*/0,
2583                               cpu_state.persistent_nodes);
2584   // Only x in peak memory usage snapshot.
2585   ValidateMemoryUsageSnapshot({"x"}, /*port_num_expected=*/0,
2586                               cpu_state.mem_usage_snapshot_at_peak);
2587   ExpectUnorderedMapEq(
2588       {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 4624)},
2589       scheduler_->GetPersistentMemoryUsage());
2590   ExpectUnorderedMapEq(
2591       {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 12800)},
2592       scheduler_->GetPeakMemoryUsage());
2593 }
2594 
TEST_F(VirtualSchedulerTest,WhileLoop)2595 TEST_F(VirtualSchedulerTest, WhileLoop) {
2596   // Init.
2597   CreateGrapplerItemWithLoop();
2598   InitScheduler();
2599 
2600   // Run the scheduler.
2601   RunScheduler("");
2602 
2603   // Check the timeline
2604   RunMetadata metadata;
2605   scheduler_->Summary(&metadata);
2606 
2607   // Nodes in topological order:
2608   // * const, ones
2609   // * while/Enter, while/Enter_1
2610   // * while/Merge, while/Merge_1
2611   // * while/Less/y
2612   // * while/Less
2613   // * while/LoopCond
2614   // * while/Switch, while/Switch_1
2615   // * while/Identity, while/Identity_1, while/Exit, while/Exit_1
2616   // * while/add/y, while/concat/axis
2617   // * while/add, while/concat
2618   // * while/NextIteration, while/NextIteration_1
2619 
2620   int num_next_iteration = 0;
2621   int num_next_iteration_1 = 0;
2622   int num_exit = 0;
2623   int num_exit_1 = 0;
2624   int64 next_iter_start_micro;
2625   int64 next_iter_1_start_micro;
2626   int64 exit_start_micro;
2627   int64 exit_1_start_micro;
2628 
2629   std::unordered_map<string, int64> start_times;
2630   for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
2631     for (const auto& stats : device_step_stats.node_stats()) {
2632       start_times[stats.node_name()] = stats.all_start_micros();
2633       if (stats.node_name() == "while/NextIteration") {
2634         ++num_next_iteration;
2635         next_iter_start_micro = stats.all_start_micros();
2636       } else if (stats.node_name() == "while/NextIteration_1") {
2637         ++num_next_iteration_1;
2638         next_iter_1_start_micro = stats.all_start_micros();
2639       } else if (stats.node_name() == "while/Exit") {
2640         ++num_exit;
2641         exit_start_micro = stats.all_start_micros();
2642       } else if (stats.node_name() == "while/Exit_1") {
2643         ++num_exit_1;
2644         exit_1_start_micro = stats.all_start_micros();
2645       }
2646     }
2647   }
2648 
2649   // Make sure we went though the body of the loop once, and that the output of
2650   // the loop was scheduled as well.
2651   EXPECT_EQ(1, num_next_iteration);
2652   EXPECT_EQ(1, num_next_iteration_1);
2653   EXPECT_EQ(1, num_exit);
2654   EXPECT_EQ(1, num_exit_1);
2655 
2656   // Start times of while/NextIteration and while/NextIteration_1 should be
2657   // different, so should be those of while/Exit and while/Exit_1.
2658   EXPECT_NE(next_iter_start_micro, next_iter_1_start_micro);
2659   EXPECT_NE(exit_start_micro, exit_1_start_micro);
2660 
2661   // Check dependency among the nodes; no matter what scheduling mechanism we
2662   // use, the scheduled ops should follow these dependency chains.
2663   // Note that currently, VirtualScheduler executes while/Merge twice; hence,
2664   // we're not testing dependency chains related to while/Merge.
2665   // TODO(dyoon): after fixing while loop behavior correctly (run nodes in the
2666   // order of Enter, Merge, ...loop condition ..., ... loop body ...,
2667   // NextIteration, Merge, ... loop condition ..., Exit), re-enable dependency
2668   // chaining test w/ Merge nodes.
2669   ValidateDependencyChain(
2670       start_times,
2671       {"Const", "while/Enter",  // "while/Merge",
2672        "while/Less/y", "while/Less", "while/LoopCond", "while/Switch",
2673        "while/Identity", "while/add/y", "while/add", "while/NextIteration"});
2674   // ValidateDependencyChain(start_times, {"while/Merge", "while/Less"});
2675   ValidateDependencyChain(start_times,
2676                           {"ones", "while/Enter_1",  // "while/Merge_1",
2677                            "while/Switch_1", "while/Identity_1", "while/concat",
2678                            "while/NextIteration_1"});
2679   ValidateDependencyChain(start_times, {"while/Switch", "while/Exit"});
2680   ValidateDependencyChain(
2681       start_times, {"while/Identity", "while/concat/axis", "while/concat"});
2682   ValidateDependencyChain(start_times, {"while/Identity", "while/add"});
2683   ValidateDependencyChain(start_times, {"while/Switch_1", "while/Exit_1"});
2684 }
2685 
TEST_F(VirtualSchedulerTest,AnnotatedWhileLoop)2686 TEST_F(VirtualSchedulerTest, AnnotatedWhileLoop) {
2687   {
2688     // Init.
2689     CreateGrapplerItemWithLoop();
2690     InitScheduler();
2691 
2692     // Runs the scheduler.
2693     RunScheduler("");
2694     Costs c = scheduler_->Summary();
2695 
2696     EXPECT_EQ(23, c.execution_time.asMicroSeconds().count());
2697     // Both while/Merge and while/Merge_1 are scheduled twice.
2698     EXPECT_EQ(grappler_item_->graph.node_size() + 2, c.num_ops_total);
2699     EXPECT_FALSE(c.inaccurate);
2700     EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2701   }
2702 
2703   {
2704     // Init.
2705     CreateGrapplerItemWithLoopAnnotated();
2706     InitScheduler();
2707 
2708     // Runs the scheduler.
2709     RunScheduler("");
2710     Costs c = scheduler_->Summary();
2711 
2712     // The costs for Merge is accumulated twice for execution_count times, but
2713     // since Merge's cost is minimal, we keep this behavior here.
2714     EXPECT_EQ(178, c.execution_time.asMicroSeconds().count());
2715     // Both while/Merge and while/Merge_1 are scheduled twice.
2716     EXPECT_EQ(grappler_item_->graph.node_size() + 2, c.num_ops_total);
2717     EXPECT_FALSE(c.inaccurate);
2718     EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2719   }
2720 }
2721 
TEST_F(VirtualSchedulerTest,Condition)2722 TEST_F(VirtualSchedulerTest, Condition) {
2723   // Without annotation.
2724   {
2725     // Inits.
2726     CreateGrapplerItemWithCondition();
2727     InitScheduler();
2728 
2729     // Runs the scheduler.
2730     RunScheduler("");
2731     RunMetadata metadata;
2732     Costs c = scheduler_->Summary(&metadata);
2733 
2734     // Nodes in topological order: a/Less, Switch, First/Second, Merge.
2735     int num_a = 0;
2736     int num_less = 0;
2737     int num_switch = 0;
2738     int num_first = 0;
2739     int num_second = 0;
2740     int num_merge = 0;
2741 
2742     for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
2743       for (const auto& stats : device_step_stats.node_stats()) {
2744         if (stats.node_name() == "a") {
2745           ++num_a;
2746         } else if (stats.node_name() == "Less") {
2747           ++num_less;
2748         } else if (stats.node_name() == "Switch") {
2749           ++num_switch;
2750         } else if (stats.node_name() == "First") {
2751           ++num_first;
2752         } else if (stats.node_name() == "Second") {
2753           ++num_second;
2754         } else if (stats.node_name() == "Merge") {
2755           ++num_merge;
2756         }
2757       }
2758     }
2759 
2760     EXPECT_EQ(1, num_a);
2761     EXPECT_EQ(1, num_less);
2762     EXPECT_EQ(1, num_switch);
2763     EXPECT_EQ(1, num_first);
2764     EXPECT_EQ(1, num_second);
2765     EXPECT_EQ(2, num_merge);
2766 
2767     EXPECT_EQ(7, c.execution_time.asMicroSeconds().count());
2768     // Merge is executed twice.
2769     EXPECT_EQ(grappler_item_->graph.node_size() + 1, c.num_ops_total);
2770     EXPECT_FALSE(c.inaccurate);
2771     EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2772   }
2773 
2774   // With annotation.
2775   {
2776     // Inits.
2777     CreateGrapplerItemWithCondition();
2778 
2779     // Annotates the Switch node.
2780     for (auto& node : *grappler_item_->graph.mutable_node()) {
2781       if (node.name() == "Switch") {
2782         AttrValue attr_output_info;
2783         // Adds one output slot 0 so that Second shouldn't be executed.
2784         (*attr_output_info.mutable_list()).add_i(0);
2785         AddNodeAttr(kOutputSlots, attr_output_info, &node);
2786       }
2787     }
2788 
2789     InitScheduler();
2790 
2791     // Runs the scheduler.
2792     RunScheduler("");
2793     RunMetadata metadata;
2794     Costs c = scheduler_->Summary(&metadata);
2795 
2796     // Nodes in topological order: a/Less, Switch, Merge
2797     int num_a = 0;
2798     int num_less = 0;
2799     int num_switch = 0;
2800     int num_first = 0;
2801     int num_second = 0;
2802     int num_merge = 0;
2803 
2804     for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
2805       for (const auto& stats : device_step_stats.node_stats()) {
2806         if (stats.node_name() == "a") {
2807           ++num_a;
2808         } else if (stats.node_name() == "Less") {
2809           ++num_less;
2810         } else if (stats.node_name() == "Switch") {
2811           ++num_switch;
2812         } else if (stats.node_name() == "First") {
2813           ++num_first;
2814         } else if (stats.node_name() == "Second") {
2815           ++num_second;
2816         } else if (stats.node_name() == "Merge") {
2817           ++num_merge;
2818         }
2819       }
2820     }
2821 
2822     EXPECT_EQ(1, num_a);
2823     EXPECT_EQ(1, num_less);
2824     EXPECT_EQ(1, num_switch);
2825     EXPECT_EQ(1, num_first);
2826     EXPECT_EQ(0, num_second);
2827     EXPECT_EQ(1, num_merge);
2828 
2829     EXPECT_EQ(5, c.execution_time.asMicroSeconds().count());
2830     // Second is not executed.
2831     EXPECT_EQ(grappler_item_->graph.node_size() - 1, c.num_ops_total);
2832     EXPECT_FALSE(c.inaccurate);
2833     EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2834   }
2835 }
2836 
TEST_F(VirtualSchedulerTest,InterDeviceTransfer)2837 TEST_F(VirtualSchedulerTest, InterDeviceTransfer) {
2838   // Init.
2839   CreateGrapplerItemWithInterDeviceTransfers();
2840   InitScheduler();
2841 
2842   // Run the scheduler.
2843   auto ops_executed = RunScheduler("");
2844 
2845   // Helper lambda to extract port num from _Send and _Recv op name.
2846   auto get_port_num = [](const string& name) -> int {
2847     if (name.find("bn_0") != string::npos) {
2848       return 0;
2849     } else if (name.find("bn_1") != string::npos) {
2850       return 1;
2851     } else if (name.find("bn_2") != string::npos) {
2852       return 2;
2853     } else if (name.find("bn_minus1") != string::npos) {
2854       return -1;
2855     }
2856     return -999;
2857   };
2858 
2859   // Reorganize ops_executed for further testing.
2860   std::unordered_map<string, int> op_count;
2861   std::unordered_map<int, string> recv_op_names;
2862   std::unordered_map<int, string> send_op_names;
2863   for (const auto& x : ops_executed) {
2864     const auto& name = x.first;
2865     const auto& node_info = x.second;
2866     const auto& op = node_info.op_info.op();
2867     if (op == kRecv) {
2868       recv_op_names[get_port_num(name)] = name;
2869     } else if (op == kSend) {
2870       send_op_names[get_port_num(name)] = name;
2871     }
2872     op_count[op]++;
2873   }
2874 
2875   // Same number of _Send and _Recv.
2876   EXPECT_EQ(op_count.at(kSend), op_count.at(kRecv));
2877 
2878   // Expect 4 Send and Recvs each: port 0, 1, and, 2, and control dependency.
2879   EXPECT_EQ(op_count.at(kRecv), 4);
2880   EXPECT_EQ(op_count.at(kSend), 4);
2881 
2882   // Helper lambda for extracting output Tensor size.
2883   auto get_output_size = [this, ops_executed](const string& name) -> int64 {
2884     const auto& output_properties_ = ops_executed.at(name).op_info.outputs();
2885     std::vector<OpInfo::TensorProperties> output_properties;
2886     for (const auto& output_property : output_properties_) {
2887       output_properties.push_back(output_property);
2888     }
2889     return CalculateOutputSize(output_properties, 0);
2890   };
2891 
2892   // Validate transfer size.
2893   // Batchnorm output y is 4D vector: batch x width x width x depth.
2894   int input_size = 4 * batch_size_ * width_ * height_ * depth_in_;
2895   EXPECT_EQ(get_output_size(recv_op_names[0]), input_size);
2896   EXPECT_EQ(get_output_size(send_op_names[0]), input_size);
2897   // Mean and vars are 1-D vector with size depth_in_.
2898   EXPECT_EQ(get_output_size(recv_op_names[1]), 4 * depth_in_);
2899   EXPECT_EQ(get_output_size(send_op_names[1]), 4 * depth_in_);
2900   EXPECT_EQ(get_output_size(recv_op_names[2]), 4 * depth_in_);
2901   EXPECT_EQ(get_output_size(send_op_names[2]), 4 * depth_in_);
2902   // Control dependency size is 4B.
2903   EXPECT_EQ(get_output_size(recv_op_names[-1]), 4);
2904   EXPECT_EQ(get_output_size(send_op_names[-1]), 4);
2905 }
2906 
TEST_F(VirtualSchedulerTest,GraphWithSendRecv)2907 TEST_F(VirtualSchedulerTest, GraphWithSendRecv) {
2908   // Init.
2909   CreateGrapplerItemWithSendRecv();
2910   InitScheduler();
2911 
2912   // Run the scheduler.
2913   auto ops_executed = RunScheduler("");
2914 
2915   EXPECT_GT(ops_executed.count("Const"), 0);
2916   EXPECT_GT(ops_executed.count("Send"), 0);
2917   EXPECT_GT(ops_executed.count("Recv"), 0);
2918 }
2919 
TEST_F(VirtualSchedulerTest,GraphWithSendRecvDifferentDevice)2920 TEST_F(VirtualSchedulerTest, GraphWithSendRecvDifferentDevice) {
2921   // Init.
2922   CreateGrapplerItemWithSendRecv();
2923   // Change Recv node's device so that Send and Recv are placed on different
2924   // devices.
2925   auto& graph = grappler_item_->graph;
2926   const string recv_device = kCPU1;
2927   for (int i = 0; i < graph.node_size(); i++) {
2928     auto* node = graph.mutable_node(i);
2929     if (node->name() == "Recv") {
2930       node->set_device(recv_device);
2931       auto* attr = node->mutable_attr();
2932       (*attr)["recv_device"].set_s(recv_device);
2933     } else if (node->name() == "Send") {
2934       auto* attr = node->mutable_attr();
2935       (*attr)["recv_device"].set_s(recv_device);
2936     }
2937   }
2938   InitScheduler();
2939 
2940   // Run the scheduler.
2941   auto ops_executed = RunScheduler("");
2942 
2943   // Expect Const, Send, Recv, and VirtualScheduler created Send and Recv ops.
2944   EXPECT_GT(ops_executed.count("Const"), 0);
2945   EXPECT_GT(ops_executed.count("Send"), 0);
2946   EXPECT_GT(ops_executed.count("Send_Send_0_from_/job_localhost/replica_0/"
2947                                "task_0/cpu_0_to_/job_localhost"
2948                                "/replica_0/task_0/cpu_1"),
2949             0);
2950   EXPECT_GT(ops_executed.count(
2951                 "Recv_Send_0_on_/job_localhost/replica_0/task_0/cpu_1"),
2952             0);
2953   EXPECT_GT(ops_executed.count("Recv"), 0);
2954 }
2955 
TEST_F(VirtualSchedulerTest,GraphWihtOnlyRecv)2956 TEST_F(VirtualSchedulerTest, GraphWihtOnlyRecv) {
2957   // Init.
2958   CreateGrapplerItemWithRecvWithoutSend();
2959   InitScheduler();
2960 
2961   // Run the scheduler.
2962   auto ops_executed = RunScheduler("");
2963 
2964   // Recv without Send will be treated as initially ready node.
2965   EXPECT_GT(ops_executed.count("Recv"), 0);
2966 }
2967 
TEST_F(VirtualSchedulerTest,AddMergeSwitch)2968 TEST_F(VirtualSchedulerTest, AddMergeSwitch) {
2969   // Override scheduler_ with CompositeNodeNamager.
2970   scheduler_ = absl::make_unique<TestVirtualScheduler>(
2971       /*use_static_shapes=*/true,
2972       /*use_aggressive_shape_inference=*/true, &composite_node_manager_,
2973       cluster_.get());
2974   CreateGrapplerItemWithSwitchMergeInput();
2975   InitScheduler();
2976 
2977   // pred --+                      z --+
2978   //        |                          |
2979   //        V                          V
2980   // x -> Switch --------> Merge ---> Add --> y
2981   //        |                ^
2982   //        |                |
2983   //        +-----> Add -----+
2984   //                 ^
2985   //                 |
2986   // b --------------+
2987 
2988   // Run the scheduler. The current VirtualScheduler, w/o annotation, triggers
2989   // both outputs of Switch; then Merge (as long as one input is ready, it's z
2990   // is ready, if we just use num_inputs_ready counter, the final Add becomes
2991   // ready. possible to skipt scheduling z. (Need to use CompositeNodeManager
2992   // to test this case).
2993   auto ops_executed = RunScheduler("");
2994 
2995   EXPECT_GT(ops_executed.count("z"), 0);
2996 }
2997 
TEST_F(VirtualSchedulerTest,AddFromOneTensor)2998 TEST_F(VirtualSchedulerTest, AddFromOneTensor) {
2999   CreateGrapplerItemWithAddFromOneTensor();
3000   InitScheduler();
3001 
3002   // x -+----> Add --> y
3003   //    |       ^
3004   //    |       |
3005   //    +-------+
3006 
3007   // Run the scheduler.
3008   auto ops_executed = RunScheduler("");
3009   EXPECT_GT(ops_executed.count("y"), 0);
3010   EXPECT_GT(ops_executed.count("x"), 0);
3011 }
3012 
3013 }  // namespace
3014 }  // end namespace grappler
3015 }  // end namespace tensorflow
3016