1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gpu_reduce_scatter_creator.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
19 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
20 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
21 #include "tensorflow/compiler/xla/service/hlo_module.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/service/hlo_parser.h"
24 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
25 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
26 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30
31 namespace xla {
32 namespace gpu {
33 namespace {
34
35 namespace op = xla::testing::opcode_matchers;
36
37 class GpuReduceScatterCreatorTest : public HloTestBase {
38 public:
RunPass(absl::string_view hlo_module,int64_t num_replicas,int64_t num_partitions,bool expect_change)39 StatusOr<std::unique_ptr<HloModule>> RunPass(absl::string_view hlo_module,
40 int64_t num_replicas,
41 int64_t num_partitions,
42 bool expect_change) {
43 HloModuleConfig config = GetModuleConfigForTest(
44 /*replica_count=*/num_replicas,
45 /*num_partitions=*/num_partitions);
46 config.set_use_spmd_partitioning(num_partitions > 1);
47 TF_ASSIGN_OR_RETURN(auto module,
48 ParseAndReturnVerifiedModule(hlo_module, config));
49 auto changed = ReduceScatterCreator().Run(module.get());
50 if (!changed.ok()) {
51 return changed.status();
52 }
53 EXPECT_EQ(changed.ValueOrDie(), expect_change);
54 return StatusOr<std::unique_ptr<HloModule>>(std::move(module));
55 }
56
AllReduceCount(std::unique_ptr<HloModule> & module)57 size_t AllReduceCount(std::unique_ptr<HloModule> &module) {
58 return absl::c_count_if(module->entry_computation()->instructions(),
59 [](const HloInstruction *inst) {
60 return inst->opcode() == HloOpcode::kAllReduce;
61 });
62 }
63 };
64
TEST_F(GpuReduceScatterCreatorTest,AllReplicas)65 TEST_F(GpuReduceScatterCreatorTest, AllReplicas) {
66 absl::string_view hlo_string = R"(
67 HloModule AllReduce
68
69 %sum {
70 %a = f32[] parameter(0)
71 %b = f32[] parameter(1)
72 ROOT %add = f32[] add(%a, %b)
73 }
74
75 ENTRY %AllReduce {
76 %param = f32[32,8,128]{2,1,0} parameter(0)
77 %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
78 replica_groups={}, to_apply=%sum
79 %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
80 %rid = u32[] replica-id()
81 %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
82 %reshape = s32[] reshape(%id)
83 %slice_size = s32[] constant(4)
84 %offset = s32[] multiply(%reshape, %slice_size)
85 %zero = s32[] constant(0)
86 ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero, %zero),
87 dynamic_slice_sizes={4,8,128}
88 }
89 )";
90
91 TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
92 /*num_replicas=*/8,
93 /*num_partitions=*/1,
94 /*expect_change=*/true));
95 ASSERT_THAT(module->entry_computation()->root_instruction(),
96 op::ReduceScatter(op::Parameter(0)));
97 const auto *rs = Cast<HloReduceScatterInstruction>(
98 module->entry_computation()->root_instruction());
99 EXPECT_EQ(rs->scatter_dimension(), 0) << rs->ToString();
100 EXPECT_EQ(AllReduceCount(module), 0);
101 }
102
TEST_F(GpuReduceScatterCreatorTest,AllReplicasWithReshape)103 TEST_F(GpuReduceScatterCreatorTest, AllReplicasWithReshape) {
104 absl::string_view hlo_string = R"(
105 HloModule AllReduce
106
107 %sum {
108 %a = f32[] parameter(0)
109 %b = f32[] parameter(1)
110 ROOT %add = f32[] add(%a, %b)
111 }
112
113 ENTRY %AllReduce {
114 %param = f32[32,8,128]{2,1,0} parameter(0)
115 %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
116 replica_groups={}, to_apply=%sum
117 %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
118 %rid = u32[] replica-id()
119 %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
120 %reshape = s32[] reshape(%id)
121 %slice_size = s32[] constant(4)
122 %offset = s32[] multiply(%reshape, %slice_size)
123 %zero = s32[] constant(0)
124 %reshape.1 = f32[32,16,64] reshape(%all-reduce)
125 ROOT %dynamic-slice = f32[4,16,64] dynamic-slice(%reshape.1, %offset, %zero, %zero),
126 dynamic_slice_sizes={4,16,64}
127 }
128 )";
129
130 TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
131 /*num_replicas=*/8,
132 /*num_partitions=*/1,
133 /*expect_change=*/true));
134 EXPECT_THAT(module->entry_computation()->root_instruction(),
135 op::Reshape(op::ReduceScatter(op::Parameter(0))));
136 EXPECT_EQ(AllReduceCount(module), 0);
137 }
138
TEST_F(GpuReduceScatterCreatorTest,AllReplicasWithReshapeSplitDimModified)139 TEST_F(GpuReduceScatterCreatorTest, AllReplicasWithReshapeSplitDimModified) {
140 absl::string_view hlo_string = R"(
141 HloModule AllReduce
142
143 %sum {
144 %a = f32[] parameter(0)
145 %b = f32[] parameter(1)
146 ROOT %add = f32[] add(%a, %b)
147 }
148
149 ENTRY %AllReduce {
150 %param = f32[336,1024] parameter(0)
151 %all-reduce = f32[336,1024] all-reduce(%param), replica_groups={}, to_apply=%sum
152 %rid = u32[] replica-id()
153 %id = s32[] convert(%rid)
154 %slice_size = s32[] constant(128)
155 %offset = s32[] multiply(%id, %slice_size)
156 %zero = s32[] constant(0)
157 %reshape.1 = f32[4,84,1024] reshape(%all-reduce)
158 ROOT %dynamic-slice = f32[4,84,128] dynamic-slice(%reshape.1, %zero, %zero, %offset),
159 dynamic_slice_sizes={4,84,128}
160 }
161 )";
162
163 TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
164 /*num_replicas=*/8,
165 /*num_partitions=*/1,
166 /*expect_change=*/true));
167 EXPECT_THAT(module->entry_computation()->root_instruction(),
168 op::Reshape(op::ReduceScatter(op::Parameter(0))));
169 EXPECT_EQ(AllReduceCount(module), 0);
170 }
171
TEST_F(GpuReduceScatterCreatorTest,AllReplicasDim2)172 TEST_F(GpuReduceScatterCreatorTest, AllReplicasDim2) {
173 absl::string_view hlo_string = R"(
174 HloModule AllReduce
175
176 %sum {
177 %a = f32[] parameter(0)
178 %b = f32[] parameter(1)
179 ROOT %add = f32[] add(%a, %b)
180 }
181
182 ENTRY %AllReduce {
183 %param = f32[32,8,128]{2,1,0} parameter(0)
184 %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
185 replica_groups={}, to_apply=%sum
186 %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
187 %rid = u32[] replica-id()
188 %rid_s32 = s32[] convert(%rid)
189 %slice_size = s32[] constant(16)
190 %offset = s32[] multiply(%rid_s32, %slice_size)
191 %zero = s32[] constant(0)
192 ROOT %dynamic-slice = f32[32,8,16] dynamic-slice(%all-reduce, %zero, %zero, %offset),
193 dynamic_slice_sizes={32,8,16}
194 }
195 )";
196
197 TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
198 /*num_replicas=*/8,
199 /*num_partitions=*/1,
200 /*expect_change=*/true));
201 ASSERT_THAT(module->entry_computation()->root_instruction(),
202 op::ReduceScatter(op::Parameter(0)));
203 const auto *rs = Cast<HloReduceScatterInstruction>(
204 module->entry_computation()->root_instruction());
205 EXPECT_EQ(rs->scatter_dimension(), 2) << rs->ToString();
206 EXPECT_EQ(AllReduceCount(module), 0);
207 }
208
TEST_F(GpuReduceScatterCreatorTest,AllReplicasWrongOffsets)209 TEST_F(GpuReduceScatterCreatorTest, AllReplicasWrongOffsets) {
210 absl::string_view hlo_string = R"(
211 HloModule AllReduce
212
213 %sum {
214 %a = f32[] parameter(0)
215 %b = f32[] parameter(1)
216 ROOT %add = f32[] add(%a, %b)
217 }
218
219 ENTRY %AllReduce {
220 %param = f32[32,8,128]{2,1,0} parameter(0)
221 %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
222 replica_groups={}, to_apply=%sum
223 %table = s32[8]{0} constant({0,1,2,3,4,5,6,8})
224 %rid = u32[] replica-id()
225 %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
226 %reshape = s32[] reshape(%id)
227 %slice_size = s32[] constant(4)
228 %offset = s32[] multiply(%reshape, %slice_size)
229 %zero = s32[] constant(0)
230 ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero, %zero),
231 dynamic_slice_sizes={4,8,128}
232 }
233 )";
234 TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
235 /*num_replicas=*/8,
236 /*num_partitions=*/1,
237 /*expect_change=*/false));
238 }
239
TEST_F(GpuReduceScatterCreatorTest,AllReplicasIotaTable)240 TEST_F(GpuReduceScatterCreatorTest, AllReplicasIotaTable) {
241 absl::string_view hlo_string = R"(
242 HloModule AllReduce
243
244 %sum {
245 %a = f32[] parameter(0)
246 %b = f32[] parameter(1)
247 ROOT %add = f32[] add(%a, %b)
248 }
249
250 ENTRY %AllReduce {
251 %param = f32[32,8,128]{2,1,0} parameter(0)
252 %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
253 replica_groups={}, to_apply=%sum
254 %table = s32[8]{0} iota(), iota_dimension=0
255 %rid = u32[] replica-id()
256 %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
257 %reshape = s32[] reshape(%id)
258 %slice_size = s32[] constant(4)
259 %offset = s32[] multiply(%reshape, %slice_size)
260 %zero = s32[] constant(0)
261 ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero, %zero),
262 dynamic_slice_sizes={4,8,128}
263 }
264 )";
265 TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
266 /*num_replicas=*/8,
267 /*num_partitions=*/2,
268 /*expect_change=*/true));
269 EXPECT_THAT(module->entry_computation()->root_instruction(),
270 op::ReduceScatter(op::Parameter(0)));
271 EXPECT_EQ(AllReduceCount(module), 0);
272 }
273
TEST_F(GpuReduceScatterCreatorTest,SubgroupedReplicas)274 TEST_F(GpuReduceScatterCreatorTest, SubgroupedReplicas) {
275 absl::string_view hlo_string = R"(
276 HloModule AllReduce
277
278 %sum {
279 %a = f32[] parameter(0)
280 %b = f32[] parameter(1)
281 ROOT %add = f32[] add(%a, %b)
282 }
283
284 ENTRY %AllReduce {
285 %param = f32[32,8,128]{2,1,0} parameter(0)
286 %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
287 replica_groups={{1,3,2,0},{4,5,6,7}}, to_apply=%sum
288 %gtable = s32[8]{0} constant({3,0,2,1,0,1,2,3})
289 %rid = u32[] replica-id()
290 %id = s32[1] dynamic-slice(%gtable, %rid), dynamic_slice_sizes={1}
291 %reshape.0 = s32[] reshape(%id)
292 %table = s32[4]{0} constant({0,8,16,24})
293 %offset = s32[1] dynamic-slice(%table, %reshape.0), dynamic_slice_sizes={1}
294 %reshape.1 = s32[] reshape(%offset)
295 %zero = s32[] constant(0)
296 ROOT %dynamic-slice = f32[8,8,128] dynamic-slice(%all-reduce, %reshape.1, %zero, %zero),
297 dynamic_slice_sizes={8,8,128}
298 }
299 )";
300 TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
301 /*num_replicas=*/8,
302 /*num_partitions=*/2,
303 /*expect_change=*/true));
304 EXPECT_THAT(module->entry_computation()->root_instruction(),
305 op::ReduceScatter(op::Parameter(0)));
306 EXPECT_EQ(AllReduceCount(module), 0);
307 }
308
TEST_F(GpuReduceScatterCreatorTest,AllPartitions)309 TEST_F(GpuReduceScatterCreatorTest, AllPartitions) {
310 absl::string_view hlo_string = R"(
311 HloModule AllReduce
312
313 %sum {
314 %a = f32[] parameter(0)
315 %b = f32[] parameter(1)
316 ROOT %add = f32[] add(%a, %b)
317 }
318
319 ENTRY %AllReduce {
320 %param = f32[32,8,128]{2,1,0} parameter(0)
321 %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
322 replica_groups={{0},{1}}, to_apply=%sum, channel_id=1
323 %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
324 %pid = u32[] partition-id()
325 %id = s32[1] dynamic-slice(%table, %pid), dynamic_slice_sizes={1}
326 %reshape = s32[] reshape(%id)
327 %slice_size = s32[] constant(4)
328 %offset = s32[] multiply(%reshape, %slice_size)
329 %zero = s32[] constant(0)
330 ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero, %zero),
331 dynamic_slice_sizes={4,8,128}
332 }
333 )";
334 TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
335 /*num_replicas=*/2,
336 /*num_partitions=*/8,
337 /*expect_change=*/true));
338 EXPECT_THAT(module->entry_computation()->root_instruction(),
339 op::ReduceScatter(op::Parameter(0)));
340 EXPECT_EQ(AllReduceCount(module), 0);
341 }
342
TEST_F(GpuReduceScatterCreatorTest,SubgroupsGlobals)343 TEST_F(GpuReduceScatterCreatorTest, SubgroupsGlobals) {
344 absl::string_view hlo_string = R"(
345 HloModule AllReduce
346
347 %sum {
348 %a = f32[] parameter(0)
349 %b = f32[] parameter(1)
350 ROOT %add = f32[] add(%a, %b)
351 }
352
353 ENTRY %AllReduce {
354 %param = f32[32,8,128]{2,1,0} parameter(0)
355 %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
356 replica_groups={{1,3,2,0},{4,5,6,7}}, to_apply=%sum, channel_id=1, use_global_device_ids=true
357 %pid = u32[] partition-id()
358 %rid = u32[] replica-id()
359 %pcount = u32[] constant(4)
360 %ridxp = u32[] multiply(%rid, %pcount)
361 %gid = u32[] add(%ridxp, %pid)
362 %gtable = s32[8]{0} constant({3,0,2,1,0,1,2,3})
363 %id = s32[1] dynamic-slice(%gtable, %gid), dynamic_slice_sizes={1}
364 %reshape.0 = s32[] reshape(%id)
365 %table = s32[4]{0} constant({0,8,16,24})
366 %offset = s32[1] dynamic-slice(%table, %reshape.0), dynamic_slice_sizes={1}
367 %reshape.1 = s32[] reshape(%offset)
368 %zero = s32[] constant(0)
369 ROOT %dynamic-slice = f32[8,8,128] dynamic-slice(%all-reduce, %reshape.1, %zero, %zero),
370 dynamic_slice_sizes={8,8,128}
371 }
372 )";
373 TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
374 /*num_replicas=*/2,
375 /*num_partitions=*/4,
376 /*expect_change=*/true));
377 EXPECT_THAT(module->entry_computation()->root_instruction(),
378 op::ReduceScatter(op::Parameter(0)));
379 EXPECT_EQ(AllReduceCount(module), 0);
380 }
381
TEST_F(GpuReduceScatterCreatorTest,SubgroupsGlobalsOrthogonalReplicas)382 TEST_F(GpuReduceScatterCreatorTest, SubgroupsGlobalsOrthogonalReplicas) {
383 absl::string_view hlo_string = R"(
384 HloModule AllReduce
385
386 %sum {
387 %a = f32[] parameter(0)
388 %b = f32[] parameter(1)
389 ROOT %add = f32[] add(%a, %b)
390 }
391
392 ENTRY %AllReduce {
393 %param = f32[32,8,128]{2,1,0} parameter(0)
394 %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
395 replica_groups={{1,3,2,0},{5,7,6,4}}, to_apply=%sum, channel_id=1, use_global_device_ids=true
396 %pid = u32[] partition-id()
397 %pid_table = s32[4]{0} constant({3,0,2,1})
398 %offset = s32[1] dynamic-slice(%pid_table, %pid), dynamic_slice_sizes={1}
399 %reshape = s32[] reshape(%offset)
400 %shard_size = s32[] constant(8)
401 %mul = s32[] multiply(%reshape, %shard_size)
402 %zero = s32[] constant(0)
403 ROOT %dynamic-slice = f32[8,8,128] dynamic-slice(%all-reduce, %mul, %zero, %zero),
404 dynamic_slice_sizes={8,8,128}
405 }
406 )";
407 TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
408 /*num_replicas=*/2,
409 /*num_partitions=*/4,
410 /*expect_change=*/true));
411 EXPECT_THAT(module->entry_computation()->root_instruction(),
412 op::ReduceScatter(op::Parameter(0)));
413 EXPECT_EQ(AllReduceCount(module), 0);
414 }
415
TEST_F(GpuReduceScatterCreatorTest,SubgroupsGlobalsNonOrthogonalReplicas)416 TEST_F(GpuReduceScatterCreatorTest, SubgroupsGlobalsNonOrthogonalReplicas) {
417 absl::string_view hlo_string = R"(
418 HloModule AllReduce
419
420 %sum {
421 %a = f32[] parameter(0)
422 %b = f32[] parameter(1)
423 ROOT %add = f32[] add(%a, %b)
424 }
425
426 ENTRY %AllReduce {
427 %param = f32[32,8,128]{2,1,0} parameter(0)
428 %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
429 replica_groups={{1,3,2,0},{7,5,6,4}}, to_apply=%sum, channel_id=1, use_global_device_ids=true
430 %pid = u32[] partition-id()
431 %pid_table = s32[4]{0} constant({3,0,2,1})
432 %offset = s32[1] dynamic-slice(%pid_table, %pid), dynamic_slice_sizes={1}
433 %reshape = s32[] reshape(%offset)
434 %shard_size = s32[] constant(8)
435 %mul = s32[] multiply(%reshape, %shard_size)
436 %zero = s32[] constant(0)
437 ROOT %dynamic-slice = f32[8,8,128] dynamic-slice(%all-reduce, %mul, %zero, %zero),
438 dynamic_slice_sizes={8,8,128}
439 }
440 )";
441 TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
442 /*num_replicas=*/2,
443 /*num_partitions=*/4,
444 /*expect_change=*/false));
445 }
446
TEST_F(GpuReduceScatterCreatorTest,NonUniformSplit)447 TEST_F(GpuReduceScatterCreatorTest, NonUniformSplit) {
448 absl::string_view hlo_string = R"(
449 HloModule AllReduce
450
451 %sum {
452 %a = f32[] parameter(0)
453 %b = f32[] parameter(1)
454 ROOT %add = f32[] add(%a, %b)
455 }
456
457 ENTRY %AllReduce {
458 %param = f32[1,7]{1,0} parameter(0)
459 %all-reduce = f32[1,7]{1,0} all-reduce(%param),
460 replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=%sum, channel_id=1, use_global_device_ids=true
461 %pid = u32[] partition-id()
462 %pid_table = s32[8]{0} constant({0, 1, 0, 1, 0, 1, 0, 1})
463 %offset = s32[1] dynamic-slice(%pid_table, %pid), dynamic_slice_sizes={1}
464 %reshape = s32[] reshape(%offset)
465 %shard_size = s32[] constant(3)
466 %mul = s32[] multiply(%reshape, %shard_size)
467 %zero = s32[] constant(0)
468 ROOT %dynamic-slice = f32[1,3] dynamic-slice(%all-reduce, %zero, %mul),
469 dynamic_slice_sizes={1,3}
470 }
471 )";
472 TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
473 /*num_replicas=*/1,
474 /*num_partitions=*/8,
475 /*expect_change=*/true));
476 EXPECT_THAT(module->entry_computation()->root_instruction(),
477 op::ReduceScatter(op::Slice(op::Parameter(0))));
478 }
479
480 } // namespace
481 } // namespace gpu
482 } // namespace xla
483