• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
17 
18 #include <float.h>
19 
20 #include <functional>
21 #include <memory>
22 #include <unordered_map>
23 #include <vector>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/types/optional.h"
31 #include "absl/types/span.h"
32 #include "tensorflow/compiler/xla/comparison_util.h"
33 #include "tensorflow/compiler/xla/literal_util.h"
34 #include "tensorflow/compiler/xla/protobuf_util.h"
35 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
36 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
37 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
38 #include "tensorflow/compiler/xla/service/hlo_computation.h"
39 #include "tensorflow/compiler/xla/service/hlo_cse.h"
40 #include "tensorflow/compiler/xla/service/hlo_dce.h"
41 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
42 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
43 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
44 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
45 #include "tensorflow/compiler/xla/service/hlo_query.h"
46 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
47 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
48 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
49 #include "tensorflow/compiler/xla/service/shape_inference.h"
50 #include "tensorflow/compiler/xla/service/spmd/custom_call_handler.h"
51 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
52 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
53 #include "tensorflow/compiler/xla/shape_util.h"
54 #include "tensorflow/compiler/xla/util.h"
55 #include "tensorflow/compiler/xla/window_util.h"
56 #include "tensorflow/compiler/xla/xla_data.pb.h"
57 #include "tensorflow/core/platform/numbers.h"
58 
59 namespace xla {
60 namespace spmd {
61 
MakeReport()62 string SpmdLogger::MakeReport() {
63   string report;
64   absl::StrAppend(&report,
65                   "\n\n***** SPMD memory during transformation *****\n");
66 
67   std::sort(entries_.begin(), entries_.end(),
68             [](auto const& entry0, auto const& entry1) {
69               return entry0.first > entry1.first;
70             });
71   for (int64_t i = 0;
72        i < std::min<int64>(report_instruction_count_, entries_.size()); ++i) {
73     absl::StrAppend(
74         &report, "\n  ",
75         tensorflow::strings::HumanReadableNumBytes(entries_[i].first), " : ",
76         entries_[i].second, "\n");
77   }
78 
79   return report;
80 }
81 
RegisterLogEntry(HloInstruction * hlo,const std::vector<HloInstruction * > & group)82 void SpmdLogger::RegisterLogEntry(HloInstruction* hlo,
83                                   const std::vector<HloInstruction*>& group) {
84   string report = hlo->ToString();
85   int64_t max_value = -1;
86   for (HloInstruction* inst : group) {
87     if (!inst->shape().IsArray()) {
88       continue;
89     }
90     max_value = std::max<int64>(max_value, ShapeSizeInBytes(inst->shape()));
91     absl::StrAppend(&report, "     * ", inst->ToString(), "\n");
92   }
93   entries_.push_back(std::make_pair(max_value, report));
94 }
95 
ReportBeforePartition(const HloModule & module,int64_t report_instruction_count)96 /* static */ string SpmdLogger::ReportBeforePartition(
97     const HloModule& module, int64_t report_instruction_count) {
98   string report;
99   absl::StrAppend(&report,
100                   "\n\n***** SPMD memory usage before partition *****\n");
101   absl::StrAppend(&report, "\n  ** Replicated instructions\n");
102   absl::StrAppend(&report, ReportMemoryUsage(
103                                module,
104                                [](const HloInstruction* hlo) {
105                                  return !hlo->has_sharding() ||
106                                         hlo->sharding().IsReplicated();
107                                },
108                                report_instruction_count));
109   absl::StrAppend(&report, "\n  ** All instructions\n");
110   absl::StrAppend(&report,
111                   ReportMemoryUsage(
112                       module, [](const HloInstruction* hlo) { return true; },
113                       report_instruction_count));
114   return report;
115 }
116 
ReportAfterPartition(const HloModule & module,int64_t report_instruction_count)117 /* static */ string SpmdLogger::ReportAfterPartition(
118     const HloModule& module, int64_t report_instruction_count) {
119   string report;
120   absl::StrAppend(&report,
121                   "\n\n***** SPMD memory usage after partition *****\n");
122   absl::StrAppend(&report,
123                   ReportMemoryUsage(
124                       module, [](const HloInstruction* hlo) { return true; },
125                       report_instruction_count));
126   return report;
127 }
128 
129 template <typename F>
ReportMemoryUsage(const HloModule & module,const F & filter,int64_t report_instruction_count)130 /* static */ string SpmdLogger::ReportMemoryUsage(
131     const HloModule& module, const F& filter,
132     int64_t report_instruction_count) {
133   string report;
134   std::vector<HloInstruction*> instructions;
135   instructions.reserve(module.instruction_count());
136 
137   for (auto computation : module.computations()) {
138     if (computation->IsFusionComputation()) {
139       continue;
140     }
141     for (auto hlo : computation->instructions()) {
142       if (!hlo->shape().IsArray() ||
143           ShapeUtil::IsEffectiveScalar(hlo->shape())) {
144         continue;
145       }
146       if (filter(hlo)) {
147         instructions.push_back(hlo);
148       }
149     }
150   }
151 
152   const auto add_report = [&](std::vector<HloInstruction*>* insts) {
153     std::sort(insts->begin(), insts->end(),
154               [](const HloInstruction* inst0, const HloInstruction* inst1) {
155                 return ShapeSizeInBytes(inst0->shape()) >
156                        ShapeSizeInBytes(inst1->shape());
157               });
158     for (int64_t i = 0;
159          i < std::min<int64>(report_instruction_count, insts->size()); ++i) {
160       absl::StrAppend(&report, "  ",
161                       tensorflow::strings::HumanReadableNumBytes(
162                           ShapeSizeInBytes((*insts)[i]->shape())),
163                       " : ", (*insts)[i]->ToString(), "\n");
164     }
165   };
166 
167   add_report(&instructions);
168   return report;
169 }
170 
171 namespace {
172 
173 // Clears all sharding attributes from instructions in the module. This must be
174 // called only after all SPMD transformation is complete.
ClearShardingAttributes(HloModule * module)175 Status ClearShardingAttributes(HloModule* module) {
176   for (HloComputation* computation : module->computations()) {
177     for (HloInstruction* hlo : computation->instructions()) {
178       // Keep sharding annotation on Infeed and entry parameters since they're
179       // used by HloReplicationAnalysis later (for ArCrsCombiner).
180       if (hlo->opcode() == HloOpcode::kInfeed) {
181         continue;
182       }
183       if (hlo->opcode() == HloOpcode::kParameter &&
184           computation == module->entry_computation()) {
185         continue;
186       }
187       hlo->clear_sharding();
188     }
189   }
190   return Status::OK();
191 }
192 
GetPartitionGroupsForReplication(const HloSharding & sharding,absl::Span<const int64> replication_dims)193 std::vector<std::vector<int64>> GetPartitionGroupsForReplication(
194     const HloSharding& sharding, absl::Span<const int64> replication_dims) {
195   int64_t group_size = 1;
196   for (int64_t i : replication_dims) {
197     group_size *= sharding.tile_assignment().dim(i);
198   }
199   std::vector<std::vector<int64>> partition_groups(
200       sharding.tile_assignment().num_elements() / group_size);
201   sharding.tile_assignment().Each(
202       [&](absl::Span<const int64> indices, int64_t partition) {
203         int64_t group_id = 0;
204         for (int64_t i = 0; i < indices.size(); ++i) {
205           if (!absl::c_linear_search(replication_dims, i)) {
206             group_id *= sharding.tile_assignment().dim(i);
207             group_id += indices[i];
208           }
209         }
210         partition_groups[group_id].push_back(partition);
211       });
212   return partition_groups;
213 }
214 
215 }  // namespace
216 
AddInstruction(std::unique_ptr<HloInstruction> instruction)217 HloInstruction* SpmdBuilder::AddInstruction(
218     std::unique_ptr<HloInstruction> instruction) {
219   HloInstruction* hlo =
220       HloComputation::Builder::AddInstruction(std::move(instruction));
221   if (visiting_hlo_) {
222     hlo->set_metadata(visiting_hlo_->metadata());
223     instructions_[visiting_hlo_].push_back(hlo);
224   }
225   if (hlo->opcode() == HloOpcode::kBroadcast) {
226     for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
227       if (!absl::c_linear_search(hlo->dimensions(), i)) {
228         broadcast_dims_[hlo].insert(i);
229       }
230     }
231   }
232   if (hlo->IsElementwise() && hlo->operand_count() > 0) {
233     absl::flat_hash_set<int64> broadcast_dims;
234     for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
235       broadcast_dims.insert(i);
236     }
237     for (int64_t i = 0; i < hlo->operand_count(); ++i) {
238       auto it = broadcast_dims_.find(hlo->operand(i));
239       if (it == broadcast_dims_.end()) {
240         broadcast_dims.clear();
241         break;
242       }
243       for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
244         if (!it->second.contains(i)) {
245           broadcast_dims.erase(i);
246         }
247       }
248     }
249     if (!broadcast_dims.empty()) {
250       broadcast_dims_[hlo] = std::move(broadcast_dims);
251     }
252   }
253   if (hlo->opcode() == HloOpcode::kTranspose) {
254     auto it = broadcast_dims_.find(hlo->operand(0));
255     if (it != broadcast_dims_.end()) {
256       absl::flat_hash_set<int64> xpose_broadcast_dims;
257       std::vector<int64> reverse_map(hlo->shape().rank());
258       for (int64_t i = 0; i < reverse_map.size(); ++i) {
259         reverse_map[hlo->dimensions(i)] = i;
260       }
261       for (int64_t dim : it->second) {
262         xpose_broadcast_dims.insert(reverse_map[dim]);
263       }
264       broadcast_dims_[hlo] = std::move(xpose_broadcast_dims);
265     }
266   }
267   if (hlo->opcode() == HloOpcode::kReshape &&
268       Product(hlo->shape().dimensions()) > 0) {
269     auto it = broadcast_dims_.find(hlo->operand(0));
270     if (it != broadcast_dims_.end()) {
271       absl::flat_hash_set<int64> reshape_broadcast_dims;
272       for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
273         reshape_broadcast_dims.insert(i);
274       }
275       std::vector<int64> before_dim_size_stack;
276       std::vector<int64> after_dim_size_stack;
277       for (int64_t i = hlo->operand(0)->shape().rank() - 1; i >= 0; --i) {
278         before_dim_size_stack.push_back(hlo->operand(0)->shape().dimensions(i));
279       }
280       for (int64_t i = hlo->shape().rank() - 1; i >= 0; --i) {
281         after_dim_size_stack.push_back(hlo->shape().dimensions(i));
282       }
283       while (!before_dim_size_stack.empty() && !after_dim_size_stack.empty()) {
284         int64_t before_size = before_dim_size_stack.back();
285         int64_t after_size = after_dim_size_stack.back();
286         int64_t current_before_dim =
287             hlo->operand(0)->shape().rank() - before_dim_size_stack.size();
288         int64_t current_after_dim =
289             hlo->shape().rank() - after_dim_size_stack.size();
290         before_dim_size_stack.pop_back();
291         after_dim_size_stack.pop_back();
292         if (!it->second.contains(current_before_dim)) {
293           reshape_broadcast_dims.erase(current_after_dim);
294         }
295         if (before_size == after_size) {
296           continue;
297         }
298         if (before_size % after_size == 0) {
299           // Split dim.
300           before_dim_size_stack.push_back(before_size / after_size);
301         } else if (after_size % before_size == 0) {
302           // Merge dim.
303           after_dim_size_stack.push_back(after_size / before_size);
304         } else {
305           // Other cases, mark all remaining dims as non-broadcast.
306           for (int64_t i = current_after_dim; i < hlo->shape().rank(); ++i) {
307             reshape_broadcast_dims.erase(i);
308           }
309           break;
310         }
311       }
312       if (!before_dim_size_stack.empty() || !after_dim_size_stack.empty()) {
313         reshape_broadcast_dims.clear();
314       }
315       if (!reshape_broadcast_dims.empty()) {
316         broadcast_dims_[hlo] = std::move(reshape_broadcast_dims);
317       }
318     }
319   }
320   if (hlo->opcode() == HloOpcode::kSlice ||
321       hlo->opcode() == HloOpcode::kDynamicSlice) {
322     auto it = broadcast_dims_.find(hlo->operand(0));
323     if (it != broadcast_dims_.end()) {
324       auto dims = it->second;
325       broadcast_dims_[hlo] = std::move(dims);
326     }
327   }
328   if (hlo->opcode() == HloOpcode::kPad) {
329     auto it = broadcast_dims_.find(hlo->operand(0));
330     if (it != broadcast_dims_.end()) {
331       absl::flat_hash_set<int64> pad_broadcast_dims;
332       for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
333         const auto& dim = hlo->padding_config().dimensions(i);
334         if (dim.edge_padding_low() == 0 && dim.edge_padding_high() == 0 &&
335             dim.interior_padding() == 0 && it->second.contains(i)) {
336           pad_broadcast_dims.insert(i);
337         }
338       }
339       if (!pad_broadcast_dims.empty()) {
340         broadcast_dims_[hlo] = std::move(pad_broadcast_dims);
341       }
342     }
343   }
344   return hlo;
345 }
346 
Reshard(const HloSharding & target)347 PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target) {
348   if (sharding() == target) {
349     return *this;
350   }
351   auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
352   const bool is_to_replicate =
353       hlo_->shape().IsArray() && target.NumTiles() < sharding().NumTiles();
354   const bool use_cache =
355       !is_to_replicate || state_.partitioner->options().cache_all_gather;
356   if (use_cache) {
357     for (auto& entry : cache) {
358       if (entry.first == target) {
359         return entry.second;
360       }
361     }
362   }
363   auto resharded = ReshardNoCache(target);
364   state_.reshard_cache->per_hlo_cache[resharded.hlo()]
365       .reshard_cache.emplace_back(sharding(), *this);
366   if (use_cache) {
367     // Get the cache again as it might be invalidated by the insertion above.
368     auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
369     cache.emplace_back(target, std::move(resharded));
370     return cache.back().second;
371   }
372   return resharded;
373 }
374 
ReshardNoCache(const HloSharding & target)375 PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) {
376   VLOG(2) << "Resharding " << hlo_->ToString() << " from "
377           << hlo_->sharding().ToString() << " to " << target.ToString();
378   const Shape& shape = hlo_->shape();
379   if (shape.element_type() == TOKEN) {
380     return *this;
381   }
382   CHECK(shape.IsTuple() || !target.IsTuple());
383 
384   // Tuple shape instructions may have non-tuple sharding, which means that the
385   // same sharding applies to all the leaves.
386   if (shape.IsTuple() && !target.IsTuple()) {
387     return Reshard(target.GetTupleSharding(shape).ValueOrDie());
388   }
389 
390   // For a tuple shape, recursively apply Reshard to all the leaves and return
391   // a tuple instruction.
392   if (shape.IsTuple()) {
393     std::vector<HloInstruction*> elements;
394     for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
395       auto subshape = ShapeUtil::GetTupleElementShape(shape, i);
396       auto element = state_.b->AddInstruction(
397           HloInstruction::CreateGetTupleElement(subshape, hlo(), i));
398       element->set_sharding(sharding().GetSubSharding(shape, {i}));
399       elements.push_back(
400           PartitionedHlo(
401               element, ShapeUtil::GetTupleElementShape(base_shape_, i), state_)
402               .Reshard(target.GetSubSharding(shape, {i}))
403               .hlo());
404     }
405     auto tuple =
406         state_.b->AddInstruction(HloInstruction::CreateTuple(elements));
407     tuple->set_sharding(target);
408     return PartitionedHlo(tuple, base_shape_, state_);
409   }
410 
411   if (sharding() == target) {
412     return *this;
413   }
414 
415   if (CanReshardWithCollectivePermute(sharding(), target)) {
416     return ReshardWithCollectivePermute(target);
417   }
418 
419   if (auto src_tgt_dims =
420           GetReshardAllToAllSourceTargetDims(sharding(), target)) {
421     return ReshardWithAllToAll(target, *src_tgt_dims);
422   }
423 
424   if (!target.IsTileMaximal() && sharding().ReplicateOnLastTileDim()) {
425     auto try_reshard = ReshardFromPartialReplicateWithDynamicSlice(target);
426     if (try_reshard.has_value()) {
427       return try_reshard.value();
428     }
429     try_reshard = ReshardPartialReplicateWithAllToAll(target);
430     if (try_reshard.has_value()) {
431       return try_reshard.value();
432     }
433   }
434 
435   if (!sharding().IsTileMaximal() && target.ReplicateOnLastTileDim()) {
436     auto try_reshard = ReshardToPartialReplicateWithAllGather(target);
437     if (try_reshard.has_value()) {
438       return try_reshard.value();
439     }
440     try_reshard = ReshardPartialReplicateWithAllToAll(target);
441     if (try_reshard.has_value()) {
442       return try_reshard.value();
443     }
444   }
445 
446   // If not replicated yet, first replicate and then reshard to use one of the
447   // two implementations below.
448   if (!sharding().IsReplicated()) {
449     return Replicate().Reshard(target);
450   }
451 
452   // 'Replicated' to 'SingleDevice'.
453   if (target.IsTileMaximal()) {
454     auto copy = state_.b->AddInstruction(
455         HloInstruction::CreateUnary(hlo_->shape(), HloOpcode::kCopy, hlo_));
456     copy->set_sharding(target);
457     return PartitionedHlo(copy, base_shape_, state_);
458   }
459 
460   // 'Replicated' to partial replicated.
461   if (target.ReplicateOnLastTileDim()) {
462     std::vector<int64> group_dims(target.tile_assignment().num_dimensions() -
463                                   1);
464     std::iota(group_dims.begin(), group_dims.end(), 0);
465     auto target_grouped = GroupShardingOnDims(target, group_dims);
466     auto partially_sharded = PerGroupSliceFromReplicated(
467         hlo_, state_.partition_id, target_grouped.device_groups, group_dims,
468         target_grouped.group_dim_sizes, state_.b);
469     partially_sharded->set_sharding(target);
470     return PartitionedHlo(partially_sharded, base_shape(), state_);
471   }
472 
473   // 'Replicated' to 'Tiled'.
474   auto padded_hlo =
475       PadBaseShapeBeforeUnevenTiledSharding(hlo_, target, state_.b);
476   auto shard_shape = MakePartitionedShape(shape, target);
477   auto slice = state_.b->AddInstruction(HloInstruction::CreateDynamicSlice(
478       shard_shape, padded_hlo,
479       MakePartitionOffsets(shape, target, state_.partition_id, state_.b),
480       shard_shape.dimensions()));
481   slice->set_sharding(target);
482   return PartitionedHlo(slice, base_shape_, state_);
483 }
484 
PadWithValue(HloInstruction * pad_value,absl::Span<const int64> left_padded_dims,absl::Span<const int64> skipped_dims) const485 PartitionedHlo PartitionedHlo::PadWithValue(
486     HloInstruction* pad_value, absl::Span<const int64> left_padded_dims,
487     absl::Span<const int64> skipped_dims) const {
488   const HloSharding& sharding = hlo_->sharding();
489   const Shape& shape = hlo_->shape();
490   CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
491   if (sharding.IsReplicated() || EvenlyPartitions(base_shape_, sharding)) {
492     return *this;
493   }
494   CHECK(!sharding.IsTileMaximal());
495   auto index_shape = ShapeUtil::ChangeElementType(shape, S32);
496   auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED);
497   auto get_mask_for_dim = [&](int64_t dim, HloInstruction* start_index) {
498     // Comparison: iota + start_index < valid_size
499     auto iota =
500         state_.b->AddInstruction(HloInstruction::CreateIota(index_shape, dim));
501     auto broadcast_start_index = state_.b->AddInstruction(
502         HloInstruction::CreateBroadcast(index_shape, start_index, {}));
503     auto index_in_full_shape =
504         state_.b->AddInstruction(HloInstruction::CreateBinary(
505             index_shape, HloOpcode::kAdd, iota, broadcast_start_index));
506     ComparisonDirection direction = ComparisonDirection::kLt;
507     int64_t index_limit = base_shape_.dimensions(dim);
508     if (absl::c_linear_search(left_padded_dims, dim)) {
509       direction = ComparisonDirection::kGe;
510       index_limit =
511           index_shape.dimensions(dim) * sharding.tile_assignment().dim(dim) -
512           index_limit;
513     }
514     auto limit = state_.b->AddInstruction(HloInstruction::CreateConstant(
515         LiteralUtil::CreateR0<int32>(index_limit)));
516     auto broadcast_limit = state_.b->AddInstruction(
517         HloInstruction::CreateBroadcast(index_shape, limit, {}));
518     return state_.b->AddInstruction(HloInstruction::CreateCompare(
519         mask_shape, index_in_full_shape, broadcast_limit, direction));
520   };
521 
522   HloInstruction* mask = nullptr;
523   auto offsets = MakePartitionOffsets(base_shape_, sharding,
524                                       state_.partition_id, state_.b);
525   for (int64_t i = 0; i < shape.rank(); ++i) {
526     if (base_shape_.dimensions(i) % sharding.tile_assignment().dim(i) == 0 ||
527         absl::c_linear_search(skipped_dims, i)) {
528       continue;
529     }
530     if (mask == nullptr) {
531       mask = get_mask_for_dim(i, offsets[i]);
532     } else {
533       mask = state_.b->AddInstruction(
534           HloInstruction::CreateBinary(mask->shape(), HloOpcode::kAnd, mask,
535                                        get_mask_for_dim(i, offsets[i])));
536     }
537   }
538 
539   if (mask == nullptr) {
540     return *this;
541   }
542 
543   auto broadcast_pad_value = state_.b->AddInstruction(
544       HloInstruction::CreateBroadcast(shape, pad_value, {}));
545   auto result = state_.b->AddInstruction(HloInstruction::CreateTernary(
546       shape, HloOpcode::kSelect, mask, hlo_, broadcast_pad_value));
547   result->set_sharding(sharding);
548   return PartitionedHlo(result, base_shape_, state_);
549 }
550 
551 absl::optional<PartitionedHlo::WindowedInputShardReturnValue>
ReshardAsWindowedInput(const Window & window,const HloSharding & target,HloInstruction * pad_value,bool mask_invalid_region)552 PartitionedHlo::ReshardAsWindowedInput(const Window& window,
553                                        const HloSharding& target,
554                                        HloInstruction* pad_value,
555                                        bool mask_invalid_region) {
556   auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].window_reshard_cache;
557   for (auto& entry : cache) {
558     if (std::get<0>(entry) == target &&
559         protobuf_util::ProtobufEquals(std::get<1>(entry), window)) {
560       return std::get<2>(entry);
561     }
562   }
563   auto update_cache = [&](WindowedInputShardReturnValue result) {
564     cache.emplace_back(target, window, std::move(result));
565     return std::get<2>(cache.back());
566   };
567   VLOG(2) << "ReshardAsWindowedInput()\n"
568           << "\twindow:" << window_util::ToString(window)
569           << "\ttarget sharding:" << target.ToString();
570 
571   CHECK(!target.IsTileMaximal());
572   auto partition_ordinals =
573       MakeTiledPartitionOrdinals(target, state_.partition_id, state_.b);
574   auto shard_shape = base_shape_;
575 
576   std::vector<MultiplyAddDivideOffsetCalculation> start_on_padded_calculations(
577       base_shape_.rank());
578   std::vector<MultiplyAddDivideOffsetCalculation> limit_on_padded_calculations(
579       base_shape_.rank());
580   std::vector<HloInstruction*> dynamic_slice_offset_on_output(
581       base_shape_.rank(), nullptr);
582 
583   Window shard_window = window;
584   Shape padded_shape = base_shape_;
585   std::vector<HloInstruction*> offsets_on_padded_shape(base_shape_.rank());
586   std::vector<int64> per_shard_window_counts(base_shape_.rank());
587   std::vector<int64> explicit_left_padding(base_shape_.rank());
588   for (int64_t i = 0; i < base_shape_.rank(); ++i) {
589     // Do not pad non-partitioned dimensions.
590     int64_t shard_count = target.tile_assignment().dim(i);
591     if (shard_count == 1) {
592       offsets_on_padded_shape[i] = state_.b->AddInstruction(
593           HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
594       continue;
595     }
596     const WindowDimension& wd = window.dimensions(i);
597     const int64_t dilated_size = 1 + (wd.size() - 1) * wd.window_dilation();
598     const int64_t full_size =
599         1 + (base_shape_.dimensions(i) - 1) * wd.base_dilation() +
600         wd.padding_high() + wd.padding_low();
601     if (full_size < dilated_size) {
602       VLOG(2) << "Failed to reshard window operand because the window size is "
603                  "larger than padded base size";
604       return absl::nullopt;
605     }
606     int64_t window_count = (full_size - dilated_size) / wd.stride() + 1;
607     per_shard_window_counts[i] = CeilOfRatio(window_count, shard_count);
608     if (wd.stride() != 1 &&
609         (wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() != 0) {
610       // TODO(yuanzx): Support this case.
611       VLOG(2) << "Failed to reshard window operand due to non-trivial dilation";
612       return absl::nullopt;
613     }
614 
615     // We use explicit padding for full dilations, then use padding_low and
616     // padding_high on the sharded op for the remaining. padding_low and
617     // padding_high are now given initial values, which will be later updated if
618     // dilation is not 1.
619     WindowDimension* swd = shard_window.mutable_dimensions(i);
620     explicit_left_padding[i] = wd.padding_low() / wd.base_dilation();
621     swd->set_padding_low(wd.padding_low() % wd.base_dilation());
622     swd->set_padding_high(0);
623 
624     // Calculation for the first element needed on the 'padded-but-not-dilated'
625     // shape. The start on the dilated shape could be a hole, so we add
626     // wd.base_dilation() - 1 to the constant term to skip the leading holes.
627     start_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation(
628         wd.stride() * per_shard_window_counts[i],
629         wd.base_dilation() - 1 - swd->padding_low(), wd.base_dilation());
630     int64_t dilated_shard_size =
631         wd.stride() * (per_shard_window_counts[i] - 1) + dilated_size;
632     limit_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation(
633         wd.stride() * per_shard_window_counts[i],
634         dilated_shard_size + wd.base_dilation() - 1 - swd->padding_low(),
635         wd.base_dilation());
636 
637     offsets_on_padded_shape[i] = start_on_padded_calculations[i].Calculate(
638         partition_ordinals[i], state_.b);
639 
640     auto shard_size_function =
641         limit_on_padded_calculations[i] - start_on_padded_calculations[i];
642     int64_t max_shard_size = shard_size_function.MaxInRange(0, shard_count);
643     shard_shape.set_dimensions(i, max_shard_size);
644     padded_shape.set_dimensions(
645         i, limit_on_padded_calculations[i].Calculate(shard_count - 1));
646 
647     // For base dilation, calculate the needed padding_low and padding_high, as
648     // well as the offset for the output if a dynamic slice is needed after the
649     // sharded op.
650     if (wd.base_dilation() != 1) {
651       // Returns the offset of a shard's first valid element in the dilated
652       // shard.
653       auto get_first_valid_element_offset_on_dilated_shard =
654           [&](int64_t shard_ordinal) {
655             return start_on_padded_calculations[i].Calculate(shard_ordinal) *
656                        wd.base_dilation() +
657                    swd->padding_low() -
658                    wd.stride() * per_shard_window_counts[i] * shard_ordinal;
659           };
660       CHECK_EQ(get_first_valid_element_offset_on_dilated_shard(0),
661                swd->padding_low());
662 
663       // Determine swd->padding_high.
664       for (int64_t shard_ordinal = 0; shard_ordinal < shard_count;
665            ++shard_ordinal) {
666         int64_t wanted_limit_on_dilated_shard =
667             wd.stride() * (per_shard_window_counts[i] - 1) + dilated_size;
668         int64_t actual_limit_on_dilated_shard_without_pad_high =
669             get_first_valid_element_offset_on_dilated_shard(shard_ordinal) +
670             (max_shard_size - 1) * wd.base_dilation() + 1;
671         swd->set_padding_high(std::max<int64>(
672             swd->padding_high(),
673             wanted_limit_on_dilated_shard -
674                 actual_limit_on_dilated_shard_without_pad_high));
675       }
676 
677       // Determine swd->padding_low and output dynamic slice index.
678       if (wd.stride() == 1) {
679         int64_t max_pad_low =
680             get_first_valid_element_offset_on_dilated_shard(0);
681         bool all_same = true;
682         for (int64_t shard_ordinal = 1; shard_ordinal < shard_count;
683              ++shard_ordinal) {
684           int64_t start =
685               get_first_valid_element_offset_on_dilated_shard(shard_ordinal);
686           if (start != swd->padding_low()) {
687             all_same = false;
688           }
689           max_pad_low = std::max(max_pad_low, start);
690         }
691         if (!all_same) {
692           auto start_on_padded_input =
693               start_on_padded_calculations[i].Calculate(partition_ordinals[i],
694                                                         state_.b);
695           // We will calculate
696           //   max_pad_low - (first_window - required_first_window)
697           // which equals
698           //   required_first_window - (first_window - max_pad_low)
699           auto first_window_minus_max_pad_low =
700               MultiplyAddDivideOffsetCalculation(
701                   wd.base_dilation(), swd->padding_low() - max_pad_low, 1)
702                   .Calculate(start_on_padded_input, state_.b);
703           auto required_first_window =
704               MultiplyAddDivideOffsetCalculation(per_shard_window_counts[i], 0,
705                                                  1)
706                   .Calculate(partition_ordinals[i], state_.b);
707           dynamic_slice_offset_on_output[i] =
708               state_.b->AddInstruction(HloInstruction::CreateBinary(
709                   required_first_window->shape(), HloOpcode::kSubtract,
710                   required_first_window, first_window_minus_max_pad_low));
711         }
712         swd->set_padding_low(max_pad_low);
713       } else {
714         if ((wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() !=
715             0) {
716           // General base dilation not yet implemented.
717           return absl::nullopt;
718         }
719         // padding_low on all shards should equal the initially assigned
720         // swd->padding_low(), i.e., the padding_low() on the original window.
721       }
722     }
723   }
724 
725   // Returns the output dynamic slice offset when needed, and absl::nullopt
726   // otherwise.
727   auto get_dynamic_slice_offset_on_output_if_needed =
728       [&]() -> absl::optional<std::vector<HloInstruction*>> {
729     if (absl::c_all_of(
730             dynamic_slice_offset_on_output,
731             [](HloInstruction* offset) { return offset == nullptr; })) {
732       return absl::nullopt;
733     }
734     auto zero = state_.b->AddInstruction(
735         HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
736     for (int64_t i = 0; i < dynamic_slice_offset_on_output.size(); ++i) {
737       if (dynamic_slice_offset_on_output[i] == nullptr) {
738         dynamic_slice_offset_on_output[i] = zero;
739       }
740     }
741     return dynamic_slice_offset_on_output;
742   };
743 
744   // If the currrent HLO is replicated, pad then slice.
745   if (sharding().IsReplicated()) {
746     PaddingConfig padding_config;
747     for (int64_t i = 0; i < base_shape_.rank(); ++i) {
748       auto padding_config_dim = padding_config.add_dimensions();
749       padding_config_dim->set_interior_padding(0);
750       // Do not pad non-partitioned dimensions.
751       if (target.tile_assignment().dim(i) == 1) {
752         padding_config_dim->set_edge_padding_low(0);
753         padding_config_dim->set_edge_padding_high(0);
754         continue;
755       }
756       padding_config_dim->set_edge_padding_low(explicit_left_padding[i]);
757       padding_config_dim->set_edge_padding_high(padded_shape.dimensions(i) -
758                                                 explicit_left_padding[i] -
759                                                 base_shape_.dimensions(i));
760     }
761     auto padded_hlo = ShapeUtil::Compatible(padded_shape, base_shape_)
762                           ? hlo_
763                           : state_.b->AddInstruction(HloInstruction::CreatePad(
764                                 padded_shape, hlo_, pad_value, padding_config));
765     auto sharded_input =
766         state_.b->AddInstruction(HloInstruction::CreateDynamicSlice(
767             shard_shape, padded_hlo, offsets_on_padded_shape,
768             shard_shape.dimensions()));
769     return update_cache(WindowedInputShardReturnValue{
770         sharded_input, shard_window,
771         get_dynamic_slice_offset_on_output_if_needed()});
772   }
773 
774   if (target != sharding()) {
775     return Reshard(target).ReshardAsWindowedInput(window, target, pad_value);
776   }
777 
778   // Halo exchange.
779   HloInstruction* visiting_hlo = hlo_;
780   auto original_shard_shape = MakePartitionedShape(base_shape_, target);
781 
782   std::vector<OffsetCalculation> left_halo_size_functions(base_shape_.rank());
783   std::vector<OffsetCalculation> right_halo_size_functions(base_shape_.rank());
784   // TODO(yuanzx): We are concatenating on each sharded dimension one at time,
785   // and in the second dimension (and beyond) we create halos by slicing the
786   // concat in the previous dimension, which is not optimal. We should generate
787   // halos only concating slices, instead of slicing concats.
788   for (int dim = 0; dim < base_shape_.rank(); ++dim) {
789     int64_t shard_count = target.tile_assignment().dim(dim);
790     if (shard_count == 1) {
791       continue;
792     }
793     int64_t input_shard_size =
794         CeilOfRatio(base_shape_.dimensions(dim), shard_count);
795 
796     // Left halo. The size of the halo is derived by subtracting the first read
797     // element offset of the i'th partition from the limit of the (i-1)'th
798     // partition.
799     MultiplyAddDivideOffsetCalculation shard_limit_of_previous_on_padded(
800         input_shard_size, explicit_left_padding[dim], 1);
801     left_halo_size_functions[dim] =
802         shard_limit_of_previous_on_padded - start_on_padded_calculations[dim];
803 
804     // Right halo.
805     MultiplyAddDivideOffsetCalculation shard_start_of_next_on_padded(
806         input_shard_size, input_shard_size + explicit_left_padding[dim], 1);
807     right_halo_size_functions[dim] =
808         limit_on_padded_calculations[dim] - shard_start_of_next_on_padded;
809 
810     auto resharded = ExchangeHaloAndGetValidData(
811         visiting_hlo, base_shape_, left_halo_size_functions[dim],
812         right_halo_size_functions[dim], explicit_left_padding[dim],
813         padded_shape.dimensions(dim), shard_shape.dimensions(dim), dim, target,
814         offsets_on_padded_shape[dim], pad_value, partition_ordinals[dim],
815         state_.collective_ops_creator, state_.next_channel_id, state_.b,
816         mask_invalid_region);
817     if (!resharded) {
818       VLOG(1) << "ReshardAsWindowedInput failed without replicate first: halo "
819                  "is beyond the neighbor.";
820       return Replicate().ReshardAsWindowedInput(window, target, pad_value);
821     }
822     visiting_hlo = *resharded;
823   }
824   return update_cache(WindowedInputShardReturnValue{
825       visiting_hlo, shard_window,
826       get_dynamic_slice_offset_on_output_if_needed()});
827 }
828 
Replicate()829 PartitionedHlo PartitionedHlo::Replicate() {
830   auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
831   if (state_.partitioner->options().cache_all_gather) {
832     for (auto& entry : cache) {
833       if (entry.first.IsReplicated()) {
834         return entry.second;
835       }
836     }
837   }
838   // Do not use a reference as the HLO's sharding can be temporarily replaced.
839   const HloSharding sharding = hlo_->sharding();
840   const Shape& shape = hlo_->shape();
841   CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
842 
843   if (sharding.IsReplicated()) {
844     return *this;
845   }
846   for (auto& entry : cache) {
847     if (entry.first.IsReplicated()) {
848       return entry.second;
849     }
850   }
851   auto update_cache = [&](PartitionedHlo resharded) {
852     state_.reshard_cache->per_hlo_cache[resharded.hlo()]
853         .reshard_cache.emplace_back(sharding, *this);
854     // Get the cache again as it might be invalidated by the insertion above.
855     auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
856     if (state_.partitioner->options().cache_all_gather) {
857       cache.emplace_back(HloSharding::Replicate(), std::move(resharded));
858       return cache.back().second;
859     }
860     return resharded;
861   };
862   // 'Single Device' to 'Repliated'.
863   if (sharding.IsTileMaximal()) {
864     return update_cache(Broadcast());
865   }
866 
867   // 'Tiled' to 'Replicated'.
868   std::vector<int64> all_dims(shape.rank());
869   std::iota(all_dims.begin(), all_dims.end(), 0);
870   HloInstruction* result = ReplicatePartial(all_dims);
871   result->set_sharding(HloSharding::Replicate());
872   return update_cache(PartitionedHlo(result, base_shape_, state_));
873 }
874 
ReplicatePartial(absl::Span<const int64> dims)875 HloInstruction* PartitionedHlo::ReplicatePartial(absl::Span<const int64> dims) {
876   CHECK(!sharding().IsTileMaximal());
877   const Shape& shard_shape = hlo()->shape();
878   Shape target_shape = shard_shape;
879   Shape padded_target_shape = shard_shape;
880   std::vector<int64> broadcast_dims;
881   std::vector<int64> ag_dims;
882   // Find dimensions that can be replicated with Broadcast() (shard size 1) and
883   // others that need all-gather.
884   for (int64_t i : dims) {
885     if (sharding().tile_assignment().dim(i) == 1) {
886       continue;
887     }
888     target_shape.set_dimensions(i, base_shape().dimensions(i));
889     if (target_shape.dimensions(i) == shard_shape.dimensions(i)) {
890       broadcast_dims.push_back(i);
891     } else {
892       padded_target_shape.set_dimensions(
893           i, shard_shape.dimensions(i) * sharding().tile_assignment().dim(i));
894       ag_dims.push_back(i);
895     }
896   }
897 
898   HloInstruction* broadcast = hlo_;
899   if (!broadcast_dims.empty()) {
900     std::vector<int64> other_dims;
901     for (int64_t i = 0; i < sharding().tile_assignment().num_dimensions();
902          ++i) {
903       if (!absl::c_linear_search(broadcast_dims, i)) {
904         other_dims.push_back(i);
905       }
906     }
907     HloSharding original_sharding = sharding();
908     auto grouped = GroupShardingOnDims(original_sharding, other_dims);
909     std::vector<int64> dev_indices(
910         grouped.sharding.tile_assignment().num_dimensions(), 0);
911     hlo_->set_sharding(HloSharding::AssignDevice(
912         grouped.sharding.tile_assignment()(dev_indices)));
913     auto per_group_partitioner_state = CreatePerGroupPartitioningState(
914         state(), grouped.device_groups, state().b);
915     auto partial_replicate_hlo =
916         PartitionedHlo(hlo_, shard_shape, per_group_partitioner_state)
917             .Broadcast();
918     hlo_->set_sharding(original_sharding);
919     partial_replicate_hlo.hlo()->clear_sharding();
920     broadcast = partial_replicate_hlo.hlo();
921   }
922 
923   if (ag_dims.empty()) {
924     return broadcast;
925   }
926 
927   HloInstruction* result = nullptr;
928   if (state_.collective_ops_creator.create_cross_partition_all_gather) {
929     result = state_.partitioner->AllGatherShards(
930         state_.b, broadcast, sharding(), state_.next_channel_id, ag_dims,
931         state_.collective_ops_creator);
932   }
933   if (result == nullptr) {
934     auto zero = state_.b->AddInstruction(HloInstruction::CreateConstant(
935         LiteralUtil::Zero(shard_shape.element_type())));
936     auto zero_bcast = state_.b->AddInstruction(
937         HloInstruction::CreateBroadcast(padded_target_shape, zero, {}));
938     auto offsets = MakePartitionOffsets(padded_target_shape, sharding(),
939                                         state_.partition_id, state_.b, ag_dims);
940     auto dus =
941         state_.b->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
942             padded_target_shape, zero_bcast, broadcast, offsets));
943     HloComputation* reduction =
944         MakeBinaryAdd(shard_shape.element_type(), state_.module);
945     result = state_.partitioner->AllReduceAlongShardingDims(
946         state_.b, dus, sharding(), state_.next_channel_id, ag_dims,
947         state_.collective_ops_creator, reduction);
948   }
949   if (!ShapeUtil::Compatible(target_shape, padded_target_shape)) {
950     std::vector<int64> start_indices(target_shape.rank(), 0);
951     std::vector<int64> strides(target_shape.rank(), 1);
952     result = state_.b->AddInstruction(
953         HloInstruction::CreateSlice(target_shape, result, start_indices,
954                                     target_shape.dimensions(), strides));
955   }
956   return result;
957 }
958 
959 absl::optional<PartitionedHlo>
ReshardToPartialReplicateWithAllGather(const HloSharding & target)960 PartitionedHlo::ReshardToPartialReplicateWithAllGather(
961     const HloSharding& target) {
962   if (!target.ReplicateOnLastTileDim()) {
963     return absl::nullopt;
964   }
965   // Tiled/partial replicate to partial replicate
966   // Get the comptible sharding to target with resharding by all reduce.
967   auto compatible_sharding =
968       PartialReplicateReshardCompatibleSharding(target, sharding());
969   if (!compatible_sharding.has_value()) {
970     return absl::nullopt;
971   }
972 
973   const auto& temp_sharding = compatible_sharding.value();
974   auto partitioned_hlo = *this;
975   // Use collective permute to adjust device assignment if needed.
976   if (CanReshardWithCollectivePermute(sharding(), temp_sharding)) {
977     partitioned_hlo =
978         partitioned_hlo.ReshardWithCollectivePermute(temp_sharding);
979   }
980 
981   // Get replicate dims and replicate factor of each dimensions.
982   int64_t rank = hlo_->shape().rank();
983   std::vector<int64> replicate_dims;
984   std::vector<int64> replicate_factors;
985   for (int64_t dim = 0; dim < rank; dim++) {
986     int64_t replicate_factor = temp_sharding.tile_assignment().dim(dim) /
987                                target.tile_assignment().dim(dim);
988     if (replicate_factor > 1) {
989       replicate_dims.emplace_back(dim);
990       replicate_factors.emplace_back(replicate_factor);
991     }
992   }
993 
994   // Do left halo exchange if all-reduce directly will remove useful data
995   // from the source.
996   auto halo_exchange = TileToPartialReplicateHaloExchange(
997       partitioned_hlo.hlo_, base_shape_, temp_sharding, target, replicate_dims,
998       partitioned_hlo.state().collective_ops_creator,
999       partitioned_hlo.state().next_channel_id,
1000       partitioned_hlo.state().partition_id, partitioned_hlo.state().b);
1001   if (!halo_exchange.has_value()) {
1002     return absl::nullopt;
1003   }
1004   auto halo_exchange_hlo = halo_exchange.value();
1005   // Grouped on replicate dimensions.
1006   auto sharding_grouped =
1007       GroupShardingOnDims(temp_sharding, replicate_dims, replicate_factors);
1008   auto per_group_partitioner_state = CreatePerGroupPartitioningState(
1009       partitioned_hlo.state(), sharding_grouped.device_groups,
1010       partitioned_hlo.state().b);
1011   auto base_shape = MakePartitionedShape(base_shape_, target);
1012   // It's possible that halo_exchange_hlo == hlo.hlo().
1013   // Record the sharding of hlo here, and reset it before return.
1014   auto original_sharding = partitioned_hlo.sharding();
1015   halo_exchange_hlo->set_sharding(sharding_grouped.sharding);
1016   auto partial_replicate_hlo = PartitionedHlo(halo_exchange_hlo, base_shape,
1017                                               per_group_partitioner_state);
1018   HloInstruction* result =
1019       partial_replicate_hlo.ReplicatePartial(replicate_dims);
1020   partitioned_hlo.hlo()->set_sharding(original_sharding);
1021   result->set_sharding(target);
1022   return PartitionedHlo(result, base_shape_, partitioned_hlo.state());
1023 }
1024 
1025 absl::optional<PartitionedHlo>
ReshardFromPartialReplicateWithDynamicSlice(const HloSharding & target)1026 PartitionedHlo::ReshardFromPartialReplicateWithDynamicSlice(
1027     const HloSharding& target) {
1028   if (!sharding().ReplicateOnLastTileDim()) {
1029     return absl::nullopt;
1030   }
1031 
1032   // Get the temp sharding target from partial replicate to target tile dims.
1033   // target_compatible_sharding has the same tile_assignment dimensions
1034   // as the target and can reshard to target by collective permute.
1035   // target_compatible_sharding could have different device assignment as
1036   // targe. sharding() can reshard to target_compatible_sharding by
1037   // dynamic slice.
1038   auto target_compatible_sharding =
1039       PartialReplicateReshardCompatibleSharding(sharding(), target);
1040   // Reshard to target_compatible_sharding by dynamic slice.
1041   if (!target_compatible_sharding.has_value()) {
1042     return absl::nullopt;
1043   }
1044   std::vector<int64> expand_tile_dims;
1045   std::vector<int64> tiling_dim_factors;
1046   int64_t rank = hlo_->shape().rank();
1047   tiling_dim_factors.reserve(target.tile_assignment().num_dimensions());
1048   const auto& temp_target_sharding = target_compatible_sharding.value();
1049   for (int64_t dim = 0; dim < rank; dim++) {
1050     if (temp_target_sharding.tile_assignment().dim(dim) >
1051         sharding().tile_assignment().dim(dim)) {
1052       expand_tile_dims.push_back(dim);
1053     }
1054     tiling_dim_factors.emplace_back(
1055         temp_target_sharding.tile_assignment().dim(dim) /
1056         sharding().tile_assignment().dim(dim));
1057   }
1058 
1059   // Add another dimension in tiling_dim_factors if target is partial replicate.
1060   if (target.ReplicateOnLastTileDim()) {
1061     tiling_dim_factors.emplace_back(
1062         target.tile_assignment().dimensions().back());
1063   }
1064 
1065   // 2. Get the padded_hlo, do right halo exchange if needed.
1066   auto padded_hlo = PadFromPartialReplicateShape(
1067       hlo_, base_shape_, sharding(), temp_target_sharding, expand_tile_dims,
1068       state_.collective_ops_creator, state_.next_channel_id,
1069       state_.partition_id, state_.b);
1070   if (!padded_hlo.has_value()) {
1071     return absl::nullopt;
1072   }
1073   // 3. Slice out the tile from replicate ones.
1074   auto shard_shape = MakePartitionedShape(base_shape_, temp_target_sharding);
1075   // Since we are just slicing, we can just use the differences between the new
1076   // and old offsets in the full shape as the dynamic-slice offsets.
1077   auto padded_base_shape = shard_shape;
1078   for (int64_t i = 0; i < padded_base_shape.rank(); ++i) {
1079     padded_base_shape.set_dimensions(
1080         i, padded_base_shape.dimensions(i) *
1081                temp_target_sharding.tile_assignment().dim(i));
1082   }
1083   auto offsets = MakePartitionOffsets(padded_base_shape, temp_target_sharding,
1084                                       state_.partition_id, state_.b);
1085   auto old_offsets = MakePartitionOffsets(padded_base_shape, sharding(),
1086                                           state_.partition_id, state_.b);
1087   for (int64_t i = 0; i < offsets.size(); ++i) {
1088     offsets[i] = state_.b->AddInstruction(HloInstruction::CreateBinary(
1089         offsets[i]->shape(), HloOpcode::kSubtract, offsets[i], old_offsets[i]));
1090   }
1091   auto slice = state_.b->AddInstruction(HloInstruction::CreateDynamicSlice(
1092       shard_shape, padded_hlo.value(), offsets, shard_shape.dimensions()));
1093   slice->set_sharding(temp_target_sharding);
1094   auto result = PartitionedHlo(slice, base_shape_, state_);
1095   // If temp_target_sharding's device assignment is different from target,
1096   // use collective permute to reshard.
1097   if (CanReshardWithCollectivePermute(temp_target_sharding, target)) {
1098     return result.ReshardWithCollectivePermute(target);
1099   }
1100   // If device assignment in temp_target_sharding and target are the same,
1101   // return result directly.
1102   return result;
1103 }
1104 
Broadcast() const1105 PartitionedHlo PartitionedHlo::Broadcast() const {
1106   const Shape& shape = hlo_->shape();
1107   const HloSharding& sharding = hlo_->sharding();
1108   CHECK(sharding.HasUniqueDevice());
1109   CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
1110 
1111   auto src_core_id = state_.b->AddInstruction(HloInstruction::CreateConstant(
1112       LiteralUtil::CreateR0<uint32>(sharding.GetUniqueDevice())));
1113   Shape bcast_shape = ShapeUtil::ChangeElementType(shape, PRED);
1114   auto is_src_core = state_.b->AddInstruction(HloInstruction::CreateBroadcast(
1115       bcast_shape,
1116       state_.b->AddInstruction(HloInstruction::CreateCompare(
1117           ShapeUtil::MakeShape(PRED, {}), state_.partition_id, src_core_id,
1118           ComparisonDirection::kEq)),
1119       {}));
1120 
1121   auto zero = state_.b->AddInstruction(
1122       HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
1123   auto zero_bcast = state_.b->AddInstruction(
1124       HloInstruction::CreateBroadcast(shape, zero, {}));
1125   auto operand = state_.b->AddInstruction(HloInstruction::CreateTernary(
1126       shape, HloOpcode::kSelect, is_src_core, hlo(), zero_bcast));
1127   HloComputation* reduction =
1128       MakeBinaryAdd(shape.element_type(), state_.module);
1129 
1130   auto result = state_.collective_ops_creator.create_cross_partition_all_reduce(
1131       state_.b, operand, reduction, {}, NewChannel());
1132   result->set_sharding(HloSharding::Replicate());
1133   return PartitionedHlo(result, base_shape_, state_);
1134 }
1135 
ReshardWithAllToAll(const HloSharding & target,absl::Span<const std::pair<int64,int64>> source_target_dims) const1136 PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
1137     const HloSharding& target,
1138     absl::Span<const std::pair<int64, int64>> source_target_dims) const {
1139   if (source_target_dims.empty()) {
1140     if (target == sharding()) {
1141       return *this;
1142     }
1143     // If the device order is different in the target, fix the order with
1144     // ReshardWithCollectivePermute.
1145     return ReshardWithCollectivePermute(target);
1146   }
1147 
1148   // Swap one pair of dimensions.
1149   int64_t source_dim = source_target_dims[0].first;
1150   int64_t target_dim = source_target_dims[0].second;
1151   const int64_t group_size = sharding().tile_assignment().dim(source_dim) /
1152                              sharding().tile_assignment().dim(target_dim);
1153 
1154   auto temp_target_tile = sharding().tile_assignment();
1155   {
1156     std::vector<int64> reshape_tile_dims(temp_target_tile.num_dimensions() + 2);
1157     int64_t i = 0;
1158     int64_t added_source_dim = -1;
1159     int64_t added_target_dim = -1;
1160     for (int64_t j = 0; j < temp_target_tile.num_dimensions(); ++j) {
1161       if (source_dim == j) {
1162         reshape_tile_dims[i] = temp_target_tile.dim(j) / group_size;
1163         reshape_tile_dims[++i] = group_size;
1164         added_source_dim = i;
1165       } else if (target_dim == j) {
1166         reshape_tile_dims[i] = temp_target_tile.dim(j);
1167         reshape_tile_dims[++i] = 1;
1168         added_target_dim = i;
1169       } else {
1170         reshape_tile_dims[i] = temp_target_tile.dim(j);
1171       }
1172       ++i;
1173     }
1174     temp_target_tile.Reshape(reshape_tile_dims);
1175     std::vector<int64> xpose_dims(temp_target_tile.num_dimensions());
1176     std::iota(xpose_dims.begin(), xpose_dims.end(), 0);
1177     xpose_dims[added_source_dim] = added_target_dim;
1178     xpose_dims[added_target_dim] = added_source_dim;
1179     temp_target_tile = hlo_sharding_util::TransposeSharding(
1180                            HloSharding::Tile(temp_target_tile), xpose_dims)
1181                            .tile_assignment();
1182     auto temp_target_tile_dims = sharding().tile_assignment().dimensions();
1183     temp_target_tile_dims[source_dim] =
1184         sharding().tile_assignment().dim(target_dim);
1185     temp_target_tile_dims[target_dim] =
1186         sharding().tile_assignment().dim(source_dim);
1187     temp_target_tile.Reshape(temp_target_tile_dims);
1188   }
1189   auto temp_target = target.ReplicateOnLastTileDim()
1190                          ? HloSharding::PartialTile(temp_target_tile)
1191                          : HloSharding::Tile(temp_target_tile);
1192   auto padded_shape = hlo_->shape();
1193   padded_shape.set_dimensions(
1194       target_dim,
1195       RoundUpToNearest(padded_shape.dimensions(target_dim),
1196                        temp_target.tile_assignment().dim(target_dim)));
1197   auto padded_hlo = PadToShape(hlo_, padded_shape, state_.b);
1198 
1199   // The order of ids in the group must follow the temp_target sharding.
1200   std::vector<std::vector<int64>> groups(
1201       temp_target.tile_assignment().num_elements() / group_size);
1202   temp_target.tile_assignment().Each(
1203       [&](absl::Span<const int64> indices, int64_t device) {
1204         int64_t group_id = 0;
1205         for (int64_t dim = 0; dim < indices.size(); ++dim) {
1206           if (dim == target_dim) {
1207             group_id *= temp_target.tile_assignment().dim(dim) / group_size;
1208             group_id += indices[dim] / group_size;
1209           } else {
1210             group_id *= temp_target.tile_assignment().dim(dim);
1211             group_id += indices[dim];
1212           }
1213         }
1214         groups[group_id].push_back(device);
1215       });
1216 
1217   HloInstruction* result = nullptr;
1218 
1219   // Split along the split dimension (target_dim) of the all-to-all
1220   // output.
1221   std::vector<int64> dimensions;
1222   for (int64_t i = 0; i < base_shape_.rank(); ++i) {
1223     if (i == target_dim) {
1224       dimensions.push_back(group_size);
1225       dimensions.push_back(padded_hlo->shape().dimensions(i) / group_size);
1226     } else {
1227       dimensions.push_back(padded_hlo->shape().dimensions(i));
1228     }
1229   }
1230   auto reshape = state_.b->AddInstruction(HloInstruction::CreateReshape(
1231       ShapeUtil::MakeShape(base_shape_.element_type(), dimensions),
1232       padded_hlo));
1233   // After the reshape, it is guaranteed to have at least 3 dimensions.
1234   auto all_to_all =
1235       state_.collective_ops_creator.create_cross_partition_all_to_all(
1236           state_.b, {reshape}, groups, (*state_.next_channel_id)++, target_dim);
1237 
1238   // Reorder the split dimension of the reshape to be located in front of the
1239   // input partition dimension, so the two dimensions can be combined.
1240   int64_t new_source_dim =
1241       (target_dim < source_dim) ? source_dim + 1 : source_dim;
1242   std::vector<int64> permutation;
1243   for (int64_t i = 0; i < all_to_all->shape().rank(); ++i) {
1244     if (i == target_dim) {
1245       continue;
1246     }
1247     if (i == new_source_dim) {
1248       permutation.push_back(target_dim);
1249     }
1250     permutation.push_back(i);
1251   }
1252   auto transpose = state_.b->AddInstruction(HloInstruction::CreateTranspose(
1253       ShapeInference::InferTransposeShape(all_to_all->shape(), permutation)
1254           .ValueOrDie(),
1255       all_to_all, permutation));
1256 
1257   // Combine the split dimension and the input partition dimension.
1258   auto new_shape = ShapeInference::InferAllToAllShape(
1259                        padded_hlo->shape(), target_dim, source_dim, group_size)
1260                        .ValueOrDie();
1261   result = state_.b->AddInstruction(
1262       HloInstruction::CreateReshape(new_shape, transpose));
1263 
1264   const Shape result_shape = MakePartitionedShape(base_shape_, temp_target);
1265   if (result_shape != result->shape()) {
1266     result = state_.b->AddInstruction(HloInstruction::CreateSlice(
1267         result_shape, result, std::vector<int64>(result_shape.rank(), 0),
1268         result_shape.dimensions(), std::vector<int64>(result_shape.rank(), 1)));
1269   }
1270   result->set_sharding(temp_target);
1271   auto remaining_source_target_dims = source_target_dims;
1272   remaining_source_target_dims.remove_prefix(1);
1273   return PartitionedHlo(result, base_shape_, state_)
1274       .ReshardWithAllToAll(target, remaining_source_target_dims);
1275 }
1276 
1277 absl::optional<PartitionedHlo>
ReshardPartialReplicateWithAllToAll(const HloSharding & target)1278 PartitionedHlo::ReshardPartialReplicateWithAllToAll(const HloSharding& target) {
1279   bool source_is_partial_replicate = sharding().ReplicateOnLastTileDim();
1280   const auto& partial_replicate_sharding =
1281       source_is_partial_replicate ? sharding() : target;
1282   // If neither the source nor the target is partial replicate, return null.
1283   if (!partial_replicate_sharding.ReplicateOnLastTileDim()) {
1284     return absl::nullopt;
1285   }
1286   const auto& tile_sharding = source_is_partial_replicate ? target : sharding();
1287   // If both source and target are partial replicate, should be supported in
1288   // Reshard with AllToAll already.
1289   if (tile_sharding.ReplicateOnLastTileDim() || tile_sharding.IsTileMaximal()) {
1290     return absl::nullopt;
1291   }
1292 
1293   // Only support resharding from sharding={devices=[2,3]0,1,2,3,4,5}
1294   // to sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}, where
1295   // the last tile dim will be replicate first before all-to-all.
1296   // Or resharding from
1297   // sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
1298   // to sharding={devices=[2,3]0,1,2,3,4,5}, where
1299   // the last tile dim will be sharded after all-to-all.
1300   const int num_replicas =
1301       partial_replicate_sharding.tile_assignment().dimensions().back();
1302   if (((tile_sharding.tile_assignment().num_dimensions() + 1) !=
1303        partial_replicate_sharding.tile_assignment().num_dimensions()) ||
1304       (partial_replicate_sharding.tile_assignment().dim(0) != 1)) {
1305     return absl::nullopt;
1306   }
1307   int to_replicate_dim = -1;
1308   for (int i = tile_sharding.tile_assignment().num_dimensions() - 1; i >= 0;
1309        --i) {
1310     if (tile_sharding.tile_assignment().dim(i) > 1 &&
1311         (to_replicate_dim == -1)) {
1312       if (tile_sharding.tile_assignment().dim(i) != num_replicas) {
1313         return absl::nullopt;
1314       }
1315       to_replicate_dim = i;
1316     }
1317 
1318     if (tile_sharding.tile_assignment().dim(i) !=
1319         partial_replicate_sharding.tile_assignment().dim(i + 1)) {
1320       return absl::nullopt;
1321     }
1322   }
1323 
1324   if (to_replicate_dim == -1) {
1325     return absl::nullopt;
1326   }
1327 
1328   // Check if core assignments for source and the target are the same.
1329   auto reshape_tile_assignment = partial_replicate_sharding.tile_assignment();
1330   reshape_tile_assignment.Reshape(tile_sharding.tile_assignment().dimensions());
1331   if (reshape_tile_assignment != tile_sharding.tile_assignment()) {
1332     return absl::nullopt;
1333   }
1334 
1335   auto tmp_tile_assignment = tile_sharding.tile_assignment();
1336   auto tmp_tile_assignment_dimensions =
1337       tile_sharding.tile_assignment().dimensions();
1338   tmp_tile_assignment_dimensions[to_replicate_dim] = 1;
1339   tmp_tile_assignment_dimensions.push_back(num_replicas);
1340   tmp_tile_assignment.Reshape(tmp_tile_assignment_dimensions);
1341   auto tmp_partial_replicate_sharding =
1342       HloSharding::PartialTile(tmp_tile_assignment);
1343 
1344   if (source_is_partial_replicate) {
1345     if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims(
1346             sharding(), tmp_partial_replicate_sharding)) {
1347       auto partitioned_hlo =
1348           ReshardWithAllToAll(tmp_partial_replicate_sharding, *src_tgt_dims);
1349       return partitioned_hlo.Reshard(target);
1350     }
1351   } else {
1352     auto partitioned_hlo = Reshard(tmp_partial_replicate_sharding);
1353 
1354     if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims(
1355             partitioned_hlo.sharding(), target)) {
1356       return partitioned_hlo.ReshardWithAllToAll(target, *src_tgt_dims);
1357     }
1358   }
1359 
1360   return absl::nullopt;
1361 }
1362 
ReshardWithCollectivePermute(const HloSharding & target) const1363 PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute(
1364     const HloSharding& target) const {
1365   CHECK(CanReshardWithCollectivePermute(sharding(), target))
1366       << sharding().ToString() << " to " << target.ToString();
1367   if (auto broadcast_dims = state_.b->BroadcastDimsForCreatedHlo(hlo())) {
1368     if (!(*broadcast_dims)->empty()) {
1369       // If hlo() has broadcast dims, check if data is already the same between
1370       // source/destination pairs.
1371       std::vector<int64> broadcast_dims_vector;
1372       for (int64_t i = 0; i < hlo()->shape().rank(); ++i) {
1373         if ((*broadcast_dims)->contains(i)) {
1374           broadcast_dims_vector.push_back(i);
1375         }
1376       }
1377       if (hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
1378               sharding(), broadcast_dims_vector) ==
1379           hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
1380               target, broadcast_dims_vector)) {
1381         auto copy = state_.b->AddInstruction(HloInstruction::CreateUnary(
1382             hlo()->shape(), HloOpcode::kCopy, hlo()));
1383         copy->set_sharding(target);
1384         return PartitionedHlo(copy, base_shape_, state_);
1385       }
1386     }
1387   }
1388   std::vector<std::pair<int64, int64>> src_dst_pairs;
1389   sharding().tile_assignment().Each(
1390       [&](absl::Span<const int64> indices, int64_t src_device) {
1391         int64_t dst_device = target.tile_assignment()(indices);
1392         src_dst_pairs.emplace_back(src_device, dst_device);
1393       });
1394   auto cp =
1395       state_.collective_ops_creator.create_cross_partition_collective_permute(
1396           state_.b, hlo(), src_dst_pairs, (*state_.next_channel_id)++);
1397   cp->set_sharding(target);
1398   return PartitionedHlo(cp, base_shape_, state_);
1399 }
1400 
SpmdPartitioningVisitor(HloComputation * computation,int64_t num_partitions,int64_t num_replicas,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,SpmdLogger * logger,SpmdPartitionerOptions options,SpmdPartitioner * partitioner)1401 SpmdPartitioningVisitor::SpmdPartitioningVisitor(
1402     HloComputation* computation, int64_t num_partitions, int64_t num_replicas,
1403     const SPMDCollectiveOpsCreator& collective_ops_creator,
1404     int64* next_channel_id, SpmdLogger* logger, SpmdPartitionerOptions options,
1405     SpmdPartitioner* partitioner)
1406     : changed_(false),
1407       module_(computation->parent()),
1408       num_partitions_(num_partitions),
1409       num_replicas_(num_replicas),
1410       collective_ops_creator_(collective_ops_creator),
1411       next_channel_id_(next_channel_id),
1412       b_(SpmdBuilder(computation->name() + "_spmd", /*hlo=*/nullptr)),
1413       partition_id_(collective_ops_creator_.create_partition_id(&b_)),
1414       logger_(logger),
1415       options_(std::move(options)),
1416       partitioner_(partitioner) {}
1417 
1418 PartitionedHlo::PartitioningState
MakePartitioningState()1419 SpmdPartitioningVisitor::MakePartitioningState() {
1420   PartitionedHlo::PartitioningState state;
1421   state.b = &b_;
1422   state.module = module_;
1423   state.num_replicas = num_replicas_;
1424   state.partition_id = partition_id_;
1425   state.collective_ops_creator = collective_ops_creator_;
1426   state.next_channel_id = next_channel_id_;
1427   state.reshard_cache = &reshard_cache_;
1428   state.partitioner = partitioner_;
1429   if (!device_groups_.empty()) {
1430     return CreatePerGroupPartitioningState(state, device_groups_, &b_);
1431   }
1432   return state;
1433 }
1434 
DefaultAction(HloInstruction * hlo)1435 Status SpmdPartitioningVisitor::DefaultAction(HloInstruction* hlo) {
1436   if (hlo->HasSideEffect()) {
1437     return Unimplemented("Side-effect ops cannot be replicated: %s",
1438                          hlo->ToString());
1439   }
1440 
1441   if (hlo->IsElementwise() && hlo->operand_count() > 0) {
1442     return HandleElementwise(hlo);
1443   }
1444 
1445   if (!hlo->sharding().IsTileMaximal()) {
1446     VLOG(1) << "Not partitioned in SPMD mode (DefaultAction):"
1447             << hlo->ToString();
1448     for (int64_t i = 0; i < hlo->operand_count(); ++i) {
1449       VLOG(1) << "  operand " << i
1450               << " sharding:" << hlo->operand(i)->sharding().ToString();
1451     }
1452   }
1453 
1454   HloSharding sharding = hlo->sharding().HasUniqueDevice()
1455                              ? hlo->sharding()
1456                              : HloSharding::Replicate();
1457 
1458   // If the instruction cannot be partitioned, replicate the instruction unless
1459   // the instruction has side-effect.
1460   std::vector<HloInstruction*> new_operands;
1461   for (HloInstruction* operand : hlo->operands()) {
1462     new_operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
1463   }
1464   auto clone =
1465       b_.AddInstruction(hlo->CloneWithNewOperands(hlo->shape(), new_operands));
1466   clone->set_sharding(sharding);
1467   SetPartitionedHlo(hlo,
1468                     PartitionedHlo(clone, hlo->shape(), MakePartitioningState())
1469                         .Reshard(hlo->sharding()));
1470   return Status::OK();
1471 }
1472 
Preprocess(HloInstruction * hlo)1473 Status SpmdPartitioningVisitor::Preprocess(HloInstruction* hlo) {
1474   visiting_hlo_ = hlo;
1475   b_.set_visiting_hlo(hlo);
1476   // Temporarily replace manual sharding to one-device sharding so that the
1477   // partitioner will not change the HLOs.
1478   auto manual_to_onedevice = [&](const Shape& shape,
1479                                  const HloSharding& sharding) {
1480     // If a tuple's elements are all manual, then sharding.IsManual() == True,
1481     // so we test whether it is tuple first.
1482     if (sharding.IsTuple()) {
1483       std::vector<HloSharding> subshardings = sharding.tuple_elements();
1484       for (HloSharding& subsharding : subshardings) {
1485         if (subsharding.IsManual()) {
1486           subsharding = HloSharding::AssignDevice(0);
1487         }
1488       }
1489       return HloSharding::Tuple(shape, subshardings);
1490     }
1491     if (sharding.IsManual()) {
1492       return HloSharding::AssignDevice(0);
1493     }
1494     return sharding;
1495   };
1496 
1497   if (hlo->opcode() != HloOpcode::kConditional &&
1498       hlo->opcode() != HloOpcode::kWhile && hlo->opcode() != HloOpcode::kRng) {
1499     const bool has_manual_sharding =
1500         hlo->sharding().IsManual() ||
1501         (hlo->sharding().IsTuple() &&
1502          absl::c_any_of(
1503              hlo->sharding().tuple_elements(),
1504              [](const HloSharding& sharding) { return sharding.IsManual(); }));
1505     if (has_manual_sharding && !hlo->IsCustomCall("SPMDFullToShardShape")) {
1506       visiting_hlo_sharding_ = hlo->sharding();
1507       hlo->set_sharding(
1508           manual_to_onedevice(hlo->shape(), *visiting_hlo_sharding_));
1509 
1510       visiting_hlo_operand_shardings_.reserve(hlo->operand_count());
1511       for (auto operand : hlo->operands()) {
1512         visiting_hlo_operand_shardings_.push_back(operand->sharding());
1513         operand->set_sharding(
1514             manual_to_onedevice(operand->shape(), operand->sharding()));
1515         GetPartitionedHlo(operand).hlo()->set_sharding(operand->sharding());
1516       }
1517     } else if (hlo->sharding().IsManualSubgroup()) {
1518       GroupedSharding group_sharding =
1519           GetManualSubgroupSharding(hlo->sharding());
1520       // Update sharding.
1521       visiting_hlo_sharding_ = hlo->sharding();
1522       hlo->set_sharding(group_sharding.sharding);
1523       // Update device_groups and num_partitions.
1524       device_groups_ = group_sharding.device_groups;
1525       visiting_num_partitions_ = num_partitions_;
1526       num_partitions_ = num_partitions_ / group_sharding.device_groups.size();
1527 
1528       // Update sharding for the operands.
1529       visiting_hlo_operand_shardings_.reserve(hlo->operand_count());
1530       visiting_state_.reserve(hlo->operand_count());
1531       for (auto operand : hlo->operands()) {
1532         visiting_hlo_operand_shardings_.push_back(operand->sharding());
1533         GroupedSharding group_sharding =
1534             GetManualSubgroupSharding(operand->sharding());
1535         operand->set_sharding(group_sharding.sharding);
1536         GetPartitionedHlo(operand).hlo()->set_sharding(operand->sharding());
1537         auto old_state = GetPartitionedHlo(operand).state();
1538         visiting_state_.push_back(old_state);
1539         auto group_state = CreatePerGroupPartitioningState(
1540             old_state, group_sharding.device_groups, &b_);
1541         GetPartitionedHlo(operand).set_state(group_state);
1542       }
1543     }
1544   }
1545   return Status::OK();
1546 }
1547 
Postprocess(HloInstruction * hlo)1548 Status SpmdPartitioningVisitor::Postprocess(HloInstruction* hlo) {
1549   logger_->RegisterLogEntry(GetPartitionedHlo(hlo).hlo(),
1550                             b_.derived_instructions(hlo));
1551   visiting_hlo_ = nullptr;
1552   b_.set_visiting_hlo(nullptr);
1553   // Revert fake one-device shardings for manually partitioned ops.
1554   if (visiting_hlo_sharding_) {
1555     hlo->set_sharding(*visiting_hlo_sharding_);
1556     GetPartitionedHlo(hlo).hlo()->set_sharding(*visiting_hlo_sharding_);
1557     for (int64_t i = 0; i < hlo->operand_count(); ++i) {
1558       auto operand = hlo->mutable_operand(i);
1559       operand->set_sharding(visiting_hlo_operand_shardings_[i]);
1560       GetPartitionedHlo(operand).hlo()->set_sharding(operand->sharding());
1561     }
1562     visiting_hlo_sharding_.reset();
1563     visiting_hlo_operand_shardings_.clear();
1564   }
1565 
1566   if (!device_groups_.empty()) {
1567     device_groups_.clear();
1568     GetPartitionedHlo(hlo).set_state(MakePartitioningState());
1569     num_partitions_ = *visiting_num_partitions_;
1570     visiting_num_partitions_.reset();
1571   }
1572 
1573   if (!visiting_state_.empty()) {
1574     for (int64_t i = 0; i < hlo->operand_count(); ++i) {
1575       const HloInstruction* operand = hlo->operand(i);
1576       GetPartitionedHlo(operand).set_state(visiting_state_[i]);
1577     }
1578     visiting_state_.clear();
1579   }
1580 
1581   return Status::OK();
1582 }
1583 
HandleElementwise(HloInstruction * hlo)1584 Status SpmdPartitioningVisitor::HandleElementwise(HloInstruction* hlo) {
1585   std::vector<HloInstruction*> new_operands;
1586   for (HloInstruction* operand : hlo->operands()) {
1587     new_operands.push_back(
1588         GetPartitionedHlo(operand).Reshard(hlo->sharding()).hlo());
1589   }
1590   SetPartitionedHlo(hlo, [&] {
1591     return b_.AddInstruction(hlo->CloneWithNewOperands(
1592         MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands));
1593   });
1594   return Status::OK();
1595 }
1596 
HandleConcatenate(HloInstruction * hlo)1597 Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) {
1598   const HloSharding& sharding = hlo->sharding();
1599   if (sharding.IsTileMaximal()) {
1600     return DefaultAction(hlo);
1601   }
1602 
1603   const Shape shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
1604   const int64_t dimension = hlo->concatenate_dimension();
1605   if (sharding.tile_assignment().dim(dimension) == 1) {
1606     std::vector<HloInstruction*> new_operands;
1607     for (HloInstruction* operand : hlo->operands()) {
1608       new_operands.push_back(
1609           GetPartitionedHlo(operand).Reshard(sharding).hlo());
1610     }
1611     SetPartitionedHlo(hlo, [&] {
1612       return b_.AddInstruction(
1613           hlo->CloneWithNewOperands(shard_shape, new_operands));
1614     });
1615     return Status::OK();
1616   }
1617 
1618   // If the concatenate dimension is along one of the partitioned dimensions,
1619   // allocate the full output shape, each partition updates its owned region,
1620   // all-reduce across partitions, and then slice its output region.
1621 
1622   // temp_output_shape is the output shape where the concatenate dimension
1623   // is changed to the full (and padded to shard count) dimension size.
1624   auto temp_output_shape = MakePartitionedShape(hlo->shape(), sharding);
1625   auto last_operand_padded_shape =
1626       MakePartitionedShape(hlo->operands().back()->shape(), sharding);
1627   // If the last operand has more padding than the temp_output padding, needs to
1628   // add extra padding to avoid dynamic update slice out of bound.
1629   int last_operand_padding =
1630       last_operand_padded_shape.dimensions(dimension) *
1631           sharding.tile_assignment().dim(dimension) -
1632       hlo->operands().back()->shape().dimensions(dimension);
1633   int temp_output_padding = temp_output_shape.dimensions(dimension) *
1634                                 sharding.tile_assignment().dim(dimension) -
1635                             hlo->shape().dimensions(dimension);
1636   int padding_for_last_operand =
1637       last_operand_padding < temp_output_padding
1638           ? 0
1639           : last_operand_padding - temp_output_padding;
1640   temp_output_shape.set_dimensions(
1641       dimension, temp_output_shape.dimensions(dimension) *
1642                          sharding.tile_assignment().dim(dimension) +
1643                      padding_for_last_operand);
1644   auto temp_output = CreateZero(temp_output_shape, &b_);
1645 
1646   // Offset of each operand along the concatenate dimension.
1647   int64_t offset = 0;
1648   auto state = MakePartitioningState();
1649   for (HloInstruction* operand : hlo->operands()) {
1650     auto spmd_operand = GetPartitionedHlo(operand).Reshard(sharding).hlo();
1651     std::vector<HloInstruction*> start_indices(
1652         hlo->shape().rank(), b_.AddInstruction(HloInstruction::CreateConstant(
1653                                  LiteralUtil::Zero(S32))));
1654     start_indices[dimension] =
1655         MultiplyAddDivideOffsetCalculation(
1656             spmd_operand->shape().dimensions(dimension), offset, 1)
1657             .Calculate(MakeTiledPartitionOrdinals(sharding, state.partition_id,
1658                                                   &b_)[dimension],
1659                        &b_);
1660     temp_output = b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1661         temp_output_shape, temp_output, spmd_operand, start_indices));
1662     offset += operand->shape().dimensions(dimension);
1663   }
1664   std::vector<int64> non_concat_dims;
1665   non_concat_dims.reserve(hlo->shape().rank() - 1);
1666   for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
1667     if (i != dimension) {
1668       non_concat_dims.push_back(i);
1669     }
1670   }
1671   auto grouped = GroupShardingOnDims(sharding, non_concat_dims);
1672   auto per_group_partitioner_state =
1673       CreatePerGroupPartitioningState(state, grouped.device_groups, &b_);
1674   auto all_reduce = per_group_partitioner_state.collective_ops_creator
1675                         .create_cross_partition_all_reduce(
1676                             &b_, temp_output,
1677                             MakeBinaryAdd(hlo->shape().element_type(), module_),
1678                             {}, NewChannel());
1679   SetPartitionedHlo(hlo, [&] {
1680     auto start_indices = MakeTiledPartitionOrdinals(
1681         grouped.sharding, per_group_partitioner_state.partition_id, &b_);
1682     start_indices[dimension] = MultiplyAddDivideOffsetCalculation(
1683                                    shard_shape.dimensions(dimension), 0, 1)
1684                                    .Calculate(start_indices[dimension], &b_);
1685     return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
1686         shard_shape, all_reduce, start_indices, shard_shape.dimensions()));
1687   });
1688 
1689   return Status::OK();
1690 }
1691 
HandleSlice(HloInstruction * hlo)1692 Status SpmdPartitioningVisitor::HandleSlice(HloInstruction* hlo) {
1693   const HloSharding& sharding = hlo->sharding();
1694   if (sharding.IsTileMaximal()) {
1695     return DefaultAction(hlo);
1696   }
1697 
1698   auto operand = GetPartitionedHlo(hlo->operand(0)).Reshard(sharding);
1699 
1700   const int64_t rank = hlo->shape().rank();
1701   // Create a window config to represent the slice.
1702   Window window;
1703   for (int64_t i = 0; i < rank; ++i) {
1704     WindowDimension* dim = window.add_dimensions();
1705     dim->set_size(1);
1706     dim->set_stride(hlo->slice_strides(i));
1707     dim->set_window_dilation(1);
1708     dim->set_window_reversal(false);
1709     dim->set_padding_low(-hlo->slice_starts(i));
1710     dim->set_padding_high(hlo->slice_limits(i) -
1711                           operand.base_shape().dimensions(i));
1712     dim->set_base_dilation(1);
1713   }
1714 
1715   auto reshard_operand = operand.ReshardAsWindowedInput(
1716       window, sharding,
1717       CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
1718       /*mask_invalid_region=*/false);
1719   if (!reshard_operand.has_value()) {
1720     return DefaultAction(hlo);
1721   }
1722   TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value());
1723   const Shape& operand_shape = reshard_operand->sharded_input->shape();
1724 
1725   std::vector<int64> start_indices(rank);
1726   std::vector<int64> limit_indices(rank);
1727   const std::vector<int64>& strides = hlo->slice_strides();
1728   bool need_slice = false;
1729   for (int64_t i = 0; i < rank; ++i) {
1730     auto dim = reshard_operand->shard_window.dimensions(i);
1731     start_indices[i] = -dim.padding_low();
1732     limit_indices[i] = operand_shape.dimensions(i) + dim.padding_high();
1733     if (start_indices[i] != 0 || strides[i] != 1 ||
1734         limit_indices[i] != operand_shape.dimensions(i)) {
1735       need_slice = true;
1736     }
1737   }
1738 
1739   SetPartitionedHlo(hlo, [&] {
1740     if (need_slice) {
1741       auto shard_shape = MakePartitionedShape(hlo->shape(), sharding);
1742       return b_.AddInstruction(HloInstruction::CreateSlice(
1743           shard_shape, reshard_operand->sharded_input, start_indices,
1744           limit_indices, strides));
1745     }
1746     auto data = reshard_operand->sharded_input;
1747     // Create a copy so that it will not share the resharding cache.
1748     return b_.AddInstruction(
1749         HloInstruction::CreateUnary(data->shape(), HloOpcode::kCopy, data));
1750   });
1751 
1752   return Status::OK();
1753 }
1754 
HandleSort(HloInstruction * hlo)1755 Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) {
1756   HloSharding sharding = hlo->sharding();
1757   if (sharding.HasUniqueDevice()) {
1758     return DefaultAction(hlo);
1759   }
1760   // Special handling for sort in TopK when first operand partitioined at
1761   // sort dimension.
1762   auto k = GetKValueInTopKWhenPartitionSortDim(hlo);
1763   if (k.has_value()) {
1764     // When the first operand partitioned at sort dimension:
1765     //   1. Partition sort computation to different partitions;
1766     //   2. Slice TopK value and index from different partitions;
1767     //   3. Gather and replicate value and index from different partitions,
1768     //      the shape of replicated value and index will be
1769     //      [batch_size, ..., partition_count * k, ...];
1770     //   4. Final sort uses replicated value and index from different partitions
1771     //      as input.
1772     // GetTupleElement and Slice after the non-partitoned sort won't change
1773     // at this point, as HandleGetTupleElement and HandleSlice will update them.
1774     HloSortInstruction* sort = DynCast<HloSortInstruction>(hlo);
1775     const int64_t sort_dim = sort->sort_dimension();
1776     auto input = hlo->operand(0);
1777     auto index = hlo->operand(1);
1778     const HloSharding& input_sharding = input->sharding();
1779     const int64_t partition_count =
1780         input_sharding.tile_assignment().dim(sort_dim);
1781     const int64_t input_size = input->shape().dimensions(sort_dim);
1782     const int64_t per_partition_size = CeilOfRatio(input_size, partition_count);
1783     const auto element_type = input->shape().element_type();
1784     const auto index_type = index->shape().element_type();
1785 
1786     // Partition and pad input and index.
1787     // Pad input with minimal value.
1788     auto partitioned_input = GetPartitionedHlo(input).PadWithValue(
1789         CreateFirstWithType(element_type, &b_));
1790     // Pad index with max value.
1791     auto partitioned_index =
1792         GetPartitionedHlo(index)
1793             .Reshard(input_sharding)
1794             .PadWithValue(CreateLastWithType(index_type, &b_));
1795 
1796     // Each partition needs to do TopK separately, thus the base shape
1797     // becomes the padded shape.
1798     std::vector<int64> replicated_dimensions(
1799         input->shape().dimensions().begin(), input->shape().dimensions().end());
1800     replicated_dimensions[sort_dim] = per_partition_size * partition_count;
1801     const Shape replicated_shape = ShapeUtil::MakeTupleShape(
1802         {ShapeUtil::MakeShape(element_type, replicated_dimensions),
1803          ShapeUtil::MakeShape(index_type, replicated_dimensions)});
1804 
1805     // Partition original topk to different shards.
1806     auto topk_sharding =
1807         input_sharding.GetTupleSharding(replicated_shape).ValueOrDie();
1808     auto shard_shape = MakePartitionedShape(replicated_shape, topk_sharding);
1809     auto topk = b_.AddInstruction(hlo->CloneWithNewOperands(
1810         shard_shape, {partitioned_input.hlo(), partitioned_index.hlo()}));
1811 
1812     // Get value from first sort.
1813     HloInstruction* value_gte =
1814         b_.AddInstruction(HloInstruction::CreateGetTupleElement(
1815             topk->shape().tuple_shapes(0), topk, 0));
1816     HloInstruction* index_gte =
1817         b_.AddInstruction(HloInstruction::CreateGetTupleElement(
1818             topk->shape().tuple_shapes(1), topk, 1));
1819 
1820     // Slice top K value from the first partitioned sort.
1821     replicated_dimensions[sort_dim] = k.value() * partition_count;
1822     auto slice_input = SliceFirstK(value_gte, &b_, sort_dim, k.value());
1823     slice_input->set_sharding(input_sharding);
1824     PartitionedHlo partitioned_slice_input(
1825         slice_input, ShapeUtil::MakeShape(element_type, replicated_dimensions),
1826         MakePartitioningState());
1827     // Reshard value to be replicated.
1828     auto replicated_slice_input =
1829         partitioned_slice_input.Reshard(HloSharding::Replicate()).hlo();
1830 
1831     // Slice top K index from the first parttioned sort.
1832     auto slice_index = SliceFirstK(index_gte, &b_, sort_dim, k.value());
1833     slice_index->set_sharding(input_sharding);
1834     PartitionedHlo partitioned_slice_index(
1835         slice_index, ShapeUtil::MakeShape(index_type, replicated_dimensions),
1836         MakePartitioningState());
1837     // Reshard value to be replicated.
1838     auto replicated_slice_index =
1839         partitioned_slice_index.Reshard(HloSharding::Replicate()).hlo();
1840 
1841     // Creates replicated sort to do TopK, the input is value and index pairs
1842     // from all the partitions.
1843     const Shape final_topk_shape = ShapeUtil::MakeTupleShape(
1844         {ShapeUtil::MakeShape(element_type, replicated_dimensions),
1845          ShapeUtil::MakeShape(index_type, replicated_dimensions)});
1846     HloInstruction* final_sort = b_.AddInstruction(HloInstruction::CreateSort(
1847         final_topk_shape, sort_dim,
1848         {replicated_slice_input, replicated_slice_index}, sort->to_apply(),
1849         sort->is_stable()));
1850     final_sort->set_sharding(HloSharding::Replicate()
1851                                  .GetTupleSharding(final_sort->shape())
1852                                  .ValueOrDie());
1853     PartitionedHlo replicated_sort(final_sort, final_sort->shape(),
1854                                    MakePartitioningState());
1855     SetPartitionedHlo(hlo, replicated_sort.Reshard(hlo->sharding()));
1856 
1857     return Status::OK();
1858   }
1859 
1860   if (hlo->shape().IsTuple()) {
1861     // Check that all elements are sharded in the same way.
1862     if (hlo->shape().tuple_shapes_size() == 0) {
1863       return DefaultAction(hlo);
1864     }
1865     sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0});
1866     for (int64_t i = 1; i < hlo->operand_count(); ++i) {
1867       if (sharding != hlo->sharding().GetSubSharding(hlo->shape(), {i})) {
1868         return DefaultAction(hlo);
1869       }
1870     }
1871   }
1872   if (sharding.IsTileMaximal()) {
1873     return DefaultAction(hlo);
1874   }
1875   for (int64_t dim : hlo->dimensions()) {
1876     if (sharding.tile_assignment().dim(dim) > 1) {
1877       return DefaultAction(hlo);
1878     }
1879   }
1880   // Reshard operands to the same as the output.
1881   std::vector<HloInstruction*> new_operands;
1882   for (HloInstruction* operand : hlo->operands()) {
1883     new_operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
1884   }
1885   SetPartitionedHlo(hlo, [&] {
1886     return b_.AddInstruction(hlo->CloneWithNewOperands(
1887         MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands));
1888   });
1889   return Status::OK();
1890 }
1891 
HandleTranspose(HloInstruction * hlo)1892 Status SpmdPartitioningVisitor::HandleTranspose(HloInstruction* hlo) {
1893   const HloSharding& sharding = hlo->sharding();
1894   if (sharding.IsTileMaximal()) {
1895     return DefaultAction(hlo);
1896   }
1897 
1898   std::vector<int64> inverse_dimensions(hlo->shape().rank());
1899   for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
1900     inverse_dimensions[hlo->dimensions(i)] = i;
1901   }
1902   auto desired_operand_sharding =
1903       hlo_sharding_util::TransposeSharding(sharding, inverse_dimensions);
1904 
1905   auto operand = GetPartitionedHlo(hlo->operand(0))
1906                      .Reshard(desired_operand_sharding)
1907                      .hlo();
1908   SetPartitionedHlo(hlo, [&] {
1909     return b_.AddInstruction(hlo->CloneWithNewOperands(
1910         MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand}));
1911   });
1912   return Status::OK();
1913 }
1914 
HandleReshape(HloInstruction * hlo)1915 Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) {
1916   const HloSharding& sharding = hlo->sharding();
1917   if (sharding.IsTileMaximal()) {
1918     return DefaultAction(hlo);
1919   }
1920 
1921   auto operand = GetPartitionedHlo(hlo->operand(0));
1922   // The output shape is the source and the operand shape is the target to get
1923   // the aligned sharding for the operand.
1924   absl::optional<HloSharding> desired_operand_sharding =
1925       hlo_sharding_util::ReshapeSharding(hlo->shape(), hlo->operand(0)->shape(),
1926                                          hlo->sharding());
1927   if (desired_operand_sharding.has_value()) {
1928     auto operand_hlo = operand.Reshard(*desired_operand_sharding).hlo();
1929     SetPartitionedHlo(hlo, [&] {
1930       return b_.AddInstruction(hlo->CloneWithNewOperands(
1931           MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand_hlo}));
1932     });
1933     return Status::OK();
1934   }
1935   absl::optional<HloSharding> desired_output_sharding =
1936       hlo_sharding_util::ReshapeSharding(hlo->operand(0)->shape(), hlo->shape(),
1937                                          operand.sharding());
1938   if (desired_output_sharding.has_value()) {
1939     auto reshape = b_.AddInstruction(hlo->CloneWithNewOperands(
1940         MakePartitionedShape(hlo->shape(), *desired_output_sharding),
1941         {operand.hlo()}));
1942     reshape->set_sharding(*desired_output_sharding);
1943     SetPartitionedHlo(hlo, [&] {
1944       return PartitionedHlo(reshape, hlo->shape(), MakePartitioningState())
1945           .Reshard(sharding)
1946           .hlo();
1947     });
1948     return Status::OK();
1949   }
1950 
1951   // Check if operand sharding and sharding are both tiled or partial replicate.
1952   // If both of them are partial replicate, check num_replications are the same.
1953   if (operand.sharding().ReplicateOnLastTileDim() !=
1954           sharding.ReplicateOnLastTileDim() ||
1955       (sharding.ReplicateOnLastTileDim() &&
1956        (operand.sharding().tile_assignment().dimensions().back() !=
1957         sharding.tile_assignment().dimensions().back()))) {
1958     return DefaultAction(hlo);
1959   }
1960 
1961   // Try use halo exchange for certain split-dim/merge-dims cases.
1962   // ReshapeSharding failed in these cases probably due to uneven partitioning,
1963   // where halo exchange could help. Specifically we check the following
1964   // conditions to detect supported cases:
1965   // 1) Both input and output are partitioned on one dimension.
1966   // 2) The combined size of dimensions before the partitioned dimension are the
1967   // same on input and output. This means we don't need to consider the major
1968   // dimensions.
1969   // 3) Let A = the input size on the partitioned dimension, and
1970   //        B = the output size on the partitioned dimension; then
1971   //    either A % B == 0 (split dim) or B % A == 0 (merge dims).
1972   auto maybe_input_sharded_dim = UniqueTiledDim(operand.sharding());
1973   auto maybe_output_sharded_dim = UniqueTiledDim(sharding);
1974   if (!maybe_input_sharded_dim || !maybe_output_sharded_dim) {
1975     return DefaultAction(hlo);
1976   }
1977   int64_t input_sharded_dim = *maybe_input_sharded_dim;
1978   int64_t output_sharded_dim = *maybe_output_sharded_dim;
1979   // Check that the major dims before the sharded dim have the same total size
1980   // for input and output.
1981   int64_t input_major_dims_size = 1;
1982   for (int64_t i = 0; i < input_sharded_dim; ++i) {
1983     input_major_dims_size *= operand.base_shape().dimensions(i);
1984   }
1985   int64_t output_major_dims_size = 1;
1986   for (int64_t i = 0; i < output_sharded_dim; ++i) {
1987     output_major_dims_size *= hlo->shape().dimensions(i);
1988   }
1989   if (input_major_dims_size != output_major_dims_size) {
1990     return DefaultAction(hlo);
1991   }
1992   // Fix potential device ordering mismatch in tile assignment.
1993   Array<int64> new_input_tile_assignment = sharding.tile_assignment();
1994   new_input_tile_assignment.Reshape(
1995       operand.sharding().tile_assignment().dimensions());
1996   auto aligned_sharding =
1997       sharding.ReplicateOnLastTileDim()
1998           ? HloSharding::PartialTile(new_input_tile_assignment)
1999           : HloSharding::Tile(new_input_tile_assignment);
2000   operand = operand.Reshard(aligned_sharding);
2001   auto replication_count = sharding.ReplicateOnLastTileDim()
2002                                ? sharding.tile_assignment().dimensions().back()
2003                                : 1;
2004 
2005   int64_t input_dim_size = operand.base_shape().dimensions(input_sharded_dim);
2006   int64_t output_dim_size = hlo->shape().dimensions(output_sharded_dim);
2007   auto input_shard_shape =
2008       MakePartitionedShape(operand.base_shape(), operand.sharding());
2009   auto output_shard_shape = MakePartitionedShape(hlo->shape(), sharding);
2010   if (input_dim_size % output_dim_size == 0) {
2011     // Split dim.
2012     int64_t split_factor = input_dim_size / output_dim_size;
2013     int64_t output_shard_size =
2014         output_shard_shape.dimensions(output_sharded_dim);
2015     // Use halo exchange to fix misaligned data.
2016     Window window;
2017     for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
2018       WindowDimension* dim = window.add_dimensions();
2019       dim->set_size(1);
2020       dim->set_stride(1);
2021       dim->set_window_dilation(1);
2022       dim->set_window_reversal(false);
2023       dim->set_base_dilation(1);
2024       dim->set_padding_low(0);
2025       if (i == input_sharded_dim) {
2026         dim->set_padding_high(output_shard_size * split_factor *
2027                                   num_partitions_ / replication_count -
2028                               input_dim_size);
2029       } else {
2030         dim->set_padding_high(0);
2031       }
2032     }
2033 
2034     auto reshard_operand = operand.ReshardAsWindowedInput(
2035         window, operand.sharding(),
2036         CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
2037         /*mask_invalid_region=*/false);
2038     if (!reshard_operand.has_value()) {
2039       return DefaultAction(hlo);
2040     }
2041     TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value());
2042     CHECK_EQ(
2043         reshard_operand->sharded_input->shape().dimensions(input_sharded_dim),
2044         output_shard_size * split_factor);
2045     SetPartitionedHlo(hlo, [&] {
2046       // Do a local reshape.
2047       return b_.AddInstruction(HloInstruction::CreateReshape(
2048           output_shard_shape, reshard_operand->sharded_input));
2049     });
2050     return Status::OK();
2051   } else if (output_dim_size % input_dim_size == 0) {
2052     // Merge dims.
2053     int64_t merge_factor = output_dim_size / input_dim_size;
2054     // First reshape locally. (The sharded dimension could include padded data.)
2055     auto tmp_shard_shape = output_shard_shape;
2056     tmp_shard_shape.set_dimensions(
2057         output_sharded_dim,
2058         input_shard_shape.dimensions(input_sharded_dim) * merge_factor);
2059     auto tmp_reshape = b_.AddInstruction(
2060         HloInstruction::CreateReshape(tmp_shard_shape, operand.hlo()));
2061     tmp_reshape->set_sharding(hlo->sharding());
2062     auto tmp_full_shape = tmp_shard_shape;
2063     tmp_full_shape.set_dimensions(
2064         output_sharded_dim, tmp_shard_shape.dimensions(output_sharded_dim) *
2065                                 num_partitions_ / replication_count);
2066     auto tmp_output =
2067         PartitionedHlo(tmp_reshape, tmp_full_shape, MakePartitioningState());
2068 
2069     // Use halo exchange to fix misaligned data.
2070     Window window;
2071     for (int64_t i = 0; i < tmp_shard_shape.rank(); ++i) {
2072       WindowDimension* dim = window.add_dimensions();
2073       dim->set_size(1);
2074       dim->set_stride(1);
2075       dim->set_window_dilation(1);
2076       dim->set_window_reversal(false);
2077       dim->set_base_dilation(1);
2078       dim->set_padding_low(0);
2079       if (i == output_sharded_dim) {
2080         dim->set_padding_high(output_dim_size -
2081                               tmp_shard_shape.dimensions(output_sharded_dim) *
2082                                   num_partitions_ / replication_count);
2083       } else {
2084         dim->set_padding_high(0);
2085       }
2086     }
2087 
2088     auto reshard_output = tmp_output.ReshardAsWindowedInput(
2089         window, sharding,
2090         CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
2091         /*mask_invalid_region=*/false);
2092     if (!reshard_output.has_value()) {
2093       return DefaultAction(hlo);
2094     }
2095     TF_RET_CHECK(!reshard_output->dynamic_slice_index_on_output.has_value());
2096     CHECK_EQ(
2097         reshard_output->sharded_input->shape().dimensions(output_sharded_dim),
2098         output_shard_shape.dimensions(output_sharded_dim));
2099     SetPartitionedHlo(hlo, [&] { return reshard_output->sharded_input; });
2100     return Status::OK();
2101   }
2102   return DefaultAction(hlo);
2103 }
2104 
HandleIota(HloInstruction * hlo)2105 Status SpmdPartitioningVisitor::HandleIota(HloInstruction* hlo) {
2106   const HloSharding& sharding = hlo->sharding();
2107   if (sharding.IsTileMaximal()) {
2108     return DefaultAction(hlo);
2109   }
2110 
2111   SetPartitionedHlo(hlo, [&] {
2112     int64_t dimension = Cast<HloIotaInstruction>(hlo)->iota_dimension();
2113     auto iota = b_.AddInstruction(HloInstruction::CreateIota(
2114         MakePartitionedShape(hlo->shape(), sharding), dimension));
2115 
2116     if (sharding.tile_assignment().dim(dimension) > 1) {
2117       auto partition_ordinals = MakeTiledPartitionOrdinals(
2118           sharding, MakePartitioningState().partition_id, &b_);
2119       auto multiplier = b_.AddInstruction(HloInstruction::CreateConstant(
2120           LiteralUtil::CreateR0<int32>(iota->shape().dimensions(dimension))));
2121       auto offset = b_.AddInstruction(HloInstruction::CreateBinary(
2122           ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply,
2123           partition_ordinals[dimension], multiplier));
2124       if (iota->shape().element_type() != S32) {
2125         offset = b_.AddInstruction(HloInstruction::CreateConvert(
2126             ShapeUtil::MakeShape(iota->shape().element_type(), {}), offset));
2127       }
2128       auto broadcast = b_.AddInstruction(
2129           HloInstruction::CreateBroadcast(iota->shape(), offset, {}));
2130       return b_.AddInstruction(HloInstruction::CreateBinary(
2131           iota->shape(), HloOpcode::kAdd, iota, broadcast));
2132     }
2133 
2134     return iota;
2135   });
2136 
2137   return Status::OK();
2138 }
2139 
HandleSingleDevice(const HloInstruction * hlo)2140 Status SpmdPartitioningVisitor::HandleSingleDevice(const HloInstruction* hlo) {
2141   TF_RET_CHECK(hlo->sharding().HasUniqueDevice());
2142   int64_t device = hlo->sharding().GetUniqueDevice();
2143   const HloSharding sharding = HloSharding::AssignDevice(device);
2144 
2145   std::vector<HloInstruction*> operands;
2146   std::vector<Shape> operand_shapes;
2147   for (const HloInstruction* operand : hlo->operands()) {
2148     operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
2149     operand_shapes.push_back(operand->shape());
2150   }
2151   auto operand = b_.AddInstruction(HloInstruction::CreateTuple(operands));
2152   auto operand_shape = ShapeUtil::MakeTupleShape(operand_shapes);
2153 
2154   auto on_device = b_.AddInstruction(
2155       HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(device)));
2156   auto pred = b_.AddInstruction(HloInstruction::CreateCompare(
2157       ShapeUtil::MakeShape(PRED, {}), MakePartitioningState().partition_id,
2158       on_device, ComparisonDirection::kEq));
2159 
2160   SpmdBuilder true_b("true_computation", visiting_hlo_);
2161   HloComputation* true_computation;
2162   {
2163     auto param = true_b.AddInstruction(HloInstruction::CreateParameter(
2164         /*parameter_number=*/0, operand_shape, "true_branch_param"));
2165     std::vector<HloInstruction*> new_operands;
2166     for (int64_t i = 0; i < operands.size(); ++i) {
2167       new_operands.push_back(true_b.AddInstruction(
2168           HloInstruction::CreateGetTupleElement(operand_shapes[i], param, i)));
2169     }
2170     auto root = true_b.AddInstruction(
2171         hlo->CloneWithNewOperands(hlo->shape(), new_operands));
2172     true_computation = module_->AddEmbeddedComputation(true_b.Build(root));
2173   }
2174 
2175   SpmdBuilder false_b("false_computation", visiting_hlo_);
2176   HloComputation* false_computation;
2177   {
2178     false_b.AddInstruction(HloInstruction::CreateParameter(
2179         /*parameter_number=*/0, operand_shape, "false_branch_param"));
2180     auto root = CreateZero(hlo->shape(), &false_b);
2181     false_computation = module_->AddEmbeddedComputation(false_b.Build(root));
2182   }
2183 
2184   SetPartitionedHlo(hlo, [&]() {
2185     return b_.AddInstruction(HloInstruction::CreateConditional(
2186         hlo->shape(), pred, operand, true_computation, operand,
2187         false_computation));
2188   });
2189   return Status::OK();
2190 }
2191 
HandleAllReduce(HloInstruction * hlo)2192 Status SpmdPartitioningVisitor::HandleAllReduce(HloInstruction* hlo) {
2193   if (hlo->IsCrossReplicaAllReduce() && hlo->operand_count() == 1) {
2194     return HandleElementwise(hlo);
2195   }
2196   return DefaultAction(hlo);
2197 }
2198 
HandleBroadcast(HloInstruction * hlo)2199 Status SpmdPartitioningVisitor::HandleBroadcast(HloInstruction* hlo) {
2200   if (hlo->sharding().IsTileMaximal()) {
2201     return DefaultAction(hlo);
2202   }
2203 
2204   auto& operand = GetPartitionedHlo(hlo->operand(0));
2205 
2206   // Tiled output.
2207   std::vector<int64> new_dims;
2208   for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
2209     if (!absl::c_linear_search(hlo->dimensions(), i)) {
2210       new_dims.push_back(i);
2211     }
2212   }
2213   auto desired_input_sharding = hlo_sharding_util::RemoveShapeDimensions(
2214       hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(hlo->sharding(),
2215                                                                new_dims),
2216       new_dims);
2217   auto input = operand.Reshard(desired_input_sharding).hlo();
2218   auto output_shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
2219   SetPartitionedHlo(hlo, [&] {
2220     return b_.AddInstruction(
2221         hlo->CloneWithNewOperands(output_shard_shape, {input}));
2222   });
2223   return Status::OK();
2224 }
2225 
HandleConstant(HloInstruction * hlo)2226 Status SpmdPartitioningVisitor::HandleConstant(HloInstruction* hlo) {
2227   const Literal& literal = hlo->literal();
2228   if (literal.shape().IsTuple() ||
2229       (!hlo->sharding().IsTileMaximal() &&
2230        (!EvenlyPartitions(hlo->shape(), hlo->sharding()) ||
2231         !literal.IsAllFirst()))) {
2232     return DefaultAction(hlo);
2233   }
2234 
2235   SetPartitionedHlo(hlo, [&]() {
2236     auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
2237     std::vector<int64> start_indices(hlo->shape().rank(), 0);
2238     auto constant = b_.AddInstruction(HloInstruction::CreateConstant(
2239         literal.Slice(start_indices, shard_shape.dimensions())));
2240     *constant->mutable_shape() = shard_shape;
2241     return constant;
2242   });
2243   return Status::OK();
2244 }
2245 
HandleDynamicSlice(HloInstruction * hlo)2246 Status SpmdPartitioningVisitor::HandleDynamicSlice(HloInstruction* hlo) {
2247   if (hlo->sharding().IsTileMaximal()) {
2248     return DefaultAction(hlo);
2249   }
2250   for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
2251     if (hlo->sharding().tile_assignment().dim(i) != 1 &&
2252         (hlo->dynamic_slice_sizes()[i] != hlo->shape().dimensions(i) ||
2253          !hlo->operand(i + 1)->IsConstant() ||
2254          !hlo->operand(i + 1)->literal().IsZero({}))) {
2255       // We currently do not partition the sliced dimensions.
2256       return DefaultAction(hlo);
2257     }
2258   }
2259   std::vector<HloInstruction*> new_indices(hlo->shape().rank());
2260   auto new_input =
2261       GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo();
2262   for (int64_t i = 0; i < new_indices.size(); ++i) {
2263     // Replicate the indices.
2264     new_indices[i] = GetPartitionedHlo(hlo->operand(i + 1))
2265                          .Reshard(HloSharding::Replicate())
2266                          .hlo();
2267   }
2268   SetPartitionedHlo(hlo, [&]() {
2269     auto partitioned_shape =
2270         MakePartitionedShape(hlo->shape(), hlo->sharding());
2271     return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
2272         partitioned_shape, new_input, new_indices,
2273         partitioned_shape.dimensions()));
2274   });
2275   return Status::OK();
2276 }
2277 
HandleDynamicUpdateSlice(HloInstruction * hlo)2278 Status SpmdPartitioningVisitor::HandleDynamicUpdateSlice(HloInstruction* hlo) {
2279   if (hlo->sharding().IsTileMaximal()) {
2280     return DefaultAction(hlo);
2281   }
2282 
2283   std::vector<int64> partitioned_slice_dims;
2284   std::vector<int64> slice_dims;
2285   std::vector<int64> partitioned_non_slice_dims;
2286   std::vector<int64> partitioned_slice_offsets;
2287   for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
2288     if (hlo->operand(1)->shape().dimensions(i) != hlo->shape().dimensions(i)) {
2289       slice_dims.push_back(i);
2290       int64_t slice_size = hlo->operand(1)->shape().dimensions(i);
2291       if (hlo->sharding().tile_assignment().dim(i) != 1) {
2292         if (!hlo->operand(i + 2)->IsConstant() && slice_size != 1) {
2293           return DefaultAction(hlo);
2294         }
2295         partitioned_slice_dims.push_back(i);
2296         // Set partitioned_slice_offsets to -1 when slice_size is 1.
2297         if (slice_size == 1) {
2298           partitioned_slice_offsets.push_back(-1);
2299         } else {
2300           partitioned_slice_offsets.push_back(
2301               hlo->operand(i + 2)->literal().Get<int>({}));
2302         }
2303       }
2304     } else if (hlo->sharding().tile_assignment().dim(i) != 1) {
2305       if (!hlo->operand(i + 2)->IsConstant() ||
2306           !hlo->operand(i + 2)->literal().IsZero({})) {
2307         return DefaultAction(hlo);
2308       }
2309       partitioned_non_slice_dims.push_back(i);
2310     }
2311   }
2312 
2313   // Handle when there is slice dim partitioned.
2314   if (!partitioned_slice_dims.empty()) {
2315     auto add_hlo = [&](std::unique_ptr<HloInstruction> to_add) {
2316       return b_.AddInstruction(std::move(to_add));
2317     };
2318     std::vector<HloInstruction*> new_indices(hlo->shape().rank());
2319     for (int64_t i = 0; i < new_indices.size(); ++i) {
2320       // Replicate the indices.
2321       new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2))
2322                            .Reshard(HloSharding::Replicate())
2323                            .hlo();
2324     }
2325 
2326     // Get partitioned input.
2327     const auto& dus_sharding = hlo->sharding();
2328     const auto& partitioned_input =
2329         GetPartitionedHlo(hlo->operand(0)).Reshard(dus_sharding).hlo();
2330 
2331     // Get replicate update.
2332     auto update_sharding = HloSharding::Replicate();
2333     if (!partitioned_non_slice_dims.empty()) {
2334       // Do partial replicate for update if non slice dims are partitioned.
2335       update_sharding =
2336           hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(dus_sharding,
2337                                                                    slice_dims);
2338     }
2339 
2340     // TODO(wangtao): use collective permute for sharded update.
2341     HloInstruction* replicate_update =
2342         GetPartitionedHlo(hlo->operand(1)).Reshard(update_sharding).hlo();
2343 
2344     const auto& update_shape = replicate_update->shape();
2345     const auto& partitioned_shape = partitioned_input->shape();
2346     auto partition_ordinals = MakeTiledPartitionOrdinals(
2347         hlo->sharding(), MakePartitioningState().partition_id, &b_);
2348     HloInstruction* all_dims_within_partition = add_hlo(
2349         HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
2350 
2351     for (int i = 0; i < partitioned_slice_dims.size(); ++i) {
2352       int dim = partitioned_slice_dims[i];
2353       // Calculate per partition size.
2354       const int64_t per_partition_size = partitioned_shape.dimensions(dim);
2355 
2356       // Only update within a single partition is supported.
2357       // Will ignore this check when slice size is 1 where
2358       // partitioned_slice_offsets[i] is -1.
2359       if ((partitioned_slice_offsets[i] != -1) &&
2360           (partitioned_slice_offsets[i] / per_partition_size) !=
2361               ((partitioned_slice_offsets[i] + update_shape.dimensions(dim) -
2362                 1) /
2363                per_partition_size)) {
2364         return DefaultAction(hlo);
2365       }
2366 
2367       // within_partition = (offset >= partition_id * per_partition_size) &&
2368       //                    (offset < (partition_id + 1) * per_partition_size)
2369       const Shape& compare_shape =
2370           ShapeUtil::ChangeElementType(partition_id_->shape(), PRED);
2371       auto per_partition_size_hlo = add_hlo(HloInstruction::CreateConstant(
2372           LiteralUtil::CreateR0<int>(per_partition_size)));
2373       const Shape& offset_shape = per_partition_size_hlo->shape();
2374       auto partition_offset = add_hlo(HloInstruction::CreateBinary(
2375           offset_shape, HloOpcode::kMultiply, partition_ordinals[dim],
2376           per_partition_size_hlo));
2377       // offset >= partition_id * per_partition_size
2378       auto offset_ge = add_hlo(HloInstruction::CreateCompare(
2379           compare_shape, new_indices[dim], partition_offset,
2380           ComparisonDirection::kGe));
2381       // offset < (partition_id + 1) * per_partition_size
2382       auto offset_lt = add_hlo(HloInstruction::CreateCompare(
2383           compare_shape, new_indices[dim],
2384           add_hlo(HloInstruction::CreateBinary(
2385               offset_shape, HloOpcode::kMultiply,
2386               add_hlo(HloInstruction::CreateBinary(
2387                   offset_shape, HloOpcode::kAdd, partition_ordinals[dim],
2388                   add_hlo(HloInstruction::CreateConstant(
2389                       LiteralUtil::CreateR0<int>(1))))),
2390               per_partition_size_hlo)),
2391           ComparisonDirection::kLt));
2392       auto update_within_partition = add_hlo(HloInstruction::CreateBinary(
2393           compare_shape, HloOpcode::kAnd, offset_ge, offset_lt));
2394 
2395       all_dims_within_partition = add_hlo(HloInstruction::CreateBinary(
2396           compare_shape, HloOpcode::kAnd, all_dims_within_partition,
2397           update_within_partition));
2398 
2399       // Calculate offset.
2400       // slice dim offset =
2401       //  within_partition ?
2402       //  offset - partition_id * per_partition_size : 0
2403       new_indices[dim] = add_hlo(HloInstruction::CreateTernary(
2404           new_indices[dim]->shape(), HloOpcode::kSelect,
2405           update_within_partition,
2406           add_hlo(HloInstruction::CreateBinary(
2407               new_indices[dim]->shape(), HloOpcode::kSubtract, new_indices[dim],
2408               partition_offset)),
2409           add_hlo(
2410               HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)))));
2411     }
2412 
2413     // Create dynamic update slice.
2414     auto dus = add_hlo(HloInstruction::CreateDynamicUpdateSlice(
2415         partitioned_shape, partitioned_input, replicate_update, new_indices));
2416     SetPartitionedHlo(hlo, [&]() {
2417       // Select if update is needed.
2418       return add_hlo(HloInstruction::CreateTernary(
2419           dus->shape(), HloOpcode::kSelect,
2420           add_hlo(HloInstruction::CreateBroadcast(
2421               ShapeUtil::ChangeElementType(dus->shape(), PRED),
2422               all_dims_within_partition, {})),
2423           dus, partitioned_input));
2424     });
2425     return Status::OK();
2426   }
2427 
2428   // Partition non slice dims only.
2429   std::vector<HloInstruction*> new_indices(hlo->shape().rank());
2430   auto new_input =
2431       GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo();
2432   auto new_update =
2433       GetPartitionedHlo(hlo->operand(1)).Reshard(hlo->sharding()).hlo();
2434   for (int64_t i = 0; i < new_indices.size(); ++i) {
2435     // Replicate the indices.
2436     new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2))
2437                          .Reshard(HloSharding::Replicate())
2438                          .hlo();
2439   }
2440   SetPartitionedHlo(hlo, [&]() {
2441     auto partitioned_shape =
2442         MakePartitionedShape(hlo->shape(), hlo->sharding());
2443     return b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2444         partitioned_shape, new_input, new_update, new_indices));
2445   });
2446   return Status::OK();
2447 }
2448 
HandleGetTupleElement(HloInstruction * hlo)2449 Status SpmdPartitioningVisitor::HandleGetTupleElement(HloInstruction* hlo) {
2450   const auto& tuple = GetPartitionedHlo(hlo->operand(0));
2451   auto gte = b_.AddInstruction(HloInstruction::CreateGetTupleElement(
2452       ShapeUtil::GetTupleElementShape(tuple.hlo()->shape(), hlo->tuple_index()),
2453       tuple.hlo(), hlo->tuple_index()));
2454   const auto source_sharding =
2455       tuple.sharding().GetSubSharding(tuple.base_shape(), {hlo->tuple_index()});
2456   gte->set_sharding(source_sharding);
2457   PartitionedHlo source_partitioned_gte(
2458       gte, tuple.base_shape().tuple_shapes(hlo->tuple_index()),
2459       MakePartitioningState());
2460   source_partitioned_gte = source_partitioned_gte.Reshard(hlo->sharding());
2461   SetPartitionedHlo(hlo, source_partitioned_gte);
2462   return Status::OK();
2463 }
2464 
HandleInfeed(HloInstruction * hlo)2465 Status SpmdPartitioningVisitor::HandleInfeed(HloInstruction* hlo) {
2466   const Shape& shape = ShapeUtil::GetTupleElementShape(hlo->shape(), 0);
2467   auto token = GetPartitionedHlo(hlo->operand(0)).hlo();
2468   if (ShapeUtil::GetLeafCount(shape) == 0) {
2469     // TODO(b/155819021): HloSharding has issues with tuple-shaped sharding: it
2470     // requires one element for an empty tuple, but leaf-count number of
2471     // elements for non-empty tuple. So if it has a nested empty tuple, we
2472     // cannot invoke GetSubSharding() since it expects a sharding for the empty
2473     // tuple. This is a workaround for that case.
2474     SetPartitionedHlo(hlo, [&]() {
2475       return b_.AddInstruction(
2476           HloInstruction::CreateInfeed(shape, token, hlo->infeed_config()));
2477     });
2478     return Status::OK();
2479   }
2480   auto sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0});
2481   auto shard_shape = MakePartitionedShape(shape, sharding);
2482   if (EvenlyPartitions(shape, sharding)) {
2483     SetPartitionedHlo(hlo, [&]() {
2484       return b_.AddInstruction(HloInstruction::CreateInfeed(
2485           shard_shape, token, hlo->infeed_config()));
2486     });
2487     return Status::OK();
2488   }
2489 
2490   if (hlo->sharding().HasUniqueDevice()) {
2491     return HandleSingleDevice(hlo);
2492   }
2493 
2494   // Create a branch for each unique partitioned shape.
2495   std::vector<Shape> per_branch_partitioned_shapes;
2496   std::vector<int32> conditional_branch_indices(num_partitions_);
2497   for (int64_t i = 0; i < num_partitions_; ++i) {
2498     auto partitioned_shape =
2499         MakeNonPaddedShapeForGivenPartition(shape, sharding, i);
2500     int64_t matching_existing_index = 0;
2501     for (; matching_existing_index < per_branch_partitioned_shapes.size();
2502          ++matching_existing_index) {
2503       if (ShapeUtil::Compatible(
2504               partitioned_shape,
2505               per_branch_partitioned_shapes[matching_existing_index])) {
2506         break;
2507       }
2508     }
2509     if (matching_existing_index < per_branch_partitioned_shapes.size()) {
2510       conditional_branch_indices[i] = matching_existing_index;
2511     } else {
2512       conditional_branch_indices[i] = per_branch_partitioned_shapes.size();
2513       per_branch_partitioned_shapes.push_back(std::move(partitioned_shape));
2514     }
2515   }
2516 
2517   HloInstruction* branch_index;
2518   auto state = MakePartitioningState();
2519   if (per_branch_partitioned_shapes.size() == num_partitions_) {
2520     // Use partition ID as the branch index if each partition has its own
2521     // branch.
2522     branch_index = state.partition_id;
2523     // PartitionId's output is U32 but conditional requires S32.
2524     if (branch_index->shape().element_type() != S32) {
2525       branch_index = b_.AddInstruction(HloInstruction::CreateConvert(
2526           ShapeUtil::ChangeElementType(branch_index->shape(), S32),
2527           branch_index));
2528     }
2529   } else {
2530     // Otherwise, use a constant table to look up the branch index.
2531     auto branch_index_table = b_.AddInstruction(HloInstruction::CreateConstant(
2532         LiteralUtil::CreateR1<int32>(conditional_branch_indices)));
2533     branch_index = b_.AddInstruction(HloInstruction::CreateDynamicSlice(
2534         ShapeUtil::MakeShape(S32, {1}), branch_index_table,
2535         {state.partition_id}, {1}));
2536     branch_index = b_.AddInstruction(HloInstruction::CreateReshape(
2537         ShapeUtil::MakeShape(S32, {}), branch_index));
2538   }
2539 
2540   std::vector<HloComputation*> branches(per_branch_partitioned_shapes.size());
2541   for (int64_t i = 0; i < branches.size(); ++i) {
2542     SpmdBuilder branch_b(absl::StrCat("infeed_branch_", i), visiting_hlo_);
2543     auto param = branch_b.AddInstruction(HloInstruction::CreateParameter(
2544         /*parameter_number=*/0, token->shape(), "infeed_token_param"));
2545     auto infeed = branch_b.AddInstruction(HloInstruction::CreateInfeed(
2546         per_branch_partitioned_shapes[i], param, hlo->infeed_config()));
2547     if (!ShapeUtil::Compatible(per_branch_partitioned_shapes[i], shard_shape)) {
2548       std::function<HloInstruction*(const ShapeIndex&, HloInstruction*)>
2549           pad_infeed = [&](const ShapeIndex& index,
2550                            HloInstruction* infeed_element) -> HloInstruction* {
2551         if (index == ShapeIndex({1})) {
2552           // Token.
2553           return infeed_element;
2554         }
2555         const Shape& element_shape =
2556             ShapeUtil::GetSubshape(infeed->shape(), index);
2557         if (element_shape.IsTuple() && element_shape.tuple_shapes_size() > 0) {
2558           std::vector<HloInstruction*> padded_elements(
2559               element_shape.tuple_shapes_size());
2560           for (int64_t i = 0; i < padded_elements.size(); ++i) {
2561             auto sub_index = index;
2562             sub_index.push_back(i);
2563             padded_elements[i] = pad_infeed(
2564                 sub_index,
2565                 branch_b.AddInstruction(HloInstruction::CreateGetTupleElement(
2566                     ShapeUtil::GetSubshape(element_shape, {i}), infeed_element,
2567                     i)));
2568           }
2569           return branch_b.AddInstruction(
2570               HloInstruction::CreateTuple(padded_elements));
2571         }
2572         const Shape& pad_shape =
2573             ShapeUtil::GetSubshape(shard_shape, ShapeIndexView(index, 1));
2574         if (ShapeUtil::Compatible(element_shape, pad_shape)) {
2575           return infeed_element;
2576         }
2577         if (element_shape.IsArray()) {
2578           CHECK(pad_shape.IsArray());
2579           return PadToShape(infeed_element, pad_shape, &branch_b);
2580         }
2581         CHECK(element_shape.IsTuple());
2582         CHECK(element_shape.tuple_shapes().empty());
2583         return CreateZero(pad_shape, &branch_b);
2584       };
2585       pad_infeed({}, infeed);
2586     }
2587     branches[i] = module_->AddEmbeddedComputation(branch_b.Build());
2588   }
2589   SetPartitionedHlo(hlo, [&]() {
2590     return b_.AddInstruction(HloInstruction::CreateConditional(
2591         ShapeUtil::MakeTupleShape({shard_shape, token->shape()}), branch_index,
2592         branches, std::vector<HloInstruction*>(branches.size(), token)));
2593   });
2594   return Status::OK();
2595 }
2596 
HandlePad(HloInstruction * hlo)2597 Status SpmdPartitioningVisitor::HandlePad(HloInstruction* hlo) {
2598   if (hlo->sharding().IsTileMaximal()) {
2599     return DefaultAction(hlo);
2600   }
2601   auto lhs = GetPartitionedHlo(hlo->operand(0));
2602   // Create a window config to represent the pad.
2603   Window window;
2604   bool needs_masking = false;
2605   const bool pad_value_is_zero =
2606       hlo->operand(1)->IsConstant() && hlo->operand(1)->literal().IsZero({});
2607   for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
2608     const auto& pd = hlo->padding_config().dimensions(i);
2609     WindowDimension* dim = window.add_dimensions();
2610     dim->set_size(1);
2611     dim->set_stride(1);
2612     dim->set_window_dilation(1);
2613     dim->set_window_reversal(false);
2614     dim->set_padding_low(pd.edge_padding_low());
2615     dim->set_padding_high(pd.edge_padding_high());
2616     dim->set_base_dilation(pd.interior_padding() + 1);
2617     const int64_t shard_count = hlo->sharding().tile_assignment().dim(i);
2618     // Need masking only if there is non-zero padding value or the operand is
2619     // unevenly partitioned. Halo exchange fills 0 in collective permute result
2620     // for non-destination cores.
2621     needs_masking |=
2622         shard_count > 1 &&
2623         (pd.edge_padding_low() > 0 || pd.edge_padding_high() > 0 ||
2624          pd.interior_padding() > 0) &&
2625         (!pad_value_is_zero ||
2626          hlo->operand(0)->shape().dimensions(i) % shard_count != 0);
2627   }
2628 
2629   auto replicated_rhs = GetPartitionedHlo(hlo->operand(1))
2630                             .Reshard(HloSharding::Replicate())
2631                             .hlo();
2632   auto reshard_operand =
2633       lhs.ReshardAsWindowedInput(window, hlo->sharding(), replicated_rhs,
2634                                  /*mask_invalid_region=*/needs_masking);
2635   if (!reshard_operand.has_value()) {
2636     return DefaultAction(hlo);
2637   }
2638   PaddingConfig sharded_padding_config;
2639   bool need_pad = false;
2640   for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
2641     auto dim = sharded_padding_config.add_dimensions();
2642     const auto& wd = reshard_operand->shard_window.dimensions(i);
2643     dim->set_edge_padding_low(wd.padding_low());
2644     dim->set_edge_padding_high(wd.padding_high());
2645     dim->set_interior_padding(wd.base_dilation() - 1);
2646     if (wd.padding_low() != 0 || wd.padding_high() != 0 ||
2647         wd.base_dilation() != 1) {
2648       need_pad = true;
2649     }
2650   }
2651   auto sharded_pad = reshard_operand->sharded_input;
2652   if (need_pad) {
2653     TF_ASSIGN_OR_RETURN(auto sharded_pad_shape,
2654                         ShapeInference::InferPadShape(sharded_pad->shape(),
2655                                                       replicated_rhs->shape(),
2656                                                       sharded_padding_config));
2657     sharded_pad = b_.AddInstruction(hlo->CreatePad(sharded_pad_shape,
2658                                                    sharded_pad, replicated_rhs,
2659                                                    sharded_padding_config));
2660   }
2661 
2662   SetPartitionedHlo(hlo, [&]() {
2663     if (!reshard_operand->dynamic_slice_index_on_output) {
2664       return sharded_pad;
2665     }
2666     auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
2667     return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
2668         shard_shape, sharded_pad,
2669         *reshard_operand->dynamic_slice_index_on_output,
2670         shard_shape.dimensions()));
2671   });
2672   return Status::OK();
2673 }
2674 
HandleParameter(HloInstruction * hlo)2675 Status SpmdPartitioningVisitor::HandleParameter(HloInstruction* hlo) {
2676   SetPartitionedHlo(hlo, [&]() {
2677     auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
2678     auto new_param = b_.AddInstruction(HloInstruction::CreateParameter(
2679         hlo->parameter_number(), shard_shape, "param"));
2680     if (hlo->parameter_replicated_at_leaf_buffers()) {
2681       new_param->set_parameter_replicated_at_leaf_buffers(
2682           *hlo->parameter_replicated_at_leaf_buffers());
2683     }
2684     return new_param;
2685   });
2686   return Status::OK();
2687 }
2688 
HandleReduce(HloInstruction * hlo)2689 Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) {
2690   if (hlo->sharding().HasUniqueDevice()) {
2691     return DefaultAction(hlo);
2692   }
2693   int64_t input_count = 1;
2694   auto per_input_sharding = hlo->sharding();
2695   if (hlo->shape().IsTuple()) {
2696     input_count = hlo->shape().tuple_shapes_size();
2697     CHECK_GT(input_count, 0);
2698     per_input_sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0});
2699   }
2700 
2701   std::vector<PartitionedHlo> inputs;
2702   std::vector<HloInstruction*> inits;
2703   std::vector<int64> preserved_dims;
2704   for (int64_t i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
2705     if (!absl::c_linear_search(hlo->dimensions(), i)) {
2706       preserved_dims.push_back(i);
2707     }
2708   }
2709 
2710   for (int64_t operand_id = 0; operand_id < input_count; ++operand_id) {
2711     inits.push_back(GetPartitionedHlo(hlo->operand(operand_id + input_count))
2712                         .Reshard(HloSharding::Replicate())
2713                         .hlo());
2714     inputs.push_back(GetPartitionedHlo(hlo->operand(operand_id)));
2715     if (operand_id > 0) {
2716       // Make sure all operands are sharded in the same way.
2717       inputs.back() = inputs.back().Reshard(inputs[0].sharding());
2718     }
2719     if (!inputs[0].sharding().IsTileMaximal()) {
2720       inputs.back() =
2721           inputs.back().PadWithValue(inits[operand_id], /*left_padded_dims=*/{},
2722                                      /*skipped_dims=*/preserved_dims);
2723     }
2724   }
2725 
2726   std::vector<Shape*> new_operand_shapes(input_count * 2);
2727   for (int64_t i = 0; i < input_count; ++i) {
2728     new_operand_shapes[i] = inputs[i].hlo()->mutable_shape();
2729     new_operand_shapes[i + input_count] = inits[i]->mutable_shape();
2730   }
2731   // Create the shard shape of the reduce result.
2732   TF_ASSIGN_OR_RETURN(
2733       auto reduce_shape,
2734       ShapeInference::InferReduceShape(new_operand_shapes, hlo->dimensions(),
2735                                        hlo->to_apply()->ComputeProgramShape()));
2736 
2737   std::vector<HloInstruction*> input_hlos(input_count);
2738   for (int64_t i = 0; i < input_count; ++i) {
2739     input_hlos[i] = inputs[i].hlo();
2740   }
2741   auto local_reduce = b_.AddInstruction(HloInstruction::CreateReduce(
2742       reduce_shape, input_hlos, inits, hlo->dimensions(), hlo->to_apply()));
2743 
2744   SetPartitionedHlo(hlo, [&]() {
2745     HloInstruction* reduce = local_reduce;
2746     const bool reduce_sharded_dimension =
2747         !inputs[0].sharding().IsTileMaximal() &&
2748         absl::c_any_of(hlo->dimensions(), [&](int64_t i) {
2749           return inputs[0].sharding().tile_assignment().dim(i) > 1;
2750         });
2751     if (reduce_sharded_dimension) {
2752       if (inputs[0].sharding().ReplicateOnLastTileDim()) {
2753         preserved_dims.push_back(inputs[0].base_shape().rank());
2754       }
2755       if (local_reduce->shape().IsArray()) {
2756         reduce = partitioner_->AllReduceAlongShardingDims(
2757             &b_, local_reduce, inputs[0].sharding(), next_channel_id_,
2758             hlo->dimensions(), collective_ops_creator_, hlo->to_apply());
2759       } else {
2760         auto grouped =
2761             GroupShardingOnDims(inputs[0].sharding(), preserved_dims);
2762         auto grouped_state = CreatePerGroupPartitioningState(
2763             inputs[0].state(), grouped.device_groups, &b_);
2764         std::vector<HloInstruction*> all_gathered_partial_results(input_count);
2765         for (int64_t i = 0; i < input_count; ++i) {
2766           auto gte = b_.AddInstruction(HloInstruction::CreateGetTupleElement(
2767               ShapeUtil::GetTupleElementShape(reduce_shape, i), local_reduce,
2768               i));
2769           auto expanded_shape = input_hlos[i]->shape();
2770           auto all_gather_shape = input_hlos[i]->shape();
2771           for (int64_t dim : hlo->dimensions()) {
2772             expanded_shape.set_dimensions(dim, 1);
2773             all_gather_shape.set_dimensions(
2774                 dim, inputs[0].sharding().tile_assignment().dim(dim));
2775           }
2776           auto reshape = b_.AddInstruction(
2777               HloInstruction::CreateReshape(expanded_shape, gte));
2778           // Replicate per group.
2779           reshape->set_sharding(grouped.sharding);
2780           all_gathered_partial_results[i] =
2781               PartitionedHlo(reshape, all_gather_shape, grouped_state)
2782                   .Replicate()
2783                   .hlo();
2784         }
2785         reduce = b_.AddInstruction(HloInstruction::CreateReduce(
2786             reduce_shape, all_gathered_partial_results, inits,
2787             hlo->dimensions(), hlo->to_apply()));
2788       }
2789     }
2790     auto sharding = hlo_sharding_util::RemoveShapeDimensions(
2791         hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2792             inputs[0].sharding(), hlo->dimensions()),
2793         hlo->dimensions());
2794     if (local_reduce->shape().IsArray()) {
2795       reduce->set_sharding(sharding);
2796     } else {
2797       reduce->set_sharding(HloSharding::Tuple(
2798           reduce->shape(), std::vector<HloSharding>(input_count, sharding)));
2799     }
2800     return PartitionedHlo(reduce, hlo->shape(), MakePartitioningState())
2801         .Reshard(hlo->sharding())
2802         .hlo();
2803   });
2804   return Status::OK();
2805 }
2806 
HandleReverse(HloInstruction * hlo)2807 Status SpmdPartitioningVisitor::HandleReverse(HloInstruction* hlo) {
2808   auto reverse = Cast<HloReverseInstruction>(hlo);
2809   if (reverse->sharding().IsTileMaximal()) {
2810     return DefaultAction(hlo);
2811   }
2812   auto operand = GetPartitionedHlo(reverse->operand(0))
2813                      .Reshard(hlo_sharding_util::ReverseSharding(
2814                          reverse->sharding(), reverse->dimensions()));
2815   auto left_padded_operand =
2816       HaloExchangeToPadOnLeft(operand, reverse->dimensions());
2817   if (!left_padded_operand) {
2818     return DefaultAction(hlo);
2819   }
2820   SetPartitionedHlo(hlo, [&] {
2821     return b_.AddInstruction(hlo->CloneWithNewOperands(
2822         left_padded_operand->shape(), {left_padded_operand}));
2823   });
2824   return Status::OK();
2825 }
2826 
HandleWhile(HloInstruction * hlo)2827 Status SpmdPartitioningVisitor::HandleWhile(HloInstruction* hlo) {
2828   const HloSharding& sharding = hlo->sharding();
2829 
2830   // Shardings for the body parameter, body root, and cond parameter must be
2831   // the same, and the condition root must be replicated so that all partitions
2832   // follow the same control flow.
2833   hlo->while_condition()->parameter_instruction(0)->set_sharding(sharding);
2834   hlo->while_body()->parameter_instruction(0)->set_sharding(sharding);
2835   const HloSharding& cond_root_sharding =
2836       hlo->while_condition()->root_instruction()->sharding();
2837   TF_RETURN_IF_ERROR(partitioner_
2838                          ->PartitionComputation(hlo->while_condition(),
2839                                                 cond_root_sharding.IsManual()
2840                                                     ? cond_root_sharding
2841                                                     : HloSharding::Replicate(),
2842                                                 next_channel_id_, logger_)
2843                          .status());
2844   TF_RETURN_IF_ERROR(partitioner_
2845                          ->PartitionComputation(hlo->while_body(), sharding,
2846                                                 next_channel_id_, logger_)
2847                          .status());
2848   SetPartitionedHlo(hlo, [&] {
2849     return b_.AddInstruction(HloInstruction::CreateWhile(
2850         MakePartitionedShape(hlo->shape(), sharding), hlo->while_condition(),
2851         hlo->while_body(),
2852         GetPartitionedHlo(hlo->operand(0)).Reshard(sharding).hlo()));
2853   });
2854   return Status::OK();
2855 }
2856 
HandleConditional(HloInstruction * hlo)2857 Status SpmdPartitioningVisitor::HandleConditional(HloInstruction* hlo) {
2858   std::vector<HloInstruction*> branch_args;
2859   for (int64_t i = 0; i < hlo->branch_count(); ++i) {
2860     HloComputation* computation = hlo->branch_computation(i);
2861 
2862     // Shardings of the branch computation parameter and its argument must be
2863     // the same.
2864     computation->parameter_instruction(0)->set_sharding(
2865         hlo->operand(i + 1)->sharding());
2866     branch_args.push_back(GetPartitionedHlo(hlo->operand(i + 1)).hlo());
2867   }
2868 
2869   // The root of the branch computations must follow the sharding of the
2870   // conditional instruction.
2871   for (int64_t i = 0; i < hlo->branch_count(); ++i) {
2872     HloComputation* computation = hlo->branch_computation(i);
2873     TF_RETURN_IF_ERROR(partitioner_
2874                            ->PartitionComputation(computation, hlo->sharding(),
2875                                                   next_channel_id_, logger_)
2876                            .status());
2877   }
2878   SetPartitionedHlo(hlo, [&] {
2879     HloInstruction* cond = GetPartitionedHlo(hlo->operand(0)).hlo();
2880     if (!hlo->operand(0)->sharding().IsManual()) {
2881       // We replicate the predicate of the conditional (the first operand) so
2882       // that all partitions follow the same control flow.
2883       cond = GetPartitionedHlo(hlo->operand(0))
2884                  .Reshard(HloSharding::Replicate())
2885                  .hlo();
2886     }
2887     return b_.AddInstruction(HloInstruction::CreateConditional(
2888         MakePartitionedShape(hlo->shape(), hlo->sharding()), cond,
2889         hlo->called_computations(), branch_args));
2890   });
2891   return Status::OK();
2892 }
2893 
HandleOutfeed(HloInstruction * hlo)2894 Status SpmdPartitioningVisitor::HandleOutfeed(HloInstruction* hlo) {
2895   if (hlo->sharding().HasUniqueDevice()) {
2896     return HandleSingleDevice(hlo);
2897   }
2898 
2899   const auto& sharding = hlo->sharding();
2900   const Shape& shape = hlo->operand(0)->shape();
2901   auto partitioned_operand =
2902       GetPartitionedHlo(hlo->operand(0)).Reshard(sharding);
2903   const auto& shard_shape = partitioned_operand.hlo()->shape();
2904   const auto& operand = partitioned_operand.hlo();
2905   auto token = GetPartitionedHlo(hlo->operand(1)).hlo();
2906 
2907   if (EvenlyPartitions(shape, sharding)) {
2908     Shape outfeed_shape = operand->shape();
2909     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(hlo->outfeed_shape(),
2910                                                            &outfeed_shape));
2911     SetPartitionedHlo(hlo, [&]() {
2912       return b_.AddInstruction(HloInstruction::CreateOutfeed(
2913           outfeed_shape, operand, token, hlo->outfeed_config()));
2914     });
2915     return Status::OK();
2916   }
2917 
2918   // Create a branch for each unique partitioned shape.
2919   std::vector<Shape> per_branch_partitioned_shapes;
2920   std::vector<int32> conditional_branch_indices(num_partitions_);
2921   for (int64_t i = 0; i < num_partitions_; ++i) {
2922     auto partitioned_shape =
2923         MakeNonPaddedShapeForGivenPartition(shape, sharding, i);
2924     int64_t matching_existing_index = 0;
2925     for (; matching_existing_index < per_branch_partitioned_shapes.size();
2926          ++matching_existing_index) {
2927       if (ShapeUtil::Compatible(
2928               partitioned_shape,
2929               per_branch_partitioned_shapes[matching_existing_index])) {
2930         break;
2931       }
2932     }
2933     if (matching_existing_index < per_branch_partitioned_shapes.size()) {
2934       conditional_branch_indices[i] = matching_existing_index;
2935     } else {
2936       conditional_branch_indices[i] = per_branch_partitioned_shapes.size();
2937       per_branch_partitioned_shapes.push_back(std::move(partitioned_shape));
2938     }
2939   }
2940 
2941   // Get branch index for this partition.
2942   HloInstruction* branch_index;
2943   auto state = MakePartitioningState();
2944   if (per_branch_partitioned_shapes.size() == num_partitions_) {
2945     // Use partition ID as the branch index if each partition has its own
2946     // branch.
2947     branch_index = state.partition_id;
2948     // PartitionId's output is U32 but conditional requires S32.
2949     if (branch_index->shape().element_type() != S32) {
2950       branch_index = b_.AddInstruction(HloInstruction::CreateConvert(
2951           ShapeUtil::ChangeElementType(branch_index->shape(), S32),
2952           branch_index));
2953     }
2954   } else {
2955     // Otherwise, use a constant table to look up the branch index.
2956     auto branch_index_table = b_.AddInstruction(HloInstruction::CreateConstant(
2957         LiteralUtil::CreateR1<int32>(conditional_branch_indices)));
2958     branch_index = b_.AddInstruction(HloInstruction::CreateDynamicSlice(
2959         ShapeUtil::MakeShape(S32, {1}), branch_index_table, {partition_id_},
2960         {1}));
2961     branch_index = b_.AddInstruction(HloInstruction::CreateReshape(
2962         ShapeUtil::MakeShape(S32, {}), branch_index));
2963   }
2964 
2965   // Create conditional for the outfeed.
2966   std::vector<HloComputation*> branches(per_branch_partitioned_shapes.size());
2967   for (int64_t i = 0; i < branches.size(); ++i) {
2968     SpmdBuilder branch_b(absl::StrCat("outfeed_branch_", i), visiting_hlo_);
2969     // Create tuple param within the branch.
2970     auto param = branch_b.AddInstruction(HloInstruction::CreateParameter(
2971         /*parameter_number=*/0,
2972         ShapeUtil::MakeTupleShape({operand->shape(), token->shape()}),
2973         "outfeed_token_param"));
2974     auto outfeed_data = branch_b.AddInstruction(
2975         HloInstruction::CreateGetTupleElement(operand->shape(), param, 0));
2976     auto outfeed_token = branch_b.AddInstruction(
2977         HloInstruction::CreateGetTupleElement(token->shape(), param, 1));
2978     if (!ShapeUtil::Compatible(per_branch_partitioned_shapes[i], shard_shape)) {
2979       std::function<HloInstruction*(const ShapeIndex&, HloInstruction*)>
2980           slice_outfeed =
2981               [&](const ShapeIndex& index,
2982                   HloInstruction* outfeed_operand) -> HloInstruction* {
2983         // Get outfeed element shape.
2984         const Shape& element_shape =
2985             ShapeUtil::GetSubshape(outfeed_data->shape(), index);
2986         // Recursively call slice_outfeed for tuple shapes.
2987         if (element_shape.IsTuple() && element_shape.tuple_shapes_size() > 0) {
2988           std::vector<HloInstruction*> slice_elements(
2989               element_shape.tuple_shapes_size());
2990           for (int64_t i = 0; i < slice_elements.size(); ++i) {
2991             auto sub_index = index;
2992             sub_index.push_back(i);
2993             slice_elements[i] = slice_outfeed(
2994                 sub_index,
2995                 branch_b.AddInstruction(HloInstruction::CreateGetTupleElement(
2996                     ShapeUtil::GetSubshape(element_shape, {i}), outfeed_operand,
2997                     i)));
2998           }
2999           return branch_b.AddInstruction(
3000               HloInstruction::CreateTuple(slice_elements));
3001         }
3002         // Get the slice shape.
3003         const Shape& slice_shape = ShapeUtil::GetSubshape(
3004             per_branch_partitioned_shapes[i], ShapeIndexView(index));
3005         if (ShapeUtil::Compatible(element_shape, slice_shape)) {
3006           return outfeed_operand;
3007         }
3008         // Slice out useful data.
3009         if (element_shape.IsArray()) {
3010           CHECK(slice_shape.IsArray());
3011           std::vector<int64> start_indices(slice_shape.rank(), 0);
3012           std::vector<int64> slice_strides(slice_shape.rank(), 1);
3013           return branch_b.AddInstruction(HloInstruction::CreateSlice(
3014               slice_shape, outfeed_operand, start_indices,
3015               slice_shape.dimensions(), slice_strides));
3016         }
3017         CHECK(element_shape.IsTuple());
3018         CHECK(element_shape.tuple_shapes().empty());
3019         return outfeed_operand;
3020       };
3021       outfeed_data = slice_outfeed({}, outfeed_data);
3022     }
3023     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
3024         hlo->outfeed_shape(), &per_branch_partitioned_shapes[i]));
3025     branch_b.AddInstruction(HloInstruction::CreateOutfeed(
3026         per_branch_partitioned_shapes[i], outfeed_data, outfeed_token,
3027         hlo->outfeed_config()));
3028     branches[i] = module_->AddEmbeddedComputation(branch_b.Build());
3029   }
3030   SetPartitionedHlo(hlo, [&]() {
3031     return b_.AddInstruction(HloInstruction::CreateConditional(
3032         token->shape(), branch_index, branches,
3033         std::vector<HloInstruction*>(
3034             branches.size(),
3035             b_.AddInstruction(HloInstruction::CreateTuple({operand, token})))));
3036   });
3037   return Status::OK();
3038 }
3039 
HandleRng(HloInstruction * hlo)3040 Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) {
3041   if (hlo->sharding().HasUniqueDevice()) {
3042     return HandleSingleDevice(hlo);
3043   }
3044   auto clone_from_original = [&](const HloSharding& shared_sharding) {
3045     std::vector<HloInstruction*> new_operands;
3046     for (int64_t i = 0; i < hlo->operand_count(); ++i) {
3047       new_operands.push_back(
3048           GetPartitionedHlo(hlo->operand(i)).Reshard(shared_sharding).hlo());
3049     }
3050     auto clone = b_.AddInstruction(
3051         hlo->CloneWithNewOperands(hlo->shape(), new_operands));
3052     clone->set_sharding(shared_sharding);
3053     return clone;
3054   };
3055 
3056   if (hlo->sharding().IsManual()) {
3057     SetPartitionedHlo(hlo,
3058                       [&] { return clone_from_original(hlo->sharding()); });
3059     return Status::OK();
3060   }
3061 
3062   if (hlo->sharding().IsReplicated()) {
3063     SetPartitionedHlo(hlo, [&] {
3064       // Run on a single device (0) and distribute the data to all other cores.
3065       auto clone = clone_from_original(HloSharding::AssignDevice(0));
3066       return PartitionedHlo(clone, hlo->shape(), MakePartitioningState())
3067           .Reshard(HloSharding::Replicate())
3068           .hlo();
3069     });
3070     return Status::OK();
3071   }
3072 
3073   TF_RET_CHECK(!hlo->sharding().IsTileMaximal());
3074   // Replicate the operands and run partitioned Rng on all devices.
3075   std::vector<HloInstruction*> new_operands;
3076   for (int64_t i = 0; i < hlo->operand_count(); ++i) {
3077     new_operands.push_back(GetPartitionedHlo(hlo->operand(i))
3078                                .Reshard(HloSharding::Replicate())
3079                                .hlo());
3080   }
3081 
3082   if (!hlo->sharding().ReplicateOnLastTileDim()) {
3083     SetPartitionedHlo(hlo, [&] {
3084       return b_.AddInstruction(HloInstruction::CreateRng(
3085           MakePartitionedShape(hlo->shape(), hlo->sharding()),
3086           hlo->random_distribution(), new_operands));
3087     });
3088   } else {
3089     std::vector<int64> group_dims(
3090         hlo->sharding().tile_assignment().num_dimensions() - 1);
3091     std::iota(group_dims.begin(), group_dims.end(), 0);
3092     auto sharding_grouped = GroupShardingOnDims(hlo->sharding(), group_dims);
3093     auto per_group_state = CreatePerGroupPartitioningState(
3094         MakePartitioningState(), sharding_grouped.device_groups, &b_);
3095     auto rng = b_.AddInstruction(HloInstruction::CreateRng(
3096         MakePartitionedShape(hlo->shape(), hlo->sharding()),
3097         hlo->random_distribution(), new_operands));
3098     rng->set_sharding(HloSharding::AssignDevice(0));
3099     SetPartitionedHlo(hlo, [&]() {
3100       return PartitionedHlo(rng, rng->shape(), per_group_state)
3101           .Replicate()
3102           .hlo();
3103     });
3104   }
3105   return Status::OK();
3106 }
3107 
HandleReduceWindow(HloInstruction * hlo)3108 Status SpmdPartitioningVisitor::HandleReduceWindow(HloInstruction* hlo) {
3109   if (hlo->sharding().IsTileMaximal()) {
3110     return DefaultAction(hlo);
3111   }
3112   HloReduceWindowInstruction* reduce_window =
3113       Cast<HloReduceWindowInstruction>(hlo);
3114   absl::Span<HloInstruction* const> input_arrays = reduce_window->inputs();
3115   absl::Span<HloInstruction* const> init_values = reduce_window->init_values();
3116   int64_t input_idx = 0;
3117   absl::InlinedVector<PartitionedHlo::WindowedInputShardReturnValue, 2>
3118       sharded_results;
3119   absl::InlinedVector<const Shape*, 2> sharded_input_shapes,
3120       replicated_init_shapes;
3121   absl::InlinedVector<HloInstruction*, 2> sharded_inputs, replicated_inits;
3122   for (const HloInstruction* input_array : input_arrays) {
3123     PartitionedHlo& operand = GetPartitionedHlo(input_array);
3124     // Replicate init
3125     PartitionedHlo replicated_init = GetPartitionedHlo(init_values[input_idx++])
3126                                          .Reshard(HloSharding::Replicate());
3127     auto resharded_operand_and_window = operand.ReshardAsWindowedInput(
3128         hlo->window(), hlo->sharding(), replicated_init.hlo());
3129     if (!resharded_operand_and_window.has_value()) {
3130       return DefaultAction(hlo);
3131     }
3132     sharded_results.push_back(resharded_operand_and_window.value());
3133     sharded_inputs.push_back(resharded_operand_and_window->sharded_input);
3134     sharded_input_shapes.push_back(&sharded_inputs.back()->shape());
3135     replicated_inits.push_back(replicated_init.hlo());
3136     replicated_init_shapes.push_back(&replicated_inits.back()->shape());
3137   }
3138   TF_ASSIGN_OR_RETURN(Shape sharded_rw_shape,
3139                       ShapeInference::InferReduceWindowShape(
3140                           sharded_input_shapes, replicated_init_shapes,
3141                           sharded_results[0].shard_window,
3142                           hlo->to_apply()->ComputeProgramShape()));
3143   HloSharding result_sharding =
3144       (hlo->shape().IsTuple())
3145           ? hlo->sharding().GetTupleSharding(hlo->shape()).ValueOrDie()
3146           : hlo->sharding();
3147   Shape shard_shape = MakePartitionedShape(hlo->shape(), result_sharding);
3148   *sharded_rw_shape.mutable_layout() = shard_shape.layout();
3149   SetPartitionedHlo(hlo, [&]() {
3150     HloInstruction* sharded_rw =
3151         b_.AddInstruction(HloInstruction::CreateReduceWindow(
3152             sharded_rw_shape, sharded_inputs, replicated_inits,
3153             sharded_results[0].shard_window, hlo->to_apply()));
3154     if (!sharded_results[0].dynamic_slice_index_on_output.has_value()) {
3155       CHECK(ShapeUtil::Compatible(shard_shape, sharded_rw->shape()))
3156           << shard_shape << " vs " << sharded_rw->shape() << "\n";
3157       return sharded_rw;
3158     }
3159     return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
3160         shard_shape, sharded_rw,
3161         *sharded_results[0].dynamic_slice_index_on_output,
3162         shard_shape.dimensions()));
3163   });
3164   return Status::OK();
3165 }
3166 
HandleSelectAndScatter(HloInstruction * hlo)3167 Status SpmdPartitioningVisitor::HandleSelectAndScatter(HloInstruction* hlo) {
3168   if (hlo->sharding().IsTileMaximal()) {
3169     return DefaultAction(hlo);
3170   }
3171   auto operand = GetPartitionedHlo(hlo->operand(0));
3172   auto source = GetPartitionedHlo(hlo->mutable_operand(1));
3173   if (hlo->sharding() != operand.sharding()) {
3174     operand = operand.Reshard(hlo->sharding());
3175   }
3176   if (hlo->sharding() != source.sharding()) {
3177     source = source.Reshard(hlo->sharding());
3178   }
3179 
3180   // For F32 and BF16 types, we can use NaN padding to workaround the issue with
3181   // low/high padding, since comparison will return false with NaN input.
3182   if (hlo->shape().element_type() != F32 &&
3183       hlo->shape().element_type() != BF16) {
3184     return DefaultAction(hlo);
3185   }
3186 
3187   auto select = hlo->called_computations()[0];
3188   auto select_root = select->root_instruction();
3189   if (select_root->opcode() != HloOpcode::kCompare ||
3190       select_root->operand(0)->opcode() != HloOpcode::kParameter ||
3191       select_root->operand(1)->opcode() != HloOpcode::kParameter ||
3192       select_root->operand(0)->parameter_number() +
3193               select_root->operand(1)->parameter_number() !=
3194           1) {
3195     return DefaultAction(hlo);
3196   }
3197 
3198   float float_pad_value;
3199   if (select_root->comparison_direction() == ComparisonDirection::kGe ||
3200       select_root->comparison_direction() == ComparisonDirection::kGt) {
3201     if (select_root->operand(0)->parameter_number() == 0) {
3202       float_pad_value = -std::numeric_limits<float>::infinity();
3203     } else {
3204       float_pad_value = std::numeric_limits<float>::infinity();
3205     }
3206   } else if (select_root->comparison_direction() == ComparisonDirection::kLe ||
3207              select_root->comparison_direction() == ComparisonDirection::kLt) {
3208     if (select_root->operand(0)->parameter_number() == 0) {
3209       float_pad_value = std::numeric_limits<float>::infinity();
3210     } else {
3211       float_pad_value = -std::numeric_limits<float>::infinity();
3212     }
3213   } else {
3214     return DefaultAction(hlo);
3215   }
3216 
3217   auto pad_value = b_.AddInstruction(HloInstruction::CreateConstant(
3218       hlo->shape().element_type() == BF16
3219           ? LiteralUtil::CreateR0<bfloat16>(
3220                 static_cast<bfloat16>(float_pad_value))
3221           : LiteralUtil::CreateR0<float>(float_pad_value)));
3222 
3223   // Replicate init
3224   auto replicated_init = GetPartitionedHlo(hlo->mutable_operand(2))
3225                              .Reshard(HloSharding::Replicate());
3226 
3227   auto state = MakePartitioningState();
3228   auto partition_ordinals =
3229       MakeTiledPartitionOrdinals(hlo->sharding(), state.partition_id, &b_);
3230 
3231   // The first window for each dimension that overlaps with the shard area.
3232   std::vector<MultiplyAddDivideOffsetCalculation> first_window(
3233       hlo->shape().rank());
3234   // The first window for each dimension that goes beyond with the shard area.
3235   std::vector<MultiplyAddDivideOffsetCalculation> limit_window(
3236       hlo->shape().rank());
3237   std::vector<OffsetCalculation> data_left_halo_sizes(hlo->shape().rank());
3238   std::vector<OffsetCalculation> data_right_halo_sizes(hlo->shape().rank());
3239   std::vector<OffsetCalculation> source_left_halo_sizes(hlo->shape().rank());
3240   std::vector<OffsetCalculation> source_right_halo_sizes(hlo->shape().rank());
3241   auto unpadded_data_shard_shape =
3242       MakePartitionedShape(hlo->shape(), hlo->sharding());
3243   auto unpadded_source_shard_shape =
3244       MakePartitionedShape(hlo->operand(1)->shape(), hlo->sharding());
3245   auto source_shard_hlo = source.hlo();
3246   auto data_shard_hlo = operand.hlo();
3247   for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
3248     int64_t shard_count = hlo->sharding().tile_assignment().dim(i);
3249     if (shard_count == 1) {
3250       continue;
3251     }
3252     // If stride > window_size, there will be gaps between windows. These gaps
3253     // will also exist in the output, so we keep them during halo exchange.
3254     //
3255     // TODO(yuanzx): This could introduce overhead if partitions start at
3256     // different offsets in a gap.
3257     auto wd = hlo->window().dimensions(i);
3258     if (wd.stride() > wd.size()) {
3259       wd.set_size(wd.stride());
3260     }
3261     // shard_size * i < stride * k - pad_low + window_size  =>
3262     //   k > (shard_size * i + pad_low - window_size) / stride  =>
3263     //   first_k == (shard_size * i + pad_low - window_size + stride) / stride
3264     first_window[i] = MultiplyAddDivideOffsetCalculation(
3265         unpadded_data_shard_shape.dimensions(i),
3266         wd.padding_low() - wd.size() + wd.stride(), wd.stride());
3267     // shard_size * (i + 1) <= stride * k - pad_low  =>
3268     //   k >= (shard_size * i + shard_size + pad_low) / stride  =>
3269     //   limit_k == (shard_size * i + shard_size + pad_low + stride - 1) /
3270     //     stride
3271     limit_window[i] = MultiplyAddDivideOffsetCalculation(
3272         unpadded_data_shard_shape.dimensions(i),
3273         unpadded_data_shard_shape.dimensions(i) + wd.padding_low() +
3274             wd.stride() - 1,
3275         wd.stride());
3276     source_left_halo_sizes[i] =
3277         MultiplyAddDivideOffsetCalculation(
3278             unpadded_source_shard_shape.dimensions(i), 0, 1) -
3279         first_window[i];
3280     source_right_halo_sizes[i] =
3281         limit_window[i] - MultiplyAddDivideOffsetCalculation(
3282                               unpadded_source_shard_shape.dimensions(i),
3283                               unpadded_source_shard_shape.dimensions(i), 1);
3284     data_left_halo_sizes[i] =
3285         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
3286             unpadded_data_shard_shape.dimensions(i), wd.padding_low(), 1)) -
3287         OffsetCalculation(
3288             HloOpcode::kMultiply, first_window[i],
3289             MultiplyAddDivideOffsetCalculation(0, wd.stride(), 1));
3290     data_right_halo_sizes[i] =
3291         OffsetCalculation(
3292             HloOpcode::kMultiply, limit_window[i],
3293             MultiplyAddDivideOffsetCalculation(0, wd.stride(), 1)) -
3294         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
3295             unpadded_data_shard_shape.dimensions(i),
3296             unpadded_data_shard_shape.dimensions(i) + wd.stride() +
3297                 wd.padding_low() - wd.size(),
3298             1));
3299 
3300     int64_t max_windows =
3301         (limit_window[i] - first_window[i]).MaxInRange(0, shard_count);
3302     auto first_window_hlo =
3303         first_window[i].Calculate(partition_ordinals[i], &b_);
3304     // Padding on the source is filled with the init value so they do not change
3305     // the data on overlapping windows.
3306     auto resharded_source = ExchangeHaloAndGetValidData(
3307         source_shard_hlo, source.base_shape(), source_left_halo_sizes[i],
3308         source_right_halo_sizes[i], 0,
3309         limit_window[i].Calculate(shard_count - 1), max_windows, i,
3310         hlo->sharding(), first_window_hlo, replicated_init.hlo(),
3311         partition_ordinals[i], collective_ops_creator_, next_channel_id_, &b_);
3312     if (!resharded_source) {
3313       return DefaultAction(hlo);
3314     }
3315     source_shard_hlo = *resharded_source;
3316 
3317     auto offset_start_in_data =
3318         MultiplyAddDivideOffsetCalculation(wd.stride(), 0, 1)
3319             .Calculate(first_window_hlo, &b_);
3320     int64_t padded_data_size =
3321         (limit_window[i].Calculate(shard_count - 1) - 1) * wd.stride() +
3322         wd.size();
3323     int64_t data_shard_size = (max_windows - 1) * wd.stride() + wd.size();
3324     auto resharded_data = ExchangeHaloAndGetValidData(
3325         data_shard_hlo, operand.base_shape(), data_left_halo_sizes[i],
3326         data_right_halo_sizes[i], wd.padding_low(), padded_data_size,
3327         data_shard_size, i, hlo->sharding(), offset_start_in_data, pad_value,
3328         partition_ordinals[i], collective_ops_creator_, next_channel_id_, &b_);
3329     if (!resharded_data) {
3330       return DefaultAction(hlo);
3331     }
3332     data_shard_hlo = *resharded_data;
3333   }
3334 
3335   Window window_on_shard = hlo->window();
3336   for (int64_t i = 0; i < window_on_shard.dimensions_size(); ++i) {
3337     int64_t shard_count = hlo->sharding().tile_assignment().dim(i);
3338     if (shard_count == 1) {
3339       continue;
3340     }
3341     auto reshard_wd = window_on_shard.mutable_dimensions(i);
3342     // The shards are already explicitly padded.
3343     reshard_wd->set_padding_low(0);
3344     reshard_wd->set_padding_high(0);
3345   }
3346 
3347   auto sharded_select_and_scatter =
3348       b_.AddInstruction(HloInstruction::CreateSelectAndScatter(
3349           data_shard_hlo->shape(), data_shard_hlo, select, window_on_shard,
3350           source_shard_hlo, replicated_init.hlo(),
3351           hlo->called_computations()[1]));
3352   SetPartitionedHlo(hlo, [&]() {
3353     auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
3354     if (ShapeUtil::Compatible(sharded_select_and_scatter->shape(),
3355                               shard_shape)) {
3356       return sharded_select_and_scatter;
3357     }
3358     auto zero = b_.AddInstruction(
3359         HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
3360     std::vector<HloInstruction*> slice_offsets(shard_shape.rank(), zero);
3361     for (int64_t i = 0; i < window_on_shard.dimensions_size(); ++i) {
3362       if (hlo->sharding().tile_assignment().dim(i) == 1) {
3363         continue;
3364       }
3365       int64_t pad_low = hlo->window().dimensions(i).padding_low();
3366       auto left_halo_size =
3367           data_left_halo_sizes[i].Calculate(partition_ordinals[i], &b_);
3368       if (data_left_halo_sizes[i].Calculate(0) == pad_low) {
3369         slice_offsets[i] = left_halo_size;
3370       } else {
3371         auto is_shard0 = b_.AddInstruction(HloInstruction::CreateCompare(
3372             ShapeUtil::MakeShape(PRED, {}), zero, partition_ordinals[i],
3373             ComparisonDirection::kEq));
3374         auto pad_low_hlo = b_.AddInstruction(HloInstruction::CreateConstant(
3375             LiteralUtil::CreateR0<int32>(pad_low)));
3376         slice_offsets[i] = b_.AddInstruction(HloInstruction::CreateTernary(
3377             zero->shape(), HloOpcode::kSelect, is_shard0, pad_low_hlo,
3378             left_halo_size));
3379       }
3380     }
3381     return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
3382         shard_shape, sharded_select_and_scatter, slice_offsets,
3383         shard_shape.dimensions()));
3384   });
3385   return Status::OK();
3386 }
3387 
HandleTuple(HloInstruction * hlo)3388 Status SpmdPartitioningVisitor::HandleTuple(HloInstruction* hlo) {
3389   std::vector<HloInstruction*> new_operands;
3390   for (int64_t i = 0; i < hlo->operand_count(); ++i) {
3391     new_operands.push_back(
3392         GetPartitionedHlo(hlo->operand(i))
3393             .Reshard(hlo->sharding().GetSubSharding(hlo->shape(), {i}))
3394             .hlo());
3395   }
3396   SetPartitionedHlo(hlo, [&]() {
3397     return b_.AddInstruction(HloInstruction::CreateTuple(new_operands));
3398   });
3399   return Status::OK();
3400 }
3401 
DoPartition(HloComputation * computation,const HloSharding & root_sharding,const SpmdPartitionerOptions & options)3402 StatusOr<bool> SpmdPartitioningVisitor::DoPartition(
3403     HloComputation* computation, const HloSharding& root_sharding,
3404     const SpmdPartitionerOptions& options) {
3405   VLOG(2) << "Partitioning computation " << computation->name() << " for "
3406           << num_replicas_ << " replicas and " << num_partitions_
3407           << " partitions";
3408   TF_RETURN_IF_ERROR(computation->Accept(this));
3409 
3410   HloModule* module = computation->parent();
3411   auto new_root =
3412       GetPartitionedHlo(computation->root_instruction()).Reshard(root_sharding);
3413   auto new_computation =
3414       module->AddEmbeddedComputation(b_.Build(new_root.hlo()));
3415   TF_RETURN_IF_ERROR(
3416       DoCodeMotionForWindowedDotGeneralLoops(new_computation, options));
3417 
3418   // Replace the original computation with the new SPMD computation.
3419   absl::flat_hash_map<HloComputation*, HloComputation*> replacement;
3420   replacement[computation] = new_computation;
3421   module->ReplaceComputations(replacement);
3422   return changed_;
3423 }
3424 
HandlePartitionId(HloInstruction * hlo)3425 Status SpmdPartitioningVisitor::HandlePartitionId(HloInstruction* hlo) {
3426   return Unimplemented(
3427       "PartitionId instruction is not supported for SPMD partitioning since "
3428       "the meaning is ambiguous -- whether the instruction is replicated or "
3429       "the data is replicated, and if the latter which data is replicated.");
3430 }
3431 
GetDefaultCollectiveOpsCreator(int64_t num_partitions,int64_t num_replicas)3432 SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64_t num_partitions,
3433                                                         int64_t num_replicas) {
3434   return {
3435       [](SpmdBuilder* b) {
3436         return b->AddInstruction(HloInstruction::CreatePartitionId());
3437       },
3438       [num_replicas, num_partitions](
3439           SpmdBuilder* b, HloInstruction* operand, HloComputation* reduction,
3440           const std::vector<std::vector<int64>>& partition_subgroups,
3441           int64_t channel_id) {
3442         if (partition_subgroups.size() <= 1) {
3443           std::vector<ReplicaGroup> groups(num_replicas);
3444           // TODO(yuanzx): Unify subgroup definition with AllToAll.
3445           for (int64_t i = 0; i < num_replicas; ++i) {
3446             groups[i].add_replica_ids(i);
3447           }
3448           return b->AddInstruction(HloInstruction::CreateAllReduce(
3449               operand->shape(), {operand}, reduction, groups,
3450               /*constrain_layout=*/false, channel_id,
3451               /*use_global_device_ids=*/false));
3452         }
3453 
3454         std::vector<ReplicaGroup> device_groups;
3455         device_groups.reserve(partition_subgroups.size() * num_replicas);
3456         for (int64_t i = 0; i < num_replicas; ++i) {
3457           for (const auto& pgroup : partition_subgroups) {
3458             device_groups.emplace_back();
3459             for (int64_t pid : pgroup) {
3460               device_groups.back().add_replica_ids(i * num_partitions + pid);
3461             }
3462           }
3463         }
3464         return b->AddInstruction(HloInstruction::CreateAllReduce(
3465             operand->shape(), {operand}, reduction, device_groups,
3466             /*constrain_layout=*/false, channel_id,
3467             /*use_global_device_ids=*/true));
3468       },
3469       [num_partitions](SpmdBuilder* b, HloInstruction* operand,
3470                        std::vector<std::pair<int64, int64>>& src_dst_pairs,
3471                        int64_t channel_id) {
3472         /* optimize trivial collective permute */
3473         if (src_dst_pairs.empty()) {
3474           // If the src/dst pairs are empty, then the collective permute just
3475           // initializes the output to zero.
3476           return CreateZero(operand->shape(), b);
3477         } else {
3478           // A collective-permute is a copy if all pairs are "identity" and
3479           // all partitions are listed.
3480           bool is_copy =
3481               src_dst_pairs.size() == num_partitions &&
3482               absl::c_all_of(src_dst_pairs,
3483                              [](const std::pair<int64, int64>& pair) {
3484                                return pair.first == pair.second;
3485                              });
3486           if (is_copy) {
3487             return operand;
3488           } else {
3489             return b->AddInstruction(HloInstruction::CreateCollectivePermute(
3490                 operand->shape(), operand, src_dst_pairs, channel_id));
3491           }
3492         }
3493       },
3494       [](SpmdBuilder* b, absl::Span<HloInstruction* const> operands,
3495          const std::vector<std::vector<int64>>& partition_subgroups,
3496          int64_t channel_id, absl::optional<int64> split_dimension) {
3497         std::vector<Shape> shapes(operands.size(), operands[0]->shape());
3498         const Shape output_shape = (shapes.size() == 1)
3499                                        ? shapes[0]
3500                                        : ShapeUtil::MakeTupleShape(shapes);
3501         std::vector<ReplicaGroup> groups(partition_subgroups.size());
3502         for (int64_t i = 0; i < groups.size(); ++i) {
3503           for (int64_t id : partition_subgroups[i]) {
3504             groups[i].add_replica_ids(id);
3505           }
3506         }
3507         return b->AddInstruction(HloInstruction::CreateAllToAll(
3508             output_shape, operands, groups,
3509             /*constrain_layout=*/false, channel_id, split_dimension));
3510       },
3511       [num_replicas, num_partitions](
3512           SpmdBuilder* b, HloInstruction* operand, const Shape& ag_shape,
3513           const std::vector<std::vector<int64>>& partition_subgroups,
3514           int64_t channel_id, int64_t all_gather_dimension) {
3515         std::vector<ReplicaGroup> device_groups;
3516         device_groups.reserve(partition_subgroups.size() * num_replicas);
3517         for (int64_t i = 0; i < num_replicas; ++i) {
3518           for (const auto& pgroup : partition_subgroups) {
3519             device_groups.emplace_back();
3520             for (int64_t pid : pgroup) {
3521               device_groups.back().add_replica_ids(i * num_partitions + pid);
3522             }
3523           }
3524         }
3525         return b->AddInstruction(HloInstruction::CreateAllGather(
3526             ag_shape, {operand}, all_gather_dimension, device_groups,
3527             /*constrain_layout=*/false, channel_id,
3528             /*use_global_device_ids=*/true));
3529       },
3530   };
3531 }
3532 
SpmdPartitioner(int64_t num_partitions,int64_t num_replicas,SpmdPartitionerOptions options)3533 SpmdPartitioner::SpmdPartitioner(int64_t num_partitions, int64_t num_replicas,
3534                                  SpmdPartitionerOptions options)
3535     : SpmdPartitioner(
3536           num_partitions, num_replicas, std::move(options),
3537           GetDefaultCollectiveOpsCreator(num_partitions, num_replicas)) {}
3538 
AllGatherShards(SpmdBuilder * b,HloInstruction * operand,const HloSharding & sharding,int64 * next_channel_id,absl::Span<const int64> selected_dims,const SPMDCollectiveOpsCreator & collectives_creator)3539 HloInstruction* SpmdPartitioner::AllGatherShards(
3540     SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
3541     int64* next_channel_id, absl::Span<const int64> selected_dims,
3542     const SPMDCollectiveOpsCreator& collectives_creator) {
3543   return AllGatherShardsInternal(b, operand, sharding, next_channel_id,
3544                                  selected_dims, collectives_creator,
3545                                  /*per_dim_ag=*/true);
3546 }
3547 
AllGatherShardsInternal(SpmdBuilder * b,HloInstruction * operand,const HloSharding & sharding,int64 * next_channel_id,absl::Span<const int64> selected_dims,const SPMDCollectiveOpsCreator & collectives_creator,bool per_dim_ag)3548 HloInstruction* SpmdPartitioner::AllGatherShardsInternal(
3549     SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
3550     int64* next_channel_id, absl::Span<const int64> selected_dims,
3551     const SPMDCollectiveOpsCreator& collectives_creator, bool per_dim_ag) {
3552   if (selected_dims.empty()) {
3553     return operand;
3554   }
3555   CHECK(!sharding.IsTileMaximal());
3556   // Add one leading dimension to gather all partitions.
3557   std::vector<int64> shape;
3558   shape.push_back(1);
3559   for (int64_t dim : operand->shape().dimensions()) {
3560     shape.push_back(dim);
3561   }
3562   auto reshape = b->AddInstruction(HloInstruction::CreateReshape(
3563       ShapeUtil::MakeShape(operand->shape().element_type(), shape), operand));
3564   HloInstruction* result = reshape;
3565   if (per_dim_ag) {
3566     for (auto it = selected_dims.rbegin(); it != selected_dims.rend(); ++it) {
3567       if (sharding.tile_assignment().dim(*it) == 1) {
3568         continue;
3569       }
3570       auto partition_subgroups =
3571           GetPartitionGroupsForReplication(sharding, {*it});
3572       shape[0] *= partition_subgroups[0].size();
3573       result = collectives_creator.create_cross_partition_all_gather(
3574           b, result,
3575           ShapeUtil::MakeShape(operand->shape().element_type(), shape),
3576           partition_subgroups, (*next_channel_id)++,
3577           /*all_gather_dimension=*/0);
3578     }
3579   } else {
3580     auto partition_subgroups =
3581         GetPartitionGroupsForReplication(sharding, selected_dims);
3582     shape[0] *= partition_subgroups[0].size();
3583     result = collectives_creator.create_cross_partition_all_gather(
3584         b, result, ShapeUtil::MakeShape(operand->shape().element_type(), shape),
3585         partition_subgroups, (*next_channel_id)++,
3586         /*all_gather_dimension=*/0);
3587   }
3588   // If n > 1 dimensions are partitioned, split the leading dimension to n.
3589   std::vector<int64> tiled_dims;
3590   for (int64_t i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
3591     if (sharding.tile_assignment().dim(i) > 1 &&
3592         absl::c_linear_search(selected_dims, i)) {
3593       tiled_dims.push_back(i);
3594     }
3595   }
3596   if (tiled_dims.size() > 1) {
3597     std::vector<int64> split_dim_shape;
3598     split_dim_shape.reserve(tiled_dims.size() + operand->shape().rank());
3599     for (int64_t i : tiled_dims) {
3600       split_dim_shape.push_back(sharding.tile_assignment().dim(i));
3601     }
3602     for (int64_t dim : operand->shape().dimensions()) {
3603       split_dim_shape.push_back(dim);
3604     }
3605     result = b->AddInstruction(HloInstruction::CreateReshape(
3606         ShapeUtil::MakeShape(operand->shape().element_type(), split_dim_shape),
3607         result));
3608   }
3609   // Transpose the gathered dimensions to next to their corresponding
3610   // partitioned dimensions.
3611   std::vector<int64> xpose_permutation(result->shape().rank());
3612   int64_t split_dims_added = 0;
3613   for (int64_t i = 0; i < xpose_permutation.size(); ++i) {
3614     if (sharding.tile_assignment().dim(i - split_dims_added) == 1 ||
3615         !absl::c_linear_search(selected_dims, i - split_dims_added)) {
3616       xpose_permutation[i] = i + tiled_dims.size() - split_dims_added;
3617     } else {
3618       xpose_permutation[i] = split_dims_added;
3619       xpose_permutation[i + 1] = i + tiled_dims.size() - split_dims_added;
3620       split_dims_added++;
3621       i++;
3622     }
3623   }
3624   result = b->AddInstruction(HloInstruction::CreateTranspose(
3625       ShapeInference::InferTransposeShape(result->shape(), xpose_permutation)
3626           .ValueOrDie(),
3627       result, xpose_permutation));
3628   // Reshape to the desired shape.
3629   auto ag_shape = operand->shape();
3630   for (int64_t i : tiled_dims) {
3631     ag_shape.set_dimensions(
3632         i, ag_shape.dimensions(i) * sharding.tile_assignment().dim(i));
3633   }
3634   result = b->AddInstruction(HloInstruction::CreateReshape(ag_shape, result));
3635   return result;
3636 }
3637 
AllReduceAlongShardingDims(SpmdBuilder * b,HloInstruction * operand,const HloSharding & sharding,int64 * next_channel_id,absl::Span<const int64> selected_dims,const SPMDCollectiveOpsCreator & collectives_creator,HloComputation * reduction)3638 HloInstruction* SpmdPartitioner::AllReduceAlongShardingDims(
3639     SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
3640     int64* next_channel_id, absl::Span<const int64> selected_dims,
3641     const SPMDCollectiveOpsCreator& collectives_creator,
3642     HloComputation* reduction) {
3643   return AllReduceAlongShardingDimsInternal(
3644       b, operand, sharding, next_channel_id, selected_dims, collectives_creator,
3645       reduction, /*per_dim_ar=*/true);
3646 }
3647 
AllReduceAlongShardingDimsInternal(SpmdBuilder * b,HloInstruction * operand,const HloSharding & sharding,int64 * next_channel_id,absl::Span<const int64> selected_dims,const SPMDCollectiveOpsCreator & collectives_creator,HloComputation * reduction,bool per_dim_ar)3648 HloInstruction* SpmdPartitioner::AllReduceAlongShardingDimsInternal(
3649     SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
3650     int64* next_channel_id, absl::Span<const int64> selected_dims,
3651     const SPMDCollectiveOpsCreator& collectives_creator,
3652     HloComputation* reduction, bool per_dim_ar) {
3653   if (!per_dim_ar) {
3654     auto partition_subgroups =
3655         GetPartitionGroupsForReplication(sharding, selected_dims);
3656     return collectives_creator.create_cross_partition_all_reduce(
3657         b, operand, reduction, partition_subgroups, (*next_channel_id)++);
3658   }
3659   auto result = operand;
3660   for (auto it = selected_dims.rbegin(); it != selected_dims.rend(); ++it) {
3661     if (sharding.tile_assignment().dim(*it) == 1) {
3662       continue;
3663     }
3664     auto partition_subgroups =
3665         GetPartitionGroupsForReplication(sharding, {*it});
3666     result = collectives_creator.create_cross_partition_all_reduce(
3667         b, result, reduction, partition_subgroups, (*next_channel_id)++);
3668   }
3669   return result;
3670 }
3671 
PartitionComputation(HloComputation * computation,const HloSharding & root_sharding,int64 * next_channel_id,SpmdLogger * logger)3672 StatusOr<bool> SpmdPartitioner::PartitionComputation(
3673     HloComputation* computation, const HloSharding& root_sharding,
3674     int64* next_channel_id, SpmdLogger* logger) {
3675   auto visitor =
3676       CreateVisitor(computation, num_partitions_, num_replicas_,
3677                     collective_ops_creator_, next_channel_id, logger, options_);
3678   return visitor->DoPartition(computation, root_sharding, options_);
3679 }
3680 
CreateVisitor(HloComputation * computation,int64_t num_partitions,int64_t num_replicas,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,SpmdLogger * logger,SpmdPartitionerOptions options)3681 std::unique_ptr<SpmdPartitioningVisitor> SpmdPartitioner::CreateVisitor(
3682     HloComputation* computation, int64_t num_partitions, int64_t num_replicas,
3683     const SPMDCollectiveOpsCreator& collective_ops_creator,
3684     int64* next_channel_id, SpmdLogger* logger,
3685     SpmdPartitionerOptions options) {
3686   return absl::make_unique<SpmdPartitioningVisitor>(
3687       computation, num_partitions, num_replicas, collective_ops_creator,
3688       next_channel_id, logger, std::move(options), this);
3689 }
3690 
Run(HloModule * module)3691 StatusOr<bool> SpmdPartitioner::Run(HloModule* module) {
3692   TF_RETURN_IF_ERROR(PreprocessSharding(module));
3693   TF_RETURN_IF_ERROR(PreprocessHlos(module));
3694 
3695   XLA_VLOG_LINES(1, SpmdLogger::ReportBeforePartition(
3696                         *module, options_.report_instruction_count));
3697 
3698   // Add the parameters' and output's shardings to the module.
3699   std::vector<HloSharding> entry_params_shardings;
3700   for (int64_t i = 0; i < module->entry_computation()->num_parameters(); ++i) {
3701     auto param = module->entry_computation()->parameter_instruction(i);
3702     CHECK(param->has_sharding()) << "Missing sharding in entry parameter " << i;
3703     entry_params_shardings.push_back(param->sharding());
3704   }
3705   module->set_spmd_parameters_shardings(entry_params_shardings);
3706   auto entry_root = module->entry_computation()->root_instruction();
3707   CHECK(entry_root->has_sharding()) << "Missing sharding in entry root.";
3708   module->set_spmd_output_sharding(entry_root->sharding());
3709 
3710   FlattenCallGraph flatten;
3711   TF_ASSIGN_OR_RETURN(auto changed, flatten.Run(module));
3712 
3713   SpmdLogger logger(options_.report_instruction_count);
3714   auto program_shape = module->entry_computation()->ComputeProgramShape();
3715   int64_t next_channel_id = hlo_query::NextChannelId(*module);
3716   // Copy the root sharding since the partitioner visitor may temporarily change
3717   // the sharding to work around manual sharding.
3718   HloSharding root_sharding = entry_root->sharding();
3719   TF_ASSIGN_OR_RETURN(
3720       bool partition_changed,
3721       PartitionComputation(module->entry_computation(), root_sharding,
3722                            &next_channel_id, &logger));
3723   changed |= partition_changed;
3724 
3725   // For the entry computation, make sure that the root instruction and the
3726   // parameters preserve their signatures.
3727   auto new_program_shape = module->entry_computation()->ComputeProgramShape();
3728   if (!options_.allow_module_signature_change) {
3729     TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
3730         program_shape.result(), new_program_shape.result()))
3731         << "Result shape changed for the entry computation";
3732     TF_RET_CHECK(program_shape.parameters_size() ==
3733                  new_program_shape.parameters_size())
3734         << "Parameter count changed for the entry computation";
3735     for (int64_t i = 0; i < program_shape.parameters_size(); ++i) {
3736       TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
3737           program_shape.parameters(i), new_program_shape.parameters(i)))
3738           << "Parameter shape changed for the entry computation";
3739     }
3740   } else {
3741     const auto& old_entry_layout = module->entry_computation_layout();
3742     // Shapes can change but the layout should still remain the same.
3743     for (int64_t i = 0; i < new_program_shape.parameters_size(); ++i) {
3744       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
3745           old_entry_layout.parameter_shape(i),
3746           new_program_shape.mutable_parameters(i)));
3747     }
3748     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
3749         old_entry_layout.result_shape(), new_program_shape.mutable_result()));
3750 
3751     HloModuleConfig config = module->config();
3752     *config.mutable_entry_computation_layout() =
3753         ComputationLayout(new_program_shape, /*ignore_layouts=*/false);
3754     module->set_config(config);
3755   }
3756 
3757   XLA_VLOG_LINES(1, SpmdLogger::ReportAfterPartition(
3758                         *module, options_.report_instruction_count));
3759   XLA_VLOG_LINES(1, logger.MakeReport());
3760 
3761   if (changed) {
3762     HloPassPipeline pass("spmd-cleanup");
3763     pass.AddPass<HloDCE>();
3764     pass.AddPass<TupleSimplifier>();
3765     pass.AddPass<HloDCE>();
3766     pass.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
3767     pass.AddPass<FlattenCallGraph>();
3768     TF_RETURN_IF_ERROR(pass.Run(module).status());
3769   }
3770 
3771   TF_RETURN_IF_ERROR(ClearShardingAttributes(module));
3772   return changed;
3773 }
3774 
PreprocessSharding(HloModule * module)3775 Status SpmdPartitioner::PreprocessSharding(HloModule* module) {
3776   for (HloComputation* computation : module->computations()) {
3777     for (HloInstruction* hlo : computation->instructions()) {
3778       if (hlo->HasSideEffectNoRecurse() && hlo->opcode() != HloOpcode::kRng) {
3779         TF_RET_CHECK(hlo->has_sharding())
3780             << "Side-effect HLO must have sharding: " << hlo->ToString();
3781         TF_RET_CHECK(!HasReplicatedSharding(hlo->sharding()) ||
3782                      CanSideEffectingHaveReplicatedSharding(hlo))
3783             << "side-effect HLO cannot have a replicated sharding: "
3784             << hlo->ToString();
3785       }
3786 
3787       // For unassigned HLOs, annotate with replicated sharding.
3788       //
3789       // Among side-effecting ops, only Rng is allowed to omit the annotation.
3790       // In that case, we currently force it to run on core 0, since we don't
3791       // support partitioning or replicating the Rng op (the values depend on
3792       // the seed provided to each device).
3793       //
3794       // TODO(hyouklee): Should we also convert single-device shardings (without
3795       // side-effects) into replicated?
3796       if (!hlo->has_sharding()) {
3797         if (hlo->opcode() == HloOpcode::kRng) {
3798           hlo->set_sharding(HloSharding::AssignDevice(0));
3799         } else {
3800           hlo->set_sharding(
3801               HloSharding::Single(hlo->shape(), HloSharding::Replicate()));
3802         }
3803       } else if (!hlo->sharding().IsTileMaximal() &&
3804                  !hlo->sharding().IsManual()) {
3805         std::vector<int64> available(num_partitions_);
3806         std::iota(available.begin(), available.end(), 0);
3807         TF_RET_CHECK(num_partitions_ == hlo_sharding_util::DevicesForSharding(
3808                                             hlo->sharding(), available)
3809                                             .size())
3810             << "num_partitions:" << num_partitions_ << "\n"
3811             << "SPMD partitioner only supports tile sharding that includes all "
3812                "partitions. If you didn't add this sharding annotation in the "
3813                "model, please file a bug to XLA team.\n"
3814             << hlo->ToString();
3815       }
3816     }
3817   }
3818 
3819   // Entry computation's parameter and root sharding must be either all
3820   // replicated or all on a single device.
3821   if (!options_.allow_module_signature_change) {
3822     const HloComputation* entry = module->entry_computation();
3823     TF_RET_CHECK(entry->root_instruction()->has_sharding());
3824     const HloSharding& root_sharding = entry->root_instruction()->sharding();
3825     if (!root_sharding.UniqueDevice().has_value()) {
3826       if (root_sharding.IsTuple()) {
3827         TF_RET_CHECK(absl::c_all_of(root_sharding.tuple_elements(),
3828                                     [](const HloSharding& s) {
3829                                       return s.IsReplicated() || s.IsManual();
3830                                     }))
3831             << "Unsupported entry root sharding: " << root_sharding.ToString();
3832 
3833       } else {
3834         TF_RET_CHECK(root_sharding.IsReplicated() || root_sharding.IsManual())
3835             << "Unsupported entry root sharding: " << root_sharding.ToString();
3836       }
3837     }
3838 
3839     for (const HloInstruction* param : entry->parameter_instructions()) {
3840       TF_RET_CHECK(param->has_sharding());
3841       TF_RET_CHECK(param->sharding().IsReplicated() ||
3842                    param->sharding().UniqueDevice().has_value())
3843           << "Unsupported entry parameter sharding:"
3844           << param->sharding().ToString();
3845     }
3846   }
3847 
3848   return Status::OK();
3849 }
3850 
PreprocessHlos(HloModule * module)3851 Status SpmdPartitioner::PreprocessHlos(HloModule* module) {
3852   auto skip_copy_operands = [](HloInstruction* operand,
3853                                bool check_single_use =
3854                                    true) -> HloInstruction* {
3855     while (operand->user_count() == 1 &&
3856            operand->opcode() == HloOpcode::kCopy) {
3857       operand = operand->mutable_operand(0);
3858     }
3859     if (check_single_use && operand->user_count() != 1) {
3860       return nullptr;
3861     }
3862     return operand;
3863   };
3864 
3865   for (HloComputation* computation : module->computations()) {
3866     for (HloInstruction* hlo : computation->MakeInstructionPostOrder()) {
3867       if (hlo->sharding().IsTileMaximal() || hlo->sharding().IsManual()) {
3868         // No need to optimize for tile-maximal or manual sharding.
3869         continue;
3870       }
3871 
3872       if (hlo->opcode() == HloOpcode::kSlice) {
3873         HloInstruction* operand = skip_copy_operands(hlo->mutable_operand(0));
3874         if (operand == nullptr || operand->sharding() != hlo->sharding()) {
3875           continue;
3876         }
3877 
3878         // Merge pad->slice to avoid multiple halo exchanges.
3879         if (operand->opcode() == HloOpcode::kPad) {
3880           absl::optional<PaddingConfig> merged_padding =
3881               operand->padding_config();
3882           bool may_have_multi_halo_exchanges = false;
3883           for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
3884             const auto& dim = operand->padding_config().dimensions(i);
3885             if (dim.interior_padding() != 0 || hlo->slice_strides(i) != 1) {
3886               merged_padding = absl::nullopt;
3887               break;
3888             }
3889             if (hlo->sharding().tile_assignment().dim(i) != 1 &&
3890                 (dim.edge_padding_low() != 0 || dim.edge_padding_high() != 0) &&
3891                 hlo->shape().dimensions(i) != operand->shape().dimensions(i)) {
3892               // There are padding, slicing, and sharding on this dim.
3893               may_have_multi_halo_exchanges = true;
3894             }
3895 
3896             auto* merged_dim = merged_padding->mutable_dimensions(i);
3897             merged_dim->set_edge_padding_low(dim.edge_padding_low() -
3898                                              hlo->slice_starts(i));
3899             merged_dim->set_edge_padding_high(hlo->slice_limits(i) -
3900                                               dim.edge_padding_low() -
3901                                               operand->shape().dimensions(i));
3902           }
3903           if (merged_padding.has_value() && may_have_multi_halo_exchanges) {
3904             // Rewrite to a single Pad.
3905             HloInstruction* new_pad =
3906                 computation->AddInstruction(HloInstruction::CreatePad(
3907                     hlo->shape(), operand->mutable_operand(0),
3908                     operand->mutable_operand(1), *merged_padding));
3909             new_pad->set_metadata(operand->metadata());
3910             new_pad->set_sharding(hlo->sharding());
3911             TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_pad));
3912             TF_RETURN_IF_ERROR(
3913                 computation->RemoveInstructionAndUnusedOperands(hlo));
3914           }
3915         }
3916       }
3917       if (hlo->opcode() == HloOpcode::kConcatenate) {
3918         const int64_t dim = hlo->concatenate_dimension();
3919         if (hlo->sharding().tile_assignment().dim(dim) == 1) {
3920           continue;
3921         }
3922         if (hlo->operand_count() == 2) {
3923           // Find a pattern of "rotate right on one dimension":
3924           // concat(slice(input), slice(input)).
3925           HloInstruction* lhs = skip_copy_operands(hlo->mutable_operand(0));
3926           HloInstruction* rhs = skip_copy_operands(hlo->mutable_operand(1));
3927           if (lhs == nullptr || rhs == nullptr) {
3928             continue;
3929           }
3930           const int64_t amount = FindRotateRightPattern(hlo, lhs, rhs);
3931           if (amount < 0) {
3932             continue;
3933           }
3934           HloInstruction* to_rotate = lhs->mutable_operand(0);
3935           HloInstruction* rotate = computation->AddInstruction(
3936               CreateCustomCallSPMDInternal_RotateRight(to_rotate, dim, amount));
3937           rotate->set_metadata(hlo->metadata());
3938           rotate->set_sharding(hlo->sharding());
3939           TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(rotate));
3940           TF_RETURN_IF_ERROR(
3941               computation->RemoveInstructionAndUnusedOperands(hlo));
3942         } else if (hlo->operand_count() == 3) {
3943           // Find the pattern for "pad with wrap": concat(slice(x), x, slice(x))
3944           // All involved values with same sharding.
3945           HloInstruction* lhs = skip_copy_operands(hlo->mutable_operand(0));
3946           HloInstruction* mid = skip_copy_operands(hlo->mutable_operand(1),
3947                                                    /*check_single_use=*/false);
3948           HloInstruction* rhs = skip_copy_operands(hlo->mutable_operand(2));
3949           absl::optional<PadWithWrapPattern> pad_pattern =
3950               FindPadWithWrapPattern(hlo, lhs, mid, rhs);
3951           if (!pad_pattern) {
3952             continue;
3953           }
3954 
3955           // Since the concat requires that the size of all operands along the
3956           // non-concat dimension is the same, it implies that the lhs/rhs slice
3957           // is slicing along the concat dims.
3958 
3959           // Step 1: Pad the mid operand to the final size. The low padding is
3960           // the size of the lhs shape, and high padding is size of rhs shape.
3961           PaddingConfig padding_config =
3962               MakeNoPaddingConfig(hlo->shape().rank());
3963           auto* padding_config_dim = padding_config.mutable_dimensions(dim);
3964           const int64_t low_pad = lhs->shape().dimensions(dim);
3965           const int64_t high_pad = rhs->shape().dimensions(dim);
3966           padding_config_dim->set_edge_padding_low(low_pad);
3967           padding_config_dim->set_edge_padding_high(high_pad);
3968           HloInstruction* zero =
3969               computation->AddInstruction(HloInstruction::CreateConstant(
3970                   LiteralUtil::Zero(hlo->shape().element_type())));
3971           zero->set_sharding(HloSharding::Replicate());
3972           HloInstruction* pad =
3973               computation->AddInstruction(HloInstruction::CreatePad(
3974                   hlo->shape(), mid, zero, padding_config));
3975           pad->set_metadata(hlo->metadata());
3976           pad->set_sharding(hlo->sharding());
3977 
3978           // Step 2: rotate the padded value so that the lhs slice aligns to the
3979           // low of the padded size.
3980           //  padded_operand = low_pad | mid | high_pad.
3981           //  slice_start in padded_operand = lhs->slice_start + low_pad.
3982           //  Rotate left by (lhs->slice_start + low_pad)
3983           //  i.e., rotate right = padded_size - (lhs_slice_start + low_pad).
3984           const int64_t padded_size = hlo->shape().dimensions(dim);
3985           const int rotate_lhs_amount =
3986               padded_size - (pad_pattern->lhs_slice_start + low_pad);
3987           HloInstruction* rotate_lhs = computation->AddInstruction(
3988               CreateCustomCallSPMDInternal_RotateRight(pad, dim,
3989                                                        rotate_lhs_amount));
3990           rotate_lhs->set_metadata(hlo->metadata());
3991           rotate_lhs->set_sharding(hlo->sharding());
3992 
3993           auto apply_modifiers =
3994               [&](HloInstruction* inst,
3995                   const std::vector<const HloInstruction*>& modifiers) {
3996                 // Apply the modifiers in the reverse order.
3997                 for (auto it = modifiers.crbegin(), end = modifiers.crend();
3998                      it != end; ++it) {
3999                   const HloInstruction* modifier = *it;
4000                   // New shape has same element type as the modifier, but dims
4001                   // as inst.
4002                   Shape new_shape = ShapeUtil::ChangeElementType(
4003                       inst->shape(), modifier->shape().element_type());
4004                   inst = computation->AddInstruction(
4005                       modifier->CloneWithNewOperands(new_shape, {inst}));
4006                 }
4007                 return inst;
4008               };
4009           rotate_lhs = apply_modifiers(rotate_lhs, pad_pattern->lhs_modifiers);
4010 
4011           // Step 3: rotate the padded value so that the rhs slice aligns to
4012           // high of the padded size.
4013           //  padded_operand = low_pad | mid | high_pad.
4014           //  slice_start in padded_operand = rhs->slice_start + low_pad.
4015           //  slice_end in padded_operand = rhs->slice_start + low_pad +
4016           //  high_pad; Rotate right by padded_size - (rhs->slice_start +
4017           //  low_pad + high_pad)
4018           const int64_t rotate_rhs_amount =
4019               padded_size - (pad_pattern->rhs_slice_start + low_pad + high_pad);
4020           HloInstruction* rotate_rhs = computation->AddInstruction(
4021               CreateCustomCallSPMDInternal_RotateRight(pad, dim,
4022                                                        rotate_rhs_amount));
4023           rotate_rhs->set_metadata(hlo->metadata());
4024           rotate_rhs->set_sharding(hlo->sharding());
4025           rotate_rhs = apply_modifiers(rotate_rhs, pad_pattern->rhs_modifiers);
4026 
4027           // Now merge the 3 results using appropriate selects.
4028           const Shape iota_shape =
4029               ShapeUtil::ChangeElementType(hlo->shape(), U32);
4030           HloInstruction* iota = computation->AddInstruction(
4031               HloInstruction::CreateIota(iota_shape, dim));
4032           iota->set_metadata(hlo->metadata());
4033           iota->set_sharding(hlo->sharding());
4034 
4035           struct SelectSpec {
4036             int64 limit;
4037             HloInstruction* hlo;
4038             Comparison::Direction cmp;
4039           };
4040           const std::array<SelectSpec, 2> selects = {
4041               {// All elements < low_pad come from rotate_lhs.
4042                {low_pad, rotate_lhs, Comparison::Direction::kLt},
4043                // All elements >= padded_size - high_pad come from rotate_rhs
4044                {padded_size - high_pad, rotate_rhs,
4045                 Comparison::Direction::kGe}}};
4046 
4047           Shape pred_shape = ShapeUtil::ChangeElementType(hlo->shape(), PRED);
4048 
4049           HloInstruction* merged = pad;
4050           for (const SelectSpec& select_spec : selects) {
4051             HloInstruction* limit =
4052                 computation->AddInstruction(HloInstruction::CreateConstant(
4053                     LiteralUtil::CreateR0<uint32_t>(select_spec.limit)));
4054             limit->set_sharding(HloSharding::Replicate());
4055             HloInstruction* limit_bcast = computation->AddInstruction(
4056                 HloInstruction::CreateBroadcast(iota_shape, limit, {}));
4057             limit_bcast->set_metadata(hlo->metadata());
4058             limit_bcast->set_sharding(hlo->sharding());
4059             HloInstruction* compare =
4060                 computation->AddInstruction(HloInstruction::CreateCompare(
4061                     pred_shape, iota, limit_bcast, select_spec.cmp));
4062             compare->set_metadata(hlo->metadata());
4063             compare->set_sharding(hlo->sharding());
4064             merged = computation->AddInstruction(HloInstruction::CreateTernary(
4065                 hlo->shape(), HloOpcode::kSelect, compare, select_spec.hlo,
4066                 merged));
4067             merged->set_metadata(hlo->metadata());
4068             merged->set_sharding(hlo->sharding());
4069           }
4070 
4071           TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(merged));
4072           TF_RETURN_IF_ERROR(
4073               computation->RemoveInstructionAndUnusedOperands(hlo));
4074         }
4075       }
4076     }
4077   }
4078   return Status::OK();
4079 }
4080 
4081 }  // namespace spmd
4082 }  // namespace xla
4083