• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/graph_partition.h"
17 
18 #include <deque>
19 #include <queue>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/memory_types.h"
28 #include "tensorflow/core/framework/node_def_builder.h"
29 #include "tensorflow/core/framework/tensor.pb.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/framework/versions.pb.h"
32 #include "tensorflow/core/graph/algorithm.h"
33 #include "tensorflow/core/graph/control_flow.h"
34 #include "tensorflow/core/graph/costmodel.h"
35 #include "tensorflow/core/graph/graph_def_builder.h"
36 #include "tensorflow/core/graph/node_builder.h"
37 #include "tensorflow/core/graph/tensor_id.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/hash/hash.h"
40 #include "tensorflow/core/lib/strings/str_util.h"
41 #include "tensorflow/core/platform/logging.h"
42 #include "tensorflow/core/util/device_name_utils.h"
43 
44 namespace tensorflow {
45 
46 namespace {
47 
IsMerge(const NodeDef & node_def)48 inline bool IsMerge(const NodeDef& node_def) {
49   return node_def.op() == "Merge" || node_def.op() == "RefMerge";
50 }
51 
IsNextIteration(const NodeDef & node_def)52 inline bool IsNextIteration(const NodeDef& node_def) {
53   return node_def.op() == "NextIteration" ||
54          node_def.op() == "RefNextIteration";
55 }
56 
57 struct DupRecvKey {
58   int src_node_id;           // Edge's src node id
59   int src_output_slot;       // Edge's src node output slot
60   GraphDef* dst_graph;       // Edge's dst node is in this subgraph
61   bool recv_output_on_host;  // The output of recv is on host
62 
63   template <typename H>
AbslHashValue(H h,const DupRecvKey & c)64   friend H AbslHashValue(H h, const DupRecvKey& c) {
65     return H::combine(std::move(h), c.src_node_id, c.src_output_slot,
66                       reinterpret_cast<std::uintptr_t>(c.dst_graph),
67                       c.recv_output_on_host);
68   }
69 
operator ==(const DupRecvKey & x,const DupRecvKey & y)70   friend bool operator==(const DupRecvKey& x, const DupRecvKey& y) {
71     return (x.src_node_id == y.src_node_id) &&
72            (x.src_output_slot == y.src_output_slot) &&
73            (x.dst_graph == y.dst_graph) &&
74            (x.recv_output_on_host == y.recv_output_on_host);
75   }
76 };
77 
78 // struct used to store the recvs, so that start times can be properly updated
79 struct RecvInfo {
80   NodeDef* recv;
81   NodeDef* real_recv;
82   int64 start_time;
83 };
84 
85 typedef absl::flat_hash_map<DupRecvKey, RecvInfo> DupRecvTable;
86 
87 // A map used to store memory types for the inputs/outputs of every node.
88 // The key is a pair of ints consisting of a node id and input/output index.
89 // TODO(power): migrate back to std::pair when absl::Hash is fixed for MSVC.
90 struct NodePort {
91   int node_id;
92   int index;
93 
operator ==(const NodePort & x,const NodePort & y)94   friend bool operator==(const NodePort& x, const NodePort& y) {
95     return x.node_id == y.node_id && x.index == y.index;
96   }
97 
98   template <typename H>
AbslHashValue(H h,const NodePort & c)99   friend H AbslHashValue(H h, const NodePort& c) {
100     return H::combine(std::move(h), c.node_id, c.index);
101   }
102 };
103 
104 typedef absl::flat_hash_map<NodePort, MemoryType> MemoryTypeMap;
105 
106 // We collect the following information about the graph before performing
107 // graph partitioning.
108 struct GraphInfo {
109   std::vector<DeviceType> device_types;
110   MemoryTypeMap input_types;
111   MemoryTypeMap output_types;
112   std::vector<ControlFlowInfo> cf_info;
113 };
114 
EdgeType(const Edge * e)115 DataType EdgeType(const Edge* e) {
116   if (e->IsControlEdge()) {
117     return DT_FLOAT;
118   } else {
119     return e->dst()->input_type(e->dst_input());
120   }
121 }
122 
123 // Return true iff we need to add the same device send/recv for 'edge'.
NeedSameDeviceSendRecv(const Edge * edge,const GraphInfo & info)124 bool NeedSameDeviceSendRecv(const Edge* edge, const GraphInfo& info) {
125   if (edge->IsControlEdge()) {
126     return false;
127   }
128 
129   const Node* src = edge->src();
130   const Node* dst = edge->dst();
131   if (src->assigned_device_name() == dst->assigned_device_name()) {
132     int src_port = edge->src_output();
133     int dst_port = edge->dst_input();
134     if (info.device_types[src->id()] != DEVICE_CPU) {
135       auto src_it = info.output_types.find({src->id(), src_port});
136       DCHECK(src_it != info.output_types.end());
137       auto dst_it = info.input_types.find({dst->id(), dst_port});
138       DCHECK(dst_it != info.input_types.end());
139       return src_it->second != dst_it->second;
140     }
141   }
142   return false;
143 }
144 
145 // Return true iff (dst, dst_input) is specified on host memory.
IsDstInputOnHost(const Edge * edge,const GraphInfo & info)146 bool IsDstInputOnHost(const Edge* edge, const GraphInfo& info) {
147   const Node* dst = edge->dst();
148   int dst_port = edge->dst_input();
149   if (info.device_types[dst->id()] != DEVICE_CPU) {
150     if (edge->IsControlEdge()) return false;
151     auto dst_it = info.input_types.find({dst->id(), dst_port});
152     DCHECK(dst_it != info.input_types.end());
153     return dst_it->second == HOST_MEMORY;
154   }
155   return true;
156 }
157 
158 // Add an input to dst that comes from the "src_slot" output of the
159 // node named by "src_name".
AddInput(NodeDef * dst,StringPiece src_name,int src_slot)160 void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
161   if (src_slot == Graph::kControlSlot) {
162     dst->add_input(strings::StrCat("^", src_name));
163   } else if (src_slot == 0) {
164     dst->add_input(src_name.data(), src_name.size());
165   } else {
166     dst->add_input(strings::StrCat(src_name, ":", src_slot));
167   }
168 }
169 
170 // Add a control edge from each input to each recv.
AddReadControl(const std::vector<NodeDef * > & recvs,const std::vector<string> & inputs)171 void AddReadControl(const std::vector<NodeDef*>& recvs,
172                     const std::vector<string>& inputs) {
173   for (NodeDef* recv : recvs) {
174     for (const string& input : inputs) {
175       recv->add_input(strings::StrCat("^", input));
176     }
177   }
178 }
179 
SetSendRecvAttrs(const PartitionOptions & opts,const Edge * edge,NodeDefBuilder * builder)180 void SetSendRecvAttrs(const PartitionOptions& opts, const Edge* edge,
181                       NodeDefBuilder* builder) {
182   builder->Attr("tensor_name",
183                 strings::StrCat("edge_", edge->id(), "_", edge->src()->name()));
184   builder->Attr("send_device", edge->src()->assigned_device_name());
185   builder->Attr("send_device_incarnation",
186                 static_cast<int64>(
187                     opts.get_incarnation(edge->src()->assigned_device_name())));
188   builder->Attr("recv_device", edge->dst()->assigned_device_name());
189   builder->Attr("client_terminated", false);
190 }
191 
AddSend(const PartitionOptions & opts,const GraphInfo & g_info,GraphDef * gdef,const Edge * edge,NodeDefBuilder::NodeOut send_from,int64 start_time,Status * status)192 NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info,
193                  GraphDef* gdef, const Edge* edge,
194                  NodeDefBuilder::NodeOut send_from, int64 start_time,
195                  Status* status) {
196   const DataType dtype = send_from.data_type;
197   const DataType cast_dtype = opts.should_cast ? opts.should_cast(edge) : dtype;
198   const Node* src = edge->src();
199   const int src_port = edge->src_output();
200 
201   // host_memory = true iff we need to use HostSend/HostCast.
202   bool host_memory = false;
203   if (!edge->IsControlEdge()) {
204     auto src_it = g_info.output_types.find({src->id(), src_port});
205     DCHECK(src_it != g_info.output_types.end());
206     host_memory = (src_it->second == HOST_MEMORY);
207   }
208 
209   // Add a cast node that casts dtype to cast_dtype.
210   // NOTE(yuanbyu): Only cast for cross-device send/recv.
211   if (dtype != cast_dtype && !NeedSameDeviceSendRecv(edge, g_info)) {
212     const string cast_op = (host_memory) ? "_HostCast" : "Cast";
213     NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op,
214                                 NodeDebugInfo(*src));
215     cast_builder.Device(src->assigned_device_name()).Input(send_from);
216     if (opts.scheduling_for_recvs) {
217       cast_builder.Attr("_start_time", start_time);
218     }
219     cast_builder.Attr("DstT", cast_dtype);
220 
221     if (cast_dtype == DT_BFLOAT16) {
222       // the below attribute specifies that the cast to bfloat16 should use
223       // truncation. This is needed to retain legacy behavior when we change
224       // the default bfloat16 casts to use rounding instead of truncation
225       cast_builder.Attr("Truncate", true);
226     }
227 
228     NodeDef* cast = gdef->add_node();
229     *status = cast_builder.Finalize(cast);
230     if (!status->ok()) return nullptr;
231 
232     // Connect the Send op to the cast.
233     send_from.Reset(cast->name(), 0, cast_dtype);
234   }
235 
236   // Add the send node.
237   const string send_op = (host_memory) ? "_HostSend" : "_Send";
238   NodeDefBuilder send_builder(opts.new_name(src->name()), send_op,
239                               NodeDebugInfo(*src));
240   SetSendRecvAttrs(opts, edge, &send_builder);
241   send_builder.Device(src->assigned_device_name()).Input(send_from);
242   if (opts.scheduling_for_recvs) {
243     send_builder.Attr("_start_time", start_time);
244   }
245   NodeDef* send = gdef->add_node();
246   *status = send_builder.Finalize(send);
247   return send;
248 }
249 
AddRecv(const PartitionOptions & opts,const GraphInfo & g_info,GraphDef * gdef,const Edge * edge,NodeDef ** real_recv,Status * status)250 NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info,
251                  GraphDef* gdef, const Edge* edge, NodeDef** real_recv,
252                  Status* status) {
253   const DataType dtype = EdgeType(edge);
254   const Node* src = edge->src();
255   const Node* dst = edge->dst();
256   const int dst_port = edge->dst_input();
257   DataType cast_dtype = dtype;
258 
259   // NOTE(yuanbyu): Only cast for cross-device send/recv.
260   if (opts.should_cast && !NeedSameDeviceSendRecv(edge, g_info)) {
261     cast_dtype = opts.should_cast(edge);
262   }
263 
264   // host_memory = true iff we need to use HostRecv/HostCast.
265   bool host_memory = false;
266   if (!edge->IsControlEdge()) {
267     auto dst_it = g_info.input_types.find({dst->id(), dst_port});
268     DCHECK(dst_it != g_info.input_types.end());
269     host_memory = (dst_it->second == HOST_MEMORY);
270   }
271 
272   // Add the recv node.
273   const string recv_op = (host_memory) ? "_HostRecv" : "_Recv";
274   NodeDefBuilder recv_builder(opts.new_name(src->name()), recv_op,
275                               NodeDebugInfo(*src));
276   SetSendRecvAttrs(opts, edge, &recv_builder);
277   recv_builder.Device(dst->assigned_device_name())
278       .Attr("tensor_type", cast_dtype);
279   NodeDef* recv = gdef->add_node();
280   *status = recv_builder.Finalize(recv);
281   if (!status->ok()) return nullptr;
282   *real_recv = recv;
283 
284   // Add the cast node (from cast_dtype to dtype) or an Identity node.
285   if (dtype != cast_dtype) {
286     const string cast_op = (host_memory) ? "_HostCast" : "Cast";
287     NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op,
288                                 NodeDebugInfo(*src));
289     cast_builder.Attr("DstT", dtype);
290     cast_builder.Device(dst->assigned_device_name())
291         .Input(recv->name(), 0, cast_dtype);
292     NodeDef* cast = gdef->add_node();
293     *status = cast_builder.Finalize(cast);
294     if (!status->ok()) return nullptr;
295     return cast;
296   } else if (edge->IsControlEdge()) {
297     // An Identity is only needed for control edges.
298     NodeDefBuilder id_builder(opts.new_name(src->name()), "Identity",
299                               NodeDebugInfo(*src));
300     id_builder.Device(dst->assigned_device_name())
301         .Input(recv->name(), 0, cast_dtype);
302     NodeDef* id = gdef->add_node();
303     *status = id_builder.Finalize(id);
304     if (!status->ok()) return nullptr;
305     return id;
306   } else {
307     return recv;
308   }
309 }
310 
AddDummyConst(const PartitionOptions & opts,GraphDef * gdef,const Edge * edge,Status * status)311 NodeDef* AddDummyConst(const PartitionOptions& opts, GraphDef* gdef,
312                        const Edge* edge, Status* status) {
313   const Node* src = edge->src();
314   Tensor tensor(DT_FLOAT, TensorShape({0}));
315   NodeDef* result = gdef->add_node();
316   *status = NodeDefBuilder(opts.new_name(src->name()), "Const")
317                 .Device(src->assigned_device_name())
318                 .Attr("dtype", DT_FLOAT)
319                 .Attr("value", tensor)
320                 .Finalize(result);
321   return result;
322 }
323 
324 // A dummy node for scheduling.
AddControlTrigger(const PartitionOptions & opts,GraphDef * gdef,const string & assigned_device_name,int64 epoch,int64 starttime,Status * status)325 NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef,
326                            const string& assigned_device_name, int64 epoch,
327                            int64 starttime, Status* status) {
328   NodeDef* result = gdef->add_node();
329   *status = NodeDefBuilder(opts.new_name(strings::StrCat("synch_", epoch)),
330                            "ControlTrigger")
331                 .Device(assigned_device_name)
332                 .Attr("_start_time", starttime)
333                 .Finalize(result);
334   return result;
335 }
336 
337 // Optimize colocation for control flow nodes. For cond, we want the
338 // switch nodes to colocate with its data input. This is particularly
339 // needed for conditional reading of a remote variable. It may also
340 // reduce the number of devices involved in a loop.
341 // TODO(yuanbyu): In this case, we don't respect the requested device in
342 // the GraphDef for these nodes. Ideally, the placer would enforce the
343 // colocation to render this unnecessary.
OptimizeControlFlowColocation(Graph * graph)344 void OptimizeControlFlowColocation(Graph* graph) {
345   auto visit = [](Node* node) {
346     if (IsSwitch(node)) {
347       for (const Edge* in_edge : node->in_edges()) {
348         if (in_edge->dst_input() == 0) {
349           // Colocate with the data input.
350           node->set_assigned_device_name(
351               in_edge->src()->assigned_device_name());
352           return;
353         }
354       }
355     } else if (IsExit(node)) {
356       for (const Edge* in_edge : node->in_edges()) {
357         if (!in_edge->IsControlEdge()) {
358           // Colocate with upstream node.
359           node->set_assigned_device_name(
360               in_edge->src()->assigned_device_name());
361           return;
362         }
363       }
364     } else {
365       if ((IsEnter(node) && !IsRefType(node->input_type(0))) ||
366           IsNextIteration(node)) {
367         const Edge* data_edge = nullptr;
368         for (const Edge* out_edge : node->out_edges()) {
369           if (!out_edge->IsControlEdge()) {
370             data_edge = out_edge;
371             break;
372           }
373         }
374         // Colocate with the first downstream data node.
375         if (data_edge) {
376           node->set_assigned_device_name(
377               data_edge->dst()->assigned_device_name());
378         }
379       }
380     }
381   };
382   DFS(*graph, visit, {});
383 }
384 
ControlLoopName(const string & name)385 string ControlLoopName(const string& name) {
386   return strings::StrCat("_cloop", name);
387 }
388 
IsControlLoop(const Node * node)389 bool IsControlLoop(const Node* node) {
390   const string& name = node->name();
391   return str_util::StartsWith(name, "_cloop");
392 }
393 
394 // An enter node for control flow.
AddControlEnter(Graph * g,const string & node_name,const string & device_name,const string & frame_name,const int parallel_iterations,Status * status)395 Node* AddControlEnter(Graph* g, const string& node_name,
396                       const string& device_name, const string& frame_name,
397                       const int parallel_iterations, Status* status) {
398   NodeBuilder node_builder(node_name, "Enter", g->op_registry());
399   node_builder.Input({"dummy", 0, DT_FLOAT});
400   node_builder.Attr("frame_name", frame_name);
401   node_builder.Attr("parallel_iterations", parallel_iterations);
402   Node* res_node;
403   *status = node_builder.Finalize(g, &res_node);
404   if (!status->ok()) return nullptr;
405   res_node->set_assigned_device_name(device_name);
406   return res_node;
407 }
408 
409 // A merge node for control flow.
AddControlMerge(const string & in_name1,const string & in_name2,Graph * g,const string & node_name,const string & device_name,Status * status)410 Node* AddControlMerge(const string& in_name1, const string& in_name2, Graph* g,
411                       const string& node_name, const string& device_name,
412                       Status* status) {
413   NodeBuilder node_builder(node_name, "Merge", g->op_registry());
414   node_builder.Input({{in_name1, 0, DT_FLOAT}, {in_name2, 0, DT_FLOAT}});
415   Node* res_node;
416   *status = node_builder.Finalize(g, &res_node);
417   if (!status->ok()) return nullptr;
418   res_node->set_assigned_device_name(device_name);
419   return res_node;
420 }
421 
422 // A switch node for control flow.
AddControlSwitch(NodeBuilder::NodeOut input1,NodeBuilder::NodeOut input2,const string & device_name,const GraphDefBuilder::Options & bopts)423 Node* AddControlSwitch(NodeBuilder::NodeOut input1, NodeBuilder::NodeOut input2,
424                        const string& device_name,
425                        const GraphDefBuilder::Options& bopts) {
426   Node* res_node =
427       ops::BinaryOp("Switch", std::move(input1), std::move(input2), bopts);
428   if (bopts.HaveError()) return nullptr;
429   res_node->set_assigned_device_name(device_name);
430   return res_node;
431 }
432 
433 // A next_iteration node for control flow.
AddControlNext(NodeBuilder::NodeOut input,const string & device_name,const GraphDefBuilder::Options & bopts)434 Node* AddControlNext(NodeBuilder::NodeOut input, const string& device_name,
435                      const GraphDefBuilder::Options& bopts) {
436   Node* res_node = ops::UnaryOp("NextIteration", std::move(input), bopts);
437   if (bopts.HaveError()) return nullptr;
438   res_node->set_assigned_device_name(device_name);
439   return res_node;
440 }
441 
EmptyConst(const GraphDefBuilder::Options & options)442 Node* EmptyConst(const GraphDefBuilder::Options& options) {
443   if (options.HaveError()) return nullptr;
444   NodeBuilder node_builder(options.GetNameForOp("Const"), "Const",
445                            options.op_registry());
446   const DataType dt = DataTypeToEnum<float>::v();
447   TensorProto proto;
448   proto.set_dtype(dt);
449   TensorShape empty_shape({0});
450   empty_shape.AsProto(proto.mutable_tensor_shape());
451   node_builder.Attr("dtype", dt).Attr("value", proto);
452   return options.FinalizeBuilder(&node_builder);
453 }
454 
455 // A dummy const node for control flow.
AddControlConst(const string & device_name,const GraphDefBuilder::Options & bopts)456 Node* AddControlConst(const string& device_name,
457                       const GraphDefBuilder::Options& bopts) {
458   Node* res_node = EmptyConst(bopts);
459   if (bopts.HaveError()) return nullptr;
460   res_node->set_assigned_device_name(device_name);
461   return res_node;
462 }
463 
464 // A synthetic loop, made up of dummy nodes. It performs control-flow actions
465 // on behalf of a leader on a different device.
466 struct ControlLoop {
467   Node* enter = nullptr;
468   Node* merge = nullptr;
469   Node* switch_node = nullptr;
470 };
471 
472 // Add the control flow info of a new node added during partitioning.
473 // The new node has the same control flow info as src.
AddControlFlowInfo(const Node * node,const Node * src,std::vector<ControlFlowInfo> * cf_info)474 void AddControlFlowInfo(const Node* node, const Node* src,
475                         std::vector<ControlFlowInfo>* cf_info) {
476   int id = node->id();
477   if (static_cast<size_t>(id) >= cf_info->size()) {
478     cf_info->resize(id + 1);
479   }
480   const ControlFlowInfo& src_info = (*cf_info)[src->id()];
481   ControlFlowInfo* info = &(*cf_info)[id];
482   info->frame = src_info.frame;
483   info->parent_frame = src_info.parent_frame;
484   info->frame_name = src_info.frame_name;
485 }
486 
487 // Constructs a control loop. Returns a struct containing the newly created
488 // enter, merge, and switch nodes. The enter and merge nodes are used in the
489 // recursive construction of control loops for nested frames (loops). The
490 // switch node will be connected to the LoopCond node. The merge node will
491 // be connected to all the recvs of the same frame by control edges when
492 // the actual partitioning happens.
AddControlLoop(const PartitionOptions & opts,Graph * g,const Node * src,const Edge * edge,Node * loop_cond,std::vector<ControlFlowInfo> * cf_info,ControlLoop * loop)493 Status AddControlLoop(const PartitionOptions& opts, Graph* g, const Node* src,
494                       const Edge* edge, Node* loop_cond,
495                       std::vector<ControlFlowInfo>* cf_info,
496                       ControlLoop* loop) {
497   Status status;
498   GraphDefBuilder::Options bopts(g, &status);
499   const ControlFlowInfo& src_info = (*cf_info)[src->id()];
500   const string& device_name = edge->dst()->assigned_device_name();
501   const string& frame_name = src_info.frame_name;
502   int parallel_iterations;
503   status = GetNodeAttr(src_info.frame->attrs(), "parallel_iterations",
504                        &parallel_iterations);
505   if (!status.ok()) return status;
506 
507   // The names of the nodes to be added.
508   const string& enter_name =
509       ControlLoopName(opts.new_name(edge->dst()->name()));
510   const string& merge_name =
511       ControlLoopName(opts.new_name(edge->dst()->name()));
512   const string& switch_name =
513       ControlLoopName(opts.new_name(edge->dst()->name()));
514   const string& next_name = ControlLoopName(opts.new_name(edge->dst()->name()));
515 
516   // Add the nodes to the graph g.
517   Node* enter = AddControlEnter(g, enter_name, device_name, frame_name,
518                                 parallel_iterations, &status);
519   if (!status.ok()) return status;
520   Node* merge = AddControlMerge(enter_name, next_name, g, merge_name,
521                                 device_name, &status);
522   if (!status.ok()) return status;
523   Node* switch_node = AddControlSwitch(merge, loop_cond, device_name,
524                                        bopts.WithName(switch_name));
525   if (!status.ok()) return status;
526   Node* next =
527       AddControlNext({switch_node, 1}, device_name, bopts.WithName(next_name));
528   if (!status.ok()) return status;
529 
530   // Add control flow info for these new nodes:
531   AddControlFlowInfo(enter, src, cf_info);
532   AddControlFlowInfo(merge, src, cf_info);
533   AddControlFlowInfo(switch_node, src, cf_info);
534   AddControlFlowInfo(next, src, cf_info);
535 
536   // Add input edges for the newly created merge node:
537   g->AddEdge(enter, 0, merge, 0);
538   g->AddEdge(next, 0, merge, 1);
539 
540   loop->enter = enter;
541   loop->merge = merge;
542   loop->switch_node = switch_node;
543   return Status::OK();
544 }
545 
546 // Build memory and device type info for every node in the graph.
547 // TODO(yuanbyu): It might be simpler if we convert MemoryType to
548 // DeviceType for the inputs/outputs of each node.
BuildMemoryDeviceInfo(const Graph & g,GraphInfo * info)549 Status BuildMemoryDeviceInfo(const Graph& g, GraphInfo* info) {
550   MemoryTypeVector input_memory_types;
551   MemoryTypeVector output_memory_types;
552 
553   info->device_types.resize(g.num_node_ids(), DEVICE_CPU);
554   for (const Node* node : g.op_nodes()) {
555     DeviceNameUtils::ParsedName parsed;
556     if (!DeviceNameUtils::ParseFullName(node->assigned_device_name(),
557                                         &parsed)) {
558       return errors::Internal("Malformed assigned device '",
559                               node->assigned_device_name(), "'");
560     }
561 
562     TF_RETURN_IF_ERROR(MemoryTypesForNode(
563         g.op_registry(), DeviceType(parsed.type), node->def(),
564         &input_memory_types, &output_memory_types));
565 
566     int node_id = node->id();
567     info->device_types[node_id] = DeviceType(parsed.type);
568     for (int i = 0; i < input_memory_types.size(); ++i) {
569       info->input_types[{node_id, i}] = input_memory_types[i];
570     }
571     for (int i = 0; i < output_memory_types.size(); ++i) {
572       info->output_types[{node_id, i}] = output_memory_types[i];
573     }
574   }
575   return Status::OK();
576 }
577 
InputFrame(const Node * node,const std::vector<ControlFlowInfo> & cf_info)578 const Node* InputFrame(const Node* node,
579                        const std::vector<ControlFlowInfo>& cf_info) {
580   // An input is in the same frame as the node except for Enter nodes.
581   // The input of Enter is in the parent frame of the Enter node.
582   if (!node->IsEnter()) {
583     return node;
584   }
585   return cf_info[node->id()].parent_frame;
586 }
587 
OutputFrame(const Node * node,const std::vector<ControlFlowInfo> & cf_info)588 const Node* OutputFrame(const Node* node,
589                         const std::vector<ControlFlowInfo>& cf_info) {
590   // An output is in the same frame as the node except for Exit nodes.
591   // The output of Exit is in the parent frame of the Exit node.
592   if (!node->IsExit()) {
593     return node;
594   }
595   return cf_info[node->id()].parent_frame;
596 }
597 
598 // Each participating device needs to decide a) if there is a next iteration,
599 // and b) if the loop terminates. We take the approach to encode this control
600 // flow logic in the dataflow graph. There are at least two possible encodings.
601 // In a completely decentralized encoding, the participants communicate peer
602 // to peer. The other encoding uses a frame leader (the participant who owns
603 // the pivot termination predicate) to broadcast the termination condition to
604 // all the participants. For now we take the latter because it is simpler.
605 //
606 // TODO(yuanbyu): The correctness of this construction is rather subtle. I got
607 // it wrong many times so it would be nice to write a proof to be sure.
AddControlFlow(const PartitionOptions & opts,Graph * g,GraphInfo * g_info)608 Status AddControlFlow(const PartitionOptions& opts, Graph* g,
609                       GraphInfo* g_info) {
610   Status status;
611   GraphDefBuilder::Options bopts(g, &status);
612   std::vector<ControlFlowInfo>& cf_info = g_info->cf_info;
613 
614   // Build the control flow info for every node.
615   status = BuildControlFlowInfo(g, &cf_info);
616   if (!status.ok()) return status;
617 
618   OptimizeControlFlowColocation(g);
619 
620   // The map from frames to their LoopCond nodes.
621   std::unordered_map<string, Node*> frame_cond_map;
622   int num_node_ids = g->num_node_ids();
623   for (int i = 0; i < num_node_ids; ++i) {
624     Node* node = g->FindNodeId(i);
625     if (node == nullptr) continue;
626 
627     if (IsLoopCond(node)) {
628       const string& frame_name = cf_info[node->id()].frame_name;
629       DCHECK(!frame_name.empty());
630       frame_cond_map[frame_name] = node;
631     }
632   }
633 
634   // Add all control loops for cross-device frames.
635   // A control loop is added only when there is a cross-device edge in a
636   // non-root frame. Nothing is added if there is no loops. We also don't
637   // add anything for a frame that is completely local to a device. For
638   // nested loops, we stack the control loops together by connecting
639   // the merge of the outer loop to the enter of the inner loop.
640   //
641   // A map from <frame_name, device_name> to ControlLoop.
642   std::unordered_map<string, ControlLoop> control_loops;
643   int num_edge_ids = g->num_edge_ids();
644   for (int i = 0; i < num_edge_ids; ++i) {
645     const Edge* edge = g->FindEdgeId(i);
646     if (edge == nullptr) continue;
647 
648     const Node* src = edge->src();
649     const Node* dst = edge->dst();
650     // Skip Sink/Source nodes.
651     if (!src->IsOp() || !dst->IsOp()) continue;
652 
653     const string& src_device = src->assigned_device_name();
654     const string& dst_device = dst->assigned_device_name();
655     // Skip local edges.
656     if (src_device == dst_device) continue;
657 
658     const Node* src_frame = OutputFrame(src, cf_info);
659     const Node* dst_frame = InputFrame(dst, cf_info);
660     const string& src_frame_name = cf_info[src_frame->id()].frame_name;
661     const string& dst_frame_name = cf_info[dst_frame->id()].frame_name;
662     // Skip if src and dst are not in the same frame.
663     if (src_frame_name.empty() || src_frame_name != dst_frame_name) {
664       continue;
665     }
666 
667     // Add the control loop. Start by adding the control loop for the
668     // current frame if needed, and recursively adding the control loop
669     // for its outer frame when nested.
670     ControlLoop child_loop;
671     while (true) {
672       const string& curr_frame_name = cf_info[src_frame->id()].frame_name;
673       if (curr_frame_name.empty()) {
674         // We have reached the root frame.
675         if (child_loop.merge != nullptr) {
676           const string& node_name = opts.new_name(edge->dst()->name());
677           const string& device_name = edge->dst()->assigned_device_name();
678           Node* const_node =
679               AddControlConst(device_name, bopts.WithName(node_name));
680           if (!status.ok()) return status;
681           AddControlFlowInfo(const_node, src_frame, &cf_info);
682           g->AddEdge(const_node, 0, child_loop.enter, 0);
683         }
684         break;
685       }
686 
687       const string& cl_key = strings::StrCat(curr_frame_name, "$$", dst_device);
688       auto it = control_loops.find(cl_key);
689       if (it != control_loops.end()) {
690         if (child_loop.enter != nullptr) {
691           g->AddEdge(it->second.merge, 0, child_loop.enter, 0);
692         }
693         break;
694       }
695 
696       // Get the frame's LoopCond.
697       auto cond_it = frame_cond_map.find(curr_frame_name);
698       if (cond_it == frame_cond_map.end()) {
699         return errors::InvalidArgument(
700             "A cross-device loop must have a pivot predicate: ",
701             curr_frame_name);
702       }
703       Node* loop_cond = cond_it->second;
704 
705       // Add the control loop.
706       ControlLoop curr_loop;
707       status = AddControlLoop(opts, g, src_frame, edge, loop_cond, &cf_info,
708                               &curr_loop);
709       if (!status.ok()) return status;
710       control_loops[cl_key] = curr_loop;
711 
712       if (child_loop.enter != nullptr) {
713         // Connect the merge of the outer loop to the enter of the inner.
714         g->AddEdge(curr_loop.merge, 0, child_loop.enter, 0);
715       }
716       src_frame = cf_info[src_frame->id()].parent_frame;
717       child_loop = curr_loop;
718     }
719   }
720 
721   // For a cross-device edge, on the dst device, add a control edge
722   // from the merge node of the control loop to dst. If a send/recv is
723   // introduced for this edge in future partitioning, we delete this
724   // control edge and add a new control edge from the merge to the recv.
725   num_edge_ids = g->num_edge_ids();
726   for (int i = 0; i < num_edge_ids; ++i) {
727     const Edge* edge = g->FindEdgeId(i);
728     if (edge == nullptr) continue;
729 
730     const Node* src = edge->src();
731     Node* dst = edge->dst();
732     // Skip Sink/Source nodes.
733     if (!src->IsOp() || !dst->IsOp()) continue;
734 
735     const string& src_device = src->assigned_device_name();
736     const string& dst_device = dst->assigned_device_name();
737     if (src_device != dst_device) {
738       const Node* src_frame = OutputFrame(src, cf_info);
739       const Node* dst_frame = InputFrame(dst, cf_info);
740       const string& src_frame_name = cf_info[src_frame->id()].frame_name;
741       const string& dst_frame_name = cf_info[dst_frame->id()].frame_name;
742       if (!src_frame_name.empty() && src_frame_name == dst_frame_name) {
743         const string& cl_key =
744             strings::StrCat(dst_frame_name, "$$", dst_device);
745         ControlLoop loop = control_loops[cl_key];
746         DCHECK(loop.enter != nullptr);
747         // Note that we'll create multiple duplicate edges if dst has multiple
748         // cross-device inputs. This is expected by the logic in Partition(), so
749         // it can add control edges to the recv nodes once they're created.
750         g->AddControlEdge(loop.merge, dst, /*allow_duplicates=*/true);
751       }
752     }
753   }
754   return Status::OK();
755 }
756 
757 struct PriorityTopoSortNode {
PriorityTopoSortNodetensorflow::__anon89c31b760111::PriorityTopoSortNode758   PriorityTopoSortNode(const NodeDef* n, int64 st) : node(n), start_time(st) {}
759 
760   const NodeDef* node;
761   int64 start_time;
762 };
763 
764 struct PriorityTopoSortNodeGreater {
operator ()tensorflow::__anon89c31b760111::PriorityTopoSortNodeGreater765   bool operator()(const PriorityTopoSortNode& left,
766                   const PriorityTopoSortNode& right) {
767     return left.start_time > right.start_time;
768   }
769 };
770 
771 }  // namespace
772 
773 // Returns in <nodes> the nodes that should participate in epoch-based recv
774 // scheduling, along with their times; <nodes> is ordered by increasing
775 // start_time. Returns in <node_to_start_time_out> the timing for all nodes,
776 // even those not in <nodes>.
777 //
778 // Comparing to sorting on the node's start time only, this also processes the
779 // nodes in dependency order, and updates start times to ensure a node's
780 // start_time > the start time for all dependencies.
781 //
782 // Note that graph_partition_test.cc accesses this function for testing, even
783 // though it's not declared in the header.
TopologicalSortNodesWithTimePriority(const GraphDef * gdef,std::vector<std::pair<const NodeDef *,int64>> * nodes,std::unordered_map<const NodeDef *,int64> * node_to_start_time_out)784 Status TopologicalSortNodesWithTimePriority(
785     const GraphDef* gdef, std::vector<std::pair<const NodeDef*, int64>>* nodes,
786     std::unordered_map<const NodeDef*, int64>* node_to_start_time_out) {
787   // Queue of nodes to process; lowest start time is returned first.
788   std::priority_queue<PriorityTopoSortNode, std::vector<PriorityTopoSortNode>,
789                       PriorityTopoSortNodeGreater>
790       q;
791   std::unordered_map<const NodeDef*, int64> node_to_start_time;
792   auto enqueue = [&q, &node_to_start_time](const NodeDef* node) {
793     const int64 start_time = node_to_start_time[node];
794     q.emplace(node, start_time);
795   };
796 
797   // Build initial structures, initial contents of queue.
798   std::unordered_map<string, std::vector<const NodeDef*>> node_to_output_nodes;
799   std::unordered_map<const NodeDef*, int> inputs_needed;
800   for (int n = 0; n < gdef->node_size(); ++n) {
801     const NodeDef* ndef = &gdef->node(n);
802     for (int i = 0; i < ndef->input_size(); ++i) {
803       node_to_output_nodes[string(ParseTensorName(ndef->input(i)).first)]
804           .push_back(ndef);
805     }
806     int64 start_time;
807     TF_RETURN_IF_ERROR(GetNodeAttr(*ndef, "_start_time", &start_time));
808     node_to_start_time[ndef] = start_time;
809     inputs_needed[ndef] = ndef->input_size();
810     if (ndef->input_size() == 0) {
811       enqueue(ndef);
812     }
813   }
814 
815   // Determine which merge nodes are parts of loops; these
816   // need to happen in the traversal after all non-NextIteration inputs
817   // are run.
818   for (int n = 0; n < gdef->node_size(); ++n) {
819     const NodeDef* ndef = &gdef->node(n);
820     if (IsNextIteration(*ndef)) {
821       for (const NodeDef* n : node_to_output_nodes[ndef->name()]) {
822         if (IsMerge(*n)) {
823           // n is a merge that is part of a loop structure.
824           // It doesn't need to wait for this NextIteration loop
825           // when doing the traversal.
826           --inputs_needed[n];
827         }
828       }
829     }
830   }
831 
832   // Traverse.
833   std::vector<std::pair<const NodeDef*, int64>> start_times;
834   start_times.reserve(gdef->node_size());
835   while (!q.empty()) {
836     PriorityTopoSortNode cur = q.top();
837     q.pop();
838 
839     start_times.emplace_back(cur.node, cur.start_time);
840 
841     for (const NodeDef* n : node_to_output_nodes[cur.node->name()]) {
842       auto& output_start_time = node_to_start_time[n];
843       if (output_start_time <= cur.start_time) {
844         output_start_time = cur.start_time + 1;
845       }
846       if (--inputs_needed[n] == 0) {
847         enqueue(n);
848       }
849     }
850   }
851 
852   // Done.
853   nodes->swap(start_times);
854   node_to_start_time_out->swap(node_to_start_time);
855   return Status::OK();
856 }
857 
AddControlEdges(const PartitionOptions & opts,std::unordered_map<string,GraphDef> * partitions)858 Status AddControlEdges(const PartitionOptions& opts,
859                        std::unordered_map<string, GraphDef>* partitions) {
860   Status status;
861   // TODO(yuanbyu): Very naive for now. To be improved.
862   const int num_epochs = 100;
863   const int prefetch = 6;
864 
865   for (auto& part : *partitions) {
866     GraphDef* gdef = &part.second;
867     std::vector<std::pair<const NodeDef*, int64>> start_times;
868     std::unordered_map<const NodeDef*, int64> node_to_start_time;
869     status = TopologicalSortNodesWithTimePriority(gdef, &start_times,
870                                                   &node_to_start_time);
871     if (!status.ok()) {
872       return status;
873     }
874 
875     // Add a dummy node for every epoch, and add a control edge from the
876     // "last" node in the preceding epoch to the dummy node.
877     string device_name = gdef->node(0).device();
878     int64 makespan = start_times.back().second;
879     int64 resolution = (makespan / num_epochs) + 1;
880 
881     int i = 0;
882     int j = 0;
883     std::vector<NodeDef*> dummys;
884     while (i < num_epochs && static_cast<size_t>(j) < start_times.size()) {
885       if (i * resolution > start_times[j].second) {
886         j++;
887       } else {
888         NodeDef* dummy = AddControlTrigger(opts, gdef, device_name, i,
889                                            i * resolution, &status);
890         if (!status.ok()) {
891           return status;
892         }
893         dummys.push_back(dummy);
894         if (j > 0) {
895           string src_name = start_times[j - 1].first->name();
896           AddInput(dummy, src_name, Graph::kControlSlot);
897         }
898         i++;
899       }
900     }
901 
902     // Finally, add the control edges to recvs.
903     for (int n = 0; n < gdef->node_size(); ++n) {
904       NodeDef* ndef = gdef->mutable_node(n);
905       if (ndef->op() == "_Recv") {
906         const int64 start_time = node_to_start_time[ndef];
907         const int recv_epoch = start_time / resolution;
908         if (recv_epoch >= prefetch) {
909           NodeDef* dummy = dummys[recv_epoch - prefetch];
910           AddInput(ndef, dummy->name(), Graph::kControlSlot);
911         }
912       }
913     }
914   }
915   return Status::OK();
916 }
917 
918 // If 'ndef' is a Send or Recv, fills its attr send_device_incarnation
919 // if possible.
SetIncarnation(const PartitionOptions & opts,NodeDef * ndef)920 void SetIncarnation(const PartitionOptions& opts, NodeDef* ndef) {
921   StringPiece op(ndef->op());
922   if (op != "_Send" && op != "_Recv") {
923     // Not related to send/recv.
924     return;
925   }
926   string send_device;
927   if (!GetNodeAttr(*ndef, "send_device", &send_device).ok()) {
928     // No known send_device. The runtime will detect it later.
929     return;
930   }
931   int64 incarnation = PartitionOptions::kIllegalIncarnation;
932   if (!GetNodeAttr(*ndef, "send_device_incarnation", &incarnation).ok() ||
933       (incarnation == PartitionOptions::kIllegalIncarnation)) {
934     incarnation = opts.get_incarnation(send_device);
935     SetAttrValue(incarnation,
936                  &((*ndef->mutable_attr())["send_device_incarnation"]));
937   }
938 }
939 
940 // Sets attribute send_device_incarnation of all Send/Recv nodes in
941 // 'gdef', if possible.
SetIncarnation(const PartitionOptions & opts,GraphDef * gdef)942 void SetIncarnation(const PartitionOptions& opts, GraphDef* gdef) {
943   for (NodeDef& ndef : *gdef->mutable_node()) {
944     SetIncarnation(opts, &ndef);
945   }
946   for (FunctionDef& fdef : *gdef->mutable_library()->mutable_function()) {
947     for (NodeDef& ndef : *fdef.mutable_node_def()) {
948       SetIncarnation(opts, &ndef);
949     }
950   }
951 }
952 
Partition(const PartitionOptions & opts,Graph * g,std::unordered_map<string,GraphDef> * partitions)953 Status Partition(const PartitionOptions& opts, Graph* g,
954                  std::unordered_map<string, GraphDef>* partitions) {
955   Status status;
956   partitions->clear();
957 
958   GraphInfo g_info;
959   if (!opts.control_flow_added) {
960     // Add the "code" for distributed execution of control flow. Code is
961     // added only for the frames that are placed on multiple devices. The
962     // new graph is an equivalent transformation of the original graph and
963     // has the property that it can be subsequently partitioned arbitrarily
964     // (down to the level of individual device) for distributed execution.
965     status = AddControlFlow(opts, g, &g_info);
966     if (!status.ok()) return status;
967   }
968 
969   // At this point, all the graph mutations have been done. Build memory
970   // and device type info for every node and edge in the graph.
971   status = BuildMemoryDeviceInfo(*g, &g_info);
972   if (!status.ok()) return status;
973 
974   string dstp;
975   std::vector<const Edge*> inputs;
976   DupRecvTable dup_recv(3);
977   // For a node dst, 'ref_recvs' remembers the recvs introduced by a ref
978   // edge to dst. 'ref_control_inputs' remembers the inputs by a non-ref
979   // edge to dst. We will add a control edge for every pair in
980   // (ref_recvs x ref_control_inputs).
981   std::vector<NodeDef*> ref_recvs;
982   std::vector<string> ref_control_inputs;
983 
984   int32 num_data = 0;
985   int32 num_control = 0;
986   for (const Node* dst : g->op_nodes()) {
987     dstp = opts.node_to_loc(dst);
988     GraphDef* dst_graph = &(*partitions)[dstp];
989     NodeDef* dst_def = dst_graph->add_node();
990     *dst_def = dst->def();
991     MergeDebugInfo(NodeDebugInfo(dst->def()), dst_def);
992     dst_def->set_device(dst->assigned_device_name());
993     dst_def->clear_input();  // Inputs are filled below
994     if (opts.need_to_record_start_times) {
995       int64 start_time;
996       status = GetNodeAttr(*dst_def, "_start_time", &start_time);
997       if (errors::IsNotFound(status)) {
998         start_time = opts.start_times[dst->id()].value();
999         AddNodeAttr("_start_time", start_time, dst_def);
1000       } else if (!status.ok()) {
1001         return status;
1002       }
1003     }
1004 
1005     // Arrange the incoming edges to dst so that input[i] holds the
1006     // input flowing into slot numbered i. Trailing entries in input[]
1007     // hold control edges.
1008     inputs.clear();
1009     inputs.resize(dst->num_inputs(), nullptr);
1010     ref_recvs.clear();
1011     ref_control_inputs.clear();
1012     const Edge* control_flow_edge = nullptr;
1013     int32 num_control_flow_edges = 0;
1014     int32 num_input_edges = 0;
1015     for (const Edge* edge : dst->in_edges()) {
1016       if (edge->IsControlEdge()) {
1017         if (IsMerge(edge->src()) && IsControlLoop(edge->src())) {
1018           // This is one of the control edges added for control flow. There
1019           // can be multiple such edges as the dest node may have multiple
1020           // remote inputs. We keep track of the number of such edges.
1021           control_flow_edge = edge;
1022           ++num_control_flow_edges;
1023         } else {
1024           inputs.push_back(edge);
1025         }
1026       } else {
1027         DCHECK(inputs[edge->dst_input()] == nullptr);
1028         inputs[edge->dst_input()] = edge;
1029         ++num_input_edges;
1030       }
1031     }
1032 
1033     if (num_input_edges != dst->num_inputs()) {
1034       return errors::InvalidArgument("Incomplete graph, missing ",
1035                                      (dst->num_inputs() - num_input_edges),
1036                                      " inputs for ", dst->name());
1037     }
1038 
1039     // Process in order so that all data edges are added as inputs to
1040     // dst in Edge::dst_input() order.
1041     for (const Edge* edge : inputs) {
1042       const Node* src = edge->src();
1043       if (!src->IsOp()) continue;  // Skip Sink/Source nodes.
1044 
1045       GraphDef* src_graph = &(*partitions)[opts.node_to_loc(src)];
1046       if (src_graph == dst_graph && !NeedSameDeviceSendRecv(edge, g_info)) {
1047         // Same partition and compatible memory types:
1048         AddInput(dst_def, src->name(), edge->src_output());
1049         if (edge->IsControlEdge() ||
1050             !IsRefType(src->output_type(edge->src_output()))) {
1051           ref_control_inputs.push_back(src->name());
1052         }
1053         continue;
1054       }
1055 
1056       int64 send_start_time = 0;
1057       int64 recv_start_time = 0;
1058       if (opts.scheduling_for_recvs) {
1059         status = GetNodeAttr(src->attrs(), "_start_time", &send_start_time);
1060         if (errors::IsNotFound(status) && opts.need_to_record_start_times) {
1061           send_start_time = opts.start_times[src->id()].value();
1062         } else if (!status.ok()) {
1063           return status;
1064         }
1065 
1066         status = GetNodeAttr(dst->attrs(), "_start_time", &recv_start_time);
1067         if (errors::IsNotFound(status) && opts.need_to_record_start_times) {
1068           recv_start_time = opts.start_times[dst->id()].value();
1069         } else if (!status.ok()) {
1070           return status;
1071         }
1072       }
1073 
1074       // Check whether there is already a send/recv pair transferring
1075       // the same tensor/control from the src to dst partition.
1076       const bool on_host = IsDstInputOnHost(edge, g_info);
1077       DupRecvKey key{src->id(), edge->src_output(), dst_graph, on_host};
1078       auto iter = dup_recv.find(key);
1079       if (iter != dup_recv.end()) {
1080         // We found one. Reuse the data/control transferred already.
1081         const string& recv_node_name = iter->second.recv->name();
1082         if (edge->IsControlEdge()) {
1083           AddInput(dst_def, recv_node_name, Graph::kControlSlot);
1084         } else {
1085           AddInput(dst_def, recv_node_name, 0);
1086         }
1087         ref_control_inputs.push_back(recv_node_name);
1088 
1089         // We want the start_time for the recv to be the smallest of the start
1090         // times of it's consumers. So we update this whenever we use a recv,
1091         // and write it out to the attribute at the end of the subroutine
1092         if (iter->second.start_time > recv_start_time) {
1093           iter->second.start_time = recv_start_time;
1094         }
1095         continue;
1096       }
1097 
1098       NodeDefBuilder::NodeOut send_from;
1099       if (edge->IsControlEdge()) {
1100         // Insert a dummy const node that will generate a tiny
1101         // data element to be sent from send to recv.
1102         VLOG(1) << "Send/Recv control: " << src->assigned_device_name() << "["
1103                 << src->name() << "] -> " << dst->assigned_device_name() << "["
1104                 << dst->name() << "]";
1105         NodeDef* dummy = AddDummyConst(opts, src_graph, edge, &status);
1106         if (!status.ok()) return status;
1107         // Set the start time for this dummy node.
1108         if (opts.scheduling_for_recvs) {
1109           AddNodeAttr("_start_time", send_start_time, dummy);
1110         }
1111         AddInput(dummy, src->name(), Graph::kControlSlot);
1112         send_from.Reset(dummy->name(), 0, DT_FLOAT);
1113       } else {
1114         send_from.Reset(src->name(), edge->src_output(), EdgeType(edge));
1115       }
1116 
1117       // Need to split edge by placing matching send/recv nodes on
1118       // the src/dst sides of the edge.
1119       NodeDef* send = AddSend(opts, g_info, src_graph, edge, send_from,
1120                               send_start_time, &status);
1121       if (!status.ok()) return status;
1122 
1123       NodeDef* real_recv = nullptr;
1124       NodeDef* recv =
1125           AddRecv(opts, g_info, dst_graph, edge, &real_recv, &status);
1126       if (!status.ok()) return status;
1127 
1128       // Fix up the control flow edge.
1129       // NOTE(yuanbyu): 'real_recv' must be the real recv node.
1130       if (src_graph == dst_graph) {
1131         // For same device send/recv, add a control edge from send to recv.
1132         // This prevents the asynchronous recv kernel from being scheduled
1133         // before the data is available.
1134         AddInput(real_recv, send->name(), Graph::kControlSlot);
1135       } else if (control_flow_edge != nullptr) {
1136         // Redirect control edge to the real recv since this is not the same
1137         // device send/recv.
1138         --num_control_flow_edges;
1139         AddInput(real_recv, control_flow_edge->src()->name(),
1140                  Graph::kControlSlot);
1141       }
1142 
1143       if (!edge->IsControlEdge() &&
1144           IsRefType(src->output_type(edge->src_output()))) {
1145         AddNodeAttr("_start_time", recv_start_time, recv);
1146         if (real_recv != recv) {
1147           AddNodeAttr("_start_time", recv_start_time, real_recv);
1148         }
1149         // If src is of ref type and the edge is not a control edge, dst has
1150         // read semantics and therefore we must control the recv.
1151         ref_recvs.push_back(real_recv);
1152       } else {
1153         // Memorize the send/recv pair, only if this is not a "ref" edge.
1154         // NOTE(yuanbyu): Collapsing ref edges requires extreme care so
1155         // for now we don't do it.
1156         dup_recv[key] = {recv, real_recv, recv_start_time};
1157         ref_control_inputs.push_back(recv->name());
1158       }
1159 
1160       if (edge->IsControlEdge()) {
1161         ++num_control;
1162         AddInput(dst_def, recv->name(), Graph::kControlSlot);
1163       } else {
1164         ++num_data;
1165         AddInput(dst_def, recv->name(), 0);
1166       }
1167     }
1168 
1169     // Add control edges from 'ref_control_inputs' to 'ref_recvs'.
1170     // NOTE(yuanbyu): Adding these control edges should not introduce
1171     // deadlocks. 'dst' has implicit "read" nodes that, when we split
1172     // across devices, are made explicit; Retargeting the dependencies
1173     // to 'dst' to those nodes would not introduce cycles if there isn't
1174     // one before the transformation.
1175     // NOTE(yuanbyu): This may impact performance because it defers the
1176     // execution of recvs until all the other inputs become available.
1177     AddReadControl(ref_recvs, ref_control_inputs);
1178 
1179     // Add back the control edges for control flow that are not used.
1180     if (control_flow_edge != nullptr) {
1181       for (int i = 0; i < num_control_flow_edges; ++i) {
1182         AddInput(dst_def, control_flow_edge->src()->name(),
1183                  Graph::kControlSlot);
1184       }
1185     }
1186   }
1187 
1188   const FunctionLibraryDefinition* flib_def = opts.flib_def;
1189   if (flib_def == nullptr) {
1190     flib_def = &g->flib_def();
1191   }
1192 
1193   // Set versions, function library and send/recv incarnation.
1194   for (auto& it : *partitions) {
1195     GraphDef* gdef = &it.second;
1196     *gdef->mutable_versions() = g->versions();
1197     // Prune unreachable functions from `flib_def` before adding them to `gdef`.
1198     *gdef->mutable_library() = flib_def->ReachableDefinitions(*gdef).ToProto();
1199 
1200     // Traverse the graph to fill every send/recv op's incarnation
1201     // information.
1202     SetIncarnation(opts, gdef);
1203   }
1204 
1205   // Set the start times for recvs at the very end.
1206   if (opts.scheduling_for_recvs) {
1207     for (auto& it : dup_recv) {
1208       AddNodeAttr("_start_time", it.second.start_time, it.second.recv);
1209       if (it.second.real_recv != it.second.recv) {
1210         AddNodeAttr("_start_time", it.second.start_time, it.second.real_recv);
1211       }
1212     }
1213   }
1214 
1215   VLOG(1) << "Added send/recv: controls=" << num_control
1216           << ", data=" << num_data;
1217   return Status::OK();
1218 }
1219 
1220 }  // namespace tensorflow
1221