• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
19 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
20 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
21 #include "tensorflow/compiler/xla/service/hlo_parser.h"
22 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
23 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 
29 namespace xla {
30 namespace spmd {
31 namespace {
32 
33 using ::testing::_;
34 using ::testing::AllOf;
35 namespace op = xla::testing::opcode_matchers;
36 
37 class SpmdPartitioningTest : public HloTestBase {
38  public:
PartitionComputation(absl::string_view hlo_module,int64_t num_devices,bool conv_halo_exchange_always_on_lhs=true,bool choose_faster_windowed_einsum=false)39   StatusOr<std::unique_ptr<HloModule>> PartitionComputation(
40       absl::string_view hlo_module, int64_t num_devices,
41       bool conv_halo_exchange_always_on_lhs = true,
42       bool choose_faster_windowed_einsum = false) {
43     // Some tests (BackpropFilter convs) set this flag false to test two
44     // different paths of the implementation.
45     SpmdPartitionerOptions options;
46     options.conv_halo_exchange_always_on_lhs = conv_halo_exchange_always_on_lhs;
47     options.allow_module_signature_change = true;
48     options.choose_faster_windowed_einsum_over_mem =
49         choose_faster_windowed_einsum;
50     auto collective_ops_creator =
51         GetDefaultCollectiveOpsCreator(num_devices, /*num_replicas=*/1);
52     // Do not use all-gather for pattern-matching purpose, as the partitioner
53     // might create reshape/transposes around it.
54     collective_ops_creator.create_cross_partition_all_gather = nullptr;
55 
56     TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(
57                                          hlo_module, GetModuleConfigForTest()));
58     HloPassPipeline pass("spmd-partitioning");
59     pass.AddPass<HloVerifier>(/*layout_sensitive=*/false,
60                               /*allow_mixed_precision=*/false);
61     pass.AddPass<SpmdPartitioner>(num_devices, /*num_replicas=*/1, options,
62                                   collective_ops_creator);
63     pass.AddPass<HloVerifier>(/*layout_sensitive=*/false,
64                               /*allow_mixed_precision=*/false);
65     TF_RETURN_IF_ERROR(pass.Run(module.get()).status());
66     return StatusOr<std::unique_ptr<HloModule>>(std::move(module));
67   }
68 };
69 
TEST_F(SpmdPartitioningTest,InvalidSharding)70 TEST_F(SpmdPartitioningTest, InvalidSharding) {
71   absl::string_view hlo_string = R"(
72 HloModule module
73 
74 ENTRY entry {
75   token0 = token[] after-all(), sharding={maximal device=0}
76   infeed = (f32[8,2]{1,0}, token[]) infeed(token0),
77     sharding={{devices=[2,1]0,1}, {maximal device=0}}
78   ROOT infeed.data = f32[8,2]{1,0} get-tuple-element(infeed), index=0,
79     sharding={maximal device=0}
80 })";
81   auto module_status = PartitionComputation(hlo_string, /*num_devices=*/4);
82   EXPECT_FALSE(module_status.status().ok());
83   EXPECT_THAT(module_status.status().ToString(),
84               ::testing::HasSubstr(
85                   "only supports tile sharding that includes all partitions"));
86 }
87 
TEST_F(SpmdPartitioningTest,SingleDeviceToReplicated)88 TEST_F(SpmdPartitioningTest, SingleDeviceToReplicated) {
89   absl::string_view hlo_string = R"(
90 HloModule module
91 
92 ENTRY entry {
93   %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
94     sharding={maximal device=0}
95   ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={replicated}
96 })";
97   TF_ASSERT_OK_AND_ASSIGN(auto module,
98                           PartitionComputation(hlo_string, /*num_devices=*/2));
99   VLOG(1) << module->ToString();
100   HloInstruction* root = module->entry_computation()->root_instruction();
101   EXPECT_THAT(root, AllOf(op::Copy(op::AllReduce(
102                               op::Select(op::Broadcast(op::Compare()),
103                                          op::Constant(), op::Broadcast()))),
104                           op::Shape("s32[2,3]")));
105 }
106 
TEST_F(SpmdPartitioningTest,SingleDeviceToSingleDevice)107 TEST_F(SpmdPartitioningTest, SingleDeviceToSingleDevice) {
108   absl::string_view hlo_string = R"(
109 HloModule module
110 
111 ENTRY entry {
112   %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
113     sharding={maximal device=0}
114   ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={maximal device=1}
115 })";
116   TF_ASSERT_OK_AND_ASSIGN(auto module,
117                           PartitionComputation(hlo_string, /*num_devices=*/2));
118   HloInstruction* root = module->entry_computation()->root_instruction();
119   VLOG(1) << module->ToString();
120   EXPECT_THAT(root, op::Copy(AllOf(op::Copy(op::AllReduce(op::Select(
121                                        op::Broadcast(op::Compare()),
122                                        op::Constant(), op::Broadcast()))),
123                                    op::Shape("s32[2,3]"))));
124 }
125 
TEST_F(SpmdPartitioningTest,SingleDeviceToTiled)126 TEST_F(SpmdPartitioningTest, SingleDeviceToTiled) {
127   absl::string_view hlo_string = R"(
128 HloModule module
129 
130 ENTRY entry {
131   %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
132     sharding={maximal device=0}
133   ROOT %copy = s32[2,3]{1,0} copy(%constant),
134     sharding={devices=[2,1]1,0}
135 })";
136   TF_ASSERT_OK_AND_ASSIGN(auto module,
137                           PartitionComputation(hlo_string, /*num_devices=*/2));
138   VLOG(1) << module->ToString();
139   HloInstruction* root = module->entry_computation()->root_instruction();
140   EXPECT_THAT(
141       root,
142       AllOf(
143           op::Copy(op::DynamicSlice(
144               op::AllReduce(op::Select(
145                   op::Broadcast(op::Compare(op::PartitionId(), op::Constant())),
146                   op::Constant(), op::Broadcast())),
147               op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())),
148               op::Constant())),
149           op::Shape("s32[1,3]")));
150 }
151 
TEST_F(SpmdPartitioningTest,TiledToReplicated)152 TEST_F(SpmdPartitioningTest, TiledToReplicated) {
153   absl::string_view hlo_string = R"(
154 HloModule module
155 
156 ENTRY entry {
157   %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
158     sharding={devices=[2,1]0,1}
159   ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={replicated}
160 })";
161   TF_ASSERT_OK_AND_ASSIGN(auto module,
162                           PartitionComputation(hlo_string, /*num_devices=*/2));
163   HloInstruction* root = module->entry_computation()->root_instruction();
164   EXPECT_THAT(
165       root,
166       op::Copy(op::AllReduce(AllOf(
167           op::DynamicUpdateSlice(
168               op::Broadcast(), AllOf(op::Constant(), op::Shape("s32[1,3]")),
169               op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())),
170               op::Constant()),
171           op::Shape("s32[2,3]")))));
172 }
173 
TEST_F(SpmdPartitioningTest,TiledToSingleDevice)174 TEST_F(SpmdPartitioningTest, TiledToSingleDevice) {
175   absl::string_view hlo_string = R"(
176 HloModule module
177 
178 ENTRY entry {
179   %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
180     sharding={devices=[2,1]0,1}
181   ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={maximal device=0}
182 })";
183   TF_ASSERT_OK_AND_ASSIGN(auto module,
184                           PartitionComputation(hlo_string, /*num_devices=*/2));
185   HloInstruction* root = module->entry_computation()->root_instruction();
186   EXPECT_THAT(
187       root,
188       op::Copy(op::Copy(op::AllReduce(AllOf(
189           op::DynamicUpdateSlice(
190               op::Broadcast(), AllOf(op::Constant(), op::Shape("s32[1,3]")),
191               op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())),
192               op::Constant()),
193           op::Shape("s32[2,3]"))))));
194 }
195 
TEST_F(SpmdPartitioningTest,TiledToTiledEven)196 TEST_F(SpmdPartitioningTest, TiledToTiledEven) {
197   absl::string_view hlo_string = R"(
198 HloModule module
199 
200 ENTRY entry {
201   %param= s32[8,2]{1,0} parameter(0), sharding={devices=[2,1]0,1}
202   ROOT %copy = s32[8,2]{1,0} copy(%param), sharding={devices=[1,2]0,1}
203 })";
204   TF_ASSERT_OK_AND_ASSIGN(auto module,
205                           PartitionComputation(hlo_string, /*num_devices=*/2));
206   VLOG(1) << module->ToString();
207 
208   HloInstruction* root = module->entry_computation()->root_instruction();
209   EXPECT_THAT(
210       root,
211       AllOf(op::Copy(op::Reshape(op::Transpose(op::AllToAll(AllOf(
212                 op::Reshape(op::Parameter()), op::Shape("s32[4,2,1]")))))),
213             op::Shape("s32[8,1]")));
214 }
215 
TEST_F(SpmdPartitioningTest,TiledToTiledUneven)216 TEST_F(SpmdPartitioningTest, TiledToTiledUneven) {
217   absl::string_view hlo_string = R"(
218 HloModule module
219 
220 ENTRY entry {
221   %param= f32[7,31,128]{2,1,0} parameter(0), sharding={devices=[1,2,1]0,1}
222   ROOT %copy = f32[7,31,128]{2,1,0} copy(%param), sharding={devices=[2,1,1]0,1}
223 })";
224   TF_ASSERT_OK_AND_ASSIGN(auto module,
225                           PartitionComputation(hlo_string, /*num_devices=*/2));
226   VLOG(1) << module->ToString();
227 
228   HloInstruction* root = module->entry_computation()->root_instruction();
229   EXPECT_THAT(
230       root,
231       AllOf(op::Copy(op::Slice(op::Reshape(AllOf(op::Transpose(op::AllToAll(
232           op::Reshape(AllOf(op::Pad(), op::Shape("f32[8,16,128]")))))))))));
233 }
234 
TEST_F(SpmdPartitioningTest,GetTupleElementSwapDevice)235 TEST_F(SpmdPartitioningTest, GetTupleElementSwapDevice) {
236   absl::string_view hlo_string = R"(
237 HloModule module
238 
239 ENTRY entry {
240   %param.0 = (f32[2,3]{1,0}, u32[]) parameter(0),
241     sharding={{maximal device=1}, {maximal device=1}}
242   %gte.0 = f32[2,3]{1,0} get-tuple-element(%param.0), index=0,
243     sharding={maximal device=0}
244   %gte.1 = u32[] get-tuple-element(%param.0), index=1,
245     sharding={maximal device=0}
246   ROOT %tuple = (f32[2,3]{1,0}, u32[]) tuple(%gte.0, %gte.1),
247     sharding={{maximal device=0},{maximal device=0}}
248 })";
249   TF_ASSERT_OK_AND_ASSIGN(auto module,
250                           PartitionComputation(hlo_string, /*num_devices=*/2));
251   VLOG(1) << module->ToString();
252   HloInstruction* root = module->entry_computation()->root_instruction();
253   ASSERT_THAT(root, op::Tuple());
254 
255   EXPECT_THAT(root->operand(0),
256               op::Copy(op::AllReduce(op::Select(
257                   op::Broadcast(op::Compare(op::PartitionId(), op::Constant())),
258                   op::GetTupleElement(op::Parameter()), op::Broadcast()))));
259   EXPECT_THAT(root->operand(1),
260               op::Copy(op::AllReduce(op::Select(
261                   op::Broadcast(op::Compare(op::PartitionId(), op::Constant())),
262                   op::GetTupleElement(op::Parameter()), op::Broadcast()))));
263 }
264 
TEST_F(SpmdPartitioningTest,GetTupleElementTiled)265 TEST_F(SpmdPartitioningTest, GetTupleElementTiled) {
266   absl::string_view hlo_string = R"(
267 HloModule module
268 
269 ENTRY entry {
270   param.0 = (f32[2,3]{1,0}, u32[2,3]{1,0}) parameter(0),
271     sharding={{replicated}, {replicated}}
272   gte.0 = f32[2,3]{1,0} get-tuple-element(param.0), index=0,
273     sharding={devices=[2,1]0,1}
274   gte.1 = u32[2,3]{1,0} get-tuple-element(param.0), index=1,
275     sharding={devices=[2,1]0,1}
276   ROOT %tuple = (f32[2,3]{1,0}, u32[2,3]{1,0}) tuple(gte.0, gte.1),
277     sharding={{devices=[2,1]0,1},{devices=[2,1]0,1}}
278 })";
279   TF_ASSERT_OK_AND_ASSIGN(auto module,
280                           PartitionComputation(hlo_string, /*num_devices=*/2));
281   VLOG(1) << module->ToString();
282   HloInstruction* root = module->entry_computation()->root_instruction();
283   ASSERT_THAT(root, op::Tuple());
284 
285   auto offset =
286       op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
287 
288   EXPECT_THAT(root->operand(0),
289               op::DynamicSlice(op::GetTupleElement(op::Parameter()), offset,
290                                op::Constant()));
291   EXPECT_THAT(root->operand(1),
292               op::DynamicSlice(op::GetTupleElement(op::Parameter()), offset,
293                                op::Constant()));
294 }
295 
TEST_F(SpmdPartitioningTest,TiledInfeed)296 TEST_F(SpmdPartitioningTest, TiledInfeed) {
297   absl::string_view hlo_string = R"(
298 HloModule module
299 
300 ENTRY entry {
301   token0 = token[] after-all(), sharding={maximal device=0}
302   infeed = (f32[8,2]{1,0}, token[]) infeed(token0),
303     sharding={{devices=[2,1]0,1}, {maximal device=0}}
304   ROOT infeed.data = f32[8,2]{1,0} get-tuple-element(infeed), index=0,
305     sharding={maximal device=0}
306 })";
307   TF_ASSERT_OK_AND_ASSIGN(auto module,
308                           PartitionComputation(hlo_string, /*num_devices=*/2));
309   HloInstruction* root = module->entry_computation()->root_instruction();
310   EXPECT_THAT(
311       root,
312       op::Copy(op::AllReduce(op::DynamicUpdateSlice(
313           op::Broadcast(),
314           op::GetTupleElement(
315               AllOf(op::Infeed(), op::Shape("(f32[4,2]{1,0}, token[])"))),
316           op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())),
317           op::Constant()))));
318 }
319 
TEST_F(SpmdPartitioningTest,UnevenTiledInfeed)320 TEST_F(SpmdPartitioningTest, UnevenTiledInfeed) {
321   absl::string_view hlo_string = R"(
322 HloModule module
323 
324 ENTRY entry {
325   token0 = token[] after-all(), sharding={maximal device=0}
326   infeed = (f32[9,2]{1,0}, token[]) infeed(token0),
327     sharding={{devices=[2,1]0,1}, {maximal device=0}}
328   ROOT infeed.data = f32[9,2]{1,0} get-tuple-element(infeed), index=0,
329     sharding={devices=[2,1]0,1}
330 })";
331   TF_ASSERT_OK_AND_ASSIGN(auto module,
332                           PartitionComputation(hlo_string, /*num_devices=*/2));
333   VLOG(1) << module->ToString();
334   HloInstruction* root = module->entry_computation()->root_instruction();
335   EXPECT_THAT(
336       root, AllOf(op::Shape("f32[5,2]"), op::GetTupleElement(op::Conditional(
337                                              op::Convert(op::PartitionId()),
338                                              op::AfterAll(), op::AfterAll()))));
339   EXPECT_THAT(
340       root->operand(0)->called_computations()[0]->root_instruction(),
341       AllOf(op::Shape("(f32[5,2], token[])"), op::Infeed(op::Parameter())));
342   auto second_infeed =
343       AllOf(op::Shape("(f32[4,2], token[])"), op::Infeed(op::Parameter()));
344   EXPECT_THAT(root->operand(0)->called_computations()[1]->root_instruction(),
345               AllOf(op::Shape("(f32[5,2], token[])"),
346                     op::Tuple(op::Pad(op::GetTupleElement(second_infeed),
347                                       op::Constant()),
348                               op::GetTupleElement(second_infeed))));
349 }
350 
TEST_F(SpmdPartitioningTest,UnevenTiledTupleInfeed)351 TEST_F(SpmdPartitioningTest, UnevenTiledTupleInfeed) {
352   absl::string_view hlo_string = R"(
353 HloModule module
354 
355 ENTRY entry {
356   token0 = token[] after-all(), sharding={maximal device=0}
357   infeed = ((f32[9,2]{1,0}, f32[2]{0}), token[]) infeed(token0),
358     sharding={{devices=[2,1]0,1}, {replicated}, {maximal device=0}}
359   ROOT infeed.data = (f32[9,2]{1,0}, f32[2]{0}) get-tuple-element(infeed),
360     index=0, sharding={{devices=[2,1]0,1}, {replicated}}
361 })";
362   TF_ASSERT_OK_AND_ASSIGN(auto module,
363                           PartitionComputation(hlo_string, /*num_devices=*/2));
364   VLOG(1) << module->ToString();
365   HloInstruction* root = module->entry_computation()->root_instruction();
366   EXPECT_THAT(root, AllOf(op::Shape("(f32[5,2], f32[2])"),
367                           op::GetTupleElement(op::Conditional(
368                               op::Convert(op::PartitionId()), op::AfterAll(),
369                               op::AfterAll()))));
370   EXPECT_THAT(root->operand(0)->called_computations()[0]->root_instruction(),
371               AllOf(op::Shape("((f32[5,2], f32[2]), token[])"),
372                     op::Infeed(op::Parameter())));
373   auto second_infeed = AllOf(op::Shape("((f32[4,2], f32[2]), token[])"),
374                              op::Infeed(op::Parameter()));
375   EXPECT_THAT(
376       root->operand(0)->called_computations()[1]->root_instruction(),
377       AllOf(op::Shape("((f32[5,2], f32[2]), token[])"),
378             op::Tuple(op::Tuple(op::Pad(op::GetTupleElement(
379                                             op::GetTupleElement(second_infeed)),
380                                         op::Constant()),
381                                 op::GetTupleElement(
382                                     op::GetTupleElement(second_infeed))),
383                       op::GetTupleElement(second_infeed))));
384 }
385 
TEST_F(SpmdPartitioningTest,MixedTupleInfeed)386 TEST_F(SpmdPartitioningTest, MixedTupleInfeed) {
387   absl::string_view hlo_string = R"(
388 HloModule module
389 
390 ENTRY entry {
391   token0 = token[] after-all(), sharding={maximal device=0}
392   infeed = ((f32[9,2]{1,0}, f32[2]{0}), token[]) infeed(token0),
393     sharding={{maximal device=0}, {maximal device=1}, {maximal device=0}}
394   ROOT infeed.data = (f32[9,2]{1,0}, f32[2]{0}) get-tuple-element(infeed),
395     index=0, sharding={{maximal device=0}, {maximal device=1}}
396 })";
397   TF_ASSERT_OK_AND_ASSIGN(auto module,
398                           PartitionComputation(hlo_string, /*num_devices=*/2));
399   VLOG(1) << module->ToString();
400   HloInstruction* root = module->entry_computation()->root_instruction();
401   EXPECT_THAT(root, AllOf(op::Shape("(f32[9,2], f32[2])"),
402                           op::GetTupleElement(op::Conditional(
403                               op::Convert(op::PartitionId()), op::AfterAll(),
404                               op::AfterAll()))));
405   auto first_infeed = AllOf(op::Shape("((f32[9,2], ()), token[])"),
406                             op::Infeed(op::Parameter()));
407   EXPECT_THAT(root->operand(0)->called_computations()[0]->root_instruction(),
408               AllOf(op::Shape("((f32[9,2], f32[2]), token[])"),
409                     op::Tuple(op::Tuple(op::GetTupleElement(
410                                             op::GetTupleElement(first_infeed)),
411                                         op::Broadcast(op::Constant())),
412                               op::GetTupleElement(first_infeed))));
413   auto second_infeed =
414       AllOf(op::Shape("(((), f32[2]), token[])"), op::Infeed(op::Parameter()));
415   EXPECT_THAT(root->operand(0)->called_computations()[1]->root_instruction(),
416               AllOf(op::Shape("((f32[9,2], f32[2]), token[])"),
417                     op::Tuple(op::Tuple(op::Broadcast(op::Constant()),
418                                         op::GetTupleElement(op::GetTupleElement(
419                                             second_infeed))),
420                               op::GetTupleElement(second_infeed))));
421 }
422 
TEST_F(SpmdPartitioningTest,TiledToReplicatedReduce)423 TEST_F(SpmdPartitioningTest, TiledToReplicatedReduce) {
424   absl::string_view hlo_string = R"(
425 HloModule module
426 
427 sum {
428   a = f32[] parameter(0)
429   b = f32[] parameter(1)
430   ROOT add = f32[] add(a, b)
431 }
432 
433 ENTRY entry {
434   constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}),
435     sharding={devices=[2,1]0,1}
436   constant.1 = f32[] constant(0), sharding={replicated}
437   ROOT reduce = f32[] reduce(constant, constant.1), dimensions={0,1},
438     to_apply=sum, sharding={replicated}
439 })";
440   TF_ASSERT_OK_AND_ASSIGN(auto module,
441                           PartitionComputation(hlo_string, /*num_devices=*/2));
442   VLOG(1) << module->ToString();
443   HloInstruction* root = module->entry_computation()->root_instruction();
444   EXPECT_THAT(
445       root,
446       op::AllReduce(op::Reduce(
447           op::Select(
448               op::Compare(op::Add(op::Iota(), op::Broadcast(op::Reshape())),
449                           op::Broadcast(op::Constant())),
450               AllOf(op::Shape("f32[2,3]{1,0}"),
451                     op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
452                                      op::Reshape(), op::Constant())),
453               op::Broadcast(op::Constant())),
454           op::Constant())));
455 }
456 
TEST_F(SpmdPartitioningTest,TiledElementwise)457 TEST_F(SpmdPartitioningTest, TiledElementwise) {
458   absl::string_view hlo_string = R"(
459 HloModule module
460 
461 ENTRY entry {
462   constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}),
463     sharding={devices=[2,1]0,1}
464   constant.1 = f32[3,3]{1,0} constant({{2,2,2},{2,2,2},{2,2,2}}),
465     sharding={replicated}
466   multiply = f32[3,3]{1,0} multiply(constant, constant.1),
467     sharding={devices=[2,1]0,1}
468   ROOT add = f32[3,3]{1,0} add(multiply, constant.1),
469     sharding={devices=[2,1]0,1}
470 })";
471   TF_ASSERT_OK_AND_ASSIGN(auto module,
472                           PartitionComputation(hlo_string, /*num_devices=*/2));
473   VLOG(1) << module->ToString();
474   HloInstruction* root = module->entry_computation()->root_instruction();
475   EXPECT_THAT(
476       root,
477       AllOf(
478           op::Shape("f32[2,3]{1,0}"),
479           op::Add(op::Multiply(
480                       op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
481                                        op::Reshape(), op::Constant()),
482                       op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
483                                        op::Reshape(), op::Constant())),
484                   op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
485                                    op::Reshape(), op::Constant()))));
486 }
487 
TEST_F(SpmdPartitioningTest,TiledAllReduce)488 TEST_F(SpmdPartitioningTest, TiledAllReduce) {
489   absl::string_view hlo_string = R"(
490 HloModule module
491 
492 sum {
493   a = f32[] parameter(0)
494   b = f32[] parameter(1)
495   ROOT add = f32[] add(a, b)
496 }
497 
498 ENTRY entry {
499   parameter = f32[3,3]{1,0} parameter(0), sharding={devices=[2,1]0,1}
500   ROOT all-reduce = f32[3,3]{1,0} all-reduce(parameter), to_apply=sum,
501     replica_groups={}, sharding={devices=[2,1]0,1}
502 })";
503   TF_ASSERT_OK_AND_ASSIGN(auto module,
504                           PartitionComputation(hlo_string, /*num_devices=*/2));
505   VLOG(1) << module->ToString();
506   HloInstruction* root = module->entry_computation()->root_instruction();
507   EXPECT_THAT(
508       root, AllOf(op::Shape("f32[2,3]{1,0}"), op::AllReduce(op::Parameter(0))));
509 }
510 
TEST_F(SpmdPartitioningTest,BroadcastOnlyNewDimsSharded)511 TEST_F(SpmdPartitioningTest, BroadcastOnlyNewDimsSharded) {
512   absl::string_view hlo_string = R"(
513 HloModule module
514 
515 ENTRY entry {
516   constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}),
517     sharding={replicated}
518   ROOT broadcast = f32[3,4,3]{2,1,0} broadcast(constant), dimensions={1,2},
519     sharding={devices=[2,1,1]0,1}
520 })";
521   TF_ASSERT_OK_AND_ASSIGN(auto module,
522                           PartitionComputation(hlo_string, /*num_devices=*/2));
523   VLOG(1) << module->ToString();
524   HloInstruction* root = module->entry_computation()->root_instruction();
525   EXPECT_THAT(root, AllOf(op::Shape("f32[2,4,3]{2,1,0}"),
526                           op::Broadcast(op::Constant())));
527 }
528 
TEST_F(SpmdPartitioningTest,BroadcastOnlyOldDimsSharded)529 TEST_F(SpmdPartitioningTest, BroadcastOnlyOldDimsSharded) {
530   absl::string_view hlo_string = R"(
531 HloModule module
532 
533 ENTRY entry {
534   constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}),
535     sharding={replicated}
536   ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2},
537     sharding={devices=[1,2,1]0,1}
538 })";
539   TF_ASSERT_OK_AND_ASSIGN(auto module,
540                           PartitionComputation(hlo_string, /*num_devices=*/2));
541   VLOG(1) << module->ToString();
542   HloInstruction* root = module->entry_computation()->root_instruction();
543   EXPECT_THAT(root, AllOf(op::Shape("f32[4,2,3]{2,1,0}"),
544                           op::Broadcast(op::DynamicSlice(
545                               op::Constant(), op::Reshape(), op::Constant()))));
546 }
547 
TEST_F(SpmdPartitioningTest,BroadcastBothOldAndNewDimsSharded)548 TEST_F(SpmdPartitioningTest, BroadcastBothOldAndNewDimsSharded) {
549   absl::string_view hlo_string = R"(
550 HloModule module
551 
552 ENTRY entry {
553   constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}),
554     sharding={replicated}
555   ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2},
556     sharding={devices=[2,2,1]0,1,2,3}
557 })";
558   TF_ASSERT_OK_AND_ASSIGN(auto module,
559                           PartitionComputation(hlo_string, /*num_devices=*/4));
560   VLOG(1) << module->ToString();
561   HloInstruction* root = module->entry_computation()->root_instruction();
562   EXPECT_THAT(
563       root,
564       AllOf(op::Shape("f32[2,2,3]{2,1,0}"),
565             op::Broadcast(AllOf(op::Shape("f32[2,3]{1,0}"),
566                                 op::DynamicSlice(op::Constant(), op::Reshape(),
567                                                  op::Constant())))));
568 }
569 
TEST_F(SpmdPartitioningTest,BroadcastBothOldAndNewDimsShardedPartiallySharded)570 TEST_F(SpmdPartitioningTest,
571        BroadcastBothOldAndNewDimsShardedPartiallySharded) {
572   absl::string_view hlo_string = R"(
573 HloModule module
574 
575 ENTRY entry {
576   param = f32[4,3] parameter(0),
577     sharding={devices=[1,2,4]0,1,4,5,2,3,6,7 last_tile_dim_replicate}
578   ROOT broadcast = f32[4,4,3] broadcast(param), dimensions={1,2},
579     sharding={devices=[2,1,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
580 })";
581   TF_ASSERT_OK_AND_ASSIGN(auto module,
582                           PartitionComputation(hlo_string, /*num_devices=*/8));
583   VLOG(1) << module->ToString();
584   HloInstruction* root = module->entry_computation()->root_instruction();
585   EXPECT_THAT(
586       root,
587       AllOf(op::Shape("f32[2,4,2]"),
588             op::Broadcast(AllOf(op::Shape("f32[4,2]"), op::Parameter(0)))));
589 }
590 
TEST_F(SpmdPartitioningTest,ConvWithParallelDimAndNonParallelSpatialDimPartitioned)591 TEST_F(SpmdPartitioningTest,
592        ConvWithParallelDimAndNonParallelSpatialDimPartitioned) {
593   absl::string_view hlo_string = R"(
594 HloModule module
595 
596 ENTRY entry {
597   %lhs = f32[32,12,12,24,32] parameter(0)
598   %lhs.copy = f32[32,12,12,24,32] copy(%lhs),
599     sharding={devices=[2,2,1,1,1]0,1,2,3}
600   %rhs = f32[32,6,6,16,32] parameter(1)
601   %rhs.copy = f32[32,6,6,16,32] copy(%rhs),
602     sharding={devices=[2,2,1,1,1]0,1,2,3}
603   ROOT %conv = f32[32,7,7,24,16] convolution(%lhs.copy, %rhs.copy),
604     dim_labels=012bf_012oi->012bf,
605     window={size=32x6x6 stride=31x1x1 lhs_dilate=32x1x1},
606     sharding={devices=[2,2,1,1,1]0,1,2,3}
607 })";
608 
609   TF_ASSERT_OK_AND_ASSIGN(auto module,
610                           PartitionComputation(hlo_string, /*num_devices=*/4));
611   VLOG(1) << module->ToString();
612   auto root = module->entry_computation()->root_instruction();
613   auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
614                                              op::Reshape(), op::Constant(),
615                                              op::Constant(), op::Constant())),
616                    op::Shape("f32[16,6,12,24,32]"));
617   auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
618                                              op::Reshape(), op::Constant(),
619                                              op::Constant(), op::Constant())),
620                    op::Shape("f32[16,3,6,16,32]"));
621   auto resharded_rhs =
622       AllOf(op::Shape("f32[16,6,6,16,32]"),
623             op::AllReduce(op::DynamicUpdateSlice(
624                 op::Broadcast(), rhs, op::Constant(), op::Reshape(),
625                 op::Constant(), op::Constant(), op::Constant())));
626 
627   auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
628                          op::Shape("f32[16,2,12,24,32]"));
629   auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
630                           op::Shape("f32[16,3,12,24,32]"));
631   EXPECT_THAT(
632       root,
633       AllOf(op::Convolution(
634                 op::Select(op::Compare(),
635                            op::DynamicSlice(
636                                op::Concatenate(left_halo, lhs, right_halo),
637                                op::Constant(), op::Add(), op::Constant(),
638                                op::Constant(), op::Constant()),
639                            op::Broadcast()),
640                 resharded_rhs),
641             op::Shape("f32[16,4,7,24,16]")));
642 }
643 
TEST_F(SpmdPartitioningTest,BroadcastPropagateTiledSharding)644 TEST_F(SpmdPartitioningTest, BroadcastPropagateTiledSharding) {
645   absl::string_view hlo_string = R"(
646 HloModule module
647 
648 ENTRY entry {
649   constant = f32[4,3]{1,0} constant({{1,1,1},{1,4,1},{1,3,1},{1,2,1}}),
650     sharding={devices=[2,1]0,1}
651   ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2},
652     sharding={devices=[1,2,1]0,1}
653 })";
654   TF_ASSERT_OK_AND_ASSIGN(auto module,
655                           PartitionComputation(hlo_string, /*num_devices=*/2));
656   VLOG(1) << module->ToString();
657   HloInstruction* root = module->entry_computation()->root_instruction();
658   EXPECT_THAT(root, AllOf(op::Shape("f32[4,2,3]{2,1,0}"),
659                           op::Broadcast(op::DynamicSlice(
660                               op::Constant(), op::Reshape(), op::Constant()))));
661 }
662 
TEST_F(SpmdPartitioningTest,OutfeedSingleDevice)663 TEST_F(SpmdPartitioningTest, OutfeedSingleDevice) {
664   absl::string_view hlo_string = R"(
665 HloModule module
666 
667 ENTRY entry {
668   token.0 = token[] after-all()
669   data = f32[1024]{0} parameter(0), sharding={maximal device=0}
670   outfeed = token[] outfeed(data, token.0), sharding={maximal device=0}
671 })";
672   TF_ASSERT_OK_AND_ASSIGN(auto module,
673                           PartitionComputation(hlo_string, /*num_devices=*/2));
674   VLOG(1) << module->ToString();
675   HloInstruction* root = module->entry_computation()->root_instruction();
676   EXPECT_THAT(root, AllOf(op::Shape("token[]"),
677                           op::Conditional(
678                               op::Compare(op::PartitionId(), op::Constant()),
679                               op::Tuple(op::Parameter(0), op::AfterAll()),
680                               op::Tuple(op::Parameter(0), op::AfterAll()))));
681 
682   HloInstruction* root_b0 = root->branch_computation(0)->root_instruction();
683   EXPECT_THAT(root_b0,
684               AllOf(op::Shape("token[]"),
685                     op::Outfeed(op::GetTupleElement(op::Parameter(), 0),
686                                 op::GetTupleElement(op::Parameter(), 1))));
687 
688   HloInstruction* root_b1 = root->branch_computation(1)->root_instruction();
689   EXPECT_THAT(root_b1, AllOf(op::Shape("token[]"), op::AfterAll()));
690 }
691 
TEST_F(SpmdPartitioningTest,OutfeedEvenlyTiled)692 TEST_F(SpmdPartitioningTest, OutfeedEvenlyTiled) {
693   absl::string_view hlo_string = R"(
694 HloModule module
695 
696 ENTRY entry {
697   token.0 = token[] after-all()
698   data = f32[1024]{0} parameter(0), sharding={devices=[2]0,1}
699   ROOT outfeed = token[] outfeed(data, token.0), sharding={devices=[2]0,1}
700 })";
701   TF_ASSERT_OK_AND_ASSIGN(auto module,
702                           PartitionComputation(hlo_string, /*num_devices=*/2));
703   VLOG(1) << module->ToString();
704   HloInstruction* root = module->entry_computation()->root_instruction();
705   EXPECT_THAT(root, AllOf(op::Shape("token[]"),
706                           op::Outfeed(op::Parameter(), op::AfterAll())));
707 }
708 
TEST_F(SpmdPartitioningTest,OutfeedTupleEvenlyTiled)709 TEST_F(SpmdPartitioningTest, OutfeedTupleEvenlyTiled) {
710   absl::string_view hlo_string = R"(
711 HloModule module
712 
713 ENTRY entry {
714   token.0 = token[] after-all()
715   data = (f32[1024,2]{1,0}, f32[2]{0}) parameter(0), sharding={{devices=[2,1]0,1},
716     {devices=[2]0,1}}
717   ROOT outfeed = token[] outfeed(data, token.0),
718     outfeed_shape=(f32[1024,2]{0,1}, f32[2]{0}), sharding={{devices=[2,1]0,1},
719     {devices=[2]0,1}}
720 })";
721   TF_ASSERT_OK_AND_ASSIGN(auto module,
722                           PartitionComputation(hlo_string, /*num_devices=*/2));
723   VLOG(1) << module->ToString();
724   HloInstruction* root = module->entry_computation()->root_instruction();
725   EXPECT_THAT(root, AllOf(op::Shape("token[]"),
726                           op::Outfeed(op::Parameter(), op::AfterAll())));
727   auto expected_layout0 = LayoutUtil::MakeLayout({0, 1});
728   auto expected_layout1 = LayoutUtil::MakeLayout({0});
729   EXPECT_TRUE(LayoutUtil::Equal(root->outfeed_shape().tuple_shapes(0).layout(),
730                                 expected_layout0));
731   EXPECT_TRUE(LayoutUtil::Equal(root->outfeed_shape().tuple_shapes(1).layout(),
732                                 expected_layout1));
733 }
734 
TEST_F(SpmdPartitioningTest,OutfeedReplicated)735 TEST_F(SpmdPartitioningTest, OutfeedReplicated) {
736   absl::string_view hlo_string = R"(
737 HloModule module
738 
739 ENTRY entry {
740   token.0 = token[] after-all()
741   data = (f32[1024,2]{1,0}, f32[2]{0}) parameter(0), sharding={{devices=[2,1]0,1},
742     {replicated}}
743   ROOT outfeed = token[] outfeed(data, token.0), sharding={{devices=[2,1]0,1},
744     {replicated}}
745 })";
746   TF_ASSERT_OK_AND_ASSIGN(auto module,
747                           PartitionComputation(hlo_string, /*num_devices=*/2));
748   VLOG(1) << module->ToString();
749   HloInstruction* root = module->entry_computation()->root_instruction();
750   EXPECT_THAT(root, AllOf(op::Shape("token[]"),
751                           op::Outfeed(op::Parameter(), op::AfterAll())));
752 }
753 
TEST_F(SpmdPartitioningTest,OutfeedUnevenlyTiled)754 TEST_F(SpmdPartitioningTest, OutfeedUnevenlyTiled) {
755   absl::string_view hlo_string = R"(
756 HloModule module
757 
758 ENTRY entry {
759   token.0 = token[] after-all()
760   data = (f32[1023,2]{1,0}, f32[3]{0}) parameter(0), sharding={{devices=[2,1]0,1},
761     {devices=[2]0,1}}
762   outfeed = token[] outfeed(data, token.0),
763     outfeed_shape=(f32[1023,2]{0,1}, f32[3]{0}), sharding={{devices=[2,1]0,1},
764     {devices=[2]0,1}}
765 })";
766   TF_ASSERT_OK_AND_ASSIGN(auto module,
767                           PartitionComputation(hlo_string, /*num_devices=*/2));
768   VLOG(1) << module->ToString();
769 
770   HloInstruction* root = module->entry_computation()->root_instruction();
771   EXPECT_THAT(
772       root, AllOf(op::Shape("token[]"),
773                   op::Conditional(op::Convert(),
774                                   op::Tuple(op::Parameter(), op::AfterAll()),
775                                   op::Tuple(op::Parameter(), op::AfterAll()))));
776 
777   auto first_outfeed =
778       AllOf(op::Shape("(f32[512,2], f32[2])"), op::GetTupleElement());
779   EXPECT_THAT(root->called_computations()[0]->root_instruction(),
780               AllOf(op::Shape("token[]"),
781                     op::Outfeed(first_outfeed, op::GetTupleElement())));
782 
783   auto second_outfeed = AllOf(op::Shape("(f32[511,2], f32[1])"), op::Tuple());
784   EXPECT_THAT(root->called_computations()[1]->root_instruction(),
785               AllOf(op::Shape("token[]"),
786                     op::Outfeed(second_outfeed, op::GetTupleElement())));
787 
788   auto expected_layout0 = LayoutUtil::MakeLayout({0, 1});
789   auto expected_layout1 = LayoutUtil::MakeLayout({0});
790   auto first_outfeed_instr = root->called_computations()[0]->root_instruction();
791   auto second_outfeed_instr =
792       root->called_computations()[1]->root_instruction();
793   EXPECT_TRUE(LayoutUtil::Equal(
794       first_outfeed_instr->outfeed_shape().tuple_shapes(0).layout(),
795       expected_layout0));
796   EXPECT_TRUE(LayoutUtil::Equal(
797       first_outfeed_instr->outfeed_shape().tuple_shapes(1).layout(),
798       expected_layout1));
799   EXPECT_TRUE(LayoutUtil::Equal(
800       second_outfeed_instr->outfeed_shape().tuple_shapes(0).layout(),
801       expected_layout0));
802   EXPECT_TRUE(LayoutUtil::Equal(
803       second_outfeed_instr->outfeed_shape().tuple_shapes(1).layout(),
804       expected_layout1));
805 }
806 
TEST_F(SpmdPartitioningTest,ReduceWindowReplicatedInput)807 TEST_F(SpmdPartitioningTest, ReduceWindowReplicatedInput) {
808   absl::string_view hlo_string = R"(
809 HloModule module
810 
811 sum {
812   a = f32[] parameter(0)
813   b = f32[] parameter(1)
814   ROOT add = f32[] add(a, b)
815 }
816 
817 ENTRY entry {
818   constant = f32[6,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1},{1,2},{2,2}}),
819     sharding={replicated}
820   constant.1 = f32[] constant(0), sharding={replicated}
821   ROOT reduce-window = f32[3,2]{1,0} reduce-window(constant, constant.1),
822     window={size=3x1 stride=2x1 pad=1_0x0_0}, to_apply=sum,
823     sharding={devices=[2,1]0,1}
824 })";
825   TF_ASSERT_OK_AND_ASSIGN(auto module,
826                           PartitionComputation(hlo_string, /*num_devices=*/2));
827   VLOG(1) << module->ToString();
828   HloInstruction* root = module->entry_computation()->root_instruction();
829   EXPECT_THAT(
830       root,
831       AllOf(op::Shape("f32[2,2]{1,0}"),
832             op::ReduceWindow(
833                 op::DynamicSlice(AllOf(op::Shape("f32[9,2]{1,0}"),
834                                        op::Pad(op::Constant(), op::Constant())),
835                                  op::Multiply(op::Reshape(), op::Constant()),
836                                  op::Constant()),
837                 op::Constant())));
838 }
839 
TEST_F(SpmdPartitioningTest,ReduceWindowTiledNegativeLeftHalo)840 TEST_F(SpmdPartitioningTest, ReduceWindowTiledNegativeLeftHalo) {
841   absl::string_view hlo_string = R"(
842 HloModule module
843 
844 sum {
845   a = f32[] parameter(0)
846   b = f32[] parameter(1)
847   ROOT add = f32[] add(a, b)
848 }
849 
850 ENTRY entry {
851   constant = f32[6,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1},{1,2},{2,2}}),
852     sharding={devices=[2,1]0,1}
853   constant.1 = f32[] constant(0), sharding={replicated}
854   ROOT %reduce-window = f32[3,2]{1,0} reduce-window(%constant, %constant.1),
855     window={size=3x1 stride=2x1 pad=0_1x0_0}, to_apply=sum,
856     sharding={devices=[2,1]0,1}
857 })";
858   TF_ASSERT_OK_AND_ASSIGN(auto module,
859                           PartitionComputation(hlo_string, /*num_devices=*/2));
860   VLOG(1) << module->ToString();
861   HloInstruction* root = module->entry_computation()->root_instruction();
862 
863   auto sharded_input =
864       op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant());
865   auto right_halo = AllOf(op::Shape("f32[2,2]{1,0}"),
866                           op::CollectivePermute(op::Slice(sharded_input)));
867   auto pre_masking = op::DynamicSlice(
868       AllOf(
869           op::Shape("f32[6,2]{1,0}"),
870           op::Pad(op::Concatenate(sharded_input, right_halo), op::Constant())),
871       op::Reshape(), op::Constant());
872   auto index_in_padded = op::Add(
873       op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant())));
874   auto masked =
875       op::Select(op::Compare(index_in_padded, op::Broadcast(op::Constant())),
876                  pre_masking, op::Broadcast(op::Constant()));
877   EXPECT_THAT(root, AllOf(op::Shape("f32[2,2]{1,0}"),
878                           op::ReduceWindow(masked, op::Constant())));
879 }
880 
TEST_F(SpmdPartitioningTest,ReduceWindowTiledOneSideHaloBeyondNeighbor)881 TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideHaloBeyondNeighbor) {
882   absl::string_view hlo_string = R"(
883 HloModule module
884 
885 sum {
886   a = f32[] parameter(0)
887   b = f32[] parameter(1)
888   ROOT add = f32[] add(a, b)
889 }
890 
891 ENTRY entry {
892   param = f32[9,2] parameter(0), sharding={devices=[5,1]0,1,2,3,4}
893   constant.1 = f32[] constant(0), sharding={replicated}
894   ROOT reduce-window = f32[5,2]{1,0} reduce-window(param, constant.1),
895     window={size=4x1 stride=2x1 pad=3_0x0_0}, to_apply=sum,
896     sharding={devices=[5,1]0,1,2,3,4}
897 })";
898   TF_ASSERT_OK_AND_ASSIGN(auto module,
899                           PartitionComputation(hlo_string, /*num_devices=*/5));
900   VLOG(1) << module->ToString();
901   auto halo0 = AllOf(op::Shape("f32[1,2]"),
902                      op::CollectivePermute(op::Slice(op::Parameter(0))));
903   auto halo1 =
904       AllOf(op::Shape("f32[2,2]"), op::CollectivePermute(op::Parameter(0)));
905   auto pre_mask =
906       AllOf(op::Shape("f32[4,2]"),
907             op::Concatenate(halo0, halo1, op::Slice(op::Parameter(0))));
908   auto masked =
909       op::Select(op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply())),
910                              op::Broadcast(op::Constant())),
911                  pre_mask, op::Broadcast(op::Constant()));
912   HloInstruction* root = module->entry_computation()->root_instruction();
913   EXPECT_THAT(root, AllOf(op::Shape("f32[1,2]{1,0}"),
914                           op::ReduceWindow(masked, op::Constant())));
915 }
916 
TEST_F(SpmdPartitioningTest,ReduceWindowTiledOneSideUnequalHalo)917 TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideUnequalHalo) {
918   absl::string_view hlo_string = R"(
919 HloModule module
920 
921 sum {
922   a = f32[] parameter(0)
923   b = f32[] parameter(1)
924   ROOT add = f32[] add(a, b)
925 }
926 
927 ENTRY entry {
928   constant = f32[9,2]{1,0} constant(
929     {{1,1},{1,4},{2,1},{3,1},{1,2},{2,2},{4,1},{1,2},{2,1}}),
930     sharding={devices=[3,1]0,1,2}
931   constant.1 = f32[] constant(0), sharding={replicated}
932   ROOT reduce-window = f32[5,2]{1,0} reduce-window(constant, constant.1),
933     window={size=3x1 stride=2x1 pad=1_1x0_0}, to_apply=sum,
934     sharding={devices=[3,1]0,1,2}
935 })";
936   TF_ASSERT_OK_AND_ASSIGN(auto module,
937                           PartitionComputation(hlo_string, /*num_devices=*/3));
938   VLOG(1) << module->ToString();
939   HloInstruction* root = module->entry_computation()->root_instruction();
940 
941   auto sharded_input =
942       op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant());
943   auto right_halo = AllOf(op::Shape("f32[2,2]{1,0}"),
944                           op::CollectivePermute(op::Slice(sharded_input)));
945   auto pre_masking = op::DynamicSlice(
946       AllOf(
947           op::Shape("f32[7,2]{1,0}"),
948           op::Pad(op::Concatenate(sharded_input, right_halo), op::Constant())),
949       op::Reshape(), op::Constant());
950   auto index_in_padded = op::Add(
951       op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant())));
952   auto masked = op::Select(
953       op::And(op::Compare(index_in_padded, op::Broadcast(op::Constant())),
954               op::Compare(index_in_padded, op::Broadcast(op::Constant()))),
955       pre_masking, op::Broadcast(op::Constant()));
956   EXPECT_THAT(root, AllOf(op::Shape("f32[2,2]{1,0}"),
957                           op::ReduceWindow(masked, op::Constant())));
958 }
959 
TEST_F(SpmdPartitioningTest,ReduceWindowTiledTwoSideHalo)960 TEST_F(SpmdPartitioningTest, ReduceWindowTiledTwoSideHalo) {
961   absl::string_view hlo_string = R"(
962 HloModule module
963 
964 sum {
965   a = f32[] parameter(0)
966   b = f32[] parameter(1)
967   ROOT add = f32[] add(a, b)
968 }
969 
970 ENTRY entry {
971   constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}}),
972     sharding={devices=[2,1]0,1}
973   constant.1 = f32[] constant(0), sharding={replicated}
974   ROOT reduce-window = f32[2,2]{1,0} reduce-window(constant, constant.1),
975     window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum,
976     sharding={devices=[2,1]0,1}
977 })";
978   TF_ASSERT_OK_AND_ASSIGN(auto module,
979                           PartitionComputation(hlo_string, /*num_devices=*/2));
980   VLOG(1) << module->ToString();
981   HloInstruction* root = module->entry_computation()->root_instruction();
982 
983   auto sharded_input =
984       op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant());
985   auto left_halo = AllOf(op::Shape("f32[1,2]{1,0}"),
986                          op::CollectivePermute(op::Slice(sharded_input)));
987   auto right_halo = AllOf(op::Shape("f32[1,2]{1,0}"),
988                           op::CollectivePermute(op::Slice(sharded_input)));
989   auto pre_masking = AllOf(
990       op::Shape("f32[5,2]{1,0}"),
991       op::DynamicSlice(
992           AllOf(op::Shape("f32[6,2]{1,0}"),
993                 op::Pad(op::Concatenate(left_halo, sharded_input, right_halo),
994                         op::Constant())),
995           op::Reshape(), op::Constant()));
996   auto index_in_padded = op::Add(
997       op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant())));
998   auto masked = op::Select(
999       op::And(op::Compare(index_in_padded, op::Broadcast(op::Constant())),
1000               op::Compare(index_in_padded, op::Broadcast(op::Constant()))),
1001       pre_masking, op::Broadcast(op::Constant()));
1002   EXPECT_THAT(root, AllOf(op::Shape("f32[1,2]{1,0}"),
1003                           op::ReduceWindow(masked, op::Constant())));
1004 }
1005 
TEST_F(SpmdPartitioningTest,ReduceWindowTiled2D)1006 TEST_F(SpmdPartitioningTest, ReduceWindowTiled2D) {
1007   absl::string_view hlo_string = R"(
1008 HloModule module
1009 
1010 sum {
1011   a = f32[] parameter(0)
1012   b = f32[] parameter(1)
1013   ROOT add = f32[] add(a, b)
1014 }
1015 
1016 ENTRY entry {
1017   token0 = token[] after-all(), sharding={maximal device=0}
1018   infeed = (f32[4,4,2,2]{3,2,1,0}, token[]) infeed(token0),
1019     sharding={{devices=[2,2,1,1]0,1,2,3}, {maximal device=0}}
1020   infeed.data = f32[4,4,2,2]{3,2,1,0} get-tuple-element(infeed), index=0,
1021     sharding={devices=[2,2,1,1]0,1,2,3}
1022   constant = f32[] constant(0), sharding={replicated}
1023   ROOT reduce-window = f32[2,2,2,2]{3,2,1,0} reduce-window(infeed.data, constant),
1024     window={size=5x5x1x1 stride=3x3x1x1 pad=2_2x2_2x0_0x0_0}, to_apply=sum,
1025     sharding={devices=[2,2,1,1]0,1,2,3}
1026 })";
1027   TF_ASSERT_OK_AND_ASSIGN(auto module,
1028                           PartitionComputation(hlo_string, /*num_devices=*/4));
1029   VLOG(1) << module->ToString();
1030   HloInstruction* root = module->entry_computation()->root_instruction();
1031 
1032   auto sharded_input = AllOf(op::Shape("f32[2,2,2,2]{3,2,1,0}"),
1033                              op::GetTupleElement(op::Infeed()));
1034   auto dim0_left_halo = AllOf(op::Shape("f32[1,2,2,2]{3,2,1,0}"),
1035                               op::CollectivePermute(op::Slice(sharded_input)));
1036   auto dim0_right_halo = AllOf(op::Shape("f32[1,2,2,2]{3,2,1,0}"),
1037                                op::CollectivePermute(op::Slice(sharded_input)));
1038   auto dim0_pre_masking = op::DynamicSlice(
1039       AllOf(op::Shape("f32[6,2,2,2]{3,2,1,0}"),
1040             op::Pad(
1041                 op::Concatenate(dim0_left_halo, sharded_input, dim0_right_halo),
1042                 op::Constant())),
1043       op::Reshape(), op::Constant(), op::Constant(), op::Constant());
1044   auto dim0_index_in_padded = op::Add(
1045       op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant())));
1046   auto dim0_masked = op::Select(
1047       op::And(op::Compare(dim0_index_in_padded, op::Broadcast(op::Constant())),
1048               op::Compare(dim0_index_in_padded, op::Broadcast(op::Constant()))),
1049       dim0_pre_masking, op::Broadcast(op::Constant()));
1050   auto dim0_resharded = AllOf(op::Shape("f32[5,2,2,2]{3,2,1,0}"), dim0_masked);
1051   auto dim1_left_halo = AllOf(op::Shape("f32[5,1,2,2]{3,2,1,0}"),
1052                               op::CollectivePermute(op::Slice(dim0_resharded)));
1053   auto dim1_right_halo =
1054       AllOf(op::Shape("f32[5,1,2,2]{3,2,1,0}"),
1055             op::CollectivePermute(op::Slice(dim0_resharded)));
1056   auto dim1_pre_masking = op::DynamicSlice(
1057       AllOf(op::Shape("f32[5,6,2,2]{3,2,1,0}"),
1058             op::Pad(op::Concatenate(dim1_left_halo, dim0_resharded,
1059                                     dim1_right_halo),
1060                     op::Constant())),
1061       op::Constant(), op::Reshape(), op::Constant(), op::Constant());
1062   auto dim1_index_in_padded = op::Add(
1063       op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant())));
1064   auto dim1_masked = op::Select(
1065       op::And(op::Compare(dim1_index_in_padded, op::Broadcast(op::Constant())),
1066               op::Compare(dim1_index_in_padded, op::Broadcast(op::Constant()))),
1067       dim1_pre_masking, op::Broadcast(op::Constant()));
1068   auto dim1_resharded = AllOf(op::Shape("f32[5,5,2,2]{3,2,1,0}"), dim1_masked);
1069   EXPECT_THAT(root, AllOf(op::Shape("f32[1,1,2,2]{3,2,1,0}"),
1070                           op::ReduceWindow(dim1_resharded, op::Constant())));
1071 }
1072 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsReplicated)1073 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicated) {
1074   absl::string_view hlo_string = R"(
1075 HloModule module
1076 
1077 ENTRY entry {
1078   %lhs = f32[128,224,224,3] parameter(0)
1079   %lhs.copy = f32[128,224,224,3] copy(f32[128,224,224,3] %lhs),
1080     sharding={devices=[1,2,1,1]0,1}
1081   %rhs = f32[7,7,3,64] parameter(1)
1082   %rhs.copy = f32[7,7,3,64] copy(f32[7,7,3,64] %rhs),
1083     sharding={replicated}
1084   ROOT %conv = f32[128,112,112,64] convolution(
1085     f32[128,224,224,3] %lhs.copy,
1086     f32[7,7,3,64] %rhs.copy),
1087     window={size=7x7 stride=2x2 pad=3_3x3_3},
1088     dim_labels=b01f_01io->b01f,
1089     sharding={devices=[1,2,1,1]0,1}
1090 })";
1091 
1092   TF_ASSERT_OK_AND_ASSIGN(auto module,
1093                           PartitionComputation(hlo_string, /*num_devices=*/2));
1094   VLOG(1) << module->ToString();
1095 
1096   auto root = module->entry_computation()->root_instruction();
1097   auto lhs = AllOf(
1098       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1099                                 op::Constant(), op::Constant())),
1100       op::Shape("f32[128,112,224,3]"));
1101   auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]"));
1102 
1103   auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1104                          op::Shape("f32[128,3,224,3]"));
1105   auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1106                           op::Shape("f32[128,2,224,3]"));
1107   EXPECT_THAT(root,
1108               AllOf(op::Convolution(
1109                         op::Select(op::And(),
1110                                    op::Concatenate(left_halo, lhs, right_halo),
1111                                    op::Broadcast()),
1112                         rhs),
1113                     op::Shape("f32[128,56,112,64]")));
1114 }
1115 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsReplicatedNeedReshard)1116 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicatedNeedReshard) {
1117   absl::string_view hlo_string = R"(
1118 HloModule module
1119 
1120 ENTRY entry {
1121   %lhs = f32[128,224,224,3] parameter(0)
1122   %lhs.copy = f32[128,224,224,3] copy(f32[128,224,224,3] %lhs),
1123     sharding={devices=[2,1,1,1]0,1}
1124   %rhs = f32[7,7,3,64] parameter(1)
1125   %rhs.copy = f32[7,7,3,64] copy(f32[7,7,3,64] %rhs),
1126     sharding={replicated}
1127   ROOT %conv = f32[128,112,112,64] convolution(
1128     f32[128,224,224,3] %lhs.copy,
1129     f32[7,7,3,64] %rhs.copy),
1130     window={size=7x7 stride=2x2 pad=3_3x3_3},
1131     dim_labels=b01f_01io->b01f,
1132     sharding={devices=[1,2,1,1]0,1}
1133 })";
1134 
1135   TF_ASSERT_OK_AND_ASSIGN(auto module,
1136                           PartitionComputation(hlo_string, /*num_devices=*/2));
1137   VLOG(1) << module->ToString();
1138 
1139   auto root = module->entry_computation()->root_instruction();
1140   auto lhs = AllOf(
1141       op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(),
1142                                 op::Constant(), op::Constant())),
1143       op::Shape("f32[64,224,224,3]"));
1144   auto all_to_all =
1145       AllOf(op::AllToAll(op::Reshape(lhs)), op::Shape("f32[64,2,112,224,3]"));
1146   auto reshard_lhs = AllOf(op::Reshape(op::Transpose(all_to_all)),
1147                            op::Shape("f32[128,112,224,3]"));
1148 
1149   auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]"));
1150 
1151   auto left_halo = AllOf(op::CollectivePermute(op::Slice(reshard_lhs)),
1152                          op::Shape("f32[128,3,224,3]"));
1153   auto right_halo = AllOf(op::CollectivePermute(op::Slice(reshard_lhs)),
1154                           op::Shape("f32[128,2,224,3]"));
1155   EXPECT_THAT(
1156       root,
1157       AllOf(op::Convolution(
1158                 op::Select(op::And(),
1159                            op::Concatenate(left_halo, reshard_lhs, right_halo),
1160                            op::Broadcast()),
1161                 rhs),
1162             op::Shape("f32[128,56,112,64]")));
1163 }
1164 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsReplicatedReordered)1165 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicatedReordered) {
1166   absl::string_view hlo_string = R"(
1167 HloModule module
1168 
1169 ENTRY entry {
1170   %lhs = f32[224,224,3,128] parameter(0)
1171   %lhs.copy = f32[224,224,3,128] copy(%lhs), sharding={devices=[2,1,1,1]0,1}
1172   %rhs = f32[7,7,3,64] parameter(1)
1173   %rhs.copy = f32[7,7,3,64] copy(%rhs), sharding={replicated}
1174   ROOT %conv = f32[128,112,112,64] convolution(%lhs.copy, %rhs.copy),
1175     window={size=7x7 stride=2x2 pad=3_3x3_3},
1176     dim_labels=01fb_01io->b01f,
1177     sharding={devices=[1,2,1,1]0,1}
1178 })";
1179 
1180   TF_ASSERT_OK_AND_ASSIGN(auto module,
1181                           PartitionComputation(hlo_string, /*num_devices=*/2));
1182   VLOG(1) << module->ToString();
1183 
1184   auto root = module->entry_computation()->root_instruction();
1185   auto lhs = AllOf(
1186       op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(),
1187                                 op::Constant(), op::Constant())),
1188       op::Shape("f32[112,224,3,128]"));
1189   auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]"));
1190 
1191   auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1192                          op::Shape("f32[3,224,3,128]"));
1193   auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1194                           op::Shape("f32[2,224,3,128]"));
1195   EXPECT_THAT(root,
1196               AllOf(op::Convolution(
1197                         op::Select(op::And(),
1198                                    op::Concatenate(left_halo, lhs, right_halo),
1199                                    op::Broadcast()),
1200                         rhs),
1201                     op::Shape("f32[128,56,112,64]")));
1202 }
1203 
1204 // (stride * per_shard_window_count) % dilation == 0
TEST_F(SpmdPartitioningTest,ConvolutionBaseDilationSameStartPatternLhsTiledRhsReplicated)1205 TEST_F(SpmdPartitioningTest,
1206        ConvolutionBaseDilationSameStartPatternLhsTiledRhsReplicated) {
1207   absl::string_view hlo_string = R"(
1208 HloModule module
1209 
1210 ENTRY entry {
1211   %lhs = f32[128,7,7,512] parameter(0)
1212   %lhs.copy = f32[128,7,7,512] copy(%lhs),
1213     sharding={devices=[1,2,1,1]0,1}
1214   %rhs = f32[3,3,512,512] parameter(1)
1215   %rhs.copy = f32[3,3,512,512] copy(%rhs),
1216     sharding={replicated}
1217   ROOT %conv = f32[128,4,4,512] convolution(%lhs.copy, %rhs.copy),
1218     window={size=3x3 stride=4x4 pad=1_1x1_1 lhs_dilate=2x2 rhs_reversal=1x1},
1219     dim_labels=b01f_01io->b01f,
1220     sharding={devices=[1,2,1,1]0,1}
1221 })";
1222 
1223   TF_ASSERT_OK_AND_ASSIGN(auto module,
1224                           PartitionComputation(hlo_string, /*num_devices=*/2));
1225   VLOG(1) << module->ToString();
1226 
1227   auto root = module->entry_computation()->root_instruction();
1228   // There is no halo exchange, and because the last element in the shard is not
1229   // needed (stride == 4), the LHS will be just a slice.
1230   auto sliced_lhs =
1231       AllOf(op::Slice(op::Copy(op::DynamicSlice(
1232                 op::Pad(op::Parameter(), op::Constant()), op::Constant(),
1233                 op::Reshape(), op::Constant(), op::Constant()))),
1234             op::Shape("f32[128,3,7,512]"));
1235   auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[3,3,512,512]"));
1236   EXPECT_THAT(root, AllOf(op::Convolution(sliced_lhs, rhs),
1237                           op::Shape("f32[128,2,4,512]")));
1238   EXPECT_EQ(root->window().dimensions(0).padding_low(), 1);
1239   EXPECT_EQ(root->window().dimensions(0).padding_high(), 1);
1240 }
1241 
1242 // (stride * per_shard_window_count) % dilation != 0 but stride == 1
TEST_F(SpmdPartitioningTest,ConvolutionBaseDilationStride1LhsTiledRhsReplicated)1243 TEST_F(SpmdPartitioningTest,
1244        ConvolutionBaseDilationStride1LhsTiledRhsReplicated) {
1245   absl::string_view hlo_string = R"(
1246 HloModule module
1247 
1248 ENTRY entry {
1249   %lhs = f32[128,7,7,512] parameter(0)
1250   %lhs.copy = f32[128,7,7,512] copy(%lhs),
1251     sharding={devices=[1,2,1,1]0,1}
1252   %rhs = f32[3,3,512,512] parameter(1)
1253   %rhs.copy = f32[3,3,512,512] copy(%rhs),
1254     sharding={replicated}
1255   ROOT %conv = f32[128,14,14,512] convolution(%lhs.copy, %rhs.copy),
1256     window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1},
1257     dim_labels=b01f_01io->b01f,
1258     sharding={devices=[1,2,1,1]0,1}
1259 })";
1260 
1261   TF_ASSERT_OK_AND_ASSIGN(auto module,
1262                           PartitionComputation(hlo_string, /*num_devices=*/2));
1263   VLOG(1) << module->ToString();
1264 
1265   auto root = module->entry_computation()->root_instruction();
1266   auto lhs = AllOf(op::Copy(op::DynamicSlice(
1267                        op::Pad(op::Parameter(), op::Constant()), op::Constant(),
1268                        op::Reshape(), op::Constant(), op::Constant())),
1269                    op::Shape("f32[128,4,7,512]"));
1270   auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[3,3,512,512]"));
1271 
1272   auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1273                          op::Shape("f32[128,1,7,512]"));
1274   auto start_window = op::Multiply(op::Reshape(), op::Constant());
1275   auto start_input_element = op::Divide(start_window, op::Constant());
1276   auto dynamic_offset_for_padded_concat = op::Subtract(
1277       op::Constant(), op::Subtract(op::Multiply(op::Reshape(), op::Constant()),
1278                                    start_input_element));
1279   auto pre_masking =
1280       AllOf(op::Shape("f32[128,5,7,512]"),
1281             op::DynamicSlice(
1282                 AllOf(op::Shape("f32[128,6,7,512]"),
1283                       op::Pad(op::Concatenate(left_halo, lhs), op::Constant())),
1284                 op::Constant(), dynamic_offset_for_padded_concat,
1285                 op::Constant(), op::Constant()));
1286   auto masked = op::Select(
1287       op::Compare(op::Add(op::Iota(), op::Broadcast(start_input_element)),
1288                   op::Broadcast(op::Constant())),
1289       pre_masking, op::Broadcast(op::Constant()));
1290   auto dynamic_offset_on_output = op::Subtract(
1291       start_window, op::Multiply(start_input_element, op::Constant()));
1292   EXPECT_THAT(root,
1293               AllOf(op::DynamicSlice(AllOf(op::Convolution(masked, rhs),
1294                                            op::Shape("f32[128,8,14,512]")),
1295                                      op::Constant(), dynamic_offset_on_output,
1296                                      op::Constant(), op::Constant()),
1297                     op::Shape("f32[128,7,14,512]")));
1298   EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_low(), 1);
1299   EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_high(), 0);
1300 }
1301 
TEST_F(SpmdPartitioningTest,SelectAndScatterNoOverlap)1302 TEST_F(SpmdPartitioningTest, SelectAndScatterNoOverlap) {
1303   absl::string_view hlo_string = R"(
1304 HloModule module
1305 
1306 ge {
1307   a = f32[] parameter(0)
1308   b = f32[] parameter(1)
1309   ROOT compare = pred[] compare(a, b), direction=GE
1310 }
1311 
1312 sum {
1313   c = f32[] parameter(0)
1314   d = f32[] parameter(1)
1315   ROOT add = f32[] add(c, d)
1316 }
1317 
1318 ENTRY entry {
1319   %param = f32[11,4]{1,0} parameter(0)
1320   %param.copy = f32[11,4] copy(%param),
1321     sharding={devices=[4,1]0,1,2,3}
1322   constant = f32[4,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8}}),
1323     sharding={devices=[4,1]0,1,2,3}
1324   constant.1 = f32[] constant(0), sharding={replicated}
1325   ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy,
1326     constant, constant.1), window={size=3x2 stride=3x2 pad=0_1x0_0},
1327     select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3}
1328 })";
1329   TF_ASSERT_OK_AND_ASSIGN(auto module,
1330                           PartitionComputation(hlo_string, /*num_devices=*/4));
1331   VLOG(1) << module->ToString();
1332   auto root = module->entry_computation()->root_instruction();
1333   auto source =
1334       AllOf(op::Shape("f32[1,2]{1,0}"),
1335             op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()));
1336   auto masked_data = AllOf(
1337       op::Shape("f32[3,4]{1,0}"),
1338       op::Select(
1339           op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply(
1340                                               op::Reshape(), op::Constant()))),
1341                       op::Broadcast(op::Constant())),
1342           op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
1343                                     op::Reshape(), op::Constant())),
1344           op::Broadcast(op::Constant())));
1345 
1346   EXPECT_THAT(root,
1347               AllOf(op::SelectAndScatter(masked_data, source, op::Constant()),
1348                     op::Shape("f32[3,4]{1,0}")));
1349   EXPECT_EQ(root->window().dimensions(0).padding_low(), 0);
1350   EXPECT_EQ(root->window().dimensions(0).padding_high(), 0);
1351 }
1352 
TEST_F(SpmdPartitioningTest,SelectAndScatterNoOverlapReshard)1353 TEST_F(SpmdPartitioningTest, SelectAndScatterNoOverlapReshard) {
1354   absl::string_view hlo_string = R"(
1355 HloModule module
1356 
1357 ge {
1358   a = f32[] parameter(0)
1359   b = f32[] parameter(1)
1360   ROOT compare = pred[] compare(a, b), direction=GE
1361 }
1362 
1363 sum {
1364   c = f32[] parameter(0)
1365   d = f32[] parameter(1)
1366   ROOT add = f32[] add(c, d)
1367 }
1368 
1369 ENTRY entry {
1370   %param = f32[11,4]{1,0} parameter(0)
1371   %param.copy = f32[11,4] copy(%param),
1372     sharding={devices=[1,4]0,1,2,3}
1373   constant = f32[4,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8}}),
1374     sharding={devices=[4,1]0,1,2,3}
1375   constant.1 = f32[] constant(0), sharding={replicated}
1376   ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy,
1377     constant, constant.1), window={size=3x2 stride=3x2 pad=0_1x0_0},
1378     select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3}
1379 })";
1380   TF_ASSERT_OK_AND_ASSIGN(auto module,
1381                           PartitionComputation(hlo_string, /*num_devices=*/4));
1382   VLOG(1) << module->ToString();
1383   auto root = module->entry_computation()->root_instruction();
1384   auto source =
1385       AllOf(op::Shape("f32[1,2]{1,0}"),
1386             op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()));
1387   auto operand = AllOf(op::Copy(op::DynamicSlice(
1388                            op::Parameter(0), op::Constant(), op::Reshape())),
1389                        op::Shape("f32[11,1]"));
1390   auto reshard_operand = op::Reshape(op::Transpose(
1391       op::AllToAll(op::Reshape(op::Pad(operand, op::Constant())))));
1392   auto masked_data = AllOf(
1393       op::Shape("f32[3,4]{1,0}"),
1394       op::Select(
1395           op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply(
1396                                               op::Reshape(), op::Constant()))),
1397                       op::Broadcast(op::Constant())),
1398           reshard_operand, op::Broadcast(op::Constant())));
1399 
1400   EXPECT_THAT(root,
1401               AllOf(op::SelectAndScatter(masked_data, source, op::Constant()),
1402                     op::Shape("f32[3,4]{1,0}")));
1403   EXPECT_EQ(root->window().dimensions(0).padding_low(), 0);
1404   EXPECT_EQ(root->window().dimensions(0).padding_high(), 0);
1405 }
1406 
TEST_F(SpmdPartitioningTest,SelectAndScatterWithOverlap)1407 TEST_F(SpmdPartitioningTest, SelectAndScatterWithOverlap) {
1408   absl::string_view hlo_string = R"(
1409 HloModule module
1410 
1411 ge {
1412   a = f32[] parameter(0)
1413   b = f32[] parameter(1)
1414   ROOT compare = pred[] compare(a, b), direction=GE
1415 }
1416 
1417 sum {
1418   c = f32[] parameter(0)
1419   d = f32[] parameter(1)
1420   ROOT add = f32[] add(c, d)
1421 }
1422 
1423 ENTRY entry {
1424   %param = f32[11,4]{1,0} parameter(0)
1425   %param.copy = f32[11,4] copy(%param),
1426     sharding={devices=[4,1]0,1,2,3}
1427   constant = f32[6,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8},{6,6},{1,9}}),
1428     sharding={devices=[4,1]0,1,2,3}
1429   constant.1 = f32[] constant(0), sharding={replicated}
1430   ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy,
1431     constant, constant.1), window={size=3x2 stride=2x2 pad=1_1x0_0},
1432     select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3}
1433 })";
1434   TF_ASSERT_OK_AND_ASSIGN(auto module,
1435                           PartitionComputation(hlo_string, /*num_devices=*/4));
1436   VLOG(1) << module->ToString();
1437   auto root = module->entry_computation()->root_instruction();
1438 
1439   auto source_shard =
1440       AllOf(op::Shape("f32[2,2]{1,0}"),
1441             op::DynamicSlice(op::Pad(), op::Reshape(), op::Constant()));
1442   // Max halo size is the same as the shard size, so slice is not needed.
1443   auto source_left_halo = op::CollectivePermute(source_shard);
1444   auto required_source_shard_start =
1445       op::Divide(op::Multiply(op::Reshape(), op::Constant()), op::Constant());
1446   auto source_with_halo = op::DynamicSlice(
1447       AllOf(op::Shape("f32[5,2]{1,0}"),
1448             op::Pad(op::Concatenate(source_left_halo, source_shard),
1449                     op::Constant())),
1450       op::Subtract(op::Constant(),
1451                    op::Subtract(op::Multiply(op::Reshape(), op::Constant()),
1452                                 required_source_shard_start)),
1453       op::Constant());
1454   auto masked_source_with_halo = AllOf(
1455       AllOf(op::Shape("f32[3,2]{1,0}")),
1456       op::Select(
1457           op::Compare(
1458               op::Add(op::Iota(), op::Broadcast(required_source_shard_start)),
1459               op::Broadcast(op::Constant())),
1460           source_with_halo, op::Broadcast(op::Constant())));
1461 
1462   auto data_shard =
1463       AllOf(op::Shape("f32[3,4]{1,0}"),
1464             op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
1465                                       op::Reshape(), op::Constant())));
1466   auto data_left_halo = AllOf(op::Shape("f32[2,4]{1,0}"),
1467                               op::CollectivePermute(op::Slice(data_shard)));
1468   auto data_right_halo = AllOf(op::Shape("f32[2,4]{1,0}"),
1469                                op::CollectivePermute(op::Slice(data_shard)));
1470   auto required_data_start_on_padded =
1471       op::Multiply(required_source_shard_start, op::Constant());
1472   auto left_halo_size = op::Subtract(
1473       op::Add(op::Multiply(op::Reshape(), op::Constant()), op::Constant()),
1474       required_data_start_on_padded);
1475   auto data_with_halo =
1476       AllOf(op::Shape("f32[7,4]{1,0}"),
1477             op::DynamicSlice(
1478                 AllOf(op::Shape("f32[8,4]{1,0}"),
1479                       op::Pad(op::Concatenate(data_left_halo, data_shard,
1480                                               data_right_halo),
1481                               op::Constant())),
1482                 op::Subtract(op::Constant(), left_halo_size), op::Constant()));
1483   auto index_on_padded =
1484       op::Add(op::Iota(), op::Broadcast(required_data_start_on_padded));
1485   auto masked_data_with_halo = op::Select(
1486       op::And(op::Compare(index_on_padded, op::Broadcast(op::Constant())),
1487               op::Compare(index_on_padded, op::Broadcast(op::Constant()))),
1488       data_with_halo, op::Broadcast(op::Constant()));
1489 
1490   EXPECT_THAT(
1491       root, AllOf(op::DynamicSlice(op::SelectAndScatter(masked_data_with_halo,
1492                                                         masked_source_with_halo,
1493                                                         op::Constant()),
1494                                    left_halo_size, op::Constant()),
1495                   op::Shape("f32[3,4]{1,0}")));
1496   EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_low(), 0);
1497   EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_high(), 0);
1498 }
1499 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiled)1500 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiled) {
1501   absl::string_view hlo_string = R"(
1502 HloModule module
1503 
1504 ENTRY entry {
1505   %lhs = f32[128,56,56,64] parameter(0)
1506   %lhs.copy = f32[128,56,56,64] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1507   %rhs = f32[128,56,56,256] parameter(1)
1508   %rhs.copy = f32[128,56,56,256] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1509   ROOT %conv = f32[1,1,64,256] convolution(%lhs.copy, %rhs.copy),
1510     window={size=56x56}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1511 })";
1512 
1513   TF_ASSERT_OK_AND_ASSIGN(auto module,
1514                           PartitionComputation(hlo_string, /*num_devices=*/2));
1515   VLOG(1) << module->ToString();
1516 
1517   auto root = module->entry_computation()->root_instruction();
1518   auto lhs = AllOf(
1519       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1520                                 op::Constant(), op::Constant())),
1521       op::Shape("f32[128,28,56,64]"));
1522   auto rhs = AllOf(
1523       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1524                                 op::Constant(), op::Constant())),
1525       op::Shape("f32[128,28,56,256]"));
1526 
1527   EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)),
1528                           op::Shape("f32[1,1,64,256]")));
1529 }
1530 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowReversal)1531 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWindowReversal) {
1532   absl::string_view hlo_string = R"(
1533 HloModule module
1534 
1535 ENTRY entry {
1536   %lhs = f32[5,128,64] parameter(0), sharding={devices=[2,1,1]0,1}
1537   %rhs = f32[5,128,256] parameter(1), sharding={devices=[2,1,1]1,0}
1538   ROOT %conv = f32[1,64,256] convolution(%lhs, %rhs),
1539     window={size=5 rhs_reversal=1}, dim_labels=0fb_0io->0bf,
1540     sharding={replicated}
1541 })";
1542 
1543   TF_ASSERT_OK_AND_ASSIGN(auto module,
1544                           PartitionComputation(hlo_string, /*num_devices=*/2));
1545   VLOG(1) << module->ToString();
1546 
1547   auto lhs_masked =
1548       AllOf(op::Shape("f32[3,128,64]"), op::Select(_, op::Parameter(0), _));
1549   auto rhs_left_padded =
1550       op::Concatenate(op::CollectivePermute(op::Slice(op::Parameter(1))),
1551                       op::Slice(op::Parameter(1)));
1552   auto rhs_masked =
1553       AllOf(op::Shape("f32[3,128,256]"), op::Select(_, rhs_left_padded, _));
1554 
1555   auto root = module->entry_computation()->root_instruction();
1556   EXPECT_THAT(root,
1557               AllOf(op::AllReduce(op::Convolution(lhs_masked, rhs_masked)),
1558                     op::Shape("f32[1,64,256]")));
1559 }
1560 
TEST_F(SpmdPartitioningTest,DotLhsTiledRhsTiledWithReshard)1561 TEST_F(SpmdPartitioningTest, DotLhsTiledRhsTiledWithReshard) {
1562   absl::string_view hlo_string = R"(
1563 HloModule module
1564 
1565 ENTRY entry {
1566   %lhs = f32[128,56,56,64] parameter(0)
1567   %lhs.copy = f32[128,56,56,64] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1568   %rhs = f32[128,56,56,256] parameter(1)
1569   %rhs.copy = f32[128,56,56,256] copy(%rhs), sharding={devices=[2,1,1,1]0,1}
1570   ROOT %conv = f32[1,1,64,256] convolution(%lhs.copy, %rhs.copy),
1571     window={size=56x56}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1572 })";
1573 
1574   TF_ASSERT_OK_AND_ASSIGN(auto module,
1575                           PartitionComputation(hlo_string, /*num_devices=*/2));
1576   VLOG(1) << module->ToString();
1577 
1578   auto root = module->entry_computation()->root_instruction();
1579   auto lhs = AllOf(
1580       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1581                                 op::Constant(), op::Constant())),
1582       op::Shape("f32[128,28,56,64]"));
1583   auto rhs = AllOf(
1584       op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(),
1585                                 op::Constant(), op::Constant())),
1586       op::Shape("f32[64,56,56,256]"));
1587   auto all_to_all =
1588       AllOf(op::AllToAll(op::Reshape(lhs)), op::Shape("f32[2,64,28,56,64]"));
1589   auto reshard = AllOf(op::Reshape(op::Transpose(all_to_all)));
1590 
1591   EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(reshard, rhs)),
1592                           op::Shape("f32[1,1,64,256]")));
1593 }
1594 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWithReshard)1595 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithReshard) {
1596   absl::string_view hlo_string = R"(
1597 HloModule module
1598 
1599 ENTRY entry {
1600   %lhs = f32[128,56,56,512] parameter(0)
1601   %lhs.copy = f32[128,56,56,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1602   %rhs = f32[128,28,28,64] parameter(1)
1603   %rhs.copy = f32[128,28,28,64] copy(%rhs), sharding={devices=[2,1,1,1]0,1}
1604   ROOT %conv = f32[1,1,512,64] convolution(%lhs.copy, %rhs.copy),
1605     window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2},
1606     dim_labels=f01b_i01o->01bf, sharding={replicated}
1607 })";
1608 
1609   TF_ASSERT_OK_AND_ASSIGN(auto module,
1610                           PartitionComputation(hlo_string, /*num_devices=*/2));
1611   VLOG(1) << module->ToString();
1612 
1613   auto root = module->entry_computation()->root_instruction();
1614   auto lhs = AllOf(
1615       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1616                                 op::Constant(), op::Constant())),
1617       op::Shape("f32[128,28,56,512]"));
1618   auto rhs = AllOf(
1619       op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(),
1620                                 op::Constant(), op::Constant())),
1621       op::Shape("f32[64,28,28,64]"));
1622   auto all_to_all =
1623       AllOf(op::AllToAll(op::Reshape(rhs)), op::Shape("f32[64,2,14,28,64]"));
1624   auto reshard = op::Reshape(op::Transpose(all_to_all));
1625 
1626   EXPECT_THAT(root,
1627               AllOf(op::AllReduce(op::Convolution(op::Slice(lhs), reshard)),
1628                     op::Shape("f32[1,1,512,64]")));
1629 }
1630 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiled_UnevenDilatedRHSPartitioned)1631 TEST_F(SpmdPartitioningTest,
1632        ConvolutionLhsTiledRhsTiled_UnevenDilatedRHSPartitioned) {
1633   absl::string_view hlo_string = R"(
1634 HloModule module
1635 
1636 ENTRY entry {
1637   %lhs = f32[8,28,28,8] parameter(0)
1638   %lhs.copy = f32[8,28,28,8] copy(%lhs), sharding={devices=[1,4,1,1]0,1,2,3}
1639   %rhs = f32[8,14,14,64] parameter(1)
1640   %rhs.copy = f32[8,14,14,64] copy(%rhs), sharding={devices=[1,4,1,1]0,1,2,3}
1641   ROOT %conv = f32[1,1,8,64] convolution(%lhs.copy, %rhs.copy),
1642     window={size=14x14 pad=0_-1x0_-1 rhs_dilate=2x2},
1643     dim_labels=f01b_i01o->01bf, sharding={replicated}
1644 })";
1645 
1646   TF_ASSERT_OK_AND_ASSIGN(auto module,
1647                           PartitionComputation(hlo_string, /*num_devices=*/4));
1648   VLOG(1) << module->ToString();
1649 
1650   auto root = module->entry_computation()->root_instruction();
1651   auto lhs = AllOf(
1652       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1653                                 op::Constant(), op::Constant())),
1654       op::Shape("f32[8,7,28,8]"));
1655   auto rhs = AllOf(op::Pad(op::Parameter(), op::Constant()),
1656                    op::Shape("f32[8,16,14,64]"));
1657   auto selected_rhs = AllOf(
1658       op::Select(op::Compare(),
1659                  op::Copy(op::DynamicSlice(rhs, op::Constant(), op::Reshape(),
1660                                            op::Constant(), op::Constant())),
1661                  op::Broadcast()),
1662       op::Shape("f32[8,4,14,64]"));
1663   auto right_halo =
1664       AllOf(op::CollectivePermute(op::Slice(lhs)), op::Shape("f32[8,2,28,8]"));
1665   auto selected_lhs =
1666       AllOf(op::DynamicSlice(
1667                 op::Pad(op::Concatenate(lhs, right_halo), op::Constant()),
1668                 op::Constant(), op::Reshape(), op::Constant(), op::Constant()),
1669             op::Shape("f32[8,7,28,8]"));
1670   EXPECT_THAT(root,
1671               AllOf(op::AllReduce(op::Convolution(selected_lhs, selected_rhs)),
1672                     op::Shape("f32[1,1,8,64]")));
1673 }
1674 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWithPadding)1675 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithPadding) {
1676   absl::string_view hlo_string = R"(
1677 HloModule module
1678 
1679 ENTRY entry {
1680   %lhs = f32[32,28,28,128] parameter(0)
1681   %lhs.copy = f32[32,28,28,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1682   %rhs = f32[32,28,28,64] parameter(1)
1683   %rhs.copy = f32[32,28,28,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1684   ROOT %conv = f32[3,3,128,64] convolution(%lhs.copy, %rhs.copy),
1685     window={size=28x28 pad=1_1x1_1}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1686 })";
1687 
1688   TF_ASSERT_OK_AND_ASSIGN(
1689       auto module,
1690       PartitionComputation(hlo_string, /*num_devices=*/2,
1691                            /*conv_halo_exchange_always_on_lhs=*/false));
1692   VLOG(1) << module->ToString();
1693 
1694   auto root = module->entry_computation()->root_instruction();
1695   auto lhs = AllOf(
1696       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1697                                 op::Constant(), op::Constant())),
1698       op::Shape("f32[32,14,28,128]"));
1699   auto rhs = AllOf(
1700       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1701                                 op::Constant(), op::Constant())),
1702       op::Shape("f32[32,14,28,64]"));
1703 
1704   auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)),
1705                          op::Shape("f32[32,1,28,64]"));
1706   auto right_halo = AllOf(op::CollectivePermute(op::Slice(rhs)),
1707                           op::Shape("f32[32,1,28,64]"));
1708   EXPECT_THAT(root,
1709               AllOf(op::AllReduce(op::Convolution(
1710                         lhs, AllOf(op::Concatenate(left_halo, rhs, right_halo),
1711                                    op::Shape("f32[32,16,28,64]")))),
1712                     op::Shape("f32[3,3,128,64]")));
1713 }
1714 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowDilate)1715 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWindowDilate) {
1716   absl::string_view hlo_string = R"(
1717 HloModule module
1718 
1719 ENTRY entry {
1720   %lhs = f32[128,224,224,3] parameter(0)
1721   %lhs.copy = f32[128,224,224,3] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1722   %rhs = f32[128,112,112,64] parameter(1)
1723   %rhs.copy = f32[128,112,112,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1724   ROOT %conv = f32[7,7,3,64] convolution(%lhs.copy, %rhs.copy),
1725     window={size=112x112 pad=3_2x3_2 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1726 })";
1727 
1728   TF_ASSERT_OK_AND_ASSIGN(
1729       auto module,
1730       PartitionComputation(hlo_string, /*num_devices=*/2,
1731                            /*conv_halo_exchange_always_on_lhs=*/false));
1732   VLOG(1) << module->ToString();
1733 
1734   auto root = module->entry_computation()->root_instruction();
1735   auto lhs = AllOf(
1736       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1737                                 op::Constant(), op::Constant())),
1738       op::Shape("f32[128,112,224,3]"));
1739   auto rhs = AllOf(
1740       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1741                                 op::Constant(), op::Constant())),
1742       op::Shape("f32[128,56,112,64]"));
1743 
1744   auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)),
1745                          op::Shape("f32[128,2,112,64]"));
1746   auto right_halo = AllOf(op::CollectivePermute(op::Slice(rhs)),
1747                           op::Shape("f32[128,2,112,64]"));
1748   EXPECT_THAT(root,
1749               AllOf(op::AllReduce(op::Convolution(
1750                         lhs, AllOf(op::Concatenate(left_halo, rhs, right_halo),
1751                                    op::Shape("f32[128,60,112,64]")))),
1752                     op::Shape("f32[7,7,3,64]")));
1753 }
1754 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowDilateNegativeRhsPadding)1755 TEST_F(SpmdPartitioningTest,
1756        ConvolutionLhsTiledRhsTiledWindowDilateNegativeRhsPadding) {
1757   absl::string_view hlo_string = R"(
1758 HloModule module
1759 
1760 ENTRY entry {
1761   %lhs = f32[128,56,56,256] parameter(0)
1762   %lhs.copy = f32[128,56,56,256] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1763   %rhs = f32[128,28,28,512] parameter(1)
1764   %rhs.copy = f32[128,28,28,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1765   ROOT %conv = f32[1,1,256,512] convolution(%lhs.copy, %rhs.copy),
1766     window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1767 })";
1768 
1769   TF_ASSERT_OK_AND_ASSIGN(
1770       auto module,
1771       PartitionComputation(hlo_string, /*num_devices=*/2,
1772                            /*conv_halo_exchange_always_on_lhs=*/false));
1773   VLOG(1) << module->ToString();
1774 
1775   auto root = module->entry_computation()->root_instruction();
1776   auto lhs = AllOf(
1777       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1778                                 op::Constant(), op::Constant())),
1779       op::Shape("f32[128,28,56,256]"));
1780   auto rhs = AllOf(
1781       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1782                                 op::Constant(), op::Constant())),
1783       op::Shape("f32[128,14,28,512]"));
1784 
1785   EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)),
1786                           op::Shape("f32[1,1,256,512]")));
1787 }
1788 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowDilateUneven)1789 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWindowDilateUneven) {
1790   absl::string_view hlo_string = R"(
1791 HloModule module
1792 
1793 ENTRY entry {
1794   %lhs = f32[128,14,14,512] parameter(0)
1795   %lhs.copy = f32[128,14,14,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1796   %rhs = f32[128,7,7,512] parameter(1)
1797   %rhs.copy = f32[128,7,7,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1798   ROOT %conv = f32[3,3,512,512] convolution(%lhs.copy, %rhs.copy),
1799     window={size=7x7 pad=1_0x1_0 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1800 })";
1801 
1802   TF_ASSERT_OK_AND_ASSIGN(
1803       auto module,
1804       PartitionComputation(hlo_string, /*num_devices=*/2,
1805                            /*conv_halo_exchange_always_on_lhs=*/false));
1806   VLOG(1) << module->ToString();
1807 
1808   auto root = module->entry_computation()->root_instruction();
1809   auto lhs = AllOf(
1810       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1811                                 op::Constant(), op::Constant())),
1812       op::Shape("f32[128,7,14,512]"));
1813   auto rhs = AllOf(
1814       op::Select(op::Compare(),
1815                  op::Copy(op::DynamicSlice(
1816                      op::Pad(op::Parameter(), op::Constant()), op::Constant(),
1817                      op::Reshape(), op::Constant(), op::Constant())),
1818                  op::Broadcast()),
1819       op::Shape("f32[128,4,7,512]"));
1820 
1821   auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)),
1822                          op::Shape("f32[128,1,7,512]"));
1823   EXPECT_THAT(root,
1824               AllOf(op::AllReduce(op::Convolution(
1825                         AllOf(op::DynamicSlice(op::Pad(lhs, op::Constant()),
1826                                                op::Constant(), op::Subtract(),
1827                                                op::Constant(), op::Constant()),
1828                               op::Shape("f32[128,10,14,512]")),
1829                         AllOf(op::Concatenate(left_halo, rhs),
1830                               op::Shape("f32[128,5,7,512]")))),
1831                     op::Shape("f32[3,3,512,512]")));
1832 }
1833 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWithPadding_HaloOnLhs)1834 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithPadding_HaloOnLhs) {
1835   absl::string_view hlo_string = R"(
1836 HloModule module
1837 
1838 ENTRY entry {
1839   %lhs = f32[32,28,28,128] parameter(0)
1840   %lhs.copy = f32[32,28,28,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1841   %rhs = f32[32,28,28,64] parameter(1)
1842   %rhs.copy = f32[32,28,28,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1843   ROOT %conv = f32[3,3,128,64] convolution(%lhs.copy, %rhs.copy),
1844     window={size=28x28 pad=1_1x1_1}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1845 })";
1846 
1847   TF_ASSERT_OK_AND_ASSIGN(auto module,
1848                           PartitionComputation(hlo_string, /*num_devices=*/2));
1849   VLOG(1) << module->ToString();
1850 
1851   auto root = module->entry_computation()->root_instruction();
1852   auto lhs = AllOf(
1853       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1854                                 op::Constant(), op::Constant())),
1855       op::Shape("f32[32,14,28,128]"));
1856   auto rhs = AllOf(
1857       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1858                                 op::Constant(), op::Constant())),
1859       op::Shape("f32[32,14,28,64]"));
1860 
1861   auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1862                          op::Shape("f32[32,1,28,128]"));
1863   auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1864                           op::Shape("f32[32,1,28,128]"));
1865   EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(
1866                               AllOf(op::Concatenate(left_halo, lhs, right_halo),
1867                                     op::Shape("f32[32,16,28,128]")),
1868                               rhs)),
1869                           op::Shape("f32[3,3,128,64]")));
1870 }
1871 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowDilate_HaloOnLhs)1872 TEST_F(SpmdPartitioningTest,
1873        ConvolutionLhsTiledRhsTiledWindowDilate_HaloOnLhs) {
1874   absl::string_view hlo_string = R"(
1875 HloModule module
1876 
1877 ENTRY entry {
1878   %lhs = f32[128,224,224,3] parameter(0)
1879   %lhs.copy = f32[128,224,224,3] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1880   %rhs = f32[128,112,112,64] parameter(1)
1881   %rhs.copy = f32[128,112,112,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1882   ROOT %conv = f32[7,7,3,64] convolution(%lhs.copy, %rhs.copy),
1883     window={size=112x112 pad=3_2x3_2 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1884 })";
1885 
1886   TF_ASSERT_OK_AND_ASSIGN(auto module,
1887                           PartitionComputation(hlo_string, /*num_devices=*/2));
1888   VLOG(1) << module->ToString();
1889 
1890   auto root = module->entry_computation()->root_instruction();
1891   auto lhs = AllOf(
1892       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1893                                 op::Constant(), op::Constant())),
1894       op::Shape("f32[128,112,224,3]"));
1895   auto rhs = AllOf(
1896       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1897                                 op::Constant(), op::Constant())),
1898       op::Shape("f32[128,56,112,64]"));
1899 
1900   auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1901                          op::Shape("f32[128,3,224,3]"));
1902   auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1903                           op::Shape("f32[128,2,224,3]"));
1904   EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(
1905                               AllOf(op::Concatenate(left_halo, lhs, right_halo),
1906                                     op::Shape("f32[128,117,224,3]")),
1907                               rhs)),
1908                           op::Shape("f32[7,7,3,64]")));
1909 }
1910 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowDilateNegativeRhsPadding_HaloOnLhs)1911 TEST_F(SpmdPartitioningTest,
1912        ConvolutionLhsTiledRhsTiledWindowDilateNegativeRhsPadding_HaloOnLhs) {
1913   absl::string_view hlo_string = R"(
1914 HloModule module
1915 
1916 ENTRY entry {
1917   %lhs = f32[128,56,56,256] parameter(0)
1918   %lhs.copy = f32[128,56,56,256] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1919   %rhs = f32[128,28,28,512] parameter(1)
1920   %rhs.copy = f32[128,28,28,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1921   ROOT %conv = f32[1,1,256,512] convolution(%lhs.copy, %rhs.copy),
1922     window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1923 })";
1924 
1925   TF_ASSERT_OK_AND_ASSIGN(auto module,
1926                           PartitionComputation(hlo_string, /*num_devices=*/2));
1927   VLOG(1) << module->ToString();
1928 
1929   auto root = module->entry_computation()->root_instruction();
1930   auto lhs = AllOf(
1931       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1932                                 op::Constant(), op::Constant())),
1933       op::Shape("f32[128,28,56,256]"));
1934   auto rhs = AllOf(
1935       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1936                                 op::Constant(), op::Constant())),
1937       op::Shape("f32[128,14,28,512]"));
1938 
1939   EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(op::Slice(lhs), rhs)),
1940                           op::Shape("f32[1,1,256,512]")));
1941 }
1942 
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowDilateUneven_HaloOnLhs)1943 TEST_F(SpmdPartitioningTest,
1944        ConvolutionLhsTiledRhsTiledWindowDilateUneven_HaloOnLhs) {
1945   absl::string_view hlo_string = R"(
1946 HloModule module
1947 
1948 ENTRY entry {
1949   %lhs = f32[128,14,14,512] parameter(0)
1950   %lhs.copy = f32[128,14,14,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1951   %rhs = f32[128,7,7,512] parameter(1)
1952   %rhs.copy = f32[128,7,7,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1953   ROOT %conv = f32[3,3,512,512] convolution(%lhs.copy, %rhs.copy),
1954     window={size=7x7 pad=1_0x1_0 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1955 })";
1956 
1957   TF_ASSERT_OK_AND_ASSIGN(auto module,
1958                           PartitionComputation(hlo_string, /*num_devices=*/2));
1959   VLOG(1) << module->ToString();
1960 
1961   auto root = module->entry_computation()->root_instruction();
1962   auto lhs = AllOf(
1963       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1964                                 op::Constant(), op::Constant())),
1965       op::Shape("f32[128,7,14,512]"));
1966   auto rhs = AllOf(
1967       op::Select(op::Compare(),
1968                  op::Copy(op::DynamicSlice(
1969                      op::Pad(op::Parameter(), op::Constant()), op::Constant(),
1970                      op::Reshape(), op::Constant(), op::Constant())),
1971                  op::Broadcast()),
1972       op::Shape("f32[128,4,7,512]"));
1973 
1974   auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1975                           op::Shape("f32[128,1,14,512]"));
1976   EXPECT_THAT(
1977       root, AllOf(op::AllReduce(op::Convolution(
1978                       AllOf(op::DynamicSlice(
1979                                 AllOf(op::Pad(op::Concatenate(lhs, right_halo),
1980                                               op::Constant()),
1981                                       op::Shape("f32[128,10,14,512]")),
1982                                 op::Constant(), op::Reshape(), op::Constant(),
1983                                 op::Constant()),
1984                             op::Shape("f32[128,9,14,512]")),
1985                       rhs)),
1986                   op::Shape("f32[3,3,512,512]")));
1987 }
1988 
TEST_F(SpmdPartitioningTest,ConcatenateAlongNonPartitionedDimension)1989 TEST_F(SpmdPartitioningTest, ConcatenateAlongNonPartitionedDimension) {
1990   absl::string_view hlo_string = R"(
1991 HloModule module
1992 
1993 ENTRY entry {
1994   %param0 = f32[14,257] parameter(0)
1995   %param0.copy = f32[14,257] copy(%param0), sharding={devices=[2,1]0,1}
1996   %param1 = f32[14,116] parameter(1)
1997   %param1.copy = f32[14,116] copy(%param1), sharding={devices=[2,1]0,1}
1998   ROOT %concatenate = f32[14,373] concatenate(%param0.copy, %param1.copy),
1999     dimensions={1}, sharding={devices=[2,1]0,1}
2000 })";
2001 
2002   TF_ASSERT_OK_AND_ASSIGN(auto module,
2003                           PartitionComputation(hlo_string, /*num_devices=*/2));
2004   VLOG(1) << module->ToString();
2005 
2006   auto root = module->entry_computation()->root_instruction();
2007   auto param0 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
2008                                                 op::Constant())),
2009                       op::Shape("f32[7,257]"));
2010   auto param1 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
2011                                                 op::Constant())),
2012                       op::Shape("f32[7,116]"));
2013   EXPECT_THAT(root,
2014               AllOf(op::Concatenate(param0, param1), op::Shape("f32[7,373]")));
2015 }
2016 
TEST_F(SpmdPartitioningTest,ConcatenateAlongPartitionedDimension)2017 TEST_F(SpmdPartitioningTest, ConcatenateAlongPartitionedDimension) {
2018   absl::string_view hlo_string = R"(
2019 HloModule module
2020 
2021 ENTRY entry {
2022   %param0 = f32[14,257] parameter(0)
2023   %param0.copy = f32[14,257] copy(%param0), sharding={devices=[1,2]0,1}
2024   %param1 = f32[14,116] parameter(1)
2025   %param1.copy = f32[14,116] copy(%param1), sharding={devices=[1,2]0,1}
2026   ROOT %concatenate = f32[14,373] concatenate(%param0.copy, %param1.copy),
2027     dimensions={1}, sharding={devices=[1,2]0,1}
2028 })";
2029 
2030   TF_ASSERT_OK_AND_ASSIGN(auto module,
2031                           PartitionComputation(hlo_string, /*num_devices=*/2));
2032   VLOG(1) << module->ToString();
2033 
2034   auto root = module->entry_computation()->root_instruction();
2035   auto param0 =
2036       AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
2037                                       op::Constant(), op::Reshape())),
2038             op::Shape("f32[14,129]"));
2039   auto param1 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(),
2040                                                 op::Reshape())),
2041                       op::Shape("f32[14,58]"));
2042   EXPECT_THAT(root, AllOf(op::DynamicSlice(
2043                               AllOf(op::AllReduce(op::DynamicUpdateSlice(
2044                                         op::DynamicUpdateSlice(
2045                                             op::Broadcast(), param0,
2046                                             op::Constant(), op::Multiply()),
2047                                         param1, op::Constant(), op::Add())),
2048                                     op::Shape("f32[14,374]")),
2049                               op::Constant(), op::Multiply()),
2050                           op::Shape("f32[14,187]")));
2051 }
2052 
TEST_F(SpmdPartitioningTest,ConcatenateAlongBothDimensions)2053 TEST_F(SpmdPartitioningTest, ConcatenateAlongBothDimensions) {
2054   const char* const hlo_string = R"(
2055 HloModule module
2056 
2057 ENTRY entry {
2058   %param0 = f32[14,257] parameter(0), sharding={devices=[2,2]0,1,2,3}
2059   %param1 = f32[14,116] parameter(1), sharding={devices=[2,2]0,1,2,3}
2060   ROOT %concatenate = f32[14,373] concatenate(%param0, %param1),
2061     dimensions={1}, sharding={devices=[2,2]0,1,2,3}
2062 })";
2063 
2064   TF_ASSERT_OK_AND_ASSIGN(auto module,
2065                           PartitionComputation(hlo_string, /*num_devices=*/4));
2066   VLOG(1) << module->ToString();
2067 
2068   auto root = module->entry_computation()->root_instruction();
2069   auto param0 = AllOf(op::Parameter(0), op::Shape("f32[7,129]"));
2070   auto param1 = AllOf(op::Parameter(1), op::Shape("f32[7,58]"));
2071   EXPECT_THAT(root, AllOf(op::DynamicSlice(
2072                               AllOf(op::AllReduce(op::DynamicUpdateSlice(
2073                                         op::DynamicUpdateSlice(
2074                                             op::Broadcast(), param0,
2075                                             op::Constant(), op::Multiply()),
2076                                         param1, op::Constant(), op::Add())),
2077                                     op::Shape("f32[7,374]")),
2078                               op::Constant(), op::Multiply()),
2079                           op::Shape("f32[7,187]")));
2080 }
2081 
TEST_F(SpmdPartitioningTest,PadAlongNonPartitionedDimension)2082 TEST_F(SpmdPartitioningTest, PadAlongNonPartitionedDimension) {
2083   absl::string_view hlo_string = R"(
2084 HloModule module
2085 
2086 ENTRY entry {
2087   %param0 = f32[128,14,257] parameter(0), sharding={devices=[1,1,2]0,1}
2088   %const = f32[] constant(0)
2089   ROOT %pad = f32[128,17,257] pad(%param0, %const), padding=0_0x1_2x0_0,
2090     sharding={devices=[1,1,2]0,1}
2091 })";
2092 
2093   TF_ASSERT_OK_AND_ASSIGN(auto module,
2094                           PartitionComputation(hlo_string, /*num_devices=*/2));
2095   VLOG(1) << module->ToString();
2096 
2097   auto root = module->entry_computation()->root_instruction();
2098   auto param0 = AllOf(op::Parameter(), op::Shape("f32[128,14,129]"));
2099   EXPECT_THAT(root, AllOf(op::Pad(param0, op::Constant()),
2100                           op::Shape("f32[128,17,129]")));
2101 }
2102 
TEST_F(SpmdPartitioningTest,PadAlongPartitionedDimension)2103 TEST_F(SpmdPartitioningTest, PadAlongPartitionedDimension) {
2104   absl::string_view hlo_string = R"(
2105 HloModule module
2106 
2107 ENTRY entry {
2108   %param0 = f32[14,257] parameter(0), sharding={devices=[1,2]0,1}
2109   %const = f32[] constant(0)
2110   ROOT %pad = f32[14,259] pad(%param0, %const), padding=0_0x0_2,
2111     sharding={devices=[1,2]0,1}
2112 })";
2113 
2114   TF_ASSERT_OK_AND_ASSIGN(auto module,
2115                           PartitionComputation(hlo_string, /*num_devices=*/2));
2116   VLOG(1) << module->ToString();
2117 
2118   auto root = module->entry_computation()->root_instruction();
2119   auto param0 = AllOf(op::Parameter(), op::Shape("f32[14,129]"));
2120   auto after_halo_exchange =
2121       AllOf(op::Shape("f32[14,130]"),
2122             op::Concatenate(param0, op::CollectivePermute(op::Slice(param0))));
2123   auto pad = AllOf(op::Shape("f32[14,131]"),
2124                    op::Pad(after_halo_exchange, op::Constant()));
2125   EXPECT_THAT(root, op::Select(_, op::DynamicSlice(pad, op::Constant(), _), _));
2126 }
2127 
TEST_F(SpmdPartitioningTest,PadAlongPartitionedDimensionWithInteriorPadding)2128 TEST_F(SpmdPartitioningTest, PadAlongPartitionedDimensionWithInteriorPadding) {
2129   absl::string_view hlo_string = R"(
2130 HloModule module
2131 
2132 ENTRY entry {
2133   %param0 = f32[7] parameter(0), sharding={devices=[2]0,1}
2134   %param1 = f32[] parameter(1), sharding={replicated}
2135   ROOT %pad = f32[22] pad(%param0, %param1), padding=2_1_2,
2136     sharding={devices=[2]0,1}
2137 })";
2138 
2139   TF_ASSERT_OK_AND_ASSIGN(auto module,
2140                           PartitionComputation(hlo_string, /*num_devices=*/2));
2141   VLOG(1) << module->ToString();
2142   auto root = module->entry_computation()->root_instruction();
2143 
2144   auto param0 = AllOf(op::Parameter(), op::Shape("f32[4]"));
2145   auto after_halo_exchange = AllOf(
2146       op::Shape("f32[4]"),
2147       op::DynamicSlice(
2148           AllOf(op::Shape("f32[5]"),
2149                 op::Pad(AllOf(op::Shape("f32[4]"),
2150                               op::Concatenate(
2151                                   op::CollectivePermute(op::Slice(param0)),
2152                                   op::Slice(param0))),
2153                         op::Parameter(1))),
2154           _));
2155   auto pad = op::Pad(after_halo_exchange, op::Parameter(1));
2156   EXPECT_THAT(root, op::DynamicSlice(pad, _));
2157 }
2158 
TEST_F(SpmdPartitioningTest,PartialReplicatePad)2159 TEST_F(SpmdPartitioningTest, PartialReplicatePad) {
2160   absl::string_view hlo_string = R"(
2161 HloModule module
2162 
2163 ENTRY entry {
2164   %param0 = f32[11,7] parameter(0),
2165     sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
2166   %param1 = f32[] parameter(1), sharding={replicated}
2167   ROOT %pad = f32[27,22] pad(%param0, %param1), padding=2_4_1x2_1_2,
2168     sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
2169 })";
2170 
2171   TF_ASSERT_OK_AND_ASSIGN(auto module,
2172                           PartitionComputation(hlo_string, /*num_devices=*/4));
2173   VLOG(1) << module->ToString();
2174   auto root = module->entry_computation()->root_instruction();
2175 
2176   auto param0 = AllOf(op::Parameter(), op::Shape("f32[11,4]"));
2177   auto after_halo_exchange = AllOf(
2178       op::Shape("f32[11,4]"),
2179       op::DynamicSlice(
2180           AllOf(op::Shape("f32[11,5]"),
2181                 op::Pad(AllOf(op::Shape("f32[11,4]"),
2182                               op::Concatenate(
2183                                   op::CollectivePermute(op::Slice(param0)),
2184                                   op::Slice(param0))),
2185                         op::Parameter(1))),
2186           op::Constant(), _));
2187   auto pad = op::Pad(after_halo_exchange, op::Parameter(1));
2188   EXPECT_THAT(root, AllOf(op::DynamicSlice(pad, op::Constant(), _),
2189                           op::Shape("f32[27,11]")));
2190 }
2191 
TEST_F(SpmdPartitioningTest,SliceAlongNonPartitionedDimension)2192 TEST_F(SpmdPartitioningTest, SliceAlongNonPartitionedDimension) {
2193   absl::string_view hlo_string = R"(
2194 HloModule module
2195 
2196 ENTRY entry {
2197   %param0 = f32[128,14,257] parameter(0)
2198   %param0.copy = f32[128,14,257] copy(%param0), sharding={devices=[1,1,2]0,1}
2199   ROOT %slice = f32[128,11,257] slice(%param0.copy),
2200     slice={[0:128:1], [2:13:1], [0:257:1]}, sharding={devices=[1,1,2]0,1}
2201 })";
2202 
2203   TF_ASSERT_OK_AND_ASSIGN(auto module,
2204                           PartitionComputation(hlo_string, /*num_devices=*/2));
2205   VLOG(1) << module->ToString();
2206 
2207   auto root = module->entry_computation()->root_instruction();
2208   auto param0 = AllOf(
2209       op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
2210                                 op::Constant(), op::Constant(), op::Reshape())),
2211       op::Shape("f32[128,14,129]"));
2212   EXPECT_THAT(root, AllOf(op::Slice(param0), op::Shape("f32[128,11,129]")));
2213 }
2214 
TEST_F(SpmdPartitioningTest,SliceAlongPartitionedDimension)2215 TEST_F(SpmdPartitioningTest, SliceAlongPartitionedDimension) {
2216   absl::string_view hlo_string = R"(
2217 HloModule module
2218 
2219 ENTRY entry {
2220   %param0 = f32[128,14,257] parameter(0), sharding={devices=[1,1,2]0,1}
2221   ROOT %slice = f32[63,14,251] slice(%param0),
2222     slice={[2:128:2], [0:14:1], [5:256:1]}, sharding={devices=[1,1,2]0,1}
2223 })";
2224 
2225   TF_ASSERT_OK_AND_ASSIGN(auto module,
2226                           PartitionComputation(hlo_string, /*num_devices=*/2));
2227   VLOG(1) << module->ToString();
2228 
2229   auto root = module->entry_computation()->root_instruction();
2230   auto param0 = AllOf(op::Parameter(0), op::Shape("f32[128,14,129]"));
2231   EXPECT_THAT(
2232       root,
2233       AllOf(op::Slice(AllOf(
2234                 op::DynamicSlice(
2235                     AllOf(op::Concatenate(
2236                               op::Slice(param0),
2237                               AllOf(op::CollectivePermute(op::Slice(param0)),
2238                                     op::Shape("f32[128,14,2]"))),
2239                           op::Shape("f32[128,14,129]")),
2240                     op::Constant(), op::Constant(), op::Add()),
2241                 op::Shape("f32[128,14,126]"))),
2242             op::Shape("f32[63,14,126]")));
2243 }
2244 
TEST_F(SpmdPartitioningTest,SliceAlongPartitionedDimension2)2245 TEST_F(SpmdPartitioningTest, SliceAlongPartitionedDimension2) {
2246   absl::string_view hlo_string = R"(
2247 HloModule module
2248 
2249 ENTRY entry {
2250   %param0 = f32[4] parameter(0), sharding={devices=[4]0,1,2,3}
2251   ROOT %slice = f32[1] slice(%param0),
2252     slice={[3:4]}, sharding={devices=[4]0,1,2,3}
2253 })";
2254 
2255   TF_ASSERT_OK_AND_ASSIGN(auto module,
2256                           PartitionComputation(hlo_string, /*num_devices=*/4));
2257   VLOG(1) << module->ToString();
2258 
2259   auto root = module->entry_computation()->root_instruction();
2260   auto param0 = AllOf(op::Parameter(0), op::Shape("f32[1]"));
2261   EXPECT_THAT(root, AllOf(op::Copy(op::CollectivePermute(param0)),
2262                           op::Shape("f32[1]")));
2263 }
2264 
TEST_F(SpmdPartitioningTest,MergedPadThenSliceShiftRight)2265 TEST_F(SpmdPartitioningTest, MergedPadThenSliceShiftRight) {
2266   absl::string_view hlo_string = R"(
2267 HloModule module
2268 
2269 ENTRY entry {
2270   %param0 = f32[4] parameter(0), sharding={devices=[4]0,1,2,3}
2271   %init = f32[] constant(2.0)
2272   %pad = f32[5] pad(%param0, %init), padding=1_0, sharding={devices=[4]0,1,2,3}
2273   %copy = f32[5] copy(%pad), sharding={devices=[4]0,1,2,3}
2274   %copy.1 = f32[5] copy(%copy), sharding={devices=[4]0,1,2,3}
2275   ROOT %slice = f32[4] slice(%copy.1), slice={[0:4]}, sharding={devices=[4]0,1,2,3}
2276 })";
2277 
2278   TF_ASSERT_OK_AND_ASSIGN(auto module,
2279                           PartitionComputation(hlo_string, /*num_devices=*/4));
2280   VLOG(1) << module->ToString();
2281 
2282   auto root = module->entry_computation()->root_instruction();
2283   auto param0 = AllOf(op::Parameter(0), op::Shape("f32[1]"));
2284   EXPECT_THAT(root, AllOf(op::Select(_, op::CollectivePermute(param0), _),
2285                           op::Shape("f32[1]")));
2286 }
2287 
2288 // Same as above except that it uses zero padding, so there is no need for
2289 // masking.
TEST_F(SpmdPartitioningTest,MergedPadThenSliceShiftRightNoMasking)2290 TEST_F(SpmdPartitioningTest, MergedPadThenSliceShiftRightNoMasking) {
2291   absl::string_view hlo_string = R"(
2292 HloModule module
2293 
2294 ENTRY entry {
2295   %param0 = f32[4] parameter(0), sharding={devices=[4]0,1,2,3}
2296   %init = f32[] constant(0)
2297   %pad = f32[5] pad(%param0, %init), padding=1_0, sharding={devices=[4]0,1,2,3}
2298   %copy = f32[5] copy(%pad), sharding={devices=[4]0,1,2,3}
2299   %copy.1 = f32[5] copy(%copy), sharding={devices=[4]0,1,2,3}
2300   ROOT %slice = f32[4] slice(%copy.1), slice={[0:4]}, sharding={devices=[4]0,1,2,3}
2301 })";
2302 
2303   TF_ASSERT_OK_AND_ASSIGN(auto module,
2304                           PartitionComputation(hlo_string, /*num_devices=*/4));
2305   VLOG(1) << module->ToString();
2306 
2307   auto root = module->entry_computation()->root_instruction();
2308   auto param0 = AllOf(op::Parameter(0), op::Shape("f32[1]"));
2309   EXPECT_THAT(root, AllOf(op::CollectivePermute(param0), op::Shape("f32[1]")));
2310 }
2311 
TEST_F(SpmdPartitioningTest,MergedSliceThenConcatRotateRight)2312 TEST_F(SpmdPartitioningTest, MergedSliceThenConcatRotateRight) {
2313   absl::string_view hlo_string = R"(
2314 HloModule module
2315 
2316 ENTRY entry {
2317   %param0 = f32[12] parameter(0), sharding={devices=[4]0,1,2,3}
2318   %slice0 = f32[2] slice(%param0), slice={[10:12]}, sharding={devices=[4]0,1,2,3}
2319   %slice1 = f32[10] slice(%param0), slice={[0:10]}, sharding={devices=[4]0,1,2,3}
2320   ROOT %concat = f32[12] concatenate(%slice0, %slice1), dimensions={0},
2321     sharding={devices=[4]0,1,2,3}
2322 })";
2323 
2324   TF_ASSERT_OK_AND_ASSIGN(auto module,
2325                           PartitionComputation(hlo_string, /*num_devices=*/4));
2326   VLOG(1) << module->ToString();
2327 
2328   auto root = module->entry_computation()->root_instruction();
2329   auto param0 = AllOf(op::Parameter(0), op::Shape("f32[3]"));
2330   auto rotate = op::Concatenate(op::CollectivePermute(op::Slice(param0)),
2331                                 op::Slice(param0));
2332   EXPECT_THAT(root, AllOf(rotate, op::Shape("f32[3]")));
2333 }
2334 
TEST_F(SpmdPartitioningTest,MergedSliceThenConcatRotateRightWithAlignedPadding)2335 TEST_F(SpmdPartitioningTest,
2336        MergedSliceThenConcatRotateRightWithAlignedPadding) {
2337   absl::string_view hlo_string = R"(
2338 HloModule module
2339 
2340 ENTRY entry {
2341   %param0 = f32[6] parameter(0), sharding={devices=[4]0,1,2,3}
2342   %slice0 = f32[2] slice(%param0), slice={[4:6]}, sharding={devices=[4]0,1,2,3}
2343   %slice1 = f32[4] slice(%param0), slice={[0:4]}, sharding={devices=[4]0,1,2,3}
2344   ROOT %concat = f32[6] concatenate(%slice0, %slice1), dimensions={0},
2345     sharding={devices=[4]0,1,2,3}
2346 })";
2347 
2348   TF_ASSERT_OK_AND_ASSIGN(auto module,
2349                           PartitionComputation(hlo_string, /*num_devices=*/4));
2350   VLOG(1) << module->ToString();
2351 
2352   auto root = module->entry_computation()->root_instruction();
2353   auto param0 = AllOf(op::Parameter(0), op::Shape("f32[2]"));
2354   EXPECT_THAT(root, op::CollectivePermute(param0));
2355 }
2356 
TEST_F(SpmdPartitioningTest,MergedSliceThenConcatRotateRightWithUnalignedPadding)2357 TEST_F(SpmdPartitioningTest,
2358        MergedSliceThenConcatRotateRightWithUnalignedPadding) {
2359   absl::string_view hlo_string = R"(
2360 HloModule module
2361 
2362 ENTRY entry {
2363   %param0 = f32[10] parameter(0), sharding={devices=[4]0,1,2,3}
2364   %slice0 = f32[6] slice(%param0), slice={[4:10]}, sharding={devices=[4]0,1,2,3}
2365   %slice1 = f32[4] slice(%param0), slice={[0:4]}, sharding={devices=[4]0,1,2,3}
2366   ROOT %concat = f32[10] concatenate(%slice0, %slice1), dimensions={0},
2367     sharding={devices=[4]0,1,2,3}
2368 })";
2369 
2370   TF_ASSERT_OK_AND_ASSIGN(auto module,
2371                           PartitionComputation(hlo_string, /*num_devices=*/4));
2372   VLOG(1) << module->ToString();
2373 
2374   auto root = module->entry_computation()->root_instruction();
2375   auto param0 = AllOf(op::Parameter(0), op::Shape("f32[3]"));
2376   auto rotate0 = op::CollectivePermute(param0);
2377   auto rotate1 = op::Concatenate(op::CollectivePermute(op::Slice(param0)),
2378                                  op::CollectivePermute(op::Slice(param0)));
2379   EXPECT_THAT(root,
2380               AllOf(op::Select(_, rotate1, rotate0), op::Shape("f32[3]")));
2381 }
2382 
TEST_F(SpmdPartitioningTest,PartialReplicateSliceAlongNonPartitionedDimension)2383 TEST_F(SpmdPartitioningTest,
2384        PartialReplicateSliceAlongNonPartitionedDimension) {
2385   absl::string_view hlo_string = R"(
2386 HloModule module
2387 
2388 ENTRY entry {
2389   %param0 = f32[128,14,257] parameter(0), sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate}
2390   ROOT %slice = f32[128,11,257] slice(%param0),
2391     slice={[0:128:1], [2:13:1], [0:257:1]}, sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate}
2392 })";
2393 
2394   TF_ASSERT_OK_AND_ASSIGN(auto module,
2395                           PartitionComputation(hlo_string, /*num_devices=*/4));
2396   VLOG(1) << module->ToString();
2397 
2398   auto root = module->entry_computation()->root_instruction();
2399   auto param0 = AllOf(op::Parameter(), op::Shape("f32[128,14,129]"));
2400   EXPECT_THAT(root, AllOf(op::Slice(param0), op::Shape("f32[128,11,129]")));
2401 }
2402 
TEST_F(SpmdPartitioningTest,PartialReplicateSliceAlongPartitionedDimension)2403 TEST_F(SpmdPartitioningTest, PartialReplicateSliceAlongPartitionedDimension) {
2404   absl::string_view hlo_string = R"(
2405 HloModule module
2406 
2407 ENTRY entry {
2408   %param0 = f32[128,14,257] parameter(0), sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate}
2409   ROOT %slice = f32[63,14,251] slice(%param0),
2410     slice={[2:128:2], [0:14:1], [5:256:1]}, sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate}
2411 })";
2412 
2413   TF_ASSERT_OK_AND_ASSIGN(auto module,
2414                           PartitionComputation(hlo_string, /*num_devices=*/4));
2415   VLOG(1) << module->ToString();
2416 
2417   auto root = module->entry_computation()->root_instruction();
2418   auto param0 = AllOf(op::Parameter(), op::Shape("f32[128,14,129]"));
2419   EXPECT_THAT(
2420       root,
2421       AllOf(
2422           op::Slice(AllOf(
2423               op::DynamicSlice(
2424                   AllOf(op::Concatenate(
2425                             op::Slice(param0),
2426                             AllOf(op::CollectivePermute(op::Slice(param0)),
2427                                   op::Shape("f32[128,14,2]"))),
2428                         op::Shape("f32[128,14,129]")),
2429                   op::Constant(), op::Constant(),
2430                   op::Add(op::Multiply(op::Reshape(op::DynamicSlice(
2431                                            op::Constant(), op::PartitionId())),
2432                                        op::Constant()),
2433                           op::Constant())),
2434               op::Shape("f32[128,14,126]"))),
2435           op::Shape("f32[63,14,126]")));
2436 }
2437 
TEST_F(SpmdPartitioningTest,SortAlongNonPartitionedDimension)2438 TEST_F(SpmdPartitioningTest, SortAlongNonPartitionedDimension) {
2439   absl::string_view hlo_string = R"(
2440 HloModule module
2441 
2442 ge {
2443   p.0.lhs.1247 = f32[]{:T(256)} parameter(0), sharding={replicated}
2444   bitcast-convert = s32[]{:T(256)} bitcast-convert(p.0.lhs.1247), sharding={replicated}
2445   constant = s32[]{:T(256)} constant(0), sharding={replicated}
2446   compare = pred[]{:T(256)E(32)} compare(bitcast-convert, constant), direction=LT, sharding={replicated}
2447   constant.1 = u32[]{:T(256)} constant(2147483647), sharding={replicated}
2448   bitcast-convert.1 = u32[]{:T(256)} bitcast-convert(p.0.lhs.1247), sharding={replicated}
2449   subtract = u32[]{:T(256)} subtract(constant.1, bitcast-convert.1), sharding={replicated}
2450   bitcast-convert.2 = s32[]{:T(256)} bitcast-convert(subtract), sharding={replicated}
2451   select = s32[]{:T(256)} select(compare, bitcast-convert.2, bitcast-convert), sharding={replicated}
2452   p.0.rhs.1248 = f32[]{:T(256)} parameter(1), sharding={replicated}
2453   bitcast-convert.3 = s32[]{:T(256)} bitcast-convert(p.0.rhs.1248), sharding={replicated}
2454   compare.1 = pred[]{:T(256)E(32)} compare(bitcast-convert.3, constant), direction=LT, sharding={replicated}
2455   bitcast-convert.4 = u32[]{:T(256)} bitcast-convert(p.0.rhs.1248), sharding={replicated}
2456   subtract.1 = u32[]{:T(256)} subtract(constant.1, bitcast-convert.4), sharding={replicated}
2457   bitcast-convert.5 = s32[]{:T(256)} bitcast-convert(subtract.1), sharding={replicated}
2458   select.1 = s32[]{:T(256)} select(compare.1, bitcast-convert.5, bitcast-convert.3), sharding={replicated}
2459   compare.2 = pred[]{:T(256)E(32)} compare(select, select.1), direction=GT, sharding={replicated}
2460   compare.258 = pred[]{:T(256)E(32)} compare(select.1, select), direction=GT, sharding={replicated}
2461   compare.259 = pred[]{:T(256)E(32)} compare(compare.2, compare.258), direction=EQ, sharding={replicated}
2462   p.1.lhs.1249 = s32[]{:T(256)} parameter(2), sharding={replicated}
2463   p.1.rhs.1250 = s32[]{:T(256)} parameter(3), sharding={replicated}
2464   compare.260 = pred[]{:T(256)E(32)} compare(p.1.lhs.1249, p.1.rhs.1250), direction=LT, sharding={replicated}
2465   ROOT select.86 = pred[]{:T(256)E(32)} select(compare.259, compare.260, compare.2), sharding={replicated}
2466 }
2467 
2468 ENTRY entry {
2469   %param0 = f32[128,14,257] parameter(0)
2470   %param0.copy = f32[128,14,257] copy(%param0), sharding={devices=[1,2,1]0,1}
2471   %param1 = s32[128,14,257] parameter(1)
2472   %param1.copy = s32[128,14,257] copy(%param1), sharding={devices=[1,2,1]0,1}
2473   ROOT %sort.6 = (f32[128,14,257]{2,1,0:T(8,128)}, s32[128,14,257]{2,1,0:T(8,128)})
2474     sort(%param0.copy, %param1.copy), dimensions={2}, is_stable=true,
2475     to_apply=%ge, sharding={{devices=[1,2,1]0,1},{devices=[1,2,1]0,1}}
2476 })";
2477 
2478   TF_ASSERT_OK_AND_ASSIGN(auto module,
2479                           PartitionComputation(hlo_string, /*num_devices=*/2));
2480   VLOG(1) << module->ToString();
2481 
2482   auto root = module->entry_computation()->root_instruction();
2483   auto param0 =
2484       AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
2485                                       op::Reshape(), op::Constant())),
2486             op::Shape("f32[128,7,257]"));
2487   auto param1 =
2488       AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(),
2489                                       op::Reshape(), op::Constant())),
2490             op::Shape("s32[128,7,257]"));
2491   EXPECT_THAT(root, AllOf(op::Sort(param0, param1),
2492                           op::Shape("(f32[128,7,257], s32[128,7,257])")));
2493 }
2494 
TEST_F(SpmdPartitioningTest,PartitionCustomCall)2495 TEST_F(SpmdPartitioningTest, PartitionCustomCall) {
2496   absl::string_view hlo_string = R"(
2497 HloModule cluster_2013453984438090939__.47
2498 
2499 ENTRY %cluster_2013453984438090939__.47
2500   (arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
2501   %arg_tuple.1 = bf16[2,209664] parameter(0)
2502   %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
2503   %custom-call = (bf16[2,2000]{1,0}, s32[2,2000]{1,0})
2504     custom-call(bf16[2,209664]{1,0} %copy.arg_tuple.1), custom_call_target="TopK"
2505   %get-tuple-element = bf16[2,2000]{1,0}
2506     get-tuple-element((bf16[2,2000]{1,0}, s32[2,2000]{1,0}) %custom-call),
2507     index=0, sharding={replicated}
2508   %get-tuple-element.1 = s32[2,2000]{1,0} get-tuple-element((bf16[2,2000]{1,0},
2509     s32[2,2000]{1,0}) %custom-call), index=1, sharding={replicated}
2510   ROOT %tuple.46 = (bf16[2,2000]{1,0}, s32[2,2000]{1,0})
2511     tuple(bf16[2,2000]{1,0} %get-tuple-element, s32[2,2000]{1,0}
2512     %get-tuple-element.1), sharding={{replicated}, {replicated}},
2513     metadata={op_name="XLA_Retvals"}
2514 })";
2515 
2516   TF_ASSERT_OK_AND_ASSIGN(auto module,
2517                           PartitionComputation(hlo_string, /*num_devices=*/2));
2518   VLOG(1) << module->ToString();
2519   auto custom_call = FindInstruction(module.get(), "custom-call.1");
2520   EXPECT_EQ(custom_call->operand(0)->shape().dimensions(1), 104832);
2521   auto sort = FindInstruction(module.get(), "sort");
2522   EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 4000);
2523   EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 4000);
2524 }
2525 
TEST_F(SpmdPartitioningTest,PartitionCustomCall_TwoPartitionedDims)2526 TEST_F(SpmdPartitioningTest, PartitionCustomCall_TwoPartitionedDims) {
2527   absl::string_view hlo_string = R"(
2528 HloModule module
2529 
2530 ENTRY entry {
2531   %param0 = f32[8,32128] parameter(0)
2532   %copy.0 = f32[8,32128] copy(%param0),
2533     sharding={devices=[4,2]0,1,2,3,4,5,6,7}
2534   %custom-call = (f32[8,2]{1,0}, s32[8,2]{1,0})
2535     custom-call(%copy.0), custom_call_target="TopK"
2536   %get-tuple-element = f32[8,2]{1,0}
2537     get-tuple-element((f32[8,2]{1,0}, s32[8,2]{1,0}) %custom-call), index=0,
2538     sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
2539   %get-tuple-element.1 = s32[8,2]{1,0}
2540     get-tuple-element((f32[8,2]{1,0}, s32[8,2]{1,0}) %custom-call), index=1,
2541     sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
2542   ROOT %tuple = (f32[8,2]{1,0}, s32[8,2]{1,0})
2543     tuple(%get-tuple-element, %get-tuple-element.1),
2544     sharding={{replicated}, {replicated}}
2545 })";
2546 
2547   TF_ASSERT_OK_AND_ASSIGN(auto module,
2548                           PartitionComputation(hlo_string, /*num_devices=*/8));
2549   VLOG(1) << module->ToString();
2550   auto custom_call = FindInstruction(module.get(), "custom-call.1");
2551   EXPECT_EQ(custom_call->operand(0)->shape().dimensions(1), 16064);
2552   auto sort = FindInstruction(module.get(), "sort");
2553   EXPECT_EQ(sort->operand(0)->shape().dimensions(0), 2);
2554   EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 4);
2555   EXPECT_EQ(sort->operand(1)->shape().dimensions(0), 2);
2556   EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 4);
2557 }
2558 
TEST_F(SpmdPartitioningTest,PartitionSortInTopK)2559 TEST_F(SpmdPartitioningTest, PartitionSortInTopK) {
2560   absl::string_view hlo_string = R"(
2561 HloModule module
2562 
2563 %compare-greater-than.8 (p.0.lhs.9: bf16[], p.0.rhs.10: bf16[], p.1.lhs.11:
2564    s32[], p.1.rhs.12: s32[]) -> pred[] {
2565   %p.1.lhs.11 = s32[] parameter(2)
2566   %p.1.rhs.12 = s32[] parameter(3)
2567   %p.0.lhs.9 = bf16[] parameter(0)
2568   %convert.13 = f32[] convert(bf16[] %p.0.lhs.9)
2569   %bitcast-convert.16 = s32[] bitcast-convert(f32[] %convert.13)
2570   %constant.20 = s32[] constant(0)
2571   %compare.21 = pred[] compare(s32[] %bitcast-convert.16, s32[] %constant.20),
2572     direction=LT
2573   %constant.15 = u32[] constant(2147483647)
2574   %bitcast-convert.17 = u32[] bitcast-convert(f32[] %convert.13)
2575   %subtract.18 = u32[] subtract(u32[] %constant.15, u32[] %bitcast-convert.17)
2576   %bitcast-convert.19 = s32[] bitcast-convert(u32[] %subtract.18)
2577   %select.22 = s32[] select(pred[] %compare.21, s32[] %bitcast-convert.19, s32[]
2578     %bitcast-convert.16)
2579   %p.0.rhs.10 = bf16[] parameter(1)
2580   %convert.14 = f32[] convert(bf16[] %p.0.rhs.10)
2581   %bitcast-convert.24 = s32[] bitcast-convert(f32[] %convert.14)
2582   %constant.28 = s32[] constant(0)
2583   %compare.29 = pred[] compare(s32[] %bitcast-convert.24, s32[] %constant.28),
2584     direction=LT
2585   %constant.23 = u32[] constant(2147483647)
2586   %bitcast-convert.25 = u32[] bitcast-convert(f32[] %convert.14)
2587   %subtract.26 = u32[] subtract(u32[] %constant.23, u32[] %bitcast-convert.25)
2588   %bitcast-convert.27 = s32[] bitcast-convert(u32[] %subtract.26)
2589   %select.30 = s32[] select(pred[] %compare.29, s32[] %bitcast-convert.27, s32[]
2590     %bitcast-convert.24)
2591   ROOT %compare.31 = pred[] compare(s32[] %select.22, s32[] %select.30),
2592     direction=GT
2593 }
2594 
2595 ENTRY entry
2596   (arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
2597   %arg_tuple.1 = bf16[2,209664] parameter(0)
2598   %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
2599   %iota.7 = s32[2,209664] iota(), iota_dimension=1,
2600     metadata={op_type="TopKV2" op_name="TopKV2"}
2601   %sort.32 = (bf16[2,209664], s32[2,209664])
2602     sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
2603     dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
2604     metadata={op_type="TopKV2" op_name="TopKV2"}
2605   %get-tuple-element.33 = bf16[2,209664]
2606     get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2607     index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
2608   %slice.34 = bf16[2,2000] slice(bf16[2,209664]
2609     %get-tuple-element.33), slice={[0:2], [0:2000]},
2610     metadata={op_type="TopKV2" op_name="TopKV2"}
2611   %get-tuple-element.35 = s32[2,209664]
2612     get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2613     index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
2614   %slice.36 = s32[2,2000] slice(s32[2,209664]
2615     %get-tuple-element.35), slice={[0:2], [0:2000]},
2616     metadata={op_type="TopKV2" op_name="TopKV2"}
2617   ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
2618     tuple(bf16[2,2000] %slice.34, s32[2,2000]
2619     %slice.36), sharding={{replicated}, {replicated}},
2620     metadata={op_name="XLA_Retvals"}
2621 })";
2622 
2623   TF_ASSERT_OK_AND_ASSIGN(auto module,
2624                           PartitionComputation(hlo_string, /*num_devices=*/2));
2625   VLOG(1) << module->ToString();
2626   auto sort = FindInstruction(module.get(), "sort");
2627   EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 104832);
2628   EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 104832);
2629   auto final_sort = FindInstruction(module.get(), "sort.1");
2630   EXPECT_EQ(final_sort->operand(0)->shape().dimensions(1), 4000);
2631   EXPECT_EQ(final_sort->operand(1)->shape().dimensions(1), 4000);
2632 }
2633 
TEST_F(SpmdPartitioningTest,PartitionSortInTopKWhenComparisonWithSelect)2634 TEST_F(SpmdPartitioningTest, PartitionSortInTopKWhenComparisonWithSelect) {
2635   absl::string_view hlo_string = R"(
2636 HloModule module
2637 
2638 %compare-greater-than.8 (p.0.lhs.2566: bf16[],
2639   p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
2640   %p.0.lhs.2566 = bf16[] parameter(0)
2641   %convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
2642   %bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
2643   %constant.285 = s32[] constant(0)
2644   %compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
2645     direction=LT
2646   %constant.286 = u32[] constant(2147483647)
2647   %bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
2648   %subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
2649   %bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
2650   %select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
2651     s32[] %bitcast-convert.48)
2652   %p.0.rhs.2567 = bf16[] parameter(1)
2653   %convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
2654   %bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
2655   %compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
2656     direction=LT
2657   %bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
2658   %subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
2659   %bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
2660   %select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
2661     s32[] %bitcast-convert.51)
2662   %compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
2663   %compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
2664   %compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
2665     direction=EQ
2666   %p.1.lhs.2586 = s32[] parameter(2)
2667   %p.1.rhs.2587 = s32[] parameter(3)
2668   %compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
2669     direction=LT
2670   ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
2671     pred[] %compare.86)
2672 }
2673 
2674 ENTRY entry
2675   (arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
2676   %arg_tuple.1 = bf16[2,209664] parameter(0)
2677   %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
2678   %iota.7 = s32[2,209664] iota(), iota_dimension=1,
2679     metadata={op_type="TopKV2" op_name="TopKV2"}
2680   %sort.32 = (bf16[2,209664], s32[2,209664])
2681     sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
2682     dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
2683     metadata={op_type="TopKV2" op_name="TopKV2"}
2684   %get-tuple-element.33 = bf16[2,209664]
2685     get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2686     index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
2687   %slice.34 = bf16[2,2000] slice(bf16[2,209664]
2688     %get-tuple-element.33), slice={[0:2], [0:2000]},
2689     metadata={op_type="TopKV2" op_name="TopKV2"}
2690   %get-tuple-element.35 = s32[2,209664]
2691     get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2692     index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
2693   %slice.36 = s32[2,2000] slice(s32[2,209664]
2694     %get-tuple-element.35), slice={[0:2], [0:2000]},
2695     metadata={op_type="TopKV2" op_name="TopKV2"}
2696   ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
2697     tuple(bf16[2,2000] %slice.34, s32[2,2000]
2698     %slice.36), sharding={{replicated}, {replicated}},
2699     metadata={op_name="XLA_Retvals"}
2700 })";
2701 
2702   TF_ASSERT_OK_AND_ASSIGN(auto module,
2703                           PartitionComputation(hlo_string, /*num_devices=*/2));
2704   VLOG(1) << module->ToString();
2705   auto sort = FindInstruction(module.get(), "sort");
2706   EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 104832);
2707   EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 104832);
2708   auto final_sort = FindInstruction(module.get(), "sort.1");
2709   EXPECT_EQ(final_sort->operand(0)->shape().dimensions(1), 4000);
2710   EXPECT_EQ(final_sort->operand(1)->shape().dimensions(1), 4000);
2711 }
2712 
TEST_F(SpmdPartitioningTest,NoPartitionSortInTopKWhenSecondOperandIsNotIota)2713 TEST_F(SpmdPartitioningTest, NoPartitionSortInTopKWhenSecondOperandIsNotIota) {
2714   absl::string_view hlo_string = R"(
2715 HloModule module
2716 
2717 %compare-greater-than.8 (p.0.lhs.2566: bf16[],
2718   p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
2719   %p.0.lhs.2566 = bf16[] parameter(0)
2720   %convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
2721   %bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
2722   %constant.285 = s32[] constant(0)
2723   %compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
2724     direction=LT
2725   %constant.286 = u32[] constant(2147483647)
2726   %bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
2727   %subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
2728   %bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
2729   %select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
2730     s32[] %bitcast-convert.48)
2731   %p.0.rhs.2567 = bf16[] parameter(1)
2732   %convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
2733   %bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
2734   %compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
2735     direction=LT
2736   %bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
2737   %subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
2738   %bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
2739   %select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
2740     s32[] %bitcast-convert.51)
2741   %compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
2742   %compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
2743   %compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
2744     direction=EQ
2745   %p.1.lhs.2586 = s32[] parameter(2)
2746   %p.1.rhs.2587 = s32[] parameter(3)
2747   %compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
2748     direction=LT
2749   ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
2750     pred[] %compare.86)
2751 }
2752 
2753 ENTRY entry {
2754   %arg_tuple.1 = bf16[2,209664] parameter(0)
2755   %arg_tuple.2 = s32[2,209664] parameter(1)
2756   %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
2757   %sort.32 = (bf16[2,209664], s32[2,209664])
2758     sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %arg_tuple.2),
2759     dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
2760     metadata={op_type="TopKV2" op_name="TopKV2"}
2761   %get-tuple-element.33 = bf16[2,209664]
2762     get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2763     index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
2764   %slice.34 = bf16[2,2000] slice(bf16[2,209664]
2765     %get-tuple-element.33), slice={[0:2], [0:2000]},
2766     metadata={op_type="TopKV2" op_name="TopKV2"}
2767   %get-tuple-element.35 = s32[2,209664]
2768     get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2769     index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
2770   %slice.36 = s32[2,2000] slice(s32[2,209664]
2771     %get-tuple-element.35), slice={[0:2], [0:2000]},
2772     metadata={op_type="TopKV2" op_name="TopKV2"}
2773   ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
2774     tuple(bf16[2,2000] %slice.34, s32[2,2000]
2775     %slice.36), sharding={{replicated}, {replicated}},
2776     metadata={op_name="XLA_Retvals"}
2777 })";
2778 
2779   TF_ASSERT_OK_AND_ASSIGN(auto module,
2780                           PartitionComputation(hlo_string, /*num_devices=*/2));
2781   VLOG(1) << module->ToString();
2782   auto sort = FindInstruction(module.get(), "sort");
2783   EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664);
2784   EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664);
2785 }
2786 
TEST_F(SpmdPartitioningTest,NoPartitionSortInTopKWhenNoPartitionInSortDim)2787 TEST_F(SpmdPartitioningTest, NoPartitionSortInTopKWhenNoPartitionInSortDim) {
2788   absl::string_view hlo_string = R"(
2789 HloModule module
2790 
2791 %compare-greater-than.8 (p.0.lhs.2566: bf16[],
2792   p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
2793   %p.0.lhs.2566 = bf16[] parameter(0)
2794   %convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
2795   %bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
2796   %constant.285 = s32[] constant(0)
2797   %compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
2798     direction=LT
2799   %constant.286 = u32[] constant(2147483647)
2800   %bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
2801   %subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
2802   %bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
2803   %select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
2804     s32[] %bitcast-convert.48)
2805   %p.0.rhs.2567 = bf16[] parameter(1)
2806   %convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
2807   %bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
2808   %compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
2809     direction=LT
2810   %bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
2811   %subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
2812   %bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
2813   %select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
2814     s32[] %bitcast-convert.51)
2815   %compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
2816   %compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
2817   %compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
2818     direction=EQ
2819   %p.1.lhs.2586 = s32[] parameter(2)
2820   %p.1.rhs.2587 = s32[] parameter(3)
2821   %compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
2822     direction=LT
2823   ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
2824     pred[] %compare.86)
2825 }
2826 
2827 ENTRY entry
2828   (arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
2829   %arg_tuple.1 = bf16[2,209664] parameter(0)
2830   %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[2,1]0,1}
2831   %iota.7 = s32[2,209664] iota(), iota_dimension=1,
2832     metadata={op_type="TopKV2" op_name="TopKV2"}
2833   %sort.32 = (bf16[2,209664], s32[2,209664])
2834     sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
2835     dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
2836     metadata={op_type="TopKV2" op_name="TopKV2"}
2837   %get-tuple-element.33 = bf16[2,209664]
2838     get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2839     index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
2840   %slice.34 = bf16[2,2000] slice(bf16[2,209664]
2841     %get-tuple-element.33), slice={[0:2], [0:2000]},
2842     metadata={op_type="TopKV2" op_name="TopKV2"}
2843   %get-tuple-element.35 = s32[2,209664]
2844     get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2845     index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
2846   %slice.36 = s32[2,2000] slice(s32[2,209664]
2847     %get-tuple-element.35), slice={[0:2], [0:2000]},
2848     metadata={op_type="TopKV2" op_name="TopKV2"}
2849   ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
2850     tuple(bf16[2,2000] %slice.34, s32[2,2000]
2851     %slice.36), sharding={{replicated}, {replicated}},
2852     metadata={op_name="XLA_Retvals"}
2853 })";
2854 
2855   TF_ASSERT_OK_AND_ASSIGN(auto module,
2856                           PartitionComputation(hlo_string, /*num_devices=*/2));
2857   VLOG(1) << module->ToString();
2858   auto sort = FindInstruction(module.get(), "sort");
2859   EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664);
2860   EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664);
2861 }
2862 
TEST_F(SpmdPartitioningTest,NoPartitionSortInTopKWhenSliceInOtherDim)2863 TEST_F(SpmdPartitioningTest, NoPartitionSortInTopKWhenSliceInOtherDim) {
2864   absl::string_view hlo_string = R"(
2865 HloModule module
2866 
2867 %compare-greater-than.8 (p.0.lhs.2566: bf16[],
2868   p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
2869   %p.0.lhs.2566 = bf16[] parameter(0)
2870   %convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
2871   %bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
2872   %constant.285 = s32[] constant(0)
2873   %compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
2874     direction=LT
2875   %constant.286 = u32[] constant(2147483647)
2876   %bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
2877   %subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
2878   %bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
2879   %select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
2880     s32[] %bitcast-convert.48)
2881   %p.0.rhs.2567 = bf16[] parameter(1)
2882   %convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
2883   %bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
2884   %compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
2885     direction=LT
2886   %bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
2887   %subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
2888   %bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
2889   %select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
2890     s32[] %bitcast-convert.51)
2891   %compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
2892   %compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
2893   %compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
2894     direction=EQ
2895   %p.1.lhs.2586 = s32[] parameter(2)
2896   %p.1.rhs.2587 = s32[] parameter(3)
2897   %compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
2898     direction=LT
2899   ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
2900     pred[] %compare.86)
2901 }
2902 
2903 ENTRY entry {
2904   %arg_tuple.1 = bf16[2,209664] parameter(0)
2905   %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
2906   %iota.7 = s32[2,209664] iota(), iota_dimension=1,
2907     metadata={op_type="TopKV2" op_name="TopKV2"}
2908   %sort.32 = (bf16[2,209664], s32[2,209664])
2909     sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
2910     dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
2911     metadata={op_type="TopKV2" op_name="TopKV2"}
2912   %get-tuple-element.33 = bf16[2,209664]
2913     get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2914     index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
2915   %slice.34 = bf16[1,209664] slice(bf16[2,209664]
2916     %get-tuple-element.33), slice={[0:1], [0:209664]},
2917     metadata={op_type="TopKV2" op_name="TopKV2"}
2918   %get-tuple-element.35 = s32[2,209664]
2919     get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2920     index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
2921   %slice.36 = s32[1,209664] slice(s32[2,209664]
2922     %get-tuple-element.35), slice={[0:1], [0:209664]},
2923     metadata={op_type="TopKV2" op_name="TopKV2"}
2924   ROOT %tuple.46 = (bf16[1,209664], s32[1,209664])
2925     tuple(bf16[1,209664] %slice.34, s32[1,209664]
2926     %slice.36), sharding={{replicated}, {replicated}},
2927     metadata={op_name="XLA_Retvals"}
2928 })";
2929 
2930   TF_ASSERT_OK_AND_ASSIGN(auto module,
2931                           PartitionComputation(hlo_string, /*num_devices=*/2));
2932   VLOG(1) << module->ToString();
2933   auto sort = FindInstruction(module.get(), "sort");
2934   EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664);
2935   EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664);
2936 }
2937 
TEST_F(SpmdPartitioningTest,ShardableTranspose)2938 TEST_F(SpmdPartitioningTest, ShardableTranspose) {
2939   absl::string_view hlo_string = R"(
2940 HloModule module
2941 
2942 ENTRY entry {
2943   %param0 = f32[16,38,38,4] parameter(0)
2944   %param0.copy = f32[16,38,38,4] copy(%param0), sharding={devices=[1,2,1,1]0,1}
2945   ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy),
2946     dimensions={0,3,1,2}, sharding={devices=[1,1,2,1]0,1}
2947 })";
2948 
2949   TF_ASSERT_OK_AND_ASSIGN(auto module,
2950                           PartitionComputation(hlo_string, /*num_devices=*/2));
2951   VLOG(1) << module->ToString();
2952 
2953   auto root = module->entry_computation()->root_instruction();
2954   auto param0 = AllOf(
2955       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
2956                                 op::Constant(), op::Constant())),
2957       op::Shape("f32[16,19,38,4]"));
2958   EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[16,4,19,38]")));
2959 }
2960 
TEST_F(SpmdPartitioningTest,MultiDimensionShardedTranspose)2961 TEST_F(SpmdPartitioningTest, MultiDimensionShardedTranspose) {
2962   absl::string_view hlo_string = R"(
2963 HloModule module
2964 
2965 ENTRY entry {
2966   %param0 = f32[16,38,38,4] parameter(0)
2967   %param0.copy = f32[16,38,38,4] copy(%param0),
2968     sharding={devices=[4,2,1,1]0,1,2,3,4,5,6,7}
2969   ROOT %transpose = f32[38,4,16,38] transpose(%param0.copy),
2970     dimensions={1,3,0,2}, sharding={devices=[2,1,4,1]0,2,4,6,1,3,5,7}
2971 })";
2972 
2973   TF_ASSERT_OK_AND_ASSIGN(auto module,
2974                           PartitionComputation(hlo_string, /*num_devices=*/8));
2975   VLOG(1) << module->ToString();
2976 
2977   auto root = module->entry_computation()->root_instruction();
2978   auto param0 = AllOf(
2979       op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Reshape(),
2980                                 op::Constant(), op::Constant())),
2981       op::Shape("f32[4,19,38,4]"));
2982   EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[19,4,4,38]")));
2983 }
2984 
TEST_F(SpmdPartitioningTest,NonShardableTranspose)2985 TEST_F(SpmdPartitioningTest, NonShardableTranspose) {
2986   absl::string_view hlo_string = R"(
2987 HloModule module
2988 
2989 ENTRY entry {
2990   %param0 = f32[16,38,38,4] parameter(0)
2991   %param0.copy = f32[16,38,38,4] copy(%param0), sharding={devices=[1,2,1,1]0,1}
2992   ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy),
2993     dimensions={0,3,1,2}, sharding={devices=[1,2,1,1]0,1}
2994 })";
2995 
2996   TF_ASSERT_OK_AND_ASSIGN(auto module,
2997                           PartitionComputation(hlo_string, /*num_devices=*/2));
2998   VLOG(1) << module->ToString();
2999 
3000   auto root = module->entry_computation()->root_instruction();
3001   auto resahrd = AllOf(op::Reshape(op::Transpose(op::Reshape(op::AllToAll()))),
3002                        op::Shape("f32[16,38,38,2]"));
3003   EXPECT_THAT(root, AllOf(op::Transpose(), op::Shape("f32[16,2,38,38]")));
3004 }
3005 
TEST_F(SpmdPartitioningTest,PartialReplicateShardableTranspose)3006 TEST_F(SpmdPartitioningTest, PartialReplicateShardableTranspose) {
3007   absl::string_view hlo_string = R"(
3008 HloModule module
3009 
3010 ENTRY entry {
3011   %param0 = f32[16,38,38,4] parameter(0)
3012   %param0.copy = f32[16,38,38,4] copy(%param0),
3013     sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate}
3014   ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy),
3015     dimensions={0,3,1,2},
3016     sharding={devices=[1,1,2,1,2]0,1,2,3 last_tile_dim_replicate}
3017 })";
3018 
3019   TF_ASSERT_OK_AND_ASSIGN(auto module,
3020                           PartitionComputation(hlo_string, /*num_devices=*/4));
3021   VLOG(1) << module->ToString();
3022 
3023   auto root = module->entry_computation()->root_instruction();
3024   auto param0 = AllOf(
3025       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
3026                                 op::Constant(), op::Constant())),
3027       op::Shape("f32[16,19,38,4]"));
3028   EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[16,4,19,38]")));
3029 }
3030 
TEST_F(SpmdPartitioningTest,PartialReplicateNonShardableTranspose)3031 TEST_F(SpmdPartitioningTest, PartialReplicateNonShardableTranspose) {
3032   absl::string_view hlo_string = R"(
3033 HloModule module
3034 
3035 ENTRY entry {
3036   %param0 = f32[16,38,38,4] parameter(0)
3037   %param0.copy = f32[16,38,38,4] copy(%param0),
3038     sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate}
3039   ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy),
3040     dimensions={0,3,1,2},
3041     sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate}
3042 })";
3043 
3044   TF_ASSERT_OK_AND_ASSIGN(auto module,
3045                           PartitionComputation(hlo_string, /*num_devices=*/4));
3046   VLOG(1) << module->ToString();
3047 
3048   auto root = module->entry_computation()->root_instruction();
3049   auto resahrd = AllOf(op::Reshape(op::Transpose(op::Reshape(op::AllToAll()))),
3050                        op::Shape("f32[16,38,38,2]"));
3051   EXPECT_THAT(root, AllOf(op::Transpose(), op::Shape("f32[16,2,38,38]")));
3052 }
3053 
TEST_F(SpmdPartitioningTest,PartialReplicateMultiDimensionShardedTranspose)3054 TEST_F(SpmdPartitioningTest, PartialReplicateMultiDimensionShardedTranspose) {
3055   absl::string_view hlo_string = R"(
3056 HloModule module
3057 
3058 ENTRY entry {
3059   %param0 = f32[16,38,38,4] parameter(0)
3060   %param0.copy = f32[16,38,38,4] copy(%param0),
3061     sharding={devices=[2,2,1,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
3062   ROOT %transpose = f32[38,4,16,38] transpose(%param0.copy),
3063     dimensions={1,3,0,2},
3064     sharding={devices=[2,1,2,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}
3065 })";
3066 
3067   TF_ASSERT_OK_AND_ASSIGN(auto module,
3068                           PartitionComputation(hlo_string, /*num_devices=*/8));
3069   VLOG(1) << module->ToString();
3070 
3071   auto root = module->entry_computation()->root_instruction();
3072   auto param0 = AllOf(
3073       op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Reshape(),
3074                                 op::Constant(), op::Constant())),
3075       op::Shape("f32[8,19,38,4]"));
3076   EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[19,4,8,38]")));
3077 }
3078 
TEST_F(SpmdPartitioningTest,ShardableReshape)3079 TEST_F(SpmdPartitioningTest, ShardableReshape) {
3080   absl::string_view hlo_string = R"(
3081 HloModule module
3082 
3083 ENTRY entry {
3084   %param0 = f32[38,38,324] parameter(0)
3085   %param0.copy = f32[38,38,324] copy(%param0), sharding={devices=[2,1,1]0,1}
3086   ROOT %reshape = f32[38,38,4,81] reshape(%param0.copy),
3087     sharding={devices=[2,1,1,1]0,1}
3088 })";
3089 
3090   TF_ASSERT_OK_AND_ASSIGN(auto module,
3091                           PartitionComputation(hlo_string, /*num_devices=*/2));
3092   VLOG(1) << module->ToString();
3093 
3094   auto root = module->entry_computation()->root_instruction();
3095   auto param0 =
3096       AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
3097                                       op::Constant(), op::Constant())),
3098             op::Shape("f32[19,38,324]"));
3099   EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]")));
3100 }
3101 
TEST_F(SpmdPartitioningTest,ReshapeWithReshard)3102 TEST_F(SpmdPartitioningTest, ReshapeWithReshard) {
3103   absl::string_view hlo_string = R"(
3104 HloModule module
3105 
3106 ENTRY entry {
3107   %param0 = f32[38,38,324] parameter(0), sharding={devices=[2,1,1]0,1}
3108   ROOT %reshape = f32[38,38,4,81] reshape(%param0),
3109     sharding={devices=[1,2,1,1]0,1}
3110 })";
3111 
3112   TF_ASSERT_OK_AND_ASSIGN(auto module,
3113                           PartitionComputation(hlo_string, /*num_devices=*/2));
3114   VLOG(1) << module->ToString();
3115 
3116   auto root = module->entry_computation()->root_instruction();
3117   auto input_reshard =
3118       op::Reshape(op::Transpose(op::AllToAll(op::Reshape(op::Parameter(0)))));
3119   EXPECT_THAT(root,
3120               AllOf(op::Reshape(input_reshard), op::Shape("f32[38,19,4,81]")));
3121 }
3122 
TEST_F(SpmdPartitioningTest,ReshapeWithReshard2)3123 TEST_F(SpmdPartitioningTest, ReshapeWithReshard2) {
3124   absl::string_view hlo_string = R"(
3125 HloModule module
3126 
3127 ENTRY entry {
3128   %param0 = f32[38,38,324] parameter(0), sharding={devices=[2,1,1]0,1}
3129   ROOT %reshape = f32[38,38,2,162] reshape(%param0),
3130     sharding={devices=[1,1,1,2]0,1}
3131 })";
3132 
3133   TF_ASSERT_OK_AND_ASSIGN(auto module,
3134                           PartitionComputation(hlo_string, /*num_devices=*/2));
3135   VLOG(1) << module->ToString();
3136 
3137   auto root = module->entry_computation()->root_instruction();
3138   auto local_reshape =
3139       AllOf(op::Reshape(op::Parameter(0)), op::Shape("f32[19,38,2,162]"));
3140   EXPECT_THAT(root, AllOf(op::Shape("f32[38,38,2,81]"),
3141                           op::Reshape(op::Transpose(
3142                               op::AllToAll(op::Reshape(local_reshape))))));
3143 }
3144 
TEST_F(SpmdPartitioningTest,PartialReplicateShardableReshape)3145 TEST_F(SpmdPartitioningTest, PartialReplicateShardableReshape) {
3146   absl::string_view hlo_string = R"(
3147 HloModule module
3148 
3149 ENTRY entry {
3150   %param0 = f32[38,38,324] parameter(0)
3151   %param0.copy = f32[38,38,324] copy(%param0),
3152     sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate}
3153   ROOT %reshape = f32[38,38,4,81] reshape(%param0.copy),
3154     sharding={devices=[2,1,1,1,2]0,1,2,3 last_tile_dim_replicate}
3155 })";
3156 
3157   TF_ASSERT_OK_AND_ASSIGN(auto module,
3158                           PartitionComputation(hlo_string, /*num_devices=*/4));
3159   VLOG(1) << module->ToString();
3160 
3161   auto root = module->entry_computation()->root_instruction();
3162   auto param0 =
3163       AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
3164                                       op::Constant(), op::Constant())),
3165             op::Shape("f32[19,38,324]"));
3166   EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]")));
3167 }
3168 
TEST_F(SpmdPartitioningTest,ReshapeMergeDimsWithHaloExchange)3169 TEST_F(SpmdPartitioningTest, ReshapeMergeDimsWithHaloExchange) {
3170   absl::string_view hlo_string = R"(
3171 HloModule module
3172 
3173 ENTRY entry {
3174   %input = s32[2,3,7,10] parameter(0), sharding={devices=[1,1,2,1]0,1}
3175   ROOT %reshape = s32[3,2,1,14,5] reshape(%input),
3176     sharding={devices=[1,1,1,2,1]0,1}
3177 })";
3178 
3179   TF_ASSERT_OK_AND_ASSIGN(auto module,
3180                           PartitionComputation(hlo_string, /*num_devices=*/2));
3181   VLOG(1) << module->ToString();
3182 
3183   auto reshape =
3184       AllOf(op::Reshape(op::Parameter(0)), op::Shape("s32[3,2,1,8,5]"));
3185   auto halo = op::CollectivePermute(op::Slice(reshape));
3186   auto exchanged = op::DynamicSlice(op::Concatenate(halo, op::Slice(reshape)),
3187                                     _, _, _, _, _);
3188   auto root = module->entry_computation()->root_instruction();
3189   EXPECT_THAT(root, AllOf(exchanged, op::Shape("s32[3,2,1,7,5]")));
3190 }
3191 
TEST_F(SpmdPartitioningTest,PartialReplicateReshapeMergeDimsWithHaloExchange)3192 TEST_F(SpmdPartitioningTest, PartialReplicateReshapeMergeDimsWithHaloExchange) {
3193   absl::string_view hlo_string = R"(
3194 HloModule module
3195 
3196 ENTRY entry {
3197   %input = s32[2,3,7,10] parameter(0),
3198     sharding={devices=[1,1,2,1,2]0,1,2,3 last_tile_dim_replicate}
3199   ROOT %reshape = s32[3,2,1,14,5] reshape(%input),
3200     sharding={devices=[1,1,1,2,1,2]0,1,2,3 last_tile_dim_replicate}
3201 })";
3202 
3203   TF_ASSERT_OK_AND_ASSIGN(auto module,
3204                           PartitionComputation(hlo_string, /*num_devices=*/4));
3205   VLOG(1) << module->ToString();
3206 
3207   auto reshape =
3208       AllOf(op::Reshape(op::Parameter(0)), op::Shape("s32[3,2,1,8,5]"));
3209   auto halo = op::CollectivePermute(op::Slice(reshape));
3210   auto exchanged = op::DynamicSlice(op::Concatenate(halo, op::Slice(reshape)),
3211                                     _, _, _, _, _);
3212   auto root = module->entry_computation()->root_instruction();
3213   EXPECT_THAT(root, AllOf(exchanged, op::Shape("s32[3,2,1,7,5]")));
3214 }
3215 
3216 // Produces an invalid module after transformation.
TEST_F(SpmdPartitioningTest,InceptionV3_4_way_ReduceWindowDilated)3217 TEST_F(SpmdPartitioningTest, InceptionV3_4_way_ReduceWindowDilated) {
3218   absl::string_view hlo_string = R"(
3219 HloModule module
3220 
3221 sum {
3222   a = f32[] parameter(0)
3223   b = f32[] parameter(1)
3224   ROOT add = f32[] add(a, b)
3225 }
3226 
3227 ENTRY entry {
3228   %param0 = f32[128,5,5,768] parameter(0)
3229   %param0.copy = f32[128,5,5,768] copy(%param0),
3230     sharding={devices=[1,4,1,1]0,1,2,3}
3231   %constant.1 = f32[] constant(0), sharding={replicated}
3232   ROOT %rw = f32[128,17,17,768] reduce-window(%param0.copy, %constant.1),
3233     window={size=1x5x5x1 pad=0_0x4_4x4_4x0_0 lhs_dilate=1x3x3x1},
3234     to_apply=sum, sharding={devices=[1,4,1,1]0,1,2,3}
3235 })";
3236 
3237   TF_ASSERT_OK_AND_ASSIGN(auto module,
3238                           PartitionComputation(hlo_string, /*num_devices=*/4));
3239   VLOG(1) << module->ToString();
3240 
3241   auto input_shard = op::Copy(op::DynamicSlice(
3242       op::Pad(op::Parameter(0), op::Constant()), op::Constant(), op::Reshape(),
3243       op::Constant(), op::Constant()));
3244   auto id_mul4_add1 =
3245       op::Add(op::Multiply(op::Reshape(), op::Constant()), op::Constant());
3246   auto id_mul5 = op::Multiply(op::Reshape(), op::Constant());
3247   auto id_mul5_add1_div3 =
3248       op::Divide(op::Add(id_mul5, op::Constant()), op::Constant());
3249   auto before_masking = AllOf(
3250       op::Shape("f32[128,3,5,768]"),
3251       op::DynamicSlice(
3252           AllOf(
3253               op::Shape("f32[128,4,5,768]"),
3254               op::Concatenate(op::CollectivePermute(input_shard), input_shard)),
3255           op::Constant(),
3256           op::Subtract(op::Constant(),
3257                        op::Subtract(id_mul4_add1, id_mul5_add1_div3)),
3258           op::Constant(), op::Constant()));
3259   auto masked = op::Select(
3260       op::And(op::Compare(op::Add(op::Iota(), op::Broadcast(id_mul5_add1_div3)),
3261                           op::Broadcast(op::Constant())),
3262               op::Compare(op::Add(op::Iota(), op::Broadcast(id_mul5_add1_div3)),
3263                           op::Broadcast(op::Constant()))),
3264       before_masking, op::Broadcast(op::Constant()));
3265   auto rw = AllOf(op::Shape("f32[128,7,17,768]"),
3266                   op::ReduceWindow(masked, op::Constant()));
3267   auto final_slice_index = op::Subtract(
3268       id_mul5,
3269       op::Add(op::Multiply(id_mul5_add1_div3, op::Constant()), op::Constant()));
3270   auto root = module->entry_computation()->root_instruction();
3271   EXPECT_THAT(root,
3272               AllOf(op::Shape("f32[128,5,17,768]"),
3273                     op::DynamicSlice(rw, op::Constant(), final_slice_index,
3274                                      op::Constant(), op::Constant())));
3275 }
3276 
TEST_F(SpmdPartitioningTest,TiledToTiledReduce)3277 TEST_F(SpmdPartitioningTest, TiledToTiledReduce) {
3278   absl::string_view hlo_string = R"(
3279 HloModule module
3280 
3281 sum {
3282   a = f32[] parameter(0)
3283   b = f32[] parameter(1)
3284   ROOT add = f32[] add(a, b)
3285 }
3286 
3287 ENTRY entry {
3288   %param0 = f32[4,32,32,128] parameter(0)
3289   %param0.copy = f32[4,32,32,128] copy(%param0),
3290     sharding={devices=[1,1,1,2]0,1}
3291   %constant.1 = f32[] constant(0), sharding={replicated}
3292   %reduce = f32[128] reduce(%param0.copy, %constant.1), dimensions={0,1,2},
3293     to_apply=%sum, sharding={devices=[2]0,1}
3294 })";
3295 
3296   TF_ASSERT_OK_AND_ASSIGN(auto module,
3297                           PartitionComputation(hlo_string, /*num_devices=*/2));
3298   VLOG(1) << module->ToString();
3299 
3300   auto root = module->entry_computation()->root_instruction();
3301   auto param0 = AllOf(
3302       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
3303                                 op::Constant(), op::Reshape())),
3304       op::Shape("f32[4,32,32,64]"));
3305 
3306   EXPECT_THAT(root,
3307               AllOf(op::Reduce(param0, op::Constant()), op::Shape("f32[64]")));
3308 }
3309 
TEST_F(SpmdPartitioningTest,PartialTiledToPartialTiledReduce)3310 TEST_F(SpmdPartitioningTest, PartialTiledToPartialTiledReduce) {
3311   absl::string_view hlo_string = R"(
3312 HloModule module
3313 
3314 sum {
3315   a = f32[] parameter(0)
3316   b = f32[] parameter(1)
3317   ROOT add = f32[] add(a, b)
3318 }
3319 
3320 ENTRY entry {
3321   %param0 = f32[4,4] parameter(0),
3322     sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
3323   %constant.1 = f32[] constant(0), sharding={replicated}
3324   ROOT %reduce = f32[4] reduce(%param0, %constant.1), dimensions={0},
3325     to_apply=%sum,
3326     sharding={devices=[2,4]0,1,4,5,2,3,6,7 last_tile_dim_replicate}
3327 })";
3328 
3329   TF_ASSERT_OK_AND_ASSIGN(auto module,
3330                           PartitionComputation(hlo_string, /*num_devices=*/8));
3331   VLOG(1) << module->ToString();
3332 
3333   auto root = module->entry_computation()->root_instruction();
3334   EXPECT_THAT(root,
3335               AllOf(op::AllReduce(op::Reduce(op::Parameter(0), op::Constant())),
3336                     op::Shape("f32[2]")));
3337 }
3338 
TEST_F(SpmdPartitioningTest,TiledToTiledTupleReduce)3339 TEST_F(SpmdPartitioningTest, TiledToTiledTupleReduce) {
3340   absl::string_view hlo_string = R"(
3341 HloModule module
3342 
3343 %minmax_func {
3344   %lhs_value = f32[] parameter(0)
3345   %rhs_value = f32[] parameter(2)
3346   %compare.2 = pred[] compare(%lhs_value, %rhs_value), direction=GT
3347   %select.4 = f32[] select(%compare.2, %lhs_value, %rhs_value)
3348   %lhs_index = s32[] parameter(1)
3349   %rhs_index = s32[] parameter(3)
3350   %select.5 = s32[] select(%compare.2, %lhs_index, %rhs_index)
3351   ROOT %tuple.2 = (f32[], s32[]) tuple(%select.4, %select.5)
3352 }
3353 
3354 ENTRY %main {
3355   %param0 = f32[28,10] parameter(0), sharding={devices=[2,1]0,1}
3356   %param1 = s32[28,10] parameter(1), sharding={devices=[2,1]0,1}
3357   %init0 = f32[] parameter(2)
3358   %init1 = s32[] parameter(3)
3359   ROOT %reduce = (f32[28], s32[28]) reduce(%param0, %param1, %init0, %init1),
3360     dimensions={1}, to_apply=%minmax_func,
3361     sharding={{devices=[2]0,1}, {devices=[2]0,1}}
3362 })";
3363 
3364   TF_ASSERT_OK_AND_ASSIGN(auto module,
3365                           PartitionComputation(hlo_string, /*num_devices=*/2));
3366   VLOG(1) << module->ToString();
3367 
3368   auto root = module->entry_computation()->root_instruction();
3369   EXPECT_THAT(root, AllOf(op::Reduce(op::Parameter(0), op::Parameter(1),
3370                                      op::Parameter(2), op::Parameter(3)),
3371                           op::Shape("(f32[14], s32[14])")));
3372 }
3373 
TEST_F(SpmdPartitioningTest,TiledToPartiallyTiledTupleReduce)3374 TEST_F(SpmdPartitioningTest, TiledToPartiallyTiledTupleReduce) {
3375   absl::string_view hlo_string = R"(
3376 HloModule module
3377 
3378 %minmax_func {
3379   %lhs_value = f32[] parameter(0)
3380   %rhs_value = f32[] parameter(2)
3381   %compare.2 = pred[] compare(%lhs_value, %rhs_value), direction=GT
3382   %select.4 = f32[] select(%compare.2, %lhs_value, %rhs_value)
3383   %lhs_index = s32[] parameter(1)
3384   %rhs_index = s32[] parameter(3)
3385   %select.5 = s32[] select(%compare.2, %lhs_index, %rhs_index)
3386   ROOT %tuple.2 = (f32[], s32[]) tuple(%select.4, %select.5)
3387 }
3388 
3389 ENTRY %main {
3390   %param0 = f32[28,12] parameter(0), sharding={devices=[2,4]0,1,2,3,4,5,6,7}
3391   %param1 = s32[28,12] parameter(1), sharding={devices=[2,4]0,1,2,3,4,5,6,7}
3392   %init0 = f32[] parameter(2)
3393   %init1 = s32[] parameter(3)
3394   ROOT %reduce = (f32[28], s32[28]) reduce(%param0, %param1, %init0, %init1),
3395     dimensions={1}, to_apply=%minmax_func,
3396     sharding={{devices=[2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate},
3397               {devices=[2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}}
3398 })";
3399 
3400   TF_ASSERT_OK_AND_ASSIGN(auto module,
3401                           PartitionComputation(hlo_string, /*num_devices=*/8));
3402   VLOG(1) << module->ToString();
3403 
3404   auto lhs = AllOf(op::Shape("f32[14,3]"), op::Parameter(0));
3405   auto rhs = AllOf(op::Shape("s32[14,3]"), op::Parameter(1));
3406   auto local_reduce =
3407       AllOf(op::Reduce(lhs, rhs, op::Parameter(2), op::Parameter(3)),
3408             op::Shape("(f32[14], s32[14])"));
3409   auto reshape_l = AllOf(op::Reshape(op::GetTupleElement(local_reduce)),
3410                          op::Shape("f32[14,1]"));
3411   auto reshape_r = AllOf(op::Reshape(op::GetTupleElement(local_reduce)),
3412                          op::Shape("s32[14,1]"));
3413   auto broadcast_l =
3414       AllOf(op::AllReduce(op::DynamicUpdateSlice(_, reshape_l, _, _)),
3415             op::Shape("f32[14,4]"));
3416   auto broadcast_r =
3417       AllOf(op::AllReduce(op::DynamicUpdateSlice(_, reshape_r, _, _)),
3418             op::Shape("s32[14,4]"));
3419   auto root = module->entry_computation()->root_instruction();
3420   EXPECT_THAT(root, AllOf(op::Reduce(broadcast_l, broadcast_r, op::Parameter(2),
3421                                      op::Parameter(3)),
3422                           op::Shape("(f32[14], s32[14])")));
3423 }
3424 
TEST_F(SpmdPartitioningTest,TiledToTiledReduceOutputReshard)3425 TEST_F(SpmdPartitioningTest, TiledToTiledReduceOutputReshard) {
3426   absl::string_view hlo_string = R"(
3427 HloModule module
3428 
3429 sum {
3430   a = f32[] parameter(0)
3431   b = f32[] parameter(1)
3432   ROOT add = f32[] add(a, b)
3433 }
3434 
3435 ENTRY entry {
3436   %param0 = f32[4,32,32,128] parameter(0)
3437   %param0.copy = f32[4,32,32,128] copy(%param0),
3438     sharding={devices=[1,2,1,1]0,1}
3439   %constant.1 = f32[] constant(0), sharding={replicated}
3440   %reduce = f32[128] reduce(%param0.copy, %constant.1), dimensions={0,1,2},
3441     to_apply=%sum, sharding={devices=[2]0,1}
3442 })";
3443 
3444   TF_ASSERT_OK_AND_ASSIGN(auto module,
3445                           PartitionComputation(hlo_string, /*num_devices=*/2));
3446   VLOG(1) << module->ToString();
3447 
3448   auto root = module->entry_computation()->root_instruction();
3449   auto param0 = AllOf(
3450       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
3451                                 op::Constant(), op::Constant())),
3452       op::Shape("f32[4,16,32,128]"));
3453 
3454   EXPECT_THAT(root,
3455               AllOf(op::DynamicSlice(
3456                         AllOf(op::AllReduce(op::Reduce(param0, op::Constant())),
3457                               op::Shape("f32[128]")),
3458                         op::Reshape()),
3459                     op::Shape("f32[64]")));
3460 }
3461 
TEST_F(SpmdPartitioningTest,IotaAlongNonTileDimension)3462 TEST_F(SpmdPartitioningTest, IotaAlongNonTileDimension) {
3463   absl::string_view hlo_string = R"(
3464 HloModule module
3465 
3466 ENTRY entry {
3467   ROOT %iota = s32[16,80,91] iota(), iota_dimension=1,
3468     sharding={devices=[1,1,2]0,1}
3469 })";
3470 
3471   TF_ASSERT_OK_AND_ASSIGN(auto module,
3472                           PartitionComputation(hlo_string, /*num_devices=*/2));
3473   VLOG(1) << module->ToString();
3474 
3475   auto root = module->entry_computation()->root_instruction();
3476   EXPECT_THAT(root, AllOf(op::Iota(), op::Shape("s32[16,80,46]")));
3477 }
3478 
TEST_F(SpmdPartitioningTest,IotaAlongTileDimension)3479 TEST_F(SpmdPartitioningTest, IotaAlongTileDimension) {
3480   absl::string_view hlo_string = R"(
3481 HloModule module
3482 
3483 ENTRY entry {
3484   ROOT %iota = s32[16,80,91] iota(), iota_dimension=2,
3485     sharding={devices=[1,1,2]0,1}
3486 })";
3487 
3488   TF_ASSERT_OK_AND_ASSIGN(auto module,
3489                           PartitionComputation(hlo_string, /*num_devices=*/2));
3490   VLOG(1) << module->ToString();
3491 
3492   auto root = module->entry_computation()->root_instruction();
3493   EXPECT_THAT(root, AllOf(op::Add(op::Iota(), op::Broadcast()),
3494                           op::Shape("s32[16,80,46]")));
3495 }
3496 
TEST_F(SpmdPartitioningTest,U32IotaAlongTileDimension)3497 TEST_F(SpmdPartitioningTest, U32IotaAlongTileDimension) {
3498   absl::string_view hlo_string = R"(
3499 HloModule module
3500 
3501 ENTRY entry {
3502   ROOT %iota = u32[16,80,91] iota(), iota_dimension=2,
3503     sharding={devices=[1,1,2]0,1}
3504 })";
3505 
3506   TF_ASSERT_OK_AND_ASSIGN(auto module,
3507                           PartitionComputation(hlo_string, /*num_devices=*/2));
3508   VLOG(1) << module->ToString();
3509 
3510   auto root = module->entry_computation()->root_instruction();
3511   EXPECT_THAT(root, AllOf(op::Add(op::Iota(), op::Broadcast()),
3512                           op::Shape("u32[16,80,46]")));
3513 }
3514 
TEST_F(SpmdPartitioningTest,Conditional)3515 TEST_F(SpmdPartitioningTest, Conditional) {
3516   absl::string_view hlo_string = R"(
3517 HloModule module
3518 
3519 Negate {
3520   x = f32[4,5] parameter(0), sharding={replicated}
3521   ROOT negate = f32[4,5] negate(x), sharding={replicated}
3522 }
3523 
3524 Identity {
3525   y = f32[4,5] parameter(0), sharding={devices=[2,1]0,1}
3526   ROOT copy = f32[4,5] copy(y), sharding={devices=[2,1]0,1}
3527 }
3528 
3529 ENTRY entry {
3530   %param.0 = pred[] parameter(0)
3531   %param.0.copy = pred[] copy(%param.0), sharding={maximal device=0}
3532   %param.1 = f32[4,5] parameter(1)
3533   %param.1.copy = f32[4,5] copy(%param.1), sharding={replicated}
3534   %param.2 = f32[4,5] parameter(2)
3535   %param.2.copy = f32[4,5] copy(%param.2), sharding={devices=[2,1]0,1}
3536   ROOT cond = f32[4,5] conditional(%param.0.copy, %param.1.copy, %param.2.copy),
3537     true_computation=Negate, false_computation=Identity,
3538     sharding={devices=[2,1]0,1}
3539 })";
3540 
3541   TF_ASSERT_OK_AND_ASSIGN(auto module,
3542                           PartitionComputation(hlo_string, /*num_devices=*/2));
3543   VLOG(1) << module->ToString();
3544 
3545   auto param0 = AllOf(op::Copy(op::Copy(op::Parameter()), op::Shape("pred[]")));
3546   auto param1 = AllOf(op::Copy(op::Parameter()), op::Shape("f32[4,5]"));
3547   auto param2 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
3548                                                 op::Constant())),
3549                       op::Shape("f32[2,5]"));
3550 
3551   auto root = module->entry_computation()->root_instruction();
3552   EXPECT_THAT(root, AllOf(op::Conditional(op::AllReduce(), param1, param2),
3553                           op::Shape("f32[2,5]")));
3554 
3555   auto then_branch_root = root->branch_computation(0)->root_instruction();
3556   EXPECT_THAT(then_branch_root,
3557               AllOf(op::DynamicSlice(op::Negate(op::Parameter()), op::Reshape(),
3558                                      op::Constant()),
3559                     op::Shape("f32[2,5]")));
3560 
3561   auto else_branch_root = root->branch_computation(1)->root_instruction();
3562   EXPECT_THAT(else_branch_root,
3563               AllOf(op::Copy(op::Parameter()), op::Shape("f32[2,5]")));
3564 }
3565 
TEST_F(SpmdPartitioningTest,ConditionalManual)3566 TEST_F(SpmdPartitioningTest, ConditionalManual) {
3567   absl::string_view hlo_string = R"(
3568 HloModule module
3569 
3570 Negate {
3571   x = f32[4,5] parameter(0), sharding={manual}
3572   ROOT negate = f32[4,5] negate(x), sharding={manual}
3573 }
3574 
3575 Identity {
3576   y = f32[4,5] parameter(0), sharding={manual}
3577   ROOT copy = f32[4,5] copy(y), sharding={manual}
3578 }
3579 
3580 ENTRY entry {
3581   %param.0 = pred[] parameter(0), sharding={manual}
3582   %param.1 = f32[4,5] parameter(1), sharding={manual}
3583   %param.2 = f32[4,5] parameter(2), sharding={manual}
3584   ROOT cond = f32[4,5] conditional(%param.0, %param.1, %param.2),
3585     true_computation=Negate, false_computation=Identity, sharding={manual}
3586 })";
3587 
3588   TF_ASSERT_OK_AND_ASSIGN(auto module,
3589                           PartitionComputation(hlo_string, /*num_devices=*/2));
3590   VLOG(1) << module->ToString();
3591 
3592   auto param0 = AllOf(op::Parameter(0), op::Shape("pred[]"));
3593   auto param1 = AllOf(op::Parameter(1), op::Shape("f32[4,5]"));
3594   auto param2 = AllOf(op::Parameter(2), op::Shape("f32[4,5]"));
3595 
3596   auto root = module->entry_computation()->root_instruction();
3597   EXPECT_THAT(root, AllOf(op::Conditional(param0, param1, param2),
3598                           op::Shape("f32[4,5]")));
3599 }
3600 
TEST_F(SpmdPartitioningTest,WhileManual)3601 TEST_F(SpmdPartitioningTest, WhileManual) {
3602   absl::string_view hlo_string = R"(
3603 HloModule module
3604 
3605 LoopCond {
3606   x = s32[] parameter(0), sharding={manual}
3607   const = s32[] constant(5), sharding={manual}
3608   ROOT lt = pred[] compare(x, const), direction=LT, sharding={manual}
3609 }
3610 
3611 Inc {
3612   x = s32[] parameter(0), sharding={manual}
3613   const = s32[] constant(1), sharding={manual}
3614   ROOT add = s32[] add(x, const), sharding={manual}
3615 }
3616 
3617 ENTRY entry {
3618   zero = s32[] parameter(0), sharding={manual}
3619   ROOT while = s32[] while(zero), body=Inc, condition=LoopCond,
3620     sharding={manual}
3621 })";
3622 
3623   TF_ASSERT_OK_AND_ASSIGN(auto module,
3624                           PartitionComputation(hlo_string, /*num_devices=*/2));
3625   VLOG(1) << module->ToString();
3626 
3627   auto zero = AllOf(op::Parameter(0), op::Shape("s32[]"));
3628   auto root = module->entry_computation()->root_instruction();
3629   EXPECT_THAT(root, AllOf(op::While(zero), op::Shape("s32[]")));
3630 }
3631 
TEST_F(SpmdPartitioningTest,SelectAndScatter_RetinaNet)3632 TEST_F(SpmdPartitioningTest, SelectAndScatter_RetinaNet) {
3633   absl::string_view hlo_string = R"(
3634 HloModule module
3635 
3636 ge {
3637   a = f32[] parameter(0)
3638   b = f32[] parameter(1)
3639   ROOT compare = pred[] compare(a, b), direction=GE
3640 }
3641 
3642 sum {
3643   c = f32[] parameter(0)
3644   d = f32[] parameter(1)
3645   ROOT add = f32[] add(c, d)
3646 }
3647 
3648 ENTRY entry {
3649   %param.0 = f32[32,128,384,64] parameter(0)
3650   %param.0.copy = f32[32,128,384,64] copy(%param.0),
3651     sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7}
3652   %param.1 = f32[32,64,192,64] parameter(1)
3653   %param.1.copy = f32[32,64,192,64] copy(%param.1),
3654     sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7}
3655   constant.1 = f32[] constant(0), sharding={replicated}
3656   ROOT select-and-scatter = f32[32,128,384,64] select-and-scatter(param.0.copy,
3657     %param.1.copy, constant.1), window={size=1x1x1x1 stride=1x2x2x1},
3658     select=ge, scatter=sum, sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7}
3659 })";
3660   TF_ASSERT_OK_AND_ASSIGN(auto module,
3661                           PartitionComputation(hlo_string, /*num_devices=*/8));
3662   VLOG(1) << module->ToString();
3663 
3664   auto root = module->entry_computation()->root_instruction();
3665   auto source = AllOf(
3666       op::Shape("f32[32,8,192,64]"),
3667       op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), op::Reshape(),
3668                                 op::Constant(), op::Constant())));
3669   auto data = AllOf(
3670       op::Shape("f32[32,16,384,64]"),
3671       op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
3672                                 op::Constant(), op::Constant())));
3673 
3674   EXPECT_THAT(root, op::SelectAndScatter(data, source, op::Constant()));
3675   EXPECT_EQ(root->window().dimensions(0).padding_low(), 0);
3676   EXPECT_EQ(root->window().dimensions(0).padding_high(), 0);
3677 }
3678 
TEST_F(SpmdPartitioningTest,TiledDot)3679 TEST_F(SpmdPartitioningTest, TiledDot) {
3680   absl::string_view hlo_string = R"(
3681 HloModule module
3682 
3683 ENTRY entry {
3684   %lhs = f32[128,64] parameter(0)
3685   %lhs.copy = f32[128,64] copy(%lhs), sharding={devices=[1,2]0,1}
3686   %rhs = f32[64,256] parameter(1)
3687   %rhs.copy = f32[64,256] copy(%rhs), sharding={devices=[2,1]0,1}
3688   ROOT %conv = f32[128,256] convolution(%lhs.copy, %rhs.copy),
3689     dim_labels=bf_io->bf, sharding={replicated}
3690 })";
3691 
3692   TF_ASSERT_OK_AND_ASSIGN(
3693       auto module,
3694       PartitionComputation(hlo_string, /*num_devices=*/2,
3695                            /*conv_halo_exchange_always_on_lhs=*/false));
3696   VLOG(1) << module->ToString();
3697 
3698   auto root = module->entry_computation()->root_instruction();
3699   auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(),
3700                                              op::Reshape())),
3701                    op::Shape("f32[128,32]"));
3702   auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
3703                                              op::Constant())),
3704                    op::Shape("f32[32,256]"));
3705   EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)),
3706                           op::Shape("f32[128,256]")));
3707 }
3708 
TEST_F(SpmdPartitioningTest,TiledDotOutputTiled)3709 TEST_F(SpmdPartitioningTest, TiledDotOutputTiled) {
3710   absl::string_view hlo_string = R"(
3711 HloModule module
3712 
3713 ENTRY entry {
3714   %lhs = f32[128,64] parameter(0)
3715   %lhs.copy = f32[128,64] copy(%lhs), sharding={devices=[1,2]0,1}
3716   %rhs = f32[64,256] parameter(1)
3717   %rhs.copy = f32[64,256] copy(%rhs), sharding={devices=[2,1]0,1}
3718   ROOT %conv = f32[128,256] convolution(%lhs.copy, %rhs.copy),
3719     dim_labels=bf_io->bf, sharding={devices=[1,2]0,1}
3720 })";
3721 
3722   TF_ASSERT_OK_AND_ASSIGN(auto module,
3723                           PartitionComputation(hlo_string, /*num_devices=*/2));
3724   VLOG(1) << module->ToString();
3725 
3726   auto root = module->entry_computation()->root_instruction();
3727   auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(),
3728                                              op::Reshape())),
3729                    op::Shape("f32[128,32]"));
3730   auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
3731                                              op::Constant())),
3732                    op::Shape("f32[32,256]"));
3733   EXPECT_THAT(root, AllOf(op::DynamicSlice(
3734                               AllOf(op::AllReduce(op::Convolution(lhs, rhs)),
3735                                     op::Shape("f32[128,256]")),
3736                               op::Constant(), op::Reshape()),
3737                           op::Shape("f32[128,128]")));
3738 }
3739 
TEST_F(SpmdPartitioningTest,BatchPartitionedConvolution)3740 TEST_F(SpmdPartitioningTest, BatchPartitionedConvolution) {
3741   absl::string_view hlo_string = R"(
3742 HloModule module
3743 
3744 ENTRY entry {
3745   %lhs = f32[128,256,256] parameter(0)
3746   %lhs.copy = f32[128,256,256] copy(%lhs), sharding={devices=[1,2,1]0,1}
3747   %rhs = f32[256,8,1] parameter(1)
3748   %rhs.copy = f32[256,8,1] copy(%rhs), sharding={replicated}
3749   ROOT %conv = f32[128,256,8] convolution(%lhs.copy, %rhs.copy),
3750     window={size=1}, dim_labels=0bf_io0->0bf, sharding={devices=[1,2,1]0,1}
3751 })";
3752 
3753   TF_ASSERT_OK_AND_ASSIGN(auto module,
3754                           PartitionComputation(hlo_string, /*num_devices=*/2));
3755   VLOG(1) << module->ToString();
3756 
3757   auto root = module->entry_computation()->root_instruction();
3758   auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
3759                                              op::Reshape(), op::Constant())),
3760                    op::Shape("f32[128,128,256]"));
3761   auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[256,8,1]"));
3762   EXPECT_THAT(root,
3763               AllOf(op::Convolution(lhs, rhs), op::Shape("f32[128,128,8]")));
3764 }
3765 
TEST_F(SpmdPartitioningTest,DotOutputFeaturePartitioned)3766 TEST_F(SpmdPartitioningTest, DotOutputFeaturePartitioned) {
3767   absl::string_view hlo_string = R"(
3768 HloModule module
3769 
3770 ENTRY entry {
3771   %lhs = f32[24,64] parameter(0)
3772   %lhs.copy = f32[24,64] copy(%lhs), sharding={replicated}
3773   %rhs = f32[39296,64] parameter(1)
3774   %rhs.copy = f32[39296,64] copy(%rhs), sharding={devices=[2,1]0,1}
3775   ROOT %dot = f32[24,39296] dot(%lhs.copy, %rhs.copy),
3776     lhs_batch_dims={}, rhs_batch_dims={},
3777     lhs_contracting_dims={1}, rhs_contracting_dims={1},
3778     sharding={devices=[1,2]0,1}
3779 })";
3780 
3781   TF_ASSERT_OK_AND_ASSIGN(auto module,
3782                           PartitionComputation(hlo_string, /*num_devices=*/2));
3783   VLOG(1) << module->ToString();
3784 
3785   auto root = module->entry_computation()->root_instruction();
3786   auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[24,64]"));
3787   auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(),
3788                                              op::Constant())),
3789                    op::Shape("f32[19648,64]"));
3790   EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[24,19648]")));
3791 }
3792 
TEST_F(SpmdPartitioningTest,DotPartialDeviceOrder)3793 TEST_F(SpmdPartitioningTest, DotPartialDeviceOrder) {
3794   absl::string_view hlo_string = R"(
3795 HloModule module
3796 
3797 ENTRY entry {
3798   %lhs = f32[16,256,4096] parameter(0), sharding={devices=[1,1,2,2]1,3,0,2 last_tile_dim_replicate}
3799   %rhs = f32[4096,2048] parameter(1), sharding={devices=[2,2]3,1,2,0}
3800   ROOT %dot = f32[16,256,2048] dot(%lhs, %rhs),
3801     lhs_batch_dims={}, rhs_batch_dims={},
3802     lhs_contracting_dims={2}, rhs_contracting_dims={0},
3803     sharding={devices=[1,1,2,2]2,3,0,1 last_tile_dim_replicate}
3804 })";
3805 
3806   TF_ASSERT_OK_AND_ASSIGN(auto module,
3807                           PartitionComputation(hlo_string, /*num_devices=*/4));
3808   VLOG(1) << module->ToString();
3809 
3810   auto root = module->entry_computation()->root_instruction();
3811   auto lhs = AllOf(op::Parameter(0), op::Shape("f32[16,256,2048]"));
3812   auto rhs = AllOf(op::Parameter(1), op::Shape("f32[2048,1024]"));
3813   EXPECT_THAT(root, AllOf(op::AllReduce(op::Dot(lhs, rhs)),
3814                           op::Shape("f32[16,256,1024]")));
3815 }
3816 
TEST_F(SpmdPartitioningTest,EinsumBatchPartitioned)3817 TEST_F(SpmdPartitioningTest, EinsumBatchPartitioned) {
3818   absl::string_view hlo_string = R"(
3819 HloModule module
3820 
3821 ENTRY entry {
3822   %lhs = f32[32,24,64] parameter(0)
3823   %lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[2,1,1]0,1}
3824   %rhs = f32[32,39296,64] parameter(1)
3825   %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={devices=[2,1,1]0,1}
3826   ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
3827     lhs_batch_dims={0}, rhs_batch_dims={0},
3828     lhs_contracting_dims={2}, rhs_contracting_dims={2},
3829     sharding={devices=[2,1,1]0,1}
3830 })";
3831 
3832   TF_ASSERT_OK_AND_ASSIGN(auto module,
3833                           PartitionComputation(hlo_string, /*num_devices=*/2));
3834   VLOG(1) << module->ToString();
3835 
3836   auto root = module->entry_computation()->root_instruction();
3837   auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
3838                                              op::Constant(), op::Constant())),
3839                    op::Shape("f32[16,24,64]"));
3840   auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(),
3841                                              op::Constant(), op::Constant())),
3842                    op::Shape("f32[16,39296,64]"));
3843   EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[16,24,39296]")));
3844 }
3845 
TEST_F(SpmdPartitioningTest,EinsumLHSandOutputBatchPartitioned)3846 TEST_F(SpmdPartitioningTest, EinsumLHSandOutputBatchPartitioned) {
3847   absl::string_view hlo_string = R"(
3848 HloModule module
3849 
3850 ENTRY entry {
3851   %lhs = f32[32,24,64] parameter(0)
3852   %lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[2,1,1]0,1}
3853   %rhs = f32[32,39296,64] parameter(1)
3854   %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated}
3855   ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
3856     lhs_batch_dims={0}, rhs_batch_dims={0},
3857     lhs_contracting_dims={2}, rhs_contracting_dims={2},
3858     sharding={devices=[2,1,1]0,1}
3859 })";
3860 
3861   TF_ASSERT_OK_AND_ASSIGN(auto module,
3862                           PartitionComputation(hlo_string, /*num_devices=*/2));
3863   VLOG(1) << module->ToString();
3864 
3865   auto root = module->entry_computation()->root_instruction();
3866   auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
3867                                              op::Constant(), op::Constant())),
3868                    op::Shape("f32[16,24,64]"));
3869   auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64]"));
3870   EXPECT_THAT(root, AllOf(op::Dot(lhs, op::DynamicSlice(rhs, op::Reshape(),
3871                                                         op::Constant(),
3872                                                         op::Constant())),
3873                           op::Shape("f32[16,24,39296]")));
3874 }
3875 
TEST_F(SpmdPartitioningTest,EinsumRHSandOutputBatchPartitioned)3876 TEST_F(SpmdPartitioningTest, EinsumRHSandOutputBatchPartitioned) {
3877   absl::string_view hlo_string = R"(
3878 HloModule module
3879 
3880 ENTRY entry {
3881   %lhs = f32[32,24,64] parameter(0)
3882   %lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[1,2,1]0,1}
3883   %rhs = f32[32,39296,64] parameter(1)
3884   %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={devices=[2,1,1]0,1}
3885   ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
3886     lhs_batch_dims={0}, rhs_batch_dims={0},
3887     lhs_contracting_dims={2}, rhs_contracting_dims={2},
3888     sharding={devices=[2,1,1]0,1}
3889 })";
3890 
3891   TF_ASSERT_OK_AND_ASSIGN(auto module,
3892                           PartitionComputation(hlo_string, /*num_devices=*/2));
3893   VLOG(1) << module->ToString();
3894 
3895   auto root = module->entry_computation()->root_instruction();
3896   auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
3897                                              op::Reshape(), op::Constant())),
3898                    op::Shape("f32[32,12,64]"));
3899   auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(),
3900                                              op::Constant(), op::Constant())),
3901                    op::Shape("f32[16,39296,64]"));
3902   auto lhs_reshard = op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))));
3903   EXPECT_THAT(root,
3904               AllOf(op::Dot(lhs_reshard, rhs), op::Shape("f32[16,24,39296]")));
3905 }
3906 
TEST_F(SpmdPartitioningTest,EinsumOutputBatchPartitioned)3907 TEST_F(SpmdPartitioningTest, EinsumOutputBatchPartitioned) {
3908   absl::string_view hlo_string = R"(
3909 HloModule module
3910 
3911 ENTRY entry {
3912   %lhs = f32[32,24,64] parameter(0)
3913   %lhs.copy = f32[32,24,64] copy(%lhs), sharding={replicated}
3914   %rhs = f32[32,39296,64] parameter(1)
3915   %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated}
3916   ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
3917     lhs_batch_dims={0}, rhs_batch_dims={0},
3918     lhs_contracting_dims={2}, rhs_contracting_dims={2},
3919     sharding={devices=[2,1,1]0,1}
3920 })";
3921 
3922   TF_ASSERT_OK_AND_ASSIGN(auto module,
3923                           PartitionComputation(hlo_string, /*num_devices=*/2));
3924   VLOG(1) << module->ToString();
3925 
3926   auto root = module->entry_computation()->root_instruction();
3927   auto lhs_slice =
3928       AllOf(op::DynamicSlice(op::Copy(op::Parameter(0)), op::Reshape(),
3929                              op::Constant(), op::Constant()),
3930             op::Shape("f32[16,24,64]"));
3931   auto rhs_slice =
3932       AllOf(op::DynamicSlice(op::Copy(op::Parameter(1)), op::Reshape(),
3933                              op::Constant(), op::Constant()),
3934             op::Shape("f32[16,39296,64]"));
3935   EXPECT_THAT(root, AllOf(op::Dot(lhs_slice, rhs_slice),
3936                           op::Shape("f32[16,24,39296]")));
3937 }
3938 
TEST_F(SpmdPartitioningTest,EinsumContractingDimsPartitioned)3939 TEST_F(SpmdPartitioningTest, EinsumContractingDimsPartitioned) {
3940   absl::string_view hlo_string = R"(
3941 HloModule module
3942 
3943 ENTRY entry {
3944   %lhs = f32[32,24,64,128] parameter(0)
3945   %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,1,2,2]0,1,2,3}
3946   %rhs = f32[32,39296,64,128] parameter(1)
3947   %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={devices=[1,1,2,2]0,1,2,3}
3948   ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
3949     lhs_batch_dims={0}, rhs_batch_dims={0},
3950     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
3951     sharding={replicated}
3952 })";
3953 
3954   TF_ASSERT_OK_AND_ASSIGN(auto module,
3955                           PartitionComputation(hlo_string, /*num_devices=*/4));
3956   VLOG(1) << module->ToString();
3957 
3958   auto root = module->entry_computation()->root_instruction();
3959   auto lhs = AllOf(
3960       op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
3961                                 op::Constant(), op::Reshape(), op::Reshape())),
3962       op::Shape("f32[32,24,32,64]"));
3963   auto rhs = AllOf(
3964       op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(),
3965                                 op::Constant(), op::Reshape(), op::Reshape())),
3966       op::Shape("f32[32,39296,32,64]"));
3967   EXPECT_THAT(root, AllOf(op::AllReduce(op::AllReduce(op::Dot(lhs, rhs))),
3968                           op::Shape("f32[32,24,39296]")));
3969 }
3970 
TEST_F(SpmdPartitioningTest,EinsumLHSNonContractingDimsPartitioned)3971 TEST_F(SpmdPartitioningTest, EinsumLHSNonContractingDimsPartitioned) {
3972   absl::string_view hlo_string = R"(
3973 HloModule module
3974 
3975 ENTRY entry {
3976   %lhs = f32[32,24,64,128] parameter(0)
3977   %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,2]0,1,2,3}
3978   %rhs = f32[32,39296,64] parameter(1)
3979   %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated}
3980   ROOT %dot = f32[32,24,128,39296] dot(%lhs.copy, %rhs.copy),
3981     lhs_batch_dims={0}, rhs_batch_dims={0},
3982     lhs_contracting_dims={2}, rhs_contracting_dims={2},
3983     sharding={devices=[1,2,2,1]0,1,2,3}
3984 })";
3985 
3986   TF_ASSERT_OK_AND_ASSIGN(auto module,
3987                           PartitionComputation(hlo_string, /*num_devices=*/4));
3988   VLOG(1) << module->ToString();
3989 
3990   auto root = module->entry_computation()->root_instruction();
3991   auto lhs = AllOf(
3992       op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
3993                                 op::Constant(), op::Reshape())),
3994       op::Shape("f32[32,12,64,64]"));
3995   auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64]"));
3996   EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[32,12,64,39296]")));
3997 }
3998 
TEST_F(SpmdPartitioningTest,EinsumRHSNonContractingDimsPartitioned)3999 TEST_F(SpmdPartitioningTest, EinsumRHSNonContractingDimsPartitioned) {
4000   absl::string_view hlo_string = R"(
4001 HloModule module
4002 
4003 ENTRY entry {
4004   %lhs = f32[32,24,64] parameter(0)
4005   %lhs.copy = f32[32,24,64] copy(%lhs), sharding={replicated}
4006   %rhs = f32[32,39296,64,128] parameter(1)
4007   %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={devices=[1,2,1,2]0,1,2,3}
4008   ROOT %dot = f32[32,24,39296,128] dot(%lhs.copy, %rhs.copy),
4009     lhs_batch_dims={0}, rhs_batch_dims={0},
4010     lhs_contracting_dims={2}, rhs_contracting_dims={2},
4011     sharding={devices=[1,1,2,2]0,1,2,3}
4012 })";
4013 
4014   TF_ASSERT_OK_AND_ASSIGN(auto module,
4015                           PartitionComputation(hlo_string, /*num_devices=*/4));
4016   VLOG(1) << module->ToString();
4017 
4018   auto root = module->entry_computation()->root_instruction();
4019   auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64]"));
4020   auto rhs = AllOf(
4021       op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), op::Reshape(),
4022                                 op::Constant(), op::Reshape())),
4023       op::Shape("f32[32,19648,64,64]"));
4024   EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[32,24,19648,64]")));
4025 }
4026 
TEST_F(SpmdPartitioningTest,EinsumOutputLHSNonContractingDimPartitioned)4027 TEST_F(SpmdPartitioningTest, EinsumOutputLHSNonContractingDimPartitioned) {
4028   absl::string_view hlo_string = R"(
4029 HloModule module
4030 
4031 ENTRY entry {
4032   %lhs = f32[32,24,64,128] parameter(0)
4033   %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={replicated}
4034   %rhs = f32[32,39296,64,128] parameter(1)
4035   %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={replicated}
4036   ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
4037     lhs_batch_dims={0}, rhs_batch_dims={0},
4038     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4039     sharding={devices=[1,2,1]0,1}
4040 })";
4041 
4042   TF_ASSERT_OK_AND_ASSIGN(auto module,
4043                           PartitionComputation(hlo_string, /*num_devices=*/2));
4044   VLOG(1) << module->ToString();
4045 
4046   auto root = module->entry_computation()->root_instruction();
4047   auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64,128]"));
4048   auto rhs =
4049       AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64,128]"));
4050   EXPECT_THAT(
4051       root,
4052       AllOf(op::Dot(AllOf(op::DynamicSlice(lhs, op::Constant(), op::Reshape(),
4053                                            op::Constant(), op::Constant()),
4054                           op::Shape("f32[32,12,64,128]")),
4055                     rhs),
4056             op::Shape("f32[32,12,39296]")));
4057 }
4058 
TEST_F(SpmdPartitioningTest,EinsumOutputRHSNonContractingDimPartitioned)4059 TEST_F(SpmdPartitioningTest, EinsumOutputRHSNonContractingDimPartitioned) {
4060   absl::string_view hlo_string = R"(
4061 HloModule module
4062 
4063 ENTRY entry {
4064   %lhs = f32[32,24,64,128] parameter(0)
4065   %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={replicated}
4066   %rhs = f32[32,39296,64,128] parameter(1)
4067   %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={replicated}
4068   ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
4069     lhs_batch_dims={0}, rhs_batch_dims={0},
4070     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4071     sharding={devices=[1,1,2]0,1}
4072 })";
4073 
4074   TF_ASSERT_OK_AND_ASSIGN(auto module,
4075                           PartitionComputation(hlo_string, /*num_devices=*/2));
4076   VLOG(1) << module->ToString();
4077 
4078   auto root = module->entry_computation()->root_instruction();
4079   auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64,128]"));
4080   auto rhs =
4081       AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64,128]"));
4082   EXPECT_THAT(root,
4083               AllOf(op::Dot(lhs, AllOf(op::DynamicSlice(
4084                                            rhs, op::Constant(), op::Reshape(),
4085                                            op::Constant(), op::Constant()),
4086                                        op::Shape("f32[32,19648,64,128]"))),
4087                     op::Shape("f32[32,24,19648]")));
4088 }
4089 
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedInContractingOutNonContractingPartitioned)4090 TEST_F(SpmdPartitioningTest,
4091        EinsumRHSWindowedInContractingOutNonContractingPartitioned) {
4092   absl::string_view hlo_string = R"(
4093 HloModule module
4094 
4095 ENTRY entry {
4096   %lhs = f32[320,25,64,128] parameter(0)
4097   %lhs.copy = f32[320,25,64,128] copy(%lhs), sharding={devices=[1,1,4,1]0,1,2,3}
4098   %rhs = f32[320,39296,64,128] parameter(1)
4099   %rhs.copy = f32[320,39296,64,128] copy(%rhs),
4100     sharding={devices=[1,1,4,1]0,1,2,3}
4101   ROOT %dot = f32[320,25,39296] dot(%lhs.copy, %rhs.copy),
4102     lhs_batch_dims={0}, rhs_batch_dims={0},
4103     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4104     sharding={devices=[1,4,1]0,1,2,3}
4105 })";
4106 
4107   TF_ASSERT_OK_AND_ASSIGN(auto module,
4108                           PartitionComputation(hlo_string, /*num_devices=*/4));
4109   VLOG(1) << module->ToString();
4110 
4111   auto root = module->entry_computation()->root_instruction();
4112   auto lhs = AllOf(
4113       op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
4114                                 op::Constant(), op::Reshape(), op::Constant())),
4115       op::Shape("f32[320,25,16,128]"));
4116   auto rhs = AllOf(
4117       op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(),
4118                                 op::Constant(), op::Reshape(), op::Constant())),
4119       op::Shape("f32[320,39296,16,128]"));
4120   EXPECT_THAT(
4121       root,
4122       AllOf(op::GetTupleElement(op::While(op::Tuple(
4123                 lhs, rhs, op::Broadcast(), op::Broadcast(), op::Constant()))),
4124             op::Shape("f32[320,7,39296]")));
4125 
4126   auto while_loop = root->operand(0);
4127   // Check loop condition.
4128   EXPECT_THAT(
4129       while_loop->while_condition()->root_instruction(),
4130       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
4131 
4132   // Check loop body.
4133   auto next_i = op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
4134   auto ds =
4135       AllOf(op::DynamicSlice(
4136                 op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
4137                 op::Constant(), op::Reshape(), op::Constant(), op::Constant()),
4138             op::Shape("f32[320,7,16,128]"));
4139   auto partial_output =
4140       AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),
4141                     op::Dot(ds, op::GetTupleElement(op::Parameter(0)))),
4142             op::Shape("f32[320,7,39296]"));
4143   auto window = op::Conditional(op::Compare(next_i, op::Constant()),
4144                                 partial_output, partial_output);
4145   EXPECT_THAT(while_loop->while_body()->root_instruction(),
4146               op::Tuple(op::GetTupleElement(op::Parameter(0)),
4147                         op::GetTupleElement(op::Parameter(0)), window,
4148                         op::GetTupleElement(op::Parameter(0)), next_i));
4149 
4150   // Check the conditional that contains the collective permute.
4151   auto cp_conditional =
4152       while_loop->while_body()->root_instruction()->operand(2);
4153   EXPECT_THAT(cp_conditional->true_computation()->root_instruction(),
4154               op::CollectivePermute(op::Parameter(0)));
4155   EXPECT_THAT(cp_conditional->false_computation()->root_instruction(),
4156               op::Parameter(0));
4157 }
4158 
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedInContractingOutNonContractingFromBroadcast)4159 TEST_F(SpmdPartitioningTest,
4160        EinsumRHSWindowedInContractingOutNonContractingFromBroadcast) {
4161   absl::string_view hlo_string = R"(
4162 HloModule module
4163 
4164 ENTRY entry {
4165   %constant.1 = f32[] constant(2)
4166   %broadcast = f32[32,25,64,128] broadcast(%constant.1), dimensions={},
4167     sharding={devices=[1,1,4,1]0,1,2,3}
4168   %add = f32[32,25,64,128] add(%broadcast, %broadcast),
4169     sharding={devices=[1,1,4,1]0,1,2,3}
4170   %rhs = f32[32,39296,64,128] parameter(0)
4171   %rhs.copy = f32[32,39296,64,128] copy(%rhs),
4172     sharding={devices=[1,1,4,1]0,1,2,3}
4173   ROOT %dot = f32[32,25,39296] dot(%add, %rhs.copy),
4174     lhs_batch_dims={0}, rhs_batch_dims={0},
4175     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4176     sharding={devices=[1,4,1]0,1,2,3}
4177 })";
4178 
4179   TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
4180                                                             /*num_devices=*/4));
4181   VLOG(1) << module->ToString();
4182   // Involves loop code motion, skips pattern matching.
4183 }
4184 
TEST_F(SpmdPartitioningTest,EinsumLHSWindowedInContractingOutNonContractingPartitioned)4185 TEST_F(SpmdPartitioningTest,
4186        EinsumLHSWindowedInContractingOutNonContractingPartitioned) {
4187   absl::string_view hlo_string = R"(
4188 HloModule module
4189 
4190 ENTRY entry {
4191   %lhs = f32[16,1024,16384] parameter(0)
4192   %lhs.copy = f32[16,1024,16384] copy(%lhs),
4193     sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
4194   %rhs = f32[16384,67,128] parameter(1)
4195   %rhs.copy = f32[16384,67,128] copy(%rhs),
4196     sharding={devices=[4,1,1,2]0,4,1,5,2,6,3,7 last_tile_dim_replicate}
4197   ROOT %dot = f32[16,1024,67,128] dot(%lhs.copy, %rhs.copy),
4198     lhs_batch_dims={}, rhs_batch_dims={},
4199     lhs_contracting_dims={2}, rhs_contracting_dims={0},
4200     sharding={devices=[2,1,4,1]0,1,2,3,4,5,6,7}
4201 })";
4202 
4203   TF_ASSERT_OK_AND_ASSIGN(auto module,
4204                           PartitionComputation(hlo_string, /*num_devices=*/8));
4205   VLOG(1) << module->ToString();
4206 
4207   auto root = module->entry_computation()->root_instruction();
4208   auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
4209                                              op::Constant(), op::Reshape())),
4210                    op::Shape("f32[8,1024,4096]"));
4211   auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(),
4212                                              op::Constant(), op::Constant())),
4213                    op::Shape("f32[4096,67,128]"));
4214   EXPECT_THAT(
4215       root,
4216       AllOf(op::GetTupleElement(op::While(op::Tuple(
4217                 lhs, rhs, op::Broadcast(), op::Broadcast(), op::Constant()))),
4218             op::Shape("f32[8,1024,17,128]")));
4219 
4220   auto while_loop = root->operand(0);
4221   // Check loop condition.
4222   EXPECT_THAT(
4223       while_loop->while_condition()->root_instruction(),
4224       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
4225 
4226   // Check loop body.
4227   auto next_i = op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
4228   auto ds =
4229       AllOf(op::DynamicSlice(
4230                 op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
4231                 op::Constant(), op::Reshape(), op::Constant()),
4232             op::Shape("f32[4096,17,128]"));
4233   auto partial_output =
4234       AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),
4235                     op::Dot(op::GetTupleElement(op::Parameter(0)), ds)),
4236             op::Shape("f32[8,1024,17,128]"));
4237   auto window = op::Conditional(op::Compare(next_i, op::Constant()),
4238                                 partial_output, partial_output);
4239   EXPECT_THAT(while_loop->while_body()->root_instruction(),
4240               op::Tuple(op::GetTupleElement(op::Parameter(0)),
4241                         op::GetTupleElement(op::Parameter(0)), window,
4242                         op::GetTupleElement(op::Parameter(0)), next_i));
4243 
4244   // Check the conditional that contains the collective permute.
4245   auto cp_conditional =
4246       while_loop->while_body()->root_instruction()->operand(2);
4247   EXPECT_THAT(cp_conditional->true_computation()->root_instruction(),
4248               op::CollectivePermute(op::Parameter(0)));
4249   EXPECT_THAT(cp_conditional->false_computation()->root_instruction(),
4250               op::Parameter(0));
4251 }
4252 
TEST_F(SpmdPartitioningTest,EinsumLHSWindowedInContractingOutNonContractingPartitioned2)4253 TEST_F(SpmdPartitioningTest,
4254        EinsumLHSWindowedInContractingOutNonContractingPartitioned2) {
4255   absl::string_view hlo_string = R"(
4256 HloModule module
4257 
4258 ENTRY entry {
4259   %lhs = f32[16,1024,16384] parameter(0)
4260   %lhs.copy = f32[16,1024,16384] copy(%lhs),
4261     sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
4262   %rhs = f32[16384,2,33,128] parameter(1)
4263   %rhs.copy = f32[16384,2,33,128] copy(%rhs),
4264     sharding={devices=[4,1,1,1,2]0,4,1,5,2,6,3,7 last_tile_dim_replicate}
4265   ROOT %dot = f32[16,1024,2,33,128] dot(%lhs.copy, %rhs.copy),
4266     lhs_batch_dims={}, rhs_batch_dims={},
4267     lhs_contracting_dims={2}, rhs_contracting_dims={0},
4268     sharding={devices=[2,1,2,2,1]0,1,2,3,4,5,6,7}
4269 })";
4270 
4271   TF_ASSERT_OK_AND_ASSIGN(auto module,
4272                           PartitionComputation(hlo_string, /*num_devices=*/8));
4273   VLOG(1) << module->ToString();
4274 
4275   auto root = module->entry_computation()->root_instruction();
4276   auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
4277                                              op::Constant(), op::Reshape())),
4278                    op::Shape("f32[8,1024,4096]"));
4279   auto rhs = AllOf(
4280       op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(), op::Constant(),
4281                                 op::Constant(), op::Constant())),
4282       op::Shape("f32[4096,2,33,128]"));
4283   EXPECT_THAT(
4284       root,
4285       AllOf(op::GetTupleElement(op::While(op::Tuple(
4286                 lhs, rhs, op::Broadcast(), op::Broadcast(), op::Constant()))),
4287             op::Shape("f32[8,1024,1,17,128]")));
4288 
4289   auto while_loop = root->operand(0);
4290   // Check loop condition.
4291   EXPECT_THAT(
4292       while_loop->while_condition()->root_instruction(),
4293       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
4294 
4295   // Check loop body.
4296   auto next_i = op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
4297   auto ds =
4298       AllOf(op::DynamicSlice(
4299                 op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
4300                 op::Constant(), op::Reshape(), op::Reshape(), op::Constant()),
4301             op::Shape("f32[4096,1,17,128]"));
4302   auto partial_output =
4303       AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),
4304                     op::Dot(op::GetTupleElement(op::Parameter(0)), ds)),
4305             op::Shape("f32[8,1024,1,17,128]"));
4306   auto window = op::Conditional(op::Compare(next_i, op::Constant()),
4307                                 partial_output, partial_output);
4308   EXPECT_THAT(while_loop->while_body()->root_instruction(),
4309               op::Tuple(op::GetTupleElement(op::Parameter(0)),
4310                         op::GetTupleElement(op::Parameter(0)), window,
4311                         op::GetTupleElement(op::Parameter(0)), next_i));
4312 
4313   // Check the conditional that contains the collective permute.
4314   auto cp_conditional =
4315       while_loop->while_body()->root_instruction()->operand(2);
4316   EXPECT_THAT(cp_conditional->true_computation()->root_instruction(),
4317               op::CollectivePermute(op::Parameter(0)));
4318   EXPECT_THAT(cp_conditional->false_computation()->root_instruction(),
4319               op::Parameter(0));
4320 }
4321 
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedNonContracting)4322 TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContracting) {
4323   absl::string_view hlo_string = R"(
4324 HloModule module
4325 
4326 ENTRY entry {
4327   %lhs = f32[32,24,64,128] parameter(0)
4328   %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
4329   %rhs = f32[32,39295,64,128] parameter(1)
4330   %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
4331   ROOT %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
4332     lhs_batch_dims={0}, rhs_batch_dims={0},
4333     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4334     sharding={devices=[1,2,1]0,1}
4335 })";
4336 
4337   TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
4338                                                             /*num_devices=*/2));
4339   VLOG(1) << module->ToString();
4340   auto root = module->entry_computation()->root_instruction();
4341   auto lhs = AllOf(
4342       op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
4343                                 op::Constant(), op::Constant())),
4344       op::Shape("f32[32,12,64,128]"));
4345   auto rhs =
4346       AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()),
4347                                       op::Constant(), op::Reshape(),
4348                                       op::Constant(), op::Constant())),
4349             op::Shape("f32[32,19648,64,128]"));
4350   EXPECT_THAT(root,
4351               AllOf(op::Slice(AllOf(op::GetTupleElement(op::While(op::Tuple(
4352                                         lhs, rhs, op::Broadcast(),
4353                                         op::Broadcast(), op::Constant()))),
4354                                     op::Shape("f32[32,12,39296]"))),
4355                     op::Shape("f32[32,12,39295]")));
4356   auto while_loop = root->operand(0)->operand(0);
4357   // Check loop condition.
4358   EXPECT_THAT(
4359       while_loop->while_condition()->root_instruction(),
4360       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
4361 
4362   // Check loop body.
4363   auto next_i = op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
4364   auto window = op::Conditional(op::Compare(next_i, op::Constant()),
4365                                 op::GetTupleElement(op::Parameter(0)),
4366                                 op::GetTupleElement(op::Parameter(0)));
4367   auto partial_output = op::Dot(op::GetTupleElement(op::Parameter(0)),
4368                                 op::GetTupleElement(op::Parameter(0)));
4369   EXPECT_THAT(
4370       while_loop->while_body()->root_instruction(),
4371       op::Tuple(op::GetTupleElement(op::Parameter(0)), window,
4372                 op::DynamicUpdateSlice(op::GetTupleElement(op::Parameter(0)),
4373                                        partial_output, op::Constant(),
4374                                        op::Constant(), op::Reshape()),
4375                 op::GetTupleElement(op::Parameter(0)), next_i));
4376 
4377   // Check the conditional that contains the collective permute.
4378   auto cp_conditional =
4379       while_loop->while_body()->root_instruction()->operand(1);
4380   EXPECT_THAT(cp_conditional->true_computation()->root_instruction(),
4381               op::CollectivePermute(op::Parameter(0)));
4382   EXPECT_THAT(cp_conditional->false_computation()->root_instruction(),
4383               op::Parameter(0));
4384 }
4385 
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedContracting)4386 TEST_F(SpmdPartitioningTest, EinsumRHSWindowedContracting) {
4387   absl::string_view hlo_string = R"(
4388 HloModule module
4389 
4390 ENTRY entry {
4391   %lhs = f32[32,24,63,128] parameter(0)
4392   %lhs.copy = f32[32,24,63,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
4393   %rhs = f32[32,39296,63,128] parameter(1)
4394   %rhs.copy = f32[32,39296,63,128] copy(%rhs), sharding={devices=[1,1,2,1]0,1}
4395   ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
4396     lhs_batch_dims={0}, rhs_batch_dims={0},
4397     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4398     sharding={devices=[1,2,1]0,1}
4399 })";
4400 
4401   TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
4402                                                             /*num_devices=*/2));
4403   VLOG(1) << module->ToString();
4404   auto root = module->entry_computation()->root_instruction();
4405   auto lhs = AllOf(
4406       op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
4407                                 op::Constant(), op::Constant())),
4408       op::Shape("f32[32,12,63,128]"));
4409   auto rhs =
4410       AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()),
4411                                       op::Constant(), op::Constant(),
4412                                       op::Reshape(), op::Constant())),
4413             op::Shape("f32[32,39296,32,128]"));
4414   auto masked_rhs =
4415       op::Select(op::Compare(), rhs, op::Broadcast(op::Constant()));
4416   EXPECT_THAT(root, AllOf(op::GetTupleElement(op::While(
4417                               op::Tuple(lhs, masked_rhs, op::Broadcast(),
4418                                         op::Broadcast(), op::Constant()))),
4419                           op::Shape("f32[32,12,39296]")));
4420   auto while_loop = root->operand(0);
4421   // Check loop condition.
4422   EXPECT_THAT(
4423       while_loop->while_condition()->root_instruction(),
4424       op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
4425 
4426   // Check loop body.
4427   auto next_i = op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
4428   auto window = op::Conditional(op::Compare(next_i, op::Constant()),
4429                                 op::GetTupleElement(op::Parameter(0)),
4430                                 op::GetTupleElement(op::Parameter(0)));
4431   auto partial_output = op::Dot(
4432       op::DynamicSlice(
4433           op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
4434           op::Constant(), op::Constant(), op::Reshape(), op::Constant()),
4435       op::GetTupleElement(op::Parameter(0)));
4436   EXPECT_THAT(
4437       while_loop->while_body()->root_instruction(),
4438       op::Tuple(op::GetTupleElement(op::Parameter(0)), window,
4439                 op::Add(op::GetTupleElement(op::Parameter(0)), partial_output),
4440                 op::GetTupleElement(op::Parameter(0)), next_i));
4441 
4442   // Check the conditional that contains the collective permute.
4443   auto cp_conditional =
4444       while_loop->while_body()->root_instruction()->operand(1);
4445   EXPECT_THAT(cp_conditional->true_computation()->root_instruction(),
4446               op::CollectivePermute(op::Parameter(0)));
4447   EXPECT_THAT(cp_conditional->false_computation()->root_instruction(),
4448               op::Parameter(0));
4449 }
4450 
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedNonContractingReduce1)4451 TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContractingReduce1) {
4452   absl::string_view hlo_string = R"(
4453 HloModule module
4454 
4455 sum {
4456   a = f32[] parameter(0)
4457   b = f32[] parameter(1)
4458   ROOT add = f32[] add(a, b)
4459 }
4460 
4461 ENTRY entry {
4462   %lhs = f32[32,24,64,128] parameter(0)
4463   %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
4464   %rhs = f32[32,39295,64,128] parameter(1)
4465   %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
4466   %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
4467     lhs_batch_dims={0}, rhs_batch_dims={0},
4468     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4469     sharding={devices=[1,2,1]0,1}
4470   %constant = f32[] constant(0)
4471   %constant.1 = f32[] constant(2)
4472   %broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={},
4473     sharding={devices=[1,2,1]0,1}
4474   %multiply = f32[32,24,39295] multiply(%dot, %broadcast),
4475   sharding={devices=[1,2,1]0,1}
4476   ROOT %reduce = f32[32,24] reduce(%multiply, %constant), dimensions={2},
4477     to_apply=sum, sharding={devices=[1,2]0,1}
4478 })";
4479 
4480   TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
4481                                                             /*num_devices=*/2));
4482   VLOG(1) << module->ToString();
4483   // Involves loop code motion, skips pattern matching.
4484 }
4485 
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedNonContractingReduce2)4486 TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContractingReduce2) {
4487   absl::string_view hlo_string = R"(
4488 HloModule module
4489 
4490 sum {
4491   a = f32[] parameter(0)
4492   b = f32[] parameter(1)
4493   ROOT add = f32[] add(a, b)
4494 }
4495 
4496 ENTRY entry {
4497   %lhs = f32[32,24,64,128] parameter(0)
4498   %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
4499   %rhs = f32[32,39295,64,128] parameter(1)
4500   %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
4501   %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
4502     lhs_batch_dims={0}, rhs_batch_dims={0},
4503     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4504     sharding={devices=[1,2,1]0,1}
4505   %constant = f32[] constant(0)
4506   %constant.1 = f32[] constant(2)
4507   %broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={},
4508     sharding={devices=[1,2,1]0,1}
4509   %multiply = f32[32,24,39295] multiply(%dot, %broadcast),
4510     sharding={devices=[1,2,1]0,1}
4511   ROOT %reduce = f32[32,39295] reduce(%multiply, %constant), dimensions={1},
4512     to_apply=sum, sharding={replicated}
4513 })";
4514 
4515   TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
4516                                                             /*num_devices=*/2));
4517   VLOG(1) << module->ToString();
4518   // Involves loop code motion, skips pattern matching.
4519 }
4520 
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedContractingFromBroadcast)4521 TEST_F(SpmdPartitioningTest, EinsumRHSWindowedContractingFromBroadcast) {
4522   absl::string_view hlo_string = R"(
4523 HloModule module
4524 
4525 ENTRY entry {
4526   %rhs = f32[32,39296,63,128] parameter(0)
4527   %rhs.copy = f32[32,39296,63,128] copy(%rhs), sharding={devices=[1,1,2,1]0,1}
4528   %constant.1 = f32[] constant(2)
4529   %broadcast = f32[32,24,63,128] broadcast(%constant.1), dimensions={},
4530     sharding={devices=[1,2,1,1]0,1}
4531   %add = f32[32,24,63,128] add(%broadcast, %broadcast),
4532     sharding={devices=[1,2,1,1]0,1}
4533   ROOT %dot = f32[32,24,39296] dot(%add, %rhs.copy),
4534     lhs_batch_dims={0}, rhs_batch_dims={0},
4535     lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4536     sharding={devices=[1,2,1]0,1}
4537 })";
4538 
4539   TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
4540                                                             /*num_devices=*/2));
4541   VLOG(1) << module->ToString();
4542   // Involves loop code motion, skips pattern matching.
4543 }
4544 
TEST_F(SpmdPartitioningTest,ReplicatedRng)4545 TEST_F(SpmdPartitioningTest, ReplicatedRng) {
4546   absl::string_view hlo_string = R"(
4547 HloModule module
4548 
4549 ENTRY entry {
4550   %lhs = s32[] parameter(0)
4551   %lhs.copy = s32[] copy(%lhs), sharding={replicated}
4552   %rhs = s32[] parameter(1)
4553   %rhs.copy = s32[] copy(%rhs), sharding={replicated}
4554   ROOT %rng = s32[4]{0} rng(%lhs.copy, %rhs.copy),
4555       distribution=rng_uniform, sharding={replicated}
4556 })";
4557 
4558   TF_ASSERT_OK_AND_ASSIGN(auto module,
4559                           PartitionComputation(hlo_string, /*num_devices=*/2));
4560   VLOG(1) << module->ToString();
4561 
4562   auto root = module->entry_computation()->root_instruction();
4563   auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("s32[]"));
4564   auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("s32[]"));
4565   EXPECT_THAT(
4566       root,
4567       AllOf(op::AllReduce(op::Select(
4568                 op::Broadcast(op::Compare(op::PartitionId(), op::Constant())),
4569                 op::Rng(), op::Broadcast(op::Constant()))),
4570             op::Shape("s32[4]")));
4571 }
4572 
TEST_F(SpmdPartitioningTest,ManualRng)4573 TEST_F(SpmdPartitioningTest, ManualRng) {
4574   absl::string_view hlo_string = R"(
4575 HloModule module
4576 
4577 ENTRY entry {
4578   %lhs = s32[] parameter(0), sharding={manual}
4579   %rhs = s32[] parameter(1), sharding={manual}
4580   ROOT %rng = s32[4]{0} rng(%lhs, %rhs),
4581       distribution=rng_uniform, sharding={manual}
4582 })";
4583 
4584   TF_ASSERT_OK_AND_ASSIGN(auto module,
4585                           PartitionComputation(hlo_string, /*num_devices=*/2));
4586   VLOG(1) << module->ToString();
4587 
4588   auto root = module->entry_computation()->root_instruction();
4589   EXPECT_THAT(root, AllOf(op::Rng(op::Parameter(0), op::Parameter(1)),
4590                           op::Shape("s32[4]")));
4591 }
4592 
TEST_F(SpmdPartitioningTest,PartitionedRng)4593 TEST_F(SpmdPartitioningTest, PartitionedRng) {
4594   absl::string_view hlo_string = R"(
4595 HloModule module
4596 
4597 ENTRY entry {
4598   %lhs = s32[] parameter(0)
4599   %lhs.copy = s32[] copy(%lhs), sharding={replicated}
4600   %rhs = s32[] parameter(1)
4601   %rhs.copy = s32[] copy(%rhs), sharding={maximal device=1}
4602   ROOT %rng = s32[4]{0} rng(%lhs.copy, %rhs.copy),
4603       distribution=rng_uniform, sharding={devices=[2]0,1}
4604 })";
4605 
4606   TF_ASSERT_OK_AND_ASSIGN(auto module,
4607                           PartitionComputation(hlo_string, /*num_devices=*/2));
4608   VLOG(1) << module->ToString();
4609 
4610   auto root = module->entry_computation()->root_instruction();
4611   auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("s32[]"));
4612   auto rhs = AllOf(op::Copy(op::Copy(op::Parameter(1))), op::Shape("s32[]"));
4613   EXPECT_THAT(root, AllOf(op::Rng(lhs, op::AllReduce(op::Select(
4614                                            op::Broadcast(op::Compare()), rhs,
4615                                            op::Broadcast(op::Constant())))),
4616                           op::Shape("s32[2]")));
4617 }
4618 
TEST_F(SpmdPartitioningTest,PartialReplicatedRng)4619 TEST_F(SpmdPartitioningTest, PartialReplicatedRng) {
4620   absl::string_view hlo_string = R"(
4621 HloModule module
4622 
4623 ENTRY entry {
4624   %lhs = s32[] parameter(0), sharding={replicated}
4625   %rhs = s32[] parameter(1), sharding={replicated}
4626   ROOT %rng = s32[8]{0} rng(%lhs, %rhs),
4627       distribution=rng_uniform,
4628       sharding={devices=[2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
4629 })";
4630 
4631   TF_ASSERT_OK_AND_ASSIGN(auto module,
4632                           PartitionComputation(hlo_string, /*num_devices=*/8));
4633   VLOG(1) << module->ToString();
4634 
4635   auto root = module->entry_computation()->root_instruction();
4636   auto lhs = AllOf(op::Parameter(0), op::Shape("s32[]"));
4637   auto rhs = AllOf(op::Parameter(1), op::Shape("s32[]"));
4638   auto partition_id =
4639       AllOf(op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())),
4640             op::Shape("u32[]"));
4641   EXPECT_THAT(
4642       root, AllOf(op::AllReduce(op::Select(
4643                       op::Broadcast(op::Compare(partition_id, op::Constant())),
4644                       op::Rng(lhs, rhs), op::Broadcast(op::Constant()))),
4645                   op::Shape("s32[4]")));
4646 }
4647 
TEST_F(SpmdPartitioningTest,DynamicSliceAlongNonPartitionedDimension)4648 TEST_F(SpmdPartitioningTest, DynamicSliceAlongNonPartitionedDimension) {
4649   absl::string_view hlo_string = R"(
4650 HloModule module
4651 
4652 ENTRY entry {
4653   %input = s32[128,64] parameter(0)
4654   %input.copy = s32[128,64] copy(%input), sharding={devices=[2,1]0,1}
4655   %index = s32[] parameter(1)
4656   %constant = s32[] constant(0)
4657   ROOT %dynamic-slice = s32[128,2] dynamic-slice(%input.copy, %constant, %index),
4658     dynamic_slice_sizes={128,2}, sharding={devices=[2,1]0,1}
4659 })";
4660 
4661   TF_ASSERT_OK_AND_ASSIGN(auto module,
4662                           PartitionComputation(hlo_string, /*num_devices=*/2));
4663   VLOG(1) << module->ToString();
4664 
4665   auto root = module->entry_computation()->root_instruction();
4666   auto input = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
4667                                                op::Constant())),
4668                      op::Shape("s32[64,64]"));
4669   EXPECT_THAT(root,
4670               AllOf(op::DynamicSlice(input, op::Constant(), op::Parameter(1)),
4671                     op::Shape("s32[64,2]")));
4672 }
4673 
TEST_F(SpmdPartitioningTest,DynamicUpdateSliceAlongNonPartitionedDimension)4674 TEST_F(SpmdPartitioningTest, DynamicUpdateSliceAlongNonPartitionedDimension) {
4675   absl::string_view hlo_string = R"(
4676 HloModule module
4677 
4678 ENTRY entry {
4679   %input = s32[128,64] parameter(0)
4680   %input.copy = s32[128,64] copy(%input), sharding={devices=[2,1]0,1}
4681   %index = s32[] parameter(1)
4682   %constant = s32[] constant(0)
4683   %update = s32[128,2] parameter(2)
4684   %update.copy = s32[128,2] copy(%update), sharding={devices=[2,1]0,1}
4685   ROOT %dynamic-update-slice = s32[128,64]
4686     dynamic-update-slice(%input.copy, %update.copy, %constant, %index),
4687     sharding={devices=[2,1]0,1}
4688 })";
4689 
4690   TF_ASSERT_OK_AND_ASSIGN(auto module,
4691                           PartitionComputation(hlo_string, /*num_devices=*/2));
4692   VLOG(1) << module->ToString();
4693 
4694   auto root = module->entry_computation()->root_instruction();
4695   auto input = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
4696                                                op::Constant())),
4697                      op::Shape("s32[64,64]"));
4698   auto update = AllOf(op::Copy(op::DynamicSlice(op::Parameter(2), op::Reshape(),
4699                                                 op::Constant())),
4700                       op::Shape("s32[64,2]"));
4701   EXPECT_THAT(root, AllOf(op::DynamicUpdateSlice(input, update, op::Constant(),
4702                                                  op::Parameter(1)),
4703                           op::Shape("s32[64,64]")));
4704 }
4705 
TEST_F(SpmdPartitioningTest,DynamicUpdateSliceAlongPartitionedDimension)4706 TEST_F(SpmdPartitioningTest, DynamicUpdateSliceAlongPartitionedDimension) {
4707   absl::string_view hlo_string = R"(
4708 HloModule module
4709 
4710 ENTRY entry {
4711   %input = s32[128,64] parameter(0)
4712   %input.copy = s32[128,64] copy(%input), sharding={devices=[1,2]0,1}
4713   %index = s32[] parameter(1)
4714   %constant = s32[] constant(60)
4715   %update = s32[128,2] parameter(2)
4716   %update.copy = s32[128,2] copy(%update), sharding={devices=[1,2]0,1}
4717   ROOT %dynamic-update-slice = s32[128,64]
4718     dynamic-update-slice(%input.copy, %update.copy, %index, %constant),
4719     sharding={devices=[1,2]0,1}
4720 })";
4721 
4722   TF_ASSERT_OK_AND_ASSIGN(auto module,
4723                           PartitionComputation(hlo_string, /*num_devices=*/2));
4724   VLOG(1) << module->ToString();
4725 
4726   auto root = module->entry_computation()->root_instruction();
4727   auto input = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
4728                                                op::Reshape())),
4729                      op::Shape("s32[128,32]"));
4730   auto update = AllOf(op::AllReduce(op::DynamicUpdateSlice(
4731                           op::Broadcast(),
4732                           op::Copy(op::DynamicSlice(
4733                               op::Parameter(2), op::Constant(), op::Reshape())),
4734                           op::Constant(), op::Reshape())),
4735                       op::Shape("s32[128,2]"));
4736 
4737   EXPECT_THAT(
4738       root, AllOf(op::Select(op::Broadcast(),
4739                              op::DynamicUpdateSlice(
4740                                  input, update, op::Parameter(1), op::Select()),
4741                              input),
4742                   op::Shape("s32[128,32]")));
4743 }
4744 
TEST_F(SpmdPartitioningTest,DynamicUpdateSliceAlongPartitionedDimension2)4745 TEST_F(SpmdPartitioningTest, DynamicUpdateSliceAlongPartitionedDimension2) {
4746   absl::string_view hlo_string = R"(
4747 HloModule module
4748 
4749 ENTRY entry {
4750   %input = s32[8,790,2] parameter(0),
4751     sharding={devices=[8,1,1]0,1,2,3,4,5,6,7}
4752   %index = s32[] parameter(1)
4753   %constant = s32[] constant(0)
4754   %update = s32[1,790,2] parameter(2),
4755     sharding={devices=[8,1,1]0,1,2,3,4,5,6,7}
4756   ROOT %dynamic-update-slice = s32[8,790,2]
4757     dynamic-update-slice(%input, %update, %index, %constant, %constant),
4758     sharding={devices=[8,1,1]0,1,2,3,4,5,6,7}
4759 })";
4760 
4761   TF_ASSERT_OK_AND_ASSIGN(auto module,
4762                           PartitionComputation(hlo_string, /*num_devices=*/8));
4763   VLOG(1) << module->ToString();
4764 
4765   auto root = module->entry_computation()->root_instruction();
4766   auto input = AllOf(op::Parameter(0), op::Shape("s32[1,790,2]"));
4767   auto update = AllOf(op::AllReduce(op::Select(
4768                           op::Broadcast(), op::Parameter(2), op::Broadcast())),
4769                       op::Shape("s32[1,790,2]"));
4770   EXPECT_THAT(
4771       root,
4772       AllOf(op::Select(op::Broadcast(),
4773                        op::DynamicUpdateSlice(input, update, op::Select(),
4774                                               op::Constant(), op::Constant()),
4775                        input),
4776             op::Shape("s32[1,790,2]")));
4777 }
4778 
TEST_F(SpmdPartitioningTest,DynamicUpdateSlicePartitionSliceAndNonSliceDims)4779 TEST_F(SpmdPartitioningTest, DynamicUpdateSlicePartitionSliceAndNonSliceDims) {
4780   absl::string_view hlo_string = R"(
4781 HloModule module
4782 
4783 ENTRY entry {
4784   %input = s32[128,64] parameter(0)
4785   %input.copy = s32[128,64] copy(%input), sharding={devices=[2,2]0,1,2,3}
4786   %constant.0 = s32[] constant(0)
4787   %constant.1 = s32[] constant(60)
4788   %update = s32[128,2] parameter(1)
4789   %update.copy = s32[128,2] copy(%update), sharding={devices=[2,2]0,1,2,3}
4790   ROOT %dynamic-update-slice = s32[128,64]
4791     dynamic-update-slice(%input.copy, %update.copy, %constant.0, %constant.1),
4792     sharding={devices=[2,2]0,1,2,3}
4793 })";
4794 
4795   TF_ASSERT_OK_AND_ASSIGN(auto module,
4796                           PartitionComputation(hlo_string, /*num_devices=*/4));
4797   VLOG(1) << module->ToString();
4798 
4799   auto root = module->entry_computation()->root_instruction();
4800   auto input = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
4801                                                op::Reshape())),
4802                      op::Shape("s32[64,32]"));
4803   auto update = AllOf(op::AllReduce(op::DynamicUpdateSlice(
4804                           op::Broadcast(),
4805                           op::Copy(op::DynamicSlice(
4806                               op::Parameter(1), op::Reshape(), op::Reshape())),
4807                           op::Constant(), op::Reshape())),
4808                       op::Shape("s32[64,2]"));
4809 
4810   EXPECT_THAT(root,
4811               AllOf(op::Select(op::Broadcast(),
4812                                op::DynamicUpdateSlice(
4813                                    input, update, op::Constant(), op::Select()),
4814                                input),
4815                     op::Shape("s32[64,32]")));
4816 }
4817 
TEST_F(SpmdPartitioningTest,PassthroughGather)4818 TEST_F(SpmdPartitioningTest, PassthroughGather) {
4819   absl::string_view hlo_string = R"(
4820 HloModule module
4821 
4822 ENTRY entry {
4823   %input = f32[2,9] parameter(0), sharding={devices=[1,2]0,1}
4824   %indices = s32[3] parameter(1), sharding={replicated}
4825   ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1},
4826     collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1,
4827     slice_sizes={1,9}, sharding={devices=[1,2]0,1}
4828 })";
4829   TF_ASSERT_OK_AND_ASSIGN(auto module,
4830                           PartitionComputation(hlo_string, /*num_devices=*/2));
4831   VLOG(1) << module->ToString();
4832   HloInstruction* root = module->entry_computation()->root_instruction();
4833   EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)),
4834                           op::Shape("f32[3,5]")));
4835 }
4836 
TEST_F(SpmdPartitioningTest,PassthroughGather_PartialReplicate)4837 TEST_F(SpmdPartitioningTest, PassthroughGather_PartialReplicate) {
4838   absl::string_view hlo_string = R"(
4839 HloModule module
4840 
4841 ENTRY entry {
4842   %input = f32[2,9] parameter(0),
4843     sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
4844   %indices = s32[3] parameter(1), sharding={replicated}
4845   ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1},
4846     collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1,
4847     slice_sizes={1,9}, sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
4848 })";
4849   TF_ASSERT_OK_AND_ASSIGN(auto module,
4850                           PartitionComputation(hlo_string, /*num_devices=*/4));
4851   VLOG(1) << module->ToString();
4852   HloInstruction* root = module->entry_computation()->root_instruction();
4853   EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)),
4854                           op::Shape("f32[3,5]")));
4855 }
4856 
TEST_F(SpmdPartitioningTest,IndexPassthroughGather)4857 TEST_F(SpmdPartitioningTest, IndexPassthroughGather) {
4858   absl::string_view hlo_string = R"(
4859 HloModule module
4860 
4861 ENTRY entry {
4862   %input = f32[2,9,8] parameter(0), sharding={replicated}
4863   %indices = s32[4,2,4] parameter(1), sharding={devices=[2,1,2]0,1,2,3}
4864   ROOT %gather = f32[8,4,4] gather(%input, %indices), offset_dims={0},
4865     collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=1,
4866     slice_sizes={1,1,8}, sharding={devices=[1,2,2]0,1,2,3}
4867 })";
4868   TF_ASSERT_OK_AND_ASSIGN(auto module,
4869                           PartitionComputation(hlo_string, /*num_devices=*/4));
4870   VLOG(1) << module->ToString();
4871   HloInstruction* root = module->entry_computation()->root_instruction();
4872   EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)),
4873                           op::Shape("f32[8,2,2]")));
4874 }
4875 
TEST_F(SpmdPartitioningTest,IndexPassthroughGather_PartialReplicate)4876 TEST_F(SpmdPartitioningTest, IndexPassthroughGather_PartialReplicate) {
4877   absl::string_view hlo_string = R"(
4878 HloModule module
4879 
4880 ENTRY entry {
4881   %input = f32[2,9,8] parameter(0), sharding={replicated}
4882   %indices = s32[4,2,4] parameter(1),
4883     sharding={devices=[2,1,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
4884   ROOT %gather = f32[8,4,4] gather(%input, %indices), offset_dims={0},
4885     collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=1,
4886     slice_sizes={1,1,8},
4887     sharding={devices=[1,2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
4888 })";
4889   TF_ASSERT_OK_AND_ASSIGN(auto module,
4890                           PartitionComputation(hlo_string, /*num_devices=*/8));
4891   VLOG(1) << module->ToString();
4892   HloInstruction* root = module->entry_computation()->root_instruction();
4893   EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)),
4894                           op::Shape("f32[8,2,2]")));
4895 }
4896 
TEST_F(SpmdPartitioningTest,GatherPartitionedOnTrivialSliceDims)4897 TEST_F(SpmdPartitioningTest, GatherPartitionedOnTrivialSliceDims) {
4898   absl::string_view hlo_string = R"(
4899 HloModule module
4900 
4901 ENTRY entry {
4902   %input = f32[17,9] parameter(0), sharding={devices=[2,1]0,1}
4903   %indices = s32[2,3] parameter(1), sharding={replicated}
4904   ROOT %gather = f32[2,3,9] gather(%input, %indices), offset_dims={2},
4905     collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2,
4906     slice_sizes={1,9}, sharding={replicated}
4907 })";
4908   TF_ASSERT_OK_AND_ASSIGN(auto module,
4909                           PartitionComputation(hlo_string, /*num_devices=*/2));
4910   VLOG(1) << module->ToString();
4911   auto offset =
4912       op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
4913   auto min = AllOf(op::Broadcast(offset), op::Shape("s32[2,3]"));
4914   auto max = AllOf(op::Broadcast(op::Add(offset, op::Constant())),
4915                    op::Shape("s32[2,3]"));
4916   auto clamp = op::Clamp(min, op::Parameter(1), max);
4917   auto gather = op::Gather(op::Parameter(0), op::Subtract(clamp, min));
4918   auto mask =
4919       op::Or(op::Lt(op::Parameter(1), min), op::Gt(op::Parameter(1), max));
4920   auto masked =
4921       op::Select(op::Broadcast(mask), op::Broadcast(op::Constant()), gather);
4922   HloInstruction* root = module->entry_computation()->root_instruction();
4923   EXPECT_THAT(root, AllOf(op::AllReduce(masked), op::Shape("f32[2,3,9]")));
4924 }
4925 
TEST_F(SpmdPartitioningTest,GatherPartitionedOnTrivialSliceDims_PartialReplicate)4926 TEST_F(SpmdPartitioningTest,
4927        GatherPartitionedOnTrivialSliceDims_PartialReplicate) {
4928   absl::string_view hlo_string = R"(
4929 HloModule module
4930 
4931 ENTRY entry {
4932   %input = f32[17,9] parameter(0),
4933     sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
4934   %indices = s32[2,3] parameter(1), sharding={replicated}
4935   ROOT %gather = f32[2,3,9] gather(%input, %indices), offset_dims={2},
4936     collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2,
4937     slice_sizes={1,9}, sharding={replicated}
4938 })";
4939   TF_ASSERT_OK_AND_ASSIGN(auto module,
4940                           PartitionComputation(hlo_string, /*num_devices=*/4));
4941   VLOG(1) << module->ToString();
4942   auto offset =
4943       op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
4944   auto min = AllOf(op::Broadcast(offset), op::Shape("s32[2,3]"));
4945   auto max = AllOf(op::Broadcast(op::Add(offset, op::Constant())),
4946                    op::Shape("s32[2,3]"));
4947   auto clamp = op::Clamp(min, op::Parameter(1), max);
4948   auto gather = op::Gather(op::Parameter(0), op::Subtract(clamp, min));
4949   auto mask =
4950       op::Or(op::Lt(op::Parameter(1), min), op::Gt(op::Parameter(1), max));
4951   auto masked =
4952       op::Select(op::Broadcast(mask), op::Broadcast(op::Constant()), gather);
4953   HloInstruction* root = module->entry_computation()->root_instruction();
4954   EXPECT_THAT(root, AllOf(op::AllReduce(masked), op::Shape("f32[2,3,9]")));
4955 }
4956 
TEST_F(SpmdPartitioningTest,PassthroughScatter)4957 TEST_F(SpmdPartitioningTest, PassthroughScatter) {
4958   absl::string_view hlo_string = R"(
4959 HloModule module
4960 
4961 add (lhs: f32[], rhs: f32[]) -> f32[] {
4962   lhs = f32[] parameter(0)
4963   rhs = f32[] parameter(1)
4964   ROOT sum = f32[] add(lhs, rhs)
4965 }
4966 
4967 ENTRY entry {
4968   %input = f32[2,9] parameter(0), sharding={devices=[1,2]0,1}
4969   %indices = s32[3] parameter(1), sharding={replicated}
4970   %updates = f32[3,9] parameter(2), sharding={devices=[1,2]0,1}
4971   ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
4972       to_apply=add,
4973       update_window_dims={1},
4974       inserted_window_dims={0},
4975       scatter_dims_to_operand_dims={0},
4976       index_vector_dim=1, sharding={devices=[1,2]0,1}
4977 })";
4978   TF_ASSERT_OK_AND_ASSIGN(auto module,
4979                           PartitionComputation(hlo_string, /*num_devices=*/2));
4980   VLOG(1) << module->ToString();
4981   HloInstruction* root = module->entry_computation()->root_instruction();
4982   EXPECT_THAT(root, AllOf(op::Scatter(op::Parameter(0), op::Parameter(1),
4983                                       op::Parameter(2)),
4984                           op::Shape("f32[2,5]")));
4985 }
4986 
TEST_F(SpmdPartitioningTest,PassthroughScatter_PartialReplicate)4987 TEST_F(SpmdPartitioningTest, PassthroughScatter_PartialReplicate) {
4988   absl::string_view hlo_string = R"(
4989 HloModule module
4990 
4991 add (lhs: f32[], rhs: f32[]) -> f32[] {
4992   lhs = f32[] parameter(0)
4993   rhs = f32[] parameter(1)
4994   ROOT sum = f32[] add(lhs, rhs)
4995 }
4996 
4997 ENTRY entry {
4998   %input = f32[2,9] parameter(0),
4999     sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
5000   %indices = s32[3] parameter(1), sharding={replicated}
5001   %updates = f32[3,9] parameter(2),
5002     sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
5003   ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
5004       to_apply=add,
5005       update_window_dims={1},
5006       inserted_window_dims={0},
5007       scatter_dims_to_operand_dims={0},
5008       index_vector_dim=1,
5009       sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
5010 })";
5011   TF_ASSERT_OK_AND_ASSIGN(auto module,
5012                           PartitionComputation(hlo_string, /*num_devices=*/4));
5013   VLOG(1) << module->ToString();
5014   HloInstruction* root = module->entry_computation()->root_instruction();
5015   EXPECT_THAT(root, AllOf(op::Scatter(op::Parameter(0), op::Parameter(1),
5016                                       op::Parameter(2)),
5017                           op::Shape("f32[2,5]")));
5018 }
5019 
TEST_F(SpmdPartitioningTest,IndexPassthroughScatter)5020 TEST_F(SpmdPartitioningTest, IndexPassthroughScatter) {
5021   absl::string_view hlo_string = R"(
5022 HloModule module
5023 
5024 add (lhs: f32[], rhs: f32[]) -> f32[] {
5025   lhs = f32[] parameter(0)
5026   rhs = f32[] parameter(1)
5027   ROOT sum = f32[] add(lhs, rhs)
5028 }
5029 
5030 ENTRY entry {
5031   %input = f32[2,9,8] parameter(0), sharding={replicated}
5032   %indices = s32[4,2,4] parameter(1), sharding={devices=[2,1,2]0,1,2,3}
5033   %updates = f32[4,4,8] parameter(2), sharding={devices=[2,2,1]0,1,2,3}
5034   ROOT %scatter = f32[2,9,8] scatter(%input, %indices, %updates),
5035       to_apply=add,
5036       update_window_dims={2},
5037       inserted_window_dims={0,1},
5038       scatter_dims_to_operand_dims={0,1},
5039       index_vector_dim=1, sharding={replicated}
5040 })";
5041   TF_ASSERT_OK_AND_ASSIGN(auto module,
5042                           PartitionComputation(hlo_string, /*num_devices=*/4));
5043   VLOG(1) << module->ToString();
5044   HloInstruction* root = module->entry_computation()->root_instruction();
5045   EXPECT_THAT(
5046       root,
5047       AllOf(op::AllReduce(op::AllReduce(op::Scatter(
5048                 op::Select(op::Broadcast(op::Convert(op::PartitionId())),
5049                            op::Broadcast(op::Constant()), op::Parameter(0)),
5050                 op::Parameter(1), op::Parameter(2)))),
5051             op::Shape("f32[2,9,8]")));
5052 }
5053 
TEST_F(SpmdPartitioningTest,IndexPassthroughScatter_PartialReplicate)5054 TEST_F(SpmdPartitioningTest, IndexPassthroughScatter_PartialReplicate) {
5055   absl::string_view hlo_string = R"(
5056 HloModule module
5057 
5058 add (lhs: f32[], rhs: f32[]) -> f32[] {
5059   lhs = f32[] parameter(0)
5060   rhs = f32[] parameter(1)
5061   ROOT sum = f32[] add(lhs, rhs)
5062 }
5063 
5064 ENTRY entry {
5065   %input = f32[2,9,8] parameter(0), sharding={replicated}
5066   %indices = s32[4,2,4] parameter(1),
5067     sharding={devices=[2,1,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
5068   %updates = f32[4,4,8] parameter(2),
5069     sharding={devices=[2,2,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
5070   ROOT %scatter = f32[2,9,8] scatter(%input, %indices, %updates),
5071       to_apply=add,
5072       update_window_dims={2},
5073       inserted_window_dims={0,1},
5074       scatter_dims_to_operand_dims={0,1},
5075       index_vector_dim=1, sharding={replicated}
5076 })";
5077   TF_ASSERT_OK_AND_ASSIGN(auto module,
5078                           PartitionComputation(hlo_string, /*num_devices=*/8));
5079   VLOG(1) << module->ToString();
5080   HloInstruction* root = module->entry_computation()->root_instruction();
5081   EXPECT_THAT(
5082       root,
5083       AllOf(op::AllReduce(op::AllReduce(op::Scatter(
5084                 op::Select(op::Broadcast(op::Convert(op::Reshape())),
5085                            op::Broadcast(op::Constant()), op::Parameter(0)),
5086                 op::Parameter(1), op::Parameter(2)))),
5087             op::Shape("f32[2,9,8]")));
5088 }
5089 
TEST_F(SpmdPartitioningTest,IndexPassthroughScatter_Min)5090 TEST_F(SpmdPartitioningTest, IndexPassthroughScatter_Min) {
5091   absl::string_view hlo_string = R"(
5092 HloModule module
5093 
5094 min (lhs: f32[], rhs: f32[]) -> f32[] {
5095   lhs = f32[] parameter(0)
5096   rhs = f32[] parameter(1)
5097   ROOT min = f32[] minimum(lhs, rhs)
5098 }
5099 
5100 ENTRY entry {
5101   %input = f32[2,9,8] parameter(0), sharding={replicated}
5102   %indices = s32[4,2,4] parameter(1), sharding={devices=[2,1,2]0,1,2,3}
5103   %updates = f32[4,4,8] parameter(2), sharding={devices=[2,2,1]0,1,2,3}
5104   ROOT %scatter = f32[2,9,8] scatter(%input, %indices, %updates),
5105       to_apply=min,
5106       update_window_dims={2},
5107       inserted_window_dims={0,1},
5108       scatter_dims_to_operand_dims={0,1},
5109       index_vector_dim=1, sharding={replicated}
5110 })";
5111   TF_ASSERT_OK_AND_ASSIGN(auto module,
5112                           PartitionComputation(hlo_string, /*num_devices=*/4));
5113   VLOG(1) << module->ToString();
5114   HloInstruction* root = module->entry_computation()->root_instruction();
5115   EXPECT_THAT(
5116       root,
5117       AllOf(op::AllReduce(op::AllReduce(op::Scatter(
5118                 op::Select(op::Broadcast(op::Convert(op::PartitionId())),
5119                            op::Broadcast(op::Constant()), op::Parameter(0)),
5120                 op::Parameter(1), op::Parameter(2)))),
5121             op::Shape("f32[2,9,8]")));
5122 }
5123 
TEST_F(SpmdPartitioningTest,ScatterPartitionedOnTrivialSliceDims)5124 TEST_F(SpmdPartitioningTest, ScatterPartitionedOnTrivialSliceDims) {
5125   absl::string_view hlo_string = R"(
5126 HloModule module
5127 
5128 add (lhs: f32[], rhs: f32[]) -> f32[] {
5129   lhs = f32[] parameter(0)
5130   rhs = f32[] parameter(1)
5131   ROOT sum = f32[] add(lhs, rhs)
5132 }
5133 
5134 ENTRY entry {
5135   %input = f32[17,9] parameter(0), sharding={devices=[2,1]0,1}
5136   %indices = s32[2,3] parameter(1), sharding={replicated}
5137   %updates = f32[2,3,9] parameter(2), sharding={replicated}
5138   ROOT %scatter = f32[17,9] scatter(%input, %indices, %updates),
5139       to_apply=add,
5140       update_window_dims={2},
5141       inserted_window_dims={0},
5142       scatter_dims_to_operand_dims={0},
5143       index_vector_dim=2, sharding={devices=[2,1]0,1}
5144 })";
5145   TF_ASSERT_OK_AND_ASSIGN(auto module,
5146                           PartitionComputation(hlo_string, /*num_devices=*/2));
5147   VLOG(1) << module->ToString();
5148   auto offset =
5149       op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
5150   auto indices = op::Subtract(
5151       op::Parameter(1), AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")));
5152   HloInstruction* root = module->entry_computation()->root_instruction();
5153   EXPECT_THAT(root,
5154               AllOf(op::Scatter(op::Parameter(0), indices, op::Parameter(2)),
5155                     op::Shape("f32[9,9]")));
5156 }
5157 
TEST_F(SpmdPartitioningTest,ScatterPartitionedOnTrivialSliceDims_PartialReplicate)5158 TEST_F(SpmdPartitioningTest,
5159        ScatterPartitionedOnTrivialSliceDims_PartialReplicate) {
5160   absl::string_view hlo_string = R"(
5161 HloModule module
5162 
5163 add (lhs: f32[], rhs: f32[]) -> f32[] {
5164   lhs = f32[] parameter(0)
5165   rhs = f32[] parameter(1)
5166   ROOT sum = f32[] add(lhs, rhs)
5167 }
5168 
5169 ENTRY entry {
5170   %input = f32[17,9] parameter(0),
5171     sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
5172   %indices = s32[2,3] parameter(1), sharding={replicated}
5173   %updates = f32[2,3,9] parameter(2), sharding={replicated}
5174   ROOT %scatter = f32[17,9] scatter(%input, %indices, %updates),
5175       to_apply=add,
5176       update_window_dims={2},
5177       inserted_window_dims={0},
5178       scatter_dims_to_operand_dims={0},
5179       index_vector_dim=2,
5180       sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
5181 })";
5182   TF_ASSERT_OK_AND_ASSIGN(auto module,
5183                           PartitionComputation(hlo_string, /*num_devices=*/4));
5184   VLOG(1) << module->ToString();
5185   auto offset =
5186       op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
5187   auto indices = op::Subtract(
5188       op::Parameter(1), AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")));
5189   HloInstruction* root = module->entry_computation()->root_instruction();
5190   EXPECT_THAT(root,
5191               AllOf(op::Scatter(op::Parameter(0), indices, op::Parameter(2)),
5192                     op::Shape("f32[9,9]")));
5193 }
5194 
TEST_F(SpmdPartitioningTest,TiledReversePassthrough)5195 TEST_F(SpmdPartitioningTest, TiledReversePassthrough) {
5196   absl::string_view hlo_string = R"(
5197 HloModule module
5198 
5199 ENTRY entry {
5200   constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}),
5201     sharding={devices=[2,1]0,1}
5202   ROOT reverse = f32[3,3]{1,0} reverse(constant), dimensions={1},
5203     sharding={devices=[2,1]0,1}
5204 })";
5205   TF_ASSERT_OK_AND_ASSIGN(auto module,
5206                           PartitionComputation(hlo_string, /*num_devices=*/2));
5207   VLOG(1) << module->ToString();
5208   HloInstruction* root = module->entry_computation()->root_instruction();
5209   EXPECT_THAT(root, AllOf(op::Shape("f32[2,3]{1,0}"),
5210                           op::Reverse(op::DynamicSlice(
5211                               op::Pad(op::Constant(), op::Constant()),
5212                               op::Reshape(), op::Constant()))));
5213 }
5214 
TEST_F(SpmdPartitioningTest,TiledReversePassthroughViaReversedSharding)5215 TEST_F(SpmdPartitioningTest, TiledReversePassthroughViaReversedSharding) {
5216   absl::string_view hlo_string = R"(
5217 HloModule module
5218 
5219 ENTRY entry {
5220   param = f32[4] parameter(0), sharding={devices=[2]0,1}
5221   ROOT reverse = f32[4] reverse(param), dimensions={0},
5222     sharding={devices=[2]1,0}
5223 })";
5224   TF_ASSERT_OK_AND_ASSIGN(auto module,
5225                           PartitionComputation(hlo_string, /*num_devices=*/2));
5226   VLOG(1) << module->ToString();
5227   HloInstruction* root = module->entry_computation()->root_instruction();
5228   EXPECT_THAT(root, AllOf(op::Shape("f32[2]"), op::Reverse(op::Parameter(0))));
5229 }
5230 
TEST_F(SpmdPartitioningTest,TiledReverseSwapShards)5231 TEST_F(SpmdPartitioningTest, TiledReverseSwapShards) {
5232   absl::string_view hlo_string = R"(
5233 HloModule module
5234 
5235 ENTRY entry {
5236   param = f32[4] parameter(0), sharding={devices=[2]0,1}
5237   ROOT reverse = f32[4] reverse(param), dimensions={0},
5238     sharding={devices=[2]0,1}
5239 })";
5240   TF_ASSERT_OK_AND_ASSIGN(auto module,
5241                           PartitionComputation(hlo_string, /*num_devices=*/2));
5242   VLOG(1) << module->ToString();
5243   HloInstruction* root = module->entry_computation()->root_instruction();
5244   EXPECT_THAT(root,
5245               AllOf(op::Shape("f32[2]"),
5246                     op::Reverse(op::CollectivePermute(op::Parameter(0)))));
5247 }
5248 
TEST_F(SpmdPartitioningTest,TiledReverseHaloExchange)5249 TEST_F(SpmdPartitioningTest, TiledReverseHaloExchange) {
5250   absl::string_view hlo_string = R"(
5251 HloModule module
5252 
5253 ENTRY entry {
5254   param = f32[3] parameter(0), sharding={devices=[2]0,1}
5255   ROOT reverse = f32[3] reverse(param), dimensions={0},
5256     sharding={devices=[2]1,0}
5257 })";
5258   TF_ASSERT_OK_AND_ASSIGN(auto module,
5259                           PartitionComputation(hlo_string, /*num_devices=*/2));
5260   VLOG(1) << module->ToString();
5261   HloInstruction* root = module->entry_computation()->root_instruction();
5262   auto halo_exchange_concat =
5263       op::Concatenate(AllOf(op::Shape("f32[1]"),
5264                             op::CollectivePermute(op::Slice(op::Parameter(0)))),
5265                       op::Slice(op::Parameter(0)));
5266   EXPECT_THAT(root,
5267               AllOf(op::Shape("f32[2]"), op::Reverse(halo_exchange_concat)));
5268 }
5269 
TEST_F(SpmdPartitioningTest,MixWithManualPartitioning)5270 TEST_F(SpmdPartitioningTest, MixWithManualPartitioning) {
5271   absl::string_view hlo_string = R"(
5272 HloModule module
5273 
5274 ENTRY entry {
5275   param = (f32[8,2], f32[4,2]) parameter(0), sharding={{devices=[2,1]0,1},{manual}}
5276   param0 = f32[8,2] get-tuple-element(param), index=0, sharding={devices=[2,1]0,1}
5277   param1 = f32[4,2] get-tuple-element(param), index=1, sharding={manual}
5278   to_shard = f32[4,2] custom-call(param0), custom_call_target="SPMDFullToShardShape", sharding={manual}
5279   add = f32[4,2] add(to_shard, param1), sharding={manual}
5280   to_full = f32[8,2] custom-call(add), custom_call_target="SPMDShardToFullShape", sharding={devices=[2,1]0,1}
5281   mul = f32[8,2] multiply(to_full, param0), sharding={devices=[2,1]0,1}
5282   to_shard2 = f32[4,2] custom-call(mul), custom_call_target="SPMDFullToShardShape", sharding={manual}
5283   ROOT tuple = (f32[4,2]) tuple(to_shard2), sharding={{manual}}
5284 })";
5285   TF_ASSERT_OK_AND_ASSIGN(auto module,
5286                           PartitionComputation(hlo_string, /*num_devices=*/2));
5287   VLOG(1) << module->ToString();
5288   HloInstruction* root = module->entry_computation()->root_instruction();
5289   auto p0 = op::GetTupleElement(op::Parameter(0));
5290   auto to_shard = op::Copy(p0);
5291   auto p1 = op::GetTupleElement(op::Parameter(0));
5292   auto mul = AllOf(op::Shape("f32[4,2]"),
5293                    op::Multiply(op::Copy(op::Add(to_shard, p1)), p0));
5294   EXPECT_THAT(root, op::Tuple(op::Copy(mul)));
5295 }
5296 
TEST_F(SpmdPartitioningTest,SubgroupAllToAllReshard)5297 TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard) {
5298   absl::string_view hlo_string = R"(
5299 HloModule module
5300 
5301 ENTRY entry {
5302   %param0 = f32[8,8,8,8] parameter(0),
5303     sharding={devices=[2,2,1,2]0,1,2,3,4,5,6,7}
5304   ROOT %copy = f32[8,8,8,8] copy(%param0),
5305     sharding={devices=[1,2,2,2]0,1,4,5,2,3,6,7}
5306 })";
5307 
5308   TF_ASSERT_OK_AND_ASSIGN(auto module,
5309                           PartitionComputation(hlo_string, /*num_devices=*/8));
5310   VLOG(1) << module->ToString();
5311 
5312   auto root = module->entry_computation()->root_instruction();
5313   auto reshape =
5314       AllOf(op::Shape("f32[4,4,2,4,4]"), op::Reshape(op::Parameter(0)));
5315   auto all_to_all = AllOf(op::Shape("f32[4,4,2,4,4]"), op::AllToAll(reshape));
5316   auto xpose = AllOf(op::Shape("f32[2,4,4,4,4]"), op::Transpose(all_to_all));
5317   EXPECT_THAT(root,
5318               op::Copy(AllOf(op::Reshape(xpose), op::Shape("f32[8,4,4,4]"))));
5319   EXPECT_EQ(root->operand(0)->operand(0)->operand(0)->replica_groups().size(),
5320             4);
5321 }
5322 
TEST_F(SpmdPartitioningTest,SubgroupAllToAllReshard2)5323 TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard2) {
5324   absl::string_view hlo_string = R"(
5325 HloModule module
5326 
5327 ENTRY entry {
5328   %param0 = f32[8,8] parameter(0),
5329     sharding={devices=[2,4]0,1,2,3,4,5,6,7}
5330   ROOT %copy = f32[8,8] copy(%param0),
5331     sharding={devices=[4,2]0,1,4,5,2,3,6,7}
5332 })";
5333 
5334   TF_ASSERT_OK_AND_ASSIGN(auto module,
5335                           PartitionComputation(hlo_string, /*num_devices=*/8));
5336   VLOG(1) << module->ToString();
5337 
5338   auto root = module->entry_computation()->root_instruction();
5339   auto all_to_all = op::AllToAll(
5340       AllOf(op::Shape("f32[2,2,2]"), op::Reshape(op::Parameter(0))));
5341   auto reshape =
5342       AllOf(op::Shape("f32[2,4]"), op::Reshape(op::Transpose(all_to_all)));
5343   EXPECT_THAT(root, op::Copy(op::CollectivePermute(reshape)));
5344 }
5345 
TEST_F(SpmdPartitioningTest,SubgroupAllToAllReshard3)5346 TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard3) {
5347   absl::string_view hlo_string = R"(
5348 HloModule module
5349 
5350 ENTRY entry {
5351   %param0 = f32[8,8,8] parameter(0),
5352     sharding={devices=[2,4,1]0,1,2,3,4,5,6,7}
5353   ROOT %copy = f32[8,8,8] copy(%param0),
5354     sharding={devices=[1,2,4]0,1,4,5,2,3,6,7}
5355 })";
5356 
5357   TF_ASSERT_OK_AND_ASSIGN(auto module,
5358                           PartitionComputation(hlo_string, /*num_devices=*/8));
5359   VLOG(1) << module->ToString();
5360 
5361   auto root = module->entry_computation()->root_instruction();
5362   auto all_to_all = op::AllToAll(
5363       AllOf(op::Shape("f32[4,2,4,2]"), op::Reshape(op::Parameter(0))));
5364   auto reshape =
5365       AllOf(op::Shape("f32[4,8,2]"), op::Reshape(op::Transpose(all_to_all)));
5366   auto all_to_all2 =
5367       op::AllToAll(AllOf(op::Shape("f32[4,2,4,2]"), op::Reshape(reshape)));
5368   auto reshape2 =
5369       AllOf(op::Shape("f32[8,4,2]"), op::Reshape(op::Transpose(all_to_all2)));
5370   EXPECT_THAT(root, op::Copy(op::CollectivePermute(reshape2)));
5371 }
5372 
TEST_F(SpmdPartitioningTest,Dot2DPartitionedNonContractingAndContracting0)5373 TEST_F(SpmdPartitioningTest, Dot2DPartitionedNonContractingAndContracting0) {
5374   absl::string_view hlo_string = R"(
5375 HloModule module
5376 
5377 ENTRY entry {
5378   %lhs = f32[48,12] parameter(0), sharding={devices=[2,2]0,1,2,3}
5379   %rhs = f32[32,12] parameter(1), sharding={devices=[2,2]0,2,1,3}
5380   ROOT %dot = f32[48,32] dot(%lhs, %rhs),
5381     lhs_batch_dims={}, rhs_batch_dims={},
5382     lhs_contracting_dims={1}, rhs_contracting_dims={1},
5383     sharding={devices=[2,2]0,1,2,3}
5384 })";
5385 
5386   TF_ASSERT_OK_AND_ASSIGN(auto module,
5387                           PartitionComputation(hlo_string, /*num_devices=*/4));
5388   VLOG(1) << module->ToString();
5389 
5390   auto lhs = AllOf(op::Shape("f32[24,6]"), op::Parameter(0));
5391   auto partial_replicated_lhs =
5392       AllOf(op::Shape("f32[24,12]"),
5393             op::AllReduce(op::DynamicUpdateSlice(_, lhs, _, _)));
5394   auto rhs = AllOf(op::Shape("f32[16,6]"), op::Parameter(1));
5395   auto partial_replicated_rhs =
5396       AllOf(op::Shape("f32[16,12]"),
5397             op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _)));
5398   auto root = module->entry_computation()->root_instruction();
5399   EXPECT_THAT(root,
5400               AllOf(op::Dot(partial_replicated_lhs, partial_replicated_rhs),
5401                     op::Shape("f32[24,16]")));
5402 }
5403 
TEST_F(SpmdPartitioningTest,Dot2DPartitionedNonContractingAndContracting1)5404 TEST_F(SpmdPartitioningTest, Dot2DPartitionedNonContractingAndContracting1) {
5405   absl::string_view hlo_string = R"(
5406 HloModule module
5407 
5408 ENTRY entry {
5409   %lhs = f32[48,100] parameter(0), sharding={devices=[2,2]0,1,2,3}
5410   %rhs = f32[32,100] parameter(1), sharding={devices=[2,2]0,1,2,3}
5411   ROOT %dot = f32[48,32] dot(%lhs, %rhs),
5412     lhs_batch_dims={}, rhs_batch_dims={},
5413     lhs_contracting_dims={1}, rhs_contracting_dims={1},
5414     sharding={devices=[2,2]0,1,2,3}
5415 })";
5416 
5417   TF_ASSERT_OK_AND_ASSIGN(auto module,
5418                           PartitionComputation(hlo_string, /*num_devices=*/4));
5419   VLOG(1) << module->ToString();
5420 
5421   auto lhs = AllOf(op::Shape("f32[24,50]"), op::Parameter(0));
5422   auto rhs = AllOf(op::Shape("f32[16,50]"), op::Parameter(1));
5423   auto partial_replicated_rhs =
5424       AllOf(op::Shape("f32[32,50]"),
5425             op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _)));
5426   auto root = module->entry_computation()->root_instruction();
5427   EXPECT_THAT(
5428       root, AllOf(op::Shape("f32[24,16]"),
5429                   op::DynamicSlice(
5430                       op::AllReduce(AllOf(op::Dot(lhs, partial_replicated_rhs),
5431                                           op::Shape("f32[24,32]"))),
5432                       _, _)));
5433 }
5434 
TEST_F(SpmdPartitioningTest,Dot2DPartitionedNonContractingAndContracting2)5435 TEST_F(SpmdPartitioningTest, Dot2DPartitionedNonContractingAndContracting2) {
5436   absl::string_view hlo_string = R"(
5437 HloModule module
5438 
5439 ENTRY entry {
5440   %lhs = f32[48,100] parameter(0), sharding={replicated}
5441   %rhs = f32[32,100] parameter(1), sharding={devices=[2,2]0,1,2,3}
5442   ROOT %dot = f32[48,32] dot(%lhs, %rhs),
5443     lhs_batch_dims={}, rhs_batch_dims={},
5444     lhs_contracting_dims={1}, rhs_contracting_dims={1},
5445     sharding={devices=[2,2]0,1,2,3}
5446 })";
5447 
5448   TF_ASSERT_OK_AND_ASSIGN(auto module,
5449                           PartitionComputation(hlo_string, /*num_devices=*/4));
5450   VLOG(1) << module->ToString();
5451 
5452   auto lhs = AllOf(op::Shape("f32[48,100]"), op::Parameter(0));
5453   auto lhs_slice = AllOf(op::Shape("f32[24,100]"), op::DynamicSlice(lhs, _, _));
5454   auto rhs = AllOf(op::Shape("f32[16,50]"), op::Parameter(1));
5455   auto partial_replicated_rhs = AllOf(
5456       op::Shape("f32[16,100]"), op::AllReduce(op::DynamicUpdateSlice(
5457                                     _, op::CollectivePermute(rhs), _, _)));
5458   auto root = module->entry_computation()->root_instruction();
5459   EXPECT_THAT(root, AllOf(op::Shape("f32[24,16]"),
5460                           op::Dot(lhs_slice, partial_replicated_rhs)));
5461 }
5462 
TEST_F(SpmdPartitioningTest,Dot2DPartitionedNoncontractingAndContracting3)5463 TEST_F(SpmdPartitioningTest, Dot2DPartitionedNoncontractingAndContracting3) {
5464   absl::string_view hlo_string = R"(
5465 HloModule module
5466 
5467 ENTRY entry {
5468   %lhs = f32[23,24] parameter(0), sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
5469   %rhs = f32[23,32] parameter(1), sharding={devices=[2,2]0,1,2,3}
5470   ROOT %dot = f32[24,32] dot(%lhs, %rhs),
5471     lhs_contracting_dims={0}, rhs_contracting_dims={0},
5472     sharding={devices=[2,2]1,0,3,2}
5473 })";
5474 
5475   TF_ASSERT_OK_AND_ASSIGN(auto module,
5476                           PartitionComputation(hlo_string, /*num_devices=*/4));
5477   VLOG(1) << module->ToString();
5478 
5479   auto lhs = AllOf(op::Shape("f32[12,24]"), op::Parameter(0));
5480   auto masked_lhs = op::Select(_, lhs, op::Broadcast(op::Constant()));
5481   auto rhs = AllOf(op::Shape("f32[12,16]"), op::Parameter(1));
5482   auto masked_rhs = op::Select(_, rhs, op::Broadcast(op::Constant()));
5483   auto root = module->entry_computation()->root_instruction();
5484   EXPECT_THAT(
5485       root,
5486       AllOf(op::Shape("f32[12,16]"),
5487             op::DynamicSlice(
5488                 AllOf(op::Shape("f32[24,16]"),
5489                       op::AllReduce(op::Dot(
5490                           masked_lhs, op::CollectivePermute(masked_rhs)))),
5491                 _, _)));
5492 }
5493 
TEST_F(SpmdPartitioningTest,Dot2DPartitionedBatchAndNonContracting)5494 TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndNonContracting) {
5495   absl::string_view hlo_string = R"(
5496 HloModule module
5497 
5498 ENTRY entry {
5499   %lhs = f32[4,24,100] parameter(0), sharding={devices=[2,2,1]0,1,2,3}
5500   %rhs = f32[4,32,100] parameter(1), sharding={devices=[2,2,1]0,1,2,3}
5501   ROOT %dot = f32[4,24,32] dot(%lhs, %rhs),
5502     lhs_batch_dims={0}, rhs_batch_dims={0},
5503     lhs_contracting_dims={2}, rhs_contracting_dims={2},
5504     sharding={devices=[2,2,1]0,1,2,3}
5505 })";
5506 
5507   TF_ASSERT_OK_AND_ASSIGN(auto module,
5508                           PartitionComputation(hlo_string, /*num_devices=*/4));
5509   VLOG(1) << module->ToString();
5510 
5511   auto lhs = AllOf(op::Shape("f32[2,12,100]"), op::Parameter(0));
5512   auto rhs = AllOf(op::Shape("f32[2,16,100]"), op::Parameter(1));
5513   auto partial_replicated_rhs =
5514       AllOf(op::Shape("f32[2,32,100]"),
5515             op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _, _)));
5516   auto root = module->entry_computation()->root_instruction();
5517   EXPECT_THAT(root, AllOf(op::Shape("f32[2,12,32]"),
5518                           op::Dot(lhs, partial_replicated_rhs)));
5519 }
5520 
TEST_F(SpmdPartitioningTest,Dot2DPartitionedBatchAndContracting)5521 TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndContracting) {
5522   absl::string_view hlo_string = R"(
5523 HloModule module
5524 
5525 ENTRY entry {
5526   %lhs = f32[4,24,100] parameter(0), sharding={devices=[2,1,2]0,1,2,3}
5527   %rhs = f32[4,32,100] parameter(1), sharding={devices=[1,2,2]0,1,2,3}
5528   ROOT %dot = f32[4,24,32] dot(%lhs, %rhs),
5529     lhs_batch_dims={0}, rhs_batch_dims={0},
5530     lhs_contracting_dims={2}, rhs_contracting_dims={2},
5531     sharding={devices=[2,2,1]0,1,2,3}
5532 })";
5533 
5534   TF_ASSERT_OK_AND_ASSIGN(auto module,
5535                           PartitionComputation(hlo_string, /*num_devices=*/4));
5536   VLOG(1) << module->ToString();
5537 
5538   auto lhs = AllOf(op::Shape("f32[2,24,50]"), op::Parameter(0));
5539   auto rhs = AllOf(op::Shape("f32[4,16,50]"), op::Parameter(1));
5540   auto resharded_rhs =
5541       AllOf(op::Shape("f32[2,32,50]"),
5542             op::Reshape(op::Transpose(op::AllToAll(op::Reshape(rhs)))));
5543   auto root = module->entry_computation()->root_instruction();
5544   EXPECT_THAT(root, AllOf(op::Shape("f32[2,12,32]"),
5545                           op::DynamicSlice(
5546                               AllOf(op::Shape("f32[2,24,32]"),
5547                                     op::AllReduce(op::Dot(lhs, resharded_rhs))),
5548                               _, _, _)));
5549 }
5550 
TEST_F(SpmdPartitioningTest,Dot2DPartitionedBatchAndContracting2)5551 TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndContracting2) {
5552   absl::string_view hlo_string = R"(
5553 HloModule module
5554 
5555 ENTRY entry {
5556   %lhs = f32[4,24,100] parameter(0), sharding={devices=[2,1,2]0,1,2,3}
5557   %rhs = f32[4,32,100] parameter(1), sharding={replicated}
5558   ROOT %dot = f32[4,24,32] dot(%lhs, %rhs),
5559     lhs_batch_dims={0}, rhs_batch_dims={0},
5560     lhs_contracting_dims={2}, rhs_contracting_dims={2},
5561     sharding={devices=[2,2,1]0,1,2,3}
5562 })";
5563 
5564   TF_ASSERT_OK_AND_ASSIGN(auto module,
5565                           PartitionComputation(hlo_string, /*num_devices=*/4));
5566   VLOG(1) << module->ToString();
5567 
5568   auto lhs = AllOf(op::Shape("f32[2,24,50]"), op::Parameter(0));
5569   auto resharded_lhs =
5570       AllOf(op::Shape("f32[2,12,100]"),
5571             op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs)))));
5572   auto rhs = AllOf(op::Shape("f32[4,32,100]"), op::Parameter(1));
5573   auto rhs_slice =
5574       AllOf(op::Shape("f32[2,32,100]"), op::DynamicSlice(rhs, _, _, _));
5575   auto root = module->entry_computation()->root_instruction();
5576   EXPECT_THAT(root, AllOf(op::Shape("f32[2,12,32]"),
5577                           op::Dot(resharded_lhs, rhs_slice)));
5578 }
5579 
TEST_F(SpmdPartitioningTest,Dot2DPartitionedBatchNonContractingAndContracting)5580 TEST_F(SpmdPartitioningTest,
5581        Dot2DPartitionedBatchNonContractingAndContracting) {
5582   absl::string_view hlo_string = R"(
5583 HloModule module
5584 
5585 ENTRY entry {
5586   %lhs = f32[4,24,100] parameter(0), sharding={devices=[2,1,2]0,1,2,3}
5587   %rhs = f32[4,32,100] parameter(1), sharding={devices=[2,2,1]0,1,2,3}
5588   ROOT %dot = f32[4,24,32] dot(%lhs, %rhs),
5589     lhs_batch_dims={0}, rhs_batch_dims={0},
5590     lhs_contracting_dims={2}, rhs_contracting_dims={2},
5591     sharding={devices=[2,1,2]0,1,2,3}
5592 })";
5593 
5594   TF_ASSERT_OK_AND_ASSIGN(auto module,
5595                           PartitionComputation(hlo_string, /*num_devices=*/4));
5596   VLOG(1) << module->ToString();
5597 
5598   auto lhs = AllOf(op::Shape("f32[2,24,50]"), op::Parameter(0));
5599   auto rhs = AllOf(op::Shape("f32[2,16,100]"), op::Parameter(1));
5600   auto partial_replicated_lhs =
5601       AllOf(op::Shape("f32[2,24,100]"),
5602             op::AllReduce(op::DynamicUpdateSlice(_, lhs, _, _, _)));
5603   auto root = module->entry_computation()->root_instruction();
5604   EXPECT_THAT(root, AllOf(op::Shape("f32[2,24,16]"),
5605                           op::Dot(partial_replicated_lhs, rhs)));
5606 }
5607 
TEST_F(SpmdPartitioningTest,Dot2DPartitionedBatchAndReshard)5608 TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndReshard) {
5609   absl::string_view hlo_string = R"(
5610 HloModule module
5611 
5612 ENTRY entry {
5613   %lhs = f32[4,8,24,100] parameter(0), sharding={devices=[2,1,2,1]0,1,2,3}
5614   %rhs = f32[4,8,32,100] parameter(1), sharding={devices=[2,1,2,1]0,1,2,3}
5615   ROOT %dot = f32[4,8,24,32] dot(%lhs, %rhs),
5616     lhs_batch_dims={0,1}, rhs_batch_dims={0,1},
5617     lhs_contracting_dims={3}, rhs_contracting_dims={3},
5618     sharding={devices=[1,2,2,1]0,1,2,3}
5619 })";
5620 
5621   TF_ASSERT_OK_AND_ASSIGN(auto module,
5622                           PartitionComputation(hlo_string, /*num_devices=*/4));
5623   VLOG(1) << module->ToString();
5624 
5625   auto lhs = AllOf(op::Shape("f32[2,8,12,100]"), op::Parameter(0));
5626   auto rhs = AllOf(op::Shape("f32[2,8,16,100]"), op::Parameter(1));
5627   auto partial_replicated_rhs =
5628       AllOf(op::Shape("f32[2,8,32,100]"),
5629             op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _, _, _)));
5630   auto dot =
5631       AllOf(op::Shape("f32[2,8,12,32]"), op::Dot(lhs, partial_replicated_rhs));
5632   auto reshape = AllOf(op::Shape("f32[2,2,4,12,32]"), op::Reshape(dot));
5633   auto all_to_all = AllOf(op::Shape("f32[2,2,4,12,32]"), op::AllToAll(reshape));
5634   auto xpose = AllOf(op::Shape("f32[2,2,4,12,32]"), op::Transpose(all_to_all));
5635   auto root = module->entry_computation()->root_instruction();
5636   EXPECT_THAT(root, AllOf(op::Shape("f32[4,4,12,32]"), op::Reshape(xpose)));
5637 }
5638 
TEST_F(SpmdPartitioningTest,SimpleDotPartial)5639 TEST_F(SpmdPartitioningTest, SimpleDotPartial) {
5640   absl::string_view hlo_string = R"(
5641 HloModule module
5642 
5643 ENTRY entry {
5644   %lhs = f32[2,24,100] parameter(0),
5645     sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate}
5646   %rhs = f32[2,32,100] parameter(1),
5647     sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate}
5648   ROOT %dot = f32[2,24,32] dot(%lhs, %rhs),
5649     lhs_batch_dims={0}, rhs_batch_dims={0},
5650     lhs_contracting_dims={2}, rhs_contracting_dims={2},
5651     sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate}
5652 })";
5653 
5654   TF_ASSERT_OK_AND_ASSIGN(auto module,
5655                           PartitionComputation(hlo_string, /*num_devices=*/4));
5656   VLOG(1) << module->ToString();
5657 
5658   auto lhs = AllOf(op::Shape("f32[1,24,100]"), op::Parameter(0));
5659   auto rhs = AllOf(op::Shape("f32[1,32,100]"), op::Parameter(1));
5660   auto dot = AllOf(op::Shape("f32[1,24,32]"), op::Dot(lhs, rhs));
5661   auto root = module->entry_computation()->root_instruction();
5662   EXPECT_THAT(root, dot);
5663 }
5664 
TEST_F(SpmdPartitioningTest,DotPartialContracting)5665 TEST_F(SpmdPartitioningTest, DotPartialContracting) {
5666   absl::string_view hlo_string = R"(
5667 HloModule module
5668 
5669 ENTRY entry {
5670   %lhs = f32[24,100] parameter(0),
5671     sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
5672   %rhs = f32[32,100] parameter(1),
5673     sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
5674   ROOT %dot = f32[24,32] dot(%lhs, %rhs),
5675     lhs_batch_dims={}, rhs_batch_dims={},
5676     lhs_contracting_dims={1}, rhs_contracting_dims={1},
5677     sharding={replicated}
5678 })";
5679 
5680   TF_ASSERT_OK_AND_ASSIGN(auto module,
5681                           PartitionComputation(hlo_string, /*num_devices=*/4));
5682   VLOG(1) << module->ToString();
5683 
5684   auto lhs = AllOf(op::Shape("f32[24,50]"), op::Parameter(0));
5685   auto rhs = AllOf(op::Shape("f32[32,50]"), op::Parameter(1));
5686   auto dot = AllOf(op::Shape("f32[24,32]"), op::Dot(lhs, rhs));
5687   auto root = module->entry_computation()->root_instruction();
5688   EXPECT_THAT(root, op::AllReduce(dot));
5689 }
5690 
TEST_F(SpmdPartitioningTest,DotPartialContracting2)5691 TEST_F(SpmdPartitioningTest, DotPartialContracting2) {
5692   absl::string_view hlo_string = R"(
5693 HloModule module
5694 
5695 ENTRY entry {
5696   %lhs = f32[24,100] parameter(0),
5697     sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
5698   %rhs = f32[32,100] parameter(1),
5699     sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
5700   ROOT %dot = f32[24,32] dot(%lhs, %rhs),
5701     lhs_batch_dims={}, rhs_batch_dims={},
5702     lhs_contracting_dims={1}, rhs_contracting_dims={1},
5703     sharding={devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}
5704 })";
5705 
5706   TF_ASSERT_OK_AND_ASSIGN(auto module,
5707                           PartitionComputation(hlo_string, /*num_devices=*/4));
5708   VLOG(1) << module->ToString();
5709 
5710   auto lhs = AllOf(op::Shape("f32[24,50]"), op::Parameter(0));
5711   auto rhs = AllOf(op::Shape("f32[32,50]"), op::Parameter(1));
5712   auto dot =
5713       AllOf(op::Shape("f32[12,32]"),
5714             op::Dot(AllOf(op::Shape("f32[12,50]"), op::DynamicSlice(lhs, _, _)),
5715                     rhs));
5716   auto root = module->entry_computation()->root_instruction();
5717   EXPECT_THAT(root, op::AllReduce(dot));
5718 }
5719 
TEST_F(SpmdPartitioningTest,DotPartialContracting3)5720 TEST_F(SpmdPartitioningTest, DotPartialContracting3) {
5721   absl::string_view hlo_string = R"(
5722 HloModule module
5723 
5724 ENTRY entry {
5725   %lhs = f32[24,100] parameter(0),
5726     sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
5727   %rhs = f32[32,100] parameter(1),
5728     sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
5729   ROOT %dot = f32[24,32] dot(%lhs, %rhs),
5730     lhs_batch_dims={}, rhs_batch_dims={},
5731     lhs_contracting_dims={1}, rhs_contracting_dims={1},
5732     sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
5733 })";
5734 
5735   TF_ASSERT_OK_AND_ASSIGN(auto module,
5736                           PartitionComputation(hlo_string, /*num_devices=*/8));
5737   VLOG(1) << module->ToString();
5738 
5739   auto lhs = AllOf(op::Shape("f32[24,50]"), op::Parameter(0));
5740   auto rhs =
5741       AllOf(op::Shape("f32[16,50]"), op::DynamicSlice(op::Parameter(1), _, _));
5742   auto dot = AllOf(op::Shape("f32[24,16]"), op::Dot(lhs, rhs));
5743   auto root = module->entry_computation()->root_instruction();
5744   EXPECT_THAT(root, op::CollectivePermute(op::AllReduce(dot)));
5745 }
5746 
TEST_F(SpmdPartitioningTest,DotBatchAndPartialContracting)5747 TEST_F(SpmdPartitioningTest, DotBatchAndPartialContracting) {
5748   absl::string_view hlo_string = R"(
5749 HloModule module
5750 
5751 ENTRY entry {
5752   %lhs = f32[4,24,100] parameter(0),
5753     sharding={devices=[2,2,2]0,1,2,3,4,5,6,7}
5754   %rhs = f32[4,32,100] parameter(1),
5755     sharding={devices=[2,1,2,2]0,2,1,3,4,6,5,7 last_tile_dim_replicate}
5756   ROOT %dot = f32[4,24,32] dot(%lhs, %rhs),
5757     lhs_batch_dims={0}, rhs_batch_dims={0},
5758     lhs_contracting_dims={2}, rhs_contracting_dims={2},
5759     sharding={devices=[2,2,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
5760 })";
5761 
5762   TF_ASSERT_OK_AND_ASSIGN(auto module,
5763                           PartitionComputation(hlo_string, /*num_devices=*/8));
5764   VLOG(1) << module->ToString();
5765 
5766   auto lhs = AllOf(op::Shape("f32[2,12,50]"), op::Parameter(0));
5767   auto rhs = AllOf(op::Shape("f32[2,32,50]"), op::Parameter(1));
5768   auto dot = AllOf(op::Shape("f32[2,12,32]"), op::Dot(lhs, rhs));
5769   auto root = module->entry_computation()->root_instruction();
5770   EXPECT_THAT(root, op::AllReduce(dot));
5771 }
5772 
TEST_F(SpmdPartitioningTest,DotPartialNonContracting)5773 TEST_F(SpmdPartitioningTest, DotPartialNonContracting) {
5774   absl::string_view hlo_string = R"(
5775 HloModule module
5776 
5777 ENTRY entry {
5778   %lhs = f32[24,8,100] parameter(0),
5779     sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate}
5780   %rhs = f32[32,100] parameter(1), sharding={devices=[2,2]0,2,1,3}
5781   ROOT %dot = f32[24,8,32] dot(%lhs, %rhs),
5782     lhs_batch_dims={}, rhs_batch_dims={},
5783     lhs_contracting_dims={2}, rhs_contracting_dims={1},
5784     sharding={devices=[2,1,2]0,1,2,3}
5785 })";
5786 
5787   TF_ASSERT_OK_AND_ASSIGN(auto module,
5788                           PartitionComputation(hlo_string, /*num_devices=*/4));
5789   VLOG(1) << module->ToString();
5790 
5791   auto lhs = AllOf(op::Shape("f32[12,8,100]"), op::Parameter(0));
5792   auto rhs = AllOf(op::Shape("f32[16,50]"), op::Parameter(1));
5793   auto partially_replicated_rhs =
5794       AllOf(op::Shape("f32[16,100]"),
5795             op::AllReduce(op::DynamicUpdateSlice(op::Broadcast(_), rhs, _, _)));
5796   auto dot =
5797       AllOf(op::Shape("f32[12,8,16]"), op::Dot(lhs, partially_replicated_rhs));
5798   auto root = module->entry_computation()->root_instruction();
5799   EXPECT_THAT(root, dot);
5800 }
5801 
TEST_F(SpmdPartitioningTest,DotPartialNonContractingPartialMatch)5802 TEST_F(SpmdPartitioningTest, DotPartialNonContractingPartialMatch) {
5803   absl::string_view hlo_string = R"(
5804 HloModule module
5805 
5806 ENTRY entry {
5807   %lhs = f32[24,8,100] parameter(0), sharding={devices=[2,2,1]0,1,2,3}
5808   %rhs = f32[32,100] parameter(1),
5809     sharding={devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}
5810   ROOT %dot = f32[24,8,32] dot(%lhs, %rhs),
5811     lhs_batch_dims={}, rhs_batch_dims={},
5812     lhs_contracting_dims={2}, rhs_contracting_dims={1},
5813     sharding={devices=[2,1,2]0,1,2,3}
5814 })";
5815 
5816   TF_ASSERT_OK_AND_ASSIGN(auto module,
5817                           PartitionComputation(hlo_string, /*num_devices=*/4));
5818   VLOG(1) << module->ToString();
5819 
5820   auto lhs = AllOf(op::Shape("f32[12,4,100]"), op::Parameter(0));
5821   auto rhs = AllOf(op::Shape("f32[16,100]"), op::Parameter(1));
5822   auto partially_replicated_lhs = AllOf(
5823       op::Shape("f32[12,8,100]"),
5824       op::AllReduce(op::DynamicUpdateSlice(op::Broadcast(_), lhs, _, _, _)));
5825   auto dot =
5826       AllOf(op::Shape("f32[12,8,16]"), op::Dot(partially_replicated_lhs, rhs));
5827   auto root = module->entry_computation()->root_instruction();
5828   EXPECT_THAT(root, dot);
5829 }
5830 
TEST_F(SpmdPartitioningTest,DotPartialContractingPartialMatch)5831 TEST_F(SpmdPartitioningTest, DotPartialContractingPartialMatch) {
5832   absl::string_view hlo_string = R"(
5833 HloModule module
5834 
5835 ENTRY entry {
5836   %lhs = f32[24,8,100] parameter(0), sharding={devices=[1,2,2]0,1,2,3}
5837   %rhs = f32[32,8,100] parameter(1),
5838     sharding={devices=[1,1,2,2]0,2,1,3 last_tile_dim_replicate}
5839   ROOT %dot = f32[24,32] dot(%lhs, %rhs),
5840     lhs_batch_dims={}, rhs_batch_dims={},
5841     lhs_contracting_dims={1,2}, rhs_contracting_dims={1,2},
5842     sharding={replicated}
5843 })";
5844 
5845   TF_ASSERT_OK_AND_ASSIGN(auto module,
5846                           PartitionComputation(hlo_string, /*num_devices=*/4));
5847   VLOG(1) << module->ToString();
5848 
5849   auto lhs = AllOf(op::Shape("f32[24,4,50]"), op::Parameter(0));
5850   auto rhs = AllOf(op::Shape("f32[32,8,50]"), op::Parameter(1));
5851   auto dot = AllOf(op::Shape("f32[24,32]"),
5852                    op::Dot(lhs, AllOf(op::Shape("f32[32,4,50]"),
5853                                       op::DynamicSlice(rhs, _, _, _))));
5854   auto root = module->entry_computation()->root_instruction();
5855   EXPECT_THAT(root, op::AllReduce(op::AllReduce(dot)));
5856 }
5857 
TEST_F(SpmdPartitioningTest,DotNonContractingPartialMatchContractingMatch)5858 TEST_F(SpmdPartitioningTest, DotNonContractingPartialMatchContractingMatch) {
5859   absl::string_view hlo_string = R"(
5860 HloModule module
5861 
5862 ENTRY entry {
5863   %lhs = f32[24,8,100] parameter(0), sharding={devices=[2,1,2]0,1,2,3}
5864   %rhs = f32[100,50] parameter(1), sharding={devices=[2,2]0,2,1,3}
5865   ROOT %dot = f32[24,8,50] dot(%lhs, %rhs),
5866     lhs_batch_dims={}, rhs_batch_dims={},
5867     lhs_contracting_dims={2}, rhs_contracting_dims={0},
5868     sharding={devices=[2,2,1]0,1,2,3}
5869 })";
5870 
5871   TF_ASSERT_OK_AND_ASSIGN(auto module,
5872                           PartitionComputation(hlo_string, /*num_devices=*/4));
5873   VLOG(1) << module->ToString();
5874 
5875   auto lhs = AllOf(op::Shape("f32[12,8,50]"), op::Parameter(0));
5876   auto rhs = AllOf(op::Shape("f32[50,25]"), op::Parameter(1));
5877   auto dot = AllOf(
5878       op::Shape("f32[12,8,50]"),
5879       op::Dot(lhs, AllOf(op::Shape("f32[50,50]"),
5880                          op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _)))));
5881   auto root = module->entry_computation()->root_instruction();
5882   EXPECT_THAT(root, AllOf(op::Shape("f32[12,4,50]"),
5883                           op::DynamicSlice(op::AllReduce(dot), _, _, _)))
5884       << module->ToString();
5885 }
5886 
TEST_F(SpmdPartitioningTest,DotLHSMutiNonContractingRHSNotMatch)5887 TEST_F(SpmdPartitioningTest, DotLHSMutiNonContractingRHSNotMatch) {
5888   absl::string_view hlo_string = R"(
5889 HloModule module
5890 
5891 ENTRY entry {
5892   %lhs = f32[24,8,10] parameter(0), sharding={devices=[2,2,1]0,1,2,3}
5893   %rhs = f32[10,50] parameter(1),
5894     sharding={devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}
5895   ROOT %dot = f32[24,8,50] dot(%lhs, %rhs),
5896     lhs_batch_dims={}, rhs_batch_dims={},
5897     lhs_contracting_dims={2}, rhs_contracting_dims={0},
5898     sharding={devices=[2,2,1]0,1,2,3}
5899 })";
5900 
5901   TF_ASSERT_OK_AND_ASSIGN(auto module,
5902                           PartitionComputation(hlo_string, /*num_devices=*/4));
5903   VLOG(1) << module->ToString();
5904 
5905   auto lhs = AllOf(op::Shape("f32[12,4,10]"), op::Parameter(0));
5906   auto rhs = AllOf(op::Shape("f32[5,50]"), op::Parameter(1));
5907   auto dot = AllOf(
5908       op::Shape("f32[12,4,50]"),
5909       op::Dot(lhs, AllOf(op::Shape("f32[10,50]"),
5910                          op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _)))));
5911   auto root = module->entry_computation()->root_instruction();
5912   EXPECT_THAT(root, dot) << module->ToString();
5913 }
5914 
TEST_F(SpmdPartitioningTest,ElementwiseTest_SubgroupSharding_TileToReplicate)5915 TEST_F(SpmdPartitioningTest, ElementwiseTest_SubgroupSharding_TileToReplicate) {
5916   absl::string_view hlo_string = R"(
5917 HloModule module
5918 
5919 ENTRY entry {
5920   constant = f32[6,3]{1,0}
5921     constant({{1,3,7},{5,1,4},{1,2,8},{2,3,7},{5,2,4},{2,2,8}}),
5922     sharding={devices=[1,2,2]0,1,2,3 last_tile_dims={manual}}
5923   constant.1 = f32[6,3]{1,0}
5924     constant({{2,7,2},{2,9,2},{2,6,2},{3,7,2},{2,9,3},{2,3,2}}),
5925     sharding={devices=[1,2,2]0,1,2,3 last_tile_dims={manual}}
5926    multiply = f32[6,3]{1,0} multiply(constant, constant.1),
5927     sharding={devices=[1,2,2]0,1,2,3 last_tile_dims={manual}}
5928    ROOT add = f32[6,3]{1,0} add(multiply, constant.1),
5929     sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dims={replicated, manual}}
5930 }
5931 )";
5932 
5933   TF_ASSERT_OK_AND_ASSIGN(auto module,
5934                           PartitionComputation(hlo_string, /*num_devices=*/4));
5935   VLOG(1) << module->ToString();
5936 
5937   auto multiply_lhs =
5938       AllOf(op::Shape("f32[6,2]"),
5939             op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
5940                              op::Constant(), op::Reshape()));
5941   auto multiply_rhs =
5942       AllOf(op::Shape("f32[6,2]"),
5943             op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
5944                              op::Constant(), op::Reshape()));
5945   auto multiply =
5946       AllOf(op::Shape("f32[6,2]"), op::Multiply(multiply_lhs, multiply_rhs));
5947   auto replicated_lhs =
5948       AllOf(op::Shape("f32[6,3]"),
5949             op::Slice(op::AllReduce(op::DynamicUpdateSlice(
5950                 op::Broadcast(), multiply, op::Constant(), op::Reshape()))));
5951   auto root = module->entry_computation()->root_instruction();
5952   EXPECT_THAT(root, AllOf(op::Shape("f32[6,3]"),
5953                           op::Add(replicated_lhs, op::Constant())));
5954 }
5955 
TEST_F(SpmdPartitioningTest,ElementwiseTest_SubgroupSharding_ReplicateToTile)5956 TEST_F(SpmdPartitioningTest, ElementwiseTest_SubgroupSharding_ReplicateToTile) {
5957   absl::string_view hlo_string = R"(
5958 HloModule module
5959 
5960 ENTRY entry {
5961   constant = f32[6,3]{1,0}
5962     constant({{1,3,7},{5,1,4},{1,2,8},{2,3,7},{5,2,4},{2,2,8}}),
5963     sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dims={replicated,manual}}
5964   constant.1 = f32[6,3]{1,0}
5965     constant({{2,7,2},{2,9,2},{2,6,2},{3,7,2},{2,9,3},{2,3,2}}),
5966     sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dims={replicated,manual}}
5967    multiply = f32[6,3]{1,0} multiply(constant, constant.1),
5968     sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dims={replicated,manual}}
5969    ROOT add = f32[6,3]{1,0} add(multiply, constant.1),
5970     sharding={devices=[1,2,2]0,1,2,3 last_tile_dims={manual}}
5971 }
5972 )";
5973 
5974   TF_ASSERT_OK_AND_ASSIGN(auto module,
5975                           PartitionComputation(hlo_string, /*num_devices=*/4));
5976   VLOG(1) << module->ToString();
5977 
5978   auto multiply = AllOf(op::Shape("f32[6,3]"),
5979                         op::Multiply(op::Constant(), op::Constant()));
5980   auto add_lhs = AllOf(op::Shape("f32[6,2]"),
5981                        op::DynamicSlice(op::Pad(multiply, op::Constant()),
5982                                         op::Constant(), op::Reshape()));
5983   auto add_rhs = AllOf(op::Shape("f32[6,2]"),
5984                        op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
5985                                         op::Constant(), op::Reshape()));
5986   auto root = module->entry_computation()->root_instruction();
5987   EXPECT_THAT(root, AllOf(op::Shape("f32[6,2]"), op::Add(add_lhs, add_rhs)));
5988 }
5989 
TEST_F(SpmdPartitioningTest,ElementwiseTest_PartialReplicateToTiledHaloExchange)5990 TEST_F(SpmdPartitioningTest,
5991        ElementwiseTest_PartialReplicateToTiledHaloExchange) {
5992   absl::string_view hlo_string = R"(
5993 HloModule module
5994 
5995 ENTRY entry {
5996   constant = f32[6,3]{1,0}
5997     constant({{1,3,7},{5,1,4},{1,2,8},{2,3,7},{5,2,4},{2,2,8}}),
5998     sharding={replicated}
5999   constant.1 = f32[6,3]{1,0}
6000     constant({{2,7,2},{2,9,2},{2,6,2},{3,7,2},{2,9,3},{2,3,2}}),
6001     sharding={replicated}
6002   multiply = f32[6,3]{1,0} multiply(constant, constant.1),
6003     sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
6004   ROOT add = f32[6,3]{1,0} add(multiply, constant.1),
6005     sharding={devices=[4,1]0,1,2,3}
6006 }
6007 )";
6008 
6009   TF_ASSERT_OK_AND_ASSIGN(auto module,
6010                           PartitionComputation(hlo_string, /*num_devices=*/4));
6011   VLOG(1) << module->ToString();
6012   auto partial_replicate_lhs =
6013       AllOf(op::Shape("f32[3,3]"),
6014             op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()));
6015   auto partial_replicate_rhs =
6016       AllOf(op::Shape("f32[3,3]"),
6017             op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()));
6018   auto multiply =
6019       AllOf(op::Shape("f32[3,3]"),
6020             op::Multiply(partial_replicate_lhs, partial_replicate_rhs));
6021   auto right_halo =
6022       AllOf(op::Shape("f32[1,3]"), op::CollectivePermute(op::Slice(multiply)));
6023   auto add_lhs = AllOf(
6024       op::Shape("f32[2,3]"),
6025       op::DynamicSlice(
6026           op::DynamicSlice(
6027               op::Pad(op::Concatenate(multiply, right_halo), op::Constant()),
6028               op::Reshape(), op::Constant()),
6029           op::Subtract(), op::Subtract()));
6030   auto add_rhs = AllOf(op::Shape("f32[2,3]"),
6031                        op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
6032                                         op::Reshape(), op::Constant()));
6033   auto root = module->entry_computation()->root_instruction();
6034   EXPECT_THAT(root, AllOf(op::Shape("f32[2,3]"), op::Add(add_lhs, add_rhs)));
6035 }
6036 
TEST_F(SpmdPartitioningTest,TileToPartialReplicateReshard)6037 TEST_F(SpmdPartitioningTest, TileToPartialReplicateReshard) {
6038   absl::string_view hlo_string = R"(
6039 HloModule module
6040 
6041 ENTRY entry {
6042   %param0 = f32[8,8] parameter(0)
6043   %copy = f32[8,8] copy(%param0),
6044     sharding={devices=[2,2]0,1,2,3}
6045   ROOT %copy0 = f32[8,8] copy(%copy),
6046     sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
6047 })";
6048 
6049   TF_ASSERT_OK_AND_ASSIGN(auto module,
6050                           PartitionComputation(hlo_string, /*num_devices=*/4));
6051   VLOG(1) << module->ToString();
6052   auto tiled = AllOf(op::Shape("f32[4,4]"),
6053                      op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
6054                                                op::Reshape())));
6055   auto partially_replicated = AllOf(
6056       op::Shape("f32[4,8]"), op::Copy(op::AllReduce(op::DynamicUpdateSlice(
6057                                  op::Broadcast(_), tiled, _, _))));
6058   auto root = module->entry_computation()->root_instruction();
6059   EXPECT_THAT(root, partially_replicated);
6060 }
6061 
TEST_F(SpmdPartitioningTest,TileToPartialReplicateReshardUnevenPartition)6062 TEST_F(SpmdPartitioningTest, TileToPartialReplicateReshardUnevenPartition) {
6063   absl::string_view hlo_string = R"(
6064 HloModule module
6065 
6066 ENTRY entry {
6067   %param0 = f32[8,8] parameter(0),
6068     sharding={devices=[2,3]0,1,2,3,4,5}
6069   ROOT %copy0 = f32[8,8] copy(%param0),
6070     sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
6071 })";
6072 
6073   TF_ASSERT_OK_AND_ASSIGN(auto module,
6074                           PartitionComputation(hlo_string, /*num_devices=*/6));
6075   VLOG(1) << module->ToString();
6076   auto tiled = AllOf(op::Shape("f32[4,3]"), op::Parameter(0));
6077   auto partially_replicated = AllOf(
6078       op::Shape("f32[8,4]"),
6079       op::Copy(op::Reshape(
6080           op::Transpose(op::AllToAll(op::Reshape(op::Slice(op::AllReduce(
6081               op::DynamicUpdateSlice(op::Broadcast(), tiled, _, _)))))))));
6082   auto root = module->entry_computation()->root_instruction();
6083   EXPECT_THAT(root, partially_replicated);
6084 }
6085 
TEST_F(SpmdPartitioningTest,PartialReplicateToTileReshardUnevenPartition)6086 TEST_F(SpmdPartitioningTest, PartialReplicateToTileReshardUnevenPartition) {
6087   absl::string_view hlo_string = R"(
6088 HloModule module
6089 
6090 ENTRY entry {
6091   %param0 = f32[8,8] parameter(0),
6092     sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
6093   ROOT %copy0 = f32[8,8] copy(%param0),
6094     sharding={devices=[2,3]0,1,2,3,4,5}
6095 })";
6096 
6097   TF_ASSERT_OK_AND_ASSIGN(auto module,
6098                           PartitionComputation(hlo_string, /*num_devices=*/6));
6099   VLOG(1) << module->ToString();
6100   auto partial_replicated = AllOf(op::Shape("f32[8,4]"), op::Parameter(0));
6101   auto tiled = AllOf(
6102       op::Shape("f32[4,3]"),
6103       op::Copy(op::DynamicSlice(op::Pad(op::Reshape(op::Transpose(op::AllToAll(
6104                                             op::Reshape(partial_replicated)))),
6105                                         _),
6106                                 _, _)));
6107   auto root = module->entry_computation()->root_instruction();
6108   EXPECT_THAT(root, tiled);
6109 }
6110 
TEST_F(SpmdPartitioningTest,PartialReplicateToTileReshard)6111 TEST_F(SpmdPartitioningTest, PartialReplicateToTileReshard) {
6112   absl::string_view hlo_string = R"(
6113 HloModule module
6114 
6115 ENTRY entry {
6116   %param0 = f32[8,8] parameter(0)
6117   %copy = f32[8,8] copy(%param0),
6118     sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
6119   ROOT %copy0 = f32[8,8] copy(%copy),
6120     sharding={devices=[2,2]0,1,2,3}
6121 })";
6122 
6123   TF_ASSERT_OK_AND_ASSIGN(auto module,
6124                           PartitionComputation(hlo_string, /*num_devices=*/4));
6125   VLOG(1) << module->ToString();
6126   auto partially_replicated =
6127       AllOf(op::Shape("f32[4,8]"),
6128             op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
6129                                       op::Constant())));
6130   auto tiled =
6131       AllOf(op::Shape("f32[4,4]"),
6132             op::Copy(op::DynamicSlice(partially_replicated, op::Subtract(),
6133                                       op::Subtract())));
6134   auto root = module->entry_computation()->root_instruction();
6135   EXPECT_THAT(root, tiled);
6136 }
6137 
TEST_F(SpmdPartitioningTest,PartialReplicateToPartialReplicateReshard_AllReduce)6138 TEST_F(SpmdPartitioningTest,
6139        PartialReplicateToPartialReplicateReshard_AllReduce) {
6140   absl::string_view hlo_string = R"(
6141 HloModule module
6142 
6143 ENTRY entry {
6144   %param0 = f32[8,8] parameter(0)
6145   %copy = f32[8,8] copy(param0),
6146     sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6147   ROOT %copy0 = f32[8,8] copy(%copy),
6148     sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6149 })";
6150 
6151   TF_ASSERT_OK_AND_ASSIGN(auto module,
6152                           PartitionComputation(hlo_string, /*num_devices=*/8));
6153 
6154   VLOG(1) << module->ToString();
6155   auto partially_replicated_init =
6156       AllOf(op::Shape("f32[4,4]"),
6157             op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
6158                                       op::Reshape())));
6159   auto partially_replicated =
6160       AllOf(op::Shape("f32[4,8]"),
6161             op::Copy(op::AllReduce(op::DynamicUpdateSlice(
6162                 op::Broadcast(_), partially_replicated_init, _, _))));
6163   auto root = module->entry_computation()->root_instruction();
6164   EXPECT_THAT(root, partially_replicated);
6165 }
6166 
TEST_F(SpmdPartitioningTest,PartialReplicateToPartialReplicateReshard_DynamicSlice)6167 TEST_F(SpmdPartitioningTest,
6168        PartialReplicateToPartialReplicateReshard_DynamicSlice) {
6169   absl::string_view hlo_string = R"(
6170 HloModule module
6171 
6172 ENTRY entry {
6173   %param0 = f32[8,8] parameter(0)
6174   %copy = f32[8,8] copy(%param0),
6175     sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6176   ROOT %copy0 = f32[8,8] copy(%copy),
6177     sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6178 })";
6179 
6180   TF_ASSERT_OK_AND_ASSIGN(auto module,
6181                           PartitionComputation(hlo_string, /*num_devices=*/8));
6182   VLOG(1) << module->ToString();
6183   auto partially_replicated =
6184       AllOf(op::Shape("f32[4,8]"),
6185             op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
6186                                       op::Constant())));
6187   auto tiled =
6188       AllOf(op::Shape("f32[4,4]"),
6189             op::Copy(op::DynamicSlice(partially_replicated, op::Subtract(),
6190                                       op::Subtract())));
6191   auto root = module->entry_computation()->root_instruction();
6192   EXPECT_THAT(root, tiled);
6193 }
6194 
TEST_F(SpmdPartitioningTest,PartialReplicateToPartialReplicateReshardWithCollectivePermute)6195 TEST_F(SpmdPartitioningTest,
6196        PartialReplicateToPartialReplicateReshardWithCollectivePermute) {
6197   absl::string_view hlo_string = R"(
6198 HloModule module
6199 
6200 ENTRY entry {
6201   %param0 = f32[8,8] parameter(0)
6202   %copy = f32[8,8] copy(param0),
6203     sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6204   ROOT %copy0 = f32[8,8] copy(%copy),
6205     sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6206 })";
6207 
6208   TF_ASSERT_OK_AND_ASSIGN(auto module,
6209                           PartitionComputation(hlo_string, /*num_devices=*/8));
6210 
6211   VLOG(1) << module->ToString();
6212   auto partially_replicated_init =
6213       AllOf(op::Shape("f32[4,4]"),
6214             op::CollectivePermute(op::Copy(op::DynamicSlice(
6215                 op::Parameter(0), op::Reshape(), op::Reshape()))));
6216   auto partially_replicated =
6217       AllOf(op::Shape("f32[8,4]"),
6218             op::Copy(op::AllReduce(op::DynamicUpdateSlice(
6219                 op::Broadcast(_), partially_replicated_init, _, _))));
6220   auto root = module->entry_computation()->root_instruction();
6221   EXPECT_THAT(root, partially_replicated);
6222 }
6223 
TEST_F(SpmdPartitioningTest,PartialReplicateToPartialReplicateReshardCollectivePermute1)6224 TEST_F(SpmdPartitioningTest,
6225        PartialReplicateToPartialReplicateReshardCollectivePermute1) {
6226   absl::string_view hlo_string = R"(
6227 HloModule module
6228 
6229 ENTRY entry {
6230   %param0 = f32[8,8] parameter(0)
6231   %copy = f32[8,8] copy(%param0),
6232     sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6233   ROOT %copy0 = f32[8,8] copy(%copy),
6234     sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6235 })";
6236 
6237   TF_ASSERT_OK_AND_ASSIGN(auto module,
6238                           PartitionComputation(hlo_string, /*num_devices=*/8));
6239   VLOG(1) << module->ToString();
6240   auto partially_replicated =
6241       AllOf(op::Shape("f32[8,4]"),
6242             op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
6243                                       op::Reshape())));
6244   auto tiled =
6245       AllOf(op::Shape("f32[4,4]"),
6246             op::Copy(op::CollectivePermute(op::DynamicSlice(
6247                 partially_replicated, op::Subtract(), op::Subtract()))));
6248   auto root = module->entry_computation()->root_instruction();
6249   EXPECT_THAT(root, tiled);
6250 }
6251 
TEST_F(SpmdPartitioningTest,PartialReplicateToPartialReplicateReshardHaloExchange)6252 TEST_F(SpmdPartitioningTest,
6253        PartialReplicateToPartialReplicateReshardHaloExchange) {
6254   absl::string_view hlo_string = R"(
6255 HloModule module
6256 
6257 ENTRY entry {
6258   %param0 = f32[6,3] parameter(0)
6259   %copy = f32[6,3] copy(param0),
6260     sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6261   ROOT %copy0 = f32[6,3] copy(%copy),
6262     sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6263 })";
6264 
6265   TF_ASSERT_OK_AND_ASSIGN(auto module,
6266                           PartitionComputation(hlo_string, /*num_devices=*/8));
6267 
6268   VLOG(1) << module->ToString();
6269   auto partially_replicated_init =
6270       AllOf(op::Shape("f32[2,3]"),
6271             op::Copy(op::DynamicSlice(op::Pad(op::Parameter(0), op::Constant()),
6272                                       op::Reshape(), op::Constant())));
6273   auto slice =
6274       AllOf(op::Shape("f32[2,3]"),
6275             op::DynamicSlice(op::Concatenate(op::CollectivePermute(op::Slice(
6276                                                  partially_replicated_init)),
6277                                              partially_replicated_init),
6278                              _, _));
6279   auto partially_replicated =
6280       AllOf(op::Shape("f32[3,3]"),
6281             op::Copy(op::Slice(op::AllReduce(
6282                 op::DynamicUpdateSlice(op::Broadcast(_), slice, _, _)))));
6283   auto root = module->entry_computation()->root_instruction();
6284   EXPECT_THAT(root, partially_replicated);
6285 }
6286 
TEST_F(SpmdPartitioningTest,PartialReplicateToPartialReplicateReshardHaloExchange1)6287 TEST_F(SpmdPartitioningTest,
6288        PartialReplicateToPartialReplicateReshardHaloExchange1) {
6289   absl::string_view hlo_string = R"(
6290 HloModule module
6291 
6292 ENTRY entry {
6293   %param0 = f32[6,3] parameter(0)
6294   %copy = f32[6,3] copy(param0),
6295     sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6296   ROOT %copy0 = f32[6,3] copy(%copy),
6297     sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6298 })";
6299 
6300   TF_ASSERT_OK_AND_ASSIGN(auto module,
6301                           PartitionComputation(hlo_string, /*num_devices=*/8));
6302 
6303   VLOG(1) << module->ToString();
6304   auto partially_replicated_init =
6305       AllOf(op::Shape("f32[3,3]"),
6306             op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
6307                                       op::Constant())));
6308   auto slice = AllOf(
6309       op::Shape("f32[4,3]"),
6310       op::DynamicSlice(op::Pad(op::Concatenate(partially_replicated_init,
6311                                                op::CollectivePermute(op::Slice(
6312                                                    partially_replicated_init))),
6313                                op::Constant()),
6314                        _, _));
6315   auto partially_replicated =
6316       AllOf(op::Shape("f32[2,3]"), op::Copy(op::DynamicSlice(slice, _, _)));
6317   auto root = module->entry_computation()->root_instruction();
6318   EXPECT_THAT(root, partially_replicated);
6319 }
6320 
TEST_F(SpmdPartitioningTest,PartitionConvWithBathGroupCount)6321 TEST_F(SpmdPartitioningTest, PartitionConvWithBathGroupCount) {
6322   absl::string_view hlo_string = R"(
6323 HloModule module
6324 
6325 ENTRY entry {
6326   %lhs = f32[16,801,1,1024] parameter(0)
6327   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6328     sharding={devices=[1,1,1,2]0,1}
6329   %rhs = f32[16,801,1,1024] parameter(1)
6330   %rhs.copy = f32[16,801,1,1024] copy(%rhs),
6331     sharding={devices=[1,1,1,2]0,1}
6332   ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy),
6333     dim_labels=f01b_i01o->01bf,batch_group_count=1024,
6334     window={size=801x1 pad=2_2x0_0},
6335     sharding={devices=[1,1,1,2]0,1}
6336 })";
6337 
6338   TF_ASSERT_OK_AND_ASSIGN(auto module,
6339                           PartitionComputation(hlo_string, /*num_devices=*/2));
6340 
6341   VLOG(1) << module->ToString();
6342   auto root = module->entry_computation()->root_instruction();
6343   auto lhs = AllOf(
6344       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6345                                 op::Constant(), op::Reshape())),
6346       op::Shape("f32[16,801,1,512]"));
6347   auto rhs = AllOf(
6348       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6349                                 op::Constant(), op::Reshape())),
6350       op::Shape("f32[16,801,1,512]"));
6351   EXPECT_THAT(root,
6352               AllOf(op::Convolution(lhs, rhs), op::Shape("f32[5,1,1,512]")));
6353 }
6354 
TEST_F(SpmdPartitioningTest,PartitionConvWithBathGroupCountRHSAlignWithLHS)6355 TEST_F(SpmdPartitioningTest, PartitionConvWithBathGroupCountRHSAlignWithLHS) {
6356   absl::string_view hlo_string = R"(
6357 HloModule module
6358 
6359 ENTRY entry {
6360   %lhs = f32[16,801,1,1024] parameter(0)
6361   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6362     sharding={devices=[1,1,1,2]0,1}
6363   %rhs = f32[16,801,1,1024] parameter(1)
6364   %rhs.copy = f32[16,801,1,1024] copy(%rhs),
6365     sharding={devices=[1,2,1,1]0,1}
6366   ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy),
6367     dim_labels=f01b_i01o->01bf,batch_group_count=1024,
6368     window={size=801x1 pad=2_2x0_0},
6369     sharding={devices=[1,1,1,2]0,1}
6370 })";
6371 
6372   TF_ASSERT_OK_AND_ASSIGN(auto module,
6373                           PartitionComputation(hlo_string, /*num_devices=*/2));
6374   VLOG(1) << module->ToString();
6375   auto root = module->entry_computation()->root_instruction();
6376   auto lhs = AllOf(
6377       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6378                                 op::Constant(), op::Reshape())),
6379       op::Shape("f32[16,801,1,512]"));
6380   auto rhs = AllOf(op::Copy(op::DynamicSlice(
6381                        op::Pad(op::Parameter(), op::Constant()), op::Constant(),
6382                        op::Reshape(), op::Constant(), op::Constant())),
6383                    op::Shape("f32[16,401,1,1024]"));
6384   auto resharded_rhs = AllOf(
6385       op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(rhs))))),
6386       op::Shape("f32[16,801,1,512]"));
6387   EXPECT_THAT(root, AllOf(op::Convolution(lhs, resharded_rhs),
6388                           op::Shape("f32[5,1,1,512]")));
6389 }
6390 
TEST_F(SpmdPartitioningTest,PartitionConvWithBathGroupCountLHSAlignWithRHS)6391 TEST_F(SpmdPartitioningTest, PartitionConvWithBathGroupCountLHSAlignWithRHS) {
6392   absl::string_view hlo_string = R"(
6393 HloModule module
6394 
6395 ENTRY entry {
6396   %lhs = f32[16,801,1,1024] parameter(0)
6397   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6398     sharding={devices=[1,2,1,1]0,1}
6399   %rhs = f32[16,801,1,1024] parameter(1)
6400   %rhs.copy = f32[16,801,1,1024] copy(%rhs),
6401     sharding={devices=[1,1,1,2]0,1}
6402   ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy),
6403     dim_labels=f01b_i01o->01bf,batch_group_count=1024,
6404     window={size=801x1 pad=2_2x0_0},
6405     sharding={devices=[1,1,1,2]0,1}
6406 })";
6407 
6408   TF_ASSERT_OK_AND_ASSIGN(auto module,
6409                           PartitionComputation(hlo_string, /*num_devices=*/2));
6410   VLOG(1) << module->ToString();
6411   auto root = module->entry_computation()->root_instruction();
6412   auto rhs = AllOf(
6413       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6414                                 op::Constant(), op::Reshape())),
6415       op::Shape("f32[16,801,1,512]"));
6416   auto lhs = AllOf(op::Copy(op::DynamicSlice(
6417                        op::Pad(op::Parameter(), op::Constant()), op::Constant(),
6418                        op::Reshape(), op::Constant(), op::Constant())),
6419                    op::Shape("f32[16,401,1,1024]"));
6420   auto resharded_lhs = AllOf(
6421       op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))))),
6422       op::Shape("f32[16,801,1,512]"));
6423   EXPECT_THAT(root, AllOf(op::Convolution(resharded_lhs, rhs),
6424                           op::Shape("f32[5,1,1,512]")));
6425 }
6426 
TEST_F(SpmdPartitioningTest,PartitionConvWithBathGroupCountOutputAlignWithLHS)6427 TEST_F(SpmdPartitioningTest,
6428        PartitionConvWithBathGroupCountOutputAlignWithLHS) {
6429   absl::string_view hlo_string = R"(
6430 HloModule module
6431 
6432 ENTRY entry {
6433   %lhs = f32[16,801,1,1024] parameter(0)
6434   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6435     sharding={devices=[1,1,1,2]0,1}
6436   %rhs = f32[16,801,1,1024] parameter(1)
6437   %rhs.copy = f32[16,801,1,1024] copy(%rhs),
6438     sharding={devices=[1,1,1,2]0,1}
6439   ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy),
6440     dim_labels=f01b_i01o->01bf,batch_group_count=1024,
6441     window={size=801x1 pad=2_2x0_0},
6442     sharding={devices=[2,1,1,1]0,1}
6443 })";
6444 
6445   TF_ASSERT_OK_AND_ASSIGN(auto module,
6446                           PartitionComputation(hlo_string, /*num_devices=*/2));
6447   VLOG(1) << module->ToString();
6448   auto root = module->entry_computation()->root_instruction();
6449   auto lhs = AllOf(
6450       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6451                                 op::Constant(), op::Reshape())),
6452       op::Shape("f32[16,801,1,512]"));
6453   auto rhs = AllOf(
6454       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6455                                 op::Constant(), op::Reshape())),
6456       op::Shape("f32[16,801,1,512]"));
6457   auto conv = AllOf(op::Convolution(lhs, rhs), op::Shape("f32[5,1,1,512]"));
6458   EXPECT_THAT(root, AllOf(op::Reshape(op::Transpose(op::AllToAll(
6459                               op::Reshape(op::Pad(conv, op::Constant()))))),
6460                           op::Shape("f32[3,1,1,1024]")));
6461 }
6462 
TEST_F(SpmdPartitioningTest,PartitionConvWithBathGroupCountOutputAlignWithRHS)6463 TEST_F(SpmdPartitioningTest,
6464        PartitionConvWithBathGroupCountOutputAlignWithRHS) {
6465   absl::string_view hlo_string = R"(
6466 HloModule module
6467 
6468 ENTRY entry {
6469   %lhs = f32[16,801,1,1024] parameter(0)
6470   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6471     sharding={devices=[1,2,1,1]0,1}
6472   %rhs = f32[16,801,1,1024] parameter(1)
6473   %rhs.copy = f32[16,801,1,1024] copy(%rhs),
6474     sharding={devices=[1,1,1,2]0,1}
6475   ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy),
6476     dim_labels=f01b_i01o->01bf,batch_group_count=1024,
6477     window={size=801x1 pad=2_2x0_0},
6478     sharding={devices=[2,1,1,1]0,1}
6479 })";
6480 
6481   TF_ASSERT_OK_AND_ASSIGN(auto module,
6482                           PartitionComputation(hlo_string, /*num_devices=*/2));
6483   VLOG(1) << module->ToString();
6484   auto root = module->entry_computation()->root_instruction();
6485   auto rhs = AllOf(
6486       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6487                                 op::Constant(), op::Reshape())),
6488       op::Shape("f32[16,801,1,512]"));
6489   auto lhs = AllOf(op::Copy(op::DynamicSlice(
6490                        op::Pad(op::Parameter(), op::Constant()), op::Constant(),
6491                        op::Reshape(), op::Constant(), op::Constant())),
6492                    op::Shape("f32[16,401,1,1024]"));
6493   auto resharded_lhs = AllOf(
6494       op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))))),
6495       op::Shape("f32[16,801,1,512]"));
6496   auto conv =
6497       AllOf(op::Convolution(resharded_lhs, rhs), op::Shape("f32[5,1,1,512]"));
6498   EXPECT_THAT(root, AllOf(op::Reshape(op::Transpose(op::AllToAll(
6499                               op::Reshape(op::Pad(conv, op::Constant()))))),
6500                           op::Shape("f32[3,1,1,1024]")));
6501 }
6502 
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCount)6503 TEST_F(SpmdPartitioningTest, PartitionConvWithFeatureGroupCount) {
6504   absl::string_view hlo_string = R"(
6505 HloModule module
6506 
6507 ENTRY entry {
6508   %lhs = f32[16,801,1,1024] parameter(0)
6509   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6510     sharding={devices=[1,1,1,2]0,1}
6511   %rhs = f32[5,1,1,1024] parameter(1)
6512   %rhs.copy = f32[5,1,1,1024] copy(%rhs),
6513     sharding={devices=[1,1,1,2]0,1}
6514   ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
6515     dim_labels=b01f_01io->b01f,feature_group_count=1024,
6516     window={size=5x1 pad=2_2x0_0},
6517     sharding={devices=[1,1,1,2]0,1}
6518 })";
6519 
6520   TF_ASSERT_OK_AND_ASSIGN(auto module,
6521                           PartitionComputation(hlo_string, /*num_devices=*/2));
6522   VLOG(1) << module->ToString();
6523   auto root = module->entry_computation()->root_instruction();
6524   auto lhs = AllOf(
6525       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6526                                 op::Constant(), op::Reshape())),
6527       op::Shape("f32[16,801,1,512]"));
6528   auto rhs = AllOf(
6529       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6530                                 op::Constant(), op::Reshape())),
6531       op::Shape("f32[5,1,1,512]"));
6532   EXPECT_THAT(root,
6533               AllOf(op::Convolution(lhs, rhs), op::Shape("f32[16,801,1,512]")));
6534 }
6535 
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCount2)6536 TEST_F(SpmdPartitioningTest, PartitionConvWithFeatureGroupCount2) {
6537   absl::string_view hlo_string = R"(
6538 HloModule module
6539 
6540 ENTRY entry {
6541   %lhs = f32[64,3,1,3072] parameter(0)
6542   %lhs.copy = f32[64,3,1,3072] copy(%lhs),
6543     sharding={devices=[1,1,1,4,8]0,1,2,3,4,5,6,7,16,17,18,19,20,21,22,23,24,25
6544     ,26,27,28,29,30,31,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
6545   %rhs = f32[3,1,1,3072] parameter(1)
6546   %rhs.copy = f32[3,1,1,3072] copy(%rhs),
6547     sharding={devices=[1,1,1,4,8]0,1,2,3,4,5,6,7,16,17,18,19,20,21,22,23,24,25
6548     ,26,27,28,29,30,31,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
6549   ROOT %conv = f32[64,1,1,3072] convolution(%lhs.copy, %rhs.copy),
6550     dim_labels=b01f_01io->b01f,feature_group_count=3072,
6551     window={size=3x1},
6552     sharding={devices=[8,1,1,4]0,16,24,8,2,18,26,10,4,20,28,12,6,22,30,14,7,23,
6553     31,15,5,21,29,13,3,19,27,11,1,17,25,9}
6554 })";
6555 
6556   TF_ASSERT_OK_AND_ASSIGN(auto module,
6557                           PartitionComputation(hlo_string, /*num_devices=*/32));
6558   VLOG(1) << module->ToString();
6559   auto root = module->entry_computation()->root_instruction();
6560   auto lhs =
6561       AllOf(op::DynamicSlice(
6562                 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(),
6563                                           op::Constant(), op::Constant(),
6564                                           op::Reshape())),
6565                 op::Reshape(), op::Constant(), op::Constant(), op::Constant()),
6566             op::Shape("f32[8,3,1,768]"));
6567   auto rhs = AllOf(
6568       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6569                                 op::Constant(), op::Reshape())),
6570       op::Shape("f32[3,1,1,768]"));
6571   EXPECT_THAT(root,
6572               AllOf(op::Convolution(lhs, rhs), op::Shape("f32[8,1,1,768]")));
6573 }
6574 
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCountRHSAlignWithLHS)6575 TEST_F(SpmdPartitioningTest,
6576        PartitionConvWithFeatureGroupCountRHSAlignWithLHS) {
6577   absl::string_view hlo_string = R"(
6578 HloModule module
6579 
6580 ENTRY entry {
6581   %lhs = f32[16,801,1,1024] parameter(0)
6582   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6583     sharding={devices=[1,1,1,2]0,1}
6584   %rhs = f32[5,1,1,1024] parameter(1)
6585   %rhs.copy = f32[5,1,1,1024] copy(%rhs),
6586     sharding={devices=[2,1,1,1]0,1}
6587   ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
6588     dim_labels=b01f_01io->b01f,feature_group_count=1024,
6589     window={size=5x1 pad=2_2x0_0},
6590     sharding={devices=[1,1,1,2]0,1}
6591 })";
6592 
6593   TF_ASSERT_OK_AND_ASSIGN(auto module,
6594                           PartitionComputation(hlo_string, /*num_devices=*/2));
6595   VLOG(1) << module->ToString();
6596   auto root = module->entry_computation()->root_instruction();
6597   auto lhs = AllOf(
6598       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6599                                 op::Constant(), op::Reshape())),
6600       op::Shape("f32[16,801,1,512]"));
6601   auto rhs = AllOf(op::Copy(op::DynamicSlice(
6602                        op::Pad(op::Parameter(), op::Constant()), op::Reshape(),
6603                        op::Constant(), op::Constant(), op::Constant())),
6604                    op::Shape("f32[3,1,1,1024]"));
6605   auto resharded_rhs = AllOf(
6606       op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(rhs))))),
6607       op::Shape("f32[5,1,1,512]"));
6608   EXPECT_THAT(root, AllOf(op::Convolution(lhs, resharded_rhs),
6609                           op::Shape("f32[16,801,1,512]")));
6610 }
6611 
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCountLHSAlignWithRHS)6612 TEST_F(SpmdPartitioningTest,
6613        PartitionConvWithFeatureGroupCountLHSAlignWithRHS) {
6614   absl::string_view hlo_string = R"(
6615 HloModule module
6616 
6617 ENTRY entry {
6618   %lhs = f32[16,801,1,1024] parameter(0)
6619   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6620     sharding={devices=[1,2,1,1]0,1}
6621   %rhs = f32[5,1,1,1024] parameter(1)
6622   %rhs.copy = f32[5,1,1,1024] copy(%rhs),
6623     sharding={devices=[1,1,1,2]0,1}
6624   ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
6625     dim_labels=b01f_01io->b01f,feature_group_count=1024,
6626     window={size=5x1 pad=2_2x0_0},
6627     sharding={devices=[1,1,1,2]0,1}
6628 })";
6629 
6630   TF_ASSERT_OK_AND_ASSIGN(auto module,
6631                           PartitionComputation(hlo_string, /*num_devices=*/2));
6632   VLOG(1) << module->ToString();
6633   auto root = module->entry_computation()->root_instruction();
6634   auto lhs = AllOf(op::Copy(op::DynamicSlice(
6635                        op::Pad(op::Parameter(), op::Constant()), op::Constant(),
6636                        op::Reshape(), op::Constant(), op::Constant())),
6637                    op::Shape("f32[16,401,1,1024]"));
6638   auto rhs = AllOf(
6639       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6640                                 op::Constant(), op::Reshape())),
6641       op::Shape("f32[5,1,1,512]"));
6642   auto resharded_lhs = AllOf(
6643       op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))))),
6644       op::Shape("f32[16,801,1,512]"));
6645   EXPECT_THAT(root, AllOf(op::Convolution(resharded_lhs, rhs),
6646                           op::Shape("f32[16,801,1,512]")));
6647 }
6648 
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCountAlignOuputWithLHS)6649 TEST_F(SpmdPartitioningTest,
6650        PartitionConvWithFeatureGroupCountAlignOuputWithLHS) {
6651   absl::string_view hlo_string = R"(
6652 HloModule module
6653 
6654 ENTRY entry {
6655   %lhs = f32[16,801,1,1024] parameter(0)
6656   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6657     sharding={devices=[1,1,1,2]0,1}
6658   %rhs = f32[5,1,1,1024] parameter(1)
6659   %rhs.copy = f32[5,1,1,1024] copy(%rhs),
6660     sharding={devices=[1,1,1,2]0,1}
6661   ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
6662     dim_labels=b01f_01io->b01f,feature_group_count=1024,
6663     window={size=5x1 pad=2_2x0_0},
6664     sharding={devices=[2,1,1,1]0,1}
6665 })";
6666 
6667   TF_ASSERT_OK_AND_ASSIGN(auto module,
6668                           PartitionComputation(hlo_string, /*num_devices=*/2));
6669   VLOG(1) << module->ToString();
6670   auto root = module->entry_computation()->root_instruction();
6671   auto lhs = AllOf(
6672       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6673                                 op::Constant(), op::Reshape())),
6674       op::Shape("f32[16,801,1,512]"));
6675   auto rhs = AllOf(
6676       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6677                                 op::Constant(), op::Reshape())),
6678       op::Shape("f32[5,1,1,512]"));
6679   auto conv = AllOf(op::Convolution(lhs, rhs), op::Shape("f32[16,801,1,512]"));
6680   EXPECT_THAT(root,
6681               AllOf(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(conv)))),
6682                     op::Shape("f32[8,801,1,1024]")));
6683 }
6684 
TEST_F(SpmdPartitioningTest,PartitionConvGroupOnFeatureGroupCount_RHSPartialReplicate)6685 TEST_F(SpmdPartitioningTest,
6686        PartitionConvGroupOnFeatureGroupCount_RHSPartialReplicate) {
6687   absl::string_view hlo_string = R"(
6688 HloModule module
6689 
6690 ENTRY entry {
6691   %lhs = f32[16,801,1,1024] parameter(0)
6692   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6693     sharding={devices=[1,2,1,2]0,1,2,3}
6694   %rhs = f32[5,1,1,1024] parameter(1)
6695   %rhs.copy = f32[5,1,1,1024] copy(%rhs),
6696     sharding={devices=[1,1,1,2,2]0,2,1,3 last_tile_dim_replicate}
6697   ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
6698     dim_labels=b01f_01io->b01f,feature_group_count=1024,
6699     window={size=5x1 pad=2_2x0_0},
6700     sharding={devices=[1,2,1,2]0,1,2,3}
6701 })";
6702 
6703   TF_ASSERT_OK_AND_ASSIGN(auto module,
6704                           PartitionComputation(hlo_string, /*num_devices=*/4));
6705   VLOG(1) << module->ToString();
6706   auto root = module->entry_computation()->root_instruction();
6707   auto lhs = AllOf(op::Copy(op::DynamicSlice(
6708                        op::Pad(op::Parameter(), op::Constant()), op::Constant(),
6709                        op::Reshape(), op::Constant(), op::Reshape())),
6710                    op::Shape("f32[16,401,1,512]"));
6711   auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
6712                          op::CollectivePermute(op::Slice(lhs)));
6713   auto right_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
6714                           op::CollectivePermute(op::Slice(lhs)));
6715   auto rhs = AllOf(
6716       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6717                                 op::Constant(), op::Reshape())),
6718       op::Shape("f32[5,1,1,512]"));
6719   EXPECT_THAT(
6720       root,
6721       AllOf(op::Convolution(
6722                 op::Select(_, op::Concatenate(left_halo, lhs, right_halo), _),
6723                 rhs),
6724             op::Shape("f32[16, 401, 1, 512]")));
6725 }
6726 
TEST_F(SpmdPartitioningTest,PartitionConvGroupOnFeatureGroupCount_RHSAlignWithOutput)6727 TEST_F(SpmdPartitioningTest,
6728        PartitionConvGroupOnFeatureGroupCount_RHSAlignWithOutput) {
6729   absl::string_view hlo_string = R"(
6730 HloModule module
6731 
6732 ENTRY entry {
6733   %lhs = f32[16,801,1,1024] parameter(0)
6734   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6735     sharding={devices=[1,2,1,2]0,1,2,3}
6736   %rhs = f32[5,1,1,1024] parameter(1), sharding={replicated}
6737   ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs),
6738     dim_labels=b01f_01io->b01f,feature_group_count=1024,
6739     window={size=5x1 pad=2_2x0_0},
6740     sharding={devices=[1,2,1,2]0,1,2,3}
6741 })";
6742   TF_ASSERT_OK_AND_ASSIGN(auto module,
6743                           PartitionComputation(hlo_string, /*num_devices=*/4));
6744   VLOG(1) << module->ToString();
6745   auto root = module->entry_computation()->root_instruction();
6746   auto lhs = AllOf(op::Copy(op::DynamicSlice(
6747                        op::Pad(op::Parameter(), op::Constant()), op::Constant(),
6748                        op::Reshape(), op::Constant(), op::Reshape())),
6749                    op::Shape("f32[16,401,1,512]"));
6750   auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
6751                          op::CollectivePermute(op::Slice(lhs)));
6752   auto right_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
6753                           op::CollectivePermute(op::Slice(lhs)));
6754   auto rhs =
6755       AllOf(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6756                              op::Constant(), op::Reshape()),
6757             op::Shape("f32[5,1,1,512]"));
6758   EXPECT_THAT(
6759       root,
6760       AllOf(op::Convolution(
6761                 op::Select(_, op::Concatenate(left_halo, lhs, right_halo), _),
6762                 rhs),
6763             op::Shape("f32[16, 401, 1, 512]")));
6764 }
6765 
TEST_F(SpmdPartitioningTest,PartitionConvGroupOnFeatureGroupCount_LHSAlignWithOutput)6766 TEST_F(SpmdPartitioningTest,
6767        PartitionConvGroupOnFeatureGroupCount_LHSAlignWithOutput) {
6768   absl::string_view hlo_string = R"(
6769 HloModule module
6770 
6771 ENTRY entry {
6772   %lhs = f32[16,801,1,1024] parameter(0)
6773   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6774     sharding={devices=[2,1,1,1,2]0,1,2,3 last_tile_dim_replicate}
6775   %rhs = f32[5,1,1,1024] parameter(1)
6776   %rhs.copy = f32[5,1,1,1024] copy(%rhs),
6777     sharding={devices=[1,1,1,2,2]0,2,1,3 last_tile_dim_replicate}
6778   ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
6779     dim_labels=b01f_01io->b01f,feature_group_count=1024,
6780     window={size=5x1 pad=2_2x0_0},
6781     sharding={devices=[1,2,1,2]0,1,2,3}
6782 })";
6783   TF_ASSERT_OK_AND_ASSIGN(auto module,
6784                           PartitionComputation(hlo_string, /*num_devices=*/4));
6785   VLOG(1) << module->ToString();
6786   auto root = module->entry_computation()->root_instruction();
6787   auto lhs = AllOf(
6788       op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(),
6789                                 op::Constant(), op::Constant())),
6790       op::Shape("f32[8,801,1,1024]"));
6791   auto resharded_lhs =
6792       AllOf(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(
6793                 op::Pad(op::DynamicSlice(lhs, op::Subtract(), op::Subtract(),
6794                                          op::Subtract(), op::Subtract()),
6795                         op::Constant()))))),
6796             op::Shape("f32[16,401,1,512]"));
6797   auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
6798                          op::CollectivePermute(op::Slice(resharded_lhs)));
6799   auto right_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
6800                           op::CollectivePermute(op::Slice(resharded_lhs)));
6801   auto rhs = AllOf(
6802       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6803                                 op::Constant(), op::Reshape())),
6804       op::Shape("f32[5,1,1,512]"));
6805   EXPECT_THAT(
6806       root,
6807       AllOf(
6808           op::Convolution(
6809               op::Select(
6810                   _, op::Concatenate(left_halo, resharded_lhs, right_halo), _),
6811               rhs),
6812           op::Shape("f32[16, 401, 1, 512]")));
6813 }
6814 
TEST_F(SpmdPartitioningTest,PartitionConvGroupOnBatchGroupCount)6815 TEST_F(SpmdPartitioningTest, PartitionConvGroupOnBatchGroupCount) {
6816   absl::string_view hlo_string = R"(
6817 HloModule module
6818 
6819 ENTRY entry {
6820   %lhs = f32[16,801,1,1024] parameter(0)
6821   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6822     sharding={devices=[1,2,1,2]0,1,2,3}
6823   %rhs = f32[16,801,1,1024] parameter(1)
6824   %rhs.copy = f32[16,801,1,1024] copy(%rhs),
6825     sharding={devices=[1,2,1,2]0,1,2,3}
6826   ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy),
6827     dim_labels=f01b_i01o->01bf,batch_group_count=1024,
6828     window={size=801x1 pad=2_2x0_0},
6829     sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate}
6830 })";
6831 
6832   TF_ASSERT_OK_AND_ASSIGN(auto module,
6833                           PartitionComputation(hlo_string, /*num_devices=*/4));
6834   VLOG(1) << module->ToString();
6835   auto root = module->entry_computation()->root_instruction();
6836   auto lhs = AllOf(
6837       op::Select(_,
6838                  op::Copy(op::DynamicSlice(
6839                      op::Pad(op::Parameter(), op::Constant()), op::Constant(),
6840                      op::Reshape(), op::Constant(), op::Reshape())),
6841                  _),
6842       op::Shape("f32[16,401,1,512]"));
6843   auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
6844                          op::CollectivePermute(op::Slice(lhs)));
6845   auto right_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
6846                           op::CollectivePermute(op::Slice(lhs)));
6847   auto rhs = AllOf(op::Copy(op::DynamicSlice(
6848                        op::Pad(op::Parameter(), op::Constant()), op::Constant(),
6849                        op::Reshape(), op::Constant(), op::Reshape())),
6850                    op::Shape("f32[16,401,1,512]"));
6851   auto conv = AllOf(op::Convolution(op::Concatenate(left_halo, lhs, right_halo),
6852                                     op::Select(_, rhs, _)),
6853                     op::Shape("f32[5,1,1,512]"));
6854   EXPECT_THAT(root, AllOf(op::CollectivePermute(op::AllReduce(conv)),
6855                           op::Shape("f32[5,1,1,512]")));
6856 }
6857 
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCountAlignOuputWithRHS)6858 TEST_F(SpmdPartitioningTest,
6859        PartitionConvWithFeatureGroupCountAlignOuputWithRHS) {
6860   absl::string_view hlo_string = R"(
6861 HloModule module
6862 
6863 ENTRY entry {
6864   %lhs = f32[16,801,1,1024] parameter(0)
6865   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6866     sharding={devices=[1,2,1,1]0,1}
6867   %rhs = f32[5,1,1,1024] parameter(1)
6868   %rhs.copy = f32[5,1,1,1024] copy(%rhs),
6869     sharding={devices=[1,1,1,2]0,1}
6870   ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
6871     dim_labels=b01f_01io->b01f,feature_group_count=1024,
6872     window={size=5x1 pad=2_2x0_0},
6873     sharding={devices=[2,1,1,1]0,1}
6874 })";
6875 
6876   TF_ASSERT_OK_AND_ASSIGN(auto module,
6877                           PartitionComputation(hlo_string, /*num_devices=*/2));
6878   VLOG(1) << module->ToString();
6879   auto root = module->entry_computation()->root_instruction();
6880   auto lhs = AllOf(op::Copy(op::DynamicSlice(
6881                        op::Pad(op::Parameter(), op::Constant()), op::Constant(),
6882                        op::Reshape(), op::Constant(), op::Constant())),
6883                    op::Shape("f32[16,401,1,1024]"));
6884   auto rhs = AllOf(
6885       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6886                                 op::Constant(), op::Reshape())),
6887       op::Shape("f32[5,1,1,512]"));
6888   auto resharded_lhs = AllOf(
6889       op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))))),
6890       op::Shape("f32[16,801,1,512]"));
6891   auto conv = AllOf(op::Convolution(resharded_lhs, rhs),
6892                     op::Shape("f32[16,801,1,512]"));
6893   EXPECT_THAT(root,
6894               AllOf(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(conv)))),
6895                     op::Shape("f32[8,801,1,1024]")));
6896 }
6897 
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCountBackProp)6898 TEST_F(SpmdPartitioningTest, PartitionConvWithFeatureGroupCountBackProp) {
6899   absl::string_view hlo_string = R"(
6900 HloModule module
6901 
6902 ENTRY entry {
6903   %lhs = f32[16,801,1,1024] parameter(0)
6904   %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6905     sharding={devices=[1,1,1,2]0,1}
6906   %rhs = f32[5,1,1024,1] parameter(1)
6907   %rhs.copy = f32[5,1,1024,1] copy(%rhs),
6908     sharding={devices=[1,1,2,1]0,1}
6909   ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
6910     dim_labels=b01f_01oi->b01f,feature_group_count=1024,
6911     window={size=5x1 pad=2_2x0_0 rhs_reversal=1x1},
6912     sharding={devices=[1,1,1,2]0,1}
6913 })";
6914 
6915   TF_ASSERT_OK_AND_ASSIGN(auto module,
6916                           PartitionComputation(hlo_string, /*num_devices=*/2));
6917   VLOG(1) << module->ToString();
6918   auto root = module->entry_computation()->root_instruction();
6919   auto lhs = AllOf(
6920       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6921                                 op::Constant(), op::Reshape())),
6922       op::Shape("f32[16,801,1,512]"));
6923   auto rhs = AllOf(
6924       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6925                                 op::Reshape(), op::Constant())),
6926       op::Shape("f32[5,1,512,1]"));
6927   EXPECT_THAT(root,
6928               AllOf(op::Convolution(lhs, rhs), op::Shape("f32[16,801,1,512]")));
6929 }
6930 
TEST_F(SpmdPartitioningTest,NoReshardOnBroadcastDims)6931 TEST_F(SpmdPartitioningTest, NoReshardOnBroadcastDims) {
6932   absl::string_view hlo_string = R"(
6933 HloModule module
6934 
6935 ENTRY entry {
6936   %param0 = f32[2,3] parameter(0)
6937   %param1 = f32[2,3,20] parameter(1)
6938   %br0 = f32[20,2,20,3,20] broadcast(%param0), dimensions={1,3}, sharding={devices=[2,1,2,1,2]0,1,2,3,4,5,6,7}
6939   %br1 = f32[20,2,20,3,20] broadcast(%param1), dimensions={1,3,4}, sharding={devices=[2,1,2,1,2]0,1,2,3,4,5,6,7}
6940   %add = f32[20,2,20,3,20] add(%br0, %br1), sharding={devices=[2,1,2,1,2]0,1,2,3,4,5,6,7}
6941   %reshape = f32[10,4,10,6,20] reshape(%br0), sharding={devices=[2,1,2,1,2]0,1,2,3,4,5,6,7}
6942   %transpose = f32[2,3,20,20,20] transpose(%br0), dimensions={1,3,0,2,4}, sharding={devices=[1,1,2,2,2]0,1,2,3,4,5,6,7}
6943   %copy_add0 = f32[20,2,20,3,20] copy(%add), sharding={devices=[2,1,2,1,2]6,7,2,3,4,5,0,1}
6944   %copy_add1 = f32[20,2,20,3,20] copy(%add), sharding={devices=[2,1,2,1,2]7,6,3,2,5,4,0,1}
6945   %copy_reshape = f32[10,4,10,6,20] copy(%reshape), sharding={devices=[2,1,2,1,2]7,6,3,2,5,4,0,1}
6946   %copy_transpose = f32[2,3,20,20,20] copy(%transpose), sharding={devices=[1,1,2,2,2]7,6,3,2,5,4,0,1}
6947   ROOT %tuple = (f32[20,2,20,3,20], f32[20,2,20,3,20], f32[10,4,10,6,20], f32[2,3,20,20,20])
6948     tuple(%copy_add0, %copy_add1, %copy_reshape, %copy_transpose),
6949     sharding={{devices=[2,1,2,1,2]6,7,2,3,4,5,0,1},{devices=[2,1,2,1,2]7,6,3,2,5,4,0,1},{devices=[2,1,2,1,2]7,6,3,2,5,4,0,1},{devices=[1,1,2,2,2]7,6,3,2,5,4,0,1}}
6950 })";
6951 
6952   TF_ASSERT_OK_AND_ASSIGN(auto module,
6953                           PartitionComputation(hlo_string, /*num_devices=*/8));
6954   VLOG(1) << module->ToString();
6955   auto root = module->entry_computation()->root_instruction();
6956   // Reshard on copy_add0 only happens on broadcast dims, can be skipped.
6957   auto copy_add0 =
6958       op::Copy(op::Copy(op::Add(op::Broadcast(_), op::Broadcast(_))));
6959   // Reshard on copy_add1 also happens on non-broadcast dims.
6960   auto copy_add1 = op::Copy(
6961       op::CollectivePermute(op::Add(op::Broadcast(_), op::Broadcast(_))));
6962   // Reshard on copy_reshape only happens on broadcast dims, can be skipped.
6963   auto copy_reshape = op::Copy(op::Copy(op::Reshape(op::Broadcast(_))));
6964   // Reshard on copy_transpose only happens on broadcast dims, can be skipped.
6965   auto copy_transpose = op::Copy(op::Copy(op::Transpose(op::Broadcast(_))));
6966   EXPECT_THAT(root,
6967               op::Tuple(copy_add0, copy_add1, copy_reshape, copy_transpose));
6968 }
6969 
TEST_F(SpmdPartitioningTest,ConvolutionFilterIFOFPartitionedInputPartialReplicate)6970 TEST_F(SpmdPartitioningTest,
6971        ConvolutionFilterIFOFPartitionedInputPartialReplicate) {
6972   absl::string_view hlo_string = R"(
6973 HloModule module
6974 
6975 ENTRY entry {
6976   %lhs = f32[128,112,112,12] parameter(0)
6977   %lhs.copy = f32[128,112,112,12] copy(f32[128,112,112,12] %lhs),
6978     sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate}
6979   %rhs = f32[7,7,12,64] parameter(1)
6980   %rhs.copy = f32[7,7,12,64] copy(f32[7,7,12,64] %rhs),
6981     sharding={devices=[1,1,2,2]0,1,2,3}
6982   ROOT %conv = f32[128,56,56,64] convolution(
6983     f32[128,112,112,12] %lhs.copy,
6984     f32[7,7,12,64] %rhs.copy),
6985     window={size=7x7 stride=2x2 pad=3_3x3_3},
6986     dim_labels=b01f_01io->b01f,
6987     sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate}
6988 })";
6989 
6990   TF_ASSERT_OK_AND_ASSIGN(auto module,
6991                           PartitionComputation(hlo_string, /*num_devices=*/4));
6992   VLOG(1) << module->ToString();
6993   auto root = module->entry_computation()->root_instruction();
6994   auto lhs = AllOf(
6995       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6996                                 op::Constant(), op::Reshape())),
6997       op::Shape("f32[128,112,112,6]"));
6998   auto rhs = AllOf(
6999       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
7000                                 op::Reshape(), op::Reshape())),
7001       op::Shape("f32[7,7,6,32]"));
7002 
7003   EXPECT_THAT(
7004       root,
7005       AllOf(op::CollectivePermute(op::AllReduce(op::Convolution(lhs, rhs))),
7006             op::Shape("f32[128,56,56,32]")));
7007 }
7008 
TEST_F(SpmdPartitioningTest,ConvolutionInputKernelNonContractingDimPartialReplicate)7009 TEST_F(SpmdPartitioningTest,
7010        ConvolutionInputKernelNonContractingDimPartialReplicate) {
7011   absl::string_view hlo_string = R"(
7012 HloModule module
7013 
7014 ENTRY entry {
7015   %lhs = f32[128,56,56,256] parameter(0)
7016   %lhs.copy = f32[128,56,56,256] copy(%lhs),
7017   sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate}
7018   %rhs = f32[128,28,28,512] parameter(1)
7019   %rhs.copy = f32[128,28,28,512] copy(%rhs),
7020   sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate}
7021   ROOT %conv = f32[1,1,256,512] convolution(%lhs.copy, %rhs.copy),
7022     window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf,
7023     sharding={devices=[1,1,2,2]0,1,2,3}
7024 })";
7025 
7026   TF_ASSERT_OK_AND_ASSIGN(auto module,
7027                           PartitionComputation(hlo_string, /*num_devices=*/4));
7028   VLOG(1) << module->ToString();
7029   auto root = module->entry_computation()->root_instruction();
7030   auto lhs = AllOf(
7031       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
7032                                 op::Constant(), op::Reshape())),
7033       op::Shape("f32[128,56,56,128]"));
7034   auto rhs = AllOf(
7035       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
7036                                 op::Constant(), op::Reshape())),
7037       op::Shape("f32[128,28,28,256]"));
7038 
7039   EXPECT_THAT(root, AllOf(op::Convolution(lhs, op::CollectivePermute(rhs)),
7040                           op::Shape("f32[1,1,128,256]")));
7041 }
7042 
TEST_F(SpmdPartitioningTest,ConvolutionInputSpatialDimAndFeatureDimParttiioned)7043 TEST_F(SpmdPartitioningTest,
7044        ConvolutionInputSpatialDimAndFeatureDimParttiioned) {
7045   absl::string_view hlo_string = R"(
7046 HloModule module
7047 
7048 ENTRY entry {
7049   %lhs = f32[8,210,210,12] parameter(0)
7050   %lhs.copy = f32[8,210,210,12] copy(f32[8,210,210,12] %lhs),
7051     sharding={devices=[1,2,1,2]0,1,2,3}
7052   %rhs = f32[3,3,12,32] parameter(1)
7053   %rhs.copy = f32[3,3,12,32] copy(f32[3,3,12,32] %rhs),
7054     sharding={devices=[1,1,2,1,2]0,1,2,3 last_tile_dim_replicate}
7055   ROOT %conv = f32[8,210,210,32] convolution(
7056     f32[8,210,210,12] %lhs.copy,
7057     f32[3,3,12,32] %rhs.copy),
7058     window={size=3x3 pad=1_1x1_1},
7059     dim_labels=b01f_01io->b01f,
7060     sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate}
7061 })";
7062   TF_ASSERT_OK_AND_ASSIGN(auto module,
7063                           PartitionComputation(hlo_string, /*num_devices=*/4));
7064   VLOG(1) << module->ToString();
7065   auto root = module->entry_computation()->root_instruction();
7066   auto lhs = AllOf(
7067       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
7068                                 op::Constant(), op::Reshape())),
7069       op::Shape("f32[8,105,210,6]"));
7070   auto left_halo =
7071       AllOf(op::CollectivePermute(op::Slice(lhs)), op::Shape("f32[8,1,210,6]"));
7072   auto right_halo =
7073       AllOf(op::CollectivePermute(op::Slice(lhs)), op::Shape("f32[8,1,210,6]"));
7074   auto exchanged_lhs = AllOf(
7075       op::Select(op::And(_, _), op::Concatenate(left_halo, lhs, right_halo),
7076                  op::Broadcast(_)),
7077       op::Shape("f32[8,107,210,6]"));
7078   auto rhs = AllOf(
7079       op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
7080                                 op::Reshape(), op::Constant())),
7081       op::Shape("f32[3,3,6,32]"));
7082   EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(
7083                               exchanged_lhs, op::CollectivePermute(rhs))),
7084                           op::Shape("f32[8,105,210,32]")));
7085 }
7086 
TEST_F(SpmdPartitioningTest,Fft3D)7087 TEST_F(SpmdPartitioningTest, Fft3D) {
7088   absl::string_view hlo_string = R"(
7089 HloModule module
7090 
7091 ENTRY entry {
7092   constant = c64[1,1,6]
7093     constant({{{(0,0),(1,1),(2,2),(3,3),(4,4),(5,5)}}}),
7094     sharding={devices=[1,1,2]0,1}
7095   ROOT fft = c64[1,1,6] fft(c64[1,1,6] constant), fft_type=FFT, fft_length={6},
7096     sharding={devices=[1,1,2]0,1}
7097 }
7098 )";
7099 
7100   TF_ASSERT_OK_AND_ASSIGN(auto module,
7101                           PartitionComputation(hlo_string, /*num_devices=*/2));
7102   VLOG(1) << module->ToString();
7103   auto root = module->entry_computation()->root_instruction();
7104   auto input = AllOf(op::DynamicSlice(op::Constant(), op::Constant(),
7105                                       op::Constant(), op::Reshape()),
7106                      op::Shape("c64[1,1,3]"));
7107   auto padded_input =
7108       AllOf(op::DynamicSlice(
7109                 op::Concatenate(input, op::CollectivePermute(op::Slice())),
7110                 op::Constant(), op::Constant(), op::Reshape()),
7111             op::Shape("c64[1,1,4]"));
7112 
7113   auto shuffled_input =
7114       AllOf(op::Slice(op::AllToAll(op::Dot(padded_input, op::Convert()))),
7115             op::Shape("c64[1,1,3]"));
7116 
7117   auto local_fft = AllOf(op::Fft(shuffled_input), op::Shape("c64[1,1,3]"));
7118 
7119   EXPECT_THAT(root, AllOf(op::GetTupleElement(op::While(op::Tuple(
7120                               _, op::Multiply(local_fft, op::Exp()), _, _, _))),
7121                           op::Shape("c64[1,1,3]")));
7122 }
7123 
TEST_F(SpmdPartitioningTest,DotInputsAreIdentical)7124 TEST_F(SpmdPartitioningTest, DotInputsAreIdentical) {
7125   absl::string_view hlo_string = R"(
7126 HloModule module
7127 
7128 ENTRY entry {
7129   %parameter.1 = f32[4000,4000]{1,0} parameter(0),
7130     sharding={devices=[2,4]0,1,2,3,4,5,6,7}
7131   ROOT %convolution = f32[4000,4000]{1,0} convolution(
7132     f32[4000,4000]{1,0} %parameter.1, f32[4000,4000]{1,0} %parameter.1),
7133     dim_labels=bf_io->bf, sharding={devices=[2,4]0,1,2,3,4,5,6,7}
7134 }
7135 
7136 )";
7137 
7138   TF_ASSERT_OK_AND_ASSIGN(auto module,
7139                           PartitionComputation(hlo_string, /*num_devices=*/8));
7140   VLOG(1) << module->ToString();
7141   auto root = module->entry_computation()->root_instruction();
7142   auto param = AllOf(op::Parameter(), op::Shape("f32[2000, 1000]"));
7143   auto resharded_lhs =
7144       AllOf(op::AllReduce(op::DynamicUpdateSlice(_, param, _, _)),
7145             op::Shape("f32[2000, 4000]"));
7146   auto resharded_rhs =
7147       AllOf(op::AllReduce(op::DynamicUpdateSlice(_, op::Copy(param), _, _)),
7148             op::Shape("f32[4000, 1000]"));
7149   EXPECT_THAT(root, AllOf(op::Convolution(resharded_lhs, resharded_rhs),
7150                           op::Shape("f32[2000, 1000]")));
7151 }
7152 
TEST_F(SpmdPartitioningTest,ConstantSliceReshard)7153 TEST_F(SpmdPartitioningTest, ConstantSliceReshard) {
7154   absl::string_view hlo_string = R"(
7155 HloModule module
7156 
7157 ENTRY entry {
7158   %constant.785 = f32[1,8] constant({{0,1,2,3,4,5,6,7}}),
7159     sharding={devices=[1,8]0,1,2,3,4,5,6,7}
7160   %slice.62 = f32[1,1] slice(%constant.785), slice={[0:1], [0:1]},
7161     sharding={devices=[1,8]0,1,2,3,4,5,6,7}
7162   ROOT %reshape.779 = f32[] reshape(%slice.62), sharding={replicated}
7163 })";
7164   TF_ASSERT_OK_AND_ASSIGN(auto module,
7165                           PartitionComputation(hlo_string, /*num_devices=*/8));
7166   auto root = module->entry_computation()->root_instruction();
7167   VLOG(1) << module->ToString();
7168   auto slice = AllOf(op::Shape("f32[1,1]"),
7169                      op::Copy(op::DynamicSlice(op::Constant(), _, _)));
7170   EXPECT_THAT(root, op::Reshape(op::AllReduce(op::Select(_, slice, _))));
7171 }
7172 
TEST_F(SpmdPartitioningTest,GatherParallelDimRedistributionOperand)7173 TEST_F(SpmdPartitioningTest, GatherParallelDimRedistributionOperand) {
7174   absl::string_view hlo_string = R"(
7175 HloModule module
7176 
7177 ENTRY %module {
7178   %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
7179     sharding={devices=[4,2,1,1]0,1,2,3,4,5,6,7}
7180   %constant = s32[4] constant({0, 1, 2, 3}), sharding={replicated}
7181   %iota = s32[1,8,4]{2,1,0} broadcast(%constant), dimensions={2},
7182     sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
7183   %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
7184     sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
7185   %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
7186     s32[1,8,4]{2,1,0} %iota2), dimensions={0},
7187     sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
7188   ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
7189     s32[8,4,2,2]{3,2,1,0} %parameter.0,
7190     s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
7191     collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
7192     slice_sizes={1,1,2,2}, sharding={replicated}
7193 })";
7194   TF_ASSERT_OK_AND_ASSIGN(auto module,
7195                           PartitionComputation(hlo_string, /*num_devices=*/8));
7196   auto root = module->entry_computation()->root_instruction();
7197   VLOG(1) << module->ToString();
7198   auto operand = AllOf(op::Shape("s32[1,4,2,2]"), op::DynamicSlice());
7199   auto indices = AllOf(op::Shape("s32[2,1,4]"), op::Subtract());
7200   auto gather = AllOf(op::Shape("s32[1,4,2,2]"), op::Gather(operand, indices));
7201   EXPECT_THAT(root,
7202               op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)));
7203 }
7204 
TEST_F(SpmdPartitioningTest,GatherParallelDimRedistributionIndices)7205 TEST_F(SpmdPartitioningTest, GatherParallelDimRedistributionIndices) {
7206   absl::string_view hlo_string = R"(
7207 HloModule module
7208 
7209 ENTRY %module {
7210   %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
7211     sharding={devices=[8,1,1,1]0,1,2,3,4,5,6,7}
7212   %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
7213     sharding={devices=[1,4,2]0,1,2,3,4,5,6,7}
7214   %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
7215     sharding={devices=[1,4,2]0,1,2,3,4,5,6,7}
7216   %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
7217     s32[1,8,4]{2,1,0} %iota2), dimensions={0},
7218     sharding={devices=[1,4,2]0,1,2,3,4,5,6,7}
7219   ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(s32[8,4,2,2]{3,2,1,0} %parameter.0,
7220     s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
7221     collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
7222     slice_sizes={1,1,2,2}, sharding={replicated}
7223 })";
7224   TF_ASSERT_OK_AND_ASSIGN(auto module,
7225                           PartitionComputation(hlo_string, /*num_devices=*/8));
7226   auto root = module->entry_computation()->root_instruction();
7227   VLOG(1) << module->ToString();
7228   auto operand = AllOf(op::Shape("s32[2,2,2,2]"), op::DynamicSlice());
7229   auto indices = AllOf(op::Shape("s32[2,2,2]"), op::Subtract());
7230   auto gather = AllOf(op::Shape("s32[2,2,2,2]"), op::Gather(operand, indices));
7231   EXPECT_THAT(root, op::AllReduce(op::AllReduce(
7232                         op::DynamicUpdateSlice(_, gather, _, _, _, _))));
7233 }
7234 
TEST_F(SpmdPartitioningTest,GatherParallelDimReplicatedIndices)7235 TEST_F(SpmdPartitioningTest, GatherParallelDimReplicatedIndices) {
7236   absl::string_view hlo_string = R"(
7237 HloModule module
7238 
7239 ENTRY %module {
7240   %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
7241     sharding={devices=[8,1,1,1]0,1,2,3,4,5,6,7}
7242   %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
7243     sharding={replicated}
7244   %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
7245     sharding={replicated}
7246   %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
7247     s32[1,8,4]{2,1,0} %iota2), dimensions={0},
7248     sharding={replicated}
7249   ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
7250     s32[8,4,2,2]{3,2,1,0} %parameter.0,
7251     s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
7252     collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
7253     slice_sizes={1,1,2,2}, sharding={replicated}
7254 })";
7255   TF_ASSERT_OK_AND_ASSIGN(auto module,
7256                           PartitionComputation(hlo_string, /*num_devices=*/8));
7257   auto root = module->entry_computation()->root_instruction();
7258   VLOG(1) << module->ToString();
7259   auto operand = AllOf(op::Shape("s32[1,4,2,2]"), op::Parameter());
7260   auto indices = AllOf(op::Shape("s32[2,1,4]"), op::Subtract());
7261   auto gather = AllOf(op::Shape("s32[1,4,2,2]"), op::Gather(operand, indices));
7262   EXPECT_THAT(root,
7263               op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)));
7264 }
7265 
TEST_F(SpmdPartitioningTest,GatherParallelDimReplicatedOperand)7266 TEST_F(SpmdPartitioningTest, GatherParallelDimReplicatedOperand) {
7267   absl::string_view hlo_string = R"(
7268 HloModule module
7269 
7270 ENTRY %module {
7271   %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0), sharding={replicated}
7272   %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
7273     sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
7274   %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
7275     sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
7276   %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
7277     s32[1,8,4]{2,1,0} %iota2), dimensions={0},
7278     sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
7279   ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
7280     s32[8,4,2,2]{3,2,1,0} %parameter.0,
7281     s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
7282     collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
7283     slice_sizes={1,1,2,2}, sharding={replicated}
7284 })";
7285   TF_ASSERT_OK_AND_ASSIGN(auto module,
7286                           PartitionComputation(hlo_string, /*num_devices=*/8));
7287   auto root = module->entry_computation()->root_instruction();
7288   VLOG(1) << module->ToString();
7289   auto operand = AllOf(op::Shape("s32[1,4,2,2]"), op::DynamicSlice());
7290   auto indices = AllOf(op::Shape("s32[2,1,4]"), op::Subtract());
7291   auto gather = AllOf(op::Shape("s32[1,4,2,2]"), op::Gather(operand, indices));
7292   EXPECT_THAT(root,
7293               op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)));
7294 }
7295 
TEST_F(SpmdPartitioningTest,GatherParallelDimPartialReplicatedIndices)7296 TEST_F(SpmdPartitioningTest, GatherParallelDimPartialReplicatedIndices) {
7297   absl::string_view hlo_string = R"(
7298 HloModule module
7299 
7300 ENTRY %module {
7301   %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
7302     sharding={devices=[8,1,1,1]0,1,2,3,4,5,6,7}
7303   %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
7304     sharding={devices=[1,2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7305   %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
7306     sharding={devices=[1,2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7307   %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
7308     s32[1,8,4]{2,1,0} %iota2), dimensions={0},
7309     sharding={devices=[1,2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7310   ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
7311     s32[8,4,2,2]{3,2,1,0} %parameter.0,
7312     s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
7313     collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
7314     slice_sizes={1,1,2,2}, sharding={replicated}
7315 })";
7316   TF_ASSERT_OK_AND_ASSIGN(auto module,
7317                           PartitionComputation(hlo_string, /*num_devices=*/8));
7318   auto root = module->entry_computation()->root_instruction();
7319   VLOG(1) << module->ToString();
7320   auto operand = AllOf(op::Shape("s32[1,4,2,2]"), op::Parameter());
7321   auto indices = AllOf(op::Shape("s32[2,1,4]"), op::Subtract());
7322   auto gather = AllOf(op::Shape("s32[1,4,2,2]"), op::Gather(operand, indices));
7323   EXPECT_THAT(root,
7324               op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)));
7325 }
7326 
TEST_F(SpmdPartitioningTest,GatherParallelDimPartialReplicatedOperand)7327 TEST_F(SpmdPartitioningTest, GatherParallelDimPartialReplicatedOperand) {
7328   absl::string_view hlo_string = R"(
7329 HloModule module
7330 
7331 ENTRY %module {
7332   %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0), sharding={
7333     devices=[2,1,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7334   %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
7335     sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
7336   %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
7337     sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
7338   %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
7339     s32[1,8,4]{2,1,0} %iota2), dimensions={0},
7340     sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
7341   ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
7342     s32[8,4,2,2]{3,2,1,0} %parameter.0,
7343     s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
7344     collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
7345     slice_sizes={1,1,2,2}, sharding={replicated}
7346 })";
7347   TF_ASSERT_OK_AND_ASSIGN(auto module,
7348                           PartitionComputation(hlo_string, /*num_devices=*/8));
7349   auto root = module->entry_computation()->root_instruction();
7350   VLOG(1) << module->ToString();
7351   auto operand = AllOf(op::Shape("s32[1,4,2,2]"), op::DynamicSlice());
7352   auto indices = AllOf(op::Shape("s32[2,1,4]"), op::Subtract());
7353   auto gather = AllOf(op::Shape("s32[1,4,2,2]"), op::Gather(operand, indices));
7354   EXPECT_THAT(root,
7355               op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)));
7356 }
7357 
TEST_F(SpmdPartitioningTest,GatherParallelDimSwappedDimensions)7358 TEST_F(SpmdPartitioningTest, GatherParallelDimSwappedDimensions) {
7359   absl::string_view hlo_string = R"(
7360 HloModule module
7361 
7362 ENTRY %module {
7363   %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0), sharding={
7364     devices=[4,2,1,1]0,1,2,3,4,5,6,7}
7365   %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
7366     sharding={devices=[1,2,4]0,1,2,3,4,5,6,7}
7367   %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
7368     sharding={devices=[1,2,4]0,1,2,3,4,5,6,7}
7369   %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
7370     s32[1,8,4]{2,1,0} %iota2), dimensions={0},
7371     sharding={devices=[1,2,4]0,1,2,3,4,5,6,7}
7372   ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
7373     s32[8,4,2,2]{3,2,1,0} %parameter.0,
7374     s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
7375     collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
7376     slice_sizes={1,1,2,2}, sharding={replicated}
7377 })";
7378   TF_ASSERT_OK_AND_ASSIGN(auto module,
7379                           PartitionComputation(hlo_string, /*num_devices=*/8));
7380   auto root = module->entry_computation()->root_instruction();
7381   VLOG(1) << module->ToString();
7382   auto operand = AllOf(op::Shape("s32[4,1,2,2]"), op::CollectivePermute());
7383   auto indices = AllOf(op::Shape("s32[2,4,1]"), op::Subtract());
7384   auto gather = AllOf(op::Shape("s32[4,1,2,2]"), op::Gather(operand, indices));
7385   EXPECT_THAT(root, op::AllReduce(op::AllReduce(
7386                         op::DynamicUpdateSlice(_, gather, _, _, _, _))));
7387 }
7388 
TEST_F(SpmdPartitioningTest,GatherMergedParalleIndexPassthrough)7389 TEST_F(SpmdPartitioningTest, GatherMergedParalleIndexPassthrough) {
7390   absl::string_view hlo_string = R"(
7391 HloModule module
7392 
7393 ENTRY %module {
7394   %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
7395     sharding={devices=[2,2,2,1]0,1,2,3,4,5,6,7}
7396   %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
7397     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7398   %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
7399     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7400   %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
7401     s32[1,8,4]{2,1,0} %iota2), dimensions={0},
7402     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7403   ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
7404     s32[8,4,2,2]{3,2,1,0} %parameter.0,
7405     s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
7406     collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
7407     slice_sizes={1,1,2,2}, sharding={replicated}
7408 })";
7409   TF_ASSERT_OK_AND_ASSIGN(auto module,
7410                           PartitionComputation(hlo_string, /*num_devices=*/8));
7411   VLOG(1) << module->ToString();
7412   auto root = module->entry_computation()->root_instruction();
7413   auto operand = AllOf(op::Shape("s32[2,4,1,2]"), op::DynamicSlice());
7414   auto indices = AllOf(op::Shape("s32[2,2,4]"), op::Subtract());
7415   auto gather = AllOf(op::Shape("s32[2,4,1,2]"), op::Gather(operand, indices));
7416   EXPECT_THAT(root, op::AllReduce(op::AllReduce(
7417                         op::DynamicUpdateSlice(_, gather, _, _, _, _))));
7418 }
7419 
TEST_F(SpmdPartitioningTest,GatherParalleIndexAndOperand)7420 TEST_F(SpmdPartitioningTest, GatherParalleIndexAndOperand) {
7421   absl::string_view hlo_string = R"(
7422 HloModule module
7423 
7424 ENTRY %module {
7425   %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
7426     sharding={devices=[4,1,2,1]0,1,2,3,4,5,6,7}
7427   %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
7428     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7429   %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
7430     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7431   %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
7432     s32[1,8,4]{2,1,0} %iota2), dimensions={0},
7433     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7434   ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
7435     s32[8,4,2,2]{3,2,1,0} %parameter.0,
7436     s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
7437     collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
7438     slice_sizes={1,1,2,2}, sharding={devices=[4,1,2,1]0,1,2,3,4,5,6,7}
7439 })";
7440   TF_ASSERT_OK_AND_ASSIGN(auto module,
7441                           PartitionComputation(hlo_string, /*num_devices=*/8));
7442   VLOG(1) << module->ToString();
7443   auto root = module->entry_computation()->root_instruction();
7444   auto operand = AllOf(op::Shape("s32[2,4,1,2]"), op::Parameter(0));
7445   auto indices = AllOf(op::Shape("s32[2,2,4]"), op::Subtract());
7446   auto gather = AllOf(op::Shape("s32[2,4,1,2]"), op::Gather(operand, indices));
7447   EXPECT_THAT(root, gather);
7448 }
7449 
TEST_F(SpmdPartitioningTest,GatherReshardParalleIndexAndOperand)7450 TEST_F(SpmdPartitioningTest, GatherReshardParalleIndexAndOperand) {
7451   absl::string_view hlo_string = R"(
7452 HloModule module
7453 
7454 ENTRY %module {
7455   %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
7456     sharding={devices=[4,1,2,1]0,1,2,3,4,5,6,7}
7457   %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
7458     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7459   %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
7460     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7461   %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
7462     s32[1,8,4]{2,1,0} %iota2), dimensions={0},
7463     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7464   ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
7465     s32[8,4,2,2]{3,2,1,0} %parameter.0,
7466     s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
7467     collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
7468     slice_sizes={1,1,2,2}, sharding={devices=[4,1,2,1]1,0,3,2,4,5,6,7}
7469 })";
7470   TF_ASSERT_OK_AND_ASSIGN(auto module,
7471                           PartitionComputation(hlo_string, /*num_devices=*/8));
7472   VLOG(1) << module->ToString();
7473   auto root = module->entry_computation()->root_instruction();
7474   auto operand = AllOf(op::Shape("s32[2,4,1,2]"), op::Parameter(0));
7475   auto indices = AllOf(op::Shape("s32[2,2,4]"), op::Subtract());
7476   auto gather = AllOf(op::Shape("s32[2,4,1,2]"), op::Gather(operand, indices));
7477   EXPECT_THAT(root, op::CollectivePermute(gather));
7478 }
7479 
TEST_F(SpmdPartitioningTest,GatherParalleIndexAndOperandReshard)7480 TEST_F(SpmdPartitioningTest, GatherParalleIndexAndOperandReshard) {
7481   absl::string_view hlo_string = R"(
7482 HloModule module
7483 
7484 ENTRY %module {
7485   %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
7486     sharding={devices=[4,1,1,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7487   %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
7488     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7489   %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
7490     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7491   %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
7492     s32[1,8,4]{2,1,0} %iota2), dimensions={0},
7493     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7494   ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
7495     s32[8,4,2,2]{3,2,1,0} %parameter.0,
7496     s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
7497     collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
7498     slice_sizes={1,1,2,2}, sharding={devices=[4,1,2,1]0,1,2,3,4,5,6,7}
7499 })";
7500   TF_ASSERT_OK_AND_ASSIGN(auto module,
7501                           PartitionComputation(hlo_string, /*num_devices=*/8));
7502   VLOG(1) << module->ToString();
7503   auto root = module->entry_computation()->root_instruction();
7504   auto operand = AllOf(op::Shape("s32[2,4,2,2]"), op::Parameter(0));
7505   auto indices = AllOf(op::Shape("s32[2,2,4]"), op::Subtract());
7506   auto gather = AllOf(op::Shape("s32[2,4,2,2]"), op::Gather(operand, indices));
7507   EXPECT_THAT(root, op::DynamicSlice(gather, _, _, _, _));
7508 }
7509 
TEST_F(SpmdPartitioningTest,GatherMergedParallelIndexTrivialSlice)7510 TEST_F(SpmdPartitioningTest, GatherMergedParallelIndexTrivialSlice) {
7511   absl::string_view hlo_string = R"(
7512 HloModule module
7513 
7514 ENTRY %module {
7515   %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
7516     sharding={devices=[4,2,1,1]0,1,2,3,4,5,6,7}
7517   %parameter.1 = s32[1,8,1]{2,1,0} parameter(1),
7518     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7519   %iota = s32[1,8,1]{2,1,0} iota(), iota_dimension=1,
7520     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7521   %concatenate.19 = s32[2,8,1]{2,1,0} concatenate(
7522     s32[1,8,1]{2,1,0} %parameter.1, s32[1,8,1]{2,1,0} %iota), dimensions={0},
7523     sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7524   ROOT %gather.20 = s32[8,1,2,2]{3,2,1,0} gather(
7525     s32[8,4,2,2]{3,2,1,0} %parameter.0,
7526     s32[2,8,1]{2,1,0} %concatenate.19), offset_dims={2,3},
7527     collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
7528     slice_sizes={1,1,2,2}, sharding={replicated}
7529 })";
7530   TF_ASSERT_OK_AND_ASSIGN(auto module,
7531                           PartitionComputation(hlo_string, /*num_devices=*/8));
7532   auto root = module->entry_computation()->root_instruction();
7533   auto operand = AllOf(op::Shape("s32[2,2,2,2]"), op::Parameter());
7534   auto indices = AllOf(op::Shape("s32[2,2,1]"), op::Subtract());
7535   auto gather = AllOf(op::Shape("s32[2,1,2,2]"), op::Gather(operand, indices));
7536   EXPECT_THAT(root,
7537               op::AllReduce(op::DynamicUpdateSlice(
7538                   _, op::AllReduce(op::Select(_, _, gather)), _, _, _, _)));
7539 }
7540 
TEST_F(SpmdPartitioningTest,SortTopKNonSortDimension)7541 TEST_F(SpmdPartitioningTest, SortTopKNonSortDimension) {
7542   absl::string_view hlo_string = R"(
7543 HloModule module
7544 
7545 %compare-greater-than.42077 (p.0.lhs.42078: f32[],
7546   p.0.rhs.42079: f32[], p.1.lhs.42080: s32[], p.1.rhs.42081: s32[]) -> pred[] {
7547   %p.0.lhs.42078 = f32[] parameter(0)
7548   %bitcast-convert.135 = s32[] bitcast-convert(f32[] %p.0.lhs.42078)
7549   %constant.45054 = s32[] constant(0)
7550   %compare.133 = pred[] compare(s32[] %bitcast-convert.135,
7551     s32[] %constant.45054), direction=LT
7552   %constant.45278 = u32[] constant(2147483647)
7553   %bitcast-convert.136 = u32[] bitcast-convert(f32[] %p.0.lhs.42078)
7554   %subtract.337 = u32[] subtract(u32[] %constant.45278,
7555     u32[] %bitcast-convert.136)
7556   %bitcast-convert.137 = s32[] bitcast-convert(u32[] %subtract.337)
7557   %select.282 = s32[] select(pred[] %compare.133, s32[] %bitcast-convert.137,
7558     s32[] %bitcast-convert.135)
7559   %p.0.rhs.42079 = f32[] parameter(1)
7560   %bitcast-convert.138 = s32[] bitcast-convert(f32[] %p.0.rhs.42079)
7561   %compare.134 = pred[] compare(s32[] %bitcast-convert.138,
7562     s32[] %constant.45054), direction=LT
7563   %bitcast-convert.139 = u32[] bitcast-convert(f32[] %p.0.rhs.42079)
7564   %subtract.338 = u32[] subtract(u32[] %constant.45278,
7565     u32[] %bitcast-convert.139)
7566   %bitcast-convert.140 = s32[] bitcast-convert(u32[] %subtract.338)
7567   %select.283 = s32[] select(pred[] %compare.134, s32[] %bitcast-convert.140,
7568     s32[] %bitcast-convert.138)
7569   %compare.135 = pred[] compare(s32[] %select.282,
7570     s32[] %select.283), direction=GT
7571   %compare.428 = pred[] compare(s32[] %select.283,
7572     s32[] %select.282), direction=GT
7573   %compare.429 = pred[] compare(pred[] %compare.135,
7574     pred[] %compare.428), direction=EQ
7575   %p.1.lhs.42080 = s32[] parameter(2)
7576   %p.1.rhs.42081 = s32[] parameter(3)
7577   %compare.430 = pred[] compare(s32[] %p.1.lhs.42080,
7578     s32[] %p.1.rhs.42081), direction=LT
7579   ROOT %select.579 = pred[] select(pred[] %compare.429,
7580     pred[] %compare.430, pred[] %compare.135)
7581 }
7582 
7583 ENTRY %module {
7584   %parameter.0 = f32[2,64,32128]{2,1,0} parameter(0),
7585      sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
7586   %iota = s32[2,64,32128]{2,1,0} iota(), iota_dimension=2,
7587     sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
7588   %sort.18 = (f32[2,64,32128]{2,1,0}, s32[2,64,32128]{2,1,0}) sort(
7589     f32[2,64,32128]{2,1,0} %parameter.0, s32[2,64,32128]{2,1,0} %iota),
7590     dimensions={2}, is_stable=true, to_apply=%compare-greater-than.42077,
7591     sharding={{devices=[2,1,4]0,1,2,3,4,5,6,7},
7592     {devices=[2,1,4]0,1,2,3,4,5,6,7}}
7593   output = f32[2,64,32128]{2,1,0} get-tuple-element(%sort.18), index=0,
7594     sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
7595   %slice.0 = f32[2,64,2]{2,1,0} slice(f32[2,64,32128]{2,1,0} output),
7596     slice={[0:2], [0:64], [0:2]},
7597     sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
7598   output2 = s32[2,64,32128]{2,1,0} get-tuple-element(%sort.18), index=1,
7599     sharding={replicated}
7600   %slice.1 = s32[2,64,2]{2,1,0} slice(s32[2,64,32128]{2,1,0} output2),
7601     slice={[0:2], [0:64], [0:2]},
7602     sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
7603   ROOT output.t = (f32[2,64,2]{2,1,0},
7604     s32[2,64,2]{2,1,0}) tuple(slice.0, slice.1),
7605     sharding={{replicated}, {replicated}}
7606 })";
7607   TF_ASSERT_OK_AND_ASSIGN(auto module,
7608                           PartitionComputation(hlo_string, /*num_devices=*/8));
7609 
7610   const HloInstruction* sort = FindInstruction(module.get(), "sort");
7611   EXPECT_NE(sort, nullptr);
7612   auto sort_match =
7613       AllOf(op::Shape("(f32[2,64,32128], s32[2,64,32128])"), op::Sort(_, _));
7614   EXPECT_THAT(sort, sort_match);
7615 }
7616 
TEST_F(SpmdPartitioningTest,SortTopKPropagateBaseShape)7617 TEST_F(SpmdPartitioningTest, SortTopKPropagateBaseShape) {
7618   absl::string_view hlo_string = R"(
7619 HloModule module
7620 
7621 %compare-greater-than.42077 (p.0.lhs.42078: f32[],
7622   p.0.rhs.42079: f32[], p.1.lhs.42080: s32[], p.1.rhs.42081: s32[]) -> pred[] {
7623   %p.0.lhs.42078 = f32[] parameter(0)
7624   %bitcast-convert.135 = s32[] bitcast-convert(f32[] %p.0.lhs.42078)
7625   %constant.45054 = s32[] constant(0)
7626   %compare.133 = pred[] compare(s32[] %bitcast-convert.135,
7627     s32[] %constant.45054), direction=LT
7628   %constant.45278 = u32[] constant(2147483647)
7629   %bitcast-convert.136 = u32[] bitcast-convert(f32[] %p.0.lhs.42078)
7630   %subtract.337 = u32[] subtract(u32[] %constant.45278,
7631     u32[] %bitcast-convert.136)
7632   %bitcast-convert.137 = s32[] bitcast-convert(u32[] %subtract.337)
7633   %select.282 = s32[] select(pred[] %compare.133, s32[] %bitcast-convert.137,
7634     s32[] %bitcast-convert.135)
7635   %p.0.rhs.42079 = f32[] parameter(1)
7636   %bitcast-convert.138 = s32[] bitcast-convert(f32[] %p.0.rhs.42079)
7637   %compare.134 = pred[] compare(s32[] %bitcast-convert.138,
7638     s32[] %constant.45054), direction=LT
7639   %bitcast-convert.139 = u32[] bitcast-convert(f32[] %p.0.rhs.42079)
7640   %subtract.338 = u32[] subtract(u32[] %constant.45278,
7641     u32[] %bitcast-convert.139)
7642   %bitcast-convert.140 = s32[] bitcast-convert(u32[] %subtract.338)
7643   %select.283 = s32[] select(pred[] %compare.134, s32[] %bitcast-convert.140,
7644     s32[] %bitcast-convert.138)
7645   %compare.135 = pred[] compare(s32[] %select.282,
7646     s32[] %select.283), direction=GT
7647   %compare.428 = pred[] compare(s32[] %select.283,
7648     s32[] %select.282), direction=GT
7649   %compare.429 = pred[] compare(pred[] %compare.135,
7650     pred[] %compare.428), direction=EQ
7651   %p.1.lhs.42080 = s32[] parameter(2)
7652   %p.1.rhs.42081 = s32[] parameter(3)
7653   %compare.430 = pred[] compare(s32[] %p.1.lhs.42080,
7654     s32[] %p.1.rhs.42081), direction=LT
7655   ROOT %select.579 = pred[] select(pred[] %compare.429,
7656     pred[] %compare.430, pred[] %compare.135)
7657 }
7658 
7659 ENTRY %module {
7660   %parameter.0 = f32[2,64,32128]{2,1,0} parameter(0),
7661      sharding={devices=[1,1,8]0,1,2,3,4,5,6,7}
7662   %iota = s32[2,64,32128]{2,1,0} iota(), iota_dimension=2,
7663     sharding={devices=[1,1,8]0,1,2,3,4,5,6,7}
7664   %sort.18 = (f32[2,64,32128]{2,1,0}, s32[2,64,32128]{2,1,0}) sort(
7665     f32[2,64,32128]{2,1,0} %parameter.0, s32[2,64,32128]{2,1,0} %iota),
7666     dimensions={2}, is_stable=true, to_apply=%compare-greater-than.42077,
7667     sharding={{devices=[1,1,8]0,1,2,3,4,5,6,7},
7668     {devices=[1,1,8]0,1,2,3,4,5,6,7}}
7669   output = f32[2,64,32128]{2,1,0} get-tuple-element(%sort.18), index=0,
7670     sharding={devices=[1,1,8]0,1,2,3,4,5,6,7}
7671   %slice.0 = f32[2,64,2]{2,1,0} slice(f32[2,64,32128]{2,1,0} output),
7672     slice={[0:2], [0:64], [0:2]},
7673     sharding={devices=[1,1,8]0,1,2,3,4,5,6,7}
7674   output2 = s32[2,64,32128]{2,1,0} get-tuple-element(%sort.18), index=1,
7675     sharding={replicated}
7676   %slice.1 = s32[2,64,2]{2,1,0} slice(s32[2,64,32128]{2,1,0} output2),
7677     slice={[0:2], [0:64], [0:2]},
7678     sharding={devices=[1,1,8]0,1,2,3,4,5,6,7}
7679   ROOT output.t = (f32[2,64,2]{2,1,0},
7680     s32[2,64,2]{2,1,0}) tuple(slice.0, slice.1),
7681     sharding={{replicated}, {replicated}}
7682 })";
7683   TF_ASSERT_OK_AND_ASSIGN(auto module,
7684                           PartitionComputation(hlo_string, /*num_devices=*/8));
7685 
7686   const HloInstruction* root = module->entry_computation()->root_instruction();
7687   auto all_reduce_val =
7688       AllOf(op::Shape("f32[2,64,2]"),
7689             op::Slice(op::AllReduce(op::DynamicUpdateSlice(_, _, _, _, _))));
7690   auto all_reduce_idx =
7691       AllOf(op::Shape("s32[2,64,2]"),
7692             op::Slice(op::AllReduce(op::DynamicUpdateSlice(_, _, _, _, _))));
7693   auto tuple = AllOf(op::Shape("(f32[2,64,2], s32[2,64,2])"),
7694                      op::Tuple(all_reduce_val, all_reduce_idx));
7695   EXPECT_THAT(root, tuple);
7696 }
7697 
TEST_F(SpmdPartitioningTest,GatherIndexOnlyCorrectReplacement)7698 TEST_F(SpmdPartitioningTest, GatherIndexOnlyCorrectReplacement) {
7699   absl::string_view hlo_string = R"(
7700 HloModule module
7701 
7702 ENTRY %module {
7703   %parameter.0 = bf16[1,8,6,6]{3,2,1,0} parameter(0),
7704     sharding={replicated}
7705   %parameter.1 = s32[2,4]{1,0} parameter(1),
7706      sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7707   %gather.100 = bf16[2,1,8,1,6]{4,3,2,1,0} gather(
7708     bf16[1,8,6,6]{3,2,1,0} %parameter.0, s32[2,4]{1,0} %parameter.1),
7709     offset_dims={1,2,3,4}, collapsed_slice_dims={}, start_index_map={0,1,2,3},
7710     index_vector_dim=1, slice_sizes={1,8,1,6},
7711     sharding={devices=[2,1,4,1,1]0,1,2,3,4,5,6,7}
7712   %constant.45590 = s32[] constant(0), sharding={replicated}
7713   %broadcast.54515 = s32[2,64,1,1]{3,2,1,0} broadcast(s32[] %constant.45590),
7714     dimensions={},
7715     sharding={devices=[2,1,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7716   ROOT %reshape.4243 = bf16[2,8,6]{2,1,0} reshape(
7717     bf16[2,1,8,1,6]{4,3,2,1,0} %gather.100),
7718     sharding={devices=[2,4,1]0,1,2,3,4,5,6,7}
7719 })";
7720   TF_ASSERT_OK_AND_ASSIGN(auto module,
7721                           PartitionComputation(hlo_string, /*num_devices=*/8));
7722 
7723   const HloInstruction* root = module->entry_computation()->root_instruction();
7724   auto param0 = AllOf(op::Shape("bf16[1,8,6,6]"), op::Parameter());
7725   auto param1 = AllOf(op::Shape("s32[1,4]"), op::Parameter());
7726   auto reshape = AllOf(
7727       op::Shape("bf16[1,2,6]"),
7728       op::Reshape(op::DynamicSlice(op::Gather(param0, param1), _, _, _, _, _)));
7729   EXPECT_THAT(root, reshape);
7730 }
7731 
TEST_F(SpmdPartitioningTest,GatherRegressionTest1)7732 TEST_F(SpmdPartitioningTest, GatherRegressionTest1) {
7733   absl::string_view hlo_string = R"(
7734 HloModule module
7735 
7736 ENTRY %module {
7737   %parameter.0 = s32[1,4] parameter(0), sharding={devices=[1,8]0,1,2,3,4,5,6,7}
7738   %iota.10 = s32[4]{0} iota(), iota_dimension=0, sharding={devices=[8]0,1,2,3,4,5,6,7}
7739   ROOT %gather.44 = s32[1,4]{1,0} gather(%parameter.0, %iota.10),
7740     offset_dims={0}, collapsed_slice_dims={1}, start_index_map={1}, index_vector_dim=1,
7741     slice_sizes={1,1}, sharding={devices=[1,8]0,1,2,3,4,5,6,7}
7742 })";
7743   TF_ASSERT_OK_AND_ASSIGN(auto module,
7744                           PartitionComputation(hlo_string, /*num_devices=*/8));
7745 
7746   const HloInstruction* root = module->entry_computation()->root_instruction();
7747   auto param0 = AllOf(op::Shape("s32[1,1]"), op::Parameter());
7748   EXPECT_THAT(root, op::Gather(param0, _));
7749 }
7750 
TEST_F(SpmdPartitioningTest,WindowedEinsumPreferMemoryFootprint)7751 TEST_F(SpmdPartitioningTest, WindowedEinsumPreferMemoryFootprint) {
7752   absl::string_view hlo_string = R"(
7753 HloModule module
7754 
7755 ENTRY %module {
7756   %parameter.0 = bf16[128,1024,4,4,1152,1,1]{6,5,4,3,2,1,0} parameter(0),
7757     sharding={devices=[4,1,2,1,1,1,1]0,1,2,3,4,5,6,7}
7758   %parameter.1 = bf16[4,4,1152,4,176,256,1]{6,5,4,3,2,1,0} parameter(1),
7759     sharding={devices=[2,2,1,2,1,1,1]0,1,2,3,4,5,6,7}
7760   %convolution.3 = bf16[128,1024,4,176,256,1,1]{6,5,4,3,2,1,0}
7761     convolution(bf16[128,1024,4,4,1152,1,1]{6,5,4,3,2,1,0} %parameter.0,
7762     bf16[4,4,1152,4,176,256,1]{6,5,4,3,2,1,0} %parameter.1),
7763     window={size=1x4x176x4x4 pad=0_0x3_3x175_175x0_0x0_0
7764     rhs_reversal=0x1x1x0x0}, dim_labels=0b34f12_34i12o0->0b12f34,
7765     sharding={devices=[4,1,2,1,1,1,1]0,1,2,3,4,5,6,7}
7766   ROOT %reshape.3973 = bf16[128,1024,4,176,256]{4,3,2,1,0}
7767     reshape(bf16[128,1024,4,176,256,1,1]{6,5,4,3,2,1,0} %convolution.3),
7768     sharding={replicated}
7769 })";
7770   TF_ASSERT_OK_AND_ASSIGN(
7771       auto module,
7772       PartitionComputation(hlo_string, /*num_devices=*/8,
7773                            /*conv_halo_exchange_always_on_lhs =*/true,
7774                            /*choose_faster_windowed_einsum =*/false));
7775   const HloInstruction* while_inst = FindInstruction(module.get(), "while");
7776   EXPECT_NE(while_inst, nullptr);
7777   const HloComputation* cond_comp = while_inst->while_condition();
7778   const HloInstruction* root = cond_comp->root_instruction();
7779   EXPECT_THAT(root, op::Compare(_, op::Constant()));
7780   const HloConstantInstruction* iterations =
7781       Cast<HloConstantInstruction>(root->operand(1));
7782   EXPECT_TRUE(iterations->literal().GetFirstInteger());
7783   EXPECT_EQ(*iterations->literal().GetFirstInteger(), 4);
7784 }
7785 
TEST_F(SpmdPartitioningTest,WindowedEinsumPreferNumberIterations)7786 TEST_F(SpmdPartitioningTest, WindowedEinsumPreferNumberIterations) {
7787   absl::string_view hlo_string = R"(
7788 HloModule module
7789 
7790 ENTRY %module {
7791   %parameter.0 = bf16[128,1024,4,4,1152,1,1]{6,5,4,3,2,1,0} parameter(0),
7792     sharding={devices=[4,1,2,1,1,1,1]0,1,2,3,4,5,6,7}
7793   %parameter.1 = bf16[4,4,1152,4,176,256,1]{6,5,4,3,2,1,0} parameter(1),
7794     sharding={devices=[2,2,1,2,1,1,1]0,1,2,3,4,5,6,7}
7795   %convolution.3 = bf16[128,1024,4,176,256,1,1]{6,5,4,3,2,1,0}
7796     convolution(bf16[128,1024,4,4,1152,1,1]{6,5,4,3,2,1,0} %parameter.0,
7797     bf16[4,4,1152,4,176,256,1]{6,5,4,3,2,1,0} %parameter.1),
7798     window={size=1x4x176x4x4 pad=0_0x3_3x175_175x0_0x0_0
7799     rhs_reversal=0x1x1x0x0}, dim_labels=0b34f12_34i12o0->0b12f34,
7800     sharding={devices=[4,1,2,1,1,1,1]0,1,2,3,4,5,6,7}
7801   ROOT %reshape.3973 = bf16[128,1024,4,176,256]{4,3,2,1,0}
7802     reshape(bf16[128,1024,4,176,256,1,1]{6,5,4,3,2,1,0} %convolution.3),
7803     sharding={replicated}
7804 })";
7805   TF_ASSERT_OK_AND_ASSIGN(
7806       auto module,
7807       PartitionComputation(hlo_string, /*num_devices=*/8,
7808                            /*conv_halo_exchange_always_on_lhs =*/true,
7809                            /*choose_faster_windowed_einsum =*/true));
7810   const HloInstruction* while_inst = FindInstruction(module.get(), "while");
7811   EXPECT_NE(while_inst, nullptr);
7812   const HloComputation* cond_comp = while_inst->while_condition();
7813   const HloInstruction* root = cond_comp->root_instruction();
7814   EXPECT_THAT(root, op::Compare(_, op::Constant()));
7815   const HloConstantInstruction* iterations =
7816       Cast<HloConstantInstruction>(root->operand(1));
7817   EXPECT_TRUE(iterations->literal().GetFirstInteger());
7818   EXPECT_EQ(*iterations->literal().GetFirstInteger(), 2);
7819 }
7820 
TEST_F(SpmdPartitioningTest,WindowedEinsumPreferNumberIterations2)7821 TEST_F(SpmdPartitioningTest, WindowedEinsumPreferNumberIterations2) {
7822   const char* const hlo_string = R"(
7823 HloModule module
7824 
7825 ENTRY entry {
7826   %lhs = bf16[512,1024,16,36,256]{4,3,2,1,0} parameter(0)
7827   %lhs.copy = bf16[512,1024,16,36,256]{4,3,2,1,0} copy(%lhs),
7828   sharding={devices=[8,1,4,1,1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,
7829             18,19,20,21,22,23,24,25,26,27,28,29,30,31}
7830   %rhs = bf16[512,1024,16,4,288]{4,3,2,1,0} parameter(1)
7831   %rhs.copy = bf16[512,1024,16,4,288]{4,3,2,1,0} copy(%rhs),
7832     sharding={devices=[8,1,4,1,1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,
7833               17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}
7834   %reshape.2556 = bf16[512,1024,16,4,288,1,1]{6,5,4,3,2,1,0} reshape(
7835     bf16[512,1024,16,4,288]{4,3,2,1,0} %rhs.copy), sharding={
7836       devices=[8,1,4,1,1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,
7837         20,21,22,23,24,25,26,27,28,29,30,31}
7838   %reshape.2570 = bf16[512,1024,16,36,256,1,1]{6,5,4,3,2,1,0}
7839     reshape(bf16[512,1024,16,36,256]{4,3,2,1,0} %lhs.copy), sharding={
7840     devices=[8,1,4,1,1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,
7841              20,21,22,23,24,25,26,27,28,29,30,31}
7842   %convolution.10 = bf16[16,36,256,16,4,288,1]{6,5,4,3,2,1,0}
7843     convolution(bf16[512,1024,16,36,256,1,1]{6,5,4,3,2,1,0} %reshape.2570,
7844     bf16[512,1024,16,4,288,1,1]{6,5,4,3,2,1,0} %reshape.2556),
7845     window={size=1x1x16x4x512 pad=0_0x0_0x15_15x3_3x0_0 rhs_reversal=0x0x1x1x0},
7846     dim_labels=4f01b23_4i23o01->01b23f4, sharding={devices=[4,1,1,4,2,1,1]0,4,8,
7847     12,16,20,24,28,1,5,9,13,17,21,25,29,2,6,10,14,18,22,26,30,3,7,11,15,19,23,
7848     27,31}
7849   ROOT %output = bf16[16,36,256,16,4,288,1]{6,5,4,3,2,1,0}
7850    copy(%convolution.10), sharding={replicated}
7851 })";
7852   TF_ASSERT_OK_AND_ASSIGN(
7853       auto module,
7854       PartitionComputation(hlo_string, /*num_devices=*/32,
7855                            /*conv_halo_exchange_always_on_lhs =*/true,
7856                            /*choose_faster_windowed_einsum =*/true));
7857   const HloInstruction* while_inst = FindInstruction(module.get(), "while");
7858   EXPECT_NE(while_inst, nullptr);
7859   const HloComputation* cond_comp = while_inst->while_condition();
7860   const HloInstruction* root = cond_comp->root_instruction();
7861   EXPECT_THAT(root, op::Compare(_, op::Constant()));
7862   const HloConstantInstruction* iterations =
7863       Cast<HloConstantInstruction>(root->operand(1));
7864   EXPECT_TRUE(iterations->literal().GetFirstInteger());
7865   EXPECT_EQ(*iterations->literal().GetFirstInteger(), 4);
7866 }
7867 
TEST_F(SpmdPartitioningTest,WindowedEinsumPreferMemoryFootprint2)7868 TEST_F(SpmdPartitioningTest, WindowedEinsumPreferMemoryFootprint2) {
7869   const char* const hlo_string = R"(
7870 HloModule module
7871 
7872 ENTRY entry {
7873   %lhs = bf16[512,1024,16,36,256]{4,3,2,1,0} parameter(0)
7874   %lhs.copy = bf16[512,1024,16,36,256]{4,3,2,1,0} copy(%lhs),
7875   sharding={devices=[8,1,4,1,1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,
7876             18,19,20,21,22,23,24,25,26,27,28,29,30,31}
7877   %rhs = bf16[512,1024,16,4,288]{4,3,2,1,0} parameter(1)
7878   %rhs.copy = bf16[512,1024,16,4,288]{4,3,2,1,0} copy(%rhs),
7879     sharding={devices=[8,1,4,1,1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,
7880               17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}
7881   %reshape.2556 = bf16[512,1024,16,4,288,1,1]{6,5,4,3,2,1,0} reshape(
7882     bf16[512,1024,16,4,288]{4,3,2,1,0} %rhs.copy), sharding={
7883       devices=[8,1,4,1,1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,
7884         20,21,22,23,24,25,26,27,28,29,30,31}
7885   %reshape.2570 = bf16[512,1024,16,36,256,1,1]{6,5,4,3,2,1,0}
7886     reshape(bf16[512,1024,16,36,256]{4,3,2,1,0} %lhs.copy), sharding={
7887     devices=[8,1,4,1,1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,
7888              20,21,22,23,24,25,26,27,28,29,30,31}
7889   %convolution.10 = bf16[16,36,256,16,4,288,1]{6,5,4,3,2,1,0}
7890     convolution(bf16[512,1024,16,36,256,1,1]{6,5,4,3,2,1,0} %reshape.2570,
7891     bf16[512,1024,16,4,288,1,1]{6,5,4,3,2,1,0} %reshape.2556),
7892     window={size=1x1x16x4x512 pad=0_0x0_0x15_15x3_3x0_0 rhs_reversal=0x0x1x1x0},
7893     dim_labels=4f01b23_4i23o01->01b23f4, sharding={devices=[4,1,1,4,2,1,1]0,4,8,
7894     12,16,20,24,28,1,5,9,13,17,21,25,29,2,6,10,14,18,22,26,30,3,7,11,15,19,23,
7895     27,31}
7896   ROOT %output = bf16[16,36,256,16,4,288,1]{6,5,4,3,2,1,0}
7897    copy(%convolution.10), sharding={replicated}
7898 })";
7899   TF_ASSERT_OK_AND_ASSIGN(
7900       auto module,
7901       PartitionComputation(hlo_string, /*num_devices=*/32,
7902                            /*conv_halo_exchange_always_on_lhs =*/true,
7903                            /*choose_faster_windowed_einsum =*/false));
7904   const HloInstruction* while_inst = FindInstruction(module.get(), "while");
7905   EXPECT_NE(while_inst, nullptr);
7906   const HloComputation* cond_comp = while_inst->while_condition();
7907   const HloInstruction* root = cond_comp->root_instruction();
7908   EXPECT_THAT(root, op::Compare(_, op::Constant()));
7909   const HloConstantInstruction* iterations =
7910       Cast<HloConstantInstruction>(root->operand(1));
7911   EXPECT_TRUE(iterations->literal().GetFirstInteger());
7912   EXPECT_EQ(*iterations->literal().GetFirstInteger(), 8);
7913 }
7914 
TEST_F(SpmdPartitioningTest,ContractingPartitionDotOperandsSlicedWrong)7915 TEST_F(SpmdPartitioningTest, ContractingPartitionDotOperandsSlicedWrong) {
7916   const char* const hlo_string = R"(
7917 HloModule module
7918 
7919 ENTRY entry {
7920   %lhs = f32[8,2,15,4] parameter(0)
7921   %lhs.copy = f32[8,2,15,4] copy(%lhs),
7922     sharding={devices=[1,2,4,1]0,1,2,3,4,5,6,7}
7923   %rhs = f32[2,15,4] parameter(1)
7924   %rhs.copy = f32[2,15,4] copy(%rhs),
7925     sharding={devices=[2,4,1]0,1,2,3,4,5,6,7}
7926   %dot = f32[8,2,2] dot(%lhs.copy, %rhs.copy),
7927     lhs_batch_dims={}, rhs_batch_dims={},
7928     lhs_contracting_dims={2,3}, rhs_contracting_dims={1,2},
7929     operand_precision={HIGH,HIGH},
7930     sharding={devices=[2,2,2]0,1,2,3,4,5,6,7}
7931   ROOT %output = f32[8,2,2] copy(%dot), sharding={replicated}
7932 })";
7933   TF_ASSERT_OK_AND_ASSIGN(
7934       auto module,
7935       PartitionComputation(hlo_string, /*num_devices=*/8,
7936                            /*conv_halo_exchange_always_on_lhs =*/true,
7937                            /*choose_faster_windowed_einsum =*/true));
7938 
7939   const HloInstruction* dot_op = FindInstruction(module.get(), "dot.1");
7940   auto op1 = op::Shape("f32[4,2,4,4]");
7941   auto op2 = op::Shape("f32[2,4,4]");
7942   EXPECT_THAT(dot_op, op::Dot(op1, op2));
7943 }
7944 
TEST_F(SpmdPartitioningTest,PartitionDotGroupOnBatchContractingReshard)7945 TEST_F(SpmdPartitioningTest, PartitionDotGroupOnBatchContractingReshard) {
7946   absl::string_view hlo_string = R"(
7947 HloModule module
7948 
7949 ENTRY entry {
7950   %lhs = f32[32,32,24,4096] parameter(0),
7951     sharding={devices=[2,1,1,2]0,1,2,3}
7952   %rhs = f32[32,4096,1024] parameter(1),
7953     sharding={devices=[2,2,1]0,1,2,3}
7954   ROOT %dot = f32[32,32,24,1024] dot(%lhs, %rhs),
7955     lhs_batch_dims={0}, rhs_batch_dims={0},
7956     lhs_contracting_dims={3}, rhs_contracting_dims={1},
7957     sharding={devices=[1,2,1,2]0,1,2,3}
7958 })";
7959 
7960   TF_ASSERT_OK_AND_ASSIGN(auto module,
7961                           PartitionComputation(hlo_string, /*num_devices=*/4));
7962   VLOG(1) << module->ToString();
7963   auto root = module->entry_computation()->root_instruction();
7964   auto dot = AllOf(op::Shape("f32[16,32,24,1024]"),
7965                    op::Dot(op::Parameter(0), op::Parameter(1)));
7966   auto reduce_scatter = AllOf(op::Shape("f32[16,32,24,512]"),
7967                               op::DynamicSlice(op::AllReduce(dot), _, _, _, _));
7968   EXPECT_THAT(root, AllOf(op::Reshape(op::Transpose(
7969                               op::AllToAll(op::Reshape(reduce_scatter)))),
7970                           op::Shape("f32[32,16,24,512]")));
7971 }
7972 
TEST_F(SpmdPartitioningTest,PartitionPassthroughScatterCorrectOutputSharding)7973 TEST_F(SpmdPartitioningTest, PartitionPassthroughScatterCorrectOutputSharding) {
7974   absl::string_view hlo_string = R"(
7975 HloModule module
7976 
7977 %scatter_add (parameter.0: bf16[], parameter.1: bf16[]) -> bf16[] {
7978   %parameter.0 = bf16[] parameter(0)
7979   %parameter.1 = bf16[] parameter(1)
7980   ROOT %add = bf16[] add(bf16[] %parameter.0, bf16[] %parameter.1)
7981 }
7982 
7983 ENTRY entry {
7984   %operand = bf16[2,1024]{1,0} parameter(0),
7985     sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
7986   %indices = s32[8,512,1]{2,1,0} parameter(1),
7987     sharding={devices=[2,1,1,2]0,2,1,3 last_tile_dim_replicate}
7988   %updates = bf16[8,512,1024]{2,1,0} parameter(2),
7989     sharding={devices=[2,1,2]0,2,1,3}
7990   ROOT %scatter = bf16[2,1024]{1,0} scatter(bf16[2,1024]{1,0} %operand,
7991     s32[8,512,1]{2,1,0} %indices,
7992     bf16[8,512,1024]{2,1,0} %updates), update_window_dims={2},
7993     inserted_window_dims={0}, scatter_dims_to_operand_dims={0},
7994     index_vector_dim=2, to_apply=%scatter_add,
7995     sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
7996 })";
7997 
7998   TF_ASSERT_OK_AND_ASSIGN(auto module,
7999                           PartitionComputation(hlo_string, /*num_devices=*/4));
8000   VLOG(1) << module->ToString();
8001   auto root = module->entry_computation()->root_instruction();
8002   auto scatter = AllOf(op::Shape("bf16[2,512]"), op::Scatter(_, _, _));
8003   EXPECT_THAT(root, scatter);
8004 }
8005 
IsTrivialCollectivePermute(HloInstruction * hlo)8006 bool IsTrivialCollectivePermute(HloInstruction* hlo) {
8007   if (hlo->opcode() != HloOpcode::kCollectivePermute) {
8008     return false;
8009   }
8010   if (hlo->source_target_pairs().empty()) {
8011     return true;
8012   }
8013   return absl::c_all_of(hlo->source_target_pairs(),
8014                         [](const std::pair<int64, int64>& pair) {
8015                           return pair.first == pair.second;
8016                         });
8017 }
8018 
TEST_F(SpmdPartitioningTest,CollectivePermuteSimplifyIdentity)8019 TEST_F(SpmdPartitioningTest, CollectivePermuteSimplifyIdentity) {
8020   absl::string_view hlo_string = R"(
8021 HloModule test
8022 
8023 ENTRY entry {
8024   %parameter.7 = f32[3,16] parameter(0), sharding={devices=[1,2]0,1}
8025   %constant.7 = f32[] constant(0)
8026   %pad.3 = f32[3,18] pad(f32[3,16] %parameter.7, f32[] %constant.7), padding=0_0x1_1, sharding={devices=[1,2]0,1}
8027   // Shift right by 16.
8028   %slice.8 = f32[3,16] slice(f32[3,18] %pad.3), slice={[0:3], [2:18]}, sharding={devices=[1,2]0,1}
8029   %slice.9 = f32[3,2] slice(f32[3,18] %pad.3), slice={[0:3], [0:2]}, sharding={devices=[1,2]0,1}
8030   ROOT %concatenate.6 = f32[3,18] concatenate(f32[3,16] %slice.8, f32[3,2] %slice.9), dimensions={1}, sharding={devices=[1,2]0,1}
8031 }
8032 )";
8033 
8034   TF_ASSERT_OK_AND_ASSIGN(auto module,
8035                           PartitionComputation(hlo_string, /*num_devices=*/2));
8036   VLOG(1) << module->ToString();
8037 
8038   // Check that the partitioned code does not have a "trivial" collective
8039   // permute (which would degenerate to a copy).
8040   for (HloComputation* computation : module->computations()) {
8041     for (HloInstruction* hlo : computation->instructions()) {
8042       EXPECT_FALSE(IsTrivialCollectivePermute(hlo)) << hlo->ToString();
8043     }
8044   }
8045 }
8046 
TEST_F(SpmdPartitioningTest,CollectivePermuteSimplifyZero)8047 TEST_F(SpmdPartitioningTest, CollectivePermuteSimplifyZero) {
8048   absl::string_view hlo_string = R"(
8049 HloModule test
8050 
8051 ENTRY entry {
8052   %parameter = f32[3,16,16,16,16,132]{5,4,3,2,1,0} parameter(0), sharding={devices=[1,2,1,1,1,1]0,1}
8053   %slice = f32[3,1,16,16,16,132]{5,4,3,2,1,0} slice(f32[3,16,16,16,16,132]{5,4,3,2,1,0} %parameter), slice={[0:3], [15:16], [0:16], [0:16], [0:16], [0:132]}, sharding={devices=[1,2,1,1,1,1]0,1}
8054   %c0 = f32[] constant(0)
8055   ROOT %pad = f32[3,18,16,16,16,132]{5,4,3,2,1,0} pad(f32[3,1,16,16,16,132]{5,4,3,2,1,0} %slice, f32[] %c0), padding=0_0x0_17x0_0x0_0x0_0x0_0, sharding={devices=[1,2,1,1,1,1]0,1}
8056 }
8057 )";
8058 
8059   TF_ASSERT_OK_AND_ASSIGN(auto module,
8060                           PartitionComputation(hlo_string, /*num_devices=*/2));
8061   VLOG(1) << module->ToString();
8062 
8063   // Check that the partitioned code does not have a collective permute with an
8064   // empty source_target_pair list.
8065   for (HloComputation* computation : module->computations()) {
8066     for (HloInstruction* hlo : computation->instructions()) {
8067       EXPECT_FALSE(IsTrivialCollectivePermute(hlo)) << hlo->ToString();
8068     }
8069   }
8070 }
8071 
TEST_F(SpmdPartitioningTest,PadWithWrapPattern)8072 TEST_F(SpmdPartitioningTest, PadWithWrapPattern) {
8073   absl::string_view hlo_string = R"(
8074 HloModule xla_computation_apply_fn__4.61
8075 
8076 ENTRY %xla_computation_apply_fn__4.61 (parameter.7: f32[3,16,16,16,16,132]) -> f32[3,18,16,16,16,132] {
8077   %parameter.7 = f32[3,16,16,16,16,132]{5,4,3,2,1,0} parameter(0), sharding={devices=[1,2,1,1,1,1]0,1}
8078   %slice.2 = f32[3,1,16,16,16,132]{5,4,3,2,1,0} slice(f32[3,16,16,16,16,132]{5,4,3,2,1,0} %parameter.7), slice={[0:3], [15:16], [0:16], [0:16], [0:16], [0:132]}, sharding={devices=[1,2,1,1,1,1]0,1}
8079   %slice.3 = f32[3,1,16,16,16,132]{5,4,3,2,1,0} slice(f32[3,16,16,16,16,132]{5,4,3,2,1,0} %parameter.7), slice={[0:3], [0:1], [0:16], [0:16], [0:16], [0:132]}, sharding={devices=[1,2,1,1,1,1]0,1}
8080   ROOT %concatenate.3 = f32[3,18,16,16,16,132]{5,4,3,2,1,0} concatenate(f32[3,1,16,16,16,132]{5,4,3,2,1,0} %slice.2, f32[3,16,16,16,16,132]{5,4,3,2,1,0} %parameter.7, f32[3,1,16,16,16,132]{5,4,3,2,1,0} %slice.3), dimensions={1}, sharding={devices=[1,2,1,1,1,1]0,1}
8081 }
8082 )";
8083 
8084   TF_ASSERT_OK_AND_ASSIGN(auto module,
8085                           PartitionComputation(hlo_string, /*num_devices=*/2));
8086   VLOG(1) << module->ToString();
8087 
8088   // Check that the partitioned code does not have all-reduce and two
8089   // non-trivial collective permute instructions.
8090   for (HloComputation* computation : module->computations()) {
8091     for (HloInstruction* hlo : computation->instructions()) {
8092       EXPECT_FALSE(IsTrivialCollectivePermute(hlo)) << hlo->ToString();
8093       EXPECT_NE(hlo->opcode(), HloOpcode::kAllReduce) << hlo->ToString();
8094     }
8095   }
8096 }
8097 
TEST_F(SpmdPartitioningTest,PadWrapWithNegatePattern)8098 TEST_F(SpmdPartitioningTest, PadWrapWithNegatePattern) {
8099   absl::string_view hlo_string = R"(
8100 HloModule module
8101 
8102 ENTRY entry {
8103   %parameter.1 = f32[1,18] parameter(0), sharding={devices=[1,2]0,1}
8104   %slice.16 = f32[1,2] slice(f32[1,18] %parameter.1), slice={[0:1], [16:18]}, sharding={devices=[1,2]0,1}
8105   %negate.2 = f32[1,2] negate(f32[1,2] %slice.16), sharding={devices=[1,2]0,1}
8106   %slice.17 = f32[1,2] slice(f32[1,18] %parameter.1), slice={[0:1], [0:2]}, sharding={devices=[1,2]0,1}
8107   %negate.3 = f32[1,2] negate(f32[1,2] %slice.17), sharding={devices=[1,2]0,1}
8108   ROOT %concatenate.13 = f32[1,22] concatenate(f32[1,2] %negate.2, f32[1,18] %parameter.1, f32[1,2] %negate.3), dimensions={1}, sharding={devices=[1,2]0,1}
8109 }
8110 )";
8111   TF_ASSERT_OK_AND_ASSIGN(auto module,
8112                           PartitionComputation(hlo_string, /*num_devices=*/2));
8113   VLOG(1) << module->ToString();
8114 
8115   // Check that the partitioned code does not have all-reduce or trivial
8116   // collective permute
8117   for (HloComputation* computation : module->computations()) {
8118     for (HloInstruction* hlo : computation->instructions()) {
8119       EXPECT_FALSE(IsTrivialCollectivePermute(hlo)) << hlo->ToString();
8120       EXPECT_NE(hlo->opcode(), HloOpcode::kAllReduce) << hlo->ToString();
8121     }
8122   }
8123 }
8124 
TEST_F(SpmdPartitioningTest,PadWrapWithMultipleModifiersPattern)8125 TEST_F(SpmdPartitioningTest, PadWrapWithMultipleModifiersPattern) {
8126   absl::string_view hlo_string = R"(
8127 HloModule module
8128 
8129 ENTRY entry {
8130   %parameter.1 = f32[1,18] parameter(0), sharding={devices=[1,2]0,1}
8131   %slice.16 = f32[1,2] slice(f32[1,18] %parameter.1), slice={[0:1], [16:18]}, sharding={devices=[1,2]0,1}
8132   %mod0.16 = f32[1,2] rsqrt(f32[1,2] %slice.16), sharding={devices=[1,2]0,1}
8133   %mod1.16 = f32[1,2] sine(f32[1,2] %mod0.16), sharding={devices=[1,2]0,1}
8134   %slice.17 = f32[1,2] slice(f32[1,18] %parameter.1), slice={[0:1], [0:2]}, sharding={devices=[1,2]0,1}
8135   %mod0.17 = f16[1,2] convert(f32[1,2] %slice.17), sharding={devices=[1,2]0,1}
8136   %mod1.17 = f16[1,2] cosine(f16[1,2] %mod0.17), sharding={devices=[1,2]0,1}
8137   %mod2.17 = f32[1,2] convert(f16[1,2] %mod1.17), sharding={devices=[1,2]0,1}
8138   ROOT %concatenate.13 = f32[1,22] concatenate(f32[1,2] %mod1.16, f32[1,18] %parameter.1, f32[1,2] %mod2.17), dimensions={1}, sharding={devices=[1,2]0,1}
8139 }
8140 )";
8141   TF_ASSERT_OK_AND_ASSIGN(auto module,
8142                           PartitionComputation(hlo_string, /*num_devices=*/2));
8143   VLOG(1) << module->ToString();
8144 
8145   // Check that the partitioned code does not have all-reduce or trivial
8146   // collective permute. Also make sure modifiers have the right dependencies.
8147   for (HloComputation* computation : module->computations()) {
8148     for (HloInstruction* hlo : computation->instructions()) {
8149       const HloOpcode op = hlo->opcode();
8150       EXPECT_FALSE(IsTrivialCollectivePermute(hlo)) << hlo->ToString();
8151       EXPECT_NE(op, HloOpcode::kAllReduce) << hlo->ToString();
8152       if (hlo->operand_count() != 1) {
8153         continue;
8154       }
8155       const PrimitiveType type = hlo->shape().element_type();
8156       const HloOpcode child_op = hlo->operand(0)->opcode();
8157       const PrimitiveType child_type = hlo->operand(0)->shape().element_type();
8158 
8159       if (op == HloOpcode::kSin) {
8160         EXPECT_EQ(child_op, HloOpcode::kRsqrt);
8161       } else if (op == HloOpcode::kConvert && type == F32) {
8162         EXPECT_EQ(child_op, HloOpcode::kCos);
8163         EXPECT_EQ(child_type, F16);
8164       } else if (op == HloOpcode::kCos) {
8165         EXPECT_EQ(child_op, HloOpcode::kConvert);
8166         EXPECT_EQ(child_type, F16);
8167       }
8168     }
8169   }
8170 }
8171 
TEST_F(SpmdPartitioningTest,BroadcastAsReplicate)8172 TEST_F(SpmdPartitioningTest, BroadcastAsReplicate) {
8173   absl::string_view hlo_string = R"(
8174 HloModule module
8175 
8176 ENTRY entry {
8177   %param0 = f32[1,1] parameter(0), sharding={devices=[2,2]0,1,2,3}
8178   ROOT %copy = f32[1,1] copy(%param0), sharding={replicated}
8179 })";
8180 
8181   TF_ASSERT_OK_AND_ASSIGN(auto module,
8182                           PartitionComputation(hlo_string, /*num_devices=*/4));
8183   VLOG(1) << module->ToString();
8184 
8185   auto root = module->entry_computation()->root_instruction();
8186   auto param0 = AllOf(op::Parameter(0), op::Shape("f32[1,1]"));
8187   EXPECT_THAT(root, AllOf(op::Copy(op::AllReduce(op::Select(_, param0, _))),
8188                           op::Shape("f32[1,1]")));
8189 }
8190 
TEST_F(SpmdPartitioningTest,BroadcastAsReplicate2)8191 TEST_F(SpmdPartitioningTest, BroadcastAsReplicate2) {
8192   absl::string_view hlo_string = R"(
8193 HloModule module
8194 
8195 ENTRY entry {
8196   %param0 = f32[1,2] parameter(0), sharding={devices=[2,2]0,1,2,3}
8197   ROOT %copy = f32[1,2] copy(%param0), sharding={replicated}
8198 })";
8199 
8200   TF_ASSERT_OK_AND_ASSIGN(auto module,
8201                           PartitionComputation(hlo_string, /*num_devices=*/4));
8202   VLOG(1) << module->ToString();
8203 
8204   auto root = module->entry_computation()->root_instruction();
8205   auto param0 = AllOf(op::Parameter(0), op::Shape("f32[1,1]"));
8206   auto broadcast =
8207       AllOf(op::AllReduce(op::Select(_, param0, _)), op::Shape("f32[1,1]"));
8208   EXPECT_THAT(
8209       root,
8210       AllOf(op::Copy(op::AllReduce(op::DynamicUpdateSlice(_, broadcast, _, _))),
8211             op::Shape("f32[1,2]")));
8212 }
8213 
TEST_F(SpmdPartitioningTest,BroadcastAsReplicate3)8214 TEST_F(SpmdPartitioningTest, BroadcastAsReplicate3) {
8215   absl::string_view hlo_string = R"(
8216 HloModule module
8217 
8218 ENTRY entry {
8219   %param0 = f32[1,1] parameter(0),
8220     sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
8221   ROOT %copy = f32[1,1] copy(%param0), sharding={replicated}
8222 })";
8223 
8224   TF_ASSERT_OK_AND_ASSIGN(auto module,
8225                           PartitionComputation(hlo_string, /*num_devices=*/4));
8226   VLOG(1) << module->ToString();
8227 
8228   auto root = module->entry_computation()->root_instruction();
8229   auto param0 = AllOf(op::Parameter(0), op::Shape("f32[1,1]"));
8230   EXPECT_THAT(root, AllOf(op::Copy(op::AllReduce(op::Select(_, param0, _))),
8231                           op::Shape("f32[1,1]")));
8232 }
8233 
8234 }  // namespace
8235 }  // namespace spmd
8236 }  // namespace xla
8237