• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/graph/collective_order.h"
16 
17 #include <gmock/gmock.h>
18 #include "tensorflow/core/framework/node_def_builder.h"
19 #include "tensorflow/core/graph/graph_def_builder.h"
20 #include "tensorflow/core/graph/graph_def_builder_util.h"
21 #include "tensorflow/core/lib/core/status_test_util.h"
22 #include "tensorflow/core/platform/test.h"
23 
24 namespace tensorflow {
25 namespace {
26 
27 using ::testing::UnorderedElementsAreArray;
28 
29 REGISTER_OP("TestParams").Output("o: float");
30 
31 // Verifies that the list of collective nodes in `graph` matches
32 // `expected_collective_nodes`, and that the list of control edges between these
33 // collective nodes matches `expected_collective_control_edges`.
VerifyGraph(const Graph & graph,const std::vector<string> & expected_collective_nodes,const std::vector<std::pair<string,string>> & expected_collective_control_edges)34 void VerifyGraph(const Graph& graph,
35                  const std::vector<string>& expected_collective_nodes,
36                  const std::vector<std::pair<string, string>>&
37                      expected_collective_control_edges) {
38   std::vector<string> actual_collective_nodes;
39   std::vector<std::pair<string, string>> actual_collective_control_edges;
40   for (const Node* src : graph.nodes()) {
41     if (!src->IsCollective()) {
42       continue;
43     }
44     actual_collective_nodes.push_back(src->name());
45     for (const Edge* edge : src->out_edges()) {
46       VLOG(2) << "collective edge " << edge->src()->name() << " -> "
47               << edge->dst()->name();
48       // Add all control edges found except those to `_SINK`.
49       if (!edge->IsControlEdge() || edge->dst()->name() == "_SINK") {
50         continue;
51       }
52       actual_collective_control_edges.emplace_back(src->name(),
53                                                    edge->dst()->name());
54     }
55   }
56   EXPECT_THAT(actual_collective_nodes,
57               UnorderedElementsAreArray(expected_collective_nodes));
58   EXPECT_THAT(actual_collective_control_edges,
59               UnorderedElementsAreArray(expected_collective_control_edges));
60 }
61 
62 // Verifies that the `wait_for` attribute on collective nodes matches
63 // `wait_for_map`.
VerifyAttrs(const Graph & graph,const std::unordered_map<string,std::vector<int32>> wait_for_map)64 void VerifyAttrs(
65     const Graph& graph,
66     const std::unordered_map<string, std::vector<int32>> wait_for_map) {
67   for (const Node* node : graph.nodes()) {
68     if (node->IsCollective() ||
69         wait_for_map.find(node->name()) == wait_for_map.end()) {
70       continue;
71     }
72     std::vector<int32> wait_for_actual;
73     TF_EXPECT_OK(GetNodeAttr(node->attrs(), "wait_for", &wait_for_actual));
74     auto wait_for_expected = wait_for_map.at(node->name());
75     EXPECT_THAT(wait_for_actual, UnorderedElementsAreArray(wait_for_expected));
76   }
77 }
78 
CollectiveReduceNode(GraphDefBuilder * builder,Node * input,const string & name,const string & device,int instance_key)79 Node* CollectiveReduceNode(GraphDefBuilder* builder, Node* input,
80                            const string& name, const string& device,
81                            int instance_key) {
82   Node* collective_node =
83       ops::UnaryOp("CollectiveReduce", input,
84                    builder->opts()
85                        .WithName(name)
86                        .WithDevice(device)
87                        .WithAttr("T", DT_FLOAT)
88                        .WithAttr("group_size", 2)
89                        .WithAttr("group_key", 1)
90                        .WithAttr("instance_key", instance_key)
91                        .WithAttr("merge_op", "Add")
92                        .WithAttr("final_op", "Id")
93                        .WithAttr("subdiv_offsets", {1}));
94   return collective_node;
95 }
96 
97 // Initialize the following graph:
98 //
99 //       (cpu0) (cpu1)
100 //         a      b
101 //         |      |
102 //         c1     c1
103 //         |      |
104 //         id     id
105 //        /  \   /  \
106 //       c2  c3 c2  c3
107 //
108 // Here ci denotes a collective node with `instance_key` i.  `a` and `b` are
109 // inputs, `id` is identity node.
InitGraph()110 std::unique_ptr<Graph> InitGraph() {
111   GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
112   const string dev0 = "/job:localhost/replica:0/task:0/device:CPU:0";
113   const string dev1 = "/job:localhost/replica:0/task:0/device:CPU:1";
114   Node* a = ops::SourceOp("TestParams",
115                           builder.opts().WithName("a").WithDevice(dev0));
116   Node* b = ops::SourceOp("TestParams",
117                           builder.opts().WithName("b").WithDevice(dev1));
118   Node* c1_0 = CollectiveReduceNode(&builder, a, "c1_0", dev0, 1);
119   Node* c1_1 = CollectiveReduceNode(&builder, b, "c1_1", dev1, 1);
120   Node* id0 = ops::UnaryOp(
121       "Identity", c1_0,
122       builder.opts().WithName("id0").WithDevice(dev0).WithAttr("T", DT_FLOAT));
123   Node* id1 = ops::UnaryOp(
124       "Identity", c1_1,
125       builder.opts().WithName("id1").WithDevice(dev1).WithAttr("T", DT_FLOAT));
126   CollectiveReduceNode(&builder, id0, "c2_0", dev0, 2);
127   CollectiveReduceNode(&builder, id1, "c2_1", dev1, 2);
128   CollectiveReduceNode(&builder, id0, "c3_0", dev0, 3);
129   CollectiveReduceNode(&builder, id1, "c3_1", dev1, 3);
130 
131   std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
132   Status s = GraphDefBuilderToGraph(builder, graph.get());
133   if (!s.ok()) {
134     LOG(FATAL) << "Error building graph " << s;
135   }
136   return graph;
137 }
138 
139 // Tests that in the graph created by `InitGraph`, exactly 2 control edges are
140 // added after calling `OrderCollectives`: c3_0 -> c2_0 and c3_1 -> c2_1.
TEST(CollectiveOrderTest,SimpleOrder)141 TEST(CollectiveOrderTest, SimpleOrder) {
142   std::unique_ptr<Graph> graph = InitGraph();
143   TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kEdges));
144   VerifyGraph(*graph, {"c1_0", "c1_1", "c2_0", "c2_1", "c3_0", "c3_1"},
145               {{"c3_0", "c2_0"}, {"c3_1", "c2_1"}});
146 }
147 
TEST(CollectiveOrderTest,SimpleOrderAttr)148 TEST(CollectiveOrderTest, SimpleOrderAttr) {
149   std::unique_ptr<Graph> graph = InitGraph();
150   TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kAttrs));
151   VerifyAttrs(*graph, {{"c2_0", {3}}, {"c2_1", {3}}});
152 }
153 
154 // Initialize the following graph:
155 //
156 //         a
157 //         |
158 //         c1
159 //        /  \
160 //       c4  id
161 //          /  \
162 //         c2  c3
163 //
164 // Here ci denotes a collective node with `instance_key` i.  `a` is an input,
165 // `id` is identity node.
InitGraph2()166 std::unique_ptr<Graph> InitGraph2() {
167   GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
168   const string dev0 = "/job:localhost/replica:0/task:0/device:CPU:0";
169   Node* a = ops::SourceOp("TestParams",
170                           builder.opts().WithName("a").WithDevice(dev0));
171   Node* c1 = CollectiveReduceNode(&builder, a, "c1", dev0, 1);
172   CollectiveReduceNode(&builder, c1, "c4", dev0, 4);
173   Node* id = ops::UnaryOp(
174       "Identity", c1,
175       builder.opts().WithName("id").WithDevice(dev0).WithAttr("T", DT_FLOAT));
176   CollectiveReduceNode(&builder, id, "c2", dev0, 2);
177   CollectiveReduceNode(&builder, id, "c3", dev0, 3);
178 
179   std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
180   Status s = GraphDefBuilderToGraph(builder, graph.get());
181   if (!s.ok()) {
182     LOG(FATAL) << "Error building graph " << s;
183   }
184   return graph;
185 }
186 
187 // Tests that in the graph created by `InitGraph2`, we add the following control
188 // edges after calling `OrderCollectives`: c4 -> c3, c3 -> c2.  c4->c2 is
189 // pruned because it follows from the other two edges.
TEST(CollectiveOrderTest,SimpleOrder2)190 TEST(CollectiveOrderTest, SimpleOrder2) {
191   std::unique_ptr<Graph> graph = InitGraph2();
192   TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kEdges));
193   VerifyGraph(*graph, {"c1", "c2", "c3", "c4"}, {{"c4", "c3"}, {"c3", "c2"}});
194 }
195 
196 // Initialize the following graph:
197 //
198 //         w   x   y   z
199 //         |   |   |   |
200 //         c1  c2  c3  c4
201 //
InitGraphForPruning()202 std::unique_ptr<Graph> InitGraphForPruning() {
203   GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
204   const string dev0 = "/job:localhost/replica:0/task:0/device:CPU:0";
205   Node* w = ops::SourceOp("TestParams",
206                           builder.opts().WithName("w").WithDevice(dev0));
207   Node* x = ops::SourceOp("TestParams",
208                           builder.opts().WithName("x").WithDevice(dev0));
209   Node* y = ops::SourceOp("TestParams",
210                           builder.opts().WithName("y").WithDevice(dev0));
211   Node* z = ops::SourceOp("TestParams",
212                           builder.opts().WithName("z").WithDevice(dev0));
213   CollectiveReduceNode(&builder, w, "c1", dev0, 1);
214   CollectiveReduceNode(&builder, x, "c2", dev0, 2);
215   CollectiveReduceNode(&builder, y, "c3", dev0, 3);
216   CollectiveReduceNode(&builder, z, "c4", dev0, 4);
217 
218   std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
219   Status s = GraphDefBuilderToGraph(builder, graph.get());
220   if (!s.ok()) {
221     LOG(FATAL) << "Error building graph " << s;
222   }
223   return graph;
224 }
225 
226 // Tests that in the graph created by `InitGraphForPruning`, we only add c4 ->
227 // c3, c3 -> c2, c2 -> c1, and other edges are pruned away.
TEST(CollectiveOrderTest,Pruning)228 TEST(CollectiveOrderTest, Pruning) {
229   std::unique_ptr<Graph> graph = InitGraphForPruning();
230   TF_EXPECT_OK(OrderCollectives(graph.get(), GraphCollectiveOrder::kAttrs));
231   VerifyAttrs(*graph, {{"c3", {4}}, {"c2", {3}}, {"c1", {2}}});
232 }
233 
234 }  // namespace
235 }  // namespace tensorflow
236