• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
16 
17 #include "tensorflow/core/framework/node_def.pb.h"
18 #include "tensorflow/core/framework/node_def_builder.h"
19 
20 namespace tensorflow {
21 namespace {
22 
23 // A helper for rewriting nodes assigned to a virtual composite device.
24 class ReplicateHelper {
25  public:
26   // Replicate the given node to all allowed devices.
ReplicateNode(const Node * node,const std::vector<string> & allowed_devices,Graph * graph)27   Status ReplicateNode(const Node* node,
28                        const std::vector<string>& allowed_devices,
29                        Graph* graph) {
30     if (replicated_nodes_map_.find(node) != replicated_nodes_map_.end()) {
31       return errors::InvalidArgument("Node ", node->name(),
32                                      " has been replicated.");
33     }
34 
35     std::vector<Node*> replicated_nodes(allowed_devices.size());
36     for (int i = 0; i < allowed_devices.size(); ++i) {
37       const auto& device = allowed_devices.at(i);
38       NodeDef node_def = node->def();
39       const string suffix = strings::StrCat("/R", i);
40       node_def.set_name(
41           graph->NewName(strings::StrCat(node_def.name(), suffix)));
42       Status status;
43       Node* replicated_node = graph->AddNode(node_def, &status);
44       TF_RETURN_IF_ERROR(status);
45       replicated_node->set_assigned_device_name(device);
46       if (replicated_node->IsArg()) {
47         replicated_node->AddAttr("sub_index", i);
48       }
49       replicated_nodes[i] = replicated_node;
50     }
51     replicated_nodes_map_.emplace(node, std::move(replicated_nodes));
52     return Status::OK();
53   }
54 
55   // Replace an edge (a regular device -> composite device) with
56   // N edges (a regular device -> allowed devices).
ReplicateFromRegularDeviceToCompositeDevice(const Edge * edge,Graph * graph) const57   void ReplicateFromRegularDeviceToCompositeDevice(const Edge* edge,
58                                                    Graph* graph) const {
59     Node* src = edge->src();
60     const std::vector<Node*>& dst_replicated_nodes =
61         replicated_nodes_map_.at(edge->dst());
62     for (Node* dst : dst_replicated_nodes) {
63       graph->AddEdge(src, edge->src_output(), dst, edge->dst_input());
64     }
65   }
66 
67   // Replace an edge (composite device -> composite device) with
68   // N edges (allowed devices -> allowed devices).
ReplicateFromCompositeDeviceToCompositeDevice(const Edge * edge,Graph * graph) const69   Status ReplicateFromCompositeDeviceToCompositeDevice(const Edge* edge,
70                                                        Graph* graph) const {
71     const std::vector<Node*>& src_replicated_nodes =
72         replicated_nodes_map_.at(edge->src());
73     const std::vector<Node*>& dst_replicated_nodes =
74         replicated_nodes_map_.at(edge->dst());
75     if (src_replicated_nodes.size() != dst_replicated_nodes.size()) {
76       return errors::InvalidArgument(
77           "Nodes assigned to the same composite device should have the "
78           "same number of replicated nodes. Found an edge from node ",
79           edge->src()->name(), " (", src_replicated_nodes.size(),
80           " replicated nodes) to node ", edge->dst()->name(), " (",
81           dst_replicated_nodes.size(), " replicated nodes).");
82     }
83     for (int i = 0; i < src_replicated_nodes.size(); ++i) {
84       graph->AddEdge(src_replicated_nodes.at(i), edge->src_output(),
85                      dst_replicated_nodes.at(i), edge->dst_input());
86     }
87     return Status::OK();
88   }
89 
90   // Data edge: replace an edge (composite device -> a regular device) with
91   // one edge (one allowed device -> a regular device).
92   // Control edge: replace an edge (composite device -> a regular device) with
93   // N edges (allowed devices -> a regular device).
ReplicateFromCompositeDeviceToRegularDevice(const Edge * edge,Graph * graph) const94   Status ReplicateFromCompositeDeviceToRegularDevice(const Edge* edge,
95                                                      Graph* graph) const {
96     const std::vector<Node*>& src_replicated_nodes =
97         replicated_nodes_map_.at(edge->src());
98     Node* dst = edge->dst();
99     if (edge->IsControlEdge()) {
100       for (Node* replicated_node : src_replicated_nodes) {
101         graph->AddControlEdge(replicated_node, dst);
102       }
103     } else {
104       const string& dst_device = dst->assigned_device_name();
105       bool found_src_node = false;
106       for (Node* replicated_node : src_replicated_nodes) {
107         if (replicated_node->assigned_device_name() == dst_device) {
108           graph->AddEdge(replicated_node, edge->src_output(), dst,
109                          edge->dst_input());
110           found_src_node = true;
111           break;
112         }
113       }
114       if (!found_src_node) {
115         if (edge->src()->type_string() == "_Arg") {
116           // This happens when the dst node runs on a host CPU and
117           // captures a function with an arg node assigned to the same
118           // composite device (e.g. ScanDataset).
119           // For this case, we insert a PackOp between replicated nodes and the
120           // dst node. The dst node is responsible for unpacking the packed
121           // tensor.
122           // Add '/Packed' as a substring to the name of the new node, which
123           // could be helpful when debugging the graph.
124           NodeDefBuilder pack_builder(
125               graph->NewName(absl::StrCat(edge->src()->name(), "/Packed")),
126               "Pack");
127           const int num_replicas = src_replicated_nodes.size();
128           pack_builder.Attr("N", num_replicas);
129           const DataType dtype = edge->src()->output_type(edge->src_output());
130           pack_builder.Attr("T", dtype);
131           std::vector<NodeDefBuilder::NodeOut> inputs;
132           inputs.reserve(src_replicated_nodes.size());
133           for (Node* replicated_node : src_replicated_nodes) {
134             inputs.emplace_back(NodeDefBuilder::NodeOut{
135                 replicated_node->name(), edge->src_output(), dtype});
136           }
137           pack_builder.Input(inputs);
138           NodeDef pack_def;
139           TF_RETURN_IF_ERROR(pack_builder.Finalize(&pack_def));
140           Status status;
141           Node* pack_node = graph->AddNode(pack_def, &status);
142           TF_RETURN_IF_ERROR(status);
143           pack_node->set_assigned_device_name(dst->assigned_device_name());
144           for (int i = 0; i < src_replicated_nodes.size(); ++i) {
145             graph->AddEdge(src_replicated_nodes[i], edge->src_output(),
146                            pack_node, i);
147           }
148           graph->AddEdge(pack_node, /*x=*/0, dst, edge->dst_input());
149         } else {
150           return errors::InvalidArgument(
151               "Dst node should be assigned to an allowed device. Found an "
152               "edge from node ",
153               edge->src()->name(), " assigned to ",
154               edge->src()->assigned_device_name(), " to node ", dst->name(),
155               " assigned to ", dst_device);
156         }
157       }
158     }
159     return Status::OK();
160   }
161 
RemoveDeadReplicatedArgs(Graph * graph)162   void RemoveDeadReplicatedArgs(Graph* graph) {
163     for (const auto& entry : replicated_nodes_map_) {
164       for (Node* replicated_node : entry.second) {
165         if (replicated_node->IsArg() && replicated_node->out_edges().empty()) {
166           graph->RemoveNode(replicated_node);
167         }
168       }
169     }
170   }
171 
172  private:
173   // Map from original nodes to corresponding replicated nodes.
174   absl::flat_hash_map<const Node*, std::vector<Node*>> replicated_nodes_map_;
175 };
176 
177 // Replicate the nodes in cluster_nodes to all allowed devices.
ReplicateNodes(const std::vector<Node * > & cluster_nodes,const std::vector<string> & allowed_devices,ReplicateHelper * helper,Graph * graph)178 Status ReplicateNodes(const std::vector<Node*>& cluster_nodes,
179                       const std::vector<string>& allowed_devices,
180                       ReplicateHelper* helper, Graph* graph) {
181   for (Node* n : cluster_nodes) {
182     TF_RETURN_IF_ERROR(helper->ReplicateNode(n, allowed_devices, graph));
183   }
184   return Status::OK();
185 }
186 
187 // Replicate the edges connecting original nodes for replicated nodes.
ReplicateEdges(const ReplicateHelper & helper,const std::vector<Node * > & cluster_nodes,Graph * graph)188 Status ReplicateEdges(const ReplicateHelper& helper,
189                       const std::vector<Node*>& cluster_nodes, Graph* graph) {
190   for (const auto* node : cluster_nodes) {
191     // Replicate input edges.
192     for (const Edge* edge : node->in_edges()) {
193       Node* src = edge->src();
194       if (src->assigned_device_name() != node->assigned_device_name()) {
195         // The source node is assigned to a different device.
196         helper.ReplicateFromRegularDeviceToCompositeDevice(edge, graph);
197       } else {
198         // The source node is assigned to the same composite device.
199         TF_RETURN_IF_ERROR(
200             helper.ReplicateFromCompositeDeviceToCompositeDevice(edge, graph));
201       }
202     }
203 
204     // Replicate output edges.
205     for (const Edge* edge : node->out_edges()) {
206       Node* dst = edge->dst();
207       if (dst->assigned_device_name() != node->assigned_device_name()) {
208         // The dst node is assigned to a different device.
209         TF_RETURN_IF_ERROR(
210             helper.ReplicateFromCompositeDeviceToRegularDevice(edge, graph));
211       }
212       // The else branch has been covered when iterating over input edges.
213     }
214   }
215   return Status::OK();
216 }
217 
218 }  // namespace
219 
ReplicatePerReplicaNodesInFunctionGraph(const absl::flat_hash_map<string,const std::vector<string> * > & composite_devices,Graph * graph)220 Status ReplicatePerReplicaNodesInFunctionGraph(
221     const absl::flat_hash_map<string, const std::vector<string>*>&
222         composite_devices,
223     Graph* graph) {
224   std::set<string> composite_device_names;
225   for (const auto& it : composite_devices) {
226     composite_device_names.insert(it.first);
227   }
228   // Map from a composite device to a cluster of nodes assigned to the
229   // composite device.
230   absl::flat_hash_map<string, std::vector<Node*>>
231       composite_device_to_cluster_nodes;
232   for (Node* n : graph->op_nodes()) {
233     if (composite_device_names.find(n->assigned_device_name()) !=
234         composite_device_names.end()) {
235       // TODO(b/145922293): Validate that an _Arg node assigned to a
236       // CompositeDevice should have an attribute indicating that the _Arg node
237       // represents a packed input.
238       composite_device_to_cluster_nodes[n->assigned_device_name()].push_back(n);
239     }
240   }
241 
242   for (const auto& it : composite_device_to_cluster_nodes) {
243     const std::vector<string>& allowed_devices =
244         *composite_devices.at(it.first);
245     if (allowed_devices.empty()) {
246       return errors::InvalidArgument("No allowed device of composite device: ",
247                                      it.first);
248     }
249     const std::vector<Node*>& cluster_nodes = it.second;
250     if (allowed_devices.size() == 1) {
251       // Reuse the original nodes if there is only one allowed device.
252       for (Node* n : cluster_nodes) {
253         n->set_assigned_device_name(allowed_devices.at(0));
254         if (n->IsArg()) {
255           n->AddAttr("sub_index", 0);
256         }
257       }
258       continue;
259     }
260     ReplicateHelper helper;
261     TF_RETURN_IF_ERROR(
262         ReplicateNodes(cluster_nodes, allowed_devices, &helper, graph));
263     TF_RETURN_IF_ERROR(ReplicateEdges(helper, cluster_nodes, graph));
264 
265     // Remove orignial nodes.
266     for (auto* n : cluster_nodes) {
267       graph->RemoveNode(n);
268     }
269 
270     helper.RemoveDeadReplicatedArgs(graph);
271   }
272   return Status::OK();
273 }
274 
275 }  // namespace tensorflow
276