1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
17
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/strings/match.h"
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
22 #include "tensorflow/compiler/jit/encapsulate_util.h"
23 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
24 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/core/common_runtime/function.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/graph_to_functiondef.h"
29 #include "tensorflow/core/framework/node_def_builder.h"
30 #include "tensorflow/core/framework/node_def_util.h"
31 #include "tensorflow/core/framework/tensor_shape.pb.h"
32 #include "tensorflow/core/graph/algorithm.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/gtl/cleanup.h"
35 #include "tensorflow/core/platform/macros.h"
36 #include "tensorflow/core/util/dump_graph.h"
37 #include "tensorflow/stream_executor/lib/statusor.h"
38
39 namespace tensorflow {
40
41 namespace {
42
43 // Control return mapping function for outside compilation host graphs.
44 // All nodes with kXlaHasHostTransfer attribute are control outputs.
HostGraphControlRetMapping(const Node * n)45 absl::optional<string> HostGraphControlRetMapping(const Node* n) {
46 if (HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) {
47 return n->name();
48 }
49 return absl::nullopt;
50 }
51
52 // Add a key placeholder node to the graph. The key placeholder node will be
53 // used as input for XlaRecvAtHost/XlaSendFromHost nodes.
AddHostComputeKeyPlaceholder(const string & xla_cluster_name,Graph * g)54 StatusOr<Node*> AddHostComputeKeyPlaceholder(const string& xla_cluster_name,
55 Graph* g) {
56 NodeDef key_def;
57 NodeDefBuilder builder(absl::StrCat(xla_cluster_name, "_key_placeholder"),
58 "Placeholder");
59 builder.Attr("dtype", DT_STRING);
60 builder.Attr("shape", PartialTensorShape({2}));
61 builder.Attr("_host_compute_call_node", xla_cluster_name);
62 Status s = builder.Finalize(&key_def);
63 if (!s.ok()) return s;
64
65 Node* n = g->AddNode(key_def, &s);
66 if (!s.ok()) return s;
67 return n;
68 }
69
70 // Returns if the node is a XLA computation key placeholder.
IsKeyPlaceholderNode(const Node & n)71 bool IsKeyPlaceholderNode(const Node& n) {
72 return n.type_string() == "Placeholder" &&
73 absl::EndsWith(n.name(), "_key_placeholder");
74 }
75
76 // Returns nodes with given type.
GatherNodesWithType(const Graph & g,const string & type)77 std::vector<Node*> GatherNodesWithType(const Graph& g, const string& type) {
78 std::vector<Node*> result;
79 for (Node* n : g.nodes()) {
80 if (n->type_string() == type) {
81 result.push_back(n);
82 }
83 }
84 return result;
85 }
86
87 // Gets data types from `arg_nodes` and fills them into `recv_at_host_dtypes`.
GetArgDataTypes(const std::vector<Node * > & arg_nodes,std::vector<DataType> * recv_at_host_dtypes)88 Status GetArgDataTypes(const std::vector<Node*>& arg_nodes,
89 std::vector<DataType>* recv_at_host_dtypes) {
90 recv_at_host_dtypes->resize(arg_nodes.size(), DT_INVALID);
91 for (auto* n : arg_nodes) {
92 int index;
93 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
94 DataType dtype;
95 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "T", &dtype));
96 (*recv_at_host_dtypes)[index] = dtype;
97 }
98 for (int i = 0, end = recv_at_host_dtypes->size(); i < end; i++) {
99 if ((*recv_at_host_dtypes)[i] == DT_INVALID) {
100 return errors::Internal("Cannot get datatype for input ", i);
101 }
102 }
103 return Status::OK();
104 }
105
106 // Builds XlaRecvAtHost node.
BuildRecvAtHostNode(Graph * g,const string & oc_cluster_name,const std::vector<DataType> & recv_at_host_dtypes,Node * key_placeholder)107 StatusOr<Node*> BuildRecvAtHostNode(
108 Graph* g, const string& oc_cluster_name,
109 const std::vector<DataType>& recv_at_host_dtypes, Node* key_placeholder) {
110 NodeDefBuilder recv_at_host_builder(
111 absl::StrCat("outside_compilation_", oc_cluster_name, "_recv"),
112 "_XlaRecvAtHost");
113 NodeDef recv_at_host_def;
114 recv_at_host_builder.Attr("Toutputs", recv_at_host_dtypes);
115 // The correct device_ordinal will be inserted during replication in a
116 // subsequent rewrite.
117 AttrValue device_ordinal_value;
118 device_ordinal_value.set_placeholder("_device_ordinal");
119 recv_at_host_builder.Attr("device_ordinal", device_ordinal_value);
120 recv_at_host_builder.Attr(
121 "key", absl::StrCat("host_compute_channel_", oc_cluster_name));
122 recv_at_host_builder.Attr(kXlaHasHostTransferAttrName, true);
123 recv_at_host_builder.Input(key_placeholder->name(), 0, DT_STRING);
124 TF_RETURN_IF_ERROR(recv_at_host_builder.Finalize(&recv_at_host_def));
125 Status s;
126 Node* recv_at_host_node = g->AddNode(recv_at_host_def, &s);
127 TF_RETURN_IF_ERROR(s);
128 return recv_at_host_node;
129 }
130
131 // Builds XlaRecvAtHost node, and replaces all _Arg nodes with it.
ReplaceArgNodesWithRecvAtHostNode(Graph * g,const string & oc_cluster_name,std::vector<DataType> * recv_at_host_dtypes,Node * key_placeholder)132 StatusOr<Node*> ReplaceArgNodesWithRecvAtHostNode(
133 Graph* g, const string& oc_cluster_name,
134 std::vector<DataType>* recv_at_host_dtypes, Node* key_placeholder) {
135 // TODO(b/77601805): use out nodes for source node, instead of traversing all
136 // nodes.
137 std::vector<Node*> arg_nodes = GatherNodesWithType(*g, "_Arg");
138 TF_RETURN_IF_ERROR(GetArgDataTypes(arg_nodes, recv_at_host_dtypes));
139 TF_ASSIGN_OR_RETURN(
140 Node * recv_at_host_node,
141 BuildRecvAtHostNode(g, oc_cluster_name, *recv_at_host_dtypes,
142 key_placeholder));
143 for (auto* n : arg_nodes) {
144 int index;
145 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
146 // Record out edges and remove `n` before adding those edges to RecvAtHost.
147 // This is to avoid multiple producers.
148 std::vector<OutEdgeInfo> out_edge_info;
149 for (auto edge : n->out_edges()) {
150 out_edge_info.push_back(
151 {edge->dst(), edge->src_output(), edge->dst_input()});
152 }
153 g->RemoveNode(n);
154 for (const OutEdgeInfo& edge : out_edge_info) {
155 if (edge.dst_input == Graph::kControlSlot) {
156 g->AddControlEdge(recv_at_host_node, edge.dst);
157 } else {
158 g->AddEdge(recv_at_host_node, index, edge.dst, edge.dst_input);
159 }
160 }
161
162 // Rewrite dst nodes because their input changed.
163 for (int i = 0, end = out_edge_info.size(); i < end; i++) {
164 const OutEdgeInfo edge = out_edge_info[i];
165 if (edge.dst_input == Graph::kControlSlot) {
166 continue;
167 }
168
169 Node* dst = edge.dst;
170 NodeDef new_def = dst->def();
171 *new_def.mutable_input(edge.dst_input) =
172 absl::StrCat(recv_at_host_node->name(), ":", index);
173 TF_ASSIGN_OR_RETURN(Node * dst_replace, ReplaceNode(g, dst, new_def));
174
175 // Other edges might have `dst` as dst node as well. Update those edges
176 // with `dst_replace`.
177 for (int j = i + 1, end = out_edge_info.size(); j < end; j++) {
178 if (out_edge_info[j].dst == dst) {
179 out_edge_info[j].dst = dst_replace;
180 }
181 }
182 }
183 }
184 g->AddEdge(key_placeholder, 0, recv_at_host_node, 0);
185 return recv_at_host_node;
186 }
187
188 // Gets data types from `ret_nodes` and fills them into `send_from_host_dtypes`.
GetRetDataTypes(const std::vector<Node * > & ret_nodes,std::vector<DataType> * send_from_host_dtypes)189 Status GetRetDataTypes(const std::vector<Node*>& ret_nodes,
190 std::vector<DataType>* send_from_host_dtypes) {
191 send_from_host_dtypes->resize(ret_nodes.size(), DT_INVALID);
192 for (auto* n : ret_nodes) {
193 int index;
194 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
195 DataType dtype;
196 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "T", &dtype));
197 (*send_from_host_dtypes)[index] = dtype;
198 }
199 for (int i = 0, end = send_from_host_dtypes->size(); i < end; i++) {
200 if ((*send_from_host_dtypes)[i] == DT_INVALID) {
201 return errors::Internal("Cannot get datatype for output ", i);
202 }
203 }
204 return Status::OK();
205 }
206
207 // Builds XlaSendFromHost node.
BuildSendFromHostNode(Graph * g,const string & oc_cluster_name,const std::vector<Node * > & ret_nodes,const std::vector<DataType> & send_from_host_dtypes,Node * key_placeholder)208 StatusOr<Node*> BuildSendFromHostNode(
209 Graph* g, const string& oc_cluster_name,
210 const std::vector<Node*>& ret_nodes,
211 const std::vector<DataType>& send_from_host_dtypes, Node* key_placeholder) {
212 NodeDefBuilder send_from_host_builder(
213 absl::StrCat("outside_compilation_", oc_cluster_name, "_send"),
214 "_XlaSendFromHost");
215 NodeDef send_from_host_def;
216 send_from_host_builder.Attr("Tinputs", send_from_host_dtypes);
217 // The correct device_ordinal will be inserted during replication in a
218 // subsequent rewrite.
219 AttrValue device_ordinal_value;
220 device_ordinal_value.set_placeholder("_device_ordinal");
221 send_from_host_builder.Attr("device_ordinal", device_ordinal_value);
222 send_from_host_builder.Attr(
223 "key", absl::StrCat("host_compute_channel_", oc_cluster_name));
224 send_from_host_builder.Attr(kXlaHasHostTransferAttrName, true);
225 std::vector<NodeDefBuilder::NodeOut> inputs(send_from_host_dtypes.size());
226 for (auto* n : ret_nodes) {
227 int index;
228 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
229 const int num_dtypes = send_from_host_dtypes.size();
230 if (index < 0 || index >= num_dtypes) {
231 return errors::Internal("Invalid _Retval index: ", index);
232 }
233 for (auto edge : n->in_edges()) {
234 inputs[index] =
235 NodeDefBuilder::NodeOut{edge->src()->name(), edge->src_output(),
236 edge->src()->output_type(edge->src_output())};
237 }
238 }
239 send_from_host_builder.Input(inputs);
240 send_from_host_builder.Input(key_placeholder->name(), 0, DT_STRING);
241 TF_RETURN_IF_ERROR(send_from_host_builder.Finalize(&send_from_host_def));
242 Status s;
243 Node* send_from_host_node = g->AddNode(send_from_host_def, &s);
244 TF_RETURN_IF_ERROR(s);
245 return send_from_host_node;
246 }
247
248 // Builds XlaSendFromHost node, and replaces all _Retval nodes with it.
ReplaceRetNodesWithSendFromHostNode(Graph * g,const string & oc_cluster_name,std::vector<DataType> * send_from_host_dtypes,Node * key_placeholder)249 StatusOr<Node*> ReplaceRetNodesWithSendFromHostNode(
250 Graph* g, const string& oc_cluster_name,
251 std::vector<DataType>* send_from_host_dtypes, Node* key_placeholder) {
252 // TODO(b/77601805): use in nodes for sink node, instead of traversing all
253 // nodes.
254 std::vector<Node*> ret_nodes = GatherNodesWithType(*g, "_Retval");
255 TF_RETURN_IF_ERROR(GetRetDataTypes(ret_nodes, send_from_host_dtypes));
256 TF_ASSIGN_OR_RETURN(
257 Node * send_from_host_node,
258 BuildSendFromHostNode(g, oc_cluster_name, ret_nodes,
259 *send_from_host_dtypes, key_placeholder));
260 for (auto* n : ret_nodes) {
261 int index;
262 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
263 for (auto edge : n->in_edges()) {
264 if (edge->src_output() == Graph::kControlSlot) {
265 g->AddControlEdge(edge->src(), send_from_host_node);
266 } else {
267 g->AddEdge(edge->src(), edge->src_output(), send_from_host_node, index);
268 }
269 }
270 g->RemoveNode(n);
271 }
272 g->AddEdge(key_placeholder, 0, send_from_host_node,
273 send_from_host_dtypes->size());
274 return send_from_host_node;
275 }
276
277 // Returns input shapes (excluding key placeholder) for `send_from_host_node`
278 // if they are all fully defined; absl::nullopt otherwise.
GetInferredInputShapes(int num_inputs,Node * send_from_host_node)279 absl::optional<std::vector<PartialTensorShape>> GetInferredInputShapes(
280 int num_inputs, Node* send_from_host_node) {
281 std::vector<PartialTensorShape> results(num_inputs);
282 for (int i = 0; i < num_inputs; i++) {
283 const Edge* e;
284 if (!send_from_host_node->input_edge(i, &e).ok()) {
285 return absl::nullopt;
286 }
287
288 std::vector<PartialTensorShape> shapes;
289 if (!GetNodeAttr(e->src()->attrs(), kXlaInferredShapesAttrName, &shapes)
290 .ok()) {
291 return absl::nullopt;
292 }
293
294 const PartialTensorShape shape = shapes[e->src_output()];
295 if (!shape.IsFullyDefined()) {
296 return absl::nullopt;
297 }
298
299 results[e->dst_input()] = shape;
300 }
301 return results;
302 }
303
host_compute_node_name(const string & original_oc_name)304 string host_compute_node_name(const string& original_oc_name) {
305 return absl::StrCat("outside_compilation_", original_oc_name,
306 "_host_compute");
307 }
308
309 // Builds XlaHostCompute NodeDef from the outside compilation call node.
BuildXlaHostComputeNodeDef(const Node * call_node,const std::map<string,int> & host_compute_core,const absl::flat_hash_map<string,std::vector<string>> & cluster_deps)310 StatusOr<NodeDef> BuildXlaHostComputeNodeDef(
311 const Node* call_node, const std::map<string, int>& host_compute_core,
312 const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
313 string original_oc_name;
314 TF_RETURN_IF_ERROR(GetNodeAttr(
315 call_node->attrs(), "_outside_compilation_subgraph", &original_oc_name));
316 NodeDefBuilder host_compute_builder(host_compute_node_name(original_oc_name),
317 "XlaHostCompute");
318 // In XlaCompiler, if XlaHostCompute node is in a function call node and that
319 // function is inlined, name of the XlaHostCompute node will be changed. So
320 // we cannot rely on node name; use an attribute instead.
321 host_compute_builder.Attr(kXlaOriginalOutsideCompilationNodeName,
322 host_compute_builder.node_name());
323
324 // Copy all attributes.
325 for (const auto& attr : call_node->attrs()) {
326 host_compute_builder.Attr(attr.first, attr.second);
327 }
328
329 // Populate tpu_core assignment.
330 const auto iter = host_compute_core.find(original_oc_name);
331 if (iter != host_compute_core.end()) {
332 int core = iter->second;
333 host_compute_builder.Attr("tpu_core", core);
334 }
335
336 // Set input tokens and other outside compilation clusters that current
337 // cluster depends in `kXlaTokenArgNodeName`. This is needed because when
338 // outside compilation subgraphs are encapsulated and moved to host graph,
339 // control/data edges between them will only be reflected in host graph.
340 // From XLA's perspective, two originally dependent clusters are no longer
341 // connected, which makes them look like they can be scheduled for execution
342 // in arbitrary order even though in fact they must be executed in order
343 // according to their host-side graph dependency. This can cause deadlock.
344 // Therefore, we hint XLA what the correct ordering of these clusters should
345 // be to avoid deadlocks.
346 std::vector<string> xla_token_input_nodes;
347 xla_token_input_nodes.emplace_back(kXlaTokenArgNodeName);
348 auto cluster_deps_it = cluster_deps.find(original_oc_name);
349 if (cluster_deps_it != cluster_deps.end()) {
350 for (const auto& dep : cluster_deps_it->second) {
351 xla_token_input_nodes.emplace_back(host_compute_node_name(dep));
352 }
353 }
354 host_compute_builder.Attr(kXlaTokenInputNodesAttrName, xla_token_input_nodes);
355
356 // Populate inputs.
357 std::vector<DataType> input_dtypes;
358 TF_RETURN_IF_ERROR(GetNodeAttr(call_node->attrs(), "Tinputs", &input_dtypes));
359 std::vector<NodeDefBuilder::NodeOut> inputs(input_dtypes.size());
360 for (auto e : call_node->in_edges()) {
361 if (e->IsControlEdge()) {
362 continue;
363 }
364
365 const int input_dtypes_size = input_dtypes.size();
366 if (e->dst_input() < 0 || e->dst_input() >= input_dtypes_size) {
367 return errors::Internal("Invalid dst_input: ", e->dst_input());
368 }
369 inputs[e->dst_input()] = NodeDefBuilder::NodeOut{
370 e->src()->name(), e->src_output(), input_dtypes[e->dst_input()]};
371 }
372 host_compute_builder.Input(inputs);
373
374 NodeDef new_def;
375 TF_RETURN_IF_ERROR(host_compute_builder.Finalize(&new_def));
376 return new_def;
377 }
378
379 // Replace outside compilation function call node with XlaHostCompute node.
ReplaceOutsideCompilationCallNode(Graph * g,Node * call_node,const std::map<string,int> & host_compute_core,const absl::flat_hash_map<string,std::vector<string>> & cluster_deps)380 TF_ATTRIBUTE_NOINLINE StatusOr<Node*> ReplaceOutsideCompilationCallNode(
381 Graph* g, Node* call_node, const std::map<string, int>& host_compute_core,
382 const absl::flat_hash_map<string, std::vector<string>>& cluster_deps) {
383 // Build XlaHostCompute NodeDef.
384 TF_ASSIGN_OR_RETURN(
385 NodeDef node_def,
386 BuildXlaHostComputeNodeDef(call_node, host_compute_core, cluster_deps));
387 TF_ASSIGN_OR_RETURN(Node * host_compute_node,
388 ReplaceNode(g, call_node, node_def));
389 VLOG(4) << "Added HostCompute node: " << host_compute_node->DebugString();
390
391 return host_compute_node;
392 }
393
394 // Resets "_device_ordinal" attr to placeholder value for related nodes
395 // (XlaRecvAtHost nodes; XlaSendFromHost nodes; If/While/FuncCall nodes
396 // containing XlaRecvAtHost/XlaSendFromHost).
ResetDeviceOrdinalToPlaceholderValue(Graph * g)397 Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) {
398 AttrValue device_ordinal_value;
399 device_ordinal_value.set_placeholder("_device_ordinal");
400 for (Node* n : g->nodes()) {
401 if (!HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) {
402 continue;
403 }
404
405 if (n->type_string() == "_XlaRecvAtHost" ||
406 n->type_string() == "_XlaSendFromHost") {
407 n->ClearAttr("device_ordinal");
408 n->AddAttr("device_ordinal", device_ordinal_value);
409 } else if (n->IsIfNode()) {
410 for (const string& attr_name :
411 std::vector<string>{"then_branch", "else_branch"}) {
412 NameAttrList branch_func;
413 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func));
414 (*branch_func.mutable_attr())["_device_ordinal"] = device_ordinal_value;
415 n->ClearAttr(attr_name);
416 n->AddAttr(attr_name, branch_func);
417 }
418 } else if (n->IsWhileNode()) {
419 for (const string& attr_name : std::vector<string>{"cond", "body"}) {
420 NameAttrList branch_func;
421 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func));
422 (*branch_func.mutable_attr())["_device_ordinal"] = device_ordinal_value;
423 n->ClearAttr(attr_name);
424 n->AddAttr(attr_name, branch_func);
425 }
426 } else if (HasNodeAttr(n->def(), "_device_ordinal")) {
427 // Function call node containing outside compilation.
428 n->ClearAttr("_device_ordinal");
429 n->AddAttr("_device_ordinal", device_ordinal_value);
430 } else {
431 return errors::Internal("Unknown node marked with ",
432 kXlaHasHostTransferAttrName, ": ",
433 n->DebugString());
434 }
435 }
436 return Status::OK();
437 }
438
439 // Cheap check to tell whether FunctionDef contains a lifted argument.
HasLiftedArgs(const FunctionDef & function_def)440 bool HasLiftedArgs(const FunctionDef& function_def) {
441 return absl::c_any_of(function_def.node_def(), [](const NodeDef& node_def) {
442 return (node_def.op() == "Placeholder" &&
443 node_def.attr().find(kXlaLiftedArgOutsideCompilationAttrName) !=
444 node_def.attr().end());
445 });
446 }
447
448 // Find lifted arguments in a function body and their corresponding outside
449 // compilation nodes.
450 StatusOr<std::vector<std::pair<Node*, Node*>>>
LiftedArgsAndOutsideCompilationNodesInFunctionBody(const FunctionBody & function_body,const std::unordered_map<string,Node * > & outside_compilation_attr_to_node)451 LiftedArgsAndOutsideCompilationNodesInFunctionBody(
452 const FunctionBody& function_body,
453 const std::unordered_map<string, Node*>& outside_compilation_attr_to_node) {
454 std::vector<std::pair<Node*, Node*>>
455 lifted_arg_nodes_and_outside_compilation_nodes;
456 for (Node* n : function_body.graph->op_nodes()) {
457 string oc_cluster;
458 if (n->type_string() == "Placeholder" &&
459 GetNodeAttr(n->def(), kXlaLiftedArgOutsideCompilationAttrName,
460 &oc_cluster)
461 .ok()) {
462 TF_RET_CHECK(outside_compilation_attr_to_node.find(oc_cluster) !=
463 outside_compilation_attr_to_node.end());
464 lifted_arg_nodes_and_outside_compilation_nodes.emplace_back(
465 n, outside_compilation_attr_to_node.at(oc_cluster));
466 }
467 }
468 return lifted_arg_nodes_and_outside_compilation_nodes;
469 }
470
471 // Append lifted args' types to functional control flow node's `type_attr_name`
472 // attribute.
UpdateTypesAttribute(const std::vector<std::pair<Node *,Node * >> & lifted_arg_nodes_and_outside_compilation_nodes,const string & type_attr_name,Node * n)473 StatusOr<std::vector<DataType>> UpdateTypesAttribute(
474 const std::vector<std::pair<Node*, Node*>>&
475 lifted_arg_nodes_and_outside_compilation_nodes,
476 const string& type_attr_name, Node* n) {
477 std::vector<DataType> data_types;
478 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), type_attr_name, &data_types));
479 for (auto pair : lifted_arg_nodes_and_outside_compilation_nodes) {
480 Node* outside_compilation_node = pair.second;
481 DataType data_type;
482 TF_RET_CHECK(outside_compilation_node->IsIdentity() ||
483 outside_compilation_node->type_string() == "Placeholder");
484 if (outside_compilation_node->IsIdentity()) {
485 TF_RETURN_IF_ERROR(
486 GetNodeAttr(outside_compilation_node->def(), "T", &data_type));
487 } else {
488 TF_RETURN_IF_ERROR(
489 GetNodeAttr(outside_compilation_node->def(), "dtype", &data_type));
490 }
491 data_types.push_back(data_type);
492 }
493 n->ClearAttr(type_attr_name);
494 n->AddAttr(type_attr_name, data_types);
495
496 return data_types;
497 }
498
499 // Add edges from lifted outside compilation argument nodes to `n` in Graph `g`.
AddEdgesFromOutsideCompilationNodes(const int original_arg_count,const int arg_to_input_edge_offset,const std::vector<DataType> & data_types,const std::vector<Node * > & outside_compilation_nodes,Graph * g,Node * n)500 void AddEdgesFromOutsideCompilationNodes(
501 const int original_arg_count, const int arg_to_input_edge_offset,
502 const std::vector<DataType>& data_types,
503 const std::vector<Node*>& outside_compilation_nodes, Graph* g, Node* n) {
504 // Add edges from outside compilation nodes to While node.
505 for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
506 Node* outside_compilation_node =
507 outside_compilation_nodes[i - original_arg_count];
508 g->AddEdge(outside_compilation_node, 0, n, i + arg_to_input_edge_offset);
509 }
510 }
511
512 // Construct _Arg that maps to lifted outside compilation argument node input.
AddOutsideCompilationInputArgToFunctionBody(const FunctionBody & function_body,const int arg_idx,const DataType & data_type)513 StatusOr<Node*> AddOutsideCompilationInputArgToFunctionBody(
514 const FunctionBody& function_body, const int arg_idx,
515 const DataType& data_type) {
516 NodeDefBuilder arg_builder(absl::StrCat("arg_", arg_idx), "_Arg");
517 arg_builder.Attr("T", data_type);
518 arg_builder.Attr("index", arg_idx);
519 NodeDef arg_def;
520 TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def));
521
522 Status s;
523 Node* arg_node = function_body.graph->AddNode(arg_def, &s);
524 TF_RETURN_IF_ERROR(s);
525 return arg_node;
526 }
527
528 // Add _Retval node that matches newly added `arg_node` and connect `arg_node`
529 // to it.
AddMatchingRetvalNode(const FunctionBody & function_body,const int arg_idx,const DataType & data_type,Node * arg_node)530 Status AddMatchingRetvalNode(const FunctionBody& function_body,
531 const int arg_idx, const DataType& data_type,
532 Node* arg_node) {
533 NodeDefBuilder ret_builder(absl::StrCat("ret_", arg_idx), "_Retval");
534 ret_builder.Attr("T", data_type);
535 ret_builder.Attr("index", arg_idx);
536 ret_builder.Input(arg_node->name(), 0, data_type);
537 NodeDef ret_def;
538 TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
539 Status s;
540 Node* ret_node = function_body.graph->AddNode(ret_def, &s);
541 TF_RETURN_IF_ERROR(s);
542 function_body.graph->AddEdge(arg_node, 0, ret_node, 0);
543
544 return Status::OK();
545 }
546
ReplaceLiftedArgNodePlaceholderWithArg(const FunctionBody & function_body,const int original_arg_count,const int arg_idx,const std::vector<Node * > & lifted_arg_nodes,Node * arg_node)547 void ReplaceLiftedArgNodePlaceholderWithArg(
548 const FunctionBody& function_body, const int original_arg_count,
549 const int arg_idx, const std::vector<Node*>& lifted_arg_nodes,
550 Node* arg_node) {
551 Node* lifted_arg_node = lifted_arg_nodes[arg_idx - original_arg_count];
552 // This might happen because lifted_arg_node only exists in one branch of an
553 // If node, and we are handling the other branch.
554 if (!lifted_arg_node) {
555 return;
556 }
557
558 for (const Edge* e : lifted_arg_node->out_edges()) {
559 if (e->IsControlEdge()) {
560 function_body.graph->AddControlEdge(arg_node, e->dst());
561 } else {
562 function_body.graph->AddEdge(arg_node, 0, e->dst(), e->dst_input());
563 }
564 }
565 function_body.graph->RemoveNode(lifted_arg_node);
566 }
567
568 // Adds function def to function definition library and update the function
569 // callsite operation `callsite_node` to invoke new function instead.
AddFunctionWithNewName(const std::string & new_name,const std::string & func_attr_name,const FunctionDef & function_def,NameAttrList * func_attr,Node * callsite_node,FunctionLibraryDefinition * fld)570 Status AddFunctionWithNewName(const std::string& new_name,
571 const std::string& func_attr_name,
572 const FunctionDef& function_def,
573 NameAttrList* func_attr, Node* callsite_node,
574 FunctionLibraryDefinition* fld) {
575 TF_RETURN_IF_ERROR(fld->AddFunctionDef(function_def));
576 func_attr->set_name(new_name);
577 callsite_node->ClearAttr(func_attr_name);
578 callsite_node->AddAttr(func_attr_name, *func_attr);
579 return Status::OK();
580 }
581
582 // Reconnect outside compilation lifted arguments in a functional While node to
583 // its outside compilation tensor sources.
PostprocessLiftedArgsForWhile(const std::unordered_map<string,Node * > & outside_compilation_attr_to_node,Graph * g,Node * n,FunctionLibraryDefinition * fld)584 Status PostprocessLiftedArgsForWhile(
585 const std::unordered_map<string, Node*>& outside_compilation_attr_to_node,
586 Graph* g, Node* n, FunctionLibraryDefinition* fld) {
587 TF_RET_CHECK(n->IsWhileNode());
588
589 // Check if there is any lifted args in body function.
590 NameAttrList body_func;
591 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "body", &body_func));
592 const FunctionDef* body_function_def = fld->Find(body_func.name());
593 TF_RET_CHECK(body_function_def);
594
595 if (!HasLiftedArgs(*body_function_def)) {
596 return Status::OK();
597 }
598
599 // Gather all lifted args.
600 std::unique_ptr<FunctionBody> body_function_body;
601 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*body_function_def,
602 AttrSlice(&body_func.attr()), fld,
603 &body_function_body));
604
605 int original_arg_count = body_function_body->arg_nodes.size();
606
607 TF_ASSIGN_OR_RETURN(
608 auto lifted_arg_nodes_and_outside_compilation_nodes,
609 LiftedArgsAndOutsideCompilationNodesInFunctionBody(
610 *body_function_body, outside_compilation_attr_to_node));
611
612 // Append lifted args' types to While node's T attribute.
613 TF_ASSIGN_OR_RETURN(
614 std::vector<DataType> data_types,
615 UpdateTypesAttribute(lifted_arg_nodes_and_outside_compilation_nodes, "T",
616 n));
617
618 // Add edges from outside compilation nodes to While node.
619 std::vector<Node*> outside_compilation_nodes;
620 std::transform(
621 lifted_arg_nodes_and_outside_compilation_nodes.begin(),
622 lifted_arg_nodes_and_outside_compilation_nodes.end(),
623 std::back_inserter(outside_compilation_nodes),
624 [](const std::pair<Node*, Node*>& pair) { return pair.second; });
625 AddEdgesFromOutsideCompilationNodes(original_arg_count,
626 /*arg_to_input_edge_offset=*/0,
627 data_types, outside_compilation_nodes, g,
628 n);
629
630 // In body_graph, create new _Arg/_Retval nodes, and replace lifted arg
631 // nodes with the new _Arg nodes.
632 std::vector<Node*> lifted_arg_nodes;
633 std::transform(
634 lifted_arg_nodes_and_outside_compilation_nodes.begin(),
635 lifted_arg_nodes_and_outside_compilation_nodes.end(),
636 std::back_inserter(lifted_arg_nodes),
637 [](const std::pair<Node*, Node*>& pair) { return pair.first; });
638 for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
639 TF_ASSIGN_OR_RETURN(Node * arg_node,
640 AddOutsideCompilationInputArgToFunctionBody(
641 *body_function_body, i, data_types[i]));
642
643 TF_RETURN_IF_ERROR(
644 AddMatchingRetvalNode(*body_function_body, i, data_types[i], arg_node));
645
646 ReplaceLiftedArgNodePlaceholderWithArg(
647 *body_function_body, original_arg_count, i, lifted_arg_nodes, arg_node);
648 }
649
650 const auto new_body_function_name =
651 fld->UniqueFunctionName(absl::StrCat(body_func.name(), "_lifted_arg_"));
652 FunctionDef rewritten_body_function_def;
653 TF_RETURN_IF_ERROR(GraphToFunctionDef(
654 *body_function_body->graph, new_body_function_name,
655 HostGraphControlRetMapping, &rewritten_body_function_def));
656 TF_RETURN_IF_ERROR(AddFunctionWithNewName(new_body_function_name, "body",
657 rewritten_body_function_def,
658 &body_func, n, fld));
659
660 // In cond_graph, just add new _Arg nodes.
661 NameAttrList cond_func;
662 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "cond", &cond_func));
663 const FunctionDef* cond_function_def = fld->Find(cond_func.name());
664 TF_RET_CHECK(cond_function_def);
665 std::unique_ptr<FunctionBody> cond_function_body;
666 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*cond_function_def,
667 AttrSlice(&cond_func.attr()), fld,
668 &cond_function_body));
669
670 for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
671 StatusOr<Node*> arg_node_or = AddOutsideCompilationInputArgToFunctionBody(
672 *cond_function_body, i, data_types[i]);
673 TF_RETURN_IF_ERROR(arg_node_or.status());
674 }
675
676 const auto new_cond_function_name =
677 fld->UniqueFunctionName(absl::StrCat(cond_func.name(), "_lifted_arg_"));
678 FunctionDef rewritten_cond_function_def;
679 TF_RETURN_IF_ERROR(GraphToFunctionDef(
680 *cond_function_body->graph, new_cond_function_name,
681 HostGraphControlRetMapping, &rewritten_cond_function_def));
682 TF_RETURN_IF_ERROR(AddFunctionWithNewName(new_cond_function_name, "cond",
683 rewritten_cond_function_def,
684 &cond_func, n, fld));
685 return Status::OK();
686 }
687
PostprocessLiftedArgsForIf(const std::unordered_map<string,Node * > & outside_compilation_attr_to_node,Graph * g,Node * n,FunctionLibraryDefinition * fld)688 Status PostprocessLiftedArgsForIf(
689 const std::unordered_map<string, Node*>& outside_compilation_attr_to_node,
690 Graph* g, Node* n, FunctionLibraryDefinition* fld) {
691 TF_RET_CHECK(n->IsIfNode());
692
693 NameAttrList then_branch_func;
694 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "then_branch", &then_branch_func));
695 const FunctionDef* then_branch_function_def =
696 fld->Find(then_branch_func.name());
697 TF_RET_CHECK(then_branch_function_def);
698
699 NameAttrList else_branch_func;
700 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "else_branch", &else_branch_func));
701 const FunctionDef* else_branch_function_def =
702 fld->Find(else_branch_func.name());
703 TF_RET_CHECK(else_branch_function_def);
704
705 // Nothing to do if neither branch contains any lifted arguments.
706 if (!HasLiftedArgs(*then_branch_function_def) &&
707 !HasLiftedArgs(*else_branch_function_def)) {
708 return Status::OK();
709 }
710
711 std::unique_ptr<FunctionBody> then_branch_function_body;
712 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
713 *then_branch_function_def, AttrSlice(&then_branch_func.attr()), fld,
714 &then_branch_function_body));
715
716 std::unique_ptr<FunctionBody> else_branch_function_body;
717 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
718 *else_branch_function_def, AttrSlice(&else_branch_func.attr()), fld,
719 &else_branch_function_body));
720
721 // Then and else branches have same argument count and argument data types.
722 int original_arg_count = then_branch_function_body->arg_nodes.size();
723
724 TF_ASSIGN_OR_RETURN(
725 auto then_branch_lifted_arg_nodes_and_outside_compilation_nodes,
726 LiftedArgsAndOutsideCompilationNodesInFunctionBody(
727 *then_branch_function_body, outside_compilation_attr_to_node));
728
729 TF_ASSIGN_OR_RETURN(
730 auto else_branch_lifted_arg_nodes_and_outside_compilation_nodes,
731 LiftedArgsAndOutsideCompilationNodesInFunctionBody(
732 *else_branch_function_body, outside_compilation_attr_to_node));
733
734 // Merge lifted args from then and else branches.
735 std::vector<Node*> outside_compilation_nodes;
736 std::vector<Node*> then_branch_lifted_arg_nodes;
737 for (const auto& pair :
738 then_branch_lifted_arg_nodes_and_outside_compilation_nodes) {
739 outside_compilation_nodes.push_back(pair.second);
740 then_branch_lifted_arg_nodes.push_back(pair.first);
741 }
742 for (const auto& pair :
743 else_branch_lifted_arg_nodes_and_outside_compilation_nodes) {
744 if (std::find(outside_compilation_nodes.begin(),
745 outside_compilation_nodes.end(),
746 pair.second) == outside_compilation_nodes.end()) {
747 outside_compilation_nodes.push_back(pair.second);
748 // Then branch does not contain this lifted arg. Add an empty item to
749 // then_branch_lifted_arg_nodes.
750 then_branch_lifted_arg_nodes.push_back(nullptr);
751 }
752 }
753 // Reorder else_branch_lifted_arg_nodes_and_outside_compilation_nodes.
754 std::vector<Node*> else_branch_lifted_arg_nodes(
755 outside_compilation_nodes.size());
756 for (const auto& pair :
757 else_branch_lifted_arg_nodes_and_outside_compilation_nodes) {
758 auto iter = std::find(outside_compilation_nodes.begin(),
759 outside_compilation_nodes.end(), pair.second);
760 TF_RET_CHECK(iter != outside_compilation_nodes.end());
761 int index = iter - outside_compilation_nodes.begin();
762 else_branch_lifted_arg_nodes[index] = pair.first;
763 }
764
765 // Append lifted args' types to If node's Tin attribute.
766 std::vector<DataType> data_types;
767 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "Tin", &data_types));
768 for (Node* n : outside_compilation_nodes) {
769 data_types.push_back(n->output_type(0));
770 }
771 n->ClearAttr("Tin");
772 n->AddAttr("Tin", data_types);
773
774 // Add edges from outside compilation nodes to If node. If node's input #0
775 // is predicate input, input #1 maps to _Arg #0 of branch functions, thus
776 // arg_to_input_edge_offset is set to 1.
777 AddEdgesFromOutsideCompilationNodes(original_arg_count,
778 /*arg_to_input_edge_offset=*/1,
779 data_types, outside_compilation_nodes, g,
780 n);
781
782 for (int i = original_arg_count, end = data_types.size(); i < end; ++i) {
783 TF_ASSIGN_OR_RETURN(Node * then_branch_arg_node,
784 AddOutsideCompilationInputArgToFunctionBody(
785 *then_branch_function_body, i, data_types[i]));
786
787 ReplaceLiftedArgNodePlaceholderWithArg(
788 *then_branch_function_body, original_arg_count, i,
789 then_branch_lifted_arg_nodes, then_branch_arg_node);
790
791 TF_ASSIGN_OR_RETURN(Node * else_branch_arg_node,
792 AddOutsideCompilationInputArgToFunctionBody(
793 *else_branch_function_body, i, data_types[i]));
794
795 ReplaceLiftedArgNodePlaceholderWithArg(
796 *else_branch_function_body, original_arg_count, i,
797 else_branch_lifted_arg_nodes, else_branch_arg_node);
798 }
799
800 const auto new_then_function_name = fld->UniqueFunctionName(
801 absl::StrCat(then_branch_func.name(), "_lifted_arg_"));
802 FunctionDef rewritten_then_branch_function_def;
803 TF_RETURN_IF_ERROR(GraphToFunctionDef(
804 *then_branch_function_body->graph, new_then_function_name,
805 HostGraphControlRetMapping, &rewritten_then_branch_function_def));
806 TF_RETURN_IF_ERROR(AddFunctionWithNewName(
807 new_then_function_name, "then_branch", rewritten_then_branch_function_def,
808 &then_branch_func, n, fld));
809
810 const auto new_else_function_name = fld->UniqueFunctionName(
811 absl::StrCat(else_branch_func.name(), "_lifted_arg_"));
812 FunctionDef rewritten_else_branch_function_def;
813 TF_RETURN_IF_ERROR(GraphToFunctionDef(
814 *else_branch_function_body->graph, new_else_function_name,
815 HostGraphControlRetMapping, &rewritten_else_branch_function_def));
816 TF_RETURN_IF_ERROR(AddFunctionWithNewName(
817 new_else_function_name, "else_branch", rewritten_else_branch_function_def,
818 &else_branch_func, n, fld));
819 return Status::OK();
820 }
821
PostprocessLiftedArgsForCall(const std::unordered_map<string,Node * > & outside_compilation_attr_to_node,Graph * g,Node * n,FunctionLibraryDefinition * fld)822 Status PostprocessLiftedArgsForCall(
823 const std::unordered_map<string, Node*>& outside_compilation_attr_to_node,
824 Graph* g, Node* n, FunctionLibraryDefinition* fld) {
825 const FunctionDef* fdef = fld->Find(n->type_string());
826 TF_RET_CHECK(fdef);
827
828 // Nothing to do if the function does not contain any lifted arguments.
829 if (!HasLiftedArgs(*fdef)) {
830 return Status::OK();
831 }
832
833 std::unique_ptr<FunctionBody> fbody;
834 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, n->attrs(), fld, &fbody));
835
836 int original_arg_count = fbody->arg_nodes.size();
837
838 TF_ASSIGN_OR_RETURN(auto lifted_arg_nodes_and_outside_compilation_nodes,
839 LiftedArgsAndOutsideCompilationNodesInFunctionBody(
840 *fbody, outside_compilation_attr_to_node));
841
842 // Append lifted args' types to call node's input data types.
843 std::vector<DataType> data_types(n->input_types().begin(),
844 n->input_types().end());
845 for (auto pair : lifted_arg_nodes_and_outside_compilation_nodes) {
846 Node* outside_compilation_node = pair.second;
847 DataType data_type;
848 TF_RET_CHECK(outside_compilation_node->IsIdentity() ||
849 outside_compilation_node->type_string() == "Placeholder");
850 if (outside_compilation_node->IsIdentity()) {
851 TF_RETURN_IF_ERROR(
852 GetNodeAttr(outside_compilation_node->def(), "T", &data_type));
853 } else {
854 TF_RETURN_IF_ERROR(
855 GetNodeAttr(outside_compilation_node->def(), "dtype", &data_type));
856 }
857 data_types.push_back(data_type);
858 }
859
860 std::vector<Node*> lifted_arg_nodes;
861 std::transform(
862 lifted_arg_nodes_and_outside_compilation_nodes.begin(),
863 lifted_arg_nodes_and_outside_compilation_nodes.end(),
864 std::back_inserter(lifted_arg_nodes),
865 [](const std::pair<Node*, Node*>& pair) { return pair.first; });
866 for (int i = original_arg_count, end = data_types.size(); i < end; ++i) {
867 TF_ASSIGN_OR_RETURN(
868 Node * arg_node,
869 AddOutsideCompilationInputArgToFunctionBody(*fbody, i, data_types[i]));
870
871 ReplaceLiftedArgNodePlaceholderWithArg(*fbody, original_arg_count, i,
872 lifted_arg_nodes, arg_node);
873 }
874
875 FunctionDef rewritten_fdef;
876 TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, n->type_string(),
877 HostGraphControlRetMapping,
878 &rewritten_fdef));
879 const auto new_function_name =
880 fld->UniqueFunctionName(absl::StrCat(n->type_string(), "_lifted_arg_"));
881 rewritten_fdef.mutable_signature()->set_name(new_function_name);
882 TF_RETURN_IF_ERROR(fld->AddFunctionDef(rewritten_fdef));
883
884 // We need to recreate the node. Otherwise TF will not know n->num_inputs()
885 // has increased.
886 NodeDef node_def = n->def();
887
888 // Function name is represented via the Op's type. Reset the op type to new
889 // function def name;
890 *node_def.mutable_op() = new_function_name;
891
892 for (int i = original_arg_count, end = data_types.size(); i < end; i++) {
893 Node* outside_compilation_node =
894 lifted_arg_nodes_and_outside_compilation_nodes[i - original_arg_count]
895 .second;
896 node_def.add_input(absl::StrCat(outside_compilation_node->name(), ":", 0));
897 }
898 TF_ASSIGN_OR_RETURN(n, ReplaceNode(g, n, node_def));
899
900 // Add edges from outside compilation nodes to call node.
901 std::vector<Node*> outside_compilation_nodes;
902 std::transform(
903 lifted_arg_nodes_and_outside_compilation_nodes.begin(),
904 lifted_arg_nodes_and_outside_compilation_nodes.end(),
905 std::back_inserter(outside_compilation_nodes),
906 [](const std::pair<Node*, Node*>& pair) { return pair.second; });
907 AddEdgesFromOutsideCompilationNodes(original_arg_count,
908 /*arg_to_input_edge_offset=*/0,
909 data_types, outside_compilation_nodes, g,
910 n);
911
912 return Status::OK();
913 }
914
915 // Creates a mapping from outside compilation cluster name to lifted argument
916 // placeholder.
OutsideCompilationAttrToNode(const Graph & g)917 StatusOr<std::unordered_map<string, Node*>> OutsideCompilationAttrToNode(
918 const Graph& g) {
919 std::unordered_map<string, Node*> outside_compilation_attr_to_node;
920 for (Node* n : g.op_nodes()) {
921 bool is_lifted_arg;
922 string outside_compilation_attr;
923 if (TryGetNodeAttr(n->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg) &&
924 TryGetNodeAttr(n->def(), "_xla_outside_compilation",
925 &outside_compilation_attr)) {
926 TF_RET_CHECK(is_lifted_arg);
927 TF_RET_CHECK(n->IsIdentity() || n->type_string() == "Placeholder");
928 outside_compilation_attr_to_node[outside_compilation_attr] = n;
929 }
930 }
931
932 return outside_compilation_attr_to_node;
933 }
934
PostprocessLiftedArgs(Graph * g,FunctionLibraryDefinition * fld)935 Status PostprocessLiftedArgs(Graph* g, FunctionLibraryDefinition* fld) {
936 TF_ASSIGN_OR_RETURN(auto outside_compilation_attr_to_node,
937 OutsideCompilationAttrToNode(*g));
938
939 std::vector<Node*> call_nodes;
940 for (Node* n : g->op_nodes()) {
941 if (!HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) {
942 continue;
943 }
944
945 if (n->IsWhileNode()) {
946 TF_RETURN_IF_ERROR(PostprocessLiftedArgsForWhile(
947 outside_compilation_attr_to_node, g, n, fld));
948 }
949
950 if (n->IsIfNode()) {
951 TF_RETURN_IF_ERROR(PostprocessLiftedArgsForIf(
952 outside_compilation_attr_to_node, g, n, fld));
953 }
954
955 // Outside compilation host side function call will always be direct
956 // function call nodes.
957 // Function call nodes need to be handled separately because we rewrite
958 // nodes in `PostprocessLiftedArgsForCall`.
959 if (fld->Contains(n->type_string())) {
960 call_nodes.push_back(n);
961 }
962 }
963
964 for (Node* n : call_nodes) {
965 TF_RETURN_IF_ERROR(PostprocessLiftedArgsForCall(
966 outside_compilation_attr_to_node, g, n, fld));
967 }
968
969 return Status::OK();
970 }
971
972 // For an XLA computation, builds host side graph given all outside compilation
973 // graphs inside it. The host side graph contains:
974 // 1) a "sequencer" node (we will add control edge between XlaRecvAtHost and
975 // XlaSendFromHost to this sequencer node, so all outside compilation nodes
976 // will be executed *before* this sequencer).
977 // 2) a "key placeholder" node. Later in ExpandHostGraphIntoMainGraph(), we will
978 // replace this node with compilation result node.
979 // 3) all outside compilation graphs.
ConstructHostGraph(const string & xla_cluster_name,const string & outside_compilation_attr_name,const std::vector<string> & outside_compilation_host_graphs,FunctionLibraryDefinition * fld,std::unique_ptr<Graph> * host_graph)980 Status ConstructHostGraph(
981 const string& xla_cluster_name, const string& outside_compilation_attr_name,
982 const std::vector<string>& outside_compilation_host_graphs,
983 FunctionLibraryDefinition* fld, std::unique_ptr<Graph>* host_graph) {
984 host_graph->reset(new Graph(fld));
985
986 // Create sequencer node in host graph.
987 NodeDefBuilder sequencer_builder(absl::StrCat(xla_cluster_name, "_sequencer"),
988 "NoOp");
989 sequencer_builder.Attr("_xla_host_transfer_sequencer", xla_cluster_name);
990 NodeDef sequencer_def;
991 TF_RETURN_IF_ERROR(sequencer_builder.Finalize(&sequencer_def));
992 Status s;
993 Node* sequencer = (*host_graph)->AddNode(sequencer_def, &s);
994 TF_RETURN_IF_ERROR(s);
995
996 // Create key placeholder in host graph.
997 TF_ASSIGN_OR_RETURN(
998 Node * key_placeholder,
999 AddHostComputeKeyPlaceholder(xla_cluster_name, host_graph->get()));
1000
1001 // For each outside compilation graph, copy them to host graph with the
1002 // following changes:
1003 // a) Use key_placeholder in host graph instead of its own.
1004 // b) Add control edge from host transfer nodes (XlaRecvAtHost,
1005 // XlaSendFromHost, If/While nodes containing
1006 // XlaRecvAtHost/XlaSendFromHost) to sequencer node.
1007 // c) Clear node_def.device(), so device placer won't get confused.
1008 for (const string& host_func : outside_compilation_host_graphs) {
1009 VLOG(4) << "Expanding host graph " << host_func;
1010 // Temporarily use "0" as "_device_ordinal". It will be reset to placeholder
1011 // value after we expanded all host graphs. We cannot just use placeholder
1012 // value here because FunctionDef instantiation does not allow placeholder
1013 // value for attributes.
1014 AttrValue device_ordinal_attr;
1015 device_ordinal_attr.set_i(0);
1016 protobuf::Map<string, AttrValue> attrs;
1017 attrs["_device_ordinal"] = device_ordinal_attr;
1018 std::unique_ptr<FunctionBody> host_fbody;
1019 const FunctionDef* host_fdef = fld->Find(host_func);
1020 TF_RET_CHECK(host_fdef);
1021 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*host_fdef, AttrSlice(&attrs),
1022 fld, &host_fbody));
1023
1024 // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse
1025 // reachable from sink node so all nodes will be copied.
1026 // TODO(b/77601805): consolidate copy graph functions.
1027 FixupSourceAndSinkEdges(host_fbody->graph);
1028
1029 std::map<const Node*, Node*> node_map;
1030 node_map[host_fbody->graph->source_node()] = (*host_graph)->source_node();
1031 node_map[host_fbody->graph->sink_node()] = (*host_graph)->sink_node();
1032 Status s;
1033 ReverseDFS(
1034 *host_fbody->graph, /*enter=*/nullptr,
1035 [&](const Node* n) {
1036 if (!s.ok()) {
1037 return;
1038 }
1039
1040 Node* copy;
1041 if (node_map.find(n) != node_map.end()) {
1042 // Already copied this node.
1043 copy = node_map.at(n);
1044 } else if (IsKeyPlaceholderNode(*n)) {
1045 // Change a).
1046 copy = key_placeholder;
1047 node_map[n] = copy;
1048 } else {
1049 // Copy the node.
1050 NodeDef copy_def = n->def();
1051 // Change c).
1052 copy_def.clear_device();
1053 copy = (*host_graph)->AddNode(copy_def, &s);
1054 if (!s.ok()) {
1055 return;
1056 }
1057 node_map[n] = copy;
1058 }
1059
1060 // Only handle input edges. Output edges will be added later as
1061 // its output nodes' input edges.
1062 for (auto e : n->in_edges()) {
1063 if (node_map.find(e->src()) == node_map.end()) {
1064 s = errors::Internal("Cannot find node image for ",
1065 e->src()->DebugString());
1066 return;
1067 }
1068 (*host_graph)
1069 ->AddEdge(node_map[e->src()], e->src_output(), copy,
1070 e->dst_input());
1071 }
1072
1073 // Change b).
1074 if (HasNodeAttr(copy->def(), kXlaHasHostTransferAttrName)) {
1075 (*host_graph)->AddControlEdge(copy, sequencer);
1076 }
1077 },
1078 NodeComparatorID());
1079
1080 if (!s.ok()) {
1081 return s;
1082 }
1083 }
1084 // Reset "_device_ordinal" to placeholder value.
1085 TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(host_graph->get()));
1086
1087 // sequencer and key_placeholder might be dead nodes. Prune them if necessary.
1088 // - sequencer should be pruned iff it has no input control edges from
1089 // RecvAtHost/SendFromHost. If it has input control edge, we connect it to
1090 // sink node so it won't be pruned.
1091 // - key_placeholder should be pruned iff there's no RecvAtHost/SendFromHost.
1092 // We don't need to do anything special.
1093 if (!sequencer->in_edges().empty()) {
1094 (*host_graph)->AddControlEdge(sequencer, (*host_graph)->sink_node());
1095 }
1096 PruneForReverseReachability(
1097 host_graph->get(),
1098 std::unordered_set<const Node*>{(*host_graph)->sink_node()});
1099
1100 // Postprocess edges between different outside compilations.
1101 TF_RETURN_IF_ERROR(PostprocessEdgesBetweenOutsideCompilations(
1102 host_graph->get(), outside_compilation_attr_name));
1103
1104 // Postprocess lifted arg nodes.
1105 TF_RETURN_IF_ERROR(PostprocessLiftedArgs(host_graph->get(), fld));
1106
1107 if (VLOG_IS_ON(4)) {
1108 DumpGraphToFile(absl::StrCat("extract_outside_compilation_host_graph_for_",
1109 xla_cluster_name),
1110 **host_graph, fld);
1111 }
1112
1113 return Status::OK();
1114 }
1115
1116 // Expand XLA computation's outside compilation host side graph into main graph.
1117 // Add a control edge between sequencer node and the XLA computation node.
ExpandHostGraphIntoMainGraph(Graph * main_graph,FunctionLibraryDefinition * fld,const string & host_graph_func_name,Node * xla_computation_node,Node * pivot_node)1118 Status ExpandHostGraphIntoMainGraph(Graph* main_graph,
1119 FunctionLibraryDefinition* fld,
1120 const string& host_graph_func_name,
1121 Node* xla_computation_node,
1122 Node* pivot_node) {
1123 // Temporarily use "0" as "_device_ordinal". It will be rewritten with the
1124 // correct value in a later pass. We cannot just use placeholder value here
1125 // because FunctionDef instantiation does not allow placeholder value for
1126 // attributes.
1127 AttrValue device_ordinal_attr;
1128 device_ordinal_attr.set_i(0);
1129 protobuf::Map<string, AttrValue> attrs;
1130 attrs["_device_ordinal"] = device_ordinal_attr;
1131 std::unique_ptr<FunctionBody> fbody;
1132 const FunctionDef* host_graph_func = fld->Find(host_graph_func_name);
1133 TF_RET_CHECK(host_graph_func);
1134 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*host_graph_func,
1135 AttrSlice(&attrs), fld, &fbody));
1136 Graph* host_graph = fbody->graph;
1137
1138 // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse
1139 // reachable from sink node so all nodes will be copied.
1140 // TODO(b/77601805): consolidate copy graph functions.
1141 FixupSourceAndSinkEdges(host_graph);
1142
1143 // Copy all nodes.
1144 std::map<const Node*, Node*> node_map;
1145 if (pivot_node) {
1146 node_map[host_graph->source_node()] = pivot_node;
1147 } else {
1148 node_map[host_graph->source_node()] = main_graph->source_node();
1149 }
1150 node_map[host_graph->sink_node()] = main_graph->sink_node();
1151 Status s = Status::OK();
1152 auto copy_node_fn = [&](const Node* n) {
1153 if (!s.ok()) {
1154 return;
1155 }
1156
1157 Node* copy;
1158 if (node_map.find(n) != node_map.end()) {
1159 // Already copied this node.
1160 copy = node_map.at(n);
1161 } else {
1162 // Copy the node.
1163 NodeDef copy_def = n->def();
1164 copy = main_graph->AddNode(copy_def, &s);
1165 if (!s.ok()) {
1166 return;
1167 }
1168 node_map[n] = copy;
1169 }
1170
1171 // Only handle input edges. Output edges will be added later as its output
1172 // nodes' input edges.
1173 for (auto e : n->in_edges()) {
1174 if (node_map.find(e->src()) == node_map.end()) {
1175 s = errors::Internal("Cannot find node image for ",
1176 e->src()->DebugString());
1177 return;
1178 }
1179 main_graph->AddEdge(node_map[e->src()], e->src_output(), copy,
1180 e->dst_input());
1181 }
1182
1183 // Add control edge from sequencer to XLA computation node.
1184 if (copy->type_string() == "NoOp" &&
1185 HasNodeAttr(copy->def(), "_xla_host_transfer_sequencer")) {
1186 main_graph->AddControlEdge(copy, xla_computation_node);
1187 }
1188 };
1189 ReverseDFS(*host_graph, /*enter=*/nullptr, copy_node_fn, NodeComparatorID());
1190 return s;
1191 }
1192
1193 // Rewrites shape inference graph for outside compilation:
1194 // 1) If XlaSendFromHost also exists in `host_graph`, copy nodes from
1195 // `host_graph`. Because we might still have outside compilation to outside
1196 // compilation placeholder nodes in shape inference graph, which will prevent
1197 // us from inferring XlaSendFromHost shape. But in `host_graph`, we already
1198 // removed those placeholder nodes.
1199 // 2) Remove control edges.
1200 // 3) Prune nodes that are not useful for shape inference.
RewriteShapeInferenceGraph(const string & shape_inference_graph_name,Graph * host_graph,Node * pivot_node,FunctionLibraryDefinition * fld)1201 Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name,
1202 Graph* host_graph, Node* pivot_node,
1203 FunctionLibraryDefinition* fld) {
1204 // Use "0" as "_device_ordinal". It does not matter for shape inference.
1205 AttrValue device_ordinal_attr;
1206 device_ordinal_attr.set_i(0);
1207 protobuf::Map<string, AttrValue> attrs;
1208 attrs["_device_ordinal"] = device_ordinal_attr;
1209 std::unique_ptr<FunctionBody> fbody;
1210 const FunctionDef* shape_inference_graph =
1211 fld->Find(shape_inference_graph_name);
1212 TF_RET_CHECK(shape_inference_graph);
1213 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*shape_inference_graph,
1214 AttrSlice(&attrs), fld, &fbody));
1215 Graph* g = fbody->graph;
1216
1217 // Find SendFromHost node.
1218 Node* send_from_host = nullptr;
1219 for (Node* n : g->nodes()) {
1220 if (n->type_string() == "_XlaSendFromHost") {
1221 send_from_host = n;
1222 break;
1223 }
1224 }
1225 if (!send_from_host) {
1226 return errors::Internal("Shape inference graph ",
1227 shape_inference_graph_name,
1228 " does not have _XlaSendFromHost node.");
1229 }
1230
1231 // See if the SendFromHost node exists in `host_graph`.
1232 Node* send_node_in_host_graph = nullptr;
1233 for (Node* n : host_graph->nodes()) {
1234 if (n->name() == send_from_host->name()) {
1235 send_node_in_host_graph = n;
1236 break;
1237 }
1238 }
1239 if (send_node_in_host_graph) {
1240 // This is an "top-level" outside compilation. Clear the graph, and copy
1241 // SendFromHost and all its predecessors from `host_graph`.
1242 std::vector<Node*> nodes;
1243 for (Node* n : g->op_nodes()) {
1244 nodes.push_back(n);
1245 }
1246 for (Node* n : nodes) {
1247 g->RemoveNode(n);
1248 }
1249 Node* start_node = pivot_node ? pivot_node : host_graph->source_node();
1250 // Reverse DFS from send_from_host_main_graph, and stop at start_node.
1251 struct Visit {
1252 Node* n;
1253 bool is_exiting;
1254 };
1255 std::vector<Visit> stack{{send_node_in_host_graph, false}};
1256 std::map<Node*, Node*> node_map;
1257 node_map[host_graph->source_node()] = g->source_node();
1258 while (!stack.empty()) {
1259 Visit& curr = stack.back();
1260 if (curr.is_exiting) {
1261 if (node_map.find(curr.n) == node_map.end()) {
1262 Node* copy = g->CopyNode(curr.n);
1263 if (curr.n != start_node) {
1264 for (const Edge* e : curr.n->in_edges()) {
1265 auto node_iter = node_map.find(e->src());
1266 if (node_iter == node_map.end()) {
1267 return errors::Internal("Cannot find node image for ",
1268 e->src()->DebugString());
1269 }
1270 g->AddEdge(node_iter->second, e->src_output(), copy,
1271 e->dst_input());
1272 }
1273 }
1274 node_map[curr.n] = copy;
1275 }
1276 stack.pop_back();
1277 } else {
1278 curr.is_exiting = true;
1279 if (curr.n != start_node) {
1280 for (const Edge* e : curr.n->in_edges()) {
1281 if (node_map.find(e->src()) != node_map.end()) {
1282 continue;
1283 }
1284 stack.push_back({e->src(), false});
1285 }
1286 }
1287 }
1288 }
1289
1290 send_from_host = node_map[send_node_in_host_graph];
1291 } else {
1292 // This is an outside compilation generated for If/While/gradient/etc.
1293 // It will be enough for shape inference. Leave `g` unchanged.
1294 }
1295
1296 // Control edges are not useful for shape inference. Remove them.
1297 for (auto e : g->edges()) {
1298 if (e->IsControlEdge()) {
1299 g->RemoveEdge(e);
1300 }
1301 }
1302
1303 // Nodes that are not reverse reachable from SendFromHost are not useful for
1304 // shape inference. Prune them.
1305 PruneForReverseReachability(g,
1306 std::unordered_set<const Node*>{send_from_host});
1307
1308 if (VLOG_IS_ON(4)) {
1309 DumpGraphToFile(shape_inference_graph_name, *g, fld);
1310 }
1311
1312 // Replace original shape inference graph.
1313 FunctionDef fdef_replace;
1314 TF_RETURN_IF_ERROR(
1315 GraphToFunctionDef(*g, shape_inference_graph_name, &fdef_replace));
1316 TF_RETURN_IF_ERROR(
1317 fld->ReplaceFunction(shape_inference_graph_name, fdef_replace));
1318
1319 return Status::OK();
1320 }
1321
1322 // Builds XlaSendToHost node which sends cond predicate to host.
BuildSendIfPredNode(const string & name,const string & host_transfer_key,Node * pred_node,Graph * g)1323 TF_ATTRIBUTE_NOINLINE StatusOr<Node*> BuildSendIfPredNode(
1324 const string& name, const string& host_transfer_key, Node* pred_node,
1325 Graph* g) {
1326 NodeDefBuilder send_pred_builder(name, "XlaSendToHost");
1327 send_pred_builder.Attr("Tinput", DT_BOOL);
1328 send_pred_builder.Attr("key", absl::StrCat(host_transfer_key, "_dtoh_0"));
1329 send_pred_builder.Attr(kXlaTokenInputNodesAttrName,
1330 std::vector<string>{kXlaTokenArgNodeName});
1331 send_pred_builder.Attr(kXlaOriginalOutsideCompilationNodeName, name);
1332 send_pred_builder.Input(pred_node->name(), 0, DT_BOOL);
1333 NodeDef send_pred_def;
1334 TF_RETURN_IF_ERROR(send_pred_builder.Finalize(&send_pred_def));
1335 Status s;
1336 Node* send_pred_node = g->AddNode(send_pred_def, &s);
1337 TF_RETURN_IF_ERROR(s);
1338 g->AddEdge(pred_node, 0, send_pred_node, 0);
1339 return send_pred_node;
1340 }
1341
1342 // Replaces key placeholder node with an _Arg node.
ReplaceKeyPlaceholderWithArgNode(const string & xla_cluster_name,const string & func_name,FunctionLibraryDefinition * fld)1343 Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name,
1344 const string& func_name,
1345 FunctionLibraryDefinition* fld) {
1346 // Temporarily use "0" as "_device_ordinal". It will be reset to placeholder
1347 // value after rewriting.
1348 AttrValue device_ordinal_attr;
1349 device_ordinal_attr.set_i(0);
1350 protobuf::Map<string, AttrValue> attrs;
1351 attrs["_device_ordinal"] = device_ordinal_attr;
1352 std::unique_ptr<FunctionBody> fbody;
1353 const FunctionDef* func = fld->Find(func_name);
1354 TF_RETURN_IF_ERROR(
1355 FunctionDefToBodyHelper(*func, AttrSlice(&attrs), fld, &fbody));
1356 Graph* g = fbody->graph;
1357
1358 // Find or create the key placeholder node.
1359 Node* key_placeholder = nullptr;
1360 for (Node* n : g->nodes()) {
1361 if (IsKeyPlaceholderNode(*n)) {
1362 key_placeholder = n;
1363 break;
1364 }
1365 }
1366 if (!key_placeholder) {
1367 TF_ASSIGN_OR_RETURN(key_placeholder,
1368 AddHostComputeKeyPlaceholder(xla_cluster_name, g));
1369 }
1370
1371 // Build the _Arg node, and replace key placeholder node with it.
1372 NodeDefBuilder arg_builder("key_arg", FunctionLibraryDefinition::kArgOp);
1373 arg_builder.Attr("T", DT_STRING);
1374 arg_builder.Attr("index", 0);
1375 NodeDef arg_def;
1376 TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def));
1377 TF_RETURN_IF_ERROR(ReplaceNode(g, key_placeholder, arg_def).status());
1378
1379 // Reset "_device_ordinal" to placeholder value.
1380 TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(g));
1381
1382 FunctionDef replace_fdef;
1383 TF_RETURN_IF_ERROR(GraphToFunctionDef(
1384 *g, func_name, HostGraphControlRetMapping, &replace_fdef));
1385 TF_RETURN_IF_ERROR(fld->ReplaceFunction(func_name, replace_fdef));
1386 return Status::OK();
1387 }
1388
1389 // Builds host side graph for If node.
BuildHostGraphForIfNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const string & if_node_name,const string & host_transfer_key,const string & host_graph_func_name,FunctionLibraryDefinition * fld,const string & then_branch_host_func_name,const string & else_branch_host_func_name)1390 TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForIfNode(
1391 const string& xla_cluster_attr_name,
1392 const string& outside_compilation_attr_name, const string& xla_cluster_name,
1393 const string& if_node_name, const string& host_transfer_key,
1394 const string& host_graph_func_name, FunctionLibraryDefinition* fld,
1395 const string& then_branch_host_func_name,
1396 const string& else_branch_host_func_name) {
1397 Graph host_graph(fld);
1398 string outside_compilation_name = absl::StrCat("oc_if_", if_node_name);
1399 AttrValue device_ordinal_value;
1400 device_ordinal_value.set_placeholder("_device_ordinal");
1401
1402 // Step 1: add key placeholder node.
1403 TF_ASSIGN_OR_RETURN(
1404 Node * key_placeholder,
1405 AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
1406
1407 // Step 2: build XlaRecvAtHost node to recv predicate.
1408 NodeDefBuilder recv_pred_builder(
1409 absl::StrCat("recv_oc_if_pred_", if_node_name), "_XlaRecvAtHost");
1410 recv_pred_builder.Attr("Toutputs", std::vector<DataType>{DT_BOOL});
1411 recv_pred_builder.Attr("key", host_transfer_key);
1412 recv_pred_builder.Attr("device_ordinal", device_ordinal_value);
1413 recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1414 recv_pred_builder.Attr(outside_compilation_attr_name,
1415 outside_compilation_name);
1416 recv_pred_builder.Attr(kXlaHasHostTransferAttrName, true);
1417 recv_pred_builder.Input(key_placeholder->name(), 0, DT_STRING);
1418 NodeDef recv_pred_def;
1419 TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def));
1420 Status s;
1421 Node* recv_pred_node = host_graph.AddNode(recv_pred_def, &s);
1422 TF_RETURN_IF_ERROR(s);
1423 host_graph.AddEdge(key_placeholder, 0, recv_pred_node, 0);
1424
1425 // Step 3: rewrite `{then, else}_branch_host_func_name`, replace key
1426 // placeholder with an _Arg node.
1427 TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1428 xla_cluster_name, then_branch_host_func_name, fld));
1429 TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1430 xla_cluster_name, else_branch_host_func_name, fld));
1431
1432 // Step 4: build If node to choose between `{then, else}_branch_host_graph`.
1433 NodeDefBuilder if_builder(absl::StrCat("oc_if_", if_node_name), "If");
1434 if_builder.Attr("Tcond", DT_BOOL);
1435 if_builder.Attr("Tin", std::vector<DataType>{DT_STRING});
1436 if_builder.Attr("Tout", std::vector<DataType>{});
1437 NameAttrList host_then_branch, host_else_branch;
1438 host_then_branch.set_name(then_branch_host_func_name);
1439 (*host_then_branch.mutable_attr())["_device_ordinal"] = device_ordinal_value;
1440 host_else_branch.set_name(else_branch_host_func_name);
1441 (*host_else_branch.mutable_attr())["_device_ordinal"] = device_ordinal_value;
1442 if_builder.Attr("then_branch", host_then_branch);
1443 if_builder.Attr("else_branch", host_else_branch);
1444 if_builder.Attr(kXlaHasHostTransferAttrName, true);
1445 if_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1446 if_builder.Attr(outside_compilation_attr_name, outside_compilation_name);
1447 if_builder.Input(recv_pred_node->name(), 0, DT_BOOL);
1448 std::vector<NodeDefBuilder::NodeOut> if_inputs{
1449 {key_placeholder->name(), 0, DT_STRING}};
1450 if_builder.Input(if_inputs);
1451 NodeDef if_def;
1452 TF_RETURN_IF_ERROR(if_builder.Finalize(&if_def));
1453 Node* if_node = host_graph.AddNode(if_def, &s);
1454 TF_RETURN_IF_ERROR(s);
1455 host_graph.AddEdge(recv_pred_node, 0, if_node, 0);
1456 host_graph.AddEdge(key_placeholder, 0, if_node, 1);
1457
1458 // Convert `host_graph` to function.
1459 FunctionDef oc_host_graph_fdef;
1460 TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name,
1461 &oc_host_graph_fdef));
1462 if (fld->Find(host_graph_func_name)) {
1463 TF_RETURN_IF_ERROR(
1464 fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef));
1465 } else {
1466 TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef));
1467 }
1468
1469 return Status::OK();
1470 }
1471
1472 // Rewrites loop cond to add a node which sends loop cond to host.
AddSendLoopPredToLoopCond(const string & cond_xla_func_name,const string & host_transfer_key,NameAttrList * loop_cond_func,FunctionLibraryDefinition * fld,Node * while_node)1473 TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond(
1474 const string& cond_xla_func_name, const string& host_transfer_key,
1475 NameAttrList* loop_cond_func, FunctionLibraryDefinition* fld,
1476 Node* while_node) {
1477 // Instantiate the loop cond function.
1478 std::unique_ptr<FunctionBody> fbody;
1479 const FunctionDef* loop_cond_fdef = fld->Find(loop_cond_func->name());
1480 TF_RET_CHECK(loop_cond_fdef);
1481 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1482 *loop_cond_fdef, AttrSlice(&loop_cond_func->attr()), fld, &fbody));
1483 Graph* g = fbody->graph;
1484
1485 // Find the _Retval node and the loop cond node.
1486 Node* ret_node = nullptr;
1487 for (Node* n : g->nodes()) {
1488 if (n->type_string() == "_Retval") {
1489 if (ret_node) {
1490 return errors::Internal("Multiple return node for loop cond function ",
1491 loop_cond_func->name(), ": ",
1492 ret_node->DebugString(), " and ",
1493 n->DebugString());
1494 } else {
1495 ret_node = n;
1496 }
1497 }
1498 }
1499 if (!ret_node) {
1500 return errors::Internal("No _Retval node for loop cond function ",
1501 loop_cond_func->name());
1502 }
1503 Node* loop_cond;
1504 TF_RETURN_IF_ERROR(ret_node->input_node(0, &loop_cond));
1505
1506 // Build the XlaSendToHost node.
1507 NodeDefBuilder send_loop_cond_builder(
1508 absl::StrCat("send_oc_while_cond_", while_node->name()), "XlaSendToHost");
1509 send_loop_cond_builder.Attr("Tinput", DT_BOOL);
1510 send_loop_cond_builder.Attr("key",
1511 absl::StrCat(host_transfer_key, "_dtoh_0"));
1512 send_loop_cond_builder.Attr(kXlaTokenInputNodesAttrName,
1513 std::vector<string>{kXlaTokenArgNodeName});
1514 send_loop_cond_builder.Attr(kXlaOriginalOutsideCompilationNodeName,
1515 send_loop_cond_builder.node_name());
1516 send_loop_cond_builder.Input(loop_cond->name(), 0, DT_BOOL);
1517 NodeDef send_loop_cond_def;
1518 TF_RETURN_IF_ERROR(send_loop_cond_builder.Finalize(&send_loop_cond_def));
1519 Status s;
1520 Node* send_loop_cond_node = g->AddNode(send_loop_cond_def, &s);
1521 TF_RETURN_IF_ERROR(s);
1522 g->AddEdge(loop_cond, 0, send_loop_cond_node, 0);
1523
1524 // Replace original function if loop_cond_func already has been re-written
1525 // for outside compilation.
1526 FunctionDef replace_fdef;
1527 if (loop_cond_func->name() == cond_xla_func_name) {
1528 TF_RETURN_IF_ERROR(
1529 GraphToFunctionDef(*g, loop_cond_func->name(), &replace_fdef));
1530 TF_RETURN_IF_ERROR(
1531 fld->ReplaceFunction(loop_cond_func->name(), replace_fdef));
1532 } else {
1533 // If original while cond function has not been modified, add a new function
1534 // with send loop predicated added and update the while node callsite
1535 // operation.
1536 const auto new_name = fld->UniqueFunctionName(
1537 absl::StrCat(loop_cond_func->name(), "_send_pred_added_"));
1538 TF_RETURN_IF_ERROR(GraphToFunctionDef(*g, new_name, &replace_fdef));
1539 TF_RETURN_IF_ERROR(fld->AddFunctionDef(replace_fdef));
1540 loop_cond_func->set_name(new_name);
1541 while_node->ClearAttr("cond");
1542 while_node->AddAttr("cond", *loop_cond_func);
1543 }
1544
1545 return Status::OK();
1546 }
1547
1548 // Rewrites while loop cond function for host.
RewriteHostWhileLoopCond(const string & cond_host_func_name,const string & while_node_name,const string & host_transfer_key,const string & xla_cluster_attr_name,const string & xla_cluster_name,const string & outside_compilation_attr_name,const string & outside_compilation_name,FunctionLibraryDefinition * fld)1549 Status RewriteHostWhileLoopCond(
1550 const string& cond_host_func_name, const string& while_node_name,
1551 const string& host_transfer_key, const string& xla_cluster_attr_name,
1552 const string& xla_cluster_name, const string& outside_compilation_attr_name,
1553 const string& outside_compilation_name, FunctionLibraryDefinition* fld) {
1554 // Replace key placeholder node with _Arg node.
1555 TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1556 xla_cluster_name, cond_host_func_name, fld));
1557
1558 // Instantiate cond function.
1559 AttrValue device_ordinal_temp_value;
1560 device_ordinal_temp_value.set_i(0);
1561 protobuf::Map<string, AttrValue> attrs;
1562 attrs["_device_ordinal"] = device_ordinal_temp_value;
1563 std::unique_ptr<FunctionBody> cond_fbody;
1564 const FunctionDef* cond_host_func = fld->Find(cond_host_func_name);
1565 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*cond_host_func, AttrSlice(&attrs),
1566 fld, &cond_fbody));
1567 Graph* cond_graph = cond_fbody->graph;
1568 Node* key_arg = nullptr;
1569 for (Node* n : cond_graph->nodes()) {
1570 if (n->type_string() == "_Arg") {
1571 key_arg = n;
1572 }
1573 }
1574 if (!key_arg) {
1575 return errors::Internal(
1576 "No _Arg node found for host compute key in function ",
1577 cond_host_func_name);
1578 }
1579
1580 // Add an XlaRecvAtHost node to use as cond function return value.
1581 NodeDefBuilder recv_pred_builder(
1582 absl::StrCat("recv_oc_while_cond_", while_node_name), "_XlaRecvAtHost");
1583 recv_pred_builder.Attr("Toutputs", std::vector<DataType>{DT_BOOL});
1584 recv_pred_builder.Attr("key", host_transfer_key);
1585 AttrValue device_ordinal_value;
1586 device_ordinal_value.set_placeholder("_device_ordinal");
1587 recv_pred_builder.Attr("device_ordinal", device_ordinal_value);
1588 recv_pred_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1589 recv_pred_builder.Attr(outside_compilation_attr_name,
1590 outside_compilation_name);
1591 recv_pred_builder.Attr(kXlaHasHostTransferAttrName, true);
1592 recv_pred_builder.Input(key_arg->name(), 0, DT_STRING);
1593 NodeDef recv_pred_def;
1594 TF_RETURN_IF_ERROR(recv_pred_builder.Finalize(&recv_pred_def));
1595 Status s;
1596 Node* recv_pred_node = cond_graph->AddNode(recv_pred_def, &s);
1597 TF_RETURN_IF_ERROR(s);
1598 cond_graph->AddEdge(key_arg, 0, recv_pred_node, 0);
1599 NodeDefBuilder ret_builder(
1600 absl::StrCat("recv_oc_while_cond_ret_", while_node_name), "_Retval");
1601 ret_builder.Attr("T", DT_BOOL);
1602 ret_builder.Attr("index", 0);
1603 ret_builder.Input(recv_pred_node->name(), 0, DT_BOOL);
1604 NodeDef ret_def;
1605 TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
1606 Node* ret_node = cond_graph->AddNode(ret_def, &s);
1607 TF_RETURN_IF_ERROR(s);
1608 cond_graph->AddEdge(recv_pred_node, 0, ret_node, 0);
1609
1610 // Reset device_ordinal to placeholder value.
1611 TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(cond_graph));
1612
1613 // Replace original function.
1614 FunctionDef cond_replace_fdef;
1615 TF_RETURN_IF_ERROR(GraphToFunctionDef(*cond_graph, cond_host_func_name,
1616 HostGraphControlRetMapping,
1617 &cond_replace_fdef));
1618 TF_RETURN_IF_ERROR(
1619 fld->ReplaceFunction(cond_host_func_name, cond_replace_fdef));
1620
1621 return Status::OK();
1622 }
1623
1624 // Rewrites while loop body function for host.
RewriteHostWhileLoopBody(const string & body_host_func_name,const string & while_node_name,const string & host_transfer_key,const string & xla_cluster_attr_name,const string & xla_cluster_name,const string & outside_compilation_attr_name,const string & outside_compilation_name,FunctionLibraryDefinition * fld)1625 Status RewriteHostWhileLoopBody(
1626 const string& body_host_func_name, const string& while_node_name,
1627 const string& host_transfer_key, const string& xla_cluster_attr_name,
1628 const string& xla_cluster_name, const string& outside_compilation_attr_name,
1629 const string& outside_compilation_name, FunctionLibraryDefinition* fld) {
1630 // Replace key placeholder node with _Arg node.
1631 TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1632 xla_cluster_name, body_host_func_name, fld));
1633
1634 // Instantiate body function.
1635 AttrValue device_ordinal_temp_value;
1636 device_ordinal_temp_value.set_i(0);
1637 protobuf::Map<string, AttrValue> attrs;
1638 attrs["_device_ordinal"] = device_ordinal_temp_value;
1639 std::unique_ptr<FunctionBody> body_fbody;
1640 const FunctionDef* body_host_func = fld->Find(body_host_func_name);
1641 TF_RET_CHECK(body_host_func);
1642 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*body_host_func, AttrSlice(&attrs),
1643 fld, &body_fbody));
1644 Graph* body_graph = body_fbody->graph;
1645 Node* key_arg = nullptr;
1646 for (Node* n : body_graph->nodes()) {
1647 if (n->type_string() == "_Arg") {
1648 key_arg = n;
1649 }
1650 }
1651 if (!key_arg) {
1652 return errors::Internal(
1653 "No _Arg node found for host compute key in function ",
1654 body_host_func_name);
1655 }
1656
1657 // Add a _Retval node to loop body.
1658 NodeDefBuilder ret_builder(
1659 absl::StrCat("recv_oc_while_body_ret_", while_node_name), "_Retval");
1660 ret_builder.Attr("T", DT_STRING);
1661 ret_builder.Attr("index", 0);
1662 ret_builder.Input(key_arg->name(), 0, DT_STRING);
1663 NodeDef ret_def;
1664 TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
1665 Status s;
1666 Node* ret_node = body_graph->AddNode(ret_def, &s);
1667 TF_RETURN_IF_ERROR(s);
1668 body_graph->AddEdge(key_arg, 0, ret_node, 0);
1669
1670 // Reset device_ordinal to placeholder value.
1671 TF_RETURN_IF_ERROR(ResetDeviceOrdinalToPlaceholderValue(body_graph));
1672
1673 // Replace original function.
1674 FunctionDef body_replace_fdef;
1675 TF_RETURN_IF_ERROR(GraphToFunctionDef(*body_graph, body_host_func_name,
1676 HostGraphControlRetMapping,
1677 &body_replace_fdef));
1678 TF_RETURN_IF_ERROR(
1679 fld->ReplaceFunction(body_host_func_name, body_replace_fdef));
1680
1681 return Status::OK();
1682 }
1683
1684 // Builds host side graph for while node.
BuildHostGraphForWhileNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const string & while_node_name,const string & host_transfer_key,const string & host_graph_func_name,FunctionLibraryDefinition * fld,const string & cond_host_func_name,const string & body_host_func_name)1685 TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForWhileNode(
1686 const string& xla_cluster_attr_name,
1687 const string& outside_compilation_attr_name, const string& xla_cluster_name,
1688 const string& while_node_name, const string& host_transfer_key,
1689 const string& host_graph_func_name, FunctionLibraryDefinition* fld,
1690 const string& cond_host_func_name, const string& body_host_func_name) {
1691 Graph host_graph(fld);
1692 string outside_compilation_name = absl::StrCat("oc_while_", while_node_name);
1693
1694 // Step 1: add key placeholder node.
1695 TF_ASSIGN_OR_RETURN(
1696 Node * key_placeholder,
1697 AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
1698
1699 // Step 2: rewrite cond function.
1700 TF_RETURN_IF_ERROR(RewriteHostWhileLoopCond(
1701 cond_host_func_name, while_node_name, host_transfer_key,
1702 xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
1703 outside_compilation_name, fld));
1704
1705 // Step 3: rewrite body function.
1706 TF_RETURN_IF_ERROR(RewriteHostWhileLoopBody(
1707 body_host_func_name, while_node_name, host_transfer_key,
1708 xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
1709 outside_compilation_name, fld));
1710
1711 // Step 4: build While node.
1712 NodeDefBuilder while_builder(absl::StrCat("oc_while_", while_node_name),
1713 "While");
1714 while_builder.Attr("T", std::vector<DataType>{DT_STRING});
1715 NameAttrList func;
1716 AttrValue device_ordinal_value;
1717 device_ordinal_value.set_placeholder("_device_ordinal");
1718 (*func.mutable_attr())["_device_ordinal"] = device_ordinal_value;
1719 func.set_name(cond_host_func_name);
1720 while_builder.Attr("cond", func);
1721 func.set_name(body_host_func_name);
1722 while_builder.Attr("body", func);
1723 while_builder.Attr(kXlaHasHostTransferAttrName, true);
1724 while_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1725 while_builder.Attr(outside_compilation_attr_name, outside_compilation_name);
1726 // Make sure loop body of i-th iteration happens before loop cond of (i+1)-th
1727 // iteration.
1728 while_builder.Attr("parallel_iterations", 1);
1729 std::vector<NodeDefBuilder::NodeOut> while_inputs{
1730 {key_placeholder->name(), 0, DT_STRING}};
1731 while_builder.Input(while_inputs);
1732 NodeDef while_def;
1733 TF_RETURN_IF_ERROR(while_builder.Finalize(&while_def));
1734 Status s;
1735 Node* while_node = host_graph.AddNode(while_def, &s);
1736 TF_RETURN_IF_ERROR(s);
1737 host_graph.AddEdge(key_placeholder, 0, while_node, 0);
1738
1739 // Convert `host_graph` to function.
1740 FunctionDef oc_host_graph_fdef;
1741 TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name,
1742 &oc_host_graph_fdef));
1743 if (fld->Find(host_graph_func_name)) {
1744 TF_RETURN_IF_ERROR(
1745 fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef));
1746 } else {
1747 TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef));
1748 }
1749
1750 return Status::OK();
1751 }
1752
1753 // Builds host graph for func call nodes.
BuildHostGraphForFuncCallNode(const string & xla_cluster_attr_name,const string & xla_cluster_name,const string & outside_compilation_attr_name,const string & func_call_node_name,const string & func_call_host_func_name,const string & host_graph_func_name,FunctionLibraryDefinition * fld)1754 Status BuildHostGraphForFuncCallNode(
1755 const string& xla_cluster_attr_name, const string& xla_cluster_name,
1756 const string& outside_compilation_attr_name,
1757 const string& func_call_node_name, const string& func_call_host_func_name,
1758 const string& host_graph_func_name, FunctionLibraryDefinition* fld) {
1759 Graph host_graph(fld);
1760 AttrValue device_ordinal_value;
1761 device_ordinal_value.set_placeholder("_device_ordinal");
1762
1763 // Step 1: add key placeholder node.
1764 TF_ASSIGN_OR_RETURN(
1765 Node * key_placeholder,
1766 AddHostComputeKeyPlaceholder(xla_cluster_name, &host_graph));
1767
1768 // Step 2: rewrite `host_func_name`, replace key placeholder with an _Arg
1769 // node.
1770 TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode(
1771 xla_cluster_name, func_call_host_func_name, fld));
1772
1773 // Step 3: build a function call node with `host_func_name`, with
1774 // `key_placeholder` as input.
1775 NodeDefBuilder call_builder(absl::StrCat("oc_call_", func_call_node_name),
1776 func_call_host_func_name, fld);
1777 call_builder.Input(key_placeholder->name(), 0, DT_STRING);
1778 call_builder.Attr("_device_ordinal", device_ordinal_value);
1779 call_builder.Attr(kXlaHasHostTransferAttrName, true);
1780 call_builder.Attr(xla_cluster_attr_name, xla_cluster_name);
1781 call_builder.Attr(outside_compilation_attr_name, call_builder.node_name());
1782 NodeDef call_def;
1783 TF_RETURN_IF_ERROR(call_builder.Finalize(&call_def));
1784 Status s;
1785 Node* call_node = host_graph.AddNode(call_def, &s);
1786 TF_RETURN_IF_ERROR(s);
1787 host_graph.AddEdge(key_placeholder, 0, call_node, 0);
1788
1789 // Convert `host_graph` to function.
1790 FunctionDef oc_host_graph_fdef;
1791 TF_RETURN_IF_ERROR(GraphToFunctionDef(host_graph, host_graph_func_name,
1792 HostGraphControlRetMapping,
1793 &oc_host_graph_fdef));
1794 if (fld->Find(host_graph_func_name)) {
1795 TF_RETURN_IF_ERROR(
1796 fld->ReplaceFunction(host_graph_func_name, oc_host_graph_fdef));
1797 } else {
1798 TF_RETURN_IF_ERROR(fld->AddFunctionDef(oc_host_graph_fdef));
1799 }
1800
1801 return Status::OK();
1802 }
1803
ExtractOutsideCompilationForFuncCallNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const std::map<string,int> & host_compute_core,Graph * g,Node * n,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * host_graphs,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)1804 TF_ATTRIBUTE_NOINLINE Status ExtractOutsideCompilationForFuncCallNode(
1805 const string& xla_cluster_attr_name,
1806 const string& outside_compilation_attr_name, const string& xla_cluster_name,
1807 const std::map<string, int>& host_compute_core, Graph* g, Node* n,
1808 FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
1809 std::vector<string>* host_graphs,
1810 std::vector<string>* shape_inference_graphs,
1811 bool* has_outside_compilation) {
1812 bool func_has_outside_compilation = false;
1813 NameAttrList func;
1814 if (fld->Contains(n->type_string())) {
1815 func.set_name(n->type_string());
1816 typedef protobuf::Map<string, AttrValue> AttrMap;
1817 *func.mutable_attr() = AttrMap(n->attrs().begin(), n->attrs().end());
1818 } else if (n->IsPartitionedCall()) {
1819 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &func));
1820 } else {
1821 TF_RET_CHECK(n->type_string() == FunctionLibraryDefinition::kGradientOp);
1822 func.set_name(FunctionLibraryDefinition::kGradientOp);
1823 *func.mutable_attr() = n->def().attr();
1824 }
1825 string canonical_func_name;
1826 if (func.name() == FunctionLibraryDefinition::kGradientOp) {
1827 NameAttrList forward_func;
1828 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &forward_func));
1829 canonical_func_name = absl::StrCat("gradient_", forward_func.name());
1830 } else {
1831 canonical_func_name = func.name();
1832 }
1833 string new_func_name = absl::StrCat(canonical_func_name, "_oc");
1834 string host_func_name =
1835 absl::StrCat("oc_func_call_host_", canonical_func_name);
1836 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
1837 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
1838 func, new_func_name, host_func_name, host_compute_core, flr, fld,
1839 shape_inference_graphs, &func_has_outside_compilation));
1840
1841 // If the function call does not have outside compilation, nothing to do.
1842 if (!func_has_outside_compilation) {
1843 return Status::OK();
1844 }
1845
1846 *has_outside_compilation = true;
1847
1848 // Change `n` to call the new function directly.
1849 auto replace_builder =
1850 absl::make_unique<NodeDefBuilder>(n->name(), new_func_name, fld);
1851 std::vector<NodeDefBuilder::NodeOut> inputs(n->num_inputs());
1852 for (const Edge* e : n->in_edges()) {
1853 if (e->IsControlEdge()) {
1854 continue;
1855 }
1856
1857 const bool input_size_check =
1858 e->dst_input() < static_cast<int>(inputs.size());
1859 TF_RET_CHECK(e->dst_input() >= 0 && input_size_check);
1860 inputs[e->dst_input()] =
1861 NodeDefBuilder::NodeOut{e->src()->name(), e->src_output(),
1862 e->src()->output_type(e->src_output())};
1863 }
1864 for (const auto& input : inputs) {
1865 replace_builder->Input(input);
1866 }
1867 for (const auto& attr : n->attrs()) {
1868 replace_builder->Attr(attr.first, attr.second);
1869 }
1870 auto replace_def = absl::make_unique<NodeDef>();
1871 TF_RETURN_IF_ERROR(replace_builder->Finalize(replace_def.get()));
1872 TF_ASSIGN_OR_RETURN(Node * replace, ReplaceNode(g, n, *replace_def));
1873 replace->AddAttr(kXlaTokenInputNodesAttrName,
1874 std::vector<string>{kXlaTokenArgNodeName});
1875 replace->AddAttr(kXlaOriginalOutsideCompilationNodeName, replace->name());
1876
1877 // Build host side graph for the function call.
1878 string oc_host_graph_name =
1879 absl::StrCat("oc_func_host_graph_", replace->name());
1880 TF_RETURN_IF_ERROR(BuildHostGraphForFuncCallNode(
1881 xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name,
1882 replace->name(), host_func_name, oc_host_graph_name, fld));
1883
1884 // Record the host graph.
1885 host_graphs->push_back(oc_host_graph_name);
1886
1887 return Status::OK();
1888 }
1889
ExtractOutsideCompilationForIfNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const std::map<string,int> & host_compute_core,Graph * g,Node * n,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * host_graphs,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)1890 Status ExtractOutsideCompilationForIfNode(
1891 const string& xla_cluster_attr_name,
1892 const string& outside_compilation_attr_name, const string& xla_cluster_name,
1893 const std::map<string, int>& host_compute_core, Graph* g, Node* n,
1894 FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
1895 std::vector<string>* host_graphs,
1896 std::vector<string>* shape_inference_graphs,
1897 bool* has_outside_compilation) {
1898 // Instantiate "then_branch" and "else_branch".
1899 NameAttrList then_branch, else_branch;
1900 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "then_branch", &then_branch));
1901 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "else_branch", &else_branch));
1902
1903 // Extract outside compilation for then_branch and else_branch.
1904 bool then_branch_has_outside_compilation = false;
1905 bool else_branch_has_outside_compilation = false;
1906 string then_branch_host_func_name =
1907 absl::StrCat("oc_then_branch_host_if_", then_branch.name()),
1908 else_branch_host_func_name =
1909 absl::StrCat("oc_else_branch_host_if_", else_branch.name());
1910 string then_branch_xla_func_name = absl::StrCat(then_branch.name(), "_oc"),
1911 else_branch_xla_func_name = absl::StrCat(else_branch.name(), "_oc");
1912 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
1913 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
1914 then_branch, then_branch_xla_func_name, then_branch_host_func_name,
1915 host_compute_core, flr, fld, shape_inference_graphs,
1916 &then_branch_has_outside_compilation));
1917 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
1918 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
1919 else_branch, else_branch_xla_func_name, else_branch_host_func_name,
1920 host_compute_core, flr, fld, shape_inference_graphs,
1921 &else_branch_has_outside_compilation));
1922
1923 // If then/else branch do not have outside compilation, nothing to do.
1924 if (!then_branch_has_outside_compilation &&
1925 !else_branch_has_outside_compilation) {
1926 return Status::OK();
1927 }
1928
1929 *has_outside_compilation = true;
1930
1931 // Change If node to call the new functions.
1932 if (then_branch_has_outside_compilation) {
1933 then_branch.set_name(then_branch_xla_func_name);
1934 n->ClearAttr("then_branch");
1935 n->AddAttr("then_branch", then_branch);
1936 }
1937 if (else_branch_has_outside_compilation) {
1938 else_branch.set_name(else_branch_xla_func_name);
1939 n->ClearAttr("else_branch");
1940 n->AddAttr("else_branch", else_branch);
1941 }
1942 n->AddAttr(kXlaOriginalOutsideCompilationNodeName, n->name());
1943
1944 string host_transfer_key = absl::StrCat("oc_if_pred_", n->name());
1945
1946 // XLA computation: add a SendToHost node to send cond predicate.
1947 Node* pred_node;
1948 TF_RETURN_IF_ERROR(n->input_node(0, &pred_node));
1949 TF_ASSIGN_OR_RETURN(
1950 Node * send_pred_node,
1951 BuildSendIfPredNode(absl::StrCat("send_oc_if_pred_", n->name()),
1952 host_transfer_key, pred_node, g));
1953 n->AddAttr(kXlaTokenInputNodesAttrName,
1954 std::vector<string>{send_pred_node->name()});
1955
1956 // Add a control edge from `send_pred_node` to If node, so XlaCompiler will
1957 // visit If node after `send_pred_node`, thus the token output for
1958 // `send_pred_node` has been generated.
1959 g->AddControlEdge(send_pred_node, n);
1960
1961 // Build host side graph for the "If" node.
1962 // If then/else branch does not have outside compilation, we won't build host
1963 // graph for the branch. But here we need a host graph for both branches, so
1964 // we need to create a no-op host graph.
1965 if (!then_branch_has_outside_compilation) {
1966 std::unique_ptr<Graph> then_branch_host_graph(new Graph(fld));
1967 std::vector<string> then_branch_host_graphs;
1968 TF_RETURN_IF_ERROR(ConstructHostGraph(
1969 xla_cluster_name, outside_compilation_attr_name,
1970 then_branch_host_graphs, fld, &then_branch_host_graph));
1971 FunctionDef then_branch_host_fdef;
1972 TF_RETURN_IF_ERROR(GraphToFunctionDef(*then_branch_host_graph,
1973 then_branch_host_func_name,
1974 &then_branch_host_fdef));
1975 if (fld->Find(then_branch_host_func_name)) {
1976 TF_RETURN_IF_ERROR(fld->ReplaceFunction(then_branch_host_func_name,
1977 then_branch_host_fdef));
1978 } else {
1979 TF_RETURN_IF_ERROR(fld->AddFunctionDef(then_branch_host_fdef));
1980 }
1981 }
1982 if (!else_branch_has_outside_compilation) {
1983 std::unique_ptr<Graph> else_branch_host_graph(new Graph(fld));
1984 std::vector<string> else_branch_host_graphs;
1985 TF_RETURN_IF_ERROR(ConstructHostGraph(
1986 xla_cluster_name, outside_compilation_attr_name,
1987 else_branch_host_graphs, fld, &else_branch_host_graph));
1988 FunctionDef else_branch_host_fdef;
1989 TF_RETURN_IF_ERROR(GraphToFunctionDef(*else_branch_host_graph,
1990 else_branch_host_func_name,
1991 &else_branch_host_fdef));
1992 if (fld->Find(else_branch_host_func_name)) {
1993 TF_RETURN_IF_ERROR(fld->ReplaceFunction(else_branch_host_func_name,
1994 else_branch_host_fdef));
1995 } else {
1996 TF_RETURN_IF_ERROR(fld->AddFunctionDef(else_branch_host_fdef));
1997 }
1998 }
1999 string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name());
2000 TF_RETURN_IF_ERROR(BuildHostGraphForIfNode(
2001 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2002 n->name(), host_transfer_key, oc_host_graph_name, fld,
2003 then_branch_host_func_name, else_branch_host_func_name));
2004 host_graphs->push_back(oc_host_graph_name);
2005
2006 return Status::OK();
2007 }
2008
ExtractOutsideCompilationForWhileNode(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const std::map<string,int> & host_compute_core,Graph * g,Node * n,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * host_graphs,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)2009 Status ExtractOutsideCompilationForWhileNode(
2010 const string& xla_cluster_attr_name,
2011 const string& outside_compilation_attr_name, const string& xla_cluster_name,
2012 const std::map<string, int>& host_compute_core, Graph* g, Node* n,
2013 FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
2014 std::vector<string>* host_graphs,
2015 std::vector<string>* shape_inference_graphs,
2016 bool* has_outside_compilation) {
2017 // Instantiate "cond" and "body".
2018 NameAttrList cond, body;
2019 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "cond", &cond));
2020 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "body", &body));
2021
2022 // Extract outside compilation for cond and body.
2023 bool cond_has_outside_compilation = false;
2024 bool body_has_outside_compilation = false;
2025 string cond_host_func_name = absl::StrCat("oc_cond_host_while_", cond.name()),
2026 body_host_func_name = absl::StrCat("oc_body_host_while_", body.name());
2027 string cond_xla_func_name = absl::StrCat(cond.name(), "_oc"),
2028 body_xla_func_name = absl::StrCat(body.name(), "_oc");
2029 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
2030 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2031 cond, cond_xla_func_name, cond_host_func_name, host_compute_core, flr,
2032 fld, shape_inference_graphs, &cond_has_outside_compilation));
2033 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
2034 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2035 body, body_xla_func_name, body_host_func_name, host_compute_core, flr,
2036 fld, shape_inference_graphs, &body_has_outside_compilation));
2037
2038 // If cond/body do not have outside compilation, nothing to do.
2039 if (!cond_has_outside_compilation && !body_has_outside_compilation) {
2040 return Status::OK();
2041 }
2042
2043 *has_outside_compilation = true;
2044
2045 // Change While node to call the new functions.
2046 if (cond_has_outside_compilation) {
2047 cond.set_name(cond_xla_func_name);
2048 n->ClearAttr("cond");
2049 n->AddAttr("cond", cond);
2050 }
2051 if (body_has_outside_compilation) {
2052 body.set_name(body_xla_func_name);
2053 n->ClearAttr("body");
2054 n->AddAttr("body", body);
2055 }
2056 n->AddAttr(kXlaOriginalOutsideCompilationNodeName, n->name());
2057
2058 string host_transfer_key = absl::StrCat("oc_while_pred_", n->name());
2059
2060 // XLA computation: rewrite cond function to add a SendToHost node to send
2061 // loop predicate.
2062 TF_RETURN_IF_ERROR(AddSendLoopPredToLoopCond(
2063 cond_xla_func_name, host_transfer_key, &cond, fld, n));
2064 n->AddAttr(kXlaTokenInputNodesAttrName,
2065 std::vector<string>{kXlaTokenArgNodeName});
2066
2067 // Build host side graph for the "While" node.
2068 if (!cond_has_outside_compilation) {
2069 std::unique_ptr<Graph> cond_host_graph(new Graph(fld));
2070 std::vector<string> host_graphs;
2071 TF_RETURN_IF_ERROR(ConstructHostGraph(xla_cluster_name,
2072 outside_compilation_attr_name,
2073 host_graphs, fld, &cond_host_graph));
2074 FunctionDef cond_host_fdef;
2075 TF_RETURN_IF_ERROR(GraphToFunctionDef(*cond_host_graph, cond_host_func_name,
2076 &cond_host_fdef));
2077 if (fld->Find(cond_host_func_name)) {
2078 TF_RETURN_IF_ERROR(
2079 fld->ReplaceFunction(cond_host_func_name, cond_host_fdef));
2080 } else {
2081 TF_RETURN_IF_ERROR(fld->AddFunctionDef(cond_host_fdef));
2082 }
2083 }
2084 if (!body_has_outside_compilation) {
2085 std::unique_ptr<Graph> body_host_graph(new Graph(fld));
2086 std::vector<string> host_graphs;
2087 TF_RETURN_IF_ERROR(ConstructHostGraph(xla_cluster_name,
2088 outside_compilation_attr_name,
2089 host_graphs, fld, &body_host_graph));
2090 FunctionDef body_host_fdef;
2091 TF_RETURN_IF_ERROR(GraphToFunctionDef(*body_host_graph, body_host_func_name,
2092 &body_host_fdef));
2093 if (fld->Find(body_host_func_name)) {
2094 TF_RETURN_IF_ERROR(
2095 fld->ReplaceFunction(body_host_func_name, body_host_fdef));
2096 } else {
2097 TF_RETURN_IF_ERROR(fld->AddFunctionDef(body_host_fdef));
2098 }
2099 }
2100 string oc_host_graph_name = absl::StrCat("oc_while_host_graph_", n->name());
2101 TF_RETURN_IF_ERROR(BuildHostGraphForWhileNode(
2102 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2103 n->name(), host_transfer_key, oc_host_graph_name, fld,
2104 cond_host_func_name, body_host_func_name));
2105 host_graphs->push_back(oc_host_graph_name);
2106
2107 return Status::OK();
2108 }
2109
ExtractOutsideCompilationForNodesWithAssociatedFunctions(Graph * g,const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const std::map<string,int> & host_compute_core,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * host_graphs,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)2110 Status ExtractOutsideCompilationForNodesWithAssociatedFunctions(
2111 Graph* g, const string& xla_cluster_attr_name,
2112 const string& outside_compilation_attr_name, const string& xla_cluster_name,
2113 const std::map<string, int>& host_compute_core, FunctionLibraryRuntime* flr,
2114 FunctionLibraryDefinition* fld, std::vector<string>* host_graphs,
2115 std::vector<string>* shape_inference_graphs,
2116 bool* has_outside_compilation) {
2117 std::vector<Node*> if_nodes, while_nodes, func_call_nodes;
2118 for (Node* n : g->nodes()) {
2119 if (n->IsIfNode()) {
2120 if_nodes.push_back(n);
2121 } else if (n->IsWhileNode()) {
2122 while_nodes.push_back(n);
2123 } else if (IsFunctionCall(*fld, *n)) {
2124 func_call_nodes.push_back(n);
2125 }
2126 }
2127
2128 for (Node* n : func_call_nodes) {
2129 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFuncCallNode(
2130 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2131 host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs,
2132 has_outside_compilation));
2133 }
2134
2135 for (Node* n : if_nodes) {
2136 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForIfNode(
2137 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2138 host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs,
2139 has_outside_compilation));
2140 }
2141
2142 for (Node* n : while_nodes) {
2143 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForWhileNode(
2144 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2145 host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs,
2146 has_outside_compilation));
2147 }
2148
2149 return Status::OK();
2150 }
2151
CopyOutsideCompilationConstNodes(Graph * g,const string & outside_compilation_attr_name)2152 Status CopyOutsideCompilationConstNodes(
2153 Graph* g, const string& outside_compilation_attr_name) {
2154 for (Node* n : g->op_nodes()) {
2155 if (!n->IsConstant() ||
2156 !HasNodeAttr(n->def(), outside_compilation_attr_name)) {
2157 continue;
2158 }
2159
2160 std::vector<const Edge*> out_edges(n->out_edges().begin(),
2161 n->out_edges().end());
2162 bool has_non_oc_output = false;
2163 for (const Edge* e : out_edges) {
2164 if (!e->IsControlEdge() &&
2165 !HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
2166 has_non_oc_output = true;
2167 break;
2168 }
2169 }
2170 if (!has_non_oc_output) {
2171 continue;
2172 }
2173
2174 NodeDef copy_def = n->def();
2175 copy_def.set_name(g->NewName(n->name()));
2176 copy_def.mutable_attr()->erase(outside_compilation_attr_name);
2177 Status s;
2178 Node* copy_node = g->AddNode(copy_def, &s);
2179 TF_RETURN_IF_ERROR(s);
2180 for (const Edge* e : n->in_edges()) {
2181 if (e->IsControlEdge()) {
2182 g->AddControlEdge(e->src(), copy_node);
2183 }
2184 }
2185 for (const Edge* e : out_edges) {
2186 if (!e->IsControlEdge() &&
2187 !HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
2188 Node* dst = e->dst();
2189 int dst_input = e->dst_input();
2190 g->RemoveEdge(e);
2191 g->AddEdge(copy_node, 0, dst, dst_input);
2192 }
2193 }
2194 }
2195
2196 return Status::OK();
2197 }
2198
2199 } // namespace
2200
operator ()(const std::vector<OutputTensor> & arg_source_tensors,std::unique_ptr<Graph> * graph,std::vector<int> * input_permutation,std::vector<int> * output_permutation,NodeDef * node_def)2201 Status RewriteOutsideCompilationSubgraphFn::operator()(
2202 const std::vector<OutputTensor>& arg_source_tensors,
2203 std::unique_ptr<Graph>* graph, std::vector<int>* input_permutation,
2204 std::vector<int>* output_permutation, NodeDef* node_def) {
2205 string old_name = node_def->op();
2206 string new_name =
2207 absl::StrCat(xla_cluster_name_, "_", new_function_name_, "_", old_name);
2208 node_def->set_op(new_name);
2209 node_def->set_name(new_name);
2210
2211 // Later we will run PruneForReverseReachability(), so make sure all original
2212 // nodes are reachable from sink node and won't be removed.
2213 FixupSourceAndSinkEdges(graph->get());
2214
2215 // Step 1: create a key placeholder node.
2216 TF_ASSIGN_OR_RETURN(
2217 Node * key_placeholder,
2218 AddHostComputeKeyPlaceholder(xla_cluster_name_, graph->get()));
2219
2220 // Step 2: build RecvAtHost node, and replace all _Arg nodes with it.
2221 std::vector<DataType> recv_at_host_dtypes;
2222 TF_ASSIGN_OR_RETURN(
2223 Node * recv_at_host_node,
2224 ReplaceArgNodesWithRecvAtHostNode(graph->get(), new_name,
2225 &recv_at_host_dtypes, key_placeholder));
2226
2227 // Step 3: build SendFromHost node, and replace all _Retval nodes with it.
2228 std::vector<DataType> send_from_host_dtypes;
2229 TF_ASSIGN_OR_RETURN(
2230 Node * send_from_host_node,
2231 ReplaceRetNodesWithSendFromHostNode(
2232 graph->get(), new_name, &send_from_host_dtypes, key_placeholder));
2233
2234 // Step 4: add XLA cluster and outside compilation attr.
2235 for (Node* n : (*graph)->nodes()) {
2236 if (IsKeyPlaceholderNode(*n)) {
2237 continue;
2238 }
2239
2240 n->AddAttr(xla_cluster_attr_name_, xla_cluster_name_);
2241 n->AddAttr(outside_compilation_attr_name_, old_name);
2242 }
2243
2244 // Check whether we have all input shapes for XlaSendFromHost. If we do, we
2245 // will set `shapes` attr for the call node; otherwise we will save the
2246 // shape inference graph and set `shape_inference_graph` for the call node.
2247 absl::optional<std::vector<PartialTensorShape>> shapes =
2248 GetInferredInputShapes(send_from_host_dtypes.size(), send_from_host_node);
2249 for (Node* n : (*graph)->nodes()) {
2250 n->ClearAttr(kXlaInferredShapesAttrName);
2251 }
2252
2253 // Step 5: add control edges for originally XLA <-> outside compilation
2254 // control edges.
2255 for (Node* n : (*graph)->nodes()) {
2256 if (HasNodeAttr(n->def(), kXlaConnectedToXlaComputationAttrName)) {
2257 (*graph)->AddControlEdge(n, send_from_host_node);
2258 n->ClearAttr(kXlaConnectedToXlaComputationAttrName);
2259 }
2260 if (HasNodeAttr(n->def(), kXlaConnectedFromXlaComputationAttrName)) {
2261 (*graph)->AddControlEdge(recv_at_host_node, n);
2262 n->ClearAttr(kXlaConnectedFromXlaComputationAttrName);
2263 }
2264 }
2265
2266 // Step 6: RecvAtHost/SendFromHost/key_placeholder might be dead nodes. Prune
2267 // them if necessary.
2268 // - RecvAtHost should be pruned iff it has no output data/control edges. If
2269 // it has any output edge, it will be reverse reachable from sink node. We
2270 // don't need to do anything special.
2271 // - SendFromHost should be pruned iff it has no input data/control edges. If
2272 // it has input edges other than key_placeholder, we connect it to sink
2273 // node so it won't be pruned.
2274 // - key_placeholder should be pruned iff RecvAtHost/SendFromHost are pruned.
2275 // We don't need to do anything special.
2276 if (send_from_host_node->in_edges().size() > 1) {
2277 (*graph)->AddControlEdge(send_from_host_node, (*graph)->sink_node());
2278 }
2279 PruneForReverseReachability(
2280 graph->get(), std::unordered_set<const Node*>{(*graph)->sink_node()});
2281
2282 // Step 7: add necessary attributes to function call node, so we can replace
2283 // it with HostCompute node later.
2284 AddNodeAttr("_outside_compilation_subgraph", old_name, node_def);
2285 if (shapes) {
2286 NameAttrList shape_inference_graph;
2287 AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def);
2288 AddNodeAttr("shapes", *shapes, node_def);
2289 } else {
2290 string shape_inference_func_name =
2291 absl::StrCat("_outside_compilation_shape_inference_", new_name);
2292 NameAttrList shape_inference_graph;
2293 shape_inference_graph.set_name(shape_inference_func_name);
2294 AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def);
2295 AddNodeAttr("shapes", std::vector<TensorShapeProto>{}, node_def);
2296 }
2297 AddNodeAttr("ancestors", std::vector<string>{}, node_def);
2298 AddNodeAttr("Tinputs", recv_at_host_dtypes, node_def);
2299 AddNodeAttr("Toutputs", send_from_host_dtypes, node_def);
2300 AddNodeAttr("key", absl::StrCat("host_compute_channel_", new_name), node_def);
2301
2302 return Status::OK();
2303 }
2304
ExtractOutsideCompilationForFunction(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const string & xla_cluster_name,const NameAttrList & func_name_attrs,const string & new_func_name,const string & host_graph_func_name,const std::map<string,int> & host_compute_core,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,std::vector<string> * shape_inference_graphs,bool * has_outside_compilation)2305 Status ExtractOutsideCompilationForFunction(
2306 const string& xla_cluster_attr_name,
2307 const string& outside_compilation_attr_name, const string& xla_cluster_name,
2308 const NameAttrList& func_name_attrs, const string& new_func_name,
2309 const string& host_graph_func_name,
2310 const std::map<string, int>& host_compute_core, FunctionLibraryRuntime* flr,
2311 FunctionLibraryDefinition* fld, std::vector<string>* shape_inference_graphs,
2312 bool* has_outside_compilation) {
2313 // Convert the function to graph.
2314 const string& func_name = func_name_attrs.name();
2315 FunctionLibraryRuntime::Handle handle;
2316 TF_RETURN_IF_ERROR(
2317 flr->Instantiate(func_name, AttrSlice(&func_name_attrs.attr()), &handle));
2318 Status ret_status = Status::OK();
2319 auto cleanup_handle = gtl::MakeCleanup([&]() {
2320 auto s = flr->ReleaseHandle(handle);
2321 if (!s.ok()) {
2322 ret_status.Update(s);
2323 }
2324 });
2325 const FunctionBody* fbody = flr->GetFunctionBody(handle);
2326
2327 // Check if we have outside compilation nodes.
2328 *has_outside_compilation = false;
2329 for (Node* n : fbody->graph->nodes()) {
2330 if (HasNodeAttr(n->def(), outside_compilation_attr_name)) {
2331 *has_outside_compilation = true;
2332 break;
2333 }
2334 }
2335 // We cannot early return here, because we might have outside compilation in
2336 // If/While function body.
2337
2338 if (VLOG_IS_ON(4)) {
2339 DumpGraphToFile(
2340 absl::StrCat("extract_outside_compilation_for_func_before_", func_name),
2341 *fbody->graph, fld);
2342 }
2343
2344 std::unique_ptr<Graph> graph_out;
2345 std::vector<string> outside_compilation_host_graphs;
2346 std::vector<string> shape_inference_graphs_to_rewrite;
2347 if (*has_outside_compilation) {
2348 // Copy outside compilation Const nodes with non outside compilation users.
2349 TF_RETURN_IF_ERROR(CopyOutsideCompilationConstNodes(
2350 fbody->graph, outside_compilation_attr_name));
2351
2352 // Find dependencies between outside compilation clusters.
2353 TF_ASSIGN_OR_RETURN(auto cluster_deps,
2354 OutsideCompilationClusterDependencies(
2355 fbody->graph, outside_compilation_attr_name));
2356
2357 // Preprocess edges between different outside compilations. They will be
2358 // restored in `ConstructHostGraph()`.
2359 TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations(
2360 fbody->graph, outside_compilation_attr_name));
2361
2362 // Encapsulate outside_compilation cluster into function call node.
2363 auto rewrite_fn = absl::make_unique<RewriteOutsideCompilationSubgraphFn>(
2364 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2365 new_func_name);
2366 TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
2367 outside_compilation_attr_name, *fbody->graph, *rewrite_fn,
2368 /*reuse_existing_functions=*/true, &graph_out, fld));
2369
2370 // Replace outside_compilation function nodes with HostCompute ops.
2371 std::vector<Node*> outside_compilation_nodes;
2372 for (Node* n : graph_out->nodes()) {
2373 if (HasNodeAttr(n->def(), "_outside_compilation_subgraph")) {
2374 outside_compilation_nodes.push_back(n);
2375 outside_compilation_host_graphs.push_back(n->name());
2376
2377 // If we could not infer shapes for XlaSendFromHost inputs statically,
2378 // we will set the "shape_inference_graph" attribute. In that case, copy
2379 // outside compilation subgraph as shape inference graph in `fld`.
2380 auto shape_inference_graph = absl::make_unique<NameAttrList>();
2381 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "shape_inference_graph",
2382 shape_inference_graph.get()));
2383 if (!shape_inference_graph->name().empty()) {
2384 shape_inference_graphs->push_back(shape_inference_graph->name());
2385 shape_inference_graphs_to_rewrite.push_back(
2386 shape_inference_graph->name());
2387
2388 const FunctionDef* xla_fdef = fld->Find(n->name());
2389 if (!xla_fdef) {
2390 return errors::Internal("Cannot find XLA function ", n->name());
2391 }
2392 auto shape_inference_fdef = absl::make_unique<FunctionDef>(*xla_fdef);
2393 shape_inference_fdef->mutable_signature()->set_name(
2394 shape_inference_graph->name());
2395 if (fld->Find(shape_inference_graph->name())) {
2396 TF_RETURN_IF_ERROR(fld->ReplaceFunction(
2397 shape_inference_graph->name(), *shape_inference_fdef));
2398 } else {
2399 TF_RETURN_IF_ERROR(fld->AddFunctionDef(*shape_inference_fdef));
2400 }
2401 }
2402 }
2403 }
2404 std::map<string, Node*> host_compute_nodes;
2405 for (Node* n : outside_compilation_nodes) {
2406 auto host_compute_node_or = ReplaceOutsideCompilationCallNode(
2407 graph_out.get(), n, host_compute_core, *cluster_deps);
2408 TF_RETURN_IF_ERROR(host_compute_node_or.status());
2409 Node* host_compute_node = host_compute_node_or.ValueOrDie();
2410 host_compute_nodes[host_compute_node->name()] = host_compute_node;
2411 }
2412 // For XlaHostCompute nodes with dependencies, add control edges between
2413 // them so XlaCompiler can handle them in correct order.
2414 for (const auto& iter : host_compute_nodes) {
2415 Node* host_compute_node = iter.second;
2416 std::vector<string> token_input_node_names;
2417 TF_RETURN_IF_ERROR(GetNodeAttr(host_compute_node->def(),
2418 kXlaTokenInputNodesAttrName,
2419 &token_input_node_names));
2420 for (const string& node_name : token_input_node_names) {
2421 if (node_name == kXlaTokenArgNodeName) {
2422 continue;
2423 }
2424
2425 auto iter = host_compute_nodes.find(node_name);
2426 TF_RET_CHECK(iter != host_compute_nodes.end());
2427 graph_out->AddControlEdge(iter->second, host_compute_node);
2428 }
2429 }
2430 }
2431
2432 // Handle nodes with associated functions.
2433 Graph* g = (*has_outside_compilation) ? graph_out.get() : fbody->graph;
2434 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForNodesWithAssociatedFunctions(
2435 g, xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2436 host_compute_core, flr, fld, &outside_compilation_host_graphs,
2437 shape_inference_graphs, has_outside_compilation));
2438
2439 if (*has_outside_compilation) {
2440 // Construct host graph.
2441 std::unique_ptr<Graph> host_graph;
2442 TF_RETURN_IF_ERROR(
2443 ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name,
2444 outside_compilation_host_graphs, fld, &host_graph));
2445 auto host_graph_fdef = absl::make_unique<FunctionDef>();
2446 TF_RETURN_IF_ERROR(GraphToFunctionDef(*host_graph, host_graph_func_name,
2447 HostGraphControlRetMapping,
2448 host_graph_fdef.get()));
2449 if (fld->Find(host_graph_func_name)) {
2450 TF_RETURN_IF_ERROR(
2451 fld->ReplaceFunction(host_graph_func_name, *host_graph_fdef));
2452 } else {
2453 TF_RETURN_IF_ERROR(fld->AddFunctionDef(*host_graph_fdef));
2454 }
2455
2456 // Shape inference graphs might contain Placeholder nodes for outside
2457 // compilation to outside compilation edges. Rewrite shape inference graphs
2458 // to remove such nodes.
2459 for (const string& shape_inference_graph :
2460 shape_inference_graphs_to_rewrite) {
2461 TF_RETURN_IF_ERROR(
2462 RewriteShapeInferenceGraph(shape_inference_graph, host_graph.get(),
2463 /*pivot_node=*/nullptr, fld));
2464 }
2465
2466 // Remove the outside compilation graphs from function library.
2467 for (const string& func : outside_compilation_host_graphs) {
2468 TF_RETURN_IF_ERROR(fld->RemoveFunction(func));
2469 }
2470
2471 // Replace original function.
2472 auto updated_fdef = absl::make_unique<FunctionDef>();
2473 TF_RETURN_IF_ERROR(
2474 GraphToFunctionDef(*g, new_func_name, updated_fdef.get()));
2475 updated_fdef->mutable_signature()->set_is_stateful(true);
2476 const FunctionDef* original_fdef = fld->Find(func_name);
2477 if (original_fdef) {
2478 for (const auto& attr : original_fdef->attr()) {
2479 (*updated_fdef->mutable_attr())[attr.first] = attr.second;
2480 }
2481 }
2482 if (fld->Find(new_func_name)) {
2483 TF_RETURN_IF_ERROR(fld->ReplaceFunction(new_func_name, *updated_fdef));
2484 } else {
2485 TF_RETURN_IF_ERROR(fld->AddFunctionDef(*updated_fdef));
2486 }
2487 if (VLOG_IS_ON(4)) {
2488 DumpGraphToFile(
2489 absl::StrCat("extract_outside_compilation_for_func_after_",
2490 func_name),
2491 *g, fld);
2492 }
2493 }
2494
2495 return ret_status;
2496 }
2497
ExtractOutsideCompilation(const string & xla_cluster_attr_name,const string & outside_compilation_attr_name,const std::unordered_map<string,XlaClusterInfo> & clusters,Graph * g,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,bool * modified)2498 Status ExtractOutsideCompilation(
2499 const string& xla_cluster_attr_name,
2500 const string& outside_compilation_attr_name,
2501 const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g,
2502 FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld,
2503 bool* modified) {
2504 if (VLOG_IS_ON(4)) {
2505 DumpGraphToFile("extract_outside_compilation_before", *g, fld);
2506 }
2507
2508 *modified = false;
2509 auto node_name_index = g->BuildNodeNameIndex();
2510 for (auto& iter : clusters) {
2511 string xla_cluster_name = iter.first;
2512 Node* n = iter.second.node;
2513 auto const& func_name_attrs = iter.second.func_name_attrs;
2514 auto const& host_compute_core = iter.second.host_compute_core;
2515
2516 std::vector<string> shape_inference_graphs;
2517 bool has_outside_compilation;
2518 string host_graph_func_name =
2519 absl::StrCat("oc_host_graph_", xla_cluster_name);
2520 TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
2521 xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
2522 func_name_attrs, func_name_attrs.name(), host_graph_func_name,
2523 host_compute_core, flr, fld, &shape_inference_graphs,
2524 &has_outside_compilation));
2525 *modified |= has_outside_compilation;
2526
2527 if (has_outside_compilation) {
2528 string pivot_name = absl::StrCat(xla_cluster_name, "/pivot");
2529 Node* pivot_node = node_name_index[pivot_name];
2530 TF_RETURN_IF_ERROR(ExpandHostGraphIntoMainGraph(
2531 g, fld, host_graph_func_name, n, pivot_node));
2532
2533 TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name));
2534
2535 for (const auto& shape_inference_graph_name : shape_inference_graphs) {
2536 TF_RETURN_IF_ERROR(RewriteShapeInferenceGraph(
2537 shape_inference_graph_name, g, pivot_node, fld));
2538 }
2539 }
2540 }
2541
2542 if (VLOG_IS_ON(4)) {
2543 DumpGraphToFile("extract_outside_compilation_after", *g, fld);
2544 }
2545 return Status::OK();
2546 }
2547
2548 } // namespace tensorflow
2549