• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/algorithm/container.h"
17 #include "absl/container/flat_hash_map.h"
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/types/optional.h"
20 #include "tensorflow/compiler/xla/literal_util.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
27 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
28 #include "tensorflow/compiler/xla/service/shape_inference.h"
29 #include "tensorflow/compiler/xla/service/spmd/convolution_handler.h"
30 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
31 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/compiler/xla/window_util.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/lib/gtl/cleanup.h"
37 #include "tensorflow/core/platform/numbers.h"
38 
39 namespace xla {
40 namespace spmd {
41 
HandleDot(HloInstruction * hlo)42 Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) {
43   DotConvDimsMapping mapping;
44   const auto& dnums = hlo->dot_dimension_numbers();
45   int64_t next_output_dim = 0;
46   for (int64_t i = 0; i < dnums.lhs_batch_dimensions_size(); ++i) {
47     mapping.batch_dims.emplace_back();
48     mapping.batch_dims.back().lhs = dnums.lhs_batch_dimensions(i);
49     mapping.batch_dims.back().rhs = dnums.rhs_batch_dimensions(i);
50     mapping.batch_dims.back().output = next_output_dim++;
51   }
52   for (int64_t i = 0; i < dnums.lhs_contracting_dimensions_size(); ++i) {
53     mapping.contracting_dims.emplace_back();
54     mapping.contracting_dims.back().lhs = dnums.lhs_contracting_dimensions(i);
55     mapping.contracting_dims.back().rhs = dnums.rhs_contracting_dimensions(i);
56     mapping.contracting_dims.back().output = -1;
57   }
58   for (int64_t i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
59     if (absl::c_linear_search(dnums.lhs_batch_dimensions(), i) ||
60         absl::c_linear_search(dnums.lhs_contracting_dimensions(), i)) {
61       continue;
62     }
63     mapping.lhs_non_contracting_dims.emplace_back();
64     mapping.lhs_non_contracting_dims.back().lhs = i;
65     mapping.lhs_non_contracting_dims.back().rhs = -1;
66     mapping.lhs_non_contracting_dims.back().output = next_output_dim++;
67   }
68   for (int64_t i = 0; i < hlo->operand(1)->shape().rank(); ++i) {
69     if (absl::c_linear_search(dnums.rhs_batch_dimensions(), i) ||
70         absl::c_linear_search(dnums.rhs_contracting_dimensions(), i)) {
71       continue;
72     }
73     mapping.rhs_non_contracting_dims.emplace_back();
74     mapping.rhs_non_contracting_dims.back().lhs = -1;
75     mapping.rhs_non_contracting_dims.back().rhs = i;
76     mapping.rhs_non_contracting_dims.back().output = next_output_dim++;
77   }
78   auto create_sharded_dot =
79       [&](HloInstruction* l, HloInstruction* r, SpmdBuilder* b,
80           const Window& conv_window) -> StatusOr<HloInstruction*> {
81     TF_ASSIGN_OR_RETURN(
82         auto sharded_dot_shape,
83         ShapeInference::InferDotOpShape(
84             l->shape(), r->shape(), hlo->dot_dimension_numbers(),
85             /*preferred_element_type=*/hlo->shape().element_type()));
86     return b->AddInstruction(HloInstruction::CreateDot(
87         sharded_dot_shape, l, r, hlo->dot_dimension_numbers(),
88         hlo->precision_config()));
89   };
90   return HandleDotHelper(hlo, mapping, create_sharded_dot);
91 }
92 
93 namespace {
94 
95 enum class WindowedEinsumOperand { LHS, RHS };
96 
97 struct WindowedEinsumConfig {
98   WindowedEinsumOperand windowed_op;
99   bool windowed_at_contracting_dims;
100   bool windowed_at_batch_dims;
101   bool operands_sharded_at_contracting_dims;
102 };
103 
104 struct DotDimensionIndexMapping {
105   std::vector<int64> lhs_to_rhs_indices;
106   std::vector<int64> lhs_to_output_indices;
107   std::vector<int64> rhs_to_lhs_indices;
108   std::vector<int64> rhs_to_output_indices;
109   std::vector<int64> output_to_lhs_indices;
110   std::vector<int64> output_to_rhs_indices;
111 };
112 
UpdateDDNums(DotDimensionNumbers * new_ddnums,int64_t reshaped_dim,bool lhs)113 void UpdateDDNums(DotDimensionNumbers* new_ddnums, int64_t reshaped_dim,
114                   bool lhs) {
115   auto update_dims =
116       [&reshaped_dim](tensorflow::protobuf::RepeatedField<int64>* dims) {
117         bool add_reshaped_dim = false;
118         if (absl::c_linear_search(*dims, reshaped_dim)) {
119           add_reshaped_dim = true;
120         }
121         for (int64_t i = 0; i < dims->size(); ++i) {
122           auto dim = dims->at(i);
123           if (reshaped_dim <= dim) {
124             dims->Set(i, dim + 1);
125           }
126         }
127         if (add_reshaped_dim) {
128           dims->Add(reshaped_dim);
129         }
130       };
131 
132   if (lhs) {
133     update_dims(new_ddnums->mutable_lhs_contracting_dimensions());
134     update_dims(new_ddnums->mutable_lhs_batch_dimensions());
135   } else {  // rhs
136     update_dims(new_ddnums->mutable_rhs_contracting_dimensions());
137     update_dims(new_ddnums->mutable_rhs_batch_dimensions());
138   }
139 }
140 
GenNewWindow(const HloInstruction * original_dot,const HloInstruction * dot_lhs,const HloInstruction * dot_rhs,int64_t lhs_concat_dim,int64_t rhs_concat_dim,bool windowed_at_contracting_dims,bool windowed_at_batch_dims)141 Window GenNewWindow(const HloInstruction* original_dot,
142                     const HloInstruction* dot_lhs,
143                     const HloInstruction* dot_rhs, int64_t lhs_concat_dim,
144                     int64_t rhs_concat_dim, bool windowed_at_contracting_dims,
145                     bool windowed_at_batch_dims) {
146   auto new_window = original_dot->window();
147   const ConvolutionDimensionNumbers& conv_dnums =
148       original_dot->convolution_dimension_numbers();
149   if (lhs_concat_dim != -1) {
150     for (int64_t i = 0; i < conv_dnums.input_spatial_dimensions_size(); ++i) {
151       if (conv_dnums.input_spatial_dimensions(i) == lhs_concat_dim) {
152         auto wd = new_window.mutable_dimensions(i);
153         auto lhs_size = dot_lhs->shape().dimensions(lhs_concat_dim + 1);
154         if (windowed_at_contracting_dims) {
155           wd->set_size(lhs_size);
156         }
157         if (windowed_at_batch_dims) {
158           wd->set_size(lhs_size);
159           wd->set_padding_low(0);
160           wd->set_padding_high(0);
161           wd->set_stride(std::max<int64>(1, lhs_size - 1));
162           wd->set_window_dilation(1);
163           wd->set_base_dilation(lhs_size);
164           wd->set_window_reversal(false);
165         }
166       }
167     }
168   }
169   if (rhs_concat_dim != -1) {
170     for (int64_t i = 0; i < conv_dnums.kernel_spatial_dimensions_size(); ++i) {
171       if (conv_dnums.kernel_spatial_dimensions(i) == rhs_concat_dim &&
172           !windowed_at_contracting_dims && !windowed_at_batch_dims &&
173           lhs_concat_dim == -1) {
174         auto wd = new_window.mutable_dimensions(i);
175         auto rhs_size = dot_rhs->shape().dimensions(rhs_concat_dim + 1);
176         wd->set_size(rhs_size);
177         wd->set_padding_low(rhs_size - 1);
178         wd->set_padding_high(rhs_size - 1);
179       }
180     }
181   }
182   // Add the extra dimension to window.
183   WindowDimension* new_dim = new_window.add_dimensions();
184   if (windowed_at_contracting_dims) {
185     new_dim->set_size(2);
186     new_dim->set_padding_low(0);
187     new_dim->set_padding_high(0);
188     new_dim->set_stride(1);
189     new_dim->set_window_dilation(1);
190     new_dim->set_base_dilation(1);
191     new_dim->set_window_reversal(false);
192   } else if (windowed_at_batch_dims) {
193     new_dim->set_size(2);
194     new_dim->set_padding_low(0);
195     new_dim->set_padding_high(0);
196     new_dim->set_stride(1);  // std::max<int64>(1, 2 - 1)
197     new_dim->set_window_dilation(1);
198     new_dim->set_base_dilation(2);
199     new_dim->set_window_reversal(false);
200   } else {
201     if (lhs_concat_dim != -1) {
202       new_dim->set_size(1);
203       new_dim->set_padding_low(0);
204       new_dim->set_padding_high(0);
205       new_dim->set_stride(1);
206       new_dim->set_window_dilation(1);
207       new_dim->set_base_dilation(1);
208       new_dim->set_window_reversal(false);
209     }
210     if (rhs_concat_dim != -1) {
211       new_dim->set_size(2);          // rhs_size
212       new_dim->set_padding_low(1);   // rhs_size - 1
213       new_dim->set_padding_high(1);  // rhs_size - 1
214       new_dim->set_stride(1);
215       new_dim->set_window_dilation(1);
216       new_dim->set_base_dilation(1);
217       new_dim->set_window_reversal(true);
218     }
219   }
220 
221   VLOG(2) << "new_window: " << new_window.ShortDebugString();
222   return new_window;
223 }
224 
GenNewConvDNums(const HloInstruction * original_dot,const HloInstruction * dot_lhs,const HloInstruction * dot_rhs,int64_t lhs_concat_dim,int64_t rhs_concat_dim,bool windowed_at_contracting_dims,bool windowed_at_batch_dims,const std::vector<int64> & lhs_to_output_indices,const std::vector<int64> & rhs_to_output_indices,const Shape & new_dot_shape)225 ConvolutionDimensionNumbers GenNewConvDNums(
226     const HloInstruction* original_dot, const HloInstruction* dot_lhs,
227     const HloInstruction* dot_rhs, int64_t lhs_concat_dim,
228     int64_t rhs_concat_dim, bool windowed_at_contracting_dims,
229     bool windowed_at_batch_dims,
230     const std::vector<int64>& lhs_to_output_indices,
231     const std::vector<int64>& rhs_to_output_indices,
232     const Shape& new_dot_shape) {
233   // Generate the new conv dimension numbers.
234   const ConvolutionDimensionNumbers& dnums =
235       original_dot->convolution_dimension_numbers();
236   // Handle the LHS dimension numbers.
237   int64_t input_batch_dimension = dnums.input_batch_dimension();
238   int64_t input_feature_dimension = dnums.input_feature_dimension();
239   std::vector<int64> input_spatial_dimensions(
240       dnums.input_spatial_dimensions().begin(),
241       dnums.input_spatial_dimensions().end());
242   if (lhs_concat_dim != -1) {
243     if (lhs_concat_dim <= input_batch_dimension) {
244       input_batch_dimension++;
245     }
246     if (lhs_concat_dim <= input_feature_dimension) {
247       input_feature_dimension++;
248     }
249     for (int64_t i = 0; i < input_spatial_dimensions.size(); ++i) {
250       if (lhs_concat_dim <= input_spatial_dimensions[i]) {
251         input_spatial_dimensions[i]++;
252       }
253     }
254     input_spatial_dimensions.push_back(lhs_concat_dim);
255   }
256   if (rhs_concat_dim != -1 && !windowed_at_contracting_dims &&
257       !windowed_at_batch_dims) {
258     input_spatial_dimensions.push_back(dot_lhs->shape().dimensions_size() - 1);
259   }
260   // Handle the RHS dimension numbers.
261   int64_t kernel_input_feature_dimension =
262       dnums.kernel_input_feature_dimension();
263   int64_t kernel_output_feature_dimension =
264       dnums.kernel_output_feature_dimension();
265   std::vector<int64> kernel_spatial_dimensions(
266       dnums.kernel_spatial_dimensions().begin(),
267       dnums.kernel_spatial_dimensions().end());
268   if (rhs_concat_dim != -1) {
269     if (rhs_concat_dim <= kernel_input_feature_dimension) {
270       kernel_input_feature_dimension++;
271     }
272     if (rhs_concat_dim <= kernel_output_feature_dimension) {
273       kernel_output_feature_dimension++;
274     }
275     for (int64_t i = 0; i < kernel_spatial_dimensions.size(); ++i) {
276       if (rhs_concat_dim <= kernel_spatial_dimensions[i]) {
277         kernel_spatial_dimensions[i]++;
278       }
279     }
280     kernel_spatial_dimensions.push_back(rhs_concat_dim);
281   }
282   if (lhs_concat_dim != -1 && !windowed_at_contracting_dims &&
283       !windowed_at_batch_dims) {
284     kernel_spatial_dimensions.push_back(dot_rhs->shape().dimensions_size() - 1);
285   }
286   // Handle the Output dimension numbers.
287   int64_t output_batch_dimension = dnums.output_batch_dimension();
288   int64_t output_feature_dimension = dnums.output_feature_dimension();
289   std::vector<int64> output_spatial_dimensions(
290       dnums.output_spatial_dimensions().begin(),
291       dnums.output_spatial_dimensions().end());
292   if (!windowed_at_contracting_dims) {
293     auto output_slice_dim = lhs_concat_dim != -1
294                                 ? lhs_to_output_indices[lhs_concat_dim]
295                                 : rhs_to_output_indices[rhs_concat_dim];
296     if (output_slice_dim <= output_batch_dimension) {
297       output_batch_dimension++;
298     }
299     if (output_slice_dim <= output_feature_dimension) {
300       output_feature_dimension++;
301     }
302     for (int64_t i = 0; i < output_spatial_dimensions.size(); ++i) {
303       if (output_slice_dim <= output_spatial_dimensions[i]) {
304         output_spatial_dimensions[i]++;
305       }
306     }
307     output_spatial_dimensions.push_back(output_slice_dim);
308   } else {
309     output_spatial_dimensions.push_back(new_dot_shape.dimensions_size() - 1);
310   }
311   // Construct the new dot dimension numbers.
312   ConvolutionDimensionNumbers new_dnums;
313   new_dnums.set_input_batch_dimension(input_batch_dimension);
314   new_dnums.set_input_feature_dimension(input_feature_dimension);
315   for (auto dim : input_spatial_dimensions) {
316     new_dnums.add_input_spatial_dimensions(dim);
317   }
318   new_dnums.set_kernel_input_feature_dimension(kernel_input_feature_dimension);
319   new_dnums.set_kernel_output_feature_dimension(
320       kernel_output_feature_dimension);
321   for (auto dim : kernel_spatial_dimensions) {
322     new_dnums.add_kernel_spatial_dimensions(dim);
323   }
324   new_dnums.set_output_batch_dimension(output_batch_dimension);
325   new_dnums.set_output_feature_dimension(output_feature_dimension);
326   for (auto dim : output_spatial_dimensions) {
327     new_dnums.add_output_spatial_dimensions(dim);
328   }
329 
330   return new_dnums;
331 }
332 
ComputeDimensionIndexMapping(const DotConvDimsMapping & dims_mapping,int64_t lhs_rank,int64_t rhs_rank,int64_t output_rank)333 DotDimensionIndexMapping ComputeDimensionIndexMapping(
334     const DotConvDimsMapping& dims_mapping, int64_t lhs_rank, int64_t rhs_rank,
335     int64_t output_rank) {
336   std::vector<int64> lhs_to_rhs_indices(lhs_rank, -1);
337   std::vector<int64> lhs_to_output_indices(lhs_rank, -1);
338   std::vector<int64> rhs_to_lhs_indices(rhs_rank, -1);
339   std::vector<int64> rhs_to_output_indices(rhs_rank, -1);
340   std::vector<int64> output_to_lhs_indices(output_rank, -1);
341   std::vector<int64> output_to_rhs_indices(output_rank, -1);
342   auto populate_indices_mapping =
343       [&](const DotConvDimsMapping::DimsMapping& mapping) {
344         if (mapping.lhs >= 0) {
345           lhs_to_rhs_indices[mapping.lhs] = mapping.rhs;
346           lhs_to_output_indices[mapping.lhs] = mapping.output;
347         }
348         if (mapping.rhs >= 0) {
349           rhs_to_lhs_indices[mapping.rhs] = mapping.lhs;
350           rhs_to_output_indices[mapping.rhs] = mapping.output;
351         }
352         if (mapping.output >= 0) {
353           output_to_lhs_indices[mapping.output] = mapping.lhs;
354           output_to_rhs_indices[mapping.output] = mapping.rhs;
355         }
356       };
357   for (const auto& mapping : dims_mapping.batch_dims) {
358     populate_indices_mapping(mapping);
359   }
360   for (const auto& mapping : dims_mapping.contracting_dims) {
361     populate_indices_mapping(mapping);
362   }
363   for (const auto& mapping : dims_mapping.lhs_non_contracting_dims) {
364     populate_indices_mapping(mapping);
365   }
366   for (const auto& mapping : dims_mapping.rhs_non_contracting_dims) {
367     populate_indices_mapping(mapping);
368   }
369   for (const auto& mapping : dims_mapping.conv_spatial_dims) {
370     populate_indices_mapping(mapping);
371   }
372   return DotDimensionIndexMapping{lhs_to_rhs_indices,    lhs_to_output_indices,
373                                   rhs_to_lhs_indices,    rhs_to_output_indices,
374                                   output_to_lhs_indices, output_to_rhs_indices};
375 }
376 
GetWindowedEinsumConfiguration(int64_t num_partitions,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,int64_t rhs_contracting_partitions,int64_t rhs_non_contracting_partitions,int64_t rhs_batch_partitions,int64_t lhs_contracting_partitions,int64_t lhs_non_contracting_partitions,int64_t lhs_batch_partitions,int64_t rhs_shape_size,int64_t lhs_shape_size,int64_t output_shape_size,int64_t einsum_threshold_mib,const absl::optional<HloSharding> & output_sharding_transposed_to_match_lhs,const absl::optional<HloSharding> & output_sharding_transposed_to_match_rhs,const HloSharding & lhs_sharding,const HloSharding & rhs_sharding)377 absl::optional<WindowedEinsumConfig> GetWindowedEinsumConfiguration(
378     int64_t num_partitions, int64_t output_lhs_non_contracting_partitions,
379     int64_t output_rhs_non_contracting_partitions,
380     int64_t rhs_contracting_partitions, int64_t rhs_non_contracting_partitions,
381     int64_t rhs_batch_partitions, int64_t lhs_contracting_partitions,
382     int64_t lhs_non_contracting_partitions, int64_t lhs_batch_partitions,
383     int64_t rhs_shape_size, int64_t lhs_shape_size, int64_t output_shape_size,
384     int64_t einsum_threshold_mib,
385     const absl::optional<HloSharding>& output_sharding_transposed_to_match_lhs,
386     const absl::optional<HloSharding>& output_sharding_transposed_to_match_rhs,
387     const HloSharding& lhs_sharding, const HloSharding& rhs_sharding) {
388   if (output_lhs_non_contracting_partitions == num_partitions &&
389       output_sharding_transposed_to_match_lhs == lhs_sharding &&
390       rhs_shape_size >= einsum_threshold_mib * 1024 * 1024) {
391     if (rhs_contracting_partitions == num_partitions) {
392       return WindowedEinsumConfig{
393           /*windowed_op=*/WindowedEinsumOperand::RHS,
394           /*windowed_at_contracting_dims*/ true,
395           /*windowed_at_batch_dims=*/false,
396           /*operands_sharded_at_contracting_dims=*/false};
397     }
398     if (rhs_non_contracting_partitions == num_partitions) {
399       return WindowedEinsumConfig{
400           /*windowed_op=*/WindowedEinsumOperand::RHS,
401           /*windowed_at_contracting_dims*/ false,
402           /*windowed_at_batch_dims=*/false,
403           /*operands_sharded_at_contracting_dims=*/false};
404     }
405     if (rhs_batch_partitions == num_partitions) {
406       return WindowedEinsumConfig{
407           /*windowed_op=*/WindowedEinsumOperand::RHS,
408           /*windowed_at_contracting_dims*/ false,
409           /*windowed_at_batch_dims=*/true,
410           /*operands_sharded_at_contracting_dims=*/false};
411     }
412   }
413   if (output_rhs_non_contracting_partitions == num_partitions &&
414       output_sharding_transposed_to_match_rhs == rhs_sharding &&
415       lhs_shape_size >= einsum_threshold_mib * 1024 * 1024) {
416     if (lhs_contracting_partitions == num_partitions) {
417       return WindowedEinsumConfig{
418           /*windowed_op=*/WindowedEinsumOperand::LHS,
419           /*windowed_at_contracting_dims*/ true,
420           /*windowed_at_batch_dims=*/false,
421           /*operands_sharded_at_contracting_dims=*/false};
422     }
423     if (lhs_non_contracting_partitions == num_partitions) {
424       return WindowedEinsumConfig{
425           /*windowed_op=*/WindowedEinsumOperand::LHS,
426           /*windowed_at_contracting_dims*/ false,
427           /*windowed_at_batch_dims=*/false,
428           /*operands_sharded_at_contracting_dims=*/false};
429     }
430     if (lhs_batch_partitions == num_partitions) {
431       return WindowedEinsumConfig{
432           /*windowed_op=*/WindowedEinsumOperand::LHS,
433           /*windowed_at_contracting_dims*/ false,
434           /*windowed_at_batch_dims=*/true,
435           /*operands_sharded_at_contracting_dims=*/false};
436     }
437   }
438   if (lhs_contracting_partitions == rhs_contracting_partitions &&
439       lhs_contracting_partitions == num_partitions &&
440       (output_lhs_non_contracting_partitions == num_partitions ||
441        output_rhs_non_contracting_partitions == num_partitions) &&
442       output_shape_size >= einsum_threshold_mib * 1024 * 1024) {
443     if (output_lhs_non_contracting_partitions == num_partitions) {
444       return WindowedEinsumConfig{
445           /*windowed_op=*/WindowedEinsumOperand::RHS,
446           /*windowed_at_contracting_dims*/ false,
447           /*windowed_at_batch_dims=*/false,
448           /*operands_sharded_at_contracting_dims=*/true};
449     }
450     if (output_rhs_non_contracting_partitions == num_partitions) {
451       return WindowedEinsumConfig{
452           /*windowed_op=*/WindowedEinsumOperand::LHS,
453           /*windowed_at_contracting_dims*/ false,
454           /*windowed_at_batch_dims=*/false,
455           /*operands_sharded_at_contracting_dims=*/true};
456     }
457   }
458   return absl::nullopt;
459 }
460 
GetLoopReplicaGroups(HloInstruction * while_loop)461 std::vector<ReplicaGroup> GetLoopReplicaGroups(HloInstruction* while_loop) {
462   std::vector<ReplicaGroup> groups;
463   for (auto inst : while_loop->while_body()->instructions()) {
464     if (inst->opcode() == HloOpcode::kCollectivePermute) {
465       std::vector<std::pair<int64, int64>> st_pairs =
466           inst->source_target_pairs();
467       std::vector<int64> source_index(st_pairs.size());
468       for (int64_t i = 0; i < st_pairs.size(); ++i) {
469         source_index[st_pairs[i].first] = i;
470       }
471 
472       absl::flat_hash_set<int64> visited;
473       for (int64_t i = 0; i < st_pairs.size(); ++i) {
474         if (visited.contains(st_pairs[i].first)) {
475           continue;
476         }
477         std::vector<int64> replica_group;
478         int64_t source = st_pairs[i].first;
479         int64_t target = st_pairs[i].second;
480         replica_group.push_back(source);
481         replica_group.push_back(target);
482         visited.insert(source);
483         visited.insert(target);
484         while (target != source) {
485           target = st_pairs[source_index[target]].second;
486           if (target != source) {
487             replica_group.push_back(target);
488             visited.insert(target);
489           }
490         }
491         absl::c_sort(replica_group);
492         groups.emplace_back();
493         for (auto id : replica_group) {
494           groups.back().add_replica_ids(id);
495         }
496       }
497 
498       VLOG(3) << "while loop: " << while_loop->name()
499               << ", replica groups: " << ReplicaGroupsToString(groups);
500       break;
501     }
502   }
503   return groups;
504 }
505 
506 // We use a recursive approach where sets of matching dimensions are recognized
507 // one at a time. The base shapes and shardings can be changed during the
508 // recursion as we group devices together. So refer to the passed in shapes and
509 // shardings for inputs and output, and do not use shape inference.
510 
PartitionBaseCase(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64_t num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,int64_t lhs_batch_partitions,int64_t rhs_batch_partitions,int64_t output_batch_partitions,int64_t lhs_contracting_partitions,int64_t rhs_contracting_partitions,int64_t lhs_non_contracting_partitions,int64_t rhs_non_contracting_partitions,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops,bool may_reshard_without_detecting_match)511 StatusOr<HloInstruction*> PartitionBaseCase(
512     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
513     const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
514     int64_t num_partitions,
515     const std::function<StatusOr<HloInstruction*>(
516         HloInstruction*, HloInstruction*, SpmdBuilder*,
517         const Window& conv_window)>& create_sharded_dot,
518     const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
519     int64_t lhs_batch_partitions, int64_t rhs_batch_partitions,
520     int64_t output_batch_partitions, int64_t lhs_contracting_partitions,
521     int64_t rhs_contracting_partitions, int64_t lhs_non_contracting_partitions,
522     int64_t rhs_non_contracting_partitions,
523     int64_t output_lhs_non_contracting_partitions,
524     int64_t output_rhs_non_contracting_partitions,
525     const SpmdPartitionerOptions& options, SpmdBuilder* b,
526     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
527         windowed_dot_general_loops,
528     bool may_reshard_without_detecting_match) {
529   const HloSharding& lhs_sharding = lhs.sharding();
530   const HloSharding& rhs_sharding = rhs.sharding();
531   if (lhs_sharding.ReplicateOnLastTileDim() ||
532       rhs_sharding.ReplicateOnLastTileDim() ||
533       output_sharding.ReplicateOnLastTileDim()) {
534     return nullptr;
535   }
536   DotDimensionIndexMapping indices_map = ComputeDimensionIndexMapping(
537       dims_mapping, lhs.base_shape().rank(), rhs.base_shape().rank(),
538       output_base_shape.rank());
539   auto lhs_sharding_transposed_to_match_rhs =
540       hlo_sharding_util::TransposeShardingWithCollapsedDims(
541           lhs_sharding, indices_map.lhs_to_rhs_indices,
542           indices_map.rhs_to_lhs_indices);
543   auto rhs_sharding_transposed_to_match_lhs =
544       hlo_sharding_util::TransposeShardingWithCollapsedDims(
545           rhs_sharding, indices_map.rhs_to_lhs_indices,
546           indices_map.lhs_to_rhs_indices);
547   auto lhs_sharding_transposed_to_match_output =
548       hlo_sharding_util::TransposeShardingWithCollapsedDims(
549           lhs_sharding, indices_map.lhs_to_output_indices,
550           indices_map.output_to_lhs_indices);
551   auto rhs_sharding_transposed_to_match_output =
552       hlo_sharding_util::TransposeShardingWithCollapsedDims(
553           rhs_sharding, indices_map.rhs_to_output_indices,
554           indices_map.output_to_rhs_indices);
555   auto output_sharding_transposed_to_match_lhs =
556       hlo_sharding_util::TransposeShardingWithCollapsedDims(
557           output_sharding, indices_map.output_to_lhs_indices,
558           indices_map.lhs_to_output_indices);
559   auto output_sharding_transposed_to_match_rhs =
560       hlo_sharding_util::TransposeShardingWithCollapsedDims(
561           output_sharding, indices_map.output_to_rhs_indices,
562           indices_map.rhs_to_output_indices);
563 
564   // LHS and RHS are partitioned the same way and only partitioned in batch
565   // dimensions.
566   if (lhs_batch_partitions == rhs_batch_partitions &&
567       rhs_batch_partitions == num_partitions &&
568       lhs_sharding_transposed_to_match_rhs == rhs_sharding) {
569     TF_ASSIGN_OR_RETURN(
570         auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
571     dot->set_sharding(*lhs_sharding_transposed_to_match_output);
572     return PartitionedHlo(dot, output_base_shape, lhs.state())
573         .Reshard(output_sharding)
574         .hlo();
575   }
576 
577   // Try emit batch-partitioned einsum with one operand resharded. Returns
578   // partitioned HLO or nullptr if the attempt fails. If
579   // may_reshard_with_allreduce is false, reshard must be done using
580   // all-to-all/collective-permute; otherwise this attempt fails.
581   auto try_emit_output_batch_partitioned_einsum_with_reshard =
582       [&](bool may_reshard_with_allreduce) -> StatusOr<HloInstruction*> {
583     // LHS and output are batch partitioned in the same way.
584     if (lhs_batch_partitions == num_partitions &&
585         output_batch_partitions == num_partitions &&
586         lhs_sharding_transposed_to_match_output == output_sharding) {
587       if (!may_reshard_with_allreduce &&
588           !CanReshardWithCollectivePermute(
589               rhs.sharding(), *lhs_sharding_transposed_to_match_rhs) &&
590           !GetReshardAllToAllSourceTargetDims(
591               rhs.sharding(), *lhs_sharding_transposed_to_match_rhs)) {
592         return nullptr;
593       }
594       auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs);
595       TF_ASSIGN_OR_RETURN(
596           auto dot,
597           create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), b, conv_window));
598       return dot;
599     }
600     // RHS and output are batch partitioned in the same way.
601     if (rhs_batch_partitions == num_partitions &&
602         output_batch_partitions == num_partitions &&
603         rhs_sharding_transposed_to_match_output == output_sharding) {
604       if (!may_reshard_with_allreduce &&
605           !CanReshardWithCollectivePermute(
606               lhs.sharding(), *rhs_sharding_transposed_to_match_lhs) &&
607           !GetReshardAllToAllSourceTargetDims(
608               lhs.sharding(), *rhs_sharding_transposed_to_match_lhs)) {
609         return nullptr;
610       }
611       auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs);
612       TF_ASSIGN_OR_RETURN(
613           auto dot,
614           create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), b, conv_window));
615       return dot;
616     }
617     return nullptr;
618   };
619 
620   {
621     // Try batch-parallel by resharding one operand, and not using all-reduce.
622     TF_ASSIGN_OR_RETURN(
623         HloInstruction * partitioned_dot,
624         try_emit_output_batch_partitioned_einsum_with_reshard(false));
625     if (partitioned_dot) {
626       return partitioned_dot;
627     }
628   }
629 
630   // Try to emit windowed DotGeneral when one operand is partitioned in the same
631   // way as the output along non-contracting dimensions, but the other operand
632   // is tiled in other dimensions. Or both operands are partitioned in the same
633   // way along contracting dimensions, but the output is partitioned along
634   // non-contracting dimensions.
635   auto emit_windowed_dot_general =
636       [&](const WindowedEinsumConfig& einsum_config)
637       -> StatusOr<HloInstruction*> {
638     CHECK(!einsum_config.windowed_at_batch_dims ||
639           !einsum_config.windowed_at_contracting_dims);
640     const bool windowed_at_batch_dims = einsum_config.windowed_at_batch_dims;
641     const bool windowed_at_contracting_dims =
642         einsum_config.windowed_at_contracting_dims;
643     const bool operands_sharded_at_contracting_dims =
644         einsum_config.operands_sharded_at_contracting_dims;
645     auto unpadded_result_buffer_shape =
646         MakePartitionedShape(output_base_shape, output_sharding);
647     auto padded_result_buffer_shape = unpadded_result_buffer_shape;
648     const bool windowed_op_is_lhs =
649         einsum_config.windowed_op == WindowedEinsumOperand::LHS;
650     // For windowing at batch/non-contracting dims, we produce the result one
651     // partition at a time, so we need to pad the shape in case of uneven
652     // partitioning in order to make dynamic-update-slice in-bound.
653     if (!windowed_at_contracting_dims &&
654         !operands_sharded_at_contracting_dims) {
655       padded_result_buffer_shape = GetPaddedShapeForUnevenPartitioning(
656           padded_result_buffer_shape,
657           windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
658                              : *rhs_sharding_transposed_to_match_output);
659     }
660     // Mask the padding area of the windowed operand with zero if there is
661     // uneven partitioning.
662     if (windowed_at_contracting_dims) {
663       auto& to_mask = windowed_op_is_lhs ? lhs : rhs;
664       to_mask =
665           to_mask.PadWithValue(b->AddInstruction(HloInstruction::CreateConstant(
666               LiteralUtil::Zero(output_base_shape.element_type()))));
667     }
668     if (operands_sharded_at_contracting_dims) {
669       auto zero = b->AddInstruction(HloInstruction::CreateConstant(
670           LiteralUtil::Zero(output_base_shape.element_type())));
671       lhs = lhs.PadWithValue(zero);
672       rhs = rhs.PadWithValue(zero);
673     }
674 
675     // Get slice sharding, sharding dim, and lhs/rhs concat dim.
676     const HloSharding* slice_sharding;
677     if (operands_sharded_at_contracting_dims) {
678       slice_sharding = windowed_op_is_lhs
679                            ? &*output_sharding_transposed_to_match_rhs
680                            : &*output_sharding_transposed_to_match_lhs;
681     } else if (windowed_at_contracting_dims || windowed_at_batch_dims) {
682       slice_sharding = windowed_op_is_lhs
683                            ? &*lhs_sharding_transposed_to_match_rhs
684                            : &*rhs_sharding_transposed_to_match_lhs;
685     } else {
686       slice_sharding = windowed_op_is_lhs
687                            ? &*lhs_sharding_transposed_to_match_output
688                            : &*rhs_sharding_transposed_to_match_output;
689     }
690     CHECK_EQ(Product(slice_sharding->tile_assignment().dimensions()),
691              num_partitions);
692     int64_t slice_sharding_dim = -1;
693     for (int64_t i = 0; i < slice_sharding->tile_assignment().num_dimensions();
694          ++i) {
695       if (slice_sharding->tile_assignment().dim(i) > 1) {
696         slice_sharding_dim = i;
697         break;
698       }
699     }
700     int64_t lhs_concat_dim = -1;
701     int64_t rhs_concat_dim = -1;
702     if (operands_sharded_at_contracting_dims) {
703       if (windowed_op_is_lhs) {
704         rhs_concat_dim = slice_sharding_dim;
705       } else {
706         lhs_concat_dim = slice_sharding_dim;
707       }
708     } else if (windowed_at_contracting_dims || windowed_at_batch_dims) {
709       lhs_concat_dim = windowed_op_is_lhs
710                            ? indices_map.rhs_to_lhs_indices[slice_sharding_dim]
711                            : slice_sharding_dim;
712       rhs_concat_dim = windowed_op_is_lhs
713                            ? slice_sharding_dim
714                            : indices_map.lhs_to_rhs_indices[slice_sharding_dim];
715     } else {
716       if (windowed_op_is_lhs) {
717         lhs_concat_dim = indices_map.output_to_lhs_indices[slice_sharding_dim];
718       } else {
719         rhs_concat_dim = indices_map.output_to_rhs_indices[slice_sharding_dim];
720       }
721     }
722 
723     auto lhs_hlo = lhs.hlo();
724     auto rhs_hlo = rhs.hlo();
725     // Reshape lhs and rhs before the loop for bidirectional communication case.
726     if (options.bidirectional_windowed_einsum && num_partitions % 4 == 0) {
727       if (lhs_concat_dim != -1 && windowed_op_is_lhs &&
728           !operands_sharded_at_contracting_dims) {
729         std::vector<int64> reshaped_dims(lhs_hlo->shape().dimensions().begin(),
730                                          lhs_hlo->shape().dimensions().end());
731         reshaped_dims.insert(reshaped_dims.begin() + lhs_concat_dim, 1);
732         lhs_hlo = b->AddInstruction(HloInstruction::CreateReshape(
733             ShapeUtil::MakeShape(lhs_hlo->shape().element_type(),
734                                  reshaped_dims),
735             lhs_hlo));
736       }
737       if (rhs_concat_dim != -1 && !windowed_op_is_lhs &&
738           !operands_sharded_at_contracting_dims) {
739         std::vector<int64> reshaped_dims(rhs_hlo->shape().dimensions().begin(),
740                                          rhs_hlo->shape().dimensions().end());
741         reshaped_dims.insert(reshaped_dims.begin() + rhs_concat_dim, 1);
742         rhs_hlo = b->AddInstruction(HloInstruction::CreateReshape(
743             ShapeUtil::MakeShape(rhs_hlo->shape().element_type(),
744                                  reshaped_dims),
745             rhs_hlo));
746       }
747     }
748 
749     auto result_buffer = CreateZero(padded_result_buffer_shape, b);
750     auto extra_buffer =
751         (!(options.bidirectional_windowed_einsum && num_partitions % 4 == 0) ||
752          operands_sharded_at_contracting_dims)
753             ? CreateZero(padded_result_buffer_shape, b)
754         : windowed_op_is_lhs ? lhs_hlo
755                              : rhs_hlo;
756 
757     if (options.bidirectional_windowed_einsum && num_partitions % 4 == 0 &&
758         !operands_sharded_at_contracting_dims) {
759       std::vector<std::pair<int64, int64>> pre_sd_pairs(num_partitions);
760       for (int64_t source = 0; source < num_partitions; ++source) {
761         // 0 -> 1, 1 -> 2, 2 -> 3, ...
762         pre_sd_pairs[source] = {source, (source + 1) % num_partitions};
763       }
764       extra_buffer =
765           lhs.state()
766               .collective_ops_creator.create_cross_partition_collective_permute(
767                   b, extra_buffer, pre_sd_pairs,
768                   (*lhs.state().next_channel_id)++);
769     }
770 
771     auto iteration = b->AddInstruction(
772         HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(0)));
773 
774     // Create a while loop that computes one window per iteration. During each
775     // iteration, each partition sends its input window to its neighbor using
776     // collective-permute for the next iteration.
777     SpmdBuilder body_b("windowed_dot_general_body", original_hlo);
778 
779     // Generate partial results used by bidirectional algorithm.
780     auto get_partial_bid_results =
781         [&](HloInstruction* l, HloInstruction* r, HloInstruction* o,
782             HloInstruction* extra_inout, HloInstruction* cw_cp_output,
783             HloInstruction* i) -> StatusOr<std::vector<HloInstruction*>> {
784       auto partition_id =
785           lhs.state().collective_ops_creator.create_partition_id(&body_b);
786       auto partition_count =
787           body_b.AddInstruction(HloInstruction::CreateConstant(
788               LiteralUtil::CreateR0<uint32>(num_partitions)));
789       auto ccw_data_partition_id =
790           body_b.AddInstruction(HloInstruction::CreateBinary(
791               i->shape(), HloOpcode::kAdd, i, partition_id));
792       auto cw_data_partition_id =
793           body_b.AddInstruction(HloInstruction::CreateBinary(
794               i->shape(), HloOpcode::kAdd, partition_count, partition_id));
795       if (operands_sharded_at_contracting_dims) {
796         ccw_data_partition_id =
797             body_b.AddInstruction(HloInstruction::CreateBinary(
798                 i->shape(), HloOpcode::kAdd, ccw_data_partition_id,
799                 body_b.AddInstruction(HloInstruction::CreateConstant(
800                     LiteralUtil::CreateR0<uint32>(num_partitions / 2 + 1)))));
801         cw_data_partition_id =
802             body_b.AddInstruction(HloInstruction::CreateBinary(
803                 i->shape(), HloOpcode::kSubtract, cw_data_partition_id,
804                 body_b.AddInstruction(HloInstruction::CreateConstant(
805                     LiteralUtil::CreateR0<uint32>(num_partitions / 2)))));
806       } else {
807         cw_data_partition_id =
808             body_b.AddInstruction(HloInstruction::CreateBinary(
809                 i->shape(), HloOpcode::kSubtract, cw_data_partition_id,
810                 CreateOne(cw_data_partition_id->shape(), &body_b)));
811       }
812       ccw_data_partition_id = body_b.AddInstruction(
813           HloInstruction::CreateBinary(i->shape(), HloOpcode::kRemainder,
814                                        ccw_data_partition_id, partition_count));
815       cw_data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary(
816           i->shape(), HloOpcode::kSubtract, cw_data_partition_id, i));
817       cw_data_partition_id = body_b.AddInstruction(
818           HloInstruction::CreateBinary(i->shape(), HloOpcode::kRemainder,
819                                        cw_data_partition_id, partition_count));
820 
821       DotDimensionNumbers new_ddnums;
822       if (original_hlo->opcode() == HloOpcode::kDot) {
823         new_ddnums = original_hlo->dot_dimension_numbers();
824       }
825 
826       auto dot_lhs = l;
827       auto dot_rhs = r;
828       auto original_dot_lhs = l;
829       auto original_dot_rhs = r;
830       // Recover original lhs and rhs, will not be used in real computation.
831       if (lhs_concat_dim != -1 && windowed_op_is_lhs) {
832         std::vector<int64> reshaped_dims(
833             original_dot_lhs->shape().dimensions().begin(),
834             original_dot_lhs->shape().dimensions().end());
835         reshaped_dims.erase(reshaped_dims.begin() + lhs_concat_dim);
836         original_dot_lhs = body_b.AddInstruction(HloInstruction::CreateReshape(
837             ShapeUtil::MakeShape(original_dot_lhs->shape().element_type(),
838                                  reshaped_dims),
839             original_dot_lhs));
840       }
841       if (rhs_concat_dim != -1 && !windowed_op_is_lhs) {
842         std::vector<int64> reshaped_dims(
843             original_dot_rhs->shape().dimensions().begin(),
844             original_dot_rhs->shape().dimensions().end());
845         reshaped_dims.erase(reshaped_dims.begin() + rhs_concat_dim);
846         original_dot_rhs = body_b.AddInstruction(HloInstruction::CreateReshape(
847             ShapeUtil::MakeShape(original_dot_rhs->shape().element_type(),
848                                  reshaped_dims),
849             original_dot_rhs));
850       }
851 
852       if (windowed_at_contracting_dims || windowed_at_batch_dims ||
853           operands_sharded_at_contracting_dims) {
854         // Slice the matching operand according to the partitioned dimensions
855         // on the windowed operand or the output.
856         auto slice_operand = !windowed_op_is_lhs ? l : r;
857 
858         // Pad the sharding dim first (then the concat dim) for correctness.
859         auto sharding_dim_size =
860             slice_operand->shape().dimensions(slice_sharding_dim);
861         if (sharding_dim_size % num_partitions != 0) {
862           slice_operand = PadBaseShapeBeforeUnevenTiledSharding(
863               slice_operand, *slice_sharding, &body_b);
864         }
865 
866         // We do this by treating the matching operand as replicated, and
867         // resharding it to match the windowed operand or the output.
868         auto gen_slice = [&](HloInstruction* data_partition_id,
869                              bool ccw) -> HloInstruction* {
870           std::vector<int64> new_dims;
871           for (int64_t i = 0; i < slice_operand->shape().dimensions_size();
872                ++i) {
873             if (i == slice_sharding_dim) {
874               new_dims.push_back(1);
875             }
876             new_dims.push_back(slice_operand->shape().dimensions(i));
877           }
878           auto reshaped_slice_operand =
879               body_b.AddInstruction(HloInstruction::CreateReshape(
880                   ShapeUtil::MakeShape(slice_operand->shape().element_type(),
881                                        new_dims),
882                   slice_operand));
883           auto min = body_b.AddInstruction(
884               HloInstruction::CreateConstant(LiteralUtil::MinValue(
885                   reshaped_slice_operand->shape().element_type())));
886           std::vector<int64> min_padding(
887               reshaped_slice_operand->shape().rank());
888           auto padded_slice_operand = reshaped_slice_operand;
889           auto padded_shape = padded_slice_operand->shape();
890           int64_t padding_dim = slice_sharding_dim;
891           padded_shape.set_dimensions(padding_dim, 2);
892           if (ccw) {
893             // ccw pad high
894             PaddingConfig ccw_pad_config =
895                 window_util::MakeSymmetricPadding(min_padding);
896             ccw_pad_config.mutable_dimensions(padding_dim)
897                 ->set_edge_padding_low(0);
898             ccw_pad_config.mutable_dimensions(padding_dim)
899                 ->set_edge_padding_high(1);
900             padded_slice_operand =
901                 body_b.AddInstruction(HloInstruction::CreatePad(
902                     padded_shape, padded_slice_operand, min, ccw_pad_config));
903           } else {
904             // cw pad low
905             PaddingConfig cw_pad_config =
906                 window_util::MakeSymmetricPadding(min_padding);
907             cw_pad_config.mutable_dimensions(padding_dim)
908                 ->set_edge_padding_low(1);
909             cw_pad_config.mutable_dimensions(padding_dim)
910                 ->set_edge_padding_high(0);
911             padded_slice_operand =
912                 body_b.AddInstruction(HloInstruction::CreatePad(
913                     padded_shape, padded_slice_operand, min, cw_pad_config));
914           }
915 
916           padded_slice_operand->set_sharding(HloSharding::Replicate());
917           auto state = lhs.state();
918           state.b = &body_b;
919           state.partition_id = data_partition_id;
920           state.reshard_cache->per_hlo_cache.erase(padded_slice_operand);
921           auto padded_slice_sharding = hlo_sharding_util::ReshapeSharding(
922               slice_operand->shape(), reshaped_slice_operand->shape(),
923               *slice_sharding);
924           auto padded_slice =
925               PartitionedHlo(padded_slice_operand,
926                              padded_slice_operand->shape(), state)
927                   .Reshard(*padded_slice_sharding)
928                   .hlo();
929           padded_slice_operand->clear_sharding();
930           return padded_slice;
931         };
932 
933         auto ccw_slice = gen_slice(ccw_data_partition_id, true);
934         auto cw_slice = gen_slice(cw_data_partition_id, false);
935         auto slice = body_b.AddInstruction(HloInstruction::CreateBinary(
936             ccw_slice->shape(), HloOpcode::kMaximum, ccw_slice, cw_slice));
937         // Reshape. The reshaped slice will not be used to produce the final
938         // result, but used as a hint for the shape inference.
939         std::vector<int64> reshaped_slice_dims;
940         for (int64_t i = 0; i < slice->shape().dimensions_size(); ++i) {
941           auto dim_size = slice->shape().dimensions(i);
942           if (i == (slice_sharding_dim + 1)) {
943             reshaped_slice_dims.push_back(dim_size * 2);
944           } else if (i != slice_sharding_dim) {
945             reshaped_slice_dims.push_back(dim_size);
946           }
947         }
948         auto reshaped_slice =
949             body_b.AddInstruction(HloInstruction::CreateReshape(
950                 ShapeUtil::MakeShape(slice->shape().element_type(),
951                                      reshaped_slice_dims),
952                 slice));
953 
954         if (!windowed_op_is_lhs) {
955           dot_lhs = slice;
956           original_dot_lhs = reshaped_slice;
957           if (original_hlo->opcode() == HloOpcode::kDot) {
958             UpdateDDNums(&new_ddnums, slice_sharding_dim, true);
959           }
960         } else {
961           dot_rhs = slice;
962           original_dot_rhs = reshaped_slice;
963           if (original_hlo->opcode() == HloOpcode::kDot) {
964             UpdateDDNums(&new_ddnums, slice_sharding_dim, false);
965           }
966         }
967       }
968 
969       auto ccw_dot_lhs = l;
970       auto ccw_dot_rhs = r;
971       auto cw_dot_lhs = windowed_op_is_lhs ? extra_inout : l;
972       auto cw_dot_rhs = windowed_op_is_lhs ? r : extra_inout;
973       if (lhs_concat_dim != -1 && windowed_op_is_lhs) {
974         // Concat
975         auto lhs_concat_shape = ccw_dot_lhs->shape();
976         lhs_concat_shape.set_dimensions(lhs_concat_dim, 2);
977         dot_lhs = body_b.AddInstruction(HloInstruction::CreateConcatenate(
978             lhs_concat_shape, {ccw_dot_lhs, cw_dot_lhs}, lhs_concat_dim));
979 
980         std::vector<int64> reshaped_dims(
981             ccw_dot_lhs->shape().dimensions().begin(),
982             ccw_dot_lhs->shape().dimensions().end());
983         reshaped_dims.erase(reshaped_dims.begin() + lhs_concat_dim);
984         reshaped_dims[lhs_concat_dim] *= 2;
985         original_dot_lhs = body_b.AddInstruction(HloInstruction::CreateReshape(
986             ShapeUtil::MakeShape(dot_lhs->shape().element_type(),
987                                  reshaped_dims),
988             dot_lhs));
989 
990         if (original_hlo->opcode() == HloOpcode::kDot) {
991           UpdateDDNums(&new_ddnums, lhs_concat_dim, true);
992         }
993       }
994       if (rhs_concat_dim != -1 && !windowed_op_is_lhs) {
995         // Concat
996         auto rhs_concat_shape = ccw_dot_rhs->shape();
997         rhs_concat_shape.set_dimensions(rhs_concat_dim, 2);
998         dot_rhs = body_b.AddInstruction(HloInstruction::CreateConcatenate(
999             rhs_concat_shape, {ccw_dot_rhs, cw_dot_rhs}, rhs_concat_dim));
1000 
1001         std::vector<int64> reshaped_dims(
1002             ccw_dot_rhs->shape().dimensions().begin(),
1003             ccw_dot_rhs->shape().dimensions().end());
1004         reshaped_dims.erase(reshaped_dims.begin() + rhs_concat_dim);
1005         reshaped_dims[rhs_concat_dim] *= 2;
1006         original_dot_rhs = body_b.AddInstruction(HloInstruction::CreateReshape(
1007             ShapeUtil::MakeShape(dot_rhs->shape().element_type(),
1008                                  reshaped_dims),
1009             dot_rhs));
1010 
1011         if (original_hlo->opcode() == HloOpcode::kDot) {
1012           UpdateDDNums(&new_ddnums, rhs_concat_dim, false);
1013         }
1014       }
1015 
1016       // The generated original dot will not be used.
1017       TF_ASSIGN_OR_RETURN(auto original_dot,
1018                           create_sharded_dot(original_dot_lhs, original_dot_rhs,
1019                                              &body_b, conv_window));
1020       VLOG(2) << original_dot->ToString();
1021 
1022       // Generate the correct shape of the new dot/conv.
1023       auto original_sharded_dot_shape = original_dot->shape();
1024       auto new_dot_shape = original_sharded_dot_shape;
1025       std::vector<int64> new_dims(new_dot_shape.dimensions().begin(),
1026                                   new_dot_shape.dimensions().end());
1027       if (!windowed_at_contracting_dims) {
1028         auto slice_dim =
1029             lhs_concat_dim != -1
1030                 ? indices_map.lhs_to_output_indices[lhs_concat_dim]
1031                 : indices_map.rhs_to_output_indices[rhs_concat_dim];
1032         new_dims[slice_dim] /= 2;
1033         new_dims.insert(new_dims.begin() + slice_dim, 2);
1034       } else if (original_hlo->opcode() != HloOpcode::kDot) {
1035         new_dims.push_back(1);
1036       }
1037       new_dot_shape =
1038           ShapeUtil::MakeShape(original_hlo->shape().element_type(), new_dims);
1039 
1040       HloInstruction* dot;
1041       if (original_hlo->opcode() == HloOpcode::kDot) {
1042         dot = body_b.AddInstruction(HloInstruction::CreateDot(
1043             new_dot_shape, dot_lhs, dot_rhs, new_ddnums,
1044             original_hlo->precision_config()));
1045       } else {
1046         if (!windowed_at_contracting_dims && !windowed_at_batch_dims) {
1047           if (lhs_concat_dim != -1) {
1048             std::vector<int64> new_dims(dot_rhs->shape().dimensions().begin(),
1049                                         dot_rhs->shape().dimensions().end());
1050             new_dims.push_back(1);
1051             dot_rhs = body_b.AddInstruction(HloInstruction::CreateReshape(
1052                 ShapeUtil::MakeShape(dot_rhs->shape().element_type(), new_dims),
1053                 dot_rhs));
1054           }
1055           if (rhs_concat_dim != -1) {
1056             std::vector<int64> new_dims(dot_lhs->shape().dimensions().begin(),
1057                                         dot_lhs->shape().dimensions().end());
1058             new_dims.push_back(1);
1059             dot_lhs = body_b.AddInstruction(HloInstruction::CreateReshape(
1060                 ShapeUtil::MakeShape(dot_lhs->shape().element_type(), new_dims),
1061                 dot_lhs));
1062           }
1063         }
1064 
1065         dot = body_b.AddInstruction(HloInstruction::CreateConvolve(
1066             new_dot_shape, dot_lhs, dot_rhs,
1067             original_dot->feature_group_count(),
1068             original_dot->batch_group_count(),
1069             GenNewWindow(original_dot, dot_lhs, dot_rhs, lhs_concat_dim,
1070                          rhs_concat_dim, windowed_at_contracting_dims,
1071                          windowed_at_batch_dims),
1072             GenNewConvDNums(original_dot, dot_lhs, dot_rhs, lhs_concat_dim,
1073                             rhs_concat_dim, windowed_at_contracting_dims,
1074                             windowed_at_batch_dims,
1075                             indices_map.lhs_to_output_indices,
1076                             indices_map.rhs_to_output_indices, new_dot_shape),
1077             original_dot->precision_config()));
1078       }
1079       VLOG(2) << dot->ToString();
1080 
1081       if (windowed_at_contracting_dims) {
1082         if (original_hlo->opcode() != HloOpcode::kDot) {
1083           // Reshape to the original sharded dot shape.
1084           dot = body_b.AddInstruction(
1085               HloInstruction::CreateReshape(original_sharded_dot_shape, dot));
1086         }
1087 
1088         // Accumulate the partial output to the result buffer.
1089         o = body_b.AddInstruction(
1090             HloInstruction::CreateBinary(o->shape(), HloOpcode::kAdd, o, dot));
1091       } else {
1092         // The windowing operand is partitioned along batch/non-contracting
1093         // dimensions, so we need a dynamic-update-slice to save the partial
1094         // output in the result buffer.
1095         auto slice_shape = dot->shape();
1096         auto slice_dim =
1097             lhs_concat_dim != -1
1098                 ? indices_map.lhs_to_output_indices[lhs_concat_dim]
1099                 : indices_map.rhs_to_output_indices[rhs_concat_dim];
1100         slice_shape.set_dimensions(slice_dim, 1);
1101         std::vector<int64> ccw_start_indices(dot->shape().rank(), 0);
1102         std::vector<int64> cw_start_indices(dot->shape().rank(), 0);
1103         cw_start_indices[slice_dim] = 1;
1104         auto ccw_dot = body_b.AddInstruction(HloInstruction::CreateSlice(
1105             slice_shape, dot, ccw_start_indices, slice_shape.dimensions(),
1106             std::vector<int64>(dot->shape().rank(), 1)));
1107         auto cw_dot = body_b.AddInstruction(HloInstruction::CreateSlice(
1108             slice_shape, dot, cw_start_indices, dot->shape().dimensions(),
1109             std::vector<int64>(dot->shape().rank(), 1)));
1110 
1111         std::vector<int64> reshaped_dims(
1112             original_sharded_dot_shape.dimensions().begin(),
1113             original_sharded_dot_shape.dimensions().end());
1114         reshaped_dims[slice_dim] /= 2;
1115         ccw_dot = body_b.AddInstruction(HloInstruction::CreateReshape(
1116             ShapeUtil::MakeShape(ccw_dot->shape().element_type(),
1117                                  reshaped_dims),
1118             ccw_dot));
1119         cw_dot = body_b.AddInstruction(HloInstruction::CreateReshape(
1120             ShapeUtil::MakeShape(cw_dot->shape().element_type(), reshaped_dims),
1121             cw_dot));
1122 
1123         if (operands_sharded_at_contracting_dims) {
1124           // Accumulate the partial output to the result buffer.
1125           o = body_b.AddInstruction(HloInstruction::CreateBinary(
1126               o->shape(), HloOpcode::kAdd, o, ccw_dot));
1127           cw_cp_output = body_b.AddInstruction(HloInstruction::CreateBinary(
1128               o->shape(), HloOpcode::kAdd, cw_cp_output, cw_dot));
1129         } else {
1130           auto ccw_offsets = MakePartitionOffsets(
1131               o->shape(),
1132               windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
1133                                  : *rhs_sharding_transposed_to_match_output,
1134               ccw_data_partition_id, &body_b);
1135           auto cw_offsets = MakePartitionOffsets(
1136               o->shape(),
1137               windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
1138                                  : *rhs_sharding_transposed_to_match_output,
1139               cw_data_partition_id, &body_b);
1140           o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1141               o->shape(), o, ccw_dot, ccw_offsets));
1142           o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1143               o->shape(), o, cw_dot, cw_offsets));
1144         }
1145       }
1146 
1147       std::vector<HloInstruction*> partial_results;
1148       partial_results.push_back(o);
1149       partial_results.push_back(cw_cp_output);
1150       return partial_results;
1151     };
1152 
1153     // Generate partial result used by unidirectional algorithm.
1154     auto get_partial_unid_result =
1155         [&](HloInstruction* l, HloInstruction* r, HloInstruction* o,
1156             HloInstruction* i) -> StatusOr<HloInstruction*> {
1157       auto partition_id =
1158           lhs.state().collective_ops_creator.create_partition_id(&body_b);
1159       auto data_partition_id =
1160           body_b.AddInstruction(HloInstruction::CreateBinary(
1161               i->shape(), HloOpcode::kAdd, i, partition_id));
1162       auto partition_count =
1163           body_b.AddInstruction(HloInstruction::CreateConstant(
1164               LiteralUtil::CreateR0<uint32>(num_partitions)));
1165       data_partition_id = body_b.AddInstruction(
1166           HloInstruction::CreateBinary(i->shape(), HloOpcode::kRemainder,
1167                                        data_partition_id, partition_count));
1168       auto dot_lhs = l;
1169       auto dot_rhs = r;
1170       if (windowed_at_contracting_dims || windowed_at_batch_dims ||
1171           operands_sharded_at_contracting_dims) {
1172         // Slice the matching operand according to the partitioned dimensions on
1173         // the windowed operand or the output.
1174         auto slice_operand = !windowed_op_is_lhs ? l : r;
1175         // We do this by treating the matching operand as replicated, and
1176         // resharding it to match the windowed operand or the output.
1177         slice_operand->set_sharding(HloSharding::Replicate());
1178         auto state = lhs.state();
1179         state.b = &body_b;
1180         state.partition_id = data_partition_id;
1181         state.reshard_cache->per_hlo_cache.erase(slice_operand);
1182         auto slice =
1183             PartitionedHlo(slice_operand, slice_operand->shape(), state)
1184                 .Reshard(*slice_sharding)
1185                 .hlo();
1186         slice_operand->clear_sharding();
1187         if (!windowed_op_is_lhs) {
1188           dot_lhs = slice;
1189         } else {
1190           dot_rhs = slice;
1191         }
1192       }
1193       TF_ASSIGN_OR_RETURN(
1194           auto dot, create_sharded_dot(dot_lhs, dot_rhs, &body_b, conv_window));
1195       if (windowed_at_contracting_dims ||
1196           operands_sharded_at_contracting_dims) {
1197         // Accumulate the partial output to the result buffer.
1198         o = body_b.AddInstruction(
1199             HloInstruction::CreateBinary(o->shape(), HloOpcode::kAdd, o, dot));
1200       } else {
1201         // The windowing operand is partitioned along batch/non-contracting
1202         // dimensions, so we need a dynamic-update-slice to save the partial
1203         // output in the result buffer.
1204         auto offsets = MakePartitionOffsets(
1205             o->shape(),
1206             windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
1207                                : *rhs_sharding_transposed_to_match_output,
1208             data_partition_id, &body_b);
1209         o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1210             o->shape(), o, dot, offsets));
1211       }
1212       return o;
1213     };
1214 
1215     auto param = body_b.AddInstruction(HloInstruction::CreateParameter(
1216         /*parameter_number=*/0,
1217         ShapeUtil::MakeTupleShape({lhs_hlo->shape(), rhs_hlo->shape(),
1218                                    result_buffer->shape(),
1219                                    extra_buffer->shape(), iteration->shape()}),
1220         "param"));
1221     auto l = body_b.AddInstruction(
1222         HloInstruction::CreateGetTupleElement(lhs_hlo->shape(), param, 0));
1223     auto r = body_b.AddInstruction(
1224         HloInstruction::CreateGetTupleElement(rhs_hlo->shape(), param, 1));
1225     auto o = body_b.AddInstruction(HloInstruction::CreateGetTupleElement(
1226         result_buffer->shape(), param, 2));
1227     auto extra_inout = body_b.AddInstruction(
1228         HloInstruction::CreateGetTupleElement(extra_buffer->shape(), param, 3));
1229     auto i = body_b.AddInstruction(
1230         HloInstruction::CreateGetTupleElement(iteration->shape(), param, 4));
1231 
1232     // The bidirectional collective permute implementation has loop unrolling
1233     // of degree 2, so num_partitions is required to be a multiple of 4.
1234     if (options.bidirectional_windowed_einsum && num_partitions % 4 == 0) {
1235       std::vector<std::pair<int64, int64>> ccw_sd_pairs(num_partitions);
1236       for (int64_t source = 0; source < num_partitions; ++source) {
1237         // 0 -> n-1, 1 -> 0, 2 -> 1, ...
1238         ccw_sd_pairs[source] = {source,
1239                                 (source - 1 + num_partitions) % num_partitions};
1240       }
1241       std::vector<std::pair<int64, int64>> cw_sd_pairs(num_partitions);
1242       for (int64_t source = 0; source < num_partitions; ++source) {
1243         // 0 -> 1, 1 -> 2, 2 -> 3, ...
1244         cw_sd_pairs[source] = {source, (source + 1) % num_partitions};
1245       }
1246 
1247       // Even number iteration.
1248       auto next_l = l;
1249       auto next_r = r;
1250       auto ccw_cp_input = operands_sharded_at_contracting_dims ? o
1251                           : windowed_op_is_lhs                 ? l
1252                                                                : r;
1253       auto ccw_cp_output =
1254           lhs.state()
1255               .collective_ops_creator.create_cross_partition_collective_permute(
1256                   &body_b, ccw_cp_input, ccw_sd_pairs,
1257                   (*lhs.state().next_channel_id)++);
1258       if (operands_sharded_at_contracting_dims) {
1259         o = ccw_cp_output;
1260       } else if (windowed_op_is_lhs) {
1261         next_l = ccw_cp_output;
1262       } else {
1263         next_r = ccw_cp_output;
1264       }
1265       auto cw_cp_input = extra_inout;
1266       auto cw_cp_output =
1267           lhs.state()
1268               .collective_ops_creator.create_cross_partition_collective_permute(
1269                   &body_b, cw_cp_input, cw_sd_pairs,
1270                   (*lhs.state().next_channel_id)++);
1271 
1272       TF_ASSIGN_OR_RETURN(
1273           auto outputs,
1274           get_partial_bid_results(l, r, o, extra_inout, cw_cp_output, i));
1275       o = outputs[0];
1276       cw_cp_output = outputs[1];
1277 
1278       // ++i
1279       i = body_b.AddInstruction(HloInstruction::CreateBinary(
1280           i->shape(), HloOpcode::kAdd, i, CreateOne(i->shape(), &body_b)));
1281 
1282       // Odd number iteration.
1283       auto second_next_l = next_l;
1284       auto second_next_r = next_r;
1285       ccw_cp_input = operands_sharded_at_contracting_dims ? o
1286                      : windowed_op_is_lhs                 ? next_l
1287                                                           : next_r;
1288       ccw_cp_output =
1289           lhs.state()
1290               .collective_ops_creator.create_cross_partition_collective_permute(
1291                   &body_b, ccw_cp_input, ccw_sd_pairs,
1292                   (*lhs.state().next_channel_id)++);
1293       if (operands_sharded_at_contracting_dims) {
1294         o = ccw_cp_output;
1295       } else if (windowed_op_is_lhs) {
1296         second_next_l = ccw_cp_output;
1297       } else {
1298         second_next_r = ccw_cp_output;
1299       }
1300       auto next_cw_cp_input = cw_cp_output;
1301       auto next_cw_cp_output =
1302           lhs.state()
1303               .collective_ops_creator.create_cross_partition_collective_permute(
1304                   &body_b, next_cw_cp_input, cw_sd_pairs,
1305                   (*lhs.state().next_channel_id)++);
1306 
1307       TF_ASSIGN_OR_RETURN(
1308           outputs, get_partial_bid_results(next_l, next_r, o, cw_cp_output,
1309                                            next_cw_cp_output, i));
1310       o = outputs[0];
1311       next_cw_cp_output = outputs[1];
1312 
1313       // ++i
1314       i = body_b.AddInstruction(HloInstruction::CreateBinary(
1315           i->shape(), HloOpcode::kAdd, i, CreateOne(i->shape(), &body_b)));
1316 
1317       body_b.AddInstruction(HloInstruction::CreateTuple(
1318           {second_next_l, second_next_r, o, next_cw_cp_output, i}));
1319 
1320     } else if (options.unroll_windowed_einsum && num_partitions % 2 == 0) {
1321       if (operands_sharded_at_contracting_dims) {
1322         std::vector<std::pair<int64, int64>> output_sd_pairs(num_partitions);
1323         for (int64_t source = 0; source < num_partitions; ++source) {
1324           // 0 -> n-2, 1 -> n-1, 2 -> 0, ...
1325           output_sd_pairs[source] = {
1326               source, (source - 2 + num_partitions) % num_partitions};
1327         }
1328 
1329         o = lhs.state()
1330                 .collective_ops_creator
1331                 .create_cross_partition_collective_permute(
1332                     &body_b, o, output_sd_pairs,
1333                     (*lhs.state().next_channel_id)++);
1334 
1335         TF_ASSIGN_OR_RETURN(extra_inout,
1336                             get_partial_unid_result(l, r, extra_inout, i));
1337 
1338         extra_inout = lhs.state()
1339                           .collective_ops_creator
1340                           .create_cross_partition_collective_permute(
1341                               &body_b, extra_inout, output_sd_pairs,
1342                               (*lhs.state().next_channel_id)++);
1343 
1344         // i+2
1345         i = body_b.AddInstruction(HloInstruction::CreateBinary(
1346             i->shape(), HloOpcode::kAdd, i,
1347             body_b.AddInstruction(HloInstruction::CreateConstant(
1348                 LiteralUtil::CreateR0<uint32>(2)))));
1349         auto real_i = body_b.AddInstruction(HloInstruction::CreateBinary(
1350             i->shape(), HloOpcode::kAdd, i,
1351             body_b.AddInstruction(HloInstruction::CreateConstant(
1352                 LiteralUtil::CreateR0<uint32>(1)))));
1353 
1354         TF_ASSIGN_OR_RETURN(o, get_partial_unid_result(l, r, o, real_i));
1355         body_b.AddInstruction(
1356             HloInstruction::CreateTuple({l, r, o, extra_inout, i}));
1357       } else {
1358         std::vector<std::pair<int64, int64>> sd_pairs(num_partitions);
1359         for (int64_t source = 0; source < num_partitions; ++source) {
1360           // 0 -> n-1, 1 -> 0, 2 -> 1, ...
1361           sd_pairs[source] = {source,
1362                               (source - 1 + num_partitions) % num_partitions};
1363         }
1364 
1365         // Even number iteration.
1366         auto next_l = l;
1367         auto next_r = r;
1368         auto cp_input = windowed_op_is_lhs ? l : r;
1369         auto cp_output = lhs.state()
1370                              .collective_ops_creator
1371                              .create_cross_partition_collective_permute(
1372                                  &body_b, cp_input, sd_pairs,
1373                                  (*lhs.state().next_channel_id)++);
1374         if (windowed_op_is_lhs) {
1375           next_l = cp_output;
1376         } else {
1377           next_r = cp_output;
1378         }
1379         TF_ASSIGN_OR_RETURN(o, get_partial_unid_result(l, r, o, i));
1380 
1381         // ++i
1382         i = body_b.AddInstruction(HloInstruction::CreateBinary(
1383             i->shape(), HloOpcode::kAdd, i,
1384             body_b.AddInstruction(HloInstruction::CreateConstant(
1385                 LiteralUtil::CreateR0<uint32>(1)))));
1386 
1387         // Odd number iteration.
1388         auto second_next_l = next_l;
1389         auto second_next_r = next_r;
1390         cp_input = windowed_op_is_lhs ? next_l : next_r;
1391         cp_output = lhs.state()
1392                         .collective_ops_creator
1393                         .create_cross_partition_collective_permute(
1394                             &body_b, cp_input, sd_pairs,
1395                             (*lhs.state().next_channel_id)++);
1396         if (windowed_op_is_lhs) {
1397           second_next_l = cp_output;
1398         } else {
1399           second_next_r = cp_output;
1400         }
1401         TF_ASSIGN_OR_RETURN(o, get_partial_unid_result(next_l, next_r, o, i));
1402 
1403         // ++i
1404         i = body_b.AddInstruction(HloInstruction::CreateBinary(
1405             i->shape(), HloOpcode::kAdd, i,
1406             body_b.AddInstruction(HloInstruction::CreateConstant(
1407                 LiteralUtil::CreateR0<uint32>(1)))));
1408 
1409         body_b.AddInstruction(HloInstruction::CreateTuple(
1410             {second_next_l, second_next_r, o, extra_inout, i}));
1411       }
1412     } else {
1413       auto real_i = i;
1414       if (operands_sharded_at_contracting_dims) {
1415         // For reduce-scatter case, start from the data_partition_id + 1 to make
1416         // the data_partition_id of the final data shard in each partition the
1417         // same as the corresponding partition_id.
1418         real_i = body_b.AddInstruction(HloInstruction::CreateBinary(
1419             real_i->shape(), HloOpcode::kAdd, real_i,
1420             CreateOne(real_i->shape(), &body_b)));
1421       }
1422       TF_ASSIGN_OR_RETURN(o, get_partial_unid_result(l, r, o, real_i));
1423 
1424       // ++i
1425       i = body_b.AddInstruction(HloInstruction::CreateBinary(
1426           i->shape(), HloOpcode::kAdd, i,
1427           body_b.AddInstruction(HloInstruction::CreateConstant(
1428               LiteralUtil::CreateR0<uint32>(1)))));
1429       auto has_more = body_b.AddInstruction(HloInstruction::CreateCompare(
1430           ShapeUtil::MakeShape(PRED, {}), i,
1431           body_b.AddInstruction(HloInstruction::CreateConstant(
1432               LiteralUtil::CreateR0<uint32>(num_partitions))),
1433           ComparisonDirection::kLt));
1434       // Collective-permute for the next window. We don't need it for the last
1435       // iteration, so we use a conditional around the collective-permute.
1436       HloInstruction* conditional;
1437       {
1438         SpmdBuilder cp_b("window_collective_permute", original_hlo);
1439         {
1440           auto p = cp_b.AddInstruction(HloInstruction::CreateParameter(
1441               0,
1442               operands_sharded_at_contracting_dims ? o->shape()
1443               : windowed_op_is_lhs                 ? l->shape()
1444                                                    : r->shape(),
1445               "window"));
1446           std::vector<std::pair<int64, int64>> sd_pairs(num_partitions);
1447           for (int64_t source = 0; source < num_partitions; ++source) {
1448             // 0 -> n-1, 1 -> 0, 2 -> 1, ...
1449             sd_pairs[source] = {source,
1450                                 (source - 1 + num_partitions) % num_partitions};
1451           }
1452           lhs.state()
1453               .collective_ops_creator.create_cross_partition_collective_permute(
1454                   &cp_b, p, sd_pairs, (*lhs.state().next_channel_id)++);
1455         }
1456         SpmdBuilder ncp_b("last_iteration_noop", original_hlo);
1457         {
1458           ncp_b.AddInstruction(HloInstruction::CreateParameter(
1459               0,
1460               operands_sharded_at_contracting_dims ? o->shape()
1461               : windowed_op_is_lhs                 ? l->shape()
1462                                                    : r->shape(),
1463               "window"));
1464         }
1465         conditional = body_b.AddInstruction(HloInstruction::CreateConditional(
1466             operands_sharded_at_contracting_dims ? o->shape()
1467             : windowed_op_is_lhs                 ? l->shape()
1468                                                  : r->shape(),
1469             has_more,
1470             operands_sharded_at_contracting_dims ? o
1471             : windowed_op_is_lhs                 ? l
1472                                                  : r,
1473             module->AddEmbeddedComputation(cp_b.Build()),
1474             operands_sharded_at_contracting_dims ? o
1475             : windowed_op_is_lhs                 ? l
1476                                                  : r,
1477             module->AddEmbeddedComputation(ncp_b.Build())));
1478       }
1479       if (operands_sharded_at_contracting_dims) {
1480         o = conditional;
1481       } else if (windowed_op_is_lhs) {
1482         l = conditional;
1483       } else {
1484         r = conditional;
1485       }
1486       body_b.AddInstruction(
1487           HloInstruction::CreateTuple({l, r, o, extra_inout, i}));
1488     }
1489 
1490     SpmdBuilder cond_b("windowed_dot_general_cond", original_hlo);
1491     auto cond_param = cond_b.AddInstruction(HloInstruction::CreateParameter(
1492         /*parameter_number=*/0,
1493         ShapeUtil::MakeTupleShape({lhs_hlo->shape(), rhs_hlo->shape(),
1494                                    result_buffer->shape(),
1495                                    extra_buffer->shape(), iteration->shape()}),
1496         "param"));
1497     auto cond_i = cond_b.AddInstruction(HloInstruction::CreateGetTupleElement(
1498         iteration->shape(), cond_param, 4));
1499     int64_t adapted_num_partitions =
1500         (options.bidirectional_windowed_einsum && num_partitions % 4 == 0)
1501             ? num_partitions / 2
1502             : num_partitions;
1503     cond_b.AddInstruction(HloInstruction::CreateCompare(
1504         ShapeUtil::MakeShape(PRED, {}), cond_i,
1505         cond_b.AddInstruction(HloInstruction::CreateConstant(
1506             LiteralUtil::CreateR0<uint32>(adapted_num_partitions))),
1507         ComparisonDirection::kLt));
1508     auto while_loop = b->AddInstruction(HloInstruction::CreateWhile(
1509         cond_param->shape(), module->AddEmbeddedComputation(cond_b.Build()),
1510         module->AddEmbeddedComputation(body_b.Build()),
1511         b->AddInstruction(HloInstruction::CreateTuple(
1512             {lhs_hlo, rhs_hlo, result_buffer, extra_buffer, iteration}))));
1513     windowed_dot_general_loops->push_back(
1514         {while_loop, windowed_op_is_lhs ? 0 : 1, windowed_at_contracting_dims,
1515          windowed_at_batch_dims, operands_sharded_at_contracting_dims,
1516          num_partitions, GetLoopReplicaGroups(while_loop)});
1517     auto result = b->AddInstruction(HloInstruction::CreateGetTupleElement(
1518         result_buffer->shape(), while_loop, 2));
1519     if (((options.bidirectional_windowed_einsum && num_partitions % 4 == 0) ||
1520          (options.unroll_windowed_einsum && num_partitions % 2 == 0)) &&
1521         operands_sharded_at_contracting_dims) {
1522       std::vector<std::pair<int64, int64>> extra_sd_pairs(num_partitions);
1523       for (int64_t source = 0; source < num_partitions; ++source) {
1524         // 0 -> 1, 1 -> 2, 2 -> 3, ...
1525         extra_sd_pairs[source] = {source, (source + 1) % num_partitions};
1526       }
1527       auto extra_result =
1528           b->AddInstruction(HloInstruction::CreateGetTupleElement(
1529               extra_buffer->shape(), while_loop, 3));
1530       if (options.bidirectional_windowed_einsum && num_partitions % 4 == 0) {
1531         extra_result = lhs.state()
1532                            .collective_ops_creator
1533                            .create_cross_partition_collective_permute(
1534                                b, extra_result, extra_sd_pairs,
1535                                (*lhs.state().next_channel_id)++);
1536       }
1537       if (options.unroll_windowed_einsum && num_partitions % 2 == 0) {
1538         result = lhs.state()
1539                      .collective_ops_creator
1540                      .create_cross_partition_collective_permute(
1541                          b, result, extra_sd_pairs,
1542                          (*lhs.state().next_channel_id)++);
1543       }
1544       result = b->AddInstruction(HloInstruction::CreateBinary(
1545           result->shape(), HloOpcode::kAdd, result, extra_result));
1546     }
1547     if (!ShapeUtil::Compatible(padded_result_buffer_shape,
1548                                unpadded_result_buffer_shape)) {
1549       result = b->AddInstruction(HloInstruction::CreateSlice(
1550           unpadded_result_buffer_shape, result,
1551           std::vector<int64>(padded_result_buffer_shape.rank(), 0),
1552           unpadded_result_buffer_shape.dimensions(),
1553           std::vector<int64>(padded_result_buffer_shape.rank(), 1)));
1554     }
1555     return result;
1556   };
1557   absl::optional<WindowedEinsumConfig> e_config =
1558       GetWindowedEinsumConfiguration(
1559           num_partitions, output_lhs_non_contracting_partitions,
1560           output_rhs_non_contracting_partitions, rhs_contracting_partitions,
1561           rhs_non_contracting_partitions, rhs_batch_partitions,
1562           lhs_contracting_partitions, lhs_non_contracting_partitions,
1563           lhs_batch_partitions, ShapeSizeInBytes(rhs.base_shape()),
1564           ShapeSizeInBytes(lhs.base_shape()),
1565           ShapeSizeInBytes(output_base_shape),
1566           options.threshold_for_windowed_einsum_mib,
1567           output_sharding_transposed_to_match_lhs,
1568           output_sharding_transposed_to_match_rhs, lhs_sharding, rhs_sharding);
1569   if (e_config) {
1570     return emit_windowed_dot_general(*e_config);
1571   }
1572 
1573   {
1574     // Try batch-parallel by resharding one operand, and allowing all-reduce.
1575     TF_ASSIGN_OR_RETURN(
1576         HloInstruction * partitioned_dot,
1577         try_emit_output_batch_partitioned_einsum_with_reshard(true));
1578     if (partitioned_dot) {
1579       return partitioned_dot;
1580     }
1581   }
1582 
1583   // LHS and RHS have the same partitioned contracting dimensions.
1584   if (lhs_contracting_partitions == rhs_contracting_partitions &&
1585       lhs_contracting_partitions == num_partitions) {
1586     auto zero = b->AddInstruction(HloInstruction::CreateConstant(
1587         LiteralUtil::Zero(output_base_shape.element_type())));
1588     // Pad both sides with zero, since NaN at one side cannot be masked by zero
1589     // on the other side.
1590     if (ShapeSizeInBytes(lhs.base_shape()) <
1591         ShapeSizeInBytes(rhs.base_shape())) {
1592       lhs =
1593           lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero);
1594       rhs = rhs.PadWithValue(zero);
1595     } else {
1596       lhs = lhs.PadWithValue(zero);
1597       rhs =
1598           rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero);
1599     }
1600     TF_ASSIGN_OR_RETURN(
1601         auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
1602     std::vector<int64> lhs_contracting_dims;
1603     lhs_contracting_dims.reserve(lhs.base_shape().rank());
1604     for (const auto& cd : dims_mapping.contracting_dims) {
1605       lhs_contracting_dims.push_back(cd.lhs);
1606     }
1607     auto ar = lhs.state().partitioner->AllReduceAlongShardingDims(
1608         b, dot, lhs.sharding(), lhs.state().next_channel_id,
1609         lhs_contracting_dims, lhs.state().collective_ops_creator,
1610         MakeBinaryAdd(output_base_shape.element_type(), module));
1611     ar->set_sharding(HloSharding::Replicate());
1612     return PartitionedHlo(ar, output_base_shape, lhs.state())
1613         .Reshard(output_sharding)
1614         .hlo();
1615   }
1616 
1617   // LHS and output have the same partitioned non-contracting dimensions.
1618   if (lhs_non_contracting_partitions == num_partitions &&
1619       output_lhs_non_contracting_partitions == num_partitions &&
1620       lhs_sharding_transposed_to_match_output == output_sharding) {
1621     auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo();
1622     TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs_replicated,
1623                                                      b, conv_window));
1624     return dot;
1625   }
1626 
1627   // RHS and output have the same partitioned non-contracting dimensions.
1628   if (rhs_non_contracting_partitions == num_partitions &&
1629       output_rhs_non_contracting_partitions == num_partitions &&
1630       rhs_sharding_transposed_to_match_output == output_sharding) {
1631     auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo();
1632     TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs_replicated, rhs.hlo(),
1633                                                      b, conv_window));
1634     return dot;
1635   }
1636 
1637   if (may_reshard_without_detecting_match) {
1638     // Output is batch partitioned.
1639     if (output_batch_partitions == num_partitions) {
1640       auto resharded_lhs =
1641           lhs.Reshard(*output_sharding_transposed_to_match_lhs);
1642       auto resharded_rhs =
1643           rhs.Reshard(*output_sharding_transposed_to_match_rhs);
1644       TF_ASSIGN_OR_RETURN(
1645           auto dot, create_sharded_dot(resharded_lhs.hlo(), resharded_rhs.hlo(),
1646                                        b, conv_window));
1647       return dot;
1648     }
1649     // Output is partitioned along LHS non-contracting dimensions.
1650     if (output_lhs_non_contracting_partitions == num_partitions) {
1651       auto resharded_lhs =
1652           lhs.Reshard(*output_sharding_transposed_to_match_lhs);
1653       auto replicated_rhs = rhs.Reshard(HloSharding::Replicate());
1654       TF_ASSIGN_OR_RETURN(
1655           auto dot, create_sharded_dot(resharded_lhs.hlo(),
1656                                        replicated_rhs.hlo(), b, conv_window));
1657       return dot;
1658     }
1659     // Output is partitioned along RHS non-contracting dimensions.
1660     if (output_rhs_non_contracting_partitions == num_partitions) {
1661       auto replicated_lhs = lhs.Reshard(HloSharding::Replicate());
1662       auto resharded_rhs =
1663           rhs.Reshard(*output_sharding_transposed_to_match_rhs);
1664       TF_ASSIGN_OR_RETURN(
1665           auto dot, create_sharded_dot(replicated_lhs.hlo(),
1666                                        resharded_rhs.hlo(), b, conv_window));
1667       return dot;
1668     }
1669   }
1670 
1671   // Returns true if it is beneficial to reshard the operand at `operand_idx`
1672   // across the contracting dimension.
1673   const auto should_partition_contracting_dim = [&](int64_t operand_idx) {
1674     if (!output_sharding.IsReplicated()) {
1675       return false;
1676     }
1677 
1678     if (operand_idx == 0) {
1679       // If LHS and output are replicated, we compare the cost of all-gather
1680       // on RHS vs all-reduce on the output.
1681       return (rhs_contracting_partitions == num_partitions) &&
1682              lhs.sharding().IsReplicated() &&
1683              ShapeUtil::ElementsIn(rhs.base_shape()) >
1684                  ShapeUtil::ElementsIn(output_base_shape);
1685     } else {
1686       return (lhs_contracting_partitions == num_partitions) &&
1687              rhs.sharding().IsReplicated() &&
1688              ShapeUtil::ElementsIn(lhs.base_shape()) >
1689                  ShapeUtil::ElementsIn(output_base_shape);
1690     }
1691   };
1692 
1693   // When the output is replicated and one of the operands is partitioned along
1694   // contracting dimension, align the other operand to be partitioned along
1695   // the contracting dimensions.
1696   if (output_sharding.IsReplicated() && (should_partition_contracting_dim(0) ||
1697                                          should_partition_contracting_dim(1))) {
1698     auto zero = b->AddInstruction(HloInstruction::CreateConstant(
1699         LiteralUtil::Zero(output_base_shape.element_type())));
1700     if (should_partition_contracting_dim(0)) {
1701       lhs =
1702           lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero);
1703       rhs = rhs.PadWithValue(zero);
1704     } else {
1705       lhs = lhs.PadWithValue(zero);
1706       rhs =
1707           rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero);
1708     }
1709     TF_ASSIGN_OR_RETURN(
1710         auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
1711 
1712     std::vector<int64> lhs_contracting_dims;
1713     lhs_contracting_dims.reserve(lhs.base_shape().rank());
1714     for (const auto& cd : dims_mapping.contracting_dims) {
1715       lhs_contracting_dims.push_back(cd.lhs);
1716     }
1717     return lhs.state().partitioner->AllReduceAlongShardingDims(
1718         b, dot, lhs.sharding(), lhs.state().next_channel_id,
1719         lhs_contracting_dims, lhs.state().collective_ops_creator,
1720         MakeBinaryAdd(output_base_shape.element_type(), module));
1721   }
1722   return nullptr;
1723 }
1724 
1725 StatusOr<HloInstruction*> PartitionDot(
1726     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
1727     const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
1728     int64_t num_partitions,
1729     const std::function<StatusOr<HloInstruction*>(
1730         HloInstruction*, HloInstruction*, SpmdBuilder*,
1731         const Window& conv_window)>& create_sharded_dot,
1732     const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
1733     const SpmdPartitionerOptions& options, SpmdBuilder* b,
1734     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
1735         windowed_dot_general_loops);
1736 
PartitionDotGroupOnBatch(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64_t num_partitions,int64_t lhs_contracting_partitions,int64_t rhs_contracting_partitions,int64_t lhs_non_contracting_partitions,int64_t rhs_non_contracting_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,bool require_matching_devices_to_group,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops)1737 StatusOr<HloInstruction*> PartitionDotGroupOnBatch(
1738     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
1739     const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
1740     int64_t num_partitions, int64_t lhs_contracting_partitions,
1741     int64_t rhs_contracting_partitions, int64_t lhs_non_contracting_partitions,
1742     int64_t rhs_non_contracting_partitions,
1743     const std::function<StatusOr<HloInstruction*>(
1744         HloInstruction*, HloInstruction*, SpmdBuilder*,
1745         const Window& conv_window)>& create_sharded_dot,
1746     const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
1747     bool require_matching_devices_to_group,
1748     const SpmdPartitionerOptions& options, SpmdBuilder* b,
1749     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
1750         windowed_dot_general_loops) {
1751   std::vector<std::pair<HloInstruction*, HloSharding>>
1752       top_level_sharding_to_reset;
1753   auto cleaner = tensorflow::gtl::MakeCleanup([&] {
1754     for (auto& to_reset : top_level_sharding_to_reset) {
1755       to_reset.first->set_sharding(to_reset.second);
1756     }
1757   });
1758   std::vector<int64> lhs_dims;
1759   std::vector<int64> rhs_dims;
1760   std::vector<int64> output_dims;
1761   auto lhs_sharding_dims_adjusted_to_output =
1762       lhs.sharding().IsReplicated()
1763           ? std::vector<int64>(lhs.base_shape().rank(), 1)
1764           : lhs.sharding().tile_assignment().dimensions();
1765   auto rhs_sharding_dims_adjusted_to_output =
1766       rhs.sharding().IsReplicated()
1767           ? std::vector<int64>(rhs.base_shape().rank(), 1)
1768           : rhs.sharding().tile_assignment().dimensions();
1769   auto output_sharding_dims_adjusted_to_lhs =
1770       output_sharding.tile_assignment().dimensions();
1771   bool lhs_rhs_dims_matching = true;
1772   for (const auto& dim : dims_mapping.batch_dims) {
1773     lhs_dims.push_back(dim.lhs);
1774     rhs_dims.push_back(dim.rhs);
1775     output_dims.push_back(dim.output);
1776     if (lhs_sharding_dims_adjusted_to_output[dim.lhs] !=
1777         rhs_sharding_dims_adjusted_to_output[dim.rhs]) {
1778       lhs_rhs_dims_matching = false;
1779     }
1780     lhs_sharding_dims_adjusted_to_output[dim.lhs] =
1781         output_sharding.tile_assignment().dim(dim.output);
1782     rhs_sharding_dims_adjusted_to_output[dim.rhs] =
1783         output_sharding.tile_assignment().dim(dim.output);
1784     output_sharding_dims_adjusted_to_lhs[dim.output] =
1785         lhs.sharding().tile_assignment().dim(dim.lhs);
1786   }
1787   if (require_matching_devices_to_group && lhs_rhs_dims_matching) {
1788     lhs_rhs_dims_matching =
1789         rhs.sharding() == UngroupSharding(AlignGroupsWith(
1790                               GroupShardingOnDims(rhs.sharding(), rhs_dims),
1791                               GroupShardingOnDims(lhs.sharding(), lhs_dims)));
1792   }
1793   auto output_grouped = GroupShardingOnDims(output_sharding, output_dims);
1794   PartitionedHlo per_group_lhs = lhs;
1795   PartitionedHlo per_group_rhs = rhs;
1796   if (lhs_rhs_dims_matching) {
1797     auto lhs_grouped = GroupShardingOnDims(lhs.sharding(), lhs_dims);
1798     auto rhs_grouped = GroupShardingOnDims(rhs.sharding(), rhs_dims);
1799     if (ShapeUtil::ByteSizeOf(lhs.hlo()->shape()) >
1800         ShapeUtil::ByteSizeOf(rhs.hlo()->shape())) {
1801       rhs_grouped = AlignGroupsWith(std::move(rhs_grouped), lhs_grouped);
1802       rhs = rhs.Reshard(UngroupSharding(rhs_grouped));
1803     } else {
1804       lhs_grouped = AlignGroupsWith(std::move(lhs_grouped), rhs_grouped);
1805       lhs = lhs.Reshard(UngroupSharding(lhs_grouped));
1806     }
1807     auto reshaped_output_tiling = output_sharding.tile_assignment();
1808     reshaped_output_tiling.Reshape(output_sharding_dims_adjusted_to_lhs);
1809     output_grouped = AlignGroupsWith(
1810         GroupShardingOnDims(
1811             output_sharding.ReplicateOnLastTileDim()
1812                 ? HloSharding::PartialTile(reshaped_output_tiling)
1813                 : HloSharding::Tile(reshaped_output_tiling),
1814             output_dims),
1815         lhs_grouped);
1816     auto per_group_partitioner_state = CreatePerGroupPartitioningState(
1817         lhs.state(), lhs_grouped.device_groups, b);
1818     top_level_sharding_to_reset.emplace_back(lhs.hlo(), lhs.sharding());
1819     lhs.hlo()->set_sharding(lhs_grouped.sharding);
1820     top_level_sharding_to_reset.emplace_back(rhs.hlo(), rhs.sharding());
1821     rhs.hlo()->set_sharding(rhs_grouped.sharding);
1822     CHECK(lhs.hlo() != rhs.hlo() ||
1823           lhs_grouped.sharding == rhs_grouped.sharding);
1824     per_group_lhs = PartitionedHlo(
1825         lhs.hlo(), GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()),
1826         per_group_partitioner_state);
1827     per_group_rhs = PartitionedHlo(
1828         rhs.hlo(), GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()),
1829         per_group_partitioner_state);
1830   } else {
1831     auto per_group_partitioner_state = CreatePerGroupPartitioningState(
1832         lhs.state(), output_grouped.device_groups, b);
1833     auto reshard_to_output_batch =
1834         [&](PartitionedHlo operand, absl::Span<const int64> batch_dims,
1835             absl::Span<const int64> contracting_dims,
1836             absl::Span<const int64> non_contracting_dims,
1837             int64_t contracting_dim_partitions,
1838             int64_t non_contracting_dim_partitions,
1839             int64_t other_contracting_dim_partitions,
1840             std::vector<int64>* sharding_dims_adjusted_to_output)
1841         -> absl::optional<PartitionedHlo> {
1842       if (operand.sharding().IsTileMaximal()) {
1843         auto partially_sharded = PerGroupSliceFromReplicated(
1844             operand.Replicate().hlo(), operand.state().partition_id,
1845             output_grouped.device_groups, batch_dims,
1846             output_grouped.group_dim_sizes, b);
1847         partially_sharded->set_sharding(HloSharding::Replicate());
1848         return PartitionedHlo(partially_sharded, partially_sharded->shape(),
1849                               per_group_partitioner_state);
1850       }
1851       auto reshaped_tiling = operand.sharding().tile_assignment();
1852       // It's possible that the operand is not initially sharded on batch
1853       // dimensions in the same way as the output, although being tiled. In that
1854       // case, the current sharding_dims_adjusted_to_output may contain more
1855       // partitions than available devices. We remove partitioning on other
1856       // dimensions.
1857       if (Product(*sharding_dims_adjusted_to_output) >
1858           reshaped_tiling.num_elements()) {
1859         if (Product(*sharding_dims_adjusted_to_output) %
1860                 reshaped_tiling.num_elements() !=
1861             0) {
1862           return absl::nullopt;
1863         }
1864         int64_t ratio = Product(*sharding_dims_adjusted_to_output) /
1865                         reshaped_tiling.num_elements();
1866         if (operand.sharding().ReplicateOnLastTileDim() &&
1867             reshaped_tiling.dimensions().back() % ratio == 0) {
1868           sharding_dims_adjusted_to_output->back() /= ratio;
1869           if (sharding_dims_adjusted_to_output->back() == 1) {
1870             sharding_dims_adjusted_to_output->pop_back();
1871           }
1872         } else if (ratio == non_contracting_dim_partitions &&
1873                    (ratio != contracting_dim_partitions ||
1874                     contracting_dim_partitions ==
1875                         other_contracting_dim_partitions)) {
1876           for (int64_t dim : non_contracting_dims) {
1877             (*sharding_dims_adjusted_to_output)[dim] = 1;
1878           }
1879         } else if (ratio == contracting_dim_partitions) {
1880           for (int64_t dim : contracting_dims) {
1881             (*sharding_dims_adjusted_to_output)[dim] = 1;
1882           }
1883         } else {
1884           return absl::nullopt;
1885         }
1886       }
1887       // If the operand is initially sharded more ways than the output in the
1888       // batch dimensions, sharding_dims_adjusted_to_output currently contains
1889       // fewer partitions than available devices. We do not handle this case.
1890       if (Product(*sharding_dims_adjusted_to_output) <
1891           reshaped_tiling.num_elements()) {
1892         return absl::nullopt;
1893       }
1894       reshaped_tiling.Reshape(*sharding_dims_adjusted_to_output);
1895       auto grouped = AlignGroupsWith(
1896           GroupShardingOnDims(operand.base_shape().rank() <
1897                                       sharding_dims_adjusted_to_output->size()
1898                                   ? HloSharding::PartialTile(reshaped_tiling)
1899                                   : HloSharding::Tile(reshaped_tiling),
1900                               batch_dims),
1901           output_grouped);
1902       if (require_matching_devices_to_group &&
1903           operand.sharding() != UngroupSharding(grouped)) {
1904         return absl::nullopt;
1905       }
1906       auto resharded = operand.Reshard(UngroupSharding(grouped));
1907       top_level_sharding_to_reset.emplace_back(resharded.hlo(),
1908                                                resharded.sharding());
1909       resharded.hlo()->set_sharding(grouped.sharding);
1910       return PartitionedHlo(resharded.hlo(),
1911                             GetPerGroupBaseShape(grouped, operand.base_shape()),
1912                             per_group_partitioner_state);
1913     };
1914     std::vector<int64> lhs_contracting_dims;
1915     std::vector<int64> rhs_contracting_dims;
1916     lhs_contracting_dims.reserve(dims_mapping.contracting_dims.size());
1917     rhs_contracting_dims.reserve(dims_mapping.contracting_dims.size());
1918     for (const auto& dim : dims_mapping.contracting_dims) {
1919       lhs_contracting_dims.push_back(dim.lhs);
1920       rhs_contracting_dims.push_back(dim.rhs);
1921     }
1922     std::vector<int64> lhs_non_contracting_dims;
1923     std::vector<int64> rhs_non_contracting_dims;
1924     lhs_non_contracting_dims.reserve(
1925         dims_mapping.lhs_non_contracting_dims.size());
1926     rhs_non_contracting_dims.reserve(
1927         dims_mapping.rhs_non_contracting_dims.size());
1928     for (const auto& dim : dims_mapping.lhs_non_contracting_dims) {
1929       lhs_non_contracting_dims.push_back(dim.lhs);
1930     }
1931     for (const auto& dim : dims_mapping.rhs_non_contracting_dims) {
1932       rhs_non_contracting_dims.push_back(dim.rhs);
1933     }
1934     if (auto resharded = reshard_to_output_batch(
1935             lhs, lhs_dims, lhs_contracting_dims, lhs_non_contracting_dims,
1936             lhs_contracting_partitions, lhs_non_contracting_partitions,
1937             rhs_contracting_partitions,
1938             &lhs_sharding_dims_adjusted_to_output)) {
1939       per_group_lhs = *resharded;
1940     } else {
1941       return nullptr;
1942     }
1943     if (auto resharded = reshard_to_output_batch(
1944             rhs, rhs_dims, rhs_contracting_dims, rhs_non_contracting_dims,
1945             rhs_contracting_partitions, rhs_non_contracting_partitions,
1946             lhs_contracting_partitions,
1947             &rhs_sharding_dims_adjusted_to_output)) {
1948       per_group_rhs = *resharded;
1949     } else {
1950       return nullptr;
1951     }
1952     CHECK(lhs.hlo() != rhs.hlo() ||
1953           per_group_lhs.sharding() == per_group_rhs.sharding());
1954   }
1955   TF_ASSIGN_OR_RETURN(
1956       auto dot,
1957       PartitionDot(per_group_lhs, per_group_rhs,
1958                    GetPerGroupBaseShape(output_grouped, output_base_shape),
1959                    output_grouped.sharding, dims_mapping,
1960                    num_partitions / output_grouped.device_groups.size(),
1961                    create_sharded_dot, conv_window, module, original_hlo,
1962                    options, b, windowed_dot_general_loops));
1963   dot->set_sharding(UngroupSharding(output_grouped));
1964   return PartitionedHlo(dot, output_base_shape, lhs.state())
1965       .Reshard(output_sharding)
1966       .hlo();
1967 }
1968 
GetNonContractingPartitionGroupedShardingForMatchedOperand(bool lhs_matching,const HloSharding & matching_sharding,const HloSharding & output_sharding,absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_dims)1969 GroupedSharding GetNonContractingPartitionGroupedShardingForMatchedOperand(
1970     bool lhs_matching, const HloSharding& matching_sharding,
1971     const HloSharding& output_sharding,
1972     absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_dims) {
1973   std::vector<int64> matching_sharding_dims =
1974       matching_sharding.tile_assignment().dimensions();
1975   std::vector<int64> matching_dims;
1976   std::vector<int64> output_dims;
1977   // Make sure the partitioning on matching's non-contracting dimensions
1978   // defines the same device groups for both matching and output.
1979   for (const auto& dim : partitioned_dims) {
1980     int64_t md = lhs_matching ? dim.lhs : dim.rhs;
1981     matching_sharding_dims[md] =
1982         output_sharding.tile_assignment().dim(dim.output);
1983     matching_dims.push_back(md);
1984     output_dims.push_back(dim.output);
1985   }
1986   GroupedSharding output_grouped =
1987       GroupShardingOnDims(output_sharding, output_dims);
1988   Array<int64> reshaped_matching_tiling = matching_sharding.tile_assignment();
1989   reshaped_matching_tiling.Reshape(matching_sharding_dims);
1990   return AlignGroupsWith(
1991       GroupShardingOnDims(
1992           matching_sharding.ReplicateOnLastTileDim()
1993               ? HloSharding::PartialTile(reshaped_matching_tiling)
1994               : HloSharding::Tile(reshaped_matching_tiling),
1995           matching_dims),
1996       output_grouped);
1997 }
1998 
1999 absl::optional<GroupedSharding>
GetNonContractingPartitionGroupedShardingForOtherOperand(bool lhs_matching,const Shape & output_base_shape,const Shape & other_shape,int64_t other_contracting_partitions,int64_t other_non_contracting_partitions,int64_t matching_contracting_partitions,int64_t output_other_non_contracting_partitions,const HloSharding & other_sharding,const HloSharding & output_sharding,absl::Span<const DotConvDimsMapping::DimsMapping> matching_partitioned_dims,absl::Span<const DotConvDimsMapping::DimsMapping> other_non_contracting_dims,absl::Span<const DotConvDimsMapping::DimsMapping> other_contracting_dims)2000 GetNonContractingPartitionGroupedShardingForOtherOperand(
2001     bool lhs_matching, const Shape& output_base_shape, const Shape& other_shape,
2002     int64_t other_contracting_partitions,
2003     int64_t other_non_contracting_partitions,
2004     int64_t matching_contracting_partitions,
2005     int64_t output_other_non_contracting_partitions,
2006     const HloSharding& other_sharding, const HloSharding& output_sharding,
2007     absl::Span<const DotConvDimsMapping::DimsMapping> matching_partitioned_dims,
2008     absl::Span<const DotConvDimsMapping::DimsMapping>
2009         other_non_contracting_dims,
2010     absl::Span<const DotConvDimsMapping::DimsMapping> other_contracting_dims) {
2011   int64_t group_count = 1;
2012   std::vector<int64> output_dims;
2013   for (const auto& dim : matching_partitioned_dims) {
2014     output_dims.push_back(dim.output);
2015     group_count *= output_sharding.tile_assignment().dim(dim.output);
2016   }
2017   GroupedSharding output_grouped =
2018       GroupShardingOnDims(output_sharding, output_dims);
2019   std::vector<int64> other_group_dims;
2020   if (other_sharding.ReplicateOnLastTileDim() &&
2021       other_sharding.tile_assignment().dimensions().back() % group_count == 0) {
2022     other_group_dims.push_back(
2023         other_sharding.tile_assignment().num_dimensions() - 1);
2024   } else {
2025     const bool may_replicate_other_contracting_dims =
2026         (other_contracting_partitions == group_count &&
2027          other_non_contracting_partitions ==
2028              output_other_non_contracting_partitions);
2029     const bool may_replicate_other_non_contracting_dims =
2030         group_count == other_non_contracting_partitions &&
2031         matching_contracting_partitions == other_contracting_partitions;
2032     if (auto found_dims = FindMatchingPartitionedDimsForGrouping(
2033             other_sharding, output_grouped.device_groups)) {
2034       other_group_dims = std::move(*found_dims);
2035     } else if (may_replicate_other_contracting_dims &&
2036                (!may_replicate_other_non_contracting_dims ||
2037                 ShapeUtil::ByteSizeOf(other_shape)) <=
2038                    ShapeUtil::ByteSizeOf(MakePartitionedShape(
2039                        output_base_shape, output_sharding))) {
2040       for (const auto& dim : other_contracting_dims) {
2041         other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs);
2042       }
2043     } else if (may_replicate_other_non_contracting_dims) {
2044       for (const auto& dim : other_non_contracting_dims) {
2045         other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs);
2046       }
2047     } else {
2048       return absl::nullopt;
2049     }
2050   }
2051   if (other_group_dims.size() == 1 &&
2052       other_group_dims[0] ==
2053           other_sharding.tile_assignment().num_dimensions() - 1) {
2054     return AlignGroupsWith(
2055         GroupShardingOnDims(
2056             other_sharding, {other_group_dims[0]},
2057             {other_sharding.tile_assignment().dimensions().back() /
2058              group_count}),
2059         output_grouped, /*ignore_group_order=*/true);
2060 
2061   } else if (!other_sharding.IsReplicated()) {
2062     return AlignGroupsWith(
2063         GroupShardingOnDims(other_sharding, other_group_dims), output_grouped,
2064         /*ignore_group_order=*/true);
2065   }
2066   return absl::nullopt;
2067 }
2068 
PartitionDotGroupOnNonContracting(bool lhs_matching,PartitionedHlo matching,PartitionedHlo other,int64_t matching_contracting_partitions,int64_t other_contracting_partitions,absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_non_contracting_dims,int64_t other_non_contracting_partitions,int64_t output_other_non_contracting_partitions,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64_t num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,bool require_matching_devices_to_group,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops)2069 StatusOr<HloInstruction*> PartitionDotGroupOnNonContracting(
2070     bool lhs_matching, PartitionedHlo matching, PartitionedHlo other,
2071     int64_t matching_contracting_partitions,
2072     int64_t other_contracting_partitions,
2073     absl::Span<const DotConvDimsMapping::DimsMapping>
2074         partitioned_non_contracting_dims,
2075     int64_t other_non_contracting_partitions,
2076     int64_t output_other_non_contracting_partitions,
2077     const Shape& output_base_shape, const HloSharding& output_sharding,
2078     const DotConvDimsMapping& dims_mapping, int64_t num_partitions,
2079     const std::function<StatusOr<HloInstruction*>(
2080         HloInstruction*, HloInstruction*, SpmdBuilder*,
2081         const Window& conv_window)>& create_sharded_dot,
2082     const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
2083     bool require_matching_devices_to_group,
2084     const SpmdPartitionerOptions& options, SpmdBuilder* b,
2085     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
2086         windowed_dot_general_loops) {
2087   std::vector<std::pair<HloInstruction*, HloSharding>>
2088       top_level_sharding_to_reset;
2089   auto cleaner = tensorflow::gtl::MakeCleanup([&] {
2090     for (auto& to_reset : top_level_sharding_to_reset) {
2091       to_reset.first->set_sharding(to_reset.second);
2092     }
2093   });
2094 
2095   std::vector<int64> output_dims;
2096   for (const auto& dim : partitioned_non_contracting_dims) {
2097     output_dims.push_back(dim.output);
2098   }
2099   GroupedSharding output_grouped =
2100       GroupShardingOnDims(output_sharding, output_dims);
2101   GroupedSharding matching_grouped =
2102       GetNonContractingPartitionGroupedShardingForMatchedOperand(
2103           lhs_matching, matching.sharding(), output_sharding,
2104           partitioned_non_contracting_dims);
2105   if (require_matching_devices_to_group &&
2106       matching.sharding() != UngroupSharding(matching_grouped)) {
2107     return nullptr;
2108   }
2109   absl::optional<GroupedSharding> other_grouped =
2110       GetNonContractingPartitionGroupedShardingForOtherOperand(
2111           lhs_matching, output_base_shape, other.hlo()->shape(),
2112           other_contracting_partitions, other_non_contracting_partitions,
2113           matching_contracting_partitions,
2114           output_other_non_contracting_partitions, other.sharding(),
2115           output_sharding, partitioned_non_contracting_dims,
2116           lhs_matching ? dims_mapping.rhs_non_contracting_dims
2117                        : dims_mapping.lhs_non_contracting_dims,
2118           dims_mapping.contracting_dims);
2119 
2120   if (!other_grouped) {
2121     other = other.Replicate();
2122   }
2123   matching = matching.Reshard(UngroupSharding(matching_grouped));
2124   auto per_group_partitioner_state = CreatePerGroupPartitioningState(
2125       matching.state(), matching_grouped.device_groups, b);
2126   top_level_sharding_to_reset.emplace_back(matching.hlo(), matching.sharding());
2127   matching.hlo()->set_sharding(matching_grouped.sharding);
2128   auto matching_p = PartitionedHlo(
2129       matching.hlo(),
2130       GetPerGroupBaseShape(matching_grouped, matching.base_shape()),
2131       per_group_partitioner_state);
2132 
2133   auto partially_replicated_other = other.hlo();
2134   if (other_grouped && other_grouped->group_dims.size() == 1 &&
2135       other_grouped->group_dims[0] == other.base_shape().rank()) {
2136     // Group on replication dim.
2137     other = other.Reshard(UngroupSharding(*other_grouped));
2138     partially_replicated_other = other.hlo();
2139     top_level_sharding_to_reset.emplace_back(other.hlo(), other.sharding());
2140     partially_replicated_other->set_sharding(other_grouped->sharding);
2141   } else if (!other.sharding().IsReplicated()) {
2142     other = other.Reshard(UngroupSharding(*other_grouped));
2143     partially_replicated_other =
2144         other
2145             .Reshard(hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2146                 other.sharding(), other_grouped->group_dims))
2147             .hlo();
2148     top_level_sharding_to_reset.emplace_back(
2149         partially_replicated_other, partially_replicated_other->sharding());
2150     partially_replicated_other->set_sharding(other_grouped->sharding);
2151   }
2152   auto other_p = PartitionedHlo(partially_replicated_other, other.base_shape(),
2153                                 per_group_partitioner_state);
2154   TF_ASSIGN_OR_RETURN(
2155       auto dot,
2156       PartitionDot(lhs_matching ? matching_p : other_p,
2157                    lhs_matching ? other_p : matching_p,
2158                    GetPerGroupBaseShape(output_grouped, output_base_shape),
2159                    output_grouped.sharding, dims_mapping,
2160                    num_partitions / matching_grouped.device_groups.size(),
2161                    create_sharded_dot, conv_window, module, original_hlo,
2162                    options, b, windowed_dot_general_loops));
2163   return dot;
2164 }
2165 
2166 std::pair<HloSharding, HloSharding>
GetDotGroupPartitionContractingOutputShardings(const DotConvDimsMapping & dims_mapping,const GroupedSharding & lhs_grouped,const Shape & output_base_shape,const HloSharding & output_sharding,int64_t group_count,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,int64_t output_batch_partitions,std::vector<int64> * output_slice_dims_out)2167 GetDotGroupPartitionContractingOutputShardings(
2168     const DotConvDimsMapping& dims_mapping, const GroupedSharding& lhs_grouped,
2169     const Shape& output_base_shape, const HloSharding& output_sharding,
2170     int64_t group_count, int64_t output_lhs_non_contracting_partitions,
2171     int64_t output_rhs_non_contracting_partitions,
2172     int64_t output_batch_partitions,
2173     std::vector<int64>* output_slice_dims_out) {
2174   HloSharding inner_output_sharding = HloSharding::Replicate();
2175   HloSharding outer_output_tmp_sharding = HloSharding::Replicate();
2176   std::vector<int64> output_slice_dims;
2177   if (output_sharding.ReplicateOnLastTileDim() &&
2178       output_sharding.tile_assignment().dimensions().back() % group_count ==
2179           0) {
2180     auto grouped = AlignGroupsWith(
2181         GroupShardingOnDims(
2182             output_sharding,
2183             {output_sharding.tile_assignment().num_dimensions() - 1},
2184             {output_sharding.tile_assignment().dimensions().back() /
2185              group_count}),
2186         lhs_grouped,
2187         /*ignore_group_order=*/true);
2188     outer_output_tmp_sharding = UngroupSharding(grouped);
2189     inner_output_sharding = std::move(grouped.sharding);
2190   } else {
2191     if (auto found_dims = FindMatchingPartitionedDimsForGrouping(
2192             output_sharding, lhs_grouped.device_groups)) {
2193       output_slice_dims = std::move(*found_dims);
2194     } else if (output_lhs_non_contracting_partitions == group_count ||
2195                output_rhs_non_contracting_partitions == group_count ||
2196                output_batch_partitions == group_count) {
2197       if (output_lhs_non_contracting_partitions == group_count) {
2198         for (const auto& dim : dims_mapping.lhs_non_contracting_dims) {
2199           output_slice_dims.push_back(dim.output);
2200         }
2201       } else if (output_rhs_non_contracting_partitions == group_count) {
2202         for (const auto& dim : dims_mapping.rhs_non_contracting_dims) {
2203           output_slice_dims.push_back(dim.output);
2204         }
2205       } else {
2206         for (const auto& dim : dims_mapping.batch_dims) {
2207           output_slice_dims.push_back(dim.output);
2208         }
2209       }
2210     }
2211     if (!output_slice_dims.empty()) {
2212       auto grouped = AlignGroupsWith(
2213           GroupShardingOnDims(output_sharding, output_slice_dims), lhs_grouped);
2214       inner_output_sharding = grouped.sharding;
2215       outer_output_tmp_sharding = UngroupSharding(grouped);
2216     }
2217   }
2218   if (output_slice_dims_out) {
2219     (*output_slice_dims_out) = std::move(output_slice_dims);
2220   }
2221   return std::make_pair(inner_output_sharding, outer_output_tmp_sharding);
2222 }
2223 
2224 std::pair<HloSharding, HloSharding>
GetDotGroupPartitionContractingLhsRhsShardings(const PartitionedHlo & lhs,const PartitionedHlo & rhs,absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_contracting_dims)2225 GetDotGroupPartitionContractingLhsRhsShardings(
2226     const PartitionedHlo& lhs, const PartitionedHlo& rhs,
2227     absl::Span<const DotConvDimsMapping::DimsMapping>
2228         partitioned_contracting_dims) {
2229   HloSharding lhs_sharding = lhs.sharding();
2230   HloSharding rhs_sharding = rhs.sharding();
2231   std::vector<int64> lhs_tile_shape =
2232       lhs_sharding.tile_assignment().dimensions();
2233   std::vector<int64> rhs_tile_shape =
2234       rhs_sharding.tile_assignment().dimensions();
2235   if (ShapeUtil::ByteSizeOf(lhs.hlo()->shape()) >
2236       ShapeUtil::ByteSizeOf(rhs.hlo()->shape())) {
2237     for (const auto& dim : partitioned_contracting_dims) {
2238       rhs_tile_shape[dim.rhs] = lhs_tile_shape[dim.lhs];
2239     }
2240     auto new_tile = rhs.sharding().tile_assignment();
2241     new_tile.Reshape(rhs_tile_shape);
2242     rhs_sharding = rhs_sharding.ReplicateOnLastTileDim()
2243                        ? HloSharding::PartialTile(new_tile)
2244                        : HloSharding::Tile(new_tile);
2245   } else {
2246     for (const auto& dim : partitioned_contracting_dims) {
2247       lhs_tile_shape[dim.lhs] = rhs_tile_shape[dim.rhs];
2248     }
2249     auto new_tile = lhs.sharding().tile_assignment();
2250     new_tile.Reshape(lhs_tile_shape);
2251     lhs_sharding = lhs_sharding.ReplicateOnLastTileDim()
2252                        ? HloSharding::PartialTile(new_tile)
2253                        : HloSharding::Tile(new_tile);
2254   }
2255   return std::make_pair(lhs_sharding, rhs_sharding);
2256 }
2257 
PartitionDotGroupOnContracting(PartitionedHlo lhs,PartitionedHlo rhs,absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_contracting_dims,int64_t output_batch_partitions,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64_t num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,bool require_matching_devices_to_group,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops)2258 StatusOr<HloInstruction*> PartitionDotGroupOnContracting(
2259     PartitionedHlo lhs, PartitionedHlo rhs,
2260     absl::Span<const DotConvDimsMapping::DimsMapping>
2261         partitioned_contracting_dims,
2262     int64_t output_batch_partitions,
2263     int64_t output_lhs_non_contracting_partitions,
2264     int64_t output_rhs_non_contracting_partitions,
2265     const Shape& output_base_shape, const HloSharding& output_sharding,
2266     const DotConvDimsMapping& dims_mapping, int64_t num_partitions,
2267     const std::function<StatusOr<HloInstruction*>(
2268         HloInstruction*, HloInstruction*, SpmdBuilder*,
2269         const Window& conv_window)>& create_sharded_dot,
2270     const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
2271     bool require_matching_devices_to_group,
2272     const SpmdPartitionerOptions& options, SpmdBuilder* b,
2273     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
2274         windowed_dot_general_loops) {
2275   std::vector<std::pair<HloInstruction*, HloSharding>>
2276       top_level_sharding_to_reset;
2277   auto cleaner = tensorflow::gtl::MakeCleanup([&] {
2278     for (auto& to_reset : top_level_sharding_to_reset) {
2279       to_reset.first->set_sharding(to_reset.second);
2280     }
2281   });
2282   std::vector<int64> lhs_dims;
2283   std::vector<int64> rhs_dims;
2284   int64_t group_count = 1;
2285   for (const auto& dim : partitioned_contracting_dims) {
2286     lhs_dims.push_back(dim.lhs);
2287     rhs_dims.push_back(dim.rhs);
2288     group_count *= lhs.sharding().tile_assignment().dim(dim.lhs);
2289   }
2290   HloSharding lhs_sharding = HloSharding::Replicate();
2291   HloSharding rhs_sharding = HloSharding::Replicate();
2292   std::tie(lhs_sharding, rhs_sharding) =
2293       GetDotGroupPartitionContractingLhsRhsShardings(
2294           lhs, rhs, partitioned_contracting_dims);
2295   auto lhs_grouped = GroupShardingOnDims(lhs_sharding, lhs_dims);
2296   auto rhs_grouped = GroupShardingOnDims(rhs_sharding, rhs_dims);
2297   if (ShapeUtil::ByteSizeOf(lhs.hlo()->shape()) >
2298       ShapeUtil::ByteSizeOf(rhs.hlo()->shape())) {
2299     rhs_grouped = AlignGroupsWith(rhs_grouped, lhs_grouped);
2300     rhs_sharding = UngroupSharding(rhs_grouped);
2301     if (require_matching_devices_to_group && rhs.sharding() != rhs_sharding) {
2302       return nullptr;
2303     }
2304     rhs = rhs.Reshard(rhs_sharding);
2305   } else {
2306     lhs_grouped = AlignGroupsWith(lhs_grouped, rhs_grouped);
2307     lhs_sharding = UngroupSharding(lhs_grouped);
2308     if (require_matching_devices_to_group && lhs.sharding() != lhs_sharding) {
2309       return nullptr;
2310     }
2311     lhs = lhs.Reshard(lhs_sharding);
2312   }
2313   // Mask out invalid data.
2314   std::vector<int64> lhs_skipped_dims;
2315   for (int64_t i = 0; i < lhs.base_shape().rank(); ++i) {
2316     if (absl::c_linear_search(lhs_dims, i)) {
2317       continue;
2318     }
2319     lhs_skipped_dims.push_back(i);
2320   }
2321   lhs = lhs.PadWithValue(
2322       CreateZero(ShapeUtil::MakeShape(lhs.base_shape().element_type(), {}), b),
2323       /*left_padded_dims=*/{}, lhs_skipped_dims);
2324   std::vector<int64> rhs_skipped_dims;
2325   for (int64_t i = 0; i < rhs.base_shape().rank(); ++i) {
2326     if (absl::c_linear_search(rhs_dims, i)) {
2327       continue;
2328     }
2329     rhs_skipped_dims.push_back(i);
2330   }
2331   rhs = rhs.PadWithValue(
2332       CreateZero(ShapeUtil::MakeShape(rhs.base_shape().element_type(), {}), b),
2333       /*left_padded_dims=*/{}, rhs_skipped_dims);
2334   top_level_sharding_to_reset.emplace_back(lhs.hlo(), lhs_sharding);
2335   lhs.hlo()->set_sharding(lhs_grouped.sharding);
2336   top_level_sharding_to_reset.emplace_back(rhs.hlo(), rhs_sharding);
2337   rhs.hlo()->set_sharding(rhs_grouped.sharding);
2338 
2339   HloSharding inner_output_sharding = HloSharding::Replicate();
2340   HloSharding outer_output_tmp_sharding = HloSharding::Replicate();
2341   std::vector<int64> output_slice_dims;
2342   std::tie(inner_output_sharding, outer_output_tmp_sharding) =
2343       GetDotGroupPartitionContractingOutputShardings(
2344           dims_mapping, lhs_grouped, output_base_shape, output_sharding,
2345           group_count, output_lhs_non_contracting_partitions,
2346           output_rhs_non_contracting_partitions, output_batch_partitions,
2347           &output_slice_dims);
2348   Shape inner_output_base_shape = output_base_shape;
2349   auto get_non_slice_dims = [&] {
2350     std::vector<int64> non_group_dims;
2351     for (int64_t i = 0; i < output_base_shape.rank(); ++i) {
2352       if (!absl::c_linear_search(output_slice_dims, i)) {
2353         non_group_dims.push_back(i);
2354       }
2355     }
2356     return non_group_dims;
2357   };
2358   if (!output_slice_dims.empty()) {
2359     inner_output_base_shape = MakePartitionedShape(
2360         output_base_shape,
2361         hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2362             output_sharding, get_non_slice_dims()));
2363   }
2364   std::function<StatusOr<HloInstruction*>(HloInstruction*, HloInstruction*,
2365                                           SpmdBuilder*, const Window&)>
2366       inner_creator =
2367           [&](HloInstruction* l, HloInstruction* r, SpmdBuilder* b,
2368               const Window& conv_window) -> StatusOr<HloInstruction*> {
2369     TF_ASSIGN_OR_RETURN(auto inner_dot,
2370                         create_sharded_dot(l, r, b, conv_window));
2371     auto ar = lhs.state().partitioner->AllReduceAlongShardingDims(
2372         b, inner_dot, lhs_sharding, lhs.state().next_channel_id, lhs_dims,
2373         lhs.state().collective_ops_creator,
2374         MakeBinaryAdd(output_base_shape.element_type(), module));
2375     if (output_slice_dims.empty()) {
2376       return ar;
2377     }
2378     // Use resharding to slice the output. Use a temporary reshard cache since
2379     // we are faking with replicated sharding.
2380     PartitionedHlo::PartitioningState new_state = lhs.state();
2381     new_state.b = b;
2382     new_state.partition_id =
2383         lhs.state().collective_ops_creator.create_partition_id(b);
2384     PartitionedHlo::ReshardCache tmp_cache;
2385     new_state.reshard_cache = &tmp_cache;
2386     ar->set_sharding(HloSharding::Replicate());
2387     return PartitionedHlo(ar, ar->shape(), new_state)
2388         .Reshard(hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2389             output_sharding, get_non_slice_dims()))
2390         .hlo();
2391   };
2392   // Disable doing the inner reshard when the "faster windowed einsum" flag is
2393   // enabled, because the windowed einsum implementation is currently slow with
2394   // this kind of reshard happening.
2395   if (options.choose_faster_windowed_einsum_over_mem) {
2396     inner_output_base_shape = output_base_shape;
2397     inner_creator = create_sharded_dot;
2398     outer_output_tmp_sharding =
2399         hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2400             outer_output_tmp_sharding, output_slice_dims);
2401   }
2402   PartitionedHlo::PartitioningState inner_state =
2403       CreatePerGroupPartitioningState(lhs.state(), lhs_grouped.device_groups,
2404                                       b);
2405   TF_ASSIGN_OR_RETURN(
2406       auto dot,
2407       PartitionDot(
2408           PartitionedHlo(lhs.hlo(),
2409                          GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()),
2410                          inner_state),
2411           PartitionedHlo(rhs.hlo(),
2412                          GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()),
2413                          inner_state),
2414           inner_output_base_shape, inner_output_sharding, dims_mapping,
2415           num_partitions / group_count, inner_creator, conv_window, module,
2416           original_hlo, options, b, windowed_dot_general_loops));
2417   if (!dot) {
2418     return nullptr;
2419   }
2420 
2421   if (options.choose_faster_windowed_einsum_over_mem) {
2422     HloInstruction* ar = lhs.state().partitioner->AllReduceAlongShardingDims(
2423         b, dot, lhs_sharding, lhs.state().next_channel_id, lhs_dims,
2424         lhs.state().collective_ops_creator,
2425         MakeBinaryAdd(output_base_shape.element_type(), module));
2426     dot = ar;
2427   }
2428 
2429   dot->set_sharding(outer_output_tmp_sharding);
2430   auto d = PartitionedHlo(dot, output_base_shape, lhs.state())
2431                .Reshard(output_sharding)
2432                .hlo();
2433   return d;
2434 }
2435 
ConvertDimsMappingWithFeatureGroupCount(const DotConvDimsMapping & dims_mapping,HloInstruction * original_hlo)2436 DotConvDimsMapping ConvertDimsMappingWithFeatureGroupCount(
2437     const DotConvDimsMapping& dims_mapping, HloInstruction* original_hlo) {
2438   const auto& dnums = original_hlo->convolution_dimension_numbers();
2439   DotConvDimsMapping new_dims_mapping;
2440   new_dims_mapping.batch_dims = dims_mapping.batch_dims;
2441   new_dims_mapping.conv_spatial_dims = dims_mapping.conv_spatial_dims;
2442   // Append batch dims.
2443   new_dims_mapping.batch_dims.emplace_back();
2444   new_dims_mapping.batch_dims.back().lhs = dnums.input_feature_dimension();
2445   new_dims_mapping.batch_dims.back().rhs =
2446       dnums.kernel_output_feature_dimension();
2447   new_dims_mapping.batch_dims.back().output = dnums.output_feature_dimension();
2448   new_dims_mapping.batch_dims.back().spatial = -1;
2449   // Setup non contracting dims.
2450   new_dims_mapping.lhs_non_contracting_dims.emplace_back();
2451   new_dims_mapping.lhs_non_contracting_dims.back().lhs =
2452       dnums.input_batch_dimension();
2453   new_dims_mapping.rhs_non_contracting_dims.emplace_back();
2454   new_dims_mapping.rhs_non_contracting_dims.back().rhs =
2455       dnums.kernel_input_feature_dimension();
2456   return new_dims_mapping;
2457 }
2458 
ConvertDimsMappingWithBatchGroupCount(const DotConvDimsMapping & dims_mapping,HloInstruction * original_hlo)2459 DotConvDimsMapping ConvertDimsMappingWithBatchGroupCount(
2460     const DotConvDimsMapping& dims_mapping, HloInstruction* original_hlo) {
2461   const auto& dnums = original_hlo->convolution_dimension_numbers();
2462   DotConvDimsMapping new_dims_mapping;
2463   new_dims_mapping.batch_dims = dims_mapping.batch_dims;
2464   new_dims_mapping.conv_spatial_dims = dims_mapping.conv_spatial_dims;
2465   new_dims_mapping.contracting_dims = dims_mapping.contracting_dims;
2466   // Append batch dims.
2467   new_dims_mapping.batch_dims.emplace_back();
2468   new_dims_mapping.batch_dims.back().lhs = dnums.input_batch_dimension();
2469   new_dims_mapping.batch_dims.back().rhs =
2470       dnums.kernel_output_feature_dimension();
2471   new_dims_mapping.batch_dims.back().output = dnums.output_feature_dimension();
2472   new_dims_mapping.batch_dims.back().spatial = -1;
2473   return new_dims_mapping;
2474 }
2475 
2476 // Estimate the number of iterations of a subsequent windowed einsum
2477 // partitioning if its partitioned in the non-contracting dimensions.
2478 // First value returned is the estimate of the number of iterations if LHS is
2479 // matched while the second is the number of iterations if RHS is matched.
2480 std::pair<absl::optional<int64>, absl::optional<int64>>
EstimateWindowedEinsumIterationsForNonContractingPartitioning(const DotConvDimsMapping & dims_mapping,const PartitionedHlo & lhs,const PartitionedHlo & rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const SpmdPartitionerOptions & options,int64_t num_partitions,int64_t lhs_non_contracting_partitions,int64_t rhs_non_contracting_partitions,int64_t lhs_matching_partitions,int64_t rhs_matching_partitions,int64_t lhs_contracting_partitions,int64_t rhs_contracting_partitions,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,int64_t lhs_batch_partitions,int64_t rhs_batch_partitions)2481 EstimateWindowedEinsumIterationsForNonContractingPartitioning(
2482     const DotConvDimsMapping& dims_mapping, const PartitionedHlo& lhs,
2483     const PartitionedHlo& rhs, const Shape& output_base_shape,
2484     const HloSharding& output_sharding, const SpmdPartitionerOptions& options,
2485     int64_t num_partitions, int64_t lhs_non_contracting_partitions,
2486     int64_t rhs_non_contracting_partitions, int64_t lhs_matching_partitions,
2487     int64_t rhs_matching_partitions, int64_t lhs_contracting_partitions,
2488     int64_t rhs_contracting_partitions,
2489     int64_t output_lhs_non_contracting_partitions,
2490     int64_t output_rhs_non_contracting_partitions, int64_t lhs_batch_partitions,
2491     int64_t rhs_batch_partitions) {
2492   const DotDimensionIndexMapping indices_map = ComputeDimensionIndexMapping(
2493       dims_mapping, lhs.base_shape().rank(), rhs.base_shape().rank(),
2494       output_base_shape.rank());
2495   auto subsequent_einsum_iterations_estimate =
2496       [&](bool assume_lhs_match) -> absl::optional<int64> {
2497     const std::vector<DotConvDimsMapping::DimsMapping>&
2498         matching_non_contracting_dims =
2499             assume_lhs_match ? dims_mapping.lhs_non_contracting_dims
2500                              : dims_mapping.rhs_non_contracting_dims;
2501     const std::vector<DotConvDimsMapping::DimsMapping>&
2502         other_non_contracting_dims =
2503             assume_lhs_match ? dims_mapping.rhs_non_contracting_dims
2504                              : dims_mapping.lhs_non_contracting_dims;
2505     const std::vector<int64>& output_to_matching_indices =
2506         assume_lhs_match ? indices_map.output_to_lhs_indices
2507                          : indices_map.output_to_rhs_indices;
2508     const std::vector<int64>& output_to_other_indices =
2509         assume_lhs_match ? indices_map.output_to_rhs_indices
2510                          : indices_map.output_to_lhs_indices;
2511     const std::vector<int64>& matching_to_output_indices =
2512         assume_lhs_match ? indices_map.lhs_to_output_indices
2513                          : indices_map.rhs_to_output_indices;
2514     const std::vector<int64>& other_to_output_indices =
2515         assume_lhs_match ? indices_map.rhs_to_output_indices
2516                          : indices_map.lhs_to_output_indices;
2517     const HloSharding& matching_sharding =
2518         assume_lhs_match ? lhs.sharding() : rhs.sharding();
2519     const HloSharding& other_sharding =
2520         assume_lhs_match ? rhs.sharding() : lhs.sharding();
2521     const PartitionedHlo& matching_partitioned = assume_lhs_match ? lhs : rhs;
2522     const PartitionedHlo& other_partitioned = assume_lhs_match ? rhs : lhs;
2523     const int64_t matching_non_contracting_partitions =
2524         assume_lhs_match ? lhs_non_contracting_partitions
2525                          : rhs_non_contracting_partitions;
2526     const int64_t other_non_contracting_partitions =
2527         assume_lhs_match ? rhs_non_contracting_partitions
2528                          : lhs_non_contracting_partitions;
2529     const int64_t matching_contracting_partitions =
2530         assume_lhs_match ? lhs_contracting_partitions
2531                          : rhs_contracting_partitions;
2532     const int64_t other_contracting_partitions =
2533         assume_lhs_match ? rhs_contracting_partitions
2534                          : lhs_contracting_partitions;
2535     const int64_t output_matching_non_contracting_partitions =
2536         assume_lhs_match ? output_lhs_non_contracting_partitions
2537                          : output_rhs_non_contracting_partitions;
2538     const int64_t output_other_non_contracting_partitions =
2539         assume_lhs_match ? output_rhs_non_contracting_partitions
2540                          : output_lhs_non_contracting_partitions;
2541     const int64_t matching_batch_partitions =
2542         assume_lhs_match ? lhs_batch_partitions : rhs_batch_partitions;
2543     const int64_t other_batch_partitions =
2544         assume_lhs_match ? rhs_batch_partitions : lhs_batch_partitions;
2545     const int64_t matching_matched_non_contracting_partitions =
2546         assume_lhs_match ? lhs_non_contracting_partitions
2547                          : rhs_non_contracting_partitions;
2548     std::vector<int64> output_dims;
2549     output_dims.reserve(matching_non_contracting_dims.size());
2550     for (const DotConvDimsMapping::DimsMapping& dim :
2551          matching_non_contracting_dims) {
2552       output_dims.push_back(dim.output);
2553     }
2554     GroupedSharding output_grouped =
2555         GroupShardingOnDims(output_sharding, output_dims);
2556     GroupedSharding matching_grouped =
2557         GetNonContractingPartitionGroupedShardingForMatchedOperand(
2558             assume_lhs_match, matching_sharding, output_sharding,
2559             matching_non_contracting_dims);
2560     absl::optional<GroupedSharding> other_grouped =
2561         GetNonContractingPartitionGroupedShardingForOtherOperand(
2562             assume_lhs_match, output_base_shape,
2563             other_partitioned.hlo()->shape(), other_contracting_partitions,
2564             other_non_contracting_partitions, matching_contracting_partitions,
2565             output_other_non_contracting_partitions, other_sharding,
2566             output_sharding, matching_non_contracting_dims,
2567             other_non_contracting_dims, dims_mapping.contracting_dims);
2568     if (!other_grouped) {
2569       return absl::nullopt;
2570     }
2571     absl::optional<HloSharding> output_sharding_transposed_to_match_matching =
2572         hlo_sharding_util::TransposeShardingWithCollapsedDims(
2573             output_grouped.sharding, output_to_matching_indices,
2574             matching_to_output_indices);
2575     absl::optional<HloSharding> output_sharding_transposed_to_match_other =
2576         hlo_sharding_util::TransposeShardingWithCollapsedDims(
2577             output_grouped.sharding, output_to_other_indices,
2578             other_to_output_indices);
2579     const int64_t new_num_partitions =
2580         num_partitions / matching_non_contracting_partitions;
2581     absl::optional<WindowedEinsumConfig> e_config =
2582         GetWindowedEinsumConfiguration(
2583             new_num_partitions, output_matching_non_contracting_partitions,
2584             output_other_non_contracting_partitions,
2585             other_contracting_partitions, other_non_contracting_partitions,
2586             other_batch_partitions, matching_contracting_partitions,
2587             matching_non_contracting_partitions /
2588                 matching_matched_non_contracting_partitions,
2589             matching_batch_partitions,
2590             ShapeSizeInBytes(other_partitioned.base_shape()),
2591             ShapeSizeInBytes(matching_partitioned.base_shape()) /
2592                 matching_non_contracting_partitions,
2593             ShapeSizeInBytes(
2594                 GetPerGroupBaseShape(output_grouped, output_base_shape)),
2595             options.threshold_for_windowed_einsum_mib,
2596             output_sharding_transposed_to_match_matching,
2597             output_sharding_transposed_to_match_other,
2598             matching_grouped.sharding, other_grouped->sharding);
2599     return e_config ? new_num_partitions : absl::optional<int64>(absl::nullopt);
2600   };
2601   absl::optional<int64> lhs_matching_iterations;
2602   if (lhs_matching_partitions != 0) {
2603     lhs_matching_iterations = subsequent_einsum_iterations_estimate(true);
2604   }
2605   absl::optional<int64> rhs_matching_iterations;
2606   if (rhs_matching_partitions != 0) {
2607     rhs_matching_iterations = subsequent_einsum_iterations_estimate(false);
2608   }
2609   return std::make_pair(lhs_matching_iterations, rhs_matching_iterations);
2610 }
2611 
2612 // Return if we should prioritize partitioning in the contracting dimensions
2613 // first then non-contracting dimensions if we estimate that would allow
2614 // for a fewer number of iterations of the windowed einsum.
PrioritizeContractingDimensionsPartitioning(const DotConvDimsMapping & dims_mapping,const PartitionedHlo & lhs,const PartitionedHlo & rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const SpmdPartitionerOptions & options,int64_t num_partitions,int64_t lhs_non_contracting_partitions,int64_t rhs_non_contracting_partitions,int64_t lhs_contracting_partitions,int64_t rhs_contracting_partitions,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,int64_t lhs_batch_partitions,int64_t rhs_batch_partitions,int64_t output_batch_partitions,bool require_matching_devices_to_group)2615 bool PrioritizeContractingDimensionsPartitioning(
2616     const DotConvDimsMapping& dims_mapping, const PartitionedHlo& lhs,
2617     const PartitionedHlo& rhs, const Shape& output_base_shape,
2618     const HloSharding& output_sharding, const SpmdPartitionerOptions& options,
2619     int64_t num_partitions, int64_t lhs_non_contracting_partitions,
2620     int64_t rhs_non_contracting_partitions, int64_t lhs_contracting_partitions,
2621     int64_t rhs_contracting_partitions,
2622     int64_t output_lhs_non_contracting_partitions,
2623     int64_t output_rhs_non_contracting_partitions, int64_t lhs_batch_partitions,
2624     int64_t rhs_batch_partitions, int64_t output_batch_partitions,
2625     bool require_matching_devices_to_group) {
2626   const bool may_group_on_lhs_non_contracting =
2627       lhs_non_contracting_partitions == output_lhs_non_contracting_partitions &&
2628       lhs_non_contracting_partitions > 1;
2629   const bool may_group_on_rhs_non_contracting =
2630       rhs_non_contracting_partitions == output_rhs_non_contracting_partitions &&
2631       rhs_non_contracting_partitions > 1;
2632   if (!options.choose_faster_windowed_einsum_over_mem) {
2633     return false;
2634   }
2635   // Check only for perfect dimensions match for now.
2636   if (!may_group_on_lhs_non_contracting && !may_group_on_rhs_non_contracting) {
2637     return false;
2638   }
2639   absl::optional<int64> lhs_matching_iterations;
2640   absl::optional<int64> rhs_matching_iterations;
2641   const int64_t lhs_matching_non_contracting_partitions =
2642       may_group_on_lhs_non_contracting ? lhs_non_contracting_partitions : 0;
2643   const int64_t rhs_matching_non_contracting_partitions =
2644       may_group_on_rhs_non_contracting ? rhs_non_contracting_partitions : 0;
2645   std::tie(lhs_matching_iterations, rhs_matching_iterations) =
2646       EstimateWindowedEinsumIterationsForNonContractingPartitioning(
2647           dims_mapping, lhs, rhs, output_base_shape, output_sharding, options,
2648           num_partitions, lhs_non_contracting_partitions,
2649           rhs_non_contracting_partitions,
2650           lhs_matching_non_contracting_partitions,
2651           rhs_matching_non_contracting_partitions, lhs_contracting_partitions,
2652           rhs_contracting_partitions, output_lhs_non_contracting_partitions,
2653           output_rhs_non_contracting_partitions, lhs_batch_partitions,
2654           rhs_batch_partitions);
2655   if (!lhs_matching_iterations && !rhs_matching_iterations) {
2656     return false;
2657   }
2658   // Be conservative and handle only case where the two partitions in rhs and
2659   // lhs match
2660   if (!(lhs_contracting_partitions == rhs_contracting_partitions &&
2661         lhs_contracting_partitions > 1)) {
2662     return false;
2663   }
2664   // Estimate the iterations in the case we perform the partitioning on the
2665   // contracting dimensions instead.
2666   std::vector<int64> lhs_dims;
2667   std::vector<int64> rhs_dims;
2668   int64_t group_count = 1;
2669   for (const auto& dim : dims_mapping.contracting_dims) {
2670     lhs_dims.push_back(dim.lhs);
2671     rhs_dims.push_back(dim.rhs);
2672     group_count *= lhs.sharding().tile_assignment().dim(dim.lhs);
2673   }
2674   HloSharding lhs_sharding = HloSharding::Replicate();
2675   HloSharding rhs_sharding = HloSharding::Replicate();
2676   std::tie(lhs_sharding, rhs_sharding) =
2677       GetDotGroupPartitionContractingLhsRhsShardings(
2678           lhs, rhs, dims_mapping.contracting_dims);
2679   auto lhs_grouped = GroupShardingOnDims(lhs_sharding, lhs_dims);
2680   auto rhs_grouped = GroupShardingOnDims(rhs_sharding, rhs_dims);
2681   rhs_grouped = AlignGroupsWith(rhs_grouped, lhs_grouped);
2682   rhs_sharding = UngroupSharding(rhs_grouped);
2683 
2684   if (require_matching_devices_to_group && rhs.sharding() != rhs_sharding) {
2685     return false;
2686   }
2687   const int64_t new_num_partitions =
2688       num_partitions / lhs_contracting_partitions;
2689 
2690   HloSharding inner_output_sharding = HloSharding::Replicate();
2691   HloSharding outer_output_tmp_sharding = HloSharding::Replicate();
2692   std::vector<int64> output_slice_dims;
2693   std::tie(inner_output_sharding, outer_output_tmp_sharding) =
2694       GetDotGroupPartitionContractingOutputShardings(
2695           dims_mapping, lhs_grouped, output_base_shape, output_sharding,
2696           group_count, output_lhs_non_contracting_partitions,
2697           output_rhs_non_contracting_partitions, output_batch_partitions,
2698           &output_slice_dims);
2699   Shape inner_output_base_shape = output_base_shape;
2700   if (!output_slice_dims.empty()) {
2701     std::vector<int64> non_group_dims;
2702     for (int64_t i = 0; i < output_base_shape.rank(); ++i) {
2703       if (!absl::c_linear_search(output_slice_dims, i)) {
2704         non_group_dims.push_back(i);
2705       }
2706     }
2707     inner_output_base_shape = MakePartitionedShape(
2708         output_base_shape,
2709         hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2710             output_sharding, non_group_dims));
2711   }
2712   int64_t new_output_lhs_non_contracting_partitions = 1;
2713   int64_t new_output_rhs_non_contracting_partitions = 1;
2714   if (!inner_output_sharding.IsTileMaximal()) {
2715     for (const auto& dim : dims_mapping.lhs_non_contracting_dims) {
2716       new_output_lhs_non_contracting_partitions *=
2717           inner_output_sharding.tile_assignment().dim(dim.output);
2718     }
2719     for (const auto& dim : dims_mapping.rhs_non_contracting_dims) {
2720       if (dim.output != -1) {
2721         new_output_rhs_non_contracting_partitions *=
2722             inner_output_sharding.tile_assignment().dim(dim.output);
2723       }
2724     }
2725   }
2726 
2727   const DotDimensionIndexMapping indices_map = ComputeDimensionIndexMapping(
2728       dims_mapping, lhs.base_shape().rank(), rhs.base_shape().rank(),
2729       inner_output_base_shape.rank());
2730   absl::optional<HloSharding> output_sharding_transposed_to_match_lhs =
2731       hlo_sharding_util::TransposeShardingWithCollapsedDims(
2732           inner_output_sharding, indices_map.output_to_lhs_indices,
2733           indices_map.lhs_to_output_indices);
2734   absl::optional<HloSharding> output_sharding_transposed_to_match_rhs =
2735       hlo_sharding_util::TransposeShardingWithCollapsedDims(
2736           inner_output_sharding, indices_map.output_to_rhs_indices,
2737           indices_map.rhs_to_output_indices);
2738   absl::optional<WindowedEinsumConfig> e_config =
2739       GetWindowedEinsumConfiguration(
2740           new_num_partitions, new_output_lhs_non_contracting_partitions,
2741           new_output_rhs_non_contracting_partitions, 1,
2742           rhs_non_contracting_partitions, rhs_batch_partitions, 1,
2743           lhs_non_contracting_partitions, lhs_batch_partitions,
2744           ShapeSizeInBytes(GetPerGroupBaseShape(rhs_grouped, rhs.base_shape())),
2745           ShapeSizeInBytes(GetPerGroupBaseShape(lhs_grouped, lhs.base_shape())),
2746           ShapeSizeInBytes(inner_output_base_shape),
2747           options.threshold_for_windowed_einsum_mib,
2748           output_sharding_transposed_to_match_lhs,
2749           output_sharding_transposed_to_match_rhs, lhs_grouped.sharding,
2750           rhs_grouped.sharding);
2751   if (!e_config) {
2752     return false;
2753   }
2754   const int64_t min_nc_iterations =
2755       std::min(lhs_matching_iterations ? *lhs_matching_iterations : INT64_MAX,
2756                rhs_matching_iterations ? *rhs_matching_iterations : INT64_MAX);
2757   return min_nc_iterations > new_num_partitions;
2758 }
2759 
2760 // Return if it would be better to match the LHS operand or RHS operand
2761 // of a dot for non-contracting partitioning.
LhsIsBestMatchForNonContractingPartitioning(const DotConvDimsMapping & dims_mapping,const PartitionedHlo & lhs,const PartitionedHlo & rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const SpmdPartitionerOptions & options,int64_t num_partitions,int64_t lhs_non_contracting_partitions,int64_t rhs_non_contracting_partitions,int64_t lhs_matching_partitions,int64_t rhs_matching_partitions,int64_t lhs_contracting_partitions,int64_t rhs_contracting_partitions,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,int64_t lhs_batch_partitions,int64_t rhs_batch_partitions)2762 bool LhsIsBestMatchForNonContractingPartitioning(
2763     const DotConvDimsMapping& dims_mapping, const PartitionedHlo& lhs,
2764     const PartitionedHlo& rhs, const Shape& output_base_shape,
2765     const HloSharding& output_sharding, const SpmdPartitionerOptions& options,
2766     int64_t num_partitions, int64_t lhs_non_contracting_partitions,
2767     int64_t rhs_non_contracting_partitions, int64_t lhs_matching_partitions,
2768     int64_t rhs_matching_partitions, int64_t lhs_contracting_partitions,
2769     int64_t rhs_contracting_partitions,
2770     int64_t output_lhs_non_contracting_partitions,
2771     int64_t output_rhs_non_contracting_partitions, int64_t lhs_batch_partitions,
2772     int64_t rhs_batch_partitions) {
2773   const bool may_group_on_lhs_non_contracting =
2774       lhs_non_contracting_partitions == output_lhs_non_contracting_partitions &&
2775       lhs_non_contracting_partitions > 1;
2776   const bool may_group_on_rhs_non_contracting =
2777       rhs_non_contracting_partitions == output_rhs_non_contracting_partitions &&
2778       rhs_non_contracting_partitions > 1;
2779   // If both match output non-contracting dimensions, choose the one which
2780   // will result in smaller replication of the other operand.
2781   bool lhs_matching = may_group_on_lhs_non_contracting &&
2782                       (!may_group_on_rhs_non_contracting ||
2783                        lhs_non_contracting_partitions *
2784                                ShapeUtil::ByteSizeOf(rhs.hlo()->shape()) <
2785                            rhs_non_contracting_partitions *
2786                                ShapeUtil::ByteSizeOf(lhs.hlo()->shape()));
2787   // If both grouping are available and the option to choose faster windowed
2788   // einsums vs saving memory is enabled then try to determine which of the
2789   // operands will generate the least amount of iterations for the windowed
2790   // einsum when matched (if a windowed einsum is gonna be generated at
2791   // all).
2792   if (may_group_on_lhs_non_contracting && may_group_on_rhs_non_contracting &&
2793       options.choose_faster_windowed_einsum_over_mem) {
2794     const DotDimensionIndexMapping indices_map = ComputeDimensionIndexMapping(
2795         dims_mapping, lhs.base_shape().rank(), rhs.base_shape().rank(),
2796         output_base_shape.rank());
2797     absl::optional<int64> lhs_matching_iterations;
2798     absl::optional<int64> rhs_matching_iterations;
2799     std::tie(lhs_matching_iterations, rhs_matching_iterations) =
2800         EstimateWindowedEinsumIterationsForNonContractingPartitioning(
2801             dims_mapping, lhs, rhs, output_base_shape, output_sharding, options,
2802             num_partitions, lhs_non_contracting_partitions,
2803             rhs_non_contracting_partitions, lhs_matching_partitions,
2804             rhs_matching_partitions, lhs_contracting_partitions,
2805             rhs_contracting_partitions, output_lhs_non_contracting_partitions,
2806             output_rhs_non_contracting_partitions, lhs_batch_partitions,
2807             rhs_batch_partitions);
2808     if (lhs_matching_iterations && rhs_matching_iterations &&
2809         *lhs_matching_iterations != *rhs_matching_iterations) {
2810       lhs_matching = *lhs_matching_iterations < *rhs_matching_iterations;
2811     }
2812   }
2813   return lhs_matching;
2814 }
2815 
2816 // Recursive partitioning function. If there are partial dimensions matching
2817 // in the operands and output, group the devices and recursively partition
2818 // the in-group dot.
PartitionDot(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64_t num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,bool require_matching_devices_to_group,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops)2819 StatusOr<HloInstruction*> PartitionDot(
2820     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
2821     const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
2822     int64_t num_partitions,
2823     const std::function<StatusOr<HloInstruction*>(
2824         HloInstruction*, HloInstruction*, SpmdBuilder*,
2825         const Window& conv_window)>& create_sharded_dot,
2826     const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
2827     bool require_matching_devices_to_group,
2828     const SpmdPartitionerOptions& options, SpmdBuilder* b,
2829     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
2830         windowed_dot_general_loops) {
2831   // If lhs‘ hlo and rhs' hlo are identical, make a copy for rhs.
2832   if (lhs.hlo() == rhs.hlo()) {
2833     auto copy_hlo = b->AddInstruction(HloInstruction::CreateUnary(
2834         rhs.hlo()->shape(), HloOpcode::kCopy, rhs.hlo()));
2835     copy_hlo->set_sharding(rhs.sharding());
2836     rhs = PartitionedHlo(copy_hlo, rhs.base_shape(), rhs.state());
2837   }
2838 
2839   // lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output.
2840   auto get_partitions_for_dims =
2841       [&](const HloSharding& sharding,
2842           absl::Span<const DotConvDimsMapping::DimsMapping> dims,
2843           int lhs_rhs_or_output) {
2844         int64_t partitions = 1;
2845         if (sharding.IsTileMaximal()) {
2846           return partitions;
2847         }
2848         for (const auto& dim : dims) {
2849           if (lhs_rhs_or_output == 0) {
2850             partitions *= sharding.tile_assignment().dim(dim.lhs);
2851           } else if (lhs_rhs_or_output == 1) {
2852             partitions *= sharding.tile_assignment().dim(dim.rhs);
2853           } else {
2854             CHECK_EQ(lhs_rhs_or_output, 2);
2855             partitions *= sharding.tile_assignment().dim(dim.output);
2856           }
2857         }
2858         return partitions;
2859       };
2860   const int64_t lhs_batch_partitions =
2861       get_partitions_for_dims(lhs.sharding(), dims_mapping.batch_dims, 0);
2862   const int64_t rhs_batch_partitions =
2863       get_partitions_for_dims(rhs.sharding(), dims_mapping.batch_dims, 1);
2864   const int64_t output_batch_partitions =
2865       get_partitions_for_dims(output_sharding, dims_mapping.batch_dims, 2);
2866   const int64_t lhs_contracting_partitions =
2867       get_partitions_for_dims(lhs.sharding(), dims_mapping.contracting_dims, 0);
2868   const int64_t rhs_contracting_partitions =
2869       get_partitions_for_dims(rhs.sharding(), dims_mapping.contracting_dims, 1);
2870   const int64_t lhs_non_contracting_partitions = get_partitions_for_dims(
2871       lhs.sharding(), dims_mapping.lhs_non_contracting_dims, 0);
2872   const int64_t rhs_non_contracting_partitions = get_partitions_for_dims(
2873       rhs.sharding(), dims_mapping.rhs_non_contracting_dims, 1);
2874   const int64_t output_lhs_non_contracting_partitions = get_partitions_for_dims(
2875       output_sharding, dims_mapping.lhs_non_contracting_dims, 2);
2876   const int64_t output_rhs_non_contracting_partitions = get_partitions_for_dims(
2877       output_sharding, dims_mapping.rhs_non_contracting_dims, 2);
2878   const int64_t lhs_conv_spatial_partitions = get_partitions_for_dims(
2879       lhs.sharding(), dims_mapping.conv_spatial_dims, 0);
2880   const int64_t rhs_conv_spatial_partitions = get_partitions_for_dims(
2881       rhs.sharding(), dims_mapping.conv_spatial_dims, 1);
2882   const int64_t output_conv_spatial_partitions = get_partitions_for_dims(
2883       output_sharding, dims_mapping.conv_spatial_dims, 2);
2884   // Before we find partial matches along the dimensions, invoke base case
2885   // again without may_reshard_without_detecting_match.
2886 
2887   // Try partition the purely spatially-partitioned convolution with
2888   // convolution spatial dimension partitioned or depthwise parallel
2889   // dimension partitioned.
2890   bool is_conv_spatial_dim_partitioned =
2891       (lhs_conv_spatial_partitions > 1 || rhs_conv_spatial_partitions > 1 ||
2892        output_conv_spatial_partitions > 1);
2893   bool is_conv_batch_or_contracting_dim_partitioned =
2894       (lhs_batch_partitions > 1 || rhs_batch_partitions > 1 ||
2895        output_batch_partitions > 1 ||
2896        (lhs_contracting_partitions > 1 && rhs_contracting_partitions > 1));
2897   if ((!dims_mapping.conv_spatial_dims.empty() &&
2898        is_conv_spatial_dim_partitioned &&
2899        !is_conv_batch_or_contracting_dim_partitioned) ||
2900       (original_hlo->opcode() == HloOpcode::kConvolution &&
2901        (original_hlo->batch_group_count() > 1 ||
2902         original_hlo->feature_group_count() > 1))) {
2903     // Partition with kernel_input_feature_dim > 1 and feature_group_count >
2904     // 1 is not supported.
2905     const auto& dnums = original_hlo->convolution_dimension_numbers();
2906     if (original_hlo->feature_group_count() > 1 &&
2907         rhs.hlo()->shape().dimensions(dnums.kernel_input_feature_dimension()) >
2908             1) {
2909       return nullptr;
2910     }
2911 
2912     TF_ASSIGN_OR_RETURN(
2913         auto partitioned_conv,
2914         PartitionConvolution(lhs, rhs, output_base_shape, output_sharding,
2915                              dims_mapping, create_sharded_dot, conv_window,
2916                              original_hlo, num_partitions, options,
2917                              lhs.state().partition_id, module, b));
2918 
2919     if (partitioned_conv) {
2920       return partitioned_conv;
2921     }
2922 
2923     // Recursively partition on different types of dimensions for
2924     // convolution. Case 0.a: Group partitions by feature group count.
2925     if (original_hlo->feature_group_count() > 1 ||
2926         original_hlo->batch_group_count() > 1) {
2927       DotConvDimsMapping new_dims_mapping;
2928       if (original_hlo->feature_group_count() > 1) {
2929         new_dims_mapping =
2930             ConvertDimsMappingWithFeatureGroupCount(dims_mapping, original_hlo);
2931       }
2932 
2933       if (original_hlo->batch_group_count() > 1) {
2934         new_dims_mapping =
2935             ConvertDimsMappingWithBatchGroupCount(dims_mapping, original_hlo);
2936       }
2937 
2938       const int64_t conv_lhs_contracting_partitions = get_partitions_for_dims(
2939           lhs.sharding(), new_dims_mapping.contracting_dims, 0);
2940       const int64_t conv_rhs_contracting_partitions = get_partitions_for_dims(
2941           rhs.sharding(), new_dims_mapping.contracting_dims, 1);
2942       const int64_t conv_lhs_non_contracting_partitions =
2943           get_partitions_for_dims(lhs.sharding(),
2944                                   new_dims_mapping.lhs_non_contracting_dims, 0);
2945       const int64_t conv_rhs_non_contracting_partitions =
2946           get_partitions_for_dims(rhs.sharding(),
2947                                   new_dims_mapping.rhs_non_contracting_dims, 1);
2948       const int64_t conv_lhs_batch_partitions = get_partitions_for_dims(
2949           lhs.sharding(), new_dims_mapping.batch_dims, 0);
2950       const int64_t conv_rhs_batch_partitions = get_partitions_for_dims(
2951           rhs.sharding(), new_dims_mapping.batch_dims, 1);
2952       const int64_t conv_output_batch_partitions = get_partitions_for_dims(
2953           output_sharding, new_dims_mapping.batch_dims, 2);
2954       if ((conv_lhs_batch_partitions == conv_output_batch_partitions ||
2955            conv_rhs_batch_partitions == conv_output_batch_partitions) &&
2956           conv_output_batch_partitions > 1) {
2957         TF_ASSIGN_OR_RETURN(
2958             auto try_partitioned_conv,
2959             PartitionDotGroupOnBatch(
2960                 lhs, rhs, output_base_shape, output_sharding, new_dims_mapping,
2961                 num_partitions, conv_lhs_contracting_partitions,
2962                 conv_rhs_contracting_partitions,
2963                 conv_lhs_non_contracting_partitions,
2964                 conv_rhs_non_contracting_partitions, create_sharded_dot,
2965                 conv_window, module, original_hlo,
2966                 require_matching_devices_to_group, options, b,
2967                 windowed_dot_general_loops));
2968         if (try_partitioned_conv) {
2969           return try_partitioned_conv;
2970         }
2971       }
2972       return nullptr;
2973     }
2974   }
2975 
2976   TF_ASSIGN_OR_RETURN(
2977       auto try_partitioned_dot,
2978       PartitionBaseCase(
2979           lhs, rhs, output_base_shape, output_sharding, dims_mapping,
2980           num_partitions, create_sharded_dot, conv_window, module, original_hlo,
2981           lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions,
2982           lhs_contracting_partitions, rhs_contracting_partitions,
2983           lhs_non_contracting_partitions, rhs_non_contracting_partitions,
2984           output_lhs_non_contracting_partitions,
2985           output_rhs_non_contracting_partitions, options, b,
2986           windowed_dot_general_loops,
2987           /*may_reshard_without_detecting_match=*/false));
2988   if (try_partitioned_dot) {
2989     return try_partitioned_dot;
2990   }
2991 
2992   // Recursively partition on different types of dimensions.
2993   //
2994   // Case 1: Group partitions by batch.
2995   if ((lhs_batch_partitions == output_batch_partitions ||
2996        rhs_batch_partitions == output_batch_partitions) &&
2997       output_batch_partitions > 1) {
2998     TF_ASSIGN_OR_RETURN(
2999         auto dot,
3000         PartitionDotGroupOnBatch(
3001             lhs, rhs, output_base_shape, output_sharding, dims_mapping,
3002             num_partitions, lhs_contracting_partitions,
3003             rhs_contracting_partitions, lhs_non_contracting_partitions,
3004             rhs_non_contracting_partitions, create_sharded_dot, conv_window,
3005             module, original_hlo, require_matching_devices_to_group, options, b,
3006             windowed_dot_general_loops));
3007     if (dot) {
3008       return dot;
3009     }
3010   }
3011 
3012   // Case 2: Group partitions by non-contracting dimensions.
3013   const bool may_group_on_lhs_non_contracting =
3014       lhs_non_contracting_partitions == output_lhs_non_contracting_partitions &&
3015       lhs_non_contracting_partitions > 1;
3016   const bool may_group_on_rhs_non_contracting =
3017       rhs_non_contracting_partitions == output_rhs_non_contracting_partitions &&
3018       rhs_non_contracting_partitions > 1;
3019   bool lhs_matching = false;
3020   std::vector<DotConvDimsMapping::DimsMapping> matching_dims;
3021   if (may_group_on_lhs_non_contracting || may_group_on_rhs_non_contracting) {
3022     lhs_matching = LhsIsBestMatchForNonContractingPartitioning(
3023         dims_mapping, lhs, rhs, output_base_shape, output_sharding, options,
3024         num_partitions, lhs_non_contracting_partitions,
3025         rhs_non_contracting_partitions, lhs_non_contracting_partitions,
3026         rhs_non_contracting_partitions, lhs_contracting_partitions,
3027         rhs_contracting_partitions, output_lhs_non_contracting_partitions,
3028         output_rhs_non_contracting_partitions, lhs_batch_partitions,
3029         rhs_batch_partitions);
3030     matching_dims = lhs_matching ? dims_mapping.lhs_non_contracting_dims
3031                                  : dims_mapping.rhs_non_contracting_dims;
3032   } else if (lhs_non_contracting_partitions > 1 &&
3033              output_lhs_non_contracting_partitions > 1) {
3034     lhs_matching = true;
3035     for (const auto& dim : dims_mapping.lhs_non_contracting_dims) {
3036       int64_t lhs_partitions = lhs.sharding().tile_assignment().dim(dim.lhs);
3037       if (lhs_partitions > 1 &&
3038           lhs_partitions == output_sharding.tile_assignment().dim(dim.output)) {
3039         matching_dims.push_back(dim);
3040       }
3041     }
3042   } else if (rhs_non_contracting_partitions > 1 &&
3043              output_rhs_non_contracting_partitions > 1) {
3044     lhs_matching = false;
3045     for (const auto& dim : dims_mapping.rhs_non_contracting_dims) {
3046       int64_t rhs_partitions = rhs.sharding().tile_assignment().dim(dim.rhs);
3047       if (rhs_partitions > 1 &&
3048           rhs_partitions == output_sharding.tile_assignment().dim(dim.output)) {
3049         matching_dims.push_back(dim);
3050       }
3051     }
3052   }
3053   const bool prioritize_contracting_for_faster_windowed_einsum =
3054       PrioritizeContractingDimensionsPartitioning(
3055           dims_mapping, lhs, rhs, output_base_shape, output_sharding, options,
3056           num_partitions, lhs_non_contracting_partitions,
3057           rhs_non_contracting_partitions, lhs_contracting_partitions,
3058           rhs_contracting_partitions, output_lhs_non_contracting_partitions,
3059           output_rhs_non_contracting_partitions, lhs_batch_partitions,
3060           rhs_batch_partitions, output_batch_partitions,
3061           require_matching_devices_to_group);
3062   if (!(matching_dims.empty() ||
3063         prioritize_contracting_for_faster_windowed_einsum)) {
3064     TF_ASSIGN_OR_RETURN(
3065         auto dot,
3066         PartitionDotGroupOnNonContracting(
3067             lhs_matching, lhs_matching ? lhs : rhs, lhs_matching ? rhs : lhs,
3068             lhs_matching ? lhs_contracting_partitions
3069                          : rhs_contracting_partitions,
3070             lhs_matching ? rhs_contracting_partitions
3071                          : lhs_contracting_partitions,
3072             matching_dims,
3073             lhs_matching ? rhs_non_contracting_partitions
3074                          : lhs_non_contracting_partitions,
3075             lhs_matching ? output_rhs_non_contracting_partitions
3076                          : output_lhs_non_contracting_partitions,
3077             output_base_shape, output_sharding, dims_mapping, num_partitions,
3078             create_sharded_dot, conv_window, module, original_hlo,
3079             require_matching_devices_to_group, options, b,
3080             windowed_dot_general_loops));
3081     if (dot) {
3082       return dot;
3083     }
3084   }
3085 
3086   // Case 3: Group partitions by contracting dimensions.
3087   if (lhs_contracting_partitions == rhs_contracting_partitions &&
3088       lhs_contracting_partitions > 1) {
3089     TF_ASSIGN_OR_RETURN(
3090         auto dot,
3091         PartitionDotGroupOnContracting(
3092             lhs, rhs, dims_mapping.contracting_dims, output_batch_partitions,
3093             output_lhs_non_contracting_partitions,
3094             output_rhs_non_contracting_partitions, output_base_shape,
3095             output_sharding, dims_mapping, num_partitions, create_sharded_dot,
3096             conv_window, module, original_hlo,
3097             require_matching_devices_to_group, options, b,
3098             windowed_dot_general_loops));
3099     if (dot) {
3100       return dot;
3101     }
3102   }
3103   if (lhs_contracting_partitions > 1 && rhs_contracting_partitions > 1) {
3104     // If part of contracting dims match, try them.
3105     std::vector<DotConvDimsMapping::DimsMapping> matching_dims;
3106     for (const auto& dim : dims_mapping.contracting_dims) {
3107       int64_t lhs_partitions = lhs.sharding().tile_assignment().dim(dim.lhs);
3108       if (lhs_partitions > 1 &&
3109           lhs_partitions == rhs.sharding().tile_assignment().dim(dim.rhs)) {
3110         matching_dims.push_back(dim);
3111       }
3112     }
3113     if (!matching_dims.empty()) {
3114       TF_ASSIGN_OR_RETURN(
3115           auto dot, PartitionDotGroupOnContracting(
3116                         lhs, rhs, matching_dims, output_batch_partitions,
3117                         output_lhs_non_contracting_partitions,
3118                         output_rhs_non_contracting_partitions,
3119                         output_base_shape, output_sharding, dims_mapping,
3120                         num_partitions, create_sharded_dot, conv_window, module,
3121                         original_hlo, require_matching_devices_to_group,
3122                         options, b, windowed_dot_general_loops));
3123       if (dot) {
3124         return dot;
3125       }
3126     }
3127   }
3128 
3129   // Case 4: If operands are replicated but output is partially replicated,
3130   // recursive call with partial replication removed.
3131   if (lhs.sharding().IsReplicated() && rhs.sharding().IsReplicated() &&
3132       output_sharding.ReplicateOnLastTileDim()) {
3133     auto grouped_output =
3134         GroupShardingOnDims(output_sharding, {output_base_shape.rank()});
3135     auto inner_state = CreatePerGroupPartitioningState(
3136         lhs.state(), grouped_output.device_groups, b);
3137     TF_ASSIGN_OR_RETURN(
3138         auto dot,
3139         PartitionDot(PartitionedHlo(lhs.hlo(), lhs.base_shape(), inner_state),
3140                      PartitionedHlo(rhs.hlo(), rhs.base_shape(), inner_state),
3141                      output_base_shape, grouped_output.sharding, dims_mapping,
3142                      output_sharding.NumTiles(), create_sharded_dot,
3143                      conv_window, module, original_hlo, options, b,
3144                      windowed_dot_general_loops));
3145     if (dot) {
3146       return dot;
3147     }
3148   }
3149 
3150   // We failed to find partial matches, invoke base case again with
3151   // may_reshard_without_detecting_match.
3152   TF_ASSIGN_OR_RETURN(
3153       auto dot,
3154       PartitionBaseCase(
3155           lhs, rhs, output_base_shape, output_sharding, dims_mapping,
3156           num_partitions, create_sharded_dot, conv_window, module, original_hlo,
3157           lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions,
3158           lhs_contracting_partitions, rhs_contracting_partitions,
3159           lhs_non_contracting_partitions, rhs_non_contracting_partitions,
3160           output_lhs_non_contracting_partitions,
3161           output_rhs_non_contracting_partitions, options, b,
3162           windowed_dot_general_loops,
3163           /*may_reshard_without_detecting_match=*/true));
3164   if (dot) {
3165     return dot;
3166   }
3167   return nullptr;
3168 }
3169 
PartitionDot(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64_t num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops)3170 StatusOr<HloInstruction*> PartitionDot(
3171     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
3172     const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
3173     int64_t num_partitions,
3174     const std::function<StatusOr<HloInstruction*>(
3175         HloInstruction*, HloInstruction*, SpmdBuilder*,
3176         const Window& conv_window)>& create_sharded_dot,
3177     const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
3178     const SpmdPartitionerOptions& options, SpmdBuilder* b,
3179     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
3180         windowed_dot_general_loops) {
3181   // First try partitioning without resharding the groups, then try allow
3182   // resharding the groups.
3183   for (bool require_matching_devices_to_group : {true, false}) {
3184     TF_ASSIGN_OR_RETURN(
3185         auto try_partition,
3186         PartitionDot(lhs, rhs, output_base_shape, output_sharding, dims_mapping,
3187                      num_partitions, create_sharded_dot, conv_window, module,
3188                      original_hlo, require_matching_devices_to_group, options,
3189                      b, windowed_dot_general_loops));
3190     if (try_partition) {
3191       return try_partition;
3192     }
3193   }
3194 
3195   // Default action.
3196   TF_ASSIGN_OR_RETURN(
3197       auto dot, create_sharded_dot(lhs.Replicate().hlo(), rhs.Replicate().hlo(),
3198                                    b, conv_window));
3199   dot->set_sharding(HloSharding::Replicate());
3200   return PartitionedHlo(dot, output_base_shape, lhs.state())
3201       .Reshard(output_sharding)
3202       .hlo();
3203 }
3204 
3205 }  // namespace
3206 
HandleDotHelper(HloInstruction * hlo,const DotConvDimsMapping & dims_mapping,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot)3207 Status SpmdPartitioningVisitor::HandleDotHelper(
3208     HloInstruction* hlo, const DotConvDimsMapping& dims_mapping,
3209     const std::function<StatusOr<HloInstruction*>(
3210         HloInstruction*, HloInstruction*, SpmdBuilder*,
3211         const Window& conv_window)>& create_sharded_dot) {
3212   if (hlo->sharding().HasUniqueDevice()) {
3213     return DefaultAction(hlo);
3214   }
3215   auto& lhs = GetPartitionedHlo(hlo->operand(0));
3216   auto& rhs = GetPartitionedHlo(hlo->operand(1));
3217   Window conv_window;
3218   if (hlo->opcode() == HloOpcode::kConvolution) {
3219     conv_window = hlo->window();
3220   }
3221 
3222   TF_ASSIGN_OR_RETURN(
3223       auto partitioned_dot,
3224       PartitionDot(lhs, rhs, hlo->shape(), hlo->sharding(), dims_mapping,
3225                    num_partitions_, create_sharded_dot, conv_window, module_,
3226                    hlo, options_, &b_, &windowed_dot_general_loops_));
3227   SetPartitionedHlo(hlo, [&] { return partitioned_dot; });
3228   return Status::OK();
3229 }
3230 
3231 namespace {
3232 
3233 // Finds a cluster of nodes that produce the inputs for `hlo` which only
3234 // depend on small operands, which means the cluster should start with
3235 // broadcasts, constants and iotas. All other internal nodes must be
3236 // non-side-effecting elemntwise ops. Returns the set of nodes, and the small
3237 // operands. E.g., for the following graph,
3238 //
3239 //     a -> broadcast -> multiply
3240 //     iota  ---> add--/
3241 //     constant/
3242 //
3243 // FindInputNodesIfOnlyDependOnSmallOperands(multiply) will return
3244 //    <{broadcast, iota, constant, add, multiply}, [a]>.
3245 std::pair<absl::flat_hash_set<HloInstruction*>, std::vector<HloInstruction*>>
FindInputNodesIfOnlyDependOnSmallOperands(HloInstruction * hlo)3246 FindInputNodesIfOnlyDependOnSmallOperands(HloInstruction* hlo) {
3247   absl::flat_hash_set<HloInstruction*> nodes_found;
3248   std::vector<HloInstruction*> new_operands;
3249   absl::flat_hash_set<const HloInstruction*> new_operands_set;
3250   std::vector<HloInstruction*> worklist;
3251   worklist.push_back(hlo);
3252   while (!worklist.empty()) {
3253     auto inst = worklist.back();
3254     worklist.pop_back();
3255     if (nodes_found.count(inst) > 0) {
3256       continue;
3257     }
3258     if (inst->opcode() == HloOpcode::kBroadcast ||
3259         inst->opcode() == HloOpcode::kConstant ||
3260         inst->opcode() == HloOpcode::kIota) {
3261       nodes_found.insert(inst);
3262       for (auto o : inst->operands()) {
3263         auto res = new_operands_set.emplace(o);
3264         if (res.second) {
3265           new_operands.push_back(o);
3266         }
3267       }
3268     } else if (inst->IsElementwise() && !inst->HasSideEffectNoRecurse() &&
3269                absl::c_all_of(inst->operands(),
3270                               [inst](const HloInstruction* o) {
3271                                 return ShapeUtil::CompatibleIgnoringElementType(
3272                                     o->shape(), inst->shape());
3273                               })) {
3274       nodes_found.insert(inst);
3275       for (auto o : inst->operands()) {
3276         worklist.push_back(o);
3277       }
3278     } else {
3279       nodes_found.clear();
3280       new_operands.clear();
3281       break;
3282     }
3283   }
3284   return {std::move(nodes_found), std::move(new_operands)};
3285 }
3286 
3287 // Moves a cluster of memory-reducing nodes into the windowed dot-general loop
3288 // on contracting dimensions. Such a loop has a dynamic slice on the
3289 // non-windowed operand. If we move the input nodes into the loop, the
3290 // dynamic-slice could be merged with them by later optimization passes, which
3291 // reduces memory.
3292 //
3293 // small_operands             small_operands
3294 //        |                          |
3295 // input_nodes                loop { |
3296 //        |          =>         input_nodes
3297 // loop { |                          |
3298 //    dynamic-slice             dynamic-slice
3299 //    ...                       ...
3300 // }                          }
3301 //
3302 // Later optimization passes (TpuPadSliceMover) will merge the dynamic slice
3303 // with the input nodes.
SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions(HloInstruction * loop,int64_t non_windowed_operand_index)3304 Status SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions(
3305     HloInstruction* loop, int64_t non_windowed_operand_index) {
3306   auto input_tuple = loop->mutable_operand(0);
3307   auto old_operand = input_tuple->mutable_operand(non_windowed_operand_index);
3308   auto input_nodes = FindInputNodesIfOnlyDependOnSmallOperands(old_operand);
3309   auto to_sink = std::move(input_nodes.first);
3310   auto new_operands = std::move(input_nodes.second);
3311   if (to_sink.empty()) {
3312     return Status::OK();
3313   }
3314   auto computation = loop->parent();
3315   // Replace the old operand with a tuple of the found small operands.
3316   auto new_input_subtuple =
3317       computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
3318   TF_RETURN_IF_ERROR(input_tuple->ReplaceOperandWithDifferentShape(
3319       non_windowed_operand_index, new_input_subtuple));
3320 
3321   auto body = loop->while_body();
3322   auto body_param = body->parameter_instruction(0);
3323   auto old_body_param_users = body_param->users();
3324   // Update all tuple shapes.
3325   for (auto tuple : std::vector<HloInstruction*>{
3326            input_tuple, loop, loop->while_condition()->parameter_instruction(0),
3327            body_param, body->root_instruction()}) {
3328     *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(),
3329                                    {non_windowed_operand_index}) =
3330         new_input_subtuple->shape();
3331   }
3332   // Now update the loop body.
3333   auto new_operand_tuple_inside =
3334       body->AddInstruction(HloInstruction::CreateGetTupleElement(
3335           new_input_subtuple->shape(), body_param, non_windowed_operand_index));
3336   TF_RETURN_IF_ERROR(body->root_instruction()->ReplaceOperandWithDifferentShape(
3337       non_windowed_operand_index, new_operand_tuple_inside));
3338 
3339   // Create nodes inside the loop body.
3340   std::vector<HloInstruction*> worklist;
3341   absl::flat_hash_map<const HloInstruction*, HloInstruction*> outside_to_inside;
3342   auto add_users_if_available = [&](HloInstruction* inst) {
3343     for (auto u : inst->users()) {
3344       if (outside_to_inside.count(u) == 0 && to_sink.count(u) > 0 &&
3345           absl::c_all_of(u->operands(), [&](const HloInstruction* o) {
3346             return outside_to_inside.count(o) > 0;
3347           })) {
3348         worklist.push_back(u);
3349       }
3350     }
3351   };
3352   for (int64_t i = 0; i < new_operands.size(); ++i) {
3353     outside_to_inside[new_operands[i]] =
3354         body->AddInstruction(HloInstruction::CreateGetTupleElement(
3355             new_operands[i]->shape(), new_operand_tuple_inside, i));
3356     add_users_if_available(new_operands[i]);
3357   }
3358   // HLOs to sink without operands.
3359   std::vector<HloInstruction*> nullaries_to_sink;
3360   for (auto inst : to_sink) {
3361     if (inst->operand_count() == 0) {
3362       nullaries_to_sink.push_back(inst);
3363     }
3364   }
3365   // Sort nullaries_to_sink to make it deterministic.
3366   absl::c_sort(nullaries_to_sink,
3367                [](const HloInstruction* a, const HloInstruction* b) {
3368                  return a->unique_id() < b->unique_id();
3369                });
3370   worklist.reserve(nullaries_to_sink.size());
3371   for (auto inst : nullaries_to_sink) {
3372     worklist.push_back(inst);
3373   }
3374   while (!worklist.empty()) {
3375     auto inst = worklist.back();
3376     worklist.pop_back();
3377     std::vector<HloInstruction*> inst_new_operands(inst->operand_count());
3378     for (int64_t i = 0; i < inst->operand_count(); ++i) {
3379       inst_new_operands[i] = outside_to_inside[inst->operand(i)];
3380     }
3381     outside_to_inside[inst] = body->AddInstruction(
3382         inst->CloneWithNewOperands(inst->shape(), inst_new_operands));
3383     add_users_if_available(inst);
3384   }
3385   TF_RET_CHECK(outside_to_inside.count(old_operand) > 0);
3386   for (auto ou : old_body_param_users) {
3387     if (ou->opcode() == HloOpcode::kGetTupleElement &&
3388         ou->tuple_index() == non_windowed_operand_index) {
3389       TF_RETURN_IF_ERROR(
3390           ou->ReplaceAllUsesWith(outside_to_inside[old_operand]));
3391       TF_RETURN_IF_ERROR(body->RemoveInstruction(ou));
3392     }
3393   }
3394   return Status::OK();
3395 }
3396 
3397 // Moves a cluster of memory-reducing nodes (with reduce nodes at the end)
3398 // into the windowed dot-general loop on non-contracting dimensions. Such a
3399 // loop has a dynamic-update-slice at the output. If we move the user nodes
3400 // into the loop and before the dynamic-update-slice, the user nodes can
3401 // operate on smaller shapes, which reduces memory.
3402 //
3403 // small_operands                   small_operands
3404 //  | |                 =>                  | |
3405 //  | |  loop {                     loop {  | |
3406 //  | |    conv                             | broadcast      conv
3407 //  | |      |                              |     |           /
3408 //  | | dynamic-update-slice                |  dynamic-slice /
3409 //  | |         |                           |     |         /
3410 //  | |  }      |                           |  multiply-----
3411 //  |broadcast  /                           |    /
3412 //  | |        /                            reduce
3413 //  |multiply--                             |
3414 //  \ |                                dynamic-update-slice
3415 //   reduce                         }
3416 //
3417 // Later optimization passes (TpuPadSliceMover) will merge the dynamic slice
3418 // with the input nodes (broadcast).
MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions(HloInstruction * loop)3419 Status MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions(
3420     HloInstruction* loop) {
3421   CHECK_EQ(loop->user_count(), 1);
3422   // There should be a single direct user of the while loop, which is the
3423   // gte for element 2, i.e., the dot output.
3424   auto user_gte = loop->users().front();
3425   CHECK_EQ(user_gte->opcode(), HloOpcode::kGetTupleElement);
3426   CHECK_EQ(user_gte->tuple_index(), 2);
3427   auto computation = loop->parent();
3428 
3429   // Find the reduce outputs and the input nodes they depend on, if input
3430   // nodes only have small operands.
3431   absl::flat_hash_set<HloInstruction*> to_move;
3432   std::vector<HloInstruction*> new_operands;
3433   absl::flat_hash_set<const HloInstruction*> new_operands_set;
3434   std::vector<HloInstruction*> reduce_outputs;
3435   std::vector<HloInstruction*> worklist;
3436   Shape padded_shape = user_gte->shape();
3437   Shape unpadded_shape = user_gte->shape();
3438   auto original_output = user_gte;
3439 
3440   if (user_gte->user_count() == 1 &&
3441       user_gte->users().back()->opcode() == HloOpcode::kSlice) {
3442     original_output = user_gte->users().back();
3443     unpadded_shape = original_output->shape();
3444   }
3445   for (auto u : original_output->users()) {
3446     worklist.push_back(u);
3447   }
3448   to_move.insert(original_output);
3449   while (!worklist.empty()) {
3450     auto inst = worklist.back();
3451     worklist.pop_back();
3452     if (to_move.count(inst) > 0) {
3453       continue;
3454     }
3455     // We only support reduces with simple reduction function, since we may
3456     // need to accumulate across iterations manually.
3457     if (inst->opcode() == HloOpcode::kReduce &&
3458         inst->to_apply()->instruction_count() == 3 &&
3459         inst->to_apply()->num_parameters() == 2 &&
3460         inst->to_apply()->root_instruction()->IsElementwise()) {
3461       to_move.insert(inst);
3462       auto other_operand = inst->mutable_operand(1);
3463       auto res = new_operands_set.emplace(other_operand);
3464       if (res.second) {
3465         new_operands.push_back(other_operand);
3466       }
3467       reduce_outputs.push_back(inst);
3468     } else if (inst != computation->root_instruction() &&
3469                inst->user_count() > 0 && inst->IsElementwise() &&
3470                !inst->HasSideEffectNoRecurse() &&
3471                absl::c_all_of(inst->operands(),
3472                               [inst](const HloInstruction* o) {
3473                                 return ShapeUtil::CompatibleIgnoringElementType(
3474                                     o->shape(), inst->shape());
3475                               })) {
3476       // For an elementwise op, we need to make sure that they depend on only
3477       // nodes already in to_move and nodes with small operands.
3478       bool can_include = true;
3479       for (auto operand : inst->operands()) {
3480         if (to_move.count(operand) > 0) {
3481           continue;
3482         }
3483         auto find_result = FindInputNodesIfOnlyDependOnSmallOperands(operand);
3484         if (find_result.first.empty()) {
3485           can_include = false;
3486           break;
3487         }
3488         for (auto n : find_result.first) {
3489           to_move.insert(n);
3490         }
3491         for (auto new_operand : find_result.second) {
3492           auto res = new_operands_set.insert(new_operand);
3493           if (res.second) {
3494             new_operands.push_back(new_operand);
3495           }
3496         }
3497       }
3498       if (!can_include) {
3499         to_move.clear();
3500         break;
3501       }
3502       to_move.insert(inst);
3503       for (auto u : inst->users()) {
3504         worklist.push_back(u);
3505       }
3506     } else {
3507       to_move.clear();
3508       break;
3509     }
3510   }
3511   // If nothing is found, to_move could contain only original_output, or
3512   // cleared by the above code.
3513   if (to_move.size() <= 1) {
3514     return Status::OK();
3515   }
3516 
3517   // We will replace the original loop output with reduce-shape outputs.
3518   // Create the initial buffers before the loop.
3519   for (auto out : reduce_outputs) {
3520     auto padded_out_shape = out->shape();
3521     int64_t operand_dim = 0;
3522     int64_t output_dim = 0;
3523     while (output_dim < padded_out_shape.rank()) {
3524       if (absl::c_linear_search(out->dimensions(), operand_dim)) {
3525         // Dimension colapsed.
3526         ++operand_dim;
3527         continue;
3528       }
3529       // Kept dimensions have the same size of the padded shape.
3530       padded_out_shape.set_dimensions(output_dim,
3531                                       padded_shape.dimensions(operand_dim));
3532       ++operand_dim;
3533       ++output_dim;
3534     }
3535     auto broadcast =
3536         computation->AddInstruction(HloInstruction::CreateBroadcast(
3537             padded_out_shape,
3538             computation->AddInstruction(HloInstruction::CreateConstant(
3539                 LiteralUtil::Zero(out->shape().element_type()))),
3540             {}));
3541     new_operands.push_back(broadcast);
3542   }
3543 
3544   auto input_tuple = loop->mutable_operand(0);
3545   // Create the new input subtuple that contains the small operands and the
3546   // reduce-shape result buffers.
3547   auto new_input_subtuple =
3548       computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
3549   TF_RETURN_IF_ERROR(
3550       input_tuple->ReplaceOperandWithDifferentShape(2, new_input_subtuple));
3551   auto body = loop->while_body();
3552   auto body_param = body->parameter_instruction(0);
3553   auto body_root = body->root_instruction();
3554   CHECK_EQ(body_root->opcode(), HloOpcode::kTuple);
3555   // Update tuple shapes.
3556   for (auto tuple : std::vector<HloInstruction*>{
3557            input_tuple, loop, loop->while_condition()->parameter_instruction(0),
3558            body_param, body_root}) {
3559     *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {2}) =
3560         new_input_subtuple->shape();
3561   }
3562   auto new_loop_input =
3563       body->AddInstruction(HloInstruction::CreateGetTupleElement(
3564           new_input_subtuple->shape(), body_param, 2));
3565 
3566   // Now create the moved nodes inside the loop body.
3567   absl::flat_hash_map<const HloInstruction*, HloInstruction*> outside_to_inside;
3568   worklist.clear();
3569   auto add_users_if_available = [&](HloInstruction* inst) {
3570     for (auto u : inst->users()) {
3571       if (outside_to_inside.count(u) == 0 && to_move.count(u) > 0 &&
3572           absl::c_all_of(u->operands(), [&](const HloInstruction* o) {
3573             return outside_to_inside.count(o) > 0;
3574           })) {
3575         worklist.push_back(u);
3576       }
3577     }
3578   };
3579   for (int64_t i = 0; i < new_operands.size(); ++i) {
3580     outside_to_inside[new_operands[i]] =
3581         body->AddInstruction(HloInstruction::CreateGetTupleElement(
3582             new_operands[i]->shape(), new_loop_input, i));
3583     add_users_if_available(new_operands[i]);
3584   }
3585   // The elementwise nodes will be created with sliced shape. The original
3586   // loop output corresponds to the dynamic-update-slice's update slice.
3587   auto dus = body_root->mutable_operand(2);
3588   CHECK_EQ(dus->opcode(), HloOpcode::kDynamicUpdateSlice);
3589   outside_to_inside[original_output] = dus->mutable_operand(1);
3590   add_users_if_available(original_output);
3591   std::vector<HloInstruction*> slice_offsets(padded_shape.rank());
3592   for (int64_t i = 0; i < slice_offsets.size(); ++i) {
3593     slice_offsets[i] = dus->mutable_operand(i + 2);
3594   }
3595   auto get_slice = [&](HloInstruction* padded) {
3596     return body->AddInstruction(HloInstruction::CreateDynamicSlice(
3597         ShapeUtil::ChangeElementType(dus->operand(1)->shape(),
3598                                      padded->shape().element_type()),
3599         padded, slice_offsets, dus->operand(1)->shape().dimensions()));
3600   };
3601   // Helper functions to create nodes with small operands.
3602   auto add_broadcast = [&](const HloInstruction* broadcast) {
3603     auto padded_operand_shape = broadcast->operand(0)->shape();
3604     for (int64_t i = 0; i < broadcast->dimensions().size(); ++i) {
3605       padded_operand_shape.set_dimensions(
3606           i, padded_shape.dimensions(broadcast->dimensions(i)));
3607     }
3608     auto padded_operand = PadToShape(outside_to_inside[broadcast->operand(0)],
3609                                      padded_operand_shape, nullptr, body);
3610     outside_to_inside[broadcast] =
3611         get_slice(body->AddInstruction(broadcast->CloneWithNewOperands(
3612             ShapeUtil::ChangeElementType(padded_shape,
3613                                          padded_operand_shape.element_type()),
3614             {padded_operand})));
3615   };
3616   auto add_iota = [&](const HloInstruction* iota) {
3617     outside_to_inside[iota] =
3618         get_slice(body->AddInstruction(iota->CloneWithNewOperands(
3619             ShapeUtil::ChangeElementType(padded_shape,
3620                                          iota->shape().element_type()),
3621             {})));
3622   };
3623   auto add_constant = [&](const HloInstruction* constant) {
3624     outside_to_inside[constant] = body->AddInstruction(constant->Clone());
3625     outside_to_inside[constant] = get_slice(
3626         PadToShape(outside_to_inside[constant],
3627                    ShapeUtil::ChangeElementType(
3628                        padded_shape, constant->shape().element_type()),
3629                    nullptr, body));
3630   };
3631   while (!worklist.empty()) {
3632     auto inst = worklist.back();
3633     worklist.pop_back();
3634     if (outside_to_inside.count(inst) > 0) {
3635       continue;
3636     }
3637     if (inst->opcode() == HloOpcode::kBroadcast) {
3638       add_broadcast(inst);
3639     } else if (inst->opcode() == HloOpcode::kIota) {
3640       add_iota(inst);
3641     } else if (inst->opcode() == HloOpcode::kConstant) {
3642       add_constant(inst);
3643     } else if (inst->opcode() == HloOpcode::kReduce) {
3644       // This is an output, for which we has special handling later.
3645     } else {
3646       std::vector<HloInstruction*> operands_inside(inst->operand_count());
3647       for (int64_t i = 0; i < operands_inside.size(); ++i) {
3648         operands_inside[i] = outside_to_inside[inst->operand(i)];
3649       }
3650       outside_to_inside[inst] = body->AddInstruction(inst->CloneWithNewOperands(
3651           ShapeUtil::ChangeElementType(dus->operand(1)->shape(),
3652                                        inst->shape().element_type()),
3653           operands_inside));
3654     }
3655     add_users_if_available(inst);
3656   }
3657   std::vector<HloInstruction*> new_outputs_inside(new_operands.size());
3658   for (int64_t i = 0; i < new_outputs_inside.size(); ++i) {
3659     new_outputs_inside[i] = outside_to_inside[new_operands[i]];
3660   }
3661   // Now create the reduce outpus inside of the loop.
3662   for (int64_t i = 0; i < reduce_outputs.size(); ++i) {
3663     auto reduce_outside = reduce_outputs[i];
3664     CHECK_EQ(reduce_outside->opcode(), HloOpcode::kReduce);
3665     int64_t index_in_operand = new_operands.size() - reduce_outputs.size() + i;
3666     auto last_iter_result = outside_to_inside[new_operands[index_in_operand]];
3667     auto operand0 = outside_to_inside[reduce_outside->operand(0)];
3668     auto operand1 = outside_to_inside[reduce_outside->operand(1)];
3669     TF_ASSIGN_OR_RETURN(auto reduce_shape,
3670                         ShapeInference::InferReduceShape(
3671                             {&operand0->shape(), &operand1->shape()},
3672                             reduce_outside->dimensions(),
3673                             reduce_outside->to_apply()->ComputeProgramShape()));
3674     *reduce_shape.mutable_layout() = reduce_outside->shape().layout();
3675     std::vector<HloInstruction*> reduce_dus_offsets;
3676     // If any collapsed dimension is windowed, we need to accumulate with last
3677     // iteration's result. If such a dimension has padding, we also need to
3678     // mask off invalid data.
3679     bool needs_accumulate = false;
3680     std::vector<int64> dims_to_mask;
3681     for (int64_t i = 0; i < slice_offsets.size(); ++i) {
3682       if (absl::c_linear_search(reduce_outside->dimensions(), i)) {
3683         if (reduce_outside->operand(0)->shape().dimensions(i) !=
3684             operand0->shape().dimensions(i)) {
3685           needs_accumulate = true;
3686           if (unpadded_shape.dimensions(i) != padded_shape.dimensions(i)) {
3687             dims_to_mask.push_back(i);
3688           }
3689         }
3690         continue;
3691       }
3692       reduce_dus_offsets.push_back(slice_offsets[i]);
3693     }
3694     // Mask off invalid data in collapsed dimensions.
3695     for (int64_t dim : dims_to_mask) {
3696       auto iota = body->AddInstruction(HloInstruction::CreateIota(
3697           ShapeUtil::ChangeElementType(operand0->shape(), S32), dim));
3698       auto add = body->AddInstruction(HloInstruction::CreateBinary(
3699           iota->shape(), HloOpcode::kAdd, iota,
3700           body->AddInstruction(HloInstruction::CreateBroadcast(
3701               iota->shape(), slice_offsets[dim], {}))));
3702       auto limit = body->AddInstruction(HloInstruction::CreateBroadcast(
3703           iota->shape(),
3704           body->AddInstruction(
3705               HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
3706                   reduce_outside->operand(0)->shape().dimensions(dim)))),
3707           {}));
3708       auto compare = body->AddInstruction(HloInstruction::CreateCompare(
3709           ShapeUtil::ChangeElementType(iota->shape(), PRED), add, limit,
3710           ComparisonDirection::kLt));
3711       operand0 = body->AddInstruction(HloInstruction::CreateTernary(
3712           operand0->shape(), HloOpcode::kSelect, compare, operand0,
3713           body->AddInstruction(HloInstruction::CreateBroadcast(
3714               operand0->shape(), operand1, {}))));
3715     }
3716     auto output_inside =
3717         body->AddInstruction(reduce_outside->CloneWithNewOperands(
3718             reduce_shape, {operand0, operand1}));
3719     // Accumulate with previous results if needed.
3720     if (needs_accumulate) {
3721       auto input_slice =
3722           body->AddInstruction(HloInstruction::CreateDynamicSlice(
3723               output_inside->shape(), last_iter_result, reduce_dus_offsets,
3724               output_inside->shape().dimensions()));
3725       output_inside = body->AddInstruction(HloInstruction::CreateBinary(
3726           output_inside->shape(),
3727           reduce_outside->to_apply()->root_instruction()->opcode(),
3728           output_inside, input_slice));
3729     }
3730     // Dynamic-update-slice if needed.
3731     if (!ShapeUtil::Compatible(output_inside->shape(),
3732                                last_iter_result->shape())) {
3733       output_inside =
3734           body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
3735               last_iter_result->shape(), last_iter_result, output_inside,
3736               reduce_dus_offsets));
3737     }
3738     new_outputs_inside[index_in_operand] = output_inside;
3739   }
3740   // Body output.
3741   auto new_output_inside =
3742       body->AddInstruction(HloInstruction::CreateTuple(new_outputs_inside));
3743   TF_RETURN_IF_ERROR(
3744       body_root->ReplaceOperandWithDifferentShape(2, new_output_inside));
3745   TF_RETURN_IF_ERROR(body->RemoveInstructionAndUnusedOperands(dus));
3746   // Replace uses of the reduces outside the loop.
3747   auto new_output_gte =
3748       computation->AddInstruction(HloInstruction::CreateGetTupleElement(
3749           new_output_inside->shape(), loop, 2));
3750   for (int64_t i = 0; i < reduce_outputs.size(); ++i) {
3751     int64_t index_in_operand = new_operands.size() - reduce_outputs.size() + i;
3752     auto new_output =
3753         computation->AddInstruction(HloInstruction::CreateGetTupleElement(
3754             new_outputs_inside[index_in_operand]->shape(), new_output_gte,
3755             index_in_operand));
3756     if (!ShapeUtil::Compatible(new_output->shape(),
3757                                reduce_outputs[i]->shape())) {
3758       new_output = computation->AddInstruction(HloInstruction::CreateSlice(
3759           reduce_outputs[i]->shape(), new_output,
3760           std::vector<int64>(new_output->shape().rank(), 0),
3761           reduce_outputs[i]->shape().dimensions(),
3762           std::vector<int64>(new_output->shape().rank(), 1)));
3763     }
3764     TF_RETURN_IF_ERROR(reduce_outputs[i]->ReplaceAllUsesWith(new_output));
3765     TF_RETURN_IF_ERROR(
3766         computation->RemoveInstructionAndUnusedOperands(reduce_outputs[i]));
3767   }
3768   return Status::OK();
3769 }
3770 
3771 }  // namespace
3772 
DoCodeMotionForWindowedDotGeneralLoops(HloComputation * computation,const SpmdPartitionerOptions & options)3773 Status SpmdPartitioningVisitor::DoCodeMotionForWindowedDotGeneralLoops(
3774     HloComputation* computation, const SpmdPartitionerOptions& options) {
3775   for (auto& loop : windowed_dot_general_loops_) {
3776     if (loop.windowed_in_contracting_dims || loop.windowed_in_batch_dims ||
3777         loop.operands_sharded_at_contracting_dims) {
3778       // We have a dynamic-slice for the non-windowed operand in
3779       // batch/contracting-dim/noncontracting-dim windowed dot-general. So
3780       // moving the broadcast/iota/elementwise ops into the loop could help
3781       // reduce memory via fusion.
3782       TF_RETURN_IF_ERROR(
3783           SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions(
3784               loop.while_loop, 1 - loop.windowed_operand));
3785     }
3786     // Currently unrolled loop does not support this optimization.
3787     if (!options.bidirectional_windowed_einsum &&
3788         !options.unroll_windowed_einsum && !loop.windowed_in_contracting_dims &&
3789         !loop.operands_sharded_at_contracting_dims) {
3790       // We have a dynamic-update-slice for the output in
3791       // batch/non-contracting-dim windowed dot-general. So moving reduce ops
3792       // into the loop could help reduce memory.
3793       TF_RETURN_IF_ERROR(
3794           MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions(
3795               loop.while_loop));
3796     }
3797   }
3798   return Status::OK();
3799 }
3800 
3801 }  // namespace spmd
3802 }  // namespace xla
3803