1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/costs/virtual_scheduler.h"
17
18 #include "tensorflow/cc/ops/standard_ops.h"
19 #include "tensorflow/core/framework/tensor_description.pb.h"
20 #include "tensorflow/core/framework/tensor_shape.pb.h"
21 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
22 #include "tensorflow/core/grappler/costs/utils.h"
23 #include "tensorflow/core/grappler/costs/virtual_placer.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26
27 namespace tensorflow {
28 namespace grappler {
29 namespace {
30
31 // Device names:
32 constexpr char kCPU0[] = "/job:localhost/replica:0/task:0/cpu:0";
33 constexpr char kCPU1[] = "/job:localhost/replica:0/task:0/cpu:1";
34 constexpr char kChannelFrom0To1[] = "Channel from CPU0 to CPU1";
35 constexpr char kChannelFrom1To0[] = "Channel from CPU1 to CPU0";
36 // Op names:
37 constexpr char kConv2D[] = "Conv2D";
38 constexpr char kSend[] = "_Send";
39 constexpr char kRecv[] = "_Recv";
40
41 class ReadyNodeManagerTest : public ::testing::Test {
42 protected:
ReadyNodeManagerTest()43 ReadyNodeManagerTest() {
44 // node1_ to node6_ on kCPU0, with time_ready in reverse_order.
45 NodeSetUp("Node1", kConv2D, kCPU0, 6000, &node1_);
46 NodeSetUp("Node2", kConv2D, kCPU0, 5000, &node2_);
47 NodeSetUp("Node3", kConv2D, kCPU0, 4000, &node3_);
48 NodeSetUp("Node4", kConv2D, kCPU0, 3000, &node4_);
49 NodeSetUp("Node5", kConv2D, kCPU0, 2000, &node5_);
50 NodeSetUp("Node6", kConv2D, kCPU0, 1000, &node6_);
51 }
52
NodeSetUp(const string & name,const string & op_name,const string & device_name,const uint64 time_ready,NodeDef * node)53 void NodeSetUp(const string& name, const string& op_name,
54 const string& device_name, const uint64 time_ready,
55 NodeDef* node) {
56 node->set_name(name);
57 node->set_op(op_name);
58 node->set_device(device_name);
59
60 node_states_[node] = NodeState();
61 node_states_[node].time_ready = time_ready;
62 node_states_[node].device_name = device_name;
63 }
64
65 NodeDef node1_, node2_, node3_, node4_, node5_, node6_;
66 std::unordered_map<const NodeDef*, NodeState> node_states_;
67 };
68
69 // Tests that FIFOManager correctly returns the current node with only 1 node.
TEST_F(ReadyNodeManagerTest,GetSingleNodeFIFOManager)70 TEST_F(ReadyNodeManagerTest, GetSingleNodeFIFOManager) {
71 FIFOManager manager = FIFOManager();
72 manager.AddNode(&node1_);
73 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
74 }
75
76 // Tests that FIFOManager removes the only node contained within.
TEST_F(ReadyNodeManagerTest,RemoveSingleNodeFIFOManager)77 TEST_F(ReadyNodeManagerTest, RemoveSingleNodeFIFOManager) {
78 FIFOManager manager = FIFOManager();
79 manager.AddNode(&node1_);
80
81 // Removes the only node in FIFOManager.
82 manager.RemoveCurrNode();
83 EXPECT_TRUE(manager.Empty());
84 }
85
86 // Tests that FIFOManager can remove multiple nodes and returns the current node
87 // in the right order.
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultipleFIFOManager)88 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleFIFOManager) {
89 FIFOManager manager = FIFOManager();
90 manager.AddNode(&node1_);
91 manager.AddNode(&node2_);
92 manager.AddNode(&node3_);
93 manager.AddNode(&node4_);
94
95 // Keeps checking current node while removing nodes from manager.
96 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
97 manager.RemoveCurrNode();
98 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
99 manager.RemoveCurrNode();
100 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
101 manager.RemoveCurrNode();
102 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
103 manager.RemoveCurrNode();
104 EXPECT_TRUE(manager.Empty());
105 }
106
107 // Tests that FIFOManager can remove multiple nodes and add more nodes, still
108 // returning the current node in the right order.
TEST_F(ReadyNodeManagerTest,AddAndRemoveMultipleFIFOManager)109 TEST_F(ReadyNodeManagerTest, AddAndRemoveMultipleFIFOManager) {
110 FIFOManager manager = FIFOManager();
111 manager.AddNode(&node1_);
112 manager.AddNode(&node2_);
113 manager.AddNode(&node3_);
114 manager.AddNode(&node4_);
115
116 // Keeps checking current node as nodes are removed and added.
117 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
118 manager.RemoveCurrNode();
119 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
120 manager.AddNode(&node5_);
121 // GetCurrNode() should return the same node even if some nodes are added,
122 // until RemoveCurrNode() is called.
123 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
124 manager.RemoveCurrNode();
125 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
126 manager.RemoveCurrNode();
127 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
128 manager.AddNode(&node6_);
129 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
130 manager.RemoveCurrNode();
131 EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
132 manager.RemoveCurrNode();
133 EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
134 manager.RemoveCurrNode();
135 EXPECT_TRUE(manager.Empty());
136 }
137
138 // Tests that LIFOManager correctly returns the current node with only 1 node.
TEST_F(ReadyNodeManagerTest,GetSingleNodeLIFOManager)139 TEST_F(ReadyNodeManagerTest, GetSingleNodeLIFOManager) {
140 LIFOManager manager = LIFOManager();
141 manager.AddNode(&node1_);
142 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
143 }
144
145 // Tests that LIFOManager removes the only node contained within.
TEST_F(ReadyNodeManagerTest,RemoveSingleNodeLIFOManager)146 TEST_F(ReadyNodeManagerTest, RemoveSingleNodeLIFOManager) {
147 LIFOManager manager = LIFOManager();
148 manager.AddNode(&node1_);
149
150 // Removes the only node in LIFOManager.
151 manager.RemoveCurrNode();
152 EXPECT_TRUE(manager.Empty());
153 }
154
155 // Tests that LIFOManager can remove multiple nodes and returns the current node
156 // in the right order.
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultipleLIFOManager)157 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleLIFOManager) {
158 LIFOManager manager = LIFOManager();
159 manager.AddNode(&node1_);
160 manager.AddNode(&node2_);
161 manager.AddNode(&node3_);
162 manager.AddNode(&node4_);
163
164 // Keeps checking current node while removing nodes from manager.
165 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
166 manager.RemoveCurrNode();
167 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
168 manager.RemoveCurrNode();
169 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
170 manager.RemoveCurrNode();
171 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
172 manager.RemoveCurrNode();
173 EXPECT_TRUE(manager.Empty());
174 }
175
176 // Tests that LIFOManager can remove multiple nodes (must be removing the
177 // current node) and add more nodes, still returning the current node in the
178 // right order.
TEST_F(ReadyNodeManagerTest,AddAndRemoveMultipleLIFOManager)179 TEST_F(ReadyNodeManagerTest, AddAndRemoveMultipleLIFOManager) {
180 LIFOManager manager = LIFOManager();
181 manager.AddNode(&node1_);
182 manager.AddNode(&node2_);
183 manager.AddNode(&node3_);
184 manager.AddNode(&node4_);
185
186 // Keeps checking current node as nodes are removed and added.
187 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
188 manager.RemoveCurrNode();
189 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
190 manager.AddNode(&node5_);
191 // GetCurrNode() should return the same node even if some nodes are added,
192 // until RemoveCurrNode() is called.
193 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
194 manager.RemoveCurrNode();
195 EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
196 manager.RemoveCurrNode();
197 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
198 manager.AddNode(&node6_);
199 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
200 manager.RemoveCurrNode();
201 EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
202 manager.RemoveCurrNode();
203 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
204 manager.RemoveCurrNode();
205 EXPECT_TRUE(manager.Empty());
206 }
207
TEST_F(ReadyNodeManagerTest,GetSingleNodeFirstReadyManager)208 TEST_F(ReadyNodeManagerTest, GetSingleNodeFirstReadyManager) {
209 FirstReadyManager manager;
210 TF_EXPECT_OK(manager.Init(&node_states_));
211 manager.AddNode(&node1_);
212 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
213 }
214
TEST_F(ReadyNodeManagerTest,RemoveSingleNodeFirstReadyManager)215 TEST_F(ReadyNodeManagerTest, RemoveSingleNodeFirstReadyManager) {
216 FirstReadyManager manager;
217 TF_EXPECT_OK(manager.Init(&node_states_));
218 manager.AddNode(&node1_);
219 manager.RemoveCurrNode();
220 EXPECT_TRUE(manager.Empty());
221 }
222
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultipleFirstReadyManager)223 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleFirstReadyManager) {
224 FirstReadyManager manager;
225 TF_EXPECT_OK(manager.Init(&node_states_));
226 // Insert nodes in some random order.
227 manager.AddNode(&node2_);
228 manager.AddNode(&node1_);
229 manager.AddNode(&node4_);
230 manager.AddNode(&node5_);
231 manager.AddNode(&node3_);
232 manager.AddNode(&node6_);
233
234 // In whatever order we insert nodes, we get the same order based on nodes'
235 // time_ready.
236 EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
237 manager.RemoveCurrNode();
238 EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
239 manager.RemoveCurrNode();
240 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
241 manager.RemoveCurrNode();
242 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
243 manager.RemoveCurrNode();
244 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
245 manager.RemoveCurrNode();
246 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
247 manager.RemoveCurrNode();
248 EXPECT_TRUE(manager.Empty());
249 }
250
TEST_F(ReadyNodeManagerTest,GetCurrNodeFirstReadyManager)251 TEST_F(ReadyNodeManagerTest, GetCurrNodeFirstReadyManager) {
252 FirstReadyManager manager;
253 TF_EXPECT_OK(manager.Init(&node_states_));
254
255 // Inserts nodes in some random order.
256 manager.AddNode(&node2_);
257 manager.AddNode(&node1_);
258 manager.AddNode(&node4_);
259 manager.AddNode(&node5_);
260 manager.AddNode(&node3_);
261 manager.AddNode(&node6_);
262
263 // Among these nodes, node6 has the smallest time_ready, hence, GetCurrNode()
264 // should return it.
265 EXPECT_EQ("Node6", manager.GetCurrNode()->name());
266
267 // Now insrets a few other nodes, but their time_ready's are even smaller than
268 // that of Node6. Before calling RemoveCurrNode(), GetCurrNode() should return
269 // the same node, Node6, in this case.
270 NodeDef node7;
271 NodeDef node8;
272 NodeDef node9;
273 NodeSetUp("Node7", kConv2D, kCPU0, 5, &node7);
274 NodeSetUp("Node8", kConv2D, kCPU0, 4, &node8);
275 NodeSetUp("Node9", kConv2D, kCPU0, 3, &node9);
276
277 manager.AddNode(&node7);
278 EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
279
280 manager.AddNode(&node8);
281 EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
282
283 manager.RemoveCurrNode();
284 // Now Node6 is removed, and GetCurrNode() will return Node8.
285 EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
286
287 // Again, AddNode shouldn't change GetCurrNode().
288 manager.AddNode(&node9);
289 EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
290
291 manager.RemoveCurrNode();
292 EXPECT_EQ(manager.GetCurrNode()->name(), "Node9");
293 manager.RemoveCurrNode();
294 EXPECT_EQ(manager.GetCurrNode()->name(), "Node7");
295 manager.RemoveCurrNode();
296 EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
297 manager.RemoveCurrNode();
298 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
299 manager.RemoveCurrNode();
300 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
301 manager.RemoveCurrNode();
302 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
303 manager.RemoveCurrNode();
304 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
305 manager.RemoveCurrNode();
306 EXPECT_TRUE(manager.Empty());
307 }
308
TEST_F(ReadyNodeManagerTest,DeterminismInFirstReadyManager)309 TEST_F(ReadyNodeManagerTest, DeterminismInFirstReadyManager) {
310 FirstReadyManager manager1;
311 TF_EXPECT_OK(manager1.Init(&node_states_));
312 FirstReadyManager manager2;
313 TF_EXPECT_OK(manager2.Init(&node_states_));
314
315 // 6 nodes with same time_ready.
316 NodeDef node7;
317 NodeDef node8;
318 NodeDef node9;
319 NodeDef node10;
320 NodeDef node11;
321 NodeDef node12;
322 NodeSetUp("Node7", kConv2D, kCPU0, 1000, &node7);
323 NodeSetUp("Node8", kConv2D, kCPU0, 1000, &node8);
324 NodeSetUp("Node9", kConv2D, kCPU0, 1000, &node9);
325 NodeSetUp("Node10", kConv2D, kCPU0, 1000, &node10);
326 NodeSetUp("Node11", kConv2D, kCPU0, 1000, &node11);
327 NodeSetUp("Node12", kConv2D, kCPU0, 1000, &node12);
328
329 // Adds the above 6 nodes to manager1.
330 manager1.AddNode(&node7);
331 manager1.AddNode(&node8);
332 manager1.AddNode(&node9);
333 manager1.AddNode(&node10);
334 manager1.AddNode(&node11);
335 manager1.AddNode(&node12);
336
337 // Adds the above 6 nodes to manager2, but in a different order.
338 manager2.AddNode(&node8);
339 manager2.AddNode(&node11);
340 manager2.AddNode(&node9);
341 manager2.AddNode(&node10);
342 manager2.AddNode(&node7);
343 manager2.AddNode(&node12);
344
345 // Expects both managers return the same nodes for deterministic node
346 // scheduling.
347 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
348 manager1.RemoveCurrNode();
349 manager2.RemoveCurrNode();
350
351 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
352 manager1.RemoveCurrNode();
353 manager2.RemoveCurrNode();
354
355 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
356 manager1.RemoveCurrNode();
357 manager2.RemoveCurrNode();
358
359 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
360 manager1.RemoveCurrNode();
361 manager2.RemoveCurrNode();
362
363 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
364 manager1.RemoveCurrNode();
365 manager2.RemoveCurrNode();
366
367 EXPECT_EQ(manager1.GetCurrNode()->name(), manager2.GetCurrNode()->name());
368 manager1.RemoveCurrNode();
369 manager2.RemoveCurrNode();
370
371 EXPECT_TRUE(manager1.Empty());
372 EXPECT_TRUE(manager2.Empty());
373 }
374
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultiplePriorityReadyManager)375 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultiplePriorityReadyManager) {
376 PriorityReadyManager manager;
377 TF_EXPECT_OK(manager.Init(&node_states_));
378
379 // Sets up node priorities.
380 std::unordered_map<string, int> node_priority = {
381 {"Node1", 1}, {"Node2", 2}, {"Node3", 2}, {"Node4", 4}, {"Node5", 5}};
382 TF_EXPECT_OK(manager.SetPriority(node_priority));
383
384 // Inserts nodes in some random order.
385 manager.AddNode(&node3_);
386 manager.AddNode(&node1_);
387 manager.AddNode(&node4_);
388 manager.AddNode(&node5_);
389 manager.AddNode(&node2_);
390 manager.AddNode(&node6_);
391
392 // Expects nodes scheduled based on priority.
393 // Node6 should default to lowest priority, since it is not found.
394 EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
395 manager.RemoveCurrNode();
396 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
397 manager.RemoveCurrNode();
398 // Nodes 2 and 3 have equal priority and so should be scheduled ready-first.
399 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
400 manager.RemoveCurrNode();
401 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
402 manager.RemoveCurrNode();
403 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
404 manager.RemoveCurrNode();
405 EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
406 manager.RemoveCurrNode();
407 EXPECT_TRUE(manager.Empty());
408 }
409
TEST_F(ReadyNodeManagerTest,RemoveSingleNodeCompositeNodeManager)410 TEST_F(ReadyNodeManagerTest, RemoveSingleNodeCompositeNodeManager) {
411 CompositeNodeManager manager;
412 TF_EXPECT_OK(manager.Init(&node_states_));
413 manager.AddNode(&node1_);
414 manager.RemoveCurrNode();
415 EXPECT_TRUE(manager.Empty());
416 }
417
TEST_F(ReadyNodeManagerTest,GetAndRemoveMultipleComopsiteNodeManager)418 TEST_F(ReadyNodeManagerTest, GetAndRemoveMultipleComopsiteNodeManager) {
419 CompositeNodeManager manager;
420 TF_EXPECT_OK(manager.Init(&node_states_));
421 manager.AddNode(&node1_);
422 manager.AddNode(&node2_);
423 manager.AddNode(&node3_);
424 manager.AddNode(&node4_);
425
426 // Keeps checking current node as nodes are removed and added.
427 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
428 manager.RemoveCurrNode();
429 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
430 manager.AddNode(&node5_);
431 // GetCurrNode() should return the same node even if some nodes are added,
432 // until RemoveCurrNode() is called.
433 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
434 manager.RemoveCurrNode();
435 EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
436 manager.RemoveCurrNode();
437 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
438 manager.AddNode(&node6_);
439 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
440 manager.RemoveCurrNode();
441 EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
442 manager.RemoveCurrNode();
443 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
444 manager.RemoveCurrNode();
445 EXPECT_TRUE(manager.Empty());
446 }
447
TEST_F(ReadyNodeManagerTest,MultiDeviceSendRecvComopsiteNodeManager)448 TEST_F(ReadyNodeManagerTest, MultiDeviceSendRecvComopsiteNodeManager) {
449 CompositeNodeManager manager;
450 TF_EXPECT_OK(manager.Init(&node_states_));
451 // Additional nodes on kCPU1.
452 NodeDef node7;
453 NodeDef node8;
454 NodeDef node9;
455 NodeSetUp("Node7", kConv2D, kCPU1, 1001, &node7);
456 NodeSetUp("Node8", kConv2D, kCPU1, 2001, &node8);
457 NodeSetUp("Node9", kConv2D, kCPU1, 3001, &node9);
458
459 // Send and Recv nodes.
460 NodeDef send1;
461 NodeDef send2;
462 NodeDef recv1;
463 NodeDef recv2;
464 NodeSetUp("Send1", kSend, kChannelFrom0To1, 2002, &send1);
465 NodeSetUp("Send2", kSend, kChannelFrom1To0, 2005, &send2);
466 NodeSetUp("Recv1", kRecv, kCPU0, 2003, &recv1);
467 NodeSetUp("Recv2", kRecv, kCPU1, 2004, &recv2);
468
469 // Inserts nodes.
470 manager.AddNode(&node1_);
471 manager.AddNode(&node2_);
472 manager.AddNode(&node3_);
473 manager.AddNode(&node4_);
474 manager.AddNode(&node5_);
475 manager.AddNode(&node6_);
476 manager.AddNode(&node7);
477 manager.AddNode(&node8);
478 manager.AddNode(&node9);
479 manager.AddNode(&send1);
480 manager.AddNode(&send2);
481 manager.AddNode(&recv1);
482 manager.AddNode(&recv2);
483
484 // On kCPU0; last one is node6_, on kCPU1: last one is node9;
485 // so choose one that has earliest time_ready among node6_, node9,
486 // Send1, Send2, Recv1, and Recv2.
487 EXPECT_EQ(manager.GetCurrNode()->name(), "Node6");
488 manager.RemoveCurrNode();
489 // Then, the next one on kCPU0 is node5_; choose the earliest time_ready node
490 // among node5_, node9, Send1, Send2, Recv1, and Recv2.
491 EXPECT_EQ(manager.GetCurrNode()->name(), "Node5");
492 manager.RemoveCurrNode();
493 // Next, choose among node4_, node9, Send1, Send2, Recv1, and Recv2.
494 EXPECT_EQ(manager.GetCurrNode()->name(), "Send1");
495 manager.RemoveCurrNode();
496 // Next, choose among node4_, node9, Sen2, Recv1, and Recv2.
497 EXPECT_EQ(manager.GetCurrNode()->name(), "Recv1");
498 manager.RemoveCurrNode();
499 // Next, choose among node4_, node9, Send2, and Recv2.
500 EXPECT_EQ(manager.GetCurrNode()->name(), "Recv2");
501 manager.RemoveCurrNode();
502 // Next, choose among node4_, node9, and Send2.
503 EXPECT_EQ(manager.GetCurrNode()->name(), "Send2");
504 manager.RemoveCurrNode();
505 // Next, choose between node4_, node9.
506 EXPECT_EQ(manager.GetCurrNode()->name(), "Node4");
507 manager.RemoveCurrNode();
508 // Next, choose between node3_, node9.
509 EXPECT_EQ(manager.GetCurrNode()->name(), "Node9");
510 manager.RemoveCurrNode();
511 // Next, choose between node3_, node8.
512 EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
513 manager.RemoveCurrNode();
514 // Next, choose between node3_, node7.
515 EXPECT_EQ(manager.GetCurrNode()->name(), "Node7");
516 manager.RemoveCurrNode();
517 // Then, just the nodes on kCPU1 -- LIFO.
518 EXPECT_EQ(manager.GetCurrNode()->name(), "Node3");
519 manager.RemoveCurrNode();
520 EXPECT_EQ(manager.GetCurrNode()->name(), "Node2");
521 manager.RemoveCurrNode();
522 EXPECT_EQ(manager.GetCurrNode()->name(), "Node1");
523 manager.RemoveCurrNode();
524 EXPECT_TRUE(manager.Empty());
525 }
526
TEST_F(ReadyNodeManagerTest,DeterminismInCompositeNodeManager)527 TEST_F(ReadyNodeManagerTest, DeterminismInCompositeNodeManager) {
528 CompositeNodeManager manager;
529 TF_EXPECT_OK(manager.Init(&node_states_));
530 CompositeNodeManager manager2;
531 TF_EXPECT_OK(manager2.Init(&node_states_));
532
533 // 6 nodes with same time_ready.
534 NodeDef node7;
535 NodeDef node8;
536 NodeDef node9;
537 NodeDef node10;
538 NodeDef node11;
539 NodeDef node12;
540 NodeSetUp("Node7", kConv2D, kCPU0, 1000, &node7);
541 NodeSetUp("Node8", kSend, kCPU0, 1000, &node8);
542 NodeSetUp("Node9", kRecv, kCPU0, 1000, &node9);
543 NodeSetUp("Node10", kConv2D, kCPU0, 999, &node10);
544 NodeSetUp("Node11", kRecv, kCPU0, 999, &node11);
545 NodeSetUp("Node12", kConv2D, kCPU1, 1000, &node12);
546
547 // Adds Nodes 7 to 9 to manager.
548 manager.AddNode(&node7);
549 manager.AddNode(&node8);
550 manager.AddNode(&node9);
551
552 // It should return _Send, Recv, and the other op order, when the candidate
553 // nodes have same time_ready.
554 EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
555 EXPECT_EQ(manager.GetCurrNode()->op(), kSend);
556 manager.RemoveCurrNode();
557 EXPECT_EQ(manager.GetCurrNode()->name(), "Node9");
558 EXPECT_EQ(manager.GetCurrNode()->op(), kRecv);
559 manager.RemoveCurrNode();
560 EXPECT_EQ(manager.GetCurrNode()->name(), "Node7");
561 EXPECT_EQ(manager.GetCurrNode()->op(), kConv2D);
562 manager.RemoveCurrNode();
563 EXPECT_TRUE(manager.Empty());
564
565 // Adds Nodes 7 to 9 to manager, but in a different order.
566 manager.AddNode(&node9);
567 manager.AddNode(&node8);
568 manager.AddNode(&node7);
569
570 // Expects same order (_Send, _Recv, and the other op), regardless of Add
571 // order.
572 EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
573 EXPECT_EQ(manager.GetCurrNode()->op(), kSend);
574 manager.RemoveCurrNode();
575 EXPECT_EQ(manager.GetCurrNode()->name(), "Node9");
576 EXPECT_EQ(manager.GetCurrNode()->op(), kRecv);
577 manager.RemoveCurrNode();
578 EXPECT_EQ(manager.GetCurrNode()->name(), "Node7");
579 EXPECT_EQ(manager.GetCurrNode()->op(), kConv2D);
580 manager.RemoveCurrNode();
581 EXPECT_TRUE(manager.Empty());
582
583 // Conv2D's time_ready < Send's time_ready; Expects Conv2D first.
584 manager.AddNode(&node8);
585 manager.AddNode(&node10);
586 EXPECT_EQ(manager.GetCurrNode()->name(), "Node10");
587 EXPECT_EQ(manager.GetCurrNode()->op(), kConv2D);
588 manager.RemoveCurrNode();
589 EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
590 EXPECT_EQ(manager.GetCurrNode()->op(), kSend);
591 manager.RemoveCurrNode();
592 EXPECT_TRUE(manager.Empty());
593
594 // Recv's time_ready < Send' time_ready; Expects Recv first.
595 manager.AddNode(&node11);
596 manager.AddNode(&node8);
597 EXPECT_EQ(manager.GetCurrNode()->name(), "Node11");
598 EXPECT_EQ(manager.GetCurrNode()->op(), kRecv);
599 manager.RemoveCurrNode();
600 EXPECT_EQ(manager.GetCurrNode()->name(), "Node8");
601 EXPECT_EQ(manager.GetCurrNode()->op(), kSend);
602 manager.RemoveCurrNode();
603 EXPECT_TRUE(manager.Empty());
604
605 // Node7 and 12 are normal ops with the same time_ready, placed on different
606 // devices. These two nodes are added to manager and manager2, but in
607 // different orders; Expects GetCurrNode() returns the nodes in the same
608 // order.
609 manager.AddNode(&node7);
610 manager.AddNode(&node12);
611
612 manager2.AddNode(&node12);
613 manager2.AddNode(&node7);
614
615 EXPECT_EQ(manager.GetCurrNode()->name(), manager2.GetCurrNode()->name());
616 manager.RemoveCurrNode();
617 manager2.RemoveCurrNode();
618 EXPECT_EQ(manager.GetCurrNode()->name(), manager2.GetCurrNode()->name());
619 manager.RemoveCurrNode();
620 manager2.RemoveCurrNode();
621 EXPECT_TRUE(manager.Empty());
622 }
623
624 // Class for testing virtual scheduler.
625 class TestVirtualScheduler : public VirtualScheduler {
626 public:
TestVirtualScheduler(const bool use_static_shapes,const bool use_aggressive_shape_inference,ReadyNodeManager * ready_node_manager,Cluster * cluster)627 TestVirtualScheduler(const bool use_static_shapes,
628 const bool use_aggressive_shape_inference,
629 ReadyNodeManager* ready_node_manager, Cluster* cluster)
630 : VirtualScheduler(
631 use_static_shapes, use_aggressive_shape_inference, cluster,
632 ready_node_manager,
633 absl::make_unique<VirtualPlacer>(cluster->GetDevices())) {
634 enable_mem_usage_tracking();
635 }
636
637 FRIEND_TEST(VirtualSchedulerTest, MemoryUsage);
638 FRIEND_TEST(VirtualSchedulerTest, ControlDependency);
639 FRIEND_TEST(VirtualSchedulerTest, ComplexDependency);
640 FRIEND_TEST(VirtualSchedulerTest, Variable);
641 FRIEND_TEST(VirtualSchedulerTest, InterDeviceTransfer);
642 };
643
644 class VirtualSchedulerTest : public ::testing::Test {
645 protected:
VirtualSchedulerTest()646 VirtualSchedulerTest() {
647 // Initializes cluster_ and scheduler_.
648 std::unordered_map<string, DeviceProperties> devices;
649
650 // Set some dummy CPU properties
651 DeviceProperties cpu_device = GetDummyCPUDevice();
652
653 // IMPORTANT: Device is not actually ever used in the test case since
654 // force_cpu_type is defaulted to "Haswell"
655 devices[kCPU0] = cpu_device;
656 devices[kCPU1] = cpu_device;
657 cluster_ = absl::make_unique<VirtualCluster>(devices);
658 scheduler_ = absl::make_unique<TestVirtualScheduler>(
659 /*use_static_shapes=*/true,
660 /*use_aggressive_shape_inference=*/true, &first_ready_manager_,
661 cluster_.get());
662 }
663
GetDummyCPUDevice()664 DeviceProperties GetDummyCPUDevice() {
665 // Create CPU with 2 cores, 4 Ghz freq, 2 GB/s mem bandwidth.
666 // - 8 Gflops
667 // - 2 GB/s
668 DeviceProperties cpu_device;
669 cpu_device.set_type("CPU");
670 cpu_device.set_frequency(4000);
671 cpu_device.set_num_cores(2);
672 cpu_device.set_bandwidth(2000000);
673 return cpu_device;
674 }
675
676 // Three Conv2Ds with only two in fetch nodes.
CreateGrapplerItemWithConv2Ds()677 void CreateGrapplerItemWithConv2Ds() {
678 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
679 auto x = ops::RandomUniform(
680 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
681 auto y = ops::RandomUniform(
682 s.WithOpName("y"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
683 auto z = ops::RandomUniform(
684 s.WithOpName("z"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
685 auto f = ops::RandomUniform(
686 s.WithOpName("f"), {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
687 std::vector<int> strides = {1, 1, 1, 1};
688 auto c0 = ops::Conv2D(s.WithOpName("c0"), x, f, strides, "SAME");
689 auto c1 = ops::Conv2D(s.WithOpName("c1"), y, f, strides, "SAME");
690 auto c2 = ops::Conv2D(s.WithOpName("c2"), z, f, strides, "SAME");
691
692 grappler_item_ = absl::make_unique<GrapplerItem>();
693 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
694 grappler_item_->id = "test_conv2d_graph";
695 grappler_item_->fetch = {"c0", "c1"};
696
697 dependency_["c0"] = {"x", "f"};
698 dependency_["c1"] = {"y", "f"};
699 }
700
701 // A Conv2D with a variable.
CreateGrapplerItemWithConv2DAndVariable()702 void CreateGrapplerItemWithConv2DAndVariable() {
703 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
704 auto x = ops::RandomUniform(
705 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
706 auto f = ops::Variable(s.WithOpName("f"),
707 {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
708 std::vector<int> strides = {1, 1, 1, 1};
709 auto y = ops::Conv2D(s.WithOpName("y"), x, f, strides, "SAME");
710
711 grappler_item_ = absl::make_unique<GrapplerItem>();
712 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
713 grappler_item_->id = "test_conv2d_var_graph";
714
715 grappler_item_->fetch = {"y"};
716
717 dependency_["y"] = {"x", "f"};
718 }
719
CreateGrapplerItemWithMatmulChain()720 void CreateGrapplerItemWithMatmulChain() {
721 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
722 // Add control dependencies to ensure tests do not rely on specific
723 // manager and the order remains consistent for the test.
724 auto a = ops::RandomUniform(s.WithOpName("a"), {3200, 3200}, DT_FLOAT);
725 auto b = ops::RandomUniform(s.WithOpName("b").WithControlDependencies(a),
726 {3200, 3200}, DT_FLOAT);
727 auto c = ops::RandomUniform(s.WithOpName("c").WithControlDependencies(b),
728 {3200, 3200}, DT_FLOAT);
729 auto d = ops::RandomUniform(s.WithOpName("d").WithControlDependencies(c),
730 {3200, 3200}, DT_FLOAT);
731 auto e = ops::RandomUniform(s.WithOpName("e").WithControlDependencies(d),
732 {3200, 3200}, DT_FLOAT);
733
734 auto ab = ops::MatMul(s.WithOpName("ab").WithControlDependencies(e), a, b);
735 auto abc = ops::MatMul(s.WithOpName("abc"), ab, c);
736 auto abcd = ops::MatMul(s.WithOpName("abcd"), abc, d);
737 auto abcde = ops::MatMul(s.WithOpName("abcde"), abcd, e);
738
739 grappler_item_ = absl::make_unique<GrapplerItem>();
740 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
741 grappler_item_->id = "test_matmul_sequence_graph";
742 grappler_item_->fetch = {"abcde"};
743
744 dependency_["ab"] = {"a", "b"};
745 dependency_["abc"] = {"ab", "c"};
746 dependency_["abcd"] = {"abc", "d"};
747 dependency_["abcde"] = {"abcd", "e"};
748 }
749
750 // AddN that takes 4 tensors with 10x10x10x10.
CreateGrapplerItemWithAddN()751 void CreateGrapplerItemWithAddN() {
752 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
753 auto x = ops::RandomUniform(s.WithOpName("x"), {10, 10, 10, 10}, DT_FLOAT);
754 auto y = ops::RandomUniform(s.WithOpName("y"), {10, 10, 10, 10}, DT_FLOAT);
755 auto z = ops::RandomUniform(s.WithOpName("z"), {10, 10, 10, 10}, DT_FLOAT);
756 auto w = ops::RandomUniform(s.WithOpName("w"), {10, 10, 10, 10}, DT_FLOAT);
757 OutputList input_tensors = {x, y, z, w};
758 auto out = ops::AddN(s.WithOpName("out"), input_tensors);
759
760 grappler_item_ = absl::make_unique<GrapplerItem>();
761 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
762 grappler_item_->id = "test_addn_graph";
763 grappler_item_->fetch = {"out"};
764
765 dependency_["out"] = {"x", "y", "z", "w"};
766 }
767
768 // Graph with some placeholder feed nodes that are not in the fetch fan-in.
CreateGrapplerItemWithUnnecessaryPlaceholderNodes()769 void CreateGrapplerItemWithUnnecessaryPlaceholderNodes() {
770 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
771 auto unnecessary = ops::Placeholder(s.WithOpName("unnecessary"), DT_FLOAT);
772 auto x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT);
773
774 grappler_item_ = absl::make_unique<GrapplerItem>();
775 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
776
777 grappler_item_->id = "test_extra_placeholders";
778 grappler_item_->fetch = {"x"};
779
780 // Grappler Item Builder puts all placeholder nodes into the feed
781 // list by default.
782 grappler_item_->feed = {{"x", Tensor()}, {"unnecessary", Tensor()}};
783 }
784
785 // NoOp that takes 7 NoOps as control dependency.
CreateGrapplerItemWithControlDependency()786 void CreateGrapplerItemWithControlDependency() {
787 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
788 std::vector<string> input_noop_names = {"x", "y", "z", "w", "u", "v", "t"};
789 std::vector<Operation> input_tensors;
790 for (const auto& input : input_noop_names) {
791 auto x = ops::NoOp(s.WithOpName(input));
792 input_tensors.push_back(x.operation);
793 }
794 auto out =
795 ops::NoOp(s.WithControlDependencies(input_tensors).WithOpName("out"));
796
797 grappler_item_ = absl::make_unique<GrapplerItem>();
798 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
799
800 grappler_item_->id = "test_control_dependency_graph";
801 grappler_item_->fetch = {"out"};
802
803 dependency_["out"] = input_noop_names;
804 }
805
CreateGrapplerItemWithAddFromOneTensor()806 void CreateGrapplerItemWithAddFromOneTensor() {
807 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
808 auto x = tensorflow::ops::RandomUniform(
809 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
810
811 auto y = tensorflow::ops::Add(s.WithOpName("y"), x, x);
812 Output fetch = ops::Identity(s.WithOpName("fetch"), y);
813
814 grappler_item_ = absl::make_unique<GrapplerItem>();
815 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
816
817 grappler_item_->id = "test_add_from_one_tensor";
818 grappler_item_->fetch = {"fetch"};
819
820 dependency_["fetch"] = {"y"};
821 dependency_["y"] = {"x"};
822 }
823
CreateGrapplerItemWithSwitchMergeInput()824 void CreateGrapplerItemWithSwitchMergeInput() {
825 // sw = Switch(x, pred)
826 // a = Add(S:1, b)
827 // m = Merge(sw:0, a)
828 // y = Add(m, z)
829
830 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
831 auto x = ops::RandomUniform(
832 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
833 auto pred = ops::Const(s.WithOpName("pred"), false, {});
834 auto sw = ops::Switch(s.WithOpName("switch"), x, pred);
835 auto b = ops::RandomUniform(
836 s.WithOpName("b"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
837 auto a = ops::Add(s.WithOpName("a"), sw.output_true, b);
838 auto m = ops::Merge(s.WithOpName("m"), {sw.output_false, a.z});
839 auto z = ops::RandomUniform(
840 s.WithOpName("z"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
841 auto y = ops::Add(s.WithOpName("y"), m.output, z);
842
843 grappler_item_ = absl::make_unique<GrapplerItem>();
844 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
845
846 grappler_item_->id = "test_add_merge_switch";
847 grappler_item_->fetch = {"y"};
848
849 dependency_["y"] = {"m", "z"};
850 }
851
852 // FusedBN [an op with multiple outputs] with multiple consumers (including
853 // control dependency).
CreateGrapplerItemWithBatchNorm()854 void CreateGrapplerItemWithBatchNorm() {
855 Scope s = Scope::NewRootScope().WithDevice(kCPU0);
856 auto x = ops::RandomUniform(
857 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
858 auto scale =
859 ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
860 auto offset =
861 ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
862 auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
863 auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
864
865 auto batch_norm = ops::FusedBatchNorm(
866 s.WithOpName("bn"), x, scale, offset, mean, var,
867 ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
868 auto y = batch_norm.y;
869 auto batch_mean = batch_norm.batch_mean;
870 auto batch_var = batch_norm.batch_variance;
871
872 auto z1 = ops::Add(s.WithOpName("z1"), x, y);
873 auto z2 = ops::Add(s.WithOpName("z2"), batch_var, batch_var);
874 auto z3 = ops::Add(s.WithOpName("z3"), batch_var, batch_var);
875 std::vector<Operation> input_tensors = {
876 batch_mean.op(),
877 z1.z.op(),
878 z2.z.op(),
879 z3.z.op(),
880 };
881 auto z4 = ops::NoOp(s.WithControlDependencies(batch_var).WithOpName("z4"));
882
883 grappler_item_ = absl::make_unique<GrapplerItem>();
884 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
885
886 grappler_item_->id = "test_complex_dependency_graph";
887 grappler_item_->fetch = {"z1", "z2", "z3", "z4"};
888
889 dependency_["bn"] = {"x", "scale", "offset", "mean", "var"};
890 dependency_["z1"] = {"x", "bn"};
891 dependency_["z2"] = {"bn"};
892 dependency_["z3"] = {"bn"};
893 dependency_["z4"] = {"bn"};
894 }
895
CreateGrapplerItemWithSendRecv()896 void CreateGrapplerItemWithSendRecv() {
897 const string gdef_ascii = R"EOF(
898 node {
899 name: "Const"
900 op: "Const"
901 device: "/job:localhost/replica:0/task:0/device:CPU:0"
902 attr {
903 key: "dtype"
904 value {
905 type: DT_FLOAT
906 }
907 }
908 attr {
909 key: "value"
910 value {
911 tensor {
912 dtype: DT_FLOAT
913 tensor_shape {
914 }
915 float_val: 3.1415
916 }
917 }
918 }
919 }
920 node {
921 name: "Send"
922 op: "_Send"
923 input: "Const"
924 device: "/job:localhost/replica:0/task:0/device:CPU:0"
925 attr {
926 key: "T"
927 value {
928 type: DT_FLOAT
929 }
930 }
931 attr {
932 key: "client_terminated"
933 value {
934 b: false
935 }
936 }
937 attr {
938 key: "recv_device"
939 value {
940 s: "/job:localhost/replica:0/task:0/device:CPU:0"
941 }
942 }
943 attr {
944 key: "send_device"
945 value {
946 s: "/job:localhost/replica:0/task:0/device:CPU:0"
947 }
948 }
949 attr {
950 key: "send_device_incarnation"
951 value {
952 i: 0
953 }
954 }
955 attr {
956 key: "tensor_name"
957 value {
958 s: "test"
959 }
960 }
961 }
962 node {
963 name: "Recv"
964 op: "_Recv"
965 device: "/job:localhost/replica:0/task:0/device:CPU:0"
966 attr {
967 key: "client_terminated"
968 value {
969 b: false
970 }
971 }
972 attr {
973 key: "recv_device"
974 value {
975 s: "/job:localhost/replica:0/task:0/device:CPU:0"
976 }
977 }
978 attr {
979 key: "send_device"
980 value {
981 s: "/job:localhost/replica:0/task:0/device:CPU:0"
982 }
983 }
984 attr {
985 key: "send_device_incarnation"
986 value {
987 i: 0
988 }
989 }
990 attr {
991 key: "tensor_name"
992 value {
993 s: "test"
994 }
995 }
996 attr {
997 key: "tensor_type"
998 value {
999 type: DT_FLOAT
1000 }
1001 }
1002 }
1003 library {
1004 }
1005 versions {
1006 producer: 24
1007 }
1008 )EOF";
1009
1010 grappler_item_ = absl::make_unique<GrapplerItem>();
1011
1012 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
1013 &grappler_item_->graph));
1014 grappler_item_->id = "test_graph";
1015 grappler_item_->fetch = {"Recv"};
1016 }
1017
CreateGrapplerItemWithRecvWithoutSend()1018 void CreateGrapplerItemWithRecvWithoutSend() {
1019 const string gdef_ascii = R"EOF(
1020 node {
1021 name: "Recv"
1022 op: "_Recv"
1023 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1024 attr {
1025 key: "client_terminated"
1026 value {
1027 b: false
1028 }
1029 }
1030 attr {
1031 key: "recv_device"
1032 value {
1033 s: "/job:localhost/replica:0/task:0/device:CPU:0"
1034 }
1035 }
1036 attr {
1037 key: "send_device"
1038 value {
1039 s: "/job:localhost/replica:0/task:0/device:CPU:0"
1040 }
1041 }
1042 attr {
1043 key: "send_device_incarnation"
1044 value {
1045 i: 0
1046 }
1047 }
1048 attr {
1049 key: "tensor_name"
1050 value {
1051 s: "test"
1052 }
1053 }
1054 attr {
1055 key: "tensor_type"
1056 value {
1057 type: DT_FLOAT
1058 }
1059 }
1060 }
1061 library {
1062 }
1063 versions {
1064 producer: 24
1065 }
1066 )EOF";
1067
1068 grappler_item_ = absl::make_unique<GrapplerItem>();
1069 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
1070 &grappler_item_->graph));
1071 grappler_item_->id = "test_graph";
1072 grappler_item_->fetch = {"Recv"};
1073 }
1074
1075 // A simple while loop
CreateGrapplerItemWithLoop()1076 void CreateGrapplerItemWithLoop() {
1077 // Test graph produced in python using:
1078 /*
1079 with tf.Graph().as_default():
1080 i0 = tf.constant(0)
1081 m0 = tf.ones([2, 2])
1082 c = lambda i, m: i < 10
1083 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
1084 r = tf.while_loop(
1085 c, b, loop_vars=[i0, m0],
1086 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
1087 with open('/tmp/graph.pbtxt', 'w') as f:
1088 f.write(str(tf.get_default_graph().as_graph_def()))
1089 */
1090 const string gdef_ascii = R"EOF(
1091 node {
1092 name: "Const"
1093 op: "Const"
1094 attr {
1095 key: "dtype"
1096 value {
1097 type: DT_INT32
1098 }
1099 }
1100 attr {
1101 key: "value"
1102 value {
1103 tensor {
1104 dtype: DT_INT32
1105 tensor_shape {
1106 }
1107 int_val: 0
1108 }
1109 }
1110 }
1111 }
1112 node {
1113 name: "ones"
1114 op: "Const"
1115 attr {
1116 key: "dtype"
1117 value {
1118 type: DT_FLOAT
1119 }
1120 }
1121 attr {
1122 key: "value"
1123 value {
1124 tensor {
1125 dtype: DT_FLOAT
1126 tensor_shape {
1127 dim {
1128 size: 2
1129 }
1130 dim {
1131 size: 2
1132 }
1133 }
1134 float_val: 1.0
1135 }
1136 }
1137 }
1138 }
1139 node {
1140 name: "while/Enter"
1141 op: "Enter"
1142 input: "Const"
1143 attr {
1144 key: "T"
1145 value {
1146 type: DT_INT32
1147 }
1148 }
1149 attr {
1150 key: "frame_name"
1151 value {
1152 s: "while/while/"
1153 }
1154 }
1155 attr {
1156 key: "is_constant"
1157 value {
1158 b: false
1159 }
1160 }
1161 attr {
1162 key: "parallel_iterations"
1163 value {
1164 i: 10
1165 }
1166 }
1167 }
1168 node {
1169 name: "while/Enter_1"
1170 op: "Enter"
1171 input: "ones"
1172 attr {
1173 key: "T"
1174 value {
1175 type: DT_FLOAT
1176 }
1177 }
1178 attr {
1179 key: "frame_name"
1180 value {
1181 s: "while/while/"
1182 }
1183 }
1184 attr {
1185 key: "is_constant"
1186 value {
1187 b: false
1188 }
1189 }
1190 attr {
1191 key: "parallel_iterations"
1192 value {
1193 i: 10
1194 }
1195 }
1196 }
1197 node {
1198 name: "while/Merge"
1199 op: "Merge"
1200 input: "while/Enter"
1201 input: "while/NextIteration"
1202 attr {
1203 key: "N"
1204 value {
1205 i: 2
1206 }
1207 }
1208 attr {
1209 key: "T"
1210 value {
1211 type: DT_INT32
1212 }
1213 }
1214 }
1215 node {
1216 name: "while/Merge_1"
1217 op: "Merge"
1218 input: "while/Enter_1"
1219 input: "while/NextIteration_1"
1220 attr {
1221 key: "N"
1222 value {
1223 i: 2
1224 }
1225 }
1226 attr {
1227 key: "T"
1228 value {
1229 type: DT_FLOAT
1230 }
1231 }
1232 }
1233 node {
1234 name: "while/Less/y"
1235 op: "Const"
1236 input: "^while/Merge"
1237 attr {
1238 key: "dtype"
1239 value {
1240 type: DT_INT32
1241 }
1242 }
1243 attr {
1244 key: "value"
1245 value {
1246 tensor {
1247 dtype: DT_INT32
1248 tensor_shape {
1249 }
1250 int_val: 10
1251 }
1252 }
1253 }
1254 }
1255 node {
1256 name: "while/Less"
1257 op: "Less"
1258 input: "while/Merge"
1259 input: "while/Less/y"
1260 attr {
1261 key: "T"
1262 value {
1263 type: DT_INT32
1264 }
1265 }
1266 }
1267 node {
1268 name: "while/LoopCond"
1269 op: "LoopCond"
1270 input: "while/Less"
1271 }
1272 node {
1273 name: "while/Switch"
1274 op: "Switch"
1275 input: "while/Merge"
1276 input: "while/LoopCond"
1277 attr {
1278 key: "T"
1279 value {
1280 type: DT_INT32
1281 }
1282 }
1283 attr {
1284 key: "_class"
1285 value {
1286 list {
1287 s: "loc:@while/Merge"
1288 }
1289 }
1290 }
1291 }
1292 node {
1293 name: "while/Switch_1"
1294 op: "Switch"
1295 input: "while/Merge_1"
1296 input: "while/LoopCond"
1297 attr {
1298 key: "T"
1299 value {
1300 type: DT_FLOAT
1301 }
1302 }
1303 attr {
1304 key: "_class"
1305 value {
1306 list {
1307 s: "loc:@while/Merge_1"
1308 }
1309 }
1310 }
1311 }
1312 node {
1313 name: "while/Identity"
1314 op: "Identity"
1315 input: "while/Switch:1"
1316 attr {
1317 key: "T"
1318 value {
1319 type: DT_INT32
1320 }
1321 }
1322 }
1323 node {
1324 name: "while/Identity_1"
1325 op: "Identity"
1326 input: "while/Switch_1:1"
1327 attr {
1328 key: "T"
1329 value {
1330 type: DT_FLOAT
1331 }
1332 }
1333 }
1334 node {
1335 name: "while/add/y"
1336 op: "Const"
1337 input: "^while/Identity"
1338 attr {
1339 key: "dtype"
1340 value {
1341 type: DT_INT32
1342 }
1343 }
1344 attr {
1345 key: "value"
1346 value {
1347 tensor {
1348 dtype: DT_INT32
1349 tensor_shape {
1350 }
1351 int_val: 1
1352 }
1353 }
1354 }
1355 }
1356 node {
1357 name: "while/add"
1358 op: "Add"
1359 input: "while/Identity"
1360 input: "while/add/y"
1361 attr {
1362 key: "T"
1363 value {
1364 type: DT_INT32
1365 }
1366 }
1367 }
1368 node {
1369 name: "while/concat/axis"
1370 op: "Const"
1371 input: "^while/Identity"
1372 attr {
1373 key: "dtype"
1374 value {
1375 type: DT_INT32
1376 }
1377 }
1378 attr {
1379 key: "value"
1380 value {
1381 tensor {
1382 dtype: DT_INT32
1383 tensor_shape {
1384 }
1385 int_val: 0
1386 }
1387 }
1388 }
1389 }
1390 node {
1391 name: "while/concat"
1392 op: "ConcatV2"
1393 input: "while/Identity_1"
1394 input: "while/Identity_1"
1395 input: "while/concat/axis"
1396 attr {
1397 key: "N"
1398 value {
1399 i: 2
1400 }
1401 }
1402 attr {
1403 key: "T"
1404 value {
1405 type: DT_FLOAT
1406 }
1407 }
1408 attr {
1409 key: "Tidx"
1410 value {
1411 type: DT_INT32
1412 }
1413 }
1414 }
1415 node {
1416 name: "while/NextIteration"
1417 op: "NextIteration"
1418 input: "while/add"
1419 attr {
1420 key: "T"
1421 value {
1422 type: DT_INT32
1423 }
1424 }
1425 }
1426 node {
1427 name: "while/NextIteration_1"
1428 op: "NextIteration"
1429 input: "while/concat"
1430 attr {
1431 key: "T"
1432 value {
1433 type: DT_FLOAT
1434 }
1435 }
1436 }
1437 node {
1438 name: "while/Exit"
1439 op: "Exit"
1440 input: "while/Switch"
1441 attr {
1442 key: "T"
1443 value {
1444 type: DT_INT32
1445 }
1446 }
1447 }
1448 node {
1449 name: "while/Exit_1"
1450 op: "Exit"
1451 input: "while/Switch_1"
1452 attr {
1453 key: "T"
1454 value {
1455 type: DT_FLOAT
1456 }
1457 }
1458 }
1459 versions {
1460 producer: 21
1461 }
1462 )EOF";
1463
1464 grappler_item_ = absl::make_unique<GrapplerItem>();
1465 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
1466 &grappler_item_->graph));
1467 grappler_item_->id = "test_graph";
1468 grappler_item_->fetch = {"while/Exit", "while/Exit_1"};
1469 }
1470
1471 // A simple while loop strengthened with Switch outputs xxx.
CreateGrapplerItemWithLoopAnnotated()1472 void CreateGrapplerItemWithLoopAnnotated() {
1473 // Test graph produced in python using:
1474 /*
1475 with tf.Graph().as_default():
1476 i0 = tf.constant(0)
1477 m0 = tf.ones([2, 2])
1478 c = lambda i, m: i < 10
1479 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
1480 r = tf.while_loop(
1481 c, b, loop_vars=[i0, m0],
1482 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
1483 with open('/tmp/graph.pbtxt', 'w') as f:
1484 f.write(str(tf.get_default_graph().as_graph_def()))
1485 */
1486 const string gdef_ascii = R"EOF(
1487 node {
1488 name: "Const"
1489 op: "Const"
1490 attr {
1491 key: "dtype"
1492 value {
1493 type: DT_INT32
1494 }
1495 }
1496 attr {
1497 key: "value"
1498 value {
1499 tensor {
1500 dtype: DT_INT32
1501 tensor_shape {
1502 }
1503 int_val: 0
1504 }
1505 }
1506 }
1507 attr {
1508 key: "_execution_count"
1509 value {
1510 i: 1
1511 }
1512 }
1513 }
1514 node {
1515 name: "ones"
1516 op: "Const"
1517 attr {
1518 key: "dtype"
1519 value {
1520 type: DT_FLOAT
1521 }
1522 }
1523 attr {
1524 key: "value"
1525 value {
1526 tensor {
1527 dtype: DT_FLOAT
1528 tensor_shape {
1529 dim {
1530 size: 2
1531 }
1532 dim {
1533 size: 2
1534 }
1535 }
1536 float_val: 1.0
1537 }
1538 }
1539 }
1540 attr {
1541 key: "_execution_count"
1542 value {
1543 i: 1
1544 }
1545 }
1546 }
1547 node {
1548 name: "while/Enter"
1549 op: "Enter"
1550 input: "Const"
1551 attr {
1552 key: "T"
1553 value {
1554 type: DT_INT32
1555 }
1556 }
1557 attr {
1558 key: "frame_name"
1559 value {
1560 s: "while/while/"
1561 }
1562 }
1563 attr {
1564 key: "is_constant"
1565 value {
1566 b: false
1567 }
1568 }
1569 attr {
1570 key: "parallel_iterations"
1571 value {
1572 i: 10
1573 }
1574 }
1575 attr {
1576 key: "_execution_count"
1577 value {
1578 i: 1
1579 }
1580 }
1581 }
1582 node {
1583 name: "while/Enter_1"
1584 op: "Enter"
1585 input: "ones"
1586 attr {
1587 key: "T"
1588 value {
1589 type: DT_FLOAT
1590 }
1591 }
1592 attr {
1593 key: "frame_name"
1594 value {
1595 s: "while/while/"
1596 }
1597 }
1598 attr {
1599 key: "is_constant"
1600 value {
1601 b: false
1602 }
1603 }
1604 attr {
1605 key: "parallel_iterations"
1606 value {
1607 i: 10
1608 }
1609 }
1610 attr {
1611 key: "_execution_count"
1612 value {
1613 i: 1
1614 }
1615 }
1616 }
1617 node {
1618 name: "while/Merge"
1619 op: "Merge"
1620 input: "while/Enter"
1621 input: "while/NextIteration"
1622 attr {
1623 key: "N"
1624 value {
1625 i: 2
1626 }
1627 }
1628 attr {
1629 key: "T"
1630 value {
1631 type: DT_INT32
1632 }
1633 }
1634 attr {
1635 key: "_execution_count"
1636 value {
1637 i: 10
1638 }
1639 }
1640 }
1641 node {
1642 name: "while/Merge_1"
1643 op: "Merge"
1644 input: "while/Enter_1"
1645 input: "while/NextIteration_1"
1646 attr {
1647 key: "N"
1648 value {
1649 i: 2
1650 }
1651 }
1652 attr {
1653 key: "T"
1654 value {
1655 type: DT_FLOAT
1656 }
1657 }
1658 attr {
1659 key: "_execution_count"
1660 value {
1661 i: 10
1662 }
1663 }
1664 }
1665 node {
1666 name: "while/Less/y"
1667 op: "Const"
1668 input: "^while/Merge"
1669 attr {
1670 key: "dtype"
1671 value {
1672 type: DT_INT32
1673 }
1674 }
1675 attr {
1676 key: "value"
1677 value {
1678 tensor {
1679 dtype: DT_INT32
1680 tensor_shape {
1681 }
1682 int_val: 10
1683 }
1684 }
1685 }
1686 attr {
1687 key: "_execution_count"
1688 value {
1689 i: 10
1690 }
1691 }
1692 }
1693 node {
1694 name: "while/Less"
1695 op: "Less"
1696 input: "while/Merge"
1697 input: "while/Less/y"
1698 attr {
1699 key: "T"
1700 value {
1701 type: DT_INT32
1702 }
1703 }
1704 attr {
1705 key: "_execution_count"
1706 value {
1707 i: 10
1708 }
1709 }
1710 }
1711 node {
1712 name: "while/LoopCond"
1713 op: "LoopCond"
1714 input: "while/Less"
1715 attr {
1716 key: "_execution_count"
1717 value {
1718 i: 10
1719 }
1720 }
1721 }
1722 node {
1723 name: "while/Switch"
1724 op: "Switch"
1725 input: "while/Merge"
1726 input: "while/LoopCond"
1727 attr {
1728 key: "T"
1729 value {
1730 type: DT_INT32
1731 }
1732 }
1733 attr {
1734 key: "_class"
1735 value {
1736 list {
1737 s: "loc:@while/Merge"
1738 }
1739 }
1740 }
1741 attr {
1742 key: "_execution_count"
1743 value {
1744 i: 11
1745 }
1746 }
1747 attr {
1748 key: "_output_slot_vector"
1749 value {
1750 list {
1751 i: 1
1752 i: 1
1753 i: 1
1754 i: 1
1755 i: 1
1756 i: 1
1757 i: 1
1758 i: 1
1759 i: 1
1760 i: 1
1761 i: 0
1762 }
1763 }
1764 }
1765 }
1766 node {
1767 name: "while/Switch_1"
1768 op: "Switch"
1769 input: "while/Merge_1"
1770 input: "while/LoopCond"
1771 attr {
1772 key: "T"
1773 value {
1774 type: DT_FLOAT
1775 }
1776 }
1777 attr {
1778 key: "_class"
1779 value {
1780 list {
1781 s: "loc:@while/Merge_1"
1782 }
1783 }
1784 }
1785 attr {
1786 key: "_execution_count"
1787 value {
1788 i: 11
1789 }
1790 }
1791 attr {
1792 key: "_output_slot_vector"
1793 value {
1794 list {
1795 i: 1
1796 i: 1
1797 i: 1
1798 i: 1
1799 i: 1
1800 i: 1
1801 i: 1
1802 i: 1
1803 i: 1
1804 i: 1
1805 i: 0
1806 }
1807 }
1808 }
1809 }
1810 node {
1811 name: "while/Identity"
1812 op: "Identity"
1813 input: "while/Switch:1"
1814 attr {
1815 key: "T"
1816 value {
1817 type: DT_INT32
1818 }
1819 }
1820 attr {
1821 key: "_execution_count"
1822 value {
1823 i: 10
1824 }
1825 }
1826 }
1827 node {
1828 name: "while/Identity_1"
1829 op: "Identity"
1830 input: "while/Switch_1:1"
1831 attr {
1832 key: "T"
1833 value {
1834 type: DT_FLOAT
1835 }
1836 }
1837 attr {
1838 key: "_execution_count"
1839 value {
1840 i: 10
1841 }
1842 }
1843 }
1844 node {
1845 name: "while/add/y"
1846 op: "Const"
1847 input: "^while/Identity"
1848 attr {
1849 key: "dtype"
1850 value {
1851 type: DT_INT32
1852 }
1853 }
1854 attr {
1855 key: "value"
1856 value {
1857 tensor {
1858 dtype: DT_INT32
1859 tensor_shape {
1860 }
1861 int_val: 1
1862 }
1863 }
1864 }
1865 attr {
1866 key: "_execution_count"
1867 value {
1868 i: 10
1869 }
1870 }
1871 }
1872 node {
1873 name: "while/add"
1874 op: "Add"
1875 input: "while/Identity"
1876 input: "while/add/y"
1877 attr {
1878 key: "T"
1879 value {
1880 type: DT_INT32
1881 }
1882 }
1883 attr {
1884 key: "_execution_count"
1885 value {
1886 i: 10
1887 }
1888 }
1889 }
1890 node {
1891 name: "while/concat/axis"
1892 op: "Const"
1893 input: "^while/Identity"
1894 attr {
1895 key: "dtype"
1896 value {
1897 type: DT_INT32
1898 }
1899 }
1900 attr {
1901 key: "value"
1902 value {
1903 tensor {
1904 dtype: DT_INT32
1905 tensor_shape {
1906 }
1907 int_val: 0
1908 }
1909 }
1910 }
1911 attr {
1912 key: "_execution_count"
1913 value {
1914 i: 10
1915 }
1916 }
1917 }
1918 node {
1919 name: "while/concat"
1920 op: "ConcatV2"
1921 input: "while/Identity_1"
1922 input: "while/Identity_1"
1923 input: "while/concat/axis"
1924 attr {
1925 key: "N"
1926 value {
1927 i: 2
1928 }
1929 }
1930 attr {
1931 key: "T"
1932 value {
1933 type: DT_FLOAT
1934 }
1935 }
1936 attr {
1937 key: "Tidx"
1938 value {
1939 type: DT_INT32
1940 }
1941 }
1942 attr {
1943 key: "_execution_count"
1944 value {
1945 i: 10
1946 }
1947 }
1948 }
1949 node {
1950 name: "while/NextIteration"
1951 op: "NextIteration"
1952 input: "while/add"
1953 attr {
1954 key: "T"
1955 value {
1956 type: DT_INT32
1957 }
1958 }
1959 attr {
1960 key: "_execution_count"
1961 value {
1962 i: 10
1963 }
1964 }
1965 }
1966 node {
1967 name: "while/NextIteration_1"
1968 op: "NextIteration"
1969 input: "while/concat"
1970 attr {
1971 key: "T"
1972 value {
1973 type: DT_FLOAT
1974 }
1975 }
1976 attr {
1977 key: "_execution_count"
1978 value {
1979 i: 10
1980 }
1981 }
1982 }
1983 node {
1984 name: "while/Exit"
1985 op: "Exit"
1986 input: "while/Switch"
1987 attr {
1988 key: "T"
1989 value {
1990 type: DT_INT32
1991 }
1992 }
1993 attr {
1994 key: "_execution_count"
1995 value {
1996 i: 1
1997 }
1998 }
1999 }
2000 node {
2001 name: "while/Exit_1"
2002 op: "Exit"
2003 input: "while/Switch_1"
2004 attr {
2005 key: "T"
2006 value {
2007 type: DT_FLOAT
2008 }
2009 }
2010 attr {
2011 key: "_execution_count"
2012 value {
2013 i: 1
2014 }
2015 }
2016 }
2017 versions {
2018 producer: 21
2019 }
2020 )EOF";
2021
2022 grappler_item_.reset(new GrapplerItem);
2023 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
2024 &grappler_item_->graph));
2025 grappler_item_->id = "test_graph";
2026 grappler_item_->fetch = {"while/Exit", "while/Exit_1"};
2027 }
2028
2029 // A simple condition graph.
CreateGrapplerItemWithCondition()2030 void CreateGrapplerItemWithCondition() {
2031 // Handcrafted test graph: a/Less -> Switch -> First/Second -> Merge.
2032 const string gdef_ascii = R"EOF(
2033 node {
2034 name: "a"
2035 op: "Const"
2036 attr {
2037 key: "dtype"
2038 value {
2039 type: DT_FLOAT
2040 }
2041 }
2042 attr {
2043 key: "value"
2044 value {
2045 tensor {
2046 dtype: DT_FLOAT
2047 tensor_shape {
2048 }
2049 float_val: 2.0
2050 }
2051 }
2052 }
2053 }
2054 node {
2055 name: "Less"
2056 op: "Const"
2057 attr {
2058 key: "dtype"
2059 value {
2060 type: DT_BOOL
2061 }
2062 }
2063 attr {
2064 key: "value"
2065 value {
2066 tensor {
2067 dtype: DT_BOOL
2068 tensor_shape {
2069 }
2070 tensor_content: "\001"
2071 }
2072 }
2073 }
2074 }
2075 node {
2076 name: "Switch"
2077 op: "Switch"
2078 input: "a"
2079 input: "Less"
2080 attr {
2081 key: "T"
2082 value {
2083 type: DT_FLOAT
2084 }
2085 }
2086 }
2087 node {
2088 name: "First"
2089 op: "Identity"
2090 input: "Switch"
2091 attr {
2092 key: "T"
2093 value {
2094 type: DT_FLOAT
2095 }
2096 }
2097 }
2098 node {
2099 name: "Second"
2100 op: "Identity"
2101 input: "Switch:1"
2102 attr {
2103 key: "T"
2104 value {
2105 type: DT_FLOAT
2106 }
2107 }
2108 }
2109 node {
2110 name: "Merge"
2111 op: "Merge"
2112 input: "First"
2113 input: "Second"
2114 attr {
2115 key: "N"
2116 value {
2117 i: 2
2118 }
2119 }
2120 attr {
2121 key: "T"
2122 value {
2123 type: DT_FLOAT
2124 }
2125 }
2126 }
2127 versions {
2128 producer: 27
2129 })EOF";
2130
2131 grappler_item_ = absl::make_unique<GrapplerItem>();
2132 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
2133 &grappler_item_->graph));
2134 grappler_item_->id = "test_graph";
2135 grappler_item_->fetch = {"Merge"};
2136 }
2137
2138 // Create a FusedBatchNorm op that has multiple output ports.
CreateGrapplerItemWithInterDeviceTransfers()2139 void CreateGrapplerItemWithInterDeviceTransfers() {
2140 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
2141
2142 // Create a FusedBatchNorm op that has multiple output ports.
2143 auto x = ops::RandomUniform(
2144 s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
2145 auto scale =
2146 ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
2147 auto offset =
2148 ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
2149 auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
2150 auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
2151
2152 auto batch_norm = ops::FusedBatchNorm(
2153 s.WithOpName("bn"), x, scale, offset, mean, var,
2154 ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
2155 auto y = batch_norm.y;
2156 auto batch_mean = batch_norm.batch_mean;
2157 auto batch_var = batch_norm.batch_variance;
2158 // y1 and y2 take the same tensor, so there should be only 1 Send and Recv.
2159 auto y1 = ops::Identity(s.WithOpName("y1").WithDevice(kCPU1), y);
2160 auto y2 = ops::Identity(s.WithOpName("y2").WithDevice(kCPU1), y);
2161 // batch_mean1 and batch_var1 take different output ports, so each will
2162 // initiate Send/Recv.
2163 auto batch_mean1 = ops::Identity(
2164 s.WithOpName("batch_mean1").WithDevice(kCPU1), batch_mean);
2165 auto batch_var1 =
2166 ops::Identity(s.WithOpName("batch_var1").WithDevice(kCPU1), batch_var);
2167 // This is control dependency.
2168 auto control_dep = ops::NoOp(s.WithOpName("control_dep")
2169 .WithControlDependencies(y)
2170 .WithDevice(kCPU1));
2171
2172 grappler_item_ = absl::make_unique<GrapplerItem>();
2173 TF_CHECK_OK(s.ToGraphDef(&grappler_item_->graph));
2174 grappler_item_->id = "test_conv2d_graph";
2175 grappler_item_->fetch = {"y1", "y2", "batch_mean1", "batch_var1",
2176 "control_dep"};
2177
2178 dependency_["bn"] = {"x", "mean", "var"};
2179 dependency_["y1"] = {"bn"};
2180 dependency_["y2"] = {"bn"};
2181 dependency_["batch_mean1"] = {"bn"};
2182 dependency_["batch_var1"] = {"bn"};
2183 dependency_["control_dep"] = {"bn"};
2184 }
2185
2186 // Call this after creating grappler_item_ and setting up dependency_.
InitScheduler()2187 void InitScheduler() { TF_ASSERT_OK(scheduler_->Init(grappler_item_.get())); }
2188
2189 // Returns cost based on op.
SimplePredictCosts(const OpContext & op_context) const2190 Costs SimplePredictCosts(const OpContext& op_context) const {
2191 Costs c;
2192 int64 exec_cost = 0;
2193 if (op_context.op_info.op() == "MatMul") {
2194 exec_cost = 2000000000;
2195 } else if (op_context.op_info.op() == "RandomUniform") {
2196 exec_cost = 1000000000;
2197 } else {
2198 exec_cost = 1000;
2199 }
2200 c.execution_time = Costs::NanoSeconds(exec_cost);
2201 return c;
2202 }
2203
2204 // Call this after init scheduler_. Scheduler stops after executing
2205 // target_node.
RunScheduler(const string & target_node)2206 std::unordered_map<string, OpContext> RunScheduler(
2207 const string& target_node) {
2208 std::unordered_map<string, OpContext> ops_executed;
2209 bool more_nodes = true;
2210 do {
2211 OpContext op_context = scheduler_->GetCurrNode();
2212 ops_executed[op_context.name] = op_context;
2213 std::cout << op_context.name << std::endl;
2214
2215 Costs node_costs = SimplePredictCosts(op_context);
2216
2217 // Check scheduling order.
2218 auto it = dependency_.find(op_context.name);
2219 if (it != dependency_.end()) {
2220 for (const auto& preceding_node : it->second) {
2221 EXPECT_GT(ops_executed.count(preceding_node), 0);
2222 }
2223 }
2224 more_nodes = scheduler_->MarkCurrNodeExecuted(node_costs);
2225
2226 if (op_context.name == target_node) {
2227 // Scheduler has the state after executing the target node.
2228 break;
2229 }
2230 } while (more_nodes);
2231 return ops_executed;
2232 }
2233
2234 // Helper method for validating a vector.
2235 template <typename T>
ExpectVectorEq(const std::vector<T> & expected,const std::vector<T> & test_elements)2236 void ExpectVectorEq(const std::vector<T>& expected,
2237 const std::vector<T>& test_elements) {
2238 // Set of expected elements for an easy comparison.
2239 std::set<T> expected_set(expected.begin(), expected.end());
2240 for (const auto& element : test_elements) {
2241 EXPECT_GT(expected_set.count(element), 0);
2242 }
2243 EXPECT_EQ(expected.size(), test_elements.size());
2244 }
2245
2246 // Helper method that checks the name of nodes.
ValidateNodeDefs(const std::vector<string> & expected,const std::vector<const NodeDef * > & node_defs)2247 void ValidateNodeDefs(const std::vector<string>& expected,
2248 const std::vector<const NodeDef*>& node_defs) {
2249 std::vector<string> node_names;
2250 std::transform(node_defs.begin(), node_defs.end(),
2251 std::back_inserter(node_names),
2252 [](const NodeDef* node) { return node->name(); });
2253 ExpectVectorEq(expected, node_names);
2254 }
2255
2256 // Helper method for validating a set.
2257 template <typename T>
ExpectSetEq(const std::set<T> & expected,const std::set<T> & test_elements)2258 void ExpectSetEq(const std::set<T>& expected,
2259 const std::set<T>& test_elements) {
2260 for (const auto& element : test_elements) {
2261 EXPECT_GT(expected.count(element), 0);
2262 }
2263 EXPECT_EQ(expected.size(), test_elements.size());
2264 }
2265
2266 // Helper method for validating an unordered map.
2267 template <typename T, typename U>
ExpectUnorderedMapEq(const std::unordered_map<T,U> & expected,const std::unordered_map<T,U> & test_map)2268 void ExpectUnorderedMapEq(const std::unordered_map<T, U>& expected,
2269 const std::unordered_map<T, U>& test_map) {
2270 EXPECT_EQ(expected.size(), test_map.size());
2271 for (const auto& key_val : expected) {
2272 EXPECT_GT(test_map.count(key_val.first), 0);
2273 EXPECT_EQ(test_map.at(key_val.first), key_val.second);
2274 }
2275 }
2276
2277 // Helper method that checks name - port pairs.
ValidateMemoryUsageSnapshot(const std::vector<string> & expected_names,const int port_num_expected,const std::unordered_set<std::pair<const NodeDef *,int>,DeviceState::NodePairHash> & mem_usage_snapshot)2278 void ValidateMemoryUsageSnapshot(
2279 const std::vector<string>& expected_names, const int port_num_expected,
2280 const std::unordered_set<std::pair<const NodeDef*, int>,
2281 DeviceState::NodePairHash>& mem_usage_snapshot) {
2282 std::set<std::pair<string, int>> nodes_at_peak_mem_usage;
2283 std::transform(
2284 mem_usage_snapshot.begin(), mem_usage_snapshot.end(),
2285 std::inserter(nodes_at_peak_mem_usage, nodes_at_peak_mem_usage.begin()),
2286 [](const std::pair<const NodeDef*, int>& node_port) {
2287 return std::make_pair(node_port.first->name(), node_port.second);
2288 });
2289 std::set<std::pair<string, int>> expected;
2290 std::transform(expected_names.begin(), expected_names.end(),
2291 std::inserter(expected, expected.begin()),
2292 [port_num_expected](const string& name) {
2293 return std::make_pair(name, port_num_expected);
2294 });
2295 ExpectSetEq(expected, nodes_at_peak_mem_usage);
2296 }
2297
2298 // Helper method for checking nodes dependency.
ValidateDependencyChain(const std::unordered_map<string,int64> & start_times,const std::vector<string> & nodes_in_dependency_order)2299 void ValidateDependencyChain(
2300 const std::unordered_map<string, int64>& start_times,
2301 const std::vector<string>& nodes_in_dependency_order) {
2302 int64 prev_node_time = -1;
2303 for (const auto& node : nodes_in_dependency_order) {
2304 int64 curr_node_time = start_times.at(node);
2305 EXPECT_GE(curr_node_time, prev_node_time);
2306 prev_node_time = curr_node_time;
2307 }
2308 }
2309
2310 // cluster_ and scheduler_ are initialized in the c'tor.
2311 std::unique_ptr<VirtualCluster> cluster_;
2312 std::unique_ptr<TestVirtualScheduler> scheduler_;
2313 FirstReadyManager first_ready_manager_;
2314 CompositeNodeManager composite_node_manager_;
2315
2316 // grappler_item_ will be initialized differently for each test case.
2317 std::unique_ptr<GrapplerItem> grappler_item_;
2318 // Node name -> its preceding nodes map for testing scheduling order.
2319 std::unordered_map<string, std::vector<string>> dependency_;
2320
2321 // Shared params for Conv2D related graphs:
2322 const int batch_size_ = 4;
2323 const int width_ = 10;
2324 const int height_ = 10;
2325 const int depth_in_ = 8;
2326 const int kernel_ = 3;
2327 const int depth_out_ = 16;
2328 };
2329
2330 // Create small graph, run predict costs on it, make sure the costs from the
2331 // summary match the hand-calculated costs.
TEST_F(VirtualSchedulerTest,SummaryCostTest)2332 TEST_F(VirtualSchedulerTest, SummaryCostTest) {
2333 // Run matmul test.
2334 CreateGrapplerItemWithMatmulChain();
2335 InitScheduler();
2336 auto ops_executed = RunScheduler("");
2337 Costs c = scheduler_->Summary();
2338
2339 // RandomUniform - 5 * 1s
2340 // Matmuls - 4 * 2s = 8
2341 // Misc - 5 * 1us
2342 // Total: 13000005
2343 EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count());
2344 EXPECT_EQ(grappler_item_->graph.node_size(), c.num_ops_total);
2345 EXPECT_FALSE(c.inaccurate);
2346 EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2347 }
2348
2349 // Like the above SummaryCostTest, but makes sure the stepstats timeline is
2350 // correct.
TEST_F(VirtualSchedulerTest,SummaryCostStepStatsTest)2351 TEST_F(VirtualSchedulerTest, SummaryCostStepStatsTest) {
2352 // Run matmul test.
2353 CreateGrapplerItemWithMatmulChain();
2354 InitScheduler();
2355 auto ops_executed = RunScheduler("");
2356 RunMetadata metadata;
2357 Costs c = scheduler_->Summary(&metadata);
2358 StepStats stepstats = metadata.step_stats();
2359 EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count());
2360 EXPECT_EQ(grappler_item_->graph.node_size(), c.num_ops_total);
2361 EXPECT_FALSE(c.inaccurate);
2362 EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2363
2364 // Should only be 1 device!
2365 EXPECT_EQ(1, stepstats.dev_stats().size());
2366
2367 // Create a map of op name -> start and end times (micros).
2368 std::map<string, std::pair<int64, int64>> start_end_times;
2369 for (const auto& device_step_stats : stepstats.dev_stats()) {
2370 for (const auto& stats : device_step_stats.node_stats()) {
2371 int64 start = stats.all_start_micros();
2372 int64 end = start + stats.all_end_rel_micros();
2373 start_end_times[stats.node_name()] = std::pair<int64, int64>(start, end);
2374
2375 // Make sure that the output properties are correct for
2376 // MatMul and RandomUniform operations.
2377 // We only check for dtype, and shape (excluding alloc)
2378 // since alloc is not set by the virtual scheduler.
2379 if (stats.timeline_label() == "MatMul" ||
2380 stats.timeline_label() == "RandomUniform") {
2381 EXPECT_EQ(1, stats.output().size());
2382 for (const auto& output : stats.output()) {
2383 EXPECT_EQ(DT_FLOAT, output.tensor_description().dtype());
2384 EXPECT_EQ(2, output.tensor_description().shape().dim().size());
2385 for (const auto& dim : output.tensor_description().shape().dim()) {
2386 EXPECT_EQ(3200, dim.size());
2387 }
2388 }
2389 }
2390 }
2391 }
2392
2393 // The base start_time is the time to compute RandomUniforms
2394 int64 cur_time = static_cast<int64>(5000005);
2395 // The increment is the execution time of one matmul. See
2396 // CreateGrapplerItemWithMatmulChain for details.
2397 int64 increment = static_cast<int64>(2000000);
2398 auto op_names = {"ab", "abc", "abcd", "abcde"};
2399 for (const auto& op_name : op_names) {
2400 int64 actual_start = start_end_times[op_name].first;
2401 int64 actual_end = start_end_times[op_name].second;
2402 int64 expected_start = cur_time;
2403 int64 expected_end = cur_time + increment;
2404 EXPECT_EQ(expected_start, actual_start);
2405 EXPECT_EQ(expected_end, actual_end);
2406 cur_time += increment;
2407 }
2408 }
2409
TEST_F(VirtualSchedulerTest,InitAndBasicScheduling)2410 TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) {
2411 // Init.
2412 CreateGrapplerItemWithConv2Ds();
2413 InitScheduler();
2414
2415 // Run the scheduler.
2416 auto ops_executed = RunScheduler(""); // Run all the nodes.
2417
2418 // [const and rand] * (x, y, f), and c0 and c1. c2 and z shouldn't be
2419 // executed.
2420 EXPECT_EQ(8, ops_executed.size());
2421
2422 // x, y, f, c0, and c1 should be in the ops executed.
2423 EXPECT_GT(ops_executed.count("x"), 0);
2424 EXPECT_GT(ops_executed.count("y"), 0);
2425 EXPECT_GT(ops_executed.count("f"), 0);
2426 EXPECT_GT(ops_executed.count("c0"), 0);
2427 EXPECT_GT(ops_executed.count("c1"), 0);
2428
2429 // z and c2 shouldn't be part of it.
2430 EXPECT_EQ(ops_executed.count("z"), 0);
2431 EXPECT_EQ(ops_executed.count("c2"), 0);
2432
2433 // Check input / output properties.
2434 EXPECT_EQ(1, ops_executed["x"].op_info.outputs_size());
2435 EXPECT_EQ(1, ops_executed["y"].op_info.outputs_size());
2436 EXPECT_EQ(1, ops_executed["f"].op_info.outputs_size());
2437 EXPECT_EQ(2, ops_executed["c0"].op_info.inputs_size());
2438 EXPECT_EQ(2, ops_executed["c1"].op_info.inputs_size());
2439 }
2440
TEST_F(VirtualSchedulerTest,MemoryUsage)2441 TEST_F(VirtualSchedulerTest, MemoryUsage) {
2442 // Init.
2443 CreateGrapplerItemWithAddN();
2444 InitScheduler();
2445
2446 // Run the scheduler.
2447 RunScheduler("");
2448
2449 const auto* device_states = scheduler_->GetDeviceStates();
2450 const auto& cpu_state = device_states->at(kCPU0);
2451
2452 // out node adds 4 tensors, each with 10x10x10x10, so the peak memory usage
2453 // is 4 x the input tensor size while executing the out node.
2454 int64 one_input_node_size = 4 * 10 * 10 * 10 * 10;
2455 const std::vector<string> expected_names = {"x", "y", "z", "w"};
2456 EXPECT_EQ(expected_names.size() * one_input_node_size,
2457 cpu_state.max_memory_usage);
2458 ValidateMemoryUsageSnapshot(expected_names, 0 /* port_num_expected */,
2459 cpu_state.mem_usage_snapshot_at_peak);
2460 ExpectUnorderedMapEq(
2461 {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 64)},
2462 scheduler_->GetPersistentMemoryUsage());
2463 ExpectUnorderedMapEq(
2464 {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 160000)},
2465 scheduler_->GetPeakMemoryUsage());
2466 }
2467
TEST_F(VirtualSchedulerTest,UnnecessaryFeedNodes)2468 TEST_F(VirtualSchedulerTest, UnnecessaryFeedNodes) {
2469 CreateGrapplerItemWithUnnecessaryPlaceholderNodes();
2470 InitScheduler();
2471
2472 // Test that scheduler can run graphs with extra unnecessary feed nodes.
2473 auto ops_executed = RunScheduler("");
2474 ASSERT_EQ(1, ops_executed.size());
2475 ASSERT_EQ(ops_executed.count("x"), 1);
2476 }
2477
TEST_F(VirtualSchedulerTest,ControlDependency)2478 TEST_F(VirtualSchedulerTest, ControlDependency) {
2479 // Init.
2480 CreateGrapplerItemWithControlDependency();
2481 InitScheduler();
2482
2483 // Run the scheduler.
2484 RunScheduler("");
2485
2486 const auto* device_states = scheduler_->GetDeviceStates();
2487 const auto& cpu_state = device_states->at(kCPU0);
2488
2489 // The graph has a NoOp that takes control dependency from 7 NoOps. The peak
2490 // memory usage is when executing the final NoOp.
2491 int64 one_input_node_size = 4; // control dependency
2492 const std::vector<string> expected_names = {"x", "y", "z", "w",
2493 "u", "v", "t"};
2494 EXPECT_EQ(expected_names.size() * one_input_node_size,
2495 cpu_state.max_memory_usage);
2496 ValidateMemoryUsageSnapshot(expected_names, -1 /* port_num_expected */,
2497 cpu_state.mem_usage_snapshot_at_peak);
2498 ExpectUnorderedMapEq(
2499 {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 0)},
2500 scheduler_->GetPersistentMemoryUsage());
2501 ExpectUnorderedMapEq(
2502 {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 28)},
2503 scheduler_->GetPeakMemoryUsage());
2504 }
2505
TEST_F(VirtualSchedulerTest,ComplexDependency)2506 TEST_F(VirtualSchedulerTest, ComplexDependency) {
2507 // Init.
2508 CreateGrapplerItemWithBatchNorm();
2509 InitScheduler();
2510
2511 // Run the scheduler.
2512 RunScheduler("bn");
2513
2514 const auto& device_states = scheduler_->GetDeviceStates();
2515 const auto& cpu_state = device_states->at(kCPU0);
2516
2517 // The graph is
2518 // bn = FusedBatchNorm(x, scale, offset, mean, var)
2519 // z1 = bn.y + x
2520 // z2 = bn.var + bn.var
2521 // z3 = bn.var + bn.var
2522 // z4 = control dependency from bn.
2523 // Note that bn.mean doesn't have any consumer.
2524 const int x_size = batch_size_ * width_ * height_ * depth_in_;
2525 int64 expected_size =
2526 4 * (2 * x_size /* x and bn.y */ + depth_in_ /* bn.var */ +
2527 1 /* control dependency */);
2528 EXPECT_EQ(expected_size, cpu_state.memory_usage);
2529
2530 // Nodes currently in memory: bn's port -1, 0, and 2, and x's port 0.
2531 std::set<std::pair<string, int>> nodes_in_memory;
2532 std::transform(
2533 cpu_state.nodes_in_memory.begin(), cpu_state.nodes_in_memory.end(),
2534 std::inserter(nodes_in_memory, nodes_in_memory.begin()),
2535 [](const std::pair<const NodeDef*, int>& node_port) {
2536 return std::make_pair(node_port.first->name(), node_port.second);
2537 });
2538 std::set<std::pair<string, int>> expected = {
2539 std::make_pair("bn", -1),
2540 std::make_pair("bn", 0),
2541 std::make_pair("bn", 2),
2542 std::make_pair("x", 0),
2543 };
2544 ExpectSetEq(expected, nodes_in_memory);
2545
2546 const auto* node_states = scheduler_->GetNodeStates();
2547 const NodeState* bn_node = nullptr;
2548 const NodeState* x_node = nullptr;
2549 for (const auto& nodedef_node_state : *node_states) {
2550 const NodeDef* node = nodedef_node_state.first;
2551 const NodeState& node_state = nodedef_node_state.second;
2552 if (node->name() == "bn") {
2553 bn_node = &node_state;
2554 }
2555 if (node->name() == "x") {
2556 x_node = &node_state;
2557 }
2558 }
2559 CHECK_NOTNULL(bn_node);
2560 CHECK_NOTNULL(x_node);
2561
2562 ValidateNodeDefs({"bn", "z1"}, x_node->outputs.at(0));
2563 ValidateNodeDefs({"z4"}, bn_node->outputs.at(-1));
2564 ValidateNodeDefs({"z1"}, bn_node->outputs.at(0));
2565 // z2 and z3 are bn.var + bn.var, so they appear twice in bn's output port 2.
2566 ValidateNodeDefs({"z2", "z3", "z2", "z3"}, bn_node->outputs.at(2));
2567 }
2568
TEST_F(VirtualSchedulerTest,Variable)2569 TEST_F(VirtualSchedulerTest, Variable) {
2570 // Init.
2571 CreateGrapplerItemWithConv2DAndVariable();
2572 InitScheduler();
2573
2574 // Run the scheduler.
2575 RunScheduler("");
2576
2577 const auto* device_states = scheduler_->GetDeviceStates();
2578 const auto& cpu_state = device_states->at(kCPU0);
2579
2580 // There is one Conv2D that takes x and f, but f is variable, so it should be
2581 // in persistent nodes.
2582 ValidateMemoryUsageSnapshot({"f", "Const/Const"}, /*port_num_expected=*/0,
2583 cpu_state.persistent_nodes);
2584 // Only x in peak memory usage snapshot.
2585 ValidateMemoryUsageSnapshot({"x"}, /*port_num_expected=*/0,
2586 cpu_state.mem_usage_snapshot_at_peak);
2587 ExpectUnorderedMapEq(
2588 {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 4624)},
2589 scheduler_->GetPersistentMemoryUsage());
2590 ExpectUnorderedMapEq(
2591 {std::make_pair("/job:localhost/replica:0/task:0/cpu:0", 12800)},
2592 scheduler_->GetPeakMemoryUsage());
2593 }
2594
TEST_F(VirtualSchedulerTest,WhileLoop)2595 TEST_F(VirtualSchedulerTest, WhileLoop) {
2596 // Init.
2597 CreateGrapplerItemWithLoop();
2598 InitScheduler();
2599
2600 // Run the scheduler.
2601 RunScheduler("");
2602
2603 // Check the timeline
2604 RunMetadata metadata;
2605 scheduler_->Summary(&metadata);
2606
2607 // Nodes in topological order:
2608 // * const, ones
2609 // * while/Enter, while/Enter_1
2610 // * while/Merge, while/Merge_1
2611 // * while/Less/y
2612 // * while/Less
2613 // * while/LoopCond
2614 // * while/Switch, while/Switch_1
2615 // * while/Identity, while/Identity_1, while/Exit, while/Exit_1
2616 // * while/add/y, while/concat/axis
2617 // * while/add, while/concat
2618 // * while/NextIteration, while/NextIteration_1
2619
2620 int num_next_iteration = 0;
2621 int num_next_iteration_1 = 0;
2622 int num_exit = 0;
2623 int num_exit_1 = 0;
2624 int64 next_iter_start_micro;
2625 int64 next_iter_1_start_micro;
2626 int64 exit_start_micro;
2627 int64 exit_1_start_micro;
2628
2629 std::unordered_map<string, int64> start_times;
2630 for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
2631 for (const auto& stats : device_step_stats.node_stats()) {
2632 start_times[stats.node_name()] = stats.all_start_micros();
2633 if (stats.node_name() == "while/NextIteration") {
2634 ++num_next_iteration;
2635 next_iter_start_micro = stats.all_start_micros();
2636 } else if (stats.node_name() == "while/NextIteration_1") {
2637 ++num_next_iteration_1;
2638 next_iter_1_start_micro = stats.all_start_micros();
2639 } else if (stats.node_name() == "while/Exit") {
2640 ++num_exit;
2641 exit_start_micro = stats.all_start_micros();
2642 } else if (stats.node_name() == "while/Exit_1") {
2643 ++num_exit_1;
2644 exit_1_start_micro = stats.all_start_micros();
2645 }
2646 }
2647 }
2648
2649 // Make sure we went though the body of the loop once, and that the output of
2650 // the loop was scheduled as well.
2651 EXPECT_EQ(1, num_next_iteration);
2652 EXPECT_EQ(1, num_next_iteration_1);
2653 EXPECT_EQ(1, num_exit);
2654 EXPECT_EQ(1, num_exit_1);
2655
2656 // Start times of while/NextIteration and while/NextIteration_1 should be
2657 // different, so should be those of while/Exit and while/Exit_1.
2658 EXPECT_NE(next_iter_start_micro, next_iter_1_start_micro);
2659 EXPECT_NE(exit_start_micro, exit_1_start_micro);
2660
2661 // Check dependency among the nodes; no matter what scheduling mechanism we
2662 // use, the scheduled ops should follow these dependency chains.
2663 // Note that currently, VirtualScheduler executes while/Merge twice; hence,
2664 // we're not testing dependency chains related to while/Merge.
2665 // TODO(dyoon): after fixing while loop behavior correctly (run nodes in the
2666 // order of Enter, Merge, ...loop condition ..., ... loop body ...,
2667 // NextIteration, Merge, ... loop condition ..., Exit), re-enable dependency
2668 // chaining test w/ Merge nodes.
2669 ValidateDependencyChain(
2670 start_times,
2671 {"Const", "while/Enter", // "while/Merge",
2672 "while/Less/y", "while/Less", "while/LoopCond", "while/Switch",
2673 "while/Identity", "while/add/y", "while/add", "while/NextIteration"});
2674 // ValidateDependencyChain(start_times, {"while/Merge", "while/Less"});
2675 ValidateDependencyChain(start_times,
2676 {"ones", "while/Enter_1", // "while/Merge_1",
2677 "while/Switch_1", "while/Identity_1", "while/concat",
2678 "while/NextIteration_1"});
2679 ValidateDependencyChain(start_times, {"while/Switch", "while/Exit"});
2680 ValidateDependencyChain(
2681 start_times, {"while/Identity", "while/concat/axis", "while/concat"});
2682 ValidateDependencyChain(start_times, {"while/Identity", "while/add"});
2683 ValidateDependencyChain(start_times, {"while/Switch_1", "while/Exit_1"});
2684 }
2685
TEST_F(VirtualSchedulerTest,AnnotatedWhileLoop)2686 TEST_F(VirtualSchedulerTest, AnnotatedWhileLoop) {
2687 {
2688 // Init.
2689 CreateGrapplerItemWithLoop();
2690 InitScheduler();
2691
2692 // Runs the scheduler.
2693 RunScheduler("");
2694 Costs c = scheduler_->Summary();
2695
2696 EXPECT_EQ(23, c.execution_time.asMicroSeconds().count());
2697 // Both while/Merge and while/Merge_1 are scheduled twice.
2698 EXPECT_EQ(grappler_item_->graph.node_size() + 2, c.num_ops_total);
2699 EXPECT_FALSE(c.inaccurate);
2700 EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2701 }
2702
2703 {
2704 // Init.
2705 CreateGrapplerItemWithLoopAnnotated();
2706 InitScheduler();
2707
2708 // Runs the scheduler.
2709 RunScheduler("");
2710 Costs c = scheduler_->Summary();
2711
2712 // The costs for Merge is accumulated twice for execution_count times, but
2713 // since Merge's cost is minimal, we keep this behavior here.
2714 EXPECT_EQ(178, c.execution_time.asMicroSeconds().count());
2715 // Both while/Merge and while/Merge_1 are scheduled twice.
2716 EXPECT_EQ(grappler_item_->graph.node_size() + 2, c.num_ops_total);
2717 EXPECT_FALSE(c.inaccurate);
2718 EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2719 }
2720 }
2721
TEST_F(VirtualSchedulerTest,Condition)2722 TEST_F(VirtualSchedulerTest, Condition) {
2723 // Without annotation.
2724 {
2725 // Inits.
2726 CreateGrapplerItemWithCondition();
2727 InitScheduler();
2728
2729 // Runs the scheduler.
2730 RunScheduler("");
2731 RunMetadata metadata;
2732 Costs c = scheduler_->Summary(&metadata);
2733
2734 // Nodes in topological order: a/Less, Switch, First/Second, Merge.
2735 int num_a = 0;
2736 int num_less = 0;
2737 int num_switch = 0;
2738 int num_first = 0;
2739 int num_second = 0;
2740 int num_merge = 0;
2741
2742 for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
2743 for (const auto& stats : device_step_stats.node_stats()) {
2744 if (stats.node_name() == "a") {
2745 ++num_a;
2746 } else if (stats.node_name() == "Less") {
2747 ++num_less;
2748 } else if (stats.node_name() == "Switch") {
2749 ++num_switch;
2750 } else if (stats.node_name() == "First") {
2751 ++num_first;
2752 } else if (stats.node_name() == "Second") {
2753 ++num_second;
2754 } else if (stats.node_name() == "Merge") {
2755 ++num_merge;
2756 }
2757 }
2758 }
2759
2760 EXPECT_EQ(1, num_a);
2761 EXPECT_EQ(1, num_less);
2762 EXPECT_EQ(1, num_switch);
2763 EXPECT_EQ(1, num_first);
2764 EXPECT_EQ(1, num_second);
2765 EXPECT_EQ(2, num_merge);
2766
2767 EXPECT_EQ(7, c.execution_time.asMicroSeconds().count());
2768 // Merge is executed twice.
2769 EXPECT_EQ(grappler_item_->graph.node_size() + 1, c.num_ops_total);
2770 EXPECT_FALSE(c.inaccurate);
2771 EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2772 }
2773
2774 // With annotation.
2775 {
2776 // Inits.
2777 CreateGrapplerItemWithCondition();
2778
2779 // Annotates the Switch node.
2780 for (auto& node : *grappler_item_->graph.mutable_node()) {
2781 if (node.name() == "Switch") {
2782 AttrValue attr_output_info;
2783 // Adds one output slot 0 so that Second shouldn't be executed.
2784 (*attr_output_info.mutable_list()).add_i(0);
2785 AddNodeAttr(kOutputSlots, attr_output_info, &node);
2786 }
2787 }
2788
2789 InitScheduler();
2790
2791 // Runs the scheduler.
2792 RunScheduler("");
2793 RunMetadata metadata;
2794 Costs c = scheduler_->Summary(&metadata);
2795
2796 // Nodes in topological order: a/Less, Switch, Merge
2797 int num_a = 0;
2798 int num_less = 0;
2799 int num_switch = 0;
2800 int num_first = 0;
2801 int num_second = 0;
2802 int num_merge = 0;
2803
2804 for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
2805 for (const auto& stats : device_step_stats.node_stats()) {
2806 if (stats.node_name() == "a") {
2807 ++num_a;
2808 } else if (stats.node_name() == "Less") {
2809 ++num_less;
2810 } else if (stats.node_name() == "Switch") {
2811 ++num_switch;
2812 } else if (stats.node_name() == "First") {
2813 ++num_first;
2814 } else if (stats.node_name() == "Second") {
2815 ++num_second;
2816 } else if (stats.node_name() == "Merge") {
2817 ++num_merge;
2818 }
2819 }
2820 }
2821
2822 EXPECT_EQ(1, num_a);
2823 EXPECT_EQ(1, num_less);
2824 EXPECT_EQ(1, num_switch);
2825 EXPECT_EQ(1, num_first);
2826 EXPECT_EQ(0, num_second);
2827 EXPECT_EQ(1, num_merge);
2828
2829 EXPECT_EQ(5, c.execution_time.asMicroSeconds().count());
2830 // Second is not executed.
2831 EXPECT_EQ(grappler_item_->graph.node_size() - 1, c.num_ops_total);
2832 EXPECT_FALSE(c.inaccurate);
2833 EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
2834 }
2835 }
2836
TEST_F(VirtualSchedulerTest,InterDeviceTransfer)2837 TEST_F(VirtualSchedulerTest, InterDeviceTransfer) {
2838 // Init.
2839 CreateGrapplerItemWithInterDeviceTransfers();
2840 InitScheduler();
2841
2842 // Run the scheduler.
2843 auto ops_executed = RunScheduler("");
2844
2845 // Helper lambda to extract port num from _Send and _Recv op name.
2846 auto get_port_num = [](const string& name) -> int {
2847 if (name.find("bn_0") != string::npos) {
2848 return 0;
2849 } else if (name.find("bn_1") != string::npos) {
2850 return 1;
2851 } else if (name.find("bn_2") != string::npos) {
2852 return 2;
2853 } else if (name.find("bn_minus1") != string::npos) {
2854 return -1;
2855 }
2856 return -999;
2857 };
2858
2859 // Reorganize ops_executed for further testing.
2860 std::unordered_map<string, int> op_count;
2861 std::unordered_map<int, string> recv_op_names;
2862 std::unordered_map<int, string> send_op_names;
2863 for (const auto& x : ops_executed) {
2864 const auto& name = x.first;
2865 const auto& node_info = x.second;
2866 const auto& op = node_info.op_info.op();
2867 if (op == kRecv) {
2868 recv_op_names[get_port_num(name)] = name;
2869 } else if (op == kSend) {
2870 send_op_names[get_port_num(name)] = name;
2871 }
2872 op_count[op]++;
2873 }
2874
2875 // Same number of _Send and _Recv.
2876 EXPECT_EQ(op_count.at(kSend), op_count.at(kRecv));
2877
2878 // Expect 4 Send and Recvs each: port 0, 1, and, 2, and control dependency.
2879 EXPECT_EQ(op_count.at(kRecv), 4);
2880 EXPECT_EQ(op_count.at(kSend), 4);
2881
2882 // Helper lambda for extracting output Tensor size.
2883 auto get_output_size = [this, ops_executed](const string& name) -> int64 {
2884 const auto& output_properties_ = ops_executed.at(name).op_info.outputs();
2885 std::vector<OpInfo::TensorProperties> output_properties;
2886 for (const auto& output_property : output_properties_) {
2887 output_properties.push_back(output_property);
2888 }
2889 return CalculateOutputSize(output_properties, 0);
2890 };
2891
2892 // Validate transfer size.
2893 // Batchnorm output y is 4D vector: batch x width x width x depth.
2894 int input_size = 4 * batch_size_ * width_ * height_ * depth_in_;
2895 EXPECT_EQ(get_output_size(recv_op_names[0]), input_size);
2896 EXPECT_EQ(get_output_size(send_op_names[0]), input_size);
2897 // Mean and vars are 1-D vector with size depth_in_.
2898 EXPECT_EQ(get_output_size(recv_op_names[1]), 4 * depth_in_);
2899 EXPECT_EQ(get_output_size(send_op_names[1]), 4 * depth_in_);
2900 EXPECT_EQ(get_output_size(recv_op_names[2]), 4 * depth_in_);
2901 EXPECT_EQ(get_output_size(send_op_names[2]), 4 * depth_in_);
2902 // Control dependency size is 4B.
2903 EXPECT_EQ(get_output_size(recv_op_names[-1]), 4);
2904 EXPECT_EQ(get_output_size(send_op_names[-1]), 4);
2905 }
2906
TEST_F(VirtualSchedulerTest,GraphWithSendRecv)2907 TEST_F(VirtualSchedulerTest, GraphWithSendRecv) {
2908 // Init.
2909 CreateGrapplerItemWithSendRecv();
2910 InitScheduler();
2911
2912 // Run the scheduler.
2913 auto ops_executed = RunScheduler("");
2914
2915 EXPECT_GT(ops_executed.count("Const"), 0);
2916 EXPECT_GT(ops_executed.count("Send"), 0);
2917 EXPECT_GT(ops_executed.count("Recv"), 0);
2918 }
2919
TEST_F(VirtualSchedulerTest,GraphWithSendRecvDifferentDevice)2920 TEST_F(VirtualSchedulerTest, GraphWithSendRecvDifferentDevice) {
2921 // Init.
2922 CreateGrapplerItemWithSendRecv();
2923 // Change Recv node's device so that Send and Recv are placed on different
2924 // devices.
2925 auto& graph = grappler_item_->graph;
2926 const string recv_device = kCPU1;
2927 for (int i = 0; i < graph.node_size(); i++) {
2928 auto* node = graph.mutable_node(i);
2929 if (node->name() == "Recv") {
2930 node->set_device(recv_device);
2931 auto* attr = node->mutable_attr();
2932 (*attr)["recv_device"].set_s(recv_device);
2933 } else if (node->name() == "Send") {
2934 auto* attr = node->mutable_attr();
2935 (*attr)["recv_device"].set_s(recv_device);
2936 }
2937 }
2938 InitScheduler();
2939
2940 // Run the scheduler.
2941 auto ops_executed = RunScheduler("");
2942
2943 // Expect Const, Send, Recv, and VirtualScheduler created Send and Recv ops.
2944 EXPECT_GT(ops_executed.count("Const"), 0);
2945 EXPECT_GT(ops_executed.count("Send"), 0);
2946 EXPECT_GT(ops_executed.count("Send_Send_0_from_/job_localhost/replica_0/"
2947 "task_0/cpu_0_to_/job_localhost"
2948 "/replica_0/task_0/cpu_1"),
2949 0);
2950 EXPECT_GT(ops_executed.count(
2951 "Recv_Send_0_on_/job_localhost/replica_0/task_0/cpu_1"),
2952 0);
2953 EXPECT_GT(ops_executed.count("Recv"), 0);
2954 }
2955
TEST_F(VirtualSchedulerTest,GraphWihtOnlyRecv)2956 TEST_F(VirtualSchedulerTest, GraphWihtOnlyRecv) {
2957 // Init.
2958 CreateGrapplerItemWithRecvWithoutSend();
2959 InitScheduler();
2960
2961 // Run the scheduler.
2962 auto ops_executed = RunScheduler("");
2963
2964 // Recv without Send will be treated as initially ready node.
2965 EXPECT_GT(ops_executed.count("Recv"), 0);
2966 }
2967
TEST_F(VirtualSchedulerTest,AddMergeSwitch)2968 TEST_F(VirtualSchedulerTest, AddMergeSwitch) {
2969 // Override scheduler_ with CompositeNodeNamager.
2970 scheduler_ = absl::make_unique<TestVirtualScheduler>(
2971 /*use_static_shapes=*/true,
2972 /*use_aggressive_shape_inference=*/true, &composite_node_manager_,
2973 cluster_.get());
2974 CreateGrapplerItemWithSwitchMergeInput();
2975 InitScheduler();
2976
2977 // pred --+ z --+
2978 // | |
2979 // V V
2980 // x -> Switch --------> Merge ---> Add --> y
2981 // | ^
2982 // | |
2983 // +-----> Add -----+
2984 // ^
2985 // |
2986 // b --------------+
2987
2988 // Run the scheduler. The current VirtualScheduler, w/o annotation, triggers
2989 // both outputs of Switch; then Merge (as long as one input is ready, it's z
2990 // is ready, if we just use num_inputs_ready counter, the final Add becomes
2991 // ready. possible to skipt scheduling z. (Need to use CompositeNodeManager
2992 // to test this case).
2993 auto ops_executed = RunScheduler("");
2994
2995 EXPECT_GT(ops_executed.count("z"), 0);
2996 }
2997
TEST_F(VirtualSchedulerTest,AddFromOneTensor)2998 TEST_F(VirtualSchedulerTest, AddFromOneTensor) {
2999 CreateGrapplerItemWithAddFromOneTensor();
3000 InitScheduler();
3001
3002 // x -+----> Add --> y
3003 // | ^
3004 // | |
3005 // +-------+
3006
3007 // Run the scheduler.
3008 auto ops_executed = RunScheduler("");
3009 EXPECT_GT(ops_executed.count("y"), 0);
3010 EXPECT_GT(ops_executed.count("x"), 0);
3011 }
3012
3013 } // namespace
3014 } // end namespace grappler
3015 } // end namespace tensorflow
3016