1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/ar_crs_combiner.h"
17 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
18 #include "tensorflow/compiler/xla/statusor.h"
19 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
20 #include "tensorflow/core/lib/core/status_test_util.h"
21
22 namespace xla {
23 namespace {
24
25 namespace op = xla::testing::opcode_matchers;
26
27 class ArCrsCombinerTest : public HloTestBase {};
28
TEST_F(ArCrsCombinerTest,SameValueTestBasecase)29 TEST_F(ArCrsCombinerTest, SameValueTestBasecase) {
30 const char* module_str = R"(
31 HloModule foobar
32
33 ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
34 %p = f32[2,2] parameter(0)
35 %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}})
36 %constant.f32.2 = f32[2,2] constant({{1, 2}, {3, 4}})
37 ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2)
38 }
39 )";
40
41 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
42 ParseAndReturnVerifiedModule(module_str));
43 auto root_tuple = module->entry_computation()->root_instruction();
44 auto i1 = root_tuple->operands()[0];
45 auto i2 = root_tuple->operands()[1];
46 EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(
47 i1, module->entry_computation()->parameter_instruction(0)));
48 EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
49 }
50
TEST_F(ArCrsCombinerTest,SameValueTestBasecase2)51 TEST_F(ArCrsCombinerTest, SameValueTestBasecase2) {
52 const char* module_str = R"(
53 HloModule foobar
54
55 ENTRY %entrycomp (x: f32[]) -> (f32[], f32[]) {
56 %x = f32[] parameter(0)
57 ROOT %tuple = (f32[], f32[]) tuple(%x, %x)
58 }
59 )";
60
61 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
62 ParseAndReturnVerifiedModule(module_str));
63 auto root_tuple = module->entry_computation()->root_instruction();
64 auto i1 = root_tuple->operands()[0];
65 auto i2 = root_tuple->operands()[1];
66 EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
67 }
68
TEST_F(ArCrsCombinerTest,SameValueTestBasecase3)69 TEST_F(ArCrsCombinerTest, SameValueTestBasecase3) {
70 const char* module_str = R"(
71 HloModule foobar
72
73 ENTRY %entrycomp (x: f32[], y: f32[]) -> (f32[], f32[]) {
74 %x = f32[] parameter(0)
75 %y = f32[] parameter(1)
76 ROOT %tuple = (f32[], f32[]) tuple(%x, %y)
77 }
78 )";
79
80 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
81 ParseAndReturnVerifiedModule(module_str));
82 auto root_tuple = module->entry_computation()->root_instruction();
83 auto i1 = root_tuple->operands()[0];
84 auto i2 = root_tuple->operands()[1];
85 EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
86 }
87
TEST_F(ArCrsCombinerTest,SameValueTestNumOperands)88 TEST_F(ArCrsCombinerTest, SameValueTestNumOperands) {
89 const char* module_str = R"(
90 HloModule foobar
91
92 ENTRY %entrycomp (p: f32[2,2]) -> ((f32[2,2]), (f32[2,2], f32[2,2])) {
93 %p = f32[2,2] parameter(0)
94 %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
95 %tuple1 = (f32[2,2]) tuple(%constant.f32)
96 %tuple2 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
97 ROOT %tuple = ((f32[2,2]), (f32[2,2], f32[2,2])) tuple(%tuple1, %tuple2)
98 }
99 )";
100
101 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
102 ParseAndReturnVerifiedModule(module_str));
103 auto root_tuple = module->entry_computation()->root_instruction();
104 auto i1 = root_tuple->operands()[0];
105 auto i2 = root_tuple->operands()[1];
106 EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
107 }
108
TEST_F(ArCrsCombinerTest,SameValueTestSliceIndicesMatch)109 TEST_F(ArCrsCombinerTest, SameValueTestSliceIndicesMatch) {
110 const char* module_str = R"(
111 HloModule foobar
112
113 ENTRY %entrycomp (p: f32[2]) -> (f32[1], f32[1]) {
114 %p = f32[2] parameter(0)
115 %slice.1 = f32[1] slice(f32[2] %p), slice={[0:1]}
116 %slice.2 = f32[1] slice(f32[2] %p), slice={[0:1]}
117 ROOT %tuple = (f32[1], f32[1]) tuple(%slice.1, %slice.2)
118 }
119 )";
120
121 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
122 ParseAndReturnVerifiedModule(module_str));
123 auto root_tuple = module->entry_computation()->root_instruction();
124 auto i1 = root_tuple->operands()[0];
125 auto i2 = root_tuple->operands()[1];
126 EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
127 }
128
TEST_F(ArCrsCombinerTest,SameValueTestSliceIndicesDontMatch)129 TEST_F(ArCrsCombinerTest, SameValueTestSliceIndicesDontMatch) {
130 const char* module_str = R"(
131 HloModule foobar
132
133 ENTRY %entrycomp (p: f32[2]) -> (f32[1], f32[1]) {
134 %p = f32[2] parameter(0)
135 %slice.1 = f32[1] slice(f32[2] %p), slice={[0:1]}
136 %slice.2 = f32[1] slice(f32[2] %p), slice={[1:2]}
137 ROOT %tuple = (f32[1], f32[1]) tuple(%slice.1, %slice.2)
138 }
139 )";
140
141 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
142 ParseAndReturnVerifiedModule(module_str));
143 auto root_tuple = module->entry_computation()->root_instruction();
144 auto i1 = root_tuple->operands()[0];
145 auto i2 = root_tuple->operands()[1];
146 EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
147 }
148
TEST_F(ArCrsCombinerTest,SameValueTestTupleElementSameIndex)149 TEST_F(ArCrsCombinerTest, SameValueTestTupleElementSameIndex) {
150 const char* module_str = R"(
151 HloModule foobar
152
153 ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
154 %p = f32[2,2] parameter(0)
155 %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
156 %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
157 %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0
158 %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=0
159 ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%get-tuple-element.1, %get-tuple-element.2)
160 }
161 )";
162
163 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
164 ParseAndReturnVerifiedModule(module_str));
165 auto root_tuple = module->entry_computation()->root_instruction();
166 auto i1 = root_tuple->operands()[0];
167 auto i2 = root_tuple->operands()[1];
168 EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
169 }
170
TEST_F(ArCrsCombinerTest,SameValueTestTupleElementDifferentIndex1)171 TEST_F(ArCrsCombinerTest, SameValueTestTupleElementDifferentIndex1) {
172 const char* module_str = R"(
173 HloModule foobar
174
175 ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
176 %p = f32[2,2] parameter(0)
177 %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
178 %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
179 %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0
180 %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1
181 ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%get-tuple-element.1, %get-tuple-element.2)
182 }
183 )";
184
185 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
186 ParseAndReturnVerifiedModule(module_str));
187 auto root_tuple = module->entry_computation()->root_instruction();
188 auto i1 = root_tuple->operands()[0];
189 auto i2 = root_tuple->operands()[1];
190 EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
191 }
192
TEST_F(ArCrsCombinerTest,SameValueTestTupleElementDifferentIndex2)193 TEST_F(ArCrsCombinerTest, SameValueTestTupleElementDifferentIndex2) {
194 const char* module_str = R"(
195 HloModule foobar
196
197 ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
198 %p = f32[2,2] parameter(0)
199 %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}})
200 %constant.f32.2 = f32[2,2] constant({{2, 3}, {4, 5}})
201 %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2)
202 %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0
203 %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1
204 ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%get-tuple-element.1, %get-tuple-element.2)
205 }
206 )";
207
208 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
209 ParseAndReturnVerifiedModule(module_str));
210 auto root_tuple = module->entry_computation()->root_instruction();
211 auto i1 = root_tuple->operands()[0];
212 auto i2 = root_tuple->operands()[1];
213 EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
214 }
215
TEST_F(ArCrsCombinerTest,SameValueTestWhile1)216 TEST_F(ArCrsCombinerTest, SameValueTestWhile1) {
217 const char* module_str = R"(
218 HloModule foobar
219
220 %condition (x: (f32[2,2], f32[2,2])) -> pred[] {
221 %x = (f32[2,2], f32[2,2]) parameter(0)
222 %constant.0 = s32[] constant(0)
223 %constant.1 = s32[] constant(1)
224 ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT
225 }
226
227 %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
228 %x = (f32[2,2], f32[2,2]) parameter(0)
229 %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
230 %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0
231 %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1
232 %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32)
233 %add.2 = f32[2,2] add(%get-tuple-element.2, %constant.f32)
234 ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2)
235 }
236
237 ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) {
238 %constant.f32 = f32[2,2] constant({{3, 4}, {5, 6}})
239 %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
240 ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body
241 }
242 )";
243
244 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
245 ParseAndReturnVerifiedModule(module_str));
246 auto root_while = module->entry_computation()->root_instruction();
247 auto body_tuple = root_while->while_body()->root_instruction();
248 auto i1 = body_tuple->operands()[0];
249 auto i2 = body_tuple->operands()[1];
250 EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
251 }
252
TEST_F(ArCrsCombinerTest,SameValueTestWhile2)253 TEST_F(ArCrsCombinerTest, SameValueTestWhile2) {
254 const char* module_str = R"(
255 HloModule foobar
256
257 %condition (x: (f32[2,2], f32[2,2])) -> pred[] {
258 %x = (f32[2,2], f32[2,2]) parameter(0)
259 %constant.0 = s32[] constant(0)
260 %constant.1 = s32[] constant(1)
261 ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT
262 }
263
264 %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
265 %x = (f32[2,2], f32[2,2]) parameter(0)
266 %constant.f32 = f32[2,2] constant({{1, 2}, {3, 4}})
267 %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0
268 %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1
269 %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32)
270 %add.2 = f32[2,2] add(%get-tuple-element.2, %constant.f32)
271 ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2)
272 }
273
274 ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) {
275 %constant.f32.1 = f32[2,2] constant({{3, 4}, {5, 6}})
276 %constant.f32.2 = f32[2,2] constant({{3, 4}, {7, 8}})
277 %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2)
278 ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body
279 }
280 )";
281
282 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
283 ParseAndReturnVerifiedModule(module_str));
284 auto root_while = module->entry_computation()->root_instruction();
285 auto body_tuple = root_while->while_body()->root_instruction();
286 auto i1 = body_tuple->operands()[0];
287 auto i2 = body_tuple->operands()[1];
288 EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
289 }
290
TEST_F(ArCrsCombinerTest,SameValueTestWhile3)291 TEST_F(ArCrsCombinerTest, SameValueTestWhile3) {
292 const char* module_str = R"(
293 HloModule foobar
294
295 %condition (x: (f32[2,2], f32[2,2])) -> pred[] {
296 %x = (f32[2,2], f32[2,2]) parameter(0)
297 %constant.0 = s32[] constant(0)
298 %constant.1 = s32[] constant(1)
299 ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %constant.0), direction=GT
300 }
301
302 %body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) {
303 %x = (f32[2,2], f32[2,2]) parameter(0)
304 %constant.f32.1 = f32[2,2] constant({{1, 2}, {3, 4}})
305 %constant.f32.2 = f32[2,2] constant({{3, 4}, {1, 2}})
306 %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0
307 %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1
308 %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32.1)
309 %add.2 = f32[2,2] add(%get-tuple-element.2, %constant.f32.2)
310 ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2)
311 }
312
313 ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) {
314 %constant.f32 = f32[2,2] constant({{3, 4}, {5, 6}})
315 %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32)
316 ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body
317 }
318 )";
319
320 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
321 ParseAndReturnVerifiedModule(module_str));
322 auto root_while = module->entry_computation()->root_instruction();
323 auto body_tuple = root_while->while_body()->root_instruction();
324 auto i1 = body_tuple->operands()[0]->operands()[0]; // %get-tuple-element.1
325 auto i2 = body_tuple->operands()[1]->operands()[0]; // %get-tuple-element.2
326 EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
327 }
328
CompareReplicaGroups(const std::vector<ReplicaGroup> & groups_before,const std::vector<ReplicaGroup> & groups_after)329 void CompareReplicaGroups(const std::vector<ReplicaGroup>& groups_before,
330 const std::vector<ReplicaGroup>& groups_after) {
331 ASSERT_EQ(groups_before.size(), groups_after.size());
332 for (int i = 0; i < groups_before.size(); ++i) {
333 // Somewhat verbose way to compare the replica_ids, because EqualsProto
334 // is not available in the open-source build.
335 auto group_before = groups_before[i];
336 std::vector<int64> ids_before(group_before.replica_ids().begin(),
337 group_before.replica_ids().end());
338 auto group_after = groups_after[i];
339 std::vector<int64> ids_after(group_after.replica_ids().begin(),
340 group_after.replica_ids().end());
341 EXPECT_EQ(ids_before, ids_after);
342 }
343 }
344
TEST_F(ArCrsCombinerTest,RewriteArConvertCrs)345 TEST_F(ArCrsCombinerTest, RewriteArConvertCrs) {
346 const char* module_str = R"(
347 HloModule foobar
348
349 %sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
350 %a = bf16[] parameter(0)
351 %b = bf16[] parameter(1)
352 ROOT %add = bf16[] add(%a, %b)
353 }
354
355 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
356 %x = f32[] parameter(0)
357 %y = f32[] parameter(1)
358 ROOT %add = f32[] add(%x, %y)
359 }
360
361 ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
362 %p = bf16[] parameter(0)
363 %constant.bf16 = bf16[] constant(1)
364
365 %all-reduce.ar.1 = bf16[]
366 all-reduce(%p),
367 replica_groups={{0},{1}},
368 all_reduce_id=1,
369 to_apply=%sum.bf16,
370 sharding={maximal device=0}
371 %convert.1 = f32[]
372 convert(%all-reduce.ar.1),
373 sharding={maximal device=0}
374 %all-reduce.1 = f32[]
375 all-reduce(%convert.1),
376 replica_groups={{0,1}},
377 to_apply=%sum.f32,
378 sharding={maximal device=0}
379
380 %all-reduce.ar.2 = bf16[]
381 all-reduce(%constant.bf16),
382 replica_groups={{0},{1}},
383 all_reduce_id=1,
384 to_apply=%sum.bf16,
385 sharding={maximal device=1}
386 %convert.2 = f32[]
387 convert(%all-reduce.ar.2),
388 sharding={maximal device=1}
389 %all-reduce.2 = f32[]
390 all-reduce(%convert.2),
391 replica_groups={{0,1}},
392 to_apply=%sum.f32,
393 sharding={maximal device=1}
394
395 ROOT %tuple = (f32[], f32[])
396 tuple(%all-reduce.1, %all-reduce.2),
397 sharding={{maximal device=0}, {maximal device=1}}
398 }
399 )";
400
401 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
402 ParseAndReturnVerifiedModule(module_str));
403 auto crs_before =
404 module->entry_computation()->root_instruction()->operands()[0];
405 auto replica_groups_before = crs_before->replica_groups();
406 ArCrsCombiner combiner(2);
407 auto changed = combiner.Run(module.get()).ValueOrDie();
408 EXPECT_TRUE(changed);
409 EXPECT_THAT(module->entry_computation()->root_instruction(),
410 op::Tuple(op::AllReduce(op::Convert(op::Parameter())),
411 op::AllReduce(op::Convert(op::Constant()))));
412 auto crs_after =
413 module->entry_computation()->root_instruction()->operands()[0];
414 auto replica_groups_after = crs_after->replica_groups();
415 CompareReplicaGroups(replica_groups_before, replica_groups_after);
416 }
417
TEST_F(ArCrsCombinerTest,RewriteArBitcastCrs)418 TEST_F(ArCrsCombinerTest, RewriteArBitcastCrs) {
419 const char* module_str = R"(
420 HloModule foobar
421
422 %sum.1 (a: f32[2,1], b: f32[2,1]) -> f32[2,1] {
423 %a = f32[2,1] parameter(0)
424 %b = f32[2,1] parameter(1)
425 ROOT %add = f32[2,1] add(%a, %b)
426 }
427
428 %sum.2 (x: f32[2], y: f32[2]) -> f32[2] {
429 %x = f32[2] parameter(0)
430 %y = f32[2] parameter(1)
431 ROOT %add = f32[2] add(%x, %y)
432 }
433
434 ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) {
435 %p = f32[2,1] parameter(0)
436
437 %all-reduce.ar.1 = f32[2,1]
438 all-reduce(%p),
439 replica_groups={{0},{1}},
440 all_reduce_id=1,
441 to_apply=%sum.1,
442 sharding={maximal device=0}
443 %bitcast.1 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.1)
444 %all-reduce.1 = f32[2]
445 all-reduce(%bitcast.1),
446 replica_groups={{0,1}},
447 to_apply=%sum.2,
448 sharding={maximal device=0}
449
450 %all-reduce.ar.2 = f32[2,1]
451 all-reduce(%p),
452 replica_groups={{0},{1}},
453 all_reduce_id=1,
454 to_apply=%sum.1,
455 sharding={maximal device=1}
456 %bitcast.2 = f32[2]{0} bitcast(f32[2,1]{1,0} %all-reduce.ar.2)
457 %all-reduce.2 = f32[2]
458 all-reduce(%bitcast.2),
459 replica_groups={{0,1}},
460 to_apply=%sum.2,
461 sharding={maximal device=1}
462
463 ROOT %tuple = (f32[], f32[])
464 tuple(%all-reduce.1, %all-reduce.2),
465 sharding={{maximal device=0}, {maximal device=1}}
466 }
467 )";
468
469 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
470 ParseAndReturnVerifiedModule(module_str));
471 auto crs_before =
472 module->entry_computation()->root_instruction()->operands()[0];
473 auto replica_groups_before = crs_before->replica_groups();
474 ArCrsCombiner combiner(2);
475 auto changed = combiner.Run(module.get()).ValueOrDie();
476 EXPECT_TRUE(changed);
477 EXPECT_THAT(module->entry_computation()->root_instruction(),
478 op::Tuple(op::AllReduce(op::Bitcast(op::Parameter())),
479 op::AllReduce(op::Bitcast(op::Parameter()))));
480 auto crs_after =
481 module->entry_computation()->root_instruction()->operands()[0];
482 auto replica_groups_after = crs_after->replica_groups();
483 CompareReplicaGroups(replica_groups_before, replica_groups_after);
484 }
485
TEST_F(ArCrsCombinerTest,RewriteArMultiplyCrs)486 TEST_F(ArCrsCombinerTest, RewriteArMultiplyCrs) {
487 const char* module_str = R"(
488 HloModule foobar
489
490 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
491 %x = f32[] parameter(0)
492 %y = f32[] parameter(1)
493 ROOT %add = f32[] add(%x, %y)
494 }
495
496 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
497 %p = f32[] parameter(0)
498 %constant.f32 = f32[] constant(123)
499
500 %all-reduce.ar.1 = f32[]
501 all-reduce(%p),
502 replica_groups={{0},{1}},
503 all_reduce_id=1,
504 to_apply=%sum.f32,
505 sharding={maximal device=0}
506 %multiply.1 = f32[]
507 multiply(%all-reduce.ar.1, %constant.f32),
508 sharding={maximal device=0}
509 %all-reduce.1 = f32[]
510 all-reduce(%multiply.1),
511 replica_groups={{0,1}},
512 to_apply=%sum.f32,
513 sharding={maximal device=0}
514
515 %all-reduce.ar.2 = f32[]
516 all-reduce(%p),
517 replica_groups={{0},{1}},
518 all_reduce_id=1,
519 to_apply=%sum.f32,
520 sharding={maximal device=1}
521 %multiply.2 = f32[]
522 multiply(%all-reduce.ar.2, %constant.f32),
523 sharding={maximal device=1}
524 %all-reduce.2 = f32[]
525 all-reduce(%multiply.2),
526 replica_groups={{0,1}},
527 to_apply=%sum.f32,
528 sharding={maximal device=1}
529
530 ROOT %tuple = (f32[], f32[])
531 tuple(%all-reduce.1, %all-reduce.2),
532 sharding={{maximal device=0}, {maximal device=1}}
533 }
534 )";
535
536 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
537 ParseAndReturnVerifiedModule(module_str));
538 auto crs_before =
539 module->entry_computation()->root_instruction()->operands()[0];
540 auto replica_groups_before = crs_before->replica_groups();
541 ArCrsCombiner combiner(2);
542 auto changed = combiner.Run(module.get()).ValueOrDie();
543 EXPECT_TRUE(changed);
544 EXPECT_THAT(
545 module->entry_computation()->root_instruction(),
546 op::Tuple(op::AllReduce(op::Multiply(op::Parameter(), op::Constant())),
547 op::AllReduce(op::Multiply(op::Parameter(), op::Constant()))));
548 auto crs_after =
549 module->entry_computation()->root_instruction()->operands()[0];
550 auto replica_groups_after = crs_after->replica_groups();
551 CompareReplicaGroups(replica_groups_before, replica_groups_after);
552 }
553
TEST_F(ArCrsCombinerTest,RewriteArConvertAddCrs)554 TEST_F(ArCrsCombinerTest, RewriteArConvertAddCrs) {
555 const char* module_str = R"(
556 HloModule foobar
557
558 %sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
559 %a = bf16[] parameter(0)
560 %b = bf16[] parameter(1)
561 ROOT %add = bf16[] add(%a, %b)
562 }
563
564 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
565 %x = f32[] parameter(0)
566 %y = f32[] parameter(1)
567 ROOT %add = f32[] add(%x, %y)
568 }
569
570 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
571 %p = f32[] parameter(0)
572 %constant.bf16 = bf16[] constant(1)
573 %constant.f32 = f32[] constant(2)
574
575 %all-reduce.ar.1 = bf16[]
576 all-reduce(%constant.bf16),
577 replica_groups={{0},{1}},
578 all_reduce_id=1,
579 to_apply=%sum.bf16,
580 sharding={maximal device=0}
581 %convert.1 = f32[]
582 convert(%all-reduce.ar.1),
583 sharding={maximal device=0}
584 %add.1 = f32[]
585 add(%constant.f32, %convert.1),
586 sharding={maximal device=0}
587 %all-reduce.1 = f32[]
588 all-reduce(%add.1),
589 replica_groups={{0,1}},
590 to_apply=%sum.f32,
591 sharding={maximal device=0}
592
593 %all-reduce.ar.2 = bf16[]
594 all-reduce(%constant.bf16),
595 replica_groups={{0},{1}},
596 all_reduce_id=1,
597 to_apply=%sum.bf16,
598 sharding={maximal device=1}
599 %convert.2 = f32[]
600 convert(%all-reduce.ar.2),
601 sharding={maximal device=1}
602 %add.2 = f32[]
603 add(%constant.f32, %convert.2),
604 sharding={maximal device=1}
605 %all-reduce.2 = f32[]
606 all-reduce(%add.2),
607 replica_groups={{0,1}},
608 to_apply=%sum.f32,
609 sharding={maximal device=1}
610
611 ROOT %tuple = (f32[], f32[])
612 tuple(%all-reduce.1, %all-reduce.2),
613 sharding={{maximal device=0}, {maximal device=1}}
614 }
615 )";
616
617 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
618 ParseAndReturnVerifiedModule(module_str));
619 auto crs_before =
620 module->entry_computation()->root_instruction()->operands()[0];
621 auto replica_groups_before = crs_before->replica_groups();
622 ArCrsCombiner combiner(2);
623 auto changed = combiner.Run(module.get()).ValueOrDie();
624 EXPECT_TRUE(changed);
625 EXPECT_THAT(
626 module->entry_computation()->root_instruction(),
627 op::Tuple(
628 op::AllReduce(op::Add(op::Divide(op::Constant(), op::Constant()),
629 op::Convert())),
630 op::AllReduce(op::Add(op::Divide(op::Constant(), op::Constant()),
631 op::Convert()))));
632 auto crs_after =
633 module->entry_computation()->root_instruction()->operands()[0];
634 auto replica_groups_after = crs_after->replica_groups();
635 CompareReplicaGroups(replica_groups_before, replica_groups_after);
636 }
637
TEST_F(ArCrsCombinerTest,OtherSummandNotTheSameDontRewrite)638 TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewrite) {
639 const char* module_str = R"(
640 HloModule foobar
641
642 %sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
643 %a = bf16[] parameter(0)
644 %b = bf16[] parameter(1)
645 ROOT %add = bf16[] add(%a, %b)
646 }
647
648 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
649 %x = f32[] parameter(0)
650 %y = f32[] parameter(1)
651 ROOT %add = f32[] add(%x, %y)
652 }
653
654 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
655 %p = f32[] parameter(0)
656 %constant.bf16 = bf16[] constant(1)
657 %constant.f32.1 = f32[] constant(2)
658 %constant.f32.2 = f32[] constant(3)
659
660 %all-reduce.ar.1 = bf16[]
661 all-reduce(%constant.bf16),
662 replica_groups={{0},{1}},
663 all_reduce_id=1,
664 to_apply=%sum.bf16,
665 sharding={maximal device=0}
666 %convert.1 = f32[]
667 convert(%all-reduce.ar.1),
668 sharding={maximal device=0}
669 %add.1 = f32[]
670 add(%constant.f32.1, %convert.1),
671 sharding={maximal device=0}
672 %all-reduce.1 = f32[]
673 all-reduce(%add.1),
674 replica_groups={{0,1}},
675 to_apply=%sum.f32,
676 sharding={maximal device=0}
677
678 %all-reduce.ar.2 = bf16[]
679 all-reduce(%constant.bf16),
680 replica_groups={{0},{1}},
681 all_reduce_id=1,
682 to_apply=%sum.bf16,
683 sharding={maximal device=1}
684 %convert.2 = f32[]
685 convert(%all-reduce.ar.2),
686 sharding={maximal device=1}
687 %add.2 = f32[]
688 add(%constant.f32.2, %convert.2),
689 sharding={maximal device=1}
690 %all-reduce.2 = f32[]
691 all-reduce(%add.2),
692 replica_groups={{0,1}},
693 to_apply=%sum.f32,
694 sharding={maximal device=1}
695
696 ROOT %tuple = (f32[], f32[])
697 tuple(%all-reduce.1, %all-reduce.2),
698 sharding={{maximal device=0}, {maximal device=1}}
699 }
700 )";
701
702 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
703 ParseAndReturnVerifiedModule(module_str));
704 ArCrsCombiner combiner(2);
705 auto changed = combiner.Run(module.get()).ValueOrDie();
706 EXPECT_FALSE(changed);
707 }
708
TEST_F(ArCrsCombinerTest,ArThenCrsDontCrash)709 TEST_F(ArCrsCombinerTest, ArThenCrsDontCrash) {
710 const char* module_str = R"(
711 HloModule foobar
712
713 %sum.1 (a: f32[], b: f32[]) -> f32[] {
714 %a = f32[] parameter(0)
715 %b = f32[] parameter(1)
716 ROOT %add = f32[] add(%a, %b)
717 }
718
719 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
720 %p = f32[] parameter(0)
721 %constant.f32 = f32[] constant(123)
722
723 %all-reduce.ar.1 = f32[]
724 all-reduce(%p),
725 replica_groups={{0},{1}},
726 all_reduce_id=1,
727 to_apply=%sum.1,
728 sharding={maximal device=0}
729 %all-reduce.1 = f32[]
730 all-reduce(%all-reduce.ar.1),
731 replica_groups={{0,1}},
732 to_apply=%sum.1,
733 sharding={maximal device=0}
734 %multiply.1 = f32[]
735 multiply(%all-reduce.1, %constant.f32),
736 sharding={maximal device=0}
737
738 %all-reduce.ar.2 = f32[]
739 all-reduce(%p),
740 replica_groups={{0},{1}},
741 all_reduce_id=1,
742 to_apply=%sum.1,
743 sharding={maximal device=1}
744 %all-reduce.2 = f32[]
745 all-reduce(%all-reduce.ar.2),
746 replica_groups={{0,1}},
747 to_apply=%sum.1,
748 sharding={maximal device=1}
749 %multiply.2 = f32[]
750 multiply(%all-reduce.2, %constant.f32),
751 sharding={maximal device=1}
752
753 ROOT %tuple = (f32[], f32[])
754 tuple(%all-reduce.1, %all-reduce.2),
755 sharding={{maximal device=0}, {maximal device=1}}
756 }
757 )";
758
759 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
760 ParseAndReturnVerifiedModule(module_str));
761 auto crs_before =
762 module->entry_computation()->root_instruction()->operands()[0];
763 auto replica_groups_before = crs_before->replica_groups();
764 ArCrsCombiner combiner(2);
765 auto changed = combiner.Run(module.get()).ValueOrDie();
766 EXPECT_TRUE(changed);
767 EXPECT_THAT(module->entry_computation()->root_instruction(),
768 op::Tuple(op::AllReduce(op::Parameter()),
769 op::AllReduce(op::Parameter())));
770 auto crs_after =
771 module->entry_computation()->root_instruction()->operands()[0];
772 auto replica_groups_after = crs_after->replica_groups();
773 CompareReplicaGroups(replica_groups_before, replica_groups_after);
774 }
775
TEST_F(ArCrsCombinerTest,RewriteMultipleAdds)776 TEST_F(ArCrsCombinerTest, RewriteMultipleAdds) {
777 const char* module_str = R"(
778 HloModule foobar
779
780 %sum (x: f32[], y: f32[]) -> f32[] {
781 %x = f32[] parameter(0)
782 %y = f32[] parameter(1)
783 ROOT %add = f32[] add(%x, %y)
784 }
785
786 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
787 %p = f32[] parameter(0)
788 %constant.1 = f32[] constant(1)
789 %constant.2 = f32[] constant(2)
790
791 %all-reduce.ar.1 = f32[]
792 all-reduce(%p),
793 replica_groups={{0},{1}},
794 all_reduce_id=1,
795 to_apply=%sum,
796 sharding={maximal device=0}
797 %add.11 = f32[]
798 add(%constant.1, %all-reduce.ar.1),
799 sharding={maximal device=0}
800 %add.12 = f32[]
801 add(%constant.2, %add.11),
802 sharding={maximal device=0}
803 %all-reduce.1 = f32[]
804 all-reduce(%add.12),
805 replica_groups={{0,1}},
806 to_apply=%sum,
807 sharding={maximal device=0}
808
809 %all-reduce.ar.2 = f32[]
810 all-reduce(%p),
811 replica_groups={{0},{1}},
812 all_reduce_id=1,
813 to_apply=%sum,
814 sharding={maximal device=0}
815 %add.21 = f32[]
816 add(%constant.1, %all-reduce.ar.2),
817 sharding={maximal device=0}
818 %add.22 = f32[]
819 add(%constant.2, %add.21),
820 sharding={maximal device=0}
821 %all-reduce.2 = f32[]
822 all-reduce(%add.22),
823 replica_groups={{0,1}},
824 to_apply=%sum,
825 sharding={maximal device=0}
826
827 ROOT %tuple = (f32[], f32[])
828 tuple(%all-reduce.1, %all-reduce.2),
829 sharding={{maximal device=0}, {maximal device=1}}
830 }
831 )";
832
833 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
834 ParseAndReturnVerifiedModule(module_str));
835 auto crs_before =
836 module->entry_computation()->root_instruction()->operands()[0];
837 auto replica_groups_before = crs_before->replica_groups();
838 ArCrsCombiner combiner(2);
839 auto changed = combiner.Run(module.get()).ValueOrDie();
840 EXPECT_TRUE(changed);
841 EXPECT_THAT(module->entry_computation()->root_instruction(),
842 op::Tuple(op::AllReduce(op::Add(
843 op::Divide(op::Constant(), op::Constant()),
844 op::Add(op::Divide(op::Constant(), op::Constant()),
845 op::Parameter()))),
846 op::AllReduce(op::Add(
847 op::Divide(op::Constant(), op::Constant()),
848 op::Add(op::Divide(op::Constant(), op::Constant()),
849 op::Parameter())))));
850 auto crs_after =
851 module->entry_computation()->root_instruction()->operands()[0];
852 auto replica_groups_after = crs_after->replica_groups();
853 CompareReplicaGroups(replica_groups_before, replica_groups_after);
854 }
855
TEST_F(ArCrsCombinerTest,RewriteArSubtractCrs)856 TEST_F(ArCrsCombinerTest, RewriteArSubtractCrs) {
857 const char* module_str = R"(
858 HloModule foobar
859
860 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
861 %x = f32[] parameter(0)
862 %y = f32[] parameter(1)
863 ROOT %add = f32[] add(%x, %y)
864 }
865
866 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
867 %p = f32[] parameter(0)
868 %constant.f32 = f32[] constant(123)
869
870 %all-reduce.ar.1 = f32[]
871 all-reduce(%p),
872 replica_groups={{0},{1}},
873 all_reduce_id=1,
874 to_apply=%sum.f32,
875 sharding={maximal device=0}
876 %sub.1 = f32[]
877 subtract(%constant.f32, %all-reduce.ar.1),
878 sharding={maximal device=0}
879 %all-reduce.1 = f32[]
880 all-reduce(%sub.1),
881 replica_groups={{0,1}},
882 to_apply=%sum.f32,
883 sharding={maximal device=0}
884
885 %all-reduce.ar.2 = f32[]
886 all-reduce(%p),
887 replica_groups={{0},{1}},
888 all_reduce_id=1,
889 to_apply=%sum.f32,
890 sharding={maximal device=1}
891 %sub.2 = f32[]
892 subtract(%constant.f32, %all-reduce.ar.2),
893 sharding={maximal device=1}
894 %all-reduce.2 = f32[]
895 all-reduce(%sub.2),
896 replica_groups={{0,1}},
897 to_apply=%sum.f32,
898 sharding={maximal device=1}
899
900 ROOT %tuple = (f32[], f32[])
901 tuple(%all-reduce.1, %all-reduce.2),
902 sharding={{maximal device=0}, {maximal device=1}}
903 }
904 )";
905
906 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
907 ParseAndReturnVerifiedModule(module_str));
908 auto crs_before =
909 module->entry_computation()->root_instruction()->operands()[0];
910 auto replica_groups_before = crs_before->replica_groups();
911 ArCrsCombiner combiner(2);
912 auto changed = combiner.Run(module.get()).ValueOrDie();
913 EXPECT_TRUE(changed);
914 EXPECT_THAT(
915 module->entry_computation()->root_instruction(),
916 op::Tuple(
917 op::AllReduce(op::Subtract(op::Divide(op::Constant(), op::Constant()),
918 op::Parameter())),
919 op::AllReduce(op::Subtract(op::Divide(op::Constant(), op::Constant()),
920 op::Parameter()))));
921 auto crs_after =
922 module->entry_computation()->root_instruction()->operands()[0];
923 auto replica_groups_after = crs_after->replica_groups();
924 CompareReplicaGroups(replica_groups_before, replica_groups_after);
925 }
926
TEST_F(ArCrsCombinerTest,RewriteMultipleARsLeft)927 TEST_F(ArCrsCombinerTest, RewriteMultipleARsLeft) {
928 const char* module_str = R"(
929 HloModule foobar
930
931 %sum (x: f32[], y: f32[]) -> f32[] {
932 %x = f32[] parameter(0)
933 %y = f32[] parameter(1)
934 ROOT %add = f32[] add(%x, %y)
935 }
936
937 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
938 %p = f32[] parameter(0)
939 %const1 = f32[] constant(1)
940 %const2 = f32[] constant(2)
941
942 %ar11 = f32[]
943 all-reduce(%p),
944 replica_groups={{0},{1}},
945 all_reduce_id=1,
946 to_apply=%sum,
947 sharding={maximal device=0}
948 %add11 = f32[]
949 add(%ar11, %const1),
950 sharding={maximal device=0}
951 %ar12 = f32[]
952 all-reduce(%p),
953 replica_groups={{0},{1}},
954 all_reduce_id=2,
955 to_apply=%sum,
956 sharding={maximal device=0}
957 %add12 = f32[]
958 add(%add11, %ar12),
959 sharding={maximal device=0}
960 %crs1 = f32[]
961 all-reduce(%add12),
962 replica_groups={{0,1}},
963 to_apply=%sum,
964 sharding={maximal device=0}
965
966 %ar21 = f32[]
967 all-reduce(%p),
968 replica_groups={{0},{1}},
969 all_reduce_id=1,
970 to_apply=%sum,
971 sharding={maximal device=1}
972 %add21 = f32[]
973 add(%ar21, %const1),
974 sharding={maximal device=1}
975 %ar22 = f32[]
976 all-reduce(%p),
977 replica_groups={{0},{1}},
978 all_reduce_id=2,
979 to_apply=%sum,
980 sharding={maximal device=1}
981 %add22 = f32[]
982 add(%add21, %ar22),
983 sharding={maximal device=1}
984 %crs2 = f32[]
985 all-reduce(%add22),
986 replica_groups={{0,1}},
987 to_apply=%sum,
988 sharding={maximal device=1}
989
990 ROOT %tuple = (f32[], f32[])
991 tuple(%crs1, %crs2),
992 sharding={{maximal device=0}, {maximal device=1}}
993 }
994 )";
995
996 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
997 ParseAndReturnVerifiedModule(module_str));
998 auto crs_before =
999 module->entry_computation()->root_instruction()->operands()[0];
1000 auto replica_groups_before = crs_before->replica_groups();
1001 ArCrsCombiner combiner(2);
1002 auto changed = combiner.Run(module.get()).ValueOrDie();
1003 EXPECT_TRUE(changed);
1004 EXPECT_THAT(module->entry_computation()->root_instruction(),
1005 op::Tuple(op::AllReduce(op::Add(
1006 op::Add(op::Parameter(),
1007 op::Divide(op::Constant(), op::Constant())),
1008 op::Parameter())),
1009 op::AllReduce(op::Add(
1010 op::Add(op::Parameter(),
1011 op::Divide(op::Constant(), op::Constant())),
1012 op::Parameter()))));
1013 auto crs_after =
1014 module->entry_computation()->root_instruction()->operands()[0];
1015 auto replica_groups_after = crs_after->replica_groups();
1016 CompareReplicaGroups(replica_groups_before, replica_groups_after);
1017 }
1018
TEST_F(ArCrsCombinerTest,RewriteMultipleARsRight)1019 TEST_F(ArCrsCombinerTest, RewriteMultipleARsRight) {
1020 const char* module_str = R"(
1021 HloModule foobar
1022
1023 %sum (x: f32[], y: f32[]) -> f32[] {
1024 %x = f32[] parameter(0)
1025 %y = f32[] parameter(1)
1026 ROOT %add = f32[] add(%x, %y)
1027 }
1028
1029 ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
1030 %p = f32[] parameter(0)
1031 %const1 = f32[] constant(1)
1032 %const2 = f32[] constant(2)
1033
1034 %ar11 = f32[]
1035 all-reduce(%p),
1036 replica_groups={{0},{1}},
1037 all_reduce_id=1,
1038 to_apply=%sum,
1039 sharding={maximal device=0}
1040 %ar12 = f32[]
1041 all-reduce(%p),
1042 replica_groups={{0},{1}},
1043 all_reduce_id=2,
1044 to_apply=%sum,
1045 sharding={maximal device=0}
1046 %add11 = f32[]
1047 add(%ar12, %const1),
1048 sharding={maximal device=0}
1049 %add12 = f32[]
1050 add(%ar11, %add11),
1051 sharding={maximal device=0}
1052 %crs1 = f32[]
1053 all-reduce(%add12),
1054 replica_groups={{0,1}},
1055 to_apply=%sum,
1056 sharding={maximal device=0}
1057
1058 %ar21 = f32[]
1059 all-reduce(%p),
1060 replica_groups={{0},{1}},
1061 all_reduce_id=1,
1062 to_apply=%sum,
1063 sharding={maximal device=1}
1064 %ar22 = f32[]
1065 all-reduce(%p),
1066 replica_groups={{0},{1}},
1067 all_reduce_id=2,
1068 to_apply=%sum,
1069 sharding={maximal device=1}
1070 %add21 = f32[]
1071 add(%ar22, %const1),
1072 sharding={maximal device=1}
1073 %add22 = f32[]
1074 add(%ar21, %add21),
1075 sharding={maximal device=1}
1076 %crs2 = f32[]
1077 all-reduce(%add22),
1078 replica_groups={{0,1}},
1079 to_apply=%sum,
1080 sharding={maximal device=1}
1081
1082 ROOT %tuple = (f32[], f32[])
1083 tuple(%crs1, %crs2),
1084 sharding={{maximal device=0}, {maximal device=1}}
1085 }
1086 )";
1087
1088 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1089 ParseAndReturnVerifiedModule(module_str));
1090 auto crs_before =
1091 module->entry_computation()->root_instruction()->operands()[0];
1092 auto replica_groups_before = crs_before->replica_groups();
1093 ArCrsCombiner combiner(2);
1094 auto changed = combiner.Run(module.get()).ValueOrDie();
1095 EXPECT_TRUE(changed);
1096 EXPECT_THAT(
1097 module->entry_computation()->root_instruction(),
1098 op::Tuple(op::AllReduce(op::Add(
1099 op::Parameter(),
1100 op::Add(op::Parameter(),
1101 op::Divide(op::Constant(), op::Constant())))),
1102 op::AllReduce(op::Add(
1103 op::Parameter(),
1104 op::Add(op::Parameter(),
1105 op::Divide(op::Constant(), op::Constant()))))));
1106
1107 auto crs_after =
1108 module->entry_computation()->root_instruction()->operands()[0];
1109 auto replica_groups_after = crs_after->replica_groups();
1110 CompareReplicaGroups(replica_groups_before, replica_groups_after);
1111 }
1112
TEST_F(ArCrsCombinerTest,OneReplicaDontRewrite)1113 TEST_F(ArCrsCombinerTest, OneReplicaDontRewrite) {
1114 const char* module_str = R"(
1115 HloModule foobar
1116
1117 %sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
1118 %a = bf16[] parameter(0)
1119 %b = bf16[] parameter(1)
1120 ROOT %add = bf16[] add(%a, %b)
1121 }
1122
1123 %sum.f32 (x: f32[], y: f32[]) -> f32[] {
1124 %x = f32[] parameter(0)
1125 %y = f32[] parameter(1)
1126 ROOT %add = f32[] add(%x, %y)
1127 }
1128
1129 ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
1130 %p = bf16[] parameter(0)
1131 %constant.bf16 = bf16[] constant(1)
1132
1133 %all-reduce.ar.1 = bf16[]
1134 all-reduce(%p),
1135 replica_groups={{0}},
1136 all_reduce_id=1,
1137 to_apply=%sum.bf16,
1138 sharding={maximal device=0}
1139 %convert.1 = f32[]
1140 convert(%all-reduce.ar.1),
1141 sharding={maximal device=0}
1142 %all-reduce.1 = f32[]
1143 all-reduce(%convert.1),
1144 replica_groups={{0}},
1145 to_apply=%sum.f32,
1146 sharding={maximal device=0}
1147
1148 %all-reduce.ar.2 = bf16[]
1149 all-reduce(%constant.bf16),
1150 replica_groups={{0}},
1151 all_reduce_id=1,
1152 to_apply=%sum.bf16,
1153 sharding={maximal device=1}
1154 %convert.2 = f32[]
1155 convert(%all-reduce.ar.2),
1156 sharding={maximal device=1}
1157 %all-reduce.2 = f32[]
1158 all-reduce(%convert.2),
1159 replica_groups={{0}},
1160 to_apply=%sum.f32,
1161 sharding={maximal device=1}
1162
1163 ROOT %tuple = (f32[], f32[])
1164 tuple(%all-reduce.1, %all-reduce.2),
1165 sharding={{maximal device=0}, {maximal device=1}}
1166 }
1167 )";
1168
1169 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1170 ParseAndReturnVerifiedModule(module_str));
1171 ArCrsCombiner combiner(2);
1172 auto changed = combiner.Run(module.get()).ValueOrDie();
1173 EXPECT_FALSE(changed);
1174 }
1175
1176 } // namespace
1177 } // namespace xla
1178