• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/strings/match.h"
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
22 #include "tensorflow/compiler/jit/encapsulate_util.h"
23 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
24 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/core/common_runtime/function.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/graph_to_functiondef.h"
29 #include "tensorflow/core/framework/node_def_builder.h"
30 #include "tensorflow/core/framework/node_def_util.h"
31 #include "tensorflow/core/framework/tensor_shape.pb.h"
32 #include "tensorflow/core/graph/algorithm.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/gtl/cleanup.h"
35 #include "tensorflow/core/platform/macros.h"
36 #include "tensorflow/core/util/dump_graph.h"
37 #include "tensorflow/stream_executor/lib/statusor.h"
38 
39 namespace tensorflow {
40 
41 namespace {
42 
43 // Control return mapping function for outside compilation host graphs.
44 // All nodes with kXlaHasHostTransfer attribute are control outputs.
HostGraphControlRetMapping(const Node * n)45 absl::optional<string> HostGraphControlRetMapping(const Node* n) {
46   if (HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) {
47     return n->name();
48   }
49   return absl::nullopt;
50 }
51 
52 // Add a key placeholder node to the graph. The key placeholder node will be
53 // used as input for XlaRecvAtHost/XlaSendFromHost nodes.
AddHostComputeKeyPlaceholder(const string & xla_cluster_name,Graph * g)54 StatusOr<Node*> AddHostComputeKeyPlaceholder(const string& xla_cluster_name,
55                                              Graph* g) {
56   NodeDef key_def;
57   NodeDefBuilder builder(absl::StrCat(xla_cluster_name, "_key_placeholder"),
58                          "Placeholder");
59   builder.Attr("dtype", DT_STRING);
60   builder.Attr("shape", PartialTensorShape({2}));
61   builder.Attr("_host_compute_call_node", xla_cluster_name);
62   Status s = builder.Finalize(&key_def);
63   if (!s.ok()) return s;
64 
65   Node* n = g->AddNode(key_def, &s);
66   if (!s.ok()) return s;
67   return n;
68 }
69 
70 // Returns if the node is a XLA computation key placeholder.
IsKeyPlaceholderNode(const Node & n)71 bool IsKeyPlaceholderNode(const Node& n) {
72   return n.type_string() == "Placeholder" &&
73          absl::EndsWith(n.name(), "_key_placeholder");
74 }
75 
76 // Returns nodes with given type.
GatherNodesWithType(const Graph & g,const string & type)77 std::vector<Node*> GatherNodesWithType(const Graph& g, const string& type) {
78   std::vector<Node*> result;
79   for (Node* n : g.nodes()) {
80     if (n->type_string() == type) {
81       result.push_back(n);
82     }
83   }
84   return result;
85 }
86 
87 // Gets data types from `arg_nodes` and fills them into `recv_at_host_dtypes`.
GetArgDataTypes(const std::vector<Node * > & arg_nodes,std::vector<DataType> * recv_at_host_dtypes)88 Status GetArgDataTypes(const std::vector<Node*>& arg_nodes,
89                        std::vector<DataType>* recv_at_host_dtypes) {
90   recv_at_host_dtypes->resize(arg_nodes.size(), DT_INVALID);
91   for (auto* n : arg_nodes) {
92     int index;
93     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
94     DataType dtype;
95     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "T", &dtype));
96     (*recv_at_host_dtypes)[index] = dtype;
97   }
98   for (int i = 0, end = recv_at_host_dtypes->size(); i < end; i++) {
99     if ((*recv_at_host_dtypes)[i] == DT_INVALID) {
100       return errors::Internal("Cannot get datatype for input ", i);
101     }
102   }
103   return Status::OK();
104 }
105 
106 // Builds XlaRecvAtHost node.
BuildRecvAtHostNode(Graph * g,const string & oc_cluster_name,const std::vector<DataType> & recv_at_host_dtypes,Node * key_placeholder)107 StatusOr<Node*> BuildRecvAtHostNode(
108     Graph* g, const string& oc_cluster_name,
109     const std::vector<DataType>& recv_at_host_dtypes, Node* key_placeholder) {
110   NodeDefBuilder recv_at_host_builder(
111       absl::StrCat("outside_compilation_", oc_cluster_name, "_recv"),
112       "_XlaRecvAtHost");
113   NodeDef recv_at_host_def;
114   recv_at_host_builder.Attr("Toutputs", recv_at_host_dtypes);
115   // The correct device_ordinal will be inserted during replication in a
116   // subsequent rewrite.
117   AttrValue device_ordinal_value;
118   device_ordinal_value.set_placeholder("_device_ordinal");
119   recv_at_host_builder.Attr("device_ordinal", device_ordinal_value);
120   recv_at_host_builder.Attr(
121       "key", absl::StrCat("host_compute_channel_", oc_cluster_name));
122   recv_at_host_builder.Attr(kXlaHasHostTransferAttrName, true);
123   recv_at_host_builder.Input(key_placeholder->name(), 0, DT_STRING);
124   TF_RETURN_IF_ERROR(recv_at_host_builder.Finalize(&recv_at_host_def));
125   Status s;
126   Node* recv_at_host_node = g->AddNode(recv_at_host_def, &s);
127   TF_RETURN_IF_ERROR(s);
128   return recv_at_host_node;
129 }
130 
131 // Builds XlaRecvAtHost node, and replaces all _Arg nodes with it.
ReplaceArgNodesWithRecvAtHostNode(Graph * g,const string & oc_cluster_name,std::vector<DataType> * recv_at_host_dtypes,Node * key_placeholder)132 StatusOr<Node*> ReplaceArgNodesWithRecvAtHostNode(
133     Graph* g, const string& oc_cluster_name,
134     std::vector<DataType>* recv_at_host_dtypes, Node* key_placeholder) {
135   // TODO(b/77601805): use out nodes for source node, instead of traversing all
136   // nodes.
137   std::vector<Node*> arg_nodes = GatherNodesWithType(*g, "_Arg");
138   TF_RETURN_IF_ERROR(GetArgDataTypes(arg_nodes, recv_at_host_dtypes));
139   TF_ASSIGN_OR_RETURN(
140       Node * recv_at_host_node,
141       BuildRecvAtHostNode(g, oc_cluster_name, *recv_at_host_dtypes,
142                           key_placeholder));
143   for (auto* n : arg_nodes) {
144     int index;
145     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
146     // Record out edges and remove `n` before adding those edges to RecvAtHost.
147     // This is to avoid multiple producers.
148     std::vector<OutEdgeInfo> out_edge_info;
149     for (auto edge : n->out_edges()) {
150       out_edge_info.push_back(
151           {edge->dst(), edge->src_output(), edge->dst_input()});
152     }
153     g->RemoveNode(n);
154     for (const OutEdgeInfo& edge : out_edge_info) {
155       if (edge.dst_input == Graph::kControlSlot) {
156         g->AddControlEdge(recv_at_host_node, edge.dst);
157       } else {
158         g->AddEdge(recv_at_host_node, index, edge.dst, edge.dst_input);
159       }
160     }
161 
162     // Rewrite dst nodes because their input changed.
163     for (int i = 0, end = out_edge_info.size(); i < end; i++) {
164       const OutEdgeInfo edge = out_edge_info[i];
165       if (edge.dst_input == Graph::kControlSlot) {
166         continue;
167       }
168 
169       Node* dst = edge.dst;
170       NodeDef new_def = dst->def();
171       *new_def.mutable_input(edge.dst_input) =
172           absl::StrCat(recv_at_host_node->name(), ":", index);
173       TF_ASSIGN_OR_RETURN(Node * dst_replace, ReplaceNode(g, dst, new_def));
174 
175       // Other edges might have `dst` as dst node as well. Update those edges
176       // with `dst_replace`.
177       for (int j = i + 1, end = out_edge_info.size(); j < end; j++) {
178         if (out_edge_info[j].dst == dst) {
179           out_edge_info[j].dst = dst_replace;
180         }
181       }
182     }
183   }
184   g->AddEdge(key_placeholder, 0, recv_at_host_node, 0);
185   return recv_at_host_node;
186 }
187 
188 // Gets data types from `ret_nodes` and fills them into `send_from_host_dtypes`.
GetRetDataTypes(const std::vector<Node * > & ret_nodes,std::vector<DataType> * send_from_host_dtypes)189 Status GetRetDataTypes(const std::vector<Node*>& ret_nodes,
190                        std::vector<DataType>* send_from_host_dtypes) {
191   send_from_host_dtypes->resize(ret_nodes.size(), DT_INVALID);
192   for (auto* n : ret_nodes) {
193     int index;
194     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
195     DataType dtype;
196     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "T", &dtype));
197     (*send_from_host_dtypes)[index] = dtype;
198   }
199   for (int i = 0, end = send_from_host_dtypes->size(); i < end; i++) {
200     if ((*send_from_host_dtypes)[i] == DT_INVALID) {
201       return errors::Internal("Cannot get datatype for output ", i);
202     }
203   }
204   return Status::OK();
205 }
206 
207 // Builds XlaSendFromHost node.
BuildSendFromHostNode(Graph * g,const string & oc_cluster_name,const std::vector<Node * > & ret_nodes,const std::vector<DataType> & send_from_host_dtypes,Node * key_placeholder)208 StatusOr<Node*> BuildSendFromHostNode(
209     Graph* g, const string& oc_cluster_name,
210     const std::vector<Node*>& ret_nodes,
211     const std::vector<DataType>& send_from_host_dtypes, Node* key_placeholder) {
212   NodeDefBuilder send_from_host_builder(
213       absl::StrCat("outside_compilation_", oc_cluster_name, "_send"),
214       "_XlaSendFromHost");
215   NodeDef send_from_host_def;
216   send_from_host_builder.Attr("Tinputs", send_from_host_dtypes);
217   // The correct device_ordinal will be inserted during replication in a
218   // subsequent rewrite.
219   AttrValue device_ordinal_value;
220   device_ordinal_value.set_placeholder("_device_ordinal");
221   send_from_host_builder.Attr("device_ordinal", device_ordinal_value);
222   send_from_host_builder.Attr(
223       "key", absl::StrCat("host_compute_channel_", oc_cluster_name));
224   send_from_host_builder.Attr(kXlaHasHostTransferAttrName, true);
225   std::vector<NodeDefBuilder::NodeOut> inputs(send_from_host_dtypes.size());
226   for (auto* n : ret_nodes) {
227     int index;
228     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
229     const int num_dtypes = send_from_host_dtypes.size();
230     if (index < 0 || index >= num_dtypes) {
231       return errors::Internal("Invalid _Retval index: ", index);
232     }
233     for (auto edge : n->in_edges()) {
234       inputs[index] =
235           NodeDefBuilder::NodeOut{edge->src()->name(), edge->src_output(),
236                                   edge->src()->output_type(edge->src_output())};
237     }
238   }
239   send_from_host_builder.Input(inputs);
240   send_from_host_builder.Input(key_placeholder->name(), 0, DT_STRING);
241   TF_RETURN_IF_ERROR(send_from_host_builder.Finalize(&send_from_host_def));
242   Status s;
243   Node* send_from_host_node = g->AddNode(send_from_host_def, &s);
244   TF_RETURN_IF_ERROR(s);
245   return send_from_host_node;
246 }
247 
248 // Builds XlaSendFromHost node, and replaces all _Retval nodes with it.
ReplaceRetNodesWithSendFromHostNode(Graph * g,const string & oc_cluster_name,std::vector<DataType> * send_from_host_dtypes,Node * key_placeholder)249 StatusOr<Node*> ReplaceRetNodesWithSendFromHostNode(
250     Graph* g, const string& oc_cluster_name,
251     std::vector<DataType>* send_from_host_dtypes, Node* key_placeholder) {
252   // TODO(b/77601805): use in nodes for sink node, instead of traversing all
253   // nodes.
254   std::vector<Node*> ret_nodes = GatherNodesWithType(*g, "_Retval");
255   TF_RETURN_IF_ERROR(GetRetDataTypes(ret_nodes, send_from_host_dtypes));
256   TF_ASSIGN_OR_RETURN(
257       Node * send_from_host_node,
258       BuildSendFromHostNode(g, oc_cluster_name, ret_nodes,
259                             *send_from_host_dtypes, key_placeholder));
260   for (auto* n : ret_nodes) {
261     int index;
262     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
263     for (auto edge : n->in_edges()) {
264       if (edge->src_output() == Graph::kControlSlot) {
265         g->AddControlEdge(edge->src(), send_from_host_node);
266       } else {
267         g->AddEdge(edge->src(), edge->src_output(), send_from_host_node, index);
268       }
269     }
270     g->RemoveNode(n);
271   }
272   g->AddEdge(key_placeholder, 0, send_from_host_node,
273              send_from_host_dtypes->size());
274   return send_from_host_node;
275 }
276 
277 // Returns input shapes (excluding key placeholder) for `send_from_host_node`
278 // if they are all fully defined; absl::nullopt otherwise.
GetInferredInputShapes(int num_inputs,Node * send_from_host_node)279 absl::optional<std::vector<PartialTensorShape>> GetInferredInputShapes(
280     int num_inputs, Node* send_from_host_node) {
281   std::vector<PartialTensorShape> results(num_inputs);
282   for (int i = 0; i < num_inputs; i++) {
283     const Edge* e;
284     if (!send_from_host_node->input_edge(i, &e).ok()) {
285       return absl::nullopt;
286     }
287 
288     std::vector<PartialTensorShape> shapes;
289     if (!GetNodeAttr(e->src()->attrs(), kXlaInferredShapesAttrName, &shapes)
290              .ok()) {
291       return absl::nullopt;
292     }
293 
294     const PartialTensorShape shape = shapes[e->src_output()];
295     if (!shape.IsFullyDefined()) {
296       return absl::nullopt;
297     }
298 
299     results[e->dst_input()] = shape;
300   }
301   return results;
302 }
303 
host_compute_node_name(const string & original_oc_name)304 string host_compute_node_name(const string& original_oc_name) {
305   return absl::StrCat("outside_compilation_", original_oc_name,
306                       "_host_compute");
307 }
308 
309 // Builds XlaHostCompute NodeDef from the outside compilation call node.
BuildXlaHostComputeNodeDef(const Node * call_node,const std::map<string,int> & host_compute_core,const absl::flat_hash_map<string,std::vector<string>> & cluster_deps)310 StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
311     const Node* call_node, const std::map<string, int>& host_compute_core,
312     const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
313   string original_oc_name;
314   TF_RETURN_IF_ERROR(GetNodeAttr(
315       call_node->attrs(), "_outside_compilation_subgraph", &original_oc_name));
316   NodeDefBuilder host_compute_builder(host_compute_node_name(original_oc_name),
317                                       "XlaHostCompute");
318   // In XlaCompiler, if XlaHostCompute node is in a function call node and that
319   // function is inlined, name of the XlaHostCompute node will be changed. So
320   // we cannot rely on node name; use an attribute instead.
321   host_compute_builder.Attr(kXlaOriginalOutsideCompilationNodeName,
322                             host_compute_builder.node_name());
323 
324   // Copy all attributes.
325   for (const auto& attr : call_node->attrs()) {
326     host_compute_builder.Attr(attr.first, attr.second);
327   }
328 
329   // Populate tpu_core assignment.
330   const auto iter = host_compute_core.find(original_oc_name);
331   if (iter != host_compute_core.end()) {
332     int core = iter->second;
333     host_compute_builder.Attr("tpu_core", core);
334   }
335 
336   // Set input tokens and other outside compilation clusters that current
337   // cluster depends in `kXlaTokenArgNodeName`. This is needed because when
338   // outside compilation subgraphs are encapsulated and moved to host graph,
339   // control/data edges between them will only be reflected in host graph.
340   // From XLA's perspective, two originally dependent clusters are no longer
341   // connected, which makes them look like they can be scheduled for execution
342   // in arbitrary order even though in fact they must be executed in order
343   // according to their host-side graph dependency. This can cause deadlock.
344   // Therefore, we hint XLA what the correct ordering of these clusters should
345   // be to avoid deadlocks.
346   std::vector<string> xla_token_input_nodes;
347   xla_token_input_nodes.emplace_back(kXlaTokenArgNodeName);
348   auto cluster_deps_it = cluster_deps.find(original_oc_name);
349   if (cluster_deps_it != cluster_deps.end()) {
350     for (const auto& dep : cluster_deps_it->second) {
351       xla_token_input_nodes.emplace_back(host_compute_node_name(dep));
352     }
353   }
354   host_compute_builder.Attr(kXlaTokenInputNodesAttrName, xla_token_input_nodes);
355 
356   // Populate inputs.
357   std::vector<DataType> input_dtypes;
358   TF_RETURN_IF_ERROR(GetNodeAttr(call_node->attrs(), "Tinputs", &input_dtypes));
359   std::vector<NodeDefBuilder::NodeOut> inputs(input_dtypes.size());
360   for (auto e : call_node->in_edges()) {
361     if (e->IsControlEdge()) {
362       continue;
363     }
364 
365     const int input_dtypes_size = input_dtypes.size();
366     if (e->dst_input() < 0 || e->dst_input() >= input_dtypes_size) {
367       return errors::Internal("Invalid dst_input: ", e->dst_input());
368     }
369     inputs[e->dst_input()] = NodeDefBuilder::NodeOut{
370         e->src()->name(), e->src_output(), input_dtypes[e->dst_input()]};
371   }
372   host_compute_builder.Input(inputs);
373 
374   NodeDef new_def;
375   TF_RETURN_IF_ERROR(host_compute_builder.Finalize(&new_def));
376   return new_def;
377 }
378 
379 // Replace outside compilation function call node with XlaHostCompute node.
ReplaceOutsideCompilationCallNode(Graph * g,Node * call_node,const std::map<string,int> & host_compute_core,const absl::flat_hash_map<string,std::vector<string>> & cluster_deps)380 TF_ATTRIBUTE_NOINLINE StatusOr<Node*> ReplaceOutsideCompilationCallNode(
381     Graph* g, Node* call_node, const std::map<string, int>& host_compute_core,
382     const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
383   // Build XlaHostCompute NodeDef.
384   TF_ASSIGN_OR_RETURN(
385       NodeDef node_def,
386       BuildXlaHostComputeNodeDef(call_node, host_compute_core, cluster_deps));
387   TF_ASSIGN_OR_RETURN(Node * host_compute_node,
388                       ReplaceNode(g, call_node, node_def));
389   VLOG(4) << "Added HostCompute node: " << host_compute_node->DebugString();
390 
391   return host_compute_node;
392 }
393 
394 // Resets "_device_ordinal" attr to placeholder value for related nodes
395 // (XlaRecvAtHost nodes; XlaSendFromHost nodes; If/While/FuncCall nodes
396 // containing XlaRecvAtHost/XlaSendFromHost).
ResetDeviceOrdinalToPlaceholderValue(Graph * g)397 Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) {
398   AttrValue device_ordinal_value;
399   device_ordinal_value.set_placeholder("_device_ordinal");
400   for (Node* n : g->nodes()) {
401     if (!HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) {
402       continue;
403     }
404 
405     if (n->type_string() == "_XlaRecvAtHost" ||
406         n->type_string() == "_XlaSendFromHost") {
407       n->ClearAttr("device_ordinal");
408       n->AddAttr("device_ordinal", device_ordinal_value);
409     } else if (n->IsIfNode()) {
410       for (const string& attr_name :
411            std::vector<string>{"then_branch", "else_branch"}) {
412         NameAttrList branch_func;
413         TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func));
414         (*branch_func.mutable_attr())["_device_ordinal"] = device_ordinal_value;
415         n->ClearAttr(attr_name);
416         n->AddAttr(attr_name, branch_func);
417       }
418     } else if (n->IsWhileNode()) {
419       for (const string& attr_name : std::vector<string>{"cond", "body"}) {
420         NameAttrList branch_func;
421         TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func));
422         (*branch_func.mutable_attr())["_device_ordinal"] = device_ordinal_value;
423         n->ClearAttr(attr_name);
424         n->AddAttr(attr_name, branch_func);
425       }
426     } else if (HasNodeAttr(n->def(), "_device_ordinal")) {
427       // Function call node containing outside compilation.
428       n->ClearAttr("_device_ordinal");
429       n->AddAttr("_device_ordinal", device_ordinal_value);
430     } else {
431       return errors::Internal("Unknown node marked with ",
432                               kXlaHasHostTransferAttrName, ": ",
433                               n->DebugString());
434     }
435   }
436   return Status::OK();
437 }
438 
439 // Cheap check to tell whether FunctionDef contains a lifted argument.
HasLiftedArgs(const FunctionDef & function_def)440 bool HasLiftedArgs(const FunctionDef& function_def) {
441   return absl::c_any_of(function_def.node_def(), [](const NodeDef& node_def) {
442     return (node_def.op() == "Placeholder" &&
443             node_def.attr().find(kXlaLiftedArgOutsideCompilationAttrName) !=
444                 node_def.attr().end());
445   });
446 }
447 
448 // Find lifted arguments in a function body and their corresponding outside
449 // compilation nodes.
450 StatusOr<std::vector<std::pair<Node*, Node*>>>
LiftedArgsAndOutsideCompilationNodesInFunctionBody(const FunctionBody & function_body,const std::unordered_map<string,Node * > & outside_compilation_attr_to_node)451 LiftedArgsAndOutsideCompilationNodesInFunctionBody(
452     const FunctionBody& function_body,
453     const std::unordered_map<string, Node*>& outside_compilation_attr_to_node) {
454   std::vector<std::pair<Node*, Node*>>
455       lifted_arg_nodes_and_outside_compilation_nodes;
456   for (Node* n : function_body.graph->op_nodes()) {
457     string oc_cluster;
458     if (n->type_string() == "Placeholder" &&
459         GetNodeAttr(n->def(), kXlaLiftedArgOutsideCompilationAttrName,
460                     &oc_cluster)
461             .ok()) {
462       TF_RET_CHECK(outside_compilation_attr_to_node.find(oc_cluster) !=
463                    outside_compilation_attr_to_node.end());
464       lifted_arg_nodes_and_outside_compilation_nodes.emplace_back(
465           n, outside_compilation_attr_to_node.at(oc_cluster));
466     }
467   }
468   return lifted_arg_nodes_and_outside_compilation_nodes;
469 }
470 
471 // Append lifted args' types to functional control flow node's `type_attr_name`
472 // attribute.
UpdateTypesAttribute(const std::vector<std::pair<Node *,Node * >> & lifted_arg_nodes_and_outside_compilation_nodes,const string & type_attr_name,Node * n)473 StatusOr<std::vector<DataType>> UpdateTypesAttribute(
474     const std::vector<std::pair<Node*, Node*>>&
475         lifted_arg_nodes_and_outside_compilation_nodes,
476     const string& type_attr_name, Node* n) {
477   std::vector<DataType> data_types;
478   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), type_attr_name, &data_types));
479   for (auto pair : lifted_arg_nodes_and_outside_compilation_nodes) {
480     Node* outside_compilation_node = pair.second;
481     DataType data_type;
482     TF_RET_CHECK(outside_compilation_node->IsIdentity() ||
483                  outside_compilation_node->type_string() == "Placeholder");
484     if (outside_compilation_node->IsIdentity()) {
485       TF_RETURN_IF_ERROR(
486           GetNodeAttr(outside_compilation_node->def(), "T", &data_type));
487     } else {
488       TF_RETURN_IF_ERROR(
489           GetNodeAttr(outside_compilation_node->def(), "dtype", &data_type));
490     }
491     data_types.push_back(data_type);
492   }
493   n->ClearAttr(type_attr_name);
494   n->AddAttr(type_attr_name, data_types);
495 
496   return data_types;
497 }
498 
499 // Add edges from lifted outside compilation argument nodes to `n` in Graph `g`.
AddEdgesFromOutsideCompilationNodes(const int original_arg_count,const int arg_to_input_edge_offset,const std::vector<DataType> & data_types,const std::vector<Node * > & outside_compilation_nodes,Graph * g,Node * n)500 void AddEdgesFromOutsideCompilationNodes(
501     const int original_arg_count, const int arg_to_input_edge_offset,
502     const std::vector<DataType>& data_types,
503     const std::vector<Node*>& outside_compilation_nodes, Graph* g, Node* n) {
504   // Add edges from outside compilation nodes to While node.
505   for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
506     Node* outside_compilation_node =
507         outside_compilation_nodes[i - original_arg_count];
508     g->AddEdge(outside_compilation_node, 0, n, i + arg_to_input_edge_offset);
509   }
510 }
511 
512 // Construct _Arg that maps to lifted outside compilation argument node input.
AddOutsideCompilationInputArgToFunctionBody(const FunctionBody & function_body,const int arg_idx,const DataType & data_type)513 StatusOr<Node*> AddOutsideCompilationInputArgToFunctionBody(
514     const FunctionBody& function_body, const int arg_idx,
515     const DataType& data_type) {
516   NodeDefBuilder arg_builder(absl::StrCat("arg_", arg_idx), "_Arg");
517   arg_builder.Attr("T", data_type);
518   arg_builder.Attr("index", arg_idx);
519   NodeDef arg_def;
520   TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def));
521 
522   Status s;
523   Node* arg_node = function_body.graph->AddNode(arg_def, &s);
524   TF_RETURN_IF_ERROR(s);
525   return arg_node;
526 }
527 
528 // Add _Retval node that matches newly added `arg_node` and connect `arg_node`
529 // to it.
AddMatchingRetvalNode(const FunctionBody & function_body,const int arg_idx,const DataType & data_type,Node * arg_node)530 Status AddMatchingRetvalNode(const FunctionBody& function_body,
531                              const int arg_idx, const DataType& data_type,
532                              Node* arg_node) {
533   NodeDefBuilder ret_builder(absl::StrCat("ret_", arg_idx), "_Retval");
534   ret_builder.Attr("T", data_type);
535   ret_builder.Attr("index", arg_idx);
536   ret_builder.Input(arg_node->name(), 0, data_type);
537   NodeDef ret_def;
538   TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
539   Status s;
540   Node* ret_node = function_body.graph->AddNode(ret_def, &s);
541   TF_RETURN_IF_ERROR(s);
542   function_body.graph->AddEdge(arg_node, 0, ret_node, 0);
543 
544   return Status::OK();
545 }
546 
ReplaceLiftedArgNodePlaceholderWithArg(const FunctionBody & function_body,const int original_arg_count,const int arg_idx,const std::vector<Node * > & lifted_arg_nodes,Node * arg_node)547 void ReplaceLiftedArgNodePlaceholderWithArg(
548     const FunctionBody& function_body, const int original_arg_count,
549     const int arg_idx, const std::vector<Node*>& lifted_arg_nodes,
550     Node* arg_node) {
551   Node* lifted_arg_node = lifted_arg_nodes[arg_idx - original_arg_count];
552   // This might happen because lifted_arg_node only exists in one branch of an
553   // If node, and we are handling the other branch.
554   if (!lifted_arg_node) {
555     return;
556   }
557 
558   for (const Edge* e : lifted_arg_node->out_edges()) {
559     if (e->IsControlEdge()) {
560       function_body.graph->AddControlEdge(arg_node, e->dst());
561     } else {
562       function_body.graph->AddEdge(arg_node, 0, e->dst(), e->dst_input());
563     }
564   }
565   function_body.graph->RemoveNode(lifted_arg_node);
566 }
567 
568 // Adds function def to function definition library and update the function
569 // callsite operation `callsite_node` to invoke new function instead.
AddFunctionWithNewName(const std::string & new_name,const std::string & func_attr_name,const FunctionDef & function_def,NameAttrList * func_attr,Node * callsite_node,FunctionLibraryDefinition * fld)570 Status AddFunctionWithNewName(const std::string& new_name,
571                               const std::string& func_attr_name,
572                               const FunctionDef& function_def,
573                               NameAttrList* func_attr, Node* callsite_node,
574                               FunctionLibraryDefinition* fld) {
575   TF_RETURN_IF_ERROR(fld->AddFunctionDef(function_def));
576   func_attr->set_name(new_name);
577   callsite_node->ClearAttr(func_attr_name);
578   callsite_node->AddAttr(func_attr_name, *func_attr);
579   return Status::OK();
580 }
581 
582 // Reconnect outside compilation lifted arguments in a functional While node to
583 // its outside compilation tensor sources.
PostprocessLiftedArgsForWhile(const std::unordered_map<string,Node * > & outside_compilation_attr_to_node,Graph * g,Node * n,FunctionLibraryDefinition * fld)584 Status PostprocessLiftedArgsForWhile(
585     const std::unordered_map<string, Node*>& outside_compilation_attr_to_node,
586     Graph* g, Node* n, FunctionLibraryDefinition* fld) {
587   TF_RET_CHECK(n->IsWhileNode());
588 
589   // Check if there is any lifted args in body function.
590   NameAttrList body_func;
591   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "body", &body_func));
592   const FunctionDef* body_function_def = fld->Find(body_func.name());
593   TF_RET_CHECK(body_function_def);
594 
595   if (!HasLiftedArgs(*body_function_def)) {
596     return Status::OK();
597   }
598 
599   // Gather all lifted args.
600   std::unique_ptr<FunctionBody> body_function_body;
601   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*body_function_def,
602                                              AttrSlice(&body_func.attr()), fld,
603                                              &body_function_body));
604 
605   int original_arg_count = body_function_body->arg_nodes.size();
606 
607   TF_ASSIGN_OR_RETURN(
608       auto lifted_arg_nodes_and_outside_compilation_nodes,
609       LiftedArgsAndOutsideCompilationNodesInFunctionBody(
610           *body_function_body, outside_compilation_attr_to_node));
611 
612   // Append lifted args' types to While node's T attribute.
613   TF_ASSIGN_OR_RETURN(
614       std::vector<DataType> data_types,
615       UpdateTypesAttribute(lifted_arg_nodes_and_outside_compilation_nodes, "T",
616                            n));
617 
618   // Add edges from outside compilation nodes to While node.
619   std::vector<Node*> outside_compilation_nodes;
620   std::transform(
621       lifted_arg_nodes_and_outside_compilation_nodes.begin(),
622       lifted_arg_nodes_and_outside_compilation_nodes.end(),
623       std::back_inserter(outside_compilation_nodes),
624       [](const std::pair<Node*, Node*>& pair) { return pair.second; });
625   AddEdgesFromOutsideCompilationNodes(original_arg_count,
626                                       /*arg_to_input_edge_offset=*/0,
627                                       data_types, outside_compilation_nodes, g,
628                                       n);
629 
630   // In body_graph, create new _Arg/_Retval nodes, and replace lifted arg
631   // nodes with the new _Arg nodes.
632   std::vector<Node*> lifted_arg_nodes;
633   std::transform(
634       lifted_arg_nodes_and_outside_compilation_nodes.begin(),
635       lifted_arg_nodes_and_outside_compilation_nodes.end(),
636       std::back_inserter(lifted_arg_nodes),
637       [](const std::pair<Node*, Node*>& pair) { return pair.first; });
638   for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
639     TF_ASSIGN_OR_RETURN(Node * arg_node,
640                         AddOutsideCompilationInputArgToFunctionBody(
641                             *body_function_body, i, data_types[i]));
642 
643     TF_RETURN_IF_ERROR(
644         AddMatchingRetvalNode(*body_function_body, i, data_types[i], arg_node));
645 
646     ReplaceLiftedArgNodePlaceholderWithArg(
647         *body_function_body, original_arg_count, i, lifted_arg_nodes, arg_node);
648   }
649 
650   const auto new_body_function_name =
651       fld->UniqueFunctionName(absl::StrCat(body_func.name(), "_lifted_arg_"));
652   FunctionDef rewritten_body_function_def;
653   TF_RETURN_IF_ERROR(GraphToFunctionDef(
654       *body_function_body->graph, new_body_function_name,
655       HostGraphControlRetMapping, &rewritten_body_function_def));
656   TF_RETURN_IF_ERROR(AddFunctionWithNewName(new_body_function_name, "body",
657                                             rewritten_body_function_def,
658                                             &body_func, n, fld));
659 
660   // In cond_graph, just add new _Arg nodes.
661   NameAttrList cond_func;
662   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "cond", &cond_func));
663   const FunctionDef* cond_function_def = fld->Find(cond_func.name());
664   TF_RET_CHECK(cond_function_def);
665   std::unique_ptr<FunctionBody> cond_function_body;
666   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*cond_function_def,
667                                              AttrSlice(&cond_func.attr()), fld,
668                                              &cond_function_body));
669 
670   for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
671     StatusOr<Node*> arg_node_or = AddOutsideCompilationInputArgToFunctionBody(
672         *cond_function_body, i, data_types[i]);
673     TF_RETURN_IF_ERROR(arg_node_or.status());
674   }
675 
676   const auto new_cond_function_name =
677       fld->UniqueFunctionName(absl::StrCat(cond_func.name(), "_lifted_arg_"));
678   FunctionDef rewritten_cond_function_def;
679   TF_RETURN_IF_ERROR(GraphToFunctionDef(
680       *cond_function_body->graph, new_cond_function_name,
681       HostGraphControlRetMapping, &rewritten_cond_function_def));
682   TF_RETURN_IF_ERROR(AddFunctionWithNewName(new_cond_function_name, "cond",
683                                             rewritten_cond_function_def,
684                                             &cond_func, n, fld));
685   return Status::OK();
686 }
687 
PostprocessLiftedArgsForIf(const std::unordered_map<string,Node * > & outside_compilation_attr_to_node,Graph * g,Node * n,FunctionLibraryDefinition * fld)688 Status PostprocessLiftedArgsForIf(
689     const std::unordered_map<string, Node*>& outside_compilation_attr_to_node,
690     Graph* g, Node* n, FunctionLibraryDefinition* fld) {
691   TF_RET_CHECK(n->IsIfNode());
692 
693   NameAttrList then_branch_func;
694   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "then_branch", &then_branch_func));
695   const FunctionDef* then_branch_function_def =
696       fld->Find(then_branch_func.name());
697   TF_RET_CHECK(then_branch_function_def);
698 
699   NameAttrList else_branch_func;
700   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "else_branch", &else_branch_func));
701   const FunctionDef* else_branch_function_def =
702       fld->Find(else_branch_func.name());
703   TF_RET_CHECK(else_branch_function_def);
704 
705   // Nothing to do if neither branch contains any lifted arguments.
706   if (!HasLiftedArgs(*then_branch_function_def) &&
707       !HasLiftedArgs(*else_branch_function_def)) {
708     return Status::OK();
709   }
710 
711   std::unique_ptr<FunctionBody> then_branch_function_body;
712   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
713       *then_branch_function_def, AttrSlice(&then_branch_func.attr()), fld,
714       &then_branch_function_body));
715 
716   std::unique_ptr<FunctionBody> else_branch_function_body;
717   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
718       *else_branch_function_def, AttrSlice(&else_branch_func.attr()), fld,
719       &else_branch_function_body));
720 
721   // Then and else branches have same argument count and argument data types.
722   int original_arg_count = then_branch_function_body->arg_nodes.size();
723 
724   TF_ASSIGN_OR_RETURN(
725       auto then_branch_lifted_arg_nodes_and_outside_compilation_nodes,
726       LiftedArgsAndOutsideCompilationNodesInFunctionBody(
727           *then_branch_function_body, outside_compilation_attr_to_node));
728 
729   TF_ASSIGN_OR_RETURN(
730       auto else_branch_lifted_arg_nodes_and_outside_compilation_nodes,
731       LiftedArgsAndOutsideCompilationNodesInFunctionBody(
732           *else_branch_function_body, outside_compilation_attr_to_node));
733 
734   // Merge lifted args from then and else branches.
735   std::vector<Node*> outside_compilation_nodes;
736   std::vector<Node*> then_branch_lifted_arg_nodes;
737   for (const auto& pair :
738        then_branch_lifted_arg_nodes_and_outside_compilation_nodes) {
739     outside_compilation_nodes.push_back(pair.second);
740     then_branch_lifted_arg_nodes.push_back(pair.first);
741   }
742   for (const auto& pair :
743        else_branch_lifted_arg_nodes_and_outside_compilation_nodes) {
744     if (std::find(outside_compilation_nodes.begin(),
745                   outside_compilation_nodes.end(),
746                   pair.second) == outside_compilation_nodes.end()) {
747       outside_compilation_nodes.push_back(pair.second);
748       // Then branch does not contain this lifted arg. Add an empty item to
749       // then_branch_lifted_arg_nodes.
750       then_branch_lifted_arg_nodes.push_back(nullptr);
751     }
752   }
753   // Reorder else_branch_lifted_arg_nodes_and_outside_compilation_nodes.
754   std::vector<Node*> else_branch_lifted_arg_nodes(
755       outside_compilation_nodes.size());
756   for (const auto& pair :
757        else_branch_lifted_arg_nodes_and_outside_compilation_nodes) {
758     auto iter = std::find(outside_compilation_nodes.begin(),
759                           outside_compilation_nodes.end(), pair.second);
760     TF_RET_CHECK(iter != outside_compilation_nodes.end());
761     int index = iter - outside_compilation_nodes.begin();
762     else_branch_lifted_arg_nodes[index] = pair.first;
763   }
764 
765   // Append lifted args' types to If node's Tin attribute.
766   std::vector<DataType> data_types;
767   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "Tin", &data_types));
768   for (Node* n : outside_compilation_nodes) {
769     data_types.push_back(n->output_type(0));
770   }
771   n->ClearAttr("Tin");
772   n->AddAttr("Tin", data_types);
773 
774   // Add edges from outside compilation nodes to If node. If node's input #0
775   // is predicate input, input #1 maps to _Arg #0 of branch functions, thus
776   // arg_to_input_edge_offset is set to 1.
777   AddEdgesFromOutsideCompilationNodes(original_arg_count,
778                                       /*arg_to_input_edge_offset=*/1,
779                                       data_types, outside_compilation_nodes, g,
780                                       n);
781 
782   for (int i = original_arg_count, end = data_types.size(); i < end; ++i) {
783     TF_ASSIGN_OR_RETURN(Node * then_branch_arg_node,
784                         AddOutsideCompilationInputArgToFunctionBody(
785                             *then_branch_function_body, i, data_types[i]));
786 
787     ReplaceLiftedArgNodePlaceholderWithArg(
788         *then_branch_function_body, original_arg_count, i,
789         then_branch_lifted_arg_nodes, then_branch_arg_node);
790 
791     TF_ASSIGN_OR_RETURN(Node * else_branch_arg_node,
792                         AddOutsideCompilationInputArgToFunctionBody(
793                             *else_branch_function_body, i, data_types[i]));
794 
795     ReplaceLiftedArgNodePlaceholderWithArg(
796         *else_branch_function_body, original_arg_count, i,
797         else_branch_lifted_arg_nodes, else_branch_arg_node);
798   }
799 
800   const auto new_then_function_name = fld->UniqueFunctionName(
801       absl::StrCat(then_branch_func.name(), "_lifted_arg_"));
802   FunctionDef rewritten_then_branch_function_def;
803   TF_RETURN_IF_ERROR(GraphToFunctionDef(
804       *then_branch_function_body->graph, new_then_function_name,
805       HostGraphControlRetMapping, &rewritten_then_branch_function_def));
806   TF_RETURN_IF_ERROR(AddFunctionWithNewName(
807       new_then_function_name, "then_branch", rewritten_then_branch_function_def,
808       &then_branch_func, n, fld));
809 
810   const auto new_else_function_name = fld->UniqueFunctionName(
811       absl::StrCat(else_branch_func.name(), "_lifted_arg_"));
812   FunctionDef rewritten_else_branch_function_def;
813   TF_RETURN_IF_ERROR(GraphToFunctionDef(
814       *else_branch_function_body->graph, new_else_function_name,
815       HostGraphControlRetMapping, &rewritten_else_branch_function_def));
816   TF_RETURN_IF_ERROR(AddFunctionWithNewName(
817       new_else_function_name, "else_branch", rewritten_else_branch_function_def,
818       &else_branch_func, n, fld));
819   return Status::OK();
820 }
821 
PostprocessLiftedArgsForCall(const std::unordered_map<string,Node * > & outside_compilation_attr_to_node,Graph * g,Node * n,FunctionLibraryDefinition * fld)822 Status PostprocessLiftedArgsForCall(
823     const std::unordered_map<string, Node*>& outside_compilation_attr_to_node,
824     Graph* g, Node* n, FunctionLibraryDefinition* fld) {
825   const FunctionDef* fdef = fld->Find(n->type_string());
826   TF_RET_CHECK(fdef);
827 
828   // Nothing to do if the function does not contain any lifted arguments.
829   if (!HasLiftedArgs(*fdef)) {
830     return Status::OK();
831   }
832 
833   std::unique_ptr<FunctionBody> fbody;
834   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, n->attrs(), fld, &fbody));
835 
836   int original_arg_count = fbody->arg_nodes.size();
837 
838   TF_ASSIGN_OR_RETURN(auto lifted_arg_nodes_and_outside_compilation_nodes,
839                       LiftedArgsAndOutsideCompilationNodesInFunctionBody(
840                           *fbody, outside_compilation_attr_to_node));
841 
842   // Append lifted args' types to call node's input data types.
843   std::vector<DataType> data_types(n->input_types().begin(),
844                                    n->input_types().end());
845   for (auto pair : lifted_arg_nodes_and_outside_compilation_nodes) {
846     Node* outside_compilation_node = pair.second;
847     DataType data_type;
848     TF_RET_CHECK(outside_compilation_node->IsIdentity() ||
849                  outside_compilation_node->type_string() == "Placeholder");
850     if (outside_compilation_node->IsIdentity()) {
851       TF_RETURN_IF_ERROR(
852           GetNodeAttr(outside_compilation_node->def(), "T", &data_type));
853     } else {
854       TF_RETURN_IF_ERROR(
855           GetNodeAttr(outside_compilation_node->def(), "dtype", &data_type));
856     }
857     data_types.push_back(data_type);
858   }
859 
860   std::vector<Node*> lifted_arg_nodes;
861   std::transform(
862       lifted_arg_nodes_and_outside_compilation_nodes.begin(),
863       lifted_arg_nodes_and_outside_compilation_nodes.end(),
864       std::back_inserter(lifted_arg_nodes),
865       [](const std::pair<Node*, Node*>& pair) { return pair.first; });
866   for (int i = original_arg_count, end = data_types.size(); i < end; ++i) {
867     TF_ASSIGN_OR_RETURN(
868         Node * arg_node,
869         AddOutsideCompilationInputArgToFunctionBody(*fbody, i, data_types[i]));
870 
871     ReplaceLiftedArgNodePlaceholderWithArg(*fbody, original_arg_count, i,
872                                            lifted_arg_nodes, arg_node);
873   }
874 
875   FunctionDef rewritten_fdef;
876   TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, n->type_string(),
877                                         HostGraphControlRetMapping,
878                                         &rewritten_fdef));
879   const auto new_function_name =
880       fld->UniqueFunctionName(absl::StrCat(n->type_string(), "_lifted_arg_"));
881   rewritten_fdef.mutable_signature()->set_name(new_function_name);
882   TF_RETURN_IF_ERROR(fld->AddFunctionDef(rewritten_fdef));
883 
884   // We need to recreate the node. Otherwise TF will not know n->num_inputs()
885   // has increased.
886   NodeDef node_def = n->def();
887 
888   // Function name is represented via the Op's type. Reset the op type to new
889   // function def name;
890   *node_def.mutable_op() = new_function_name;
891 
892   for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
893     Node* outside_compilation_node =
894         lifted_arg_nodes_and_outside_compilation_nodes[i - original_arg_count]
895             .second;
896     node_def.add_input(absl::StrCat(outside_compilation_node->name(), ":", 0));
897   }
898   TF_ASSIGN_OR_RETURN(n, ReplaceNode(g, n, node_def));
899 
900   // Add edges from outside compilation nodes to call node.
901   std::vector<Node*> outside_compilation_nodes;
902   std::transform(
903       lifted_arg_nodes_and_outside_compilation_nodes.begin(),
904       lifted_arg_nodes_and_outside_compilation_nodes.end(),
905       std::back_inserter(outside_compilation_nodes),
906       [](const std::pair<Node*, Node*>& pair) { return pair.second; });
907   AddEdgesFromOutsideCompilationNodes(original_arg_count,
908                                       /*arg_to_input_edge_offset=*/0,
909                                       data_types, outside_compilation_nodes, g,
910                                       n);
911 
912   return Status::OK();
913 }
914 
915 // Creates a mapping from outside compilation cluster name to lifted argument
916 // placeholder.
OutsideCompilationAttrToNode(const Graph & g)917 StatusOr<std::unordered_map<string, Node*>> OutsideCompilationAttrToNode(
918     const Graph& g) {
919   std::unordered_map<string, Node*> outside_compilation_attr_to_node;
920   for (Node* n : g.op_nodes()) {
921     bool is_lifted_arg;
922     string outside_compilation_attr;
923     if (TryGetNodeAttr(n->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg) &&
924         TryGetNodeAttr(n->def(), "_xla_outside_compilation",
925                        &outside_compilation_attr)) {
926       TF_RET_CHECK(is_lifted_arg);
927       TF_RET_CHECK(n->IsIdentity() || n->type_string() == "Placeholder");
928       outside_compilation_attr_to_node[outside_compilation_attr] = n;
929     }
930   }
931 
932   return outside_compilation_attr_to_node;
933 }
934 
PostprocessLiftedArgs(Graph * g,FunctionLibraryDefinition * fld)935 Status PostprocessLiftedArgs(Graph* g, FunctionLibraryDefinition* fld) {
936   TF_ASSIGN_OR_RETURN(auto outside_compilation_attr_to_node,
937                       OutsideCompilationAttrToNode(*g));
938 
939   std::vector<Node*> call_nodes;
940   for (Node* n : g->op_nodes()) {
941     if (!HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) {
942       continue;
943     }
944 
945     if (n->IsWhileNode()) {
946       TF_RETURN_IF_ERROR(PostprocessLiftedArgsForWhile(
947           outside_compilation_attr_to_node, g, n, fld));
948     }
949 
950     if (n->IsIfNode()) {
951       TF_RETURN_IF_ERROR(PostprocessLiftedArgsForIf(
952           outside_compilation_attr_to_node, g, n, fld));
953     }
954 
955     // Outside compilation host side function call will always be direct
956     // function call nodes.
957     // Function call nodes need to be handled separately because we rewrite
958     // nodes in `PostprocessLiftedArgsForCall`.
959     if (fld->Contains(n->type_string())) {
960       call_nodes.push_back(n);
961     }
962   }
963 
964   for (Node* n : call_nodes) {
965     TF_RETURN_IF_ERROR(PostprocessLiftedArgsForCall(
966         outside_compilation_attr_to_node, g, n, fld));
967   }
968 
969   return Status::OK();
970 }
971 
972 // For an XLA computation, builds host side graph given all outside compilation
973 // graphs inside it. The host side graph contains:
974 // 1) a "sequencer" node (we will add control edge between XlaRecvAtHost and
975 //    XlaSendFromHost to this sequencer node, so all outside compilation nodes
976 //    will be executed *before* this sequencer).
977 // 2) a "key placeholder" node. Later in ExpandHostGraphIntoMainGraph(), we will
978 //    replace this node with compilation result node.
979 // 3) all outside compilation graphs.
ConstructHostGraph(const string & xla_cluster_name,const string & outside_compilation_attr_name,const std::vector<string> & outside_compilation_host_graphs,FunctionLibraryDefinition * fld,std::unique_ptr<Graph> * host_graph)980 Status ConstructHostGraph(
981     const string& xla_cluster_name, const string& outside_compilation_attr_name,
982     const std::vector<string>& outside_compilation_host_graphs,
983     FunctionLibraryDefinition* fld, std::unique_ptr<Graph>* host_graph) {
984   host_graph->reset(new Graph(fld));
985 
986   // Create sequencer node in host graph.
987   NodeDefBuilder sequencer_builder(absl::StrCat(xla_cluster_name, "_sequencer"),
988                                    "NoOp");
989   sequencer_builder.Attr("_xla_host_transfer_sequencer", xla_cluster_name);
990   NodeDef sequencer_def;
991   TF_RETURN_IF_ERROR(sequencer_builder.Finalize(&sequencer_def));
992   Status s;
993   Node* sequencer = (*host_graph)->AddNode(sequencer_def, &s);
994   TF_RETURN_IF_ERROR(s);
995 
996   // Create key placeholder in host graph.
997   TF_ASSIGN_OR_RETURN(
998       Node * key_placeholder,
999       AddHostComputeKeyPlaceholder(xla_cluster_name, host_graph->get()));
1000 
1001   // For each outside compilation graph, copy them to host graph with the
1002   // following changes:
1003   // a) Use key_placeholder in host graph instead of its own.
1004   // b) Add control edge from host transfer nodes (XlaRecvAtHost,
1005   //    XlaSendFromHost, If/While nodes containing
1006   //    XlaRecvAtHost/XlaSendFromHost) to sequencer node.
1007   // c) Clear node_def.device(), so device placer won't get confused.
1008   for (const string& host_func : outside_compilation_host_graphs) {
1009     VLOG(4) << "Expanding host graph " << host_func;
1010     // Temporarily use "0" as "_device_ordinal". It will be reset to placeholder
1011     // value after we expanded all host graphs. We cannot just use placeholder
1012     // value here because FunctionDef instantiation does not allow placeholder
1013     // value for attributes.
1014     AttrValue device_ordinal_attr;
1015     device_ordinal_attr.set_i(0);
1016     protobuf::Map<string, AttrValue> attrs;
1017     attrs["_device_ordinal"] = device_ordinal_attr;
1018     std::unique_ptr<FunctionBody> host_fbody;
1019     const FunctionDef* host_fdef = fld->Find(host_func);
1020     TF_RET_CHECK(host_fdef);
1021     TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*host_fdef, AttrSlice(&attrs),
1022                                                fld, &host_fbody));
1023 
1024     // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse
1025     // reachable from sink node so all nodes will be copied.
1026     // TODO(b/77601805): consolidate copy graph functions.
1027     FixupSourceAndSinkEdges(host_fbody->graph);
1028 
1029     std::map<const Node*, Node*> node_map;
1030     node_map[host_fbody->graph->source_node()] = (*host_graph)->source_node();
1031     node_map[host_fbody->graph->sink_node()] = (*host_graph)->sink_node();
1032     Status s;
1033     ReverseDFS(
1034         *host_fbody->graph, /*enter=*/nullptr,
1035         [&](const Node* n) {
1036           if (!s.ok()) {
1037             return;
1038           }
1039 
1040           Node* copy;
1041           if (node_map.find(n) != node_map.end()) {
1042             // Already copied this node.
1043             copy = node_map.at(n);
1044           } else if (IsKeyPlaceholderNode(*n)) {
1045             // Change a).
1046             copy = key_placeholder;
1047             node_map[n] = copy;
1048           } else {
1049             // Copy the node.
1050             NodeDef copy_def = n->def();
1051             // Change c).
1052             copy_def.clear_device();
1053             copy = (*host_graph)->AddNode(copy_def, &s);
1054             if (!s.ok()) {
1055               return;
1056             }
1057             node_map[n] = copy;
1058           }
1059 
1060           // Only handle input edges. Output edges will be added later as
1061           // its output nodes' input edges.
1062           for (auto e : n->in_edges()) {
1063             if (node_map.find(e->src()) == node_map.end()) {
1064               s = errors::Internal("Cannot find node image for ",
1065                                    e->src()->DebugString());
1066               return;
1067             }
1068             (*host_graph)
1069                 ->AddEdge(node_map[e->src()], e->src_output(), copy,
1070                           e->dst_input());
1071           }
1072 
1073           // Change b).
1074           if (HasNodeAttr(copy->def(), kXlaHasHostTransferAttrName)) {
1075             (*host_graph)->AddControlEdge(copy, sequencer);
1076           }
1077         },
1078         NodeComparatorID());
1079 
1080     if (!s.ok()) {
1081       return s;
1082     }
1083   }
1084   // Reset "_device_ordinal" to placeholder value.
1085   TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(host_graph->get()));
1086 
1087   // sequencer and key_placeholder might be dead nodes. Prune them if necessary.
1088   // - sequencer should be pruned iff it has no input control edges from
1089   //   RecvAtHost/SendFromHost. If it has input control edge, we connect it to
1090   //   sink node so it won't be pruned.
1091   // - key_placeholder should be pruned iff there's no RecvAtHost/SendFromHost.
1092   //   We don't need to do anything special.
1093   if (!sequencer->in_edges().empty()) {
1094     (*host_graph)->AddControlEdge(sequencer, (*host_graph)->sink_node());
1095   }
1096   PruneForReverseReachability(
1097       host_graph->get(),
1098       std::unordered_set<const Node*>{(*host_graph)->sink_node()});
1099 
1100   // Postprocess edges between different outside compilations.
1101   TF_RETURN_IF_ERROR(PostprocessEdgesBetweenOutsideCompilations(
1102       host_graph->get(), outside_compilation_attr_name));
1103 
1104   // Postprocess lifted arg nodes.
1105   TF_RETURN_IF_ERROR(PostprocessLiftedArgs(host_graph->get(), fld));
1106 
1107   if (VLOG_IS_ON(4)) {
1108     DumpGraphToFile(absl::StrCat("extract_outside_compilation_host_graph_for_",
1109                                  xla_cluster_name),
1110                     **host_graph, fld);
1111   }
1112 
1113   return Status::OK();
1114 }
1115 
1116 // Expand XLA computation's outside compilation host side graph into main graph.
1117 // Add a control edge between sequencer node and the XLA computation node.
ExpandHostGraphIntoMainGraph(Graph * main_graph,FunctionLibraryDefinition * fld,const string & host_graph_func_name,Node * xla_computation_node,Node * pivot_node)1118 Status ExpandHostGraphIntoMainGraph(Graph* main_graph,
1119                                     FunctionLibraryDefinition* fld,
1120                                     const string& host_graph_func_name,
1121                                     Node* xla_computation_node,
1122                                     Node* pivot_node) {
1123   // Temporarily use "0" as "_device_ordinal". It will be rewritten with the
1124   // correct value in a later pass. We cannot just use placeholder value here
1125   // because FunctionDef instantiation does not allow placeholder value for
1126   // attributes.
1127   AttrValue device_ordinal_attr;
1128   device_ordinal_attr.set_i(0);
1129   protobuf::Map<string, AttrValue> attrs;
1130   attrs["_device_ordinal"] = device_ordinal_attr;
1131   std::unique_ptr<FunctionBody> fbody;
1132   const FunctionDef* host_graph_func = fld->Find(host_graph_func_name);
1133   TF_RET_CHECK(host_graph_func);
1134   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*host_graph_func,
1135                                              AttrSlice(&attrs), fld, &fbody));
1136   Graph* host_graph = fbody->graph;
1137 
1138   // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse
1139   // reachable from sink node so all nodes will be copied.
1140   // TODO(b/77601805): consolidate copy graph functions.
1141   FixupSourceAndSinkEdges(host_graph);
1142 
1143   // Copy all nodes.
1144   std::map<const Node*, Node*> node_map;
1145   if (pivot_node) {
1146     node_map[host_graph->source_node()] = pivot_node;
1147   } else {
1148     node_map[host_graph->source_node()] = main_graph->source_node();
1149   }
1150   node_map[host_graph->sink_node()] = main_graph->sink_node();
1151   Status s = Status::OK();
1152   auto copy_node_fn = [&](const Node* n) {
1153     if (!s.ok()) {
1154       return;
1155     }
1156 
1157     Node* copy;
1158     if (node_map.find(n) != node_map.end()) {
1159       // Already copied this node.
1160       copy = node_map.at(n);
1161     } else {
1162       // Copy the node.
1163       NodeDef copy_def = n->def();
1164       copy = main_graph->AddNode(copy_def, &s);
1165       if (!s.ok()) {
1166         return;
1167       }
1168       node_map[n] = copy;
1169     }
1170 
1171     // Only handle input edges. Output edges will be added later as its output
1172     // nodes' input edges.
1173     for (auto e : n->in_edges()) {
1174       if (node_map.find(e->src()) == node_map.end()) {
1175         s = errors::Internal("Cannot find node image for ",
1176                              e->src()->DebugString());
1177         return;
1178       }
1179       main_graph->AddEdge(node_map[e->src()], e->src_output(), copy,
1180                           e->dst_input());
1181     }
1182 
1183     // Add control edge from sequencer to XLA computation node.
1184     if (copy->type_string() == "NoOp" &&
1185         HasNodeAttr(copy->def(), "_xla_host_transfer_sequencer")) {
1186       main_graph->AddControlEdge(copy, xla_computation_node);
1187     }
1188   };
1189   ReverseDFS(*host_graph, /*enter=*/nullptr, copy_node_fn, NodeComparatorID());
1190   return s;
1191 }
1192 
1193 // Rewrites shape inference graph for outside compilation:
1194 // 1) If XlaSendFromHost also exists in `host_graph`, copy nodes from
1195 //    `host_graph`. Because we might still have outside compilation to outside
1196 //    compilation placeholder nodes in shape inference graph, which will prevent
1197 //    us from inferring XlaSendFromHost shape. But in `host_graph`, we already
1198 //    removed those placeholder nodes.
1199 // 2) Remove control edges.
1200 // 3) Prune nodes that are not useful for shape inference.
RewriteShapeInferenceGraph(const string & shape_inference_graph_name,Graph * host_graph,Node * pivot_node,FunctionLibraryDefinition * fld)1201 Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name,
1202                                   Graph* host_graph, Node* pivot_node,
1203                                   FunctionLibraryDefinition* fld) {
1204   // Use "0" as "_device_ordinal". It does not matter for shape inference.
1205   AttrValue device_ordinal_attr;
1206   device_ordinal_attr.set_i(0);
1207   protobuf::Map<string, AttrValue> attrs;
1208   attrs["_device_ordinal"] = device_ordinal_attr;
1209   std::unique_ptr<FunctionBody> fbody;
1210   const FunctionDef* shape_inference_graph =
1211       fld->Find(shape_inference_graph_name);
1212   TF_RET_CHECK(shape_inference_graph);
1213   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*shape_inference_graph,
1214                                              AttrSlice(&attrs), fld, &fbody));
1215   Graph* g = fbody->graph;
1216 
1217   // Find SendFromHost node.
1218   Node* send_from_host = nullptr;
1219   for (Node* n : g->nodes()) {
1220     if (n->type_string() == "_XlaSendFromHost") {
1221       send_from_host = n;
1222       break;
1223     }
1224   }
1225   if (!send_from_host) {
1226     return errors::Internal("Shape inference graph ",
1227                             shape_inference_graph_name,
1228                             " does not have _XlaSendFromHost node.");
1229   }
1230 
1231   // See if the SendFromHost node exists in `host_graph`.
1232   Node* send_node_in_host_graph = nullptr;
1233   for (Node* n : host_graph->nodes()) {
1234     if (n->name() == send_from_host->name()) {
1235       send_node_in_host_graph = n;
1236       break;
1237     }
1238   }
1239   if (send_node_in_host_graph) {
1240     // This is an "top-level" outside compilation. Clear the graph, and copy
1241     // SendFromHost and all its predecessors from `host_graph`.
1242     std::vector<Node*> nodes;
1243     for (Node* n : g->op_nodes()) {
1244       nodes.push_back(n);
1245     }
1246     for (Node* n : nodes) {
1247       g->RemoveNode(n);
1248     }
1249     Node* start_node = pivot_node ? pivot_node : host_graph->source_node();
1250     // Reverse DFS from send_from_host_main_graph, and stop at start_node.
1251     struct Visit {
1252       Node* n;
1253       bool is_exiting;
1254     };
1255     std::vector<Visit> stack{{send_node_in_host_graph, false}};
1256     std::map<Node*, Node*> node_map;
1257     node_map[host_graph->source_node()] = g->source_node();
1258     while (!stack.empty()) {
1259       Visit& curr = stack.back();
1260       if (curr.is_exiting) {
1261         if (node_map.find(curr.n) == node_map.end()) {
1262           Node* copy = g->CopyNode(curr.n);
1263           if (curr.n != start_node) {
1264             for (const Edge* e : curr.n->in_edges()) {
1265               auto node_iter = node_map.find(e->src());
1266               if (node_iter == node_map.end()) {
1267                 return errors::Internal("Cannot find node image for ",
1268                                         e->src()->DebugString());
1269               }
1270               g->AddEdge(node_iter->second, e->src_output(), copy,
1271                          e->dst_input());
1272             }
1273           }
1274           node_map[curr.n] = copy;
1275         }
1276         stack.pop_back();
1277       } else {
1278         curr.is_exiting = true;
1279         if (curr.n != start_node) {
1280           for (const Edge* e : curr.n->in_edges()) {
1281             if (node_map.find(e->src()) != node_map.end()) {
1282               continue;
1283             }
1284             stack.push_back({e->src(), false});
1285           }
1286         }
1287       }
1288     }
1289 
1290     send_from_host = node_map[send_node_in_host_graph];
1291   } else {
1292     // This is an outside compilation generated for If/While/gradient/etc.
1293     // It will be enough for shape inference. Leave `g` unchanged.
1294   }
1295 
1296   // Control edges are not useful for shape inference. Remove them.
1297   for (auto e : g->edges()) {
1298     if (e->IsControlEdge()) {
1299       g->RemoveEdge(e);
1300     }
1301   }
1302 
1303   // Nodes that are not reverse reachable from SendFromHost are not useful for
1304   // shape inference. Prune them.
1305   PruneForReverseReachability(g,
1306                               std::unordered_set<const Node*>{send_from_host});
1307 
1308   if (VLOG_IS_ON(4)) {
1309     DumpGraphToFile(shape_inference_graph_name, *g, fld);
1310   }
1311 
1312   // Replace original shape inference graph.
1313   FunctionDef fdef_replace;
1314   TF_RETURN_IF_ERROR(
1315       GraphToFunctionDef(*g, shape_inference_graph_name, &fdef_replace));
1316   TF_RETURN_IF_ERROR(
1317       fld->ReplaceFunction(shape_inference_graph_name, fdef_replace));
1318 
1319   return Status::OK();
1320 }
1321 
1322 // Builds XlaSendToHost node which sends cond predicate to host.
BuildSendIfPredNode(const string & name,const string & host_transfer_key,Node * pred_node,Graph * g)1323 TF_ATTRIBUTE_NOINLINE StatusOr<Node*> BuildSendIfPredNode(
1324     const string& name, const string& host_transfer_key, Node* pred_node,
1325     Graph* g) {
1326   NodeDefBuilder send_pred_builder(name, "XlaSendToHost");
1327   send_pred_builder.Attr("Tinput", DT_BOOL);
1328   send_pred_builder.Attr("key", absl::StrCat(host_transfer_key, "_dtoh_0"));
1329   send_pred_builder.Attr(kXlaTokenInputNodesAttrName,
1330                          std::vector<string>{kXlaTokenArgNodeName});
1331   send_pred_builder.Attr(kXlaOriginalOutsideCompilationNodeName, name);
1332   send_pred_builder.Input(pred_node->name(), 0, DT_BOOL);
1333   NodeDef send_pred_def;
1334   TF_RETURN_IF_ERROR(send_pred_builder.Finalize(&send_pred_def));
1335   Status s;
1336   Node* send_pred_node = g->AddNode(send_pred_def, &s);
1337   TF_RETURN_IF_ERROR(s);
1338   g->AddEdge(pred_node, 0, send_pred_node, 0);
1339   return send_pred_node;
1340 }
1341 
1342 // Replaces key placeholder node with an _Arg node.
ReplaceKeyPlaceholderWithArgNode(const string & xla_cluster_name,const string & func_name,FunctionLibraryDefinition * fld)1343 Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name,
1344                                         const string& func_name,
1345                                         FunctionLibraryDefinition* fld) {
1346   // Temporarily use "0" as "_device_ordinal". It will be reset to placeholder
1347   // value after rewriting.
1348   AttrValue device_ordinal_attr;
1349   device_ordinal_attr.set_i(0);
1350   protobuf::Map<string, AttrValue> attrs;
1351   attrs["_device_ordinal"] = device_ordinal_attr;
1352   std::unique_ptr<FunctionBody> fbody;
1353   const FunctionDef* func = fld->Find(func_name);
1354   TF_RETURN_IF_ERROR(
1355       FunctionDefToBodyHelper(*func, AttrSlice(&attrs), fld, &fbody));
1356   Graph* g = fbody->graph;
1357 
1358   // Find or create the key placeholder node.
1359   Node* key_placeholder = nullptr;
1360   for (Node* n : g->nodes()) {
1361     if (IsKeyPlaceholderNode(*n)) {
1362       key_placeholder = n;
1363       break;
1364     }
1365   }
1366   if (!key_placeholder) {
1367     TF_ASSIGN_OR_RETURN(key_placeholder,
1368                         AddHostComputeKeyPlaceholder(xla_cluster_name, g));
1369   }
1370 
1371   // Build the _Arg node, and replace key placeholder node with it.
1372   NodeDefBuilder arg_builder("key_arg", FunctionLibraryDefinition::kArgOp);
1373   arg_builder.Attr("T", DT_STRING);
1374   arg_builder.Attr("index", 0);
1375   NodeDef arg_def;
1376   TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def));
1377   TF_RETURN_IF_ERROR(ReplaceNode(g, key_placeholder, arg_def).status());
1378 
1379   // Reset "_device_ordinal" to placeholder value.
1380   TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(g));
1381 
1382   FunctionDef replace_fdef;
1383   TF_RETURN_IF_ERROR(GraphToFunctionDef(
1384       *g, func_name, HostGraphControlRetMapping, &replace_fdef));
1385   TF_RETURN_IF_ERROR(fld->ReplaceFunction(func_name, replace_fdef));
1386   return Status::OK();
1387 }
1388 
1389 // Builds host side graph for If node.
BuildHostGraphForIfNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const string & if_node_name,const string & host_transfer_key,const string & host_graph_func_name,FunctionLibraryDefinition * fld,const string & then_branch_host_func_name,const string & else_branch_host_func_name)1390 TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForIfNode(
1391     const string& xla_cluster_attr_name,
1392     const string& outside_compilation_attr_name, const string& xla_cluster_name,
1393     const string& if_node_name, const string& host_transfer_key,
1394     const string& host_graph_func_name, FunctionLibraryDefinition* fld,
1395     const string& then_branch_host_func_name,
1396     const string& else_branch_host_func_name) {
1397   Graph host_graph(fld);
1398   string outside_compilation_name = absl::StrCat("oc_if_", if_node_name);
1399   AttrValue device_ordinal_value;
1400   device_ordinal_value.set_placeholder("_device_ordinal");
1401 
1402   // Step 1: add key placeholder node.
1403   TF_ASSIGN_OR_RETURN(
1404       Node * key_placeholder,
1405       AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
1406 
1407   // Step 2: build XlaRecvAtHost node to recv predicate.
1408   NodeDefBuilder recv_pred_builder(
1409       absl::StrCat("recv_oc_if_pred_", if_node_name), "_XlaRecvAtHost");
1410   recv_pred_builder.Attr("Toutputs", std::vector<DataType>{DT_BOOL});
1411   recv_pred_builder.Attr("key", host_transfer_key);
1412   recv_pred_builder.Attr("device_ordinal", device_ordinal_value);
1413   recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1414   recv_pred_builder.Attr(outside_compilation_attr_name,
1415                          outside_compilation_name);
1416   recv_pred_builder.Attr(kXlaHasHostTransferAttrName, true);
1417   recv_pred_builder.Input(key_placeholder->name(), 0, DT_STRING);
1418   NodeDef recv_pred_def;
1419   TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def));
1420   Status s;
1421   Node* recv_pred_node = host_graph.AddNode(recv_pred_def, &s);
1422   TF_RETURN_IF_ERROR(s);
1423   host_graph.AddEdge(key_placeholder, 0, recv_pred_node, 0);
1424 
1425   // Step 3: rewrite `{then, else}_branch_host_func_name`, replace key
1426   // placeholder with an _Arg node.
1427   TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1428       xla_cluster_name, then_branch_host_func_name, fld));
1429   TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1430       xla_cluster_name, else_branch_host_func_name, fld));
1431 
1432   // Step 4: build If node to choose between `{then, else}_branch_host_graph`.
1433   NodeDefBuilder if_builder(absl::StrCat("oc_if_", if_node_name), "If");
1434   if_builder.Attr("Tcond", DT_BOOL);
1435   if_builder.Attr("Tin", std::vector<DataType>{DT_STRING});
1436   if_builder.Attr("Tout", std::vector<DataType>{});
1437   NameAttrList host_then_branch, host_else_branch;
1438   host_then_branch.set_name(then_branch_host_func_name);
1439   (*host_then_branch.mutable_attr())["_device_ordinal"] = device_ordinal_value;
1440   host_else_branch.set_name(else_branch_host_func_name);
1441   (*host_else_branch.mutable_attr())["_device_ordinal"] = device_ordinal_value;
1442   if_builder.Attr("then_branch", host_then_branch);
1443   if_builder.Attr("else_branch", host_else_branch);
1444   if_builder.Attr(kXlaHasHostTransferAttrName, true);
1445   if_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1446   if_builder.Attr(outside_compilation_attr_name, outside_compilation_name);
1447   if_builder.Input(recv_pred_node->name(), 0, DT_BOOL);
1448   std::vector<NodeDefBuilder::NodeOut> if_inputs{
1449       {key_placeholder->name(), 0, DT_STRING}};
1450   if_builder.Input(if_inputs);
1451   NodeDef if_def;
1452   TF_RETURN_IF_ERROR(if_builder.Finalize(&if_def));
1453   Node* if_node = host_graph.AddNode(if_def, &s);
1454   TF_RETURN_IF_ERROR(s);
1455   host_graph.AddEdge(recv_pred_node, 0, if_node, 0);
1456   host_graph.AddEdge(key_placeholder, 0, if_node, 1);
1457 
1458   // Convert `host_graph` to function.
1459   FunctionDef oc_host_graph_fdef;
1460   TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name,
1461                                         &oc_host_graph_fdef));
1462   if (fld->Find(host_graph_func_name)) {
1463     TF_RETURN_IF_ERROR(
1464         fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef));
1465   } else {
1466     TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef));
1467   }
1468 
1469   return Status::OK();
1470 }
1471 
1472 // Rewrites loop cond to add a node which sends loop cond to host.
AddSendLoopPredToLoopCond(const string & cond_xla_func_name,const string & host_transfer_key,NameAttrList * loop_cond_func,FunctionLibraryDefinition * fld,Node * while_node)1473 TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond(
1474     const string& cond_xla_func_name, const string& host_transfer_key,
1475     NameAttrList* loop_cond_func, FunctionLibraryDefinition* fld,
1476     Node* while_node) {
1477   // Instantiate the loop cond function.
1478   std::unique_ptr<FunctionBody> fbody;
1479   const FunctionDef* loop_cond_fdef = fld->Find(loop_cond_func->name());
1480   TF_RET_CHECK(loop_cond_fdef);
1481   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1482       *loop_cond_fdef, AttrSlice(&loop_cond_func->attr()), fld, &fbody));
1483   Graph* g = fbody->graph;
1484 
1485   // Find the _Retval node and the loop cond node.
1486   Node* ret_node = nullptr;
1487   for (Node* n : g->nodes()) {
1488     if (n->type_string() == "_Retval") {
1489       if (ret_node) {
1490         return errors::Internal("Multiple return node for loop cond function ",
1491                                 loop_cond_func->name(), ": ",
1492                                 ret_node->DebugString(), " and ",
1493                                 n->DebugString());
1494       } else {
1495         ret_node = n;
1496       }
1497     }
1498   }
1499   if (!ret_node) {
1500     return errors::Internal("No _Retval node for loop cond function ",
1501                             loop_cond_func->name());
1502   }
1503   Node* loop_cond;
1504   TF_RETURN_IF_ERROR(ret_node->input_node(0, &loop_cond));
1505 
1506   // Build the XlaSendToHost node.
1507   NodeDefBuilder send_loop_cond_builder(
1508       absl::StrCat("send_oc_while_cond_", while_node->name()), "XlaSendToHost");
1509   send_loop_cond_builder.Attr("Tinput", DT_BOOL);
1510   send_loop_cond_builder.Attr("key",
1511                               absl::StrCat(host_transfer_key, "_dtoh_0"));
1512   send_loop_cond_builder.Attr(kXlaTokenInputNodesAttrName,
1513                               std::vector<string>{kXlaTokenArgNodeName});
1514   send_loop_cond_builder.Attr(kXlaOriginalOutsideCompilationNodeName,
1515                               send_loop_cond_builder.node_name());
1516   send_loop_cond_builder.Input(loop_cond->name(), 0, DT_BOOL);
1517   NodeDef send_loop_cond_def;
1518   TF_RETURN_IF_ERROR(send_loop_cond_builder.Finalize(&send_loop_cond_def));
1519   Status s;
1520   Node* send_loop_cond_node = g->AddNode(send_loop_cond_def, &s);
1521   TF_RETURN_IF_ERROR(s);
1522   g->AddEdge(loop_cond, 0, send_loop_cond_node, 0);
1523 
1524   // Replace original function if loop_cond_func already has been re-written
1525   // for outside compilation.
1526   FunctionDef replace_fdef;
1527   if (loop_cond_func->name() == cond_xla_func_name) {
1528     TF_RETURN_IF_ERROR(
1529         GraphToFunctionDef(*g, loop_cond_func->name(), &replace_fdef));
1530     TF_RETURN_IF_ERROR(
1531         fld->ReplaceFunction(loop_cond_func->name(), replace_fdef));
1532   } else {
1533     // If original while cond function has not been modified, add a new function
1534     // with send loop predicated added and update the while node callsite
1535     // operation.
1536     const auto new_name = fld->UniqueFunctionName(
1537         absl::StrCat(loop_cond_func->name(), "_send_pred_added_"));
1538     TF_RETURN_IF_ERROR(GraphToFunctionDef(*g, new_name, &replace_fdef));
1539     TF_RETURN_IF_ERROR(fld->AddFunctionDef(replace_fdef));
1540     loop_cond_func->set_name(new_name);
1541     while_node->ClearAttr("cond");
1542     while_node->AddAttr("cond", *loop_cond_func);
1543   }
1544 
1545   return Status::OK();
1546 }
1547 
1548 // Rewrites while loop cond function for host.
RewriteHostWhileLoopCond(const string & cond_host_func_name,const string & while_node_name,const string & host_transfer_key,const string & xla_cluster_attr_name,const string & xla_cluster_name,const string & outside_compilation_attr_name,const string & outside_compilation_name,FunctionLibraryDefinition * fld)1549 Status RewriteHostWhileLoopCond(
1550     const string& cond_host_func_name, const string& while_node_name,
1551     const string& host_transfer_key, const string& xla_cluster_attr_name,
1552     const string& xla_cluster_name, const string& outside_compilation_attr_name,
1553     const string& outside_compilation_name, FunctionLibraryDefinition* fld) {
1554   // Replace key placeholder node with _Arg node.
1555   TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1556       xla_cluster_name, cond_host_func_name, fld));
1557 
1558   // Instantiate cond function.
1559   AttrValue device_ordinal_temp_value;
1560   device_ordinal_temp_value.set_i(0);
1561   protobuf::Map<string, AttrValue> attrs;
1562   attrs["_device_ordinal"] = device_ordinal_temp_value;
1563   std::unique_ptr<FunctionBody> cond_fbody;
1564   const FunctionDef* cond_host_func = fld->Find(cond_host_func_name);
1565   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*cond_host_func, AttrSlice(&attrs),
1566                                              fld, &cond_fbody));
1567   Graph* cond_graph = cond_fbody->graph;
1568   Node* key_arg = nullptr;
1569   for (Node* n : cond_graph->nodes()) {
1570     if (n->type_string() == "_Arg") {
1571       key_arg = n;
1572     }
1573   }
1574   if (!key_arg) {
1575     return errors::Internal(
1576         "No _Arg node found for host compute key in function ",
1577         cond_host_func_name);
1578   }
1579 
1580   // Add an XlaRecvAtHost node to use as cond function return value.
1581   NodeDefBuilder recv_pred_builder(
1582       absl::StrCat("recv_oc_while_cond_", while_node_name), "_XlaRecvAtHost");
1583   recv_pred_builder.Attr("Toutputs", std::vector<DataType>{DT_BOOL});
1584   recv_pred_builder.Attr("key", host_transfer_key);
1585   AttrValue device_ordinal_value;
1586   device_ordinal_value.set_placeholder("_device_ordinal");
1587   recv_pred_builder.Attr("device_ordinal", device_ordinal_value);
1588   recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1589   recv_pred_builder.Attr(outside_compilation_attr_name,
1590                          outside_compilation_name);
1591   recv_pred_builder.Attr(kXlaHasHostTransferAttrName, true);
1592   recv_pred_builder.Input(key_arg->name(), 0, DT_STRING);
1593   NodeDef recv_pred_def;
1594   TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def));
1595   Status s;
1596   Node* recv_pred_node = cond_graph->AddNode(recv_pred_def, &s);
1597   TF_RETURN_IF_ERROR(s);
1598   cond_graph->AddEdge(key_arg, 0, recv_pred_node, 0);
1599   NodeDefBuilder ret_builder(
1600       absl::StrCat("recv_oc_while_cond_ret_", while_node_name), "_Retval");
1601   ret_builder.Attr("T", DT_BOOL);
1602   ret_builder.Attr("index", 0);
1603   ret_builder.Input(recv_pred_node->name(), 0, DT_BOOL);
1604   NodeDef ret_def;
1605   TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
1606   Node* ret_node = cond_graph->AddNode(ret_def, &s);
1607   TF_RETURN_IF_ERROR(s);
1608   cond_graph->AddEdge(recv_pred_node, 0, ret_node, 0);
1609 
1610   // Reset device_ordinal to placeholder value.
1611   TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(cond_graph));
1612 
1613   // Replace original function.
1614   FunctionDef cond_replace_fdef;
1615   TF_RETURN_IF_ERROR(GraphToFunctionDef(*cond_graph, cond_host_func_name,
1616                                         HostGraphControlRetMapping,
1617                                         &cond_replace_fdef));
1618   TF_RETURN_IF_ERROR(
1619       fld->ReplaceFunction(cond_host_func_name, cond_replace_fdef));
1620 
1621   return Status::OK();
1622 }
1623 
1624 // Rewrites while loop body function for host.
RewriteHostWhileLoopBody(const string & body_host_func_name,const string & while_node_name,const string & host_transfer_key,const string & xla_cluster_attr_name,const string & xla_cluster_name,const string & outside_compilation_attr_name,const string & outside_compilation_name,FunctionLibraryDefinition * fld)1625 Status RewriteHostWhileLoopBody(
1626     const string& body_host_func_name, const string& while_node_name,
1627     const string& host_transfer_key, const string& xla_cluster_attr_name,
1628     const string& xla_cluster_name, const string& outside_compilation_attr_name,
1629     const string& outside_compilation_name, FunctionLibraryDefinition* fld) {
1630   // Replace key placeholder node with _Arg node.
1631   TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1632       xla_cluster_name, body_host_func_name, fld));
1633 
1634   // Instantiate body function.
1635   AttrValue device_ordinal_temp_value;
1636   device_ordinal_temp_value.set_i(0);
1637   protobuf::Map<string, AttrValue> attrs;
1638   attrs["_device_ordinal"] = device_ordinal_temp_value;
1639   std::unique_ptr<FunctionBody> body_fbody;
1640   const FunctionDef* body_host_func = fld->Find(body_host_func_name);
1641   TF_RET_CHECK(body_host_func);
1642   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*body_host_func, AttrSlice(&attrs),
1643                                              fld, &body_fbody));
1644   Graph* body_graph = body_fbody->graph;
1645   Node* key_arg = nullptr;
1646   for (Node* n : body_graph->nodes()) {
1647     if (n->type_string() == "_Arg") {
1648       key_arg = n;
1649     }
1650   }
1651   if (!key_arg) {
1652     return errors::Internal(
1653         "No _Arg node found for host compute key in function ",
1654         body_host_func_name);
1655   }
1656 
1657   // Add a _Retval node to loop body.
1658   NodeDefBuilder ret_builder(
1659       absl::StrCat("recv_oc_while_body_ret_", while_node_name), "_Retval");
1660   ret_builder.Attr("T", DT_STRING);
1661   ret_builder.Attr("index", 0);
1662   ret_builder.Input(key_arg->name(), 0, DT_STRING);
1663   NodeDef ret_def;
1664   TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
1665   Status s;
1666   Node* ret_node = body_graph->AddNode(ret_def, &s);
1667   TF_RETURN_IF_ERROR(s);
1668   body_graph->AddEdge(key_arg, 0, ret_node, 0);
1669 
1670   // Reset device_ordinal to placeholder value.
1671   TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(body_graph));
1672 
1673   // Replace original function.
1674   FunctionDef body_replace_fdef;
1675   TF_RETURN_IF_ERROR(GraphToFunctionDef(*body_graph, body_host_func_name,
1676                                         HostGraphControlRetMapping,
1677                                         &body_replace_fdef));
1678   TF_RETURN_IF_ERROR(
1679       fld->ReplaceFunction(body_host_func_name, body_replace_fdef));
1680 
1681   return Status::OK();
1682 }
1683 
1684 // Builds host side graph for while node.
BuildHostGraphForWhileNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const string & while_node_name,const string & host_transfer_key,const string & host_graph_func_name,FunctionLibraryDefinition * fld,const string & cond_host_func_name,const string & body_host_func_name)1685 TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForWhileNode(
1686     const string& xla_cluster_attr_name,
1687     const string& outside_compilation_attr_name, const string& xla_cluster_name,
1688     const string& while_node_name, const string& host_transfer_key,
1689     const string& host_graph_func_name, FunctionLibraryDefinition* fld,
1690     const string& cond_host_func_name, const string& body_host_func_name) {
1691   Graph host_graph(fld);
1692   string outside_compilation_name = absl::StrCat("oc_while_", while_node_name);
1693 
1694   // Step 1: add key placeholder node.
1695   TF_ASSIGN_OR_RETURN(
1696       Node * key_placeholder,
1697       AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
1698 
1699   // Step 2: rewrite cond function.
1700   TF_RETURN_IF_ERROR(RewriteHostWhileLoopCond(
1701       cond_host_func_name, while_node_name, host_transfer_key,
1702       xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
1703       outside_compilation_name, fld));
1704 
1705   // Step 3: rewrite body function.
1706   TF_RETURN_IF_ERROR(RewriteHostWhileLoopBody(
1707       body_host_func_name, while_node_name, host_transfer_key,
1708       xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
1709       outside_compilation_name, fld));
1710 
1711   // Step 4: build While node.
1712   NodeDefBuilder while_builder(absl::StrCat("oc_while_", while_node_name),
1713                                "While");
1714   while_builder.Attr("T", std::vector<DataType>{DT_STRING});
1715   NameAttrList func;
1716   AttrValue device_ordinal_value;
1717   device_ordinal_value.set_placeholder("_device_ordinal");
1718   (*func.mutable_attr())["_device_ordinal"] = device_ordinal_value;
1719   func.set_name(cond_host_func_name);
1720   while_builder.Attr("cond", func);
1721   func.set_name(body_host_func_name);
1722   while_builder.Attr("body", func);
1723   while_builder.Attr(kXlaHasHostTransferAttrName, true);
1724   while_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1725   while_builder.Attr(outside_compilation_attr_name, outside_compilation_name);
1726   // Make sure loop body of i-th iteration happens before loop cond of (i+1)-th
1727   // iteration.
1728   while_builder.Attr("parallel_iterations", 1);
1729   std::vector<NodeDefBuilder::NodeOut> while_inputs{
1730       {key_placeholder->name(), 0, DT_STRING}};
1731   while_builder.Input(while_inputs);
1732   NodeDef while_def;
1733   TF_RETURN_IF_ERROR(while_builder.Finalize(&while_def));
1734   Status s;
1735   Node* while_node = host_graph.AddNode(while_def, &s);
1736   TF_RETURN_IF_ERROR(s);
1737   host_graph.AddEdge(key_placeholder, 0, while_node, 0);
1738 
1739   // Convert `host_graph` to function.
1740   FunctionDef oc_host_graph_fdef;
1741   TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name,
1742                                         &oc_host_graph_fdef));
1743   if (fld->Find(host_graph_func_name)) {
1744     TF_RETURN_IF_ERROR(
1745         fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef));
1746   } else {
1747     TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef));
1748   }
1749 
1750   return Status::OK();
1751 }
1752 
1753 // Builds host graph for func call nodes.
BuildHostGraphForFuncCallNode(const string & xla_cluster_attr_name,const string & xla_cluster_name,const string & outside_compilation_attr_name,const string & func_call_node_name,const string & func_call_host_func_name,const string & host_graph_func_name,FunctionLibraryDefinition * fld)1754 Status BuildHostGraphForFuncCallNode(
1755     const string& xla_cluster_attr_name, const string& xla_cluster_name,
1756     const string& outside_compilation_attr_name,
1757     const string& func_call_node_name, const string& func_call_host_func_name,
1758     const string& host_graph_func_name, FunctionLibraryDefinition* fld) {
1759   Graph host_graph(fld);
1760   AttrValue device_ordinal_value;
1761   device_ordinal_value.set_placeholder("_device_ordinal");
1762 
1763   // Step 1: add key placeholder node.
1764   TF_ASSIGN_OR_RETURN(
1765       Node * key_placeholder,
1766       AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
1767 
1768   // Step 2: rewrite `host_func_name`, replace key placeholder with an _Arg
1769   // node.
1770   TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1771       xla_cluster_name, func_call_host_func_name, fld));
1772 
1773   // Step 3: build a function call node with `host_func_name`, with
1774   // `key_placeholder` as input.
1775   NodeDefBuilder call_builder(absl::StrCat("oc_call_", func_call_node_name),
1776                               func_call_host_func_name, fld);
1777   call_builder.Input(key_placeholder->name(), 0, DT_STRING);
1778   call_builder.Attr("_device_ordinal", device_ordinal_value);
1779   call_builder.Attr(kXlaHasHostTransferAttrName, true);
1780   call_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1781   call_builder.Attr(outside_compilation_attr_name, call_builder.node_name());
1782   NodeDef call_def;
1783   TF_RETURN_IF_ERROR(call_builder.Finalize(&call_def));
1784   Status s;
1785   Node* call_node = host_graph.AddNode(call_def, &s);
1786   TF_RETURN_IF_ERROR(s);
1787   host_graph.AddEdge(key_placeholder, 0, call_node, 0);
1788 
1789   // Convert `host_graph` to function.
1790   FunctionDef oc_host_graph_fdef;
1791   TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name,
1792                                         HostGraphControlRetMapping,
1793                                         &oc_host_graph_fdef));
1794   if (fld->Find(host_graph_func_name)) {
1795     TF_RETURN_IF_ERROR(
1796         fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef));
1797   } else {
1798     TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef));
1799   }
1800 
1801   return Status::OK();
1802 }
1803 
ExtractOutsideCompilationForFuncCallNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const std::map<string,int> & host_compute_core,Graph * g,Node * n,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * host_graphs,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)1804 TF_ATTRIBUTE_NOINLINE Status ExtractOutsideCompilationForFuncCallNode(
1805     const string& xla_cluster_attr_name,
1806     const string& outside_compilation_attr_name, const string& xla_cluster_name,
1807     const std::map<string, int>& host_compute_core, Graph* g, Node* n,
1808     FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
1809     std::vector<string>* host_graphs,
1810     std::vector<string>* shape_inference_graphs,
1811     bool* has_outside_compilation) {
1812   bool func_has_outside_compilation = false;
1813   NameAttrList func;
1814   if (fld->Contains(n->type_string())) {
1815     func.set_name(n->type_string());
1816     typedef protobuf::Map<string, AttrValue> AttrMap;
1817     *func.mutable_attr() = AttrMap(n->attrs().begin(), n->attrs().end());
1818   } else if (n->IsPartitionedCall()) {
1819     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &func));
1820   } else {
1821     TF_RET_CHECK(n->type_string() == FunctionLibraryDefinition::kGradientOp);
1822     func.set_name(FunctionLibraryDefinition::kGradientOp);
1823     *func.mutable_attr() = n->def().attr();
1824   }
1825   string canonical_func_name;
1826   if (func.name() == FunctionLibraryDefinition::kGradientOp) {
1827     NameAttrList forward_func;
1828     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &forward_func));
1829     canonical_func_name = absl::StrCat("gradient_", forward_func.name());
1830   } else {
1831     canonical_func_name = func.name();
1832   }
1833   string new_func_name = absl::StrCat(canonical_func_name, "_oc");
1834   string host_func_name =
1835       absl::StrCat("oc_func_call_host_", canonical_func_name);
1836   TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
1837       xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
1838       func, new_func_name, host_func_name, host_compute_core, flr, fld,
1839       shape_inference_graphs, &func_has_outside_compilation));
1840 
1841   // If the function call does not have outside compilation, nothing to do.
1842   if (!func_has_outside_compilation) {
1843     return Status::OK();
1844   }
1845 
1846   *has_outside_compilation = true;
1847 
1848   // Change `n` to call the new function directly.
1849   auto replace_builder =
1850       absl::make_unique<NodeDefBuilder>(n->name(), new_func_name, fld);
1851   std::vector<NodeDefBuilder::NodeOut> inputs(n->num_inputs());
1852   for (const Edge* e : n->in_edges()) {
1853     if (e->IsControlEdge()) {
1854       continue;
1855     }
1856 
1857     const bool input_size_check =
1858         e->dst_input() < static_cast<int>(inputs.size());
1859     TF_RET_CHECK(e->dst_input() >= 0 && input_size_check);
1860     inputs[e->dst_input()] =
1861         NodeDefBuilder::NodeOut{e->src()->name(), e->src_output(),
1862                                 e->src()->output_type(e->src_output())};
1863   }
1864   for (const auto& input : inputs) {
1865     replace_builder->Input(input);
1866   }
1867   for (const auto& attr : n->attrs()) {
1868     replace_builder->Attr(attr.first, attr.second);
1869   }
1870   auto replace_def = absl::make_unique<NodeDef>();
1871   TF_RETURN_IF_ERROR(replace_builder->Finalize(replace_def.get()));
1872   TF_ASSIGN_OR_RETURN(Node * replace, ReplaceNode(g, n, *replace_def));
1873   replace->AddAttr(kXlaTokenInputNodesAttrName,
1874                    std::vector<string>{kXlaTokenArgNodeName});
1875   replace->AddAttr(kXlaOriginalOutsideCompilationNodeName, replace->name());
1876 
1877   // Build host side graph for the function call.
1878   string oc_host_graph_name =
1879       absl::StrCat("oc_func_host_graph_", replace->name());
1880   TF_RETURN_IF_ERROR(BuildHostGraphForFuncCallNode(
1881       xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
1882       replace->name(), host_func_name, oc_host_graph_name, fld));
1883 
1884   // Record the host graph.
1885   host_graphs->push_back(oc_host_graph_name);
1886 
1887   return Status::OK();
1888 }
1889 
ExtractOutsideCompilationForIfNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const std::map<string,int> & host_compute_core,Graph * g,Node * n,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * host_graphs,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)1890 Status ExtractOutsideCompilationForIfNode(
1891     const string& xla_cluster_attr_name,
1892     const string& outside_compilation_attr_name, const string& xla_cluster_name,
1893     const std::map<string, int>& host_compute_core, Graph* g, Node* n,
1894     FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
1895     std::vector<string>* host_graphs,
1896     std::vector<string>* shape_inference_graphs,
1897     bool* has_outside_compilation) {
1898   // Instantiate "then_branch" and "else_branch".
1899   NameAttrList then_branch, else_branch;
1900   TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "then_branch", &then_branch));
1901   TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "else_branch", &else_branch));
1902 
1903   // Extract outside compilation for then_branch and else_branch.
1904   bool then_branch_has_outside_compilation = false;
1905   bool else_branch_has_outside_compilation = false;
1906   string then_branch_host_func_name =
1907              absl::StrCat("oc_then_branch_host_if_", then_branch.name()),
1908          else_branch_host_func_name =
1909              absl::StrCat("oc_else_branch_host_if_", else_branch.name());
1910   string then_branch_xla_func_name = absl::StrCat(then_branch.name(), "_oc"),
1911          else_branch_xla_func_name = absl::StrCat(else_branch.name(), "_oc");
1912   TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
1913       xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
1914       then_branch, then_branch_xla_func_name, then_branch_host_func_name,
1915       host_compute_core, flr, fld, shape_inference_graphs,
1916       &then_branch_has_outside_compilation));
1917   TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
1918       xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
1919       else_branch, else_branch_xla_func_name, else_branch_host_func_name,
1920       host_compute_core, flr, fld, shape_inference_graphs,
1921       &else_branch_has_outside_compilation));
1922 
1923   // If then/else branch do not have outside compilation, nothing to do.
1924   if (!then_branch_has_outside_compilation &&
1925       !else_branch_has_outside_compilation) {
1926     return Status::OK();
1927   }
1928 
1929   *has_outside_compilation = true;
1930 
1931   // Change If node to call the new functions.
1932   if (then_branch_has_outside_compilation) {
1933     then_branch.set_name(then_branch_xla_func_name);
1934     n->ClearAttr("then_branch");
1935     n->AddAttr("then_branch", then_branch);
1936   }
1937   if (else_branch_has_outside_compilation) {
1938     else_branch.set_name(else_branch_xla_func_name);
1939     n->ClearAttr("else_branch");
1940     n->AddAttr("else_branch", else_branch);
1941   }
1942   n->AddAttr(kXlaOriginalOutsideCompilationNodeName, n->name());
1943 
1944   string host_transfer_key = absl::StrCat("oc_if_pred_", n->name());
1945 
1946   // XLA computation: add a SendToHost node to send cond predicate.
1947   Node* pred_node;
1948   TF_RETURN_IF_ERROR(n->input_node(0, &pred_node));
1949   TF_ASSIGN_OR_RETURN(
1950       Node * send_pred_node,
1951       BuildSendIfPredNode(absl::StrCat("send_oc_if_pred_", n->name()),
1952                           host_transfer_key, pred_node, g));
1953   n->AddAttr(kXlaTokenInputNodesAttrName,
1954              std::vector<string>{send_pred_node->name()});
1955 
1956   // Add a control edge from `send_pred_node` to If node, so XlaCompiler will
1957   // visit If node after `send_pred_node`, thus the token output for
1958   // `send_pred_node` has been generated.
1959   g->AddControlEdge(send_pred_node, n);
1960 
1961   // Build host side graph for the "If" node.
1962   // If then/else branch does not have outside compilation, we won't build host
1963   // graph for the branch. But here we need a host graph for both branches, so
1964   // we need to create a no-op host graph.
1965   if (!then_branch_has_outside_compilation) {
1966     std::unique_ptr<Graph> then_branch_host_graph(new Graph(fld));
1967     std::vector<string> then_branch_host_graphs;
1968     TF_RETURN_IF_ERROR(ConstructHostGraph(
1969         xla_cluster_name, outside_compilation_attr_name,
1970         then_branch_host_graphs, fld, &then_branch_host_graph));
1971     FunctionDef then_branch_host_fdef;
1972     TF_RETURN_IF_ERROR(GraphToFunctionDef(*then_branch_host_graph,
1973                                           then_branch_host_func_name,
1974                                           &then_branch_host_fdef));
1975     if (fld->Find(then_branch_host_func_name)) {
1976       TF_RETURN_IF_ERROR(fld->ReplaceFunction(then_branch_host_func_name,
1977                                               then_branch_host_fdef));
1978     } else {
1979       TF_RETURN_IF_ERROR(fld->AddFunctionDef(then_branch_host_fdef));
1980     }
1981   }
1982   if (!else_branch_has_outside_compilation) {
1983     std::unique_ptr<Graph> else_branch_host_graph(new Graph(fld));
1984     std::vector<string> else_branch_host_graphs;
1985     TF_RETURN_IF_ERROR(ConstructHostGraph(
1986         xla_cluster_name, outside_compilation_attr_name,
1987         else_branch_host_graphs, fld, &else_branch_host_graph));
1988     FunctionDef else_branch_host_fdef;
1989     TF_RETURN_IF_ERROR(GraphToFunctionDef(*else_branch_host_graph,
1990                                           else_branch_host_func_name,
1991                                           &else_branch_host_fdef));
1992     if (fld->Find(else_branch_host_func_name)) {
1993       TF_RETURN_IF_ERROR(fld->ReplaceFunction(else_branch_host_func_name,
1994                                               else_branch_host_fdef));
1995     } else {
1996       TF_RETURN_IF_ERROR(fld->AddFunctionDef(else_branch_host_fdef));
1997     }
1998   }
1999   string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name());
2000   TF_RETURN_IF_ERROR(BuildHostGraphForIfNode(
2001       xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2002       n->name(), host_transfer_key, oc_host_graph_name, fld,
2003       then_branch_host_func_name, else_branch_host_func_name));
2004   host_graphs->push_back(oc_host_graph_name);
2005 
2006   return Status::OK();
2007 }
2008 
ExtractOutsideCompilationForWhileNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const std::map<string,int> & host_compute_core,Graph * g,Node * n,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * host_graphs,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)2009 Status ExtractOutsideCompilationForWhileNode(
2010     const string& xla_cluster_attr_name,
2011     const string& outside_compilation_attr_name, const string& xla_cluster_name,
2012     const std::map<string, int>& host_compute_core, Graph* g, Node* n,
2013     FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
2014     std::vector<string>* host_graphs,
2015     std::vector<string>* shape_inference_graphs,
2016     bool* has_outside_compilation) {
2017   // Instantiate "cond" and "body".
2018   NameAttrList cond, body;
2019   TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "cond", &cond));
2020   TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "body", &body));
2021 
2022   // Extract outside compilation for cond and body.
2023   bool cond_has_outside_compilation = false;
2024   bool body_has_outside_compilation = false;
2025   string cond_host_func_name = absl::StrCat("oc_cond_host_while_", cond.name()),
2026          body_host_func_name = absl::StrCat("oc_body_host_while_", body.name());
2027   string cond_xla_func_name = absl::StrCat(cond.name(), "_oc"),
2028          body_xla_func_name = absl::StrCat(body.name(), "_oc");
2029   TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
2030       xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2031       cond, cond_xla_func_name, cond_host_func_name, host_compute_core, flr,
2032       fld, shape_inference_graphs, &cond_has_outside_compilation));
2033   TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
2034       xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2035       body, body_xla_func_name, body_host_func_name, host_compute_core, flr,
2036       fld, shape_inference_graphs, &body_has_outside_compilation));
2037 
2038   // If cond/body do not have outside compilation, nothing to do.
2039   if (!cond_has_outside_compilation && !body_has_outside_compilation) {
2040     return Status::OK();
2041   }
2042 
2043   *has_outside_compilation = true;
2044 
2045   // Change While node to call the new functions.
2046   if (cond_has_outside_compilation) {
2047     cond.set_name(cond_xla_func_name);
2048     n->ClearAttr("cond");
2049     n->AddAttr("cond", cond);
2050   }
2051   if (body_has_outside_compilation) {
2052     body.set_name(body_xla_func_name);
2053     n->ClearAttr("body");
2054     n->AddAttr("body", body);
2055   }
2056   n->AddAttr(kXlaOriginalOutsideCompilationNodeName, n->name());
2057 
2058   string host_transfer_key = absl::StrCat("oc_while_pred_", n->name());
2059 
2060   // XLA computation: rewrite cond function to add a SendToHost node to send
2061   // loop predicate.
2062   TF_RETURN_IF_ERROR(AddSendLoopPredToLoopCond(
2063       cond_xla_func_name, host_transfer_key, &cond, fld, n));
2064   n->AddAttr(kXlaTokenInputNodesAttrName,
2065              std::vector<string>{kXlaTokenArgNodeName});
2066 
2067   // Build host side graph for the "While" node.
2068   if (!cond_has_outside_compilation) {
2069     std::unique_ptr<Graph> cond_host_graph(new Graph(fld));
2070     std::vector<string> host_graphs;
2071     TF_RETURN_IF_ERROR(ConstructHostGraph(xla_cluster_name,
2072                                           outside_compilation_attr_name,
2073                                           host_graphs, fld, &cond_host_graph));
2074     FunctionDef cond_host_fdef;
2075     TF_RETURN_IF_ERROR(GraphToFunctionDef(*cond_host_graph, cond_host_func_name,
2076                                           &cond_host_fdef));
2077     if (fld->Find(cond_host_func_name)) {
2078       TF_RETURN_IF_ERROR(
2079           fld->ReplaceFunction(cond_host_func_name, cond_host_fdef));
2080     } else {
2081       TF_RETURN_IF_ERROR(fld->AddFunctionDef(cond_host_fdef));
2082     }
2083   }
2084   if (!body_has_outside_compilation) {
2085     std::unique_ptr<Graph> body_host_graph(new Graph(fld));
2086     std::vector<string> host_graphs;
2087     TF_RETURN_IF_ERROR(ConstructHostGraph(xla_cluster_name,
2088                                           outside_compilation_attr_name,
2089                                           host_graphs, fld, &body_host_graph));
2090     FunctionDef body_host_fdef;
2091     TF_RETURN_IF_ERROR(GraphToFunctionDef(*body_host_graph, body_host_func_name,
2092                                           &body_host_fdef));
2093     if (fld->Find(body_host_func_name)) {
2094       TF_RETURN_IF_ERROR(
2095           fld->ReplaceFunction(body_host_func_name, body_host_fdef));
2096     } else {
2097       TF_RETURN_IF_ERROR(fld->AddFunctionDef(body_host_fdef));
2098     }
2099   }
2100   string oc_host_graph_name = absl::StrCat("oc_while_host_graph_", n->name());
2101   TF_RETURN_IF_ERROR(BuildHostGraphForWhileNode(
2102       xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2103       n->name(), host_transfer_key, oc_host_graph_name, fld,
2104       cond_host_func_name, body_host_func_name));
2105   host_graphs->push_back(oc_host_graph_name);
2106 
2107   return Status::OK();
2108 }
2109 
ExtractOutsideCompilationForNodesWithAssociatedFunctions(Graph * g,const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const std::map<string,int> & host_compute_core,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * host_graphs,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)2110 Status ExtractOutsideCompilationForNodesWithAssociatedFunctions(
2111     Graph* g, const string& xla_cluster_attr_name,
2112     const string& outside_compilation_attr_name, const string& xla_cluster_name,
2113     const std::map<string, int>& host_compute_core, FunctionLibraryRuntime* flr,
2114     FunctionLibraryDefinition* fld, std::vector<string>* host_graphs,
2115     std::vector<string>* shape_inference_graphs,
2116     bool* has_outside_compilation) {
2117   std::vector<Node*> if_nodes, while_nodes, func_call_nodes;
2118   for (Node* n : g->nodes()) {
2119     if (n->IsIfNode()) {
2120       if_nodes.push_back(n);
2121     } else if (n->IsWhileNode()) {
2122       while_nodes.push_back(n);
2123     } else if (IsFunctionCall(*fld, *n)) {
2124       func_call_nodes.push_back(n);
2125     }
2126   }
2127 
2128   for (Node* n : func_call_nodes) {
2129     TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFuncCallNode(
2130         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2131         host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs,
2132         has_outside_compilation));
2133   }
2134 
2135   for (Node* n : if_nodes) {
2136     TF_RETURN_IF_ERROR(ExtractOutsideCompilationForIfNode(
2137         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2138         host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs,
2139         has_outside_compilation));
2140   }
2141 
2142   for (Node* n : while_nodes) {
2143     TF_RETURN_IF_ERROR(ExtractOutsideCompilationForWhileNode(
2144         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2145         host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs,
2146         has_outside_compilation));
2147   }
2148 
2149   return Status::OK();
2150 }
2151 
CopyOutsideCompilationConstNodes(Graph * g,const string & outside_compilation_attr_name)2152 Status CopyOutsideCompilationConstNodes(
2153     Graph* g, const string& outside_compilation_attr_name) {
2154   for (Node* n : g->op_nodes()) {
2155     if (!n->IsConstant() ||
2156         !HasNodeAttr(n->def(), outside_compilation_attr_name)) {
2157       continue;
2158     }
2159 
2160     std::vector<const Edge*> out_edges(n->out_edges().begin(),
2161                                        n->out_edges().end());
2162     bool has_non_oc_output = false;
2163     for (const Edge* e : out_edges) {
2164       if (!e->IsControlEdge() &&
2165           !HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
2166         has_non_oc_output = true;
2167         break;
2168       }
2169     }
2170     if (!has_non_oc_output) {
2171       continue;
2172     }
2173 
2174     NodeDef copy_def = n->def();
2175     copy_def.set_name(g->NewName(n->name()));
2176     copy_def.mutable_attr()->erase(outside_compilation_attr_name);
2177     Status s;
2178     Node* copy_node = g->AddNode(copy_def, &s);
2179     TF_RETURN_IF_ERROR(s);
2180     for (const Edge* e : n->in_edges()) {
2181       if (e->IsControlEdge()) {
2182         g->AddControlEdge(e->src(), copy_node);
2183       }
2184     }
2185     for (const Edge* e : out_edges) {
2186       if (!e->IsControlEdge() &&
2187           !HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
2188         Node* dst = e->dst();
2189         int dst_input = e->dst_input();
2190         g->RemoveEdge(e);
2191         g->AddEdge(copy_node, 0, dst, dst_input);
2192       }
2193     }
2194   }
2195 
2196   return Status::OK();
2197 }
2198 
2199 }  // namespace
2200 
operator ()(const std::vector<OutputTensor> & arg_source_tensors,std::unique_ptr<Graph> * graph,std::vector<int> * input_permutation,std::vector<int> * output_permutation,NodeDef * node_def)2201 Status RewriteOutsideCompilationSubgraphFn::operator()(
2202     const std::vector<OutputTensor>& arg_source_tensors,
2203     std::unique_ptr<Graph>* graph, std::vector<int>* input_permutation,
2204     std::vector<int>* output_permutation, NodeDef* node_def) {
2205   string old_name = node_def->op();
2206   string new_name =
2207       absl::StrCat(xla_cluster_name_, "_", new_function_name_, "_", old_name);
2208   node_def->set_op(new_name);
2209   node_def->set_name(new_name);
2210 
2211   // Later we will run PruneForReverseReachability(), so make sure all original
2212   // nodes are reachable from sink node and won't be removed.
2213   FixupSourceAndSinkEdges(graph->get());
2214 
2215   // Step 1: create a key placeholder node.
2216   TF_ASSIGN_OR_RETURN(
2217       Node * key_placeholder,
2218       AddHostComputeKeyPlaceholder(xla_cluster_name_, graph->get()));
2219 
2220   // Step 2: build RecvAtHost node, and replace all _Arg nodes with it.
2221   std::vector<DataType> recv_at_host_dtypes;
2222   TF_ASSIGN_OR_RETURN(
2223       Node * recv_at_host_node,
2224       ReplaceArgNodesWithRecvAtHostNode(graph->get(), new_name,
2225                                         &recv_at_host_dtypes, key_placeholder));
2226 
2227   // Step 3: build SendFromHost node, and replace all _Retval nodes with it.
2228   std::vector<DataType> send_from_host_dtypes;
2229   TF_ASSIGN_OR_RETURN(
2230       Node * send_from_host_node,
2231       ReplaceRetNodesWithSendFromHostNode(
2232           graph->get(), new_name, &send_from_host_dtypes, key_placeholder));
2233 
2234   // Step 4: add XLA cluster and outside compilation attr.
2235   for (Node* n : (*graph)->nodes()) {
2236     if (IsKeyPlaceholderNode(*n)) {
2237       continue;
2238     }
2239 
2240     n->AddAttr(xla_cluster_attr_name_, xla_cluster_name_);
2241     n->AddAttr(outside_compilation_attr_name_, old_name);
2242   }
2243 
2244   // Check whether we have all input shapes for XlaSendFromHost. If we do, we
2245   // will set `shapes` attr for the call node; otherwise we will save the
2246   // shape inference graph and set `shape_inference_graph` for the call node.
2247   absl::optional<std::vector<PartialTensorShape>> shapes =
2248       GetInferredInputShapes(send_from_host_dtypes.size(), send_from_host_node);
2249   for (Node* n : (*graph)->nodes()) {
2250     n->ClearAttr(kXlaInferredShapesAttrName);
2251   }
2252 
2253   // Step 5: add control edges for originally XLA <-> outside compilation
2254   // control edges.
2255   for (Node* n : (*graph)->nodes()) {
2256     if (HasNodeAttr(n->def(), kXlaConnectedToXlaComputationAttrName)) {
2257       (*graph)->AddControlEdge(n, send_from_host_node);
2258       n->ClearAttr(kXlaConnectedToXlaComputationAttrName);
2259     }
2260     if (HasNodeAttr(n->def(), kXlaConnectedFromXlaComputationAttrName)) {
2261       (*graph)->AddControlEdge(recv_at_host_node, n);
2262       n->ClearAttr(kXlaConnectedFromXlaComputationAttrName);
2263     }
2264   }
2265 
2266   // Step 6: RecvAtHost/SendFromHost/key_placeholder might be dead nodes. Prune
2267   // them if necessary.
2268   // - RecvAtHost should be pruned iff it has no output data/control edges. If
2269   //   it has any output edge, it will be reverse reachable from sink node. We
2270   //   don't need to do anything special.
2271   // - SendFromHost should be pruned iff it has no input data/control edges. If
2272   //   it has input edges other than key_placeholder, we connect it to sink
2273   //   node so it won't be pruned.
2274   // - key_placeholder should be pruned iff RecvAtHost/SendFromHost are pruned.
2275   //   We don't need to do anything special.
2276   if (send_from_host_node->in_edges().size() > 1) {
2277     (*graph)->AddControlEdge(send_from_host_node, (*graph)->sink_node());
2278   }
2279   PruneForReverseReachability(
2280       graph->get(), std::unordered_set<const Node*>{(*graph)->sink_node()});
2281 
2282   // Step 7: add necessary attributes to function call node, so we can replace
2283   // it with HostCompute node later.
2284   AddNodeAttr("_outside_compilation_subgraph", old_name, node_def);
2285   if (shapes) {
2286     NameAttrList shape_inference_graph;
2287     AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def);
2288     AddNodeAttr("shapes", *shapes, node_def);
2289   } else {
2290     string shape_inference_func_name =
2291         absl::StrCat("_outside_compilation_shape_inference_", new_name);
2292     NameAttrList shape_inference_graph;
2293     shape_inference_graph.set_name(shape_inference_func_name);
2294     AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def);
2295     AddNodeAttr("shapes", std::vector<TensorShapeProto>{}, node_def);
2296   }
2297   AddNodeAttr("ancestors", std::vector<string>{}, node_def);
2298   AddNodeAttr("Tinputs", recv_at_host_dtypes, node_def);
2299   AddNodeAttr("Toutputs", send_from_host_dtypes, node_def);
2300   AddNodeAttr("key", absl::StrCat("host_compute_channel_", new_name), node_def);
2301 
2302   return Status::OK();
2303 }
2304 
ExtractOutsideCompilationForFunction(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const NameAttrList & func_name_attrs,const string & new_func_name,const string & host_graph_func_name,const std::map<string,int> & host_compute_core,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)2305 Status ExtractOutsideCompilationForFunction(
2306     const string& xla_cluster_attr_name,
2307     const string& outside_compilation_attr_name, const string& xla_cluster_name,
2308     const NameAttrList& func_name_attrs, const string& new_func_name,
2309     const string& host_graph_func_name,
2310     const std::map<string, int>& host_compute_core, FunctionLibraryRuntime* flr,
2311     FunctionLibraryDefinition* fld, std::vector<string>* shape_inference_graphs,
2312     bool* has_outside_compilation) {
2313   // Convert the function to graph.
2314   const string& func_name = func_name_attrs.name();
2315   FunctionLibraryRuntime::Handle handle;
2316   TF_RETURN_IF_ERROR(
2317       flr->Instantiate(func_name, AttrSlice(&func_name_attrs.attr()), &handle));
2318   Status ret_status = Status::OK();
2319   auto cleanup_handle = gtl::MakeCleanup([&]() {
2320     auto s = flr->ReleaseHandle(handle);
2321     if (!s.ok()) {
2322       ret_status.Update(s);
2323     }
2324   });
2325   const FunctionBody* fbody = flr->GetFunctionBody(handle);
2326 
2327   // Check if we have outside compilation nodes.
2328   *has_outside_compilation = false;
2329   for (Node* n : fbody->graph->nodes()) {
2330     if (HasNodeAttr(n->def(), outside_compilation_attr_name)) {
2331       *has_outside_compilation = true;
2332       break;
2333     }
2334   }
2335   // We cannot early return here, because we might have outside compilation in
2336   // If/While function body.
2337 
2338   if (VLOG_IS_ON(4)) {
2339     DumpGraphToFile(
2340         absl::StrCat("extract_outside_compilation_for_func_before_", func_name),
2341         *fbody->graph, fld);
2342   }
2343 
2344   std::unique_ptr<Graph> graph_out;
2345   std::vector<string> outside_compilation_host_graphs;
2346   std::vector<string> shape_inference_graphs_to_rewrite;
2347   if (*has_outside_compilation) {
2348     // Copy outside compilation Const nodes with non outside compilation users.
2349     TF_RETURN_IF_ERROR(CopyOutsideCompilationConstNodes(
2350         fbody->graph, outside_compilation_attr_name));
2351 
2352     // Find dependencies between outside compilation clusters.
2353     TF_ASSIGN_OR_RETURN(auto cluster_deps,
2354                         OutsideCompilationClusterDependencies(
2355                             fbody->graph, outside_compilation_attr_name));
2356 
2357     // Preprocess edges between different outside compilations. They will be
2358     // restored in `ConstructHostGraph()`.
2359     TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations(
2360         fbody->graph, outside_compilation_attr_name));
2361 
2362     // Encapsulate outside_compilation cluster into function call node.
2363     auto rewrite_fn = absl::make_unique<RewriteOutsideCompilationSubgraphFn>(
2364         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2365         new_func_name);
2366     TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
2367         outside_compilation_attr_name, *fbody->graph, *rewrite_fn,
2368         /*reuse_existing_functions=*/true, &graph_out, fld));
2369 
2370     // Replace outside_compilation function nodes with HostCompute ops.
2371     std::vector<Node*> outside_compilation_nodes;
2372     for (Node* n : graph_out->nodes()) {
2373       if (HasNodeAttr(n->def(), "_outside_compilation_subgraph")) {
2374         outside_compilation_nodes.push_back(n);
2375         outside_compilation_host_graphs.push_back(n->name());
2376 
2377         // If we could not infer shapes for XlaSendFromHost inputs statically,
2378         // we will set the "shape_inference_graph" attribute. In that case, copy
2379         // outside compilation subgraph as shape inference graph in `fld`.
2380         auto shape_inference_graph = absl::make_unique<NameAttrList>();
2381         TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "shape_inference_graph",
2382                                        shape_inference_graph.get()));
2383         if (!shape_inference_graph->name().empty()) {
2384           shape_inference_graphs->push_back(shape_inference_graph->name());
2385           shape_inference_graphs_to_rewrite.push_back(
2386               shape_inference_graph->name());
2387 
2388           const FunctionDef* xla_fdef = fld->Find(n->name());
2389           if (!xla_fdef) {
2390             return errors::Internal("Cannot find XLA function ", n->name());
2391           }
2392           auto shape_inference_fdef = absl::make_unique<FunctionDef>(*xla_fdef);
2393           shape_inference_fdef->mutable_signature()->set_name(
2394               shape_inference_graph->name());
2395           if (fld->Find(shape_inference_graph->name())) {
2396             TF_RETURN_IF_ERROR(fld->ReplaceFunction(
2397                 shape_inference_graph->name(), *shape_inference_fdef));
2398           } else {
2399             TF_RETURN_IF_ERROR(fld->AddFunctionDef(*shape_inference_fdef));
2400           }
2401         }
2402       }
2403     }
2404     std::map<string, Node*> host_compute_nodes;
2405     for (Node* n : outside_compilation_nodes) {
2406       auto host_compute_node_or = ReplaceOutsideCompilationCallNode(
2407           graph_out.get(), n, host_compute_core, *cluster_deps);
2408       TF_RETURN_IF_ERROR(host_compute_node_or.status());
2409       Node* host_compute_node = host_compute_node_or.ValueOrDie();
2410       host_compute_nodes[host_compute_node->name()] = host_compute_node;
2411     }
2412     // For XlaHostCompute nodes with dependencies, add control edges between
2413     // them so XlaCompiler can handle them in correct order.
2414     for (const auto& iter : host_compute_nodes) {
2415       Node* host_compute_node = iter.second;
2416       std::vector<string> token_input_node_names;
2417       TF_RETURN_IF_ERROR(GetNodeAttr(host_compute_node->def(),
2418                                      kXlaTokenInputNodesAttrName,
2419                                      &token_input_node_names));
2420       for (const string& node_name : token_input_node_names) {
2421         if (node_name == kXlaTokenArgNodeName) {
2422           continue;
2423         }
2424 
2425         auto iter = host_compute_nodes.find(node_name);
2426         TF_RET_CHECK(iter != host_compute_nodes.end());
2427         graph_out->AddControlEdge(iter->second, host_compute_node);
2428       }
2429     }
2430   }
2431 
2432   // Handle nodes with associated functions.
2433   Graph* g = (*has_outside_compilation) ? graph_out.get() : fbody->graph;
2434   TF_RETURN_IF_ERROR(ExtractOutsideCompilationForNodesWithAssociatedFunctions(
2435       g, xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2436       host_compute_core, flr, fld, &outside_compilation_host_graphs,
2437       shape_inference_graphs, has_outside_compilation));
2438 
2439   if (*has_outside_compilation) {
2440     // Construct host graph.
2441     std::unique_ptr<Graph> host_graph;
2442     TF_RETURN_IF_ERROR(
2443         ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name,
2444                            outside_compilation_host_graphs, fld, &host_graph));
2445     auto host_graph_fdef = absl::make_unique<FunctionDef>();
2446     TF_RETURN_IF_ERROR(GraphToFunctionDef(*host_graph, host_graph_func_name,
2447                                           HostGraphControlRetMapping,
2448                                           host_graph_fdef.get()));
2449     if (fld->Find(host_graph_func_name)) {
2450       TF_RETURN_IF_ERROR(
2451           fld->ReplaceFunction(host_graph_func_name, *host_graph_fdef));
2452     } else {
2453       TF_RETURN_IF_ERROR(fld->AddFunctionDef(*host_graph_fdef));
2454     }
2455 
2456     // Shape inference graphs might contain Placeholder nodes for outside
2457     // compilation to outside compilation edges. Rewrite shape inference graphs
2458     // to remove such nodes.
2459     for (const string& shape_inference_graph :
2460          shape_inference_graphs_to_rewrite) {
2461       TF_RETURN_IF_ERROR(
2462           RewriteShapeInferenceGraph(shape_inference_graph, host_graph.get(),
2463                                      /*pivot_node=*/nullptr, fld));
2464     }
2465 
2466     // Remove the outside compilation graphs from function library.
2467     for (const string& func : outside_compilation_host_graphs) {
2468       TF_RETURN_IF_ERROR(fld->RemoveFunction(func));
2469     }
2470 
2471     // Replace original function.
2472     auto updated_fdef = absl::make_unique<FunctionDef>();
2473     TF_RETURN_IF_ERROR(
2474         GraphToFunctionDef(*g, new_func_name, updated_fdef.get()));
2475     updated_fdef->mutable_signature()->set_is_stateful(true);
2476     const FunctionDef* original_fdef = fld->Find(func_name);
2477     if (original_fdef) {
2478       for (const auto& attr : original_fdef->attr()) {
2479         (*updated_fdef->mutable_attr())[attr.first] = attr.second;
2480       }
2481     }
2482     if (fld->Find(new_func_name)) {
2483       TF_RETURN_IF_ERROR(fld->ReplaceFunction(new_func_name, *updated_fdef));
2484     } else {
2485       TF_RETURN_IF_ERROR(fld->AddFunctionDef(*updated_fdef));
2486     }
2487     if (VLOG_IS_ON(4)) {
2488       DumpGraphToFile(
2489           absl::StrCat("extract_outside_compilation_for_func_after_",
2490                        func_name),
2491           *g, fld);
2492     }
2493   }
2494 
2495   return ret_status;
2496 }
2497 
ExtractOutsideCompilation(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const std::unordered_map<string,XlaClusterInfo> & clusters,Graph * g,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,bool * modified)2498 Status ExtractOutsideCompilation(
2499     const string& xla_cluster_attr_name,
2500     const string& outside_compilation_attr_name,
2501     const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g,
2502     FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
2503     bool* modified) {
2504   if (VLOG_IS_ON(4)) {
2505     DumpGraphToFile("extract_outside_compilation_before", *g, fld);
2506   }
2507 
2508   *modified = false;
2509   auto node_name_index = g->BuildNodeNameIndex();
2510   for (auto& iter : clusters) {
2511     string xla_cluster_name = iter.first;
2512     Node* n = iter.second.node;
2513     auto const& func_name_attrs = iter.second.func_name_attrs;
2514     auto const& host_compute_core = iter.second.host_compute_core;
2515 
2516     std::vector<string> shape_inference_graphs;
2517     bool has_outside_compilation;
2518     string host_graph_func_name =
2519         absl::StrCat("oc_host_graph_", xla_cluster_name);
2520     TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
2521         xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2522         func_name_attrs, func_name_attrs.name(), host_graph_func_name,
2523         host_compute_core, flr, fld, &shape_inference_graphs,
2524         &has_outside_compilation));
2525     *modified |= has_outside_compilation;
2526 
2527     if (has_outside_compilation) {
2528       string pivot_name = absl::StrCat(xla_cluster_name, "/pivot");
2529       Node* pivot_node = node_name_index[pivot_name];
2530       TF_RETURN_IF_ERROR(ExpandHostGraphIntoMainGraph(
2531           g, fld, host_graph_func_name, n, pivot_node));
2532 
2533       TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name));
2534 
2535       for (const auto& shape_inference_graph_name : shape_inference_graphs) {
2536         TF_RETURN_IF_ERROR(RewriteShapeInferenceGraph(
2537             shape_inference_graph_name, g, pivot_node, fld));
2538       }
2539     }
2540   }
2541 
2542   if (VLOG_IS_ON(4)) {
2543     DumpGraphToFile("extract_outside_compilation_after", *g, fld);
2544   }
2545   return Status::OK();
2546 }
2547 
2548 }  // namespace tensorflow
2549