1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
16
17 #include <queue>
18
19 #include "tensorflow/core/framework/node_def.pb.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/platform/errors.h"
22
23 namespace tensorflow {
24 namespace {
25
26 // A helper for rewriting nodes assigned to a virtual composite device.
27 class ReplicateHelper {
28 public:
29 // Initialize replicated nodes with nullptr.
InitializeNode(const Node * node,int num_allowed_devices)30 Status InitializeNode(const Node* node, int num_allowed_devices) {
31 if (replicated_nodes_map_.find(node) != replicated_nodes_map_.end()) {
32 return errors::InvalidArgument("Node ", node->name(),
33 " has been replicated.");
34 }
35 std::vector<Node*> replicated_nodes(num_allowed_devices, nullptr);
36 replicated_nodes_map_.emplace(node, std::move(replicated_nodes));
37 return OkStatus();
38 }
39
40 // Replicate the given node to an allowed device.
ReplicateNode(const Node * node,const std::vector<string> & allowed_devices,int allowed_device_index,Graph * graph)41 Status ReplicateNode(const Node* node,
42 const std::vector<string>& allowed_devices,
43 int allowed_device_index, Graph* graph) {
44 auto& replicated_nodes = replicated_nodes_map_.at(node);
45 if (replicated_nodes[allowed_device_index] != nullptr) {
46 return OkStatus();
47 }
48 const auto& device = allowed_devices.at(allowed_device_index);
49 NodeDef node_def = node->def();
50 const string suffix = strings::StrCat("/R", allowed_device_index);
51 node_def.set_name(graph->NewName(strings::StrCat(node_def.name(), suffix)));
52 TF_ASSIGN_OR_RETURN(Node * replicated_node, graph->AddNode(node_def));
53 replicated_node->set_assigned_device_name(device);
54 if (replicated_node->IsArg()) {
55 replicated_node->AddAttr("sub_index", allowed_device_index);
56 }
57 replicated_nodes[allowed_device_index] = replicated_node;
58 return OkStatus();
59 }
60
61 // Replace an edge (a regular device -> composite device) with
62 // N edges (a regular device -> allowed devices).
ReplicateFromRegularDeviceToCompositeDevice(const Edge * edge,Graph * graph) const63 void ReplicateFromRegularDeviceToCompositeDevice(const Edge* edge,
64 Graph* graph) const {
65 Node* src = edge->src();
66 const std::vector<Node*>& dst_replicated_nodes =
67 replicated_nodes_map_.at(edge->dst());
68 for (Node* dst : dst_replicated_nodes) {
69 // Skip a replicated dst node without any consumer.
70 if (dst == nullptr) {
71 continue;
72 }
73 graph->AddEdge(src, edge->src_output(), dst, edge->dst_input());
74 }
75 }
76
77 // Replace an edge (composite device -> composite device) with
78 // N edges (allowed devices -> allowed devices).
ReplicateFromCompositeDeviceToCompositeDevice(const Edge * edge,const std::vector<string> & allowed_devices,Graph * graph)79 Status ReplicateFromCompositeDeviceToCompositeDevice(
80 const Edge* edge, const std::vector<string>& allowed_devices,
81 Graph* graph) {
82 const std::vector<Node*>& src_replicated_nodes =
83 replicated_nodes_map_.at(edge->src());
84 const std::vector<Node*>& dst_replicated_nodes =
85 replicated_nodes_map_.at(edge->dst());
86 if (src_replicated_nodes.size() != dst_replicated_nodes.size()) {
87 return errors::InvalidArgument(
88 "Nodes assigned to the same composite device should have the "
89 "same number of replicated nodes. Found an edge from node ",
90 edge->src()->name(), " (", src_replicated_nodes.size(),
91 " replicated nodes) to node ", edge->dst()->name(), " (",
92 dst_replicated_nodes.size(), " replicated nodes).");
93 }
94 for (int i = 0; i < src_replicated_nodes.size(); ++i) {
95 Node* dst = dst_replicated_nodes.at(i);
96 // Skip a replicated dst node without any consumer.
97 if (dst == nullptr) {
98 continue;
99 }
100 TF_RETURN_IF_ERROR(ReplicateNode(edge->src(), allowed_devices, i, graph));
101 graph->AddEdge(src_replicated_nodes.at(i), edge->src_output(), dst,
102 edge->dst_input());
103 }
104 return OkStatus();
105 }
106
107 // Data edge: replace an edge (composite device -> a regular device) with
108 // one edge (one allowed device -> a regular device).
109 // Control edge: replace an edge (composite device -> a regular device) with
110 // N edges (allowed devices -> a regular device).
ReplicateFromCompositeDeviceToRegularDevice(const Edge * edge,const std::vector<string> & allowed_devices,Graph * graph)111 Status ReplicateFromCompositeDeviceToRegularDevice(
112 const Edge* edge, const std::vector<string>& allowed_devices,
113 Graph* graph) {
114 const std::vector<Node*>& src_replicated_nodes =
115 replicated_nodes_map_.at(edge->src());
116 Node* dst = edge->dst();
117 const string& dst_device = dst->assigned_device_name();
118 bool found_src_node = false;
119 for (int i = 0; i < allowed_devices.size(); ++i) {
120 if (allowed_devices.at(i) == dst_device) {
121 TF_RETURN_IF_ERROR(
122 ReplicateNode(edge->src(), allowed_devices, i, graph));
123 graph->AddEdge(src_replicated_nodes.at(i), edge->src_output(), dst,
124 edge->dst_input());
125 found_src_node = true;
126 break;
127 }
128 }
129 if (!found_src_node) {
130 for (int i = 0; i < allowed_devices.size(); ++i) {
131 TF_RETURN_IF_ERROR(
132 ReplicateNode(edge->src(), allowed_devices, i, graph));
133 }
134 if (edge->IsControlEdge()) {
135 for (Node* replicated_node : src_replicated_nodes) {
136 // Duplication check in `Graph::AddControlEdge` is expensive for the
137 // dst node with a lot of input edges. Here each (src, dst) pair
138 // will only occur once so it is safe to skip the duplication check.
139 graph->AddControlEdge(replicated_node, dst,
140 /*allow_duplicates=*/true);
141 }
142 return OkStatus();
143 }
144 if (edge->src()->type_string() == "_Arg") {
145 // This happens when the dst node runs on a host CPU and
146 // captures a function with an arg node assigned to the same
147 // composite device (e.g. ScanDataset).
148 // For this case, we insert a PackOp between replicated nodes and the
149 // dst node. The dst node is responsible for unpacking the packed
150 // tensor.
151 // Add '/Packed' as a substring to the name of the new node, which
152 // could be helpful when debugging the graph.
153 NodeDefBuilder pack_builder(
154 graph->NewName(absl::StrCat(edge->src()->name(), "/Packed")),
155 "Pack");
156 const int num_replicas = src_replicated_nodes.size();
157 pack_builder.Attr("N", num_replicas);
158 const DataType dtype = edge->src()->output_type(edge->src_output());
159 pack_builder.Attr("T", dtype);
160 std::vector<NodeDefBuilder::NodeOut> inputs;
161 inputs.reserve(src_replicated_nodes.size());
162 for (Node* replicated_node : src_replicated_nodes) {
163 inputs.emplace_back(NodeDefBuilder::NodeOut{
164 replicated_node->name(), edge->src_output(), dtype});
165 }
166 pack_builder.Input(inputs);
167 NodeDef pack_def;
168 TF_RETURN_IF_ERROR(pack_builder.Finalize(&pack_def));
169 TF_ASSIGN_OR_RETURN(Node * pack_node, graph->AddNode(pack_def));
170 pack_node->set_assigned_device_name(dst->assigned_device_name());
171 for (int i = 0; i < src_replicated_nodes.size(); ++i) {
172 graph->AddEdge(src_replicated_nodes[i], edge->src_output(), pack_node,
173 i);
174 }
175 graph->AddEdge(pack_node, /*x=*/0, dst, edge->dst_input());
176 } else {
177 return errors::InvalidArgument(
178 "Dst node should be assigned to an allowed device. Found an "
179 "edge from node ",
180 edge->src()->name(), " assigned to ",
181 edge->src()->assigned_device_name(), " to node ", dst->name(),
182 " assigned to ", dst_device);
183 }
184 }
185 return OkStatus();
186 }
187
188 private:
189 // Map from original nodes to corresponding replicated nodes.
190 absl::flat_hash_map<const Node*, std::vector<Node*>> replicated_nodes_map_;
191 };
192
193 // Replicate the nodes in cluster_nodes and update edges.
ReplicateNodesAndEdges(const std::vector<string> & allowed_devices,absl::flat_hash_map<Node *,int> * cluster_nodes,ReplicateHelper * helper,Graph * graph)194 Status ReplicateNodesAndEdges(const std::vector<string>& allowed_devices,
195 absl::flat_hash_map<Node*, int>* cluster_nodes,
196 ReplicateHelper* helper, Graph* graph) {
197 // Contains nodes in cluster_nodes whose out nodes are all on physical
198 // devices.
199 std::queue<Node*> nodes_ready_to_delete;
200 for (auto& pair : *cluster_nodes) {
201 Node* node = pair.first;
202 for (const Edge* edge : node->out_edges()) {
203 Node* dst = edge->dst();
204 if (dst->assigned_device_name() != node->assigned_device_name()) {
205 // The dst node is assigned to a different device.
206 TF_RETURN_IF_ERROR(helper->ReplicateFromCompositeDeviceToRegularDevice(
207 edge, allowed_devices, graph));
208 --pair.second;
209 }
210 }
211 // Node is ready to delete when all its consumer nodes are assigned to a
212 // physical device.
213 if (cluster_nodes->at(node) == 0) {
214 nodes_ready_to_delete.push(node);
215 }
216 }
217
218 while (!nodes_ready_to_delete.empty()) {
219 Node* node = nodes_ready_to_delete.front();
220 nodes_ready_to_delete.pop();
221
222 // Update input edges.
223 for (const Edge* edge : node->in_edges()) {
224 Node* src = edge->src();
225 if (src->assigned_device_name() != node->assigned_device_name()) {
226 // The source node is assigned to a different device.
227 helper->ReplicateFromRegularDeviceToCompositeDevice(edge, graph);
228 } else {
229 // The source node is assigned to the same composite device.
230 TF_RETURN_IF_ERROR(
231 helper->ReplicateFromCompositeDeviceToCompositeDevice(
232 edge, allowed_devices, graph));
233 if (--(*cluster_nodes)[src] == 0) {
234 nodes_ready_to_delete.push(src);
235 }
236 }
237 }
238
239 // Remove the original node.
240 cluster_nodes->erase(node);
241 graph->RemoveNode(node);
242 }
243 return OkStatus();
244 }
245
246 } // namespace
247
ReplicatePerReplicaNodesInFunctionGraph(const absl::flat_hash_map<string,const std::vector<string> * > & composite_devices,Graph * graph)248 Status ReplicatePerReplicaNodesInFunctionGraph(
249 const absl::flat_hash_map<string, const std::vector<string>*>&
250 composite_devices,
251 Graph* graph) {
252 std::set<string> composite_device_names;
253 for (const auto& it : composite_devices) {
254 composite_device_names.insert(it.first);
255 }
256 // Map from a composite device to a cluster of nodes assigned to the
257 // composite device and the numbers of their out edges to process.
258 absl::flat_hash_map<string, absl::flat_hash_map<Node*, int>>
259 composite_device_to_cluster_nodes;
260 for (Node* n : graph->op_nodes()) {
261 if (composite_device_names.find(n->assigned_device_name()) !=
262 composite_device_names.end()) {
263 // TODO(b/145922293): Validate that an _Arg node assigned to a
264 // CompositeDevice should have an attribute indicating that the _Arg node
265 // represents a packed input.
266 composite_device_to_cluster_nodes[n->assigned_device_name()].emplace(
267 n, n->out_edges().size());
268 }
269 }
270
271 for (auto& it : composite_device_to_cluster_nodes) {
272 const std::vector<string>& allowed_devices =
273 *composite_devices.at(it.first);
274 if (allowed_devices.empty()) {
275 return errors::InvalidArgument("No allowed device of composite device: ",
276 it.first);
277 }
278 absl::flat_hash_map<Node*, int>& cluster_nodes = it.second;
279 if (allowed_devices.size() == 1) {
280 // Reuse the original nodes if there is only one allowed device.
281 for (const auto& pair : it.second) {
282 Node* n = pair.first;
283 n->set_assigned_device_name(allowed_devices.at(0));
284 if (n->IsArg()) {
285 n->AddAttr("sub_index", 0);
286 }
287 }
288 continue;
289 }
290 ReplicateHelper helper;
291 for (const auto& pair : cluster_nodes) {
292 TF_RETURN_IF_ERROR(
293 helper.InitializeNode(pair.first, allowed_devices.size()));
294 }
295
296 TF_RETURN_IF_ERROR(ReplicateNodesAndEdges(allowed_devices, &cluster_nodes,
297 &helper, graph));
298
299 if (!cluster_nodes.empty()) {
300 return errors::InvalidArgument(
301 "There are still ", cluster_nodes.size(),
302 " nodes on CompositiveDevice ",
303 cluster_nodes.begin()->first->assigned_device_name());
304 }
305 }
306 return OkStatus();
307 }
308
309 } // namespace tensorflow
310