• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gpu_reduce_scatter_creator.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
19 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
20 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
21 #include "tensorflow/compiler/xla/service/hlo_module.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/service/hlo_parser.h"
24 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
25 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
26 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 
31 namespace xla {
32 namespace gpu {
33 namespace {
34 
35 namespace op = xla::testing::opcode_matchers;
36 
37 class GpuReduceScatterCreatorTest : public HloTestBase {
38  public:
RunPass(absl::string_view hlo_module,int64_t num_replicas,int64_t num_partitions,bool expect_change)39   StatusOr<std::unique_ptr<HloModule>> RunPass(absl::string_view hlo_module,
40                                                int64_t num_replicas,
41                                                int64_t num_partitions,
42                                                bool expect_change) {
43     HloModuleConfig config = GetModuleConfigForTest(
44         /*replica_count=*/num_replicas,
45         /*num_partitions=*/num_partitions);
46     config.set_use_spmd_partitioning(num_partitions > 1);
47     TF_ASSIGN_OR_RETURN(auto module,
48                         ParseAndReturnVerifiedModule(hlo_module, config));
49     auto changed = ReduceScatterCreator().Run(module.get());
50     if (!changed.ok()) {
51       return changed.status();
52     }
53     EXPECT_EQ(changed.ValueOrDie(), expect_change);
54     return StatusOr<std::unique_ptr<HloModule>>(std::move(module));
55   }
56 
AllReduceCount(std::unique_ptr<HloModule> & module)57   size_t AllReduceCount(std::unique_ptr<HloModule> &module) {
58     return absl::c_count_if(module->entry_computation()->instructions(),
59                             [](const HloInstruction *inst) {
60                               return inst->opcode() == HloOpcode::kAllReduce;
61                             });
62   }
63 };
64 
TEST_F(GpuReduceScatterCreatorTest,AllReplicas)65 TEST_F(GpuReduceScatterCreatorTest, AllReplicas) {
66   absl::string_view hlo_string = R"(
67 HloModule AllReduce
68 
69 %sum {
70   %a = f32[] parameter(0)
71   %b = f32[] parameter(1)
72   ROOT %add = f32[] add(%a, %b)
73 }
74 
75 ENTRY %AllReduce {
76   %param = f32[32,8,128]{2,1,0} parameter(0)
77   %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
78     replica_groups={}, to_apply=%sum
79   %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
80   %rid = u32[] replica-id()
81   %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
82   %reshape = s32[] reshape(%id)
83   %slice_size = s32[] constant(4)
84   %offset = s32[] multiply(%reshape, %slice_size)
85   %zero = s32[] constant(0)
86   ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero, %zero),
87     dynamic_slice_sizes={4,8,128}
88 }
89 )";
90 
91   TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
92                                                /*num_replicas=*/8,
93                                                /*num_partitions=*/1,
94                                                /*expect_change=*/true));
95   ASSERT_THAT(module->entry_computation()->root_instruction(),
96               op::ReduceScatter(op::Parameter(0)));
97   const auto *rs = Cast<HloReduceScatterInstruction>(
98       module->entry_computation()->root_instruction());
99   EXPECT_EQ(rs->scatter_dimension(), 0) << rs->ToString();
100   EXPECT_EQ(AllReduceCount(module), 0);
101 }
102 
TEST_F(GpuReduceScatterCreatorTest,AllReplicasWithReshape)103 TEST_F(GpuReduceScatterCreatorTest, AllReplicasWithReshape) {
104   absl::string_view hlo_string = R"(
105 HloModule AllReduce
106 
107 %sum {
108   %a = f32[] parameter(0)
109   %b = f32[] parameter(1)
110   ROOT %add = f32[] add(%a, %b)
111 }
112 
113 ENTRY %AllReduce {
114   %param = f32[32,8,128]{2,1,0} parameter(0)
115   %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
116     replica_groups={}, to_apply=%sum
117   %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
118   %rid = u32[] replica-id()
119   %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
120   %reshape = s32[] reshape(%id)
121   %slice_size = s32[] constant(4)
122   %offset = s32[] multiply(%reshape, %slice_size)
123   %zero = s32[] constant(0)
124   %reshape.1 = f32[32,16,64] reshape(%all-reduce)
125   ROOT %dynamic-slice = f32[4,16,64] dynamic-slice(%reshape.1, %offset, %zero, %zero),
126     dynamic_slice_sizes={4,16,64}
127 }
128 )";
129 
130   TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
131                                                /*num_replicas=*/8,
132                                                /*num_partitions=*/1,
133                                                /*expect_change=*/true));
134   EXPECT_THAT(module->entry_computation()->root_instruction(),
135               op::Reshape(op::ReduceScatter(op::Parameter(0))));
136   EXPECT_EQ(AllReduceCount(module), 0);
137 }
138 
TEST_F(GpuReduceScatterCreatorTest,AllReplicasWithReshapeSplitDimModified)139 TEST_F(GpuReduceScatterCreatorTest, AllReplicasWithReshapeSplitDimModified) {
140   absl::string_view hlo_string = R"(
141 HloModule AllReduce
142 
143 %sum {
144   %a = f32[] parameter(0)
145   %b = f32[] parameter(1)
146   ROOT %add = f32[] add(%a, %b)
147 }
148 
149 ENTRY %AllReduce {
150   %param = f32[336,1024] parameter(0)
151   %all-reduce = f32[336,1024] all-reduce(%param), replica_groups={}, to_apply=%sum
152   %rid = u32[] replica-id()
153   %id = s32[] convert(%rid)
154   %slice_size = s32[] constant(128)
155   %offset = s32[] multiply(%id, %slice_size)
156   %zero = s32[] constant(0)
157   %reshape.1 = f32[4,84,1024] reshape(%all-reduce)
158   ROOT %dynamic-slice = f32[4,84,128] dynamic-slice(%reshape.1, %zero, %zero, %offset),
159     dynamic_slice_sizes={4,84,128}
160 }
161 )";
162 
163   TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
164                                                /*num_replicas=*/8,
165                                                /*num_partitions=*/1,
166                                                /*expect_change=*/true));
167   EXPECT_THAT(module->entry_computation()->root_instruction(),
168               op::Reshape(op::ReduceScatter(op::Parameter(0))));
169   EXPECT_EQ(AllReduceCount(module), 0);
170 }
171 
TEST_F(GpuReduceScatterCreatorTest,AllReplicasDim2)172 TEST_F(GpuReduceScatterCreatorTest, AllReplicasDim2) {
173   absl::string_view hlo_string = R"(
174 HloModule AllReduce
175 
176 %sum {
177   %a = f32[] parameter(0)
178   %b = f32[] parameter(1)
179   ROOT %add = f32[] add(%a, %b)
180 }
181 
182 ENTRY %AllReduce {
183   %param = f32[32,8,128]{2,1,0} parameter(0)
184   %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
185     replica_groups={}, to_apply=%sum
186   %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
187   %rid = u32[] replica-id()
188   %rid_s32 = s32[] convert(%rid)
189   %slice_size = s32[] constant(16)
190   %offset = s32[] multiply(%rid_s32, %slice_size)
191   %zero = s32[] constant(0)
192   ROOT %dynamic-slice = f32[32,8,16] dynamic-slice(%all-reduce, %zero, %zero, %offset),
193     dynamic_slice_sizes={32,8,16}
194 }
195 )";
196 
197   TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
198                                                /*num_replicas=*/8,
199                                                /*num_partitions=*/1,
200                                                /*expect_change=*/true));
201   ASSERT_THAT(module->entry_computation()->root_instruction(),
202               op::ReduceScatter(op::Parameter(0)));
203   const auto *rs = Cast<HloReduceScatterInstruction>(
204       module->entry_computation()->root_instruction());
205   EXPECT_EQ(rs->scatter_dimension(), 2) << rs->ToString();
206   EXPECT_EQ(AllReduceCount(module), 0);
207 }
208 
TEST_F(GpuReduceScatterCreatorTest,AllReplicasWrongOffsets)209 TEST_F(GpuReduceScatterCreatorTest, AllReplicasWrongOffsets) {
210   absl::string_view hlo_string = R"(
211 HloModule AllReduce
212 
213 %sum {
214   %a = f32[] parameter(0)
215   %b = f32[] parameter(1)
216   ROOT %add = f32[] add(%a, %b)
217 }
218 
219 ENTRY %AllReduce {
220   %param = f32[32,8,128]{2,1,0} parameter(0)
221   %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
222     replica_groups={}, to_apply=%sum
223   %table = s32[8]{0} constant({0,1,2,3,4,5,6,8})
224   %rid = u32[] replica-id()
225   %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
226   %reshape = s32[] reshape(%id)
227   %slice_size = s32[] constant(4)
228   %offset = s32[] multiply(%reshape, %slice_size)
229   %zero = s32[] constant(0)
230   ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero, %zero),
231     dynamic_slice_sizes={4,8,128}
232 }
233 )";
234   TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
235                                                /*num_replicas=*/8,
236                                                /*num_partitions=*/1,
237                                                /*expect_change=*/false));
238 }
239 
TEST_F(GpuReduceScatterCreatorTest,AllReplicasIotaTable)240 TEST_F(GpuReduceScatterCreatorTest, AllReplicasIotaTable) {
241   absl::string_view hlo_string = R"(
242 HloModule AllReduce
243 
244 %sum {
245   %a = f32[] parameter(0)
246   %b = f32[] parameter(1)
247   ROOT %add = f32[] add(%a, %b)
248 }
249 
250 ENTRY %AllReduce {
251   %param = f32[32,8,128]{2,1,0} parameter(0)
252   %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
253     replica_groups={}, to_apply=%sum
254   %table = s32[8]{0} iota(), iota_dimension=0
255   %rid = u32[] replica-id()
256   %id = s32[1] dynamic-slice(%table, %rid), dynamic_slice_sizes={1}
257   %reshape = s32[] reshape(%id)
258   %slice_size = s32[] constant(4)
259   %offset = s32[] multiply(%reshape, %slice_size)
260   %zero = s32[] constant(0)
261   ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero, %zero),
262     dynamic_slice_sizes={4,8,128}
263 }
264 )";
265   TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
266                                                /*num_replicas=*/8,
267                                                /*num_partitions=*/2,
268                                                /*expect_change=*/true));
269   EXPECT_THAT(module->entry_computation()->root_instruction(),
270               op::ReduceScatter(op::Parameter(0)));
271   EXPECT_EQ(AllReduceCount(module), 0);
272 }
273 
TEST_F(GpuReduceScatterCreatorTest,SubgroupedReplicas)274 TEST_F(GpuReduceScatterCreatorTest, SubgroupedReplicas) {
275   absl::string_view hlo_string = R"(
276 HloModule AllReduce
277 
278 %sum {
279   %a = f32[] parameter(0)
280   %b = f32[] parameter(1)
281   ROOT %add = f32[] add(%a, %b)
282 }
283 
284 ENTRY %AllReduce {
285   %param = f32[32,8,128]{2,1,0} parameter(0)
286   %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
287     replica_groups={{1,3,2,0},{4,5,6,7}}, to_apply=%sum
288   %gtable = s32[8]{0} constant({3,0,2,1,0,1,2,3})
289   %rid = u32[] replica-id()
290   %id = s32[1] dynamic-slice(%gtable, %rid), dynamic_slice_sizes={1}
291   %reshape.0 = s32[] reshape(%id)
292   %table = s32[4]{0} constant({0,8,16,24})
293   %offset = s32[1] dynamic-slice(%table, %reshape.0), dynamic_slice_sizes={1}
294   %reshape.1 = s32[] reshape(%offset)
295   %zero = s32[] constant(0)
296   ROOT %dynamic-slice = f32[8,8,128] dynamic-slice(%all-reduce, %reshape.1, %zero, %zero),
297     dynamic_slice_sizes={8,8,128}
298 }
299 )";
300   TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
301                                                /*num_replicas=*/8,
302                                                /*num_partitions=*/2,
303                                                /*expect_change=*/true));
304   EXPECT_THAT(module->entry_computation()->root_instruction(),
305               op::ReduceScatter(op::Parameter(0)));
306   EXPECT_EQ(AllReduceCount(module), 0);
307 }
308 
TEST_F(GpuReduceScatterCreatorTest,AllPartitions)309 TEST_F(GpuReduceScatterCreatorTest, AllPartitions) {
310   absl::string_view hlo_string = R"(
311 HloModule AllReduce
312 
313 %sum {
314   %a = f32[] parameter(0)
315   %b = f32[] parameter(1)
316   ROOT %add = f32[] add(%a, %b)
317 }
318 
319 ENTRY %AllReduce {
320   %param = f32[32,8,128]{2,1,0} parameter(0)
321   %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
322     replica_groups={{0},{1}}, to_apply=%sum, channel_id=1
323   %table = s32[8]{0} constant({0,1,2,3,4,5,6,7})
324   %pid = u32[] partition-id()
325   %id = s32[1] dynamic-slice(%table, %pid), dynamic_slice_sizes={1}
326   %reshape = s32[] reshape(%id)
327   %slice_size = s32[] constant(4)
328   %offset = s32[] multiply(%reshape, %slice_size)
329   %zero = s32[] constant(0)
330   ROOT %dynamic-slice = f32[4,8,128] dynamic-slice(%all-reduce, %offset, %zero, %zero),
331     dynamic_slice_sizes={4,8,128}
332 }
333 )";
334   TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
335                                                /*num_replicas=*/2,
336                                                /*num_partitions=*/8,
337                                                /*expect_change=*/true));
338   EXPECT_THAT(module->entry_computation()->root_instruction(),
339               op::ReduceScatter(op::Parameter(0)));
340   EXPECT_EQ(AllReduceCount(module), 0);
341 }
342 
TEST_F(GpuReduceScatterCreatorTest,SubgroupsGlobals)343 TEST_F(GpuReduceScatterCreatorTest, SubgroupsGlobals) {
344   absl::string_view hlo_string = R"(
345 HloModule AllReduce
346 
347 %sum {
348   %a = f32[] parameter(0)
349   %b = f32[] parameter(1)
350   ROOT %add = f32[] add(%a, %b)
351 }
352 
353 ENTRY %AllReduce {
354   %param = f32[32,8,128]{2,1,0} parameter(0)
355   %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
356     replica_groups={{1,3,2,0},{4,5,6,7}}, to_apply=%sum, channel_id=1, use_global_device_ids=true
357   %pid = u32[] partition-id()
358   %rid = u32[] replica-id()
359   %pcount = u32[] constant(4)
360   %ridxp = u32[] multiply(%rid, %pcount)
361   %gid = u32[] add(%ridxp, %pid)
362   %gtable = s32[8]{0} constant({3,0,2,1,0,1,2,3})
363   %id = s32[1] dynamic-slice(%gtable, %gid), dynamic_slice_sizes={1}
364   %reshape.0 = s32[] reshape(%id)
365   %table = s32[4]{0} constant({0,8,16,24})
366   %offset = s32[1] dynamic-slice(%table, %reshape.0), dynamic_slice_sizes={1}
367   %reshape.1 = s32[] reshape(%offset)
368   %zero = s32[] constant(0)
369   ROOT %dynamic-slice = f32[8,8,128] dynamic-slice(%all-reduce, %reshape.1, %zero, %zero),
370     dynamic_slice_sizes={8,8,128}
371 }
372 )";
373   TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
374                                                /*num_replicas=*/2,
375                                                /*num_partitions=*/4,
376                                                /*expect_change=*/true));
377   EXPECT_THAT(module->entry_computation()->root_instruction(),
378               op::ReduceScatter(op::Parameter(0)));
379   EXPECT_EQ(AllReduceCount(module), 0);
380 }
381 
TEST_F(GpuReduceScatterCreatorTest,SubgroupsGlobalsOrthogonalReplicas)382 TEST_F(GpuReduceScatterCreatorTest, SubgroupsGlobalsOrthogonalReplicas) {
383   absl::string_view hlo_string = R"(
384 HloModule AllReduce
385 
386 %sum {
387   %a = f32[] parameter(0)
388   %b = f32[] parameter(1)
389   ROOT %add = f32[] add(%a, %b)
390 }
391 
392 ENTRY %AllReduce {
393   %param = f32[32,8,128]{2,1,0} parameter(0)
394   %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
395     replica_groups={{1,3,2,0},{5,7,6,4}}, to_apply=%sum, channel_id=1, use_global_device_ids=true
396   %pid = u32[] partition-id()
397   %pid_table = s32[4]{0} constant({3,0,2,1})
398   %offset = s32[1] dynamic-slice(%pid_table, %pid), dynamic_slice_sizes={1}
399   %reshape = s32[] reshape(%offset)
400   %shard_size = s32[] constant(8)
401   %mul = s32[] multiply(%reshape, %shard_size)
402   %zero = s32[] constant(0)
403   ROOT %dynamic-slice = f32[8,8,128] dynamic-slice(%all-reduce, %mul, %zero, %zero),
404     dynamic_slice_sizes={8,8,128}
405 }
406 )";
407   TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
408                                                /*num_replicas=*/2,
409                                                /*num_partitions=*/4,
410                                                /*expect_change=*/true));
411   EXPECT_THAT(module->entry_computation()->root_instruction(),
412               op::ReduceScatter(op::Parameter(0)));
413   EXPECT_EQ(AllReduceCount(module), 0);
414 }
415 
TEST_F(GpuReduceScatterCreatorTest,SubgroupsGlobalsNonOrthogonalReplicas)416 TEST_F(GpuReduceScatterCreatorTest, SubgroupsGlobalsNonOrthogonalReplicas) {
417   absl::string_view hlo_string = R"(
418 HloModule AllReduce
419 
420 %sum {
421   %a = f32[] parameter(0)
422   %b = f32[] parameter(1)
423   ROOT %add = f32[] add(%a, %b)
424 }
425 
426 ENTRY %AllReduce {
427   %param = f32[32,8,128]{2,1,0} parameter(0)
428   %all-reduce = f32[32,8,128]{2,1,0} all-reduce(%param),
429     replica_groups={{1,3,2,0},{7,5,6,4}}, to_apply=%sum, channel_id=1, use_global_device_ids=true
430   %pid = u32[] partition-id()
431   %pid_table = s32[4]{0} constant({3,0,2,1})
432   %offset = s32[1] dynamic-slice(%pid_table, %pid), dynamic_slice_sizes={1}
433   %reshape = s32[] reshape(%offset)
434   %shard_size = s32[] constant(8)
435   %mul = s32[] multiply(%reshape, %shard_size)
436   %zero = s32[] constant(0)
437   ROOT %dynamic-slice = f32[8,8,128] dynamic-slice(%all-reduce, %mul, %zero, %zero),
438     dynamic_slice_sizes={8,8,128}
439 }
440 )";
441   TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
442                                                /*num_replicas=*/2,
443                                                /*num_partitions=*/4,
444                                                /*expect_change=*/false));
445 }
446 
TEST_F(GpuReduceScatterCreatorTest,NonUniformSplit)447 TEST_F(GpuReduceScatterCreatorTest, NonUniformSplit) {
448   absl::string_view hlo_string = R"(
449 HloModule AllReduce
450 
451 %sum {
452   %a = f32[] parameter(0)
453   %b = f32[] parameter(1)
454   ROOT %add = f32[] add(%a, %b)
455 }
456 
457 ENTRY %AllReduce {
458   %param = f32[1,7]{1,0} parameter(0)
459   %all-reduce = f32[1,7]{1,0} all-reduce(%param),
460     replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=%sum, channel_id=1, use_global_device_ids=true
461   %pid = u32[] partition-id()
462   %pid_table = s32[8]{0} constant({0, 1, 0, 1, 0, 1, 0, 1})
463   %offset = s32[1] dynamic-slice(%pid_table, %pid), dynamic_slice_sizes={1}
464   %reshape = s32[] reshape(%offset)
465   %shard_size = s32[] constant(3)
466   %mul = s32[] multiply(%reshape, %shard_size)
467   %zero = s32[] constant(0)
468   ROOT %dynamic-slice = f32[1,3] dynamic-slice(%all-reduce, %zero, %mul),
469     dynamic_slice_sizes={1,3}
470 }
471 )";
472   TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string,
473                                                /*num_replicas=*/1,
474                                                /*num_partitions=*/8,
475                                                /*expect_change=*/true));
476   EXPECT_THAT(module->entry_computation()->root_instruction(),
477               op::ReduceScatter(op::Slice(op::Parameter(0))));
478 }
479 
480 }  // namespace
481 }  // namespace gpu
482 }  // namespace xla
483