• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
17 
18 #include <algorithm>
19 #include <memory>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/inlined_vector.h"
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_join.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
32 #include "tensorflow/compiler/xla/service/hlo_module.h"
33 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
34 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
35 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
36 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
37 #include "tensorflow/compiler/xla/service/shape_inference.h"
38 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/compiler/xla/window_util.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43 
44 namespace xla {
45 namespace spmd {
46 
HasReplicatedSharding(const HloSharding & sharding)47 bool HasReplicatedSharding(const HloSharding& sharding) {
48   if (sharding.IsTuple()) {
49     return absl::c_any_of(sharding.tuple_elements(), HasReplicatedSharding);
50   }
51   return sharding.IsReplicated();
52 }
53 
CreateConstant(const Shape & shape,Literal value,SpmdBuilder * b)54 HloInstruction* CreateConstant(const Shape& shape, Literal value,
55                                SpmdBuilder* b) {
56   if (shape.IsTuple()) {
57     std::vector<HloInstruction*> elements;
58     for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
59       elements.push_back(CreateConstant(
60           ShapeUtil::GetTupleElementShape(shape, i), value.Clone(), b));
61     }
62     return b->AddInstruction(HloInstruction::CreateTuple(elements));
63   }
64 
65   CHECK(
66       ShapeUtil::IsScalarWithElementType(value.shape(), shape.element_type()));
67   auto c = b->AddInstruction(HloInstruction::CreateConstant(std::move(value)));
68   return b->AddInstruction(HloInstruction::CreateBroadcast(shape, c, {}));
69 }
70 
CreateZero(const Shape & shape,SpmdBuilder * b)71 HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b) {
72   if (shape.IsTuple()) {
73     std::vector<HloInstruction*> elements;
74     for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
75       elements.push_back(
76           CreateZero(ShapeUtil::GetTupleElementShape(shape, i), b));
77     }
78     return b->AddInstruction(HloInstruction::CreateTuple(elements));
79   }
80 
81   if (shape.IsToken()) {
82     return b->AddInstruction(HloInstruction::CreateToken());
83   }
84   auto zero = b->AddInstruction(
85       HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
86   if (shape.rank() == 0) {
87     return zero;
88   }
89   return b->AddInstruction(HloInstruction::CreateBroadcast(shape, zero, {}));
90 }
91 
CreateOne(const Shape & shape,SpmdBuilder * b)92 HloInstruction* CreateOne(const Shape& shape, SpmdBuilder* b) {
93   if (shape.IsTuple()) {
94     std::vector<HloInstruction*> elements;
95     for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
96       elements.push_back(
97           CreateOne(ShapeUtil::GetTupleElementShape(shape, i), b));
98     }
99     return b->AddInstruction(HloInstruction::CreateTuple(elements));
100   }
101 
102   if (shape.IsToken()) {
103     return b->AddInstruction(HloInstruction::CreateToken());
104   }
105   auto one = b->AddInstruction(
106       HloInstruction::CreateConstant(LiteralUtil::One(shape.element_type())));
107   return b->AddInstruction(HloInstruction::CreateBroadcast(shape, one, {}));
108 }
109 
MakeBinaryAdd(PrimitiveType type,HloModule * module)110 HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module) {
111   HloComputation::Builder sum_b("add");
112   auto x = sum_b.AddInstruction(HloInstruction::CreateParameter(
113       /*parameter_number=*/0, ShapeUtil::MakeShape(type, {}), "x"));
114   auto y = sum_b.AddInstruction(HloInstruction::CreateParameter(
115       /*parameter_number=*/1, ShapeUtil::MakeShape(type, {}), "y"));
116   if (type == PRED) {
117     sum_b.AddInstruction(HloInstruction::CreateBinary(
118         ShapeUtil::MakeShape(type, {}), HloOpcode::kOr, x, y));
119   } else {
120     sum_b.AddInstruction(HloInstruction::CreateBinary(
121         ShapeUtil::MakeShape(type, {}), HloOpcode::kAdd, x, y));
122   }
123   HloComputation* reduction = module->AddEmbeddedComputation(sum_b.Build());
124   return reduction;
125 }
126 
EvenlyPartitions(const Shape & shape,const HloSharding & sharding)127 bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding) {
128   if (sharding.IsTuple()) {
129     for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
130       if (!EvenlyPartitions(ShapeUtil::GetTupleElementShape(shape, i),
131                             sharding.GetSubSharding(shape, {i}))) {
132         return false;
133       }
134     }
135   }
136 
137   if (sharding.IsTileMaximal()) {
138     return sharding.IsReplicated();
139   }
140   for (int64_t i = 0; i < shape.dimensions_size(); ++i) {
141     if (shape.dimensions(i) % sharding.tile_assignment().dim(i) != 0) {
142       return false;
143     }
144   }
145   return true;
146 }
147 
MakePartitionedShape(const Shape & shape,const HloSharding & sharding)148 Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding) {
149   if (sharding.IsTuple()) {
150     std::vector<Shape> subshapes;
151     for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
152       subshapes.push_back(
153           MakePartitionedShape(ShapeUtil::GetTupleElementShape(shape, i),
154                                sharding.GetSubSharding(shape, {i})));
155     }
156     return ShapeUtil::MakeTupleShape(subshapes);
157   }
158   return sharding.TileShape(shape);
159 }
160 
ShapeSizeInBytes(const Shape & shape)161 int64 ShapeSizeInBytes(const Shape& shape) {
162   return ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()) *
163          ShapeUtil::ElementsIn(shape);
164 }
165 
MakeNonPaddedShapeForGivenPartition(const Shape & shape,const HloSharding & sharding,int64_t partition_id)166 Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape,
167                                           const HloSharding& sharding,
168                                           int64_t partition_id) {
169   if (sharding.IsTuple()) {
170     std::vector<Shape> subshapes;
171     for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
172       subshapes.push_back(MakeNonPaddedShapeForGivenPartition(
173           ShapeUtil::GetTupleElementShape(shape, i),
174           sharding.GetSubSharding(shape, {i}), partition_id));
175     }
176     return ShapeUtil::MakeTupleShape(subshapes);
177   }
178 
179   if (sharding.IsReplicated()) {
180     return shape;
181   }
182   if (sharding.IsTileMaximal()) {
183     if (partition_id == *sharding.UniqueDevice()) {
184       return shape;
185     }
186     return ShapeUtil::MakeTupleShape({});
187   }
188 
189   auto partition_shape = shape;
190   std::vector<int64> tile_offset =
191       sharding.TileOffsetForDevice(shape, partition_id);
192   std::vector<int64> tile_limit =
193       sharding.TileLimitForDevice(shape, partition_id);
194   for (int64_t i = 0; i < tile_offset.size(); ++i) {
195     if (sharding.UsesDevice(partition_id)) {
196       partition_shape.set_dimensions(i, tile_limit[i] - tile_offset[i]);
197     } else {
198       partition_shape.set_dimensions(i, 0);
199     }
200   }
201   return partition_shape;
202 }
203 
MakePartitionOffsets(const Shape & shape,const HloSharding & sharding,HloInstruction * partition_id,SpmdBuilder * b,absl::Span<const int64> dims)204 std::vector<HloInstruction*> MakePartitionOffsets(
205     const Shape& shape, const HloSharding& sharding,
206     HloInstruction* partition_id, SpmdBuilder* b,
207     absl::Span<const int64> dims) {
208   CHECK(!shape.IsTuple());
209 
210   std::vector<std::vector<int32>> offset_arrays(shape.rank());
211   for (int64_t i = 0; i < shape.rank(); ++i) {
212     offset_arrays[i].resize(sharding.tile_assignment().num_elements());
213   }
214   auto shard_shape = MakePartitionedShape(shape, sharding);
215   sharding.tile_assignment().Each(
216       [&](absl::Span<const int64> indices, int64_t device) {
217         for (int64_t i = 0; i < shape.rank(); ++i) {
218           offset_arrays[i][device] = indices[i] * shard_shape.dimensions(i);
219         }
220       });
221   std::vector<HloInstruction*> offsets;
222   for (int64_t i = 0; i < shape.rank(); ++i) {
223     if (sharding.tile_assignment().dim(i) == 1 ||
224         (!dims.empty() && !absl::c_linear_search(dims, i))) {
225       offsets.push_back(b->AddInstruction(
226           HloInstruction::CreateConstant(LiteralUtil::Zero(S32))));
227     } else {
228       auto offset_table = b->AddInstruction(HloInstruction::CreateConstant(
229           LiteralUtil::CreateR1<int32>(offset_arrays[i])));
230       auto index = b->AddInstruction(HloInstruction::CreateDynamicSlice(
231           ShapeUtil::MakeShape(S32, {1}), offset_table, {partition_id}, {1}));
232       offsets.push_back(b->AddInstruction(
233           HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), index)));
234     }
235   }
236   return offsets;
237 }
238 
MakeTiledPartitionOrdinals(const HloSharding & sharding,HloInstruction * partition_id,SpmdBuilder * b)239 std::vector<HloInstruction*> MakeTiledPartitionOrdinals(
240     const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b) {
241   CHECK(!sharding.IsTileMaximal());
242   auto dimensions = sharding.tile_assignment().dimensions();
243   if (sharding.ReplicateOnLastTileDim()) {
244     dimensions.pop_back();
245   }
246   auto table_shape = ShapeUtil::MakeShape(S32, dimensions);
247   return MakePartitionOffsets(table_shape, sharding, partition_id, b);
248 }
249 
PadToShape(HloInstruction * hlo,const Shape & padded_shape,SpmdBuilder * b,HloComputation * computation)250 HloInstruction* PadToShape(HloInstruction* hlo, const Shape& padded_shape,
251                            SpmdBuilder* b, HloComputation* computation) {
252   CHECK(b == nullptr || computation == nullptr);
253   if (ShapeUtil::Compatible(hlo->shape(), padded_shape)) {
254     return hlo;
255   }
256   PaddingConfig padding_config;
257   for (int64_t i = 0; i < padded_shape.rank(); ++i) {
258     auto padding_config_dim = padding_config.add_dimensions();
259     padding_config_dim->set_edge_padding_low(0);
260     padding_config_dim->set_interior_padding(0);
261     padding_config_dim->set_edge_padding_high(padded_shape.dimensions(i) -
262                                               hlo->shape().dimensions(i));
263   }
264   auto add_hlo = [&](std::unique_ptr<HloInstruction> to_add) {
265     if (b == nullptr) {
266       return computation->AddInstruction(std::move(to_add));
267     }
268     return b->AddInstruction(std::move(to_add));
269   };
270   auto zero = add_hlo(HloInstruction::CreateConstant(
271       LiteralUtil::Zero(hlo->shape().element_type())));
272   return add_hlo(
273       HloInstruction::CreatePad(padded_shape, hlo, zero, padding_config));
274 }
275 
GetPaddedShapeForUnevenPartitioning(const Shape & base_shape,const HloSharding & sharding)276 Shape GetPaddedShapeForUnevenPartitioning(const Shape& base_shape,
277                                           const HloSharding& sharding) {
278   if (sharding.IsTileMaximal()) {
279     return base_shape;
280   }
281   if (EvenlyPartitions(base_shape, sharding)) {
282     return base_shape;
283   }
284   auto shard_shape = MakePartitionedShape(base_shape, sharding);
285   Shape padded_base_shape = base_shape;
286   for (int64_t i = 0; i < padded_base_shape.rank(); ++i) {
287     padded_base_shape.set_dimensions(
288         i, shard_shape.dimensions(i) * sharding.tile_assignment().dim(i));
289   }
290   return padded_base_shape;
291 }
292 
PadBaseShapeBeforeUnevenTiledSharding(HloInstruction * hlo,const HloSharding & sharding,SpmdBuilder * b)293 HloInstruction* PadBaseShapeBeforeUnevenTiledSharding(
294     HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b) {
295   auto padded_base_shape =
296       GetPaddedShapeForUnevenPartitioning(hlo->shape(), sharding);
297   if (ShapeUtil::Compatible(padded_base_shape, hlo->shape())) {
298     return hlo;
299   }
300   return PadToShape(hlo, padded_base_shape, b);
301 }
302 
PartialReplicateReshardCompatibleSharding(const HloSharding & partial_sharding,const HloSharding & target_sharding)303 absl::optional<HloSharding> PartialReplicateReshardCompatibleSharding(
304     const HloSharding& partial_sharding, const HloSharding& target_sharding) {
305   if (!partial_sharding.ReplicateOnLastTileDim()) {
306     return absl::nullopt;
307   }
308   int64_t rank = partial_sharding.tile_assignment().num_dimensions() - 1;
309   int64_t target_rank = target_sharding.tile_assignment().num_dimensions() -
310                         (target_sharding.ReplicateOnLastTileDim() ? 1 : 0);
311   if (target_rank != rank) {
312     return absl::nullopt;
313   }
314 
315   absl::flat_hash_map<int64, int64> device_to_replication_group;
316   partial_sharding.tile_assignment().Each(
317       [&](absl::Span<const int64> indices, int64_t device) {
318         int64_t gid = 0;
319         for (int64_t i = 0; i < rank; ++i) {
320           gid *= partial_sharding.tile_assignment().dim(i);
321           gid += indices[i];
322         }
323         device_to_replication_group[device] = gid;
324       });
325 
326   // A dimension is expanded when target_tile_size > partial_tile_size and
327   // target_tile_size % partial_tile_size == 0.
328   // expand_tile_dims_positions is the index of the expand_dim.
329   std::vector<int64> expand_tile_dims_indices(rank, -1);
330   // expand_tile_size = target_tile_size / partial_tile_size.
331   std::vector<int64> expand_tile_sizes;
332   int num_expand_dims = 0;
333   for (int64_t dim = 0; dim < rank; dim++) {
334     int64_t partial_tile_size = partial_sharding.tile_assignment().dim(dim);
335     int64_t target_tile_size = target_sharding.tile_assignment().dim(dim);
336     if (target_tile_size % partial_tile_size != 0 ||
337         target_tile_size < partial_tile_size) {
338       return absl::nullopt;
339     }
340 
341     if (target_tile_size > partial_tile_size) {
342       expand_tile_dims_indices[dim] = num_expand_dims++;
343       expand_tile_sizes.emplace_back(target_tile_size / partial_tile_size);
344     }
345   }
346 
347   // Reshape the partial replicate tile_dimensions.
348   int64_t num_target_replication = 1;
349   if (target_sharding.ReplicateOnLastTileDim()) {
350     num_target_replication =
351         target_sharding.tile_assignment().dimensions().back();
352   }
353   auto reshape_dimensions = partial_sharding.tile_assignment().dimensions();
354   int64_t num_replication = reshape_dimensions.back();
355   if (num_replication / num_target_replication != Product(expand_tile_sizes) ||
356       num_replication % num_target_replication != 0) {
357     return absl::nullopt;
358   }
359 
360   reshape_dimensions.pop_back();
361   reshape_dimensions.insert(reshape_dimensions.end(), expand_tile_sizes.begin(),
362                             expand_tile_sizes.end());
363 
364   if (target_sharding.ReplicateOnLastTileDim()) {
365     reshape_dimensions.push_back(num_target_replication);
366   }
367 
368   auto reshape_tile_assignment = partial_sharding.tile_assignment();
369   reshape_tile_assignment.Reshape(reshape_dimensions);
370 
371   // Transpose.
372   std::vector<int64> perm;
373   perm.reserve(rank + expand_tile_sizes.size());
374   for (int64_t dim = 0; dim < rank; dim++) {
375     perm.emplace_back(dim);
376     if (expand_tile_dims_indices[dim] > -1) {
377       perm.emplace_back(expand_tile_dims_indices[dim] + rank);
378     }
379   }
380   auto transpose_sharding = hlo_sharding_util::TransposeSharding(
381       target_sharding.ReplicateOnLastTileDim()
382           ? HloSharding::PartialTile(reshape_tile_assignment)
383           : HloSharding::Tile(reshape_tile_assignment),
384       perm);
385 
386   // Reshape to target shape
387   auto transpose_tile_assignment = transpose_sharding.tile_assignment();
388   transpose_tile_assignment.Reshape(
389       target_sharding.tile_assignment().dimensions());
390 
391   bool groups_matching = true;
392   target_sharding.tile_assignment().Each(
393       [&](absl::Span<const int64> indices, int64_t device) {
394         if (device_to_replication_group[device] !=
395             device_to_replication_group[transpose_tile_assignment(indices)]) {
396           groups_matching = false;
397         }
398       });
399 
400   if (groups_matching) {
401     return target_sharding;
402   }
403   return target_sharding.ReplicateOnLastTileDim()
404              ? HloSharding::PartialTile(transpose_tile_assignment)
405              : HloSharding::Tile(transpose_tile_assignment);
406 }
407 
TileToPartialReplicateHaloExchange(HloInstruction * hlo,const Shape & base_shape,const HloSharding & src_sharding,const HloSharding & dst_sharding,const std::vector<int64> & replicate_dims,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,HloInstruction * partition_id,SpmdBuilder * b)408 absl::optional<HloInstruction*> TileToPartialReplicateHaloExchange(
409     HloInstruction* hlo, const Shape& base_shape,
410     const HloSharding& src_sharding, const HloSharding& dst_sharding,
411     const std::vector<int64>& replicate_dims,
412     const SPMDCollectiveOpsCreator& collective_ops_creator,
413     int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) {
414   // Source is tile sharding.
415   auto padded_src_shape =
416       GetPaddedShapeForUnevenPartitioning(base_shape, src_sharding);
417   // Target is partial replicate.
418   auto padded_dst_shape =
419       GetPaddedShapeForUnevenPartitioning(base_shape, dst_sharding);
420   if (ShapeUtil::Compatible(padded_dst_shape, hlo->shape())) {
421     return hlo;
422   }
423 
424   auto partition_ordinals =
425       MakeTiledPartitionOrdinals(dst_sharding, partition_id, b);
426 
427   auto result = hlo;
428   auto hlo_shape = hlo->shape();
429   for (auto dim : replicate_dims) {
430     int64_t dst_shard_count = dst_sharding.tile_assignment().dim(dim);
431     int64_t src_per_shard_size =
432         padded_src_shape.dimensions(dim) / dst_shard_count;
433     // Calculate per shard size using the sharding to compare if dst_sharding
434     // needs more padding at the end.
435     int64_t dst_per_shard_size =
436         padded_dst_shape.dimensions(dim) / dst_shard_count;
437 
438     // If src per shard doesn't have redundant data.
439     if (src_per_shard_size <= dst_per_shard_size || dst_shard_count == 1) {
440       continue;
441     }
442 
443     // If src_per_shard * replicate_factor > dst_per_shard , need to
444     // re-distribute the data between each shard using collective permute. For
445     // example, if dimension size is 6 and shard 4 ways in the src but needs to
446     // shard 2 ways in the dst. 4 way sharding has 2 element in each shard,
447     // while 2 way sharding has 3 elements, the last element in the first shard
448     // will be sliced out. re-distribution is needed.
449     //
450     // 1. Calculate left_halo size.
451     // left-halo size is
452     //   (src_per_shard_size - dst_per_shard_size) * i / replicate_factor
453     int64_t replicate_factor = src_sharding.tile_assignment().dim(dim) /
454                                dst_sharding.tile_assignment().dim(dim);
455     OffsetCalculation left_halo_size_function =
456         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
457             src_per_shard_size - dst_per_shard_size, 0, replicate_factor));
458 
459     // 2. Calculate right_halo size.
460     // right-halo size is 0
461     OffsetCalculation right_halo_size_function =
462         OffsetCalculation(MultiplyAddDivideOffsetCalculation(0, 0, 1));
463 
464     auto concat = result;
465     // 3. Halo exchange.
466     auto halo_exchange_result = ExchangeHalo(
467         result, left_halo_size_function, right_halo_size_function, dim,
468         src_sharding, collective_ops_creator, next_channel_id, b);
469 
470     if (halo_exchange_result.has_value()) {
471       concat = halo_exchange_result.value();
472     } else {
473       return absl::nullopt;
474     }
475 
476     // 4. Slice the valid result.
477     // Slice offset is
478     // (dst_shard_count - i - 1) *
479     // (src_per_shard_size - dst_per_shard_size)
480     // i is the index in dst_sharindg.
481     auto zero_s32 = b->AddInstruction(
482         HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
483     OffsetCalculation start_offset_on_padded_concat_calculation =
484         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
485             dst_per_shard_size - src_per_shard_size,
486             (src_per_shard_size - dst_per_shard_size) * (dst_shard_count - 1),
487             1));
488     auto slice_shape = concat->shape();
489     slice_shape.set_dimensions(dim,
490                                padded_src_shape.dimensions(dim) /
491                                    src_sharding.tile_assignment().dim(dim));
492     std::vector<HloInstruction*> slice_offsets(concat->shape().rank(),
493                                                zero_s32);
494     slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate(
495         partition_ordinals[dim], b);
496     result = b->AddInstruction(HloInstruction::CreateDynamicSlice(
497         slice_shape, concat, slice_offsets, slice_shape.dimensions()));
498   }
499   return result;
500 }
501 
PadFromPartialReplicateShape(HloInstruction * hlo,const Shape & base_shape,const HloSharding & src_sharding,const HloSharding & dst_sharding,const std::vector<int64> & expand_tile_dims,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,HloInstruction * partition_id,SpmdBuilder * b)502 absl::optional<HloInstruction*> PadFromPartialReplicateShape(
503     HloInstruction* hlo, const Shape& base_shape,
504     const HloSharding& src_sharding, const HloSharding& dst_sharding,
505     const std::vector<int64>& expand_tile_dims,
506     const SPMDCollectiveOpsCreator& collective_ops_creator,
507     int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) {
508   auto padded_src_shape =
509       GetPaddedShapeForUnevenPartitioning(base_shape, src_sharding);
510   auto padded_dst_shape =
511       GetPaddedShapeForUnevenPartitioning(base_shape, dst_sharding);
512   if (ShapeUtil::Compatible(padded_dst_shape, hlo->shape())) {
513     return hlo;
514   }
515 
516   auto partition_ordinals =
517       MakeTiledPartitionOrdinals(src_sharding, partition_id, b);
518 
519   HloInstruction* result = hlo;
520   auto zero = b->AddInstruction(HloInstruction::CreateConstant(
521       LiteralUtil::Zero(hlo->shape().element_type())));
522   std::vector<int64> expand_dims_without_halo_exchange;
523   // Pad the dimensions needs halo exchange and record the padded dims that
524   // won't need halo exchange.
525   for (auto dim : expand_tile_dims) {
526     int64_t src_shard_count = src_sharding.tile_assignment().dim(dim);
527     int64_t src_per_shard_size =
528         padded_src_shape.dimensions(dim) / src_shard_count;
529     // Calculate per shard size using the sharding to compare if dst_sharding
530     // needs more padding at the end.
531     int64_t dst_per_shard_size =
532         padded_dst_shape.dimensions(dim) / src_shard_count;
533 
534     // If dst_sharding doesn't need more padding at the end.
535     if (src_per_shard_size >= dst_per_shard_size) {
536       continue;
537     }
538     // If src sharding at this dimension is not partitoned, simply pad to
539     // the desired shape.
540     if (src_shard_count == 1) {
541       expand_dims_without_halo_exchange.emplace_back(dim);
542       continue;
543     }
544 
545     // If dst_padding needs more padding at the end, need to re-distribute the
546     // data between each shard using collective permute.
547     // For example, if dimension size is 6 and shard 2 ways in the src but
548     // needs to shard 4 ways in the dst. 4 ways needs padding 2 0s at the end
549     // and has 2 elements at each shard, while 2 way sharding has 3 elements
550     // in each shard, re-distribution is needed.
551     //
552     // 1. Calculate left_halo size.
553     // left-halo size is 0
554     OffsetCalculation left_halo_size_function =
555         OffsetCalculation(MultiplyAddDivideOffsetCalculation(0, 0, 1));
556 
557     // 2. Calculate right_halo size.
558     // right-halo size is D * (i + 1) - S * (i + 1) = (D - S) * i + (D - S)
559     OffsetCalculation right_halo_size_function =
560         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
561             dst_per_shard_size - src_per_shard_size,
562             dst_per_shard_size - src_per_shard_size, 1));
563 
564     auto concat = result;
565     // 3. Halo exchange.
566     auto halo_exchange_result = ExchangeHalo(
567         result, left_halo_size_function, right_halo_size_function, dim,
568         src_sharding, collective_ops_creator, next_channel_id, b);
569 
570     if (halo_exchange_result.has_value()) {
571       concat = halo_exchange_result.value();
572     } else {
573       return absl::nullopt;
574     }
575 
576     // 4. Pad.
577     std::vector<int64> zero_padding(concat->shape().rank());
578     PaddingConfig pad_config = window_util::MakeSymmetricPadding(zero_padding);
579     pad_config.mutable_dimensions(dim)->set_edge_padding_low(0);
580     int64_t max_right_halo_size =
581         right_halo_size_function.MaxInRange(0, src_shard_count - 1);
582     pad_config.mutable_dimensions(dim)->set_edge_padding_high(std::max(
583         int64{0}, padded_dst_shape.dimensions(dim) -
584                       padded_src_shape.dimensions(dim) - max_right_halo_size));
585     auto padded_concat_shape = ShapeInference::InferPadShape(
586                                    concat->shape(), zero->shape(), pad_config)
587                                    .ValueOrDie();
588     concat = b->AddInstruction(HloInstruction::CreatePad(
589         padded_concat_shape, concat, zero, pad_config));
590 
591     // 5. Slice the valid result.
592     // Slice offset is (D-S) * i
593     auto zero_s32 = b->AddInstruction(
594         HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
595     OffsetCalculation start_offset_on_padded_concat_calculation =
596         OffsetCalculation(MultiplyAddDivideOffsetCalculation(
597             dst_per_shard_size - src_per_shard_size, 0, 1));
598     auto slice_shape = concat->shape();
599     slice_shape.set_dimensions(dim, dst_per_shard_size);
600     std::vector<HloInstruction*> slice_offsets(concat->shape().rank(),
601                                                zero_s32);
602     slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate(
603         partition_ordinals[dim], b);
604     result = b->AddInstruction(HloInstruction::CreateDynamicSlice(
605         slice_shape, concat, slice_offsets, slice_shape.dimensions()));
606   }
607 
608   // Pad other dimensions that won't need halo exchange with a single pad.
609   if (!expand_dims_without_halo_exchange.empty()) {
610     std::vector<int64> zero_padding(result->shape().rank());
611     PaddingConfig pad_config = window_util::MakeSymmetricPadding(zero_padding);
612 
613     auto padded_shape = result->shape();
614     for (auto dim : expand_dims_without_halo_exchange) {
615       pad_config.mutable_dimensions(dim)->set_edge_padding_low(0);
616       pad_config.mutable_dimensions(dim)->set_edge_padding_high(
617           padded_dst_shape.dimensions(dim) - padded_src_shape.dimensions(dim));
618       padded_shape.set_dimensions(dim, result->shape().dimensions(dim) +
619                                            padded_dst_shape.dimensions(dim) -
620                                            padded_src_shape.dimensions(dim));
621     }
622     result = b->AddInstruction(
623         HloInstruction::CreatePad(padded_shape, result, zero, pad_config));
624   }
625 
626   return result;
627 }
628 
UniqueTiledDim(const HloSharding & sharding)629 absl::optional<int64> UniqueTiledDim(const HloSharding& sharding) {
630   if (sharding.IsTileMaximal()) {
631     return absl::nullopt;
632   }
633   int64_t dim = -1;
634   int64_t rank = sharding.ReplicateOnLastTileDim()
635                      ? sharding.tile_assignment().num_dimensions() - 1
636                      : sharding.tile_assignment().num_dimensions();
637   for (int64_t i = 0; i < rank; ++i) {
638     if (sharding.tile_assignment().dim(i) > 1) {
639       if (dim != -1) {
640         return absl::nullopt;
641       }
642       dim = i;
643     }
644   }
645   CHECK_NE(dim, -1);
646   return dim;
647 }
648 
MultiplyAddDivideOffsetCalculation(int64_t multiplier,int64_t offset,int64_t divisor)649 MultiplyAddDivideOffsetCalculation::MultiplyAddDivideOffsetCalculation(
650     int64_t multiplier, int64_t offset, int64_t divisor)
651     : multiplier_(multiplier), offset_(offset), divisor_(divisor) {
652   CHECK_GT(divisor_, 0);
653   Simplify();
654 }
655 
operator -(const MultiplyAddDivideOffsetCalculation & other) const656 OffsetCalculation MultiplyAddDivideOffsetCalculation::operator-(
657     const MultiplyAddDivideOffsetCalculation& other) const {
658   if (divisor_ == 1 && other.divisor_ == 1) {
659     return OffsetCalculation(MultiplyAddDivideOffsetCalculation(
660         multiplier_ - other.multiplier_, offset_ - other.offset_, 1));
661   }
662   return OffsetCalculation(HloOpcode::kSubtract, *this, other);
663 }
664 
Simplify()665 void MultiplyAddDivideOffsetCalculation::Simplify() {
666   // We could simplify the calculation when multiplier is a multiple of
667   // divisor_. However, when offset_ is not a multiple of divisor_, we must
668   // make sure that offset_ and multiplier_ are both non-negative or both
669   // non-positive. E.g., (3 * i  - 1) / 3 is not equivalent to i or i - 1.
670   if (divisor_ != 1 && multiplier_ % divisor_ == 0 &&
671       (offset_ % divisor_ == 0 || offset_ * multiplier_ > 0)) {
672     multiplier_ /= divisor_;
673     offset_ /= divisor_;
674     divisor_ = 1;
675   }
676 }
677 
Calculate(int64_t shard_ordinal) const678 int64 MultiplyAddDivideOffsetCalculation::Calculate(
679     int64_t shard_ordinal) const {
680   return (shard_ordinal * multiplier_ + offset_) / divisor_;
681 }
682 
Calculate(HloInstruction * shard_ordinal,SpmdBuilder * b) const683 HloInstruction* MultiplyAddDivideOffsetCalculation::Calculate(
684     HloInstruction* shard_ordinal, SpmdBuilder* b) const {
685   auto scalar_shape = ShapeUtil::MakeShape(S32, {});
686   if (multiplier_ == 0) {
687     return b->AddInstruction(HloInstruction::CreateConstant(
688         LiteralUtil::CreateR0<int32>(offset_ / divisor_)));
689   }
690   HloInstruction* result = shard_ordinal;
691   if (multiplier_ != 1) {
692     result = b->AddInstruction(HloInstruction::CreateBinary(
693         scalar_shape, HloOpcode::kMultiply, shard_ordinal,
694         b->AddInstruction(HloInstruction::CreateConstant(
695             LiteralUtil::CreateR0<int32>(multiplier_)))));
696   }
697   if (offset_ != 0) {
698     auto offset = b->AddInstruction(
699         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(offset_)));
700     result = b->AddInstruction(HloInstruction::CreateBinary(
701         scalar_shape, HloOpcode::kAdd, result, offset));
702   }
703   if (divisor_ != 1) {
704     auto divisor = b->AddInstruction(
705         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(divisor_)));
706     result = b->AddInstruction(HloInstruction::CreateBinary(
707         scalar_shape, HloOpcode::kDivide, result, divisor));
708   }
709   return result;
710 }
711 
MaxInRange(int64_t start_ordinal,int64_t limit_ordinal) const712 int64 MultiplyAddDivideOffsetCalculation::MaxInRange(
713     int64_t start_ordinal, int64_t limit_ordinal) const {
714   int64_t max = Calculate(start_ordinal);
715   for (int64_t i = start_ordinal + 1; i < limit_ordinal; ++i) {
716     max = std::max(max, Calculate(i));
717   }
718   return max;
719 }
720 
operator =(const OffsetCalculation & other)721 OffsetCalculation& OffsetCalculation::operator=(
722     const OffsetCalculation& other) {
723   opcode_ = other.opcode_;
724   copy_from_ = other.copy_from_;
725   if (opcode_ != HloOpcode::kCopy) {
726     lhs_ = absl::make_unique<OffsetCalculation>(*other.lhs_);
727     rhs_ = absl::make_unique<OffsetCalculation>(*other.rhs_);
728   }
729   return *this;
730 }
731 
IsConstant() const732 bool OffsetCalculation::IsConstant() const {
733   if (opcode_ == HloOpcode::kCopy) {
734     return copy_from_.IsConstant();
735   }
736   if (opcode_ == HloOpcode::kSubtract && *lhs_ == *rhs_) {
737     return true;
738   }
739   return lhs_->IsConstant() && rhs_->IsConstant();
740 }
741 
operator -(const OffsetCalculation & other) const742 OffsetCalculation OffsetCalculation::operator-(
743     const OffsetCalculation& other) const {
744   if (opcode_ == HloOpcode::kCopy && other.opcode_ == HloOpcode::kCopy) {
745     return copy_from_ - other.copy_from_;
746   }
747   return OffsetCalculation(HloOpcode::kSubtract, *this, other);
748 }
749 
operator ==(const OffsetCalculation & other) const750 bool OffsetCalculation::operator==(const OffsetCalculation& other) const {
751   if (opcode_ != other.opcode_) {
752     return false;
753   }
754   if (opcode_ == HloOpcode::kCopy) {
755     return copy_from_ == other.copy_from_;
756   }
757   return *lhs_ == *other.lhs_ && *rhs_ == *other.rhs_;
758 }
759 
Calculate(int64_t shard_ordinal) const760 int64 OffsetCalculation::Calculate(int64_t shard_ordinal) const {
761   switch (opcode_) {
762     case HloOpcode::kCopy:
763       return copy_from_.Calculate(shard_ordinal);
764     case HloOpcode::kSubtract:
765       return lhs_->Calculate(shard_ordinal) - rhs_->Calculate(shard_ordinal);
766     case HloOpcode::kMultiply:
767       return lhs_->Calculate(shard_ordinal) * rhs_->Calculate(shard_ordinal);
768     default:
769       LOG(FATAL) << "Should not happen";
770   }
771 }
772 
Calculate(HloInstruction * shard_ordinal,SpmdBuilder * b) const773 HloInstruction* OffsetCalculation::Calculate(HloInstruction* shard_ordinal,
774                                              SpmdBuilder* b) const {
775   if (opcode_ == HloOpcode::kCopy) {
776     return copy_from_.Calculate(shard_ordinal, b);
777   }
778   auto lhs = lhs_->Calculate(shard_ordinal, b);
779   auto rhs = rhs_->Calculate(shard_ordinal, b);
780   return b->AddInstruction(
781       HloInstruction::CreateBinary(lhs->shape(), opcode_, lhs, rhs));
782 }
783 
MaxInRange(int64_t start_ordinal,int64_t limit_ordinal) const784 int64 OffsetCalculation::MaxInRange(int64_t start_ordinal,
785                                     int64_t limit_ordinal) const {
786   if (IsConstant()) {
787     return Calculate(start_ordinal);
788   }
789   if (opcode_ == HloOpcode::kCopy) {
790     return std::max(Calculate(start_ordinal), Calculate(limit_ordinal - 1));
791   }
792   int64_t max = Calculate(start_ordinal);
793   for (int64_t i = start_ordinal + 1; i < limit_ordinal; ++i) {
794     max = std::max(max, Calculate(i));
795   }
796   return max;
797 }
798 
ExchangeHalo(HloInstruction * hlo,const OffsetCalculation & left_halo_size_function,const OffsetCalculation & right_halo_size_function,int64_t dim,const HloSharding & target,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,SpmdBuilder * b)799 absl::optional<HloInstruction*> ExchangeHalo(
800     HloInstruction* hlo, const OffsetCalculation& left_halo_size_function,
801     const OffsetCalculation& right_halo_size_function, int64_t dim,
802     const HloSharding& target,
803     const SPMDCollectiveOpsCreator& collective_ops_creator,
804     int64* next_channel_id, SpmdBuilder* b) {
805   int64_t input_shard_size = hlo->shape().dimensions(dim);
806   int64_t shard_count = target.tile_assignment().dim(dim);
807 
808   std::vector<HloInstruction*> concat_pieces;
809 
810   int64_t max_left_halo_size =
811       left_halo_size_function.MaxInRange(1, shard_count);
812   int64_t max_right_halo_size =
813       right_halo_size_function.MaxInRange(0, shard_count - 1);
814   if (max_left_halo_size + max_right_halo_size + input_shard_size >=
815           input_shard_size * shard_count &&
816       (max_left_halo_size > input_shard_size ||
817        max_right_halo_size > input_shard_size)) {
818     return absl::nullopt;
819   }
820   // Since max halo sizes could be negative, we only need to include data within
821   // certain bounds. Useful region is [left_bound, right_bound).
822   const int64_t left_bound =
823       -left_halo_size_function.MaxInRange(0, shard_count);
824   const int64_t right_bound =
825       input_shard_size + right_halo_size_function.MaxInRange(0, shard_count);
826   if (left_bound >= right_bound) {
827     return absl::nullopt;
828   }
829   // Left halo.
830   for (int64_t i = CeilOfRatio(max_left_halo_size, input_shard_size) - 1;
831        i >= 0 && (-i - 1) * input_shard_size < right_bound; --i) {
832     std::vector<std::pair<int64, int64>> source_target_pairs;
833     target.tile_assignment().Each(
834         [&](absl::Span<const int64> indices, int64_t device) {
835           if (indices[dim] > i) {
836             std::vector<int64> source_indices(indices.begin(), indices.end());
837             source_indices[dim] -= i + 1;
838             source_target_pairs.emplace_back(
839                 target.tile_assignment()(source_indices), device);
840           }
841         });
842     int64_t halo_size_including_skips =
843         std::min(max_left_halo_size - input_shard_size * i, input_shard_size);
844     int64_t halo_right_skips =
845         std::max<int64>(-i * input_shard_size - right_bound, 0);
846     int64_t halo_size = halo_size_including_skips - halo_right_skips;
847     auto halo_shape = hlo->shape();
848     auto source_halo_slice = hlo;
849     if (halo_size != hlo->shape().dimensions(dim)) {
850       halo_shape.set_dimensions(dim, halo_size);
851       std::vector<int64> halo_start_indices(halo_shape.rank(), 0);
852       halo_start_indices[dim] =
853           hlo->shape().dimensions(dim) - halo_size_including_skips;
854       std::vector<int64> halo_limit_indices(hlo->shape().dimensions().begin(),
855                                             hlo->shape().dimensions().end());
856       halo_limit_indices[dim] -= halo_right_skips;
857       std::vector<int64> halo_slice_strides(halo_shape.rank(), 1);
858       source_halo_slice = b->AddInstruction(
859           HloInstruction::CreateSlice(halo_shape, hlo, halo_start_indices,
860                                       halo_limit_indices, halo_slice_strides));
861     }
862     auto left_halo =
863         collective_ops_creator.create_cross_partition_collective_permute(
864             b, source_halo_slice, source_target_pairs, (*next_channel_id)++);
865     concat_pieces.push_back(left_halo);
866   }
867 
868   if (left_bound < input_shard_size && right_bound > 0) {
869     int64_t self_start = std::max<int64>(0, left_bound);
870     int64_t self_limit = std::min<int64>(input_shard_size, right_bound);
871     if (self_start == 0 && self_limit == input_shard_size) {
872       concat_pieces.push_back(hlo);
873     } else {
874       auto self_shape = hlo->shape();
875       self_shape.set_dimensions(dim, self_limit - self_start);
876       std::vector<int64> start_indices(self_shape.rank(), 0);
877       start_indices[dim] = self_start;
878       std::vector<int64> limit_indices(hlo->shape().dimensions().begin(),
879                                        hlo->shape().dimensions().end());
880       limit_indices[dim] = self_limit;
881       std::vector<int64> slice_strides(self_shape.rank(), 1);
882       concat_pieces.push_back(b->AddInstruction(HloInstruction::CreateSlice(
883           self_shape, hlo, start_indices, limit_indices, slice_strides)));
884     }
885   }
886 
887   int64_t skipped_right_halos =
888       std::min<int64>(std::max<int64>(left_bound - input_shard_size, 0),
889                       std::max<int64>(max_right_halo_size, 0)) /
890       input_shard_size;
891   // Right halo.
892   for (int64_t i = skipped_right_halos;
893        i < CeilOfRatio(max_right_halo_size, input_shard_size); ++i) {
894     std::vector<std::pair<int64, int64>> source_target_pairs;
895     target.tile_assignment().Each(
896         [&](absl::Span<const int64> indices, int64_t device) {
897           if (indices[dim] > i) {
898             std::vector<int64> target_indices(indices.begin(), indices.end());
899             target_indices[dim] -= i + 1;
900             source_target_pairs.emplace_back(
901                 device, target.tile_assignment()(target_indices));
902           }
903         });
904     int64_t halo_size_including_skips =
905         std::min(max_right_halo_size - input_shard_size * i, input_shard_size);
906     int64_t halo_left_skips =
907         std::max<int64>(left_bound - (i + 1) * input_shard_size, 0);
908     int64_t halo_size = halo_size_including_skips - halo_left_skips;
909     auto halo_shape = hlo->shape();
910     HloInstruction* source_halo_slice = hlo;
911     if (halo_size != halo_shape.dimensions(dim)) {
912       halo_shape.set_dimensions(dim, halo_size);
913       std::vector<int64> halo_start_indices(halo_shape.rank(), 0);
914       halo_start_indices[dim] = halo_left_skips;
915       std::vector<int64> halo_limit_indices(halo_shape.dimensions().begin(),
916                                             halo_shape.dimensions().end());
917       halo_limit_indices[dim] += halo_left_skips;
918       std::vector<int64> halo_slice_strides(halo_shape.rank(), 1);
919       source_halo_slice = b->AddInstruction(
920           HloInstruction::CreateSlice(halo_shape, hlo, halo_start_indices,
921                                       halo_limit_indices, halo_slice_strides));
922     }
923     auto right_halo =
924         collective_ops_creator.create_cross_partition_collective_permute(
925             b, source_halo_slice, source_target_pairs, (*next_channel_id)++);
926     concat_pieces.push_back(right_halo);
927   }
928 
929   auto concat = concat_pieces[0];
930   // Concat with halos/padding.
931   if (concat_pieces.size() > 1) {
932     auto concat_shape = hlo->shape();
933     int64_t concat_dim_size = 0;
934     for (auto piece : concat_pieces) {
935       concat_dim_size += piece->shape().dimensions(dim);
936     }
937     concat_shape.set_dimensions(dim, concat_dim_size);
938     concat = b->AddInstruction(
939         HloInstruction::CreateConcatenate(concat_shape, concat_pieces, dim));
940   }
941 
942   return concat;
943 }
944 
ExchangeHalo(HloInstruction * hlo,std::vector<OffsetCalculation> left_halo_size_functions,std::vector<OffsetCalculation> right_halo_size_functions,const HloSharding & target,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,SpmdBuilder * b)945 absl::optional<HloInstruction*> ExchangeHalo(
946     HloInstruction* hlo,
947     std::vector<OffsetCalculation> left_halo_size_functions,
948     std::vector<OffsetCalculation> right_halo_size_functions,
949     const HloSharding& target,
950     const SPMDCollectiveOpsCreator& collective_ops_creator,
951     int64* next_channel_id, SpmdBuilder* b) {
952   CHECK(left_halo_size_functions.size() == hlo->shape().rank());
953   CHECK(right_halo_size_functions.size() == hlo->shape().rank());
954 
955   HloInstruction* visiting_hlo = hlo;
956   for (int dim = 0; dim < hlo->shape().rank(); ++dim) {
957     auto concat = ExchangeHalo(visiting_hlo, left_halo_size_functions[dim],
958                                right_halo_size_functions[dim], dim, target,
959                                collective_ops_creator, next_channel_id, b);
960     if (!concat) {
961       return absl::nullopt;
962     }
963     visiting_hlo = *concat;
964   }
965   return visiting_hlo;
966 }
967 
ExchangeHaloAndGetValidData(HloInstruction * hlo,const Shape & base_shape,const OffsetCalculation & left_halo_size_function,const OffsetCalculation & right_halo_size_function,int64_t explicit_left_padding_on_full_shape,int64_t padded_full_shape_size,int64_t shard_size_with_halo,int64_t dim,const HloSharding & target,HloInstruction * offset_on_padded_shape,HloInstruction * pad_value,HloInstruction * partition_ordinal,const SPMDCollectiveOpsCreator & collective_ops_creator,int64 * next_channel_id,SpmdBuilder * b,bool mask_invalid_region)968 absl::optional<HloInstruction*> ExchangeHaloAndGetValidData(
969     HloInstruction* hlo, const Shape& base_shape,
970     const OffsetCalculation& left_halo_size_function,
971     const OffsetCalculation& right_halo_size_function,
972     int64_t explicit_left_padding_on_full_shape, int64_t padded_full_shape_size,
973     int64_t shard_size_with_halo, int64_t dim, const HloSharding& target,
974     HloInstruction* offset_on_padded_shape, HloInstruction* pad_value,
975     HloInstruction* partition_ordinal,
976     const SPMDCollectiveOpsCreator& collective_ops_creator,
977     int64* next_channel_id, SpmdBuilder* b, bool mask_invalid_region) {
978   auto halo_exchange_result =
979       ExchangeHalo(hlo, left_halo_size_function, right_halo_size_function, dim,
980                    target, collective_ops_creator, next_channel_id, b);
981   if (!halo_exchange_result) {
982     return absl::nullopt;
983   }
984   auto concat = *halo_exchange_result;
985   int64_t shard_count = target.tile_assignment().dim(dim);
986   int64_t max_left_halo_size =
987       left_halo_size_function.MaxInRange(1, shard_count);
988 
989   // Now we determine if we need extra padding after the concat.
990   //
991   // The max of halo size or the first shard's explicit left padding.
992   int64_t max_left_halo_or_padding_size =
993       std::max(max_left_halo_size, explicit_left_padding_on_full_shape);
994   // The calculation that returns the dynamic slice index for a shard on the
995   // padded concat, which is the difference between
996   // max_left_halo_or_padding_size and its left halo size.
997   auto start_offset_on_padded_concat_calculation =
998       OffsetCalculation(MultiplyAddDivideOffsetCalculation(
999           0, max_left_halo_or_padding_size, 1)) -
1000       left_halo_size_function;
1001 
1002   // See if we need to pad the concat before dynamic slice.
1003   int64_t extra_left_padding =
1004       std::max(int64{0}, max_left_halo_or_padding_size -
1005                              std::max(int64{0}, max_left_halo_size));
1006   int64_t extra_right_padding =
1007       start_offset_on_padded_concat_calculation.MaxInRange(0, shard_count) +
1008       shard_size_with_halo - concat->shape().dimensions(dim) -
1009       extra_left_padding;
1010   extra_right_padding = std::max(int64{0}, extra_right_padding);
1011   if (extra_left_padding > 0 || extra_right_padding > 0) {
1012     PaddingConfig padding_config;
1013     auto padded_concat_shape = concat->shape();
1014     for (int64_t i = 0; i < base_shape.rank(); ++i) {
1015       auto padding_config_dim = padding_config.add_dimensions();
1016       padding_config_dim->set_interior_padding(0);
1017       padding_config_dim->set_edge_padding_low(0);
1018       padding_config_dim->set_edge_padding_high(0);
1019       if (i != dim) {
1020         continue;
1021       }
1022       padding_config_dim->set_edge_padding_low(extra_left_padding);
1023       padding_config_dim->set_edge_padding_high(extra_right_padding);
1024       padded_concat_shape.set_dimensions(dim, concat->shape().dimensions(dim) +
1025                                                   extra_left_padding +
1026                                                   extra_right_padding);
1027     }
1028     concat = b->AddInstruction(HloInstruction::CreatePad(
1029         padded_concat_shape, concat, pad_value, padding_config));
1030   }
1031 
1032   auto valid_slice = concat;
1033   if (shard_size_with_halo != concat->shape().dimensions(dim)) {
1034     // Concat is bigger than the shard shape, so we need a dynamic slice.
1035     CHECK_LT(shard_size_with_halo, concat->shape().dimensions(dim));
1036     auto slice_shape = concat->shape();
1037     slice_shape.set_dimensions(dim, shard_size_with_halo);
1038 
1039     if (left_halo_size_function.IsConstant() &&
1040         left_halo_size_function.Calculate(0) ==
1041             explicit_left_padding_on_full_shape) {
1042       std::vector<int64> start_indices(slice_shape.rank(), 0);
1043       std::vector<int64> strides(slice_shape.rank(), 1);
1044       valid_slice = b->AddInstruction(
1045           HloInstruction::CreateSlice(slice_shape, concat, start_indices,
1046                                       slice_shape.dimensions(), strides));
1047     } else {
1048       auto zero = b->AddInstruction(
1049           HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
1050       std::vector<HloInstruction*> slice_offsets(base_shape.rank(), zero);
1051       slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate(
1052           partition_ordinal, b);
1053       valid_slice = b->AddInstruction(HloInstruction::CreateDynamicSlice(
1054           slice_shape, concat, slice_offsets, slice_shape.dimensions()));
1055     }
1056   }
1057 
1058   if (!mask_invalid_region) {
1059     return valid_slice;
1060   }
1061 
1062   int64_t total_right_padding = padded_full_shape_size -
1063                                 base_shape.dimensions(dim) -
1064                                 explicit_left_padding_on_full_shape;
1065   // Mask off garbage data due to uneven partition or low/high padding.
1066   if (explicit_left_padding_on_full_shape > 0 || total_right_padding > 0) {
1067     auto index_shape = ShapeUtil::ChangeElementType(valid_slice->shape(), S32);
1068     auto iota = b->AddInstruction(HloInstruction::CreateIota(index_shape, dim));
1069     auto broadcast_start_index_in_padded_shape =
1070         b->AddInstruction(HloInstruction::CreateBroadcast(
1071             index_shape, offset_on_padded_shape, {}));
1072     auto index_in_padded_shape = b->AddInstruction(
1073         HloInstruction::CreateBinary(index_shape, HloOpcode::kAdd, iota,
1074                                      broadcast_start_index_in_padded_shape));
1075     auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED);
1076     std::vector<HloInstruction*> predicates;
1077     if (explicit_left_padding_on_full_shape > 0) {
1078       auto valid_index_start =
1079           b->AddInstruction(HloInstruction::CreateBroadcast(
1080               index_shape,
1081               b->AddInstruction(
1082                   HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
1083                       explicit_left_padding_on_full_shape))),
1084               {}));
1085       predicates.push_back(b->AddInstruction(HloInstruction::CreateCompare(
1086           mask_shape, index_in_padded_shape, valid_index_start,
1087           ComparisonDirection::kGe)));
1088     }
1089     if (total_right_padding > 0) {
1090       auto valid_index_limit =
1091           b->AddInstruction(HloInstruction::CreateBroadcast(
1092               index_shape,
1093               b->AddInstruction(
1094                   HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
1095                       base_shape.dimensions(dim) +
1096                       explicit_left_padding_on_full_shape))),
1097               {}));
1098       predicates.push_back(b->AddInstruction(HloInstruction::CreateCompare(
1099           mask_shape, index_in_padded_shape, valid_index_limit,
1100           ComparisonDirection::kLt)));
1101     }
1102     CHECK(!predicates.empty());
1103     auto is_valid =
1104         predicates.size() == 2
1105             ? b->AddInstruction(HloInstruction::CreateBinary(
1106                   mask_shape, HloOpcode::kAnd, predicates[0], predicates[1]))
1107             : predicates[0];
1108     auto masking_value = b->AddInstruction(
1109         HloInstruction::CreateBroadcast(valid_slice->shape(), pad_value, {}));
1110     valid_slice = b->AddInstruction(
1111         HloInstruction::CreateTernary(valid_slice->shape(), HloOpcode::kSelect,
1112                                       is_valid, valid_slice, masking_value));
1113   }
1114   return valid_slice;
1115 }
1116 
HaloExchangeToPadOnLeft(PartitionedHlo & original,absl::Span<const int64> dims)1117 HloInstruction* HaloExchangeToPadOnLeft(PartitionedHlo& original,
1118                                         absl::Span<const int64> dims) {
1119   if (original.sharding().IsTileMaximal()) {
1120     return original.hlo();
1121   }
1122   // Create a window config to halo exchange for unevenly partitioned reverse
1123   // dimensions.
1124   Window window;
1125   for (int64_t i = 0; i < original.base_shape().rank(); ++i) {
1126     WindowDimension* dim = window.add_dimensions();
1127     dim->set_size(1);
1128     dim->set_stride(1);
1129     dim->set_window_dilation(1);
1130     dim->set_window_reversal(false);
1131     int64_t low_padding = 0;
1132     if (absl::c_linear_search(dims, i)) {
1133       low_padding =
1134           RoundUpToNearest(original.base_shape().dimensions(i),
1135                            original.sharding().tile_assignment().dim(i)) -
1136           original.base_shape().dimensions(i);
1137     }
1138     dim->set_padding_low(low_padding);
1139     dim->set_padding_high(0);
1140     dim->set_base_dilation(1);
1141   }
1142 
1143   auto reshard_window = original.ReshardAsWindowedInput(
1144       window, original.sharding(),
1145       CreateZero(ShapeUtil::MakeShape(original.base_shape().element_type(), {}),
1146                  original.state().b),
1147       /*mask_invalid_region=*/false);
1148   if (!reshard_window.has_value()) {
1149     return nullptr;
1150   }
1151   CHECK(!reshard_window->dynamic_slice_index_on_output.has_value());
1152   return reshard_window->sharded_input;
1153 }
1154 
IsNanSafeGt(HloComputation * comp)1155 bool IsNanSafeGt(HloComputation* comp) {
1156   namespace m = match;
1157   auto match_bitcast_f32 = [](int64_t parameter_number) {
1158     auto param = m::Parameter(parameter_number)
1159                      .WithShape(m::Shape().WithElementType(F32));
1160     auto param_s32 =
1161         m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32));
1162     auto param_u32 =
1163         m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32));
1164     return m::Select(
1165         m::Lt(param_s32, m::ConstantScalar(0)),
1166         m::BitcastConvert(
1167             m::Subtract(m::ConstantScalar(std::numeric_limits<int32>::max()),
1168                         param_u32))
1169             .WithShape(m::Shape().WithElementType(S32)),
1170         param_s32);
1171   };
1172   auto match_bitcast_bf16 = [](int64_t parameter_number) {
1173     auto param = m::Convert(m::Parameter(parameter_number)
1174                                 .WithShape(m::Shape().WithElementType(BF16)))
1175                      .WithShape(m::Shape().WithElementType(F32));
1176     auto param_s32 =
1177         m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32));
1178     auto param_u32 =
1179         m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32));
1180     return m::Select(
1181         m::Lt(param_s32, m::ConstantScalar(0)),
1182         m::BitcastConvert(
1183             m::Subtract(m::ConstantScalar(std::numeric_limits<int32>::max()),
1184                         param_u32))
1185             .WithShape(m::Shape().WithElementType(S32)),
1186         param_s32);
1187   };
1188   // If root instruction is kSelect and compares indices if values are equal.
1189   if (comp->root_instruction()->opcode() == HloOpcode::kSelect) {
1190     return Match(comp->root_instruction()->operand(2),
1191                  m::Gt(match_bitcast_f32(0), match_bitcast_f32(1))) ||
1192            Match(comp->root_instruction()->operand(2),
1193                  m::Gt(match_bitcast_bf16(0), match_bitcast_bf16(1)));
1194   }
1195   return Match(comp->root_instruction(),
1196                m::Gt(match_bitcast_f32(0), match_bitcast_f32(1))) ||
1197          Match(comp->root_instruction(),
1198                m::Gt(match_bitcast_bf16(0), match_bitcast_bf16(1)));
1199 }
1200 
GetKValueInTopKWhenPartitionSortDim(HloInstruction * hlo)1201 absl::optional<int64> GetKValueInTopKWhenPartitionSortDim(HloInstruction* hlo) {
1202   HloSortInstruction* sort = DynCast<HloSortInstruction>(hlo);
1203   if (sort == nullptr || sort->operand_count() != 2) {
1204     return absl::nullopt;
1205   }
1206   if (!IsNanSafeGt(sort->to_apply())) {
1207     return absl::nullopt;
1208   }
1209   HloInstruction* data = sort->mutable_operand(0);
1210   HloIotaInstruction* iota =
1211       DynCast<HloIotaInstruction>(sort->mutable_operand(1));
1212   const PrimitiveType element_type = data->shape().element_type();
1213   if (iota == nullptr || iota->shape().element_type() != S32 ||
1214       iota->opcode() != HloOpcode::kIota ||
1215       iota->iota_dimension() != sort->sort_dimension()) {
1216     return absl::nullopt;
1217   }
1218 
1219   const int64_t sort_dim = sort->sort_dimension();
1220 
1221   if (element_type != F32 && element_type != BF16 && element_type != S32 &&
1222       element_type != U32) {
1223     return absl::nullopt;
1224   }
1225 
1226   bool supported = true;
1227   absl::optional<int64> k;
1228   for (HloInstruction* gte : sort->users()) {
1229     if (gte->opcode() != HloOpcode::kGetTupleElement) {
1230       supported = false;
1231       break;
1232     }
1233 
1234     const HloInstruction* slice = gte->users()[0];
1235     if (slice->opcode() != HloOpcode::kSlice) {
1236       // Non-slice user means we are not doing a TopK
1237       supported = false;
1238       break;
1239     }
1240     if (absl::c_any_of(slice->slice_starts(), [](int x) { return x != 0; }) ||
1241         absl::c_any_of(slice->slice_strides(), [](int x) { return x != 1; })) {
1242       // Strided slice or slicing at the beginning isn't supported.
1243       supported = false;
1244       break;
1245     }
1246     for (int64_t dim = 0; dim < data->shape().dimensions_size(); dim++) {
1247       if (dim == sort_dim) {
1248         continue;
1249       }
1250       if (slice->slice_limits(dim) !=
1251           slice->operand(0)->shape().dimensions(dim)) {
1252         // Slicing along the other dimension isn't supported.
1253         supported = false;
1254         break;
1255       }
1256     }
1257     if (!k.has_value()) {
1258       k = slice->slice_limits(sort_dim);
1259     } else if (k != slice->slice_limits(sort_dim)) {
1260       // Different k for the different operands isn't supported.
1261       supported = false;
1262       break;
1263     }
1264   }
1265   if (k == absl::nullopt || !supported) {
1266     return absl::nullopt;
1267   }
1268 
1269   // Only support when sort dim is sharded.
1270   if (!data->has_sharding()) {
1271     return absl::nullopt;
1272   }
1273   const HloSharding& sharding = sort->operand(0)->sharding();
1274 
1275   if (sharding.IsTileMaximal()) {
1276     return absl::nullopt;
1277   }
1278 
1279   // Check if partitioned at sort dimension.
1280   for (int64_t dim = 0; dim < sort->shape().tuple_shapes(0).dimensions_size();
1281        ++dim) {
1282     if (sharding.tile_assignment().dim(dim) > 1) {
1283       if (dim != sort_dim) {
1284         return absl::nullopt;
1285       }
1286     }
1287   }
1288 
1289   // Checks if partition size is smaller than k.
1290   const int64_t shard_count = sharding.tile_assignment().dim(sort_dim);
1291 
1292   if (shard_count <= 1) {
1293     return absl::nullopt;
1294   }
1295 
1296   const int64_t input_size = hlo->operand(0)->shape().dimensions(sort_dim);
1297   const int64_t per_partition_size = CeilOfRatio(input_size, shard_count);
1298 
1299   if (k.value() >= per_partition_size) {
1300     return absl::nullopt;
1301   }
1302 
1303   return k;
1304 }
1305 
1306 // Slice first k elements from sort_dim.
SliceFirstK(HloInstruction * hlo,SpmdBuilder * builder,int64_t slice_dim,int64_t k)1307 HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder,
1308                             int64_t slice_dim, int64_t k) {
1309   const Shape& hlo_shape = hlo->shape();
1310   auto hlo_dims = hlo_shape.dimensions();
1311   std::vector<int64> start_indices(hlo_shape.dimensions_size(), 0);
1312   std::vector<int64> limit_indices(hlo_dims.begin(), hlo_dims.end());
1313   std::vector<int64> strides(hlo_shape.dimensions_size(), 1);
1314   limit_indices[slice_dim] = k;
1315   auto output_shape = hlo_shape;
1316   output_shape.set_dimensions(slice_dim, k);
1317   return builder->AddInstruction(HloInstruction::CreateSlice(
1318       output_shape, hlo, start_indices, limit_indices, strides));
1319 }
1320 
1321 // Check if a dimension is sharded.
ShardCountAtDim(const HloSharding & sharding,int64_t dim)1322 int64 ShardCountAtDim(const HloSharding& sharding, int64_t dim) {
1323   if (sharding.IsTileMaximal()) {
1324     return 1;
1325   }
1326   return sharding.tile_assignment().dim(dim);
1327 }
1328 
1329 absl::optional<std::vector<std::pair<int64, int64>>>
GetReshardAllToAllSourceTargetDims(const HloSharding & source,const HloSharding & target)1330 GetReshardAllToAllSourceTargetDims(const HloSharding& source,
1331                                    const HloSharding& target) {
1332   if (source.IsTileMaximal() || target.IsTileMaximal() ||
1333       source.tile_assignment().num_dimensions() !=
1334           target.tile_assignment().num_dimensions() ||
1335       source.NumTiles() != target.NumTiles()) {
1336     return absl::nullopt;
1337   }
1338   // Record partition count to index for indices that have different partition
1339   // counts on source and target.
1340   std::map<int64, std::vector<int64>> source_size_to_dim;
1341   std::map<int64, std::vector<int64>> target_size_to_dim;
1342   for (int64_t i = 0; i < source.tile_assignment().num_dimensions(); ++i) {
1343     if (source.tile_assignment().dim(i) == target.tile_assignment().dim(i)) {
1344       continue;
1345     }
1346     source_size_to_dim[source.tile_assignment().dim(i)].push_back(i);
1347     target_size_to_dim[target.tile_assignment().dim(i)].push_back(i);
1348   }
1349   // In order to shard via AllToAll, source_size_to_dim and target_size_to_dim
1350   // must have the same distribution.
1351   if (source_size_to_dim.empty() ||
1352       source_size_to_dim.size() != target_size_to_dim.size()) {
1353     return absl::nullopt;
1354   }
1355   for (const auto& entry : source_size_to_dim) {
1356     auto target_it = target_size_to_dim.find(entry.first);
1357     if (target_it == target_size_to_dim.end() ||
1358         target_it->second.size() != entry.second.size()) {
1359       return absl::nullopt;
1360     }
1361   }
1362   std::vector<std::pair<int64, int64>> result;
1363   auto remove_entry = [](int64_t size, int64_t dim,
1364                          std::map<int64, std::vector<int64>>& size_to_dim) {
1365     size_to_dim[size].erase(
1366         std::remove_if(size_to_dim[size].begin(), size_to_dim[size].end(),
1367                        [dim](int64_t a) { return a == dim; }),
1368         size_to_dim[size].end());
1369     if (size_to_dim[size].empty()) {
1370       size_to_dim.erase(size);
1371     }
1372   };
1373   // Find one pair of dimensions to swap at a time.
1374   while (!source_size_to_dim.empty()) {
1375     int64_t source_size = source_size_to_dim.begin()->first;
1376     int64_t i = source_size_to_dim.begin()->second.back();
1377     int64_t target_i_size = target.tile_assignment().dim(i);
1378     if (target_i_size == source_size) {
1379       remove_entry(source_size, i, source_size_to_dim);
1380       remove_entry(source_size, i, target_size_to_dim);
1381       continue;
1382     }
1383     auto j_it = source_size_to_dim[target_i_size].begin();
1384     int64_t j = *j_it;
1385     if (source_size == 1) {
1386       // If possible, find a j where the target partition count is not one, so
1387       // that when we swap, the resulting size-1 dimension will still be useful
1388       // to other dimensions.
1389       while (target.tile_assignment().dim(j) == 1) {
1390         if (++j_it == source_size_to_dim[target_i_size].end()) {
1391           break;
1392         }
1393         j = *j_it;
1394       }
1395     } else if (target_i_size % source_size == 0) {
1396       // If possible, find a j where the target partition count is source_size,
1397       // so that we can do a single swap.
1398       while (target.tile_assignment().dim(j) != source_size) {
1399         if (++j_it == source_size_to_dim[target_i_size].end()) {
1400           break;
1401         }
1402         j = *j_it;
1403       }
1404     } else {
1405       return absl::nullopt;
1406     }
1407     result.emplace_back(j, i);
1408     remove_entry(target_i_size, i, target_size_to_dim);
1409     source_size_to_dim.begin()->second.back() = j;
1410     remove_entry(target_i_size, j, source_size_to_dim);
1411   }
1412   return result;
1413 }
1414 
CanReshardWithCollectivePermute(const HloSharding & source,const HloSharding & target)1415 bool CanReshardWithCollectivePermute(const HloSharding& source,
1416                                      const HloSharding& target) {
1417   return !source.IsTileMaximal() && !target.IsTileMaximal() &&
1418          source.tile_assignment().dimensions() ==
1419              target.tile_assignment().dimensions() &&
1420          source.ReplicateOnLastTileDim() == target.ReplicateOnLastTileDim() &&
1421          source.tile_assignment() != target.tile_assignment();
1422 }
1423 
GroupShardingOnDims(const HloSharding & sharding,absl::Span<const int64> group_dims)1424 GroupedSharding GroupShardingOnDims(const HloSharding& sharding,
1425                                     absl::Span<const int64> group_dims) {
1426   std::vector<int64> group_dim_shards(group_dims.size(), 1);
1427   return GroupShardingOnDims(sharding, group_dims, group_dim_shards);
1428 }
1429 
GroupShardingOnDims(const HloSharding & sharding,absl::Span<const int64> group_dims,absl::Span<const int64> group_dim_shards)1430 GroupedSharding GroupShardingOnDims(const HloSharding& sharding,
1431                                     absl::Span<const int64> group_dims,
1432                                     absl::Span<const int64> group_dim_shards) {
1433   CHECK(!sharding.IsTileMaximal());
1434   std::vector<int64> grouped_tiling_dims =
1435       sharding.tile_assignment().dimensions();
1436   std::vector<int64> group_dim_sizes(group_dims.size());
1437   for (int64_t i = 0; i < group_dims.size(); ++i) {
1438     CHECK_EQ(grouped_tiling_dims[group_dims[i]] % group_dim_shards[i], 0);
1439     group_dim_sizes[i] =
1440         grouped_tiling_dims[group_dims[i]] / group_dim_shards[i];
1441     grouped_tiling_dims[group_dims[i]] = group_dim_shards[i];
1442   }
1443 
1444   std::vector<std::vector<int64>> device_groups(Product(group_dim_sizes));
1445   sharding.tile_assignment().Each([&](absl::Span<const int64> indices,
1446                                       int64_t device) {
1447     int64_t group_id = 0;
1448     for (int64_t i = 0; i < group_dims.size(); ++i) {
1449       group_id *=
1450           sharding.tile_assignment().dim(group_dims[i]) / group_dim_shards[i];
1451       group_id += indices[group_dims[i]] / group_dim_shards[i];
1452     }
1453     device_groups[group_id].push_back(device);
1454   });
1455   auto grouped = GroupedSharding(
1456       std::move(device_groups),
1457       std::vector<int64>(group_dims.begin(), group_dims.end()),
1458       std::move(group_dim_sizes), sharding.tile_assignment().num_dimensions(),
1459       HloSharding::Replicate());
1460   if (sharding.ReplicateOnLastTileDim()) {
1461     grouped.data_rank--;
1462   }
1463   if (Product(grouped_tiling_dims) == 1 ||
1464       (sharding.ReplicateOnLastTileDim() &&
1465        Product(grouped_tiling_dims) == grouped_tiling_dims.back())) {
1466     return grouped;
1467   }
1468   if ((sharding.ReplicateOnLastTileDim() || sharding.IsManualSubgroup()) &&
1469       grouped_tiling_dims.back() == 1) {
1470     grouped_tiling_dims.pop_back();
1471   }
1472   Array<int64> grouped_tiling(grouped_tiling_dims);
1473   grouped_tiling.FillIota(0);
1474   grouped.sharding = sharding.ReplicateOnLastTileDim() &&
1475                              grouped_tiling_dims.size() ==
1476                                  sharding.tile_assignment().num_dimensions()
1477                          ? HloSharding::PartialTile(grouped_tiling)
1478                          : HloSharding::Tile(grouped_tiling);
1479   return grouped;
1480 }
1481 
UngroupSharding(const GroupedSharding & grouped_sharding)1482 HloSharding UngroupSharding(const GroupedSharding& grouped_sharding) {
1483   std::vector<int64> tiling_dims;
1484   bool partial_sharding = false;
1485   auto grouped_tiling = grouped_sharding.sharding.tile_assignment();
1486   if (grouped_sharding.sharding.IsTileMaximal()) {
1487     tiling_dims = std::vector<int64>(grouped_sharding.data_rank, 1);
1488     if (grouped_sharding.device_groups[0].size() != 1) {
1489       // This is partial sharding.
1490       tiling_dims.push_back(grouped_sharding.device_groups[0].size());
1491       partial_sharding = true;
1492     }
1493     grouped_tiling = Array<int64>(tiling_dims);
1494     grouped_tiling.FillIota(0);
1495   } else {
1496     partial_sharding = grouped_sharding.sharding.ReplicateOnLastTileDim();
1497     tiling_dims = grouped_sharding.sharding.tile_assignment().dimensions();
1498     if (absl::c_linear_search(grouped_sharding.group_dims,
1499                               tiling_dims.size())) {
1500       tiling_dims.push_back(1);
1501       grouped_tiling.Reshape(tiling_dims);
1502       partial_sharding = true;
1503     }
1504   }
1505   for (int64_t i = 0; i < grouped_sharding.group_dims.size(); ++i) {
1506     int64_t dim = grouped_sharding.group_dims[i];
1507     tiling_dims[dim] *= grouped_sharding.group_dim_sizes[i];
1508   }
1509   Array<int64> tiling(tiling_dims);
1510   grouped_tiling.Each([&](absl::Span<const int64> indices, int64_t device) {
1511     std::vector<int64> ungrouped_inds(indices.begin(), indices.end());
1512     for (int64_t g = 0; g < grouped_sharding.device_groups.size(); ++g) {
1513       int64_t remaining_group_index = g;
1514       for (int64_t i = grouped_sharding.group_dims.size() - 1; i >= 0; --i) {
1515         int64_t dim = grouped_sharding.group_dims[i];
1516         int64_t groups_in_this_dim = grouped_sharding.group_dim_sizes[i];
1517         ungrouped_inds[dim] = (remaining_group_index % groups_in_this_dim) *
1518                                   grouped_tiling.dim(dim) +
1519                               indices[dim];
1520         remaining_group_index /= groups_in_this_dim;
1521       }
1522       tiling(ungrouped_inds) = grouped_sharding.device_groups[g][device];
1523     }
1524   });
1525   return partial_sharding ? HloSharding::PartialTile(tiling)
1526                           : HloSharding::Tile(tiling);
1527 }
1528 
AlignGroupsWith(GroupedSharding grouped_sharding,const GroupedSharding & reference,bool ignore_group_order)1529 GroupedSharding AlignGroupsWith(GroupedSharding grouped_sharding,
1530                                 const GroupedSharding& reference,
1531                                 bool ignore_group_order) {
1532   // Returns src -> dst index mapping.
1533   auto get_permutation = [](absl::Span<const int64> src,
1534                             absl::Span<const int64> dst) {
1535     CHECK_EQ(src.size(), dst.size());
1536     absl::flat_hash_map<int64, int64> dst_reverse_map;
1537     for (int64_t i = 0; i < dst.size(); ++i) {
1538       dst_reverse_map[dst[i]] = i;
1539     }
1540     std::vector<int64> permutation(src.size());
1541     for (int64_t i = 0; i < src.size(); ++i) {
1542       auto it = dst_reverse_map.find(src[i]);
1543       CHECK(it != dst_reverse_map.end());
1544       permutation[i] = it->second;
1545     }
1546     return permutation;
1547   };
1548   CHECK_EQ(grouped_sharding.device_groups.size(),
1549            reference.device_groups.size());
1550   absl::flat_hash_map<int64, int64> device_to_ref_group;
1551   for (int64_t g = 0; g < reference.device_groups.size(); ++g) {
1552     for (int64_t device : reference.device_groups[g]) {
1553       device_to_ref_group[device] = g;
1554     }
1555   }
1556   auto unique_ref_dev_group = [&](absl::Span<const int64> devices) -> int64 {
1557     int64_t ref_g = -1;
1558     for (int64_t device : devices) {
1559       if (ref_g == -1) {
1560         ref_g = device_to_ref_group[device];
1561       } else if (ref_g != device_to_ref_group[device]) {
1562         return -1;
1563       }
1564     }
1565     return ref_g;
1566   };
1567   bool matching_groups = true;
1568   std::vector<int64> original_src_to_ref_permutation;
1569   for (int64_t g = 0; g < grouped_sharding.device_groups.size(); ++g) {
1570     int64_t ref_g = unique_ref_dev_group(grouped_sharding.device_groups[g]);
1571     if (ref_g < 0 || (!ignore_group_order && g != ref_g)) {
1572       matching_groups = false;
1573       break;
1574     }
1575     if (g == 0) {
1576       original_src_to_ref_permutation = get_permutation(
1577           grouped_sharding.device_groups[g], reference.device_groups[ref_g]);
1578     }
1579   }
1580   if (matching_groups && !grouped_sharding.sharding.IsTileMaximal()) {
1581     auto tiles = grouped_sharding.sharding.tile_assignment();
1582     tiles.Each([&](absl::Span<const int64> indices, int64* device) {
1583       *device = original_src_to_ref_permutation[*device];
1584     });
1585     grouped_sharding.sharding =
1586         grouped_sharding.sharding.ReplicateOnLastTileDim()
1587             ? HloSharding::PartialTile(tiles)
1588             : HloSharding::Tile(tiles);
1589   }
1590   grouped_sharding.device_groups = std::move(reference.device_groups);
1591   return grouped_sharding;
1592 }
1593 
AlignShardingOnDims(const HloSharding & sharding,absl::Span<const int64> sharding_dims,const HloSharding & reference,absl::Span<const int64> reference_dims)1594 HloSharding AlignShardingOnDims(const HloSharding& sharding,
1595                                 absl::Span<const int64> sharding_dims,
1596                                 const HloSharding& reference,
1597                                 absl::Span<const int64> reference_dims) {
1598   auto sharding_grouped = GroupShardingOnDims(sharding, sharding_dims);
1599   auto reference_grouped = GroupShardingOnDims(reference, reference_dims);
1600   return UngroupSharding(AlignGroupsWith(sharding_grouped, reference_grouped));
1601 }
1602 
GetPerGroupBaseShape(const GroupedSharding & grouped_sharding,const Shape & original_base_shape)1603 Shape GetPerGroupBaseShape(const GroupedSharding& grouped_sharding,
1604                            const Shape& original_base_shape) {
1605   auto result = original_base_shape;
1606   for (int64_t i = 0; i < grouped_sharding.group_dims.size(); ++i) {
1607     int64_t dim = grouped_sharding.group_dims[i];
1608     if (dim >= original_base_shape.rank()) {
1609       continue;
1610     }
1611     int64_t groups = grouped_sharding.group_dim_sizes[i];
1612     result.set_dimensions(dim, CeilOfRatio(result.dimensions(dim), groups));
1613   }
1614   return result;
1615 }
1616 
1617 namespace {
1618 
GetInGroupPartitionId(HloInstruction * partition_id,const std::vector<std::vector<int64>> & device_groups,SpmdBuilder * b)1619 HloInstruction* GetInGroupPartitionId(
1620     HloInstruction* partition_id,
1621     const std::vector<std::vector<int64>>& device_groups, SpmdBuilder* b) {
1622   int64_t total_devices = device_groups.size() * device_groups[0].size();
1623   std::vector<uint32> in_group_ids(total_devices);
1624   for (uint32 i = 0; i < device_groups.size(); ++i) {
1625     for (uint32 j = 0; j < device_groups[i].size(); ++j) {
1626       in_group_ids[device_groups[i][j]] = j;
1627     }
1628   }
1629   auto id_table = b->AddInstruction(HloInstruction::CreateConstant(
1630       LiteralUtil::CreateR1<uint32>(in_group_ids)));
1631   return b->AddInstruction(HloInstruction::CreateReshape(
1632       ShapeUtil::MakeScalarShape(U32),
1633       b->AddInstruction(HloInstruction::CreateDynamicSlice(
1634           ShapeUtil::MakeShape(U32, {1}), id_table, {partition_id}, {1}))));
1635 }
1636 
GetPerGroupCollectiveOpsCreator(const SPMDCollectiveOpsCreator & creator,const std::vector<std::vector<int64>> & device_groups)1637 SPMDCollectiveOpsCreator GetPerGroupCollectiveOpsCreator(
1638     const SPMDCollectiveOpsCreator& creator,
1639     const std::vector<std::vector<int64>>& device_groups) {
1640   SPMDCollectiveOpsCreator result;
1641   result.create_partition_id = [creator, device_groups](SpmdBuilder* b) {
1642     return GetInGroupPartitionId(creator.create_partition_id(b), device_groups,
1643                                  b);
1644   };
1645   auto expand_partition_groups =
1646       [device_groups](
1647           const std::vector<std::vector<int64>>& partition_subgroups) {
1648         if (partition_subgroups.empty()) {
1649           return device_groups;
1650         }
1651         std::vector<std::vector<int64>> result(partition_subgroups.size() *
1652                                                device_groups.size());
1653         for (int64_t g = 0; g < device_groups.size(); ++g) {
1654           for (int64_t i = 0; i < partition_subgroups.size(); ++i) {
1655             result[g * partition_subgroups.size() + i].resize(
1656                 partition_subgroups[i].size());
1657             for (int64_t j = 0; j < partition_subgroups[i].size(); ++j) {
1658               result[g * partition_subgroups.size() + i][j] =
1659                   device_groups[g][partition_subgroups[i][j]];
1660             }
1661           }
1662         }
1663         return result;
1664       };
1665   result.create_cross_partition_all_reduce =
1666       [creator, expand_partition_groups](
1667           SpmdBuilder* b, HloInstruction* operand, HloComputation* reduction,
1668           const std::vector<std::vector<int64>>& partition_subgroups,
1669           int64_t channel_id) {
1670         return creator.create_cross_partition_all_reduce(
1671             b, operand, reduction, expand_partition_groups(partition_subgroups),
1672             channel_id);
1673       };
1674   result.create_cross_partition_collective_permute =
1675       [creator, device_groups](
1676           SpmdBuilder* b, HloInstruction* operand,
1677           std::vector<std::pair<int64, int64>>& src_dst_pairs,
1678           int64_t next_channel_id) {
1679         std::vector<std::pair<int64, int64>> expanded_pairs(
1680             src_dst_pairs.size() * device_groups.size());
1681         for (int64_t g = 0; g < device_groups.size(); ++g) {
1682           for (int64_t i = 0; i < src_dst_pairs.size(); ++i) {
1683             expanded_pairs[g * src_dst_pairs.size() + i] =
1684                 std::pair<int64, int64>{
1685                     device_groups[g][src_dst_pairs[i].first],
1686                     device_groups[g][src_dst_pairs[i].second]};
1687           }
1688         }
1689         return creator.create_cross_partition_collective_permute(
1690             b, operand, expanded_pairs, next_channel_id);
1691       };
1692   result.create_cross_partition_all_to_all =
1693       [creator, expand_partition_groups](
1694           SpmdBuilder* b, absl::Span<HloInstruction* const> operands,
1695           const std::vector<std::vector<int64>>& partition_subgroups,
1696           int64_t channel_id, absl::optional<int64> split_dimension) {
1697         return creator.create_cross_partition_all_to_all(
1698             b, operands, expand_partition_groups(partition_subgroups),
1699             channel_id, split_dimension);
1700       };
1701   if (creator.create_cross_partition_all_gather) {
1702     result.create_cross_partition_all_gather =
1703         [creator, expand_partition_groups](
1704             SpmdBuilder* b, HloInstruction* operand, const Shape& ag_shape,
1705             const std::vector<std::vector<int64>>& partition_subgroups,
1706             int64_t channel_id, int64_t all_gather_dimension) {
1707           return creator.create_cross_partition_all_gather(
1708               b, operand, ag_shape,
1709               expand_partition_groups(partition_subgroups), channel_id,
1710               all_gather_dimension);
1711         };
1712   }
1713   return result;
1714 }
1715 
1716 }  // namespace
1717 
CreatePerGroupPartitioningState(const PartitionedHlo::PartitioningState & state,const std::vector<std::vector<int64>> & device_groups,SpmdBuilder * b)1718 PartitionedHlo::PartitioningState CreatePerGroupPartitioningState(
1719     const PartitionedHlo::PartitioningState& state,
1720     const std::vector<std::vector<int64>>& device_groups, SpmdBuilder* b) {
1721   auto result = state;
1722   result.collective_ops_creator = GetPerGroupCollectiveOpsCreator(
1723       state.collective_ops_creator, device_groups);
1724   result.partition_id =
1725       GetInGroupPartitionId(state.partition_id, device_groups, b);
1726   // Create a string key for the groups.
1727   std::vector<std::string> per_group_strings(device_groups.size());
1728   for (int64_t i = 0; i < per_group_strings.size(); ++i) {
1729     per_group_strings[i] = absl::StrJoin(device_groups[i], ",");
1730   }
1731   auto& grouped_cache =
1732       state.reshard_cache->groupd_caches[absl::StrJoin(per_group_strings, ";")];
1733   if (!grouped_cache) {
1734     grouped_cache = absl::make_unique<PartitionedHlo::ReshardCache>();
1735   }
1736   result.reshard_cache = grouped_cache.get();
1737   return result;
1738 }
1739 
PerGroupSliceFromReplicated(HloInstruction * replicated,HloInstruction * partition_id,const std::vector<std::vector<int64>> & device_groups,absl::Span<const int64> group_dims,absl::Span<const int64> group_dim_sizes,SpmdBuilder * b)1740 HloInstruction* PerGroupSliceFromReplicated(
1741     HloInstruction* replicated, HloInstruction* partition_id,
1742     const std::vector<std::vector<int64>>& device_groups,
1743     absl::Span<const int64> group_dims, absl::Span<const int64> group_dim_sizes,
1744     SpmdBuilder* b) {
1745   std::vector<uint32> group_ids(device_groups.size() * device_groups[0].size());
1746   for (int64_t g = 0; g < device_groups.size(); ++g) {
1747     for (int64_t device : device_groups[g]) {
1748       group_ids[device] = g;
1749     }
1750   }
1751   auto group_id_table = b->AddInstruction(
1752       HloInstruction::CreateConstant(LiteralUtil::CreateR1<uint32>(group_ids)));
1753   auto group_id = b->AddInstruction(HloInstruction::CreateReshape(
1754       ShapeUtil::MakeScalarShape(U32),
1755       b->AddInstruction(HloInstruction::CreateDynamicSlice(
1756           ShapeUtil::MakeShape(U32, {1}), group_id_table, {partition_id},
1757           {1}))));
1758   std::vector<int64> group_level_tile_dims(replicated->shape().rank(), 1);
1759   for (int64_t i = 0; i < group_dims.size(); ++i) {
1760     group_level_tile_dims[group_dims[i]] = group_dim_sizes[i];
1761   }
1762   Array<int64> group_level_tile(group_level_tile_dims);
1763   group_level_tile.Each([&](absl::Span<const int64> indices, int64* group) {
1764     *group = 0;
1765     for (int64_t dim : group_dims) {
1766       *group *= group_level_tile.dim(dim);
1767       *group += indices[dim];
1768     }
1769   });
1770   auto group_level_sharding = HloSharding::Tile(group_level_tile);
1771   auto padded_hlo = PadBaseShapeBeforeUnevenTiledSharding(
1772       replicated, group_level_sharding, b);
1773   auto shard_shape =
1774       MakePartitionedShape(replicated->shape(), group_level_sharding);
1775   return b->AddInstruction(HloInstruction::CreateDynamicSlice(
1776       shard_shape, padded_hlo,
1777       MakePartitionOffsets(replicated->shape(), group_level_sharding, group_id,
1778                            b),
1779       shard_shape.dimensions()));
1780 }
1781 
ParseReductionComputation(const HloComputation * reduction_comp)1782 absl::optional<HloOpcode> ParseReductionComputation(
1783     const HloComputation* reduction_comp) {
1784   if (reduction_comp->num_parameters() != 2) {
1785     return absl::nullopt;
1786   }
1787   auto root = reduction_comp->root_instruction();
1788   if (!root->IsElementwiseBinary()) {
1789     return absl::nullopt;
1790   }
1791   if (!absl::c_linear_search(root->operands(),
1792                              reduction_comp->parameter_instruction(0)) ||
1793       !absl::c_linear_search(root->operands(),
1794                              reduction_comp->parameter_instruction(1))) {
1795     return absl::nullopt;
1796   }
1797   return root->opcode();
1798 }
1799 
FindMatchingPartitionedDimsForGrouping(const HloSharding & sharding,const std::vector<std::vector<int64>> & device_groups)1800 absl::optional<std::vector<int64>> FindMatchingPartitionedDimsForGrouping(
1801     const HloSharding& sharding,
1802     const std::vector<std::vector<int64>>& device_groups) {
1803   if (sharding.NumTiles() < device_groups.size() || device_groups.size() < 2 ||
1804       device_groups[0].size() < 2) {
1805     return absl::nullopt;
1806   }
1807   int64_t rank = sharding.tile_assignment().num_dimensions();
1808   if (sharding.ReplicateOnLastTileDim()) {
1809     rank--;
1810   }
1811   absl::flat_hash_map<int64, std::vector<int64>> device_to_index;
1812   sharding.tile_assignment().Each(
1813       [&](absl::Span<const int64> index, int64_t device) {
1814         device_to_index[device] =
1815             std::vector<int64>(index.begin(), index.begin() + rank);
1816       });
1817   std::vector<int64> dims;
1818   int64_t group_count = 1;
1819   for (int64_t i = 0; i < rank; ++i) {
1820     if (device_to_index[device_groups[0][0]][i] ==
1821         device_to_index[device_groups[0][1]][i]) {
1822       dims.push_back(i);
1823       group_count *= sharding.tile_assignment().dim(i);
1824     }
1825   }
1826   if (group_count != device_groups.size()) {
1827     return absl::nullopt;
1828   }
1829   for (const auto& group : device_groups) {
1830     for (int64_t i = 1; i < group.size(); ++i) {
1831       if (absl::c_any_of(dims, [&](const int64_t dim) {
1832             return device_to_index[group[i]][dim] !=
1833                    device_to_index[group[0]][dim];
1834           })) {
1835         return absl::nullopt;
1836       }
1837     }
1838   }
1839   return dims;
1840 }
1841 
CreateMatchingShardingOnDims(const Shape & target_shape,const HloSharding & source_sharding,absl::Span<const int64> target_dims,absl::Span<const int64> source_dims)1842 HloSharding CreateMatchingShardingOnDims(const Shape& target_shape,
1843                                          const HloSharding& source_sharding,
1844                                          absl::Span<const int64> target_dims,
1845                                          absl::Span<const int64> source_dims) {
1846   CHECK(target_dims.size() == source_dims.size())
1847       << "Expected 1:1 match between parallel dimensions";
1848   if (source_sharding.IsReplicated()) {
1849     return HloSharding::Replicate();
1850   }
1851   absl::InlinedVector<int64, 4> tile_dims(target_shape.dimensions_size(), 1);
1852   int num_tiles = 1;
1853   for (int i = 0, end = target_dims.size(); i < end; ++i) {
1854     num_tiles *= source_sharding.tile_assignment().dim(source_dims[i]);
1855     tile_dims[target_dims[i]] =
1856         source_sharding.tile_assignment().dim(source_dims[i]);
1857   }
1858   // If there is some partition across non-parallel dimensions in the
1859   // other operand then partially replicate for the new
1860   bool to_be_partially_replicated = false;
1861   if (num_tiles != source_sharding.tile_assignment().num_elements()) {
1862     CHECK_EQ(source_sharding.tile_assignment().num_elements() % num_tiles, 0);
1863     to_be_partially_replicated = true;
1864     tile_dims.push_back(source_sharding.tile_assignment().num_elements() /
1865                         num_tiles);
1866   }
1867   auto tgt_tile_assignment = source_sharding.tile_assignment();
1868   tgt_tile_assignment.Reshape(tile_dims);
1869   if (to_be_partially_replicated) {
1870     return AlignShardingOnDims(HloSharding::PartialTile(tgt_tile_assignment),
1871                                target_dims, source_sharding, source_dims);
1872   } else {
1873     return AlignShardingOnDims(HloSharding::Tile(tgt_tile_assignment),
1874                                target_dims, source_sharding, source_dims);
1875   }
1876 }
1877 
1878 absl::optional<GatherParallelDimSharding>
GatherOperandsShardedAcrossParallelDims(const HloInstruction & operand,const HloInstruction & indices,const hlo_sharding_util::GatherParallelDims & parallel_dims)1879 GatherOperandsShardedAcrossParallelDims(
1880     const HloInstruction& operand, const HloInstruction& indices,
1881     const hlo_sharding_util::GatherParallelDims& parallel_dims) {
1882   auto& indices_parallel_dims = parallel_dims.indices_parallel_dims;
1883   auto& operand_parallel_dims = parallel_dims.operand_parallel_dims;
1884   if (indices_parallel_dims.size() != operand_parallel_dims.size()) {
1885     return absl::nullopt;
1886   }
1887   auto new_index_shard = indices.sharding();
1888   auto new_operand_shard = operand.sharding();
1889   int idx_parallel_tiles_num = new_index_shard.NumTiles(indices_parallel_dims);
1890   int op_parallel_tiles_num = new_operand_shard.NumTiles(operand_parallel_dims);
1891   if (idx_parallel_tiles_num == 1 && op_parallel_tiles_num == 1) {
1892     return absl::nullopt;
1893   }
1894   absl::InlinedVector<int64, 1> indices_parallel_dims_ordered_as_operand;
1895   for (int idx : parallel_dims.index_parallel_in_dim) {
1896     if (idx != -1) {
1897       indices_parallel_dims_ordered_as_operand.push_back(idx);
1898     }
1899   }
1900   if (new_index_shard.IsReplicated()) {
1901     return GatherParallelDimSharding{
1902         CreateMatchingShardingOnDims(indices.shape(), new_operand_shard,
1903                                      indices_parallel_dims_ordered_as_operand,
1904                                      operand_parallel_dims),
1905         new_operand_shard};
1906   }
1907   if (new_operand_shard.IsReplicated()) {
1908     return GatherParallelDimSharding{
1909         new_index_shard,
1910         CreateMatchingShardingOnDims(operand.shape(), new_index_shard,
1911                                      operand_parallel_dims,
1912                                      indices_parallel_dims_ordered_as_operand)};
1913   }
1914 
1915   // Parallel dimension distribution needs to be the same, so try to steal
1916   // sharding from partial replication to compensate.
1917   if (idx_parallel_tiles_num != op_parallel_tiles_num) {
1918     auto to_adjust_dims = operand_parallel_dims;
1919     auto target_dims = indices_parallel_dims_ordered_as_operand;
1920     HloSharding* target = &new_index_shard;
1921     HloSharding* to_adjust = &new_operand_shard;
1922     if (idx_parallel_tiles_num < op_parallel_tiles_num) {
1923       std::swap(to_adjust_dims, target_dims);
1924       std::swap(to_adjust, target);
1925     }
1926     if (!to_adjust->ReplicateOnLastTileDim()) {
1927       return absl::nullopt;
1928     }
1929     auto new_tile_assignment_dims = to_adjust->tile_assignment().dimensions();
1930     for (int i = 0; i < to_adjust_dims.size(); ++i) {
1931       int64_t target_dim = target->tile_assignment().dim(target_dims[i]);
1932       int64_t to_adjust_dim =
1933           to_adjust->tile_assignment().dim(to_adjust_dims[i]);
1934       if (target_dim < to_adjust_dim) {
1935         return absl::nullopt;
1936       }
1937       if (target_dim == to_adjust_dim) {
1938         continue;
1939       }
1940       int64_t ratio = target_dim / to_adjust_dim;
1941       if (target_dim % to_adjust_dim != 0 ||
1942           new_tile_assignment_dims.back() % ratio != 0) {
1943         return absl::nullopt;
1944       }
1945       new_tile_assignment_dims[to_adjust_dims[i]] *= ratio;
1946       new_tile_assignment_dims.back() /= ratio;
1947     }
1948     CHECK_GE(new_tile_assignment_dims.back(), 1);
1949     bool to_partially_replicate = true;
1950     if (new_tile_assignment_dims.back() == 1) {
1951       new_tile_assignment_dims.pop_back();
1952       to_partially_replicate = false;
1953     }
1954     auto new_tile_assignment = to_adjust->tile_assignment();
1955     new_tile_assignment.Reshape(new_tile_assignment_dims);
1956     if (to_partially_replicate) {
1957       *to_adjust =
1958           AlignShardingOnDims(HloSharding::PartialTile(new_tile_assignment),
1959                               to_adjust_dims, *target, target_dims);
1960     } else {
1961       *to_adjust = AlignShardingOnDims(HloSharding::Tile(new_tile_assignment),
1962                                        to_adjust_dims, *target, target_dims);
1963     }
1964   }
1965   // Make sure that the parallel dimensions are aligned.
1966   auto operand_shard_tile_dims =
1967       new_operand_shard.tile_assignment().dimensions();
1968   for (int i = 0; i < indices_parallel_dims_ordered_as_operand.size(); ++i) {
1969     operand_shard_tile_dims[operand_parallel_dims[i]] =
1970         new_index_shard.tile_assignment().dim(
1971             indices_parallel_dims_ordered_as_operand[i]);
1972   }
1973   auto operand_shard_tiles = new_operand_shard.tile_assignment();
1974   operand_shard_tiles.Reshape(operand_shard_tile_dims);
1975   new_operand_shard =
1976       AlignShardingOnDims(new_operand_shard.ReplicateOnLastTileDim()
1977                               ? HloSharding::PartialTile(operand_shard_tiles)
1978                               : HloSharding::Tile(operand_shard_tiles),
1979                           operand_parallel_dims, new_index_shard,
1980                           indices_parallel_dims_ordered_as_operand);
1981   return GatherParallelDimSharding{new_index_shard, new_operand_shard};
1982 }
1983 
FindRotateRightPattern(const HloInstruction * concat,const HloInstruction * lhs,const HloInstruction * rhs)1984 int64 FindRotateRightPattern(const HloInstruction* concat,
1985                              const HloInstruction* lhs,
1986                              const HloInstruction* rhs) {
1987   if (lhs->opcode() != HloOpcode::kSlice ||
1988       rhs->opcode() != HloOpcode::kSlice ||
1989       lhs->operand(0) != rhs->operand(0)) {
1990     return -1;
1991   }
1992   const HloInstruction* to_rotate = lhs->operand(0);
1993   if (!ShapeUtil::Compatible(to_rotate->shape(), concat->shape()) ||
1994       concat->sharding() != to_rotate->sharding()) {
1995     return -1;
1996   }
1997   const int64_t dim = concat->concatenate_dimension();
1998   if (lhs->slice_strides(dim) != 1 || rhs->slice_strides(dim) != 1 ||
1999       lhs->slice_starts(dim) != rhs->slice_limits(dim)) {
2000     return -1;
2001   }
2002   return lhs->shape().dimensions(dim);
2003 }
2004 
FindPadWithWrapPattern(const HloInstruction * concat,const HloInstruction * lhs,const HloInstruction * mid,const HloInstruction * rhs)2005 absl::optional<PadWithWrapPattern> FindPadWithWrapPattern(
2006     const HloInstruction* concat, const HloInstruction* lhs,
2007     const HloInstruction* mid, const HloInstruction* rhs) {
2008   if (!lhs || !mid || !rhs) {
2009     return absl::nullopt;
2010   }
2011 
2012   // Skip elementwise unary operations applied to inst, returning
2013   // a list of applied operations that were skipped.
2014   auto skip_elementwise_ops = [&](const HloInstruction* inst) {
2015     std::vector<const HloInstruction*> modifiers;
2016     while (inst->IsElementwise() && inst->operand_count() == 1 &&
2017            inst->user_count() == 1) {
2018       if (inst->opcode() != HloOpcode::kCopy) {
2019         modifiers.push_back(inst);
2020       }
2021       inst = inst->operand(0);
2022     }
2023     return std::make_pair(modifiers, inst);
2024   };
2025 
2026   PadWithWrapPattern pad_pattern;
2027   auto skip_result = skip_elementwise_ops(lhs);
2028   pad_pattern.lhs_modifiers = std::move(skip_result.first);
2029   lhs = skip_result.second;
2030 
2031   skip_result = skip_elementwise_ops(rhs);
2032   pad_pattern.rhs_modifiers = std::move(skip_result.first);
2033   rhs = skip_result.second;
2034 
2035   const int64_t dim = concat->concatenate_dimension();
2036   if (lhs->opcode() != HloOpcode::kSlice ||
2037       rhs->opcode() != HloOpcode::kSlice || lhs->operand(0) != mid ||
2038       rhs->operand(0) != mid || lhs->slice_strides(dim) != 1 ||
2039       rhs->slice_strides(dim) != 1 || lhs->sharding() != mid->sharding() ||
2040       rhs->sharding() != mid->sharding() ||
2041       lhs->sharding() != concat->sharding()) {
2042     return absl::nullopt;
2043   }
2044   pad_pattern.lhs_slice_start = lhs->slice_starts(dim);
2045   pad_pattern.rhs_slice_start = rhs->slice_starts(dim);
2046   return pad_pattern;
2047 }
2048 
GetManualSubgroupSharding(const HloSharding & sharding)2049 GroupedSharding GetManualSubgroupSharding(const HloSharding& sharding) {
2050   CHECK(sharding.IsManualSubgroup());
2051   int64_t tile_dimensions = sharding.tile_assignment().num_dimensions();
2052   int64_t subgroup_size = sharding.sharding_types().size();
2053   int64_t rank = tile_dimensions - subgroup_size;
2054   std::vector<int64> group_dims;
2055   bool last_tile_dim_replicate = false;
2056 
2057   for (int64_t i = 0; i < subgroup_size; i++) {
2058     if (sharding.sharding_types()[i] == OpSharding::MANUAL) {
2059       group_dims.push_back(rank + i);
2060     } else if (sharding.sharding_types()[i] == OpSharding::REPLICATED) {
2061       last_tile_dim_replicate = true;
2062     }
2063   }
2064 
2065   GroupedSharding group_sharding = GroupShardingOnDims(sharding, group_dims);
2066 
2067   if (last_tile_dim_replicate ||
2068       group_sharding.sharding.tile_assignment().num_dimensions() > rank) {
2069     group_sharding.sharding = HloSharding::PartialTile(
2070         group_sharding.sharding.tile_assignment(), sharding.metadata());
2071   }
2072   return group_sharding;
2073 }
2074 
2075 }  // namespace spmd
2076 }  // namespace xla
2077