1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
17
18 #include <float.h>
19
20 #include <functional>
21 #include <memory>
22 #include <unordered_map>
23 #include <vector>
24
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/types/optional.h"
31 #include "absl/types/span.h"
32 #include "tensorflow/compiler/xla/comparison_util.h"
33 #include "tensorflow/compiler/xla/literal_util.h"
34 #include "tensorflow/compiler/xla/protobuf_util.h"
35 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
36 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
37 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
38 #include "tensorflow/compiler/xla/service/hlo_computation.h"
39 #include "tensorflow/compiler/xla/service/hlo_cse.h"
40 #include "tensorflow/compiler/xla/service/hlo_dce.h"
41 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
42 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
43 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
44 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
45 #include "tensorflow/compiler/xla/service/hlo_query.h"
46 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
47 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
48 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
49 #include "tensorflow/compiler/xla/service/shape_inference.h"
50 #include "tensorflow/compiler/xla/service/spmd/custom_call_handler.h"
51 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
52 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
53 #include "tensorflow/compiler/xla/shape_util.h"
54 #include "tensorflow/compiler/xla/util.h"
55 #include "tensorflow/compiler/xla/window_util.h"
56 #include "tensorflow/compiler/xla/xla_data.pb.h"
57 #include "tensorflow/core/platform/numbers.h"
58
59 namespace xla {
60 namespace spmd {
61
MakeReport()62 string SpmdLogger::MakeReport() {
63 string report;
64 absl::StrAppend(&report,
65 "\n\n***** SPMD memory during transformation *****\n");
66
67 std::sort(entries_.begin(), entries_.end(),
68 [](auto const& entry0, auto const& entry1) {
69 return entry0.first > entry1.first;
70 });
71 for (int64_t i = 0;
72 i < std::min<int64>(report_instruction_count_, entries_.size()); ++i) {
73 absl::StrAppend(
74 &report, "\n ",
75 tensorflow::strings::HumanReadableNumBytes(entries_[i].first), " : ",
76 entries_[i].second, "\n");
77 }
78
79 return report;
80 }
81
RegisterLogEntry(HloInstruction * hlo,const std::vector<HloInstruction * > & group)82 void SpmdLogger::RegisterLogEntry(HloInstruction* hlo,
83 const std::vector<HloInstruction*>& group) {
84 string report = hlo->ToString();
85 int64_t max_value = -1;
86 for (HloInstruction* inst : group) {
87 if (!inst->shape().IsArray()) {
88 continue;
89 }
90 max_value = std::max<int64>(max_value, ShapeSizeInBytes(inst->shape()));
91 absl::StrAppend(&report, " * ", inst->ToString(), "\n");
92 }
93 entries_.push_back(std::make_pair(max_value, report));
94 }
95
ReportBeforePartition(const HloModule & module,int64_t report_instruction_count)96 /* static */ string SpmdLogger::ReportBeforePartition(
97 const HloModule& module, int64_t report_instruction_count) {
98 string report;
99 absl::StrAppend(&report,
100 "\n\n***** SPMD memory usage before partition *****\n");
101 absl::StrAppend(&report, "\n ** Replicated instructions\n");
102 absl::StrAppend(&report, ReportMemoryUsage(
103 module,
104 [](const HloInstruction* hlo) {
105 return !hlo->has_sharding() ||
106 hlo->sharding().IsReplicated();
107 },
108 report_instruction_count));
109 absl::StrAppend(&report, "\n ** All instructions\n");
110 absl::StrAppend(&report,
111 ReportMemoryUsage(
112 module, [](const HloInstruction* hlo) { return true; },
113 report_instruction_count));
114 return report;
115 }
116
ReportAfterPartition(const HloModule & module,int64_t report_instruction_count)117 /* static */ string SpmdLogger::ReportAfterPartition(
118 const HloModule& module, int64_t report_instruction_count) {
119 string report;
120 absl::StrAppend(&report,
121 "\n\n***** SPMD memory usage after partition *****\n");
122 absl::StrAppend(&report,
123 ReportMemoryUsage(
124 module, [](const HloInstruction* hlo) { return true; },
125 report_instruction_count));
126 return report;
127 }
128
129 template <typename F>
ReportMemoryUsage(const HloModule & module,const F & filter,int64_t report_instruction_count)130 /* static */ string SpmdLogger::ReportMemoryUsage(
131 const HloModule& module, const F& filter,
132 int64_t report_instruction_count) {
133 string report;
134 std::vector<HloInstruction*> instructions;
135 instructions.reserve(module.instruction_count());
136
137 for (auto computation : module.computations()) {
138 if (computation->IsFusionComputation()) {
139 continue;
140 }
141 for (auto hlo : computation->instructions()) {
142 if (!hlo->shape().IsArray() ||
143 ShapeUtil::IsEffectiveScalar(hlo->shape())) {
144 continue;
145 }
146 if (filter(hlo)) {
147 instructions.push_back(hlo);
148 }
149 }
150 }
151
152 const auto add_report = [&](std::vector<HloInstruction*>* insts) {
153 std::sort(insts->begin(), insts->end(),
154 [](const HloInstruction* inst0, const HloInstruction* inst1) {
155 return ShapeSizeInBytes(inst0->shape()) >
156 ShapeSizeInBytes(inst1->shape());
157 });
158 for (int64_t i = 0;
159 i < std::min<int64>(report_instruction_count, insts->size()); ++i) {
160 absl::StrAppend(&report, " ",
161 tensorflow::strings::HumanReadableNumBytes(
162 ShapeSizeInBytes((*insts)[i]->shape())),
163 " : ", (*insts)[i]->ToString(), "\n");
164 }
165 };
166
167 add_report(&instructions);
168 return report;
169 }
170
171 namespace {
172
173 // Clears all sharding attributes from instructions in the module. This must be
174 // called only after all SPMD transformation is complete.
ClearShardingAttributes(HloModule * module)175 Status ClearShardingAttributes(HloModule* module) {
176 for (HloComputation* computation : module->computations()) {
177 for (HloInstruction* hlo : computation->instructions()) {
178 // Keep sharding annotation on Infeed and entry parameters since they're
179 // used by HloReplicationAnalysis later (for ArCrsCombiner).
180 if (hlo->opcode() == HloOpcode::kInfeed) {
181 continue;
182 }
183 if (hlo->opcode() == HloOpcode::kParameter &&
184 computation == module->entry_computation()) {
185 continue;
186 }
187 hlo->clear_sharding();
188 }
189 }
190 return Status::OK();
191 }
192
GetPartitionGroupsForReplication(const HloSharding & sharding,absl::Span<const int64> replication_dims)193 std::vector<std::vector<int64>> GetPartitionGroupsForReplication(
194 const HloSharding& sharding, absl::Span<const int64> replication_dims) {
195 int64_t group_size = 1;
196 for (int64_t i : replication_dims) {
197 group_size *= sharding.tile_assignment().dim(i);
198 }
199 std::vector<std::vector<int64>> partition_groups(
200 sharding.tile_assignment().num_elements() / group_size);
201 sharding.tile_assignment().Each(
202 [&](absl::Span<const int64> indices, int64_t partition) {
203 int64_t group_id = 0;
204 for (int64_t i = 0; i < indices.size(); ++i) {
205 if (!absl::c_linear_search(replication_dims, i)) {
206 group_id *= sharding.tile_assignment().dim(i);
207 group_id += indices[i];
208 }
209 }
210 partition_groups[group_id].push_back(partition);
211 });
212 return partition_groups;
213 }
214
215 } // namespace
216
AddInstruction(std::unique_ptr<HloInstruction> instruction)217 HloInstruction* SpmdBuilder::AddInstruction(
218 std::unique_ptr<HloInstruction> instruction) {
219 HloInstruction* hlo =
220 HloComputation::Builder::AddInstruction(std::move(instruction));
221 if (visiting_hlo_) {
222 hlo->set_metadata(visiting_hlo_->metadata());
223 instructions_[visiting_hlo_].push_back(hlo);
224 }
225 if (hlo->opcode() == HloOpcode::kBroadcast) {
226 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
227 if (!absl::c_linear_search(hlo->dimensions(), i)) {
228 broadcast_dims_[hlo].insert(i);
229 }
230 }
231 }
232 if (hlo->IsElementwise() && hlo->operand_count() > 0) {
233 absl::flat_hash_set<int64> broadcast_dims;
234 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
235 broadcast_dims.insert(i);
236 }
237 for (int64_t i = 0; i < hlo->operand_count(); ++i) {
238 auto it = broadcast_dims_.find(hlo->operand(i));
239 if (it == broadcast_dims_.end()) {
240 broadcast_dims.clear();
241 break;
242 }
243 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
244 if (!it->second.contains(i)) {
245 broadcast_dims.erase(i);
246 }
247 }
248 }
249 if (!broadcast_dims.empty()) {
250 broadcast_dims_[hlo] = std::move(broadcast_dims);
251 }
252 }
253 if (hlo->opcode() == HloOpcode::kTranspose) {
254 auto it = broadcast_dims_.find(hlo->operand(0));
255 if (it != broadcast_dims_.end()) {
256 absl::flat_hash_set<int64> xpose_broadcast_dims;
257 std::vector<int64> reverse_map(hlo->shape().rank());
258 for (int64_t i = 0; i < reverse_map.size(); ++i) {
259 reverse_map[hlo->dimensions(i)] = i;
260 }
261 for (int64_t dim : it->second) {
262 xpose_broadcast_dims.insert(reverse_map[dim]);
263 }
264 broadcast_dims_[hlo] = std::move(xpose_broadcast_dims);
265 }
266 }
267 if (hlo->opcode() == HloOpcode::kReshape &&
268 Product(hlo->shape().dimensions()) > 0) {
269 auto it = broadcast_dims_.find(hlo->operand(0));
270 if (it != broadcast_dims_.end()) {
271 absl::flat_hash_set<int64> reshape_broadcast_dims;
272 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
273 reshape_broadcast_dims.insert(i);
274 }
275 std::vector<int64> before_dim_size_stack;
276 std::vector<int64> after_dim_size_stack;
277 for (int64_t i = hlo->operand(0)->shape().rank() - 1; i >= 0; --i) {
278 before_dim_size_stack.push_back(hlo->operand(0)->shape().dimensions(i));
279 }
280 for (int64_t i = hlo->shape().rank() - 1; i >= 0; --i) {
281 after_dim_size_stack.push_back(hlo->shape().dimensions(i));
282 }
283 while (!before_dim_size_stack.empty() && !after_dim_size_stack.empty()) {
284 int64_t before_size = before_dim_size_stack.back();
285 int64_t after_size = after_dim_size_stack.back();
286 int64_t current_before_dim =
287 hlo->operand(0)->shape().rank() - before_dim_size_stack.size();
288 int64_t current_after_dim =
289 hlo->shape().rank() - after_dim_size_stack.size();
290 before_dim_size_stack.pop_back();
291 after_dim_size_stack.pop_back();
292 if (!it->second.contains(current_before_dim)) {
293 reshape_broadcast_dims.erase(current_after_dim);
294 }
295 if (before_size == after_size) {
296 continue;
297 }
298 if (before_size % after_size == 0) {
299 // Split dim.
300 before_dim_size_stack.push_back(before_size / after_size);
301 } else if (after_size % before_size == 0) {
302 // Merge dim.
303 after_dim_size_stack.push_back(after_size / before_size);
304 } else {
305 // Other cases, mark all remaining dims as non-broadcast.
306 for (int64_t i = current_after_dim; i < hlo->shape().rank(); ++i) {
307 reshape_broadcast_dims.erase(i);
308 }
309 break;
310 }
311 }
312 if (!before_dim_size_stack.empty() || !after_dim_size_stack.empty()) {
313 reshape_broadcast_dims.clear();
314 }
315 if (!reshape_broadcast_dims.empty()) {
316 broadcast_dims_[hlo] = std::move(reshape_broadcast_dims);
317 }
318 }
319 }
320 if (hlo->opcode() == HloOpcode::kSlice ||
321 hlo->opcode() == HloOpcode::kDynamicSlice) {
322 auto it = broadcast_dims_.find(hlo->operand(0));
323 if (it != broadcast_dims_.end()) {
324 auto dims = it->second;
325 broadcast_dims_[hlo] = std::move(dims);
326 }
327 }
328 if (hlo->opcode() == HloOpcode::kPad) {
329 auto it = broadcast_dims_.find(hlo->operand(0));
330 if (it != broadcast_dims_.end()) {
331 absl::flat_hash_set<int64> pad_broadcast_dims;
332 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
333 const auto& dim = hlo->padding_config().dimensions(i);
334 if (dim.edge_padding_low() == 0 && dim.edge_padding_high() == 0 &&
335 dim.interior_padding() == 0 && it->second.contains(i)) {
336 pad_broadcast_dims.insert(i);
337 }
338 }
339 if (!pad_broadcast_dims.empty()) {
340 broadcast_dims_[hlo] = std::move(pad_broadcast_dims);
341 }
342 }
343 }
344 return hlo;
345 }
346
Reshard(const HloSharding & target)347 PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target) {
348 if (sharding() == target) {
349 return *this;
350 }
351 auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
352 const bool is_to_replicate =
353 hlo_->shape().IsArray() && target.NumTiles() < sharding().NumTiles();
354 const bool use_cache =
355 !is_to_replicate || state_.partitioner->options().cache_all_gather;
356 if (use_cache) {
357 for (auto& entry : cache) {
358 if (entry.first == target) {
359 return entry.second;
360 }
361 }
362 }
363 auto resharded = ReshardNoCache(target);
364 state_.reshard_cache->per_hlo_cache[resharded.hlo()]
365 .reshard_cache.emplace_back(sharding(), *this);
366 if (use_cache) {
367 // Get the cache again as it might be invalidated by the insertion above.
368 auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
369 cache.emplace_back(target, std::move(resharded));
370 return cache.back().second;
371 }
372 return resharded;
373 }
374
ReshardNoCache(const HloSharding & target)375 PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) {
376 VLOG(2) << "Resharding " << hlo_->ToString() << " from "
377 << hlo_->sharding().ToString() << " to " << target.ToString();
378 const Shape& shape = hlo_->shape();
379 if (shape.element_type() == TOKEN) {
380 return *this;
381 }
382 CHECK(shape.IsTuple() || !target.IsTuple());
383
384 // Tuple shape instructions may have non-tuple sharding, which means that the
385 // same sharding applies to all the leaves.
386 if (shape.IsTuple() && !target.IsTuple()) {
387 return Reshard(target.GetTupleSharding(shape).ValueOrDie());
388 }
389
390 // For a tuple shape, recursively apply Reshard to all the leaves and return
391 // a tuple instruction.
392 if (shape.IsTuple()) {
393 std::vector<HloInstruction*> elements;
394 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
395 auto subshape = ShapeUtil::GetTupleElementShape(shape, i);
396 auto element = state_.b->AddInstruction(
397 HloInstruction::CreateGetTupleElement(subshape, hlo(), i));
398 element->set_sharding(sharding().GetSubSharding(shape, {i}));
399 elements.push_back(
400 PartitionedHlo(
401 element, ShapeUtil::GetTupleElementShape(base_shape_, i), state_)
402 .Reshard(target.GetSubSharding(shape, {i}))
403 .hlo());
404 }
405 auto tuple =
406 state_.b->AddInstruction(HloInstruction::CreateTuple(elements));
407 tuple->set_sharding(target);
408 return PartitionedHlo(tuple, base_shape_, state_);
409 }
410
411 if (sharding() == target) {
412 return *this;
413 }
414
415 if (CanReshardWithCollectivePermute(sharding(), target)) {
416 return ReshardWithCollectivePermute(target);
417 }
418
419 if (auto src_tgt_dims =
420 GetReshardAllToAllSourceTargetDims(sharding(), target)) {
421 return ReshardWithAllToAll(target, *src_tgt_dims);
422 }
423
424 if (!target.IsTileMaximal() && sharding().ReplicateOnLastTileDim()) {
425 auto try_reshard = ReshardFromPartialReplicateWithDynamicSlice(target);
426 if (try_reshard.has_value()) {
427 return try_reshard.value();
428 }
429 try_reshard = ReshardPartialReplicateWithAllToAll(target);
430 if (try_reshard.has_value()) {
431 return try_reshard.value();
432 }
433 }
434
435 if (!sharding().IsTileMaximal() && target.ReplicateOnLastTileDim()) {
436 auto try_reshard = ReshardToPartialReplicateWithAllGather(target);
437 if (try_reshard.has_value()) {
438 return try_reshard.value();
439 }
440 try_reshard = ReshardPartialReplicateWithAllToAll(target);
441 if (try_reshard.has_value()) {
442 return try_reshard.value();
443 }
444 }
445
446 // If not replicated yet, first replicate and then reshard to use one of the
447 // two implementations below.
448 if (!sharding().IsReplicated()) {
449 return Replicate().Reshard(target);
450 }
451
452 // 'Replicated' to 'SingleDevice'.
453 if (target.IsTileMaximal()) {
454 auto copy = state_.b->AddInstruction(
455 HloInstruction::CreateUnary(hlo_->shape(), HloOpcode::kCopy, hlo_));
456 copy->set_sharding(target);
457 return PartitionedHlo(copy, base_shape_, state_);
458 }
459
460 // 'Replicated' to partial replicated.
461 if (target.ReplicateOnLastTileDim()) {
462 std::vector<int64> group_dims(target.tile_assignment().num_dimensions() -
463 1);
464 std::iota(group_dims.begin(), group_dims.end(), 0);
465 auto target_grouped = GroupShardingOnDims(target, group_dims);
466 auto partially_sharded = PerGroupSliceFromReplicated(
467 hlo_, state_.partition_id, target_grouped.device_groups, group_dims,
468 target_grouped.group_dim_sizes, state_.b);
469 partially_sharded->set_sharding(target);
470 return PartitionedHlo(partially_sharded, base_shape(), state_);
471 }
472
473 // 'Replicated' to 'Tiled'.
474 auto padded_hlo =
475 PadBaseShapeBeforeUnevenTiledSharding(hlo_, target, state_.b);
476 auto shard_shape = MakePartitionedShape(shape, target);
477 auto slice = state_.b->AddInstruction(HloInstruction::CreateDynamicSlice(
478 shard_shape, padded_hlo,
479 MakePartitionOffsets(shape, target, state_.partition_id, state_.b),
480 shard_shape.dimensions()));
481 slice->set_sharding(target);
482 return PartitionedHlo(slice, base_shape_, state_);
483 }
484
PadWithValue(HloInstruction * pad_value,absl::Span<const int64> left_padded_dims,absl::Span<const int64> skipped_dims) const485 PartitionedHlo PartitionedHlo::PadWithValue(
486 HloInstruction* pad_value, absl::Span<const int64> left_padded_dims,
487 absl::Span<const int64> skipped_dims) const {
488 const HloSharding& sharding = hlo_->sharding();
489 const Shape& shape = hlo_->shape();
490 CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
491 if (sharding.IsReplicated() || EvenlyPartitions(base_shape_, sharding)) {
492 return *this;
493 }
494 CHECK(!sharding.IsTileMaximal());
495 auto index_shape = ShapeUtil::ChangeElementType(shape, S32);
496 auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED);
497 auto get_mask_for_dim = [&](int64_t dim, HloInstruction* start_index) {
498 // Comparison: iota + start_index < valid_size
499 auto iota =
500 state_.b->AddInstruction(HloInstruction::CreateIota(index_shape, dim));
501 auto broadcast_start_index = state_.b->AddInstruction(
502 HloInstruction::CreateBroadcast(index_shape, start_index, {}));
503 auto index_in_full_shape =
504 state_.b->AddInstruction(HloInstruction::CreateBinary(
505 index_shape, HloOpcode::kAdd, iota, broadcast_start_index));
506 ComparisonDirection direction = ComparisonDirection::kLt;
507 int64_t index_limit = base_shape_.dimensions(dim);
508 if (absl::c_linear_search(left_padded_dims, dim)) {
509 direction = ComparisonDirection::kGe;
510 index_limit =
511 index_shape.dimensions(dim) * sharding.tile_assignment().dim(dim) -
512 index_limit;
513 }
514 auto limit = state_.b->AddInstruction(HloInstruction::CreateConstant(
515 LiteralUtil::CreateR0<int32>(index_limit)));
516 auto broadcast_limit = state_.b->AddInstruction(
517 HloInstruction::CreateBroadcast(index_shape, limit, {}));
518 return state_.b->AddInstruction(HloInstruction::CreateCompare(
519 mask_shape, index_in_full_shape, broadcast_limit, direction));
520 };
521
522 HloInstruction* mask = nullptr;
523 auto offsets = MakePartitionOffsets(base_shape_, sharding,
524 state_.partition_id, state_.b);
525 for (int64_t i = 0; i < shape.rank(); ++i) {
526 if (base_shape_.dimensions(i) % sharding.tile_assignment().dim(i) == 0 ||
527 absl::c_linear_search(skipped_dims, i)) {
528 continue;
529 }
530 if (mask == nullptr) {
531 mask = get_mask_for_dim(i, offsets[i]);
532 } else {
533 mask = state_.b->AddInstruction(
534 HloInstruction::CreateBinary(mask->shape(), HloOpcode::kAnd, mask,
535 get_mask_for_dim(i, offsets[i])));
536 }
537 }
538
539 if (mask == nullptr) {
540 return *this;
541 }
542
543 auto broadcast_pad_value = state_.b->AddInstruction(
544 HloInstruction::CreateBroadcast(shape, pad_value, {}));
545 auto result = state_.b->AddInstruction(HloInstruction::CreateTernary(
546 shape, HloOpcode::kSelect, mask, hlo_, broadcast_pad_value));
547 result->set_sharding(sharding);
548 return PartitionedHlo(result, base_shape_, state_);
549 }
550
551 absl::optional<PartitionedHlo::WindowedInputShardReturnValue>
ReshardAsWindowedInput(const Window & window,const HloSharding & target,HloInstruction * pad_value,bool mask_invalid_region)552 PartitionedHlo::ReshardAsWindowedInput(const Window& window,
553 const HloSharding& target,
554 HloInstruction* pad_value,
555 bool mask_invalid_region) {
556 auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].window_reshard_cache;
557 for (auto& entry : cache) {
558 if (std::get<0>(entry) == target &&
559 protobuf_util::ProtobufEquals(std::get<1>(entry), window)) {
560 return std::get<2>(entry);
561 }
562 }
563 auto update_cache = [&](WindowedInputShardReturnValue result) {
564 cache.emplace_back(target, window, std::move(result));
565 return std::get<2>(cache.back());
566 };
567 VLOG(2) << "ReshardAsWindowedInput()\n"
568 << "\twindow:" << window_util::ToString(window)
569 << "\ttarget sharding:" << target.ToString();
570
571 CHECK(!target.IsTileMaximal());
572 auto partition_ordinals =
573 MakeTiledPartitionOrdinals(target, state_.partition_id, state_.b);
574 auto shard_shape = base_shape_;
575
576 std::vector<MultiplyAddDivideOffsetCalculation> start_on_padded_calculations(
577 base_shape_.rank());
578 std::vector<MultiplyAddDivideOffsetCalculation> limit_on_padded_calculations(
579 base_shape_.rank());
580 std::vector<HloInstruction*> dynamic_slice_offset_on_output(
581 base_shape_.rank(), nullptr);
582
583 Window shard_window = window;
584 Shape padded_shape = base_shape_;
585 std::vector<HloInstruction*> offsets_on_padded_shape(base_shape_.rank());
586 std::vector<int64> per_shard_window_counts(base_shape_.rank());
587 std::vector<int64> explicit_left_padding(base_shape_.rank());
588 for (int64_t i = 0; i < base_shape_.rank(); ++i) {
589 // Do not pad non-partitioned dimensions.
590 int64_t shard_count = target.tile_assignment().dim(i);
591 if (shard_count == 1) {
592 offsets_on_padded_shape[i] = state_.b->AddInstruction(
593 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
594 continue;
595 }
596 const WindowDimension& wd = window.dimensions(i);
597 const int64_t dilated_size = 1 + (wd.size() - 1) * wd.window_dilation();
598 const int64_t full_size =
599 1 + (base_shape_.dimensions(i) - 1) * wd.base_dilation() +
600 wd.padding_high() + wd.padding_low();
601 if (full_size < dilated_size) {
602 VLOG(2) << "Failed to reshard window operand because the window size is "
603 "larger than padded base size";
604 return absl::nullopt;
605 }
606 int64_t window_count = (full_size - dilated_size) / wd.stride() + 1;
607 per_shard_window_counts[i] = CeilOfRatio(window_count, shard_count);
608 if (wd.stride() != 1 &&
609 (wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() != 0) {
610 // TODO(yuanzx): Support this case.
611 VLOG(2) << "Failed to reshard window operand due to non-trivial dilation";
612 return absl::nullopt;
613 }
614
615 // We use explicit padding for full dilations, then use padding_low and
616 // padding_high on the sharded op for the remaining. padding_low and
617 // padding_high are now given initial values, which will be later updated if
618 // dilation is not 1.
619 WindowDimension* swd = shard_window.mutable_dimensions(i);
620 explicit_left_padding[i] = wd.padding_low() / wd.base_dilation();
621 swd->set_padding_low(wd.padding_low() % wd.base_dilation());
622 swd->set_padding_high(0);
623
624 // Calculation for the first element needed on the 'padded-but-not-dilated'
625 // shape. The start on the dilated shape could be a hole, so we add
626 // wd.base_dilation() - 1 to the constant term to skip the leading holes.
627 start_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation(
628 wd.stride() * per_shard_window_counts[i],
629 wd.base_dilation() - 1 - swd->padding_low(), wd.base_dilation());
630 int64_t dilated_shard_size =
631 wd.stride() * (per_shard_window_counts[i] - 1) + dilated_size;
632 limit_on_padded_calculations[i] = MultiplyAddDivideOffsetCalculation(
633 wd.stride() * per_shard_window_counts[i],
634 dilated_shard_size + wd.base_dilation() - 1 - swd->padding_low(),
635 wd.base_dilation());
636
637 offsets_on_padded_shape[i] = start_on_padded_calculations[i].Calculate(
638 partition_ordinals[i], state_.b);
639
640 auto shard_size_function =
641 limit_on_padded_calculations[i] - start_on_padded_calculations[i];
642 int64_t max_shard_size = shard_size_function.MaxInRange(0, shard_count);
643 shard_shape.set_dimensions(i, max_shard_size);
644 padded_shape.set_dimensions(
645 i, limit_on_padded_calculations[i].Calculate(shard_count - 1));
646
647 // For base dilation, calculate the needed padding_low and padding_high, as
648 // well as the offset for the output if a dynamic slice is needed after the
649 // sharded op.
650 if (wd.base_dilation() != 1) {
651 // Returns the offset of a shard's first valid element in the dilated
652 // shard.
653 auto get_first_valid_element_offset_on_dilated_shard =
654 [&](int64_t shard_ordinal) {
655 return start_on_padded_calculations[i].Calculate(shard_ordinal) *
656 wd.base_dilation() +
657 swd->padding_low() -
658 wd.stride() * per_shard_window_counts[i] * shard_ordinal;
659 };
660 CHECK_EQ(get_first_valid_element_offset_on_dilated_shard(0),
661 swd->padding_low());
662
663 // Determine swd->padding_high.
664 for (int64_t shard_ordinal = 0; shard_ordinal < shard_count;
665 ++shard_ordinal) {
666 int64_t wanted_limit_on_dilated_shard =
667 wd.stride() * (per_shard_window_counts[i] - 1) + dilated_size;
668 int64_t actual_limit_on_dilated_shard_without_pad_high =
669 get_first_valid_element_offset_on_dilated_shard(shard_ordinal) +
670 (max_shard_size - 1) * wd.base_dilation() + 1;
671 swd->set_padding_high(std::max<int64>(
672 swd->padding_high(),
673 wanted_limit_on_dilated_shard -
674 actual_limit_on_dilated_shard_without_pad_high));
675 }
676
677 // Determine swd->padding_low and output dynamic slice index.
678 if (wd.stride() == 1) {
679 int64_t max_pad_low =
680 get_first_valid_element_offset_on_dilated_shard(0);
681 bool all_same = true;
682 for (int64_t shard_ordinal = 1; shard_ordinal < shard_count;
683 ++shard_ordinal) {
684 int64_t start =
685 get_first_valid_element_offset_on_dilated_shard(shard_ordinal);
686 if (start != swd->padding_low()) {
687 all_same = false;
688 }
689 max_pad_low = std::max(max_pad_low, start);
690 }
691 if (!all_same) {
692 auto start_on_padded_input =
693 start_on_padded_calculations[i].Calculate(partition_ordinals[i],
694 state_.b);
695 // We will calculate
696 // max_pad_low - (first_window - required_first_window)
697 // which equals
698 // required_first_window - (first_window - max_pad_low)
699 auto first_window_minus_max_pad_low =
700 MultiplyAddDivideOffsetCalculation(
701 wd.base_dilation(), swd->padding_low() - max_pad_low, 1)
702 .Calculate(start_on_padded_input, state_.b);
703 auto required_first_window =
704 MultiplyAddDivideOffsetCalculation(per_shard_window_counts[i], 0,
705 1)
706 .Calculate(partition_ordinals[i], state_.b);
707 dynamic_slice_offset_on_output[i] =
708 state_.b->AddInstruction(HloInstruction::CreateBinary(
709 required_first_window->shape(), HloOpcode::kSubtract,
710 required_first_window, first_window_minus_max_pad_low));
711 }
712 swd->set_padding_low(max_pad_low);
713 } else {
714 if ((wd.stride() * per_shard_window_counts[i]) % wd.base_dilation() !=
715 0) {
716 // General base dilation not yet implemented.
717 return absl::nullopt;
718 }
719 // padding_low on all shards should equal the initially assigned
720 // swd->padding_low(), i.e., the padding_low() on the original window.
721 }
722 }
723 }
724
725 // Returns the output dynamic slice offset when needed, and absl::nullopt
726 // otherwise.
727 auto get_dynamic_slice_offset_on_output_if_needed =
728 [&]() -> absl::optional<std::vector<HloInstruction*>> {
729 if (absl::c_all_of(
730 dynamic_slice_offset_on_output,
731 [](HloInstruction* offset) { return offset == nullptr; })) {
732 return absl::nullopt;
733 }
734 auto zero = state_.b->AddInstruction(
735 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
736 for (int64_t i = 0; i < dynamic_slice_offset_on_output.size(); ++i) {
737 if (dynamic_slice_offset_on_output[i] == nullptr) {
738 dynamic_slice_offset_on_output[i] = zero;
739 }
740 }
741 return dynamic_slice_offset_on_output;
742 };
743
744 // If the currrent HLO is replicated, pad then slice.
745 if (sharding().IsReplicated()) {
746 PaddingConfig padding_config;
747 for (int64_t i = 0; i < base_shape_.rank(); ++i) {
748 auto padding_config_dim = padding_config.add_dimensions();
749 padding_config_dim->set_interior_padding(0);
750 // Do not pad non-partitioned dimensions.
751 if (target.tile_assignment().dim(i) == 1) {
752 padding_config_dim->set_edge_padding_low(0);
753 padding_config_dim->set_edge_padding_high(0);
754 continue;
755 }
756 padding_config_dim->set_edge_padding_low(explicit_left_padding[i]);
757 padding_config_dim->set_edge_padding_high(padded_shape.dimensions(i) -
758 explicit_left_padding[i] -
759 base_shape_.dimensions(i));
760 }
761 auto padded_hlo = ShapeUtil::Compatible(padded_shape, base_shape_)
762 ? hlo_
763 : state_.b->AddInstruction(HloInstruction::CreatePad(
764 padded_shape, hlo_, pad_value, padding_config));
765 auto sharded_input =
766 state_.b->AddInstruction(HloInstruction::CreateDynamicSlice(
767 shard_shape, padded_hlo, offsets_on_padded_shape,
768 shard_shape.dimensions()));
769 return update_cache(WindowedInputShardReturnValue{
770 sharded_input, shard_window,
771 get_dynamic_slice_offset_on_output_if_needed()});
772 }
773
774 if (target != sharding()) {
775 return Reshard(target).ReshardAsWindowedInput(window, target, pad_value);
776 }
777
778 // Halo exchange.
779 HloInstruction* visiting_hlo = hlo_;
780 auto original_shard_shape = MakePartitionedShape(base_shape_, target);
781
782 std::vector<OffsetCalculation> left_halo_size_functions(base_shape_.rank());
783 std::vector<OffsetCalculation> right_halo_size_functions(base_shape_.rank());
784 // TODO(yuanzx): We are concatenating on each sharded dimension one at time,
785 // and in the second dimension (and beyond) we create halos by slicing the
786 // concat in the previous dimension, which is not optimal. We should generate
787 // halos only concating slices, instead of slicing concats.
788 for (int dim = 0; dim < base_shape_.rank(); ++dim) {
789 int64_t shard_count = target.tile_assignment().dim(dim);
790 if (shard_count == 1) {
791 continue;
792 }
793 int64_t input_shard_size =
794 CeilOfRatio(base_shape_.dimensions(dim), shard_count);
795
796 // Left halo. The size of the halo is derived by subtracting the first read
797 // element offset of the i'th partition from the limit of the (i-1)'th
798 // partition.
799 MultiplyAddDivideOffsetCalculation shard_limit_of_previous_on_padded(
800 input_shard_size, explicit_left_padding[dim], 1);
801 left_halo_size_functions[dim] =
802 shard_limit_of_previous_on_padded - start_on_padded_calculations[dim];
803
804 // Right halo.
805 MultiplyAddDivideOffsetCalculation shard_start_of_next_on_padded(
806 input_shard_size, input_shard_size + explicit_left_padding[dim], 1);
807 right_halo_size_functions[dim] =
808 limit_on_padded_calculations[dim] - shard_start_of_next_on_padded;
809
810 auto resharded = ExchangeHaloAndGetValidData(
811 visiting_hlo, base_shape_, left_halo_size_functions[dim],
812 right_halo_size_functions[dim], explicit_left_padding[dim],
813 padded_shape.dimensions(dim), shard_shape.dimensions(dim), dim, target,
814 offsets_on_padded_shape[dim], pad_value, partition_ordinals[dim],
815 state_.collective_ops_creator, state_.next_channel_id, state_.b,
816 mask_invalid_region);
817 if (!resharded) {
818 VLOG(1) << "ReshardAsWindowedInput failed without replicate first: halo "
819 "is beyond the neighbor.";
820 return Replicate().ReshardAsWindowedInput(window, target, pad_value);
821 }
822 visiting_hlo = *resharded;
823 }
824 return update_cache(WindowedInputShardReturnValue{
825 visiting_hlo, shard_window,
826 get_dynamic_slice_offset_on_output_if_needed()});
827 }
828
Replicate()829 PartitionedHlo PartitionedHlo::Replicate() {
830 auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
831 if (state_.partitioner->options().cache_all_gather) {
832 for (auto& entry : cache) {
833 if (entry.first.IsReplicated()) {
834 return entry.second;
835 }
836 }
837 }
838 // Do not use a reference as the HLO's sharding can be temporarily replaced.
839 const HloSharding sharding = hlo_->sharding();
840 const Shape& shape = hlo_->shape();
841 CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
842
843 if (sharding.IsReplicated()) {
844 return *this;
845 }
846 for (auto& entry : cache) {
847 if (entry.first.IsReplicated()) {
848 return entry.second;
849 }
850 }
851 auto update_cache = [&](PartitionedHlo resharded) {
852 state_.reshard_cache->per_hlo_cache[resharded.hlo()]
853 .reshard_cache.emplace_back(sharding, *this);
854 // Get the cache again as it might be invalidated by the insertion above.
855 auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache;
856 if (state_.partitioner->options().cache_all_gather) {
857 cache.emplace_back(HloSharding::Replicate(), std::move(resharded));
858 return cache.back().second;
859 }
860 return resharded;
861 };
862 // 'Single Device' to 'Repliated'.
863 if (sharding.IsTileMaximal()) {
864 return update_cache(Broadcast());
865 }
866
867 // 'Tiled' to 'Replicated'.
868 std::vector<int64> all_dims(shape.rank());
869 std::iota(all_dims.begin(), all_dims.end(), 0);
870 HloInstruction* result = ReplicatePartial(all_dims);
871 result->set_sharding(HloSharding::Replicate());
872 return update_cache(PartitionedHlo(result, base_shape_, state_));
873 }
874
ReplicatePartial(absl::Span<const int64> dims)875 HloInstruction* PartitionedHlo::ReplicatePartial(absl::Span<const int64> dims) {
876 CHECK(!sharding().IsTileMaximal());
877 const Shape& shard_shape = hlo()->shape();
878 Shape target_shape = shard_shape;
879 Shape padded_target_shape = shard_shape;
880 std::vector<int64> broadcast_dims;
881 std::vector<int64> ag_dims;
882 // Find dimensions that can be replicated with Broadcast() (shard size 1) and
883 // others that need all-gather.
884 for (int64_t i : dims) {
885 if (sharding().tile_assignment().dim(i) == 1) {
886 continue;
887 }
888 target_shape.set_dimensions(i, base_shape().dimensions(i));
889 if (target_shape.dimensions(i) == shard_shape.dimensions(i)) {
890 broadcast_dims.push_back(i);
891 } else {
892 padded_target_shape.set_dimensions(
893 i, shard_shape.dimensions(i) * sharding().tile_assignment().dim(i));
894 ag_dims.push_back(i);
895 }
896 }
897
898 HloInstruction* broadcast = hlo_;
899 if (!broadcast_dims.empty()) {
900 std::vector<int64> other_dims;
901 for (int64_t i = 0; i < sharding().tile_assignment().num_dimensions();
902 ++i) {
903 if (!absl::c_linear_search(broadcast_dims, i)) {
904 other_dims.push_back(i);
905 }
906 }
907 HloSharding original_sharding = sharding();
908 auto grouped = GroupShardingOnDims(original_sharding, other_dims);
909 std::vector<int64> dev_indices(
910 grouped.sharding.tile_assignment().num_dimensions(), 0);
911 hlo_->set_sharding(HloSharding::AssignDevice(
912 grouped.sharding.tile_assignment()(dev_indices)));
913 auto per_group_partitioner_state = CreatePerGroupPartitioningState(
914 state(), grouped.device_groups, state().b);
915 auto partial_replicate_hlo =
916 PartitionedHlo(hlo_, shard_shape, per_group_partitioner_state)
917 .Broadcast();
918 hlo_->set_sharding(original_sharding);
919 partial_replicate_hlo.hlo()->clear_sharding();
920 broadcast = partial_replicate_hlo.hlo();
921 }
922
923 if (ag_dims.empty()) {
924 return broadcast;
925 }
926
927 HloInstruction* result = nullptr;
928 if (state_.collective_ops_creator.create_cross_partition_all_gather) {
929 result = state_.partitioner->AllGatherShards(
930 state_.b, broadcast, sharding(), state_.next_channel_id, ag_dims,
931 state_.collective_ops_creator);
932 }
933 if (result == nullptr) {
934 auto zero = state_.b->AddInstruction(HloInstruction::CreateConstant(
935 LiteralUtil::Zero(shard_shape.element_type())));
936 auto zero_bcast = state_.b->AddInstruction(
937 HloInstruction::CreateBroadcast(padded_target_shape, zero, {}));
938 auto offsets = MakePartitionOffsets(padded_target_shape, sharding(),
939 state_.partition_id, state_.b, ag_dims);
940 auto dus =
941 state_.b->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
942 padded_target_shape, zero_bcast, broadcast, offsets));
943 HloComputation* reduction =
944 MakeBinaryAdd(shard_shape.element_type(), state_.module);
945 result = state_.partitioner->AllReduceAlongShardingDims(
946 state_.b, dus, sharding(), state_.next_channel_id, ag_dims,
947 state_.collective_ops_creator, reduction);
948 }
949 if (!ShapeUtil::Compatible(target_shape, padded_target_shape)) {
950 std::vector<int64> start_indices(target_shape.rank(), 0);
951 std::vector<int64> strides(target_shape.rank(), 1);
952 result = state_.b->AddInstruction(
953 HloInstruction::CreateSlice(target_shape, result, start_indices,
954 target_shape.dimensions(), strides));
955 }
956 return result;
957 }
958
959 absl::optional<PartitionedHlo>
ReshardToPartialReplicateWithAllGather(const HloSharding & target)960 PartitionedHlo::ReshardToPartialReplicateWithAllGather(
961 const HloSharding& target) {
962 if (!target.ReplicateOnLastTileDim()) {
963 return absl::nullopt;
964 }
965 // Tiled/partial replicate to partial replicate
966 // Get the comptible sharding to target with resharding by all reduce.
967 auto compatible_sharding =
968 PartialReplicateReshardCompatibleSharding(target, sharding());
969 if (!compatible_sharding.has_value()) {
970 return absl::nullopt;
971 }
972
973 const auto& temp_sharding = compatible_sharding.value();
974 auto partitioned_hlo = *this;
975 // Use collective permute to adjust device assignment if needed.
976 if (CanReshardWithCollectivePermute(sharding(), temp_sharding)) {
977 partitioned_hlo =
978 partitioned_hlo.ReshardWithCollectivePermute(temp_sharding);
979 }
980
981 // Get replicate dims and replicate factor of each dimensions.
982 int64_t rank = hlo_->shape().rank();
983 std::vector<int64> replicate_dims;
984 std::vector<int64> replicate_factors;
985 for (int64_t dim = 0; dim < rank; dim++) {
986 int64_t replicate_factor = temp_sharding.tile_assignment().dim(dim) /
987 target.tile_assignment().dim(dim);
988 if (replicate_factor > 1) {
989 replicate_dims.emplace_back(dim);
990 replicate_factors.emplace_back(replicate_factor);
991 }
992 }
993
994 // Do left halo exchange if all-reduce directly will remove useful data
995 // from the source.
996 auto halo_exchange = TileToPartialReplicateHaloExchange(
997 partitioned_hlo.hlo_, base_shape_, temp_sharding, target, replicate_dims,
998 partitioned_hlo.state().collective_ops_creator,
999 partitioned_hlo.state().next_channel_id,
1000 partitioned_hlo.state().partition_id, partitioned_hlo.state().b);
1001 if (!halo_exchange.has_value()) {
1002 return absl::nullopt;
1003 }
1004 auto halo_exchange_hlo = halo_exchange.value();
1005 // Grouped on replicate dimensions.
1006 auto sharding_grouped =
1007 GroupShardingOnDims(temp_sharding, replicate_dims, replicate_factors);
1008 auto per_group_partitioner_state = CreatePerGroupPartitioningState(
1009 partitioned_hlo.state(), sharding_grouped.device_groups,
1010 partitioned_hlo.state().b);
1011 auto base_shape = MakePartitionedShape(base_shape_, target);
1012 // It's possible that halo_exchange_hlo == hlo.hlo().
1013 // Record the sharding of hlo here, and reset it before return.
1014 auto original_sharding = partitioned_hlo.sharding();
1015 halo_exchange_hlo->set_sharding(sharding_grouped.sharding);
1016 auto partial_replicate_hlo = PartitionedHlo(halo_exchange_hlo, base_shape,
1017 per_group_partitioner_state);
1018 HloInstruction* result =
1019 partial_replicate_hlo.ReplicatePartial(replicate_dims);
1020 partitioned_hlo.hlo()->set_sharding(original_sharding);
1021 result->set_sharding(target);
1022 return PartitionedHlo(result, base_shape_, partitioned_hlo.state());
1023 }
1024
1025 absl::optional<PartitionedHlo>
ReshardFromPartialReplicateWithDynamicSlice(const HloSharding & target)1026 PartitionedHlo::ReshardFromPartialReplicateWithDynamicSlice(
1027 const HloSharding& target) {
1028 if (!sharding().ReplicateOnLastTileDim()) {
1029 return absl::nullopt;
1030 }
1031
1032 // Get the temp sharding target from partial replicate to target tile dims.
1033 // target_compatible_sharding has the same tile_assignment dimensions
1034 // as the target and can reshard to target by collective permute.
1035 // target_compatible_sharding could have different device assignment as
1036 // targe. sharding() can reshard to target_compatible_sharding by
1037 // dynamic slice.
1038 auto target_compatible_sharding =
1039 PartialReplicateReshardCompatibleSharding(sharding(), target);
1040 // Reshard to target_compatible_sharding by dynamic slice.
1041 if (!target_compatible_sharding.has_value()) {
1042 return absl::nullopt;
1043 }
1044 std::vector<int64> expand_tile_dims;
1045 std::vector<int64> tiling_dim_factors;
1046 int64_t rank = hlo_->shape().rank();
1047 tiling_dim_factors.reserve(target.tile_assignment().num_dimensions());
1048 const auto& temp_target_sharding = target_compatible_sharding.value();
1049 for (int64_t dim = 0; dim < rank; dim++) {
1050 if (temp_target_sharding.tile_assignment().dim(dim) >
1051 sharding().tile_assignment().dim(dim)) {
1052 expand_tile_dims.push_back(dim);
1053 }
1054 tiling_dim_factors.emplace_back(
1055 temp_target_sharding.tile_assignment().dim(dim) /
1056 sharding().tile_assignment().dim(dim));
1057 }
1058
1059 // Add another dimension in tiling_dim_factors if target is partial replicate.
1060 if (target.ReplicateOnLastTileDim()) {
1061 tiling_dim_factors.emplace_back(
1062 target.tile_assignment().dimensions().back());
1063 }
1064
1065 // 2. Get the padded_hlo, do right halo exchange if needed.
1066 auto padded_hlo = PadFromPartialReplicateShape(
1067 hlo_, base_shape_, sharding(), temp_target_sharding, expand_tile_dims,
1068 state_.collective_ops_creator, state_.next_channel_id,
1069 state_.partition_id, state_.b);
1070 if (!padded_hlo.has_value()) {
1071 return absl::nullopt;
1072 }
1073 // 3. Slice out the tile from replicate ones.
1074 auto shard_shape = MakePartitionedShape(base_shape_, temp_target_sharding);
1075 // Since we are just slicing, we can just use the differences between the new
1076 // and old offsets in the full shape as the dynamic-slice offsets.
1077 auto padded_base_shape = shard_shape;
1078 for (int64_t i = 0; i < padded_base_shape.rank(); ++i) {
1079 padded_base_shape.set_dimensions(
1080 i, padded_base_shape.dimensions(i) *
1081 temp_target_sharding.tile_assignment().dim(i));
1082 }
1083 auto offsets = MakePartitionOffsets(padded_base_shape, temp_target_sharding,
1084 state_.partition_id, state_.b);
1085 auto old_offsets = MakePartitionOffsets(padded_base_shape, sharding(),
1086 state_.partition_id, state_.b);
1087 for (int64_t i = 0; i < offsets.size(); ++i) {
1088 offsets[i] = state_.b->AddInstruction(HloInstruction::CreateBinary(
1089 offsets[i]->shape(), HloOpcode::kSubtract, offsets[i], old_offsets[i]));
1090 }
1091 auto slice = state_.b->AddInstruction(HloInstruction::CreateDynamicSlice(
1092 shard_shape, padded_hlo.value(), offsets, shard_shape.dimensions()));
1093 slice->set_sharding(temp_target_sharding);
1094 auto result = PartitionedHlo(slice, base_shape_, state_);
1095 // If temp_target_sharding's device assignment is different from target,
1096 // use collective permute to reshard.
1097 if (CanReshardWithCollectivePermute(temp_target_sharding, target)) {
1098 return result.ReshardWithCollectivePermute(target);
1099 }
1100 // If device assignment in temp_target_sharding and target are the same,
1101 // return result directly.
1102 return result;
1103 }
1104
Broadcast() const1105 PartitionedHlo PartitionedHlo::Broadcast() const {
1106 const Shape& shape = hlo_->shape();
1107 const HloSharding& sharding = hlo_->sharding();
1108 CHECK(sharding.HasUniqueDevice());
1109 CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
1110
1111 auto src_core_id = state_.b->AddInstruction(HloInstruction::CreateConstant(
1112 LiteralUtil::CreateR0<uint32>(sharding.GetUniqueDevice())));
1113 Shape bcast_shape = ShapeUtil::ChangeElementType(shape, PRED);
1114 auto is_src_core = state_.b->AddInstruction(HloInstruction::CreateBroadcast(
1115 bcast_shape,
1116 state_.b->AddInstruction(HloInstruction::CreateCompare(
1117 ShapeUtil::MakeShape(PRED, {}), state_.partition_id, src_core_id,
1118 ComparisonDirection::kEq)),
1119 {}));
1120
1121 auto zero = state_.b->AddInstruction(
1122 HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
1123 auto zero_bcast = state_.b->AddInstruction(
1124 HloInstruction::CreateBroadcast(shape, zero, {}));
1125 auto operand = state_.b->AddInstruction(HloInstruction::CreateTernary(
1126 shape, HloOpcode::kSelect, is_src_core, hlo(), zero_bcast));
1127 HloComputation* reduction =
1128 MakeBinaryAdd(shape.element_type(), state_.module);
1129
1130 auto result = state_.collective_ops_creator.create_cross_partition_all_reduce(
1131 state_.b, operand, reduction, {}, NewChannel());
1132 result->set_sharding(HloSharding::Replicate());
1133 return PartitionedHlo(result, base_shape_, state_);
1134 }
1135
ReshardWithAllToAll(const HloSharding & target,absl::Span<const std::pair<int64,int64>> source_target_dims) const1136 PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
1137 const HloSharding& target,
1138 absl::Span<const std::pair<int64, int64>> source_target_dims) const {
1139 if (source_target_dims.empty()) {
1140 if (target == sharding()) {
1141 return *this;
1142 }
1143 // If the device order is different in the target, fix the order with
1144 // ReshardWithCollectivePermute.
1145 return ReshardWithCollectivePermute(target);
1146 }
1147
1148 // Swap one pair of dimensions.
1149 int64_t source_dim = source_target_dims[0].first;
1150 int64_t target_dim = source_target_dims[0].second;
1151 const int64_t group_size = sharding().tile_assignment().dim(source_dim) /
1152 sharding().tile_assignment().dim(target_dim);
1153
1154 auto temp_target_tile = sharding().tile_assignment();
1155 {
1156 std::vector<int64> reshape_tile_dims(temp_target_tile.num_dimensions() + 2);
1157 int64_t i = 0;
1158 int64_t added_source_dim = -1;
1159 int64_t added_target_dim = -1;
1160 for (int64_t j = 0; j < temp_target_tile.num_dimensions(); ++j) {
1161 if (source_dim == j) {
1162 reshape_tile_dims[i] = temp_target_tile.dim(j) / group_size;
1163 reshape_tile_dims[++i] = group_size;
1164 added_source_dim = i;
1165 } else if (target_dim == j) {
1166 reshape_tile_dims[i] = temp_target_tile.dim(j);
1167 reshape_tile_dims[++i] = 1;
1168 added_target_dim = i;
1169 } else {
1170 reshape_tile_dims[i] = temp_target_tile.dim(j);
1171 }
1172 ++i;
1173 }
1174 temp_target_tile.Reshape(reshape_tile_dims);
1175 std::vector<int64> xpose_dims(temp_target_tile.num_dimensions());
1176 std::iota(xpose_dims.begin(), xpose_dims.end(), 0);
1177 xpose_dims[added_source_dim] = added_target_dim;
1178 xpose_dims[added_target_dim] = added_source_dim;
1179 temp_target_tile = hlo_sharding_util::TransposeSharding(
1180 HloSharding::Tile(temp_target_tile), xpose_dims)
1181 .tile_assignment();
1182 auto temp_target_tile_dims = sharding().tile_assignment().dimensions();
1183 temp_target_tile_dims[source_dim] =
1184 sharding().tile_assignment().dim(target_dim);
1185 temp_target_tile_dims[target_dim] =
1186 sharding().tile_assignment().dim(source_dim);
1187 temp_target_tile.Reshape(temp_target_tile_dims);
1188 }
1189 auto temp_target = target.ReplicateOnLastTileDim()
1190 ? HloSharding::PartialTile(temp_target_tile)
1191 : HloSharding::Tile(temp_target_tile);
1192 auto padded_shape = hlo_->shape();
1193 padded_shape.set_dimensions(
1194 target_dim,
1195 RoundUpToNearest(padded_shape.dimensions(target_dim),
1196 temp_target.tile_assignment().dim(target_dim)));
1197 auto padded_hlo = PadToShape(hlo_, padded_shape, state_.b);
1198
1199 // The order of ids in the group must follow the temp_target sharding.
1200 std::vector<std::vector<int64>> groups(
1201 temp_target.tile_assignment().num_elements() / group_size);
1202 temp_target.tile_assignment().Each(
1203 [&](absl::Span<const int64> indices, int64_t device) {
1204 int64_t group_id = 0;
1205 for (int64_t dim = 0; dim < indices.size(); ++dim) {
1206 if (dim == target_dim) {
1207 group_id *= temp_target.tile_assignment().dim(dim) / group_size;
1208 group_id += indices[dim] / group_size;
1209 } else {
1210 group_id *= temp_target.tile_assignment().dim(dim);
1211 group_id += indices[dim];
1212 }
1213 }
1214 groups[group_id].push_back(device);
1215 });
1216
1217 HloInstruction* result = nullptr;
1218
1219 // Split along the split dimension (target_dim) of the all-to-all
1220 // output.
1221 std::vector<int64> dimensions;
1222 for (int64_t i = 0; i < base_shape_.rank(); ++i) {
1223 if (i == target_dim) {
1224 dimensions.push_back(group_size);
1225 dimensions.push_back(padded_hlo->shape().dimensions(i) / group_size);
1226 } else {
1227 dimensions.push_back(padded_hlo->shape().dimensions(i));
1228 }
1229 }
1230 auto reshape = state_.b->AddInstruction(HloInstruction::CreateReshape(
1231 ShapeUtil::MakeShape(base_shape_.element_type(), dimensions),
1232 padded_hlo));
1233 // After the reshape, it is guaranteed to have at least 3 dimensions.
1234 auto all_to_all =
1235 state_.collective_ops_creator.create_cross_partition_all_to_all(
1236 state_.b, {reshape}, groups, (*state_.next_channel_id)++, target_dim);
1237
1238 // Reorder the split dimension of the reshape to be located in front of the
1239 // input partition dimension, so the two dimensions can be combined.
1240 int64_t new_source_dim =
1241 (target_dim < source_dim) ? source_dim + 1 : source_dim;
1242 std::vector<int64> permutation;
1243 for (int64_t i = 0; i < all_to_all->shape().rank(); ++i) {
1244 if (i == target_dim) {
1245 continue;
1246 }
1247 if (i == new_source_dim) {
1248 permutation.push_back(target_dim);
1249 }
1250 permutation.push_back(i);
1251 }
1252 auto transpose = state_.b->AddInstruction(HloInstruction::CreateTranspose(
1253 ShapeInference::InferTransposeShape(all_to_all->shape(), permutation)
1254 .ValueOrDie(),
1255 all_to_all, permutation));
1256
1257 // Combine the split dimension and the input partition dimension.
1258 auto new_shape = ShapeInference::InferAllToAllShape(
1259 padded_hlo->shape(), target_dim, source_dim, group_size)
1260 .ValueOrDie();
1261 result = state_.b->AddInstruction(
1262 HloInstruction::CreateReshape(new_shape, transpose));
1263
1264 const Shape result_shape = MakePartitionedShape(base_shape_, temp_target);
1265 if (result_shape != result->shape()) {
1266 result = state_.b->AddInstruction(HloInstruction::CreateSlice(
1267 result_shape, result, std::vector<int64>(result_shape.rank(), 0),
1268 result_shape.dimensions(), std::vector<int64>(result_shape.rank(), 1)));
1269 }
1270 result->set_sharding(temp_target);
1271 auto remaining_source_target_dims = source_target_dims;
1272 remaining_source_target_dims.remove_prefix(1);
1273 return PartitionedHlo(result, base_shape_, state_)
1274 .ReshardWithAllToAll(target, remaining_source_target_dims);
1275 }
1276
1277 absl::optional<PartitionedHlo>
ReshardPartialReplicateWithAllToAll(const HloSharding & target)1278 PartitionedHlo::ReshardPartialReplicateWithAllToAll(const HloSharding& target) {
1279 bool source_is_partial_replicate = sharding().ReplicateOnLastTileDim();
1280 const auto& partial_replicate_sharding =
1281 source_is_partial_replicate ? sharding() : target;
1282 // If neither the source nor the target is partial replicate, return null.
1283 if (!partial_replicate_sharding.ReplicateOnLastTileDim()) {
1284 return absl::nullopt;
1285 }
1286 const auto& tile_sharding = source_is_partial_replicate ? target : sharding();
1287 // If both source and target are partial replicate, should be supported in
1288 // Reshard with AllToAll already.
1289 if (tile_sharding.ReplicateOnLastTileDim() || tile_sharding.IsTileMaximal()) {
1290 return absl::nullopt;
1291 }
1292
1293 // Only support resharding from sharding={devices=[2,3]0,1,2,3,4,5}
1294 // to sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}, where
1295 // the last tile dim will be replicate first before all-to-all.
1296 // Or resharding from
1297 // sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
1298 // to sharding={devices=[2,3]0,1,2,3,4,5}, where
1299 // the last tile dim will be sharded after all-to-all.
1300 const int num_replicas =
1301 partial_replicate_sharding.tile_assignment().dimensions().back();
1302 if (((tile_sharding.tile_assignment().num_dimensions() + 1) !=
1303 partial_replicate_sharding.tile_assignment().num_dimensions()) ||
1304 (partial_replicate_sharding.tile_assignment().dim(0) != 1)) {
1305 return absl::nullopt;
1306 }
1307 int to_replicate_dim = -1;
1308 for (int i = tile_sharding.tile_assignment().num_dimensions() - 1; i >= 0;
1309 --i) {
1310 if (tile_sharding.tile_assignment().dim(i) > 1 &&
1311 (to_replicate_dim == -1)) {
1312 if (tile_sharding.tile_assignment().dim(i) != num_replicas) {
1313 return absl::nullopt;
1314 }
1315 to_replicate_dim = i;
1316 }
1317
1318 if (tile_sharding.tile_assignment().dim(i) !=
1319 partial_replicate_sharding.tile_assignment().dim(i + 1)) {
1320 return absl::nullopt;
1321 }
1322 }
1323
1324 if (to_replicate_dim == -1) {
1325 return absl::nullopt;
1326 }
1327
1328 // Check if core assignments for source and the target are the same.
1329 auto reshape_tile_assignment = partial_replicate_sharding.tile_assignment();
1330 reshape_tile_assignment.Reshape(tile_sharding.tile_assignment().dimensions());
1331 if (reshape_tile_assignment != tile_sharding.tile_assignment()) {
1332 return absl::nullopt;
1333 }
1334
1335 auto tmp_tile_assignment = tile_sharding.tile_assignment();
1336 auto tmp_tile_assignment_dimensions =
1337 tile_sharding.tile_assignment().dimensions();
1338 tmp_tile_assignment_dimensions[to_replicate_dim] = 1;
1339 tmp_tile_assignment_dimensions.push_back(num_replicas);
1340 tmp_tile_assignment.Reshape(tmp_tile_assignment_dimensions);
1341 auto tmp_partial_replicate_sharding =
1342 HloSharding::PartialTile(tmp_tile_assignment);
1343
1344 if (source_is_partial_replicate) {
1345 if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims(
1346 sharding(), tmp_partial_replicate_sharding)) {
1347 auto partitioned_hlo =
1348 ReshardWithAllToAll(tmp_partial_replicate_sharding, *src_tgt_dims);
1349 return partitioned_hlo.Reshard(target);
1350 }
1351 } else {
1352 auto partitioned_hlo = Reshard(tmp_partial_replicate_sharding);
1353
1354 if (auto src_tgt_dims = GetReshardAllToAllSourceTargetDims(
1355 partitioned_hlo.sharding(), target)) {
1356 return partitioned_hlo.ReshardWithAllToAll(target, *src_tgt_dims);
1357 }
1358 }
1359
1360 return absl::nullopt;
1361 }
1362
ReshardWithCollectivePermute(const HloSharding & target) const1363 PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute(
1364 const HloSharding& target) const {
1365 CHECK(CanReshardWithCollectivePermute(sharding(), target))
1366 << sharding().ToString() << " to " << target.ToString();
1367 if (auto broadcast_dims = state_.b->BroadcastDimsForCreatedHlo(hlo())) {
1368 if (!(*broadcast_dims)->empty()) {
1369 // If hlo() has broadcast dims, check if data is already the same between
1370 // source/destination pairs.
1371 std::vector<int64> broadcast_dims_vector;
1372 for (int64_t i = 0; i < hlo()->shape().rank(); ++i) {
1373 if ((*broadcast_dims)->contains(i)) {
1374 broadcast_dims_vector.push_back(i);
1375 }
1376 }
1377 if (hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
1378 sharding(), broadcast_dims_vector) ==
1379 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
1380 target, broadcast_dims_vector)) {
1381 auto copy = state_.b->AddInstruction(HloInstruction::CreateUnary(
1382 hlo()->shape(), HloOpcode::kCopy, hlo()));
1383 copy->set_sharding(target);
1384 return PartitionedHlo(copy, base_shape_, state_);
1385 }
1386 }
1387 }
1388 std::vector<std::pair<int64, int64>> src_dst_pairs;
1389 sharding().tile_assignment().Each(
1390 [&](absl::Span<const int64> indices, int64_t src_device) {
1391 int64_t dst_device = target.tile_assignment()(indices);
1392 src_dst_pairs.emplace_back(src_device, dst_device);
1393 });
1394 auto cp =
1395 state_.collective_ops_creator.create_cross_partition_collective_permute(
1396 state_.b, hlo(), src_dst_pairs, (*state_.next_channel_id)++);
1397 cp->set_sharding(target);
1398 return PartitionedHlo(cp, base_shape_, state_);
1399 }
1400
SpmdPartitioningVisitor(HloComputation * computation,int64_t num_partitions,int64_t num_replicas,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,SpmdLogger * logger,SpmdPartitionerOptions options,SpmdPartitioner * partitioner)1401 SpmdPartitioningVisitor::SpmdPartitioningVisitor(
1402 HloComputation* computation, int64_t num_partitions, int64_t num_replicas,
1403 const SPMDCollectiveOpsCreator& collective_ops_creator,
1404 int64* next_channel_id, SpmdLogger* logger, SpmdPartitionerOptions options,
1405 SpmdPartitioner* partitioner)
1406 : changed_(false),
1407 module_(computation->parent()),
1408 num_partitions_(num_partitions),
1409 num_replicas_(num_replicas),
1410 collective_ops_creator_(collective_ops_creator),
1411 next_channel_id_(next_channel_id),
1412 b_(SpmdBuilder(computation->name() + "_spmd", /*hlo=*/nullptr)),
1413 partition_id_(collective_ops_creator_.create_partition_id(&b_)),
1414 logger_(logger),
1415 options_(std::move(options)),
1416 partitioner_(partitioner) {}
1417
1418 PartitionedHlo::PartitioningState
MakePartitioningState()1419 SpmdPartitioningVisitor::MakePartitioningState() {
1420 PartitionedHlo::PartitioningState state;
1421 state.b = &b_;
1422 state.module = module_;
1423 state.num_replicas = num_replicas_;
1424 state.partition_id = partition_id_;
1425 state.collective_ops_creator = collective_ops_creator_;
1426 state.next_channel_id = next_channel_id_;
1427 state.reshard_cache = &reshard_cache_;
1428 state.partitioner = partitioner_;
1429 if (!device_groups_.empty()) {
1430 return CreatePerGroupPartitioningState(state, device_groups_, &b_);
1431 }
1432 return state;
1433 }
1434
DefaultAction(HloInstruction * hlo)1435 Status SpmdPartitioningVisitor::DefaultAction(HloInstruction* hlo) {
1436 if (hlo->HasSideEffect()) {
1437 return Unimplemented("Side-effect ops cannot be replicated: %s",
1438 hlo->ToString());
1439 }
1440
1441 if (hlo->IsElementwise() && hlo->operand_count() > 0) {
1442 return HandleElementwise(hlo);
1443 }
1444
1445 if (!hlo->sharding().IsTileMaximal()) {
1446 VLOG(1) << "Not partitioned in SPMD mode (DefaultAction):"
1447 << hlo->ToString();
1448 for (int64_t i = 0; i < hlo->operand_count(); ++i) {
1449 VLOG(1) << " operand " << i
1450 << " sharding:" << hlo->operand(i)->sharding().ToString();
1451 }
1452 }
1453
1454 HloSharding sharding = hlo->sharding().HasUniqueDevice()
1455 ? hlo->sharding()
1456 : HloSharding::Replicate();
1457
1458 // If the instruction cannot be partitioned, replicate the instruction unless
1459 // the instruction has side-effect.
1460 std::vector<HloInstruction*> new_operands;
1461 for (HloInstruction* operand : hlo->operands()) {
1462 new_operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
1463 }
1464 auto clone =
1465 b_.AddInstruction(hlo->CloneWithNewOperands(hlo->shape(), new_operands));
1466 clone->set_sharding(sharding);
1467 SetPartitionedHlo(hlo,
1468 PartitionedHlo(clone, hlo->shape(), MakePartitioningState())
1469 .Reshard(hlo->sharding()));
1470 return Status::OK();
1471 }
1472
Preprocess(HloInstruction * hlo)1473 Status SpmdPartitioningVisitor::Preprocess(HloInstruction* hlo) {
1474 visiting_hlo_ = hlo;
1475 b_.set_visiting_hlo(hlo);
1476 // Temporarily replace manual sharding to one-device sharding so that the
1477 // partitioner will not change the HLOs.
1478 auto manual_to_onedevice = [&](const Shape& shape,
1479 const HloSharding& sharding) {
1480 // If a tuple's elements are all manual, then sharding.IsManual() == True,
1481 // so we test whether it is tuple first.
1482 if (sharding.IsTuple()) {
1483 std::vector<HloSharding> subshardings = sharding.tuple_elements();
1484 for (HloSharding& subsharding : subshardings) {
1485 if (subsharding.IsManual()) {
1486 subsharding = HloSharding::AssignDevice(0);
1487 }
1488 }
1489 return HloSharding::Tuple(shape, subshardings);
1490 }
1491 if (sharding.IsManual()) {
1492 return HloSharding::AssignDevice(0);
1493 }
1494 return sharding;
1495 };
1496
1497 if (hlo->opcode() != HloOpcode::kConditional &&
1498 hlo->opcode() != HloOpcode::kWhile && hlo->opcode() != HloOpcode::kRng) {
1499 const bool has_manual_sharding =
1500 hlo->sharding().IsManual() ||
1501 (hlo->sharding().IsTuple() &&
1502 absl::c_any_of(
1503 hlo->sharding().tuple_elements(),
1504 [](const HloSharding& sharding) { return sharding.IsManual(); }));
1505 if (has_manual_sharding && !hlo->IsCustomCall("SPMDFullToShardShape")) {
1506 visiting_hlo_sharding_ = hlo->sharding();
1507 hlo->set_sharding(
1508 manual_to_onedevice(hlo->shape(), *visiting_hlo_sharding_));
1509
1510 visiting_hlo_operand_shardings_.reserve(hlo->operand_count());
1511 for (auto operand : hlo->operands()) {
1512 visiting_hlo_operand_shardings_.push_back(operand->sharding());
1513 operand->set_sharding(
1514 manual_to_onedevice(operand->shape(), operand->sharding()));
1515 GetPartitionedHlo(operand).hlo()->set_sharding(operand->sharding());
1516 }
1517 } else if (hlo->sharding().IsManualSubgroup()) {
1518 GroupedSharding group_sharding =
1519 GetManualSubgroupSharding(hlo->sharding());
1520 // Update sharding.
1521 visiting_hlo_sharding_ = hlo->sharding();
1522 hlo->set_sharding(group_sharding.sharding);
1523 // Update device_groups and num_partitions.
1524 device_groups_ = group_sharding.device_groups;
1525 visiting_num_partitions_ = num_partitions_;
1526 num_partitions_ = num_partitions_ / group_sharding.device_groups.size();
1527
1528 // Update sharding for the operands.
1529 visiting_hlo_operand_shardings_.reserve(hlo->operand_count());
1530 visiting_state_.reserve(hlo->operand_count());
1531 for (auto operand : hlo->operands()) {
1532 visiting_hlo_operand_shardings_.push_back(operand->sharding());
1533 GroupedSharding group_sharding =
1534 GetManualSubgroupSharding(operand->sharding());
1535 operand->set_sharding(group_sharding.sharding);
1536 GetPartitionedHlo(operand).hlo()->set_sharding(operand->sharding());
1537 auto old_state = GetPartitionedHlo(operand).state();
1538 visiting_state_.push_back(old_state);
1539 auto group_state = CreatePerGroupPartitioningState(
1540 old_state, group_sharding.device_groups, &b_);
1541 GetPartitionedHlo(operand).set_state(group_state);
1542 }
1543 }
1544 }
1545 return Status::OK();
1546 }
1547
Postprocess(HloInstruction * hlo)1548 Status SpmdPartitioningVisitor::Postprocess(HloInstruction* hlo) {
1549 logger_->RegisterLogEntry(GetPartitionedHlo(hlo).hlo(),
1550 b_.derived_instructions(hlo));
1551 visiting_hlo_ = nullptr;
1552 b_.set_visiting_hlo(nullptr);
1553 // Revert fake one-device shardings for manually partitioned ops.
1554 if (visiting_hlo_sharding_) {
1555 hlo->set_sharding(*visiting_hlo_sharding_);
1556 GetPartitionedHlo(hlo).hlo()->set_sharding(*visiting_hlo_sharding_);
1557 for (int64_t i = 0; i < hlo->operand_count(); ++i) {
1558 auto operand = hlo->mutable_operand(i);
1559 operand->set_sharding(visiting_hlo_operand_shardings_[i]);
1560 GetPartitionedHlo(operand).hlo()->set_sharding(operand->sharding());
1561 }
1562 visiting_hlo_sharding_.reset();
1563 visiting_hlo_operand_shardings_.clear();
1564 }
1565
1566 if (!device_groups_.empty()) {
1567 device_groups_.clear();
1568 GetPartitionedHlo(hlo).set_state(MakePartitioningState());
1569 num_partitions_ = *visiting_num_partitions_;
1570 visiting_num_partitions_.reset();
1571 }
1572
1573 if (!visiting_state_.empty()) {
1574 for (int64_t i = 0; i < hlo->operand_count(); ++i) {
1575 const HloInstruction* operand = hlo->operand(i);
1576 GetPartitionedHlo(operand).set_state(visiting_state_[i]);
1577 }
1578 visiting_state_.clear();
1579 }
1580
1581 return Status::OK();
1582 }
1583
HandleElementwise(HloInstruction * hlo)1584 Status SpmdPartitioningVisitor::HandleElementwise(HloInstruction* hlo) {
1585 std::vector<HloInstruction*> new_operands;
1586 for (HloInstruction* operand : hlo->operands()) {
1587 new_operands.push_back(
1588 GetPartitionedHlo(operand).Reshard(hlo->sharding()).hlo());
1589 }
1590 SetPartitionedHlo(hlo, [&] {
1591 return b_.AddInstruction(hlo->CloneWithNewOperands(
1592 MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands));
1593 });
1594 return Status::OK();
1595 }
1596
HandleConcatenate(HloInstruction * hlo)1597 Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) {
1598 const HloSharding& sharding = hlo->sharding();
1599 if (sharding.IsTileMaximal()) {
1600 return DefaultAction(hlo);
1601 }
1602
1603 const Shape shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
1604 const int64_t dimension = hlo->concatenate_dimension();
1605 if (sharding.tile_assignment().dim(dimension) == 1) {
1606 std::vector<HloInstruction*> new_operands;
1607 for (HloInstruction* operand : hlo->operands()) {
1608 new_operands.push_back(
1609 GetPartitionedHlo(operand).Reshard(sharding).hlo());
1610 }
1611 SetPartitionedHlo(hlo, [&] {
1612 return b_.AddInstruction(
1613 hlo->CloneWithNewOperands(shard_shape, new_operands));
1614 });
1615 return Status::OK();
1616 }
1617
1618 // If the concatenate dimension is along one of the partitioned dimensions,
1619 // allocate the full output shape, each partition updates its owned region,
1620 // all-reduce across partitions, and then slice its output region.
1621
1622 // temp_output_shape is the output shape where the concatenate dimension
1623 // is changed to the full (and padded to shard count) dimension size.
1624 auto temp_output_shape = MakePartitionedShape(hlo->shape(), sharding);
1625 auto last_operand_padded_shape =
1626 MakePartitionedShape(hlo->operands().back()->shape(), sharding);
1627 // If the last operand has more padding than the temp_output padding, needs to
1628 // add extra padding to avoid dynamic update slice out of bound.
1629 int last_operand_padding =
1630 last_operand_padded_shape.dimensions(dimension) *
1631 sharding.tile_assignment().dim(dimension) -
1632 hlo->operands().back()->shape().dimensions(dimension);
1633 int temp_output_padding = temp_output_shape.dimensions(dimension) *
1634 sharding.tile_assignment().dim(dimension) -
1635 hlo->shape().dimensions(dimension);
1636 int padding_for_last_operand =
1637 last_operand_padding < temp_output_padding
1638 ? 0
1639 : last_operand_padding - temp_output_padding;
1640 temp_output_shape.set_dimensions(
1641 dimension, temp_output_shape.dimensions(dimension) *
1642 sharding.tile_assignment().dim(dimension) +
1643 padding_for_last_operand);
1644 auto temp_output = CreateZero(temp_output_shape, &b_);
1645
1646 // Offset of each operand along the concatenate dimension.
1647 int64_t offset = 0;
1648 auto state = MakePartitioningState();
1649 for (HloInstruction* operand : hlo->operands()) {
1650 auto spmd_operand = GetPartitionedHlo(operand).Reshard(sharding).hlo();
1651 std::vector<HloInstruction*> start_indices(
1652 hlo->shape().rank(), b_.AddInstruction(HloInstruction::CreateConstant(
1653 LiteralUtil::Zero(S32))));
1654 start_indices[dimension] =
1655 MultiplyAddDivideOffsetCalculation(
1656 spmd_operand->shape().dimensions(dimension), offset, 1)
1657 .Calculate(MakeTiledPartitionOrdinals(sharding, state.partition_id,
1658 &b_)[dimension],
1659 &b_);
1660 temp_output = b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1661 temp_output_shape, temp_output, spmd_operand, start_indices));
1662 offset += operand->shape().dimensions(dimension);
1663 }
1664 std::vector<int64> non_concat_dims;
1665 non_concat_dims.reserve(hlo->shape().rank() - 1);
1666 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
1667 if (i != dimension) {
1668 non_concat_dims.push_back(i);
1669 }
1670 }
1671 auto grouped = GroupShardingOnDims(sharding, non_concat_dims);
1672 auto per_group_partitioner_state =
1673 CreatePerGroupPartitioningState(state, grouped.device_groups, &b_);
1674 auto all_reduce = per_group_partitioner_state.collective_ops_creator
1675 .create_cross_partition_all_reduce(
1676 &b_, temp_output,
1677 MakeBinaryAdd(hlo->shape().element_type(), module_),
1678 {}, NewChannel());
1679 SetPartitionedHlo(hlo, [&] {
1680 auto start_indices = MakeTiledPartitionOrdinals(
1681 grouped.sharding, per_group_partitioner_state.partition_id, &b_);
1682 start_indices[dimension] = MultiplyAddDivideOffsetCalculation(
1683 shard_shape.dimensions(dimension), 0, 1)
1684 .Calculate(start_indices[dimension], &b_);
1685 return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
1686 shard_shape, all_reduce, start_indices, shard_shape.dimensions()));
1687 });
1688
1689 return Status::OK();
1690 }
1691
HandleSlice(HloInstruction * hlo)1692 Status SpmdPartitioningVisitor::HandleSlice(HloInstruction* hlo) {
1693 const HloSharding& sharding = hlo->sharding();
1694 if (sharding.IsTileMaximal()) {
1695 return DefaultAction(hlo);
1696 }
1697
1698 auto operand = GetPartitionedHlo(hlo->operand(0)).Reshard(sharding);
1699
1700 const int64_t rank = hlo->shape().rank();
1701 // Create a window config to represent the slice.
1702 Window window;
1703 for (int64_t i = 0; i < rank; ++i) {
1704 WindowDimension* dim = window.add_dimensions();
1705 dim->set_size(1);
1706 dim->set_stride(hlo->slice_strides(i));
1707 dim->set_window_dilation(1);
1708 dim->set_window_reversal(false);
1709 dim->set_padding_low(-hlo->slice_starts(i));
1710 dim->set_padding_high(hlo->slice_limits(i) -
1711 operand.base_shape().dimensions(i));
1712 dim->set_base_dilation(1);
1713 }
1714
1715 auto reshard_operand = operand.ReshardAsWindowedInput(
1716 window, sharding,
1717 CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
1718 /*mask_invalid_region=*/false);
1719 if (!reshard_operand.has_value()) {
1720 return DefaultAction(hlo);
1721 }
1722 TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value());
1723 const Shape& operand_shape = reshard_operand->sharded_input->shape();
1724
1725 std::vector<int64> start_indices(rank);
1726 std::vector<int64> limit_indices(rank);
1727 const std::vector<int64>& strides = hlo->slice_strides();
1728 bool need_slice = false;
1729 for (int64_t i = 0; i < rank; ++i) {
1730 auto dim = reshard_operand->shard_window.dimensions(i);
1731 start_indices[i] = -dim.padding_low();
1732 limit_indices[i] = operand_shape.dimensions(i) + dim.padding_high();
1733 if (start_indices[i] != 0 || strides[i] != 1 ||
1734 limit_indices[i] != operand_shape.dimensions(i)) {
1735 need_slice = true;
1736 }
1737 }
1738
1739 SetPartitionedHlo(hlo, [&] {
1740 if (need_slice) {
1741 auto shard_shape = MakePartitionedShape(hlo->shape(), sharding);
1742 return b_.AddInstruction(HloInstruction::CreateSlice(
1743 shard_shape, reshard_operand->sharded_input, start_indices,
1744 limit_indices, strides));
1745 }
1746 auto data = reshard_operand->sharded_input;
1747 // Create a copy so that it will not share the resharding cache.
1748 return b_.AddInstruction(
1749 HloInstruction::CreateUnary(data->shape(), HloOpcode::kCopy, data));
1750 });
1751
1752 return Status::OK();
1753 }
1754
HandleSort(HloInstruction * hlo)1755 Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) {
1756 HloSharding sharding = hlo->sharding();
1757 if (sharding.HasUniqueDevice()) {
1758 return DefaultAction(hlo);
1759 }
1760 // Special handling for sort in TopK when first operand partitioined at
1761 // sort dimension.
1762 auto k = GetKValueInTopKWhenPartitionSortDim(hlo);
1763 if (k.has_value()) {
1764 // When the first operand partitioned at sort dimension:
1765 // 1. Partition sort computation to different partitions;
1766 // 2. Slice TopK value and index from different partitions;
1767 // 3. Gather and replicate value and index from different partitions,
1768 // the shape of replicated value and index will be
1769 // [batch_size, ..., partition_count * k, ...];
1770 // 4. Final sort uses replicated value and index from different partitions
1771 // as input.
1772 // GetTupleElement and Slice after the non-partitoned sort won't change
1773 // at this point, as HandleGetTupleElement and HandleSlice will update them.
1774 HloSortInstruction* sort = DynCast<HloSortInstruction>(hlo);
1775 const int64_t sort_dim = sort->sort_dimension();
1776 auto input = hlo->operand(0);
1777 auto index = hlo->operand(1);
1778 const HloSharding& input_sharding = input->sharding();
1779 const int64_t partition_count =
1780 input_sharding.tile_assignment().dim(sort_dim);
1781 const int64_t input_size = input->shape().dimensions(sort_dim);
1782 const int64_t per_partition_size = CeilOfRatio(input_size, partition_count);
1783 const auto element_type = input->shape().element_type();
1784 const auto index_type = index->shape().element_type();
1785
1786 // Partition and pad input and index.
1787 // Pad input with minimal value.
1788 auto partitioned_input = GetPartitionedHlo(input).PadWithValue(
1789 CreateFirstWithType(element_type, &b_));
1790 // Pad index with max value.
1791 auto partitioned_index =
1792 GetPartitionedHlo(index)
1793 .Reshard(input_sharding)
1794 .PadWithValue(CreateLastWithType(index_type, &b_));
1795
1796 // Each partition needs to do TopK separately, thus the base shape
1797 // becomes the padded shape.
1798 std::vector<int64> replicated_dimensions(
1799 input->shape().dimensions().begin(), input->shape().dimensions().end());
1800 replicated_dimensions[sort_dim] = per_partition_size * partition_count;
1801 const Shape replicated_shape = ShapeUtil::MakeTupleShape(
1802 {ShapeUtil::MakeShape(element_type, replicated_dimensions),
1803 ShapeUtil::MakeShape(index_type, replicated_dimensions)});
1804
1805 // Partition original topk to different shards.
1806 auto topk_sharding =
1807 input_sharding.GetTupleSharding(replicated_shape).ValueOrDie();
1808 auto shard_shape = MakePartitionedShape(replicated_shape, topk_sharding);
1809 auto topk = b_.AddInstruction(hlo->CloneWithNewOperands(
1810 shard_shape, {partitioned_input.hlo(), partitioned_index.hlo()}));
1811
1812 // Get value from first sort.
1813 HloInstruction* value_gte =
1814 b_.AddInstruction(HloInstruction::CreateGetTupleElement(
1815 topk->shape().tuple_shapes(0), topk, 0));
1816 HloInstruction* index_gte =
1817 b_.AddInstruction(HloInstruction::CreateGetTupleElement(
1818 topk->shape().tuple_shapes(1), topk, 1));
1819
1820 // Slice top K value from the first partitioned sort.
1821 replicated_dimensions[sort_dim] = k.value() * partition_count;
1822 auto slice_input = SliceFirstK(value_gte, &b_, sort_dim, k.value());
1823 slice_input->set_sharding(input_sharding);
1824 PartitionedHlo partitioned_slice_input(
1825 slice_input, ShapeUtil::MakeShape(element_type, replicated_dimensions),
1826 MakePartitioningState());
1827 // Reshard value to be replicated.
1828 auto replicated_slice_input =
1829 partitioned_slice_input.Reshard(HloSharding::Replicate()).hlo();
1830
1831 // Slice top K index from the first parttioned sort.
1832 auto slice_index = SliceFirstK(index_gte, &b_, sort_dim, k.value());
1833 slice_index->set_sharding(input_sharding);
1834 PartitionedHlo partitioned_slice_index(
1835 slice_index, ShapeUtil::MakeShape(index_type, replicated_dimensions),
1836 MakePartitioningState());
1837 // Reshard value to be replicated.
1838 auto replicated_slice_index =
1839 partitioned_slice_index.Reshard(HloSharding::Replicate()).hlo();
1840
1841 // Creates replicated sort to do TopK, the input is value and index pairs
1842 // from all the partitions.
1843 const Shape final_topk_shape = ShapeUtil::MakeTupleShape(
1844 {ShapeUtil::MakeShape(element_type, replicated_dimensions),
1845 ShapeUtil::MakeShape(index_type, replicated_dimensions)});
1846 HloInstruction* final_sort = b_.AddInstruction(HloInstruction::CreateSort(
1847 final_topk_shape, sort_dim,
1848 {replicated_slice_input, replicated_slice_index}, sort->to_apply(),
1849 sort->is_stable()));
1850 final_sort->set_sharding(HloSharding::Replicate()
1851 .GetTupleSharding(final_sort->shape())
1852 .ValueOrDie());
1853 PartitionedHlo replicated_sort(final_sort, final_sort->shape(),
1854 MakePartitioningState());
1855 SetPartitionedHlo(hlo, replicated_sort.Reshard(hlo->sharding()));
1856
1857 return Status::OK();
1858 }
1859
1860 if (hlo->shape().IsTuple()) {
1861 // Check that all elements are sharded in the same way.
1862 if (hlo->shape().tuple_shapes_size() == 0) {
1863 return DefaultAction(hlo);
1864 }
1865 sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0});
1866 for (int64_t i = 1; i < hlo->operand_count(); ++i) {
1867 if (sharding != hlo->sharding().GetSubSharding(hlo->shape(), {i})) {
1868 return DefaultAction(hlo);
1869 }
1870 }
1871 }
1872 if (sharding.IsTileMaximal()) {
1873 return DefaultAction(hlo);
1874 }
1875 for (int64_t dim : hlo->dimensions()) {
1876 if (sharding.tile_assignment().dim(dim) > 1) {
1877 return DefaultAction(hlo);
1878 }
1879 }
1880 // Reshard operands to the same as the output.
1881 std::vector<HloInstruction*> new_operands;
1882 for (HloInstruction* operand : hlo->operands()) {
1883 new_operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
1884 }
1885 SetPartitionedHlo(hlo, [&] {
1886 return b_.AddInstruction(hlo->CloneWithNewOperands(
1887 MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands));
1888 });
1889 return Status::OK();
1890 }
1891
HandleTranspose(HloInstruction * hlo)1892 Status SpmdPartitioningVisitor::HandleTranspose(HloInstruction* hlo) {
1893 const HloSharding& sharding = hlo->sharding();
1894 if (sharding.IsTileMaximal()) {
1895 return DefaultAction(hlo);
1896 }
1897
1898 std::vector<int64> inverse_dimensions(hlo->shape().rank());
1899 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
1900 inverse_dimensions[hlo->dimensions(i)] = i;
1901 }
1902 auto desired_operand_sharding =
1903 hlo_sharding_util::TransposeSharding(sharding, inverse_dimensions);
1904
1905 auto operand = GetPartitionedHlo(hlo->operand(0))
1906 .Reshard(desired_operand_sharding)
1907 .hlo();
1908 SetPartitionedHlo(hlo, [&] {
1909 return b_.AddInstruction(hlo->CloneWithNewOperands(
1910 MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand}));
1911 });
1912 return Status::OK();
1913 }
1914
HandleReshape(HloInstruction * hlo)1915 Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) {
1916 const HloSharding& sharding = hlo->sharding();
1917 if (sharding.IsTileMaximal()) {
1918 return DefaultAction(hlo);
1919 }
1920
1921 auto operand = GetPartitionedHlo(hlo->operand(0));
1922 // The output shape is the source and the operand shape is the target to get
1923 // the aligned sharding for the operand.
1924 absl::optional<HloSharding> desired_operand_sharding =
1925 hlo_sharding_util::ReshapeSharding(hlo->shape(), hlo->operand(0)->shape(),
1926 hlo->sharding());
1927 if (desired_operand_sharding.has_value()) {
1928 auto operand_hlo = operand.Reshard(*desired_operand_sharding).hlo();
1929 SetPartitionedHlo(hlo, [&] {
1930 return b_.AddInstruction(hlo->CloneWithNewOperands(
1931 MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand_hlo}));
1932 });
1933 return Status::OK();
1934 }
1935 absl::optional<HloSharding> desired_output_sharding =
1936 hlo_sharding_util::ReshapeSharding(hlo->operand(0)->shape(), hlo->shape(),
1937 operand.sharding());
1938 if (desired_output_sharding.has_value()) {
1939 auto reshape = b_.AddInstruction(hlo->CloneWithNewOperands(
1940 MakePartitionedShape(hlo->shape(), *desired_output_sharding),
1941 {operand.hlo()}));
1942 reshape->set_sharding(*desired_output_sharding);
1943 SetPartitionedHlo(hlo, [&] {
1944 return PartitionedHlo(reshape, hlo->shape(), MakePartitioningState())
1945 .Reshard(sharding)
1946 .hlo();
1947 });
1948 return Status::OK();
1949 }
1950
1951 // Check if operand sharding and sharding are both tiled or partial replicate.
1952 // If both of them are partial replicate, check num_replications are the same.
1953 if (operand.sharding().ReplicateOnLastTileDim() !=
1954 sharding.ReplicateOnLastTileDim() ||
1955 (sharding.ReplicateOnLastTileDim() &&
1956 (operand.sharding().tile_assignment().dimensions().back() !=
1957 sharding.tile_assignment().dimensions().back()))) {
1958 return DefaultAction(hlo);
1959 }
1960
1961 // Try use halo exchange for certain split-dim/merge-dims cases.
1962 // ReshapeSharding failed in these cases probably due to uneven partitioning,
1963 // where halo exchange could help. Specifically we check the following
1964 // conditions to detect supported cases:
1965 // 1) Both input and output are partitioned on one dimension.
1966 // 2) The combined size of dimensions before the partitioned dimension are the
1967 // same on input and output. This means we don't need to consider the major
1968 // dimensions.
1969 // 3) Let A = the input size on the partitioned dimension, and
1970 // B = the output size on the partitioned dimension; then
1971 // either A % B == 0 (split dim) or B % A == 0 (merge dims).
1972 auto maybe_input_sharded_dim = UniqueTiledDim(operand.sharding());
1973 auto maybe_output_sharded_dim = UniqueTiledDim(sharding);
1974 if (!maybe_input_sharded_dim || !maybe_output_sharded_dim) {
1975 return DefaultAction(hlo);
1976 }
1977 int64_t input_sharded_dim = *maybe_input_sharded_dim;
1978 int64_t output_sharded_dim = *maybe_output_sharded_dim;
1979 // Check that the major dims before the sharded dim have the same total size
1980 // for input and output.
1981 int64_t input_major_dims_size = 1;
1982 for (int64_t i = 0; i < input_sharded_dim; ++i) {
1983 input_major_dims_size *= operand.base_shape().dimensions(i);
1984 }
1985 int64_t output_major_dims_size = 1;
1986 for (int64_t i = 0; i < output_sharded_dim; ++i) {
1987 output_major_dims_size *= hlo->shape().dimensions(i);
1988 }
1989 if (input_major_dims_size != output_major_dims_size) {
1990 return DefaultAction(hlo);
1991 }
1992 // Fix potential device ordering mismatch in tile assignment.
1993 Array<int64> new_input_tile_assignment = sharding.tile_assignment();
1994 new_input_tile_assignment.Reshape(
1995 operand.sharding().tile_assignment().dimensions());
1996 auto aligned_sharding =
1997 sharding.ReplicateOnLastTileDim()
1998 ? HloSharding::PartialTile(new_input_tile_assignment)
1999 : HloSharding::Tile(new_input_tile_assignment);
2000 operand = operand.Reshard(aligned_sharding);
2001 auto replication_count = sharding.ReplicateOnLastTileDim()
2002 ? sharding.tile_assignment().dimensions().back()
2003 : 1;
2004
2005 int64_t input_dim_size = operand.base_shape().dimensions(input_sharded_dim);
2006 int64_t output_dim_size = hlo->shape().dimensions(output_sharded_dim);
2007 auto input_shard_shape =
2008 MakePartitionedShape(operand.base_shape(), operand.sharding());
2009 auto output_shard_shape = MakePartitionedShape(hlo->shape(), sharding);
2010 if (input_dim_size % output_dim_size == 0) {
2011 // Split dim.
2012 int64_t split_factor = input_dim_size / output_dim_size;
2013 int64_t output_shard_size =
2014 output_shard_shape.dimensions(output_sharded_dim);
2015 // Use halo exchange to fix misaligned data.
2016 Window window;
2017 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
2018 WindowDimension* dim = window.add_dimensions();
2019 dim->set_size(1);
2020 dim->set_stride(1);
2021 dim->set_window_dilation(1);
2022 dim->set_window_reversal(false);
2023 dim->set_base_dilation(1);
2024 dim->set_padding_low(0);
2025 if (i == input_sharded_dim) {
2026 dim->set_padding_high(output_shard_size * split_factor *
2027 num_partitions_ / replication_count -
2028 input_dim_size);
2029 } else {
2030 dim->set_padding_high(0);
2031 }
2032 }
2033
2034 auto reshard_operand = operand.ReshardAsWindowedInput(
2035 window, operand.sharding(),
2036 CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
2037 /*mask_invalid_region=*/false);
2038 if (!reshard_operand.has_value()) {
2039 return DefaultAction(hlo);
2040 }
2041 TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value());
2042 CHECK_EQ(
2043 reshard_operand->sharded_input->shape().dimensions(input_sharded_dim),
2044 output_shard_size * split_factor);
2045 SetPartitionedHlo(hlo, [&] {
2046 // Do a local reshape.
2047 return b_.AddInstruction(HloInstruction::CreateReshape(
2048 output_shard_shape, reshard_operand->sharded_input));
2049 });
2050 return Status::OK();
2051 } else if (output_dim_size % input_dim_size == 0) {
2052 // Merge dims.
2053 int64_t merge_factor = output_dim_size / input_dim_size;
2054 // First reshape locally. (The sharded dimension could include padded data.)
2055 auto tmp_shard_shape = output_shard_shape;
2056 tmp_shard_shape.set_dimensions(
2057 output_sharded_dim,
2058 input_shard_shape.dimensions(input_sharded_dim) * merge_factor);
2059 auto tmp_reshape = b_.AddInstruction(
2060 HloInstruction::CreateReshape(tmp_shard_shape, operand.hlo()));
2061 tmp_reshape->set_sharding(hlo->sharding());
2062 auto tmp_full_shape = tmp_shard_shape;
2063 tmp_full_shape.set_dimensions(
2064 output_sharded_dim, tmp_shard_shape.dimensions(output_sharded_dim) *
2065 num_partitions_ / replication_count);
2066 auto tmp_output =
2067 PartitionedHlo(tmp_reshape, tmp_full_shape, MakePartitioningState());
2068
2069 // Use halo exchange to fix misaligned data.
2070 Window window;
2071 for (int64_t i = 0; i < tmp_shard_shape.rank(); ++i) {
2072 WindowDimension* dim = window.add_dimensions();
2073 dim->set_size(1);
2074 dim->set_stride(1);
2075 dim->set_window_dilation(1);
2076 dim->set_window_reversal(false);
2077 dim->set_base_dilation(1);
2078 dim->set_padding_low(0);
2079 if (i == output_sharded_dim) {
2080 dim->set_padding_high(output_dim_size -
2081 tmp_shard_shape.dimensions(output_sharded_dim) *
2082 num_partitions_ / replication_count);
2083 } else {
2084 dim->set_padding_high(0);
2085 }
2086 }
2087
2088 auto reshard_output = tmp_output.ReshardAsWindowedInput(
2089 window, sharding,
2090 CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
2091 /*mask_invalid_region=*/false);
2092 if (!reshard_output.has_value()) {
2093 return DefaultAction(hlo);
2094 }
2095 TF_RET_CHECK(!reshard_output->dynamic_slice_index_on_output.has_value());
2096 CHECK_EQ(
2097 reshard_output->sharded_input->shape().dimensions(output_sharded_dim),
2098 output_shard_shape.dimensions(output_sharded_dim));
2099 SetPartitionedHlo(hlo, [&] { return reshard_output->sharded_input; });
2100 return Status::OK();
2101 }
2102 return DefaultAction(hlo);
2103 }
2104
HandleIota(HloInstruction * hlo)2105 Status SpmdPartitioningVisitor::HandleIota(HloInstruction* hlo) {
2106 const HloSharding& sharding = hlo->sharding();
2107 if (sharding.IsTileMaximal()) {
2108 return DefaultAction(hlo);
2109 }
2110
2111 SetPartitionedHlo(hlo, [&] {
2112 int64_t dimension = Cast<HloIotaInstruction>(hlo)->iota_dimension();
2113 auto iota = b_.AddInstruction(HloInstruction::CreateIota(
2114 MakePartitionedShape(hlo->shape(), sharding), dimension));
2115
2116 if (sharding.tile_assignment().dim(dimension) > 1) {
2117 auto partition_ordinals = MakeTiledPartitionOrdinals(
2118 sharding, MakePartitioningState().partition_id, &b_);
2119 auto multiplier = b_.AddInstruction(HloInstruction::CreateConstant(
2120 LiteralUtil::CreateR0<int32>(iota->shape().dimensions(dimension))));
2121 auto offset = b_.AddInstruction(HloInstruction::CreateBinary(
2122 ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply,
2123 partition_ordinals[dimension], multiplier));
2124 if (iota->shape().element_type() != S32) {
2125 offset = b_.AddInstruction(HloInstruction::CreateConvert(
2126 ShapeUtil::MakeShape(iota->shape().element_type(), {}), offset));
2127 }
2128 auto broadcast = b_.AddInstruction(
2129 HloInstruction::CreateBroadcast(iota->shape(), offset, {}));
2130 return b_.AddInstruction(HloInstruction::CreateBinary(
2131 iota->shape(), HloOpcode::kAdd, iota, broadcast));
2132 }
2133
2134 return iota;
2135 });
2136
2137 return Status::OK();
2138 }
2139
HandleSingleDevice(const HloInstruction * hlo)2140 Status SpmdPartitioningVisitor::HandleSingleDevice(const HloInstruction* hlo) {
2141 TF_RET_CHECK(hlo->sharding().HasUniqueDevice());
2142 int64_t device = hlo->sharding().GetUniqueDevice();
2143 const HloSharding sharding = HloSharding::AssignDevice(device);
2144
2145 std::vector<HloInstruction*> operands;
2146 std::vector<Shape> operand_shapes;
2147 for (const HloInstruction* operand : hlo->operands()) {
2148 operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
2149 operand_shapes.push_back(operand->shape());
2150 }
2151 auto operand = b_.AddInstruction(HloInstruction::CreateTuple(operands));
2152 auto operand_shape = ShapeUtil::MakeTupleShape(operand_shapes);
2153
2154 auto on_device = b_.AddInstruction(
2155 HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(device)));
2156 auto pred = b_.AddInstruction(HloInstruction::CreateCompare(
2157 ShapeUtil::MakeShape(PRED, {}), MakePartitioningState().partition_id,
2158 on_device, ComparisonDirection::kEq));
2159
2160 SpmdBuilder true_b("true_computation", visiting_hlo_);
2161 HloComputation* true_computation;
2162 {
2163 auto param = true_b.AddInstruction(HloInstruction::CreateParameter(
2164 /*parameter_number=*/0, operand_shape, "true_branch_param"));
2165 std::vector<HloInstruction*> new_operands;
2166 for (int64_t i = 0; i < operands.size(); ++i) {
2167 new_operands.push_back(true_b.AddInstruction(
2168 HloInstruction::CreateGetTupleElement(operand_shapes[i], param, i)));
2169 }
2170 auto root = true_b.AddInstruction(
2171 hlo->CloneWithNewOperands(hlo->shape(), new_operands));
2172 true_computation = module_->AddEmbeddedComputation(true_b.Build(root));
2173 }
2174
2175 SpmdBuilder false_b("false_computation", visiting_hlo_);
2176 HloComputation* false_computation;
2177 {
2178 false_b.AddInstruction(HloInstruction::CreateParameter(
2179 /*parameter_number=*/0, operand_shape, "false_branch_param"));
2180 auto root = CreateZero(hlo->shape(), &false_b);
2181 false_computation = module_->AddEmbeddedComputation(false_b.Build(root));
2182 }
2183
2184 SetPartitionedHlo(hlo, [&]() {
2185 return b_.AddInstruction(HloInstruction::CreateConditional(
2186 hlo->shape(), pred, operand, true_computation, operand,
2187 false_computation));
2188 });
2189 return Status::OK();
2190 }
2191
HandleAllReduce(HloInstruction * hlo)2192 Status SpmdPartitioningVisitor::HandleAllReduce(HloInstruction* hlo) {
2193 if (hlo->IsCrossReplicaAllReduce() && hlo->operand_count() == 1) {
2194 return HandleElementwise(hlo);
2195 }
2196 return DefaultAction(hlo);
2197 }
2198
HandleBroadcast(HloInstruction * hlo)2199 Status SpmdPartitioningVisitor::HandleBroadcast(HloInstruction* hlo) {
2200 if (hlo->sharding().IsTileMaximal()) {
2201 return DefaultAction(hlo);
2202 }
2203
2204 auto& operand = GetPartitionedHlo(hlo->operand(0));
2205
2206 // Tiled output.
2207 std::vector<int64> new_dims;
2208 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
2209 if (!absl::c_linear_search(hlo->dimensions(), i)) {
2210 new_dims.push_back(i);
2211 }
2212 }
2213 auto desired_input_sharding = hlo_sharding_util::RemoveShapeDimensions(
2214 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(hlo->sharding(),
2215 new_dims),
2216 new_dims);
2217 auto input = operand.Reshard(desired_input_sharding).hlo();
2218 auto output_shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
2219 SetPartitionedHlo(hlo, [&] {
2220 return b_.AddInstruction(
2221 hlo->CloneWithNewOperands(output_shard_shape, {input}));
2222 });
2223 return Status::OK();
2224 }
2225
HandleConstant(HloInstruction * hlo)2226 Status SpmdPartitioningVisitor::HandleConstant(HloInstruction* hlo) {
2227 const Literal& literal = hlo->literal();
2228 if (literal.shape().IsTuple() ||
2229 (!hlo->sharding().IsTileMaximal() &&
2230 (!EvenlyPartitions(hlo->shape(), hlo->sharding()) ||
2231 !literal.IsAllFirst()))) {
2232 return DefaultAction(hlo);
2233 }
2234
2235 SetPartitionedHlo(hlo, [&]() {
2236 auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
2237 std::vector<int64> start_indices(hlo->shape().rank(), 0);
2238 auto constant = b_.AddInstruction(HloInstruction::CreateConstant(
2239 literal.Slice(start_indices, shard_shape.dimensions())));
2240 *constant->mutable_shape() = shard_shape;
2241 return constant;
2242 });
2243 return Status::OK();
2244 }
2245
HandleDynamicSlice(HloInstruction * hlo)2246 Status SpmdPartitioningVisitor::HandleDynamicSlice(HloInstruction* hlo) {
2247 if (hlo->sharding().IsTileMaximal()) {
2248 return DefaultAction(hlo);
2249 }
2250 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
2251 if (hlo->sharding().tile_assignment().dim(i) != 1 &&
2252 (hlo->dynamic_slice_sizes()[i] != hlo->shape().dimensions(i) ||
2253 !hlo->operand(i + 1)->IsConstant() ||
2254 !hlo->operand(i + 1)->literal().IsZero({}))) {
2255 // We currently do not partition the sliced dimensions.
2256 return DefaultAction(hlo);
2257 }
2258 }
2259 std::vector<HloInstruction*> new_indices(hlo->shape().rank());
2260 auto new_input =
2261 GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo();
2262 for (int64_t i = 0; i < new_indices.size(); ++i) {
2263 // Replicate the indices.
2264 new_indices[i] = GetPartitionedHlo(hlo->operand(i + 1))
2265 .Reshard(HloSharding::Replicate())
2266 .hlo();
2267 }
2268 SetPartitionedHlo(hlo, [&]() {
2269 auto partitioned_shape =
2270 MakePartitionedShape(hlo->shape(), hlo->sharding());
2271 return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
2272 partitioned_shape, new_input, new_indices,
2273 partitioned_shape.dimensions()));
2274 });
2275 return Status::OK();
2276 }
2277
HandleDynamicUpdateSlice(HloInstruction * hlo)2278 Status SpmdPartitioningVisitor::HandleDynamicUpdateSlice(HloInstruction* hlo) {
2279 if (hlo->sharding().IsTileMaximal()) {
2280 return DefaultAction(hlo);
2281 }
2282
2283 std::vector<int64> partitioned_slice_dims;
2284 std::vector<int64> slice_dims;
2285 std::vector<int64> partitioned_non_slice_dims;
2286 std::vector<int64> partitioned_slice_offsets;
2287 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
2288 if (hlo->operand(1)->shape().dimensions(i) != hlo->shape().dimensions(i)) {
2289 slice_dims.push_back(i);
2290 int64_t slice_size = hlo->operand(1)->shape().dimensions(i);
2291 if (hlo->sharding().tile_assignment().dim(i) != 1) {
2292 if (!hlo->operand(i + 2)->IsConstant() && slice_size != 1) {
2293 return DefaultAction(hlo);
2294 }
2295 partitioned_slice_dims.push_back(i);
2296 // Set partitioned_slice_offsets to -1 when slice_size is 1.
2297 if (slice_size == 1) {
2298 partitioned_slice_offsets.push_back(-1);
2299 } else {
2300 partitioned_slice_offsets.push_back(
2301 hlo->operand(i + 2)->literal().Get<int>({}));
2302 }
2303 }
2304 } else if (hlo->sharding().tile_assignment().dim(i) != 1) {
2305 if (!hlo->operand(i + 2)->IsConstant() ||
2306 !hlo->operand(i + 2)->literal().IsZero({})) {
2307 return DefaultAction(hlo);
2308 }
2309 partitioned_non_slice_dims.push_back(i);
2310 }
2311 }
2312
2313 // Handle when there is slice dim partitioned.
2314 if (!partitioned_slice_dims.empty()) {
2315 auto add_hlo = [&](std::unique_ptr<HloInstruction> to_add) {
2316 return b_.AddInstruction(std::move(to_add));
2317 };
2318 std::vector<HloInstruction*> new_indices(hlo->shape().rank());
2319 for (int64_t i = 0; i < new_indices.size(); ++i) {
2320 // Replicate the indices.
2321 new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2))
2322 .Reshard(HloSharding::Replicate())
2323 .hlo();
2324 }
2325
2326 // Get partitioned input.
2327 const auto& dus_sharding = hlo->sharding();
2328 const auto& partitioned_input =
2329 GetPartitionedHlo(hlo->operand(0)).Reshard(dus_sharding).hlo();
2330
2331 // Get replicate update.
2332 auto update_sharding = HloSharding::Replicate();
2333 if (!partitioned_non_slice_dims.empty()) {
2334 // Do partial replicate for update if non slice dims are partitioned.
2335 update_sharding =
2336 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(dus_sharding,
2337 slice_dims);
2338 }
2339
2340 // TODO(wangtao): use collective permute for sharded update.
2341 HloInstruction* replicate_update =
2342 GetPartitionedHlo(hlo->operand(1)).Reshard(update_sharding).hlo();
2343
2344 const auto& update_shape = replicate_update->shape();
2345 const auto& partitioned_shape = partitioned_input->shape();
2346 auto partition_ordinals = MakeTiledPartitionOrdinals(
2347 hlo->sharding(), MakePartitioningState().partition_id, &b_);
2348 HloInstruction* all_dims_within_partition = add_hlo(
2349 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
2350
2351 for (int i = 0; i < partitioned_slice_dims.size(); ++i) {
2352 int dim = partitioned_slice_dims[i];
2353 // Calculate per partition size.
2354 const int64_t per_partition_size = partitioned_shape.dimensions(dim);
2355
2356 // Only update within a single partition is supported.
2357 // Will ignore this check when slice size is 1 where
2358 // partitioned_slice_offsets[i] is -1.
2359 if ((partitioned_slice_offsets[i] != -1) &&
2360 (partitioned_slice_offsets[i] / per_partition_size) !=
2361 ((partitioned_slice_offsets[i] + update_shape.dimensions(dim) -
2362 1) /
2363 per_partition_size)) {
2364 return DefaultAction(hlo);
2365 }
2366
2367 // within_partition = (offset >= partition_id * per_partition_size) &&
2368 // (offset < (partition_id + 1) * per_partition_size)
2369 const Shape& compare_shape =
2370 ShapeUtil::ChangeElementType(partition_id_->shape(), PRED);
2371 auto per_partition_size_hlo = add_hlo(HloInstruction::CreateConstant(
2372 LiteralUtil::CreateR0<int>(per_partition_size)));
2373 const Shape& offset_shape = per_partition_size_hlo->shape();
2374 auto partition_offset = add_hlo(HloInstruction::CreateBinary(
2375 offset_shape, HloOpcode::kMultiply, partition_ordinals[dim],
2376 per_partition_size_hlo));
2377 // offset >= partition_id * per_partition_size
2378 auto offset_ge = add_hlo(HloInstruction::CreateCompare(
2379 compare_shape, new_indices[dim], partition_offset,
2380 ComparisonDirection::kGe));
2381 // offset < (partition_id + 1) * per_partition_size
2382 auto offset_lt = add_hlo(HloInstruction::CreateCompare(
2383 compare_shape, new_indices[dim],
2384 add_hlo(HloInstruction::CreateBinary(
2385 offset_shape, HloOpcode::kMultiply,
2386 add_hlo(HloInstruction::CreateBinary(
2387 offset_shape, HloOpcode::kAdd, partition_ordinals[dim],
2388 add_hlo(HloInstruction::CreateConstant(
2389 LiteralUtil::CreateR0<int>(1))))),
2390 per_partition_size_hlo)),
2391 ComparisonDirection::kLt));
2392 auto update_within_partition = add_hlo(HloInstruction::CreateBinary(
2393 compare_shape, HloOpcode::kAnd, offset_ge, offset_lt));
2394
2395 all_dims_within_partition = add_hlo(HloInstruction::CreateBinary(
2396 compare_shape, HloOpcode::kAnd, all_dims_within_partition,
2397 update_within_partition));
2398
2399 // Calculate offset.
2400 // slice dim offset =
2401 // within_partition ?
2402 // offset - partition_id * per_partition_size : 0
2403 new_indices[dim] = add_hlo(HloInstruction::CreateTernary(
2404 new_indices[dim]->shape(), HloOpcode::kSelect,
2405 update_within_partition,
2406 add_hlo(HloInstruction::CreateBinary(
2407 new_indices[dim]->shape(), HloOpcode::kSubtract, new_indices[dim],
2408 partition_offset)),
2409 add_hlo(
2410 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)))));
2411 }
2412
2413 // Create dynamic update slice.
2414 auto dus = add_hlo(HloInstruction::CreateDynamicUpdateSlice(
2415 partitioned_shape, partitioned_input, replicate_update, new_indices));
2416 SetPartitionedHlo(hlo, [&]() {
2417 // Select if update is needed.
2418 return add_hlo(HloInstruction::CreateTernary(
2419 dus->shape(), HloOpcode::kSelect,
2420 add_hlo(HloInstruction::CreateBroadcast(
2421 ShapeUtil::ChangeElementType(dus->shape(), PRED),
2422 all_dims_within_partition, {})),
2423 dus, partitioned_input));
2424 });
2425 return Status::OK();
2426 }
2427
2428 // Partition non slice dims only.
2429 std::vector<HloInstruction*> new_indices(hlo->shape().rank());
2430 auto new_input =
2431 GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo();
2432 auto new_update =
2433 GetPartitionedHlo(hlo->operand(1)).Reshard(hlo->sharding()).hlo();
2434 for (int64_t i = 0; i < new_indices.size(); ++i) {
2435 // Replicate the indices.
2436 new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2))
2437 .Reshard(HloSharding::Replicate())
2438 .hlo();
2439 }
2440 SetPartitionedHlo(hlo, [&]() {
2441 auto partitioned_shape =
2442 MakePartitionedShape(hlo->shape(), hlo->sharding());
2443 return b_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2444 partitioned_shape, new_input, new_update, new_indices));
2445 });
2446 return Status::OK();
2447 }
2448
HandleGetTupleElement(HloInstruction * hlo)2449 Status SpmdPartitioningVisitor::HandleGetTupleElement(HloInstruction* hlo) {
2450 const auto& tuple = GetPartitionedHlo(hlo->operand(0));
2451 auto gte = b_.AddInstruction(HloInstruction::CreateGetTupleElement(
2452 ShapeUtil::GetTupleElementShape(tuple.hlo()->shape(), hlo->tuple_index()),
2453 tuple.hlo(), hlo->tuple_index()));
2454 const auto source_sharding =
2455 tuple.sharding().GetSubSharding(tuple.base_shape(), {hlo->tuple_index()});
2456 gte->set_sharding(source_sharding);
2457 PartitionedHlo source_partitioned_gte(
2458 gte, tuple.base_shape().tuple_shapes(hlo->tuple_index()),
2459 MakePartitioningState());
2460 source_partitioned_gte = source_partitioned_gte.Reshard(hlo->sharding());
2461 SetPartitionedHlo(hlo, source_partitioned_gte);
2462 return Status::OK();
2463 }
2464
HandleInfeed(HloInstruction * hlo)2465 Status SpmdPartitioningVisitor::HandleInfeed(HloInstruction* hlo) {
2466 const Shape& shape = ShapeUtil::GetTupleElementShape(hlo->shape(), 0);
2467 auto token = GetPartitionedHlo(hlo->operand(0)).hlo();
2468 if (ShapeUtil::GetLeafCount(shape) == 0) {
2469 // TODO(b/155819021): HloSharding has issues with tuple-shaped sharding: it
2470 // requires one element for an empty tuple, but leaf-count number of
2471 // elements for non-empty tuple. So if it has a nested empty tuple, we
2472 // cannot invoke GetSubSharding() since it expects a sharding for the empty
2473 // tuple. This is a workaround for that case.
2474 SetPartitionedHlo(hlo, [&]() {
2475 return b_.AddInstruction(
2476 HloInstruction::CreateInfeed(shape, token, hlo->infeed_config()));
2477 });
2478 return Status::OK();
2479 }
2480 auto sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0});
2481 auto shard_shape = MakePartitionedShape(shape, sharding);
2482 if (EvenlyPartitions(shape, sharding)) {
2483 SetPartitionedHlo(hlo, [&]() {
2484 return b_.AddInstruction(HloInstruction::CreateInfeed(
2485 shard_shape, token, hlo->infeed_config()));
2486 });
2487 return Status::OK();
2488 }
2489
2490 if (hlo->sharding().HasUniqueDevice()) {
2491 return HandleSingleDevice(hlo);
2492 }
2493
2494 // Create a branch for each unique partitioned shape.
2495 std::vector<Shape> per_branch_partitioned_shapes;
2496 std::vector<int32> conditional_branch_indices(num_partitions_);
2497 for (int64_t i = 0; i < num_partitions_; ++i) {
2498 auto partitioned_shape =
2499 MakeNonPaddedShapeForGivenPartition(shape, sharding, i);
2500 int64_t matching_existing_index = 0;
2501 for (; matching_existing_index < per_branch_partitioned_shapes.size();
2502 ++matching_existing_index) {
2503 if (ShapeUtil::Compatible(
2504 partitioned_shape,
2505 per_branch_partitioned_shapes[matching_existing_index])) {
2506 break;
2507 }
2508 }
2509 if (matching_existing_index < per_branch_partitioned_shapes.size()) {
2510 conditional_branch_indices[i] = matching_existing_index;
2511 } else {
2512 conditional_branch_indices[i] = per_branch_partitioned_shapes.size();
2513 per_branch_partitioned_shapes.push_back(std::move(partitioned_shape));
2514 }
2515 }
2516
2517 HloInstruction* branch_index;
2518 auto state = MakePartitioningState();
2519 if (per_branch_partitioned_shapes.size() == num_partitions_) {
2520 // Use partition ID as the branch index if each partition has its own
2521 // branch.
2522 branch_index = state.partition_id;
2523 // PartitionId's output is U32 but conditional requires S32.
2524 if (branch_index->shape().element_type() != S32) {
2525 branch_index = b_.AddInstruction(HloInstruction::CreateConvert(
2526 ShapeUtil::ChangeElementType(branch_index->shape(), S32),
2527 branch_index));
2528 }
2529 } else {
2530 // Otherwise, use a constant table to look up the branch index.
2531 auto branch_index_table = b_.AddInstruction(HloInstruction::CreateConstant(
2532 LiteralUtil::CreateR1<int32>(conditional_branch_indices)));
2533 branch_index = b_.AddInstruction(HloInstruction::CreateDynamicSlice(
2534 ShapeUtil::MakeShape(S32, {1}), branch_index_table,
2535 {state.partition_id}, {1}));
2536 branch_index = b_.AddInstruction(HloInstruction::CreateReshape(
2537 ShapeUtil::MakeShape(S32, {}), branch_index));
2538 }
2539
2540 std::vector<HloComputation*> branches(per_branch_partitioned_shapes.size());
2541 for (int64_t i = 0; i < branches.size(); ++i) {
2542 SpmdBuilder branch_b(absl::StrCat("infeed_branch_", i), visiting_hlo_);
2543 auto param = branch_b.AddInstruction(HloInstruction::CreateParameter(
2544 /*parameter_number=*/0, token->shape(), "infeed_token_param"));
2545 auto infeed = branch_b.AddInstruction(HloInstruction::CreateInfeed(
2546 per_branch_partitioned_shapes[i], param, hlo->infeed_config()));
2547 if (!ShapeUtil::Compatible(per_branch_partitioned_shapes[i], shard_shape)) {
2548 std::function<HloInstruction*(const ShapeIndex&, HloInstruction*)>
2549 pad_infeed = [&](const ShapeIndex& index,
2550 HloInstruction* infeed_element) -> HloInstruction* {
2551 if (index == ShapeIndex({1})) {
2552 // Token.
2553 return infeed_element;
2554 }
2555 const Shape& element_shape =
2556 ShapeUtil::GetSubshape(infeed->shape(), index);
2557 if (element_shape.IsTuple() && element_shape.tuple_shapes_size() > 0) {
2558 std::vector<HloInstruction*> padded_elements(
2559 element_shape.tuple_shapes_size());
2560 for (int64_t i = 0; i < padded_elements.size(); ++i) {
2561 auto sub_index = index;
2562 sub_index.push_back(i);
2563 padded_elements[i] = pad_infeed(
2564 sub_index,
2565 branch_b.AddInstruction(HloInstruction::CreateGetTupleElement(
2566 ShapeUtil::GetSubshape(element_shape, {i}), infeed_element,
2567 i)));
2568 }
2569 return branch_b.AddInstruction(
2570 HloInstruction::CreateTuple(padded_elements));
2571 }
2572 const Shape& pad_shape =
2573 ShapeUtil::GetSubshape(shard_shape, ShapeIndexView(index, 1));
2574 if (ShapeUtil::Compatible(element_shape, pad_shape)) {
2575 return infeed_element;
2576 }
2577 if (element_shape.IsArray()) {
2578 CHECK(pad_shape.IsArray());
2579 return PadToShape(infeed_element, pad_shape, &branch_b);
2580 }
2581 CHECK(element_shape.IsTuple());
2582 CHECK(element_shape.tuple_shapes().empty());
2583 return CreateZero(pad_shape, &branch_b);
2584 };
2585 pad_infeed({}, infeed);
2586 }
2587 branches[i] = module_->AddEmbeddedComputation(branch_b.Build());
2588 }
2589 SetPartitionedHlo(hlo, [&]() {
2590 return b_.AddInstruction(HloInstruction::CreateConditional(
2591 ShapeUtil::MakeTupleShape({shard_shape, token->shape()}), branch_index,
2592 branches, std::vector<HloInstruction*>(branches.size(), token)));
2593 });
2594 return Status::OK();
2595 }
2596
HandlePad(HloInstruction * hlo)2597 Status SpmdPartitioningVisitor::HandlePad(HloInstruction* hlo) {
2598 if (hlo->sharding().IsTileMaximal()) {
2599 return DefaultAction(hlo);
2600 }
2601 auto lhs = GetPartitionedHlo(hlo->operand(0));
2602 // Create a window config to represent the pad.
2603 Window window;
2604 bool needs_masking = false;
2605 const bool pad_value_is_zero =
2606 hlo->operand(1)->IsConstant() && hlo->operand(1)->literal().IsZero({});
2607 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
2608 const auto& pd = hlo->padding_config().dimensions(i);
2609 WindowDimension* dim = window.add_dimensions();
2610 dim->set_size(1);
2611 dim->set_stride(1);
2612 dim->set_window_dilation(1);
2613 dim->set_window_reversal(false);
2614 dim->set_padding_low(pd.edge_padding_low());
2615 dim->set_padding_high(pd.edge_padding_high());
2616 dim->set_base_dilation(pd.interior_padding() + 1);
2617 const int64_t shard_count = hlo->sharding().tile_assignment().dim(i);
2618 // Need masking only if there is non-zero padding value or the operand is
2619 // unevenly partitioned. Halo exchange fills 0 in collective permute result
2620 // for non-destination cores.
2621 needs_masking |=
2622 shard_count > 1 &&
2623 (pd.edge_padding_low() > 0 || pd.edge_padding_high() > 0 ||
2624 pd.interior_padding() > 0) &&
2625 (!pad_value_is_zero ||
2626 hlo->operand(0)->shape().dimensions(i) % shard_count != 0);
2627 }
2628
2629 auto replicated_rhs = GetPartitionedHlo(hlo->operand(1))
2630 .Reshard(HloSharding::Replicate())
2631 .hlo();
2632 auto reshard_operand =
2633 lhs.ReshardAsWindowedInput(window, hlo->sharding(), replicated_rhs,
2634 /*mask_invalid_region=*/needs_masking);
2635 if (!reshard_operand.has_value()) {
2636 return DefaultAction(hlo);
2637 }
2638 PaddingConfig sharded_padding_config;
2639 bool need_pad = false;
2640 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
2641 auto dim = sharded_padding_config.add_dimensions();
2642 const auto& wd = reshard_operand->shard_window.dimensions(i);
2643 dim->set_edge_padding_low(wd.padding_low());
2644 dim->set_edge_padding_high(wd.padding_high());
2645 dim->set_interior_padding(wd.base_dilation() - 1);
2646 if (wd.padding_low() != 0 || wd.padding_high() != 0 ||
2647 wd.base_dilation() != 1) {
2648 need_pad = true;
2649 }
2650 }
2651 auto sharded_pad = reshard_operand->sharded_input;
2652 if (need_pad) {
2653 TF_ASSIGN_OR_RETURN(auto sharded_pad_shape,
2654 ShapeInference::InferPadShape(sharded_pad->shape(),
2655 replicated_rhs->shape(),
2656 sharded_padding_config));
2657 sharded_pad = b_.AddInstruction(hlo->CreatePad(sharded_pad_shape,
2658 sharded_pad, replicated_rhs,
2659 sharded_padding_config));
2660 }
2661
2662 SetPartitionedHlo(hlo, [&]() {
2663 if (!reshard_operand->dynamic_slice_index_on_output) {
2664 return sharded_pad;
2665 }
2666 auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
2667 return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
2668 shard_shape, sharded_pad,
2669 *reshard_operand->dynamic_slice_index_on_output,
2670 shard_shape.dimensions()));
2671 });
2672 return Status::OK();
2673 }
2674
HandleParameter(HloInstruction * hlo)2675 Status SpmdPartitioningVisitor::HandleParameter(HloInstruction* hlo) {
2676 SetPartitionedHlo(hlo, [&]() {
2677 auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
2678 auto new_param = b_.AddInstruction(HloInstruction::CreateParameter(
2679 hlo->parameter_number(), shard_shape, "param"));
2680 if (hlo->parameter_replicated_at_leaf_buffers()) {
2681 new_param->set_parameter_replicated_at_leaf_buffers(
2682 *hlo->parameter_replicated_at_leaf_buffers());
2683 }
2684 return new_param;
2685 });
2686 return Status::OK();
2687 }
2688
HandleReduce(HloInstruction * hlo)2689 Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) {
2690 if (hlo->sharding().HasUniqueDevice()) {
2691 return DefaultAction(hlo);
2692 }
2693 int64_t input_count = 1;
2694 auto per_input_sharding = hlo->sharding();
2695 if (hlo->shape().IsTuple()) {
2696 input_count = hlo->shape().tuple_shapes_size();
2697 CHECK_GT(input_count, 0);
2698 per_input_sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0});
2699 }
2700
2701 std::vector<PartitionedHlo> inputs;
2702 std::vector<HloInstruction*> inits;
2703 std::vector<int64> preserved_dims;
2704 for (int64_t i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
2705 if (!absl::c_linear_search(hlo->dimensions(), i)) {
2706 preserved_dims.push_back(i);
2707 }
2708 }
2709
2710 for (int64_t operand_id = 0; operand_id < input_count; ++operand_id) {
2711 inits.push_back(GetPartitionedHlo(hlo->operand(operand_id + input_count))
2712 .Reshard(HloSharding::Replicate())
2713 .hlo());
2714 inputs.push_back(GetPartitionedHlo(hlo->operand(operand_id)));
2715 if (operand_id > 0) {
2716 // Make sure all operands are sharded in the same way.
2717 inputs.back() = inputs.back().Reshard(inputs[0].sharding());
2718 }
2719 if (!inputs[0].sharding().IsTileMaximal()) {
2720 inputs.back() =
2721 inputs.back().PadWithValue(inits[operand_id], /*left_padded_dims=*/{},
2722 /*skipped_dims=*/preserved_dims);
2723 }
2724 }
2725
2726 std::vector<Shape*> new_operand_shapes(input_count * 2);
2727 for (int64_t i = 0; i < input_count; ++i) {
2728 new_operand_shapes[i] = inputs[i].hlo()->mutable_shape();
2729 new_operand_shapes[i + input_count] = inits[i]->mutable_shape();
2730 }
2731 // Create the shard shape of the reduce result.
2732 TF_ASSIGN_OR_RETURN(
2733 auto reduce_shape,
2734 ShapeInference::InferReduceShape(new_operand_shapes, hlo->dimensions(),
2735 hlo->to_apply()->ComputeProgramShape()));
2736
2737 std::vector<HloInstruction*> input_hlos(input_count);
2738 for (int64_t i = 0; i < input_count; ++i) {
2739 input_hlos[i] = inputs[i].hlo();
2740 }
2741 auto local_reduce = b_.AddInstruction(HloInstruction::CreateReduce(
2742 reduce_shape, input_hlos, inits, hlo->dimensions(), hlo->to_apply()));
2743
2744 SetPartitionedHlo(hlo, [&]() {
2745 HloInstruction* reduce = local_reduce;
2746 const bool reduce_sharded_dimension =
2747 !inputs[0].sharding().IsTileMaximal() &&
2748 absl::c_any_of(hlo->dimensions(), [&](int64_t i) {
2749 return inputs[0].sharding().tile_assignment().dim(i) > 1;
2750 });
2751 if (reduce_sharded_dimension) {
2752 if (inputs[0].sharding().ReplicateOnLastTileDim()) {
2753 preserved_dims.push_back(inputs[0].base_shape().rank());
2754 }
2755 if (local_reduce->shape().IsArray()) {
2756 reduce = partitioner_->AllReduceAlongShardingDims(
2757 &b_, local_reduce, inputs[0].sharding(), next_channel_id_,
2758 hlo->dimensions(), collective_ops_creator_, hlo->to_apply());
2759 } else {
2760 auto grouped =
2761 GroupShardingOnDims(inputs[0].sharding(), preserved_dims);
2762 auto grouped_state = CreatePerGroupPartitioningState(
2763 inputs[0].state(), grouped.device_groups, &b_);
2764 std::vector<HloInstruction*> all_gathered_partial_results(input_count);
2765 for (int64_t i = 0; i < input_count; ++i) {
2766 auto gte = b_.AddInstruction(HloInstruction::CreateGetTupleElement(
2767 ShapeUtil::GetTupleElementShape(reduce_shape, i), local_reduce,
2768 i));
2769 auto expanded_shape = input_hlos[i]->shape();
2770 auto all_gather_shape = input_hlos[i]->shape();
2771 for (int64_t dim : hlo->dimensions()) {
2772 expanded_shape.set_dimensions(dim, 1);
2773 all_gather_shape.set_dimensions(
2774 dim, inputs[0].sharding().tile_assignment().dim(dim));
2775 }
2776 auto reshape = b_.AddInstruction(
2777 HloInstruction::CreateReshape(expanded_shape, gte));
2778 // Replicate per group.
2779 reshape->set_sharding(grouped.sharding);
2780 all_gathered_partial_results[i] =
2781 PartitionedHlo(reshape, all_gather_shape, grouped_state)
2782 .Replicate()
2783 .hlo();
2784 }
2785 reduce = b_.AddInstruction(HloInstruction::CreateReduce(
2786 reduce_shape, all_gathered_partial_results, inits,
2787 hlo->dimensions(), hlo->to_apply()));
2788 }
2789 }
2790 auto sharding = hlo_sharding_util::RemoveShapeDimensions(
2791 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2792 inputs[0].sharding(), hlo->dimensions()),
2793 hlo->dimensions());
2794 if (local_reduce->shape().IsArray()) {
2795 reduce->set_sharding(sharding);
2796 } else {
2797 reduce->set_sharding(HloSharding::Tuple(
2798 reduce->shape(), std::vector<HloSharding>(input_count, sharding)));
2799 }
2800 return PartitionedHlo(reduce, hlo->shape(), MakePartitioningState())
2801 .Reshard(hlo->sharding())
2802 .hlo();
2803 });
2804 return Status::OK();
2805 }
2806
HandleReverse(HloInstruction * hlo)2807 Status SpmdPartitioningVisitor::HandleReverse(HloInstruction* hlo) {
2808 auto reverse = Cast<HloReverseInstruction>(hlo);
2809 if (reverse->sharding().IsTileMaximal()) {
2810 return DefaultAction(hlo);
2811 }
2812 auto operand = GetPartitionedHlo(reverse->operand(0))
2813 .Reshard(hlo_sharding_util::ReverseSharding(
2814 reverse->sharding(), reverse->dimensions()));
2815 auto left_padded_operand =
2816 HaloExchangeToPadOnLeft(operand, reverse->dimensions());
2817 if (!left_padded_operand) {
2818 return DefaultAction(hlo);
2819 }
2820 SetPartitionedHlo(hlo, [&] {
2821 return b_.AddInstruction(hlo->CloneWithNewOperands(
2822 left_padded_operand->shape(), {left_padded_operand}));
2823 });
2824 return Status::OK();
2825 }
2826
HandleWhile(HloInstruction * hlo)2827 Status SpmdPartitioningVisitor::HandleWhile(HloInstruction* hlo) {
2828 const HloSharding& sharding = hlo->sharding();
2829
2830 // Shardings for the body parameter, body root, and cond parameter must be
2831 // the same, and the condition root must be replicated so that all partitions
2832 // follow the same control flow.
2833 hlo->while_condition()->parameter_instruction(0)->set_sharding(sharding);
2834 hlo->while_body()->parameter_instruction(0)->set_sharding(sharding);
2835 const HloSharding& cond_root_sharding =
2836 hlo->while_condition()->root_instruction()->sharding();
2837 TF_RETURN_IF_ERROR(partitioner_
2838 ->PartitionComputation(hlo->while_condition(),
2839 cond_root_sharding.IsManual()
2840 ? cond_root_sharding
2841 : HloSharding::Replicate(),
2842 next_channel_id_, logger_)
2843 .status());
2844 TF_RETURN_IF_ERROR(partitioner_
2845 ->PartitionComputation(hlo->while_body(), sharding,
2846 next_channel_id_, logger_)
2847 .status());
2848 SetPartitionedHlo(hlo, [&] {
2849 return b_.AddInstruction(HloInstruction::CreateWhile(
2850 MakePartitionedShape(hlo->shape(), sharding), hlo->while_condition(),
2851 hlo->while_body(),
2852 GetPartitionedHlo(hlo->operand(0)).Reshard(sharding).hlo()));
2853 });
2854 return Status::OK();
2855 }
2856
HandleConditional(HloInstruction * hlo)2857 Status SpmdPartitioningVisitor::HandleConditional(HloInstruction* hlo) {
2858 std::vector<HloInstruction*> branch_args;
2859 for (int64_t i = 0; i < hlo->branch_count(); ++i) {
2860 HloComputation* computation = hlo->branch_computation(i);
2861
2862 // Shardings of the branch computation parameter and its argument must be
2863 // the same.
2864 computation->parameter_instruction(0)->set_sharding(
2865 hlo->operand(i + 1)->sharding());
2866 branch_args.push_back(GetPartitionedHlo(hlo->operand(i + 1)).hlo());
2867 }
2868
2869 // The root of the branch computations must follow the sharding of the
2870 // conditional instruction.
2871 for (int64_t i = 0; i < hlo->branch_count(); ++i) {
2872 HloComputation* computation = hlo->branch_computation(i);
2873 TF_RETURN_IF_ERROR(partitioner_
2874 ->PartitionComputation(computation, hlo->sharding(),
2875 next_channel_id_, logger_)
2876 .status());
2877 }
2878 SetPartitionedHlo(hlo, [&] {
2879 HloInstruction* cond = GetPartitionedHlo(hlo->operand(0)).hlo();
2880 if (!hlo->operand(0)->sharding().IsManual()) {
2881 // We replicate the predicate of the conditional (the first operand) so
2882 // that all partitions follow the same control flow.
2883 cond = GetPartitionedHlo(hlo->operand(0))
2884 .Reshard(HloSharding::Replicate())
2885 .hlo();
2886 }
2887 return b_.AddInstruction(HloInstruction::CreateConditional(
2888 MakePartitionedShape(hlo->shape(), hlo->sharding()), cond,
2889 hlo->called_computations(), branch_args));
2890 });
2891 return Status::OK();
2892 }
2893
HandleOutfeed(HloInstruction * hlo)2894 Status SpmdPartitioningVisitor::HandleOutfeed(HloInstruction* hlo) {
2895 if (hlo->sharding().HasUniqueDevice()) {
2896 return HandleSingleDevice(hlo);
2897 }
2898
2899 const auto& sharding = hlo->sharding();
2900 const Shape& shape = hlo->operand(0)->shape();
2901 auto partitioned_operand =
2902 GetPartitionedHlo(hlo->operand(0)).Reshard(sharding);
2903 const auto& shard_shape = partitioned_operand.hlo()->shape();
2904 const auto& operand = partitioned_operand.hlo();
2905 auto token = GetPartitionedHlo(hlo->operand(1)).hlo();
2906
2907 if (EvenlyPartitions(shape, sharding)) {
2908 Shape outfeed_shape = operand->shape();
2909 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(hlo->outfeed_shape(),
2910 &outfeed_shape));
2911 SetPartitionedHlo(hlo, [&]() {
2912 return b_.AddInstruction(HloInstruction::CreateOutfeed(
2913 outfeed_shape, operand, token, hlo->outfeed_config()));
2914 });
2915 return Status::OK();
2916 }
2917
2918 // Create a branch for each unique partitioned shape.
2919 std::vector<Shape> per_branch_partitioned_shapes;
2920 std::vector<int32> conditional_branch_indices(num_partitions_);
2921 for (int64_t i = 0; i < num_partitions_; ++i) {
2922 auto partitioned_shape =
2923 MakeNonPaddedShapeForGivenPartition(shape, sharding, i);
2924 int64_t matching_existing_index = 0;
2925 for (; matching_existing_index < per_branch_partitioned_shapes.size();
2926 ++matching_existing_index) {
2927 if (ShapeUtil::Compatible(
2928 partitioned_shape,
2929 per_branch_partitioned_shapes[matching_existing_index])) {
2930 break;
2931 }
2932 }
2933 if (matching_existing_index < per_branch_partitioned_shapes.size()) {
2934 conditional_branch_indices[i] = matching_existing_index;
2935 } else {
2936 conditional_branch_indices[i] = per_branch_partitioned_shapes.size();
2937 per_branch_partitioned_shapes.push_back(std::move(partitioned_shape));
2938 }
2939 }
2940
2941 // Get branch index for this partition.
2942 HloInstruction* branch_index;
2943 auto state = MakePartitioningState();
2944 if (per_branch_partitioned_shapes.size() == num_partitions_) {
2945 // Use partition ID as the branch index if each partition has its own
2946 // branch.
2947 branch_index = state.partition_id;
2948 // PartitionId's output is U32 but conditional requires S32.
2949 if (branch_index->shape().element_type() != S32) {
2950 branch_index = b_.AddInstruction(HloInstruction::CreateConvert(
2951 ShapeUtil::ChangeElementType(branch_index->shape(), S32),
2952 branch_index));
2953 }
2954 } else {
2955 // Otherwise, use a constant table to look up the branch index.
2956 auto branch_index_table = b_.AddInstruction(HloInstruction::CreateConstant(
2957 LiteralUtil::CreateR1<int32>(conditional_branch_indices)));
2958 branch_index = b_.AddInstruction(HloInstruction::CreateDynamicSlice(
2959 ShapeUtil::MakeShape(S32, {1}), branch_index_table, {partition_id_},
2960 {1}));
2961 branch_index = b_.AddInstruction(HloInstruction::CreateReshape(
2962 ShapeUtil::MakeShape(S32, {}), branch_index));
2963 }
2964
2965 // Create conditional for the outfeed.
2966 std::vector<HloComputation*> branches(per_branch_partitioned_shapes.size());
2967 for (int64_t i = 0; i < branches.size(); ++i) {
2968 SpmdBuilder branch_b(absl::StrCat("outfeed_branch_", i), visiting_hlo_);
2969 // Create tuple param within the branch.
2970 auto param = branch_b.AddInstruction(HloInstruction::CreateParameter(
2971 /*parameter_number=*/0,
2972 ShapeUtil::MakeTupleShape({operand->shape(), token->shape()}),
2973 "outfeed_token_param"));
2974 auto outfeed_data = branch_b.AddInstruction(
2975 HloInstruction::CreateGetTupleElement(operand->shape(), param, 0));
2976 auto outfeed_token = branch_b.AddInstruction(
2977 HloInstruction::CreateGetTupleElement(token->shape(), param, 1));
2978 if (!ShapeUtil::Compatible(per_branch_partitioned_shapes[i], shard_shape)) {
2979 std::function<HloInstruction*(const ShapeIndex&, HloInstruction*)>
2980 slice_outfeed =
2981 [&](const ShapeIndex& index,
2982 HloInstruction* outfeed_operand) -> HloInstruction* {
2983 // Get outfeed element shape.
2984 const Shape& element_shape =
2985 ShapeUtil::GetSubshape(outfeed_data->shape(), index);
2986 // Recursively call slice_outfeed for tuple shapes.
2987 if (element_shape.IsTuple() && element_shape.tuple_shapes_size() > 0) {
2988 std::vector<HloInstruction*> slice_elements(
2989 element_shape.tuple_shapes_size());
2990 for (int64_t i = 0; i < slice_elements.size(); ++i) {
2991 auto sub_index = index;
2992 sub_index.push_back(i);
2993 slice_elements[i] = slice_outfeed(
2994 sub_index,
2995 branch_b.AddInstruction(HloInstruction::CreateGetTupleElement(
2996 ShapeUtil::GetSubshape(element_shape, {i}), outfeed_operand,
2997 i)));
2998 }
2999 return branch_b.AddInstruction(
3000 HloInstruction::CreateTuple(slice_elements));
3001 }
3002 // Get the slice shape.
3003 const Shape& slice_shape = ShapeUtil::GetSubshape(
3004 per_branch_partitioned_shapes[i], ShapeIndexView(index));
3005 if (ShapeUtil::Compatible(element_shape, slice_shape)) {
3006 return outfeed_operand;
3007 }
3008 // Slice out useful data.
3009 if (element_shape.IsArray()) {
3010 CHECK(slice_shape.IsArray());
3011 std::vector<int64> start_indices(slice_shape.rank(), 0);
3012 std::vector<int64> slice_strides(slice_shape.rank(), 1);
3013 return branch_b.AddInstruction(HloInstruction::CreateSlice(
3014 slice_shape, outfeed_operand, start_indices,
3015 slice_shape.dimensions(), slice_strides));
3016 }
3017 CHECK(element_shape.IsTuple());
3018 CHECK(element_shape.tuple_shapes().empty());
3019 return outfeed_operand;
3020 };
3021 outfeed_data = slice_outfeed({}, outfeed_data);
3022 }
3023 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
3024 hlo->outfeed_shape(), &per_branch_partitioned_shapes[i]));
3025 branch_b.AddInstruction(HloInstruction::CreateOutfeed(
3026 per_branch_partitioned_shapes[i], outfeed_data, outfeed_token,
3027 hlo->outfeed_config()));
3028 branches[i] = module_->AddEmbeddedComputation(branch_b.Build());
3029 }
3030 SetPartitionedHlo(hlo, [&]() {
3031 return b_.AddInstruction(HloInstruction::CreateConditional(
3032 token->shape(), branch_index, branches,
3033 std::vector<HloInstruction*>(
3034 branches.size(),
3035 b_.AddInstruction(HloInstruction::CreateTuple({operand, token})))));
3036 });
3037 return Status::OK();
3038 }
3039
HandleRng(HloInstruction * hlo)3040 Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) {
3041 if (hlo->sharding().HasUniqueDevice()) {
3042 return HandleSingleDevice(hlo);
3043 }
3044 auto clone_from_original = [&](const HloSharding& shared_sharding) {
3045 std::vector<HloInstruction*> new_operands;
3046 for (int64_t i = 0; i < hlo->operand_count(); ++i) {
3047 new_operands.push_back(
3048 GetPartitionedHlo(hlo->operand(i)).Reshard(shared_sharding).hlo());
3049 }
3050 auto clone = b_.AddInstruction(
3051 hlo->CloneWithNewOperands(hlo->shape(), new_operands));
3052 clone->set_sharding(shared_sharding);
3053 return clone;
3054 };
3055
3056 if (hlo->sharding().IsManual()) {
3057 SetPartitionedHlo(hlo,
3058 [&] { return clone_from_original(hlo->sharding()); });
3059 return Status::OK();
3060 }
3061
3062 if (hlo->sharding().IsReplicated()) {
3063 SetPartitionedHlo(hlo, [&] {
3064 // Run on a single device (0) and distribute the data to all other cores.
3065 auto clone = clone_from_original(HloSharding::AssignDevice(0));
3066 return PartitionedHlo(clone, hlo->shape(), MakePartitioningState())
3067 .Reshard(HloSharding::Replicate())
3068 .hlo();
3069 });
3070 return Status::OK();
3071 }
3072
3073 TF_RET_CHECK(!hlo->sharding().IsTileMaximal());
3074 // Replicate the operands and run partitioned Rng on all devices.
3075 std::vector<HloInstruction*> new_operands;
3076 for (int64_t i = 0; i < hlo->operand_count(); ++i) {
3077 new_operands.push_back(GetPartitionedHlo(hlo->operand(i))
3078 .Reshard(HloSharding::Replicate())
3079 .hlo());
3080 }
3081
3082 if (!hlo->sharding().ReplicateOnLastTileDim()) {
3083 SetPartitionedHlo(hlo, [&] {
3084 return b_.AddInstruction(HloInstruction::CreateRng(
3085 MakePartitionedShape(hlo->shape(), hlo->sharding()),
3086 hlo->random_distribution(), new_operands));
3087 });
3088 } else {
3089 std::vector<int64> group_dims(
3090 hlo->sharding().tile_assignment().num_dimensions() - 1);
3091 std::iota(group_dims.begin(), group_dims.end(), 0);
3092 auto sharding_grouped = GroupShardingOnDims(hlo->sharding(), group_dims);
3093 auto per_group_state = CreatePerGroupPartitioningState(
3094 MakePartitioningState(), sharding_grouped.device_groups, &b_);
3095 auto rng = b_.AddInstruction(HloInstruction::CreateRng(
3096 MakePartitionedShape(hlo->shape(), hlo->sharding()),
3097 hlo->random_distribution(), new_operands));
3098 rng->set_sharding(HloSharding::AssignDevice(0));
3099 SetPartitionedHlo(hlo, [&]() {
3100 return PartitionedHlo(rng, rng->shape(), per_group_state)
3101 .Replicate()
3102 .hlo();
3103 });
3104 }
3105 return Status::OK();
3106 }
3107
HandleReduceWindow(HloInstruction * hlo)3108 Status SpmdPartitioningVisitor::HandleReduceWindow(HloInstruction* hlo) {
3109 if (hlo->sharding().IsTileMaximal()) {
3110 return DefaultAction(hlo);
3111 }
3112 HloReduceWindowInstruction* reduce_window =
3113 Cast<HloReduceWindowInstruction>(hlo);
3114 absl::Span<HloInstruction* const> input_arrays = reduce_window->inputs();
3115 absl::Span<HloInstruction* const> init_values = reduce_window->init_values();
3116 int64_t input_idx = 0;
3117 absl::InlinedVector<PartitionedHlo::WindowedInputShardReturnValue, 2>
3118 sharded_results;
3119 absl::InlinedVector<const Shape*, 2> sharded_input_shapes,
3120 replicated_init_shapes;
3121 absl::InlinedVector<HloInstruction*, 2> sharded_inputs, replicated_inits;
3122 for (const HloInstruction* input_array : input_arrays) {
3123 PartitionedHlo& operand = GetPartitionedHlo(input_array);
3124 // Replicate init
3125 PartitionedHlo replicated_init = GetPartitionedHlo(init_values[input_idx++])
3126 .Reshard(HloSharding::Replicate());
3127 auto resharded_operand_and_window = operand.ReshardAsWindowedInput(
3128 hlo->window(), hlo->sharding(), replicated_init.hlo());
3129 if (!resharded_operand_and_window.has_value()) {
3130 return DefaultAction(hlo);
3131 }
3132 sharded_results.push_back(resharded_operand_and_window.value());
3133 sharded_inputs.push_back(resharded_operand_and_window->sharded_input);
3134 sharded_input_shapes.push_back(&sharded_inputs.back()->shape());
3135 replicated_inits.push_back(replicated_init.hlo());
3136 replicated_init_shapes.push_back(&replicated_inits.back()->shape());
3137 }
3138 TF_ASSIGN_OR_RETURN(Shape sharded_rw_shape,
3139 ShapeInference::InferReduceWindowShape(
3140 sharded_input_shapes, replicated_init_shapes,
3141 sharded_results[0].shard_window,
3142 hlo->to_apply()->ComputeProgramShape()));
3143 HloSharding result_sharding =
3144 (hlo->shape().IsTuple())
3145 ? hlo->sharding().GetTupleSharding(hlo->shape()).ValueOrDie()
3146 : hlo->sharding();
3147 Shape shard_shape = MakePartitionedShape(hlo->shape(), result_sharding);
3148 *sharded_rw_shape.mutable_layout() = shard_shape.layout();
3149 SetPartitionedHlo(hlo, [&]() {
3150 HloInstruction* sharded_rw =
3151 b_.AddInstruction(HloInstruction::CreateReduceWindow(
3152 sharded_rw_shape, sharded_inputs, replicated_inits,
3153 sharded_results[0].shard_window, hlo->to_apply()));
3154 if (!sharded_results[0].dynamic_slice_index_on_output.has_value()) {
3155 CHECK(ShapeUtil::Compatible(shard_shape, sharded_rw->shape()))
3156 << shard_shape << " vs " << sharded_rw->shape() << "\n";
3157 return sharded_rw;
3158 }
3159 return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
3160 shard_shape, sharded_rw,
3161 *sharded_results[0].dynamic_slice_index_on_output,
3162 shard_shape.dimensions()));
3163 });
3164 return Status::OK();
3165 }
3166
HandleSelectAndScatter(HloInstruction * hlo)3167 Status SpmdPartitioningVisitor::HandleSelectAndScatter(HloInstruction* hlo) {
3168 if (hlo->sharding().IsTileMaximal()) {
3169 return DefaultAction(hlo);
3170 }
3171 auto operand = GetPartitionedHlo(hlo->operand(0));
3172 auto source = GetPartitionedHlo(hlo->mutable_operand(1));
3173 if (hlo->sharding() != operand.sharding()) {
3174 operand = operand.Reshard(hlo->sharding());
3175 }
3176 if (hlo->sharding() != source.sharding()) {
3177 source = source.Reshard(hlo->sharding());
3178 }
3179
3180 // For F32 and BF16 types, we can use NaN padding to workaround the issue with
3181 // low/high padding, since comparison will return false with NaN input.
3182 if (hlo->shape().element_type() != F32 &&
3183 hlo->shape().element_type() != BF16) {
3184 return DefaultAction(hlo);
3185 }
3186
3187 auto select = hlo->called_computations()[0];
3188 auto select_root = select->root_instruction();
3189 if (select_root->opcode() != HloOpcode::kCompare ||
3190 select_root->operand(0)->opcode() != HloOpcode::kParameter ||
3191 select_root->operand(1)->opcode() != HloOpcode::kParameter ||
3192 select_root->operand(0)->parameter_number() +
3193 select_root->operand(1)->parameter_number() !=
3194 1) {
3195 return DefaultAction(hlo);
3196 }
3197
3198 float float_pad_value;
3199 if (select_root->comparison_direction() == ComparisonDirection::kGe ||
3200 select_root->comparison_direction() == ComparisonDirection::kGt) {
3201 if (select_root->operand(0)->parameter_number() == 0) {
3202 float_pad_value = -std::numeric_limits<float>::infinity();
3203 } else {
3204 float_pad_value = std::numeric_limits<float>::infinity();
3205 }
3206 } else if (select_root->comparison_direction() == ComparisonDirection::kLe ||
3207 select_root->comparison_direction() == ComparisonDirection::kLt) {
3208 if (select_root->operand(0)->parameter_number() == 0) {
3209 float_pad_value = std::numeric_limits<float>::infinity();
3210 } else {
3211 float_pad_value = -std::numeric_limits<float>::infinity();
3212 }
3213 } else {
3214 return DefaultAction(hlo);
3215 }
3216
3217 auto pad_value = b_.AddInstruction(HloInstruction::CreateConstant(
3218 hlo->shape().element_type() == BF16
3219 ? LiteralUtil::CreateR0<bfloat16>(
3220 static_cast<bfloat16>(float_pad_value))
3221 : LiteralUtil::CreateR0<float>(float_pad_value)));
3222
3223 // Replicate init
3224 auto replicated_init = GetPartitionedHlo(hlo->mutable_operand(2))
3225 .Reshard(HloSharding::Replicate());
3226
3227 auto state = MakePartitioningState();
3228 auto partition_ordinals =
3229 MakeTiledPartitionOrdinals(hlo->sharding(), state.partition_id, &b_);
3230
3231 // The first window for each dimension that overlaps with the shard area.
3232 std::vector<MultiplyAddDivideOffsetCalculation> first_window(
3233 hlo->shape().rank());
3234 // The first window for each dimension that goes beyond with the shard area.
3235 std::vector<MultiplyAddDivideOffsetCalculation> limit_window(
3236 hlo->shape().rank());
3237 std::vector<OffsetCalculation> data_left_halo_sizes(hlo->shape().rank());
3238 std::vector<OffsetCalculation> data_right_halo_sizes(hlo->shape().rank());
3239 std::vector<OffsetCalculation> source_left_halo_sizes(hlo->shape().rank());
3240 std::vector<OffsetCalculation> source_right_halo_sizes(hlo->shape().rank());
3241 auto unpadded_data_shard_shape =
3242 MakePartitionedShape(hlo->shape(), hlo->sharding());
3243 auto unpadded_source_shard_shape =
3244 MakePartitionedShape(hlo->operand(1)->shape(), hlo->sharding());
3245 auto source_shard_hlo = source.hlo();
3246 auto data_shard_hlo = operand.hlo();
3247 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
3248 int64_t shard_count = hlo->sharding().tile_assignment().dim(i);
3249 if (shard_count == 1) {
3250 continue;
3251 }
3252 // If stride > window_size, there will be gaps between windows. These gaps
3253 // will also exist in the output, so we keep them during halo exchange.
3254 //
3255 // TODO(yuanzx): This could introduce overhead if partitions start at
3256 // different offsets in a gap.
3257 auto wd = hlo->window().dimensions(i);
3258 if (wd.stride() > wd.size()) {
3259 wd.set_size(wd.stride());
3260 }
3261 // shard_size * i < stride * k - pad_low + window_size =>
3262 // k > (shard_size * i + pad_low - window_size) / stride =>
3263 // first_k == (shard_size * i + pad_low - window_size + stride) / stride
3264 first_window[i] = MultiplyAddDivideOffsetCalculation(
3265 unpadded_data_shard_shape.dimensions(i),
3266 wd.padding_low() - wd.size() + wd.stride(), wd.stride());
3267 // shard_size * (i + 1) <= stride * k - pad_low =>
3268 // k >= (shard_size * i + shard_size + pad_low) / stride =>
3269 // limit_k == (shard_size * i + shard_size + pad_low + stride - 1) /
3270 // stride
3271 limit_window[i] = MultiplyAddDivideOffsetCalculation(
3272 unpadded_data_shard_shape.dimensions(i),
3273 unpadded_data_shard_shape.dimensions(i) + wd.padding_low() +
3274 wd.stride() - 1,
3275 wd.stride());
3276 source_left_halo_sizes[i] =
3277 MultiplyAddDivideOffsetCalculation(
3278 unpadded_source_shard_shape.dimensions(i), 0, 1) -
3279 first_window[i];
3280 source_right_halo_sizes[i] =
3281 limit_window[i] - MultiplyAddDivideOffsetCalculation(
3282 unpadded_source_shard_shape.dimensions(i),
3283 unpadded_source_shard_shape.dimensions(i), 1);
3284 data_left_halo_sizes[i] =
3285 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
3286 unpadded_data_shard_shape.dimensions(i), wd.padding_low(), 1)) -
3287 OffsetCalculation(
3288 HloOpcode::kMultiply, first_window[i],
3289 MultiplyAddDivideOffsetCalculation(0, wd.stride(), 1));
3290 data_right_halo_sizes[i] =
3291 OffsetCalculation(
3292 HloOpcode::kMultiply, limit_window[i],
3293 MultiplyAddDivideOffsetCalculation(0, wd.stride(), 1)) -
3294 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
3295 unpadded_data_shard_shape.dimensions(i),
3296 unpadded_data_shard_shape.dimensions(i) + wd.stride() +
3297 wd.padding_low() - wd.size(),
3298 1));
3299
3300 int64_t max_windows =
3301 (limit_window[i] - first_window[i]).MaxInRange(0, shard_count);
3302 auto first_window_hlo =
3303 first_window[i].Calculate(partition_ordinals[i], &b_);
3304 // Padding on the source is filled with the init value so they do not change
3305 // the data on overlapping windows.
3306 auto resharded_source = ExchangeHaloAndGetValidData(
3307 source_shard_hlo, source.base_shape(), source_left_halo_sizes[i],
3308 source_right_halo_sizes[i], 0,
3309 limit_window[i].Calculate(shard_count - 1), max_windows, i,
3310 hlo->sharding(), first_window_hlo, replicated_init.hlo(),
3311 partition_ordinals[i], collective_ops_creator_, next_channel_id_, &b_);
3312 if (!resharded_source) {
3313 return DefaultAction(hlo);
3314 }
3315 source_shard_hlo = *resharded_source;
3316
3317 auto offset_start_in_data =
3318 MultiplyAddDivideOffsetCalculation(wd.stride(), 0, 1)
3319 .Calculate(first_window_hlo, &b_);
3320 int64_t padded_data_size =
3321 (limit_window[i].Calculate(shard_count - 1) - 1) * wd.stride() +
3322 wd.size();
3323 int64_t data_shard_size = (max_windows - 1) * wd.stride() + wd.size();
3324 auto resharded_data = ExchangeHaloAndGetValidData(
3325 data_shard_hlo, operand.base_shape(), data_left_halo_sizes[i],
3326 data_right_halo_sizes[i], wd.padding_low(), padded_data_size,
3327 data_shard_size, i, hlo->sharding(), offset_start_in_data, pad_value,
3328 partition_ordinals[i], collective_ops_creator_, next_channel_id_, &b_);
3329 if (!resharded_data) {
3330 return DefaultAction(hlo);
3331 }
3332 data_shard_hlo = *resharded_data;
3333 }
3334
3335 Window window_on_shard = hlo->window();
3336 for (int64_t i = 0; i < window_on_shard.dimensions_size(); ++i) {
3337 int64_t shard_count = hlo->sharding().tile_assignment().dim(i);
3338 if (shard_count == 1) {
3339 continue;
3340 }
3341 auto reshard_wd = window_on_shard.mutable_dimensions(i);
3342 // The shards are already explicitly padded.
3343 reshard_wd->set_padding_low(0);
3344 reshard_wd->set_padding_high(0);
3345 }
3346
3347 auto sharded_select_and_scatter =
3348 b_.AddInstruction(HloInstruction::CreateSelectAndScatter(
3349 data_shard_hlo->shape(), data_shard_hlo, select, window_on_shard,
3350 source_shard_hlo, replicated_init.hlo(),
3351 hlo->called_computations()[1]));
3352 SetPartitionedHlo(hlo, [&]() {
3353 auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
3354 if (ShapeUtil::Compatible(sharded_select_and_scatter->shape(),
3355 shard_shape)) {
3356 return sharded_select_and_scatter;
3357 }
3358 auto zero = b_.AddInstruction(
3359 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
3360 std::vector<HloInstruction*> slice_offsets(shard_shape.rank(), zero);
3361 for (int64_t i = 0; i < window_on_shard.dimensions_size(); ++i) {
3362 if (hlo->sharding().tile_assignment().dim(i) == 1) {
3363 continue;
3364 }
3365 int64_t pad_low = hlo->window().dimensions(i).padding_low();
3366 auto left_halo_size =
3367 data_left_halo_sizes[i].Calculate(partition_ordinals[i], &b_);
3368 if (data_left_halo_sizes[i].Calculate(0) == pad_low) {
3369 slice_offsets[i] = left_halo_size;
3370 } else {
3371 auto is_shard0 = b_.AddInstruction(HloInstruction::CreateCompare(
3372 ShapeUtil::MakeShape(PRED, {}), zero, partition_ordinals[i],
3373 ComparisonDirection::kEq));
3374 auto pad_low_hlo = b_.AddInstruction(HloInstruction::CreateConstant(
3375 LiteralUtil::CreateR0<int32>(pad_low)));
3376 slice_offsets[i] = b_.AddInstruction(HloInstruction::CreateTernary(
3377 zero->shape(), HloOpcode::kSelect, is_shard0, pad_low_hlo,
3378 left_halo_size));
3379 }
3380 }
3381 return b_.AddInstruction(HloInstruction::CreateDynamicSlice(
3382 shard_shape, sharded_select_and_scatter, slice_offsets,
3383 shard_shape.dimensions()));
3384 });
3385 return Status::OK();
3386 }
3387
HandleTuple(HloInstruction * hlo)3388 Status SpmdPartitioningVisitor::HandleTuple(HloInstruction* hlo) {
3389 std::vector<HloInstruction*> new_operands;
3390 for (int64_t i = 0; i < hlo->operand_count(); ++i) {
3391 new_operands.push_back(
3392 GetPartitionedHlo(hlo->operand(i))
3393 .Reshard(hlo->sharding().GetSubSharding(hlo->shape(), {i}))
3394 .hlo());
3395 }
3396 SetPartitionedHlo(hlo, [&]() {
3397 return b_.AddInstruction(HloInstruction::CreateTuple(new_operands));
3398 });
3399 return Status::OK();
3400 }
3401
DoPartition(HloComputation * computation,const HloSharding & root_sharding,const SpmdPartitionerOptions & options)3402 StatusOr<bool> SpmdPartitioningVisitor::DoPartition(
3403 HloComputation* computation, const HloSharding& root_sharding,
3404 const SpmdPartitionerOptions& options) {
3405 VLOG(2) << "Partitioning computation " << computation->name() << " for "
3406 << num_replicas_ << " replicas and " << num_partitions_
3407 << " partitions";
3408 TF_RETURN_IF_ERROR(computation->Accept(this));
3409
3410 HloModule* module = computation->parent();
3411 auto new_root =
3412 GetPartitionedHlo(computation->root_instruction()).Reshard(root_sharding);
3413 auto new_computation =
3414 module->AddEmbeddedComputation(b_.Build(new_root.hlo()));
3415 TF_RETURN_IF_ERROR(
3416 DoCodeMotionForWindowedDotGeneralLoops(new_computation, options));
3417
3418 // Replace the original computation with the new SPMD computation.
3419 absl::flat_hash_map<HloComputation*, HloComputation*> replacement;
3420 replacement[computation] = new_computation;
3421 module->ReplaceComputations(replacement);
3422 return changed_;
3423 }
3424
HandlePartitionId(HloInstruction * hlo)3425 Status SpmdPartitioningVisitor::HandlePartitionId(HloInstruction* hlo) {
3426 return Unimplemented(
3427 "PartitionId instruction is not supported for SPMD partitioning since "
3428 "the meaning is ambiguous -- whether the instruction is replicated or "
3429 "the data is replicated, and if the latter which data is replicated.");
3430 }
3431
GetDefaultCollectiveOpsCreator(int64_t num_partitions,int64_t num_replicas)3432 SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64_t num_partitions,
3433 int64_t num_replicas) {
3434 return {
3435 [](SpmdBuilder* b) {
3436 return b->AddInstruction(HloInstruction::CreatePartitionId());
3437 },
3438 [num_replicas, num_partitions](
3439 SpmdBuilder* b, HloInstruction* operand, HloComputation* reduction,
3440 const std::vector<std::vector<int64>>& partition_subgroups,
3441 int64_t channel_id) {
3442 if (partition_subgroups.size() <= 1) {
3443 std::vector<ReplicaGroup> groups(num_replicas);
3444 // TODO(yuanzx): Unify subgroup definition with AllToAll.
3445 for (int64_t i = 0; i < num_replicas; ++i) {
3446 groups[i].add_replica_ids(i);
3447 }
3448 return b->AddInstruction(HloInstruction::CreateAllReduce(
3449 operand->shape(), {operand}, reduction, groups,
3450 /*constrain_layout=*/false, channel_id,
3451 /*use_global_device_ids=*/false));
3452 }
3453
3454 std::vector<ReplicaGroup> device_groups;
3455 device_groups.reserve(partition_subgroups.size() * num_replicas);
3456 for (int64_t i = 0; i < num_replicas; ++i) {
3457 for (const auto& pgroup : partition_subgroups) {
3458 device_groups.emplace_back();
3459 for (int64_t pid : pgroup) {
3460 device_groups.back().add_replica_ids(i * num_partitions + pid);
3461 }
3462 }
3463 }
3464 return b->AddInstruction(HloInstruction::CreateAllReduce(
3465 operand->shape(), {operand}, reduction, device_groups,
3466 /*constrain_layout=*/false, channel_id,
3467 /*use_global_device_ids=*/true));
3468 },
3469 [num_partitions](SpmdBuilder* b, HloInstruction* operand,
3470 std::vector<std::pair<int64, int64>>& src_dst_pairs,
3471 int64_t channel_id) {
3472 /* optimize trivial collective permute */
3473 if (src_dst_pairs.empty()) {
3474 // If the src/dst pairs are empty, then the collective permute just
3475 // initializes the output to zero.
3476 return CreateZero(operand->shape(), b);
3477 } else {
3478 // A collective-permute is a copy if all pairs are "identity" and
3479 // all partitions are listed.
3480 bool is_copy =
3481 src_dst_pairs.size() == num_partitions &&
3482 absl::c_all_of(src_dst_pairs,
3483 [](const std::pair<int64, int64>& pair) {
3484 return pair.first == pair.second;
3485 });
3486 if (is_copy) {
3487 return operand;
3488 } else {
3489 return b->AddInstruction(HloInstruction::CreateCollectivePermute(
3490 operand->shape(), operand, src_dst_pairs, channel_id));
3491 }
3492 }
3493 },
3494 [](SpmdBuilder* b, absl::Span<HloInstruction* const> operands,
3495 const std::vector<std::vector<int64>>& partition_subgroups,
3496 int64_t channel_id, absl::optional<int64> split_dimension) {
3497 std::vector<Shape> shapes(operands.size(), operands[0]->shape());
3498 const Shape output_shape = (shapes.size() == 1)
3499 ? shapes[0]
3500 : ShapeUtil::MakeTupleShape(shapes);
3501 std::vector<ReplicaGroup> groups(partition_subgroups.size());
3502 for (int64_t i = 0; i < groups.size(); ++i) {
3503 for (int64_t id : partition_subgroups[i]) {
3504 groups[i].add_replica_ids(id);
3505 }
3506 }
3507 return b->AddInstruction(HloInstruction::CreateAllToAll(
3508 output_shape, operands, groups,
3509 /*constrain_layout=*/false, channel_id, split_dimension));
3510 },
3511 [num_replicas, num_partitions](
3512 SpmdBuilder* b, HloInstruction* operand, const Shape& ag_shape,
3513 const std::vector<std::vector<int64>>& partition_subgroups,
3514 int64_t channel_id, int64_t all_gather_dimension) {
3515 std::vector<ReplicaGroup> device_groups;
3516 device_groups.reserve(partition_subgroups.size() * num_replicas);
3517 for (int64_t i = 0; i < num_replicas; ++i) {
3518 for (const auto& pgroup : partition_subgroups) {
3519 device_groups.emplace_back();
3520 for (int64_t pid : pgroup) {
3521 device_groups.back().add_replica_ids(i * num_partitions + pid);
3522 }
3523 }
3524 }
3525 return b->AddInstruction(HloInstruction::CreateAllGather(
3526 ag_shape, {operand}, all_gather_dimension, device_groups,
3527 /*constrain_layout=*/false, channel_id,
3528 /*use_global_device_ids=*/true));
3529 },
3530 };
3531 }
3532
SpmdPartitioner(int64_t num_partitions,int64_t num_replicas,SpmdPartitionerOptions options)3533 SpmdPartitioner::SpmdPartitioner(int64_t num_partitions, int64_t num_replicas,
3534 SpmdPartitionerOptions options)
3535 : SpmdPartitioner(
3536 num_partitions, num_replicas, std::move(options),
3537 GetDefaultCollectiveOpsCreator(num_partitions, num_replicas)) {}
3538
AllGatherShards(SpmdBuilder * b,HloInstruction * operand,const HloSharding & sharding,int64 * next_channel_id,absl::Span<const int64> selected_dims,const SPMDCollectiveOpsCreator & collectives_creator)3539 HloInstruction* SpmdPartitioner::AllGatherShards(
3540 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
3541 int64* next_channel_id, absl::Span<const int64> selected_dims,
3542 const SPMDCollectiveOpsCreator& collectives_creator) {
3543 return AllGatherShardsInternal(b, operand, sharding, next_channel_id,
3544 selected_dims, collectives_creator,
3545 /*per_dim_ag=*/true);
3546 }
3547
AllGatherShardsInternal(SpmdBuilder * b,HloInstruction * operand,const HloSharding & sharding,int64 * next_channel_id,absl::Span<const int64> selected_dims,const SPMDCollectiveOpsCreator & collectives_creator,bool per_dim_ag)3548 HloInstruction* SpmdPartitioner::AllGatherShardsInternal(
3549 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
3550 int64* next_channel_id, absl::Span<const int64> selected_dims,
3551 const SPMDCollectiveOpsCreator& collectives_creator, bool per_dim_ag) {
3552 if (selected_dims.empty()) {
3553 return operand;
3554 }
3555 CHECK(!sharding.IsTileMaximal());
3556 // Add one leading dimension to gather all partitions.
3557 std::vector<int64> shape;
3558 shape.push_back(1);
3559 for (int64_t dim : operand->shape().dimensions()) {
3560 shape.push_back(dim);
3561 }
3562 auto reshape = b->AddInstruction(HloInstruction::CreateReshape(
3563 ShapeUtil::MakeShape(operand->shape().element_type(), shape), operand));
3564 HloInstruction* result = reshape;
3565 if (per_dim_ag) {
3566 for (auto it = selected_dims.rbegin(); it != selected_dims.rend(); ++it) {
3567 if (sharding.tile_assignment().dim(*it) == 1) {
3568 continue;
3569 }
3570 auto partition_subgroups =
3571 GetPartitionGroupsForReplication(sharding, {*it});
3572 shape[0] *= partition_subgroups[0].size();
3573 result = collectives_creator.create_cross_partition_all_gather(
3574 b, result,
3575 ShapeUtil::MakeShape(operand->shape().element_type(), shape),
3576 partition_subgroups, (*next_channel_id)++,
3577 /*all_gather_dimension=*/0);
3578 }
3579 } else {
3580 auto partition_subgroups =
3581 GetPartitionGroupsForReplication(sharding, selected_dims);
3582 shape[0] *= partition_subgroups[0].size();
3583 result = collectives_creator.create_cross_partition_all_gather(
3584 b, result, ShapeUtil::MakeShape(operand->shape().element_type(), shape),
3585 partition_subgroups, (*next_channel_id)++,
3586 /*all_gather_dimension=*/0);
3587 }
3588 // If n > 1 dimensions are partitioned, split the leading dimension to n.
3589 std::vector<int64> tiled_dims;
3590 for (int64_t i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
3591 if (sharding.tile_assignment().dim(i) > 1 &&
3592 absl::c_linear_search(selected_dims, i)) {
3593 tiled_dims.push_back(i);
3594 }
3595 }
3596 if (tiled_dims.size() > 1) {
3597 std::vector<int64> split_dim_shape;
3598 split_dim_shape.reserve(tiled_dims.size() + operand->shape().rank());
3599 for (int64_t i : tiled_dims) {
3600 split_dim_shape.push_back(sharding.tile_assignment().dim(i));
3601 }
3602 for (int64_t dim : operand->shape().dimensions()) {
3603 split_dim_shape.push_back(dim);
3604 }
3605 result = b->AddInstruction(HloInstruction::CreateReshape(
3606 ShapeUtil::MakeShape(operand->shape().element_type(), split_dim_shape),
3607 result));
3608 }
3609 // Transpose the gathered dimensions to next to their corresponding
3610 // partitioned dimensions.
3611 std::vector<int64> xpose_permutation(result->shape().rank());
3612 int64_t split_dims_added = 0;
3613 for (int64_t i = 0; i < xpose_permutation.size(); ++i) {
3614 if (sharding.tile_assignment().dim(i - split_dims_added) == 1 ||
3615 !absl::c_linear_search(selected_dims, i - split_dims_added)) {
3616 xpose_permutation[i] = i + tiled_dims.size() - split_dims_added;
3617 } else {
3618 xpose_permutation[i] = split_dims_added;
3619 xpose_permutation[i + 1] = i + tiled_dims.size() - split_dims_added;
3620 split_dims_added++;
3621 i++;
3622 }
3623 }
3624 result = b->AddInstruction(HloInstruction::CreateTranspose(
3625 ShapeInference::InferTransposeShape(result->shape(), xpose_permutation)
3626 .ValueOrDie(),
3627 result, xpose_permutation));
3628 // Reshape to the desired shape.
3629 auto ag_shape = operand->shape();
3630 for (int64_t i : tiled_dims) {
3631 ag_shape.set_dimensions(
3632 i, ag_shape.dimensions(i) * sharding.tile_assignment().dim(i));
3633 }
3634 result = b->AddInstruction(HloInstruction::CreateReshape(ag_shape, result));
3635 return result;
3636 }
3637
AllReduceAlongShardingDims(SpmdBuilder * b,HloInstruction * operand,const HloSharding & sharding,int64 * next_channel_id,absl::Span<const int64> selected_dims,const SPMDCollectiveOpsCreator & collectives_creator,HloComputation * reduction)3638 HloInstruction* SpmdPartitioner::AllReduceAlongShardingDims(
3639 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
3640 int64* next_channel_id, absl::Span<const int64> selected_dims,
3641 const SPMDCollectiveOpsCreator& collectives_creator,
3642 HloComputation* reduction) {
3643 return AllReduceAlongShardingDimsInternal(
3644 b, operand, sharding, next_channel_id, selected_dims, collectives_creator,
3645 reduction, /*per_dim_ar=*/true);
3646 }
3647
AllReduceAlongShardingDimsInternal(SpmdBuilder * b,HloInstruction * operand,const HloSharding & sharding,int64 * next_channel_id,absl::Span<const int64> selected_dims,const SPMDCollectiveOpsCreator & collectives_creator,HloComputation * reduction,bool per_dim_ar)3648 HloInstruction* SpmdPartitioner::AllReduceAlongShardingDimsInternal(
3649 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
3650 int64* next_channel_id, absl::Span<const int64> selected_dims,
3651 const SPMDCollectiveOpsCreator& collectives_creator,
3652 HloComputation* reduction, bool per_dim_ar) {
3653 if (!per_dim_ar) {
3654 auto partition_subgroups =
3655 GetPartitionGroupsForReplication(sharding, selected_dims);
3656 return collectives_creator.create_cross_partition_all_reduce(
3657 b, operand, reduction, partition_subgroups, (*next_channel_id)++);
3658 }
3659 auto result = operand;
3660 for (auto it = selected_dims.rbegin(); it != selected_dims.rend(); ++it) {
3661 if (sharding.tile_assignment().dim(*it) == 1) {
3662 continue;
3663 }
3664 auto partition_subgroups =
3665 GetPartitionGroupsForReplication(sharding, {*it});
3666 result = collectives_creator.create_cross_partition_all_reduce(
3667 b, result, reduction, partition_subgroups, (*next_channel_id)++);
3668 }
3669 return result;
3670 }
3671
PartitionComputation(HloComputation * computation,const HloSharding & root_sharding,int64 * next_channel_id,SpmdLogger * logger)3672 StatusOr<bool> SpmdPartitioner::PartitionComputation(
3673 HloComputation* computation, const HloSharding& root_sharding,
3674 int64* next_channel_id, SpmdLogger* logger) {
3675 auto visitor =
3676 CreateVisitor(computation, num_partitions_, num_replicas_,
3677 collective_ops_creator_, next_channel_id, logger, options_);
3678 return visitor->DoPartition(computation, root_sharding, options_);
3679 }
3680
CreateVisitor(HloComputation * computation,int64_t num_partitions,int64_t num_replicas,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,SpmdLogger * logger,SpmdPartitionerOptions options)3681 std::unique_ptr<SpmdPartitioningVisitor> SpmdPartitioner::CreateVisitor(
3682 HloComputation* computation, int64_t num_partitions, int64_t num_replicas,
3683 const SPMDCollectiveOpsCreator& collective_ops_creator,
3684 int64* next_channel_id, SpmdLogger* logger,
3685 SpmdPartitionerOptions options) {
3686 return absl::make_unique<SpmdPartitioningVisitor>(
3687 computation, num_partitions, num_replicas, collective_ops_creator,
3688 next_channel_id, logger, std::move(options), this);
3689 }
3690
Run(HloModule * module)3691 StatusOr<bool> SpmdPartitioner::Run(HloModule* module) {
3692 TF_RETURN_IF_ERROR(PreprocessSharding(module));
3693 TF_RETURN_IF_ERROR(PreprocessHlos(module));
3694
3695 XLA_VLOG_LINES(1, SpmdLogger::ReportBeforePartition(
3696 *module, options_.report_instruction_count));
3697
3698 // Add the parameters' and output's shardings to the module.
3699 std::vector<HloSharding> entry_params_shardings;
3700 for (int64_t i = 0; i < module->entry_computation()->num_parameters(); ++i) {
3701 auto param = module->entry_computation()->parameter_instruction(i);
3702 CHECK(param->has_sharding()) << "Missing sharding in entry parameter " << i;
3703 entry_params_shardings.push_back(param->sharding());
3704 }
3705 module->set_spmd_parameters_shardings(entry_params_shardings);
3706 auto entry_root = module->entry_computation()->root_instruction();
3707 CHECK(entry_root->has_sharding()) << "Missing sharding in entry root.";
3708 module->set_spmd_output_sharding(entry_root->sharding());
3709
3710 FlattenCallGraph flatten;
3711 TF_ASSIGN_OR_RETURN(auto changed, flatten.Run(module));
3712
3713 SpmdLogger logger(options_.report_instruction_count);
3714 auto program_shape = module->entry_computation()->ComputeProgramShape();
3715 int64_t next_channel_id = hlo_query::NextChannelId(*module);
3716 // Copy the root sharding since the partitioner visitor may temporarily change
3717 // the sharding to work around manual sharding.
3718 HloSharding root_sharding = entry_root->sharding();
3719 TF_ASSIGN_OR_RETURN(
3720 bool partition_changed,
3721 PartitionComputation(module->entry_computation(), root_sharding,
3722 &next_channel_id, &logger));
3723 changed |= partition_changed;
3724
3725 // For the entry computation, make sure that the root instruction and the
3726 // parameters preserve their signatures.
3727 auto new_program_shape = module->entry_computation()->ComputeProgramShape();
3728 if (!options_.allow_module_signature_change) {
3729 TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
3730 program_shape.result(), new_program_shape.result()))
3731 << "Result shape changed for the entry computation";
3732 TF_RET_CHECK(program_shape.parameters_size() ==
3733 new_program_shape.parameters_size())
3734 << "Parameter count changed for the entry computation";
3735 for (int64_t i = 0; i < program_shape.parameters_size(); ++i) {
3736 TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
3737 program_shape.parameters(i), new_program_shape.parameters(i)))
3738 << "Parameter shape changed for the entry computation";
3739 }
3740 } else {
3741 const auto& old_entry_layout = module->entry_computation_layout();
3742 // Shapes can change but the layout should still remain the same.
3743 for (int64_t i = 0; i < new_program_shape.parameters_size(); ++i) {
3744 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
3745 old_entry_layout.parameter_shape(i),
3746 new_program_shape.mutable_parameters(i)));
3747 }
3748 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
3749 old_entry_layout.result_shape(), new_program_shape.mutable_result()));
3750
3751 HloModuleConfig config = module->config();
3752 *config.mutable_entry_computation_layout() =
3753 ComputationLayout(new_program_shape, /*ignore_layouts=*/false);
3754 module->set_config(config);
3755 }
3756
3757 XLA_VLOG_LINES(1, SpmdLogger::ReportAfterPartition(
3758 *module, options_.report_instruction_count));
3759 XLA_VLOG_LINES(1, logger.MakeReport());
3760
3761 if (changed) {
3762 HloPassPipeline pass("spmd-cleanup");
3763 pass.AddPass<HloDCE>();
3764 pass.AddPass<TupleSimplifier>();
3765 pass.AddPass<HloDCE>();
3766 pass.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
3767 pass.AddPass<FlattenCallGraph>();
3768 TF_RETURN_IF_ERROR(pass.Run(module).status());
3769 }
3770
3771 TF_RETURN_IF_ERROR(ClearShardingAttributes(module));
3772 return changed;
3773 }
3774
PreprocessSharding(HloModule * module)3775 Status SpmdPartitioner::PreprocessSharding(HloModule* module) {
3776 for (HloComputation* computation : module->computations()) {
3777 for (HloInstruction* hlo : computation->instructions()) {
3778 if (hlo->HasSideEffectNoRecurse() && hlo->opcode() != HloOpcode::kRng) {
3779 TF_RET_CHECK(hlo->has_sharding())
3780 << "Side-effect HLO must have sharding: " << hlo->ToString();
3781 TF_RET_CHECK(!HasReplicatedSharding(hlo->sharding()) ||
3782 CanSideEffectingHaveReplicatedSharding(hlo))
3783 << "side-effect HLO cannot have a replicated sharding: "
3784 << hlo->ToString();
3785 }
3786
3787 // For unassigned HLOs, annotate with replicated sharding.
3788 //
3789 // Among side-effecting ops, only Rng is allowed to omit the annotation.
3790 // In that case, we currently force it to run on core 0, since we don't
3791 // support partitioning or replicating the Rng op (the values depend on
3792 // the seed provided to each device).
3793 //
3794 // TODO(hyouklee): Should we also convert single-device shardings (without
3795 // side-effects) into replicated?
3796 if (!hlo->has_sharding()) {
3797 if (hlo->opcode() == HloOpcode::kRng) {
3798 hlo->set_sharding(HloSharding::AssignDevice(0));
3799 } else {
3800 hlo->set_sharding(
3801 HloSharding::Single(hlo->shape(), HloSharding::Replicate()));
3802 }
3803 } else if (!hlo->sharding().IsTileMaximal() &&
3804 !hlo->sharding().IsManual()) {
3805 std::vector<int64> available(num_partitions_);
3806 std::iota(available.begin(), available.end(), 0);
3807 TF_RET_CHECK(num_partitions_ == hlo_sharding_util::DevicesForSharding(
3808 hlo->sharding(), available)
3809 .size())
3810 << "num_partitions:" << num_partitions_ << "\n"
3811 << "SPMD partitioner only supports tile sharding that includes all "
3812 "partitions. If you didn't add this sharding annotation in the "
3813 "model, please file a bug to XLA team.\n"
3814 << hlo->ToString();
3815 }
3816 }
3817 }
3818
3819 // Entry computation's parameter and root sharding must be either all
3820 // replicated or all on a single device.
3821 if (!options_.allow_module_signature_change) {
3822 const HloComputation* entry = module->entry_computation();
3823 TF_RET_CHECK(entry->root_instruction()->has_sharding());
3824 const HloSharding& root_sharding = entry->root_instruction()->sharding();
3825 if (!root_sharding.UniqueDevice().has_value()) {
3826 if (root_sharding.IsTuple()) {
3827 TF_RET_CHECK(absl::c_all_of(root_sharding.tuple_elements(),
3828 [](const HloSharding& s) {
3829 return s.IsReplicated() || s.IsManual();
3830 }))
3831 << "Unsupported entry root sharding: " << root_sharding.ToString();
3832
3833 } else {
3834 TF_RET_CHECK(root_sharding.IsReplicated() || root_sharding.IsManual())
3835 << "Unsupported entry root sharding: " << root_sharding.ToString();
3836 }
3837 }
3838
3839 for (const HloInstruction* param : entry->parameter_instructions()) {
3840 TF_RET_CHECK(param->has_sharding());
3841 TF_RET_CHECK(param->sharding().IsReplicated() ||
3842 param->sharding().UniqueDevice().has_value())
3843 << "Unsupported entry parameter sharding:"
3844 << param->sharding().ToString();
3845 }
3846 }
3847
3848 return Status::OK();
3849 }
3850
PreprocessHlos(HloModule * module)3851 Status SpmdPartitioner::PreprocessHlos(HloModule* module) {
3852 auto skip_copy_operands = [](HloInstruction* operand,
3853 bool check_single_use =
3854 true) -> HloInstruction* {
3855 while (operand->user_count() == 1 &&
3856 operand->opcode() == HloOpcode::kCopy) {
3857 operand = operand->mutable_operand(0);
3858 }
3859 if (check_single_use && operand->user_count() != 1) {
3860 return nullptr;
3861 }
3862 return operand;
3863 };
3864
3865 for (HloComputation* computation : module->computations()) {
3866 for (HloInstruction* hlo : computation->MakeInstructionPostOrder()) {
3867 if (hlo->sharding().IsTileMaximal() || hlo->sharding().IsManual()) {
3868 // No need to optimize for tile-maximal or manual sharding.
3869 continue;
3870 }
3871
3872 if (hlo->opcode() == HloOpcode::kSlice) {
3873 HloInstruction* operand = skip_copy_operands(hlo->mutable_operand(0));
3874 if (operand == nullptr || operand->sharding() != hlo->sharding()) {
3875 continue;
3876 }
3877
3878 // Merge pad->slice to avoid multiple halo exchanges.
3879 if (operand->opcode() == HloOpcode::kPad) {
3880 absl::optional<PaddingConfig> merged_padding =
3881 operand->padding_config();
3882 bool may_have_multi_halo_exchanges = false;
3883 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
3884 const auto& dim = operand->padding_config().dimensions(i);
3885 if (dim.interior_padding() != 0 || hlo->slice_strides(i) != 1) {
3886 merged_padding = absl::nullopt;
3887 break;
3888 }
3889 if (hlo->sharding().tile_assignment().dim(i) != 1 &&
3890 (dim.edge_padding_low() != 0 || dim.edge_padding_high() != 0) &&
3891 hlo->shape().dimensions(i) != operand->shape().dimensions(i)) {
3892 // There are padding, slicing, and sharding on this dim.
3893 may_have_multi_halo_exchanges = true;
3894 }
3895
3896 auto* merged_dim = merged_padding->mutable_dimensions(i);
3897 merged_dim->set_edge_padding_low(dim.edge_padding_low() -
3898 hlo->slice_starts(i));
3899 merged_dim->set_edge_padding_high(hlo->slice_limits(i) -
3900 dim.edge_padding_low() -
3901 operand->shape().dimensions(i));
3902 }
3903 if (merged_padding.has_value() && may_have_multi_halo_exchanges) {
3904 // Rewrite to a single Pad.
3905 HloInstruction* new_pad =
3906 computation->AddInstruction(HloInstruction::CreatePad(
3907 hlo->shape(), operand->mutable_operand(0),
3908 operand->mutable_operand(1), *merged_padding));
3909 new_pad->set_metadata(operand->metadata());
3910 new_pad->set_sharding(hlo->sharding());
3911 TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_pad));
3912 TF_RETURN_IF_ERROR(
3913 computation->RemoveInstructionAndUnusedOperands(hlo));
3914 }
3915 }
3916 }
3917 if (hlo->opcode() == HloOpcode::kConcatenate) {
3918 const int64_t dim = hlo->concatenate_dimension();
3919 if (hlo->sharding().tile_assignment().dim(dim) == 1) {
3920 continue;
3921 }
3922 if (hlo->operand_count() == 2) {
3923 // Find a pattern of "rotate right on one dimension":
3924 // concat(slice(input), slice(input)).
3925 HloInstruction* lhs = skip_copy_operands(hlo->mutable_operand(0));
3926 HloInstruction* rhs = skip_copy_operands(hlo->mutable_operand(1));
3927 if (lhs == nullptr || rhs == nullptr) {
3928 continue;
3929 }
3930 const int64_t amount = FindRotateRightPattern(hlo, lhs, rhs);
3931 if (amount < 0) {
3932 continue;
3933 }
3934 HloInstruction* to_rotate = lhs->mutable_operand(0);
3935 HloInstruction* rotate = computation->AddInstruction(
3936 CreateCustomCallSPMDInternal_RotateRight(to_rotate, dim, amount));
3937 rotate->set_metadata(hlo->metadata());
3938 rotate->set_sharding(hlo->sharding());
3939 TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(rotate));
3940 TF_RETURN_IF_ERROR(
3941 computation->RemoveInstructionAndUnusedOperands(hlo));
3942 } else if (hlo->operand_count() == 3) {
3943 // Find the pattern for "pad with wrap": concat(slice(x), x, slice(x))
3944 // All involved values with same sharding.
3945 HloInstruction* lhs = skip_copy_operands(hlo->mutable_operand(0));
3946 HloInstruction* mid = skip_copy_operands(hlo->mutable_operand(1),
3947 /*check_single_use=*/false);
3948 HloInstruction* rhs = skip_copy_operands(hlo->mutable_operand(2));
3949 absl::optional<PadWithWrapPattern> pad_pattern =
3950 FindPadWithWrapPattern(hlo, lhs, mid, rhs);
3951 if (!pad_pattern) {
3952 continue;
3953 }
3954
3955 // Since the concat requires that the size of all operands along the
3956 // non-concat dimension is the same, it implies that the lhs/rhs slice
3957 // is slicing along the concat dims.
3958
3959 // Step 1: Pad the mid operand to the final size. The low padding is
3960 // the size of the lhs shape, and high padding is size of rhs shape.
3961 PaddingConfig padding_config =
3962 MakeNoPaddingConfig(hlo->shape().rank());
3963 auto* padding_config_dim = padding_config.mutable_dimensions(dim);
3964 const int64_t low_pad = lhs->shape().dimensions(dim);
3965 const int64_t high_pad = rhs->shape().dimensions(dim);
3966 padding_config_dim->set_edge_padding_low(low_pad);
3967 padding_config_dim->set_edge_padding_high(high_pad);
3968 HloInstruction* zero =
3969 computation->AddInstruction(HloInstruction::CreateConstant(
3970 LiteralUtil::Zero(hlo->shape().element_type())));
3971 zero->set_sharding(HloSharding::Replicate());
3972 HloInstruction* pad =
3973 computation->AddInstruction(HloInstruction::CreatePad(
3974 hlo->shape(), mid, zero, padding_config));
3975 pad->set_metadata(hlo->metadata());
3976 pad->set_sharding(hlo->sharding());
3977
3978 // Step 2: rotate the padded value so that the lhs slice aligns to the
3979 // low of the padded size.
3980 // padded_operand = low_pad | mid | high_pad.
3981 // slice_start in padded_operand = lhs->slice_start + low_pad.
3982 // Rotate left by (lhs->slice_start + low_pad)
3983 // i.e., rotate right = padded_size - (lhs_slice_start + low_pad).
3984 const int64_t padded_size = hlo->shape().dimensions(dim);
3985 const int rotate_lhs_amount =
3986 padded_size - (pad_pattern->lhs_slice_start + low_pad);
3987 HloInstruction* rotate_lhs = computation->AddInstruction(
3988 CreateCustomCallSPMDInternal_RotateRight(pad, dim,
3989 rotate_lhs_amount));
3990 rotate_lhs->set_metadata(hlo->metadata());
3991 rotate_lhs->set_sharding(hlo->sharding());
3992
3993 auto apply_modifiers =
3994 [&](HloInstruction* inst,
3995 const std::vector<const HloInstruction*>& modifiers) {
3996 // Apply the modifiers in the reverse order.
3997 for (auto it = modifiers.crbegin(), end = modifiers.crend();
3998 it != end; ++it) {
3999 const HloInstruction* modifier = *it;
4000 // New shape has same element type as the modifier, but dims
4001 // as inst.
4002 Shape new_shape = ShapeUtil::ChangeElementType(
4003 inst->shape(), modifier->shape().element_type());
4004 inst = computation->AddInstruction(
4005 modifier->CloneWithNewOperands(new_shape, {inst}));
4006 }
4007 return inst;
4008 };
4009 rotate_lhs = apply_modifiers(rotate_lhs, pad_pattern->lhs_modifiers);
4010
4011 // Step 3: rotate the padded value so that the rhs slice aligns to
4012 // high of the padded size.
4013 // padded_operand = low_pad | mid | high_pad.
4014 // slice_start in padded_operand = rhs->slice_start + low_pad.
4015 // slice_end in padded_operand = rhs->slice_start + low_pad +
4016 // high_pad; Rotate right by padded_size - (rhs->slice_start +
4017 // low_pad + high_pad)
4018 const int64_t rotate_rhs_amount =
4019 padded_size - (pad_pattern->rhs_slice_start + low_pad + high_pad);
4020 HloInstruction* rotate_rhs = computation->AddInstruction(
4021 CreateCustomCallSPMDInternal_RotateRight(pad, dim,
4022 rotate_rhs_amount));
4023 rotate_rhs->set_metadata(hlo->metadata());
4024 rotate_rhs->set_sharding(hlo->sharding());
4025 rotate_rhs = apply_modifiers(rotate_rhs, pad_pattern->rhs_modifiers);
4026
4027 // Now merge the 3 results using appropriate selects.
4028 const Shape iota_shape =
4029 ShapeUtil::ChangeElementType(hlo->shape(), U32);
4030 HloInstruction* iota = computation->AddInstruction(
4031 HloInstruction::CreateIota(iota_shape, dim));
4032 iota->set_metadata(hlo->metadata());
4033 iota->set_sharding(hlo->sharding());
4034
4035 struct SelectSpec {
4036 int64 limit;
4037 HloInstruction* hlo;
4038 Comparison::Direction cmp;
4039 };
4040 const std::array<SelectSpec, 2> selects = {
4041 {// All elements < low_pad come from rotate_lhs.
4042 {low_pad, rotate_lhs, Comparison::Direction::kLt},
4043 // All elements >= padded_size - high_pad come from rotate_rhs
4044 {padded_size - high_pad, rotate_rhs,
4045 Comparison::Direction::kGe}}};
4046
4047 Shape pred_shape = ShapeUtil::ChangeElementType(hlo->shape(), PRED);
4048
4049 HloInstruction* merged = pad;
4050 for (const SelectSpec& select_spec : selects) {
4051 HloInstruction* limit =
4052 computation->AddInstruction(HloInstruction::CreateConstant(
4053 LiteralUtil::CreateR0<uint32_t>(select_spec.limit)));
4054 limit->set_sharding(HloSharding::Replicate());
4055 HloInstruction* limit_bcast = computation->AddInstruction(
4056 HloInstruction::CreateBroadcast(iota_shape, limit, {}));
4057 limit_bcast->set_metadata(hlo->metadata());
4058 limit_bcast->set_sharding(hlo->sharding());
4059 HloInstruction* compare =
4060 computation->AddInstruction(HloInstruction::CreateCompare(
4061 pred_shape, iota, limit_bcast, select_spec.cmp));
4062 compare->set_metadata(hlo->metadata());
4063 compare->set_sharding(hlo->sharding());
4064 merged = computation->AddInstruction(HloInstruction::CreateTernary(
4065 hlo->shape(), HloOpcode::kSelect, compare, select_spec.hlo,
4066 merged));
4067 merged->set_metadata(hlo->metadata());
4068 merged->set_sharding(hlo->sharding());
4069 }
4070
4071 TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(merged));
4072 TF_RETURN_IF_ERROR(
4073 computation->RemoveInstructionAndUnusedOperands(hlo));
4074 }
4075 }
4076 }
4077 }
4078 return Status::OK();
4079 }
4080
4081 } // namespace spmd
4082 } // namespace xla
4083