1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/graph/collective_order.h"
16
17 #include <gmock/gmock.h>
18 #include "tensorflow/core/framework/node_def_builder.h"
19 #include "tensorflow/core/graph/graph_def_builder.h"
20 #include "tensorflow/core/graph/graph_def_builder_util.h"
21 #include "tensorflow/core/lib/core/status_test_util.h"
22 #include "tensorflow/core/platform/test.h"
23
24 namespace tensorflow {
25 namespace {
26
27 using ::testing::UnorderedElementsAreArray;
28
29 REGISTER_OP("TestParams").Output("o: float");
30
31 // Verifies that the list of collective nodes in `graph` matches
32 // `expected_collective_nodes`, and that the list of control edges between these
33 // collective nodes matches `expected_collective_control_edges`.
VerifyGraph(const Graph & graph,const std::vector<string> & expected_collective_nodes,const std::vector<std::pair<string,string>> & expected_collective_control_edges)34 void VerifyGraph(const Graph& graph,
35 const std::vector<string>& expected_collective_nodes,
36 const std::vector<std::pair<string, string>>&
37 expected_collective_control_edges) {
38 std::vector<string> actual_collective_nodes;
39 std::vector<std::pair<string, string>> actual_collective_control_edges;
40 for (const Node* src : graph.nodes()) {
41 if (!src->IsCollective()) {
42 continue;
43 }
44 actual_collective_nodes.push_back(src->name());
45 for (const Edge* edge : src->out_edges()) {
46 VLOG(2) << "collective edge " << edge->src()->name() << " -> "
47 << edge->dst()->name();
48 // Add all control edges found except those to `_SINK`.
49 if (!edge->IsControlEdge() || edge->dst()->name() == "_SINK") {
50 continue;
51 }
52 actual_collective_control_edges.emplace_back(src->name(),
53 edge->dst()->name());
54 }
55 }
56 EXPECT_THAT(actual_collective_nodes,
57 UnorderedElementsAreArray(expected_collective_nodes));
58 EXPECT_THAT(actual_collective_control_edges,
59 UnorderedElementsAreArray(expected_collective_control_edges));
60 }
61
62 // Verifies that the `wait_for` attribute on collective nodes matches
63 // `wait_for_map`.
VerifyAttrs(const Graph & graph,const std::unordered_map<string,std::vector<int32>> wait_for_map)64 void VerifyAttrs(
65 const Graph& graph,
66 const std::unordered_map<string, std::vector<int32>> wait_for_map) {
67 for (const Node* node : graph.nodes()) {
68 if (node->IsCollective() ||
69 wait_for_map.find(node->name()) == wait_for_map.end()) {
70 continue;
71 }
72 std::vector<int32> wait_for_actual;
73 TF_EXPECT_OK(GetNodeAttr(node->attrs(), "wait_for", &wait_for_actual));
74 auto wait_for_expected = wait_for_map.at(node->name());
75 EXPECT_THAT(wait_for_actual, UnorderedElementsAreArray(wait_for_expected));
76 }
77 }
78
CollectiveReduceNode(GraphDefBuilder * builder,Node * input,const string & name,const string & device,int instance_key)79 Node* CollectiveReduceNode(GraphDefBuilder* builder, Node* input,
80 const string& name, const string& device,
81 int instance_key) {
82 Node* collective_node =
83 ops::UnaryOp("CollectiveReduce", input,
84 builder->opts()
85 .WithName(name)
86 .WithDevice(device)
87 .WithAttr("T", DT_FLOAT)
88 .WithAttr("group_size", 2)
89 .WithAttr("group_key", 1)
90 .WithAttr("instance_key", instance_key)
91 .WithAttr("merge_op", "Add")
92 .WithAttr("final_op", "Id")
93 .WithAttr("subdiv_offsets", {1}));
94 return collective_node;
95 }
96
97 // Initialize the following graph:
98 //
99 // (cpu0) (cpu1)
100 // a b
101 // | |
102 // c1 c1
103 // | |
104 // id id
105 // / \ / \
106 // c2 c3 c2 c3
107 //
108 // Here ci denotes a collective node with `instance_key` i. `a` and `b` are
109 // inputs, `id` is identity node.
InitGraph()110 std::unique_ptr<Graph> InitGraph() {
111 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
112 const string dev0 = "/job:localhost/replica:0/task:0/device:CPU:0";
113 const string dev1 = "/job:localhost/replica:0/task:0/device:CPU:1";
114 Node* a = ops::SourceOp("TestParams",
115 builder.opts().WithName("a").WithDevice(dev0));
116 Node* b = ops::SourceOp("TestParams",
117 builder.opts().WithName("b").WithDevice(dev1));
118 Node* c1_0 = CollectiveReduceNode(&builder, a, "c1_0", dev0, 1);
119 Node* c1_1 = CollectiveReduceNode(&builder, b, "c1_1", dev1, 1);
120 Node* id0 = ops::UnaryOp(
121 "Identity", c1_0,
122 builder.opts().WithName("id0").WithDevice(dev0).WithAttr("T", DT_FLOAT));
123 Node* id1 = ops::UnaryOp(
124 "Identity", c1_1,
125 builder.opts().WithName("id1").WithDevice(dev1).WithAttr("T", DT_FLOAT));
126 CollectiveReduceNode(&builder, id0, "c2_0", dev0, 2);
127 CollectiveReduceNode(&builder, id1, "c2_1", dev1, 2);
128 CollectiveReduceNode(&builder, id0, "c3_0", dev0, 3);
129 CollectiveReduceNode(&builder, id1, "c3_1", dev1, 3);
130
131 std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
132 Status s = GraphDefBuilderToGraph(builder, graph.get());
133 if (!s.ok()) {
134 LOG(FATAL) << "Error building graph " << s;
135 }
136 return graph;
137 }
138
139 // Tests that in the graph created by `InitGraph`, exactly 2 control edges are
140 // added after calling `OrderCollectives`: c3_0 -> c2_0 and c3_1 -> c2_1.
TEST(CollectiveOrderTest,SimpleOrder)141 TEST(CollectiveOrderTest, SimpleOrder) {
142 std::unique_ptr<Graph> graph = InitGraph();
143 TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kEdges));
144 VerifyGraph(*graph, {"c1_0", "c1_1", "c2_0", "c2_1", "c3_0", "c3_1"},
145 {{"c3_0", "c2_0"}, {"c3_1", "c2_1"}});
146 }
147
TEST(CollectiveOrderTest,SimpleOrderAttr)148 TEST(CollectiveOrderTest, SimpleOrderAttr) {
149 std::unique_ptr<Graph> graph = InitGraph();
150 TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kAttrs));
151 VerifyAttrs(*graph, {{"c2_0", {3}}, {"c2_1", {3}}});
152 }
153
154 // Initialize the following graph:
155 //
156 // a
157 // |
158 // c1
159 // / \
160 // c4 id
161 // / \
162 // c2 c3
163 //
164 // Here ci denotes a collective node with `instance_key` i. `a` is an input,
165 // `id` is identity node.
InitGraph2()166 std::unique_ptr<Graph> InitGraph2() {
167 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
168 const string dev0 = "/job:localhost/replica:0/task:0/device:CPU:0";
169 Node* a = ops::SourceOp("TestParams",
170 builder.opts().WithName("a").WithDevice(dev0));
171 Node* c1 = CollectiveReduceNode(&builder, a, "c1", dev0, 1);
172 CollectiveReduceNode(&builder, c1, "c4", dev0, 4);
173 Node* id = ops::UnaryOp(
174 "Identity", c1,
175 builder.opts().WithName("id").WithDevice(dev0).WithAttr("T", DT_FLOAT));
176 CollectiveReduceNode(&builder, id, "c2", dev0, 2);
177 CollectiveReduceNode(&builder, id, "c3", dev0, 3);
178
179 std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
180 Status s = GraphDefBuilderToGraph(builder, graph.get());
181 if (!s.ok()) {
182 LOG(FATAL) << "Error building graph " << s;
183 }
184 return graph;
185 }
186
187 // Tests that in the graph created by `InitGraph2`, we add the following control
188 // edges after calling `OrderCollectives`: c4 -> c3, c3 -> c2. c4->c2 is
189 // pruned because it follows from the other two edges.
TEST(CollectiveOrderTest,SimpleOrder2)190 TEST(CollectiveOrderTest, SimpleOrder2) {
191 std::unique_ptr<Graph> graph = InitGraph2();
192 TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kEdges));
193 VerifyGraph(*graph, {"c1", "c2", "c3", "c4"}, {{"c4", "c3"}, {"c3", "c2"}});
194 }
195
196 // Initialize the following graph:
197 //
198 // w x y z
199 // | | | |
200 // c1 c2 c3 c4
201 //
InitGraphForPruning()202 std::unique_ptr<Graph> InitGraphForPruning() {
203 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
204 const string dev0 = "/job:localhost/replica:0/task:0/device:CPU:0";
205 Node* w = ops::SourceOp("TestParams",
206 builder.opts().WithName("w").WithDevice(dev0));
207 Node* x = ops::SourceOp("TestParams",
208 builder.opts().WithName("x").WithDevice(dev0));
209 Node* y = ops::SourceOp("TestParams",
210 builder.opts().WithName("y").WithDevice(dev0));
211 Node* z = ops::SourceOp("TestParams",
212 builder.opts().WithName("z").WithDevice(dev0));
213 CollectiveReduceNode(&builder, w, "c1", dev0, 1);
214 CollectiveReduceNode(&builder, x, "c2", dev0, 2);
215 CollectiveReduceNode(&builder, y, "c3", dev0, 3);
216 CollectiveReduceNode(&builder, z, "c4", dev0, 4);
217
218 std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
219 Status s = GraphDefBuilderToGraph(builder, graph.get());
220 if (!s.ok()) {
221 LOG(FATAL) << "Error building graph " << s;
222 }
223 return graph;
224 }
225
226 // Tests that in the graph created by `InitGraphForPruning`, we only add c4 ->
227 // c3, c3 -> c2, c2 -> c1, and other edges are pruned away.
TEST(CollectiveOrderTest,Pruning)228 TEST(CollectiveOrderTest, Pruning) {
229 std::unique_ptr<Graph> graph = InitGraphForPruning();
230 TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kAttrs));
231 VerifyAttrs(*graph, {{"c3", {4}}, {"c2", {3}}, {"c1", {2}}});
232 }
233
234 } // namespace
235 } // namespace tensorflow
236