• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
17 
18 #include <algorithm>
19 #include <map>
20 #include <vector>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/types/optional.h"
25 #include "tensorflow/compiler/xla/array.h"
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 
34 namespace xla {
35 namespace hlo_sharding_util {
36 
IsShardingMoreSpecific(const HloSharding & lhs,const HloSharding & rhs)37 bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs) {
38   CHECK_EQ(lhs.IsTuple(), rhs.IsTuple()) << lhs << " <> " << rhs;
39   if (lhs.IsTuple()) {
40     // For tuples we consider lhs to have a better sharding if none of the
41     // elements are worse and at least one element is better then in rhs
42     // sharding.
43     const auto& lhs_shardings = lhs.tuple_elements();
44     const auto& rhs_shardings = rhs.tuple_elements();
45     CHECK_EQ(lhs_shardings.size(), rhs_shardings.size());
46     bool is_better = false;
47     for (int64_t i = 0; i < lhs_shardings.size(); ++i) {
48       if (IsShardingMoreSpecific(rhs_shardings[i], lhs_shardings[i])) {
49         return false;
50       }
51       if (IsShardingMoreSpecific(lhs_shardings[i], rhs_shardings[i])) {
52         is_better = true;
53       }
54     }
55     return is_better;
56   }
57   if (!rhs.IsTileMaximal()) {
58     return lhs.NumTiles() > rhs.NumTiles();
59   } else if (!rhs.IsReplicated()) {
60     // If we are not replicated then only tiled (not tile maximal) shardings
61     // can improve us.
62     return !lhs.IsTileMaximal();
63   } else {
64     // If we are replicated then any non-replicated sharding can improve us.
65     return !lhs.IsReplicated();
66   }
67 }
68 
MergeSharding(const HloSharding & old,HloSharding * to_merge,bool may_combine_partial_sharding)69 bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
70                    bool may_combine_partial_sharding) {
71   if (old.IsTuple()) {
72     CHECK(to_merge->IsTuple());
73     bool changed = false;
74     for (int64_t i = 0; i < old.tuple_elements().size(); ++i) {
75       changed |=
76           MergeSharding(old.tuple_elements()[i], &to_merge->tuple_elements()[i],
77                         may_combine_partial_sharding);
78     }
79     return changed;
80   }
81   if (!may_combine_partial_sharding || !old.ReplicateOnLastTileDim() ||
82       !to_merge->ReplicateOnLastTileDim() ||
83       old.tile_assignment().num_elements() !=
84           to_merge->tile_assignment().num_elements()) {
85     return IsShardingMoreSpecific(*to_merge, old);
86   }
87 
88   if (MergeShardingIfCompatible(
89           old,
90           /*minimum_tiles=*/std::max(old.NumTiles(), to_merge->NumTiles()) + 1,
91           to_merge)) {
92     return true;
93   }
94   return IsShardingMoreSpecific(*to_merge, old);
95 }
96 
MergeShardingIfCompatible(const HloSharding & to_merge,int64_t minimum_tiles,HloSharding * dst)97 bool MergeShardingIfCompatible(const HloSharding& to_merge,
98                                int64_t minimum_tiles, HloSharding* dst) {
99   if (to_merge.IsTileMaximal()) {
100     return false;
101   }
102   if (dst->IsTileMaximal()) {
103     *dst = to_merge;
104     return true;
105   }
106   // Combine the tile dimension sizes from dst and to_merge.
107   int64_t num_devices = to_merge.tile_assignment().num_elements();
108   std::vector<int64> merged_tile_dims;
109   bool compatible = true;
110   merged_tile_dims.reserve(dst->tile_assignment().num_dimensions());
111   for (int64_t i = 0; i < dst->tile_assignment().num_dimensions() - 1; ++i) {
112     int64_t dst_dim = dst->tile_assignment().dim(i);
113     int64_t merge_dim = to_merge.tile_assignment().dim(i);
114     if (dst_dim == 1) {
115       merged_tile_dims.push_back(merge_dim);
116     } else if (merge_dim == 1) {
117       merged_tile_dims.push_back(dst_dim);
118     } else if (dst_dim == merge_dim) {
119       merged_tile_dims.push_back(dst_dim);
120     } else {
121       compatible = false;
122       break;
123     }
124   }
125   int64_t merged_tiles = Product(merged_tile_dims);
126   if (!compatible || num_devices % Product(merged_tile_dims) != 0 ||
127       merged_tiles < minimum_tiles) {
128     return false;
129   }
130   int64_t replication = num_devices / merged_tiles;
131   merged_tile_dims.push_back(replication);
132   Array<int64> merged_tile(merged_tile_dims);
133   // Maps from replication group ID to sorted members.
134   absl::flat_hash_map<int64, std::set<int64>> merge_group_members;
135   absl::flat_hash_map<int64, std::set<int64>> dst_group_members;
136   auto get_group_index = [&](absl::Span<const int64> tile_indices,
137                              const HloSharding& sharding) {
138     int64_t group_id = 0;
139     for (int64_t i = 0; i < tile_indices.size() - 1; ++i) {
140       group_id *= dst->tile_assignment().dim(i);
141       group_id += tile_indices[i];
142     }
143     return group_id;
144   };
145   to_merge.tile_assignment().Each(
146       [&](absl::Span<const int64> indices, int64_t device) {
147         merge_group_members[get_group_index(indices, to_merge)].insert(device);
148       });
149   dst->tile_assignment().Each(
150       [&](absl::Span<const int64> indices, int64_t device) {
151         dst_group_members[get_group_index(indices, *dst)].insert(device);
152       });
153   // Try to find the intersection of to_merge and dst replication groups, in
154   // order to determine the merged tile assignment.
155   merged_tile.Each([&](absl::Span<const int64> indices, int64* device) {
156     if (!compatible) {
157       return;
158     }
159     std::vector<int64> to_merge_index(indices.begin(), indices.end());
160     std::vector<int64> dst_index = to_merge_index;
161     for (int64_t i = 0; i < indices.size() - 1; ++i) {
162       if (to_merge.tile_assignment().dim(i) == 1) {
163         to_merge_index[i] = 0;
164       }
165       if (dst->tile_assignment().dim(i) == 1) {
166         dst_index[i] = 0;
167       }
168     }
169     int64_t to_merge_group_id = get_group_index(to_merge_index, to_merge);
170     int64_t dst_group_id = get_group_index(dst_index, *dst);
171     if (merge_group_members[to_merge_group_id].empty() ||
172         dst_group_members[dst_group_id].empty()) {
173       compatible = false;
174       return;
175     }
176 
177     int64_t smallest_to_merge = *merge_group_members[to_merge_group_id].begin();
178     int64_t smallest_dst = *dst_group_members[dst_group_id].begin();
179     if (smallest_to_merge < smallest_dst) {
180       if (merge_group_members[to_merge_group_id].count(smallest_dst) == 0) {
181         compatible = false;
182         return;
183       }
184       *device = smallest_dst;
185     } else {
186       if (dst_group_members[dst_group_id].count(smallest_to_merge) == 0) {
187         compatible = false;
188         return;
189       }
190       *device = smallest_to_merge;
191     }
192     merge_group_members[to_merge_group_id].erase(*device);
193     dst_group_members[dst_group_id].erase(*device);
194   });
195   if (!compatible) {
196     return false;
197   }
198   std::vector<OpMetadata> merged_metadata(std::move(dst->metadata()));
199   merged_metadata.reserve(merged_metadata.size() + to_merge.metadata().size());
200   const absl::flat_hash_set<OpMetadata, protobuf_util::ProtobufHashWrapper,
201                             protobuf_util::ProtobufEqualsWrapper>
202       metadata_set(merged_metadata.begin(), merged_metadata.end());
203   absl::c_copy_if(to_merge.metadata(), std::back_inserter(merged_metadata),
204                   [&metadata_set](const OpMetadata& data) {
205                     return !ContainsKey(metadata_set, data);
206                   });
207   if (replication == 1) {
208     merged_tile_dims.pop_back();
209     merged_tile.Reshape(merged_tile_dims);
210     *dst = HloSharding::Tile(merged_tile, merged_metadata);
211   } else {
212     *dst = HloSharding::PartialTile(merged_tile, merged_metadata);
213   }
214   return true;
215 }
216 
SelectDominantDevice(const std::map<int64,int64> & device_map,int64 * top_count)217 absl::optional<int64> SelectDominantDevice(
218     const std::map<int64, int64>& device_map, int64* top_count) {
219   int64_t device = 0;
220   int64_t count = 0;
221   for (auto& it : device_map) {
222     if (it.second > count) {
223       count = it.second;
224       device = it.first;
225     }
226   }
227   if (top_count != nullptr) {
228     *top_count = count;
229   }
230   return count > 0 ? absl::optional<int64>(device) : absl::optional<int64>();
231 }
232 
AssignComputationDevice(HloComputation * computation,int64_t device)233 Status AssignComputationDevice(HloComputation* computation, int64_t device) {
234   VLOG(4) << "Assigning device " << device << " to " << computation->name()
235           << " computation";
236   for (HloInstruction* instruction : computation->instructions()) {
237     if (!instruction->has_sharding()) {
238       VLOG(4) << "Assigning device " << device << " to " << instruction->name();
239       instruction->set_device_sharding(device);
240     }
241   }
242   return Status::OK();
243 }
244 
GetMostOccurringDevice(absl::Span<HloInstruction * const> instructions)245 absl::optional<int64> GetMostOccurringDevice(
246     absl::Span<HloInstruction* const> instructions) {
247   std::map<int64, int64> device_map;
248   for (HloInstruction* instruction : instructions) {
249     if (instruction->has_sharding()) {
250       for (auto& it : instruction->sharding().UsedDevices(nullptr)) {
251         // The UsedDevices() API returns a map<device, occurrence_count>.
252         device_map[it.first] += it.second;
253       }
254     }
255   }
256   return SelectDominantDevice(device_map, nullptr);
257 }
258 
GetDominantDevice(absl::Span<HloComputation * const> computations,double dominant_factor)259 StatusOr<absl::optional<int64>> GetDominantDevice(
260     absl::Span<HloComputation* const> computations, double dominant_factor) {
261   int64_t instruction_count = 0;
262   std::map<int64, int64> device_map;
263   for (HloComputation* computation : computations) {
264     for (HloInstruction* instruction : computation->instructions()) {
265       int64_t count = 1;
266       if (instruction->has_sharding()) {
267         for (auto& it : instruction->sharding().UsedDevices(&count)) {
268           // The UsedDevices() API returns a map<device, occurrence_count>.
269           device_map[it.first] += it.second;
270         }
271       }
272       instruction_count += count;
273     }
274   }
275   int64_t count;
276   absl::optional<int64> device = SelectDominantDevice(device_map, &count);
277   absl::optional<int64> dominant_device;
278   if (device) {
279     double factor =
280         static_cast<double>(count) / static_cast<double>(instruction_count);
281     if (factor >= dominant_factor) {
282       dominant_device = device;
283     }
284   }
285   return dominant_device;
286 }
287 
TransposeSharding(const HloSharding & sharding,const std::vector<int64> & dimensions)288 HloSharding TransposeSharding(const HloSharding& sharding,
289                               const std::vector<int64>& dimensions) {
290   if (sharding.IsTileMaximal()) {
291     return sharding;
292   }
293   auto perm_dimensions = dimensions;
294   if (sharding.ReplicateOnLastTileDim() &&
295       dimensions.size() < sharding.tile_assignment().num_dimensions()) {
296     perm_dimensions.push_back(dimensions.size());
297   }
298   const int64_t rank = perm_dimensions.size();
299   std::vector<int64> tile_assignment_dim(rank);
300   for (int64_t i = 0; i < rank; ++i) {
301     tile_assignment_dim[i] = sharding.tile_assignment().dim(perm_dimensions[i]);
302   }
303   Array<int64> tile_assignment = sharding.tile_assignment();
304   tile_assignment.Reshape(tile_assignment_dim);
305   tile_assignment.Each([&](absl::Span<const int64> indices, int64* value) {
306     std::vector<int64> src_indices(indices.size(), -1);
307     for (int64_t i = 0; i < indices.size(); ++i) {
308       src_indices[perm_dimensions[i]] = indices[i];
309     }
310     *value = sharding.tile_assignment()(src_indices);
311   });
312   return sharding.ReplicateOnLastTileDim()
313              ? HloSharding::PartialTile(tile_assignment, sharding.metadata())
314              : HloSharding::Tile(tile_assignment, sharding.metadata());
315 }
316 
ReshapeSharding(const Shape & source_shape,const Shape & target_shape,const HloSharding & sharding)317 absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
318                                             const Shape& target_shape,
319                                             const HloSharding& sharding) {
320   if (sharding.IsTileMaximal()) {
321     return sharding;
322   }
323 
324   // In case of a tiled sharding the reshaped sharding will be a valid if the
325   // reshape is composed from the following operations:
326   // * Adding or removing dimensions with size 1.
327   // * Merging consecutive dimensions where only the most major is sharded.
328   // * Splitting a dimension to consecutive dimensions.
329   // * Any reshaping of unsharded dimensions.
330   // Note that merge and split can happen consecutively on the same dimension,
331   // e.g., f32[1024,256,1024] to f32[128,2048,1024] can be considered that 1024
332   // gets split into 128 and 8, but 8 then gets merged with 256. We use stacks
333   // to make supporting such cases easy.
334   const Shape tile_shape = sharding.TileShape(source_shape);
335   std::vector<int64> target_tile_assignment_dimensions;
336   std::vector<int64> source_dims_stack(source_shape.rank());
337   std::vector<int64> target_dims_stack(target_shape.rank());
338   std::vector<int64> sharding_tile_dims_stack(source_shape.rank());
339   for (int64_t i = 0; i < source_shape.rank(); ++i) {
340     source_dims_stack[i] = source_shape.dimensions(source_shape.rank() - 1 - i);
341     sharding_tile_dims_stack[i] =
342         sharding.tile_assignment().dim(source_shape.rank() - 1 - i);
343   }
344   for (int64_t i = 0; i < target_shape.rank(); ++i) {
345     target_dims_stack[i] = target_shape.dimensions(target_shape.rank() - 1 - i);
346   }
347   while (!source_dims_stack.empty() || !target_dims_stack.empty()) {
348     if (target_dims_stack.empty()) {
349       if (Product(sharding_tile_dims_stack) != 1) {
350         return absl::nullopt;
351       }
352       break;
353     }
354     int64_t s_size = 1;
355     int64_t t_size = 1;
356     int64_t s_partitions = 1;
357     if (!source_dims_stack.empty()) {
358       s_size = source_dims_stack.back();
359       source_dims_stack.pop_back();
360       s_partitions = sharding_tile_dims_stack.back();
361       sharding_tile_dims_stack.pop_back();
362     }
363     t_size = target_dims_stack.back();
364     target_dims_stack.pop_back();
365     if (s_partitions * Product(sharding_tile_dims_stack) == 1) {
366       // No more partitions left.
367       target_tile_assignment_dimensions.push_back(1);
368       continue;
369     }
370     if (s_size == t_size) {
371       // Same dimension.
372       target_tile_assignment_dimensions.push_back(s_partitions);
373     } else if (t_size == 1) {
374       // Trivial dimension added.
375       target_tile_assignment_dimensions.push_back(1);
376       source_dims_stack.push_back(s_size);
377       sharding_tile_dims_stack.push_back(s_partitions);
378     } else if (s_size == 1) {
379       // Trivial dimension removed.
380       if (s_partitions != 1) {
381         return absl::nullopt;
382       }
383       target_dims_stack.push_back(t_size);
384     } else if (s_size > t_size) {
385       // Dimension split.
386       if (s_size % t_size != 0 || s_size % s_partitions != 0) {
387         return absl::nullopt;
388       }
389       if (t_size % s_partitions == 0) {
390         target_tile_assignment_dimensions.push_back(s_partitions);
391         // We have part of the s_size unprocessed, so put it back to stack.
392         source_dims_stack.push_back(s_size / t_size);
393         sharding_tile_dims_stack.push_back(1);
394       } else if (s_partitions % t_size == 0) {
395         target_tile_assignment_dimensions.push_back(t_size);
396         // We have part of the s_size unprocessed, so put it back to stack.
397         source_dims_stack.push_back(s_size / t_size);
398         sharding_tile_dims_stack.push_back(s_partitions / t_size);
399       } else {
400         return absl::nullopt;
401       }
402     } else {
403       // Dimension merge. Also merge the source dimension with the next, and
404       // process it next time.
405       if (s_size % s_partitions != 0) {
406         return absl::nullopt;
407       }
408       CHECK(!source_dims_stack.empty());
409       if (sharding_tile_dims_stack.back() != 1 && s_size != s_partitions) {
410         // If the next dimension to combine is sharded, we require that the
411         // current dimension's shard size to be 1. Otherwise, the new shard
412         // would be non-contiguous.
413         return absl::nullopt;
414       }
415       source_dims_stack.back() *= s_size;
416       sharding_tile_dims_stack.back() *= s_partitions;
417       target_dims_stack.push_back(t_size);
418     }
419   }
420   Array<int64> new_tile_assignment = sharding.tile_assignment();
421   if (sharding.ReplicateOnLastTileDim()) {
422     target_tile_assignment_dimensions.push_back(
423         sharding.tile_assignment().dimensions().back());
424   }
425   new_tile_assignment.Reshape(target_tile_assignment_dimensions);
426   return sharding.ReplicateOnLastTileDim()
427              ? HloSharding::PartialTile(new_tile_assignment,
428                                         sharding.metadata())
429              : HloSharding::Tile(new_tile_assignment, sharding.metadata());
430 }
431 
ReverseSharding(const HloSharding & sharding,absl::Span<const int64> dimensions)432 HloSharding ReverseSharding(const HloSharding& sharding,
433                             absl::Span<const int64> dimensions) {
434   if (sharding.IsTileMaximal() || dimensions.empty()) {
435     return sharding;
436   }
437 
438   Array<int64> new_tile_assignment(sharding.tile_assignment().dimensions());
439   new_tile_assignment.Each([&](absl::Span<const int64> indices, int64* device) {
440     std::vector<int64> original_indices(indices.begin(), indices.end());
441     for (int64_t d : dimensions) {
442       original_indices[d] =
443           new_tile_assignment.dim(d) - 1 - original_indices[d];
444     }
445     *device = sharding.tile_assignment()(original_indices);
446   });
447   return sharding.ReplicateOnLastTileDim()
448              ? HloSharding::PartialTile(new_tile_assignment,
449                                         sharding.metadata())
450              : HloSharding::Tile(new_tile_assignment, sharding.metadata());
451 }
452 
ReshapeToTileDimension(const HloSharding & sharding,int64_t dim,absl::Span<const int64> dims)453 HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64_t dim,
454                                    absl::Span<const int64> dims) {
455   CHECK(!sharding.IsTuple() && !sharding.IsTileMaximal());
456   CHECK_NE(absl::c_find(dims, dim), dims.end()) << "dim is not in dims";
457   // We optimize the tile assignment on the single dimension dim in a way to
458   // minimize communication among devices caused by the reshard:
459   // +---+---+               +---+---+              +-+-+-+-+
460   // |   |   |               |   0   |              | | | | |
461   // | 0 | 1 |               +-------+              | | | | |
462   // |   |   |  reshape on   |   1   |  reshape on  | | | | |
463   // +---+---+   dim 0  =>   +-------+   dim 1  =>  |0|2|1|3|
464   // |   |   |               |   2   |              | | | | |
465   // | 2 | 3 |               +-------+              | | | | |
466   // |   |   |               |   3   |              | | | | |
467   // +---+---+               +---+---+              +-+-+-+-+
468 
469   std::vector<int64> tile_dims(sharding.tile_assignment().num_dimensions(), 1);
470   // Handle ignore dimensions.
471   std::vector<int64> ignore_sizes;
472   int64_t ignore_size = 1;
473   for (int64_t i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
474     if (absl::c_find(dims, i) == dims.end()) {
475       int64_t size = sharding.tile_assignment().dim(i);
476       ignore_sizes.push_back(size);
477       tile_dims[i] = size;
478       ignore_size *= size;
479     }
480   }
481 
482   using Buckets = std::vector<std::vector<int64>>;
483   Array<Buckets> buckets(ignore_sizes,
484                          Buckets(sharding.tile_assignment().dim(dim)));
485   sharding.tile_assignment().Each(
486       [&](absl::Span<const int64> index, int64_t device) {
487         std::vector<int64> ignore_index;
488         for (int64_t i = 0; i < index.size(); ++i) {
489           if (absl::c_find(dims, i) == dims.end()) {
490             ignore_index.push_back(index[i]);
491           }
492         }
493         buckets(ignore_index)[index[dim]].push_back(device);
494       });
495   std::vector<int64> devices;
496   buckets.Each([&](absl::Span<const int64> index, const Buckets& buckets) {
497     for (auto& bucket : buckets) {
498       devices.insert(devices.end(), bucket.begin(), bucket.end());
499     }
500   });
501   tile_dims[dim] = devices.size() / ignore_size;
502   Array<int64> tile_assignment(tile_dims);
503   tile_assignment.SetValues(devices);
504   return HloSharding::Tile(tile_assignment, sharding.metadata());
505 }
506 
ContainsTileSharding(const HloModule & module)507 bool ContainsTileSharding(const HloModule& module) {
508   for (const HloComputation* computation : module.computations()) {
509     for (const HloInstruction* instruction : computation->instructions()) {
510       if (instruction->has_sharding() &&
511           !instruction->sharding().IsTileMaximal()) {
512         return true;
513       }
514     }
515   }
516   return false;
517 }
518 
GatherOutputSharding(const HloSharding & index_sharding,const HloInstruction * hlo)519 HloSharding GatherOutputSharding(const HloSharding& index_sharding,
520                                  const HloInstruction* hlo) {
521   if (index_sharding.IsTileMaximal()) {
522     return index_sharding;
523   }
524 
525   const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers();
526   std::vector<int64> output_tile_assignment_dims;
527   for (int64_t i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) {
528     if (absl::c_binary_search(dnums.offset_dims(), i)) {
529       output_tile_assignment_dims.push_back(1);
530     } else {
531       const int64_t new_tile_dimension =
532           index_dim >= dnums.index_vector_dim() ? index_dim + 1 : index_dim;
533       output_tile_assignment_dims.push_back(
534           index_sharding.tile_assignment().dim(new_tile_dimension));
535       ++index_dim;
536     }
537   }
538 
539   if (index_sharding.ReplicateOnLastTileDim()) {
540     output_tile_assignment_dims.push_back(
541         index_sharding.tile_assignment().dimensions().back());
542   }
543 
544   Array<int64> new_tile_assignment = index_sharding.tile_assignment();
545   if (new_tile_assignment.num_elements() !=
546       Product(output_tile_assignment_dims)) {
547     return HloSharding::Replicate(index_sharding.metadata());
548   }
549   new_tile_assignment.Reshape(output_tile_assignment_dims);
550   return index_sharding.ReplicateOnLastTileDim()
551              ? HloSharding::PartialTile(new_tile_assignment,
552                                         index_sharding.metadata())
553              : HloSharding::Tile(new_tile_assignment,
554                                  index_sharding.metadata());
555 }
556 
GatherIndexSharding(const HloSharding & output_sharding,const HloInstruction * hlo)557 HloSharding GatherIndexSharding(const HloSharding& output_sharding,
558                                 const HloInstruction* hlo) {
559   CHECK(hlo->opcode() == HloOpcode::kGather);
560   if (output_sharding.IsTileMaximal()) {
561     return output_sharding;
562   }
563 
564   const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers();
565   std::vector<int64> index_tile_assignment_dims;
566   // Relevant output dims have shardings passed to the index.
567   std::vector<int64> relevant_output_dims;
568   for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
569     if (!absl::c_binary_search(dnums.offset_dims(), i)) {
570       index_tile_assignment_dims.push_back(
571           output_sharding.tile_assignment().dim(i));
572       relevant_output_dims.push_back(i);
573     }
574   }
575   int64_t index_rank = hlo->operand(1)->shape().rank();
576 
577   // Vector indices sharding is not supported yet.
578   if (index_rank > index_tile_assignment_dims.size()) {
579     index_tile_assignment_dims.insert(
580         index_tile_assignment_dims.begin() + dnums.index_vector_dim(), 1);
581   }
582 
583   if (Product(index_tile_assignment_dims) == 1) {
584     return HloSharding::Replicate(output_sharding.metadata());
585   }
586   HloSharding relevant_output_sharding =
587       PartiallyReplicateTiledShardingOnAllDimsExcept(output_sharding,
588                                                      relevant_output_dims);
589   int64_t partial_replication_size = 1;
590   if (relevant_output_sharding.ReplicateOnLastTileDim()) {
591     partial_replication_size *=
592         relevant_output_sharding.tile_assignment().dimensions().back();
593   }
594 
595   Array<int64> new_tile_assignment = relevant_output_sharding.tile_assignment();
596   const int64_t index_tile_elements =
597       Product(index_tile_assignment_dims) * partial_replication_size;
598   if (new_tile_assignment.num_elements() != index_tile_elements) {
599     if (new_tile_assignment.num_elements() % index_tile_elements == 0) {
600       partial_replication_size *=
601           (new_tile_assignment.num_elements() / index_tile_elements);
602     } else {
603       return HloSharding::Replicate(output_sharding.metadata());
604     }
605   }
606   if (partial_replication_size > 1) {
607     index_tile_assignment_dims.push_back(partial_replication_size);
608   }
609   new_tile_assignment.Reshape(index_tile_assignment_dims);
610   return partial_replication_size > 1
611              ? HloSharding::PartialTile(new_tile_assignment,
612                                         output_sharding.metadata())
613              : HloSharding::Tile(new_tile_assignment,
614                                  output_sharding.metadata());
615 }
616 
GatherEffectiveOutputSharding(const HloInstruction & hlo)617 HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) {
618   if (hlo.sharding().IsTileMaximal()) {
619     return hlo.sharding();
620   }
621 
622   const GatherDimensionNumbers& dnums = hlo.gather_dimension_numbers();
623   std::vector<int64> tile_assignment_dims(hlo.shape().rank());
624   int64_t num_elements = 1;
625   for (int64_t i = 0; i < hlo.shape().rank(); ++i) {
626     if (!absl::c_binary_search(dnums.offset_dims(), i)) {
627       tile_assignment_dims[i] = hlo.sharding().tile_assignment().dim(i);
628       num_elements *= hlo.sharding().tile_assignment().dim(i);
629     } else {
630       tile_assignment_dims[i] = 1;
631     }
632   }
633   if (num_elements == hlo.sharding().tile_assignment().num_elements()) {
634     // Output sharding is only on non offset dimensions. We use output sharding
635     // to shard this gather op directly.
636     return hlo.sharding();
637   }
638 
639   if (num_elements == 1) {
640     // Output sharding is only on offset dimensions. We do not shard this gather
641     // op. Return a tile maximal sharding with the first device in output
642     // sharding tile assignment.
643     return HloSharding::AssignDevice(*hlo.sharding().tile_assignment().begin(),
644                                      hlo.sharding().metadata());
645   }
646 
647   // Output sharding is on both offset and non offset dimensions. We shard the
648   // gather op only on non offset dimensions.
649   // For example:
650   // - the gather op has sharding [2,2]{0,1,2,3},
651   // - first dimension is non offset dimension,
652   // - second dimension is offset dimension,
653   // Then the result sharding will be [2,1]{0,2}.
654   std::vector<int64> slice_starts(hlo.shape().rank(), 0LL),
655       slice_limits(hlo.shape().rank());
656   for (int64_t i = 0; i < hlo.shape().rank(); ++i) {
657     if (!absl::c_binary_search(dnums.offset_dims(), i)) {
658       slice_limits[i] = hlo.sharding().tile_assignment().dim(i);
659     } else {
660       slice_limits[i] = 1;
661     }
662   }
663   Array<int64> tile_assignment =
664       hlo.sharding().tile_assignment().Slice(slice_starts, slice_limits);
665   return HloSharding::Tile(tile_assignment, hlo.sharding().metadata());
666 }
667 
ScatterIndexSharding(const HloSharding & data_sharding,const HloInstruction * hlo)668 HloSharding ScatterIndexSharding(const HloSharding& data_sharding,
669                                  const HloInstruction* hlo) {
670   if (data_sharding.IsTileMaximal()) {
671     return data_sharding;
672   }
673 
674   const ScatterDimensionNumbers& dnums = hlo->scatter_dimension_numbers();
675   std::vector<int64> index_tile_assignment_dims;
676   for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
677     if (!absl::c_binary_search(dnums.update_window_dims(), i)) {
678       index_tile_assignment_dims.push_back(
679           data_sharding.tile_assignment().dim(i));
680     }
681   }
682   if (index_tile_assignment_dims.size() < hlo->operand(1)->shape().rank()) {
683     index_tile_assignment_dims.push_back(1);
684   }
685   if (data_sharding.ReplicateOnLastTileDim()) {
686     index_tile_assignment_dims.push_back(
687         data_sharding.tile_assignment().dimensions().back());
688   }
689   Array<int64> new_tile_assignment = data_sharding.tile_assignment();
690   if (new_tile_assignment.num_elements() !=
691       Product(index_tile_assignment_dims)) {
692     return HloSharding::Replicate(data_sharding.metadata());
693   }
694   new_tile_assignment.Reshape(index_tile_assignment_dims);
695   return data_sharding.ReplicateOnLastTileDim()
696              ? HloSharding::PartialTile(new_tile_assignment,
697                                         data_sharding.metadata())
698              : HloSharding::Tile(new_tile_assignment, data_sharding.metadata());
699 }
700 
ScatterDataSharding(const HloSharding & index_sharding,const HloInstruction * hlo)701 HloSharding ScatterDataSharding(const HloSharding& index_sharding,
702                                 const HloInstruction* hlo) {
703   if (index_sharding.IsTileMaximal()) {
704     return index_sharding;
705   }
706 
707   const ScatterDimensionNumbers& dnums = hlo->scatter_dimension_numbers();
708   std::vector<int64> data_tile_assignment_dims;
709   for (int64_t i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) {
710     if (absl::c_binary_search(dnums.update_window_dims(), i)) {
711       data_tile_assignment_dims.push_back(1);
712     } else {
713       data_tile_assignment_dims.push_back(
714           index_sharding.tile_assignment().dim(index_dim));
715       index_dim++;
716     }
717   }
718   if (index_sharding.ReplicateOnLastTileDim()) {
719     data_tile_assignment_dims.push_back(
720         index_sharding.tile_assignment().dimensions().back());
721   }
722   Array<int64> new_tile_assignment = index_sharding.tile_assignment();
723   if (new_tile_assignment.num_elements() !=
724       Product(data_tile_assignment_dims)) {
725     return HloSharding::Replicate(index_sharding.metadata());
726   }
727   new_tile_assignment.Reshape(data_tile_assignment_dims);
728   return index_sharding.ReplicateOnLastTileDim()
729              ? HloSharding::PartialTile(new_tile_assignment,
730                                         index_sharding.metadata())
731              : HloSharding::Tile(new_tile_assignment,
732                                  index_sharding.metadata());
733 }
734 
ScatterEffectiveIndexSharding(const HloSharding & index_sharding,const HloInstruction & hlo)735 HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding,
736                                           const HloInstruction& hlo) {
737   if (index_sharding.IsTileMaximal()) {
738     return index_sharding;
739   }
740 
741   // Only shard on first "number of scatter_window_dims" dimensions.
742   const ScatterDimensionNumbers& dnums = hlo.scatter_dimension_numbers();
743   int64_t num_elements = 1;
744   int64_t index_dim = 0;
745   for (int64_t i = 0; i < hlo.shape().rank(); ++i) {
746     if (absl::c_binary_search(dnums.inserted_window_dims(), i)) {
747       num_elements *= index_sharding.tile_assignment().dim(index_dim);
748       index_dim++;
749     }
750   }
751   if (num_elements == index_sharding.tile_assignment().num_elements()) {
752     // Index sharding is only on scatter_window_dims. We use this index sharding
753     // directly.
754     return index_sharding;
755   }
756 
757   // Index sharding is only on update_window_dims. We do not shard this scatter
758   // op. Return a tile maximal sharding with the first device in index sharding
759   // tile assignment.
760   if (num_elements == 1) {
761     return HloSharding::AssignDevice(*index_sharding.tile_assignment().begin(),
762                                      index_sharding.metadata());
763   }
764 
765   const int64_t index_rank = hlo.operand(1)->shape().rank();
766   std::vector<int64> slice_starts(index_rank, 0LL), slice_limits(index_rank);
767   for (int64_t i = 0; i < index_rank; ++i) {
768     if (i < index_dim) {
769       slice_limits[i] = index_sharding.tile_assignment().dim(i);
770     } else {
771       slice_limits[i] = 1;
772     }
773   }
774   Array<int64> tile_assignment =
775       index_sharding.tile_assignment().Slice(slice_starts, slice_limits);
776   return HloSharding::Tile(tile_assignment, index_sharding.metadata());
777 }
778 
ScatterEffectiveDataSharding(const HloSharding & data_sharding,const HloInstruction & hlo)779 HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding,
780                                          const HloInstruction& hlo) {
781   if (data_sharding.IsTileMaximal()) {
782     return data_sharding;
783   }
784 
785   const ScatterDimensionNumbers& dnums = hlo.scatter_dimension_numbers();
786   const int64_t data_rank = hlo.operand(2)->shape().rank();
787   std::vector<int64> tile_assignment_dims(data_rank, 1LL);
788   int64_t num_elements = 1;
789   for (int64_t i = 0; i < hlo.shape().rank(); ++i) {
790     if (absl::c_binary_search(dnums.inserted_window_dims(), i)) {
791       CHECK_LT(i, data_rank);
792       tile_assignment_dims[i] = data_sharding.tile_assignment().dim(i);
793       num_elements *= data_sharding.tile_assignment().dim(i);
794     }
795   }
796   if (num_elements == data_sharding.tile_assignment().num_elements()) {
797     // Data sharding is only on scatter_window_dims. We use this data sharding
798     // directly.
799     return data_sharding;
800   }
801 
802   if (num_elements == 1) {
803     // Data sharding is only on update_window_dims. We do not shard this
804     // scatter op. Return a tile maximal sharding with the first device in
805     // data sharding tile assignment.
806     return HloSharding::AssignDevice(*data_sharding.tile_assignment().begin(),
807                                      data_sharding.metadata());
808   }
809 
810   // Data sharding is on both update_window_dims and scatter_window_dims. We
811   // shard the scatter op only on scatter_window_dims. For example:
812   // - the scatter data has sharding [2,2]{0,1,2,3},
813   // - first dimension is scatter_window_dims,
814   // - second dimension is update_window_dims,
815   // Then the result sharding will be [2,1]{0,2}.
816   std::vector<int64> slice_starts(data_rank, 0LL);
817   Array<int64> tile_assignment =
818       data_sharding.tile_assignment().Slice(slice_starts, tile_assignment_dims);
819   return HloSharding::Tile(tile_assignment, data_sharding.metadata());
820 }
821 
822 namespace {
823 
824 // If partitioning in the operand only happens in dimensions in passthrough
825 // dimensions (offset dimensions in the gather output (or scatter update) that
826 // have the same size as the operand), returns the corresponding output (or
827 // update) sharding by passing through the input sharding.
PassthroughOperandToGatherOutputOrScatterUpdate(const Shape & operand_shape,const HloSharding & operand_sharding,const Shape & update_or_gather_shape,absl::Span<const int64> collapsed_or_inserted_dims,absl::Span<const int64> index_map,absl::Span<const int64> offset_or_window_dims,absl::Span<const int64> slice_size)828 absl::optional<HloSharding> PassthroughOperandToGatherOutputOrScatterUpdate(
829     const Shape& operand_shape, const HloSharding& operand_sharding,
830     const Shape& update_or_gather_shape,
831     absl::Span<const int64> collapsed_or_inserted_dims,
832     absl::Span<const int64> index_map,
833     absl::Span<const int64> offset_or_window_dims,
834     absl::Span<const int64> slice_size) {
835   if (operand_sharding.IsTileMaximal()) {
836     return operand_sharding;
837   }
838   std::vector<int64> passthrough_tile(update_or_gather_shape.rank(), 1);
839   int64_t collapsed = 0;
840   for (int64_t i = 0; i < operand_shape.rank(); ++i) {
841     int64_t dim_partitions = operand_sharding.tile_assignment().dim(i);
842     if (absl::c_linear_search(collapsed_or_inserted_dims, i) ||
843         absl::c_linear_search(index_map, i)) {
844       if (dim_partitions > 1) {
845         return absl::nullopt;
846       }
847       collapsed++;
848       continue;
849     }
850     if (slice_size[i] != operand_shape.dimensions(i) && dim_partitions > 1) {
851       return absl::nullopt;
852     }
853     int64_t offset_dim = offset_or_window_dims[i - collapsed];
854     if (i - collapsed > 0 &&
855         offset_dim < offset_or_window_dims[i - collapsed - 1]) {
856       // Output offsets are transposed, we do not support this case.
857       return absl::nullopt;
858     }
859     passthrough_tile[offset_dim] = dim_partitions;
860   }
861   if (operand_sharding.ReplicateOnLastTileDim()) {
862     passthrough_tile.push_back(
863         operand_sharding.tile_assignment().dimensions().back());
864   }
865   Array<int64> tile_assignment = operand_sharding.tile_assignment();
866   tile_assignment.Reshape(passthrough_tile);
867   return operand_sharding.ReplicateOnLastTileDim()
868              ? HloSharding::PartialTile(tile_assignment,
869                                         operand_sharding.metadata())
870              : HloSharding::Tile(tile_assignment, operand_sharding.metadata());
871 }
872 
873 // Inverse of PassthroughOperandToGatherOutputOrScatterUpdate.
PassthroughGatherOutputOrScatterUpdateToOperand(const Shape & operand_shape,const HloSharding & update_or_gather_sharding,absl::Span<const int64> collapsed_or_inserted_dims,absl::Span<const int64> index_map,absl::Span<const int64> offset_or_window_dims,absl::Span<const int64> slice_size)874 absl::optional<HloSharding> PassthroughGatherOutputOrScatterUpdateToOperand(
875     const Shape& operand_shape, const HloSharding& update_or_gather_sharding,
876     absl::Span<const int64> collapsed_or_inserted_dims,
877     absl::Span<const int64> index_map,
878     absl::Span<const int64> offset_or_window_dims,
879     absl::Span<const int64> slice_size) {
880   if (update_or_gather_sharding.IsTileMaximal()) {
881     return update_or_gather_sharding;
882   }
883   std::vector<int64> passthrough_tile(operand_shape.rank(), 1);
884   int64_t collapsed = 0;
885   // Relevant dims have shardings passed to the operand.
886   std::vector<int64> relevant_update_or_gather_dims;
887   for (int64_t i = 0; i < operand_shape.rank(); ++i) {
888     if (absl::c_linear_search(collapsed_or_inserted_dims, i) ||
889         absl::c_linear_search(index_map, i)) {
890       collapsed++;
891       continue;
892     }
893     int64_t offset_dim = offset_or_window_dims[i - collapsed];
894     int64_t dim_partitions =
895         update_or_gather_sharding.tile_assignment().dim(offset_dim);
896     if (slice_size[i] != operand_shape.dimensions(i) && dim_partitions > 1) {
897       return absl::nullopt;
898     }
899     if (i - collapsed > 0 &&
900         offset_dim < offset_or_window_dims[i - collapsed - 1]) {
901       // Output offsets are transposed, we do not support this case.
902       return absl::nullopt;
903     }
904     relevant_update_or_gather_dims.push_back(offset_dim);
905     passthrough_tile[i] = dim_partitions;
906   }
907 
908   Array<int64> tile_assignment =
909       PartiallyReplicateTiledShardingOnAllDimsExcept(
910           update_or_gather_sharding, relevant_update_or_gather_dims)
911           .tile_assignment();
912   bool is_partial = false;
913   if (tile_assignment.num_elements() != Product(passthrough_tile)) {
914     passthrough_tile.push_back(tile_assignment.num_elements() /
915                                Product(passthrough_tile));
916     is_partial = true;
917   }
918   tile_assignment.Reshape(passthrough_tile);
919   return is_partial ? HloSharding::PartialTile(
920                           tile_assignment, update_or_gather_sharding.metadata())
921                     : HloSharding::Tile(tile_assignment,
922                                         update_or_gather_sharding.metadata());
923 }
924 
925 // Collect data operand sharding for a gather with parallel dimensions from
926 // the output.
GatherParallelDataOperandSharding(const HloSharding & output_sharding,const HloInstruction & gather,const GatherParallelDims & parallel_dims)927 absl::optional<HloSharding> GatherParallelDataOperandSharding(
928     const HloSharding& output_sharding, const HloInstruction& gather,
929     const GatherParallelDims& parallel_dims) {
930   if (output_sharding.IsTileMaximal()) {
931     return output_sharding;
932   }
933   auto output_parallel_dims = GatherParallelOutputDims(gather, parallel_dims);
934   auto output_aligned_operand_parallel_dims =
935       GatherOutputAlignedOperandParallelDims(gather, parallel_dims);
936   const Shape gather_shape = gather.shape();
937   CHECK_EQ(output_parallel_dims.size(),
938            output_aligned_operand_parallel_dims.size());
939   std::vector<int64> operand_tile_assignment(gather.operand(0)->shape().rank(),
940                                              1);
941   std::vector<int64> relevant_output_dims;
942   for (int i = 0, parallel_idx = 0; i < gather_shape.rank(); ++i) {
943     if (parallel_idx >= output_parallel_dims.size() ||
944         output_parallel_dims[parallel_idx] != i) {
945       continue;
946     }
947     const int64_t operand_dim =
948         output_aligned_operand_parallel_dims[parallel_idx++];
949     operand_tile_assignment[operand_dim] =
950         output_sharding.tile_assignment().dim(i);
951     relevant_output_dims.push_back(i);
952   }
953   HloSharding relevant_output_sharding =
954       PartiallyReplicateTiledShardingOnAllDimsExcept(output_sharding,
955                                                      relevant_output_dims);
956   int64_t partially_replicated_size = 1;
957   if (relevant_output_sharding.ReplicateOnLastTileDim()) {
958     partially_replicated_size *=
959         relevant_output_sharding.tile_assignment().dimensions().back();
960   }
961 
962   if (relevant_output_sharding.IsTileMaximal()) {
963     partially_replicated_size *=
964         Product(output_sharding.tile_assignment().dimensions());
965   }
966 
967   Array<int64> tile_assignment = relevant_output_sharding.tile_assignment();
968   if (relevant_output_sharding.IsTileMaximal()) {
969     Array<int64> replicated_tile_assignment({gather_shape.rank()}, 1L);
970     tile_assignment = replicated_tile_assignment;
971   }
972   const int64_t operand_tile_elements =
973       Product(operand_tile_assignment) * partially_replicated_size;
974   if (tile_assignment.num_elements() != operand_tile_elements) {
975     if (tile_assignment.num_elements() % operand_tile_elements == 0) {
976       partially_replicated_size *=
977           (tile_assignment.num_elements() / operand_tile_elements);
978     } else {
979       return absl::nullopt;
980     }
981   }
982   if (partially_replicated_size > 1) {
983     operand_tile_assignment.push_back(partially_replicated_size);
984   }
985   tile_assignment.Reshape(operand_tile_assignment);
986   return partially_replicated_size > 1
987              ? HloSharding::PartialTile(tile_assignment,
988                                         output_sharding.metadata())
989              : HloSharding::Tile(tile_assignment, output_sharding.metadata());
990 }
991 
992 }  // namespace
993 
GatherOutputShardingFromDataOperand(const HloSharding & data_operand_sharding,const HloInstruction & hlo,const Shape & output_shape,const Shape & operand_shape)994 absl::optional<HloSharding> GatherOutputShardingFromDataOperand(
995     const HloSharding& data_operand_sharding, const HloInstruction& hlo,
996     const Shape& output_shape, const Shape& operand_shape) {
997   const auto& dnums = hlo.gather_dimension_numbers();
998   std::vector<int64> collapsed_slice_dims(dnums.collapsed_slice_dims().begin(),
999                                           dnums.collapsed_slice_dims().end());
1000   std::vector<int64> start_index_map(dnums.start_index_map().begin(),
1001                                      dnums.start_index_map().end());
1002   std::vector<int64> offset_dims(dnums.offset_dims().begin(),
1003                                  dnums.offset_dims().end());
1004   return PassthroughOperandToGatherOutputOrScatterUpdate(
1005       operand_shape, data_operand_sharding, output_shape, collapsed_slice_dims,
1006       start_index_map, offset_dims, hlo.gather_slice_sizes());
1007 }
1008 
GatherDataOperandShardingFromOutput(const HloSharding & output_sharding,const HloInstruction & hlo)1009 absl::optional<HloSharding> GatherDataOperandShardingFromOutput(
1010     const HloSharding& output_sharding, const HloInstruction& hlo) {
1011   const auto& dnums = hlo.gather_dimension_numbers();
1012   std::vector<int64> collapsed_slice_dims(dnums.collapsed_slice_dims().begin(),
1013                                           dnums.collapsed_slice_dims().end());
1014   std::vector<int64> start_index_map(dnums.start_index_map().begin(),
1015                                      dnums.start_index_map().end());
1016   std::vector<int64> offset_dims(dnums.offset_dims().begin(),
1017                                  dnums.offset_dims().end());
1018 
1019   absl::optional<HloSharding> parallel_sharding;
1020   auto parallel_dims = GetGatherBatchParallelDims(hlo);
1021   if (parallel_dims) {
1022     // Prioritize parallel sharding first as this is how it is in
1023     // spmd_partitioner.
1024     parallel_sharding =
1025         GatherParallelDataOperandSharding(output_sharding, hlo, *parallel_dims);
1026   }
1027   absl::optional<HloSharding> passthrough_sharding =
1028       PassthroughGatherOutputOrScatterUpdateToOperand(
1029           hlo.operand(0)->shape(), output_sharding, collapsed_slice_dims,
1030           start_index_map, offset_dims, hlo.gather_slice_sizes());
1031   // Try to merge the two shardings or return the one that is present if only
1032   // one of the two is.
1033   if (!passthrough_sharding) {
1034     return parallel_sharding;
1035   }
1036   if (!parallel_sharding) {
1037     return passthrough_sharding;
1038   }
1039   if (MergeSharding(*parallel_sharding, &*passthrough_sharding,
1040                     /*may_combine_partial_sharding=*/true)) {
1041     return passthrough_sharding;
1042   }
1043   if (MergeSharding(*passthrough_sharding, &*parallel_sharding,
1044                     /*may_combine_partial_sharding=*/true)) {
1045     return parallel_sharding;
1046   }
1047   return absl::nullopt;
1048 }
1049 
ScatterOutputShardingFromUpdate(const HloSharding & update_sharding,const HloInstruction & hlo)1050 absl::optional<HloSharding> ScatterOutputShardingFromUpdate(
1051     const HloSharding& update_sharding, const HloInstruction& hlo) {
1052   const auto& dnums = hlo.scatter_dimension_numbers();
1053   std::vector<int64> inserted_window_dims(dnums.inserted_window_dims().begin(),
1054                                           dnums.inserted_window_dims().end());
1055   std::vector<int64> scatter_dims_to_operand_dims(
1056       dnums.scatter_dims_to_operand_dims().begin(),
1057       dnums.scatter_dims_to_operand_dims().end());
1058   std::vector<int64> update_window_dims(dnums.update_window_dims().begin(),
1059                                         dnums.update_window_dims().end());
1060   std::vector<int64> slice_size(hlo.shape().rank(), 1);
1061   int64_t num_update_window_dims = 0;
1062   for (int64_t i = 0; i < hlo.shape().rank(); ++i) {
1063     if (absl::c_linear_search(dnums.inserted_window_dims(), i)) {
1064       continue;
1065     }
1066     slice_size[i] = hlo.operand(2)->shape().dimensions(
1067         dnums.update_window_dims(num_update_window_dims++));
1068   }
1069   return PassthroughGatherOutputOrScatterUpdateToOperand(
1070       hlo.shape(), update_sharding, inserted_window_dims,
1071       scatter_dims_to_operand_dims, update_window_dims, slice_size);
1072 }
1073 
ScatterUpdateShardingFromOutput(const HloSharding & output_sharding,const HloInstruction & hlo)1074 absl::optional<HloSharding> ScatterUpdateShardingFromOutput(
1075     const HloSharding& output_sharding, const HloInstruction& hlo) {
1076   const auto& dnums = hlo.scatter_dimension_numbers();
1077   std::vector<int64> inserted_window_dims(dnums.inserted_window_dims().begin(),
1078                                           dnums.inserted_window_dims().end());
1079   std::vector<int64> scatter_dims_to_operand_dims(
1080       dnums.scatter_dims_to_operand_dims().begin(),
1081       dnums.scatter_dims_to_operand_dims().end());
1082   std::vector<int64> update_window_dims(dnums.update_window_dims().begin(),
1083                                         dnums.update_window_dims().end());
1084   std::vector<int64> slice_size(hlo.shape().rank(), 1);
1085   int64_t num_update_window_dims = 0;
1086   for (int64_t i = 0; i < hlo.shape().rank(); ++i) {
1087     if (absl::c_linear_search(dnums.inserted_window_dims(), i)) {
1088       continue;
1089     }
1090     slice_size[i] = hlo.operand(2)->shape().dimensions(
1091         dnums.update_window_dims(num_update_window_dims++));
1092   }
1093   return PassthroughOperandToGatherOutputOrScatterUpdate(
1094       hlo.shape(), output_sharding, hlo.operand(2)->shape(),
1095       inserted_window_dims, scatter_dims_to_operand_dims, update_window_dims,
1096       slice_size);
1097 }
1098 
1099 StatusOr<std::pair<std::unique_ptr<HloInstruction>, HloOpcode>>
IdentityValueAndHloOpcodeForScatterReduceComputation(const HloScatterInstruction & scatter)1100 IdentityValueAndHloOpcodeForScatterReduceComputation(
1101     const HloScatterInstruction& scatter) {
1102   auto computation = scatter.to_apply();
1103   // We only handle computations with 2 parameters and only 1 calculation.
1104   if (computation->instruction_count() != 3) {
1105     return Status(
1106         tensorflow::error::Code::INVALID_ARGUMENT,
1107         "Expected scatter reduce computation with 2 parameters and only 1 "
1108         "calculation");
1109   }
1110 
1111   auto root_instruction = computation->root_instruction();
1112   if (root_instruction->opcode() == HloOpcode::kAdd ||
1113       root_instruction->opcode() == HloOpcode::kOr) {
1114     return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::Zero(
1115                               scatter.shape().element_type())),
1116                           root_instruction->opcode());
1117   } else if (root_instruction->opcode() == HloOpcode::kMultiply ||
1118              root_instruction->opcode() == HloOpcode::kAnd) {
1119     return std::make_pair(HloInstruction::CreateConstant(
1120                               LiteralUtil::One(scatter.shape().element_type())),
1121                           root_instruction->opcode());
1122   } else if (root_instruction->opcode() == HloOpcode::kMaximum) {
1123     return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MinValue(
1124                               scatter.shape().element_type())),
1125                           root_instruction->opcode());
1126   } else if (root_instruction->opcode() == HloOpcode::kMinimum) {
1127     return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MaxValue(
1128                               scatter.shape().element_type())),
1129                           root_instruction->opcode());
1130   }
1131 
1132   return Status(tensorflow::error::Code::INVALID_ARGUMENT,
1133                 "Expected scatter reduce computation which is "
1134                 "add/or/multiply/add/min/max");
1135 }
1136 
1137 namespace {
1138 
DevicesForShardingInternal(const HloSharding & sharding,const absl::flat_hash_set<int64> & available_devices,absl::flat_hash_set<int64> * used)1139 void DevicesForShardingInternal(
1140     const HloSharding& sharding,
1141     const absl::flat_hash_set<int64>& available_devices,
1142     absl::flat_hash_set<int64>* used) {
1143   if (sharding.IsTuple()) {
1144     for (const auto& subsharding : sharding.tuple_elements()) {
1145       DevicesForShardingInternal(subsharding, available_devices, used);
1146     }
1147     return;
1148   }
1149 
1150   if (sharding.IsReplicated()) {
1151     for (int64_t device : available_devices) {
1152       if (!HloSharding::IsReservedDevice(device)) {
1153         used->insert(device);
1154       }
1155     }
1156     return;
1157   }
1158 
1159   DCHECK(std::all_of(
1160       sharding.tile_assignment().begin(), sharding.tile_assignment().end(),
1161       [&](int64_t device) { return available_devices.contains(device); }));
1162   sharding.tile_assignment().Each(
1163       [&](absl::Span<const int64> /*indices*/, int64_t device) {
1164         used->insert(device);
1165       });
1166 }
1167 
1168 }  // namespace
1169 
DevicesForSharding(const HloSharding & sharding,const std::vector<int64> & available_devices)1170 std::vector<int64> DevicesForSharding(
1171     const HloSharding& sharding, const std::vector<int64>& available_devices) {
1172   absl::flat_hash_set<int64> available_set;
1173   for (int64_t device : available_devices) {
1174     available_set.insert(device);
1175   }
1176   absl::flat_hash_set<int64> used_set;
1177   DevicesForShardingInternal(sharding, available_set, &used_set);
1178   std::vector<int64> devices;
1179   for (int64_t device : available_devices) {
1180     if (used_set.contains(device)) {
1181       devices.push_back(device);
1182     }
1183   }
1184   return devices;
1185 }
1186 
PartiallyReplicateTiledShardingOnDims(const HloSharding & sharding,absl::Span<const int64> dims_to_replicate)1187 HloSharding PartiallyReplicateTiledShardingOnDims(
1188     const HloSharding& sharding, absl::Span<const int64> dims_to_replicate) {
1189   if (sharding.IsTileMaximal()) {
1190     return sharding;
1191   }
1192   int64_t group_count = 1;
1193   for (int64_t dim : dims_to_replicate) {
1194     if (sharding.ReplicateOnLastTileDim()) {
1195       CHECK_LT(dim, sharding.tile_assignment().num_dimensions());
1196     }
1197     group_count *= sharding.tile_assignment().dim(dim);
1198   }
1199   if (group_count == 1) {
1200     return sharding;
1201   }
1202   if (group_count == sharding.NumTiles()) {
1203     return HloSharding::Replicate(sharding.metadata());
1204   }
1205   std::vector<int64> dim_permutation(
1206       sharding.tile_assignment().num_dimensions());
1207   std::iota(dim_permutation.begin(), dim_permutation.end(), 0);
1208   absl::c_stable_sort(dim_permutation, [&](const int64_t a, const int64_t b) {
1209     return absl::c_linear_search(dims_to_replicate, a) <
1210            absl::c_linear_search(dims_to_replicate, b);
1211   });
1212   auto transposed = TransposeSharding(sharding, dim_permutation);
1213   auto new_tile = transposed.tile_assignment();
1214   std::vector<int64> new_tile_shape(
1215       sharding.tile_assignment().dimensions().begin(),
1216       sharding.tile_assignment().dimensions().end());
1217   for (int64_t dim : dims_to_replicate) {
1218     new_tile_shape[dim] = 1;
1219   }
1220   if (sharding.ReplicateOnLastTileDim()) {
1221     new_tile_shape.back() *= group_count;
1222   } else {
1223     new_tile_shape.push_back(group_count);
1224   }
1225   new_tile.Reshape(new_tile_shape);
1226   return HloSharding::PartialTile(new_tile, sharding.metadata());
1227 }
1228 
PartiallyReplicateTiledShardingOnAllDimsExcept(const HloSharding & sharding,absl::Span<const int64> dims_to_keep)1229 HloSharding PartiallyReplicateTiledShardingOnAllDimsExcept(
1230     const HloSharding& sharding, absl::Span<const int64> dims_to_keep) {
1231   if (sharding.IsTileMaximal()) {
1232     return sharding;
1233   }
1234   int64_t data_rank = sharding.tile_assignment().num_dimensions();
1235   if (sharding.ReplicateOnLastTileDim()) {
1236     data_rank -= 1;
1237   }
1238   std::vector<int64> dims_to_replicate(data_rank);
1239   absl::c_iota(dims_to_replicate, 0);
1240 
1241   dims_to_replicate.erase(
1242       std::remove_if(
1243           dims_to_replicate.begin(), dims_to_replicate.end(),
1244           [&](int64_t i) { return absl::c_linear_search(dims_to_keep, i); }),
1245       dims_to_replicate.end());
1246   return PartiallyReplicateTiledShardingOnDims(sharding, dims_to_replicate);
1247 }
1248 
RemoveShapeDimensions(const HloSharding & sharding,const std::vector<int64> & dims_to_remove)1249 HloSharding RemoveShapeDimensions(const HloSharding& sharding,
1250                                   const std::vector<int64>& dims_to_remove) {
1251   if (sharding.IsTileMaximal() || dims_to_remove.empty()) {
1252     return sharding;
1253   }
1254   std::vector<int64> new_tile_shape;
1255   new_tile_shape.reserve(sharding.tile_assignment().num_dimensions() -
1256                          dims_to_remove.size());
1257   for (int64_t i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
1258     if (absl::c_linear_search(dims_to_remove, i)) {
1259       CHECK_EQ(sharding.tile_assignment().dim(i), 1);
1260     } else {
1261       new_tile_shape.push_back(sharding.tile_assignment().dim(i));
1262     }
1263   }
1264   auto new_tile = sharding.tile_assignment();
1265   new_tile.Reshape(new_tile_shape);
1266   return sharding.ReplicateOnLastTileDim()
1267              ? HloSharding::PartialTile(new_tile, sharding.metadata())
1268              : HloSharding::Tile(new_tile, sharding.metadata());
1269 }
1270 
TransposeShardingWithCollapsedDims(const HloSharding & source,absl::Span<int64 const> src_to_tgt,absl::Span<int64 const> tgt_to_src)1271 absl::optional<HloSharding> TransposeShardingWithCollapsedDims(
1272     const HloSharding& source, absl::Span<int64 const> src_to_tgt,
1273     absl::Span<int64 const> tgt_to_src) {
1274   if (source.IsTileMaximal()) {
1275     return source;
1276   }
1277   if (source.ReplicateOnLastTileDim() &&
1278       src_to_tgt.size() < source.tile_assignment().num_dimensions()) {
1279     std::vector<int64> new_src_to_tgt(src_to_tgt.begin(), src_to_tgt.end());
1280     new_src_to_tgt.push_back(tgt_to_src.size());
1281     std::vector<int64> new_tgt_to_src(tgt_to_src.begin(), tgt_to_src.end());
1282     new_tgt_to_src.push_back(src_to_tgt.size());
1283     return TransposeShardingWithCollapsedDims(source, new_src_to_tgt,
1284                                               new_tgt_to_src);
1285   }
1286   std::vector<int64> tgt_dims_skipping_new(tgt_to_src.size(), -1);
1287   int64_t skipped_tgt_dims = 0;
1288   for (int64_t i = 0; i < tgt_to_src.size(); ++i) {
1289     if (tgt_to_src[i] < 0) {
1290       skipped_tgt_dims++;
1291     } else {
1292       tgt_dims_skipping_new[i] = i - skipped_tgt_dims;
1293     }
1294   }
1295   int64_t skipped_src_dims = absl::c_count(src_to_tgt, -1);
1296   std::vector<int64> perm(src_to_tgt.size());
1297   for (int64_t i = 0; i < src_to_tgt.size(); ++i) {
1298     if (src_to_tgt[i] < 0) {
1299       if (source.tile_assignment().dim(i) > 1) {
1300         return absl::nullopt;
1301       }
1302       perm[src_to_tgt.size() - skipped_src_dims] = i;
1303       skipped_src_dims--;
1304     } else {
1305       perm[tgt_dims_skipping_new[src_to_tgt[i]]] = i;
1306     }
1307   }
1308   auto tgt_sharding = hlo_sharding_util::TransposeSharding(source, perm);
1309   auto reshape_tiles = tgt_sharding.tile_assignment();
1310   std::vector<int64> tgt_tiles(tgt_to_src.size(), 1);
1311   for (int64_t i = 0; i < tgt_tiles.size(); ++i) {
1312     if (tgt_to_src[i] >= 0) {
1313       tgt_tiles[i] = reshape_tiles.dim(tgt_dims_skipping_new[i]);
1314     }
1315   }
1316   reshape_tiles.Reshape(tgt_tiles);
1317   return source.ReplicateOnLastTileDim()
1318              ? HloSharding::PartialTile(reshape_tiles, source.metadata())
1319              : HloSharding::Tile(reshape_tiles, source.metadata());
1320 }
1321 
GetDimensionForIota(const HloInstruction * maybe_iota)1322 absl::optional<int64> GetDimensionForIota(const HloInstruction* maybe_iota) {
1323   if (auto* iota = DynCast<HloIotaInstruction>(maybe_iota)) {
1324     return iota->iota_dimension();
1325   }
1326 
1327   if (maybe_iota->shape().element_type() != S32) {
1328     return absl::nullopt;
1329   }
1330   if (maybe_iota->IsConstant()) {
1331     std::vector<bool> is_iota_dim(maybe_iota->shape().rank(), true);
1332     maybe_iota->literal().EachCell<int32>(
1333         [&](absl::Span<const int64> indices, int32_t val) {
1334           for (int64_t i = 0; i < indices.size(); ++i) {
1335             if (val != indices[i]) {
1336               is_iota_dim[i] = false;
1337             }
1338           }
1339         });
1340     for (int64_t i = 0; i < is_iota_dim.size(); ++i) {
1341       if (is_iota_dim[i] && maybe_iota->shape().dimensions(i) > 1) {
1342         return i;
1343       }
1344     }
1345     return absl::nullopt;
1346   }
1347 
1348   if (maybe_iota->opcode() == HloOpcode::kBroadcast) {
1349     auto operand_dim = GetDimensionForIota(maybe_iota->operand(0));
1350     if (operand_dim) {
1351       return maybe_iota->dimensions(*operand_dim);
1352     }
1353     return absl::nullopt;
1354   }
1355   return absl::nullopt;
1356 }
1357 
GetGatherBatchParallelDims(const HloInstruction & hlo)1358 absl::optional<GatherParallelDims> GetGatherBatchParallelDims(
1359     const HloInstruction& hlo) {
1360   const auto& dnums = hlo.gather_dimension_numbers();
1361   int64_t index_dim = dnums.index_vector_dim();
1362   // Try to identify if there's a dimension in the indices that is monotonically
1363   // increasing with a Iota across a certain dimension. This would mean that the
1364   // access in the relative dimension indexed by this index in the operand is
1365   // parallelizable and that we can shard the operand (and the index/output)
1366   // across such dimension.
1367   // For example the pattern:
1368   //   %iota.1 = iota()
1369   //   %indices = concatenate(..., %iota.1, ...)
1370   //   ... = gather(..., %indices)
1371   // is common for tf.reverse_sequence and would match this case.
1372   absl::InlinedVector<const HloIotaInstruction*, 4> iotas;
1373   const HloInstruction* indices = hlo.operand(1);
1374   const int num_indices = dnums.start_index_map_size();
1375   std::vector<int64> index_parallel_in_dim(num_indices, -1);
1376   // Handle cases where we concatenate pieces of the indices one at a time.
1377   if (indices->opcode() == HloOpcode::kConcatenate &&
1378       indices->concatenate_dimension() == index_dim) {
1379     int concatenated_dims = 0;
1380     for (int i = 0; i < indices->operand_count(); ++i) {
1381       const HloInstruction* op = indices->operand(i);
1382       const int64_t num_indices_from_element =
1383           op->shape().dimensions_size() > index_dim
1384               ? op->shape().dimensions(index_dim)
1385               : 1;
1386       if (absl::optional<int64> maybe_iota_dim = GetDimensionForIota(op)) {
1387         if (*maybe_iota_dim != index_dim) {
1388           for (int j = 0; j < num_indices_from_element; ++j) {
1389             index_parallel_in_dim[concatenated_dims + j] = *maybe_iota_dim;
1390           }
1391         }
1392       }
1393       concatenated_dims += num_indices_from_element;
1394     }
1395   } else if (absl::optional<int64> maybe_iota_dim =
1396                  GetDimensionForIota(indices)) {
1397     if (*maybe_iota_dim != index_dim) {
1398       // This is a case of a single iota with index_dim being out of bounds.
1399       const int64_t num_indices_from_element =
1400           indices->shape().dimensions_size() > index_dim
1401               ? indices->shape().dimensions(index_dim)
1402               : 1;
1403       index_parallel_in_dim.assign(num_indices_from_element, *maybe_iota_dim);
1404     }
1405   }
1406   absl::InlinedVector<int64, 1> indices_parallel_dims;
1407   absl::InlinedVector<int64, 1> operand_parallel_dims;
1408   // Map the parallelizable dimension from the iota to the dimensions of the
1409   // output and the operand. These dimensions are interconnected, but between
1410   // operands and index they could have different spots in the shape because the
1411   // position of the index dimension in the operand is determined by
1412   // start_index_map.
1413   for (int i = 0; i < index_parallel_in_dim.size(); ++i) {
1414     int index_parallel_dim = index_parallel_in_dim[i];
1415     if (index_parallel_dim == -1) {
1416       continue;
1417     }
1418     if (absl::c_linear_search(indices_parallel_dims, index_parallel_dim)) {
1419       return absl::nullopt;
1420     }
1421     // Considered parallel only if the slice is of size 1 over the operand.
1422     if (hlo.gather_slice_sizes()[dnums.start_index_map(i)] == 1) {
1423       indices_parallel_dims.push_back(index_parallel_dim);
1424       operand_parallel_dims.push_back(dnums.start_index_map(i));
1425     } else {
1426       index_parallel_in_dim[i] = -1;
1427     }
1428   }
1429   absl::c_sort(indices_parallel_dims);
1430   if (!indices_parallel_dims.empty()) {
1431     return GatherParallelDims{indices_parallel_dims, operand_parallel_dims,
1432                               index_parallel_in_dim};
1433   }
1434   return absl::nullopt;
1435 }
1436 
GatherParallelOutputDims(const HloInstruction & gather,const GatherParallelDims & parallel_dim)1437 absl::InlinedVector<int64, 1> GatherParallelOutputDims(
1438     const HloInstruction& gather, const GatherParallelDims& parallel_dim) {
1439   absl::InlinedVector<int64, 1> output_parallel_dims;
1440   auto indices_parallel_dims = parallel_dim.indices_parallel_dims;
1441   const Shape gather_shape = gather.shape();
1442   auto dnums = gather.gather_dimension_numbers();
1443   for (int i = 0, idx_dim = 0; i < gather_shape.dimensions_size(); ++i) {
1444     if (absl::c_linear_search(dnums.offset_dims(), i)) {
1445       continue;
1446     }
1447     const int index_dim =
1448         idx_dim < dnums.index_vector_dim() ? idx_dim : idx_dim + 1;
1449     if (absl::c_binary_search(indices_parallel_dims, index_dim)) {
1450       output_parallel_dims.push_back(i);
1451     }
1452     ++idx_dim;
1453   }
1454   return output_parallel_dims;
1455 }
1456 
GatherOutputAlignedOperandParallelDims(const HloInstruction & gather,const GatherParallelDims & parallel_dims)1457 absl::InlinedVector<int64, 1> GatherOutputAlignedOperandParallelDims(
1458     const HloInstruction& gather, const GatherParallelDims& parallel_dims) {
1459   absl::InlinedVector<int64, 1> operand_parallel_dim_to_output(
1460       parallel_dims.operand_parallel_dims.size(), -1);
1461   auto dnums = gather.gather_dimension_numbers();
1462   CHECK_LE(parallel_dims.indices_parallel_dims.size(),
1463            parallel_dims.operand_parallel_dims.size());
1464   for (int i = 0; i < parallel_dims.index_parallel_in_dim.size(); ++i) {
1465     // This is the equivalent batch dimension of the indices that corresponds
1466     // to this index dimension.
1467     const int64_t index_parallel_dim = parallel_dims.index_parallel_in_dim[i];
1468     // If it's not an index that is parallel skip.
1469     if (index_parallel_dim == -1) {
1470       continue;
1471     }
1472     // This is small so just look linearly. Populate the operand parallel
1473     // dimensions based on the order of the index batch dims (which is the same
1474     // order as the output).
1475     for (int j = 0; j < parallel_dims.indices_parallel_dims.size(); ++j) {
1476       if (parallel_dims.indices_parallel_dims[j] == index_parallel_dim) {
1477         const int64_t operand_parallel_dim = dnums.start_index_map(i);
1478         if (operand_parallel_dim_to_output[j] == -1) {
1479           operand_parallel_dim_to_output[j] = operand_parallel_dim;
1480         }
1481         break;
1482       }
1483     }
1484   }
1485   return operand_parallel_dim_to_output;
1486 }
1487 
1488 }  // namespace hlo_sharding_util
1489 }  // namespace xla
1490