1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.h"
17
18 #include <queue>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/container/node_hash_map.h"
23 #include "absl/memory/memory.h"
24 #include "absl/strings/str_cat.h"
25 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
26 #include "tensorflow/compiler/jit/encapsulate_util.h"
27 #include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
28 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
29 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/core/common_runtime/function.h"
32 #include "tensorflow/core/framework/function.h"
33 #include "tensorflow/core/framework/graph_to_functiondef.h"
34 #include "tensorflow/core/framework/node_def.pb.h"
35 #include "tensorflow/core/framework/node_def_builder.h"
36 #include "tensorflow/core/framework/node_def_util.h"
37 #include "tensorflow/core/graph/algorithm.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/gtl/cleanup.h"
40 #include "tensorflow/core/lib/gtl/flatset.h"
41 #include "tensorflow/core/lib/hash/hash.h"
42 #include "tensorflow/core/lib/strings/proto_serialization.h"
43 #include "tensorflow/core/lib/strings/str_util.h"
44 #include "tensorflow/core/public/session_options.h"
45 #include "tensorflow/core/public/version.h"
46 #include "tensorflow/core/tpu/tpu_compile_interface.h"
47 #include "tensorflow/core/tpu/tpu_defs.h"
48 #include "tensorflow/core/util/dump_graph.h"
49
50 namespace tensorflow {
51
52 namespace {
53
54 const char* const kTPUReplicatedInput = "TPUReplicatedInput";
55 const char* const kTPUReplicatedOutput = "TPUReplicatedOutput";
56 const char* const kPivotForClusterAttr = "_pivot_for_cluster";
57 const char* const kTPUPartitionedInput = "TPUPartitionedInput";
58
59 // Finds the `index` of an _Arg or _Retval node.
GetIndexAttr(const Node & n,int num_args,int * index)60 Status GetIndexAttr(const Node& n, int num_args, int* index) {
61 TF_RETURN_IF_ERROR(GetNodeAttr(n.attrs(), "index", index));
62 if (*index < 0 || *index >= num_args) {
63 return errors::InvalidArgument("Invalid ", n.type_string(), " number ",
64 *index);
65 }
66 return Status::OK();
67 }
68
69 // Rewrite function to be passed to EncapsulateSubgraphsInFunctions that sorts
70 // the arguments into the order expected by TPUReplicate computations:
71 // 1) replicated arguments
72 // 2) non-replicated (broadcast) arguments
73 // 3) resource variable arguments
74 // See the documentation of EncapsulateSubgraphsInFunctions for the meaning
75 // of the arguments.
RewriteSubgraph(const std::vector<OutputTensor> & arg_source_tensors,std::unique_ptr<Graph> * graph_ptr,std::vector<int> * input_permutation,std::vector<int> * output_permutation,NodeDef * call_def)76 Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
77 std::unique_ptr<Graph>* graph_ptr,
78 std::vector<int>* input_permutation,
79 std::vector<int>* output_permutation,
80 NodeDef* call_def) {
81 // Replicated inputs have TPUReplicatedInput nodes as predecessors in the
82 // input graph.
83 auto is_replicated_input = [&](const Node& n, bool* is_packed = nullptr) {
84 CHECK_EQ("_Arg", n.type_string());
85 int index;
86 TF_CHECK_OK(GetIndexAttr(n, arg_source_tensors.size(), &index));
87 bool ret =
88 arg_source_tensors.at(index).node->type_string() == kTPUReplicatedInput;
89 if (is_packed) {
90 if (!ret || !GetNodeAttr(arg_source_tensors.at(index).node->attrs(),
91 "is_packed", is_packed)
92 .ok()) {
93 *is_packed = false;
94 }
95 }
96 return ret;
97 };
98
99 auto get_replicated_input_index = [&](const Node& n) {
100 CHECK_EQ("_Arg", n.type_string());
101 int index;
102 TF_CHECK_OK(GetIndexAttr(n, arg_source_tensors.size(), &index));
103 if (arg_source_tensors.at(index).node->type_string() !=
104 kTPUReplicatedInput) {
105 return -1;
106 }
107 int replicated_index;
108 TF_CHECK_OK(GetNodeAttr(arg_source_tensors.at(index).node->attrs(), "index",
109 &replicated_index));
110
111 return replicated_index;
112 };
113
114 auto is_guaranteed_constant = [&](const Node& n) {
115 bool guaranteed_constant = false;
116 if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant", &guaranteed_constant)
117 .ok()) {
118 return false;
119 }
120 // Replicated input nodes can be marked as guaranteed constants if they are
121 // const.
122 return guaranteed_constant && !is_replicated_input(n);
123 };
124
125 Graph* graph = graph_ptr->get();
126 Node* metadata_node = nullptr;
127 const int num_args = input_permutation->size();
128 const int num_retvals = output_permutation->size();
129
130 std::vector<Node*> args;
131 std::vector<Node*> retvals;
132 args.reserve(num_args);
133 retvals.reserve(num_retvals);
134 for (Node* n : graph->nodes()) {
135 if (n->type_string() == "_Arg") {
136 args.push_back(n);
137 } else if (n->type_string() == "_Retval") {
138 retvals.push_back(n);
139 } else if (n->type_string() == "TPUReplicateMetadata") {
140 metadata_node = n;
141 } else if (!str_util::StrContains(n->requested_device(),
142 DEVICE_TPU_REPLICATED_CORE)) {
143 // If an operator isn't assigned to a TPU core device, assign it to
144 // TPU_REPLICATED_CORE without a specific core ID. For some operators,
145 // such as variable reads/writes, the operator may be assigned to non-TPU
146 // devices due to colocation.
147 n->set_assigned_device_name(
148 strings::StrCat("/device:", DEVICE_TPU_REPLICATED_CORE));
149 }
150 }
151
152 // Read the metadata node and remove it from the graph.
153 if (metadata_node == nullptr) {
154 return errors::InvalidArgument("Missing TPUReplicateMetadata node");
155 }
156
157 for (const auto& attr : metadata_node->attrs()) {
158 if (attr.first == "computation_shape") {
159 // Convert the deprecated computation_shape attribute into a
160 // num_cores_per_replica value. If a computation_shape is present, it
161 // overrides num_cores_per_replica.
162 std::vector<int> shape;
163 TF_RETURN_IF_ERROR(
164 GetNodeAttr(metadata_node->attrs(), "computation_shape", &shape));
165 if (!shape.empty()) {
166 int64 num_cores_per_replica = 1LL;
167 for (int dim : shape) {
168 num_cores_per_replica *= dim;
169 }
170 call_def->mutable_attr()->erase("num_cores_per_replica");
171 AddNodeAttr("num_cores_per_replica", num_cores_per_replica, call_def);
172 }
173 } else {
174 call_def->mutable_attr()->insert(attr);
175 }
176 }
177 MergeDebugInfo(NodeDebugInfo(metadata_node->def()), call_def);
178 graph->RemoveNode(metadata_node);
179
180 if (std::find(args.begin(), args.end(), nullptr) != args.end()) {
181 return errors::InvalidArgument("Missing or non-consecutive arguments");
182 }
183
184 // Reorders the arguments.
185 std::sort(args.begin(), args.end(), [&](Node* a, Node* b) {
186 // Non-constants appear before constants
187 bool a_is_guaranteed_constant = is_guaranteed_constant(*a);
188 bool b_is_guaranteed_constant = is_guaranteed_constant(*b);
189 // Non-packed values appear before packed values.
190 bool a_is_packed;
191 bool b_is_packed;
192 // Replicated values appear before non-replicated values.
193 bool a_not_replicated = !is_replicated_input(*a, &a_is_packed);
194 bool b_not_replicated = !is_replicated_input(*b, &b_is_packed);
195 int a_replicated_index = get_replicated_input_index(*a);
196 int b_replicated_index = get_replicated_input_index(*b);
197 // Non-resources appear before resources
198 bool a_is_resource = (a->output_type(0) == DT_RESOURCE);
199 bool b_is_resource = (b->output_type(0) == DT_RESOURCE);
200 // Uses the name as a tiebreaker so the output is deterministic.
201 StringPiece a_name(a->name());
202 StringPiece b_name(b->name());
203 return std::tie(a_is_guaranteed_constant, a_not_replicated, a_is_packed,
204 a_is_resource, a_replicated_index, a_name) <
205 std::tie(b_is_guaranteed_constant, b_not_replicated, b_is_packed,
206 b_is_resource, b_replicated_index, b_name);
207 });
208 // Sorts the retvals by name so the order is deterministic.
209 std::sort(retvals.begin(), retvals.end(),
210 [](Node* a, Node* b) { return a->name() < b->name(); });
211
212 // Computes the permutation to produce the correct argument order, and update
213 // the argument indices.
214 int variable_start_index = num_args;
215 int guaranteed_const_start_index = num_args;
216 for (int i = 0; i < num_args; ++i) {
217 int index;
218 TF_RETURN_IF_ERROR(GetIndexAttr(*args[i], num_args, &index));
219 if (args[i]->output_type(0) == DT_RESOURCE &&
220 !is_replicated_input(*args[i]) && variable_start_index == num_args) {
221 variable_start_index = i;
222 } else if (is_guaranteed_constant(*args[i]) &&
223 guaranteed_const_start_index == num_args) {
224 guaranteed_const_start_index = i;
225 }
226 (*input_permutation)[index] = i;
227 args[i]->AddAttr("index", i);
228 }
229 VLOG(4) << "variable_start_index: " << variable_start_index
230 << " guaranteed_const_start_index: " << guaranteed_const_start_index;
231
232 // Computes the permutation to produce the correct retval order, and update
233 // the argument indices.
234 for (int i = 0; i < num_retvals; ++i) {
235 int index;
236 TF_RETURN_IF_ERROR(GetIndexAttr(*retvals[i], num_retvals, &index));
237 (*output_permutation)[index] = i;
238 retvals[i]->AddAttr("index", i);
239 }
240
241 AddNodeAttr(kTPUReplicateAttr, call_def->name(), call_def);
242 AddNodeAttr("_variable_start_index", variable_start_index, call_def);
243 AddNodeAttr("_guaranteed_const_start_index", guaranteed_const_start_index,
244 call_def);
245
246 // Uniquify the function name.
247 GraphDef gdef;
248 graph->ToGraphDef(&gdef);
249
250 // Before serialization, sort each node's control inputs to achieve
251 // determinism. Sorting control inputs could help (but not necessarily)
252 // create a deterministic serialization and fingerprint. Other sources of
253 // nondeterminism include unstable node ordering.
254 SortControlInputs(&gdef);
255 // Fingerprint the function.
256 // Nondeterminism in serialization would not lead to incorrect results, but
257 // may cause spurious cache misses. DeterministicSerialization is a
258 // best-effort deterministic serialization.
259 string serialized;
260 TF_RET_CHECK(SerializeToStringDeterministic(gdef, &serialized));
261 uint64 fingerprint =
262 TpuCompileInterface::Get()->FingerprintString(serialized);
263 LOG(INFO) << "Subgraph fingerprint:" << fingerprint;
264 call_def->set_op(strings::StrCat(call_def->op(), "_", fingerprint));
265 return Status::OK();
266 }
267
EdgeType(const Edge * edge)268 DataType EdgeType(const Edge* edge) {
269 return edge->dst()->input_type(edge->dst_input());
270 }
271
272 // Adds the control inputs of `node` to `*deps`.
AddControlInputs(const Node & node,gtl::FlatSet<Node * > * deps)273 void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) {
274 for (const Edge* edge : node.in_edges()) {
275 if (edge->IsControlEdge()) {
276 deps->insert(edge->src());
277 }
278 }
279 }
280
281 // Adds the control outputs of `node` to `*deps`.
AddControlOutputs(const Node & node,gtl::FlatSet<Node * > * deps)282 void AddControlOutputs(const Node& node, gtl::FlatSet<Node*>* deps) {
283 for (const Edge* edge : node.out_edges()) {
284 if (edge->IsControlEdge()) {
285 deps->insert(edge->dst());
286 }
287 }
288 }
289
290 // We add Identity nodes for _Arg/_Retval in XLA computation. Remove those
291 // Identity nodes to simplify furthur processing.
RemoveIdentityNodesForArgRetval(Graph * g)292 Status RemoveIdentityNodesForArgRetval(Graph* g) {
293 // Collect Identity nodes for _Arg/_Retval.
294 std::vector<Node*> identity_nodes;
295 for (Node* n : g->nodes()) {
296 if (n->type_string() == "Identity" &&
297 (HasNodeAttr(n->def(), "_tpu_input_identity") ||
298 HasNodeAttr(n->def(), "_tpu_output_identity"))) {
299 identity_nodes.push_back(n);
300 }
301 }
302
303 // Remove those Identity nodes.
304 for (Node* n : identity_nodes) {
305 const Edge* input_edge;
306 TF_RETURN_IF_ERROR(n->input_edge(0, &input_edge));
307
308 std::vector<const Edge*> output_edges;
309 for (const Edge* e : n->out_edges()) {
310 output_edges.push_back(e);
311 }
312 for (const Edge* e : output_edges) {
313 if (e->IsControlEdge()) {
314 Node* dst = e->dst();
315 g->RemoveEdge(e);
316 g->AddControlEdge(input_edge->src(), dst);
317 } else {
318 Node* dst = e->dst();
319 int dst_input = e->dst_input();
320 g->RemoveEdge(e);
321 g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
322 }
323 }
324 g->RemoveNode(n);
325 }
326
327 return Status::OK();
328 }
329
330 // Updates the TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR when
331 // 'additional_per_replicate_inputs' are added to the inputs of `xla_node`.
UpdateMirroredVariableIndices(int additional_per_replica_inputs,Node * xla_node)332 Status UpdateMirroredVariableIndices(int additional_per_replica_inputs,
333 Node* xla_node) {
334 std::vector<int> mirrored_variable_indices;
335 if (xla_node->attrs().Find(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR) !=
336 nullptr) {
337 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->def(),
338 TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
339 &mirrored_variable_indices));
340 }
341
342 if (!mirrored_variable_indices.empty()) {
343 for (int i = 0; i < mirrored_variable_indices.size(); ++i)
344 mirrored_variable_indices[i] += additional_per_replica_inputs;
345 xla_node->ClearAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR);
346 xla_node->AddAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
347 mirrored_variable_indices);
348 }
349 return Status::OK();
350 }
351
352 // Move outside compilation nodes at the beginning of XLA computation to host.
353 // For XLA computation graph, we will add new _Arg nodes to replace those
354 // outside compilation nodes.
355 // For host graph, we will move those outside compilation nodes to host,
356 // replicate them, and use them as XLA node's input.
MoveHeadOutsideCompilationToHost(const string & outside_compilation_attr_name,const string & xla_func_name,const std::string & cluster_name,Graph * g,Graph * xla_graph,Node * xla_node,Node * pivot_node)357 Status MoveHeadOutsideCompilationToHost(
358 const string& outside_compilation_attr_name, const string& xla_func_name,
359 const std::string& cluster_name, Graph* g, Graph* xla_graph, Node* xla_node,
360 Node* pivot_node) {
361 // Find outside compilation nodes that only have _Arg or other outside
362 // compilation nodes as input. These nodes will be moved to host graph.
363 std::vector<Node*> oc_nodes_at_head;
364 const string kOnlyArgOrOcInputAttrName = "_xla_only_arg_or_oc_input";
365 ReverseDFS(
366 *xla_graph, /*enter=*/nullptr,
367 [&](Node* n) {
368 bool has_non_arg_or_oc_input = false;
369 for (const Edge* e : n->in_edges()) {
370 if (e->src() == xla_graph->source_node()) {
371 continue;
372 }
373 if (!e->src()->IsArg() &&
374 (!HasNodeAttr(e->src()->def(), outside_compilation_attr_name) ||
375 !HasNodeAttr(e->src()->def(), kOnlyArgOrOcInputAttrName))) {
376 has_non_arg_or_oc_input = true;
377 break;
378 }
379 }
380 if (HasNodeAttr(n->def(), outside_compilation_attr_name) &&
381 !has_non_arg_or_oc_input &&
382 !HasNodeAttr(n->def(), kXlaIsPlaceholderForArg)) {
383 n->AddAttr(kOnlyArgOrOcInputAttrName, true);
384 oc_nodes_at_head.push_back(n);
385 }
386 },
387 NodeComparatorName());
388 std::vector<Node*> const_nodes_to_remove;
389 for (Node* n : oc_nodes_at_head) {
390 // If a Const node is in "oc_nodes_at_head" but some of its successors are
391 // not, copy this Const node and use the copied node for those successors.
392 if (n->type_string() != "Const") {
393 continue;
394 }
395
396 std::vector<const Edge*> edges_to_replace;
397 for (const Edge* e : n->out_edges()) {
398 if (!e->IsControlEdge() &&
399 HasNodeAttr(e->dst()->def(), outside_compilation_attr_name) &&
400 !HasNodeAttr(e->dst()->def(), kOnlyArgOrOcInputAttrName)) {
401 edges_to_replace.push_back(e);
402 }
403 }
404 if (edges_to_replace.empty()) {
405 continue;
406 }
407
408 Node* const_copy = xla_graph->CopyNode(n);
409 for (const Edge* e : edges_to_replace) {
410 Node* dst = e->dst();
411 int dst_input = e->dst_input();
412 xla_graph->RemoveEdge(e);
413 xla_graph->AddEdge(const_copy, 0, dst, dst_input);
414 }
415 // Make sure the copied node can be traced from source node.
416 xla_graph->AddControlEdge(xla_graph->source_node(), const_copy);
417
418 // If this Const node has no data output any more, remove it later.
419 bool has_output_edge = false;
420 for (const Edge* e : n->out_edges()) {
421 if (!e->IsControlEdge()) {
422 has_output_edge = true;
423 break;
424 }
425 }
426 if (!has_output_edge) {
427 const_nodes_to_remove.push_back(n);
428 }
429 }
430 for (Node* n : const_nodes_to_remove) {
431 xla_graph->RemoveNode(n);
432 oc_nodes_at_head.erase(
433 std::remove(oc_nodes_at_head.begin(), oc_nodes_at_head.end(), n),
434 oc_nodes_at_head.end());
435 }
436 if (VLOG_IS_ON(5)) {
437 for (Node* n : oc_nodes_at_head) {
438 VLOG(5) << "oc_nodes_at_head: " << n->DebugString();
439 }
440 }
441
442 // Copy all nodes in `oc_nodes_at_head` to host graph, and also replicate
443 // them.
444
445 // Sometimes `xla_node` can have a lot of inputs, calling Node::input_edge
446 // will become very expensive in this case because it is doing a linear
447 // search inside. Create an input_edges vector ahead to make the lookups
448 // faster.
449 std::vector<const Edge*> input_edges;
450 TF_RETURN_IF_ERROR(xla_node->input_edges(&input_edges));
451
452 std::vector<DataType> input_types;
453 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "Tinputs", &input_types));
454 int num_distributed_vars;
455 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "num_distributed_variables",
456 &num_distributed_vars));
457 int num_replicas;
458 TF_RETURN_IF_ERROR(
459 GetNodeAttr(xla_node->attrs(), "num_replicas", &num_replicas));
460 int old_num_per_replica_inputs =
461 (input_types.size() - num_distributed_vars) / num_replicas;
462 VLOG(5) << "old_num_per_replica_inputs: " << old_num_per_replica_inputs;
463 std::map<Node*, std::vector<Node*>> node_images;
464 for (Node* n : oc_nodes_at_head) {
465 for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
466 NodeDef copy_def = n->def();
467 copy_def.set_name(absl::StrCat(n->name(), "_head_oc/R", replica_id));
468 copy_def.clear_device();
469
470 Status s;
471 Node* copy_node = g->AddNode(copy_def, &s);
472 TF_RETURN_IF_ERROR(s);
473
474 copy_node->AddAttr(kXlaReplicaIdAttrName, replica_id);
475 copy_node->AddAttr(kTPUReplicateAttr, cluster_name);
476
477 for (const Edge* e : n->in_edges()) {
478 if (e->src() == xla_graph->source_node()) {
479 continue;
480 }
481 // Either e->src() is _Arg node, or it's in `node_images`.
482 if (e->src()->IsArg()) {
483 int index;
484 TF_RETURN_IF_ERROR(GetNodeAttr(e->src()->attrs(), "index", &index));
485 const int new_index =
486 (index < old_num_per_replica_inputs)
487 ? (old_num_per_replica_inputs * replica_id + index)
488 : (old_num_per_replica_inputs * num_replicas +
489 (index - old_num_per_replica_inputs));
490 const Edge* original_edge = input_edges.at(new_index);
491 g->AddEdge(original_edge->src(), original_edge->src_output(),
492 copy_node, e->dst_input());
493 } else {
494 g->AddEdge(node_images[e->src()][replica_id], e->src_output(),
495 copy_node, e->dst_input());
496 }
497 }
498
499 // Add control edge between `copy_node` and `xla_node`, so these outside
500 // compilation nodes will be executed before XLA computation happens.
501 g->AddControlEdge(copy_node, xla_node);
502
503 // Add control edge between `pivot_node` and `copy_node`, so `copy_node`
504 // belongs to same while loop as `xla_node`.
505 if (pivot_node) {
506 g->AddControlEdge(pivot_node, copy_node);
507 }
508
509 node_images[n].push_back(copy_node);
510 }
511 }
512
513 // Record output edges from `oc_nodes_at_head`. We will create an _Arg node
514 // for each of these edges. An obvious optimization here is to deduplicate
515 // these edges by <src, src_output>. But that optimization will complicate
516 // the code, and in practice we usually do not have output edges with the
517 // same <src, src_output>.
518 std::vector<const Edge*> oc_output_edges;
519 std::vector<DataType> new_arg_types;
520 for (Node* n : oc_nodes_at_head) {
521 for (const Edge* e : n->out_edges()) {
522 if (!e->IsControlEdge() &&
523 node_images.find(e->dst()) == node_images.end()) {
524 VLOG(5) << "oc_output_edges: " << e->DebugString();
525 oc_output_edges.push_back(e);
526 new_arg_types.push_back(e->src()->output_type(e->src_output()));
527 }
528 }
529 }
530 int new_num_per_replica_inputs =
531 old_num_per_replica_inputs + oc_output_edges.size();
532 VLOG(5) << "new_num_per_replica_inputs: " << new_num_per_replica_inputs;
533
534 // Process input edges for XLA node.
535 int num_variables;
536 TF_RETURN_IF_ERROR(
537 GetNodeAttr(xla_node->attrs(), "NumVariables", &num_variables));
538 std::vector<DataType> broadcast_input_types, guaranteed_constant_types;
539 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "Tbroadcast_inputs",
540 &broadcast_input_types));
541 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "Tguaranteed_constants",
542 &guaranteed_constant_types));
543 int num_other_inputs = num_distributed_vars + num_variables +
544 broadcast_input_types.size() +
545 guaranteed_constant_types.size();
546 VLOG(5) << "num_other_inputs: " << num_other_inputs;
547
548 // Update `Tinputs` attribute for `xla_node`.
549 std::vector<DataType> new_input_types;
550 // Order of new_input_types: old per-replica inputs -> new per-replica inputs
551 // -> distributed variables
552 new_input_types.reserve(num_replicas * new_num_per_replica_inputs +
553 num_distributed_vars);
554 for (int replica_id = 0; replica_id < num_replicas; ++replica_id) {
555 for (int i = 0; i < old_num_per_replica_inputs; ++i) {
556 new_input_types.push_back(input_types[i]);
557 }
558 for (int i = old_num_per_replica_inputs; i < new_num_per_replica_inputs;
559 ++i) {
560 new_input_types.push_back(new_arg_types[i - old_num_per_replica_inputs]);
561 }
562 }
563 const int num_new_per_replica_input_types = new_input_types.size();
564 for (int i = input_types.size() - num_distributed_vars;
565 i < input_types.size(); i++) {
566 new_input_types.push_back(input_types[i]);
567 }
568 xla_node->ClearAttr("Tinputs");
569 xla_node->AddAttr("Tinputs", new_input_types);
570
571 TF_RETURN_IF_ERROR(UpdateMirroredVariableIndices(
572 /*additional_per_replica_inputs=*/oc_output_edges.size(), xla_node));
573
574 int new_variable_start_index =
575 num_new_per_replica_input_types / num_replicas + num_distributed_vars +
576 broadcast_input_types.size();
577 if (xla_node->attrs().Find("_variable_start_index") != nullptr) {
578 xla_node->ClearAttr("_variable_start_index");
579 xla_node->AddAttr("_variable_start_index", new_variable_start_index);
580 }
581 int new_guaranteed_const_start_index =
582 new_variable_start_index + num_variables;
583 if (xla_node->attrs().Find("_guaranteed_const_start_index") != nullptr) {
584 xla_node->ClearAttr("_guaranteed_const_start_index");
585 xla_node->AddAttr("_guaranteed_const_start_index",
586 new_guaranteed_const_start_index);
587 }
588
589 // Move non per-replica input edges.
590 std::vector<const Edge*> new_input_edges(
591 num_replicas * new_num_per_replica_inputs + num_other_inputs);
592 int end_input_index =
593 num_replicas * new_num_per_replica_inputs + num_other_inputs - 1;
594 int start_input_index = end_input_index + 1 - num_other_inputs;
595 for (int input_index = end_input_index; input_index >= start_input_index;
596 input_index--) {
597 const Edge* e =
598 input_edges.at(input_index - num_replicas * new_arg_types.size());
599 Node* src = e->src();
600 int src_output = e->src_output();
601 g->RemoveEdge(e);
602 const Edge* new_input_edge =
603 g->AddEdge(src, src_output, xla_node, input_index);
604 new_input_edges[input_index] = new_input_edge;
605 }
606
607 // Re-order old per-replica inputs edges, and add new per-replica input edges.
608 std::vector<std::pair<Node*, int>> per_replica_inputs;
609 std::vector<const Edge*> old_per_replica_edges;
610 for (int i = 0; i < old_num_per_replica_inputs * num_replicas; i++) {
611 const Edge* e = input_edges.at(i);
612 per_replica_inputs.push_back(std::make_pair(e->src(), e->src_output()));
613 old_per_replica_edges.push_back(e);
614 }
615 for (const Edge* e : old_per_replica_edges) {
616 g->RemoveEdge(e);
617 }
618 for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
619 for (int input_index = 0; input_index < old_num_per_replica_inputs;
620 input_index++) {
621 Node* src = per_replica_inputs[replica_id * old_num_per_replica_inputs +
622 input_index]
623 .first;
624 int src_output =
625 per_replica_inputs[replica_id * old_num_per_replica_inputs +
626 input_index]
627 .second;
628 const Edge* new_input_edge =
629 g->AddEdge(src, src_output, xla_node,
630 replica_id * new_num_per_replica_inputs + input_index);
631 new_input_edges[input_index] = new_input_edge;
632 }
633 for (int input_index = old_num_per_replica_inputs;
634 input_index < new_num_per_replica_inputs; input_index++) {
635 Node* original_src =
636 oc_output_edges[input_index - old_num_per_replica_inputs]->src();
637 int original_src_output =
638 oc_output_edges[input_index - old_num_per_replica_inputs]
639 ->src_output();
640 Node* src = node_images[original_src][replica_id];
641 const Edge* new_input_edge =
642 g->AddEdge(src, original_src_output, xla_node,
643 replica_id * new_num_per_replica_inputs + input_index);
644 new_input_edges[input_index] = new_input_edge;
645 }
646 }
647
648 // Adjust original _Arg nodes in `xla_graph`.
649 for (Node* n : xla_graph->nodes()) {
650 if (n->IsArg()) {
651 int index;
652 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
653 if (index >= old_num_per_replica_inputs) {
654 index += new_arg_types.size();
655 n->ClearAttr("index");
656 n->AddAttr("index", index);
657 }
658 }
659 }
660
661 // Create new _Arg nodes in `xla_graph`.
662 for (int i = old_num_per_replica_inputs; i < new_num_per_replica_inputs;
663 i++) {
664 NodeDefBuilder arg_builder(absl::StrCat("arg_", i),
665 FunctionLibraryDefinition::kArgOp);
666 arg_builder.Attr("T", new_arg_types[i - old_num_per_replica_inputs]);
667 arg_builder.Attr("index", i);
668 NodeDef arg_def;
669 TF_RETURN_IF_ERROR(arg_builder.Finalize(&arg_def));
670 Status s;
671 Node* arg_node = xla_graph->AddNode(arg_def, &s);
672 TF_RETURN_IF_ERROR(s);
673 const Edge* original_edge = oc_output_edges[i - old_num_per_replica_inputs];
674 Node* dst = original_edge->dst();
675 int dst_input = original_edge->dst_input();
676 xla_graph->RemoveEdge(original_edge);
677 xla_graph->AddEdge(arg_node, 0, dst, dst_input);
678 }
679
680 // For lifted arg nodes:
681 // 1. Add a Placeholder node in `xla_graph`. When we build host side graph
682 // in ExtractOutsideCompilationPass, we will use this new Placeholder node
683 // instead of lifted arg node here.
684 // 2. Add an IdentityN node in `g` to indicate its inputs. We will reconnect
685 // this IdentityN node and this lifted arg node's usage nodes in
686 // DistributedTPURewritePass.
687 for (Node* n : oc_nodes_at_head) {
688 bool is_lifted_arg;
689 string outside_compilation_attr;
690 if (!TryGetNodeAttr(n->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg) ||
691 !TryGetNodeAttr(n->def(), kOutsideCompilationAttr,
692 &outside_compilation_attr)) {
693 continue;
694 }
695
696 TF_RET_CHECK(n->IsIdentity());
697 NodeDefBuilder ph_builder(absl::StrCat("placeholder_", n->name()),
698 "Placeholder");
699 DataType dtype;
700 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
701 ph_builder.Attr("dtype", dtype);
702 ph_builder.Attr(kXlaIsLiftedArgAttrName, true);
703 ph_builder.Attr(kOutsideCompilationAttr, outside_compilation_attr);
704 NodeDef ph_def;
705 TF_RETURN_IF_ERROR(ph_builder.Finalize(&ph_def));
706 Status s;
707 xla_graph->AddNode(ph_def, &s);
708 TF_RETURN_IF_ERROR(s);
709
710 Node* input_node;
711 TF_RETURN_IF_ERROR(n->input_node(0, &input_node));
712 TF_RET_CHECK(input_node->type_string() == "_Arg");
713 int index;
714 TF_RETURN_IF_ERROR(GetNodeAttr(input_node->def(), "index", &index));
715 // TODO(b/74023706): for now we only support resource input (e.g. summary
716 // writer), which is non-replicated input. Support replicated input as
717 // well.
718 TF_RET_CHECK(index >= new_num_per_replica_inputs + num_distributed_vars);
719 const Edge* input_edge =
720 new_input_edges.at(num_replicas * new_num_per_replica_inputs + index -
721 new_num_per_replica_inputs);
722 NodeDefBuilder id_builder(absl::StrCat("lifted_arg_input_", index),
723 "IdentityN");
724 DataType input_dtype =
725 input_edge->src()->output_type(input_edge->src_output());
726 id_builder.Attr("T", std::vector<DataType>(num_replicas, input_dtype));
727 std::vector<NodeDefBuilder::NodeOut> inputs(
728 num_replicas,
729 NodeDefBuilder::NodeOut{input_edge->src()->name(),
730 input_edge->src_output(), input_dtype});
731 id_builder.Attr(kXlaOutsideCompilationInputsAttrName,
732 outside_compilation_attr);
733 id_builder.Input(inputs);
734 NodeDef id_def;
735 TF_RETURN_IF_ERROR(id_builder.Finalize(&id_def));
736 Node* id_node = g->AddNode(id_def, &s);
737 TF_RETURN_IF_ERROR(s);
738 for (int i = 0; i < num_replicas; i++) {
739 g->AddEdge(input_edge->src(), input_edge->src_output(), id_node, i);
740 }
741 }
742
743 // Remove `oc_nodes_at_head`.
744 for (Node* n : oc_nodes_at_head) {
745 xla_graph->RemoveNode(n);
746 }
747
748 VLOG(4) << "MoveHeadOutsideCompilationToHost host graph: "
749 << DumpGraphToFile(absl::StrCat("move_head_oc_host_", xla_func_name),
750 *g);
751 VLOG(4) << "MoveHeadOutsideCompilationToHost XLA graph: "
752 << DumpGraphToFile(absl::StrCat("move_head_oc_xla_", xla_func_name),
753 *xla_graph);
754
755 return Status::OK();
756 }
757
758 // If there are any unused _Arg nodes in `xla_graph`, remove them from
759 // `xla_graph` and remove corresponding input edge in host graph `g`.
RemoveUnusedXlaInput(const string & xla_func_name,Graph * g,Graph * xla_graph,Node * xla_node)760 Status RemoveUnusedXlaInput(const string& xla_func_name, Graph* g,
761 Graph* xla_graph, Node* xla_node) {
762 // Find unused _Arg nodes, and remove them.
763 std::vector<DataType> input_types;
764 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->def(), "Tinputs", &input_types));
765 std::vector<int> mirrored_variable_indices;
766 if (xla_node->attrs().Find(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR) !=
767 nullptr) {
768 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->def(),
769 TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
770 &mirrored_variable_indices));
771 }
772 std::vector<DataType> broadcast_input_types;
773 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->def(), "Tbroadcast_inputs",
774 &broadcast_input_types));
775 std::vector<DataType> guaranteed_constant_types;
776 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->def(), "Tguaranteed_constants",
777 &guaranteed_constant_types));
778 int num_variables;
779 TF_RETURN_IF_ERROR(
780 GetNodeAttr(xla_node->def(), "NumVariables", &num_variables));
781 int num_replicas;
782 TF_RETURN_IF_ERROR(
783 GetNodeAttr(xla_node->def(), "num_replicas", &num_replicas));
784 int num_distributed_vars;
785 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "num_distributed_variables",
786 &num_distributed_vars));
787 int num_per_replica_inputs =
788 (input_types.size() - num_distributed_vars) / num_replicas;
789 std::set<int> arg_indices_to_remove;
790 std::vector<Node*> arg_nodes_to_update, nodes_to_remove;
791 int num_args = 0, num_removed_per_replica_inputs = 0,
792 num_removed_distributed_vars = 0;
793 for (Node* n : xla_graph->nodes()) {
794 if (!n->IsArg()) {
795 continue;
796 }
797
798 bool has_output = false;
799 for (const Edge* e : n->out_edges()) {
800 if (e->dst() != xla_graph->sink_node()) {
801 has_output = true;
802 break;
803 }
804 }
805
806 num_args++;
807 int index;
808 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
809 if (has_output) {
810 arg_nodes_to_update.push_back(n);
811 continue;
812 }
813
814 arg_indices_to_remove.insert(index);
815 if (index < num_per_replica_inputs) {
816 num_removed_per_replica_inputs++;
817 } else if (index < num_per_replica_inputs + num_distributed_vars) {
818 num_removed_distributed_vars++;
819 }
820 nodes_to_remove.push_back(n);
821 }
822 for (Node* n : nodes_to_remove) {
823 xla_graph->RemoveNode(n);
824 }
825
826 // Update `index` for other _Arg nodes.
827 std::map<int, int> arg_index_mapping;
828 int new_arg_index = 0;
829 for (int i = 0; i < num_args; i++) {
830 if (arg_indices_to_remove.find(i) != arg_indices_to_remove.end()) {
831 continue;
832 } else {
833 arg_index_mapping[i] = new_arg_index;
834 new_arg_index++;
835 }
836 }
837 for (Node* n : arg_nodes_to_update) {
838 int index;
839 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
840 n->ClearAttr("index");
841 n->AddAttr("index", arg_index_mapping[index]);
842 }
843
844 // Re-order replicated index edges for `xla_node`.
845
846 // Sometimes `xla_node` can have a lot of inputs, calling Node::input_edge
847 // will become very expensive in this case because it is doing a linear search
848 // inside. Create a input_edges vector ahead to make the lookups faster.
849 std::vector<const Edge*> input_edges;
850 TF_RETURN_IF_ERROR(xla_node->input_edges(&input_edges));
851
852 const int num_new_per_replica_inputs =
853 num_per_replica_inputs - num_removed_per_replica_inputs;
854 for (int i = 0; i < num_replicas; i++) {
855 for (int j = 0; j < num_per_replica_inputs; j++) {
856 auto iter = arg_index_mapping.find(j);
857 if (iter != arg_index_mapping.end()) {
858 const Edge* e = input_edges.at(i * num_per_replica_inputs + j);
859 Node* src = e->src();
860 int src_output = e->src_output();
861 int dst_input = i * num_new_per_replica_inputs + iter->second;
862
863 g->RemoveEdge(e);
864 g->AddEdge(src, src_output, xla_node, dst_input);
865 } else {
866 const Edge* e = input_edges.at(i * num_per_replica_inputs + j);
867 g->RemoveEdge(e);
868 }
869 }
870 }
871
872 // Move other data input edges.
873 for (int i = num_replicas * num_per_replica_inputs;
874 i < xla_node->num_inputs(); i++) {
875 int arg_index =
876 num_per_replica_inputs + i - num_replicas * num_per_replica_inputs;
877 auto iter = arg_index_mapping.find(arg_index);
878 if (iter != arg_index_mapping.end()) {
879 const Edge* e = input_edges.at(i);
880 Node* src = e->src();
881 int src_output = e->src_output();
882 int dst_input = num_replicas * num_new_per_replica_inputs + iter->second -
883 num_new_per_replica_inputs;
884
885 g->RemoveEdge(e);
886 g->AddEdge(src, src_output, xla_node, dst_input);
887 } else {
888 const Edge* e = input_edges.at(i);
889 g->RemoveEdge(e);
890 }
891 }
892
893 // Update attributes for `xla_node`.
894 std::vector<DataType> new_input_types;
895 for (int i = 0; i < num_replicas; i++) {
896 for (int j = 0; j < num_per_replica_inputs; j++) {
897 auto iter = arg_index_mapping.find(j);
898 if (iter != arg_index_mapping.end()) {
899 new_input_types.push_back(input_types[iter->first]);
900 }
901 }
902 }
903 for (int i = 0; i < num_distributed_vars; ++i) {
904 auto iter = arg_index_mapping.find(i + num_per_replica_inputs);
905 if (iter != arg_index_mapping.end()) {
906 new_input_types.push_back(
907 input_types[iter->first - num_per_replica_inputs +
908 num_per_replica_inputs * num_replicas]);
909 }
910 }
911 xla_node->ClearAttr("Tinputs");
912 xla_node->AddAttr("Tinputs", new_input_types);
913
914 const int num_new_distributed_vars =
915 num_distributed_vars - num_removed_distributed_vars;
916 xla_node->ClearAttr("num_distributed_variables");
917 xla_node->AddAttr("num_distributed_variables", num_new_distributed_vars);
918
919 if (!mirrored_variable_indices.empty()) {
920 std::vector<int> new_mirrored_variable_indices;
921 absl::flat_hash_set<int> old_mirrored_variable_indices_set;
922 for (int index : mirrored_variable_indices) {
923 old_mirrored_variable_indices_set.insert(index);
924 }
925 for (int i = 0; i < num_per_replica_inputs + num_distributed_vars; i++) {
926 auto iter = arg_index_mapping.find(i);
927 if (iter != arg_index_mapping.end() &&
928 old_mirrored_variable_indices_set.contains(iter->first)) {
929 new_mirrored_variable_indices.push_back(iter->second);
930 }
931 }
932 xla_node->ClearAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR);
933 xla_node->AddAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
934 new_mirrored_variable_indices);
935 }
936
937 int num_replicated_inputs = num_per_replica_inputs + num_distributed_vars;
938 std::vector<DataType> new_broadcast_input_types;
939 for (int i = 0; i < broadcast_input_types.size(); i++) {
940 int arg_index = num_replicated_inputs + i;
941 if (arg_index_mapping.find(arg_index) != arg_index_mapping.end()) {
942 new_broadcast_input_types.push_back(broadcast_input_types[i]);
943 }
944 }
945 xla_node->ClearAttr("Tbroadcast_inputs");
946 xla_node->AddAttr("Tbroadcast_inputs", new_broadcast_input_types);
947 int new_num_variables = 0;
948 for (int i = 0; i < num_variables; i++) {
949 int arg_index = num_replicated_inputs + broadcast_input_types.size() + i;
950 if (arg_index_mapping.find(arg_index) != arg_index_mapping.end()) {
951 new_num_variables++;
952 }
953 }
954 xla_node->ClearAttr("NumVariables");
955 xla_node->AddAttr("NumVariables", new_num_variables);
956 std::vector<DataType> new_guaranteed_constant_types;
957 for (int i = 0; i < guaranteed_constant_types.size(); i++) {
958 int arg_index = num_replicated_inputs + broadcast_input_types.size() +
959 num_variables + i;
960 if (arg_index_mapping.find(arg_index) != arg_index_mapping.end()) {
961 new_guaranteed_constant_types.push_back(guaranteed_constant_types[i]);
962 }
963 }
964 xla_node->ClearAttr("Tguaranteed_constants");
965 xla_node->AddAttr("Tguaranteed_constants", new_guaranteed_constant_types);
966
967 int new_variable_start_index = num_new_per_replica_inputs +
968 num_new_distributed_vars +
969 new_broadcast_input_types.size();
970 if (xla_node->attrs().Find("_variable_start_index") != nullptr) {
971 xla_node->ClearAttr("_variable_start_index");
972 xla_node->AddAttr("_variable_start_index", new_variable_start_index);
973 }
974 int new_guaranteed_const_start_index =
975 new_variable_start_index + new_num_variables;
976 if (xla_node->attrs().Find("_guaranteed_const_start_index") != nullptr) {
977 xla_node->ClearAttr("_guaranteed_const_start_index");
978 xla_node->AddAttr("_guaranteed_const_start_index",
979 new_guaranteed_const_start_index);
980 }
981
982 VLOG(4) << "RemoveUnusedXlaInput host graph: "
983 << DumpGraphToFile(
984 absl::StrCat("remove_unused_input_host_", xla_func_name), *g);
985 VLOG(4) << "RemoveUnusedXlaInput XLA graph: "
986 << DumpGraphToFile(
987 absl::StrCat("remove_unused_input_xla_", xla_func_name),
988 *xla_graph);
989
990 return Status::OK();
991 }
992
993 // Move outside compilation nodes at the end of XLA computation to host.
994 // For XLA computation graph, we will add new _Retval nodes to replace those
995 // outside compilation nodes.
996 // For host graph, we will move those outside compilation nodes to host,
997 // replicate them, and use them as XLA node's output.
MoveTailOutsideCompilationToHost(const string & outside_compilation_attr_name,const string & xla_func_name,const std::string & cluster_name,Graph * g,Graph * xla_graph,Node * xla_node,Node * pivot_node)998 Status MoveTailOutsideCompilationToHost(
999 const string& outside_compilation_attr_name, const string& xla_func_name,
1000 const std::string& cluster_name, Graph* g, Graph* xla_graph, Node* xla_node,
1001 Node* pivot_node) {
1002 // Find outside compilation nodes that only have _Retval or other outside
1003 // compilation nodes as output. These nodes will be moved to host graph.
1004 std::vector<Node*> oc_nodes_at_tail;
1005 const string kOnlyRetOrOcOutputAttrName = "_xla_only_ret_or_oc_output";
1006 DFS(
1007 *xla_graph, /*enter=*/nullptr,
1008 [&](Node* n) {
1009 bool has_non_ret_or_oc_output = false;
1010 for (const Edge* e : n->out_edges()) {
1011 if (e->dst() == xla_graph->sink_node()) {
1012 continue;
1013 }
1014 if (!e->dst()->IsRetval() &&
1015 (!HasNodeAttr(e->dst()->def(), outside_compilation_attr_name) ||
1016 !HasNodeAttr(e->dst()->def(), kOnlyRetOrOcOutputAttrName))) {
1017 has_non_ret_or_oc_output = true;
1018 break;
1019 }
1020 }
1021 if (HasNodeAttr(n->def(), outside_compilation_attr_name) &&
1022 !has_non_ret_or_oc_output) {
1023 n->AddAttr(kOnlyRetOrOcOutputAttrName, true);
1024 oc_nodes_at_tail.push_back(n);
1025 }
1026 },
1027 NodeComparatorName());
1028 if (VLOG_IS_ON(5)) {
1029 for (Node* n : oc_nodes_at_tail) {
1030 VLOG(5) << "oc_nodes_at_tail: " << n->DebugString();
1031 }
1032 }
1033
1034 // Record input edges from `oc_nodes_at_tail`. We will create an _Retval node
1035 // for each of these edges. An obvious optimization here is to deduplicate
1036 // these edges by <src, src_output>. But that optimization will complicate
1037 // the code, and in practice we usually do not have input edges with the
1038 // same <src, src_output>.
1039 std::vector<const Edge*> oc_input_edges;
1040 std::vector<DataType> new_ret_types;
1041 for (Node* n : oc_nodes_at_tail) {
1042 for (const Edge* e : n->in_edges()) {
1043 if (!e->IsControlEdge() &&
1044 !HasNodeAttr(e->src()->def(), kOnlyRetOrOcOutputAttrName)) {
1045 VLOG(5) << "oc_input_edges: " << e->DebugString();
1046 oc_input_edges.push_back(e);
1047 new_ret_types.push_back(e->src()->output_type(e->src_output()));
1048 }
1049 }
1050 }
1051 std::vector<DataType> output_types;
1052 TF_RETURN_IF_ERROR(
1053 GetNodeAttr(xla_node->attrs(), "output_types", &output_types));
1054 int num_replicas;
1055 TF_RETURN_IF_ERROR(
1056 GetNodeAttr(xla_node->attrs(), "num_replicas", &num_replicas));
1057 int old_num_replicated_outputs = output_types.size() / num_replicas;
1058 int new_num_replicated_outputs =
1059 old_num_replicated_outputs + oc_input_edges.size();
1060 VLOG(5) << "old_num_replicated_outputs: " << old_num_replicated_outputs;
1061 VLOG(5) << "new_num_replicated_outputs: " << new_num_replicated_outputs;
1062
1063 // Update `output_types` attribute for `xla_node`.
1064 std::vector<DataType> new_output_types;
1065 for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
1066 for (int i = 0; i < old_num_replicated_outputs; i++) {
1067 new_output_types.push_back(output_types[i]);
1068 }
1069 for (int i = old_num_replicated_outputs; i < new_num_replicated_outputs;
1070 i++) {
1071 new_output_types.push_back(new_ret_types[i - old_num_replicated_outputs]);
1072 }
1073 }
1074 xla_node->ClearAttr("output_types");
1075 xla_node->AddAttr("output_types", new_output_types);
1076
1077 // Re-order old replicated output edges. Since a node could potentially
1078 // connect to multiple nodes, build a vector<vector<pair>> mapping of
1079 // output index to input nodes/index.
1080 // The outer vector represents the output index, the inner vector
1081 // represents the destination node and input index pair with the possibility
1082 // of multiple node/index pairs.
1083 std::vector<std::vector<std::pair<Node*, int>>> replicated_outputs(
1084 old_num_replicated_outputs * num_replicas);
1085 std::vector<const Edge*> old_replicated_edges;
1086 for (const Edge* e : xla_node->out_edges()) {
1087 if (e->src_output() >= 0 &&
1088 e->src_output() < old_num_replicated_outputs * num_replicas) {
1089 replicated_outputs[e->src_output()].push_back(
1090 std::make_pair(e->dst(), e->dst_input()));
1091 old_replicated_edges.push_back(e);
1092 }
1093 }
1094 for (const Edge* e : old_replicated_edges) {
1095 g->RemoveEdge(e);
1096 }
1097 for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
1098 for (int output_index = 0; output_index < old_num_replicated_outputs;
1099 output_index++) {
1100 for (const auto& node_input_pair :
1101 replicated_outputs[replica_id * old_num_replicated_outputs +
1102 output_index]) {
1103 Node* dst = node_input_pair.first;
1104 int dst_input = node_input_pair.second;
1105 g->AddEdge(xla_node,
1106 replica_id * new_num_replicated_outputs + output_index, dst,
1107 dst_input);
1108 }
1109 }
1110 }
1111
1112 // Copy all nodes in `oc_nodes_at_tail` to host graph, and also replicate
1113 // them.
1114 std::map<Node*, std::vector<Node*>> node_images;
1115 for (Node* n : oc_nodes_at_tail) {
1116 for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
1117 NodeDef copy_def = n->def();
1118 copy_def.set_name(absl::StrCat(n->name(), "_tail_oc/R", replica_id));
1119 copy_def.clear_device();
1120
1121 Status s;
1122 Node* copy_node = g->AddNode(copy_def, &s);
1123 TF_RETURN_IF_ERROR(s);
1124
1125 copy_node->AddAttr(kXlaReplicaIdAttrName, replica_id);
1126 copy_node->AddAttr(kTPUReplicateAttr, cluster_name);
1127
1128 for (const Edge* e : n->out_edges()) {
1129 if (e->dst() == xla_graph->sink_node()) {
1130 continue;
1131 }
1132 // Either e->dst() is _Retval, or it's in `node_images`.
1133 if (e->dst()->IsRetval()) {
1134 int index;
1135 TF_RETURN_IF_ERROR(GetNodeAttr(e->dst()->attrs(), "index", &index));
1136 for (const auto& output :
1137 replicated_outputs[replica_id * old_num_replicated_outputs +
1138 index]) {
1139 // Remove original input edge, if existent.
1140 const Edge* original_edge;
1141 Status s = output.first->input_edge(output.second, &original_edge);
1142 if (s.ok()) {
1143 g->RemoveEdge(original_edge);
1144 }
1145 g->AddEdge(copy_node, e->src_output(), output.first, output.second);
1146 }
1147 } else {
1148 g->AddEdge(copy_node, e->src_output(),
1149 node_images[e->dst()][replica_id], e->dst_input());
1150 }
1151 }
1152
1153 // Add attribute "_xla_tail_outside_compilation" to `copy_node`, and add a
1154 // control edge between `xla_node` and `copy_node`. As a result, in later
1155 // rewriting pass, a control edge will be added between `copy_node` and
1156 // "control_after" node for the XLA computation, so `copy_node` will be
1157 // executed before XLA computation's final results.
1158 copy_node->AddAttr("_xla_tail_outside_compilation", true);
1159 g->AddControlEdge(xla_node, copy_node);
1160
1161 // Add control edge between `pivot_node` and `copy_node`, so `copy_node`
1162 // belongs to same while loop as `xla_node`.
1163 if (pivot_node) {
1164 g->AddControlEdge(pivot_node, copy_node);
1165 }
1166
1167 node_images[n].push_back(copy_node);
1168 }
1169 }
1170
1171 // Connect new output values of `xla_node` to dst nodes of `oc_input_edges`.
1172 for (int i = 0; i < new_ret_types.size(); i++) {
1173 const Edge* original_edge = oc_input_edges[i];
1174 for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
1175 int src_output = replica_id * new_num_replicated_outputs +
1176 old_num_replicated_outputs + i;
1177 Node* dst = node_images[original_edge->dst()][replica_id];
1178 g->AddEdge(xla_node, src_output, dst, original_edge->dst_input());
1179 }
1180 }
1181
1182 // Create new _Retval nodes in `xla_graph`.
1183 for (int i = old_num_replicated_outputs; i < new_num_replicated_outputs;
1184 i++) {
1185 NodeDefBuilder ret_builder(absl::StrCat("ret_", i),
1186 FunctionLibraryDefinition::kRetOp);
1187 ret_builder.Attr("T", new_ret_types[i - old_num_replicated_outputs]);
1188 ret_builder.Attr("index", i);
1189 const Edge* original_edge = oc_input_edges[i - old_num_replicated_outputs];
1190 Node* src = original_edge->src();
1191 int src_output = original_edge->src_output();
1192 ret_builder.Input(src->name(), src_output, src->output_type(src_output));
1193 NodeDef ret_def;
1194 TF_RETURN_IF_ERROR(ret_builder.Finalize(&ret_def));
1195 Status s;
1196 Node* ret_node = xla_graph->AddNode(ret_def, &s);
1197 TF_RETURN_IF_ERROR(s);
1198 xla_graph->RemoveEdge(original_edge);
1199 xla_graph->AddEdge(src, src_output, ret_node, 0);
1200 }
1201
1202 // Remove `oc_nodes_at_tail`.
1203 for (Node* n : oc_nodes_at_tail) {
1204 xla_graph->RemoveNode(n);
1205 }
1206
1207 // We cannot leave _Retval with no input. Add a placeholder input, which will
1208 // be removed later with unused _Retval.
1209 std::vector<Node*> unused_rets;
1210 for (Node* n : xla_graph->nodes()) {
1211 if (n->IsRetval() && n->in_edges().empty()) {
1212 unused_rets.push_back(n);
1213 }
1214 }
1215 for (Node* n : unused_rets) {
1216 NodeDefBuilder builder(absl::StrCat("placeholder_", n->name()),
1217 "Placeholder");
1218 DataType dtype;
1219 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
1220 builder.Attr("dtype", dtype);
1221 builder.Attr(kXlaIsPlaceholderForTailOcAttrName, true);
1222 NodeDef def;
1223 TF_RETURN_IF_ERROR(builder.Finalize(&def));
1224 Status s;
1225 Node* placeholder = xla_graph->AddNode(def, &s);
1226 TF_RETURN_IF_ERROR(s);
1227 xla_graph->AddEdge(placeholder, 0, n, 0);
1228 }
1229
1230 VLOG(4) << "MoveTailOutsideCompilationToHost host graph: "
1231 << DumpGraphToFile(absl::StrCat("move_tail_oc_host_", xla_func_name),
1232 *g);
1233 VLOG(4) << "MoveTaildOutsideCompilationToHost XLA graph: "
1234 << DumpGraphToFile(absl::StrCat("move_tail_oc_xla_", xla_func_name),
1235 *xla_graph);
1236
1237 return Status::OK();
1238 }
1239
ReplaceArgUsedByOutsideCompilationWithPlaceholder(const string & outside_compilation_attr_name,const string & xla_func_name,Graph * g,Graph * xla_graph,Node * xla_node)1240 Status ReplaceArgUsedByOutsideCompilationWithPlaceholder(
1241 const string& outside_compilation_attr_name, const string& xla_func_name,
1242 Graph* g, Graph* xla_graph, Node* xla_node) {
1243 std::vector<DataType> input_types;
1244 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "Tinputs", &input_types));
1245 int num_distributed_vars;
1246 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "num_distributed_variables",
1247 &num_distributed_vars));
1248 int num_replicas;
1249 TF_RETURN_IF_ERROR(
1250 GetNodeAttr(xla_node->attrs(), "num_replicas", &num_replicas));
1251 int num_per_replica_inputs =
1252 (input_types.size() - num_distributed_vars) / num_replicas;
1253
1254 for (Node* n : xla_graph->op_nodes()) {
1255 if (!n->IsArg()) {
1256 continue;
1257 }
1258
1259 DataType dtype;
1260 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype));
1261 // TODO(b/74023706): enable moving normal data tensors.
1262 if (dtype != DT_RESOURCE) {
1263 continue;
1264 }
1265
1266 std::vector<const Edge*> oc_out_edges;
1267 for (const Edge* e : n->out_edges()) {
1268 if (e->IsControlEdge() ||
1269 !HasNodeAttr(e->dst()->def(), kOutsideCompilationAttr)) {
1270 continue;
1271 }
1272
1273 oc_out_edges.push_back(e);
1274 }
1275 if (oc_out_edges.empty()) {
1276 continue;
1277 }
1278
1279 // Sometimes `xla_node` can have a lot of inputs, calling Node::input_edge
1280 // will become very expensive in this case because it is doing a linear
1281 // search inside. Create an input_edges vector ahead to make the lookups
1282 // faster.
1283 std::vector<const Edge*> input_edges;
1284 TF_RETURN_IF_ERROR(xla_node->input_edges(&input_edges));
1285
1286 // Build an IdentityN node to record inputs for this _Arg node.
1287 int index;
1288 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
1289 string oc_identifier = absl::StrCat("oc_only_arg_", index);
1290 NodeDefBuilder id_builder(absl::StrCat(oc_identifier, "_inputs"),
1291 "IdentityN");
1292 std::vector<DataType> dtypes(num_replicas, dtype);
1293 id_builder.Attr("T", dtypes);
1294 id_builder.Attr(kXlaOutsideCompilationInputsAttrName, oc_identifier);
1295 std::vector<NodeDefBuilder::NodeOut> inputs(num_replicas);
1296 if (index >= num_per_replica_inputs) {
1297 const Edge* e = input_edges.at(num_replicas * num_per_replica_inputs +
1298 (index - num_per_replica_inputs));
1299 for (int i = 0; i < num_replicas; i++) {
1300 inputs[i] =
1301 NodeDefBuilder::NodeOut{e->src()->name(), e->src_output(),
1302 e->src()->output_type(e->src_output())};
1303 }
1304 } else {
1305 for (int i = 0; i < num_replicas; i++) {
1306 const Edge* e = input_edges.at(i * num_per_replica_inputs + index);
1307 inputs[i] =
1308 NodeDefBuilder::NodeOut{e->src()->name(), e->src_output(),
1309 e->src()->output_type(e->src_output())};
1310 }
1311 }
1312 id_builder.Input(inputs);
1313 NodeDef id_def;
1314 TF_RETURN_IF_ERROR(id_builder.Finalize(&id_def));
1315 Status s;
1316 Node* id_node = g->AddNode(id_def, &s);
1317 TF_RETURN_IF_ERROR(s);
1318 if (index >= num_per_replica_inputs) {
1319 const Edge* e = input_edges.at(num_replicas * num_per_replica_inputs +
1320 (index - num_per_replica_inputs));
1321 for (int i = 0; i < num_replicas; i++) {
1322 g->AddEdge(e->src(), e->src_output(), id_node, i);
1323 }
1324 } else {
1325 for (int i = 0; i < num_replicas; i++) {
1326 const Edge* e = input_edges.at(i * num_per_replica_inputs + index);
1327 g->AddEdge(e->src(), e->src_output(), id_node, i);
1328 }
1329 }
1330
1331 for (const Edge* e : oc_out_edges) {
1332 // 'e' will use a new Placeholder node as input.
1333 NodeDefBuilder ph_builder(xla_graph->NewName("ph_for_arg_in_oc_"),
1334 "Placeholder");
1335 ph_builder.Attr("dtype", dtype);
1336
1337 string outside_compilation_attr;
1338 TF_RETURN_IF_ERROR(GetNodeAttr(e->dst()->def(), kOutsideCompilationAttr,
1339 &outside_compilation_attr));
1340 ph_builder.Attr(kOutsideCompilationAttr, outside_compilation_attr);
1341 ph_builder.Attr(kXlaOutsideCompilationInputsAttrName, oc_identifier);
1342 ph_builder.Attr(kXlaIsPlaceholderForArg, true);
1343 NodeDef ph_def;
1344 TF_RETURN_IF_ERROR(ph_builder.Finalize(&ph_def));
1345 Status s;
1346 Node* ph_node = xla_graph->AddNode(ph_def, &s);
1347 TF_RETURN_IF_ERROR(s);
1348 Node* dst = e->dst();
1349 int dst_input = e->dst_input();
1350 xla_graph->RemoveEdge(e);
1351 xla_graph->AddEdge(ph_node, 0, dst, dst_input);
1352 xla_graph->AddControlEdge(xla_graph->source_node(), ph_node);
1353 }
1354 }
1355 VLOG(4) << "ReplaceOutsideCompilationOnlyArgWithPlaceholder host graph: "
1356 << DumpGraphToFile(
1357 absl::StrCat("replace_oc_only_arg_host_", xla_func_name), *g);
1358 VLOG(4) << "ReplaceOutsideCompilationOnlyArgWithPlaceholder XLA graph: "
1359 << DumpGraphToFile(
1360 absl::StrCat("replace_oc_only_arg_xla_", xla_func_name),
1361 *xla_graph);
1362 return Status::OK();
1363 }
1364
1365 // If there are any unused _Retval nodes in `xla_graph` (whose input is a
1366 // Placeholder node), remove them from `xla_graph` and remove corresponding
1367 // output edge in host graph `g`.
RemoveUnusedXlaOutput(const string & xla_func_name,Graph * g,Graph * xla_graph,Node * xla_node)1368 Status RemoveUnusedXlaOutput(const string& xla_func_name, Graph* g,
1369 Graph* xla_graph, Node* xla_node) {
1370 // Find unused _Retval nodes, and remove them.
1371 std::vector<DataType> output_types;
1372 TF_RETURN_IF_ERROR(
1373 GetNodeAttr(xla_node->def(), "output_types", &output_types));
1374 int num_replicas;
1375 TF_RETURN_IF_ERROR(
1376 GetNodeAttr(xla_node->def(), "num_replicas", &num_replicas));
1377 int num_replicated_outputs = output_types.size() / num_replicas;
1378 std::set<int> ret_indices_to_remove;
1379 std::vector<Node*> ret_nodes_to_update, nodes_to_remove;
1380 int num_rets = 0;
1381 for (Node* n : xla_graph->nodes()) {
1382 if (!n->IsRetval()) {
1383 continue;
1384 }
1385
1386 num_rets++;
1387
1388 const Edge* e;
1389 TF_RETURN_IF_ERROR(n->input_edge(0, &e));
1390 if (e->src()->type_string() != "Placeholder" ||
1391 !HasNodeAttr(e->src()->def(), kXlaIsPlaceholderForTailOcAttrName)) {
1392 ret_nodes_to_update.push_back(n);
1393 continue;
1394 }
1395
1396 int index;
1397 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
1398 ret_indices_to_remove.insert(index);
1399 nodes_to_remove.push_back(e->src());
1400 nodes_to_remove.push_back(n);
1401 }
1402 for (Node* n : nodes_to_remove) {
1403 xla_graph->RemoveNode(n);
1404 }
1405
1406 // Update `index` for other _Arg nodes.
1407 std::map<int, int> ret_index_mapping;
1408 int new_ret_index = 0;
1409 for (int i = 0; i < num_rets; i++) {
1410 if (ret_indices_to_remove.find(i) != ret_indices_to_remove.end()) {
1411 continue;
1412 } else {
1413 ret_index_mapping[i] = new_ret_index;
1414 new_ret_index++;
1415 }
1416 }
1417 for (Node* n : ret_nodes_to_update) {
1418 int index;
1419 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
1420 n->ClearAttr("index");
1421 n->AddAttr("index", ret_index_mapping[index]);
1422 }
1423
1424 // Update `output_types` attribute for `xla_node`.
1425 std::vector<DataType> new_output_types;
1426 for (int i = 0; i < num_replicas; i++) {
1427 for (const auto& e : ret_index_mapping) {
1428 new_output_types.push_back(output_types[e.first]);
1429 }
1430 }
1431
1432 xla_node->ClearAttr("output_types");
1433 xla_node->AddAttr("output_types", new_output_types);
1434
1435 // Re-order replicated output edges for `xla_node`.
1436 std::vector<std::vector<const Edge*>> output_edges(num_replicas *
1437 num_replicated_outputs);
1438 for (const Edge* e : xla_node->out_edges()) {
1439 if (e->src_output() >= 0 &&
1440 e->src_output() < num_replicas * num_replicated_outputs) {
1441 output_edges[e->src_output()].push_back(e);
1442 }
1443 }
1444 for (int i = 0; i < num_replicas; i++) {
1445 for (int j = 0; j < num_replicated_outputs; j++) {
1446 auto iter = ret_index_mapping.find(j);
1447 if (iter != ret_index_mapping.end()) {
1448 for (const Edge* e : output_edges[i * num_replicated_outputs + j]) {
1449 Node* dst = e->dst();
1450 int dst_input = e->dst_input();
1451 int src_output =
1452 i * (num_replicated_outputs - ret_indices_to_remove.size()) +
1453 iter->second;
1454 g->RemoveEdge(e);
1455 g->AddEdge(xla_node, src_output, dst, dst_input);
1456 }
1457 } else {
1458 TF_RET_CHECK(output_edges[i * num_replicated_outputs + j].empty())
1459 << "Output edge not removed: "
1460 << output_edges[i * num_replicated_outputs + j][0]->DebugString();
1461 }
1462 }
1463 }
1464
1465 VLOG(4) << "RemoveUnusedXlaOutput host graph: "
1466 << DumpGraphToFile(
1467 absl::StrCat("remove_unused_output_host_", xla_func_name), *g);
1468 VLOG(4) << "RemoveUnusedXlaOutput XLA graph: "
1469 << DumpGraphToFile(
1470 absl::StrCat("remove_unused_output_xla_", xla_func_name),
1471 *xla_graph);
1472
1473 return Status::OK();
1474 }
1475
1476 // For data edges between _Arg and _Retval in `xla_graph`, remove them and
1477 // change input/output edges in `g` (host graph). For now, we only consider
1478 // replicated inputs.
RemoveEdgesBetweenArgAndRetval(const string & xla_func_name,Graph * g,Graph * xla_graph,Node * xla_node)1479 Status RemoveEdgesBetweenArgAndRetval(const string& xla_func_name, Graph* g,
1480 Graph* xla_graph, Node* xla_node) {
1481 // Collect data edges between _Arg and _Retval.
1482 int num_replicas;
1483 TF_RETURN_IF_ERROR(
1484 GetNodeAttr(xla_node->def(), "num_replicas", &num_replicas));
1485 std::vector<DataType> input_types;
1486 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->def(), "Tinputs", &input_types));
1487 int num_distributed_vars;
1488 TF_RETURN_IF_ERROR(GetNodeAttr(xla_node->attrs(), "num_distributed_variables",
1489 &num_distributed_vars));
1490 int old_num_per_replica_inputs =
1491 (input_types.size() - num_distributed_vars) / num_replicas;
1492 std::vector<DataType> output_types;
1493 TF_RETURN_IF_ERROR(
1494 GetNodeAttr(xla_node->def(), "output_types", &output_types));
1495 int old_num_outputs = output_types.size() / num_replicas;
1496 std::vector<const Edge*> edges;
1497 for (const Edge* e : xla_graph->edges()) {
1498 if (!e->IsControlEdge() && e->src()->IsArg() && e->dst()->IsRetval()) {
1499 edges.push_back(e);
1500 }
1501 }
1502
1503 // In host graph `g`, remove output edge from `xla_node` and connect input &
1504 // output directly.
1505 std::vector<std::vector<const Edge*>> xla_node_out_edges(
1506 xla_node->num_outputs());
1507 for (const Edge* e : xla_node->out_edges()) {
1508 if (!e->IsControlEdge()) {
1509 xla_node_out_edges[e->src_output()].push_back(e);
1510 }
1511 }
1512
1513 // Sometimes `xla_node` can have a lot of inputs, calling Node::input_edge
1514 // will become very expensive in this case because it is doing a linear
1515 // search inside. Create an input_edges vector ahead to make the lookups
1516 // faster.
1517 std::vector<const Edge*> input_edges;
1518 TF_RETURN_IF_ERROR(xla_node->input_edges(&input_edges));
1519 for (const Edge* e : edges) {
1520 int arg_index;
1521 TF_RETURN_IF_ERROR(GetNodeAttr(e->src()->def(), "index", &arg_index));
1522 int ret_index;
1523 TF_RETURN_IF_ERROR(GetNodeAttr(e->dst()->def(), "index", &ret_index));
1524
1525 for (int replica_id = 0; replica_id < num_replicas; replica_id++) {
1526 int input_index;
1527 if (arg_index < old_num_per_replica_inputs) {
1528 input_index = replica_id * old_num_per_replica_inputs + arg_index;
1529 } else {
1530 input_index = num_replicas * old_num_per_replica_inputs +
1531 (arg_index - old_num_per_replica_inputs);
1532 }
1533 const Edge* input_edge = input_edges.at(input_index);
1534
1535 int output_index = replica_id * old_num_outputs + ret_index;
1536 for (const Edge* output_edge : xla_node_out_edges[output_index]) {
1537 Node* dst = output_edge->dst();
1538 int dst_input = output_edge->dst_input();
1539
1540 g->RemoveEdge(output_edge);
1541 g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
1542 }
1543 }
1544 }
1545
1546 // Remove edges from `xla_graph`. Add a Placeholder node for the _Retval node,
1547 // which will be removed by `RemoveUnusedXlaOutput()` later.
1548 for (const Edge* e : edges) {
1549 NodeDefBuilder placeholder_builder(
1550 absl::StrCat("placeholder_", e->dst()->name()), "Placeholder");
1551 placeholder_builder.Attr("dtype", e->src()->output_type(e->src_output()));
1552 placeholder_builder.Attr(kXlaIsPlaceholderForTailOcAttrName, true);
1553 NodeDef placeholder_def;
1554 TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def));
1555 Status s;
1556 Node* placeholder_node = xla_graph->AddNode(placeholder_def, &s);
1557 TF_RETURN_IF_ERROR(s);
1558
1559 Node* dst = e->dst();
1560 int dst_input = e->dst_input();
1561 xla_graph->RemoveEdge(e);
1562 xla_graph->AddEdge(placeholder_node, 0, dst, dst_input);
1563 }
1564
1565 VLOG(4) << "RemoveUnusedArgRetvalPair host graph: "
1566 << DumpGraphToFile(
1567 absl::StrCat("remove_unused_arg_ret_host_", xla_func_name),
1568 *g);
1569 VLOG(4) << "RemoveUnusedArgRetvalPair XLA graph: "
1570 << DumpGraphToFile(
1571 absl::StrCat("remove_unused_arg_ret_xla_", xla_func_name),
1572 *xla_graph);
1573
1574 return Status::OK();
1575 }
1576
1577 // Remove any TPUReplicatedInput nodes with no output edges. Those nodes are
1578 // usually TPUMirroredVariable handles which are not used by any computations.
RemoveUnusedTPUReplicatedInputs(Graph * graph)1579 void RemoveUnusedTPUReplicatedInputs(Graph* graph) {
1580 for (Node* n : graph->nodes()) {
1581 if (n->type_string() == kTPUReplicatedInput) {
1582 bool has_output = false;
1583 for (const Edge* e : n->out_edges()) {
1584 if (!e->dst()->IsSink()) {
1585 has_output = true;
1586 break;
1587 }
1588 }
1589 if (!has_output) {
1590 // Remove any TPUPartitionedInput node from the src nodes of the
1591 // to-be-removed TPUReplicatedInput node
1592 std::vector<Node*> to_be_removed_src_nodes;
1593 for (const auto& e_in : n->in_edges()) {
1594 if (!e_in->IsControlEdge() &&
1595 e_in->src()->type_string() == kTPUPartitionedInput)
1596 to_be_removed_src_nodes.push_back(e_in->src());
1597 }
1598 graph->RemoveNode(n);
1599 for (Node* node : to_be_removed_src_nodes) {
1600 graph->RemoveNode(node);
1601 }
1602 }
1603 }
1604 }
1605 }
1606
1607 // We might have duplicated cluster names in the graph, e.g. when a tf.function
1608 // containing tpu_strategy.run() is called multiple times with
1609 // the same inputs. Find clusters with duplicated names and rename them.
RenameClustersWithDuplicatedNames(Graph * g)1610 Status RenameClustersWithDuplicatedNames(Graph* g) {
1611 // Find all TPU clusters by finding all TPUReplicateMetadata nodes.
1612 std::unordered_map<string, std::vector<Node*>> cluster_name_to_metadata_nodes;
1613 std::unordered_set<string> cluster_names;
1614 for (Node* n : g->nodes()) {
1615 if (n->type_string() != "TPUReplicateMetadata") {
1616 continue;
1617 }
1618 string cluster_name;
1619 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kTPUReplicateAttr, &cluster_name));
1620 cluster_name_to_metadata_nodes[cluster_name].push_back(n);
1621 cluster_names.insert(cluster_name);
1622 }
1623 // Look for clusters with duplicated name.
1624 for (const auto& iter : cluster_name_to_metadata_nodes) {
1625 if (iter.second.size() == 1) {
1626 continue;
1627 }
1628
1629 // Rename clusters.
1630 for (int i = 1; i < iter.second.size(); i++) {
1631 // Find an available cluster name.
1632 string new_cluster_name;
1633 int cluster_name_suffix = 1;
1634 while (true) {
1635 new_cluster_name = absl::StrCat(iter.first, "_", cluster_name_suffix);
1636 if (cluster_names.find(new_cluster_name) == cluster_names.end()) {
1637 break;
1638 }
1639 cluster_name_suffix++;
1640 }
1641 cluster_names.insert(new_cluster_name);
1642
1643 // Change _tpu_replicate attribute for all nodes in this cluster.
1644 // Start with outputs of TPUReplicateMetadata and follow output edges.
1645 std::queue<Node*> queue;
1646 queue.push(iter.second.at(i));
1647 std::unordered_set<Node*> visited;
1648 while (!queue.empty()) {
1649 Node* n = queue.front();
1650 queue.pop();
1651
1652 visited.insert(n);
1653
1654 n->ClearAttr(kTPUReplicateAttr);
1655 n->AddAttr(kTPUReplicateAttr, new_cluster_name);
1656
1657 string cluster_name;
1658 for (const Edge* e : n->out_edges()) {
1659 if (GetNodeAttr(e->dst()->def(), kTPUReplicateAttr, &cluster_name)
1660 .ok() &&
1661 cluster_name == iter.first &&
1662 visited.find(e->dst()) == visited.end()) {
1663 queue.push(e->dst());
1664 }
1665 }
1666 }
1667 // Change "_tpu_compilation_status" attr for TPUCompilationResult node.
1668 for (const Edge* e : iter.second.at(i)->out_edges()) {
1669 if (e->dst()->type_string() == "TPUCompilationResult") {
1670 e->dst()->ClearAttr("_tpu_compilation_status");
1671 e->dst()->AddAttr("_tpu_compilation_status", new_cluster_name);
1672 }
1673 }
1674 }
1675 }
1676 return Status::OK();
1677 }
1678
1679 // Instantiate a function that is associated with a functional control flow
1680 // node. The function name is found by looking up `function_name_attr` of given
1681 // node.
InstantiateAssociatedFunction(const Node & n,absl::string_view function_name_attr,FunctionLibraryDefinition * fld)1682 xla::StatusOr<std::unique_ptr<FunctionBody>> InstantiateAssociatedFunction(
1683 const Node& n, absl::string_view function_name_attr,
1684 FunctionLibraryDefinition* fld) {
1685 std::unique_ptr<FunctionBody> fbody;
1686 NameAttrList func_attr_list;
1687 TF_RETURN_IF_ERROR(GetNodeAttr(n.def(), function_name_attr, &func_attr_list));
1688 const FunctionDef* fdef = fld->Find(func_attr_list.name());
1689 if (fdef == nullptr) {
1690 return errors::Internal("Cannot find ", function_name_attr, " function",
1691 "for node ", n.DebugString());
1692 }
1693 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1694 *fdef, AttrSlice(&func_attr_list.attr()), fld, &fbody));
1695 return fbody;
1696 }
1697
1698 // Find inputs of If node that are only used for outside compilation if used at
1699 // all in both if/else branches
FindArgsToLiftForIfNode(const Node & if_node,FunctionLibraryDefinition * fld)1700 xla::StatusOr<absl::flat_hash_set<int>> FindArgsToLiftForIfNode(
1701 const Node& if_node, FunctionLibraryDefinition* fld) {
1702 absl::flat_hash_set<int> args_to_lift_indices;
1703 std::vector<DataType> dtypes;
1704 TF_RETURN_IF_ERROR(GetNodeAttr(if_node.def(), "Tin", &dtypes));
1705
1706 int num_args = dtypes.size();
1707
1708 for (int i = 0; i < num_args; i++) {
1709 // TODO(b/74023706): enable non resource inputs as well.
1710 if (dtypes[i] == DT_RESOURCE) {
1711 args_to_lift_indices.insert(i);
1712 }
1713 }
1714
1715 TF_ASSIGN_OR_RETURN(
1716 std::unique_ptr<FunctionBody> then_branch_fbody,
1717 InstantiateAssociatedFunction(if_node, "then_branch", fld));
1718
1719 TF_ASSIGN_OR_RETURN(
1720 std::unique_ptr<FunctionBody> else_branch_fbody,
1721 InstantiateAssociatedFunction(if_node, "else_branch", fld));
1722
1723 for (int i = 0; i < num_args; ++i) {
1724 bool used = false;
1725
1726 const Node* then_arg_node = then_branch_fbody->arg_nodes[i];
1727 for (const Edge* e : then_arg_node->out_edges()) {
1728 used = true;
1729 if (e->IsControlEdge() ||
1730 HasNodeAttr(e->dst()->def(), kOutsideCompilationAttr))
1731 continue;
1732
1733 args_to_lift_indices.erase(i);
1734 break;
1735 }
1736
1737 const Node* else_arg_node = else_branch_fbody->arg_nodes[i];
1738 for (const Edge* e : else_arg_node->out_edges()) {
1739 used = true;
1740 if (e->IsControlEdge() ||
1741 HasNodeAttr(e->dst()->def(), kOutsideCompilationAttr))
1742 continue;
1743
1744 args_to_lift_indices.erase(i);
1745 break;
1746 }
1747
1748 // Do not lift arguments that are not used at all. Otherwise, this unused
1749 // arg would be outside compiled, its output tensor will be forced to
1750 // transfer to host needlessly.
1751 if (!used) args_to_lift_indices.erase(i);
1752 }
1753
1754 return args_to_lift_indices;
1755 }
1756
1757 // Find inputs of While node that are:
1758 // 1. not used in cond func,
1759 // 2. only used for outside compilation in body func,
1760 // 3. loop invariant.
1761 // These inputs can be lifted out of the while loop.
FindArgsToLiftForWhileNode(Node * while_node,FunctionLibraryDefinition * fld)1762 xla::StatusOr<absl::flat_hash_set<int>> FindArgsToLiftForWhileNode(
1763 Node* while_node, FunctionLibraryDefinition* fld) {
1764 // DT_RESOURCE inputs are candidates.
1765 absl::flat_hash_set<int> result;
1766 std::vector<DataType> dtypes;
1767 TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "T", &dtypes));
1768 for (int i = 0; i < dtypes.size(); i++) {
1769 // TODO(b/74023706): enable non resource inputs as well.
1770 if (dtypes[i] == DT_RESOURCE) {
1771 result.insert(i);
1772 }
1773 }
1774
1775 // Remove inputs that are used in cond func.
1776 NameAttrList cond_func;
1777 TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "cond", &cond_func));
1778 const FunctionDef* cond_fdef = fld->Find(cond_func.name());
1779 if (cond_fdef == nullptr) {
1780 return errors::Internal("Cannot find cond function ", cond_func.name(),
1781 " for while node ", while_node->DebugString());
1782 }
1783 std::unique_ptr<FunctionBody> cond_fbody;
1784 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1785 *cond_fdef, AttrSlice(&cond_func.attr()), fld, &cond_fbody));
1786 for (int i = 0; i < cond_fbody->arg_nodes.size(); i++) {
1787 const Node* arg_node = cond_fbody->arg_nodes[i];
1788 for (const Edge* e : arg_node->out_edges()) {
1789 if (!e->IsControlEdge()) {
1790 result.erase(i);
1791 }
1792 }
1793 }
1794
1795 // Remove inputs that are not loop invariant.
1796 NameAttrList body_func;
1797 TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "body", &body_func));
1798 const FunctionDef* body_fdef = fld->Find(body_func.name());
1799 if (body_fdef == nullptr) {
1800 return errors::Internal("Cannot find body function ", body_func.name(),
1801 " for while node ", while_node->DebugString());
1802 }
1803 std::unique_ptr<FunctionBody> body_fbody;
1804 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1805 *body_fdef, AttrSlice(&body_func.attr()), fld, &body_fbody));
1806 for (int i = 0; i < body_fbody->ret_nodes.size(); i++) {
1807 const Node* node = body_fbody->ret_nodes[i];
1808 do {
1809 TF_RETURN_IF_ERROR(node->input_node(0, &node));
1810 } while (node->IsIdentity());
1811 if (node != body_fbody->arg_nodes[i]) {
1812 result.erase(i);
1813 }
1814 }
1815
1816 // Remove inputs that only have one output edge (loop invariant, but not used
1817 // in outside compilation).
1818 for (int i = 0; i < body_fbody->arg_nodes.size(); i++) {
1819 const Node* arg_node = body_fbody->arg_nodes[i];
1820 int data_edge_count = std::count_if(
1821 arg_node->out_edges().begin(), arg_node->out_edges().end(),
1822 [](const Edge* e) { return !e->IsControlEdge(); });
1823 if (data_edge_count == 1) {
1824 result.erase(i);
1825 }
1826 }
1827
1828 // Remove inputs that have non-outside-compilation usage.
1829 for (int i = 0; i < body_fbody->arg_nodes.size(); i++) {
1830 const Node* arg_node = body_fbody->arg_nodes[i];
1831 for (const Edge* e : arg_node->out_edges()) {
1832 if (!e->dst()->IsRetval() &&
1833 !HasNodeAttr(e->dst()->def(), kOutsideCompilationAttr)) {
1834 result.erase(i);
1835 break;
1836 }
1837 }
1838 }
1839
1840 return result;
1841 }
1842
1843 // Find inputs of function call node that are only used for outside compilation.
1844 // These inputs can be lifted out of the function call node.
FindArgsToLiftForCallNode(Node * call_node,const FunctionBody & fbody)1845 xla::StatusOr<absl::flat_hash_set<int>> FindArgsToLiftForCallNode(
1846 Node* call_node, const FunctionBody& fbody) {
1847 // DT_RESOURCE inputs are candidates.
1848 absl::flat_hash_set<int> result;
1849 std::vector<DataType> dtypes(call_node->input_types().begin(),
1850 call_node->input_types().end());
1851 for (int i = 0; i < dtypes.size(); i++) {
1852 // TODO(b/74023706): enable for non resource inputs as well.
1853 if (dtypes[i] == DT_RESOURCE) {
1854 result.insert(i);
1855 }
1856 }
1857
1858 // Remove inputs that have non-outside-compilation usage, or not used at all.
1859 for (int i = 0; i < fbody.arg_nodes.size(); i++) {
1860 const Node* arg_node = fbody.arg_nodes[i];
1861 if (arg_node->out_edges().empty()) {
1862 result.erase(i);
1863 continue;
1864 }
1865
1866 for (const Edge* e : arg_node->out_edges()) {
1867 if (!HasNodeAttr(e->dst()->def(), kOutsideCompilationAttr)) {
1868 result.erase(i);
1869 break;
1870 }
1871 }
1872 }
1873 return result;
1874 }
1875
1876 Status LiftOutsideCompilationOnlyArgs(Graph* g, FunctionLibraryRuntime* flr,
1877 FunctionLibraryDefinition* fld,
1878 int* lifted_arg_count, bool* rewritten);
1879
LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(const FunctionBody & fbody,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,int * lifted_arg_count,absl::optional<string> new_func_name,bool * rewritten)1880 Status LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
1881 const FunctionBody& fbody, FunctionLibraryRuntime* flr,
1882 FunctionLibraryDefinition* fld, int* lifted_arg_count,
1883 absl::optional<string> new_func_name, bool* rewritten) {
1884 *rewritten = false;
1885 TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgs(
1886 fbody.graph, flr, fld, lifted_arg_count, rewritten));
1887
1888 if (*rewritten) {
1889 FunctionDef rewritten_fdef;
1890 TF_RETURN_IF_ERROR(GraphToFunctionDef(
1891 *(fbody.graph), fbody.fdef.signature().name(), &rewritten_fdef));
1892 if (new_func_name) {
1893 rewritten_fdef.mutable_signature()->set_name(*new_func_name);
1894 TF_RETURN_IF_ERROR(fld->AddFunctionDef(rewritten_fdef));
1895 } else {
1896 TF_RETURN_IF_ERROR(
1897 fld->ReplaceFunction(fbody.fdef.signature().name(), rewritten_fdef));
1898 }
1899 }
1900
1901 return Status::OK();
1902 }
1903
MakeIdentityNodesForArgsToLift(const absl::flat_hash_set<int> & args_to_lift,const int arg_to_input_edge_offset,Graph * g,Node * n,absl::flat_hash_map<int,string> * lifted_arg_index_to_oc_cluster_name,int * lifted_arg_count)1904 Status MakeIdentityNodesForArgsToLift(
1905 const absl::flat_hash_set<int>& args_to_lift,
1906 const int arg_to_input_edge_offset, Graph* g, Node* n,
1907 absl::flat_hash_map<int, string>* lifted_arg_index_to_oc_cluster_name,
1908 int* lifted_arg_count) {
1909 int num_input = n->num_inputs();
1910 for (int arg_index = 0; arg_index < num_input; ++arg_index) {
1911 if (!args_to_lift.contains(arg_index)) continue;
1912
1913 int input_edge_index = arg_index + arg_to_input_edge_offset;
1914 const Edge* arg_edge;
1915 TF_RETURN_IF_ERROR(n->input_edge(input_edge_index, &arg_edge));
1916
1917 string node_name =
1918 g->NewName(absl::StrCat("lifted_arg", *lifted_arg_count));
1919 (*lifted_arg_count)++;
1920 (*lifted_arg_index_to_oc_cluster_name)[arg_index] = node_name;
1921 NodeDefBuilder id_builder(node_name, "Identity");
1922 id_builder.Attr("T", n->input_type(input_edge_index));
1923 id_builder.Attr(kOutsideCompilationAttr, id_builder.node_name());
1924 id_builder.Attr(kXlaIsLiftedArgAttrName, true);
1925 id_builder.Input(arg_edge->src()->name(), arg_edge->src_output(),
1926 n->input_type(input_edge_index));
1927 NodeDef id_def;
1928 TF_RETURN_IF_ERROR(id_builder.Finalize(&id_def));
1929 Status s;
1930 Node* id_node = g->AddNode(id_def, &s);
1931 TF_RETURN_IF_ERROR(s);
1932 g->AddEdge(arg_edge->src(), arg_edge->src_output(), id_node, 0);
1933 g->AddControlEdge(id_node, n);
1934 }
1935
1936 return Status::OK();
1937 }
1938
1939 // Replaces all usages of lifted args with placeholder nodes. Afterwards,
1940 // removing these args should be safe since they no longer have users.
RemoveArgsToLiftFromFunctionBody(const absl::flat_hash_set<int> & args_to_lift,const std::vector<DataType> & arg_dtypes,const absl::flat_hash_map<int,string> & lifted_arg_index_to_oc_cluster_name,const absl::flat_hash_map<int,int> & index_mapping,const FunctionBody * fbody)1941 Status RemoveArgsToLiftFromFunctionBody(
1942 const absl::flat_hash_set<int>& args_to_lift,
1943 const std::vector<DataType>& arg_dtypes,
1944 const absl::flat_hash_map<int, string>& lifted_arg_index_to_oc_cluster_name,
1945 const absl::flat_hash_map<int, int>& index_mapping,
1946 const FunctionBody* fbody) {
1947 for (int i = 0; i < fbody->arg_nodes.size(); ++i) {
1948 Node* arg_node = fbody->arg_nodes[i];
1949
1950 if (!args_to_lift.contains(i)) {
1951 int new_index = index_mapping.at(i);
1952 arg_node->ClearAttr("index");
1953 arg_node->AddAttr("index", new_index);
1954 arg_node->ClearAttr("T");
1955 arg_node->AddAttr("T", arg_dtypes[i]);
1956 continue;
1957 }
1958
1959 std::vector<const Edge*> out_edges_to_oc;
1960 for (const Edge* e : arg_node->out_edges()) {
1961 if (HasNodeAttr(e->dst()->def(), kOutsideCompilationAttr)) {
1962 out_edges_to_oc.push_back(e);
1963 }
1964 }
1965
1966 for (const Edge* e : out_edges_to_oc) {
1967 string outside_compilation_cluster;
1968 TF_RETURN_IF_ERROR(GetNodeAttr(e->dst()->def(), kOutsideCompilationAttr,
1969 &outside_compilation_cluster));
1970 NodeDefBuilder ph_builder(fbody->graph->NewName("lifted_arg"),
1971 "Placeholder");
1972 ph_builder.Attr("dtype", arg_dtypes[i]);
1973 ph_builder.Attr(kOutsideCompilationAttr, outside_compilation_cluster);
1974 TF_RET_CHECK(lifted_arg_index_to_oc_cluster_name.contains(i));
1975 ph_builder.Attr(kXlaLiftedArgOutsideCompilationAttrName,
1976 lifted_arg_index_to_oc_cluster_name.at(i));
1977
1978 NodeDef ph_def;
1979 TF_RETURN_IF_ERROR(ph_builder.Finalize(&ph_def));
1980
1981 Status s;
1982 Node* ph_node = fbody->graph->AddNode(ph_def, &s);
1983 TF_RETURN_IF_ERROR(s);
1984
1985 Node* dst = e->dst();
1986 int dst_input = e->dst_input();
1987 fbody->graph->RemoveEdge(e);
1988 fbody->graph->AddEdge(ph_node, 0, dst, dst_input);
1989 }
1990
1991 fbody->graph->RemoveNode(arg_node);
1992 }
1993
1994 return Status::OK();
1995 }
1996
CleanUpInEdges(const absl::flat_hash_map<int,int> & index_mapping,const int arg_to_input_edge_offset,Graph * g,Node * n)1997 Status CleanUpInEdges(const absl::flat_hash_map<int, int>& index_mapping,
1998 const int arg_to_input_edge_offset, Graph* g, Node* n) {
1999 int num_inputs = n->num_inputs();
2000 for (int i = 0; i < num_inputs; ++i) {
2001 if (i < arg_to_input_edge_offset) continue;
2002
2003 int arg_idx = i - arg_to_input_edge_offset;
2004 const Edge* e;
2005 TF_RETURN_IF_ERROR(n->input_edge(i, &e));
2006
2007 // If an edge maps to a lifted argument, simply remove that edge from graph.
2008 if (!index_mapping.contains(arg_idx)) {
2009 g->RemoveEdge(e);
2010 continue;
2011 }
2012
2013 // If an edge maps to same input port, nothing to do.
2014 if (index_mapping.at(arg_idx) == arg_idx) continue;
2015
2016 g->AddEdge(e->src(), e->src_output(), n,
2017 index_mapping.at(arg_idx) + arg_to_input_edge_offset);
2018 g->RemoveEdge(e);
2019 }
2020
2021 return Status::OK();
2022 }
2023
UpdateTypeAttribute(const absl::flat_hash_map<int,int> & index_mapping,const string & type_attr_name,const std::vector<DataType> & dtypes,Node * n)2024 Status UpdateTypeAttribute(const absl::flat_hash_map<int, int>& index_mapping,
2025 const string& type_attr_name,
2026 const std::vector<DataType>& dtypes, Node* n) {
2027 std::vector<DataType> new_dtypes;
2028 new_dtypes.reserve(index_mapping.size());
2029 for (int i = 0; i < dtypes.size(); ++i) {
2030 if (index_mapping.contains(i)) {
2031 new_dtypes.emplace_back(dtypes[i]);
2032 }
2033 }
2034
2035 n->ClearAttr(type_attr_name);
2036 n->AddAttr(type_attr_name, new_dtypes);
2037
2038 return Status::OK();
2039 }
2040
2041 // While V2 always creates Identity node for each While node output, which is
2042 // not necessary for XLA computation. Remove those Identity nodes.
RemoveOutputIdentityNodesForWhileV2(Graph * g,Node * while_node)2043 void RemoveOutputIdentityNodesForWhileV2(Graph* g, Node* while_node) {
2044 std::vector<const Edge*> edges_to_identity_node;
2045 for (const Edge* e : while_node->out_edges()) {
2046 if (!e->IsControlEdge() && e->dst()->IsIdentity()) {
2047 edges_to_identity_node.push_back(e);
2048 }
2049 }
2050 for (const Edge* e : edges_to_identity_node) {
2051 Node* identity = e->dst();
2052 std::vector<const Edge*> out_edges(identity->out_edges().begin(),
2053 identity->out_edges().end());
2054 for (const Edge* out_edge : out_edges) {
2055 if (out_edge->IsControlEdge()) {
2056 g->AddControlEdge(while_node, out_edge->dst());
2057 } else {
2058 Node* dst = out_edge->dst();
2059 int dst_input = out_edge->dst_input();
2060 g->RemoveEdge(out_edge);
2061 g->AddEdge(while_node, e->src_output(), dst, dst_input);
2062 }
2063 }
2064 g->RemoveNode(identity);
2065 }
2066 }
2067
2068 // If corresponding While node output is used, change it to use While node input
2069 // instead.
ReplaceOutputEdgesWithInputEdgeSourceForWhile(const absl::flat_hash_set<int> & args_to_lift,Graph * g,Node * while_node)2070 Status ReplaceOutputEdgesWithInputEdgeSourceForWhile(
2071 const absl::flat_hash_set<int>& args_to_lift, Graph* g, Node* while_node) {
2072 std::vector<const Edge*> edges_to_replace;
2073 for (const Edge* e : while_node->out_edges()) {
2074 if (args_to_lift.contains(e->src_output())) {
2075 edges_to_replace.push_back(e);
2076 }
2077 }
2078 for (const Edge* e : edges_to_replace) {
2079 const Edge* input_edge;
2080 TF_RETURN_IF_ERROR(while_node->input_edge(e->src_output(), &input_edge));
2081 Node* dst = e->dst();
2082 int dst_input = e->dst_input();
2083 g->RemoveEdge(e);
2084 g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
2085 }
2086
2087 return Status::OK();
2088 }
2089
2090 // Calculates mapping from argument index before lifting to index afterwards.
ArgIndexMapping(const int num_args,const absl::flat_hash_set<int> & args_to_lift)2091 absl::flat_hash_map<int, int> ArgIndexMapping(
2092 const int num_args, const absl::flat_hash_set<int>& args_to_lift) {
2093 absl::flat_hash_map<int, int> index_mapping;
2094 int new_index = 0;
2095 for (int i = 0; i < num_args; i++) {
2096 if (!args_to_lift.contains(i)) {
2097 index_mapping[i] = new_index;
2098 ++new_index;
2099 }
2100 }
2101
2102 return index_mapping;
2103 }
2104
2105 // Remove outputs of While node body function that maps to lifted arguments.
CleanUpRetvalsForWhileBody(const absl::flat_hash_map<int,int> & index_mapping,const std::vector<DataType> & dtypes,FunctionBody * fbody)2106 void CleanUpRetvalsForWhileBody(
2107 const absl::flat_hash_map<int, int>& index_mapping,
2108 const std::vector<DataType>& dtypes, FunctionBody* fbody) {
2109 for (int i = 0; i < fbody->ret_nodes.size(); i++) {
2110 Node* ret_node = fbody->ret_nodes[i];
2111 if (index_mapping.contains(i)) {
2112 int new_index = index_mapping.at(i);
2113 ret_node->ClearAttr("index");
2114 ret_node->AddAttr("index", new_index);
2115 ret_node->ClearAttr("T");
2116 ret_node->AddAttr("T", dtypes[i]);
2117 } else {
2118 fbody->graph->RemoveNode(ret_node);
2119 }
2120 }
2121 }
2122
LiftOutsideCompilationOnlyArgsFromWhileNode(Graph * g,Node * while_node,FunctionLibraryDefinition * fld,int * lifted_arg_count,bool * rewritten)2123 Status LiftOutsideCompilationOnlyArgsFromWhileNode(
2124 Graph* g, Node* while_node, FunctionLibraryDefinition* fld,
2125 int* lifted_arg_count, bool* rewritten) {
2126 *rewritten = false;
2127
2128 TF_ASSIGN_OR_RETURN(absl::flat_hash_set<int> args_to_lift,
2129 FindArgsToLiftForWhileNode(while_node, fld));
2130 if (args_to_lift.empty()) return Status::OK();
2131
2132 RemoveOutputIdentityNodesForWhileV2(g, while_node);
2133
2134 TF_RETURN_IF_ERROR(ReplaceOutputEdgesWithInputEdgeSourceForWhile(
2135 args_to_lift, g, while_node));
2136
2137 std::vector<DataType> dtypes;
2138 TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "T", &dtypes));
2139
2140 absl::flat_hash_map<int, int> index_mapping =
2141 ArgIndexMapping(dtypes.size(), args_to_lift);
2142
2143 // For each lifted arg, add an outside compilation Identity node to send
2144 // it to host.
2145 absl::flat_hash_map<int, string> lifted_arg_index_to_oc_cluster_name;
2146 TF_RETURN_IF_ERROR(MakeIdentityNodesForArgsToLift(
2147 args_to_lift, /*arg_to_input_edge_offset=*/0, g, while_node,
2148 &lifted_arg_index_to_oc_cluster_name, lifted_arg_count));
2149
2150 // For cond func, remove _Arg nodes.
2151 TF_ASSIGN_OR_RETURN(std::unique_ptr<FunctionBody> cond_fbody,
2152 InstantiateAssociatedFunction(*while_node, "cond", fld));
2153 TF_RETURN_IF_ERROR(RemoveArgsToLiftFromFunctionBody(
2154 args_to_lift, dtypes, lifted_arg_index_to_oc_cluster_name, index_mapping,
2155 cond_fbody.get()));
2156
2157 FunctionDef rewritten_cond_fdef;
2158 TF_RETURN_IF_ERROR(GraphToFunctionDef(*(cond_fbody->graph),
2159 cond_fbody->fdef.signature().name(),
2160 &rewritten_cond_fdef));
2161 TF_RETURN_IF_ERROR(fld->ReplaceFunction(cond_fbody->fdef.signature().name(),
2162 rewritten_cond_fdef));
2163
2164 // For body func, remove _Retval nodes, and replace _Arg nodes with
2165 // Placeholder nodes.
2166 TF_ASSIGN_OR_RETURN(std::unique_ptr<FunctionBody> body_fbody,
2167 InstantiateAssociatedFunction(*while_node, "body", fld));
2168
2169 TF_RETURN_IF_ERROR(RemoveArgsToLiftFromFunctionBody(
2170 args_to_lift, dtypes, lifted_arg_index_to_oc_cluster_name, index_mapping,
2171 body_fbody.get()));
2172
2173 CleanUpRetvalsForWhileBody(index_mapping, dtypes, body_fbody.get());
2174
2175 FunctionDef rewritten_body_fdef;
2176 TF_RETURN_IF_ERROR(GraphToFunctionDef(*(body_fbody->graph),
2177 body_fbody->fdef.signature().name(),
2178 &rewritten_body_fdef));
2179 TF_RETURN_IF_ERROR(fld->ReplaceFunction(body_fbody->fdef.signature().name(),
2180 rewritten_body_fdef));
2181
2182 // Remove edges from lifted args to While node, and change "T" attr of the
2183 // While node.
2184 TF_RETURN_IF_ERROR(CleanUpInEdges(
2185 index_mapping, /*arg_to_input_edge_offset=*/0, g, while_node));
2186
2187 TF_RETURN_IF_ERROR(
2188 UpdateTypeAttribute(index_mapping, "T", dtypes, while_node));
2189
2190 *rewritten = true;
2191
2192 return Status::OK();
2193 }
2194
LiftOutsideCompilationOnlyArgsFromIfNode(Graph * g,Node * if_node,FunctionLibraryDefinition * fld,int * lifted_arg_count,bool * rewritten)2195 Status LiftOutsideCompilationOnlyArgsFromIfNode(Graph* g, Node* if_node,
2196 FunctionLibraryDefinition* fld,
2197 int* lifted_arg_count,
2198 bool* rewritten) {
2199 *rewritten = false;
2200 TF_ASSIGN_OR_RETURN(absl::flat_hash_set<int> args_to_lift,
2201 FindArgsToLiftForIfNode(*if_node, fld));
2202 if (args_to_lift.empty()) return Status::OK();
2203
2204 std::vector<DataType> dtypes;
2205 TF_RETURN_IF_ERROR(GetNodeAttr(if_node->def(), "Tin", &dtypes));
2206
2207 absl::flat_hash_map<int, int> index_mapping;
2208 int new_index = 0;
2209 for (int i = 0; i < dtypes.size(); i++) {
2210 if (!args_to_lift.contains(i)) {
2211 index_mapping[i] = new_index;
2212 ++new_index;
2213 }
2214 }
2215
2216 // For each lifted arg, add an outside compilation Identity node to send
2217 // it to host.
2218 absl::flat_hash_map<int, string> lifted_arg_index_to_oc_cluster_name;
2219 TF_RETURN_IF_ERROR(MakeIdentityNodesForArgsToLift(
2220 args_to_lift, /*arg_to_input_edge_offset=*/1, g, if_node,
2221 &lifted_arg_index_to_oc_cluster_name, lifted_arg_count));
2222
2223 TF_ASSIGN_OR_RETURN(
2224 std::unique_ptr<FunctionBody> then_branch_fbody,
2225 InstantiateAssociatedFunction(*if_node, "then_branch", fld));
2226
2227 TF_RETURN_IF_ERROR(RemoveArgsToLiftFromFunctionBody(
2228 args_to_lift, dtypes, lifted_arg_index_to_oc_cluster_name, index_mapping,
2229 then_branch_fbody.get()));
2230
2231 FunctionDef rewritten_then_branch_fdef;
2232 TF_RETURN_IF_ERROR(GraphToFunctionDef(
2233 *(then_branch_fbody->graph), then_branch_fbody->fdef.signature().name(),
2234 &rewritten_then_branch_fdef));
2235 TF_RETURN_IF_ERROR(fld->ReplaceFunction(
2236 then_branch_fbody->fdef.signature().name(), rewritten_then_branch_fdef));
2237
2238 TF_ASSIGN_OR_RETURN(
2239 std::unique_ptr<FunctionBody> else_branch_fbody,
2240 InstantiateAssociatedFunction(*if_node, "else_branch", fld));
2241
2242 TF_RETURN_IF_ERROR(RemoveArgsToLiftFromFunctionBody(
2243 args_to_lift, dtypes, lifted_arg_index_to_oc_cluster_name, index_mapping,
2244 else_branch_fbody.get()));
2245
2246 FunctionDef rewritten_else_branch_fdef;
2247 TF_RETURN_IF_ERROR(GraphToFunctionDef(
2248 *(else_branch_fbody->graph), else_branch_fbody->fdef.signature().name(),
2249 &rewritten_else_branch_fdef));
2250 TF_RETURN_IF_ERROR(fld->ReplaceFunction(
2251 else_branch_fbody->fdef.signature().name(), rewritten_else_branch_fdef));
2252
2253 // Remove edges from lifted args to If node, and change "Tin" attr of the
2254 // If node.
2255 TF_RETURN_IF_ERROR(CleanUpInEdges(
2256 index_mapping, /*arg_to_input_edge_offset=*/1, g, if_node));
2257 TF_RETURN_IF_ERROR(
2258 UpdateTypeAttribute(index_mapping, "Tin", dtypes, if_node));
2259
2260 *rewritten = true;
2261
2262 return Status::OK();
2263 }
2264
LiftOutsideCompilationOnlyArgsFromCallNode(Graph * g,Node * call_node,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,int * lifted_arg_count,bool * rewritten)2265 Status LiftOutsideCompilationOnlyArgsFromCallNode(
2266 Graph* g, Node* call_node, FunctionLibraryRuntime* flr,
2267 FunctionLibraryDefinition* fld, int* lifted_arg_count, bool* rewritten) {
2268 *rewritten = false;
2269
2270 // Instantiate the function.
2271 NameAttrList func;
2272 if (fld->Contains(call_node->type_string())) {
2273 func.set_name(call_node->type_string());
2274 *func.mutable_attr() = call_node->def().attr();
2275 } else if (call_node->IsPartitionedCall()) {
2276 TF_RETURN_IF_ERROR(GetNodeAttr(call_node->def(), "f", &func));
2277 } else {
2278 TF_RET_CHECK(call_node->type_string() ==
2279 FunctionLibraryDefinition::kGradientOp);
2280 func.set_name(FunctionLibraryDefinition::kGradientOp);
2281 *func.mutable_attr() = call_node->def().attr();
2282 }
2283 FunctionLibraryRuntime::Handle handle;
2284 TF_RETURN_IF_ERROR(
2285 flr->Instantiate(func.name(), AttrSlice(&func.attr()), &handle));
2286 auto cleanup_handle = gtl::MakeCleanup(
2287 [&flr, &handle]() { flr->ReleaseHandle(handle).IgnoreError(); });
2288 const FunctionBody* fbody = flr->GetFunctionBody(handle);
2289
2290 // Find _Arg nodes to lift.
2291 TF_ASSIGN_OR_RETURN(absl::flat_hash_set<int> args_to_lift,
2292 FindArgsToLiftForCallNode(call_node, *fbody));
2293 if (args_to_lift.empty()) return Status::OK();
2294
2295 std::vector<DataType> dtypes;
2296 dtypes = std::vector<DataType>(call_node->input_types().begin(),
2297 call_node->input_types().end());
2298
2299 absl::flat_hash_map<int, int> index_mapping =
2300 ArgIndexMapping(dtypes.size(), args_to_lift);
2301
2302 // For each lifted arg, add an outside compilation Identity node to send
2303 // it to host.
2304 absl::flat_hash_map<int, string> lifted_arg_index_to_oc_cluster_name;
2305 TF_RETURN_IF_ERROR(MakeIdentityNodesForArgsToLift(
2306 args_to_lift, /*arg_to_input_edge_offset=*/0, g, call_node,
2307 &lifted_arg_index_to_oc_cluster_name, lifted_arg_count));
2308
2309 // Remove _Arg nodes.
2310 TF_RETURN_IF_ERROR(RemoveArgsToLiftFromFunctionBody(
2311 args_to_lift, dtypes, lifted_arg_index_to_oc_cluster_name, index_mapping,
2312 fbody));
2313
2314 // Store rewritten function as a new function, because the original function
2315 // might be defined by user and we should not modify it.
2316 FunctionDef rewritten_fdef;
2317 TF_RETURN_IF_ERROR(GraphToFunctionDef(
2318 *(fbody->graph), fbody->fdef.signature().name(), &rewritten_fdef));
2319 string new_func_name =
2320 fld->UniqueFunctionName(fbody->fdef.signature().name());
2321 rewritten_fdef.mutable_signature()->set_name(new_func_name);
2322 TF_RETURN_IF_ERROR(fld->AddFunctionDef(rewritten_fdef));
2323
2324 // Remove edges from lifted args to call node.
2325 TF_RETURN_IF_ERROR(CleanUpInEdges(
2326 index_mapping, /*arg_to_input_edge_offset=*/0, g, call_node));
2327
2328 // Rewrite the call node to use the rewritten function.
2329 NodeDef node_def;
2330 node_def.set_name(g->NewName(call_node->name()));
2331 node_def.set_op(new_func_name);
2332 if (call_node->IsPartitionedCall()) {
2333 NameAttrList f;
2334 TF_RETURN_IF_ERROR(GetNodeAttr(call_node->def(), "f", &f));
2335 *node_def.mutable_attr() = f.attr();
2336 } else if (fld->Contains(call_node->type_string())) {
2337 *node_def.mutable_attr() = call_node->def().attr();
2338 } else {
2339 TF_RET_CHECK(call_node->type_string() ==
2340 FunctionLibraryDefinition::kGradientOp);
2341 *node_def.mutable_attr() = call_node->def().attr();
2342 node_def.mutable_attr()->erase(FunctionLibraryDefinition::kFuncAttr);
2343 }
2344 TF_ASSIGN_OR_RETURN(call_node, ReplaceNode(g, call_node, node_def));
2345
2346 *rewritten = true;
2347
2348 return Status::OK();
2349 }
2350
2351 // Lifts outside compilation only _Arg nodes out of If/While/function nodes.
LiftOutsideCompilationOnlyArgs(Graph * g,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld,int * lifted_arg_count,bool * rewritten)2352 Status LiftOutsideCompilationOnlyArgs(Graph* g, FunctionLibraryRuntime* flr,
2353 FunctionLibraryDefinition* fld,
2354 int* lifted_arg_count, bool* rewritten) {
2355 *rewritten = false;
2356
2357 // Handle deeper functional nodes first.
2358 std::vector<Node*> while_nodes, if_nodes, call_nodes;
2359 for (Node* n : g->op_nodes()) {
2360 if (HasNodeAttr(n->def(), kOutsideCompilationAttr)) {
2361 continue;
2362 }
2363
2364 if (n->IsWhileNode()) {
2365 TF_ASSIGN_OR_RETURN(std::unique_ptr<FunctionBody> body_fbody,
2366 InstantiateAssociatedFunction(*n, "body", fld));
2367 bool func_rewritten = false;
2368 TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
2369 *body_fbody, flr, fld, lifted_arg_count,
2370 /*new_func_name=*/absl::nullopt, &func_rewritten));
2371 *rewritten = *rewritten || func_rewritten;
2372
2373 while_nodes.push_back(n);
2374 } else if (n->IsIfNode()) {
2375 TF_ASSIGN_OR_RETURN(
2376 std::unique_ptr<FunctionBody> then_branch_fbody,
2377 InstantiateAssociatedFunction(*n, "then_branch", fld));
2378 bool func_rewritten = false;
2379 TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
2380 *then_branch_fbody, flr, fld, lifted_arg_count,
2381 /*new_func_name=*/absl::nullopt, &func_rewritten));
2382 *rewritten |= func_rewritten;
2383
2384 TF_ASSIGN_OR_RETURN(
2385 std::unique_ptr<FunctionBody> else_branch_fbody,
2386 InstantiateAssociatedFunction(*n, "else_branch", fld));
2387 func_rewritten = false;
2388 TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
2389 *else_branch_fbody, flr, fld, lifted_arg_count,
2390 /*new_func_name=*/absl::nullopt, &func_rewritten));
2391 *rewritten |= func_rewritten;
2392
2393 if_nodes.push_back(n);
2394 } else if (IsFunctionCall(*fld, *n)) {
2395 // Function call nodes need to be rewritten, so handle them later.
2396 call_nodes.push_back(n);
2397 }
2398 }
2399
2400 std::vector<Node*> rewritten_call_nodes;
2401 for (Node* call_node : call_nodes) {
2402 if (call_node->IsPartitionedCall()) {
2403 std::unique_ptr<FunctionBody> function_fbody;
2404 TF_ASSIGN_OR_RETURN(function_fbody,
2405 InstantiateAssociatedFunction(*call_node, "f", fld));
2406 bool func_rewritten = false;
2407 string new_func_name =
2408 fld->UniqueFunctionName(function_fbody->fdef.signature().name());
2409 TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
2410 *function_fbody, flr, fld, lifted_arg_count, new_func_name,
2411 &func_rewritten));
2412 if (func_rewritten) {
2413 NameAttrList f;
2414 TF_RETURN_IF_ERROR(GetNodeAttr(call_node->def(), "f", &f));
2415 f.set_name(new_func_name);
2416 call_node->ClearAttr("f");
2417 call_node->AddAttr("f", f);
2418 }
2419
2420 *rewritten |= func_rewritten;
2421 rewritten_call_nodes.push_back(call_node);
2422 } else if (fld->Contains(call_node->type_string())) {
2423 std::unique_ptr<FunctionBody> function_fbody;
2424 const FunctionDef* fdef = fld->Find(call_node->type_string());
2425 TF_RET_CHECK(fdef);
2426 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, call_node->attrs(), fld,
2427 &function_fbody));
2428 bool func_rewritten = false;
2429 string new_func_name =
2430 fld->UniqueFunctionName(function_fbody->fdef.signature().name());
2431 TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
2432 *function_fbody, flr, fld, lifted_arg_count, new_func_name,
2433 &func_rewritten));
2434 if (func_rewritten) {
2435 NodeDef node_def;
2436 node_def.set_name(g->NewName(call_node->name()));
2437 node_def.set_op(new_func_name);
2438 *node_def.mutable_attr() = call_node->def().attr();
2439 TF_ASSIGN_OR_RETURN(call_node, ReplaceNode(g, call_node, node_def));
2440 }
2441
2442 *rewritten |= func_rewritten;
2443 rewritten_call_nodes.push_back(call_node);
2444 } else {
2445 TF_RET_CHECK(call_node->type_string() ==
2446 FunctionLibraryDefinition::kGradientOp);
2447 FunctionLibraryRuntime::Handle handle;
2448 TF_RETURN_IF_ERROR(flr->Instantiate(call_node->type_string(),
2449 call_node->attrs(), &handle));
2450 auto cleanup_handle = gtl::MakeCleanup(
2451 [&flr, &handle]() { flr->ReleaseHandle(handle).IgnoreError(); });
2452 bool func_rewritten = false;
2453 string new_func_name = fld->UniqueFunctionName(
2454 absl::StrCat(call_node->name(), "_lift_args"));
2455 const FunctionBody* function_fbody = flr->GetFunctionBody(handle);
2456 TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsAndReplaceFunctionDef(
2457 *function_fbody, flr, fld, lifted_arg_count, new_func_name,
2458 &func_rewritten));
2459 if (func_rewritten) {
2460 NodeDef node_def;
2461 node_def.set_name(g->NewName(call_node->name()));
2462 node_def.set_op(new_func_name);
2463 *node_def.mutable_attr() = call_node->def().attr();
2464 node_def.mutable_attr()->erase(FunctionLibraryDefinition::kFuncAttr);
2465 TF_ASSIGN_OR_RETURN(call_node, ReplaceNode(g, call_node, node_def));
2466 }
2467
2468 *rewritten |= func_rewritten;
2469 rewritten_call_nodes.push_back(call_node);
2470 }
2471 }
2472
2473 for (Node* n : while_nodes) {
2474 bool node_rewritten = false;
2475 TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsFromWhileNode(
2476 g, n, fld, lifted_arg_count, &node_rewritten));
2477 *rewritten = *rewritten || node_rewritten;
2478 }
2479
2480 for (Node* n : if_nodes) {
2481 bool node_rewritten = false;
2482 TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsFromIfNode(
2483 g, n, fld, lifted_arg_count, &node_rewritten));
2484 *rewritten = *rewritten || node_rewritten;
2485 }
2486
2487 for (Node* n : rewritten_call_nodes) {
2488 bool node_rewritten = false;
2489 TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgsFromCallNode(
2490 g, n, flr, fld, lifted_arg_count, &node_rewritten));
2491 *rewritten = *rewritten || node_rewritten;
2492 }
2493
2494 if (*rewritten) {
2495 VLOG(4) << DumpGraphToFile("after_lifting_args", *g, fld);
2496 }
2497
2498 return Status::OK();
2499 }
2500
2501 } // namespace
2502
Encapsulate(std::unique_ptr<Graph> * graph,FunctionLibraryDefinition * flib_def)2503 /*static*/ Status EncapsulateTPUComputationsPass::Encapsulate(
2504 std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def) {
2505 // Check for undeclared outputs before Encapsulation, so we can give a better
2506 // error message.
2507 // TODO(phawkins): merge this with the encapsulation code to avoid the extra
2508 // O(n) pass over the edges.
2509 for (const Edge* e : (*graph)->edges()) {
2510 if (!e->IsControlEdge() &&
2511 e->src()->attrs().Find(kTPUReplicateAttr) != nullptr &&
2512 e->src()->attrs().Find(kOutsideCompilationAttr) == nullptr &&
2513 e->dst()->attrs().Find(kTPUReplicateAttr) == nullptr &&
2514 e->dst()->type_string() != kTPUReplicatedOutput) {
2515 return errors::InvalidArgument(
2516 "Undeclared output of TPU computation. A common cause of this error "
2517 "is variable initializers that depend on the TPU computation. Edge: ",
2518 FormatNodeForError(*e->src()), ":", e->src_output(), " -> ",
2519 FormatNodeForError(*e->dst()), ":", e->dst_input());
2520 }
2521 }
2522
2523 RemoveUnusedTPUReplicatedInputs(graph->get());
2524
2525 TF_RETURN_IF_ERROR(RenameClustersWithDuplicatedNames(graph->get()));
2526
2527 TF_RETURN_IF_ERROR(
2528 PerformStaticShapeInferenceBeforeEncapsulation(graph->get()));
2529
2530 auto output = absl::make_unique<Graph>((*graph)->op_registry());
2531 TF_RETURN_WITH_CONTEXT_IF_ERROR(
2532 EncapsulateSubgraphsInFunctions(
2533 kTPUReplicateAttr, **graph, RewriteSubgraph,
2534 /*reuse_existing_functions=*/true, &output, flib_def),
2535 "EncapsulateTPUComputationsPass failed");
2536 graph->swap(output);
2537
2538 return Status::OK();
2539 }
2540
BuildTPUReplicateOps(Graph * graph)2541 /*static*/ Status EncapsulateTPUComputationsPass::BuildTPUReplicateOps(
2542 Graph* graph) {
2543 // Finds all of the replicate function calls, to avoid mutating the graph
2544 // while iterating.
2545 std::vector<Node*> replicate_nodes;
2546 std::vector<Node*> guarantee_const_nodes;
2547 for (Node* n : graph->nodes()) {
2548 string name;
2549 if (TryGetNodeAttr(n->attrs(), kTPUReplicateAttr, &name) &&
2550 !TryGetNodeAttr(n->attrs(), kOutsideCompilationAttr, &name)) {
2551 replicate_nodes.push_back(n);
2552 } else if (n->type_string() == "GuaranteeConst") {
2553 guarantee_const_nodes.push_back(n);
2554 }
2555 }
2556
2557 // Replace any GuaranteeConst nodes with Identity nodes. These nodes have now
2558 // served their purpose and have no runtime effect, except increasing
2559 // inference latency due to executor overhead. Subsequent rewrites will remove
2560 // the Identity nodes.
2561 for (Node* n : guarantee_const_nodes) {
2562 std::vector<std::pair<Node*, int>> predecessors;
2563 for (const Edge* e : n->in_edges()) {
2564 predecessors.emplace_back(e->src(), e->src_output());
2565 }
2566 std::vector<std::pair<Node*, int>> successors;
2567 for (const Edge* e : n->out_edges()) {
2568 successors.emplace_back(e->dst(), e->dst_input());
2569 }
2570 NodeDef ndef;
2571 ndef.set_name(n->name());
2572 ndef.set_op("Identity");
2573 ndef.set_device(n->requested_device());
2574 MergeDebugInfo(NodeDebugInfo(n->def()), &ndef);
2575 AddNodeAttr("T", n->output_type(0), &ndef);
2576
2577 graph->RemoveNode(n);
2578 Status s;
2579 Node* id_node = graph->AddNode(ndef, &s);
2580 TF_RETURN_IF_ERROR(s);
2581
2582 for (const auto& pred : predecessors) {
2583 if (pred.second < 0) {
2584 graph->AddControlEdge(pred.first, id_node);
2585 } else {
2586 graph->AddEdge(pred.first, pred.second, id_node, 0);
2587 }
2588 }
2589 for (const auto& succ : successors) {
2590 if (succ.second < 0) {
2591 graph->AddControlEdge(id_node, succ.first);
2592 } else {
2593 graph->AddEdge(id_node, 0, succ.first, succ.second);
2594 }
2595 }
2596 }
2597
2598 // Replaces each replicate function call together with its neighboring
2599 // TPUReplicatedInput/TPUReplicatedOutput nodes with a TPUReplicate node.
2600 for (Node* replicate : replicate_nodes) {
2601 int num_replicas;
2602 TF_RETURN_IF_ERROR(
2603 GetNodeAttr(replicate->attrs(), "num_replicas", &num_replicas));
2604 int variable_start_index;
2605 TF_RETURN_IF_ERROR(GetNodeAttr(replicate->attrs(), "_variable_start_index",
2606 &variable_start_index));
2607 int guaranteed_const_start_index;
2608 TF_RETURN_IF_ERROR(GetNodeAttr(replicate->attrs(),
2609 "_guaranteed_const_start_index",
2610 &guaranteed_const_start_index));
2611
2612 if (HasNodeAttr(replicate->def(), "use_tpu")) {
2613 bool use_tpu;
2614 TF_RETURN_IF_ERROR(GetNodeAttr(replicate->attrs(), "use_tpu", &use_tpu));
2615 if (!use_tpu) {
2616 LOG(WARNING) << "use_tpu=false attr on a TPUReplicate node is ignored.";
2617 }
2618 }
2619
2620 std::vector<const Edge*> in_edges;
2621 TF_RETURN_IF_ERROR(replicate->input_edges(&in_edges));
2622
2623 // Counts the number of replicated, non-replicated, and variable inputs.
2624 int pos = 0;
2625 std::vector<int> mirrored_variable_indices;
2626 int distributed_var_start_index = 0;
2627 while (pos < in_edges.size() &&
2628 in_edges[pos]->src()->type_string() == kTPUReplicatedInput) {
2629 // Checks that each TPUReplicatedInput node has the correct number of
2630 // replicas.
2631 int input_num_replicas;
2632 TF_RETURN_IF_ERROR(
2633 GetNodeAttr(in_edges[pos]->src()->attrs(), "N", &input_num_replicas));
2634
2635 bool is_mirrored_variable;
2636 CHECK(GetNodeAttr(in_edges[pos]->src()->attrs(), "is_mirrored_variable",
2637 &is_mirrored_variable)
2638 .ok());
2639 if (is_mirrored_variable) {
2640 mirrored_variable_indices.push_back(pos);
2641 }
2642
2643 bool is_packed = false;
2644 GetNodeAttr(in_edges[pos]->src()->attrs(), "is_packed", &is_packed)
2645 .IgnoreError();
2646
2647 bool is_distributed_variable =
2648 is_packed && (in_edges[pos]->src()->output_type(
2649 in_edges[pos]->src_output()) == DT_RESOURCE);
2650
2651 if (!is_distributed_variable && input_num_replicas != num_replicas) {
2652 return errors::InvalidArgument(
2653 "Mismatched number of replicas. Computation has ", num_replicas,
2654 " replicas, input '", FormatNodeForError(*in_edges[pos]->src()),
2655 "' has ", input_num_replicas, " replicas.");
2656 }
2657
2658 if (!is_distributed_variable) {
2659 if (distributed_var_start_index < pos) {
2660 return errors::InvalidArgument(
2661 "Expect a distributed resource after index ",
2662 distributed_var_start_index,
2663 ", but got a replicated resource at index ", pos);
2664 } else {
2665 ++distributed_var_start_index;
2666 }
2667 }
2668 ++pos;
2669 }
2670 const int num_replicated_inputs = distributed_var_start_index;
2671 const int num_distributed_vars = pos - num_replicated_inputs;
2672
2673 const int num_variables =
2674 std::max(0, guaranteed_const_start_index - variable_start_index);
2675
2676 const int num_guaranteed_constants =
2677 in_edges.size() - guaranteed_const_start_index;
2678 TF_RET_CHECK(num_guaranteed_constants >= 0);
2679
2680 VLOG(1) << "Replicate node '" << replicate->name() << "'"
2681 << " input edges: " << in_edges.size()
2682 << " num_replicated_inputs: " << num_replicated_inputs
2683 << " num_distributed_vars: " << num_distributed_vars
2684 << " num_variables: " << num_variables
2685 << " num_guaranteed_constants: " << num_guaranteed_constants
2686 << " num_mirrored_variables: " << mirrored_variable_indices.size();
2687
2688 const int num_broadcast_inputs =
2689 in_edges.size() - (num_replicated_inputs + num_distributed_vars +
2690 num_variables + num_guaranteed_constants);
2691 TF_RET_CHECK(num_broadcast_inputs >= 0);
2692
2693 const int num_inputs = num_replicated_inputs * num_replicas +
2694 num_distributed_vars + num_broadcast_inputs +
2695 num_guaranteed_constants + num_variables;
2696
2697 std::vector<Node*> nodes_to_remove = {replicate};
2698
2699 // Data and control inputs to the new TPUReplicate node.
2700 std::vector<std::pair<Node*, int>> data_inputs(num_inputs);
2701 gtl::FlatSet<Node*> control_inputs;
2702
2703 AddControlInputs(*replicate, &control_inputs);
2704
2705 // Replicated inputs. Adds the inputs from the TPUReplicatedInput inputs,
2706 // in replica-major order. See the comments in
2707 // distributed_tpu_rewrite_pass.h for a description of the argument order.
2708 DataTypeVector replicated_input_types(num_replicated_inputs * num_replicas +
2709 num_distributed_vars);
2710
2711 // Inputs with is_distributed_variable = false.
2712 for (int i = 0; i < num_replicated_inputs; ++i) {
2713 std::vector<const Edge*> replica_in_edges;
2714 TF_RETURN_IF_ERROR(in_edges[i]->src()->input_edges(&replica_in_edges));
2715 for (int replica = 0; replica < num_replicas; ++replica) {
2716 int pos = replica * num_replicated_inputs + i;
2717 const Edge* edge = replica_in_edges[replica];
2718 data_inputs[pos] = {edge->src(), edge->src_output()};
2719 replicated_input_types[pos] = EdgeType(edge);
2720 }
2721 AddControlInputs(*in_edges[i]->src(), &control_inputs);
2722 nodes_to_remove.push_back(in_edges[i]->src());
2723 }
2724
2725 // Inputs with is_distributed_variable = true.
2726 for (int i = 0; i < num_distributed_vars; ++i) {
2727 int pos = num_replicas * num_replicated_inputs + i;
2728 std::vector<const Edge*> replica_in_edges;
2729 TF_RETURN_IF_ERROR(
2730 in_edges[num_replicated_inputs + i]->src()->input_edges(
2731 &replica_in_edges));
2732 TF_RET_CHECK(replica_in_edges.size() == 1);
2733 const Edge* edge = replica_in_edges[0];
2734 data_inputs[pos] = {edge->src(), edge->src_output()};
2735 replicated_input_types[pos] = EdgeType(edge);
2736 AddControlInputs(*in_edges[num_replicated_inputs + i]->src(),
2737 &control_inputs);
2738 nodes_to_remove.push_back(in_edges[num_replicated_inputs + i]->src());
2739 }
2740
2741 // Appends the broadcast inputs.
2742 DataTypeVector broadcast_input_types(num_broadcast_inputs);
2743 for (int i = 0; i < num_broadcast_inputs; ++i) {
2744 int pos = num_replicas * num_replicated_inputs + num_distributed_vars + i;
2745 const Edge* edge =
2746 in_edges[num_replicated_inputs + num_distributed_vars + i];
2747 data_inputs[pos] = {edge->src(), edge->src_output()};
2748 broadcast_input_types[i] = EdgeType(edge);
2749 }
2750
2751 // Appends the variable inputs.
2752 for (int i = 0; i < num_variables; ++i) {
2753 int pos = num_replicas * num_replicated_inputs + num_distributed_vars +
2754 num_broadcast_inputs + i;
2755 const Edge* edge = in_edges[num_replicated_inputs + num_distributed_vars +
2756 num_broadcast_inputs + i];
2757 data_inputs[pos] = {edge->src(), edge->src_output()};
2758 }
2759
2760 DataTypeVector guaranteed_constant_types(num_guaranteed_constants);
2761 for (int i = 0; i < num_guaranteed_constants; ++i) {
2762 int pos = num_replicas * num_replicated_inputs + num_distributed_vars +
2763 num_broadcast_inputs + num_variables + i;
2764 const Edge* edge = in_edges[num_replicated_inputs + num_distributed_vars +
2765 num_broadcast_inputs + num_variables + i];
2766 data_inputs[pos] = {edge->src(), edge->src_output()};
2767 guaranteed_constant_types[i] = EdgeType(edge);
2768 }
2769
2770 // Outputs. All outputs from a replicated computation are replicated.
2771 const int num_outputs = replicate->output_types().size();
2772 gtl::FlatSet<Node*> control_outputs;
2773 std::vector<Node*> replicated_outputs(num_outputs);
2774 for (const Edge* e : replicate->out_edges()) {
2775 if (e->IsControlEdge()) {
2776 control_outputs.insert(e->dst());
2777 } else {
2778 TF_RET_CHECK(e->src_output() < num_outputs);
2779 TF_RET_CHECK(e->dst()->type_string() == kTPUReplicatedOutput)
2780 << e->DebugString();
2781 TF_RET_CHECK(e->dst()->output_types().size() == num_replicas);
2782 replicated_outputs[e->src_output()] = e->dst();
2783 nodes_to_remove.push_back(e->dst());
2784
2785 AddControlOutputs(*e->dst(), &control_outputs);
2786 }
2787 }
2788
2789 // Flattens the edges outgoing from the TPUReplicatedOutput nodes in
2790 // replica-major order.
2791 std::vector<std::vector<std::pair<Node*, int>>> data_outputs(num_replicas *
2792 num_outputs);
2793 DataTypeVector output_types(num_replicas * num_outputs);
2794 for (int i = 0; i < num_outputs; ++i) {
2795 std::vector<std::vector<const Edge*>> replica_out_edges(num_replicas);
2796 TF_RET_CHECK(replicated_outputs[i] != nullptr);
2797 for (const Edge* e : replicated_outputs[i]->out_edges()) {
2798 TF_RET_CHECK(!e->IsControlEdge());
2799 replica_out_edges[e->src_output()].push_back(e);
2800 }
2801
2802 for (int replica = 0; replica < num_replicas; ++replica) {
2803 const int pos = replica * num_outputs + i;
2804 for (const Edge* edge : replica_out_edges[replica]) {
2805 data_outputs[pos].push_back({edge->dst(), edge->dst_input()});
2806 }
2807 output_types[pos] = replicated_outputs[i]->input_type(0);
2808 }
2809 }
2810
2811 // TODO(b/79092708): Consolidate and cleanup to avoid TPU specialization.
2812 NodeDef def;
2813 def.set_name(replicate->name());
2814 def.set_op("_TPUReplicate");
2815 MergeDebugInfo(NodeDebugInfo(replicate->def()), &def);
2816 NameAttrList computation;
2817 computation.set_name(replicate->type_string());
2818 AddNodeAttr("computation", computation, &def);
2819 for (const auto& attr : replicate->attrs()) {
2820 def.mutable_attr()->insert(attr);
2821 }
2822 AddNodeAttr("Tinputs", replicated_input_types, &def);
2823 AddNodeAttr("Tbroadcast_inputs", broadcast_input_types, &def);
2824 AddNodeAttr("NumVariables", num_variables, &def);
2825 AddNodeAttr("Tguaranteed_constants", guaranteed_constant_types, &def);
2826 AddNodeAttr("output_types", output_types, &def);
2827 AddNodeAttr(TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
2828 mirrored_variable_indices, &def);
2829 AddNodeAttr("num_distributed_variables", num_distributed_vars, &def);
2830
2831 for (Node* node : nodes_to_remove) {
2832 VLOG(2) << "Deleting node " << node->DebugString();
2833 // Ensure that we do not attempt to add control edges to nodes that are
2834 // deleted.
2835 control_inputs.erase(node);
2836 control_outputs.erase(node);
2837 graph->RemoveNode(node);
2838 }
2839
2840 Status status;
2841 Node* tpu_replicate = graph->AddNode(def, &status);
2842 if (!status.ok()) {
2843 return status;
2844 }
2845 for (int i = 0; i < data_inputs.size(); ++i) {
2846 graph->AddEdge(data_inputs[i].first, data_inputs[i].second, tpu_replicate,
2847 i);
2848 }
2849 for (Node* n : control_inputs) {
2850 graph->AddControlEdge(n, tpu_replicate);
2851 }
2852 for (int i = 0; i < data_outputs.size(); ++i) {
2853 for (const auto& successor : data_outputs[i]) {
2854 graph->AddEdge(tpu_replicate, i, successor.first, successor.second);
2855 }
2856 }
2857 for (Node* n : control_outputs) {
2858 graph->AddControlEdge(tpu_replicate, n);
2859 }
2860 }
2861 return Status::OK();
2862 }
2863
Run(const GraphOptimizationPassOptions & options)2864 Status EncapsulateTPUComputationsPass::Run(
2865 const GraphOptimizationPassOptions& options) {
2866 VLOG(1) << "EncapsulateTPUComputations(): "
2867 << DumpGraphToFile("encapsulate_tpu_computations_before",
2868 **options.graph, options.flib_def);
2869
2870 TF_RETURN_IF_ERROR(Encapsulate(options.graph, options.flib_def));
2871 VLOG(1) << "EncapsulateTPUComputations() half-way: "
2872 << DumpGraphToFile("encapsulate_tpu_computations_halfway",
2873 **options.graph, options.flib_def);
2874
2875 TF_RETURN_IF_ERROR(BuildTPUReplicateOps(options.graph->get()));
2876 VLOG(1) << "EncapsulateTPUComputations() finished: "
2877 << DumpGraphToFile("encapsulate_tpu_computations_after",
2878 **options.graph, options.flib_def);
2879 return Status::OK();
2880 }
2881
ProcessHeadTailOutsideCompilation(const string & outside_compilation_attr_name,int * lifted_arg_count,std::unordered_map<string,XlaClusterInfo> * clusters,Graph * g,FunctionLibraryRuntime * flr,FunctionLibraryDefinition * fld)2882 Status ExtractOutsideCompilationPass::ProcessHeadTailOutsideCompilation(
2883 const string& outside_compilation_attr_name, int* lifted_arg_count,
2884 std::unordered_map<string, XlaClusterInfo>* clusters, Graph* g,
2885 FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld) {
2886 // Gather a list of pivots by cluster so we can easily look them up.
2887 absl::node_hash_map<string, Node*> pivots;
2888 string cluster_name;
2889 for (Node* node : g->nodes()) {
2890 if (TryGetNodeAttr(node->attrs(), kPivotForClusterAttr, &cluster_name)) {
2891 pivots[cluster_name] = node;
2892 }
2893 }
2894 for (auto& iter : *clusters) {
2895 // Find pivot node for this XLA cluster.
2896 Node* pivot_node = pivots[iter.first];
2897
2898 // Instantiate XLA computation function.
2899 string xla_func_name = iter.second.func_name_attrs.name();
2900 std::unique_ptr<FunctionBody> xla_fbody;
2901 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
2902 *fld->Find(xla_func_name),
2903 AttrSlice(&iter.second.func_name_attrs.attr()), fld, &xla_fbody));
2904 Graph* xla_graph = xla_fbody->graph;
2905
2906 // Make sure all nodes can be traced from sink node.
2907 FixupSourceAndSinkEdges(xla_graph);
2908
2909 // We create Identity nodes for all _Arg/_Retval nodes in XLA computation.
2910 // Remove those Identity nodes to simplify furthur processing.
2911 TF_RETURN_IF_ERROR(RemoveIdentityNodesForArgRetval(xla_graph));
2912
2913 bool rewritten;
2914 TF_RETURN_IF_ERROR(LiftOutsideCompilationOnlyArgs(
2915 xla_graph, flr, fld, lifted_arg_count, &rewritten));
2916
2917 // Move head outside compilation to host.
2918 TF_RETURN_IF_ERROR(MoveHeadOutsideCompilationToHost(
2919 outside_compilation_attr_name, iter.second.func_name_attrs.name(),
2920 iter.second.cluster_name, g, xla_graph, iter.second.node, pivot_node));
2921
2922 // Move tail outside compilation to host.
2923 TF_RETURN_IF_ERROR(MoveTailOutsideCompilationToHost(
2924 outside_compilation_attr_name, iter.second.func_name_attrs.name(),
2925 iter.second.cluster_name, g, xla_graph, iter.second.node, pivot_node));
2926
2927 // Replace outside compilation only _Arg nodes with Placeholder nodes.
2928 TF_RETURN_IF_ERROR(ReplaceArgUsedByOutsideCompilationWithPlaceholder(
2929 outside_compilation_attr_name, xla_func_name, g, xla_graph,
2930 iter.second.node));
2931
2932 // There might be direct data edges between _Arg node and _Retval node in
2933 // `xla_graph`. Remove those edges to avoid back-and-forth data transfer
2934 // between host and XLA.
2935 TF_RETURN_IF_ERROR(RemoveEdgesBetweenArgAndRetval(
2936 iter.second.func_name_attrs.name(), g, xla_graph, iter.second.node));
2937
2938 // After `MoveHeadOutsideCompilationToHost`, there might be unused XLA
2939 // inputs. Remove them.
2940 TF_RETURN_IF_ERROR(RemoveUnusedXlaInput(iter.second.func_name_attrs.name(),
2941 g, xla_graph, iter.second.node));
2942
2943 // After `MoveTailOutsideCompilationToHost`, there might be unused XLA
2944 // outputs. Remove them.
2945 TF_RETURN_IF_ERROR(RemoveUnusedXlaOutput(iter.second.func_name_attrs.name(),
2946 g, xla_graph, iter.second.node));
2947
2948 // Replace original function.
2949 FunctionDef replace_fdef;
2950 TF_RETURN_IF_ERROR(
2951 GraphToFunctionDef(*xla_graph, xla_func_name, &replace_fdef));
2952 TF_RETURN_IF_ERROR(fld->ReplaceFunction(xla_func_name, replace_fdef));
2953
2954 FixupSourceAndSinkEdges(g);
2955 }
2956
2957 return Status::OK();
2958 }
2959
Run(const GraphOptimizationPassOptions & options)2960 Status ExtractOutsideCompilationPass::Run(
2961 const GraphOptimizationPassOptions& options) {
2962 const auto* config =
2963 (options.session_options ? &options.session_options->config : nullptr);
2964 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
2965 new ProcessFunctionLibraryRuntime(
2966 /*device_mgr=*/nullptr, options.session_options->env,
2967 /*config=*/config, TF_GRAPH_DEF_VERSION, options.flib_def,
2968 config ? config->graph_options().optimizer_options()
2969 : OptimizerOptions()));
2970 FunctionLibraryRuntime* flr =
2971 pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
2972
2973 // Find XLA compile ops and their corresponding FunctionDefs.
2974 static std::map<string, string>* kNodeTypeToFunctionAttrMapping =
2975 new std::map<string, string>{
2976 {"_TPUReplicate", "computation"},
2977 };
2978 std::unordered_map<string, XlaClusterInfo> clusters;
2979 int lifted_arg_count = 0;
2980 for (Node* n : (*options.graph)->nodes()) {
2981 auto iter = kNodeTypeToFunctionAttrMapping->find(n->type_string());
2982 if (iter == kNodeTypeToFunctionAttrMapping->end()) {
2983 continue;
2984 }
2985
2986 string xla_cluster_name = n->name();
2987
2988 string func_attr = iter->second;
2989 NameAttrList func;
2990 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func));
2991
2992 std::vector<string> core_list;
2993 TF_RETURN_IF_ERROR(
2994 GetNodeAttr(n->attrs(), "host_compute_core", &core_list));
2995 std::map<string, int> host_compute_core;
2996 TF_RETURN_IF_ERROR(ParseHostComputeCoreList(core_list, &host_compute_core));
2997
2998 clusters.emplace(xla_cluster_name, XlaClusterInfo{xla_cluster_name, func, n,
2999 host_compute_core});
3000 }
3001 TF_RETURN_IF_ERROR(ProcessHeadTailOutsideCompilation(
3002 kOutsideCompilationAttr, &lifted_arg_count, &clusters,
3003 options.graph->get(), flr, options.flib_def));
3004 bool modified;
3005 TF_RETURN_IF_ERROR(ExtractOutsideCompilation(
3006 kTPUReplicateAttr, kOutsideCompilationAttr, clusters,
3007 options.graph->get(), flr, options.flib_def, &modified));
3008 if (modified) {
3009 TF_RETURN_IF_ERROR(
3010 PruneUnreachableFunctionsFromGraph(**options.graph, options.flib_def));
3011 }
3012
3013 return Status::OK();
3014 }
3015
3016 } // namespace tensorflow
3017