1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
17
18 #include <algorithm>
19 #include <memory>
20
21 #include "absl/algorithm/container.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/inlined_vector.h"
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_join.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
32 #include "tensorflow/compiler/xla/service/hlo_module.h"
33 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
34 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
35 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
36 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
37 #include "tensorflow/compiler/xla/service/shape_inference.h"
38 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/compiler/xla/window_util.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43
44 namespace xla {
45 namespace spmd {
46
HasReplicatedSharding(const HloSharding & sharding)47 bool HasReplicatedSharding(const HloSharding& sharding) {
48 if (sharding.IsTuple()) {
49 return absl::c_any_of(sharding.tuple_elements(), HasReplicatedSharding);
50 }
51 return sharding.IsReplicated();
52 }
53
CreateConstant(const Shape & shape,Literal value,SpmdBuilder * b)54 HloInstruction* CreateConstant(const Shape& shape, Literal value,
55 SpmdBuilder* b) {
56 if (shape.IsTuple()) {
57 std::vector<HloInstruction*> elements;
58 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
59 elements.push_back(CreateConstant(
60 ShapeUtil::GetTupleElementShape(shape, i), value.Clone(), b));
61 }
62 return b->AddInstruction(HloInstruction::CreateTuple(elements));
63 }
64
65 CHECK(
66 ShapeUtil::IsScalarWithElementType(value.shape(), shape.element_type()));
67 auto c = b->AddInstruction(HloInstruction::CreateConstant(std::move(value)));
68 return b->AddInstruction(HloInstruction::CreateBroadcast(shape, c, {}));
69 }
70
CreateZero(const Shape & shape,SpmdBuilder * b)71 HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b) {
72 if (shape.IsTuple()) {
73 std::vector<HloInstruction*> elements;
74 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
75 elements.push_back(
76 CreateZero(ShapeUtil::GetTupleElementShape(shape, i), b));
77 }
78 return b->AddInstruction(HloInstruction::CreateTuple(elements));
79 }
80
81 if (shape.IsToken()) {
82 return b->AddInstruction(HloInstruction::CreateToken());
83 }
84 auto zero = b->AddInstruction(
85 HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
86 if (shape.rank() == 0) {
87 return zero;
88 }
89 return b->AddInstruction(HloInstruction::CreateBroadcast(shape, zero, {}));
90 }
91
CreateOne(const Shape & shape,SpmdBuilder * b)92 HloInstruction* CreateOne(const Shape& shape, SpmdBuilder* b) {
93 if (shape.IsTuple()) {
94 std::vector<HloInstruction*> elements;
95 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
96 elements.push_back(
97 CreateOne(ShapeUtil::GetTupleElementShape(shape, i), b));
98 }
99 return b->AddInstruction(HloInstruction::CreateTuple(elements));
100 }
101
102 if (shape.IsToken()) {
103 return b->AddInstruction(HloInstruction::CreateToken());
104 }
105 auto one = b->AddInstruction(
106 HloInstruction::CreateConstant(LiteralUtil::One(shape.element_type())));
107 return b->AddInstruction(HloInstruction::CreateBroadcast(shape, one, {}));
108 }
109
MakeBinaryAdd(PrimitiveType type,HloModule * module)110 HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module) {
111 HloComputation::Builder sum_b("add");
112 auto x = sum_b.AddInstruction(HloInstruction::CreateParameter(
113 /*parameter_number=*/0, ShapeUtil::MakeShape(type, {}), "x"));
114 auto y = sum_b.AddInstruction(HloInstruction::CreateParameter(
115 /*parameter_number=*/1, ShapeUtil::MakeShape(type, {}), "y"));
116 if (type == PRED) {
117 sum_b.AddInstruction(HloInstruction::CreateBinary(
118 ShapeUtil::MakeShape(type, {}), HloOpcode::kOr, x, y));
119 } else {
120 sum_b.AddInstruction(HloInstruction::CreateBinary(
121 ShapeUtil::MakeShape(type, {}), HloOpcode::kAdd, x, y));
122 }
123 HloComputation* reduction = module->AddEmbeddedComputation(sum_b.Build());
124 return reduction;
125 }
126
EvenlyPartitions(const Shape & shape,const HloSharding & sharding)127 bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding) {
128 if (sharding.IsTuple()) {
129 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
130 if (!EvenlyPartitions(ShapeUtil::GetTupleElementShape(shape, i),
131 sharding.GetSubSharding(shape, {i}))) {
132 return false;
133 }
134 }
135 }
136
137 if (sharding.IsTileMaximal()) {
138 return sharding.IsReplicated();
139 }
140 for (int64_t i = 0; i < shape.dimensions_size(); ++i) {
141 if (shape.dimensions(i) % sharding.tile_assignment().dim(i) != 0) {
142 return false;
143 }
144 }
145 return true;
146 }
147
MakePartitionedShape(const Shape & shape,const HloSharding & sharding)148 Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding) {
149 if (sharding.IsTuple()) {
150 std::vector<Shape> subshapes;
151 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
152 subshapes.push_back(
153 MakePartitionedShape(ShapeUtil::GetTupleElementShape(shape, i),
154 sharding.GetSubSharding(shape, {i})));
155 }
156 return ShapeUtil::MakeTupleShape(subshapes);
157 }
158 return sharding.TileShape(shape);
159 }
160
ShapeSizeInBytes(const Shape & shape)161 int64 ShapeSizeInBytes(const Shape& shape) {
162 return ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()) *
163 ShapeUtil::ElementsIn(shape);
164 }
165
MakeNonPaddedShapeForGivenPartition(const Shape & shape,const HloSharding & sharding,int64_t partition_id)166 Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape,
167 const HloSharding& sharding,
168 int64_t partition_id) {
169 if (sharding.IsTuple()) {
170 std::vector<Shape> subshapes;
171 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
172 subshapes.push_back(MakeNonPaddedShapeForGivenPartition(
173 ShapeUtil::GetTupleElementShape(shape, i),
174 sharding.GetSubSharding(shape, {i}), partition_id));
175 }
176 return ShapeUtil::MakeTupleShape(subshapes);
177 }
178
179 if (sharding.IsReplicated()) {
180 return shape;
181 }
182 if (sharding.IsTileMaximal()) {
183 if (partition_id == *sharding.UniqueDevice()) {
184 return shape;
185 }
186 return ShapeUtil::MakeTupleShape({});
187 }
188
189 auto partition_shape = shape;
190 std::vector<int64> tile_offset =
191 sharding.TileOffsetForDevice(shape, partition_id);
192 std::vector<int64> tile_limit =
193 sharding.TileLimitForDevice(shape, partition_id);
194 for (int64_t i = 0; i < tile_offset.size(); ++i) {
195 if (sharding.UsesDevice(partition_id)) {
196 partition_shape.set_dimensions(i, tile_limit[i] - tile_offset[i]);
197 } else {
198 partition_shape.set_dimensions(i, 0);
199 }
200 }
201 return partition_shape;
202 }
203
MakePartitionOffsets(const Shape & shape,const HloSharding & sharding,HloInstruction * partition_id,SpmdBuilder * b,absl::Span<const int64> dims)204 std::vector<HloInstruction*> MakePartitionOffsets(
205 const Shape& shape, const HloSharding& sharding,
206 HloInstruction* partition_id, SpmdBuilder* b,
207 absl::Span<const int64> dims) {
208 CHECK(!shape.IsTuple());
209
210 std::vector<std::vector<int32>> offset_arrays(shape.rank());
211 for (int64_t i = 0; i < shape.rank(); ++i) {
212 offset_arrays[i].resize(sharding.tile_assignment().num_elements());
213 }
214 auto shard_shape = MakePartitionedShape(shape, sharding);
215 sharding.tile_assignment().Each(
216 [&](absl::Span<const int64> indices, int64_t device) {
217 for (int64_t i = 0; i < shape.rank(); ++i) {
218 offset_arrays[i][device] = indices[i] * shard_shape.dimensions(i);
219 }
220 });
221 std::vector<HloInstruction*> offsets;
222 for (int64_t i = 0; i < shape.rank(); ++i) {
223 if (sharding.tile_assignment().dim(i) == 1 ||
224 (!dims.empty() && !absl::c_linear_search(dims, i))) {
225 offsets.push_back(b->AddInstruction(
226 HloInstruction::CreateConstant(LiteralUtil::Zero(S32))));
227 } else {
228 auto offset_table = b->AddInstruction(HloInstruction::CreateConstant(
229 LiteralUtil::CreateR1<int32>(offset_arrays[i])));
230 auto index = b->AddInstruction(HloInstruction::CreateDynamicSlice(
231 ShapeUtil::MakeShape(S32, {1}), offset_table, {partition_id}, {1}));
232 offsets.push_back(b->AddInstruction(
233 HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), index)));
234 }
235 }
236 return offsets;
237 }
238
MakeTiledPartitionOrdinals(const HloSharding & sharding,HloInstruction * partition_id,SpmdBuilder * b)239 std::vector<HloInstruction*> MakeTiledPartitionOrdinals(
240 const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b) {
241 CHECK(!sharding.IsTileMaximal());
242 auto dimensions = sharding.tile_assignment().dimensions();
243 if (sharding.ReplicateOnLastTileDim()) {
244 dimensions.pop_back();
245 }
246 auto table_shape = ShapeUtil::MakeShape(S32, dimensions);
247 return MakePartitionOffsets(table_shape, sharding, partition_id, b);
248 }
249
PadToShape(HloInstruction * hlo,const Shape & padded_shape,SpmdBuilder * b,HloComputation * computation)250 HloInstruction* PadToShape(HloInstruction* hlo, const Shape& padded_shape,
251 SpmdBuilder* b, HloComputation* computation) {
252 CHECK(b == nullptr || computation == nullptr);
253 if (ShapeUtil::Compatible(hlo->shape(), padded_shape)) {
254 return hlo;
255 }
256 PaddingConfig padding_config;
257 for (int64_t i = 0; i < padded_shape.rank(); ++i) {
258 auto padding_config_dim = padding_config.add_dimensions();
259 padding_config_dim->set_edge_padding_low(0);
260 padding_config_dim->set_interior_padding(0);
261 padding_config_dim->set_edge_padding_high(padded_shape.dimensions(i) -
262 hlo->shape().dimensions(i));
263 }
264 auto add_hlo = [&](std::unique_ptr<HloInstruction> to_add) {
265 if (b == nullptr) {
266 return computation->AddInstruction(std::move(to_add));
267 }
268 return b->AddInstruction(std::move(to_add));
269 };
270 auto zero = add_hlo(HloInstruction::CreateConstant(
271 LiteralUtil::Zero(hlo->shape().element_type())));
272 return add_hlo(
273 HloInstruction::CreatePad(padded_shape, hlo, zero, padding_config));
274 }
275
GetPaddedShapeForUnevenPartitioning(const Shape & base_shape,const HloSharding & sharding)276 Shape GetPaddedShapeForUnevenPartitioning(const Shape& base_shape,
277 const HloSharding& sharding) {
278 if (sharding.IsTileMaximal()) {
279 return base_shape;
280 }
281 if (EvenlyPartitions(base_shape, sharding)) {
282 return base_shape;
283 }
284 auto shard_shape = MakePartitionedShape(base_shape, sharding);
285 Shape padded_base_shape = base_shape;
286 for (int64_t i = 0; i < padded_base_shape.rank(); ++i) {
287 padded_base_shape.set_dimensions(
288 i, shard_shape.dimensions(i) * sharding.tile_assignment().dim(i));
289 }
290 return padded_base_shape;
291 }
292
PadBaseShapeBeforeUnevenTiledSharding(HloInstruction * hlo,const HloSharding & sharding,SpmdBuilder * b)293 HloInstruction* PadBaseShapeBeforeUnevenTiledSharding(
294 HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b) {
295 auto padded_base_shape =
296 GetPaddedShapeForUnevenPartitioning(hlo->shape(), sharding);
297 if (ShapeUtil::Compatible(padded_base_shape, hlo->shape())) {
298 return hlo;
299 }
300 return PadToShape(hlo, padded_base_shape, b);
301 }
302
PartialReplicateReshardCompatibleSharding(const HloSharding & partial_sharding,const HloSharding & target_sharding)303 absl::optional<HloSharding> PartialReplicateReshardCompatibleSharding(
304 const HloSharding& partial_sharding, const HloSharding& target_sharding) {
305 if (!partial_sharding.ReplicateOnLastTileDim()) {
306 return absl::nullopt;
307 }
308 int64_t rank = partial_sharding.tile_assignment().num_dimensions() - 1;
309 int64_t target_rank = target_sharding.tile_assignment().num_dimensions() -
310 (target_sharding.ReplicateOnLastTileDim() ? 1 : 0);
311 if (target_rank != rank) {
312 return absl::nullopt;
313 }
314
315 absl::flat_hash_map<int64, int64> device_to_replication_group;
316 partial_sharding.tile_assignment().Each(
317 [&](absl::Span<const int64> indices, int64_t device) {
318 int64_t gid = 0;
319 for (int64_t i = 0; i < rank; ++i) {
320 gid *= partial_sharding.tile_assignment().dim(i);
321 gid += indices[i];
322 }
323 device_to_replication_group[device] = gid;
324 });
325
326 // A dimension is expanded when target_tile_size > partial_tile_size and
327 // target_tile_size % partial_tile_size == 0.
328 // expand_tile_dims_positions is the index of the expand_dim.
329 std::vector<int64> expand_tile_dims_indices(rank, -1);
330 // expand_tile_size = target_tile_size / partial_tile_size.
331 std::vector<int64> expand_tile_sizes;
332 int num_expand_dims = 0;
333 for (int64_t dim = 0; dim < rank; dim++) {
334 int64_t partial_tile_size = partial_sharding.tile_assignment().dim(dim);
335 int64_t target_tile_size = target_sharding.tile_assignment().dim(dim);
336 if (target_tile_size % partial_tile_size != 0 ||
337 target_tile_size < partial_tile_size) {
338 return absl::nullopt;
339 }
340
341 if (target_tile_size > partial_tile_size) {
342 expand_tile_dims_indices[dim] = num_expand_dims++;
343 expand_tile_sizes.emplace_back(target_tile_size / partial_tile_size);
344 }
345 }
346
347 // Reshape the partial replicate tile_dimensions.
348 int64_t num_target_replication = 1;
349 if (target_sharding.ReplicateOnLastTileDim()) {
350 num_target_replication =
351 target_sharding.tile_assignment().dimensions().back();
352 }
353 auto reshape_dimensions = partial_sharding.tile_assignment().dimensions();
354 int64_t num_replication = reshape_dimensions.back();
355 if (num_replication / num_target_replication != Product(expand_tile_sizes) ||
356 num_replication % num_target_replication != 0) {
357 return absl::nullopt;
358 }
359
360 reshape_dimensions.pop_back();
361 reshape_dimensions.insert(reshape_dimensions.end(), expand_tile_sizes.begin(),
362 expand_tile_sizes.end());
363
364 if (target_sharding.ReplicateOnLastTileDim()) {
365 reshape_dimensions.push_back(num_target_replication);
366 }
367
368 auto reshape_tile_assignment = partial_sharding.tile_assignment();
369 reshape_tile_assignment.Reshape(reshape_dimensions);
370
371 // Transpose.
372 std::vector<int64> perm;
373 perm.reserve(rank + expand_tile_sizes.size());
374 for (int64_t dim = 0; dim < rank; dim++) {
375 perm.emplace_back(dim);
376 if (expand_tile_dims_indices[dim] > -1) {
377 perm.emplace_back(expand_tile_dims_indices[dim] + rank);
378 }
379 }
380 auto transpose_sharding = hlo_sharding_util::TransposeSharding(
381 target_sharding.ReplicateOnLastTileDim()
382 ? HloSharding::PartialTile(reshape_tile_assignment)
383 : HloSharding::Tile(reshape_tile_assignment),
384 perm);
385
386 // Reshape to target shape
387 auto transpose_tile_assignment = transpose_sharding.tile_assignment();
388 transpose_tile_assignment.Reshape(
389 target_sharding.tile_assignment().dimensions());
390
391 bool groups_matching = true;
392 target_sharding.tile_assignment().Each(
393 [&](absl::Span<const int64> indices, int64_t device) {
394 if (device_to_replication_group[device] !=
395 device_to_replication_group[transpose_tile_assignment(indices)]) {
396 groups_matching = false;
397 }
398 });
399
400 if (groups_matching) {
401 return target_sharding;
402 }
403 return target_sharding.ReplicateOnLastTileDim()
404 ? HloSharding::PartialTile(transpose_tile_assignment)
405 : HloSharding::Tile(transpose_tile_assignment);
406 }
407
TileToPartialReplicateHaloExchange(HloInstruction * hlo,const Shape & base_shape,const HloSharding & src_sharding,const HloSharding & dst_sharding,const std::vector<int64> & replicate_dims,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,HloInstruction * partition_id,SpmdBuilder * b)408 absl::optional<HloInstruction*> TileToPartialReplicateHaloExchange(
409 HloInstruction* hlo, const Shape& base_shape,
410 const HloSharding& src_sharding, const HloSharding& dst_sharding,
411 const std::vector<int64>& replicate_dims,
412 const SPMDCollectiveOpsCreator& collective_ops_creator,
413 int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) {
414 // Source is tile sharding.
415 auto padded_src_shape =
416 GetPaddedShapeForUnevenPartitioning(base_shape, src_sharding);
417 // Target is partial replicate.
418 auto padded_dst_shape =
419 GetPaddedShapeForUnevenPartitioning(base_shape, dst_sharding);
420 if (ShapeUtil::Compatible(padded_dst_shape, hlo->shape())) {
421 return hlo;
422 }
423
424 auto partition_ordinals =
425 MakeTiledPartitionOrdinals(dst_sharding, partition_id, b);
426
427 auto result = hlo;
428 auto hlo_shape = hlo->shape();
429 for (auto dim : replicate_dims) {
430 int64_t dst_shard_count = dst_sharding.tile_assignment().dim(dim);
431 int64_t src_per_shard_size =
432 padded_src_shape.dimensions(dim) / dst_shard_count;
433 // Calculate per shard size using the sharding to compare if dst_sharding
434 // needs more padding at the end.
435 int64_t dst_per_shard_size =
436 padded_dst_shape.dimensions(dim) / dst_shard_count;
437
438 // If src per shard doesn't have redundant data.
439 if (src_per_shard_size <= dst_per_shard_size || dst_shard_count == 1) {
440 continue;
441 }
442
443 // If src_per_shard * replicate_factor > dst_per_shard , need to
444 // re-distribute the data between each shard using collective permute. For
445 // example, if dimension size is 6 and shard 4 ways in the src but needs to
446 // shard 2 ways in the dst. 4 way sharding has 2 element in each shard,
447 // while 2 way sharding has 3 elements, the last element in the first shard
448 // will be sliced out. re-distribution is needed.
449 //
450 // 1. Calculate left_halo size.
451 // left-halo size is
452 // (src_per_shard_size - dst_per_shard_size) * i / replicate_factor
453 int64_t replicate_factor = src_sharding.tile_assignment().dim(dim) /
454 dst_sharding.tile_assignment().dim(dim);
455 OffsetCalculation left_halo_size_function =
456 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
457 src_per_shard_size - dst_per_shard_size, 0, replicate_factor));
458
459 // 2. Calculate right_halo size.
460 // right-halo size is 0
461 OffsetCalculation right_halo_size_function =
462 OffsetCalculation(MultiplyAddDivideOffsetCalculation(0, 0, 1));
463
464 auto concat = result;
465 // 3. Halo exchange.
466 auto halo_exchange_result = ExchangeHalo(
467 result, left_halo_size_function, right_halo_size_function, dim,
468 src_sharding, collective_ops_creator, next_channel_id, b);
469
470 if (halo_exchange_result.has_value()) {
471 concat = halo_exchange_result.value();
472 } else {
473 return absl::nullopt;
474 }
475
476 // 4. Slice the valid result.
477 // Slice offset is
478 // (dst_shard_count - i - 1) *
479 // (src_per_shard_size - dst_per_shard_size)
480 // i is the index in dst_sharindg.
481 auto zero_s32 = b->AddInstruction(
482 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
483 OffsetCalculation start_offset_on_padded_concat_calculation =
484 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
485 dst_per_shard_size - src_per_shard_size,
486 (src_per_shard_size - dst_per_shard_size) * (dst_shard_count - 1),
487 1));
488 auto slice_shape = concat->shape();
489 slice_shape.set_dimensions(dim,
490 padded_src_shape.dimensions(dim) /
491 src_sharding.tile_assignment().dim(dim));
492 std::vector<HloInstruction*> slice_offsets(concat->shape().rank(),
493 zero_s32);
494 slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate(
495 partition_ordinals[dim], b);
496 result = b->AddInstruction(HloInstruction::CreateDynamicSlice(
497 slice_shape, concat, slice_offsets, slice_shape.dimensions()));
498 }
499 return result;
500 }
501
PadFromPartialReplicateShape(HloInstruction * hlo,const Shape & base_shape,const HloSharding & src_sharding,const HloSharding & dst_sharding,const std::vector<int64> & expand_tile_dims,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,HloInstruction * partition_id,SpmdBuilder * b)502 absl::optional<HloInstruction*> PadFromPartialReplicateShape(
503 HloInstruction* hlo, const Shape& base_shape,
504 const HloSharding& src_sharding, const HloSharding& dst_sharding,
505 const std::vector<int64>& expand_tile_dims,
506 const SPMDCollectiveOpsCreator& collective_ops_creator,
507 int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) {
508 auto padded_src_shape =
509 GetPaddedShapeForUnevenPartitioning(base_shape, src_sharding);
510 auto padded_dst_shape =
511 GetPaddedShapeForUnevenPartitioning(base_shape, dst_sharding);
512 if (ShapeUtil::Compatible(padded_dst_shape, hlo->shape())) {
513 return hlo;
514 }
515
516 auto partition_ordinals =
517 MakeTiledPartitionOrdinals(src_sharding, partition_id, b);
518
519 HloInstruction* result = hlo;
520 auto zero = b->AddInstruction(HloInstruction::CreateConstant(
521 LiteralUtil::Zero(hlo->shape().element_type())));
522 std::vector<int64> expand_dims_without_halo_exchange;
523 // Pad the dimensions needs halo exchange and record the padded dims that
524 // won't need halo exchange.
525 for (auto dim : expand_tile_dims) {
526 int64_t src_shard_count = src_sharding.tile_assignment().dim(dim);
527 int64_t src_per_shard_size =
528 padded_src_shape.dimensions(dim) / src_shard_count;
529 // Calculate per shard size using the sharding to compare if dst_sharding
530 // needs more padding at the end.
531 int64_t dst_per_shard_size =
532 padded_dst_shape.dimensions(dim) / src_shard_count;
533
534 // If dst_sharding doesn't need more padding at the end.
535 if (src_per_shard_size >= dst_per_shard_size) {
536 continue;
537 }
538 // If src sharding at this dimension is not partitoned, simply pad to
539 // the desired shape.
540 if (src_shard_count == 1) {
541 expand_dims_without_halo_exchange.emplace_back(dim);
542 continue;
543 }
544
545 // If dst_padding needs more padding at the end, need to re-distribute the
546 // data between each shard using collective permute.
547 // For example, if dimension size is 6 and shard 2 ways in the src but
548 // needs to shard 4 ways in the dst. 4 ways needs padding 2 0s at the end
549 // and has 2 elements at each shard, while 2 way sharding has 3 elements
550 // in each shard, re-distribution is needed.
551 //
552 // 1. Calculate left_halo size.
553 // left-halo size is 0
554 OffsetCalculation left_halo_size_function =
555 OffsetCalculation(MultiplyAddDivideOffsetCalculation(0, 0, 1));
556
557 // 2. Calculate right_halo size.
558 // right-halo size is D * (i + 1) - S * (i + 1) = (D - S) * i + (D - S)
559 OffsetCalculation right_halo_size_function =
560 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
561 dst_per_shard_size - src_per_shard_size,
562 dst_per_shard_size - src_per_shard_size, 1));
563
564 auto concat = result;
565 // 3. Halo exchange.
566 auto halo_exchange_result = ExchangeHalo(
567 result, left_halo_size_function, right_halo_size_function, dim,
568 src_sharding, collective_ops_creator, next_channel_id, b);
569
570 if (halo_exchange_result.has_value()) {
571 concat = halo_exchange_result.value();
572 } else {
573 return absl::nullopt;
574 }
575
576 // 4. Pad.
577 std::vector<int64> zero_padding(concat->shape().rank());
578 PaddingConfig pad_config = window_util::MakeSymmetricPadding(zero_padding);
579 pad_config.mutable_dimensions(dim)->set_edge_padding_low(0);
580 int64_t max_right_halo_size =
581 right_halo_size_function.MaxInRange(0, src_shard_count - 1);
582 pad_config.mutable_dimensions(dim)->set_edge_padding_high(std::max(
583 int64{0}, padded_dst_shape.dimensions(dim) -
584 padded_src_shape.dimensions(dim) - max_right_halo_size));
585 auto padded_concat_shape = ShapeInference::InferPadShape(
586 concat->shape(), zero->shape(), pad_config)
587 .ValueOrDie();
588 concat = b->AddInstruction(HloInstruction::CreatePad(
589 padded_concat_shape, concat, zero, pad_config));
590
591 // 5. Slice the valid result.
592 // Slice offset is (D-S) * i
593 auto zero_s32 = b->AddInstruction(
594 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
595 OffsetCalculation start_offset_on_padded_concat_calculation =
596 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
597 dst_per_shard_size - src_per_shard_size, 0, 1));
598 auto slice_shape = concat->shape();
599 slice_shape.set_dimensions(dim, dst_per_shard_size);
600 std::vector<HloInstruction*> slice_offsets(concat->shape().rank(),
601 zero_s32);
602 slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate(
603 partition_ordinals[dim], b);
604 result = b->AddInstruction(HloInstruction::CreateDynamicSlice(
605 slice_shape, concat, slice_offsets, slice_shape.dimensions()));
606 }
607
608 // Pad other dimensions that won't need halo exchange with a single pad.
609 if (!expand_dims_without_halo_exchange.empty()) {
610 std::vector<int64> zero_padding(result->shape().rank());
611 PaddingConfig pad_config = window_util::MakeSymmetricPadding(zero_padding);
612
613 auto padded_shape = result->shape();
614 for (auto dim : expand_dims_without_halo_exchange) {
615 pad_config.mutable_dimensions(dim)->set_edge_padding_low(0);
616 pad_config.mutable_dimensions(dim)->set_edge_padding_high(
617 padded_dst_shape.dimensions(dim) - padded_src_shape.dimensions(dim));
618 padded_shape.set_dimensions(dim, result->shape().dimensions(dim) +
619 padded_dst_shape.dimensions(dim) -
620 padded_src_shape.dimensions(dim));
621 }
622 result = b->AddInstruction(
623 HloInstruction::CreatePad(padded_shape, result, zero, pad_config));
624 }
625
626 return result;
627 }
628
UniqueTiledDim(const HloSharding & sharding)629 absl::optional<int64> UniqueTiledDim(const HloSharding& sharding) {
630 if (sharding.IsTileMaximal()) {
631 return absl::nullopt;
632 }
633 int64_t dim = -1;
634 int64_t rank = sharding.ReplicateOnLastTileDim()
635 ? sharding.tile_assignment().num_dimensions() - 1
636 : sharding.tile_assignment().num_dimensions();
637 for (int64_t i = 0; i < rank; ++i) {
638 if (sharding.tile_assignment().dim(i) > 1) {
639 if (dim != -1) {
640 return absl::nullopt;
641 }
642 dim = i;
643 }
644 }
645 CHECK_NE(dim, -1);
646 return dim;
647 }
648
MultiplyAddDivideOffsetCalculation(int64_t multiplier,int64_t offset,int64_t divisor)649 MultiplyAddDivideOffsetCalculation::MultiplyAddDivideOffsetCalculation(
650 int64_t multiplier, int64_t offset, int64_t divisor)
651 : multiplier_(multiplier), offset_(offset), divisor_(divisor) {
652 CHECK_GT(divisor_, 0);
653 Simplify();
654 }
655
operator -(const MultiplyAddDivideOffsetCalculation & other) const656 OffsetCalculation MultiplyAddDivideOffsetCalculation::operator-(
657 const MultiplyAddDivideOffsetCalculation& other) const {
658 if (divisor_ == 1 && other.divisor_ == 1) {
659 return OffsetCalculation(MultiplyAddDivideOffsetCalculation(
660 multiplier_ - other.multiplier_, offset_ - other.offset_, 1));
661 }
662 return OffsetCalculation(HloOpcode::kSubtract, *this, other);
663 }
664
Simplify()665 void MultiplyAddDivideOffsetCalculation::Simplify() {
666 // We could simplify the calculation when multiplier is a multiple of
667 // divisor_. However, when offset_ is not a multiple of divisor_, we must
668 // make sure that offset_ and multiplier_ are both non-negative or both
669 // non-positive. E.g., (3 * i - 1) / 3 is not equivalent to i or i - 1.
670 if (divisor_ != 1 && multiplier_ % divisor_ == 0 &&
671 (offset_ % divisor_ == 0 || offset_ * multiplier_ > 0)) {
672 multiplier_ /= divisor_;
673 offset_ /= divisor_;
674 divisor_ = 1;
675 }
676 }
677
Calculate(int64_t shard_ordinal) const678 int64 MultiplyAddDivideOffsetCalculation::Calculate(
679 int64_t shard_ordinal) const {
680 return (shard_ordinal * multiplier_ + offset_) / divisor_;
681 }
682
Calculate(HloInstruction * shard_ordinal,SpmdBuilder * b) const683 HloInstruction* MultiplyAddDivideOffsetCalculation::Calculate(
684 HloInstruction* shard_ordinal, SpmdBuilder* b) const {
685 auto scalar_shape = ShapeUtil::MakeShape(S32, {});
686 if (multiplier_ == 0) {
687 return b->AddInstruction(HloInstruction::CreateConstant(
688 LiteralUtil::CreateR0<int32>(offset_ / divisor_)));
689 }
690 HloInstruction* result = shard_ordinal;
691 if (multiplier_ != 1) {
692 result = b->AddInstruction(HloInstruction::CreateBinary(
693 scalar_shape, HloOpcode::kMultiply, shard_ordinal,
694 b->AddInstruction(HloInstruction::CreateConstant(
695 LiteralUtil::CreateR0<int32>(multiplier_)))));
696 }
697 if (offset_ != 0) {
698 auto offset = b->AddInstruction(
699 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(offset_)));
700 result = b->AddInstruction(HloInstruction::CreateBinary(
701 scalar_shape, HloOpcode::kAdd, result, offset));
702 }
703 if (divisor_ != 1) {
704 auto divisor = b->AddInstruction(
705 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(divisor_)));
706 result = b->AddInstruction(HloInstruction::CreateBinary(
707 scalar_shape, HloOpcode::kDivide, result, divisor));
708 }
709 return result;
710 }
711
MaxInRange(int64_t start_ordinal,int64_t limit_ordinal) const712 int64 MultiplyAddDivideOffsetCalculation::MaxInRange(
713 int64_t start_ordinal, int64_t limit_ordinal) const {
714 int64_t max = Calculate(start_ordinal);
715 for (int64_t i = start_ordinal + 1; i < limit_ordinal; ++i) {
716 max = std::max(max, Calculate(i));
717 }
718 return max;
719 }
720
operator =(const OffsetCalculation & other)721 OffsetCalculation& OffsetCalculation::operator=(
722 const OffsetCalculation& other) {
723 opcode_ = other.opcode_;
724 copy_from_ = other.copy_from_;
725 if (opcode_ != HloOpcode::kCopy) {
726 lhs_ = absl::make_unique<OffsetCalculation>(*other.lhs_);
727 rhs_ = absl::make_unique<OffsetCalculation>(*other.rhs_);
728 }
729 return *this;
730 }
731
IsConstant() const732 bool OffsetCalculation::IsConstant() const {
733 if (opcode_ == HloOpcode::kCopy) {
734 return copy_from_.IsConstant();
735 }
736 if (opcode_ == HloOpcode::kSubtract && *lhs_ == *rhs_) {
737 return true;
738 }
739 return lhs_->IsConstant() && rhs_->IsConstant();
740 }
741
operator -(const OffsetCalculation & other) const742 OffsetCalculation OffsetCalculation::operator-(
743 const OffsetCalculation& other) const {
744 if (opcode_ == HloOpcode::kCopy && other.opcode_ == HloOpcode::kCopy) {
745 return copy_from_ - other.copy_from_;
746 }
747 return OffsetCalculation(HloOpcode::kSubtract, *this, other);
748 }
749
operator ==(const OffsetCalculation & other) const750 bool OffsetCalculation::operator==(const OffsetCalculation& other) const {
751 if (opcode_ != other.opcode_) {
752 return false;
753 }
754 if (opcode_ == HloOpcode::kCopy) {
755 return copy_from_ == other.copy_from_;
756 }
757 return *lhs_ == *other.lhs_ && *rhs_ == *other.rhs_;
758 }
759
Calculate(int64_t shard_ordinal) const760 int64 OffsetCalculation::Calculate(int64_t shard_ordinal) const {
761 switch (opcode_) {
762 case HloOpcode::kCopy:
763 return copy_from_.Calculate(shard_ordinal);
764 case HloOpcode::kSubtract:
765 return lhs_->Calculate(shard_ordinal) - rhs_->Calculate(shard_ordinal);
766 case HloOpcode::kMultiply:
767 return lhs_->Calculate(shard_ordinal) * rhs_->Calculate(shard_ordinal);
768 default:
769 LOG(FATAL) << "Should not happen";
770 }
771 }
772
Calculate(HloInstruction * shard_ordinal,SpmdBuilder * b) const773 HloInstruction* OffsetCalculation::Calculate(HloInstruction* shard_ordinal,
774 SpmdBuilder* b) const {
775 if (opcode_ == HloOpcode::kCopy) {
776 return copy_from_.Calculate(shard_ordinal, b);
777 }
778 auto lhs = lhs_->Calculate(shard_ordinal, b);
779 auto rhs = rhs_->Calculate(shard_ordinal, b);
780 return b->AddInstruction(
781 HloInstruction::CreateBinary(lhs->shape(), opcode_, lhs, rhs));
782 }
783
MaxInRange(int64_t start_ordinal,int64_t limit_ordinal) const784 int64 OffsetCalculation::MaxInRange(int64_t start_ordinal,
785 int64_t limit_ordinal) const {
786 if (IsConstant()) {
787 return Calculate(start_ordinal);
788 }
789 if (opcode_ == HloOpcode::kCopy) {
790 return std::max(Calculate(start_ordinal), Calculate(limit_ordinal - 1));
791 }
792 int64_t max = Calculate(start_ordinal);
793 for (int64_t i = start_ordinal + 1; i < limit_ordinal; ++i) {
794 max = std::max(max, Calculate(i));
795 }
796 return max;
797 }
798
ExchangeHalo(HloInstruction * hlo,const OffsetCalculation & left_halo_size_function,const OffsetCalculation & right_halo_size_function,int64_t dim,const HloSharding & target,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,SpmdBuilder * b)799 absl::optional<HloInstruction*> ExchangeHalo(
800 HloInstruction* hlo, const OffsetCalculation& left_halo_size_function,
801 const OffsetCalculation& right_halo_size_function, int64_t dim,
802 const HloSharding& target,
803 const SPMDCollectiveOpsCreator& collective_ops_creator,
804 int64* next_channel_id, SpmdBuilder* b) {
805 int64_t input_shard_size = hlo->shape().dimensions(dim);
806 int64_t shard_count = target.tile_assignment().dim(dim);
807
808 std::vector<HloInstruction*> concat_pieces;
809
810 int64_t max_left_halo_size =
811 left_halo_size_function.MaxInRange(1, shard_count);
812 int64_t max_right_halo_size =
813 right_halo_size_function.MaxInRange(0, shard_count - 1);
814 if (max_left_halo_size + max_right_halo_size + input_shard_size >=
815 input_shard_size * shard_count &&
816 (max_left_halo_size > input_shard_size ||
817 max_right_halo_size > input_shard_size)) {
818 return absl::nullopt;
819 }
820 // Since max halo sizes could be negative, we only need to include data within
821 // certain bounds. Useful region is [left_bound, right_bound).
822 const int64_t left_bound =
823 -left_halo_size_function.MaxInRange(0, shard_count);
824 const int64_t right_bound =
825 input_shard_size + right_halo_size_function.MaxInRange(0, shard_count);
826 if (left_bound >= right_bound) {
827 return absl::nullopt;
828 }
829 // Left halo.
830 for (int64_t i = CeilOfRatio(max_left_halo_size, input_shard_size) - 1;
831 i >= 0 && (-i - 1) * input_shard_size < right_bound; --i) {
832 std::vector<std::pair<int64, int64>> source_target_pairs;
833 target.tile_assignment().Each(
834 [&](absl::Span<const int64> indices, int64_t device) {
835 if (indices[dim] > i) {
836 std::vector<int64> source_indices(indices.begin(), indices.end());
837 source_indices[dim] -= i + 1;
838 source_target_pairs.emplace_back(
839 target.tile_assignment()(source_indices), device);
840 }
841 });
842 int64_t halo_size_including_skips =
843 std::min(max_left_halo_size - input_shard_size * i, input_shard_size);
844 int64_t halo_right_skips =
845 std::max<int64>(-i * input_shard_size - right_bound, 0);
846 int64_t halo_size = halo_size_including_skips - halo_right_skips;
847 auto halo_shape = hlo->shape();
848 auto source_halo_slice = hlo;
849 if (halo_size != hlo->shape().dimensions(dim)) {
850 halo_shape.set_dimensions(dim, halo_size);
851 std::vector<int64> halo_start_indices(halo_shape.rank(), 0);
852 halo_start_indices[dim] =
853 hlo->shape().dimensions(dim) - halo_size_including_skips;
854 std::vector<int64> halo_limit_indices(hlo->shape().dimensions().begin(),
855 hlo->shape().dimensions().end());
856 halo_limit_indices[dim] -= halo_right_skips;
857 std::vector<int64> halo_slice_strides(halo_shape.rank(), 1);
858 source_halo_slice = b->AddInstruction(
859 HloInstruction::CreateSlice(halo_shape, hlo, halo_start_indices,
860 halo_limit_indices, halo_slice_strides));
861 }
862 auto left_halo =
863 collective_ops_creator.create_cross_partition_collective_permute(
864 b, source_halo_slice, source_target_pairs, (*next_channel_id)++);
865 concat_pieces.push_back(left_halo);
866 }
867
868 if (left_bound < input_shard_size && right_bound > 0) {
869 int64_t self_start = std::max<int64>(0, left_bound);
870 int64_t self_limit = std::min<int64>(input_shard_size, right_bound);
871 if (self_start == 0 && self_limit == input_shard_size) {
872 concat_pieces.push_back(hlo);
873 } else {
874 auto self_shape = hlo->shape();
875 self_shape.set_dimensions(dim, self_limit - self_start);
876 std::vector<int64> start_indices(self_shape.rank(), 0);
877 start_indices[dim] = self_start;
878 std::vector<int64> limit_indices(hlo->shape().dimensions().begin(),
879 hlo->shape().dimensions().end());
880 limit_indices[dim] = self_limit;
881 std::vector<int64> slice_strides(self_shape.rank(), 1);
882 concat_pieces.push_back(b->AddInstruction(HloInstruction::CreateSlice(
883 self_shape, hlo, start_indices, limit_indices, slice_strides)));
884 }
885 }
886
887 int64_t skipped_right_halos =
888 std::min<int64>(std::max<int64>(left_bound - input_shard_size, 0),
889 std::max<int64>(max_right_halo_size, 0)) /
890 input_shard_size;
891 // Right halo.
892 for (int64_t i = skipped_right_halos;
893 i < CeilOfRatio(max_right_halo_size, input_shard_size); ++i) {
894 std::vector<std::pair<int64, int64>> source_target_pairs;
895 target.tile_assignment().Each(
896 [&](absl::Span<const int64> indices, int64_t device) {
897 if (indices[dim] > i) {
898 std::vector<int64> target_indices(indices.begin(), indices.end());
899 target_indices[dim] -= i + 1;
900 source_target_pairs.emplace_back(
901 device, target.tile_assignment()(target_indices));
902 }
903 });
904 int64_t halo_size_including_skips =
905 std::min(max_right_halo_size - input_shard_size * i, input_shard_size);
906 int64_t halo_left_skips =
907 std::max<int64>(left_bound - (i + 1) * input_shard_size, 0);
908 int64_t halo_size = halo_size_including_skips - halo_left_skips;
909 auto halo_shape = hlo->shape();
910 HloInstruction* source_halo_slice = hlo;
911 if (halo_size != halo_shape.dimensions(dim)) {
912 halo_shape.set_dimensions(dim, halo_size);
913 std::vector<int64> halo_start_indices(halo_shape.rank(), 0);
914 halo_start_indices[dim] = halo_left_skips;
915 std::vector<int64> halo_limit_indices(halo_shape.dimensions().begin(),
916 halo_shape.dimensions().end());
917 halo_limit_indices[dim] += halo_left_skips;
918 std::vector<int64> halo_slice_strides(halo_shape.rank(), 1);
919 source_halo_slice = b->AddInstruction(
920 HloInstruction::CreateSlice(halo_shape, hlo, halo_start_indices,
921 halo_limit_indices, halo_slice_strides));
922 }
923 auto right_halo =
924 collective_ops_creator.create_cross_partition_collective_permute(
925 b, source_halo_slice, source_target_pairs, (*next_channel_id)++);
926 concat_pieces.push_back(right_halo);
927 }
928
929 auto concat = concat_pieces[0];
930 // Concat with halos/padding.
931 if (concat_pieces.size() > 1) {
932 auto concat_shape = hlo->shape();
933 int64_t concat_dim_size = 0;
934 for (auto piece : concat_pieces) {
935 concat_dim_size += piece->shape().dimensions(dim);
936 }
937 concat_shape.set_dimensions(dim, concat_dim_size);
938 concat = b->AddInstruction(
939 HloInstruction::CreateConcatenate(concat_shape, concat_pieces, dim));
940 }
941
942 return concat;
943 }
944
ExchangeHalo(HloInstruction * hlo,std::vector<OffsetCalculation> left_halo_size_functions,std::vector<OffsetCalculation> right_halo_size_functions,const HloSharding & target,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,SpmdBuilder * b)945 absl::optional<HloInstruction*> ExchangeHalo(
946 HloInstruction* hlo,
947 std::vector<OffsetCalculation> left_halo_size_functions,
948 std::vector<OffsetCalculation> right_halo_size_functions,
949 const HloSharding& target,
950 const SPMDCollectiveOpsCreator& collective_ops_creator,
951 int64* next_channel_id, SpmdBuilder* b) {
952 CHECK(left_halo_size_functions.size() == hlo->shape().rank());
953 CHECK(right_halo_size_functions.size() == hlo->shape().rank());
954
955 HloInstruction* visiting_hlo = hlo;
956 for (int dim = 0; dim < hlo->shape().rank(); ++dim) {
957 auto concat = ExchangeHalo(visiting_hlo, left_halo_size_functions[dim],
958 right_halo_size_functions[dim], dim, target,
959 collective_ops_creator, next_channel_id, b);
960 if (!concat) {
961 return absl::nullopt;
962 }
963 visiting_hlo = *concat;
964 }
965 return visiting_hlo;
966 }
967
ExchangeHaloAndGetValidData(HloInstruction * hlo,const Shape & base_shape,const OffsetCalculation & left_halo_size_function,const OffsetCalculation & right_halo_size_function,int64_t explicit_left_padding_on_full_shape,int64_t padded_full_shape_size,int64_t shard_size_with_halo,int64_t dim,const HloSharding & target,HloInstruction * offset_on_padded_shape,HloInstruction * pad_value,HloInstruction * partition_ordinal,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,SpmdBuilder * b,bool mask_invalid_region)968 absl::optional<HloInstruction*> ExchangeHaloAndGetValidData(
969 HloInstruction* hlo, const Shape& base_shape,
970 const OffsetCalculation& left_halo_size_function,
971 const OffsetCalculation& right_halo_size_function,
972 int64_t explicit_left_padding_on_full_shape, int64_t padded_full_shape_size,
973 int64_t shard_size_with_halo, int64_t dim, const HloSharding& target,
974 HloInstruction* offset_on_padded_shape, HloInstruction* pad_value,
975 HloInstruction* partition_ordinal,
976 const SPMDCollectiveOpsCreator& collective_ops_creator,
977 int64* next_channel_id, SpmdBuilder* b, bool mask_invalid_region) {
978 auto halo_exchange_result =
979 ExchangeHalo(hlo, left_halo_size_function, right_halo_size_function, dim,
980 target, collective_ops_creator, next_channel_id, b);
981 if (!halo_exchange_result) {
982 return absl::nullopt;
983 }
984 auto concat = *halo_exchange_result;
985 int64_t shard_count = target.tile_assignment().dim(dim);
986 int64_t max_left_halo_size =
987 left_halo_size_function.MaxInRange(1, shard_count);
988
989 // Now we determine if we need extra padding after the concat.
990 //
991 // The max of halo size or the first shard's explicit left padding.
992 int64_t max_left_halo_or_padding_size =
993 std::max(max_left_halo_size, explicit_left_padding_on_full_shape);
994 // The calculation that returns the dynamic slice index for a shard on the
995 // padded concat, which is the difference between
996 // max_left_halo_or_padding_size and its left halo size.
997 auto start_offset_on_padded_concat_calculation =
998 OffsetCalculation(MultiplyAddDivideOffsetCalculation(
999 0, max_left_halo_or_padding_size, 1)) -
1000 left_halo_size_function;
1001
1002 // See if we need to pad the concat before dynamic slice.
1003 int64_t extra_left_padding =
1004 std::max(int64{0}, max_left_halo_or_padding_size -
1005 std::max(int64{0}, max_left_halo_size));
1006 int64_t extra_right_padding =
1007 start_offset_on_padded_concat_calculation.MaxInRange(0, shard_count) +
1008 shard_size_with_halo - concat->shape().dimensions(dim) -
1009 extra_left_padding;
1010 extra_right_padding = std::max(int64{0}, extra_right_padding);
1011 if (extra_left_padding > 0 || extra_right_padding > 0) {
1012 PaddingConfig padding_config;
1013 auto padded_concat_shape = concat->shape();
1014 for (int64_t i = 0; i < base_shape.rank(); ++i) {
1015 auto padding_config_dim = padding_config.add_dimensions();
1016 padding_config_dim->set_interior_padding(0);
1017 padding_config_dim->set_edge_padding_low(0);
1018 padding_config_dim->set_edge_padding_high(0);
1019 if (i != dim) {
1020 continue;
1021 }
1022 padding_config_dim->set_edge_padding_low(extra_left_padding);
1023 padding_config_dim->set_edge_padding_high(extra_right_padding);
1024 padded_concat_shape.set_dimensions(dim, concat->shape().dimensions(dim) +
1025 extra_left_padding +
1026 extra_right_padding);
1027 }
1028 concat = b->AddInstruction(HloInstruction::CreatePad(
1029 padded_concat_shape, concat, pad_value, padding_config));
1030 }
1031
1032 auto valid_slice = concat;
1033 if (shard_size_with_halo != concat->shape().dimensions(dim)) {
1034 // Concat is bigger than the shard shape, so we need a dynamic slice.
1035 CHECK_LT(shard_size_with_halo, concat->shape().dimensions(dim));
1036 auto slice_shape = concat->shape();
1037 slice_shape.set_dimensions(dim, shard_size_with_halo);
1038
1039 if (left_halo_size_function.IsConstant() &&
1040 left_halo_size_function.Calculate(0) ==
1041 explicit_left_padding_on_full_shape) {
1042 std::vector<int64> start_indices(slice_shape.rank(), 0);
1043 std::vector<int64> strides(slice_shape.rank(), 1);
1044 valid_slice = b->AddInstruction(
1045 HloInstruction::CreateSlice(slice_shape, concat, start_indices,
1046 slice_shape.dimensions(), strides));
1047 } else {
1048 auto zero = b->AddInstruction(
1049 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
1050 std::vector<HloInstruction*> slice_offsets(base_shape.rank(), zero);
1051 slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate(
1052 partition_ordinal, b);
1053 valid_slice = b->AddInstruction(HloInstruction::CreateDynamicSlice(
1054 slice_shape, concat, slice_offsets, slice_shape.dimensions()));
1055 }
1056 }
1057
1058 if (!mask_invalid_region) {
1059 return valid_slice;
1060 }
1061
1062 int64_t total_right_padding = padded_full_shape_size -
1063 base_shape.dimensions(dim) -
1064 explicit_left_padding_on_full_shape;
1065 // Mask off garbage data due to uneven partition or low/high padding.
1066 if (explicit_left_padding_on_full_shape > 0 || total_right_padding > 0) {
1067 auto index_shape = ShapeUtil::ChangeElementType(valid_slice->shape(), S32);
1068 auto iota = b->AddInstruction(HloInstruction::CreateIota(index_shape, dim));
1069 auto broadcast_start_index_in_padded_shape =
1070 b->AddInstruction(HloInstruction::CreateBroadcast(
1071 index_shape, offset_on_padded_shape, {}));
1072 auto index_in_padded_shape = b->AddInstruction(
1073 HloInstruction::CreateBinary(index_shape, HloOpcode::kAdd, iota,
1074 broadcast_start_index_in_padded_shape));
1075 auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED);
1076 std::vector<HloInstruction*> predicates;
1077 if (explicit_left_padding_on_full_shape > 0) {
1078 auto valid_index_start =
1079 b->AddInstruction(HloInstruction::CreateBroadcast(
1080 index_shape,
1081 b->AddInstruction(
1082 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
1083 explicit_left_padding_on_full_shape))),
1084 {}));
1085 predicates.push_back(b->AddInstruction(HloInstruction::CreateCompare(
1086 mask_shape, index_in_padded_shape, valid_index_start,
1087 ComparisonDirection::kGe)));
1088 }
1089 if (total_right_padding > 0) {
1090 auto valid_index_limit =
1091 b->AddInstruction(HloInstruction::CreateBroadcast(
1092 index_shape,
1093 b->AddInstruction(
1094 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
1095 base_shape.dimensions(dim) +
1096 explicit_left_padding_on_full_shape))),
1097 {}));
1098 predicates.push_back(b->AddInstruction(HloInstruction::CreateCompare(
1099 mask_shape, index_in_padded_shape, valid_index_limit,
1100 ComparisonDirection::kLt)));
1101 }
1102 CHECK(!predicates.empty());
1103 auto is_valid =
1104 predicates.size() == 2
1105 ? b->AddInstruction(HloInstruction::CreateBinary(
1106 mask_shape, HloOpcode::kAnd, predicates[0], predicates[1]))
1107 : predicates[0];
1108 auto masking_value = b->AddInstruction(
1109 HloInstruction::CreateBroadcast(valid_slice->shape(), pad_value, {}));
1110 valid_slice = b->AddInstruction(
1111 HloInstruction::CreateTernary(valid_slice->shape(), HloOpcode::kSelect,
1112 is_valid, valid_slice, masking_value));
1113 }
1114 return valid_slice;
1115 }
1116
HaloExchangeToPadOnLeft(PartitionedHlo & original,absl::Span<const int64> dims)1117 HloInstruction* HaloExchangeToPadOnLeft(PartitionedHlo& original,
1118 absl::Span<const int64> dims) {
1119 if (original.sharding().IsTileMaximal()) {
1120 return original.hlo();
1121 }
1122 // Create a window config to halo exchange for unevenly partitioned reverse
1123 // dimensions.
1124 Window window;
1125 for (int64_t i = 0; i < original.base_shape().rank(); ++i) {
1126 WindowDimension* dim = window.add_dimensions();
1127 dim->set_size(1);
1128 dim->set_stride(1);
1129 dim->set_window_dilation(1);
1130 dim->set_window_reversal(false);
1131 int64_t low_padding = 0;
1132 if (absl::c_linear_search(dims, i)) {
1133 low_padding =
1134 RoundUpToNearest(original.base_shape().dimensions(i),
1135 original.sharding().tile_assignment().dim(i)) -
1136 original.base_shape().dimensions(i);
1137 }
1138 dim->set_padding_low(low_padding);
1139 dim->set_padding_high(0);
1140 dim->set_base_dilation(1);
1141 }
1142
1143 auto reshard_window = original.ReshardAsWindowedInput(
1144 window, original.sharding(),
1145 CreateZero(ShapeUtil::MakeShape(original.base_shape().element_type(), {}),
1146 original.state().b),
1147 /*mask_invalid_region=*/false);
1148 if (!reshard_window.has_value()) {
1149 return nullptr;
1150 }
1151 CHECK(!reshard_window->dynamic_slice_index_on_output.has_value());
1152 return reshard_window->sharded_input;
1153 }
1154
IsNanSafeGt(HloComputation * comp)1155 bool IsNanSafeGt(HloComputation* comp) {
1156 namespace m = match;
1157 auto match_bitcast_f32 = [](int64_t parameter_number) {
1158 auto param = m::Parameter(parameter_number)
1159 .WithShape(m::Shape().WithElementType(F32));
1160 auto param_s32 =
1161 m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32));
1162 auto param_u32 =
1163 m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32));
1164 return m::Select(
1165 m::Lt(param_s32, m::ConstantScalar(0)),
1166 m::BitcastConvert(
1167 m::Subtract(m::ConstantScalar(std::numeric_limits<int32>::max()),
1168 param_u32))
1169 .WithShape(m::Shape().WithElementType(S32)),
1170 param_s32);
1171 };
1172 auto match_bitcast_bf16 = [](int64_t parameter_number) {
1173 auto param = m::Convert(m::Parameter(parameter_number)
1174 .WithShape(m::Shape().WithElementType(BF16)))
1175 .WithShape(m::Shape().WithElementType(F32));
1176 auto param_s32 =
1177 m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32));
1178 auto param_u32 =
1179 m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32));
1180 return m::Select(
1181 m::Lt(param_s32, m::ConstantScalar(0)),
1182 m::BitcastConvert(
1183 m::Subtract(m::ConstantScalar(std::numeric_limits<int32>::max()),
1184 param_u32))
1185 .WithShape(m::Shape().WithElementType(S32)),
1186 param_s32);
1187 };
1188 // If root instruction is kSelect and compares indices if values are equal.
1189 if (comp->root_instruction()->opcode() == HloOpcode::kSelect) {
1190 return Match(comp->root_instruction()->operand(2),
1191 m::Gt(match_bitcast_f32(0), match_bitcast_f32(1))) ||
1192 Match(comp->root_instruction()->operand(2),
1193 m::Gt(match_bitcast_bf16(0), match_bitcast_bf16(1)));
1194 }
1195 return Match(comp->root_instruction(),
1196 m::Gt(match_bitcast_f32(0), match_bitcast_f32(1))) ||
1197 Match(comp->root_instruction(),
1198 m::Gt(match_bitcast_bf16(0), match_bitcast_bf16(1)));
1199 }
1200
GetKValueInTopKWhenPartitionSortDim(HloInstruction * hlo)1201 absl::optional<int64> GetKValueInTopKWhenPartitionSortDim(HloInstruction* hlo) {
1202 HloSortInstruction* sort = DynCast<HloSortInstruction>(hlo);
1203 if (sort == nullptr || sort->operand_count() != 2) {
1204 return absl::nullopt;
1205 }
1206 if (!IsNanSafeGt(sort->to_apply())) {
1207 return absl::nullopt;
1208 }
1209 HloInstruction* data = sort->mutable_operand(0);
1210 HloIotaInstruction* iota =
1211 DynCast<HloIotaInstruction>(sort->mutable_operand(1));
1212 const PrimitiveType element_type = data->shape().element_type();
1213 if (iota == nullptr || iota->shape().element_type() != S32 ||
1214 iota->opcode() != HloOpcode::kIota ||
1215 iota->iota_dimension() != sort->sort_dimension()) {
1216 return absl::nullopt;
1217 }
1218
1219 const int64_t sort_dim = sort->sort_dimension();
1220
1221 if (element_type != F32 && element_type != BF16 && element_type != S32 &&
1222 element_type != U32) {
1223 return absl::nullopt;
1224 }
1225
1226 bool supported = true;
1227 absl::optional<int64> k;
1228 for (HloInstruction* gte : sort->users()) {
1229 if (gte->opcode() != HloOpcode::kGetTupleElement) {
1230 supported = false;
1231 break;
1232 }
1233
1234 const HloInstruction* slice = gte->users()[0];
1235 if (slice->opcode() != HloOpcode::kSlice) {
1236 // Non-slice user means we are not doing a TopK
1237 supported = false;
1238 break;
1239 }
1240 if (absl::c_any_of(slice->slice_starts(), [](int x) { return x != 0; }) ||
1241 absl::c_any_of(slice->slice_strides(), [](int x) { return x != 1; })) {
1242 // Strided slice or slicing at the beginning isn't supported.
1243 supported = false;
1244 break;
1245 }
1246 for (int64_t dim = 0; dim < data->shape().dimensions_size(); dim++) {
1247 if (dim == sort_dim) {
1248 continue;
1249 }
1250 if (slice->slice_limits(dim) !=
1251 slice->operand(0)->shape().dimensions(dim)) {
1252 // Slicing along the other dimension isn't supported.
1253 supported = false;
1254 break;
1255 }
1256 }
1257 if (!k.has_value()) {
1258 k = slice->slice_limits(sort_dim);
1259 } else if (k != slice->slice_limits(sort_dim)) {
1260 // Different k for the different operands isn't supported.
1261 supported = false;
1262 break;
1263 }
1264 }
1265 if (k == absl::nullopt || !supported) {
1266 return absl::nullopt;
1267 }
1268
1269 // Only support when sort dim is sharded.
1270 if (!data->has_sharding()) {
1271 return absl::nullopt;
1272 }
1273 const HloSharding& sharding = sort->operand(0)->sharding();
1274
1275 if (sharding.IsTileMaximal()) {
1276 return absl::nullopt;
1277 }
1278
1279 // Check if partitioned at sort dimension.
1280 for (int64_t dim = 0; dim < sort->shape().tuple_shapes(0).dimensions_size();
1281 ++dim) {
1282 if (sharding.tile_assignment().dim(dim) > 1) {
1283 if (dim != sort_dim) {
1284 return absl::nullopt;
1285 }
1286 }
1287 }
1288
1289 // Checks if partition size is smaller than k.
1290 const int64_t shard_count = sharding.tile_assignment().dim(sort_dim);
1291
1292 if (shard_count <= 1) {
1293 return absl::nullopt;
1294 }
1295
1296 const int64_t input_size = hlo->operand(0)->shape().dimensions(sort_dim);
1297 const int64_t per_partition_size = CeilOfRatio(input_size, shard_count);
1298
1299 if (k.value() >= per_partition_size) {
1300 return absl::nullopt;
1301 }
1302
1303 return k;
1304 }
1305
1306 // Slice first k elements from sort_dim.
SliceFirstK(HloInstruction * hlo,SpmdBuilder * builder,int64_t slice_dim,int64_t k)1307 HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder,
1308 int64_t slice_dim, int64_t k) {
1309 const Shape& hlo_shape = hlo->shape();
1310 auto hlo_dims = hlo_shape.dimensions();
1311 std::vector<int64> start_indices(hlo_shape.dimensions_size(), 0);
1312 std::vector<int64> limit_indices(hlo_dims.begin(), hlo_dims.end());
1313 std::vector<int64> strides(hlo_shape.dimensions_size(), 1);
1314 limit_indices[slice_dim] = k;
1315 auto output_shape = hlo_shape;
1316 output_shape.set_dimensions(slice_dim, k);
1317 return builder->AddInstruction(HloInstruction::CreateSlice(
1318 output_shape, hlo, start_indices, limit_indices, strides));
1319 }
1320
1321 // Check if a dimension is sharded.
ShardCountAtDim(const HloSharding & sharding,int64_t dim)1322 int64 ShardCountAtDim(const HloSharding& sharding, int64_t dim) {
1323 if (sharding.IsTileMaximal()) {
1324 return 1;
1325 }
1326 return sharding.tile_assignment().dim(dim);
1327 }
1328
1329 absl::optional<std::vector<std::pair<int64, int64>>>
GetReshardAllToAllSourceTargetDims(const HloSharding & source,const HloSharding & target)1330 GetReshardAllToAllSourceTargetDims(const HloSharding& source,
1331 const HloSharding& target) {
1332 if (source.IsTileMaximal() || target.IsTileMaximal() ||
1333 source.tile_assignment().num_dimensions() !=
1334 target.tile_assignment().num_dimensions() ||
1335 source.NumTiles() != target.NumTiles()) {
1336 return absl::nullopt;
1337 }
1338 // Record partition count to index for indices that have different partition
1339 // counts on source and target.
1340 std::map<int64, std::vector<int64>> source_size_to_dim;
1341 std::map<int64, std::vector<int64>> target_size_to_dim;
1342 for (int64_t i = 0; i < source.tile_assignment().num_dimensions(); ++i) {
1343 if (source.tile_assignment().dim(i) == target.tile_assignment().dim(i)) {
1344 continue;
1345 }
1346 source_size_to_dim[source.tile_assignment().dim(i)].push_back(i);
1347 target_size_to_dim[target.tile_assignment().dim(i)].push_back(i);
1348 }
1349 // In order to shard via AllToAll, source_size_to_dim and target_size_to_dim
1350 // must have the same distribution.
1351 if (source_size_to_dim.empty() ||
1352 source_size_to_dim.size() != target_size_to_dim.size()) {
1353 return absl::nullopt;
1354 }
1355 for (const auto& entry : source_size_to_dim) {
1356 auto target_it = target_size_to_dim.find(entry.first);
1357 if (target_it == target_size_to_dim.end() ||
1358 target_it->second.size() != entry.second.size()) {
1359 return absl::nullopt;
1360 }
1361 }
1362 std::vector<std::pair<int64, int64>> result;
1363 auto remove_entry = [](int64_t size, int64_t dim,
1364 std::map<int64, std::vector<int64>>& size_to_dim) {
1365 size_to_dim[size].erase(
1366 std::remove_if(size_to_dim[size].begin(), size_to_dim[size].end(),
1367 [dim](int64_t a) { return a == dim; }),
1368 size_to_dim[size].end());
1369 if (size_to_dim[size].empty()) {
1370 size_to_dim.erase(size);
1371 }
1372 };
1373 // Find one pair of dimensions to swap at a time.
1374 while (!source_size_to_dim.empty()) {
1375 int64_t source_size = source_size_to_dim.begin()->first;
1376 int64_t i = source_size_to_dim.begin()->second.back();
1377 int64_t target_i_size = target.tile_assignment().dim(i);
1378 if (target_i_size == source_size) {
1379 remove_entry(source_size, i, source_size_to_dim);
1380 remove_entry(source_size, i, target_size_to_dim);
1381 continue;
1382 }
1383 auto j_it = source_size_to_dim[target_i_size].begin();
1384 int64_t j = *j_it;
1385 if (source_size == 1) {
1386 // If possible, find a j where the target partition count is not one, so
1387 // that when we swap, the resulting size-1 dimension will still be useful
1388 // to other dimensions.
1389 while (target.tile_assignment().dim(j) == 1) {
1390 if (++j_it == source_size_to_dim[target_i_size].end()) {
1391 break;
1392 }
1393 j = *j_it;
1394 }
1395 } else if (target_i_size % source_size == 0) {
1396 // If possible, find a j where the target partition count is source_size,
1397 // so that we can do a single swap.
1398 while (target.tile_assignment().dim(j) != source_size) {
1399 if (++j_it == source_size_to_dim[target_i_size].end()) {
1400 break;
1401 }
1402 j = *j_it;
1403 }
1404 } else {
1405 return absl::nullopt;
1406 }
1407 result.emplace_back(j, i);
1408 remove_entry(target_i_size, i, target_size_to_dim);
1409 source_size_to_dim.begin()->second.back() = j;
1410 remove_entry(target_i_size, j, source_size_to_dim);
1411 }
1412 return result;
1413 }
1414
CanReshardWithCollectivePermute(const HloSharding & source,const HloSharding & target)1415 bool CanReshardWithCollectivePermute(const HloSharding& source,
1416 const HloSharding& target) {
1417 return !source.IsTileMaximal() && !target.IsTileMaximal() &&
1418 source.tile_assignment().dimensions() ==
1419 target.tile_assignment().dimensions() &&
1420 source.ReplicateOnLastTileDim() == target.ReplicateOnLastTileDim() &&
1421 source.tile_assignment() != target.tile_assignment();
1422 }
1423
GroupShardingOnDims(const HloSharding & sharding,absl::Span<const int64> group_dims)1424 GroupedSharding GroupShardingOnDims(const HloSharding& sharding,
1425 absl::Span<const int64> group_dims) {
1426 std::vector<int64> group_dim_shards(group_dims.size(), 1);
1427 return GroupShardingOnDims(sharding, group_dims, group_dim_shards);
1428 }
1429
GroupShardingOnDims(const HloSharding & sharding,absl::Span<const int64> group_dims,absl::Span<const int64> group_dim_shards)1430 GroupedSharding GroupShardingOnDims(const HloSharding& sharding,
1431 absl::Span<const int64> group_dims,
1432 absl::Span<const int64> group_dim_shards) {
1433 CHECK(!sharding.IsTileMaximal());
1434 std::vector<int64> grouped_tiling_dims =
1435 sharding.tile_assignment().dimensions();
1436 std::vector<int64> group_dim_sizes(group_dims.size());
1437 for (int64_t i = 0; i < group_dims.size(); ++i) {
1438 CHECK_EQ(grouped_tiling_dims[group_dims[i]] % group_dim_shards[i], 0);
1439 group_dim_sizes[i] =
1440 grouped_tiling_dims[group_dims[i]] / group_dim_shards[i];
1441 grouped_tiling_dims[group_dims[i]] = group_dim_shards[i];
1442 }
1443
1444 std::vector<std::vector<int64>> device_groups(Product(group_dim_sizes));
1445 sharding.tile_assignment().Each([&](absl::Span<const int64> indices,
1446 int64_t device) {
1447 int64_t group_id = 0;
1448 for (int64_t i = 0; i < group_dims.size(); ++i) {
1449 group_id *=
1450 sharding.tile_assignment().dim(group_dims[i]) / group_dim_shards[i];
1451 group_id += indices[group_dims[i]] / group_dim_shards[i];
1452 }
1453 device_groups[group_id].push_back(device);
1454 });
1455 auto grouped = GroupedSharding(
1456 std::move(device_groups),
1457 std::vector<int64>(group_dims.begin(), group_dims.end()),
1458 std::move(group_dim_sizes), sharding.tile_assignment().num_dimensions(),
1459 HloSharding::Replicate());
1460 if (sharding.ReplicateOnLastTileDim()) {
1461 grouped.data_rank--;
1462 }
1463 if (Product(grouped_tiling_dims) == 1 ||
1464 (sharding.ReplicateOnLastTileDim() &&
1465 Product(grouped_tiling_dims) == grouped_tiling_dims.back())) {
1466 return grouped;
1467 }
1468 if ((sharding.ReplicateOnLastTileDim() || sharding.IsManualSubgroup()) &&
1469 grouped_tiling_dims.back() == 1) {
1470 grouped_tiling_dims.pop_back();
1471 }
1472 Array<int64> grouped_tiling(grouped_tiling_dims);
1473 grouped_tiling.FillIota(0);
1474 grouped.sharding = sharding.ReplicateOnLastTileDim() &&
1475 grouped_tiling_dims.size() ==
1476 sharding.tile_assignment().num_dimensions()
1477 ? HloSharding::PartialTile(grouped_tiling)
1478 : HloSharding::Tile(grouped_tiling);
1479 return grouped;
1480 }
1481
UngroupSharding(const GroupedSharding & grouped_sharding)1482 HloSharding UngroupSharding(const GroupedSharding& grouped_sharding) {
1483 std::vector<int64> tiling_dims;
1484 bool partial_sharding = false;
1485 auto grouped_tiling = grouped_sharding.sharding.tile_assignment();
1486 if (grouped_sharding.sharding.IsTileMaximal()) {
1487 tiling_dims = std::vector<int64>(grouped_sharding.data_rank, 1);
1488 if (grouped_sharding.device_groups[0].size() != 1) {
1489 // This is partial sharding.
1490 tiling_dims.push_back(grouped_sharding.device_groups[0].size());
1491 partial_sharding = true;
1492 }
1493 grouped_tiling = Array<int64>(tiling_dims);
1494 grouped_tiling.FillIota(0);
1495 } else {
1496 partial_sharding = grouped_sharding.sharding.ReplicateOnLastTileDim();
1497 tiling_dims = grouped_sharding.sharding.tile_assignment().dimensions();
1498 if (absl::c_linear_search(grouped_sharding.group_dims,
1499 tiling_dims.size())) {
1500 tiling_dims.push_back(1);
1501 grouped_tiling.Reshape(tiling_dims);
1502 partial_sharding = true;
1503 }
1504 }
1505 for (int64_t i = 0; i < grouped_sharding.group_dims.size(); ++i) {
1506 int64_t dim = grouped_sharding.group_dims[i];
1507 tiling_dims[dim] *= grouped_sharding.group_dim_sizes[i];
1508 }
1509 Array<int64> tiling(tiling_dims);
1510 grouped_tiling.Each([&](absl::Span<const int64> indices, int64_t device) {
1511 std::vector<int64> ungrouped_inds(indices.begin(), indices.end());
1512 for (int64_t g = 0; g < grouped_sharding.device_groups.size(); ++g) {
1513 int64_t remaining_group_index = g;
1514 for (int64_t i = grouped_sharding.group_dims.size() - 1; i >= 0; --i) {
1515 int64_t dim = grouped_sharding.group_dims[i];
1516 int64_t groups_in_this_dim = grouped_sharding.group_dim_sizes[i];
1517 ungrouped_inds[dim] = (remaining_group_index % groups_in_this_dim) *
1518 grouped_tiling.dim(dim) +
1519 indices[dim];
1520 remaining_group_index /= groups_in_this_dim;
1521 }
1522 tiling(ungrouped_inds) = grouped_sharding.device_groups[g][device];
1523 }
1524 });
1525 return partial_sharding ? HloSharding::PartialTile(tiling)
1526 : HloSharding::Tile(tiling);
1527 }
1528
AlignGroupsWith(GroupedSharding grouped_sharding,const GroupedSharding & reference,bool ignore_group_order)1529 GroupedSharding AlignGroupsWith(GroupedSharding grouped_sharding,
1530 const GroupedSharding& reference,
1531 bool ignore_group_order) {
1532 // Returns src -> dst index mapping.
1533 auto get_permutation = [](absl::Span<const int64> src,
1534 absl::Span<const int64> dst) {
1535 CHECK_EQ(src.size(), dst.size());
1536 absl::flat_hash_map<int64, int64> dst_reverse_map;
1537 for (int64_t i = 0; i < dst.size(); ++i) {
1538 dst_reverse_map[dst[i]] = i;
1539 }
1540 std::vector<int64> permutation(src.size());
1541 for (int64_t i = 0; i < src.size(); ++i) {
1542 auto it = dst_reverse_map.find(src[i]);
1543 CHECK(it != dst_reverse_map.end());
1544 permutation[i] = it->second;
1545 }
1546 return permutation;
1547 };
1548 CHECK_EQ(grouped_sharding.device_groups.size(),
1549 reference.device_groups.size());
1550 absl::flat_hash_map<int64, int64> device_to_ref_group;
1551 for (int64_t g = 0; g < reference.device_groups.size(); ++g) {
1552 for (int64_t device : reference.device_groups[g]) {
1553 device_to_ref_group[device] = g;
1554 }
1555 }
1556 auto unique_ref_dev_group = [&](absl::Span<const int64> devices) -> int64 {
1557 int64_t ref_g = -1;
1558 for (int64_t device : devices) {
1559 if (ref_g == -1) {
1560 ref_g = device_to_ref_group[device];
1561 } else if (ref_g != device_to_ref_group[device]) {
1562 return -1;
1563 }
1564 }
1565 return ref_g;
1566 };
1567 bool matching_groups = true;
1568 std::vector<int64> original_src_to_ref_permutation;
1569 for (int64_t g = 0; g < grouped_sharding.device_groups.size(); ++g) {
1570 int64_t ref_g = unique_ref_dev_group(grouped_sharding.device_groups[g]);
1571 if (ref_g < 0 || (!ignore_group_order && g != ref_g)) {
1572 matching_groups = false;
1573 break;
1574 }
1575 if (g == 0) {
1576 original_src_to_ref_permutation = get_permutation(
1577 grouped_sharding.device_groups[g], reference.device_groups[ref_g]);
1578 }
1579 }
1580 if (matching_groups && !grouped_sharding.sharding.IsTileMaximal()) {
1581 auto tiles = grouped_sharding.sharding.tile_assignment();
1582 tiles.Each([&](absl::Span<const int64> indices, int64* device) {
1583 *device = original_src_to_ref_permutation[*device];
1584 });
1585 grouped_sharding.sharding =
1586 grouped_sharding.sharding.ReplicateOnLastTileDim()
1587 ? HloSharding::PartialTile(tiles)
1588 : HloSharding::Tile(tiles);
1589 }
1590 grouped_sharding.device_groups = std::move(reference.device_groups);
1591 return grouped_sharding;
1592 }
1593
AlignShardingOnDims(const HloSharding & sharding,absl::Span<const int64> sharding_dims,const HloSharding & reference,absl::Span<const int64> reference_dims)1594 HloSharding AlignShardingOnDims(const HloSharding& sharding,
1595 absl::Span<const int64> sharding_dims,
1596 const HloSharding& reference,
1597 absl::Span<const int64> reference_dims) {
1598 auto sharding_grouped = GroupShardingOnDims(sharding, sharding_dims);
1599 auto reference_grouped = GroupShardingOnDims(reference, reference_dims);
1600 return UngroupSharding(AlignGroupsWith(sharding_grouped, reference_grouped));
1601 }
1602
GetPerGroupBaseShape(const GroupedSharding & grouped_sharding,const Shape & original_base_shape)1603 Shape GetPerGroupBaseShape(const GroupedSharding& grouped_sharding,
1604 const Shape& original_base_shape) {
1605 auto result = original_base_shape;
1606 for (int64_t i = 0; i < grouped_sharding.group_dims.size(); ++i) {
1607 int64_t dim = grouped_sharding.group_dims[i];
1608 if (dim >= original_base_shape.rank()) {
1609 continue;
1610 }
1611 int64_t groups = grouped_sharding.group_dim_sizes[i];
1612 result.set_dimensions(dim, CeilOfRatio(result.dimensions(dim), groups));
1613 }
1614 return result;
1615 }
1616
1617 namespace {
1618
GetInGroupPartitionId(HloInstruction * partition_id,const std::vector<std::vector<int64>> & device_groups,SpmdBuilder * b)1619 HloInstruction* GetInGroupPartitionId(
1620 HloInstruction* partition_id,
1621 const std::vector<std::vector<int64>>& device_groups, SpmdBuilder* b) {
1622 int64_t total_devices = device_groups.size() * device_groups[0].size();
1623 std::vector<uint32> in_group_ids(total_devices);
1624 for (uint32 i = 0; i < device_groups.size(); ++i) {
1625 for (uint32 j = 0; j < device_groups[i].size(); ++j) {
1626 in_group_ids[device_groups[i][j]] = j;
1627 }
1628 }
1629 auto id_table = b->AddInstruction(HloInstruction::CreateConstant(
1630 LiteralUtil::CreateR1<uint32>(in_group_ids)));
1631 return b->AddInstruction(HloInstruction::CreateReshape(
1632 ShapeUtil::MakeScalarShape(U32),
1633 b->AddInstruction(HloInstruction::CreateDynamicSlice(
1634 ShapeUtil::MakeShape(U32, {1}), id_table, {partition_id}, {1}))));
1635 }
1636
GetPerGroupCollectiveOpsCreator(const SPMDCollectiveOpsCreator & creator,const std::vector<std::vector<int64>> & device_groups)1637 SPMDCollectiveOpsCreator GetPerGroupCollectiveOpsCreator(
1638 const SPMDCollectiveOpsCreator& creator,
1639 const std::vector<std::vector<int64>>& device_groups) {
1640 SPMDCollectiveOpsCreator result;
1641 result.create_partition_id = [creator, device_groups](SpmdBuilder* b) {
1642 return GetInGroupPartitionId(creator.create_partition_id(b), device_groups,
1643 b);
1644 };
1645 auto expand_partition_groups =
1646 [device_groups](
1647 const std::vector<std::vector<int64>>& partition_subgroups) {
1648 if (partition_subgroups.empty()) {
1649 return device_groups;
1650 }
1651 std::vector<std::vector<int64>> result(partition_subgroups.size() *
1652 device_groups.size());
1653 for (int64_t g = 0; g < device_groups.size(); ++g) {
1654 for (int64_t i = 0; i < partition_subgroups.size(); ++i) {
1655 result[g * partition_subgroups.size() + i].resize(
1656 partition_subgroups[i].size());
1657 for (int64_t j = 0; j < partition_subgroups[i].size(); ++j) {
1658 result[g * partition_subgroups.size() + i][j] =
1659 device_groups[g][partition_subgroups[i][j]];
1660 }
1661 }
1662 }
1663 return result;
1664 };
1665 result.create_cross_partition_all_reduce =
1666 [creator, expand_partition_groups](
1667 SpmdBuilder* b, HloInstruction* operand, HloComputation* reduction,
1668 const std::vector<std::vector<int64>>& partition_subgroups,
1669 int64_t channel_id) {
1670 return creator.create_cross_partition_all_reduce(
1671 b, operand, reduction, expand_partition_groups(partition_subgroups),
1672 channel_id);
1673 };
1674 result.create_cross_partition_collective_permute =
1675 [creator, device_groups](
1676 SpmdBuilder* b, HloInstruction* operand,
1677 std::vector<std::pair<int64, int64>>& src_dst_pairs,
1678 int64_t next_channel_id) {
1679 std::vector<std::pair<int64, int64>> expanded_pairs(
1680 src_dst_pairs.size() * device_groups.size());
1681 for (int64_t g = 0; g < device_groups.size(); ++g) {
1682 for (int64_t i = 0; i < src_dst_pairs.size(); ++i) {
1683 expanded_pairs[g * src_dst_pairs.size() + i] =
1684 std::pair<int64, int64>{
1685 device_groups[g][src_dst_pairs[i].first],
1686 device_groups[g][src_dst_pairs[i].second]};
1687 }
1688 }
1689 return creator.create_cross_partition_collective_permute(
1690 b, operand, expanded_pairs, next_channel_id);
1691 };
1692 result.create_cross_partition_all_to_all =
1693 [creator, expand_partition_groups](
1694 SpmdBuilder* b, absl::Span<HloInstruction* const> operands,
1695 const std::vector<std::vector<int64>>& partition_subgroups,
1696 int64_t channel_id, absl::optional<int64> split_dimension) {
1697 return creator.create_cross_partition_all_to_all(
1698 b, operands, expand_partition_groups(partition_subgroups),
1699 channel_id, split_dimension);
1700 };
1701 if (creator.create_cross_partition_all_gather) {
1702 result.create_cross_partition_all_gather =
1703 [creator, expand_partition_groups](
1704 SpmdBuilder* b, HloInstruction* operand, const Shape& ag_shape,
1705 const std::vector<std::vector<int64>>& partition_subgroups,
1706 int64_t channel_id, int64_t all_gather_dimension) {
1707 return creator.create_cross_partition_all_gather(
1708 b, operand, ag_shape,
1709 expand_partition_groups(partition_subgroups), channel_id,
1710 all_gather_dimension);
1711 };
1712 }
1713 return result;
1714 }
1715
1716 } // namespace
1717
CreatePerGroupPartitioningState(const PartitionedHlo::PartitioningState & state,const std::vector<std::vector<int64>> & device_groups,SpmdBuilder * b)1718 PartitionedHlo::PartitioningState CreatePerGroupPartitioningState(
1719 const PartitionedHlo::PartitioningState& state,
1720 const std::vector<std::vector<int64>>& device_groups, SpmdBuilder* b) {
1721 auto result = state;
1722 result.collective_ops_creator = GetPerGroupCollectiveOpsCreator(
1723 state.collective_ops_creator, device_groups);
1724 result.partition_id =
1725 GetInGroupPartitionId(state.partition_id, device_groups, b);
1726 // Create a string key for the groups.
1727 std::vector<std::string> per_group_strings(device_groups.size());
1728 for (int64_t i = 0; i < per_group_strings.size(); ++i) {
1729 per_group_strings[i] = absl::StrJoin(device_groups[i], ",");
1730 }
1731 auto& grouped_cache =
1732 state.reshard_cache->groupd_caches[absl::StrJoin(per_group_strings, ";")];
1733 if (!grouped_cache) {
1734 grouped_cache = absl::make_unique<PartitionedHlo::ReshardCache>();
1735 }
1736 result.reshard_cache = grouped_cache.get();
1737 return result;
1738 }
1739
PerGroupSliceFromReplicated(HloInstruction * replicated,HloInstruction * partition_id,const std::vector<std::vector<int64>> & device_groups,absl::Span<const int64> group_dims,absl::Span<const int64> group_dim_sizes,SpmdBuilder * b)1740 HloInstruction* PerGroupSliceFromReplicated(
1741 HloInstruction* replicated, HloInstruction* partition_id,
1742 const std::vector<std::vector<int64>>& device_groups,
1743 absl::Span<const int64> group_dims, absl::Span<const int64> group_dim_sizes,
1744 SpmdBuilder* b) {
1745 std::vector<uint32> group_ids(device_groups.size() * device_groups[0].size());
1746 for (int64_t g = 0; g < device_groups.size(); ++g) {
1747 for (int64_t device : device_groups[g]) {
1748 group_ids[device] = g;
1749 }
1750 }
1751 auto group_id_table = b->AddInstruction(
1752 HloInstruction::CreateConstant(LiteralUtil::CreateR1<uint32>(group_ids)));
1753 auto group_id = b->AddInstruction(HloInstruction::CreateReshape(
1754 ShapeUtil::MakeScalarShape(U32),
1755 b->AddInstruction(HloInstruction::CreateDynamicSlice(
1756 ShapeUtil::MakeShape(U32, {1}), group_id_table, {partition_id},
1757 {1}))));
1758 std::vector<int64> group_level_tile_dims(replicated->shape().rank(), 1);
1759 for (int64_t i = 0; i < group_dims.size(); ++i) {
1760 group_level_tile_dims[group_dims[i]] = group_dim_sizes[i];
1761 }
1762 Array<int64> group_level_tile(group_level_tile_dims);
1763 group_level_tile.Each([&](absl::Span<const int64> indices, int64* group) {
1764 *group = 0;
1765 for (int64_t dim : group_dims) {
1766 *group *= group_level_tile.dim(dim);
1767 *group += indices[dim];
1768 }
1769 });
1770 auto group_level_sharding = HloSharding::Tile(group_level_tile);
1771 auto padded_hlo = PadBaseShapeBeforeUnevenTiledSharding(
1772 replicated, group_level_sharding, b);
1773 auto shard_shape =
1774 MakePartitionedShape(replicated->shape(), group_level_sharding);
1775 return b->AddInstruction(HloInstruction::CreateDynamicSlice(
1776 shard_shape, padded_hlo,
1777 MakePartitionOffsets(replicated->shape(), group_level_sharding, group_id,
1778 b),
1779 shard_shape.dimensions()));
1780 }
1781
ParseReductionComputation(const HloComputation * reduction_comp)1782 absl::optional<HloOpcode> ParseReductionComputation(
1783 const HloComputation* reduction_comp) {
1784 if (reduction_comp->num_parameters() != 2) {
1785 return absl::nullopt;
1786 }
1787 auto root = reduction_comp->root_instruction();
1788 if (!root->IsElementwiseBinary()) {
1789 return absl::nullopt;
1790 }
1791 if (!absl::c_linear_search(root->operands(),
1792 reduction_comp->parameter_instruction(0)) ||
1793 !absl::c_linear_search(root->operands(),
1794 reduction_comp->parameter_instruction(1))) {
1795 return absl::nullopt;
1796 }
1797 return root->opcode();
1798 }
1799
FindMatchingPartitionedDimsForGrouping(const HloSharding & sharding,const std::vector<std::vector<int64>> & device_groups)1800 absl::optional<std::vector<int64>> FindMatchingPartitionedDimsForGrouping(
1801 const HloSharding& sharding,
1802 const std::vector<std::vector<int64>>& device_groups) {
1803 if (sharding.NumTiles() < device_groups.size() || device_groups.size() < 2 ||
1804 device_groups[0].size() < 2) {
1805 return absl::nullopt;
1806 }
1807 int64_t rank = sharding.tile_assignment().num_dimensions();
1808 if (sharding.ReplicateOnLastTileDim()) {
1809 rank--;
1810 }
1811 absl::flat_hash_map<int64, std::vector<int64>> device_to_index;
1812 sharding.tile_assignment().Each(
1813 [&](absl::Span<const int64> index, int64_t device) {
1814 device_to_index[device] =
1815 std::vector<int64>(index.begin(), index.begin() + rank);
1816 });
1817 std::vector<int64> dims;
1818 int64_t group_count = 1;
1819 for (int64_t i = 0; i < rank; ++i) {
1820 if (device_to_index[device_groups[0][0]][i] ==
1821 device_to_index[device_groups[0][1]][i]) {
1822 dims.push_back(i);
1823 group_count *= sharding.tile_assignment().dim(i);
1824 }
1825 }
1826 if (group_count != device_groups.size()) {
1827 return absl::nullopt;
1828 }
1829 for (const auto& group : device_groups) {
1830 for (int64_t i = 1; i < group.size(); ++i) {
1831 if (absl::c_any_of(dims, [&](const int64_t dim) {
1832 return device_to_index[group[i]][dim] !=
1833 device_to_index[group[0]][dim];
1834 })) {
1835 return absl::nullopt;
1836 }
1837 }
1838 }
1839 return dims;
1840 }
1841
CreateMatchingShardingOnDims(const Shape & target_shape,const HloSharding & source_sharding,absl::Span<const int64> target_dims,absl::Span<const int64> source_dims)1842 HloSharding CreateMatchingShardingOnDims(const Shape& target_shape,
1843 const HloSharding& source_sharding,
1844 absl::Span<const int64> target_dims,
1845 absl::Span<const int64> source_dims) {
1846 CHECK(target_dims.size() == source_dims.size())
1847 << "Expected 1:1 match between parallel dimensions";
1848 if (source_sharding.IsReplicated()) {
1849 return HloSharding::Replicate();
1850 }
1851 absl::InlinedVector<int64, 4> tile_dims(target_shape.dimensions_size(), 1);
1852 int num_tiles = 1;
1853 for (int i = 0, end = target_dims.size(); i < end; ++i) {
1854 num_tiles *= source_sharding.tile_assignment().dim(source_dims[i]);
1855 tile_dims[target_dims[i]] =
1856 source_sharding.tile_assignment().dim(source_dims[i]);
1857 }
1858 // If there is some partition across non-parallel dimensions in the
1859 // other operand then partially replicate for the new
1860 bool to_be_partially_replicated = false;
1861 if (num_tiles != source_sharding.tile_assignment().num_elements()) {
1862 CHECK_EQ(source_sharding.tile_assignment().num_elements() % num_tiles, 0);
1863 to_be_partially_replicated = true;
1864 tile_dims.push_back(source_sharding.tile_assignment().num_elements() /
1865 num_tiles);
1866 }
1867 auto tgt_tile_assignment = source_sharding.tile_assignment();
1868 tgt_tile_assignment.Reshape(tile_dims);
1869 if (to_be_partially_replicated) {
1870 return AlignShardingOnDims(HloSharding::PartialTile(tgt_tile_assignment),
1871 target_dims, source_sharding, source_dims);
1872 } else {
1873 return AlignShardingOnDims(HloSharding::Tile(tgt_tile_assignment),
1874 target_dims, source_sharding, source_dims);
1875 }
1876 }
1877
1878 absl::optional<GatherParallelDimSharding>
GatherOperandsShardedAcrossParallelDims(const HloInstruction & operand,const HloInstruction & indices,const hlo_sharding_util::GatherParallelDims & parallel_dims)1879 GatherOperandsShardedAcrossParallelDims(
1880 const HloInstruction& operand, const HloInstruction& indices,
1881 const hlo_sharding_util::GatherParallelDims& parallel_dims) {
1882 auto& indices_parallel_dims = parallel_dims.indices_parallel_dims;
1883 auto& operand_parallel_dims = parallel_dims.operand_parallel_dims;
1884 if (indices_parallel_dims.size() != operand_parallel_dims.size()) {
1885 return absl::nullopt;
1886 }
1887 auto new_index_shard = indices.sharding();
1888 auto new_operand_shard = operand.sharding();
1889 int idx_parallel_tiles_num = new_index_shard.NumTiles(indices_parallel_dims);
1890 int op_parallel_tiles_num = new_operand_shard.NumTiles(operand_parallel_dims);
1891 if (idx_parallel_tiles_num == 1 && op_parallel_tiles_num == 1) {
1892 return absl::nullopt;
1893 }
1894 absl::InlinedVector<int64, 1> indices_parallel_dims_ordered_as_operand;
1895 for (int idx : parallel_dims.index_parallel_in_dim) {
1896 if (idx != -1) {
1897 indices_parallel_dims_ordered_as_operand.push_back(idx);
1898 }
1899 }
1900 if (new_index_shard.IsReplicated()) {
1901 return GatherParallelDimSharding{
1902 CreateMatchingShardingOnDims(indices.shape(), new_operand_shard,
1903 indices_parallel_dims_ordered_as_operand,
1904 operand_parallel_dims),
1905 new_operand_shard};
1906 }
1907 if (new_operand_shard.IsReplicated()) {
1908 return GatherParallelDimSharding{
1909 new_index_shard,
1910 CreateMatchingShardingOnDims(operand.shape(), new_index_shard,
1911 operand_parallel_dims,
1912 indices_parallel_dims_ordered_as_operand)};
1913 }
1914
1915 // Parallel dimension distribution needs to be the same, so try to steal
1916 // sharding from partial replication to compensate.
1917 if (idx_parallel_tiles_num != op_parallel_tiles_num) {
1918 auto to_adjust_dims = operand_parallel_dims;
1919 auto target_dims = indices_parallel_dims_ordered_as_operand;
1920 HloSharding* target = &new_index_shard;
1921 HloSharding* to_adjust = &new_operand_shard;
1922 if (idx_parallel_tiles_num < op_parallel_tiles_num) {
1923 std::swap(to_adjust_dims, target_dims);
1924 std::swap(to_adjust, target);
1925 }
1926 if (!to_adjust->ReplicateOnLastTileDim()) {
1927 return absl::nullopt;
1928 }
1929 auto new_tile_assignment_dims = to_adjust->tile_assignment().dimensions();
1930 for (int i = 0; i < to_adjust_dims.size(); ++i) {
1931 int64_t target_dim = target->tile_assignment().dim(target_dims[i]);
1932 int64_t to_adjust_dim =
1933 to_adjust->tile_assignment().dim(to_adjust_dims[i]);
1934 if (target_dim < to_adjust_dim) {
1935 return absl::nullopt;
1936 }
1937 if (target_dim == to_adjust_dim) {
1938 continue;
1939 }
1940 int64_t ratio = target_dim / to_adjust_dim;
1941 if (target_dim % to_adjust_dim != 0 ||
1942 new_tile_assignment_dims.back() % ratio != 0) {
1943 return absl::nullopt;
1944 }
1945 new_tile_assignment_dims[to_adjust_dims[i]] *= ratio;
1946 new_tile_assignment_dims.back() /= ratio;
1947 }
1948 CHECK_GE(new_tile_assignment_dims.back(), 1);
1949 bool to_partially_replicate = true;
1950 if (new_tile_assignment_dims.back() == 1) {
1951 new_tile_assignment_dims.pop_back();
1952 to_partially_replicate = false;
1953 }
1954 auto new_tile_assignment = to_adjust->tile_assignment();
1955 new_tile_assignment.Reshape(new_tile_assignment_dims);
1956 if (to_partially_replicate) {
1957 *to_adjust =
1958 AlignShardingOnDims(HloSharding::PartialTile(new_tile_assignment),
1959 to_adjust_dims, *target, target_dims);
1960 } else {
1961 *to_adjust = AlignShardingOnDims(HloSharding::Tile(new_tile_assignment),
1962 to_adjust_dims, *target, target_dims);
1963 }
1964 }
1965 // Make sure that the parallel dimensions are aligned.
1966 auto operand_shard_tile_dims =
1967 new_operand_shard.tile_assignment().dimensions();
1968 for (int i = 0; i < indices_parallel_dims_ordered_as_operand.size(); ++i) {
1969 operand_shard_tile_dims[operand_parallel_dims[i]] =
1970 new_index_shard.tile_assignment().dim(
1971 indices_parallel_dims_ordered_as_operand[i]);
1972 }
1973 auto operand_shard_tiles = new_operand_shard.tile_assignment();
1974 operand_shard_tiles.Reshape(operand_shard_tile_dims);
1975 new_operand_shard =
1976 AlignShardingOnDims(new_operand_shard.ReplicateOnLastTileDim()
1977 ? HloSharding::PartialTile(operand_shard_tiles)
1978 : HloSharding::Tile(operand_shard_tiles),
1979 operand_parallel_dims, new_index_shard,
1980 indices_parallel_dims_ordered_as_operand);
1981 return GatherParallelDimSharding{new_index_shard, new_operand_shard};
1982 }
1983
FindRotateRightPattern(const HloInstruction * concat,const HloInstruction * lhs,const HloInstruction * rhs)1984 int64 FindRotateRightPattern(const HloInstruction* concat,
1985 const HloInstruction* lhs,
1986 const HloInstruction* rhs) {
1987 if (lhs->opcode() != HloOpcode::kSlice ||
1988 rhs->opcode() != HloOpcode::kSlice ||
1989 lhs->operand(0) != rhs->operand(0)) {
1990 return -1;
1991 }
1992 const HloInstruction* to_rotate = lhs->operand(0);
1993 if (!ShapeUtil::Compatible(to_rotate->shape(), concat->shape()) ||
1994 concat->sharding() != to_rotate->sharding()) {
1995 return -1;
1996 }
1997 const int64_t dim = concat->concatenate_dimension();
1998 if (lhs->slice_strides(dim) != 1 || rhs->slice_strides(dim) != 1 ||
1999 lhs->slice_starts(dim) != rhs->slice_limits(dim)) {
2000 return -1;
2001 }
2002 return lhs->shape().dimensions(dim);
2003 }
2004
FindPadWithWrapPattern(const HloInstruction * concat,const HloInstruction * lhs,const HloInstruction * mid,const HloInstruction * rhs)2005 absl::optional<PadWithWrapPattern> FindPadWithWrapPattern(
2006 const HloInstruction* concat, const HloInstruction* lhs,
2007 const HloInstruction* mid, const HloInstruction* rhs) {
2008 if (!lhs || !mid || !rhs) {
2009 return absl::nullopt;
2010 }
2011
2012 // Skip elementwise unary operations applied to inst, returning
2013 // a list of applied operations that were skipped.
2014 auto skip_elementwise_ops = [&](const HloInstruction* inst) {
2015 std::vector<const HloInstruction*> modifiers;
2016 while (inst->IsElementwise() && inst->operand_count() == 1 &&
2017 inst->user_count() == 1) {
2018 if (inst->opcode() != HloOpcode::kCopy) {
2019 modifiers.push_back(inst);
2020 }
2021 inst = inst->operand(0);
2022 }
2023 return std::make_pair(modifiers, inst);
2024 };
2025
2026 PadWithWrapPattern pad_pattern;
2027 auto skip_result = skip_elementwise_ops(lhs);
2028 pad_pattern.lhs_modifiers = std::move(skip_result.first);
2029 lhs = skip_result.second;
2030
2031 skip_result = skip_elementwise_ops(rhs);
2032 pad_pattern.rhs_modifiers = std::move(skip_result.first);
2033 rhs = skip_result.second;
2034
2035 const int64_t dim = concat->concatenate_dimension();
2036 if (lhs->opcode() != HloOpcode::kSlice ||
2037 rhs->opcode() != HloOpcode::kSlice || lhs->operand(0) != mid ||
2038 rhs->operand(0) != mid || lhs->slice_strides(dim) != 1 ||
2039 rhs->slice_strides(dim) != 1 || lhs->sharding() != mid->sharding() ||
2040 rhs->sharding() != mid->sharding() ||
2041 lhs->sharding() != concat->sharding()) {
2042 return absl::nullopt;
2043 }
2044 pad_pattern.lhs_slice_start = lhs->slice_starts(dim);
2045 pad_pattern.rhs_slice_start = rhs->slice_starts(dim);
2046 return pad_pattern;
2047 }
2048
GetManualSubgroupSharding(const HloSharding & sharding)2049 GroupedSharding GetManualSubgroupSharding(const HloSharding& sharding) {
2050 CHECK(sharding.IsManualSubgroup());
2051 int64_t tile_dimensions = sharding.tile_assignment().num_dimensions();
2052 int64_t subgroup_size = sharding.sharding_types().size();
2053 int64_t rank = tile_dimensions - subgroup_size;
2054 std::vector<int64> group_dims;
2055 bool last_tile_dim_replicate = false;
2056
2057 for (int64_t i = 0; i < subgroup_size; i++) {
2058 if (sharding.sharding_types()[i] == OpSharding::MANUAL) {
2059 group_dims.push_back(rank + i);
2060 } else if (sharding.sharding_types()[i] == OpSharding::REPLICATED) {
2061 last_tile_dim_replicate = true;
2062 }
2063 }
2064
2065 GroupedSharding group_sharding = GroupShardingOnDims(sharding, group_dims);
2066
2067 if (last_tile_dim_replicate ||
2068 group_sharding.sharding.tile_assignment().num_dimensions() > rank) {
2069 group_sharding.sharding = HloSharding::PartialTile(
2070 group_sharding.sharding.tile_assignment(), sharding.metadata());
2071 }
2072 return group_sharding;
2073 }
2074
2075 } // namespace spmd
2076 } // namespace xla
2077