1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
17
18 #include <algorithm>
19 #include <map>
20 #include <vector>
21
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/types/optional.h"
25 #include "tensorflow/compiler/xla/array.h"
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33
34 namespace xla {
35 namespace hlo_sharding_util {
36
IsShardingMoreSpecific(const HloSharding & lhs,const HloSharding & rhs)37 bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs) {
38 CHECK_EQ(lhs.IsTuple(), rhs.IsTuple()) << lhs << " <> " << rhs;
39 if (lhs.IsTuple()) {
40 // For tuples we consider lhs to have a better sharding if none of the
41 // elements are worse and at least one element is better then in rhs
42 // sharding.
43 const auto& lhs_shardings = lhs.tuple_elements();
44 const auto& rhs_shardings = rhs.tuple_elements();
45 CHECK_EQ(lhs_shardings.size(), rhs_shardings.size());
46 bool is_better = false;
47 for (int64_t i = 0; i < lhs_shardings.size(); ++i) {
48 if (IsShardingMoreSpecific(rhs_shardings[i], lhs_shardings[i])) {
49 return false;
50 }
51 if (IsShardingMoreSpecific(lhs_shardings[i], rhs_shardings[i])) {
52 is_better = true;
53 }
54 }
55 return is_better;
56 }
57 if (!rhs.IsTileMaximal()) {
58 return lhs.NumTiles() > rhs.NumTiles();
59 } else if (!rhs.IsReplicated()) {
60 // If we are not replicated then only tiled (not tile maximal) shardings
61 // can improve us.
62 return !lhs.IsTileMaximal();
63 } else {
64 // If we are replicated then any non-replicated sharding can improve us.
65 return !lhs.IsReplicated();
66 }
67 }
68
MergeSharding(const HloSharding & old,HloSharding * to_merge,bool may_combine_partial_sharding)69 bool MergeSharding(const HloSharding& old, HloSharding* to_merge,
70 bool may_combine_partial_sharding) {
71 if (old.IsTuple()) {
72 CHECK(to_merge->IsTuple());
73 bool changed = false;
74 for (int64_t i = 0; i < old.tuple_elements().size(); ++i) {
75 changed |=
76 MergeSharding(old.tuple_elements()[i], &to_merge->tuple_elements()[i],
77 may_combine_partial_sharding);
78 }
79 return changed;
80 }
81 if (!may_combine_partial_sharding || !old.ReplicateOnLastTileDim() ||
82 !to_merge->ReplicateOnLastTileDim() ||
83 old.tile_assignment().num_elements() !=
84 to_merge->tile_assignment().num_elements()) {
85 return IsShardingMoreSpecific(*to_merge, old);
86 }
87
88 if (MergeShardingIfCompatible(
89 old,
90 /*minimum_tiles=*/std::max(old.NumTiles(), to_merge->NumTiles()) + 1,
91 to_merge)) {
92 return true;
93 }
94 return IsShardingMoreSpecific(*to_merge, old);
95 }
96
MergeShardingIfCompatible(const HloSharding & to_merge,int64_t minimum_tiles,HloSharding * dst)97 bool MergeShardingIfCompatible(const HloSharding& to_merge,
98 int64_t minimum_tiles, HloSharding* dst) {
99 if (to_merge.IsTileMaximal()) {
100 return false;
101 }
102 if (dst->IsTileMaximal()) {
103 *dst = to_merge;
104 return true;
105 }
106 // Combine the tile dimension sizes from dst and to_merge.
107 int64_t num_devices = to_merge.tile_assignment().num_elements();
108 std::vector<int64> merged_tile_dims;
109 bool compatible = true;
110 merged_tile_dims.reserve(dst->tile_assignment().num_dimensions());
111 for (int64_t i = 0; i < dst->tile_assignment().num_dimensions() - 1; ++i) {
112 int64_t dst_dim = dst->tile_assignment().dim(i);
113 int64_t merge_dim = to_merge.tile_assignment().dim(i);
114 if (dst_dim == 1) {
115 merged_tile_dims.push_back(merge_dim);
116 } else if (merge_dim == 1) {
117 merged_tile_dims.push_back(dst_dim);
118 } else if (dst_dim == merge_dim) {
119 merged_tile_dims.push_back(dst_dim);
120 } else {
121 compatible = false;
122 break;
123 }
124 }
125 int64_t merged_tiles = Product(merged_tile_dims);
126 if (!compatible || num_devices % Product(merged_tile_dims) != 0 ||
127 merged_tiles < minimum_tiles) {
128 return false;
129 }
130 int64_t replication = num_devices / merged_tiles;
131 merged_tile_dims.push_back(replication);
132 Array<int64> merged_tile(merged_tile_dims);
133 // Maps from replication group ID to sorted members.
134 absl::flat_hash_map<int64, std::set<int64>> merge_group_members;
135 absl::flat_hash_map<int64, std::set<int64>> dst_group_members;
136 auto get_group_index = [&](absl::Span<const int64> tile_indices,
137 const HloSharding& sharding) {
138 int64_t group_id = 0;
139 for (int64_t i = 0; i < tile_indices.size() - 1; ++i) {
140 group_id *= dst->tile_assignment().dim(i);
141 group_id += tile_indices[i];
142 }
143 return group_id;
144 };
145 to_merge.tile_assignment().Each(
146 [&](absl::Span<const int64> indices, int64_t device) {
147 merge_group_members[get_group_index(indices, to_merge)].insert(device);
148 });
149 dst->tile_assignment().Each(
150 [&](absl::Span<const int64> indices, int64_t device) {
151 dst_group_members[get_group_index(indices, *dst)].insert(device);
152 });
153 // Try to find the intersection of to_merge and dst replication groups, in
154 // order to determine the merged tile assignment.
155 merged_tile.Each([&](absl::Span<const int64> indices, int64* device) {
156 if (!compatible) {
157 return;
158 }
159 std::vector<int64> to_merge_index(indices.begin(), indices.end());
160 std::vector<int64> dst_index = to_merge_index;
161 for (int64_t i = 0; i < indices.size() - 1; ++i) {
162 if (to_merge.tile_assignment().dim(i) == 1) {
163 to_merge_index[i] = 0;
164 }
165 if (dst->tile_assignment().dim(i) == 1) {
166 dst_index[i] = 0;
167 }
168 }
169 int64_t to_merge_group_id = get_group_index(to_merge_index, to_merge);
170 int64_t dst_group_id = get_group_index(dst_index, *dst);
171 if (merge_group_members[to_merge_group_id].empty() ||
172 dst_group_members[dst_group_id].empty()) {
173 compatible = false;
174 return;
175 }
176
177 int64_t smallest_to_merge = *merge_group_members[to_merge_group_id].begin();
178 int64_t smallest_dst = *dst_group_members[dst_group_id].begin();
179 if (smallest_to_merge < smallest_dst) {
180 if (merge_group_members[to_merge_group_id].count(smallest_dst) == 0) {
181 compatible = false;
182 return;
183 }
184 *device = smallest_dst;
185 } else {
186 if (dst_group_members[dst_group_id].count(smallest_to_merge) == 0) {
187 compatible = false;
188 return;
189 }
190 *device = smallest_to_merge;
191 }
192 merge_group_members[to_merge_group_id].erase(*device);
193 dst_group_members[dst_group_id].erase(*device);
194 });
195 if (!compatible) {
196 return false;
197 }
198 std::vector<OpMetadata> merged_metadata(std::move(dst->metadata()));
199 merged_metadata.reserve(merged_metadata.size() + to_merge.metadata().size());
200 const absl::flat_hash_set<OpMetadata, protobuf_util::ProtobufHashWrapper,
201 protobuf_util::ProtobufEqualsWrapper>
202 metadata_set(merged_metadata.begin(), merged_metadata.end());
203 absl::c_copy_if(to_merge.metadata(), std::back_inserter(merged_metadata),
204 [&metadata_set](const OpMetadata& data) {
205 return !ContainsKey(metadata_set, data);
206 });
207 if (replication == 1) {
208 merged_tile_dims.pop_back();
209 merged_tile.Reshape(merged_tile_dims);
210 *dst = HloSharding::Tile(merged_tile, merged_metadata);
211 } else {
212 *dst = HloSharding::PartialTile(merged_tile, merged_metadata);
213 }
214 return true;
215 }
216
SelectDominantDevice(const std::map<int64,int64> & device_map,int64 * top_count)217 absl::optional<int64> SelectDominantDevice(
218 const std::map<int64, int64>& device_map, int64* top_count) {
219 int64_t device = 0;
220 int64_t count = 0;
221 for (auto& it : device_map) {
222 if (it.second > count) {
223 count = it.second;
224 device = it.first;
225 }
226 }
227 if (top_count != nullptr) {
228 *top_count = count;
229 }
230 return count > 0 ? absl::optional<int64>(device) : absl::optional<int64>();
231 }
232
AssignComputationDevice(HloComputation * computation,int64_t device)233 Status AssignComputationDevice(HloComputation* computation, int64_t device) {
234 VLOG(4) << "Assigning device " << device << " to " << computation->name()
235 << " computation";
236 for (HloInstruction* instruction : computation->instructions()) {
237 if (!instruction->has_sharding()) {
238 VLOG(4) << "Assigning device " << device << " to " << instruction->name();
239 instruction->set_device_sharding(device);
240 }
241 }
242 return Status::OK();
243 }
244
GetMostOccurringDevice(absl::Span<HloInstruction * const> instructions)245 absl::optional<int64> GetMostOccurringDevice(
246 absl::Span<HloInstruction* const> instructions) {
247 std::map<int64, int64> device_map;
248 for (HloInstruction* instruction : instructions) {
249 if (instruction->has_sharding()) {
250 for (auto& it : instruction->sharding().UsedDevices(nullptr)) {
251 // The UsedDevices() API returns a map<device, occurrence_count>.
252 device_map[it.first] += it.second;
253 }
254 }
255 }
256 return SelectDominantDevice(device_map, nullptr);
257 }
258
GetDominantDevice(absl::Span<HloComputation * const> computations,double dominant_factor)259 StatusOr<absl::optional<int64>> GetDominantDevice(
260 absl::Span<HloComputation* const> computations, double dominant_factor) {
261 int64_t instruction_count = 0;
262 std::map<int64, int64> device_map;
263 for (HloComputation* computation : computations) {
264 for (HloInstruction* instruction : computation->instructions()) {
265 int64_t count = 1;
266 if (instruction->has_sharding()) {
267 for (auto& it : instruction->sharding().UsedDevices(&count)) {
268 // The UsedDevices() API returns a map<device, occurrence_count>.
269 device_map[it.first] += it.second;
270 }
271 }
272 instruction_count += count;
273 }
274 }
275 int64_t count;
276 absl::optional<int64> device = SelectDominantDevice(device_map, &count);
277 absl::optional<int64> dominant_device;
278 if (device) {
279 double factor =
280 static_cast<double>(count) / static_cast<double>(instruction_count);
281 if (factor >= dominant_factor) {
282 dominant_device = device;
283 }
284 }
285 return dominant_device;
286 }
287
TransposeSharding(const HloSharding & sharding,const std::vector<int64> & dimensions)288 HloSharding TransposeSharding(const HloSharding& sharding,
289 const std::vector<int64>& dimensions) {
290 if (sharding.IsTileMaximal()) {
291 return sharding;
292 }
293 auto perm_dimensions = dimensions;
294 if (sharding.ReplicateOnLastTileDim() &&
295 dimensions.size() < sharding.tile_assignment().num_dimensions()) {
296 perm_dimensions.push_back(dimensions.size());
297 }
298 const int64_t rank = perm_dimensions.size();
299 std::vector<int64> tile_assignment_dim(rank);
300 for (int64_t i = 0; i < rank; ++i) {
301 tile_assignment_dim[i] = sharding.tile_assignment().dim(perm_dimensions[i]);
302 }
303 Array<int64> tile_assignment = sharding.tile_assignment();
304 tile_assignment.Reshape(tile_assignment_dim);
305 tile_assignment.Each([&](absl::Span<const int64> indices, int64* value) {
306 std::vector<int64> src_indices(indices.size(), -1);
307 for (int64_t i = 0; i < indices.size(); ++i) {
308 src_indices[perm_dimensions[i]] = indices[i];
309 }
310 *value = sharding.tile_assignment()(src_indices);
311 });
312 return sharding.ReplicateOnLastTileDim()
313 ? HloSharding::PartialTile(tile_assignment, sharding.metadata())
314 : HloSharding::Tile(tile_assignment, sharding.metadata());
315 }
316
ReshapeSharding(const Shape & source_shape,const Shape & target_shape,const HloSharding & sharding)317 absl::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
318 const Shape& target_shape,
319 const HloSharding& sharding) {
320 if (sharding.IsTileMaximal()) {
321 return sharding;
322 }
323
324 // In case of a tiled sharding the reshaped sharding will be a valid if the
325 // reshape is composed from the following operations:
326 // * Adding or removing dimensions with size 1.
327 // * Merging consecutive dimensions where only the most major is sharded.
328 // * Splitting a dimension to consecutive dimensions.
329 // * Any reshaping of unsharded dimensions.
330 // Note that merge and split can happen consecutively on the same dimension,
331 // e.g., f32[1024,256,1024] to f32[128,2048,1024] can be considered that 1024
332 // gets split into 128 and 8, but 8 then gets merged with 256. We use stacks
333 // to make supporting such cases easy.
334 const Shape tile_shape = sharding.TileShape(source_shape);
335 std::vector<int64> target_tile_assignment_dimensions;
336 std::vector<int64> source_dims_stack(source_shape.rank());
337 std::vector<int64> target_dims_stack(target_shape.rank());
338 std::vector<int64> sharding_tile_dims_stack(source_shape.rank());
339 for (int64_t i = 0; i < source_shape.rank(); ++i) {
340 source_dims_stack[i] = source_shape.dimensions(source_shape.rank() - 1 - i);
341 sharding_tile_dims_stack[i] =
342 sharding.tile_assignment().dim(source_shape.rank() - 1 - i);
343 }
344 for (int64_t i = 0; i < target_shape.rank(); ++i) {
345 target_dims_stack[i] = target_shape.dimensions(target_shape.rank() - 1 - i);
346 }
347 while (!source_dims_stack.empty() || !target_dims_stack.empty()) {
348 if (target_dims_stack.empty()) {
349 if (Product(sharding_tile_dims_stack) != 1) {
350 return absl::nullopt;
351 }
352 break;
353 }
354 int64_t s_size = 1;
355 int64_t t_size = 1;
356 int64_t s_partitions = 1;
357 if (!source_dims_stack.empty()) {
358 s_size = source_dims_stack.back();
359 source_dims_stack.pop_back();
360 s_partitions = sharding_tile_dims_stack.back();
361 sharding_tile_dims_stack.pop_back();
362 }
363 t_size = target_dims_stack.back();
364 target_dims_stack.pop_back();
365 if (s_partitions * Product(sharding_tile_dims_stack) == 1) {
366 // No more partitions left.
367 target_tile_assignment_dimensions.push_back(1);
368 continue;
369 }
370 if (s_size == t_size) {
371 // Same dimension.
372 target_tile_assignment_dimensions.push_back(s_partitions);
373 } else if (t_size == 1) {
374 // Trivial dimension added.
375 target_tile_assignment_dimensions.push_back(1);
376 source_dims_stack.push_back(s_size);
377 sharding_tile_dims_stack.push_back(s_partitions);
378 } else if (s_size == 1) {
379 // Trivial dimension removed.
380 if (s_partitions != 1) {
381 return absl::nullopt;
382 }
383 target_dims_stack.push_back(t_size);
384 } else if (s_size > t_size) {
385 // Dimension split.
386 if (s_size % t_size != 0 || s_size % s_partitions != 0) {
387 return absl::nullopt;
388 }
389 if (t_size % s_partitions == 0) {
390 target_tile_assignment_dimensions.push_back(s_partitions);
391 // We have part of the s_size unprocessed, so put it back to stack.
392 source_dims_stack.push_back(s_size / t_size);
393 sharding_tile_dims_stack.push_back(1);
394 } else if (s_partitions % t_size == 0) {
395 target_tile_assignment_dimensions.push_back(t_size);
396 // We have part of the s_size unprocessed, so put it back to stack.
397 source_dims_stack.push_back(s_size / t_size);
398 sharding_tile_dims_stack.push_back(s_partitions / t_size);
399 } else {
400 return absl::nullopt;
401 }
402 } else {
403 // Dimension merge. Also merge the source dimension with the next, and
404 // process it next time.
405 if (s_size % s_partitions != 0) {
406 return absl::nullopt;
407 }
408 CHECK(!source_dims_stack.empty());
409 if (sharding_tile_dims_stack.back() != 1 && s_size != s_partitions) {
410 // If the next dimension to combine is sharded, we require that the
411 // current dimension's shard size to be 1. Otherwise, the new shard
412 // would be non-contiguous.
413 return absl::nullopt;
414 }
415 source_dims_stack.back() *= s_size;
416 sharding_tile_dims_stack.back() *= s_partitions;
417 target_dims_stack.push_back(t_size);
418 }
419 }
420 Array<int64> new_tile_assignment = sharding.tile_assignment();
421 if (sharding.ReplicateOnLastTileDim()) {
422 target_tile_assignment_dimensions.push_back(
423 sharding.tile_assignment().dimensions().back());
424 }
425 new_tile_assignment.Reshape(target_tile_assignment_dimensions);
426 return sharding.ReplicateOnLastTileDim()
427 ? HloSharding::PartialTile(new_tile_assignment,
428 sharding.metadata())
429 : HloSharding::Tile(new_tile_assignment, sharding.metadata());
430 }
431
ReverseSharding(const HloSharding & sharding,absl::Span<const int64> dimensions)432 HloSharding ReverseSharding(const HloSharding& sharding,
433 absl::Span<const int64> dimensions) {
434 if (sharding.IsTileMaximal() || dimensions.empty()) {
435 return sharding;
436 }
437
438 Array<int64> new_tile_assignment(sharding.tile_assignment().dimensions());
439 new_tile_assignment.Each([&](absl::Span<const int64> indices, int64* device) {
440 std::vector<int64> original_indices(indices.begin(), indices.end());
441 for (int64_t d : dimensions) {
442 original_indices[d] =
443 new_tile_assignment.dim(d) - 1 - original_indices[d];
444 }
445 *device = sharding.tile_assignment()(original_indices);
446 });
447 return sharding.ReplicateOnLastTileDim()
448 ? HloSharding::PartialTile(new_tile_assignment,
449 sharding.metadata())
450 : HloSharding::Tile(new_tile_assignment, sharding.metadata());
451 }
452
ReshapeToTileDimension(const HloSharding & sharding,int64_t dim,absl::Span<const int64> dims)453 HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64_t dim,
454 absl::Span<const int64> dims) {
455 CHECK(!sharding.IsTuple() && !sharding.IsTileMaximal());
456 CHECK_NE(absl::c_find(dims, dim), dims.end()) << "dim is not in dims";
457 // We optimize the tile assignment on the single dimension dim in a way to
458 // minimize communication among devices caused by the reshard:
459 // +---+---+ +---+---+ +-+-+-+-+
460 // | | | | 0 | | | | | |
461 // | 0 | 1 | +-------+ | | | | |
462 // | | | reshape on | 1 | reshape on | | | | |
463 // +---+---+ dim 0 => +-------+ dim 1 => |0|2|1|3|
464 // | | | | 2 | | | | | |
465 // | 2 | 3 | +-------+ | | | | |
466 // | | | | 3 | | | | | |
467 // +---+---+ +---+---+ +-+-+-+-+
468
469 std::vector<int64> tile_dims(sharding.tile_assignment().num_dimensions(), 1);
470 // Handle ignore dimensions.
471 std::vector<int64> ignore_sizes;
472 int64_t ignore_size = 1;
473 for (int64_t i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
474 if (absl::c_find(dims, i) == dims.end()) {
475 int64_t size = sharding.tile_assignment().dim(i);
476 ignore_sizes.push_back(size);
477 tile_dims[i] = size;
478 ignore_size *= size;
479 }
480 }
481
482 using Buckets = std::vector<std::vector<int64>>;
483 Array<Buckets> buckets(ignore_sizes,
484 Buckets(sharding.tile_assignment().dim(dim)));
485 sharding.tile_assignment().Each(
486 [&](absl::Span<const int64> index, int64_t device) {
487 std::vector<int64> ignore_index;
488 for (int64_t i = 0; i < index.size(); ++i) {
489 if (absl::c_find(dims, i) == dims.end()) {
490 ignore_index.push_back(index[i]);
491 }
492 }
493 buckets(ignore_index)[index[dim]].push_back(device);
494 });
495 std::vector<int64> devices;
496 buckets.Each([&](absl::Span<const int64> index, const Buckets& buckets) {
497 for (auto& bucket : buckets) {
498 devices.insert(devices.end(), bucket.begin(), bucket.end());
499 }
500 });
501 tile_dims[dim] = devices.size() / ignore_size;
502 Array<int64> tile_assignment(tile_dims);
503 tile_assignment.SetValues(devices);
504 return HloSharding::Tile(tile_assignment, sharding.metadata());
505 }
506
ContainsTileSharding(const HloModule & module)507 bool ContainsTileSharding(const HloModule& module) {
508 for (const HloComputation* computation : module.computations()) {
509 for (const HloInstruction* instruction : computation->instructions()) {
510 if (instruction->has_sharding() &&
511 !instruction->sharding().IsTileMaximal()) {
512 return true;
513 }
514 }
515 }
516 return false;
517 }
518
GatherOutputSharding(const HloSharding & index_sharding,const HloInstruction * hlo)519 HloSharding GatherOutputSharding(const HloSharding& index_sharding,
520 const HloInstruction* hlo) {
521 if (index_sharding.IsTileMaximal()) {
522 return index_sharding;
523 }
524
525 const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers();
526 std::vector<int64> output_tile_assignment_dims;
527 for (int64_t i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) {
528 if (absl::c_binary_search(dnums.offset_dims(), i)) {
529 output_tile_assignment_dims.push_back(1);
530 } else {
531 const int64_t new_tile_dimension =
532 index_dim >= dnums.index_vector_dim() ? index_dim + 1 : index_dim;
533 output_tile_assignment_dims.push_back(
534 index_sharding.tile_assignment().dim(new_tile_dimension));
535 ++index_dim;
536 }
537 }
538
539 if (index_sharding.ReplicateOnLastTileDim()) {
540 output_tile_assignment_dims.push_back(
541 index_sharding.tile_assignment().dimensions().back());
542 }
543
544 Array<int64> new_tile_assignment = index_sharding.tile_assignment();
545 if (new_tile_assignment.num_elements() !=
546 Product(output_tile_assignment_dims)) {
547 return HloSharding::Replicate(index_sharding.metadata());
548 }
549 new_tile_assignment.Reshape(output_tile_assignment_dims);
550 return index_sharding.ReplicateOnLastTileDim()
551 ? HloSharding::PartialTile(new_tile_assignment,
552 index_sharding.metadata())
553 : HloSharding::Tile(new_tile_assignment,
554 index_sharding.metadata());
555 }
556
GatherIndexSharding(const HloSharding & output_sharding,const HloInstruction * hlo)557 HloSharding GatherIndexSharding(const HloSharding& output_sharding,
558 const HloInstruction* hlo) {
559 CHECK(hlo->opcode() == HloOpcode::kGather);
560 if (output_sharding.IsTileMaximal()) {
561 return output_sharding;
562 }
563
564 const GatherDimensionNumbers& dnums = hlo->gather_dimension_numbers();
565 std::vector<int64> index_tile_assignment_dims;
566 // Relevant output dims have shardings passed to the index.
567 std::vector<int64> relevant_output_dims;
568 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
569 if (!absl::c_binary_search(dnums.offset_dims(), i)) {
570 index_tile_assignment_dims.push_back(
571 output_sharding.tile_assignment().dim(i));
572 relevant_output_dims.push_back(i);
573 }
574 }
575 int64_t index_rank = hlo->operand(1)->shape().rank();
576
577 // Vector indices sharding is not supported yet.
578 if (index_rank > index_tile_assignment_dims.size()) {
579 index_tile_assignment_dims.insert(
580 index_tile_assignment_dims.begin() + dnums.index_vector_dim(), 1);
581 }
582
583 if (Product(index_tile_assignment_dims) == 1) {
584 return HloSharding::Replicate(output_sharding.metadata());
585 }
586 HloSharding relevant_output_sharding =
587 PartiallyReplicateTiledShardingOnAllDimsExcept(output_sharding,
588 relevant_output_dims);
589 int64_t partial_replication_size = 1;
590 if (relevant_output_sharding.ReplicateOnLastTileDim()) {
591 partial_replication_size *=
592 relevant_output_sharding.tile_assignment().dimensions().back();
593 }
594
595 Array<int64> new_tile_assignment = relevant_output_sharding.tile_assignment();
596 const int64_t index_tile_elements =
597 Product(index_tile_assignment_dims) * partial_replication_size;
598 if (new_tile_assignment.num_elements() != index_tile_elements) {
599 if (new_tile_assignment.num_elements() % index_tile_elements == 0) {
600 partial_replication_size *=
601 (new_tile_assignment.num_elements() / index_tile_elements);
602 } else {
603 return HloSharding::Replicate(output_sharding.metadata());
604 }
605 }
606 if (partial_replication_size > 1) {
607 index_tile_assignment_dims.push_back(partial_replication_size);
608 }
609 new_tile_assignment.Reshape(index_tile_assignment_dims);
610 return partial_replication_size > 1
611 ? HloSharding::PartialTile(new_tile_assignment,
612 output_sharding.metadata())
613 : HloSharding::Tile(new_tile_assignment,
614 output_sharding.metadata());
615 }
616
GatherEffectiveOutputSharding(const HloInstruction & hlo)617 HloSharding GatherEffectiveOutputSharding(const HloInstruction& hlo) {
618 if (hlo.sharding().IsTileMaximal()) {
619 return hlo.sharding();
620 }
621
622 const GatherDimensionNumbers& dnums = hlo.gather_dimension_numbers();
623 std::vector<int64> tile_assignment_dims(hlo.shape().rank());
624 int64_t num_elements = 1;
625 for (int64_t i = 0; i < hlo.shape().rank(); ++i) {
626 if (!absl::c_binary_search(dnums.offset_dims(), i)) {
627 tile_assignment_dims[i] = hlo.sharding().tile_assignment().dim(i);
628 num_elements *= hlo.sharding().tile_assignment().dim(i);
629 } else {
630 tile_assignment_dims[i] = 1;
631 }
632 }
633 if (num_elements == hlo.sharding().tile_assignment().num_elements()) {
634 // Output sharding is only on non offset dimensions. We use output sharding
635 // to shard this gather op directly.
636 return hlo.sharding();
637 }
638
639 if (num_elements == 1) {
640 // Output sharding is only on offset dimensions. We do not shard this gather
641 // op. Return a tile maximal sharding with the first device in output
642 // sharding tile assignment.
643 return HloSharding::AssignDevice(*hlo.sharding().tile_assignment().begin(),
644 hlo.sharding().metadata());
645 }
646
647 // Output sharding is on both offset and non offset dimensions. We shard the
648 // gather op only on non offset dimensions.
649 // For example:
650 // - the gather op has sharding [2,2]{0,1,2,3},
651 // - first dimension is non offset dimension,
652 // - second dimension is offset dimension,
653 // Then the result sharding will be [2,1]{0,2}.
654 std::vector<int64> slice_starts(hlo.shape().rank(), 0LL),
655 slice_limits(hlo.shape().rank());
656 for (int64_t i = 0; i < hlo.shape().rank(); ++i) {
657 if (!absl::c_binary_search(dnums.offset_dims(), i)) {
658 slice_limits[i] = hlo.sharding().tile_assignment().dim(i);
659 } else {
660 slice_limits[i] = 1;
661 }
662 }
663 Array<int64> tile_assignment =
664 hlo.sharding().tile_assignment().Slice(slice_starts, slice_limits);
665 return HloSharding::Tile(tile_assignment, hlo.sharding().metadata());
666 }
667
ScatterIndexSharding(const HloSharding & data_sharding,const HloInstruction * hlo)668 HloSharding ScatterIndexSharding(const HloSharding& data_sharding,
669 const HloInstruction* hlo) {
670 if (data_sharding.IsTileMaximal()) {
671 return data_sharding;
672 }
673
674 const ScatterDimensionNumbers& dnums = hlo->scatter_dimension_numbers();
675 std::vector<int64> index_tile_assignment_dims;
676 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
677 if (!absl::c_binary_search(dnums.update_window_dims(), i)) {
678 index_tile_assignment_dims.push_back(
679 data_sharding.tile_assignment().dim(i));
680 }
681 }
682 if (index_tile_assignment_dims.size() < hlo->operand(1)->shape().rank()) {
683 index_tile_assignment_dims.push_back(1);
684 }
685 if (data_sharding.ReplicateOnLastTileDim()) {
686 index_tile_assignment_dims.push_back(
687 data_sharding.tile_assignment().dimensions().back());
688 }
689 Array<int64> new_tile_assignment = data_sharding.tile_assignment();
690 if (new_tile_assignment.num_elements() !=
691 Product(index_tile_assignment_dims)) {
692 return HloSharding::Replicate(data_sharding.metadata());
693 }
694 new_tile_assignment.Reshape(index_tile_assignment_dims);
695 return data_sharding.ReplicateOnLastTileDim()
696 ? HloSharding::PartialTile(new_tile_assignment,
697 data_sharding.metadata())
698 : HloSharding::Tile(new_tile_assignment, data_sharding.metadata());
699 }
700
ScatterDataSharding(const HloSharding & index_sharding,const HloInstruction * hlo)701 HloSharding ScatterDataSharding(const HloSharding& index_sharding,
702 const HloInstruction* hlo) {
703 if (index_sharding.IsTileMaximal()) {
704 return index_sharding;
705 }
706
707 const ScatterDimensionNumbers& dnums = hlo->scatter_dimension_numbers();
708 std::vector<int64> data_tile_assignment_dims;
709 for (int64_t i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) {
710 if (absl::c_binary_search(dnums.update_window_dims(), i)) {
711 data_tile_assignment_dims.push_back(1);
712 } else {
713 data_tile_assignment_dims.push_back(
714 index_sharding.tile_assignment().dim(index_dim));
715 index_dim++;
716 }
717 }
718 if (index_sharding.ReplicateOnLastTileDim()) {
719 data_tile_assignment_dims.push_back(
720 index_sharding.tile_assignment().dimensions().back());
721 }
722 Array<int64> new_tile_assignment = index_sharding.tile_assignment();
723 if (new_tile_assignment.num_elements() !=
724 Product(data_tile_assignment_dims)) {
725 return HloSharding::Replicate(index_sharding.metadata());
726 }
727 new_tile_assignment.Reshape(data_tile_assignment_dims);
728 return index_sharding.ReplicateOnLastTileDim()
729 ? HloSharding::PartialTile(new_tile_assignment,
730 index_sharding.metadata())
731 : HloSharding::Tile(new_tile_assignment,
732 index_sharding.metadata());
733 }
734
ScatterEffectiveIndexSharding(const HloSharding & index_sharding,const HloInstruction & hlo)735 HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding,
736 const HloInstruction& hlo) {
737 if (index_sharding.IsTileMaximal()) {
738 return index_sharding;
739 }
740
741 // Only shard on first "number of scatter_window_dims" dimensions.
742 const ScatterDimensionNumbers& dnums = hlo.scatter_dimension_numbers();
743 int64_t num_elements = 1;
744 int64_t index_dim = 0;
745 for (int64_t i = 0; i < hlo.shape().rank(); ++i) {
746 if (absl::c_binary_search(dnums.inserted_window_dims(), i)) {
747 num_elements *= index_sharding.tile_assignment().dim(index_dim);
748 index_dim++;
749 }
750 }
751 if (num_elements == index_sharding.tile_assignment().num_elements()) {
752 // Index sharding is only on scatter_window_dims. We use this index sharding
753 // directly.
754 return index_sharding;
755 }
756
757 // Index sharding is only on update_window_dims. We do not shard this scatter
758 // op. Return a tile maximal sharding with the first device in index sharding
759 // tile assignment.
760 if (num_elements == 1) {
761 return HloSharding::AssignDevice(*index_sharding.tile_assignment().begin(),
762 index_sharding.metadata());
763 }
764
765 const int64_t index_rank = hlo.operand(1)->shape().rank();
766 std::vector<int64> slice_starts(index_rank, 0LL), slice_limits(index_rank);
767 for (int64_t i = 0; i < index_rank; ++i) {
768 if (i < index_dim) {
769 slice_limits[i] = index_sharding.tile_assignment().dim(i);
770 } else {
771 slice_limits[i] = 1;
772 }
773 }
774 Array<int64> tile_assignment =
775 index_sharding.tile_assignment().Slice(slice_starts, slice_limits);
776 return HloSharding::Tile(tile_assignment, index_sharding.metadata());
777 }
778
ScatterEffectiveDataSharding(const HloSharding & data_sharding,const HloInstruction & hlo)779 HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding,
780 const HloInstruction& hlo) {
781 if (data_sharding.IsTileMaximal()) {
782 return data_sharding;
783 }
784
785 const ScatterDimensionNumbers& dnums = hlo.scatter_dimension_numbers();
786 const int64_t data_rank = hlo.operand(2)->shape().rank();
787 std::vector<int64> tile_assignment_dims(data_rank, 1LL);
788 int64_t num_elements = 1;
789 for (int64_t i = 0; i < hlo.shape().rank(); ++i) {
790 if (absl::c_binary_search(dnums.inserted_window_dims(), i)) {
791 CHECK_LT(i, data_rank);
792 tile_assignment_dims[i] = data_sharding.tile_assignment().dim(i);
793 num_elements *= data_sharding.tile_assignment().dim(i);
794 }
795 }
796 if (num_elements == data_sharding.tile_assignment().num_elements()) {
797 // Data sharding is only on scatter_window_dims. We use this data sharding
798 // directly.
799 return data_sharding;
800 }
801
802 if (num_elements == 1) {
803 // Data sharding is only on update_window_dims. We do not shard this
804 // scatter op. Return a tile maximal sharding with the first device in
805 // data sharding tile assignment.
806 return HloSharding::AssignDevice(*data_sharding.tile_assignment().begin(),
807 data_sharding.metadata());
808 }
809
810 // Data sharding is on both update_window_dims and scatter_window_dims. We
811 // shard the scatter op only on scatter_window_dims. For example:
812 // - the scatter data has sharding [2,2]{0,1,2,3},
813 // - first dimension is scatter_window_dims,
814 // - second dimension is update_window_dims,
815 // Then the result sharding will be [2,1]{0,2}.
816 std::vector<int64> slice_starts(data_rank, 0LL);
817 Array<int64> tile_assignment =
818 data_sharding.tile_assignment().Slice(slice_starts, tile_assignment_dims);
819 return HloSharding::Tile(tile_assignment, data_sharding.metadata());
820 }
821
822 namespace {
823
824 // If partitioning in the operand only happens in dimensions in passthrough
825 // dimensions (offset dimensions in the gather output (or scatter update) that
826 // have the same size as the operand), returns the corresponding output (or
827 // update) sharding by passing through the input sharding.
PassthroughOperandToGatherOutputOrScatterUpdate(const Shape & operand_shape,const HloSharding & operand_sharding,const Shape & update_or_gather_shape,absl::Span<const int64> collapsed_or_inserted_dims,absl::Span<const int64> index_map,absl::Span<const int64> offset_or_window_dims,absl::Span<const int64> slice_size)828 absl::optional<HloSharding> PassthroughOperandToGatherOutputOrScatterUpdate(
829 const Shape& operand_shape, const HloSharding& operand_sharding,
830 const Shape& update_or_gather_shape,
831 absl::Span<const int64> collapsed_or_inserted_dims,
832 absl::Span<const int64> index_map,
833 absl::Span<const int64> offset_or_window_dims,
834 absl::Span<const int64> slice_size) {
835 if (operand_sharding.IsTileMaximal()) {
836 return operand_sharding;
837 }
838 std::vector<int64> passthrough_tile(update_or_gather_shape.rank(), 1);
839 int64_t collapsed = 0;
840 for (int64_t i = 0; i < operand_shape.rank(); ++i) {
841 int64_t dim_partitions = operand_sharding.tile_assignment().dim(i);
842 if (absl::c_linear_search(collapsed_or_inserted_dims, i) ||
843 absl::c_linear_search(index_map, i)) {
844 if (dim_partitions > 1) {
845 return absl::nullopt;
846 }
847 collapsed++;
848 continue;
849 }
850 if (slice_size[i] != operand_shape.dimensions(i) && dim_partitions > 1) {
851 return absl::nullopt;
852 }
853 int64_t offset_dim = offset_or_window_dims[i - collapsed];
854 if (i - collapsed > 0 &&
855 offset_dim < offset_or_window_dims[i - collapsed - 1]) {
856 // Output offsets are transposed, we do not support this case.
857 return absl::nullopt;
858 }
859 passthrough_tile[offset_dim] = dim_partitions;
860 }
861 if (operand_sharding.ReplicateOnLastTileDim()) {
862 passthrough_tile.push_back(
863 operand_sharding.tile_assignment().dimensions().back());
864 }
865 Array<int64> tile_assignment = operand_sharding.tile_assignment();
866 tile_assignment.Reshape(passthrough_tile);
867 return operand_sharding.ReplicateOnLastTileDim()
868 ? HloSharding::PartialTile(tile_assignment,
869 operand_sharding.metadata())
870 : HloSharding::Tile(tile_assignment, operand_sharding.metadata());
871 }
872
873 // Inverse of PassthroughOperandToGatherOutputOrScatterUpdate.
PassthroughGatherOutputOrScatterUpdateToOperand(const Shape & operand_shape,const HloSharding & update_or_gather_sharding,absl::Span<const int64> collapsed_or_inserted_dims,absl::Span<const int64> index_map,absl::Span<const int64> offset_or_window_dims,absl::Span<const int64> slice_size)874 absl::optional<HloSharding> PassthroughGatherOutputOrScatterUpdateToOperand(
875 const Shape& operand_shape, const HloSharding& update_or_gather_sharding,
876 absl::Span<const int64> collapsed_or_inserted_dims,
877 absl::Span<const int64> index_map,
878 absl::Span<const int64> offset_or_window_dims,
879 absl::Span<const int64> slice_size) {
880 if (update_or_gather_sharding.IsTileMaximal()) {
881 return update_or_gather_sharding;
882 }
883 std::vector<int64> passthrough_tile(operand_shape.rank(), 1);
884 int64_t collapsed = 0;
885 // Relevant dims have shardings passed to the operand.
886 std::vector<int64> relevant_update_or_gather_dims;
887 for (int64_t i = 0; i < operand_shape.rank(); ++i) {
888 if (absl::c_linear_search(collapsed_or_inserted_dims, i) ||
889 absl::c_linear_search(index_map, i)) {
890 collapsed++;
891 continue;
892 }
893 int64_t offset_dim = offset_or_window_dims[i - collapsed];
894 int64_t dim_partitions =
895 update_or_gather_sharding.tile_assignment().dim(offset_dim);
896 if (slice_size[i] != operand_shape.dimensions(i) && dim_partitions > 1) {
897 return absl::nullopt;
898 }
899 if (i - collapsed > 0 &&
900 offset_dim < offset_or_window_dims[i - collapsed - 1]) {
901 // Output offsets are transposed, we do not support this case.
902 return absl::nullopt;
903 }
904 relevant_update_or_gather_dims.push_back(offset_dim);
905 passthrough_tile[i] = dim_partitions;
906 }
907
908 Array<int64> tile_assignment =
909 PartiallyReplicateTiledShardingOnAllDimsExcept(
910 update_or_gather_sharding, relevant_update_or_gather_dims)
911 .tile_assignment();
912 bool is_partial = false;
913 if (tile_assignment.num_elements() != Product(passthrough_tile)) {
914 passthrough_tile.push_back(tile_assignment.num_elements() /
915 Product(passthrough_tile));
916 is_partial = true;
917 }
918 tile_assignment.Reshape(passthrough_tile);
919 return is_partial ? HloSharding::PartialTile(
920 tile_assignment, update_or_gather_sharding.metadata())
921 : HloSharding::Tile(tile_assignment,
922 update_or_gather_sharding.metadata());
923 }
924
925 // Collect data operand sharding for a gather with parallel dimensions from
926 // the output.
GatherParallelDataOperandSharding(const HloSharding & output_sharding,const HloInstruction & gather,const GatherParallelDims & parallel_dims)927 absl::optional<HloSharding> GatherParallelDataOperandSharding(
928 const HloSharding& output_sharding, const HloInstruction& gather,
929 const GatherParallelDims& parallel_dims) {
930 if (output_sharding.IsTileMaximal()) {
931 return output_sharding;
932 }
933 auto output_parallel_dims = GatherParallelOutputDims(gather, parallel_dims);
934 auto output_aligned_operand_parallel_dims =
935 GatherOutputAlignedOperandParallelDims(gather, parallel_dims);
936 const Shape gather_shape = gather.shape();
937 CHECK_EQ(output_parallel_dims.size(),
938 output_aligned_operand_parallel_dims.size());
939 std::vector<int64> operand_tile_assignment(gather.operand(0)->shape().rank(),
940 1);
941 std::vector<int64> relevant_output_dims;
942 for (int i = 0, parallel_idx = 0; i < gather_shape.rank(); ++i) {
943 if (parallel_idx >= output_parallel_dims.size() ||
944 output_parallel_dims[parallel_idx] != i) {
945 continue;
946 }
947 const int64_t operand_dim =
948 output_aligned_operand_parallel_dims[parallel_idx++];
949 operand_tile_assignment[operand_dim] =
950 output_sharding.tile_assignment().dim(i);
951 relevant_output_dims.push_back(i);
952 }
953 HloSharding relevant_output_sharding =
954 PartiallyReplicateTiledShardingOnAllDimsExcept(output_sharding,
955 relevant_output_dims);
956 int64_t partially_replicated_size = 1;
957 if (relevant_output_sharding.ReplicateOnLastTileDim()) {
958 partially_replicated_size *=
959 relevant_output_sharding.tile_assignment().dimensions().back();
960 }
961
962 if (relevant_output_sharding.IsTileMaximal()) {
963 partially_replicated_size *=
964 Product(output_sharding.tile_assignment().dimensions());
965 }
966
967 Array<int64> tile_assignment = relevant_output_sharding.tile_assignment();
968 if (relevant_output_sharding.IsTileMaximal()) {
969 Array<int64> replicated_tile_assignment({gather_shape.rank()}, 1L);
970 tile_assignment = replicated_tile_assignment;
971 }
972 const int64_t operand_tile_elements =
973 Product(operand_tile_assignment) * partially_replicated_size;
974 if (tile_assignment.num_elements() != operand_tile_elements) {
975 if (tile_assignment.num_elements() % operand_tile_elements == 0) {
976 partially_replicated_size *=
977 (tile_assignment.num_elements() / operand_tile_elements);
978 } else {
979 return absl::nullopt;
980 }
981 }
982 if (partially_replicated_size > 1) {
983 operand_tile_assignment.push_back(partially_replicated_size);
984 }
985 tile_assignment.Reshape(operand_tile_assignment);
986 return partially_replicated_size > 1
987 ? HloSharding::PartialTile(tile_assignment,
988 output_sharding.metadata())
989 : HloSharding::Tile(tile_assignment, output_sharding.metadata());
990 }
991
992 } // namespace
993
GatherOutputShardingFromDataOperand(const HloSharding & data_operand_sharding,const HloInstruction & hlo,const Shape & output_shape,const Shape & operand_shape)994 absl::optional<HloSharding> GatherOutputShardingFromDataOperand(
995 const HloSharding& data_operand_sharding, const HloInstruction& hlo,
996 const Shape& output_shape, const Shape& operand_shape) {
997 const auto& dnums = hlo.gather_dimension_numbers();
998 std::vector<int64> collapsed_slice_dims(dnums.collapsed_slice_dims().begin(),
999 dnums.collapsed_slice_dims().end());
1000 std::vector<int64> start_index_map(dnums.start_index_map().begin(),
1001 dnums.start_index_map().end());
1002 std::vector<int64> offset_dims(dnums.offset_dims().begin(),
1003 dnums.offset_dims().end());
1004 return PassthroughOperandToGatherOutputOrScatterUpdate(
1005 operand_shape, data_operand_sharding, output_shape, collapsed_slice_dims,
1006 start_index_map, offset_dims, hlo.gather_slice_sizes());
1007 }
1008
GatherDataOperandShardingFromOutput(const HloSharding & output_sharding,const HloInstruction & hlo)1009 absl::optional<HloSharding> GatherDataOperandShardingFromOutput(
1010 const HloSharding& output_sharding, const HloInstruction& hlo) {
1011 const auto& dnums = hlo.gather_dimension_numbers();
1012 std::vector<int64> collapsed_slice_dims(dnums.collapsed_slice_dims().begin(),
1013 dnums.collapsed_slice_dims().end());
1014 std::vector<int64> start_index_map(dnums.start_index_map().begin(),
1015 dnums.start_index_map().end());
1016 std::vector<int64> offset_dims(dnums.offset_dims().begin(),
1017 dnums.offset_dims().end());
1018
1019 absl::optional<HloSharding> parallel_sharding;
1020 auto parallel_dims = GetGatherBatchParallelDims(hlo);
1021 if (parallel_dims) {
1022 // Prioritize parallel sharding first as this is how it is in
1023 // spmd_partitioner.
1024 parallel_sharding =
1025 GatherParallelDataOperandSharding(output_sharding, hlo, *parallel_dims);
1026 }
1027 absl::optional<HloSharding> passthrough_sharding =
1028 PassthroughGatherOutputOrScatterUpdateToOperand(
1029 hlo.operand(0)->shape(), output_sharding, collapsed_slice_dims,
1030 start_index_map, offset_dims, hlo.gather_slice_sizes());
1031 // Try to merge the two shardings or return the one that is present if only
1032 // one of the two is.
1033 if (!passthrough_sharding) {
1034 return parallel_sharding;
1035 }
1036 if (!parallel_sharding) {
1037 return passthrough_sharding;
1038 }
1039 if (MergeSharding(*parallel_sharding, &*passthrough_sharding,
1040 /*may_combine_partial_sharding=*/true)) {
1041 return passthrough_sharding;
1042 }
1043 if (MergeSharding(*passthrough_sharding, &*parallel_sharding,
1044 /*may_combine_partial_sharding=*/true)) {
1045 return parallel_sharding;
1046 }
1047 return absl::nullopt;
1048 }
1049
ScatterOutputShardingFromUpdate(const HloSharding & update_sharding,const HloInstruction & hlo)1050 absl::optional<HloSharding> ScatterOutputShardingFromUpdate(
1051 const HloSharding& update_sharding, const HloInstruction& hlo) {
1052 const auto& dnums = hlo.scatter_dimension_numbers();
1053 std::vector<int64> inserted_window_dims(dnums.inserted_window_dims().begin(),
1054 dnums.inserted_window_dims().end());
1055 std::vector<int64> scatter_dims_to_operand_dims(
1056 dnums.scatter_dims_to_operand_dims().begin(),
1057 dnums.scatter_dims_to_operand_dims().end());
1058 std::vector<int64> update_window_dims(dnums.update_window_dims().begin(),
1059 dnums.update_window_dims().end());
1060 std::vector<int64> slice_size(hlo.shape().rank(), 1);
1061 int64_t num_update_window_dims = 0;
1062 for (int64_t i = 0; i < hlo.shape().rank(); ++i) {
1063 if (absl::c_linear_search(dnums.inserted_window_dims(), i)) {
1064 continue;
1065 }
1066 slice_size[i] = hlo.operand(2)->shape().dimensions(
1067 dnums.update_window_dims(num_update_window_dims++));
1068 }
1069 return PassthroughGatherOutputOrScatterUpdateToOperand(
1070 hlo.shape(), update_sharding, inserted_window_dims,
1071 scatter_dims_to_operand_dims, update_window_dims, slice_size);
1072 }
1073
ScatterUpdateShardingFromOutput(const HloSharding & output_sharding,const HloInstruction & hlo)1074 absl::optional<HloSharding> ScatterUpdateShardingFromOutput(
1075 const HloSharding& output_sharding, const HloInstruction& hlo) {
1076 const auto& dnums = hlo.scatter_dimension_numbers();
1077 std::vector<int64> inserted_window_dims(dnums.inserted_window_dims().begin(),
1078 dnums.inserted_window_dims().end());
1079 std::vector<int64> scatter_dims_to_operand_dims(
1080 dnums.scatter_dims_to_operand_dims().begin(),
1081 dnums.scatter_dims_to_operand_dims().end());
1082 std::vector<int64> update_window_dims(dnums.update_window_dims().begin(),
1083 dnums.update_window_dims().end());
1084 std::vector<int64> slice_size(hlo.shape().rank(), 1);
1085 int64_t num_update_window_dims = 0;
1086 for (int64_t i = 0; i < hlo.shape().rank(); ++i) {
1087 if (absl::c_linear_search(dnums.inserted_window_dims(), i)) {
1088 continue;
1089 }
1090 slice_size[i] = hlo.operand(2)->shape().dimensions(
1091 dnums.update_window_dims(num_update_window_dims++));
1092 }
1093 return PassthroughOperandToGatherOutputOrScatterUpdate(
1094 hlo.shape(), output_sharding, hlo.operand(2)->shape(),
1095 inserted_window_dims, scatter_dims_to_operand_dims, update_window_dims,
1096 slice_size);
1097 }
1098
1099 StatusOr<std::pair<std::unique_ptr<HloInstruction>, HloOpcode>>
IdentityValueAndHloOpcodeForScatterReduceComputation(const HloScatterInstruction & scatter)1100 IdentityValueAndHloOpcodeForScatterReduceComputation(
1101 const HloScatterInstruction& scatter) {
1102 auto computation = scatter.to_apply();
1103 // We only handle computations with 2 parameters and only 1 calculation.
1104 if (computation->instruction_count() != 3) {
1105 return Status(
1106 tensorflow::error::Code::INVALID_ARGUMENT,
1107 "Expected scatter reduce computation with 2 parameters and only 1 "
1108 "calculation");
1109 }
1110
1111 auto root_instruction = computation->root_instruction();
1112 if (root_instruction->opcode() == HloOpcode::kAdd ||
1113 root_instruction->opcode() == HloOpcode::kOr) {
1114 return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::Zero(
1115 scatter.shape().element_type())),
1116 root_instruction->opcode());
1117 } else if (root_instruction->opcode() == HloOpcode::kMultiply ||
1118 root_instruction->opcode() == HloOpcode::kAnd) {
1119 return std::make_pair(HloInstruction::CreateConstant(
1120 LiteralUtil::One(scatter.shape().element_type())),
1121 root_instruction->opcode());
1122 } else if (root_instruction->opcode() == HloOpcode::kMaximum) {
1123 return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MinValue(
1124 scatter.shape().element_type())),
1125 root_instruction->opcode());
1126 } else if (root_instruction->opcode() == HloOpcode::kMinimum) {
1127 return std::make_pair(HloInstruction::CreateConstant(LiteralUtil::MaxValue(
1128 scatter.shape().element_type())),
1129 root_instruction->opcode());
1130 }
1131
1132 return Status(tensorflow::error::Code::INVALID_ARGUMENT,
1133 "Expected scatter reduce computation which is "
1134 "add/or/multiply/add/min/max");
1135 }
1136
1137 namespace {
1138
DevicesForShardingInternal(const HloSharding & sharding,const absl::flat_hash_set<int64> & available_devices,absl::flat_hash_set<int64> * used)1139 void DevicesForShardingInternal(
1140 const HloSharding& sharding,
1141 const absl::flat_hash_set<int64>& available_devices,
1142 absl::flat_hash_set<int64>* used) {
1143 if (sharding.IsTuple()) {
1144 for (const auto& subsharding : sharding.tuple_elements()) {
1145 DevicesForShardingInternal(subsharding, available_devices, used);
1146 }
1147 return;
1148 }
1149
1150 if (sharding.IsReplicated()) {
1151 for (int64_t device : available_devices) {
1152 if (!HloSharding::IsReservedDevice(device)) {
1153 used->insert(device);
1154 }
1155 }
1156 return;
1157 }
1158
1159 DCHECK(std::all_of(
1160 sharding.tile_assignment().begin(), sharding.tile_assignment().end(),
1161 [&](int64_t device) { return available_devices.contains(device); }));
1162 sharding.tile_assignment().Each(
1163 [&](absl::Span<const int64> /*indices*/, int64_t device) {
1164 used->insert(device);
1165 });
1166 }
1167
1168 } // namespace
1169
DevicesForSharding(const HloSharding & sharding,const std::vector<int64> & available_devices)1170 std::vector<int64> DevicesForSharding(
1171 const HloSharding& sharding, const std::vector<int64>& available_devices) {
1172 absl::flat_hash_set<int64> available_set;
1173 for (int64_t device : available_devices) {
1174 available_set.insert(device);
1175 }
1176 absl::flat_hash_set<int64> used_set;
1177 DevicesForShardingInternal(sharding, available_set, &used_set);
1178 std::vector<int64> devices;
1179 for (int64_t device : available_devices) {
1180 if (used_set.contains(device)) {
1181 devices.push_back(device);
1182 }
1183 }
1184 return devices;
1185 }
1186
PartiallyReplicateTiledShardingOnDims(const HloSharding & sharding,absl::Span<const int64> dims_to_replicate)1187 HloSharding PartiallyReplicateTiledShardingOnDims(
1188 const HloSharding& sharding, absl::Span<const int64> dims_to_replicate) {
1189 if (sharding.IsTileMaximal()) {
1190 return sharding;
1191 }
1192 int64_t group_count = 1;
1193 for (int64_t dim : dims_to_replicate) {
1194 if (sharding.ReplicateOnLastTileDim()) {
1195 CHECK_LT(dim, sharding.tile_assignment().num_dimensions());
1196 }
1197 group_count *= sharding.tile_assignment().dim(dim);
1198 }
1199 if (group_count == 1) {
1200 return sharding;
1201 }
1202 if (group_count == sharding.NumTiles()) {
1203 return HloSharding::Replicate(sharding.metadata());
1204 }
1205 std::vector<int64> dim_permutation(
1206 sharding.tile_assignment().num_dimensions());
1207 std::iota(dim_permutation.begin(), dim_permutation.end(), 0);
1208 absl::c_stable_sort(dim_permutation, [&](const int64_t a, const int64_t b) {
1209 return absl::c_linear_search(dims_to_replicate, a) <
1210 absl::c_linear_search(dims_to_replicate, b);
1211 });
1212 auto transposed = TransposeSharding(sharding, dim_permutation);
1213 auto new_tile = transposed.tile_assignment();
1214 std::vector<int64> new_tile_shape(
1215 sharding.tile_assignment().dimensions().begin(),
1216 sharding.tile_assignment().dimensions().end());
1217 for (int64_t dim : dims_to_replicate) {
1218 new_tile_shape[dim] = 1;
1219 }
1220 if (sharding.ReplicateOnLastTileDim()) {
1221 new_tile_shape.back() *= group_count;
1222 } else {
1223 new_tile_shape.push_back(group_count);
1224 }
1225 new_tile.Reshape(new_tile_shape);
1226 return HloSharding::PartialTile(new_tile, sharding.metadata());
1227 }
1228
PartiallyReplicateTiledShardingOnAllDimsExcept(const HloSharding & sharding,absl::Span<const int64> dims_to_keep)1229 HloSharding PartiallyReplicateTiledShardingOnAllDimsExcept(
1230 const HloSharding& sharding, absl::Span<const int64> dims_to_keep) {
1231 if (sharding.IsTileMaximal()) {
1232 return sharding;
1233 }
1234 int64_t data_rank = sharding.tile_assignment().num_dimensions();
1235 if (sharding.ReplicateOnLastTileDim()) {
1236 data_rank -= 1;
1237 }
1238 std::vector<int64> dims_to_replicate(data_rank);
1239 absl::c_iota(dims_to_replicate, 0);
1240
1241 dims_to_replicate.erase(
1242 std::remove_if(
1243 dims_to_replicate.begin(), dims_to_replicate.end(),
1244 [&](int64_t i) { return absl::c_linear_search(dims_to_keep, i); }),
1245 dims_to_replicate.end());
1246 return PartiallyReplicateTiledShardingOnDims(sharding, dims_to_replicate);
1247 }
1248
RemoveShapeDimensions(const HloSharding & sharding,const std::vector<int64> & dims_to_remove)1249 HloSharding RemoveShapeDimensions(const HloSharding& sharding,
1250 const std::vector<int64>& dims_to_remove) {
1251 if (sharding.IsTileMaximal() || dims_to_remove.empty()) {
1252 return sharding;
1253 }
1254 std::vector<int64> new_tile_shape;
1255 new_tile_shape.reserve(sharding.tile_assignment().num_dimensions() -
1256 dims_to_remove.size());
1257 for (int64_t i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
1258 if (absl::c_linear_search(dims_to_remove, i)) {
1259 CHECK_EQ(sharding.tile_assignment().dim(i), 1);
1260 } else {
1261 new_tile_shape.push_back(sharding.tile_assignment().dim(i));
1262 }
1263 }
1264 auto new_tile = sharding.tile_assignment();
1265 new_tile.Reshape(new_tile_shape);
1266 return sharding.ReplicateOnLastTileDim()
1267 ? HloSharding::PartialTile(new_tile, sharding.metadata())
1268 : HloSharding::Tile(new_tile, sharding.metadata());
1269 }
1270
TransposeShardingWithCollapsedDims(const HloSharding & source,absl::Span<int64 const> src_to_tgt,absl::Span<int64 const> tgt_to_src)1271 absl::optional<HloSharding> TransposeShardingWithCollapsedDims(
1272 const HloSharding& source, absl::Span<int64 const> src_to_tgt,
1273 absl::Span<int64 const> tgt_to_src) {
1274 if (source.IsTileMaximal()) {
1275 return source;
1276 }
1277 if (source.ReplicateOnLastTileDim() &&
1278 src_to_tgt.size() < source.tile_assignment().num_dimensions()) {
1279 std::vector<int64> new_src_to_tgt(src_to_tgt.begin(), src_to_tgt.end());
1280 new_src_to_tgt.push_back(tgt_to_src.size());
1281 std::vector<int64> new_tgt_to_src(tgt_to_src.begin(), tgt_to_src.end());
1282 new_tgt_to_src.push_back(src_to_tgt.size());
1283 return TransposeShardingWithCollapsedDims(source, new_src_to_tgt,
1284 new_tgt_to_src);
1285 }
1286 std::vector<int64> tgt_dims_skipping_new(tgt_to_src.size(), -1);
1287 int64_t skipped_tgt_dims = 0;
1288 for (int64_t i = 0; i < tgt_to_src.size(); ++i) {
1289 if (tgt_to_src[i] < 0) {
1290 skipped_tgt_dims++;
1291 } else {
1292 tgt_dims_skipping_new[i] = i - skipped_tgt_dims;
1293 }
1294 }
1295 int64_t skipped_src_dims = absl::c_count(src_to_tgt, -1);
1296 std::vector<int64> perm(src_to_tgt.size());
1297 for (int64_t i = 0; i < src_to_tgt.size(); ++i) {
1298 if (src_to_tgt[i] < 0) {
1299 if (source.tile_assignment().dim(i) > 1) {
1300 return absl::nullopt;
1301 }
1302 perm[src_to_tgt.size() - skipped_src_dims] = i;
1303 skipped_src_dims--;
1304 } else {
1305 perm[tgt_dims_skipping_new[src_to_tgt[i]]] = i;
1306 }
1307 }
1308 auto tgt_sharding = hlo_sharding_util::TransposeSharding(source, perm);
1309 auto reshape_tiles = tgt_sharding.tile_assignment();
1310 std::vector<int64> tgt_tiles(tgt_to_src.size(), 1);
1311 for (int64_t i = 0; i < tgt_tiles.size(); ++i) {
1312 if (tgt_to_src[i] >= 0) {
1313 tgt_tiles[i] = reshape_tiles.dim(tgt_dims_skipping_new[i]);
1314 }
1315 }
1316 reshape_tiles.Reshape(tgt_tiles);
1317 return source.ReplicateOnLastTileDim()
1318 ? HloSharding::PartialTile(reshape_tiles, source.metadata())
1319 : HloSharding::Tile(reshape_tiles, source.metadata());
1320 }
1321
GetDimensionForIota(const HloInstruction * maybe_iota)1322 absl::optional<int64> GetDimensionForIota(const HloInstruction* maybe_iota) {
1323 if (auto* iota = DynCast<HloIotaInstruction>(maybe_iota)) {
1324 return iota->iota_dimension();
1325 }
1326
1327 if (maybe_iota->shape().element_type() != S32) {
1328 return absl::nullopt;
1329 }
1330 if (maybe_iota->IsConstant()) {
1331 std::vector<bool> is_iota_dim(maybe_iota->shape().rank(), true);
1332 maybe_iota->literal().EachCell<int32>(
1333 [&](absl::Span<const int64> indices, int32_t val) {
1334 for (int64_t i = 0; i < indices.size(); ++i) {
1335 if (val != indices[i]) {
1336 is_iota_dim[i] = false;
1337 }
1338 }
1339 });
1340 for (int64_t i = 0; i < is_iota_dim.size(); ++i) {
1341 if (is_iota_dim[i] && maybe_iota->shape().dimensions(i) > 1) {
1342 return i;
1343 }
1344 }
1345 return absl::nullopt;
1346 }
1347
1348 if (maybe_iota->opcode() == HloOpcode::kBroadcast) {
1349 auto operand_dim = GetDimensionForIota(maybe_iota->operand(0));
1350 if (operand_dim) {
1351 return maybe_iota->dimensions(*operand_dim);
1352 }
1353 return absl::nullopt;
1354 }
1355 return absl::nullopt;
1356 }
1357
GetGatherBatchParallelDims(const HloInstruction & hlo)1358 absl::optional<GatherParallelDims> GetGatherBatchParallelDims(
1359 const HloInstruction& hlo) {
1360 const auto& dnums = hlo.gather_dimension_numbers();
1361 int64_t index_dim = dnums.index_vector_dim();
1362 // Try to identify if there's a dimension in the indices that is monotonically
1363 // increasing with a Iota across a certain dimension. This would mean that the
1364 // access in the relative dimension indexed by this index in the operand is
1365 // parallelizable and that we can shard the operand (and the index/output)
1366 // across such dimension.
1367 // For example the pattern:
1368 // %iota.1 = iota()
1369 // %indices = concatenate(..., %iota.1, ...)
1370 // ... = gather(..., %indices)
1371 // is common for tf.reverse_sequence and would match this case.
1372 absl::InlinedVector<const HloIotaInstruction*, 4> iotas;
1373 const HloInstruction* indices = hlo.operand(1);
1374 const int num_indices = dnums.start_index_map_size();
1375 std::vector<int64> index_parallel_in_dim(num_indices, -1);
1376 // Handle cases where we concatenate pieces of the indices one at a time.
1377 if (indices->opcode() == HloOpcode::kConcatenate &&
1378 indices->concatenate_dimension() == index_dim) {
1379 int concatenated_dims = 0;
1380 for (int i = 0; i < indices->operand_count(); ++i) {
1381 const HloInstruction* op = indices->operand(i);
1382 const int64_t num_indices_from_element =
1383 op->shape().dimensions_size() > index_dim
1384 ? op->shape().dimensions(index_dim)
1385 : 1;
1386 if (absl::optional<int64> maybe_iota_dim = GetDimensionForIota(op)) {
1387 if (*maybe_iota_dim != index_dim) {
1388 for (int j = 0; j < num_indices_from_element; ++j) {
1389 index_parallel_in_dim[concatenated_dims + j] = *maybe_iota_dim;
1390 }
1391 }
1392 }
1393 concatenated_dims += num_indices_from_element;
1394 }
1395 } else if (absl::optional<int64> maybe_iota_dim =
1396 GetDimensionForIota(indices)) {
1397 if (*maybe_iota_dim != index_dim) {
1398 // This is a case of a single iota with index_dim being out of bounds.
1399 const int64_t num_indices_from_element =
1400 indices->shape().dimensions_size() > index_dim
1401 ? indices->shape().dimensions(index_dim)
1402 : 1;
1403 index_parallel_in_dim.assign(num_indices_from_element, *maybe_iota_dim);
1404 }
1405 }
1406 absl::InlinedVector<int64, 1> indices_parallel_dims;
1407 absl::InlinedVector<int64, 1> operand_parallel_dims;
1408 // Map the parallelizable dimension from the iota to the dimensions of the
1409 // output and the operand. These dimensions are interconnected, but between
1410 // operands and index they could have different spots in the shape because the
1411 // position of the index dimension in the operand is determined by
1412 // start_index_map.
1413 for (int i = 0; i < index_parallel_in_dim.size(); ++i) {
1414 int index_parallel_dim = index_parallel_in_dim[i];
1415 if (index_parallel_dim == -1) {
1416 continue;
1417 }
1418 if (absl::c_linear_search(indices_parallel_dims, index_parallel_dim)) {
1419 return absl::nullopt;
1420 }
1421 // Considered parallel only if the slice is of size 1 over the operand.
1422 if (hlo.gather_slice_sizes()[dnums.start_index_map(i)] == 1) {
1423 indices_parallel_dims.push_back(index_parallel_dim);
1424 operand_parallel_dims.push_back(dnums.start_index_map(i));
1425 } else {
1426 index_parallel_in_dim[i] = -1;
1427 }
1428 }
1429 absl::c_sort(indices_parallel_dims);
1430 if (!indices_parallel_dims.empty()) {
1431 return GatherParallelDims{indices_parallel_dims, operand_parallel_dims,
1432 index_parallel_in_dim};
1433 }
1434 return absl::nullopt;
1435 }
1436
GatherParallelOutputDims(const HloInstruction & gather,const GatherParallelDims & parallel_dim)1437 absl::InlinedVector<int64, 1> GatherParallelOutputDims(
1438 const HloInstruction& gather, const GatherParallelDims& parallel_dim) {
1439 absl::InlinedVector<int64, 1> output_parallel_dims;
1440 auto indices_parallel_dims = parallel_dim.indices_parallel_dims;
1441 const Shape gather_shape = gather.shape();
1442 auto dnums = gather.gather_dimension_numbers();
1443 for (int i = 0, idx_dim = 0; i < gather_shape.dimensions_size(); ++i) {
1444 if (absl::c_linear_search(dnums.offset_dims(), i)) {
1445 continue;
1446 }
1447 const int index_dim =
1448 idx_dim < dnums.index_vector_dim() ? idx_dim : idx_dim + 1;
1449 if (absl::c_binary_search(indices_parallel_dims, index_dim)) {
1450 output_parallel_dims.push_back(i);
1451 }
1452 ++idx_dim;
1453 }
1454 return output_parallel_dims;
1455 }
1456
GatherOutputAlignedOperandParallelDims(const HloInstruction & gather,const GatherParallelDims & parallel_dims)1457 absl::InlinedVector<int64, 1> GatherOutputAlignedOperandParallelDims(
1458 const HloInstruction& gather, const GatherParallelDims& parallel_dims) {
1459 absl::InlinedVector<int64, 1> operand_parallel_dim_to_output(
1460 parallel_dims.operand_parallel_dims.size(), -1);
1461 auto dnums = gather.gather_dimension_numbers();
1462 CHECK_LE(parallel_dims.indices_parallel_dims.size(),
1463 parallel_dims.operand_parallel_dims.size());
1464 for (int i = 0; i < parallel_dims.index_parallel_in_dim.size(); ++i) {
1465 // This is the equivalent batch dimension of the indices that corresponds
1466 // to this index dimension.
1467 const int64_t index_parallel_dim = parallel_dims.index_parallel_in_dim[i];
1468 // If it's not an index that is parallel skip.
1469 if (index_parallel_dim == -1) {
1470 continue;
1471 }
1472 // This is small so just look linearly. Populate the operand parallel
1473 // dimensions based on the order of the index batch dims (which is the same
1474 // order as the output).
1475 for (int j = 0; j < parallel_dims.indices_parallel_dims.size(); ++j) {
1476 if (parallel_dims.indices_parallel_dims[j] == index_parallel_dim) {
1477 const int64_t operand_parallel_dim = dnums.start_index_map(i);
1478 if (operand_parallel_dim_to_output[j] == -1) {
1479 operand_parallel_dim_to_output[j] = operand_parallel_dim;
1480 }
1481 break;
1482 }
1483 }
1484 }
1485 return operand_parallel_dim_to_output;
1486 }
1487
1488 } // namespace hlo_sharding_util
1489 } // namespace xla
1490