• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
16 
17 #include <queue>
18 
19 #include "tensorflow/core/framework/node_def.pb.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/platform/errors.h"
22 
23 namespace tensorflow {
24 namespace {
25 
26 // A helper for rewriting nodes assigned to a virtual composite device.
27 class ReplicateHelper {
28  public:
29   // Initialize replicated nodes with nullptr.
InitializeNode(const Node * node,int num_allowed_devices)30   Status InitializeNode(const Node* node, int num_allowed_devices) {
31     if (replicated_nodes_map_.find(node) != replicated_nodes_map_.end()) {
32       return errors::InvalidArgument("Node ", node->name(),
33                                      " has been replicated.");
34     }
35     std::vector<Node*> replicated_nodes(num_allowed_devices, nullptr);
36     replicated_nodes_map_.emplace(node, std::move(replicated_nodes));
37     return Status::OK();
38   }
39 
40   // Replicate the given node to an allowed device.
ReplicateNode(const Node * node,const std::vector<string> & allowed_devices,int allowed_device_index,Graph * graph)41   Status ReplicateNode(const Node* node,
42                        const std::vector<string>& allowed_devices,
43                        int allowed_device_index, Graph* graph) {
44     auto& replicated_nodes = replicated_nodes_map_.at(node);
45     if (replicated_nodes[allowed_device_index] != nullptr) {
46       return Status::OK();
47     }
48     const auto& device = allowed_devices.at(allowed_device_index);
49     NodeDef node_def = node->def();
50     const string suffix = strings::StrCat("/R", allowed_device_index);
51     node_def.set_name(graph->NewName(strings::StrCat(node_def.name(), suffix)));
52     Status status;
53     Node* replicated_node = graph->AddNode(node_def, &status);
54     TF_RETURN_IF_ERROR(status);
55     replicated_node->set_assigned_device_name(device);
56     if (replicated_node->IsArg()) {
57       replicated_node->AddAttr("sub_index", allowed_device_index);
58     }
59     replicated_nodes[allowed_device_index] = replicated_node;
60     return Status::OK();
61   }
62 
63   // Replace an edge (a regular device -> composite device) with
64   // N edges (a regular device -> allowed devices).
ReplicateFromRegularDeviceToCompositeDevice(const Edge * edge,Graph * graph) const65   void ReplicateFromRegularDeviceToCompositeDevice(const Edge* edge,
66                                                    Graph* graph) const {
67     Node* src = edge->src();
68     const std::vector<Node*>& dst_replicated_nodes =
69         replicated_nodes_map_.at(edge->dst());
70     for (Node* dst : dst_replicated_nodes) {
71       // Skip a replicated dst node without any consumer.
72       if (dst == nullptr) {
73         continue;
74       }
75       graph->AddEdge(src, edge->src_output(), dst, edge->dst_input());
76     }
77   }
78 
79   // Replace an edge (composite device -> composite device) with
80   // N edges (allowed devices -> allowed devices).
ReplicateFromCompositeDeviceToCompositeDevice(const Edge * edge,const std::vector<string> & allowed_devices,Graph * graph)81   Status ReplicateFromCompositeDeviceToCompositeDevice(
82       const Edge* edge, const std::vector<string>& allowed_devices,
83       Graph* graph) {
84     const std::vector<Node*>& src_replicated_nodes =
85         replicated_nodes_map_.at(edge->src());
86     const std::vector<Node*>& dst_replicated_nodes =
87         replicated_nodes_map_.at(edge->dst());
88     if (src_replicated_nodes.size() != dst_replicated_nodes.size()) {
89       return errors::InvalidArgument(
90           "Nodes assigned to the same composite device should have the "
91           "same number of replicated nodes. Found an edge from node ",
92           edge->src()->name(), " (", src_replicated_nodes.size(),
93           " replicated nodes) to node ", edge->dst()->name(), " (",
94           dst_replicated_nodes.size(), " replicated nodes).");
95     }
96     for (int i = 0; i < src_replicated_nodes.size(); ++i) {
97       Node* dst = dst_replicated_nodes.at(i);
98       // Skip a replicated dst node without any consumer.
99       if (dst == nullptr) {
100         continue;
101       }
102       TF_RETURN_IF_ERROR(ReplicateNode(edge->src(), allowed_devices, i, graph));
103       graph->AddEdge(src_replicated_nodes.at(i), edge->src_output(), dst,
104                      edge->dst_input());
105     }
106     return Status::OK();
107   }
108 
109   // Data edge: replace an edge (composite device -> a regular device) with
110   // one edge (one allowed device -> a regular device).
111   // Control edge: replace an edge (composite device -> a regular device) with
112   // N edges (allowed devices -> a regular device).
ReplicateFromCompositeDeviceToRegularDevice(const Edge * edge,const std::vector<string> & allowed_devices,Graph * graph)113   Status ReplicateFromCompositeDeviceToRegularDevice(
114       const Edge* edge, const std::vector<string>& allowed_devices,
115       Graph* graph) {
116     const std::vector<Node*>& src_replicated_nodes =
117         replicated_nodes_map_.at(edge->src());
118     Node* dst = edge->dst();
119     const string& dst_device = dst->assigned_device_name();
120     bool found_src_node = false;
121     for (int i = 0; i < allowed_devices.size(); ++i) {
122       if (allowed_devices.at(i) == dst_device) {
123         TF_RETURN_IF_ERROR(
124             ReplicateNode(edge->src(), allowed_devices, i, graph));
125         graph->AddEdge(src_replicated_nodes.at(i), edge->src_output(), dst,
126                        edge->dst_input());
127         found_src_node = true;
128         break;
129       }
130     }
131     if (!found_src_node) {
132       for (int i = 0; i < allowed_devices.size(); ++i) {
133         TF_RETURN_IF_ERROR(
134             ReplicateNode(edge->src(), allowed_devices, i, graph));
135       }
136       if (edge->IsControlEdge()) {
137         for (Node* replicated_node : src_replicated_nodes) {
138           // Duplication check in `Graph::AddControlEdge` is expensive for the
139           // dst node with a lot of input edges. Here each (src, dst) pair
140           // will only occur once so it is safe to skip the duplication check.
141           graph->AddControlEdge(replicated_node, dst,
142                                 /*allow_duplicates=*/true);
143         }
144         return Status::OK();
145       }
146       if (edge->src()->type_string() == "_Arg") {
147         // This happens when the dst node runs on a host CPU and
148         // captures a function with an arg node assigned to the same
149         // composite device (e.g. ScanDataset).
150         // For this case, we insert a PackOp between replicated nodes and the
151         // dst node. The dst node is responsible for unpacking the packed
152         // tensor.
153         // Add '/Packed' as a substring to the name of the new node, which
154         // could be helpful when debugging the graph.
155         NodeDefBuilder pack_builder(
156             graph->NewName(absl::StrCat(edge->src()->name(), "/Packed")),
157             "Pack");
158         const int num_replicas = src_replicated_nodes.size();
159         pack_builder.Attr("N", num_replicas);
160         const DataType dtype = edge->src()->output_type(edge->src_output());
161         pack_builder.Attr("T", dtype);
162         std::vector<NodeDefBuilder::NodeOut> inputs;
163         inputs.reserve(src_replicated_nodes.size());
164         for (Node* replicated_node : src_replicated_nodes) {
165           inputs.emplace_back(NodeDefBuilder::NodeOut{
166               replicated_node->name(), edge->src_output(), dtype});
167         }
168         pack_builder.Input(inputs);
169         NodeDef pack_def;
170         TF_RETURN_IF_ERROR(pack_builder.Finalize(&pack_def));
171         Status status;
172         Node* pack_node = graph->AddNode(pack_def, &status);
173         TF_RETURN_IF_ERROR(status);
174         pack_node->set_assigned_device_name(dst->assigned_device_name());
175         for (int i = 0; i < src_replicated_nodes.size(); ++i) {
176           graph->AddEdge(src_replicated_nodes[i], edge->src_output(), pack_node,
177                          i);
178         }
179         graph->AddEdge(pack_node, /*x=*/0, dst, edge->dst_input());
180       } else {
181         return errors::InvalidArgument(
182             "Dst node should be assigned to an allowed device. Found an "
183             "edge from node ",
184             edge->src()->name(), " assigned to ",
185             edge->src()->assigned_device_name(), " to node ", dst->name(),
186             " assigned to ", dst_device);
187       }
188     }
189     return Status::OK();
190   }
191 
192  private:
193   // Map from original nodes to corresponding replicated nodes.
194   absl::flat_hash_map<const Node*, std::vector<Node*>> replicated_nodes_map_;
195 };
196 
197 // Replicate the nodes in cluster_nodes and update edges.
ReplicateNodesAndEdges(const std::vector<string> & allowed_devices,absl::flat_hash_map<Node *,int> * cluster_nodes,ReplicateHelper * helper,Graph * graph)198 Status ReplicateNodesAndEdges(const std::vector<string>& allowed_devices,
199                               absl::flat_hash_map<Node*, int>* cluster_nodes,
200                               ReplicateHelper* helper, Graph* graph) {
201   // Contains nodes in cluster_nodes whose out nodes are all on physical
202   // devices.
203   std::queue<Node*> nodes_ready_to_delete;
204   for (auto& pair : *cluster_nodes) {
205     Node* node = pair.first;
206     for (const Edge* edge : node->out_edges()) {
207       Node* dst = edge->dst();
208       if (dst->assigned_device_name() != node->assigned_device_name()) {
209         // The dst node is assigned to a different device.
210         TF_RETURN_IF_ERROR(helper->ReplicateFromCompositeDeviceToRegularDevice(
211             edge, allowed_devices, graph));
212         --pair.second;
213       }
214     }
215     // Node is ready to delete when all its comsumer nodes are assigned to a
216     // physical device.
217     if (cluster_nodes->at(node) == 0) {
218       nodes_ready_to_delete.push(node);
219     }
220   }
221 
222   while (!nodes_ready_to_delete.empty()) {
223     Node* node = nodes_ready_to_delete.front();
224     nodes_ready_to_delete.pop();
225 
226     // Update input edges.
227     for (const Edge* edge : node->in_edges()) {
228       Node* src = edge->src();
229       if (src->assigned_device_name() != node->assigned_device_name()) {
230         // The source node is assigned to a different device.
231         helper->ReplicateFromRegularDeviceToCompositeDevice(edge, graph);
232       } else {
233         // The source node is assigned to the same composite device.
234         TF_RETURN_IF_ERROR(
235             helper->ReplicateFromCompositeDeviceToCompositeDevice(
236                 edge, allowed_devices, graph));
237         if (--(*cluster_nodes)[src] == 0) {
238           nodes_ready_to_delete.push(src);
239         }
240       }
241     }
242 
243     // Remove the original node.
244     cluster_nodes->erase(node);
245     graph->RemoveNode(node);
246   }
247   return Status::OK();
248 }
249 
250 }  // namespace
251 
ReplicatePerReplicaNodesInFunctionGraph(const absl::flat_hash_map<string,const std::vector<string> * > & composite_devices,Graph * graph)252 Status ReplicatePerReplicaNodesInFunctionGraph(
253     const absl::flat_hash_map<string, const std::vector<string>*>&
254         composite_devices,
255     Graph* graph) {
256   std::set<string> composite_device_names;
257   for (const auto& it : composite_devices) {
258     composite_device_names.insert(it.first);
259   }
260   // Map from a composite device to a cluster of nodes assigned to the
261   // composite device and the numbers of their out edges to process.
262   absl::flat_hash_map<string, absl::flat_hash_map<Node*, int>>
263       composite_device_to_cluster_nodes;
264   for (Node* n : graph->op_nodes()) {
265     if (composite_device_names.find(n->assigned_device_name()) !=
266         composite_device_names.end()) {
267       // TODO(b/145922293): Validate that an _Arg node assigned to a
268       // CompositeDevice should have an attribute indicating that the _Arg node
269       // represents a packed input.
270       composite_device_to_cluster_nodes[n->assigned_device_name()].emplace(
271           n, n->out_edges().size());
272     }
273   }
274 
275   for (auto& it : composite_device_to_cluster_nodes) {
276     const std::vector<string>& allowed_devices =
277         *composite_devices.at(it.first);
278     if (allowed_devices.empty()) {
279       return errors::InvalidArgument("No allowed device of composite device: ",
280                                      it.first);
281     }
282     absl::flat_hash_map<Node*, int>& cluster_nodes = it.second;
283     if (allowed_devices.size() == 1) {
284       // Reuse the original nodes if there is only one allowed device.
285       for (const auto& pair : it.second) {
286         Node* n = pair.first;
287         n->set_assigned_device_name(allowed_devices.at(0));
288         if (n->IsArg()) {
289           n->AddAttr("sub_index", 0);
290         }
291       }
292       continue;
293     }
294     ReplicateHelper helper;
295     for (const auto& pair : cluster_nodes) {
296       TF_RETURN_IF_ERROR(
297           helper.InitializeNode(pair.first, allowed_devices.size()));
298     }
299 
300     TF_RETURN_IF_ERROR(ReplicateNodesAndEdges(allowed_devices, &cluster_nodes,
301                                               &helper, graph));
302 
303     if (!cluster_nodes.empty()) {
304       return errors::InvalidArgument(
305           "There are still ", cluster_nodes.size(),
306           " nodes on CompositiveDevice ",
307           cluster_nodes.begin()->first->assigned_device_name());
308     }
309   }
310   return Status::OK();
311 }
312 
313 }  // namespace tensorflow
314