1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/all_reduce_combiner.h"
17
18 #include <memory>
19
20 #include "absl/memory/memory.h"
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/literal_util.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
26 #include "tensorflow/compiler/xla/service/hlo_module.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
30 #include "tensorflow/compiler/xla/tests/test_utils.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/types.h"
35
36 namespace xla {
37 namespace {
38
39 using absl::nullopt;
40 using ::testing::AllOf;
41 namespace op = xla::testing::opcode_matchers;
42 int64 kMaxCombineCount = 256;
43
AllReduceCount(const HloModule & module)44 int64 AllReduceCount(const HloModule& module) {
45 int64 count = 0;
46 for (HloComputation* computation : module.computations()) {
47 if (computation->IsFusionComputation()) {
48 continue;
49 }
50 for (HloInstruction* hlo : computation->instructions()) {
51 if (hlo->opcode() == HloOpcode::kAllReduce) {
52 ++count;
53 }
54 }
55 }
56 return count;
57 }
58
59 // inputs[i] will be some op producing a shape of size sizes_in_kib[i] which
60 // feeds into a a all reduce op in all_reduces[i]. Returns a tuple
61 // of the all_reduces.
MakeCrossReplicaReductions(std::vector<int64> sizes_in_kib,std::vector<HloComputation * > reductions,std::vector<HloInstruction * > * inputs,HloComputation::Builder * b)62 HloInstruction* MakeCrossReplicaReductions(
63 std::vector<int64> sizes_in_kib, std::vector<HloComputation*> reductions,
64 std::vector<HloInstruction*>* inputs, HloComputation::Builder* b) {
65 CHECK_EQ(reductions.size(), sizes_in_kib.size());
66 std::vector<HloInstruction*> all_reduces;
67 for (int i = 0; i < sizes_in_kib.size(); i++) {
68 int64 size_in_kib = sizes_in_kib[i];
69 HloComputation* reduction = reductions[i];
70 auto constant = b->AddInstruction(
71 HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.3)));
72 Shape shape = ShapeUtil::MakeShape(
73 F32, {static_cast<int32>(size_in_kib * 1024 / sizeof(float))});
74 auto input =
75 b->AddInstruction(HloInstruction::CreateBroadcast(shape, constant, {}));
76 inputs->push_back(input);
77 all_reduces.push_back(b->AddInstruction(HloInstruction::CreateAllReduce(
78 shape, {input}, reduction, /*replica_groups=*/{},
79 /*constrain_layout=*/false, /*channel_id=*/nullopt,
80 /*use_global_device_ids=*/false)));
81 }
82 return b->AddInstruction(HloInstruction::CreateTuple(all_reduces));
83 }
84
85 // Create and add a reduction computation in the given type to the module.
MakeReduction(const HloOpcode type,HloModule * module)86 HloComputation* MakeReduction(const HloOpcode type, HloModule* module) {
87 HloComputation::Builder sum_builder(HloOpcodeString(type));
88 auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter(
89 /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x"));
90 auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter(
91 /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y"));
92 sum_builder.AddInstruction(
93 HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {}), type, x, y));
94 HloComputation* reduction =
95 module->AddEmbeddedComputation(sum_builder.Build());
96 return reduction;
97 }
98
99 // Creates replica groups for AllReduce. groups[i] represents replica ids
100 // for group 'i'.
CreateReplicaGroups(absl::Span<const std::vector<int64>> groups)101 std::vector<ReplicaGroup> CreateReplicaGroups(
102 absl::Span<const std::vector<int64>> groups) {
103 std::vector<ReplicaGroup> replica_groups(groups.size());
104 for (int64 i = 0; i < groups.size(); ++i) {
105 *replica_groups[i].mutable_replica_ids() = {groups[i].begin(),
106 groups[i].end()};
107 }
108 return replica_groups;
109 }
110
111 using AllReduceCombinerTest = HloTestBase;
112
113 // Tests combination of several AllReduce instructions.
TEST_F(AllReduceCombinerTest,CombineAllReduces)114 TEST_F(AllReduceCombinerTest, CombineAllReduces) {
115 auto module = CreateNewVerifiedModule();
116 HloComputation* sum = MakeReduction(HloOpcode::kAdd, module.get());
117
118 HloComputation::Builder b(TestName());
119 std::vector<HloInstruction*> inputs;
120 auto root = MakeCrossReplicaReductions(
121 {1, 2, 10, 7, 6}, {sum, sum, sum, sum, sum}, &inputs, &b);
122 auto computation = module->AddEntryComputation(b.Build());
123
124 // Run the AllReduce combiner optimization pass.
125 AllReduceCombiner combine(10 * 1024 * 1024, kMaxCombineCount);
126 ASSERT_EQ(AllReduceCount(*module), inputs.size());
127 TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
128 ASSERT_EQ(AllReduceCount(*module), 1);
129 EXPECT_TRUE(changed);
130
131 ASSERT_EQ(root, computation->root_instruction());
132 ASSERT_EQ(inputs.size(), root->operands().size());
133
134 HloInstruction* combined = nullptr;
135 for (int64 i = 0; i < root->operands().size(); ++i) {
136 HloInstruction* hlo = root->mutable_operand(i);
137 ASSERT_TRUE(hlo->opcode() == HloOpcode::kGetTupleElement);
138 EXPECT_EQ(hlo->tuple_index(), i);
139 EXPECT_TRUE(ShapeUtil::Equal(inputs[i]->shape(), hlo->shape()));
140
141 if (combined == nullptr) {
142 // Verify the combined all reduce instruction.
143 combined = hlo->mutable_operand(0);
144 ASSERT_TRUE(combined->opcode() == HloOpcode::kAllReduce);
145 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), combined->shape()));
146 ASSERT_EQ(combined->operands().size(), inputs.size());
147 }
148 EXPECT_EQ(combined, hlo->operand(0));
149 EXPECT_TRUE(ShapeUtil::Equal(inputs[i]->shape(), hlo->shape()));
150 EXPECT_EQ(combined->operand(i), inputs[i]);
151 EXPECT_EQ(1, inputs[i]->users().size());
152 }
153 ASSERT_NE(combined, nullptr);
154 }
155
156 // Tests combination of several cross replica reduction instructions in
157 // different types.k
TEST_F(AllReduceCombinerTest,CombineCrossReplicaReductionsInGroups)158 TEST_F(AllReduceCombinerTest, CombineCrossReplicaReductionsInGroups) {
159 auto module = CreateNewVerifiedModule();
160 HloComputation* sum = MakeReduction(HloOpcode::kAdd, module.get());
161 HloComputation* min = MakeReduction(HloOpcode::kMinimum, module.get());
162 HloComputation* max = MakeReduction(HloOpcode::kMaximum, module.get());
163 HloComputation* sum_2 = MakeReduction(HloOpcode::kAdd, module.get());
164
165 HloComputation::Builder b(TestName());
166 std::vector<HloInstruction*> inputs;
167 MakeCrossReplicaReductions(
168 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
169 {sum, sum_2, min, min, min, max, max, max, sum, sum_2}, &inputs, &b);
170 module->AddEntryComputation(b.Build());
171
172 // Run the AllReduce combiner optimization pass.
173 AllReduceCombiner combine(10 * 1024 * 1024, kMaxCombineCount);
174 ASSERT_EQ(AllReduceCount(*module), inputs.size());
175 TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
176 ASSERT_EQ(AllReduceCount(*module), 3)
177 << "expects 3 groups for 3 reduction types.";
178 EXPECT_TRUE(changed);
179 }
180
181 // Tests that the combination threshold is respected.
TEST_F(AllReduceCombinerTest,RespectThreshold)182 TEST_F(AllReduceCombinerTest, RespectThreshold) {
183 auto module = CreateNewVerifiedModule();
184 HloComputation* sum = MakeReduction(HloOpcode::kAdd, module.get());
185
186 HloComputation::Builder b(TestName());
187 std::vector<HloInstruction*> inputs;
188 MakeCrossReplicaReductions({8, 4}, {sum, sum}, &inputs, &b);
189 module->AddEntryComputation(b.Build());
190
191 // Run the AllReduce combiner optimization pass with threshold less than
192 // the combined size of the all reduce ops so that the combination
193 // cannot occur.
194 {
195 AllReduceCombiner combine((8 + 4) * 1024 - 1, kMaxCombineCount);
196 ASSERT_EQ(AllReduceCount(*module), inputs.size());
197 TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
198 EXPECT_EQ(AllReduceCount(*module), inputs.size());
199 EXPECT_FALSE(changed);
200 }
201
202 // Run the AllReduce combiner optimization pass again with a slightly
203 // higher threshold so that the combination can occur.
204 {
205 AllReduceCombiner combine((8 + 4) * 1024, kMaxCombineCount);
206 ASSERT_EQ(AllReduceCount(*module), inputs.size());
207 TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
208 EXPECT_EQ(AllReduceCount(*module), 1);
209 EXPECT_TRUE(changed);
210 }
211 }
212
213 // Tests that dependent all reduces are not combined.
TEST_F(AllReduceCombinerTest,NoDependentCombination)214 TEST_F(AllReduceCombinerTest, NoDependentCombination) {
215 auto module = CreateNewVerifiedModule();
216 HloComputation* reduction = MakeReduction(HloOpcode::kAdd, module.get());
217
218 HloComputation::Builder b(TestName());
219 auto constant = b.AddInstruction(
220 HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.3)));
221 auto all_reduce = b.AddInstruction(HloInstruction::CreateAllReduce(
222 constant->shape(), {constant}, reduction, /*replica_groups=*/{},
223 /*constrain_layout=*/false, /*channel_id=*/nullopt,
224 /*use_global_device_ids=*/false));
225 b.AddInstruction(HloInstruction::CreateAllReduce(
226 constant->shape(), {all_reduce}, reduction,
227 /*replica_groups=*/{}, /*constrain_layout=*/false,
228 /*channel_id=*/nullopt, /*use_global_device_ids=*/false));
229
230 module->AddEntryComputation(b.Build());
231
232 AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
233 ASSERT_EQ(AllReduceCount(*module), 2);
234 TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
235 EXPECT_EQ(AllReduceCount(*module), 2);
236 EXPECT_FALSE(changed);
237 }
238
239 // Tests that AllReduce ops with different groups are not combined.
TEST_F(AllReduceCombinerTest,GroupAllReduce)240 TEST_F(AllReduceCombinerTest, GroupAllReduce) {
241 auto module = CreateNewVerifiedModule(TestName(), /*replica_count=*/4);
242 HloComputation::Builder b(TestName());
243 HloComputation* reduction = MakeReduction(HloOpcode::kAdd, module.get());
244
245 auto constant = b.AddInstruction(
246 HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.3)));
247 auto crs0 = b.AddInstruction(HloInstruction::CreateAllReduce(
248 constant->shape(), {constant}, reduction,
249 CreateReplicaGroups({{0, 1}, {2, 3}}),
250 /*constrain_layout=*/false,
251 /*channel_id=*/nullopt, /*use_global_device_ids=*/false));
252 auto crs1 = b.AddInstruction(HloInstruction::CreateAllReduce(
253 constant->shape(), {constant}, reduction,
254 CreateReplicaGroups({{0, 2}, {1, 3}}),
255 /*constrain_layout=*/false,
256 /*channel_id=*/nullopt, /*use_global_device_ids=*/false));
257 b.AddInstruction(HloInstruction::CreateTuple({crs0, crs1}));
258
259 module->AddEntryComputation(b.Build());
260
261 AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
262 ASSERT_EQ(AllReduceCount(*module), 2);
263 TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
264 EXPECT_EQ(AllReduceCount(*module), 2);
265 EXPECT_FALSE(changed);
266 }
267
TEST_F(AllReduceCombinerTest,DomainPreventsCombining)268 TEST_F(AllReduceCombinerTest, DomainPreventsCombining) {
269 const char* const hlo_string = R"(
270 HloModule Module
271
272 summit {
273 lhs = f32[] parameter(0)
274 rhs = f32[] parameter(1)
275 ROOT add = f32[] add(lhs, rhs)
276 }
277
278 ENTRY entry {
279 param0 = f32[128] parameter(0), sharding={maximal device=0}
280 param1 = f32[128] parameter(1), sharding={maximal device=1}
281 crs0 = f32[128] all-reduce(param0),
282 replica_groups={}, to_apply=summit, sharding={maximal device=0}
283 crs1 = f32[128] all-reduce(param1),
284 replica_groups={}, to_apply=summit, sharding={maximal device=1}
285 domain0 = f32[128] domain(crs0),
286 domain={kind="sharding", entry={{maximal device=0}, {maximal device=1}}, exit={maximal device=0}}
287 domain1 = f32[128] domain(crs1),
288 domain={kind="sharding", entry={{maximal device=0}, {maximal device=1}}, exit={maximal device=1}}
289 ROOT tuple = (f32[128], f32[128]) tuple(domain0, domain1),
290 sharding={{maximal device=0}, {maximal device=1}}
291 }
292 )";
293 TF_ASSERT_OK_AND_ASSIGN(auto module,
294 ParseAndReturnVerifiedModule(hlo_string));
295 LOG(INFO) << "Original module:\n" << module->ToString();
296
297 AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
298 ASSERT_EQ(AllReduceCount(*module), 2);
299 TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
300 EXPECT_EQ(AllReduceCount(*module), 2);
301 EXPECT_FALSE(changed);
302 }
303
304 // This test checks that two CRS instructions that are in separate domains
305 // but with the same domain metadata can be combined.
TEST_F(AllReduceCombinerTest,CombineFromTwoDomainsWithSameMetadata)306 TEST_F(AllReduceCombinerTest, CombineFromTwoDomainsWithSameMetadata) {
307 const char* const hlo_string = R"(
308 HloModule Module
309
310 summit {
311 lhs = f32[] parameter(0)
312 rhs = f32[] parameter(1)
313 ROOT add = f32[] add(lhs, rhs)
314 }
315
316 ENTRY entry {
317 param0 = f32[128] parameter(0), sharding={maximal device=0}
318 param1 = f32[128] parameter(1), sharding={maximal device=1}
319 param2 = f32[128] parameter(2), sharding={maximal device=1}
320 crs0 = f32[128] all-reduce(param0),
321 replica_groups={}, to_apply=summit, sharding={maximal device=0}
322 crs1 = f32[128] all-reduce(param1),
323 replica_groups={}, to_apply=summit, sharding={maximal device=1}
324 crs2 = f32[128] all-reduce(param2),
325 replica_groups={}, to_apply=summit, sharding={maximal device=0}
326 domain0 = f32[128] domain(crs0),
327 domain={kind="sharding", entry={{maximal device=0}, {maximal device=1},
328 {maximal device=0}}, exit={maximal device=0}}
329 domain1 = f32[128] domain(crs1),
330 domain={kind="sharding", entry={{maximal device=0}, {maximal device=1},
331 {maximal device=0}}, exit={maximal device=1}}
332 domain2 = f32[128] domain(crs2),
333 domain={kind="sharding", entry={{maximal device=0}, {maximal device=1},
334 {maximal device=0}}, exit={maximal device=0}}
335 ROOT tuple = (f32[128], f32[128], f32[128]) tuple(domain0, domain1, domain2),
336 sharding={{maximal device=0}, {maximal device=1}, {maximal device=0}}
337 }
338 )";
339 TF_ASSERT_OK_AND_ASSIGN(auto module,
340 ParseAndReturnVerifiedModule(hlo_string));
341
342 AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
343 ASSERT_EQ(AllReduceCount(*module), 3);
344 TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
345 EXPECT_EQ(AllReduceCount(*module), 2);
346 EXPECT_TRUE(changed);
347 }
348
TEST_F(AllReduceCombinerTest,DoNotCombineCrossShardAndCrosReplicaInSPMD)349 TEST_F(AllReduceCombinerTest, DoNotCombineCrossShardAndCrosReplicaInSPMD) {
350 const char* const hlo_string = R"(
351 HloModule Module
352
353 summit {
354 lhs = f32[] parameter(0)
355 rhs = f32[] parameter(1)
356 ROOT add = f32[] add(lhs, rhs)
357 }
358
359 ENTRY entry {
360 param0 = f32[128] parameter(0), sharding={maximal device=0}
361 param1 = f32[128] parameter(1), sharding={maximal device=1}
362 cross_shard_ar = f32[128] all-reduce(param0),
363 replica_groups={{0}}, to_apply=summit, channel_id=1
364 cross_replica_ar = f32[128] all-reduce(param1),
365 replica_groups={{0}}, to_apply=summit, sharding={maximal device=1}
366 ROOT tuple = (f32[128], f32[128]) tuple(cross_shard_ar, cross_replica_ar)
367 }
368 )";
369 TF_ASSERT_OK_AND_ASSIGN(auto module,
370 ParseAndReturnVerifiedModule(hlo_string));
371
372 AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
373 ASSERT_EQ(AllReduceCount(*module), 2);
374 TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
375 EXPECT_EQ(AllReduceCount(*module), 2);
376 EXPECT_FALSE(changed);
377 }
378
TEST_F(AllReduceCombinerTest,CrossCoreAllReduce)379 TEST_F(AllReduceCombinerTest, CrossCoreAllReduce) {
380 const char* const hlo_string = R"(
381 HloModule Module
382
383 summit {
384 lhs = f32[] parameter(0)
385 rhs = f32[] parameter(1)
386 ROOT add = f32[] add(lhs, rhs)
387 }
388
389 ENTRY entry {
390 param0 = f32[128] parameter(0), sharding={maximal device=0}
391 param1 = f32[128] parameter(1), sharding={maximal device=1}
392 crs00 = f32[128] all-reduce(param0),
393 replica_groups={{0}}, channel_id=1, to_apply=summit,
394 sharding={maximal device=0}
395 crs01 = f32[128] all-reduce(param1),
396 replica_groups={{0}}, channel_id=1, to_apply=summit,
397 sharding={maximal device=1}
398 crs10 = f32[128] all-reduce(param0),
399 replica_groups={{0}}, channel_id=2, to_apply=summit,
400 sharding={maximal device=0}
401 crs11 = f32[128] all-reduce(param1),
402 replica_groups={{0}}, channel_id=2, to_apply=summit,
403 sharding={maximal device=1}
404 domain0 = f32[128] domain(crs00),
405 domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
406 ROOT add = f32[128] add(domain0, crs11),
407 sharding={maximal device=1}
408 })";
409 TF_ASSERT_OK_AND_ASSIGN(auto module,
410 ParseAndReturnVerifiedModule(hlo_string));
411
412 AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
413 ASSERT_EQ(AllReduceCount(*module), 4);
414 TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
415 EXPECT_EQ(AllReduceCount(*module), 2);
416 EXPECT_TRUE(changed);
417
418 EXPECT_THAT(
419 module->entry_computation()->root_instruction(),
420 op::Add(op::Domain(op::GetTupleElement(
421 AllOf(op::AllReduce(op::Parameter(0), op::Parameter(0)),
422 op::Shape("(f32[128], f32[128])")),
423 1)),
424 op::GetTupleElement(
425 AllOf(op::AllReduce(op::Parameter(1), op::Parameter(1)),
426 op::Shape("(f32[128], f32[128])")),
427 0)));
428 }
429
TEST_F(AllReduceCombinerTest,CrossCombineGroupCycle)430 TEST_F(AllReduceCombinerTest, CrossCombineGroupCycle) {
431 const char* const hlo_string = R"(
432 HloModule module
433
434 %add {
435 lhs = f32[] parameter(0)
436 rhs = f32[] parameter(1)
437 ROOT add = f32[] add(lhs, rhs)
438 }
439
440 %max {
441 lhs = f32[] parameter(0)
442 rhs = f32[] parameter(1)
443 ROOT add = f32[] maximum(lhs, rhs)
444 }
445 ENTRY %comp {
446 p0 = f32[128] parameter(0)
447 p1 = f32[128] parameter(1)
448
449 crs00 = f32[128] all-reduce(p0), to_apply=add
450 crs10 = f32[128] all-reduce(p1), to_apply=max
451
452 crs01 = f32[128] all-reduce(crs00), to_apply=max
453 crs11 = f32[128] all-reduce(crs10), to_apply=add
454 add0 = f32[128] add(crs01, crs11)
455
456 crs02 = f32[128] all-reduce(add0), to_apply=add
457 crs12 = f32[128] all-reduce(crs11), to_apply=add
458 ROOT tuple = (f32[128], f32[128]) tuple(crs02, crs12)
459 })";
460 TF_ASSERT_OK_AND_ASSIGN(auto module,
461 ParseAndReturnVerifiedModule(hlo_string));
462
463 AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
464 ASSERT_EQ(AllReduceCount(*module), 6);
465 TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
466 EXPECT_EQ(AllReduceCount(*module), 4);
467 EXPECT_TRUE(changed);
468
469 auto crs0 = op::AllReduce(op::Parameter(0), op::AllReduce(op::Parameter(1)));
470 auto add = op::Add(op::AllReduce(op::GetTupleElement(crs0, 0)),
471 op::GetTupleElement(crs0, 1));
472 auto crs1 = op::AllReduce(add, op::GetTupleElement(crs0));
473 EXPECT_THAT(
474 module->entry_computation()->root_instruction(),
475 op::Tuple(op::GetTupleElement(crs1, 0), op::GetTupleElement(crs1, 1)));
476 }
477
478 } // namespace
479 } // namespace xla
480