1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
16
17 #include <queue>
18
19 #include "tensorflow/core/framework/node_def.pb.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/platform/errors.h"
22
23 namespace tensorflow {
24 namespace {
25
26 // A helper for rewriting nodes assigned to a virtual composite device.
27 class ReplicateHelper {
28 public:
29 // Initialize replicated nodes with nullptr.
InitializeNode(const Node * node,int num_allowed_devices)30 Status InitializeNode(const Node* node, int num_allowed_devices) {
31 if (replicated_nodes_map_.find(node) != replicated_nodes_map_.end()) {
32 return errors::InvalidArgument("Node ", node->name(),
33 " has been replicated.");
34 }
35 std::vector<Node*> replicated_nodes(num_allowed_devices, nullptr);
36 replicated_nodes_map_.emplace(node, std::move(replicated_nodes));
37 return Status::OK();
38 }
39
40 // Replicate the given node to an allowed device.
ReplicateNode(const Node * node,const std::vector<string> & allowed_devices,int allowed_device_index,Graph * graph)41 Status ReplicateNode(const Node* node,
42 const std::vector<string>& allowed_devices,
43 int allowed_device_index, Graph* graph) {
44 auto& replicated_nodes = replicated_nodes_map_.at(node);
45 if (replicated_nodes[allowed_device_index] != nullptr) {
46 return Status::OK();
47 }
48 const auto& device = allowed_devices.at(allowed_device_index);
49 NodeDef node_def = node->def();
50 const string suffix = strings::StrCat("/R", allowed_device_index);
51 node_def.set_name(graph->NewName(strings::StrCat(node_def.name(), suffix)));
52 Status status;
53 Node* replicated_node = graph->AddNode(node_def, &status);
54 TF_RETURN_IF_ERROR(status);
55 replicated_node->set_assigned_device_name(device);
56 if (replicated_node->IsArg()) {
57 replicated_node->AddAttr("sub_index", allowed_device_index);
58 }
59 replicated_nodes[allowed_device_index] = replicated_node;
60 return Status::OK();
61 }
62
63 // Replace an edge (a regular device -> composite device) with
64 // N edges (a regular device -> allowed devices).
ReplicateFromRegularDeviceToCompositeDevice(const Edge * edge,Graph * graph) const65 void ReplicateFromRegularDeviceToCompositeDevice(const Edge* edge,
66 Graph* graph) const {
67 Node* src = edge->src();
68 const std::vector<Node*>& dst_replicated_nodes =
69 replicated_nodes_map_.at(edge->dst());
70 for (Node* dst : dst_replicated_nodes) {
71 // Skip a replicated dst node without any consumer.
72 if (dst == nullptr) {
73 continue;
74 }
75 graph->AddEdge(src, edge->src_output(), dst, edge->dst_input());
76 }
77 }
78
79 // Replace an edge (composite device -> composite device) with
80 // N edges (allowed devices -> allowed devices).
ReplicateFromCompositeDeviceToCompositeDevice(const Edge * edge,const std::vector<string> & allowed_devices,Graph * graph)81 Status ReplicateFromCompositeDeviceToCompositeDevice(
82 const Edge* edge, const std::vector<string>& allowed_devices,
83 Graph* graph) {
84 const std::vector<Node*>& src_replicated_nodes =
85 replicated_nodes_map_.at(edge->src());
86 const std::vector<Node*>& dst_replicated_nodes =
87 replicated_nodes_map_.at(edge->dst());
88 if (src_replicated_nodes.size() != dst_replicated_nodes.size()) {
89 return errors::InvalidArgument(
90 "Nodes assigned to the same composite device should have the "
91 "same number of replicated nodes. Found an edge from node ",
92 edge->src()->name(), " (", src_replicated_nodes.size(),
93 " replicated nodes) to node ", edge->dst()->name(), " (",
94 dst_replicated_nodes.size(), " replicated nodes).");
95 }
96 for (int i = 0; i < src_replicated_nodes.size(); ++i) {
97 Node* dst = dst_replicated_nodes.at(i);
98 // Skip a replicated dst node without any consumer.
99 if (dst == nullptr) {
100 continue;
101 }
102 TF_RETURN_IF_ERROR(ReplicateNode(edge->src(), allowed_devices, i, graph));
103 graph->AddEdge(src_replicated_nodes.at(i), edge->src_output(), dst,
104 edge->dst_input());
105 }
106 return Status::OK();
107 }
108
109 // Data edge: replace an edge (composite device -> a regular device) with
110 // one edge (one allowed device -> a regular device).
111 // Control edge: replace an edge (composite device -> a regular device) with
112 // N edges (allowed devices -> a regular device).
ReplicateFromCompositeDeviceToRegularDevice(const Edge * edge,const std::vector<string> & allowed_devices,Graph * graph)113 Status ReplicateFromCompositeDeviceToRegularDevice(
114 const Edge* edge, const std::vector<string>& allowed_devices,
115 Graph* graph) {
116 const std::vector<Node*>& src_replicated_nodes =
117 replicated_nodes_map_.at(edge->src());
118 Node* dst = edge->dst();
119 const string& dst_device = dst->assigned_device_name();
120 bool found_src_node = false;
121 for (int i = 0; i < allowed_devices.size(); ++i) {
122 if (allowed_devices.at(i) == dst_device) {
123 TF_RETURN_IF_ERROR(
124 ReplicateNode(edge->src(), allowed_devices, i, graph));
125 graph->AddEdge(src_replicated_nodes.at(i), edge->src_output(), dst,
126 edge->dst_input());
127 found_src_node = true;
128 break;
129 }
130 }
131 if (!found_src_node) {
132 for (int i = 0; i < allowed_devices.size(); ++i) {
133 TF_RETURN_IF_ERROR(
134 ReplicateNode(edge->src(), allowed_devices, i, graph));
135 }
136 if (edge->IsControlEdge()) {
137 for (Node* replicated_node : src_replicated_nodes) {
138 // Duplication check in `Graph::AddControlEdge` is expensive for the
139 // dst node with a lot of input edges. Here each (src, dst) pair
140 // will only occur once so it is safe to skip the duplication check.
141 graph->AddControlEdge(replicated_node, dst,
142 /*allow_duplicates=*/true);
143 }
144 return Status::OK();
145 }
146 if (edge->src()->type_string() == "_Arg") {
147 // This happens when the dst node runs on a host CPU and
148 // captures a function with an arg node assigned to the same
149 // composite device (e.g. ScanDataset).
150 // For this case, we insert a PackOp between replicated nodes and the
151 // dst node. The dst node is responsible for unpacking the packed
152 // tensor.
153 // Add '/Packed' as a substring to the name of the new node, which
154 // could be helpful when debugging the graph.
155 NodeDefBuilder pack_builder(
156 graph->NewName(absl::StrCat(edge->src()->name(), "/Packed")),
157 "Pack");
158 const int num_replicas = src_replicated_nodes.size();
159 pack_builder.Attr("N", num_replicas);
160 const DataType dtype = edge->src()->output_type(edge->src_output());
161 pack_builder.Attr("T", dtype);
162 std::vector<NodeDefBuilder::NodeOut> inputs;
163 inputs.reserve(src_replicated_nodes.size());
164 for (Node* replicated_node : src_replicated_nodes) {
165 inputs.emplace_back(NodeDefBuilder::NodeOut{
166 replicated_node->name(), edge->src_output(), dtype});
167 }
168 pack_builder.Input(inputs);
169 NodeDef pack_def;
170 TF_RETURN_IF_ERROR(pack_builder.Finalize(&pack_def));
171 Status status;
172 Node* pack_node = graph->AddNode(pack_def, &status);
173 TF_RETURN_IF_ERROR(status);
174 pack_node->set_assigned_device_name(dst->assigned_device_name());
175 for (int i = 0; i < src_replicated_nodes.size(); ++i) {
176 graph->AddEdge(src_replicated_nodes[i], edge->src_output(), pack_node,
177 i);
178 }
179 graph->AddEdge(pack_node, /*x=*/0, dst, edge->dst_input());
180 } else {
181 return errors::InvalidArgument(
182 "Dst node should be assigned to an allowed device. Found an "
183 "edge from node ",
184 edge->src()->name(), " assigned to ",
185 edge->src()->assigned_device_name(), " to node ", dst->name(),
186 " assigned to ", dst_device);
187 }
188 }
189 return Status::OK();
190 }
191
192 private:
193 // Map from original nodes to corresponding replicated nodes.
194 absl::flat_hash_map<const Node*, std::vector<Node*>> replicated_nodes_map_;
195 };
196
197 // Replicate the nodes in cluster_nodes and update edges.
ReplicateNodesAndEdges(const std::vector<string> & allowed_devices,absl::flat_hash_map<Node *,int> * cluster_nodes,ReplicateHelper * helper,Graph * graph)198 Status ReplicateNodesAndEdges(const std::vector<string>& allowed_devices,
199 absl::flat_hash_map<Node*, int>* cluster_nodes,
200 ReplicateHelper* helper, Graph* graph) {
201 // Contains nodes in cluster_nodes whose out nodes are all on physical
202 // devices.
203 std::queue<Node*> nodes_ready_to_delete;
204 for (auto& pair : *cluster_nodes) {
205 Node* node = pair.first;
206 for (const Edge* edge : node->out_edges()) {
207 Node* dst = edge->dst();
208 if (dst->assigned_device_name() != node->assigned_device_name()) {
209 // The dst node is assigned to a different device.
210 TF_RETURN_IF_ERROR(helper->ReplicateFromCompositeDeviceToRegularDevice(
211 edge, allowed_devices, graph));
212 --pair.second;
213 }
214 }
215 // Node is ready to delete when all its comsumer nodes are assigned to a
216 // physical device.
217 if (cluster_nodes->at(node) == 0) {
218 nodes_ready_to_delete.push(node);
219 }
220 }
221
222 while (!nodes_ready_to_delete.empty()) {
223 Node* node = nodes_ready_to_delete.front();
224 nodes_ready_to_delete.pop();
225
226 // Update input edges.
227 for (const Edge* edge : node->in_edges()) {
228 Node* src = edge->src();
229 if (src->assigned_device_name() != node->assigned_device_name()) {
230 // The source node is assigned to a different device.
231 helper->ReplicateFromRegularDeviceToCompositeDevice(edge, graph);
232 } else {
233 // The source node is assigned to the same composite device.
234 TF_RETURN_IF_ERROR(
235 helper->ReplicateFromCompositeDeviceToCompositeDevice(
236 edge, allowed_devices, graph));
237 if (--(*cluster_nodes)[src] == 0) {
238 nodes_ready_to_delete.push(src);
239 }
240 }
241 }
242
243 // Remove the original node.
244 cluster_nodes->erase(node);
245 graph->RemoveNode(node);
246 }
247 return Status::OK();
248 }
249
250 } // namespace
251
ReplicatePerReplicaNodesInFunctionGraph(const absl::flat_hash_map<string,const std::vector<string> * > & composite_devices,Graph * graph)252 Status ReplicatePerReplicaNodesInFunctionGraph(
253 const absl::flat_hash_map<string, const std::vector<string>*>&
254 composite_devices,
255 Graph* graph) {
256 std::set<string> composite_device_names;
257 for (const auto& it : composite_devices) {
258 composite_device_names.insert(it.first);
259 }
260 // Map from a composite device to a cluster of nodes assigned to the
261 // composite device and the numbers of their out edges to process.
262 absl::flat_hash_map<string, absl::flat_hash_map<Node*, int>>
263 composite_device_to_cluster_nodes;
264 for (Node* n : graph->op_nodes()) {
265 if (composite_device_names.find(n->assigned_device_name()) !=
266 composite_device_names.end()) {
267 // TODO(b/145922293): Validate that an _Arg node assigned to a
268 // CompositeDevice should have an attribute indicating that the _Arg node
269 // represents a packed input.
270 composite_device_to_cluster_nodes[n->assigned_device_name()].emplace(
271 n, n->out_edges().size());
272 }
273 }
274
275 for (auto& it : composite_device_to_cluster_nodes) {
276 const std::vector<string>& allowed_devices =
277 *composite_devices.at(it.first);
278 if (allowed_devices.empty()) {
279 return errors::InvalidArgument("No allowed device of composite device: ",
280 it.first);
281 }
282 absl::flat_hash_map<Node*, int>& cluster_nodes = it.second;
283 if (allowed_devices.size() == 1) {
284 // Reuse the original nodes if there is only one allowed device.
285 for (const auto& pair : it.second) {
286 Node* n = pair.first;
287 n->set_assigned_device_name(allowed_devices.at(0));
288 if (n->IsArg()) {
289 n->AddAttr("sub_index", 0);
290 }
291 }
292 continue;
293 }
294 ReplicateHelper helper;
295 for (const auto& pair : cluster_nodes) {
296 TF_RETURN_IF_ERROR(
297 helper.InitializeNode(pair.first, allowed_devices.size()));
298 }
299
300 TF_RETURN_IF_ERROR(ReplicateNodesAndEdges(allowed_devices, &cluster_nodes,
301 &helper, graph));
302
303 if (!cluster_nodes.empty()) {
304 return errors::InvalidArgument(
305 "There are still ", cluster_nodes.size(),
306 " nodes on CompositiveDevice ",
307 cluster_nodes.begin()->first->assigned_device_name());
308 }
309 }
310 return Status::OK();
311 }
312
313 } // namespace tensorflow
314