1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "absl/algorithm/container.h"
17 #include "absl/container/flat_hash_map.h"
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/types/optional.h"
20 #include "tensorflow/compiler/xla/literal_util.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
27 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
28 #include "tensorflow/compiler/xla/service/shape_inference.h"
29 #include "tensorflow/compiler/xla/service/spmd/convolution_handler.h"
30 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
31 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/compiler/xla/window_util.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/lib/gtl/cleanup.h"
37 #include "tensorflow/core/platform/numbers.h"
38
39 namespace xla {
40 namespace spmd {
41
HandleDot(HloInstruction * hlo)42 Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) {
43 DotConvDimsMapping mapping;
44 const auto& dnums = hlo->dot_dimension_numbers();
45 int64_t next_output_dim = 0;
46 for (int64_t i = 0; i < dnums.lhs_batch_dimensions_size(); ++i) {
47 mapping.batch_dims.emplace_back();
48 mapping.batch_dims.back().lhs = dnums.lhs_batch_dimensions(i);
49 mapping.batch_dims.back().rhs = dnums.rhs_batch_dimensions(i);
50 mapping.batch_dims.back().output = next_output_dim++;
51 }
52 for (int64_t i = 0; i < dnums.lhs_contracting_dimensions_size(); ++i) {
53 mapping.contracting_dims.emplace_back();
54 mapping.contracting_dims.back().lhs = dnums.lhs_contracting_dimensions(i);
55 mapping.contracting_dims.back().rhs = dnums.rhs_contracting_dimensions(i);
56 mapping.contracting_dims.back().output = -1;
57 }
58 for (int64_t i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
59 if (absl::c_linear_search(dnums.lhs_batch_dimensions(), i) ||
60 absl::c_linear_search(dnums.lhs_contracting_dimensions(), i)) {
61 continue;
62 }
63 mapping.lhs_non_contracting_dims.emplace_back();
64 mapping.lhs_non_contracting_dims.back().lhs = i;
65 mapping.lhs_non_contracting_dims.back().rhs = -1;
66 mapping.lhs_non_contracting_dims.back().output = next_output_dim++;
67 }
68 for (int64_t i = 0; i < hlo->operand(1)->shape().rank(); ++i) {
69 if (absl::c_linear_search(dnums.rhs_batch_dimensions(), i) ||
70 absl::c_linear_search(dnums.rhs_contracting_dimensions(), i)) {
71 continue;
72 }
73 mapping.rhs_non_contracting_dims.emplace_back();
74 mapping.rhs_non_contracting_dims.back().lhs = -1;
75 mapping.rhs_non_contracting_dims.back().rhs = i;
76 mapping.rhs_non_contracting_dims.back().output = next_output_dim++;
77 }
78 auto create_sharded_dot =
79 [&](HloInstruction* l, HloInstruction* r, SpmdBuilder* b,
80 const Window& conv_window) -> StatusOr<HloInstruction*> {
81 TF_ASSIGN_OR_RETURN(
82 auto sharded_dot_shape,
83 ShapeInference::InferDotOpShape(
84 l->shape(), r->shape(), hlo->dot_dimension_numbers(),
85 /*preferred_element_type=*/hlo->shape().element_type()));
86 return b->AddInstruction(HloInstruction::CreateDot(
87 sharded_dot_shape, l, r, hlo->dot_dimension_numbers(),
88 hlo->precision_config()));
89 };
90 return HandleDotHelper(hlo, mapping, create_sharded_dot);
91 }
92
93 namespace {
94
95 enum class WindowedEinsumOperand { LHS, RHS };
96
97 struct WindowedEinsumConfig {
98 WindowedEinsumOperand windowed_op;
99 bool windowed_at_contracting_dims;
100 bool windowed_at_batch_dims;
101 bool operands_sharded_at_contracting_dims;
102 };
103
104 struct DotDimensionIndexMapping {
105 std::vector<int64> lhs_to_rhs_indices;
106 std::vector<int64> lhs_to_output_indices;
107 std::vector<int64> rhs_to_lhs_indices;
108 std::vector<int64> rhs_to_output_indices;
109 std::vector<int64> output_to_lhs_indices;
110 std::vector<int64> output_to_rhs_indices;
111 };
112
UpdateDDNums(DotDimensionNumbers * new_ddnums,int64_t reshaped_dim,bool lhs)113 void UpdateDDNums(DotDimensionNumbers* new_ddnums, int64_t reshaped_dim,
114 bool lhs) {
115 auto update_dims =
116 [&reshaped_dim](tensorflow::protobuf::RepeatedField<int64>* dims) {
117 bool add_reshaped_dim = false;
118 if (absl::c_linear_search(*dims, reshaped_dim)) {
119 add_reshaped_dim = true;
120 }
121 for (int64_t i = 0; i < dims->size(); ++i) {
122 auto dim = dims->at(i);
123 if (reshaped_dim <= dim) {
124 dims->Set(i, dim + 1);
125 }
126 }
127 if (add_reshaped_dim) {
128 dims->Add(reshaped_dim);
129 }
130 };
131
132 if (lhs) {
133 update_dims(new_ddnums->mutable_lhs_contracting_dimensions());
134 update_dims(new_ddnums->mutable_lhs_batch_dimensions());
135 } else { // rhs
136 update_dims(new_ddnums->mutable_rhs_contracting_dimensions());
137 update_dims(new_ddnums->mutable_rhs_batch_dimensions());
138 }
139 }
140
GenNewWindow(const HloInstruction * original_dot,const HloInstruction * dot_lhs,const HloInstruction * dot_rhs,int64_t lhs_concat_dim,int64_t rhs_concat_dim,bool windowed_at_contracting_dims,bool windowed_at_batch_dims)141 Window GenNewWindow(const HloInstruction* original_dot,
142 const HloInstruction* dot_lhs,
143 const HloInstruction* dot_rhs, int64_t lhs_concat_dim,
144 int64_t rhs_concat_dim, bool windowed_at_contracting_dims,
145 bool windowed_at_batch_dims) {
146 auto new_window = original_dot->window();
147 const ConvolutionDimensionNumbers& conv_dnums =
148 original_dot->convolution_dimension_numbers();
149 if (lhs_concat_dim != -1) {
150 for (int64_t i = 0; i < conv_dnums.input_spatial_dimensions_size(); ++i) {
151 if (conv_dnums.input_spatial_dimensions(i) == lhs_concat_dim) {
152 auto wd = new_window.mutable_dimensions(i);
153 auto lhs_size = dot_lhs->shape().dimensions(lhs_concat_dim + 1);
154 if (windowed_at_contracting_dims) {
155 wd->set_size(lhs_size);
156 }
157 if (windowed_at_batch_dims) {
158 wd->set_size(lhs_size);
159 wd->set_padding_low(0);
160 wd->set_padding_high(0);
161 wd->set_stride(std::max<int64>(1, lhs_size - 1));
162 wd->set_window_dilation(1);
163 wd->set_base_dilation(lhs_size);
164 wd->set_window_reversal(false);
165 }
166 }
167 }
168 }
169 if (rhs_concat_dim != -1) {
170 for (int64_t i = 0; i < conv_dnums.kernel_spatial_dimensions_size(); ++i) {
171 if (conv_dnums.kernel_spatial_dimensions(i) == rhs_concat_dim &&
172 !windowed_at_contracting_dims && !windowed_at_batch_dims &&
173 lhs_concat_dim == -1) {
174 auto wd = new_window.mutable_dimensions(i);
175 auto rhs_size = dot_rhs->shape().dimensions(rhs_concat_dim + 1);
176 wd->set_size(rhs_size);
177 wd->set_padding_low(rhs_size - 1);
178 wd->set_padding_high(rhs_size - 1);
179 }
180 }
181 }
182 // Add the extra dimension to window.
183 WindowDimension* new_dim = new_window.add_dimensions();
184 if (windowed_at_contracting_dims) {
185 new_dim->set_size(2);
186 new_dim->set_padding_low(0);
187 new_dim->set_padding_high(0);
188 new_dim->set_stride(1);
189 new_dim->set_window_dilation(1);
190 new_dim->set_base_dilation(1);
191 new_dim->set_window_reversal(false);
192 } else if (windowed_at_batch_dims) {
193 new_dim->set_size(2);
194 new_dim->set_padding_low(0);
195 new_dim->set_padding_high(0);
196 new_dim->set_stride(1); // std::max<int64>(1, 2 - 1)
197 new_dim->set_window_dilation(1);
198 new_dim->set_base_dilation(2);
199 new_dim->set_window_reversal(false);
200 } else {
201 if (lhs_concat_dim != -1) {
202 new_dim->set_size(1);
203 new_dim->set_padding_low(0);
204 new_dim->set_padding_high(0);
205 new_dim->set_stride(1);
206 new_dim->set_window_dilation(1);
207 new_dim->set_base_dilation(1);
208 new_dim->set_window_reversal(false);
209 }
210 if (rhs_concat_dim != -1) {
211 new_dim->set_size(2); // rhs_size
212 new_dim->set_padding_low(1); // rhs_size - 1
213 new_dim->set_padding_high(1); // rhs_size - 1
214 new_dim->set_stride(1);
215 new_dim->set_window_dilation(1);
216 new_dim->set_base_dilation(1);
217 new_dim->set_window_reversal(true);
218 }
219 }
220
221 VLOG(2) << "new_window: " << new_window.ShortDebugString();
222 return new_window;
223 }
224
GenNewConvDNums(const HloInstruction * original_dot,const HloInstruction * dot_lhs,const HloInstruction * dot_rhs,int64_t lhs_concat_dim,int64_t rhs_concat_dim,bool windowed_at_contracting_dims,bool windowed_at_batch_dims,const std::vector<int64> & lhs_to_output_indices,const std::vector<int64> & rhs_to_output_indices,const Shape & new_dot_shape)225 ConvolutionDimensionNumbers GenNewConvDNums(
226 const HloInstruction* original_dot, const HloInstruction* dot_lhs,
227 const HloInstruction* dot_rhs, int64_t lhs_concat_dim,
228 int64_t rhs_concat_dim, bool windowed_at_contracting_dims,
229 bool windowed_at_batch_dims,
230 const std::vector<int64>& lhs_to_output_indices,
231 const std::vector<int64>& rhs_to_output_indices,
232 const Shape& new_dot_shape) {
233 // Generate the new conv dimension numbers.
234 const ConvolutionDimensionNumbers& dnums =
235 original_dot->convolution_dimension_numbers();
236 // Handle the LHS dimension numbers.
237 int64_t input_batch_dimension = dnums.input_batch_dimension();
238 int64_t input_feature_dimension = dnums.input_feature_dimension();
239 std::vector<int64> input_spatial_dimensions(
240 dnums.input_spatial_dimensions().begin(),
241 dnums.input_spatial_dimensions().end());
242 if (lhs_concat_dim != -1) {
243 if (lhs_concat_dim <= input_batch_dimension) {
244 input_batch_dimension++;
245 }
246 if (lhs_concat_dim <= input_feature_dimension) {
247 input_feature_dimension++;
248 }
249 for (int64_t i = 0; i < input_spatial_dimensions.size(); ++i) {
250 if (lhs_concat_dim <= input_spatial_dimensions[i]) {
251 input_spatial_dimensions[i]++;
252 }
253 }
254 input_spatial_dimensions.push_back(lhs_concat_dim);
255 }
256 if (rhs_concat_dim != -1 && !windowed_at_contracting_dims &&
257 !windowed_at_batch_dims) {
258 input_spatial_dimensions.push_back(dot_lhs->shape().dimensions_size() - 1);
259 }
260 // Handle the RHS dimension numbers.
261 int64_t kernel_input_feature_dimension =
262 dnums.kernel_input_feature_dimension();
263 int64_t kernel_output_feature_dimension =
264 dnums.kernel_output_feature_dimension();
265 std::vector<int64> kernel_spatial_dimensions(
266 dnums.kernel_spatial_dimensions().begin(),
267 dnums.kernel_spatial_dimensions().end());
268 if (rhs_concat_dim != -1) {
269 if (rhs_concat_dim <= kernel_input_feature_dimension) {
270 kernel_input_feature_dimension++;
271 }
272 if (rhs_concat_dim <= kernel_output_feature_dimension) {
273 kernel_output_feature_dimension++;
274 }
275 for (int64_t i = 0; i < kernel_spatial_dimensions.size(); ++i) {
276 if (rhs_concat_dim <= kernel_spatial_dimensions[i]) {
277 kernel_spatial_dimensions[i]++;
278 }
279 }
280 kernel_spatial_dimensions.push_back(rhs_concat_dim);
281 }
282 if (lhs_concat_dim != -1 && !windowed_at_contracting_dims &&
283 !windowed_at_batch_dims) {
284 kernel_spatial_dimensions.push_back(dot_rhs->shape().dimensions_size() - 1);
285 }
286 // Handle the Output dimension numbers.
287 int64_t output_batch_dimension = dnums.output_batch_dimension();
288 int64_t output_feature_dimension = dnums.output_feature_dimension();
289 std::vector<int64> output_spatial_dimensions(
290 dnums.output_spatial_dimensions().begin(),
291 dnums.output_spatial_dimensions().end());
292 if (!windowed_at_contracting_dims) {
293 auto output_slice_dim = lhs_concat_dim != -1
294 ? lhs_to_output_indices[lhs_concat_dim]
295 : rhs_to_output_indices[rhs_concat_dim];
296 if (output_slice_dim <= output_batch_dimension) {
297 output_batch_dimension++;
298 }
299 if (output_slice_dim <= output_feature_dimension) {
300 output_feature_dimension++;
301 }
302 for (int64_t i = 0; i < output_spatial_dimensions.size(); ++i) {
303 if (output_slice_dim <= output_spatial_dimensions[i]) {
304 output_spatial_dimensions[i]++;
305 }
306 }
307 output_spatial_dimensions.push_back(output_slice_dim);
308 } else {
309 output_spatial_dimensions.push_back(new_dot_shape.dimensions_size() - 1);
310 }
311 // Construct the new dot dimension numbers.
312 ConvolutionDimensionNumbers new_dnums;
313 new_dnums.set_input_batch_dimension(input_batch_dimension);
314 new_dnums.set_input_feature_dimension(input_feature_dimension);
315 for (auto dim : input_spatial_dimensions) {
316 new_dnums.add_input_spatial_dimensions(dim);
317 }
318 new_dnums.set_kernel_input_feature_dimension(kernel_input_feature_dimension);
319 new_dnums.set_kernel_output_feature_dimension(
320 kernel_output_feature_dimension);
321 for (auto dim : kernel_spatial_dimensions) {
322 new_dnums.add_kernel_spatial_dimensions(dim);
323 }
324 new_dnums.set_output_batch_dimension(output_batch_dimension);
325 new_dnums.set_output_feature_dimension(output_feature_dimension);
326 for (auto dim : output_spatial_dimensions) {
327 new_dnums.add_output_spatial_dimensions(dim);
328 }
329
330 return new_dnums;
331 }
332
ComputeDimensionIndexMapping(const DotConvDimsMapping & dims_mapping,int64_t lhs_rank,int64_t rhs_rank,int64_t output_rank)333 DotDimensionIndexMapping ComputeDimensionIndexMapping(
334 const DotConvDimsMapping& dims_mapping, int64_t lhs_rank, int64_t rhs_rank,
335 int64_t output_rank) {
336 std::vector<int64> lhs_to_rhs_indices(lhs_rank, -1);
337 std::vector<int64> lhs_to_output_indices(lhs_rank, -1);
338 std::vector<int64> rhs_to_lhs_indices(rhs_rank, -1);
339 std::vector<int64> rhs_to_output_indices(rhs_rank, -1);
340 std::vector<int64> output_to_lhs_indices(output_rank, -1);
341 std::vector<int64> output_to_rhs_indices(output_rank, -1);
342 auto populate_indices_mapping =
343 [&](const DotConvDimsMapping::DimsMapping& mapping) {
344 if (mapping.lhs >= 0) {
345 lhs_to_rhs_indices[mapping.lhs] = mapping.rhs;
346 lhs_to_output_indices[mapping.lhs] = mapping.output;
347 }
348 if (mapping.rhs >= 0) {
349 rhs_to_lhs_indices[mapping.rhs] = mapping.lhs;
350 rhs_to_output_indices[mapping.rhs] = mapping.output;
351 }
352 if (mapping.output >= 0) {
353 output_to_lhs_indices[mapping.output] = mapping.lhs;
354 output_to_rhs_indices[mapping.output] = mapping.rhs;
355 }
356 };
357 for (const auto& mapping : dims_mapping.batch_dims) {
358 populate_indices_mapping(mapping);
359 }
360 for (const auto& mapping : dims_mapping.contracting_dims) {
361 populate_indices_mapping(mapping);
362 }
363 for (const auto& mapping : dims_mapping.lhs_non_contracting_dims) {
364 populate_indices_mapping(mapping);
365 }
366 for (const auto& mapping : dims_mapping.rhs_non_contracting_dims) {
367 populate_indices_mapping(mapping);
368 }
369 for (const auto& mapping : dims_mapping.conv_spatial_dims) {
370 populate_indices_mapping(mapping);
371 }
372 return DotDimensionIndexMapping{lhs_to_rhs_indices, lhs_to_output_indices,
373 rhs_to_lhs_indices, rhs_to_output_indices,
374 output_to_lhs_indices, output_to_rhs_indices};
375 }
376
GetWindowedEinsumConfiguration(int64_t num_partitions,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,int64_t rhs_contracting_partitions,int64_t rhs_non_contracting_partitions,int64_t rhs_batch_partitions,int64_t lhs_contracting_partitions,int64_t lhs_non_contracting_partitions,int64_t lhs_batch_partitions,int64_t rhs_shape_size,int64_t lhs_shape_size,int64_t output_shape_size,int64_t einsum_threshold_mib,const absl::optional<HloSharding> & output_sharding_transposed_to_match_lhs,const absl::optional<HloSharding> & output_sharding_transposed_to_match_rhs,const HloSharding & lhs_sharding,const HloSharding & rhs_sharding)377 absl::optional<WindowedEinsumConfig> GetWindowedEinsumConfiguration(
378 int64_t num_partitions, int64_t output_lhs_non_contracting_partitions,
379 int64_t output_rhs_non_contracting_partitions,
380 int64_t rhs_contracting_partitions, int64_t rhs_non_contracting_partitions,
381 int64_t rhs_batch_partitions, int64_t lhs_contracting_partitions,
382 int64_t lhs_non_contracting_partitions, int64_t lhs_batch_partitions,
383 int64_t rhs_shape_size, int64_t lhs_shape_size, int64_t output_shape_size,
384 int64_t einsum_threshold_mib,
385 const absl::optional<HloSharding>& output_sharding_transposed_to_match_lhs,
386 const absl::optional<HloSharding>& output_sharding_transposed_to_match_rhs,
387 const HloSharding& lhs_sharding, const HloSharding& rhs_sharding) {
388 if (output_lhs_non_contracting_partitions == num_partitions &&
389 output_sharding_transposed_to_match_lhs == lhs_sharding &&
390 rhs_shape_size >= einsum_threshold_mib * 1024 * 1024) {
391 if (rhs_contracting_partitions == num_partitions) {
392 return WindowedEinsumConfig{
393 /*windowed_op=*/WindowedEinsumOperand::RHS,
394 /*windowed_at_contracting_dims*/ true,
395 /*windowed_at_batch_dims=*/false,
396 /*operands_sharded_at_contracting_dims=*/false};
397 }
398 if (rhs_non_contracting_partitions == num_partitions) {
399 return WindowedEinsumConfig{
400 /*windowed_op=*/WindowedEinsumOperand::RHS,
401 /*windowed_at_contracting_dims*/ false,
402 /*windowed_at_batch_dims=*/false,
403 /*operands_sharded_at_contracting_dims=*/false};
404 }
405 if (rhs_batch_partitions == num_partitions) {
406 return WindowedEinsumConfig{
407 /*windowed_op=*/WindowedEinsumOperand::RHS,
408 /*windowed_at_contracting_dims*/ false,
409 /*windowed_at_batch_dims=*/true,
410 /*operands_sharded_at_contracting_dims=*/false};
411 }
412 }
413 if (output_rhs_non_contracting_partitions == num_partitions &&
414 output_sharding_transposed_to_match_rhs == rhs_sharding &&
415 lhs_shape_size >= einsum_threshold_mib * 1024 * 1024) {
416 if (lhs_contracting_partitions == num_partitions) {
417 return WindowedEinsumConfig{
418 /*windowed_op=*/WindowedEinsumOperand::LHS,
419 /*windowed_at_contracting_dims*/ true,
420 /*windowed_at_batch_dims=*/false,
421 /*operands_sharded_at_contracting_dims=*/false};
422 }
423 if (lhs_non_contracting_partitions == num_partitions) {
424 return WindowedEinsumConfig{
425 /*windowed_op=*/WindowedEinsumOperand::LHS,
426 /*windowed_at_contracting_dims*/ false,
427 /*windowed_at_batch_dims=*/false,
428 /*operands_sharded_at_contracting_dims=*/false};
429 }
430 if (lhs_batch_partitions == num_partitions) {
431 return WindowedEinsumConfig{
432 /*windowed_op=*/WindowedEinsumOperand::LHS,
433 /*windowed_at_contracting_dims*/ false,
434 /*windowed_at_batch_dims=*/true,
435 /*operands_sharded_at_contracting_dims=*/false};
436 }
437 }
438 if (lhs_contracting_partitions == rhs_contracting_partitions &&
439 lhs_contracting_partitions == num_partitions &&
440 (output_lhs_non_contracting_partitions == num_partitions ||
441 output_rhs_non_contracting_partitions == num_partitions) &&
442 output_shape_size >= einsum_threshold_mib * 1024 * 1024) {
443 if (output_lhs_non_contracting_partitions == num_partitions) {
444 return WindowedEinsumConfig{
445 /*windowed_op=*/WindowedEinsumOperand::RHS,
446 /*windowed_at_contracting_dims*/ false,
447 /*windowed_at_batch_dims=*/false,
448 /*operands_sharded_at_contracting_dims=*/true};
449 }
450 if (output_rhs_non_contracting_partitions == num_partitions) {
451 return WindowedEinsumConfig{
452 /*windowed_op=*/WindowedEinsumOperand::LHS,
453 /*windowed_at_contracting_dims*/ false,
454 /*windowed_at_batch_dims=*/false,
455 /*operands_sharded_at_contracting_dims=*/true};
456 }
457 }
458 return absl::nullopt;
459 }
460
GetLoopReplicaGroups(HloInstruction * while_loop)461 std::vector<ReplicaGroup> GetLoopReplicaGroups(HloInstruction* while_loop) {
462 std::vector<ReplicaGroup> groups;
463 for (auto inst : while_loop->while_body()->instructions()) {
464 if (inst->opcode() == HloOpcode::kCollectivePermute) {
465 std::vector<std::pair<int64, int64>> st_pairs =
466 inst->source_target_pairs();
467 std::vector<int64> source_index(st_pairs.size());
468 for (int64_t i = 0; i < st_pairs.size(); ++i) {
469 source_index[st_pairs[i].first] = i;
470 }
471
472 absl::flat_hash_set<int64> visited;
473 for (int64_t i = 0; i < st_pairs.size(); ++i) {
474 if (visited.contains(st_pairs[i].first)) {
475 continue;
476 }
477 std::vector<int64> replica_group;
478 int64_t source = st_pairs[i].first;
479 int64_t target = st_pairs[i].second;
480 replica_group.push_back(source);
481 replica_group.push_back(target);
482 visited.insert(source);
483 visited.insert(target);
484 while (target != source) {
485 target = st_pairs[source_index[target]].second;
486 if (target != source) {
487 replica_group.push_back(target);
488 visited.insert(target);
489 }
490 }
491 absl::c_sort(replica_group);
492 groups.emplace_back();
493 for (auto id : replica_group) {
494 groups.back().add_replica_ids(id);
495 }
496 }
497
498 VLOG(3) << "while loop: " << while_loop->name()
499 << ", replica groups: " << ReplicaGroupsToString(groups);
500 break;
501 }
502 }
503 return groups;
504 }
505
506 // We use a recursive approach where sets of matching dimensions are recognized
507 // one at a time. The base shapes and shardings can be changed during the
508 // recursion as we group devices together. So refer to the passed in shapes and
509 // shardings for inputs and output, and do not use shape inference.
510
PartitionBaseCase(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64_t num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,int64_t lhs_batch_partitions,int64_t rhs_batch_partitions,int64_t output_batch_partitions,int64_t lhs_contracting_partitions,int64_t rhs_contracting_partitions,int64_t lhs_non_contracting_partitions,int64_t rhs_non_contracting_partitions,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops,bool may_reshard_without_detecting_match)511 StatusOr<HloInstruction*> PartitionBaseCase(
512 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
513 const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
514 int64_t num_partitions,
515 const std::function<StatusOr<HloInstruction*>(
516 HloInstruction*, HloInstruction*, SpmdBuilder*,
517 const Window& conv_window)>& create_sharded_dot,
518 const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
519 int64_t lhs_batch_partitions, int64_t rhs_batch_partitions,
520 int64_t output_batch_partitions, int64_t lhs_contracting_partitions,
521 int64_t rhs_contracting_partitions, int64_t lhs_non_contracting_partitions,
522 int64_t rhs_non_contracting_partitions,
523 int64_t output_lhs_non_contracting_partitions,
524 int64_t output_rhs_non_contracting_partitions,
525 const SpmdPartitionerOptions& options, SpmdBuilder* b,
526 std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
527 windowed_dot_general_loops,
528 bool may_reshard_without_detecting_match) {
529 const HloSharding& lhs_sharding = lhs.sharding();
530 const HloSharding& rhs_sharding = rhs.sharding();
531 if (lhs_sharding.ReplicateOnLastTileDim() ||
532 rhs_sharding.ReplicateOnLastTileDim() ||
533 output_sharding.ReplicateOnLastTileDim()) {
534 return nullptr;
535 }
536 DotDimensionIndexMapping indices_map = ComputeDimensionIndexMapping(
537 dims_mapping, lhs.base_shape().rank(), rhs.base_shape().rank(),
538 output_base_shape.rank());
539 auto lhs_sharding_transposed_to_match_rhs =
540 hlo_sharding_util::TransposeShardingWithCollapsedDims(
541 lhs_sharding, indices_map.lhs_to_rhs_indices,
542 indices_map.rhs_to_lhs_indices);
543 auto rhs_sharding_transposed_to_match_lhs =
544 hlo_sharding_util::TransposeShardingWithCollapsedDims(
545 rhs_sharding, indices_map.rhs_to_lhs_indices,
546 indices_map.lhs_to_rhs_indices);
547 auto lhs_sharding_transposed_to_match_output =
548 hlo_sharding_util::TransposeShardingWithCollapsedDims(
549 lhs_sharding, indices_map.lhs_to_output_indices,
550 indices_map.output_to_lhs_indices);
551 auto rhs_sharding_transposed_to_match_output =
552 hlo_sharding_util::TransposeShardingWithCollapsedDims(
553 rhs_sharding, indices_map.rhs_to_output_indices,
554 indices_map.output_to_rhs_indices);
555 auto output_sharding_transposed_to_match_lhs =
556 hlo_sharding_util::TransposeShardingWithCollapsedDims(
557 output_sharding, indices_map.output_to_lhs_indices,
558 indices_map.lhs_to_output_indices);
559 auto output_sharding_transposed_to_match_rhs =
560 hlo_sharding_util::TransposeShardingWithCollapsedDims(
561 output_sharding, indices_map.output_to_rhs_indices,
562 indices_map.rhs_to_output_indices);
563
564 // LHS and RHS are partitioned the same way and only partitioned in batch
565 // dimensions.
566 if (lhs_batch_partitions == rhs_batch_partitions &&
567 rhs_batch_partitions == num_partitions &&
568 lhs_sharding_transposed_to_match_rhs == rhs_sharding) {
569 TF_ASSIGN_OR_RETURN(
570 auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
571 dot->set_sharding(*lhs_sharding_transposed_to_match_output);
572 return PartitionedHlo(dot, output_base_shape, lhs.state())
573 .Reshard(output_sharding)
574 .hlo();
575 }
576
577 // Try emit batch-partitioned einsum with one operand resharded. Returns
578 // partitioned HLO or nullptr if the attempt fails. If
579 // may_reshard_with_allreduce is false, reshard must be done using
580 // all-to-all/collective-permute; otherwise this attempt fails.
581 auto try_emit_output_batch_partitioned_einsum_with_reshard =
582 [&](bool may_reshard_with_allreduce) -> StatusOr<HloInstruction*> {
583 // LHS and output are batch partitioned in the same way.
584 if (lhs_batch_partitions == num_partitions &&
585 output_batch_partitions == num_partitions &&
586 lhs_sharding_transposed_to_match_output == output_sharding) {
587 if (!may_reshard_with_allreduce &&
588 !CanReshardWithCollectivePermute(
589 rhs.sharding(), *lhs_sharding_transposed_to_match_rhs) &&
590 !GetReshardAllToAllSourceTargetDims(
591 rhs.sharding(), *lhs_sharding_transposed_to_match_rhs)) {
592 return nullptr;
593 }
594 auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs);
595 TF_ASSIGN_OR_RETURN(
596 auto dot,
597 create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), b, conv_window));
598 return dot;
599 }
600 // RHS and output are batch partitioned in the same way.
601 if (rhs_batch_partitions == num_partitions &&
602 output_batch_partitions == num_partitions &&
603 rhs_sharding_transposed_to_match_output == output_sharding) {
604 if (!may_reshard_with_allreduce &&
605 !CanReshardWithCollectivePermute(
606 lhs.sharding(), *rhs_sharding_transposed_to_match_lhs) &&
607 !GetReshardAllToAllSourceTargetDims(
608 lhs.sharding(), *rhs_sharding_transposed_to_match_lhs)) {
609 return nullptr;
610 }
611 auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs);
612 TF_ASSIGN_OR_RETURN(
613 auto dot,
614 create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), b, conv_window));
615 return dot;
616 }
617 return nullptr;
618 };
619
620 {
621 // Try batch-parallel by resharding one operand, and not using all-reduce.
622 TF_ASSIGN_OR_RETURN(
623 HloInstruction * partitioned_dot,
624 try_emit_output_batch_partitioned_einsum_with_reshard(false));
625 if (partitioned_dot) {
626 return partitioned_dot;
627 }
628 }
629
630 // Try to emit windowed DotGeneral when one operand is partitioned in the same
631 // way as the output along non-contracting dimensions, but the other operand
632 // is tiled in other dimensions. Or both operands are partitioned in the same
633 // way along contracting dimensions, but the output is partitioned along
634 // non-contracting dimensions.
635 auto emit_windowed_dot_general =
636 [&](const WindowedEinsumConfig& einsum_config)
637 -> StatusOr<HloInstruction*> {
638 CHECK(!einsum_config.windowed_at_batch_dims ||
639 !einsum_config.windowed_at_contracting_dims);
640 const bool windowed_at_batch_dims = einsum_config.windowed_at_batch_dims;
641 const bool windowed_at_contracting_dims =
642 einsum_config.windowed_at_contracting_dims;
643 const bool operands_sharded_at_contracting_dims =
644 einsum_config.operands_sharded_at_contracting_dims;
645 auto unpadded_result_buffer_shape =
646 MakePartitionedShape(output_base_shape, output_sharding);
647 auto padded_result_buffer_shape = unpadded_result_buffer_shape;
648 const bool windowed_op_is_lhs =
649 einsum_config.windowed_op == WindowedEinsumOperand::LHS;
650 // For windowing at batch/non-contracting dims, we produce the result one
651 // partition at a time, so we need to pad the shape in case of uneven
652 // partitioning in order to make dynamic-update-slice in-bound.
653 if (!windowed_at_contracting_dims &&
654 !operands_sharded_at_contracting_dims) {
655 padded_result_buffer_shape = GetPaddedShapeForUnevenPartitioning(
656 padded_result_buffer_shape,
657 windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
658 : *rhs_sharding_transposed_to_match_output);
659 }
660 // Mask the padding area of the windowed operand with zero if there is
661 // uneven partitioning.
662 if (windowed_at_contracting_dims) {
663 auto& to_mask = windowed_op_is_lhs ? lhs : rhs;
664 to_mask =
665 to_mask.PadWithValue(b->AddInstruction(HloInstruction::CreateConstant(
666 LiteralUtil::Zero(output_base_shape.element_type()))));
667 }
668 if (operands_sharded_at_contracting_dims) {
669 auto zero = b->AddInstruction(HloInstruction::CreateConstant(
670 LiteralUtil::Zero(output_base_shape.element_type())));
671 lhs = lhs.PadWithValue(zero);
672 rhs = rhs.PadWithValue(zero);
673 }
674
675 // Get slice sharding, sharding dim, and lhs/rhs concat dim.
676 const HloSharding* slice_sharding;
677 if (operands_sharded_at_contracting_dims) {
678 slice_sharding = windowed_op_is_lhs
679 ? &*output_sharding_transposed_to_match_rhs
680 : &*output_sharding_transposed_to_match_lhs;
681 } else if (windowed_at_contracting_dims || windowed_at_batch_dims) {
682 slice_sharding = windowed_op_is_lhs
683 ? &*lhs_sharding_transposed_to_match_rhs
684 : &*rhs_sharding_transposed_to_match_lhs;
685 } else {
686 slice_sharding = windowed_op_is_lhs
687 ? &*lhs_sharding_transposed_to_match_output
688 : &*rhs_sharding_transposed_to_match_output;
689 }
690 CHECK_EQ(Product(slice_sharding->tile_assignment().dimensions()),
691 num_partitions);
692 int64_t slice_sharding_dim = -1;
693 for (int64_t i = 0; i < slice_sharding->tile_assignment().num_dimensions();
694 ++i) {
695 if (slice_sharding->tile_assignment().dim(i) > 1) {
696 slice_sharding_dim = i;
697 break;
698 }
699 }
700 int64_t lhs_concat_dim = -1;
701 int64_t rhs_concat_dim = -1;
702 if (operands_sharded_at_contracting_dims) {
703 if (windowed_op_is_lhs) {
704 rhs_concat_dim = slice_sharding_dim;
705 } else {
706 lhs_concat_dim = slice_sharding_dim;
707 }
708 } else if (windowed_at_contracting_dims || windowed_at_batch_dims) {
709 lhs_concat_dim = windowed_op_is_lhs
710 ? indices_map.rhs_to_lhs_indices[slice_sharding_dim]
711 : slice_sharding_dim;
712 rhs_concat_dim = windowed_op_is_lhs
713 ? slice_sharding_dim
714 : indices_map.lhs_to_rhs_indices[slice_sharding_dim];
715 } else {
716 if (windowed_op_is_lhs) {
717 lhs_concat_dim = indices_map.output_to_lhs_indices[slice_sharding_dim];
718 } else {
719 rhs_concat_dim = indices_map.output_to_rhs_indices[slice_sharding_dim];
720 }
721 }
722
723 auto lhs_hlo = lhs.hlo();
724 auto rhs_hlo = rhs.hlo();
725 // Reshape lhs and rhs before the loop for bidirectional communication case.
726 if (options.bidirectional_windowed_einsum && num_partitions % 4 == 0) {
727 if (lhs_concat_dim != -1 && windowed_op_is_lhs &&
728 !operands_sharded_at_contracting_dims) {
729 std::vector<int64> reshaped_dims(lhs_hlo->shape().dimensions().begin(),
730 lhs_hlo->shape().dimensions().end());
731 reshaped_dims.insert(reshaped_dims.begin() + lhs_concat_dim, 1);
732 lhs_hlo = b->AddInstruction(HloInstruction::CreateReshape(
733 ShapeUtil::MakeShape(lhs_hlo->shape().element_type(),
734 reshaped_dims),
735 lhs_hlo));
736 }
737 if (rhs_concat_dim != -1 && !windowed_op_is_lhs &&
738 !operands_sharded_at_contracting_dims) {
739 std::vector<int64> reshaped_dims(rhs_hlo->shape().dimensions().begin(),
740 rhs_hlo->shape().dimensions().end());
741 reshaped_dims.insert(reshaped_dims.begin() + rhs_concat_dim, 1);
742 rhs_hlo = b->AddInstruction(HloInstruction::CreateReshape(
743 ShapeUtil::MakeShape(rhs_hlo->shape().element_type(),
744 reshaped_dims),
745 rhs_hlo));
746 }
747 }
748
749 auto result_buffer = CreateZero(padded_result_buffer_shape, b);
750 auto extra_buffer =
751 (!(options.bidirectional_windowed_einsum && num_partitions % 4 == 0) ||
752 operands_sharded_at_contracting_dims)
753 ? CreateZero(padded_result_buffer_shape, b)
754 : windowed_op_is_lhs ? lhs_hlo
755 : rhs_hlo;
756
757 if (options.bidirectional_windowed_einsum && num_partitions % 4 == 0 &&
758 !operands_sharded_at_contracting_dims) {
759 std::vector<std::pair<int64, int64>> pre_sd_pairs(num_partitions);
760 for (int64_t source = 0; source < num_partitions; ++source) {
761 // 0 -> 1, 1 -> 2, 2 -> 3, ...
762 pre_sd_pairs[source] = {source, (source + 1) % num_partitions};
763 }
764 extra_buffer =
765 lhs.state()
766 .collective_ops_creator.create_cross_partition_collective_permute(
767 b, extra_buffer, pre_sd_pairs,
768 (*lhs.state().next_channel_id)++);
769 }
770
771 auto iteration = b->AddInstruction(
772 HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(0)));
773
774 // Create a while loop that computes one window per iteration. During each
775 // iteration, each partition sends its input window to its neighbor using
776 // collective-permute for the next iteration.
777 SpmdBuilder body_b("windowed_dot_general_body", original_hlo);
778
779 // Generate partial results used by bidirectional algorithm.
780 auto get_partial_bid_results =
781 [&](HloInstruction* l, HloInstruction* r, HloInstruction* o,
782 HloInstruction* extra_inout, HloInstruction* cw_cp_output,
783 HloInstruction* i) -> StatusOr<std::vector<HloInstruction*>> {
784 auto partition_id =
785 lhs.state().collective_ops_creator.create_partition_id(&body_b);
786 auto partition_count =
787 body_b.AddInstruction(HloInstruction::CreateConstant(
788 LiteralUtil::CreateR0<uint32>(num_partitions)));
789 auto ccw_data_partition_id =
790 body_b.AddInstruction(HloInstruction::CreateBinary(
791 i->shape(), HloOpcode::kAdd, i, partition_id));
792 auto cw_data_partition_id =
793 body_b.AddInstruction(HloInstruction::CreateBinary(
794 i->shape(), HloOpcode::kAdd, partition_count, partition_id));
795 if (operands_sharded_at_contracting_dims) {
796 ccw_data_partition_id =
797 body_b.AddInstruction(HloInstruction::CreateBinary(
798 i->shape(), HloOpcode::kAdd, ccw_data_partition_id,
799 body_b.AddInstruction(HloInstruction::CreateConstant(
800 LiteralUtil::CreateR0<uint32>(num_partitions / 2 + 1)))));
801 cw_data_partition_id =
802 body_b.AddInstruction(HloInstruction::CreateBinary(
803 i->shape(), HloOpcode::kSubtract, cw_data_partition_id,
804 body_b.AddInstruction(HloInstruction::CreateConstant(
805 LiteralUtil::CreateR0<uint32>(num_partitions / 2)))));
806 } else {
807 cw_data_partition_id =
808 body_b.AddInstruction(HloInstruction::CreateBinary(
809 i->shape(), HloOpcode::kSubtract, cw_data_partition_id,
810 CreateOne(cw_data_partition_id->shape(), &body_b)));
811 }
812 ccw_data_partition_id = body_b.AddInstruction(
813 HloInstruction::CreateBinary(i->shape(), HloOpcode::kRemainder,
814 ccw_data_partition_id, partition_count));
815 cw_data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary(
816 i->shape(), HloOpcode::kSubtract, cw_data_partition_id, i));
817 cw_data_partition_id = body_b.AddInstruction(
818 HloInstruction::CreateBinary(i->shape(), HloOpcode::kRemainder,
819 cw_data_partition_id, partition_count));
820
821 DotDimensionNumbers new_ddnums;
822 if (original_hlo->opcode() == HloOpcode::kDot) {
823 new_ddnums = original_hlo->dot_dimension_numbers();
824 }
825
826 auto dot_lhs = l;
827 auto dot_rhs = r;
828 auto original_dot_lhs = l;
829 auto original_dot_rhs = r;
830 // Recover original lhs and rhs, will not be used in real computation.
831 if (lhs_concat_dim != -1 && windowed_op_is_lhs) {
832 std::vector<int64> reshaped_dims(
833 original_dot_lhs->shape().dimensions().begin(),
834 original_dot_lhs->shape().dimensions().end());
835 reshaped_dims.erase(reshaped_dims.begin() + lhs_concat_dim);
836 original_dot_lhs = body_b.AddInstruction(HloInstruction::CreateReshape(
837 ShapeUtil::MakeShape(original_dot_lhs->shape().element_type(),
838 reshaped_dims),
839 original_dot_lhs));
840 }
841 if (rhs_concat_dim != -1 && !windowed_op_is_lhs) {
842 std::vector<int64> reshaped_dims(
843 original_dot_rhs->shape().dimensions().begin(),
844 original_dot_rhs->shape().dimensions().end());
845 reshaped_dims.erase(reshaped_dims.begin() + rhs_concat_dim);
846 original_dot_rhs = body_b.AddInstruction(HloInstruction::CreateReshape(
847 ShapeUtil::MakeShape(original_dot_rhs->shape().element_type(),
848 reshaped_dims),
849 original_dot_rhs));
850 }
851
852 if (windowed_at_contracting_dims || windowed_at_batch_dims ||
853 operands_sharded_at_contracting_dims) {
854 // Slice the matching operand according to the partitioned dimensions
855 // on the windowed operand or the output.
856 auto slice_operand = !windowed_op_is_lhs ? l : r;
857
858 // Pad the sharding dim first (then the concat dim) for correctness.
859 auto sharding_dim_size =
860 slice_operand->shape().dimensions(slice_sharding_dim);
861 if (sharding_dim_size % num_partitions != 0) {
862 slice_operand = PadBaseShapeBeforeUnevenTiledSharding(
863 slice_operand, *slice_sharding, &body_b);
864 }
865
866 // We do this by treating the matching operand as replicated, and
867 // resharding it to match the windowed operand or the output.
868 auto gen_slice = [&](HloInstruction* data_partition_id,
869 bool ccw) -> HloInstruction* {
870 std::vector<int64> new_dims;
871 for (int64_t i = 0; i < slice_operand->shape().dimensions_size();
872 ++i) {
873 if (i == slice_sharding_dim) {
874 new_dims.push_back(1);
875 }
876 new_dims.push_back(slice_operand->shape().dimensions(i));
877 }
878 auto reshaped_slice_operand =
879 body_b.AddInstruction(HloInstruction::CreateReshape(
880 ShapeUtil::MakeShape(slice_operand->shape().element_type(),
881 new_dims),
882 slice_operand));
883 auto min = body_b.AddInstruction(
884 HloInstruction::CreateConstant(LiteralUtil::MinValue(
885 reshaped_slice_operand->shape().element_type())));
886 std::vector<int64> min_padding(
887 reshaped_slice_operand->shape().rank());
888 auto padded_slice_operand = reshaped_slice_operand;
889 auto padded_shape = padded_slice_operand->shape();
890 int64_t padding_dim = slice_sharding_dim;
891 padded_shape.set_dimensions(padding_dim, 2);
892 if (ccw) {
893 // ccw pad high
894 PaddingConfig ccw_pad_config =
895 window_util::MakeSymmetricPadding(min_padding);
896 ccw_pad_config.mutable_dimensions(padding_dim)
897 ->set_edge_padding_low(0);
898 ccw_pad_config.mutable_dimensions(padding_dim)
899 ->set_edge_padding_high(1);
900 padded_slice_operand =
901 body_b.AddInstruction(HloInstruction::CreatePad(
902 padded_shape, padded_slice_operand, min, ccw_pad_config));
903 } else {
904 // cw pad low
905 PaddingConfig cw_pad_config =
906 window_util::MakeSymmetricPadding(min_padding);
907 cw_pad_config.mutable_dimensions(padding_dim)
908 ->set_edge_padding_low(1);
909 cw_pad_config.mutable_dimensions(padding_dim)
910 ->set_edge_padding_high(0);
911 padded_slice_operand =
912 body_b.AddInstruction(HloInstruction::CreatePad(
913 padded_shape, padded_slice_operand, min, cw_pad_config));
914 }
915
916 padded_slice_operand->set_sharding(HloSharding::Replicate());
917 auto state = lhs.state();
918 state.b = &body_b;
919 state.partition_id = data_partition_id;
920 state.reshard_cache->per_hlo_cache.erase(padded_slice_operand);
921 auto padded_slice_sharding = hlo_sharding_util::ReshapeSharding(
922 slice_operand->shape(), reshaped_slice_operand->shape(),
923 *slice_sharding);
924 auto padded_slice =
925 PartitionedHlo(padded_slice_operand,
926 padded_slice_operand->shape(), state)
927 .Reshard(*padded_slice_sharding)
928 .hlo();
929 padded_slice_operand->clear_sharding();
930 return padded_slice;
931 };
932
933 auto ccw_slice = gen_slice(ccw_data_partition_id, true);
934 auto cw_slice = gen_slice(cw_data_partition_id, false);
935 auto slice = body_b.AddInstruction(HloInstruction::CreateBinary(
936 ccw_slice->shape(), HloOpcode::kMaximum, ccw_slice, cw_slice));
937 // Reshape. The reshaped slice will not be used to produce the final
938 // result, but used as a hint for the shape inference.
939 std::vector<int64> reshaped_slice_dims;
940 for (int64_t i = 0; i < slice->shape().dimensions_size(); ++i) {
941 auto dim_size = slice->shape().dimensions(i);
942 if (i == (slice_sharding_dim + 1)) {
943 reshaped_slice_dims.push_back(dim_size * 2);
944 } else if (i != slice_sharding_dim) {
945 reshaped_slice_dims.push_back(dim_size);
946 }
947 }
948 auto reshaped_slice =
949 body_b.AddInstruction(HloInstruction::CreateReshape(
950 ShapeUtil::MakeShape(slice->shape().element_type(),
951 reshaped_slice_dims),
952 slice));
953
954 if (!windowed_op_is_lhs) {
955 dot_lhs = slice;
956 original_dot_lhs = reshaped_slice;
957 if (original_hlo->opcode() == HloOpcode::kDot) {
958 UpdateDDNums(&new_ddnums, slice_sharding_dim, true);
959 }
960 } else {
961 dot_rhs = slice;
962 original_dot_rhs = reshaped_slice;
963 if (original_hlo->opcode() == HloOpcode::kDot) {
964 UpdateDDNums(&new_ddnums, slice_sharding_dim, false);
965 }
966 }
967 }
968
969 auto ccw_dot_lhs = l;
970 auto ccw_dot_rhs = r;
971 auto cw_dot_lhs = windowed_op_is_lhs ? extra_inout : l;
972 auto cw_dot_rhs = windowed_op_is_lhs ? r : extra_inout;
973 if (lhs_concat_dim != -1 && windowed_op_is_lhs) {
974 // Concat
975 auto lhs_concat_shape = ccw_dot_lhs->shape();
976 lhs_concat_shape.set_dimensions(lhs_concat_dim, 2);
977 dot_lhs = body_b.AddInstruction(HloInstruction::CreateConcatenate(
978 lhs_concat_shape, {ccw_dot_lhs, cw_dot_lhs}, lhs_concat_dim));
979
980 std::vector<int64> reshaped_dims(
981 ccw_dot_lhs->shape().dimensions().begin(),
982 ccw_dot_lhs->shape().dimensions().end());
983 reshaped_dims.erase(reshaped_dims.begin() + lhs_concat_dim);
984 reshaped_dims[lhs_concat_dim] *= 2;
985 original_dot_lhs = body_b.AddInstruction(HloInstruction::CreateReshape(
986 ShapeUtil::MakeShape(dot_lhs->shape().element_type(),
987 reshaped_dims),
988 dot_lhs));
989
990 if (original_hlo->opcode() == HloOpcode::kDot) {
991 UpdateDDNums(&new_ddnums, lhs_concat_dim, true);
992 }
993 }
994 if (rhs_concat_dim != -1 && !windowed_op_is_lhs) {
995 // Concat
996 auto rhs_concat_shape = ccw_dot_rhs->shape();
997 rhs_concat_shape.set_dimensions(rhs_concat_dim, 2);
998 dot_rhs = body_b.AddInstruction(HloInstruction::CreateConcatenate(
999 rhs_concat_shape, {ccw_dot_rhs, cw_dot_rhs}, rhs_concat_dim));
1000
1001 std::vector<int64> reshaped_dims(
1002 ccw_dot_rhs->shape().dimensions().begin(),
1003 ccw_dot_rhs->shape().dimensions().end());
1004 reshaped_dims.erase(reshaped_dims.begin() + rhs_concat_dim);
1005 reshaped_dims[rhs_concat_dim] *= 2;
1006 original_dot_rhs = body_b.AddInstruction(HloInstruction::CreateReshape(
1007 ShapeUtil::MakeShape(dot_rhs->shape().element_type(),
1008 reshaped_dims),
1009 dot_rhs));
1010
1011 if (original_hlo->opcode() == HloOpcode::kDot) {
1012 UpdateDDNums(&new_ddnums, rhs_concat_dim, false);
1013 }
1014 }
1015
1016 // The generated original dot will not be used.
1017 TF_ASSIGN_OR_RETURN(auto original_dot,
1018 create_sharded_dot(original_dot_lhs, original_dot_rhs,
1019 &body_b, conv_window));
1020 VLOG(2) << original_dot->ToString();
1021
1022 // Generate the correct shape of the new dot/conv.
1023 auto original_sharded_dot_shape = original_dot->shape();
1024 auto new_dot_shape = original_sharded_dot_shape;
1025 std::vector<int64> new_dims(new_dot_shape.dimensions().begin(),
1026 new_dot_shape.dimensions().end());
1027 if (!windowed_at_contracting_dims) {
1028 auto slice_dim =
1029 lhs_concat_dim != -1
1030 ? indices_map.lhs_to_output_indices[lhs_concat_dim]
1031 : indices_map.rhs_to_output_indices[rhs_concat_dim];
1032 new_dims[slice_dim] /= 2;
1033 new_dims.insert(new_dims.begin() + slice_dim, 2);
1034 } else if (original_hlo->opcode() != HloOpcode::kDot) {
1035 new_dims.push_back(1);
1036 }
1037 new_dot_shape =
1038 ShapeUtil::MakeShape(original_hlo->shape().element_type(), new_dims);
1039
1040 HloInstruction* dot;
1041 if (original_hlo->opcode() == HloOpcode::kDot) {
1042 dot = body_b.AddInstruction(HloInstruction::CreateDot(
1043 new_dot_shape, dot_lhs, dot_rhs, new_ddnums,
1044 original_hlo->precision_config()));
1045 } else {
1046 if (!windowed_at_contracting_dims && !windowed_at_batch_dims) {
1047 if (lhs_concat_dim != -1) {
1048 std::vector<int64> new_dims(dot_rhs->shape().dimensions().begin(),
1049 dot_rhs->shape().dimensions().end());
1050 new_dims.push_back(1);
1051 dot_rhs = body_b.AddInstruction(HloInstruction::CreateReshape(
1052 ShapeUtil::MakeShape(dot_rhs->shape().element_type(), new_dims),
1053 dot_rhs));
1054 }
1055 if (rhs_concat_dim != -1) {
1056 std::vector<int64> new_dims(dot_lhs->shape().dimensions().begin(),
1057 dot_lhs->shape().dimensions().end());
1058 new_dims.push_back(1);
1059 dot_lhs = body_b.AddInstruction(HloInstruction::CreateReshape(
1060 ShapeUtil::MakeShape(dot_lhs->shape().element_type(), new_dims),
1061 dot_lhs));
1062 }
1063 }
1064
1065 dot = body_b.AddInstruction(HloInstruction::CreateConvolve(
1066 new_dot_shape, dot_lhs, dot_rhs,
1067 original_dot->feature_group_count(),
1068 original_dot->batch_group_count(),
1069 GenNewWindow(original_dot, dot_lhs, dot_rhs, lhs_concat_dim,
1070 rhs_concat_dim, windowed_at_contracting_dims,
1071 windowed_at_batch_dims),
1072 GenNewConvDNums(original_dot, dot_lhs, dot_rhs, lhs_concat_dim,
1073 rhs_concat_dim, windowed_at_contracting_dims,
1074 windowed_at_batch_dims,
1075 indices_map.lhs_to_output_indices,
1076 indices_map.rhs_to_output_indices, new_dot_shape),
1077 original_dot->precision_config()));
1078 }
1079 VLOG(2) << dot->ToString();
1080
1081 if (windowed_at_contracting_dims) {
1082 if (original_hlo->opcode() != HloOpcode::kDot) {
1083 // Reshape to the original sharded dot shape.
1084 dot = body_b.AddInstruction(
1085 HloInstruction::CreateReshape(original_sharded_dot_shape, dot));
1086 }
1087
1088 // Accumulate the partial output to the result buffer.
1089 o = body_b.AddInstruction(
1090 HloInstruction::CreateBinary(o->shape(), HloOpcode::kAdd, o, dot));
1091 } else {
1092 // The windowing operand is partitioned along batch/non-contracting
1093 // dimensions, so we need a dynamic-update-slice to save the partial
1094 // output in the result buffer.
1095 auto slice_shape = dot->shape();
1096 auto slice_dim =
1097 lhs_concat_dim != -1
1098 ? indices_map.lhs_to_output_indices[lhs_concat_dim]
1099 : indices_map.rhs_to_output_indices[rhs_concat_dim];
1100 slice_shape.set_dimensions(slice_dim, 1);
1101 std::vector<int64> ccw_start_indices(dot->shape().rank(), 0);
1102 std::vector<int64> cw_start_indices(dot->shape().rank(), 0);
1103 cw_start_indices[slice_dim] = 1;
1104 auto ccw_dot = body_b.AddInstruction(HloInstruction::CreateSlice(
1105 slice_shape, dot, ccw_start_indices, slice_shape.dimensions(),
1106 std::vector<int64>(dot->shape().rank(), 1)));
1107 auto cw_dot = body_b.AddInstruction(HloInstruction::CreateSlice(
1108 slice_shape, dot, cw_start_indices, dot->shape().dimensions(),
1109 std::vector<int64>(dot->shape().rank(), 1)));
1110
1111 std::vector<int64> reshaped_dims(
1112 original_sharded_dot_shape.dimensions().begin(),
1113 original_sharded_dot_shape.dimensions().end());
1114 reshaped_dims[slice_dim] /= 2;
1115 ccw_dot = body_b.AddInstruction(HloInstruction::CreateReshape(
1116 ShapeUtil::MakeShape(ccw_dot->shape().element_type(),
1117 reshaped_dims),
1118 ccw_dot));
1119 cw_dot = body_b.AddInstruction(HloInstruction::CreateReshape(
1120 ShapeUtil::MakeShape(cw_dot->shape().element_type(), reshaped_dims),
1121 cw_dot));
1122
1123 if (operands_sharded_at_contracting_dims) {
1124 // Accumulate the partial output to the result buffer.
1125 o = body_b.AddInstruction(HloInstruction::CreateBinary(
1126 o->shape(), HloOpcode::kAdd, o, ccw_dot));
1127 cw_cp_output = body_b.AddInstruction(HloInstruction::CreateBinary(
1128 o->shape(), HloOpcode::kAdd, cw_cp_output, cw_dot));
1129 } else {
1130 auto ccw_offsets = MakePartitionOffsets(
1131 o->shape(),
1132 windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
1133 : *rhs_sharding_transposed_to_match_output,
1134 ccw_data_partition_id, &body_b);
1135 auto cw_offsets = MakePartitionOffsets(
1136 o->shape(),
1137 windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
1138 : *rhs_sharding_transposed_to_match_output,
1139 cw_data_partition_id, &body_b);
1140 o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1141 o->shape(), o, ccw_dot, ccw_offsets));
1142 o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1143 o->shape(), o, cw_dot, cw_offsets));
1144 }
1145 }
1146
1147 std::vector<HloInstruction*> partial_results;
1148 partial_results.push_back(o);
1149 partial_results.push_back(cw_cp_output);
1150 return partial_results;
1151 };
1152
1153 // Generate partial result used by unidirectional algorithm.
1154 auto get_partial_unid_result =
1155 [&](HloInstruction* l, HloInstruction* r, HloInstruction* o,
1156 HloInstruction* i) -> StatusOr<HloInstruction*> {
1157 auto partition_id =
1158 lhs.state().collective_ops_creator.create_partition_id(&body_b);
1159 auto data_partition_id =
1160 body_b.AddInstruction(HloInstruction::CreateBinary(
1161 i->shape(), HloOpcode::kAdd, i, partition_id));
1162 auto partition_count =
1163 body_b.AddInstruction(HloInstruction::CreateConstant(
1164 LiteralUtil::CreateR0<uint32>(num_partitions)));
1165 data_partition_id = body_b.AddInstruction(
1166 HloInstruction::CreateBinary(i->shape(), HloOpcode::kRemainder,
1167 data_partition_id, partition_count));
1168 auto dot_lhs = l;
1169 auto dot_rhs = r;
1170 if (windowed_at_contracting_dims || windowed_at_batch_dims ||
1171 operands_sharded_at_contracting_dims) {
1172 // Slice the matching operand according to the partitioned dimensions on
1173 // the windowed operand or the output.
1174 auto slice_operand = !windowed_op_is_lhs ? l : r;
1175 // We do this by treating the matching operand as replicated, and
1176 // resharding it to match the windowed operand or the output.
1177 slice_operand->set_sharding(HloSharding::Replicate());
1178 auto state = lhs.state();
1179 state.b = &body_b;
1180 state.partition_id = data_partition_id;
1181 state.reshard_cache->per_hlo_cache.erase(slice_operand);
1182 auto slice =
1183 PartitionedHlo(slice_operand, slice_operand->shape(), state)
1184 .Reshard(*slice_sharding)
1185 .hlo();
1186 slice_operand->clear_sharding();
1187 if (!windowed_op_is_lhs) {
1188 dot_lhs = slice;
1189 } else {
1190 dot_rhs = slice;
1191 }
1192 }
1193 TF_ASSIGN_OR_RETURN(
1194 auto dot, create_sharded_dot(dot_lhs, dot_rhs, &body_b, conv_window));
1195 if (windowed_at_contracting_dims ||
1196 operands_sharded_at_contracting_dims) {
1197 // Accumulate the partial output to the result buffer.
1198 o = body_b.AddInstruction(
1199 HloInstruction::CreateBinary(o->shape(), HloOpcode::kAdd, o, dot));
1200 } else {
1201 // The windowing operand is partitioned along batch/non-contracting
1202 // dimensions, so we need a dynamic-update-slice to save the partial
1203 // output in the result buffer.
1204 auto offsets = MakePartitionOffsets(
1205 o->shape(),
1206 windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
1207 : *rhs_sharding_transposed_to_match_output,
1208 data_partition_id, &body_b);
1209 o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1210 o->shape(), o, dot, offsets));
1211 }
1212 return o;
1213 };
1214
1215 auto param = body_b.AddInstruction(HloInstruction::CreateParameter(
1216 /*parameter_number=*/0,
1217 ShapeUtil::MakeTupleShape({lhs_hlo->shape(), rhs_hlo->shape(),
1218 result_buffer->shape(),
1219 extra_buffer->shape(), iteration->shape()}),
1220 "param"));
1221 auto l = body_b.AddInstruction(
1222 HloInstruction::CreateGetTupleElement(lhs_hlo->shape(), param, 0));
1223 auto r = body_b.AddInstruction(
1224 HloInstruction::CreateGetTupleElement(rhs_hlo->shape(), param, 1));
1225 auto o = body_b.AddInstruction(HloInstruction::CreateGetTupleElement(
1226 result_buffer->shape(), param, 2));
1227 auto extra_inout = body_b.AddInstruction(
1228 HloInstruction::CreateGetTupleElement(extra_buffer->shape(), param, 3));
1229 auto i = body_b.AddInstruction(
1230 HloInstruction::CreateGetTupleElement(iteration->shape(), param, 4));
1231
1232 // The bidirectional collective permute implementation has loop unrolling
1233 // of degree 2, so num_partitions is required to be a multiple of 4.
1234 if (options.bidirectional_windowed_einsum && num_partitions % 4 == 0) {
1235 std::vector<std::pair<int64, int64>> ccw_sd_pairs(num_partitions);
1236 for (int64_t source = 0; source < num_partitions; ++source) {
1237 // 0 -> n-1, 1 -> 0, 2 -> 1, ...
1238 ccw_sd_pairs[source] = {source,
1239 (source - 1 + num_partitions) % num_partitions};
1240 }
1241 std::vector<std::pair<int64, int64>> cw_sd_pairs(num_partitions);
1242 for (int64_t source = 0; source < num_partitions; ++source) {
1243 // 0 -> 1, 1 -> 2, 2 -> 3, ...
1244 cw_sd_pairs[source] = {source, (source + 1) % num_partitions};
1245 }
1246
1247 // Even number iteration.
1248 auto next_l = l;
1249 auto next_r = r;
1250 auto ccw_cp_input = operands_sharded_at_contracting_dims ? o
1251 : windowed_op_is_lhs ? l
1252 : r;
1253 auto ccw_cp_output =
1254 lhs.state()
1255 .collective_ops_creator.create_cross_partition_collective_permute(
1256 &body_b, ccw_cp_input, ccw_sd_pairs,
1257 (*lhs.state().next_channel_id)++);
1258 if (operands_sharded_at_contracting_dims) {
1259 o = ccw_cp_output;
1260 } else if (windowed_op_is_lhs) {
1261 next_l = ccw_cp_output;
1262 } else {
1263 next_r = ccw_cp_output;
1264 }
1265 auto cw_cp_input = extra_inout;
1266 auto cw_cp_output =
1267 lhs.state()
1268 .collective_ops_creator.create_cross_partition_collective_permute(
1269 &body_b, cw_cp_input, cw_sd_pairs,
1270 (*lhs.state().next_channel_id)++);
1271
1272 TF_ASSIGN_OR_RETURN(
1273 auto outputs,
1274 get_partial_bid_results(l, r, o, extra_inout, cw_cp_output, i));
1275 o = outputs[0];
1276 cw_cp_output = outputs[1];
1277
1278 // ++i
1279 i = body_b.AddInstruction(HloInstruction::CreateBinary(
1280 i->shape(), HloOpcode::kAdd, i, CreateOne(i->shape(), &body_b)));
1281
1282 // Odd number iteration.
1283 auto second_next_l = next_l;
1284 auto second_next_r = next_r;
1285 ccw_cp_input = operands_sharded_at_contracting_dims ? o
1286 : windowed_op_is_lhs ? next_l
1287 : next_r;
1288 ccw_cp_output =
1289 lhs.state()
1290 .collective_ops_creator.create_cross_partition_collective_permute(
1291 &body_b, ccw_cp_input, ccw_sd_pairs,
1292 (*lhs.state().next_channel_id)++);
1293 if (operands_sharded_at_contracting_dims) {
1294 o = ccw_cp_output;
1295 } else if (windowed_op_is_lhs) {
1296 second_next_l = ccw_cp_output;
1297 } else {
1298 second_next_r = ccw_cp_output;
1299 }
1300 auto next_cw_cp_input = cw_cp_output;
1301 auto next_cw_cp_output =
1302 lhs.state()
1303 .collective_ops_creator.create_cross_partition_collective_permute(
1304 &body_b, next_cw_cp_input, cw_sd_pairs,
1305 (*lhs.state().next_channel_id)++);
1306
1307 TF_ASSIGN_OR_RETURN(
1308 outputs, get_partial_bid_results(next_l, next_r, o, cw_cp_output,
1309 next_cw_cp_output, i));
1310 o = outputs[0];
1311 next_cw_cp_output = outputs[1];
1312
1313 // ++i
1314 i = body_b.AddInstruction(HloInstruction::CreateBinary(
1315 i->shape(), HloOpcode::kAdd, i, CreateOne(i->shape(), &body_b)));
1316
1317 body_b.AddInstruction(HloInstruction::CreateTuple(
1318 {second_next_l, second_next_r, o, next_cw_cp_output, i}));
1319
1320 } else if (options.unroll_windowed_einsum && num_partitions % 2 == 0) {
1321 if (operands_sharded_at_contracting_dims) {
1322 std::vector<std::pair<int64, int64>> output_sd_pairs(num_partitions);
1323 for (int64_t source = 0; source < num_partitions; ++source) {
1324 // 0 -> n-2, 1 -> n-1, 2 -> 0, ...
1325 output_sd_pairs[source] = {
1326 source, (source - 2 + num_partitions) % num_partitions};
1327 }
1328
1329 o = lhs.state()
1330 .collective_ops_creator
1331 .create_cross_partition_collective_permute(
1332 &body_b, o, output_sd_pairs,
1333 (*lhs.state().next_channel_id)++);
1334
1335 TF_ASSIGN_OR_RETURN(extra_inout,
1336 get_partial_unid_result(l, r, extra_inout, i));
1337
1338 extra_inout = lhs.state()
1339 .collective_ops_creator
1340 .create_cross_partition_collective_permute(
1341 &body_b, extra_inout, output_sd_pairs,
1342 (*lhs.state().next_channel_id)++);
1343
1344 // i+2
1345 i = body_b.AddInstruction(HloInstruction::CreateBinary(
1346 i->shape(), HloOpcode::kAdd, i,
1347 body_b.AddInstruction(HloInstruction::CreateConstant(
1348 LiteralUtil::CreateR0<uint32>(2)))));
1349 auto real_i = body_b.AddInstruction(HloInstruction::CreateBinary(
1350 i->shape(), HloOpcode::kAdd, i,
1351 body_b.AddInstruction(HloInstruction::CreateConstant(
1352 LiteralUtil::CreateR0<uint32>(1)))));
1353
1354 TF_ASSIGN_OR_RETURN(o, get_partial_unid_result(l, r, o, real_i));
1355 body_b.AddInstruction(
1356 HloInstruction::CreateTuple({l, r, o, extra_inout, i}));
1357 } else {
1358 std::vector<std::pair<int64, int64>> sd_pairs(num_partitions);
1359 for (int64_t source = 0; source < num_partitions; ++source) {
1360 // 0 -> n-1, 1 -> 0, 2 -> 1, ...
1361 sd_pairs[source] = {source,
1362 (source - 1 + num_partitions) % num_partitions};
1363 }
1364
1365 // Even number iteration.
1366 auto next_l = l;
1367 auto next_r = r;
1368 auto cp_input = windowed_op_is_lhs ? l : r;
1369 auto cp_output = lhs.state()
1370 .collective_ops_creator
1371 .create_cross_partition_collective_permute(
1372 &body_b, cp_input, sd_pairs,
1373 (*lhs.state().next_channel_id)++);
1374 if (windowed_op_is_lhs) {
1375 next_l = cp_output;
1376 } else {
1377 next_r = cp_output;
1378 }
1379 TF_ASSIGN_OR_RETURN(o, get_partial_unid_result(l, r, o, i));
1380
1381 // ++i
1382 i = body_b.AddInstruction(HloInstruction::CreateBinary(
1383 i->shape(), HloOpcode::kAdd, i,
1384 body_b.AddInstruction(HloInstruction::CreateConstant(
1385 LiteralUtil::CreateR0<uint32>(1)))));
1386
1387 // Odd number iteration.
1388 auto second_next_l = next_l;
1389 auto second_next_r = next_r;
1390 cp_input = windowed_op_is_lhs ? next_l : next_r;
1391 cp_output = lhs.state()
1392 .collective_ops_creator
1393 .create_cross_partition_collective_permute(
1394 &body_b, cp_input, sd_pairs,
1395 (*lhs.state().next_channel_id)++);
1396 if (windowed_op_is_lhs) {
1397 second_next_l = cp_output;
1398 } else {
1399 second_next_r = cp_output;
1400 }
1401 TF_ASSIGN_OR_RETURN(o, get_partial_unid_result(next_l, next_r, o, i));
1402
1403 // ++i
1404 i = body_b.AddInstruction(HloInstruction::CreateBinary(
1405 i->shape(), HloOpcode::kAdd, i,
1406 body_b.AddInstruction(HloInstruction::CreateConstant(
1407 LiteralUtil::CreateR0<uint32>(1)))));
1408
1409 body_b.AddInstruction(HloInstruction::CreateTuple(
1410 {second_next_l, second_next_r, o, extra_inout, i}));
1411 }
1412 } else {
1413 auto real_i = i;
1414 if (operands_sharded_at_contracting_dims) {
1415 // For reduce-scatter case, start from the data_partition_id + 1 to make
1416 // the data_partition_id of the final data shard in each partition the
1417 // same as the corresponding partition_id.
1418 real_i = body_b.AddInstruction(HloInstruction::CreateBinary(
1419 real_i->shape(), HloOpcode::kAdd, real_i,
1420 CreateOne(real_i->shape(), &body_b)));
1421 }
1422 TF_ASSIGN_OR_RETURN(o, get_partial_unid_result(l, r, o, real_i));
1423
1424 // ++i
1425 i = body_b.AddInstruction(HloInstruction::CreateBinary(
1426 i->shape(), HloOpcode::kAdd, i,
1427 body_b.AddInstruction(HloInstruction::CreateConstant(
1428 LiteralUtil::CreateR0<uint32>(1)))));
1429 auto has_more = body_b.AddInstruction(HloInstruction::CreateCompare(
1430 ShapeUtil::MakeShape(PRED, {}), i,
1431 body_b.AddInstruction(HloInstruction::CreateConstant(
1432 LiteralUtil::CreateR0<uint32>(num_partitions))),
1433 ComparisonDirection::kLt));
1434 // Collective-permute for the next window. We don't need it for the last
1435 // iteration, so we use a conditional around the collective-permute.
1436 HloInstruction* conditional;
1437 {
1438 SpmdBuilder cp_b("window_collective_permute", original_hlo);
1439 {
1440 auto p = cp_b.AddInstruction(HloInstruction::CreateParameter(
1441 0,
1442 operands_sharded_at_contracting_dims ? o->shape()
1443 : windowed_op_is_lhs ? l->shape()
1444 : r->shape(),
1445 "window"));
1446 std::vector<std::pair<int64, int64>> sd_pairs(num_partitions);
1447 for (int64_t source = 0; source < num_partitions; ++source) {
1448 // 0 -> n-1, 1 -> 0, 2 -> 1, ...
1449 sd_pairs[source] = {source,
1450 (source - 1 + num_partitions) % num_partitions};
1451 }
1452 lhs.state()
1453 .collective_ops_creator.create_cross_partition_collective_permute(
1454 &cp_b, p, sd_pairs, (*lhs.state().next_channel_id)++);
1455 }
1456 SpmdBuilder ncp_b("last_iteration_noop", original_hlo);
1457 {
1458 ncp_b.AddInstruction(HloInstruction::CreateParameter(
1459 0,
1460 operands_sharded_at_contracting_dims ? o->shape()
1461 : windowed_op_is_lhs ? l->shape()
1462 : r->shape(),
1463 "window"));
1464 }
1465 conditional = body_b.AddInstruction(HloInstruction::CreateConditional(
1466 operands_sharded_at_contracting_dims ? o->shape()
1467 : windowed_op_is_lhs ? l->shape()
1468 : r->shape(),
1469 has_more,
1470 operands_sharded_at_contracting_dims ? o
1471 : windowed_op_is_lhs ? l
1472 : r,
1473 module->AddEmbeddedComputation(cp_b.Build()),
1474 operands_sharded_at_contracting_dims ? o
1475 : windowed_op_is_lhs ? l
1476 : r,
1477 module->AddEmbeddedComputation(ncp_b.Build())));
1478 }
1479 if (operands_sharded_at_contracting_dims) {
1480 o = conditional;
1481 } else if (windowed_op_is_lhs) {
1482 l = conditional;
1483 } else {
1484 r = conditional;
1485 }
1486 body_b.AddInstruction(
1487 HloInstruction::CreateTuple({l, r, o, extra_inout, i}));
1488 }
1489
1490 SpmdBuilder cond_b("windowed_dot_general_cond", original_hlo);
1491 auto cond_param = cond_b.AddInstruction(HloInstruction::CreateParameter(
1492 /*parameter_number=*/0,
1493 ShapeUtil::MakeTupleShape({lhs_hlo->shape(), rhs_hlo->shape(),
1494 result_buffer->shape(),
1495 extra_buffer->shape(), iteration->shape()}),
1496 "param"));
1497 auto cond_i = cond_b.AddInstruction(HloInstruction::CreateGetTupleElement(
1498 iteration->shape(), cond_param, 4));
1499 int64_t adapted_num_partitions =
1500 (options.bidirectional_windowed_einsum && num_partitions % 4 == 0)
1501 ? num_partitions / 2
1502 : num_partitions;
1503 cond_b.AddInstruction(HloInstruction::CreateCompare(
1504 ShapeUtil::MakeShape(PRED, {}), cond_i,
1505 cond_b.AddInstruction(HloInstruction::CreateConstant(
1506 LiteralUtil::CreateR0<uint32>(adapted_num_partitions))),
1507 ComparisonDirection::kLt));
1508 auto while_loop = b->AddInstruction(HloInstruction::CreateWhile(
1509 cond_param->shape(), module->AddEmbeddedComputation(cond_b.Build()),
1510 module->AddEmbeddedComputation(body_b.Build()),
1511 b->AddInstruction(HloInstruction::CreateTuple(
1512 {lhs_hlo, rhs_hlo, result_buffer, extra_buffer, iteration}))));
1513 windowed_dot_general_loops->push_back(
1514 {while_loop, windowed_op_is_lhs ? 0 : 1, windowed_at_contracting_dims,
1515 windowed_at_batch_dims, operands_sharded_at_contracting_dims,
1516 num_partitions, GetLoopReplicaGroups(while_loop)});
1517 auto result = b->AddInstruction(HloInstruction::CreateGetTupleElement(
1518 result_buffer->shape(), while_loop, 2));
1519 if (((options.bidirectional_windowed_einsum && num_partitions % 4 == 0) ||
1520 (options.unroll_windowed_einsum && num_partitions % 2 == 0)) &&
1521 operands_sharded_at_contracting_dims) {
1522 std::vector<std::pair<int64, int64>> extra_sd_pairs(num_partitions);
1523 for (int64_t source = 0; source < num_partitions; ++source) {
1524 // 0 -> 1, 1 -> 2, 2 -> 3, ...
1525 extra_sd_pairs[source] = {source, (source + 1) % num_partitions};
1526 }
1527 auto extra_result =
1528 b->AddInstruction(HloInstruction::CreateGetTupleElement(
1529 extra_buffer->shape(), while_loop, 3));
1530 if (options.bidirectional_windowed_einsum && num_partitions % 4 == 0) {
1531 extra_result = lhs.state()
1532 .collective_ops_creator
1533 .create_cross_partition_collective_permute(
1534 b, extra_result, extra_sd_pairs,
1535 (*lhs.state().next_channel_id)++);
1536 }
1537 if (options.unroll_windowed_einsum && num_partitions % 2 == 0) {
1538 result = lhs.state()
1539 .collective_ops_creator
1540 .create_cross_partition_collective_permute(
1541 b, result, extra_sd_pairs,
1542 (*lhs.state().next_channel_id)++);
1543 }
1544 result = b->AddInstruction(HloInstruction::CreateBinary(
1545 result->shape(), HloOpcode::kAdd, result, extra_result));
1546 }
1547 if (!ShapeUtil::Compatible(padded_result_buffer_shape,
1548 unpadded_result_buffer_shape)) {
1549 result = b->AddInstruction(HloInstruction::CreateSlice(
1550 unpadded_result_buffer_shape, result,
1551 std::vector<int64>(padded_result_buffer_shape.rank(), 0),
1552 unpadded_result_buffer_shape.dimensions(),
1553 std::vector<int64>(padded_result_buffer_shape.rank(), 1)));
1554 }
1555 return result;
1556 };
1557 absl::optional<WindowedEinsumConfig> e_config =
1558 GetWindowedEinsumConfiguration(
1559 num_partitions, output_lhs_non_contracting_partitions,
1560 output_rhs_non_contracting_partitions, rhs_contracting_partitions,
1561 rhs_non_contracting_partitions, rhs_batch_partitions,
1562 lhs_contracting_partitions, lhs_non_contracting_partitions,
1563 lhs_batch_partitions, ShapeSizeInBytes(rhs.base_shape()),
1564 ShapeSizeInBytes(lhs.base_shape()),
1565 ShapeSizeInBytes(output_base_shape),
1566 options.threshold_for_windowed_einsum_mib,
1567 output_sharding_transposed_to_match_lhs,
1568 output_sharding_transposed_to_match_rhs, lhs_sharding, rhs_sharding);
1569 if (e_config) {
1570 return emit_windowed_dot_general(*e_config);
1571 }
1572
1573 {
1574 // Try batch-parallel by resharding one operand, and allowing all-reduce.
1575 TF_ASSIGN_OR_RETURN(
1576 HloInstruction * partitioned_dot,
1577 try_emit_output_batch_partitioned_einsum_with_reshard(true));
1578 if (partitioned_dot) {
1579 return partitioned_dot;
1580 }
1581 }
1582
1583 // LHS and RHS have the same partitioned contracting dimensions.
1584 if (lhs_contracting_partitions == rhs_contracting_partitions &&
1585 lhs_contracting_partitions == num_partitions) {
1586 auto zero = b->AddInstruction(HloInstruction::CreateConstant(
1587 LiteralUtil::Zero(output_base_shape.element_type())));
1588 // Pad both sides with zero, since NaN at one side cannot be masked by zero
1589 // on the other side.
1590 if (ShapeSizeInBytes(lhs.base_shape()) <
1591 ShapeSizeInBytes(rhs.base_shape())) {
1592 lhs =
1593 lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero);
1594 rhs = rhs.PadWithValue(zero);
1595 } else {
1596 lhs = lhs.PadWithValue(zero);
1597 rhs =
1598 rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero);
1599 }
1600 TF_ASSIGN_OR_RETURN(
1601 auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
1602 std::vector<int64> lhs_contracting_dims;
1603 lhs_contracting_dims.reserve(lhs.base_shape().rank());
1604 for (const auto& cd : dims_mapping.contracting_dims) {
1605 lhs_contracting_dims.push_back(cd.lhs);
1606 }
1607 auto ar = lhs.state().partitioner->AllReduceAlongShardingDims(
1608 b, dot, lhs.sharding(), lhs.state().next_channel_id,
1609 lhs_contracting_dims, lhs.state().collective_ops_creator,
1610 MakeBinaryAdd(output_base_shape.element_type(), module));
1611 ar->set_sharding(HloSharding::Replicate());
1612 return PartitionedHlo(ar, output_base_shape, lhs.state())
1613 .Reshard(output_sharding)
1614 .hlo();
1615 }
1616
1617 // LHS and output have the same partitioned non-contracting dimensions.
1618 if (lhs_non_contracting_partitions == num_partitions &&
1619 output_lhs_non_contracting_partitions == num_partitions &&
1620 lhs_sharding_transposed_to_match_output == output_sharding) {
1621 auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo();
1622 TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs_replicated,
1623 b, conv_window));
1624 return dot;
1625 }
1626
1627 // RHS and output have the same partitioned non-contracting dimensions.
1628 if (rhs_non_contracting_partitions == num_partitions &&
1629 output_rhs_non_contracting_partitions == num_partitions &&
1630 rhs_sharding_transposed_to_match_output == output_sharding) {
1631 auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo();
1632 TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs_replicated, rhs.hlo(),
1633 b, conv_window));
1634 return dot;
1635 }
1636
1637 if (may_reshard_without_detecting_match) {
1638 // Output is batch partitioned.
1639 if (output_batch_partitions == num_partitions) {
1640 auto resharded_lhs =
1641 lhs.Reshard(*output_sharding_transposed_to_match_lhs);
1642 auto resharded_rhs =
1643 rhs.Reshard(*output_sharding_transposed_to_match_rhs);
1644 TF_ASSIGN_OR_RETURN(
1645 auto dot, create_sharded_dot(resharded_lhs.hlo(), resharded_rhs.hlo(),
1646 b, conv_window));
1647 return dot;
1648 }
1649 // Output is partitioned along LHS non-contracting dimensions.
1650 if (output_lhs_non_contracting_partitions == num_partitions) {
1651 auto resharded_lhs =
1652 lhs.Reshard(*output_sharding_transposed_to_match_lhs);
1653 auto replicated_rhs = rhs.Reshard(HloSharding::Replicate());
1654 TF_ASSIGN_OR_RETURN(
1655 auto dot, create_sharded_dot(resharded_lhs.hlo(),
1656 replicated_rhs.hlo(), b, conv_window));
1657 return dot;
1658 }
1659 // Output is partitioned along RHS non-contracting dimensions.
1660 if (output_rhs_non_contracting_partitions == num_partitions) {
1661 auto replicated_lhs = lhs.Reshard(HloSharding::Replicate());
1662 auto resharded_rhs =
1663 rhs.Reshard(*output_sharding_transposed_to_match_rhs);
1664 TF_ASSIGN_OR_RETURN(
1665 auto dot, create_sharded_dot(replicated_lhs.hlo(),
1666 resharded_rhs.hlo(), b, conv_window));
1667 return dot;
1668 }
1669 }
1670
1671 // Returns true if it is beneficial to reshard the operand at `operand_idx`
1672 // across the contracting dimension.
1673 const auto should_partition_contracting_dim = [&](int64_t operand_idx) {
1674 if (!output_sharding.IsReplicated()) {
1675 return false;
1676 }
1677
1678 if (operand_idx == 0) {
1679 // If LHS and output are replicated, we compare the cost of all-gather
1680 // on RHS vs all-reduce on the output.
1681 return (rhs_contracting_partitions == num_partitions) &&
1682 lhs.sharding().IsReplicated() &&
1683 ShapeUtil::ElementsIn(rhs.base_shape()) >
1684 ShapeUtil::ElementsIn(output_base_shape);
1685 } else {
1686 return (lhs_contracting_partitions == num_partitions) &&
1687 rhs.sharding().IsReplicated() &&
1688 ShapeUtil::ElementsIn(lhs.base_shape()) >
1689 ShapeUtil::ElementsIn(output_base_shape);
1690 }
1691 };
1692
1693 // When the output is replicated and one of the operands is partitioned along
1694 // contracting dimension, align the other operand to be partitioned along
1695 // the contracting dimensions.
1696 if (output_sharding.IsReplicated() && (should_partition_contracting_dim(0) ||
1697 should_partition_contracting_dim(1))) {
1698 auto zero = b->AddInstruction(HloInstruction::CreateConstant(
1699 LiteralUtil::Zero(output_base_shape.element_type())));
1700 if (should_partition_contracting_dim(0)) {
1701 lhs =
1702 lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero);
1703 rhs = rhs.PadWithValue(zero);
1704 } else {
1705 lhs = lhs.PadWithValue(zero);
1706 rhs =
1707 rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero);
1708 }
1709 TF_ASSIGN_OR_RETURN(
1710 auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
1711
1712 std::vector<int64> lhs_contracting_dims;
1713 lhs_contracting_dims.reserve(lhs.base_shape().rank());
1714 for (const auto& cd : dims_mapping.contracting_dims) {
1715 lhs_contracting_dims.push_back(cd.lhs);
1716 }
1717 return lhs.state().partitioner->AllReduceAlongShardingDims(
1718 b, dot, lhs.sharding(), lhs.state().next_channel_id,
1719 lhs_contracting_dims, lhs.state().collective_ops_creator,
1720 MakeBinaryAdd(output_base_shape.element_type(), module));
1721 }
1722 return nullptr;
1723 }
1724
1725 StatusOr<HloInstruction*> PartitionDot(
1726 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
1727 const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
1728 int64_t num_partitions,
1729 const std::function<StatusOr<HloInstruction*>(
1730 HloInstruction*, HloInstruction*, SpmdBuilder*,
1731 const Window& conv_window)>& create_sharded_dot,
1732 const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
1733 const SpmdPartitionerOptions& options, SpmdBuilder* b,
1734 std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
1735 windowed_dot_general_loops);
1736
PartitionDotGroupOnBatch(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64_t num_partitions,int64_t lhs_contracting_partitions,int64_t rhs_contracting_partitions,int64_t lhs_non_contracting_partitions,int64_t rhs_non_contracting_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,bool require_matching_devices_to_group,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops)1737 StatusOr<HloInstruction*> PartitionDotGroupOnBatch(
1738 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
1739 const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
1740 int64_t num_partitions, int64_t lhs_contracting_partitions,
1741 int64_t rhs_contracting_partitions, int64_t lhs_non_contracting_partitions,
1742 int64_t rhs_non_contracting_partitions,
1743 const std::function<StatusOr<HloInstruction*>(
1744 HloInstruction*, HloInstruction*, SpmdBuilder*,
1745 const Window& conv_window)>& create_sharded_dot,
1746 const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
1747 bool require_matching_devices_to_group,
1748 const SpmdPartitionerOptions& options, SpmdBuilder* b,
1749 std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
1750 windowed_dot_general_loops) {
1751 std::vector<std::pair<HloInstruction*, HloSharding>>
1752 top_level_sharding_to_reset;
1753 auto cleaner = tensorflow::gtl::MakeCleanup([&] {
1754 for (auto& to_reset : top_level_sharding_to_reset) {
1755 to_reset.first->set_sharding(to_reset.second);
1756 }
1757 });
1758 std::vector<int64> lhs_dims;
1759 std::vector<int64> rhs_dims;
1760 std::vector<int64> output_dims;
1761 auto lhs_sharding_dims_adjusted_to_output =
1762 lhs.sharding().IsReplicated()
1763 ? std::vector<int64>(lhs.base_shape().rank(), 1)
1764 : lhs.sharding().tile_assignment().dimensions();
1765 auto rhs_sharding_dims_adjusted_to_output =
1766 rhs.sharding().IsReplicated()
1767 ? std::vector<int64>(rhs.base_shape().rank(), 1)
1768 : rhs.sharding().tile_assignment().dimensions();
1769 auto output_sharding_dims_adjusted_to_lhs =
1770 output_sharding.tile_assignment().dimensions();
1771 bool lhs_rhs_dims_matching = true;
1772 for (const auto& dim : dims_mapping.batch_dims) {
1773 lhs_dims.push_back(dim.lhs);
1774 rhs_dims.push_back(dim.rhs);
1775 output_dims.push_back(dim.output);
1776 if (lhs_sharding_dims_adjusted_to_output[dim.lhs] !=
1777 rhs_sharding_dims_adjusted_to_output[dim.rhs]) {
1778 lhs_rhs_dims_matching = false;
1779 }
1780 lhs_sharding_dims_adjusted_to_output[dim.lhs] =
1781 output_sharding.tile_assignment().dim(dim.output);
1782 rhs_sharding_dims_adjusted_to_output[dim.rhs] =
1783 output_sharding.tile_assignment().dim(dim.output);
1784 output_sharding_dims_adjusted_to_lhs[dim.output] =
1785 lhs.sharding().tile_assignment().dim(dim.lhs);
1786 }
1787 if (require_matching_devices_to_group && lhs_rhs_dims_matching) {
1788 lhs_rhs_dims_matching =
1789 rhs.sharding() == UngroupSharding(AlignGroupsWith(
1790 GroupShardingOnDims(rhs.sharding(), rhs_dims),
1791 GroupShardingOnDims(lhs.sharding(), lhs_dims)));
1792 }
1793 auto output_grouped = GroupShardingOnDims(output_sharding, output_dims);
1794 PartitionedHlo per_group_lhs = lhs;
1795 PartitionedHlo per_group_rhs = rhs;
1796 if (lhs_rhs_dims_matching) {
1797 auto lhs_grouped = GroupShardingOnDims(lhs.sharding(), lhs_dims);
1798 auto rhs_grouped = GroupShardingOnDims(rhs.sharding(), rhs_dims);
1799 if (ShapeUtil::ByteSizeOf(lhs.hlo()->shape()) >
1800 ShapeUtil::ByteSizeOf(rhs.hlo()->shape())) {
1801 rhs_grouped = AlignGroupsWith(std::move(rhs_grouped), lhs_grouped);
1802 rhs = rhs.Reshard(UngroupSharding(rhs_grouped));
1803 } else {
1804 lhs_grouped = AlignGroupsWith(std::move(lhs_grouped), rhs_grouped);
1805 lhs = lhs.Reshard(UngroupSharding(lhs_grouped));
1806 }
1807 auto reshaped_output_tiling = output_sharding.tile_assignment();
1808 reshaped_output_tiling.Reshape(output_sharding_dims_adjusted_to_lhs);
1809 output_grouped = AlignGroupsWith(
1810 GroupShardingOnDims(
1811 output_sharding.ReplicateOnLastTileDim()
1812 ? HloSharding::PartialTile(reshaped_output_tiling)
1813 : HloSharding::Tile(reshaped_output_tiling),
1814 output_dims),
1815 lhs_grouped);
1816 auto per_group_partitioner_state = CreatePerGroupPartitioningState(
1817 lhs.state(), lhs_grouped.device_groups, b);
1818 top_level_sharding_to_reset.emplace_back(lhs.hlo(), lhs.sharding());
1819 lhs.hlo()->set_sharding(lhs_grouped.sharding);
1820 top_level_sharding_to_reset.emplace_back(rhs.hlo(), rhs.sharding());
1821 rhs.hlo()->set_sharding(rhs_grouped.sharding);
1822 CHECK(lhs.hlo() != rhs.hlo() ||
1823 lhs_grouped.sharding == rhs_grouped.sharding);
1824 per_group_lhs = PartitionedHlo(
1825 lhs.hlo(), GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()),
1826 per_group_partitioner_state);
1827 per_group_rhs = PartitionedHlo(
1828 rhs.hlo(), GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()),
1829 per_group_partitioner_state);
1830 } else {
1831 auto per_group_partitioner_state = CreatePerGroupPartitioningState(
1832 lhs.state(), output_grouped.device_groups, b);
1833 auto reshard_to_output_batch =
1834 [&](PartitionedHlo operand, absl::Span<const int64> batch_dims,
1835 absl::Span<const int64> contracting_dims,
1836 absl::Span<const int64> non_contracting_dims,
1837 int64_t contracting_dim_partitions,
1838 int64_t non_contracting_dim_partitions,
1839 int64_t other_contracting_dim_partitions,
1840 std::vector<int64>* sharding_dims_adjusted_to_output)
1841 -> absl::optional<PartitionedHlo> {
1842 if (operand.sharding().IsTileMaximal()) {
1843 auto partially_sharded = PerGroupSliceFromReplicated(
1844 operand.Replicate().hlo(), operand.state().partition_id,
1845 output_grouped.device_groups, batch_dims,
1846 output_grouped.group_dim_sizes, b);
1847 partially_sharded->set_sharding(HloSharding::Replicate());
1848 return PartitionedHlo(partially_sharded, partially_sharded->shape(),
1849 per_group_partitioner_state);
1850 }
1851 auto reshaped_tiling = operand.sharding().tile_assignment();
1852 // It's possible that the operand is not initially sharded on batch
1853 // dimensions in the same way as the output, although being tiled. In that
1854 // case, the current sharding_dims_adjusted_to_output may contain more
1855 // partitions than available devices. We remove partitioning on other
1856 // dimensions.
1857 if (Product(*sharding_dims_adjusted_to_output) >
1858 reshaped_tiling.num_elements()) {
1859 if (Product(*sharding_dims_adjusted_to_output) %
1860 reshaped_tiling.num_elements() !=
1861 0) {
1862 return absl::nullopt;
1863 }
1864 int64_t ratio = Product(*sharding_dims_adjusted_to_output) /
1865 reshaped_tiling.num_elements();
1866 if (operand.sharding().ReplicateOnLastTileDim() &&
1867 reshaped_tiling.dimensions().back() % ratio == 0) {
1868 sharding_dims_adjusted_to_output->back() /= ratio;
1869 if (sharding_dims_adjusted_to_output->back() == 1) {
1870 sharding_dims_adjusted_to_output->pop_back();
1871 }
1872 } else if (ratio == non_contracting_dim_partitions &&
1873 (ratio != contracting_dim_partitions ||
1874 contracting_dim_partitions ==
1875 other_contracting_dim_partitions)) {
1876 for (int64_t dim : non_contracting_dims) {
1877 (*sharding_dims_adjusted_to_output)[dim] = 1;
1878 }
1879 } else if (ratio == contracting_dim_partitions) {
1880 for (int64_t dim : contracting_dims) {
1881 (*sharding_dims_adjusted_to_output)[dim] = 1;
1882 }
1883 } else {
1884 return absl::nullopt;
1885 }
1886 }
1887 // If the operand is initially sharded more ways than the output in the
1888 // batch dimensions, sharding_dims_adjusted_to_output currently contains
1889 // fewer partitions than available devices. We do not handle this case.
1890 if (Product(*sharding_dims_adjusted_to_output) <
1891 reshaped_tiling.num_elements()) {
1892 return absl::nullopt;
1893 }
1894 reshaped_tiling.Reshape(*sharding_dims_adjusted_to_output);
1895 auto grouped = AlignGroupsWith(
1896 GroupShardingOnDims(operand.base_shape().rank() <
1897 sharding_dims_adjusted_to_output->size()
1898 ? HloSharding::PartialTile(reshaped_tiling)
1899 : HloSharding::Tile(reshaped_tiling),
1900 batch_dims),
1901 output_grouped);
1902 if (require_matching_devices_to_group &&
1903 operand.sharding() != UngroupSharding(grouped)) {
1904 return absl::nullopt;
1905 }
1906 auto resharded = operand.Reshard(UngroupSharding(grouped));
1907 top_level_sharding_to_reset.emplace_back(resharded.hlo(),
1908 resharded.sharding());
1909 resharded.hlo()->set_sharding(grouped.sharding);
1910 return PartitionedHlo(resharded.hlo(),
1911 GetPerGroupBaseShape(grouped, operand.base_shape()),
1912 per_group_partitioner_state);
1913 };
1914 std::vector<int64> lhs_contracting_dims;
1915 std::vector<int64> rhs_contracting_dims;
1916 lhs_contracting_dims.reserve(dims_mapping.contracting_dims.size());
1917 rhs_contracting_dims.reserve(dims_mapping.contracting_dims.size());
1918 for (const auto& dim : dims_mapping.contracting_dims) {
1919 lhs_contracting_dims.push_back(dim.lhs);
1920 rhs_contracting_dims.push_back(dim.rhs);
1921 }
1922 std::vector<int64> lhs_non_contracting_dims;
1923 std::vector<int64> rhs_non_contracting_dims;
1924 lhs_non_contracting_dims.reserve(
1925 dims_mapping.lhs_non_contracting_dims.size());
1926 rhs_non_contracting_dims.reserve(
1927 dims_mapping.rhs_non_contracting_dims.size());
1928 for (const auto& dim : dims_mapping.lhs_non_contracting_dims) {
1929 lhs_non_contracting_dims.push_back(dim.lhs);
1930 }
1931 for (const auto& dim : dims_mapping.rhs_non_contracting_dims) {
1932 rhs_non_contracting_dims.push_back(dim.rhs);
1933 }
1934 if (auto resharded = reshard_to_output_batch(
1935 lhs, lhs_dims, lhs_contracting_dims, lhs_non_contracting_dims,
1936 lhs_contracting_partitions, lhs_non_contracting_partitions,
1937 rhs_contracting_partitions,
1938 &lhs_sharding_dims_adjusted_to_output)) {
1939 per_group_lhs = *resharded;
1940 } else {
1941 return nullptr;
1942 }
1943 if (auto resharded = reshard_to_output_batch(
1944 rhs, rhs_dims, rhs_contracting_dims, rhs_non_contracting_dims,
1945 rhs_contracting_partitions, rhs_non_contracting_partitions,
1946 lhs_contracting_partitions,
1947 &rhs_sharding_dims_adjusted_to_output)) {
1948 per_group_rhs = *resharded;
1949 } else {
1950 return nullptr;
1951 }
1952 CHECK(lhs.hlo() != rhs.hlo() ||
1953 per_group_lhs.sharding() == per_group_rhs.sharding());
1954 }
1955 TF_ASSIGN_OR_RETURN(
1956 auto dot,
1957 PartitionDot(per_group_lhs, per_group_rhs,
1958 GetPerGroupBaseShape(output_grouped, output_base_shape),
1959 output_grouped.sharding, dims_mapping,
1960 num_partitions / output_grouped.device_groups.size(),
1961 create_sharded_dot, conv_window, module, original_hlo,
1962 options, b, windowed_dot_general_loops));
1963 dot->set_sharding(UngroupSharding(output_grouped));
1964 return PartitionedHlo(dot, output_base_shape, lhs.state())
1965 .Reshard(output_sharding)
1966 .hlo();
1967 }
1968
GetNonContractingPartitionGroupedShardingForMatchedOperand(bool lhs_matching,const HloSharding & matching_sharding,const HloSharding & output_sharding,absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_dims)1969 GroupedSharding GetNonContractingPartitionGroupedShardingForMatchedOperand(
1970 bool lhs_matching, const HloSharding& matching_sharding,
1971 const HloSharding& output_sharding,
1972 absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_dims) {
1973 std::vector<int64> matching_sharding_dims =
1974 matching_sharding.tile_assignment().dimensions();
1975 std::vector<int64> matching_dims;
1976 std::vector<int64> output_dims;
1977 // Make sure the partitioning on matching's non-contracting dimensions
1978 // defines the same device groups for both matching and output.
1979 for (const auto& dim : partitioned_dims) {
1980 int64_t md = lhs_matching ? dim.lhs : dim.rhs;
1981 matching_sharding_dims[md] =
1982 output_sharding.tile_assignment().dim(dim.output);
1983 matching_dims.push_back(md);
1984 output_dims.push_back(dim.output);
1985 }
1986 GroupedSharding output_grouped =
1987 GroupShardingOnDims(output_sharding, output_dims);
1988 Array<int64> reshaped_matching_tiling = matching_sharding.tile_assignment();
1989 reshaped_matching_tiling.Reshape(matching_sharding_dims);
1990 return AlignGroupsWith(
1991 GroupShardingOnDims(
1992 matching_sharding.ReplicateOnLastTileDim()
1993 ? HloSharding::PartialTile(reshaped_matching_tiling)
1994 : HloSharding::Tile(reshaped_matching_tiling),
1995 matching_dims),
1996 output_grouped);
1997 }
1998
1999 absl::optional<GroupedSharding>
GetNonContractingPartitionGroupedShardingForOtherOperand(bool lhs_matching,const Shape & output_base_shape,const Shape & other_shape,int64_t other_contracting_partitions,int64_t other_non_contracting_partitions,int64_t matching_contracting_partitions,int64_t output_other_non_contracting_partitions,const HloSharding & other_sharding,const HloSharding & output_sharding,absl::Span<const DotConvDimsMapping::DimsMapping> matching_partitioned_dims,absl::Span<const DotConvDimsMapping::DimsMapping> other_non_contracting_dims,absl::Span<const DotConvDimsMapping::DimsMapping> other_contracting_dims)2000 GetNonContractingPartitionGroupedShardingForOtherOperand(
2001 bool lhs_matching, const Shape& output_base_shape, const Shape& other_shape,
2002 int64_t other_contracting_partitions,
2003 int64_t other_non_contracting_partitions,
2004 int64_t matching_contracting_partitions,
2005 int64_t output_other_non_contracting_partitions,
2006 const HloSharding& other_sharding, const HloSharding& output_sharding,
2007 absl::Span<const DotConvDimsMapping::DimsMapping> matching_partitioned_dims,
2008 absl::Span<const DotConvDimsMapping::DimsMapping>
2009 other_non_contracting_dims,
2010 absl::Span<const DotConvDimsMapping::DimsMapping> other_contracting_dims) {
2011 int64_t group_count = 1;
2012 std::vector<int64> output_dims;
2013 for (const auto& dim : matching_partitioned_dims) {
2014 output_dims.push_back(dim.output);
2015 group_count *= output_sharding.tile_assignment().dim(dim.output);
2016 }
2017 GroupedSharding output_grouped =
2018 GroupShardingOnDims(output_sharding, output_dims);
2019 std::vector<int64> other_group_dims;
2020 if (other_sharding.ReplicateOnLastTileDim() &&
2021 other_sharding.tile_assignment().dimensions().back() % group_count == 0) {
2022 other_group_dims.push_back(
2023 other_sharding.tile_assignment().num_dimensions() - 1);
2024 } else {
2025 const bool may_replicate_other_contracting_dims =
2026 (other_contracting_partitions == group_count &&
2027 other_non_contracting_partitions ==
2028 output_other_non_contracting_partitions);
2029 const bool may_replicate_other_non_contracting_dims =
2030 group_count == other_non_contracting_partitions &&
2031 matching_contracting_partitions == other_contracting_partitions;
2032 if (auto found_dims = FindMatchingPartitionedDimsForGrouping(
2033 other_sharding, output_grouped.device_groups)) {
2034 other_group_dims = std::move(*found_dims);
2035 } else if (may_replicate_other_contracting_dims &&
2036 (!may_replicate_other_non_contracting_dims ||
2037 ShapeUtil::ByteSizeOf(other_shape)) <=
2038 ShapeUtil::ByteSizeOf(MakePartitionedShape(
2039 output_base_shape, output_sharding))) {
2040 for (const auto& dim : other_contracting_dims) {
2041 other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs);
2042 }
2043 } else if (may_replicate_other_non_contracting_dims) {
2044 for (const auto& dim : other_non_contracting_dims) {
2045 other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs);
2046 }
2047 } else {
2048 return absl::nullopt;
2049 }
2050 }
2051 if (other_group_dims.size() == 1 &&
2052 other_group_dims[0] ==
2053 other_sharding.tile_assignment().num_dimensions() - 1) {
2054 return AlignGroupsWith(
2055 GroupShardingOnDims(
2056 other_sharding, {other_group_dims[0]},
2057 {other_sharding.tile_assignment().dimensions().back() /
2058 group_count}),
2059 output_grouped, /*ignore_group_order=*/true);
2060
2061 } else if (!other_sharding.IsReplicated()) {
2062 return AlignGroupsWith(
2063 GroupShardingOnDims(other_sharding, other_group_dims), output_grouped,
2064 /*ignore_group_order=*/true);
2065 }
2066 return absl::nullopt;
2067 }
2068
PartitionDotGroupOnNonContracting(bool lhs_matching,PartitionedHlo matching,PartitionedHlo other,int64_t matching_contracting_partitions,int64_t other_contracting_partitions,absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_non_contracting_dims,int64_t other_non_contracting_partitions,int64_t output_other_non_contracting_partitions,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64_t num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,bool require_matching_devices_to_group,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops)2069 StatusOr<HloInstruction*> PartitionDotGroupOnNonContracting(
2070 bool lhs_matching, PartitionedHlo matching, PartitionedHlo other,
2071 int64_t matching_contracting_partitions,
2072 int64_t other_contracting_partitions,
2073 absl::Span<const DotConvDimsMapping::DimsMapping>
2074 partitioned_non_contracting_dims,
2075 int64_t other_non_contracting_partitions,
2076 int64_t output_other_non_contracting_partitions,
2077 const Shape& output_base_shape, const HloSharding& output_sharding,
2078 const DotConvDimsMapping& dims_mapping, int64_t num_partitions,
2079 const std::function<StatusOr<HloInstruction*>(
2080 HloInstruction*, HloInstruction*, SpmdBuilder*,
2081 const Window& conv_window)>& create_sharded_dot,
2082 const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
2083 bool require_matching_devices_to_group,
2084 const SpmdPartitionerOptions& options, SpmdBuilder* b,
2085 std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
2086 windowed_dot_general_loops) {
2087 std::vector<std::pair<HloInstruction*, HloSharding>>
2088 top_level_sharding_to_reset;
2089 auto cleaner = tensorflow::gtl::MakeCleanup([&] {
2090 for (auto& to_reset : top_level_sharding_to_reset) {
2091 to_reset.first->set_sharding(to_reset.second);
2092 }
2093 });
2094
2095 std::vector<int64> output_dims;
2096 for (const auto& dim : partitioned_non_contracting_dims) {
2097 output_dims.push_back(dim.output);
2098 }
2099 GroupedSharding output_grouped =
2100 GroupShardingOnDims(output_sharding, output_dims);
2101 GroupedSharding matching_grouped =
2102 GetNonContractingPartitionGroupedShardingForMatchedOperand(
2103 lhs_matching, matching.sharding(), output_sharding,
2104 partitioned_non_contracting_dims);
2105 if (require_matching_devices_to_group &&
2106 matching.sharding() != UngroupSharding(matching_grouped)) {
2107 return nullptr;
2108 }
2109 absl::optional<GroupedSharding> other_grouped =
2110 GetNonContractingPartitionGroupedShardingForOtherOperand(
2111 lhs_matching, output_base_shape, other.hlo()->shape(),
2112 other_contracting_partitions, other_non_contracting_partitions,
2113 matching_contracting_partitions,
2114 output_other_non_contracting_partitions, other.sharding(),
2115 output_sharding, partitioned_non_contracting_dims,
2116 lhs_matching ? dims_mapping.rhs_non_contracting_dims
2117 : dims_mapping.lhs_non_contracting_dims,
2118 dims_mapping.contracting_dims);
2119
2120 if (!other_grouped) {
2121 other = other.Replicate();
2122 }
2123 matching = matching.Reshard(UngroupSharding(matching_grouped));
2124 auto per_group_partitioner_state = CreatePerGroupPartitioningState(
2125 matching.state(), matching_grouped.device_groups, b);
2126 top_level_sharding_to_reset.emplace_back(matching.hlo(), matching.sharding());
2127 matching.hlo()->set_sharding(matching_grouped.sharding);
2128 auto matching_p = PartitionedHlo(
2129 matching.hlo(),
2130 GetPerGroupBaseShape(matching_grouped, matching.base_shape()),
2131 per_group_partitioner_state);
2132
2133 auto partially_replicated_other = other.hlo();
2134 if (other_grouped && other_grouped->group_dims.size() == 1 &&
2135 other_grouped->group_dims[0] == other.base_shape().rank()) {
2136 // Group on replication dim.
2137 other = other.Reshard(UngroupSharding(*other_grouped));
2138 partially_replicated_other = other.hlo();
2139 top_level_sharding_to_reset.emplace_back(other.hlo(), other.sharding());
2140 partially_replicated_other->set_sharding(other_grouped->sharding);
2141 } else if (!other.sharding().IsReplicated()) {
2142 other = other.Reshard(UngroupSharding(*other_grouped));
2143 partially_replicated_other =
2144 other
2145 .Reshard(hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2146 other.sharding(), other_grouped->group_dims))
2147 .hlo();
2148 top_level_sharding_to_reset.emplace_back(
2149 partially_replicated_other, partially_replicated_other->sharding());
2150 partially_replicated_other->set_sharding(other_grouped->sharding);
2151 }
2152 auto other_p = PartitionedHlo(partially_replicated_other, other.base_shape(),
2153 per_group_partitioner_state);
2154 TF_ASSIGN_OR_RETURN(
2155 auto dot,
2156 PartitionDot(lhs_matching ? matching_p : other_p,
2157 lhs_matching ? other_p : matching_p,
2158 GetPerGroupBaseShape(output_grouped, output_base_shape),
2159 output_grouped.sharding, dims_mapping,
2160 num_partitions / matching_grouped.device_groups.size(),
2161 create_sharded_dot, conv_window, module, original_hlo,
2162 options, b, windowed_dot_general_loops));
2163 return dot;
2164 }
2165
2166 std::pair<HloSharding, HloSharding>
GetDotGroupPartitionContractingOutputShardings(const DotConvDimsMapping & dims_mapping,const GroupedSharding & lhs_grouped,const Shape & output_base_shape,const HloSharding & output_sharding,int64_t group_count,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,int64_t output_batch_partitions,std::vector<int64> * output_slice_dims_out)2167 GetDotGroupPartitionContractingOutputShardings(
2168 const DotConvDimsMapping& dims_mapping, const GroupedSharding& lhs_grouped,
2169 const Shape& output_base_shape, const HloSharding& output_sharding,
2170 int64_t group_count, int64_t output_lhs_non_contracting_partitions,
2171 int64_t output_rhs_non_contracting_partitions,
2172 int64_t output_batch_partitions,
2173 std::vector<int64>* output_slice_dims_out) {
2174 HloSharding inner_output_sharding = HloSharding::Replicate();
2175 HloSharding outer_output_tmp_sharding = HloSharding::Replicate();
2176 std::vector<int64> output_slice_dims;
2177 if (output_sharding.ReplicateOnLastTileDim() &&
2178 output_sharding.tile_assignment().dimensions().back() % group_count ==
2179 0) {
2180 auto grouped = AlignGroupsWith(
2181 GroupShardingOnDims(
2182 output_sharding,
2183 {output_sharding.tile_assignment().num_dimensions() - 1},
2184 {output_sharding.tile_assignment().dimensions().back() /
2185 group_count}),
2186 lhs_grouped,
2187 /*ignore_group_order=*/true);
2188 outer_output_tmp_sharding = UngroupSharding(grouped);
2189 inner_output_sharding = std::move(grouped.sharding);
2190 } else {
2191 if (auto found_dims = FindMatchingPartitionedDimsForGrouping(
2192 output_sharding, lhs_grouped.device_groups)) {
2193 output_slice_dims = std::move(*found_dims);
2194 } else if (output_lhs_non_contracting_partitions == group_count ||
2195 output_rhs_non_contracting_partitions == group_count ||
2196 output_batch_partitions == group_count) {
2197 if (output_lhs_non_contracting_partitions == group_count) {
2198 for (const auto& dim : dims_mapping.lhs_non_contracting_dims) {
2199 output_slice_dims.push_back(dim.output);
2200 }
2201 } else if (output_rhs_non_contracting_partitions == group_count) {
2202 for (const auto& dim : dims_mapping.rhs_non_contracting_dims) {
2203 output_slice_dims.push_back(dim.output);
2204 }
2205 } else {
2206 for (const auto& dim : dims_mapping.batch_dims) {
2207 output_slice_dims.push_back(dim.output);
2208 }
2209 }
2210 }
2211 if (!output_slice_dims.empty()) {
2212 auto grouped = AlignGroupsWith(
2213 GroupShardingOnDims(output_sharding, output_slice_dims), lhs_grouped);
2214 inner_output_sharding = grouped.sharding;
2215 outer_output_tmp_sharding = UngroupSharding(grouped);
2216 }
2217 }
2218 if (output_slice_dims_out) {
2219 (*output_slice_dims_out) = std::move(output_slice_dims);
2220 }
2221 return std::make_pair(inner_output_sharding, outer_output_tmp_sharding);
2222 }
2223
2224 std::pair<HloSharding, HloSharding>
GetDotGroupPartitionContractingLhsRhsShardings(const PartitionedHlo & lhs,const PartitionedHlo & rhs,absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_contracting_dims)2225 GetDotGroupPartitionContractingLhsRhsShardings(
2226 const PartitionedHlo& lhs, const PartitionedHlo& rhs,
2227 absl::Span<const DotConvDimsMapping::DimsMapping>
2228 partitioned_contracting_dims) {
2229 HloSharding lhs_sharding = lhs.sharding();
2230 HloSharding rhs_sharding = rhs.sharding();
2231 std::vector<int64> lhs_tile_shape =
2232 lhs_sharding.tile_assignment().dimensions();
2233 std::vector<int64> rhs_tile_shape =
2234 rhs_sharding.tile_assignment().dimensions();
2235 if (ShapeUtil::ByteSizeOf(lhs.hlo()->shape()) >
2236 ShapeUtil::ByteSizeOf(rhs.hlo()->shape())) {
2237 for (const auto& dim : partitioned_contracting_dims) {
2238 rhs_tile_shape[dim.rhs] = lhs_tile_shape[dim.lhs];
2239 }
2240 auto new_tile = rhs.sharding().tile_assignment();
2241 new_tile.Reshape(rhs_tile_shape);
2242 rhs_sharding = rhs_sharding.ReplicateOnLastTileDim()
2243 ? HloSharding::PartialTile(new_tile)
2244 : HloSharding::Tile(new_tile);
2245 } else {
2246 for (const auto& dim : partitioned_contracting_dims) {
2247 lhs_tile_shape[dim.lhs] = rhs_tile_shape[dim.rhs];
2248 }
2249 auto new_tile = lhs.sharding().tile_assignment();
2250 new_tile.Reshape(lhs_tile_shape);
2251 lhs_sharding = lhs_sharding.ReplicateOnLastTileDim()
2252 ? HloSharding::PartialTile(new_tile)
2253 : HloSharding::Tile(new_tile);
2254 }
2255 return std::make_pair(lhs_sharding, rhs_sharding);
2256 }
2257
PartitionDotGroupOnContracting(PartitionedHlo lhs,PartitionedHlo rhs,absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_contracting_dims,int64_t output_batch_partitions,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64_t num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,bool require_matching_devices_to_group,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops)2258 StatusOr<HloInstruction*> PartitionDotGroupOnContracting(
2259 PartitionedHlo lhs, PartitionedHlo rhs,
2260 absl::Span<const DotConvDimsMapping::DimsMapping>
2261 partitioned_contracting_dims,
2262 int64_t output_batch_partitions,
2263 int64_t output_lhs_non_contracting_partitions,
2264 int64_t output_rhs_non_contracting_partitions,
2265 const Shape& output_base_shape, const HloSharding& output_sharding,
2266 const DotConvDimsMapping& dims_mapping, int64_t num_partitions,
2267 const std::function<StatusOr<HloInstruction*>(
2268 HloInstruction*, HloInstruction*, SpmdBuilder*,
2269 const Window& conv_window)>& create_sharded_dot,
2270 const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
2271 bool require_matching_devices_to_group,
2272 const SpmdPartitionerOptions& options, SpmdBuilder* b,
2273 std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
2274 windowed_dot_general_loops) {
2275 std::vector<std::pair<HloInstruction*, HloSharding>>
2276 top_level_sharding_to_reset;
2277 auto cleaner = tensorflow::gtl::MakeCleanup([&] {
2278 for (auto& to_reset : top_level_sharding_to_reset) {
2279 to_reset.first->set_sharding(to_reset.second);
2280 }
2281 });
2282 std::vector<int64> lhs_dims;
2283 std::vector<int64> rhs_dims;
2284 int64_t group_count = 1;
2285 for (const auto& dim : partitioned_contracting_dims) {
2286 lhs_dims.push_back(dim.lhs);
2287 rhs_dims.push_back(dim.rhs);
2288 group_count *= lhs.sharding().tile_assignment().dim(dim.lhs);
2289 }
2290 HloSharding lhs_sharding = HloSharding::Replicate();
2291 HloSharding rhs_sharding = HloSharding::Replicate();
2292 std::tie(lhs_sharding, rhs_sharding) =
2293 GetDotGroupPartitionContractingLhsRhsShardings(
2294 lhs, rhs, partitioned_contracting_dims);
2295 auto lhs_grouped = GroupShardingOnDims(lhs_sharding, lhs_dims);
2296 auto rhs_grouped = GroupShardingOnDims(rhs_sharding, rhs_dims);
2297 if (ShapeUtil::ByteSizeOf(lhs.hlo()->shape()) >
2298 ShapeUtil::ByteSizeOf(rhs.hlo()->shape())) {
2299 rhs_grouped = AlignGroupsWith(rhs_grouped, lhs_grouped);
2300 rhs_sharding = UngroupSharding(rhs_grouped);
2301 if (require_matching_devices_to_group && rhs.sharding() != rhs_sharding) {
2302 return nullptr;
2303 }
2304 rhs = rhs.Reshard(rhs_sharding);
2305 } else {
2306 lhs_grouped = AlignGroupsWith(lhs_grouped, rhs_grouped);
2307 lhs_sharding = UngroupSharding(lhs_grouped);
2308 if (require_matching_devices_to_group && lhs.sharding() != lhs_sharding) {
2309 return nullptr;
2310 }
2311 lhs = lhs.Reshard(lhs_sharding);
2312 }
2313 // Mask out invalid data.
2314 std::vector<int64> lhs_skipped_dims;
2315 for (int64_t i = 0; i < lhs.base_shape().rank(); ++i) {
2316 if (absl::c_linear_search(lhs_dims, i)) {
2317 continue;
2318 }
2319 lhs_skipped_dims.push_back(i);
2320 }
2321 lhs = lhs.PadWithValue(
2322 CreateZero(ShapeUtil::MakeShape(lhs.base_shape().element_type(), {}), b),
2323 /*left_padded_dims=*/{}, lhs_skipped_dims);
2324 std::vector<int64> rhs_skipped_dims;
2325 for (int64_t i = 0; i < rhs.base_shape().rank(); ++i) {
2326 if (absl::c_linear_search(rhs_dims, i)) {
2327 continue;
2328 }
2329 rhs_skipped_dims.push_back(i);
2330 }
2331 rhs = rhs.PadWithValue(
2332 CreateZero(ShapeUtil::MakeShape(rhs.base_shape().element_type(), {}), b),
2333 /*left_padded_dims=*/{}, rhs_skipped_dims);
2334 top_level_sharding_to_reset.emplace_back(lhs.hlo(), lhs_sharding);
2335 lhs.hlo()->set_sharding(lhs_grouped.sharding);
2336 top_level_sharding_to_reset.emplace_back(rhs.hlo(), rhs_sharding);
2337 rhs.hlo()->set_sharding(rhs_grouped.sharding);
2338
2339 HloSharding inner_output_sharding = HloSharding::Replicate();
2340 HloSharding outer_output_tmp_sharding = HloSharding::Replicate();
2341 std::vector<int64> output_slice_dims;
2342 std::tie(inner_output_sharding, outer_output_tmp_sharding) =
2343 GetDotGroupPartitionContractingOutputShardings(
2344 dims_mapping, lhs_grouped, output_base_shape, output_sharding,
2345 group_count, output_lhs_non_contracting_partitions,
2346 output_rhs_non_contracting_partitions, output_batch_partitions,
2347 &output_slice_dims);
2348 Shape inner_output_base_shape = output_base_shape;
2349 auto get_non_slice_dims = [&] {
2350 std::vector<int64> non_group_dims;
2351 for (int64_t i = 0; i < output_base_shape.rank(); ++i) {
2352 if (!absl::c_linear_search(output_slice_dims, i)) {
2353 non_group_dims.push_back(i);
2354 }
2355 }
2356 return non_group_dims;
2357 };
2358 if (!output_slice_dims.empty()) {
2359 inner_output_base_shape = MakePartitionedShape(
2360 output_base_shape,
2361 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2362 output_sharding, get_non_slice_dims()));
2363 }
2364 std::function<StatusOr<HloInstruction*>(HloInstruction*, HloInstruction*,
2365 SpmdBuilder*, const Window&)>
2366 inner_creator =
2367 [&](HloInstruction* l, HloInstruction* r, SpmdBuilder* b,
2368 const Window& conv_window) -> StatusOr<HloInstruction*> {
2369 TF_ASSIGN_OR_RETURN(auto inner_dot,
2370 create_sharded_dot(l, r, b, conv_window));
2371 auto ar = lhs.state().partitioner->AllReduceAlongShardingDims(
2372 b, inner_dot, lhs_sharding, lhs.state().next_channel_id, lhs_dims,
2373 lhs.state().collective_ops_creator,
2374 MakeBinaryAdd(output_base_shape.element_type(), module));
2375 if (output_slice_dims.empty()) {
2376 return ar;
2377 }
2378 // Use resharding to slice the output. Use a temporary reshard cache since
2379 // we are faking with replicated sharding.
2380 PartitionedHlo::PartitioningState new_state = lhs.state();
2381 new_state.b = b;
2382 new_state.partition_id =
2383 lhs.state().collective_ops_creator.create_partition_id(b);
2384 PartitionedHlo::ReshardCache tmp_cache;
2385 new_state.reshard_cache = &tmp_cache;
2386 ar->set_sharding(HloSharding::Replicate());
2387 return PartitionedHlo(ar, ar->shape(), new_state)
2388 .Reshard(hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2389 output_sharding, get_non_slice_dims()))
2390 .hlo();
2391 };
2392 // Disable doing the inner reshard when the "faster windowed einsum" flag is
2393 // enabled, because the windowed einsum implementation is currently slow with
2394 // this kind of reshard happening.
2395 if (options.choose_faster_windowed_einsum_over_mem) {
2396 inner_output_base_shape = output_base_shape;
2397 inner_creator = create_sharded_dot;
2398 outer_output_tmp_sharding =
2399 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2400 outer_output_tmp_sharding, output_slice_dims);
2401 }
2402 PartitionedHlo::PartitioningState inner_state =
2403 CreatePerGroupPartitioningState(lhs.state(), lhs_grouped.device_groups,
2404 b);
2405 TF_ASSIGN_OR_RETURN(
2406 auto dot,
2407 PartitionDot(
2408 PartitionedHlo(lhs.hlo(),
2409 GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()),
2410 inner_state),
2411 PartitionedHlo(rhs.hlo(),
2412 GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()),
2413 inner_state),
2414 inner_output_base_shape, inner_output_sharding, dims_mapping,
2415 num_partitions / group_count, inner_creator, conv_window, module,
2416 original_hlo, options, b, windowed_dot_general_loops));
2417 if (!dot) {
2418 return nullptr;
2419 }
2420
2421 if (options.choose_faster_windowed_einsum_over_mem) {
2422 HloInstruction* ar = lhs.state().partitioner->AllReduceAlongShardingDims(
2423 b, dot, lhs_sharding, lhs.state().next_channel_id, lhs_dims,
2424 lhs.state().collective_ops_creator,
2425 MakeBinaryAdd(output_base_shape.element_type(), module));
2426 dot = ar;
2427 }
2428
2429 dot->set_sharding(outer_output_tmp_sharding);
2430 auto d = PartitionedHlo(dot, output_base_shape, lhs.state())
2431 .Reshard(output_sharding)
2432 .hlo();
2433 return d;
2434 }
2435
ConvertDimsMappingWithFeatureGroupCount(const DotConvDimsMapping & dims_mapping,HloInstruction * original_hlo)2436 DotConvDimsMapping ConvertDimsMappingWithFeatureGroupCount(
2437 const DotConvDimsMapping& dims_mapping, HloInstruction* original_hlo) {
2438 const auto& dnums = original_hlo->convolution_dimension_numbers();
2439 DotConvDimsMapping new_dims_mapping;
2440 new_dims_mapping.batch_dims = dims_mapping.batch_dims;
2441 new_dims_mapping.conv_spatial_dims = dims_mapping.conv_spatial_dims;
2442 // Append batch dims.
2443 new_dims_mapping.batch_dims.emplace_back();
2444 new_dims_mapping.batch_dims.back().lhs = dnums.input_feature_dimension();
2445 new_dims_mapping.batch_dims.back().rhs =
2446 dnums.kernel_output_feature_dimension();
2447 new_dims_mapping.batch_dims.back().output = dnums.output_feature_dimension();
2448 new_dims_mapping.batch_dims.back().spatial = -1;
2449 // Setup non contracting dims.
2450 new_dims_mapping.lhs_non_contracting_dims.emplace_back();
2451 new_dims_mapping.lhs_non_contracting_dims.back().lhs =
2452 dnums.input_batch_dimension();
2453 new_dims_mapping.rhs_non_contracting_dims.emplace_back();
2454 new_dims_mapping.rhs_non_contracting_dims.back().rhs =
2455 dnums.kernel_input_feature_dimension();
2456 return new_dims_mapping;
2457 }
2458
ConvertDimsMappingWithBatchGroupCount(const DotConvDimsMapping & dims_mapping,HloInstruction * original_hlo)2459 DotConvDimsMapping ConvertDimsMappingWithBatchGroupCount(
2460 const DotConvDimsMapping& dims_mapping, HloInstruction* original_hlo) {
2461 const auto& dnums = original_hlo->convolution_dimension_numbers();
2462 DotConvDimsMapping new_dims_mapping;
2463 new_dims_mapping.batch_dims = dims_mapping.batch_dims;
2464 new_dims_mapping.conv_spatial_dims = dims_mapping.conv_spatial_dims;
2465 new_dims_mapping.contracting_dims = dims_mapping.contracting_dims;
2466 // Append batch dims.
2467 new_dims_mapping.batch_dims.emplace_back();
2468 new_dims_mapping.batch_dims.back().lhs = dnums.input_batch_dimension();
2469 new_dims_mapping.batch_dims.back().rhs =
2470 dnums.kernel_output_feature_dimension();
2471 new_dims_mapping.batch_dims.back().output = dnums.output_feature_dimension();
2472 new_dims_mapping.batch_dims.back().spatial = -1;
2473 return new_dims_mapping;
2474 }
2475
2476 // Estimate the number of iterations of a subsequent windowed einsum
2477 // partitioning if its partitioned in the non-contracting dimensions.
2478 // First value returned is the estimate of the number of iterations if LHS is
2479 // matched while the second is the number of iterations if RHS is matched.
2480 std::pair<absl::optional<int64>, absl::optional<int64>>
EstimateWindowedEinsumIterationsForNonContractingPartitioning(const DotConvDimsMapping & dims_mapping,const PartitionedHlo & lhs,const PartitionedHlo & rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const SpmdPartitionerOptions & options,int64_t num_partitions,int64_t lhs_non_contracting_partitions,int64_t rhs_non_contracting_partitions,int64_t lhs_matching_partitions,int64_t rhs_matching_partitions,int64_t lhs_contracting_partitions,int64_t rhs_contracting_partitions,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,int64_t lhs_batch_partitions,int64_t rhs_batch_partitions)2481 EstimateWindowedEinsumIterationsForNonContractingPartitioning(
2482 const DotConvDimsMapping& dims_mapping, const PartitionedHlo& lhs,
2483 const PartitionedHlo& rhs, const Shape& output_base_shape,
2484 const HloSharding& output_sharding, const SpmdPartitionerOptions& options,
2485 int64_t num_partitions, int64_t lhs_non_contracting_partitions,
2486 int64_t rhs_non_contracting_partitions, int64_t lhs_matching_partitions,
2487 int64_t rhs_matching_partitions, int64_t lhs_contracting_partitions,
2488 int64_t rhs_contracting_partitions,
2489 int64_t output_lhs_non_contracting_partitions,
2490 int64_t output_rhs_non_contracting_partitions, int64_t lhs_batch_partitions,
2491 int64_t rhs_batch_partitions) {
2492 const DotDimensionIndexMapping indices_map = ComputeDimensionIndexMapping(
2493 dims_mapping, lhs.base_shape().rank(), rhs.base_shape().rank(),
2494 output_base_shape.rank());
2495 auto subsequent_einsum_iterations_estimate =
2496 [&](bool assume_lhs_match) -> absl::optional<int64> {
2497 const std::vector<DotConvDimsMapping::DimsMapping>&
2498 matching_non_contracting_dims =
2499 assume_lhs_match ? dims_mapping.lhs_non_contracting_dims
2500 : dims_mapping.rhs_non_contracting_dims;
2501 const std::vector<DotConvDimsMapping::DimsMapping>&
2502 other_non_contracting_dims =
2503 assume_lhs_match ? dims_mapping.rhs_non_contracting_dims
2504 : dims_mapping.lhs_non_contracting_dims;
2505 const std::vector<int64>& output_to_matching_indices =
2506 assume_lhs_match ? indices_map.output_to_lhs_indices
2507 : indices_map.output_to_rhs_indices;
2508 const std::vector<int64>& output_to_other_indices =
2509 assume_lhs_match ? indices_map.output_to_rhs_indices
2510 : indices_map.output_to_lhs_indices;
2511 const std::vector<int64>& matching_to_output_indices =
2512 assume_lhs_match ? indices_map.lhs_to_output_indices
2513 : indices_map.rhs_to_output_indices;
2514 const std::vector<int64>& other_to_output_indices =
2515 assume_lhs_match ? indices_map.rhs_to_output_indices
2516 : indices_map.lhs_to_output_indices;
2517 const HloSharding& matching_sharding =
2518 assume_lhs_match ? lhs.sharding() : rhs.sharding();
2519 const HloSharding& other_sharding =
2520 assume_lhs_match ? rhs.sharding() : lhs.sharding();
2521 const PartitionedHlo& matching_partitioned = assume_lhs_match ? lhs : rhs;
2522 const PartitionedHlo& other_partitioned = assume_lhs_match ? rhs : lhs;
2523 const int64_t matching_non_contracting_partitions =
2524 assume_lhs_match ? lhs_non_contracting_partitions
2525 : rhs_non_contracting_partitions;
2526 const int64_t other_non_contracting_partitions =
2527 assume_lhs_match ? rhs_non_contracting_partitions
2528 : lhs_non_contracting_partitions;
2529 const int64_t matching_contracting_partitions =
2530 assume_lhs_match ? lhs_contracting_partitions
2531 : rhs_contracting_partitions;
2532 const int64_t other_contracting_partitions =
2533 assume_lhs_match ? rhs_contracting_partitions
2534 : lhs_contracting_partitions;
2535 const int64_t output_matching_non_contracting_partitions =
2536 assume_lhs_match ? output_lhs_non_contracting_partitions
2537 : output_rhs_non_contracting_partitions;
2538 const int64_t output_other_non_contracting_partitions =
2539 assume_lhs_match ? output_rhs_non_contracting_partitions
2540 : output_lhs_non_contracting_partitions;
2541 const int64_t matching_batch_partitions =
2542 assume_lhs_match ? lhs_batch_partitions : rhs_batch_partitions;
2543 const int64_t other_batch_partitions =
2544 assume_lhs_match ? rhs_batch_partitions : lhs_batch_partitions;
2545 const int64_t matching_matched_non_contracting_partitions =
2546 assume_lhs_match ? lhs_non_contracting_partitions
2547 : rhs_non_contracting_partitions;
2548 std::vector<int64> output_dims;
2549 output_dims.reserve(matching_non_contracting_dims.size());
2550 for (const DotConvDimsMapping::DimsMapping& dim :
2551 matching_non_contracting_dims) {
2552 output_dims.push_back(dim.output);
2553 }
2554 GroupedSharding output_grouped =
2555 GroupShardingOnDims(output_sharding, output_dims);
2556 GroupedSharding matching_grouped =
2557 GetNonContractingPartitionGroupedShardingForMatchedOperand(
2558 assume_lhs_match, matching_sharding, output_sharding,
2559 matching_non_contracting_dims);
2560 absl::optional<GroupedSharding> other_grouped =
2561 GetNonContractingPartitionGroupedShardingForOtherOperand(
2562 assume_lhs_match, output_base_shape,
2563 other_partitioned.hlo()->shape(), other_contracting_partitions,
2564 other_non_contracting_partitions, matching_contracting_partitions,
2565 output_other_non_contracting_partitions, other_sharding,
2566 output_sharding, matching_non_contracting_dims,
2567 other_non_contracting_dims, dims_mapping.contracting_dims);
2568 if (!other_grouped) {
2569 return absl::nullopt;
2570 }
2571 absl::optional<HloSharding> output_sharding_transposed_to_match_matching =
2572 hlo_sharding_util::TransposeShardingWithCollapsedDims(
2573 output_grouped.sharding, output_to_matching_indices,
2574 matching_to_output_indices);
2575 absl::optional<HloSharding> output_sharding_transposed_to_match_other =
2576 hlo_sharding_util::TransposeShardingWithCollapsedDims(
2577 output_grouped.sharding, output_to_other_indices,
2578 other_to_output_indices);
2579 const int64_t new_num_partitions =
2580 num_partitions / matching_non_contracting_partitions;
2581 absl::optional<WindowedEinsumConfig> e_config =
2582 GetWindowedEinsumConfiguration(
2583 new_num_partitions, output_matching_non_contracting_partitions,
2584 output_other_non_contracting_partitions,
2585 other_contracting_partitions, other_non_contracting_partitions,
2586 other_batch_partitions, matching_contracting_partitions,
2587 matching_non_contracting_partitions /
2588 matching_matched_non_contracting_partitions,
2589 matching_batch_partitions,
2590 ShapeSizeInBytes(other_partitioned.base_shape()),
2591 ShapeSizeInBytes(matching_partitioned.base_shape()) /
2592 matching_non_contracting_partitions,
2593 ShapeSizeInBytes(
2594 GetPerGroupBaseShape(output_grouped, output_base_shape)),
2595 options.threshold_for_windowed_einsum_mib,
2596 output_sharding_transposed_to_match_matching,
2597 output_sharding_transposed_to_match_other,
2598 matching_grouped.sharding, other_grouped->sharding);
2599 return e_config ? new_num_partitions : absl::optional<int64>(absl::nullopt);
2600 };
2601 absl::optional<int64> lhs_matching_iterations;
2602 if (lhs_matching_partitions != 0) {
2603 lhs_matching_iterations = subsequent_einsum_iterations_estimate(true);
2604 }
2605 absl::optional<int64> rhs_matching_iterations;
2606 if (rhs_matching_partitions != 0) {
2607 rhs_matching_iterations = subsequent_einsum_iterations_estimate(false);
2608 }
2609 return std::make_pair(lhs_matching_iterations, rhs_matching_iterations);
2610 }
2611
2612 // Return if we should prioritize partitioning in the contracting dimensions
2613 // first then non-contracting dimensions if we estimate that would allow
2614 // for a fewer number of iterations of the windowed einsum.
PrioritizeContractingDimensionsPartitioning(const DotConvDimsMapping & dims_mapping,const PartitionedHlo & lhs,const PartitionedHlo & rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const SpmdPartitionerOptions & options,int64_t num_partitions,int64_t lhs_non_contracting_partitions,int64_t rhs_non_contracting_partitions,int64_t lhs_contracting_partitions,int64_t rhs_contracting_partitions,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,int64_t lhs_batch_partitions,int64_t rhs_batch_partitions,int64_t output_batch_partitions,bool require_matching_devices_to_group)2615 bool PrioritizeContractingDimensionsPartitioning(
2616 const DotConvDimsMapping& dims_mapping, const PartitionedHlo& lhs,
2617 const PartitionedHlo& rhs, const Shape& output_base_shape,
2618 const HloSharding& output_sharding, const SpmdPartitionerOptions& options,
2619 int64_t num_partitions, int64_t lhs_non_contracting_partitions,
2620 int64_t rhs_non_contracting_partitions, int64_t lhs_contracting_partitions,
2621 int64_t rhs_contracting_partitions,
2622 int64_t output_lhs_non_contracting_partitions,
2623 int64_t output_rhs_non_contracting_partitions, int64_t lhs_batch_partitions,
2624 int64_t rhs_batch_partitions, int64_t output_batch_partitions,
2625 bool require_matching_devices_to_group) {
2626 const bool may_group_on_lhs_non_contracting =
2627 lhs_non_contracting_partitions == output_lhs_non_contracting_partitions &&
2628 lhs_non_contracting_partitions > 1;
2629 const bool may_group_on_rhs_non_contracting =
2630 rhs_non_contracting_partitions == output_rhs_non_contracting_partitions &&
2631 rhs_non_contracting_partitions > 1;
2632 if (!options.choose_faster_windowed_einsum_over_mem) {
2633 return false;
2634 }
2635 // Check only for perfect dimensions match for now.
2636 if (!may_group_on_lhs_non_contracting && !may_group_on_rhs_non_contracting) {
2637 return false;
2638 }
2639 absl::optional<int64> lhs_matching_iterations;
2640 absl::optional<int64> rhs_matching_iterations;
2641 const int64_t lhs_matching_non_contracting_partitions =
2642 may_group_on_lhs_non_contracting ? lhs_non_contracting_partitions : 0;
2643 const int64_t rhs_matching_non_contracting_partitions =
2644 may_group_on_rhs_non_contracting ? rhs_non_contracting_partitions : 0;
2645 std::tie(lhs_matching_iterations, rhs_matching_iterations) =
2646 EstimateWindowedEinsumIterationsForNonContractingPartitioning(
2647 dims_mapping, lhs, rhs, output_base_shape, output_sharding, options,
2648 num_partitions, lhs_non_contracting_partitions,
2649 rhs_non_contracting_partitions,
2650 lhs_matching_non_contracting_partitions,
2651 rhs_matching_non_contracting_partitions, lhs_contracting_partitions,
2652 rhs_contracting_partitions, output_lhs_non_contracting_partitions,
2653 output_rhs_non_contracting_partitions, lhs_batch_partitions,
2654 rhs_batch_partitions);
2655 if (!lhs_matching_iterations && !rhs_matching_iterations) {
2656 return false;
2657 }
2658 // Be conservative and handle only case where the two partitions in rhs and
2659 // lhs match
2660 if (!(lhs_contracting_partitions == rhs_contracting_partitions &&
2661 lhs_contracting_partitions > 1)) {
2662 return false;
2663 }
2664 // Estimate the iterations in the case we perform the partitioning on the
2665 // contracting dimensions instead.
2666 std::vector<int64> lhs_dims;
2667 std::vector<int64> rhs_dims;
2668 int64_t group_count = 1;
2669 for (const auto& dim : dims_mapping.contracting_dims) {
2670 lhs_dims.push_back(dim.lhs);
2671 rhs_dims.push_back(dim.rhs);
2672 group_count *= lhs.sharding().tile_assignment().dim(dim.lhs);
2673 }
2674 HloSharding lhs_sharding = HloSharding::Replicate();
2675 HloSharding rhs_sharding = HloSharding::Replicate();
2676 std::tie(lhs_sharding, rhs_sharding) =
2677 GetDotGroupPartitionContractingLhsRhsShardings(
2678 lhs, rhs, dims_mapping.contracting_dims);
2679 auto lhs_grouped = GroupShardingOnDims(lhs_sharding, lhs_dims);
2680 auto rhs_grouped = GroupShardingOnDims(rhs_sharding, rhs_dims);
2681 rhs_grouped = AlignGroupsWith(rhs_grouped, lhs_grouped);
2682 rhs_sharding = UngroupSharding(rhs_grouped);
2683
2684 if (require_matching_devices_to_group && rhs.sharding() != rhs_sharding) {
2685 return false;
2686 }
2687 const int64_t new_num_partitions =
2688 num_partitions / lhs_contracting_partitions;
2689
2690 HloSharding inner_output_sharding = HloSharding::Replicate();
2691 HloSharding outer_output_tmp_sharding = HloSharding::Replicate();
2692 std::vector<int64> output_slice_dims;
2693 std::tie(inner_output_sharding, outer_output_tmp_sharding) =
2694 GetDotGroupPartitionContractingOutputShardings(
2695 dims_mapping, lhs_grouped, output_base_shape, output_sharding,
2696 group_count, output_lhs_non_contracting_partitions,
2697 output_rhs_non_contracting_partitions, output_batch_partitions,
2698 &output_slice_dims);
2699 Shape inner_output_base_shape = output_base_shape;
2700 if (!output_slice_dims.empty()) {
2701 std::vector<int64> non_group_dims;
2702 for (int64_t i = 0; i < output_base_shape.rank(); ++i) {
2703 if (!absl::c_linear_search(output_slice_dims, i)) {
2704 non_group_dims.push_back(i);
2705 }
2706 }
2707 inner_output_base_shape = MakePartitionedShape(
2708 output_base_shape,
2709 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2710 output_sharding, non_group_dims));
2711 }
2712 int64_t new_output_lhs_non_contracting_partitions = 1;
2713 int64_t new_output_rhs_non_contracting_partitions = 1;
2714 if (!inner_output_sharding.IsTileMaximal()) {
2715 for (const auto& dim : dims_mapping.lhs_non_contracting_dims) {
2716 new_output_lhs_non_contracting_partitions *=
2717 inner_output_sharding.tile_assignment().dim(dim.output);
2718 }
2719 for (const auto& dim : dims_mapping.rhs_non_contracting_dims) {
2720 if (dim.output != -1) {
2721 new_output_rhs_non_contracting_partitions *=
2722 inner_output_sharding.tile_assignment().dim(dim.output);
2723 }
2724 }
2725 }
2726
2727 const DotDimensionIndexMapping indices_map = ComputeDimensionIndexMapping(
2728 dims_mapping, lhs.base_shape().rank(), rhs.base_shape().rank(),
2729 inner_output_base_shape.rank());
2730 absl::optional<HloSharding> output_sharding_transposed_to_match_lhs =
2731 hlo_sharding_util::TransposeShardingWithCollapsedDims(
2732 inner_output_sharding, indices_map.output_to_lhs_indices,
2733 indices_map.lhs_to_output_indices);
2734 absl::optional<HloSharding> output_sharding_transposed_to_match_rhs =
2735 hlo_sharding_util::TransposeShardingWithCollapsedDims(
2736 inner_output_sharding, indices_map.output_to_rhs_indices,
2737 indices_map.rhs_to_output_indices);
2738 absl::optional<WindowedEinsumConfig> e_config =
2739 GetWindowedEinsumConfiguration(
2740 new_num_partitions, new_output_lhs_non_contracting_partitions,
2741 new_output_rhs_non_contracting_partitions, 1,
2742 rhs_non_contracting_partitions, rhs_batch_partitions, 1,
2743 lhs_non_contracting_partitions, lhs_batch_partitions,
2744 ShapeSizeInBytes(GetPerGroupBaseShape(rhs_grouped, rhs.base_shape())),
2745 ShapeSizeInBytes(GetPerGroupBaseShape(lhs_grouped, lhs.base_shape())),
2746 ShapeSizeInBytes(inner_output_base_shape),
2747 options.threshold_for_windowed_einsum_mib,
2748 output_sharding_transposed_to_match_lhs,
2749 output_sharding_transposed_to_match_rhs, lhs_grouped.sharding,
2750 rhs_grouped.sharding);
2751 if (!e_config) {
2752 return false;
2753 }
2754 const int64_t min_nc_iterations =
2755 std::min(lhs_matching_iterations ? *lhs_matching_iterations : INT64_MAX,
2756 rhs_matching_iterations ? *rhs_matching_iterations : INT64_MAX);
2757 return min_nc_iterations > new_num_partitions;
2758 }
2759
2760 // Return if it would be better to match the LHS operand or RHS operand
2761 // of a dot for non-contracting partitioning.
LhsIsBestMatchForNonContractingPartitioning(const DotConvDimsMapping & dims_mapping,const PartitionedHlo & lhs,const PartitionedHlo & rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const SpmdPartitionerOptions & options,int64_t num_partitions,int64_t lhs_non_contracting_partitions,int64_t rhs_non_contracting_partitions,int64_t lhs_matching_partitions,int64_t rhs_matching_partitions,int64_t lhs_contracting_partitions,int64_t rhs_contracting_partitions,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,int64_t lhs_batch_partitions,int64_t rhs_batch_partitions)2762 bool LhsIsBestMatchForNonContractingPartitioning(
2763 const DotConvDimsMapping& dims_mapping, const PartitionedHlo& lhs,
2764 const PartitionedHlo& rhs, const Shape& output_base_shape,
2765 const HloSharding& output_sharding, const SpmdPartitionerOptions& options,
2766 int64_t num_partitions, int64_t lhs_non_contracting_partitions,
2767 int64_t rhs_non_contracting_partitions, int64_t lhs_matching_partitions,
2768 int64_t rhs_matching_partitions, int64_t lhs_contracting_partitions,
2769 int64_t rhs_contracting_partitions,
2770 int64_t output_lhs_non_contracting_partitions,
2771 int64_t output_rhs_non_contracting_partitions, int64_t lhs_batch_partitions,
2772 int64_t rhs_batch_partitions) {
2773 const bool may_group_on_lhs_non_contracting =
2774 lhs_non_contracting_partitions == output_lhs_non_contracting_partitions &&
2775 lhs_non_contracting_partitions > 1;
2776 const bool may_group_on_rhs_non_contracting =
2777 rhs_non_contracting_partitions == output_rhs_non_contracting_partitions &&
2778 rhs_non_contracting_partitions > 1;
2779 // If both match output non-contracting dimensions, choose the one which
2780 // will result in smaller replication of the other operand.
2781 bool lhs_matching = may_group_on_lhs_non_contracting &&
2782 (!may_group_on_rhs_non_contracting ||
2783 lhs_non_contracting_partitions *
2784 ShapeUtil::ByteSizeOf(rhs.hlo()->shape()) <
2785 rhs_non_contracting_partitions *
2786 ShapeUtil::ByteSizeOf(lhs.hlo()->shape()));
2787 // If both grouping are available and the option to choose faster windowed
2788 // einsums vs saving memory is enabled then try to determine which of the
2789 // operands will generate the least amount of iterations for the windowed
2790 // einsum when matched (if a windowed einsum is gonna be generated at
2791 // all).
2792 if (may_group_on_lhs_non_contracting && may_group_on_rhs_non_contracting &&
2793 options.choose_faster_windowed_einsum_over_mem) {
2794 const DotDimensionIndexMapping indices_map = ComputeDimensionIndexMapping(
2795 dims_mapping, lhs.base_shape().rank(), rhs.base_shape().rank(),
2796 output_base_shape.rank());
2797 absl::optional<int64> lhs_matching_iterations;
2798 absl::optional<int64> rhs_matching_iterations;
2799 std::tie(lhs_matching_iterations, rhs_matching_iterations) =
2800 EstimateWindowedEinsumIterationsForNonContractingPartitioning(
2801 dims_mapping, lhs, rhs, output_base_shape, output_sharding, options,
2802 num_partitions, lhs_non_contracting_partitions,
2803 rhs_non_contracting_partitions, lhs_matching_partitions,
2804 rhs_matching_partitions, lhs_contracting_partitions,
2805 rhs_contracting_partitions, output_lhs_non_contracting_partitions,
2806 output_rhs_non_contracting_partitions, lhs_batch_partitions,
2807 rhs_batch_partitions);
2808 if (lhs_matching_iterations && rhs_matching_iterations &&
2809 *lhs_matching_iterations != *rhs_matching_iterations) {
2810 lhs_matching = *lhs_matching_iterations < *rhs_matching_iterations;
2811 }
2812 }
2813 return lhs_matching;
2814 }
2815
2816 // Recursive partitioning function. If there are partial dimensions matching
2817 // in the operands and output, group the devices and recursively partition
2818 // the in-group dot.
PartitionDot(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64_t num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,bool require_matching_devices_to_group,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops)2819 StatusOr<HloInstruction*> PartitionDot(
2820 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
2821 const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
2822 int64_t num_partitions,
2823 const std::function<StatusOr<HloInstruction*>(
2824 HloInstruction*, HloInstruction*, SpmdBuilder*,
2825 const Window& conv_window)>& create_sharded_dot,
2826 const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
2827 bool require_matching_devices_to_group,
2828 const SpmdPartitionerOptions& options, SpmdBuilder* b,
2829 std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
2830 windowed_dot_general_loops) {
2831 // If lhs‘ hlo and rhs' hlo are identical, make a copy for rhs.
2832 if (lhs.hlo() == rhs.hlo()) {
2833 auto copy_hlo = b->AddInstruction(HloInstruction::CreateUnary(
2834 rhs.hlo()->shape(), HloOpcode::kCopy, rhs.hlo()));
2835 copy_hlo->set_sharding(rhs.sharding());
2836 rhs = PartitionedHlo(copy_hlo, rhs.base_shape(), rhs.state());
2837 }
2838
2839 // lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output.
2840 auto get_partitions_for_dims =
2841 [&](const HloSharding& sharding,
2842 absl::Span<const DotConvDimsMapping::DimsMapping> dims,
2843 int lhs_rhs_or_output) {
2844 int64_t partitions = 1;
2845 if (sharding.IsTileMaximal()) {
2846 return partitions;
2847 }
2848 for (const auto& dim : dims) {
2849 if (lhs_rhs_or_output == 0) {
2850 partitions *= sharding.tile_assignment().dim(dim.lhs);
2851 } else if (lhs_rhs_or_output == 1) {
2852 partitions *= sharding.tile_assignment().dim(dim.rhs);
2853 } else {
2854 CHECK_EQ(lhs_rhs_or_output, 2);
2855 partitions *= sharding.tile_assignment().dim(dim.output);
2856 }
2857 }
2858 return partitions;
2859 };
2860 const int64_t lhs_batch_partitions =
2861 get_partitions_for_dims(lhs.sharding(), dims_mapping.batch_dims, 0);
2862 const int64_t rhs_batch_partitions =
2863 get_partitions_for_dims(rhs.sharding(), dims_mapping.batch_dims, 1);
2864 const int64_t output_batch_partitions =
2865 get_partitions_for_dims(output_sharding, dims_mapping.batch_dims, 2);
2866 const int64_t lhs_contracting_partitions =
2867 get_partitions_for_dims(lhs.sharding(), dims_mapping.contracting_dims, 0);
2868 const int64_t rhs_contracting_partitions =
2869 get_partitions_for_dims(rhs.sharding(), dims_mapping.contracting_dims, 1);
2870 const int64_t lhs_non_contracting_partitions = get_partitions_for_dims(
2871 lhs.sharding(), dims_mapping.lhs_non_contracting_dims, 0);
2872 const int64_t rhs_non_contracting_partitions = get_partitions_for_dims(
2873 rhs.sharding(), dims_mapping.rhs_non_contracting_dims, 1);
2874 const int64_t output_lhs_non_contracting_partitions = get_partitions_for_dims(
2875 output_sharding, dims_mapping.lhs_non_contracting_dims, 2);
2876 const int64_t output_rhs_non_contracting_partitions = get_partitions_for_dims(
2877 output_sharding, dims_mapping.rhs_non_contracting_dims, 2);
2878 const int64_t lhs_conv_spatial_partitions = get_partitions_for_dims(
2879 lhs.sharding(), dims_mapping.conv_spatial_dims, 0);
2880 const int64_t rhs_conv_spatial_partitions = get_partitions_for_dims(
2881 rhs.sharding(), dims_mapping.conv_spatial_dims, 1);
2882 const int64_t output_conv_spatial_partitions = get_partitions_for_dims(
2883 output_sharding, dims_mapping.conv_spatial_dims, 2);
2884 // Before we find partial matches along the dimensions, invoke base case
2885 // again without may_reshard_without_detecting_match.
2886
2887 // Try partition the purely spatially-partitioned convolution with
2888 // convolution spatial dimension partitioned or depthwise parallel
2889 // dimension partitioned.
2890 bool is_conv_spatial_dim_partitioned =
2891 (lhs_conv_spatial_partitions > 1 || rhs_conv_spatial_partitions > 1 ||
2892 output_conv_spatial_partitions > 1);
2893 bool is_conv_batch_or_contracting_dim_partitioned =
2894 (lhs_batch_partitions > 1 || rhs_batch_partitions > 1 ||
2895 output_batch_partitions > 1 ||
2896 (lhs_contracting_partitions > 1 && rhs_contracting_partitions > 1));
2897 if ((!dims_mapping.conv_spatial_dims.empty() &&
2898 is_conv_spatial_dim_partitioned &&
2899 !is_conv_batch_or_contracting_dim_partitioned) ||
2900 (original_hlo->opcode() == HloOpcode::kConvolution &&
2901 (original_hlo->batch_group_count() > 1 ||
2902 original_hlo->feature_group_count() > 1))) {
2903 // Partition with kernel_input_feature_dim > 1 and feature_group_count >
2904 // 1 is not supported.
2905 const auto& dnums = original_hlo->convolution_dimension_numbers();
2906 if (original_hlo->feature_group_count() > 1 &&
2907 rhs.hlo()->shape().dimensions(dnums.kernel_input_feature_dimension()) >
2908 1) {
2909 return nullptr;
2910 }
2911
2912 TF_ASSIGN_OR_RETURN(
2913 auto partitioned_conv,
2914 PartitionConvolution(lhs, rhs, output_base_shape, output_sharding,
2915 dims_mapping, create_sharded_dot, conv_window,
2916 original_hlo, num_partitions, options,
2917 lhs.state().partition_id, module, b));
2918
2919 if (partitioned_conv) {
2920 return partitioned_conv;
2921 }
2922
2923 // Recursively partition on different types of dimensions for
2924 // convolution. Case 0.a: Group partitions by feature group count.
2925 if (original_hlo->feature_group_count() > 1 ||
2926 original_hlo->batch_group_count() > 1) {
2927 DotConvDimsMapping new_dims_mapping;
2928 if (original_hlo->feature_group_count() > 1) {
2929 new_dims_mapping =
2930 ConvertDimsMappingWithFeatureGroupCount(dims_mapping, original_hlo);
2931 }
2932
2933 if (original_hlo->batch_group_count() > 1) {
2934 new_dims_mapping =
2935 ConvertDimsMappingWithBatchGroupCount(dims_mapping, original_hlo);
2936 }
2937
2938 const int64_t conv_lhs_contracting_partitions = get_partitions_for_dims(
2939 lhs.sharding(), new_dims_mapping.contracting_dims, 0);
2940 const int64_t conv_rhs_contracting_partitions = get_partitions_for_dims(
2941 rhs.sharding(), new_dims_mapping.contracting_dims, 1);
2942 const int64_t conv_lhs_non_contracting_partitions =
2943 get_partitions_for_dims(lhs.sharding(),
2944 new_dims_mapping.lhs_non_contracting_dims, 0);
2945 const int64_t conv_rhs_non_contracting_partitions =
2946 get_partitions_for_dims(rhs.sharding(),
2947 new_dims_mapping.rhs_non_contracting_dims, 1);
2948 const int64_t conv_lhs_batch_partitions = get_partitions_for_dims(
2949 lhs.sharding(), new_dims_mapping.batch_dims, 0);
2950 const int64_t conv_rhs_batch_partitions = get_partitions_for_dims(
2951 rhs.sharding(), new_dims_mapping.batch_dims, 1);
2952 const int64_t conv_output_batch_partitions = get_partitions_for_dims(
2953 output_sharding, new_dims_mapping.batch_dims, 2);
2954 if ((conv_lhs_batch_partitions == conv_output_batch_partitions ||
2955 conv_rhs_batch_partitions == conv_output_batch_partitions) &&
2956 conv_output_batch_partitions > 1) {
2957 TF_ASSIGN_OR_RETURN(
2958 auto try_partitioned_conv,
2959 PartitionDotGroupOnBatch(
2960 lhs, rhs, output_base_shape, output_sharding, new_dims_mapping,
2961 num_partitions, conv_lhs_contracting_partitions,
2962 conv_rhs_contracting_partitions,
2963 conv_lhs_non_contracting_partitions,
2964 conv_rhs_non_contracting_partitions, create_sharded_dot,
2965 conv_window, module, original_hlo,
2966 require_matching_devices_to_group, options, b,
2967 windowed_dot_general_loops));
2968 if (try_partitioned_conv) {
2969 return try_partitioned_conv;
2970 }
2971 }
2972 return nullptr;
2973 }
2974 }
2975
2976 TF_ASSIGN_OR_RETURN(
2977 auto try_partitioned_dot,
2978 PartitionBaseCase(
2979 lhs, rhs, output_base_shape, output_sharding, dims_mapping,
2980 num_partitions, create_sharded_dot, conv_window, module, original_hlo,
2981 lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions,
2982 lhs_contracting_partitions, rhs_contracting_partitions,
2983 lhs_non_contracting_partitions, rhs_non_contracting_partitions,
2984 output_lhs_non_contracting_partitions,
2985 output_rhs_non_contracting_partitions, options, b,
2986 windowed_dot_general_loops,
2987 /*may_reshard_without_detecting_match=*/false));
2988 if (try_partitioned_dot) {
2989 return try_partitioned_dot;
2990 }
2991
2992 // Recursively partition on different types of dimensions.
2993 //
2994 // Case 1: Group partitions by batch.
2995 if ((lhs_batch_partitions == output_batch_partitions ||
2996 rhs_batch_partitions == output_batch_partitions) &&
2997 output_batch_partitions > 1) {
2998 TF_ASSIGN_OR_RETURN(
2999 auto dot,
3000 PartitionDotGroupOnBatch(
3001 lhs, rhs, output_base_shape, output_sharding, dims_mapping,
3002 num_partitions, lhs_contracting_partitions,
3003 rhs_contracting_partitions, lhs_non_contracting_partitions,
3004 rhs_non_contracting_partitions, create_sharded_dot, conv_window,
3005 module, original_hlo, require_matching_devices_to_group, options, b,
3006 windowed_dot_general_loops));
3007 if (dot) {
3008 return dot;
3009 }
3010 }
3011
3012 // Case 2: Group partitions by non-contracting dimensions.
3013 const bool may_group_on_lhs_non_contracting =
3014 lhs_non_contracting_partitions == output_lhs_non_contracting_partitions &&
3015 lhs_non_contracting_partitions > 1;
3016 const bool may_group_on_rhs_non_contracting =
3017 rhs_non_contracting_partitions == output_rhs_non_contracting_partitions &&
3018 rhs_non_contracting_partitions > 1;
3019 bool lhs_matching = false;
3020 std::vector<DotConvDimsMapping::DimsMapping> matching_dims;
3021 if (may_group_on_lhs_non_contracting || may_group_on_rhs_non_contracting) {
3022 lhs_matching = LhsIsBestMatchForNonContractingPartitioning(
3023 dims_mapping, lhs, rhs, output_base_shape, output_sharding, options,
3024 num_partitions, lhs_non_contracting_partitions,
3025 rhs_non_contracting_partitions, lhs_non_contracting_partitions,
3026 rhs_non_contracting_partitions, lhs_contracting_partitions,
3027 rhs_contracting_partitions, output_lhs_non_contracting_partitions,
3028 output_rhs_non_contracting_partitions, lhs_batch_partitions,
3029 rhs_batch_partitions);
3030 matching_dims = lhs_matching ? dims_mapping.lhs_non_contracting_dims
3031 : dims_mapping.rhs_non_contracting_dims;
3032 } else if (lhs_non_contracting_partitions > 1 &&
3033 output_lhs_non_contracting_partitions > 1) {
3034 lhs_matching = true;
3035 for (const auto& dim : dims_mapping.lhs_non_contracting_dims) {
3036 int64_t lhs_partitions = lhs.sharding().tile_assignment().dim(dim.lhs);
3037 if (lhs_partitions > 1 &&
3038 lhs_partitions == output_sharding.tile_assignment().dim(dim.output)) {
3039 matching_dims.push_back(dim);
3040 }
3041 }
3042 } else if (rhs_non_contracting_partitions > 1 &&
3043 output_rhs_non_contracting_partitions > 1) {
3044 lhs_matching = false;
3045 for (const auto& dim : dims_mapping.rhs_non_contracting_dims) {
3046 int64_t rhs_partitions = rhs.sharding().tile_assignment().dim(dim.rhs);
3047 if (rhs_partitions > 1 &&
3048 rhs_partitions == output_sharding.tile_assignment().dim(dim.output)) {
3049 matching_dims.push_back(dim);
3050 }
3051 }
3052 }
3053 const bool prioritize_contracting_for_faster_windowed_einsum =
3054 PrioritizeContractingDimensionsPartitioning(
3055 dims_mapping, lhs, rhs, output_base_shape, output_sharding, options,
3056 num_partitions, lhs_non_contracting_partitions,
3057 rhs_non_contracting_partitions, lhs_contracting_partitions,
3058 rhs_contracting_partitions, output_lhs_non_contracting_partitions,
3059 output_rhs_non_contracting_partitions, lhs_batch_partitions,
3060 rhs_batch_partitions, output_batch_partitions,
3061 require_matching_devices_to_group);
3062 if (!(matching_dims.empty() ||
3063 prioritize_contracting_for_faster_windowed_einsum)) {
3064 TF_ASSIGN_OR_RETURN(
3065 auto dot,
3066 PartitionDotGroupOnNonContracting(
3067 lhs_matching, lhs_matching ? lhs : rhs, lhs_matching ? rhs : lhs,
3068 lhs_matching ? lhs_contracting_partitions
3069 : rhs_contracting_partitions,
3070 lhs_matching ? rhs_contracting_partitions
3071 : lhs_contracting_partitions,
3072 matching_dims,
3073 lhs_matching ? rhs_non_contracting_partitions
3074 : lhs_non_contracting_partitions,
3075 lhs_matching ? output_rhs_non_contracting_partitions
3076 : output_lhs_non_contracting_partitions,
3077 output_base_shape, output_sharding, dims_mapping, num_partitions,
3078 create_sharded_dot, conv_window, module, original_hlo,
3079 require_matching_devices_to_group, options, b,
3080 windowed_dot_general_loops));
3081 if (dot) {
3082 return dot;
3083 }
3084 }
3085
3086 // Case 3: Group partitions by contracting dimensions.
3087 if (lhs_contracting_partitions == rhs_contracting_partitions &&
3088 lhs_contracting_partitions > 1) {
3089 TF_ASSIGN_OR_RETURN(
3090 auto dot,
3091 PartitionDotGroupOnContracting(
3092 lhs, rhs, dims_mapping.contracting_dims, output_batch_partitions,
3093 output_lhs_non_contracting_partitions,
3094 output_rhs_non_contracting_partitions, output_base_shape,
3095 output_sharding, dims_mapping, num_partitions, create_sharded_dot,
3096 conv_window, module, original_hlo,
3097 require_matching_devices_to_group, options, b,
3098 windowed_dot_general_loops));
3099 if (dot) {
3100 return dot;
3101 }
3102 }
3103 if (lhs_contracting_partitions > 1 && rhs_contracting_partitions > 1) {
3104 // If part of contracting dims match, try them.
3105 std::vector<DotConvDimsMapping::DimsMapping> matching_dims;
3106 for (const auto& dim : dims_mapping.contracting_dims) {
3107 int64_t lhs_partitions = lhs.sharding().tile_assignment().dim(dim.lhs);
3108 if (lhs_partitions > 1 &&
3109 lhs_partitions == rhs.sharding().tile_assignment().dim(dim.rhs)) {
3110 matching_dims.push_back(dim);
3111 }
3112 }
3113 if (!matching_dims.empty()) {
3114 TF_ASSIGN_OR_RETURN(
3115 auto dot, PartitionDotGroupOnContracting(
3116 lhs, rhs, matching_dims, output_batch_partitions,
3117 output_lhs_non_contracting_partitions,
3118 output_rhs_non_contracting_partitions,
3119 output_base_shape, output_sharding, dims_mapping,
3120 num_partitions, create_sharded_dot, conv_window, module,
3121 original_hlo, require_matching_devices_to_group,
3122 options, b, windowed_dot_general_loops));
3123 if (dot) {
3124 return dot;
3125 }
3126 }
3127 }
3128
3129 // Case 4: If operands are replicated but output is partially replicated,
3130 // recursive call with partial replication removed.
3131 if (lhs.sharding().IsReplicated() && rhs.sharding().IsReplicated() &&
3132 output_sharding.ReplicateOnLastTileDim()) {
3133 auto grouped_output =
3134 GroupShardingOnDims(output_sharding, {output_base_shape.rank()});
3135 auto inner_state = CreatePerGroupPartitioningState(
3136 lhs.state(), grouped_output.device_groups, b);
3137 TF_ASSIGN_OR_RETURN(
3138 auto dot,
3139 PartitionDot(PartitionedHlo(lhs.hlo(), lhs.base_shape(), inner_state),
3140 PartitionedHlo(rhs.hlo(), rhs.base_shape(), inner_state),
3141 output_base_shape, grouped_output.sharding, dims_mapping,
3142 output_sharding.NumTiles(), create_sharded_dot,
3143 conv_window, module, original_hlo, options, b,
3144 windowed_dot_general_loops));
3145 if (dot) {
3146 return dot;
3147 }
3148 }
3149
3150 // We failed to find partial matches, invoke base case again with
3151 // may_reshard_without_detecting_match.
3152 TF_ASSIGN_OR_RETURN(
3153 auto dot,
3154 PartitionBaseCase(
3155 lhs, rhs, output_base_shape, output_sharding, dims_mapping,
3156 num_partitions, create_sharded_dot, conv_window, module, original_hlo,
3157 lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions,
3158 lhs_contracting_partitions, rhs_contracting_partitions,
3159 lhs_non_contracting_partitions, rhs_non_contracting_partitions,
3160 output_lhs_non_contracting_partitions,
3161 output_rhs_non_contracting_partitions, options, b,
3162 windowed_dot_general_loops,
3163 /*may_reshard_without_detecting_match=*/true));
3164 if (dot) {
3165 return dot;
3166 }
3167 return nullptr;
3168 }
3169
PartitionDot(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64_t num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops)3170 StatusOr<HloInstruction*> PartitionDot(
3171 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
3172 const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
3173 int64_t num_partitions,
3174 const std::function<StatusOr<HloInstruction*>(
3175 HloInstruction*, HloInstruction*, SpmdBuilder*,
3176 const Window& conv_window)>& create_sharded_dot,
3177 const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
3178 const SpmdPartitionerOptions& options, SpmdBuilder* b,
3179 std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
3180 windowed_dot_general_loops) {
3181 // First try partitioning without resharding the groups, then try allow
3182 // resharding the groups.
3183 for (bool require_matching_devices_to_group : {true, false}) {
3184 TF_ASSIGN_OR_RETURN(
3185 auto try_partition,
3186 PartitionDot(lhs, rhs, output_base_shape, output_sharding, dims_mapping,
3187 num_partitions, create_sharded_dot, conv_window, module,
3188 original_hlo, require_matching_devices_to_group, options,
3189 b, windowed_dot_general_loops));
3190 if (try_partition) {
3191 return try_partition;
3192 }
3193 }
3194
3195 // Default action.
3196 TF_ASSIGN_OR_RETURN(
3197 auto dot, create_sharded_dot(lhs.Replicate().hlo(), rhs.Replicate().hlo(),
3198 b, conv_window));
3199 dot->set_sharding(HloSharding::Replicate());
3200 return PartitionedHlo(dot, output_base_shape, lhs.state())
3201 .Reshard(output_sharding)
3202 .hlo();
3203 }
3204
3205 } // namespace
3206
HandleDotHelper(HloInstruction * hlo,const DotConvDimsMapping & dims_mapping,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot)3207 Status SpmdPartitioningVisitor::HandleDotHelper(
3208 HloInstruction* hlo, const DotConvDimsMapping& dims_mapping,
3209 const std::function<StatusOr<HloInstruction*>(
3210 HloInstruction*, HloInstruction*, SpmdBuilder*,
3211 const Window& conv_window)>& create_sharded_dot) {
3212 if (hlo->sharding().HasUniqueDevice()) {
3213 return DefaultAction(hlo);
3214 }
3215 auto& lhs = GetPartitionedHlo(hlo->operand(0));
3216 auto& rhs = GetPartitionedHlo(hlo->operand(1));
3217 Window conv_window;
3218 if (hlo->opcode() == HloOpcode::kConvolution) {
3219 conv_window = hlo->window();
3220 }
3221
3222 TF_ASSIGN_OR_RETURN(
3223 auto partitioned_dot,
3224 PartitionDot(lhs, rhs, hlo->shape(), hlo->sharding(), dims_mapping,
3225 num_partitions_, create_sharded_dot, conv_window, module_,
3226 hlo, options_, &b_, &windowed_dot_general_loops_));
3227 SetPartitionedHlo(hlo, [&] { return partitioned_dot; });
3228 return Status::OK();
3229 }
3230
3231 namespace {
3232
3233 // Finds a cluster of nodes that produce the inputs for `hlo` which only
3234 // depend on small operands, which means the cluster should start with
3235 // broadcasts, constants and iotas. All other internal nodes must be
3236 // non-side-effecting elemntwise ops. Returns the set of nodes, and the small
3237 // operands. E.g., for the following graph,
3238 //
3239 // a -> broadcast -> multiply
3240 // iota ---> add--/
3241 // constant/
3242 //
3243 // FindInputNodesIfOnlyDependOnSmallOperands(multiply) will return
3244 // <{broadcast, iota, constant, add, multiply}, [a]>.
3245 std::pair<absl::flat_hash_set<HloInstruction*>, std::vector<HloInstruction*>>
FindInputNodesIfOnlyDependOnSmallOperands(HloInstruction * hlo)3246 FindInputNodesIfOnlyDependOnSmallOperands(HloInstruction* hlo) {
3247 absl::flat_hash_set<HloInstruction*> nodes_found;
3248 std::vector<HloInstruction*> new_operands;
3249 absl::flat_hash_set<const HloInstruction*> new_operands_set;
3250 std::vector<HloInstruction*> worklist;
3251 worklist.push_back(hlo);
3252 while (!worklist.empty()) {
3253 auto inst = worklist.back();
3254 worklist.pop_back();
3255 if (nodes_found.count(inst) > 0) {
3256 continue;
3257 }
3258 if (inst->opcode() == HloOpcode::kBroadcast ||
3259 inst->opcode() == HloOpcode::kConstant ||
3260 inst->opcode() == HloOpcode::kIota) {
3261 nodes_found.insert(inst);
3262 for (auto o : inst->operands()) {
3263 auto res = new_operands_set.emplace(o);
3264 if (res.second) {
3265 new_operands.push_back(o);
3266 }
3267 }
3268 } else if (inst->IsElementwise() && !inst->HasSideEffectNoRecurse() &&
3269 absl::c_all_of(inst->operands(),
3270 [inst](const HloInstruction* o) {
3271 return ShapeUtil::CompatibleIgnoringElementType(
3272 o->shape(), inst->shape());
3273 })) {
3274 nodes_found.insert(inst);
3275 for (auto o : inst->operands()) {
3276 worklist.push_back(o);
3277 }
3278 } else {
3279 nodes_found.clear();
3280 new_operands.clear();
3281 break;
3282 }
3283 }
3284 return {std::move(nodes_found), std::move(new_operands)};
3285 }
3286
3287 // Moves a cluster of memory-reducing nodes into the windowed dot-general loop
3288 // on contracting dimensions. Such a loop has a dynamic slice on the
3289 // non-windowed operand. If we move the input nodes into the loop, the
3290 // dynamic-slice could be merged with them by later optimization passes, which
3291 // reduces memory.
3292 //
3293 // small_operands small_operands
3294 // | |
3295 // input_nodes loop { |
3296 // | => input_nodes
3297 // loop { | |
3298 // dynamic-slice dynamic-slice
3299 // ... ...
3300 // } }
3301 //
3302 // Later optimization passes (TpuPadSliceMover) will merge the dynamic slice
3303 // with the input nodes.
SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions(HloInstruction * loop,int64_t non_windowed_operand_index)3304 Status SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions(
3305 HloInstruction* loop, int64_t non_windowed_operand_index) {
3306 auto input_tuple = loop->mutable_operand(0);
3307 auto old_operand = input_tuple->mutable_operand(non_windowed_operand_index);
3308 auto input_nodes = FindInputNodesIfOnlyDependOnSmallOperands(old_operand);
3309 auto to_sink = std::move(input_nodes.first);
3310 auto new_operands = std::move(input_nodes.second);
3311 if (to_sink.empty()) {
3312 return Status::OK();
3313 }
3314 auto computation = loop->parent();
3315 // Replace the old operand with a tuple of the found small operands.
3316 auto new_input_subtuple =
3317 computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
3318 TF_RETURN_IF_ERROR(input_tuple->ReplaceOperandWithDifferentShape(
3319 non_windowed_operand_index, new_input_subtuple));
3320
3321 auto body = loop->while_body();
3322 auto body_param = body->parameter_instruction(0);
3323 auto old_body_param_users = body_param->users();
3324 // Update all tuple shapes.
3325 for (auto tuple : std::vector<HloInstruction*>{
3326 input_tuple, loop, loop->while_condition()->parameter_instruction(0),
3327 body_param, body->root_instruction()}) {
3328 *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(),
3329 {non_windowed_operand_index}) =
3330 new_input_subtuple->shape();
3331 }
3332 // Now update the loop body.
3333 auto new_operand_tuple_inside =
3334 body->AddInstruction(HloInstruction::CreateGetTupleElement(
3335 new_input_subtuple->shape(), body_param, non_windowed_operand_index));
3336 TF_RETURN_IF_ERROR(body->root_instruction()->ReplaceOperandWithDifferentShape(
3337 non_windowed_operand_index, new_operand_tuple_inside));
3338
3339 // Create nodes inside the loop body.
3340 std::vector<HloInstruction*> worklist;
3341 absl::flat_hash_map<const HloInstruction*, HloInstruction*> outside_to_inside;
3342 auto add_users_if_available = [&](HloInstruction* inst) {
3343 for (auto u : inst->users()) {
3344 if (outside_to_inside.count(u) == 0 && to_sink.count(u) > 0 &&
3345 absl::c_all_of(u->operands(), [&](const HloInstruction* o) {
3346 return outside_to_inside.count(o) > 0;
3347 })) {
3348 worklist.push_back(u);
3349 }
3350 }
3351 };
3352 for (int64_t i = 0; i < new_operands.size(); ++i) {
3353 outside_to_inside[new_operands[i]] =
3354 body->AddInstruction(HloInstruction::CreateGetTupleElement(
3355 new_operands[i]->shape(), new_operand_tuple_inside, i));
3356 add_users_if_available(new_operands[i]);
3357 }
3358 // HLOs to sink without operands.
3359 std::vector<HloInstruction*> nullaries_to_sink;
3360 for (auto inst : to_sink) {
3361 if (inst->operand_count() == 0) {
3362 nullaries_to_sink.push_back(inst);
3363 }
3364 }
3365 // Sort nullaries_to_sink to make it deterministic.
3366 absl::c_sort(nullaries_to_sink,
3367 [](const HloInstruction* a, const HloInstruction* b) {
3368 return a->unique_id() < b->unique_id();
3369 });
3370 worklist.reserve(nullaries_to_sink.size());
3371 for (auto inst : nullaries_to_sink) {
3372 worklist.push_back(inst);
3373 }
3374 while (!worklist.empty()) {
3375 auto inst = worklist.back();
3376 worklist.pop_back();
3377 std::vector<HloInstruction*> inst_new_operands(inst->operand_count());
3378 for (int64_t i = 0; i < inst->operand_count(); ++i) {
3379 inst_new_operands[i] = outside_to_inside[inst->operand(i)];
3380 }
3381 outside_to_inside[inst] = body->AddInstruction(
3382 inst->CloneWithNewOperands(inst->shape(), inst_new_operands));
3383 add_users_if_available(inst);
3384 }
3385 TF_RET_CHECK(outside_to_inside.count(old_operand) > 0);
3386 for (auto ou : old_body_param_users) {
3387 if (ou->opcode() == HloOpcode::kGetTupleElement &&
3388 ou->tuple_index() == non_windowed_operand_index) {
3389 TF_RETURN_IF_ERROR(
3390 ou->ReplaceAllUsesWith(outside_to_inside[old_operand]));
3391 TF_RETURN_IF_ERROR(body->RemoveInstruction(ou));
3392 }
3393 }
3394 return Status::OK();
3395 }
3396
3397 // Moves a cluster of memory-reducing nodes (with reduce nodes at the end)
3398 // into the windowed dot-general loop on non-contracting dimensions. Such a
3399 // loop has a dynamic-update-slice at the output. If we move the user nodes
3400 // into the loop and before the dynamic-update-slice, the user nodes can
3401 // operate on smaller shapes, which reduces memory.
3402 //
3403 // small_operands small_operands
3404 // | | => | |
3405 // | | loop { loop { | |
3406 // | | conv | broadcast conv
3407 // | | | | | /
3408 // | | dynamic-update-slice | dynamic-slice /
3409 // | | | | | /
3410 // | | } | | multiply-----
3411 // |broadcast / | /
3412 // | | / reduce
3413 // |multiply-- |
3414 // \ | dynamic-update-slice
3415 // reduce }
3416 //
3417 // Later optimization passes (TpuPadSliceMover) will merge the dynamic slice
3418 // with the input nodes (broadcast).
MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions(HloInstruction * loop)3419 Status MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions(
3420 HloInstruction* loop) {
3421 CHECK_EQ(loop->user_count(), 1);
3422 // There should be a single direct user of the while loop, which is the
3423 // gte for element 2, i.e., the dot output.
3424 auto user_gte = loop->users().front();
3425 CHECK_EQ(user_gte->opcode(), HloOpcode::kGetTupleElement);
3426 CHECK_EQ(user_gte->tuple_index(), 2);
3427 auto computation = loop->parent();
3428
3429 // Find the reduce outputs and the input nodes they depend on, if input
3430 // nodes only have small operands.
3431 absl::flat_hash_set<HloInstruction*> to_move;
3432 std::vector<HloInstruction*> new_operands;
3433 absl::flat_hash_set<const HloInstruction*> new_operands_set;
3434 std::vector<HloInstruction*> reduce_outputs;
3435 std::vector<HloInstruction*> worklist;
3436 Shape padded_shape = user_gte->shape();
3437 Shape unpadded_shape = user_gte->shape();
3438 auto original_output = user_gte;
3439
3440 if (user_gte->user_count() == 1 &&
3441 user_gte->users().back()->opcode() == HloOpcode::kSlice) {
3442 original_output = user_gte->users().back();
3443 unpadded_shape = original_output->shape();
3444 }
3445 for (auto u : original_output->users()) {
3446 worklist.push_back(u);
3447 }
3448 to_move.insert(original_output);
3449 while (!worklist.empty()) {
3450 auto inst = worklist.back();
3451 worklist.pop_back();
3452 if (to_move.count(inst) > 0) {
3453 continue;
3454 }
3455 // We only support reduces with simple reduction function, since we may
3456 // need to accumulate across iterations manually.
3457 if (inst->opcode() == HloOpcode::kReduce &&
3458 inst->to_apply()->instruction_count() == 3 &&
3459 inst->to_apply()->num_parameters() == 2 &&
3460 inst->to_apply()->root_instruction()->IsElementwise()) {
3461 to_move.insert(inst);
3462 auto other_operand = inst->mutable_operand(1);
3463 auto res = new_operands_set.emplace(other_operand);
3464 if (res.second) {
3465 new_operands.push_back(other_operand);
3466 }
3467 reduce_outputs.push_back(inst);
3468 } else if (inst != computation->root_instruction() &&
3469 inst->user_count() > 0 && inst->IsElementwise() &&
3470 !inst->HasSideEffectNoRecurse() &&
3471 absl::c_all_of(inst->operands(),
3472 [inst](const HloInstruction* o) {
3473 return ShapeUtil::CompatibleIgnoringElementType(
3474 o->shape(), inst->shape());
3475 })) {
3476 // For an elementwise op, we need to make sure that they depend on only
3477 // nodes already in to_move and nodes with small operands.
3478 bool can_include = true;
3479 for (auto operand : inst->operands()) {
3480 if (to_move.count(operand) > 0) {
3481 continue;
3482 }
3483 auto find_result = FindInputNodesIfOnlyDependOnSmallOperands(operand);
3484 if (find_result.first.empty()) {
3485 can_include = false;
3486 break;
3487 }
3488 for (auto n : find_result.first) {
3489 to_move.insert(n);
3490 }
3491 for (auto new_operand : find_result.second) {
3492 auto res = new_operands_set.insert(new_operand);
3493 if (res.second) {
3494 new_operands.push_back(new_operand);
3495 }
3496 }
3497 }
3498 if (!can_include) {
3499 to_move.clear();
3500 break;
3501 }
3502 to_move.insert(inst);
3503 for (auto u : inst->users()) {
3504 worklist.push_back(u);
3505 }
3506 } else {
3507 to_move.clear();
3508 break;
3509 }
3510 }
3511 // If nothing is found, to_move could contain only original_output, or
3512 // cleared by the above code.
3513 if (to_move.size() <= 1) {
3514 return Status::OK();
3515 }
3516
3517 // We will replace the original loop output with reduce-shape outputs.
3518 // Create the initial buffers before the loop.
3519 for (auto out : reduce_outputs) {
3520 auto padded_out_shape = out->shape();
3521 int64_t operand_dim = 0;
3522 int64_t output_dim = 0;
3523 while (output_dim < padded_out_shape.rank()) {
3524 if (absl::c_linear_search(out->dimensions(), operand_dim)) {
3525 // Dimension colapsed.
3526 ++operand_dim;
3527 continue;
3528 }
3529 // Kept dimensions have the same size of the padded shape.
3530 padded_out_shape.set_dimensions(output_dim,
3531 padded_shape.dimensions(operand_dim));
3532 ++operand_dim;
3533 ++output_dim;
3534 }
3535 auto broadcast =
3536 computation->AddInstruction(HloInstruction::CreateBroadcast(
3537 padded_out_shape,
3538 computation->AddInstruction(HloInstruction::CreateConstant(
3539 LiteralUtil::Zero(out->shape().element_type()))),
3540 {}));
3541 new_operands.push_back(broadcast);
3542 }
3543
3544 auto input_tuple = loop->mutable_operand(0);
3545 // Create the new input subtuple that contains the small operands and the
3546 // reduce-shape result buffers.
3547 auto new_input_subtuple =
3548 computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
3549 TF_RETURN_IF_ERROR(
3550 input_tuple->ReplaceOperandWithDifferentShape(2, new_input_subtuple));
3551 auto body = loop->while_body();
3552 auto body_param = body->parameter_instruction(0);
3553 auto body_root = body->root_instruction();
3554 CHECK_EQ(body_root->opcode(), HloOpcode::kTuple);
3555 // Update tuple shapes.
3556 for (auto tuple : std::vector<HloInstruction*>{
3557 input_tuple, loop, loop->while_condition()->parameter_instruction(0),
3558 body_param, body_root}) {
3559 *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {2}) =
3560 new_input_subtuple->shape();
3561 }
3562 auto new_loop_input =
3563 body->AddInstruction(HloInstruction::CreateGetTupleElement(
3564 new_input_subtuple->shape(), body_param, 2));
3565
3566 // Now create the moved nodes inside the loop body.
3567 absl::flat_hash_map<const HloInstruction*, HloInstruction*> outside_to_inside;
3568 worklist.clear();
3569 auto add_users_if_available = [&](HloInstruction* inst) {
3570 for (auto u : inst->users()) {
3571 if (outside_to_inside.count(u) == 0 && to_move.count(u) > 0 &&
3572 absl::c_all_of(u->operands(), [&](const HloInstruction* o) {
3573 return outside_to_inside.count(o) > 0;
3574 })) {
3575 worklist.push_back(u);
3576 }
3577 }
3578 };
3579 for (int64_t i = 0; i < new_operands.size(); ++i) {
3580 outside_to_inside[new_operands[i]] =
3581 body->AddInstruction(HloInstruction::CreateGetTupleElement(
3582 new_operands[i]->shape(), new_loop_input, i));
3583 add_users_if_available(new_operands[i]);
3584 }
3585 // The elementwise nodes will be created with sliced shape. The original
3586 // loop output corresponds to the dynamic-update-slice's update slice.
3587 auto dus = body_root->mutable_operand(2);
3588 CHECK_EQ(dus->opcode(), HloOpcode::kDynamicUpdateSlice);
3589 outside_to_inside[original_output] = dus->mutable_operand(1);
3590 add_users_if_available(original_output);
3591 std::vector<HloInstruction*> slice_offsets(padded_shape.rank());
3592 for (int64_t i = 0; i < slice_offsets.size(); ++i) {
3593 slice_offsets[i] = dus->mutable_operand(i + 2);
3594 }
3595 auto get_slice = [&](HloInstruction* padded) {
3596 return body->AddInstruction(HloInstruction::CreateDynamicSlice(
3597 ShapeUtil::ChangeElementType(dus->operand(1)->shape(),
3598 padded->shape().element_type()),
3599 padded, slice_offsets, dus->operand(1)->shape().dimensions()));
3600 };
3601 // Helper functions to create nodes with small operands.
3602 auto add_broadcast = [&](const HloInstruction* broadcast) {
3603 auto padded_operand_shape = broadcast->operand(0)->shape();
3604 for (int64_t i = 0; i < broadcast->dimensions().size(); ++i) {
3605 padded_operand_shape.set_dimensions(
3606 i, padded_shape.dimensions(broadcast->dimensions(i)));
3607 }
3608 auto padded_operand = PadToShape(outside_to_inside[broadcast->operand(0)],
3609 padded_operand_shape, nullptr, body);
3610 outside_to_inside[broadcast] =
3611 get_slice(body->AddInstruction(broadcast->CloneWithNewOperands(
3612 ShapeUtil::ChangeElementType(padded_shape,
3613 padded_operand_shape.element_type()),
3614 {padded_operand})));
3615 };
3616 auto add_iota = [&](const HloInstruction* iota) {
3617 outside_to_inside[iota] =
3618 get_slice(body->AddInstruction(iota->CloneWithNewOperands(
3619 ShapeUtil::ChangeElementType(padded_shape,
3620 iota->shape().element_type()),
3621 {})));
3622 };
3623 auto add_constant = [&](const HloInstruction* constant) {
3624 outside_to_inside[constant] = body->AddInstruction(constant->Clone());
3625 outside_to_inside[constant] = get_slice(
3626 PadToShape(outside_to_inside[constant],
3627 ShapeUtil::ChangeElementType(
3628 padded_shape, constant->shape().element_type()),
3629 nullptr, body));
3630 };
3631 while (!worklist.empty()) {
3632 auto inst = worklist.back();
3633 worklist.pop_back();
3634 if (outside_to_inside.count(inst) > 0) {
3635 continue;
3636 }
3637 if (inst->opcode() == HloOpcode::kBroadcast) {
3638 add_broadcast(inst);
3639 } else if (inst->opcode() == HloOpcode::kIota) {
3640 add_iota(inst);
3641 } else if (inst->opcode() == HloOpcode::kConstant) {
3642 add_constant(inst);
3643 } else if (inst->opcode() == HloOpcode::kReduce) {
3644 // This is an output, for which we has special handling later.
3645 } else {
3646 std::vector<HloInstruction*> operands_inside(inst->operand_count());
3647 for (int64_t i = 0; i < operands_inside.size(); ++i) {
3648 operands_inside[i] = outside_to_inside[inst->operand(i)];
3649 }
3650 outside_to_inside[inst] = body->AddInstruction(inst->CloneWithNewOperands(
3651 ShapeUtil::ChangeElementType(dus->operand(1)->shape(),
3652 inst->shape().element_type()),
3653 operands_inside));
3654 }
3655 add_users_if_available(inst);
3656 }
3657 std::vector<HloInstruction*> new_outputs_inside(new_operands.size());
3658 for (int64_t i = 0; i < new_outputs_inside.size(); ++i) {
3659 new_outputs_inside[i] = outside_to_inside[new_operands[i]];
3660 }
3661 // Now create the reduce outpus inside of the loop.
3662 for (int64_t i = 0; i < reduce_outputs.size(); ++i) {
3663 auto reduce_outside = reduce_outputs[i];
3664 CHECK_EQ(reduce_outside->opcode(), HloOpcode::kReduce);
3665 int64_t index_in_operand = new_operands.size() - reduce_outputs.size() + i;
3666 auto last_iter_result = outside_to_inside[new_operands[index_in_operand]];
3667 auto operand0 = outside_to_inside[reduce_outside->operand(0)];
3668 auto operand1 = outside_to_inside[reduce_outside->operand(1)];
3669 TF_ASSIGN_OR_RETURN(auto reduce_shape,
3670 ShapeInference::InferReduceShape(
3671 {&operand0->shape(), &operand1->shape()},
3672 reduce_outside->dimensions(),
3673 reduce_outside->to_apply()->ComputeProgramShape()));
3674 *reduce_shape.mutable_layout() = reduce_outside->shape().layout();
3675 std::vector<HloInstruction*> reduce_dus_offsets;
3676 // If any collapsed dimension is windowed, we need to accumulate with last
3677 // iteration's result. If such a dimension has padding, we also need to
3678 // mask off invalid data.
3679 bool needs_accumulate = false;
3680 std::vector<int64> dims_to_mask;
3681 for (int64_t i = 0; i < slice_offsets.size(); ++i) {
3682 if (absl::c_linear_search(reduce_outside->dimensions(), i)) {
3683 if (reduce_outside->operand(0)->shape().dimensions(i) !=
3684 operand0->shape().dimensions(i)) {
3685 needs_accumulate = true;
3686 if (unpadded_shape.dimensions(i) != padded_shape.dimensions(i)) {
3687 dims_to_mask.push_back(i);
3688 }
3689 }
3690 continue;
3691 }
3692 reduce_dus_offsets.push_back(slice_offsets[i]);
3693 }
3694 // Mask off invalid data in collapsed dimensions.
3695 for (int64_t dim : dims_to_mask) {
3696 auto iota = body->AddInstruction(HloInstruction::CreateIota(
3697 ShapeUtil::ChangeElementType(operand0->shape(), S32), dim));
3698 auto add = body->AddInstruction(HloInstruction::CreateBinary(
3699 iota->shape(), HloOpcode::kAdd, iota,
3700 body->AddInstruction(HloInstruction::CreateBroadcast(
3701 iota->shape(), slice_offsets[dim], {}))));
3702 auto limit = body->AddInstruction(HloInstruction::CreateBroadcast(
3703 iota->shape(),
3704 body->AddInstruction(
3705 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
3706 reduce_outside->operand(0)->shape().dimensions(dim)))),
3707 {}));
3708 auto compare = body->AddInstruction(HloInstruction::CreateCompare(
3709 ShapeUtil::ChangeElementType(iota->shape(), PRED), add, limit,
3710 ComparisonDirection::kLt));
3711 operand0 = body->AddInstruction(HloInstruction::CreateTernary(
3712 operand0->shape(), HloOpcode::kSelect, compare, operand0,
3713 body->AddInstruction(HloInstruction::CreateBroadcast(
3714 operand0->shape(), operand1, {}))));
3715 }
3716 auto output_inside =
3717 body->AddInstruction(reduce_outside->CloneWithNewOperands(
3718 reduce_shape, {operand0, operand1}));
3719 // Accumulate with previous results if needed.
3720 if (needs_accumulate) {
3721 auto input_slice =
3722 body->AddInstruction(HloInstruction::CreateDynamicSlice(
3723 output_inside->shape(), last_iter_result, reduce_dus_offsets,
3724 output_inside->shape().dimensions()));
3725 output_inside = body->AddInstruction(HloInstruction::CreateBinary(
3726 output_inside->shape(),
3727 reduce_outside->to_apply()->root_instruction()->opcode(),
3728 output_inside, input_slice));
3729 }
3730 // Dynamic-update-slice if needed.
3731 if (!ShapeUtil::Compatible(output_inside->shape(),
3732 last_iter_result->shape())) {
3733 output_inside =
3734 body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
3735 last_iter_result->shape(), last_iter_result, output_inside,
3736 reduce_dus_offsets));
3737 }
3738 new_outputs_inside[index_in_operand] = output_inside;
3739 }
3740 // Body output.
3741 auto new_output_inside =
3742 body->AddInstruction(HloInstruction::CreateTuple(new_outputs_inside));
3743 TF_RETURN_IF_ERROR(
3744 body_root->ReplaceOperandWithDifferentShape(2, new_output_inside));
3745 TF_RETURN_IF_ERROR(body->RemoveInstructionAndUnusedOperands(dus));
3746 // Replace uses of the reduces outside the loop.
3747 auto new_output_gte =
3748 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
3749 new_output_inside->shape(), loop, 2));
3750 for (int64_t i = 0; i < reduce_outputs.size(); ++i) {
3751 int64_t index_in_operand = new_operands.size() - reduce_outputs.size() + i;
3752 auto new_output =
3753 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
3754 new_outputs_inside[index_in_operand]->shape(), new_output_gte,
3755 index_in_operand));
3756 if (!ShapeUtil::Compatible(new_output->shape(),
3757 reduce_outputs[i]->shape())) {
3758 new_output = computation->AddInstruction(HloInstruction::CreateSlice(
3759 reduce_outputs[i]->shape(), new_output,
3760 std::vector<int64>(new_output->shape().rank(), 0),
3761 reduce_outputs[i]->shape().dimensions(),
3762 std::vector<int64>(new_output->shape().rank(), 1)));
3763 }
3764 TF_RETURN_IF_ERROR(reduce_outputs[i]->ReplaceAllUsesWith(new_output));
3765 TF_RETURN_IF_ERROR(
3766 computation->RemoveInstructionAndUnusedOperands(reduce_outputs[i]));
3767 }
3768 return Status::OK();
3769 }
3770
3771 } // namespace
3772
DoCodeMotionForWindowedDotGeneralLoops(HloComputation * computation,const SpmdPartitionerOptions & options)3773 Status SpmdPartitioningVisitor::DoCodeMotionForWindowedDotGeneralLoops(
3774 HloComputation* computation, const SpmdPartitionerOptions& options) {
3775 for (auto& loop : windowed_dot_general_loops_) {
3776 if (loop.windowed_in_contracting_dims || loop.windowed_in_batch_dims ||
3777 loop.operands_sharded_at_contracting_dims) {
3778 // We have a dynamic-slice for the non-windowed operand in
3779 // batch/contracting-dim/noncontracting-dim windowed dot-general. So
3780 // moving the broadcast/iota/elementwise ops into the loop could help
3781 // reduce memory via fusion.
3782 TF_RETURN_IF_ERROR(
3783 SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions(
3784 loop.while_loop, 1 - loop.windowed_operand));
3785 }
3786 // Currently unrolled loop does not support this optimization.
3787 if (!options.bidirectional_windowed_einsum &&
3788 !options.unroll_windowed_einsum && !loop.windowed_in_contracting_dims &&
3789 !loop.operands_sharded_at_contracting_dims) {
3790 // We have a dynamic-update-slice for the output in
3791 // batch/non-contracting-dim windowed dot-general. So moving reduce ops
3792 // into the loop could help reduce memory.
3793 TF_RETURN_IF_ERROR(
3794 MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions(
3795 loop.while_loop));
3796 }
3797 }
3798 return Status::OK();
3799 }
3800
3801 } // namespace spmd
3802 } // namespace xla
3803