• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/lib/strings/str_util.h"
17 #if GOOGLE_CUDA
18 
19 #include <forward_list>
20 #include <vector>
21 
22 #include "tensorflow/core/common_runtime/optimization_registry.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/graph/node_builder.h"
25 
26 namespace tensorflow {
27 namespace {
28 
29 // Replaces NcclReduce node with _NcclReduceRecv reusing one input of same
30 // device, adds one _NcclReduceSend for each other input.
ReplaceReduce(Graph * graph,Node * node)31 Status ReplaceReduce(Graph* graph, Node* node) {
32   string reduction;
33   TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "reduction", &reduction));
34   DataType dtype;
35   TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &dtype));
36   int num_devices = node->num_inputs();
37   string shared_name = node->name();
38   auto make_builder = [&](StringPiece op_name, StringPiece suffix) {
39     return NodeBuilder(strings::StrCat(shared_name, suffix), op_name)
40         .Attr("reduction", reduction)
41         .Attr("num_devices", num_devices)
42         .Attr("shared_name", shared_name)
43         .Attr("T", dtype);
44   };
45   std::vector<Node*> control_inputs;
46   for (const auto& edge : node->in_edges()) {
47     if (edge->IsControlEdge()) {
48       control_inputs.push_back(edge->src());
49     }
50   }
51   std::vector<NodeBuilder::NodeOut> out_nodes;
52   for (const auto& edge : node->out_edges()) {
53     out_nodes.emplace_back(edge->dst(), edge->dst_input());
54   }
55   int recv_dev = node->assigned_device_name_index();
56   NodeBuilder recv_builder =
57       make_builder("_NcclReduceRecv", "Recv").ControlInputs(control_inputs);
58   bool recv_input_set = false;
59   int send_counter = 0;
60   for (const auto& edge : node->in_edges()) {
61     Node* src_node = edge->src();
62     if (edge->IsControlEdge()) {
63       continue;
64     }
65     int send_dev = src_node->assigned_device_name_index();
66     if (!recv_input_set && send_dev == recv_dev) {
67       recv_builder.Input(src_node);
68       recv_input_set = true;
69       continue;
70     }
71     auto send_builder = make_builder("_NcclReduceSend",
72                                      strings::StrCat("Send_", ++send_counter))
73                             .Input(src_node)
74                             .ControlInputs(control_inputs);
75     Node* send_node = nullptr;
76     TF_RETURN_IF_ERROR(send_builder.Finalize(graph, &send_node));
77     send_node->set_assigned_device_name_index(send_dev);
78     // Send nodes don't have any outputs and therefore have no data dependencies
79     // to the outputs of the graph. We add a control dependency to the receive
80     // node so that those 'dangling' nodes are run.
81     // TODO(b/67027412): Avoid these cross-device control edges.
82     for (const auto& out_node : out_nodes) {
83       graph->AddControlEdge(send_node, out_node.node);
84     }
85   }
86   if (!recv_input_set) {
87     return errors::InvalidArgument(
88         "No input tensor uses the same device as the NcclReduce op");
89   }
90   Node* recv_node = nullptr;
91   TF_RETURN_IF_ERROR(recv_builder.Finalize(graph, &recv_node));
92   recv_node->set_assigned_device_name_index(recv_dev);
93   graph->RemoveNode(node);
94   for (const auto& out_node : out_nodes) {
95     if (out_node.index == Graph::kControlSlot) {
96       graph->AddControlEdge(recv_node, out_node.node);
97     } else {
98       graph->AddEdge(recv_node, 0, out_node.node, out_node.index);
99     }
100   }
101   return OkStatus();
102 }
103 
TensorFromShape(const TensorShapeProto & shape)104 TensorProto TensorFromShape(const TensorShapeProto& shape) {
105   TensorProto result;
106   result.set_dtype(DT_INT32);
107   for (const auto& dim : shape.dim()) {
108     result.add_int_val(dim.size());
109   }
110   result.mutable_tensor_shape()->add_dim()->set_size(shape.dim_size());
111   return result;
112 }
113 
114 // Replaces NcclBroadcast node with _NcclBroadcastSend, connects the input to
115 // all outputs of same device, adds one _NcclBroadcastRecv for each other output
116 // device.
ReplaceBroadcast(Graph * graph,Node * node)117 Status ReplaceBroadcast(Graph* graph, Node* node) {
118   DataType dtype;
119   TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &dtype));
120   int send_dev = node->assigned_device_name_index();
121   int num_devices = 0;  // Number of distinct devices, incremented below.
122   std::vector<int> recv_index_map;  // Map device name index to stable index.
123 
124   // Map device name index to nodes that take the broadcast as input.
125   std::vector<std::forward_list<NodeBuilder::NodeOut>> out_nodes_map;
126   for (const auto& edge : node->out_edges()) {
127     int dst_dev = edge->IsControlEdge()
128                       ? send_dev
129                       : edge->dst()->assigned_device_name_index();
130     if (out_nodes_map.size() <= dst_dev) {
131       out_nodes_map.resize(dst_dev + 1);
132       recv_index_map.resize(dst_dev + 1);
133     }
134     auto it = out_nodes_map.begin() + dst_dev;
135     if (it->empty()) {
136       recv_index_map[dst_dev] = num_devices;
137       ++num_devices;
138     }
139     it->emplace_front(NodeBuilder::NodeOut(edge->dst(), edge->dst_input()));
140   }
141 
142   if (num_devices <= 1) {
143     // Only one participating device, skip NCCL op.
144     const Edge* in_edge = nullptr;
145     TF_RETURN_IF_ERROR(node->input_edge(0, &in_edge));
146     Node* in_node = in_edge->src();
147     int in_index = in_edge->src_output();
148     graph->RemoveNode(node);
149     for (const auto& out_nodes : out_nodes_map) {
150       for (const auto& out_node : out_nodes) {
151         if (out_node.index == Graph::kControlSlot) {
152           graph->AddControlEdge(in_node, out_node.node);
153         } else {
154           graph->AddEdge(in_node, in_index, out_node.node, out_node.index);
155         }
156       }
157     }
158     return OkStatus();
159   }
160 
161   string shared_name = node->name();
162   auto make_builder = [&](StringPiece op_name, StringPiece suffix) {
163     return NodeBuilder(strings::StrCat(shared_name, suffix), op_name)
164         .Attr("num_devices", num_devices)
165         .Attr("shared_name", shared_name)
166         .Attr("T", dtype);
167   };
168 
169   // Create broadcast send node and replace the original broadcast node.
170   NodeBuilder::NodeOut in_node;
171   NodeBuilder send_builder = make_builder("_NcclBroadcastSend", "Send");
172   for (const auto& edge : node->in_edges()) {
173     if (edge->IsControlEdge()) {
174       send_builder.ControlInput(edge->src());
175     } else {
176       in_node = NodeBuilder::NodeOut(edge->src(), edge->src_output());
177       send_builder.Input(in_node);
178     }
179   }
180   Node* send_node = nullptr;
181   TF_RETURN_IF_ERROR(send_builder.Finalize(graph, &send_node));
182   send_node->set_assigned_device_name_index(send_dev);
183 
184   TensorShapeProto shape_proto;
185   TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "shape", &shape_proto));
186 
187   // Delete the original node before reconnecting to outputs.
188   graph->RemoveNode(node);
189 
190   // Connect all outputs on the device of broadcast send.
191   for (const auto& out_node : out_nodes_map[send_dev]) {
192     if (out_node.index == Graph::kControlSlot) {
193       graph->AddControlEdge(send_node, out_node.node);
194     } else {
195       graph->AddEdge(in_node.node, in_node.index, out_node.node,
196                      out_node.index);
197       // Add control edge so send node is run.
198       graph->AddControlEdge(send_node, out_node.node);
199     }
200   }
201   out_nodes_map[send_dev].clear();
202 
203   TensorProto tensor_proto = TensorFromShape(shape_proto);
204   bool is_fully_defined = TensorShape(shape_proto).IsFullyDefined();
205   string shape_name = strings::StrCat(in_node.node->name(), "/Shape");
206   Node* shape_node = nullptr;
207   if (!is_fully_defined) {
208     NodeBuilder shape_builder(shape_name, "Shape");
209     shape_builder.Input(in_node).Attr("out_type", DT_INT32).Attr("T", dtype);
210     TF_RETURN_IF_ERROR(shape_builder.Finalize(graph, &shape_node));
211     shape_node->set_assigned_device_name_index(send_dev);
212   }
213 
214   // For all other devices, create a broadcast receive and connect outputs.
215   for (int recv_dev = 0; recv_dev < out_nodes_map.size(); ++recv_dev) {
216     if (out_nodes_map[recv_dev].empty()) {
217       continue;
218     }
219     int recv_index = recv_index_map[recv_dev];
220     if (is_fully_defined) {
221       // If the shape is fully defined, define one const node per device.
222       NodeBuilder shape_builder(strings::StrCat(shape_name, recv_index),
223                                 "Const");
224       shape_builder.Attr("value", tensor_proto).Attr("dtype", DT_INT32);
225       TF_RETURN_IF_ERROR(shape_builder.Finalize(graph, &shape_node));
226       shape_node->set_assigned_device_name_index(recv_dev);
227     }
228     Node* recv_node;
229     TF_RETURN_IF_ERROR(
230         make_builder("_NcclBroadcastRecv", strings::StrCat("Recv_", recv_index))
231             .Input(shape_node)
232             .Finalize(graph, &recv_node));
233     recv_node->set_assigned_device_name_index(recv_dev);
234     for (const auto& out_node : out_nodes_map[recv_dev]) {
235       graph->AddEdge(recv_node, 0, out_node.node, out_node.index);
236     }
237   }
238 
239   return OkStatus();
240 }
241 
242 // Replaces occurrences of Nccl{Reduce, Broadcast}Input/Output with their
243 // _Nccl...Send/Recv counterparts and removes data dependencies between them.
244 class NcclReplacePass : public GraphOptimizationPass {
245  public:
Run(const GraphOptimizationPassOptions & options)246   Status Run(const GraphOptimizationPassOptions& options) override {
247     if (options.graph == nullptr) {
248       return OkStatus();
249     }
250     Graph* graph = options.graph->get();
251     if (graph == nullptr) {
252       return errors::Internal(
253           "NCCL replacement should happen before partitioning and a "
254           "graph should be available.");
255     }
256     // Find reduction and broadcast ops and replace them with Send/Recv ops.
257     for (Node* node : graph->op_nodes()) {
258       StringPiece type = node->type_string();
259       if (!absl::StartsWith(type, "Nccl")) {
260         continue;
261       }
262       if (type == "NcclReduce") {
263         TF_RETURN_IF_ERROR(ReplaceReduce(graph, node));
264       }
265       if (type == "NcclBroadcast") {
266         TF_RETURN_IF_ERROR(ReplaceBroadcast(graph, node));
267       }
268     }
269     return OkStatus();
270   }
271 };
272 REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_PLACEMENT, 0,
273                       NcclReplacePass);
274 
275 }  // namespace
276 }  // namespace tensorflow
277 
278 #endif  // GOOGLE_CUDA
279