• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
16 
17 #include <queue>
18 
19 #include "tensorflow/core/framework/node_def.pb.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/platform/errors.h"
22 
23 namespace tensorflow {
24 namespace {
25 
26 // A helper for rewriting nodes assigned to a virtual composite device.
27 class ReplicateHelper {
28  public:
29   // Initialize replicated nodes with nullptr.
InitializeNode(const Node * node,int num_allowed_devices)30   Status InitializeNode(const Node* node, int num_allowed_devices) {
31     if (replicated_nodes_map_.find(node) != replicated_nodes_map_.end()) {
32       return errors::InvalidArgument("Node ", node->name(),
33                                      " has been replicated.");
34     }
35     std::vector<Node*> replicated_nodes(num_allowed_devices, nullptr);
36     replicated_nodes_map_.emplace(node, std::move(replicated_nodes));
37     return OkStatus();
38   }
39 
40   // Replicate the given node to an allowed device.
ReplicateNode(const Node * node,const std::vector<string> & allowed_devices,int allowed_device_index,Graph * graph)41   Status ReplicateNode(const Node* node,
42                        const std::vector<string>& allowed_devices,
43                        int allowed_device_index, Graph* graph) {
44     auto& replicated_nodes = replicated_nodes_map_.at(node);
45     if (replicated_nodes[allowed_device_index] != nullptr) {
46       return OkStatus();
47     }
48     const auto& device = allowed_devices.at(allowed_device_index);
49     NodeDef node_def = node->def();
50     const string suffix = strings::StrCat("/R", allowed_device_index);
51     node_def.set_name(graph->NewName(strings::StrCat(node_def.name(), suffix)));
52     TF_ASSIGN_OR_RETURN(Node * replicated_node, graph->AddNode(node_def));
53     replicated_node->set_assigned_device_name(device);
54     if (replicated_node->IsArg()) {
55       replicated_node->AddAttr("sub_index", allowed_device_index);
56     }
57     replicated_nodes[allowed_device_index] = replicated_node;
58     return OkStatus();
59   }
60 
61   // Replace an edge (a regular device -> composite device) with
62   // N edges (a regular device -> allowed devices).
ReplicateFromRegularDeviceToCompositeDevice(const Edge * edge,Graph * graph) const63   void ReplicateFromRegularDeviceToCompositeDevice(const Edge* edge,
64                                                    Graph* graph) const {
65     Node* src = edge->src();
66     const std::vector<Node*>& dst_replicated_nodes =
67         replicated_nodes_map_.at(edge->dst());
68     for (Node* dst : dst_replicated_nodes) {
69       // Skip a replicated dst node without any consumer.
70       if (dst == nullptr) {
71         continue;
72       }
73       graph->AddEdge(src, edge->src_output(), dst, edge->dst_input());
74     }
75   }
76 
77   // Replace an edge (composite device -> composite device) with
78   // N edges (allowed devices -> allowed devices).
ReplicateFromCompositeDeviceToCompositeDevice(const Edge * edge,const std::vector<string> & allowed_devices,Graph * graph)79   Status ReplicateFromCompositeDeviceToCompositeDevice(
80       const Edge* edge, const std::vector<string>& allowed_devices,
81       Graph* graph) {
82     const std::vector<Node*>& src_replicated_nodes =
83         replicated_nodes_map_.at(edge->src());
84     const std::vector<Node*>& dst_replicated_nodes =
85         replicated_nodes_map_.at(edge->dst());
86     if (src_replicated_nodes.size() != dst_replicated_nodes.size()) {
87       return errors::InvalidArgument(
88           "Nodes assigned to the same composite device should have the "
89           "same number of replicated nodes. Found an edge from node ",
90           edge->src()->name(), " (", src_replicated_nodes.size(),
91           " replicated nodes) to node ", edge->dst()->name(), " (",
92           dst_replicated_nodes.size(), " replicated nodes).");
93     }
94     for (int i = 0; i < src_replicated_nodes.size(); ++i) {
95       Node* dst = dst_replicated_nodes.at(i);
96       // Skip a replicated dst node without any consumer.
97       if (dst == nullptr) {
98         continue;
99       }
100       TF_RETURN_IF_ERROR(ReplicateNode(edge->src(), allowed_devices, i, graph));
101       graph->AddEdge(src_replicated_nodes.at(i), edge->src_output(), dst,
102                      edge->dst_input());
103     }
104     return OkStatus();
105   }
106 
107   // Data edge: replace an edge (composite device -> a regular device) with
108   // one edge (one allowed device -> a regular device).
109   // Control edge: replace an edge (composite device -> a regular device) with
110   // N edges (allowed devices -> a regular device).
ReplicateFromCompositeDeviceToRegularDevice(const Edge * edge,const std::vector<string> & allowed_devices,Graph * graph)111   Status ReplicateFromCompositeDeviceToRegularDevice(
112       const Edge* edge, const std::vector<string>& allowed_devices,
113       Graph* graph) {
114     const std::vector<Node*>& src_replicated_nodes =
115         replicated_nodes_map_.at(edge->src());
116     Node* dst = edge->dst();
117     const string& dst_device = dst->assigned_device_name();
118     bool found_src_node = false;
119     for (int i = 0; i < allowed_devices.size(); ++i) {
120       if (allowed_devices.at(i) == dst_device) {
121         TF_RETURN_IF_ERROR(
122             ReplicateNode(edge->src(), allowed_devices, i, graph));
123         graph->AddEdge(src_replicated_nodes.at(i), edge->src_output(), dst,
124                        edge->dst_input());
125         found_src_node = true;
126         break;
127       }
128     }
129     if (!found_src_node) {
130       for (int i = 0; i < allowed_devices.size(); ++i) {
131         TF_RETURN_IF_ERROR(
132             ReplicateNode(edge->src(), allowed_devices, i, graph));
133       }
134       if (edge->IsControlEdge()) {
135         for (Node* replicated_node : src_replicated_nodes) {
136           // Duplication check in `Graph::AddControlEdge` is expensive for the
137           // dst node with a lot of input edges. Here each (src, dst) pair
138           // will only occur once so it is safe to skip the duplication check.
139           graph->AddControlEdge(replicated_node, dst,
140                                 /*allow_duplicates=*/true);
141         }
142         return OkStatus();
143       }
144       if (edge->src()->type_string() == "_Arg") {
145         // This happens when the dst node runs on a host CPU and
146         // captures a function with an arg node assigned to the same
147         // composite device (e.g. ScanDataset).
148         // For this case, we insert a PackOp between replicated nodes and the
149         // dst node. The dst node is responsible for unpacking the packed
150         // tensor.
151         // Add '/Packed' as a substring to the name of the new node, which
152         // could be helpful when debugging the graph.
153         NodeDefBuilder pack_builder(
154             graph->NewName(absl::StrCat(edge->src()->name(), "/Packed")),
155             "Pack");
156         const int num_replicas = src_replicated_nodes.size();
157         pack_builder.Attr("N", num_replicas);
158         const DataType dtype = edge->src()->output_type(edge->src_output());
159         pack_builder.Attr("T", dtype);
160         std::vector<NodeDefBuilder::NodeOut> inputs;
161         inputs.reserve(src_replicated_nodes.size());
162         for (Node* replicated_node : src_replicated_nodes) {
163           inputs.emplace_back(NodeDefBuilder::NodeOut{
164               replicated_node->name(), edge->src_output(), dtype});
165         }
166         pack_builder.Input(inputs);
167         NodeDef pack_def;
168         TF_RETURN_IF_ERROR(pack_builder.Finalize(&pack_def));
169         TF_ASSIGN_OR_RETURN(Node * pack_node, graph->AddNode(pack_def));
170         pack_node->set_assigned_device_name(dst->assigned_device_name());
171         for (int i = 0; i < src_replicated_nodes.size(); ++i) {
172           graph->AddEdge(src_replicated_nodes[i], edge->src_output(), pack_node,
173                          i);
174         }
175         graph->AddEdge(pack_node, /*x=*/0, dst, edge->dst_input());
176       } else {
177         return errors::InvalidArgument(
178             "Dst node should be assigned to an allowed device. Found an "
179             "edge from node ",
180             edge->src()->name(), " assigned to ",
181             edge->src()->assigned_device_name(), " to node ", dst->name(),
182             " assigned to ", dst_device);
183       }
184     }
185     return OkStatus();
186   }
187 
188  private:
189   // Map from original nodes to corresponding replicated nodes.
190   absl::flat_hash_map<const Node*, std::vector<Node*>> replicated_nodes_map_;
191 };
192 
193 // Replicate the nodes in cluster_nodes and update edges.
ReplicateNodesAndEdges(const std::vector<string> & allowed_devices,absl::flat_hash_map<Node *,int> * cluster_nodes,ReplicateHelper * helper,Graph * graph)194 Status ReplicateNodesAndEdges(const std::vector<string>& allowed_devices,
195                               absl::flat_hash_map<Node*, int>* cluster_nodes,
196                               ReplicateHelper* helper, Graph* graph) {
197   // Contains nodes in cluster_nodes whose out nodes are all on physical
198   // devices.
199   std::queue<Node*> nodes_ready_to_delete;
200   for (auto& pair : *cluster_nodes) {
201     Node* node = pair.first;
202     for (const Edge* edge : node->out_edges()) {
203       Node* dst = edge->dst();
204       if (dst->assigned_device_name() != node->assigned_device_name()) {
205         // The dst node is assigned to a different device.
206         TF_RETURN_IF_ERROR(helper->ReplicateFromCompositeDeviceToRegularDevice(
207             edge, allowed_devices, graph));
208         --pair.second;
209       }
210     }
211     // Node is ready to delete when all its consumer nodes are assigned to a
212     // physical device.
213     if (cluster_nodes->at(node) == 0) {
214       nodes_ready_to_delete.push(node);
215     }
216   }
217 
218   while (!nodes_ready_to_delete.empty()) {
219     Node* node = nodes_ready_to_delete.front();
220     nodes_ready_to_delete.pop();
221 
222     // Update input edges.
223     for (const Edge* edge : node->in_edges()) {
224       Node* src = edge->src();
225       if (src->assigned_device_name() != node->assigned_device_name()) {
226         // The source node is assigned to a different device.
227         helper->ReplicateFromRegularDeviceToCompositeDevice(edge, graph);
228       } else {
229         // The source node is assigned to the same composite device.
230         TF_RETURN_IF_ERROR(
231             helper->ReplicateFromCompositeDeviceToCompositeDevice(
232                 edge, allowed_devices, graph));
233         if (--(*cluster_nodes)[src] == 0) {
234           nodes_ready_to_delete.push(src);
235         }
236       }
237     }
238 
239     // Remove the original node.
240     cluster_nodes->erase(node);
241     graph->RemoveNode(node);
242   }
243   return OkStatus();
244 }
245 
246 }  // namespace
247 
ReplicatePerReplicaNodesInFunctionGraph(const absl::flat_hash_map<string,const std::vector<string> * > & composite_devices,Graph * graph)248 Status ReplicatePerReplicaNodesInFunctionGraph(
249     const absl::flat_hash_map<string, const std::vector<string>*>&
250         composite_devices,
251     Graph* graph) {
252   std::set<string> composite_device_names;
253   for (const auto& it : composite_devices) {
254     composite_device_names.insert(it.first);
255   }
256   // Map from a composite device to a cluster of nodes assigned to the
257   // composite device and the numbers of their out edges to process.
258   absl::flat_hash_map<string, absl::flat_hash_map<Node*, int>>
259       composite_device_to_cluster_nodes;
260   for (Node* n : graph->op_nodes()) {
261     if (composite_device_names.find(n->assigned_device_name()) !=
262         composite_device_names.end()) {
263       // TODO(b/145922293): Validate that an _Arg node assigned to a
264       // CompositeDevice should have an attribute indicating that the _Arg node
265       // represents a packed input.
266       composite_device_to_cluster_nodes[n->assigned_device_name()].emplace(
267           n, n->out_edges().size());
268     }
269   }
270 
271   for (auto& it : composite_device_to_cluster_nodes) {
272     const std::vector<string>& allowed_devices =
273         *composite_devices.at(it.first);
274     if (allowed_devices.empty()) {
275       return errors::InvalidArgument("No allowed device of composite device: ",
276                                      it.first);
277     }
278     absl::flat_hash_map<Node*, int>& cluster_nodes = it.second;
279     if (allowed_devices.size() == 1) {
280       // Reuse the original nodes if there is only one allowed device.
281       for (const auto& pair : it.second) {
282         Node* n = pair.first;
283         n->set_assigned_device_name(allowed_devices.at(0));
284         if (n->IsArg()) {
285           n->AddAttr("sub_index", 0);
286         }
287       }
288       continue;
289     }
290     ReplicateHelper helper;
291     for (const auto& pair : cluster_nodes) {
292       TF_RETURN_IF_ERROR(
293           helper.InitializeNode(pair.first, allowed_devices.size()));
294     }
295 
296     TF_RETURN_IF_ERROR(ReplicateNodesAndEdges(allowed_devices, &cluster_nodes,
297                                               &helper, graph));
298 
299     if (!cluster_nodes.empty()) {
300       return errors::InvalidArgument(
301           "There are still ", cluster_nodes.size(),
302           " nodes on CompositiveDevice ",
303           cluster_nodes.begin()->first->assigned_device_name());
304     }
305   }
306   return OkStatus();
307 }
308 
309 }  // namespace tensorflow
310