1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Compilation for distributed TPU (TPU_REPLICATED_CORE devices).
17
18 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h"
19
20 #include <algorithm>
21 #include <queue>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/algorithm/container.h"
26 #include "absl/container/btree_map.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/strings/escaping.h"
29 #include "tensorflow/compiler/jit/encapsulate_util.h"
30 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
31 #include "tensorflow/compiler/tf2xla/sharding_util.h"
32 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
33 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
34 #include "tensorflow/compiler/xla/array3d.h"
35 #include "tensorflow/compiler/xla/array4d.h"
36 #include "tensorflow/compiler/xla/client/sharding_builder.h"
37 #include "tensorflow/compiler/xla/service/computation_placer.h"
38 #include "tensorflow/compiler/xla/xla.pb.h"
39 #include "tensorflow/core/common_runtime/function.h"
40 #include "tensorflow/core/common_runtime/graph_constructor.h"
41 #include "tensorflow/core/common_runtime/lower_function_call_op.h"
42 #include "tensorflow/core/common_runtime/lower_functional_ops.h"
43 #include "tensorflow/core/common_runtime/lower_if_op.h"
44 #include "tensorflow/core/common_runtime/lower_while_op.h"
45 #include "tensorflow/core/common_runtime/optimization_registry.h"
46 #include "tensorflow/core/framework/function.h"
47 #include "tensorflow/core/framework/graph_to_functiondef.h"
48 #include "tensorflow/core/framework/node_def_builder.h"
49 #include "tensorflow/core/framework/node_def_util.h"
50 #include "tensorflow/core/framework/partial_tensor_shape.h"
51 #include "tensorflow/core/framework/tensor.pb.h"
52 #include "tensorflow/core/framework/types.pb.h"
53 #include "tensorflow/core/framework/versions.pb.h"
54 #include "tensorflow/core/graph/algorithm.h"
55 #include "tensorflow/core/graph/graph.h"
56 #include "tensorflow/core/lib/core/errors.h"
57 #include "tensorflow/core/lib/core/status.h"
58 #include "tensorflow/core/lib/gtl/cleanup.h"
59 #include "tensorflow/core/lib/strings/proto_serialization.h"
60 #include "tensorflow/core/lib/strings/str_util.h"
61 #include "tensorflow/core/platform/fingerprint.h"
62 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
63 #include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
64 #include "tensorflow/core/protobuf/tpu/topology.pb.h"
65 #include "tensorflow/core/public/session_options.h"
66 #include "tensorflow/core/tpu/graph_rewrite/cond_builder.h"
67 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h"
68 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h"
69 #include "tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h"
70 #include "tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h"
71 #include "tensorflow/core/tpu/tpu_compile_interface.h"
72 #include "tensorflow/core/tpu/tpu_defs.h"
73 #include "tensorflow/core/tpu/tpu_fingerprint_utils.h"
74 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
75 #include "tensorflow/core/util/device_name_utils.h"
76 #include "tensorflow/core/util/dump_graph.h"
77 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
78
79 namespace tensorflow {
80
81 namespace {
82
83 // Device coordinates are defined as (x, y, z, core), thus resulting in a rank 4
84 // topology.
85 constexpr int kTPUTopologyRank = 4;
86
87 // An upper bound on how many cores may be present in the topology.
88 static constexpr int kTPUMaxTopologySize = 4096;
89
90 // Attribute containing the serialized xla::OpSharding to be passed to the
91 // corresponding XLA HLO operation, which represents how a shape is distributed
92 // across logical cores, e.g., replication, single-device, or partitioning.
93 const char kShardingAttribute[] = "_XlaSharding";
94
95 const char kTPUPartitionedInput[] = "TPUPartitionedInput";
96 const char kTPUPartitionedOutput[] = "TPUPartitionedOutput";
97
98 const char kVarHandleOp[] = "VarHandleOp";
99
100 static const char* const kTPUCompilationResultAttr = "_tpu_compilation_status";
101 static const char* const kPostDeviceRewriteAttr = "_post_device_rewrite";
102
103 using NodeAndId = std::pair<const Node*, int>;
104
105 struct NodeAndPort {
NodeAndPorttensorflow::__anonf6cf56690111::NodeAndPort106 explicit NodeAndPort(Node* node, int port) : node(node), port(port) {}
107
108 Node* node;
109 // Port of the node, e.g. this can be the `src_output` index of an Edge.
110 int port;
111 };
112
113 class IntrusiveHeapLink {
114 public:
115 using size_type = size_t;
116 static constexpr size_type kNotMember = -1;
117
118 IntrusiveHeapLink() = default;
119
120 // Only IntrusiveHeap and LinkAccess objects should make these objects.
IntrusiveHeapLink(size_type pos)121 explicit IntrusiveHeapLink(size_type pos) : pos_{pos} {}
122
123 // Only IntrusiveHeap and LinkAccess should get the value.
get() const124 size_type get() const { return pos_; }
125
126 private:
127 size_type pos_{kNotMember};
128 };
129
130 template <typename T, IntrusiveHeapLink T::*M>
131 struct IntrusiveHeapDataMemberLinkAccess {
Gettensorflow::__anonf6cf56690111::IntrusiveHeapDataMemberLinkAccess132 IntrusiveHeapLink Get(const T* elem) const { return elem->*M; }
Settensorflow::__anonf6cf56690111::IntrusiveHeapDataMemberLinkAccess133 void Set(T* elem, IntrusiveHeapLink link) const { elem->*M = link; }
134 };
135
136 template <typename T>
137 struct DefaultIntrusiveHeapLinkAccess {
Gettensorflow::__anonf6cf56690111::DefaultIntrusiveHeapLinkAccess138 IntrusiveHeapLink Get(const T* elem) const { return elem->heap; }
Settensorflow::__anonf6cf56690111::DefaultIntrusiveHeapLinkAccess139 void Set(T* elem, IntrusiveHeapLink link) const { elem->heap = link; }
140 };
141
142 template <typename T, typename PtrCompare,
143 typename LinkAccess = DefaultIntrusiveHeapLinkAccess<T>,
144 typename Alloc = std::allocator<T*>>
145 class IntrusiveHeap {
146 public:
147 typedef typename IntrusiveHeapLink::size_type size_type;
148 typedef T value_type;
149 typedef T* pointer;
150 typedef const T* const_pointer;
151 typedef PtrCompare pointer_compare_type;
152 typedef LinkAccess link_access_type;
153 typedef Alloc allocator_type;
154
IntrusiveHeap(const pointer_compare_type & comp=pointer_compare_type (),const link_access_type & link_access=link_access_type (),const allocator_type & alloc=allocator_type ())155 explicit IntrusiveHeap(
156 const pointer_compare_type& comp = pointer_compare_type(),
157 const link_access_type& link_access = link_access_type(),
158 const allocator_type& alloc = allocator_type())
159 : rep_(comp, link_access, alloc) {}
160
size() const161 size_type size() const { return heap().size(); }
162
empty() const163 bool empty() const { return heap().empty(); }
164
165 // Return the top element, but don't remove it.
top() const166 pointer top() const {
167 DCHECK(!empty());
168 return heap()[0];
169 }
170
171 // Remove the top() pointer from the heap and return it.
Pop()172 pointer Pop() {
173 pointer t = top();
174 Remove(t);
175 return t;
176 }
177
178 // Insert 't' into the heap.
Push(pointer t)179 void Push(pointer t) {
180 SetPositionOf(t, heap().size());
181 heap().push_back(t);
182 FixHeapUp(t);
183 }
184
185 // Adjust the heap to accommodate changes in '*t'.
Adjust(pointer t)186 void Adjust(pointer t) {
187 DCHECK(Contains(t));
188 size_type h = GetPositionOf(t);
189 if (h != 0 && compare()(t, heap()[(h - 1) >> 1])) {
190 FixHeapUp(t);
191 } else {
192 FixHeapDown(t);
193 }
194 }
195
196 // Remove the specified pointer from the heap.
Remove(pointer t)197 void Remove(pointer t) {
198 DCHECK(Contains(t));
199 size_type h = GetPositionOf(t);
200 SetPositionOf(t, IntrusiveHeapLink::kNotMember);
201 if (h == heap().size() - 1) {
202 // Fast path for removing from back of heap.
203 heap().pop_back();
204 return;
205 }
206 // Move the element from the back of the heap to overwrite 't'.
207 pointer& elem = heap()[h];
208 elem = heap().back();
209 SetPositionOf(elem, h); // Element has moved, so update its link.
210 heap().pop_back();
211 Adjust(elem); // Restore the heap invariant.
212 }
213
Clear()214 void Clear() { heap().clear(); }
215
Contains(const_pointer t) const216 bool Contains(const_pointer t) const {
217 size_type h = GetPositionOf(t);
218 return (h != IntrusiveHeapLink::kNotMember) && (h < size()) &&
219 heap()[h] == t;
220 }
221
reserve(size_type n)222 void reserve(size_type n) { heap().reserve(n); }
223
capacity() const224 size_type capacity() const { return heap().capacity(); }
225
get_allocator() const226 allocator_type get_allocator() const { return rep_.heap_.get_allocator(); }
227
228 private:
229 typedef std::vector<pointer, allocator_type> heap_type;
230
231 // Empty base class optimization for pointer_compare and link_access.
232 // The heap_ data member retains a copy of the allocator, so it is not
233 // stored explicitly.
234 struct Rep : pointer_compare_type, link_access_type {
Reptensorflow::__anonf6cf56690111::IntrusiveHeap::Rep235 explicit Rep(const pointer_compare_type& cmp,
236 const link_access_type& link_access,
237 const allocator_type& alloc)
238 : pointer_compare_type(cmp),
239 link_access_type(link_access),
240 heap_(alloc) {}
241 heap_type heap_; // NOLINT
242 };
243
compare() const244 const pointer_compare_type& compare() const { return rep_; }
245
link_access() const246 const link_access_type& link_access() const { return rep_; }
247
heap() const248 const heap_type& heap() const { return rep_.heap_; }
heap()249 heap_type& heap() { return rep_.heap_; }
250
GetPositionOf(const_pointer t) const251 size_type GetPositionOf(const_pointer t) const {
252 return link_access().Get(t).get();
253 }
254
SetPositionOf(pointer t,size_type pos) const255 void SetPositionOf(pointer t, size_type pos) const {
256 return link_access().Set(t, IntrusiveHeapLink(pos));
257 }
258
FixHeapUp(pointer t)259 void FixHeapUp(pointer t) {
260 size_type h = GetPositionOf(t);
261 while (h != 0) {
262 size_type parent = (h - 1) >> 1;
263 if (compare()(heap()[parent], t)) {
264 break;
265 }
266 heap()[h] = heap()[parent];
267 SetPositionOf(heap()[h], h);
268 h = parent;
269 }
270 heap()[h] = t;
271 SetPositionOf(t, h);
272 }
273
FixHeapDown(pointer t)274 void FixHeapDown(pointer t) {
275 size_type h = GetPositionOf(t);
276 for (;;) {
277 size_type kid = (h << 1) + 1;
278 if (kid >= heap().size()) {
279 break;
280 }
281 if (kid + 1 < heap().size() && compare()(heap()[kid + 1], heap()[kid])) {
282 ++kid;
283 }
284 if (compare()(t, heap()[kid])) {
285 break;
286 }
287 heap()[h] = heap()[kid];
288 SetPositionOf(heap()[h], h);
289 h = kid;
290 }
291
292 heap()[h] = t;
293 SetPositionOf(t, h);
294 }
295
296 Rep rep_;
297 };
298
CoreDeviceLabel(int core)299 string CoreDeviceLabel(int core) {
300 return strings::StrCat("/device:", DEVICE_TPU_REPLICATED_CORE, ":", core);
301 }
302
303 // Creates a unique node name with a particular prefix.
UniqueNodeName(const StringPiece prefix,Graph * graph)304 string UniqueNodeName(const StringPiece prefix, Graph* graph) {
305 return graph->NewName(strings::StrCat(prefix, "/_", internal::GetNodeId()));
306 }
307
SetNodeDeviceForTPUCommunication(DeviceNameUtils::ParsedName device,const string & target_device_type,Node * node)308 Status SetNodeDeviceForTPUCommunication(DeviceNameUtils::ParsedName device,
309 const string& target_device_type,
310 Node* node) {
311 TF_RET_CHECK(device.has_type && device.type == DEVICE_TPU_NODE);
312 TF_RET_CHECK(device.has_id);
313 TF_RET_CHECK(HasNodeAttr(node->def(), kXlaHasHostTransferAttrName));
314
315 // Store the device instance as an attr on the Node.
316 TF_RETURN_IF_ERROR(SetDeviceOrdinalAttributeForNode(node, device.id));
317
318 // Place the execute Op on the TPU_SYSTEM device so it can access the cache of
319 // compiled protos in the resource manager.
320 device.type = target_device_type;
321 device.id = 0;
322
323 node->set_assigned_device_name(DeviceNameUtils::ParsedNameToString(device));
324 return Status::OK();
325 }
326
327 // Iterate over the nodes in the original graph and find all the TPUReplicate
328 // nodes, and all the nodes that are part of outside_compilation clusters.
FindTaggedNodes(Graph * graph,std::vector<Node * > * replicate_nodes,std::map<string,DistributedTPURewritePass::OutsideCompilationNodeMap> * outside_compilation_nodes,std::map<string,std::vector<Node * >> * head_tail_outside_compilation_nodes)329 Status FindTaggedNodes(
330 Graph* graph, std::vector<Node*>* replicate_nodes,
331 std::map<string, DistributedTPURewritePass::OutsideCompilationNodeMap>*
332 outside_compilation_nodes,
333 std::map<string, std::vector<Node*>>* head_tail_outside_compilation_nodes) {
334 for (Node* node : graph->op_nodes()) {
335 if (node->type_string() == "_TPUReplicate") {
336 replicate_nodes->push_back(node);
337 const AttrValue* cluster_attr = node->attrs().Find(kTPUReplicateAttr);
338 if (cluster_attr == nullptr) {
339 return errors::Internal("TPUReplicate node ", node->name(), " has no ",
340 kTPUReplicateAttr, " attr.");
341 } else {
342 const string& cluster = cluster_attr->s();
343 if (cluster.empty()) {
344 return errors::Internal("Attr ", kTPUReplicateAttr, " on node ",
345 node->name(), " has no string value.");
346 }
347 if (outside_compilation_nodes->find(cluster) !=
348 outside_compilation_nodes->end()) {
349 return errors::Internal(
350 "TPUReplicate node ", node->name(), " has ", kTPUReplicateAttr,
351 " attr value '", cluster,
352 "' which is a duplicate of another TPUReplicate node in the "
353 "graph.");
354 }
355 (*outside_compilation_nodes)[cluster] =
356 DistributedTPURewritePass::OutsideCompilationNodeMap();
357 (*head_tail_outside_compilation_nodes)[cluster] = std::vector<Node*>();
358 }
359 }
360 }
361 for (Node* node : graph->op_nodes()) {
362 if (node->type_string() != "_TPUReplicate") {
363 const AttrValue* cluster_attr = node->attrs().Find(kTPUReplicateAttr);
364 const AttrValue* outside_compilation_attr =
365 node->attrs().Find(kOutsideCompilationAttr);
366 if (cluster_attr == nullptr) {
367 if (outside_compilation_attr != nullptr) {
368 return errors::Internal("Node ", node->name(), " has ",
369 kOutsideCompilationAttr, " attr but no ",
370 kTPUReplicateAttr, " attr.");
371 }
372 } else {
373 const string& cluster = cluster_attr->s();
374 if (cluster.empty()) {
375 return errors::Internal("Attr ", kTPUReplicateAttr, " on node ",
376 node->name(), " has no string value.");
377 }
378 const auto iter = outside_compilation_nodes->find(cluster);
379 if (iter == outside_compilation_nodes->end()) {
380 return errors::Internal(
381 "Attr ", kTPUReplicateAttr, " on node ", node->name(),
382 " does not correspond to a TPUReplicate node.");
383 }
384 if (outside_compilation_attr == nullptr) {
385 return errors::Internal("Node ", node->name(), " has ",
386 kTPUReplicateAttr, " attr but no ",
387 kOutsideCompilationAttr, " attr.");
388 }
389 const string& oc_cluster = outside_compilation_attr->s();
390 if (oc_cluster.empty()) {
391 return errors::Internal("Attr ", kOutsideCompilationAttr, " on node ",
392 node->name(), " has no string value.");
393 }
394
395 // Outside compilation cluster at head and tail of TPU computation has
396 // already been moved to host and is already replicated. As so, do not
397 // replicate outside compilation nodes with replica id attribute.
398 int replica_id;
399 if (TryGetNodeAttr(node->def(), kXlaReplicaIdAttrName, &replica_id)) {
400 const AttrValue* head_attr =
401 node->attrs().Find("_xla_only_arg_or_oc_input");
402 const AttrValue* tail_attr =
403 node->attrs().Find("_xla_only_ret_or_oc_output");
404 if (((head_attr != nullptr) && (head_attr->b())) ||
405 ((tail_attr != nullptr) && (tail_attr->b()))) {
406 // This is safe as this has the same keys as
407 // outside_compilation_nodes which we already know has this key.
408 (*head_tail_outside_compilation_nodes)[cluster].push_back(node);
409 }
410 continue;
411 }
412 iter->second[oc_cluster].push_back(node);
413 }
414 }
415 }
416 return Status::OK();
417 }
418
419 // Helper class to spread TPU computation arguments and return values
420 // across cores.
421 // If all shapes are fully defined, balance by their size.
422 // If some of them are not fully defined, the undefined shapes size will
423 // be estimated with the average size of the fully defined ones.
424 // If none are defined, fall back to round-robin.
425 class TensorDevicePlacer {
426 public:
427 // Creates a TensorDevicePlacer object to distribute arguments or
428 // return values to a set of num_devices devices, where the types and
429 // the inferred shapes of the inputs (arguments or return values) are
430 // passed in types and shapes.
TensorDevicePlacer(int64_t num_devices,const DataTypeVector & types,const std::vector<InferredShape> & shapes)431 TensorDevicePlacer(int64_t num_devices, const DataTypeVector& types,
432 const std::vector<InferredShape>& shapes)
433 : index_nodes_(num_devices), sizes_(types.size()) {
434 int64_t total_size = 0;
435 int64_t num_defined = 0;
436 for (int64_t i = 0; i < types.size(); ++i) {
437 sizes_[i] = GetInferredShapeSize(shapes[i], types[i]);
438 if (sizes_[i] >= 0) {
439 total_size += sizes_[i];
440 ++num_defined;
441 }
442 }
443 // If a shape is undefined, select a size for it which is the average
444 // of the defined shapes. If no shapes are defined, assign 1 so that we
445 // get round-robin behavior.
446 int64_t undefined_shape_size =
447 (num_defined > 0) ? total_size / num_defined : 1;
448 for (int64_t i = 0; i < sizes_.size(); ++i) {
449 if (sizes_[i] < 0) {
450 sizes_[i] = undefined_shape_size;
451 }
452 }
453
454 for (int64_t i = 0; i < num_devices; ++i) {
455 heap_.Push(&index_nodes_[i]);
456 }
457 }
458
459 // Reports that the argument/return-value at index has been assigned
460 // by the user to a given device.
ReportDeviceAssigned(int64_t device,int64_t index)461 void ReportDeviceAssigned(int64_t device, int64_t index) {
462 if (device >= index_nodes_.size()) {
463 LOG(FATAL) << "Sharding assignment is out of bounds. " // Crash OK
464 "Check that the number of nodes is properly set.";
465 }
466 DeviceNode* node = &index_nodes_.at(device);
467 node->size += sizes_.at(index);
468 heap_.Adjust(node);
469 }
470
471 // Retrieves the device at which the argument/return-value at index
472 // should be assigned to.
RetrieveAssignment(int64_t index)473 int64 RetrieveAssignment(int64_t index) {
474 DeviceNode* node = heap_.top();
475 int64_t device = node - index_nodes_.data();
476 node->size += sizes_.at(index);
477 heap_.Adjust(node);
478 return device;
479 }
480
481 private:
482 struct DeviceNode {
483 struct Compare {
484 // Compare functor to implement a min heap using the ::gtl::IntrusiveHeap
485 // infrastructure.
operator ()tensorflow::__anonf6cf56690111::TensorDevicePlacer::DeviceNode::Compare486 bool operator()(const DeviceNode* lhs, const DeviceNode* rhs) const {
487 return lhs->size < rhs->size;
488 }
489 };
490
491 IntrusiveHeapLink heap;
492 int64 size = 0;
493 };
494
GetInferredShapeSize(const InferredShape & ishape,DataType dtype)495 static int64 GetInferredShapeSize(const InferredShape& ishape,
496 DataType dtype) {
497 return ishape.shape.IsFullyDefined()
498 ? ishape.shape.num_elements() * DataTypeSize(dtype)
499 : -1;
500 }
501
502 std::vector<DeviceNode> index_nodes_;
503 IntrusiveHeap<DeviceNode, typename DeviceNode::Compare> heap_;
504 std::vector<int64> sizes_;
505 };
506
ValidateCoreNumber(int64_t core,int64_t num_cores_per_replica)507 Status ValidateCoreNumber(int64_t core, int64_t num_cores_per_replica) {
508 if (core < 0 || core >= num_cores_per_replica) {
509 return tensorflow::errors::InvalidArgument("Invalid core ID: ", core,
510 ". The valid core IDs are [0..",
511 num_cores_per_replica, ")");
512 }
513 return Status::OK();
514 }
515
FindHostComputeKeyPlaceholderNodes(const Graph * graph,const std::vector<Node * > & replicate_nodes,std::unordered_map<string,Node * > * host_compute_key_placeholder_map)516 Status FindHostComputeKeyPlaceholderNodes(
517 const Graph* graph, const std::vector<Node*>& replicate_nodes,
518 std::unordered_map<string, Node*>* host_compute_key_placeholder_map) {
519 host_compute_key_placeholder_map->clear();
520 for (const auto node : replicate_nodes) {
521 (*host_compute_key_placeholder_map)[node->name()] = nullptr;
522 }
523
524 for (Node* node : graph->op_nodes()) {
525 if (node->type_string() == "Placeholder" &&
526 str_util::EndsWith(node->name(), "_key_placeholder")) {
527 const AttrValue* call_node_attr =
528 node->attrs().Find("_host_compute_call_node");
529 if (call_node_attr != nullptr) {
530 auto iter = host_compute_key_placeholder_map->find(call_node_attr->s());
531 if (iter == host_compute_key_placeholder_map->end()) {
532 return errors::InvalidArgument(
533 "Node ", node->name(), " has _host_compute_call_node attribute '",
534 call_node_attr->s(), "' that doesn't correspond to a call node");
535 }
536 if (iter->second != nullptr) {
537 return errors::InvalidArgument(
538 "Key placeholder node ", iter->second->name(), " for call node ",
539 call_node_attr->s(), " previously found as ",
540 iter->second->name());
541 }
542 iter->second = node;
543 }
544 }
545 }
546
547 return Status::OK();
548 }
549
ReplaceCompilationResultNodeWithIdentity(Graph * graph,Node ** node)550 Status ReplaceCompilationResultNodeWithIdentity(Graph* graph, Node** node) {
551 Node* old_node = *node;
552 // We want to replace the node with an identity node with the same name.
553 const string& node_name = old_node->name();
554
555 // Create identity node.
556 TF_ASSIGN_OR_RETURN(
557 Node * id_node,
558 BuildIdentityNode(graph, node_name, DT_STRING,
559 /*input=*/nullptr, /*requested_device=*/""));
560
561 // No incoming edges are copied as a new one will be added from compile node
562 // to id_node.
563
564 // Copy outgoing edges to the id node.
565 std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
566 old_node->out_edges().end());
567 for (const Edge* edge : out_edges) {
568 Node* dst = edge->dst();
569 int src_output = edge->src_output();
570 int dst_input = edge->dst_input();
571
572 if (src_output == Graph::kControlSlot) {
573 graph->AddControlEdge(id_node, dst);
574 } else {
575 graph->AddEdge(id_node, src_output, dst, dst_input);
576 }
577 graph->RemoveEdge(edge);
578 }
579 graph->RemoveNode(old_node);
580
581 *node = id_node;
582 return Status::OK();
583 }
584
GetStepMarkerLocation(const Node & replicate_node,xla::DebugOptions::StepMarkerLocation * location)585 Status GetStepMarkerLocation(const Node& replicate_node,
586 xla::DebugOptions::StepMarkerLocation* location) {
587 string step_marker_location_attr;
588 TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "step_marker_location",
589 &step_marker_location_attr));
590 if (step_marker_location_attr.empty()) {
591 *location = xla::DebugOptions::STEP_MARK_AT_ENTRY;
592 } else {
593 if (!xla::DebugOptions::StepMarkerLocation_Parse(step_marker_location_attr,
594 location)) {
595 return errors::InvalidArgument("Malformed step_marker_location: ",
596 step_marker_location_attr);
597 }
598 }
599 return Status::OK();
600 }
601
602 // Extracts a map of dimension and number of splits for tiled input from xla
603 // sharding attribute.
GetDimensionIndicesAndNumSplitsFromSharding(const xla::OpSharding & sharding,std::map<int,int> * split_dimension_map)604 Status GetDimensionIndicesAndNumSplitsFromSharding(
605 const xla::OpSharding& sharding, std::map<int, int>* split_dimension_map) {
606 int64_t tensor_tile_rank = sharding.tile_assignment_dimensions_size();
607 if (sharding.replicate_on_last_tile_dim()) {
608 tensor_tile_rank--;
609 }
610 for (int dim_index = 0; dim_index < tensor_tile_rank; dim_index++) {
611 if (sharding.tile_assignment_dimensions(dim_index) > 1) {
612 split_dimension_map->emplace(
613 dim_index, sharding.tile_assignment_dimensions(dim_index));
614 }
615 }
616
617 if (split_dimension_map->empty()) {
618 return errors::InvalidArgument("Arg has unnecessary tiled sharding: ",
619 sharding.DebugString());
620 }
621 return Status::OK();
622 }
623
624 // Updates contents of the function with `function_name` in function library
625 // definition `flib_def` to `new_graph`. This is required when graph
626 // transformation happens inside a function call body.
UpdateFunctionLibDefinition(const Graph & new_graph,const std::string & function_name,FunctionLibraryDefinition * flib_def)627 Status UpdateFunctionLibDefinition(const Graph& new_graph,
628 const std::string& function_name,
629 FunctionLibraryDefinition* flib_def) {
630 FunctionDef graph_fdef;
631 TF_RETURN_IF_ERROR(GraphToFunctionDef(new_graph, function_name, &graph_fdef));
632 TF_RETURN_IF_ERROR(flib_def->ReplaceFunction(function_name, graph_fdef));
633 return Status::OK();
634 }
635
636 struct NodeOut {
637 Node* node;
638 int index;
639 };
640
641 struct ShardedInputIndex {
642 int replica_id;
643 int argument_index;
644
operator <tensorflow::__anonf6cf56690111::ShardedInputIndex645 bool operator<(const ShardedInputIndex& rhs) const {
646 return std::tie(replica_id, argument_index) <
647 std::tie(rhs.replica_id, rhs.argument_index);
648 }
649 };
650
651 struct ShardedPerHostInputIndex {
652 string host_device;
653 int argument_index;
operator <tensorflow::__anonf6cf56690111::ShardedPerHostInputIndex654 bool operator<(const ShardedPerHostInputIndex& rhs) const {
655 return std::tie(host_device, argument_index) <
656 std::tie(rhs.host_device, rhs.argument_index);
657 }
operator ==tensorflow::__anonf6cf56690111::ShardedPerHostInputIndex658 bool operator==(const ShardedPerHostInputIndex& rhs) const {
659 return (argument_index == rhs.argument_index) &&
660 (host_device == rhs.host_device);
661 }
662 };
663
664 struct ShardedInputInfo {
665 // Split node that would be connected to tiled input Node.
666 Node* split_node;
667 // List of splits nodes and output index of the split node from which sharded
668 // input will be connected to the TPUExecute node. The inputs are ordered by
669 // logical core ids.
670 std::vector<NodeOut> sharded_inputs;
671 };
672
673 // Adds pad node after split node to graph for uneven sharding tiled inputs.
674 // |graph| owns the returned Node* instance.
CreatePadNode(const int padding,const int num_dims,const int split_dim,DataType dtype,Node * control_predecessor,Node * split_node,const int split_index,Graph * graph)675 xla::StatusOr<Node*> CreatePadNode(const int padding, const int num_dims,
676 const int split_dim, DataType dtype,
677 Node* control_predecessor, Node* split_node,
678 const int split_index, Graph* graph) {
679 // Add paddings node.
680 Status s;
681 NodeDef paddings_def;
682 paddings_def.set_name(
683 graph->NewName(absl::StrCat(split_node->name(), "/paddings")));
684 paddings_def.set_op("Const");
685 AddNodeAttr("dtype", DT_INT32, &paddings_def);
686 paddings_def.set_device(split_node->assigned_device_name());
687 TensorProto sizes_tensor_proto;
688 sizes_tensor_proto.set_dtype(DT_INT32);
689 for (int i = 0; i < num_dims; ++i) {
690 sizes_tensor_proto.add_int_val(0);
691 if (i == split_dim) {
692 sizes_tensor_proto.add_int_val(padding);
693 } else {
694 sizes_tensor_proto.add_int_val(0);
695 }
696 }
697 TensorShape sizes_shape({num_dims, 2});
698 sizes_shape.AsProto(sizes_tensor_proto.mutable_tensor_shape());
699 AddNodeAttr("value", sizes_tensor_proto, &paddings_def);
700 Node* paddings_node = graph->AddNode(paddings_def, &s);
701 TF_RETURN_IF_ERROR(s);
702
703 // Add Pad node.
704 NodeDef pad_def;
705 pad_def.set_name(graph->NewName(
706 absl::StrCat(split_node->name(), "/pad_shard_", split_index)));
707 pad_def.set_op("Pad");
708 pad_def.set_device(split_node->assigned_device_name());
709 AddNodeAttr("T", dtype, &pad_def);
710 AddNodeAttr("Tpaddings", DT_INT32, &pad_def);
711 pad_def.add_input(absl::StrCat(split_node->name(), ":", split_index));
712 pad_def.add_input(absl::StrCat(paddings_node->name(), ":0"));
713 Node* pad_node = graph->AddNode(pad_def, &s);
714 pad_node->set_assigned_device_name(split_node->assigned_device_name());
715 TF_RETURN_IF_ERROR(s);
716 // Add edges for pad node.
717 graph->AddEdge(split_node, split_index, pad_node, 0);
718 graph->AddEdge(paddings_node, 0, pad_node, 1);
719 graph->AddControlEdge(control_predecessor, pad_node);
720 return pad_node;
721 }
722
723 // Adds split node and split dimension node to graph for sharding tiled inputs.
724 // |graph| owns the returned Node* instance.
CreateSplitNode(const int num_splits,const int dim,const int num_dims,const int64_t padding,const int orig_src_output,DataType dtype,absl::string_view name_prefix,Node * control_predecessor,Node * orig_src,Graph * graph)725 xla::StatusOr<Node*> CreateSplitNode(const int num_splits, const int dim,
726 const int num_dims, const int64_t padding,
727 const int orig_src_output, DataType dtype,
728 absl::string_view name_prefix,
729 Node* control_predecessor, Node* orig_src,
730 Graph* graph) {
731 const std::string input_assigned_device = orig_src->assigned_device_name();
732 Node* to_split_node = orig_src;
733 int to_split_index = orig_src_output;
734 if (padding > 0) {
735 TF_ASSIGN_OR_RETURN(
736 Node * pad_node,
737 CreatePadNode(padding, num_dims, dim, dtype, control_predecessor,
738 orig_src, orig_src_output, graph));
739 to_split_node = pad_node;
740 to_split_index = 0;
741 }
742
743 // Add a split dimension node.
744 NodeDef split_dim_def;
745 split_dim_def.set_name(
746 graph->NewName(absl::StrCat(name_prefix, "/split_dim")));
747 split_dim_def.set_op("Const");
748 split_dim_def.set_device(input_assigned_device);
749 AddNodeAttr("dtype", DT_INT32, &split_dim_def);
750 TensorProto tensor_proto;
751 tensor_proto.set_dtype(DT_INT32);
752 tensor_proto.add_int_val(dim);
753 TensorShape shape({});
754 shape.AsProto(tensor_proto.mutable_tensor_shape());
755 AddNodeAttr("value", tensor_proto, &split_dim_def);
756 Status s;
757 Node* split_dim_node = graph->AddNode(split_dim_def, &s);
758 TF_RETURN_IF_ERROR(s);
759 // Add a split node.
760 NodeDef split_def;
761 split_def.set_name(graph->NewName(absl::StrCat(name_prefix, "/split")));
762 split_def.set_op("Split");
763 split_def.set_device(input_assigned_device);
764 AddNodeAttr("num_split", num_splits, &split_def);
765 AddNodeAttr("T", dtype, &split_def);
766 split_def.add_input(absl::StrCat(split_dim_node->name(), ":0"));
767 split_def.add_input(absl::StrCat(to_split_node->name(), ":", to_split_index));
768 Node* split_node = graph->AddNode(split_def, &s);
769 TF_RETURN_IF_ERROR(s);
770
771 split_node->set_assigned_device_name(input_assigned_device);
772
773 // If colocate the newly created split op to source node of input to TPU
774 // computation.
775 split_node->AddAttr(kColocationAttrName,
776 std::vector<string>{absl::StrCat(kColocationGroupPrefix,
777 orig_src->name())});
778
779 graph->AddEdge(split_dim_node, 0, split_node, 0);
780 graph->AddEdge(to_split_node, to_split_index, split_node, 1);
781
782 // Add a control dependency from `control_predecessor` to newly created
783 // constant node. This ensures that newly added split/split dim
784 // nodes are placed inside correct while loop frames when TPUExecute
785 // node is inside a host training loop.
786 graph->AddControlEdge(control_predecessor, split_dim_node);
787 return split_node;
788 }
789
GetPadding(const int split_dim,const int num_splits,const PartialTensorShape & partial_tensor_shape)790 int64 GetPadding(const int split_dim, const int num_splits,
791 const PartialTensorShape& partial_tensor_shape) {
792 // If dim dimension is not defined, no uneven sharding support.
793 if (partial_tensor_shape.dim_size(split_dim) <= 0) {
794 return 0;
795 }
796 int64_t per_split_size = tensorflow::MathUtil::CeilOfRatio<int64>(
797 partial_tensor_shape.dim_size(split_dim), num_splits);
798 int64_t total_padding =
799 per_split_size * num_splits - partial_tensor_shape.dim_size(split_dim);
800 return total_padding;
801 }
802
803 // Creates a set of splits nodes that shards tiled input node in graph.
CreateOrGetSplitNodesForInputSharding(const xla::OpSharding & sharding,int orig_arg_num,DataType dtype,const PartialTensorShape & partial_tensor_shape,int replica_id,int orig_src_output,Node * orig_src,Node * control_predecessor,Graph * graph,std::map<ShardedInputIndex,ShardedInputInfo> * arg_index_to_sharded_input_map)804 xla::StatusOr<ShardedInputInfo> CreateOrGetSplitNodesForInputSharding(
805 const xla::OpSharding& sharding, int orig_arg_num, DataType dtype,
806 const PartialTensorShape& partial_tensor_shape, int replica_id,
807 int orig_src_output, Node* orig_src, Node* control_predecessor,
808 Graph* graph,
809 std::map<ShardedInputIndex, ShardedInputInfo>*
810 arg_index_to_sharded_input_map) {
811 ShardedInputIndex input_index{replica_id, orig_arg_num};
812 auto iter = arg_index_to_sharded_input_map->find(input_index);
813 if (iter != arg_index_to_sharded_input_map->end()) {
814 return iter->second;
815 }
816 // Maps input dimension and number of splits with which the
817 // dimension sharded.
818 std::map<int, int> split_dimension_map;
819 TF_RETURN_IF_ERROR(GetDimensionIndicesAndNumSplitsFromSharding(
820 sharding, &split_dimension_map));
821 TF_RET_CHECK(!split_dimension_map.empty())
822 << "Unnecessary sharding attribute found.";
823
824 // For v1 while loop, nodes inside the loop body must either
825 // 1) Have data edges from while loop input node.
826 // or
827 // 2) Have direct control dependency from while loop input control
828 // node.
829 //
830 // As so, if we are adding Split node inside, while loop body,
831 // we must manually add a control dependency to a node inside
832 // a while loop (i.e. `control_predecessor`) to constant nodes
833 // without data in-edges to make sure that added split nodes
834 // have correct frame name. Else, placer will complain when
835 // `BuildControlFlow()` is invoked.
836
837 auto sharding_it = split_dimension_map.begin();
838 std::queue<Node*> split_nodes_for_dimension;
839 absl::flat_hash_map<Node*, int> node_to_split_dim;
840 int split_dimension = sharding_it->first;
841 int num_split = sharding_it->second;
842
843 // Creates a tree of split nodes for sharding tiled inputs. Splits nodes
844 // are created such that input data is sharded in row major order.
845 // Split nodes at ith depth from the original input node represent nodes
846 // that split the input data at ith dimension.
847 TF_ASSIGN_OR_RETURN(
848 Node * root_split_node,
849 CreateSplitNode(
850 num_split, split_dimension, partial_tensor_shape.dims(),
851 GetPadding(split_dimension, num_split, partial_tensor_shape),
852 orig_src_output, dtype,
853 absl::StrCat("sharded_input/replica_", replica_id, "_dim_",
854 split_dimension),
855 control_predecessor, orig_src, graph));
856 sharding_it++;
857
858 split_nodes_for_dimension.emplace(root_split_node);
859 node_to_split_dim[root_split_node] = split_dimension;
860
861 while (sharding_it != split_dimension_map.end()) {
862 split_dimension = sharding_it->first;
863 num_split = sharding_it->second;
864 int num_split_nodes_in_dimension = split_nodes_for_dimension.size();
865 for (int i = 0; i < num_split_nodes_in_dimension; ++i) {
866 Node* input_split_node = split_nodes_for_dimension.front();
867 split_nodes_for_dimension.pop();
868 for (int src_output_index = 0;
869 src_output_index < input_split_node->num_outputs();
870 ++src_output_index) {
871 TF_ASSIGN_OR_RETURN(
872 Node * split_node,
873 CreateSplitNode(
874 num_split, split_dimension, partial_tensor_shape.dims(),
875 GetPadding(split_dimension, num_split, partial_tensor_shape),
876 src_output_index, dtype,
877 absl::StrCat("sharded_input/replica_", replica_id, "_dim_",
878 split_dimension),
879 control_predecessor, input_split_node, graph));
880 split_nodes_for_dimension.emplace(split_node);
881 node_to_split_dim[split_node] = split_dimension;
882 }
883 }
884 sharding_it++;
885 }
886
887 // `split_nodes_for_dimension` now includes final split nodes
888 // from which sharded data will be fed into TPUExcute nodes -- sorted by
889 // row major order.
890 std::vector<NodeOut> sharded_inputs_list(
891 sharding.tile_assignment_devices_size());
892 int64_t next_core_tile_index = 0;
893 while (!split_nodes_for_dimension.empty()) {
894 Node* split_node = split_nodes_for_dimension.front();
895 split_nodes_for_dimension.pop();
896 int num_splits;
897 TF_RETURN_IF_ERROR(
898 GetNodeAttr(split_node->def(), "num_split", &num_splits));
899 for (int out_index = 0; out_index < num_splits; ++out_index) {
900 int64_t repeat_count =
901 sharding.replicate_on_last_tile_dim()
902 ? *sharding.tile_assignment_dimensions().rbegin()
903 : 1;
904 for (int64_t i = 0; i < repeat_count; ++i) {
905 int64_t next_core =
906 sharding.tile_assignment_devices(next_core_tile_index++);
907 sharded_inputs_list[next_core] = NodeOut{split_node, out_index};
908 }
909 }
910 }
911
912 ShardedInputInfo sharded_input_info{root_split_node,
913 std::move(sharded_inputs_list)};
914 (*arg_index_to_sharded_input_map)[input_index] = sharded_input_info;
915 return sharded_input_info;
916 }
917
918 // Creates a xla split node to shard an input, and adds that new node to a
919 // Graph.
CreateXlaSplitOp(absl::string_view node_name,const bool is_resource,const NodeOut & input,const PartialTensorShape & partial_tensor_shape,const std::vector<Node * > & control_inputs,const std::vector<Node * > & control_outputs,const DataType dtype,const int num_shards,const xla::OpSharding & sharding,Graph * graph)920 StatusOr<Node*> CreateXlaSplitOp(absl::string_view node_name,
921 const bool is_resource, const NodeOut& input,
922 const PartialTensorShape& partial_tensor_shape,
923 const std::vector<Node*>& control_inputs,
924 const std::vector<Node*>& control_outputs,
925 const DataType dtype, const int num_shards,
926 const xla::OpSharding& sharding,
927 Graph* graph) {
928 const std::string& input_assigned_device = input.node->assigned_device_name();
929 NodeDef xla_split_def;
930 xla_split_def.set_name(graph->NewName(node_name));
931 xla_split_def.set_op(is_resource ? "ReadVariableXlaSplitND" : "XlaSplitND");
932 xla_split_def.set_device(input_assigned_device);
933 AddNodeAttr("T", dtype, &xla_split_def);
934 AddNodeAttr("N", num_shards, &xla_split_def);
935 const std::vector<int64> num_splits(
936 sharding.tile_assignment_dimensions().begin(),
937 sharding.replicate_on_last_tile_dim()
938 ? std::prev(sharding.tile_assignment_dimensions().end())
939 : sharding.tile_assignment_dimensions().end());
940 AddNodeAttr("num_splits", num_splits, &xla_split_def);
941 const int rank = sharding.replicate_on_last_tile_dim()
942 ? sharding.tile_assignment_dimensions_size() - 1
943 : sharding.tile_assignment_dimensions_size();
944 std::vector<int32> paddings;
945 paddings.reserve(rank);
946 for (int dim = 0; dim < rank; ++dim) {
947 paddings.push_back(GetPadding(dim, sharding.tile_assignment_dimensions(dim),
948 partial_tensor_shape));
949 }
950 AddNodeAttr("paddings", paddings, &xla_split_def);
951
952 if (!is_resource) {
953 AddNodeAttr("_tpu_avoid_constant_fold", "not_used", &xla_split_def);
954 AddNodeAttr(kColocationAttrName,
955 std::vector<string>{
956 absl::StrCat(kColocationGroupPrefix, input.node->name())},
957 &xla_split_def);
958 }
959
960 Status s;
961 Node* xla_split = graph->AddNode(xla_split_def, &s);
962 TF_RETURN_IF_ERROR(s);
963 if (is_resource) {
964 xla_split->set_requested_device(input.node->requested_device());
965 }
966 xla_split->set_assigned_device_name(input_assigned_device);
967 graph->AddEdge(input.node, input.index, xla_split, 0);
968 for (Node* control_input : control_inputs) {
969 graph->AddControlEdge(control_input, xla_split);
970 }
971 for (Node* control_output : control_outputs) {
972 graph->AddControlEdge(xla_split, control_output);
973 }
974 return xla_split;
975 }
976
977 // Creates a sharded tensor list for all input shards of an input with sharding.
ShardInputWithXlaSplitOp(absl::string_view node_name,const bool is_resource,const NodeOut & input,const PartialTensorShape & partial_tensor_shape,const std::vector<Node * > & control_inputs,const std::vector<Node * > & control_outputs,const DataType dtype,const xla::OpSharding & sharding,Graph * graph)978 xla::StatusOr<std::vector<NodeOut>> ShardInputWithXlaSplitOp(
979 absl::string_view node_name, const bool is_resource, const NodeOut& input,
980 const PartialTensorShape& partial_tensor_shape,
981 const std::vector<Node*>& control_inputs,
982 const std::vector<Node*>& control_outputs, const DataType dtype,
983 const xla::OpSharding& sharding, Graph* graph) {
984 const int repeat = sharding.replicate_on_last_tile_dim()
985 ? *sharding.tile_assignment_dimensions().rbegin()
986 : 1;
987 const int num_shards = sharding.tile_assignment_devices_size() / repeat;
988
989 TF_ASSIGN_OR_RETURN(
990 Node * xla_split,
991 CreateXlaSplitOp(node_name, is_resource, input, partial_tensor_shape,
992 control_inputs, control_outputs, dtype, num_shards,
993 sharding, graph));
994
995 std::vector<NodeOut> sharded_inputs_list(
996 sharding.tile_assignment_devices_size());
997
998 for (int i = 0; i < num_shards; ++i) {
999 for (int j = 0; j < repeat; ++j) {
1000 const int index = i * repeat + j;
1001 const int core = sharding.tile_assignment_devices(index);
1002 sharded_inputs_list[core] = {xla_split, i};
1003 }
1004 }
1005
1006 return sharded_inputs_list;
1007 }
1008
1009 // Creates an XlaSplitND op to shard a per-replica arg.
CreateOrGetXlaSplitNodeForShardedPerReplicaArg(const xla::OpSharding & sharding,const int replica_id,const int orig_arg_num,DataType dtype,const PartialTensorShape & partial_tensor_shape,Node * orig_src,const int orig_src_output,Graph * graph,std::map<ShardedInputIndex,ShardedInputInfo> * arg_index_to_sharded_input_map)1010 xla::StatusOr<ShardedInputInfo> CreateOrGetXlaSplitNodeForShardedPerReplicaArg(
1011 const xla::OpSharding& sharding, const int replica_id,
1012 const int orig_arg_num, DataType dtype,
1013 const PartialTensorShape& partial_tensor_shape, Node* orig_src,
1014 const int orig_src_output, Graph* graph,
1015 std::map<ShardedInputIndex, ShardedInputInfo>*
1016 arg_index_to_sharded_input_map) {
1017 ShardedInputIndex input_index{replica_id, orig_arg_num};
1018 auto iter = arg_index_to_sharded_input_map->find(input_index);
1019 if (iter != arg_index_to_sharded_input_map->end()) {
1020 return iter->second;
1021 }
1022
1023 TF_ASSIGN_OR_RETURN(
1024 std::vector<NodeOut> sharded_inputs_list,
1025 ShardInputWithXlaSplitOp(
1026 absl::StrCat(orig_src->name(), "/replica_", replica_id, "_split"),
1027 /*is_resource=*/false, /*input=*/{orig_src, orig_src_output},
1028 partial_tensor_shape, /*control_inputs=*/{}, /*control_outputs=*/{},
1029 dtype, sharding, graph));
1030
1031 ShardedInputInfo sharded_input_info{nullptr, std::move(sharded_inputs_list)};
1032 (*arg_index_to_sharded_input_map)[input_index] = sharded_input_info;
1033 return sharded_input_info;
1034 }
1035
1036 // Creates an XlaSplitND op to shard a distributed arg.
CreateOrGetXlaSplitNodeForDistributedArg(const xla::OpSharding & sharding,const int num_replicas,const int replica_id,const int orig_arg_num,DataType dtype,const PartialTensorShape & partial_tensor_shape,Node * orig_src,const int orig_src_output,Graph * graph,std::map<ShardedInputIndex,ShardedInputInfo> * arg_index_to_sharded_input_map)1037 xla::StatusOr<ShardedInputInfo> CreateOrGetXlaSplitNodeForDistributedArg(
1038 const xla::OpSharding& sharding, const int num_replicas,
1039 const int replica_id, const int orig_arg_num, DataType dtype,
1040 const PartialTensorShape& partial_tensor_shape, Node* orig_src,
1041 const int orig_src_output, Graph* graph,
1042 std::map<ShardedInputIndex, ShardedInputInfo>*
1043 arg_index_to_sharded_input_map) {
1044 ShardedInputIndex input_index{replica_id, orig_arg_num};
1045 auto iter = arg_index_to_sharded_input_map->find(input_index);
1046 if (iter != arg_index_to_sharded_input_map->end()) {
1047 return iter->second;
1048 }
1049
1050 TF_ASSIGN_OR_RETURN(
1051 std::vector<NodeOut> sharded_inputs_list,
1052 ShardInputWithXlaSplitOp(
1053 absl::StrCat(orig_src->name(), "/distributed_split"),
1054 /*is_resource=*/false, /*input=*/{orig_src, orig_src_output},
1055 partial_tensor_shape, /*control_inputs=*/{}, /*control_outputs=*/{},
1056 dtype, sharding, graph));
1057
1058 ShardedInputInfo sharded_input_info{nullptr, std::move(sharded_inputs_list)};
1059 for (int replica = 0; replica < num_replicas; ++replica) {
1060 (*arg_index_to_sharded_input_map)[{replica, orig_arg_num}] =
1061 sharded_input_info;
1062 }
1063 return sharded_input_info;
1064 }
1065
1066 // Creates an ReadVariableXlaSplitND op to shard a variable arg.
CreateOrGetXlaSplitNodeForVariableArg(const xla::OpSharding & sharding,const int num_replicas,const int replica_id,const int orig_arg_num,DataType dtype,const PartialTensorShape & partial_tensor_shape,Node * orig_src,const int orig_src_output,Graph * graph,std::vector<Node * > * to_be_removed_nodes,std::map<ShardedInputIndex,ShardedInputInfo> * arg_index_to_sharded_input_map)1067 xla::StatusOr<ShardedInputInfo> CreateOrGetXlaSplitNodeForVariableArg(
1068 const xla::OpSharding& sharding, const int num_replicas,
1069 const int replica_id, const int orig_arg_num, DataType dtype,
1070 const PartialTensorShape& partial_tensor_shape, Node* orig_src,
1071 const int orig_src_output, Graph* graph,
1072 std::vector<Node*>* to_be_removed_nodes,
1073 std::map<ShardedInputIndex, ShardedInputInfo>*
1074 arg_index_to_sharded_input_map) {
1075 ShardedInputIndex input_index{replica_id, orig_arg_num};
1076 auto iter = arg_index_to_sharded_input_map->find(input_index);
1077 if (iter != arg_index_to_sharded_input_map->end()) {
1078 return iter->second;
1079 }
1080
1081 DCHECK_EQ(orig_src->type_string(), "ReadVariableOp");
1082 std::vector<Node*> control_outputs;
1083 std::vector<const Edge*> edges_to_remove;
1084 for (const Edge* edge : orig_src->out_edges()) {
1085 if (edge->IsControlEdge()) {
1086 control_outputs.push_back(edge->dst());
1087 }
1088 edges_to_remove.push_back(edge);
1089 }
1090
1091 to_be_removed_nodes->push_back(orig_src);
1092
1093 const Edge* resource = nullptr;
1094 TF_RETURN_IF_ERROR(orig_src->input_edge(0, &resource));
1095
1096 std::vector<Node*> control_inputs;
1097 for (const Edge* edge : orig_src->in_edges()) {
1098 if (edge->IsControlEdge()) {
1099 control_inputs.push_back(edge->src());
1100 }
1101 }
1102
1103 TF_ASSIGN_OR_RETURN(
1104 std::vector<NodeOut> sharded_inputs_list,
1105 ShardInputWithXlaSplitOp(
1106 absl::StrCat(resource->src()->name(), "/read_variable_split"),
1107 /*is_resource=*/true,
1108 /*input=*/{resource->src(), resource->src_output()},
1109 partial_tensor_shape, control_inputs, control_outputs, dtype,
1110 sharding, graph));
1111
1112 for (const Edge* edge : edges_to_remove) {
1113 graph->RemoveControlEdge(edge);
1114 }
1115
1116 DCHECK(orig_src->out_edges().empty());
1117
1118 ShardedInputInfo sharded_input_info{nullptr, std::move(sharded_inputs_list)};
1119 for (int replica = 0; replica < num_replicas; ++replica) {
1120 ShardedInputIndex idx{replica, orig_arg_num};
1121 // Refrain from overwriting, if dummy inputs were already placed instead.
1122 arg_index_to_sharded_input_map->insert({idx, sharded_input_info});
1123 }
1124 return sharded_input_info;
1125 }
1126
1127 // Creates a concat node to be used for aggregating sharded retvals across
1128 // logical cores.
CreateConcatNode(int dim,int num_splits,DataType dtype,absl::string_view name_prefix,const std::vector<NodeOut> & inputs,Graph * graph,absl::string_view device)1129 xla::StatusOr<Node*> CreateConcatNode(int dim, int num_splits, DataType dtype,
1130 absl::string_view name_prefix,
1131 const std::vector<NodeOut>& inputs,
1132 Graph* graph, absl::string_view device) {
1133 // Add a Concat dim node.
1134 NodeDef concat_dim_def;
1135 concat_dim_def.set_name(
1136 graph->NewName(absl::StrCat(name_prefix, "/concat_dim")));
1137 concat_dim_def.set_op("Const");
1138 AddNodeAttr("dtype", DT_INT32, &concat_dim_def);
1139 concat_dim_def.set_device(std::string(device));
1140 TensorProto tensor_proto;
1141 tensor_proto.set_dtype(DT_INT32);
1142 tensor_proto.add_int_val(dim);
1143 TensorShape shape({});
1144 shape.AsProto(tensor_proto.mutable_tensor_shape());
1145 AddNodeAttr("value", tensor_proto, &concat_dim_def);
1146 Status s;
1147 Node* concat_dim_node = graph->AddNode(concat_dim_def, &s);
1148 TF_RETURN_IF_ERROR(s);
1149
1150 // Add a Concat node.
1151 NodeDef concat_def;
1152 concat_def.set_name(graph->NewName(absl::StrCat(name_prefix, "/concat")));
1153 concat_def.set_op("Concat");
1154 AddNodeAttr("N", num_splits, &concat_def);
1155 AddNodeAttr("T", dtype, &concat_def);
1156 concat_def.add_input(absl::StrCat(concat_dim_node->name(), ":0"));
1157 concat_def.set_device(std::string(device));
1158 for (const auto& i : inputs) {
1159 concat_def.add_input(absl::StrCat(i.node->name(), ":", i.index));
1160 }
1161 Node* concat_node = graph->AddNode(concat_def, &s);
1162 TF_RETURN_IF_ERROR(s);
1163
1164 graph->AddEdge(concat_dim_node, 0, concat_node, 0);
1165
1166 // 0th input to concat node is a concat dim node. So we start from 1st input
1167 // and add all input edges.
1168 int dst_input = 1;
1169 for (const auto& i : inputs) {
1170 graph->AddEdge(i.node, i.index, concat_node, dst_input);
1171 ++dst_input;
1172 }
1173 return concat_node;
1174 }
1175
1176 // Adds slice node after concat node to graph for uneven sharding tiled inputs.
CreateSliceNode(DataType dtype,const PartialTensorShape & shape,Node * concat_node,const int concat_out_index,Graph * graph,absl::string_view device)1177 xla::StatusOr<Node*> CreateSliceNode(DataType dtype,
1178 const PartialTensorShape& shape,
1179 Node* concat_node,
1180 const int concat_out_index, Graph* graph,
1181 absl::string_view device) {
1182 Status s;
1183 // Add begin node for concat.
1184 NodeDef begin_def;
1185 begin_def.set_name(
1186 graph->NewName(absl::StrCat(concat_node->name(), "/slice_begin")));
1187 begin_def.set_op("Const");
1188 AddNodeAttr("dtype", DT_INT32, &begin_def);
1189 begin_def.set_device(std::string(device));
1190 TensorProto begin_tensor_proto;
1191 begin_tensor_proto.set_dtype(DT_INT32);
1192 for (int i = 0; i < shape.dims(); ++i) {
1193 begin_tensor_proto.add_int_val(0);
1194 }
1195 TensorShape begin_shape({shape.dims()});
1196 begin_shape.AsProto(begin_tensor_proto.mutable_tensor_shape());
1197 AddNodeAttr("value", begin_tensor_proto, &begin_def);
1198 Node* begin_node = graph->AddNode(begin_def, &s);
1199 TF_RETURN_IF_ERROR(s);
1200
1201 // Add size node.
1202 NodeDef size_def;
1203 size_def.set_name(
1204 graph->NewName(absl::StrCat(concat_node->name(), "/slice_size")));
1205 size_def.set_op("Const");
1206 AddNodeAttr("dtype", DT_INT32, &size_def);
1207 size_def.set_device(std::string(device));
1208 TensorProto sizes_tensor_proto;
1209 sizes_tensor_proto.set_dtype(DT_INT32);
1210 for (int i = 0; i < shape.dims(); ++i) {
1211 sizes_tensor_proto.add_int_val(shape.dim_size(i));
1212 }
1213 TensorShape sizes_shape({shape.dims()});
1214 sizes_shape.AsProto(sizes_tensor_proto.mutable_tensor_shape());
1215 AddNodeAttr("value", sizes_tensor_proto, &size_def);
1216 Node* size_node = graph->AddNode(size_def, &s);
1217 TF_RETURN_IF_ERROR(s);
1218
1219 // Add Slice node.
1220 NodeDef slice_def;
1221 slice_def.set_name(
1222 graph->NewName(absl::StrCat(concat_node->name(), "/slice")));
1223 slice_def.set_op("Slice");
1224 slice_def.set_device(std::string(device));
1225 AddNodeAttr("T", dtype, &slice_def);
1226 AddNodeAttr("Index", DT_INT32, &slice_def);
1227 slice_def.add_input(absl::StrCat(concat_node->name(), ":", concat_out_index));
1228 slice_def.add_input(absl::StrCat(begin_node->name(), ":0"));
1229 slice_def.add_input(absl::StrCat(size_node->name(), ":0"));
1230 Node* slice_node = graph->AddNode(slice_def, &s);
1231 TF_RETURN_IF_ERROR(s);
1232 // Add edges for slice node.
1233 graph->AddEdge(concat_node, concat_out_index, slice_node, 0);
1234 graph->AddEdge(begin_node, 0, slice_node, 1);
1235 graph->AddEdge(size_node, 0, slice_node, 2);
1236 return slice_node;
1237 }
1238
1239 // Creates a set of Concat nodes that aggregates sharded outputs from TPUExecute
1240 // nodes into a single output. Sharded outputs are concatenated along row major
1241 // order. That is, tiled output along 0th dimension will be concatenated last.
CreateConcatNodesForRetval(const xla::OpSharding & sharding,DataType dtype,const PartialTensorShape & inferred_shape,int replica_id,const std::vector<NodeOut> & orig_inputs,Graph * graph,absl::string_view device)1242 xla::StatusOr<Node*> CreateConcatNodesForRetval(
1243 const xla::OpSharding& sharding, DataType dtype,
1244 const PartialTensorShape& inferred_shape, int replica_id,
1245 const std::vector<NodeOut>& orig_inputs, Graph* graph,
1246 absl::string_view device) {
1247 std::map<int, int> split_dimension_map;
1248 TF_RETURN_IF_ERROR(GetDimensionIndicesAndNumSplitsFromSharding(
1249 sharding, &split_dimension_map));
1250 std::vector<NodeOut> inputs_to_sharded_retval = orig_inputs;
1251 bool has_paddings = false;
1252
1253 for (auto it = split_dimension_map.rbegin(); it != split_dimension_map.rend();
1254 it++) {
1255 auto dim = it->first;
1256 auto num_splits = it->second;
1257
1258 int num_concat_nodes = inputs_to_sharded_retval.size() / num_splits;
1259 int input_index_to_concat_node = 0;
1260
1261 std::vector<NodeOut> new_concat_nodes;
1262 for (int i = 0; i < num_concat_nodes; ++i) {
1263 auto concat_input_it =
1264 inputs_to_sharded_retval.begin() + input_index_to_concat_node;
1265 std::vector<NodeOut> inputs(concat_input_it,
1266 concat_input_it + num_splits);
1267 input_index_to_concat_node += num_splits;
1268
1269 TF_ASSIGN_OR_RETURN(
1270 Node * concat_node,
1271 CreateConcatNode(
1272 dim, num_splits, dtype,
1273 absl::StrCat("sharded_output/replica_", replica_id, "_dim_", dim),
1274 inputs, graph, device));
1275 int64_t paddings = GetPadding(dim, num_splits, inferred_shape);
1276 has_paddings |= paddings > 0;
1277 new_concat_nodes.emplace_back(NodeOut{concat_node, 0});
1278 }
1279 inputs_to_sharded_retval = new_concat_nodes;
1280 }
1281
1282 TF_RET_CHECK(inputs_to_sharded_retval.size() == 1);
1283 if (has_paddings) {
1284 TF_ASSIGN_OR_RETURN(Node * slice_node,
1285 CreateSliceNode(dtype, inferred_shape,
1286 inputs_to_sharded_retval.at(0).node,
1287 /*concat_out_index*/ 0, graph, device));
1288 return slice_node;
1289 }
1290 return inputs_to_sharded_retval.at(0).node;
1291 }
1292
CreateXlaConcatNode(const xla::OpSharding & sharding,const int replica_id,DataType dtype,const PartialTensorShape & partial_tensor_shape,const std::vector<NodeOut> & orig_inputs,absl::string_view device,Graph * graph)1293 xla::StatusOr<Node*> CreateXlaConcatNode(
1294 const xla::OpSharding& sharding, const int replica_id, DataType dtype,
1295 const PartialTensorShape& partial_tensor_shape,
1296 const std::vector<NodeOut>& orig_inputs, absl::string_view device,
1297 Graph* graph) {
1298 NodeDef xla_concat_def;
1299 xla_concat_def.set_name(graph->NewName(
1300 absl::StrCat("sharded_output/replica_", replica_id, "_concat")));
1301 xla_concat_def.set_op("XlaConcatND");
1302 xla_concat_def.set_device(std::string(device));
1303 AddNodeAttr("T", dtype, &xla_concat_def);
1304 AddNodeAttr("N", static_cast<int64>(orig_inputs.size()), &xla_concat_def);
1305 const std::vector<int64> num_concats(
1306 sharding.tile_assignment_dimensions().begin(),
1307 sharding.replicate_on_last_tile_dim()
1308 ? std::prev(sharding.tile_assignment_dimensions().end())
1309 : sharding.tile_assignment_dimensions().end());
1310 AddNodeAttr("num_concats", num_concats, &xla_concat_def);
1311 const int rank = sharding.replicate_on_last_tile_dim()
1312 ? sharding.tile_assignment_dimensions_size() - 1
1313 : sharding.tile_assignment_dimensions_size();
1314 std::vector<int32> paddings;
1315 paddings.reserve(rank);
1316 for (int dim = 0; dim < rank; ++dim) {
1317 paddings.push_back(GetPadding(dim, sharding.tile_assignment_dimensions(dim),
1318 partial_tensor_shape));
1319 }
1320 AddNodeAttr("paddings", paddings, &xla_concat_def);
1321
1322 Status s;
1323 Node* xla_concat = graph->AddNode(xla_concat_def, &s);
1324 TF_RETURN_IF_ERROR(s);
1325 for (int i = 0, e = orig_inputs.size(); i < e; ++i) {
1326 const NodeOut& input = orig_inputs[i];
1327 graph->AddEdge(input.node, input.index, xla_concat, i);
1328 }
1329 return xla_concat;
1330 }
1331
1332 // Set the padding ops the same devices as the original inputs. If the original
1333 // inputs are on TPUs, the padding ops will be placed on TPUs and XLA on demand
1334 // mode will be triggered, so we don't need to copy the data back to the host
1335 // to do the padding.
SetPaddingNodesDevices(Graph * graph)1336 Status SetPaddingNodesDevices(Graph* graph) {
1337 for (Node* n : graph->op_nodes()) {
1338 bool tpu_padding_attr;
1339 if (n->type_string() == "Pad" &&
1340 GetNodeAttr(n->attrs(), kPostDeviceRewriteAttr, &tpu_padding_attr)
1341 .ok()) {
1342 Node* unpadded_input;
1343 TF_RETURN_IF_ERROR(n->input_node(0, &unpadded_input));
1344
1345 const string& requested_device = unpadded_input->requested_device();
1346 const string& assigned_device = unpadded_input->assigned_device_name();
1347 if (!requested_device.empty() || !assigned_device.empty()) {
1348 // The output nodes of the original unpadded inputs include the padded
1349 // inputs and real shapes of inputs, we assign those to the same device
1350 // as the original inputs.
1351 for (Node* out : unpadded_input->out_nodes()) {
1352 if (GetNodeAttr(out->attrs(), kPostDeviceRewriteAttr,
1353 &tpu_padding_attr)
1354 .ok()) {
1355 out->set_requested_device(requested_device);
1356 out->set_assigned_device_name(assigned_device);
1357 }
1358 }
1359 // There might be a tf.shape node added before TPUCompileOp, we need to
1360 // set its device as well.
1361 for (Node* out : n->out_nodes()) {
1362 if (n->type_string() == "Shape") {
1363 out->set_requested_device(requested_device);
1364 out->set_assigned_device_name(assigned_device);
1365 }
1366 }
1367 }
1368 }
1369 }
1370 return Status::OK();
1371 }
1372
AssignedOrRequestedDevice(const Node * node)1373 const string& AssignedOrRequestedDevice(const Node* node) {
1374 if (!node->assigned_device_name().empty()) {
1375 return node->assigned_device_name();
1376 }
1377 return node->requested_device();
1378 }
1379
IsTpuDevice(const string & device_string)1380 bool IsTpuDevice(const string& device_string) {
1381 DeviceNameUtils::ParsedName device;
1382 return DeviceNameUtils::ParseFullName(device_string, &device) &&
1383 device.type == DEVICE_TPU_NODE;
1384 }
1385
1386 // Returns a set of device ops can be placed on TPU. There is no strict rule of
1387 // thumb to decide which ops should be in the list, but empirically they are
1388 // mostly dummy ops like Identity-like ops or control flow related ops. However
1389 // people can add also add other ops like Pad to allow data stay on TPU.
PlaceOnTPUOpList()1390 const absl::flat_hash_set<std::string>& PlaceOnTPUOpList() {
1391 static const auto place_on_tpu_ops = new absl::flat_hash_set<std::string>(
1392 {"Identity", "IdentityN", "Enter", "Exit", "Switch", "Merge",
1393 "NextIteration", "Shape", "_Retval"});
1394 return *place_on_tpu_ops;
1395 }
1396
1397 // If an op satisfies the following conditions, it will be placed on the same
1398 // TPU device as its inputs:
1399 // (1) The op can be placed on TPU (in the PlaceOnTPUOpList)
1400 // (2) The op itself has no requested or assigned devices.
1401 // (3) All the data inputs of this op are placed on the same device on TPUs.
1402 // There are exceptions like the NextIterations input of Switch node can
1403 // be placed on CPU as it is just a boolean.
1404 //
1405 // Returns true if the node device has been changed, otherwise returns false.
PlaceOpsOnTPU(Node * node)1406 bool PlaceOpsOnTPU(Node* node) {
1407 if (!AssignedOrRequestedDevice(node).empty() ||
1408 !PlaceOnTPUOpList().contains(node->type_string())) {
1409 return false;
1410 }
1411 string src_tpu_device = "";
1412 Node* src_node;
1413 for (const Edge* e : node->in_edges()) {
1414 if (e->IsControlEdge()) {
1415 continue;
1416 }
1417 Node* src = e->src();
1418 const string& src_device = AssignedOrRequestedDevice(src);
1419
1420 // Make exceptions that we don't force the some inputs to place on TPUs.
1421 if (node->IsSwitch() && src->IsLoopCond()) {
1422 continue;
1423 }
1424
1425 if (!IsTpuDevice(src_device) ||
1426 (!src_tpu_device.empty() && src_device != src_tpu_device)) {
1427 return false;
1428 }
1429 if (src_tpu_device.empty()) {
1430 src_tpu_device = src_device;
1431 src_node = src;
1432 }
1433 }
1434 node->set_assigned_device_name(src_node->assigned_device_name());
1435 node->set_requested_device(src_node->requested_device());
1436 return true;
1437 }
1438
CreateOpMetadataFromNode(const Node & node)1439 xla::OpMetadata CreateOpMetadataFromNode(const Node& node) {
1440 xla::OpMetadata metadata;
1441 metadata.set_op_type(node.type_string());
1442 metadata.set_op_name(node.name());
1443 return metadata;
1444 }
1445
1446 // Helper struct holding node (nullable) and associated sharding.
1447 struct NodeAndSharding {
NodeAndShardingtensorflow::__anonf6cf56690111::NodeAndSharding1448 explicit NodeAndSharding(const Node* node, const xla::OpSharding& sharding)
1449 : node(node), sharding(sharding) {}
1450
1451 const Node* node;
1452 xla::OpSharding sharding;
1453 };
1454
1455 // Validate sharding configuration derived from XlaSharding attribute.
1456 // Infer the core id from the OpSharding, if necessary.
ParseAndValidateSharding(const NodeAndSharding & node_and_sharding,const int num_cores_per_replica,int64 * inferred_core_id,absl::optional<NodeAndSharding> * result)1457 Status ParseAndValidateSharding(const NodeAndSharding& node_and_sharding,
1458 const int num_cores_per_replica,
1459 int64* inferred_core_id,
1460 absl::optional<NodeAndSharding>* result) {
1461 if (node_and_sharding.sharding.type() == xla::OpSharding::MAXIMAL) {
1462 int64_t core_annotation =
1463 node_and_sharding.sharding.tile_assignment_devices(0);
1464 TF_RETURN_IF_ERROR(
1465 ValidateCoreNumber(core_annotation, num_cores_per_replica));
1466 if (*inferred_core_id == -1 || *inferred_core_id > core_annotation) {
1467 *inferred_core_id = core_annotation;
1468 result->emplace(node_and_sharding);
1469 }
1470 } else {
1471 if (node_and_sharding.sharding.type() == xla::OpSharding::OTHER) {
1472 for (int64_t core :
1473 node_and_sharding.sharding.tile_assignment_devices()) {
1474 TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica));
1475 }
1476 }
1477
1478 if (!result->has_value()) {
1479 *result = node_and_sharding;
1480 } else {
1481 std::string result_value_serialized;
1482 xla::OpSharding result_value = result->value().sharding;
1483 result_value.clear_metadata();
1484 SerializeToStringDeterministic(result_value, &result_value_serialized);
1485
1486 std::string sharding_serialized;
1487 xla::OpSharding sharding = node_and_sharding.sharding;
1488 sharding.clear_metadata();
1489 SerializeToStringDeterministic(sharding, &sharding_serialized);
1490
1491 // TODO(lyandy): Choose the more granular sharding instead of always
1492 // assigning to core 0 (maximal).
1493 if (result_value_serialized != sharding_serialized) {
1494 // We see different shardings, assign to core 0.
1495 auto core_zero_sharding = xla::sharding_builder::AssignDevice(0);
1496 DCHECK_NE(node_and_sharding.node, nullptr);
1497 *core_zero_sharding.add_metadata() =
1498 CreateOpMetadataFromNode(*node_and_sharding.node);
1499 result->emplace(
1500 NodeAndSharding(node_and_sharding.node, core_zero_sharding));
1501 }
1502 }
1503 }
1504 return Status::OK();
1505 }
1506
1507 // As XlaSharding node may be followed by Cast op or an Identity op,
1508 // recursively walk the graph and aggregate nodes connectd to
1509 // |input_node| or Cast/Identity op following the |input_node|.
FindNodesMaybeContainingShardingInfo(const Node & input_node,std::vector<const Node * > * nodes)1510 void FindNodesMaybeContainingShardingInfo(const Node& input_node,
1511 std::vector<const Node*>* nodes) {
1512 if (input_node.IsIdentity() || input_node.type_string() == "Cast") {
1513 for (const Node* connected_node : input_node.out_nodes())
1514 FindNodesMaybeContainingShardingInfo(*connected_node, nodes);
1515 }
1516 nodes->emplace_back(&input_node);
1517 }
1518
1519 // Parse sharding configuration from |node| or it's adjacent nodes.
1520 // XlaSharding configuration may be derived from
1521 // a) Connected Identity op node.
1522 // b) Connected Cast op node.
1523 xla::StatusOr<absl::optional<NodeAndSharding>>
ParseInputShardingFromAdjacentNode(const int num_cores_per_replica,const Node & node)1524 ParseInputShardingFromAdjacentNode(const int num_cores_per_replica,
1525 const Node& node) {
1526 // If |node| has `device` attribute or is a XlaSharding op,
1527 // return the parsed OpSharding.
1528 TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
1529 ParseShardingFromDevice(node, num_cores_per_replica,
1530 /*add_metadata=*/true));
1531 if (sharding.has_value()) {
1532 return absl::optional<NodeAndSharding>(NodeAndSharding(&node, *sharding));
1533 }
1534
1535 // XlaShardingOp may be followed by an identity or followed by identity
1536 // and a Cast op.
1537 std::vector<const Node*> potential_nodes_with_input_sharding;
1538 FindNodesMaybeContainingShardingInfo(node,
1539 &potential_nodes_with_input_sharding);
1540 for (const Node* maybe_node_with_sharding_info :
1541 potential_nodes_with_input_sharding) {
1542 if (maybe_node_with_sharding_info->type_string() != "XlaSharding") continue;
1543
1544 TF_ASSIGN_OR_RETURN(
1545 absl::optional<xla::OpSharding> sharding_config,
1546 ParseShardingFromDevice(*maybe_node_with_sharding_info,
1547 num_cores_per_replica, /*add_metadata=*/true));
1548 if (sharding_config.has_value()) {
1549 return absl::optional<NodeAndSharding>(
1550 NodeAndSharding(maybe_node_with_sharding_info, *sharding_config));
1551 }
1552 }
1553 return absl::optional<NodeAndSharding>();
1554 }
1555
1556 // Walk the graph from an argument node to find OpSharding configuration
1557 // from its neighbor nodes. Sharding configuration may be inferred from
1558 // 1) Parsing XlaSharding attribute from neighboring node.
1559 // 2) If argument node is a resource, then by parsing adjacent nodes
1560 // of the connected ReadVariable op.
ParseAndValidateShardingFromNeighbors(const int num_cores_per_replica,const std::string & arg_node_name,const Node & neighbor_node,int64 * inferred_core_id,bool * is_fast_mem,absl::optional<NodeAndSharding> * result)1561 Status ParseAndValidateShardingFromNeighbors(
1562 const int num_cores_per_replica, const std::string& arg_node_name,
1563 const Node& neighbor_node, int64* inferred_core_id, bool* is_fast_mem,
1564 absl::optional<NodeAndSharding>* result) {
1565 if (neighbor_node.attrs().Find(TPU_FAST_MEM_ATTR) != nullptr) {
1566 *is_fast_mem = true;
1567 VLOG(2) << "place " << neighbor_node.name() << " on fast memory because "
1568 << arg_node_name << " has " << TPU_FAST_MEM_ATTR << " attribute";
1569 }
1570
1571 // XlaSharding information may be encoded on node directly connected to the
1572 // argument node.
1573 TF_ASSIGN_OR_RETURN(
1574 absl::optional<NodeAndSharding> node_and_sharding,
1575 ParseInputShardingFromAdjacentNode(num_cores_per_replica, neighbor_node));
1576 if (node_and_sharding.has_value()) {
1577 TF_RETURN_IF_ERROR(ParseAndValidateSharding(
1578 *node_and_sharding, num_cores_per_replica, inferred_core_id, result));
1579 return Status::OK();
1580 }
1581
1582 // When we use variable in TPU computation, we always have a
1583 // XlaSharding op followed by a ReadVariableOp. As so, correctly parse
1584 // the users of ReadVariableOp for potential sharding configuration.
1585 if (neighbor_node.type_string() == "ReadVariableOp") {
1586 for (const Edge* e : neighbor_node.out_edges()) {
1587 if (e->IsControlEdge()) continue;
1588
1589 if (e->dst()->attrs().Find(TPU_FAST_MEM_ATTR) != nullptr) {
1590 *is_fast_mem = true;
1591 VLOG(2) << "place " << arg_node_name << " on fast memory because "
1592 << e->dst()->name() << TPU_FAST_MEM_ATTR << " attribute";
1593 }
1594
1595 TF_ASSIGN_OR_RETURN(
1596 absl::optional<NodeAndSharding> node_and_sharding,
1597 ParseInputShardingFromAdjacentNode(num_cores_per_replica, *e->dst()));
1598 if (node_and_sharding.has_value()) {
1599 TF_RETURN_IF_ERROR(ParseAndValidateSharding(*node_and_sharding,
1600 num_cores_per_replica,
1601 inferred_core_id, result));
1602 return Status::OK();
1603 }
1604 }
1605 }
1606 return Status::OK();
1607 }
1608
1609 } // namespace
1610
1611 // Inputs:
1612 // replication_spec_string: the device to which the TPUReplicate node was
1613 // assigned.
1614 // device_set: the set of TF devices.
1615 // Outputs:
1616 // tpu_compilation_device: the name of the TPU compilation device.
1617 // num_tpus_per_task: the number of TPUs in each task. Verifies that all tasks
1618 // have the same number of TPU devices.
1619 // tpu_devices: the TPU devices, indexed by [task][device].
GetTPUDeviceNames(const string & replication_spec_string,const DeviceSet & device_set,string * tpu_compilation_device,int * num_tpus_per_task,std::vector<std::vector<Device * >> * tpu_devices)1620 static Status GetTPUDeviceNames(
1621 const string& replication_spec_string, const DeviceSet& device_set,
1622 string* tpu_compilation_device, int* num_tpus_per_task,
1623 std::vector<std::vector<Device*>>* tpu_devices) {
1624 // TODO(b/110910013) GetSystemDevice parses the spec and returns the name of
1625 // the tpu_system device, which we replace by the cpu device. We do this
1626 // replacement because we want to place the TPUCompileOp (and the compile
1627 // assert op) explicitly on cpu devices on the same job as the tpu_system
1628 // device.
1629 DeviceNameUtils::ParsedName replication_spec;
1630 Device* replication_device;
1631 TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetSystemDevice(
1632 replication_spec_string, device_set, &replication_spec,
1633 &replication_device));
1634 *tpu_compilation_device =
1635 str_util::StringReplace(replication_device->name(), DEVICE_TPU_SYSTEM,
1636 DEVICE_CPU, /*replace_all=*/true);
1637
1638 // Finds the set of TPU devices attached to the tasks in the job.
1639 TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetTPUDevices(
1640 replication_spec, device_set, num_tpus_per_task, tpu_devices));
1641
1642 return Status::OK();
1643 }
1644
1645 // Parses the topology attribute of TPUReplicate, and populates *topology with
1646 // a physical mesh coordinate to (task, device) mapping.
ParseTopologyAttr(const string & topology_attr,const tpu::TpuTopologyExternal & tpu_topology,int num_tasks,int num_tpus_per_task,xla::Array4D<std::pair<int,int>> * topology)1647 static Status ParseTopologyAttr(const string& topology_attr,
1648 const tpu::TpuTopologyExternal& tpu_topology,
1649 int num_tasks, int num_tpus_per_task,
1650 xla::Array4D<std::pair<int, int>>* topology) {
1651 static_assert(4 == kTPUTopologyRank, "Assumes the topology rank is 4");
1652 tpu::TopologyProto proto;
1653 proto.ParseFromString(topology_attr);
1654 if (proto.mesh_shape_size() != kTPUTopologyRank) {
1655 return errors::InvalidArgument("TPU topology must be rank ",
1656 kTPUTopologyRank);
1657 }
1658 if (proto.num_tasks() != num_tasks) {
1659 return errors::InvalidArgument("Mismatched number of TPU tasks");
1660 }
1661 if (proto.num_tpu_devices_per_task() != num_tpus_per_task) {
1662 return errors::InvalidArgument("Mismatched number of TPUs per task (",
1663 proto.num_tpu_devices_per_task(),
1664 " != ", num_tpus_per_task, ").");
1665 }
1666 if (proto.device_coordinates_size() !=
1667 num_tasks * num_tpus_per_task * kTPUTopologyRank) {
1668 return errors::InvalidArgument(
1669 "device coordinates should be ", num_tasks, "x", num_tpus_per_task, "x",
1670 kTPUTopologyRank, "; got ", proto.device_coordinates_size());
1671 }
1672
1673 int devices_per_chip = tpu_topology.LogicalDevicesPerChip(kTensorCore);
1674 *topology = xla::Array4D<std::pair<int, int>>(
1675 tpu_topology.chip_bounds().x, tpu_topology.chip_bounds().y,
1676 tpu_topology.chip_bounds().z, devices_per_chip, {-1, -1});
1677 int pos = 0;
1678 for (int task = 0; task < num_tasks; ++task) {
1679 for (int device = 0; device < num_tpus_per_task; ++device) {
1680 int32_t x = proto.device_coordinates(pos++);
1681 int32_t y = proto.device_coordinates(pos++);
1682 int32_t z = proto.device_coordinates(pos++);
1683 int32_t core = proto.device_coordinates(pos++);
1684
1685 if (!tpu_topology.HasChip(x, y, z) || core < 0 ||
1686 core >= devices_per_chip) {
1687 return errors::InvalidArgument(
1688 "Mesh coordinates (", x, ",", y, ",", z, ",", core,
1689 ") are not valid for the current TPU topology");
1690 }
1691 if ((*topology)(x, y, z, core).first != -1) {
1692 return errors::InvalidArgument("Duplicate coordinates (", x, ",", y,
1693 ",", z, ",", core, ") in TPU topology");
1694 }
1695 (*topology)(x, y, z, core) = {task, device};
1696 }
1697 }
1698 return Status::OK();
1699 }
1700
1701 // Parses the value of the device_assignment attribute to TPUReplicate.
1702 // Populates *device_assignment; *device_assignment must be a 2D array with
1703 // shape (num_replicas, num_cores_per_replica).
ParseDeviceAssignmentAttr(absl::Span<const int> device_assignment_attr,const tpu::TpuTopologyExternal & tpu_topology,int num_replicas,int num_cores_per_replica,xla::Array2D<tpu::TpuCoreLocationExternal> * device_assignment)1704 static Status ParseDeviceAssignmentAttr(
1705 absl::Span<const int> device_assignment_attr,
1706 const tpu::TpuTopologyExternal& tpu_topology, int num_replicas,
1707 int num_cores_per_replica,
1708 xla::Array2D<tpu::TpuCoreLocationExternal>* device_assignment) {
1709 static_assert(4 == kTPUTopologyRank, "Assumes the topology rank is 4");
1710
1711 const int64_t device_assignment_attr_size =
1712 num_replicas * num_cores_per_replica * kTPUTopologyRank;
1713 if (device_assignment_attr.size() != device_assignment_attr_size) {
1714 return errors::InvalidArgument(
1715 "Length of device_assignment attribute must be equal to num_replicas (",
1716 num_replicas, ") * num_cores_per_replica (", num_cores_per_replica,
1717 ") * ", kTPUTopologyRank, " got ", device_assignment_attr.size());
1718 }
1719 for (int core : device_assignment_attr) {
1720 if (core < 0 || core >= kTPUMaxTopologySize) {
1721 return errors::InvalidArgument(
1722 "Invalid core number in device assignment: ", core);
1723 }
1724 }
1725
1726 *device_assignment = xla::Array2D<tpu::TpuCoreLocationExternal>(
1727 num_replicas, num_cores_per_replica);
1728 int devices_per_chip = tpu_topology.LogicalDevicesPerChip(kTensorCore);
1729 xla::Array4D<int> replica_assignment(
1730 tpu_topology.chip_bounds().x, tpu_topology.chip_bounds().y,
1731 tpu_topology.chip_bounds().z, devices_per_chip, -1);
1732 int pos = 0;
1733 for (int replica = 0; replica < num_replicas; ++replica) {
1734 for (int logical_core = 0; logical_core < num_cores_per_replica;
1735 ++logical_core) {
1736 int32_t x = device_assignment_attr[pos++];
1737 int32_t y = device_assignment_attr[pos++];
1738 int32_t z = device_assignment_attr[pos++];
1739 int32_t core = device_assignment_attr[pos++];
1740
1741 if (!tpu_topology.HasChip(x, y, z) || core < 0 ||
1742 core >= devices_per_chip) {
1743 return errors::InvalidArgument(
1744 "Mesh coordinates (", x, ",", y, ",", core,
1745 ") are not valid for the current TPU topology");
1746 }
1747 tpu::TpuCoreLocationExternal core_location =
1748 tpu_topology.Core(kTensorCore, x, y, z, core);
1749
1750 if (replica_assignment(x, y, z, core) != -1) {
1751 return errors::InvalidArgument("Duplicate coordinates (", x, ",", y,
1752 ",", z, ",", core,
1753 ") in TPU device assignment");
1754 }
1755 replica_assignment(x, y, z, core) = replica;
1756 (*device_assignment)(replica, logical_core) = core_location;
1757 }
1758 }
1759 return Status::OK();
1760 }
1761
1762 // Builds TensorFlow device assignments for the special case of a single core
1763 // computation that is replicated to every core in the mesh.
1764 // LINT.IfChange
BuildFullMeshDeviceAssignment(int num_replicas,const std::vector<std::vector<Device * >> & tpu_devices,int num_tasks,int num_tpus_per_task,std::vector<std::vector<string>> * tf_device_assignment,std::vector<int> * devices_to_lock)1765 static Status BuildFullMeshDeviceAssignment(
1766 int num_replicas, const std::vector<std::vector<Device*>>& tpu_devices,
1767 int num_tasks, int num_tpus_per_task,
1768 std::vector<std::vector<string>>* tf_device_assignment,
1769 std::vector<int>* devices_to_lock) {
1770 // Assign TensorFlow devices to replicas arbitrarily.
1771 for (int i = 0; i < num_replicas; ++i) {
1772 int task = i / num_tpus_per_task;
1773 int device = i % num_tpus_per_task;
1774 TF_RET_CHECK(task >= 0 && task < num_tasks);
1775 TF_RET_CHECK(device >= 0 && device < num_tpus_per_task);
1776
1777 // We don't actually know which TF device corresponds to which physical
1778 // device, but it doesn't matter—they're all identical.
1779 (*tf_device_assignment)[i] = {tpu_devices[task][device]->name()};
1780 devices_to_lock->push_back(i);
1781 }
1782 return Status::OK();
1783 }
1784 // LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
1785
1786 // Builds TensorFlow device assignments for a replicated computation and convert
1787 // device_assignment into xla_device_assignment.
BuildGeneralDeviceAssignment(int num_replicas,int num_cores_per_replica,const std::vector<std::vector<Device * >> & tpu_devices,const xla::Array2D<tpu::TpuCoreLocationExternal> & device_assignment,const xla::Array4D<std::pair<int,int>> & topology,std::vector<std::vector<string>> * tf_device_assignment,std::vector<int> * devices_to_lock,std::unique_ptr<xla::DeviceAssignment> * xla_device_assignment)1788 static Status BuildGeneralDeviceAssignment(
1789 int num_replicas, int num_cores_per_replica,
1790 const std::vector<std::vector<Device*>>& tpu_devices,
1791 const xla::Array2D<tpu::TpuCoreLocationExternal>& device_assignment,
1792 const xla::Array4D<std::pair<int, int>>& topology,
1793 std::vector<std::vector<string>>* tf_device_assignment,
1794 std::vector<int>* devices_to_lock,
1795 std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment) {
1796 // Assign TensorFlow devices to each computation's replicas according to
1797 // device_assignment and 'topology'.
1798 *xla_device_assignment = absl::make_unique<xla::DeviceAssignment>(
1799 num_replicas, num_cores_per_replica);
1800 for (int replica = 0; replica < num_replicas; ++replica) {
1801 for (int computation = 0; computation < num_cores_per_replica;
1802 ++computation) {
1803 const tpu::TpuCoreLocationExternal& core_location =
1804 device_assignment(replica, computation);
1805
1806 int task;
1807 int device;
1808 std::tie(task, device) =
1809 topology(core_location.chip_coordinates().x,
1810 core_location.chip_coordinates().y,
1811 core_location.chip_coordinates().z, core_location.index());
1812
1813 CHECK_LT(computation, num_cores_per_replica);
1814 (**xla_device_assignment)(replica, computation) = core_location.Id();
1815
1816 // The communication pattern between replicas will be determined later by
1817 // BuildAllReduceRing.
1818 TF_RET_CHECK(task >= 0 && task < tpu_devices.size());
1819 TF_RET_CHECK(device >= 0 && device < tpu_devices[task].size());
1820 (*tf_device_assignment)[replica].push_back(
1821 tpu_devices[task][device]->name());
1822 devices_to_lock->push_back((task * tpu_devices[task].size()) + device);
1823 }
1824 }
1825 return Status::OK();
1826 }
1827
BuildDeviceAssignment(const tpu::TpuTopologyExternal & tpu_topology,int num_tpus_per_task,const std::vector<std::vector<Device * >> & tpu_devices,int num_replicas,int num_cores_per_replica,const string & topology_attr,absl::Span<const int> device_assignment_attr,std::vector<std::vector<string>> * tf_device_assignment,std::vector<int> * devices_to_lock,std::unique_ptr<xla::DeviceAssignment> * xla_device_assignment)1828 /*static*/ Status DistributedTPURewritePass::BuildDeviceAssignment(
1829 const tpu::TpuTopologyExternal& tpu_topology, int num_tpus_per_task,
1830 const std::vector<std::vector<Device*>>& tpu_devices, int num_replicas,
1831 int num_cores_per_replica, const string& topology_attr,
1832 absl::Span<const int> device_assignment_attr,
1833 std::vector<std::vector<string>>* tf_device_assignment,
1834 std::vector<int>* devices_to_lock,
1835 std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment) {
1836 const int num_tasks = tpu_devices.size();
1837 const int num_tpu_devices = num_tasks * num_tpus_per_task;
1838 VLOG(2) << "num_tasks=" << num_tasks
1839 << " num_tpus_per_task=" << num_tpus_per_task;
1840
1841 // Checks num_replicas is sane first to avoid integer overflow.
1842 if (num_replicas > num_tpu_devices) {
1843 #ifdef PLATFORM_CLOUD_TPU
1844 return errors::InvalidArgument("Requested num_replicas=", num_replicas,
1845 " but there are only ", num_tpu_devices,
1846 " cores in the TPU topology.");
1847 #else
1848 return errors::InvalidArgument("Requested num_replicas=", num_replicas,
1849 " but there are only ", num_tpu_devices,
1850 " cores in the TPU topology.");
1851 #endif
1852 }
1853 if (num_replicas * num_cores_per_replica > num_tpu_devices) {
1854 return errors::InvalidArgument(
1855 "Requested num_replicas=", num_replicas, " with ",
1856 num_cores_per_replica, " cores per replica, but there are only ",
1857 num_tpu_devices, " cores in the TPU topology");
1858 }
1859
1860 tf_device_assignment->clear();
1861 tf_device_assignment->resize(num_replicas);
1862
1863 devices_to_lock->clear();
1864 devices_to_lock->reserve(num_replicas * num_cores_per_replica);
1865
1866 // Special case: we allow the user to omit the topology and device assignment
1867 // information in two cases:
1868 // * there is only one replica and one core per replica. In this case, we
1869 // don't need to know topology information because we don't communicate with
1870 // other cores.
1871 // * the number of replicas is equal to the number of cores in the slice. In
1872 // this case, all cores are running the same program so we don't need to
1873 // know which is which.
1874 if (topology_attr.empty()) {
1875 // LINT.IfChange
1876 if (num_replicas != 1 && num_replicas != num_tpu_devices) {
1877 return errors::InvalidArgument(
1878 "TPUReplicate asked to create ", num_replicas,
1879 " replicas, but the number of cores in the TPU topology is ",
1880 num_tpu_devices,
1881 " and no TPU device assignment was supplied. "
1882 "A TPU device assignment is required if the number of replicas is "
1883 "not 1 or the number of cores in the topology (",
1884 num_tpu_devices, ")");
1885 }
1886
1887 if (num_cores_per_replica != 1) {
1888 return errors::InvalidArgument(
1889 "A TPU topology must be provided if num_cores_per_replica != 1");
1890 }
1891
1892 if (!device_assignment_attr.empty()) {
1893 return errors::InvalidArgument(
1894 "A TPU topology must be provided if device_assignment_attr is "
1895 "non-empty");
1896 }
1897
1898 // If there is only one replica, assign the Tensorflow computation to task 0
1899 // device 0, and leave the XLA device assignment empty. We don't know which
1900 // core this is in the TPU topology, but it doesn't matter—we don't need to
1901 // communicate with any other cores.
1902 if (num_replicas == 1) {
1903 (*tf_device_assignment)[0] = {tpu_devices[0][0]->name()};
1904 devices_to_lock->push_back(0);
1905 return Status::OK();
1906 }
1907
1908 // Otherwise, num_replicas is equal to the number of cores, and we build a
1909 // device assignment that covers the entire mesh. We do not need to know
1910 // the topology to do so because all cores are identical.
1911 return BuildFullMeshDeviceAssignment(num_replicas, tpu_devices, num_tasks,
1912 num_tpus_per_task,
1913 tf_device_assignment, devices_to_lock);
1914 // LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
1915 }
1916
1917 // Array that maps mesh coordinates to {TF task, TF TPU device #} pairs.
1918 xla::Array4D<std::pair<int, int>> topology;
1919 TF_RETURN_IF_ERROR(ParseTopologyAttr(topology_attr, tpu_topology, num_tasks,
1920 num_tpus_per_task, &topology));
1921
1922 // Array that maps logical (replica, core) pairs to physical mesh coordinates.
1923 xla::Array2D<tpu::TpuCoreLocationExternal> device_assignment;
1924 TF_RETURN_IF_ERROR(ParseDeviceAssignmentAttr(
1925 device_assignment_attr, tpu_topology, num_replicas, num_cores_per_replica,
1926 &device_assignment));
1927
1928 return BuildGeneralDeviceAssignment(
1929 num_replicas, num_cores_per_replica, tpu_devices, device_assignment,
1930 topology, tf_device_assignment, devices_to_lock, xla_device_assignment);
1931 }
1932
GetComputationForTPUReplicateOp(const NameAttrList & function,FunctionLibraryRuntime * flr,Graph * computation,DataTypeVector * arg_types,DataTypeVector * retval_types)1933 Status DistributedTPURewritePass::GetComputationForTPUReplicateOp(
1934 const NameAttrList& function, FunctionLibraryRuntime* flr,
1935 Graph* computation, DataTypeVector* arg_types,
1936 DataTypeVector* retval_types) {
1937 FunctionLibraryRuntime::Handle handle;
1938
1939 TF_RETURN_IF_ERROR(
1940 flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
1941
1942 const FunctionBody* fbody = flr->GetFunctionBody(handle);
1943
1944 CopyGraph(*fbody->graph, computation);
1945 *arg_types = fbody->arg_types;
1946 *retval_types = fbody->ret_types;
1947 return Status::OK();
1948 }
1949
1950 // Grab the InferredShape corresponding to an edge input.
GetEdgeShape(const GraphShapeInfo & shape_info,const Edge & edge,const InferredShape ** info)1951 static Status GetEdgeShape(const GraphShapeInfo& shape_info, const Edge& edge,
1952 const InferredShape** info) {
1953 auto it = shape_info.find(edge.src()->name());
1954 if (it == shape_info.end()) {
1955 return errors::InvalidArgument(
1956 "Input to replicated TPU computation is missing InferredShape: ",
1957 edge.src()->name());
1958 }
1959 TF_RET_CHECK(it->second.size() > edge.src_output());
1960 *info = &it->second[edge.src_output()];
1961 return Status::OK();
1962 }
1963
GetArgAndRetvalShapes(const GraphShapeInfo & shape_info,const Node & node,const ParameterInfo & params_info,std::vector<InferredShape> * arg_shapes,std::vector<InferredShape> * retval_shapes)1964 Status DistributedTPURewritePass::GetArgAndRetvalShapes(
1965 const GraphShapeInfo& shape_info, const Node& node,
1966 const ParameterInfo& params_info, std::vector<InferredShape>* arg_shapes,
1967 std::vector<InferredShape>* retval_shapes) {
1968 std::vector<const Edge*> input_edges;
1969 TF_RETURN_IF_ERROR(node.input_edges(&input_edges));
1970
1971 // If any replica's arg shape is unknown, we will mark the computation's arg
1972 // shape as being unknown. If the shapes differ the TpuExecute Op will raise a
1973 // runtime error.
1974 std::vector<bool> any_replica_shape_unknown(
1975 params_info.NumInputsToEachReplica());
1976 arg_shapes->clear();
1977 arg_shapes->resize(params_info.NumInputsToEachReplica());
1978 TF_RET_CHECK(input_edges.size() == params_info.NumInputsFromHost());
1979 // Determines the shapes of the per-replica arguments and checks that all
1980 // replicas have identical shapes.
1981 int64_t edge_pos = 0;
1982 auto check_shape = [&](int input_index) -> Status {
1983 const InferredShape* info;
1984 TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
1985 ++edge_pos;
1986
1987 if ((info->handle_type == DT_INVALID && !info->shape.IsFullyDefined()) ||
1988 (info->handle_type != DT_INVALID &&
1989 !info->handle_shape.IsFullyDefined())) {
1990 any_replica_shape_unknown[input_index] = true;
1991 }
1992 xla::StatusOr<InferredShape> status =
1993 MergeInferredShapes((*arg_shapes)[input_index], *info);
1994 if (!status.ok()) {
1995 return errors::InvalidArgument(
1996 "Mismatched shapes for input ", input_index, ": ",
1997 (*arg_shapes)[input_index].shape.DebugString(), " vs. ",
1998 info->shape.DebugString());
1999 }
2000 (*arg_shapes)[input_index] = status.ValueOrDie();
2001 return Status::OK();
2002 };
2003
2004 for (int64_t i = 0; i < params_info.NumReplicas(); ++i) {
2005 for (int64_t j = 0; j < params_info.NumPerReplicaArgs(); ++j) {
2006 TF_RETURN_IF_ERROR(check_shape(j));
2007 }
2008 }
2009
2010 for (int64_t i = 0; i < params_info.NumDistributedArgs(); ++i) {
2011 TF_RETURN_IF_ERROR(check_shape(params_info.NumPerReplicaArgs() + i));
2012 }
2013
2014 for (int64_t i = 0;
2015 i < params_info.NumPerReplicaArgs() + params_info.NumDistributedArgs();
2016 ++i) {
2017 if (any_replica_shape_unknown[i]) {
2018 (*arg_shapes)[i].shape = PartialTensorShape();
2019 (*arg_shapes)[i].handle_shape = PartialTensorShape();
2020 }
2021 }
2022
2023 // Determines the shape of the broadcast arguments.
2024 for (int64_t i = 0; i < params_info.NumBroadcastArgs(); ++i) {
2025 TF_RET_CHECK(node.input_type(edge_pos) != DT_RESOURCE);
2026 const InferredShape* info;
2027 TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
2028 (*arg_shapes)[i + params_info.NumPerReplicaArgs() +
2029 params_info.NumDistributedArgs()]
2030 .shape = info->shape;
2031 ++edge_pos;
2032 }
2033
2034 // Determines the handle shape and handle type of the resource variable
2035 // arguments.
2036 for (int64_t i = 0; i < params_info.NumVariables(); ++i) {
2037 TF_RET_CHECK(node.input_type(edge_pos) == DT_RESOURCE);
2038 const InferredShape* info;
2039 TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
2040 InferredShape& arg_shape =
2041 (*arg_shapes)[i + params_info.NumPerReplicaArgs() +
2042 params_info.NumDistributedArgs() +
2043 params_info.NumBroadcastArgs()];
2044 arg_shape.shape = TensorShape(); // Variables are always scalars.
2045 arg_shape.handle_shape = info->handle_shape;
2046 arg_shape.handle_type = info->handle_type;
2047 TF_RET_CHECK(arg_shape.handle_type != DT_INVALID)
2048 << " input edge: " << input_edges[edge_pos]->DebugString();
2049 ++edge_pos;
2050 }
2051
2052 // Determines the shape of the guaranteed constants.
2053 // TODO(vinuraja): Can be removed because they are not required for any
2054 // calculations. Leaving them here for symmetry with other structures like
2055 // arg_types, arg_sharding, etc.
2056 for (int64_t i = 0; i < params_info.NumGuaranteedConstants(); ++i) {
2057 TF_RET_CHECK(node.input_type(edge_pos) != DT_RESOURCE);
2058 const InferredShape* info;
2059 TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
2060 (*arg_shapes)[i + params_info.NumPerReplicaArgs() +
2061 params_info.NumDistributedArgs() +
2062 params_info.NumBroadcastArgs() + params_info.NumVariables()]
2063 .shape = info->shape;
2064 ++edge_pos;
2065 }
2066
2067 // Extract the return value shapes.
2068 auto it = shape_info.find(node.name());
2069 retval_shapes->clear();
2070 if (it != shape_info.end()) {
2071 TF_RET_CHECK(it->second.size() >= node.num_outputs());
2072 retval_shapes->resize(node.num_outputs());
2073 for (int i = 0; i < node.num_outputs(); ++i) {
2074 (*retval_shapes)[i].shape = it->second[i].shape;
2075 }
2076 } else if (node.num_outputs() > 0) {
2077 return errors::InvalidArgument(
2078 "Replicated TPU computation is missing InferredShape: ",
2079 FormatNodeForError(node));
2080 }
2081 return Status::OK();
2082 }
2083
2084 // Verifies that all nodes have legal sharding.
ValidateCoreNumbers(const Graph & graph,int num_cores_per_replica)2085 static Status ValidateCoreNumbers(const Graph& graph,
2086 int num_cores_per_replica) {
2087 for (Node* n : graph.nodes()) {
2088 TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
2089 ParseShardingFromDevice(*n, num_cores_per_replica,
2090 /*add_metadata=*/true));
2091 }
2092 return Status::OK();
2093 }
2094
InferXlaShardingFromNeighbors(const Node & n,int num_cores_per_replica,FunctionLibraryRuntime * flr,CachedFunctionHandles * cached_function_handles,absl::optional<NodeAndSharding> * output_node_and_sharding,bool * is_fast_mem)2095 static Status InferXlaShardingFromNeighbors(
2096 const Node& n, int num_cores_per_replica, FunctionLibraryRuntime* flr,
2097 CachedFunctionHandles* cached_function_handles,
2098 absl::optional<NodeAndSharding>* output_node_and_sharding,
2099 bool* is_fast_mem) {
2100 int64_t core = -1;
2101 absl::optional<NodeAndSharding> result;
2102 // We assume the variable has been allocated on fast memory if any consuming
2103 // op has TPU_FAST_MEM_ATTR attribute. This is a protocol between runtime and
2104 // compiler.
2105 *is_fast_mem = false;
2106 for (const Edge* edge : n.out_edges()) {
2107 if (edge->IsControlEdge()) continue;
2108
2109 TF_RETURN_IF_ERROR(ParseAndValidateShardingFromNeighbors(
2110 num_cores_per_replica, n.name(), *edge->dst(), &core, is_fast_mem,
2111 &result));
2112
2113 if (!flr) continue;
2114
2115 // The nodes deciding this arg's device assignment might be in
2116 // FunctionDef. Instantiate FunctionDefs associated with this node
2117 // and check nodes using this arg.
2118 std::function<Status(const Edge* call_edge)> parse_sharding_from_function =
2119 [&](const Edge* call_edge) {
2120 auto associated_functions = GetAssociatedFunctions(
2121 *call_edge->dst(), flr->GetFunctionLibraryDefinition());
2122 for (auto& associated_function : associated_functions) {
2123 FunctionLibraryRuntime::Handle handle;
2124 TF_RETURN_IF_ERROR(cached_function_handles->GetOrInstantiate(
2125 associated_function.func_name(),
2126 AttrSlice(&associated_function.attrs()), &handle));
2127 const FunctionBody* body = flr->GetFunctionBody(handle);
2128 Graph* g = body->graph;
2129
2130 for (Node* body_node : g->nodes()) {
2131 if (!body_node->IsArg()) continue;
2132
2133 int index;
2134 TF_RETURN_IF_ERROR(
2135 GetNodeAttr(body_node->attrs(), "index", &index));
2136 if (index != call_edge->dst_input()) continue;
2137
2138 for (const Edge* out_edge : body_node->out_edges()) {
2139 if (out_edge->IsControlEdge()) continue;
2140
2141 TF_RETURN_IF_ERROR(ParseAndValidateShardingFromNeighbors(
2142 num_cores_per_replica, n.name(), *out_edge->dst(), &core,
2143 is_fast_mem, &result));
2144
2145 TF_RETURN_IF_ERROR(parse_sharding_from_function(out_edge));
2146 }
2147 }
2148 }
2149 return Status::OK();
2150 };
2151 TF_RETURN_IF_ERROR(parse_sharding_from_function(edge));
2152 }
2153 *output_node_and_sharding = result;
2154 return Status::OK();
2155 }
2156
UseSpmdForXlaPartitioning(const Node * replicate_node)2157 bool UseSpmdForXlaPartitioning(const Node* replicate_node) {
2158 bool spmd_attr;
2159 if (!replicate_node ||
2160 !TryGetNodeAttr(replicate_node->attrs(), "use_spmd_for_xla_partitioning",
2161 &spmd_attr)) {
2162 spmd_attr = false;
2163 }
2164 return spmd_attr;
2165 }
2166
FormatNodeAndShardingMsg(const absl::optional<NodeAndSharding> & node_and_sharding)2167 std::string FormatNodeAndShardingMsg(
2168 const absl::optional<NodeAndSharding>& node_and_sharding) {
2169 DCHECK(node_and_sharding.has_value());
2170
2171 xla::OpSharding sharding_no_metadata = node_and_sharding->sharding;
2172 sharding_no_metadata.clear_metadata();
2173 std::string escaped_sharding_str =
2174 absl::CEscape(sharding_no_metadata.SerializeAsString());
2175 if (node_and_sharding->node == nullptr) {
2176 return absl::StrCat(" via default sharding '", escaped_sharding_str, "'");
2177 }
2178
2179 return absl::StrCat(" via node ", node_and_sharding->node->DebugString(),
2180 " sharding '", escaped_sharding_str, "'");
2181 }
2182
AssignArgsAndRetvalsToCores(int num_cores_per_replica,const ParameterInfo & params_info,const DataTypeVector & arg_types,const std::vector<InferredShape> & arg_shapes,const DataTypeVector & retval_types,const std::vector<InferredShape> & retval_shapes,const Graph & graph,const Node * replicate_node,FunctionLibraryRuntime * flr,bool allow_parameter_replication_for_spmd,std::vector<xla::OpSharding> * arg_sharding,std::vector<bool> * arg_fast_mem,std::vector<xla::OpSharding> * retval_sharding,std::vector<std::string> * arg_names)2183 Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
2184 int num_cores_per_replica, const ParameterInfo& params_info,
2185 const DataTypeVector& arg_types,
2186 const std::vector<InferredShape>& arg_shapes,
2187 const DataTypeVector& retval_types,
2188 const std::vector<InferredShape>& retval_shapes, const Graph& graph,
2189 const Node* replicate_node, FunctionLibraryRuntime* flr,
2190 bool allow_parameter_replication_for_spmd,
2191 std::vector<xla::OpSharding>* arg_sharding, std::vector<bool>* arg_fast_mem,
2192 std::vector<xla::OpSharding>* retval_sharding,
2193 std::vector<std::string>* arg_names) {
2194 // Builds vectors of the argument and return nodes.
2195 std::vector<Node*> args(arg_types.size());
2196 std::vector<Node*> retvals(retval_types.size());
2197 absl::flat_hash_map<int, Node*> partitioned_output_nodes;
2198 for (Node* node : graph.op_nodes()) {
2199 if (node->IsArg()) {
2200 int index;
2201 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
2202 TF_RET_CHECK(index >= 0 && index < args.size());
2203 args[index] = node;
2204 } else if (node->IsRetval()) {
2205 int index;
2206 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
2207 TF_RET_CHECK(index >= 0 && index < retvals.size());
2208 retvals[index] = node;
2209 }
2210 }
2211 for (const Edge* edge : replicate_node->out_edges()) {
2212 int num_partitioned_outputs = 0;
2213 for (const Edge* out_edge : edge->dst()->out_edges()) {
2214 if (out_edge->dst()->type_string() == kTPUPartitionedOutput) {
2215 partitioned_output_nodes[edge->src_output()] = out_edge->dst();
2216 num_partitioned_outputs++;
2217 }
2218 }
2219 if (num_partitioned_outputs > 1) {
2220 return errors::InvalidArgument(
2221 "More than one TPUPartitionedOutput per replciated output.");
2222 }
2223 }
2224
2225 // Verifies there are no missing arguments/return values.
2226 for (int i = 0; i < args.size(); ++i) {
2227 if (args[i] == nullptr) {
2228 return errors::Internal("Missing function argument: ", i);
2229 }
2230 }
2231 for (int i = 0; i < retvals.size(); ++i) {
2232 if (retvals[i] == nullptr) {
2233 return errors::Internal("Missing function return value: ", i);
2234 }
2235 }
2236
2237 // Assigns a core to each _Arg. Chooses the lowest-numbered core that
2238 // consumes the argument. We choose the lowest-numbered core so the
2239 // assignment is deterministic.
2240 TensorDevicePlacer args_device_selector(num_cores_per_replica, arg_types,
2241 arg_shapes);
2242 arg_sharding->resize(args.size());
2243 arg_names->resize(args.size());
2244 arg_fast_mem->resize(args.size());
2245 CachedFunctionHandles cached_function_handles(flr);
2246 const bool use_spmd = (UseSpmdForXlaPartitioning(replicate_node) ||
2247 replicate_inputs_outputs_by_default_for_xla_spmd_) &&
2248 allow_parameter_replication_for_spmd;
2249
2250 // Offset _TPUReplicate non per replica argument indices by
2251 // (num_replicas - 1) * num_per_replica_args as _TPUReplicate nodes are
2252 // constructed with all per replica args across all replicas while the
2253 // encapsulated function only has 1 replica's per replica args. Per replica
2254 // args are ordered by replica first, so the index here does not require an
2255 // offset and the first replica's input nodes is sufficient for determining
2256 // argument sharding.
2257 const int index_offset =
2258 (params_info.NumReplicas() - 1) * params_info.NumPerReplicaArgs();
2259 for (int i = 0; i < args.size(); ++i) {
2260 const Node* n = args[i];
2261 absl::optional<int64> assigned_core;
2262 absl::optional<NodeAndSharding> node_and_sharding;
2263 bool is_fast_mem;
2264 TF_RETURN_IF_ERROR(InferXlaShardingFromNeighbors(
2265 *n, num_cores_per_replica, flr, &cached_function_handles,
2266 &node_and_sharding, &is_fast_mem));
2267
2268 const bool is_per_replica_arg = params_info.IsPerReplicaArg(i);
2269 if (is_per_replica_arg || params_info.IsDistributedArg(i)) {
2270 Node* input_node;
2271 TF_RETURN_IF_ERROR(replicate_node->input_node(
2272 i + (is_per_replica_arg ? 0 : index_offset), &input_node));
2273 if (input_node->type_string() == kTPUPartitionedInput) {
2274 TF_ASSIGN_OR_RETURN(
2275 absl::optional<xla::OpSharding> parsed_sharding,
2276 GetShardingFromNodeDef(input_node->def(), /*add_metadata=*/true));
2277 if (!parsed_sharding.has_value())
2278 return errors::InvalidArgument("Missing _XlaSharding attr from: ",
2279 input_node->DebugString());
2280 node_and_sharding = NodeAndSharding(input_node, *parsed_sharding);
2281 VLOG(1) << "Arg " << i << " parsed sharding information from "
2282 << input_node->DebugString() << " : "
2283 << parsed_sharding->DebugString();
2284 }
2285 }
2286
2287 if (params_info.IsVariableArg(i)) {
2288 Node* input_node;
2289 TF_RETURN_IF_ERROR(
2290 replicate_node->input_node(i + index_offset, &input_node));
2291 if (input_node->type_string() == kVarHandleOp) {
2292 TF_ASSIGN_OR_RETURN(
2293 absl::optional<xla::OpSharding> parsed_sharding,
2294 GetShardingFromNodeDef(input_node->def(), /*add_metadata=*/true));
2295 if (parsed_sharding.has_value()) {
2296 node_and_sharding = NodeAndSharding(input_node, *parsed_sharding);
2297 VLOG(1) << "Arg " << i << " parsed sharding information from "
2298 << input_node->DebugString() << " : "
2299 << parsed_sharding->DebugString();
2300 }
2301 }
2302 }
2303
2304 if (node_and_sharding.has_value() && enable_automatic_model_parallelism_) {
2305 return tensorflow::errors::InvalidArgument(
2306 "Specifying manual sharding is not allowed when automatic "
2307 "model parallelism is enabled.",
2308 node_and_sharding->sharding.DebugString());
2309 }
2310
2311 if (!node_and_sharding.has_value()) {
2312 if (use_spmd &&
2313 (params_info.IsVariableArg(i) || params_info.IsBroadcastArg(i) ||
2314 ((params_info.IsPerReplicaArg(i) ||
2315 params_info.IsDistributedArg(i)) &&
2316 arg_types[i] != DT_RESOURCE))) {
2317 // Use replication for host variables or non-variable per-replica
2318 // inputs.
2319 node_and_sharding = NodeAndSharding(/*node=*/nullptr,
2320 xla::sharding_builder::Replicate());
2321 } else {
2322 // TODO(dlibenzi): Distributing variables to cores other than 0 makes
2323 // learning/brain/research/babelfish/trainer:trainer_tpu_test fail.
2324 // For now distribute only per replica arguments, unless
2325 // tf_jf_distribute_vars is set, to allow debugging the issue.
2326 if (((params_info.IsPerReplicaArg(i) ||
2327 params_info.IsDistributedArg(i)) &&
2328 arg_types[i] != DT_RESOURCE) ||
2329 (distribute_vars_ && params_info.IsVariableArg(i))) {
2330 assigned_core = args_device_selector.RetrieveAssignment(i);
2331 } else {
2332 assigned_core = 0;
2333 }
2334 node_and_sharding = NodeAndSharding(
2335 /*node=*/nullptr,
2336 xla::sharding_builder::AssignDevice(*assigned_core));
2337 }
2338 *node_and_sharding->sharding.add_metadata() =
2339 CreateOpMetadataFromNode(*replicate_node);
2340 } else if (node_and_sharding->sharding.type() == xla::OpSharding::MAXIMAL) {
2341 assigned_core = node_and_sharding->sharding.tile_assignment_devices(0);
2342 } else if (node_and_sharding->sharding.type() !=
2343 xla::OpSharding::REPLICATED &&
2344 node_and_sharding->sharding.type() != xla::OpSharding::OTHER) {
2345 return tensorflow::errors::InvalidArgument(
2346 "Unsupported argument sharding (for arg ", n->DebugString(),
2347 "): ", node_and_sharding->sharding.DebugString());
2348 }
2349 if (assigned_core.has_value()) {
2350 args_device_selector.ReportDeviceAssigned(*assigned_core, i);
2351 VLOG(3) << "Assigning argument " << i << " (" << n->DebugString()
2352 << ") to core " << *assigned_core
2353 << FormatNodeAndShardingMsg(node_and_sharding);
2354 args[i]->set_assigned_device_name(CoreDeviceLabel(*assigned_core));
2355 } else if (node_and_sharding->sharding.type() == xla::OpSharding::OTHER) {
2356 for (int64_t core :
2357 node_and_sharding->sharding.tile_assignment_devices()) {
2358 TF_RET_CHECK(core >= 0 && core < num_cores_per_replica)
2359 << "core " << core << " should be between [0, "
2360 << num_cores_per_replica << "). sharding is "
2361 << node_and_sharding->sharding.DebugString();
2362 args_device_selector.ReportDeviceAssigned(core, i);
2363 }
2364 VLOG(3) << "Assigning argument " << i << " (" << n->DebugString()
2365 << ") with tiled sharding to cores "
2366 << absl::StrJoin(
2367 node_and_sharding->sharding.tile_assignment_devices(), ",")
2368 << " " << FormatNodeAndShardingMsg(node_and_sharding);
2369 } else {
2370 DCHECK_EQ(node_and_sharding->sharding.type(),
2371 xla::OpSharding::REPLICATED);
2372 for (int64_t core = 0; core < num_cores_per_replica; ++core) {
2373 args_device_selector.ReportDeviceAssigned(core, i);
2374 }
2375 VLOG(3) << "Assigning argument " << i << " (" << n->DebugString()
2376 << ") to all cores"
2377 << FormatNodeAndShardingMsg(node_and_sharding);
2378 }
2379 (*arg_sharding)[i] = node_and_sharding->sharding;
2380 (*arg_fast_mem)[i] = is_fast_mem;
2381 (*arg_names)[i] = n->name();
2382 if (is_fast_mem) {
2383 VLOG(3) << "Add " << TPU_FAST_MEM_ATTR << " attribute to "
2384 << args[i]->name();
2385 }
2386 args[i]->AddAttr(kShardingAttribute,
2387 node_and_sharding->sharding.SerializeAsString());
2388 }
2389 TF_RETURN_IF_ERROR(cached_function_handles.ReleaseAllHandles());
2390
2391 // Assigns each _Retval node to the core that produces its value.
2392 TensorDevicePlacer retvals_device_selector(num_cores_per_replica,
2393 retval_types, retval_shapes);
2394 retval_sharding->resize(retvals.size());
2395 for (int i = 0; i < retvals.size(); ++i) {
2396 const Edge* edge;
2397 TF_RETURN_IF_ERROR(retvals[i]->input_edge(0, &edge));
2398
2399 TF_ASSIGN_OR_RETURN(
2400 absl::optional<xla::OpSharding> edge_sharding,
2401 ParseShardingFromEdgeSource(*edge, num_cores_per_replica,
2402 /*add_metadata=*/true));
2403
2404 absl::optional<NodeAndSharding> node_and_sharding;
2405 if (edge_sharding.has_value()) {
2406 node_and_sharding.emplace(NodeAndSharding(edge->src(), *edge_sharding));
2407 }
2408
2409 if (partitioned_output_nodes.contains(i)) {
2410 Node* output_node = partitioned_output_nodes[i];
2411 TF_ASSIGN_OR_RETURN(
2412 absl::optional<xla::OpSharding> parsed_sharding,
2413 GetShardingFromNodeDef(output_node->def(), /*add_metadata=*/true));
2414 if (parsed_sharding.has_value()) {
2415 node_and_sharding = NodeAndSharding(output_node, *parsed_sharding);
2416 VLOG(1) << "Retval " << i << " parsed sharding information from "
2417 << output_node->DebugString() << " : "
2418 << parsed_sharding->DebugString();
2419 }
2420 }
2421 absl::optional<int64> assigned_core;
2422 if (node_and_sharding.has_value()) {
2423 if (enable_automatic_model_parallelism_) {
2424 return tensorflow::errors::InvalidArgument(
2425 "Specifying manual sharding is not allowed when automatic "
2426 "model parallelism is enabled.",
2427 node_and_sharding->sharding.DebugString());
2428 }
2429
2430 if (node_and_sharding->sharding.type() == xla::OpSharding::MAXIMAL) {
2431 assigned_core = node_and_sharding->sharding.tile_assignment_devices(0);
2432 TF_RETURN_IF_ERROR(
2433 ValidateCoreNumber(*assigned_core, num_cores_per_replica));
2434 } else if (node_and_sharding->sharding.type() !=
2435 xla::OpSharding::REPLICATED &&
2436 node_and_sharding->sharding.type() != xla::OpSharding::OTHER) {
2437 return tensorflow::errors::InvalidArgument(
2438 "Unsupported argument sharding for retval ",
2439 retvals[i]->DebugString(), " edge=", edge->DebugString(), ": ",
2440 node_and_sharding->sharding.DebugString());
2441 }
2442 } else {
2443 if (use_spmd) {
2444 node_and_sharding = NodeAndSharding(/*node=*/nullptr,
2445 xla::sharding_builder::Replicate());
2446 } else {
2447 if (distribute_vars_) {
2448 assigned_core = retvals_device_selector.RetrieveAssignment(i);
2449 } else {
2450 assigned_core = 0;
2451 }
2452 node_and_sharding = NodeAndSharding(
2453 /*node=*/nullptr,
2454 xla::sharding_builder::AssignDevice(*assigned_core));
2455 }
2456 *node_and_sharding->sharding.add_metadata() =
2457 CreateOpMetadataFromNode(*replicate_node);
2458 }
2459 if (assigned_core.has_value()) {
2460 retvals[i]->set_assigned_device_name(CoreDeviceLabel(*assigned_core));
2461 retvals_device_selector.ReportDeviceAssigned(*assigned_core, i);
2462 VLOG(3) << "Assigning return value " << i << " ("
2463 << retvals[i]->DebugString() << ") to core " << *assigned_core
2464 << FormatNodeAndShardingMsg(node_and_sharding);
2465 } else if (node_and_sharding->sharding.type() == xla::OpSharding::OTHER) {
2466 for (int64_t core :
2467 node_and_sharding->sharding.tile_assignment_devices()) {
2468 TF_RET_CHECK(core >= 0 && core < num_cores_per_replica)
2469 << "core " << core << " should be between [0, "
2470 << num_cores_per_replica << "). sharding is "
2471 << node_and_sharding->sharding.DebugString();
2472 retvals_device_selector.ReportDeviceAssigned(core, i);
2473 }
2474 VLOG(3) << "Assigning return value " << i << " ("
2475 << retvals[i]->DebugString() << ") with tiled sharding to cores "
2476 << absl::StrJoin(
2477 node_and_sharding->sharding.tile_assignment_devices(), ",")
2478 << " " << FormatNodeAndShardingMsg(node_and_sharding);
2479 } else {
2480 DCHECK_EQ(node_and_sharding->sharding.type(),
2481 xla::OpSharding::REPLICATED);
2482 for (int64_t core = 0; core < num_cores_per_replica; ++core) {
2483 retvals_device_selector.ReportDeviceAssigned(core, i);
2484 }
2485 VLOG(3) << "Assigning return value " << i << " ("
2486 << retvals[i]->DebugString() << ") to all cores"
2487 << FormatNodeAndShardingMsg(node_and_sharding);
2488 }
2489 retvals[i]->AddAttr(kShardingAttribute,
2490 node_and_sharding->sharding.SerializeAsString());
2491 (*retval_sharding)[i] = node_and_sharding->sharding;
2492 }
2493 if (use_spmd &&
2494 (absl::c_any_of(*arg_sharding,
2495 [](const xla::OpSharding& s) {
2496 return s.type() == xla::OpSharding::MAXIMAL;
2497 }) ||
2498 absl::c_any_of(*retval_sharding, [](const xla::OpSharding& s) {
2499 return s.type() == xla::OpSharding::MAXIMAL;
2500 }))) {
2501 LOG(WARNING) << "XLA SPMD only supports cases where all inputs/outputs "
2502 "exist on every partition (sharded or replicated). Fall "
2503 "back to MPMD.";
2504 return AssignArgsAndRetvalsToCores(
2505 num_cores_per_replica, params_info, arg_types, arg_shapes, retval_types,
2506 retval_shapes, graph, replicate_node, flr,
2507 /*allow_parameter_replication_for_spmd=*/false, arg_sharding,
2508 arg_fast_mem, retval_sharding, arg_names);
2509 }
2510 return Status::OK();
2511 }
2512
2513 // Builds Shape nodes that compute the shapes of arguments whose shapes are not
2514 // statically known.
BuildDynamicShapeNodes(const Node & replicate_node,const std::vector<InferredShape> & arg_shapes,const ParameterInfo & params_info,const std::vector<Node * > & variable_reads,Graph * graph,std::vector<Node * > * dynamic_shape_nodes)2515 /* static */ Status DistributedTPURewritePass::BuildDynamicShapeNodes(
2516 const Node& replicate_node, const std::vector<InferredShape>& arg_shapes,
2517 const ParameterInfo& params_info, const std::vector<Node*>& variable_reads,
2518 Graph* graph, std::vector<Node*>* dynamic_shape_nodes) {
2519 dynamic_shape_nodes->clear();
2520
2521 std::vector<const Edge*> replicate_input_edges;
2522 TF_RETURN_IF_ERROR(replicate_node.input_edges(&replicate_input_edges));
2523
2524 // The compiler determines the shape of each constant by inspecting the value
2525 // of its corresponding host-memory tensor; this happens when a step is run.
2526 // As a result, the shapes of constants are not needed at graph rewrite time.
2527 const int num_args = arg_shapes.size() - params_info.NumGuaranteedConstants();
2528 TF_RET_CHECK(num_args == params_info.NumPerReplicaArgs() +
2529 params_info.NumDistributedArgs() +
2530 params_info.NumBroadcastArgs() +
2531 params_info.NumVariables());
2532
2533 for (int i = 0; i < num_args; ++i) {
2534 const PartialTensorShape* shape = arg_shapes[i].handle_type == DT_INVALID
2535 ? &arg_shapes[i].shape
2536 : &arg_shapes[i].handle_shape;
2537 if (!shape->IsFullyDefined()) {
2538 NodeDef def;
2539 Node* src;
2540 int src_output;
2541 std::vector<Node*> control_inputs;
2542
2543 if (params_info.IsVariableArg(i)) {
2544 int64_t var_num = i - params_info.NumPerReplicaArgs() -
2545 params_info.NumDistributedArgs() -
2546 params_info.NumBroadcastArgs();
2547 TF_RET_CHECK(0 <= var_num && var_num < variable_reads.size());
2548 Node* read = variable_reads[var_num];
2549
2550 DCHECK_EQ(read->type_string(), "ReadVariableOp");
2551
2552 for (const Edge* edge : read->in_edges()) {
2553 if (edge->IsControlEdge()) {
2554 control_inputs.push_back(edge->src());
2555 }
2556 }
2557
2558 const Edge* variable_input = nullptr;
2559 TF_RETURN_IF_ERROR(read->input_edge(/*idx=*/0, &variable_input));
2560 src = variable_input->src();
2561 src_output = variable_input->src_output();
2562
2563 def.set_name(
2564 graph->NewName(strings::StrCat(src->name(), "/variable_shape")));
2565 def.set_op("VariableShape");
2566 } else {
2567 if (params_info.IsPerReplicaArg(i)) {
2568 TF_RET_CHECK(i < replicate_input_edges.size());
2569 // All replicas must have the same input shapes. Uses the shape of the
2570 // inputs from the first replica.
2571 src = replicate_input_edges[i]->src();
2572 src_output = replicate_input_edges[i]->src_output();
2573 } else {
2574 DCHECK(params_info.IsDistributedArg(i) ||
2575 params_info.IsBroadcastArg(i));
2576 int64_t input_num =
2577 params_info.NumPerReplicaArgs() * params_info.NumReplicas() + i -
2578 params_info.NumPerReplicaArgs();
2579 TF_RET_CHECK(0 <= input_num &&
2580 input_num < replicate_input_edges.size());
2581 src = replicate_input_edges[input_num]->src();
2582 src_output = replicate_input_edges[input_num]->src_output();
2583 }
2584
2585 def.set_name(graph->NewName(strings::StrCat(src->name(), "/shape")));
2586 def.set_op("Shape");
2587 AddNodeAttr("T", src->output_type(src_output), &def);
2588 }
2589
2590 def.set_device(src->assigned_device_name());
2591 AddNodeAttr("out_type", DT_INT64, &def);
2592 MergeDebugInfo(NodeDebugInfo(replicate_node.def()), &def);
2593
2594 Status status;
2595 Node* shape_node = graph->AddNode(def, &status);
2596 if (!status.ok()) return status;
2597 dynamic_shape_nodes->push_back(shape_node);
2598
2599 shape_node->set_assigned_device_name(src->assigned_device_name());
2600 graph->AddEdge(src, src_output, shape_node, 0);
2601 for (Node* control_input : control_inputs) {
2602 graph->AddControlEdge(control_input, shape_node);
2603 }
2604 }
2605 }
2606 return Status::OK();
2607 }
2608
2609 namespace {
2610
XlaBroadcastTypeSupported(const DataType dtype)2611 bool XlaBroadcastTypeSupported(const DataType dtype) {
2612 return (dtype == DT_FLOAT || dtype == DT_BFLOAT16 || dtype == DT_INT32 ||
2613 dtype == DT_BOOL);
2614 }
2615
XlaBroadcastKindSupported(const DistributedTPURewritePass::ParameterInfo & params_info,int param_num)2616 bool XlaBroadcastKindSupported(
2617 const DistributedTPURewritePass::ParameterInfo& params_info,
2618 int param_num) {
2619 // NOTE: This is intended to cover non-sharded data parallel variables, for
2620 // training only. . Is it correct to just check if the arg_type is
2621 // DT_RESOURCE?
2622 return params_info.IsVariableArg(param_num) &&
2623 !(params_info.IsPerReplicaArg(param_num) ||
2624 params_info.IsDistributedArg(param_num) ||
2625 params_info.IsBroadcastArg(param_num) ||
2626 params_info.IsConstantArg(param_num));
2627 }
2628
EnableXlaParamBroadcast(bool enable_xla_param_broadcast,const DistributedTPURewritePass::ParameterInfo & params_info,int param_num,DataType dtype)2629 bool EnableXlaParamBroadcast(
2630 bool enable_xla_param_broadcast,
2631 const DistributedTPURewritePass::ParameterInfo& params_info, int param_num,
2632 DataType dtype) {
2633 // Conditions necessary to use XLA collectives for arg broadcast:
2634 // 1. Globally enabled via enable_xla_param_broadcast.
2635 // 2. DataType must be supported.
2636 // 3. Parameter must be a variable, and not distributed or broadcasted.
2637 return enable_xla_param_broadcast && XlaBroadcastTypeSupported(dtype) &&
2638 XlaBroadcastKindSupported(params_info, param_num);
2639 }
2640
2641 } // namespace
2642
2643 // Builds a TPUCompile node that compiles the bodies of the function call
2644 // `nodes`.
BuildCompileNode(const Node * replicate_node,const NameAttrList & function,uint64 library_fingerprint,const ParameterInfo & params_info,const std::vector<InferredShape> & arg_shapes,const DataTypeVector & arg_types,const std::vector<Node * > & guaranteed_constant_nodes,const string & session_handle,const std::vector<xla::OpSharding> & arg_sharding,const std::vector<bool> & arg_fast_mem,const std::vector<std::string> & arg_names,const std::vector<xla::OpSharding> & retval_sharding,int num_cores_per_replica,const string & compile_device,const xla::DeviceAssignment * xla_device_assignment,const std::vector<Node * > & dynamic_shape_nodes,Graph * graph,Node ** compile_node,int64_t autotuner_thresh)2645 Status DistributedTPURewritePass::BuildCompileNode(
2646 const Node* replicate_node, const NameAttrList& function,
2647 uint64 library_fingerprint, const ParameterInfo& params_info,
2648 const std::vector<InferredShape>& arg_shapes,
2649 const DataTypeVector& arg_types,
2650 const std::vector<Node*>& guaranteed_constant_nodes,
2651 const string& session_handle,
2652 const std::vector<xla::OpSharding>& arg_sharding,
2653 const std::vector<bool>& arg_fast_mem,
2654 const std::vector<std::string>& arg_names,
2655 const std::vector<xla::OpSharding>& retval_sharding,
2656 int num_cores_per_replica, const string& compile_device,
2657 const xla::DeviceAssignment* xla_device_assignment,
2658 const std::vector<Node*>& dynamic_shape_nodes, Graph* graph,
2659 Node** compile_node, int64_t autotuner_thresh) {
2660 VLOG(1) << "BuildCompileNode";
2661
2662 tpu::TPUCompileMetadataProto proto;
2663 proto.set_num_replicas(params_info.NumReplicas());
2664 proto.set_num_cores_per_replica(num_cores_per_replica);
2665 proto.set_function_library_fingerprint(library_fingerprint);
2666 proto.set_enable_automatic_model_parallelism(
2667 enable_cross_replica_sharding_mirrored_variables_);
2668 const bool use_spmd =
2669 UseSpmdForXlaPartitioning(replicate_node) && allow_xla_spmd_partition_ &&
2670 !absl::c_any_of(arg_sharding,
2671 [](const xla::OpSharding& s) {
2672 return s.type() == xla::OpSharding::MAXIMAL;
2673 }) &&
2674 !absl::c_any_of(retval_sharding, [](const xla::OpSharding& s) {
2675 return s.type() == xla::OpSharding::MAXIMAL;
2676 });
2677 proto.set_use_spmd_for_xla_partitioning(use_spmd);
2678
2679 // Get and fill padding map.
2680 if (replicate_node != nullptr) {
2681 xla::DebugOptions::StepMarkerLocation location;
2682 TF_RETURN_IF_ERROR(GetStepMarkerLocation(*replicate_node, &location));
2683 proto.set_step_marker_location(location);
2684 }
2685
2686 if (xla_device_assignment != nullptr) {
2687 TF_RETURN_IF_ERROR(
2688 xla_device_assignment->Serialize(proto.mutable_device_assignment()));
2689 }
2690
2691 const int num_args = arg_types.size();
2692 const int num_guaranteed_constants = guaranteed_constant_nodes.size();
2693 const int guaranteed_const_start_index = num_args - num_guaranteed_constants;
2694 TF_RET_CHECK(num_args == arg_shapes.size());
2695 TF_RET_CHECK(num_args == arg_sharding.size())
2696 << num_args << " != " << arg_sharding.size();
2697
2698 for (int i = 0; i < num_args; ++i) {
2699 tpu::TPUCompileMetadataProto::Arg* arg = proto.add_args();
2700 DataType type = arg_types[i];
2701 const InferredShape& arg_shape = arg_shapes[i];
2702 arg->set_name(arg_names[i]);
2703 if (type == DT_RESOURCE) {
2704 TF_RET_CHECK(arg_shape.handle_type != DT_INVALID) << i;
2705 arg->set_dtype(arg_shape.handle_type);
2706 arg_shape.handle_shape.AsProto(arg->mutable_shape());
2707 arg->set_kind(tpu::TPUCompileMetadataProto::Arg::VARIABLE);
2708 arg->set_fast_mem(arg_fast_mem[i]);
2709 } else {
2710 arg->set_dtype(type);
2711 arg_shape.shape.AsProto(arg->mutable_shape());
2712 if (i >= guaranteed_const_start_index) {
2713 const DataType edge_type =
2714 guaranteed_constant_nodes[i - guaranteed_const_start_index]
2715 ->output_type(0);
2716 TF_RET_CHECK(type == edge_type)
2717 << "Arg type: " << type << " but edge type: " << edge_type;
2718 arg->set_kind(tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT);
2719 } else {
2720 arg->set_kind(tpu::TPUCompileMetadataProto::Arg::PARAMETER);
2721 }
2722 }
2723
2724 // Use XLA collective primitives to distribute variables to all replicas.
2725 arg->set_requires_xla_broadcast(
2726 params_info.NumReplicas() > 1 &&
2727 EnableXlaParamBroadcast(enable_xla_param_broadcast_, params_info, i,
2728 arg_shape.handle_type /*arg.dtype?*/));
2729
2730 // As long as the argument is not a per-replica one, it should have the same
2731 // value for all replicas. For clarity, we keep the (redundant) checks for
2732 // variable, broadcast and constant types, to prevent bugs in case new types
2733 // with different semantics are introduced in the future.
2734 arg->set_is_same_data_across_replicas(
2735 !params_info.IsPerReplicaArg(i) && !params_info.IsDistributedArg(i) &&
2736 (params_info.IsVariableArg(i) || params_info.IsBroadcastArg(i) ||
2737 params_info.IsConstantArg(i)));
2738 if (params_info.mirrored_variable_indices().count(i) > 0) {
2739 CHECK_EQ(type, DT_RESOURCE);
2740 arg->set_is_same_data_across_replicas(true);
2741 // 64-bit type is not shardable by XLA:TPU yet.
2742 bool sharding_enabled = (arg_shape.handle_type != DT_COMPLEX64 &&
2743 arg_shape.handle_type != DT_INT64 &&
2744 arg_shape.handle_type != DT_UINT64 &&
2745 arg_shape.handle_type != DT_DOUBLE);
2746 arg->set_enable_xla_sharding(
2747 sharding_enabled ? tpu::TPUCompileMetadataProto::Arg::TENTATIVE
2748 : tpu::TPUCompileMetadataProto::Arg::DISALLOWED);
2749 }
2750 *arg->mutable_sharding() = arg_sharding[i];
2751 }
2752
2753 const int num_retvals = retval_sharding.size();
2754 for (int i = 0; i < num_retvals; ++i) {
2755 *proto.add_retvals()->mutable_sharding() = retval_sharding[i];
2756 }
2757 proto.set_session_handle(session_handle);
2758
2759 DataTypeVector constant_arg_types;
2760 constant_arg_types.reserve(num_guaranteed_constants);
2761 for (int i = 0; i < num_guaranteed_constants; ++i) {
2762 constant_arg_types.push_back(arg_types[guaranteed_const_start_index + i]);
2763 }
2764 proto.set_xla_fusion_autotuner_thresh(autotuner_thresh);
2765
2766 string metadata;
2767 proto.SerializeToString(&metadata);
2768
2769 NodeDef def;
2770 def.set_name(UniqueNodeName("TPUReplicate/_compile", graph));
2771 def.set_op("TPUCompile");
2772 def.set_device(compile_device);
2773 if (replicate_node) {
2774 MergeDebugInfo(NodeDebugInfo(replicate_node->def()), &def);
2775 }
2776
2777 AddNodeAttr("function", function, &def);
2778 AddNodeAttr("num_computations", num_cores_per_replica, &def);
2779 AddNodeAttr("NumDynamicShapes", static_cast<int>(dynamic_shape_nodes.size()),
2780 &def);
2781 AddNodeAttr("metadata", metadata, &def);
2782 AddNodeAttr("Tguaranteed_constants", constant_arg_types, &def);
2783
2784 Status status;
2785 *compile_node = graph->AddNode(def, &status);
2786 TF_RETURN_IF_ERROR(status);
2787
2788 (*compile_node)->set_assigned_device_name(compile_device);
2789
2790 for (int i = 0; i < dynamic_shape_nodes.size(); ++i) {
2791 graph->AddEdge(dynamic_shape_nodes[i], 0, *compile_node, i);
2792 }
2793
2794 for (int i = 0; i < num_guaranteed_constants; ++i) {
2795 graph->AddEdge(guaranteed_constant_nodes[i], 0, *compile_node,
2796 dynamic_shape_nodes.size() + i);
2797 }
2798 VLOG(1) << "BuildCompileNode(): " << status;
2799 return status;
2800 }
2801
FindGuaranteedConstantInputs(const Node & node,const NameRangeMap & input_range_map,std::vector<Node * > * guaranteed_constants)2802 Status DistributedTPURewritePass::FindGuaranteedConstantInputs(
2803 const Node& node, const NameRangeMap& input_range_map,
2804 std::vector<Node*>* guaranteed_constants) {
2805 std::vector<const Edge*> input_edges;
2806 TF_RETURN_IF_ERROR(node.input_edges(&input_edges));
2807 std::pair<int, int> variables_limits =
2808 input_range_map.at("guaranteed_constants");
2809 for (int i = variables_limits.first; i < variables_limits.second; ++i) {
2810 guaranteed_constants->push_back(input_edges[i]->src());
2811 }
2812 return Status::OK();
2813 }
2814
FindVariableInputs(const Node & node,const NameRangeMap & input_range_map,std::vector<VariableInput> * variables)2815 Status DistributedTPURewritePass::FindVariableInputs(
2816 const Node& node, const NameRangeMap& input_range_map,
2817 std::vector<VariableInput>* variables) {
2818 std::vector<const Edge*> input_edges;
2819 TF_RETURN_IF_ERROR(node.input_edges(&input_edges));
2820 std::pair<int, int> variables_limits = input_range_map.at("variables");
2821 for (int i = variables_limits.first; i < variables_limits.second; ++i) {
2822 Node* node = input_edges[i]->src();
2823
2824 // Find the type of the VarHandleOp that feeds this node, looking through
2825 // any wrapping Enter or Switch nodes.
2826 while (node->IsEnter() || node->IsSwitch()) {
2827 TF_RETURN_IF_ERROR(node->input_node(0, &node));
2828 }
2829 // Fix the variable device assignment if it is requested with a full name.
2830 if (!node->has_assigned_device_name() &&
2831 !node->requested_device().empty()) {
2832 DeviceNameUtils::ParsedName var_device;
2833 TF_RET_CHECK(DeviceNameUtils::ParseFullName(node->requested_device(),
2834 &var_device));
2835 if (var_device.has_job && var_device.has_replica && var_device.has_task &&
2836 var_device.has_type && var_device.has_id) {
2837 node->set_assigned_device_name(node->requested_device());
2838 if (node != input_edges[i]->src() &&
2839 !input_edges[i]->src()->has_assigned_device_name()) {
2840 input_edges[i]->src()->set_assigned_device_name(
2841 node->requested_device());
2842 }
2843 }
2844 }
2845 if (node->type_string() == kVarHandleOp) {
2846 DataType dtype;
2847 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "dtype", &dtype));
2848 variables->push_back(VariableInput{input_edges[i]->src(),
2849 input_edges[i]->src_output(), dtype});
2850 } else if (node->type_string() == "_Arg") {
2851 std::vector<DataType> dtypes;
2852 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "_handle_dtypes", &dtypes));
2853 if (dtypes.empty()) {
2854 return errors::Internal(
2855 "_Arg node with resource output must have non-empty _handle_dtypes "
2856 "attribute: ",
2857 node->DebugString());
2858 }
2859 variables->push_back(VariableInput{
2860 input_edges[i]->src(), input_edges[i]->src_output(), dtypes[0]});
2861 } else {
2862 return errors::Internal(
2863 "Cannot handle variable input with node type other than VarHandleOp "
2864 "and _Arg: ",
2865 node->DebugString());
2866 }
2867 }
2868 return Status::OK();
2869 }
2870
2871 // Builds a NoOp node, used for building control dependencies.
BuildNoopNode(const Node & source,StringPiece name,const string & device,Graph * graph,Node ** node)2872 static Status BuildNoopNode(const Node& source, StringPiece name,
2873 const string& device, Graph* graph, Node** node) {
2874 NodeDefBuilder builder(name, "NoOp", NodeDebugInfo(source));
2875 if (!device.empty()) {
2876 builder.Device(device);
2877 }
2878 NodeDef def;
2879 TF_RETURN_IF_ERROR(builder.Finalize(&def));
2880
2881 Status status;
2882 *node = graph->AddNode(def, &status);
2883 if (!device.empty()) {
2884 (*node)->set_assigned_device_name(device);
2885 }
2886 return status;
2887 }
2888
ConnectHostComputeNodes(Node * compile_node,Node * key_placeholder_node,Graph * graph)2889 Status DistributedTPURewritePass::ConnectHostComputeNodes(
2890 Node* compile_node, Node* key_placeholder_node, Graph* graph) {
2891 // First find all the downstream nodes of the key placeholder node, since we
2892 // want to delete the connecting edges from key_placeholder_node which would
2893 // invalidate the out_nodes iterator.
2894 std::vector<Node*> host_transfer_nodes;
2895 for (Node* node : key_placeholder_node->out_nodes()) {
2896 host_transfer_nodes.push_back(node);
2897 }
2898 for (Node* node : host_transfer_nodes) {
2899 int input_index = -1;
2900 for (int i = 0; i < node->num_inputs(); i++) {
2901 const Edge* e;
2902 TF_RETURN_IF_ERROR(node->input_edge(i, &e));
2903 if (e->src() == key_placeholder_node) {
2904 if (input_index != -1) {
2905 return errors::Internal(
2906 "Node ", node->name(),
2907 " has multiple input edges from key placeholder node");
2908 }
2909 input_index = e->dst_input();
2910 }
2911 }
2912 if (input_index == -1) {
2913 return errors::Internal("Node ", node->name(),
2914 " has no input edge from key placeholder node");
2915 }
2916 const Edge* key_edge;
2917 TF_RETURN_IF_ERROR(node->input_edge(input_index, &key_edge));
2918 graph->RemoveEdge(key_edge);
2919 graph->AddEdge(compile_node, 1, node, input_index);
2920 }
2921 graph->RemoveNode(key_placeholder_node);
2922 return Status::OK();
2923 }
2924
BuildVariableReads(absl::Span<const VariableInput> variables,Node * control_predecessor,Graph * graph,std::vector<Node * > * variable_reads)2925 Status DistributedTPURewritePass::BuildVariableReads(
2926 absl::Span<const VariableInput> variables, Node* control_predecessor,
2927 Graph* graph, std::vector<Node*>* variable_reads) {
2928 variable_reads->resize(variables.size());
2929 for (int i = 0; i < variables.size(); ++i) {
2930 string name =
2931 graph->NewName(strings::StrCat(variables[i].node->name(), "/read"));
2932 NodeDefBuilder builder(name, "ReadVariableOp",
2933 NodeDebugInfo(*variables[i].node));
2934
2935 builder.Attr("dtype", variables[i].dtype);
2936 builder.Device(variables[i].node->assigned_device_name());
2937 builder.Input(variables[i].node->name(), 0, DT_RESOURCE);
2938 NodeDef def;
2939 TF_RETURN_IF_ERROR(builder.Finalize(&def));
2940
2941 Status status;
2942 Node* read_node;
2943 (*variable_reads)[i] = read_node = graph->AddNode(def, &status);
2944 if (!status.ok()) return status;
2945
2946 read_node->set_requested_device(variables[i].node->requested_device());
2947 read_node->set_assigned_device_name(
2948 variables[i].node->assigned_device_name());
2949 graph->AddEdge(variables[i].node, variables[i].index, read_node, 0);
2950
2951 graph->AddControlEdge(control_predecessor, read_node);
2952 }
2953 return Status::OK();
2954 }
2955
ContainsResourceWriteOp(const Graph & graph,const FunctionLibraryDefinition & fld)2956 bool DistributedTPURewritePass::ContainsResourceWriteOp(
2957 const Graph& graph, const FunctionLibraryDefinition& fld) {
2958 for (const Node* n : graph.nodes()) {
2959 const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n->type_string());
2960 if (op_info && op_info->kind() != XlaResourceOpKind::kRead) {
2961 VLOG(2) << "Found write resource op inside computation";
2962 return true;
2963 }
2964 }
2965 for (const string& func_name : fld.ListFunctionNames()) {
2966 const FunctionDef* func_def = fld.Find(func_name);
2967 for (const NodeDef& n : func_def->node_def()) {
2968 const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.op());
2969 if (op_info && op_info->kind() != XlaResourceOpKind::kRead) {
2970 VLOG(2) << "Found write resource op inside " << func_name;
2971 return true;
2972 }
2973 }
2974 }
2975 return false;
2976 }
2977
BuildVariableWrites(absl::Span<const VariableInput> variables,Node * control_successor,absl::Span<const VariableWrite> variable_writes,Graph * graph)2978 Status DistributedTPURewritePass::BuildVariableWrites(
2979 absl::Span<const VariableInput> variables, Node* control_successor,
2980 absl::Span<const VariableWrite> variable_writes, Graph* graph) {
2981 CHECK_EQ(variables.size(), variable_writes.size());
2982 for (int i = 0; i < variables.size(); ++i) {
2983 const VariableWrite& write = variable_writes[i];
2984 NodeDebugInfo debug_info(*variables[i].node);
2985
2986 auto name = [&](string suffix) {
2987 return graph->NewName(
2988 strings::StrCat(variables[i].node->name(), "/", suffix));
2989 };
2990
2991 Node* write_node;
2992 TF_RETURN_IF_ERROR(
2993 IncompleteNodeDefBuilder(name("assign"), "AssignVariableOp", debug_info)
2994 .AddAttr("dtype", variables[i].dtype)
2995 .Device(variables[i].node->assigned_device_name())
2996 .Build(graph, &write_node));
2997
2998 // Colocate the control flow with the variable.
2999 CondBuilder cb(variables[i].node->name(),
3000 variables[i].node->assigned_device_name(), debug_info,
3001 graph);
3002
3003 // Inputs to conditional.
3004 Node* switch_val;
3005 TF_RETURN_IF_ERROR(
3006 cb.AddInput("switch_val", variables[i].dtype,
3007 /*device=*/write.value->assigned_device_name(), debug_info,
3008 &switch_val));
3009 Node* switch_var;
3010 TF_RETURN_IF_ERROR(
3011 cb.AddInput("switch_var", DT_RESOURCE,
3012 /*device=*/variables[i].node->assigned_device_name(),
3013 debug_info, &switch_var));
3014 // Conditionally write the value back.
3015 graph->AddEdge(variables[i].node, variables[i].index, switch_var, 0);
3016 graph->AddEdge(switch_var, CondBuilder::kThenBranch, write_node, 0);
3017 graph->AddEdge(switch_val, CondBuilder::kThenBranch, write_node, 1);
3018 // Add control edge from the write to value that will be merged. There is no
3019 // output from the write so this control edge ensures the write completes.
3020 graph->AddControlEdge(write_node, cb.switch_t());
3021
3022 graph->AddControlEdge(cb.control_successor(), control_successor);
3023
3024 graph->AddEdge(write.predicate, write.predicate_output, cb.pred(), 0);
3025 graph->AddEdge(write.value, write.value_output, switch_val, 0);
3026 }
3027 return Status::OK();
3028 }
3029
3030 namespace {
3031
3032 // Computes the shape of the sharded tensor and modifies in place.
ComputeShardedArgShapes(TensorShape * shape,const xla::OpSharding & sharding)3033 Status ComputeShardedArgShapes(TensorShape* shape,
3034 const xla::OpSharding& sharding) {
3035 if (sharding.type() != xla::OpSharding::OTHER) {
3036 return Status::OK();
3037 }
3038 if (!shape->IsFullyDefined()) {
3039 return errors::Internal(
3040 "Arg shape must be fully defined before sharded shape inference.");
3041 }
3042 int sharded_rank = sharding.tile_assignment_dimensions_size();
3043 if (sharding.replicate_on_last_tile_dim()) {
3044 sharded_rank--;
3045 }
3046 for (int dim_idx = 0; dim_idx < sharded_rank; ++dim_idx) {
3047 auto sharded_dim = tensorflow::MathUtil::CeilOfRatio<int64>(
3048 shape->dim_size(dim_idx), sharding.tile_assignment_dimensions(dim_idx));
3049 shape->set_dim(dim_idx, sharded_dim);
3050 }
3051 if (sharded_rank != shape->dims()) {
3052 LOG(WARNING) << "Rank of sharded arg should match sharding spec. Rank: "
3053 << sharded_rank << ", tiled shape: " << shape->DebugString()
3054 << ", sharding: " << sharding.DebugString();
3055 }
3056
3057 return Status::OK();
3058 }
3059
3060 // Creates nodes for zero-initialized dummy arguments for TPUExecute nodes.
CreateTpuExecuteDummyArg(const TensorShape & var_shape,const DataType & dtype,const string & host_cpu_device,Node * var_read,int replica_id,Graph * graph)3061 xla::StatusOr<Node*> CreateTpuExecuteDummyArg(const TensorShape& var_shape,
3062 const DataType& dtype,
3063 const string& host_cpu_device,
3064 Node* var_read, int replica_id,
3065 Graph* graph) {
3066 Status status;
3067
3068 // Const - shape_as_tensor
3069 const std::string name_prefix = strings::StrCat(
3070 var_read->name(), absl::StrFormat("/dummy_%d", replica_id));
3071 NodeDef shape_tensor_def;
3072 shape_tensor_def.set_op("Const");
3073 shape_tensor_def.set_name(graph->NewName(
3074 strings::StrCat(name_prefix, "/Initializer/zeros/shape_as_tensor")));
3075 shape_tensor_def.set_device(host_cpu_device);
3076 AddNodeAttr("dtype", DT_INT32, &shape_tensor_def);
3077 TensorProto tensorshape_proto;
3078 tensorshape_proto.set_dtype(DT_INT32);
3079 for (int i = 0; i < var_shape.dims(); ++i) {
3080 tensorshape_proto.add_int_val(var_shape.dim_size(i));
3081 }
3082 TensorShape shape_shape({var_shape.dims()});
3083 shape_shape.AsProto(tensorshape_proto.mutable_tensor_shape());
3084 AddNodeAttr("value", tensorshape_proto, &shape_tensor_def);
3085 Node* shape_as_tensor_node = graph->AddNode(shape_tensor_def, &status);
3086 TF_RETURN_IF_ERROR(status);
3087
3088 // Const - initializer value
3089 NodeDef init_val_def;
3090 init_val_def.set_op("Const");
3091 init_val_def.set_name(graph->NewName(
3092 strings::StrCat(name_prefix, "/Initializer/zeros/const_val")));
3093 init_val_def.set_device(host_cpu_device);
3094 TensorProto tensor_proto;
3095 tensor_proto.set_dtype(dtype);
3096 if (dtype == DT_FLOAT) {
3097 tensor_proto.add_float_val(0.0f);
3098 } else if (dtype == DT_BFLOAT16) {
3099 tensor_proto.add_half_val(0);
3100 } else if (dtype == DT_INT32) {
3101 tensor_proto.add_int_val(0);
3102 } else if (dtype == DT_BOOL) {
3103 tensor_proto.add_bool_val(false);
3104 } else {
3105 return errors::Internal(
3106 "Unable to create zero-init dummy arg tensor for type ", dtype);
3107 }
3108 TensorShape scalar_shape({});
3109 scalar_shape.AsProto(tensor_proto.mutable_tensor_shape());
3110 AddNodeAttr("value", tensor_proto, &init_val_def);
3111 AddNodeAttr("dtype", dtype, &init_val_def);
3112 Node* init_val_node = graph->AddNode(init_val_def, &status);
3113 TF_RETURN_IF_ERROR(status);
3114
3115 // Fill node
3116 NodeDef fill_def;
3117 fill_def.set_op("Fill");
3118 fill_def.set_device(host_cpu_device);
3119 fill_def.set_name(
3120 graph->NewName(strings::StrCat(name_prefix, "/Initializer/zeros")));
3121 AddNodeAttr("T", dtype, &fill_def);
3122 AddNodeAttr("index_type", DT_INT32, &fill_def);
3123 Node* fill_node = graph->AddNode(fill_def, &status);
3124 TF_RETURN_IF_ERROR(status);
3125 graph->AddEdge(shape_as_tensor_node, 0, fill_node, 0);
3126 graph->AddEdge(init_val_node, 0, fill_node, 1);
3127
3128 return fill_node;
3129 }
3130
3131 // Creates dummy inputs for partitioned variables that are using XLA broadcast
3132 // for inputs.
CreatePartitionedDummyVarArgs(const xla::OpSharding & sharding,const int num_replicas,const int replica_id,const InferredShape & raw_shape,Node * orig_var_read,const int orig_arg_num,DataType dtype,const string & device,Graph * graph,const std::vector<std::vector<string>> & tpu_device_names,absl::btree_map<ShardedPerHostInputIndex,Node * > * per_host_index,std::map<ShardedInputIndex,ShardedInputInfo> * arg_index_to_sharded_input_map)3133 Status CreatePartitionedDummyVarArgs(
3134 const xla::OpSharding& sharding, const int num_replicas,
3135 const int replica_id, const InferredShape& raw_shape, Node* orig_var_read,
3136 const int orig_arg_num, DataType dtype, const string& device, Graph* graph,
3137 const std::vector<std::vector<string>>& tpu_device_names,
3138 absl::btree_map<ShardedPerHostInputIndex, Node*>* per_host_index,
3139 std::map<ShardedInputIndex, ShardedInputInfo>*
3140 arg_index_to_sharded_input_map) {
3141 ShardedInputIndex input_index{replica_id, orig_arg_num};
3142 auto iter = arg_index_to_sharded_input_map->find(input_index);
3143 if (iter != arg_index_to_sharded_input_map->end()) {
3144 return Status::OK();
3145 }
3146 const int repeat = sharding.replicate_on_last_tile_dim()
3147 ? *sharding.tile_assignment_dimensions().rbegin()
3148 : 1;
3149 const int num_shards = sharding.tile_assignment_devices_size() / repeat;
3150
3151 TensorShape var_shape;
3152 if (!raw_shape.handle_shape.AsTensorShape(&var_shape) &&
3153 !raw_shape.shape.AsTensorShape(&var_shape)) {
3154 return errors::FailedPrecondition("Failed to read arg shape.");
3155 }
3156 TF_RETURN_IF_ERROR(ComputeShardedArgShapes(&var_shape, sharding));
3157
3158 for (int replica = 1; replica < num_replicas; ++replica) {
3159 std::vector<NodeOut> sharded_inputs_list(
3160 sharding.tile_assignment_devices_size());
3161 for (int i = 0; i < num_shards; ++i) {
3162 for (int j = 0; j < repeat; ++j) {
3163 const int index = i * repeat + j;
3164 const int core = sharding.tile_assignment_devices(index);
3165 string host_device;
3166 TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
3167 tpu_device_names[replica][core], &host_device));
3168 ShardedPerHostInputIndex idx{host_device, orig_arg_num};
3169 if (!per_host_index->contains(idx)) {
3170 TF_ASSIGN_OR_RETURN(
3171 auto dummy_node,
3172 CreateTpuExecuteDummyArg(var_shape, dtype, host_device,
3173 orig_var_read, replica, graph));
3174 (*per_host_index)[idx] = dummy_node;
3175 }
3176 sharded_inputs_list[core] = {(*per_host_index)[idx], /*index=*/0};
3177 }
3178 }
3179 ShardedInputInfo sharded_input_info{nullptr,
3180 std::move(sharded_inputs_list)};
3181 (*arg_index_to_sharded_input_map)[{replica, orig_arg_num}] =
3182 sharded_input_info;
3183 }
3184
3185 return Status::OK();
3186 }
3187
3188 // Helper that creates an IdentityN node containing all of the variables
3189 // values on CPU device 'device', except for those that will be split across
3190 // cores. (For split variables, this may cause additional cross-host data
3191 // transfers if more than 1 devices share the same variable partition on a
3192 // remote host.)
3193 //
3194 // A previous iteration of this code built one Identity node per TPU core per
3195 // variable, but this can rapidly become hundreds of thousands of nodes. This
3196 // formulation creates a single IdentityN node containing all of the variables
3197 // on each host. This may cause some unnecessary variable copies if only a
3198 // subset of hosts consume a given variable, but has the virtue of being
3199 // simple, and most models use pure replication where all cores want all the
3200 // variables.
3201 //
3202 // If enable_xla_param_broadcast is set to true, then per-host dummy
3203 // tensor args are created on all hosts except for the primary host. In this
3204 // scheme, the dummy args feed the IdentityN node on their local host. All
3205 // are zero-initialized.
3206 //
3207 // Returns the node and its output index to be consumed by TPUExecute for the
3208 // requested variable index.
CreateOrGetPerHostVariableCopy(const string & host_cpu_device,int64_t var_index,const std::vector<Node * > & variable_reads,const DistributedTPURewritePass::ParameterInfo & params_info,const std::vector<xla::OpSharding> & arg_shardings,const Node & replicate_node,const bool enable_xla_param_broadcast,const int num_cores_per_replica,int replica_id,const std::vector<InferredShape> & arg_shapes,absl::flat_hash_map<string,std::vector<NodeOut>> * per_host_var_copies,Graph * graph)3209 xla::StatusOr<NodeOut> CreateOrGetPerHostVariableCopy(
3210 const string& host_cpu_device, int64_t var_index,
3211 const std::vector<Node*>& variable_reads,
3212 const DistributedTPURewritePass::ParameterInfo& params_info,
3213 const std::vector<xla::OpSharding>& arg_shardings,
3214 const Node& replicate_node, const bool enable_xla_param_broadcast,
3215 const int num_cores_per_replica, int replica_id,
3216 const std::vector<InferredShape>& arg_shapes,
3217 absl::flat_hash_map<string, std::vector<NodeOut>>* per_host_var_copies,
3218 Graph* graph) {
3219 auto it = per_host_var_copies->find(host_cpu_device);
3220 if (it != per_host_var_copies->end()) {
3221 return it->second[var_index];
3222 }
3223
3224 DataTypeVector dtypes;
3225 // Per-variable data source for TPUExecute.
3226 std::vector<NodeOut> index_mapping;
3227 index_mapping.reserve(variable_reads.size());
3228 dtypes.reserve(variable_reads.size());
3229 for (int64_t i = 0; i < variable_reads.size(); ++i) {
3230 Node* read = variable_reads[i];
3231 int64_t orig_arg_num = i + params_info.NumPerReplicaArgs() +
3232 params_info.NumDistributedArgs() +
3233 params_info.NumBroadcastArgs();
3234 if (arg_shardings[orig_arg_num].type() != xla::OpSharding::OTHER) {
3235 // We haven't built the IdentityN node yet, so temporarily use nullptr.
3236 index_mapping.push_back(
3237 NodeOut{nullptr, static_cast<int>(dtypes.size())});
3238 dtypes.push_back(read->output_type(0));
3239 } else {
3240 // Do not copy the full tensor of partitioned variables.
3241 index_mapping.push_back(NodeOut{read, 0});
3242 }
3243 }
3244 NodeDef ndef;
3245 ndef.set_name(graph->NewName(
3246 absl::StrCat(replicate_node.name(), "/", kTpuExecuteStagingNodeName)));
3247 ndef.set_op(kTpuExecuteStagingOp);
3248 ndef.set_device(host_cpu_device);
3249 AddNodeAttr("T", dtypes, &ndef);
3250 // TF meta-optimizer should skip this node for constant folding.
3251 AddNodeAttr("_tpu_avoid_constant_fold", "not_used", &ndef);
3252 Status s;
3253 Node* id_node = graph->AddNode(ndef, &s);
3254 TF_RETURN_IF_ERROR(s);
3255 id_node->set_assigned_device_name(host_cpu_device);
3256
3257 for (int64_t i = 0; i < variable_reads.size(); ++i) {
3258 Node* read = variable_reads[i];
3259 int64_t orig_arg_num = i + params_info.NumPerReplicaArgs() +
3260 params_info.NumDistributedArgs() +
3261 params_info.NumBroadcastArgs();
3262 DataType dtype = read->output_type(0);
3263 bool use_xla_broadcast =
3264 EnableXlaParamBroadcast(enable_xla_param_broadcast, params_info,
3265 orig_arg_num, dtype) &&
3266 replica_id != 0;
3267 if (index_mapping[i].node == nullptr) {
3268 // Fill index_mapping with the actual IdentityN node.
3269 index_mapping[i].node = id_node;
3270 if (!use_xla_broadcast) {
3271 // Add the variable read edge to id_node.
3272 graph->AddEdge(variable_reads[i], 0, id_node, index_mapping[i].index);
3273 } else {
3274 // XLA param broadcast mode is enabled. Create zero-valued dummy
3275 // tensors to use as variable args in the TPUExecuteOp, instead of
3276 // original variable reads.
3277 TensorShape var_shape;
3278 auto inferred_shape = arg_shapes[orig_arg_num];
3279 if (!inferred_shape.handle_shape.AsTensorShape(&var_shape) &&
3280 !inferred_shape.shape.AsTensorShape(&var_shape)) {
3281 return errors::FailedPrecondition("Failed to read arg shape.");
3282 }
3283 TF_ASSIGN_OR_RETURN(
3284 Node * dummy_read,
3285 CreateTpuExecuteDummyArg(var_shape, dtype, host_cpu_device,
3286 variable_reads[i], replica_id, graph));
3287 graph->AddEdge(dummy_read, 0, id_node, index_mapping[i].index);
3288 }
3289 }
3290 }
3291
3292 auto result = index_mapping[var_index];
3293 (*per_host_var_copies)[host_cpu_device] = std::move(index_mapping);
3294 return result;
3295 }
3296
3297 } // namespace
3298
BuildExecuteNodes(const ParameterInfo & params_info,int num_tasks,int num_cores_per_replica,const Node & replicate_node,const std::vector<std::string> & arg_names,const DataTypeVector & arg_types,const std::vector<InferredShape> & arg_shapes,const DataTypeVector & retval_types,const std::vector<xla::OpSharding> & arg_shardings,const std::vector<xla::OpSharding> & retval_shardings,const std::vector<std::vector<string>> & tpu_device_names,Node * compile_node,const std::vector<Node * > & variable_reads,Node * control_predecessor,Node * control_successor,Node * multilock_acquire,std::vector<VariableWrite> * variable_writes,Graph * graph)3299 Status DistributedTPURewritePass::BuildExecuteNodes(
3300 const ParameterInfo& params_info, int num_tasks, int num_cores_per_replica,
3301 const Node& replicate_node, const std::vector<std::string>& arg_names,
3302 const DataTypeVector& arg_types,
3303 const std::vector<InferredShape>& arg_shapes,
3304 const DataTypeVector& retval_types,
3305 const std::vector<xla::OpSharding>& arg_shardings,
3306 const std::vector<xla::OpSharding>& retval_shardings,
3307 const std::vector<std::vector<string>>& tpu_device_names,
3308 Node* compile_node, const std::vector<Node*>& variable_reads,
3309 Node* control_predecessor, Node* control_successor, Node* multilock_acquire,
3310 std::vector<VariableWrite>* variable_writes, Graph* graph) {
3311 VLOG(1) << "BuildExecuteNodes " << replicate_node.DebugString();
3312 TF_RET_CHECK(params_info.NumReplicas() == tpu_device_names.size());
3313
3314 const int num_variables = variable_reads.size();
3315 const int num_retvals_per_replica = retval_types.size();
3316
3317 variable_writes->resize(num_variables);
3318
3319 std::vector<const Edge*> replicate_input_edges;
3320 TF_RETURN_IF_ERROR(replicate_node.input_edges(&replicate_input_edges));
3321
3322 // Map from replicate input index to the fan_in node;
3323 absl::flat_hash_map<int, std::vector<NodeAndPort>>
3324 replicate_input_fan_in_nodes;
3325 absl::flat_hash_map<int, std::vector<Node*>> replicate_output_fan_out_nodes;
3326 absl::flat_hash_map<int, std::vector<int>>
3327 replicate_output_fan_out_dst_inputs;
3328 std::vector<Node*> to_be_removed_nodes;
3329
3330 for (const Edge* e : replicate_input_edges) {
3331 if (e->src()->type_string() == kTPUPartitionedInput) {
3332 int num_users = 0;
3333 for (const auto& ue : e->src()->out_edges()) {
3334 if (!ue->IsControlEdge()) ++num_users;
3335 }
3336 if (num_users != 1) {
3337 return tensorflow::errors::InvalidArgument(
3338 e->src()->name(), " must only have one user. Found ", num_users);
3339 }
3340 to_be_removed_nodes.push_back(e->src());
3341 std::vector<NodeAndPort>& nodes =
3342 replicate_input_fan_in_nodes[e->dst_input()];
3343 nodes.resize(num_cores_per_replica, NodeAndPort(nullptr, 0));
3344 VLOG(2) << "allocate " << num_cores_per_replica
3345 << " for replicate_input_fan_in_nodes[" << e->dst_input() << "]";
3346 std::vector<const Edge*> fan_in_edges;
3347 TF_RETURN_IF_ERROR(e->src()->input_edges(&fan_in_edges));
3348 TF_RET_CHECK(fan_in_edges.size() == num_cores_per_replica);
3349
3350 for (const Edge* fe : fan_in_edges) {
3351 nodes[fe->dst_input()].node = fe->src();
3352 nodes[fe->dst_input()].port = fe->src_output();
3353 VLOG(2) << "replicate_input_fan_in_nodes[" << e->dst_input() << "]["
3354 << fe->dst_input() << "] = " << fe->src()->name();
3355 }
3356 }
3357 }
3358
3359 // Replicate output edges are sorted by replica id and then by outputs for
3360 // each replica. For example, if TPU Computation has outputs (output_1,
3361 // output_2, and output_3) and number of replicas is 2, then
3362 // replicate_output_edges order would be:
3363 // output_1_replica_1, output_2_replica_1, output_3_replica_1,
3364 // output_1_replica_2, output_2_replica_2, output_3_replica_2.
3365 std::vector<const Edge*> replicate_output_edges(replicate_node.num_outputs(),
3366 nullptr);
3367 for (const Edge* edge : replicate_node.out_edges()) {
3368 if (edge->IsControlEdge()) continue;
3369
3370 int num_partitioned_outputs = 0;
3371
3372 for (const Edge* out_edge : edge->dst()->out_edges()) {
3373 if (out_edge->dst()->type_string() == kTPUPartitionedOutput) {
3374 num_partitioned_outputs++;
3375 // Paths between replicate_node and replicate_output_fan_out_nodes:
3376 // ReplicateNode->TpuOutIdenity->kTPUPartitionedOutput->fan-out-nodes
3377 TF_RET_CHECK(edge->dst()->out_edges().size() == 1);
3378 to_be_removed_nodes.push_back(edge->dst());
3379 to_be_removed_nodes.push_back(out_edge->dst());
3380 // Get the right replicated id from the replicate_output_edge.
3381 std::vector<Node*>& nodes =
3382 replicate_output_fan_out_nodes[edge->src_output()];
3383 std::vector<int>& dst_inputs =
3384 replicate_output_fan_out_dst_inputs[edge->src_output()];
3385 nodes.resize(num_cores_per_replica, nullptr);
3386 dst_inputs.resize(num_cores_per_replica, 0);
3387 TF_RET_CHECK(out_edge->dst()->out_edges().size() ==
3388 num_cores_per_replica);
3389
3390 for (const Edge* fe : out_edge->dst()->out_edges()) {
3391 nodes[fe->src_output()] = fe->dst();
3392 dst_inputs[fe->src_output()] = fe->dst_input();
3393 VLOG(2) << "replicate_output_fan_out_nodes[" << out_edge->src_output()
3394 << "][" << fe->src_output()
3395 << "] = " << fe->dst()->DebugString() << " with dst_input "
3396 << fe->dst_input();
3397 }
3398 }
3399 }
3400 replicate_output_edges[edge->src_output()] = edge;
3401 if (num_partitioned_outputs > 1) {
3402 return errors::InvalidArgument(
3403 "More than one TPUPartitionedOutput per replicated output.");
3404 }
3405 }
3406
3407 const int num_execute_args =
3408 arg_shardings.size() - params_info.NumGuaranteedConstants();
3409 // Inverts the arg_shardings and retval_shardings mappings to
3410 // form core -> {argument number} maps.
3411 std::vector<std::vector<int>> core_arg_nums(num_cores_per_replica);
3412 for (int i = 0; i < num_execute_args; ++i) {
3413 const auto& sharding = arg_shardings[i];
3414 if (sharding.type() == xla::OpSharding::MAXIMAL) {
3415 int core = sharding.tile_assignment_devices(0);
3416 TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica));
3417 core_arg_nums[core].push_back(i);
3418 } else if (sharding.type() == xla::OpSharding::OTHER) {
3419 for (int64_t core : sharding.tile_assignment_devices()) {
3420 core_arg_nums[core].push_back(i);
3421 }
3422 } else if (sharding.type() == xla::OpSharding::REPLICATED) {
3423 for (int core = 0; core < num_cores_per_replica; ++core) {
3424 core_arg_nums[core].push_back(i);
3425 }
3426 } else {
3427 return tensorflow::errors::InvalidArgument(
3428 "Unsupported argument sharding for arg=", arg_names[i],
3429 " shape=", arg_shapes[i].shape.DebugString(), ": ",
3430 sharding.DebugString());
3431 }
3432 }
3433 std::vector<std::vector<int>> core_retval_nums(num_cores_per_replica);
3434 for (int i = 0; i < retval_shardings.size(); ++i) {
3435 const auto& sharding = retval_shardings[i];
3436 if (sharding.type() == xla::OpSharding::MAXIMAL) {
3437 int core = sharding.tile_assignment_devices(0);
3438 TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica));
3439 core_retval_nums[core].push_back(i);
3440 } else if (sharding.type() == xla::OpSharding::REPLICATED) {
3441 for (int core = 0; core < num_cores_per_replica; ++core) {
3442 core_retval_nums[core].push_back(i);
3443 }
3444 } else if (sharding.type() == xla::OpSharding::OTHER) {
3445 for (int64_t core : sharding.tile_assignment_devices()) {
3446 core_retval_nums[core].push_back(i);
3447 }
3448 } else {
3449 return tensorflow::errors::InvalidArgument(
3450 "Unsupported argument sharding: ", sharding.DebugString());
3451 }
3452 }
3453
3454 // Maps host device name to a list of per-variable pairs (variable_copy_node,
3455 // output_index_of_copy_node).
3456 absl::flat_hash_map<string, std::vector<NodeOut>> per_host_var_copies;
3457
3458 Node* execute_successor = control_successor;
3459
3460 int num_total_cores = params_info.NumReplicas() * num_cores_per_replica;
3461 if (enable_multicore_locking_ && num_total_cores > 1) {
3462 // Add a node to release exclusive access once all the cores have finished
3463 // execution.
3464 NodeDef lock_def;
3465 lock_def.set_name(graph->NewName(
3466 strings::StrCat(compile_node->name(), "/", "tpu_release_multilock")));
3467 lock_def.set_op("ConsumeTpuMultilock");
3468 MergeDebugInfo(NodeDebugInfo(replicate_node.def()), &lock_def);
3469 Status status;
3470 Node* multilock_release = graph->AddNode(lock_def, &status);
3471 TF_RETURN_IF_ERROR(status);
3472 multilock_release->set_assigned_device_name(
3473 compile_node->assigned_device_name());
3474 TF_RET_CHECK(multilock_acquire != nullptr);
3475 graph->AddEdge(multilock_acquire, 0, multilock_release, 0);
3476 graph->AddControlEdge(multilock_release, control_successor);
3477 // Make sure all execute Ops happen before the multilock_release.
3478 execute_successor = multilock_release;
3479 }
3480
3481 // Mapping from original resource arg number to a second level map. Second
3482 // level map is from core id to output index of updated variable value.
3483 absl::flat_hash_map<int, absl::flat_hash_map<int, int>>
3484 orig_arg_num_to_output_index_mapping;
3485 // Mapping from retval index to a second level map. Second level map is from
3486 // core id to output index of sharded output value.
3487 std::unordered_map<int, std::unordered_map<int, int>>
3488 retval_index_to_output_index_mapping;
3489
3490 // Represents mapping of argument index of sharded input to each
3491 // TPUExecute node to its corresponding Split node and its output index
3492 // from which sharded input will be fed into TPUExecute node.
3493 std::map<ShardedInputIndex, ShardedInputInfo> input_index_to_sharded_inputs;
3494
3495 // Additional map of {host, arg_num} to dummy input. Per-task copies of the
3496 // inputs reduces cross-task communication and allows sharing across replicas.
3497 absl::btree_map<ShardedPerHostInputIndex, Node*> sharded_per_host_index;
3498
3499 // Builds one TPUExecute node per core per replica.
3500 std::vector<std::vector<Node*>> execute_nodes(params_info.NumReplicas());
3501 for (int core = 0; core < num_cores_per_replica; ++core) {
3502 DataTypeVector core_retval_types;
3503 for (int output : core_retval_nums[core]) {
3504 core_retval_types.push_back(retval_types[output]);
3505 }
3506 DataTypeVector core_arg_types;
3507 std::vector<int> core_variable_writes;
3508 for (int input : core_arg_nums[core]) {
3509 // Resource variables can be passed either by reference (as a DT_RESOURCE)
3510 // tensor or by value (as the variable's current value). Per-replica or
3511 // distributed resource arguments are always passed by reference and
3512 // broadcast variables are always passed by value.
3513 if (arg_types[input] == DT_RESOURCE &&
3514 !params_info.IsPerReplicaArg(input) &&
3515 !params_info.IsDistributedArg(input)) {
3516 DataType handle_type = arg_shapes[input].handle_type;
3517 TF_RET_CHECK(handle_type != DT_INVALID) << DataTypeString(handle_type);
3518 core_arg_types.push_back(handle_type);
3519 int base = input - params_info.NumPerReplicaArgs() -
3520 params_info.NumDistributedArgs() -
3521 params_info.NumBroadcastArgs();
3522 // Variables passed by value will have a corresponding additional output
3523 // containing an updated value for the variable.
3524 core_variable_writes.push_back(base);
3525 core_retval_types.push_back(handle_type);
3526 } else {
3527 core_arg_types.push_back(arg_types[input]);
3528 }
3529 }
3530
3531 NodeDef def;
3532 def.set_op("TPUExecute");
3533 MergeDebugInfo(NodeDebugInfo(replicate_node.def()), &def);
3534 AddNodeAttr("Targs", core_arg_types, &def);
3535 AddNodeAttr("Tresults", core_retval_types, &def);
3536
3537 for (int64_t replica = 0; replica < params_info.NumReplicas(); ++replica) {
3538 def.set_name(strings::StrCat(replicate_node.name(), "/_execute_", replica,
3539 "_", core));
3540
3541 Status status;
3542 Node* node = graph->AddNode(def, &status);
3543 if (!status.ok()) return status;
3544 execute_nodes[replica].push_back(node);
3545
3546 node->set_assigned_device_name(tpu_device_names[replica][core]);
3547
3548 // Add control edges to ensure that execution happens after
3549 // `control_predecessor`, happens before `execute_successor`, and is
3550 // triggered by evaluating any operator that depends on the original
3551 // TPUReplicate operator. See the comment at the top of the header file
3552 // for more details.
3553 graph->AddControlEdge(control_predecessor, node);
3554 graph->AddControlEdge(node, execute_successor);
3555
3556 // Add data input edges.
3557 for (int64_t i = 0; i < core_arg_nums[core].size(); ++i) {
3558 int64_t orig_arg_num = core_arg_nums[core][i];
3559 VLOG(2) << " replica " << replica << " core " << core << " i " << i
3560 << " orig_arg_num " << orig_arg_num;
3561 const bool is_per_replica_arg =
3562 params_info.IsPerReplicaArg(orig_arg_num);
3563 if (is_per_replica_arg || params_info.IsDistributedArg(orig_arg_num)) {
3564 // Per-replica input and distributed input
3565 const int64_t input_num =
3566 is_per_replica_arg ? replica * params_info.NumPerReplicaArgs() +
3567 core_arg_nums[core][i]
3568 : params_info.NumReplicas() *
3569 params_info.NumPerReplicaArgs() +
3570 core_arg_nums[core][i] -
3571 params_info.NumPerReplicaArgs();
3572
3573 const Edge* edge = replicate_input_edges[input_num];
3574 VLOG(2) << "replicate_input_edges[" << input_num << "]";
3575 DataType dtype = edge->src()->output_type(edge->src_output());
3576 if (dtype == DT_RESOURCE) {
3577 DataType handle_dtype = arg_shapes[orig_arg_num].handle_type;
3578 if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(),
3579 handle_dtype) == kTpuAllTypes.end()) {
3580 return errors::InvalidArgument(
3581 "Unsupported resource variable data type for TPU: ",
3582 DataTypeString(handle_dtype), ", caused by output ",
3583 edge->src()->name(), ":", edge->src_output());
3584 }
3585 } else {
3586 if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) ==
3587 kTpuAllTypes.end()) {
3588 return errors::InvalidArgument(
3589 "Unsupported data type for TPU: ", DataTypeString(dtype),
3590 ", caused by output ", edge->src()->name(), ":",
3591 edge->src_output());
3592 }
3593 }
3594 if (arg_shardings[orig_arg_num].type() == xla::OpSharding::OTHER) {
3595 // Don't automatically add a split node when input node is
3596 // kTPUPartitionedInput
3597 if (edge->src()->type_string() == kTPUPartitionedInput) {
3598 VLOG(2)
3599 << "Connect "
3600 << replicate_input_fan_in_nodes[input_num][core].node->name()
3601 << " to " << node->name() << " at " << i;
3602 graph->AddEdge(replicate_input_fan_in_nodes[input_num][core].node,
3603 replicate_input_fan_in_nodes[input_num][core].port,
3604 node, i);
3605 } else {
3606 if (dtype == DT_RESOURCE) {
3607 return errors::InvalidArgument(
3608 "Tiled sharding for per-replica DT_RESOURCE input must",
3609 "be TPUPartitionedInput. Here got ",
3610 edge->src()->type_string());
3611 }
3612 const xla::OpSharding& sharding = arg_shardings[orig_arg_num];
3613
3614 ShardedInputInfo sharded_input_info;
3615 if (use_nd_sharding_ops_ && is_per_replica_arg) {
3616 TF_ASSIGN_OR_RETURN(
3617 sharded_input_info,
3618 CreateOrGetXlaSplitNodeForShardedPerReplicaArg(
3619 sharding, replica, orig_arg_num, dtype,
3620 PartialTensorShape(), edge->src(), edge->src_output(),
3621 graph, &input_index_to_sharded_inputs));
3622 } else if (use_nd_sharding_ops_) {
3623 TF_ASSIGN_OR_RETURN(
3624 sharded_input_info,
3625 CreateOrGetXlaSplitNodeForDistributedArg(
3626 sharding, params_info.NumReplicas(), replica,
3627 orig_arg_num, dtype, PartialTensorShape(), edge->src(),
3628 edge->src_output(), graph,
3629 &input_index_to_sharded_inputs));
3630 } else {
3631 TF_ASSIGN_OR_RETURN(
3632 sharded_input_info,
3633 CreateOrGetSplitNodesForInputSharding(
3634 sharding, orig_arg_num, dtype, PartialTensorShape(),
3635 replica, edge->src_output(), edge->src(),
3636 control_predecessor, graph,
3637 &input_index_to_sharded_inputs));
3638 }
3639
3640 NodeOut split_node_and_index =
3641 sharded_input_info.sharded_inputs.at(core);
3642 // Connect with Split node output.
3643 graph->AddEdge(split_node_and_index.node,
3644 split_node_and_index.index, node, i);
3645 }
3646 } else if (edge->src()->type_string() == kTPUPartitionedInput &&
3647 arg_shardings[orig_arg_num].type() ==
3648 xla::OpSharding::REPLICATED) {
3649 graph->AddEdge(replicate_input_fan_in_nodes[input_num][core].node,
3650 replicate_input_fan_in_nodes[input_num][core].port,
3651 node, i);
3652 } else {
3653 graph->AddEdge(edge->src(), edge->src_output(), node, i);
3654 }
3655 } else if (params_info.IsBroadcastArg(orig_arg_num)) {
3656 // Broadcast input.
3657 int64_t input_num = params_info.FirstBroadcastArgFromHost() +
3658 core_arg_nums[core][i] -
3659 params_info.NumPerReplicaArgs() -
3660 params_info.NumDistributedArgs();
3661 const Edge* edge = replicate_input_edges[input_num];
3662 DataType dtype = edge->src()->output_type(edge->src_output());
3663 if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) ==
3664 kTpuAllTypes.end()) {
3665 return errors::InvalidArgument(
3666 "Unsupported data type for TPU: ", DataTypeString(dtype),
3667 ", caused by output ", edge->src()->name(), ":",
3668 edge->src_output());
3669 }
3670 graph->AddEdge(edge->src(), edge->src_output(), node, i);
3671 } else {
3672 // Variable input.
3673 int64_t variable_num =
3674 orig_arg_num - params_info.NumPerReplicaArgs() -
3675 params_info.NumDistributedArgs() - params_info.NumBroadcastArgs();
3676 TF_RET_CHECK(variable_num < num_variables);
3677
3678 Node* variable_read = variable_reads[variable_num];
3679 DataType dtype = variable_read->output_type(0);
3680 if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) ==
3681 kTpuAllTypes.end()) {
3682 return errors::InvalidArgument(
3683 "Unsupported resource variable data type for TPU: ",
3684 DataTypeString(dtype), ", caused by ReadVariableOp ",
3685 variable_read->DebugString());
3686 }
3687 DeviceNameUtils::ParsedName requested_device;
3688 string requested = variable_read->requested_device();
3689 TF_RET_CHECK(
3690 DeviceNameUtils::ParseFullName(requested, &requested_device));
3691 if (requested_device.type != "TPU") {
3692 // Stage the value via the CPU device on the remote host. The graph
3693 // partitioner will introduce an intermediate copy rather than
3694 // copying the same tensor multiple times across the network, and we
3695 // would prefer that intermediate copy to be in host memory to avoid
3696 // running out of memory if the TPUExecute op on the staging device
3697 // starts running before the _Send ops to the other TPU devices on
3698 // the same host complete. We don't do this if the variables are
3699 // already placed on TPU, otherwise it will cause an unnecessary
3700 // round trip copy.
3701 // TODO(b/79580121): give each replica its own on-device variable
3702 // replica and then delete this code.
3703 string device;
3704 TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
3705 tpu_device_names[replica][core], &device));
3706 TF_ASSIGN_OR_RETURN(
3707 auto var_data,
3708 CreateOrGetPerHostVariableCopy(
3709 device, variable_num, variable_reads, params_info,
3710 arg_shardings, replicate_node, enable_xla_param_broadcast_,
3711 num_cores_per_replica, replica, arg_shapes,
3712 &per_host_var_copies, graph));
3713
3714 if (arg_shardings[orig_arg_num].type() == xla::OpSharding::OTHER) {
3715 ShardedInputInfo sharded_input_info;
3716
3717 if (EnableXlaParamBroadcast(enable_xla_param_broadcast_,
3718 params_info, orig_arg_num, dtype)) {
3719 // Populates the sharded dummy vars for non-zero replicas.
3720 TF_RETURN_IF_ERROR(CreatePartitionedDummyVarArgs(
3721 arg_shardings[orig_arg_num], params_info.NumReplicas(),
3722 replica, arg_shapes[orig_arg_num], var_data.node,
3723 orig_arg_num, dtype, device, graph, tpu_device_names,
3724 &sharded_per_host_index, &input_index_to_sharded_inputs));
3725 }
3726
3727 if (use_nd_sharding_ops_) {
3728 TF_ASSIGN_OR_RETURN(
3729 sharded_input_info,
3730 CreateOrGetXlaSplitNodeForVariableArg(
3731 arg_shardings[orig_arg_num], params_info.NumReplicas(),
3732 replica, orig_arg_num,
3733 arg_shapes[orig_arg_num].handle_type,
3734 arg_shapes[orig_arg_num].handle_shape, var_data.node,
3735 var_data.index, graph, &to_be_removed_nodes,
3736 &input_index_to_sharded_inputs));
3737 } else {
3738 TF_ASSIGN_OR_RETURN(
3739 sharded_input_info,
3740 CreateOrGetSplitNodesForInputSharding(
3741 arg_shardings[orig_arg_num], orig_arg_num,
3742 arg_shapes[orig_arg_num].handle_type,
3743 arg_shapes[orig_arg_num].handle_shape, replica,
3744 var_data.index, var_data.node, control_predecessor,
3745 graph, &input_index_to_sharded_inputs));
3746 }
3747
3748 NodeOut split_node_and_index =
3749 sharded_input_info.sharded_inputs[core];
3750 // Connect with Split node output.
3751 graph->AddEdge(split_node_and_index.node,
3752 split_node_and_index.index, node, i);
3753
3754 } else {
3755 graph->AddEdge(var_data.node, var_data.index, node, i);
3756 }
3757 } else {
3758 graph->AddEdge(variable_reads[variable_num], 0, node, i);
3759 }
3760 }
3761 }
3762
3763 // Adds a program input edge from the compiler.
3764 graph->AddEdge(compile_node, core + 1, node, node->num_inputs() - 1);
3765
3766 // Add data output edges.
3767 int num_outputs = core_retval_nums[core].size();
3768 for (int i = 0; i < num_outputs; ++i) {
3769 int output_num =
3770 replica * num_retvals_per_replica + core_retval_nums[core][i];
3771 const auto& sharding = retval_shardings[core_retval_nums[core][i]];
3772 if (sharding.type() == xla::OpSharding::OTHER) {
3773 int retval_index = core_retval_nums[core][i];
3774 retval_index_to_output_index_mapping[retval_index][core] = i;
3775 bool is_last_core =
3776 core ==
3777 *std::max_element(sharding.tile_assignment_devices().begin(),
3778 sharding.tile_assignment_devices().end());
3779 bool isPartitionOutNode = false;
3780
3781 const Edge* e = replicate_output_edges[output_num];
3782 const Edge* e_out;
3783 for (const Edge* out_edge : e->dst()->out_edges()) {
3784 if (out_edge->dst()->type_string() == kTPUPartitionedOutput) {
3785 isPartitionOutNode = true;
3786 e_out = out_edge;
3787 }
3788 }
3789 if (isPartitionOutNode) {
3790 graph->AddEdge(
3791 node, i, replicate_output_fan_out_nodes[output_num][core],
3792 replicate_output_fan_out_dst_inputs[output_num][core]);
3793 VLOG(2) << "Connect " << node->name() << " at " << i << " to "
3794 << replicate_output_fan_out_nodes[output_num][core]->name()
3795 << " at "
3796 << replicate_output_fan_out_dst_inputs[output_num][core];
3797 if (is_last_core) {
3798 graph->RemoveEdge(e);
3799 graph->RemoveEdge(e_out);
3800 }
3801 continue;
3802 }
3803
3804 // Do this in the iteration of last core in tile assignment, so all
3805 // TPUExecute nodes have been created.
3806 if (!is_last_core) {
3807 continue;
3808 }
3809
3810 // Add a Concat node.
3811 std::vector<NodeOut> orig_inputs;
3812 for (int64_t tile_index = 0;
3813 tile_index < sharding.tile_assignment_devices_size();
3814 ++tile_index) {
3815 int64_t last_tile_dim_size =
3816 *sharding.tile_assignment_dimensions().rbegin();
3817 if (sharding.replicate_on_last_tile_dim() &&
3818 tile_index % last_tile_dim_size != 0) {
3819 continue;
3820 }
3821 int64_t core_id = sharding.tile_assignment_devices(tile_index);
3822 int core_retval_index =
3823 retval_index_to_output_index_mapping[retval_index][core_id];
3824 orig_inputs.push_back(
3825 NodeOut{execute_nodes[replica][core_id],
3826 static_cast<int>(
3827 core_retval_nums[core_id][core_retval_index])});
3828 }
3829 DataType dtype = e->src()->output_type(e->src_output());
3830 Node* concat_node = nullptr;
3831 if (use_nd_sharding_ops_) {
3832 TF_ASSIGN_OR_RETURN(
3833 concat_node, CreateXlaConcatNode(
3834 sharding, replica, dtype,
3835 /*partial_tensor_shape=*/PartialTensorShape(),
3836 orig_inputs, /*device=*/"", graph));
3837 } else {
3838 TF_ASSIGN_OR_RETURN(
3839 concat_node,
3840 CreateConcatNodesForRetval(
3841 sharding, dtype, /*inferred_shape=*/PartialTensorShape(),
3842 replica, orig_inputs, graph, /*device=*/""));
3843 }
3844
3845 const Edge* edge = replicate_output_edges[output_num];
3846 Node* dst = edge->dst();
3847 int dst_input = edge->dst_input();
3848 graph->RemoveEdge(edge);
3849 graph->AddEdge(concat_node, 0, dst, dst_input);
3850
3851 continue;
3852 }
3853
3854 // If this is a replicated output, outputs on all cores will be the
3855 // same, and we only take the output from core 0.
3856 if (sharding.type() == xla::OpSharding::REPLICATED && core != 0) {
3857 continue;
3858 }
3859
3860 // If output has maximal sharding, make sure we only use output from
3861 // TPUExecute node with logical core id equal to core id defined by the
3862 // xla sharding.
3863 if (sharding.type() == xla::OpSharding::MAXIMAL &&
3864 core != sharding.tile_assignment_devices(0)) {
3865 continue;
3866 }
3867
3868 const Edge* replicate_edge_to_replace =
3869 replicate_output_edges[output_num];
3870 Node* dst = replicate_edge_to_replace->dst();
3871 int dst_input = replicate_edge_to_replace->dst_input();
3872 graph->RemoveEdge(replicate_edge_to_replace);
3873 graph->AddEdge(node, i, dst, dst_input);
3874 }
3875
3876 // Feed the updated variable values from the first replica to the
3877 // variable write nodes.
3878 if (replica == 0) {
3879 for (int i = 0; i < core_variable_writes.size(); ++i) {
3880 int orig_arg_num =
3881 core_variable_writes[i] + params_info.NumPerReplicaArgs() +
3882 params_info.NumDistributedArgs() + params_info.NumBroadcastArgs();
3883 const auto& sharding = arg_shardings[orig_arg_num];
3884 // If this is a tiling sharded variable, concat variable updates from
3885 // all cores.
3886 if (sharding.type() == xla::OpSharding::OTHER) {
3887 orig_arg_num_to_output_index_mapping[orig_arg_num][core] = i;
3888
3889 // Do this in the iteration of last core in tile assignment, so all
3890 // TPUExecute nodes have been created.
3891 if (core !=
3892 *std::max_element(sharding.tile_assignment_devices().begin(),
3893 sharding.tile_assignment_devices().end())) {
3894 continue;
3895 }
3896
3897 // Add a Concat node.
3898 std::vector<NodeOut> orig_inputs;
3899 for (int64_t tile_index = 0;
3900 tile_index < sharding.tile_assignment_devices_size();
3901 ++tile_index) {
3902 int64_t last_tile_dim_size =
3903 *sharding.tile_assignment_dimensions().rbegin();
3904 if (sharding.replicate_on_last_tile_dim() &&
3905 tile_index % last_tile_dim_size != 0) {
3906 continue;
3907 }
3908 int64_t core_id = sharding.tile_assignment_devices(tile_index);
3909 int core_retval_num =
3910 orig_arg_num_to_output_index_mapping[orig_arg_num][core_id];
3911 orig_inputs.push_back(
3912 NodeOut{execute_nodes[0][core_id],
3913 static_cast<int>(core_retval_nums[core_id].size() +
3914 core_retval_num)});
3915 }
3916
3917 // Use the variable read's device for the concat. They should both
3918 // be collocated with the variable.
3919 absl::string_view device =
3920 variable_reads[core_variable_writes[i]]->assigned_device_name();
3921 Node* concat_node = nullptr;
3922 if (use_nd_sharding_ops_) {
3923 TF_ASSIGN_OR_RETURN(
3924 concat_node,
3925 CreateXlaConcatNode(sharding, replica,
3926 arg_shapes[orig_arg_num].handle_type,
3927 arg_shapes[orig_arg_num].handle_shape,
3928 orig_inputs, device, graph));
3929 } else {
3930 TF_ASSIGN_OR_RETURN(
3931 concat_node,
3932 CreateConcatNodesForRetval(
3933 sharding, arg_shapes[orig_arg_num].handle_type,
3934 arg_shapes[orig_arg_num].handle_shape, replica,
3935 orig_inputs, graph, device));
3936 }
3937 // Populate VariableWrite.
3938 VariableWrite& write = variable_writes->at(core_variable_writes[i]);
3939 write.value = concat_node;
3940 write.value_output = 0;
3941 write.predicate = compile_node;
3942 write.predicate_output = num_cores_per_replica + core + 1;
3943
3944 continue;
3945 }
3946
3947 // If this is a replicated variable, outputs on all cores will be the
3948 // same, and we only take the output from core 0 for the variable
3949 // update.
3950 if (sharding.type() == xla::OpSharding::REPLICATED && core != 0) {
3951 continue;
3952 }
3953 VariableWrite& write = variable_writes->at(core_variable_writes[i]);
3954 write.value = node;
3955 write.value_output = num_outputs + i;
3956 write.predicate = compile_node;
3957 write.predicate_output = num_cores_per_replica + core + 1;
3958 }
3959 }
3960 }
3961 }
3962
3963 for (Node* node : to_be_removed_nodes) {
3964 graph->RemoveNode(node);
3965 }
3966 return Status::OK();
3967 } // NOLINT(readability/fn_size)
3968
CopyOutsideCompilationNodes(int replica_index,const std::vector<Node * > & outside_compilation_nodes,const DeviceNameUtils::ParsedName & tpu_device,const DeviceNameUtils::ParsedName & partial_device,NodeToNodeReplicasMap * node_images,Graph * graph)3969 /* static */ Status DistributedTPURewritePass::CopyOutsideCompilationNodes(
3970 int replica_index, const std::vector<Node*>& outside_compilation_nodes,
3971 const DeviceNameUtils::ParsedName& tpu_device,
3972 const DeviceNameUtils::ParsedName& partial_device,
3973 NodeToNodeReplicasMap* node_images, Graph* graph) {
3974 for (Node* node : outside_compilation_nodes) {
3975 NodeDef image_def = node->def();
3976 MergeDebugInfo(NodeDebugInfo(node->def()), &image_def);
3977 const string suffix = strings::StrCat("/R", replica_index);
3978 // In addition to node name, make the frame name unique to avoid multiple
3979 // LoopCond nodes in one frame.
3980 TF_RETURN_IF_ERROR(
3981 AddPrefixAndSuffixToNode("" /* prefix */, suffix, &image_def));
3982 Status status;
3983 Node* image = graph->AddNode(image_def, &status);
3984 image->AddAttr(kXlaReplicaIdAttrName, replica_index);
3985 TF_RETURN_IF_ERROR(status);
3986 if (HasNodeAttr(image->def(), kXlaHasHostTransferAttrName)) {
3987 TF_RETURN_IF_ERROR(
3988 SetNodeDeviceForTPUCommunication(tpu_device, DEVICE_CPU, image));
3989 } else {
3990 const string& original_device_string =
3991 node->assigned_device_name().empty() ? node->requested_device()
3992 : node->assigned_device_name();
3993 DeviceNameUtils::ParsedName device;
3994 TF_RET_CHECK(
3995 DeviceNameUtils::ParseFullName(original_device_string, &device));
3996 // If the requested device can be merged with the replica's host device,
3997 // then do so. For example, if the requested device is "/CPU:0" or
3998 // "/GPU:0" then it will be placed on the CPU/GPU of the host where this
3999 // replica is running. But if the requested device is
4000 // "/task:3/replica:2/CPU:0" then it will be placed on that task/replica.
4001 if (DeviceNameUtils::IsSpecification(device, partial_device)) {
4002 TF_RETURN_IF_ERROR(
4003 DeviceNameUtils::MergeDevNames(&device, partial_device));
4004 }
4005 image->set_requested_device(DeviceNameUtils::ParsedNameToString(device));
4006 }
4007 std::vector<Node*>& node_image_vector = (*node_images)[node];
4008 node_image_vector.resize(replica_index + 1);
4009 node_image_vector[replica_index] = image;
4010 }
4011 return Status::OK();
4012 }
4013
ReplicateOutsideCompilationNodes(const std::vector<std::vector<string>> & tf_device_assignment,const HostComputeCoreMap & host_compute_core,const OutsideCompilationNodeMap & outside_compilation_nodes,NodeToNodeReplicasMap * node_images,Graph * graph)4014 /* static */ Status DistributedTPURewritePass::ReplicateOutsideCompilationNodes(
4015 const std::vector<std::vector<string>>& tf_device_assignment,
4016 const HostComputeCoreMap& host_compute_core,
4017 const OutsideCompilationNodeMap& outside_compilation_nodes,
4018 NodeToNodeReplicasMap* node_images, Graph* graph) {
4019 // Iterate over replicas.
4020 for (int i = 0; i < tf_device_assignment.size(); ++i) {
4021 const auto& core_devices = tf_device_assignment[i];
4022 for (const auto& oc_cluster_iter : outside_compilation_nodes) {
4023 const string& oc_cluster_name = oc_cluster_iter.first;
4024 const auto& oc_cluster_nodes = oc_cluster_iter.second;
4025 // We previously validated that host_compute_core contains an entry for
4026 // each cluster.
4027 int core = host_compute_core.at(oc_cluster_name);
4028 TF_RET_CHECK(core >= 0 && core < core_devices.size());
4029 // tpu_device is the device the HostCompute XLA Op for this cluster runs
4030 // on.
4031 DeviceNameUtils::ParsedName tpu_device;
4032 TF_RET_CHECK(
4033 DeviceNameUtils::ParseFullName(core_devices[core], &tpu_device));
4034 // partial_device contains the replica and task but not the type.
4035 DeviceNameUtils::ParsedName partial_device = tpu_device;
4036 partial_device.has_type = false;
4037 partial_device.has_id = false;
4038
4039 if (tf_device_assignment.size() == 1) {
4040 // With a single replica don't copy any nodes just put the original
4041 // nodes into the image map. We leave the device placement alone, except
4042 // that we have to fill in the correct core for the host send and
4043 // receive nodes.
4044 for (Node* node : oc_cluster_nodes) {
4045 (*node_images)[node] = {node};
4046 node->AddAttr(kXlaReplicaIdAttrName, 0);
4047 if (HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) {
4048 TF_RETURN_IF_ERROR(
4049 SetNodeDeviceForTPUCommunication(tpu_device, DEVICE_CPU, node));
4050 }
4051 }
4052 } else {
4053 // Iterate over outside_compilation clusters in this computation, adding
4054 // all the nodes with appropriate device assignments.
4055 TF_RETURN_IF_ERROR(
4056 CopyOutsideCompilationNodes(i, oc_cluster_nodes, tpu_device,
4057 partial_device, node_images, graph));
4058 }
4059 }
4060 }
4061 return Status::OK();
4062 }
4063
CopyOutsideCompilationEdges(const std::vector<Node * > & outside_compilation_nodes,const NodeToNodeReplicasMap & node_images,const std::unordered_map<string,Node * > outside_compilation_inputs,Graph * graph)4064 /* static */ Status DistributedTPURewritePass::CopyOutsideCompilationEdges(
4065 const std::vector<Node*>& outside_compilation_nodes,
4066 const NodeToNodeReplicasMap& node_images,
4067 const std::unordered_map<string, Node*> outside_compilation_inputs,
4068 Graph* graph) {
4069 for (Node* node : outside_compilation_nodes) {
4070 const auto& images = node_images.at(node);
4071 // Make a copy of all edges and iterate on "in_edges", because we might
4072 // remove edges when iteratating through them.
4073 std::vector<const Edge*> in_edges(node->in_edges().begin(),
4074 node->in_edges().end());
4075 for (const Edge* edge : in_edges) {
4076 Node* src = edge->src();
4077 const auto iter = node_images.find(src);
4078 if (iter == node_images.end()) {
4079 if (images.size() > 1) {
4080 // The source node is a 'normal' node not part of any
4081 // rewrite. Broadcast the value to all replicas. (If images.size() ==
4082 // 1 the cluster is not replicated and we can leave the original edge
4083 // in place.)
4084 for (Node* dst : images) {
4085 graph->AddEdge(src, edge->src_output(), dst, edge->dst_input());
4086 }
4087 }
4088 continue;
4089 }
4090
4091 // The source node is a replicated outside_compilation node.
4092 const auto& src_images = iter->second;
4093 if (src_images.size() != images.size()) {
4094 return errors::InvalidArgument(
4095 "Graph contains an edge from node ", src->name(),
4096 " in an outside_compilation block replicated ", src_images.size(),
4097 " ways to node ", node->name(),
4098 " in an outside_compilation block replicated ", images.size(),
4099 " ways. Replication factors must match. Leave a comment on "
4100 "tracking bug b/76419636 if you need this to be supported.");
4101 }
4102 bool is_lifted_arg;
4103 string outside_compilation_cluster;
4104 if (GetNodeAttr(src->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg)
4105 .ok() &&
4106 GetNodeAttr(src->def(), kOutsideCompilationAttr,
4107 &outside_compilation_cluster)
4108 .ok()) {
4109 const auto input_iter =
4110 outside_compilation_inputs.find(outside_compilation_cluster);
4111 TF_RET_CHECK(input_iter != outside_compilation_inputs.end());
4112 TF_RET_CHECK(input_iter->second->type_string() == "IdentityN");
4113 int dst_input = edge->dst_input();
4114 if (src_images.size() == 1) {
4115 graph->RemoveEdge(edge);
4116 }
4117 for (int i = 0; i < src_images.size(); ++i) {
4118 graph->AddEdge(input_iter->second, i, images[i], dst_input);
4119 }
4120 continue;
4121 }
4122
4123 bool is_placeholder_for_arg;
4124 string outside_compilation_input_attr;
4125 if (GetNodeAttr(src->def(), kXlaIsPlaceholderForArg,
4126 &is_placeholder_for_arg)
4127 .ok() &&
4128 GetNodeAttr(src->def(), kXlaOutsideCompilationInputsAttrName,
4129 &outside_compilation_input_attr)
4130 .ok()) {
4131 const auto input_iter =
4132 outside_compilation_inputs.find(outside_compilation_input_attr);
4133 TF_RET_CHECK(input_iter != outside_compilation_inputs.end());
4134 TF_RET_CHECK(input_iter->second->type_string() == "IdentityN");
4135 int dst_input = edge->dst_input();
4136 if (src_images.size() == 1) {
4137 graph->RemoveEdge(edge);
4138 }
4139 for (int i = 0; i < src_images.size(); ++i) {
4140 graph->AddEdge(input_iter->second, i, images[i], dst_input);
4141 }
4142 continue;
4143 }
4144
4145 if (images.size() > 1) {
4146 // If images.size() == 1 neither cluster is replicated and we can
4147 // leave the original edges in place.
4148 for (int i = 0; i < src_images.size(); ++i) {
4149 graph->AddEdge(src_images[i], edge->src_output(), images[i],
4150 edge->dst_input());
4151 }
4152 }
4153 }
4154 for (const Edge* edge : node->out_edges()) {
4155 Node* dst = edge->dst();
4156 const auto iter = node_images.find(dst);
4157 if (iter == node_images.end()) {
4158 // The source node is a 'normal' node not part of any rewrite.
4159 if (edge->IsControlEdge()) {
4160 // Make the dst node have a control dependency on every replica.
4161 if (images.size() > 1) {
4162 for (int i = 0; i < images.size(); ++i) {
4163 graph->AddControlEdge(images[i], dst);
4164 }
4165 }
4166 // else the cluster is not replicated so we can leave the original
4167 // edge in place.
4168 } else {
4169 // The edge
4170 // is only valid if the outside_compilation block is not replicated.
4171 if (images.size() > 1) {
4172 return errors::InvalidArgument(
4173 "Graph contains an edge from node ", node->name(),
4174 " in an outside_compilation block replicated ", images.size(),
4175 " ways to node ", dst->name(),
4176 " that is not part of an outside_compilation block. Edges from "
4177 "outside_compilation to regular graph nodes are only supported "
4178 "for replication factors of 1. Leave a comment on tracking bug "
4179 "b/76419636 if you need this to be supported.");
4180 }
4181 // else the cluster is not replicated so we can leave the original
4182 // edge in place.
4183 }
4184 }
4185 // The case where src and dst are both in node_images is covered elsewhere
4186 // when iterating over in_edges of dst.
4187 }
4188 }
4189 return Status::OK();
4190 }
4191
ReplicateOutsideCompilationEdges(const OutsideCompilationNodeMap & outside_compilation_nodes,const NodeToNodeReplicasMap & node_images,const std::unordered_map<string,Node * > outside_compilation_inputs,Graph * graph)4192 /* static */ Status DistributedTPURewritePass::ReplicateOutsideCompilationEdges(
4193 const OutsideCompilationNodeMap& outside_compilation_nodes,
4194 const NodeToNodeReplicasMap& node_images,
4195 const std::unordered_map<string, Node*> outside_compilation_inputs,
4196 Graph* graph) {
4197 for (const auto& oc_cluster_iter : outside_compilation_nodes) {
4198 TF_RETURN_IF_ERROR(
4199 CopyOutsideCompilationEdges(oc_cluster_iter.second, node_images,
4200 outside_compilation_inputs, graph));
4201 }
4202 return Status::OK();
4203 }
4204
RemoveOutsideCompilationNodes(const NodeToNodeReplicasMap & node_images,Graph * graph)4205 /* static */ Status DistributedTPURewritePass::RemoveOutsideCompilationNodes(
4206 const NodeToNodeReplicasMap& node_images, Graph* graph) {
4207 for (const auto& iter : node_images) {
4208 if (iter.second.size() > 1) {
4209 // The cluster was replicated so remove the original node.
4210 Node* node = iter.first;
4211 graph->RemoveNode(node);
4212 }
4213 }
4214 return Status::OK();
4215 }
4216
4217 /* static */ Status
LowerOutsideCompilationFunctionalNodes(Graph * g,FunctionLibraryDefinition & flib_def,const TPUReplicateDeviceNamesMapping & tpu_replicate_device_names_mapping)4218 DistributedTPURewritePass::LowerOutsideCompilationFunctionalNodes(
4219 Graph* g, FunctionLibraryDefinition& flib_def,
4220 const TPUReplicateDeviceNamesMapping& tpu_replicate_device_names_mapping) {
4221 bool modified = false;
4222 do {
4223 std::vector<Node*> nodes_to_lower;
4224 for (Node* n : g->op_nodes()) {
4225 if (!HasNodeAttr(n->def(), kOutsideCompilationAttr)) {
4226 continue;
4227 }
4228
4229 if (n->IsWhileNode() || n->IsIfNode() || IsFunctionCall(flib_def, *n)) {
4230 // Only lower functional ops with DT_RESOURCE input, because otherwise
4231 // placer will complain. For normal cases, lowering will cause slowdown
4232 // when related functions are huge (b/139037679).
4233 bool has_resource_input = false;
4234 for (const Edge* e : n->in_edges()) {
4235 if (!e->IsControlEdge() &&
4236 e->src()->output_type(e->src_output()) == DT_RESOURCE) {
4237 has_resource_input = true;
4238 break;
4239 }
4240 }
4241 if (has_resource_input) {
4242 nodes_to_lower.push_back(n);
4243 }
4244 }
4245 }
4246
4247 modified = !nodes_to_lower.empty();
4248
4249 auto lower_functional_node = [&flib_def, &g](Node* n) -> Status {
4250 // Clear device assignment. Otherwise all lowered nodes will have
4251 // device assignment, which is not what we want.
4252 n->set_requested_device("");
4253
4254 int replica_id;
4255 TF_RETURN_IF_ERROR(
4256 GetNodeAttr(n->def(), kXlaReplicaIdAttrName, &replica_id));
4257
4258 string outside_compilation_attr;
4259 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kOutsideCompilationAttr,
4260 &outside_compilation_attr));
4261
4262 // There are two different kinds of functional outside compilation nodes:
4263 // 1. Nodes that are in outside compilation blocks already. They are
4264 // generated by FunctionalizeControlFlowForXlaPass, and only have
4265 // attribute kOutsideCompilationAttr.
4266 // 2. Mirrored control flow built for outside compilation in functional
4267 // nodes. They are generated by ExtractOutsideCompilationPass, and have
4268 // both kOutsideCompilationAttr and kXlaHasHostTransferAttrName.
4269 // When lowering them, they need to be treated differently.
4270 // For 1), their body functions are always V1 functions written by users,
4271 // and their "control outputs" are control inputs of _Retval nodes. They
4272 // should be lowered as V1 functions.
4273 // For 2), we always add necessary "control outputs"
4274 // (_XlaRecvAtHost/_XlaSendAtHost nodes) to "control_ret" field in their
4275 // FunctionDef's. They should be lowered as V2 functions.
4276 bool is_host_side_mirrored_control_flow =
4277 HasNodeAttr(n->def(), kXlaHasHostTransferAttrName);
4278
4279 int num_node_ids = g->num_node_ids();
4280 bool is_call_node = IsFunctionCall(flib_def, *n);
4281 if (n->IsWhileNode()) {
4282 TF_RETURN_IF_ERROR(RewriteWhileNode(n, g, &flib_def,
4283 /*keep_node_fetchable=*/false));
4284 } else if (n->IsIfNode()) {
4285 TF_RETURN_IF_ERROR(RewriteIfNode(n, g, /*keep_node_fetchable=*/false));
4286 } else {
4287 TF_RET_CHECK(is_call_node);
4288 // See comments for "is_host_side_mirrored_control_flow" above.
4289 // If this is a node that's in outside compilation block, lower it as
4290 // V1 function. This is controlled by removing
4291 // kLowerAsMultiDeviceFunctionAttr from the node.
4292 if (!is_host_side_mirrored_control_flow) {
4293 n->ClearAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr);
4294 } else {
4295 n->ClearAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr);
4296 n->AddAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr,
4297 true);
4298 }
4299 TF_RETURN_IF_ERROR(
4300 RewriteFunctionCallNode(n, g, flib_def,
4301 /*keep_caller_fetchable=*/false));
4302 }
4303
4304 for (int i = num_node_ids; i < g->num_node_ids(); i++) {
4305 Node* node = g->FindNodeId(i);
4306 if (!node) {
4307 continue;
4308 }
4309
4310 if (!is_call_node && is_host_side_mirrored_control_flow &&
4311 IsFunctionCall(flib_def, *node)) {
4312 // For If/While nodes, if they are host side mirrored control flow,
4313 // mark their body function calls with kXlaHasHostTransferAttrName
4314 // attribute to make sure we lower them as V2 function.
4315 node->AddAttr(kXlaHasHostTransferAttrName, true);
4316 }
4317
4318 if (IsFunctionCall(flib_def, *node) || node->IsWhileNode() ||
4319 node->IsIfNode()) {
4320 // Set kOutsideCompilationAttr attribute so we lower these
4321 // nested function call nodes later.
4322 node->AddAttr(kOutsideCompilationAttr, outside_compilation_attr);
4323 // Set kXlaReplicaIdAttrName attribute so we know replica id when we
4324 // lower this function call node.
4325 node->AddAttr(kXlaReplicaIdAttrName, replica_id);
4326 } else if (node->type_string() == "_XlaRecvAtHost" ||
4327 node->type_string() == "_XlaSendFromHost") {
4328 // For "_XlaRecvAtHost" and "_XlaSendFromHost" nodes, make sure they
4329 // have kXlaReplicaIdAttrName attribute so later we know which host
4330 // device to assign.
4331 node->AddAttr(kXlaReplicaIdAttrName, replica_id);
4332 }
4333 }
4334 return Status::OK();
4335 };
4336
4337 for (Node* n : nodes_to_lower) {
4338 TF_RETURN_IF_ERROR(lower_functional_node(n));
4339 }
4340 } while (modified);
4341
4342 // Set device for all _XlaRecvAtHost and _XlaSendFromHost nodes.
4343 for (Node* n : g->op_nodes()) {
4344 if (n->type_string() != "_XlaRecvAtHost" &&
4345 n->type_string() != "_XlaSendFromHost") {
4346 continue;
4347 }
4348
4349 string replicate;
4350 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kTPUReplicateAttr, &replicate));
4351 auto iter = tpu_replicate_device_names_mapping.find(replicate);
4352 TF_RET_CHECK(iter != tpu_replicate_device_names_mapping.end());
4353 const auto& tpu_device_names = iter->second;
4354
4355 int replica_id;
4356 TF_RETURN_IF_ERROR(
4357 GetNodeAttr(n->def(), kXlaReplicaIdAttrName, &replica_id));
4358 TF_RET_CHECK(replica_id < tpu_device_names.size());
4359 const string& tpu_device_name = tpu_device_names[replica_id][0];
4360 string host_device_name;
4361 TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
4362 tpu_device_name, &host_device_name));
4363 n->set_assigned_device_name(host_device_name);
4364 // We may run TPU rewrite passes again on the subgraphs of the resulting
4365 // graph. Clear kTPUReplicateAttr and kOutsideCompilationAttr for
4366 // "_XlaRecvAtHost" nodes and "_XlaSendFromHost" nodes, in order to make
4367 // sure that TPU rewrite passes take no effect on host-side subgraphs for
4368 // outside compilation.
4369 n->ClearAttr(kTPUReplicateAttr);
4370 n->ClearAttr(kOutsideCompilationAttr);
4371 }
4372
4373 // Remove IdentityN nodes generated for outside compilation. IdentityN is
4374 // exempt from resource edge colocation, but here we do need input and output
4375 // for these IdentityN nodes to be colocated.
4376 std::vector<Node*> identityn_nodes;
4377 for (Node* n : g->op_nodes()) {
4378 if (n->type_string() == "IdentityN" &&
4379 HasNodeAttr(n->def(), kXlaOutsideCompilationInputsAttrName)) {
4380 identityn_nodes.push_back(n);
4381 }
4382 }
4383 for (Node* n : identityn_nodes) {
4384 std::vector<const Edge*> out_edges(n->out_edges().begin(),
4385 n->out_edges().end());
4386 for (const Edge* e : out_edges) {
4387 if (e->IsControlEdge()) {
4388 continue;
4389 }
4390
4391 int src_output = e->src_output();
4392 const Edge* input_edge;
4393 TF_RETURN_IF_ERROR(n->input_edge(src_output, &input_edge));
4394 Node* dst = e->dst();
4395 int dst_input = e->dst_input();
4396 g->RemoveEdge(e);
4397 g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
4398 }
4399 g->RemoveNode(n);
4400 }
4401
4402 return Status::OK();
4403 }
4404
ParseHostComputeCores(const Node & replicate_node,const OutsideCompilationNodeMap & outside_compilation_nodes,HostComputeCoreMap * host_compute_core)4405 /* static */ Status DistributedTPURewritePass::ParseHostComputeCores(
4406 const Node& replicate_node,
4407 const OutsideCompilationNodeMap& outside_compilation_nodes,
4408 HostComputeCoreMap* host_compute_core) {
4409 std::vector<string> hc_core_string;
4410 TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "host_compute_core",
4411 &hc_core_string));
4412 TF_RETURN_IF_ERROR(
4413 ParseHostComputeCoreList(hc_core_string, host_compute_core));
4414 for (const auto& iter : outside_compilation_nodes) {
4415 const string& oc_cluster_name = iter.first;
4416 if (host_compute_core->find(oc_cluster_name) == host_compute_core->end()) {
4417 // By default put host compute Ops on replicated core 0.
4418 (*host_compute_core)[oc_cluster_name] = 0;
4419 }
4420 }
4421 return Status::OK();
4422 }
4423
GetDeviceTopology(const DeviceSet & device_set,const Node & replicate_node,int * num_replicas,int * num_cores_per_replica,int * num_tasks,std::vector<std::vector<string>> * tf_device_assignment,std::vector<int> * devices_to_lock,std::unique_ptr<xla::DeviceAssignment> * xla_device_assignment,string * tpu_compilation_device)4424 /* static */ Status DistributedTPURewritePass::GetDeviceTopology(
4425 const DeviceSet& device_set, const Node& replicate_node, int* num_replicas,
4426 int* num_cores_per_replica, int* num_tasks,
4427 std::vector<std::vector<string>>* tf_device_assignment,
4428 std::vector<int>* devices_to_lock,
4429 std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment,
4430 string* tpu_compilation_device) {
4431 TF_RETURN_IF_ERROR(
4432 GetNodeAttr(replicate_node.attrs(), "num_replicas", num_replicas));
4433 if (*num_replicas < 1) {
4434 return errors::InvalidArgument("num_replicas must be >= 1, got ",
4435 *num_replicas);
4436 }
4437
4438 // Find the set of TPU devices in the TF job.
4439 // Indexed by [task number][tpu device number].
4440 std::vector<std::vector<Device*>> tpu_devices;
4441 int num_tpus_per_task;
4442 TF_RETURN_IF_ERROR(GetTPUDeviceNames(replicate_node.requested_device(),
4443 device_set, tpu_compilation_device,
4444 &num_tpus_per_task, &tpu_devices));
4445 *num_tasks = tpu_devices.size();
4446
4447 string topology;
4448 TF_RETURN_IF_ERROR(
4449 GetNodeAttr(replicate_node.attrs(), "topology", &topology));
4450 TF_RETURN_IF_ERROR(GetNodeAttr(
4451 replicate_node.attrs(), "num_cores_per_replica", num_cores_per_replica));
4452 std::vector<int> device_assignment;
4453 TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "device_assignment",
4454 &device_assignment));
4455
4456 // TODO(cwhipkey): since we can control multiple pods of different shapes
4457 // from a single worker, it may be desirable to propagate the remote device
4458 // information around (e.g., in DeviceAttributes). This can lead to the mesh
4459 // topology proto being leaked to cloud TPU users (e.g. through GetStatus
4460 // calls); this may be okay, but to be conservative, just assume that the
4461 // master session has the proper flags set.
4462
4463 // We do not initialize platform right now, but we can still retrieve the
4464 // TPU topology even with an uninitialized platform.
4465 auto* tpu_platform = tpu::TpuPlatformInterface::GetRegisteredPlatform(
4466 /*initialize_platform=*/false);
4467 TF_RET_CHECK(tpu_platform);
4468 tpu::TpuTopologyExternal tpu_topology(tpu_platform->GetTopologyPtr());
4469 TF_RET_CHECK(num_tpus_per_task ==
4470 tpu_topology.LogicalDevicesPerHost(kTensorCore));
4471 TF_RETURN_IF_ERROR(BuildDeviceAssignment(
4472 tpu_topology, num_tpus_per_task, tpu_devices, *num_replicas,
4473 *num_cores_per_replica, topology, device_assignment, tf_device_assignment,
4474 devices_to_lock, xla_device_assignment));
4475
4476 return Status::OK();
4477 }
4478
GetIOTypes(int num_replicas,const Node & replicate_node,FunctionLibraryRuntime * flr,Graph * graph,NameRangeMap * input_name_map,const NameAttrList ** function,std::unique_ptr<Graph> * computation,DataTypeVector * arg_types,DataTypeVector * retval_types,ParameterInfo * params_info)4479 /* static */ Status DistributedTPURewritePass::GetIOTypes(
4480 int num_replicas, const Node& replicate_node, FunctionLibraryRuntime* flr,
4481 Graph* graph, NameRangeMap* input_name_map, const NameAttrList** function,
4482 std::unique_ptr<Graph>* computation, DataTypeVector* arg_types,
4483 DataTypeVector* retval_types, ParameterInfo* params_info) {
4484 DataTypeVector input_types, broadcast_input_types, guaranteed_constant_types;
4485 TF_RETURN_IF_ERROR(
4486 GetNodeAttr(replicate_node.attrs(), "Tinputs", &input_types));
4487 TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "Tbroadcast_inputs",
4488 &broadcast_input_types));
4489 TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(),
4490 "Tguaranteed_constants",
4491 &guaranteed_constant_types));
4492 int num_distributed_vars;
4493 TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(),
4494 "num_distributed_variables",
4495 &num_distributed_vars));
4496 const int num_per_replica_inputs = input_types.size() - num_distributed_vars;
4497
4498 if (num_per_replica_inputs % num_replicas != 0) {
4499 return errors::InvalidArgument(
4500 "Number of inputs to TPUReplicate (", num_per_replica_inputs,
4501 ") is not divisible by the number of replicas (", num_replicas, ").");
4502 }
4503
4504 int num_variables;
4505 TF_RETURN_IF_ERROR(
4506 GetNodeAttr(replicate_node.attrs(), "NumVariables", &num_variables));
4507
4508 NameRangeMap output_name_map;
4509 TF_RETURN_IF_ERROR(NameRangesForNode(replicate_node, replicate_node.op_def(),
4510 input_name_map, &output_name_map));
4511
4512 TF_RETURN_IF_ERROR(
4513 GetNodeAttr(replicate_node.attrs(), "computation", function));
4514
4515 *computation = absl::make_unique<Graph>(graph->op_registry());
4516 TF_RETURN_IF_ERROR(GetComputationForTPUReplicateOp(
4517 **function, flr, computation->get(), arg_types, retval_types));
4518
4519 *params_info = ParameterInfo(
4520 num_replicas, num_per_replica_inputs / num_replicas, num_distributed_vars,
4521 broadcast_input_types.size(), num_variables,
4522 guaranteed_constant_types.size(), retval_types->size());
4523
4524 if (arg_types->size() != params_info->NumInputsToEachReplica()) {
4525 return errors::InvalidArgument(
4526 "Computation argument to TPUReplicate has wrong number of "
4527 "arguments. Expected ",
4528 params_info->NumInputsToEachReplica(), " inputs, got ",
4529 arg_types->size());
4530 }
4531 if (replicate_node.num_outputs() != params_info->NumOutputsToHost()) {
4532 return errors::InvalidArgument(
4533 "Wrong number of outputs from TPUReplicate. Expected ",
4534 params_info->NumOutputsToHost(), " outputs, got ",
4535 replicate_node.num_outputs());
4536 }
4537 if (enable_cross_replica_sharding_mirrored_variables_) {
4538 std::vector<int> mirrored_variable_indices;
4539 TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(),
4540 TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
4541 &mirrored_variable_indices));
4542 for (int index : mirrored_variable_indices) {
4543 TF_RET_CHECK(params_info->IsPerReplicaArg(index) ||
4544 params_info->IsDistributedArg(index))
4545 << "Mirrored variables not categorized as per-replica arguments, "
4546 "index: "
4547 << index;
4548 params_info->mutable_mirrored_variable_indices()->insert(index);
4549 }
4550 }
4551 return Status::OK();
4552 }
4553
BuildSequencingNodes(const string & tpu_compilation_device,const Node & replicate_node,Graph * graph,Node ** host_transfer_sequencer,Node ** control_before,Node ** control_after)4554 /* static */ Status DistributedTPURewritePass::BuildSequencingNodes(
4555 const string& tpu_compilation_device, const Node& replicate_node,
4556 Graph* graph, Node** host_transfer_sequencer, Node** control_before,
4557 Node** control_after) {
4558 *host_transfer_sequencer = nullptr;
4559
4560 TF_RETURN_IF_ERROR(
4561 BuildNoopNode(replicate_node,
4562 graph->NewName(strings::StrCat(replicate_node.name(), "/",
4563 "control_before")),
4564 /*device=*/"", graph, control_before));
4565 for (const Edge* e : replicate_node.in_edges()) {
4566 if (!e->IsControlEdge()) {
4567 continue;
4568 }
4569 Node* predecessor = e->src();
4570 if (predecessor->IsSource()) continue;
4571 if (predecessor->type_string() == "NoOp" &&
4572 predecessor->attrs().Find("_xla_host_transfer_sequencer") != nullptr) {
4573 // The node is the sequencer for host transfer operations. Its control
4574 // dependency needs to be placed after the execute node, not before.
4575 if (*host_transfer_sequencer != nullptr) {
4576 return errors::Internal("Replicate node ", replicate_node.name(),
4577 " has two transfer sequencer nodes: ",
4578 (*host_transfer_sequencer)->name(), " and ",
4579 predecessor->name());
4580 }
4581 // Set the correct device to match the other sequencing nodes.
4582 predecessor->set_assigned_device_name(tpu_compilation_device);
4583 *host_transfer_sequencer = predecessor;
4584 } else {
4585 graph->AddControlEdge(predecessor, *control_before);
4586 }
4587 }
4588
4589 TF_RETURN_IF_ERROR(
4590 BuildNoopNode(replicate_node,
4591 graph->NewName(strings::StrCat(replicate_node.name(), "/",
4592 "control_after")),
4593 /*device=*/tpu_compilation_device, graph, control_after));
4594 for (Node* successor : replicate_node.out_nodes()) {
4595 if (successor->attrs().Find("_xla_tail_outside_compilation") != nullptr) {
4596 graph->AddControlEdge(successor, *control_after);
4597 } else {
4598 graph->AddControlEdge(*control_after, successor);
4599 }
4600 }
4601 return Status::OK();
4602 }
4603
DealWithConstantsAndVariables(const Node & replicate_node,const NameRangeMap & input_name_map,Graph * graph,Node * host_transfer_sequencer,Node * control_before,Node * control_after,absl::Span<const VariableInput> variable_nodes,std::vector<Node * > * guaranteed_constant_nodes,std::vector<Node * > * variable_reads)4604 /* static */ Status DistributedTPURewritePass::DealWithConstantsAndVariables(
4605 const Node& replicate_node, const NameRangeMap& input_name_map,
4606 Graph* graph, Node* host_transfer_sequencer, Node* control_before,
4607 Node* control_after, absl::Span<const VariableInput> variable_nodes,
4608 std::vector<Node*>* guaranteed_constant_nodes,
4609 std::vector<Node*>* variable_reads) {
4610 TF_RETURN_IF_ERROR(FindGuaranteedConstantInputs(
4611 replicate_node, input_name_map, guaranteed_constant_nodes));
4612
4613 TF_RETURN_IF_ERROR(BuildVariableReads(variable_nodes, control_before, graph,
4614 variable_reads));
4615 // Add the control dependency from host transfer nodes.
4616 if (host_transfer_sequencer != nullptr) {
4617 graph->AddControlEdge(host_transfer_sequencer, control_after);
4618 }
4619 return Status::OK();
4620 }
4621
4622 /* static */ Status
BuildCompilationStatusReturnNodes(Node * replicate_node,Node * compile_node,absl::Span<const int> devices_to_lock,Node ** control_after_compilation,Node ** multilock_acquire,Graph * graph)4623 DistributedTPURewritePass::BuildCompilationStatusReturnNodes(
4624 Node* replicate_node, Node* compile_node,
4625 absl::Span<const int> devices_to_lock, Node** control_after_compilation,
4626 Node** multilock_acquire, Graph* graph) {
4627 const Edge* compilation_edge = nullptr;
4628 for (const auto* e : replicate_node->out_edges()) {
4629 if (e->IsControlEdge() &&
4630 e->dst()->type_string() == "TPUCompilationResult") {
4631 TF_RET_CHECK(compilation_edge == nullptr)
4632 << "Multiple compilation result nodes attached to the same replicate "
4633 "cluster.";
4634 compilation_edge = e;
4635 }
4636 }
4637
4638 // TODO(jpienaar): This should be checked by default, current tests not using
4639 // this are ones that use the "abort upon successful compile flag" which will
4640 // be removed. Leaving this in until then.
4641 if (compilation_edge != nullptr) {
4642 Node* compilation_status = compilation_edge->dst();
4643 const AttrValue* compile_status_cluster_attr =
4644 compilation_status->attrs().Find(kTPUCompilationResultAttr);
4645 TF_RET_CHECK(compile_status_cluster_attr != nullptr);
4646 const string& compile_status_cluster = compile_status_cluster_attr->s();
4647 TF_RET_CHECK(!compile_status_cluster.empty());
4648 const AttrValue* replicate_cluster_attr =
4649 replicate_node->attrs().Find(kTPUReplicateAttr);
4650 TF_RET_CHECK(replicate_cluster_attr != nullptr);
4651 const string& replicate_cluster = replicate_cluster_attr->s();
4652 TF_RET_CHECK(!replicate_cluster.empty());
4653 TF_RET_CHECK(compile_status_cluster == replicate_cluster);
4654
4655 TF_RETURN_IF_ERROR(
4656 ReplaceCompilationResultNodeWithIdentity(graph, &compilation_status));
4657 graph->AddEdge(compile_node, 0, compilation_status, 0);
4658 }
4659
4660 NodeDef def;
4661 def.set_name(UniqueNodeName("tpu_compile_succeeded_assert", graph));
4662 // Create an op to assert that compilation succeeded. The alternative would
4663 // have been to have each execute op check and return an error.
4664 def.set_op("TPUCompileSucceededAssert");
4665 MergeDebugInfo(NodeDebugInfo(replicate_node->def()), &def);
4666 Status status;
4667 Node* compile_succeeded = graph->AddNode(def, &status);
4668 compile_succeeded->set_assigned_device_name(
4669 compile_node->assigned_device_name());
4670 TF_RETURN_IF_ERROR(status);
4671 graph->AddEdge(compile_node, 0, compile_succeeded, 0);
4672
4673 Node* last_node_before_sequencer = compile_succeeded;
4674
4675 if (enable_multicore_locking_ && devices_to_lock.size() > 1) {
4676 // Add a lock node to acquire exclusive access to all the cores that will
4677 // execute this program. The lock is required to prevent deadlock or
4678 // incorrect results when running concurrent multi-core programs in the
4679 // same distributed runtime when there is no direct graph dependency
4680 // between the programs (either because they are run from different sessions
4681 // or because they are in the same graph, but have no control or data
4682 // dependencies to sequence them). Consider the case of two multi-core
4683 // computations A and B whose cores overlap and include cores X and Y. With
4684 // no locking and no graph dependencies it is possible that A's program
4685 // gets enqueued before B's on core X, while B's program gets enqueued
4686 // before A's on core Y. This will lead either to deadlock or to
4687 // incorrect results, since the runtime has no mechanism to re-sequence
4688 // the programs on the cores. By adding a multi-lock acquisition for all the
4689 // before any TPUExecute ops are run, and releasing it after they complete,
4690 // we ensure that the programs are enqueued on the cores in a consistent
4691 // order.
4692 //
4693 // There is a risk when computations are in the same graph, and include a
4694 // data dependency, that the lock acquisition could provoke deadlock.
4695 // Suppose that A must happen before B because B's input depends on A's
4696 // output. Then it is obviously necessary that A's lock acquisition must
4697 // happen before B's lock acquisition, and so we must ensure that there is
4698 // a graph dependency causing B's lock acquisition to be sequenced after A's
4699 // lock acquisition. Right now that dependency is satisfied because the
4700 // shape inference code cannot determine the shape of A's outputs, and so
4701 // B's compilation, which precedes B's lock acquisition, is always sequenced
4702 // after A's execution. If the shape inference is improved it will be
4703 // necessary to add an explicit control edge between dependent lock
4704 // acquisition ops.
4705 NodeDef lock_def;
4706 lock_def.set_name(graph->NewName(
4707 strings::StrCat(compile_node->name(), "/", "tpu_acquire_multilock")));
4708 lock_def.set_op("TpuMultilock");
4709 AddNodeAttr("lock_list", devices_to_lock, &lock_def);
4710 MergeDebugInfo(NodeDebugInfo(replicate_node->def()), &lock_def);
4711 Status status;
4712 *multilock_acquire = graph->AddNode(lock_def, &status);
4713 TF_RETURN_IF_ERROR(status);
4714 (*multilock_acquire)
4715 ->set_assigned_device_name(compile_node->assigned_device_name());
4716 graph->AddControlEdge(compile_succeeded, *multilock_acquire);
4717 last_node_before_sequencer = *multilock_acquire;
4718 } else {
4719 *multilock_acquire = nullptr;
4720 }
4721
4722 // Build a sequencing node for when compilation has completed.
4723 TF_RETURN_IF_ERROR(
4724 BuildNoopNode(*replicate_node,
4725 graph->NewName(strings::StrCat(compile_node->name(), "/",
4726 "after_compilation")),
4727 /*device=*/"", graph, control_after_compilation));
4728 graph->AddControlEdge(last_node_before_sequencer, *control_after_compilation);
4729
4730 return Status::OK();
4731 }
4732
4733 // Updates the head and tail outside compiled nodes so that nodes have the
4734 // correct device and removes the replication and outside compilation attributes
4735 // so that these nodes do not trigger further graph optimization passes.
UpdateHeadTailOutsideCompilation(const std::vector<std::vector<string>> & tf_device_assignment,const std::vector<Node * > & head_tail_outside_compilation_nodes)4736 /* static */ Status DistributedTPURewritePass::UpdateHeadTailOutsideCompilation(
4737 const std::vector<std::vector<string>>& tf_device_assignment,
4738 const std::vector<Node*>& head_tail_outside_compilation_nodes) {
4739 for (Node* node : head_tail_outside_compilation_nodes) {
4740 int replica_id;
4741 TF_RETURN_IF_ERROR(
4742 GetNodeAttr(node->def(), kXlaReplicaIdAttrName, &replica_id));
4743 // Since we set the device, this will now run on a task other than 0. We
4744 // clear the two following attributes so that we don't trigger encapsulation
4745 // again on the remote host (which will fail due to a missing
4746 // _TPUReplicateMetadata node for the cluster).
4747 for (const Edge* e : node->in_edges()) {
4748 // Resource consuming ops should colocate with its resource input.
4749 if (e->src()->IsArg() &&
4750 e->src()->output_type(e->src_output()) == DT_RESOURCE) {
4751 node->set_requested_device(tf_device_assignment[replica_id][0]);
4752 }
4753 }
4754 if (node->requested_device().empty()) {
4755 string cpu_device;
4756 TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
4757 tf_device_assignment[replica_id][0], &cpu_device));
4758 node->set_requested_device(cpu_device);
4759 }
4760 node->ClearAttr(kTPUReplicateAttr);
4761 node->ClearAttr(kOutsideCompilationAttr);
4762 }
4763 return Status::OK();
4764 }
4765
4766 // Performs the rewrite on a single TPUReplicate node.
RewriteTPUReplicateNode(const string & session_handle,const DeviceSet & device_set,Node * replicate_node,FunctionLibraryDefinition * flib_def,FunctionLibraryRuntime * flr,Node * host_compute_key_placeholder_node,const OutsideCompilationNodeMap & outside_compilation_nodes,const std::vector<Node * > & head_tail_outside_compilation_nodes,NodeToNodeReplicasMap * outside_compilation_node_images,Graph * graph,const GraphShapeInfo & shape_info,TPUReplicateDeviceNamesMapping * tpu_replicate_device_names_mapping,int64_t autotuner_thresh)4767 /* static */ Status DistributedTPURewritePass::RewriteTPUReplicateNode(
4768 const string& session_handle, const DeviceSet& device_set,
4769 Node* replicate_node, FunctionLibraryDefinition* flib_def,
4770 FunctionLibraryRuntime* flr, Node* host_compute_key_placeholder_node,
4771 const OutsideCompilationNodeMap& outside_compilation_nodes,
4772 const std::vector<Node*>& head_tail_outside_compilation_nodes,
4773 NodeToNodeReplicasMap* outside_compilation_node_images, Graph* graph,
4774 const GraphShapeInfo& shape_info,
4775 TPUReplicateDeviceNamesMapping* tpu_replicate_device_names_mapping,
4776 int64_t autotuner_thresh) {
4777 VLOG(2) << "Rewriting node " << replicate_node->name();
4778
4779 // num_replicas and num_cores_per_replica are the 'virtual' replicas (copies
4780 // of the computation) and cores (virtual cores within computations) specified
4781 // by the user. They will be mapped to physical TPU cores below.
4782 int num_replicas;
4783 int num_cores_per_replica;
4784 int num_tasks;
4785 std::vector<std::vector<string>> tf_device_assignment;
4786 std::vector<int> devices_to_lock;
4787 std::unique_ptr<xla::DeviceAssignment> xla_device_assignment;
4788 string tpu_compilation_device;
4789 TF_RETURN_IF_ERROR(GetDeviceTopology(
4790 device_set, *replicate_node, &num_replicas, &num_cores_per_replica,
4791 &num_tasks, &tf_device_assignment, &devices_to_lock,
4792 &xla_device_assignment, &tpu_compilation_device));
4793
4794 TF_RETURN_IF_ERROR(UpdateHeadTailOutsideCompilation(
4795 tf_device_assignment, head_tail_outside_compilation_nodes));
4796
4797 string replicate;
4798 TF_RETURN_IF_ERROR(
4799 GetNodeAttr(replicate_node->def(), kTPUReplicateAttr, &replicate));
4800 tpu_replicate_device_names_mapping->emplace(replicate, tf_device_assignment);
4801
4802 NameRangeMap input_name_map;
4803 const NameAttrList* function;
4804 std::unique_ptr<Graph> computation;
4805 DataTypeVector arg_types, retval_types;
4806 ParameterInfo params_info;
4807 TF_RETURN_IF_ERROR(GetIOTypes(num_replicas, *replicate_node, flr, graph,
4808 &input_name_map, &function, &computation,
4809 &arg_types, &retval_types, ¶ms_info));
4810
4811 std::vector<InferredShape> arg_shapes, retval_shapes;
4812 TF_RETURN_IF_ERROR(GetArgAndRetvalShapes(
4813 shape_info, *replicate_node, params_info, &arg_shapes, &retval_shapes));
4814
4815 TF_RETURN_IF_ERROR(ValidateCoreNumbers(*computation, num_cores_per_replica));
4816
4817 std::vector<xla::OpSharding> arg_sharding;
4818 std::vector<bool> arg_fast_mem;
4819 std::vector<std::string> arg_names;
4820 std::vector<xla::OpSharding> retval_sharding;
4821 TF_RETURN_IF_ERROR(AssignArgsAndRetvalsToCores(
4822 num_cores_per_replica, params_info, arg_types, arg_shapes, retval_types,
4823 retval_shapes, *computation, replicate_node, flr,
4824 allow_xla_spmd_partition_, &arg_sharding, &arg_fast_mem, &retval_sharding,
4825 &arg_names));
4826
4827 VLOG(1) << DumpGraphToFile("distributed_tpu_graph_to_replicate", *computation,
4828 flib_def);
4829
4830 GraphDef graph_def;
4831 graph->ToGraphDef(&graph_def);
4832 FunctionLibraryDefinition reachable_functions =
4833 flib_def->ReachableDefinitions(graph_def);
4834 uint64 library_fingerprint;
4835
4836 TF_RETURN_IF_ERROR(
4837 FingerprintFunctionLibrary(reachable_functions, &library_fingerprint));
4838 VLOG(1) << "Fingerprint functions: "
4839 << absl::StrJoin(reachable_functions.ListFunctionNames(), ", ");
4840 VLOG(1) << "library_fingerprint: " << library_fingerprint;
4841
4842 // Builds trigger nodes that put barriers around the expansion of
4843 // TPUReplicate. In particular, we must guarantee:
4844 // a) variable reads happen after all predecessors of the original
4845 // TPUReplicate.
4846 // b) variable writes happen before all successors of the original
4847 // TPUReplicate.
4848 // c) all replicas execute, even if output tensors are only requested from
4849 // a subset of replicas. This is necessary both to ensure that variable
4850 // updates happen, but also Send/Recv will deadlock if only one half of
4851 // the communicating pair runs.
4852 Node* host_transfer_sequencer;
4853 Node* control_before;
4854 Node* control_after;
4855 TF_RETURN_IF_ERROR(BuildSequencingNodes(
4856 tpu_compilation_device, *replicate_node, graph, &host_transfer_sequencer,
4857 &control_before, &control_after));
4858
4859 // Build a vector of variable nodes that are inputs.
4860 std::vector<VariableInput> variable_inputs;
4861 TF_RETURN_IF_ERROR(
4862 FindVariableInputs(*replicate_node, input_name_map, &variable_inputs));
4863
4864 std::vector<Node*> guaranteed_constant_nodes;
4865 std::vector<Node*> variable_reads;
4866 TF_RETURN_IF_ERROR(DealWithConstantsAndVariables(
4867 *replicate_node, input_name_map, graph, host_transfer_sequencer,
4868 control_before, control_after, variable_inputs,
4869 &guaranteed_constant_nodes, &variable_reads));
4870
4871 // Builds Shape nodes that compute the dynamic shapes of arguments whose
4872 // shapes are not statically known.
4873 std::vector<Node*> dynamic_shape_nodes;
4874 TF_RETURN_IF_ERROR(BuildDynamicShapeNodes(*replicate_node, arg_shapes,
4875 params_info, variable_reads, graph,
4876 &dynamic_shape_nodes));
4877
4878 // Builds a TPUCompile node that compiles `clusters` on `compile_device`.
4879 Node* compile_node;
4880 TF_RETURN_IF_ERROR(BuildCompileNode(
4881 replicate_node, *function, library_fingerprint, params_info, arg_shapes,
4882 arg_types, guaranteed_constant_nodes, session_handle, arg_sharding,
4883 arg_fast_mem, arg_names, retval_sharding, num_cores_per_replica,
4884 /*compile_device=*/tpu_compilation_device, xla_device_assignment.get(),
4885 dynamic_shape_nodes, graph, &compile_node, autotuner_thresh));
4886
4887 // Compilation must be sequenced after the control node if the TPU computation
4888 // in a control-flow construct, such as a loop.
4889 graph->AddControlEdge(control_before, compile_node);
4890
4891 Node* control_after_compilation;
4892 Node* multilock_acquire;
4893 TF_RETURN_IF_ERROR(BuildCompilationStatusReturnNodes(
4894 replicate_node, compile_node, devices_to_lock, &control_after_compilation,
4895 &multilock_acquire, graph));
4896
4897 std::vector<VariableWrite> variable_writes;
4898 TF_RETURN_IF_ERROR(BuildExecuteNodes(
4899 params_info, num_tasks, num_cores_per_replica, *replicate_node, arg_names,
4900 arg_types, arg_shapes, retval_types, arg_sharding, retval_sharding,
4901 tf_device_assignment, compile_node, variable_reads,
4902 control_after_compilation, control_after, multilock_acquire,
4903 &variable_writes, graph));
4904 bool contains_resource_write_op =
4905 ContainsResourceWriteOp(*graph, reachable_functions);
4906
4907 VLOG(2) << "contains_resource_write_op: " << contains_resource_write_op;
4908 // Skip conditional write if there is no resource writing op inside TPU
4909 // computation.
4910 if (contains_resource_write_op) {
4911 TF_RETURN_IF_ERROR(BuildVariableWrites(variable_inputs, control_after,
4912 variable_writes, graph));
4913 }
4914
4915 if (host_compute_key_placeholder_node != nullptr) {
4916 TF_RETURN_IF_ERROR(ConnectHostComputeNodes(
4917 compile_node, host_compute_key_placeholder_node, graph));
4918 }
4919
4920 HostComputeCoreMap host_compute_core;
4921 TF_RETURN_IF_ERROR(ParseHostComputeCores(
4922 *replicate_node, outside_compilation_nodes, &host_compute_core));
4923 TF_RETURN_IF_ERROR(ReplicateOutsideCompilationNodes(
4924 tf_device_assignment, host_compute_core, outside_compilation_nodes,
4925 outside_compilation_node_images, graph));
4926
4927 graph->RemoveNode(replicate_node);
4928 return Status::OK();
4929 }
4930
4931 // Adds sharded weight update optimization for each host training loop.
4932 //
4933 // For any host training loop found in the graph, TPUVariableReshard ops
4934 // are inserted to match the best layout chosen by the XLA.
4935 /* static */ Status
PerformHostTrainingLoopOptimization(Graph * graph,FunctionLibraryDefinition * flib_def,FunctionLibraryRuntime * flr)4936 DistributedTPURewritePass::PerformHostTrainingLoopOptimization(
4937 Graph* graph, FunctionLibraryDefinition* flib_def,
4938 FunctionLibraryRuntime* flr) {
4939 std::vector<tpu::HostTrainingLoopInfo> host_training_loops_info;
4940 Status s = tpu::DetectHostTrainingLoop(
4941 /*current_function_name=*/nullptr,
4942 /*current_function_attr=*/nullptr, flib_def, graph, flr,
4943 &host_training_loops_info);
4944 if (!s.ok()) {
4945 VLOG(2) << "No valid host training loop found. Skipping sharded weight "
4946 << "update optimization.";
4947 return Status::OK();
4948 }
4949
4950 for (const auto& host_loop : host_training_loops_info) {
4951 const auto& function_name = host_loop.encapsulating_function_name;
4952 // `function_name` has value when host training loop is inside a
4953 // function call node. When host training loop is found inside a function
4954 // call node, then, in addition to adding TPUVariableReshard ops, function
4955 // library definition needs to be updated as well.
4956 if (function_name.has_value()) {
4957 const auto& function_attr = host_loop.encapsulating_function_attrs;
4958 TF_RET_CHECK(function_attr.has_value())
4959 << "Unable to find function attribute for function: "
4960 << *function_name;
4961
4962 const FunctionDef* function_def = flib_def->Find(*function_name);
4963 TF_RET_CHECK(function_def)
4964 << "Unable to find function : " << *function_name;
4965
4966 std::unique_ptr<FunctionBody> fbody;
4967 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
4968 *function_def, AttrSlice(&function_attr.value()), flib_def, &fbody));
4969 Graph* function_graph = fbody->graph;
4970 TF_RETURN_IF_ERROR(tpu::AddReshardOp(function_graph, host_loop));
4971 TF_RETURN_IF_ERROR(UpdateFunctionLibDefinition(*function_graph,
4972 *function_name, flib_def));
4973 } else {
4974 TF_RETURN_IF_ERROR(tpu::AddReshardOp(graph, host_loop));
4975 }
4976 }
4977 return Status::OK();
4978 }
4979
PlaceUnassignedDeviceNodesOnTPUIfPossible(Graph * graph)4980 Status DistributedTPURewritePass::PlaceUnassignedDeviceNodesOnTPUIfPossible(
4981 Graph* graph) {
4982 ReverseDFS(*graph, {}, PlaceOpsOnTPU);
4983 return Status::OK();
4984 }
4985
Run(const GraphOptimizationPassOptions & options)4986 Status DistributedTPURewritePass::Run(
4987 const GraphOptimizationPassOptions& options) {
4988 VLOG(1) << "DistributedTPURewritePass::Run";
4989
4990 Graph* graph = options.graph->get();
4991
4992 VLOG(1) << DumpGraphToFile("distributed_tpu_compilation_before", *graph,
4993 options.flib_def);
4994
4995 const auto* config = &options.session_options->config;
4996 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
4997 new ProcessFunctionLibraryRuntime(
4998 nullptr, options.session_options->env, config,
4999 graph->versions().producer(), options.flib_def,
5000 config ? config->graph_options().optimizer_options()
5001 : OptimizerOptions()));
5002
5003 FunctionLibraryRuntime* flr =
5004 pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
5005
5006 // This pass can only run in the session master, which should fill
5007 // in the device_set field to the options.
5008 TF_RET_CHECK(options.device_set != nullptr);
5009
5010 // Find all the replicate nodes before mutating the graph.
5011 std::vector<Node*> replicate_nodes;
5012 // Map from compiled subgraph cluster name to the outside_compilation nodes in
5013 // that cluster.
5014 std::map<string, OutsideCompilationNodeMap> outside_compilation_nodes;
5015 std::map<string, std::vector<Node*>> head_tail_outside_compilation_nodes;
5016 TF_RETURN_IF_ERROR(FindTaggedNodes(graph, &replicate_nodes,
5017 &outside_compilation_nodes,
5018 &head_tail_outside_compilation_nodes));
5019
5020 if (replicate_nodes.empty()) {
5021 // Remove unused TPUPartitionedInput nodes.
5022 for (Node* n : graph->nodes()) {
5023 if (n->type_string() == kTPUPartitionedInput) graph->RemoveNode(n);
5024 }
5025 VLOG(1) << DumpGraphToFile("distributed_tpu_compilation_after", *graph,
5026 options.flib_def);
5027 VLOG(1) << "Replicate nodes are empty. DistributedTPURewritePass::Run() "
5028 "finished";
5029 return Status::OK();
5030 }
5031
5032 std::unordered_map<string, Node*> host_compute_key_placeholder_map;
5033 TF_RETURN_IF_ERROR(FindHostComputeKeyPlaceholderNodes(
5034 graph, replicate_nodes, &host_compute_key_placeholder_map));
5035
5036 // This shape inference pass does not compute the shapes of outputs of
5037 // TPU computations. The concurrent multi-core locking implementation
5038 // *relies* on this behavior because it ensures that, if TPU computation B's
5039 // inputs depend on TPU computation A's outputs, then computation B's
5040 // compilation will be sequenced after A's execution, and this ensures that
5041 // locks are acquired in the correct order. If the shape inference is improved
5042 // to compute shapes of TPU computation outputs, it will be necessary to add
5043 // an explicit control edge between lock acquisitions for dependent
5044 // computations in order to avoid deadlock.
5045 GraphShapeInfo shape_info;
5046 TF_RETURN_IF_ERROR(InferShapes(graph, /*arg_shapes=*/{},
5047 flr->GetFunctionLibraryDefinition(),
5048 &shape_info));
5049 int64_t autotuner_thresh = options.session_options->config.experimental()
5050 .xla_fusion_autotuner_thresh();
5051
5052 NodeToNodeReplicasMap outside_compilation_node_images;
5053 TPUReplicateDeviceNamesMapping tpu_replicate_device_names_mapping;
5054 for (Node* node : replicate_nodes) {
5055 TF_RETURN_IF_ERROR(RewriteTPUReplicateNode(
5056 options.session_handle, *options.device_set, node, options.flib_def,
5057 flr, host_compute_key_placeholder_map[node->name()],
5058 outside_compilation_nodes[node->name()],
5059 head_tail_outside_compilation_nodes[node->name()],
5060 &outside_compilation_node_images, graph, shape_info,
5061 &tpu_replicate_device_names_mapping, autotuner_thresh));
5062 }
5063
5064 // Place the padding nodes generated by dynamic padder on the correct devices.
5065 // TODO(rxsang): Place padding ops on TPUs in
5066 // PlaceUnassignedDeviceNodesOnTPUIfPossible function.
5067 TF_RETURN_IF_ERROR(SetPaddingNodesDevices(graph));
5068
5069 std::unordered_map<string, Node*> outside_compilation_inputs;
5070 for (Node* n : graph->op_nodes()) {
5071 string lifted_arg_inputs_attr;
5072 if (n->type_string() == "IdentityN" &&
5073 GetNodeAttr(n->def(), kXlaOutsideCompilationInputsAttrName,
5074 &lifted_arg_inputs_attr)
5075 .ok()) {
5076 outside_compilation_inputs[lifted_arg_inputs_attr] = n;
5077 }
5078 }
5079 for (const auto& iter : outside_compilation_nodes) {
5080 TF_RETURN_IF_ERROR(ReplicateOutsideCompilationEdges(
5081 iter.second, outside_compilation_node_images,
5082 outside_compilation_inputs, graph));
5083 }
5084 TF_RETURN_IF_ERROR(
5085 RemoveOutsideCompilationNodes(outside_compilation_node_images, graph));
5086 TF_RETURN_IF_ERROR(LowerOutsideCompilationFunctionalNodes(
5087 graph, *options.flib_def, tpu_replicate_device_names_mapping));
5088
5089 TF_RETURN_IF_ERROR(PlaceUnassignedDeviceNodesOnTPUIfPossible(graph));
5090 VLOG(1) << DumpGraphToFile("distributed_tpu_compilation_after", *graph,
5091 options.flib_def);
5092 VLOG(1) << "DistributedTPURewritePass::Run() finished";
5093
5094 if (enable_cross_replica_sharding_mirrored_variables_) {
5095 VLOG(1) << "Starting host training loop optimization.";
5096 VLOG(1) << DumpGraphToFile("host_loop_optimization_before", *graph,
5097 options.flib_def);
5098 TF_RETURN_IF_ERROR(
5099 PerformHostTrainingLoopOptimization(graph, options.flib_def, flr));
5100 VLOG(1) << DumpGraphToFile("host_loop_optimization_after", *graph,
5101 options.flib_def);
5102 VLOG(1) << "Host training loop optimization finished.";
5103 }
5104
5105 return Status::OK();
5106 }
5107
5108 bool DistributedTPURewritePass::distribute_vars_ = false;
5109 bool DistributedTPURewritePass::allow_xla_spmd_partition_ = true;
5110 bool DistributedTPURewritePass::
5111 replicate_inputs_outputs_by_default_for_xla_spmd_ = false;
5112 bool DistributedTPURewritePass::
5113 enable_cross_replica_sharding_mirrored_variables_ = true;
5114 bool DistributedTPURewritePass::enable_automatic_model_parallelism_ = false;
5115 bool DistributedTPURewritePass::enable_xla_param_broadcast_ = false;
5116 bool DistributedTPURewritePass::enable_multicore_locking_ = false;
5117 bool DistributedTPURewritePass::use_nd_sharding_ops_ = false;
5118
SetDistributedTpuRewritePassOptions(bool distribute_vars,bool allow_xla_spmd_partition,bool replicate_inputs_outputs_by_default_for_xla_spmd,bool enable_cross_replica_sharding_mirrored_variables,bool enable_automatic_model_parallelism,bool enable_xla_param_broadcast,bool enable_multicore_locking,bool use_nd_sharding_ops)5119 /*static*/ void DistributedTPURewritePass::SetDistributedTpuRewritePassOptions(
5120 bool distribute_vars, bool allow_xla_spmd_partition,
5121 bool replicate_inputs_outputs_by_default_for_xla_spmd,
5122 bool enable_cross_replica_sharding_mirrored_variables,
5123 bool enable_automatic_model_parallelism, bool enable_xla_param_broadcast,
5124 bool enable_multicore_locking, bool use_nd_sharding_ops) {
5125 distribute_vars_ = distribute_vars;
5126 allow_xla_spmd_partition_ = allow_xla_spmd_partition;
5127 replicate_inputs_outputs_by_default_for_xla_spmd_ =
5128 replicate_inputs_outputs_by_default_for_xla_spmd;
5129 enable_cross_replica_sharding_mirrored_variables_ =
5130 enable_cross_replica_sharding_mirrored_variables;
5131 enable_automatic_model_parallelism_ = enable_automatic_model_parallelism;
5132 enable_xla_param_broadcast_ = enable_xla_param_broadcast;
5133 enable_multicore_locking_ = enable_multicore_locking;
5134 use_nd_sharding_ops_ = use_nd_sharding_ops;
5135 }
5136
5137 } // namespace tensorflow
5138