1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "absl/strings/str_replace.h"
17 #include "absl/types/span.h"
18 #include "tensorflow/compiler/xla/literal.h"
19 #include "tensorflow/compiler/xla/primitive_util.h"
20 #include "tensorflow/compiler/xla/service/gpu/nccl_test_utils.h"
21 #include "tensorflow/compiler/xla/shape_util.h"
22 #include "tensorflow/compiler/xla/test.h"
23 #include "tensorflow/compiler/xla/test_helpers.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
26 #include "tensorflow/compiler/xla/tests/test_macros.h"
27 #include "tensorflow/core/lib/core/blocking_counter.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/lib/core/threadpool.h"
30 #include "tensorflow/core/platform/env.h"
31
32 // Tests cross-GPU operations.
33 //
34 // This test requires at least four GPUs. For instructions on running this
35 // within Google, see go/multi-gpu-unit-test.
36
37 namespace xla {
38 namespace {
39
40 using ::testing::IsSupersetOf;
41
42 class CollectiveOpsTest : public HloTestBase {
43 public:
SetUpTestSuite()44 static void SetUpTestSuite() {
45 // Not needed structly, since this test exercises cross replica collective
46 // permute which does not use NCCL. But keeping it here for testing.
47 tensorflow::setenv("NCCL_LAUNCH_MODE", "PARALLEL", /*overwrite=*/1);
48 HloTestBase::SetUpTestSuite();
49 }
50
51 protected:
MakeCrsModule(const Shape & shape,std::vector<std::vector<int64>> replica_groups,const HloModuleConfig & config,std::string op="add",std::string datatype="f32")52 std::unique_ptr<HloModule> MakeCrsModule(
53 const Shape& shape, std::vector<std::vector<int64>> replica_groups,
54 const HloModuleConfig& config, std::string op = "add",
55 std::string datatype = "f32") {
56 std::string hlo_template = R"(
57 HloModule test
58
59 apply_op {
60 x = DATATYPE[] parameter(0)
61 y = DATATYPE[] parameter(1)
62 ROOT apply_op = DATATYPE[] OP(x, y)
63 }
64
65 ENTRY test_computation {
66 p = SHAPE parameter(0)
67 p2 = SHAPE bitcast(p)
68 crs = SHAPE all-reduce(p2), replica_groups=REPLICA_GROUPS, to_apply=apply_op
69 copy = SHAPE copy(crs)
70 ROOT out = SHAPE bitcast(copy)
71 }
72 )";
73 std::vector<string> replica_group_strs;
74 for (const auto& g : replica_groups) {
75 replica_group_strs.push_back(
76 absl::StrFormat("{%s}", absl::StrJoin(g, ",")));
77 }
78 std::string shape_str = shape.ToString(/*print_layout=*/false);
79 if (shape_str == "f32[1]") {
80 // Exercise the scalar codepath.
81 hlo_template = absl::StrReplaceAll(
82 hlo_template,
83 {{"DATATYPE[SHAPE] bitcast(p)", "DATATYPE[] bitcast(p)"},
84 {"DATATYPE[SHAPE] all-reduce", "DATATYPE[] all-reduce"},
85 {"DATATYPE[SHAPE] copy", "DATATYPE[] copy"}});
86 }
87 std::string parameterized_hlo = absl::StrReplaceAll(
88 hlo_template,
89 {{"SHAPE", shape_str},
90 {"REPLICA_GROUPS",
91 absl::StrFormat("{%s}", absl::StrJoin(replica_group_strs, ", "))},
92 {"OP", op},
93 {"DATATYPE", datatype}});
94 return ParseAndReturnVerifiedModule(parameterized_hlo, config).ValueOrDie();
95 }
96
97 template <typename LiteralType>
TestTwoReplicasOneOperand(std::string op,Literal input_value,Literal expected_value)98 void TestTwoReplicasOneOperand(std::string op, Literal input_value,
99 Literal expected_value) {
100 const int kNumReplicas = 2;
101 std::string dtype = primitive_util::LowercasePrimitiveTypeName(
102 primitive_util::NativeToPrimitiveType<LiteralType>());
103 auto config = GetModuleConfigForTest();
104 config.set_replica_count(kNumReplicas);
105 auto module = MakeCrsModule(
106 /*shape_str=*/input_value.shape(),
107 /*replica_groups=*/{}, config,
108 /*op=*/op, /*datatype=*/dtype);
109 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
110 ExecuteReplicated(std::move(module), {&input_value},
111 /*num_replicas=*/kNumReplicas,
112 /*use_threads=*/true));
113 for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) {
114 EXPECT_TRUE(LiteralTestUtil::NearOrEqual(
115 expected_value, results[replica_idx], ErrorSpec{1e-5, 1e-5}));
116 }
117 }
118
119 template <typename LiteralType>
TestAllOpsForReduce()120 void TestAllOpsForReduce() {
121 auto cast = [&](int value) { return static_cast<LiteralType>(value); };
122 auto to_literal = [&](absl::Span<const LiteralType> values) {
123 return LiteralUtil::CreateR1<LiteralType>(values);
124 };
125 Literal input_value = to_literal({cast(1), cast(2), cast(3)});
126 TestTwoReplicasOneOperand<LiteralType>(
127 "add",
128 /*input_value=*/input_value.Clone(),
129 /*expected_value=*/to_literal({cast(2), cast(4), cast(6)}));
130 TestTwoReplicasOneOperand<LiteralType>(
131 "multiply",
132 /*input_value=*/input_value.Clone(),
133 /*expected_value=*/to_literal({cast(1), cast(4), cast(9)}));
134 TestTwoReplicasOneOperand<LiteralType>(
135 "maximum",
136 /*input_value=*/input_value.Clone(),
137 /*expected_value=*/to_literal({cast(1), cast(2), cast(3)}));
138 TestTwoReplicasOneOperand<LiteralType>(
139 "minimum",
140 /*input_value=*/input_value.Clone(),
141 /*expected_value=*/to_literal({cast(1), cast(2), cast(3)}));
142 }
143 };
144
145 // Returns the non-empty subsets of {0, 1, ..., n}. For example,
146 // PowerSetOfIota(3) = {{0}, {1}, {2}, {0,1}, {0,2}, {1,2}, {0,1,2}}.
PowerSetOfIota(int64_t n)147 std::vector<std::vector<int64>> PowerSetOfIota(int64_t n) {
148 std::vector<std::vector<int64>> power_set;
149 for (int64_t i = 1; i < (1 << n); ++i) {
150 power_set.emplace_back();
151 for (int64_t j = 0; j < n; ++j) {
152 if (i & (1 << j)) {
153 power_set.back().push_back(j);
154 }
155 }
156 }
157 return power_set;
158 }
159
160 // Makes a DeviceAssignment assigning replica-id i to devices[i].
MakeDeviceAssn(std::vector<int64> devices)161 DeviceAssignment MakeDeviceAssn(std::vector<int64> devices) {
162 DeviceAssignment assn(/*replica_count=*/devices.size(),
163 /*computation_count=*/1);
164 for (int64_t i = 0; i < devices.size(); ++i) {
165 assn(i, 0) = devices[i];
166 }
167 return assn;
168 }
169
170 // Shorter alias for this function.
OpenNcclChannels()171 absl::flat_hash_set<GlobalDeviceId> OpenNcclChannels() {
172 return gpu::DevicesWithOpenNcclChannels();
173 }
174
175 template <typename T>
ToHalf(T value)176 static Eigen::half ToHalf(T value) {
177 return static_cast<Eigen::half>(value);
178 }
179
XLA_TEST_F(CollectiveOpsTest,AllReduce_sum_float32_2D)180 XLA_TEST_F(CollectiveOpsTest, AllReduce_sum_float32_2D) {
181 TestTwoReplicasOneOperand<float>(
182 "add",
183 /*input_value=*/LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}),
184 /*expected_value=*/LiteralUtil::CreateR2<float>({{2, 4}, {6, 8}}));
185 }
186
XLA_TEST_F(CollectiveOpsTest,AllReduceSingleOutput_float32)187 XLA_TEST_F(CollectiveOpsTest, AllReduceSingleOutput_float32) {
188 TestTwoReplicasOneOperand<float>(
189 "add",
190 /*input_value=*/LiteralUtil::CreateR1<float>({1}),
191 /*expected_value=*/LiteralUtil::CreateR1<float>({2}));
192 }
193
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_int8)194 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int8) {
195 TestAllOpsForReduce<int8>();
196 }
197
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_uint8)198 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint8) {
199 TestAllOpsForReduce<uint8>();
200 }
201
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_uint32)202 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint32) {
203 TestAllOpsForReduce<uint32>();
204 }
205
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_int32)206 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int32) {
207 TestAllOpsForReduce<int32>();
208 }
209
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_int64)210 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int64) {
211 TestAllOpsForReduce<int64>();
212 }
213
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_uint64)214 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint64) {
215 TestAllOpsForReduce<uint64>();
216 }
217
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_float32)218 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_float32) {
219 TestAllOpsForReduce<float>();
220 }
221
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_double)222 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_double) {
223 TestAllOpsForReduce<double>();
224 }
225
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_half)226 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_half) {
227 TestAllOpsForReduce<Eigen::half>();
228 }
229
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllReduceTwoReplicasOneOperand_bfloat16))230 XLA_TEST_F(CollectiveOpsTest,
231 DISABLED_ON_CPU(AllReduceTwoReplicasOneOperand_bfloat16)) {
232 TestAllOpsForReduce<bfloat16>();
233 }
234
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllReduce_sum_complex64))235 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllReduce_sum_complex64)) {
236 TestTwoReplicasOneOperand<complex64>(
237 "add",
238 /*input_value=*/LiteralUtil::CreateR1<complex64>({{1, 2}, {3, 4}}),
239 /*expected_value=*/LiteralUtil::CreateR1<complex64>({{2, 4}, {6, 8}}));
240 }
241
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllReduce_sum_complex128))242 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllReduce_sum_complex128)) {
243 TestTwoReplicasOneOperand<complex128>(
244 "add",
245 /*input_value=*/LiteralUtil::CreateR1<complex128>({{1, 2}, {3, 4}}),
246 /*expected_value=*/LiteralUtil::CreateR1<complex128>({{2, 4}, {6, 8}}));
247 }
248
XLA_TEST_F(CollectiveOpsTest,AllReduceAnd_Pred)249 XLA_TEST_F(CollectiveOpsTest, AllReduceAnd_Pred) {
250 // Test with equal elements.
251 TestTwoReplicasOneOperand<bool>(
252 "and",
253 /*input_value=*/LiteralUtil::CreateR1<bool>({true, false}),
254 /*expected_value=*/LiteralUtil::CreateR1<bool>({true, false}));
255
256 // Test with {true, false}.
257 const char* hlo_module = R"(
258 HloModule test
259
260 apply_op {
261 x = pred[] parameter(0)
262 y = pred[] parameter(1)
263 ROOT apply_op = pred[] and(x, y)
264 }
265
266 ENTRY test_computation {
267 id = u32[] replica-id()
268 c = u32[] constant(0)
269 p = pred[] compare(id, c), direction=EQ
270 p2 = pred[1] bitcast(p)
271 crs = pred[1] all-reduce(p2), replica_groups={}, to_apply=apply_op
272 copy = pred[1] copy(crs)
273 ROOT out = pred[1] bitcast(copy)
274 }
275 )";
276
277 auto config = GetModuleConfigForTest();
278 config.set_replica_count(2);
279 auto module = ParseAndReturnVerifiedModule(hlo_module, config).ValueOrDie();
280 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
281 ExecuteReplicated(std::move(module), {},
282 /*num_replicas=*/2,
283 /*use_threads=*/true));
284 for (int replica_idx = 0; replica_idx < 2; replica_idx++) {
285 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<bool>({false}),
286 results[replica_idx]));
287 }
288 }
289
XLA_TEST_F(CollectiveOpsTest,AllReduceOr_Pred)290 XLA_TEST_F(CollectiveOpsTest, AllReduceOr_Pred) {
291 // Test with equal elements.
292 TestTwoReplicasOneOperand<bool>(
293 "or",
294 /*input_value=*/LiteralUtil::CreateR1<bool>({true, false}),
295 /*expected_value=*/LiteralUtil::CreateR1<bool>({true, false}));
296
297 // Test with {true, false}.
298 const char* hlo_module = R"(
299 HloModule test
300
301 apply_op {
302 x = pred[] parameter(0)
303 y = pred[] parameter(1)
304 ROOT apply_op = pred[] or(x, y)
305 }
306
307 ENTRY test_computation {
308 id = u32[] replica-id()
309 c = u32[] constant(0)
310 p = pred[] compare(id, c), direction=EQ
311 p2 = pred[1] bitcast(p)
312 crs = pred[1] all-reduce(p2), replica_groups={}, to_apply=apply_op
313 copy = pred[1] copy(crs)
314 ROOT out = pred[1] bitcast(copy)
315 }
316 )";
317
318 auto config = GetModuleConfigForTest();
319 config.set_replica_count(2);
320 auto module = ParseAndReturnVerifiedModule(hlo_module, config).ValueOrDie();
321 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
322 ExecuteReplicated(std::move(module), {},
323 /*num_replicas=*/2,
324 /*use_threads=*/true));
325 for (int replica_idx = 0; replica_idx < 2; replica_idx++) {
326 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<bool>({true}),
327 results[replica_idx]));
328 }
329 }
330
331 // Tries all-to-all operations across all 2^kNumDevices - 1 combinations of
332 // devices in sequence.
XLA_TEST_F(CollectiveOpsTest,AllReduce_AllCombinations)333 XLA_TEST_F(CollectiveOpsTest, AllReduce_AllCombinations) {
334 const int64_t kNumDevices = 4;
335 const int64_t kNumElems = 1024;
336
337 for (std::vector<int64> devices : PowerSetOfIota(kNumDevices)) {
338 SCOPED_TRACE(absl::StrFormat("Running on devices {%s}",
339 absl::StrJoin(devices, ", ")));
340
341 DeviceAssignment device_assn = MakeDeviceAssn(devices);
342
343 auto config = GetModuleConfigForTest();
344 config.set_replica_count(devices.size());
345 config.set_static_device_assignment(device_assn);
346
347 std::vector<float> input_vec(kNumElems);
348 absl::c_iota(input_vec, 0);
349 auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
350
351 auto module = MakeCrsModule(input_literal.shape(),
352 /*replica_groups=*/{}, config);
353
354 TF_ASSERT_OK_AND_ASSIGN(
355 std::vector<Literal> results,
356 ExecuteReplicated(std::move(module), {&input_literal},
357 /*num_replicas=*/devices.size(), &device_assn,
358 /*run_hlo_passes=*/true, /*use_threads=*/true));
359 }
360 }
361
362 // Check that the NCCL data structures in our all-reduce implementation are
363 // cached as we expect.
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllReduce_NcclChannelCaching))364 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllReduce_NcclChannelCaching)) {
365 const int64_t kNumElems = 1024;
366
367 std::vector<float> input_vec(kNumElems);
368 absl::c_iota(input_vec, 0);
369 auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
370
371 // Create three Executables, touching devices {0,1}, {1,2}, and {0,1,2}.
372 struct ExecutableInfo {
373 std::unique_ptr<Executable> executable;
374 DeviceAssignment device_assn;
375 HloRunner::ReplicatedExecuteOptions opts;
376 };
377 std::vector<ExecutableInfo> executables;
378 for (const auto& devices :
379 std::vector<std::vector<int64>>{{0, 1}, {1, 2}, {0, 1, 2}}) {
380 executables.emplace_back();
381 auto& e = executables.back();
382
383 e.device_assn = MakeDeviceAssn(devices);
384
385 auto config = GetModuleConfigForTest();
386 config.set_replica_count(devices.size());
387 config.set_static_device_assignment(e.device_assn);
388 auto module = MakeCrsModule(input_literal.shape(),
389 /*replica_groups=*/{}, config);
390 e.executable =
391 test_runner_
392 .CreateExecutable(std::move(module), /*run_hlo_passes=*/true)
393 .ValueOrDie();
394
395 e.opts.num_replicas = devices.size();
396 e.opts.use_threads = true;
397 e.opts.arguments.push_back(&input_literal);
398 }
399
400 auto run_executable = [&](int64_t i) {
401 auto& e = executables[i];
402 TF_ASSERT_OK(
403 test_runner_
404 .ExecuteReplicated(e.executable.get(), e.opts, &e.device_assn)
405 .status());
406 };
407
408 // Run the executables and check that channels are opened as we expect.
409 run_executable(0);
410 EXPECT_THAT(OpenNcclChannels(), IsSupersetOf({0, 1}));
411
412 run_executable(2);
413 EXPECT_THAT(OpenNcclChannels(), IsSupersetOf({0, 1, 2}));
414
415 run_executable(1);
416 EXPECT_THAT(OpenNcclChannels(), IsSupersetOf({0, 1, 2}));
417
418 // Tear down the executables and check that channels are closed as we expect.
419 // Note that after we tear down an executable *all* the nccl channels may go
420 // away, so we rerun all of the executables that haven't been torn down.
421 executables[2].executable.reset();
422 run_executable(0);
423 run_executable(1);
424 EXPECT_THAT(OpenNcclChannels(), IsSupersetOf({0, 1, 2}));
425
426 executables[0].executable.reset();
427 run_executable(1);
428 EXPECT_THAT(OpenNcclChannels(), IsSupersetOf({1, 2}));
429
430 executables[1].executable.reset();
431 }
432
433 // Runs the same executable many times concurrently. The all-reduces should not
434 // conflict with one another.
XLA_TEST_F(CollectiveOpsTest,AllReduce_ManyConcurrentAllReduces)435 XLA_TEST_F(CollectiveOpsTest, AllReduce_ManyConcurrentAllReduces) {
436 const int64_t kNumElems = 1024;
437 const int64_t kNumThreads = 200;
438 const int64_t kRunsPerThread = 10;
439
440 std::vector<float> input_vec(kNumElems);
441 absl::c_iota(input_vec, 0);
442 auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
443
444 auto config = GetModuleConfigForTest();
445 config.set_replica_count(2);
446 auto executable =
447 test_runner_
448 .CreateExecutable(MakeCrsModule(input_literal.shape(),
449 /*replica_groups=*/{}, config),
450 /*run_hlo_passes=*/true)
451 .ValueOrDie();
452 std::vector<int64> devices = {0, 1};
453 auto device_assn = MakeDeviceAssn(devices);
454
455 HloRunner::ReplicatedExecuteOptions opts;
456 opts.num_replicas = devices.size();
457 opts.use_threads = true;
458 opts.arguments.push_back(&input_literal);
459
460 tensorflow::BlockingCounter done(kNumThreads * kRunsPerThread);
461 tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), TestName(),
462 kNumThreads);
463 for (int64_t i = 0; i < kNumThreads * kRunsPerThread; ++i) {
464 pool.Schedule([&] {
465 TF_ASSERT_OK(
466 test_runner_.ExecuteReplicated(executable.get(), opts, &device_assn)
467 .status());
468 done.DecrementCount();
469 });
470 }
471 done.Wait();
472 }
473
474 // Runs the same executable many times concurrently. The all-reduces should not
475 // conflict with one another.
XLA_TEST_F(CollectiveOpsTest,AllReduce_CombinableAllReduces)476 XLA_TEST_F(CollectiveOpsTest, AllReduce_CombinableAllReduces) {
477 std::string hlo_string = R"(
478 HloModule test
479
480 apply_op {
481 x = f32[] parameter(0)
482 y = f32[] parameter(1)
483 ROOT apply_op = f32[] add(x, y)
484 }
485
486 ENTRY test_computation {
487 p0 = f32[5] parameter(0)
488 p1 = f32[5] parameter(1)
489 crs0 = f32[5] all-reduce(p0), replica_groups={}, to_apply=apply_op
490 crs1 = f32[5] all-reduce(p1), replica_groups={}, to_apply=apply_op
491 ROOT out = (f32[5], f32[5]) tuple(f32[5] crs0, f32[5] crs1)
492 }
493 )";
494 static constexpr int kNumReplicas = 2;
495 auto config = GetModuleConfigForTest();
496 config.set_replica_count(kNumReplicas);
497 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
498 ParseAndReturnVerifiedModule(hlo_string, config));
499
500 std::vector<float> input0_vec = {1., 2., 3., 4., 5.};
501 auto input0_literal = LiteralUtil::CreateR1<float>(input0_vec);
502 std::vector<float> input1_vec = {7., 3., 4., 1., 2.};
503 auto input1_literal = LiteralUtil::CreateR1<float>(input1_vec);
504
505 TF_ASSERT_OK_AND_ASSIGN(
506 std::vector<Literal> results,
507 ExecuteReplicated(std::move(module), {&input0_literal, &input1_literal},
508 /*num_replicas=*/kNumReplicas,
509 /*use_threads=*/true));
510 std::vector<float> expected0_vec = {2., 4., 6., 8., 10.};
511 auto expected0_literal = LiteralUtil::CreateR1<float>(expected0_vec);
512 std::vector<float> expected1_vec = {14., 6., 8., 2., 4.};
513 auto expected1_literal = LiteralUtil::CreateR1<float>(expected1_vec);
514 for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) {
515 auto rs = results[replica_idx].DecomposeTuple();
516 EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected0_literal, rs[0],
517 ErrorSpec{1e-5, 1e-5}));
518 EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected1_literal, rs[1],
519 ErrorSpec{1e-5, 1e-5}));
520 }
521 }
522
523 // Runs an all-reduce with three partitions:
524 // {0}, {1,2}, {3}
525 // meaning, the all-reduce is a nop for devices 0 and 3, and only devices 1 and
526 // 2 actually exchange data with each other.
XLA_TEST_F(CollectiveOpsTest,AllReduce_ThreeReplicaGroups)527 XLA_TEST_F(CollectiveOpsTest, AllReduce_ThreeReplicaGroups) {
528 // Test a prime number so it's not all powers of 2.
529 const int64_t kNumElems = 137;
530
531 auto config = GetModuleConfigForTest();
532 config.set_replica_count(4);
533 std::vector<float> input_vec(kNumElems);
534 absl::c_iota(input_vec, 0);
535 auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
536 auto module = MakeCrsModule(
537 /*shape_str=*/input_literal.shape(),
538 /*replica_groups=*/{{0}, {1, 2}, {3}}, config);
539
540 TF_ASSERT_OK_AND_ASSIGN(
541 std::vector<Literal> results,
542 ExecuteReplicated(std::move(module), {&input_literal}, /*num_replicas=*/4,
543 /*use_threads=*/true));
544
545 ASSERT_EQ(results.size(), 4);
546
547 std::vector<float> input_vec_doubled;
548 for (float n : input_vec) {
549 input_vec_doubled.push_back(n * 2);
550 }
551 auto input_literal_doubled = LiteralUtil::CreateR1<float>(input_vec_doubled);
552
553 EXPECT_TRUE(LiteralTestUtil::Equal(input_literal, results[0]));
554 EXPECT_TRUE(LiteralTestUtil::Equal(input_literal_doubled, results[1]));
555 EXPECT_TRUE(LiteralTestUtil::Equal(input_literal_doubled, results[2]));
556 EXPECT_TRUE(LiteralTestUtil::Equal(input_literal, results[3]));
557 }
558
XLA_TEST_F(CollectiveOpsTest,AllReduce_Degenerate)559 XLA_TEST_F(CollectiveOpsTest, AllReduce_Degenerate) {
560 const char* const kModuleStr = R"(
561 HloModule test
562
563 apply_op {
564 x = u32[] parameter(0)
565 y = u32[] parameter(1)
566 ROOT apply_op = u32[] add(x, y)
567 }
568
569 ENTRY test_computation {
570 id = u32[] replica-id()
571 ROOT crs = u32[] all-reduce(id), replica_groups={{0},{1},{2},{3}}, to_apply=apply_op
572 }
573 )";
574 static constexpr int kNumReplicas = 4;
575 auto config = GetModuleConfigForTest();
576 config.set_replica_count(kNumReplicas);
577 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
578 ParseAndReturnVerifiedModule(kModuleStr, config));
579 TF_ASSERT_OK_AND_ASSIGN(
580 std::vector<Literal> results,
581 ExecuteReplicated(std::move(module), {}, /*num_replicas=*/kNumReplicas,
582 /*use_threads=*/true));
583
584 ASSERT_EQ(results.size(), kNumReplicas);
585 for (int i = 0; i < kNumReplicas; ++i) {
586 LiteralTestUtil::ExpectR0Equal<uint32_t>(i, results[i]);
587 }
588 }
589
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AsyncAllReduce))590 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllReduce)) {
591 const absl::string_view kModuleStr = R"(
592 HloModule test
593
594 apply_op {
595 x = u32[] parameter(0)
596 y = u32[] parameter(1)
597 ROOT apply_op = u32[] add(x, y)
598 }
599
600 ENTRY test_computation {
601 id = u32[] replica-id()
602 start = (u32[], u32[]) all-reduce-start(id), to_apply=apply_op
603 ROOT done = u32[] all-reduce-done(start)
604 }
605 )";
606 static constexpr int kNumReplicas = 4;
607 HloModuleConfig config = GetModuleConfigForTest();
608 config.set_replica_count(kNumReplicas);
609 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
610 ParseAndReturnVerifiedModule(kModuleStr, config));
611 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
612 ExecuteReplicated(std::move(module), {}, kNumReplicas,
613 /*use_threads=*/true));
614
615 ASSERT_EQ(results.size(), kNumReplicas);
616 uint32_t expected = 6; // sum [0,4)
617 for (int i = 0; i < kNumReplicas; ++i) {
618 LiteralTestUtil::ExpectR0Equal<uint32_t>(expected, results[i]);
619 }
620 }
621
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AsyncAllReduceTwoOperands))622 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllReduceTwoOperands)) {
623 const absl::string_view kModuleStr = R"(
624 HloModule test
625
626 apply_op {
627 x = u32[] parameter(0)
628 y = u32[] parameter(1)
629 ROOT apply_op = u32[] add(x, y)
630 }
631
632 ENTRY test_computation {
633 id = u32[] replica-id()
634 id2 = u32[] multiply(id, id)
635 start = ((u32[], u32[]), (u32[], u32[])) all-reduce-start(id, id2), to_apply=apply_op
636 ROOT done = (u32[], u32[]) all-reduce-done(start)
637 }
638 )";
639 static constexpr int kNumReplicas = 4;
640 HloModuleConfig config = GetModuleConfigForTest();
641 config.set_replica_count(kNumReplicas);
642 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
643 ParseAndReturnVerifiedModule(kModuleStr, config));
644 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
645 ExecuteReplicated(std::move(module), {}, kNumReplicas,
646 /*use_threads=*/true));
647
648 ASSERT_EQ(results.size(), kNumReplicas);
649 uint32_t expected0 = 6; // sum [0,4)
650 uint32_t expected1 = 14; // sum squares [0,4)
651 for (int i = 0; i < kNumReplicas; ++i) {
652 std::vector<Literal> replica_results = results[i].DecomposeTuple();
653 LiteralTestUtil::ExpectR0Equal<uint32_t>(expected0, replica_results[0]);
654 LiteralTestUtil::ExpectR0Equal<uint32_t>(expected1, replica_results[1]);
655 }
656 }
657
XLA_TEST_F(CollectiveOpsTest,ReplicaId)658 XLA_TEST_F(CollectiveOpsTest, ReplicaId) {
659 const char* const kModuleStr = R"(
660 HloModule test
661 ENTRY test_computation {
662 id = u32[] replica-id()
663 ROOT out = u32[] copy(id)
664 }
665 )";
666 const int64_t kNumReplicas = 4;
667
668 auto config = GetModuleConfigForTest();
669 config.set_replica_count(kNumReplicas);
670 TF_ASSERT_OK_AND_ASSIGN(auto module,
671 ParseAndReturnVerifiedModule(kModuleStr));
672
673 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
674 ExecuteReplicated(std::move(module), {}, kNumReplicas,
675 /*use_threads=*/true));
676
677 ASSERT_EQ(results.size(), kNumReplicas);
678 for (uint32 i = 0; i < kNumReplicas; ++i) {
679 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR0(i), results[i]));
680 }
681 }
682
XLA_TEST_F(CollectiveOpsTest,CollectivePermute_Simple)683 XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) {
684 const char* const kModuleStr = R"(
685 HloModule test
686 ENTRY test_computation {
687 replica = u32[] replica-id()
688 ten = u32[] constant(10)
689 sum = u32[] add(replica, ten)
690 p = u32[2] broadcast(sum), dimensions={}
691 permute = u32[2] collective-permute(p), source_target_pairs={{1,0}, {0,1}, {2,2}}
692 ROOT copy = u32[2] copy(permute)
693 }
694 )";
695 const int64_t kNumReplicas = 4;
696
697 auto config = GetModuleConfigForTest();
698 config.set_replica_count(kNumReplicas);
699 TF_ASSERT_OK_AND_ASSIGN(auto module,
700 ParseAndReturnVerifiedModule(kModuleStr, config));
701
702 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
703 ExecuteReplicated(std::move(module), {}, kNumReplicas,
704 /*use_threads=*/true));
705 ASSERT_EQ(results.size(), kNumReplicas);
706 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({11, 11}),
707 results[0]));
708 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({10, 10}),
709 results[1]));
710 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({12, 12}),
711 results[2]));
712 // Nothing writes to replica 3, so it is memzero'ed.
713 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({0, 0}),
714 results[3]));
715 }
716
XLA_TEST_F(CollectiveOpsTest,CollectivePermute_Degnerate)717 XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Degnerate) {
718 const char* const kModuleStr = R"(
719 HloModule test
720 ENTRY test_computation {
721 replica = u32[] replica-id()
722 ten = u32[] constant(10)
723 sum = u32[] add(replica, ten)
724 p = u32[2] broadcast(sum), dimensions={}
725 permute = u32[2] collective-permute(p), source_target_pairs={{0,0}, {1,1}, {2,2}, {3,3}}
726 ROOT copy = u32[2] copy(permute)
727 }
728 )";
729 const int64_t kNumReplicas = 4;
730
731 auto config = GetModuleConfigForTest();
732 config.set_replica_count(kNumReplicas);
733 TF_ASSERT_OK_AND_ASSIGN(auto module,
734 ParseAndReturnVerifiedModule(kModuleStr, config));
735
736 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
737 ExecuteReplicated(std::move(module), {}, kNumReplicas,
738 /*use_threads=*/true));
739 ASSERT_EQ(results.size(), kNumReplicas);
740 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({10, 10}),
741 results[0]));
742 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({11, 11}),
743 results[1]));
744 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({12, 12}),
745 results[2]));
746 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({13, 13}),
747 results[3]));
748 }
749
XLA_TEST_F(CollectiveOpsTest,CollectivePermute_NoDegnerate)750 XLA_TEST_F(CollectiveOpsTest, CollectivePermute_NoDegnerate) {
751 const char* const kModuleStr = R"(
752 HloModule test
753 ENTRY test_computation {
754 replica = u32[] replica-id()
755 ten = u32[] constant(10)
756 sum = u32[] add(replica, ten)
757 p = u32[2] broadcast(sum), dimensions={}
758 permute = u32[2] collective-permute(p), source_target_pairs={{0,0}, {1,1}, {2,2}}
759 ROOT copy = u32[2] copy(permute)
760 }
761 )";
762 const int64_t kNumReplicas = 4;
763
764 auto config = GetModuleConfigForTest();
765 config.set_replica_count(kNumReplicas);
766 TF_ASSERT_OK_AND_ASSIGN(auto module,
767 ParseAndReturnVerifiedModule(kModuleStr, config));
768
769 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
770 ExecuteReplicated(std::move(module), {}, kNumReplicas,
771 /*use_threads=*/true));
772 ASSERT_EQ(results.size(), kNumReplicas);
773 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({10, 10}),
774 results[0]));
775 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({11, 11}),
776 results[1]));
777 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({12, 12}),
778 results[2]));
779 // Nothing writes to replica 3, so it is memzero'ed.
780 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({0, 0}),
781 results[3]));
782 }
783
XLA_TEST_F(CollectiveOpsTest,CollectivePermute_Rotate)784 XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Rotate) {
785 const char* const kModuleStr = R"(
786 HloModule test
787 ENTRY test_computation {
788 replica = u32[] replica-id()
789 ten = u32[] constant(10)
790 sum = u32[] add(replica, ten)
791 p = u32[2] broadcast(sum), dimensions={}
792 permute = u32[2] collective-permute(p), source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}}
793 ROOT copy = u32[2] copy(permute)
794 }
795 )";
796 const int64_t kNumReplicas = 4;
797
798 auto config = GetModuleConfigForTest();
799 config.set_replica_count(kNumReplicas);
800 TF_ASSERT_OK_AND_ASSIGN(auto module,
801 ParseAndReturnVerifiedModule(kModuleStr, config));
802
803 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
804 ExecuteReplicated(std::move(module), {}, kNumReplicas,
805 /*use_threads=*/true));
806 ASSERT_EQ(results.size(), kNumReplicas);
807 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({13, 13}),
808 results[0]));
809 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({10, 10}),
810 results[1]));
811 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({11, 11}),
812 results[2]));
813 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({12, 12}),
814 results[3]));
815 }
816
XLA_TEST_F(CollectiveOpsTest,AllToAll_EmptyReplicaGroups)817 XLA_TEST_F(CollectiveOpsTest, AllToAll_EmptyReplicaGroups) {
818 const char* const kModuleStr = R"(
819 HloModule test
820 ENTRY test_computation {
821 id = u32[] replica-id()
822 id2 = u32[2] broadcast(id), dimensions={}
823 a0 = u32[2] constant({10, 15})
824 b0 = u32[2] constant({20, 25})
825 c0 = u32[2] constant({30, 35})
826 d0 = u32[2] constant({40, 45})
827 a1 = u32[2] add(id2, a0)
828 b1 = u32[2] add(id2, b0)
829 c1 = u32[2] add(id2, c0)
830 d1 = u32[2] add(id2, d0)
831 all2all = (u32[2], u32[2], u32[2], u32[2]) all-to-all(a1, b1, c1, d1), replica_groups={}
832 a_prime = u32[2] get-tuple-element(all2all), index=0
833 b_prime = u32[2] get-tuple-element(all2all), index=1
834 c_prime = u32[2] get-tuple-element(all2all), index=2
835 d_prime = u32[2] get-tuple-element(all2all), index=3
836 ROOT out = u32[8] concatenate(a_prime, b_prime, c_prime, d_prime), dimensions={0}
837 }
838 )";
839 const int64_t kNumReplicas = 4;
840 auto config = GetModuleConfigForTest(kNumReplicas);
841 TF_ASSERT_OK_AND_ASSIGN(auto module,
842 ParseAndReturnVerifiedModule(kModuleStr, config));
843
844 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
845 ExecuteReplicated(std::move(module), {}, kNumReplicas,
846 /*use_threads=*/true));
847 ASSERT_EQ(results.size(), kNumReplicas);
848 LiteralTestUtil::ExpectR1Equal<uint32>({10, 15, 11, 16, 12, 17, 13, 18},
849 results[0]);
850 LiteralTestUtil::ExpectR1Equal<uint32>({20, 25, 21, 26, 22, 27, 23, 28},
851 results[1]);
852 LiteralTestUtil::ExpectR1Equal<uint32>({30, 35, 31, 36, 32, 37, 33, 38},
853 results[2]);
854 LiteralTestUtil::ExpectR1Equal<uint32>({40, 45, 41, 46, 42, 47, 43, 48},
855 results[3]);
856 }
857
XLA_TEST_F(CollectiveOpsTest,AllToAll_OrderedReplicaGroups)858 XLA_TEST_F(CollectiveOpsTest, AllToAll_OrderedReplicaGroups) {
859 const char* const kModuleStr = R"(
860 HloModule test
861 ENTRY test_computation {
862 id = u32[] replica-id()
863 id2 = u32[2] broadcast(id), dimensions={}
864 a0 = u32[2] constant({10, 15})
865 b0 = u32[2] constant({20, 25})
866 c0 = u32[2] constant({30, 35})
867 d0 = u32[2] constant({40, 45})
868 a1 = u32[2] add(id2, a0)
869 b1 = u32[2] add(id2, b0)
870 c1 = u32[2] add(id2, c0)
871 d1 = u32[2] add(id2, d0)
872 all2all = (u32[2], u32[2], u32[2], u32[2]) all-to-all(a1, b1, c1, d1), replica_groups={{3,2,1,0}}
873 a_prime = u32[2] get-tuple-element(all2all), index=0
874 b_prime = u32[2] get-tuple-element(all2all), index=1
875 c_prime = u32[2] get-tuple-element(all2all), index=2
876 d_prime = u32[2] get-tuple-element(all2all), index=3
877 ROOT out = u32[8] concatenate(a_prime, b_prime, c_prime, d_prime), dimensions={0}
878 }
879 )";
880 const int64_t kNumReplicas = 4;
881 auto config = GetModuleConfigForTest(kNumReplicas);
882 TF_ASSERT_OK_AND_ASSIGN(auto module,
883 ParseAndReturnVerifiedModule(kModuleStr, config));
884
885 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
886 ExecuteReplicated(std::move(module), {}, kNumReplicas,
887 /*use_threads=*/true));
888 ASSERT_EQ(results.size(), kNumReplicas);
889 LiteralTestUtil::ExpectR1Equal<uint32>({43, 48, 42, 47, 41, 46, 40, 45},
890 results[0]);
891 LiteralTestUtil::ExpectR1Equal<uint32>({33, 38, 32, 37, 31, 36, 30, 35},
892 results[1]);
893 LiteralTestUtil::ExpectR1Equal<uint32>({23, 28, 22, 27, 21, 26, 20, 25},
894 results[2]);
895 LiteralTestUtil::ExpectR1Equal<uint32>({13, 18, 12, 17, 11, 16, 10, 15},
896 results[3]);
897 }
898
XLA_TEST_F(CollectiveOpsTest,AllToAll_TwoReplicaGroups)899 XLA_TEST_F(CollectiveOpsTest, AllToAll_TwoReplicaGroups) {
900 const char* const kModuleStr = R"(
901 HloModule test
902 ENTRY test_computation {
903 id = u32[] replica-id()
904 id2 = u32[2] broadcast(id), dimensions={}
905 a0 = u32[2] constant({10, 15})
906 b0 = u32[2] constant({20, 25})
907 a1 = u32[2] add(id2, a0)
908 b1 = u32[2] add(id2, b0)
909 all2all = (u32[2], u32[2]) all-to-all(a1, b1), replica_groups={{2,1},{3,0}}
910 a_prime = u32[2] get-tuple-element(all2all), index=0
911 b_prime = u32[2] get-tuple-element(all2all), index=1
912 ROOT out = u32[4] concatenate(a_prime, b_prime), dimensions={0}
913 }
914 )";
915 const int64_t kNumReplicas = 4;
916 auto config = GetModuleConfigForTest(kNumReplicas);
917 TF_ASSERT_OK_AND_ASSIGN(auto module,
918 ParseAndReturnVerifiedModule(kModuleStr, config));
919
920 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
921 ExecuteReplicated(std::move(module), {}, kNumReplicas,
922 /*use_threads=*/true));
923 ASSERT_EQ(results.size(), kNumReplicas);
924 LiteralTestUtil::ExpectR1Equal<uint32>({23, 28, 20, 25}, results[0]);
925 LiteralTestUtil::ExpectR1Equal<uint32>({22, 27, 21, 26}, results[1]);
926 LiteralTestUtil::ExpectR1Equal<uint32>({12, 17, 11, 16}, results[2]);
927 LiteralTestUtil::ExpectR1Equal<uint32>({13, 18, 10, 15}, results[3]);
928 }
929
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllToAll_SplitDimension))930 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_SplitDimension)) {
931 const char* const kModuleStr = R"(
932 HloModule test
933 ENTRY test_computation {
934 id = u32[] replica-id()
935 id2 = u32[4, 2] broadcast(id), dimensions={}
936 a0 = u32[4, 2] constant({{10, 15}, {20, 25}, {30, 35}, {40, 45}})
937 a1 = u32[4, 2] add(id2, a0)
938 all2all = u32[4, 2] all-to-all(a1), replica_groups={{0,1,2,3}}, dimensions={0}
939 ROOT out = u32[8] reshape(all2all)
940 }
941 )";
942 const int64_t kNumReplicas = 4;
943 auto config = GetModuleConfigForTest(kNumReplicas);
944 TF_ASSERT_OK_AND_ASSIGN(auto module,
945 ParseAndReturnVerifiedModule(kModuleStr, config));
946
947 TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
948 ExecuteReplicated(std::move(module), {}, kNumReplicas,
949 /*use_threads=*/true));
950 ASSERT_EQ(results.size(), kNumReplicas);
951 LiteralTestUtil::ExpectR1Equal<uint32>({10, 15, 11, 16, 12, 17, 13, 18},
952 results[0]);
953 LiteralTestUtil::ExpectR1Equal<uint32>({20, 25, 21, 26, 22, 27, 23, 28},
954 results[1]);
955 LiteralTestUtil::ExpectR1Equal<uint32>({30, 35, 31, 36, 32, 37, 33, 38},
956 results[2]);
957 LiteralTestUtil::ExpectR1Equal<uint32>({40, 45, 41, 46, 42, 47, 43, 48},
958 results[3]);
959 }
960
XLA_TEST_F(CollectiveOpsTest,AllGather_Dim0)961 XLA_TEST_F(CollectiveOpsTest, AllGather_Dim0) {
962 const char* const kModuleStr = R"(
963 HloModule test
964 ENTRY test_computation {
965 id = u32[] replica-id()
966 id2 = u32[1, 2] broadcast(id), dimensions={}
967 a0 = u32[1, 2] constant({{10, 15}})
968 a1 = u32[1, 2] add(id2, a0)
969 allgather = u32[4, 2] all-gather(a1), dimensions={0}
970 ROOT out = u32[8] reshape(allgather)
971 }
972 )";
973 const int64_t kNumReplicas = 4;
974 auto config = GetModuleConfigForTest(kNumReplicas);
975 TF_ASSERT_OK_AND_ASSIGN(auto module,
976 ParseAndReturnVerifiedModule(kModuleStr, config));
977
978 TF_ASSERT_OK_AND_ASSIGN(
979 std::vector<Literal> results,
980 ExecuteReplicated(std::move(module), {}, kNumReplicas,
981 /*use_threads=*/true, /*run_hlo_passes=*/true));
982 ASSERT_EQ(results.size(), kNumReplicas);
983 for (const Literal& result : results) {
984 LiteralTestUtil::ExpectR1Equal<uint32>({10, 15, 11, 16, 12, 17, 13, 18},
985 result);
986 }
987 }
988
XLA_TEST_F(CollectiveOpsTest,AllGather_Dim1)989 XLA_TEST_F(CollectiveOpsTest, AllGather_Dim1) {
990 const char* const kModuleStr = R"(
991 HloModule test
992 ENTRY test_computation {
993 id = u32[] replica-id()
994 id2 = u32[2, 1] broadcast(id), dimensions={}
995 a0 = u32[2, 1] constant({{10}, {15}})
996 a1 = u32[2, 1] add(id2, a0)
997 allgather = u32[2, 4] all-gather(a1), dimensions={1}
998 ROOT out = u32[8] reshape(allgather)
999 }
1000 )";
1001 const int64_t kNumReplicas = 4;
1002 auto config = GetModuleConfigForTest(kNumReplicas);
1003 TF_ASSERT_OK_AND_ASSIGN(auto module,
1004 ParseAndReturnVerifiedModule(kModuleStr, config));
1005
1006 TF_ASSERT_OK_AND_ASSIGN(
1007 std::vector<Literal> results,
1008 ExecuteReplicated(std::move(module), {}, kNumReplicas,
1009 /*use_threads=*/true, /*run_hlo_passes=*/true));
1010 ASSERT_EQ(results.size(), kNumReplicas);
1011 for (const Literal& result : results) {
1012 LiteralTestUtil::ExpectR1Equal<uint32>({10, 11, 12, 13, 15, 16, 17, 18},
1013 result);
1014 }
1015 }
1016
XLA_TEST_F(CollectiveOpsTest,AllReduce_TupleAllReduce)1017 XLA_TEST_F(CollectiveOpsTest, AllReduce_TupleAllReduce) {
1018 std::string hlo_string = R"(
1019 HloModule test
1020
1021 apply_op {
1022 x = f32[] parameter(0)
1023 y = f32[] parameter(1)
1024 ROOT apply_op = f32[] add(x, y)
1025 }
1026
1027 ENTRY test_computation {
1028 p0 = f32[5] parameter(0)
1029 p1 = f32[7] parameter(1)
1030 ROOT out = (f32[5], f32[7]) all-reduce(p0, p1), replica_groups={}, to_apply=apply_op
1031 }
1032 )";
1033 static constexpr int kNumReplicas = 2;
1034 auto config = GetModuleConfigForTest();
1035 config.set_replica_count(kNumReplicas);
1036 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1037 ParseAndReturnVerifiedModule(hlo_string, config));
1038
1039 std::vector<float> input0_vec = {1., 2., 3., 4., 5.};
1040 auto input0_literal = LiteralUtil::CreateR1<float>(input0_vec);
1041 std::vector<float> input1_vec = {
1042 7., 3., 4., 1., 2., 3., 4.,
1043 };
1044 auto input1_literal = LiteralUtil::CreateR1<float>(input1_vec);
1045
1046 TF_ASSERT_OK_AND_ASSIGN(
1047 std::vector<Literal> results,
1048 ExecuteReplicated(std::move(module), {&input0_literal, &input1_literal},
1049 /*num_replicas=*/kNumReplicas,
1050 /*use_threads=*/true));
1051 std::vector<float> expected0_vec = {2., 4., 6., 8., 10.};
1052 auto expected0_literal = LiteralUtil::CreateR1<float>(expected0_vec);
1053 std::vector<float> expected1_vec = {14., 6., 8., 2., 4., 6., 8.};
1054 auto expected1_literal = LiteralUtil::CreateR1<float>(expected1_vec);
1055 for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) {
1056 auto rs = results[replica_idx].DecomposeTuple();
1057 EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected0_literal, rs[0],
1058 ErrorSpec{1e-5, 1e-5}));
1059 EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected1_literal, rs[1],
1060 ErrorSpec{1e-5, 1e-5}));
1061 }
1062 }
1063
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllGatherMixedTypes))1064 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllGatherMixedTypes)) {
1065 const char* const kModuleStr = R"(
1066 HloModule test
1067 ENTRY test_computation {
1068 id = u32[] replica-id()
1069 p0 = u32[2, 1] broadcast(id), dimensions={}
1070 p1 = f32[2, 1] convert(p0)
1071 allgather = (u32[2, 2], f32[2, 2]) all-gather(p0, p1), dimensions={1}
1072 ag0 = u32[2, 2] get-tuple-element(allgather), index=0
1073 ag1 = f32[2, 2] get-tuple-element(allgather), index=1
1074 r0 = u32[4] reshape(ag0)
1075 r1 = f32[4] reshape(ag1)
1076 ROOT out = (u32[4], f32[4]) tuple(r0, r1)
1077 }
1078 )";
1079 const int64_t kNumReplicas = 2;
1080 auto config = GetModuleConfigForTest(kNumReplicas);
1081 TF_ASSERT_OK_AND_ASSIGN(auto module,
1082 ParseAndReturnVerifiedModule(kModuleStr, config));
1083
1084 TF_ASSERT_OK_AND_ASSIGN(
1085 std::vector<Literal> results,
1086 ExecuteReplicated(std::move(module), {}, kNumReplicas,
1087 /*use_threads=*/true, /*run_hlo_passes=*/true));
1088 for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) {
1089 auto rs = results[replica_idx].DecomposeTuple();
1090 LiteralTestUtil::ExpectR1Equal<uint32>({0, 1, 0, 1}, rs[0]);
1091 LiteralTestUtil::ExpectR1Near<float>({0.0, 1.0, 0.0, 1.0}, rs[1],
1092 ErrorSpec{1e-5, 1e-5});
1093 }
1094 }
1095
XLA_TEST_F(CollectiveOpsTest,ReduceScatter)1096 XLA_TEST_F(CollectiveOpsTest, ReduceScatter) {
1097 const char* const kModuleStr = R"(
1098 HloModule test
1099 add {
1100 lhs = u32[] parameter(0)
1101 rhs = u32[] parameter(1)
1102 ROOT add = u32[] add(lhs, rhs)
1103 }
1104
1105 ENTRY main {
1106 c0 = u32[8] constant({1, 2, 3, 4, 5, 6, 7, 8})
1107 c1 = u32[8] constant({10, 11, 12, 13, 14, 15, 16, 17})
1108 zero = u32[] constant(0)
1109 id = u32[] replica-id()
1110 p = pred[] compare(id, zero), direction=EQ
1111 pb = pred[8] broadcast(p), dimensions={}
1112 // data = c0 for replica 0 and c1 for replica 1
1113 data = u32[8] select(pb, c0, c1)
1114 ROOT ars = u32[4] reduce-scatter(data), replica_groups={},
1115 dimensions={0}, to_apply=add
1116 }
1117 )";
1118
1119 const int64_t kNumReplicas = 2;
1120 auto config = GetModuleConfigForTest(kNumReplicas);
1121 TF_ASSERT_OK_AND_ASSIGN(auto module,
1122 ParseAndReturnVerifiedModule(kModuleStr, config));
1123
1124 TF_ASSERT_OK_AND_ASSIGN(
1125 std::vector<Literal> results,
1126 ExecuteReplicated(std::move(module), {}, kNumReplicas,
1127 /*use_threads=*/true, /*run_hlo_passes=*/true));
1128 LiteralTestUtil::ExpectR1Equal<uint32>({11, 13, 15, 17}, results[0]);
1129 LiteralTestUtil::ExpectR1Equal<uint32>({19, 21, 23, 25}, results[1]);
1130 }
1131
XLA_TEST_F(CollectiveOpsTest,ReduceScatter_Dim1)1132 XLA_TEST_F(CollectiveOpsTest, ReduceScatter_Dim1) {
1133 const char* const kModuleStr = R"(
1134 HloModule test
1135 add {
1136 lhs = u32[] parameter(0)
1137 rhs = u32[] parameter(1)
1138 ROOT add = u32[] add(lhs, rhs)
1139 }
1140
1141 ENTRY main {
1142 c0 = u32[2, 4] constant({{ 1, 2, 3, 4}, { 5, 6, 7, 8}})
1143 c1 = u32[2, 4] constant({{10, 11, 12, 13}, {14, 15, 16, 17}})
1144 zero = u32[] constant(0)
1145 id = u32[] replica-id()
1146 p = pred[] compare(id, zero), direction=EQ
1147 pb = pred[2, 4] broadcast(p), dimensions={}
1148 // data = c0 for replica 0 and c1 for replica 1
1149 data = u32[2, 4] select(pb, c0, c1)
1150 // all-reduce result = {{11, 13, 15, 17}, {19, 21, 23, 25}}
1151 ars = u32[2, 2] reduce-scatter(data), replica_groups={},
1152 dimensions={1}, to_apply=add
1153 ROOT r = u32[4] reshape(ars)
1154 }
1155 )";
1156
1157 const int64_t kNumReplicas = 2;
1158 auto config = GetModuleConfigForTest(kNumReplicas);
1159 TF_ASSERT_OK_AND_ASSIGN(auto module,
1160 ParseAndReturnVerifiedModule(kModuleStr, config));
1161
1162 TF_ASSERT_OK_AND_ASSIGN(
1163 std::vector<Literal> results,
1164 ExecuteReplicated(std::move(module), {}, kNumReplicas,
1165 /*use_threads=*/true, /*run_hlo_passes=*/true));
1166 LiteralTestUtil::ExpectR1Equal<uint32>({11, 13, 19, 21}, results[0]);
1167 LiteralTestUtil::ExpectR1Equal<uint32>({15, 17, 23, 25}, results[1]);
1168 }
1169
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllReduceReassociate))1170 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllReduceReassociate)) {
1171 const char* const kModuleStr = R"(
1172 HloModule m
1173 sum {
1174 a = f32[] parameter(0)
1175 b = f32[] parameter(1)
1176 ROOT add.2 = f32[] add(a, b)
1177 }
1178
1179 ENTRY main {
1180 c0 = f32[8] constant({ 1, 2, 3, 4, 5, 6, 7, 8})
1181 c1 = f32[8] constant({ 11, 12, 13, 14, 15, 16, 17, 18})
1182 c2 = f32[8] constant({ 2, 3, 4, 5, 6, 7, 8, 9})
1183 c3 = f32[8] constant({ 12, 13, 14, 15, 16, 17, 18, 19})
1184 zero = u32[] constant(0)
1185 id = u32[] replica-id()
1186 p = pred[] compare(id, zero), direction=EQ
1187 pb = pred[8] broadcast(p), dimensions={}
1188 // data0 = c0 for replica 0 and c1 for replica 1
1189 data0 = f32[8] select(pb, c0, c1)
1190 // data1 = c2 for replica 0 and c3 for replica 1
1191 data1 = f32[8] select(pb, c2, c3)
1192
1193 ar0 = f32[8] all-reduce(data0), replica_groups={}, to_apply=sum
1194 ar1 = f32[8] all-reduce(data1), replica_groups={}, to_apply=sum
1195 ROOT add = f32[8] add(ar0, ar1)
1196 }
1197 )";
1198 const int64_t kNumReplicas = 2;
1199 auto config = GetModuleConfigForTest(kNumReplicas);
1200 TF_ASSERT_OK_AND_ASSIGN(auto module,
1201 ParseAndReturnVerifiedModule(kModuleStr, config));
1202
1203 TF_ASSERT_OK_AND_ASSIGN(
1204 std::vector<Literal> results,
1205 ExecuteReplicated(std::move(module), {}, kNumReplicas,
1206 /*use_threads=*/true, /*run_hlo_passes=*/true));
1207
1208 const ErrorSpec es{1e-5, 1e-5};
1209 EXPECT_TRUE(LiteralTestUtil::NearOrEqual(results[0], results[1], es));
1210 LiteralTestUtil::ExpectR1Near<float>(
1211 {26.0, 30.0, 34.0, 38.0, 42.0, 46.0, 50.0, 54.0}, results[0], es);
1212 }
1213
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllGatherBroadcastReorder_NonUniform))1214 XLA_TEST_F(CollectiveOpsTest,
1215 DISABLED_ON_CPU(AllGatherBroadcastReorder_NonUniform)) {
1216 const char* const kModuleStr = R"(
1217 HloModule m
1218
1219 ENTRY main {
1220 c0 = u32[2, 3] constant({{ 1, 2, 3}, { 4, 5, 6}})
1221 c1 = u32[2, 3] constant({{10, 11, 12}, {13, 14, 15}})
1222 zero = u32[] constant(0)
1223 id = u32[] replica-id()
1224 p = pred[] compare(id, zero), direction=EQ
1225 pb = pred[2, 3] broadcast(p), dimensions={}
1226 // data = c0 for replica 0 and c1 for replica 1
1227 data = u32[2, 3] select(pb, c0, c1)
1228 bc = u32[2, 4, 3] broadcast(data), dimensions={0, 2}
1229 ROOT ag = u32[2, 4, 6] all-gather(bc), dimensions={2}, replica_groups={{0, 1}}
1230 }
1231 )";
1232
1233 const int64_t kNumReplicas = 2;
1234 auto config = GetModuleConfigForTest(kNumReplicas);
1235 TF_ASSERT_OK_AND_ASSIGN(auto module,
1236 ParseAndReturnVerifiedModule(kModuleStr, config));
1237
1238 TF_ASSERT_OK_AND_ASSIGN(
1239 std::vector<Literal> results,
1240 ExecuteReplicated(std::move(module), {}, kNumReplicas,
1241 /*use_threads=*/true, /*run_hlo_passes=*/true));
1242
1243 EXPECT_TRUE(LiteralTestUtil::Equal(results[0], results[1]));
1244 LiteralTestUtil::ExpectR3Equal<uint32_t>({{{1, 2, 3, 10, 11, 12},
1245 {1, 2, 3, 10, 11, 12},
1246 {1, 2, 3, 10, 11, 12},
1247 {1, 2, 3, 10, 11, 12}},
1248 {{4, 5, 6, 13, 14, 15},
1249 {4, 5, 6, 13, 14, 15},
1250 {4, 5, 6, 13, 14, 15},
1251 {4, 5, 6, 13, 14, 15}}},
1252 results[0]);
1253 }
1254
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllGatherBroadcastReorder_Uniform))1255 XLA_TEST_F(CollectiveOpsTest,
1256 DISABLED_ON_CPU(AllGatherBroadcastReorder_Uniform)) {
1257 const char* const kModuleStr = R"(
1258 HloModule m
1259
1260 ENTRY main {
1261 c0 = u32[2, 3] constant({{ 1, 2, 3}, { 4, 5, 6}})
1262 c1 = u32[2, 3] constant({{10, 11, 12}, {13, 14, 15}})
1263 zero = u32[] constant(0)
1264 id = u32[] replica-id()
1265 p = pred[] compare(id, zero), direction=EQ
1266 pb = pred[2, 3] broadcast(p), dimensions={}
1267 // data = c0 for replica 0 and c1 for replica 1
1268 data = u32[2, 3] select(pb, c0, c1)
1269 bc = u32[2, 4, 3] broadcast(data), dimensions={0, 2}
1270 ROOT ag = u32[2, 8, 3] all-gather(bc), dimensions={1}, replica_groups={{0, 1}}
1271 }
1272 )";
1273
1274 const int64_t kNumReplicas = 2;
1275 auto config = GetModuleConfigForTest(kNumReplicas);
1276 TF_ASSERT_OK_AND_ASSIGN(auto module,
1277 ParseAndReturnVerifiedModule(kModuleStr, config));
1278
1279 TF_ASSERT_OK_AND_ASSIGN(
1280 std::vector<Literal> results,
1281 ExecuteReplicated(std::move(module), {}, kNumReplicas,
1282 /*use_threads=*/true, /*run_hlo_passes=*/true));
1283 EXPECT_TRUE(LiteralTestUtil::Equal(results[0], results[1]));
1284 LiteralTestUtil::ExpectR3Equal<uint32_t>({{{1, 2, 3},
1285 {1, 2, 3},
1286 {1, 2, 3},
1287 {1, 2, 3},
1288 {10, 11, 12},
1289 {10, 11, 12},
1290 {10, 11, 12},
1291 {10, 11, 12}},
1292 {{4, 5, 6},
1293 {4, 5, 6},
1294 {4, 5, 6},
1295 {4, 5, 6},
1296 {13, 14, 15},
1297 {13, 14, 15},
1298 {13, 14, 15},
1299 {13, 14, 15}}},
1300 results[0]);
1301 }
1302
1303 } // namespace
1304 } // namespace xla
1305