1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/lib/strings/str_util.h"
17 #if GOOGLE_CUDA
18
19 #include <forward_list>
20 #include <vector>
21
22 #include "tensorflow/core/common_runtime/optimization_registry.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/graph/node_builder.h"
25
26 namespace tensorflow {
27 namespace {
28
29 // Replaces NcclReduce node with _NcclReduceRecv reusing one input of same
30 // device, adds one _NcclReduceSend for each other input.
ReplaceReduce(Graph * graph,Node * node)31 Status ReplaceReduce(Graph* graph, Node* node) {
32 string reduction;
33 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "reduction", &reduction));
34 DataType dtype;
35 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &dtype));
36 int num_devices = node->num_inputs();
37 string shared_name = node->name();
38 auto make_builder = [&](StringPiece op_name, StringPiece suffix) {
39 return NodeBuilder(strings::StrCat(shared_name, suffix), op_name)
40 .Attr("reduction", reduction)
41 .Attr("num_devices", num_devices)
42 .Attr("shared_name", shared_name)
43 .Attr("T", dtype);
44 };
45 std::vector<Node*> control_inputs;
46 for (const auto& edge : node->in_edges()) {
47 if (edge->IsControlEdge()) {
48 control_inputs.push_back(edge->src());
49 }
50 }
51 std::vector<NodeBuilder::NodeOut> out_nodes;
52 for (const auto& edge : node->out_edges()) {
53 out_nodes.emplace_back(edge->dst(), edge->dst_input());
54 }
55 int recv_dev = node->assigned_device_name_index();
56 NodeBuilder recv_builder =
57 make_builder("_NcclReduceRecv", "Recv").ControlInputs(control_inputs);
58 bool recv_input_set = false;
59 int send_counter = 0;
60 for (const auto& edge : node->in_edges()) {
61 Node* src_node = edge->src();
62 if (edge->IsControlEdge()) {
63 continue;
64 }
65 int send_dev = src_node->assigned_device_name_index();
66 if (!recv_input_set && send_dev == recv_dev) {
67 recv_builder.Input(src_node);
68 recv_input_set = true;
69 continue;
70 }
71 auto send_builder = make_builder("_NcclReduceSend",
72 strings::StrCat("Send_", ++send_counter))
73 .Input(src_node)
74 .ControlInputs(control_inputs);
75 Node* send_node = nullptr;
76 TF_RETURN_IF_ERROR(send_builder.Finalize(graph, &send_node));
77 send_node->set_assigned_device_name_index(send_dev);
78 // Send nodes don't have any outputs and therefore have no data dependencies
79 // to the outputs of the graph. We add a control dependency to the receive
80 // node so that those 'dangling' nodes are run.
81 // TODO(b/67027412): Avoid these cross-device control edges.
82 for (const auto& out_node : out_nodes) {
83 graph->AddControlEdge(send_node, out_node.node);
84 }
85 }
86 if (!recv_input_set) {
87 return errors::InvalidArgument(
88 "No input tensor uses the same device as the NcclReduce op");
89 }
90 Node* recv_node = nullptr;
91 TF_RETURN_IF_ERROR(recv_builder.Finalize(graph, &recv_node));
92 recv_node->set_assigned_device_name_index(recv_dev);
93 graph->RemoveNode(node);
94 for (const auto& out_node : out_nodes) {
95 if (out_node.index == Graph::kControlSlot) {
96 graph->AddControlEdge(recv_node, out_node.node);
97 } else {
98 graph->AddEdge(recv_node, 0, out_node.node, out_node.index);
99 }
100 }
101 return OkStatus();
102 }
103
TensorFromShape(const TensorShapeProto & shape)104 TensorProto TensorFromShape(const TensorShapeProto& shape) {
105 TensorProto result;
106 result.set_dtype(DT_INT32);
107 for (const auto& dim : shape.dim()) {
108 result.add_int_val(dim.size());
109 }
110 result.mutable_tensor_shape()->add_dim()->set_size(shape.dim_size());
111 return result;
112 }
113
114 // Replaces NcclBroadcast node with _NcclBroadcastSend, connects the input to
115 // all outputs of same device, adds one _NcclBroadcastRecv for each other output
116 // device.
ReplaceBroadcast(Graph * graph,Node * node)117 Status ReplaceBroadcast(Graph* graph, Node* node) {
118 DataType dtype;
119 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &dtype));
120 int send_dev = node->assigned_device_name_index();
121 int num_devices = 0; // Number of distinct devices, incremented below.
122 std::vector<int> recv_index_map; // Map device name index to stable index.
123
124 // Map device name index to nodes that take the broadcast as input.
125 std::vector<std::forward_list<NodeBuilder::NodeOut>> out_nodes_map;
126 for (const auto& edge : node->out_edges()) {
127 int dst_dev = edge->IsControlEdge()
128 ? send_dev
129 : edge->dst()->assigned_device_name_index();
130 if (out_nodes_map.size() <= dst_dev) {
131 out_nodes_map.resize(dst_dev + 1);
132 recv_index_map.resize(dst_dev + 1);
133 }
134 auto it = out_nodes_map.begin() + dst_dev;
135 if (it->empty()) {
136 recv_index_map[dst_dev] = num_devices;
137 ++num_devices;
138 }
139 it->emplace_front(NodeBuilder::NodeOut(edge->dst(), edge->dst_input()));
140 }
141
142 if (num_devices <= 1) {
143 // Only one participating device, skip NCCL op.
144 const Edge* in_edge = nullptr;
145 TF_RETURN_IF_ERROR(node->input_edge(0, &in_edge));
146 Node* in_node = in_edge->src();
147 int in_index = in_edge->src_output();
148 graph->RemoveNode(node);
149 for (const auto& out_nodes : out_nodes_map) {
150 for (const auto& out_node : out_nodes) {
151 if (out_node.index == Graph::kControlSlot) {
152 graph->AddControlEdge(in_node, out_node.node);
153 } else {
154 graph->AddEdge(in_node, in_index, out_node.node, out_node.index);
155 }
156 }
157 }
158 return OkStatus();
159 }
160
161 string shared_name = node->name();
162 auto make_builder = [&](StringPiece op_name, StringPiece suffix) {
163 return NodeBuilder(strings::StrCat(shared_name, suffix), op_name)
164 .Attr("num_devices", num_devices)
165 .Attr("shared_name", shared_name)
166 .Attr("T", dtype);
167 };
168
169 // Create broadcast send node and replace the original broadcast node.
170 NodeBuilder::NodeOut in_node;
171 NodeBuilder send_builder = make_builder("_NcclBroadcastSend", "Send");
172 for (const auto& edge : node->in_edges()) {
173 if (edge->IsControlEdge()) {
174 send_builder.ControlInput(edge->src());
175 } else {
176 in_node = NodeBuilder::NodeOut(edge->src(), edge->src_output());
177 send_builder.Input(in_node);
178 }
179 }
180 Node* send_node = nullptr;
181 TF_RETURN_IF_ERROR(send_builder.Finalize(graph, &send_node));
182 send_node->set_assigned_device_name_index(send_dev);
183
184 TensorShapeProto shape_proto;
185 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "shape", &shape_proto));
186
187 // Delete the original node before reconnecting to outputs.
188 graph->RemoveNode(node);
189
190 // Connect all outputs on the device of broadcast send.
191 for (const auto& out_node : out_nodes_map[send_dev]) {
192 if (out_node.index == Graph::kControlSlot) {
193 graph->AddControlEdge(send_node, out_node.node);
194 } else {
195 graph->AddEdge(in_node.node, in_node.index, out_node.node,
196 out_node.index);
197 // Add control edge so send node is run.
198 graph->AddControlEdge(send_node, out_node.node);
199 }
200 }
201 out_nodes_map[send_dev].clear();
202
203 TensorProto tensor_proto = TensorFromShape(shape_proto);
204 bool is_fully_defined = TensorShape(shape_proto).IsFullyDefined();
205 string shape_name = strings::StrCat(in_node.node->name(), "/Shape");
206 Node* shape_node = nullptr;
207 if (!is_fully_defined) {
208 NodeBuilder shape_builder(shape_name, "Shape");
209 shape_builder.Input(in_node).Attr("out_type", DT_INT32).Attr("T", dtype);
210 TF_RETURN_IF_ERROR(shape_builder.Finalize(graph, &shape_node));
211 shape_node->set_assigned_device_name_index(send_dev);
212 }
213
214 // For all other devices, create a broadcast receive and connect outputs.
215 for (int recv_dev = 0; recv_dev < out_nodes_map.size(); ++recv_dev) {
216 if (out_nodes_map[recv_dev].empty()) {
217 continue;
218 }
219 int recv_index = recv_index_map[recv_dev];
220 if (is_fully_defined) {
221 // If the shape is fully defined, define one const node per device.
222 NodeBuilder shape_builder(strings::StrCat(shape_name, recv_index),
223 "Const");
224 shape_builder.Attr("value", tensor_proto).Attr("dtype", DT_INT32);
225 TF_RETURN_IF_ERROR(shape_builder.Finalize(graph, &shape_node));
226 shape_node->set_assigned_device_name_index(recv_dev);
227 }
228 Node* recv_node;
229 TF_RETURN_IF_ERROR(
230 make_builder("_NcclBroadcastRecv", strings::StrCat("Recv_", recv_index))
231 .Input(shape_node)
232 .Finalize(graph, &recv_node));
233 recv_node->set_assigned_device_name_index(recv_dev);
234 for (const auto& out_node : out_nodes_map[recv_dev]) {
235 graph->AddEdge(recv_node, 0, out_node.node, out_node.index);
236 }
237 }
238
239 return OkStatus();
240 }
241
242 // Replaces occurrences of Nccl{Reduce, Broadcast}Input/Output with their
243 // _Nccl...Send/Recv counterparts and removes data dependencies between them.
244 class NcclReplacePass : public GraphOptimizationPass {
245 public:
Run(const GraphOptimizationPassOptions & options)246 Status Run(const GraphOptimizationPassOptions& options) override {
247 if (options.graph == nullptr) {
248 return OkStatus();
249 }
250 Graph* graph = options.graph->get();
251 if (graph == nullptr) {
252 return errors::Internal(
253 "NCCL replacement should happen before partitioning and a "
254 "graph should be available.");
255 }
256 // Find reduction and broadcast ops and replace them with Send/Recv ops.
257 for (Node* node : graph->op_nodes()) {
258 StringPiece type = node->type_string();
259 if (!absl::StartsWith(type, "Nccl")) {
260 continue;
261 }
262 if (type == "NcclReduce") {
263 TF_RETURN_IF_ERROR(ReplaceReduce(graph, node));
264 }
265 if (type == "NcclBroadcast") {
266 TF_RETURN_IF_ERROR(ReplaceBroadcast(graph, node));
267 }
268 }
269 return OkStatus();
270 }
271 };
272 REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_PLACEMENT, 0,
273 NcclReplacePass);
274
275 } // namespace
276 } // namespace tensorflow
277
278 #endif // GOOGLE_CUDA
279