• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Compilation for distributed TPU (TPU_REPLICATED_CORE devices).
17 
18 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h"
19 
20 #include <algorithm>
21 #include <queue>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/container/btree_map.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/strings/escaping.h"
29 #include "tensorflow/compiler/jit/encapsulate_util.h"
30 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
31 #include "tensorflow/compiler/tf2xla/sharding_util.h"
32 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
33 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
34 #include "tensorflow/compiler/xla/array3d.h"
35 #include "tensorflow/compiler/xla/array4d.h"
36 #include "tensorflow/compiler/xla/client/sharding_builder.h"
37 #include "tensorflow/compiler/xla/service/computation_placer.h"
38 #include "tensorflow/compiler/xla/xla.pb.h"
39 #include "tensorflow/core/common_runtime/function.h"
40 #include "tensorflow/core/common_runtime/graph_constructor.h"
41 #include "tensorflow/core/common_runtime/lower_function_call_op.h"
42 #include "tensorflow/core/common_runtime/lower_functional_ops.h"
43 #include "tensorflow/core/common_runtime/lower_if_op.h"
44 #include "tensorflow/core/common_runtime/lower_while_op.h"
45 #include "tensorflow/core/common_runtime/optimization_registry.h"
46 #include "tensorflow/core/framework/function.h"
47 #include "tensorflow/core/framework/graph_to_functiondef.h"
48 #include "tensorflow/core/framework/node_def_builder.h"
49 #include "tensorflow/core/framework/node_def_util.h"
50 #include "tensorflow/core/framework/partial_tensor_shape.h"
51 #include "tensorflow/core/framework/tensor.pb.h"
52 #include "tensorflow/core/framework/types.pb.h"
53 #include "tensorflow/core/framework/versions.pb.h"
54 #include "tensorflow/core/graph/algorithm.h"
55 #include "tensorflow/core/graph/graph.h"
56 #include "tensorflow/core/lib/core/errors.h"
57 #include "tensorflow/core/lib/core/status.h"
58 #include "tensorflow/core/lib/gtl/cleanup.h"
59 #include "tensorflow/core/lib/strings/proto_serialization.h"
60 #include "tensorflow/core/lib/strings/str_util.h"
61 #include "tensorflow/core/platform/fingerprint.h"
62 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
63 #include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
64 #include "tensorflow/core/protobuf/tpu/topology.pb.h"
65 #include "tensorflow/core/public/session_options.h"
66 #include "tensorflow/core/tpu/graph_rewrite/cond_builder.h"
67 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h"
68 #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h"
69 #include "tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h"
70 #include "tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h"
71 #include "tensorflow/core/tpu/tpu_compile_interface.h"
72 #include "tensorflow/core/tpu/tpu_defs.h"
73 #include "tensorflow/core/tpu/tpu_fingerprint_utils.h"
74 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
75 #include "tensorflow/core/util/device_name_utils.h"
76 #include "tensorflow/core/util/dump_graph.h"
77 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
78 
79 namespace tensorflow {
80 
81 namespace {
82 
83 // Device coordinates are defined as (x, y, z, core), thus resulting in a rank 4
84 // topology.
85 constexpr int kTPUTopologyRank = 4;
86 
87 // An upper bound on how many cores may be present in the topology.
88 static constexpr int kTPUMaxTopologySize = 4096;
89 
90 // Attribute containing the serialized xla::OpSharding to be passed to the
91 // corresponding XLA HLO operation, which represents how a shape is distributed
92 // across logical cores, e.g., replication, single-device, or partitioning.
93 const char kShardingAttribute[] = "_XlaSharding";
94 
95 const char kTPUPartitionedInput[] = "TPUPartitionedInput";
96 const char kTPUPartitionedOutput[] = "TPUPartitionedOutput";
97 
98 const char kVarHandleOp[] = "VarHandleOp";
99 
100 static const char* const kTPUCompilationResultAttr = "_tpu_compilation_status";
101 static const char* const kPostDeviceRewriteAttr = "_post_device_rewrite";
102 
103 using NodeAndId = std::pair<const Node*, int>;
104 
105 struct NodeAndPort {
NodeAndPorttensorflow::__anonf6cf56690111::NodeAndPort106   explicit NodeAndPort(Node* node, int port) : node(node), port(port) {}
107 
108   Node* node;
109   // Port of the node, e.g. this can be the `src_output` index of an Edge.
110   int port;
111 };
112 
113 class IntrusiveHeapLink {
114  public:
115   using size_type = size_t;
116   static constexpr size_type kNotMember = -1;
117 
118   IntrusiveHeapLink() = default;
119 
120   // Only IntrusiveHeap and LinkAccess objects should make these objects.
IntrusiveHeapLink(size_type pos)121   explicit IntrusiveHeapLink(size_type pos) : pos_{pos} {}
122 
123   // Only IntrusiveHeap and LinkAccess should get the value.
get() const124   size_type get() const { return pos_; }
125 
126  private:
127   size_type pos_{kNotMember};
128 };
129 
130 template <typename T, IntrusiveHeapLink T::*M>
131 struct IntrusiveHeapDataMemberLinkAccess {
Gettensorflow::__anonf6cf56690111::IntrusiveHeapDataMemberLinkAccess132   IntrusiveHeapLink Get(const T* elem) const { return elem->*M; }
Settensorflow::__anonf6cf56690111::IntrusiveHeapDataMemberLinkAccess133   void Set(T* elem, IntrusiveHeapLink link) const { elem->*M = link; }
134 };
135 
136 template <typename T>
137 struct DefaultIntrusiveHeapLinkAccess {
Gettensorflow::__anonf6cf56690111::DefaultIntrusiveHeapLinkAccess138   IntrusiveHeapLink Get(const T* elem) const { return elem->heap; }
Settensorflow::__anonf6cf56690111::DefaultIntrusiveHeapLinkAccess139   void Set(T* elem, IntrusiveHeapLink link) const { elem->heap = link; }
140 };
141 
142 template <typename T, typename PtrCompare,
143           typename LinkAccess = DefaultIntrusiveHeapLinkAccess<T>,
144           typename Alloc = std::allocator<T*>>
145 class IntrusiveHeap {
146  public:
147   typedef typename IntrusiveHeapLink::size_type size_type;
148   typedef T value_type;
149   typedef T* pointer;
150   typedef const T* const_pointer;
151   typedef PtrCompare pointer_compare_type;
152   typedef LinkAccess link_access_type;
153   typedef Alloc allocator_type;
154 
IntrusiveHeap(const pointer_compare_type & comp=pointer_compare_type (),const link_access_type & link_access=link_access_type (),const allocator_type & alloc=allocator_type ())155   explicit IntrusiveHeap(
156       const pointer_compare_type& comp = pointer_compare_type(),
157       const link_access_type& link_access = link_access_type(),
158       const allocator_type& alloc = allocator_type())
159       : rep_(comp, link_access, alloc) {}
160 
size() const161   size_type size() const { return heap().size(); }
162 
empty() const163   bool empty() const { return heap().empty(); }
164 
165   // Return the top element, but don't remove it.
top() const166   pointer top() const {
167     DCHECK(!empty());
168     return heap()[0];
169   }
170 
171   // Remove the top() pointer from the heap and return it.
Pop()172   pointer Pop() {
173     pointer t = top();
174     Remove(t);
175     return t;
176   }
177 
178   // Insert 't' into the heap.
Push(pointer t)179   void Push(pointer t) {
180     SetPositionOf(t, heap().size());
181     heap().push_back(t);
182     FixHeapUp(t);
183   }
184 
185   // Adjust the heap to accommodate changes in '*t'.
Adjust(pointer t)186   void Adjust(pointer t) {
187     DCHECK(Contains(t));
188     size_type h = GetPositionOf(t);
189     if (h != 0 && compare()(t, heap()[(h - 1) >> 1])) {
190       FixHeapUp(t);
191     } else {
192       FixHeapDown(t);
193     }
194   }
195 
196   // Remove the specified pointer from the heap.
Remove(pointer t)197   void Remove(pointer t) {
198     DCHECK(Contains(t));
199     size_type h = GetPositionOf(t);
200     SetPositionOf(t, IntrusiveHeapLink::kNotMember);
201     if (h == heap().size() - 1) {
202       // Fast path for removing from back of heap.
203       heap().pop_back();
204       return;
205     }
206     // Move the element from the back of the heap to overwrite 't'.
207     pointer& elem = heap()[h];
208     elem = heap().back();
209     SetPositionOf(elem, h);  // Element has moved, so update its link.
210     heap().pop_back();
211     Adjust(elem);  // Restore the heap invariant.
212   }
213 
Clear()214   void Clear() { heap().clear(); }
215 
Contains(const_pointer t) const216   bool Contains(const_pointer t) const {
217     size_type h = GetPositionOf(t);
218     return (h != IntrusiveHeapLink::kNotMember) && (h < size()) &&
219            heap()[h] == t;
220   }
221 
reserve(size_type n)222   void reserve(size_type n) { heap().reserve(n); }
223 
capacity() const224   size_type capacity() const { return heap().capacity(); }
225 
get_allocator() const226   allocator_type get_allocator() const { return rep_.heap_.get_allocator(); }
227 
228  private:
229   typedef std::vector<pointer, allocator_type> heap_type;
230 
231   // Empty base class optimization for pointer_compare and link_access.
232   // The heap_ data member retains a copy of the allocator, so it is not
233   // stored explicitly.
234   struct Rep : pointer_compare_type, link_access_type {
Reptensorflow::__anonf6cf56690111::IntrusiveHeap::Rep235     explicit Rep(const pointer_compare_type& cmp,
236                  const link_access_type& link_access,
237                  const allocator_type& alloc)
238         : pointer_compare_type(cmp),
239           link_access_type(link_access),
240           heap_(alloc) {}
241     heap_type heap_;  // NOLINT
242   };
243 
compare() const244   const pointer_compare_type& compare() const { return rep_; }
245 
link_access() const246   const link_access_type& link_access() const { return rep_; }
247 
heap() const248   const heap_type& heap() const { return rep_.heap_; }
heap()249   heap_type& heap() { return rep_.heap_; }
250 
GetPositionOf(const_pointer t) const251   size_type GetPositionOf(const_pointer t) const {
252     return link_access().Get(t).get();
253   }
254 
SetPositionOf(pointer t,size_type pos) const255   void SetPositionOf(pointer t, size_type pos) const {
256     return link_access().Set(t, IntrusiveHeapLink(pos));
257   }
258 
FixHeapUp(pointer t)259   void FixHeapUp(pointer t) {
260     size_type h = GetPositionOf(t);
261     while (h != 0) {
262       size_type parent = (h - 1) >> 1;
263       if (compare()(heap()[parent], t)) {
264         break;
265       }
266       heap()[h] = heap()[parent];
267       SetPositionOf(heap()[h], h);
268       h = parent;
269     }
270     heap()[h] = t;
271     SetPositionOf(t, h);
272   }
273 
FixHeapDown(pointer t)274   void FixHeapDown(pointer t) {
275     size_type h = GetPositionOf(t);
276     for (;;) {
277       size_type kid = (h << 1) + 1;
278       if (kid >= heap().size()) {
279         break;
280       }
281       if (kid + 1 < heap().size() && compare()(heap()[kid + 1], heap()[kid])) {
282         ++kid;
283       }
284       if (compare()(t, heap()[kid])) {
285         break;
286       }
287       heap()[h] = heap()[kid];
288       SetPositionOf(heap()[h], h);
289       h = kid;
290     }
291 
292     heap()[h] = t;
293     SetPositionOf(t, h);
294   }
295 
296   Rep rep_;
297 };
298 
CoreDeviceLabel(int core)299 string CoreDeviceLabel(int core) {
300   return strings::StrCat("/device:", DEVICE_TPU_REPLICATED_CORE, ":", core);
301 }
302 
303 // Creates a unique node name with a particular prefix.
UniqueNodeName(const StringPiece prefix,Graph * graph)304 string UniqueNodeName(const StringPiece prefix, Graph* graph) {
305   return graph->NewName(strings::StrCat(prefix, "/_", internal::GetNodeId()));
306 }
307 
SetNodeDeviceForTPUCommunication(DeviceNameUtils::ParsedName device,const string & target_device_type,Node * node)308 Status SetNodeDeviceForTPUCommunication(DeviceNameUtils::ParsedName device,
309                                         const string& target_device_type,
310                                         Node* node) {
311   TF_RET_CHECK(device.has_type && device.type == DEVICE_TPU_NODE);
312   TF_RET_CHECK(device.has_id);
313   TF_RET_CHECK(HasNodeAttr(node->def(), kXlaHasHostTransferAttrName));
314 
315   // Store the device instance as an attr on the Node.
316   TF_RETURN_IF_ERROR(SetDeviceOrdinalAttributeForNode(node, device.id));
317 
318   // Place the execute Op on the TPU_SYSTEM device so it can access the cache of
319   // compiled protos in the resource manager.
320   device.type = target_device_type;
321   device.id = 0;
322 
323   node->set_assigned_device_name(DeviceNameUtils::ParsedNameToString(device));
324   return Status::OK();
325 }
326 
327 // Iterate over the nodes in the original graph and find all the TPUReplicate
328 // nodes, and all the nodes that are part of outside_compilation clusters.
FindTaggedNodes(Graph * graph,std::vector<Node * > * replicate_nodes,std::map<string,DistributedTPURewritePass::OutsideCompilationNodeMap> * outside_compilation_nodes,std::map<string,std::vector<Node * >> * head_tail_outside_compilation_nodes)329 Status FindTaggedNodes(
330     Graph* graph, std::vector<Node*>* replicate_nodes,
331     std::map<string, DistributedTPURewritePass::OutsideCompilationNodeMap>*
332         outside_compilation_nodes,
333     std::map<string, std::vector<Node*>>* head_tail_outside_compilation_nodes) {
334   for (Node* node : graph->op_nodes()) {
335     if (node->type_string() == "_TPUReplicate") {
336       replicate_nodes->push_back(node);
337       const AttrValue* cluster_attr = node->attrs().Find(kTPUReplicateAttr);
338       if (cluster_attr == nullptr) {
339         return errors::Internal("TPUReplicate node ", node->name(), " has no ",
340                                 kTPUReplicateAttr, " attr.");
341       } else {
342         const string& cluster = cluster_attr->s();
343         if (cluster.empty()) {
344           return errors::Internal("Attr ", kTPUReplicateAttr, " on node ",
345                                   node->name(), " has no string value.");
346         }
347         if (outside_compilation_nodes->find(cluster) !=
348             outside_compilation_nodes->end()) {
349           return errors::Internal(
350               "TPUReplicate node ", node->name(), " has ", kTPUReplicateAttr,
351               " attr value '", cluster,
352               "' which is a duplicate of another TPUReplicate node in the "
353               "graph.");
354         }
355         (*outside_compilation_nodes)[cluster] =
356             DistributedTPURewritePass::OutsideCompilationNodeMap();
357         (*head_tail_outside_compilation_nodes)[cluster] = std::vector<Node*>();
358       }
359     }
360   }
361   for (Node* node : graph->op_nodes()) {
362     if (node->type_string() != "_TPUReplicate") {
363       const AttrValue* cluster_attr = node->attrs().Find(kTPUReplicateAttr);
364       const AttrValue* outside_compilation_attr =
365           node->attrs().Find(kOutsideCompilationAttr);
366       if (cluster_attr == nullptr) {
367         if (outside_compilation_attr != nullptr) {
368           return errors::Internal("Node ", node->name(), " has ",
369                                   kOutsideCompilationAttr, " attr but no ",
370                                   kTPUReplicateAttr, " attr.");
371         }
372       } else {
373         const string& cluster = cluster_attr->s();
374         if (cluster.empty()) {
375           return errors::Internal("Attr ", kTPUReplicateAttr, " on node ",
376                                   node->name(), " has no string value.");
377         }
378         const auto iter = outside_compilation_nodes->find(cluster);
379         if (iter == outside_compilation_nodes->end()) {
380           return errors::Internal(
381               "Attr ", kTPUReplicateAttr, " on node ", node->name(),
382               " does not correspond to a TPUReplicate node.");
383         }
384         if (outside_compilation_attr == nullptr) {
385           return errors::Internal("Node ", node->name(), " has ",
386                                   kTPUReplicateAttr, " attr but no ",
387                                   kOutsideCompilationAttr, " attr.");
388         }
389         const string& oc_cluster = outside_compilation_attr->s();
390         if (oc_cluster.empty()) {
391           return errors::Internal("Attr ", kOutsideCompilationAttr, " on node ",
392                                   node->name(), " has no string value.");
393         }
394 
395         // Outside compilation cluster at head and tail of TPU computation has
396         // already been moved to host and is already replicated. As so, do not
397         // replicate outside compilation nodes with replica id attribute.
398         int replica_id;
399         if (TryGetNodeAttr(node->def(), kXlaReplicaIdAttrName, &replica_id)) {
400           const AttrValue* head_attr =
401               node->attrs().Find("_xla_only_arg_or_oc_input");
402           const AttrValue* tail_attr =
403               node->attrs().Find("_xla_only_ret_or_oc_output");
404           if (((head_attr != nullptr) && (head_attr->b())) ||
405               ((tail_attr != nullptr) && (tail_attr->b()))) {
406             // This is safe as this has the same keys as
407             // outside_compilation_nodes which we already know has this key.
408             (*head_tail_outside_compilation_nodes)[cluster].push_back(node);
409           }
410           continue;
411         }
412         iter->second[oc_cluster].push_back(node);
413       }
414     }
415   }
416   return Status::OK();
417 }
418 
419 // Helper class to spread TPU computation arguments and return values
420 // across cores.
421 // If all shapes are fully defined, balance by their size.
422 // If some of them are not fully defined, the undefined shapes size will
423 // be estimated with the average size of the fully defined ones.
424 // If none are defined, fall back to round-robin.
425 class TensorDevicePlacer {
426  public:
427   // Creates a TensorDevicePlacer object to distribute arguments or
428   // return values to a set of num_devices devices, where the types and
429   // the inferred shapes of the inputs (arguments or return values) are
430   // passed in types and shapes.
TensorDevicePlacer(int64_t num_devices,const DataTypeVector & types,const std::vector<InferredShape> & shapes)431   TensorDevicePlacer(int64_t num_devices, const DataTypeVector& types,
432                      const std::vector<InferredShape>& shapes)
433       : index_nodes_(num_devices), sizes_(types.size()) {
434     int64_t total_size = 0;
435     int64_t num_defined = 0;
436     for (int64_t i = 0; i < types.size(); ++i) {
437       sizes_[i] = GetInferredShapeSize(shapes[i], types[i]);
438       if (sizes_[i] >= 0) {
439         total_size += sizes_[i];
440         ++num_defined;
441       }
442     }
443     // If a shape is undefined, select a size for it which is the average
444     // of the defined shapes. If no shapes are defined, assign 1 so that we
445     // get round-robin behavior.
446     int64_t undefined_shape_size =
447         (num_defined > 0) ? total_size / num_defined : 1;
448     for (int64_t i = 0; i < sizes_.size(); ++i) {
449       if (sizes_[i] < 0) {
450         sizes_[i] = undefined_shape_size;
451       }
452     }
453 
454     for (int64_t i = 0; i < num_devices; ++i) {
455       heap_.Push(&index_nodes_[i]);
456     }
457   }
458 
459   // Reports that the argument/return-value at index has been assigned
460   // by the user to a given device.
ReportDeviceAssigned(int64_t device,int64_t index)461   void ReportDeviceAssigned(int64_t device, int64_t index) {
462     if (device >= index_nodes_.size()) {
463       LOG(FATAL) << "Sharding assignment is out of bounds. "  // Crash OK
464                     "Check that the number of nodes is properly set.";
465     }
466     DeviceNode* node = &index_nodes_.at(device);
467     node->size += sizes_.at(index);
468     heap_.Adjust(node);
469   }
470 
471   // Retrieves the device at which the argument/return-value at index
472   // should be assigned to.
RetrieveAssignment(int64_t index)473   int64 RetrieveAssignment(int64_t index) {
474     DeviceNode* node = heap_.top();
475     int64_t device = node - index_nodes_.data();
476     node->size += sizes_.at(index);
477     heap_.Adjust(node);
478     return device;
479   }
480 
481  private:
482   struct DeviceNode {
483     struct Compare {
484       // Compare functor to implement a min heap using the ::gtl::IntrusiveHeap
485       // infrastructure.
operator ()tensorflow::__anonf6cf56690111::TensorDevicePlacer::DeviceNode::Compare486       bool operator()(const DeviceNode* lhs, const DeviceNode* rhs) const {
487         return lhs->size < rhs->size;
488       }
489     };
490 
491     IntrusiveHeapLink heap;
492     int64 size = 0;
493   };
494 
GetInferredShapeSize(const InferredShape & ishape,DataType dtype)495   static int64 GetInferredShapeSize(const InferredShape& ishape,
496                                     DataType dtype) {
497     return ishape.shape.IsFullyDefined()
498                ? ishape.shape.num_elements() * DataTypeSize(dtype)
499                : -1;
500   }
501 
502   std::vector<DeviceNode> index_nodes_;
503   IntrusiveHeap<DeviceNode, typename DeviceNode::Compare> heap_;
504   std::vector<int64> sizes_;
505 };
506 
ValidateCoreNumber(int64_t core,int64_t num_cores_per_replica)507 Status ValidateCoreNumber(int64_t core, int64_t num_cores_per_replica) {
508   if (core < 0 || core >= num_cores_per_replica) {
509     return tensorflow::errors::InvalidArgument("Invalid core ID: ", core,
510                                                ". The valid core IDs are [0..",
511                                                num_cores_per_replica, ")");
512   }
513   return Status::OK();
514 }
515 
FindHostComputeKeyPlaceholderNodes(const Graph * graph,const std::vector<Node * > & replicate_nodes,std::unordered_map<string,Node * > * host_compute_key_placeholder_map)516 Status FindHostComputeKeyPlaceholderNodes(
517     const Graph* graph, const std::vector<Node*>& replicate_nodes,
518     std::unordered_map<string, Node*>* host_compute_key_placeholder_map) {
519   host_compute_key_placeholder_map->clear();
520   for (const auto node : replicate_nodes) {
521     (*host_compute_key_placeholder_map)[node->name()] = nullptr;
522   }
523 
524   for (Node* node : graph->op_nodes()) {
525     if (node->type_string() == "Placeholder" &&
526         str_util::EndsWith(node->name(), "_key_placeholder")) {
527       const AttrValue* call_node_attr =
528           node->attrs().Find("_host_compute_call_node");
529       if (call_node_attr != nullptr) {
530         auto iter = host_compute_key_placeholder_map->find(call_node_attr->s());
531         if (iter == host_compute_key_placeholder_map->end()) {
532           return errors::InvalidArgument(
533               "Node ", node->name(), " has _host_compute_call_node attribute '",
534               call_node_attr->s(), "' that doesn't correspond to a call node");
535         }
536         if (iter->second != nullptr) {
537           return errors::InvalidArgument(
538               "Key placeholder node ", iter->second->name(), " for call node ",
539               call_node_attr->s(), " previously found as ",
540               iter->second->name());
541         }
542         iter->second = node;
543       }
544     }
545   }
546 
547   return Status::OK();
548 }
549 
ReplaceCompilationResultNodeWithIdentity(Graph * graph,Node ** node)550 Status ReplaceCompilationResultNodeWithIdentity(Graph* graph, Node** node) {
551   Node* old_node = *node;
552   // We want to replace the node with an identity node with the same name.
553   const string& node_name = old_node->name();
554 
555   // Create identity node.
556   TF_ASSIGN_OR_RETURN(
557       Node * id_node,
558       BuildIdentityNode(graph, node_name, DT_STRING,
559                         /*input=*/nullptr, /*requested_device=*/""));
560 
561   // No incoming edges are copied as a new one will be added from compile node
562   // to id_node.
563 
564   // Copy outgoing edges to the id node.
565   std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
566                                      old_node->out_edges().end());
567   for (const Edge* edge : out_edges) {
568     Node* dst = edge->dst();
569     int src_output = edge->src_output();
570     int dst_input = edge->dst_input();
571 
572     if (src_output == Graph::kControlSlot) {
573       graph->AddControlEdge(id_node, dst);
574     } else {
575       graph->AddEdge(id_node, src_output, dst, dst_input);
576     }
577     graph->RemoveEdge(edge);
578   }
579   graph->RemoveNode(old_node);
580 
581   *node = id_node;
582   return Status::OK();
583 }
584 
GetStepMarkerLocation(const Node & replicate_node,xla::DebugOptions::StepMarkerLocation * location)585 Status GetStepMarkerLocation(const Node& replicate_node,
586                              xla::DebugOptions::StepMarkerLocation* location) {
587   string step_marker_location_attr;
588   TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "step_marker_location",
589                                  &step_marker_location_attr));
590   if (step_marker_location_attr.empty()) {
591     *location = xla::DebugOptions::STEP_MARK_AT_ENTRY;
592   } else {
593     if (!xla::DebugOptions::StepMarkerLocation_Parse(step_marker_location_attr,
594                                                      location)) {
595       return errors::InvalidArgument("Malformed step_marker_location: ",
596                                      step_marker_location_attr);
597     }
598   }
599   return Status::OK();
600 }
601 
602 // Extracts a map of dimension and number of splits for tiled input from xla
603 // sharding attribute.
GetDimensionIndicesAndNumSplitsFromSharding(const xla::OpSharding & sharding,std::map<int,int> * split_dimension_map)604 Status GetDimensionIndicesAndNumSplitsFromSharding(
605     const xla::OpSharding& sharding, std::map<int, int>* split_dimension_map) {
606   int64_t tensor_tile_rank = sharding.tile_assignment_dimensions_size();
607   if (sharding.replicate_on_last_tile_dim()) {
608     tensor_tile_rank--;
609   }
610   for (int dim_index = 0; dim_index < tensor_tile_rank; dim_index++) {
611     if (sharding.tile_assignment_dimensions(dim_index) > 1) {
612       split_dimension_map->emplace(
613           dim_index, sharding.tile_assignment_dimensions(dim_index));
614     }
615   }
616 
617   if (split_dimension_map->empty()) {
618     return errors::InvalidArgument("Arg has unnecessary tiled sharding: ",
619                                    sharding.DebugString());
620   }
621   return Status::OK();
622 }
623 
624 // Updates contents of the function with `function_name` in function library
625 // definition `flib_def` to `new_graph`. This is required when graph
626 // transformation happens inside a function call body.
UpdateFunctionLibDefinition(const Graph & new_graph,const std::string & function_name,FunctionLibraryDefinition * flib_def)627 Status UpdateFunctionLibDefinition(const Graph& new_graph,
628                                    const std::string& function_name,
629                                    FunctionLibraryDefinition* flib_def) {
630   FunctionDef graph_fdef;
631   TF_RETURN_IF_ERROR(GraphToFunctionDef(new_graph, function_name, &graph_fdef));
632   TF_RETURN_IF_ERROR(flib_def->ReplaceFunction(function_name, graph_fdef));
633   return Status::OK();
634 }
635 
636 struct NodeOut {
637   Node* node;
638   int index;
639 };
640 
641 struct ShardedInputIndex {
642   int replica_id;
643   int argument_index;
644 
operator <tensorflow::__anonf6cf56690111::ShardedInputIndex645   bool operator<(const ShardedInputIndex& rhs) const {
646     return std::tie(replica_id, argument_index) <
647            std::tie(rhs.replica_id, rhs.argument_index);
648   }
649 };
650 
651 struct ShardedPerHostInputIndex {
652   string host_device;
653   int argument_index;
operator <tensorflow::__anonf6cf56690111::ShardedPerHostInputIndex654   bool operator<(const ShardedPerHostInputIndex& rhs) const {
655     return std::tie(host_device, argument_index) <
656            std::tie(rhs.host_device, rhs.argument_index);
657   }
operator ==tensorflow::__anonf6cf56690111::ShardedPerHostInputIndex658   bool operator==(const ShardedPerHostInputIndex& rhs) const {
659     return (argument_index == rhs.argument_index) &&
660            (host_device == rhs.host_device);
661   }
662 };
663 
664 struct ShardedInputInfo {
665   // Split node that would be connected to tiled input Node.
666   Node* split_node;
667   // List of splits nodes and output index of the split node from which sharded
668   // input will be connected to the TPUExecute node. The inputs are ordered by
669   // logical core ids.
670   std::vector<NodeOut> sharded_inputs;
671 };
672 
673 // Adds pad node after split node to graph for uneven sharding tiled inputs.
674 // |graph| owns the returned Node* instance.
CreatePadNode(const int padding,const int num_dims,const int split_dim,DataType dtype,Node * control_predecessor,Node * split_node,const int split_index,Graph * graph)675 xla::StatusOr<Node*> CreatePadNode(const int padding, const int num_dims,
676                                    const int split_dim, DataType dtype,
677                                    Node* control_predecessor, Node* split_node,
678                                    const int split_index, Graph* graph) {
679   // Add paddings node.
680   Status s;
681   NodeDef paddings_def;
682   paddings_def.set_name(
683       graph->NewName(absl::StrCat(split_node->name(), "/paddings")));
684   paddings_def.set_op("Const");
685   AddNodeAttr("dtype", DT_INT32, &paddings_def);
686   paddings_def.set_device(split_node->assigned_device_name());
687   TensorProto sizes_tensor_proto;
688   sizes_tensor_proto.set_dtype(DT_INT32);
689   for (int i = 0; i < num_dims; ++i) {
690     sizes_tensor_proto.add_int_val(0);
691     if (i == split_dim) {
692       sizes_tensor_proto.add_int_val(padding);
693     } else {
694       sizes_tensor_proto.add_int_val(0);
695     }
696   }
697   TensorShape sizes_shape({num_dims, 2});
698   sizes_shape.AsProto(sizes_tensor_proto.mutable_tensor_shape());
699   AddNodeAttr("value", sizes_tensor_proto, &paddings_def);
700   Node* paddings_node = graph->AddNode(paddings_def, &s);
701   TF_RETURN_IF_ERROR(s);
702 
703   // Add Pad node.
704   NodeDef pad_def;
705   pad_def.set_name(graph->NewName(
706       absl::StrCat(split_node->name(), "/pad_shard_", split_index)));
707   pad_def.set_op("Pad");
708   pad_def.set_device(split_node->assigned_device_name());
709   AddNodeAttr("T", dtype, &pad_def);
710   AddNodeAttr("Tpaddings", DT_INT32, &pad_def);
711   pad_def.add_input(absl::StrCat(split_node->name(), ":", split_index));
712   pad_def.add_input(absl::StrCat(paddings_node->name(), ":0"));
713   Node* pad_node = graph->AddNode(pad_def, &s);
714   pad_node->set_assigned_device_name(split_node->assigned_device_name());
715   TF_RETURN_IF_ERROR(s);
716   // Add edges for pad node.
717   graph->AddEdge(split_node, split_index, pad_node, 0);
718   graph->AddEdge(paddings_node, 0, pad_node, 1);
719   graph->AddControlEdge(control_predecessor, pad_node);
720   return pad_node;
721 }
722 
723 // Adds split node and split dimension node to graph for sharding tiled inputs.
724 // |graph| owns the returned Node* instance.
CreateSplitNode(const int num_splits,const int dim,const int num_dims,const int64_t padding,const int orig_src_output,DataType dtype,absl::string_view name_prefix,Node * control_predecessor,Node * orig_src,Graph * graph)725 xla::StatusOr<Node*> CreateSplitNode(const int num_splits, const int dim,
726                                      const int num_dims, const int64_t padding,
727                                      const int orig_src_output, DataType dtype,
728                                      absl::string_view name_prefix,
729                                      Node* control_predecessor, Node* orig_src,
730                                      Graph* graph) {
731   const std::string input_assigned_device = orig_src->assigned_device_name();
732   Node* to_split_node = orig_src;
733   int to_split_index = orig_src_output;
734   if (padding > 0) {
735     TF_ASSIGN_OR_RETURN(
736         Node * pad_node,
737         CreatePadNode(padding, num_dims, dim, dtype, control_predecessor,
738                       orig_src, orig_src_output, graph));
739     to_split_node = pad_node;
740     to_split_index = 0;
741   }
742 
743   // Add a split dimension node.
744   NodeDef split_dim_def;
745   split_dim_def.set_name(
746       graph->NewName(absl::StrCat(name_prefix, "/split_dim")));
747   split_dim_def.set_op("Const");
748   split_dim_def.set_device(input_assigned_device);
749   AddNodeAttr("dtype", DT_INT32, &split_dim_def);
750   TensorProto tensor_proto;
751   tensor_proto.set_dtype(DT_INT32);
752   tensor_proto.add_int_val(dim);
753   TensorShape shape({});
754   shape.AsProto(tensor_proto.mutable_tensor_shape());
755   AddNodeAttr("value", tensor_proto, &split_dim_def);
756   Status s;
757   Node* split_dim_node = graph->AddNode(split_dim_def, &s);
758   TF_RETURN_IF_ERROR(s);
759   // Add a split node.
760   NodeDef split_def;
761   split_def.set_name(graph->NewName(absl::StrCat(name_prefix, "/split")));
762   split_def.set_op("Split");
763   split_def.set_device(input_assigned_device);
764   AddNodeAttr("num_split", num_splits, &split_def);
765   AddNodeAttr("T", dtype, &split_def);
766   split_def.add_input(absl::StrCat(split_dim_node->name(), ":0"));
767   split_def.add_input(absl::StrCat(to_split_node->name(), ":", to_split_index));
768   Node* split_node = graph->AddNode(split_def, &s);
769   TF_RETURN_IF_ERROR(s);
770 
771   split_node->set_assigned_device_name(input_assigned_device);
772 
773   // If colocate the newly created split op to source node of input to TPU
774   // computation.
775   split_node->AddAttr(kColocationAttrName,
776                       std::vector<string>{absl::StrCat(kColocationGroupPrefix,
777                                                        orig_src->name())});
778 
779   graph->AddEdge(split_dim_node, 0, split_node, 0);
780   graph->AddEdge(to_split_node, to_split_index, split_node, 1);
781 
782   // Add a control dependency from `control_predecessor` to newly created
783   // constant node. This ensures that newly added split/split dim
784   // nodes are placed inside correct while loop frames when TPUExecute
785   // node is inside a host training loop.
786   graph->AddControlEdge(control_predecessor, split_dim_node);
787   return split_node;
788 }
789 
GetPadding(const int split_dim,const int num_splits,const PartialTensorShape & partial_tensor_shape)790 int64 GetPadding(const int split_dim, const int num_splits,
791                  const PartialTensorShape& partial_tensor_shape) {
792   // If dim dimension is not defined, no uneven sharding support.
793   if (partial_tensor_shape.dim_size(split_dim) <= 0) {
794     return 0;
795   }
796   int64_t per_split_size = tensorflow::MathUtil::CeilOfRatio<int64>(
797       partial_tensor_shape.dim_size(split_dim), num_splits);
798   int64_t total_padding =
799       per_split_size * num_splits - partial_tensor_shape.dim_size(split_dim);
800   return total_padding;
801 }
802 
803 // Creates a set of splits nodes that shards tiled input node in graph.
CreateOrGetSplitNodesForInputSharding(const xla::OpSharding & sharding,int orig_arg_num,DataType dtype,const PartialTensorShape & partial_tensor_shape,int replica_id,int orig_src_output,Node * orig_src,Node * control_predecessor,Graph * graph,std::map<ShardedInputIndex,ShardedInputInfo> * arg_index_to_sharded_input_map)804 xla::StatusOr<ShardedInputInfo> CreateOrGetSplitNodesForInputSharding(
805     const xla::OpSharding& sharding, int orig_arg_num, DataType dtype,
806     const PartialTensorShape& partial_tensor_shape, int replica_id,
807     int orig_src_output, Node* orig_src, Node* control_predecessor,
808     Graph* graph,
809     std::map<ShardedInputIndex, ShardedInputInfo>*
810         arg_index_to_sharded_input_map) {
811   ShardedInputIndex input_index{replica_id, orig_arg_num};
812   auto iter = arg_index_to_sharded_input_map->find(input_index);
813   if (iter != arg_index_to_sharded_input_map->end()) {
814     return iter->second;
815   }
816   // Maps input dimension and number of splits with which the
817   // dimension sharded.
818   std::map<int, int> split_dimension_map;
819   TF_RETURN_IF_ERROR(GetDimensionIndicesAndNumSplitsFromSharding(
820       sharding, &split_dimension_map));
821   TF_RET_CHECK(!split_dimension_map.empty())
822       << "Unnecessary sharding attribute found.";
823 
824   // For v1 while loop, nodes inside the loop body must either
825   //  1) Have data edges from while loop input node.
826   //  or
827   //  2) Have direct control dependency from while loop input control
828   //     node.
829   //
830   // As so, if we are adding Split node inside, while loop body,
831   // we must manually add a control dependency to a node inside
832   // a while loop (i.e. `control_predecessor`) to constant nodes
833   // without data in-edges to make sure that added split nodes
834   // have correct frame name. Else, placer will complain when
835   // `BuildControlFlow()` is invoked.
836 
837   auto sharding_it = split_dimension_map.begin();
838   std::queue<Node*> split_nodes_for_dimension;
839   absl::flat_hash_map<Node*, int> node_to_split_dim;
840   int split_dimension = sharding_it->first;
841   int num_split = sharding_it->second;
842 
843   // Creates a tree of split nodes for sharding tiled inputs. Splits nodes
844   // are created such that input data is sharded in row major order.
845   // Split nodes at ith depth from the original input node represent nodes
846   // that split the input data at ith dimension.
847   TF_ASSIGN_OR_RETURN(
848       Node * root_split_node,
849       CreateSplitNode(
850           num_split, split_dimension, partial_tensor_shape.dims(),
851           GetPadding(split_dimension, num_split, partial_tensor_shape),
852           orig_src_output, dtype,
853           absl::StrCat("sharded_input/replica_", replica_id, "_dim_",
854                        split_dimension),
855           control_predecessor, orig_src, graph));
856   sharding_it++;
857 
858   split_nodes_for_dimension.emplace(root_split_node);
859   node_to_split_dim[root_split_node] = split_dimension;
860 
861   while (sharding_it != split_dimension_map.end()) {
862     split_dimension = sharding_it->first;
863     num_split = sharding_it->second;
864     int num_split_nodes_in_dimension = split_nodes_for_dimension.size();
865     for (int i = 0; i < num_split_nodes_in_dimension; ++i) {
866       Node* input_split_node = split_nodes_for_dimension.front();
867       split_nodes_for_dimension.pop();
868       for (int src_output_index = 0;
869            src_output_index < input_split_node->num_outputs();
870            ++src_output_index) {
871         TF_ASSIGN_OR_RETURN(
872             Node * split_node,
873             CreateSplitNode(
874                 num_split, split_dimension, partial_tensor_shape.dims(),
875                 GetPadding(split_dimension, num_split, partial_tensor_shape),
876                 src_output_index, dtype,
877                 absl::StrCat("sharded_input/replica_", replica_id, "_dim_",
878                              split_dimension),
879                 control_predecessor, input_split_node, graph));
880         split_nodes_for_dimension.emplace(split_node);
881         node_to_split_dim[split_node] = split_dimension;
882       }
883     }
884     sharding_it++;
885   }
886 
887   // `split_nodes_for_dimension` now includes final split nodes
888   // from which sharded data will be fed into TPUExcute nodes -- sorted by
889   // row major order.
890   std::vector<NodeOut> sharded_inputs_list(
891       sharding.tile_assignment_devices_size());
892   int64_t next_core_tile_index = 0;
893   while (!split_nodes_for_dimension.empty()) {
894     Node* split_node = split_nodes_for_dimension.front();
895     split_nodes_for_dimension.pop();
896     int num_splits;
897     TF_RETURN_IF_ERROR(
898         GetNodeAttr(split_node->def(), "num_split", &num_splits));
899     for (int out_index = 0; out_index < num_splits; ++out_index) {
900       int64_t repeat_count =
901           sharding.replicate_on_last_tile_dim()
902               ? *sharding.tile_assignment_dimensions().rbegin()
903               : 1;
904       for (int64_t i = 0; i < repeat_count; ++i) {
905         int64_t next_core =
906             sharding.tile_assignment_devices(next_core_tile_index++);
907         sharded_inputs_list[next_core] = NodeOut{split_node, out_index};
908       }
909     }
910   }
911 
912   ShardedInputInfo sharded_input_info{root_split_node,
913                                       std::move(sharded_inputs_list)};
914   (*arg_index_to_sharded_input_map)[input_index] = sharded_input_info;
915   return sharded_input_info;
916 }
917 
918 // Creates a xla split node to shard an input, and adds that new node to a
919 // Graph.
CreateXlaSplitOp(absl::string_view node_name,const bool is_resource,const NodeOut & input,const PartialTensorShape & partial_tensor_shape,const std::vector<Node * > & control_inputs,const std::vector<Node * > & control_outputs,const DataType dtype,const int num_shards,const xla::OpSharding & sharding,Graph * graph)920 StatusOr<Node*> CreateXlaSplitOp(absl::string_view node_name,
921                                  const bool is_resource, const NodeOut& input,
922                                  const PartialTensorShape& partial_tensor_shape,
923                                  const std::vector<Node*>& control_inputs,
924                                  const std::vector<Node*>& control_outputs,
925                                  const DataType dtype, const int num_shards,
926                                  const xla::OpSharding& sharding,
927                                  Graph* graph) {
928   const std::string& input_assigned_device = input.node->assigned_device_name();
929   NodeDef xla_split_def;
930   xla_split_def.set_name(graph->NewName(node_name));
931   xla_split_def.set_op(is_resource ? "ReadVariableXlaSplitND" : "XlaSplitND");
932   xla_split_def.set_device(input_assigned_device);
933   AddNodeAttr("T", dtype, &xla_split_def);
934   AddNodeAttr("N", num_shards, &xla_split_def);
935   const std::vector<int64> num_splits(
936       sharding.tile_assignment_dimensions().begin(),
937       sharding.replicate_on_last_tile_dim()
938           ? std::prev(sharding.tile_assignment_dimensions().end())
939           : sharding.tile_assignment_dimensions().end());
940   AddNodeAttr("num_splits", num_splits, &xla_split_def);
941   const int rank = sharding.replicate_on_last_tile_dim()
942                        ? sharding.tile_assignment_dimensions_size() - 1
943                        : sharding.tile_assignment_dimensions_size();
944   std::vector<int32> paddings;
945   paddings.reserve(rank);
946   for (int dim = 0; dim < rank; ++dim) {
947     paddings.push_back(GetPadding(dim, sharding.tile_assignment_dimensions(dim),
948                                   partial_tensor_shape));
949   }
950   AddNodeAttr("paddings", paddings, &xla_split_def);
951 
952   if (!is_resource) {
953     AddNodeAttr("_tpu_avoid_constant_fold", "not_used", &xla_split_def);
954     AddNodeAttr(kColocationAttrName,
955                 std::vector<string>{
956                     absl::StrCat(kColocationGroupPrefix, input.node->name())},
957                 &xla_split_def);
958   }
959 
960   Status s;
961   Node* xla_split = graph->AddNode(xla_split_def, &s);
962   TF_RETURN_IF_ERROR(s);
963   if (is_resource) {
964     xla_split->set_requested_device(input.node->requested_device());
965   }
966   xla_split->set_assigned_device_name(input_assigned_device);
967   graph->AddEdge(input.node, input.index, xla_split, 0);
968   for (Node* control_input : control_inputs) {
969     graph->AddControlEdge(control_input, xla_split);
970   }
971   for (Node* control_output : control_outputs) {
972     graph->AddControlEdge(xla_split, control_output);
973   }
974   return xla_split;
975 }
976 
977 // Creates a sharded tensor list for all input shards of an input with sharding.
ShardInputWithXlaSplitOp(absl::string_view node_name,const bool is_resource,const NodeOut & input,const PartialTensorShape & partial_tensor_shape,const std::vector<Node * > & control_inputs,const std::vector<Node * > & control_outputs,const DataType dtype,const xla::OpSharding & sharding,Graph * graph)978 xla::StatusOr<std::vector<NodeOut>> ShardInputWithXlaSplitOp(
979     absl::string_view node_name, const bool is_resource, const NodeOut& input,
980     const PartialTensorShape& partial_tensor_shape,
981     const std::vector<Node*>& control_inputs,
982     const std::vector<Node*>& control_outputs, const DataType dtype,
983     const xla::OpSharding& sharding, Graph* graph) {
984   const int repeat = sharding.replicate_on_last_tile_dim()
985                          ? *sharding.tile_assignment_dimensions().rbegin()
986                          : 1;
987   const int num_shards = sharding.tile_assignment_devices_size() / repeat;
988 
989   TF_ASSIGN_OR_RETURN(
990       Node * xla_split,
991       CreateXlaSplitOp(node_name, is_resource, input, partial_tensor_shape,
992                        control_inputs, control_outputs, dtype, num_shards,
993                        sharding, graph));
994 
995   std::vector<NodeOut> sharded_inputs_list(
996       sharding.tile_assignment_devices_size());
997 
998   for (int i = 0; i < num_shards; ++i) {
999     for (int j = 0; j < repeat; ++j) {
1000       const int index = i * repeat + j;
1001       const int core = sharding.tile_assignment_devices(index);
1002       sharded_inputs_list[core] = {xla_split, i};
1003     }
1004   }
1005 
1006   return sharded_inputs_list;
1007 }
1008 
1009 // Creates an XlaSplitND op to shard a per-replica arg.
CreateOrGetXlaSplitNodeForShardedPerReplicaArg(const xla::OpSharding & sharding,const int replica_id,const int orig_arg_num,DataType dtype,const PartialTensorShape & partial_tensor_shape,Node * orig_src,const int orig_src_output,Graph * graph,std::map<ShardedInputIndex,ShardedInputInfo> * arg_index_to_sharded_input_map)1010 xla::StatusOr<ShardedInputInfo> CreateOrGetXlaSplitNodeForShardedPerReplicaArg(
1011     const xla::OpSharding& sharding, const int replica_id,
1012     const int orig_arg_num, DataType dtype,
1013     const PartialTensorShape& partial_tensor_shape, Node* orig_src,
1014     const int orig_src_output, Graph* graph,
1015     std::map<ShardedInputIndex, ShardedInputInfo>*
1016         arg_index_to_sharded_input_map) {
1017   ShardedInputIndex input_index{replica_id, orig_arg_num};
1018   auto iter = arg_index_to_sharded_input_map->find(input_index);
1019   if (iter != arg_index_to_sharded_input_map->end()) {
1020     return iter->second;
1021   }
1022 
1023   TF_ASSIGN_OR_RETURN(
1024       std::vector<NodeOut> sharded_inputs_list,
1025       ShardInputWithXlaSplitOp(
1026           absl::StrCat(orig_src->name(), "/replica_", replica_id, "_split"),
1027           /*is_resource=*/false, /*input=*/{orig_src, orig_src_output},
1028           partial_tensor_shape, /*control_inputs=*/{}, /*control_outputs=*/{},
1029           dtype, sharding, graph));
1030 
1031   ShardedInputInfo sharded_input_info{nullptr, std::move(sharded_inputs_list)};
1032   (*arg_index_to_sharded_input_map)[input_index] = sharded_input_info;
1033   return sharded_input_info;
1034 }
1035 
1036 // Creates an XlaSplitND op to shard a distributed arg.
CreateOrGetXlaSplitNodeForDistributedArg(const xla::OpSharding & sharding,const int num_replicas,const int replica_id,const int orig_arg_num,DataType dtype,const PartialTensorShape & partial_tensor_shape,Node * orig_src,const int orig_src_output,Graph * graph,std::map<ShardedInputIndex,ShardedInputInfo> * arg_index_to_sharded_input_map)1037 xla::StatusOr<ShardedInputInfo> CreateOrGetXlaSplitNodeForDistributedArg(
1038     const xla::OpSharding& sharding, const int num_replicas,
1039     const int replica_id, const int orig_arg_num, DataType dtype,
1040     const PartialTensorShape& partial_tensor_shape, Node* orig_src,
1041     const int orig_src_output, Graph* graph,
1042     std::map<ShardedInputIndex, ShardedInputInfo>*
1043         arg_index_to_sharded_input_map) {
1044   ShardedInputIndex input_index{replica_id, orig_arg_num};
1045   auto iter = arg_index_to_sharded_input_map->find(input_index);
1046   if (iter != arg_index_to_sharded_input_map->end()) {
1047     return iter->second;
1048   }
1049 
1050   TF_ASSIGN_OR_RETURN(
1051       std::vector<NodeOut> sharded_inputs_list,
1052       ShardInputWithXlaSplitOp(
1053           absl::StrCat(orig_src->name(), "/distributed_split"),
1054           /*is_resource=*/false, /*input=*/{orig_src, orig_src_output},
1055           partial_tensor_shape, /*control_inputs=*/{}, /*control_outputs=*/{},
1056           dtype, sharding, graph));
1057 
1058   ShardedInputInfo sharded_input_info{nullptr, std::move(sharded_inputs_list)};
1059   for (int replica = 0; replica < num_replicas; ++replica) {
1060     (*arg_index_to_sharded_input_map)[{replica, orig_arg_num}] =
1061         sharded_input_info;
1062   }
1063   return sharded_input_info;
1064 }
1065 
1066 // Creates an ReadVariableXlaSplitND op to shard a variable arg.
CreateOrGetXlaSplitNodeForVariableArg(const xla::OpSharding & sharding,const int num_replicas,const int replica_id,const int orig_arg_num,DataType dtype,const PartialTensorShape & partial_tensor_shape,Node * orig_src,const int orig_src_output,Graph * graph,std::vector<Node * > * to_be_removed_nodes,std::map<ShardedInputIndex,ShardedInputInfo> * arg_index_to_sharded_input_map)1067 xla::StatusOr<ShardedInputInfo> CreateOrGetXlaSplitNodeForVariableArg(
1068     const xla::OpSharding& sharding, const int num_replicas,
1069     const int replica_id, const int orig_arg_num, DataType dtype,
1070     const PartialTensorShape& partial_tensor_shape, Node* orig_src,
1071     const int orig_src_output, Graph* graph,
1072     std::vector<Node*>* to_be_removed_nodes,
1073     std::map<ShardedInputIndex, ShardedInputInfo>*
1074         arg_index_to_sharded_input_map) {
1075   ShardedInputIndex input_index{replica_id, orig_arg_num};
1076   auto iter = arg_index_to_sharded_input_map->find(input_index);
1077   if (iter != arg_index_to_sharded_input_map->end()) {
1078     return iter->second;
1079   }
1080 
1081   DCHECK_EQ(orig_src->type_string(), "ReadVariableOp");
1082   std::vector<Node*> control_outputs;
1083   std::vector<const Edge*> edges_to_remove;
1084   for (const Edge* edge : orig_src->out_edges()) {
1085     if (edge->IsControlEdge()) {
1086       control_outputs.push_back(edge->dst());
1087     }
1088     edges_to_remove.push_back(edge);
1089   }
1090 
1091   to_be_removed_nodes->push_back(orig_src);
1092 
1093   const Edge* resource = nullptr;
1094   TF_RETURN_IF_ERROR(orig_src->input_edge(0, &resource));
1095 
1096   std::vector<Node*> control_inputs;
1097   for (const Edge* edge : orig_src->in_edges()) {
1098     if (edge->IsControlEdge()) {
1099       control_inputs.push_back(edge->src());
1100     }
1101   }
1102 
1103   TF_ASSIGN_OR_RETURN(
1104       std::vector<NodeOut> sharded_inputs_list,
1105       ShardInputWithXlaSplitOp(
1106           absl::StrCat(resource->src()->name(), "/read_variable_split"),
1107           /*is_resource=*/true,
1108           /*input=*/{resource->src(), resource->src_output()},
1109           partial_tensor_shape, control_inputs, control_outputs, dtype,
1110           sharding, graph));
1111 
1112   for (const Edge* edge : edges_to_remove) {
1113     graph->RemoveControlEdge(edge);
1114   }
1115 
1116   DCHECK(orig_src->out_edges().empty());
1117 
1118   ShardedInputInfo sharded_input_info{nullptr, std::move(sharded_inputs_list)};
1119   for (int replica = 0; replica < num_replicas; ++replica) {
1120     ShardedInputIndex idx{replica, orig_arg_num};
1121     // Refrain from overwriting, if dummy inputs were already placed instead.
1122     arg_index_to_sharded_input_map->insert({idx, sharded_input_info});
1123   }
1124   return sharded_input_info;
1125 }
1126 
1127 // Creates a concat node to be used for aggregating sharded retvals across
1128 // logical cores.
CreateConcatNode(int dim,int num_splits,DataType dtype,absl::string_view name_prefix,const std::vector<NodeOut> & inputs,Graph * graph,absl::string_view device)1129 xla::StatusOr<Node*> CreateConcatNode(int dim, int num_splits, DataType dtype,
1130                                       absl::string_view name_prefix,
1131                                       const std::vector<NodeOut>& inputs,
1132                                       Graph* graph, absl::string_view device) {
1133   // Add a Concat dim node.
1134   NodeDef concat_dim_def;
1135   concat_dim_def.set_name(
1136       graph->NewName(absl::StrCat(name_prefix, "/concat_dim")));
1137   concat_dim_def.set_op("Const");
1138   AddNodeAttr("dtype", DT_INT32, &concat_dim_def);
1139   concat_dim_def.set_device(std::string(device));
1140   TensorProto tensor_proto;
1141   tensor_proto.set_dtype(DT_INT32);
1142   tensor_proto.add_int_val(dim);
1143   TensorShape shape({});
1144   shape.AsProto(tensor_proto.mutable_tensor_shape());
1145   AddNodeAttr("value", tensor_proto, &concat_dim_def);
1146   Status s;
1147   Node* concat_dim_node = graph->AddNode(concat_dim_def, &s);
1148   TF_RETURN_IF_ERROR(s);
1149 
1150   // Add a Concat node.
1151   NodeDef concat_def;
1152   concat_def.set_name(graph->NewName(absl::StrCat(name_prefix, "/concat")));
1153   concat_def.set_op("Concat");
1154   AddNodeAttr("N", num_splits, &concat_def);
1155   AddNodeAttr("T", dtype, &concat_def);
1156   concat_def.add_input(absl::StrCat(concat_dim_node->name(), ":0"));
1157   concat_def.set_device(std::string(device));
1158   for (const auto& i : inputs) {
1159     concat_def.add_input(absl::StrCat(i.node->name(), ":", i.index));
1160   }
1161   Node* concat_node = graph->AddNode(concat_def, &s);
1162   TF_RETURN_IF_ERROR(s);
1163 
1164   graph->AddEdge(concat_dim_node, 0, concat_node, 0);
1165 
1166   // 0th input to concat node is a concat dim node. So we start from 1st input
1167   // and add all input edges.
1168   int dst_input = 1;
1169   for (const auto& i : inputs) {
1170     graph->AddEdge(i.node, i.index, concat_node, dst_input);
1171     ++dst_input;
1172   }
1173   return concat_node;
1174 }
1175 
1176 // Adds slice node after concat node to graph for uneven sharding tiled inputs.
CreateSliceNode(DataType dtype,const PartialTensorShape & shape,Node * concat_node,const int concat_out_index,Graph * graph,absl::string_view device)1177 xla::StatusOr<Node*> CreateSliceNode(DataType dtype,
1178                                      const PartialTensorShape& shape,
1179                                      Node* concat_node,
1180                                      const int concat_out_index, Graph* graph,
1181                                      absl::string_view device) {
1182   Status s;
1183   // Add begin node for concat.
1184   NodeDef begin_def;
1185   begin_def.set_name(
1186       graph->NewName(absl::StrCat(concat_node->name(), "/slice_begin")));
1187   begin_def.set_op("Const");
1188   AddNodeAttr("dtype", DT_INT32, &begin_def);
1189   begin_def.set_device(std::string(device));
1190   TensorProto begin_tensor_proto;
1191   begin_tensor_proto.set_dtype(DT_INT32);
1192   for (int i = 0; i < shape.dims(); ++i) {
1193     begin_tensor_proto.add_int_val(0);
1194   }
1195   TensorShape begin_shape({shape.dims()});
1196   begin_shape.AsProto(begin_tensor_proto.mutable_tensor_shape());
1197   AddNodeAttr("value", begin_tensor_proto, &begin_def);
1198   Node* begin_node = graph->AddNode(begin_def, &s);
1199   TF_RETURN_IF_ERROR(s);
1200 
1201   // Add size node.
1202   NodeDef size_def;
1203   size_def.set_name(
1204       graph->NewName(absl::StrCat(concat_node->name(), "/slice_size")));
1205   size_def.set_op("Const");
1206   AddNodeAttr("dtype", DT_INT32, &size_def);
1207   size_def.set_device(std::string(device));
1208   TensorProto sizes_tensor_proto;
1209   sizes_tensor_proto.set_dtype(DT_INT32);
1210   for (int i = 0; i < shape.dims(); ++i) {
1211     sizes_tensor_proto.add_int_val(shape.dim_size(i));
1212   }
1213   TensorShape sizes_shape({shape.dims()});
1214   sizes_shape.AsProto(sizes_tensor_proto.mutable_tensor_shape());
1215   AddNodeAttr("value", sizes_tensor_proto, &size_def);
1216   Node* size_node = graph->AddNode(size_def, &s);
1217   TF_RETURN_IF_ERROR(s);
1218 
1219   // Add Slice node.
1220   NodeDef slice_def;
1221   slice_def.set_name(
1222       graph->NewName(absl::StrCat(concat_node->name(), "/slice")));
1223   slice_def.set_op("Slice");
1224   slice_def.set_device(std::string(device));
1225   AddNodeAttr("T", dtype, &slice_def);
1226   AddNodeAttr("Index", DT_INT32, &slice_def);
1227   slice_def.add_input(absl::StrCat(concat_node->name(), ":", concat_out_index));
1228   slice_def.add_input(absl::StrCat(begin_node->name(), ":0"));
1229   slice_def.add_input(absl::StrCat(size_node->name(), ":0"));
1230   Node* slice_node = graph->AddNode(slice_def, &s);
1231   TF_RETURN_IF_ERROR(s);
1232   // Add edges for slice node.
1233   graph->AddEdge(concat_node, concat_out_index, slice_node, 0);
1234   graph->AddEdge(begin_node, 0, slice_node, 1);
1235   graph->AddEdge(size_node, 0, slice_node, 2);
1236   return slice_node;
1237 }
1238 
1239 // Creates a set of Concat nodes that aggregates sharded outputs from TPUExecute
1240 // nodes into a single output. Sharded outputs are concatenated along row major
1241 // order. That is, tiled output along 0th dimension will be concatenated last.
CreateConcatNodesForRetval(const xla::OpSharding & sharding,DataType dtype,const PartialTensorShape & inferred_shape,int replica_id,const std::vector<NodeOut> & orig_inputs,Graph * graph,absl::string_view device)1242 xla::StatusOr<Node*> CreateConcatNodesForRetval(
1243     const xla::OpSharding& sharding, DataType dtype,
1244     const PartialTensorShape& inferred_shape, int replica_id,
1245     const std::vector<NodeOut>& orig_inputs, Graph* graph,
1246     absl::string_view device) {
1247   std::map<int, int> split_dimension_map;
1248   TF_RETURN_IF_ERROR(GetDimensionIndicesAndNumSplitsFromSharding(
1249       sharding, &split_dimension_map));
1250   std::vector<NodeOut> inputs_to_sharded_retval = orig_inputs;
1251   bool has_paddings = false;
1252 
1253   for (auto it = split_dimension_map.rbegin(); it != split_dimension_map.rend();
1254        it++) {
1255     auto dim = it->first;
1256     auto num_splits = it->second;
1257 
1258     int num_concat_nodes = inputs_to_sharded_retval.size() / num_splits;
1259     int input_index_to_concat_node = 0;
1260 
1261     std::vector<NodeOut> new_concat_nodes;
1262     for (int i = 0; i < num_concat_nodes; ++i) {
1263       auto concat_input_it =
1264           inputs_to_sharded_retval.begin() + input_index_to_concat_node;
1265       std::vector<NodeOut> inputs(concat_input_it,
1266                                   concat_input_it + num_splits);
1267       input_index_to_concat_node += num_splits;
1268 
1269       TF_ASSIGN_OR_RETURN(
1270           Node * concat_node,
1271           CreateConcatNode(
1272               dim, num_splits, dtype,
1273               absl::StrCat("sharded_output/replica_", replica_id, "_dim_", dim),
1274               inputs, graph, device));
1275       int64_t paddings = GetPadding(dim, num_splits, inferred_shape);
1276       has_paddings |= paddings > 0;
1277       new_concat_nodes.emplace_back(NodeOut{concat_node, 0});
1278     }
1279     inputs_to_sharded_retval = new_concat_nodes;
1280   }
1281 
1282   TF_RET_CHECK(inputs_to_sharded_retval.size() == 1);
1283   if (has_paddings) {
1284     TF_ASSIGN_OR_RETURN(Node * slice_node,
1285                         CreateSliceNode(dtype, inferred_shape,
1286                                         inputs_to_sharded_retval.at(0).node,
1287                                         /*concat_out_index*/ 0, graph, device));
1288     return slice_node;
1289   }
1290   return inputs_to_sharded_retval.at(0).node;
1291 }
1292 
CreateXlaConcatNode(const xla::OpSharding & sharding,const int replica_id,DataType dtype,const PartialTensorShape & partial_tensor_shape,const std::vector<NodeOut> & orig_inputs,absl::string_view device,Graph * graph)1293 xla::StatusOr<Node*> CreateXlaConcatNode(
1294     const xla::OpSharding& sharding, const int replica_id, DataType dtype,
1295     const PartialTensorShape& partial_tensor_shape,
1296     const std::vector<NodeOut>& orig_inputs, absl::string_view device,
1297     Graph* graph) {
1298   NodeDef xla_concat_def;
1299   xla_concat_def.set_name(graph->NewName(
1300       absl::StrCat("sharded_output/replica_", replica_id, "_concat")));
1301   xla_concat_def.set_op("XlaConcatND");
1302   xla_concat_def.set_device(std::string(device));
1303   AddNodeAttr("T", dtype, &xla_concat_def);
1304   AddNodeAttr("N", static_cast<int64>(orig_inputs.size()), &xla_concat_def);
1305   const std::vector<int64> num_concats(
1306       sharding.tile_assignment_dimensions().begin(),
1307       sharding.replicate_on_last_tile_dim()
1308           ? std::prev(sharding.tile_assignment_dimensions().end())
1309           : sharding.tile_assignment_dimensions().end());
1310   AddNodeAttr("num_concats", num_concats, &xla_concat_def);
1311   const int rank = sharding.replicate_on_last_tile_dim()
1312                        ? sharding.tile_assignment_dimensions_size() - 1
1313                        : sharding.tile_assignment_dimensions_size();
1314   std::vector<int32> paddings;
1315   paddings.reserve(rank);
1316   for (int dim = 0; dim < rank; ++dim) {
1317     paddings.push_back(GetPadding(dim, sharding.tile_assignment_dimensions(dim),
1318                                   partial_tensor_shape));
1319   }
1320   AddNodeAttr("paddings", paddings, &xla_concat_def);
1321 
1322   Status s;
1323   Node* xla_concat = graph->AddNode(xla_concat_def, &s);
1324   TF_RETURN_IF_ERROR(s);
1325   for (int i = 0, e = orig_inputs.size(); i < e; ++i) {
1326     const NodeOut& input = orig_inputs[i];
1327     graph->AddEdge(input.node, input.index, xla_concat, i);
1328   }
1329   return xla_concat;
1330 }
1331 
1332 // Set the padding ops the same devices as the original inputs. If the original
1333 // inputs are on TPUs, the padding ops will be placed on TPUs and XLA on demand
1334 // mode will be triggered, so we don't need to copy the data back to the host
1335 // to do the padding.
SetPaddingNodesDevices(Graph * graph)1336 Status SetPaddingNodesDevices(Graph* graph) {
1337   for (Node* n : graph->op_nodes()) {
1338     bool tpu_padding_attr;
1339     if (n->type_string() == "Pad" &&
1340         GetNodeAttr(n->attrs(), kPostDeviceRewriteAttr, &tpu_padding_attr)
1341             .ok()) {
1342       Node* unpadded_input;
1343       TF_RETURN_IF_ERROR(n->input_node(0, &unpadded_input));
1344 
1345       const string& requested_device = unpadded_input->requested_device();
1346       const string& assigned_device = unpadded_input->assigned_device_name();
1347       if (!requested_device.empty() || !assigned_device.empty()) {
1348         // The output nodes of the original unpadded inputs include the padded
1349         // inputs and real shapes of inputs, we assign those to the same device
1350         // as the original inputs.
1351         for (Node* out : unpadded_input->out_nodes()) {
1352           if (GetNodeAttr(out->attrs(), kPostDeviceRewriteAttr,
1353                           &tpu_padding_attr)
1354                   .ok()) {
1355             out->set_requested_device(requested_device);
1356             out->set_assigned_device_name(assigned_device);
1357           }
1358         }
1359         // There might be a tf.shape node added before TPUCompileOp, we need to
1360         // set its device as well.
1361         for (Node* out : n->out_nodes()) {
1362           if (n->type_string() == "Shape") {
1363             out->set_requested_device(requested_device);
1364             out->set_assigned_device_name(assigned_device);
1365           }
1366         }
1367       }
1368     }
1369   }
1370   return Status::OK();
1371 }
1372 
AssignedOrRequestedDevice(const Node * node)1373 const string& AssignedOrRequestedDevice(const Node* node) {
1374   if (!node->assigned_device_name().empty()) {
1375     return node->assigned_device_name();
1376   }
1377   return node->requested_device();
1378 }
1379 
IsTpuDevice(const string & device_string)1380 bool IsTpuDevice(const string& device_string) {
1381   DeviceNameUtils::ParsedName device;
1382   return DeviceNameUtils::ParseFullName(device_string, &device) &&
1383          device.type == DEVICE_TPU_NODE;
1384 }
1385 
1386 // Returns a set of device ops can be placed on TPU. There is no strict rule of
1387 // thumb to decide which ops should be in the list, but empirically they are
1388 // mostly dummy ops like Identity-like ops or control flow related ops. However
1389 // people can add also add other ops like Pad to allow data stay on TPU.
PlaceOnTPUOpList()1390 const absl::flat_hash_set<std::string>& PlaceOnTPUOpList() {
1391   static const auto place_on_tpu_ops = new absl::flat_hash_set<std::string>(
1392       {"Identity", "IdentityN", "Enter", "Exit", "Switch", "Merge",
1393        "NextIteration", "Shape", "_Retval"});
1394   return *place_on_tpu_ops;
1395 }
1396 
1397 // If an op satisfies the following conditions, it will be placed on the same
1398 // TPU device as its inputs:
1399 //   (1) The op can be placed on TPU (in the PlaceOnTPUOpList)
1400 //   (2) The op itself has no requested or assigned devices.
1401 //   (3) All the data inputs of this op are placed on the same device on TPUs.
1402 //       There are exceptions like the NextIterations input of Switch node can
1403 //       be placed on CPU as it is just a boolean.
1404 //
1405 // Returns true if the node device has been changed, otherwise returns false.
PlaceOpsOnTPU(Node * node)1406 bool PlaceOpsOnTPU(Node* node) {
1407   if (!AssignedOrRequestedDevice(node).empty() ||
1408       !PlaceOnTPUOpList().contains(node->type_string())) {
1409     return false;
1410   }
1411   string src_tpu_device = "";
1412   Node* src_node;
1413   for (const Edge* e : node->in_edges()) {
1414     if (e->IsControlEdge()) {
1415       continue;
1416     }
1417     Node* src = e->src();
1418     const string& src_device = AssignedOrRequestedDevice(src);
1419 
1420     // Make exceptions that we don't force the some inputs to place on TPUs.
1421     if (node->IsSwitch() && src->IsLoopCond()) {
1422       continue;
1423     }
1424 
1425     if (!IsTpuDevice(src_device) ||
1426         (!src_tpu_device.empty() && src_device != src_tpu_device)) {
1427       return false;
1428     }
1429     if (src_tpu_device.empty()) {
1430       src_tpu_device = src_device;
1431       src_node = src;
1432     }
1433   }
1434   node->set_assigned_device_name(src_node->assigned_device_name());
1435   node->set_requested_device(src_node->requested_device());
1436   return true;
1437 }
1438 
CreateOpMetadataFromNode(const Node & node)1439 xla::OpMetadata CreateOpMetadataFromNode(const Node& node) {
1440   xla::OpMetadata metadata;
1441   metadata.set_op_type(node.type_string());
1442   metadata.set_op_name(node.name());
1443   return metadata;
1444 }
1445 
1446 // Helper struct holding node (nullable) and associated sharding.
1447 struct NodeAndSharding {
NodeAndShardingtensorflow::__anonf6cf56690111::NodeAndSharding1448   explicit NodeAndSharding(const Node* node, const xla::OpSharding& sharding)
1449       : node(node), sharding(sharding) {}
1450 
1451   const Node* node;
1452   xla::OpSharding sharding;
1453 };
1454 
1455 // Validate sharding configuration derived from XlaSharding attribute.
1456 // Infer the core id from the OpSharding, if necessary.
ParseAndValidateSharding(const NodeAndSharding & node_and_sharding,const int num_cores_per_replica,int64 * inferred_core_id,absl::optional<NodeAndSharding> * result)1457 Status ParseAndValidateSharding(const NodeAndSharding& node_and_sharding,
1458                                 const int num_cores_per_replica,
1459                                 int64* inferred_core_id,
1460                                 absl::optional<NodeAndSharding>* result) {
1461   if (node_and_sharding.sharding.type() == xla::OpSharding::MAXIMAL) {
1462     int64_t core_annotation =
1463         node_and_sharding.sharding.tile_assignment_devices(0);
1464     TF_RETURN_IF_ERROR(
1465         ValidateCoreNumber(core_annotation, num_cores_per_replica));
1466     if (*inferred_core_id == -1 || *inferred_core_id > core_annotation) {
1467       *inferred_core_id = core_annotation;
1468       result->emplace(node_and_sharding);
1469     }
1470   } else {
1471     if (node_and_sharding.sharding.type() == xla::OpSharding::OTHER) {
1472       for (int64_t core :
1473            node_and_sharding.sharding.tile_assignment_devices()) {
1474         TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica));
1475       }
1476     }
1477 
1478     if (!result->has_value()) {
1479       *result = node_and_sharding;
1480     } else {
1481       std::string result_value_serialized;
1482       xla::OpSharding result_value = result->value().sharding;
1483       result_value.clear_metadata();
1484       SerializeToStringDeterministic(result_value, &result_value_serialized);
1485 
1486       std::string sharding_serialized;
1487       xla::OpSharding sharding = node_and_sharding.sharding;
1488       sharding.clear_metadata();
1489       SerializeToStringDeterministic(sharding, &sharding_serialized);
1490 
1491       // TODO(lyandy): Choose the more granular sharding instead of always
1492       // assigning to core 0 (maximal).
1493       if (result_value_serialized != sharding_serialized) {
1494         // We see different shardings, assign to core 0.
1495         auto core_zero_sharding = xla::sharding_builder::AssignDevice(0);
1496         DCHECK_NE(node_and_sharding.node, nullptr);
1497         *core_zero_sharding.add_metadata() =
1498             CreateOpMetadataFromNode(*node_and_sharding.node);
1499         result->emplace(
1500             NodeAndSharding(node_and_sharding.node, core_zero_sharding));
1501       }
1502     }
1503   }
1504   return Status::OK();
1505 }
1506 
1507 // As XlaSharding node may be followed by Cast op or an Identity op,
1508 // recursively walk the graph and aggregate nodes connectd to
1509 // |input_node| or Cast/Identity op following the |input_node|.
FindNodesMaybeContainingShardingInfo(const Node & input_node,std::vector<const Node * > * nodes)1510 void FindNodesMaybeContainingShardingInfo(const Node& input_node,
1511                                           std::vector<const Node*>* nodes) {
1512   if (input_node.IsIdentity() || input_node.type_string() == "Cast") {
1513     for (const Node* connected_node : input_node.out_nodes())
1514       FindNodesMaybeContainingShardingInfo(*connected_node, nodes);
1515   }
1516   nodes->emplace_back(&input_node);
1517 }
1518 
1519 // Parse sharding configuration from |node| or it's adjacent nodes.
1520 // XlaSharding configuration may be derived from
1521 //   a) Connected Identity op node.
1522 //   b) Connected Cast op node.
1523 xla::StatusOr<absl::optional<NodeAndSharding>>
ParseInputShardingFromAdjacentNode(const int num_cores_per_replica,const Node & node)1524 ParseInputShardingFromAdjacentNode(const int num_cores_per_replica,
1525                                    const Node& node) {
1526   // If |node| has `device` attribute or is a XlaSharding op,
1527   // return the parsed OpSharding.
1528   TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
1529                       ParseShardingFromDevice(node, num_cores_per_replica,
1530                                               /*add_metadata=*/true));
1531   if (sharding.has_value()) {
1532     return absl::optional<NodeAndSharding>(NodeAndSharding(&node, *sharding));
1533   }
1534 
1535   // XlaShardingOp may be followed by an identity or followed by identity
1536   // and a Cast op.
1537   std::vector<const Node*> potential_nodes_with_input_sharding;
1538   FindNodesMaybeContainingShardingInfo(node,
1539                                        &potential_nodes_with_input_sharding);
1540   for (const Node* maybe_node_with_sharding_info :
1541        potential_nodes_with_input_sharding) {
1542     if (maybe_node_with_sharding_info->type_string() != "XlaSharding") continue;
1543 
1544     TF_ASSIGN_OR_RETURN(
1545         absl::optional<xla::OpSharding> sharding_config,
1546         ParseShardingFromDevice(*maybe_node_with_sharding_info,
1547                                 num_cores_per_replica, /*add_metadata=*/true));
1548     if (sharding_config.has_value()) {
1549       return absl::optional<NodeAndSharding>(
1550           NodeAndSharding(maybe_node_with_sharding_info, *sharding_config));
1551     }
1552   }
1553   return absl::optional<NodeAndSharding>();
1554 }
1555 
1556 // Walk the graph from an argument node to find OpSharding configuration
1557 // from its neighbor nodes. Sharding configuration may be inferred from
1558 //  1) Parsing XlaSharding attribute from neighboring node.
1559 //  2) If argument node is a resource, then by parsing adjacent nodes
1560 //     of the connected ReadVariable op.
ParseAndValidateShardingFromNeighbors(const int num_cores_per_replica,const std::string & arg_node_name,const Node & neighbor_node,int64 * inferred_core_id,bool * is_fast_mem,absl::optional<NodeAndSharding> * result)1561 Status ParseAndValidateShardingFromNeighbors(
1562     const int num_cores_per_replica, const std::string& arg_node_name,
1563     const Node& neighbor_node, int64* inferred_core_id, bool* is_fast_mem,
1564     absl::optional<NodeAndSharding>* result) {
1565   if (neighbor_node.attrs().Find(TPU_FAST_MEM_ATTR) != nullptr) {
1566     *is_fast_mem = true;
1567     VLOG(2) << "place " << neighbor_node.name() << " on fast memory because "
1568             << arg_node_name << " has " << TPU_FAST_MEM_ATTR << " attribute";
1569   }
1570 
1571   // XlaSharding information may be encoded on node directly connected to the
1572   // argument node.
1573   TF_ASSIGN_OR_RETURN(
1574       absl::optional<NodeAndSharding> node_and_sharding,
1575       ParseInputShardingFromAdjacentNode(num_cores_per_replica, neighbor_node));
1576   if (node_and_sharding.has_value()) {
1577     TF_RETURN_IF_ERROR(ParseAndValidateSharding(
1578         *node_and_sharding, num_cores_per_replica, inferred_core_id, result));
1579     return Status::OK();
1580   }
1581 
1582   // When we use variable in TPU computation, we always have a
1583   // XlaSharding op followed by a ReadVariableOp. As so, correctly parse
1584   // the users of ReadVariableOp for potential sharding configuration.
1585   if (neighbor_node.type_string() == "ReadVariableOp") {
1586     for (const Edge* e : neighbor_node.out_edges()) {
1587       if (e->IsControlEdge()) continue;
1588 
1589       if (e->dst()->attrs().Find(TPU_FAST_MEM_ATTR) != nullptr) {
1590         *is_fast_mem = true;
1591         VLOG(2) << "place " << arg_node_name << " on fast memory because "
1592                 << e->dst()->name() << TPU_FAST_MEM_ATTR << " attribute";
1593       }
1594 
1595       TF_ASSIGN_OR_RETURN(
1596           absl::optional<NodeAndSharding> node_and_sharding,
1597           ParseInputShardingFromAdjacentNode(num_cores_per_replica, *e->dst()));
1598       if (node_and_sharding.has_value()) {
1599         TF_RETURN_IF_ERROR(ParseAndValidateSharding(*node_and_sharding,
1600                                                     num_cores_per_replica,
1601                                                     inferred_core_id, result));
1602         return Status::OK();
1603       }
1604     }
1605   }
1606   return Status::OK();
1607 }
1608 
1609 }  // namespace
1610 
1611 // Inputs:
1612 //   replication_spec_string: the device to which the TPUReplicate node was
1613 //     assigned.
1614 //   device_set: the set of TF devices.
1615 // Outputs:
1616 //   tpu_compilation_device: the name of the TPU compilation device.
1617 //   num_tpus_per_task: the number of TPUs in each task. Verifies that all tasks
1618 //     have the same number of TPU devices.
1619 //   tpu_devices: the TPU devices, indexed by [task][device].
GetTPUDeviceNames(const string & replication_spec_string,const DeviceSet & device_set,string * tpu_compilation_device,int * num_tpus_per_task,std::vector<std::vector<Device * >> * tpu_devices)1620 static Status GetTPUDeviceNames(
1621     const string& replication_spec_string, const DeviceSet& device_set,
1622     string* tpu_compilation_device, int* num_tpus_per_task,
1623     std::vector<std::vector<Device*>>* tpu_devices) {
1624   // TODO(b/110910013) GetSystemDevice parses the spec and returns the name of
1625   // the tpu_system device, which we replace by the cpu device. We do this
1626   // replacement because we want to place the TPUCompileOp (and the compile
1627   // assert op) explicitly on cpu devices on the same job as the tpu_system
1628   // device.
1629   DeviceNameUtils::ParsedName replication_spec;
1630   Device* replication_device;
1631   TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetSystemDevice(
1632       replication_spec_string, device_set, &replication_spec,
1633       &replication_device));
1634   *tpu_compilation_device =
1635       str_util::StringReplace(replication_device->name(), DEVICE_TPU_SYSTEM,
1636                               DEVICE_CPU, /*replace_all=*/true);
1637 
1638   // Finds the set of TPU devices attached to the tasks in the job.
1639   TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetTPUDevices(
1640       replication_spec, device_set, num_tpus_per_task, tpu_devices));
1641 
1642   return Status::OK();
1643 }
1644 
1645 // Parses the topology attribute of TPUReplicate, and populates *topology with
1646 // a physical mesh coordinate to (task, device) mapping.
ParseTopologyAttr(const string & topology_attr,const tpu::TpuTopologyExternal & tpu_topology,int num_tasks,int num_tpus_per_task,xla::Array4D<std::pair<int,int>> * topology)1647 static Status ParseTopologyAttr(const string& topology_attr,
1648                                 const tpu::TpuTopologyExternal& tpu_topology,
1649                                 int num_tasks, int num_tpus_per_task,
1650                                 xla::Array4D<std::pair<int, int>>* topology) {
1651   static_assert(4 == kTPUTopologyRank, "Assumes the topology rank is 4");
1652   tpu::TopologyProto proto;
1653   proto.ParseFromString(topology_attr);
1654   if (proto.mesh_shape_size() != kTPUTopologyRank) {
1655     return errors::InvalidArgument("TPU topology must be rank ",
1656                                    kTPUTopologyRank);
1657   }
1658   if (proto.num_tasks() != num_tasks) {
1659     return errors::InvalidArgument("Mismatched number of TPU tasks");
1660   }
1661   if (proto.num_tpu_devices_per_task() != num_tpus_per_task) {
1662     return errors::InvalidArgument("Mismatched number of TPUs per task (",
1663                                    proto.num_tpu_devices_per_task(),
1664                                    " != ", num_tpus_per_task, ").");
1665   }
1666   if (proto.device_coordinates_size() !=
1667       num_tasks * num_tpus_per_task * kTPUTopologyRank) {
1668     return errors::InvalidArgument(
1669         "device coordinates should be ", num_tasks, "x", num_tpus_per_task, "x",
1670         kTPUTopologyRank, "; got ", proto.device_coordinates_size());
1671   }
1672 
1673   int devices_per_chip = tpu_topology.LogicalDevicesPerChip(kTensorCore);
1674   *topology = xla::Array4D<std::pair<int, int>>(
1675       tpu_topology.chip_bounds().x, tpu_topology.chip_bounds().y,
1676       tpu_topology.chip_bounds().z, devices_per_chip, {-1, -1});
1677   int pos = 0;
1678   for (int task = 0; task < num_tasks; ++task) {
1679     for (int device = 0; device < num_tpus_per_task; ++device) {
1680       int32_t x = proto.device_coordinates(pos++);
1681       int32_t y = proto.device_coordinates(pos++);
1682       int32_t z = proto.device_coordinates(pos++);
1683       int32_t core = proto.device_coordinates(pos++);
1684 
1685       if (!tpu_topology.HasChip(x, y, z) || core < 0 ||
1686           core >= devices_per_chip) {
1687         return errors::InvalidArgument(
1688             "Mesh coordinates (", x, ",", y, ",", z, ",", core,
1689             ") are not valid for the current TPU topology");
1690       }
1691       if ((*topology)(x, y, z, core).first != -1) {
1692         return errors::InvalidArgument("Duplicate coordinates (", x, ",", y,
1693                                        ",", z, ",", core, ") in TPU topology");
1694       }
1695       (*topology)(x, y, z, core) = {task, device};
1696     }
1697   }
1698   return Status::OK();
1699 }
1700 
1701 // Parses the value of the device_assignment attribute to TPUReplicate.
1702 // Populates *device_assignment; *device_assignment must be a 2D array with
1703 // shape (num_replicas, num_cores_per_replica).
ParseDeviceAssignmentAttr(absl::Span<const int> device_assignment_attr,const tpu::TpuTopologyExternal & tpu_topology,int num_replicas,int num_cores_per_replica,xla::Array2D<tpu::TpuCoreLocationExternal> * device_assignment)1704 static Status ParseDeviceAssignmentAttr(
1705     absl::Span<const int> device_assignment_attr,
1706     const tpu::TpuTopologyExternal& tpu_topology, int num_replicas,
1707     int num_cores_per_replica,
1708     xla::Array2D<tpu::TpuCoreLocationExternal>* device_assignment) {
1709   static_assert(4 == kTPUTopologyRank, "Assumes the topology rank is 4");
1710 
1711   const int64_t device_assignment_attr_size =
1712       num_replicas * num_cores_per_replica * kTPUTopologyRank;
1713   if (device_assignment_attr.size() != device_assignment_attr_size) {
1714     return errors::InvalidArgument(
1715         "Length of device_assignment attribute must be equal to num_replicas (",
1716         num_replicas, ") * num_cores_per_replica (", num_cores_per_replica,
1717         ") * ", kTPUTopologyRank, " got ", device_assignment_attr.size());
1718   }
1719   for (int core : device_assignment_attr) {
1720     if (core < 0 || core >= kTPUMaxTopologySize) {
1721       return errors::InvalidArgument(
1722           "Invalid core number in device assignment: ", core);
1723     }
1724   }
1725 
1726   *device_assignment = xla::Array2D<tpu::TpuCoreLocationExternal>(
1727       num_replicas, num_cores_per_replica);
1728   int devices_per_chip = tpu_topology.LogicalDevicesPerChip(kTensorCore);
1729   xla::Array4D<int> replica_assignment(
1730       tpu_topology.chip_bounds().x, tpu_topology.chip_bounds().y,
1731       tpu_topology.chip_bounds().z, devices_per_chip, -1);
1732   int pos = 0;
1733   for (int replica = 0; replica < num_replicas; ++replica) {
1734     for (int logical_core = 0; logical_core < num_cores_per_replica;
1735          ++logical_core) {
1736       int32_t x = device_assignment_attr[pos++];
1737       int32_t y = device_assignment_attr[pos++];
1738       int32_t z = device_assignment_attr[pos++];
1739       int32_t core = device_assignment_attr[pos++];
1740 
1741       if (!tpu_topology.HasChip(x, y, z) || core < 0 ||
1742           core >= devices_per_chip) {
1743         return errors::InvalidArgument(
1744             "Mesh coordinates (", x, ",", y, ",", core,
1745             ") are not valid for the current TPU topology");
1746       }
1747       tpu::TpuCoreLocationExternal core_location =
1748           tpu_topology.Core(kTensorCore, x, y, z, core);
1749 
1750       if (replica_assignment(x, y, z, core) != -1) {
1751         return errors::InvalidArgument("Duplicate coordinates (", x, ",", y,
1752                                        ",", z, ",", core,
1753                                        ") in TPU device assignment");
1754       }
1755       replica_assignment(x, y, z, core) = replica;
1756       (*device_assignment)(replica, logical_core) = core_location;
1757     }
1758   }
1759   return Status::OK();
1760 }
1761 
1762 // Builds TensorFlow device assignments for the special case of a single core
1763 // computation that is replicated to every core in the mesh.
1764 // LINT.IfChange
BuildFullMeshDeviceAssignment(int num_replicas,const std::vector<std::vector<Device * >> & tpu_devices,int num_tasks,int num_tpus_per_task,std::vector<std::vector<string>> * tf_device_assignment,std::vector<int> * devices_to_lock)1765 static Status BuildFullMeshDeviceAssignment(
1766     int num_replicas, const std::vector<std::vector<Device*>>& tpu_devices,
1767     int num_tasks, int num_tpus_per_task,
1768     std::vector<std::vector<string>>* tf_device_assignment,
1769     std::vector<int>* devices_to_lock) {
1770   // Assign TensorFlow devices to replicas arbitrarily.
1771   for (int i = 0; i < num_replicas; ++i) {
1772     int task = i / num_tpus_per_task;
1773     int device = i % num_tpus_per_task;
1774     TF_RET_CHECK(task >= 0 && task < num_tasks);
1775     TF_RET_CHECK(device >= 0 && device < num_tpus_per_task);
1776 
1777     // We don't actually know which TF device corresponds to which physical
1778     // device, but it doesn't matter—they're all identical.
1779     (*tf_device_assignment)[i] = {tpu_devices[task][device]->name()};
1780     devices_to_lock->push_back(i);
1781   }
1782   return Status::OK();
1783 }
1784 // LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
1785 
1786 // Builds TensorFlow device assignments for a replicated computation and convert
1787 // device_assignment into xla_device_assignment.
BuildGeneralDeviceAssignment(int num_replicas,int num_cores_per_replica,const std::vector<std::vector<Device * >> & tpu_devices,const xla::Array2D<tpu::TpuCoreLocationExternal> & device_assignment,const xla::Array4D<std::pair<int,int>> & topology,std::vector<std::vector<string>> * tf_device_assignment,std::vector<int> * devices_to_lock,std::unique_ptr<xla::DeviceAssignment> * xla_device_assignment)1788 static Status BuildGeneralDeviceAssignment(
1789     int num_replicas, int num_cores_per_replica,
1790     const std::vector<std::vector<Device*>>& tpu_devices,
1791     const xla::Array2D<tpu::TpuCoreLocationExternal>& device_assignment,
1792     const xla::Array4D<std::pair<int, int>>& topology,
1793     std::vector<std::vector<string>>* tf_device_assignment,
1794     std::vector<int>* devices_to_lock,
1795     std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment) {
1796   // Assign TensorFlow devices to each computation's replicas according to
1797   // device_assignment and 'topology'.
1798   *xla_device_assignment = absl::make_unique<xla::DeviceAssignment>(
1799       num_replicas, num_cores_per_replica);
1800   for (int replica = 0; replica < num_replicas; ++replica) {
1801     for (int computation = 0; computation < num_cores_per_replica;
1802          ++computation) {
1803       const tpu::TpuCoreLocationExternal& core_location =
1804           device_assignment(replica, computation);
1805 
1806       int task;
1807       int device;
1808       std::tie(task, device) =
1809           topology(core_location.chip_coordinates().x,
1810                    core_location.chip_coordinates().y,
1811                    core_location.chip_coordinates().z, core_location.index());
1812 
1813       CHECK_LT(computation, num_cores_per_replica);
1814       (**xla_device_assignment)(replica, computation) = core_location.Id();
1815 
1816       // The communication pattern between replicas will be determined later by
1817       // BuildAllReduceRing.
1818       TF_RET_CHECK(task >= 0 && task < tpu_devices.size());
1819       TF_RET_CHECK(device >= 0 && device < tpu_devices[task].size());
1820       (*tf_device_assignment)[replica].push_back(
1821           tpu_devices[task][device]->name());
1822       devices_to_lock->push_back((task * tpu_devices[task].size()) + device);
1823     }
1824   }
1825   return Status::OK();
1826 }
1827 
BuildDeviceAssignment(const tpu::TpuTopologyExternal & tpu_topology,int num_tpus_per_task,const std::vector<std::vector<Device * >> & tpu_devices,int num_replicas,int num_cores_per_replica,const string & topology_attr,absl::Span<const int> device_assignment_attr,std::vector<std::vector<string>> * tf_device_assignment,std::vector<int> * devices_to_lock,std::unique_ptr<xla::DeviceAssignment> * xla_device_assignment)1828 /*static*/ Status DistributedTPURewritePass::BuildDeviceAssignment(
1829     const tpu::TpuTopologyExternal& tpu_topology, int num_tpus_per_task,
1830     const std::vector<std::vector<Device*>>& tpu_devices, int num_replicas,
1831     int num_cores_per_replica, const string& topology_attr,
1832     absl::Span<const int> device_assignment_attr,
1833     std::vector<std::vector<string>>* tf_device_assignment,
1834     std::vector<int>* devices_to_lock,
1835     std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment) {
1836   const int num_tasks = tpu_devices.size();
1837   const int num_tpu_devices = num_tasks * num_tpus_per_task;
1838   VLOG(2) << "num_tasks=" << num_tasks
1839           << " num_tpus_per_task=" << num_tpus_per_task;
1840 
1841   // Checks num_replicas is sane first to avoid integer overflow.
1842   if (num_replicas > num_tpu_devices) {
1843 #ifdef PLATFORM_CLOUD_TPU
1844     return errors::InvalidArgument("Requested num_replicas=", num_replicas,
1845                                    " but there are only ", num_tpu_devices,
1846                                    " cores in the TPU topology.");
1847 #else
1848     return errors::InvalidArgument("Requested num_replicas=", num_replicas,
1849                                    " but there are only ", num_tpu_devices,
1850                                    " cores in the TPU topology.");
1851 #endif
1852   }
1853   if (num_replicas * num_cores_per_replica > num_tpu_devices) {
1854     return errors::InvalidArgument(
1855         "Requested num_replicas=", num_replicas, " with ",
1856         num_cores_per_replica, " cores per replica, but there are only ",
1857         num_tpu_devices, " cores in the TPU topology");
1858   }
1859 
1860   tf_device_assignment->clear();
1861   tf_device_assignment->resize(num_replicas);
1862 
1863   devices_to_lock->clear();
1864   devices_to_lock->reserve(num_replicas * num_cores_per_replica);
1865 
1866   // Special case: we allow the user to omit the topology and device assignment
1867   // information in two cases:
1868   // * there is only one replica and one core per replica. In this case, we
1869   //   don't need to know topology information because we don't communicate with
1870   //   other cores.
1871   // * the number of replicas is equal to the number of cores in the slice. In
1872   //   this case, all cores are running the same program so we don't need to
1873   //   know which is which.
1874   if (topology_attr.empty()) {
1875     // LINT.IfChange
1876     if (num_replicas != 1 && num_replicas != num_tpu_devices) {
1877       return errors::InvalidArgument(
1878           "TPUReplicate asked to create ", num_replicas,
1879           " replicas, but the number of cores in the TPU topology is ",
1880           num_tpu_devices,
1881           " and no TPU device assignment was supplied. "
1882           "A TPU device assignment is required if the number of replicas is "
1883           "not 1 or the number of cores in the topology (",
1884           num_tpu_devices, ")");
1885     }
1886 
1887     if (num_cores_per_replica != 1) {
1888       return errors::InvalidArgument(
1889           "A TPU topology must be provided if num_cores_per_replica != 1");
1890     }
1891 
1892     if (!device_assignment_attr.empty()) {
1893       return errors::InvalidArgument(
1894           "A TPU topology must be provided if device_assignment_attr is "
1895           "non-empty");
1896     }
1897 
1898     // If there is only one replica, assign the Tensorflow computation to task 0
1899     // device 0, and leave the XLA device assignment empty. We don't know which
1900     // core this is in the TPU topology, but it doesn't matter—we don't need to
1901     // communicate with any other cores.
1902     if (num_replicas == 1) {
1903       (*tf_device_assignment)[0] = {tpu_devices[0][0]->name()};
1904       devices_to_lock->push_back(0);
1905       return Status::OK();
1906     }
1907 
1908     // Otherwise, num_replicas is equal to the number of cores, and we build a
1909     // device assignment that covers the entire mesh. We do not need to know
1910     // the topology to do so because all cores are identical.
1911     return BuildFullMeshDeviceAssignment(num_replicas, tpu_devices, num_tasks,
1912                                          num_tpus_per_task,
1913                                          tf_device_assignment, devices_to_lock);
1914     // LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
1915   }
1916 
1917   // Array that maps mesh coordinates to {TF task, TF TPU device #} pairs.
1918   xla::Array4D<std::pair<int, int>> topology;
1919   TF_RETURN_IF_ERROR(ParseTopologyAttr(topology_attr, tpu_topology, num_tasks,
1920                                        num_tpus_per_task, &topology));
1921 
1922   // Array that maps logical (replica, core) pairs to physical mesh coordinates.
1923   xla::Array2D<tpu::TpuCoreLocationExternal> device_assignment;
1924   TF_RETURN_IF_ERROR(ParseDeviceAssignmentAttr(
1925       device_assignment_attr, tpu_topology, num_replicas, num_cores_per_replica,
1926       &device_assignment));
1927 
1928   return BuildGeneralDeviceAssignment(
1929       num_replicas, num_cores_per_replica, tpu_devices, device_assignment,
1930       topology, tf_device_assignment, devices_to_lock, xla_device_assignment);
1931 }
1932 
GetComputationForTPUReplicateOp(const NameAttrList & function,FunctionLibraryRuntime * flr,Graph * computation,DataTypeVector * arg_types,DataTypeVector * retval_types)1933 Status DistributedTPURewritePass::GetComputationForTPUReplicateOp(
1934     const NameAttrList& function, FunctionLibraryRuntime* flr,
1935     Graph* computation, DataTypeVector* arg_types,
1936     DataTypeVector* retval_types) {
1937   FunctionLibraryRuntime::Handle handle;
1938 
1939   TF_RETURN_IF_ERROR(
1940       flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
1941 
1942   const FunctionBody* fbody = flr->GetFunctionBody(handle);
1943 
1944   CopyGraph(*fbody->graph, computation);
1945   *arg_types = fbody->arg_types;
1946   *retval_types = fbody->ret_types;
1947   return Status::OK();
1948 }
1949 
1950 // Grab the InferredShape corresponding to an edge input.
GetEdgeShape(const GraphShapeInfo & shape_info,const Edge & edge,const InferredShape ** info)1951 static Status GetEdgeShape(const GraphShapeInfo& shape_info, const Edge& edge,
1952                            const InferredShape** info) {
1953   auto it = shape_info.find(edge.src()->name());
1954   if (it == shape_info.end()) {
1955     return errors::InvalidArgument(
1956         "Input to replicated TPU computation is missing InferredShape: ",
1957         edge.src()->name());
1958   }
1959   TF_RET_CHECK(it->second.size() > edge.src_output());
1960   *info = &it->second[edge.src_output()];
1961   return Status::OK();
1962 }
1963 
GetArgAndRetvalShapes(const GraphShapeInfo & shape_info,const Node & node,const ParameterInfo & params_info,std::vector<InferredShape> * arg_shapes,std::vector<InferredShape> * retval_shapes)1964 Status DistributedTPURewritePass::GetArgAndRetvalShapes(
1965     const GraphShapeInfo& shape_info, const Node& node,
1966     const ParameterInfo& params_info, std::vector<InferredShape>* arg_shapes,
1967     std::vector<InferredShape>* retval_shapes) {
1968   std::vector<const Edge*> input_edges;
1969   TF_RETURN_IF_ERROR(node.input_edges(&input_edges));
1970 
1971   // If any replica's arg shape is unknown, we will mark the computation's arg
1972   // shape as being unknown. If the shapes differ the TpuExecute Op will raise a
1973   // runtime error.
1974   std::vector<bool> any_replica_shape_unknown(
1975       params_info.NumInputsToEachReplica());
1976   arg_shapes->clear();
1977   arg_shapes->resize(params_info.NumInputsToEachReplica());
1978   TF_RET_CHECK(input_edges.size() == params_info.NumInputsFromHost());
1979   // Determines the shapes of the per-replica arguments and checks that all
1980   // replicas have identical shapes.
1981   int64_t edge_pos = 0;
1982   auto check_shape = [&](int input_index) -> Status {
1983     const InferredShape* info;
1984     TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
1985     ++edge_pos;
1986 
1987     if ((info->handle_type == DT_INVALID && !info->shape.IsFullyDefined()) ||
1988         (info->handle_type != DT_INVALID &&
1989          !info->handle_shape.IsFullyDefined())) {
1990       any_replica_shape_unknown[input_index] = true;
1991     }
1992     xla::StatusOr<InferredShape> status =
1993         MergeInferredShapes((*arg_shapes)[input_index], *info);
1994     if (!status.ok()) {
1995       return errors::InvalidArgument(
1996           "Mismatched shapes for input ", input_index, ": ",
1997           (*arg_shapes)[input_index].shape.DebugString(), " vs. ",
1998           info->shape.DebugString());
1999     }
2000     (*arg_shapes)[input_index] = status.ValueOrDie();
2001     return Status::OK();
2002   };
2003 
2004   for (int64_t i = 0; i < params_info.NumReplicas(); ++i) {
2005     for (int64_t j = 0; j < params_info.NumPerReplicaArgs(); ++j) {
2006       TF_RETURN_IF_ERROR(check_shape(j));
2007     }
2008   }
2009 
2010   for (int64_t i = 0; i < params_info.NumDistributedArgs(); ++i) {
2011     TF_RETURN_IF_ERROR(check_shape(params_info.NumPerReplicaArgs() + i));
2012   }
2013 
2014   for (int64_t i = 0;
2015        i < params_info.NumPerReplicaArgs() + params_info.NumDistributedArgs();
2016        ++i) {
2017     if (any_replica_shape_unknown[i]) {
2018       (*arg_shapes)[i].shape = PartialTensorShape();
2019       (*arg_shapes)[i].handle_shape = PartialTensorShape();
2020     }
2021   }
2022 
2023   // Determines the shape of the broadcast arguments.
2024   for (int64_t i = 0; i < params_info.NumBroadcastArgs(); ++i) {
2025     TF_RET_CHECK(node.input_type(edge_pos) != DT_RESOURCE);
2026     const InferredShape* info;
2027     TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
2028     (*arg_shapes)[i + params_info.NumPerReplicaArgs() +
2029                   params_info.NumDistributedArgs()]
2030         .shape = info->shape;
2031     ++edge_pos;
2032   }
2033 
2034   // Determines the handle shape and handle type of the resource variable
2035   // arguments.
2036   for (int64_t i = 0; i < params_info.NumVariables(); ++i) {
2037     TF_RET_CHECK(node.input_type(edge_pos) == DT_RESOURCE);
2038     const InferredShape* info;
2039     TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
2040     InferredShape& arg_shape =
2041         (*arg_shapes)[i + params_info.NumPerReplicaArgs() +
2042                       params_info.NumDistributedArgs() +
2043                       params_info.NumBroadcastArgs()];
2044     arg_shape.shape = TensorShape();  // Variables are always scalars.
2045     arg_shape.handle_shape = info->handle_shape;
2046     arg_shape.handle_type = info->handle_type;
2047     TF_RET_CHECK(arg_shape.handle_type != DT_INVALID)
2048         << " input edge: " << input_edges[edge_pos]->DebugString();
2049     ++edge_pos;
2050   }
2051 
2052   // Determines the shape of the guaranteed constants.
2053   // TODO(vinuraja): Can be removed because they are not required for any
2054   // calculations. Leaving them here for symmetry with other structures like
2055   // arg_types, arg_sharding, etc.
2056   for (int64_t i = 0; i < params_info.NumGuaranteedConstants(); ++i) {
2057     TF_RET_CHECK(node.input_type(edge_pos) != DT_RESOURCE);
2058     const InferredShape* info;
2059     TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info));
2060     (*arg_shapes)[i + params_info.NumPerReplicaArgs() +
2061                   params_info.NumDistributedArgs() +
2062                   params_info.NumBroadcastArgs() + params_info.NumVariables()]
2063         .shape = info->shape;
2064     ++edge_pos;
2065   }
2066 
2067   // Extract the return value shapes.
2068   auto it = shape_info.find(node.name());
2069   retval_shapes->clear();
2070   if (it != shape_info.end()) {
2071     TF_RET_CHECK(it->second.size() >= node.num_outputs());
2072     retval_shapes->resize(node.num_outputs());
2073     for (int i = 0; i < node.num_outputs(); ++i) {
2074       (*retval_shapes)[i].shape = it->second[i].shape;
2075     }
2076   } else if (node.num_outputs() > 0) {
2077     return errors::InvalidArgument(
2078         "Replicated TPU computation is missing InferredShape: ",
2079         FormatNodeForError(node));
2080   }
2081   return Status::OK();
2082 }
2083 
2084 // Verifies that all nodes have legal sharding.
ValidateCoreNumbers(const Graph & graph,int num_cores_per_replica)2085 static Status ValidateCoreNumbers(const Graph& graph,
2086                                   int num_cores_per_replica) {
2087   for (Node* n : graph.nodes()) {
2088     TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding,
2089                         ParseShardingFromDevice(*n, num_cores_per_replica,
2090                                                 /*add_metadata=*/true));
2091   }
2092   return Status::OK();
2093 }
2094 
InferXlaShardingFromNeighbors(const Node & n,int num_cores_per_replica,FunctionLibraryRuntime * flr,CachedFunctionHandles * cached_function_handles,absl::optional<NodeAndSharding> * output_node_and_sharding,bool * is_fast_mem)2095 static Status InferXlaShardingFromNeighbors(
2096     const Node& n, int num_cores_per_replica, FunctionLibraryRuntime* flr,
2097     CachedFunctionHandles* cached_function_handles,
2098     absl::optional<NodeAndSharding>* output_node_and_sharding,
2099     bool* is_fast_mem) {
2100   int64_t core = -1;
2101   absl::optional<NodeAndSharding> result;
2102   // We assume the variable has been allocated on fast memory if any consuming
2103   // op has TPU_FAST_MEM_ATTR attribute. This is a protocol between runtime and
2104   // compiler.
2105   *is_fast_mem = false;
2106   for (const Edge* edge : n.out_edges()) {
2107     if (edge->IsControlEdge()) continue;
2108 
2109     TF_RETURN_IF_ERROR(ParseAndValidateShardingFromNeighbors(
2110         num_cores_per_replica, n.name(), *edge->dst(), &core, is_fast_mem,
2111         &result));
2112 
2113     if (!flr) continue;
2114 
2115     // The nodes deciding this arg's device assignment might be in
2116     // FunctionDef. Instantiate FunctionDefs associated with this node
2117     // and check nodes using this arg.
2118     std::function<Status(const Edge* call_edge)> parse_sharding_from_function =
2119         [&](const Edge* call_edge) {
2120           auto associated_functions = GetAssociatedFunctions(
2121               *call_edge->dst(), flr->GetFunctionLibraryDefinition());
2122           for (auto& associated_function : associated_functions) {
2123             FunctionLibraryRuntime::Handle handle;
2124             TF_RETURN_IF_ERROR(cached_function_handles->GetOrInstantiate(
2125                 associated_function.func_name(),
2126                 AttrSlice(&associated_function.attrs()), &handle));
2127             const FunctionBody* body = flr->GetFunctionBody(handle);
2128             Graph* g = body->graph;
2129 
2130             for (Node* body_node : g->nodes()) {
2131               if (!body_node->IsArg()) continue;
2132 
2133               int index;
2134               TF_RETURN_IF_ERROR(
2135                   GetNodeAttr(body_node->attrs(), "index", &index));
2136               if (index != call_edge->dst_input()) continue;
2137 
2138               for (const Edge* out_edge : body_node->out_edges()) {
2139                 if (out_edge->IsControlEdge()) continue;
2140 
2141                 TF_RETURN_IF_ERROR(ParseAndValidateShardingFromNeighbors(
2142                     num_cores_per_replica, n.name(), *out_edge->dst(), &core,
2143                     is_fast_mem, &result));
2144 
2145                 TF_RETURN_IF_ERROR(parse_sharding_from_function(out_edge));
2146               }
2147             }
2148           }
2149           return Status::OK();
2150         };
2151     TF_RETURN_IF_ERROR(parse_sharding_from_function(edge));
2152   }
2153   *output_node_and_sharding = result;
2154   return Status::OK();
2155 }
2156 
UseSpmdForXlaPartitioning(const Node * replicate_node)2157 bool UseSpmdForXlaPartitioning(const Node* replicate_node) {
2158   bool spmd_attr;
2159   if (!replicate_node ||
2160       !TryGetNodeAttr(replicate_node->attrs(), "use_spmd_for_xla_partitioning",
2161                       &spmd_attr)) {
2162     spmd_attr = false;
2163   }
2164   return spmd_attr;
2165 }
2166 
FormatNodeAndShardingMsg(const absl::optional<NodeAndSharding> & node_and_sharding)2167 std::string FormatNodeAndShardingMsg(
2168     const absl::optional<NodeAndSharding>& node_and_sharding) {
2169   DCHECK(node_and_sharding.has_value());
2170 
2171   xla::OpSharding sharding_no_metadata = node_and_sharding->sharding;
2172   sharding_no_metadata.clear_metadata();
2173   std::string escaped_sharding_str =
2174       absl::CEscape(sharding_no_metadata.SerializeAsString());
2175   if (node_and_sharding->node == nullptr) {
2176     return absl::StrCat(" via default sharding '", escaped_sharding_str, "'");
2177   }
2178 
2179   return absl::StrCat(" via node ", node_and_sharding->node->DebugString(),
2180                       " sharding '", escaped_sharding_str, "'");
2181 }
2182 
AssignArgsAndRetvalsToCores(int num_cores_per_replica,const ParameterInfo & params_info,const DataTypeVector & arg_types,const std::vector<InferredShape> & arg_shapes,const DataTypeVector & retval_types,const std::vector<InferredShape> & retval_shapes,const Graph & graph,const Node * replicate_node,FunctionLibraryRuntime * flr,bool allow_parameter_replication_for_spmd,std::vector<xla::OpSharding> * arg_sharding,std::vector<bool> * arg_fast_mem,std::vector<xla::OpSharding> * retval_sharding,std::vector<std::string> * arg_names)2183 Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores(
2184     int num_cores_per_replica, const ParameterInfo& params_info,
2185     const DataTypeVector& arg_types,
2186     const std::vector<InferredShape>& arg_shapes,
2187     const DataTypeVector& retval_types,
2188     const std::vector<InferredShape>& retval_shapes, const Graph& graph,
2189     const Node* replicate_node, FunctionLibraryRuntime* flr,
2190     bool allow_parameter_replication_for_spmd,
2191     std::vector<xla::OpSharding>* arg_sharding, std::vector<bool>* arg_fast_mem,
2192     std::vector<xla::OpSharding>* retval_sharding,
2193     std::vector<std::string>* arg_names) {
2194   // Builds vectors of the argument and return nodes.
2195   std::vector<Node*> args(arg_types.size());
2196   std::vector<Node*> retvals(retval_types.size());
2197   absl::flat_hash_map<int, Node*> partitioned_output_nodes;
2198   for (Node* node : graph.op_nodes()) {
2199     if (node->IsArg()) {
2200       int index;
2201       TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
2202       TF_RET_CHECK(index >= 0 && index < args.size());
2203       args[index] = node;
2204     } else if (node->IsRetval()) {
2205       int index;
2206       TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
2207       TF_RET_CHECK(index >= 0 && index < retvals.size());
2208       retvals[index] = node;
2209     }
2210   }
2211   for (const Edge* edge : replicate_node->out_edges()) {
2212     int num_partitioned_outputs = 0;
2213     for (const Edge* out_edge : edge->dst()->out_edges()) {
2214       if (out_edge->dst()->type_string() == kTPUPartitionedOutput) {
2215         partitioned_output_nodes[edge->src_output()] = out_edge->dst();
2216         num_partitioned_outputs++;
2217       }
2218     }
2219     if (num_partitioned_outputs > 1) {
2220       return errors::InvalidArgument(
2221           "More than one TPUPartitionedOutput per replciated output.");
2222     }
2223   }
2224 
2225   // Verifies there are no missing arguments/return values.
2226   for (int i = 0; i < args.size(); ++i) {
2227     if (args[i] == nullptr) {
2228       return errors::Internal("Missing function argument: ", i);
2229     }
2230   }
2231   for (int i = 0; i < retvals.size(); ++i) {
2232     if (retvals[i] == nullptr) {
2233       return errors::Internal("Missing function return value: ", i);
2234     }
2235   }
2236 
2237   // Assigns a core to each _Arg. Chooses the lowest-numbered core that
2238   // consumes the argument. We choose the lowest-numbered core so the
2239   // assignment is deterministic.
2240   TensorDevicePlacer args_device_selector(num_cores_per_replica, arg_types,
2241                                           arg_shapes);
2242   arg_sharding->resize(args.size());
2243   arg_names->resize(args.size());
2244   arg_fast_mem->resize(args.size());
2245   CachedFunctionHandles cached_function_handles(flr);
2246   const bool use_spmd = (UseSpmdForXlaPartitioning(replicate_node) ||
2247                          replicate_inputs_outputs_by_default_for_xla_spmd_) &&
2248                         allow_parameter_replication_for_spmd;
2249 
2250   // Offset _TPUReplicate non per replica argument indices by
2251   // (num_replicas - 1) * num_per_replica_args as _TPUReplicate nodes are
2252   // constructed with all per replica args across all replicas while the
2253   // encapsulated function only has 1 replica's per replica args. Per replica
2254   // args are ordered by replica first, so the index here does not require an
2255   // offset and the first replica's input nodes is sufficient for determining
2256   // argument sharding.
2257   const int index_offset =
2258       (params_info.NumReplicas() - 1) * params_info.NumPerReplicaArgs();
2259   for (int i = 0; i < args.size(); ++i) {
2260     const Node* n = args[i];
2261     absl::optional<int64> assigned_core;
2262     absl::optional<NodeAndSharding> node_and_sharding;
2263     bool is_fast_mem;
2264     TF_RETURN_IF_ERROR(InferXlaShardingFromNeighbors(
2265         *n, num_cores_per_replica, flr, &cached_function_handles,
2266         &node_and_sharding, &is_fast_mem));
2267 
2268     const bool is_per_replica_arg = params_info.IsPerReplicaArg(i);
2269     if (is_per_replica_arg || params_info.IsDistributedArg(i)) {
2270       Node* input_node;
2271       TF_RETURN_IF_ERROR(replicate_node->input_node(
2272           i + (is_per_replica_arg ? 0 : index_offset), &input_node));
2273       if (input_node->type_string() == kTPUPartitionedInput) {
2274         TF_ASSIGN_OR_RETURN(
2275             absl::optional<xla::OpSharding> parsed_sharding,
2276             GetShardingFromNodeDef(input_node->def(), /*add_metadata=*/true));
2277         if (!parsed_sharding.has_value())
2278           return errors::InvalidArgument("Missing _XlaSharding attr from: ",
2279                                          input_node->DebugString());
2280         node_and_sharding = NodeAndSharding(input_node, *parsed_sharding);
2281         VLOG(1) << "Arg " << i << " parsed sharding information from "
2282                 << input_node->DebugString() << " : "
2283                 << parsed_sharding->DebugString();
2284       }
2285     }
2286 
2287     if (params_info.IsVariableArg(i)) {
2288       Node* input_node;
2289       TF_RETURN_IF_ERROR(
2290           replicate_node->input_node(i + index_offset, &input_node));
2291       if (input_node->type_string() == kVarHandleOp) {
2292         TF_ASSIGN_OR_RETURN(
2293             absl::optional<xla::OpSharding> parsed_sharding,
2294             GetShardingFromNodeDef(input_node->def(), /*add_metadata=*/true));
2295         if (parsed_sharding.has_value()) {
2296           node_and_sharding = NodeAndSharding(input_node, *parsed_sharding);
2297           VLOG(1) << "Arg " << i << " parsed sharding information from "
2298                   << input_node->DebugString() << " : "
2299                   << parsed_sharding->DebugString();
2300         }
2301       }
2302     }
2303 
2304     if (node_and_sharding.has_value() && enable_automatic_model_parallelism_) {
2305       return tensorflow::errors::InvalidArgument(
2306           "Specifying manual sharding is not allowed when automatic "
2307           "model parallelism is enabled.",
2308           node_and_sharding->sharding.DebugString());
2309     }
2310 
2311     if (!node_and_sharding.has_value()) {
2312       if (use_spmd &&
2313           (params_info.IsVariableArg(i) || params_info.IsBroadcastArg(i) ||
2314            ((params_info.IsPerReplicaArg(i) ||
2315              params_info.IsDistributedArg(i)) &&
2316             arg_types[i] != DT_RESOURCE))) {
2317         // Use replication for host variables or non-variable per-replica
2318         // inputs.
2319         node_and_sharding = NodeAndSharding(/*node=*/nullptr,
2320                                             xla::sharding_builder::Replicate());
2321       } else {
2322         // TODO(dlibenzi): Distributing variables to cores other than 0 makes
2323         // learning/brain/research/babelfish/trainer:trainer_tpu_test fail.
2324         // For now distribute only per replica arguments, unless
2325         // tf_jf_distribute_vars is set, to allow debugging the issue.
2326         if (((params_info.IsPerReplicaArg(i) ||
2327               params_info.IsDistributedArg(i)) &&
2328              arg_types[i] != DT_RESOURCE) ||
2329             (distribute_vars_ && params_info.IsVariableArg(i))) {
2330           assigned_core = args_device_selector.RetrieveAssignment(i);
2331         } else {
2332           assigned_core = 0;
2333         }
2334         node_and_sharding = NodeAndSharding(
2335             /*node=*/nullptr,
2336             xla::sharding_builder::AssignDevice(*assigned_core));
2337       }
2338       *node_and_sharding->sharding.add_metadata() =
2339           CreateOpMetadataFromNode(*replicate_node);
2340     } else if (node_and_sharding->sharding.type() == xla::OpSharding::MAXIMAL) {
2341       assigned_core = node_and_sharding->sharding.tile_assignment_devices(0);
2342     } else if (node_and_sharding->sharding.type() !=
2343                    xla::OpSharding::REPLICATED &&
2344                node_and_sharding->sharding.type() != xla::OpSharding::OTHER) {
2345       return tensorflow::errors::InvalidArgument(
2346           "Unsupported argument sharding (for arg ", n->DebugString(),
2347           "): ", node_and_sharding->sharding.DebugString());
2348     }
2349     if (assigned_core.has_value()) {
2350       args_device_selector.ReportDeviceAssigned(*assigned_core, i);
2351       VLOG(3) << "Assigning argument " << i << " (" << n->DebugString()
2352               << ") to core " << *assigned_core
2353               << FormatNodeAndShardingMsg(node_and_sharding);
2354       args[i]->set_assigned_device_name(CoreDeviceLabel(*assigned_core));
2355     } else if (node_and_sharding->sharding.type() == xla::OpSharding::OTHER) {
2356       for (int64_t core :
2357            node_and_sharding->sharding.tile_assignment_devices()) {
2358         TF_RET_CHECK(core >= 0 && core < num_cores_per_replica)
2359             << "core " << core << " should be between [0, "
2360             << num_cores_per_replica << "). sharding is "
2361             << node_and_sharding->sharding.DebugString();
2362         args_device_selector.ReportDeviceAssigned(core, i);
2363       }
2364       VLOG(3) << "Assigning argument " << i << " (" << n->DebugString()
2365               << ") with tiled sharding to cores "
2366               << absl::StrJoin(
2367                      node_and_sharding->sharding.tile_assignment_devices(), ",")
2368               << " " << FormatNodeAndShardingMsg(node_and_sharding);
2369     } else {
2370       DCHECK_EQ(node_and_sharding->sharding.type(),
2371                 xla::OpSharding::REPLICATED);
2372       for (int64_t core = 0; core < num_cores_per_replica; ++core) {
2373         args_device_selector.ReportDeviceAssigned(core, i);
2374       }
2375       VLOG(3) << "Assigning argument " << i << " (" << n->DebugString()
2376               << ") to all cores"
2377               << FormatNodeAndShardingMsg(node_and_sharding);
2378     }
2379     (*arg_sharding)[i] = node_and_sharding->sharding;
2380     (*arg_fast_mem)[i] = is_fast_mem;
2381     (*arg_names)[i] = n->name();
2382     if (is_fast_mem) {
2383       VLOG(3) << "Add " << TPU_FAST_MEM_ATTR << " attribute to "
2384               << args[i]->name();
2385     }
2386     args[i]->AddAttr(kShardingAttribute,
2387                      node_and_sharding->sharding.SerializeAsString());
2388   }
2389   TF_RETURN_IF_ERROR(cached_function_handles.ReleaseAllHandles());
2390 
2391   // Assigns each _Retval node to the core that produces its value.
2392   TensorDevicePlacer retvals_device_selector(num_cores_per_replica,
2393                                              retval_types, retval_shapes);
2394   retval_sharding->resize(retvals.size());
2395   for (int i = 0; i < retvals.size(); ++i) {
2396     const Edge* edge;
2397     TF_RETURN_IF_ERROR(retvals[i]->input_edge(0, &edge));
2398 
2399     TF_ASSIGN_OR_RETURN(
2400         absl::optional<xla::OpSharding> edge_sharding,
2401         ParseShardingFromEdgeSource(*edge, num_cores_per_replica,
2402                                     /*add_metadata=*/true));
2403 
2404     absl::optional<NodeAndSharding> node_and_sharding;
2405     if (edge_sharding.has_value()) {
2406       node_and_sharding.emplace(NodeAndSharding(edge->src(), *edge_sharding));
2407     }
2408 
2409     if (partitioned_output_nodes.contains(i)) {
2410       Node* output_node = partitioned_output_nodes[i];
2411       TF_ASSIGN_OR_RETURN(
2412           absl::optional<xla::OpSharding> parsed_sharding,
2413           GetShardingFromNodeDef(output_node->def(), /*add_metadata=*/true));
2414       if (parsed_sharding.has_value()) {
2415         node_and_sharding = NodeAndSharding(output_node, *parsed_sharding);
2416         VLOG(1) << "Retval " << i << " parsed sharding information from "
2417                 << output_node->DebugString() << " : "
2418                 << parsed_sharding->DebugString();
2419       }
2420     }
2421     absl::optional<int64> assigned_core;
2422     if (node_and_sharding.has_value()) {
2423       if (enable_automatic_model_parallelism_) {
2424         return tensorflow::errors::InvalidArgument(
2425             "Specifying manual sharding is not allowed when automatic "
2426             "model parallelism is enabled.",
2427             node_and_sharding->sharding.DebugString());
2428       }
2429 
2430       if (node_and_sharding->sharding.type() == xla::OpSharding::MAXIMAL) {
2431         assigned_core = node_and_sharding->sharding.tile_assignment_devices(0);
2432         TF_RETURN_IF_ERROR(
2433             ValidateCoreNumber(*assigned_core, num_cores_per_replica));
2434       } else if (node_and_sharding->sharding.type() !=
2435                      xla::OpSharding::REPLICATED &&
2436                  node_and_sharding->sharding.type() != xla::OpSharding::OTHER) {
2437         return tensorflow::errors::InvalidArgument(
2438             "Unsupported argument sharding for retval ",
2439             retvals[i]->DebugString(), " edge=", edge->DebugString(), ": ",
2440             node_and_sharding->sharding.DebugString());
2441       }
2442     } else {
2443       if (use_spmd) {
2444         node_and_sharding = NodeAndSharding(/*node=*/nullptr,
2445                                             xla::sharding_builder::Replicate());
2446       } else {
2447         if (distribute_vars_) {
2448           assigned_core = retvals_device_selector.RetrieveAssignment(i);
2449         } else {
2450           assigned_core = 0;
2451         }
2452         node_and_sharding = NodeAndSharding(
2453             /*node=*/nullptr,
2454             xla::sharding_builder::AssignDevice(*assigned_core));
2455       }
2456       *node_and_sharding->sharding.add_metadata() =
2457           CreateOpMetadataFromNode(*replicate_node);
2458     }
2459     if (assigned_core.has_value()) {
2460       retvals[i]->set_assigned_device_name(CoreDeviceLabel(*assigned_core));
2461       retvals_device_selector.ReportDeviceAssigned(*assigned_core, i);
2462       VLOG(3) << "Assigning return value " << i << " ("
2463               << retvals[i]->DebugString() << ") to core " << *assigned_core
2464               << FormatNodeAndShardingMsg(node_and_sharding);
2465     } else if (node_and_sharding->sharding.type() == xla::OpSharding::OTHER) {
2466       for (int64_t core :
2467            node_and_sharding->sharding.tile_assignment_devices()) {
2468         TF_RET_CHECK(core >= 0 && core < num_cores_per_replica)
2469             << "core " << core << " should be between [0, "
2470             << num_cores_per_replica << "). sharding is "
2471             << node_and_sharding->sharding.DebugString();
2472         retvals_device_selector.ReportDeviceAssigned(core, i);
2473       }
2474       VLOG(3) << "Assigning return value " << i << " ("
2475               << retvals[i]->DebugString() << ") with tiled sharding to cores "
2476               << absl::StrJoin(
2477                      node_and_sharding->sharding.tile_assignment_devices(), ",")
2478               << " " << FormatNodeAndShardingMsg(node_and_sharding);
2479     } else {
2480       DCHECK_EQ(node_and_sharding->sharding.type(),
2481                 xla::OpSharding::REPLICATED);
2482       for (int64_t core = 0; core < num_cores_per_replica; ++core) {
2483         retvals_device_selector.ReportDeviceAssigned(core, i);
2484       }
2485       VLOG(3) << "Assigning return value " << i << " ("
2486               << retvals[i]->DebugString() << ") to all cores"
2487               << FormatNodeAndShardingMsg(node_and_sharding);
2488     }
2489     retvals[i]->AddAttr(kShardingAttribute,
2490                         node_and_sharding->sharding.SerializeAsString());
2491     (*retval_sharding)[i] = node_and_sharding->sharding;
2492   }
2493   if (use_spmd &&
2494       (absl::c_any_of(*arg_sharding,
2495                       [](const xla::OpSharding& s) {
2496                         return s.type() == xla::OpSharding::MAXIMAL;
2497                       }) ||
2498        absl::c_any_of(*retval_sharding, [](const xla::OpSharding& s) {
2499          return s.type() == xla::OpSharding::MAXIMAL;
2500        }))) {
2501     LOG(WARNING) << "XLA SPMD only supports cases where all inputs/outputs "
2502                     "exist on every partition (sharded or replicated). Fall "
2503                     "back to MPMD.";
2504     return AssignArgsAndRetvalsToCores(
2505         num_cores_per_replica, params_info, arg_types, arg_shapes, retval_types,
2506         retval_shapes, graph, replicate_node, flr,
2507         /*allow_parameter_replication_for_spmd=*/false, arg_sharding,
2508         arg_fast_mem, retval_sharding, arg_names);
2509   }
2510   return Status::OK();
2511 }
2512 
2513 // Builds Shape nodes that compute the shapes of arguments whose shapes are not
2514 // statically known.
BuildDynamicShapeNodes(const Node & replicate_node,const std::vector<InferredShape> & arg_shapes,const ParameterInfo & params_info,const std::vector<Node * > & variable_reads,Graph * graph,std::vector<Node * > * dynamic_shape_nodes)2515 /* static */ Status DistributedTPURewritePass::BuildDynamicShapeNodes(
2516     const Node& replicate_node, const std::vector<InferredShape>& arg_shapes,
2517     const ParameterInfo& params_info, const std::vector<Node*>& variable_reads,
2518     Graph* graph, std::vector<Node*>* dynamic_shape_nodes) {
2519   dynamic_shape_nodes->clear();
2520 
2521   std::vector<const Edge*> replicate_input_edges;
2522   TF_RETURN_IF_ERROR(replicate_node.input_edges(&replicate_input_edges));
2523 
2524   // The compiler determines the shape of each constant by inspecting the value
2525   // of its corresponding host-memory tensor; this happens when a step is run.
2526   // As a result, the shapes of constants are not needed at graph rewrite time.
2527   const int num_args = arg_shapes.size() - params_info.NumGuaranteedConstants();
2528   TF_RET_CHECK(num_args == params_info.NumPerReplicaArgs() +
2529                                params_info.NumDistributedArgs() +
2530                                params_info.NumBroadcastArgs() +
2531                                params_info.NumVariables());
2532 
2533   for (int i = 0; i < num_args; ++i) {
2534     const PartialTensorShape* shape = arg_shapes[i].handle_type == DT_INVALID
2535                                           ? &arg_shapes[i].shape
2536                                           : &arg_shapes[i].handle_shape;
2537     if (!shape->IsFullyDefined()) {
2538       NodeDef def;
2539       Node* src;
2540       int src_output;
2541       std::vector<Node*> control_inputs;
2542 
2543       if (params_info.IsVariableArg(i)) {
2544         int64_t var_num = i - params_info.NumPerReplicaArgs() -
2545                           params_info.NumDistributedArgs() -
2546                           params_info.NumBroadcastArgs();
2547         TF_RET_CHECK(0 <= var_num && var_num < variable_reads.size());
2548         Node* read = variable_reads[var_num];
2549 
2550         DCHECK_EQ(read->type_string(), "ReadVariableOp");
2551 
2552         for (const Edge* edge : read->in_edges()) {
2553           if (edge->IsControlEdge()) {
2554             control_inputs.push_back(edge->src());
2555           }
2556         }
2557 
2558         const Edge* variable_input = nullptr;
2559         TF_RETURN_IF_ERROR(read->input_edge(/*idx=*/0, &variable_input));
2560         src = variable_input->src();
2561         src_output = variable_input->src_output();
2562 
2563         def.set_name(
2564             graph->NewName(strings::StrCat(src->name(), "/variable_shape")));
2565         def.set_op("VariableShape");
2566       } else {
2567         if (params_info.IsPerReplicaArg(i)) {
2568           TF_RET_CHECK(i < replicate_input_edges.size());
2569           // All replicas must have the same input shapes. Uses the shape of the
2570           // inputs from the first replica.
2571           src = replicate_input_edges[i]->src();
2572           src_output = replicate_input_edges[i]->src_output();
2573         } else {
2574           DCHECK(params_info.IsDistributedArg(i) ||
2575                  params_info.IsBroadcastArg(i));
2576           int64_t input_num =
2577               params_info.NumPerReplicaArgs() * params_info.NumReplicas() + i -
2578               params_info.NumPerReplicaArgs();
2579           TF_RET_CHECK(0 <= input_num &&
2580                        input_num < replicate_input_edges.size());
2581           src = replicate_input_edges[input_num]->src();
2582           src_output = replicate_input_edges[input_num]->src_output();
2583         }
2584 
2585         def.set_name(graph->NewName(strings::StrCat(src->name(), "/shape")));
2586         def.set_op("Shape");
2587         AddNodeAttr("T", src->output_type(src_output), &def);
2588       }
2589 
2590       def.set_device(src->assigned_device_name());
2591       AddNodeAttr("out_type", DT_INT64, &def);
2592       MergeDebugInfo(NodeDebugInfo(replicate_node.def()), &def);
2593 
2594       Status status;
2595       Node* shape_node = graph->AddNode(def, &status);
2596       if (!status.ok()) return status;
2597       dynamic_shape_nodes->push_back(shape_node);
2598 
2599       shape_node->set_assigned_device_name(src->assigned_device_name());
2600       graph->AddEdge(src, src_output, shape_node, 0);
2601       for (Node* control_input : control_inputs) {
2602         graph->AddControlEdge(control_input, shape_node);
2603       }
2604     }
2605   }
2606   return Status::OK();
2607 }
2608 
2609 namespace {
2610 
XlaBroadcastTypeSupported(const DataType dtype)2611 bool XlaBroadcastTypeSupported(const DataType dtype) {
2612   return (dtype == DT_FLOAT || dtype == DT_BFLOAT16 || dtype == DT_INT32 ||
2613           dtype == DT_BOOL);
2614 }
2615 
XlaBroadcastKindSupported(const DistributedTPURewritePass::ParameterInfo & params_info,int param_num)2616 bool XlaBroadcastKindSupported(
2617     const DistributedTPURewritePass::ParameterInfo& params_info,
2618     int param_num) {
2619   // NOTE: This is intended to cover non-sharded data parallel variables, for
2620   // training only. . Is it correct to just check if the arg_type is
2621   // DT_RESOURCE?
2622   return params_info.IsVariableArg(param_num) &&
2623          !(params_info.IsPerReplicaArg(param_num) ||
2624            params_info.IsDistributedArg(param_num) ||
2625            params_info.IsBroadcastArg(param_num) ||
2626            params_info.IsConstantArg(param_num));
2627 }
2628 
EnableXlaParamBroadcast(bool enable_xla_param_broadcast,const DistributedTPURewritePass::ParameterInfo & params_info,int param_num,DataType dtype)2629 bool EnableXlaParamBroadcast(
2630     bool enable_xla_param_broadcast,
2631     const DistributedTPURewritePass::ParameterInfo& params_info, int param_num,
2632     DataType dtype) {
2633   // Conditions necessary to use XLA collectives for arg broadcast:
2634   // 1. Globally enabled via enable_xla_param_broadcast.
2635   // 2. DataType must be supported.
2636   // 3. Parameter must be a variable, and not distributed or broadcasted.
2637   return enable_xla_param_broadcast && XlaBroadcastTypeSupported(dtype) &&
2638          XlaBroadcastKindSupported(params_info, param_num);
2639 }
2640 
2641 }  // namespace
2642 
2643 // Builds a TPUCompile node that compiles the bodies of the function call
2644 // `nodes`.
BuildCompileNode(const Node * replicate_node,const NameAttrList & function,uint64 library_fingerprint,const ParameterInfo & params_info,const std::vector<InferredShape> & arg_shapes,const DataTypeVector & arg_types,const std::vector<Node * > & guaranteed_constant_nodes,const string & session_handle,const std::vector<xla::OpSharding> & arg_sharding,const std::vector<bool> & arg_fast_mem,const std::vector<std::string> & arg_names,const std::vector<xla::OpSharding> & retval_sharding,int num_cores_per_replica,const string & compile_device,const xla::DeviceAssignment * xla_device_assignment,const std::vector<Node * > & dynamic_shape_nodes,Graph * graph,Node ** compile_node,int64_t autotuner_thresh)2645 Status DistributedTPURewritePass::BuildCompileNode(
2646     const Node* replicate_node, const NameAttrList& function,
2647     uint64 library_fingerprint, const ParameterInfo& params_info,
2648     const std::vector<InferredShape>& arg_shapes,
2649     const DataTypeVector& arg_types,
2650     const std::vector<Node*>& guaranteed_constant_nodes,
2651     const string& session_handle,
2652     const std::vector<xla::OpSharding>& arg_sharding,
2653     const std::vector<bool>& arg_fast_mem,
2654     const std::vector<std::string>& arg_names,
2655     const std::vector<xla::OpSharding>& retval_sharding,
2656     int num_cores_per_replica, const string& compile_device,
2657     const xla::DeviceAssignment* xla_device_assignment,
2658     const std::vector<Node*>& dynamic_shape_nodes, Graph* graph,
2659     Node** compile_node, int64_t autotuner_thresh) {
2660   VLOG(1) << "BuildCompileNode";
2661 
2662   tpu::TPUCompileMetadataProto proto;
2663   proto.set_num_replicas(params_info.NumReplicas());
2664   proto.set_num_cores_per_replica(num_cores_per_replica);
2665   proto.set_function_library_fingerprint(library_fingerprint);
2666   proto.set_enable_automatic_model_parallelism(
2667       enable_cross_replica_sharding_mirrored_variables_);
2668   const bool use_spmd =
2669       UseSpmdForXlaPartitioning(replicate_node) && allow_xla_spmd_partition_ &&
2670       !absl::c_any_of(arg_sharding,
2671                       [](const xla::OpSharding& s) {
2672                         return s.type() == xla::OpSharding::MAXIMAL;
2673                       }) &&
2674       !absl::c_any_of(retval_sharding, [](const xla::OpSharding& s) {
2675         return s.type() == xla::OpSharding::MAXIMAL;
2676       });
2677   proto.set_use_spmd_for_xla_partitioning(use_spmd);
2678 
2679   // Get and fill padding map.
2680   if (replicate_node != nullptr) {
2681     xla::DebugOptions::StepMarkerLocation location;
2682     TF_RETURN_IF_ERROR(GetStepMarkerLocation(*replicate_node, &location));
2683     proto.set_step_marker_location(location);
2684   }
2685 
2686   if (xla_device_assignment != nullptr) {
2687     TF_RETURN_IF_ERROR(
2688         xla_device_assignment->Serialize(proto.mutable_device_assignment()));
2689   }
2690 
2691   const int num_args = arg_types.size();
2692   const int num_guaranteed_constants = guaranteed_constant_nodes.size();
2693   const int guaranteed_const_start_index = num_args - num_guaranteed_constants;
2694   TF_RET_CHECK(num_args == arg_shapes.size());
2695   TF_RET_CHECK(num_args == arg_sharding.size())
2696       << num_args << " != " << arg_sharding.size();
2697 
2698   for (int i = 0; i < num_args; ++i) {
2699     tpu::TPUCompileMetadataProto::Arg* arg = proto.add_args();
2700     DataType type = arg_types[i];
2701     const InferredShape& arg_shape = arg_shapes[i];
2702     arg->set_name(arg_names[i]);
2703     if (type == DT_RESOURCE) {
2704       TF_RET_CHECK(arg_shape.handle_type != DT_INVALID) << i;
2705       arg->set_dtype(arg_shape.handle_type);
2706       arg_shape.handle_shape.AsProto(arg->mutable_shape());
2707       arg->set_kind(tpu::TPUCompileMetadataProto::Arg::VARIABLE);
2708       arg->set_fast_mem(arg_fast_mem[i]);
2709     } else {
2710       arg->set_dtype(type);
2711       arg_shape.shape.AsProto(arg->mutable_shape());
2712       if (i >= guaranteed_const_start_index) {
2713         const DataType edge_type =
2714             guaranteed_constant_nodes[i - guaranteed_const_start_index]
2715                 ->output_type(0);
2716         TF_RET_CHECK(type == edge_type)
2717             << "Arg type: " << type << " but edge type: " << edge_type;
2718         arg->set_kind(tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT);
2719       } else {
2720         arg->set_kind(tpu::TPUCompileMetadataProto::Arg::PARAMETER);
2721       }
2722     }
2723 
2724     // Use XLA collective primitives to distribute variables to all replicas.
2725     arg->set_requires_xla_broadcast(
2726         params_info.NumReplicas() > 1 &&
2727         EnableXlaParamBroadcast(enable_xla_param_broadcast_, params_info, i,
2728                                 arg_shape.handle_type /*arg.dtype?*/));
2729 
2730     // As long as the argument is not a per-replica one, it should have the same
2731     // value for all replicas. For clarity, we keep the (redundant) checks for
2732     // variable, broadcast and constant types, to prevent bugs in case new types
2733     // with different semantics are introduced in the future.
2734     arg->set_is_same_data_across_replicas(
2735         !params_info.IsPerReplicaArg(i) && !params_info.IsDistributedArg(i) &&
2736         (params_info.IsVariableArg(i) || params_info.IsBroadcastArg(i) ||
2737          params_info.IsConstantArg(i)));
2738     if (params_info.mirrored_variable_indices().count(i) > 0) {
2739       CHECK_EQ(type, DT_RESOURCE);
2740       arg->set_is_same_data_across_replicas(true);
2741       // 64-bit type is not shardable by XLA:TPU yet.
2742       bool sharding_enabled = (arg_shape.handle_type != DT_COMPLEX64 &&
2743                                arg_shape.handle_type != DT_INT64 &&
2744                                arg_shape.handle_type != DT_UINT64 &&
2745                                arg_shape.handle_type != DT_DOUBLE);
2746       arg->set_enable_xla_sharding(
2747           sharding_enabled ? tpu::TPUCompileMetadataProto::Arg::TENTATIVE
2748                            : tpu::TPUCompileMetadataProto::Arg::DISALLOWED);
2749     }
2750     *arg->mutable_sharding() = arg_sharding[i];
2751   }
2752 
2753   const int num_retvals = retval_sharding.size();
2754   for (int i = 0; i < num_retvals; ++i) {
2755     *proto.add_retvals()->mutable_sharding() = retval_sharding[i];
2756   }
2757   proto.set_session_handle(session_handle);
2758 
2759   DataTypeVector constant_arg_types;
2760   constant_arg_types.reserve(num_guaranteed_constants);
2761   for (int i = 0; i < num_guaranteed_constants; ++i) {
2762     constant_arg_types.push_back(arg_types[guaranteed_const_start_index + i]);
2763   }
2764   proto.set_xla_fusion_autotuner_thresh(autotuner_thresh);
2765 
2766   string metadata;
2767   proto.SerializeToString(&metadata);
2768 
2769   NodeDef def;
2770   def.set_name(UniqueNodeName("TPUReplicate/_compile", graph));
2771   def.set_op("TPUCompile");
2772   def.set_device(compile_device);
2773   if (replicate_node) {
2774     MergeDebugInfo(NodeDebugInfo(replicate_node->def()), &def);
2775   }
2776 
2777   AddNodeAttr("function", function, &def);
2778   AddNodeAttr("num_computations", num_cores_per_replica, &def);
2779   AddNodeAttr("NumDynamicShapes", static_cast<int>(dynamic_shape_nodes.size()),
2780               &def);
2781   AddNodeAttr("metadata", metadata, &def);
2782   AddNodeAttr("Tguaranteed_constants", constant_arg_types, &def);
2783 
2784   Status status;
2785   *compile_node = graph->AddNode(def, &status);
2786   TF_RETURN_IF_ERROR(status);
2787 
2788   (*compile_node)->set_assigned_device_name(compile_device);
2789 
2790   for (int i = 0; i < dynamic_shape_nodes.size(); ++i) {
2791     graph->AddEdge(dynamic_shape_nodes[i], 0, *compile_node, i);
2792   }
2793 
2794   for (int i = 0; i < num_guaranteed_constants; ++i) {
2795     graph->AddEdge(guaranteed_constant_nodes[i], 0, *compile_node,
2796                    dynamic_shape_nodes.size() + i);
2797   }
2798   VLOG(1) << "BuildCompileNode(): " << status;
2799   return status;
2800 }
2801 
FindGuaranteedConstantInputs(const Node & node,const NameRangeMap & input_range_map,std::vector<Node * > * guaranteed_constants)2802 Status DistributedTPURewritePass::FindGuaranteedConstantInputs(
2803     const Node& node, const NameRangeMap& input_range_map,
2804     std::vector<Node*>* guaranteed_constants) {
2805   std::vector<const Edge*> input_edges;
2806   TF_RETURN_IF_ERROR(node.input_edges(&input_edges));
2807   std::pair<int, int> variables_limits =
2808       input_range_map.at("guaranteed_constants");
2809   for (int i = variables_limits.first; i < variables_limits.second; ++i) {
2810     guaranteed_constants->push_back(input_edges[i]->src());
2811   }
2812   return Status::OK();
2813 }
2814 
FindVariableInputs(const Node & node,const NameRangeMap & input_range_map,std::vector<VariableInput> * variables)2815 Status DistributedTPURewritePass::FindVariableInputs(
2816     const Node& node, const NameRangeMap& input_range_map,
2817     std::vector<VariableInput>* variables) {
2818   std::vector<const Edge*> input_edges;
2819   TF_RETURN_IF_ERROR(node.input_edges(&input_edges));
2820   std::pair<int, int> variables_limits = input_range_map.at("variables");
2821   for (int i = variables_limits.first; i < variables_limits.second; ++i) {
2822     Node* node = input_edges[i]->src();
2823 
2824     // Find the type of the VarHandleOp that feeds this node, looking through
2825     // any wrapping Enter or Switch nodes.
2826     while (node->IsEnter() || node->IsSwitch()) {
2827       TF_RETURN_IF_ERROR(node->input_node(0, &node));
2828     }
2829     // Fix the variable device assignment if it is requested with a full name.
2830     if (!node->has_assigned_device_name() &&
2831         !node->requested_device().empty()) {
2832       DeviceNameUtils::ParsedName var_device;
2833       TF_RET_CHECK(DeviceNameUtils::ParseFullName(node->requested_device(),
2834                                                   &var_device));
2835       if (var_device.has_job && var_device.has_replica && var_device.has_task &&
2836           var_device.has_type && var_device.has_id) {
2837         node->set_assigned_device_name(node->requested_device());
2838         if (node != input_edges[i]->src() &&
2839             !input_edges[i]->src()->has_assigned_device_name()) {
2840           input_edges[i]->src()->set_assigned_device_name(
2841               node->requested_device());
2842         }
2843       }
2844     }
2845     if (node->type_string() == kVarHandleOp) {
2846       DataType dtype;
2847       TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "dtype", &dtype));
2848       variables->push_back(VariableInput{input_edges[i]->src(),
2849                                          input_edges[i]->src_output(), dtype});
2850     } else if (node->type_string() == "_Arg") {
2851       std::vector<DataType> dtypes;
2852       TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "_handle_dtypes", &dtypes));
2853       if (dtypes.empty()) {
2854         return errors::Internal(
2855             "_Arg node with resource output must have non-empty _handle_dtypes "
2856             "attribute: ",
2857             node->DebugString());
2858       }
2859       variables->push_back(VariableInput{
2860           input_edges[i]->src(), input_edges[i]->src_output(), dtypes[0]});
2861     } else {
2862       return errors::Internal(
2863           "Cannot handle variable input with node type other than VarHandleOp "
2864           "and _Arg: ",
2865           node->DebugString());
2866     }
2867   }
2868   return Status::OK();
2869 }
2870 
2871 // Builds a NoOp node, used for building control dependencies.
BuildNoopNode(const Node & source,StringPiece name,const string & device,Graph * graph,Node ** node)2872 static Status BuildNoopNode(const Node& source, StringPiece name,
2873                             const string& device, Graph* graph, Node** node) {
2874   NodeDefBuilder builder(name, "NoOp", NodeDebugInfo(source));
2875   if (!device.empty()) {
2876     builder.Device(device);
2877   }
2878   NodeDef def;
2879   TF_RETURN_IF_ERROR(builder.Finalize(&def));
2880 
2881   Status status;
2882   *node = graph->AddNode(def, &status);
2883   if (!device.empty()) {
2884     (*node)->set_assigned_device_name(device);
2885   }
2886   return status;
2887 }
2888 
ConnectHostComputeNodes(Node * compile_node,Node * key_placeholder_node,Graph * graph)2889 Status DistributedTPURewritePass::ConnectHostComputeNodes(
2890     Node* compile_node, Node* key_placeholder_node, Graph* graph) {
2891   // First find all the downstream nodes of the key placeholder node, since we
2892   // want to delete the connecting edges from key_placeholder_node which would
2893   // invalidate the out_nodes iterator.
2894   std::vector<Node*> host_transfer_nodes;
2895   for (Node* node : key_placeholder_node->out_nodes()) {
2896     host_transfer_nodes.push_back(node);
2897   }
2898   for (Node* node : host_transfer_nodes) {
2899     int input_index = -1;
2900     for (int i = 0; i < node->num_inputs(); i++) {
2901       const Edge* e;
2902       TF_RETURN_IF_ERROR(node->input_edge(i, &e));
2903       if (e->src() == key_placeholder_node) {
2904         if (input_index != -1) {
2905           return errors::Internal(
2906               "Node ", node->name(),
2907               " has multiple input edges from key placeholder node");
2908         }
2909         input_index = e->dst_input();
2910       }
2911     }
2912     if (input_index == -1) {
2913       return errors::Internal("Node ", node->name(),
2914                               " has no input edge from key placeholder node");
2915     }
2916     const Edge* key_edge;
2917     TF_RETURN_IF_ERROR(node->input_edge(input_index, &key_edge));
2918     graph->RemoveEdge(key_edge);
2919     graph->AddEdge(compile_node, 1, node, input_index);
2920   }
2921   graph->RemoveNode(key_placeholder_node);
2922   return Status::OK();
2923 }
2924 
BuildVariableReads(absl::Span<const VariableInput> variables,Node * control_predecessor,Graph * graph,std::vector<Node * > * variable_reads)2925 Status DistributedTPURewritePass::BuildVariableReads(
2926     absl::Span<const VariableInput> variables, Node* control_predecessor,
2927     Graph* graph, std::vector<Node*>* variable_reads) {
2928   variable_reads->resize(variables.size());
2929   for (int i = 0; i < variables.size(); ++i) {
2930     string name =
2931         graph->NewName(strings::StrCat(variables[i].node->name(), "/read"));
2932     NodeDefBuilder builder(name, "ReadVariableOp",
2933                            NodeDebugInfo(*variables[i].node));
2934 
2935     builder.Attr("dtype", variables[i].dtype);
2936     builder.Device(variables[i].node->assigned_device_name());
2937     builder.Input(variables[i].node->name(), 0, DT_RESOURCE);
2938     NodeDef def;
2939     TF_RETURN_IF_ERROR(builder.Finalize(&def));
2940 
2941     Status status;
2942     Node* read_node;
2943     (*variable_reads)[i] = read_node = graph->AddNode(def, &status);
2944     if (!status.ok()) return status;
2945 
2946     read_node->set_requested_device(variables[i].node->requested_device());
2947     read_node->set_assigned_device_name(
2948         variables[i].node->assigned_device_name());
2949     graph->AddEdge(variables[i].node, variables[i].index, read_node, 0);
2950 
2951     graph->AddControlEdge(control_predecessor, read_node);
2952   }
2953   return Status::OK();
2954 }
2955 
ContainsResourceWriteOp(const Graph & graph,const FunctionLibraryDefinition & fld)2956 bool DistributedTPURewritePass::ContainsResourceWriteOp(
2957     const Graph& graph, const FunctionLibraryDefinition& fld) {
2958   for (const Node* n : graph.nodes()) {
2959     const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n->type_string());
2960     if (op_info && op_info->kind() != XlaResourceOpKind::kRead) {
2961       VLOG(2) << "Found write resource op inside computation";
2962       return true;
2963     }
2964   }
2965   for (const string& func_name : fld.ListFunctionNames()) {
2966     const FunctionDef* func_def = fld.Find(func_name);
2967     for (const NodeDef& n : func_def->node_def()) {
2968       const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.op());
2969       if (op_info && op_info->kind() != XlaResourceOpKind::kRead) {
2970         VLOG(2) << "Found write resource op inside " << func_name;
2971         return true;
2972       }
2973     }
2974   }
2975   return false;
2976 }
2977 
BuildVariableWrites(absl::Span<const VariableInput> variables,Node * control_successor,absl::Span<const VariableWrite> variable_writes,Graph * graph)2978 Status DistributedTPURewritePass::BuildVariableWrites(
2979     absl::Span<const VariableInput> variables, Node* control_successor,
2980     absl::Span<const VariableWrite> variable_writes, Graph* graph) {
2981   CHECK_EQ(variables.size(), variable_writes.size());
2982   for (int i = 0; i < variables.size(); ++i) {
2983     const VariableWrite& write = variable_writes[i];
2984     NodeDebugInfo debug_info(*variables[i].node);
2985 
2986     auto name = [&](string suffix) {
2987       return graph->NewName(
2988           strings::StrCat(variables[i].node->name(), "/", suffix));
2989     };
2990 
2991     Node* write_node;
2992     TF_RETURN_IF_ERROR(
2993         IncompleteNodeDefBuilder(name("assign"), "AssignVariableOp", debug_info)
2994             .AddAttr("dtype", variables[i].dtype)
2995             .Device(variables[i].node->assigned_device_name())
2996             .Build(graph, &write_node));
2997 
2998     // Colocate the control flow with the variable.
2999     CondBuilder cb(variables[i].node->name(),
3000                    variables[i].node->assigned_device_name(), debug_info,
3001                    graph);
3002 
3003     // Inputs to conditional.
3004     Node* switch_val;
3005     TF_RETURN_IF_ERROR(
3006         cb.AddInput("switch_val", variables[i].dtype,
3007                     /*device=*/write.value->assigned_device_name(), debug_info,
3008                     &switch_val));
3009     Node* switch_var;
3010     TF_RETURN_IF_ERROR(
3011         cb.AddInput("switch_var", DT_RESOURCE,
3012                     /*device=*/variables[i].node->assigned_device_name(),
3013                     debug_info, &switch_var));
3014     // Conditionally write the value back.
3015     graph->AddEdge(variables[i].node, variables[i].index, switch_var, 0);
3016     graph->AddEdge(switch_var, CondBuilder::kThenBranch, write_node, 0);
3017     graph->AddEdge(switch_val, CondBuilder::kThenBranch, write_node, 1);
3018     // Add control edge from the write to value that will be merged. There is no
3019     // output from the write so this control edge ensures the write completes.
3020     graph->AddControlEdge(write_node, cb.switch_t());
3021 
3022     graph->AddControlEdge(cb.control_successor(), control_successor);
3023 
3024     graph->AddEdge(write.predicate, write.predicate_output, cb.pred(), 0);
3025     graph->AddEdge(write.value, write.value_output, switch_val, 0);
3026   }
3027   return Status::OK();
3028 }
3029 
3030 namespace {
3031 
3032 // Computes the shape of the sharded tensor and modifies in place.
ComputeShardedArgShapes(TensorShape * shape,const xla::OpSharding & sharding)3033 Status ComputeShardedArgShapes(TensorShape* shape,
3034                                const xla::OpSharding& sharding) {
3035   if (sharding.type() != xla::OpSharding::OTHER) {
3036     return Status::OK();
3037   }
3038   if (!shape->IsFullyDefined()) {
3039     return errors::Internal(
3040         "Arg shape must be fully defined before sharded shape inference.");
3041   }
3042   int sharded_rank = sharding.tile_assignment_dimensions_size();
3043   if (sharding.replicate_on_last_tile_dim()) {
3044     sharded_rank--;
3045   }
3046   for (int dim_idx = 0; dim_idx < sharded_rank; ++dim_idx) {
3047     auto sharded_dim = tensorflow::MathUtil::CeilOfRatio<int64>(
3048         shape->dim_size(dim_idx), sharding.tile_assignment_dimensions(dim_idx));
3049     shape->set_dim(dim_idx, sharded_dim);
3050   }
3051   if (sharded_rank != shape->dims()) {
3052     LOG(WARNING) << "Rank of sharded arg should match sharding spec.  Rank: "
3053                  << sharded_rank << ", tiled shape: " << shape->DebugString()
3054                  << ", sharding: " << sharding.DebugString();
3055   }
3056 
3057   return Status::OK();
3058 }
3059 
3060 // Creates nodes for zero-initialized dummy arguments for TPUExecute nodes.
CreateTpuExecuteDummyArg(const TensorShape & var_shape,const DataType & dtype,const string & host_cpu_device,Node * var_read,int replica_id,Graph * graph)3061 xla::StatusOr<Node*> CreateTpuExecuteDummyArg(const TensorShape& var_shape,
3062                                               const DataType& dtype,
3063                                               const string& host_cpu_device,
3064                                               Node* var_read, int replica_id,
3065                                               Graph* graph) {
3066   Status status;
3067 
3068   // Const - shape_as_tensor
3069   const std::string name_prefix = strings::StrCat(
3070       var_read->name(), absl::StrFormat("/dummy_%d", replica_id));
3071   NodeDef shape_tensor_def;
3072   shape_tensor_def.set_op("Const");
3073   shape_tensor_def.set_name(graph->NewName(
3074       strings::StrCat(name_prefix, "/Initializer/zeros/shape_as_tensor")));
3075   shape_tensor_def.set_device(host_cpu_device);
3076   AddNodeAttr("dtype", DT_INT32, &shape_tensor_def);
3077   TensorProto tensorshape_proto;
3078   tensorshape_proto.set_dtype(DT_INT32);
3079   for (int i = 0; i < var_shape.dims(); ++i) {
3080     tensorshape_proto.add_int_val(var_shape.dim_size(i));
3081   }
3082   TensorShape shape_shape({var_shape.dims()});
3083   shape_shape.AsProto(tensorshape_proto.mutable_tensor_shape());
3084   AddNodeAttr("value", tensorshape_proto, &shape_tensor_def);
3085   Node* shape_as_tensor_node = graph->AddNode(shape_tensor_def, &status);
3086   TF_RETURN_IF_ERROR(status);
3087 
3088   // Const - initializer value
3089   NodeDef init_val_def;
3090   init_val_def.set_op("Const");
3091   init_val_def.set_name(graph->NewName(
3092       strings::StrCat(name_prefix, "/Initializer/zeros/const_val")));
3093   init_val_def.set_device(host_cpu_device);
3094   TensorProto tensor_proto;
3095   tensor_proto.set_dtype(dtype);
3096   if (dtype == DT_FLOAT) {
3097     tensor_proto.add_float_val(0.0f);
3098   } else if (dtype == DT_BFLOAT16) {
3099     tensor_proto.add_half_val(0);
3100   } else if (dtype == DT_INT32) {
3101     tensor_proto.add_int_val(0);
3102   } else if (dtype == DT_BOOL) {
3103     tensor_proto.add_bool_val(false);
3104   } else {
3105     return errors::Internal(
3106         "Unable to create zero-init dummy arg tensor for type ", dtype);
3107   }
3108   TensorShape scalar_shape({});
3109   scalar_shape.AsProto(tensor_proto.mutable_tensor_shape());
3110   AddNodeAttr("value", tensor_proto, &init_val_def);
3111   AddNodeAttr("dtype", dtype, &init_val_def);
3112   Node* init_val_node = graph->AddNode(init_val_def, &status);
3113   TF_RETURN_IF_ERROR(status);
3114 
3115   // Fill node
3116   NodeDef fill_def;
3117   fill_def.set_op("Fill");
3118   fill_def.set_device(host_cpu_device);
3119   fill_def.set_name(
3120       graph->NewName(strings::StrCat(name_prefix, "/Initializer/zeros")));
3121   AddNodeAttr("T", dtype, &fill_def);
3122   AddNodeAttr("index_type", DT_INT32, &fill_def);
3123   Node* fill_node = graph->AddNode(fill_def, &status);
3124   TF_RETURN_IF_ERROR(status);
3125   graph->AddEdge(shape_as_tensor_node, 0, fill_node, 0);
3126   graph->AddEdge(init_val_node, 0, fill_node, 1);
3127 
3128   return fill_node;
3129 }
3130 
3131 // Creates dummy inputs for partitioned variables that are using XLA broadcast
3132 // for inputs.
CreatePartitionedDummyVarArgs(const xla::OpSharding & sharding,const int num_replicas,const int replica_id,const InferredShape & raw_shape,Node * orig_var_read,const int orig_arg_num,DataType dtype,const string & device,Graph * graph,const std::vector<std::vector<string>> & tpu_device_names,absl::btree_map<ShardedPerHostInputIndex,Node * > * per_host_index,std::map<ShardedInputIndex,ShardedInputInfo> * arg_index_to_sharded_input_map)3133 Status CreatePartitionedDummyVarArgs(
3134     const xla::OpSharding& sharding, const int num_replicas,
3135     const int replica_id, const InferredShape& raw_shape, Node* orig_var_read,
3136     const int orig_arg_num, DataType dtype, const string& device, Graph* graph,
3137     const std::vector<std::vector<string>>& tpu_device_names,
3138     absl::btree_map<ShardedPerHostInputIndex, Node*>* per_host_index,
3139     std::map<ShardedInputIndex, ShardedInputInfo>*
3140         arg_index_to_sharded_input_map) {
3141   ShardedInputIndex input_index{replica_id, orig_arg_num};
3142   auto iter = arg_index_to_sharded_input_map->find(input_index);
3143   if (iter != arg_index_to_sharded_input_map->end()) {
3144     return Status::OK();
3145   }
3146   const int repeat = sharding.replicate_on_last_tile_dim()
3147                          ? *sharding.tile_assignment_dimensions().rbegin()
3148                          : 1;
3149   const int num_shards = sharding.tile_assignment_devices_size() / repeat;
3150 
3151   TensorShape var_shape;
3152   if (!raw_shape.handle_shape.AsTensorShape(&var_shape) &&
3153       !raw_shape.shape.AsTensorShape(&var_shape)) {
3154     return errors::FailedPrecondition("Failed to read arg shape.");
3155   }
3156   TF_RETURN_IF_ERROR(ComputeShardedArgShapes(&var_shape, sharding));
3157 
3158   for (int replica = 1; replica < num_replicas; ++replica) {
3159     std::vector<NodeOut> sharded_inputs_list(
3160         sharding.tile_assignment_devices_size());
3161     for (int i = 0; i < num_shards; ++i) {
3162       for (int j = 0; j < repeat; ++j) {
3163         const int index = i * repeat + j;
3164         const int core = sharding.tile_assignment_devices(index);
3165         string host_device;
3166         TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
3167             tpu_device_names[replica][core], &host_device));
3168         ShardedPerHostInputIndex idx{host_device, orig_arg_num};
3169         if (!per_host_index->contains(idx)) {
3170           TF_ASSIGN_OR_RETURN(
3171               auto dummy_node,
3172               CreateTpuExecuteDummyArg(var_shape, dtype, host_device,
3173                                        orig_var_read, replica, graph));
3174           (*per_host_index)[idx] = dummy_node;
3175         }
3176         sharded_inputs_list[core] = {(*per_host_index)[idx], /*index=*/0};
3177       }
3178     }
3179     ShardedInputInfo sharded_input_info{nullptr,
3180                                         std::move(sharded_inputs_list)};
3181     (*arg_index_to_sharded_input_map)[{replica, orig_arg_num}] =
3182         sharded_input_info;
3183   }
3184 
3185   return Status::OK();
3186 }
3187 
3188 // Helper that creates an IdentityN node containing all of the variables
3189 // values on CPU device 'device', except for those that will be split across
3190 // cores. (For split variables, this may cause additional cross-host data
3191 // transfers if more than 1 devices share the same variable partition on a
3192 // remote host.)
3193 //
3194 // A previous iteration of this code built one Identity node per TPU core per
3195 // variable, but this can rapidly become hundreds of thousands of nodes. This
3196 // formulation creates a single IdentityN node containing all of the variables
3197 // on each host. This may cause some unnecessary variable copies if only a
3198 // subset of hosts consume a given variable, but has the virtue of being
3199 // simple, and most models use pure replication where all cores want all the
3200 // variables.
3201 //
3202 // If enable_xla_param_broadcast is set to true, then per-host dummy
3203 // tensor args are created on all hosts except for the primary host. In this
3204 // scheme, the dummy args feed the IdentityN node on their local host. All
3205 // are zero-initialized.
3206 //
3207 // Returns the node and its output index to be consumed by TPUExecute for the
3208 // requested variable index.
CreateOrGetPerHostVariableCopy(const string & host_cpu_device,int64_t var_index,const std::vector<Node * > & variable_reads,const DistributedTPURewritePass::ParameterInfo & params_info,const std::vector<xla::OpSharding> & arg_shardings,const Node & replicate_node,const bool enable_xla_param_broadcast,const int num_cores_per_replica,int replica_id,const std::vector<InferredShape> & arg_shapes,absl::flat_hash_map<string,std::vector<NodeOut>> * per_host_var_copies,Graph * graph)3209 xla::StatusOr<NodeOut> CreateOrGetPerHostVariableCopy(
3210     const string& host_cpu_device, int64_t var_index,
3211     const std::vector<Node*>& variable_reads,
3212     const DistributedTPURewritePass::ParameterInfo& params_info,
3213     const std::vector<xla::OpSharding>& arg_shardings,
3214     const Node& replicate_node, const bool enable_xla_param_broadcast,
3215     const int num_cores_per_replica, int replica_id,
3216     const std::vector<InferredShape>& arg_shapes,
3217     absl::flat_hash_map<string, std::vector<NodeOut>>* per_host_var_copies,
3218     Graph* graph) {
3219   auto it = per_host_var_copies->find(host_cpu_device);
3220   if (it != per_host_var_copies->end()) {
3221     return it->second[var_index];
3222   }
3223 
3224   DataTypeVector dtypes;
3225   // Per-variable data source for TPUExecute.
3226   std::vector<NodeOut> index_mapping;
3227   index_mapping.reserve(variable_reads.size());
3228   dtypes.reserve(variable_reads.size());
3229   for (int64_t i = 0; i < variable_reads.size(); ++i) {
3230     Node* read = variable_reads[i];
3231     int64_t orig_arg_num = i + params_info.NumPerReplicaArgs() +
3232                            params_info.NumDistributedArgs() +
3233                            params_info.NumBroadcastArgs();
3234     if (arg_shardings[orig_arg_num].type() != xla::OpSharding::OTHER) {
3235       // We haven't built the IdentityN node yet, so temporarily use nullptr.
3236       index_mapping.push_back(
3237           NodeOut{nullptr, static_cast<int>(dtypes.size())});
3238       dtypes.push_back(read->output_type(0));
3239     } else {
3240       // Do not copy the full tensor of partitioned variables.
3241       index_mapping.push_back(NodeOut{read, 0});
3242     }
3243   }
3244   NodeDef ndef;
3245   ndef.set_name(graph->NewName(
3246       absl::StrCat(replicate_node.name(), "/", kTpuExecuteStagingNodeName)));
3247   ndef.set_op(kTpuExecuteStagingOp);
3248   ndef.set_device(host_cpu_device);
3249   AddNodeAttr("T", dtypes, &ndef);
3250   // TF meta-optimizer should skip this node for constant folding.
3251   AddNodeAttr("_tpu_avoid_constant_fold", "not_used", &ndef);
3252   Status s;
3253   Node* id_node = graph->AddNode(ndef, &s);
3254   TF_RETURN_IF_ERROR(s);
3255   id_node->set_assigned_device_name(host_cpu_device);
3256 
3257   for (int64_t i = 0; i < variable_reads.size(); ++i) {
3258     Node* read = variable_reads[i];
3259     int64_t orig_arg_num = i + params_info.NumPerReplicaArgs() +
3260                            params_info.NumDistributedArgs() +
3261                            params_info.NumBroadcastArgs();
3262     DataType dtype = read->output_type(0);
3263     bool use_xla_broadcast =
3264         EnableXlaParamBroadcast(enable_xla_param_broadcast, params_info,
3265                                 orig_arg_num, dtype) &&
3266         replica_id != 0;
3267     if (index_mapping[i].node == nullptr) {
3268       // Fill index_mapping with the actual IdentityN node.
3269       index_mapping[i].node = id_node;
3270       if (!use_xla_broadcast) {
3271         // Add the variable read edge to id_node.
3272         graph->AddEdge(variable_reads[i], 0, id_node, index_mapping[i].index);
3273       } else {
3274         // XLA param broadcast mode is enabled.  Create zero-valued dummy
3275         // tensors to use as variable args in the TPUExecuteOp, instead of
3276         // original variable reads.
3277         TensorShape var_shape;
3278         auto inferred_shape = arg_shapes[orig_arg_num];
3279         if (!inferred_shape.handle_shape.AsTensorShape(&var_shape) &&
3280             !inferred_shape.shape.AsTensorShape(&var_shape)) {
3281           return errors::FailedPrecondition("Failed to read arg shape.");
3282         }
3283         TF_ASSIGN_OR_RETURN(
3284             Node * dummy_read,
3285             CreateTpuExecuteDummyArg(var_shape, dtype, host_cpu_device,
3286                                      variable_reads[i], replica_id, graph));
3287         graph->AddEdge(dummy_read, 0, id_node, index_mapping[i].index);
3288       }
3289     }
3290   }
3291 
3292   auto result = index_mapping[var_index];
3293   (*per_host_var_copies)[host_cpu_device] = std::move(index_mapping);
3294   return result;
3295 }
3296 
3297 }  // namespace
3298 
BuildExecuteNodes(const ParameterInfo & params_info,int num_tasks,int num_cores_per_replica,const Node & replicate_node,const std::vector<std::string> & arg_names,const DataTypeVector & arg_types,const std::vector<InferredShape> & arg_shapes,const DataTypeVector & retval_types,const std::vector<xla::OpSharding> & arg_shardings,const std::vector<xla::OpSharding> & retval_shardings,const std::vector<std::vector<string>> & tpu_device_names,Node * compile_node,const std::vector<Node * > & variable_reads,Node * control_predecessor,Node * control_successor,Node * multilock_acquire,std::vector<VariableWrite> * variable_writes,Graph * graph)3299 Status DistributedTPURewritePass::BuildExecuteNodes(
3300     const ParameterInfo& params_info, int num_tasks, int num_cores_per_replica,
3301     const Node& replicate_node, const std::vector<std::string>& arg_names,
3302     const DataTypeVector& arg_types,
3303     const std::vector<InferredShape>& arg_shapes,
3304     const DataTypeVector& retval_types,
3305     const std::vector<xla::OpSharding>& arg_shardings,
3306     const std::vector<xla::OpSharding>& retval_shardings,
3307     const std::vector<std::vector<string>>& tpu_device_names,
3308     Node* compile_node, const std::vector<Node*>& variable_reads,
3309     Node* control_predecessor, Node* control_successor, Node* multilock_acquire,
3310     std::vector<VariableWrite>* variable_writes, Graph* graph) {
3311   VLOG(1) << "BuildExecuteNodes " << replicate_node.DebugString();
3312   TF_RET_CHECK(params_info.NumReplicas() == tpu_device_names.size());
3313 
3314   const int num_variables = variable_reads.size();
3315   const int num_retvals_per_replica = retval_types.size();
3316 
3317   variable_writes->resize(num_variables);
3318 
3319   std::vector<const Edge*> replicate_input_edges;
3320   TF_RETURN_IF_ERROR(replicate_node.input_edges(&replicate_input_edges));
3321 
3322   // Map from replicate input index to the fan_in node;
3323   absl::flat_hash_map<int, std::vector<NodeAndPort>>
3324       replicate_input_fan_in_nodes;
3325   absl::flat_hash_map<int, std::vector<Node*>> replicate_output_fan_out_nodes;
3326   absl::flat_hash_map<int, std::vector<int>>
3327       replicate_output_fan_out_dst_inputs;
3328   std::vector<Node*> to_be_removed_nodes;
3329 
3330   for (const Edge* e : replicate_input_edges) {
3331     if (e->src()->type_string() == kTPUPartitionedInput) {
3332       int num_users = 0;
3333       for (const auto& ue : e->src()->out_edges()) {
3334         if (!ue->IsControlEdge()) ++num_users;
3335       }
3336       if (num_users != 1) {
3337         return tensorflow::errors::InvalidArgument(
3338             e->src()->name(), " must only have one user. Found ", num_users);
3339       }
3340       to_be_removed_nodes.push_back(e->src());
3341       std::vector<NodeAndPort>& nodes =
3342           replicate_input_fan_in_nodes[e->dst_input()];
3343       nodes.resize(num_cores_per_replica, NodeAndPort(nullptr, 0));
3344       VLOG(2) << "allocate " << num_cores_per_replica
3345               << " for replicate_input_fan_in_nodes[" << e->dst_input() << "]";
3346       std::vector<const Edge*> fan_in_edges;
3347       TF_RETURN_IF_ERROR(e->src()->input_edges(&fan_in_edges));
3348       TF_RET_CHECK(fan_in_edges.size() == num_cores_per_replica);
3349 
3350       for (const Edge* fe : fan_in_edges) {
3351         nodes[fe->dst_input()].node = fe->src();
3352         nodes[fe->dst_input()].port = fe->src_output();
3353         VLOG(2) << "replicate_input_fan_in_nodes[" << e->dst_input() << "]["
3354                 << fe->dst_input() << "] = " << fe->src()->name();
3355       }
3356     }
3357   }
3358 
3359   // Replicate output edges are sorted by replica id and then by outputs for
3360   // each replica. For example, if TPU Computation has outputs (output_1,
3361   // output_2, and output_3) and number of replicas is 2, then
3362   // replicate_output_edges order would be:
3363   // output_1_replica_1, output_2_replica_1, output_3_replica_1,
3364   // output_1_replica_2, output_2_replica_2, output_3_replica_2.
3365   std::vector<const Edge*> replicate_output_edges(replicate_node.num_outputs(),
3366                                                   nullptr);
3367   for (const Edge* edge : replicate_node.out_edges()) {
3368     if (edge->IsControlEdge()) continue;
3369 
3370     int num_partitioned_outputs = 0;
3371 
3372     for (const Edge* out_edge : edge->dst()->out_edges()) {
3373       if (out_edge->dst()->type_string() == kTPUPartitionedOutput) {
3374         num_partitioned_outputs++;
3375         // Paths between replicate_node and replicate_output_fan_out_nodes:
3376         // ReplicateNode->TpuOutIdenity->kTPUPartitionedOutput->fan-out-nodes
3377         TF_RET_CHECK(edge->dst()->out_edges().size() == 1);
3378         to_be_removed_nodes.push_back(edge->dst());
3379         to_be_removed_nodes.push_back(out_edge->dst());
3380         // Get the right replicated id from the replicate_output_edge.
3381         std::vector<Node*>& nodes =
3382             replicate_output_fan_out_nodes[edge->src_output()];
3383         std::vector<int>& dst_inputs =
3384             replicate_output_fan_out_dst_inputs[edge->src_output()];
3385         nodes.resize(num_cores_per_replica, nullptr);
3386         dst_inputs.resize(num_cores_per_replica, 0);
3387         TF_RET_CHECK(out_edge->dst()->out_edges().size() ==
3388                      num_cores_per_replica);
3389 
3390         for (const Edge* fe : out_edge->dst()->out_edges()) {
3391           nodes[fe->src_output()] = fe->dst();
3392           dst_inputs[fe->src_output()] = fe->dst_input();
3393           VLOG(2) << "replicate_output_fan_out_nodes[" << out_edge->src_output()
3394                   << "][" << fe->src_output()
3395                   << "] = " << fe->dst()->DebugString() << " with dst_input "
3396                   << fe->dst_input();
3397         }
3398       }
3399     }
3400     replicate_output_edges[edge->src_output()] = edge;
3401     if (num_partitioned_outputs > 1) {
3402       return errors::InvalidArgument(
3403           "More than one TPUPartitionedOutput per replicated output.");
3404     }
3405   }
3406 
3407   const int num_execute_args =
3408       arg_shardings.size() - params_info.NumGuaranteedConstants();
3409   // Inverts the arg_shardings and retval_shardings mappings to
3410   // form core -> {argument number} maps.
3411   std::vector<std::vector<int>> core_arg_nums(num_cores_per_replica);
3412   for (int i = 0; i < num_execute_args; ++i) {
3413     const auto& sharding = arg_shardings[i];
3414     if (sharding.type() == xla::OpSharding::MAXIMAL) {
3415       int core = sharding.tile_assignment_devices(0);
3416       TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica));
3417       core_arg_nums[core].push_back(i);
3418     } else if (sharding.type() == xla::OpSharding::OTHER) {
3419       for (int64_t core : sharding.tile_assignment_devices()) {
3420         core_arg_nums[core].push_back(i);
3421       }
3422     } else if (sharding.type() == xla::OpSharding::REPLICATED) {
3423       for (int core = 0; core < num_cores_per_replica; ++core) {
3424         core_arg_nums[core].push_back(i);
3425       }
3426     } else {
3427       return tensorflow::errors::InvalidArgument(
3428           "Unsupported argument sharding for arg=", arg_names[i],
3429           " shape=", arg_shapes[i].shape.DebugString(), ": ",
3430           sharding.DebugString());
3431     }
3432   }
3433   std::vector<std::vector<int>> core_retval_nums(num_cores_per_replica);
3434   for (int i = 0; i < retval_shardings.size(); ++i) {
3435     const auto& sharding = retval_shardings[i];
3436     if (sharding.type() == xla::OpSharding::MAXIMAL) {
3437       int core = sharding.tile_assignment_devices(0);
3438       TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica));
3439       core_retval_nums[core].push_back(i);
3440     } else if (sharding.type() == xla::OpSharding::REPLICATED) {
3441       for (int core = 0; core < num_cores_per_replica; ++core) {
3442         core_retval_nums[core].push_back(i);
3443       }
3444     } else if (sharding.type() == xla::OpSharding::OTHER) {
3445       for (int64_t core : sharding.tile_assignment_devices()) {
3446         core_retval_nums[core].push_back(i);
3447       }
3448     } else {
3449       return tensorflow::errors::InvalidArgument(
3450           "Unsupported argument sharding: ", sharding.DebugString());
3451     }
3452   }
3453 
3454   // Maps host device name to a list of per-variable pairs (variable_copy_node,
3455   // output_index_of_copy_node).
3456   absl::flat_hash_map<string, std::vector<NodeOut>> per_host_var_copies;
3457 
3458   Node* execute_successor = control_successor;
3459 
3460   int num_total_cores = params_info.NumReplicas() * num_cores_per_replica;
3461   if (enable_multicore_locking_ && num_total_cores > 1) {
3462     // Add a node to release exclusive access once all the cores have finished
3463     // execution.
3464     NodeDef lock_def;
3465     lock_def.set_name(graph->NewName(
3466         strings::StrCat(compile_node->name(), "/", "tpu_release_multilock")));
3467     lock_def.set_op("ConsumeTpuMultilock");
3468     MergeDebugInfo(NodeDebugInfo(replicate_node.def()), &lock_def);
3469     Status status;
3470     Node* multilock_release = graph->AddNode(lock_def, &status);
3471     TF_RETURN_IF_ERROR(status);
3472     multilock_release->set_assigned_device_name(
3473         compile_node->assigned_device_name());
3474     TF_RET_CHECK(multilock_acquire != nullptr);
3475     graph->AddEdge(multilock_acquire, 0, multilock_release, 0);
3476     graph->AddControlEdge(multilock_release, control_successor);
3477     // Make sure all execute Ops happen before the multilock_release.
3478     execute_successor = multilock_release;
3479   }
3480 
3481   // Mapping from original resource arg number to a second level map. Second
3482   // level map is from core id to output index of updated variable value.
3483   absl::flat_hash_map<int, absl::flat_hash_map<int, int>>
3484       orig_arg_num_to_output_index_mapping;
3485   // Mapping from retval index to a second level map. Second level map is from
3486   // core id to output index of sharded output value.
3487   std::unordered_map<int, std::unordered_map<int, int>>
3488       retval_index_to_output_index_mapping;
3489 
3490   // Represents mapping of argument index of sharded input to each
3491   // TPUExecute node to its corresponding Split node and its output index
3492   // from which sharded input will be fed into TPUExecute node.
3493   std::map<ShardedInputIndex, ShardedInputInfo> input_index_to_sharded_inputs;
3494 
3495   // Additional map of {host, arg_num} to dummy input. Per-task copies of the
3496   // inputs reduces cross-task communication and allows sharing across replicas.
3497   absl::btree_map<ShardedPerHostInputIndex, Node*> sharded_per_host_index;
3498 
3499   // Builds one TPUExecute node per core per replica.
3500   std::vector<std::vector<Node*>> execute_nodes(params_info.NumReplicas());
3501   for (int core = 0; core < num_cores_per_replica; ++core) {
3502     DataTypeVector core_retval_types;
3503     for (int output : core_retval_nums[core]) {
3504       core_retval_types.push_back(retval_types[output]);
3505     }
3506     DataTypeVector core_arg_types;
3507     std::vector<int> core_variable_writes;
3508     for (int input : core_arg_nums[core]) {
3509       // Resource variables can be passed either by reference (as a DT_RESOURCE)
3510       // tensor or by value (as the variable's current value). Per-replica or
3511       // distributed resource arguments are always passed by reference and
3512       // broadcast variables are always passed by value.
3513       if (arg_types[input] == DT_RESOURCE &&
3514           !params_info.IsPerReplicaArg(input) &&
3515           !params_info.IsDistributedArg(input)) {
3516         DataType handle_type = arg_shapes[input].handle_type;
3517         TF_RET_CHECK(handle_type != DT_INVALID) << DataTypeString(handle_type);
3518         core_arg_types.push_back(handle_type);
3519         int base = input - params_info.NumPerReplicaArgs() -
3520                    params_info.NumDistributedArgs() -
3521                    params_info.NumBroadcastArgs();
3522         // Variables passed by value will have a corresponding additional output
3523         // containing an updated value for the variable.
3524         core_variable_writes.push_back(base);
3525         core_retval_types.push_back(handle_type);
3526       } else {
3527         core_arg_types.push_back(arg_types[input]);
3528       }
3529     }
3530 
3531     NodeDef def;
3532     def.set_op("TPUExecute");
3533     MergeDebugInfo(NodeDebugInfo(replicate_node.def()), &def);
3534     AddNodeAttr("Targs", core_arg_types, &def);
3535     AddNodeAttr("Tresults", core_retval_types, &def);
3536 
3537     for (int64_t replica = 0; replica < params_info.NumReplicas(); ++replica) {
3538       def.set_name(strings::StrCat(replicate_node.name(), "/_execute_", replica,
3539                                    "_", core));
3540 
3541       Status status;
3542       Node* node = graph->AddNode(def, &status);
3543       if (!status.ok()) return status;
3544       execute_nodes[replica].push_back(node);
3545 
3546       node->set_assigned_device_name(tpu_device_names[replica][core]);
3547 
3548       // Add control edges to ensure that execution happens after
3549       // `control_predecessor`, happens before `execute_successor`, and is
3550       // triggered by evaluating any operator that depends on the original
3551       // TPUReplicate operator. See the comment at the top of the header file
3552       // for more details.
3553       graph->AddControlEdge(control_predecessor, node);
3554       graph->AddControlEdge(node, execute_successor);
3555 
3556       // Add data input edges.
3557       for (int64_t i = 0; i < core_arg_nums[core].size(); ++i) {
3558         int64_t orig_arg_num = core_arg_nums[core][i];
3559         VLOG(2) << " replica " << replica << " core " << core << " i " << i
3560                 << " orig_arg_num " << orig_arg_num;
3561         const bool is_per_replica_arg =
3562             params_info.IsPerReplicaArg(orig_arg_num);
3563         if (is_per_replica_arg || params_info.IsDistributedArg(orig_arg_num)) {
3564           // Per-replica input and distributed input
3565           const int64_t input_num =
3566               is_per_replica_arg ? replica * params_info.NumPerReplicaArgs() +
3567                                        core_arg_nums[core][i]
3568                                  : params_info.NumReplicas() *
3569                                            params_info.NumPerReplicaArgs() +
3570                                        core_arg_nums[core][i] -
3571                                        params_info.NumPerReplicaArgs();
3572 
3573           const Edge* edge = replicate_input_edges[input_num];
3574           VLOG(2) << "replicate_input_edges[" << input_num << "]";
3575           DataType dtype = edge->src()->output_type(edge->src_output());
3576           if (dtype == DT_RESOURCE) {
3577             DataType handle_dtype = arg_shapes[orig_arg_num].handle_type;
3578             if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(),
3579                           handle_dtype) == kTpuAllTypes.end()) {
3580               return errors::InvalidArgument(
3581                   "Unsupported resource variable data type for TPU: ",
3582                   DataTypeString(handle_dtype), ", caused by output ",
3583                   edge->src()->name(), ":", edge->src_output());
3584             }
3585           } else {
3586             if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) ==
3587                 kTpuAllTypes.end()) {
3588               return errors::InvalidArgument(
3589                   "Unsupported data type for TPU: ", DataTypeString(dtype),
3590                   ", caused by output ", edge->src()->name(), ":",
3591                   edge->src_output());
3592             }
3593           }
3594           if (arg_shardings[orig_arg_num].type() == xla::OpSharding::OTHER) {
3595             // Don't automatically add a split node when input node is
3596             // kTPUPartitionedInput
3597             if (edge->src()->type_string() == kTPUPartitionedInput) {
3598               VLOG(2)
3599                   << "Connect "
3600                   << replicate_input_fan_in_nodes[input_num][core].node->name()
3601                   << " to " << node->name() << " at " << i;
3602               graph->AddEdge(replicate_input_fan_in_nodes[input_num][core].node,
3603                              replicate_input_fan_in_nodes[input_num][core].port,
3604                              node, i);
3605             } else {
3606               if (dtype == DT_RESOURCE) {
3607                 return errors::InvalidArgument(
3608                     "Tiled sharding for per-replica DT_RESOURCE input must",
3609                     "be TPUPartitionedInput. Here got ",
3610                     edge->src()->type_string());
3611               }
3612               const xla::OpSharding& sharding = arg_shardings[orig_arg_num];
3613 
3614               ShardedInputInfo sharded_input_info;
3615               if (use_nd_sharding_ops_ && is_per_replica_arg) {
3616                 TF_ASSIGN_OR_RETURN(
3617                     sharded_input_info,
3618                     CreateOrGetXlaSplitNodeForShardedPerReplicaArg(
3619                         sharding, replica, orig_arg_num, dtype,
3620                         PartialTensorShape(), edge->src(), edge->src_output(),
3621                         graph, &input_index_to_sharded_inputs));
3622               } else if (use_nd_sharding_ops_) {
3623                 TF_ASSIGN_OR_RETURN(
3624                     sharded_input_info,
3625                     CreateOrGetXlaSplitNodeForDistributedArg(
3626                         sharding, params_info.NumReplicas(), replica,
3627                         orig_arg_num, dtype, PartialTensorShape(), edge->src(),
3628                         edge->src_output(), graph,
3629                         &input_index_to_sharded_inputs));
3630               } else {
3631                 TF_ASSIGN_OR_RETURN(
3632                     sharded_input_info,
3633                     CreateOrGetSplitNodesForInputSharding(
3634                         sharding, orig_arg_num, dtype, PartialTensorShape(),
3635                         replica, edge->src_output(), edge->src(),
3636                         control_predecessor, graph,
3637                         &input_index_to_sharded_inputs));
3638               }
3639 
3640               NodeOut split_node_and_index =
3641                   sharded_input_info.sharded_inputs.at(core);
3642               // Connect with Split node output.
3643               graph->AddEdge(split_node_and_index.node,
3644                              split_node_and_index.index, node, i);
3645             }
3646           } else if (edge->src()->type_string() == kTPUPartitionedInput &&
3647                      arg_shardings[orig_arg_num].type() ==
3648                          xla::OpSharding::REPLICATED) {
3649             graph->AddEdge(replicate_input_fan_in_nodes[input_num][core].node,
3650                            replicate_input_fan_in_nodes[input_num][core].port,
3651                            node, i);
3652           } else {
3653             graph->AddEdge(edge->src(), edge->src_output(), node, i);
3654           }
3655         } else if (params_info.IsBroadcastArg(orig_arg_num)) {
3656           // Broadcast input.
3657           int64_t input_num = params_info.FirstBroadcastArgFromHost() +
3658                               core_arg_nums[core][i] -
3659                               params_info.NumPerReplicaArgs() -
3660                               params_info.NumDistributedArgs();
3661           const Edge* edge = replicate_input_edges[input_num];
3662           DataType dtype = edge->src()->output_type(edge->src_output());
3663           if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) ==
3664               kTpuAllTypes.end()) {
3665             return errors::InvalidArgument(
3666                 "Unsupported data type for TPU: ", DataTypeString(dtype),
3667                 ", caused by output ", edge->src()->name(), ":",
3668                 edge->src_output());
3669           }
3670           graph->AddEdge(edge->src(), edge->src_output(), node, i);
3671         } else {
3672           // Variable input.
3673           int64_t variable_num =
3674               orig_arg_num - params_info.NumPerReplicaArgs() -
3675               params_info.NumDistributedArgs() - params_info.NumBroadcastArgs();
3676           TF_RET_CHECK(variable_num < num_variables);
3677 
3678           Node* variable_read = variable_reads[variable_num];
3679           DataType dtype = variable_read->output_type(0);
3680           if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) ==
3681               kTpuAllTypes.end()) {
3682             return errors::InvalidArgument(
3683                 "Unsupported resource variable data type for TPU: ",
3684                 DataTypeString(dtype), ", caused by ReadVariableOp ",
3685                 variable_read->DebugString());
3686           }
3687           DeviceNameUtils::ParsedName requested_device;
3688           string requested = variable_read->requested_device();
3689           TF_RET_CHECK(
3690               DeviceNameUtils::ParseFullName(requested, &requested_device));
3691           if (requested_device.type != "TPU") {
3692             // Stage the value via the CPU device on the remote host. The graph
3693             // partitioner will introduce an intermediate copy rather than
3694             // copying the same tensor multiple times across the network, and we
3695             // would prefer that intermediate copy to be in host memory to avoid
3696             // running out of memory if the TPUExecute op on the staging device
3697             // starts running before the _Send ops to the other TPU devices on
3698             // the same host complete. We don't do this if the variables are
3699             // already placed on TPU, otherwise it will cause an unnecessary
3700             // round trip copy.
3701             // TODO(b/79580121): give each replica its own on-device variable
3702             // replica and then delete this code.
3703             string device;
3704             TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
3705                 tpu_device_names[replica][core], &device));
3706             TF_ASSIGN_OR_RETURN(
3707                 auto var_data,
3708                 CreateOrGetPerHostVariableCopy(
3709                     device, variable_num, variable_reads, params_info,
3710                     arg_shardings, replicate_node, enable_xla_param_broadcast_,
3711                     num_cores_per_replica, replica, arg_shapes,
3712                     &per_host_var_copies, graph));
3713 
3714             if (arg_shardings[orig_arg_num].type() == xla::OpSharding::OTHER) {
3715               ShardedInputInfo sharded_input_info;
3716 
3717               if (EnableXlaParamBroadcast(enable_xla_param_broadcast_,
3718                                           params_info, orig_arg_num, dtype)) {
3719                 // Populates the sharded dummy vars for non-zero replicas.
3720                 TF_RETURN_IF_ERROR(CreatePartitionedDummyVarArgs(
3721                     arg_shardings[orig_arg_num], params_info.NumReplicas(),
3722                     replica, arg_shapes[orig_arg_num], var_data.node,
3723                     orig_arg_num, dtype, device, graph, tpu_device_names,
3724                     &sharded_per_host_index, &input_index_to_sharded_inputs));
3725               }
3726 
3727               if (use_nd_sharding_ops_) {
3728                 TF_ASSIGN_OR_RETURN(
3729                     sharded_input_info,
3730                     CreateOrGetXlaSplitNodeForVariableArg(
3731                         arg_shardings[orig_arg_num], params_info.NumReplicas(),
3732                         replica, orig_arg_num,
3733                         arg_shapes[orig_arg_num].handle_type,
3734                         arg_shapes[orig_arg_num].handle_shape, var_data.node,
3735                         var_data.index, graph, &to_be_removed_nodes,
3736                         &input_index_to_sharded_inputs));
3737               } else {
3738                 TF_ASSIGN_OR_RETURN(
3739                     sharded_input_info,
3740                     CreateOrGetSplitNodesForInputSharding(
3741                         arg_shardings[orig_arg_num], orig_arg_num,
3742                         arg_shapes[orig_arg_num].handle_type,
3743                         arg_shapes[orig_arg_num].handle_shape, replica,
3744                         var_data.index, var_data.node, control_predecessor,
3745                         graph, &input_index_to_sharded_inputs));
3746               }
3747 
3748               NodeOut split_node_and_index =
3749                   sharded_input_info.sharded_inputs[core];
3750               // Connect with Split node output.
3751               graph->AddEdge(split_node_and_index.node,
3752                              split_node_and_index.index, node, i);
3753 
3754             } else {
3755               graph->AddEdge(var_data.node, var_data.index, node, i);
3756             }
3757           } else {
3758             graph->AddEdge(variable_reads[variable_num], 0, node, i);
3759           }
3760         }
3761       }
3762 
3763       // Adds a program input edge from the compiler.
3764       graph->AddEdge(compile_node, core + 1, node, node->num_inputs() - 1);
3765 
3766       // Add data output edges.
3767       int num_outputs = core_retval_nums[core].size();
3768       for (int i = 0; i < num_outputs; ++i) {
3769         int output_num =
3770             replica * num_retvals_per_replica + core_retval_nums[core][i];
3771         const auto& sharding = retval_shardings[core_retval_nums[core][i]];
3772         if (sharding.type() == xla::OpSharding::OTHER) {
3773           int retval_index = core_retval_nums[core][i];
3774           retval_index_to_output_index_mapping[retval_index][core] = i;
3775           bool is_last_core =
3776               core ==
3777               *std::max_element(sharding.tile_assignment_devices().begin(),
3778                                 sharding.tile_assignment_devices().end());
3779           bool isPartitionOutNode = false;
3780 
3781           const Edge* e = replicate_output_edges[output_num];
3782           const Edge* e_out;
3783           for (const Edge* out_edge : e->dst()->out_edges()) {
3784             if (out_edge->dst()->type_string() == kTPUPartitionedOutput) {
3785               isPartitionOutNode = true;
3786               e_out = out_edge;
3787             }
3788           }
3789           if (isPartitionOutNode) {
3790             graph->AddEdge(
3791                 node, i, replicate_output_fan_out_nodes[output_num][core],
3792                 replicate_output_fan_out_dst_inputs[output_num][core]);
3793             VLOG(2) << "Connect " << node->name() << " at " << i << " to "
3794                     << replicate_output_fan_out_nodes[output_num][core]->name()
3795                     << " at "
3796                     << replicate_output_fan_out_dst_inputs[output_num][core];
3797             if (is_last_core) {
3798               graph->RemoveEdge(e);
3799               graph->RemoveEdge(e_out);
3800             }
3801             continue;
3802           }
3803 
3804           // Do this in the iteration of last core in tile assignment, so all
3805           // TPUExecute nodes have been created.
3806           if (!is_last_core) {
3807             continue;
3808           }
3809 
3810           // Add a Concat node.
3811           std::vector<NodeOut> orig_inputs;
3812           for (int64_t tile_index = 0;
3813                tile_index < sharding.tile_assignment_devices_size();
3814                ++tile_index) {
3815             int64_t last_tile_dim_size =
3816                 *sharding.tile_assignment_dimensions().rbegin();
3817             if (sharding.replicate_on_last_tile_dim() &&
3818                 tile_index % last_tile_dim_size != 0) {
3819               continue;
3820             }
3821             int64_t core_id = sharding.tile_assignment_devices(tile_index);
3822             int core_retval_index =
3823                 retval_index_to_output_index_mapping[retval_index][core_id];
3824             orig_inputs.push_back(
3825                 NodeOut{execute_nodes[replica][core_id],
3826                         static_cast<int>(
3827                             core_retval_nums[core_id][core_retval_index])});
3828           }
3829           DataType dtype = e->src()->output_type(e->src_output());
3830           Node* concat_node = nullptr;
3831           if (use_nd_sharding_ops_) {
3832             TF_ASSIGN_OR_RETURN(
3833                 concat_node, CreateXlaConcatNode(
3834                                  sharding, replica, dtype,
3835                                  /*partial_tensor_shape=*/PartialTensorShape(),
3836                                  orig_inputs, /*device=*/"", graph));
3837           } else {
3838             TF_ASSIGN_OR_RETURN(
3839                 concat_node,
3840                 CreateConcatNodesForRetval(
3841                     sharding, dtype, /*inferred_shape=*/PartialTensorShape(),
3842                     replica, orig_inputs, graph, /*device=*/""));
3843           }
3844 
3845           const Edge* edge = replicate_output_edges[output_num];
3846           Node* dst = edge->dst();
3847           int dst_input = edge->dst_input();
3848           graph->RemoveEdge(edge);
3849           graph->AddEdge(concat_node, 0, dst, dst_input);
3850 
3851           continue;
3852         }
3853 
3854         // If this is a replicated output, outputs on all cores will be the
3855         // same, and we only take the output from core 0.
3856         if (sharding.type() == xla::OpSharding::REPLICATED && core != 0) {
3857           continue;
3858         }
3859 
3860         // If output has maximal sharding, make sure we only use output from
3861         // TPUExecute node with logical core id equal to core id defined by the
3862         // xla sharding.
3863         if (sharding.type() == xla::OpSharding::MAXIMAL &&
3864             core != sharding.tile_assignment_devices(0)) {
3865           continue;
3866         }
3867 
3868         const Edge* replicate_edge_to_replace =
3869             replicate_output_edges[output_num];
3870         Node* dst = replicate_edge_to_replace->dst();
3871         int dst_input = replicate_edge_to_replace->dst_input();
3872         graph->RemoveEdge(replicate_edge_to_replace);
3873         graph->AddEdge(node, i, dst, dst_input);
3874       }
3875 
3876       // Feed the updated variable values from the first replica to the
3877       // variable write nodes.
3878       if (replica == 0) {
3879         for (int i = 0; i < core_variable_writes.size(); ++i) {
3880           int orig_arg_num =
3881               core_variable_writes[i] + params_info.NumPerReplicaArgs() +
3882               params_info.NumDistributedArgs() + params_info.NumBroadcastArgs();
3883           const auto& sharding = arg_shardings[orig_arg_num];
3884           // If this is a tiling sharded variable, concat variable updates from
3885           // all cores.
3886           if (sharding.type() == xla::OpSharding::OTHER) {
3887             orig_arg_num_to_output_index_mapping[orig_arg_num][core] = i;
3888 
3889             // Do this in the iteration of last core in tile assignment, so all
3890             // TPUExecute nodes have been created.
3891             if (core !=
3892                 *std::max_element(sharding.tile_assignment_devices().begin(),
3893                                   sharding.tile_assignment_devices().end())) {
3894               continue;
3895             }
3896 
3897             // Add a Concat node.
3898             std::vector<NodeOut> orig_inputs;
3899             for (int64_t tile_index = 0;
3900                  tile_index < sharding.tile_assignment_devices_size();
3901                  ++tile_index) {
3902               int64_t last_tile_dim_size =
3903                   *sharding.tile_assignment_dimensions().rbegin();
3904               if (sharding.replicate_on_last_tile_dim() &&
3905                   tile_index % last_tile_dim_size != 0) {
3906                 continue;
3907               }
3908               int64_t core_id = sharding.tile_assignment_devices(tile_index);
3909               int core_retval_num =
3910                   orig_arg_num_to_output_index_mapping[orig_arg_num][core_id];
3911               orig_inputs.push_back(
3912                   NodeOut{execute_nodes[0][core_id],
3913                           static_cast<int>(core_retval_nums[core_id].size() +
3914                                            core_retval_num)});
3915             }
3916 
3917             // Use the variable read's device for the concat. They should both
3918             // be collocated with the variable.
3919             absl::string_view device =
3920                 variable_reads[core_variable_writes[i]]->assigned_device_name();
3921             Node* concat_node = nullptr;
3922             if (use_nd_sharding_ops_) {
3923               TF_ASSIGN_OR_RETURN(
3924                   concat_node,
3925                   CreateXlaConcatNode(sharding, replica,
3926                                       arg_shapes[orig_arg_num].handle_type,
3927                                       arg_shapes[orig_arg_num].handle_shape,
3928                                       orig_inputs, device, graph));
3929             } else {
3930               TF_ASSIGN_OR_RETURN(
3931                   concat_node,
3932                   CreateConcatNodesForRetval(
3933                       sharding, arg_shapes[orig_arg_num].handle_type,
3934                       arg_shapes[orig_arg_num].handle_shape, replica,
3935                       orig_inputs, graph, device));
3936             }
3937             // Populate VariableWrite.
3938             VariableWrite& write = variable_writes->at(core_variable_writes[i]);
3939             write.value = concat_node;
3940             write.value_output = 0;
3941             write.predicate = compile_node;
3942             write.predicate_output = num_cores_per_replica + core + 1;
3943 
3944             continue;
3945           }
3946 
3947           // If this is a replicated variable, outputs on all cores will be the
3948           // same, and we only take the output from core 0 for the variable
3949           // update.
3950           if (sharding.type() == xla::OpSharding::REPLICATED && core != 0) {
3951             continue;
3952           }
3953           VariableWrite& write = variable_writes->at(core_variable_writes[i]);
3954           write.value = node;
3955           write.value_output = num_outputs + i;
3956           write.predicate = compile_node;
3957           write.predicate_output = num_cores_per_replica + core + 1;
3958         }
3959       }
3960     }
3961   }
3962 
3963   for (Node* node : to_be_removed_nodes) {
3964     graph->RemoveNode(node);
3965   }
3966   return Status::OK();
3967 }  // NOLINT(readability/fn_size)
3968 
CopyOutsideCompilationNodes(int replica_index,const std::vector<Node * > & outside_compilation_nodes,const DeviceNameUtils::ParsedName & tpu_device,const DeviceNameUtils::ParsedName & partial_device,NodeToNodeReplicasMap * node_images,Graph * graph)3969 /* static */ Status DistributedTPURewritePass::CopyOutsideCompilationNodes(
3970     int replica_index, const std::vector<Node*>& outside_compilation_nodes,
3971     const DeviceNameUtils::ParsedName& tpu_device,
3972     const DeviceNameUtils::ParsedName& partial_device,
3973     NodeToNodeReplicasMap* node_images, Graph* graph) {
3974   for (Node* node : outside_compilation_nodes) {
3975     NodeDef image_def = node->def();
3976     MergeDebugInfo(NodeDebugInfo(node->def()), &image_def);
3977     const string suffix = strings::StrCat("/R", replica_index);
3978     // In addition to node name, make the frame name unique to avoid multiple
3979     // LoopCond nodes in one frame.
3980     TF_RETURN_IF_ERROR(
3981         AddPrefixAndSuffixToNode("" /* prefix */, suffix, &image_def));
3982     Status status;
3983     Node* image = graph->AddNode(image_def, &status);
3984     image->AddAttr(kXlaReplicaIdAttrName, replica_index);
3985     TF_RETURN_IF_ERROR(status);
3986     if (HasNodeAttr(image->def(), kXlaHasHostTransferAttrName)) {
3987       TF_RETURN_IF_ERROR(
3988           SetNodeDeviceForTPUCommunication(tpu_device, DEVICE_CPU, image));
3989     } else {
3990       const string& original_device_string =
3991           node->assigned_device_name().empty() ? node->requested_device()
3992                                                : node->assigned_device_name();
3993       DeviceNameUtils::ParsedName device;
3994       TF_RET_CHECK(
3995           DeviceNameUtils::ParseFullName(original_device_string, &device));
3996       // If the requested device can be merged with the replica's host device,
3997       // then do so. For example, if the requested device is "/CPU:0" or
3998       // "/GPU:0" then it will be placed on the CPU/GPU of the host where this
3999       // replica is running. But if the requested device is
4000       // "/task:3/replica:2/CPU:0" then it will be placed on that task/replica.
4001       if (DeviceNameUtils::IsSpecification(device, partial_device)) {
4002         TF_RETURN_IF_ERROR(
4003             DeviceNameUtils::MergeDevNames(&device, partial_device));
4004       }
4005       image->set_requested_device(DeviceNameUtils::ParsedNameToString(device));
4006     }
4007     std::vector<Node*>& node_image_vector = (*node_images)[node];
4008     node_image_vector.resize(replica_index + 1);
4009     node_image_vector[replica_index] = image;
4010   }
4011   return Status::OK();
4012 }
4013 
ReplicateOutsideCompilationNodes(const std::vector<std::vector<string>> & tf_device_assignment,const HostComputeCoreMap & host_compute_core,const OutsideCompilationNodeMap & outside_compilation_nodes,NodeToNodeReplicasMap * node_images,Graph * graph)4014 /* static */ Status DistributedTPURewritePass::ReplicateOutsideCompilationNodes(
4015     const std::vector<std::vector<string>>& tf_device_assignment,
4016     const HostComputeCoreMap& host_compute_core,
4017     const OutsideCompilationNodeMap& outside_compilation_nodes,
4018     NodeToNodeReplicasMap* node_images, Graph* graph) {
4019   // Iterate over replicas.
4020   for (int i = 0; i < tf_device_assignment.size(); ++i) {
4021     const auto& core_devices = tf_device_assignment[i];
4022     for (const auto& oc_cluster_iter : outside_compilation_nodes) {
4023       const string& oc_cluster_name = oc_cluster_iter.first;
4024       const auto& oc_cluster_nodes = oc_cluster_iter.second;
4025       // We previously validated that host_compute_core contains an entry for
4026       // each cluster.
4027       int core = host_compute_core.at(oc_cluster_name);
4028       TF_RET_CHECK(core >= 0 && core < core_devices.size());
4029       // tpu_device is the device the HostCompute XLA Op for this cluster runs
4030       // on.
4031       DeviceNameUtils::ParsedName tpu_device;
4032       TF_RET_CHECK(
4033           DeviceNameUtils::ParseFullName(core_devices[core], &tpu_device));
4034       // partial_device contains the replica and task but not the type.
4035       DeviceNameUtils::ParsedName partial_device = tpu_device;
4036       partial_device.has_type = false;
4037       partial_device.has_id = false;
4038 
4039       if (tf_device_assignment.size() == 1) {
4040         // With a single replica don't copy any nodes just put the original
4041         // nodes into the image map. We leave the device placement alone, except
4042         // that we have to fill in the correct core for the host send and
4043         // receive nodes.
4044         for (Node* node : oc_cluster_nodes) {
4045           (*node_images)[node] = {node};
4046           node->AddAttr(kXlaReplicaIdAttrName, 0);
4047           if (HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) {
4048             TF_RETURN_IF_ERROR(
4049                 SetNodeDeviceForTPUCommunication(tpu_device, DEVICE_CPU, node));
4050           }
4051         }
4052       } else {
4053         // Iterate over outside_compilation clusters in this computation, adding
4054         // all the nodes with appropriate device assignments.
4055         TF_RETURN_IF_ERROR(
4056             CopyOutsideCompilationNodes(i, oc_cluster_nodes, tpu_device,
4057                                         partial_device, node_images, graph));
4058       }
4059     }
4060   }
4061   return Status::OK();
4062 }
4063 
CopyOutsideCompilationEdges(const std::vector<Node * > & outside_compilation_nodes,const NodeToNodeReplicasMap & node_images,const std::unordered_map<string,Node * > outside_compilation_inputs,Graph * graph)4064 /* static */ Status DistributedTPURewritePass::CopyOutsideCompilationEdges(
4065     const std::vector<Node*>& outside_compilation_nodes,
4066     const NodeToNodeReplicasMap& node_images,
4067     const std::unordered_map<string, Node*> outside_compilation_inputs,
4068     Graph* graph) {
4069   for (Node* node : outside_compilation_nodes) {
4070     const auto& images = node_images.at(node);
4071     // Make a copy of all edges and iterate on "in_edges", because we might
4072     // remove edges when iteratating through them.
4073     std::vector<const Edge*> in_edges(node->in_edges().begin(),
4074                                       node->in_edges().end());
4075     for (const Edge* edge : in_edges) {
4076       Node* src = edge->src();
4077       const auto iter = node_images.find(src);
4078       if (iter == node_images.end()) {
4079         if (images.size() > 1) {
4080           // The source node is a 'normal' node not part of any
4081           // rewrite. Broadcast the value to all replicas. (If images.size() ==
4082           // 1 the cluster is not replicated and we can leave the original edge
4083           // in place.)
4084           for (Node* dst : images) {
4085             graph->AddEdge(src, edge->src_output(), dst, edge->dst_input());
4086           }
4087         }
4088         continue;
4089       }
4090 
4091       // The source node is a replicated outside_compilation node.
4092       const auto& src_images = iter->second;
4093       if (src_images.size() != images.size()) {
4094         return errors::InvalidArgument(
4095             "Graph contains an edge from node ", src->name(),
4096             " in an outside_compilation block replicated ", src_images.size(),
4097             " ways to node ", node->name(),
4098             " in an outside_compilation block replicated ", images.size(),
4099             " ways. Replication factors must match. Leave a comment on "
4100             "tracking bug b/76419636 if you need this to be supported.");
4101       }
4102       bool is_lifted_arg;
4103       string outside_compilation_cluster;
4104       if (GetNodeAttr(src->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg)
4105               .ok() &&
4106           GetNodeAttr(src->def(), kOutsideCompilationAttr,
4107                       &outside_compilation_cluster)
4108               .ok()) {
4109         const auto input_iter =
4110             outside_compilation_inputs.find(outside_compilation_cluster);
4111         TF_RET_CHECK(input_iter != outside_compilation_inputs.end());
4112         TF_RET_CHECK(input_iter->second->type_string() == "IdentityN");
4113         int dst_input = edge->dst_input();
4114         if (src_images.size() == 1) {
4115           graph->RemoveEdge(edge);
4116         }
4117         for (int i = 0; i < src_images.size(); ++i) {
4118           graph->AddEdge(input_iter->second, i, images[i], dst_input);
4119         }
4120         continue;
4121       }
4122 
4123       bool is_placeholder_for_arg;
4124       string outside_compilation_input_attr;
4125       if (GetNodeAttr(src->def(), kXlaIsPlaceholderForArg,
4126                       &is_placeholder_for_arg)
4127               .ok() &&
4128           GetNodeAttr(src->def(), kXlaOutsideCompilationInputsAttrName,
4129                       &outside_compilation_input_attr)
4130               .ok()) {
4131         const auto input_iter =
4132             outside_compilation_inputs.find(outside_compilation_input_attr);
4133         TF_RET_CHECK(input_iter != outside_compilation_inputs.end());
4134         TF_RET_CHECK(input_iter->second->type_string() == "IdentityN");
4135         int dst_input = edge->dst_input();
4136         if (src_images.size() == 1) {
4137           graph->RemoveEdge(edge);
4138         }
4139         for (int i = 0; i < src_images.size(); ++i) {
4140           graph->AddEdge(input_iter->second, i, images[i], dst_input);
4141         }
4142         continue;
4143       }
4144 
4145       if (images.size() > 1) {
4146         // If images.size() == 1 neither cluster is replicated and we can
4147         // leave the original edges in place.
4148         for (int i = 0; i < src_images.size(); ++i) {
4149           graph->AddEdge(src_images[i], edge->src_output(), images[i],
4150                          edge->dst_input());
4151         }
4152       }
4153     }
4154     for (const Edge* edge : node->out_edges()) {
4155       Node* dst = edge->dst();
4156       const auto iter = node_images.find(dst);
4157       if (iter == node_images.end()) {
4158         // The source node is a 'normal' node not part of any rewrite.
4159         if (edge->IsControlEdge()) {
4160           // Make the dst node have a control dependency on every replica.
4161           if (images.size() > 1) {
4162             for (int i = 0; i < images.size(); ++i) {
4163               graph->AddControlEdge(images[i], dst);
4164             }
4165           }
4166           // else the cluster is not replicated so we can leave the original
4167           // edge in place.
4168         } else {
4169           // The edge
4170           // is only valid if the outside_compilation block is not replicated.
4171           if (images.size() > 1) {
4172             return errors::InvalidArgument(
4173                 "Graph contains an edge from node ", node->name(),
4174                 " in an outside_compilation block replicated ", images.size(),
4175                 " ways to node ", dst->name(),
4176                 " that is not part of an outside_compilation block. Edges from "
4177                 "outside_compilation to regular graph nodes are only supported "
4178                 "for replication factors of 1. Leave a comment on tracking bug "
4179                 "b/76419636 if you need this to be supported.");
4180           }
4181           // else the cluster is not replicated so we can leave the original
4182           // edge in place.
4183         }
4184       }
4185       // The case where src and dst are both in node_images is covered elsewhere
4186       // when iterating over in_edges of dst.
4187     }
4188   }
4189   return Status::OK();
4190 }
4191 
ReplicateOutsideCompilationEdges(const OutsideCompilationNodeMap & outside_compilation_nodes,const NodeToNodeReplicasMap & node_images,const std::unordered_map<string,Node * > outside_compilation_inputs,Graph * graph)4192 /* static */ Status DistributedTPURewritePass::ReplicateOutsideCompilationEdges(
4193     const OutsideCompilationNodeMap& outside_compilation_nodes,
4194     const NodeToNodeReplicasMap& node_images,
4195     const std::unordered_map<string, Node*> outside_compilation_inputs,
4196     Graph* graph) {
4197   for (const auto& oc_cluster_iter : outside_compilation_nodes) {
4198     TF_RETURN_IF_ERROR(
4199         CopyOutsideCompilationEdges(oc_cluster_iter.second, node_images,
4200                                     outside_compilation_inputs, graph));
4201   }
4202   return Status::OK();
4203 }
4204 
RemoveOutsideCompilationNodes(const NodeToNodeReplicasMap & node_images,Graph * graph)4205 /* static */ Status DistributedTPURewritePass::RemoveOutsideCompilationNodes(
4206     const NodeToNodeReplicasMap& node_images, Graph* graph) {
4207   for (const auto& iter : node_images) {
4208     if (iter.second.size() > 1) {
4209       // The cluster was replicated so remove the original node.
4210       Node* node = iter.first;
4211       graph->RemoveNode(node);
4212     }
4213   }
4214   return Status::OK();
4215 }
4216 
4217 /* static */ Status
LowerOutsideCompilationFunctionalNodes(Graph * g,FunctionLibraryDefinition & flib_def,const TPUReplicateDeviceNamesMapping & tpu_replicate_device_names_mapping)4218 DistributedTPURewritePass::LowerOutsideCompilationFunctionalNodes(
4219     Graph* g, FunctionLibraryDefinition& flib_def,
4220     const TPUReplicateDeviceNamesMapping& tpu_replicate_device_names_mapping) {
4221   bool modified = false;
4222   do {
4223     std::vector<Node*> nodes_to_lower;
4224     for (Node* n : g->op_nodes()) {
4225       if (!HasNodeAttr(n->def(), kOutsideCompilationAttr)) {
4226         continue;
4227       }
4228 
4229       if (n->IsWhileNode() || n->IsIfNode() || IsFunctionCall(flib_def, *n)) {
4230         // Only lower functional ops with DT_RESOURCE input, because otherwise
4231         // placer will complain. For normal cases, lowering will cause slowdown
4232         // when related functions are huge (b/139037679).
4233         bool has_resource_input = false;
4234         for (const Edge* e : n->in_edges()) {
4235           if (!e->IsControlEdge() &&
4236               e->src()->output_type(e->src_output()) == DT_RESOURCE) {
4237             has_resource_input = true;
4238             break;
4239           }
4240         }
4241         if (has_resource_input) {
4242           nodes_to_lower.push_back(n);
4243         }
4244       }
4245     }
4246 
4247     modified = !nodes_to_lower.empty();
4248 
4249     auto lower_functional_node = [&flib_def, &g](Node* n) -> Status {
4250       // Clear device assignment. Otherwise all lowered nodes will have
4251       // device assignment, which is not what we want.
4252       n->set_requested_device("");
4253 
4254       int replica_id;
4255       TF_RETURN_IF_ERROR(
4256           GetNodeAttr(n->def(), kXlaReplicaIdAttrName, &replica_id));
4257 
4258       string outside_compilation_attr;
4259       TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kOutsideCompilationAttr,
4260                                      &outside_compilation_attr));
4261 
4262       // There are two different kinds of functional outside compilation nodes:
4263       // 1. Nodes that are in outside compilation blocks already. They are
4264       //    generated by FunctionalizeControlFlowForXlaPass, and only have
4265       //    attribute kOutsideCompilationAttr.
4266       // 2. Mirrored control flow built for outside compilation in functional
4267       //    nodes. They are generated by ExtractOutsideCompilationPass, and have
4268       //    both kOutsideCompilationAttr and kXlaHasHostTransferAttrName.
4269       // When lowering them, they need to be treated differently.
4270       // For 1), their body functions are always V1 functions written by users,
4271       // and their "control outputs" are control inputs of _Retval nodes. They
4272       // should be lowered as V1 functions.
4273       // For 2), we always add necessary "control outputs"
4274       // (_XlaRecvAtHost/_XlaSendAtHost nodes) to "control_ret" field in their
4275       // FunctionDef's. They should be lowered as V2 functions.
4276       bool is_host_side_mirrored_control_flow =
4277           HasNodeAttr(n->def(), kXlaHasHostTransferAttrName);
4278 
4279       int num_node_ids = g->num_node_ids();
4280       bool is_call_node = IsFunctionCall(flib_def, *n);
4281       if (n->IsWhileNode()) {
4282         TF_RETURN_IF_ERROR(RewriteWhileNode(n, g, &flib_def,
4283                                             /*keep_node_fetchable=*/false));
4284       } else if (n->IsIfNode()) {
4285         TF_RETURN_IF_ERROR(RewriteIfNode(n, g, /*keep_node_fetchable=*/false));
4286       } else {
4287         TF_RET_CHECK(is_call_node);
4288         // See comments for "is_host_side_mirrored_control_flow" above.
4289         // If this is a node that's in outside compilation block, lower it as
4290         // V1 function. This is controlled by removing
4291         // kLowerAsMultiDeviceFunctionAttr from the node.
4292         if (!is_host_side_mirrored_control_flow) {
4293           n->ClearAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr);
4294         } else {
4295           n->ClearAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr);
4296           n->AddAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr,
4297                      true);
4298         }
4299         TF_RETURN_IF_ERROR(
4300             RewriteFunctionCallNode(n, g, flib_def,
4301                                     /*keep_caller_fetchable=*/false));
4302       }
4303 
4304       for (int i = num_node_ids; i < g->num_node_ids(); i++) {
4305         Node* node = g->FindNodeId(i);
4306         if (!node) {
4307           continue;
4308         }
4309 
4310         if (!is_call_node && is_host_side_mirrored_control_flow &&
4311             IsFunctionCall(flib_def, *node)) {
4312           // For If/While nodes, if they are host side mirrored control flow,
4313           // mark their body function calls with kXlaHasHostTransferAttrName
4314           // attribute to make sure we lower them as V2 function.
4315           node->AddAttr(kXlaHasHostTransferAttrName, true);
4316         }
4317 
4318         if (IsFunctionCall(flib_def, *node) || node->IsWhileNode() ||
4319             node->IsIfNode()) {
4320           // Set kOutsideCompilationAttr attribute so we lower these
4321           // nested function call nodes later.
4322           node->AddAttr(kOutsideCompilationAttr, outside_compilation_attr);
4323           // Set kXlaReplicaIdAttrName attribute so we know replica id when we
4324           // lower this function call node.
4325           node->AddAttr(kXlaReplicaIdAttrName, replica_id);
4326         } else if (node->type_string() == "_XlaRecvAtHost" ||
4327                    node->type_string() == "_XlaSendFromHost") {
4328           // For "_XlaRecvAtHost" and "_XlaSendFromHost" nodes, make sure they
4329           // have kXlaReplicaIdAttrName attribute so later we know which host
4330           // device to assign.
4331           node->AddAttr(kXlaReplicaIdAttrName, replica_id);
4332         }
4333       }
4334       return Status::OK();
4335     };
4336 
4337     for (Node* n : nodes_to_lower) {
4338       TF_RETURN_IF_ERROR(lower_functional_node(n));
4339     }
4340   } while (modified);
4341 
4342   // Set device for all _XlaRecvAtHost and _XlaSendFromHost nodes.
4343   for (Node* n : g->op_nodes()) {
4344     if (n->type_string() != "_XlaRecvAtHost" &&
4345         n->type_string() != "_XlaSendFromHost") {
4346       continue;
4347     }
4348 
4349     string replicate;
4350     TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kTPUReplicateAttr, &replicate));
4351     auto iter = tpu_replicate_device_names_mapping.find(replicate);
4352     TF_RET_CHECK(iter != tpu_replicate_device_names_mapping.end());
4353     const auto& tpu_device_names = iter->second;
4354 
4355     int replica_id;
4356     TF_RETURN_IF_ERROR(
4357         GetNodeAttr(n->def(), kXlaReplicaIdAttrName, &replica_id));
4358     TF_RET_CHECK(replica_id < tpu_device_names.size());
4359     const string& tpu_device_name = tpu_device_names[replica_id][0];
4360     string host_device_name;
4361     TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
4362         tpu_device_name, &host_device_name));
4363     n->set_assigned_device_name(host_device_name);
4364     // We may run TPU rewrite passes again on the subgraphs of the resulting
4365     // graph. Clear kTPUReplicateAttr and kOutsideCompilationAttr for
4366     // "_XlaRecvAtHost" nodes and "_XlaSendFromHost" nodes, in order to make
4367     // sure that TPU rewrite passes take no effect on host-side subgraphs for
4368     // outside compilation.
4369     n->ClearAttr(kTPUReplicateAttr);
4370     n->ClearAttr(kOutsideCompilationAttr);
4371   }
4372 
4373   // Remove IdentityN nodes generated for outside compilation. IdentityN is
4374   // exempt from resource edge colocation, but here we do need input and output
4375   // for these IdentityN nodes to be colocated.
4376   std::vector<Node*> identityn_nodes;
4377   for (Node* n : g->op_nodes()) {
4378     if (n->type_string() == "IdentityN" &&
4379         HasNodeAttr(n->def(), kXlaOutsideCompilationInputsAttrName)) {
4380       identityn_nodes.push_back(n);
4381     }
4382   }
4383   for (Node* n : identityn_nodes) {
4384     std::vector<const Edge*> out_edges(n->out_edges().begin(),
4385                                        n->out_edges().end());
4386     for (const Edge* e : out_edges) {
4387       if (e->IsControlEdge()) {
4388         continue;
4389       }
4390 
4391       int src_output = e->src_output();
4392       const Edge* input_edge;
4393       TF_RETURN_IF_ERROR(n->input_edge(src_output, &input_edge));
4394       Node* dst = e->dst();
4395       int dst_input = e->dst_input();
4396       g->RemoveEdge(e);
4397       g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input);
4398     }
4399     g->RemoveNode(n);
4400   }
4401 
4402   return Status::OK();
4403 }
4404 
ParseHostComputeCores(const Node & replicate_node,const OutsideCompilationNodeMap & outside_compilation_nodes,HostComputeCoreMap * host_compute_core)4405 /* static */ Status DistributedTPURewritePass::ParseHostComputeCores(
4406     const Node& replicate_node,
4407     const OutsideCompilationNodeMap& outside_compilation_nodes,
4408     HostComputeCoreMap* host_compute_core) {
4409   std::vector<string> hc_core_string;
4410   TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "host_compute_core",
4411                                  &hc_core_string));
4412   TF_RETURN_IF_ERROR(
4413       ParseHostComputeCoreList(hc_core_string, host_compute_core));
4414   for (const auto& iter : outside_compilation_nodes) {
4415     const string& oc_cluster_name = iter.first;
4416     if (host_compute_core->find(oc_cluster_name) == host_compute_core->end()) {
4417       // By default put host compute Ops on replicated core 0.
4418       (*host_compute_core)[oc_cluster_name] = 0;
4419     }
4420   }
4421   return Status::OK();
4422 }
4423 
GetDeviceTopology(const DeviceSet & device_set,const Node & replicate_node,int * num_replicas,int * num_cores_per_replica,int * num_tasks,std::vector<std::vector<string>> * tf_device_assignment,std::vector<int> * devices_to_lock,std::unique_ptr<xla::DeviceAssignment> * xla_device_assignment,string * tpu_compilation_device)4424 /* static */ Status DistributedTPURewritePass::GetDeviceTopology(
4425     const DeviceSet& device_set, const Node& replicate_node, int* num_replicas,
4426     int* num_cores_per_replica, int* num_tasks,
4427     std::vector<std::vector<string>>* tf_device_assignment,
4428     std::vector<int>* devices_to_lock,
4429     std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment,
4430     string* tpu_compilation_device) {
4431   TF_RETURN_IF_ERROR(
4432       GetNodeAttr(replicate_node.attrs(), "num_replicas", num_replicas));
4433   if (*num_replicas < 1) {
4434     return errors::InvalidArgument("num_replicas must be >= 1, got ",
4435                                    *num_replicas);
4436   }
4437 
4438   // Find the set of TPU devices in the TF job.
4439   // Indexed by [task number][tpu device number].
4440   std::vector<std::vector<Device*>> tpu_devices;
4441   int num_tpus_per_task;
4442   TF_RETURN_IF_ERROR(GetTPUDeviceNames(replicate_node.requested_device(),
4443                                        device_set, tpu_compilation_device,
4444                                        &num_tpus_per_task, &tpu_devices));
4445   *num_tasks = tpu_devices.size();
4446 
4447   string topology;
4448   TF_RETURN_IF_ERROR(
4449       GetNodeAttr(replicate_node.attrs(), "topology", &topology));
4450   TF_RETURN_IF_ERROR(GetNodeAttr(
4451       replicate_node.attrs(), "num_cores_per_replica", num_cores_per_replica));
4452   std::vector<int> device_assignment;
4453   TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "device_assignment",
4454                                  &device_assignment));
4455 
4456   // TODO(cwhipkey): since we can control multiple pods of different shapes
4457   // from a single worker, it may be desirable to propagate the remote device
4458   // information around (e.g., in DeviceAttributes). This can lead to the mesh
4459   // topology proto being leaked to cloud TPU users (e.g. through GetStatus
4460   // calls); this may be okay, but to be conservative, just assume that the
4461   // master session has the proper flags set.
4462 
4463   // We do not initialize platform right now, but we can still retrieve the
4464   // TPU topology even with an uninitialized platform.
4465   auto* tpu_platform = tpu::TpuPlatformInterface::GetRegisteredPlatform(
4466       /*initialize_platform=*/false);
4467   TF_RET_CHECK(tpu_platform);
4468   tpu::TpuTopologyExternal tpu_topology(tpu_platform->GetTopologyPtr());
4469   TF_RET_CHECK(num_tpus_per_task ==
4470                tpu_topology.LogicalDevicesPerHost(kTensorCore));
4471   TF_RETURN_IF_ERROR(BuildDeviceAssignment(
4472       tpu_topology, num_tpus_per_task, tpu_devices, *num_replicas,
4473       *num_cores_per_replica, topology, device_assignment, tf_device_assignment,
4474       devices_to_lock, xla_device_assignment));
4475 
4476   return Status::OK();
4477 }
4478 
GetIOTypes(int num_replicas,const Node & replicate_node,FunctionLibraryRuntime * flr,Graph * graph,NameRangeMap * input_name_map,const NameAttrList ** function,std::unique_ptr<Graph> * computation,DataTypeVector * arg_types,DataTypeVector * retval_types,ParameterInfo * params_info)4479 /* static */ Status DistributedTPURewritePass::GetIOTypes(
4480     int num_replicas, const Node& replicate_node, FunctionLibraryRuntime* flr,
4481     Graph* graph, NameRangeMap* input_name_map, const NameAttrList** function,
4482     std::unique_ptr<Graph>* computation, DataTypeVector* arg_types,
4483     DataTypeVector* retval_types, ParameterInfo* params_info) {
4484   DataTypeVector input_types, broadcast_input_types, guaranteed_constant_types;
4485   TF_RETURN_IF_ERROR(
4486       GetNodeAttr(replicate_node.attrs(), "Tinputs", &input_types));
4487   TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "Tbroadcast_inputs",
4488                                  &broadcast_input_types));
4489   TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(),
4490                                  "Tguaranteed_constants",
4491                                  &guaranteed_constant_types));
4492   int num_distributed_vars;
4493   TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(),
4494                                  "num_distributed_variables",
4495                                  &num_distributed_vars));
4496   const int num_per_replica_inputs = input_types.size() - num_distributed_vars;
4497 
4498   if (num_per_replica_inputs % num_replicas != 0) {
4499     return errors::InvalidArgument(
4500         "Number of inputs to TPUReplicate (", num_per_replica_inputs,
4501         ") is not divisible by the number of replicas (", num_replicas, ").");
4502   }
4503 
4504   int num_variables;
4505   TF_RETURN_IF_ERROR(
4506       GetNodeAttr(replicate_node.attrs(), "NumVariables", &num_variables));
4507 
4508   NameRangeMap output_name_map;
4509   TF_RETURN_IF_ERROR(NameRangesForNode(replicate_node, replicate_node.op_def(),
4510                                        input_name_map, &output_name_map));
4511 
4512   TF_RETURN_IF_ERROR(
4513       GetNodeAttr(replicate_node.attrs(), "computation", function));
4514 
4515   *computation = absl::make_unique<Graph>(graph->op_registry());
4516   TF_RETURN_IF_ERROR(GetComputationForTPUReplicateOp(
4517       **function, flr, computation->get(), arg_types, retval_types));
4518 
4519   *params_info = ParameterInfo(
4520       num_replicas, num_per_replica_inputs / num_replicas, num_distributed_vars,
4521       broadcast_input_types.size(), num_variables,
4522       guaranteed_constant_types.size(), retval_types->size());
4523 
4524   if (arg_types->size() != params_info->NumInputsToEachReplica()) {
4525     return errors::InvalidArgument(
4526         "Computation argument to TPUReplicate has wrong number of "
4527         "arguments. Expected ",
4528         params_info->NumInputsToEachReplica(), " inputs, got ",
4529         arg_types->size());
4530   }
4531   if (replicate_node.num_outputs() != params_info->NumOutputsToHost()) {
4532     return errors::InvalidArgument(
4533         "Wrong number of outputs from TPUReplicate. Expected ",
4534         params_info->NumOutputsToHost(), " outputs, got ",
4535         replicate_node.num_outputs());
4536   }
4537   if (enable_cross_replica_sharding_mirrored_variables_) {
4538     std::vector<int> mirrored_variable_indices;
4539     TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(),
4540                                    TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR,
4541                                    &mirrored_variable_indices));
4542     for (int index : mirrored_variable_indices) {
4543       TF_RET_CHECK(params_info->IsPerReplicaArg(index) ||
4544                    params_info->IsDistributedArg(index))
4545           << "Mirrored variables not categorized as per-replica arguments, "
4546              "index: "
4547           << index;
4548       params_info->mutable_mirrored_variable_indices()->insert(index);
4549     }
4550   }
4551   return Status::OK();
4552 }
4553 
BuildSequencingNodes(const string & tpu_compilation_device,const Node & replicate_node,Graph * graph,Node ** host_transfer_sequencer,Node ** control_before,Node ** control_after)4554 /* static */ Status DistributedTPURewritePass::BuildSequencingNodes(
4555     const string& tpu_compilation_device, const Node& replicate_node,
4556     Graph* graph, Node** host_transfer_sequencer, Node** control_before,
4557     Node** control_after) {
4558   *host_transfer_sequencer = nullptr;
4559 
4560   TF_RETURN_IF_ERROR(
4561       BuildNoopNode(replicate_node,
4562                     graph->NewName(strings::StrCat(replicate_node.name(), "/",
4563                                                    "control_before")),
4564                     /*device=*/"", graph, control_before));
4565   for (const Edge* e : replicate_node.in_edges()) {
4566     if (!e->IsControlEdge()) {
4567       continue;
4568     }
4569     Node* predecessor = e->src();
4570     if (predecessor->IsSource()) continue;
4571     if (predecessor->type_string() == "NoOp" &&
4572         predecessor->attrs().Find("_xla_host_transfer_sequencer") != nullptr) {
4573       // The node is the sequencer for host transfer operations. Its control
4574       // dependency needs to be placed after the execute node, not before.
4575       if (*host_transfer_sequencer != nullptr) {
4576         return errors::Internal("Replicate node ", replicate_node.name(),
4577                                 " has two transfer sequencer nodes: ",
4578                                 (*host_transfer_sequencer)->name(), " and ",
4579                                 predecessor->name());
4580       }
4581       // Set the correct device to match the other sequencing nodes.
4582       predecessor->set_assigned_device_name(tpu_compilation_device);
4583       *host_transfer_sequencer = predecessor;
4584     } else {
4585       graph->AddControlEdge(predecessor, *control_before);
4586     }
4587   }
4588 
4589   TF_RETURN_IF_ERROR(
4590       BuildNoopNode(replicate_node,
4591                     graph->NewName(strings::StrCat(replicate_node.name(), "/",
4592                                                    "control_after")),
4593                     /*device=*/tpu_compilation_device, graph, control_after));
4594   for (Node* successor : replicate_node.out_nodes()) {
4595     if (successor->attrs().Find("_xla_tail_outside_compilation") != nullptr) {
4596       graph->AddControlEdge(successor, *control_after);
4597     } else {
4598       graph->AddControlEdge(*control_after, successor);
4599     }
4600   }
4601   return Status::OK();
4602 }
4603 
DealWithConstantsAndVariables(const Node & replicate_node,const NameRangeMap & input_name_map,Graph * graph,Node * host_transfer_sequencer,Node * control_before,Node * control_after,absl::Span<const VariableInput> variable_nodes,std::vector<Node * > * guaranteed_constant_nodes,std::vector<Node * > * variable_reads)4604 /* static */ Status DistributedTPURewritePass::DealWithConstantsAndVariables(
4605     const Node& replicate_node, const NameRangeMap& input_name_map,
4606     Graph* graph, Node* host_transfer_sequencer, Node* control_before,
4607     Node* control_after, absl::Span<const VariableInput> variable_nodes,
4608     std::vector<Node*>* guaranteed_constant_nodes,
4609     std::vector<Node*>* variable_reads) {
4610   TF_RETURN_IF_ERROR(FindGuaranteedConstantInputs(
4611       replicate_node, input_name_map, guaranteed_constant_nodes));
4612 
4613   TF_RETURN_IF_ERROR(BuildVariableReads(variable_nodes, control_before, graph,
4614                                         variable_reads));
4615   // Add the control dependency from host transfer nodes.
4616   if (host_transfer_sequencer != nullptr) {
4617     graph->AddControlEdge(host_transfer_sequencer, control_after);
4618   }
4619   return Status::OK();
4620 }
4621 
4622 /* static */ Status
BuildCompilationStatusReturnNodes(Node * replicate_node,Node * compile_node,absl::Span<const int> devices_to_lock,Node ** control_after_compilation,Node ** multilock_acquire,Graph * graph)4623 DistributedTPURewritePass::BuildCompilationStatusReturnNodes(
4624     Node* replicate_node, Node* compile_node,
4625     absl::Span<const int> devices_to_lock, Node** control_after_compilation,
4626     Node** multilock_acquire, Graph* graph) {
4627   const Edge* compilation_edge = nullptr;
4628   for (const auto* e : replicate_node->out_edges()) {
4629     if (e->IsControlEdge() &&
4630         e->dst()->type_string() == "TPUCompilationResult") {
4631       TF_RET_CHECK(compilation_edge == nullptr)
4632           << "Multiple compilation result nodes attached to the same replicate "
4633              "cluster.";
4634       compilation_edge = e;
4635     }
4636   }
4637 
4638   // TODO(jpienaar): This should be checked by default, current tests not using
4639   // this are ones that use the "abort upon successful compile flag" which will
4640   // be removed. Leaving this in until then.
4641   if (compilation_edge != nullptr) {
4642     Node* compilation_status = compilation_edge->dst();
4643     const AttrValue* compile_status_cluster_attr =
4644         compilation_status->attrs().Find(kTPUCompilationResultAttr);
4645     TF_RET_CHECK(compile_status_cluster_attr != nullptr);
4646     const string& compile_status_cluster = compile_status_cluster_attr->s();
4647     TF_RET_CHECK(!compile_status_cluster.empty());
4648     const AttrValue* replicate_cluster_attr =
4649         replicate_node->attrs().Find(kTPUReplicateAttr);
4650     TF_RET_CHECK(replicate_cluster_attr != nullptr);
4651     const string& replicate_cluster = replicate_cluster_attr->s();
4652     TF_RET_CHECK(!replicate_cluster.empty());
4653     TF_RET_CHECK(compile_status_cluster == replicate_cluster);
4654 
4655     TF_RETURN_IF_ERROR(
4656         ReplaceCompilationResultNodeWithIdentity(graph, &compilation_status));
4657     graph->AddEdge(compile_node, 0, compilation_status, 0);
4658   }
4659 
4660   NodeDef def;
4661   def.set_name(UniqueNodeName("tpu_compile_succeeded_assert", graph));
4662   // Create an op to assert that compilation succeeded. The alternative would
4663   // have been to have each execute op check and return an error.
4664   def.set_op("TPUCompileSucceededAssert");
4665   MergeDebugInfo(NodeDebugInfo(replicate_node->def()), &def);
4666   Status status;
4667   Node* compile_succeeded = graph->AddNode(def, &status);
4668   compile_succeeded->set_assigned_device_name(
4669       compile_node->assigned_device_name());
4670   TF_RETURN_IF_ERROR(status);
4671   graph->AddEdge(compile_node, 0, compile_succeeded, 0);
4672 
4673   Node* last_node_before_sequencer = compile_succeeded;
4674 
4675   if (enable_multicore_locking_ && devices_to_lock.size() > 1) {
4676     // Add a lock node to acquire exclusive access to all the cores that will
4677     // execute this program. The lock is required to prevent deadlock or
4678     // incorrect results when running concurrent multi-core programs in the
4679     // same distributed runtime when there is no direct graph dependency
4680     // between the programs (either because they are run from different sessions
4681     // or because they are in the same graph, but have no control or data
4682     // dependencies to sequence them). Consider the case of two multi-core
4683     // computations A and B whose cores overlap and include cores X and Y. With
4684     // no locking and no graph dependencies it is possible that A's program
4685     // gets enqueued before B's on core X, while B's program gets enqueued
4686     // before A's on core Y. This will lead either to deadlock or to
4687     // incorrect results, since the runtime has no mechanism to re-sequence
4688     // the programs on the cores. By adding a multi-lock acquisition for all the
4689     // before any TPUExecute ops are run, and releasing it after they complete,
4690     // we ensure that the programs are enqueued on the cores in a consistent
4691     // order.
4692     //
4693     // There is a risk when computations are in the same graph, and include a
4694     // data dependency, that the lock acquisition could provoke deadlock.
4695     // Suppose that A must happen before B because B's input depends on A's
4696     // output. Then it is obviously necessary that A's lock acquisition must
4697     // happen before B's lock acquisition, and so we must ensure that there is
4698     // a graph dependency causing B's lock acquisition to be sequenced after A's
4699     // lock acquisition. Right now that dependency is satisfied because the
4700     // shape inference code cannot determine the shape of A's outputs, and so
4701     // B's compilation, which precedes B's lock acquisition, is always sequenced
4702     // after A's execution. If the shape inference is improved it will be
4703     // necessary to add an explicit control edge between dependent lock
4704     // acquisition ops.
4705     NodeDef lock_def;
4706     lock_def.set_name(graph->NewName(
4707         strings::StrCat(compile_node->name(), "/", "tpu_acquire_multilock")));
4708     lock_def.set_op("TpuMultilock");
4709     AddNodeAttr("lock_list", devices_to_lock, &lock_def);
4710     MergeDebugInfo(NodeDebugInfo(replicate_node->def()), &lock_def);
4711     Status status;
4712     *multilock_acquire = graph->AddNode(lock_def, &status);
4713     TF_RETURN_IF_ERROR(status);
4714     (*multilock_acquire)
4715         ->set_assigned_device_name(compile_node->assigned_device_name());
4716     graph->AddControlEdge(compile_succeeded, *multilock_acquire);
4717     last_node_before_sequencer = *multilock_acquire;
4718   } else {
4719     *multilock_acquire = nullptr;
4720   }
4721 
4722   // Build a sequencing node for when compilation has completed.
4723   TF_RETURN_IF_ERROR(
4724       BuildNoopNode(*replicate_node,
4725                     graph->NewName(strings::StrCat(compile_node->name(), "/",
4726                                                    "after_compilation")),
4727                     /*device=*/"", graph, control_after_compilation));
4728   graph->AddControlEdge(last_node_before_sequencer, *control_after_compilation);
4729 
4730   return Status::OK();
4731 }
4732 
4733 // Updates the head and tail outside compiled nodes so that nodes have the
4734 // correct device and removes the replication and outside compilation attributes
4735 // so that these nodes do not trigger further graph optimization passes.
UpdateHeadTailOutsideCompilation(const std::vector<std::vector<string>> & tf_device_assignment,const std::vector<Node * > & head_tail_outside_compilation_nodes)4736 /* static */ Status DistributedTPURewritePass::UpdateHeadTailOutsideCompilation(
4737     const std::vector<std::vector<string>>& tf_device_assignment,
4738     const std::vector<Node*>& head_tail_outside_compilation_nodes) {
4739   for (Node* node : head_tail_outside_compilation_nodes) {
4740     int replica_id;
4741     TF_RETURN_IF_ERROR(
4742         GetNodeAttr(node->def(), kXlaReplicaIdAttrName, &replica_id));
4743     // Since we set the device, this will now run on a task other than 0. We
4744     // clear the two following attributes so that we don't trigger encapsulation
4745     // again on the remote host (which will fail due to a missing
4746     // _TPUReplicateMetadata node for the cluster).
4747     for (const Edge* e : node->in_edges()) {
4748       // Resource consuming ops should colocate with its resource input.
4749       if (e->src()->IsArg() &&
4750           e->src()->output_type(e->src_output()) == DT_RESOURCE) {
4751         node->set_requested_device(tf_device_assignment[replica_id][0]);
4752       }
4753     }
4754     if (node->requested_device().empty()) {
4755       string cpu_device;
4756       TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
4757           tf_device_assignment[replica_id][0], &cpu_device));
4758       node->set_requested_device(cpu_device);
4759     }
4760     node->ClearAttr(kTPUReplicateAttr);
4761     node->ClearAttr(kOutsideCompilationAttr);
4762   }
4763   return Status::OK();
4764 }
4765 
4766 // Performs the rewrite on a single TPUReplicate node.
RewriteTPUReplicateNode(const string & session_handle,const DeviceSet & device_set,Node * replicate_node,FunctionLibraryDefinition * flib_def,FunctionLibraryRuntime * flr,Node * host_compute_key_placeholder_node,const OutsideCompilationNodeMap & outside_compilation_nodes,const std::vector<Node * > & head_tail_outside_compilation_nodes,NodeToNodeReplicasMap * outside_compilation_node_images,Graph * graph,const GraphShapeInfo & shape_info,TPUReplicateDeviceNamesMapping * tpu_replicate_device_names_mapping,int64_t autotuner_thresh)4767 /* static */ Status DistributedTPURewritePass::RewriteTPUReplicateNode(
4768     const string& session_handle, const DeviceSet& device_set,
4769     Node* replicate_node, FunctionLibraryDefinition* flib_def,
4770     FunctionLibraryRuntime* flr, Node* host_compute_key_placeholder_node,
4771     const OutsideCompilationNodeMap& outside_compilation_nodes,
4772     const std::vector<Node*>& head_tail_outside_compilation_nodes,
4773     NodeToNodeReplicasMap* outside_compilation_node_images, Graph* graph,
4774     const GraphShapeInfo& shape_info,
4775     TPUReplicateDeviceNamesMapping* tpu_replicate_device_names_mapping,
4776     int64_t autotuner_thresh) {
4777   VLOG(2) << "Rewriting node " << replicate_node->name();
4778 
4779   // num_replicas and num_cores_per_replica are the 'virtual' replicas (copies
4780   // of the computation) and cores (virtual cores within computations) specified
4781   // by the user. They will be mapped to physical TPU cores below.
4782   int num_replicas;
4783   int num_cores_per_replica;
4784   int num_tasks;
4785   std::vector<std::vector<string>> tf_device_assignment;
4786   std::vector<int> devices_to_lock;
4787   std::unique_ptr<xla::DeviceAssignment> xla_device_assignment;
4788   string tpu_compilation_device;
4789   TF_RETURN_IF_ERROR(GetDeviceTopology(
4790       device_set, *replicate_node, &num_replicas, &num_cores_per_replica,
4791       &num_tasks, &tf_device_assignment, &devices_to_lock,
4792       &xla_device_assignment, &tpu_compilation_device));
4793 
4794   TF_RETURN_IF_ERROR(UpdateHeadTailOutsideCompilation(
4795       tf_device_assignment, head_tail_outside_compilation_nodes));
4796 
4797   string replicate;
4798   TF_RETURN_IF_ERROR(
4799       GetNodeAttr(replicate_node->def(), kTPUReplicateAttr, &replicate));
4800   tpu_replicate_device_names_mapping->emplace(replicate, tf_device_assignment);
4801 
4802   NameRangeMap input_name_map;
4803   const NameAttrList* function;
4804   std::unique_ptr<Graph> computation;
4805   DataTypeVector arg_types, retval_types;
4806   ParameterInfo params_info;
4807   TF_RETURN_IF_ERROR(GetIOTypes(num_replicas, *replicate_node, flr, graph,
4808                                 &input_name_map, &function, &computation,
4809                                 &arg_types, &retval_types, &params_info));
4810 
4811   std::vector<InferredShape> arg_shapes, retval_shapes;
4812   TF_RETURN_IF_ERROR(GetArgAndRetvalShapes(
4813       shape_info, *replicate_node, params_info, &arg_shapes, &retval_shapes));
4814 
4815   TF_RETURN_IF_ERROR(ValidateCoreNumbers(*computation, num_cores_per_replica));
4816 
4817   std::vector<xla::OpSharding> arg_sharding;
4818   std::vector<bool> arg_fast_mem;
4819   std::vector<std::string> arg_names;
4820   std::vector<xla::OpSharding> retval_sharding;
4821   TF_RETURN_IF_ERROR(AssignArgsAndRetvalsToCores(
4822       num_cores_per_replica, params_info, arg_types, arg_shapes, retval_types,
4823       retval_shapes, *computation, replicate_node, flr,
4824       allow_xla_spmd_partition_, &arg_sharding, &arg_fast_mem, &retval_sharding,
4825       &arg_names));
4826 
4827   VLOG(1) << DumpGraphToFile("distributed_tpu_graph_to_replicate", *computation,
4828                              flib_def);
4829 
4830   GraphDef graph_def;
4831   graph->ToGraphDef(&graph_def);
4832   FunctionLibraryDefinition reachable_functions =
4833       flib_def->ReachableDefinitions(graph_def);
4834   uint64 library_fingerprint;
4835 
4836   TF_RETURN_IF_ERROR(
4837       FingerprintFunctionLibrary(reachable_functions, &library_fingerprint));
4838   VLOG(1) << "Fingerprint functions: "
4839           << absl::StrJoin(reachable_functions.ListFunctionNames(), ", ");
4840   VLOG(1) << "library_fingerprint: " << library_fingerprint;
4841 
4842   // Builds trigger nodes that put barriers around the expansion of
4843   // TPUReplicate. In particular, we must guarantee:
4844   // a) variable reads happen after all predecessors of the original
4845   //    TPUReplicate.
4846   // b) variable writes happen before all successors of the original
4847   //    TPUReplicate.
4848   // c) all replicas execute, even if output tensors are only requested from
4849   //    a subset of replicas. This is necessary both to ensure that variable
4850   //    updates happen, but also Send/Recv will deadlock if only one half of
4851   //    the communicating pair runs.
4852   Node* host_transfer_sequencer;
4853   Node* control_before;
4854   Node* control_after;
4855   TF_RETURN_IF_ERROR(BuildSequencingNodes(
4856       tpu_compilation_device, *replicate_node, graph, &host_transfer_sequencer,
4857       &control_before, &control_after));
4858 
4859   // Build a vector of variable nodes that are inputs.
4860   std::vector<VariableInput> variable_inputs;
4861   TF_RETURN_IF_ERROR(
4862       FindVariableInputs(*replicate_node, input_name_map, &variable_inputs));
4863 
4864   std::vector<Node*> guaranteed_constant_nodes;
4865   std::vector<Node*> variable_reads;
4866   TF_RETURN_IF_ERROR(DealWithConstantsAndVariables(
4867       *replicate_node, input_name_map, graph, host_transfer_sequencer,
4868       control_before, control_after, variable_inputs,
4869       &guaranteed_constant_nodes, &variable_reads));
4870 
4871   // Builds Shape nodes that compute the dynamic shapes of arguments whose
4872   // shapes are not statically known.
4873   std::vector<Node*> dynamic_shape_nodes;
4874   TF_RETURN_IF_ERROR(BuildDynamicShapeNodes(*replicate_node, arg_shapes,
4875                                             params_info, variable_reads, graph,
4876                                             &dynamic_shape_nodes));
4877 
4878   // Builds a TPUCompile node that compiles `clusters` on `compile_device`.
4879   Node* compile_node;
4880   TF_RETURN_IF_ERROR(BuildCompileNode(
4881       replicate_node, *function, library_fingerprint, params_info, arg_shapes,
4882       arg_types, guaranteed_constant_nodes, session_handle, arg_sharding,
4883       arg_fast_mem, arg_names, retval_sharding, num_cores_per_replica,
4884       /*compile_device=*/tpu_compilation_device, xla_device_assignment.get(),
4885       dynamic_shape_nodes, graph, &compile_node, autotuner_thresh));
4886 
4887   // Compilation must be sequenced after the control node if the TPU computation
4888   // in a control-flow construct, such as a loop.
4889   graph->AddControlEdge(control_before, compile_node);
4890 
4891   Node* control_after_compilation;
4892   Node* multilock_acquire;
4893   TF_RETURN_IF_ERROR(BuildCompilationStatusReturnNodes(
4894       replicate_node, compile_node, devices_to_lock, &control_after_compilation,
4895       &multilock_acquire, graph));
4896 
4897   std::vector<VariableWrite> variable_writes;
4898   TF_RETURN_IF_ERROR(BuildExecuteNodes(
4899       params_info, num_tasks, num_cores_per_replica, *replicate_node, arg_names,
4900       arg_types, arg_shapes, retval_types, arg_sharding, retval_sharding,
4901       tf_device_assignment, compile_node, variable_reads,
4902       control_after_compilation, control_after, multilock_acquire,
4903       &variable_writes, graph));
4904   bool contains_resource_write_op =
4905       ContainsResourceWriteOp(*graph, reachable_functions);
4906 
4907   VLOG(2) << "contains_resource_write_op: " << contains_resource_write_op;
4908   // Skip conditional write if there is no resource writing op inside TPU
4909   // computation.
4910   if (contains_resource_write_op) {
4911     TF_RETURN_IF_ERROR(BuildVariableWrites(variable_inputs, control_after,
4912                                            variable_writes, graph));
4913   }
4914 
4915   if (host_compute_key_placeholder_node != nullptr) {
4916     TF_RETURN_IF_ERROR(ConnectHostComputeNodes(
4917         compile_node, host_compute_key_placeholder_node, graph));
4918   }
4919 
4920   HostComputeCoreMap host_compute_core;
4921   TF_RETURN_IF_ERROR(ParseHostComputeCores(
4922       *replicate_node, outside_compilation_nodes, &host_compute_core));
4923   TF_RETURN_IF_ERROR(ReplicateOutsideCompilationNodes(
4924       tf_device_assignment, host_compute_core, outside_compilation_nodes,
4925       outside_compilation_node_images, graph));
4926 
4927   graph->RemoveNode(replicate_node);
4928   return Status::OK();
4929 }
4930 
4931 // Adds sharded weight update optimization for each host training loop.
4932 //
4933 // For any host training loop found in the graph, TPUVariableReshard ops
4934 // are inserted to match the best layout chosen by the XLA.
4935 /* static */ Status
PerformHostTrainingLoopOptimization(Graph * graph,FunctionLibraryDefinition * flib_def,FunctionLibraryRuntime * flr)4936 DistributedTPURewritePass::PerformHostTrainingLoopOptimization(
4937     Graph* graph, FunctionLibraryDefinition* flib_def,
4938     FunctionLibraryRuntime* flr) {
4939   std::vector<tpu::HostTrainingLoopInfo> host_training_loops_info;
4940   Status s = tpu::DetectHostTrainingLoop(
4941       /*current_function_name=*/nullptr,
4942       /*current_function_attr=*/nullptr, flib_def, graph, flr,
4943       &host_training_loops_info);
4944   if (!s.ok()) {
4945     VLOG(2) << "No valid host training loop found. Skipping sharded weight "
4946             << "update optimization.";
4947     return Status::OK();
4948   }
4949 
4950   for (const auto& host_loop : host_training_loops_info) {
4951     const auto& function_name = host_loop.encapsulating_function_name;
4952     // `function_name` has value when host training loop is inside a
4953     // function call node. When host training loop is found inside a function
4954     // call node, then, in addition to adding TPUVariableReshard ops, function
4955     // library definition needs to be updated as well.
4956     if (function_name.has_value()) {
4957       const auto& function_attr = host_loop.encapsulating_function_attrs;
4958       TF_RET_CHECK(function_attr.has_value())
4959           << "Unable to find function attribute for function: "
4960           << *function_name;
4961 
4962       const FunctionDef* function_def = flib_def->Find(*function_name);
4963       TF_RET_CHECK(function_def)
4964           << "Unable to find function : " << *function_name;
4965 
4966       std::unique_ptr<FunctionBody> fbody;
4967       TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
4968           *function_def, AttrSlice(&function_attr.value()), flib_def, &fbody));
4969       Graph* function_graph = fbody->graph;
4970       TF_RETURN_IF_ERROR(tpu::AddReshardOp(function_graph, host_loop));
4971       TF_RETURN_IF_ERROR(UpdateFunctionLibDefinition(*function_graph,
4972                                                      *function_name, flib_def));
4973     } else {
4974       TF_RETURN_IF_ERROR(tpu::AddReshardOp(graph, host_loop));
4975     }
4976   }
4977   return Status::OK();
4978 }
4979 
PlaceUnassignedDeviceNodesOnTPUIfPossible(Graph * graph)4980 Status DistributedTPURewritePass::PlaceUnassignedDeviceNodesOnTPUIfPossible(
4981     Graph* graph) {
4982   ReverseDFS(*graph, {}, PlaceOpsOnTPU);
4983   return Status::OK();
4984 }
4985 
Run(const GraphOptimizationPassOptions & options)4986 Status DistributedTPURewritePass::Run(
4987     const GraphOptimizationPassOptions& options) {
4988   VLOG(1) << "DistributedTPURewritePass::Run";
4989 
4990   Graph* graph = options.graph->get();
4991 
4992   VLOG(1) << DumpGraphToFile("distributed_tpu_compilation_before", *graph,
4993                              options.flib_def);
4994 
4995   const auto* config = &options.session_options->config;
4996   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
4997       new ProcessFunctionLibraryRuntime(
4998           nullptr, options.session_options->env, config,
4999           graph->versions().producer(), options.flib_def,
5000           config ? config->graph_options().optimizer_options()
5001                  : OptimizerOptions()));
5002 
5003   FunctionLibraryRuntime* flr =
5004       pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
5005 
5006   // This pass can only run in the session master, which should fill
5007   // in the device_set field to the options.
5008   TF_RET_CHECK(options.device_set != nullptr);
5009 
5010   // Find all the replicate nodes before mutating the graph.
5011   std::vector<Node*> replicate_nodes;
5012   // Map from compiled subgraph cluster name to the outside_compilation nodes in
5013   // that cluster.
5014   std::map<string, OutsideCompilationNodeMap> outside_compilation_nodes;
5015   std::map<string, std::vector<Node*>> head_tail_outside_compilation_nodes;
5016   TF_RETURN_IF_ERROR(FindTaggedNodes(graph, &replicate_nodes,
5017                                      &outside_compilation_nodes,
5018                                      &head_tail_outside_compilation_nodes));
5019 
5020   if (replicate_nodes.empty()) {
5021     // Remove unused TPUPartitionedInput nodes.
5022     for (Node* n : graph->nodes()) {
5023       if (n->type_string() == kTPUPartitionedInput) graph->RemoveNode(n);
5024     }
5025     VLOG(1) << DumpGraphToFile("distributed_tpu_compilation_after", *graph,
5026                                options.flib_def);
5027     VLOG(1) << "Replicate nodes are empty. DistributedTPURewritePass::Run() "
5028                "finished";
5029     return Status::OK();
5030   }
5031 
5032   std::unordered_map<string, Node*> host_compute_key_placeholder_map;
5033   TF_RETURN_IF_ERROR(FindHostComputeKeyPlaceholderNodes(
5034       graph, replicate_nodes, &host_compute_key_placeholder_map));
5035 
5036   // This shape inference pass does not compute the shapes of outputs of
5037   // TPU computations. The concurrent multi-core locking implementation
5038   // *relies* on this behavior because it ensures that, if TPU computation B's
5039   // inputs depend on TPU computation A's outputs, then computation B's
5040   // compilation will be sequenced after A's execution, and this ensures that
5041   // locks are acquired in the correct order. If the shape inference is improved
5042   // to compute shapes of TPU computation outputs, it will be necessary to add
5043   // an explicit control edge between lock acquisitions for dependent
5044   // computations in order to avoid deadlock.
5045   GraphShapeInfo shape_info;
5046   TF_RETURN_IF_ERROR(InferShapes(graph, /*arg_shapes=*/{},
5047                                  flr->GetFunctionLibraryDefinition(),
5048                                  &shape_info));
5049   int64_t autotuner_thresh = options.session_options->config.experimental()
5050                                  .xla_fusion_autotuner_thresh();
5051 
5052   NodeToNodeReplicasMap outside_compilation_node_images;
5053   TPUReplicateDeviceNamesMapping tpu_replicate_device_names_mapping;
5054   for (Node* node : replicate_nodes) {
5055     TF_RETURN_IF_ERROR(RewriteTPUReplicateNode(
5056         options.session_handle, *options.device_set, node, options.flib_def,
5057         flr, host_compute_key_placeholder_map[node->name()],
5058         outside_compilation_nodes[node->name()],
5059         head_tail_outside_compilation_nodes[node->name()],
5060         &outside_compilation_node_images, graph, shape_info,
5061         &tpu_replicate_device_names_mapping, autotuner_thresh));
5062   }
5063 
5064   // Place the padding nodes generated by dynamic padder on the correct devices.
5065   // TODO(rxsang): Place padding ops on TPUs in
5066   // PlaceUnassignedDeviceNodesOnTPUIfPossible function.
5067   TF_RETURN_IF_ERROR(SetPaddingNodesDevices(graph));
5068 
5069   std::unordered_map<string, Node*> outside_compilation_inputs;
5070   for (Node* n : graph->op_nodes()) {
5071     string lifted_arg_inputs_attr;
5072     if (n->type_string() == "IdentityN" &&
5073         GetNodeAttr(n->def(), kXlaOutsideCompilationInputsAttrName,
5074                     &lifted_arg_inputs_attr)
5075             .ok()) {
5076       outside_compilation_inputs[lifted_arg_inputs_attr] = n;
5077     }
5078   }
5079   for (const auto& iter : outside_compilation_nodes) {
5080     TF_RETURN_IF_ERROR(ReplicateOutsideCompilationEdges(
5081         iter.second, outside_compilation_node_images,
5082         outside_compilation_inputs, graph));
5083   }
5084   TF_RETURN_IF_ERROR(
5085       RemoveOutsideCompilationNodes(outside_compilation_node_images, graph));
5086   TF_RETURN_IF_ERROR(LowerOutsideCompilationFunctionalNodes(
5087       graph, *options.flib_def, tpu_replicate_device_names_mapping));
5088 
5089   TF_RETURN_IF_ERROR(PlaceUnassignedDeviceNodesOnTPUIfPossible(graph));
5090   VLOG(1) << DumpGraphToFile("distributed_tpu_compilation_after", *graph,
5091                              options.flib_def);
5092   VLOG(1) << "DistributedTPURewritePass::Run() finished";
5093 
5094   if (enable_cross_replica_sharding_mirrored_variables_) {
5095     VLOG(1) << "Starting host training loop optimization.";
5096     VLOG(1) << DumpGraphToFile("host_loop_optimization_before", *graph,
5097                                options.flib_def);
5098     TF_RETURN_IF_ERROR(
5099         PerformHostTrainingLoopOptimization(graph, options.flib_def, flr));
5100     VLOG(1) << DumpGraphToFile("host_loop_optimization_after", *graph,
5101                                options.flib_def);
5102     VLOG(1) << "Host training loop optimization finished.";
5103   }
5104 
5105   return Status::OK();
5106 }
5107 
5108 bool DistributedTPURewritePass::distribute_vars_ = false;
5109 bool DistributedTPURewritePass::allow_xla_spmd_partition_ = true;
5110 bool DistributedTPURewritePass::
5111     replicate_inputs_outputs_by_default_for_xla_spmd_ = false;
5112 bool DistributedTPURewritePass::
5113     enable_cross_replica_sharding_mirrored_variables_ = true;
5114 bool DistributedTPURewritePass::enable_automatic_model_parallelism_ = false;
5115 bool DistributedTPURewritePass::enable_xla_param_broadcast_ = false;
5116 bool DistributedTPURewritePass::enable_multicore_locking_ = false;
5117 bool DistributedTPURewritePass::use_nd_sharding_ops_ = false;
5118 
SetDistributedTpuRewritePassOptions(bool distribute_vars,bool allow_xla_spmd_partition,bool replicate_inputs_outputs_by_default_for_xla_spmd,bool enable_cross_replica_sharding_mirrored_variables,bool enable_automatic_model_parallelism,bool enable_xla_param_broadcast,bool enable_multicore_locking,bool use_nd_sharding_ops)5119 /*static*/ void DistributedTPURewritePass::SetDistributedTpuRewritePassOptions(
5120     bool distribute_vars, bool allow_xla_spmd_partition,
5121     bool replicate_inputs_outputs_by_default_for_xla_spmd,
5122     bool enable_cross_replica_sharding_mirrored_variables,
5123     bool enable_automatic_model_parallelism, bool enable_xla_param_broadcast,
5124     bool enable_multicore_locking, bool use_nd_sharding_ops) {
5125   distribute_vars_ = distribute_vars;
5126   allow_xla_spmd_partition_ = allow_xla_spmd_partition;
5127   replicate_inputs_outputs_by_default_for_xla_spmd_ =
5128       replicate_inputs_outputs_by_default_for_xla_spmd;
5129   enable_cross_replica_sharding_mirrored_variables_ =
5130       enable_cross_replica_sharding_mirrored_variables;
5131   enable_automatic_model_parallelism_ = enable_automatic_model_parallelism;
5132   enable_xla_param_broadcast_ = enable_xla_param_broadcast;
5133   enable_multicore_locking_ = enable_multicore_locking;
5134   use_nd_sharding_ops_ = use_nd_sharding_ops;
5135 }
5136 
5137 }  // namespace tensorflow
5138