1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
16
17 #include "tensorflow/core/framework/node_def.pb.h"
18 #include "tensorflow/core/framework/node_def_builder.h"
19
20 namespace tensorflow {
21 namespace {
22
23 // A helper for rewriting nodes assigned to a virtual composite device.
24 class ReplicateHelper {
25 public:
26 // Replicate the given node to all allowed devices.
ReplicateNode(const Node * node,const std::vector<string> & allowed_devices,Graph * graph)27 Status ReplicateNode(const Node* node,
28 const std::vector<string>& allowed_devices,
29 Graph* graph) {
30 if (replicated_nodes_map_.find(node) != replicated_nodes_map_.end()) {
31 return errors::InvalidArgument("Node ", node->name(),
32 " has been replicated.");
33 }
34
35 std::vector<Node*> replicated_nodes(allowed_devices.size());
36 for (int i = 0; i < allowed_devices.size(); ++i) {
37 const auto& device = allowed_devices.at(i);
38 NodeDef node_def = node->def();
39 const string suffix = strings::StrCat("/R", i);
40 node_def.set_name(
41 graph->NewName(strings::StrCat(node_def.name(), suffix)));
42 Status status;
43 Node* replicated_node = graph->AddNode(node_def, &status);
44 TF_RETURN_IF_ERROR(status);
45 replicated_node->set_assigned_device_name(device);
46 if (replicated_node->IsArg()) {
47 replicated_node->AddAttr("sub_index", i);
48 }
49 replicated_nodes[i] = replicated_node;
50 }
51 replicated_nodes_map_.emplace(node, std::move(replicated_nodes));
52 return Status::OK();
53 }
54
55 // Replace an edge (a regular device -> composite device) with
56 // N edges (a regular device -> allowed devices).
ReplicateFromRegularDeviceToCompositeDevice(const Edge * edge,Graph * graph) const57 void ReplicateFromRegularDeviceToCompositeDevice(const Edge* edge,
58 Graph* graph) const {
59 Node* src = edge->src();
60 const std::vector<Node*>& dst_replicated_nodes =
61 replicated_nodes_map_.at(edge->dst());
62 for (Node* dst : dst_replicated_nodes) {
63 graph->AddEdge(src, edge->src_output(), dst, edge->dst_input());
64 }
65 }
66
67 // Replace an edge (composite device -> composite device) with
68 // N edges (allowed devices -> allowed devices).
ReplicateFromCompositeDeviceToCompositeDevice(const Edge * edge,Graph * graph) const69 Status ReplicateFromCompositeDeviceToCompositeDevice(const Edge* edge,
70 Graph* graph) const {
71 const std::vector<Node*>& src_replicated_nodes =
72 replicated_nodes_map_.at(edge->src());
73 const std::vector<Node*>& dst_replicated_nodes =
74 replicated_nodes_map_.at(edge->dst());
75 if (src_replicated_nodes.size() != dst_replicated_nodes.size()) {
76 return errors::InvalidArgument(
77 "Nodes assigned to the same composite device should have the "
78 "same number of replicated nodes. Found an edge from node ",
79 edge->src()->name(), " (", src_replicated_nodes.size(),
80 " replicated nodes) to node ", edge->dst()->name(), " (",
81 dst_replicated_nodes.size(), " replicated nodes).");
82 }
83 for (int i = 0; i < src_replicated_nodes.size(); ++i) {
84 graph->AddEdge(src_replicated_nodes.at(i), edge->src_output(),
85 dst_replicated_nodes.at(i), edge->dst_input());
86 }
87 return Status::OK();
88 }
89
90 // Data edge: replace an edge (composite device -> a regular device) with
91 // one edge (one allowed device -> a regular device).
92 // Control edge: replace an edge (composite device -> a regular device) with
93 // N edges (allowed devices -> a regular device).
ReplicateFromCompositeDeviceToRegularDevice(const Edge * edge,Graph * graph) const94 Status ReplicateFromCompositeDeviceToRegularDevice(const Edge* edge,
95 Graph* graph) const {
96 const std::vector<Node*>& src_replicated_nodes =
97 replicated_nodes_map_.at(edge->src());
98 Node* dst = edge->dst();
99 if (edge->IsControlEdge()) {
100 for (Node* replicated_node : src_replicated_nodes) {
101 graph->AddControlEdge(replicated_node, dst);
102 }
103 } else {
104 const string& dst_device = dst->assigned_device_name();
105 bool found_src_node = false;
106 for (Node* replicated_node : src_replicated_nodes) {
107 if (replicated_node->assigned_device_name() == dst_device) {
108 graph->AddEdge(replicated_node, edge->src_output(), dst,
109 edge->dst_input());
110 found_src_node = true;
111 break;
112 }
113 }
114 if (!found_src_node) {
115 if (edge->src()->type_string() == "_Arg") {
116 // This happens when the dst node runs on a host CPU and
117 // captures a function with an arg node assigned to the same
118 // composite device (e.g. ScanDataset).
119 // For this case, we insert a PackOp between replicated nodes and the
120 // dst node. The dst node is responsible for unpacking the packed
121 // tensor.
122 // Add '/Packed' as a substring to the name of the new node, which
123 // could be helpful when debugging the graph.
124 NodeDefBuilder pack_builder(
125 graph->NewName(absl::StrCat(edge->src()->name(), "/Packed")),
126 "Pack");
127 const int num_replicas = src_replicated_nodes.size();
128 pack_builder.Attr("N", num_replicas);
129 const DataType dtype = edge->src()->output_type(edge->src_output());
130 pack_builder.Attr("T", dtype);
131 std::vector<NodeDefBuilder::NodeOut> inputs;
132 inputs.reserve(src_replicated_nodes.size());
133 for (Node* replicated_node : src_replicated_nodes) {
134 inputs.emplace_back(NodeDefBuilder::NodeOut{
135 replicated_node->name(), edge->src_output(), dtype});
136 }
137 pack_builder.Input(inputs);
138 NodeDef pack_def;
139 TF_RETURN_IF_ERROR(pack_builder.Finalize(&pack_def));
140 Status status;
141 Node* pack_node = graph->AddNode(pack_def, &status);
142 TF_RETURN_IF_ERROR(status);
143 pack_node->set_assigned_device_name(dst->assigned_device_name());
144 for (int i = 0; i < src_replicated_nodes.size(); ++i) {
145 graph->AddEdge(src_replicated_nodes[i], edge->src_output(),
146 pack_node, i);
147 }
148 graph->AddEdge(pack_node, /*x=*/0, dst, edge->dst_input());
149 } else {
150 return errors::InvalidArgument(
151 "Dst node should be assigned to an allowed device. Found an "
152 "edge from node ",
153 edge->src()->name(), " assigned to ",
154 edge->src()->assigned_device_name(), " to node ", dst->name(),
155 " assigned to ", dst_device);
156 }
157 }
158 }
159 return Status::OK();
160 }
161
RemoveDeadReplicatedArgs(Graph * graph)162 void RemoveDeadReplicatedArgs(Graph* graph) {
163 for (const auto& entry : replicated_nodes_map_) {
164 for (Node* replicated_node : entry.second) {
165 if (replicated_node->IsArg() && replicated_node->out_edges().empty()) {
166 graph->RemoveNode(replicated_node);
167 }
168 }
169 }
170 }
171
172 private:
173 // Map from original nodes to corresponding replicated nodes.
174 absl::flat_hash_map<const Node*, std::vector<Node*>> replicated_nodes_map_;
175 };
176
177 // Replicate the nodes in cluster_nodes to all allowed devices.
ReplicateNodes(const std::vector<Node * > & cluster_nodes,const std::vector<string> & allowed_devices,ReplicateHelper * helper,Graph * graph)178 Status ReplicateNodes(const std::vector<Node*>& cluster_nodes,
179 const std::vector<string>& allowed_devices,
180 ReplicateHelper* helper, Graph* graph) {
181 for (Node* n : cluster_nodes) {
182 TF_RETURN_IF_ERROR(helper->ReplicateNode(n, allowed_devices, graph));
183 }
184 return Status::OK();
185 }
186
187 // Replicate the edges connecting original nodes for replicated nodes.
ReplicateEdges(const ReplicateHelper & helper,const std::vector<Node * > & cluster_nodes,Graph * graph)188 Status ReplicateEdges(const ReplicateHelper& helper,
189 const std::vector<Node*>& cluster_nodes, Graph* graph) {
190 for (const auto* node : cluster_nodes) {
191 // Replicate input edges.
192 for (const Edge* edge : node->in_edges()) {
193 Node* src = edge->src();
194 if (src->assigned_device_name() != node->assigned_device_name()) {
195 // The source node is assigned to a different device.
196 helper.ReplicateFromRegularDeviceToCompositeDevice(edge, graph);
197 } else {
198 // The source node is assigned to the same composite device.
199 TF_RETURN_IF_ERROR(
200 helper.ReplicateFromCompositeDeviceToCompositeDevice(edge, graph));
201 }
202 }
203
204 // Replicate output edges.
205 for (const Edge* edge : node->out_edges()) {
206 Node* dst = edge->dst();
207 if (dst->assigned_device_name() != node->assigned_device_name()) {
208 // The dst node is assigned to a different device.
209 TF_RETURN_IF_ERROR(
210 helper.ReplicateFromCompositeDeviceToRegularDevice(edge, graph));
211 }
212 // The else branch has been covered when iterating over input edges.
213 }
214 }
215 return Status::OK();
216 }
217
218 } // namespace
219
ReplicatePerReplicaNodesInFunctionGraph(const absl::flat_hash_map<string,const std::vector<string> * > & composite_devices,Graph * graph)220 Status ReplicatePerReplicaNodesInFunctionGraph(
221 const absl::flat_hash_map<string, const std::vector<string>*>&
222 composite_devices,
223 Graph* graph) {
224 std::set<string> composite_device_names;
225 for (const auto& it : composite_devices) {
226 composite_device_names.insert(it.first);
227 }
228 // Map from a composite device to a cluster of nodes assigned to the
229 // composite device.
230 absl::flat_hash_map<string, std::vector<Node*>>
231 composite_device_to_cluster_nodes;
232 for (Node* n : graph->op_nodes()) {
233 if (composite_device_names.find(n->assigned_device_name()) !=
234 composite_device_names.end()) {
235 // TODO(b/145922293): Validate that an _Arg node assigned to a
236 // CompositeDevice should have an attribute indicating that the _Arg node
237 // represents a packed input.
238 composite_device_to_cluster_nodes[n->assigned_device_name()].push_back(n);
239 }
240 }
241
242 for (const auto& it : composite_device_to_cluster_nodes) {
243 const std::vector<string>& allowed_devices =
244 *composite_devices.at(it.first);
245 if (allowed_devices.empty()) {
246 return errors::InvalidArgument("No allowed device of composite device: ",
247 it.first);
248 }
249 const std::vector<Node*>& cluster_nodes = it.second;
250 if (allowed_devices.size() == 1) {
251 // Reuse the original nodes if there is only one allowed device.
252 for (Node* n : cluster_nodes) {
253 n->set_assigned_device_name(allowed_devices.at(0));
254 if (n->IsArg()) {
255 n->AddAttr("sub_index", 0);
256 }
257 }
258 continue;
259 }
260 ReplicateHelper helper;
261 TF_RETURN_IF_ERROR(
262 ReplicateNodes(cluster_nodes, allowed_devices, &helper, graph));
263 TF_RETURN_IF_ERROR(ReplicateEdges(helper, cluster_nodes, graph));
264
265 // Remove orignial nodes.
266 for (auto* n : cluster_nodes) {
267 graph->RemoveNode(n);
268 }
269
270 helper.RemoveDeadReplicatedArgs(graph);
271 }
272 return Status::OK();
273 }
274
275 } // namespace tensorflow
276