1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/graph_partition.h"
17 
18 #include <deque>
19 #include <queue>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/memory_types.h"
28 #include "tensorflow/core/framework/node_def_builder.h"
29 #include "tensorflow/core/framework/tensor.pb.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/framework/versions.pb.h"
32 #include "tensorflow/core/graph/algorithm.h"
33 #include "tensorflow/core/graph/control_flow.h"
34 #include "tensorflow/core/graph/costmodel.h"
35 #include "tensorflow/core/graph/graph_def_builder.h"
36 #include "tensorflow/core/graph/node_builder.h"
37 #include "tensorflow/core/graph/tensor_id.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/hash/hash.h"
40 #include "tensorflow/core/lib/strings/str_util.h"
41 #include "tensorflow/core/platform/logging.h"
42 #include "tensorflow/core/util/device_name_utils.h"
43 #include "tensorflow/core/util/dump_graph.h"
44 
45 namespace tensorflow {
46 
47 namespace {
48 
IsMerge(const NodeDef & node_def)49 inline bool IsMerge(const NodeDef& node_def) {
50   return node_def.op() == "Merge" || node_def.op() == "RefMerge" ||
51          node_def.op() == "_XlaMerge";
52 }
53 
IsNextIteration(const NodeDef & node_def)54 inline bool IsNextIteration(const NodeDef& node_def) {
55   return node_def.op() == "NextIteration" ||
56          node_def.op() == "RefNextIteration";
57 }
58 
59 struct DupRecvKey {
60   int src_node_id;           // Edge's src node id
61   int src_output_slot;       // Edge's src node output slot
62   GraphDef* dst_graph;       // Edge's dst node is in this subgraph
63   bool recv_output_on_host;  // The output of recv is on host
64 
65   template <typename H>
AbslHashValue(H h,const DupRecvKey & c)66   friend H AbslHashValue(H h, const DupRecvKey& c) {
67     return H::combine(std::move(h), c.src_node_id, c.src_output_slot,
68                       reinterpret_cast<std::uintptr_t>(c.dst_graph),
69                       c.recv_output_on_host);
70   }
71 
operator ==(const DupRecvKey & x,const DupRecvKey & y)72   friend bool operator==(const DupRecvKey& x, const DupRecvKey& y) {
73     return (x.src_node_id == y.src_node_id) &&
74            (x.src_output_slot == y.src_output_slot) &&
75            (x.dst_graph == y.dst_graph) &&
76            (x.recv_output_on_host == y.recv_output_on_host);
77   }
78 };
79 
80 // struct used to store the recvs, so that start times can be properly updated
81 struct RecvInfo {
82   NodeDef* recv;
83   NodeDef* real_recv;
84   int64 start_time;
85 };
86 
87 typedef absl::flat_hash_map<DupRecvKey, RecvInfo> DupRecvTable;
88 
89 // A map used to store memory types for the inputs/outputs of every node.
90 // The key is a pair of ints consisting of a node id and input/output index.
91 // TODO(power): migrate back to std::pair when absl::Hash is fixed for MSVC.
92 struct NodePort {
93   int node_id;
94   int index;
95 
operator ==(const NodePort & x,const NodePort & y)96   friend bool operator==(const NodePort& x, const NodePort& y) {
97     return x.node_id == y.node_id && x.index == y.index;
98   }
99 
100   template <typename H>
AbslHashValue(H h,const NodePort & c)101   friend H AbslHashValue(H h, const NodePort& c) {
102     return H::combine(std::move(h), c.node_id, c.index);
103   }
104 };
105 
106 typedef absl::flat_hash_map<NodePort, MemoryType> MemoryTypeMap;
107 
108 // We collect the following information about the graph before performing
109 // graph partitioning.
110 struct GraphInfo {
111   std::vector<DeviceType> device_types;
112   MemoryTypeMap input_types;
113   MemoryTypeMap output_types;
114   std::vector<ControlFlowInfo> cf_info;
115 };
116 
EdgeType(const Edge * e)117 DataType EdgeType(const Edge* e) {
118   if (e->IsControlEdge()) {
119     return DT_FLOAT;
120   } else {
121     return e->dst()->input_type(e->dst_input());
122   }
123 }
124 
125 // Return true iff we need to add the same device send/recv for 'edge'.
NeedSameDeviceSendRecv(const Edge * edge,const GraphInfo & info)126 bool NeedSameDeviceSendRecv(const Edge* edge, const GraphInfo& info) {
127   if (edge->IsControlEdge()) {
128     return false;
129   }
130 
131   const Node* src = edge->src();
132   const Node* dst = edge->dst();
133   if (src->assigned_device_name() == dst->assigned_device_name()) {
134     int src_port = edge->src_output();
135     int dst_port = edge->dst_input();
136     if (info.device_types[src->id()] != DEVICE_CPU) {
137       auto src_it = info.output_types.find({src->id(), src_port});
138       DCHECK(src_it != info.output_types.end());
139       auto dst_it = info.input_types.find({dst->id(), dst_port});
140       DCHECK(dst_it != info.input_types.end());
141       return src_it->second != dst_it->second;
142     }
143   }
144   return false;
145 }
146 
147 // Return true iff (dst, dst_input) is specified on host memory.
IsDstInputOnHost(const Edge * edge,const GraphInfo & info)148 bool IsDstInputOnHost(const Edge* edge, const GraphInfo& info) {
149   const Node* dst = edge->dst();
150   int dst_port = edge->dst_input();
151   if (info.device_types[dst->id()] != DEVICE_CPU) {
152     if (edge->IsControlEdge()) return false;
153     auto dst_it = info.input_types.find({dst->id(), dst_port});
154     DCHECK(dst_it != info.input_types.end());
155     return dst_it->second == HOST_MEMORY;
156   }
157   return true;
158 }
159 
160 // Add an input to dst that comes from the "src_slot" output of the
161 // node named by "src_name".
AddInput(NodeDef * dst,StringPiece src_name,int src_slot)162 void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
163   if (src_slot == Graph::kControlSlot) {
164     dst->add_input(strings::StrCat("^", src_name));
165   } else if (src_slot == 0) {
166     dst->add_input(src_name.data(), src_name.size());
167   } else {
168     dst->add_input(strings::StrCat(src_name, ":", src_slot));
169   }
170 }
171 
172 // Add a control edge from each input to each recv.
AddReadControl(const std::vector<NodeDef * > & recvs,const std::vector<string> & inputs)173 void AddReadControl(const std::vector<NodeDef*>& recvs,
174                     const std::vector<string>& inputs) {
175   for (NodeDef* recv : recvs) {
176     for (const string& input : inputs) {
177       recv->add_input(strings::StrCat("^", input));
178     }
179   }
180 }
181 
SetSendRecvAttrs(const PartitionOptions & opts,const Edge * edge,NodeDefBuilder * builder)182 void SetSendRecvAttrs(const PartitionOptions& opts, const Edge* edge,
183                       NodeDefBuilder* builder) {
184   builder->Attr("tensor_name",
185                 strings::StrCat("edge_", edge->id(), "_", edge->src()->name()));
186   builder->Attr("send_device", edge->src()->assigned_device_name());
187   builder->Attr("send_device_incarnation",
188                 static_cast<int64>(
189                     opts.get_incarnation(edge->src()->assigned_device_name())));
190   builder->Attr("recv_device", edge->dst()->assigned_device_name());
191   builder->Attr("client_terminated", false);
192   builder->Attr("_src", edge->src()->name());
193   builder->Attr("_dst", edge->dst()->name());
194 }
195 
AddSend(const PartitionOptions & opts,const GraphInfo & g_info,GraphDef * gdef,const Edge * edge,NodeDefBuilder::NodeOut send_from,int64_t start_time,Status * status)196 NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info,
197                  GraphDef* gdef, const Edge* edge,
198                  NodeDefBuilder::NodeOut send_from, int64_t start_time,
199                  Status* status) {
200   const DataType dtype = send_from.data_type;
201   const DataType cast_dtype = opts.should_cast ? opts.should_cast(edge) : dtype;
202   const Node* src = edge->src();
203   const int src_port = edge->src_output();
204 
205   // host_memory = true iff we need to use HostSend/HostCast.
206   bool host_memory = false;
207   if (!edge->IsControlEdge()) {
208     auto src_it = g_info.output_types.find({src->id(), src_port});
209     DCHECK(src_it != g_info.output_types.end());
210     host_memory = (src_it->second == HOST_MEMORY);
211   }
212 
213   // Add a cast node that casts dtype to cast_dtype.
214   // NOTE(yuanbyu): Only cast for cross-device send/recv.
215   if (dtype != cast_dtype && !NeedSameDeviceSendRecv(edge, g_info)) {
216     const string cast_op = (host_memory) ? "_HostCast" : "Cast";
217     NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op,
218                                 NodeDebugInfo(*src));
219     cast_builder.Device(src->assigned_device_name()).Input(send_from);
220     if (opts.scheduling_for_recvs) {
221       cast_builder.Attr("_start_time", start_time);
222     }
223     cast_builder.Attr("DstT", cast_dtype);
224 
225     if (cast_dtype == DT_BFLOAT16) {
226       // the below attribute specifies that the cast to bfloat16 should use
227       // truncation. This is needed to retain legacy behavior when we change
228       // the default bfloat16 casts to use rounding instead of truncation
229       cast_builder.Attr("Truncate", true);
230     }
231 
232     NodeDef* cast = gdef->add_node();
233     *status = cast_builder.Finalize(cast, /*consume=*/true);
234     if (!status->ok()) return nullptr;
235 
236     // Connect the Send op to the cast.
237     send_from.Reset(cast->name(), 0, cast_dtype);
238   }
239 
240   // Add the send node.
241   const string send_op = (host_memory) ? "_HostSend" : "_Send";
242   NodeDefBuilder send_builder(opts.new_name(src->name()), send_op,
243                               NodeDebugInfo(*src));
244   SetSendRecvAttrs(opts, edge, &send_builder);
245   send_builder.Device(src->assigned_device_name()).Input(send_from);
246   if (opts.scheduling_for_recvs) {
247     send_builder.Attr("_start_time", start_time);
248   }
249   NodeDef* send = gdef->add_node();
250   *status = send_builder.Finalize(send, /*consume=*/true);
251   return send;
252 }
253 
AddRecv(const PartitionOptions & opts,const GraphInfo & g_info,GraphDef * gdef,const Edge * edge,NodeDef ** real_recv,Status * status)254 NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info,
255                  GraphDef* gdef, const Edge* edge, NodeDef** real_recv,
256                  Status* status) {
257   const DataType dtype = EdgeType(edge);
258   const Node* src = edge->src();
259   const Node* dst = edge->dst();
260   const int dst_port = edge->dst_input();
261   DataType cast_dtype = dtype;
262 
263   // NOTE(yuanbyu): Only cast for cross-device send/recv.
264   if (opts.should_cast && !NeedSameDeviceSendRecv(edge, g_info)) {
265     cast_dtype = opts.should_cast(edge);
266   }
267 
268   // host_memory = true iff we need to use HostRecv/HostCast.
269   // Also log the introduction of the send-recv pair, for performance debugging.
270   bool host_memory = false;
271   if (!edge->IsControlEdge()) {
272     auto dst_it = g_info.input_types.find({dst->id(), dst_port});
273     DCHECK(dst_it != g_info.input_types.end());
274     host_memory = (dst_it->second == HOST_MEMORY);
275     bool src_host_memory = false;
276     if (VLOG_IS_ON(1)) {
277       const int src_port = edge->src_output();
278       auto src_it = g_info.output_types.find({src->id(), src_port});
279       DCHECK(src_it != g_info.output_types.end());
280       src_host_memory = (src_it->second == HOST_MEMORY);
281     }
282     VLOG(1) << "Receiving data"
283             << " from " << src->name() << " (" << src->type_string() << ")"
284             << " on " << src->assigned_device_name() << " in "
285             << (src_host_memory ? "host memory" : "device memory") << " for "
286             << dst->name() << " (" << dst->type_string() << ")"
287             << " on " << dst->assigned_device_name() << " in "
288             << (host_memory ? "host memory" : "device memory");
289   } else {
290     // Log control-edge transfers too, but don't mention memory space since it's
291     // irrelevant.
292     VLOG(1) << "Receiving control"
293             << " from " << src->name() << " (" << src->type_string() << ")"
294             << " on " << src->assigned_device_name() << " for " << dst->name()
295             << " (" << dst->type_string() << ")"
296             << " on " << dst->assigned_device_name();
297   }
298 
299   // Add the recv node.
300   const string recv_op = (host_memory) ? "_HostRecv" : "_Recv";
301   NodeDefBuilder recv_builder(opts.new_name(src->name()), recv_op,
302                               NodeDebugInfo(*src));
303   SetSendRecvAttrs(opts, edge, &recv_builder);
304   recv_builder.Device(dst->assigned_device_name())
305       .Attr("tensor_type", cast_dtype);
306   NodeDef* recv = gdef->add_node();
307   *status = recv_builder.Finalize(recv, /*consume=*/true);
308   if (!status->ok()) return nullptr;
309   *real_recv = recv;
310 
311   // Add the cast node (from cast_dtype to dtype) or an Identity node.
312   if (dtype != cast_dtype) {
313     const string cast_op = (host_memory) ? "_HostCast" : "Cast";
314     NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op,
315                                 NodeDebugInfo(*src));
316     cast_builder.Attr("DstT", dtype);
317     cast_builder.Device(dst->assigned_device_name())
318         .Input(recv->name(), 0, cast_dtype);
319     NodeDef* cast = gdef->add_node();
320     *status = cast_builder.Finalize(cast, /*consume=*/true);
321     if (!status->ok()) return nullptr;
322     return cast;
323   } else if (edge->IsControlEdge()) {
324     // An Identity is only needed for control edges.
325     NodeDefBuilder id_builder(opts.new_name(src->name()), "Identity",
326                               NodeDebugInfo(*src));
327     id_builder.Device(dst->assigned_device_name())
328         .Input(recv->name(), 0, cast_dtype);
329     NodeDef* id = gdef->add_node();
330     *status = id_builder.Finalize(id, /*consume=*/true);
331     if (!status->ok()) return nullptr;
332     return id;
333   } else {
334     return recv;
335   }
336 }
337 
AddDummyConst(const PartitionOptions & opts,GraphDef * gdef,const Edge * edge,Status * status)338 NodeDef* AddDummyConst(const PartitionOptions& opts, GraphDef* gdef,
339                        const Edge* edge, Status* status) {
340   const Node* src = edge->src();
341   Tensor tensor(DT_FLOAT, TensorShape({0}));
342   NodeDef* result = gdef->add_node();
343   *status = NodeDefBuilder(opts.new_name(src->name()), "Const")
344                 .Device(src->assigned_device_name())
345                 .Attr("dtype", DT_FLOAT)
346                 .Attr("value", tensor)
347                 .Finalize(result, /*consume=*/true);
348   return result;
349 }
350 
351 // A dummy node for scheduling.
AddControlTrigger(const PartitionOptions & opts,GraphDef * gdef,const string & assigned_device_name,int64_t epoch,int64_t starttime,Status * status)352 NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef,
353                            const string& assigned_device_name, int64_t epoch,
354                            int64_t starttime, Status* status) {
355   NodeDef* result = gdef->add_node();
356   *status = NodeDefBuilder(opts.new_name(strings::StrCat("synch_", epoch)),
357                            "ControlTrigger")
358                 .Device(assigned_device_name)
359                 .Attr("_start_time", starttime)
360                 .Finalize(result, /*consume=*/true);
361   return result;
362 }
363 
364 // Optimize colocation for control flow nodes. For cond, we want the
365 // switch nodes to colocate with its data input. This is particularly
366 // needed for conditional reading of a remote variable. It may also
367 // reduce the number of devices involved in a loop.
368 // TODO(yuanbyu): In this case, we don't respect the requested device in
369 // the GraphDef for these nodes. Ideally, the placer would enforce the
370 // colocation to render this unnecessary.
OptimizeControlFlowColocation(Graph * graph)371 void OptimizeControlFlowColocation(Graph* graph) {
372   auto visit = [](Node* node) {
373     if (IsSwitch(node)) {
374       for (const Edge* in_edge : node->in_edges()) {
375         if (in_edge->dst_input() == 0) {
376           // Colocate with the data input.
377           node->set_assigned_device_name(
378               in_edge->src()->assigned_device_name());
379           return;
380         }
381       }
382     } else if (IsExit(node)) {
383       for (const Edge* in_edge : node->in_edges()) {
384         if (!in_edge->IsControlEdge()) {
385           // Colocate with upstream node.
386           node->set_assigned_device_name(
387               in_edge->src()->assigned_device_name());
388           return;
389         }
390       }
391     } else {
392       if ((IsEnter(node) && !IsRefType(node->input_type(0))) ||
393           IsNextIteration(node)) {
394         const Edge* data_edge = nullptr;
395         for (const Edge* out_edge : node->out_edges()) {
396           if (!out_edge->IsControlEdge()) {
397             data_edge = out_edge;
398             break;
399           }
400         }
401         // Colocate with the first downstream data node.
402         if (data_edge) {
403           node->set_assigned_device_name(
404               data_edge->dst()->assigned_device_name());
405         }
406       }
407     }
408   };
409   DFS(*graph, visit, {});
410 }
411 
ControlLoopName(const string & name)412 string ControlLoopName(const string& name) {
413   return strings::StrCat("_cloop", name);
414 }
415 
IsControlLoop(const Node * node)416 bool IsControlLoop(const Node* node) {
417   const string& name = node->name();
418   return absl::StartsWith(name, "_cloop");
419 }
420 
421 // An enter node for control flow.
AddControlEnter(Graph * g,const string & node_name,const string & device_name,const string & frame_name,const int parallel_iterations,Status * status)422 Node* AddControlEnter(Graph* g, const string& node_name,
423                       const string& device_name, const string& frame_name,
424                       const int parallel_iterations, Status* status) {
425   NodeBuilder node_builder(node_name, "Enter", g->op_registry());
426   node_builder.Input({"dummy", 0, DT_FLOAT});
427   node_builder.Attr("frame_name", frame_name);
428   node_builder.Attr("parallel_iterations", parallel_iterations);
429   Node* res_node;
430   *status = node_builder.Finalize(g, &res_node, /*consume=*/true);
431   if (!status->ok()) return nullptr;
432   res_node->set_assigned_device_name(device_name);
433   return res_node;
434 }
435 
436 // A merge node for control flow.
AddControlMerge(const string & in_name1,const string & in_name2,Graph * g,const string & node_name,const string & device_name,Status * status)437 Node* AddControlMerge(const string& in_name1, const string& in_name2, Graph* g,
438                       const string& node_name, const string& device_name,
439                       Status* status) {
440   NodeBuilder node_builder(node_name, "Merge", g->op_registry());
441   node_builder.Input({{in_name1, 0, DT_FLOAT}, {in_name2, 0, DT_FLOAT}});
442   Node* res_node;
443   *status = node_builder.Finalize(g, &res_node, /*consume=*/true);
444   if (!status->ok()) return nullptr;
445   res_node->set_assigned_device_name(device_name);
446   return res_node;
447 }
448 
449 // A switch node for control flow.
AddControlSwitch(NodeBuilder::NodeOut input1,NodeBuilder::NodeOut input2,const string & device_name,const GraphDefBuilder::Options & bopts)450 Node* AddControlSwitch(NodeBuilder::NodeOut input1, NodeBuilder::NodeOut input2,
451                        const string& device_name,
452                        const GraphDefBuilder::Options& bopts) {
453   Node* res_node =
454       ops::BinaryOp("Switch", std::move(input1), std::move(input2), bopts);
455   if (bopts.HaveError()) return nullptr;
456   res_node->set_assigned_device_name(device_name);
457   return res_node;
458 }
459 
460 // A next_iteration node for control flow.
AddControlNext(NodeBuilder::NodeOut input,const string & device_name,const GraphDefBuilder::Options & bopts)461 Node* AddControlNext(NodeBuilder::NodeOut input, const string& device_name,
462                      const GraphDefBuilder::Options& bopts) {
463   Node* res_node = ops::UnaryOp("NextIteration", std::move(input), bopts);
464   if (bopts.HaveError()) return nullptr;
465   res_node->set_assigned_device_name(device_name);
466   return res_node;
467 }
468 
EmptyConst(const GraphDefBuilder::Options & options)469 Node* EmptyConst(const GraphDefBuilder::Options& options) {
470   if (options.HaveError()) return nullptr;
471   NodeBuilder node_builder(options.GetNameForOp("Const"), "Const",
472                            options.op_registry());
473   const DataType dt = DataTypeToEnum<float>::v();
474   TensorProto proto;
475   proto.set_dtype(dt);
476   TensorShape empty_shape({0});
477   empty_shape.AsProto(proto.mutable_tensor_shape());
478   node_builder.Attr("dtype", dt).Attr("value", proto);
479   return options.FinalizeBuilder(&node_builder);
480 }
481 
482 // A dummy const node for control flow.
AddControlConst(const string & device_name,const GraphDefBuilder::Options & bopts)483 Node* AddControlConst(const string& device_name,
484                       const GraphDefBuilder::Options& bopts) {
485   Node* res_node = EmptyConst(bopts);
486   if (bopts.HaveError()) return nullptr;
487   res_node->set_assigned_device_name(device_name);
488   return res_node;
489 }
490 
491 // A synthetic loop, made up of dummy nodes. It performs control-flow actions
492 // on behalf of a leader on a different device.
493 struct ControlLoop {
494   Node* enter = nullptr;
495   Node* merge = nullptr;
496   Node* switch_node = nullptr;
497 };
498 
499 // Add the control flow info of a new node added during partitioning.
500 // The new node has the same control flow info as src.
AddControlFlowInfo(const Node * node,const Node * src,std::vector<ControlFlowInfo> * cf_info)501 void AddControlFlowInfo(const Node* node, const Node* src,
502                         std::vector<ControlFlowInfo>* cf_info) {
503   int id = node->id();
504   if (static_cast<size_t>(id) >= cf_info->size()) {
505     cf_info->resize(id + 1);
506   }
507   const ControlFlowInfo& src_info = (*cf_info)[src->id()];
508   ControlFlowInfo* info = &(*cf_info)[id];
509   info->frame = src_info.frame;
510   info->parent_frame = src_info.parent_frame;
511   info->frame_name = src_info.frame_name;
512 }
513 
514 // Constructs a control loop. Returns a struct containing the newly created
515 // enter, merge, and switch nodes. The enter and merge nodes are used in the
516 // recursive construction of control loops for nested frames (loops). The
517 // switch node will be connected to the LoopCond node. The merge node will
518 // be connected to all the recvs of the same frame by control edges when
519 // the actual partitioning happens.
AddControlLoop(const PartitionOptions & opts,Graph * g,const Node * src,const Edge * edge,Node * loop_cond,std::vector<ControlFlowInfo> * cf_info,ControlLoop * loop)520 Status AddControlLoop(const PartitionOptions& opts, Graph* g, const Node* src,
521                       const Edge* edge, Node* loop_cond,
522                       std::vector<ControlFlowInfo>* cf_info,
523                       ControlLoop* loop) {
524   Status status;
525   GraphDefBuilder::Options bopts(g, &status);
526   const ControlFlowInfo& src_info = (*cf_info)[src->id()];
527   const string& device_name = edge->dst()->assigned_device_name();
528   const string& frame_name = src_info.frame_name;
529   int parallel_iterations;
530   status = GetNodeAttr(src_info.frame->attrs(), "parallel_iterations",
531                        ¶llel_iterations);
532   if (!status.ok()) return status;
533 
534   // The names of the nodes to be added.
535   const string& enter_name =
536       ControlLoopName(opts.new_name(edge->dst()->name()));
537   const string& merge_name =
538       ControlLoopName(opts.new_name(edge->dst()->name()));
539   const string& switch_name =
540       ControlLoopName(opts.new_name(edge->dst()->name()));
541   const string& next_name = ControlLoopName(opts.new_name(edge->dst()->name()));
542 
543   // Add the nodes to the graph g.
544   Node* enter = AddControlEnter(g, enter_name, device_name, frame_name,
545                                 parallel_iterations, &status);
546   if (!status.ok()) return status;
547   Node* merge = AddControlMerge(enter_name, next_name, g, merge_name,
548                                 device_name, &status);
549   if (!status.ok()) return status;
550   Node* switch_node = AddControlSwitch(merge, loop_cond, device_name,
551                                        bopts.WithName(switch_name));
552   if (!status.ok()) return status;
553   Node* next =
554       AddControlNext({switch_node, 1}, device_name, bopts.WithName(next_name));
555   if (!status.ok()) return status;
556 
557   // Add control flow info for these new nodes:
558   AddControlFlowInfo(enter, src, cf_info);
559   AddControlFlowInfo(merge, src, cf_info);
560   AddControlFlowInfo(switch_node, src, cf_info);
561   AddControlFlowInfo(next, src, cf_info);
562 
563   // Add input edges for the newly created merge node:
564   g->AddEdge(enter, 0, merge, 0);
565   g->AddEdge(next, 0, merge, 1);
566 
567   loop->enter = enter;
568   loop->merge = merge;
569   loop->switch_node = switch_node;
570   return Status::OK();
571 }
572 
573 // Build memory and device type info for every node in the graph.
574 // TODO(yuanbyu): It might be simpler if we convert MemoryType to
575 // DeviceType for the inputs/outputs of each node.
BuildMemoryDeviceInfo(const Graph & g,GraphInfo * info)576 Status BuildMemoryDeviceInfo(const Graph& g, GraphInfo* info) {
577   MemoryTypeVector input_memory_types;
578   MemoryTypeVector output_memory_types;
579 
580   info->device_types.resize(g.num_node_ids(), DEVICE_CPU);
581   for (const Node* node : g.op_nodes()) {
582     DeviceNameUtils::ParsedName parsed;
583     if (!DeviceNameUtils::ParseFullName(node->assigned_device_name(),
584                                         &parsed)) {
585       return errors::Internal("Malformed assigned device '",
586                               node->assigned_device_name(), "'");
587     }
588 
589     TF_RETURN_IF_ERROR(MemoryTypesForNode(
590         g.op_registry(), DeviceType(parsed.type), node->def(),
591         &input_memory_types, &output_memory_types));
592 
593     int node_id = node->id();
594     info->device_types[node_id] = DeviceType(parsed.type);
595     for (int i = 0; i < input_memory_types.size(); ++i) {
596       info->input_types[{node_id, i}] = input_memory_types[i];
597     }
598     for (int i = 0; i < output_memory_types.size(); ++i) {
599       info->output_types[{node_id, i}] = output_memory_types[i];
600     }
601   }
602   return Status::OK();
603 }
604 
InputFrame(const Node * node,const std::vector<ControlFlowInfo> & cf_info)605 const Node* InputFrame(const Node* node,
606                        const std::vector<ControlFlowInfo>& cf_info) {
607   // An input is in the same frame as the node except for Enter nodes.
608   // The input of Enter is in the parent frame of the Enter node.
609   if (!node->IsEnter()) {
610     return node;
611   }
612   return cf_info[node->id()].parent_frame;
613 }
614 
OutputFrame(const Node * node,const std::vector<ControlFlowInfo> & cf_info)615 const Node* OutputFrame(const Node* node,
616                         const std::vector<ControlFlowInfo>& cf_info) {
617   // An output is in the same frame as the node except for Exit nodes.
618   // The output of Exit is in the parent frame of the Exit node.
619   if (!node->IsExit()) {
620     return node;
621   }
622   return cf_info[node->id()].parent_frame;
623 }
624 
625 // Each participating device needs to decide a) if there is a next iteration,
626 // and b) if the loop terminates. We take the approach to encode this control
627 // flow logic in the dataflow graph. There are at least two possible encodings.
628 // In a completely decentralized encoding, the participants communicate peer
629 // to peer. The other encoding uses a frame leader (the participant who owns
630 // the pivot termination predicate) to broadcast the termination condition to
631 // all the participants. For now we take the latter because it is simpler.
632 //
633 // TODO(yuanbyu): The correctness of this construction is rather subtle. I got
634 // it wrong many times so it would be nice to write a proof to be sure.
AddControlFlow(const PartitionOptions & opts,Graph * g,GraphInfo * g_info)635 Status AddControlFlow(const PartitionOptions& opts, Graph* g,
636                       GraphInfo* g_info) {
637   Status status;
638   GraphDefBuilder::Options bopts(g, &status);
639   std::vector<ControlFlowInfo>& cf_info = g_info->cf_info;
640 
641   // Build the control flow info for every node.
642   status = BuildControlFlowInfo(g, &cf_info);
643   if (!status.ok()) return status;
644 
645   OptimizeControlFlowColocation(g);
646 
647   // The map from frames to their LoopCond nodes.
648   std::unordered_map<string, Node*> frame_cond_map;
649   int num_node_ids = g->num_node_ids();
650   for (int i = 0; i < num_node_ids; ++i) {
651     Node* node = g->FindNodeId(i);
652     if (node == nullptr) continue;
653 
654     if (IsLoopCond(node)) {
655       const string& frame_name = cf_info[node->id()].frame_name;
656       DCHECK(!frame_name.empty());
657       frame_cond_map[frame_name] = node;
658     }
659   }
660 
661   // Add all control loops for cross-device frames.
662   // A control loop is added only when there is a cross-device edge in a
663   // non-root frame. Nothing is added if there is no loops. We also don't
664   // add anything for a frame that is completely local to a device. For
665   // nested loops, we stack the control loops together by connecting
666   // the merge of the outer loop to the enter of the inner loop.
667   //
668   // A map from <frame_name, device_name> to ControlLoop.
669   std::unordered_map<string, ControlLoop> control_loops;
670   int num_edge_ids = g->num_edge_ids();
671   for (int i = 0; i < num_edge_ids; ++i) {
672     const Edge* edge = g->FindEdgeId(i);
673     if (edge == nullptr) continue;
674 
675     const Node* src = edge->src();
676     const Node* dst = edge->dst();
677     // Skip Sink/Source nodes.
678     if (!src->IsOp() || !dst->IsOp()) continue;
679 
680     const string& src_device = src->assigned_device_name();
681     const string& dst_device = dst->assigned_device_name();
682     // Skip local edges.
683     if (src_device == dst_device) continue;
684 
685     const Node* src_frame = OutputFrame(src, cf_info);
686     const Node* dst_frame = InputFrame(dst, cf_info);
687     const string& src_frame_name = cf_info[src_frame->id()].frame_name;
688     const string& dst_frame_name = cf_info[dst_frame->id()].frame_name;
689     // Skip if src and dst are not in the same frame.
690     if (src_frame_name.empty() || src_frame_name != dst_frame_name) {
691       continue;
692     }
693 
694     // Add the control loop. Start by adding the control loop for the
695     // current frame if needed, and recursively adding the control loop
696     // for its outer frame when nested.
697     ControlLoop child_loop;
698     while (true) {
699       const string& curr_frame_name = cf_info[src_frame->id()].frame_name;
700       if (curr_frame_name.empty()) {
701         // We have reached the root frame.
702         if (child_loop.merge != nullptr) {
703           const string& node_name = opts.new_name(edge->dst()->name());
704           const string& device_name = edge->dst()->assigned_device_name();
705           Node* const_node =
706               AddControlConst(device_name, bopts.WithName(node_name));
707           if (!status.ok()) return status;
708           AddControlFlowInfo(const_node, src_frame, &cf_info);
709           g->AddEdge(const_node, 0, child_loop.enter, 0);
710         }
711         break;
712       }
713 
714       const string& cl_key = strings::StrCat(curr_frame_name, "$$", dst_device);
715       auto it = control_loops.find(cl_key);
716       if (it != control_loops.end()) {
717         if (child_loop.enter != nullptr) {
718           g->AddEdge(it->second.merge, 0, child_loop.enter, 0);
719         }
720         break;
721       }
722 
723       // Get the frame's LoopCond.
724       auto cond_it = frame_cond_map.find(curr_frame_name);
725       if (cond_it == frame_cond_map.end()) {
726         return errors::InvalidArgument(
727             "A cross-device loop must have a pivot predicate: ",
728             curr_frame_name);
729       }
730       Node* loop_cond = cond_it->second;
731 
732       // Add the control loop.
733       ControlLoop curr_loop;
734       status = AddControlLoop(opts, g, src_frame, edge, loop_cond, &cf_info,
735                               &curr_loop);
736       if (!status.ok()) return status;
737       control_loops[cl_key] = curr_loop;
738 
739       if (child_loop.enter != nullptr) {
740         // Connect the merge of the outer loop to the enter of the inner.
741         g->AddEdge(curr_loop.merge, 0, child_loop.enter, 0);
742       }
743       src_frame = cf_info[src_frame->id()].parent_frame;
744       child_loop = curr_loop;
745     }
746   }
747 
748   // For a cross-device edge, on the dst device, add a control edge
749   // from the merge node of the control loop to dst. If a send/recv is
750   // introduced for this edge in future partitioning, we delete this
751   // control edge and add a new control edge from the merge to the recv.
752   num_edge_ids = g->num_edge_ids();
753   for (int i = 0; i < num_edge_ids; ++i) {
754     const Edge* edge = g->FindEdgeId(i);
755     if (edge == nullptr) continue;
756 
757     const Node* src = edge->src();
758     Node* dst = edge->dst();
759     // Skip Sink/Source nodes.
760     if (!src->IsOp() || !dst->IsOp()) continue;
761 
762     const string& src_device = src->assigned_device_name();
763     const string& dst_device = dst->assigned_device_name();
764     if (src_device != dst_device) {
765       const Node* src_frame = OutputFrame(src, cf_info);
766       const Node* dst_frame = InputFrame(dst, cf_info);
767       const string& src_frame_name = cf_info[src_frame->id()].frame_name;
768       const string& dst_frame_name = cf_info[dst_frame->id()].frame_name;
769       if (!src_frame_name.empty() && src_frame_name == dst_frame_name) {
770         const string& cl_key =
771             strings::StrCat(dst_frame_name, "$$", dst_device);
772         ControlLoop loop = control_loops[cl_key];
773         DCHECK(loop.enter != nullptr);
774         // Note that we'll create multiple duplicate edges if dst has multiple
775         // cross-device inputs. This is expected by the logic in Partition(), so
776         // it can add control edges to the recv nodes once they're created.
777         g->AddControlEdge(loop.merge, dst, /*allow_duplicates=*/true);
778       }
779     }
780   }
781   return Status::OK();
782 }
783 
784 struct PriorityTopoSortNode {
PriorityTopoSortNodetensorflow::__anonbceffdd90111::PriorityTopoSortNode785   PriorityTopoSortNode(const NodeDef* n, int64_t st)
786       : node(n), start_time(st) {}
787 
788   const NodeDef* node;
789   int64 start_time;
790 };
791 
792 struct PriorityTopoSortNodeGreater {
operator ()tensorflow::__anonbceffdd90111::PriorityTopoSortNodeGreater793   bool operator()(const PriorityTopoSortNode& left,
794                   const PriorityTopoSortNode& right) {
795     return left.start_time > right.start_time;
796   }
797 };
798 
799 }  // namespace
800 
801 // Returns in <nodes> the nodes that should participate in epoch-based recv
802 // scheduling, along with their times; <nodes> is ordered by increasing
803 // start_time. Returns in <node_to_start_time_out> the timing for all nodes,
804 // even those not in <nodes>.
805 //
806 // Comparing to sorting on the node's start time only, this also processes the
807 // nodes in dependency order, and updates start times to ensure a node's
808 // start_time > the start time for all dependencies.
809 //
810 // Note that graph_partition_test.cc accesses this function for testing, even
811 // though it's not declared in the header.
TopologicalSortNodesWithTimePriority(const GraphDef * gdef,std::vector<std::pair<const NodeDef *,int64>> * nodes,std::unordered_map<const NodeDef *,int64> * node_to_start_time_out)812 Status TopologicalSortNodesWithTimePriority(
813     const GraphDef* gdef, std::vector<std::pair<const NodeDef*, int64>>* nodes,
814     std::unordered_map<const NodeDef*, int64>* node_to_start_time_out) {
815   // Queue of nodes to process; lowest start time is returned first.
816   std::priority_queue<PriorityTopoSortNode, std::vector<PriorityTopoSortNode>,
817                       PriorityTopoSortNodeGreater>
818       q;
819   std::unordered_map<const NodeDef*, int64> node_to_start_time;
820   auto enqueue = [&q, &node_to_start_time](const NodeDef* node) {
821     const int64_t start_time = node_to_start_time[node];
822     q.emplace(node, start_time);
823   };
824 
825   // Build initial structures, initial contents of queue.
826   std::unordered_map<string, std::vector<const NodeDef*>> node_to_output_nodes;
827   std::unordered_map<const NodeDef*, int> inputs_needed;
828   for (int n = 0; n < gdef->node_size(); ++n) {
829     const NodeDef* ndef = &gdef->node(n);
830     for (int i = 0; i < ndef->input_size(); ++i) {
831       node_to_output_nodes[string(ParseTensorName(ndef->input(i)).first)]
832           .push_back(ndef);
833     }
834     int64_t start_time;
835     TF_RETURN_IF_ERROR(GetNodeAttr(*ndef, "_start_time", &start_time));
836     node_to_start_time[ndef] = start_time;
837     inputs_needed[ndef] = ndef->input_size();
838     if (ndef->input_size() == 0) {
839       enqueue(ndef);
840     }
841   }
842 
843   // Determine which merge nodes are parts of loops; these
844   // need to happen in the traversal after all non-NextIteration inputs
845   // are run.
846   for (int n = 0; n < gdef->node_size(); ++n) {
847     const NodeDef* ndef = &gdef->node(n);
848     if (IsNextIteration(*ndef)) {
849       for (const NodeDef* n : node_to_output_nodes[ndef->name()]) {
850         if (IsMerge(*n)) {
851           // n is a merge that is part of a loop structure.
852           // It doesn't need to wait for this NextIteration loop
853           // when doing the traversal.
854           --inputs_needed[n];
855         }
856       }
857     }
858   }
859 
860   // Traverse.
861   std::vector<std::pair<const NodeDef*, int64>> start_times;
862   start_times.reserve(gdef->node_size());
863   while (!q.empty()) {
864     PriorityTopoSortNode cur = q.top();
865     q.pop();
866 
867     start_times.emplace_back(cur.node, cur.start_time);
868 
869     for (const NodeDef* n : node_to_output_nodes[cur.node->name()]) {
870       auto& output_start_time = node_to_start_time[n];
871       if (output_start_time <= cur.start_time) {
872         output_start_time = cur.start_time + 1;
873       }
874       if (--inputs_needed[n] == 0) {
875         enqueue(n);
876       }
877     }
878   }
879 
880   // Done.
881   nodes->swap(start_times);
882   node_to_start_time_out->swap(node_to_start_time);
883   return Status::OK();
884 }
885 
AddControlEdges(const PartitionOptions & opts,std::unordered_map<string,GraphDef> * partitions)886 Status AddControlEdges(const PartitionOptions& opts,
887                        std::unordered_map<string, GraphDef>* partitions) {
888   Status status;
889   // TODO(yuanbyu): Very naive for now. To be improved.
890   const int num_epochs = 100;
891   const int prefetch = 6;
892 
893   for (auto& part : *partitions) {
894     GraphDef* gdef = &part.second;
895     std::vector<std::pair<const NodeDef*, int64>> start_times;
896     std::unordered_map<const NodeDef*, int64> node_to_start_time;
897     status = TopologicalSortNodesWithTimePriority(gdef, &start_times,
898                                                   &node_to_start_time);
899     if (!status.ok()) {
900       return status;
901     }
902 
903     // Add a dummy node for every epoch, and add a control edge from the
904     // "last" node in the preceding epoch to the dummy node.
905     string device_name = gdef->node(0).device();
906     int64_t makespan = start_times.back().second;
907     int64_t resolution = (makespan / num_epochs) + 1;
908 
909     int i = 0;
910     int j = 0;
911     std::vector<NodeDef*> dummys;
912     while (i < num_epochs && static_cast<size_t>(j) < start_times.size()) {
913       if (i * resolution > start_times[j].second) {
914         j++;
915       } else {
916         NodeDef* dummy = AddControlTrigger(opts, gdef, device_name, i,
917                                            i * resolution, &status);
918         if (!status.ok()) {
919           return status;
920         }
921         dummys.push_back(dummy);
922         if (j > 0) {
923           string src_name = start_times[j - 1].first->name();
924           AddInput(dummy, src_name, Graph::kControlSlot);
925         }
926         i++;
927       }
928     }
929 
930     // Finally, add the control edges to recvs.
931     for (int n = 0; n < gdef->node_size(); ++n) {
932       NodeDef* ndef = gdef->mutable_node(n);
933       if (ndef->op() == "_Recv") {
934         const int64_t start_time = node_to_start_time[ndef];
935         const int recv_epoch = start_time / resolution;
936         if (recv_epoch >= prefetch) {
937           NodeDef* dummy = dummys[recv_epoch - prefetch];
938           AddInput(ndef, dummy->name(), Graph::kControlSlot);
939         }
940       }
941     }
942   }
943   return Status::OK();
944 }
945 
946 // If 'ndef' is a Send or Recv, fills its attr send_device_incarnation
947 // if possible.
SetIncarnation(const PartitionOptions & opts,NodeDef * ndef)948 void SetIncarnation(const PartitionOptions& opts, NodeDef* ndef) {
949   StringPiece op(ndef->op());
950   if (op != "_Send" && op != "_Recv") {
951     // Not related to send/recv.
952     return;
953   }
954   const string& send_device = GetNodeAttrString(*ndef, "send_device");
955   if (send_device.empty()) {
956     // No known send_device. The runtime will detect it later.
957     return;
958   }
959   int64_t incarnation = PartitionOptions::kIllegalIncarnation;
960   if (!TryGetNodeAttr(*ndef, "send_device_incarnation", &incarnation) ||
961       (incarnation == PartitionOptions::kIllegalIncarnation)) {
962     incarnation = opts.get_incarnation(send_device);
963     SetAttrValue(incarnation,
964                  &((*ndef->mutable_attr())["send_device_incarnation"]));
965   }
966 }
967 
968 // Sets attribute send_device_incarnation of all Send/Recv nodes in
969 // 'gdef', if possible.
SetIncarnation(const PartitionOptions & opts,GraphDef * gdef)970 void SetIncarnation(const PartitionOptions& opts, GraphDef* gdef) {
971   for (NodeDef& ndef : *gdef->mutable_node()) {
972     SetIncarnation(opts, &ndef);
973   }
974   for (FunctionDef& fdef : *gdef->mutable_library()->mutable_function()) {
975     for (NodeDef& ndef : *fdef.mutable_node_def()) {
976       SetIncarnation(opts, &ndef);
977     }
978   }
979 }
980 
Partition(const PartitionOptions & opts,Graph * g,std::unordered_map<string,GraphDef> * partitions)981 Status Partition(const PartitionOptions& opts, Graph* g,
982                  std::unordered_map<string, GraphDef>* partitions) {
983   Status status;
984   partitions->clear();
985 
986   GraphInfo g_info;
987   if (!opts.control_flow_added) {
988     // Add the "code" for distributed execution of control flow. Code is
989     // added only for the frames that are placed on multiple devices. The
990     // new graph is an equivalent transformation of the original graph and
991     // has the property that it can be subsequently partitioned arbitrarily
992     // (down to the level of individual device) for distributed execution.
993     status = AddControlFlow(opts, g, &g_info);
994     if (!status.ok()) return status;
995   }
996 
997   // At this point, all the graph mutations have been done. Build memory
998   // and device type info for every node and edge in the graph.
999   status = BuildMemoryDeviceInfo(*g, &g_info);
1000   if (!status.ok()) return status;
1001 
1002   string dstp;
1003   std::vector<const Edge*> inputs;
1004   DupRecvTable dup_recv(3);
1005   // For a node dst, 'ref_recvs' remembers the recvs introduced by a ref
1006   // edge to dst. 'ref_control_inputs' remembers the inputs by a non-ref
1007   // edge to dst. We will add a control edge for every pair in
1008   // (ref_recvs x ref_control_inputs).
1009   std::vector<NodeDef*> ref_recvs;
1010   std::vector<string> ref_control_inputs;
1011 
1012   int32_t num_data = 0;
1013   int32_t num_control = 0;
1014   for (const Node* dst : g->op_nodes()) {
1015     dstp = opts.node_to_loc(dst);
1016     GraphDef* dst_graph = &(*partitions)[dstp];
1017     NodeDef* dst_def = dst_graph->add_node();
1018     *dst_def = dst->def();
1019     MergeDebugInfo(NodeDebugInfo(dst->def()), dst_def);
1020     dst_def->set_device(dst->assigned_device_name());
1021     dst_def->clear_input();  // Inputs are filled below
1022     if (opts.need_to_record_start_times) {
1023       int64_t start_time;
1024       status = GetNodeAttr(*dst_def, "_start_time", &start_time);
1025       if (errors::IsNotFound(status)) {
1026         start_time = opts.start_times[dst->id()].value();
1027         AddNodeAttr("_start_time", start_time, dst_def);
1028       } else if (!status.ok()) {
1029         return status;
1030       }
1031     }
1032 
1033     // Arrange the incoming edges to dst so that input[i] holds the
1034     // input flowing into slot numbered i. Trailing entries in input[]
1035     // hold control edges.
1036     inputs.clear();
1037     inputs.resize(dst->num_inputs(), nullptr);
1038     ref_recvs.clear();
1039     ref_control_inputs.clear();
1040     const Edge* control_flow_edge = nullptr;
1041     int32_t num_control_flow_edges = 0;
1042     int32_t num_input_edges = 0;
1043     for (const Edge* edge : dst->in_edges()) {
1044       if (edge->IsControlEdge()) {
1045         if (IsMerge(edge->src()) && IsControlLoop(edge->src())) {
1046           // This is one of the control edges added for control flow. There
1047           // can be multiple such edges as the dest node may have multiple
1048           // remote inputs. We keep track of the number of such edges.
1049           control_flow_edge = edge;
1050           ++num_control_flow_edges;
1051         } else {
1052           inputs.push_back(edge);
1053         }
1054       } else {
1055         DCHECK(inputs[edge->dst_input()] == nullptr);
1056         inputs[edge->dst_input()] = edge;
1057         ++num_input_edges;
1058       }
1059     }
1060 
1061     if (num_input_edges != dst->num_inputs()) {
1062       return errors::InvalidArgument("Incomplete graph, missing ",
1063                                      (dst->num_inputs() - num_input_edges),
1064                                      " inputs for ", dst->name());
1065     }
1066 
1067     // Process in order so that all data edges are added as inputs to
1068     // dst in Edge::dst_input() order.
1069     for (const Edge* edge : inputs) {
1070       const Node* src = edge->src();
1071       if (!src->IsOp()) continue;  // Skip Sink/Source nodes.
1072 
1073       GraphDef* src_graph = &(*partitions)[opts.node_to_loc(src)];
1074       if (src_graph == dst_graph && !NeedSameDeviceSendRecv(edge, g_info)) {
1075         // Same partition and compatible memory types:
1076         AddInput(dst_def, src->name(), edge->src_output());
1077         if (edge->IsControlEdge() ||
1078             !IsRefType(src->output_type(edge->src_output()))) {
1079           ref_control_inputs.push_back(src->name());
1080         }
1081         continue;
1082       }
1083 
1084       int64_t send_start_time = 0;
1085       int64_t recv_start_time = 0;
1086       if (opts.scheduling_for_recvs) {
1087         status = GetNodeAttr(src->attrs(), "_start_time", &send_start_time);
1088         if (errors::IsNotFound(status) && opts.need_to_record_start_times) {
1089           send_start_time = opts.start_times[src->id()].value();
1090         } else if (!status.ok()) {
1091           return status;
1092         }
1093 
1094         status = GetNodeAttr(dst->attrs(), "_start_time", &recv_start_time);
1095         if (errors::IsNotFound(status) && opts.need_to_record_start_times) {
1096           recv_start_time = opts.start_times[dst->id()].value();
1097         } else if (!status.ok()) {
1098           return status;
1099         }
1100       }
1101 
1102       // Check whether there is already a send/recv pair transferring
1103       // the same tensor/control from the src to dst partition.
1104       const bool on_host = IsDstInputOnHost(edge, g_info);
1105       DupRecvKey key{src->id(), edge->src_output(), dst_graph, on_host};
1106       auto iter = dup_recv.find(key);
1107       if (iter != dup_recv.end()) {
1108         // We found one. Reuse the data/control transferred already.
1109         const string& recv_node_name = iter->second.recv->name();
1110         if (edge->IsControlEdge()) {
1111           AddInput(dst_def, recv_node_name, Graph::kControlSlot);
1112         } else {
1113           AddInput(dst_def, recv_node_name, 0);
1114         }
1115         ref_control_inputs.push_back(recv_node_name);
1116 
1117         // We want the start_time for the recv to be the smallest of the start
1118         // times of it's consumers. So we update this whenever we use a recv,
1119         // and write it out to the attribute at the end of the subroutine
1120         if (iter->second.start_time > recv_start_time) {
1121           iter->second.start_time = recv_start_time;
1122         }
1123         continue;
1124       }
1125 
1126       NodeDefBuilder::NodeOut send_from;
1127       if (edge->IsControlEdge()) {
1128         // Insert a dummy const node that will generate a tiny
1129         // data element to be sent from send to recv.
1130         VLOG(1) << "Send/Recv control: " << src->assigned_device_name() << "["
1131                 << src->name() << "] -> " << dst->assigned_device_name() << "["
1132                 << dst->name() << "]";
1133         NodeDef* dummy = AddDummyConst(opts, src_graph, edge, &status);
1134         if (!status.ok()) return status;
1135         // Set the start time for this dummy node.
1136         if (opts.scheduling_for_recvs) {
1137           AddNodeAttr("_start_time", send_start_time, dummy);
1138         }
1139         AddInput(dummy, src->name(), Graph::kControlSlot);
1140         send_from.Reset(dummy->name(), 0, DT_FLOAT);
1141       } else {
1142         send_from.Reset(src->name(), edge->src_output(), EdgeType(edge));
1143       }
1144 
1145       // Need to split edge by placing matching send/recv nodes on
1146       // the src/dst sides of the edge.
1147       NodeDef* send = AddSend(opts, g_info, src_graph, edge, send_from,
1148                               send_start_time, &status);
1149       if (!status.ok()) return status;
1150 
1151       NodeDef* real_recv = nullptr;
1152       NodeDef* recv =
1153           AddRecv(opts, g_info, dst_graph, edge, &real_recv, &status);
1154       if (!status.ok()) return status;
1155 
1156       // Fix up the control flow edge.
1157       // NOTE(yuanbyu): 'real_recv' must be the real recv node.
1158       if (src_graph == dst_graph) {
1159         // For same device send/recv, add a control edge from send to recv.
1160         // This prevents the asynchronous recv kernel from being scheduled
1161         // before the data is available.
1162         AddInput(real_recv, send->name(), Graph::kControlSlot);
1163       } else if (control_flow_edge != nullptr) {
1164         // Redirect control edge to the real recv since this is not the same
1165         // device send/recv.
1166         --num_control_flow_edges;
1167         AddInput(real_recv, control_flow_edge->src()->name(),
1168                  Graph::kControlSlot);
1169       }
1170 
1171       if (!edge->IsControlEdge() &&
1172           IsRefType(src->output_type(edge->src_output()))) {
1173         AddNodeAttr("_start_time", recv_start_time, recv);
1174         if (real_recv != recv) {
1175           AddNodeAttr("_start_time", recv_start_time, real_recv);
1176         }
1177         // If src is of ref type and the edge is not a control edge, dst has
1178         // read semantics and therefore we must control the recv.
1179         ref_recvs.push_back(real_recv);
1180       } else {
1181         // Memorize the send/recv pair, only if this is not a "ref" edge.
1182         // NOTE(yuanbyu): Collapsing ref edges requires extreme care so
1183         // for now we don't do it.
1184         dup_recv[key] = {recv, real_recv, recv_start_time};
1185         ref_control_inputs.push_back(recv->name());
1186       }
1187 
1188       if (edge->IsControlEdge()) {
1189         ++num_control;
1190         AddInput(dst_def, recv->name(), Graph::kControlSlot);
1191       } else {
1192         ++num_data;
1193         AddInput(dst_def, recv->name(), 0);
1194       }
1195     }
1196 
1197     // Add control edges from 'ref_control_inputs' to 'ref_recvs'.
1198     // NOTE(yuanbyu): Adding these control edges should not introduce
1199     // deadlocks. 'dst' has implicit "read" nodes that, when we split
1200     // across devices, are made explicit; Retargeting the dependencies
1201     // to 'dst' to those nodes would not introduce cycles if there isn't
1202     // one before the transformation.
1203     // NOTE(yuanbyu): This may impact performance because it defers the
1204     // execution of recvs until all the other inputs become available.
1205     AddReadControl(ref_recvs, ref_control_inputs);
1206 
1207     // Add back the control edges for control flow that are not used.
1208     if (control_flow_edge != nullptr) {
1209       for (int i = 0; i < num_control_flow_edges; ++i) {
1210         AddInput(dst_def, control_flow_edge->src()->name(),
1211                  Graph::kControlSlot);
1212       }
1213     }
1214   }
1215 
1216   const FunctionLibraryDefinition* flib_def = opts.flib_def;
1217   if (flib_def == nullptr) {
1218     flib_def = &g->flib_def();
1219   }
1220 
1221   // Set versions, function library and send/recv incarnation.
1222   for (auto& it : *partitions) {
1223     GraphDef* gdef = &it.second;
1224     *gdef->mutable_versions() = g->versions();
1225     // Prune unreachable functions from `flib_def` before adding them to `gdef`.
1226     *gdef->mutable_library() = flib_def->ReachableDefinitions(*gdef).ToProto();
1227 
1228     // Traverse the graph to fill every send/recv op's incarnation
1229     // information.
1230     SetIncarnation(opts, gdef);
1231   }
1232 
1233   // Set the start times for recvs at the very end.
1234   if (opts.scheduling_for_recvs) {
1235     for (auto& it : dup_recv) {
1236       AddNodeAttr("_start_time", it.second.start_time, it.second.recv);
1237       if (it.second.real_recv != it.second.recv) {
1238         AddNodeAttr("_start_time", it.second.start_time, it.second.real_recv);
1239       }
1240     }
1241   }
1242 
1243   VLOG(1) << "Added send/recv: controls=" << num_control
1244           << ", data=" << num_data;
1245   if (VLOG_IS_ON(2)) {
1246     for (auto& it : *partitions) {
1247       GraphDef* gdef = &it.second;
1248       DumpGraphDefToFile(strings::StrCat("partition_", it.first, "_",
1249                                          reinterpret_cast<uintptr_t>(gdef)),
1250                          *gdef);
1251     }
1252   }
1253   return Status::OK();
1254 }
1255 
1256 }  // namespace tensorflow
1257