1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
19 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
20 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
21 #include "tensorflow/compiler/xla/service/hlo_parser.h"
22 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
23 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28
29 namespace xla {
30 namespace spmd {
31 namespace {
32
33 using ::testing::_;
34 using ::testing::AllOf;
35 namespace op = xla::testing::opcode_matchers;
36
37 class SpmdPartitioningTest : public HloTestBase {
38 public:
PartitionComputation(absl::string_view hlo_module,int64_t num_devices,bool conv_halo_exchange_always_on_lhs=true,bool choose_faster_windowed_einsum=false)39 StatusOr<std::unique_ptr<HloModule>> PartitionComputation(
40 absl::string_view hlo_module, int64_t num_devices,
41 bool conv_halo_exchange_always_on_lhs = true,
42 bool choose_faster_windowed_einsum = false) {
43 // Some tests (BackpropFilter convs) set this flag false to test two
44 // different paths of the implementation.
45 SpmdPartitionerOptions options;
46 options.conv_halo_exchange_always_on_lhs = conv_halo_exchange_always_on_lhs;
47 options.allow_module_signature_change = true;
48 options.choose_faster_windowed_einsum_over_mem =
49 choose_faster_windowed_einsum;
50 auto collective_ops_creator =
51 GetDefaultCollectiveOpsCreator(num_devices, /*num_replicas=*/1);
52 // Do not use all-gather for pattern-matching purpose, as the partitioner
53 // might create reshape/transposes around it.
54 collective_ops_creator.create_cross_partition_all_gather = nullptr;
55
56 TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(
57 hlo_module, GetModuleConfigForTest()));
58 HloPassPipeline pass("spmd-partitioning");
59 pass.AddPass<HloVerifier>(/*layout_sensitive=*/false,
60 /*allow_mixed_precision=*/false);
61 pass.AddPass<SpmdPartitioner>(num_devices, /*num_replicas=*/1, options,
62 collective_ops_creator);
63 pass.AddPass<HloVerifier>(/*layout_sensitive=*/false,
64 /*allow_mixed_precision=*/false);
65 TF_RETURN_IF_ERROR(pass.Run(module.get()).status());
66 return StatusOr<std::unique_ptr<HloModule>>(std::move(module));
67 }
68 };
69
TEST_F(SpmdPartitioningTest,InvalidSharding)70 TEST_F(SpmdPartitioningTest, InvalidSharding) {
71 absl::string_view hlo_string = R"(
72 HloModule module
73
74 ENTRY entry {
75 token0 = token[] after-all(), sharding={maximal device=0}
76 infeed = (f32[8,2]{1,0}, token[]) infeed(token0),
77 sharding={{devices=[2,1]0,1}, {maximal device=0}}
78 ROOT infeed.data = f32[8,2]{1,0} get-tuple-element(infeed), index=0,
79 sharding={maximal device=0}
80 })";
81 auto module_status = PartitionComputation(hlo_string, /*num_devices=*/4);
82 EXPECT_FALSE(module_status.status().ok());
83 EXPECT_THAT(module_status.status().ToString(),
84 ::testing::HasSubstr(
85 "only supports tile sharding that includes all partitions"));
86 }
87
TEST_F(SpmdPartitioningTest,SingleDeviceToReplicated)88 TEST_F(SpmdPartitioningTest, SingleDeviceToReplicated) {
89 absl::string_view hlo_string = R"(
90 HloModule module
91
92 ENTRY entry {
93 %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
94 sharding={maximal device=0}
95 ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={replicated}
96 })";
97 TF_ASSERT_OK_AND_ASSIGN(auto module,
98 PartitionComputation(hlo_string, /*num_devices=*/2));
99 VLOG(1) << module->ToString();
100 HloInstruction* root = module->entry_computation()->root_instruction();
101 EXPECT_THAT(root, AllOf(op::Copy(op::AllReduce(
102 op::Select(op::Broadcast(op::Compare()),
103 op::Constant(), op::Broadcast()))),
104 op::Shape("s32[2,3]")));
105 }
106
TEST_F(SpmdPartitioningTest,SingleDeviceToSingleDevice)107 TEST_F(SpmdPartitioningTest, SingleDeviceToSingleDevice) {
108 absl::string_view hlo_string = R"(
109 HloModule module
110
111 ENTRY entry {
112 %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
113 sharding={maximal device=0}
114 ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={maximal device=1}
115 })";
116 TF_ASSERT_OK_AND_ASSIGN(auto module,
117 PartitionComputation(hlo_string, /*num_devices=*/2));
118 HloInstruction* root = module->entry_computation()->root_instruction();
119 VLOG(1) << module->ToString();
120 EXPECT_THAT(root, op::Copy(AllOf(op::Copy(op::AllReduce(op::Select(
121 op::Broadcast(op::Compare()),
122 op::Constant(), op::Broadcast()))),
123 op::Shape("s32[2,3]"))));
124 }
125
TEST_F(SpmdPartitioningTest,SingleDeviceToTiled)126 TEST_F(SpmdPartitioningTest, SingleDeviceToTiled) {
127 absl::string_view hlo_string = R"(
128 HloModule module
129
130 ENTRY entry {
131 %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
132 sharding={maximal device=0}
133 ROOT %copy = s32[2,3]{1,0} copy(%constant),
134 sharding={devices=[2,1]1,0}
135 })";
136 TF_ASSERT_OK_AND_ASSIGN(auto module,
137 PartitionComputation(hlo_string, /*num_devices=*/2));
138 VLOG(1) << module->ToString();
139 HloInstruction* root = module->entry_computation()->root_instruction();
140 EXPECT_THAT(
141 root,
142 AllOf(
143 op::Copy(op::DynamicSlice(
144 op::AllReduce(op::Select(
145 op::Broadcast(op::Compare(op::PartitionId(), op::Constant())),
146 op::Constant(), op::Broadcast())),
147 op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())),
148 op::Constant())),
149 op::Shape("s32[1,3]")));
150 }
151
TEST_F(SpmdPartitioningTest,TiledToReplicated)152 TEST_F(SpmdPartitioningTest, TiledToReplicated) {
153 absl::string_view hlo_string = R"(
154 HloModule module
155
156 ENTRY entry {
157 %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
158 sharding={devices=[2,1]0,1}
159 ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={replicated}
160 })";
161 TF_ASSERT_OK_AND_ASSIGN(auto module,
162 PartitionComputation(hlo_string, /*num_devices=*/2));
163 HloInstruction* root = module->entry_computation()->root_instruction();
164 EXPECT_THAT(
165 root,
166 op::Copy(op::AllReduce(AllOf(
167 op::DynamicUpdateSlice(
168 op::Broadcast(), AllOf(op::Constant(), op::Shape("s32[1,3]")),
169 op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())),
170 op::Constant()),
171 op::Shape("s32[2,3]")))));
172 }
173
TEST_F(SpmdPartitioningTest,TiledToSingleDevice)174 TEST_F(SpmdPartitioningTest, TiledToSingleDevice) {
175 absl::string_view hlo_string = R"(
176 HloModule module
177
178 ENTRY entry {
179 %constant = s32[2,3]{1,0} constant({{1,1,1},{1,1,1}}),
180 sharding={devices=[2,1]0,1}
181 ROOT %copy = s32[2,3]{1,0} copy(%constant), sharding={maximal device=0}
182 })";
183 TF_ASSERT_OK_AND_ASSIGN(auto module,
184 PartitionComputation(hlo_string, /*num_devices=*/2));
185 HloInstruction* root = module->entry_computation()->root_instruction();
186 EXPECT_THAT(
187 root,
188 op::Copy(op::Copy(op::AllReduce(AllOf(
189 op::DynamicUpdateSlice(
190 op::Broadcast(), AllOf(op::Constant(), op::Shape("s32[1,3]")),
191 op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())),
192 op::Constant()),
193 op::Shape("s32[2,3]"))))));
194 }
195
TEST_F(SpmdPartitioningTest,TiledToTiledEven)196 TEST_F(SpmdPartitioningTest, TiledToTiledEven) {
197 absl::string_view hlo_string = R"(
198 HloModule module
199
200 ENTRY entry {
201 %param= s32[8,2]{1,0} parameter(0), sharding={devices=[2,1]0,1}
202 ROOT %copy = s32[8,2]{1,0} copy(%param), sharding={devices=[1,2]0,1}
203 })";
204 TF_ASSERT_OK_AND_ASSIGN(auto module,
205 PartitionComputation(hlo_string, /*num_devices=*/2));
206 VLOG(1) << module->ToString();
207
208 HloInstruction* root = module->entry_computation()->root_instruction();
209 EXPECT_THAT(
210 root,
211 AllOf(op::Copy(op::Reshape(op::Transpose(op::AllToAll(AllOf(
212 op::Reshape(op::Parameter()), op::Shape("s32[4,2,1]")))))),
213 op::Shape("s32[8,1]")));
214 }
215
TEST_F(SpmdPartitioningTest,TiledToTiledUneven)216 TEST_F(SpmdPartitioningTest, TiledToTiledUneven) {
217 absl::string_view hlo_string = R"(
218 HloModule module
219
220 ENTRY entry {
221 %param= f32[7,31,128]{2,1,0} parameter(0), sharding={devices=[1,2,1]0,1}
222 ROOT %copy = f32[7,31,128]{2,1,0} copy(%param), sharding={devices=[2,1,1]0,1}
223 })";
224 TF_ASSERT_OK_AND_ASSIGN(auto module,
225 PartitionComputation(hlo_string, /*num_devices=*/2));
226 VLOG(1) << module->ToString();
227
228 HloInstruction* root = module->entry_computation()->root_instruction();
229 EXPECT_THAT(
230 root,
231 AllOf(op::Copy(op::Slice(op::Reshape(AllOf(op::Transpose(op::AllToAll(
232 op::Reshape(AllOf(op::Pad(), op::Shape("f32[8,16,128]")))))))))));
233 }
234
TEST_F(SpmdPartitioningTest,GetTupleElementSwapDevice)235 TEST_F(SpmdPartitioningTest, GetTupleElementSwapDevice) {
236 absl::string_view hlo_string = R"(
237 HloModule module
238
239 ENTRY entry {
240 %param.0 = (f32[2,3]{1,0}, u32[]) parameter(0),
241 sharding={{maximal device=1}, {maximal device=1}}
242 %gte.0 = f32[2,3]{1,0} get-tuple-element(%param.0), index=0,
243 sharding={maximal device=0}
244 %gte.1 = u32[] get-tuple-element(%param.0), index=1,
245 sharding={maximal device=0}
246 ROOT %tuple = (f32[2,3]{1,0}, u32[]) tuple(%gte.0, %gte.1),
247 sharding={{maximal device=0},{maximal device=0}}
248 })";
249 TF_ASSERT_OK_AND_ASSIGN(auto module,
250 PartitionComputation(hlo_string, /*num_devices=*/2));
251 VLOG(1) << module->ToString();
252 HloInstruction* root = module->entry_computation()->root_instruction();
253 ASSERT_THAT(root, op::Tuple());
254
255 EXPECT_THAT(root->operand(0),
256 op::Copy(op::AllReduce(op::Select(
257 op::Broadcast(op::Compare(op::PartitionId(), op::Constant())),
258 op::GetTupleElement(op::Parameter()), op::Broadcast()))));
259 EXPECT_THAT(root->operand(1),
260 op::Copy(op::AllReduce(op::Select(
261 op::Broadcast(op::Compare(op::PartitionId(), op::Constant())),
262 op::GetTupleElement(op::Parameter()), op::Broadcast()))));
263 }
264
TEST_F(SpmdPartitioningTest,GetTupleElementTiled)265 TEST_F(SpmdPartitioningTest, GetTupleElementTiled) {
266 absl::string_view hlo_string = R"(
267 HloModule module
268
269 ENTRY entry {
270 param.0 = (f32[2,3]{1,0}, u32[2,3]{1,0}) parameter(0),
271 sharding={{replicated}, {replicated}}
272 gte.0 = f32[2,3]{1,0} get-tuple-element(param.0), index=0,
273 sharding={devices=[2,1]0,1}
274 gte.1 = u32[2,3]{1,0} get-tuple-element(param.0), index=1,
275 sharding={devices=[2,1]0,1}
276 ROOT %tuple = (f32[2,3]{1,0}, u32[2,3]{1,0}) tuple(gte.0, gte.1),
277 sharding={{devices=[2,1]0,1},{devices=[2,1]0,1}}
278 })";
279 TF_ASSERT_OK_AND_ASSIGN(auto module,
280 PartitionComputation(hlo_string, /*num_devices=*/2));
281 VLOG(1) << module->ToString();
282 HloInstruction* root = module->entry_computation()->root_instruction();
283 ASSERT_THAT(root, op::Tuple());
284
285 auto offset =
286 op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
287
288 EXPECT_THAT(root->operand(0),
289 op::DynamicSlice(op::GetTupleElement(op::Parameter()), offset,
290 op::Constant()));
291 EXPECT_THAT(root->operand(1),
292 op::DynamicSlice(op::GetTupleElement(op::Parameter()), offset,
293 op::Constant()));
294 }
295
TEST_F(SpmdPartitioningTest,TiledInfeed)296 TEST_F(SpmdPartitioningTest, TiledInfeed) {
297 absl::string_view hlo_string = R"(
298 HloModule module
299
300 ENTRY entry {
301 token0 = token[] after-all(), sharding={maximal device=0}
302 infeed = (f32[8,2]{1,0}, token[]) infeed(token0),
303 sharding={{devices=[2,1]0,1}, {maximal device=0}}
304 ROOT infeed.data = f32[8,2]{1,0} get-tuple-element(infeed), index=0,
305 sharding={maximal device=0}
306 })";
307 TF_ASSERT_OK_AND_ASSIGN(auto module,
308 PartitionComputation(hlo_string, /*num_devices=*/2));
309 HloInstruction* root = module->entry_computation()->root_instruction();
310 EXPECT_THAT(
311 root,
312 op::Copy(op::AllReduce(op::DynamicUpdateSlice(
313 op::Broadcast(),
314 op::GetTupleElement(
315 AllOf(op::Infeed(), op::Shape("(f32[4,2]{1,0}, token[])"))),
316 op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())),
317 op::Constant()))));
318 }
319
TEST_F(SpmdPartitioningTest,UnevenTiledInfeed)320 TEST_F(SpmdPartitioningTest, UnevenTiledInfeed) {
321 absl::string_view hlo_string = R"(
322 HloModule module
323
324 ENTRY entry {
325 token0 = token[] after-all(), sharding={maximal device=0}
326 infeed = (f32[9,2]{1,0}, token[]) infeed(token0),
327 sharding={{devices=[2,1]0,1}, {maximal device=0}}
328 ROOT infeed.data = f32[9,2]{1,0} get-tuple-element(infeed), index=0,
329 sharding={devices=[2,1]0,1}
330 })";
331 TF_ASSERT_OK_AND_ASSIGN(auto module,
332 PartitionComputation(hlo_string, /*num_devices=*/2));
333 VLOG(1) << module->ToString();
334 HloInstruction* root = module->entry_computation()->root_instruction();
335 EXPECT_THAT(
336 root, AllOf(op::Shape("f32[5,2]"), op::GetTupleElement(op::Conditional(
337 op::Convert(op::PartitionId()),
338 op::AfterAll(), op::AfterAll()))));
339 EXPECT_THAT(
340 root->operand(0)->called_computations()[0]->root_instruction(),
341 AllOf(op::Shape("(f32[5,2], token[])"), op::Infeed(op::Parameter())));
342 auto second_infeed =
343 AllOf(op::Shape("(f32[4,2], token[])"), op::Infeed(op::Parameter()));
344 EXPECT_THAT(root->operand(0)->called_computations()[1]->root_instruction(),
345 AllOf(op::Shape("(f32[5,2], token[])"),
346 op::Tuple(op::Pad(op::GetTupleElement(second_infeed),
347 op::Constant()),
348 op::GetTupleElement(second_infeed))));
349 }
350
TEST_F(SpmdPartitioningTest,UnevenTiledTupleInfeed)351 TEST_F(SpmdPartitioningTest, UnevenTiledTupleInfeed) {
352 absl::string_view hlo_string = R"(
353 HloModule module
354
355 ENTRY entry {
356 token0 = token[] after-all(), sharding={maximal device=0}
357 infeed = ((f32[9,2]{1,0}, f32[2]{0}), token[]) infeed(token0),
358 sharding={{devices=[2,1]0,1}, {replicated}, {maximal device=0}}
359 ROOT infeed.data = (f32[9,2]{1,0}, f32[2]{0}) get-tuple-element(infeed),
360 index=0, sharding={{devices=[2,1]0,1}, {replicated}}
361 })";
362 TF_ASSERT_OK_AND_ASSIGN(auto module,
363 PartitionComputation(hlo_string, /*num_devices=*/2));
364 VLOG(1) << module->ToString();
365 HloInstruction* root = module->entry_computation()->root_instruction();
366 EXPECT_THAT(root, AllOf(op::Shape("(f32[5,2], f32[2])"),
367 op::GetTupleElement(op::Conditional(
368 op::Convert(op::PartitionId()), op::AfterAll(),
369 op::AfterAll()))));
370 EXPECT_THAT(root->operand(0)->called_computations()[0]->root_instruction(),
371 AllOf(op::Shape("((f32[5,2], f32[2]), token[])"),
372 op::Infeed(op::Parameter())));
373 auto second_infeed = AllOf(op::Shape("((f32[4,2], f32[2]), token[])"),
374 op::Infeed(op::Parameter()));
375 EXPECT_THAT(
376 root->operand(0)->called_computations()[1]->root_instruction(),
377 AllOf(op::Shape("((f32[5,2], f32[2]), token[])"),
378 op::Tuple(op::Tuple(op::Pad(op::GetTupleElement(
379 op::GetTupleElement(second_infeed)),
380 op::Constant()),
381 op::GetTupleElement(
382 op::GetTupleElement(second_infeed))),
383 op::GetTupleElement(second_infeed))));
384 }
385
TEST_F(SpmdPartitioningTest,MixedTupleInfeed)386 TEST_F(SpmdPartitioningTest, MixedTupleInfeed) {
387 absl::string_view hlo_string = R"(
388 HloModule module
389
390 ENTRY entry {
391 token0 = token[] after-all(), sharding={maximal device=0}
392 infeed = ((f32[9,2]{1,0}, f32[2]{0}), token[]) infeed(token0),
393 sharding={{maximal device=0}, {maximal device=1}, {maximal device=0}}
394 ROOT infeed.data = (f32[9,2]{1,0}, f32[2]{0}) get-tuple-element(infeed),
395 index=0, sharding={{maximal device=0}, {maximal device=1}}
396 })";
397 TF_ASSERT_OK_AND_ASSIGN(auto module,
398 PartitionComputation(hlo_string, /*num_devices=*/2));
399 VLOG(1) << module->ToString();
400 HloInstruction* root = module->entry_computation()->root_instruction();
401 EXPECT_THAT(root, AllOf(op::Shape("(f32[9,2], f32[2])"),
402 op::GetTupleElement(op::Conditional(
403 op::Convert(op::PartitionId()), op::AfterAll(),
404 op::AfterAll()))));
405 auto first_infeed = AllOf(op::Shape("((f32[9,2], ()), token[])"),
406 op::Infeed(op::Parameter()));
407 EXPECT_THAT(root->operand(0)->called_computations()[0]->root_instruction(),
408 AllOf(op::Shape("((f32[9,2], f32[2]), token[])"),
409 op::Tuple(op::Tuple(op::GetTupleElement(
410 op::GetTupleElement(first_infeed)),
411 op::Broadcast(op::Constant())),
412 op::GetTupleElement(first_infeed))));
413 auto second_infeed =
414 AllOf(op::Shape("(((), f32[2]), token[])"), op::Infeed(op::Parameter()));
415 EXPECT_THAT(root->operand(0)->called_computations()[1]->root_instruction(),
416 AllOf(op::Shape("((f32[9,2], f32[2]), token[])"),
417 op::Tuple(op::Tuple(op::Broadcast(op::Constant()),
418 op::GetTupleElement(op::GetTupleElement(
419 second_infeed))),
420 op::GetTupleElement(second_infeed))));
421 }
422
TEST_F(SpmdPartitioningTest,TiledToReplicatedReduce)423 TEST_F(SpmdPartitioningTest, TiledToReplicatedReduce) {
424 absl::string_view hlo_string = R"(
425 HloModule module
426
427 sum {
428 a = f32[] parameter(0)
429 b = f32[] parameter(1)
430 ROOT add = f32[] add(a, b)
431 }
432
433 ENTRY entry {
434 constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}),
435 sharding={devices=[2,1]0,1}
436 constant.1 = f32[] constant(0), sharding={replicated}
437 ROOT reduce = f32[] reduce(constant, constant.1), dimensions={0,1},
438 to_apply=sum, sharding={replicated}
439 })";
440 TF_ASSERT_OK_AND_ASSIGN(auto module,
441 PartitionComputation(hlo_string, /*num_devices=*/2));
442 VLOG(1) << module->ToString();
443 HloInstruction* root = module->entry_computation()->root_instruction();
444 EXPECT_THAT(
445 root,
446 op::AllReduce(op::Reduce(
447 op::Select(
448 op::Compare(op::Add(op::Iota(), op::Broadcast(op::Reshape())),
449 op::Broadcast(op::Constant())),
450 AllOf(op::Shape("f32[2,3]{1,0}"),
451 op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
452 op::Reshape(), op::Constant())),
453 op::Broadcast(op::Constant())),
454 op::Constant())));
455 }
456
TEST_F(SpmdPartitioningTest,TiledElementwise)457 TEST_F(SpmdPartitioningTest, TiledElementwise) {
458 absl::string_view hlo_string = R"(
459 HloModule module
460
461 ENTRY entry {
462 constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}),
463 sharding={devices=[2,1]0,1}
464 constant.1 = f32[3,3]{1,0} constant({{2,2,2},{2,2,2},{2,2,2}}),
465 sharding={replicated}
466 multiply = f32[3,3]{1,0} multiply(constant, constant.1),
467 sharding={devices=[2,1]0,1}
468 ROOT add = f32[3,3]{1,0} add(multiply, constant.1),
469 sharding={devices=[2,1]0,1}
470 })";
471 TF_ASSERT_OK_AND_ASSIGN(auto module,
472 PartitionComputation(hlo_string, /*num_devices=*/2));
473 VLOG(1) << module->ToString();
474 HloInstruction* root = module->entry_computation()->root_instruction();
475 EXPECT_THAT(
476 root,
477 AllOf(
478 op::Shape("f32[2,3]{1,0}"),
479 op::Add(op::Multiply(
480 op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
481 op::Reshape(), op::Constant()),
482 op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
483 op::Reshape(), op::Constant())),
484 op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
485 op::Reshape(), op::Constant()))));
486 }
487
TEST_F(SpmdPartitioningTest,TiledAllReduce)488 TEST_F(SpmdPartitioningTest, TiledAllReduce) {
489 absl::string_view hlo_string = R"(
490 HloModule module
491
492 sum {
493 a = f32[] parameter(0)
494 b = f32[] parameter(1)
495 ROOT add = f32[] add(a, b)
496 }
497
498 ENTRY entry {
499 parameter = f32[3,3]{1,0} parameter(0), sharding={devices=[2,1]0,1}
500 ROOT all-reduce = f32[3,3]{1,0} all-reduce(parameter), to_apply=sum,
501 replica_groups={}, sharding={devices=[2,1]0,1}
502 })";
503 TF_ASSERT_OK_AND_ASSIGN(auto module,
504 PartitionComputation(hlo_string, /*num_devices=*/2));
505 VLOG(1) << module->ToString();
506 HloInstruction* root = module->entry_computation()->root_instruction();
507 EXPECT_THAT(
508 root, AllOf(op::Shape("f32[2,3]{1,0}"), op::AllReduce(op::Parameter(0))));
509 }
510
TEST_F(SpmdPartitioningTest,BroadcastOnlyNewDimsSharded)511 TEST_F(SpmdPartitioningTest, BroadcastOnlyNewDimsSharded) {
512 absl::string_view hlo_string = R"(
513 HloModule module
514
515 ENTRY entry {
516 constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}),
517 sharding={replicated}
518 ROOT broadcast = f32[3,4,3]{2,1,0} broadcast(constant), dimensions={1,2},
519 sharding={devices=[2,1,1]0,1}
520 })";
521 TF_ASSERT_OK_AND_ASSIGN(auto module,
522 PartitionComputation(hlo_string, /*num_devices=*/2));
523 VLOG(1) << module->ToString();
524 HloInstruction* root = module->entry_computation()->root_instruction();
525 EXPECT_THAT(root, AllOf(op::Shape("f32[2,4,3]{2,1,0}"),
526 op::Broadcast(op::Constant())));
527 }
528
TEST_F(SpmdPartitioningTest,BroadcastOnlyOldDimsSharded)529 TEST_F(SpmdPartitioningTest, BroadcastOnlyOldDimsSharded) {
530 absl::string_view hlo_string = R"(
531 HloModule module
532
533 ENTRY entry {
534 constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}),
535 sharding={replicated}
536 ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2},
537 sharding={devices=[1,2,1]0,1}
538 })";
539 TF_ASSERT_OK_AND_ASSIGN(auto module,
540 PartitionComputation(hlo_string, /*num_devices=*/2));
541 VLOG(1) << module->ToString();
542 HloInstruction* root = module->entry_computation()->root_instruction();
543 EXPECT_THAT(root, AllOf(op::Shape("f32[4,2,3]{2,1,0}"),
544 op::Broadcast(op::DynamicSlice(
545 op::Constant(), op::Reshape(), op::Constant()))));
546 }
547
TEST_F(SpmdPartitioningTest,BroadcastBothOldAndNewDimsSharded)548 TEST_F(SpmdPartitioningTest, BroadcastBothOldAndNewDimsSharded) {
549 absl::string_view hlo_string = R"(
550 HloModule module
551
552 ENTRY entry {
553 constant = f32[4,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1},{1,1,1}}),
554 sharding={replicated}
555 ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2},
556 sharding={devices=[2,2,1]0,1,2,3}
557 })";
558 TF_ASSERT_OK_AND_ASSIGN(auto module,
559 PartitionComputation(hlo_string, /*num_devices=*/4));
560 VLOG(1) << module->ToString();
561 HloInstruction* root = module->entry_computation()->root_instruction();
562 EXPECT_THAT(
563 root,
564 AllOf(op::Shape("f32[2,2,3]{2,1,0}"),
565 op::Broadcast(AllOf(op::Shape("f32[2,3]{1,0}"),
566 op::DynamicSlice(op::Constant(), op::Reshape(),
567 op::Constant())))));
568 }
569
TEST_F(SpmdPartitioningTest,BroadcastBothOldAndNewDimsShardedPartiallySharded)570 TEST_F(SpmdPartitioningTest,
571 BroadcastBothOldAndNewDimsShardedPartiallySharded) {
572 absl::string_view hlo_string = R"(
573 HloModule module
574
575 ENTRY entry {
576 param = f32[4,3] parameter(0),
577 sharding={devices=[1,2,4]0,1,4,5,2,3,6,7 last_tile_dim_replicate}
578 ROOT broadcast = f32[4,4,3] broadcast(param), dimensions={1,2},
579 sharding={devices=[2,1,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
580 })";
581 TF_ASSERT_OK_AND_ASSIGN(auto module,
582 PartitionComputation(hlo_string, /*num_devices=*/8));
583 VLOG(1) << module->ToString();
584 HloInstruction* root = module->entry_computation()->root_instruction();
585 EXPECT_THAT(
586 root,
587 AllOf(op::Shape("f32[2,4,2]"),
588 op::Broadcast(AllOf(op::Shape("f32[4,2]"), op::Parameter(0)))));
589 }
590
TEST_F(SpmdPartitioningTest,ConvWithParallelDimAndNonParallelSpatialDimPartitioned)591 TEST_F(SpmdPartitioningTest,
592 ConvWithParallelDimAndNonParallelSpatialDimPartitioned) {
593 absl::string_view hlo_string = R"(
594 HloModule module
595
596 ENTRY entry {
597 %lhs = f32[32,12,12,24,32] parameter(0)
598 %lhs.copy = f32[32,12,12,24,32] copy(%lhs),
599 sharding={devices=[2,2,1,1,1]0,1,2,3}
600 %rhs = f32[32,6,6,16,32] parameter(1)
601 %rhs.copy = f32[32,6,6,16,32] copy(%rhs),
602 sharding={devices=[2,2,1,1,1]0,1,2,3}
603 ROOT %conv = f32[32,7,7,24,16] convolution(%lhs.copy, %rhs.copy),
604 dim_labels=012bf_012oi->012bf,
605 window={size=32x6x6 stride=31x1x1 lhs_dilate=32x1x1},
606 sharding={devices=[2,2,1,1,1]0,1,2,3}
607 })";
608
609 TF_ASSERT_OK_AND_ASSIGN(auto module,
610 PartitionComputation(hlo_string, /*num_devices=*/4));
611 VLOG(1) << module->ToString();
612 auto root = module->entry_computation()->root_instruction();
613 auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
614 op::Reshape(), op::Constant(),
615 op::Constant(), op::Constant())),
616 op::Shape("f32[16,6,12,24,32]"));
617 auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
618 op::Reshape(), op::Constant(),
619 op::Constant(), op::Constant())),
620 op::Shape("f32[16,3,6,16,32]"));
621 auto resharded_rhs =
622 AllOf(op::Shape("f32[16,6,6,16,32]"),
623 op::AllReduce(op::DynamicUpdateSlice(
624 op::Broadcast(), rhs, op::Constant(), op::Reshape(),
625 op::Constant(), op::Constant(), op::Constant())));
626
627 auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
628 op::Shape("f32[16,2,12,24,32]"));
629 auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
630 op::Shape("f32[16,3,12,24,32]"));
631 EXPECT_THAT(
632 root,
633 AllOf(op::Convolution(
634 op::Select(op::Compare(),
635 op::DynamicSlice(
636 op::Concatenate(left_halo, lhs, right_halo),
637 op::Constant(), op::Add(), op::Constant(),
638 op::Constant(), op::Constant()),
639 op::Broadcast()),
640 resharded_rhs),
641 op::Shape("f32[16,4,7,24,16]")));
642 }
643
TEST_F(SpmdPartitioningTest,BroadcastPropagateTiledSharding)644 TEST_F(SpmdPartitioningTest, BroadcastPropagateTiledSharding) {
645 absl::string_view hlo_string = R"(
646 HloModule module
647
648 ENTRY entry {
649 constant = f32[4,3]{1,0} constant({{1,1,1},{1,4,1},{1,3,1},{1,2,1}}),
650 sharding={devices=[2,1]0,1}
651 ROOT broadcast = f32[4,4,3]{2,1,0} broadcast(constant), dimensions={1,2},
652 sharding={devices=[1,2,1]0,1}
653 })";
654 TF_ASSERT_OK_AND_ASSIGN(auto module,
655 PartitionComputation(hlo_string, /*num_devices=*/2));
656 VLOG(1) << module->ToString();
657 HloInstruction* root = module->entry_computation()->root_instruction();
658 EXPECT_THAT(root, AllOf(op::Shape("f32[4,2,3]{2,1,0}"),
659 op::Broadcast(op::DynamicSlice(
660 op::Constant(), op::Reshape(), op::Constant()))));
661 }
662
TEST_F(SpmdPartitioningTest,OutfeedSingleDevice)663 TEST_F(SpmdPartitioningTest, OutfeedSingleDevice) {
664 absl::string_view hlo_string = R"(
665 HloModule module
666
667 ENTRY entry {
668 token.0 = token[] after-all()
669 data = f32[1024]{0} parameter(0), sharding={maximal device=0}
670 outfeed = token[] outfeed(data, token.0), sharding={maximal device=0}
671 })";
672 TF_ASSERT_OK_AND_ASSIGN(auto module,
673 PartitionComputation(hlo_string, /*num_devices=*/2));
674 VLOG(1) << module->ToString();
675 HloInstruction* root = module->entry_computation()->root_instruction();
676 EXPECT_THAT(root, AllOf(op::Shape("token[]"),
677 op::Conditional(
678 op::Compare(op::PartitionId(), op::Constant()),
679 op::Tuple(op::Parameter(0), op::AfterAll()),
680 op::Tuple(op::Parameter(0), op::AfterAll()))));
681
682 HloInstruction* root_b0 = root->branch_computation(0)->root_instruction();
683 EXPECT_THAT(root_b0,
684 AllOf(op::Shape("token[]"),
685 op::Outfeed(op::GetTupleElement(op::Parameter(), 0),
686 op::GetTupleElement(op::Parameter(), 1))));
687
688 HloInstruction* root_b1 = root->branch_computation(1)->root_instruction();
689 EXPECT_THAT(root_b1, AllOf(op::Shape("token[]"), op::AfterAll()));
690 }
691
TEST_F(SpmdPartitioningTest,OutfeedEvenlyTiled)692 TEST_F(SpmdPartitioningTest, OutfeedEvenlyTiled) {
693 absl::string_view hlo_string = R"(
694 HloModule module
695
696 ENTRY entry {
697 token.0 = token[] after-all()
698 data = f32[1024]{0} parameter(0), sharding={devices=[2]0,1}
699 ROOT outfeed = token[] outfeed(data, token.0), sharding={devices=[2]0,1}
700 })";
701 TF_ASSERT_OK_AND_ASSIGN(auto module,
702 PartitionComputation(hlo_string, /*num_devices=*/2));
703 VLOG(1) << module->ToString();
704 HloInstruction* root = module->entry_computation()->root_instruction();
705 EXPECT_THAT(root, AllOf(op::Shape("token[]"),
706 op::Outfeed(op::Parameter(), op::AfterAll())));
707 }
708
TEST_F(SpmdPartitioningTest,OutfeedTupleEvenlyTiled)709 TEST_F(SpmdPartitioningTest, OutfeedTupleEvenlyTiled) {
710 absl::string_view hlo_string = R"(
711 HloModule module
712
713 ENTRY entry {
714 token.0 = token[] after-all()
715 data = (f32[1024,2]{1,0}, f32[2]{0}) parameter(0), sharding={{devices=[2,1]0,1},
716 {devices=[2]0,1}}
717 ROOT outfeed = token[] outfeed(data, token.0),
718 outfeed_shape=(f32[1024,2]{0,1}, f32[2]{0}), sharding={{devices=[2,1]0,1},
719 {devices=[2]0,1}}
720 })";
721 TF_ASSERT_OK_AND_ASSIGN(auto module,
722 PartitionComputation(hlo_string, /*num_devices=*/2));
723 VLOG(1) << module->ToString();
724 HloInstruction* root = module->entry_computation()->root_instruction();
725 EXPECT_THAT(root, AllOf(op::Shape("token[]"),
726 op::Outfeed(op::Parameter(), op::AfterAll())));
727 auto expected_layout0 = LayoutUtil::MakeLayout({0, 1});
728 auto expected_layout1 = LayoutUtil::MakeLayout({0});
729 EXPECT_TRUE(LayoutUtil::Equal(root->outfeed_shape().tuple_shapes(0).layout(),
730 expected_layout0));
731 EXPECT_TRUE(LayoutUtil::Equal(root->outfeed_shape().tuple_shapes(1).layout(),
732 expected_layout1));
733 }
734
TEST_F(SpmdPartitioningTest,OutfeedReplicated)735 TEST_F(SpmdPartitioningTest, OutfeedReplicated) {
736 absl::string_view hlo_string = R"(
737 HloModule module
738
739 ENTRY entry {
740 token.0 = token[] after-all()
741 data = (f32[1024,2]{1,0}, f32[2]{0}) parameter(0), sharding={{devices=[2,1]0,1},
742 {replicated}}
743 ROOT outfeed = token[] outfeed(data, token.0), sharding={{devices=[2,1]0,1},
744 {replicated}}
745 })";
746 TF_ASSERT_OK_AND_ASSIGN(auto module,
747 PartitionComputation(hlo_string, /*num_devices=*/2));
748 VLOG(1) << module->ToString();
749 HloInstruction* root = module->entry_computation()->root_instruction();
750 EXPECT_THAT(root, AllOf(op::Shape("token[]"),
751 op::Outfeed(op::Parameter(), op::AfterAll())));
752 }
753
TEST_F(SpmdPartitioningTest,OutfeedUnevenlyTiled)754 TEST_F(SpmdPartitioningTest, OutfeedUnevenlyTiled) {
755 absl::string_view hlo_string = R"(
756 HloModule module
757
758 ENTRY entry {
759 token.0 = token[] after-all()
760 data = (f32[1023,2]{1,0}, f32[3]{0}) parameter(0), sharding={{devices=[2,1]0,1},
761 {devices=[2]0,1}}
762 outfeed = token[] outfeed(data, token.0),
763 outfeed_shape=(f32[1023,2]{0,1}, f32[3]{0}), sharding={{devices=[2,1]0,1},
764 {devices=[2]0,1}}
765 })";
766 TF_ASSERT_OK_AND_ASSIGN(auto module,
767 PartitionComputation(hlo_string, /*num_devices=*/2));
768 VLOG(1) << module->ToString();
769
770 HloInstruction* root = module->entry_computation()->root_instruction();
771 EXPECT_THAT(
772 root, AllOf(op::Shape("token[]"),
773 op::Conditional(op::Convert(),
774 op::Tuple(op::Parameter(), op::AfterAll()),
775 op::Tuple(op::Parameter(), op::AfterAll()))));
776
777 auto first_outfeed =
778 AllOf(op::Shape("(f32[512,2], f32[2])"), op::GetTupleElement());
779 EXPECT_THAT(root->called_computations()[0]->root_instruction(),
780 AllOf(op::Shape("token[]"),
781 op::Outfeed(first_outfeed, op::GetTupleElement())));
782
783 auto second_outfeed = AllOf(op::Shape("(f32[511,2], f32[1])"), op::Tuple());
784 EXPECT_THAT(root->called_computations()[1]->root_instruction(),
785 AllOf(op::Shape("token[]"),
786 op::Outfeed(second_outfeed, op::GetTupleElement())));
787
788 auto expected_layout0 = LayoutUtil::MakeLayout({0, 1});
789 auto expected_layout1 = LayoutUtil::MakeLayout({0});
790 auto first_outfeed_instr = root->called_computations()[0]->root_instruction();
791 auto second_outfeed_instr =
792 root->called_computations()[1]->root_instruction();
793 EXPECT_TRUE(LayoutUtil::Equal(
794 first_outfeed_instr->outfeed_shape().tuple_shapes(0).layout(),
795 expected_layout0));
796 EXPECT_TRUE(LayoutUtil::Equal(
797 first_outfeed_instr->outfeed_shape().tuple_shapes(1).layout(),
798 expected_layout1));
799 EXPECT_TRUE(LayoutUtil::Equal(
800 second_outfeed_instr->outfeed_shape().tuple_shapes(0).layout(),
801 expected_layout0));
802 EXPECT_TRUE(LayoutUtil::Equal(
803 second_outfeed_instr->outfeed_shape().tuple_shapes(1).layout(),
804 expected_layout1));
805 }
806
TEST_F(SpmdPartitioningTest,ReduceWindowReplicatedInput)807 TEST_F(SpmdPartitioningTest, ReduceWindowReplicatedInput) {
808 absl::string_view hlo_string = R"(
809 HloModule module
810
811 sum {
812 a = f32[] parameter(0)
813 b = f32[] parameter(1)
814 ROOT add = f32[] add(a, b)
815 }
816
817 ENTRY entry {
818 constant = f32[6,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1},{1,2},{2,2}}),
819 sharding={replicated}
820 constant.1 = f32[] constant(0), sharding={replicated}
821 ROOT reduce-window = f32[3,2]{1,0} reduce-window(constant, constant.1),
822 window={size=3x1 stride=2x1 pad=1_0x0_0}, to_apply=sum,
823 sharding={devices=[2,1]0,1}
824 })";
825 TF_ASSERT_OK_AND_ASSIGN(auto module,
826 PartitionComputation(hlo_string, /*num_devices=*/2));
827 VLOG(1) << module->ToString();
828 HloInstruction* root = module->entry_computation()->root_instruction();
829 EXPECT_THAT(
830 root,
831 AllOf(op::Shape("f32[2,2]{1,0}"),
832 op::ReduceWindow(
833 op::DynamicSlice(AllOf(op::Shape("f32[9,2]{1,0}"),
834 op::Pad(op::Constant(), op::Constant())),
835 op::Multiply(op::Reshape(), op::Constant()),
836 op::Constant()),
837 op::Constant())));
838 }
839
TEST_F(SpmdPartitioningTest,ReduceWindowTiledNegativeLeftHalo)840 TEST_F(SpmdPartitioningTest, ReduceWindowTiledNegativeLeftHalo) {
841 absl::string_view hlo_string = R"(
842 HloModule module
843
844 sum {
845 a = f32[] parameter(0)
846 b = f32[] parameter(1)
847 ROOT add = f32[] add(a, b)
848 }
849
850 ENTRY entry {
851 constant = f32[6,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1},{1,2},{2,2}}),
852 sharding={devices=[2,1]0,1}
853 constant.1 = f32[] constant(0), sharding={replicated}
854 ROOT %reduce-window = f32[3,2]{1,0} reduce-window(%constant, %constant.1),
855 window={size=3x1 stride=2x1 pad=0_1x0_0}, to_apply=sum,
856 sharding={devices=[2,1]0,1}
857 })";
858 TF_ASSERT_OK_AND_ASSIGN(auto module,
859 PartitionComputation(hlo_string, /*num_devices=*/2));
860 VLOG(1) << module->ToString();
861 HloInstruction* root = module->entry_computation()->root_instruction();
862
863 auto sharded_input =
864 op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant());
865 auto right_halo = AllOf(op::Shape("f32[2,2]{1,0}"),
866 op::CollectivePermute(op::Slice(sharded_input)));
867 auto pre_masking = op::DynamicSlice(
868 AllOf(
869 op::Shape("f32[6,2]{1,0}"),
870 op::Pad(op::Concatenate(sharded_input, right_halo), op::Constant())),
871 op::Reshape(), op::Constant());
872 auto index_in_padded = op::Add(
873 op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant())));
874 auto masked =
875 op::Select(op::Compare(index_in_padded, op::Broadcast(op::Constant())),
876 pre_masking, op::Broadcast(op::Constant()));
877 EXPECT_THAT(root, AllOf(op::Shape("f32[2,2]{1,0}"),
878 op::ReduceWindow(masked, op::Constant())));
879 }
880
TEST_F(SpmdPartitioningTest,ReduceWindowTiledOneSideHaloBeyondNeighbor)881 TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideHaloBeyondNeighbor) {
882 absl::string_view hlo_string = R"(
883 HloModule module
884
885 sum {
886 a = f32[] parameter(0)
887 b = f32[] parameter(1)
888 ROOT add = f32[] add(a, b)
889 }
890
891 ENTRY entry {
892 param = f32[9,2] parameter(0), sharding={devices=[5,1]0,1,2,3,4}
893 constant.1 = f32[] constant(0), sharding={replicated}
894 ROOT reduce-window = f32[5,2]{1,0} reduce-window(param, constant.1),
895 window={size=4x1 stride=2x1 pad=3_0x0_0}, to_apply=sum,
896 sharding={devices=[5,1]0,1,2,3,4}
897 })";
898 TF_ASSERT_OK_AND_ASSIGN(auto module,
899 PartitionComputation(hlo_string, /*num_devices=*/5));
900 VLOG(1) << module->ToString();
901 auto halo0 = AllOf(op::Shape("f32[1,2]"),
902 op::CollectivePermute(op::Slice(op::Parameter(0))));
903 auto halo1 =
904 AllOf(op::Shape("f32[2,2]"), op::CollectivePermute(op::Parameter(0)));
905 auto pre_mask =
906 AllOf(op::Shape("f32[4,2]"),
907 op::Concatenate(halo0, halo1, op::Slice(op::Parameter(0))));
908 auto masked =
909 op::Select(op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply())),
910 op::Broadcast(op::Constant())),
911 pre_mask, op::Broadcast(op::Constant()));
912 HloInstruction* root = module->entry_computation()->root_instruction();
913 EXPECT_THAT(root, AllOf(op::Shape("f32[1,2]{1,0}"),
914 op::ReduceWindow(masked, op::Constant())));
915 }
916
TEST_F(SpmdPartitioningTest,ReduceWindowTiledOneSideUnequalHalo)917 TEST_F(SpmdPartitioningTest, ReduceWindowTiledOneSideUnequalHalo) {
918 absl::string_view hlo_string = R"(
919 HloModule module
920
921 sum {
922 a = f32[] parameter(0)
923 b = f32[] parameter(1)
924 ROOT add = f32[] add(a, b)
925 }
926
927 ENTRY entry {
928 constant = f32[9,2]{1,0} constant(
929 {{1,1},{1,4},{2,1},{3,1},{1,2},{2,2},{4,1},{1,2},{2,1}}),
930 sharding={devices=[3,1]0,1,2}
931 constant.1 = f32[] constant(0), sharding={replicated}
932 ROOT reduce-window = f32[5,2]{1,0} reduce-window(constant, constant.1),
933 window={size=3x1 stride=2x1 pad=1_1x0_0}, to_apply=sum,
934 sharding={devices=[3,1]0,1,2}
935 })";
936 TF_ASSERT_OK_AND_ASSIGN(auto module,
937 PartitionComputation(hlo_string, /*num_devices=*/3));
938 VLOG(1) << module->ToString();
939 HloInstruction* root = module->entry_computation()->root_instruction();
940
941 auto sharded_input =
942 op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant());
943 auto right_halo = AllOf(op::Shape("f32[2,2]{1,0}"),
944 op::CollectivePermute(op::Slice(sharded_input)));
945 auto pre_masking = op::DynamicSlice(
946 AllOf(
947 op::Shape("f32[7,2]{1,0}"),
948 op::Pad(op::Concatenate(sharded_input, right_halo), op::Constant())),
949 op::Reshape(), op::Constant());
950 auto index_in_padded = op::Add(
951 op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant())));
952 auto masked = op::Select(
953 op::And(op::Compare(index_in_padded, op::Broadcast(op::Constant())),
954 op::Compare(index_in_padded, op::Broadcast(op::Constant()))),
955 pre_masking, op::Broadcast(op::Constant()));
956 EXPECT_THAT(root, AllOf(op::Shape("f32[2,2]{1,0}"),
957 op::ReduceWindow(masked, op::Constant())));
958 }
959
TEST_F(SpmdPartitioningTest,ReduceWindowTiledTwoSideHalo)960 TEST_F(SpmdPartitioningTest, ReduceWindowTiledTwoSideHalo) {
961 absl::string_view hlo_string = R"(
962 HloModule module
963
964 sum {
965 a = f32[] parameter(0)
966 b = f32[] parameter(1)
967 ROOT add = f32[] add(a, b)
968 }
969
970 ENTRY entry {
971 constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}}),
972 sharding={devices=[2,1]0,1}
973 constant.1 = f32[] constant(0), sharding={replicated}
974 ROOT reduce-window = f32[2,2]{1,0} reduce-window(constant, constant.1),
975 window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum,
976 sharding={devices=[2,1]0,1}
977 })";
978 TF_ASSERT_OK_AND_ASSIGN(auto module,
979 PartitionComputation(hlo_string, /*num_devices=*/2));
980 VLOG(1) << module->ToString();
981 HloInstruction* root = module->entry_computation()->root_instruction();
982
983 auto sharded_input =
984 op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant());
985 auto left_halo = AllOf(op::Shape("f32[1,2]{1,0}"),
986 op::CollectivePermute(op::Slice(sharded_input)));
987 auto right_halo = AllOf(op::Shape("f32[1,2]{1,0}"),
988 op::CollectivePermute(op::Slice(sharded_input)));
989 auto pre_masking = AllOf(
990 op::Shape("f32[5,2]{1,0}"),
991 op::DynamicSlice(
992 AllOf(op::Shape("f32[6,2]{1,0}"),
993 op::Pad(op::Concatenate(left_halo, sharded_input, right_halo),
994 op::Constant())),
995 op::Reshape(), op::Constant()));
996 auto index_in_padded = op::Add(
997 op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant())));
998 auto masked = op::Select(
999 op::And(op::Compare(index_in_padded, op::Broadcast(op::Constant())),
1000 op::Compare(index_in_padded, op::Broadcast(op::Constant()))),
1001 pre_masking, op::Broadcast(op::Constant()));
1002 EXPECT_THAT(root, AllOf(op::Shape("f32[1,2]{1,0}"),
1003 op::ReduceWindow(masked, op::Constant())));
1004 }
1005
TEST_F(SpmdPartitioningTest,ReduceWindowTiled2D)1006 TEST_F(SpmdPartitioningTest, ReduceWindowTiled2D) {
1007 absl::string_view hlo_string = R"(
1008 HloModule module
1009
1010 sum {
1011 a = f32[] parameter(0)
1012 b = f32[] parameter(1)
1013 ROOT add = f32[] add(a, b)
1014 }
1015
1016 ENTRY entry {
1017 token0 = token[] after-all(), sharding={maximal device=0}
1018 infeed = (f32[4,4,2,2]{3,2,1,0}, token[]) infeed(token0),
1019 sharding={{devices=[2,2,1,1]0,1,2,3}, {maximal device=0}}
1020 infeed.data = f32[4,4,2,2]{3,2,1,0} get-tuple-element(infeed), index=0,
1021 sharding={devices=[2,2,1,1]0,1,2,3}
1022 constant = f32[] constant(0), sharding={replicated}
1023 ROOT reduce-window = f32[2,2,2,2]{3,2,1,0} reduce-window(infeed.data, constant),
1024 window={size=5x5x1x1 stride=3x3x1x1 pad=2_2x2_2x0_0x0_0}, to_apply=sum,
1025 sharding={devices=[2,2,1,1]0,1,2,3}
1026 })";
1027 TF_ASSERT_OK_AND_ASSIGN(auto module,
1028 PartitionComputation(hlo_string, /*num_devices=*/4));
1029 VLOG(1) << module->ToString();
1030 HloInstruction* root = module->entry_computation()->root_instruction();
1031
1032 auto sharded_input = AllOf(op::Shape("f32[2,2,2,2]{3,2,1,0}"),
1033 op::GetTupleElement(op::Infeed()));
1034 auto dim0_left_halo = AllOf(op::Shape("f32[1,2,2,2]{3,2,1,0}"),
1035 op::CollectivePermute(op::Slice(sharded_input)));
1036 auto dim0_right_halo = AllOf(op::Shape("f32[1,2,2,2]{3,2,1,0}"),
1037 op::CollectivePermute(op::Slice(sharded_input)));
1038 auto dim0_pre_masking = op::DynamicSlice(
1039 AllOf(op::Shape("f32[6,2,2,2]{3,2,1,0}"),
1040 op::Pad(
1041 op::Concatenate(dim0_left_halo, sharded_input, dim0_right_halo),
1042 op::Constant())),
1043 op::Reshape(), op::Constant(), op::Constant(), op::Constant());
1044 auto dim0_index_in_padded = op::Add(
1045 op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant())));
1046 auto dim0_masked = op::Select(
1047 op::And(op::Compare(dim0_index_in_padded, op::Broadcast(op::Constant())),
1048 op::Compare(dim0_index_in_padded, op::Broadcast(op::Constant()))),
1049 dim0_pre_masking, op::Broadcast(op::Constant()));
1050 auto dim0_resharded = AllOf(op::Shape("f32[5,2,2,2]{3,2,1,0}"), dim0_masked);
1051 auto dim1_left_halo = AllOf(op::Shape("f32[5,1,2,2]{3,2,1,0}"),
1052 op::CollectivePermute(op::Slice(dim0_resharded)));
1053 auto dim1_right_halo =
1054 AllOf(op::Shape("f32[5,1,2,2]{3,2,1,0}"),
1055 op::CollectivePermute(op::Slice(dim0_resharded)));
1056 auto dim1_pre_masking = op::DynamicSlice(
1057 AllOf(op::Shape("f32[5,6,2,2]{3,2,1,0}"),
1058 op::Pad(op::Concatenate(dim1_left_halo, dim0_resharded,
1059 dim1_right_halo),
1060 op::Constant())),
1061 op::Constant(), op::Reshape(), op::Constant(), op::Constant());
1062 auto dim1_index_in_padded = op::Add(
1063 op::Iota(), op::Broadcast(op::Multiply(op::Reshape(), op::Constant())));
1064 auto dim1_masked = op::Select(
1065 op::And(op::Compare(dim1_index_in_padded, op::Broadcast(op::Constant())),
1066 op::Compare(dim1_index_in_padded, op::Broadcast(op::Constant()))),
1067 dim1_pre_masking, op::Broadcast(op::Constant()));
1068 auto dim1_resharded = AllOf(op::Shape("f32[5,5,2,2]{3,2,1,0}"), dim1_masked);
1069 EXPECT_THAT(root, AllOf(op::Shape("f32[1,1,2,2]{3,2,1,0}"),
1070 op::ReduceWindow(dim1_resharded, op::Constant())));
1071 }
1072
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsReplicated)1073 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicated) {
1074 absl::string_view hlo_string = R"(
1075 HloModule module
1076
1077 ENTRY entry {
1078 %lhs = f32[128,224,224,3] parameter(0)
1079 %lhs.copy = f32[128,224,224,3] copy(f32[128,224,224,3] %lhs),
1080 sharding={devices=[1,2,1,1]0,1}
1081 %rhs = f32[7,7,3,64] parameter(1)
1082 %rhs.copy = f32[7,7,3,64] copy(f32[7,7,3,64] %rhs),
1083 sharding={replicated}
1084 ROOT %conv = f32[128,112,112,64] convolution(
1085 f32[128,224,224,3] %lhs.copy,
1086 f32[7,7,3,64] %rhs.copy),
1087 window={size=7x7 stride=2x2 pad=3_3x3_3},
1088 dim_labels=b01f_01io->b01f,
1089 sharding={devices=[1,2,1,1]0,1}
1090 })";
1091
1092 TF_ASSERT_OK_AND_ASSIGN(auto module,
1093 PartitionComputation(hlo_string, /*num_devices=*/2));
1094 VLOG(1) << module->ToString();
1095
1096 auto root = module->entry_computation()->root_instruction();
1097 auto lhs = AllOf(
1098 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1099 op::Constant(), op::Constant())),
1100 op::Shape("f32[128,112,224,3]"));
1101 auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]"));
1102
1103 auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1104 op::Shape("f32[128,3,224,3]"));
1105 auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1106 op::Shape("f32[128,2,224,3]"));
1107 EXPECT_THAT(root,
1108 AllOf(op::Convolution(
1109 op::Select(op::And(),
1110 op::Concatenate(left_halo, lhs, right_halo),
1111 op::Broadcast()),
1112 rhs),
1113 op::Shape("f32[128,56,112,64]")));
1114 }
1115
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsReplicatedNeedReshard)1116 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicatedNeedReshard) {
1117 absl::string_view hlo_string = R"(
1118 HloModule module
1119
1120 ENTRY entry {
1121 %lhs = f32[128,224,224,3] parameter(0)
1122 %lhs.copy = f32[128,224,224,3] copy(f32[128,224,224,3] %lhs),
1123 sharding={devices=[2,1,1,1]0,1}
1124 %rhs = f32[7,7,3,64] parameter(1)
1125 %rhs.copy = f32[7,7,3,64] copy(f32[7,7,3,64] %rhs),
1126 sharding={replicated}
1127 ROOT %conv = f32[128,112,112,64] convolution(
1128 f32[128,224,224,3] %lhs.copy,
1129 f32[7,7,3,64] %rhs.copy),
1130 window={size=7x7 stride=2x2 pad=3_3x3_3},
1131 dim_labels=b01f_01io->b01f,
1132 sharding={devices=[1,2,1,1]0,1}
1133 })";
1134
1135 TF_ASSERT_OK_AND_ASSIGN(auto module,
1136 PartitionComputation(hlo_string, /*num_devices=*/2));
1137 VLOG(1) << module->ToString();
1138
1139 auto root = module->entry_computation()->root_instruction();
1140 auto lhs = AllOf(
1141 op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(),
1142 op::Constant(), op::Constant())),
1143 op::Shape("f32[64,224,224,3]"));
1144 auto all_to_all =
1145 AllOf(op::AllToAll(op::Reshape(lhs)), op::Shape("f32[64,2,112,224,3]"));
1146 auto reshard_lhs = AllOf(op::Reshape(op::Transpose(all_to_all)),
1147 op::Shape("f32[128,112,224,3]"));
1148
1149 auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]"));
1150
1151 auto left_halo = AllOf(op::CollectivePermute(op::Slice(reshard_lhs)),
1152 op::Shape("f32[128,3,224,3]"));
1153 auto right_halo = AllOf(op::CollectivePermute(op::Slice(reshard_lhs)),
1154 op::Shape("f32[128,2,224,3]"));
1155 EXPECT_THAT(
1156 root,
1157 AllOf(op::Convolution(
1158 op::Select(op::And(),
1159 op::Concatenate(left_halo, reshard_lhs, right_halo),
1160 op::Broadcast()),
1161 rhs),
1162 op::Shape("f32[128,56,112,64]")));
1163 }
1164
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsReplicatedReordered)1165 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsReplicatedReordered) {
1166 absl::string_view hlo_string = R"(
1167 HloModule module
1168
1169 ENTRY entry {
1170 %lhs = f32[224,224,3,128] parameter(0)
1171 %lhs.copy = f32[224,224,3,128] copy(%lhs), sharding={devices=[2,1,1,1]0,1}
1172 %rhs = f32[7,7,3,64] parameter(1)
1173 %rhs.copy = f32[7,7,3,64] copy(%rhs), sharding={replicated}
1174 ROOT %conv = f32[128,112,112,64] convolution(%lhs.copy, %rhs.copy),
1175 window={size=7x7 stride=2x2 pad=3_3x3_3},
1176 dim_labels=01fb_01io->b01f,
1177 sharding={devices=[1,2,1,1]0,1}
1178 })";
1179
1180 TF_ASSERT_OK_AND_ASSIGN(auto module,
1181 PartitionComputation(hlo_string, /*num_devices=*/2));
1182 VLOG(1) << module->ToString();
1183
1184 auto root = module->entry_computation()->root_instruction();
1185 auto lhs = AllOf(
1186 op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(),
1187 op::Constant(), op::Constant())),
1188 op::Shape("f32[112,224,3,128]"));
1189 auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[7,7,3,64]"));
1190
1191 auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1192 op::Shape("f32[3,224,3,128]"));
1193 auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1194 op::Shape("f32[2,224,3,128]"));
1195 EXPECT_THAT(root,
1196 AllOf(op::Convolution(
1197 op::Select(op::And(),
1198 op::Concatenate(left_halo, lhs, right_halo),
1199 op::Broadcast()),
1200 rhs),
1201 op::Shape("f32[128,56,112,64]")));
1202 }
1203
1204 // (stride * per_shard_window_count) % dilation == 0
TEST_F(SpmdPartitioningTest,ConvolutionBaseDilationSameStartPatternLhsTiledRhsReplicated)1205 TEST_F(SpmdPartitioningTest,
1206 ConvolutionBaseDilationSameStartPatternLhsTiledRhsReplicated) {
1207 absl::string_view hlo_string = R"(
1208 HloModule module
1209
1210 ENTRY entry {
1211 %lhs = f32[128,7,7,512] parameter(0)
1212 %lhs.copy = f32[128,7,7,512] copy(%lhs),
1213 sharding={devices=[1,2,1,1]0,1}
1214 %rhs = f32[3,3,512,512] parameter(1)
1215 %rhs.copy = f32[3,3,512,512] copy(%rhs),
1216 sharding={replicated}
1217 ROOT %conv = f32[128,4,4,512] convolution(%lhs.copy, %rhs.copy),
1218 window={size=3x3 stride=4x4 pad=1_1x1_1 lhs_dilate=2x2 rhs_reversal=1x1},
1219 dim_labels=b01f_01io->b01f,
1220 sharding={devices=[1,2,1,1]0,1}
1221 })";
1222
1223 TF_ASSERT_OK_AND_ASSIGN(auto module,
1224 PartitionComputation(hlo_string, /*num_devices=*/2));
1225 VLOG(1) << module->ToString();
1226
1227 auto root = module->entry_computation()->root_instruction();
1228 // There is no halo exchange, and because the last element in the shard is not
1229 // needed (stride == 4), the LHS will be just a slice.
1230 auto sliced_lhs =
1231 AllOf(op::Slice(op::Copy(op::DynamicSlice(
1232 op::Pad(op::Parameter(), op::Constant()), op::Constant(),
1233 op::Reshape(), op::Constant(), op::Constant()))),
1234 op::Shape("f32[128,3,7,512]"));
1235 auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[3,3,512,512]"));
1236 EXPECT_THAT(root, AllOf(op::Convolution(sliced_lhs, rhs),
1237 op::Shape("f32[128,2,4,512]")));
1238 EXPECT_EQ(root->window().dimensions(0).padding_low(), 1);
1239 EXPECT_EQ(root->window().dimensions(0).padding_high(), 1);
1240 }
1241
1242 // (stride * per_shard_window_count) % dilation != 0 but stride == 1
TEST_F(SpmdPartitioningTest,ConvolutionBaseDilationStride1LhsTiledRhsReplicated)1243 TEST_F(SpmdPartitioningTest,
1244 ConvolutionBaseDilationStride1LhsTiledRhsReplicated) {
1245 absl::string_view hlo_string = R"(
1246 HloModule module
1247
1248 ENTRY entry {
1249 %lhs = f32[128,7,7,512] parameter(0)
1250 %lhs.copy = f32[128,7,7,512] copy(%lhs),
1251 sharding={devices=[1,2,1,1]0,1}
1252 %rhs = f32[3,3,512,512] parameter(1)
1253 %rhs.copy = f32[3,3,512,512] copy(%rhs),
1254 sharding={replicated}
1255 ROOT %conv = f32[128,14,14,512] convolution(%lhs.copy, %rhs.copy),
1256 window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1},
1257 dim_labels=b01f_01io->b01f,
1258 sharding={devices=[1,2,1,1]0,1}
1259 })";
1260
1261 TF_ASSERT_OK_AND_ASSIGN(auto module,
1262 PartitionComputation(hlo_string, /*num_devices=*/2));
1263 VLOG(1) << module->ToString();
1264
1265 auto root = module->entry_computation()->root_instruction();
1266 auto lhs = AllOf(op::Copy(op::DynamicSlice(
1267 op::Pad(op::Parameter(), op::Constant()), op::Constant(),
1268 op::Reshape(), op::Constant(), op::Constant())),
1269 op::Shape("f32[128,4,7,512]"));
1270 auto rhs = AllOf(op::Copy(op::Parameter()), op::Shape("f32[3,3,512,512]"));
1271
1272 auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1273 op::Shape("f32[128,1,7,512]"));
1274 auto start_window = op::Multiply(op::Reshape(), op::Constant());
1275 auto start_input_element = op::Divide(start_window, op::Constant());
1276 auto dynamic_offset_for_padded_concat = op::Subtract(
1277 op::Constant(), op::Subtract(op::Multiply(op::Reshape(), op::Constant()),
1278 start_input_element));
1279 auto pre_masking =
1280 AllOf(op::Shape("f32[128,5,7,512]"),
1281 op::DynamicSlice(
1282 AllOf(op::Shape("f32[128,6,7,512]"),
1283 op::Pad(op::Concatenate(left_halo, lhs), op::Constant())),
1284 op::Constant(), dynamic_offset_for_padded_concat,
1285 op::Constant(), op::Constant()));
1286 auto masked = op::Select(
1287 op::Compare(op::Add(op::Iota(), op::Broadcast(start_input_element)),
1288 op::Broadcast(op::Constant())),
1289 pre_masking, op::Broadcast(op::Constant()));
1290 auto dynamic_offset_on_output = op::Subtract(
1291 start_window, op::Multiply(start_input_element, op::Constant()));
1292 EXPECT_THAT(root,
1293 AllOf(op::DynamicSlice(AllOf(op::Convolution(masked, rhs),
1294 op::Shape("f32[128,8,14,512]")),
1295 op::Constant(), dynamic_offset_on_output,
1296 op::Constant(), op::Constant()),
1297 op::Shape("f32[128,7,14,512]")));
1298 EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_low(), 1);
1299 EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_high(), 0);
1300 }
1301
TEST_F(SpmdPartitioningTest,SelectAndScatterNoOverlap)1302 TEST_F(SpmdPartitioningTest, SelectAndScatterNoOverlap) {
1303 absl::string_view hlo_string = R"(
1304 HloModule module
1305
1306 ge {
1307 a = f32[] parameter(0)
1308 b = f32[] parameter(1)
1309 ROOT compare = pred[] compare(a, b), direction=GE
1310 }
1311
1312 sum {
1313 c = f32[] parameter(0)
1314 d = f32[] parameter(1)
1315 ROOT add = f32[] add(c, d)
1316 }
1317
1318 ENTRY entry {
1319 %param = f32[11,4]{1,0} parameter(0)
1320 %param.copy = f32[11,4] copy(%param),
1321 sharding={devices=[4,1]0,1,2,3}
1322 constant = f32[4,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8}}),
1323 sharding={devices=[4,1]0,1,2,3}
1324 constant.1 = f32[] constant(0), sharding={replicated}
1325 ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy,
1326 constant, constant.1), window={size=3x2 stride=3x2 pad=0_1x0_0},
1327 select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3}
1328 })";
1329 TF_ASSERT_OK_AND_ASSIGN(auto module,
1330 PartitionComputation(hlo_string, /*num_devices=*/4));
1331 VLOG(1) << module->ToString();
1332 auto root = module->entry_computation()->root_instruction();
1333 auto source =
1334 AllOf(op::Shape("f32[1,2]{1,0}"),
1335 op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()));
1336 auto masked_data = AllOf(
1337 op::Shape("f32[3,4]{1,0}"),
1338 op::Select(
1339 op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply(
1340 op::Reshape(), op::Constant()))),
1341 op::Broadcast(op::Constant())),
1342 op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
1343 op::Reshape(), op::Constant())),
1344 op::Broadcast(op::Constant())));
1345
1346 EXPECT_THAT(root,
1347 AllOf(op::SelectAndScatter(masked_data, source, op::Constant()),
1348 op::Shape("f32[3,4]{1,0}")));
1349 EXPECT_EQ(root->window().dimensions(0).padding_low(), 0);
1350 EXPECT_EQ(root->window().dimensions(0).padding_high(), 0);
1351 }
1352
TEST_F(SpmdPartitioningTest,SelectAndScatterNoOverlapReshard)1353 TEST_F(SpmdPartitioningTest, SelectAndScatterNoOverlapReshard) {
1354 absl::string_view hlo_string = R"(
1355 HloModule module
1356
1357 ge {
1358 a = f32[] parameter(0)
1359 b = f32[] parameter(1)
1360 ROOT compare = pred[] compare(a, b), direction=GE
1361 }
1362
1363 sum {
1364 c = f32[] parameter(0)
1365 d = f32[] parameter(1)
1366 ROOT add = f32[] add(c, d)
1367 }
1368
1369 ENTRY entry {
1370 %param = f32[11,4]{1,0} parameter(0)
1371 %param.copy = f32[11,4] copy(%param),
1372 sharding={devices=[1,4]0,1,2,3}
1373 constant = f32[4,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8}}),
1374 sharding={devices=[4,1]0,1,2,3}
1375 constant.1 = f32[] constant(0), sharding={replicated}
1376 ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy,
1377 constant, constant.1), window={size=3x2 stride=3x2 pad=0_1x0_0},
1378 select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3}
1379 })";
1380 TF_ASSERT_OK_AND_ASSIGN(auto module,
1381 PartitionComputation(hlo_string, /*num_devices=*/4));
1382 VLOG(1) << module->ToString();
1383 auto root = module->entry_computation()->root_instruction();
1384 auto source =
1385 AllOf(op::Shape("f32[1,2]{1,0}"),
1386 op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()));
1387 auto operand = AllOf(op::Copy(op::DynamicSlice(
1388 op::Parameter(0), op::Constant(), op::Reshape())),
1389 op::Shape("f32[11,1]"));
1390 auto reshard_operand = op::Reshape(op::Transpose(
1391 op::AllToAll(op::Reshape(op::Pad(operand, op::Constant())))));
1392 auto masked_data = AllOf(
1393 op::Shape("f32[3,4]{1,0}"),
1394 op::Select(
1395 op::Compare(op::Add(op::Iota(), op::Broadcast(op::Multiply(
1396 op::Reshape(), op::Constant()))),
1397 op::Broadcast(op::Constant())),
1398 reshard_operand, op::Broadcast(op::Constant())));
1399
1400 EXPECT_THAT(root,
1401 AllOf(op::SelectAndScatter(masked_data, source, op::Constant()),
1402 op::Shape("f32[3,4]{1,0}")));
1403 EXPECT_EQ(root->window().dimensions(0).padding_low(), 0);
1404 EXPECT_EQ(root->window().dimensions(0).padding_high(), 0);
1405 }
1406
TEST_F(SpmdPartitioningTest,SelectAndScatterWithOverlap)1407 TEST_F(SpmdPartitioningTest, SelectAndScatterWithOverlap) {
1408 absl::string_view hlo_string = R"(
1409 HloModule module
1410
1411 ge {
1412 a = f32[] parameter(0)
1413 b = f32[] parameter(1)
1414 ROOT compare = pred[] compare(a, b), direction=GE
1415 }
1416
1417 sum {
1418 c = f32[] parameter(0)
1419 d = f32[] parameter(1)
1420 ROOT add = f32[] add(c, d)
1421 }
1422
1423 ENTRY entry {
1424 %param = f32[11,4]{1,0} parameter(0)
1425 %param.copy = f32[11,4] copy(%param),
1426 sharding={devices=[4,1]0,1,2,3}
1427 constant = f32[6,2]{1,0} constant({{1,2},{3,4},{1,0},{2,8},{6,6},{1,9}}),
1428 sharding={devices=[4,1]0,1,2,3}
1429 constant.1 = f32[] constant(0), sharding={replicated}
1430 ROOT select-and-scatter = f32[11,4]{1,0} select-and-scatter(param.copy,
1431 constant, constant.1), window={size=3x2 stride=2x2 pad=1_1x0_0},
1432 select=ge, scatter=sum, sharding={devices=[4,1]0,1,2,3}
1433 })";
1434 TF_ASSERT_OK_AND_ASSIGN(auto module,
1435 PartitionComputation(hlo_string, /*num_devices=*/4));
1436 VLOG(1) << module->ToString();
1437 auto root = module->entry_computation()->root_instruction();
1438
1439 auto source_shard =
1440 AllOf(op::Shape("f32[2,2]{1,0}"),
1441 op::DynamicSlice(op::Pad(), op::Reshape(), op::Constant()));
1442 // Max halo size is the same as the shard size, so slice is not needed.
1443 auto source_left_halo = op::CollectivePermute(source_shard);
1444 auto required_source_shard_start =
1445 op::Divide(op::Multiply(op::Reshape(), op::Constant()), op::Constant());
1446 auto source_with_halo = op::DynamicSlice(
1447 AllOf(op::Shape("f32[5,2]{1,0}"),
1448 op::Pad(op::Concatenate(source_left_halo, source_shard),
1449 op::Constant())),
1450 op::Subtract(op::Constant(),
1451 op::Subtract(op::Multiply(op::Reshape(), op::Constant()),
1452 required_source_shard_start)),
1453 op::Constant());
1454 auto masked_source_with_halo = AllOf(
1455 AllOf(op::Shape("f32[3,2]{1,0}")),
1456 op::Select(
1457 op::Compare(
1458 op::Add(op::Iota(), op::Broadcast(required_source_shard_start)),
1459 op::Broadcast(op::Constant())),
1460 source_with_halo, op::Broadcast(op::Constant())));
1461
1462 auto data_shard =
1463 AllOf(op::Shape("f32[3,4]{1,0}"),
1464 op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
1465 op::Reshape(), op::Constant())));
1466 auto data_left_halo = AllOf(op::Shape("f32[2,4]{1,0}"),
1467 op::CollectivePermute(op::Slice(data_shard)));
1468 auto data_right_halo = AllOf(op::Shape("f32[2,4]{1,0}"),
1469 op::CollectivePermute(op::Slice(data_shard)));
1470 auto required_data_start_on_padded =
1471 op::Multiply(required_source_shard_start, op::Constant());
1472 auto left_halo_size = op::Subtract(
1473 op::Add(op::Multiply(op::Reshape(), op::Constant()), op::Constant()),
1474 required_data_start_on_padded);
1475 auto data_with_halo =
1476 AllOf(op::Shape("f32[7,4]{1,0}"),
1477 op::DynamicSlice(
1478 AllOf(op::Shape("f32[8,4]{1,0}"),
1479 op::Pad(op::Concatenate(data_left_halo, data_shard,
1480 data_right_halo),
1481 op::Constant())),
1482 op::Subtract(op::Constant(), left_halo_size), op::Constant()));
1483 auto index_on_padded =
1484 op::Add(op::Iota(), op::Broadcast(required_data_start_on_padded));
1485 auto masked_data_with_halo = op::Select(
1486 op::And(op::Compare(index_on_padded, op::Broadcast(op::Constant())),
1487 op::Compare(index_on_padded, op::Broadcast(op::Constant()))),
1488 data_with_halo, op::Broadcast(op::Constant()));
1489
1490 EXPECT_THAT(
1491 root, AllOf(op::DynamicSlice(op::SelectAndScatter(masked_data_with_halo,
1492 masked_source_with_halo,
1493 op::Constant()),
1494 left_halo_size, op::Constant()),
1495 op::Shape("f32[3,4]{1,0}")));
1496 EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_low(), 0);
1497 EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_high(), 0);
1498 }
1499
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiled)1500 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiled) {
1501 absl::string_view hlo_string = R"(
1502 HloModule module
1503
1504 ENTRY entry {
1505 %lhs = f32[128,56,56,64] parameter(0)
1506 %lhs.copy = f32[128,56,56,64] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1507 %rhs = f32[128,56,56,256] parameter(1)
1508 %rhs.copy = f32[128,56,56,256] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1509 ROOT %conv = f32[1,1,64,256] convolution(%lhs.copy, %rhs.copy),
1510 window={size=56x56}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1511 })";
1512
1513 TF_ASSERT_OK_AND_ASSIGN(auto module,
1514 PartitionComputation(hlo_string, /*num_devices=*/2));
1515 VLOG(1) << module->ToString();
1516
1517 auto root = module->entry_computation()->root_instruction();
1518 auto lhs = AllOf(
1519 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1520 op::Constant(), op::Constant())),
1521 op::Shape("f32[128,28,56,64]"));
1522 auto rhs = AllOf(
1523 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1524 op::Constant(), op::Constant())),
1525 op::Shape("f32[128,28,56,256]"));
1526
1527 EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)),
1528 op::Shape("f32[1,1,64,256]")));
1529 }
1530
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowReversal)1531 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWindowReversal) {
1532 absl::string_view hlo_string = R"(
1533 HloModule module
1534
1535 ENTRY entry {
1536 %lhs = f32[5,128,64] parameter(0), sharding={devices=[2,1,1]0,1}
1537 %rhs = f32[5,128,256] parameter(1), sharding={devices=[2,1,1]1,0}
1538 ROOT %conv = f32[1,64,256] convolution(%lhs, %rhs),
1539 window={size=5 rhs_reversal=1}, dim_labels=0fb_0io->0bf,
1540 sharding={replicated}
1541 })";
1542
1543 TF_ASSERT_OK_AND_ASSIGN(auto module,
1544 PartitionComputation(hlo_string, /*num_devices=*/2));
1545 VLOG(1) << module->ToString();
1546
1547 auto lhs_masked =
1548 AllOf(op::Shape("f32[3,128,64]"), op::Select(_, op::Parameter(0), _));
1549 auto rhs_left_padded =
1550 op::Concatenate(op::CollectivePermute(op::Slice(op::Parameter(1))),
1551 op::Slice(op::Parameter(1)));
1552 auto rhs_masked =
1553 AllOf(op::Shape("f32[3,128,256]"), op::Select(_, rhs_left_padded, _));
1554
1555 auto root = module->entry_computation()->root_instruction();
1556 EXPECT_THAT(root,
1557 AllOf(op::AllReduce(op::Convolution(lhs_masked, rhs_masked)),
1558 op::Shape("f32[1,64,256]")));
1559 }
1560
TEST_F(SpmdPartitioningTest,DotLhsTiledRhsTiledWithReshard)1561 TEST_F(SpmdPartitioningTest, DotLhsTiledRhsTiledWithReshard) {
1562 absl::string_view hlo_string = R"(
1563 HloModule module
1564
1565 ENTRY entry {
1566 %lhs = f32[128,56,56,64] parameter(0)
1567 %lhs.copy = f32[128,56,56,64] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1568 %rhs = f32[128,56,56,256] parameter(1)
1569 %rhs.copy = f32[128,56,56,256] copy(%rhs), sharding={devices=[2,1,1,1]0,1}
1570 ROOT %conv = f32[1,1,64,256] convolution(%lhs.copy, %rhs.copy),
1571 window={size=56x56}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1572 })";
1573
1574 TF_ASSERT_OK_AND_ASSIGN(auto module,
1575 PartitionComputation(hlo_string, /*num_devices=*/2));
1576 VLOG(1) << module->ToString();
1577
1578 auto root = module->entry_computation()->root_instruction();
1579 auto lhs = AllOf(
1580 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1581 op::Constant(), op::Constant())),
1582 op::Shape("f32[128,28,56,64]"));
1583 auto rhs = AllOf(
1584 op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(),
1585 op::Constant(), op::Constant())),
1586 op::Shape("f32[64,56,56,256]"));
1587 auto all_to_all =
1588 AllOf(op::AllToAll(op::Reshape(lhs)), op::Shape("f32[2,64,28,56,64]"));
1589 auto reshard = AllOf(op::Reshape(op::Transpose(all_to_all)));
1590
1591 EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(reshard, rhs)),
1592 op::Shape("f32[1,1,64,256]")));
1593 }
1594
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWithReshard)1595 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithReshard) {
1596 absl::string_view hlo_string = R"(
1597 HloModule module
1598
1599 ENTRY entry {
1600 %lhs = f32[128,56,56,512] parameter(0)
1601 %lhs.copy = f32[128,56,56,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1602 %rhs = f32[128,28,28,64] parameter(1)
1603 %rhs.copy = f32[128,28,28,64] copy(%rhs), sharding={devices=[2,1,1,1]0,1}
1604 ROOT %conv = f32[1,1,512,64] convolution(%lhs.copy, %rhs.copy),
1605 window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2},
1606 dim_labels=f01b_i01o->01bf, sharding={replicated}
1607 })";
1608
1609 TF_ASSERT_OK_AND_ASSIGN(auto module,
1610 PartitionComputation(hlo_string, /*num_devices=*/2));
1611 VLOG(1) << module->ToString();
1612
1613 auto root = module->entry_computation()->root_instruction();
1614 auto lhs = AllOf(
1615 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1616 op::Constant(), op::Constant())),
1617 op::Shape("f32[128,28,56,512]"));
1618 auto rhs = AllOf(
1619 op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(),
1620 op::Constant(), op::Constant())),
1621 op::Shape("f32[64,28,28,64]"));
1622 auto all_to_all =
1623 AllOf(op::AllToAll(op::Reshape(rhs)), op::Shape("f32[64,2,14,28,64]"));
1624 auto reshard = op::Reshape(op::Transpose(all_to_all));
1625
1626 EXPECT_THAT(root,
1627 AllOf(op::AllReduce(op::Convolution(op::Slice(lhs), reshard)),
1628 op::Shape("f32[1,1,512,64]")));
1629 }
1630
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiled_UnevenDilatedRHSPartitioned)1631 TEST_F(SpmdPartitioningTest,
1632 ConvolutionLhsTiledRhsTiled_UnevenDilatedRHSPartitioned) {
1633 absl::string_view hlo_string = R"(
1634 HloModule module
1635
1636 ENTRY entry {
1637 %lhs = f32[8,28,28,8] parameter(0)
1638 %lhs.copy = f32[8,28,28,8] copy(%lhs), sharding={devices=[1,4,1,1]0,1,2,3}
1639 %rhs = f32[8,14,14,64] parameter(1)
1640 %rhs.copy = f32[8,14,14,64] copy(%rhs), sharding={devices=[1,4,1,1]0,1,2,3}
1641 ROOT %conv = f32[1,1,8,64] convolution(%lhs.copy, %rhs.copy),
1642 window={size=14x14 pad=0_-1x0_-1 rhs_dilate=2x2},
1643 dim_labels=f01b_i01o->01bf, sharding={replicated}
1644 })";
1645
1646 TF_ASSERT_OK_AND_ASSIGN(auto module,
1647 PartitionComputation(hlo_string, /*num_devices=*/4));
1648 VLOG(1) << module->ToString();
1649
1650 auto root = module->entry_computation()->root_instruction();
1651 auto lhs = AllOf(
1652 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1653 op::Constant(), op::Constant())),
1654 op::Shape("f32[8,7,28,8]"));
1655 auto rhs = AllOf(op::Pad(op::Parameter(), op::Constant()),
1656 op::Shape("f32[8,16,14,64]"));
1657 auto selected_rhs = AllOf(
1658 op::Select(op::Compare(),
1659 op::Copy(op::DynamicSlice(rhs, op::Constant(), op::Reshape(),
1660 op::Constant(), op::Constant())),
1661 op::Broadcast()),
1662 op::Shape("f32[8,4,14,64]"));
1663 auto right_halo =
1664 AllOf(op::CollectivePermute(op::Slice(lhs)), op::Shape("f32[8,2,28,8]"));
1665 auto selected_lhs =
1666 AllOf(op::DynamicSlice(
1667 op::Pad(op::Concatenate(lhs, right_halo), op::Constant()),
1668 op::Constant(), op::Reshape(), op::Constant(), op::Constant()),
1669 op::Shape("f32[8,7,28,8]"));
1670 EXPECT_THAT(root,
1671 AllOf(op::AllReduce(op::Convolution(selected_lhs, selected_rhs)),
1672 op::Shape("f32[1,1,8,64]")));
1673 }
1674
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWithPadding)1675 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithPadding) {
1676 absl::string_view hlo_string = R"(
1677 HloModule module
1678
1679 ENTRY entry {
1680 %lhs = f32[32,28,28,128] parameter(0)
1681 %lhs.copy = f32[32,28,28,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1682 %rhs = f32[32,28,28,64] parameter(1)
1683 %rhs.copy = f32[32,28,28,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1684 ROOT %conv = f32[3,3,128,64] convolution(%lhs.copy, %rhs.copy),
1685 window={size=28x28 pad=1_1x1_1}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1686 })";
1687
1688 TF_ASSERT_OK_AND_ASSIGN(
1689 auto module,
1690 PartitionComputation(hlo_string, /*num_devices=*/2,
1691 /*conv_halo_exchange_always_on_lhs=*/false));
1692 VLOG(1) << module->ToString();
1693
1694 auto root = module->entry_computation()->root_instruction();
1695 auto lhs = AllOf(
1696 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1697 op::Constant(), op::Constant())),
1698 op::Shape("f32[32,14,28,128]"));
1699 auto rhs = AllOf(
1700 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1701 op::Constant(), op::Constant())),
1702 op::Shape("f32[32,14,28,64]"));
1703
1704 auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)),
1705 op::Shape("f32[32,1,28,64]"));
1706 auto right_halo = AllOf(op::CollectivePermute(op::Slice(rhs)),
1707 op::Shape("f32[32,1,28,64]"));
1708 EXPECT_THAT(root,
1709 AllOf(op::AllReduce(op::Convolution(
1710 lhs, AllOf(op::Concatenate(left_halo, rhs, right_halo),
1711 op::Shape("f32[32,16,28,64]")))),
1712 op::Shape("f32[3,3,128,64]")));
1713 }
1714
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowDilate)1715 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWindowDilate) {
1716 absl::string_view hlo_string = R"(
1717 HloModule module
1718
1719 ENTRY entry {
1720 %lhs = f32[128,224,224,3] parameter(0)
1721 %lhs.copy = f32[128,224,224,3] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1722 %rhs = f32[128,112,112,64] parameter(1)
1723 %rhs.copy = f32[128,112,112,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1724 ROOT %conv = f32[7,7,3,64] convolution(%lhs.copy, %rhs.copy),
1725 window={size=112x112 pad=3_2x3_2 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1726 })";
1727
1728 TF_ASSERT_OK_AND_ASSIGN(
1729 auto module,
1730 PartitionComputation(hlo_string, /*num_devices=*/2,
1731 /*conv_halo_exchange_always_on_lhs=*/false));
1732 VLOG(1) << module->ToString();
1733
1734 auto root = module->entry_computation()->root_instruction();
1735 auto lhs = AllOf(
1736 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1737 op::Constant(), op::Constant())),
1738 op::Shape("f32[128,112,224,3]"));
1739 auto rhs = AllOf(
1740 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1741 op::Constant(), op::Constant())),
1742 op::Shape("f32[128,56,112,64]"));
1743
1744 auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)),
1745 op::Shape("f32[128,2,112,64]"));
1746 auto right_halo = AllOf(op::CollectivePermute(op::Slice(rhs)),
1747 op::Shape("f32[128,2,112,64]"));
1748 EXPECT_THAT(root,
1749 AllOf(op::AllReduce(op::Convolution(
1750 lhs, AllOf(op::Concatenate(left_halo, rhs, right_halo),
1751 op::Shape("f32[128,60,112,64]")))),
1752 op::Shape("f32[7,7,3,64]")));
1753 }
1754
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowDilateNegativeRhsPadding)1755 TEST_F(SpmdPartitioningTest,
1756 ConvolutionLhsTiledRhsTiledWindowDilateNegativeRhsPadding) {
1757 absl::string_view hlo_string = R"(
1758 HloModule module
1759
1760 ENTRY entry {
1761 %lhs = f32[128,56,56,256] parameter(0)
1762 %lhs.copy = f32[128,56,56,256] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1763 %rhs = f32[128,28,28,512] parameter(1)
1764 %rhs.copy = f32[128,28,28,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1765 ROOT %conv = f32[1,1,256,512] convolution(%lhs.copy, %rhs.copy),
1766 window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1767 })";
1768
1769 TF_ASSERT_OK_AND_ASSIGN(
1770 auto module,
1771 PartitionComputation(hlo_string, /*num_devices=*/2,
1772 /*conv_halo_exchange_always_on_lhs=*/false));
1773 VLOG(1) << module->ToString();
1774
1775 auto root = module->entry_computation()->root_instruction();
1776 auto lhs = AllOf(
1777 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1778 op::Constant(), op::Constant())),
1779 op::Shape("f32[128,28,56,256]"));
1780 auto rhs = AllOf(
1781 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1782 op::Constant(), op::Constant())),
1783 op::Shape("f32[128,14,28,512]"));
1784
1785 EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)),
1786 op::Shape("f32[1,1,256,512]")));
1787 }
1788
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowDilateUneven)1789 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWindowDilateUneven) {
1790 absl::string_view hlo_string = R"(
1791 HloModule module
1792
1793 ENTRY entry {
1794 %lhs = f32[128,14,14,512] parameter(0)
1795 %lhs.copy = f32[128,14,14,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1796 %rhs = f32[128,7,7,512] parameter(1)
1797 %rhs.copy = f32[128,7,7,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1798 ROOT %conv = f32[3,3,512,512] convolution(%lhs.copy, %rhs.copy),
1799 window={size=7x7 pad=1_0x1_0 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1800 })";
1801
1802 TF_ASSERT_OK_AND_ASSIGN(
1803 auto module,
1804 PartitionComputation(hlo_string, /*num_devices=*/2,
1805 /*conv_halo_exchange_always_on_lhs=*/false));
1806 VLOG(1) << module->ToString();
1807
1808 auto root = module->entry_computation()->root_instruction();
1809 auto lhs = AllOf(
1810 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1811 op::Constant(), op::Constant())),
1812 op::Shape("f32[128,7,14,512]"));
1813 auto rhs = AllOf(
1814 op::Select(op::Compare(),
1815 op::Copy(op::DynamicSlice(
1816 op::Pad(op::Parameter(), op::Constant()), op::Constant(),
1817 op::Reshape(), op::Constant(), op::Constant())),
1818 op::Broadcast()),
1819 op::Shape("f32[128,4,7,512]"));
1820
1821 auto left_halo = AllOf(op::CollectivePermute(op::Slice(rhs)),
1822 op::Shape("f32[128,1,7,512]"));
1823 EXPECT_THAT(root,
1824 AllOf(op::AllReduce(op::Convolution(
1825 AllOf(op::DynamicSlice(op::Pad(lhs, op::Constant()),
1826 op::Constant(), op::Subtract(),
1827 op::Constant(), op::Constant()),
1828 op::Shape("f32[128,10,14,512]")),
1829 AllOf(op::Concatenate(left_halo, rhs),
1830 op::Shape("f32[128,5,7,512]")))),
1831 op::Shape("f32[3,3,512,512]")));
1832 }
1833
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWithPadding_HaloOnLhs)1834 TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithPadding_HaloOnLhs) {
1835 absl::string_view hlo_string = R"(
1836 HloModule module
1837
1838 ENTRY entry {
1839 %lhs = f32[32,28,28,128] parameter(0)
1840 %lhs.copy = f32[32,28,28,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1841 %rhs = f32[32,28,28,64] parameter(1)
1842 %rhs.copy = f32[32,28,28,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1843 ROOT %conv = f32[3,3,128,64] convolution(%lhs.copy, %rhs.copy),
1844 window={size=28x28 pad=1_1x1_1}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1845 })";
1846
1847 TF_ASSERT_OK_AND_ASSIGN(auto module,
1848 PartitionComputation(hlo_string, /*num_devices=*/2));
1849 VLOG(1) << module->ToString();
1850
1851 auto root = module->entry_computation()->root_instruction();
1852 auto lhs = AllOf(
1853 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1854 op::Constant(), op::Constant())),
1855 op::Shape("f32[32,14,28,128]"));
1856 auto rhs = AllOf(
1857 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1858 op::Constant(), op::Constant())),
1859 op::Shape("f32[32,14,28,64]"));
1860
1861 auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1862 op::Shape("f32[32,1,28,128]"));
1863 auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1864 op::Shape("f32[32,1,28,128]"));
1865 EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(
1866 AllOf(op::Concatenate(left_halo, lhs, right_halo),
1867 op::Shape("f32[32,16,28,128]")),
1868 rhs)),
1869 op::Shape("f32[3,3,128,64]")));
1870 }
1871
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowDilate_HaloOnLhs)1872 TEST_F(SpmdPartitioningTest,
1873 ConvolutionLhsTiledRhsTiledWindowDilate_HaloOnLhs) {
1874 absl::string_view hlo_string = R"(
1875 HloModule module
1876
1877 ENTRY entry {
1878 %lhs = f32[128,224,224,3] parameter(0)
1879 %lhs.copy = f32[128,224,224,3] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1880 %rhs = f32[128,112,112,64] parameter(1)
1881 %rhs.copy = f32[128,112,112,64] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1882 ROOT %conv = f32[7,7,3,64] convolution(%lhs.copy, %rhs.copy),
1883 window={size=112x112 pad=3_2x3_2 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1884 })";
1885
1886 TF_ASSERT_OK_AND_ASSIGN(auto module,
1887 PartitionComputation(hlo_string, /*num_devices=*/2));
1888 VLOG(1) << module->ToString();
1889
1890 auto root = module->entry_computation()->root_instruction();
1891 auto lhs = AllOf(
1892 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1893 op::Constant(), op::Constant())),
1894 op::Shape("f32[128,112,224,3]"));
1895 auto rhs = AllOf(
1896 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1897 op::Constant(), op::Constant())),
1898 op::Shape("f32[128,56,112,64]"));
1899
1900 auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1901 op::Shape("f32[128,3,224,3]"));
1902 auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1903 op::Shape("f32[128,2,224,3]"));
1904 EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(
1905 AllOf(op::Concatenate(left_halo, lhs, right_halo),
1906 op::Shape("f32[128,117,224,3]")),
1907 rhs)),
1908 op::Shape("f32[7,7,3,64]")));
1909 }
1910
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowDilateNegativeRhsPadding_HaloOnLhs)1911 TEST_F(SpmdPartitioningTest,
1912 ConvolutionLhsTiledRhsTiledWindowDilateNegativeRhsPadding_HaloOnLhs) {
1913 absl::string_view hlo_string = R"(
1914 HloModule module
1915
1916 ENTRY entry {
1917 %lhs = f32[128,56,56,256] parameter(0)
1918 %lhs.copy = f32[128,56,56,256] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1919 %rhs = f32[128,28,28,512] parameter(1)
1920 %rhs.copy = f32[128,28,28,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1921 ROOT %conv = f32[1,1,256,512] convolution(%lhs.copy, %rhs.copy),
1922 window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1923 })";
1924
1925 TF_ASSERT_OK_AND_ASSIGN(auto module,
1926 PartitionComputation(hlo_string, /*num_devices=*/2));
1927 VLOG(1) << module->ToString();
1928
1929 auto root = module->entry_computation()->root_instruction();
1930 auto lhs = AllOf(
1931 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1932 op::Constant(), op::Constant())),
1933 op::Shape("f32[128,28,56,256]"));
1934 auto rhs = AllOf(
1935 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1936 op::Constant(), op::Constant())),
1937 op::Shape("f32[128,14,28,512]"));
1938
1939 EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(op::Slice(lhs), rhs)),
1940 op::Shape("f32[1,1,256,512]")));
1941 }
1942
TEST_F(SpmdPartitioningTest,ConvolutionLhsTiledRhsTiledWindowDilateUneven_HaloOnLhs)1943 TEST_F(SpmdPartitioningTest,
1944 ConvolutionLhsTiledRhsTiledWindowDilateUneven_HaloOnLhs) {
1945 absl::string_view hlo_string = R"(
1946 HloModule module
1947
1948 ENTRY entry {
1949 %lhs = f32[128,14,14,512] parameter(0)
1950 %lhs.copy = f32[128,14,14,512] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
1951 %rhs = f32[128,7,7,512] parameter(1)
1952 %rhs.copy = f32[128,7,7,512] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
1953 ROOT %conv = f32[3,3,512,512] convolution(%lhs.copy, %rhs.copy),
1954 window={size=7x7 pad=1_0x1_0 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf, sharding={replicated}
1955 })";
1956
1957 TF_ASSERT_OK_AND_ASSIGN(auto module,
1958 PartitionComputation(hlo_string, /*num_devices=*/2));
1959 VLOG(1) << module->ToString();
1960
1961 auto root = module->entry_computation()->root_instruction();
1962 auto lhs = AllOf(
1963 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
1964 op::Constant(), op::Constant())),
1965 op::Shape("f32[128,7,14,512]"));
1966 auto rhs = AllOf(
1967 op::Select(op::Compare(),
1968 op::Copy(op::DynamicSlice(
1969 op::Pad(op::Parameter(), op::Constant()), op::Constant(),
1970 op::Reshape(), op::Constant(), op::Constant())),
1971 op::Broadcast()),
1972 op::Shape("f32[128,4,7,512]"));
1973
1974 auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)),
1975 op::Shape("f32[128,1,14,512]"));
1976 EXPECT_THAT(
1977 root, AllOf(op::AllReduce(op::Convolution(
1978 AllOf(op::DynamicSlice(
1979 AllOf(op::Pad(op::Concatenate(lhs, right_halo),
1980 op::Constant()),
1981 op::Shape("f32[128,10,14,512]")),
1982 op::Constant(), op::Reshape(), op::Constant(),
1983 op::Constant()),
1984 op::Shape("f32[128,9,14,512]")),
1985 rhs)),
1986 op::Shape("f32[3,3,512,512]")));
1987 }
1988
TEST_F(SpmdPartitioningTest,ConcatenateAlongNonPartitionedDimension)1989 TEST_F(SpmdPartitioningTest, ConcatenateAlongNonPartitionedDimension) {
1990 absl::string_view hlo_string = R"(
1991 HloModule module
1992
1993 ENTRY entry {
1994 %param0 = f32[14,257] parameter(0)
1995 %param0.copy = f32[14,257] copy(%param0), sharding={devices=[2,1]0,1}
1996 %param1 = f32[14,116] parameter(1)
1997 %param1.copy = f32[14,116] copy(%param1), sharding={devices=[2,1]0,1}
1998 ROOT %concatenate = f32[14,373] concatenate(%param0.copy, %param1.copy),
1999 dimensions={1}, sharding={devices=[2,1]0,1}
2000 })";
2001
2002 TF_ASSERT_OK_AND_ASSIGN(auto module,
2003 PartitionComputation(hlo_string, /*num_devices=*/2));
2004 VLOG(1) << module->ToString();
2005
2006 auto root = module->entry_computation()->root_instruction();
2007 auto param0 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
2008 op::Constant())),
2009 op::Shape("f32[7,257]"));
2010 auto param1 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
2011 op::Constant())),
2012 op::Shape("f32[7,116]"));
2013 EXPECT_THAT(root,
2014 AllOf(op::Concatenate(param0, param1), op::Shape("f32[7,373]")));
2015 }
2016
TEST_F(SpmdPartitioningTest,ConcatenateAlongPartitionedDimension)2017 TEST_F(SpmdPartitioningTest, ConcatenateAlongPartitionedDimension) {
2018 absl::string_view hlo_string = R"(
2019 HloModule module
2020
2021 ENTRY entry {
2022 %param0 = f32[14,257] parameter(0)
2023 %param0.copy = f32[14,257] copy(%param0), sharding={devices=[1,2]0,1}
2024 %param1 = f32[14,116] parameter(1)
2025 %param1.copy = f32[14,116] copy(%param1), sharding={devices=[1,2]0,1}
2026 ROOT %concatenate = f32[14,373] concatenate(%param0.copy, %param1.copy),
2027 dimensions={1}, sharding={devices=[1,2]0,1}
2028 })";
2029
2030 TF_ASSERT_OK_AND_ASSIGN(auto module,
2031 PartitionComputation(hlo_string, /*num_devices=*/2));
2032 VLOG(1) << module->ToString();
2033
2034 auto root = module->entry_computation()->root_instruction();
2035 auto param0 =
2036 AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
2037 op::Constant(), op::Reshape())),
2038 op::Shape("f32[14,129]"));
2039 auto param1 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(),
2040 op::Reshape())),
2041 op::Shape("f32[14,58]"));
2042 EXPECT_THAT(root, AllOf(op::DynamicSlice(
2043 AllOf(op::AllReduce(op::DynamicUpdateSlice(
2044 op::DynamicUpdateSlice(
2045 op::Broadcast(), param0,
2046 op::Constant(), op::Multiply()),
2047 param1, op::Constant(), op::Add())),
2048 op::Shape("f32[14,374]")),
2049 op::Constant(), op::Multiply()),
2050 op::Shape("f32[14,187]")));
2051 }
2052
TEST_F(SpmdPartitioningTest,ConcatenateAlongBothDimensions)2053 TEST_F(SpmdPartitioningTest, ConcatenateAlongBothDimensions) {
2054 const char* const hlo_string = R"(
2055 HloModule module
2056
2057 ENTRY entry {
2058 %param0 = f32[14,257] parameter(0), sharding={devices=[2,2]0,1,2,3}
2059 %param1 = f32[14,116] parameter(1), sharding={devices=[2,2]0,1,2,3}
2060 ROOT %concatenate = f32[14,373] concatenate(%param0, %param1),
2061 dimensions={1}, sharding={devices=[2,2]0,1,2,3}
2062 })";
2063
2064 TF_ASSERT_OK_AND_ASSIGN(auto module,
2065 PartitionComputation(hlo_string, /*num_devices=*/4));
2066 VLOG(1) << module->ToString();
2067
2068 auto root = module->entry_computation()->root_instruction();
2069 auto param0 = AllOf(op::Parameter(0), op::Shape("f32[7,129]"));
2070 auto param1 = AllOf(op::Parameter(1), op::Shape("f32[7,58]"));
2071 EXPECT_THAT(root, AllOf(op::DynamicSlice(
2072 AllOf(op::AllReduce(op::DynamicUpdateSlice(
2073 op::DynamicUpdateSlice(
2074 op::Broadcast(), param0,
2075 op::Constant(), op::Multiply()),
2076 param1, op::Constant(), op::Add())),
2077 op::Shape("f32[7,374]")),
2078 op::Constant(), op::Multiply()),
2079 op::Shape("f32[7,187]")));
2080 }
2081
TEST_F(SpmdPartitioningTest,PadAlongNonPartitionedDimension)2082 TEST_F(SpmdPartitioningTest, PadAlongNonPartitionedDimension) {
2083 absl::string_view hlo_string = R"(
2084 HloModule module
2085
2086 ENTRY entry {
2087 %param0 = f32[128,14,257] parameter(0), sharding={devices=[1,1,2]0,1}
2088 %const = f32[] constant(0)
2089 ROOT %pad = f32[128,17,257] pad(%param0, %const), padding=0_0x1_2x0_0,
2090 sharding={devices=[1,1,2]0,1}
2091 })";
2092
2093 TF_ASSERT_OK_AND_ASSIGN(auto module,
2094 PartitionComputation(hlo_string, /*num_devices=*/2));
2095 VLOG(1) << module->ToString();
2096
2097 auto root = module->entry_computation()->root_instruction();
2098 auto param0 = AllOf(op::Parameter(), op::Shape("f32[128,14,129]"));
2099 EXPECT_THAT(root, AllOf(op::Pad(param0, op::Constant()),
2100 op::Shape("f32[128,17,129]")));
2101 }
2102
TEST_F(SpmdPartitioningTest,PadAlongPartitionedDimension)2103 TEST_F(SpmdPartitioningTest, PadAlongPartitionedDimension) {
2104 absl::string_view hlo_string = R"(
2105 HloModule module
2106
2107 ENTRY entry {
2108 %param0 = f32[14,257] parameter(0), sharding={devices=[1,2]0,1}
2109 %const = f32[] constant(0)
2110 ROOT %pad = f32[14,259] pad(%param0, %const), padding=0_0x0_2,
2111 sharding={devices=[1,2]0,1}
2112 })";
2113
2114 TF_ASSERT_OK_AND_ASSIGN(auto module,
2115 PartitionComputation(hlo_string, /*num_devices=*/2));
2116 VLOG(1) << module->ToString();
2117
2118 auto root = module->entry_computation()->root_instruction();
2119 auto param0 = AllOf(op::Parameter(), op::Shape("f32[14,129]"));
2120 auto after_halo_exchange =
2121 AllOf(op::Shape("f32[14,130]"),
2122 op::Concatenate(param0, op::CollectivePermute(op::Slice(param0))));
2123 auto pad = AllOf(op::Shape("f32[14,131]"),
2124 op::Pad(after_halo_exchange, op::Constant()));
2125 EXPECT_THAT(root, op::Select(_, op::DynamicSlice(pad, op::Constant(), _), _));
2126 }
2127
TEST_F(SpmdPartitioningTest,PadAlongPartitionedDimensionWithInteriorPadding)2128 TEST_F(SpmdPartitioningTest, PadAlongPartitionedDimensionWithInteriorPadding) {
2129 absl::string_view hlo_string = R"(
2130 HloModule module
2131
2132 ENTRY entry {
2133 %param0 = f32[7] parameter(0), sharding={devices=[2]0,1}
2134 %param1 = f32[] parameter(1), sharding={replicated}
2135 ROOT %pad = f32[22] pad(%param0, %param1), padding=2_1_2,
2136 sharding={devices=[2]0,1}
2137 })";
2138
2139 TF_ASSERT_OK_AND_ASSIGN(auto module,
2140 PartitionComputation(hlo_string, /*num_devices=*/2));
2141 VLOG(1) << module->ToString();
2142 auto root = module->entry_computation()->root_instruction();
2143
2144 auto param0 = AllOf(op::Parameter(), op::Shape("f32[4]"));
2145 auto after_halo_exchange = AllOf(
2146 op::Shape("f32[4]"),
2147 op::DynamicSlice(
2148 AllOf(op::Shape("f32[5]"),
2149 op::Pad(AllOf(op::Shape("f32[4]"),
2150 op::Concatenate(
2151 op::CollectivePermute(op::Slice(param0)),
2152 op::Slice(param0))),
2153 op::Parameter(1))),
2154 _));
2155 auto pad = op::Pad(after_halo_exchange, op::Parameter(1));
2156 EXPECT_THAT(root, op::DynamicSlice(pad, _));
2157 }
2158
TEST_F(SpmdPartitioningTest,PartialReplicatePad)2159 TEST_F(SpmdPartitioningTest, PartialReplicatePad) {
2160 absl::string_view hlo_string = R"(
2161 HloModule module
2162
2163 ENTRY entry {
2164 %param0 = f32[11,7] parameter(0),
2165 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
2166 %param1 = f32[] parameter(1), sharding={replicated}
2167 ROOT %pad = f32[27,22] pad(%param0, %param1), padding=2_4_1x2_1_2,
2168 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
2169 })";
2170
2171 TF_ASSERT_OK_AND_ASSIGN(auto module,
2172 PartitionComputation(hlo_string, /*num_devices=*/4));
2173 VLOG(1) << module->ToString();
2174 auto root = module->entry_computation()->root_instruction();
2175
2176 auto param0 = AllOf(op::Parameter(), op::Shape("f32[11,4]"));
2177 auto after_halo_exchange = AllOf(
2178 op::Shape("f32[11,4]"),
2179 op::DynamicSlice(
2180 AllOf(op::Shape("f32[11,5]"),
2181 op::Pad(AllOf(op::Shape("f32[11,4]"),
2182 op::Concatenate(
2183 op::CollectivePermute(op::Slice(param0)),
2184 op::Slice(param0))),
2185 op::Parameter(1))),
2186 op::Constant(), _));
2187 auto pad = op::Pad(after_halo_exchange, op::Parameter(1));
2188 EXPECT_THAT(root, AllOf(op::DynamicSlice(pad, op::Constant(), _),
2189 op::Shape("f32[27,11]")));
2190 }
2191
TEST_F(SpmdPartitioningTest,SliceAlongNonPartitionedDimension)2192 TEST_F(SpmdPartitioningTest, SliceAlongNonPartitionedDimension) {
2193 absl::string_view hlo_string = R"(
2194 HloModule module
2195
2196 ENTRY entry {
2197 %param0 = f32[128,14,257] parameter(0)
2198 %param0.copy = f32[128,14,257] copy(%param0), sharding={devices=[1,1,2]0,1}
2199 ROOT %slice = f32[128,11,257] slice(%param0.copy),
2200 slice={[0:128:1], [2:13:1], [0:257:1]}, sharding={devices=[1,1,2]0,1}
2201 })";
2202
2203 TF_ASSERT_OK_AND_ASSIGN(auto module,
2204 PartitionComputation(hlo_string, /*num_devices=*/2));
2205 VLOG(1) << module->ToString();
2206
2207 auto root = module->entry_computation()->root_instruction();
2208 auto param0 = AllOf(
2209 op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
2210 op::Constant(), op::Constant(), op::Reshape())),
2211 op::Shape("f32[128,14,129]"));
2212 EXPECT_THAT(root, AllOf(op::Slice(param0), op::Shape("f32[128,11,129]")));
2213 }
2214
TEST_F(SpmdPartitioningTest,SliceAlongPartitionedDimension)2215 TEST_F(SpmdPartitioningTest, SliceAlongPartitionedDimension) {
2216 absl::string_view hlo_string = R"(
2217 HloModule module
2218
2219 ENTRY entry {
2220 %param0 = f32[128,14,257] parameter(0), sharding={devices=[1,1,2]0,1}
2221 ROOT %slice = f32[63,14,251] slice(%param0),
2222 slice={[2:128:2], [0:14:1], [5:256:1]}, sharding={devices=[1,1,2]0,1}
2223 })";
2224
2225 TF_ASSERT_OK_AND_ASSIGN(auto module,
2226 PartitionComputation(hlo_string, /*num_devices=*/2));
2227 VLOG(1) << module->ToString();
2228
2229 auto root = module->entry_computation()->root_instruction();
2230 auto param0 = AllOf(op::Parameter(0), op::Shape("f32[128,14,129]"));
2231 EXPECT_THAT(
2232 root,
2233 AllOf(op::Slice(AllOf(
2234 op::DynamicSlice(
2235 AllOf(op::Concatenate(
2236 op::Slice(param0),
2237 AllOf(op::CollectivePermute(op::Slice(param0)),
2238 op::Shape("f32[128,14,2]"))),
2239 op::Shape("f32[128,14,129]")),
2240 op::Constant(), op::Constant(), op::Add()),
2241 op::Shape("f32[128,14,126]"))),
2242 op::Shape("f32[63,14,126]")));
2243 }
2244
TEST_F(SpmdPartitioningTest,SliceAlongPartitionedDimension2)2245 TEST_F(SpmdPartitioningTest, SliceAlongPartitionedDimension2) {
2246 absl::string_view hlo_string = R"(
2247 HloModule module
2248
2249 ENTRY entry {
2250 %param0 = f32[4] parameter(0), sharding={devices=[4]0,1,2,3}
2251 ROOT %slice = f32[1] slice(%param0),
2252 slice={[3:4]}, sharding={devices=[4]0,1,2,3}
2253 })";
2254
2255 TF_ASSERT_OK_AND_ASSIGN(auto module,
2256 PartitionComputation(hlo_string, /*num_devices=*/4));
2257 VLOG(1) << module->ToString();
2258
2259 auto root = module->entry_computation()->root_instruction();
2260 auto param0 = AllOf(op::Parameter(0), op::Shape("f32[1]"));
2261 EXPECT_THAT(root, AllOf(op::Copy(op::CollectivePermute(param0)),
2262 op::Shape("f32[1]")));
2263 }
2264
TEST_F(SpmdPartitioningTest,MergedPadThenSliceShiftRight)2265 TEST_F(SpmdPartitioningTest, MergedPadThenSliceShiftRight) {
2266 absl::string_view hlo_string = R"(
2267 HloModule module
2268
2269 ENTRY entry {
2270 %param0 = f32[4] parameter(0), sharding={devices=[4]0,1,2,3}
2271 %init = f32[] constant(2.0)
2272 %pad = f32[5] pad(%param0, %init), padding=1_0, sharding={devices=[4]0,1,2,3}
2273 %copy = f32[5] copy(%pad), sharding={devices=[4]0,1,2,3}
2274 %copy.1 = f32[5] copy(%copy), sharding={devices=[4]0,1,2,3}
2275 ROOT %slice = f32[4] slice(%copy.1), slice={[0:4]}, sharding={devices=[4]0,1,2,3}
2276 })";
2277
2278 TF_ASSERT_OK_AND_ASSIGN(auto module,
2279 PartitionComputation(hlo_string, /*num_devices=*/4));
2280 VLOG(1) << module->ToString();
2281
2282 auto root = module->entry_computation()->root_instruction();
2283 auto param0 = AllOf(op::Parameter(0), op::Shape("f32[1]"));
2284 EXPECT_THAT(root, AllOf(op::Select(_, op::CollectivePermute(param0), _),
2285 op::Shape("f32[1]")));
2286 }
2287
2288 // Same as above except that it uses zero padding, so there is no need for
2289 // masking.
TEST_F(SpmdPartitioningTest,MergedPadThenSliceShiftRightNoMasking)2290 TEST_F(SpmdPartitioningTest, MergedPadThenSliceShiftRightNoMasking) {
2291 absl::string_view hlo_string = R"(
2292 HloModule module
2293
2294 ENTRY entry {
2295 %param0 = f32[4] parameter(0), sharding={devices=[4]0,1,2,3}
2296 %init = f32[] constant(0)
2297 %pad = f32[5] pad(%param0, %init), padding=1_0, sharding={devices=[4]0,1,2,3}
2298 %copy = f32[5] copy(%pad), sharding={devices=[4]0,1,2,3}
2299 %copy.1 = f32[5] copy(%copy), sharding={devices=[4]0,1,2,3}
2300 ROOT %slice = f32[4] slice(%copy.1), slice={[0:4]}, sharding={devices=[4]0,1,2,3}
2301 })";
2302
2303 TF_ASSERT_OK_AND_ASSIGN(auto module,
2304 PartitionComputation(hlo_string, /*num_devices=*/4));
2305 VLOG(1) << module->ToString();
2306
2307 auto root = module->entry_computation()->root_instruction();
2308 auto param0 = AllOf(op::Parameter(0), op::Shape("f32[1]"));
2309 EXPECT_THAT(root, AllOf(op::CollectivePermute(param0), op::Shape("f32[1]")));
2310 }
2311
TEST_F(SpmdPartitioningTest,MergedSliceThenConcatRotateRight)2312 TEST_F(SpmdPartitioningTest, MergedSliceThenConcatRotateRight) {
2313 absl::string_view hlo_string = R"(
2314 HloModule module
2315
2316 ENTRY entry {
2317 %param0 = f32[12] parameter(0), sharding={devices=[4]0,1,2,3}
2318 %slice0 = f32[2] slice(%param0), slice={[10:12]}, sharding={devices=[4]0,1,2,3}
2319 %slice1 = f32[10] slice(%param0), slice={[0:10]}, sharding={devices=[4]0,1,2,3}
2320 ROOT %concat = f32[12] concatenate(%slice0, %slice1), dimensions={0},
2321 sharding={devices=[4]0,1,2,3}
2322 })";
2323
2324 TF_ASSERT_OK_AND_ASSIGN(auto module,
2325 PartitionComputation(hlo_string, /*num_devices=*/4));
2326 VLOG(1) << module->ToString();
2327
2328 auto root = module->entry_computation()->root_instruction();
2329 auto param0 = AllOf(op::Parameter(0), op::Shape("f32[3]"));
2330 auto rotate = op::Concatenate(op::CollectivePermute(op::Slice(param0)),
2331 op::Slice(param0));
2332 EXPECT_THAT(root, AllOf(rotate, op::Shape("f32[3]")));
2333 }
2334
TEST_F(SpmdPartitioningTest,MergedSliceThenConcatRotateRightWithAlignedPadding)2335 TEST_F(SpmdPartitioningTest,
2336 MergedSliceThenConcatRotateRightWithAlignedPadding) {
2337 absl::string_view hlo_string = R"(
2338 HloModule module
2339
2340 ENTRY entry {
2341 %param0 = f32[6] parameter(0), sharding={devices=[4]0,1,2,3}
2342 %slice0 = f32[2] slice(%param0), slice={[4:6]}, sharding={devices=[4]0,1,2,3}
2343 %slice1 = f32[4] slice(%param0), slice={[0:4]}, sharding={devices=[4]0,1,2,3}
2344 ROOT %concat = f32[6] concatenate(%slice0, %slice1), dimensions={0},
2345 sharding={devices=[4]0,1,2,3}
2346 })";
2347
2348 TF_ASSERT_OK_AND_ASSIGN(auto module,
2349 PartitionComputation(hlo_string, /*num_devices=*/4));
2350 VLOG(1) << module->ToString();
2351
2352 auto root = module->entry_computation()->root_instruction();
2353 auto param0 = AllOf(op::Parameter(0), op::Shape("f32[2]"));
2354 EXPECT_THAT(root, op::CollectivePermute(param0));
2355 }
2356
TEST_F(SpmdPartitioningTest,MergedSliceThenConcatRotateRightWithUnalignedPadding)2357 TEST_F(SpmdPartitioningTest,
2358 MergedSliceThenConcatRotateRightWithUnalignedPadding) {
2359 absl::string_view hlo_string = R"(
2360 HloModule module
2361
2362 ENTRY entry {
2363 %param0 = f32[10] parameter(0), sharding={devices=[4]0,1,2,3}
2364 %slice0 = f32[6] slice(%param0), slice={[4:10]}, sharding={devices=[4]0,1,2,3}
2365 %slice1 = f32[4] slice(%param0), slice={[0:4]}, sharding={devices=[4]0,1,2,3}
2366 ROOT %concat = f32[10] concatenate(%slice0, %slice1), dimensions={0},
2367 sharding={devices=[4]0,1,2,3}
2368 })";
2369
2370 TF_ASSERT_OK_AND_ASSIGN(auto module,
2371 PartitionComputation(hlo_string, /*num_devices=*/4));
2372 VLOG(1) << module->ToString();
2373
2374 auto root = module->entry_computation()->root_instruction();
2375 auto param0 = AllOf(op::Parameter(0), op::Shape("f32[3]"));
2376 auto rotate0 = op::CollectivePermute(param0);
2377 auto rotate1 = op::Concatenate(op::CollectivePermute(op::Slice(param0)),
2378 op::CollectivePermute(op::Slice(param0)));
2379 EXPECT_THAT(root,
2380 AllOf(op::Select(_, rotate1, rotate0), op::Shape("f32[3]")));
2381 }
2382
TEST_F(SpmdPartitioningTest,PartialReplicateSliceAlongNonPartitionedDimension)2383 TEST_F(SpmdPartitioningTest,
2384 PartialReplicateSliceAlongNonPartitionedDimension) {
2385 absl::string_view hlo_string = R"(
2386 HloModule module
2387
2388 ENTRY entry {
2389 %param0 = f32[128,14,257] parameter(0), sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate}
2390 ROOT %slice = f32[128,11,257] slice(%param0),
2391 slice={[0:128:1], [2:13:1], [0:257:1]}, sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate}
2392 })";
2393
2394 TF_ASSERT_OK_AND_ASSIGN(auto module,
2395 PartitionComputation(hlo_string, /*num_devices=*/4));
2396 VLOG(1) << module->ToString();
2397
2398 auto root = module->entry_computation()->root_instruction();
2399 auto param0 = AllOf(op::Parameter(), op::Shape("f32[128,14,129]"));
2400 EXPECT_THAT(root, AllOf(op::Slice(param0), op::Shape("f32[128,11,129]")));
2401 }
2402
TEST_F(SpmdPartitioningTest,PartialReplicateSliceAlongPartitionedDimension)2403 TEST_F(SpmdPartitioningTest, PartialReplicateSliceAlongPartitionedDimension) {
2404 absl::string_view hlo_string = R"(
2405 HloModule module
2406
2407 ENTRY entry {
2408 %param0 = f32[128,14,257] parameter(0), sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate}
2409 ROOT %slice = f32[63,14,251] slice(%param0),
2410 slice={[2:128:2], [0:14:1], [5:256:1]}, sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dim_replicate}
2411 })";
2412
2413 TF_ASSERT_OK_AND_ASSIGN(auto module,
2414 PartitionComputation(hlo_string, /*num_devices=*/4));
2415 VLOG(1) << module->ToString();
2416
2417 auto root = module->entry_computation()->root_instruction();
2418 auto param0 = AllOf(op::Parameter(), op::Shape("f32[128,14,129]"));
2419 EXPECT_THAT(
2420 root,
2421 AllOf(
2422 op::Slice(AllOf(
2423 op::DynamicSlice(
2424 AllOf(op::Concatenate(
2425 op::Slice(param0),
2426 AllOf(op::CollectivePermute(op::Slice(param0)),
2427 op::Shape("f32[128,14,2]"))),
2428 op::Shape("f32[128,14,129]")),
2429 op::Constant(), op::Constant(),
2430 op::Add(op::Multiply(op::Reshape(op::DynamicSlice(
2431 op::Constant(), op::PartitionId())),
2432 op::Constant()),
2433 op::Constant())),
2434 op::Shape("f32[128,14,126]"))),
2435 op::Shape("f32[63,14,126]")));
2436 }
2437
TEST_F(SpmdPartitioningTest,SortAlongNonPartitionedDimension)2438 TEST_F(SpmdPartitioningTest, SortAlongNonPartitionedDimension) {
2439 absl::string_view hlo_string = R"(
2440 HloModule module
2441
2442 ge {
2443 p.0.lhs.1247 = f32[]{:T(256)} parameter(0), sharding={replicated}
2444 bitcast-convert = s32[]{:T(256)} bitcast-convert(p.0.lhs.1247), sharding={replicated}
2445 constant = s32[]{:T(256)} constant(0), sharding={replicated}
2446 compare = pred[]{:T(256)E(32)} compare(bitcast-convert, constant), direction=LT, sharding={replicated}
2447 constant.1 = u32[]{:T(256)} constant(2147483647), sharding={replicated}
2448 bitcast-convert.1 = u32[]{:T(256)} bitcast-convert(p.0.lhs.1247), sharding={replicated}
2449 subtract = u32[]{:T(256)} subtract(constant.1, bitcast-convert.1), sharding={replicated}
2450 bitcast-convert.2 = s32[]{:T(256)} bitcast-convert(subtract), sharding={replicated}
2451 select = s32[]{:T(256)} select(compare, bitcast-convert.2, bitcast-convert), sharding={replicated}
2452 p.0.rhs.1248 = f32[]{:T(256)} parameter(1), sharding={replicated}
2453 bitcast-convert.3 = s32[]{:T(256)} bitcast-convert(p.0.rhs.1248), sharding={replicated}
2454 compare.1 = pred[]{:T(256)E(32)} compare(bitcast-convert.3, constant), direction=LT, sharding={replicated}
2455 bitcast-convert.4 = u32[]{:T(256)} bitcast-convert(p.0.rhs.1248), sharding={replicated}
2456 subtract.1 = u32[]{:T(256)} subtract(constant.1, bitcast-convert.4), sharding={replicated}
2457 bitcast-convert.5 = s32[]{:T(256)} bitcast-convert(subtract.1), sharding={replicated}
2458 select.1 = s32[]{:T(256)} select(compare.1, bitcast-convert.5, bitcast-convert.3), sharding={replicated}
2459 compare.2 = pred[]{:T(256)E(32)} compare(select, select.1), direction=GT, sharding={replicated}
2460 compare.258 = pred[]{:T(256)E(32)} compare(select.1, select), direction=GT, sharding={replicated}
2461 compare.259 = pred[]{:T(256)E(32)} compare(compare.2, compare.258), direction=EQ, sharding={replicated}
2462 p.1.lhs.1249 = s32[]{:T(256)} parameter(2), sharding={replicated}
2463 p.1.rhs.1250 = s32[]{:T(256)} parameter(3), sharding={replicated}
2464 compare.260 = pred[]{:T(256)E(32)} compare(p.1.lhs.1249, p.1.rhs.1250), direction=LT, sharding={replicated}
2465 ROOT select.86 = pred[]{:T(256)E(32)} select(compare.259, compare.260, compare.2), sharding={replicated}
2466 }
2467
2468 ENTRY entry {
2469 %param0 = f32[128,14,257] parameter(0)
2470 %param0.copy = f32[128,14,257] copy(%param0), sharding={devices=[1,2,1]0,1}
2471 %param1 = s32[128,14,257] parameter(1)
2472 %param1.copy = s32[128,14,257] copy(%param1), sharding={devices=[1,2,1]0,1}
2473 ROOT %sort.6 = (f32[128,14,257]{2,1,0:T(8,128)}, s32[128,14,257]{2,1,0:T(8,128)})
2474 sort(%param0.copy, %param1.copy), dimensions={2}, is_stable=true,
2475 to_apply=%ge, sharding={{devices=[1,2,1]0,1},{devices=[1,2,1]0,1}}
2476 })";
2477
2478 TF_ASSERT_OK_AND_ASSIGN(auto module,
2479 PartitionComputation(hlo_string, /*num_devices=*/2));
2480 VLOG(1) << module->ToString();
2481
2482 auto root = module->entry_computation()->root_instruction();
2483 auto param0 =
2484 AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
2485 op::Reshape(), op::Constant())),
2486 op::Shape("f32[128,7,257]"));
2487 auto param1 =
2488 AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(),
2489 op::Reshape(), op::Constant())),
2490 op::Shape("s32[128,7,257]"));
2491 EXPECT_THAT(root, AllOf(op::Sort(param0, param1),
2492 op::Shape("(f32[128,7,257], s32[128,7,257])")));
2493 }
2494
TEST_F(SpmdPartitioningTest,PartitionCustomCall)2495 TEST_F(SpmdPartitioningTest, PartitionCustomCall) {
2496 absl::string_view hlo_string = R"(
2497 HloModule cluster_2013453984438090939__.47
2498
2499 ENTRY %cluster_2013453984438090939__.47
2500 (arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
2501 %arg_tuple.1 = bf16[2,209664] parameter(0)
2502 %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
2503 %custom-call = (bf16[2,2000]{1,0}, s32[2,2000]{1,0})
2504 custom-call(bf16[2,209664]{1,0} %copy.arg_tuple.1), custom_call_target="TopK"
2505 %get-tuple-element = bf16[2,2000]{1,0}
2506 get-tuple-element((bf16[2,2000]{1,0}, s32[2,2000]{1,0}) %custom-call),
2507 index=0, sharding={replicated}
2508 %get-tuple-element.1 = s32[2,2000]{1,0} get-tuple-element((bf16[2,2000]{1,0},
2509 s32[2,2000]{1,0}) %custom-call), index=1, sharding={replicated}
2510 ROOT %tuple.46 = (bf16[2,2000]{1,0}, s32[2,2000]{1,0})
2511 tuple(bf16[2,2000]{1,0} %get-tuple-element, s32[2,2000]{1,0}
2512 %get-tuple-element.1), sharding={{replicated}, {replicated}},
2513 metadata={op_name="XLA_Retvals"}
2514 })";
2515
2516 TF_ASSERT_OK_AND_ASSIGN(auto module,
2517 PartitionComputation(hlo_string, /*num_devices=*/2));
2518 VLOG(1) << module->ToString();
2519 auto custom_call = FindInstruction(module.get(), "custom-call.1");
2520 EXPECT_EQ(custom_call->operand(0)->shape().dimensions(1), 104832);
2521 auto sort = FindInstruction(module.get(), "sort");
2522 EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 4000);
2523 EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 4000);
2524 }
2525
TEST_F(SpmdPartitioningTest,PartitionCustomCall_TwoPartitionedDims)2526 TEST_F(SpmdPartitioningTest, PartitionCustomCall_TwoPartitionedDims) {
2527 absl::string_view hlo_string = R"(
2528 HloModule module
2529
2530 ENTRY entry {
2531 %param0 = f32[8,32128] parameter(0)
2532 %copy.0 = f32[8,32128] copy(%param0),
2533 sharding={devices=[4,2]0,1,2,3,4,5,6,7}
2534 %custom-call = (f32[8,2]{1,0}, s32[8,2]{1,0})
2535 custom-call(%copy.0), custom_call_target="TopK"
2536 %get-tuple-element = f32[8,2]{1,0}
2537 get-tuple-element((f32[8,2]{1,0}, s32[8,2]{1,0}) %custom-call), index=0,
2538 sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
2539 %get-tuple-element.1 = s32[8,2]{1,0}
2540 get-tuple-element((f32[8,2]{1,0}, s32[8,2]{1,0}) %custom-call), index=1,
2541 sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
2542 ROOT %tuple = (f32[8,2]{1,0}, s32[8,2]{1,0})
2543 tuple(%get-tuple-element, %get-tuple-element.1),
2544 sharding={{replicated}, {replicated}}
2545 })";
2546
2547 TF_ASSERT_OK_AND_ASSIGN(auto module,
2548 PartitionComputation(hlo_string, /*num_devices=*/8));
2549 VLOG(1) << module->ToString();
2550 auto custom_call = FindInstruction(module.get(), "custom-call.1");
2551 EXPECT_EQ(custom_call->operand(0)->shape().dimensions(1), 16064);
2552 auto sort = FindInstruction(module.get(), "sort");
2553 EXPECT_EQ(sort->operand(0)->shape().dimensions(0), 2);
2554 EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 4);
2555 EXPECT_EQ(sort->operand(1)->shape().dimensions(0), 2);
2556 EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 4);
2557 }
2558
TEST_F(SpmdPartitioningTest,PartitionSortInTopK)2559 TEST_F(SpmdPartitioningTest, PartitionSortInTopK) {
2560 absl::string_view hlo_string = R"(
2561 HloModule module
2562
2563 %compare-greater-than.8 (p.0.lhs.9: bf16[], p.0.rhs.10: bf16[], p.1.lhs.11:
2564 s32[], p.1.rhs.12: s32[]) -> pred[] {
2565 %p.1.lhs.11 = s32[] parameter(2)
2566 %p.1.rhs.12 = s32[] parameter(3)
2567 %p.0.lhs.9 = bf16[] parameter(0)
2568 %convert.13 = f32[] convert(bf16[] %p.0.lhs.9)
2569 %bitcast-convert.16 = s32[] bitcast-convert(f32[] %convert.13)
2570 %constant.20 = s32[] constant(0)
2571 %compare.21 = pred[] compare(s32[] %bitcast-convert.16, s32[] %constant.20),
2572 direction=LT
2573 %constant.15 = u32[] constant(2147483647)
2574 %bitcast-convert.17 = u32[] bitcast-convert(f32[] %convert.13)
2575 %subtract.18 = u32[] subtract(u32[] %constant.15, u32[] %bitcast-convert.17)
2576 %bitcast-convert.19 = s32[] bitcast-convert(u32[] %subtract.18)
2577 %select.22 = s32[] select(pred[] %compare.21, s32[] %bitcast-convert.19, s32[]
2578 %bitcast-convert.16)
2579 %p.0.rhs.10 = bf16[] parameter(1)
2580 %convert.14 = f32[] convert(bf16[] %p.0.rhs.10)
2581 %bitcast-convert.24 = s32[] bitcast-convert(f32[] %convert.14)
2582 %constant.28 = s32[] constant(0)
2583 %compare.29 = pred[] compare(s32[] %bitcast-convert.24, s32[] %constant.28),
2584 direction=LT
2585 %constant.23 = u32[] constant(2147483647)
2586 %bitcast-convert.25 = u32[] bitcast-convert(f32[] %convert.14)
2587 %subtract.26 = u32[] subtract(u32[] %constant.23, u32[] %bitcast-convert.25)
2588 %bitcast-convert.27 = s32[] bitcast-convert(u32[] %subtract.26)
2589 %select.30 = s32[] select(pred[] %compare.29, s32[] %bitcast-convert.27, s32[]
2590 %bitcast-convert.24)
2591 ROOT %compare.31 = pred[] compare(s32[] %select.22, s32[] %select.30),
2592 direction=GT
2593 }
2594
2595 ENTRY entry
2596 (arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
2597 %arg_tuple.1 = bf16[2,209664] parameter(0)
2598 %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
2599 %iota.7 = s32[2,209664] iota(), iota_dimension=1,
2600 metadata={op_type="TopKV2" op_name="TopKV2"}
2601 %sort.32 = (bf16[2,209664], s32[2,209664])
2602 sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
2603 dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
2604 metadata={op_type="TopKV2" op_name="TopKV2"}
2605 %get-tuple-element.33 = bf16[2,209664]
2606 get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2607 index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
2608 %slice.34 = bf16[2,2000] slice(bf16[2,209664]
2609 %get-tuple-element.33), slice={[0:2], [0:2000]},
2610 metadata={op_type="TopKV2" op_name="TopKV2"}
2611 %get-tuple-element.35 = s32[2,209664]
2612 get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2613 index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
2614 %slice.36 = s32[2,2000] slice(s32[2,209664]
2615 %get-tuple-element.35), slice={[0:2], [0:2000]},
2616 metadata={op_type="TopKV2" op_name="TopKV2"}
2617 ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
2618 tuple(bf16[2,2000] %slice.34, s32[2,2000]
2619 %slice.36), sharding={{replicated}, {replicated}},
2620 metadata={op_name="XLA_Retvals"}
2621 })";
2622
2623 TF_ASSERT_OK_AND_ASSIGN(auto module,
2624 PartitionComputation(hlo_string, /*num_devices=*/2));
2625 VLOG(1) << module->ToString();
2626 auto sort = FindInstruction(module.get(), "sort");
2627 EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 104832);
2628 EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 104832);
2629 auto final_sort = FindInstruction(module.get(), "sort.1");
2630 EXPECT_EQ(final_sort->operand(0)->shape().dimensions(1), 4000);
2631 EXPECT_EQ(final_sort->operand(1)->shape().dimensions(1), 4000);
2632 }
2633
TEST_F(SpmdPartitioningTest,PartitionSortInTopKWhenComparisonWithSelect)2634 TEST_F(SpmdPartitioningTest, PartitionSortInTopKWhenComparisonWithSelect) {
2635 absl::string_view hlo_string = R"(
2636 HloModule module
2637
2638 %compare-greater-than.8 (p.0.lhs.2566: bf16[],
2639 p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
2640 %p.0.lhs.2566 = bf16[] parameter(0)
2641 %convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
2642 %bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
2643 %constant.285 = s32[] constant(0)
2644 %compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
2645 direction=LT
2646 %constant.286 = u32[] constant(2147483647)
2647 %bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
2648 %subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
2649 %bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
2650 %select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
2651 s32[] %bitcast-convert.48)
2652 %p.0.rhs.2567 = bf16[] parameter(1)
2653 %convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
2654 %bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
2655 %compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
2656 direction=LT
2657 %bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
2658 %subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
2659 %bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
2660 %select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
2661 s32[] %bitcast-convert.51)
2662 %compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
2663 %compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
2664 %compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
2665 direction=EQ
2666 %p.1.lhs.2586 = s32[] parameter(2)
2667 %p.1.rhs.2587 = s32[] parameter(3)
2668 %compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
2669 direction=LT
2670 ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
2671 pred[] %compare.86)
2672 }
2673
2674 ENTRY entry
2675 (arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
2676 %arg_tuple.1 = bf16[2,209664] parameter(0)
2677 %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
2678 %iota.7 = s32[2,209664] iota(), iota_dimension=1,
2679 metadata={op_type="TopKV2" op_name="TopKV2"}
2680 %sort.32 = (bf16[2,209664], s32[2,209664])
2681 sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
2682 dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
2683 metadata={op_type="TopKV2" op_name="TopKV2"}
2684 %get-tuple-element.33 = bf16[2,209664]
2685 get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2686 index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
2687 %slice.34 = bf16[2,2000] slice(bf16[2,209664]
2688 %get-tuple-element.33), slice={[0:2], [0:2000]},
2689 metadata={op_type="TopKV2" op_name="TopKV2"}
2690 %get-tuple-element.35 = s32[2,209664]
2691 get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2692 index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
2693 %slice.36 = s32[2,2000] slice(s32[2,209664]
2694 %get-tuple-element.35), slice={[0:2], [0:2000]},
2695 metadata={op_type="TopKV2" op_name="TopKV2"}
2696 ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
2697 tuple(bf16[2,2000] %slice.34, s32[2,2000]
2698 %slice.36), sharding={{replicated}, {replicated}},
2699 metadata={op_name="XLA_Retvals"}
2700 })";
2701
2702 TF_ASSERT_OK_AND_ASSIGN(auto module,
2703 PartitionComputation(hlo_string, /*num_devices=*/2));
2704 VLOG(1) << module->ToString();
2705 auto sort = FindInstruction(module.get(), "sort");
2706 EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 104832);
2707 EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 104832);
2708 auto final_sort = FindInstruction(module.get(), "sort.1");
2709 EXPECT_EQ(final_sort->operand(0)->shape().dimensions(1), 4000);
2710 EXPECT_EQ(final_sort->operand(1)->shape().dimensions(1), 4000);
2711 }
2712
TEST_F(SpmdPartitioningTest,NoPartitionSortInTopKWhenSecondOperandIsNotIota)2713 TEST_F(SpmdPartitioningTest, NoPartitionSortInTopKWhenSecondOperandIsNotIota) {
2714 absl::string_view hlo_string = R"(
2715 HloModule module
2716
2717 %compare-greater-than.8 (p.0.lhs.2566: bf16[],
2718 p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
2719 %p.0.lhs.2566 = bf16[] parameter(0)
2720 %convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
2721 %bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
2722 %constant.285 = s32[] constant(0)
2723 %compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
2724 direction=LT
2725 %constant.286 = u32[] constant(2147483647)
2726 %bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
2727 %subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
2728 %bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
2729 %select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
2730 s32[] %bitcast-convert.48)
2731 %p.0.rhs.2567 = bf16[] parameter(1)
2732 %convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
2733 %bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
2734 %compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
2735 direction=LT
2736 %bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
2737 %subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
2738 %bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
2739 %select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
2740 s32[] %bitcast-convert.51)
2741 %compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
2742 %compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
2743 %compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
2744 direction=EQ
2745 %p.1.lhs.2586 = s32[] parameter(2)
2746 %p.1.rhs.2587 = s32[] parameter(3)
2747 %compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
2748 direction=LT
2749 ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
2750 pred[] %compare.86)
2751 }
2752
2753 ENTRY entry {
2754 %arg_tuple.1 = bf16[2,209664] parameter(0)
2755 %arg_tuple.2 = s32[2,209664] parameter(1)
2756 %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
2757 %sort.32 = (bf16[2,209664], s32[2,209664])
2758 sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %arg_tuple.2),
2759 dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
2760 metadata={op_type="TopKV2" op_name="TopKV2"}
2761 %get-tuple-element.33 = bf16[2,209664]
2762 get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2763 index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
2764 %slice.34 = bf16[2,2000] slice(bf16[2,209664]
2765 %get-tuple-element.33), slice={[0:2], [0:2000]},
2766 metadata={op_type="TopKV2" op_name="TopKV2"}
2767 %get-tuple-element.35 = s32[2,209664]
2768 get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2769 index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
2770 %slice.36 = s32[2,2000] slice(s32[2,209664]
2771 %get-tuple-element.35), slice={[0:2], [0:2000]},
2772 metadata={op_type="TopKV2" op_name="TopKV2"}
2773 ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
2774 tuple(bf16[2,2000] %slice.34, s32[2,2000]
2775 %slice.36), sharding={{replicated}, {replicated}},
2776 metadata={op_name="XLA_Retvals"}
2777 })";
2778
2779 TF_ASSERT_OK_AND_ASSIGN(auto module,
2780 PartitionComputation(hlo_string, /*num_devices=*/2));
2781 VLOG(1) << module->ToString();
2782 auto sort = FindInstruction(module.get(), "sort");
2783 EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664);
2784 EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664);
2785 }
2786
TEST_F(SpmdPartitioningTest,NoPartitionSortInTopKWhenNoPartitionInSortDim)2787 TEST_F(SpmdPartitioningTest, NoPartitionSortInTopKWhenNoPartitionInSortDim) {
2788 absl::string_view hlo_string = R"(
2789 HloModule module
2790
2791 %compare-greater-than.8 (p.0.lhs.2566: bf16[],
2792 p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
2793 %p.0.lhs.2566 = bf16[] parameter(0)
2794 %convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
2795 %bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
2796 %constant.285 = s32[] constant(0)
2797 %compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
2798 direction=LT
2799 %constant.286 = u32[] constant(2147483647)
2800 %bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
2801 %subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
2802 %bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
2803 %select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
2804 s32[] %bitcast-convert.48)
2805 %p.0.rhs.2567 = bf16[] parameter(1)
2806 %convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
2807 %bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
2808 %compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
2809 direction=LT
2810 %bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
2811 %subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
2812 %bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
2813 %select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
2814 s32[] %bitcast-convert.51)
2815 %compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
2816 %compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
2817 %compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
2818 direction=EQ
2819 %p.1.lhs.2586 = s32[] parameter(2)
2820 %p.1.rhs.2587 = s32[] parameter(3)
2821 %compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
2822 direction=LT
2823 ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
2824 pred[] %compare.86)
2825 }
2826
2827 ENTRY entry
2828 (arg_tuple.1: ()) -> (bf16[2,2000], s32[2,2000]) {
2829 %arg_tuple.1 = bf16[2,209664] parameter(0)
2830 %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[2,1]0,1}
2831 %iota.7 = s32[2,209664] iota(), iota_dimension=1,
2832 metadata={op_type="TopKV2" op_name="TopKV2"}
2833 %sort.32 = (bf16[2,209664], s32[2,209664])
2834 sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
2835 dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
2836 metadata={op_type="TopKV2" op_name="TopKV2"}
2837 %get-tuple-element.33 = bf16[2,209664]
2838 get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2839 index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
2840 %slice.34 = bf16[2,2000] slice(bf16[2,209664]
2841 %get-tuple-element.33), slice={[0:2], [0:2000]},
2842 metadata={op_type="TopKV2" op_name="TopKV2"}
2843 %get-tuple-element.35 = s32[2,209664]
2844 get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2845 index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
2846 %slice.36 = s32[2,2000] slice(s32[2,209664]
2847 %get-tuple-element.35), slice={[0:2], [0:2000]},
2848 metadata={op_type="TopKV2" op_name="TopKV2"}
2849 ROOT %tuple.46 = (bf16[2,2000], s32[2,2000])
2850 tuple(bf16[2,2000] %slice.34, s32[2,2000]
2851 %slice.36), sharding={{replicated}, {replicated}},
2852 metadata={op_name="XLA_Retvals"}
2853 })";
2854
2855 TF_ASSERT_OK_AND_ASSIGN(auto module,
2856 PartitionComputation(hlo_string, /*num_devices=*/2));
2857 VLOG(1) << module->ToString();
2858 auto sort = FindInstruction(module.get(), "sort");
2859 EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664);
2860 EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664);
2861 }
2862
TEST_F(SpmdPartitioningTest,NoPartitionSortInTopKWhenSliceInOtherDim)2863 TEST_F(SpmdPartitioningTest, NoPartitionSortInTopKWhenSliceInOtherDim) {
2864 absl::string_view hlo_string = R"(
2865 HloModule module
2866
2867 %compare-greater-than.8 (p.0.lhs.2566: bf16[],
2868 p.0.rhs.2567: bf16[], p.1.lhs.2586: s32[], p.1.rhs.2587: s32[]) -> pred[] {
2869 %p.0.lhs.2566 = bf16[] parameter(0)
2870 %convert.164 = f32[] convert(bf16[] %p.0.lhs.2566)
2871 %bitcast-convert.48 = s32[] bitcast-convert(f32[] %convert.164)
2872 %constant.285 = s32[] constant(0)
2873 %compare.84 = pred[] compare(s32[] %bitcast-convert.48, s32[] %constant.285),
2874 direction=LT
2875 %constant.286 = u32[] constant(2147483647)
2876 %bitcast-convert.49 = u32[] bitcast-convert(f32[] %convert.164)
2877 %subtract.84 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.49)
2878 %bitcast-convert.50 = s32[] bitcast-convert(u32[] %subtract.84)
2879 %select.40 = s32[] select(pred[] %compare.84, s32[] %bitcast-convert.50,
2880 s32[] %bitcast-convert.48)
2881 %p.0.rhs.2567 = bf16[] parameter(1)
2882 %convert.165 = f32[] convert(bf16[] %p.0.rhs.2567)
2883 %bitcast-convert.51 = s32[] bitcast-convert(f32[] %convert.165)
2884 %compare.85 = pred[] compare(s32[] %bitcast-convert.51, s32[] %constant.285),
2885 direction=LT
2886 %bitcast-convert.52 = u32[] bitcast-convert(f32[] %convert.165)
2887 %subtract.85 = u32[] subtract(u32[] %constant.286, u32[] %bitcast-convert.52)
2888 %bitcast-convert.53 = s32[] bitcast-convert(u32[] %subtract.85)
2889 %select.41 = s32[] select(pred[] %compare.85, s32[] %bitcast-convert.53,
2890 s32[] %bitcast-convert.51)
2891 %compare.86 = pred[] compare(s32[] %select.40, s32[] %select.41), direction=GT
2892 %compare.1645 = pred[] compare(s32[] %select.41, s32[] %select.40), direction=GT
2893 %compare.1646 = pred[] compare(pred[] %compare.86, pred[] %compare.1645),
2894 direction=EQ
2895 %p.1.lhs.2586 = s32[] parameter(2)
2896 %p.1.rhs.2587 = s32[] parameter(3)
2897 %compare.1647 = pred[] compare(s32[] %p.1.lhs.2586, s32[] %p.1.rhs.2587),
2898 direction=LT
2899 ROOT %select.1054 = pred[] select(pred[] %compare.1646, pred[] %compare.1647,
2900 pred[] %compare.86)
2901 }
2902
2903 ENTRY entry {
2904 %arg_tuple.1 = bf16[2,209664] parameter(0)
2905 %copy.arg_tuple.1 = bf16[2,209664] copy(%arg_tuple.1), sharding={devices=[1,2]0,1}
2906 %iota.7 = s32[2,209664] iota(), iota_dimension=1,
2907 metadata={op_type="TopKV2" op_name="TopKV2"}
2908 %sort.32 = (bf16[2,209664], s32[2,209664])
2909 sort(bf16[2,209664] %copy.arg_tuple.1, s32[2,209664] %iota.7),
2910 dimensions={1}, is_stable=true, to_apply=%compare-greater-than.8,
2911 metadata={op_type="TopKV2" op_name="TopKV2"}
2912 %get-tuple-element.33 = bf16[2,209664]
2913 get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2914 index=0, metadata={op_type="TopKV2" op_name="TopKV2"}
2915 %slice.34 = bf16[1,209664] slice(bf16[2,209664]
2916 %get-tuple-element.33), slice={[0:1], [0:209664]},
2917 metadata={op_type="TopKV2" op_name="TopKV2"}
2918 %get-tuple-element.35 = s32[2,209664]
2919 get-tuple-element((bf16[2,209664], s32[2,209664]) %sort.32),
2920 index=1, metadata={op_type="TopKV2" op_name="TopKV2"}
2921 %slice.36 = s32[1,209664] slice(s32[2,209664]
2922 %get-tuple-element.35), slice={[0:1], [0:209664]},
2923 metadata={op_type="TopKV2" op_name="TopKV2"}
2924 ROOT %tuple.46 = (bf16[1,209664], s32[1,209664])
2925 tuple(bf16[1,209664] %slice.34, s32[1,209664]
2926 %slice.36), sharding={{replicated}, {replicated}},
2927 metadata={op_name="XLA_Retvals"}
2928 })";
2929
2930 TF_ASSERT_OK_AND_ASSIGN(auto module,
2931 PartitionComputation(hlo_string, /*num_devices=*/2));
2932 VLOG(1) << module->ToString();
2933 auto sort = FindInstruction(module.get(), "sort");
2934 EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664);
2935 EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664);
2936 }
2937
TEST_F(SpmdPartitioningTest,ShardableTranspose)2938 TEST_F(SpmdPartitioningTest, ShardableTranspose) {
2939 absl::string_view hlo_string = R"(
2940 HloModule module
2941
2942 ENTRY entry {
2943 %param0 = f32[16,38,38,4] parameter(0)
2944 %param0.copy = f32[16,38,38,4] copy(%param0), sharding={devices=[1,2,1,1]0,1}
2945 ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy),
2946 dimensions={0,3,1,2}, sharding={devices=[1,1,2,1]0,1}
2947 })";
2948
2949 TF_ASSERT_OK_AND_ASSIGN(auto module,
2950 PartitionComputation(hlo_string, /*num_devices=*/2));
2951 VLOG(1) << module->ToString();
2952
2953 auto root = module->entry_computation()->root_instruction();
2954 auto param0 = AllOf(
2955 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
2956 op::Constant(), op::Constant())),
2957 op::Shape("f32[16,19,38,4]"));
2958 EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[16,4,19,38]")));
2959 }
2960
TEST_F(SpmdPartitioningTest,MultiDimensionShardedTranspose)2961 TEST_F(SpmdPartitioningTest, MultiDimensionShardedTranspose) {
2962 absl::string_view hlo_string = R"(
2963 HloModule module
2964
2965 ENTRY entry {
2966 %param0 = f32[16,38,38,4] parameter(0)
2967 %param0.copy = f32[16,38,38,4] copy(%param0),
2968 sharding={devices=[4,2,1,1]0,1,2,3,4,5,6,7}
2969 ROOT %transpose = f32[38,4,16,38] transpose(%param0.copy),
2970 dimensions={1,3,0,2}, sharding={devices=[2,1,4,1]0,2,4,6,1,3,5,7}
2971 })";
2972
2973 TF_ASSERT_OK_AND_ASSIGN(auto module,
2974 PartitionComputation(hlo_string, /*num_devices=*/8));
2975 VLOG(1) << module->ToString();
2976
2977 auto root = module->entry_computation()->root_instruction();
2978 auto param0 = AllOf(
2979 op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Reshape(),
2980 op::Constant(), op::Constant())),
2981 op::Shape("f32[4,19,38,4]"));
2982 EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[19,4,4,38]")));
2983 }
2984
TEST_F(SpmdPartitioningTest,NonShardableTranspose)2985 TEST_F(SpmdPartitioningTest, NonShardableTranspose) {
2986 absl::string_view hlo_string = R"(
2987 HloModule module
2988
2989 ENTRY entry {
2990 %param0 = f32[16,38,38,4] parameter(0)
2991 %param0.copy = f32[16,38,38,4] copy(%param0), sharding={devices=[1,2,1,1]0,1}
2992 ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy),
2993 dimensions={0,3,1,2}, sharding={devices=[1,2,1,1]0,1}
2994 })";
2995
2996 TF_ASSERT_OK_AND_ASSIGN(auto module,
2997 PartitionComputation(hlo_string, /*num_devices=*/2));
2998 VLOG(1) << module->ToString();
2999
3000 auto root = module->entry_computation()->root_instruction();
3001 auto resahrd = AllOf(op::Reshape(op::Transpose(op::Reshape(op::AllToAll()))),
3002 op::Shape("f32[16,38,38,2]"));
3003 EXPECT_THAT(root, AllOf(op::Transpose(), op::Shape("f32[16,2,38,38]")));
3004 }
3005
TEST_F(SpmdPartitioningTest,PartialReplicateShardableTranspose)3006 TEST_F(SpmdPartitioningTest, PartialReplicateShardableTranspose) {
3007 absl::string_view hlo_string = R"(
3008 HloModule module
3009
3010 ENTRY entry {
3011 %param0 = f32[16,38,38,4] parameter(0)
3012 %param0.copy = f32[16,38,38,4] copy(%param0),
3013 sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate}
3014 ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy),
3015 dimensions={0,3,1,2},
3016 sharding={devices=[1,1,2,1,2]0,1,2,3 last_tile_dim_replicate}
3017 })";
3018
3019 TF_ASSERT_OK_AND_ASSIGN(auto module,
3020 PartitionComputation(hlo_string, /*num_devices=*/4));
3021 VLOG(1) << module->ToString();
3022
3023 auto root = module->entry_computation()->root_instruction();
3024 auto param0 = AllOf(
3025 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
3026 op::Constant(), op::Constant())),
3027 op::Shape("f32[16,19,38,4]"));
3028 EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[16,4,19,38]")));
3029 }
3030
TEST_F(SpmdPartitioningTest,PartialReplicateNonShardableTranspose)3031 TEST_F(SpmdPartitioningTest, PartialReplicateNonShardableTranspose) {
3032 absl::string_view hlo_string = R"(
3033 HloModule module
3034
3035 ENTRY entry {
3036 %param0 = f32[16,38,38,4] parameter(0)
3037 %param0.copy = f32[16,38,38,4] copy(%param0),
3038 sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate}
3039 ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy),
3040 dimensions={0,3,1,2},
3041 sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate}
3042 })";
3043
3044 TF_ASSERT_OK_AND_ASSIGN(auto module,
3045 PartitionComputation(hlo_string, /*num_devices=*/4));
3046 VLOG(1) << module->ToString();
3047
3048 auto root = module->entry_computation()->root_instruction();
3049 auto resahrd = AllOf(op::Reshape(op::Transpose(op::Reshape(op::AllToAll()))),
3050 op::Shape("f32[16,38,38,2]"));
3051 EXPECT_THAT(root, AllOf(op::Transpose(), op::Shape("f32[16,2,38,38]")));
3052 }
3053
TEST_F(SpmdPartitioningTest,PartialReplicateMultiDimensionShardedTranspose)3054 TEST_F(SpmdPartitioningTest, PartialReplicateMultiDimensionShardedTranspose) {
3055 absl::string_view hlo_string = R"(
3056 HloModule module
3057
3058 ENTRY entry {
3059 %param0 = f32[16,38,38,4] parameter(0)
3060 %param0.copy = f32[16,38,38,4] copy(%param0),
3061 sharding={devices=[2,2,1,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
3062 ROOT %transpose = f32[38,4,16,38] transpose(%param0.copy),
3063 dimensions={1,3,0,2},
3064 sharding={devices=[2,1,2,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}
3065 })";
3066
3067 TF_ASSERT_OK_AND_ASSIGN(auto module,
3068 PartitionComputation(hlo_string, /*num_devices=*/8));
3069 VLOG(1) << module->ToString();
3070
3071 auto root = module->entry_computation()->root_instruction();
3072 auto param0 = AllOf(
3073 op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Reshape(),
3074 op::Constant(), op::Constant())),
3075 op::Shape("f32[8,19,38,4]"));
3076 EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[19,4,8,38]")));
3077 }
3078
TEST_F(SpmdPartitioningTest,ShardableReshape)3079 TEST_F(SpmdPartitioningTest, ShardableReshape) {
3080 absl::string_view hlo_string = R"(
3081 HloModule module
3082
3083 ENTRY entry {
3084 %param0 = f32[38,38,324] parameter(0)
3085 %param0.copy = f32[38,38,324] copy(%param0), sharding={devices=[2,1,1]0,1}
3086 ROOT %reshape = f32[38,38,4,81] reshape(%param0.copy),
3087 sharding={devices=[2,1,1,1]0,1}
3088 })";
3089
3090 TF_ASSERT_OK_AND_ASSIGN(auto module,
3091 PartitionComputation(hlo_string, /*num_devices=*/2));
3092 VLOG(1) << module->ToString();
3093
3094 auto root = module->entry_computation()->root_instruction();
3095 auto param0 =
3096 AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
3097 op::Constant(), op::Constant())),
3098 op::Shape("f32[19,38,324]"));
3099 EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]")));
3100 }
3101
TEST_F(SpmdPartitioningTest,ReshapeWithReshard)3102 TEST_F(SpmdPartitioningTest, ReshapeWithReshard) {
3103 absl::string_view hlo_string = R"(
3104 HloModule module
3105
3106 ENTRY entry {
3107 %param0 = f32[38,38,324] parameter(0), sharding={devices=[2,1,1]0,1}
3108 ROOT %reshape = f32[38,38,4,81] reshape(%param0),
3109 sharding={devices=[1,2,1,1]0,1}
3110 })";
3111
3112 TF_ASSERT_OK_AND_ASSIGN(auto module,
3113 PartitionComputation(hlo_string, /*num_devices=*/2));
3114 VLOG(1) << module->ToString();
3115
3116 auto root = module->entry_computation()->root_instruction();
3117 auto input_reshard =
3118 op::Reshape(op::Transpose(op::AllToAll(op::Reshape(op::Parameter(0)))));
3119 EXPECT_THAT(root,
3120 AllOf(op::Reshape(input_reshard), op::Shape("f32[38,19,4,81]")));
3121 }
3122
TEST_F(SpmdPartitioningTest,ReshapeWithReshard2)3123 TEST_F(SpmdPartitioningTest, ReshapeWithReshard2) {
3124 absl::string_view hlo_string = R"(
3125 HloModule module
3126
3127 ENTRY entry {
3128 %param0 = f32[38,38,324] parameter(0), sharding={devices=[2,1,1]0,1}
3129 ROOT %reshape = f32[38,38,2,162] reshape(%param0),
3130 sharding={devices=[1,1,1,2]0,1}
3131 })";
3132
3133 TF_ASSERT_OK_AND_ASSIGN(auto module,
3134 PartitionComputation(hlo_string, /*num_devices=*/2));
3135 VLOG(1) << module->ToString();
3136
3137 auto root = module->entry_computation()->root_instruction();
3138 auto local_reshape =
3139 AllOf(op::Reshape(op::Parameter(0)), op::Shape("f32[19,38,2,162]"));
3140 EXPECT_THAT(root, AllOf(op::Shape("f32[38,38,2,81]"),
3141 op::Reshape(op::Transpose(
3142 op::AllToAll(op::Reshape(local_reshape))))));
3143 }
3144
TEST_F(SpmdPartitioningTest,PartialReplicateShardableReshape)3145 TEST_F(SpmdPartitioningTest, PartialReplicateShardableReshape) {
3146 absl::string_view hlo_string = R"(
3147 HloModule module
3148
3149 ENTRY entry {
3150 %param0 = f32[38,38,324] parameter(0)
3151 %param0.copy = f32[38,38,324] copy(%param0),
3152 sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate}
3153 ROOT %reshape = f32[38,38,4,81] reshape(%param0.copy),
3154 sharding={devices=[2,1,1,1,2]0,1,2,3 last_tile_dim_replicate}
3155 })";
3156
3157 TF_ASSERT_OK_AND_ASSIGN(auto module,
3158 PartitionComputation(hlo_string, /*num_devices=*/4));
3159 VLOG(1) << module->ToString();
3160
3161 auto root = module->entry_computation()->root_instruction();
3162 auto param0 =
3163 AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
3164 op::Constant(), op::Constant())),
3165 op::Shape("f32[19,38,324]"));
3166 EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]")));
3167 }
3168
TEST_F(SpmdPartitioningTest,ReshapeMergeDimsWithHaloExchange)3169 TEST_F(SpmdPartitioningTest, ReshapeMergeDimsWithHaloExchange) {
3170 absl::string_view hlo_string = R"(
3171 HloModule module
3172
3173 ENTRY entry {
3174 %input = s32[2,3,7,10] parameter(0), sharding={devices=[1,1,2,1]0,1}
3175 ROOT %reshape = s32[3,2,1,14,5] reshape(%input),
3176 sharding={devices=[1,1,1,2,1]0,1}
3177 })";
3178
3179 TF_ASSERT_OK_AND_ASSIGN(auto module,
3180 PartitionComputation(hlo_string, /*num_devices=*/2));
3181 VLOG(1) << module->ToString();
3182
3183 auto reshape =
3184 AllOf(op::Reshape(op::Parameter(0)), op::Shape("s32[3,2,1,8,5]"));
3185 auto halo = op::CollectivePermute(op::Slice(reshape));
3186 auto exchanged = op::DynamicSlice(op::Concatenate(halo, op::Slice(reshape)),
3187 _, _, _, _, _);
3188 auto root = module->entry_computation()->root_instruction();
3189 EXPECT_THAT(root, AllOf(exchanged, op::Shape("s32[3,2,1,7,5]")));
3190 }
3191
TEST_F(SpmdPartitioningTest,PartialReplicateReshapeMergeDimsWithHaloExchange)3192 TEST_F(SpmdPartitioningTest, PartialReplicateReshapeMergeDimsWithHaloExchange) {
3193 absl::string_view hlo_string = R"(
3194 HloModule module
3195
3196 ENTRY entry {
3197 %input = s32[2,3,7,10] parameter(0),
3198 sharding={devices=[1,1,2,1,2]0,1,2,3 last_tile_dim_replicate}
3199 ROOT %reshape = s32[3,2,1,14,5] reshape(%input),
3200 sharding={devices=[1,1,1,2,1,2]0,1,2,3 last_tile_dim_replicate}
3201 })";
3202
3203 TF_ASSERT_OK_AND_ASSIGN(auto module,
3204 PartitionComputation(hlo_string, /*num_devices=*/4));
3205 VLOG(1) << module->ToString();
3206
3207 auto reshape =
3208 AllOf(op::Reshape(op::Parameter(0)), op::Shape("s32[3,2,1,8,5]"));
3209 auto halo = op::CollectivePermute(op::Slice(reshape));
3210 auto exchanged = op::DynamicSlice(op::Concatenate(halo, op::Slice(reshape)),
3211 _, _, _, _, _);
3212 auto root = module->entry_computation()->root_instruction();
3213 EXPECT_THAT(root, AllOf(exchanged, op::Shape("s32[3,2,1,7,5]")));
3214 }
3215
3216 // Produces an invalid module after transformation.
TEST_F(SpmdPartitioningTest,InceptionV3_4_way_ReduceWindowDilated)3217 TEST_F(SpmdPartitioningTest, InceptionV3_4_way_ReduceWindowDilated) {
3218 absl::string_view hlo_string = R"(
3219 HloModule module
3220
3221 sum {
3222 a = f32[] parameter(0)
3223 b = f32[] parameter(1)
3224 ROOT add = f32[] add(a, b)
3225 }
3226
3227 ENTRY entry {
3228 %param0 = f32[128,5,5,768] parameter(0)
3229 %param0.copy = f32[128,5,5,768] copy(%param0),
3230 sharding={devices=[1,4,1,1]0,1,2,3}
3231 %constant.1 = f32[] constant(0), sharding={replicated}
3232 ROOT %rw = f32[128,17,17,768] reduce-window(%param0.copy, %constant.1),
3233 window={size=1x5x5x1 pad=0_0x4_4x4_4x0_0 lhs_dilate=1x3x3x1},
3234 to_apply=sum, sharding={devices=[1,4,1,1]0,1,2,3}
3235 })";
3236
3237 TF_ASSERT_OK_AND_ASSIGN(auto module,
3238 PartitionComputation(hlo_string, /*num_devices=*/4));
3239 VLOG(1) << module->ToString();
3240
3241 auto input_shard = op::Copy(op::DynamicSlice(
3242 op::Pad(op::Parameter(0), op::Constant()), op::Constant(), op::Reshape(),
3243 op::Constant(), op::Constant()));
3244 auto id_mul4_add1 =
3245 op::Add(op::Multiply(op::Reshape(), op::Constant()), op::Constant());
3246 auto id_mul5 = op::Multiply(op::Reshape(), op::Constant());
3247 auto id_mul5_add1_div3 =
3248 op::Divide(op::Add(id_mul5, op::Constant()), op::Constant());
3249 auto before_masking = AllOf(
3250 op::Shape("f32[128,3,5,768]"),
3251 op::DynamicSlice(
3252 AllOf(
3253 op::Shape("f32[128,4,5,768]"),
3254 op::Concatenate(op::CollectivePermute(input_shard), input_shard)),
3255 op::Constant(),
3256 op::Subtract(op::Constant(),
3257 op::Subtract(id_mul4_add1, id_mul5_add1_div3)),
3258 op::Constant(), op::Constant()));
3259 auto masked = op::Select(
3260 op::And(op::Compare(op::Add(op::Iota(), op::Broadcast(id_mul5_add1_div3)),
3261 op::Broadcast(op::Constant())),
3262 op::Compare(op::Add(op::Iota(), op::Broadcast(id_mul5_add1_div3)),
3263 op::Broadcast(op::Constant()))),
3264 before_masking, op::Broadcast(op::Constant()));
3265 auto rw = AllOf(op::Shape("f32[128,7,17,768]"),
3266 op::ReduceWindow(masked, op::Constant()));
3267 auto final_slice_index = op::Subtract(
3268 id_mul5,
3269 op::Add(op::Multiply(id_mul5_add1_div3, op::Constant()), op::Constant()));
3270 auto root = module->entry_computation()->root_instruction();
3271 EXPECT_THAT(root,
3272 AllOf(op::Shape("f32[128,5,17,768]"),
3273 op::DynamicSlice(rw, op::Constant(), final_slice_index,
3274 op::Constant(), op::Constant())));
3275 }
3276
TEST_F(SpmdPartitioningTest,TiledToTiledReduce)3277 TEST_F(SpmdPartitioningTest, TiledToTiledReduce) {
3278 absl::string_view hlo_string = R"(
3279 HloModule module
3280
3281 sum {
3282 a = f32[] parameter(0)
3283 b = f32[] parameter(1)
3284 ROOT add = f32[] add(a, b)
3285 }
3286
3287 ENTRY entry {
3288 %param0 = f32[4,32,32,128] parameter(0)
3289 %param0.copy = f32[4,32,32,128] copy(%param0),
3290 sharding={devices=[1,1,1,2]0,1}
3291 %constant.1 = f32[] constant(0), sharding={replicated}
3292 %reduce = f32[128] reduce(%param0.copy, %constant.1), dimensions={0,1,2},
3293 to_apply=%sum, sharding={devices=[2]0,1}
3294 })";
3295
3296 TF_ASSERT_OK_AND_ASSIGN(auto module,
3297 PartitionComputation(hlo_string, /*num_devices=*/2));
3298 VLOG(1) << module->ToString();
3299
3300 auto root = module->entry_computation()->root_instruction();
3301 auto param0 = AllOf(
3302 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
3303 op::Constant(), op::Reshape())),
3304 op::Shape("f32[4,32,32,64]"));
3305
3306 EXPECT_THAT(root,
3307 AllOf(op::Reduce(param0, op::Constant()), op::Shape("f32[64]")));
3308 }
3309
TEST_F(SpmdPartitioningTest,PartialTiledToPartialTiledReduce)3310 TEST_F(SpmdPartitioningTest, PartialTiledToPartialTiledReduce) {
3311 absl::string_view hlo_string = R"(
3312 HloModule module
3313
3314 sum {
3315 a = f32[] parameter(0)
3316 b = f32[] parameter(1)
3317 ROOT add = f32[] add(a, b)
3318 }
3319
3320 ENTRY entry {
3321 %param0 = f32[4,4] parameter(0),
3322 sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
3323 %constant.1 = f32[] constant(0), sharding={replicated}
3324 ROOT %reduce = f32[4] reduce(%param0, %constant.1), dimensions={0},
3325 to_apply=%sum,
3326 sharding={devices=[2,4]0,1,4,5,2,3,6,7 last_tile_dim_replicate}
3327 })";
3328
3329 TF_ASSERT_OK_AND_ASSIGN(auto module,
3330 PartitionComputation(hlo_string, /*num_devices=*/8));
3331 VLOG(1) << module->ToString();
3332
3333 auto root = module->entry_computation()->root_instruction();
3334 EXPECT_THAT(root,
3335 AllOf(op::AllReduce(op::Reduce(op::Parameter(0), op::Constant())),
3336 op::Shape("f32[2]")));
3337 }
3338
TEST_F(SpmdPartitioningTest,TiledToTiledTupleReduce)3339 TEST_F(SpmdPartitioningTest, TiledToTiledTupleReduce) {
3340 absl::string_view hlo_string = R"(
3341 HloModule module
3342
3343 %minmax_func {
3344 %lhs_value = f32[] parameter(0)
3345 %rhs_value = f32[] parameter(2)
3346 %compare.2 = pred[] compare(%lhs_value, %rhs_value), direction=GT
3347 %select.4 = f32[] select(%compare.2, %lhs_value, %rhs_value)
3348 %lhs_index = s32[] parameter(1)
3349 %rhs_index = s32[] parameter(3)
3350 %select.5 = s32[] select(%compare.2, %lhs_index, %rhs_index)
3351 ROOT %tuple.2 = (f32[], s32[]) tuple(%select.4, %select.5)
3352 }
3353
3354 ENTRY %main {
3355 %param0 = f32[28,10] parameter(0), sharding={devices=[2,1]0,1}
3356 %param1 = s32[28,10] parameter(1), sharding={devices=[2,1]0,1}
3357 %init0 = f32[] parameter(2)
3358 %init1 = s32[] parameter(3)
3359 ROOT %reduce = (f32[28], s32[28]) reduce(%param0, %param1, %init0, %init1),
3360 dimensions={1}, to_apply=%minmax_func,
3361 sharding={{devices=[2]0,1}, {devices=[2]0,1}}
3362 })";
3363
3364 TF_ASSERT_OK_AND_ASSIGN(auto module,
3365 PartitionComputation(hlo_string, /*num_devices=*/2));
3366 VLOG(1) << module->ToString();
3367
3368 auto root = module->entry_computation()->root_instruction();
3369 EXPECT_THAT(root, AllOf(op::Reduce(op::Parameter(0), op::Parameter(1),
3370 op::Parameter(2), op::Parameter(3)),
3371 op::Shape("(f32[14], s32[14])")));
3372 }
3373
TEST_F(SpmdPartitioningTest,TiledToPartiallyTiledTupleReduce)3374 TEST_F(SpmdPartitioningTest, TiledToPartiallyTiledTupleReduce) {
3375 absl::string_view hlo_string = R"(
3376 HloModule module
3377
3378 %minmax_func {
3379 %lhs_value = f32[] parameter(0)
3380 %rhs_value = f32[] parameter(2)
3381 %compare.2 = pred[] compare(%lhs_value, %rhs_value), direction=GT
3382 %select.4 = f32[] select(%compare.2, %lhs_value, %rhs_value)
3383 %lhs_index = s32[] parameter(1)
3384 %rhs_index = s32[] parameter(3)
3385 %select.5 = s32[] select(%compare.2, %lhs_index, %rhs_index)
3386 ROOT %tuple.2 = (f32[], s32[]) tuple(%select.4, %select.5)
3387 }
3388
3389 ENTRY %main {
3390 %param0 = f32[28,12] parameter(0), sharding={devices=[2,4]0,1,2,3,4,5,6,7}
3391 %param1 = s32[28,12] parameter(1), sharding={devices=[2,4]0,1,2,3,4,5,6,7}
3392 %init0 = f32[] parameter(2)
3393 %init1 = s32[] parameter(3)
3394 ROOT %reduce = (f32[28], s32[28]) reduce(%param0, %param1, %init0, %init1),
3395 dimensions={1}, to_apply=%minmax_func,
3396 sharding={{devices=[2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate},
3397 {devices=[2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}}
3398 })";
3399
3400 TF_ASSERT_OK_AND_ASSIGN(auto module,
3401 PartitionComputation(hlo_string, /*num_devices=*/8));
3402 VLOG(1) << module->ToString();
3403
3404 auto lhs = AllOf(op::Shape("f32[14,3]"), op::Parameter(0));
3405 auto rhs = AllOf(op::Shape("s32[14,3]"), op::Parameter(1));
3406 auto local_reduce =
3407 AllOf(op::Reduce(lhs, rhs, op::Parameter(2), op::Parameter(3)),
3408 op::Shape("(f32[14], s32[14])"));
3409 auto reshape_l = AllOf(op::Reshape(op::GetTupleElement(local_reduce)),
3410 op::Shape("f32[14,1]"));
3411 auto reshape_r = AllOf(op::Reshape(op::GetTupleElement(local_reduce)),
3412 op::Shape("s32[14,1]"));
3413 auto broadcast_l =
3414 AllOf(op::AllReduce(op::DynamicUpdateSlice(_, reshape_l, _, _)),
3415 op::Shape("f32[14,4]"));
3416 auto broadcast_r =
3417 AllOf(op::AllReduce(op::DynamicUpdateSlice(_, reshape_r, _, _)),
3418 op::Shape("s32[14,4]"));
3419 auto root = module->entry_computation()->root_instruction();
3420 EXPECT_THAT(root, AllOf(op::Reduce(broadcast_l, broadcast_r, op::Parameter(2),
3421 op::Parameter(3)),
3422 op::Shape("(f32[14], s32[14])")));
3423 }
3424
TEST_F(SpmdPartitioningTest,TiledToTiledReduceOutputReshard)3425 TEST_F(SpmdPartitioningTest, TiledToTiledReduceOutputReshard) {
3426 absl::string_view hlo_string = R"(
3427 HloModule module
3428
3429 sum {
3430 a = f32[] parameter(0)
3431 b = f32[] parameter(1)
3432 ROOT add = f32[] add(a, b)
3433 }
3434
3435 ENTRY entry {
3436 %param0 = f32[4,32,32,128] parameter(0)
3437 %param0.copy = f32[4,32,32,128] copy(%param0),
3438 sharding={devices=[1,2,1,1]0,1}
3439 %constant.1 = f32[] constant(0), sharding={replicated}
3440 %reduce = f32[128] reduce(%param0.copy, %constant.1), dimensions={0,1,2},
3441 to_apply=%sum, sharding={devices=[2]0,1}
3442 })";
3443
3444 TF_ASSERT_OK_AND_ASSIGN(auto module,
3445 PartitionComputation(hlo_string, /*num_devices=*/2));
3446 VLOG(1) << module->ToString();
3447
3448 auto root = module->entry_computation()->root_instruction();
3449 auto param0 = AllOf(
3450 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
3451 op::Constant(), op::Constant())),
3452 op::Shape("f32[4,16,32,128]"));
3453
3454 EXPECT_THAT(root,
3455 AllOf(op::DynamicSlice(
3456 AllOf(op::AllReduce(op::Reduce(param0, op::Constant())),
3457 op::Shape("f32[128]")),
3458 op::Reshape()),
3459 op::Shape("f32[64]")));
3460 }
3461
TEST_F(SpmdPartitioningTest,IotaAlongNonTileDimension)3462 TEST_F(SpmdPartitioningTest, IotaAlongNonTileDimension) {
3463 absl::string_view hlo_string = R"(
3464 HloModule module
3465
3466 ENTRY entry {
3467 ROOT %iota = s32[16,80,91] iota(), iota_dimension=1,
3468 sharding={devices=[1,1,2]0,1}
3469 })";
3470
3471 TF_ASSERT_OK_AND_ASSIGN(auto module,
3472 PartitionComputation(hlo_string, /*num_devices=*/2));
3473 VLOG(1) << module->ToString();
3474
3475 auto root = module->entry_computation()->root_instruction();
3476 EXPECT_THAT(root, AllOf(op::Iota(), op::Shape("s32[16,80,46]")));
3477 }
3478
TEST_F(SpmdPartitioningTest,IotaAlongTileDimension)3479 TEST_F(SpmdPartitioningTest, IotaAlongTileDimension) {
3480 absl::string_view hlo_string = R"(
3481 HloModule module
3482
3483 ENTRY entry {
3484 ROOT %iota = s32[16,80,91] iota(), iota_dimension=2,
3485 sharding={devices=[1,1,2]0,1}
3486 })";
3487
3488 TF_ASSERT_OK_AND_ASSIGN(auto module,
3489 PartitionComputation(hlo_string, /*num_devices=*/2));
3490 VLOG(1) << module->ToString();
3491
3492 auto root = module->entry_computation()->root_instruction();
3493 EXPECT_THAT(root, AllOf(op::Add(op::Iota(), op::Broadcast()),
3494 op::Shape("s32[16,80,46]")));
3495 }
3496
TEST_F(SpmdPartitioningTest,U32IotaAlongTileDimension)3497 TEST_F(SpmdPartitioningTest, U32IotaAlongTileDimension) {
3498 absl::string_view hlo_string = R"(
3499 HloModule module
3500
3501 ENTRY entry {
3502 ROOT %iota = u32[16,80,91] iota(), iota_dimension=2,
3503 sharding={devices=[1,1,2]0,1}
3504 })";
3505
3506 TF_ASSERT_OK_AND_ASSIGN(auto module,
3507 PartitionComputation(hlo_string, /*num_devices=*/2));
3508 VLOG(1) << module->ToString();
3509
3510 auto root = module->entry_computation()->root_instruction();
3511 EXPECT_THAT(root, AllOf(op::Add(op::Iota(), op::Broadcast()),
3512 op::Shape("u32[16,80,46]")));
3513 }
3514
TEST_F(SpmdPartitioningTest,Conditional)3515 TEST_F(SpmdPartitioningTest, Conditional) {
3516 absl::string_view hlo_string = R"(
3517 HloModule module
3518
3519 Negate {
3520 x = f32[4,5] parameter(0), sharding={replicated}
3521 ROOT negate = f32[4,5] negate(x), sharding={replicated}
3522 }
3523
3524 Identity {
3525 y = f32[4,5] parameter(0), sharding={devices=[2,1]0,1}
3526 ROOT copy = f32[4,5] copy(y), sharding={devices=[2,1]0,1}
3527 }
3528
3529 ENTRY entry {
3530 %param.0 = pred[] parameter(0)
3531 %param.0.copy = pred[] copy(%param.0), sharding={maximal device=0}
3532 %param.1 = f32[4,5] parameter(1)
3533 %param.1.copy = f32[4,5] copy(%param.1), sharding={replicated}
3534 %param.2 = f32[4,5] parameter(2)
3535 %param.2.copy = f32[4,5] copy(%param.2), sharding={devices=[2,1]0,1}
3536 ROOT cond = f32[4,5] conditional(%param.0.copy, %param.1.copy, %param.2.copy),
3537 true_computation=Negate, false_computation=Identity,
3538 sharding={devices=[2,1]0,1}
3539 })";
3540
3541 TF_ASSERT_OK_AND_ASSIGN(auto module,
3542 PartitionComputation(hlo_string, /*num_devices=*/2));
3543 VLOG(1) << module->ToString();
3544
3545 auto param0 = AllOf(op::Copy(op::Copy(op::Parameter()), op::Shape("pred[]")));
3546 auto param1 = AllOf(op::Copy(op::Parameter()), op::Shape("f32[4,5]"));
3547 auto param2 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
3548 op::Constant())),
3549 op::Shape("f32[2,5]"));
3550
3551 auto root = module->entry_computation()->root_instruction();
3552 EXPECT_THAT(root, AllOf(op::Conditional(op::AllReduce(), param1, param2),
3553 op::Shape("f32[2,5]")));
3554
3555 auto then_branch_root = root->branch_computation(0)->root_instruction();
3556 EXPECT_THAT(then_branch_root,
3557 AllOf(op::DynamicSlice(op::Negate(op::Parameter()), op::Reshape(),
3558 op::Constant()),
3559 op::Shape("f32[2,5]")));
3560
3561 auto else_branch_root = root->branch_computation(1)->root_instruction();
3562 EXPECT_THAT(else_branch_root,
3563 AllOf(op::Copy(op::Parameter()), op::Shape("f32[2,5]")));
3564 }
3565
TEST_F(SpmdPartitioningTest,ConditionalManual)3566 TEST_F(SpmdPartitioningTest, ConditionalManual) {
3567 absl::string_view hlo_string = R"(
3568 HloModule module
3569
3570 Negate {
3571 x = f32[4,5] parameter(0), sharding={manual}
3572 ROOT negate = f32[4,5] negate(x), sharding={manual}
3573 }
3574
3575 Identity {
3576 y = f32[4,5] parameter(0), sharding={manual}
3577 ROOT copy = f32[4,5] copy(y), sharding={manual}
3578 }
3579
3580 ENTRY entry {
3581 %param.0 = pred[] parameter(0), sharding={manual}
3582 %param.1 = f32[4,5] parameter(1), sharding={manual}
3583 %param.2 = f32[4,5] parameter(2), sharding={manual}
3584 ROOT cond = f32[4,5] conditional(%param.0, %param.1, %param.2),
3585 true_computation=Negate, false_computation=Identity, sharding={manual}
3586 })";
3587
3588 TF_ASSERT_OK_AND_ASSIGN(auto module,
3589 PartitionComputation(hlo_string, /*num_devices=*/2));
3590 VLOG(1) << module->ToString();
3591
3592 auto param0 = AllOf(op::Parameter(0), op::Shape("pred[]"));
3593 auto param1 = AllOf(op::Parameter(1), op::Shape("f32[4,5]"));
3594 auto param2 = AllOf(op::Parameter(2), op::Shape("f32[4,5]"));
3595
3596 auto root = module->entry_computation()->root_instruction();
3597 EXPECT_THAT(root, AllOf(op::Conditional(param0, param1, param2),
3598 op::Shape("f32[4,5]")));
3599 }
3600
TEST_F(SpmdPartitioningTest,WhileManual)3601 TEST_F(SpmdPartitioningTest, WhileManual) {
3602 absl::string_view hlo_string = R"(
3603 HloModule module
3604
3605 LoopCond {
3606 x = s32[] parameter(0), sharding={manual}
3607 const = s32[] constant(5), sharding={manual}
3608 ROOT lt = pred[] compare(x, const), direction=LT, sharding={manual}
3609 }
3610
3611 Inc {
3612 x = s32[] parameter(0), sharding={manual}
3613 const = s32[] constant(1), sharding={manual}
3614 ROOT add = s32[] add(x, const), sharding={manual}
3615 }
3616
3617 ENTRY entry {
3618 zero = s32[] parameter(0), sharding={manual}
3619 ROOT while = s32[] while(zero), body=Inc, condition=LoopCond,
3620 sharding={manual}
3621 })";
3622
3623 TF_ASSERT_OK_AND_ASSIGN(auto module,
3624 PartitionComputation(hlo_string, /*num_devices=*/2));
3625 VLOG(1) << module->ToString();
3626
3627 auto zero = AllOf(op::Parameter(0), op::Shape("s32[]"));
3628 auto root = module->entry_computation()->root_instruction();
3629 EXPECT_THAT(root, AllOf(op::While(zero), op::Shape("s32[]")));
3630 }
3631
TEST_F(SpmdPartitioningTest,SelectAndScatter_RetinaNet)3632 TEST_F(SpmdPartitioningTest, SelectAndScatter_RetinaNet) {
3633 absl::string_view hlo_string = R"(
3634 HloModule module
3635
3636 ge {
3637 a = f32[] parameter(0)
3638 b = f32[] parameter(1)
3639 ROOT compare = pred[] compare(a, b), direction=GE
3640 }
3641
3642 sum {
3643 c = f32[] parameter(0)
3644 d = f32[] parameter(1)
3645 ROOT add = f32[] add(c, d)
3646 }
3647
3648 ENTRY entry {
3649 %param.0 = f32[32,128,384,64] parameter(0)
3650 %param.0.copy = f32[32,128,384,64] copy(%param.0),
3651 sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7}
3652 %param.1 = f32[32,64,192,64] parameter(1)
3653 %param.1.copy = f32[32,64,192,64] copy(%param.1),
3654 sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7}
3655 constant.1 = f32[] constant(0), sharding={replicated}
3656 ROOT select-and-scatter = f32[32,128,384,64] select-and-scatter(param.0.copy,
3657 %param.1.copy, constant.1), window={size=1x1x1x1 stride=1x2x2x1},
3658 select=ge, scatter=sum, sharding={devices=[1,8,1,1]0,1,2,3,4,5,6,7}
3659 })";
3660 TF_ASSERT_OK_AND_ASSIGN(auto module,
3661 PartitionComputation(hlo_string, /*num_devices=*/8));
3662 VLOG(1) << module->ToString();
3663
3664 auto root = module->entry_computation()->root_instruction();
3665 auto source = AllOf(
3666 op::Shape("f32[32,8,192,64]"),
3667 op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), op::Reshape(),
3668 op::Constant(), op::Constant())));
3669 auto data = AllOf(
3670 op::Shape("f32[32,16,384,64]"),
3671 op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
3672 op::Constant(), op::Constant())));
3673
3674 EXPECT_THAT(root, op::SelectAndScatter(data, source, op::Constant()));
3675 EXPECT_EQ(root->window().dimensions(0).padding_low(), 0);
3676 EXPECT_EQ(root->window().dimensions(0).padding_high(), 0);
3677 }
3678
TEST_F(SpmdPartitioningTest,TiledDot)3679 TEST_F(SpmdPartitioningTest, TiledDot) {
3680 absl::string_view hlo_string = R"(
3681 HloModule module
3682
3683 ENTRY entry {
3684 %lhs = f32[128,64] parameter(0)
3685 %lhs.copy = f32[128,64] copy(%lhs), sharding={devices=[1,2]0,1}
3686 %rhs = f32[64,256] parameter(1)
3687 %rhs.copy = f32[64,256] copy(%rhs), sharding={devices=[2,1]0,1}
3688 ROOT %conv = f32[128,256] convolution(%lhs.copy, %rhs.copy),
3689 dim_labels=bf_io->bf, sharding={replicated}
3690 })";
3691
3692 TF_ASSERT_OK_AND_ASSIGN(
3693 auto module,
3694 PartitionComputation(hlo_string, /*num_devices=*/2,
3695 /*conv_halo_exchange_always_on_lhs=*/false));
3696 VLOG(1) << module->ToString();
3697
3698 auto root = module->entry_computation()->root_instruction();
3699 auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(),
3700 op::Reshape())),
3701 op::Shape("f32[128,32]"));
3702 auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
3703 op::Constant())),
3704 op::Shape("f32[32,256]"));
3705 EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(lhs, rhs)),
3706 op::Shape("f32[128,256]")));
3707 }
3708
TEST_F(SpmdPartitioningTest,TiledDotOutputTiled)3709 TEST_F(SpmdPartitioningTest, TiledDotOutputTiled) {
3710 absl::string_view hlo_string = R"(
3711 HloModule module
3712
3713 ENTRY entry {
3714 %lhs = f32[128,64] parameter(0)
3715 %lhs.copy = f32[128,64] copy(%lhs), sharding={devices=[1,2]0,1}
3716 %rhs = f32[64,256] parameter(1)
3717 %rhs.copy = f32[64,256] copy(%rhs), sharding={devices=[2,1]0,1}
3718 ROOT %conv = f32[128,256] convolution(%lhs.copy, %rhs.copy),
3719 dim_labels=bf_io->bf, sharding={devices=[1,2]0,1}
3720 })";
3721
3722 TF_ASSERT_OK_AND_ASSIGN(auto module,
3723 PartitionComputation(hlo_string, /*num_devices=*/2));
3724 VLOG(1) << module->ToString();
3725
3726 auto root = module->entry_computation()->root_instruction();
3727 auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(),
3728 op::Reshape())),
3729 op::Shape("f32[128,32]"));
3730 auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(),
3731 op::Constant())),
3732 op::Shape("f32[32,256]"));
3733 EXPECT_THAT(root, AllOf(op::DynamicSlice(
3734 AllOf(op::AllReduce(op::Convolution(lhs, rhs)),
3735 op::Shape("f32[128,256]")),
3736 op::Constant(), op::Reshape()),
3737 op::Shape("f32[128,128]")));
3738 }
3739
TEST_F(SpmdPartitioningTest,BatchPartitionedConvolution)3740 TEST_F(SpmdPartitioningTest, BatchPartitionedConvolution) {
3741 absl::string_view hlo_string = R"(
3742 HloModule module
3743
3744 ENTRY entry {
3745 %lhs = f32[128,256,256] parameter(0)
3746 %lhs.copy = f32[128,256,256] copy(%lhs), sharding={devices=[1,2,1]0,1}
3747 %rhs = f32[256,8,1] parameter(1)
3748 %rhs.copy = f32[256,8,1] copy(%rhs), sharding={replicated}
3749 ROOT %conv = f32[128,256,8] convolution(%lhs.copy, %rhs.copy),
3750 window={size=1}, dim_labels=0bf_io0->0bf, sharding={devices=[1,2,1]0,1}
3751 })";
3752
3753 TF_ASSERT_OK_AND_ASSIGN(auto module,
3754 PartitionComputation(hlo_string, /*num_devices=*/2));
3755 VLOG(1) << module->ToString();
3756
3757 auto root = module->entry_computation()->root_instruction();
3758 auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
3759 op::Reshape(), op::Constant())),
3760 op::Shape("f32[128,128,256]"));
3761 auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[256,8,1]"));
3762 EXPECT_THAT(root,
3763 AllOf(op::Convolution(lhs, rhs), op::Shape("f32[128,128,8]")));
3764 }
3765
TEST_F(SpmdPartitioningTest,DotOutputFeaturePartitioned)3766 TEST_F(SpmdPartitioningTest, DotOutputFeaturePartitioned) {
3767 absl::string_view hlo_string = R"(
3768 HloModule module
3769
3770 ENTRY entry {
3771 %lhs = f32[24,64] parameter(0)
3772 %lhs.copy = f32[24,64] copy(%lhs), sharding={replicated}
3773 %rhs = f32[39296,64] parameter(1)
3774 %rhs.copy = f32[39296,64] copy(%rhs), sharding={devices=[2,1]0,1}
3775 ROOT %dot = f32[24,39296] dot(%lhs.copy, %rhs.copy),
3776 lhs_batch_dims={}, rhs_batch_dims={},
3777 lhs_contracting_dims={1}, rhs_contracting_dims={1},
3778 sharding={devices=[1,2]0,1}
3779 })";
3780
3781 TF_ASSERT_OK_AND_ASSIGN(auto module,
3782 PartitionComputation(hlo_string, /*num_devices=*/2));
3783 VLOG(1) << module->ToString();
3784
3785 auto root = module->entry_computation()->root_instruction();
3786 auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[24,64]"));
3787 auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(),
3788 op::Constant())),
3789 op::Shape("f32[19648,64]"));
3790 EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[24,19648]")));
3791 }
3792
TEST_F(SpmdPartitioningTest,DotPartialDeviceOrder)3793 TEST_F(SpmdPartitioningTest, DotPartialDeviceOrder) {
3794 absl::string_view hlo_string = R"(
3795 HloModule module
3796
3797 ENTRY entry {
3798 %lhs = f32[16,256,4096] parameter(0), sharding={devices=[1,1,2,2]1,3,0,2 last_tile_dim_replicate}
3799 %rhs = f32[4096,2048] parameter(1), sharding={devices=[2,2]3,1,2,0}
3800 ROOT %dot = f32[16,256,2048] dot(%lhs, %rhs),
3801 lhs_batch_dims={}, rhs_batch_dims={},
3802 lhs_contracting_dims={2}, rhs_contracting_dims={0},
3803 sharding={devices=[1,1,2,2]2,3,0,1 last_tile_dim_replicate}
3804 })";
3805
3806 TF_ASSERT_OK_AND_ASSIGN(auto module,
3807 PartitionComputation(hlo_string, /*num_devices=*/4));
3808 VLOG(1) << module->ToString();
3809
3810 auto root = module->entry_computation()->root_instruction();
3811 auto lhs = AllOf(op::Parameter(0), op::Shape("f32[16,256,2048]"));
3812 auto rhs = AllOf(op::Parameter(1), op::Shape("f32[2048,1024]"));
3813 EXPECT_THAT(root, AllOf(op::AllReduce(op::Dot(lhs, rhs)),
3814 op::Shape("f32[16,256,1024]")));
3815 }
3816
TEST_F(SpmdPartitioningTest,EinsumBatchPartitioned)3817 TEST_F(SpmdPartitioningTest, EinsumBatchPartitioned) {
3818 absl::string_view hlo_string = R"(
3819 HloModule module
3820
3821 ENTRY entry {
3822 %lhs = f32[32,24,64] parameter(0)
3823 %lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[2,1,1]0,1}
3824 %rhs = f32[32,39296,64] parameter(1)
3825 %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={devices=[2,1,1]0,1}
3826 ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
3827 lhs_batch_dims={0}, rhs_batch_dims={0},
3828 lhs_contracting_dims={2}, rhs_contracting_dims={2},
3829 sharding={devices=[2,1,1]0,1}
3830 })";
3831
3832 TF_ASSERT_OK_AND_ASSIGN(auto module,
3833 PartitionComputation(hlo_string, /*num_devices=*/2));
3834 VLOG(1) << module->ToString();
3835
3836 auto root = module->entry_computation()->root_instruction();
3837 auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
3838 op::Constant(), op::Constant())),
3839 op::Shape("f32[16,24,64]"));
3840 auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(),
3841 op::Constant(), op::Constant())),
3842 op::Shape("f32[16,39296,64]"));
3843 EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[16,24,39296]")));
3844 }
3845
TEST_F(SpmdPartitioningTest,EinsumLHSandOutputBatchPartitioned)3846 TEST_F(SpmdPartitioningTest, EinsumLHSandOutputBatchPartitioned) {
3847 absl::string_view hlo_string = R"(
3848 HloModule module
3849
3850 ENTRY entry {
3851 %lhs = f32[32,24,64] parameter(0)
3852 %lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[2,1,1]0,1}
3853 %rhs = f32[32,39296,64] parameter(1)
3854 %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated}
3855 ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
3856 lhs_batch_dims={0}, rhs_batch_dims={0},
3857 lhs_contracting_dims={2}, rhs_contracting_dims={2},
3858 sharding={devices=[2,1,1]0,1}
3859 })";
3860
3861 TF_ASSERT_OK_AND_ASSIGN(auto module,
3862 PartitionComputation(hlo_string, /*num_devices=*/2));
3863 VLOG(1) << module->ToString();
3864
3865 auto root = module->entry_computation()->root_instruction();
3866 auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
3867 op::Constant(), op::Constant())),
3868 op::Shape("f32[16,24,64]"));
3869 auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64]"));
3870 EXPECT_THAT(root, AllOf(op::Dot(lhs, op::DynamicSlice(rhs, op::Reshape(),
3871 op::Constant(),
3872 op::Constant())),
3873 op::Shape("f32[16,24,39296]")));
3874 }
3875
TEST_F(SpmdPartitioningTest,EinsumRHSandOutputBatchPartitioned)3876 TEST_F(SpmdPartitioningTest, EinsumRHSandOutputBatchPartitioned) {
3877 absl::string_view hlo_string = R"(
3878 HloModule module
3879
3880 ENTRY entry {
3881 %lhs = f32[32,24,64] parameter(0)
3882 %lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[1,2,1]0,1}
3883 %rhs = f32[32,39296,64] parameter(1)
3884 %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={devices=[2,1,1]0,1}
3885 ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
3886 lhs_batch_dims={0}, rhs_batch_dims={0},
3887 lhs_contracting_dims={2}, rhs_contracting_dims={2},
3888 sharding={devices=[2,1,1]0,1}
3889 })";
3890
3891 TF_ASSERT_OK_AND_ASSIGN(auto module,
3892 PartitionComputation(hlo_string, /*num_devices=*/2));
3893 VLOG(1) << module->ToString();
3894
3895 auto root = module->entry_computation()->root_instruction();
3896 auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
3897 op::Reshape(), op::Constant())),
3898 op::Shape("f32[32,12,64]"));
3899 auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(),
3900 op::Constant(), op::Constant())),
3901 op::Shape("f32[16,39296,64]"));
3902 auto lhs_reshard = op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))));
3903 EXPECT_THAT(root,
3904 AllOf(op::Dot(lhs_reshard, rhs), op::Shape("f32[16,24,39296]")));
3905 }
3906
TEST_F(SpmdPartitioningTest,EinsumOutputBatchPartitioned)3907 TEST_F(SpmdPartitioningTest, EinsumOutputBatchPartitioned) {
3908 absl::string_view hlo_string = R"(
3909 HloModule module
3910
3911 ENTRY entry {
3912 %lhs = f32[32,24,64] parameter(0)
3913 %lhs.copy = f32[32,24,64] copy(%lhs), sharding={replicated}
3914 %rhs = f32[32,39296,64] parameter(1)
3915 %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated}
3916 ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
3917 lhs_batch_dims={0}, rhs_batch_dims={0},
3918 lhs_contracting_dims={2}, rhs_contracting_dims={2},
3919 sharding={devices=[2,1,1]0,1}
3920 })";
3921
3922 TF_ASSERT_OK_AND_ASSIGN(auto module,
3923 PartitionComputation(hlo_string, /*num_devices=*/2));
3924 VLOG(1) << module->ToString();
3925
3926 auto root = module->entry_computation()->root_instruction();
3927 auto lhs_slice =
3928 AllOf(op::DynamicSlice(op::Copy(op::Parameter(0)), op::Reshape(),
3929 op::Constant(), op::Constant()),
3930 op::Shape("f32[16,24,64]"));
3931 auto rhs_slice =
3932 AllOf(op::DynamicSlice(op::Copy(op::Parameter(1)), op::Reshape(),
3933 op::Constant(), op::Constant()),
3934 op::Shape("f32[16,39296,64]"));
3935 EXPECT_THAT(root, AllOf(op::Dot(lhs_slice, rhs_slice),
3936 op::Shape("f32[16,24,39296]")));
3937 }
3938
TEST_F(SpmdPartitioningTest,EinsumContractingDimsPartitioned)3939 TEST_F(SpmdPartitioningTest, EinsumContractingDimsPartitioned) {
3940 absl::string_view hlo_string = R"(
3941 HloModule module
3942
3943 ENTRY entry {
3944 %lhs = f32[32,24,64,128] parameter(0)
3945 %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,1,2,2]0,1,2,3}
3946 %rhs = f32[32,39296,64,128] parameter(1)
3947 %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={devices=[1,1,2,2]0,1,2,3}
3948 ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
3949 lhs_batch_dims={0}, rhs_batch_dims={0},
3950 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
3951 sharding={replicated}
3952 })";
3953
3954 TF_ASSERT_OK_AND_ASSIGN(auto module,
3955 PartitionComputation(hlo_string, /*num_devices=*/4));
3956 VLOG(1) << module->ToString();
3957
3958 auto root = module->entry_computation()->root_instruction();
3959 auto lhs = AllOf(
3960 op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
3961 op::Constant(), op::Reshape(), op::Reshape())),
3962 op::Shape("f32[32,24,32,64]"));
3963 auto rhs = AllOf(
3964 op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(),
3965 op::Constant(), op::Reshape(), op::Reshape())),
3966 op::Shape("f32[32,39296,32,64]"));
3967 EXPECT_THAT(root, AllOf(op::AllReduce(op::AllReduce(op::Dot(lhs, rhs))),
3968 op::Shape("f32[32,24,39296]")));
3969 }
3970
TEST_F(SpmdPartitioningTest,EinsumLHSNonContractingDimsPartitioned)3971 TEST_F(SpmdPartitioningTest, EinsumLHSNonContractingDimsPartitioned) {
3972 absl::string_view hlo_string = R"(
3973 HloModule module
3974
3975 ENTRY entry {
3976 %lhs = f32[32,24,64,128] parameter(0)
3977 %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,2]0,1,2,3}
3978 %rhs = f32[32,39296,64] parameter(1)
3979 %rhs.copy = f32[32,39296,64] copy(%rhs), sharding={replicated}
3980 ROOT %dot = f32[32,24,128,39296] dot(%lhs.copy, %rhs.copy),
3981 lhs_batch_dims={0}, rhs_batch_dims={0},
3982 lhs_contracting_dims={2}, rhs_contracting_dims={2},
3983 sharding={devices=[1,2,2,1]0,1,2,3}
3984 })";
3985
3986 TF_ASSERT_OK_AND_ASSIGN(auto module,
3987 PartitionComputation(hlo_string, /*num_devices=*/4));
3988 VLOG(1) << module->ToString();
3989
3990 auto root = module->entry_computation()->root_instruction();
3991 auto lhs = AllOf(
3992 op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
3993 op::Constant(), op::Reshape())),
3994 op::Shape("f32[32,12,64,64]"));
3995 auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64]"));
3996 EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[32,12,64,39296]")));
3997 }
3998
TEST_F(SpmdPartitioningTest,EinsumRHSNonContractingDimsPartitioned)3999 TEST_F(SpmdPartitioningTest, EinsumRHSNonContractingDimsPartitioned) {
4000 absl::string_view hlo_string = R"(
4001 HloModule module
4002
4003 ENTRY entry {
4004 %lhs = f32[32,24,64] parameter(0)
4005 %lhs.copy = f32[32,24,64] copy(%lhs), sharding={replicated}
4006 %rhs = f32[32,39296,64,128] parameter(1)
4007 %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={devices=[1,2,1,2]0,1,2,3}
4008 ROOT %dot = f32[32,24,39296,128] dot(%lhs.copy, %rhs.copy),
4009 lhs_batch_dims={0}, rhs_batch_dims={0},
4010 lhs_contracting_dims={2}, rhs_contracting_dims={2},
4011 sharding={devices=[1,1,2,2]0,1,2,3}
4012 })";
4013
4014 TF_ASSERT_OK_AND_ASSIGN(auto module,
4015 PartitionComputation(hlo_string, /*num_devices=*/4));
4016 VLOG(1) << module->ToString();
4017
4018 auto root = module->entry_computation()->root_instruction();
4019 auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64]"));
4020 auto rhs = AllOf(
4021 op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), op::Reshape(),
4022 op::Constant(), op::Reshape())),
4023 op::Shape("f32[32,19648,64,64]"));
4024 EXPECT_THAT(root, AllOf(op::Dot(lhs, rhs), op::Shape("f32[32,24,19648,64]")));
4025 }
4026
TEST_F(SpmdPartitioningTest,EinsumOutputLHSNonContractingDimPartitioned)4027 TEST_F(SpmdPartitioningTest, EinsumOutputLHSNonContractingDimPartitioned) {
4028 absl::string_view hlo_string = R"(
4029 HloModule module
4030
4031 ENTRY entry {
4032 %lhs = f32[32,24,64,128] parameter(0)
4033 %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={replicated}
4034 %rhs = f32[32,39296,64,128] parameter(1)
4035 %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={replicated}
4036 ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
4037 lhs_batch_dims={0}, rhs_batch_dims={0},
4038 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4039 sharding={devices=[1,2,1]0,1}
4040 })";
4041
4042 TF_ASSERT_OK_AND_ASSIGN(auto module,
4043 PartitionComputation(hlo_string, /*num_devices=*/2));
4044 VLOG(1) << module->ToString();
4045
4046 auto root = module->entry_computation()->root_instruction();
4047 auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64,128]"));
4048 auto rhs =
4049 AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64,128]"));
4050 EXPECT_THAT(
4051 root,
4052 AllOf(op::Dot(AllOf(op::DynamicSlice(lhs, op::Constant(), op::Reshape(),
4053 op::Constant(), op::Constant()),
4054 op::Shape("f32[32,12,64,128]")),
4055 rhs),
4056 op::Shape("f32[32,12,39296]")));
4057 }
4058
TEST_F(SpmdPartitioningTest,EinsumOutputRHSNonContractingDimPartitioned)4059 TEST_F(SpmdPartitioningTest, EinsumOutputRHSNonContractingDimPartitioned) {
4060 absl::string_view hlo_string = R"(
4061 HloModule module
4062
4063 ENTRY entry {
4064 %lhs = f32[32,24,64,128] parameter(0)
4065 %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={replicated}
4066 %rhs = f32[32,39296,64,128] parameter(1)
4067 %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={replicated}
4068 ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
4069 lhs_batch_dims={0}, rhs_batch_dims={0},
4070 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4071 sharding={devices=[1,1,2]0,1}
4072 })";
4073
4074 TF_ASSERT_OK_AND_ASSIGN(auto module,
4075 PartitionComputation(hlo_string, /*num_devices=*/2));
4076 VLOG(1) << module->ToString();
4077
4078 auto root = module->entry_computation()->root_instruction();
4079 auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("f32[32,24,64,128]"));
4080 auto rhs =
4081 AllOf(op::Copy(op::Parameter(1)), op::Shape("f32[32,39296,64,128]"));
4082 EXPECT_THAT(root,
4083 AllOf(op::Dot(lhs, AllOf(op::DynamicSlice(
4084 rhs, op::Constant(), op::Reshape(),
4085 op::Constant(), op::Constant()),
4086 op::Shape("f32[32,19648,64,128]"))),
4087 op::Shape("f32[32,24,19648]")));
4088 }
4089
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedInContractingOutNonContractingPartitioned)4090 TEST_F(SpmdPartitioningTest,
4091 EinsumRHSWindowedInContractingOutNonContractingPartitioned) {
4092 absl::string_view hlo_string = R"(
4093 HloModule module
4094
4095 ENTRY entry {
4096 %lhs = f32[320,25,64,128] parameter(0)
4097 %lhs.copy = f32[320,25,64,128] copy(%lhs), sharding={devices=[1,1,4,1]0,1,2,3}
4098 %rhs = f32[320,39296,64,128] parameter(1)
4099 %rhs.copy = f32[320,39296,64,128] copy(%rhs),
4100 sharding={devices=[1,1,4,1]0,1,2,3}
4101 ROOT %dot = f32[320,25,39296] dot(%lhs.copy, %rhs.copy),
4102 lhs_batch_dims={0}, rhs_batch_dims={0},
4103 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4104 sharding={devices=[1,4,1]0,1,2,3}
4105 })";
4106
4107 TF_ASSERT_OK_AND_ASSIGN(auto module,
4108 PartitionComputation(hlo_string, /*num_devices=*/4));
4109 VLOG(1) << module->ToString();
4110
4111 auto root = module->entry_computation()->root_instruction();
4112 auto lhs = AllOf(
4113 op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
4114 op::Constant(), op::Reshape(), op::Constant())),
4115 op::Shape("f32[320,25,16,128]"));
4116 auto rhs = AllOf(
4117 op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(),
4118 op::Constant(), op::Reshape(), op::Constant())),
4119 op::Shape("f32[320,39296,16,128]"));
4120 EXPECT_THAT(
4121 root,
4122 AllOf(op::GetTupleElement(op::While(op::Tuple(
4123 lhs, rhs, op::Broadcast(), op::Broadcast(), op::Constant()))),
4124 op::Shape("f32[320,7,39296]")));
4125
4126 auto while_loop = root->operand(0);
4127 // Check loop condition.
4128 EXPECT_THAT(
4129 while_loop->while_condition()->root_instruction(),
4130 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
4131
4132 // Check loop body.
4133 auto next_i = op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
4134 auto ds =
4135 AllOf(op::DynamicSlice(
4136 op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
4137 op::Constant(), op::Reshape(), op::Constant(), op::Constant()),
4138 op::Shape("f32[320,7,16,128]"));
4139 auto partial_output =
4140 AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),
4141 op::Dot(ds, op::GetTupleElement(op::Parameter(0)))),
4142 op::Shape("f32[320,7,39296]"));
4143 auto window = op::Conditional(op::Compare(next_i, op::Constant()),
4144 partial_output, partial_output);
4145 EXPECT_THAT(while_loop->while_body()->root_instruction(),
4146 op::Tuple(op::GetTupleElement(op::Parameter(0)),
4147 op::GetTupleElement(op::Parameter(0)), window,
4148 op::GetTupleElement(op::Parameter(0)), next_i));
4149
4150 // Check the conditional that contains the collective permute.
4151 auto cp_conditional =
4152 while_loop->while_body()->root_instruction()->operand(2);
4153 EXPECT_THAT(cp_conditional->true_computation()->root_instruction(),
4154 op::CollectivePermute(op::Parameter(0)));
4155 EXPECT_THAT(cp_conditional->false_computation()->root_instruction(),
4156 op::Parameter(0));
4157 }
4158
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedInContractingOutNonContractingFromBroadcast)4159 TEST_F(SpmdPartitioningTest,
4160 EinsumRHSWindowedInContractingOutNonContractingFromBroadcast) {
4161 absl::string_view hlo_string = R"(
4162 HloModule module
4163
4164 ENTRY entry {
4165 %constant.1 = f32[] constant(2)
4166 %broadcast = f32[32,25,64,128] broadcast(%constant.1), dimensions={},
4167 sharding={devices=[1,1,4,1]0,1,2,3}
4168 %add = f32[32,25,64,128] add(%broadcast, %broadcast),
4169 sharding={devices=[1,1,4,1]0,1,2,3}
4170 %rhs = f32[32,39296,64,128] parameter(0)
4171 %rhs.copy = f32[32,39296,64,128] copy(%rhs),
4172 sharding={devices=[1,1,4,1]0,1,2,3}
4173 ROOT %dot = f32[32,25,39296] dot(%add, %rhs.copy),
4174 lhs_batch_dims={0}, rhs_batch_dims={0},
4175 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4176 sharding={devices=[1,4,1]0,1,2,3}
4177 })";
4178
4179 TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
4180 /*num_devices=*/4));
4181 VLOG(1) << module->ToString();
4182 // Involves loop code motion, skips pattern matching.
4183 }
4184
TEST_F(SpmdPartitioningTest,EinsumLHSWindowedInContractingOutNonContractingPartitioned)4185 TEST_F(SpmdPartitioningTest,
4186 EinsumLHSWindowedInContractingOutNonContractingPartitioned) {
4187 absl::string_view hlo_string = R"(
4188 HloModule module
4189
4190 ENTRY entry {
4191 %lhs = f32[16,1024,16384] parameter(0)
4192 %lhs.copy = f32[16,1024,16384] copy(%lhs),
4193 sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
4194 %rhs = f32[16384,67,128] parameter(1)
4195 %rhs.copy = f32[16384,67,128] copy(%rhs),
4196 sharding={devices=[4,1,1,2]0,4,1,5,2,6,3,7 last_tile_dim_replicate}
4197 ROOT %dot = f32[16,1024,67,128] dot(%lhs.copy, %rhs.copy),
4198 lhs_batch_dims={}, rhs_batch_dims={},
4199 lhs_contracting_dims={2}, rhs_contracting_dims={0},
4200 sharding={devices=[2,1,4,1]0,1,2,3,4,5,6,7}
4201 })";
4202
4203 TF_ASSERT_OK_AND_ASSIGN(auto module,
4204 PartitionComputation(hlo_string, /*num_devices=*/8));
4205 VLOG(1) << module->ToString();
4206
4207 auto root = module->entry_computation()->root_instruction();
4208 auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
4209 op::Constant(), op::Reshape())),
4210 op::Shape("f32[8,1024,4096]"));
4211 auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(),
4212 op::Constant(), op::Constant())),
4213 op::Shape("f32[4096,67,128]"));
4214 EXPECT_THAT(
4215 root,
4216 AllOf(op::GetTupleElement(op::While(op::Tuple(
4217 lhs, rhs, op::Broadcast(), op::Broadcast(), op::Constant()))),
4218 op::Shape("f32[8,1024,17,128]")));
4219
4220 auto while_loop = root->operand(0);
4221 // Check loop condition.
4222 EXPECT_THAT(
4223 while_loop->while_condition()->root_instruction(),
4224 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
4225
4226 // Check loop body.
4227 auto next_i = op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
4228 auto ds =
4229 AllOf(op::DynamicSlice(
4230 op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
4231 op::Constant(), op::Reshape(), op::Constant()),
4232 op::Shape("f32[4096,17,128]"));
4233 auto partial_output =
4234 AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),
4235 op::Dot(op::GetTupleElement(op::Parameter(0)), ds)),
4236 op::Shape("f32[8,1024,17,128]"));
4237 auto window = op::Conditional(op::Compare(next_i, op::Constant()),
4238 partial_output, partial_output);
4239 EXPECT_THAT(while_loop->while_body()->root_instruction(),
4240 op::Tuple(op::GetTupleElement(op::Parameter(0)),
4241 op::GetTupleElement(op::Parameter(0)), window,
4242 op::GetTupleElement(op::Parameter(0)), next_i));
4243
4244 // Check the conditional that contains the collective permute.
4245 auto cp_conditional =
4246 while_loop->while_body()->root_instruction()->operand(2);
4247 EXPECT_THAT(cp_conditional->true_computation()->root_instruction(),
4248 op::CollectivePermute(op::Parameter(0)));
4249 EXPECT_THAT(cp_conditional->false_computation()->root_instruction(),
4250 op::Parameter(0));
4251 }
4252
TEST_F(SpmdPartitioningTest,EinsumLHSWindowedInContractingOutNonContractingPartitioned2)4253 TEST_F(SpmdPartitioningTest,
4254 EinsumLHSWindowedInContractingOutNonContractingPartitioned2) {
4255 absl::string_view hlo_string = R"(
4256 HloModule module
4257
4258 ENTRY entry {
4259 %lhs = f32[16,1024,16384] parameter(0)
4260 %lhs.copy = f32[16,1024,16384] copy(%lhs),
4261 sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
4262 %rhs = f32[16384,2,33,128] parameter(1)
4263 %rhs.copy = f32[16384,2,33,128] copy(%rhs),
4264 sharding={devices=[4,1,1,1,2]0,4,1,5,2,6,3,7 last_tile_dim_replicate}
4265 ROOT %dot = f32[16,1024,2,33,128] dot(%lhs.copy, %rhs.copy),
4266 lhs_batch_dims={}, rhs_batch_dims={},
4267 lhs_contracting_dims={2}, rhs_contracting_dims={0},
4268 sharding={devices=[2,1,2,2,1]0,1,2,3,4,5,6,7}
4269 })";
4270
4271 TF_ASSERT_OK_AND_ASSIGN(auto module,
4272 PartitionComputation(hlo_string, /*num_devices=*/8));
4273 VLOG(1) << module->ToString();
4274
4275 auto root = module->entry_computation()->root_instruction();
4276 auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
4277 op::Constant(), op::Reshape())),
4278 op::Shape("f32[8,1024,4096]"));
4279 auto rhs = AllOf(
4280 op::Copy(op::DynamicSlice(op::Parameter(1), op::Reshape(), op::Constant(),
4281 op::Constant(), op::Constant())),
4282 op::Shape("f32[4096,2,33,128]"));
4283 EXPECT_THAT(
4284 root,
4285 AllOf(op::GetTupleElement(op::While(op::Tuple(
4286 lhs, rhs, op::Broadcast(), op::Broadcast(), op::Constant()))),
4287 op::Shape("f32[8,1024,1,17,128]")));
4288
4289 auto while_loop = root->operand(0);
4290 // Check loop condition.
4291 EXPECT_THAT(
4292 while_loop->while_condition()->root_instruction(),
4293 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
4294
4295 // Check loop body.
4296 auto next_i = op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
4297 auto ds =
4298 AllOf(op::DynamicSlice(
4299 op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
4300 op::Constant(), op::Reshape(), op::Reshape(), op::Constant()),
4301 op::Shape("f32[4096,1,17,128]"));
4302 auto partial_output =
4303 AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),
4304 op::Dot(op::GetTupleElement(op::Parameter(0)), ds)),
4305 op::Shape("f32[8,1024,1,17,128]"));
4306 auto window = op::Conditional(op::Compare(next_i, op::Constant()),
4307 partial_output, partial_output);
4308 EXPECT_THAT(while_loop->while_body()->root_instruction(),
4309 op::Tuple(op::GetTupleElement(op::Parameter(0)),
4310 op::GetTupleElement(op::Parameter(0)), window,
4311 op::GetTupleElement(op::Parameter(0)), next_i));
4312
4313 // Check the conditional that contains the collective permute.
4314 auto cp_conditional =
4315 while_loop->while_body()->root_instruction()->operand(2);
4316 EXPECT_THAT(cp_conditional->true_computation()->root_instruction(),
4317 op::CollectivePermute(op::Parameter(0)));
4318 EXPECT_THAT(cp_conditional->false_computation()->root_instruction(),
4319 op::Parameter(0));
4320 }
4321
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedNonContracting)4322 TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContracting) {
4323 absl::string_view hlo_string = R"(
4324 HloModule module
4325
4326 ENTRY entry {
4327 %lhs = f32[32,24,64,128] parameter(0)
4328 %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
4329 %rhs = f32[32,39295,64,128] parameter(1)
4330 %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
4331 ROOT %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
4332 lhs_batch_dims={0}, rhs_batch_dims={0},
4333 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4334 sharding={devices=[1,2,1]0,1}
4335 })";
4336
4337 TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
4338 /*num_devices=*/2));
4339 VLOG(1) << module->ToString();
4340 auto root = module->entry_computation()->root_instruction();
4341 auto lhs = AllOf(
4342 op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
4343 op::Constant(), op::Constant())),
4344 op::Shape("f32[32,12,64,128]"));
4345 auto rhs =
4346 AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()),
4347 op::Constant(), op::Reshape(),
4348 op::Constant(), op::Constant())),
4349 op::Shape("f32[32,19648,64,128]"));
4350 EXPECT_THAT(root,
4351 AllOf(op::Slice(AllOf(op::GetTupleElement(op::While(op::Tuple(
4352 lhs, rhs, op::Broadcast(),
4353 op::Broadcast(), op::Constant()))),
4354 op::Shape("f32[32,12,39296]"))),
4355 op::Shape("f32[32,12,39295]")));
4356 auto while_loop = root->operand(0)->operand(0);
4357 // Check loop condition.
4358 EXPECT_THAT(
4359 while_loop->while_condition()->root_instruction(),
4360 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
4361
4362 // Check loop body.
4363 auto next_i = op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
4364 auto window = op::Conditional(op::Compare(next_i, op::Constant()),
4365 op::GetTupleElement(op::Parameter(0)),
4366 op::GetTupleElement(op::Parameter(0)));
4367 auto partial_output = op::Dot(op::GetTupleElement(op::Parameter(0)),
4368 op::GetTupleElement(op::Parameter(0)));
4369 EXPECT_THAT(
4370 while_loop->while_body()->root_instruction(),
4371 op::Tuple(op::GetTupleElement(op::Parameter(0)), window,
4372 op::DynamicUpdateSlice(op::GetTupleElement(op::Parameter(0)),
4373 partial_output, op::Constant(),
4374 op::Constant(), op::Reshape()),
4375 op::GetTupleElement(op::Parameter(0)), next_i));
4376
4377 // Check the conditional that contains the collective permute.
4378 auto cp_conditional =
4379 while_loop->while_body()->root_instruction()->operand(1);
4380 EXPECT_THAT(cp_conditional->true_computation()->root_instruction(),
4381 op::CollectivePermute(op::Parameter(0)));
4382 EXPECT_THAT(cp_conditional->false_computation()->root_instruction(),
4383 op::Parameter(0));
4384 }
4385
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedContracting)4386 TEST_F(SpmdPartitioningTest, EinsumRHSWindowedContracting) {
4387 absl::string_view hlo_string = R"(
4388 HloModule module
4389
4390 ENTRY entry {
4391 %lhs = f32[32,24,63,128] parameter(0)
4392 %lhs.copy = f32[32,24,63,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
4393 %rhs = f32[32,39296,63,128] parameter(1)
4394 %rhs.copy = f32[32,39296,63,128] copy(%rhs), sharding={devices=[1,1,2,1]0,1}
4395 ROOT %dot = f32[32,24,39296] dot(%lhs.copy, %rhs.copy),
4396 lhs_batch_dims={0}, rhs_batch_dims={0},
4397 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4398 sharding={devices=[1,2,1]0,1}
4399 })";
4400
4401 TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
4402 /*num_devices=*/2));
4403 VLOG(1) << module->ToString();
4404 auto root = module->entry_computation()->root_instruction();
4405 auto lhs = AllOf(
4406 op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Reshape(),
4407 op::Constant(), op::Constant())),
4408 op::Shape("f32[32,12,63,128]"));
4409 auto rhs =
4410 AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(1), op::Constant()),
4411 op::Constant(), op::Constant(),
4412 op::Reshape(), op::Constant())),
4413 op::Shape("f32[32,39296,32,128]"));
4414 auto masked_rhs =
4415 op::Select(op::Compare(), rhs, op::Broadcast(op::Constant()));
4416 EXPECT_THAT(root, AllOf(op::GetTupleElement(op::While(
4417 op::Tuple(lhs, masked_rhs, op::Broadcast(),
4418 op::Broadcast(), op::Constant()))),
4419 op::Shape("f32[32,12,39296]")));
4420 auto while_loop = root->operand(0);
4421 // Check loop condition.
4422 EXPECT_THAT(
4423 while_loop->while_condition()->root_instruction(),
4424 op::Compare(op::GetTupleElement(op::Parameter(0)), op::Constant()));
4425
4426 // Check loop body.
4427 auto next_i = op::Add(op::GetTupleElement(op::Parameter(0)), op::Constant());
4428 auto window = op::Conditional(op::Compare(next_i, op::Constant()),
4429 op::GetTupleElement(op::Parameter(0)),
4430 op::GetTupleElement(op::Parameter(0)));
4431 auto partial_output = op::Dot(
4432 op::DynamicSlice(
4433 op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
4434 op::Constant(), op::Constant(), op::Reshape(), op::Constant()),
4435 op::GetTupleElement(op::Parameter(0)));
4436 EXPECT_THAT(
4437 while_loop->while_body()->root_instruction(),
4438 op::Tuple(op::GetTupleElement(op::Parameter(0)), window,
4439 op::Add(op::GetTupleElement(op::Parameter(0)), partial_output),
4440 op::GetTupleElement(op::Parameter(0)), next_i));
4441
4442 // Check the conditional that contains the collective permute.
4443 auto cp_conditional =
4444 while_loop->while_body()->root_instruction()->operand(1);
4445 EXPECT_THAT(cp_conditional->true_computation()->root_instruction(),
4446 op::CollectivePermute(op::Parameter(0)));
4447 EXPECT_THAT(cp_conditional->false_computation()->root_instruction(),
4448 op::Parameter(0));
4449 }
4450
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedNonContractingReduce1)4451 TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContractingReduce1) {
4452 absl::string_view hlo_string = R"(
4453 HloModule module
4454
4455 sum {
4456 a = f32[] parameter(0)
4457 b = f32[] parameter(1)
4458 ROOT add = f32[] add(a, b)
4459 }
4460
4461 ENTRY entry {
4462 %lhs = f32[32,24,64,128] parameter(0)
4463 %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
4464 %rhs = f32[32,39295,64,128] parameter(1)
4465 %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
4466 %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
4467 lhs_batch_dims={0}, rhs_batch_dims={0},
4468 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4469 sharding={devices=[1,2,1]0,1}
4470 %constant = f32[] constant(0)
4471 %constant.1 = f32[] constant(2)
4472 %broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={},
4473 sharding={devices=[1,2,1]0,1}
4474 %multiply = f32[32,24,39295] multiply(%dot, %broadcast),
4475 sharding={devices=[1,2,1]0,1}
4476 ROOT %reduce = f32[32,24] reduce(%multiply, %constant), dimensions={2},
4477 to_apply=sum, sharding={devices=[1,2]0,1}
4478 })";
4479
4480 TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
4481 /*num_devices=*/2));
4482 VLOG(1) << module->ToString();
4483 // Involves loop code motion, skips pattern matching.
4484 }
4485
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedNonContractingReduce2)4486 TEST_F(SpmdPartitioningTest, EinsumRHSWindowedNonContractingReduce2) {
4487 absl::string_view hlo_string = R"(
4488 HloModule module
4489
4490 sum {
4491 a = f32[] parameter(0)
4492 b = f32[] parameter(1)
4493 ROOT add = f32[] add(a, b)
4494 }
4495
4496 ENTRY entry {
4497 %lhs = f32[32,24,64,128] parameter(0)
4498 %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,1]0,1}
4499 %rhs = f32[32,39295,64,128] parameter(1)
4500 %rhs.copy = f32[32,39295,64,128] copy(%rhs), sharding={devices=[1,2,1,1]0,1}
4501 %dot = f32[32,24,39295] dot(%lhs.copy, %rhs.copy),
4502 lhs_batch_dims={0}, rhs_batch_dims={0},
4503 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4504 sharding={devices=[1,2,1]0,1}
4505 %constant = f32[] constant(0)
4506 %constant.1 = f32[] constant(2)
4507 %broadcast = f32[32,24,39295] broadcast(%constant.1), dimensions={},
4508 sharding={devices=[1,2,1]0,1}
4509 %multiply = f32[32,24,39295] multiply(%dot, %broadcast),
4510 sharding={devices=[1,2,1]0,1}
4511 ROOT %reduce = f32[32,39295] reduce(%multiply, %constant), dimensions={1},
4512 to_apply=sum, sharding={replicated}
4513 })";
4514
4515 TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
4516 /*num_devices=*/2));
4517 VLOG(1) << module->ToString();
4518 // Involves loop code motion, skips pattern matching.
4519 }
4520
TEST_F(SpmdPartitioningTest,EinsumRHSWindowedContractingFromBroadcast)4521 TEST_F(SpmdPartitioningTest, EinsumRHSWindowedContractingFromBroadcast) {
4522 absl::string_view hlo_string = R"(
4523 HloModule module
4524
4525 ENTRY entry {
4526 %rhs = f32[32,39296,63,128] parameter(0)
4527 %rhs.copy = f32[32,39296,63,128] copy(%rhs), sharding={devices=[1,1,2,1]0,1}
4528 %constant.1 = f32[] constant(2)
4529 %broadcast = f32[32,24,63,128] broadcast(%constant.1), dimensions={},
4530 sharding={devices=[1,2,1,1]0,1}
4531 %add = f32[32,24,63,128] add(%broadcast, %broadcast),
4532 sharding={devices=[1,2,1,1]0,1}
4533 ROOT %dot = f32[32,24,39296] dot(%add, %rhs.copy),
4534 lhs_batch_dims={0}, rhs_batch_dims={0},
4535 lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
4536 sharding={devices=[1,2,1]0,1}
4537 })";
4538
4539 TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string,
4540 /*num_devices=*/2));
4541 VLOG(1) << module->ToString();
4542 // Involves loop code motion, skips pattern matching.
4543 }
4544
TEST_F(SpmdPartitioningTest,ReplicatedRng)4545 TEST_F(SpmdPartitioningTest, ReplicatedRng) {
4546 absl::string_view hlo_string = R"(
4547 HloModule module
4548
4549 ENTRY entry {
4550 %lhs = s32[] parameter(0)
4551 %lhs.copy = s32[] copy(%lhs), sharding={replicated}
4552 %rhs = s32[] parameter(1)
4553 %rhs.copy = s32[] copy(%rhs), sharding={replicated}
4554 ROOT %rng = s32[4]{0} rng(%lhs.copy, %rhs.copy),
4555 distribution=rng_uniform, sharding={replicated}
4556 })";
4557
4558 TF_ASSERT_OK_AND_ASSIGN(auto module,
4559 PartitionComputation(hlo_string, /*num_devices=*/2));
4560 VLOG(1) << module->ToString();
4561
4562 auto root = module->entry_computation()->root_instruction();
4563 auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("s32[]"));
4564 auto rhs = AllOf(op::Copy(op::Parameter(1)), op::Shape("s32[]"));
4565 EXPECT_THAT(
4566 root,
4567 AllOf(op::AllReduce(op::Select(
4568 op::Broadcast(op::Compare(op::PartitionId(), op::Constant())),
4569 op::Rng(), op::Broadcast(op::Constant()))),
4570 op::Shape("s32[4]")));
4571 }
4572
TEST_F(SpmdPartitioningTest,ManualRng)4573 TEST_F(SpmdPartitioningTest, ManualRng) {
4574 absl::string_view hlo_string = R"(
4575 HloModule module
4576
4577 ENTRY entry {
4578 %lhs = s32[] parameter(0), sharding={manual}
4579 %rhs = s32[] parameter(1), sharding={manual}
4580 ROOT %rng = s32[4]{0} rng(%lhs, %rhs),
4581 distribution=rng_uniform, sharding={manual}
4582 })";
4583
4584 TF_ASSERT_OK_AND_ASSIGN(auto module,
4585 PartitionComputation(hlo_string, /*num_devices=*/2));
4586 VLOG(1) << module->ToString();
4587
4588 auto root = module->entry_computation()->root_instruction();
4589 EXPECT_THAT(root, AllOf(op::Rng(op::Parameter(0), op::Parameter(1)),
4590 op::Shape("s32[4]")));
4591 }
4592
TEST_F(SpmdPartitioningTest,PartitionedRng)4593 TEST_F(SpmdPartitioningTest, PartitionedRng) {
4594 absl::string_view hlo_string = R"(
4595 HloModule module
4596
4597 ENTRY entry {
4598 %lhs = s32[] parameter(0)
4599 %lhs.copy = s32[] copy(%lhs), sharding={replicated}
4600 %rhs = s32[] parameter(1)
4601 %rhs.copy = s32[] copy(%rhs), sharding={maximal device=1}
4602 ROOT %rng = s32[4]{0} rng(%lhs.copy, %rhs.copy),
4603 distribution=rng_uniform, sharding={devices=[2]0,1}
4604 })";
4605
4606 TF_ASSERT_OK_AND_ASSIGN(auto module,
4607 PartitionComputation(hlo_string, /*num_devices=*/2));
4608 VLOG(1) << module->ToString();
4609
4610 auto root = module->entry_computation()->root_instruction();
4611 auto lhs = AllOf(op::Copy(op::Parameter(0)), op::Shape("s32[]"));
4612 auto rhs = AllOf(op::Copy(op::Copy(op::Parameter(1))), op::Shape("s32[]"));
4613 EXPECT_THAT(root, AllOf(op::Rng(lhs, op::AllReduce(op::Select(
4614 op::Broadcast(op::Compare()), rhs,
4615 op::Broadcast(op::Constant())))),
4616 op::Shape("s32[2]")));
4617 }
4618
TEST_F(SpmdPartitioningTest,PartialReplicatedRng)4619 TEST_F(SpmdPartitioningTest, PartialReplicatedRng) {
4620 absl::string_view hlo_string = R"(
4621 HloModule module
4622
4623 ENTRY entry {
4624 %lhs = s32[] parameter(0), sharding={replicated}
4625 %rhs = s32[] parameter(1), sharding={replicated}
4626 ROOT %rng = s32[8]{0} rng(%lhs, %rhs),
4627 distribution=rng_uniform,
4628 sharding={devices=[2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
4629 })";
4630
4631 TF_ASSERT_OK_AND_ASSIGN(auto module,
4632 PartitionComputation(hlo_string, /*num_devices=*/8));
4633 VLOG(1) << module->ToString();
4634
4635 auto root = module->entry_computation()->root_instruction();
4636 auto lhs = AllOf(op::Parameter(0), op::Shape("s32[]"));
4637 auto rhs = AllOf(op::Parameter(1), op::Shape("s32[]"));
4638 auto partition_id =
4639 AllOf(op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())),
4640 op::Shape("u32[]"));
4641 EXPECT_THAT(
4642 root, AllOf(op::AllReduce(op::Select(
4643 op::Broadcast(op::Compare(partition_id, op::Constant())),
4644 op::Rng(lhs, rhs), op::Broadcast(op::Constant()))),
4645 op::Shape("s32[4]")));
4646 }
4647
TEST_F(SpmdPartitioningTest,DynamicSliceAlongNonPartitionedDimension)4648 TEST_F(SpmdPartitioningTest, DynamicSliceAlongNonPartitionedDimension) {
4649 absl::string_view hlo_string = R"(
4650 HloModule module
4651
4652 ENTRY entry {
4653 %input = s32[128,64] parameter(0)
4654 %input.copy = s32[128,64] copy(%input), sharding={devices=[2,1]0,1}
4655 %index = s32[] parameter(1)
4656 %constant = s32[] constant(0)
4657 ROOT %dynamic-slice = s32[128,2] dynamic-slice(%input.copy, %constant, %index),
4658 dynamic_slice_sizes={128,2}, sharding={devices=[2,1]0,1}
4659 })";
4660
4661 TF_ASSERT_OK_AND_ASSIGN(auto module,
4662 PartitionComputation(hlo_string, /*num_devices=*/2));
4663 VLOG(1) << module->ToString();
4664
4665 auto root = module->entry_computation()->root_instruction();
4666 auto input = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
4667 op::Constant())),
4668 op::Shape("s32[64,64]"));
4669 EXPECT_THAT(root,
4670 AllOf(op::DynamicSlice(input, op::Constant(), op::Parameter(1)),
4671 op::Shape("s32[64,2]")));
4672 }
4673
TEST_F(SpmdPartitioningTest,DynamicUpdateSliceAlongNonPartitionedDimension)4674 TEST_F(SpmdPartitioningTest, DynamicUpdateSliceAlongNonPartitionedDimension) {
4675 absl::string_view hlo_string = R"(
4676 HloModule module
4677
4678 ENTRY entry {
4679 %input = s32[128,64] parameter(0)
4680 %input.copy = s32[128,64] copy(%input), sharding={devices=[2,1]0,1}
4681 %index = s32[] parameter(1)
4682 %constant = s32[] constant(0)
4683 %update = s32[128,2] parameter(2)
4684 %update.copy = s32[128,2] copy(%update), sharding={devices=[2,1]0,1}
4685 ROOT %dynamic-update-slice = s32[128,64]
4686 dynamic-update-slice(%input.copy, %update.copy, %constant, %index),
4687 sharding={devices=[2,1]0,1}
4688 })";
4689
4690 TF_ASSERT_OK_AND_ASSIGN(auto module,
4691 PartitionComputation(hlo_string, /*num_devices=*/2));
4692 VLOG(1) << module->ToString();
4693
4694 auto root = module->entry_computation()->root_instruction();
4695 auto input = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
4696 op::Constant())),
4697 op::Shape("s32[64,64]"));
4698 auto update = AllOf(op::Copy(op::DynamicSlice(op::Parameter(2), op::Reshape(),
4699 op::Constant())),
4700 op::Shape("s32[64,2]"));
4701 EXPECT_THAT(root, AllOf(op::DynamicUpdateSlice(input, update, op::Constant(),
4702 op::Parameter(1)),
4703 op::Shape("s32[64,64]")));
4704 }
4705
TEST_F(SpmdPartitioningTest,DynamicUpdateSliceAlongPartitionedDimension)4706 TEST_F(SpmdPartitioningTest, DynamicUpdateSliceAlongPartitionedDimension) {
4707 absl::string_view hlo_string = R"(
4708 HloModule module
4709
4710 ENTRY entry {
4711 %input = s32[128,64] parameter(0)
4712 %input.copy = s32[128,64] copy(%input), sharding={devices=[1,2]0,1}
4713 %index = s32[] parameter(1)
4714 %constant = s32[] constant(60)
4715 %update = s32[128,2] parameter(2)
4716 %update.copy = s32[128,2] copy(%update), sharding={devices=[1,2]0,1}
4717 ROOT %dynamic-update-slice = s32[128,64]
4718 dynamic-update-slice(%input.copy, %update.copy, %index, %constant),
4719 sharding={devices=[1,2]0,1}
4720 })";
4721
4722 TF_ASSERT_OK_AND_ASSIGN(auto module,
4723 PartitionComputation(hlo_string, /*num_devices=*/2));
4724 VLOG(1) << module->ToString();
4725
4726 auto root = module->entry_computation()->root_instruction();
4727 auto input = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
4728 op::Reshape())),
4729 op::Shape("s32[128,32]"));
4730 auto update = AllOf(op::AllReduce(op::DynamicUpdateSlice(
4731 op::Broadcast(),
4732 op::Copy(op::DynamicSlice(
4733 op::Parameter(2), op::Constant(), op::Reshape())),
4734 op::Constant(), op::Reshape())),
4735 op::Shape("s32[128,2]"));
4736
4737 EXPECT_THAT(
4738 root, AllOf(op::Select(op::Broadcast(),
4739 op::DynamicUpdateSlice(
4740 input, update, op::Parameter(1), op::Select()),
4741 input),
4742 op::Shape("s32[128,32]")));
4743 }
4744
TEST_F(SpmdPartitioningTest,DynamicUpdateSliceAlongPartitionedDimension2)4745 TEST_F(SpmdPartitioningTest, DynamicUpdateSliceAlongPartitionedDimension2) {
4746 absl::string_view hlo_string = R"(
4747 HloModule module
4748
4749 ENTRY entry {
4750 %input = s32[8,790,2] parameter(0),
4751 sharding={devices=[8,1,1]0,1,2,3,4,5,6,7}
4752 %index = s32[] parameter(1)
4753 %constant = s32[] constant(0)
4754 %update = s32[1,790,2] parameter(2),
4755 sharding={devices=[8,1,1]0,1,2,3,4,5,6,7}
4756 ROOT %dynamic-update-slice = s32[8,790,2]
4757 dynamic-update-slice(%input, %update, %index, %constant, %constant),
4758 sharding={devices=[8,1,1]0,1,2,3,4,5,6,7}
4759 })";
4760
4761 TF_ASSERT_OK_AND_ASSIGN(auto module,
4762 PartitionComputation(hlo_string, /*num_devices=*/8));
4763 VLOG(1) << module->ToString();
4764
4765 auto root = module->entry_computation()->root_instruction();
4766 auto input = AllOf(op::Parameter(0), op::Shape("s32[1,790,2]"));
4767 auto update = AllOf(op::AllReduce(op::Select(
4768 op::Broadcast(), op::Parameter(2), op::Broadcast())),
4769 op::Shape("s32[1,790,2]"));
4770 EXPECT_THAT(
4771 root,
4772 AllOf(op::Select(op::Broadcast(),
4773 op::DynamicUpdateSlice(input, update, op::Select(),
4774 op::Constant(), op::Constant()),
4775 input),
4776 op::Shape("s32[1,790,2]")));
4777 }
4778
TEST_F(SpmdPartitioningTest,DynamicUpdateSlicePartitionSliceAndNonSliceDims)4779 TEST_F(SpmdPartitioningTest, DynamicUpdateSlicePartitionSliceAndNonSliceDims) {
4780 absl::string_view hlo_string = R"(
4781 HloModule module
4782
4783 ENTRY entry {
4784 %input = s32[128,64] parameter(0)
4785 %input.copy = s32[128,64] copy(%input), sharding={devices=[2,2]0,1,2,3}
4786 %constant.0 = s32[] constant(0)
4787 %constant.1 = s32[] constant(60)
4788 %update = s32[128,2] parameter(1)
4789 %update.copy = s32[128,2] copy(%update), sharding={devices=[2,2]0,1,2,3}
4790 ROOT %dynamic-update-slice = s32[128,64]
4791 dynamic-update-slice(%input.copy, %update.copy, %constant.0, %constant.1),
4792 sharding={devices=[2,2]0,1,2,3}
4793 })";
4794
4795 TF_ASSERT_OK_AND_ASSIGN(auto module,
4796 PartitionComputation(hlo_string, /*num_devices=*/4));
4797 VLOG(1) << module->ToString();
4798
4799 auto root = module->entry_computation()->root_instruction();
4800 auto input = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
4801 op::Reshape())),
4802 op::Shape("s32[64,32]"));
4803 auto update = AllOf(op::AllReduce(op::DynamicUpdateSlice(
4804 op::Broadcast(),
4805 op::Copy(op::DynamicSlice(
4806 op::Parameter(1), op::Reshape(), op::Reshape())),
4807 op::Constant(), op::Reshape())),
4808 op::Shape("s32[64,2]"));
4809
4810 EXPECT_THAT(root,
4811 AllOf(op::Select(op::Broadcast(),
4812 op::DynamicUpdateSlice(
4813 input, update, op::Constant(), op::Select()),
4814 input),
4815 op::Shape("s32[64,32]")));
4816 }
4817
TEST_F(SpmdPartitioningTest,PassthroughGather)4818 TEST_F(SpmdPartitioningTest, PassthroughGather) {
4819 absl::string_view hlo_string = R"(
4820 HloModule module
4821
4822 ENTRY entry {
4823 %input = f32[2,9] parameter(0), sharding={devices=[1,2]0,1}
4824 %indices = s32[3] parameter(1), sharding={replicated}
4825 ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1},
4826 collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1,
4827 slice_sizes={1,9}, sharding={devices=[1,2]0,1}
4828 })";
4829 TF_ASSERT_OK_AND_ASSIGN(auto module,
4830 PartitionComputation(hlo_string, /*num_devices=*/2));
4831 VLOG(1) << module->ToString();
4832 HloInstruction* root = module->entry_computation()->root_instruction();
4833 EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)),
4834 op::Shape("f32[3,5]")));
4835 }
4836
TEST_F(SpmdPartitioningTest,PassthroughGather_PartialReplicate)4837 TEST_F(SpmdPartitioningTest, PassthroughGather_PartialReplicate) {
4838 absl::string_view hlo_string = R"(
4839 HloModule module
4840
4841 ENTRY entry {
4842 %input = f32[2,9] parameter(0),
4843 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
4844 %indices = s32[3] parameter(1), sharding={replicated}
4845 ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1},
4846 collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1,
4847 slice_sizes={1,9}, sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
4848 })";
4849 TF_ASSERT_OK_AND_ASSIGN(auto module,
4850 PartitionComputation(hlo_string, /*num_devices=*/4));
4851 VLOG(1) << module->ToString();
4852 HloInstruction* root = module->entry_computation()->root_instruction();
4853 EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)),
4854 op::Shape("f32[3,5]")));
4855 }
4856
TEST_F(SpmdPartitioningTest,IndexPassthroughGather)4857 TEST_F(SpmdPartitioningTest, IndexPassthroughGather) {
4858 absl::string_view hlo_string = R"(
4859 HloModule module
4860
4861 ENTRY entry {
4862 %input = f32[2,9,8] parameter(0), sharding={replicated}
4863 %indices = s32[4,2,4] parameter(1), sharding={devices=[2,1,2]0,1,2,3}
4864 ROOT %gather = f32[8,4,4] gather(%input, %indices), offset_dims={0},
4865 collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=1,
4866 slice_sizes={1,1,8}, sharding={devices=[1,2,2]0,1,2,3}
4867 })";
4868 TF_ASSERT_OK_AND_ASSIGN(auto module,
4869 PartitionComputation(hlo_string, /*num_devices=*/4));
4870 VLOG(1) << module->ToString();
4871 HloInstruction* root = module->entry_computation()->root_instruction();
4872 EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)),
4873 op::Shape("f32[8,2,2]")));
4874 }
4875
TEST_F(SpmdPartitioningTest,IndexPassthroughGather_PartialReplicate)4876 TEST_F(SpmdPartitioningTest, IndexPassthroughGather_PartialReplicate) {
4877 absl::string_view hlo_string = R"(
4878 HloModule module
4879
4880 ENTRY entry {
4881 %input = f32[2,9,8] parameter(0), sharding={replicated}
4882 %indices = s32[4,2,4] parameter(1),
4883 sharding={devices=[2,1,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
4884 ROOT %gather = f32[8,4,4] gather(%input, %indices), offset_dims={0},
4885 collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=1,
4886 slice_sizes={1,1,8},
4887 sharding={devices=[1,2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
4888 })";
4889 TF_ASSERT_OK_AND_ASSIGN(auto module,
4890 PartitionComputation(hlo_string, /*num_devices=*/8));
4891 VLOG(1) << module->ToString();
4892 HloInstruction* root = module->entry_computation()->root_instruction();
4893 EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)),
4894 op::Shape("f32[8,2,2]")));
4895 }
4896
TEST_F(SpmdPartitioningTest,GatherPartitionedOnTrivialSliceDims)4897 TEST_F(SpmdPartitioningTest, GatherPartitionedOnTrivialSliceDims) {
4898 absl::string_view hlo_string = R"(
4899 HloModule module
4900
4901 ENTRY entry {
4902 %input = f32[17,9] parameter(0), sharding={devices=[2,1]0,1}
4903 %indices = s32[2,3] parameter(1), sharding={replicated}
4904 ROOT %gather = f32[2,3,9] gather(%input, %indices), offset_dims={2},
4905 collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2,
4906 slice_sizes={1,9}, sharding={replicated}
4907 })";
4908 TF_ASSERT_OK_AND_ASSIGN(auto module,
4909 PartitionComputation(hlo_string, /*num_devices=*/2));
4910 VLOG(1) << module->ToString();
4911 auto offset =
4912 op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
4913 auto min = AllOf(op::Broadcast(offset), op::Shape("s32[2,3]"));
4914 auto max = AllOf(op::Broadcast(op::Add(offset, op::Constant())),
4915 op::Shape("s32[2,3]"));
4916 auto clamp = op::Clamp(min, op::Parameter(1), max);
4917 auto gather = op::Gather(op::Parameter(0), op::Subtract(clamp, min));
4918 auto mask =
4919 op::Or(op::Lt(op::Parameter(1), min), op::Gt(op::Parameter(1), max));
4920 auto masked =
4921 op::Select(op::Broadcast(mask), op::Broadcast(op::Constant()), gather);
4922 HloInstruction* root = module->entry_computation()->root_instruction();
4923 EXPECT_THAT(root, AllOf(op::AllReduce(masked), op::Shape("f32[2,3,9]")));
4924 }
4925
TEST_F(SpmdPartitioningTest,GatherPartitionedOnTrivialSliceDims_PartialReplicate)4926 TEST_F(SpmdPartitioningTest,
4927 GatherPartitionedOnTrivialSliceDims_PartialReplicate) {
4928 absl::string_view hlo_string = R"(
4929 HloModule module
4930
4931 ENTRY entry {
4932 %input = f32[17,9] parameter(0),
4933 sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
4934 %indices = s32[2,3] parameter(1), sharding={replicated}
4935 ROOT %gather = f32[2,3,9] gather(%input, %indices), offset_dims={2},
4936 collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2,
4937 slice_sizes={1,9}, sharding={replicated}
4938 })";
4939 TF_ASSERT_OK_AND_ASSIGN(auto module,
4940 PartitionComputation(hlo_string, /*num_devices=*/4));
4941 VLOG(1) << module->ToString();
4942 auto offset =
4943 op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
4944 auto min = AllOf(op::Broadcast(offset), op::Shape("s32[2,3]"));
4945 auto max = AllOf(op::Broadcast(op::Add(offset, op::Constant())),
4946 op::Shape("s32[2,3]"));
4947 auto clamp = op::Clamp(min, op::Parameter(1), max);
4948 auto gather = op::Gather(op::Parameter(0), op::Subtract(clamp, min));
4949 auto mask =
4950 op::Or(op::Lt(op::Parameter(1), min), op::Gt(op::Parameter(1), max));
4951 auto masked =
4952 op::Select(op::Broadcast(mask), op::Broadcast(op::Constant()), gather);
4953 HloInstruction* root = module->entry_computation()->root_instruction();
4954 EXPECT_THAT(root, AllOf(op::AllReduce(masked), op::Shape("f32[2,3,9]")));
4955 }
4956
TEST_F(SpmdPartitioningTest,PassthroughScatter)4957 TEST_F(SpmdPartitioningTest, PassthroughScatter) {
4958 absl::string_view hlo_string = R"(
4959 HloModule module
4960
4961 add (lhs: f32[], rhs: f32[]) -> f32[] {
4962 lhs = f32[] parameter(0)
4963 rhs = f32[] parameter(1)
4964 ROOT sum = f32[] add(lhs, rhs)
4965 }
4966
4967 ENTRY entry {
4968 %input = f32[2,9] parameter(0), sharding={devices=[1,2]0,1}
4969 %indices = s32[3] parameter(1), sharding={replicated}
4970 %updates = f32[3,9] parameter(2), sharding={devices=[1,2]0,1}
4971 ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
4972 to_apply=add,
4973 update_window_dims={1},
4974 inserted_window_dims={0},
4975 scatter_dims_to_operand_dims={0},
4976 index_vector_dim=1, sharding={devices=[1,2]0,1}
4977 })";
4978 TF_ASSERT_OK_AND_ASSIGN(auto module,
4979 PartitionComputation(hlo_string, /*num_devices=*/2));
4980 VLOG(1) << module->ToString();
4981 HloInstruction* root = module->entry_computation()->root_instruction();
4982 EXPECT_THAT(root, AllOf(op::Scatter(op::Parameter(0), op::Parameter(1),
4983 op::Parameter(2)),
4984 op::Shape("f32[2,5]")));
4985 }
4986
TEST_F(SpmdPartitioningTest,PassthroughScatter_PartialReplicate)4987 TEST_F(SpmdPartitioningTest, PassthroughScatter_PartialReplicate) {
4988 absl::string_view hlo_string = R"(
4989 HloModule module
4990
4991 add (lhs: f32[], rhs: f32[]) -> f32[] {
4992 lhs = f32[] parameter(0)
4993 rhs = f32[] parameter(1)
4994 ROOT sum = f32[] add(lhs, rhs)
4995 }
4996
4997 ENTRY entry {
4998 %input = f32[2,9] parameter(0),
4999 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
5000 %indices = s32[3] parameter(1), sharding={replicated}
5001 %updates = f32[3,9] parameter(2),
5002 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
5003 ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates),
5004 to_apply=add,
5005 update_window_dims={1},
5006 inserted_window_dims={0},
5007 scatter_dims_to_operand_dims={0},
5008 index_vector_dim=1,
5009 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
5010 })";
5011 TF_ASSERT_OK_AND_ASSIGN(auto module,
5012 PartitionComputation(hlo_string, /*num_devices=*/4));
5013 VLOG(1) << module->ToString();
5014 HloInstruction* root = module->entry_computation()->root_instruction();
5015 EXPECT_THAT(root, AllOf(op::Scatter(op::Parameter(0), op::Parameter(1),
5016 op::Parameter(2)),
5017 op::Shape("f32[2,5]")));
5018 }
5019
TEST_F(SpmdPartitioningTest,IndexPassthroughScatter)5020 TEST_F(SpmdPartitioningTest, IndexPassthroughScatter) {
5021 absl::string_view hlo_string = R"(
5022 HloModule module
5023
5024 add (lhs: f32[], rhs: f32[]) -> f32[] {
5025 lhs = f32[] parameter(0)
5026 rhs = f32[] parameter(1)
5027 ROOT sum = f32[] add(lhs, rhs)
5028 }
5029
5030 ENTRY entry {
5031 %input = f32[2,9,8] parameter(0), sharding={replicated}
5032 %indices = s32[4,2,4] parameter(1), sharding={devices=[2,1,2]0,1,2,3}
5033 %updates = f32[4,4,8] parameter(2), sharding={devices=[2,2,1]0,1,2,3}
5034 ROOT %scatter = f32[2,9,8] scatter(%input, %indices, %updates),
5035 to_apply=add,
5036 update_window_dims={2},
5037 inserted_window_dims={0,1},
5038 scatter_dims_to_operand_dims={0,1},
5039 index_vector_dim=1, sharding={replicated}
5040 })";
5041 TF_ASSERT_OK_AND_ASSIGN(auto module,
5042 PartitionComputation(hlo_string, /*num_devices=*/4));
5043 VLOG(1) << module->ToString();
5044 HloInstruction* root = module->entry_computation()->root_instruction();
5045 EXPECT_THAT(
5046 root,
5047 AllOf(op::AllReduce(op::AllReduce(op::Scatter(
5048 op::Select(op::Broadcast(op::Convert(op::PartitionId())),
5049 op::Broadcast(op::Constant()), op::Parameter(0)),
5050 op::Parameter(1), op::Parameter(2)))),
5051 op::Shape("f32[2,9,8]")));
5052 }
5053
TEST_F(SpmdPartitioningTest,IndexPassthroughScatter_PartialReplicate)5054 TEST_F(SpmdPartitioningTest, IndexPassthroughScatter_PartialReplicate) {
5055 absl::string_view hlo_string = R"(
5056 HloModule module
5057
5058 add (lhs: f32[], rhs: f32[]) -> f32[] {
5059 lhs = f32[] parameter(0)
5060 rhs = f32[] parameter(1)
5061 ROOT sum = f32[] add(lhs, rhs)
5062 }
5063
5064 ENTRY entry {
5065 %input = f32[2,9,8] parameter(0), sharding={replicated}
5066 %indices = s32[4,2,4] parameter(1),
5067 sharding={devices=[2,1,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
5068 %updates = f32[4,4,8] parameter(2),
5069 sharding={devices=[2,2,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
5070 ROOT %scatter = f32[2,9,8] scatter(%input, %indices, %updates),
5071 to_apply=add,
5072 update_window_dims={2},
5073 inserted_window_dims={0,1},
5074 scatter_dims_to_operand_dims={0,1},
5075 index_vector_dim=1, sharding={replicated}
5076 })";
5077 TF_ASSERT_OK_AND_ASSIGN(auto module,
5078 PartitionComputation(hlo_string, /*num_devices=*/8));
5079 VLOG(1) << module->ToString();
5080 HloInstruction* root = module->entry_computation()->root_instruction();
5081 EXPECT_THAT(
5082 root,
5083 AllOf(op::AllReduce(op::AllReduce(op::Scatter(
5084 op::Select(op::Broadcast(op::Convert(op::Reshape())),
5085 op::Broadcast(op::Constant()), op::Parameter(0)),
5086 op::Parameter(1), op::Parameter(2)))),
5087 op::Shape("f32[2,9,8]")));
5088 }
5089
TEST_F(SpmdPartitioningTest,IndexPassthroughScatter_Min)5090 TEST_F(SpmdPartitioningTest, IndexPassthroughScatter_Min) {
5091 absl::string_view hlo_string = R"(
5092 HloModule module
5093
5094 min (lhs: f32[], rhs: f32[]) -> f32[] {
5095 lhs = f32[] parameter(0)
5096 rhs = f32[] parameter(1)
5097 ROOT min = f32[] minimum(lhs, rhs)
5098 }
5099
5100 ENTRY entry {
5101 %input = f32[2,9,8] parameter(0), sharding={replicated}
5102 %indices = s32[4,2,4] parameter(1), sharding={devices=[2,1,2]0,1,2,3}
5103 %updates = f32[4,4,8] parameter(2), sharding={devices=[2,2,1]0,1,2,3}
5104 ROOT %scatter = f32[2,9,8] scatter(%input, %indices, %updates),
5105 to_apply=min,
5106 update_window_dims={2},
5107 inserted_window_dims={0,1},
5108 scatter_dims_to_operand_dims={0,1},
5109 index_vector_dim=1, sharding={replicated}
5110 })";
5111 TF_ASSERT_OK_AND_ASSIGN(auto module,
5112 PartitionComputation(hlo_string, /*num_devices=*/4));
5113 VLOG(1) << module->ToString();
5114 HloInstruction* root = module->entry_computation()->root_instruction();
5115 EXPECT_THAT(
5116 root,
5117 AllOf(op::AllReduce(op::AllReduce(op::Scatter(
5118 op::Select(op::Broadcast(op::Convert(op::PartitionId())),
5119 op::Broadcast(op::Constant()), op::Parameter(0)),
5120 op::Parameter(1), op::Parameter(2)))),
5121 op::Shape("f32[2,9,8]")));
5122 }
5123
TEST_F(SpmdPartitioningTest,ScatterPartitionedOnTrivialSliceDims)5124 TEST_F(SpmdPartitioningTest, ScatterPartitionedOnTrivialSliceDims) {
5125 absl::string_view hlo_string = R"(
5126 HloModule module
5127
5128 add (lhs: f32[], rhs: f32[]) -> f32[] {
5129 lhs = f32[] parameter(0)
5130 rhs = f32[] parameter(1)
5131 ROOT sum = f32[] add(lhs, rhs)
5132 }
5133
5134 ENTRY entry {
5135 %input = f32[17,9] parameter(0), sharding={devices=[2,1]0,1}
5136 %indices = s32[2,3] parameter(1), sharding={replicated}
5137 %updates = f32[2,3,9] parameter(2), sharding={replicated}
5138 ROOT %scatter = f32[17,9] scatter(%input, %indices, %updates),
5139 to_apply=add,
5140 update_window_dims={2},
5141 inserted_window_dims={0},
5142 scatter_dims_to_operand_dims={0},
5143 index_vector_dim=2, sharding={devices=[2,1]0,1}
5144 })";
5145 TF_ASSERT_OK_AND_ASSIGN(auto module,
5146 PartitionComputation(hlo_string, /*num_devices=*/2));
5147 VLOG(1) << module->ToString();
5148 auto offset =
5149 op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
5150 auto indices = op::Subtract(
5151 op::Parameter(1), AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")));
5152 HloInstruction* root = module->entry_computation()->root_instruction();
5153 EXPECT_THAT(root,
5154 AllOf(op::Scatter(op::Parameter(0), indices, op::Parameter(2)),
5155 op::Shape("f32[9,9]")));
5156 }
5157
TEST_F(SpmdPartitioningTest,ScatterPartitionedOnTrivialSliceDims_PartialReplicate)5158 TEST_F(SpmdPartitioningTest,
5159 ScatterPartitionedOnTrivialSliceDims_PartialReplicate) {
5160 absl::string_view hlo_string = R"(
5161 HloModule module
5162
5163 add (lhs: f32[], rhs: f32[]) -> f32[] {
5164 lhs = f32[] parameter(0)
5165 rhs = f32[] parameter(1)
5166 ROOT sum = f32[] add(lhs, rhs)
5167 }
5168
5169 ENTRY entry {
5170 %input = f32[17,9] parameter(0),
5171 sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
5172 %indices = s32[2,3] parameter(1), sharding={replicated}
5173 %updates = f32[2,3,9] parameter(2), sharding={replicated}
5174 ROOT %scatter = f32[17,9] scatter(%input, %indices, %updates),
5175 to_apply=add,
5176 update_window_dims={2},
5177 inserted_window_dims={0},
5178 scatter_dims_to_operand_dims={0},
5179 index_vector_dim=2,
5180 sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
5181 })";
5182 TF_ASSERT_OK_AND_ASSIGN(auto module,
5183 PartitionComputation(hlo_string, /*num_devices=*/4));
5184 VLOG(1) << module->ToString();
5185 auto offset =
5186 op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId()));
5187 auto indices = op::Subtract(
5188 op::Parameter(1), AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")));
5189 HloInstruction* root = module->entry_computation()->root_instruction();
5190 EXPECT_THAT(root,
5191 AllOf(op::Scatter(op::Parameter(0), indices, op::Parameter(2)),
5192 op::Shape("f32[9,9]")));
5193 }
5194
TEST_F(SpmdPartitioningTest,TiledReversePassthrough)5195 TEST_F(SpmdPartitioningTest, TiledReversePassthrough) {
5196 absl::string_view hlo_string = R"(
5197 HloModule module
5198
5199 ENTRY entry {
5200 constant = f32[3,3]{1,0} constant({{1,1,1},{1,1,1},{1,1,1}}),
5201 sharding={devices=[2,1]0,1}
5202 ROOT reverse = f32[3,3]{1,0} reverse(constant), dimensions={1},
5203 sharding={devices=[2,1]0,1}
5204 })";
5205 TF_ASSERT_OK_AND_ASSIGN(auto module,
5206 PartitionComputation(hlo_string, /*num_devices=*/2));
5207 VLOG(1) << module->ToString();
5208 HloInstruction* root = module->entry_computation()->root_instruction();
5209 EXPECT_THAT(root, AllOf(op::Shape("f32[2,3]{1,0}"),
5210 op::Reverse(op::DynamicSlice(
5211 op::Pad(op::Constant(), op::Constant()),
5212 op::Reshape(), op::Constant()))));
5213 }
5214
TEST_F(SpmdPartitioningTest,TiledReversePassthroughViaReversedSharding)5215 TEST_F(SpmdPartitioningTest, TiledReversePassthroughViaReversedSharding) {
5216 absl::string_view hlo_string = R"(
5217 HloModule module
5218
5219 ENTRY entry {
5220 param = f32[4] parameter(0), sharding={devices=[2]0,1}
5221 ROOT reverse = f32[4] reverse(param), dimensions={0},
5222 sharding={devices=[2]1,0}
5223 })";
5224 TF_ASSERT_OK_AND_ASSIGN(auto module,
5225 PartitionComputation(hlo_string, /*num_devices=*/2));
5226 VLOG(1) << module->ToString();
5227 HloInstruction* root = module->entry_computation()->root_instruction();
5228 EXPECT_THAT(root, AllOf(op::Shape("f32[2]"), op::Reverse(op::Parameter(0))));
5229 }
5230
TEST_F(SpmdPartitioningTest,TiledReverseSwapShards)5231 TEST_F(SpmdPartitioningTest, TiledReverseSwapShards) {
5232 absl::string_view hlo_string = R"(
5233 HloModule module
5234
5235 ENTRY entry {
5236 param = f32[4] parameter(0), sharding={devices=[2]0,1}
5237 ROOT reverse = f32[4] reverse(param), dimensions={0},
5238 sharding={devices=[2]0,1}
5239 })";
5240 TF_ASSERT_OK_AND_ASSIGN(auto module,
5241 PartitionComputation(hlo_string, /*num_devices=*/2));
5242 VLOG(1) << module->ToString();
5243 HloInstruction* root = module->entry_computation()->root_instruction();
5244 EXPECT_THAT(root,
5245 AllOf(op::Shape("f32[2]"),
5246 op::Reverse(op::CollectivePermute(op::Parameter(0)))));
5247 }
5248
TEST_F(SpmdPartitioningTest,TiledReverseHaloExchange)5249 TEST_F(SpmdPartitioningTest, TiledReverseHaloExchange) {
5250 absl::string_view hlo_string = R"(
5251 HloModule module
5252
5253 ENTRY entry {
5254 param = f32[3] parameter(0), sharding={devices=[2]0,1}
5255 ROOT reverse = f32[3] reverse(param), dimensions={0},
5256 sharding={devices=[2]1,0}
5257 })";
5258 TF_ASSERT_OK_AND_ASSIGN(auto module,
5259 PartitionComputation(hlo_string, /*num_devices=*/2));
5260 VLOG(1) << module->ToString();
5261 HloInstruction* root = module->entry_computation()->root_instruction();
5262 auto halo_exchange_concat =
5263 op::Concatenate(AllOf(op::Shape("f32[1]"),
5264 op::CollectivePermute(op::Slice(op::Parameter(0)))),
5265 op::Slice(op::Parameter(0)));
5266 EXPECT_THAT(root,
5267 AllOf(op::Shape("f32[2]"), op::Reverse(halo_exchange_concat)));
5268 }
5269
TEST_F(SpmdPartitioningTest,MixWithManualPartitioning)5270 TEST_F(SpmdPartitioningTest, MixWithManualPartitioning) {
5271 absl::string_view hlo_string = R"(
5272 HloModule module
5273
5274 ENTRY entry {
5275 param = (f32[8,2], f32[4,2]) parameter(0), sharding={{devices=[2,1]0,1},{manual}}
5276 param0 = f32[8,2] get-tuple-element(param), index=0, sharding={devices=[2,1]0,1}
5277 param1 = f32[4,2] get-tuple-element(param), index=1, sharding={manual}
5278 to_shard = f32[4,2] custom-call(param0), custom_call_target="SPMDFullToShardShape", sharding={manual}
5279 add = f32[4,2] add(to_shard, param1), sharding={manual}
5280 to_full = f32[8,2] custom-call(add), custom_call_target="SPMDShardToFullShape", sharding={devices=[2,1]0,1}
5281 mul = f32[8,2] multiply(to_full, param0), sharding={devices=[2,1]0,1}
5282 to_shard2 = f32[4,2] custom-call(mul), custom_call_target="SPMDFullToShardShape", sharding={manual}
5283 ROOT tuple = (f32[4,2]) tuple(to_shard2), sharding={{manual}}
5284 })";
5285 TF_ASSERT_OK_AND_ASSIGN(auto module,
5286 PartitionComputation(hlo_string, /*num_devices=*/2));
5287 VLOG(1) << module->ToString();
5288 HloInstruction* root = module->entry_computation()->root_instruction();
5289 auto p0 = op::GetTupleElement(op::Parameter(0));
5290 auto to_shard = op::Copy(p0);
5291 auto p1 = op::GetTupleElement(op::Parameter(0));
5292 auto mul = AllOf(op::Shape("f32[4,2]"),
5293 op::Multiply(op::Copy(op::Add(to_shard, p1)), p0));
5294 EXPECT_THAT(root, op::Tuple(op::Copy(mul)));
5295 }
5296
TEST_F(SpmdPartitioningTest,SubgroupAllToAllReshard)5297 TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard) {
5298 absl::string_view hlo_string = R"(
5299 HloModule module
5300
5301 ENTRY entry {
5302 %param0 = f32[8,8,8,8] parameter(0),
5303 sharding={devices=[2,2,1,2]0,1,2,3,4,5,6,7}
5304 ROOT %copy = f32[8,8,8,8] copy(%param0),
5305 sharding={devices=[1,2,2,2]0,1,4,5,2,3,6,7}
5306 })";
5307
5308 TF_ASSERT_OK_AND_ASSIGN(auto module,
5309 PartitionComputation(hlo_string, /*num_devices=*/8));
5310 VLOG(1) << module->ToString();
5311
5312 auto root = module->entry_computation()->root_instruction();
5313 auto reshape =
5314 AllOf(op::Shape("f32[4,4,2,4,4]"), op::Reshape(op::Parameter(0)));
5315 auto all_to_all = AllOf(op::Shape("f32[4,4,2,4,4]"), op::AllToAll(reshape));
5316 auto xpose = AllOf(op::Shape("f32[2,4,4,4,4]"), op::Transpose(all_to_all));
5317 EXPECT_THAT(root,
5318 op::Copy(AllOf(op::Reshape(xpose), op::Shape("f32[8,4,4,4]"))));
5319 EXPECT_EQ(root->operand(0)->operand(0)->operand(0)->replica_groups().size(),
5320 4);
5321 }
5322
TEST_F(SpmdPartitioningTest,SubgroupAllToAllReshard2)5323 TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard2) {
5324 absl::string_view hlo_string = R"(
5325 HloModule module
5326
5327 ENTRY entry {
5328 %param0 = f32[8,8] parameter(0),
5329 sharding={devices=[2,4]0,1,2,3,4,5,6,7}
5330 ROOT %copy = f32[8,8] copy(%param0),
5331 sharding={devices=[4,2]0,1,4,5,2,3,6,7}
5332 })";
5333
5334 TF_ASSERT_OK_AND_ASSIGN(auto module,
5335 PartitionComputation(hlo_string, /*num_devices=*/8));
5336 VLOG(1) << module->ToString();
5337
5338 auto root = module->entry_computation()->root_instruction();
5339 auto all_to_all = op::AllToAll(
5340 AllOf(op::Shape("f32[2,2,2]"), op::Reshape(op::Parameter(0))));
5341 auto reshape =
5342 AllOf(op::Shape("f32[2,4]"), op::Reshape(op::Transpose(all_to_all)));
5343 EXPECT_THAT(root, op::Copy(op::CollectivePermute(reshape)));
5344 }
5345
TEST_F(SpmdPartitioningTest,SubgroupAllToAllReshard3)5346 TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard3) {
5347 absl::string_view hlo_string = R"(
5348 HloModule module
5349
5350 ENTRY entry {
5351 %param0 = f32[8,8,8] parameter(0),
5352 sharding={devices=[2,4,1]0,1,2,3,4,5,6,7}
5353 ROOT %copy = f32[8,8,8] copy(%param0),
5354 sharding={devices=[1,2,4]0,1,4,5,2,3,6,7}
5355 })";
5356
5357 TF_ASSERT_OK_AND_ASSIGN(auto module,
5358 PartitionComputation(hlo_string, /*num_devices=*/8));
5359 VLOG(1) << module->ToString();
5360
5361 auto root = module->entry_computation()->root_instruction();
5362 auto all_to_all = op::AllToAll(
5363 AllOf(op::Shape("f32[4,2,4,2]"), op::Reshape(op::Parameter(0))));
5364 auto reshape =
5365 AllOf(op::Shape("f32[4,8,2]"), op::Reshape(op::Transpose(all_to_all)));
5366 auto all_to_all2 =
5367 op::AllToAll(AllOf(op::Shape("f32[4,2,4,2]"), op::Reshape(reshape)));
5368 auto reshape2 =
5369 AllOf(op::Shape("f32[8,4,2]"), op::Reshape(op::Transpose(all_to_all2)));
5370 EXPECT_THAT(root, op::Copy(op::CollectivePermute(reshape2)));
5371 }
5372
TEST_F(SpmdPartitioningTest,Dot2DPartitionedNonContractingAndContracting0)5373 TEST_F(SpmdPartitioningTest, Dot2DPartitionedNonContractingAndContracting0) {
5374 absl::string_view hlo_string = R"(
5375 HloModule module
5376
5377 ENTRY entry {
5378 %lhs = f32[48,12] parameter(0), sharding={devices=[2,2]0,1,2,3}
5379 %rhs = f32[32,12] parameter(1), sharding={devices=[2,2]0,2,1,3}
5380 ROOT %dot = f32[48,32] dot(%lhs, %rhs),
5381 lhs_batch_dims={}, rhs_batch_dims={},
5382 lhs_contracting_dims={1}, rhs_contracting_dims={1},
5383 sharding={devices=[2,2]0,1,2,3}
5384 })";
5385
5386 TF_ASSERT_OK_AND_ASSIGN(auto module,
5387 PartitionComputation(hlo_string, /*num_devices=*/4));
5388 VLOG(1) << module->ToString();
5389
5390 auto lhs = AllOf(op::Shape("f32[24,6]"), op::Parameter(0));
5391 auto partial_replicated_lhs =
5392 AllOf(op::Shape("f32[24,12]"),
5393 op::AllReduce(op::DynamicUpdateSlice(_, lhs, _, _)));
5394 auto rhs = AllOf(op::Shape("f32[16,6]"), op::Parameter(1));
5395 auto partial_replicated_rhs =
5396 AllOf(op::Shape("f32[16,12]"),
5397 op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _)));
5398 auto root = module->entry_computation()->root_instruction();
5399 EXPECT_THAT(root,
5400 AllOf(op::Dot(partial_replicated_lhs, partial_replicated_rhs),
5401 op::Shape("f32[24,16]")));
5402 }
5403
TEST_F(SpmdPartitioningTest,Dot2DPartitionedNonContractingAndContracting1)5404 TEST_F(SpmdPartitioningTest, Dot2DPartitionedNonContractingAndContracting1) {
5405 absl::string_view hlo_string = R"(
5406 HloModule module
5407
5408 ENTRY entry {
5409 %lhs = f32[48,100] parameter(0), sharding={devices=[2,2]0,1,2,3}
5410 %rhs = f32[32,100] parameter(1), sharding={devices=[2,2]0,1,2,3}
5411 ROOT %dot = f32[48,32] dot(%lhs, %rhs),
5412 lhs_batch_dims={}, rhs_batch_dims={},
5413 lhs_contracting_dims={1}, rhs_contracting_dims={1},
5414 sharding={devices=[2,2]0,1,2,3}
5415 })";
5416
5417 TF_ASSERT_OK_AND_ASSIGN(auto module,
5418 PartitionComputation(hlo_string, /*num_devices=*/4));
5419 VLOG(1) << module->ToString();
5420
5421 auto lhs = AllOf(op::Shape("f32[24,50]"), op::Parameter(0));
5422 auto rhs = AllOf(op::Shape("f32[16,50]"), op::Parameter(1));
5423 auto partial_replicated_rhs =
5424 AllOf(op::Shape("f32[32,50]"),
5425 op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _)));
5426 auto root = module->entry_computation()->root_instruction();
5427 EXPECT_THAT(
5428 root, AllOf(op::Shape("f32[24,16]"),
5429 op::DynamicSlice(
5430 op::AllReduce(AllOf(op::Dot(lhs, partial_replicated_rhs),
5431 op::Shape("f32[24,32]"))),
5432 _, _)));
5433 }
5434
TEST_F(SpmdPartitioningTest,Dot2DPartitionedNonContractingAndContracting2)5435 TEST_F(SpmdPartitioningTest, Dot2DPartitionedNonContractingAndContracting2) {
5436 absl::string_view hlo_string = R"(
5437 HloModule module
5438
5439 ENTRY entry {
5440 %lhs = f32[48,100] parameter(0), sharding={replicated}
5441 %rhs = f32[32,100] parameter(1), sharding={devices=[2,2]0,1,2,3}
5442 ROOT %dot = f32[48,32] dot(%lhs, %rhs),
5443 lhs_batch_dims={}, rhs_batch_dims={},
5444 lhs_contracting_dims={1}, rhs_contracting_dims={1},
5445 sharding={devices=[2,2]0,1,2,3}
5446 })";
5447
5448 TF_ASSERT_OK_AND_ASSIGN(auto module,
5449 PartitionComputation(hlo_string, /*num_devices=*/4));
5450 VLOG(1) << module->ToString();
5451
5452 auto lhs = AllOf(op::Shape("f32[48,100]"), op::Parameter(0));
5453 auto lhs_slice = AllOf(op::Shape("f32[24,100]"), op::DynamicSlice(lhs, _, _));
5454 auto rhs = AllOf(op::Shape("f32[16,50]"), op::Parameter(1));
5455 auto partial_replicated_rhs = AllOf(
5456 op::Shape("f32[16,100]"), op::AllReduce(op::DynamicUpdateSlice(
5457 _, op::CollectivePermute(rhs), _, _)));
5458 auto root = module->entry_computation()->root_instruction();
5459 EXPECT_THAT(root, AllOf(op::Shape("f32[24,16]"),
5460 op::Dot(lhs_slice, partial_replicated_rhs)));
5461 }
5462
TEST_F(SpmdPartitioningTest,Dot2DPartitionedNoncontractingAndContracting3)5463 TEST_F(SpmdPartitioningTest, Dot2DPartitionedNoncontractingAndContracting3) {
5464 absl::string_view hlo_string = R"(
5465 HloModule module
5466
5467 ENTRY entry {
5468 %lhs = f32[23,24] parameter(0), sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
5469 %rhs = f32[23,32] parameter(1), sharding={devices=[2,2]0,1,2,3}
5470 ROOT %dot = f32[24,32] dot(%lhs, %rhs),
5471 lhs_contracting_dims={0}, rhs_contracting_dims={0},
5472 sharding={devices=[2,2]1,0,3,2}
5473 })";
5474
5475 TF_ASSERT_OK_AND_ASSIGN(auto module,
5476 PartitionComputation(hlo_string, /*num_devices=*/4));
5477 VLOG(1) << module->ToString();
5478
5479 auto lhs = AllOf(op::Shape("f32[12,24]"), op::Parameter(0));
5480 auto masked_lhs = op::Select(_, lhs, op::Broadcast(op::Constant()));
5481 auto rhs = AllOf(op::Shape("f32[12,16]"), op::Parameter(1));
5482 auto masked_rhs = op::Select(_, rhs, op::Broadcast(op::Constant()));
5483 auto root = module->entry_computation()->root_instruction();
5484 EXPECT_THAT(
5485 root,
5486 AllOf(op::Shape("f32[12,16]"),
5487 op::DynamicSlice(
5488 AllOf(op::Shape("f32[24,16]"),
5489 op::AllReduce(op::Dot(
5490 masked_lhs, op::CollectivePermute(masked_rhs)))),
5491 _, _)));
5492 }
5493
TEST_F(SpmdPartitioningTest,Dot2DPartitionedBatchAndNonContracting)5494 TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndNonContracting) {
5495 absl::string_view hlo_string = R"(
5496 HloModule module
5497
5498 ENTRY entry {
5499 %lhs = f32[4,24,100] parameter(0), sharding={devices=[2,2,1]0,1,2,3}
5500 %rhs = f32[4,32,100] parameter(1), sharding={devices=[2,2,1]0,1,2,3}
5501 ROOT %dot = f32[4,24,32] dot(%lhs, %rhs),
5502 lhs_batch_dims={0}, rhs_batch_dims={0},
5503 lhs_contracting_dims={2}, rhs_contracting_dims={2},
5504 sharding={devices=[2,2,1]0,1,2,3}
5505 })";
5506
5507 TF_ASSERT_OK_AND_ASSIGN(auto module,
5508 PartitionComputation(hlo_string, /*num_devices=*/4));
5509 VLOG(1) << module->ToString();
5510
5511 auto lhs = AllOf(op::Shape("f32[2,12,100]"), op::Parameter(0));
5512 auto rhs = AllOf(op::Shape("f32[2,16,100]"), op::Parameter(1));
5513 auto partial_replicated_rhs =
5514 AllOf(op::Shape("f32[2,32,100]"),
5515 op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _, _)));
5516 auto root = module->entry_computation()->root_instruction();
5517 EXPECT_THAT(root, AllOf(op::Shape("f32[2,12,32]"),
5518 op::Dot(lhs, partial_replicated_rhs)));
5519 }
5520
TEST_F(SpmdPartitioningTest,Dot2DPartitionedBatchAndContracting)5521 TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndContracting) {
5522 absl::string_view hlo_string = R"(
5523 HloModule module
5524
5525 ENTRY entry {
5526 %lhs = f32[4,24,100] parameter(0), sharding={devices=[2,1,2]0,1,2,3}
5527 %rhs = f32[4,32,100] parameter(1), sharding={devices=[1,2,2]0,1,2,3}
5528 ROOT %dot = f32[4,24,32] dot(%lhs, %rhs),
5529 lhs_batch_dims={0}, rhs_batch_dims={0},
5530 lhs_contracting_dims={2}, rhs_contracting_dims={2},
5531 sharding={devices=[2,2,1]0,1,2,3}
5532 })";
5533
5534 TF_ASSERT_OK_AND_ASSIGN(auto module,
5535 PartitionComputation(hlo_string, /*num_devices=*/4));
5536 VLOG(1) << module->ToString();
5537
5538 auto lhs = AllOf(op::Shape("f32[2,24,50]"), op::Parameter(0));
5539 auto rhs = AllOf(op::Shape("f32[4,16,50]"), op::Parameter(1));
5540 auto resharded_rhs =
5541 AllOf(op::Shape("f32[2,32,50]"),
5542 op::Reshape(op::Transpose(op::AllToAll(op::Reshape(rhs)))));
5543 auto root = module->entry_computation()->root_instruction();
5544 EXPECT_THAT(root, AllOf(op::Shape("f32[2,12,32]"),
5545 op::DynamicSlice(
5546 AllOf(op::Shape("f32[2,24,32]"),
5547 op::AllReduce(op::Dot(lhs, resharded_rhs))),
5548 _, _, _)));
5549 }
5550
TEST_F(SpmdPartitioningTest,Dot2DPartitionedBatchAndContracting2)5551 TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndContracting2) {
5552 absl::string_view hlo_string = R"(
5553 HloModule module
5554
5555 ENTRY entry {
5556 %lhs = f32[4,24,100] parameter(0), sharding={devices=[2,1,2]0,1,2,3}
5557 %rhs = f32[4,32,100] parameter(1), sharding={replicated}
5558 ROOT %dot = f32[4,24,32] dot(%lhs, %rhs),
5559 lhs_batch_dims={0}, rhs_batch_dims={0},
5560 lhs_contracting_dims={2}, rhs_contracting_dims={2},
5561 sharding={devices=[2,2,1]0,1,2,3}
5562 })";
5563
5564 TF_ASSERT_OK_AND_ASSIGN(auto module,
5565 PartitionComputation(hlo_string, /*num_devices=*/4));
5566 VLOG(1) << module->ToString();
5567
5568 auto lhs = AllOf(op::Shape("f32[2,24,50]"), op::Parameter(0));
5569 auto resharded_lhs =
5570 AllOf(op::Shape("f32[2,12,100]"),
5571 op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs)))));
5572 auto rhs = AllOf(op::Shape("f32[4,32,100]"), op::Parameter(1));
5573 auto rhs_slice =
5574 AllOf(op::Shape("f32[2,32,100]"), op::DynamicSlice(rhs, _, _, _));
5575 auto root = module->entry_computation()->root_instruction();
5576 EXPECT_THAT(root, AllOf(op::Shape("f32[2,12,32]"),
5577 op::Dot(resharded_lhs, rhs_slice)));
5578 }
5579
TEST_F(SpmdPartitioningTest,Dot2DPartitionedBatchNonContractingAndContracting)5580 TEST_F(SpmdPartitioningTest,
5581 Dot2DPartitionedBatchNonContractingAndContracting) {
5582 absl::string_view hlo_string = R"(
5583 HloModule module
5584
5585 ENTRY entry {
5586 %lhs = f32[4,24,100] parameter(0), sharding={devices=[2,1,2]0,1,2,3}
5587 %rhs = f32[4,32,100] parameter(1), sharding={devices=[2,2,1]0,1,2,3}
5588 ROOT %dot = f32[4,24,32] dot(%lhs, %rhs),
5589 lhs_batch_dims={0}, rhs_batch_dims={0},
5590 lhs_contracting_dims={2}, rhs_contracting_dims={2},
5591 sharding={devices=[2,1,2]0,1,2,3}
5592 })";
5593
5594 TF_ASSERT_OK_AND_ASSIGN(auto module,
5595 PartitionComputation(hlo_string, /*num_devices=*/4));
5596 VLOG(1) << module->ToString();
5597
5598 auto lhs = AllOf(op::Shape("f32[2,24,50]"), op::Parameter(0));
5599 auto rhs = AllOf(op::Shape("f32[2,16,100]"), op::Parameter(1));
5600 auto partial_replicated_lhs =
5601 AllOf(op::Shape("f32[2,24,100]"),
5602 op::AllReduce(op::DynamicUpdateSlice(_, lhs, _, _, _)));
5603 auto root = module->entry_computation()->root_instruction();
5604 EXPECT_THAT(root, AllOf(op::Shape("f32[2,24,16]"),
5605 op::Dot(partial_replicated_lhs, rhs)));
5606 }
5607
TEST_F(SpmdPartitioningTest,Dot2DPartitionedBatchAndReshard)5608 TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndReshard) {
5609 absl::string_view hlo_string = R"(
5610 HloModule module
5611
5612 ENTRY entry {
5613 %lhs = f32[4,8,24,100] parameter(0), sharding={devices=[2,1,2,1]0,1,2,3}
5614 %rhs = f32[4,8,32,100] parameter(1), sharding={devices=[2,1,2,1]0,1,2,3}
5615 ROOT %dot = f32[4,8,24,32] dot(%lhs, %rhs),
5616 lhs_batch_dims={0,1}, rhs_batch_dims={0,1},
5617 lhs_contracting_dims={3}, rhs_contracting_dims={3},
5618 sharding={devices=[1,2,2,1]0,1,2,3}
5619 })";
5620
5621 TF_ASSERT_OK_AND_ASSIGN(auto module,
5622 PartitionComputation(hlo_string, /*num_devices=*/4));
5623 VLOG(1) << module->ToString();
5624
5625 auto lhs = AllOf(op::Shape("f32[2,8,12,100]"), op::Parameter(0));
5626 auto rhs = AllOf(op::Shape("f32[2,8,16,100]"), op::Parameter(1));
5627 auto partial_replicated_rhs =
5628 AllOf(op::Shape("f32[2,8,32,100]"),
5629 op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _, _, _)));
5630 auto dot =
5631 AllOf(op::Shape("f32[2,8,12,32]"), op::Dot(lhs, partial_replicated_rhs));
5632 auto reshape = AllOf(op::Shape("f32[2,2,4,12,32]"), op::Reshape(dot));
5633 auto all_to_all = AllOf(op::Shape("f32[2,2,4,12,32]"), op::AllToAll(reshape));
5634 auto xpose = AllOf(op::Shape("f32[2,2,4,12,32]"), op::Transpose(all_to_all));
5635 auto root = module->entry_computation()->root_instruction();
5636 EXPECT_THAT(root, AllOf(op::Shape("f32[4,4,12,32]"), op::Reshape(xpose)));
5637 }
5638
TEST_F(SpmdPartitioningTest,SimpleDotPartial)5639 TEST_F(SpmdPartitioningTest, SimpleDotPartial) {
5640 absl::string_view hlo_string = R"(
5641 HloModule module
5642
5643 ENTRY entry {
5644 %lhs = f32[2,24,100] parameter(0),
5645 sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate}
5646 %rhs = f32[2,32,100] parameter(1),
5647 sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate}
5648 ROOT %dot = f32[2,24,32] dot(%lhs, %rhs),
5649 lhs_batch_dims={0}, rhs_batch_dims={0},
5650 lhs_contracting_dims={2}, rhs_contracting_dims={2},
5651 sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate}
5652 })";
5653
5654 TF_ASSERT_OK_AND_ASSIGN(auto module,
5655 PartitionComputation(hlo_string, /*num_devices=*/4));
5656 VLOG(1) << module->ToString();
5657
5658 auto lhs = AllOf(op::Shape("f32[1,24,100]"), op::Parameter(0));
5659 auto rhs = AllOf(op::Shape("f32[1,32,100]"), op::Parameter(1));
5660 auto dot = AllOf(op::Shape("f32[1,24,32]"), op::Dot(lhs, rhs));
5661 auto root = module->entry_computation()->root_instruction();
5662 EXPECT_THAT(root, dot);
5663 }
5664
TEST_F(SpmdPartitioningTest,DotPartialContracting)5665 TEST_F(SpmdPartitioningTest, DotPartialContracting) {
5666 absl::string_view hlo_string = R"(
5667 HloModule module
5668
5669 ENTRY entry {
5670 %lhs = f32[24,100] parameter(0),
5671 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
5672 %rhs = f32[32,100] parameter(1),
5673 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
5674 ROOT %dot = f32[24,32] dot(%lhs, %rhs),
5675 lhs_batch_dims={}, rhs_batch_dims={},
5676 lhs_contracting_dims={1}, rhs_contracting_dims={1},
5677 sharding={replicated}
5678 })";
5679
5680 TF_ASSERT_OK_AND_ASSIGN(auto module,
5681 PartitionComputation(hlo_string, /*num_devices=*/4));
5682 VLOG(1) << module->ToString();
5683
5684 auto lhs = AllOf(op::Shape("f32[24,50]"), op::Parameter(0));
5685 auto rhs = AllOf(op::Shape("f32[32,50]"), op::Parameter(1));
5686 auto dot = AllOf(op::Shape("f32[24,32]"), op::Dot(lhs, rhs));
5687 auto root = module->entry_computation()->root_instruction();
5688 EXPECT_THAT(root, op::AllReduce(dot));
5689 }
5690
TEST_F(SpmdPartitioningTest,DotPartialContracting2)5691 TEST_F(SpmdPartitioningTest, DotPartialContracting2) {
5692 absl::string_view hlo_string = R"(
5693 HloModule module
5694
5695 ENTRY entry {
5696 %lhs = f32[24,100] parameter(0),
5697 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
5698 %rhs = f32[32,100] parameter(1),
5699 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
5700 ROOT %dot = f32[24,32] dot(%lhs, %rhs),
5701 lhs_batch_dims={}, rhs_batch_dims={},
5702 lhs_contracting_dims={1}, rhs_contracting_dims={1},
5703 sharding={devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}
5704 })";
5705
5706 TF_ASSERT_OK_AND_ASSIGN(auto module,
5707 PartitionComputation(hlo_string, /*num_devices=*/4));
5708 VLOG(1) << module->ToString();
5709
5710 auto lhs = AllOf(op::Shape("f32[24,50]"), op::Parameter(0));
5711 auto rhs = AllOf(op::Shape("f32[32,50]"), op::Parameter(1));
5712 auto dot =
5713 AllOf(op::Shape("f32[12,32]"),
5714 op::Dot(AllOf(op::Shape("f32[12,50]"), op::DynamicSlice(lhs, _, _)),
5715 rhs));
5716 auto root = module->entry_computation()->root_instruction();
5717 EXPECT_THAT(root, op::AllReduce(dot));
5718 }
5719
TEST_F(SpmdPartitioningTest,DotPartialContracting3)5720 TEST_F(SpmdPartitioningTest, DotPartialContracting3) {
5721 absl::string_view hlo_string = R"(
5722 HloModule module
5723
5724 ENTRY entry {
5725 %lhs = f32[24,100] parameter(0),
5726 sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
5727 %rhs = f32[32,100] parameter(1),
5728 sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
5729 ROOT %dot = f32[24,32] dot(%lhs, %rhs),
5730 lhs_batch_dims={}, rhs_batch_dims={},
5731 lhs_contracting_dims={1}, rhs_contracting_dims={1},
5732 sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
5733 })";
5734
5735 TF_ASSERT_OK_AND_ASSIGN(auto module,
5736 PartitionComputation(hlo_string, /*num_devices=*/8));
5737 VLOG(1) << module->ToString();
5738
5739 auto lhs = AllOf(op::Shape("f32[24,50]"), op::Parameter(0));
5740 auto rhs =
5741 AllOf(op::Shape("f32[16,50]"), op::DynamicSlice(op::Parameter(1), _, _));
5742 auto dot = AllOf(op::Shape("f32[24,16]"), op::Dot(lhs, rhs));
5743 auto root = module->entry_computation()->root_instruction();
5744 EXPECT_THAT(root, op::CollectivePermute(op::AllReduce(dot)));
5745 }
5746
TEST_F(SpmdPartitioningTest,DotBatchAndPartialContracting)5747 TEST_F(SpmdPartitioningTest, DotBatchAndPartialContracting) {
5748 absl::string_view hlo_string = R"(
5749 HloModule module
5750
5751 ENTRY entry {
5752 %lhs = f32[4,24,100] parameter(0),
5753 sharding={devices=[2,2,2]0,1,2,3,4,5,6,7}
5754 %rhs = f32[4,32,100] parameter(1),
5755 sharding={devices=[2,1,2,2]0,2,1,3,4,6,5,7 last_tile_dim_replicate}
5756 ROOT %dot = f32[4,24,32] dot(%lhs, %rhs),
5757 lhs_batch_dims={0}, rhs_batch_dims={0},
5758 lhs_contracting_dims={2}, rhs_contracting_dims={2},
5759 sharding={devices=[2,2,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
5760 })";
5761
5762 TF_ASSERT_OK_AND_ASSIGN(auto module,
5763 PartitionComputation(hlo_string, /*num_devices=*/8));
5764 VLOG(1) << module->ToString();
5765
5766 auto lhs = AllOf(op::Shape("f32[2,12,50]"), op::Parameter(0));
5767 auto rhs = AllOf(op::Shape("f32[2,32,50]"), op::Parameter(1));
5768 auto dot = AllOf(op::Shape("f32[2,12,32]"), op::Dot(lhs, rhs));
5769 auto root = module->entry_computation()->root_instruction();
5770 EXPECT_THAT(root, op::AllReduce(dot));
5771 }
5772
TEST_F(SpmdPartitioningTest,DotPartialNonContracting)5773 TEST_F(SpmdPartitioningTest, DotPartialNonContracting) {
5774 absl::string_view hlo_string = R"(
5775 HloModule module
5776
5777 ENTRY entry {
5778 %lhs = f32[24,8,100] parameter(0),
5779 sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate}
5780 %rhs = f32[32,100] parameter(1), sharding={devices=[2,2]0,2,1,3}
5781 ROOT %dot = f32[24,8,32] dot(%lhs, %rhs),
5782 lhs_batch_dims={}, rhs_batch_dims={},
5783 lhs_contracting_dims={2}, rhs_contracting_dims={1},
5784 sharding={devices=[2,1,2]0,1,2,3}
5785 })";
5786
5787 TF_ASSERT_OK_AND_ASSIGN(auto module,
5788 PartitionComputation(hlo_string, /*num_devices=*/4));
5789 VLOG(1) << module->ToString();
5790
5791 auto lhs = AllOf(op::Shape("f32[12,8,100]"), op::Parameter(0));
5792 auto rhs = AllOf(op::Shape("f32[16,50]"), op::Parameter(1));
5793 auto partially_replicated_rhs =
5794 AllOf(op::Shape("f32[16,100]"),
5795 op::AllReduce(op::DynamicUpdateSlice(op::Broadcast(_), rhs, _, _)));
5796 auto dot =
5797 AllOf(op::Shape("f32[12,8,16]"), op::Dot(lhs, partially_replicated_rhs));
5798 auto root = module->entry_computation()->root_instruction();
5799 EXPECT_THAT(root, dot);
5800 }
5801
TEST_F(SpmdPartitioningTest,DotPartialNonContractingPartialMatch)5802 TEST_F(SpmdPartitioningTest, DotPartialNonContractingPartialMatch) {
5803 absl::string_view hlo_string = R"(
5804 HloModule module
5805
5806 ENTRY entry {
5807 %lhs = f32[24,8,100] parameter(0), sharding={devices=[2,2,1]0,1,2,3}
5808 %rhs = f32[32,100] parameter(1),
5809 sharding={devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}
5810 ROOT %dot = f32[24,8,32] dot(%lhs, %rhs),
5811 lhs_batch_dims={}, rhs_batch_dims={},
5812 lhs_contracting_dims={2}, rhs_contracting_dims={1},
5813 sharding={devices=[2,1,2]0,1,2,3}
5814 })";
5815
5816 TF_ASSERT_OK_AND_ASSIGN(auto module,
5817 PartitionComputation(hlo_string, /*num_devices=*/4));
5818 VLOG(1) << module->ToString();
5819
5820 auto lhs = AllOf(op::Shape("f32[12,4,100]"), op::Parameter(0));
5821 auto rhs = AllOf(op::Shape("f32[16,100]"), op::Parameter(1));
5822 auto partially_replicated_lhs = AllOf(
5823 op::Shape("f32[12,8,100]"),
5824 op::AllReduce(op::DynamicUpdateSlice(op::Broadcast(_), lhs, _, _, _)));
5825 auto dot =
5826 AllOf(op::Shape("f32[12,8,16]"), op::Dot(partially_replicated_lhs, rhs));
5827 auto root = module->entry_computation()->root_instruction();
5828 EXPECT_THAT(root, dot);
5829 }
5830
TEST_F(SpmdPartitioningTest,DotPartialContractingPartialMatch)5831 TEST_F(SpmdPartitioningTest, DotPartialContractingPartialMatch) {
5832 absl::string_view hlo_string = R"(
5833 HloModule module
5834
5835 ENTRY entry {
5836 %lhs = f32[24,8,100] parameter(0), sharding={devices=[1,2,2]0,1,2,3}
5837 %rhs = f32[32,8,100] parameter(1),
5838 sharding={devices=[1,1,2,2]0,2,1,3 last_tile_dim_replicate}
5839 ROOT %dot = f32[24,32] dot(%lhs, %rhs),
5840 lhs_batch_dims={}, rhs_batch_dims={},
5841 lhs_contracting_dims={1,2}, rhs_contracting_dims={1,2},
5842 sharding={replicated}
5843 })";
5844
5845 TF_ASSERT_OK_AND_ASSIGN(auto module,
5846 PartitionComputation(hlo_string, /*num_devices=*/4));
5847 VLOG(1) << module->ToString();
5848
5849 auto lhs = AllOf(op::Shape("f32[24,4,50]"), op::Parameter(0));
5850 auto rhs = AllOf(op::Shape("f32[32,8,50]"), op::Parameter(1));
5851 auto dot = AllOf(op::Shape("f32[24,32]"),
5852 op::Dot(lhs, AllOf(op::Shape("f32[32,4,50]"),
5853 op::DynamicSlice(rhs, _, _, _))));
5854 auto root = module->entry_computation()->root_instruction();
5855 EXPECT_THAT(root, op::AllReduce(op::AllReduce(dot)));
5856 }
5857
TEST_F(SpmdPartitioningTest,DotNonContractingPartialMatchContractingMatch)5858 TEST_F(SpmdPartitioningTest, DotNonContractingPartialMatchContractingMatch) {
5859 absl::string_view hlo_string = R"(
5860 HloModule module
5861
5862 ENTRY entry {
5863 %lhs = f32[24,8,100] parameter(0), sharding={devices=[2,1,2]0,1,2,3}
5864 %rhs = f32[100,50] parameter(1), sharding={devices=[2,2]0,2,1,3}
5865 ROOT %dot = f32[24,8,50] dot(%lhs, %rhs),
5866 lhs_batch_dims={}, rhs_batch_dims={},
5867 lhs_contracting_dims={2}, rhs_contracting_dims={0},
5868 sharding={devices=[2,2,1]0,1,2,3}
5869 })";
5870
5871 TF_ASSERT_OK_AND_ASSIGN(auto module,
5872 PartitionComputation(hlo_string, /*num_devices=*/4));
5873 VLOG(1) << module->ToString();
5874
5875 auto lhs = AllOf(op::Shape("f32[12,8,50]"), op::Parameter(0));
5876 auto rhs = AllOf(op::Shape("f32[50,25]"), op::Parameter(1));
5877 auto dot = AllOf(
5878 op::Shape("f32[12,8,50]"),
5879 op::Dot(lhs, AllOf(op::Shape("f32[50,50]"),
5880 op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _)))));
5881 auto root = module->entry_computation()->root_instruction();
5882 EXPECT_THAT(root, AllOf(op::Shape("f32[12,4,50]"),
5883 op::DynamicSlice(op::AllReduce(dot), _, _, _)))
5884 << module->ToString();
5885 }
5886
TEST_F(SpmdPartitioningTest,DotLHSMutiNonContractingRHSNotMatch)5887 TEST_F(SpmdPartitioningTest, DotLHSMutiNonContractingRHSNotMatch) {
5888 absl::string_view hlo_string = R"(
5889 HloModule module
5890
5891 ENTRY entry {
5892 %lhs = f32[24,8,10] parameter(0), sharding={devices=[2,2,1]0,1,2,3}
5893 %rhs = f32[10,50] parameter(1),
5894 sharding={devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}
5895 ROOT %dot = f32[24,8,50] dot(%lhs, %rhs),
5896 lhs_batch_dims={}, rhs_batch_dims={},
5897 lhs_contracting_dims={2}, rhs_contracting_dims={0},
5898 sharding={devices=[2,2,1]0,1,2,3}
5899 })";
5900
5901 TF_ASSERT_OK_AND_ASSIGN(auto module,
5902 PartitionComputation(hlo_string, /*num_devices=*/4));
5903 VLOG(1) << module->ToString();
5904
5905 auto lhs = AllOf(op::Shape("f32[12,4,10]"), op::Parameter(0));
5906 auto rhs = AllOf(op::Shape("f32[5,50]"), op::Parameter(1));
5907 auto dot = AllOf(
5908 op::Shape("f32[12,4,50]"),
5909 op::Dot(lhs, AllOf(op::Shape("f32[10,50]"),
5910 op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _)))));
5911 auto root = module->entry_computation()->root_instruction();
5912 EXPECT_THAT(root, dot) << module->ToString();
5913 }
5914
TEST_F(SpmdPartitioningTest,ElementwiseTest_SubgroupSharding_TileToReplicate)5915 TEST_F(SpmdPartitioningTest, ElementwiseTest_SubgroupSharding_TileToReplicate) {
5916 absl::string_view hlo_string = R"(
5917 HloModule module
5918
5919 ENTRY entry {
5920 constant = f32[6,3]{1,0}
5921 constant({{1,3,7},{5,1,4},{1,2,8},{2,3,7},{5,2,4},{2,2,8}}),
5922 sharding={devices=[1,2,2]0,1,2,3 last_tile_dims={manual}}
5923 constant.1 = f32[6,3]{1,0}
5924 constant({{2,7,2},{2,9,2},{2,6,2},{3,7,2},{2,9,3},{2,3,2}}),
5925 sharding={devices=[1,2,2]0,1,2,3 last_tile_dims={manual}}
5926 multiply = f32[6,3]{1,0} multiply(constant, constant.1),
5927 sharding={devices=[1,2,2]0,1,2,3 last_tile_dims={manual}}
5928 ROOT add = f32[6,3]{1,0} add(multiply, constant.1),
5929 sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dims={replicated, manual}}
5930 }
5931 )";
5932
5933 TF_ASSERT_OK_AND_ASSIGN(auto module,
5934 PartitionComputation(hlo_string, /*num_devices=*/4));
5935 VLOG(1) << module->ToString();
5936
5937 auto multiply_lhs =
5938 AllOf(op::Shape("f32[6,2]"),
5939 op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
5940 op::Constant(), op::Reshape()));
5941 auto multiply_rhs =
5942 AllOf(op::Shape("f32[6,2]"),
5943 op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
5944 op::Constant(), op::Reshape()));
5945 auto multiply =
5946 AllOf(op::Shape("f32[6,2]"), op::Multiply(multiply_lhs, multiply_rhs));
5947 auto replicated_lhs =
5948 AllOf(op::Shape("f32[6,3]"),
5949 op::Slice(op::AllReduce(op::DynamicUpdateSlice(
5950 op::Broadcast(), multiply, op::Constant(), op::Reshape()))));
5951 auto root = module->entry_computation()->root_instruction();
5952 EXPECT_THAT(root, AllOf(op::Shape("f32[6,3]"),
5953 op::Add(replicated_lhs, op::Constant())));
5954 }
5955
TEST_F(SpmdPartitioningTest,ElementwiseTest_SubgroupSharding_ReplicateToTile)5956 TEST_F(SpmdPartitioningTest, ElementwiseTest_SubgroupSharding_ReplicateToTile) {
5957 absl::string_view hlo_string = R"(
5958 HloModule module
5959
5960 ENTRY entry {
5961 constant = f32[6,3]{1,0}
5962 constant({{1,3,7},{5,1,4},{1,2,8},{2,3,7},{5,2,4},{2,2,8}}),
5963 sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dims={replicated,manual}}
5964 constant.1 = f32[6,3]{1,0}
5965 constant({{2,7,2},{2,9,2},{2,6,2},{3,7,2},{2,9,3},{2,3,2}}),
5966 sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dims={replicated,manual}}
5967 multiply = f32[6,3]{1,0} multiply(constant, constant.1),
5968 sharding={devices=[1,1,2,2]0,1,2,3 last_tile_dims={replicated,manual}}
5969 ROOT add = f32[6,3]{1,0} add(multiply, constant.1),
5970 sharding={devices=[1,2,2]0,1,2,3 last_tile_dims={manual}}
5971 }
5972 )";
5973
5974 TF_ASSERT_OK_AND_ASSIGN(auto module,
5975 PartitionComputation(hlo_string, /*num_devices=*/4));
5976 VLOG(1) << module->ToString();
5977
5978 auto multiply = AllOf(op::Shape("f32[6,3]"),
5979 op::Multiply(op::Constant(), op::Constant()));
5980 auto add_lhs = AllOf(op::Shape("f32[6,2]"),
5981 op::DynamicSlice(op::Pad(multiply, op::Constant()),
5982 op::Constant(), op::Reshape()));
5983 auto add_rhs = AllOf(op::Shape("f32[6,2]"),
5984 op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
5985 op::Constant(), op::Reshape()));
5986 auto root = module->entry_computation()->root_instruction();
5987 EXPECT_THAT(root, AllOf(op::Shape("f32[6,2]"), op::Add(add_lhs, add_rhs)));
5988 }
5989
TEST_F(SpmdPartitioningTest,ElementwiseTest_PartialReplicateToTiledHaloExchange)5990 TEST_F(SpmdPartitioningTest,
5991 ElementwiseTest_PartialReplicateToTiledHaloExchange) {
5992 absl::string_view hlo_string = R"(
5993 HloModule module
5994
5995 ENTRY entry {
5996 constant = f32[6,3]{1,0}
5997 constant({{1,3,7},{5,1,4},{1,2,8},{2,3,7},{5,2,4},{2,2,8}}),
5998 sharding={replicated}
5999 constant.1 = f32[6,3]{1,0}
6000 constant({{2,7,2},{2,9,2},{2,6,2},{3,7,2},{2,9,3},{2,3,2}}),
6001 sharding={replicated}
6002 multiply = f32[6,3]{1,0} multiply(constant, constant.1),
6003 sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
6004 ROOT add = f32[6,3]{1,0} add(multiply, constant.1),
6005 sharding={devices=[4,1]0,1,2,3}
6006 }
6007 )";
6008
6009 TF_ASSERT_OK_AND_ASSIGN(auto module,
6010 PartitionComputation(hlo_string, /*num_devices=*/4));
6011 VLOG(1) << module->ToString();
6012 auto partial_replicate_lhs =
6013 AllOf(op::Shape("f32[3,3]"),
6014 op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()));
6015 auto partial_replicate_rhs =
6016 AllOf(op::Shape("f32[3,3]"),
6017 op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant()));
6018 auto multiply =
6019 AllOf(op::Shape("f32[3,3]"),
6020 op::Multiply(partial_replicate_lhs, partial_replicate_rhs));
6021 auto right_halo =
6022 AllOf(op::Shape("f32[1,3]"), op::CollectivePermute(op::Slice(multiply)));
6023 auto add_lhs = AllOf(
6024 op::Shape("f32[2,3]"),
6025 op::DynamicSlice(
6026 op::DynamicSlice(
6027 op::Pad(op::Concatenate(multiply, right_halo), op::Constant()),
6028 op::Reshape(), op::Constant()),
6029 op::Subtract(), op::Subtract()));
6030 auto add_rhs = AllOf(op::Shape("f32[2,3]"),
6031 op::DynamicSlice(op::Pad(op::Constant(), op::Constant()),
6032 op::Reshape(), op::Constant()));
6033 auto root = module->entry_computation()->root_instruction();
6034 EXPECT_THAT(root, AllOf(op::Shape("f32[2,3]"), op::Add(add_lhs, add_rhs)));
6035 }
6036
TEST_F(SpmdPartitioningTest,TileToPartialReplicateReshard)6037 TEST_F(SpmdPartitioningTest, TileToPartialReplicateReshard) {
6038 absl::string_view hlo_string = R"(
6039 HloModule module
6040
6041 ENTRY entry {
6042 %param0 = f32[8,8] parameter(0)
6043 %copy = f32[8,8] copy(%param0),
6044 sharding={devices=[2,2]0,1,2,3}
6045 ROOT %copy0 = f32[8,8] copy(%copy),
6046 sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
6047 })";
6048
6049 TF_ASSERT_OK_AND_ASSIGN(auto module,
6050 PartitionComputation(hlo_string, /*num_devices=*/4));
6051 VLOG(1) << module->ToString();
6052 auto tiled = AllOf(op::Shape("f32[4,4]"),
6053 op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
6054 op::Reshape())));
6055 auto partially_replicated = AllOf(
6056 op::Shape("f32[4,8]"), op::Copy(op::AllReduce(op::DynamicUpdateSlice(
6057 op::Broadcast(_), tiled, _, _))));
6058 auto root = module->entry_computation()->root_instruction();
6059 EXPECT_THAT(root, partially_replicated);
6060 }
6061
TEST_F(SpmdPartitioningTest,TileToPartialReplicateReshardUnevenPartition)6062 TEST_F(SpmdPartitioningTest, TileToPartialReplicateReshardUnevenPartition) {
6063 absl::string_view hlo_string = R"(
6064 HloModule module
6065
6066 ENTRY entry {
6067 %param0 = f32[8,8] parameter(0),
6068 sharding={devices=[2,3]0,1,2,3,4,5}
6069 ROOT %copy0 = f32[8,8] copy(%param0),
6070 sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
6071 })";
6072
6073 TF_ASSERT_OK_AND_ASSIGN(auto module,
6074 PartitionComputation(hlo_string, /*num_devices=*/6));
6075 VLOG(1) << module->ToString();
6076 auto tiled = AllOf(op::Shape("f32[4,3]"), op::Parameter(0));
6077 auto partially_replicated = AllOf(
6078 op::Shape("f32[8,4]"),
6079 op::Copy(op::Reshape(
6080 op::Transpose(op::AllToAll(op::Reshape(op::Slice(op::AllReduce(
6081 op::DynamicUpdateSlice(op::Broadcast(), tiled, _, _)))))))));
6082 auto root = module->entry_computation()->root_instruction();
6083 EXPECT_THAT(root, partially_replicated);
6084 }
6085
TEST_F(SpmdPartitioningTest,PartialReplicateToTileReshardUnevenPartition)6086 TEST_F(SpmdPartitioningTest, PartialReplicateToTileReshardUnevenPartition) {
6087 absl::string_view hlo_string = R"(
6088 HloModule module
6089
6090 ENTRY entry {
6091 %param0 = f32[8,8] parameter(0),
6092 sharding={devices=[1,2,3]0,1,2,3,4,5 last_tile_dim_replicate}
6093 ROOT %copy0 = f32[8,8] copy(%param0),
6094 sharding={devices=[2,3]0,1,2,3,4,5}
6095 })";
6096
6097 TF_ASSERT_OK_AND_ASSIGN(auto module,
6098 PartitionComputation(hlo_string, /*num_devices=*/6));
6099 VLOG(1) << module->ToString();
6100 auto partial_replicated = AllOf(op::Shape("f32[8,4]"), op::Parameter(0));
6101 auto tiled = AllOf(
6102 op::Shape("f32[4,3]"),
6103 op::Copy(op::DynamicSlice(op::Pad(op::Reshape(op::Transpose(op::AllToAll(
6104 op::Reshape(partial_replicated)))),
6105 _),
6106 _, _)));
6107 auto root = module->entry_computation()->root_instruction();
6108 EXPECT_THAT(root, tiled);
6109 }
6110
TEST_F(SpmdPartitioningTest,PartialReplicateToTileReshard)6111 TEST_F(SpmdPartitioningTest, PartialReplicateToTileReshard) {
6112 absl::string_view hlo_string = R"(
6113 HloModule module
6114
6115 ENTRY entry {
6116 %param0 = f32[8,8] parameter(0)
6117 %copy = f32[8,8] copy(%param0),
6118 sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
6119 ROOT %copy0 = f32[8,8] copy(%copy),
6120 sharding={devices=[2,2]0,1,2,3}
6121 })";
6122
6123 TF_ASSERT_OK_AND_ASSIGN(auto module,
6124 PartitionComputation(hlo_string, /*num_devices=*/4));
6125 VLOG(1) << module->ToString();
6126 auto partially_replicated =
6127 AllOf(op::Shape("f32[4,8]"),
6128 op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
6129 op::Constant())));
6130 auto tiled =
6131 AllOf(op::Shape("f32[4,4]"),
6132 op::Copy(op::DynamicSlice(partially_replicated, op::Subtract(),
6133 op::Subtract())));
6134 auto root = module->entry_computation()->root_instruction();
6135 EXPECT_THAT(root, tiled);
6136 }
6137
TEST_F(SpmdPartitioningTest,PartialReplicateToPartialReplicateReshard_AllReduce)6138 TEST_F(SpmdPartitioningTest,
6139 PartialReplicateToPartialReplicateReshard_AllReduce) {
6140 absl::string_view hlo_string = R"(
6141 HloModule module
6142
6143 ENTRY entry {
6144 %param0 = f32[8,8] parameter(0)
6145 %copy = f32[8,8] copy(param0),
6146 sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6147 ROOT %copy0 = f32[8,8] copy(%copy),
6148 sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6149 })";
6150
6151 TF_ASSERT_OK_AND_ASSIGN(auto module,
6152 PartitionComputation(hlo_string, /*num_devices=*/8));
6153
6154 VLOG(1) << module->ToString();
6155 auto partially_replicated_init =
6156 AllOf(op::Shape("f32[4,4]"),
6157 op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
6158 op::Reshape())));
6159 auto partially_replicated =
6160 AllOf(op::Shape("f32[4,8]"),
6161 op::Copy(op::AllReduce(op::DynamicUpdateSlice(
6162 op::Broadcast(_), partially_replicated_init, _, _))));
6163 auto root = module->entry_computation()->root_instruction();
6164 EXPECT_THAT(root, partially_replicated);
6165 }
6166
TEST_F(SpmdPartitioningTest,PartialReplicateToPartialReplicateReshard_DynamicSlice)6167 TEST_F(SpmdPartitioningTest,
6168 PartialReplicateToPartialReplicateReshard_DynamicSlice) {
6169 absl::string_view hlo_string = R"(
6170 HloModule module
6171
6172 ENTRY entry {
6173 %param0 = f32[8,8] parameter(0)
6174 %copy = f32[8,8] copy(%param0),
6175 sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6176 ROOT %copy0 = f32[8,8] copy(%copy),
6177 sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6178 })";
6179
6180 TF_ASSERT_OK_AND_ASSIGN(auto module,
6181 PartitionComputation(hlo_string, /*num_devices=*/8));
6182 VLOG(1) << module->ToString();
6183 auto partially_replicated =
6184 AllOf(op::Shape("f32[4,8]"),
6185 op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
6186 op::Constant())));
6187 auto tiled =
6188 AllOf(op::Shape("f32[4,4]"),
6189 op::Copy(op::DynamicSlice(partially_replicated, op::Subtract(),
6190 op::Subtract())));
6191 auto root = module->entry_computation()->root_instruction();
6192 EXPECT_THAT(root, tiled);
6193 }
6194
TEST_F(SpmdPartitioningTest,PartialReplicateToPartialReplicateReshardWithCollectivePermute)6195 TEST_F(SpmdPartitioningTest,
6196 PartialReplicateToPartialReplicateReshardWithCollectivePermute) {
6197 absl::string_view hlo_string = R"(
6198 HloModule module
6199
6200 ENTRY entry {
6201 %param0 = f32[8,8] parameter(0)
6202 %copy = f32[8,8] copy(param0),
6203 sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6204 ROOT %copy0 = f32[8,8] copy(%copy),
6205 sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6206 })";
6207
6208 TF_ASSERT_OK_AND_ASSIGN(auto module,
6209 PartitionComputation(hlo_string, /*num_devices=*/8));
6210
6211 VLOG(1) << module->ToString();
6212 auto partially_replicated_init =
6213 AllOf(op::Shape("f32[4,4]"),
6214 op::CollectivePermute(op::Copy(op::DynamicSlice(
6215 op::Parameter(0), op::Reshape(), op::Reshape()))));
6216 auto partially_replicated =
6217 AllOf(op::Shape("f32[8,4]"),
6218 op::Copy(op::AllReduce(op::DynamicUpdateSlice(
6219 op::Broadcast(_), partially_replicated_init, _, _))));
6220 auto root = module->entry_computation()->root_instruction();
6221 EXPECT_THAT(root, partially_replicated);
6222 }
6223
TEST_F(SpmdPartitioningTest,PartialReplicateToPartialReplicateReshardCollectivePermute1)6224 TEST_F(SpmdPartitioningTest,
6225 PartialReplicateToPartialReplicateReshardCollectivePermute1) {
6226 absl::string_view hlo_string = R"(
6227 HloModule module
6228
6229 ENTRY entry {
6230 %param0 = f32[8,8] parameter(0)
6231 %copy = f32[8,8] copy(%param0),
6232 sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6233 ROOT %copy0 = f32[8,8] copy(%copy),
6234 sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6235 })";
6236
6237 TF_ASSERT_OK_AND_ASSIGN(auto module,
6238 PartitionComputation(hlo_string, /*num_devices=*/8));
6239 VLOG(1) << module->ToString();
6240 auto partially_replicated =
6241 AllOf(op::Shape("f32[8,4]"),
6242 op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
6243 op::Reshape())));
6244 auto tiled =
6245 AllOf(op::Shape("f32[4,4]"),
6246 op::Copy(op::CollectivePermute(op::DynamicSlice(
6247 partially_replicated, op::Subtract(), op::Subtract()))));
6248 auto root = module->entry_computation()->root_instruction();
6249 EXPECT_THAT(root, tiled);
6250 }
6251
TEST_F(SpmdPartitioningTest,PartialReplicateToPartialReplicateReshardHaloExchange)6252 TEST_F(SpmdPartitioningTest,
6253 PartialReplicateToPartialReplicateReshardHaloExchange) {
6254 absl::string_view hlo_string = R"(
6255 HloModule module
6256
6257 ENTRY entry {
6258 %param0 = f32[6,3] parameter(0)
6259 %copy = f32[6,3] copy(param0),
6260 sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6261 ROOT %copy0 = f32[6,3] copy(%copy),
6262 sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6263 })";
6264
6265 TF_ASSERT_OK_AND_ASSIGN(auto module,
6266 PartitionComputation(hlo_string, /*num_devices=*/8));
6267
6268 VLOG(1) << module->ToString();
6269 auto partially_replicated_init =
6270 AllOf(op::Shape("f32[2,3]"),
6271 op::Copy(op::DynamicSlice(op::Pad(op::Parameter(0), op::Constant()),
6272 op::Reshape(), op::Constant())));
6273 auto slice =
6274 AllOf(op::Shape("f32[2,3]"),
6275 op::DynamicSlice(op::Concatenate(op::CollectivePermute(op::Slice(
6276 partially_replicated_init)),
6277 partially_replicated_init),
6278 _, _));
6279 auto partially_replicated =
6280 AllOf(op::Shape("f32[3,3]"),
6281 op::Copy(op::Slice(op::AllReduce(
6282 op::DynamicUpdateSlice(op::Broadcast(_), slice, _, _)))));
6283 auto root = module->entry_computation()->root_instruction();
6284 EXPECT_THAT(root, partially_replicated);
6285 }
6286
TEST_F(SpmdPartitioningTest,PartialReplicateToPartialReplicateReshardHaloExchange1)6287 TEST_F(SpmdPartitioningTest,
6288 PartialReplicateToPartialReplicateReshardHaloExchange1) {
6289 absl::string_view hlo_string = R"(
6290 HloModule module
6291
6292 ENTRY entry {
6293 %param0 = f32[6,3] parameter(0)
6294 %copy = f32[6,3] copy(param0),
6295 sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6296 ROOT %copy0 = f32[6,3] copy(%copy),
6297 sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
6298 })";
6299
6300 TF_ASSERT_OK_AND_ASSIGN(auto module,
6301 PartitionComputation(hlo_string, /*num_devices=*/8));
6302
6303 VLOG(1) << module->ToString();
6304 auto partially_replicated_init =
6305 AllOf(op::Shape("f32[3,3]"),
6306 op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
6307 op::Constant())));
6308 auto slice = AllOf(
6309 op::Shape("f32[4,3]"),
6310 op::DynamicSlice(op::Pad(op::Concatenate(partially_replicated_init,
6311 op::CollectivePermute(op::Slice(
6312 partially_replicated_init))),
6313 op::Constant()),
6314 _, _));
6315 auto partially_replicated =
6316 AllOf(op::Shape("f32[2,3]"), op::Copy(op::DynamicSlice(slice, _, _)));
6317 auto root = module->entry_computation()->root_instruction();
6318 EXPECT_THAT(root, partially_replicated);
6319 }
6320
TEST_F(SpmdPartitioningTest,PartitionConvWithBathGroupCount)6321 TEST_F(SpmdPartitioningTest, PartitionConvWithBathGroupCount) {
6322 absl::string_view hlo_string = R"(
6323 HloModule module
6324
6325 ENTRY entry {
6326 %lhs = f32[16,801,1,1024] parameter(0)
6327 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6328 sharding={devices=[1,1,1,2]0,1}
6329 %rhs = f32[16,801,1,1024] parameter(1)
6330 %rhs.copy = f32[16,801,1,1024] copy(%rhs),
6331 sharding={devices=[1,1,1,2]0,1}
6332 ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy),
6333 dim_labels=f01b_i01o->01bf,batch_group_count=1024,
6334 window={size=801x1 pad=2_2x0_0},
6335 sharding={devices=[1,1,1,2]0,1}
6336 })";
6337
6338 TF_ASSERT_OK_AND_ASSIGN(auto module,
6339 PartitionComputation(hlo_string, /*num_devices=*/2));
6340
6341 VLOG(1) << module->ToString();
6342 auto root = module->entry_computation()->root_instruction();
6343 auto lhs = AllOf(
6344 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6345 op::Constant(), op::Reshape())),
6346 op::Shape("f32[16,801,1,512]"));
6347 auto rhs = AllOf(
6348 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6349 op::Constant(), op::Reshape())),
6350 op::Shape("f32[16,801,1,512]"));
6351 EXPECT_THAT(root,
6352 AllOf(op::Convolution(lhs, rhs), op::Shape("f32[5,1,1,512]")));
6353 }
6354
TEST_F(SpmdPartitioningTest,PartitionConvWithBathGroupCountRHSAlignWithLHS)6355 TEST_F(SpmdPartitioningTest, PartitionConvWithBathGroupCountRHSAlignWithLHS) {
6356 absl::string_view hlo_string = R"(
6357 HloModule module
6358
6359 ENTRY entry {
6360 %lhs = f32[16,801,1,1024] parameter(0)
6361 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6362 sharding={devices=[1,1,1,2]0,1}
6363 %rhs = f32[16,801,1,1024] parameter(1)
6364 %rhs.copy = f32[16,801,1,1024] copy(%rhs),
6365 sharding={devices=[1,2,1,1]0,1}
6366 ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy),
6367 dim_labels=f01b_i01o->01bf,batch_group_count=1024,
6368 window={size=801x1 pad=2_2x0_0},
6369 sharding={devices=[1,1,1,2]0,1}
6370 })";
6371
6372 TF_ASSERT_OK_AND_ASSIGN(auto module,
6373 PartitionComputation(hlo_string, /*num_devices=*/2));
6374 VLOG(1) << module->ToString();
6375 auto root = module->entry_computation()->root_instruction();
6376 auto lhs = AllOf(
6377 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6378 op::Constant(), op::Reshape())),
6379 op::Shape("f32[16,801,1,512]"));
6380 auto rhs = AllOf(op::Copy(op::DynamicSlice(
6381 op::Pad(op::Parameter(), op::Constant()), op::Constant(),
6382 op::Reshape(), op::Constant(), op::Constant())),
6383 op::Shape("f32[16,401,1,1024]"));
6384 auto resharded_rhs = AllOf(
6385 op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(rhs))))),
6386 op::Shape("f32[16,801,1,512]"));
6387 EXPECT_THAT(root, AllOf(op::Convolution(lhs, resharded_rhs),
6388 op::Shape("f32[5,1,1,512]")));
6389 }
6390
TEST_F(SpmdPartitioningTest,PartitionConvWithBathGroupCountLHSAlignWithRHS)6391 TEST_F(SpmdPartitioningTest, PartitionConvWithBathGroupCountLHSAlignWithRHS) {
6392 absl::string_view hlo_string = R"(
6393 HloModule module
6394
6395 ENTRY entry {
6396 %lhs = f32[16,801,1,1024] parameter(0)
6397 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6398 sharding={devices=[1,2,1,1]0,1}
6399 %rhs = f32[16,801,1,1024] parameter(1)
6400 %rhs.copy = f32[16,801,1,1024] copy(%rhs),
6401 sharding={devices=[1,1,1,2]0,1}
6402 ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy),
6403 dim_labels=f01b_i01o->01bf,batch_group_count=1024,
6404 window={size=801x1 pad=2_2x0_0},
6405 sharding={devices=[1,1,1,2]0,1}
6406 })";
6407
6408 TF_ASSERT_OK_AND_ASSIGN(auto module,
6409 PartitionComputation(hlo_string, /*num_devices=*/2));
6410 VLOG(1) << module->ToString();
6411 auto root = module->entry_computation()->root_instruction();
6412 auto rhs = AllOf(
6413 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6414 op::Constant(), op::Reshape())),
6415 op::Shape("f32[16,801,1,512]"));
6416 auto lhs = AllOf(op::Copy(op::DynamicSlice(
6417 op::Pad(op::Parameter(), op::Constant()), op::Constant(),
6418 op::Reshape(), op::Constant(), op::Constant())),
6419 op::Shape("f32[16,401,1,1024]"));
6420 auto resharded_lhs = AllOf(
6421 op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))))),
6422 op::Shape("f32[16,801,1,512]"));
6423 EXPECT_THAT(root, AllOf(op::Convolution(resharded_lhs, rhs),
6424 op::Shape("f32[5,1,1,512]")));
6425 }
6426
TEST_F(SpmdPartitioningTest,PartitionConvWithBathGroupCountOutputAlignWithLHS)6427 TEST_F(SpmdPartitioningTest,
6428 PartitionConvWithBathGroupCountOutputAlignWithLHS) {
6429 absl::string_view hlo_string = R"(
6430 HloModule module
6431
6432 ENTRY entry {
6433 %lhs = f32[16,801,1,1024] parameter(0)
6434 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6435 sharding={devices=[1,1,1,2]0,1}
6436 %rhs = f32[16,801,1,1024] parameter(1)
6437 %rhs.copy = f32[16,801,1,1024] copy(%rhs),
6438 sharding={devices=[1,1,1,2]0,1}
6439 ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy),
6440 dim_labels=f01b_i01o->01bf,batch_group_count=1024,
6441 window={size=801x1 pad=2_2x0_0},
6442 sharding={devices=[2,1,1,1]0,1}
6443 })";
6444
6445 TF_ASSERT_OK_AND_ASSIGN(auto module,
6446 PartitionComputation(hlo_string, /*num_devices=*/2));
6447 VLOG(1) << module->ToString();
6448 auto root = module->entry_computation()->root_instruction();
6449 auto lhs = AllOf(
6450 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6451 op::Constant(), op::Reshape())),
6452 op::Shape("f32[16,801,1,512]"));
6453 auto rhs = AllOf(
6454 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6455 op::Constant(), op::Reshape())),
6456 op::Shape("f32[16,801,1,512]"));
6457 auto conv = AllOf(op::Convolution(lhs, rhs), op::Shape("f32[5,1,1,512]"));
6458 EXPECT_THAT(root, AllOf(op::Reshape(op::Transpose(op::AllToAll(
6459 op::Reshape(op::Pad(conv, op::Constant()))))),
6460 op::Shape("f32[3,1,1,1024]")));
6461 }
6462
TEST_F(SpmdPartitioningTest,PartitionConvWithBathGroupCountOutputAlignWithRHS)6463 TEST_F(SpmdPartitioningTest,
6464 PartitionConvWithBathGroupCountOutputAlignWithRHS) {
6465 absl::string_view hlo_string = R"(
6466 HloModule module
6467
6468 ENTRY entry {
6469 %lhs = f32[16,801,1,1024] parameter(0)
6470 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6471 sharding={devices=[1,2,1,1]0,1}
6472 %rhs = f32[16,801,1,1024] parameter(1)
6473 %rhs.copy = f32[16,801,1,1024] copy(%rhs),
6474 sharding={devices=[1,1,1,2]0,1}
6475 ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy),
6476 dim_labels=f01b_i01o->01bf,batch_group_count=1024,
6477 window={size=801x1 pad=2_2x0_0},
6478 sharding={devices=[2,1,1,1]0,1}
6479 })";
6480
6481 TF_ASSERT_OK_AND_ASSIGN(auto module,
6482 PartitionComputation(hlo_string, /*num_devices=*/2));
6483 VLOG(1) << module->ToString();
6484 auto root = module->entry_computation()->root_instruction();
6485 auto rhs = AllOf(
6486 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6487 op::Constant(), op::Reshape())),
6488 op::Shape("f32[16,801,1,512]"));
6489 auto lhs = AllOf(op::Copy(op::DynamicSlice(
6490 op::Pad(op::Parameter(), op::Constant()), op::Constant(),
6491 op::Reshape(), op::Constant(), op::Constant())),
6492 op::Shape("f32[16,401,1,1024]"));
6493 auto resharded_lhs = AllOf(
6494 op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))))),
6495 op::Shape("f32[16,801,1,512]"));
6496 auto conv =
6497 AllOf(op::Convolution(resharded_lhs, rhs), op::Shape("f32[5,1,1,512]"));
6498 EXPECT_THAT(root, AllOf(op::Reshape(op::Transpose(op::AllToAll(
6499 op::Reshape(op::Pad(conv, op::Constant()))))),
6500 op::Shape("f32[3,1,1,1024]")));
6501 }
6502
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCount)6503 TEST_F(SpmdPartitioningTest, PartitionConvWithFeatureGroupCount) {
6504 absl::string_view hlo_string = R"(
6505 HloModule module
6506
6507 ENTRY entry {
6508 %lhs = f32[16,801,1,1024] parameter(0)
6509 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6510 sharding={devices=[1,1,1,2]0,1}
6511 %rhs = f32[5,1,1,1024] parameter(1)
6512 %rhs.copy = f32[5,1,1,1024] copy(%rhs),
6513 sharding={devices=[1,1,1,2]0,1}
6514 ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
6515 dim_labels=b01f_01io->b01f,feature_group_count=1024,
6516 window={size=5x1 pad=2_2x0_0},
6517 sharding={devices=[1,1,1,2]0,1}
6518 })";
6519
6520 TF_ASSERT_OK_AND_ASSIGN(auto module,
6521 PartitionComputation(hlo_string, /*num_devices=*/2));
6522 VLOG(1) << module->ToString();
6523 auto root = module->entry_computation()->root_instruction();
6524 auto lhs = AllOf(
6525 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6526 op::Constant(), op::Reshape())),
6527 op::Shape("f32[16,801,1,512]"));
6528 auto rhs = AllOf(
6529 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6530 op::Constant(), op::Reshape())),
6531 op::Shape("f32[5,1,1,512]"));
6532 EXPECT_THAT(root,
6533 AllOf(op::Convolution(lhs, rhs), op::Shape("f32[16,801,1,512]")));
6534 }
6535
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCount2)6536 TEST_F(SpmdPartitioningTest, PartitionConvWithFeatureGroupCount2) {
6537 absl::string_view hlo_string = R"(
6538 HloModule module
6539
6540 ENTRY entry {
6541 %lhs = f32[64,3,1,3072] parameter(0)
6542 %lhs.copy = f32[64,3,1,3072] copy(%lhs),
6543 sharding={devices=[1,1,1,4,8]0,1,2,3,4,5,6,7,16,17,18,19,20,21,22,23,24,25
6544 ,26,27,28,29,30,31,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
6545 %rhs = f32[3,1,1,3072] parameter(1)
6546 %rhs.copy = f32[3,1,1,3072] copy(%rhs),
6547 sharding={devices=[1,1,1,4,8]0,1,2,3,4,5,6,7,16,17,18,19,20,21,22,23,24,25
6548 ,26,27,28,29,30,31,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
6549 ROOT %conv = f32[64,1,1,3072] convolution(%lhs.copy, %rhs.copy),
6550 dim_labels=b01f_01io->b01f,feature_group_count=3072,
6551 window={size=3x1},
6552 sharding={devices=[8,1,1,4]0,16,24,8,2,18,26,10,4,20,28,12,6,22,30,14,7,23,
6553 31,15,5,21,29,13,3,19,27,11,1,17,25,9}
6554 })";
6555
6556 TF_ASSERT_OK_AND_ASSIGN(auto module,
6557 PartitionComputation(hlo_string, /*num_devices=*/32));
6558 VLOG(1) << module->ToString();
6559 auto root = module->entry_computation()->root_instruction();
6560 auto lhs =
6561 AllOf(op::DynamicSlice(
6562 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(),
6563 op::Constant(), op::Constant(),
6564 op::Reshape())),
6565 op::Reshape(), op::Constant(), op::Constant(), op::Constant()),
6566 op::Shape("f32[8,3,1,768]"));
6567 auto rhs = AllOf(
6568 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6569 op::Constant(), op::Reshape())),
6570 op::Shape("f32[3,1,1,768]"));
6571 EXPECT_THAT(root,
6572 AllOf(op::Convolution(lhs, rhs), op::Shape("f32[8,1,1,768]")));
6573 }
6574
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCountRHSAlignWithLHS)6575 TEST_F(SpmdPartitioningTest,
6576 PartitionConvWithFeatureGroupCountRHSAlignWithLHS) {
6577 absl::string_view hlo_string = R"(
6578 HloModule module
6579
6580 ENTRY entry {
6581 %lhs = f32[16,801,1,1024] parameter(0)
6582 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6583 sharding={devices=[1,1,1,2]0,1}
6584 %rhs = f32[5,1,1,1024] parameter(1)
6585 %rhs.copy = f32[5,1,1,1024] copy(%rhs),
6586 sharding={devices=[2,1,1,1]0,1}
6587 ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
6588 dim_labels=b01f_01io->b01f,feature_group_count=1024,
6589 window={size=5x1 pad=2_2x0_0},
6590 sharding={devices=[1,1,1,2]0,1}
6591 })";
6592
6593 TF_ASSERT_OK_AND_ASSIGN(auto module,
6594 PartitionComputation(hlo_string, /*num_devices=*/2));
6595 VLOG(1) << module->ToString();
6596 auto root = module->entry_computation()->root_instruction();
6597 auto lhs = AllOf(
6598 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6599 op::Constant(), op::Reshape())),
6600 op::Shape("f32[16,801,1,512]"));
6601 auto rhs = AllOf(op::Copy(op::DynamicSlice(
6602 op::Pad(op::Parameter(), op::Constant()), op::Reshape(),
6603 op::Constant(), op::Constant(), op::Constant())),
6604 op::Shape("f32[3,1,1,1024]"));
6605 auto resharded_rhs = AllOf(
6606 op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(rhs))))),
6607 op::Shape("f32[5,1,1,512]"));
6608 EXPECT_THAT(root, AllOf(op::Convolution(lhs, resharded_rhs),
6609 op::Shape("f32[16,801,1,512]")));
6610 }
6611
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCountLHSAlignWithRHS)6612 TEST_F(SpmdPartitioningTest,
6613 PartitionConvWithFeatureGroupCountLHSAlignWithRHS) {
6614 absl::string_view hlo_string = R"(
6615 HloModule module
6616
6617 ENTRY entry {
6618 %lhs = f32[16,801,1,1024] parameter(0)
6619 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6620 sharding={devices=[1,2,1,1]0,1}
6621 %rhs = f32[5,1,1,1024] parameter(1)
6622 %rhs.copy = f32[5,1,1,1024] copy(%rhs),
6623 sharding={devices=[1,1,1,2]0,1}
6624 ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
6625 dim_labels=b01f_01io->b01f,feature_group_count=1024,
6626 window={size=5x1 pad=2_2x0_0},
6627 sharding={devices=[1,1,1,2]0,1}
6628 })";
6629
6630 TF_ASSERT_OK_AND_ASSIGN(auto module,
6631 PartitionComputation(hlo_string, /*num_devices=*/2));
6632 VLOG(1) << module->ToString();
6633 auto root = module->entry_computation()->root_instruction();
6634 auto lhs = AllOf(op::Copy(op::DynamicSlice(
6635 op::Pad(op::Parameter(), op::Constant()), op::Constant(),
6636 op::Reshape(), op::Constant(), op::Constant())),
6637 op::Shape("f32[16,401,1,1024]"));
6638 auto rhs = AllOf(
6639 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6640 op::Constant(), op::Reshape())),
6641 op::Shape("f32[5,1,1,512]"));
6642 auto resharded_lhs = AllOf(
6643 op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))))),
6644 op::Shape("f32[16,801,1,512]"));
6645 EXPECT_THAT(root, AllOf(op::Convolution(resharded_lhs, rhs),
6646 op::Shape("f32[16,801,1,512]")));
6647 }
6648
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCountAlignOuputWithLHS)6649 TEST_F(SpmdPartitioningTest,
6650 PartitionConvWithFeatureGroupCountAlignOuputWithLHS) {
6651 absl::string_view hlo_string = R"(
6652 HloModule module
6653
6654 ENTRY entry {
6655 %lhs = f32[16,801,1,1024] parameter(0)
6656 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6657 sharding={devices=[1,1,1,2]0,1}
6658 %rhs = f32[5,1,1,1024] parameter(1)
6659 %rhs.copy = f32[5,1,1,1024] copy(%rhs),
6660 sharding={devices=[1,1,1,2]0,1}
6661 ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
6662 dim_labels=b01f_01io->b01f,feature_group_count=1024,
6663 window={size=5x1 pad=2_2x0_0},
6664 sharding={devices=[2,1,1,1]0,1}
6665 })";
6666
6667 TF_ASSERT_OK_AND_ASSIGN(auto module,
6668 PartitionComputation(hlo_string, /*num_devices=*/2));
6669 VLOG(1) << module->ToString();
6670 auto root = module->entry_computation()->root_instruction();
6671 auto lhs = AllOf(
6672 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6673 op::Constant(), op::Reshape())),
6674 op::Shape("f32[16,801,1,512]"));
6675 auto rhs = AllOf(
6676 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6677 op::Constant(), op::Reshape())),
6678 op::Shape("f32[5,1,1,512]"));
6679 auto conv = AllOf(op::Convolution(lhs, rhs), op::Shape("f32[16,801,1,512]"));
6680 EXPECT_THAT(root,
6681 AllOf(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(conv)))),
6682 op::Shape("f32[8,801,1,1024]")));
6683 }
6684
TEST_F(SpmdPartitioningTest,PartitionConvGroupOnFeatureGroupCount_RHSPartialReplicate)6685 TEST_F(SpmdPartitioningTest,
6686 PartitionConvGroupOnFeatureGroupCount_RHSPartialReplicate) {
6687 absl::string_view hlo_string = R"(
6688 HloModule module
6689
6690 ENTRY entry {
6691 %lhs = f32[16,801,1,1024] parameter(0)
6692 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6693 sharding={devices=[1,2,1,2]0,1,2,3}
6694 %rhs = f32[5,1,1,1024] parameter(1)
6695 %rhs.copy = f32[5,1,1,1024] copy(%rhs),
6696 sharding={devices=[1,1,1,2,2]0,2,1,3 last_tile_dim_replicate}
6697 ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
6698 dim_labels=b01f_01io->b01f,feature_group_count=1024,
6699 window={size=5x1 pad=2_2x0_0},
6700 sharding={devices=[1,2,1,2]0,1,2,3}
6701 })";
6702
6703 TF_ASSERT_OK_AND_ASSIGN(auto module,
6704 PartitionComputation(hlo_string, /*num_devices=*/4));
6705 VLOG(1) << module->ToString();
6706 auto root = module->entry_computation()->root_instruction();
6707 auto lhs = AllOf(op::Copy(op::DynamicSlice(
6708 op::Pad(op::Parameter(), op::Constant()), op::Constant(),
6709 op::Reshape(), op::Constant(), op::Reshape())),
6710 op::Shape("f32[16,401,1,512]"));
6711 auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
6712 op::CollectivePermute(op::Slice(lhs)));
6713 auto right_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
6714 op::CollectivePermute(op::Slice(lhs)));
6715 auto rhs = AllOf(
6716 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6717 op::Constant(), op::Reshape())),
6718 op::Shape("f32[5,1,1,512]"));
6719 EXPECT_THAT(
6720 root,
6721 AllOf(op::Convolution(
6722 op::Select(_, op::Concatenate(left_halo, lhs, right_halo), _),
6723 rhs),
6724 op::Shape("f32[16, 401, 1, 512]")));
6725 }
6726
TEST_F(SpmdPartitioningTest,PartitionConvGroupOnFeatureGroupCount_RHSAlignWithOutput)6727 TEST_F(SpmdPartitioningTest,
6728 PartitionConvGroupOnFeatureGroupCount_RHSAlignWithOutput) {
6729 absl::string_view hlo_string = R"(
6730 HloModule module
6731
6732 ENTRY entry {
6733 %lhs = f32[16,801,1,1024] parameter(0)
6734 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6735 sharding={devices=[1,2,1,2]0,1,2,3}
6736 %rhs = f32[5,1,1,1024] parameter(1), sharding={replicated}
6737 ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs),
6738 dim_labels=b01f_01io->b01f,feature_group_count=1024,
6739 window={size=5x1 pad=2_2x0_0},
6740 sharding={devices=[1,2,1,2]0,1,2,3}
6741 })";
6742 TF_ASSERT_OK_AND_ASSIGN(auto module,
6743 PartitionComputation(hlo_string, /*num_devices=*/4));
6744 VLOG(1) << module->ToString();
6745 auto root = module->entry_computation()->root_instruction();
6746 auto lhs = AllOf(op::Copy(op::DynamicSlice(
6747 op::Pad(op::Parameter(), op::Constant()), op::Constant(),
6748 op::Reshape(), op::Constant(), op::Reshape())),
6749 op::Shape("f32[16,401,1,512]"));
6750 auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
6751 op::CollectivePermute(op::Slice(lhs)));
6752 auto right_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
6753 op::CollectivePermute(op::Slice(lhs)));
6754 auto rhs =
6755 AllOf(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6756 op::Constant(), op::Reshape()),
6757 op::Shape("f32[5,1,1,512]"));
6758 EXPECT_THAT(
6759 root,
6760 AllOf(op::Convolution(
6761 op::Select(_, op::Concatenate(left_halo, lhs, right_halo), _),
6762 rhs),
6763 op::Shape("f32[16, 401, 1, 512]")));
6764 }
6765
TEST_F(SpmdPartitioningTest,PartitionConvGroupOnFeatureGroupCount_LHSAlignWithOutput)6766 TEST_F(SpmdPartitioningTest,
6767 PartitionConvGroupOnFeatureGroupCount_LHSAlignWithOutput) {
6768 absl::string_view hlo_string = R"(
6769 HloModule module
6770
6771 ENTRY entry {
6772 %lhs = f32[16,801,1,1024] parameter(0)
6773 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6774 sharding={devices=[2,1,1,1,2]0,1,2,3 last_tile_dim_replicate}
6775 %rhs = f32[5,1,1,1024] parameter(1)
6776 %rhs.copy = f32[5,1,1,1024] copy(%rhs),
6777 sharding={devices=[1,1,1,2,2]0,2,1,3 last_tile_dim_replicate}
6778 ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
6779 dim_labels=b01f_01io->b01f,feature_group_count=1024,
6780 window={size=5x1 pad=2_2x0_0},
6781 sharding={devices=[1,2,1,2]0,1,2,3}
6782 })";
6783 TF_ASSERT_OK_AND_ASSIGN(auto module,
6784 PartitionComputation(hlo_string, /*num_devices=*/4));
6785 VLOG(1) << module->ToString();
6786 auto root = module->entry_computation()->root_instruction();
6787 auto lhs = AllOf(
6788 op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(),
6789 op::Constant(), op::Constant())),
6790 op::Shape("f32[8,801,1,1024]"));
6791 auto resharded_lhs =
6792 AllOf(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(
6793 op::Pad(op::DynamicSlice(lhs, op::Subtract(), op::Subtract(),
6794 op::Subtract(), op::Subtract()),
6795 op::Constant()))))),
6796 op::Shape("f32[16,401,1,512]"));
6797 auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
6798 op::CollectivePermute(op::Slice(resharded_lhs)));
6799 auto right_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
6800 op::CollectivePermute(op::Slice(resharded_lhs)));
6801 auto rhs = AllOf(
6802 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6803 op::Constant(), op::Reshape())),
6804 op::Shape("f32[5,1,1,512]"));
6805 EXPECT_THAT(
6806 root,
6807 AllOf(
6808 op::Convolution(
6809 op::Select(
6810 _, op::Concatenate(left_halo, resharded_lhs, right_halo), _),
6811 rhs),
6812 op::Shape("f32[16, 401, 1, 512]")));
6813 }
6814
TEST_F(SpmdPartitioningTest,PartitionConvGroupOnBatchGroupCount)6815 TEST_F(SpmdPartitioningTest, PartitionConvGroupOnBatchGroupCount) {
6816 absl::string_view hlo_string = R"(
6817 HloModule module
6818
6819 ENTRY entry {
6820 %lhs = f32[16,801,1,1024] parameter(0)
6821 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6822 sharding={devices=[1,2,1,2]0,1,2,3}
6823 %rhs = f32[16,801,1,1024] parameter(1)
6824 %rhs.copy = f32[16,801,1,1024] copy(%rhs),
6825 sharding={devices=[1,2,1,2]0,1,2,3}
6826 ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy),
6827 dim_labels=f01b_i01o->01bf,batch_group_count=1024,
6828 window={size=801x1 pad=2_2x0_0},
6829 sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate}
6830 })";
6831
6832 TF_ASSERT_OK_AND_ASSIGN(auto module,
6833 PartitionComputation(hlo_string, /*num_devices=*/4));
6834 VLOG(1) << module->ToString();
6835 auto root = module->entry_computation()->root_instruction();
6836 auto lhs = AllOf(
6837 op::Select(_,
6838 op::Copy(op::DynamicSlice(
6839 op::Pad(op::Parameter(), op::Constant()), op::Constant(),
6840 op::Reshape(), op::Constant(), op::Reshape())),
6841 _),
6842 op::Shape("f32[16,401,1,512]"));
6843 auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
6844 op::CollectivePermute(op::Slice(lhs)));
6845 auto right_halo = AllOf(op::Shape("f32[16,2, 1, 512]"),
6846 op::CollectivePermute(op::Slice(lhs)));
6847 auto rhs = AllOf(op::Copy(op::DynamicSlice(
6848 op::Pad(op::Parameter(), op::Constant()), op::Constant(),
6849 op::Reshape(), op::Constant(), op::Reshape())),
6850 op::Shape("f32[16,401,1,512]"));
6851 auto conv = AllOf(op::Convolution(op::Concatenate(left_halo, lhs, right_halo),
6852 op::Select(_, rhs, _)),
6853 op::Shape("f32[5,1,1,512]"));
6854 EXPECT_THAT(root, AllOf(op::CollectivePermute(op::AllReduce(conv)),
6855 op::Shape("f32[5,1,1,512]")));
6856 }
6857
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCountAlignOuputWithRHS)6858 TEST_F(SpmdPartitioningTest,
6859 PartitionConvWithFeatureGroupCountAlignOuputWithRHS) {
6860 absl::string_view hlo_string = R"(
6861 HloModule module
6862
6863 ENTRY entry {
6864 %lhs = f32[16,801,1,1024] parameter(0)
6865 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6866 sharding={devices=[1,2,1,1]0,1}
6867 %rhs = f32[5,1,1,1024] parameter(1)
6868 %rhs.copy = f32[5,1,1,1024] copy(%rhs),
6869 sharding={devices=[1,1,1,2]0,1}
6870 ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
6871 dim_labels=b01f_01io->b01f,feature_group_count=1024,
6872 window={size=5x1 pad=2_2x0_0},
6873 sharding={devices=[2,1,1,1]0,1}
6874 })";
6875
6876 TF_ASSERT_OK_AND_ASSIGN(auto module,
6877 PartitionComputation(hlo_string, /*num_devices=*/2));
6878 VLOG(1) << module->ToString();
6879 auto root = module->entry_computation()->root_instruction();
6880 auto lhs = AllOf(op::Copy(op::DynamicSlice(
6881 op::Pad(op::Parameter(), op::Constant()), op::Constant(),
6882 op::Reshape(), op::Constant(), op::Constant())),
6883 op::Shape("f32[16,401,1,1024]"));
6884 auto rhs = AllOf(
6885 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6886 op::Constant(), op::Reshape())),
6887 op::Shape("f32[5,1,1,512]"));
6888 auto resharded_lhs = AllOf(
6889 op::Slice(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))))),
6890 op::Shape("f32[16,801,1,512]"));
6891 auto conv = AllOf(op::Convolution(resharded_lhs, rhs),
6892 op::Shape("f32[16,801,1,512]"));
6893 EXPECT_THAT(root,
6894 AllOf(op::Reshape(op::Transpose(op::AllToAll(op::Reshape(conv)))),
6895 op::Shape("f32[8,801,1,1024]")));
6896 }
6897
TEST_F(SpmdPartitioningTest,PartitionConvWithFeatureGroupCountBackProp)6898 TEST_F(SpmdPartitioningTest, PartitionConvWithFeatureGroupCountBackProp) {
6899 absl::string_view hlo_string = R"(
6900 HloModule module
6901
6902 ENTRY entry {
6903 %lhs = f32[16,801,1,1024] parameter(0)
6904 %lhs.copy = f32[16,801,1,1024] copy(%lhs),
6905 sharding={devices=[1,1,1,2]0,1}
6906 %rhs = f32[5,1,1024,1] parameter(1)
6907 %rhs.copy = f32[5,1,1024,1] copy(%rhs),
6908 sharding={devices=[1,1,2,1]0,1}
6909 ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy),
6910 dim_labels=b01f_01oi->b01f,feature_group_count=1024,
6911 window={size=5x1 pad=2_2x0_0 rhs_reversal=1x1},
6912 sharding={devices=[1,1,1,2]0,1}
6913 })";
6914
6915 TF_ASSERT_OK_AND_ASSIGN(auto module,
6916 PartitionComputation(hlo_string, /*num_devices=*/2));
6917 VLOG(1) << module->ToString();
6918 auto root = module->entry_computation()->root_instruction();
6919 auto lhs = AllOf(
6920 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6921 op::Constant(), op::Reshape())),
6922 op::Shape("f32[16,801,1,512]"));
6923 auto rhs = AllOf(
6924 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6925 op::Reshape(), op::Constant())),
6926 op::Shape("f32[5,1,512,1]"));
6927 EXPECT_THAT(root,
6928 AllOf(op::Convolution(lhs, rhs), op::Shape("f32[16,801,1,512]")));
6929 }
6930
TEST_F(SpmdPartitioningTest,NoReshardOnBroadcastDims)6931 TEST_F(SpmdPartitioningTest, NoReshardOnBroadcastDims) {
6932 absl::string_view hlo_string = R"(
6933 HloModule module
6934
6935 ENTRY entry {
6936 %param0 = f32[2,3] parameter(0)
6937 %param1 = f32[2,3,20] parameter(1)
6938 %br0 = f32[20,2,20,3,20] broadcast(%param0), dimensions={1,3}, sharding={devices=[2,1,2,1,2]0,1,2,3,4,5,6,7}
6939 %br1 = f32[20,2,20,3,20] broadcast(%param1), dimensions={1,3,4}, sharding={devices=[2,1,2,1,2]0,1,2,3,4,5,6,7}
6940 %add = f32[20,2,20,3,20] add(%br0, %br1), sharding={devices=[2,1,2,1,2]0,1,2,3,4,5,6,7}
6941 %reshape = f32[10,4,10,6,20] reshape(%br0), sharding={devices=[2,1,2,1,2]0,1,2,3,4,5,6,7}
6942 %transpose = f32[2,3,20,20,20] transpose(%br0), dimensions={1,3,0,2,4}, sharding={devices=[1,1,2,2,2]0,1,2,3,4,5,6,7}
6943 %copy_add0 = f32[20,2,20,3,20] copy(%add), sharding={devices=[2,1,2,1,2]6,7,2,3,4,5,0,1}
6944 %copy_add1 = f32[20,2,20,3,20] copy(%add), sharding={devices=[2,1,2,1,2]7,6,3,2,5,4,0,1}
6945 %copy_reshape = f32[10,4,10,6,20] copy(%reshape), sharding={devices=[2,1,2,1,2]7,6,3,2,5,4,0,1}
6946 %copy_transpose = f32[2,3,20,20,20] copy(%transpose), sharding={devices=[1,1,2,2,2]7,6,3,2,5,4,0,1}
6947 ROOT %tuple = (f32[20,2,20,3,20], f32[20,2,20,3,20], f32[10,4,10,6,20], f32[2,3,20,20,20])
6948 tuple(%copy_add0, %copy_add1, %copy_reshape, %copy_transpose),
6949 sharding={{devices=[2,1,2,1,2]6,7,2,3,4,5,0,1},{devices=[2,1,2,1,2]7,6,3,2,5,4,0,1},{devices=[2,1,2,1,2]7,6,3,2,5,4,0,1},{devices=[1,1,2,2,2]7,6,3,2,5,4,0,1}}
6950 })";
6951
6952 TF_ASSERT_OK_AND_ASSIGN(auto module,
6953 PartitionComputation(hlo_string, /*num_devices=*/8));
6954 VLOG(1) << module->ToString();
6955 auto root = module->entry_computation()->root_instruction();
6956 // Reshard on copy_add0 only happens on broadcast dims, can be skipped.
6957 auto copy_add0 =
6958 op::Copy(op::Copy(op::Add(op::Broadcast(_), op::Broadcast(_))));
6959 // Reshard on copy_add1 also happens on non-broadcast dims.
6960 auto copy_add1 = op::Copy(
6961 op::CollectivePermute(op::Add(op::Broadcast(_), op::Broadcast(_))));
6962 // Reshard on copy_reshape only happens on broadcast dims, can be skipped.
6963 auto copy_reshape = op::Copy(op::Copy(op::Reshape(op::Broadcast(_))));
6964 // Reshard on copy_transpose only happens on broadcast dims, can be skipped.
6965 auto copy_transpose = op::Copy(op::Copy(op::Transpose(op::Broadcast(_))));
6966 EXPECT_THAT(root,
6967 op::Tuple(copy_add0, copy_add1, copy_reshape, copy_transpose));
6968 }
6969
TEST_F(SpmdPartitioningTest,ConvolutionFilterIFOFPartitionedInputPartialReplicate)6970 TEST_F(SpmdPartitioningTest,
6971 ConvolutionFilterIFOFPartitionedInputPartialReplicate) {
6972 absl::string_view hlo_string = R"(
6973 HloModule module
6974
6975 ENTRY entry {
6976 %lhs = f32[128,112,112,12] parameter(0)
6977 %lhs.copy = f32[128,112,112,12] copy(f32[128,112,112,12] %lhs),
6978 sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate}
6979 %rhs = f32[7,7,12,64] parameter(1)
6980 %rhs.copy = f32[7,7,12,64] copy(f32[7,7,12,64] %rhs),
6981 sharding={devices=[1,1,2,2]0,1,2,3}
6982 ROOT %conv = f32[128,56,56,64] convolution(
6983 f32[128,112,112,12] %lhs.copy,
6984 f32[7,7,12,64] %rhs.copy),
6985 window={size=7x7 stride=2x2 pad=3_3x3_3},
6986 dim_labels=b01f_01io->b01f,
6987 sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate}
6988 })";
6989
6990 TF_ASSERT_OK_AND_ASSIGN(auto module,
6991 PartitionComputation(hlo_string, /*num_devices=*/4));
6992 VLOG(1) << module->ToString();
6993 auto root = module->entry_computation()->root_instruction();
6994 auto lhs = AllOf(
6995 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
6996 op::Constant(), op::Reshape())),
6997 op::Shape("f32[128,112,112,6]"));
6998 auto rhs = AllOf(
6999 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
7000 op::Reshape(), op::Reshape())),
7001 op::Shape("f32[7,7,6,32]"));
7002
7003 EXPECT_THAT(
7004 root,
7005 AllOf(op::CollectivePermute(op::AllReduce(op::Convolution(lhs, rhs))),
7006 op::Shape("f32[128,56,56,32]")));
7007 }
7008
TEST_F(SpmdPartitioningTest,ConvolutionInputKernelNonContractingDimPartialReplicate)7009 TEST_F(SpmdPartitioningTest,
7010 ConvolutionInputKernelNonContractingDimPartialReplicate) {
7011 absl::string_view hlo_string = R"(
7012 HloModule module
7013
7014 ENTRY entry {
7015 %lhs = f32[128,56,56,256] parameter(0)
7016 %lhs.copy = f32[128,56,56,256] copy(%lhs),
7017 sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate}
7018 %rhs = f32[128,28,28,512] parameter(1)
7019 %rhs.copy = f32[128,28,28,512] copy(%rhs),
7020 sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate}
7021 ROOT %conv = f32[1,1,256,512] convolution(%lhs.copy, %rhs.copy),
7022 window={size=28x28 pad=0_-1x0_-1 rhs_dilate=2x2}, dim_labels=f01b_i01o->01bf,
7023 sharding={devices=[1,1,2,2]0,1,2,3}
7024 })";
7025
7026 TF_ASSERT_OK_AND_ASSIGN(auto module,
7027 PartitionComputation(hlo_string, /*num_devices=*/4));
7028 VLOG(1) << module->ToString();
7029 auto root = module->entry_computation()->root_instruction();
7030 auto lhs = AllOf(
7031 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
7032 op::Constant(), op::Reshape())),
7033 op::Shape("f32[128,56,56,128]"));
7034 auto rhs = AllOf(
7035 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
7036 op::Constant(), op::Reshape())),
7037 op::Shape("f32[128,28,28,256]"));
7038
7039 EXPECT_THAT(root, AllOf(op::Convolution(lhs, op::CollectivePermute(rhs)),
7040 op::Shape("f32[1,1,128,256]")));
7041 }
7042
TEST_F(SpmdPartitioningTest,ConvolutionInputSpatialDimAndFeatureDimParttiioned)7043 TEST_F(SpmdPartitioningTest,
7044 ConvolutionInputSpatialDimAndFeatureDimParttiioned) {
7045 absl::string_view hlo_string = R"(
7046 HloModule module
7047
7048 ENTRY entry {
7049 %lhs = f32[8,210,210,12] parameter(0)
7050 %lhs.copy = f32[8,210,210,12] copy(f32[8,210,210,12] %lhs),
7051 sharding={devices=[1,2,1,2]0,1,2,3}
7052 %rhs = f32[3,3,12,32] parameter(1)
7053 %rhs.copy = f32[3,3,12,32] copy(f32[3,3,12,32] %rhs),
7054 sharding={devices=[1,1,2,1,2]0,1,2,3 last_tile_dim_replicate}
7055 ROOT %conv = f32[8,210,210,32] convolution(
7056 f32[8,210,210,12] %lhs.copy,
7057 f32[3,3,12,32] %rhs.copy),
7058 window={size=3x3 pad=1_1x1_1},
7059 dim_labels=b01f_01io->b01f,
7060 sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate}
7061 })";
7062 TF_ASSERT_OK_AND_ASSIGN(auto module,
7063 PartitionComputation(hlo_string, /*num_devices=*/4));
7064 VLOG(1) << module->ToString();
7065 auto root = module->entry_computation()->root_instruction();
7066 auto lhs = AllOf(
7067 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(),
7068 op::Constant(), op::Reshape())),
7069 op::Shape("f32[8,105,210,6]"));
7070 auto left_halo =
7071 AllOf(op::CollectivePermute(op::Slice(lhs)), op::Shape("f32[8,1,210,6]"));
7072 auto right_halo =
7073 AllOf(op::CollectivePermute(op::Slice(lhs)), op::Shape("f32[8,1,210,6]"));
7074 auto exchanged_lhs = AllOf(
7075 op::Select(op::And(_, _), op::Concatenate(left_halo, lhs, right_halo),
7076 op::Broadcast(_)),
7077 op::Shape("f32[8,107,210,6]"));
7078 auto rhs = AllOf(
7079 op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(),
7080 op::Reshape(), op::Constant())),
7081 op::Shape("f32[3,3,6,32]"));
7082 EXPECT_THAT(root, AllOf(op::AllReduce(op::Convolution(
7083 exchanged_lhs, op::CollectivePermute(rhs))),
7084 op::Shape("f32[8,105,210,32]")));
7085 }
7086
TEST_F(SpmdPartitioningTest,Fft3D)7087 TEST_F(SpmdPartitioningTest, Fft3D) {
7088 absl::string_view hlo_string = R"(
7089 HloModule module
7090
7091 ENTRY entry {
7092 constant = c64[1,1,6]
7093 constant({{{(0,0),(1,1),(2,2),(3,3),(4,4),(5,5)}}}),
7094 sharding={devices=[1,1,2]0,1}
7095 ROOT fft = c64[1,1,6] fft(c64[1,1,6] constant), fft_type=FFT, fft_length={6},
7096 sharding={devices=[1,1,2]0,1}
7097 }
7098 )";
7099
7100 TF_ASSERT_OK_AND_ASSIGN(auto module,
7101 PartitionComputation(hlo_string, /*num_devices=*/2));
7102 VLOG(1) << module->ToString();
7103 auto root = module->entry_computation()->root_instruction();
7104 auto input = AllOf(op::DynamicSlice(op::Constant(), op::Constant(),
7105 op::Constant(), op::Reshape()),
7106 op::Shape("c64[1,1,3]"));
7107 auto padded_input =
7108 AllOf(op::DynamicSlice(
7109 op::Concatenate(input, op::CollectivePermute(op::Slice())),
7110 op::Constant(), op::Constant(), op::Reshape()),
7111 op::Shape("c64[1,1,4]"));
7112
7113 auto shuffled_input =
7114 AllOf(op::Slice(op::AllToAll(op::Dot(padded_input, op::Convert()))),
7115 op::Shape("c64[1,1,3]"));
7116
7117 auto local_fft = AllOf(op::Fft(shuffled_input), op::Shape("c64[1,1,3]"));
7118
7119 EXPECT_THAT(root, AllOf(op::GetTupleElement(op::While(op::Tuple(
7120 _, op::Multiply(local_fft, op::Exp()), _, _, _))),
7121 op::Shape("c64[1,1,3]")));
7122 }
7123
TEST_F(SpmdPartitioningTest,DotInputsAreIdentical)7124 TEST_F(SpmdPartitioningTest, DotInputsAreIdentical) {
7125 absl::string_view hlo_string = R"(
7126 HloModule module
7127
7128 ENTRY entry {
7129 %parameter.1 = f32[4000,4000]{1,0} parameter(0),
7130 sharding={devices=[2,4]0,1,2,3,4,5,6,7}
7131 ROOT %convolution = f32[4000,4000]{1,0} convolution(
7132 f32[4000,4000]{1,0} %parameter.1, f32[4000,4000]{1,0} %parameter.1),
7133 dim_labels=bf_io->bf, sharding={devices=[2,4]0,1,2,3,4,5,6,7}
7134 }
7135
7136 )";
7137
7138 TF_ASSERT_OK_AND_ASSIGN(auto module,
7139 PartitionComputation(hlo_string, /*num_devices=*/8));
7140 VLOG(1) << module->ToString();
7141 auto root = module->entry_computation()->root_instruction();
7142 auto param = AllOf(op::Parameter(), op::Shape("f32[2000, 1000]"));
7143 auto resharded_lhs =
7144 AllOf(op::AllReduce(op::DynamicUpdateSlice(_, param, _, _)),
7145 op::Shape("f32[2000, 4000]"));
7146 auto resharded_rhs =
7147 AllOf(op::AllReduce(op::DynamicUpdateSlice(_, op::Copy(param), _, _)),
7148 op::Shape("f32[4000, 1000]"));
7149 EXPECT_THAT(root, AllOf(op::Convolution(resharded_lhs, resharded_rhs),
7150 op::Shape("f32[2000, 1000]")));
7151 }
7152
TEST_F(SpmdPartitioningTest,ConstantSliceReshard)7153 TEST_F(SpmdPartitioningTest, ConstantSliceReshard) {
7154 absl::string_view hlo_string = R"(
7155 HloModule module
7156
7157 ENTRY entry {
7158 %constant.785 = f32[1,8] constant({{0,1,2,3,4,5,6,7}}),
7159 sharding={devices=[1,8]0,1,2,3,4,5,6,7}
7160 %slice.62 = f32[1,1] slice(%constant.785), slice={[0:1], [0:1]},
7161 sharding={devices=[1,8]0,1,2,3,4,5,6,7}
7162 ROOT %reshape.779 = f32[] reshape(%slice.62), sharding={replicated}
7163 })";
7164 TF_ASSERT_OK_AND_ASSIGN(auto module,
7165 PartitionComputation(hlo_string, /*num_devices=*/8));
7166 auto root = module->entry_computation()->root_instruction();
7167 VLOG(1) << module->ToString();
7168 auto slice = AllOf(op::Shape("f32[1,1]"),
7169 op::Copy(op::DynamicSlice(op::Constant(), _, _)));
7170 EXPECT_THAT(root, op::Reshape(op::AllReduce(op::Select(_, slice, _))));
7171 }
7172
TEST_F(SpmdPartitioningTest,GatherParallelDimRedistributionOperand)7173 TEST_F(SpmdPartitioningTest, GatherParallelDimRedistributionOperand) {
7174 absl::string_view hlo_string = R"(
7175 HloModule module
7176
7177 ENTRY %module {
7178 %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
7179 sharding={devices=[4,2,1,1]0,1,2,3,4,5,6,7}
7180 %constant = s32[4] constant({0, 1, 2, 3}), sharding={replicated}
7181 %iota = s32[1,8,4]{2,1,0} broadcast(%constant), dimensions={2},
7182 sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
7183 %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
7184 sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
7185 %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
7186 s32[1,8,4]{2,1,0} %iota2), dimensions={0},
7187 sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
7188 ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
7189 s32[8,4,2,2]{3,2,1,0} %parameter.0,
7190 s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
7191 collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
7192 slice_sizes={1,1,2,2}, sharding={replicated}
7193 })";
7194 TF_ASSERT_OK_AND_ASSIGN(auto module,
7195 PartitionComputation(hlo_string, /*num_devices=*/8));
7196 auto root = module->entry_computation()->root_instruction();
7197 VLOG(1) << module->ToString();
7198 auto operand = AllOf(op::Shape("s32[1,4,2,2]"), op::DynamicSlice());
7199 auto indices = AllOf(op::Shape("s32[2,1,4]"), op::Subtract());
7200 auto gather = AllOf(op::Shape("s32[1,4,2,2]"), op::Gather(operand, indices));
7201 EXPECT_THAT(root,
7202 op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)));
7203 }
7204
TEST_F(SpmdPartitioningTest,GatherParallelDimRedistributionIndices)7205 TEST_F(SpmdPartitioningTest, GatherParallelDimRedistributionIndices) {
7206 absl::string_view hlo_string = R"(
7207 HloModule module
7208
7209 ENTRY %module {
7210 %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
7211 sharding={devices=[8,1,1,1]0,1,2,3,4,5,6,7}
7212 %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
7213 sharding={devices=[1,4,2]0,1,2,3,4,5,6,7}
7214 %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
7215 sharding={devices=[1,4,2]0,1,2,3,4,5,6,7}
7216 %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
7217 s32[1,8,4]{2,1,0} %iota2), dimensions={0},
7218 sharding={devices=[1,4,2]0,1,2,3,4,5,6,7}
7219 ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(s32[8,4,2,2]{3,2,1,0} %parameter.0,
7220 s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
7221 collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
7222 slice_sizes={1,1,2,2}, sharding={replicated}
7223 })";
7224 TF_ASSERT_OK_AND_ASSIGN(auto module,
7225 PartitionComputation(hlo_string, /*num_devices=*/8));
7226 auto root = module->entry_computation()->root_instruction();
7227 VLOG(1) << module->ToString();
7228 auto operand = AllOf(op::Shape("s32[2,2,2,2]"), op::DynamicSlice());
7229 auto indices = AllOf(op::Shape("s32[2,2,2]"), op::Subtract());
7230 auto gather = AllOf(op::Shape("s32[2,2,2,2]"), op::Gather(operand, indices));
7231 EXPECT_THAT(root, op::AllReduce(op::AllReduce(
7232 op::DynamicUpdateSlice(_, gather, _, _, _, _))));
7233 }
7234
TEST_F(SpmdPartitioningTest,GatherParallelDimReplicatedIndices)7235 TEST_F(SpmdPartitioningTest, GatherParallelDimReplicatedIndices) {
7236 absl::string_view hlo_string = R"(
7237 HloModule module
7238
7239 ENTRY %module {
7240 %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
7241 sharding={devices=[8,1,1,1]0,1,2,3,4,5,6,7}
7242 %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
7243 sharding={replicated}
7244 %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
7245 sharding={replicated}
7246 %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
7247 s32[1,8,4]{2,1,0} %iota2), dimensions={0},
7248 sharding={replicated}
7249 ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
7250 s32[8,4,2,2]{3,2,1,0} %parameter.0,
7251 s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
7252 collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
7253 slice_sizes={1,1,2,2}, sharding={replicated}
7254 })";
7255 TF_ASSERT_OK_AND_ASSIGN(auto module,
7256 PartitionComputation(hlo_string, /*num_devices=*/8));
7257 auto root = module->entry_computation()->root_instruction();
7258 VLOG(1) << module->ToString();
7259 auto operand = AllOf(op::Shape("s32[1,4,2,2]"), op::Parameter());
7260 auto indices = AllOf(op::Shape("s32[2,1,4]"), op::Subtract());
7261 auto gather = AllOf(op::Shape("s32[1,4,2,2]"), op::Gather(operand, indices));
7262 EXPECT_THAT(root,
7263 op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)));
7264 }
7265
TEST_F(SpmdPartitioningTest,GatherParallelDimReplicatedOperand)7266 TEST_F(SpmdPartitioningTest, GatherParallelDimReplicatedOperand) {
7267 absl::string_view hlo_string = R"(
7268 HloModule module
7269
7270 ENTRY %module {
7271 %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0), sharding={replicated}
7272 %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
7273 sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
7274 %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
7275 sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
7276 %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
7277 s32[1,8,4]{2,1,0} %iota2), dimensions={0},
7278 sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
7279 ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
7280 s32[8,4,2,2]{3,2,1,0} %parameter.0,
7281 s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
7282 collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
7283 slice_sizes={1,1,2,2}, sharding={replicated}
7284 })";
7285 TF_ASSERT_OK_AND_ASSIGN(auto module,
7286 PartitionComputation(hlo_string, /*num_devices=*/8));
7287 auto root = module->entry_computation()->root_instruction();
7288 VLOG(1) << module->ToString();
7289 auto operand = AllOf(op::Shape("s32[1,4,2,2]"), op::DynamicSlice());
7290 auto indices = AllOf(op::Shape("s32[2,1,4]"), op::Subtract());
7291 auto gather = AllOf(op::Shape("s32[1,4,2,2]"), op::Gather(operand, indices));
7292 EXPECT_THAT(root,
7293 op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)));
7294 }
7295
TEST_F(SpmdPartitioningTest,GatherParallelDimPartialReplicatedIndices)7296 TEST_F(SpmdPartitioningTest, GatherParallelDimPartialReplicatedIndices) {
7297 absl::string_view hlo_string = R"(
7298 HloModule module
7299
7300 ENTRY %module {
7301 %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
7302 sharding={devices=[8,1,1,1]0,1,2,3,4,5,6,7}
7303 %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
7304 sharding={devices=[1,2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7305 %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
7306 sharding={devices=[1,2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7307 %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
7308 s32[1,8,4]{2,1,0} %iota2), dimensions={0},
7309 sharding={devices=[1,2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7310 ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
7311 s32[8,4,2,2]{3,2,1,0} %parameter.0,
7312 s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
7313 collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
7314 slice_sizes={1,1,2,2}, sharding={replicated}
7315 })";
7316 TF_ASSERT_OK_AND_ASSIGN(auto module,
7317 PartitionComputation(hlo_string, /*num_devices=*/8));
7318 auto root = module->entry_computation()->root_instruction();
7319 VLOG(1) << module->ToString();
7320 auto operand = AllOf(op::Shape("s32[1,4,2,2]"), op::Parameter());
7321 auto indices = AllOf(op::Shape("s32[2,1,4]"), op::Subtract());
7322 auto gather = AllOf(op::Shape("s32[1,4,2,2]"), op::Gather(operand, indices));
7323 EXPECT_THAT(root,
7324 op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)));
7325 }
7326
TEST_F(SpmdPartitioningTest,GatherParallelDimPartialReplicatedOperand)7327 TEST_F(SpmdPartitioningTest, GatherParallelDimPartialReplicatedOperand) {
7328 absl::string_view hlo_string = R"(
7329 HloModule module
7330
7331 ENTRY %module {
7332 %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0), sharding={
7333 devices=[2,1,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7334 %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
7335 sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
7336 %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
7337 sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
7338 %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
7339 s32[1,8,4]{2,1,0} %iota2), dimensions={0},
7340 sharding={devices=[1,8,1]0,1,2,3,4,5,6,7}
7341 ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
7342 s32[8,4,2,2]{3,2,1,0} %parameter.0,
7343 s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
7344 collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
7345 slice_sizes={1,1,2,2}, sharding={replicated}
7346 })";
7347 TF_ASSERT_OK_AND_ASSIGN(auto module,
7348 PartitionComputation(hlo_string, /*num_devices=*/8));
7349 auto root = module->entry_computation()->root_instruction();
7350 VLOG(1) << module->ToString();
7351 auto operand = AllOf(op::Shape("s32[1,4,2,2]"), op::DynamicSlice());
7352 auto indices = AllOf(op::Shape("s32[2,1,4]"), op::Subtract());
7353 auto gather = AllOf(op::Shape("s32[1,4,2,2]"), op::Gather(operand, indices));
7354 EXPECT_THAT(root,
7355 op::AllReduce(op::DynamicUpdateSlice(_, gather, _, _, _, _)));
7356 }
7357
TEST_F(SpmdPartitioningTest,GatherParallelDimSwappedDimensions)7358 TEST_F(SpmdPartitioningTest, GatherParallelDimSwappedDimensions) {
7359 absl::string_view hlo_string = R"(
7360 HloModule module
7361
7362 ENTRY %module {
7363 %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0), sharding={
7364 devices=[4,2,1,1]0,1,2,3,4,5,6,7}
7365 %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
7366 sharding={devices=[1,2,4]0,1,2,3,4,5,6,7}
7367 %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
7368 sharding={devices=[1,2,4]0,1,2,3,4,5,6,7}
7369 %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
7370 s32[1,8,4]{2,1,0} %iota2), dimensions={0},
7371 sharding={devices=[1,2,4]0,1,2,3,4,5,6,7}
7372 ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
7373 s32[8,4,2,2]{3,2,1,0} %parameter.0,
7374 s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
7375 collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=0,
7376 slice_sizes={1,1,2,2}, sharding={replicated}
7377 })";
7378 TF_ASSERT_OK_AND_ASSIGN(auto module,
7379 PartitionComputation(hlo_string, /*num_devices=*/8));
7380 auto root = module->entry_computation()->root_instruction();
7381 VLOG(1) << module->ToString();
7382 auto operand = AllOf(op::Shape("s32[4,1,2,2]"), op::CollectivePermute());
7383 auto indices = AllOf(op::Shape("s32[2,4,1]"), op::Subtract());
7384 auto gather = AllOf(op::Shape("s32[4,1,2,2]"), op::Gather(operand, indices));
7385 EXPECT_THAT(root, op::AllReduce(op::AllReduce(
7386 op::DynamicUpdateSlice(_, gather, _, _, _, _))));
7387 }
7388
TEST_F(SpmdPartitioningTest,GatherMergedParalleIndexPassthrough)7389 TEST_F(SpmdPartitioningTest, GatherMergedParalleIndexPassthrough) {
7390 absl::string_view hlo_string = R"(
7391 HloModule module
7392
7393 ENTRY %module {
7394 %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
7395 sharding={devices=[2,2,2,1]0,1,2,3,4,5,6,7}
7396 %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
7397 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7398 %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
7399 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7400 %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
7401 s32[1,8,4]{2,1,0} %iota2), dimensions={0},
7402 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7403 ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
7404 s32[8,4,2,2]{3,2,1,0} %parameter.0,
7405 s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
7406 collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
7407 slice_sizes={1,1,2,2}, sharding={replicated}
7408 })";
7409 TF_ASSERT_OK_AND_ASSIGN(auto module,
7410 PartitionComputation(hlo_string, /*num_devices=*/8));
7411 VLOG(1) << module->ToString();
7412 auto root = module->entry_computation()->root_instruction();
7413 auto operand = AllOf(op::Shape("s32[2,4,1,2]"), op::DynamicSlice());
7414 auto indices = AllOf(op::Shape("s32[2,2,4]"), op::Subtract());
7415 auto gather = AllOf(op::Shape("s32[2,4,1,2]"), op::Gather(operand, indices));
7416 EXPECT_THAT(root, op::AllReduce(op::AllReduce(
7417 op::DynamicUpdateSlice(_, gather, _, _, _, _))));
7418 }
7419
TEST_F(SpmdPartitioningTest,GatherParalleIndexAndOperand)7420 TEST_F(SpmdPartitioningTest, GatherParalleIndexAndOperand) {
7421 absl::string_view hlo_string = R"(
7422 HloModule module
7423
7424 ENTRY %module {
7425 %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
7426 sharding={devices=[4,1,2,1]0,1,2,3,4,5,6,7}
7427 %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
7428 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7429 %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
7430 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7431 %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
7432 s32[1,8,4]{2,1,0} %iota2), dimensions={0},
7433 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7434 ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
7435 s32[8,4,2,2]{3,2,1,0} %parameter.0,
7436 s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
7437 collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
7438 slice_sizes={1,1,2,2}, sharding={devices=[4,1,2,1]0,1,2,3,4,5,6,7}
7439 })";
7440 TF_ASSERT_OK_AND_ASSIGN(auto module,
7441 PartitionComputation(hlo_string, /*num_devices=*/8));
7442 VLOG(1) << module->ToString();
7443 auto root = module->entry_computation()->root_instruction();
7444 auto operand = AllOf(op::Shape("s32[2,4,1,2]"), op::Parameter(0));
7445 auto indices = AllOf(op::Shape("s32[2,2,4]"), op::Subtract());
7446 auto gather = AllOf(op::Shape("s32[2,4,1,2]"), op::Gather(operand, indices));
7447 EXPECT_THAT(root, gather);
7448 }
7449
TEST_F(SpmdPartitioningTest,GatherReshardParalleIndexAndOperand)7450 TEST_F(SpmdPartitioningTest, GatherReshardParalleIndexAndOperand) {
7451 absl::string_view hlo_string = R"(
7452 HloModule module
7453
7454 ENTRY %module {
7455 %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
7456 sharding={devices=[4,1,2,1]0,1,2,3,4,5,6,7}
7457 %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
7458 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7459 %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
7460 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7461 %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
7462 s32[1,8,4]{2,1,0} %iota2), dimensions={0},
7463 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7464 ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
7465 s32[8,4,2,2]{3,2,1,0} %parameter.0,
7466 s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
7467 collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
7468 slice_sizes={1,1,2,2}, sharding={devices=[4,1,2,1]1,0,3,2,4,5,6,7}
7469 })";
7470 TF_ASSERT_OK_AND_ASSIGN(auto module,
7471 PartitionComputation(hlo_string, /*num_devices=*/8));
7472 VLOG(1) << module->ToString();
7473 auto root = module->entry_computation()->root_instruction();
7474 auto operand = AllOf(op::Shape("s32[2,4,1,2]"), op::Parameter(0));
7475 auto indices = AllOf(op::Shape("s32[2,2,4]"), op::Subtract());
7476 auto gather = AllOf(op::Shape("s32[2,4,1,2]"), op::Gather(operand, indices));
7477 EXPECT_THAT(root, op::CollectivePermute(gather));
7478 }
7479
TEST_F(SpmdPartitioningTest,GatherParalleIndexAndOperandReshard)7480 TEST_F(SpmdPartitioningTest, GatherParalleIndexAndOperandReshard) {
7481 absl::string_view hlo_string = R"(
7482 HloModule module
7483
7484 ENTRY %module {
7485 %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
7486 sharding={devices=[4,1,1,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7487 %iota = s32[1,8,4]{2,1,0} iota(), iota_dimension=2,
7488 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7489 %iota2 = s32[1,8,4]{2,1,0} iota(), iota_dimension=1,
7490 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7491 %concatenate.19 = s32[2,8,4]{2,1,0} concatenate(s32[1,8,4]{2,1,0} %iota,
7492 s32[1,8,4]{2,1,0} %iota2), dimensions={0},
7493 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7494 ROOT %gather.20 = s32[8,4,2,2]{3,2,1,0} gather(
7495 s32[8,4,2,2]{3,2,1,0} %parameter.0,
7496 s32[2,8,4]{2,1,0} %concatenate.19), offset_dims={2,3},
7497 collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
7498 slice_sizes={1,1,2,2}, sharding={devices=[4,1,2,1]0,1,2,3,4,5,6,7}
7499 })";
7500 TF_ASSERT_OK_AND_ASSIGN(auto module,
7501 PartitionComputation(hlo_string, /*num_devices=*/8));
7502 VLOG(1) << module->ToString();
7503 auto root = module->entry_computation()->root_instruction();
7504 auto operand = AllOf(op::Shape("s32[2,4,2,2]"), op::Parameter(0));
7505 auto indices = AllOf(op::Shape("s32[2,2,4]"), op::Subtract());
7506 auto gather = AllOf(op::Shape("s32[2,4,2,2]"), op::Gather(operand, indices));
7507 EXPECT_THAT(root, op::DynamicSlice(gather, _, _, _, _));
7508 }
7509
TEST_F(SpmdPartitioningTest,GatherMergedParallelIndexTrivialSlice)7510 TEST_F(SpmdPartitioningTest, GatherMergedParallelIndexTrivialSlice) {
7511 absl::string_view hlo_string = R"(
7512 HloModule module
7513
7514 ENTRY %module {
7515 %parameter.0 = s32[8,4,2,2]{3,2,1,0} parameter(0),
7516 sharding={devices=[4,2,1,1]0,1,2,3,4,5,6,7}
7517 %parameter.1 = s32[1,8,1]{2,1,0} parameter(1),
7518 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7519 %iota = s32[1,8,1]{2,1,0} iota(), iota_dimension=1,
7520 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7521 %concatenate.19 = s32[2,8,1]{2,1,0} concatenate(
7522 s32[1,8,1]{2,1,0} %parameter.1, s32[1,8,1]{2,1,0} %iota), dimensions={0},
7523 sharding={devices=[1,4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7524 ROOT %gather.20 = s32[8,1,2,2]{3,2,1,0} gather(
7525 s32[8,4,2,2]{3,2,1,0} %parameter.0,
7526 s32[2,8,1]{2,1,0} %concatenate.19), offset_dims={2,3},
7527 collapsed_slice_dims={0,1}, start_index_map={1,0}, index_vector_dim=0,
7528 slice_sizes={1,1,2,2}, sharding={replicated}
7529 })";
7530 TF_ASSERT_OK_AND_ASSIGN(auto module,
7531 PartitionComputation(hlo_string, /*num_devices=*/8));
7532 auto root = module->entry_computation()->root_instruction();
7533 auto operand = AllOf(op::Shape("s32[2,2,2,2]"), op::Parameter());
7534 auto indices = AllOf(op::Shape("s32[2,2,1]"), op::Subtract());
7535 auto gather = AllOf(op::Shape("s32[2,1,2,2]"), op::Gather(operand, indices));
7536 EXPECT_THAT(root,
7537 op::AllReduce(op::DynamicUpdateSlice(
7538 _, op::AllReduce(op::Select(_, _, gather)), _, _, _, _)));
7539 }
7540
TEST_F(SpmdPartitioningTest,SortTopKNonSortDimension)7541 TEST_F(SpmdPartitioningTest, SortTopKNonSortDimension) {
7542 absl::string_view hlo_string = R"(
7543 HloModule module
7544
7545 %compare-greater-than.42077 (p.0.lhs.42078: f32[],
7546 p.0.rhs.42079: f32[], p.1.lhs.42080: s32[], p.1.rhs.42081: s32[]) -> pred[] {
7547 %p.0.lhs.42078 = f32[] parameter(0)
7548 %bitcast-convert.135 = s32[] bitcast-convert(f32[] %p.0.lhs.42078)
7549 %constant.45054 = s32[] constant(0)
7550 %compare.133 = pred[] compare(s32[] %bitcast-convert.135,
7551 s32[] %constant.45054), direction=LT
7552 %constant.45278 = u32[] constant(2147483647)
7553 %bitcast-convert.136 = u32[] bitcast-convert(f32[] %p.0.lhs.42078)
7554 %subtract.337 = u32[] subtract(u32[] %constant.45278,
7555 u32[] %bitcast-convert.136)
7556 %bitcast-convert.137 = s32[] bitcast-convert(u32[] %subtract.337)
7557 %select.282 = s32[] select(pred[] %compare.133, s32[] %bitcast-convert.137,
7558 s32[] %bitcast-convert.135)
7559 %p.0.rhs.42079 = f32[] parameter(1)
7560 %bitcast-convert.138 = s32[] bitcast-convert(f32[] %p.0.rhs.42079)
7561 %compare.134 = pred[] compare(s32[] %bitcast-convert.138,
7562 s32[] %constant.45054), direction=LT
7563 %bitcast-convert.139 = u32[] bitcast-convert(f32[] %p.0.rhs.42079)
7564 %subtract.338 = u32[] subtract(u32[] %constant.45278,
7565 u32[] %bitcast-convert.139)
7566 %bitcast-convert.140 = s32[] bitcast-convert(u32[] %subtract.338)
7567 %select.283 = s32[] select(pred[] %compare.134, s32[] %bitcast-convert.140,
7568 s32[] %bitcast-convert.138)
7569 %compare.135 = pred[] compare(s32[] %select.282,
7570 s32[] %select.283), direction=GT
7571 %compare.428 = pred[] compare(s32[] %select.283,
7572 s32[] %select.282), direction=GT
7573 %compare.429 = pred[] compare(pred[] %compare.135,
7574 pred[] %compare.428), direction=EQ
7575 %p.1.lhs.42080 = s32[] parameter(2)
7576 %p.1.rhs.42081 = s32[] parameter(3)
7577 %compare.430 = pred[] compare(s32[] %p.1.lhs.42080,
7578 s32[] %p.1.rhs.42081), direction=LT
7579 ROOT %select.579 = pred[] select(pred[] %compare.429,
7580 pred[] %compare.430, pred[] %compare.135)
7581 }
7582
7583 ENTRY %module {
7584 %parameter.0 = f32[2,64,32128]{2,1,0} parameter(0),
7585 sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
7586 %iota = s32[2,64,32128]{2,1,0} iota(), iota_dimension=2,
7587 sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
7588 %sort.18 = (f32[2,64,32128]{2,1,0}, s32[2,64,32128]{2,1,0}) sort(
7589 f32[2,64,32128]{2,1,0} %parameter.0, s32[2,64,32128]{2,1,0} %iota),
7590 dimensions={2}, is_stable=true, to_apply=%compare-greater-than.42077,
7591 sharding={{devices=[2,1,4]0,1,2,3,4,5,6,7},
7592 {devices=[2,1,4]0,1,2,3,4,5,6,7}}
7593 output = f32[2,64,32128]{2,1,0} get-tuple-element(%sort.18), index=0,
7594 sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
7595 %slice.0 = f32[2,64,2]{2,1,0} slice(f32[2,64,32128]{2,1,0} output),
7596 slice={[0:2], [0:64], [0:2]},
7597 sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
7598 output2 = s32[2,64,32128]{2,1,0} get-tuple-element(%sort.18), index=1,
7599 sharding={replicated}
7600 %slice.1 = s32[2,64,2]{2,1,0} slice(s32[2,64,32128]{2,1,0} output2),
7601 slice={[0:2], [0:64], [0:2]},
7602 sharding={devices=[2,1,4]0,1,2,3,4,5,6,7}
7603 ROOT output.t = (f32[2,64,2]{2,1,0},
7604 s32[2,64,2]{2,1,0}) tuple(slice.0, slice.1),
7605 sharding={{replicated}, {replicated}}
7606 })";
7607 TF_ASSERT_OK_AND_ASSIGN(auto module,
7608 PartitionComputation(hlo_string, /*num_devices=*/8));
7609
7610 const HloInstruction* sort = FindInstruction(module.get(), "sort");
7611 EXPECT_NE(sort, nullptr);
7612 auto sort_match =
7613 AllOf(op::Shape("(f32[2,64,32128], s32[2,64,32128])"), op::Sort(_, _));
7614 EXPECT_THAT(sort, sort_match);
7615 }
7616
TEST_F(SpmdPartitioningTest,SortTopKPropagateBaseShape)7617 TEST_F(SpmdPartitioningTest, SortTopKPropagateBaseShape) {
7618 absl::string_view hlo_string = R"(
7619 HloModule module
7620
7621 %compare-greater-than.42077 (p.0.lhs.42078: f32[],
7622 p.0.rhs.42079: f32[], p.1.lhs.42080: s32[], p.1.rhs.42081: s32[]) -> pred[] {
7623 %p.0.lhs.42078 = f32[] parameter(0)
7624 %bitcast-convert.135 = s32[] bitcast-convert(f32[] %p.0.lhs.42078)
7625 %constant.45054 = s32[] constant(0)
7626 %compare.133 = pred[] compare(s32[] %bitcast-convert.135,
7627 s32[] %constant.45054), direction=LT
7628 %constant.45278 = u32[] constant(2147483647)
7629 %bitcast-convert.136 = u32[] bitcast-convert(f32[] %p.0.lhs.42078)
7630 %subtract.337 = u32[] subtract(u32[] %constant.45278,
7631 u32[] %bitcast-convert.136)
7632 %bitcast-convert.137 = s32[] bitcast-convert(u32[] %subtract.337)
7633 %select.282 = s32[] select(pred[] %compare.133, s32[] %bitcast-convert.137,
7634 s32[] %bitcast-convert.135)
7635 %p.0.rhs.42079 = f32[] parameter(1)
7636 %bitcast-convert.138 = s32[] bitcast-convert(f32[] %p.0.rhs.42079)
7637 %compare.134 = pred[] compare(s32[] %bitcast-convert.138,
7638 s32[] %constant.45054), direction=LT
7639 %bitcast-convert.139 = u32[] bitcast-convert(f32[] %p.0.rhs.42079)
7640 %subtract.338 = u32[] subtract(u32[] %constant.45278,
7641 u32[] %bitcast-convert.139)
7642 %bitcast-convert.140 = s32[] bitcast-convert(u32[] %subtract.338)
7643 %select.283 = s32[] select(pred[] %compare.134, s32[] %bitcast-convert.140,
7644 s32[] %bitcast-convert.138)
7645 %compare.135 = pred[] compare(s32[] %select.282,
7646 s32[] %select.283), direction=GT
7647 %compare.428 = pred[] compare(s32[] %select.283,
7648 s32[] %select.282), direction=GT
7649 %compare.429 = pred[] compare(pred[] %compare.135,
7650 pred[] %compare.428), direction=EQ
7651 %p.1.lhs.42080 = s32[] parameter(2)
7652 %p.1.rhs.42081 = s32[] parameter(3)
7653 %compare.430 = pred[] compare(s32[] %p.1.lhs.42080,
7654 s32[] %p.1.rhs.42081), direction=LT
7655 ROOT %select.579 = pred[] select(pred[] %compare.429,
7656 pred[] %compare.430, pred[] %compare.135)
7657 }
7658
7659 ENTRY %module {
7660 %parameter.0 = f32[2,64,32128]{2,1,0} parameter(0),
7661 sharding={devices=[1,1,8]0,1,2,3,4,5,6,7}
7662 %iota = s32[2,64,32128]{2,1,0} iota(), iota_dimension=2,
7663 sharding={devices=[1,1,8]0,1,2,3,4,5,6,7}
7664 %sort.18 = (f32[2,64,32128]{2,1,0}, s32[2,64,32128]{2,1,0}) sort(
7665 f32[2,64,32128]{2,1,0} %parameter.0, s32[2,64,32128]{2,1,0} %iota),
7666 dimensions={2}, is_stable=true, to_apply=%compare-greater-than.42077,
7667 sharding={{devices=[1,1,8]0,1,2,3,4,5,6,7},
7668 {devices=[1,1,8]0,1,2,3,4,5,6,7}}
7669 output = f32[2,64,32128]{2,1,0} get-tuple-element(%sort.18), index=0,
7670 sharding={devices=[1,1,8]0,1,2,3,4,5,6,7}
7671 %slice.0 = f32[2,64,2]{2,1,0} slice(f32[2,64,32128]{2,1,0} output),
7672 slice={[0:2], [0:64], [0:2]},
7673 sharding={devices=[1,1,8]0,1,2,3,4,5,6,7}
7674 output2 = s32[2,64,32128]{2,1,0} get-tuple-element(%sort.18), index=1,
7675 sharding={replicated}
7676 %slice.1 = s32[2,64,2]{2,1,0} slice(s32[2,64,32128]{2,1,0} output2),
7677 slice={[0:2], [0:64], [0:2]},
7678 sharding={devices=[1,1,8]0,1,2,3,4,5,6,7}
7679 ROOT output.t = (f32[2,64,2]{2,1,0},
7680 s32[2,64,2]{2,1,0}) tuple(slice.0, slice.1),
7681 sharding={{replicated}, {replicated}}
7682 })";
7683 TF_ASSERT_OK_AND_ASSIGN(auto module,
7684 PartitionComputation(hlo_string, /*num_devices=*/8));
7685
7686 const HloInstruction* root = module->entry_computation()->root_instruction();
7687 auto all_reduce_val =
7688 AllOf(op::Shape("f32[2,64,2]"),
7689 op::Slice(op::AllReduce(op::DynamicUpdateSlice(_, _, _, _, _))));
7690 auto all_reduce_idx =
7691 AllOf(op::Shape("s32[2,64,2]"),
7692 op::Slice(op::AllReduce(op::DynamicUpdateSlice(_, _, _, _, _))));
7693 auto tuple = AllOf(op::Shape("(f32[2,64,2], s32[2,64,2])"),
7694 op::Tuple(all_reduce_val, all_reduce_idx));
7695 EXPECT_THAT(root, tuple);
7696 }
7697
TEST_F(SpmdPartitioningTest,GatherIndexOnlyCorrectReplacement)7698 TEST_F(SpmdPartitioningTest, GatherIndexOnlyCorrectReplacement) {
7699 absl::string_view hlo_string = R"(
7700 HloModule module
7701
7702 ENTRY %module {
7703 %parameter.0 = bf16[1,8,6,6]{3,2,1,0} parameter(0),
7704 sharding={replicated}
7705 %parameter.1 = s32[2,4]{1,0} parameter(1),
7706 sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7707 %gather.100 = bf16[2,1,8,1,6]{4,3,2,1,0} gather(
7708 bf16[1,8,6,6]{3,2,1,0} %parameter.0, s32[2,4]{1,0} %parameter.1),
7709 offset_dims={1,2,3,4}, collapsed_slice_dims={}, start_index_map={0,1,2,3},
7710 index_vector_dim=1, slice_sizes={1,8,1,6},
7711 sharding={devices=[2,1,4,1,1]0,1,2,3,4,5,6,7}
7712 %constant.45590 = s32[] constant(0), sharding={replicated}
7713 %broadcast.54515 = s32[2,64,1,1]{3,2,1,0} broadcast(s32[] %constant.45590),
7714 dimensions={},
7715 sharding={devices=[2,1,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
7716 ROOT %reshape.4243 = bf16[2,8,6]{2,1,0} reshape(
7717 bf16[2,1,8,1,6]{4,3,2,1,0} %gather.100),
7718 sharding={devices=[2,4,1]0,1,2,3,4,5,6,7}
7719 })";
7720 TF_ASSERT_OK_AND_ASSIGN(auto module,
7721 PartitionComputation(hlo_string, /*num_devices=*/8));
7722
7723 const HloInstruction* root = module->entry_computation()->root_instruction();
7724 auto param0 = AllOf(op::Shape("bf16[1,8,6,6]"), op::Parameter());
7725 auto param1 = AllOf(op::Shape("s32[1,4]"), op::Parameter());
7726 auto reshape = AllOf(
7727 op::Shape("bf16[1,2,6]"),
7728 op::Reshape(op::DynamicSlice(op::Gather(param0, param1), _, _, _, _, _)));
7729 EXPECT_THAT(root, reshape);
7730 }
7731
TEST_F(SpmdPartitioningTest,GatherRegressionTest1)7732 TEST_F(SpmdPartitioningTest, GatherRegressionTest1) {
7733 absl::string_view hlo_string = R"(
7734 HloModule module
7735
7736 ENTRY %module {
7737 %parameter.0 = s32[1,4] parameter(0), sharding={devices=[1,8]0,1,2,3,4,5,6,7}
7738 %iota.10 = s32[4]{0} iota(), iota_dimension=0, sharding={devices=[8]0,1,2,3,4,5,6,7}
7739 ROOT %gather.44 = s32[1,4]{1,0} gather(%parameter.0, %iota.10),
7740 offset_dims={0}, collapsed_slice_dims={1}, start_index_map={1}, index_vector_dim=1,
7741 slice_sizes={1,1}, sharding={devices=[1,8]0,1,2,3,4,5,6,7}
7742 })";
7743 TF_ASSERT_OK_AND_ASSIGN(auto module,
7744 PartitionComputation(hlo_string, /*num_devices=*/8));
7745
7746 const HloInstruction* root = module->entry_computation()->root_instruction();
7747 auto param0 = AllOf(op::Shape("s32[1,1]"), op::Parameter());
7748 EXPECT_THAT(root, op::Gather(param0, _));
7749 }
7750
TEST_F(SpmdPartitioningTest,WindowedEinsumPreferMemoryFootprint)7751 TEST_F(SpmdPartitioningTest, WindowedEinsumPreferMemoryFootprint) {
7752 absl::string_view hlo_string = R"(
7753 HloModule module
7754
7755 ENTRY %module {
7756 %parameter.0 = bf16[128,1024,4,4,1152,1,1]{6,5,4,3,2,1,0} parameter(0),
7757 sharding={devices=[4,1,2,1,1,1,1]0,1,2,3,4,5,6,7}
7758 %parameter.1 = bf16[4,4,1152,4,176,256,1]{6,5,4,3,2,1,0} parameter(1),
7759 sharding={devices=[2,2,1,2,1,1,1]0,1,2,3,4,5,6,7}
7760 %convolution.3 = bf16[128,1024,4,176,256,1,1]{6,5,4,3,2,1,0}
7761 convolution(bf16[128,1024,4,4,1152,1,1]{6,5,4,3,2,1,0} %parameter.0,
7762 bf16[4,4,1152,4,176,256,1]{6,5,4,3,2,1,0} %parameter.1),
7763 window={size=1x4x176x4x4 pad=0_0x3_3x175_175x0_0x0_0
7764 rhs_reversal=0x1x1x0x0}, dim_labels=0b34f12_34i12o0->0b12f34,
7765 sharding={devices=[4,1,2,1,1,1,1]0,1,2,3,4,5,6,7}
7766 ROOT %reshape.3973 = bf16[128,1024,4,176,256]{4,3,2,1,0}
7767 reshape(bf16[128,1024,4,176,256,1,1]{6,5,4,3,2,1,0} %convolution.3),
7768 sharding={replicated}
7769 })";
7770 TF_ASSERT_OK_AND_ASSIGN(
7771 auto module,
7772 PartitionComputation(hlo_string, /*num_devices=*/8,
7773 /*conv_halo_exchange_always_on_lhs =*/true,
7774 /*choose_faster_windowed_einsum =*/false));
7775 const HloInstruction* while_inst = FindInstruction(module.get(), "while");
7776 EXPECT_NE(while_inst, nullptr);
7777 const HloComputation* cond_comp = while_inst->while_condition();
7778 const HloInstruction* root = cond_comp->root_instruction();
7779 EXPECT_THAT(root, op::Compare(_, op::Constant()));
7780 const HloConstantInstruction* iterations =
7781 Cast<HloConstantInstruction>(root->operand(1));
7782 EXPECT_TRUE(iterations->literal().GetFirstInteger());
7783 EXPECT_EQ(*iterations->literal().GetFirstInteger(), 4);
7784 }
7785
TEST_F(SpmdPartitioningTest,WindowedEinsumPreferNumberIterations)7786 TEST_F(SpmdPartitioningTest, WindowedEinsumPreferNumberIterations) {
7787 absl::string_view hlo_string = R"(
7788 HloModule module
7789
7790 ENTRY %module {
7791 %parameter.0 = bf16[128,1024,4,4,1152,1,1]{6,5,4,3,2,1,0} parameter(0),
7792 sharding={devices=[4,1,2,1,1,1,1]0,1,2,3,4,5,6,7}
7793 %parameter.1 = bf16[4,4,1152,4,176,256,1]{6,5,4,3,2,1,0} parameter(1),
7794 sharding={devices=[2,2,1,2,1,1,1]0,1,2,3,4,5,6,7}
7795 %convolution.3 = bf16[128,1024,4,176,256,1,1]{6,5,4,3,2,1,0}
7796 convolution(bf16[128,1024,4,4,1152,1,1]{6,5,4,3,2,1,0} %parameter.0,
7797 bf16[4,4,1152,4,176,256,1]{6,5,4,3,2,1,0} %parameter.1),
7798 window={size=1x4x176x4x4 pad=0_0x3_3x175_175x0_0x0_0
7799 rhs_reversal=0x1x1x0x0}, dim_labels=0b34f12_34i12o0->0b12f34,
7800 sharding={devices=[4,1,2,1,1,1,1]0,1,2,3,4,5,6,7}
7801 ROOT %reshape.3973 = bf16[128,1024,4,176,256]{4,3,2,1,0}
7802 reshape(bf16[128,1024,4,176,256,1,1]{6,5,4,3,2,1,0} %convolution.3),
7803 sharding={replicated}
7804 })";
7805 TF_ASSERT_OK_AND_ASSIGN(
7806 auto module,
7807 PartitionComputation(hlo_string, /*num_devices=*/8,
7808 /*conv_halo_exchange_always_on_lhs =*/true,
7809 /*choose_faster_windowed_einsum =*/true));
7810 const HloInstruction* while_inst = FindInstruction(module.get(), "while");
7811 EXPECT_NE(while_inst, nullptr);
7812 const HloComputation* cond_comp = while_inst->while_condition();
7813 const HloInstruction* root = cond_comp->root_instruction();
7814 EXPECT_THAT(root, op::Compare(_, op::Constant()));
7815 const HloConstantInstruction* iterations =
7816 Cast<HloConstantInstruction>(root->operand(1));
7817 EXPECT_TRUE(iterations->literal().GetFirstInteger());
7818 EXPECT_EQ(*iterations->literal().GetFirstInteger(), 2);
7819 }
7820
TEST_F(SpmdPartitioningTest,WindowedEinsumPreferNumberIterations2)7821 TEST_F(SpmdPartitioningTest, WindowedEinsumPreferNumberIterations2) {
7822 const char* const hlo_string = R"(
7823 HloModule module
7824
7825 ENTRY entry {
7826 %lhs = bf16[512,1024,16,36,256]{4,3,2,1,0} parameter(0)
7827 %lhs.copy = bf16[512,1024,16,36,256]{4,3,2,1,0} copy(%lhs),
7828 sharding={devices=[8,1,4,1,1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,
7829 18,19,20,21,22,23,24,25,26,27,28,29,30,31}
7830 %rhs = bf16[512,1024,16,4,288]{4,3,2,1,0} parameter(1)
7831 %rhs.copy = bf16[512,1024,16,4,288]{4,3,2,1,0} copy(%rhs),
7832 sharding={devices=[8,1,4,1,1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,
7833 17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}
7834 %reshape.2556 = bf16[512,1024,16,4,288,1,1]{6,5,4,3,2,1,0} reshape(
7835 bf16[512,1024,16,4,288]{4,3,2,1,0} %rhs.copy), sharding={
7836 devices=[8,1,4,1,1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,
7837 20,21,22,23,24,25,26,27,28,29,30,31}
7838 %reshape.2570 = bf16[512,1024,16,36,256,1,1]{6,5,4,3,2,1,0}
7839 reshape(bf16[512,1024,16,36,256]{4,3,2,1,0} %lhs.copy), sharding={
7840 devices=[8,1,4,1,1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,
7841 20,21,22,23,24,25,26,27,28,29,30,31}
7842 %convolution.10 = bf16[16,36,256,16,4,288,1]{6,5,4,3,2,1,0}
7843 convolution(bf16[512,1024,16,36,256,1,1]{6,5,4,3,2,1,0} %reshape.2570,
7844 bf16[512,1024,16,4,288,1,1]{6,5,4,3,2,1,0} %reshape.2556),
7845 window={size=1x1x16x4x512 pad=0_0x0_0x15_15x3_3x0_0 rhs_reversal=0x0x1x1x0},
7846 dim_labels=4f01b23_4i23o01->01b23f4, sharding={devices=[4,1,1,4,2,1,1]0,4,8,
7847 12,16,20,24,28,1,5,9,13,17,21,25,29,2,6,10,14,18,22,26,30,3,7,11,15,19,23,
7848 27,31}
7849 ROOT %output = bf16[16,36,256,16,4,288,1]{6,5,4,3,2,1,0}
7850 copy(%convolution.10), sharding={replicated}
7851 })";
7852 TF_ASSERT_OK_AND_ASSIGN(
7853 auto module,
7854 PartitionComputation(hlo_string, /*num_devices=*/32,
7855 /*conv_halo_exchange_always_on_lhs =*/true,
7856 /*choose_faster_windowed_einsum =*/true));
7857 const HloInstruction* while_inst = FindInstruction(module.get(), "while");
7858 EXPECT_NE(while_inst, nullptr);
7859 const HloComputation* cond_comp = while_inst->while_condition();
7860 const HloInstruction* root = cond_comp->root_instruction();
7861 EXPECT_THAT(root, op::Compare(_, op::Constant()));
7862 const HloConstantInstruction* iterations =
7863 Cast<HloConstantInstruction>(root->operand(1));
7864 EXPECT_TRUE(iterations->literal().GetFirstInteger());
7865 EXPECT_EQ(*iterations->literal().GetFirstInteger(), 4);
7866 }
7867
TEST_F(SpmdPartitioningTest,WindowedEinsumPreferMemoryFootprint2)7868 TEST_F(SpmdPartitioningTest, WindowedEinsumPreferMemoryFootprint2) {
7869 const char* const hlo_string = R"(
7870 HloModule module
7871
7872 ENTRY entry {
7873 %lhs = bf16[512,1024,16,36,256]{4,3,2,1,0} parameter(0)
7874 %lhs.copy = bf16[512,1024,16,36,256]{4,3,2,1,0} copy(%lhs),
7875 sharding={devices=[8,1,4,1,1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,
7876 18,19,20,21,22,23,24,25,26,27,28,29,30,31}
7877 %rhs = bf16[512,1024,16,4,288]{4,3,2,1,0} parameter(1)
7878 %rhs.copy = bf16[512,1024,16,4,288]{4,3,2,1,0} copy(%rhs),
7879 sharding={devices=[8,1,4,1,1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,
7880 17,18,19,20,21,22,23,24,25,26,27,28,29,30,31}
7881 %reshape.2556 = bf16[512,1024,16,4,288,1,1]{6,5,4,3,2,1,0} reshape(
7882 bf16[512,1024,16,4,288]{4,3,2,1,0} %rhs.copy), sharding={
7883 devices=[8,1,4,1,1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,
7884 20,21,22,23,24,25,26,27,28,29,30,31}
7885 %reshape.2570 = bf16[512,1024,16,36,256,1,1]{6,5,4,3,2,1,0}
7886 reshape(bf16[512,1024,16,36,256]{4,3,2,1,0} %lhs.copy), sharding={
7887 devices=[8,1,4,1,1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,
7888 20,21,22,23,24,25,26,27,28,29,30,31}
7889 %convolution.10 = bf16[16,36,256,16,4,288,1]{6,5,4,3,2,1,0}
7890 convolution(bf16[512,1024,16,36,256,1,1]{6,5,4,3,2,1,0} %reshape.2570,
7891 bf16[512,1024,16,4,288,1,1]{6,5,4,3,2,1,0} %reshape.2556),
7892 window={size=1x1x16x4x512 pad=0_0x0_0x15_15x3_3x0_0 rhs_reversal=0x0x1x1x0},
7893 dim_labels=4f01b23_4i23o01->01b23f4, sharding={devices=[4,1,1,4,2,1,1]0,4,8,
7894 12,16,20,24,28,1,5,9,13,17,21,25,29,2,6,10,14,18,22,26,30,3,7,11,15,19,23,
7895 27,31}
7896 ROOT %output = bf16[16,36,256,16,4,288,1]{6,5,4,3,2,1,0}
7897 copy(%convolution.10), sharding={replicated}
7898 })";
7899 TF_ASSERT_OK_AND_ASSIGN(
7900 auto module,
7901 PartitionComputation(hlo_string, /*num_devices=*/32,
7902 /*conv_halo_exchange_always_on_lhs =*/true,
7903 /*choose_faster_windowed_einsum =*/false));
7904 const HloInstruction* while_inst = FindInstruction(module.get(), "while");
7905 EXPECT_NE(while_inst, nullptr);
7906 const HloComputation* cond_comp = while_inst->while_condition();
7907 const HloInstruction* root = cond_comp->root_instruction();
7908 EXPECT_THAT(root, op::Compare(_, op::Constant()));
7909 const HloConstantInstruction* iterations =
7910 Cast<HloConstantInstruction>(root->operand(1));
7911 EXPECT_TRUE(iterations->literal().GetFirstInteger());
7912 EXPECT_EQ(*iterations->literal().GetFirstInteger(), 8);
7913 }
7914
TEST_F(SpmdPartitioningTest,ContractingPartitionDotOperandsSlicedWrong)7915 TEST_F(SpmdPartitioningTest, ContractingPartitionDotOperandsSlicedWrong) {
7916 const char* const hlo_string = R"(
7917 HloModule module
7918
7919 ENTRY entry {
7920 %lhs = f32[8,2,15,4] parameter(0)
7921 %lhs.copy = f32[8,2,15,4] copy(%lhs),
7922 sharding={devices=[1,2,4,1]0,1,2,3,4,5,6,7}
7923 %rhs = f32[2,15,4] parameter(1)
7924 %rhs.copy = f32[2,15,4] copy(%rhs),
7925 sharding={devices=[2,4,1]0,1,2,3,4,5,6,7}
7926 %dot = f32[8,2,2] dot(%lhs.copy, %rhs.copy),
7927 lhs_batch_dims={}, rhs_batch_dims={},
7928 lhs_contracting_dims={2,3}, rhs_contracting_dims={1,2},
7929 operand_precision={HIGH,HIGH},
7930 sharding={devices=[2,2,2]0,1,2,3,4,5,6,7}
7931 ROOT %output = f32[8,2,2] copy(%dot), sharding={replicated}
7932 })";
7933 TF_ASSERT_OK_AND_ASSIGN(
7934 auto module,
7935 PartitionComputation(hlo_string, /*num_devices=*/8,
7936 /*conv_halo_exchange_always_on_lhs =*/true,
7937 /*choose_faster_windowed_einsum =*/true));
7938
7939 const HloInstruction* dot_op = FindInstruction(module.get(), "dot.1");
7940 auto op1 = op::Shape("f32[4,2,4,4]");
7941 auto op2 = op::Shape("f32[2,4,4]");
7942 EXPECT_THAT(dot_op, op::Dot(op1, op2));
7943 }
7944
TEST_F(SpmdPartitioningTest,PartitionDotGroupOnBatchContractingReshard)7945 TEST_F(SpmdPartitioningTest, PartitionDotGroupOnBatchContractingReshard) {
7946 absl::string_view hlo_string = R"(
7947 HloModule module
7948
7949 ENTRY entry {
7950 %lhs = f32[32,32,24,4096] parameter(0),
7951 sharding={devices=[2,1,1,2]0,1,2,3}
7952 %rhs = f32[32,4096,1024] parameter(1),
7953 sharding={devices=[2,2,1]0,1,2,3}
7954 ROOT %dot = f32[32,32,24,1024] dot(%lhs, %rhs),
7955 lhs_batch_dims={0}, rhs_batch_dims={0},
7956 lhs_contracting_dims={3}, rhs_contracting_dims={1},
7957 sharding={devices=[1,2,1,2]0,1,2,3}
7958 })";
7959
7960 TF_ASSERT_OK_AND_ASSIGN(auto module,
7961 PartitionComputation(hlo_string, /*num_devices=*/4));
7962 VLOG(1) << module->ToString();
7963 auto root = module->entry_computation()->root_instruction();
7964 auto dot = AllOf(op::Shape("f32[16,32,24,1024]"),
7965 op::Dot(op::Parameter(0), op::Parameter(1)));
7966 auto reduce_scatter = AllOf(op::Shape("f32[16,32,24,512]"),
7967 op::DynamicSlice(op::AllReduce(dot), _, _, _, _));
7968 EXPECT_THAT(root, AllOf(op::Reshape(op::Transpose(
7969 op::AllToAll(op::Reshape(reduce_scatter)))),
7970 op::Shape("f32[32,16,24,512]")));
7971 }
7972
TEST_F(SpmdPartitioningTest,PartitionPassthroughScatterCorrectOutputSharding)7973 TEST_F(SpmdPartitioningTest, PartitionPassthroughScatterCorrectOutputSharding) {
7974 absl::string_view hlo_string = R"(
7975 HloModule module
7976
7977 %scatter_add (parameter.0: bf16[], parameter.1: bf16[]) -> bf16[] {
7978 %parameter.0 = bf16[] parameter(0)
7979 %parameter.1 = bf16[] parameter(1)
7980 ROOT %add = bf16[] add(bf16[] %parameter.0, bf16[] %parameter.1)
7981 }
7982
7983 ENTRY entry {
7984 %operand = bf16[2,1024]{1,0} parameter(0),
7985 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
7986 %indices = s32[8,512,1]{2,1,0} parameter(1),
7987 sharding={devices=[2,1,1,2]0,2,1,3 last_tile_dim_replicate}
7988 %updates = bf16[8,512,1024]{2,1,0} parameter(2),
7989 sharding={devices=[2,1,2]0,2,1,3}
7990 ROOT %scatter = bf16[2,1024]{1,0} scatter(bf16[2,1024]{1,0} %operand,
7991 s32[8,512,1]{2,1,0} %indices,
7992 bf16[8,512,1024]{2,1,0} %updates), update_window_dims={2},
7993 inserted_window_dims={0}, scatter_dims_to_operand_dims={0},
7994 index_vector_dim=2, to_apply=%scatter_add,
7995 sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
7996 })";
7997
7998 TF_ASSERT_OK_AND_ASSIGN(auto module,
7999 PartitionComputation(hlo_string, /*num_devices=*/4));
8000 VLOG(1) << module->ToString();
8001 auto root = module->entry_computation()->root_instruction();
8002 auto scatter = AllOf(op::Shape("bf16[2,512]"), op::Scatter(_, _, _));
8003 EXPECT_THAT(root, scatter);
8004 }
8005
IsTrivialCollectivePermute(HloInstruction * hlo)8006 bool IsTrivialCollectivePermute(HloInstruction* hlo) {
8007 if (hlo->opcode() != HloOpcode::kCollectivePermute) {
8008 return false;
8009 }
8010 if (hlo->source_target_pairs().empty()) {
8011 return true;
8012 }
8013 return absl::c_all_of(hlo->source_target_pairs(),
8014 [](const std::pair<int64, int64>& pair) {
8015 return pair.first == pair.second;
8016 });
8017 }
8018
TEST_F(SpmdPartitioningTest,CollectivePermuteSimplifyIdentity)8019 TEST_F(SpmdPartitioningTest, CollectivePermuteSimplifyIdentity) {
8020 absl::string_view hlo_string = R"(
8021 HloModule test
8022
8023 ENTRY entry {
8024 %parameter.7 = f32[3,16] parameter(0), sharding={devices=[1,2]0,1}
8025 %constant.7 = f32[] constant(0)
8026 %pad.3 = f32[3,18] pad(f32[3,16] %parameter.7, f32[] %constant.7), padding=0_0x1_1, sharding={devices=[1,2]0,1}
8027 // Shift right by 16.
8028 %slice.8 = f32[3,16] slice(f32[3,18] %pad.3), slice={[0:3], [2:18]}, sharding={devices=[1,2]0,1}
8029 %slice.9 = f32[3,2] slice(f32[3,18] %pad.3), slice={[0:3], [0:2]}, sharding={devices=[1,2]0,1}
8030 ROOT %concatenate.6 = f32[3,18] concatenate(f32[3,16] %slice.8, f32[3,2] %slice.9), dimensions={1}, sharding={devices=[1,2]0,1}
8031 }
8032 )";
8033
8034 TF_ASSERT_OK_AND_ASSIGN(auto module,
8035 PartitionComputation(hlo_string, /*num_devices=*/2));
8036 VLOG(1) << module->ToString();
8037
8038 // Check that the partitioned code does not have a "trivial" collective
8039 // permute (which would degenerate to a copy).
8040 for (HloComputation* computation : module->computations()) {
8041 for (HloInstruction* hlo : computation->instructions()) {
8042 EXPECT_FALSE(IsTrivialCollectivePermute(hlo)) << hlo->ToString();
8043 }
8044 }
8045 }
8046
TEST_F(SpmdPartitioningTest,CollectivePermuteSimplifyZero)8047 TEST_F(SpmdPartitioningTest, CollectivePermuteSimplifyZero) {
8048 absl::string_view hlo_string = R"(
8049 HloModule test
8050
8051 ENTRY entry {
8052 %parameter = f32[3,16,16,16,16,132]{5,4,3,2,1,0} parameter(0), sharding={devices=[1,2,1,1,1,1]0,1}
8053 %slice = f32[3,1,16,16,16,132]{5,4,3,2,1,0} slice(f32[3,16,16,16,16,132]{5,4,3,2,1,0} %parameter), slice={[0:3], [15:16], [0:16], [0:16], [0:16], [0:132]}, sharding={devices=[1,2,1,1,1,1]0,1}
8054 %c0 = f32[] constant(0)
8055 ROOT %pad = f32[3,18,16,16,16,132]{5,4,3,2,1,0} pad(f32[3,1,16,16,16,132]{5,4,3,2,1,0} %slice, f32[] %c0), padding=0_0x0_17x0_0x0_0x0_0x0_0, sharding={devices=[1,2,1,1,1,1]0,1}
8056 }
8057 )";
8058
8059 TF_ASSERT_OK_AND_ASSIGN(auto module,
8060 PartitionComputation(hlo_string, /*num_devices=*/2));
8061 VLOG(1) << module->ToString();
8062
8063 // Check that the partitioned code does not have a collective permute with an
8064 // empty source_target_pair list.
8065 for (HloComputation* computation : module->computations()) {
8066 for (HloInstruction* hlo : computation->instructions()) {
8067 EXPECT_FALSE(IsTrivialCollectivePermute(hlo)) << hlo->ToString();
8068 }
8069 }
8070 }
8071
TEST_F(SpmdPartitioningTest,PadWithWrapPattern)8072 TEST_F(SpmdPartitioningTest, PadWithWrapPattern) {
8073 absl::string_view hlo_string = R"(
8074 HloModule xla_computation_apply_fn__4.61
8075
8076 ENTRY %xla_computation_apply_fn__4.61 (parameter.7: f32[3,16,16,16,16,132]) -> f32[3,18,16,16,16,132] {
8077 %parameter.7 = f32[3,16,16,16,16,132]{5,4,3,2,1,0} parameter(0), sharding={devices=[1,2,1,1,1,1]0,1}
8078 %slice.2 = f32[3,1,16,16,16,132]{5,4,3,2,1,0} slice(f32[3,16,16,16,16,132]{5,4,3,2,1,0} %parameter.7), slice={[0:3], [15:16], [0:16], [0:16], [0:16], [0:132]}, sharding={devices=[1,2,1,1,1,1]0,1}
8079 %slice.3 = f32[3,1,16,16,16,132]{5,4,3,2,1,0} slice(f32[3,16,16,16,16,132]{5,4,3,2,1,0} %parameter.7), slice={[0:3], [0:1], [0:16], [0:16], [0:16], [0:132]}, sharding={devices=[1,2,1,1,1,1]0,1}
8080 ROOT %concatenate.3 = f32[3,18,16,16,16,132]{5,4,3,2,1,0} concatenate(f32[3,1,16,16,16,132]{5,4,3,2,1,0} %slice.2, f32[3,16,16,16,16,132]{5,4,3,2,1,0} %parameter.7, f32[3,1,16,16,16,132]{5,4,3,2,1,0} %slice.3), dimensions={1}, sharding={devices=[1,2,1,1,1,1]0,1}
8081 }
8082 )";
8083
8084 TF_ASSERT_OK_AND_ASSIGN(auto module,
8085 PartitionComputation(hlo_string, /*num_devices=*/2));
8086 VLOG(1) << module->ToString();
8087
8088 // Check that the partitioned code does not have all-reduce and two
8089 // non-trivial collective permute instructions.
8090 for (HloComputation* computation : module->computations()) {
8091 for (HloInstruction* hlo : computation->instructions()) {
8092 EXPECT_FALSE(IsTrivialCollectivePermute(hlo)) << hlo->ToString();
8093 EXPECT_NE(hlo->opcode(), HloOpcode::kAllReduce) << hlo->ToString();
8094 }
8095 }
8096 }
8097
TEST_F(SpmdPartitioningTest,PadWrapWithNegatePattern)8098 TEST_F(SpmdPartitioningTest, PadWrapWithNegatePattern) {
8099 absl::string_view hlo_string = R"(
8100 HloModule module
8101
8102 ENTRY entry {
8103 %parameter.1 = f32[1,18] parameter(0), sharding={devices=[1,2]0,1}
8104 %slice.16 = f32[1,2] slice(f32[1,18] %parameter.1), slice={[0:1], [16:18]}, sharding={devices=[1,2]0,1}
8105 %negate.2 = f32[1,2] negate(f32[1,2] %slice.16), sharding={devices=[1,2]0,1}
8106 %slice.17 = f32[1,2] slice(f32[1,18] %parameter.1), slice={[0:1], [0:2]}, sharding={devices=[1,2]0,1}
8107 %negate.3 = f32[1,2] negate(f32[1,2] %slice.17), sharding={devices=[1,2]0,1}
8108 ROOT %concatenate.13 = f32[1,22] concatenate(f32[1,2] %negate.2, f32[1,18] %parameter.1, f32[1,2] %negate.3), dimensions={1}, sharding={devices=[1,2]0,1}
8109 }
8110 )";
8111 TF_ASSERT_OK_AND_ASSIGN(auto module,
8112 PartitionComputation(hlo_string, /*num_devices=*/2));
8113 VLOG(1) << module->ToString();
8114
8115 // Check that the partitioned code does not have all-reduce or trivial
8116 // collective permute
8117 for (HloComputation* computation : module->computations()) {
8118 for (HloInstruction* hlo : computation->instructions()) {
8119 EXPECT_FALSE(IsTrivialCollectivePermute(hlo)) << hlo->ToString();
8120 EXPECT_NE(hlo->opcode(), HloOpcode::kAllReduce) << hlo->ToString();
8121 }
8122 }
8123 }
8124
TEST_F(SpmdPartitioningTest,PadWrapWithMultipleModifiersPattern)8125 TEST_F(SpmdPartitioningTest, PadWrapWithMultipleModifiersPattern) {
8126 absl::string_view hlo_string = R"(
8127 HloModule module
8128
8129 ENTRY entry {
8130 %parameter.1 = f32[1,18] parameter(0), sharding={devices=[1,2]0,1}
8131 %slice.16 = f32[1,2] slice(f32[1,18] %parameter.1), slice={[0:1], [16:18]}, sharding={devices=[1,2]0,1}
8132 %mod0.16 = f32[1,2] rsqrt(f32[1,2] %slice.16), sharding={devices=[1,2]0,1}
8133 %mod1.16 = f32[1,2] sine(f32[1,2] %mod0.16), sharding={devices=[1,2]0,1}
8134 %slice.17 = f32[1,2] slice(f32[1,18] %parameter.1), slice={[0:1], [0:2]}, sharding={devices=[1,2]0,1}
8135 %mod0.17 = f16[1,2] convert(f32[1,2] %slice.17), sharding={devices=[1,2]0,1}
8136 %mod1.17 = f16[1,2] cosine(f16[1,2] %mod0.17), sharding={devices=[1,2]0,1}
8137 %mod2.17 = f32[1,2] convert(f16[1,2] %mod1.17), sharding={devices=[1,2]0,1}
8138 ROOT %concatenate.13 = f32[1,22] concatenate(f32[1,2] %mod1.16, f32[1,18] %parameter.1, f32[1,2] %mod2.17), dimensions={1}, sharding={devices=[1,2]0,1}
8139 }
8140 )";
8141 TF_ASSERT_OK_AND_ASSIGN(auto module,
8142 PartitionComputation(hlo_string, /*num_devices=*/2));
8143 VLOG(1) << module->ToString();
8144
8145 // Check that the partitioned code does not have all-reduce or trivial
8146 // collective permute. Also make sure modifiers have the right dependencies.
8147 for (HloComputation* computation : module->computations()) {
8148 for (HloInstruction* hlo : computation->instructions()) {
8149 const HloOpcode op = hlo->opcode();
8150 EXPECT_FALSE(IsTrivialCollectivePermute(hlo)) << hlo->ToString();
8151 EXPECT_NE(op, HloOpcode::kAllReduce) << hlo->ToString();
8152 if (hlo->operand_count() != 1) {
8153 continue;
8154 }
8155 const PrimitiveType type = hlo->shape().element_type();
8156 const HloOpcode child_op = hlo->operand(0)->opcode();
8157 const PrimitiveType child_type = hlo->operand(0)->shape().element_type();
8158
8159 if (op == HloOpcode::kSin) {
8160 EXPECT_EQ(child_op, HloOpcode::kRsqrt);
8161 } else if (op == HloOpcode::kConvert && type == F32) {
8162 EXPECT_EQ(child_op, HloOpcode::kCos);
8163 EXPECT_EQ(child_type, F16);
8164 } else if (op == HloOpcode::kCos) {
8165 EXPECT_EQ(child_op, HloOpcode::kConvert);
8166 EXPECT_EQ(child_type, F16);
8167 }
8168 }
8169 }
8170 }
8171
TEST_F(SpmdPartitioningTest,BroadcastAsReplicate)8172 TEST_F(SpmdPartitioningTest, BroadcastAsReplicate) {
8173 absl::string_view hlo_string = R"(
8174 HloModule module
8175
8176 ENTRY entry {
8177 %param0 = f32[1,1] parameter(0), sharding={devices=[2,2]0,1,2,3}
8178 ROOT %copy = f32[1,1] copy(%param0), sharding={replicated}
8179 })";
8180
8181 TF_ASSERT_OK_AND_ASSIGN(auto module,
8182 PartitionComputation(hlo_string, /*num_devices=*/4));
8183 VLOG(1) << module->ToString();
8184
8185 auto root = module->entry_computation()->root_instruction();
8186 auto param0 = AllOf(op::Parameter(0), op::Shape("f32[1,1]"));
8187 EXPECT_THAT(root, AllOf(op::Copy(op::AllReduce(op::Select(_, param0, _))),
8188 op::Shape("f32[1,1]")));
8189 }
8190
TEST_F(SpmdPartitioningTest,BroadcastAsReplicate2)8191 TEST_F(SpmdPartitioningTest, BroadcastAsReplicate2) {
8192 absl::string_view hlo_string = R"(
8193 HloModule module
8194
8195 ENTRY entry {
8196 %param0 = f32[1,2] parameter(0), sharding={devices=[2,2]0,1,2,3}
8197 ROOT %copy = f32[1,2] copy(%param0), sharding={replicated}
8198 })";
8199
8200 TF_ASSERT_OK_AND_ASSIGN(auto module,
8201 PartitionComputation(hlo_string, /*num_devices=*/4));
8202 VLOG(1) << module->ToString();
8203
8204 auto root = module->entry_computation()->root_instruction();
8205 auto param0 = AllOf(op::Parameter(0), op::Shape("f32[1,1]"));
8206 auto broadcast =
8207 AllOf(op::AllReduce(op::Select(_, param0, _)), op::Shape("f32[1,1]"));
8208 EXPECT_THAT(
8209 root,
8210 AllOf(op::Copy(op::AllReduce(op::DynamicUpdateSlice(_, broadcast, _, _))),
8211 op::Shape("f32[1,2]")));
8212 }
8213
TEST_F(SpmdPartitioningTest,BroadcastAsReplicate3)8214 TEST_F(SpmdPartitioningTest, BroadcastAsReplicate3) {
8215 absl::string_view hlo_string = R"(
8216 HloModule module
8217
8218 ENTRY entry {
8219 %param0 = f32[1,1] parameter(0),
8220 sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}
8221 ROOT %copy = f32[1,1] copy(%param0), sharding={replicated}
8222 })";
8223
8224 TF_ASSERT_OK_AND_ASSIGN(auto module,
8225 PartitionComputation(hlo_string, /*num_devices=*/4));
8226 VLOG(1) << module->ToString();
8227
8228 auto root = module->entry_computation()->root_instruction();
8229 auto param0 = AllOf(op::Parameter(0), op::Shape("f32[1,1]"));
8230 EXPECT_THAT(root, AllOf(op::Copy(op::AllReduce(op::Select(_, param0, _))),
8231 op::Shape("f32[1,1]")));
8232 }
8233
8234 } // namespace
8235 } // namespace spmd
8236 } // namespace xla
8237