1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/graph_partition.h"
17
18 #include <deque>
19 #include <queue>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/container/flat_hash_map.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/memory_types.h"
28 #include "tensorflow/core/framework/node_def_builder.h"
29 #include "tensorflow/core/framework/tensor.pb.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/framework/versions.pb.h"
32 #include "tensorflow/core/graph/algorithm.h"
33 #include "tensorflow/core/graph/control_flow.h"
34 #include "tensorflow/core/graph/costmodel.h"
35 #include "tensorflow/core/graph/graph_def_builder.h"
36 #include "tensorflow/core/graph/node_builder.h"
37 #include "tensorflow/core/graph/tensor_id.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/hash/hash.h"
40 #include "tensorflow/core/lib/strings/str_util.h"
41 #include "tensorflow/core/platform/logging.h"
42 #include "tensorflow/core/util/device_name_utils.h"
43
44 namespace tensorflow {
45
46 namespace {
47
IsMerge(const NodeDef & node_def)48 inline bool IsMerge(const NodeDef& node_def) {
49 return node_def.op() == "Merge" || node_def.op() == "RefMerge";
50 }
51
IsNextIteration(const NodeDef & node_def)52 inline bool IsNextIteration(const NodeDef& node_def) {
53 return node_def.op() == "NextIteration" ||
54 node_def.op() == "RefNextIteration";
55 }
56
57 struct DupRecvKey {
58 int src_node_id; // Edge's src node id
59 int src_output_slot; // Edge's src node output slot
60 GraphDef* dst_graph; // Edge's dst node is in this subgraph
61 bool recv_output_on_host; // The output of recv is on host
62
63 template <typename H>
AbslHashValue(H h,const DupRecvKey & c)64 friend H AbslHashValue(H h, const DupRecvKey& c) {
65 return H::combine(std::move(h), c.src_node_id, c.src_output_slot,
66 reinterpret_cast<std::uintptr_t>(c.dst_graph),
67 c.recv_output_on_host);
68 }
69
operator ==(const DupRecvKey & x,const DupRecvKey & y)70 friend bool operator==(const DupRecvKey& x, const DupRecvKey& y) {
71 return (x.src_node_id == y.src_node_id) &&
72 (x.src_output_slot == y.src_output_slot) &&
73 (x.dst_graph == y.dst_graph) &&
74 (x.recv_output_on_host == y.recv_output_on_host);
75 }
76 };
77
78 // struct used to store the recvs, so that start times can be properly updated
79 struct RecvInfo {
80 NodeDef* recv;
81 NodeDef* real_recv;
82 int64 start_time;
83 };
84
85 typedef absl::flat_hash_map<DupRecvKey, RecvInfo> DupRecvTable;
86
87 // A map used to store memory types for the inputs/outputs of every node.
88 // The key is a pair of ints consisting of a node id and input/output index.
89 // TODO(power): migrate back to std::pair when absl::Hash is fixed for MSVC.
90 struct NodePort {
91 int node_id;
92 int index;
93
operator ==(const NodePort & x,const NodePort & y)94 friend bool operator==(const NodePort& x, const NodePort& y) {
95 return x.node_id == y.node_id && x.index == y.index;
96 }
97
98 template <typename H>
AbslHashValue(H h,const NodePort & c)99 friend H AbslHashValue(H h, const NodePort& c) {
100 return H::combine(std::move(h), c.node_id, c.index);
101 }
102 };
103
104 typedef absl::flat_hash_map<NodePort, MemoryType> MemoryTypeMap;
105
106 // We collect the following information about the graph before performing
107 // graph partitioning.
108 struct GraphInfo {
109 std::vector<DeviceType> device_types;
110 MemoryTypeMap input_types;
111 MemoryTypeMap output_types;
112 std::vector<ControlFlowInfo> cf_info;
113 };
114
EdgeType(const Edge * e)115 DataType EdgeType(const Edge* e) {
116 if (e->IsControlEdge()) {
117 return DT_FLOAT;
118 } else {
119 return e->dst()->input_type(e->dst_input());
120 }
121 }
122
123 // Return true iff we need to add the same device send/recv for 'edge'.
NeedSameDeviceSendRecv(const Edge * edge,const GraphInfo & info)124 bool NeedSameDeviceSendRecv(const Edge* edge, const GraphInfo& info) {
125 if (edge->IsControlEdge()) {
126 return false;
127 }
128
129 const Node* src = edge->src();
130 const Node* dst = edge->dst();
131 if (src->assigned_device_name() == dst->assigned_device_name()) {
132 int src_port = edge->src_output();
133 int dst_port = edge->dst_input();
134 if (info.device_types[src->id()] != DEVICE_CPU) {
135 auto src_it = info.output_types.find({src->id(), src_port});
136 DCHECK(src_it != info.output_types.end());
137 auto dst_it = info.input_types.find({dst->id(), dst_port});
138 DCHECK(dst_it != info.input_types.end());
139 return src_it->second != dst_it->second;
140 }
141 }
142 return false;
143 }
144
145 // Return true iff (dst, dst_input) is specified on host memory.
IsDstInputOnHost(const Edge * edge,const GraphInfo & info)146 bool IsDstInputOnHost(const Edge* edge, const GraphInfo& info) {
147 const Node* dst = edge->dst();
148 int dst_port = edge->dst_input();
149 if (info.device_types[dst->id()] != DEVICE_CPU) {
150 if (edge->IsControlEdge()) return false;
151 auto dst_it = info.input_types.find({dst->id(), dst_port});
152 DCHECK(dst_it != info.input_types.end());
153 return dst_it->second == HOST_MEMORY;
154 }
155 return true;
156 }
157
158 // Add an input to dst that comes from the "src_slot" output of the
159 // node named by "src_name".
AddInput(NodeDef * dst,StringPiece src_name,int src_slot)160 void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
161 if (src_slot == Graph::kControlSlot) {
162 dst->add_input(strings::StrCat("^", src_name));
163 } else if (src_slot == 0) {
164 dst->add_input(src_name.data(), src_name.size());
165 } else {
166 dst->add_input(strings::StrCat(src_name, ":", src_slot));
167 }
168 }
169
170 // Add a control edge from each input to each recv.
AddReadControl(const std::vector<NodeDef * > & recvs,const std::vector<string> & inputs)171 void AddReadControl(const std::vector<NodeDef*>& recvs,
172 const std::vector<string>& inputs) {
173 for (NodeDef* recv : recvs) {
174 for (const string& input : inputs) {
175 recv->add_input(strings::StrCat("^", input));
176 }
177 }
178 }
179
SetSendRecvAttrs(const PartitionOptions & opts,const Edge * edge,NodeDefBuilder * builder)180 void SetSendRecvAttrs(const PartitionOptions& opts, const Edge* edge,
181 NodeDefBuilder* builder) {
182 builder->Attr("tensor_name",
183 strings::StrCat("edge_", edge->id(), "_", edge->src()->name()));
184 builder->Attr("send_device", edge->src()->assigned_device_name());
185 builder->Attr("send_device_incarnation",
186 static_cast<int64>(
187 opts.get_incarnation(edge->src()->assigned_device_name())));
188 builder->Attr("recv_device", edge->dst()->assigned_device_name());
189 builder->Attr("client_terminated", false);
190 }
191
AddSend(const PartitionOptions & opts,const GraphInfo & g_info,GraphDef * gdef,const Edge * edge,NodeDefBuilder::NodeOut send_from,int64 start_time,Status * status)192 NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info,
193 GraphDef* gdef, const Edge* edge,
194 NodeDefBuilder::NodeOut send_from, int64 start_time,
195 Status* status) {
196 const DataType dtype = send_from.data_type;
197 const DataType cast_dtype = opts.should_cast ? opts.should_cast(edge) : dtype;
198 const Node* src = edge->src();
199 const int src_port = edge->src_output();
200
201 // host_memory = true iff we need to use HostSend/HostCast.
202 bool host_memory = false;
203 if (!edge->IsControlEdge()) {
204 auto src_it = g_info.output_types.find({src->id(), src_port});
205 DCHECK(src_it != g_info.output_types.end());
206 host_memory = (src_it->second == HOST_MEMORY);
207 }
208
209 // Add a cast node that casts dtype to cast_dtype.
210 // NOTE(yuanbyu): Only cast for cross-device send/recv.
211 if (dtype != cast_dtype && !NeedSameDeviceSendRecv(edge, g_info)) {
212 const string cast_op = (host_memory) ? "_HostCast" : "Cast";
213 NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op,
214 NodeDebugInfo(*src));
215 cast_builder.Device(src->assigned_device_name()).Input(send_from);
216 if (opts.scheduling_for_recvs) {
217 cast_builder.Attr("_start_time", start_time);
218 }
219 cast_builder.Attr("DstT", cast_dtype);
220
221 if (cast_dtype == DT_BFLOAT16) {
222 // the below attribute specifies that the cast to bfloat16 should use
223 // truncation. This is needed to retain legacy behavior when we change
224 // the default bfloat16 casts to use rounding instead of truncation
225 cast_builder.Attr("Truncate", true);
226 }
227
228 NodeDef* cast = gdef->add_node();
229 *status = cast_builder.Finalize(cast);
230 if (!status->ok()) return nullptr;
231
232 // Connect the Send op to the cast.
233 send_from.Reset(cast->name(), 0, cast_dtype);
234 }
235
236 // Add the send node.
237 const string send_op = (host_memory) ? "_HostSend" : "_Send";
238 NodeDefBuilder send_builder(opts.new_name(src->name()), send_op,
239 NodeDebugInfo(*src));
240 SetSendRecvAttrs(opts, edge, &send_builder);
241 send_builder.Device(src->assigned_device_name()).Input(send_from);
242 if (opts.scheduling_for_recvs) {
243 send_builder.Attr("_start_time", start_time);
244 }
245 NodeDef* send = gdef->add_node();
246 *status = send_builder.Finalize(send);
247 return send;
248 }
249
AddRecv(const PartitionOptions & opts,const GraphInfo & g_info,GraphDef * gdef,const Edge * edge,NodeDef ** real_recv,Status * status)250 NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info,
251 GraphDef* gdef, const Edge* edge, NodeDef** real_recv,
252 Status* status) {
253 const DataType dtype = EdgeType(edge);
254 const Node* src = edge->src();
255 const Node* dst = edge->dst();
256 const int dst_port = edge->dst_input();
257 DataType cast_dtype = dtype;
258
259 // NOTE(yuanbyu): Only cast for cross-device send/recv.
260 if (opts.should_cast && !NeedSameDeviceSendRecv(edge, g_info)) {
261 cast_dtype = opts.should_cast(edge);
262 }
263
264 // host_memory = true iff we need to use HostRecv/HostCast.
265 bool host_memory = false;
266 if (!edge->IsControlEdge()) {
267 auto dst_it = g_info.input_types.find({dst->id(), dst_port});
268 DCHECK(dst_it != g_info.input_types.end());
269 host_memory = (dst_it->second == HOST_MEMORY);
270 }
271
272 // Add the recv node.
273 const string recv_op = (host_memory) ? "_HostRecv" : "_Recv";
274 NodeDefBuilder recv_builder(opts.new_name(src->name()), recv_op,
275 NodeDebugInfo(*src));
276 SetSendRecvAttrs(opts, edge, &recv_builder);
277 recv_builder.Device(dst->assigned_device_name())
278 .Attr("tensor_type", cast_dtype);
279 NodeDef* recv = gdef->add_node();
280 *status = recv_builder.Finalize(recv);
281 if (!status->ok()) return nullptr;
282 *real_recv = recv;
283
284 // Add the cast node (from cast_dtype to dtype) or an Identity node.
285 if (dtype != cast_dtype) {
286 const string cast_op = (host_memory) ? "_HostCast" : "Cast";
287 NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op,
288 NodeDebugInfo(*src));
289 cast_builder.Attr("DstT", dtype);
290 cast_builder.Device(dst->assigned_device_name())
291 .Input(recv->name(), 0, cast_dtype);
292 NodeDef* cast = gdef->add_node();
293 *status = cast_builder.Finalize(cast);
294 if (!status->ok()) return nullptr;
295 return cast;
296 } else if (edge->IsControlEdge()) {
297 // An Identity is only needed for control edges.
298 NodeDefBuilder id_builder(opts.new_name(src->name()), "Identity",
299 NodeDebugInfo(*src));
300 id_builder.Device(dst->assigned_device_name())
301 .Input(recv->name(), 0, cast_dtype);
302 NodeDef* id = gdef->add_node();
303 *status = id_builder.Finalize(id);
304 if (!status->ok()) return nullptr;
305 return id;
306 } else {
307 return recv;
308 }
309 }
310
AddDummyConst(const PartitionOptions & opts,GraphDef * gdef,const Edge * edge,Status * status)311 NodeDef* AddDummyConst(const PartitionOptions& opts, GraphDef* gdef,
312 const Edge* edge, Status* status) {
313 const Node* src = edge->src();
314 Tensor tensor(DT_FLOAT, TensorShape({0}));
315 NodeDef* result = gdef->add_node();
316 *status = NodeDefBuilder(opts.new_name(src->name()), "Const")
317 .Device(src->assigned_device_name())
318 .Attr("dtype", DT_FLOAT)
319 .Attr("value", tensor)
320 .Finalize(result);
321 return result;
322 }
323
324 // A dummy node for scheduling.
AddControlTrigger(const PartitionOptions & opts,GraphDef * gdef,const string & assigned_device_name,int64 epoch,int64 starttime,Status * status)325 NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef,
326 const string& assigned_device_name, int64 epoch,
327 int64 starttime, Status* status) {
328 NodeDef* result = gdef->add_node();
329 *status = NodeDefBuilder(opts.new_name(strings::StrCat("synch_", epoch)),
330 "ControlTrigger")
331 .Device(assigned_device_name)
332 .Attr("_start_time", starttime)
333 .Finalize(result);
334 return result;
335 }
336
337 // Optimize colocation for control flow nodes. For cond, we want the
338 // switch nodes to colocate with its data input. This is particularly
339 // needed for conditional reading of a remote variable. It may also
340 // reduce the number of devices involved in a loop.
341 // TODO(yuanbyu): In this case, we don't respect the requested device in
342 // the GraphDef for these nodes. Ideally, the placer would enforce the
343 // colocation to render this unnecessary.
OptimizeControlFlowColocation(Graph * graph)344 void OptimizeControlFlowColocation(Graph* graph) {
345 auto visit = [](Node* node) {
346 if (IsSwitch(node)) {
347 for (const Edge* in_edge : node->in_edges()) {
348 if (in_edge->dst_input() == 0) {
349 // Colocate with the data input.
350 node->set_assigned_device_name(
351 in_edge->src()->assigned_device_name());
352 return;
353 }
354 }
355 } else if (IsExit(node)) {
356 for (const Edge* in_edge : node->in_edges()) {
357 if (!in_edge->IsControlEdge()) {
358 // Colocate with upstream node.
359 node->set_assigned_device_name(
360 in_edge->src()->assigned_device_name());
361 return;
362 }
363 }
364 } else {
365 if ((IsEnter(node) && !IsRefType(node->input_type(0))) ||
366 IsNextIteration(node)) {
367 const Edge* data_edge = nullptr;
368 for (const Edge* out_edge : node->out_edges()) {
369 if (!out_edge->IsControlEdge()) {
370 data_edge = out_edge;
371 break;
372 }
373 }
374 // Colocate with the first downstream data node.
375 if (data_edge) {
376 node->set_assigned_device_name(
377 data_edge->dst()->assigned_device_name());
378 }
379 }
380 }
381 };
382 DFS(*graph, visit, {});
383 }
384
ControlLoopName(const string & name)385 string ControlLoopName(const string& name) {
386 return strings::StrCat("_cloop", name);
387 }
388
IsControlLoop(const Node * node)389 bool IsControlLoop(const Node* node) {
390 const string& name = node->name();
391 return str_util::StartsWith(name, "_cloop");
392 }
393
394 // An enter node for control flow.
AddControlEnter(Graph * g,const string & node_name,const string & device_name,const string & frame_name,const int parallel_iterations,Status * status)395 Node* AddControlEnter(Graph* g, const string& node_name,
396 const string& device_name, const string& frame_name,
397 const int parallel_iterations, Status* status) {
398 NodeBuilder node_builder(node_name, "Enter", g->op_registry());
399 node_builder.Input({"dummy", 0, DT_FLOAT});
400 node_builder.Attr("frame_name", frame_name);
401 node_builder.Attr("parallel_iterations", parallel_iterations);
402 Node* res_node;
403 *status = node_builder.Finalize(g, &res_node);
404 if (!status->ok()) return nullptr;
405 res_node->set_assigned_device_name(device_name);
406 return res_node;
407 }
408
409 // A merge node for control flow.
AddControlMerge(const string & in_name1,const string & in_name2,Graph * g,const string & node_name,const string & device_name,Status * status)410 Node* AddControlMerge(const string& in_name1, const string& in_name2, Graph* g,
411 const string& node_name, const string& device_name,
412 Status* status) {
413 NodeBuilder node_builder(node_name, "Merge", g->op_registry());
414 node_builder.Input({{in_name1, 0, DT_FLOAT}, {in_name2, 0, DT_FLOAT}});
415 Node* res_node;
416 *status = node_builder.Finalize(g, &res_node);
417 if (!status->ok()) return nullptr;
418 res_node->set_assigned_device_name(device_name);
419 return res_node;
420 }
421
422 // A switch node for control flow.
AddControlSwitch(NodeBuilder::NodeOut input1,NodeBuilder::NodeOut input2,const string & device_name,const GraphDefBuilder::Options & bopts)423 Node* AddControlSwitch(NodeBuilder::NodeOut input1, NodeBuilder::NodeOut input2,
424 const string& device_name,
425 const GraphDefBuilder::Options& bopts) {
426 Node* res_node =
427 ops::BinaryOp("Switch", std::move(input1), std::move(input2), bopts);
428 if (bopts.HaveError()) return nullptr;
429 res_node->set_assigned_device_name(device_name);
430 return res_node;
431 }
432
433 // A next_iteration node for control flow.
AddControlNext(NodeBuilder::NodeOut input,const string & device_name,const GraphDefBuilder::Options & bopts)434 Node* AddControlNext(NodeBuilder::NodeOut input, const string& device_name,
435 const GraphDefBuilder::Options& bopts) {
436 Node* res_node = ops::UnaryOp("NextIteration", std::move(input), bopts);
437 if (bopts.HaveError()) return nullptr;
438 res_node->set_assigned_device_name(device_name);
439 return res_node;
440 }
441
EmptyConst(const GraphDefBuilder::Options & options)442 Node* EmptyConst(const GraphDefBuilder::Options& options) {
443 if (options.HaveError()) return nullptr;
444 NodeBuilder node_builder(options.GetNameForOp("Const"), "Const",
445 options.op_registry());
446 const DataType dt = DataTypeToEnum<float>::v();
447 TensorProto proto;
448 proto.set_dtype(dt);
449 TensorShape empty_shape({0});
450 empty_shape.AsProto(proto.mutable_tensor_shape());
451 node_builder.Attr("dtype", dt).Attr("value", proto);
452 return options.FinalizeBuilder(&node_builder);
453 }
454
455 // A dummy const node for control flow.
AddControlConst(const string & device_name,const GraphDefBuilder::Options & bopts)456 Node* AddControlConst(const string& device_name,
457 const GraphDefBuilder::Options& bopts) {
458 Node* res_node = EmptyConst(bopts);
459 if (bopts.HaveError()) return nullptr;
460 res_node->set_assigned_device_name(device_name);
461 return res_node;
462 }
463
464 // A synthetic loop, made up of dummy nodes. It performs control-flow actions
465 // on behalf of a leader on a different device.
466 struct ControlLoop {
467 Node* enter = nullptr;
468 Node* merge = nullptr;
469 Node* switch_node = nullptr;
470 };
471
472 // Add the control flow info of a new node added during partitioning.
473 // The new node has the same control flow info as src.
AddControlFlowInfo(const Node * node,const Node * src,std::vector<ControlFlowInfo> * cf_info)474 void AddControlFlowInfo(const Node* node, const Node* src,
475 std::vector<ControlFlowInfo>* cf_info) {
476 int id = node->id();
477 if (static_cast<size_t>(id) >= cf_info->size()) {
478 cf_info->resize(id + 1);
479 }
480 const ControlFlowInfo& src_info = (*cf_info)[src->id()];
481 ControlFlowInfo* info = &(*cf_info)[id];
482 info->frame = src_info.frame;
483 info->parent_frame = src_info.parent_frame;
484 info->frame_name = src_info.frame_name;
485 }
486
487 // Constructs a control loop. Returns a struct containing the newly created
488 // enter, merge, and switch nodes. The enter and merge nodes are used in the
489 // recursive construction of control loops for nested frames (loops). The
490 // switch node will be connected to the LoopCond node. The merge node will
491 // be connected to all the recvs of the same frame by control edges when
492 // the actual partitioning happens.
AddControlLoop(const PartitionOptions & opts,Graph * g,const Node * src,const Edge * edge,Node * loop_cond,std::vector<ControlFlowInfo> * cf_info,ControlLoop * loop)493 Status AddControlLoop(const PartitionOptions& opts, Graph* g, const Node* src,
494 const Edge* edge, Node* loop_cond,
495 std::vector<ControlFlowInfo>* cf_info,
496 ControlLoop* loop) {
497 Status status;
498 GraphDefBuilder::Options bopts(g, &status);
499 const ControlFlowInfo& src_info = (*cf_info)[src->id()];
500 const string& device_name = edge->dst()->assigned_device_name();
501 const string& frame_name = src_info.frame_name;
502 int parallel_iterations;
503 status = GetNodeAttr(src_info.frame->attrs(), "parallel_iterations",
504 ¶llel_iterations);
505 if (!status.ok()) return status;
506
507 // The names of the nodes to be added.
508 const string& enter_name =
509 ControlLoopName(opts.new_name(edge->dst()->name()));
510 const string& merge_name =
511 ControlLoopName(opts.new_name(edge->dst()->name()));
512 const string& switch_name =
513 ControlLoopName(opts.new_name(edge->dst()->name()));
514 const string& next_name = ControlLoopName(opts.new_name(edge->dst()->name()));
515
516 // Add the nodes to the graph g.
517 Node* enter = AddControlEnter(g, enter_name, device_name, frame_name,
518 parallel_iterations, &status);
519 if (!status.ok()) return status;
520 Node* merge = AddControlMerge(enter_name, next_name, g, merge_name,
521 device_name, &status);
522 if (!status.ok()) return status;
523 Node* switch_node = AddControlSwitch(merge, loop_cond, device_name,
524 bopts.WithName(switch_name));
525 if (!status.ok()) return status;
526 Node* next =
527 AddControlNext({switch_node, 1}, device_name, bopts.WithName(next_name));
528 if (!status.ok()) return status;
529
530 // Add control flow info for these new nodes:
531 AddControlFlowInfo(enter, src, cf_info);
532 AddControlFlowInfo(merge, src, cf_info);
533 AddControlFlowInfo(switch_node, src, cf_info);
534 AddControlFlowInfo(next, src, cf_info);
535
536 // Add input edges for the newly created merge node:
537 g->AddEdge(enter, 0, merge, 0);
538 g->AddEdge(next, 0, merge, 1);
539
540 loop->enter = enter;
541 loop->merge = merge;
542 loop->switch_node = switch_node;
543 return Status::OK();
544 }
545
546 // Build memory and device type info for every node in the graph.
547 // TODO(yuanbyu): It might be simpler if we convert MemoryType to
548 // DeviceType for the inputs/outputs of each node.
BuildMemoryDeviceInfo(const Graph & g,GraphInfo * info)549 Status BuildMemoryDeviceInfo(const Graph& g, GraphInfo* info) {
550 MemoryTypeVector input_memory_types;
551 MemoryTypeVector output_memory_types;
552
553 info->device_types.resize(g.num_node_ids(), DEVICE_CPU);
554 for (const Node* node : g.op_nodes()) {
555 DeviceNameUtils::ParsedName parsed;
556 if (!DeviceNameUtils::ParseFullName(node->assigned_device_name(),
557 &parsed)) {
558 return errors::Internal("Malformed assigned device '",
559 node->assigned_device_name(), "'");
560 }
561
562 TF_RETURN_IF_ERROR(MemoryTypesForNode(
563 g.op_registry(), DeviceType(parsed.type), node->def(),
564 &input_memory_types, &output_memory_types));
565
566 int node_id = node->id();
567 info->device_types[node_id] = DeviceType(parsed.type);
568 for (int i = 0; i < input_memory_types.size(); ++i) {
569 info->input_types[{node_id, i}] = input_memory_types[i];
570 }
571 for (int i = 0; i < output_memory_types.size(); ++i) {
572 info->output_types[{node_id, i}] = output_memory_types[i];
573 }
574 }
575 return Status::OK();
576 }
577
InputFrame(const Node * node,const std::vector<ControlFlowInfo> & cf_info)578 const Node* InputFrame(const Node* node,
579 const std::vector<ControlFlowInfo>& cf_info) {
580 // An input is in the same frame as the node except for Enter nodes.
581 // The input of Enter is in the parent frame of the Enter node.
582 if (!node->IsEnter()) {
583 return node;
584 }
585 return cf_info[node->id()].parent_frame;
586 }
587
OutputFrame(const Node * node,const std::vector<ControlFlowInfo> & cf_info)588 const Node* OutputFrame(const Node* node,
589 const std::vector<ControlFlowInfo>& cf_info) {
590 // An output is in the same frame as the node except for Exit nodes.
591 // The output of Exit is in the parent frame of the Exit node.
592 if (!node->IsExit()) {
593 return node;
594 }
595 return cf_info[node->id()].parent_frame;
596 }
597
598 // Each participating device needs to decide a) if there is a next iteration,
599 // and b) if the loop terminates. We take the approach to encode this control
600 // flow logic in the dataflow graph. There are at least two possible encodings.
601 // In a completely decentralized encoding, the participants communicate peer
602 // to peer. The other encoding uses a frame leader (the participant who owns
603 // the pivot termination predicate) to broadcast the termination condition to
604 // all the participants. For now we take the latter because it is simpler.
605 //
606 // TODO(yuanbyu): The correctness of this construction is rather subtle. I got
607 // it wrong many times so it would be nice to write a proof to be sure.
AddControlFlow(const PartitionOptions & opts,Graph * g,GraphInfo * g_info)608 Status AddControlFlow(const PartitionOptions& opts, Graph* g,
609 GraphInfo* g_info) {
610 Status status;
611 GraphDefBuilder::Options bopts(g, &status);
612 std::vector<ControlFlowInfo>& cf_info = g_info->cf_info;
613
614 // Build the control flow info for every node.
615 status = BuildControlFlowInfo(g, &cf_info);
616 if (!status.ok()) return status;
617
618 OptimizeControlFlowColocation(g);
619
620 // The map from frames to their LoopCond nodes.
621 std::unordered_map<string, Node*> frame_cond_map;
622 int num_node_ids = g->num_node_ids();
623 for (int i = 0; i < num_node_ids; ++i) {
624 Node* node = g->FindNodeId(i);
625 if (node == nullptr) continue;
626
627 if (IsLoopCond(node)) {
628 const string& frame_name = cf_info[node->id()].frame_name;
629 DCHECK(!frame_name.empty());
630 frame_cond_map[frame_name] = node;
631 }
632 }
633
634 // Add all control loops for cross-device frames.
635 // A control loop is added only when there is a cross-device edge in a
636 // non-root frame. Nothing is added if there is no loops. We also don't
637 // add anything for a frame that is completely local to a device. For
638 // nested loops, we stack the control loops together by connecting
639 // the merge of the outer loop to the enter of the inner loop.
640 //
641 // A map from <frame_name, device_name> to ControlLoop.
642 std::unordered_map<string, ControlLoop> control_loops;
643 int num_edge_ids = g->num_edge_ids();
644 for (int i = 0; i < num_edge_ids; ++i) {
645 const Edge* edge = g->FindEdgeId(i);
646 if (edge == nullptr) continue;
647
648 const Node* src = edge->src();
649 const Node* dst = edge->dst();
650 // Skip Sink/Source nodes.
651 if (!src->IsOp() || !dst->IsOp()) continue;
652
653 const string& src_device = src->assigned_device_name();
654 const string& dst_device = dst->assigned_device_name();
655 // Skip local edges.
656 if (src_device == dst_device) continue;
657
658 const Node* src_frame = OutputFrame(src, cf_info);
659 const Node* dst_frame = InputFrame(dst, cf_info);
660 const string& src_frame_name = cf_info[src_frame->id()].frame_name;
661 const string& dst_frame_name = cf_info[dst_frame->id()].frame_name;
662 // Skip if src and dst are not in the same frame.
663 if (src_frame_name.empty() || src_frame_name != dst_frame_name) {
664 continue;
665 }
666
667 // Add the control loop. Start by adding the control loop for the
668 // current frame if needed, and recursively adding the control loop
669 // for its outer frame when nested.
670 ControlLoop child_loop;
671 while (true) {
672 const string& curr_frame_name = cf_info[src_frame->id()].frame_name;
673 if (curr_frame_name.empty()) {
674 // We have reached the root frame.
675 if (child_loop.merge != nullptr) {
676 const string& node_name = opts.new_name(edge->dst()->name());
677 const string& device_name = edge->dst()->assigned_device_name();
678 Node* const_node =
679 AddControlConst(device_name, bopts.WithName(node_name));
680 if (!status.ok()) return status;
681 AddControlFlowInfo(const_node, src_frame, &cf_info);
682 g->AddEdge(const_node, 0, child_loop.enter, 0);
683 }
684 break;
685 }
686
687 const string& cl_key = strings::StrCat(curr_frame_name, "$$", dst_device);
688 auto it = control_loops.find(cl_key);
689 if (it != control_loops.end()) {
690 if (child_loop.enter != nullptr) {
691 g->AddEdge(it->second.merge, 0, child_loop.enter, 0);
692 }
693 break;
694 }
695
696 // Get the frame's LoopCond.
697 auto cond_it = frame_cond_map.find(curr_frame_name);
698 if (cond_it == frame_cond_map.end()) {
699 return errors::InvalidArgument(
700 "A cross-device loop must have a pivot predicate: ",
701 curr_frame_name);
702 }
703 Node* loop_cond = cond_it->second;
704
705 // Add the control loop.
706 ControlLoop curr_loop;
707 status = AddControlLoop(opts, g, src_frame, edge, loop_cond, &cf_info,
708 &curr_loop);
709 if (!status.ok()) return status;
710 control_loops[cl_key] = curr_loop;
711
712 if (child_loop.enter != nullptr) {
713 // Connect the merge of the outer loop to the enter of the inner.
714 g->AddEdge(curr_loop.merge, 0, child_loop.enter, 0);
715 }
716 src_frame = cf_info[src_frame->id()].parent_frame;
717 child_loop = curr_loop;
718 }
719 }
720
721 // For a cross-device edge, on the dst device, add a control edge
722 // from the merge node of the control loop to dst. If a send/recv is
723 // introduced for this edge in future partitioning, we delete this
724 // control edge and add a new control edge from the merge to the recv.
725 num_edge_ids = g->num_edge_ids();
726 for (int i = 0; i < num_edge_ids; ++i) {
727 const Edge* edge = g->FindEdgeId(i);
728 if (edge == nullptr) continue;
729
730 const Node* src = edge->src();
731 Node* dst = edge->dst();
732 // Skip Sink/Source nodes.
733 if (!src->IsOp() || !dst->IsOp()) continue;
734
735 const string& src_device = src->assigned_device_name();
736 const string& dst_device = dst->assigned_device_name();
737 if (src_device != dst_device) {
738 const Node* src_frame = OutputFrame(src, cf_info);
739 const Node* dst_frame = InputFrame(dst, cf_info);
740 const string& src_frame_name = cf_info[src_frame->id()].frame_name;
741 const string& dst_frame_name = cf_info[dst_frame->id()].frame_name;
742 if (!src_frame_name.empty() && src_frame_name == dst_frame_name) {
743 const string& cl_key =
744 strings::StrCat(dst_frame_name, "$$", dst_device);
745 ControlLoop loop = control_loops[cl_key];
746 DCHECK(loop.enter != nullptr);
747 // Note that we'll create multiple duplicate edges if dst has multiple
748 // cross-device inputs. This is expected by the logic in Partition(), so
749 // it can add control edges to the recv nodes once they're created.
750 g->AddControlEdge(loop.merge, dst, /*allow_duplicates=*/true);
751 }
752 }
753 }
754 return Status::OK();
755 }
756
757 struct PriorityTopoSortNode {
PriorityTopoSortNodetensorflow::__anon89c31b760111::PriorityTopoSortNode758 PriorityTopoSortNode(const NodeDef* n, int64 st) : node(n), start_time(st) {}
759
760 const NodeDef* node;
761 int64 start_time;
762 };
763
764 struct PriorityTopoSortNodeGreater {
operator ()tensorflow::__anon89c31b760111::PriorityTopoSortNodeGreater765 bool operator()(const PriorityTopoSortNode& left,
766 const PriorityTopoSortNode& right) {
767 return left.start_time > right.start_time;
768 }
769 };
770
771 } // namespace
772
773 // Returns in <nodes> the nodes that should participate in epoch-based recv
774 // scheduling, along with their times; <nodes> is ordered by increasing
775 // start_time. Returns in <node_to_start_time_out> the timing for all nodes,
776 // even those not in <nodes>.
777 //
778 // Comparing to sorting on the node's start time only, this also processes the
779 // nodes in dependency order, and updates start times to ensure a node's
780 // start_time > the start time for all dependencies.
781 //
782 // Note that graph_partition_test.cc accesses this function for testing, even
783 // though it's not declared in the header.
TopologicalSortNodesWithTimePriority(const GraphDef * gdef,std::vector<std::pair<const NodeDef *,int64>> * nodes,std::unordered_map<const NodeDef *,int64> * node_to_start_time_out)784 Status TopologicalSortNodesWithTimePriority(
785 const GraphDef* gdef, std::vector<std::pair<const NodeDef*, int64>>* nodes,
786 std::unordered_map<const NodeDef*, int64>* node_to_start_time_out) {
787 // Queue of nodes to process; lowest start time is returned first.
788 std::priority_queue<PriorityTopoSortNode, std::vector<PriorityTopoSortNode>,
789 PriorityTopoSortNodeGreater>
790 q;
791 std::unordered_map<const NodeDef*, int64> node_to_start_time;
792 auto enqueue = [&q, &node_to_start_time](const NodeDef* node) {
793 const int64 start_time = node_to_start_time[node];
794 q.emplace(node, start_time);
795 };
796
797 // Build initial structures, initial contents of queue.
798 std::unordered_map<string, std::vector<const NodeDef*>> node_to_output_nodes;
799 std::unordered_map<const NodeDef*, int> inputs_needed;
800 for (int n = 0; n < gdef->node_size(); ++n) {
801 const NodeDef* ndef = &gdef->node(n);
802 for (int i = 0; i < ndef->input_size(); ++i) {
803 node_to_output_nodes[string(ParseTensorName(ndef->input(i)).first)]
804 .push_back(ndef);
805 }
806 int64 start_time;
807 TF_RETURN_IF_ERROR(GetNodeAttr(*ndef, "_start_time", &start_time));
808 node_to_start_time[ndef] = start_time;
809 inputs_needed[ndef] = ndef->input_size();
810 if (ndef->input_size() == 0) {
811 enqueue(ndef);
812 }
813 }
814
815 // Determine which merge nodes are parts of loops; these
816 // need to happen in the traversal after all non-NextIteration inputs
817 // are run.
818 for (int n = 0; n < gdef->node_size(); ++n) {
819 const NodeDef* ndef = &gdef->node(n);
820 if (IsNextIteration(*ndef)) {
821 for (const NodeDef* n : node_to_output_nodes[ndef->name()]) {
822 if (IsMerge(*n)) {
823 // n is a merge that is part of a loop structure.
824 // It doesn't need to wait for this NextIteration loop
825 // when doing the traversal.
826 --inputs_needed[n];
827 }
828 }
829 }
830 }
831
832 // Traverse.
833 std::vector<std::pair<const NodeDef*, int64>> start_times;
834 start_times.reserve(gdef->node_size());
835 while (!q.empty()) {
836 PriorityTopoSortNode cur = q.top();
837 q.pop();
838
839 start_times.emplace_back(cur.node, cur.start_time);
840
841 for (const NodeDef* n : node_to_output_nodes[cur.node->name()]) {
842 auto& output_start_time = node_to_start_time[n];
843 if (output_start_time <= cur.start_time) {
844 output_start_time = cur.start_time + 1;
845 }
846 if (--inputs_needed[n] == 0) {
847 enqueue(n);
848 }
849 }
850 }
851
852 // Done.
853 nodes->swap(start_times);
854 node_to_start_time_out->swap(node_to_start_time);
855 return Status::OK();
856 }
857
AddControlEdges(const PartitionOptions & opts,std::unordered_map<string,GraphDef> * partitions)858 Status AddControlEdges(const PartitionOptions& opts,
859 std::unordered_map<string, GraphDef>* partitions) {
860 Status status;
861 // TODO(yuanbyu): Very naive for now. To be improved.
862 const int num_epochs = 100;
863 const int prefetch = 6;
864
865 for (auto& part : *partitions) {
866 GraphDef* gdef = &part.second;
867 std::vector<std::pair<const NodeDef*, int64>> start_times;
868 std::unordered_map<const NodeDef*, int64> node_to_start_time;
869 status = TopologicalSortNodesWithTimePriority(gdef, &start_times,
870 &node_to_start_time);
871 if (!status.ok()) {
872 return status;
873 }
874
875 // Add a dummy node for every epoch, and add a control edge from the
876 // "last" node in the preceding epoch to the dummy node.
877 string device_name = gdef->node(0).device();
878 int64 makespan = start_times.back().second;
879 int64 resolution = (makespan / num_epochs) + 1;
880
881 int i = 0;
882 int j = 0;
883 std::vector<NodeDef*> dummys;
884 while (i < num_epochs && static_cast<size_t>(j) < start_times.size()) {
885 if (i * resolution > start_times[j].second) {
886 j++;
887 } else {
888 NodeDef* dummy = AddControlTrigger(opts, gdef, device_name, i,
889 i * resolution, &status);
890 if (!status.ok()) {
891 return status;
892 }
893 dummys.push_back(dummy);
894 if (j > 0) {
895 string src_name = start_times[j - 1].first->name();
896 AddInput(dummy, src_name, Graph::kControlSlot);
897 }
898 i++;
899 }
900 }
901
902 // Finally, add the control edges to recvs.
903 for (int n = 0; n < gdef->node_size(); ++n) {
904 NodeDef* ndef = gdef->mutable_node(n);
905 if (ndef->op() == "_Recv") {
906 const int64 start_time = node_to_start_time[ndef];
907 const int recv_epoch = start_time / resolution;
908 if (recv_epoch >= prefetch) {
909 NodeDef* dummy = dummys[recv_epoch - prefetch];
910 AddInput(ndef, dummy->name(), Graph::kControlSlot);
911 }
912 }
913 }
914 }
915 return Status::OK();
916 }
917
918 // If 'ndef' is a Send or Recv, fills its attr send_device_incarnation
919 // if possible.
SetIncarnation(const PartitionOptions & opts,NodeDef * ndef)920 void SetIncarnation(const PartitionOptions& opts, NodeDef* ndef) {
921 StringPiece op(ndef->op());
922 if (op != "_Send" && op != "_Recv") {
923 // Not related to send/recv.
924 return;
925 }
926 string send_device;
927 if (!GetNodeAttr(*ndef, "send_device", &send_device).ok()) {
928 // No known send_device. The runtime will detect it later.
929 return;
930 }
931 int64 incarnation = PartitionOptions::kIllegalIncarnation;
932 if (!GetNodeAttr(*ndef, "send_device_incarnation", &incarnation).ok() ||
933 (incarnation == PartitionOptions::kIllegalIncarnation)) {
934 incarnation = opts.get_incarnation(send_device);
935 SetAttrValue(incarnation,
936 &((*ndef->mutable_attr())["send_device_incarnation"]));
937 }
938 }
939
940 // Sets attribute send_device_incarnation of all Send/Recv nodes in
941 // 'gdef', if possible.
SetIncarnation(const PartitionOptions & opts,GraphDef * gdef)942 void SetIncarnation(const PartitionOptions& opts, GraphDef* gdef) {
943 for (NodeDef& ndef : *gdef->mutable_node()) {
944 SetIncarnation(opts, &ndef);
945 }
946 for (FunctionDef& fdef : *gdef->mutable_library()->mutable_function()) {
947 for (NodeDef& ndef : *fdef.mutable_node_def()) {
948 SetIncarnation(opts, &ndef);
949 }
950 }
951 }
952
Partition(const PartitionOptions & opts,Graph * g,std::unordered_map<string,GraphDef> * partitions)953 Status Partition(const PartitionOptions& opts, Graph* g,
954 std::unordered_map<string, GraphDef>* partitions) {
955 Status status;
956 partitions->clear();
957
958 GraphInfo g_info;
959 if (!opts.control_flow_added) {
960 // Add the "code" for distributed execution of control flow. Code is
961 // added only for the frames that are placed on multiple devices. The
962 // new graph is an equivalent transformation of the original graph and
963 // has the property that it can be subsequently partitioned arbitrarily
964 // (down to the level of individual device) for distributed execution.
965 status = AddControlFlow(opts, g, &g_info);
966 if (!status.ok()) return status;
967 }
968
969 // At this point, all the graph mutations have been done. Build memory
970 // and device type info for every node and edge in the graph.
971 status = BuildMemoryDeviceInfo(*g, &g_info);
972 if (!status.ok()) return status;
973
974 string dstp;
975 std::vector<const Edge*> inputs;
976 DupRecvTable dup_recv(3);
977 // For a node dst, 'ref_recvs' remembers the recvs introduced by a ref
978 // edge to dst. 'ref_control_inputs' remembers the inputs by a non-ref
979 // edge to dst. We will add a control edge for every pair in
980 // (ref_recvs x ref_control_inputs).
981 std::vector<NodeDef*> ref_recvs;
982 std::vector<string> ref_control_inputs;
983
984 int32 num_data = 0;
985 int32 num_control = 0;
986 for (const Node* dst : g->op_nodes()) {
987 dstp = opts.node_to_loc(dst);
988 GraphDef* dst_graph = &(*partitions)[dstp];
989 NodeDef* dst_def = dst_graph->add_node();
990 *dst_def = dst->def();
991 MergeDebugInfo(NodeDebugInfo(dst->def()), dst_def);
992 dst_def->set_device(dst->assigned_device_name());
993 dst_def->clear_input(); // Inputs are filled below
994 if (opts.need_to_record_start_times) {
995 int64 start_time;
996 status = GetNodeAttr(*dst_def, "_start_time", &start_time);
997 if (errors::IsNotFound(status)) {
998 start_time = opts.start_times[dst->id()].value();
999 AddNodeAttr("_start_time", start_time, dst_def);
1000 } else if (!status.ok()) {
1001 return status;
1002 }
1003 }
1004
1005 // Arrange the incoming edges to dst so that input[i] holds the
1006 // input flowing into slot numbered i. Trailing entries in input[]
1007 // hold control edges.
1008 inputs.clear();
1009 inputs.resize(dst->num_inputs(), nullptr);
1010 ref_recvs.clear();
1011 ref_control_inputs.clear();
1012 const Edge* control_flow_edge = nullptr;
1013 int32 num_control_flow_edges = 0;
1014 int32 num_input_edges = 0;
1015 for (const Edge* edge : dst->in_edges()) {
1016 if (edge->IsControlEdge()) {
1017 if (IsMerge(edge->src()) && IsControlLoop(edge->src())) {
1018 // This is one of the control edges added for control flow. There
1019 // can be multiple such edges as the dest node may have multiple
1020 // remote inputs. We keep track of the number of such edges.
1021 control_flow_edge = edge;
1022 ++num_control_flow_edges;
1023 } else {
1024 inputs.push_back(edge);
1025 }
1026 } else {
1027 DCHECK(inputs[edge->dst_input()] == nullptr);
1028 inputs[edge->dst_input()] = edge;
1029 ++num_input_edges;
1030 }
1031 }
1032
1033 if (num_input_edges != dst->num_inputs()) {
1034 return errors::InvalidArgument("Incomplete graph, missing ",
1035 (dst->num_inputs() - num_input_edges),
1036 " inputs for ", dst->name());
1037 }
1038
1039 // Process in order so that all data edges are added as inputs to
1040 // dst in Edge::dst_input() order.
1041 for (const Edge* edge : inputs) {
1042 const Node* src = edge->src();
1043 if (!src->IsOp()) continue; // Skip Sink/Source nodes.
1044
1045 GraphDef* src_graph = &(*partitions)[opts.node_to_loc(src)];
1046 if (src_graph == dst_graph && !NeedSameDeviceSendRecv(edge, g_info)) {
1047 // Same partition and compatible memory types:
1048 AddInput(dst_def, src->name(), edge->src_output());
1049 if (edge->IsControlEdge() ||
1050 !IsRefType(src->output_type(edge->src_output()))) {
1051 ref_control_inputs.push_back(src->name());
1052 }
1053 continue;
1054 }
1055
1056 int64 send_start_time = 0;
1057 int64 recv_start_time = 0;
1058 if (opts.scheduling_for_recvs) {
1059 status = GetNodeAttr(src->attrs(), "_start_time", &send_start_time);
1060 if (errors::IsNotFound(status) && opts.need_to_record_start_times) {
1061 send_start_time = opts.start_times[src->id()].value();
1062 } else if (!status.ok()) {
1063 return status;
1064 }
1065
1066 status = GetNodeAttr(dst->attrs(), "_start_time", &recv_start_time);
1067 if (errors::IsNotFound(status) && opts.need_to_record_start_times) {
1068 recv_start_time = opts.start_times[dst->id()].value();
1069 } else if (!status.ok()) {
1070 return status;
1071 }
1072 }
1073
1074 // Check whether there is already a send/recv pair transferring
1075 // the same tensor/control from the src to dst partition.
1076 const bool on_host = IsDstInputOnHost(edge, g_info);
1077 DupRecvKey key{src->id(), edge->src_output(), dst_graph, on_host};
1078 auto iter = dup_recv.find(key);
1079 if (iter != dup_recv.end()) {
1080 // We found one. Reuse the data/control transferred already.
1081 const string& recv_node_name = iter->second.recv->name();
1082 if (edge->IsControlEdge()) {
1083 AddInput(dst_def, recv_node_name, Graph::kControlSlot);
1084 } else {
1085 AddInput(dst_def, recv_node_name, 0);
1086 }
1087 ref_control_inputs.push_back(recv_node_name);
1088
1089 // We want the start_time for the recv to be the smallest of the start
1090 // times of it's consumers. So we update this whenever we use a recv,
1091 // and write it out to the attribute at the end of the subroutine
1092 if (iter->second.start_time > recv_start_time) {
1093 iter->second.start_time = recv_start_time;
1094 }
1095 continue;
1096 }
1097
1098 NodeDefBuilder::NodeOut send_from;
1099 if (edge->IsControlEdge()) {
1100 // Insert a dummy const node that will generate a tiny
1101 // data element to be sent from send to recv.
1102 VLOG(1) << "Send/Recv control: " << src->assigned_device_name() << "["
1103 << src->name() << "] -> " << dst->assigned_device_name() << "["
1104 << dst->name() << "]";
1105 NodeDef* dummy = AddDummyConst(opts, src_graph, edge, &status);
1106 if (!status.ok()) return status;
1107 // Set the start time for this dummy node.
1108 if (opts.scheduling_for_recvs) {
1109 AddNodeAttr("_start_time", send_start_time, dummy);
1110 }
1111 AddInput(dummy, src->name(), Graph::kControlSlot);
1112 send_from.Reset(dummy->name(), 0, DT_FLOAT);
1113 } else {
1114 send_from.Reset(src->name(), edge->src_output(), EdgeType(edge));
1115 }
1116
1117 // Need to split edge by placing matching send/recv nodes on
1118 // the src/dst sides of the edge.
1119 NodeDef* send = AddSend(opts, g_info, src_graph, edge, send_from,
1120 send_start_time, &status);
1121 if (!status.ok()) return status;
1122
1123 NodeDef* real_recv = nullptr;
1124 NodeDef* recv =
1125 AddRecv(opts, g_info, dst_graph, edge, &real_recv, &status);
1126 if (!status.ok()) return status;
1127
1128 // Fix up the control flow edge.
1129 // NOTE(yuanbyu): 'real_recv' must be the real recv node.
1130 if (src_graph == dst_graph) {
1131 // For same device send/recv, add a control edge from send to recv.
1132 // This prevents the asynchronous recv kernel from being scheduled
1133 // before the data is available.
1134 AddInput(real_recv, send->name(), Graph::kControlSlot);
1135 } else if (control_flow_edge != nullptr) {
1136 // Redirect control edge to the real recv since this is not the same
1137 // device send/recv.
1138 --num_control_flow_edges;
1139 AddInput(real_recv, control_flow_edge->src()->name(),
1140 Graph::kControlSlot);
1141 }
1142
1143 if (!edge->IsControlEdge() &&
1144 IsRefType(src->output_type(edge->src_output()))) {
1145 AddNodeAttr("_start_time", recv_start_time, recv);
1146 if (real_recv != recv) {
1147 AddNodeAttr("_start_time", recv_start_time, real_recv);
1148 }
1149 // If src is of ref type and the edge is not a control edge, dst has
1150 // read semantics and therefore we must control the recv.
1151 ref_recvs.push_back(real_recv);
1152 } else {
1153 // Memorize the send/recv pair, only if this is not a "ref" edge.
1154 // NOTE(yuanbyu): Collapsing ref edges requires extreme care so
1155 // for now we don't do it.
1156 dup_recv[key] = {recv, real_recv, recv_start_time};
1157 ref_control_inputs.push_back(recv->name());
1158 }
1159
1160 if (edge->IsControlEdge()) {
1161 ++num_control;
1162 AddInput(dst_def, recv->name(), Graph::kControlSlot);
1163 } else {
1164 ++num_data;
1165 AddInput(dst_def, recv->name(), 0);
1166 }
1167 }
1168
1169 // Add control edges from 'ref_control_inputs' to 'ref_recvs'.
1170 // NOTE(yuanbyu): Adding these control edges should not introduce
1171 // deadlocks. 'dst' has implicit "read" nodes that, when we split
1172 // across devices, are made explicit; Retargeting the dependencies
1173 // to 'dst' to those nodes would not introduce cycles if there isn't
1174 // one before the transformation.
1175 // NOTE(yuanbyu): This may impact performance because it defers the
1176 // execution of recvs until all the other inputs become available.
1177 AddReadControl(ref_recvs, ref_control_inputs);
1178
1179 // Add back the control edges for control flow that are not used.
1180 if (control_flow_edge != nullptr) {
1181 for (int i = 0; i < num_control_flow_edges; ++i) {
1182 AddInput(dst_def, control_flow_edge->src()->name(),
1183 Graph::kControlSlot);
1184 }
1185 }
1186 }
1187
1188 const FunctionLibraryDefinition* flib_def = opts.flib_def;
1189 if (flib_def == nullptr) {
1190 flib_def = &g->flib_def();
1191 }
1192
1193 // Set versions, function library and send/recv incarnation.
1194 for (auto& it : *partitions) {
1195 GraphDef* gdef = &it.second;
1196 *gdef->mutable_versions() = g->versions();
1197 // Prune unreachable functions from `flib_def` before adding them to `gdef`.
1198 *gdef->mutable_library() = flib_def->ReachableDefinitions(*gdef).ToProto();
1199
1200 // Traverse the graph to fill every send/recv op's incarnation
1201 // information.
1202 SetIncarnation(opts, gdef);
1203 }
1204
1205 // Set the start times for recvs at the very end.
1206 if (opts.scheduling_for_recvs) {
1207 for (auto& it : dup_recv) {
1208 AddNodeAttr("_start_time", it.second.start_time, it.second.recv);
1209 if (it.second.real_recv != it.second.recv) {
1210 AddNodeAttr("_start_time", it.second.start_time, it.second.real_recv);
1211 }
1212 }
1213 }
1214
1215 VLOG(1) << "Added send/recv: controls=" << num_control
1216 << ", data=" << num_data;
1217 return Status::OK();
1218 }
1219
1220 } // namespace tensorflow
1221