• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/strings/str_replace.h"
17 #include "absl/types/span.h"
18 #include "tensorflow/compiler/xla/literal.h"
19 #include "tensorflow/compiler/xla/primitive_util.h"
20 #include "tensorflow/compiler/xla/service/gpu/nccl_test_utils.h"
21 #include "tensorflow/compiler/xla/shape_util.h"
22 #include "tensorflow/compiler/xla/test.h"
23 #include "tensorflow/compiler/xla/test_helpers.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
26 #include "tensorflow/compiler/xla/tests/test_macros.h"
27 #include "tensorflow/core/lib/core/blocking_counter.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/lib/core/threadpool.h"
30 #include "tensorflow/core/platform/env.h"
31 
32 // Tests cross-GPU operations.
33 //
34 // This test requires at least four GPUs.  For instructions on running this
35 // within Google, see go/multi-gpu-unit-test.
36 
37 namespace xla {
38 namespace {
39 
40 using ::testing::IsSupersetOf;
41 
42 class CollectiveOpsTest : public HloTestBase {
43  public:
SetUpTestSuite()44   static void SetUpTestSuite() {
45     // Not needed structly, since this test exercises cross replica collective
46     // permute which does not use NCCL. But keeping it here for testing.
47     tensorflow::setenv("NCCL_LAUNCH_MODE", "PARALLEL", /*overwrite=*/1);
48     HloTestBase::SetUpTestSuite();
49   }
50 
51  protected:
MakeCrsModule(const Shape & shape,std::vector<std::vector<int64>> replica_groups,const HloModuleConfig & config,std::string op="add",std::string datatype="f32")52   std::unique_ptr<HloModule> MakeCrsModule(
53       const Shape& shape, std::vector<std::vector<int64>> replica_groups,
54       const HloModuleConfig& config, std::string op = "add",
55       std::string datatype = "f32") {
56     std::string hlo_template = R"(
57       HloModule test
58 
59       apply_op {
60         x = DATATYPE[] parameter(0)
61         y = DATATYPE[] parameter(1)
62         ROOT apply_op = DATATYPE[] OP(x, y)
63       }
64 
65       ENTRY test_computation {
66         p = SHAPE parameter(0)
67         p2 = SHAPE bitcast(p)
68         crs = SHAPE all-reduce(p2), replica_groups=REPLICA_GROUPS, to_apply=apply_op
69         copy = SHAPE copy(crs)
70         ROOT out = SHAPE bitcast(copy)
71       }
72     )";
73     std::vector<string> replica_group_strs;
74     for (const auto& g : replica_groups) {
75       replica_group_strs.push_back(
76           absl::StrFormat("{%s}", absl::StrJoin(g, ",")));
77     }
78     std::string shape_str = shape.ToString(/*print_layout=*/false);
79     if (shape_str == "f32[1]") {
80       // Exercise the scalar codepath.
81       hlo_template = absl::StrReplaceAll(
82           hlo_template,
83           {{"DATATYPE[SHAPE] bitcast(p)", "DATATYPE[] bitcast(p)"},
84            {"DATATYPE[SHAPE] all-reduce", "DATATYPE[] all-reduce"},
85            {"DATATYPE[SHAPE] copy", "DATATYPE[] copy"}});
86     }
87     std::string parameterized_hlo = absl::StrReplaceAll(
88         hlo_template,
89         {{"SHAPE", shape_str},
90          {"REPLICA_GROUPS",
91           absl::StrFormat("{%s}", absl::StrJoin(replica_group_strs, ", "))},
92          {"OP", op},
93          {"DATATYPE", datatype}});
94     return ParseAndReturnVerifiedModule(parameterized_hlo, config).ValueOrDie();
95   }
96 
97   template <typename LiteralType>
TestTwoReplicasOneOperand(std::string op,Literal input_value,Literal expected_value)98   void TestTwoReplicasOneOperand(std::string op, Literal input_value,
99                                  Literal expected_value) {
100     const int kNumReplicas = 2;
101     std::string dtype = primitive_util::LowercasePrimitiveTypeName(
102         primitive_util::NativeToPrimitiveType<LiteralType>());
103     auto config = GetModuleConfigForTest();
104     config.set_replica_count(kNumReplicas);
105     auto module = MakeCrsModule(
106         /*shape_str=*/input_value.shape(),
107         /*replica_groups=*/{}, config,
108         /*op=*/op, /*datatype=*/dtype);
109     TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
110                             ExecuteReplicated(std::move(module), {&input_value},
111                                               /*num_replicas=*/kNumReplicas,
112                                               /*use_threads=*/true));
113     for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) {
114       EXPECT_TRUE(LiteralTestUtil::NearOrEqual(
115           expected_value, results[replica_idx], ErrorSpec{1e-5, 1e-5}));
116     }
117   }
118 
119   template <typename LiteralType>
TestAllOpsForReduce()120   void TestAllOpsForReduce() {
121     auto cast = [&](int value) { return static_cast<LiteralType>(value); };
122     auto to_literal = [&](absl::Span<const LiteralType> values) {
123       return LiteralUtil::CreateR1<LiteralType>(values);
124     };
125     Literal input_value = to_literal({cast(1), cast(2), cast(3)});
126     TestTwoReplicasOneOperand<LiteralType>(
127         "add",
128         /*input_value=*/input_value.Clone(),
129         /*expected_value=*/to_literal({cast(2), cast(4), cast(6)}));
130     TestTwoReplicasOneOperand<LiteralType>(
131         "multiply",
132         /*input_value=*/input_value.Clone(),
133         /*expected_value=*/to_literal({cast(1), cast(4), cast(9)}));
134     TestTwoReplicasOneOperand<LiteralType>(
135         "maximum",
136         /*input_value=*/input_value.Clone(),
137         /*expected_value=*/to_literal({cast(1), cast(2), cast(3)}));
138     TestTwoReplicasOneOperand<LiteralType>(
139         "minimum",
140         /*input_value=*/input_value.Clone(),
141         /*expected_value=*/to_literal({cast(1), cast(2), cast(3)}));
142   }
143 };
144 
145 // Returns the non-empty subsets of {0, 1, ..., n}.  For example,
146 // PowerSetOfIota(3) = {{0}, {1}, {2}, {0,1}, {0,2}, {1,2}, {0,1,2}}.
PowerSetOfIota(int64_t n)147 std::vector<std::vector<int64>> PowerSetOfIota(int64_t n) {
148   std::vector<std::vector<int64>> power_set;
149   for (int64_t i = 1; i < (1 << n); ++i) {
150     power_set.emplace_back();
151     for (int64_t j = 0; j < n; ++j) {
152       if (i & (1 << j)) {
153         power_set.back().push_back(j);
154       }
155     }
156   }
157   return power_set;
158 }
159 
160 // Makes a DeviceAssignment assigning replica-id i to devices[i].
MakeDeviceAssn(std::vector<int64> devices)161 DeviceAssignment MakeDeviceAssn(std::vector<int64> devices) {
162   DeviceAssignment assn(/*replica_count=*/devices.size(),
163                         /*computation_count=*/1);
164   for (int64_t i = 0; i < devices.size(); ++i) {
165     assn(i, 0) = devices[i];
166   }
167   return assn;
168 }
169 
170 // Shorter alias for this function.
OpenNcclChannels()171 absl::flat_hash_set<GlobalDeviceId> OpenNcclChannels() {
172   return gpu::DevicesWithOpenNcclChannels();
173 }
174 
175 template <typename T>
ToHalf(T value)176 static Eigen::half ToHalf(T value) {
177   return static_cast<Eigen::half>(value);
178 }
179 
XLA_TEST_F(CollectiveOpsTest,AllReduce_sum_float32_2D)180 XLA_TEST_F(CollectiveOpsTest, AllReduce_sum_float32_2D) {
181   TestTwoReplicasOneOperand<float>(
182       "add",
183       /*input_value=*/LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}),
184       /*expected_value=*/LiteralUtil::CreateR2<float>({{2, 4}, {6, 8}}));
185 }
186 
XLA_TEST_F(CollectiveOpsTest,AllReduceSingleOutput_float32)187 XLA_TEST_F(CollectiveOpsTest, AllReduceSingleOutput_float32) {
188   TestTwoReplicasOneOperand<float>(
189       "add",
190       /*input_value=*/LiteralUtil::CreateR1<float>({1}),
191       /*expected_value=*/LiteralUtil::CreateR1<float>({2}));
192 }
193 
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_int8)194 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int8) {
195   TestAllOpsForReduce<int8>();
196 }
197 
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_uint8)198 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint8) {
199   TestAllOpsForReduce<uint8>();
200 }
201 
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_uint32)202 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint32) {
203   TestAllOpsForReduce<uint32>();
204 }
205 
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_int32)206 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int32) {
207   TestAllOpsForReduce<int32>();
208 }
209 
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_int64)210 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int64) {
211   TestAllOpsForReduce<int64>();
212 }
213 
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_uint64)214 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint64) {
215   TestAllOpsForReduce<uint64>();
216 }
217 
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_float32)218 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_float32) {
219   TestAllOpsForReduce<float>();
220 }
221 
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_double)222 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_double) {
223   TestAllOpsForReduce<double>();
224 }
225 
XLA_TEST_F(CollectiveOpsTest,AllReduceTwoReplicasOneOperand_half)226 XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_half) {
227   TestAllOpsForReduce<Eigen::half>();
228 }
229 
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllReduceTwoReplicasOneOperand_bfloat16))230 XLA_TEST_F(CollectiveOpsTest,
231            DISABLED_ON_CPU(AllReduceTwoReplicasOneOperand_bfloat16)) {
232   TestAllOpsForReduce<bfloat16>();
233 }
234 
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllReduce_sum_complex64))235 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllReduce_sum_complex64)) {
236   TestTwoReplicasOneOperand<complex64>(
237       "add",
238       /*input_value=*/LiteralUtil::CreateR1<complex64>({{1, 2}, {3, 4}}),
239       /*expected_value=*/LiteralUtil::CreateR1<complex64>({{2, 4}, {6, 8}}));
240 }
241 
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllReduce_sum_complex128))242 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllReduce_sum_complex128)) {
243   TestTwoReplicasOneOperand<complex128>(
244       "add",
245       /*input_value=*/LiteralUtil::CreateR1<complex128>({{1, 2}, {3, 4}}),
246       /*expected_value=*/LiteralUtil::CreateR1<complex128>({{2, 4}, {6, 8}}));
247 }
248 
XLA_TEST_F(CollectiveOpsTest,AllReduceAnd_Pred)249 XLA_TEST_F(CollectiveOpsTest, AllReduceAnd_Pred) {
250   // Test with equal elements.
251   TestTwoReplicasOneOperand<bool>(
252       "and",
253       /*input_value=*/LiteralUtil::CreateR1<bool>({true, false}),
254       /*expected_value=*/LiteralUtil::CreateR1<bool>({true, false}));
255 
256   // Test with {true, false}.
257   const char* hlo_module = R"(
258     HloModule test
259 
260     apply_op {
261       x = pred[] parameter(0)
262       y = pred[] parameter(1)
263       ROOT apply_op = pred[] and(x, y)
264     }
265 
266     ENTRY test_computation {
267       id = u32[] replica-id()
268       c = u32[] constant(0)
269       p = pred[] compare(id, c), direction=EQ
270       p2 = pred[1] bitcast(p)
271       crs = pred[1] all-reduce(p2), replica_groups={}, to_apply=apply_op
272       copy = pred[1] copy(crs)
273       ROOT out = pred[1] bitcast(copy)
274     }
275   )";
276 
277   auto config = GetModuleConfigForTest();
278   config.set_replica_count(2);
279   auto module = ParseAndReturnVerifiedModule(hlo_module, config).ValueOrDie();
280   TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
281                           ExecuteReplicated(std::move(module), {},
282                                             /*num_replicas=*/2,
283                                             /*use_threads=*/true));
284   for (int replica_idx = 0; replica_idx < 2; replica_idx++) {
285     EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<bool>({false}),
286                                        results[replica_idx]));
287   }
288 }
289 
XLA_TEST_F(CollectiveOpsTest,AllReduceOr_Pred)290 XLA_TEST_F(CollectiveOpsTest, AllReduceOr_Pred) {
291   // Test with equal elements.
292   TestTwoReplicasOneOperand<bool>(
293       "or",
294       /*input_value=*/LiteralUtil::CreateR1<bool>({true, false}),
295       /*expected_value=*/LiteralUtil::CreateR1<bool>({true, false}));
296 
297   // Test with {true, false}.
298   const char* hlo_module = R"(
299     HloModule test
300 
301     apply_op {
302       x = pred[] parameter(0)
303       y = pred[] parameter(1)
304       ROOT apply_op = pred[] or(x, y)
305     }
306 
307     ENTRY test_computation {
308       id = u32[] replica-id()
309       c = u32[] constant(0)
310       p = pred[] compare(id, c), direction=EQ
311       p2 = pred[1] bitcast(p)
312       crs = pred[1] all-reduce(p2), replica_groups={}, to_apply=apply_op
313       copy = pred[1] copy(crs)
314       ROOT out = pred[1] bitcast(copy)
315     }
316   )";
317 
318   auto config = GetModuleConfigForTest();
319   config.set_replica_count(2);
320   auto module = ParseAndReturnVerifiedModule(hlo_module, config).ValueOrDie();
321   TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
322                           ExecuteReplicated(std::move(module), {},
323                                             /*num_replicas=*/2,
324                                             /*use_threads=*/true));
325   for (int replica_idx = 0; replica_idx < 2; replica_idx++) {
326     EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<bool>({true}),
327                                        results[replica_idx]));
328   }
329 }
330 
331 // Tries all-to-all operations across all 2^kNumDevices - 1 combinations of
332 // devices in sequence.
XLA_TEST_F(CollectiveOpsTest,AllReduce_AllCombinations)333 XLA_TEST_F(CollectiveOpsTest, AllReduce_AllCombinations) {
334   const int64_t kNumDevices = 4;
335   const int64_t kNumElems = 1024;
336 
337   for (std::vector<int64> devices : PowerSetOfIota(kNumDevices)) {
338     SCOPED_TRACE(absl::StrFormat("Running on devices {%s}",
339                                  absl::StrJoin(devices, ", ")));
340 
341     DeviceAssignment device_assn = MakeDeviceAssn(devices);
342 
343     auto config = GetModuleConfigForTest();
344     config.set_replica_count(devices.size());
345     config.set_static_device_assignment(device_assn);
346 
347     std::vector<float> input_vec(kNumElems);
348     absl::c_iota(input_vec, 0);
349     auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
350 
351     auto module = MakeCrsModule(input_literal.shape(),
352                                 /*replica_groups=*/{}, config);
353 
354     TF_ASSERT_OK_AND_ASSIGN(
355         std::vector<Literal> results,
356         ExecuteReplicated(std::move(module), {&input_literal},
357                           /*num_replicas=*/devices.size(), &device_assn,
358                           /*run_hlo_passes=*/true, /*use_threads=*/true));
359   }
360 }
361 
362 // Check that the NCCL data structures in our all-reduce implementation are
363 // cached as we expect.
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllReduce_NcclChannelCaching))364 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllReduce_NcclChannelCaching)) {
365   const int64_t kNumElems = 1024;
366 
367   std::vector<float> input_vec(kNumElems);
368   absl::c_iota(input_vec, 0);
369   auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
370 
371   // Create three Executables, touching devices {0,1}, {1,2}, and {0,1,2}.
372   struct ExecutableInfo {
373     std::unique_ptr<Executable> executable;
374     DeviceAssignment device_assn;
375     HloRunner::ReplicatedExecuteOptions opts;
376   };
377   std::vector<ExecutableInfo> executables;
378   for (const auto& devices :
379        std::vector<std::vector<int64>>{{0, 1}, {1, 2}, {0, 1, 2}}) {
380     executables.emplace_back();
381     auto& e = executables.back();
382 
383     e.device_assn = MakeDeviceAssn(devices);
384 
385     auto config = GetModuleConfigForTest();
386     config.set_replica_count(devices.size());
387     config.set_static_device_assignment(e.device_assn);
388     auto module = MakeCrsModule(input_literal.shape(),
389                                 /*replica_groups=*/{}, config);
390     e.executable =
391         test_runner_
392             .CreateExecutable(std::move(module), /*run_hlo_passes=*/true)
393             .ValueOrDie();
394 
395     e.opts.num_replicas = devices.size();
396     e.opts.use_threads = true;
397     e.opts.arguments.push_back(&input_literal);
398   }
399 
400   auto run_executable = [&](int64_t i) {
401     auto& e = executables[i];
402     TF_ASSERT_OK(
403         test_runner_
404             .ExecuteReplicated(e.executable.get(), e.opts, &e.device_assn)
405             .status());
406   };
407 
408   // Run the executables and check that channels are opened as we expect.
409   run_executable(0);
410   EXPECT_THAT(OpenNcclChannels(), IsSupersetOf({0, 1}));
411 
412   run_executable(2);
413   EXPECT_THAT(OpenNcclChannels(), IsSupersetOf({0, 1, 2}));
414 
415   run_executable(1);
416   EXPECT_THAT(OpenNcclChannels(), IsSupersetOf({0, 1, 2}));
417 
418   // Tear down the executables and check that channels are closed as we expect.
419   // Note that after we tear down an executable *all* the nccl channels may go
420   // away, so we rerun all of the executables that haven't been torn down.
421   executables[2].executable.reset();
422   run_executable(0);
423   run_executable(1);
424   EXPECT_THAT(OpenNcclChannels(), IsSupersetOf({0, 1, 2}));
425 
426   executables[0].executable.reset();
427   run_executable(1);
428   EXPECT_THAT(OpenNcclChannels(), IsSupersetOf({1, 2}));
429 
430   executables[1].executable.reset();
431 }
432 
433 // Runs the same executable many times concurrently.  The all-reduces should not
434 // conflict with one another.
XLA_TEST_F(CollectiveOpsTest,AllReduce_ManyConcurrentAllReduces)435 XLA_TEST_F(CollectiveOpsTest, AllReduce_ManyConcurrentAllReduces) {
436   const int64_t kNumElems = 1024;
437   const int64_t kNumThreads = 200;
438   const int64_t kRunsPerThread = 10;
439 
440   std::vector<float> input_vec(kNumElems);
441   absl::c_iota(input_vec, 0);
442   auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
443 
444   auto config = GetModuleConfigForTest();
445   config.set_replica_count(2);
446   auto executable =
447       test_runner_
448           .CreateExecutable(MakeCrsModule(input_literal.shape(),
449                                           /*replica_groups=*/{}, config),
450                             /*run_hlo_passes=*/true)
451           .ValueOrDie();
452   std::vector<int64> devices = {0, 1};
453   auto device_assn = MakeDeviceAssn(devices);
454 
455   HloRunner::ReplicatedExecuteOptions opts;
456   opts.num_replicas = devices.size();
457   opts.use_threads = true;
458   opts.arguments.push_back(&input_literal);
459 
460   tensorflow::BlockingCounter done(kNumThreads * kRunsPerThread);
461   tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), TestName(),
462                                       kNumThreads);
463   for (int64_t i = 0; i < kNumThreads * kRunsPerThread; ++i) {
464     pool.Schedule([&] {
465       TF_ASSERT_OK(
466           test_runner_.ExecuteReplicated(executable.get(), opts, &device_assn)
467               .status());
468       done.DecrementCount();
469     });
470   }
471   done.Wait();
472 }
473 
474 // Runs the same executable many times concurrently.  The all-reduces should not
475 // conflict with one another.
XLA_TEST_F(CollectiveOpsTest,AllReduce_CombinableAllReduces)476 XLA_TEST_F(CollectiveOpsTest, AllReduce_CombinableAllReduces) {
477   std::string hlo_string = R"(
478     HloModule test
479 
480     apply_op {
481       x = f32[] parameter(0)
482       y = f32[] parameter(1)
483       ROOT apply_op = f32[] add(x, y)
484     }
485 
486     ENTRY test_computation {
487       p0 = f32[5] parameter(0)
488       p1 = f32[5] parameter(1)
489       crs0 = f32[5] all-reduce(p0), replica_groups={}, to_apply=apply_op
490       crs1 = f32[5] all-reduce(p1), replica_groups={}, to_apply=apply_op
491       ROOT out = (f32[5], f32[5]) tuple(f32[5] crs0, f32[5] crs1)
492     }
493   )";
494   static constexpr int kNumReplicas = 2;
495   auto config = GetModuleConfigForTest();
496   config.set_replica_count(kNumReplicas);
497   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
498                           ParseAndReturnVerifiedModule(hlo_string, config));
499 
500   std::vector<float> input0_vec = {1., 2., 3., 4., 5.};
501   auto input0_literal = LiteralUtil::CreateR1<float>(input0_vec);
502   std::vector<float> input1_vec = {7., 3., 4., 1., 2.};
503   auto input1_literal = LiteralUtil::CreateR1<float>(input1_vec);
504 
505   TF_ASSERT_OK_AND_ASSIGN(
506       std::vector<Literal> results,
507       ExecuteReplicated(std::move(module), {&input0_literal, &input1_literal},
508                         /*num_replicas=*/kNumReplicas,
509                         /*use_threads=*/true));
510   std::vector<float> expected0_vec = {2., 4., 6., 8., 10.};
511   auto expected0_literal = LiteralUtil::CreateR1<float>(expected0_vec);
512   std::vector<float> expected1_vec = {14., 6., 8., 2., 4.};
513   auto expected1_literal = LiteralUtil::CreateR1<float>(expected1_vec);
514   for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) {
515     auto rs = results[replica_idx].DecomposeTuple();
516     EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected0_literal, rs[0],
517                                              ErrorSpec{1e-5, 1e-5}));
518     EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected1_literal, rs[1],
519                                              ErrorSpec{1e-5, 1e-5}));
520   }
521 }
522 
523 // Runs an all-reduce with three partitions:
524 //  {0}, {1,2}, {3}
525 // meaning, the all-reduce is a nop for devices 0 and 3, and only devices 1 and
526 // 2 actually exchange data with each other.
XLA_TEST_F(CollectiveOpsTest,AllReduce_ThreeReplicaGroups)527 XLA_TEST_F(CollectiveOpsTest, AllReduce_ThreeReplicaGroups) {
528   // Test a prime number so it's not all powers of 2.
529   const int64_t kNumElems = 137;
530 
531   auto config = GetModuleConfigForTest();
532   config.set_replica_count(4);
533   std::vector<float> input_vec(kNumElems);
534   absl::c_iota(input_vec, 0);
535   auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
536   auto module = MakeCrsModule(
537       /*shape_str=*/input_literal.shape(),
538       /*replica_groups=*/{{0}, {1, 2}, {3}}, config);
539 
540   TF_ASSERT_OK_AND_ASSIGN(
541       std::vector<Literal> results,
542       ExecuteReplicated(std::move(module), {&input_literal}, /*num_replicas=*/4,
543                         /*use_threads=*/true));
544 
545   ASSERT_EQ(results.size(), 4);
546 
547   std::vector<float> input_vec_doubled;
548   for (float n : input_vec) {
549     input_vec_doubled.push_back(n * 2);
550   }
551   auto input_literal_doubled = LiteralUtil::CreateR1<float>(input_vec_doubled);
552 
553   EXPECT_TRUE(LiteralTestUtil::Equal(input_literal, results[0]));
554   EXPECT_TRUE(LiteralTestUtil::Equal(input_literal_doubled, results[1]));
555   EXPECT_TRUE(LiteralTestUtil::Equal(input_literal_doubled, results[2]));
556   EXPECT_TRUE(LiteralTestUtil::Equal(input_literal, results[3]));
557 }
558 
XLA_TEST_F(CollectiveOpsTest,AllReduce_Degenerate)559 XLA_TEST_F(CollectiveOpsTest, AllReduce_Degenerate) {
560   const char* const kModuleStr = R"(
561       HloModule test
562 
563       apply_op {
564         x = u32[] parameter(0)
565         y = u32[] parameter(1)
566         ROOT apply_op = u32[] add(x, y)
567       }
568 
569       ENTRY test_computation {
570         id = u32[] replica-id()
571         ROOT crs = u32[] all-reduce(id), replica_groups={{0},{1},{2},{3}}, to_apply=apply_op
572       }
573     )";
574   static constexpr int kNumReplicas = 4;
575   auto config = GetModuleConfigForTest();
576   config.set_replica_count(kNumReplicas);
577   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
578                           ParseAndReturnVerifiedModule(kModuleStr, config));
579   TF_ASSERT_OK_AND_ASSIGN(
580       std::vector<Literal> results,
581       ExecuteReplicated(std::move(module), {}, /*num_replicas=*/kNumReplicas,
582                         /*use_threads=*/true));
583 
584   ASSERT_EQ(results.size(), kNumReplicas);
585   for (int i = 0; i < kNumReplicas; ++i) {
586     LiteralTestUtil::ExpectR0Equal<uint32_t>(i, results[i]);
587   }
588 }
589 
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AsyncAllReduce))590 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllReduce)) {
591   const absl::string_view kModuleStr = R"(
592       HloModule test
593 
594       apply_op {
595         x = u32[] parameter(0)
596         y = u32[] parameter(1)
597         ROOT apply_op = u32[] add(x, y)
598       }
599 
600       ENTRY test_computation {
601         id = u32[] replica-id()
602         start = (u32[], u32[]) all-reduce-start(id), to_apply=apply_op
603         ROOT done = u32[] all-reduce-done(start)
604       }
605     )";
606   static constexpr int kNumReplicas = 4;
607   HloModuleConfig config = GetModuleConfigForTest();
608   config.set_replica_count(kNumReplicas);
609   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
610                           ParseAndReturnVerifiedModule(kModuleStr, config));
611   TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
612                           ExecuteReplicated(std::move(module), {}, kNumReplicas,
613                                             /*use_threads=*/true));
614 
615   ASSERT_EQ(results.size(), kNumReplicas);
616   uint32_t expected = 6;  // sum [0,4)
617   for (int i = 0; i < kNumReplicas; ++i) {
618     LiteralTestUtil::ExpectR0Equal<uint32_t>(expected, results[i]);
619   }
620 }
621 
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AsyncAllReduceTwoOperands))622 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllReduceTwoOperands)) {
623   const absl::string_view kModuleStr = R"(
624       HloModule test
625 
626       apply_op {
627         x = u32[] parameter(0)
628         y = u32[] parameter(1)
629         ROOT apply_op = u32[] add(x, y)
630       }
631 
632       ENTRY test_computation {
633         id = u32[] replica-id()
634         id2 = u32[] multiply(id, id)
635         start = ((u32[], u32[]), (u32[], u32[])) all-reduce-start(id, id2), to_apply=apply_op
636         ROOT done = (u32[], u32[]) all-reduce-done(start)
637       }
638     )";
639   static constexpr int kNumReplicas = 4;
640   HloModuleConfig config = GetModuleConfigForTest();
641   config.set_replica_count(kNumReplicas);
642   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
643                           ParseAndReturnVerifiedModule(kModuleStr, config));
644   TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
645                           ExecuteReplicated(std::move(module), {}, kNumReplicas,
646                                             /*use_threads=*/true));
647 
648   ASSERT_EQ(results.size(), kNumReplicas);
649   uint32_t expected0 = 6;   // sum [0,4)
650   uint32_t expected1 = 14;  // sum squares [0,4)
651   for (int i = 0; i < kNumReplicas; ++i) {
652     std::vector<Literal> replica_results = results[i].DecomposeTuple();
653     LiteralTestUtil::ExpectR0Equal<uint32_t>(expected0, replica_results[0]);
654     LiteralTestUtil::ExpectR0Equal<uint32_t>(expected1, replica_results[1]);
655   }
656 }
657 
XLA_TEST_F(CollectiveOpsTest,ReplicaId)658 XLA_TEST_F(CollectiveOpsTest, ReplicaId) {
659   const char* const kModuleStr = R"(
660   HloModule test
661   ENTRY test_computation {
662     id = u32[] replica-id()
663     ROOT out = u32[] copy(id)
664   }
665   )";
666   const int64_t kNumReplicas = 4;
667 
668   auto config = GetModuleConfigForTest();
669   config.set_replica_count(kNumReplicas);
670   TF_ASSERT_OK_AND_ASSIGN(auto module,
671                           ParseAndReturnVerifiedModule(kModuleStr));
672 
673   TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
674                           ExecuteReplicated(std::move(module), {}, kNumReplicas,
675                                             /*use_threads=*/true));
676 
677   ASSERT_EQ(results.size(), kNumReplicas);
678   for (uint32 i = 0; i < kNumReplicas; ++i) {
679     EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR0(i), results[i]));
680   }
681 }
682 
XLA_TEST_F(CollectiveOpsTest,CollectivePermute_Simple)683 XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) {
684   const char* const kModuleStr = R"(
685   HloModule test
686   ENTRY test_computation {
687     replica = u32[] replica-id()
688     ten = u32[] constant(10)
689     sum = u32[] add(replica, ten)
690     p = u32[2] broadcast(sum), dimensions={}
691     permute = u32[2] collective-permute(p), source_target_pairs={{1,0}, {0,1}, {2,2}}
692     ROOT copy = u32[2] copy(permute)
693   }
694   )";
695   const int64_t kNumReplicas = 4;
696 
697   auto config = GetModuleConfigForTest();
698   config.set_replica_count(kNumReplicas);
699   TF_ASSERT_OK_AND_ASSIGN(auto module,
700                           ParseAndReturnVerifiedModule(kModuleStr, config));
701 
702   TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
703                           ExecuteReplicated(std::move(module), {}, kNumReplicas,
704                                             /*use_threads=*/true));
705   ASSERT_EQ(results.size(), kNumReplicas);
706   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({11, 11}),
707                                      results[0]));
708   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({10, 10}),
709                                      results[1]));
710   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({12, 12}),
711                                      results[2]));
712   // Nothing writes to replica 3, so it is memzero'ed.
713   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({0, 0}),
714                                      results[3]));
715 }
716 
XLA_TEST_F(CollectiveOpsTest,CollectivePermute_Degnerate)717 XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Degnerate) {
718   const char* const kModuleStr = R"(
719   HloModule test
720   ENTRY test_computation {
721     replica = u32[] replica-id()
722     ten = u32[] constant(10)
723     sum = u32[] add(replica, ten)
724     p = u32[2] broadcast(sum), dimensions={}
725     permute = u32[2] collective-permute(p), source_target_pairs={{0,0}, {1,1}, {2,2}, {3,3}}
726     ROOT copy = u32[2] copy(permute)
727   }
728   )";
729   const int64_t kNumReplicas = 4;
730 
731   auto config = GetModuleConfigForTest();
732   config.set_replica_count(kNumReplicas);
733   TF_ASSERT_OK_AND_ASSIGN(auto module,
734                           ParseAndReturnVerifiedModule(kModuleStr, config));
735 
736   TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
737                           ExecuteReplicated(std::move(module), {}, kNumReplicas,
738                                             /*use_threads=*/true));
739   ASSERT_EQ(results.size(), kNumReplicas);
740   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({10, 10}),
741                                      results[0]));
742   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({11, 11}),
743                                      results[1]));
744   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({12, 12}),
745                                      results[2]));
746   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({13, 13}),
747                                      results[3]));
748 }
749 
XLA_TEST_F(CollectiveOpsTest,CollectivePermute_NoDegnerate)750 XLA_TEST_F(CollectiveOpsTest, CollectivePermute_NoDegnerate) {
751   const char* const kModuleStr = R"(
752   HloModule test
753   ENTRY test_computation {
754     replica = u32[] replica-id()
755     ten = u32[] constant(10)
756     sum = u32[] add(replica, ten)
757     p = u32[2] broadcast(sum), dimensions={}
758     permute = u32[2] collective-permute(p), source_target_pairs={{0,0}, {1,1}, {2,2}}
759     ROOT copy = u32[2] copy(permute)
760   }
761   )";
762   const int64_t kNumReplicas = 4;
763 
764   auto config = GetModuleConfigForTest();
765   config.set_replica_count(kNumReplicas);
766   TF_ASSERT_OK_AND_ASSIGN(auto module,
767                           ParseAndReturnVerifiedModule(kModuleStr, config));
768 
769   TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
770                           ExecuteReplicated(std::move(module), {}, kNumReplicas,
771                                             /*use_threads=*/true));
772   ASSERT_EQ(results.size(), kNumReplicas);
773   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({10, 10}),
774                                      results[0]));
775   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({11, 11}),
776                                      results[1]));
777   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({12, 12}),
778                                      results[2]));
779   // Nothing writes to replica 3, so it is memzero'ed.
780   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({0, 0}),
781                                      results[3]));
782 }
783 
XLA_TEST_F(CollectiveOpsTest,CollectivePermute_Rotate)784 XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Rotate) {
785   const char* const kModuleStr = R"(
786   HloModule test
787   ENTRY test_computation {
788     replica = u32[] replica-id()
789     ten = u32[] constant(10)
790     sum = u32[] add(replica, ten)
791     p = u32[2] broadcast(sum), dimensions={}
792     permute = u32[2] collective-permute(p), source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}}
793     ROOT copy = u32[2] copy(permute)
794   }
795   )";
796   const int64_t kNumReplicas = 4;
797 
798   auto config = GetModuleConfigForTest();
799   config.set_replica_count(kNumReplicas);
800   TF_ASSERT_OK_AND_ASSIGN(auto module,
801                           ParseAndReturnVerifiedModule(kModuleStr, config));
802 
803   TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
804                           ExecuteReplicated(std::move(module), {}, kNumReplicas,
805                                             /*use_threads=*/true));
806   ASSERT_EQ(results.size(), kNumReplicas);
807   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({13, 13}),
808                                      results[0]));
809   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({10, 10}),
810                                      results[1]));
811   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({11, 11}),
812                                      results[2]));
813   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1<uint32>({12, 12}),
814                                      results[3]));
815 }
816 
XLA_TEST_F(CollectiveOpsTest,AllToAll_EmptyReplicaGroups)817 XLA_TEST_F(CollectiveOpsTest, AllToAll_EmptyReplicaGroups) {
818   const char* const kModuleStr = R"(
819   HloModule test
820   ENTRY test_computation {
821     id = u32[] replica-id()
822     id2 = u32[2] broadcast(id), dimensions={}
823     a0 = u32[2] constant({10, 15})
824     b0 = u32[2] constant({20, 25})
825     c0 = u32[2] constant({30, 35})
826     d0 = u32[2] constant({40, 45})
827     a1 = u32[2] add(id2, a0)
828     b1 = u32[2] add(id2, b0)
829     c1 = u32[2] add(id2, c0)
830     d1 = u32[2] add(id2, d0)
831     all2all = (u32[2], u32[2], u32[2], u32[2]) all-to-all(a1, b1, c1, d1), replica_groups={}
832     a_prime = u32[2] get-tuple-element(all2all), index=0
833     b_prime = u32[2] get-tuple-element(all2all), index=1
834     c_prime = u32[2] get-tuple-element(all2all), index=2
835     d_prime = u32[2] get-tuple-element(all2all), index=3
836     ROOT out = u32[8] concatenate(a_prime, b_prime, c_prime, d_prime), dimensions={0}
837   }
838   )";
839   const int64_t kNumReplicas = 4;
840   auto config = GetModuleConfigForTest(kNumReplicas);
841   TF_ASSERT_OK_AND_ASSIGN(auto module,
842                           ParseAndReturnVerifiedModule(kModuleStr, config));
843 
844   TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
845                           ExecuteReplicated(std::move(module), {}, kNumReplicas,
846                                             /*use_threads=*/true));
847   ASSERT_EQ(results.size(), kNumReplicas);
848   LiteralTestUtil::ExpectR1Equal<uint32>({10, 15, 11, 16, 12, 17, 13, 18},
849                                          results[0]);
850   LiteralTestUtil::ExpectR1Equal<uint32>({20, 25, 21, 26, 22, 27, 23, 28},
851                                          results[1]);
852   LiteralTestUtil::ExpectR1Equal<uint32>({30, 35, 31, 36, 32, 37, 33, 38},
853                                          results[2]);
854   LiteralTestUtil::ExpectR1Equal<uint32>({40, 45, 41, 46, 42, 47, 43, 48},
855                                          results[3]);
856 }
857 
XLA_TEST_F(CollectiveOpsTest,AllToAll_OrderedReplicaGroups)858 XLA_TEST_F(CollectiveOpsTest, AllToAll_OrderedReplicaGroups) {
859   const char* const kModuleStr = R"(
860   HloModule test
861   ENTRY test_computation {
862     id = u32[] replica-id()
863     id2 = u32[2] broadcast(id), dimensions={}
864     a0 = u32[2] constant({10, 15})
865     b0 = u32[2] constant({20, 25})
866     c0 = u32[2] constant({30, 35})
867     d0 = u32[2] constant({40, 45})
868     a1 = u32[2] add(id2, a0)
869     b1 = u32[2] add(id2, b0)
870     c1 = u32[2] add(id2, c0)
871     d1 = u32[2] add(id2, d0)
872     all2all = (u32[2], u32[2], u32[2], u32[2]) all-to-all(a1, b1, c1, d1), replica_groups={{3,2,1,0}}
873     a_prime = u32[2] get-tuple-element(all2all), index=0
874     b_prime = u32[2] get-tuple-element(all2all), index=1
875     c_prime = u32[2] get-tuple-element(all2all), index=2
876     d_prime = u32[2] get-tuple-element(all2all), index=3
877     ROOT out = u32[8] concatenate(a_prime, b_prime, c_prime, d_prime), dimensions={0}
878   }
879   )";
880   const int64_t kNumReplicas = 4;
881   auto config = GetModuleConfigForTest(kNumReplicas);
882   TF_ASSERT_OK_AND_ASSIGN(auto module,
883                           ParseAndReturnVerifiedModule(kModuleStr, config));
884 
885   TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
886                           ExecuteReplicated(std::move(module), {}, kNumReplicas,
887                                             /*use_threads=*/true));
888   ASSERT_EQ(results.size(), kNumReplicas);
889   LiteralTestUtil::ExpectR1Equal<uint32>({43, 48, 42, 47, 41, 46, 40, 45},
890                                          results[0]);
891   LiteralTestUtil::ExpectR1Equal<uint32>({33, 38, 32, 37, 31, 36, 30, 35},
892                                          results[1]);
893   LiteralTestUtil::ExpectR1Equal<uint32>({23, 28, 22, 27, 21, 26, 20, 25},
894                                          results[2]);
895   LiteralTestUtil::ExpectR1Equal<uint32>({13, 18, 12, 17, 11, 16, 10, 15},
896                                          results[3]);
897 }
898 
XLA_TEST_F(CollectiveOpsTest,AllToAll_TwoReplicaGroups)899 XLA_TEST_F(CollectiveOpsTest, AllToAll_TwoReplicaGroups) {
900   const char* const kModuleStr = R"(
901   HloModule test
902   ENTRY test_computation {
903     id = u32[] replica-id()
904     id2 = u32[2] broadcast(id), dimensions={}
905     a0 = u32[2] constant({10, 15})
906     b0 = u32[2] constant({20, 25})
907     a1 = u32[2] add(id2, a0)
908     b1 = u32[2] add(id2, b0)
909     all2all = (u32[2], u32[2]) all-to-all(a1, b1), replica_groups={{2,1},{3,0}}
910     a_prime = u32[2] get-tuple-element(all2all), index=0
911     b_prime = u32[2] get-tuple-element(all2all), index=1
912     ROOT out = u32[4] concatenate(a_prime, b_prime), dimensions={0}
913   }
914   )";
915   const int64_t kNumReplicas = 4;
916   auto config = GetModuleConfigForTest(kNumReplicas);
917   TF_ASSERT_OK_AND_ASSIGN(auto module,
918                           ParseAndReturnVerifiedModule(kModuleStr, config));
919 
920   TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
921                           ExecuteReplicated(std::move(module), {}, kNumReplicas,
922                                             /*use_threads=*/true));
923   ASSERT_EQ(results.size(), kNumReplicas);
924   LiteralTestUtil::ExpectR1Equal<uint32>({23, 28, 20, 25}, results[0]);
925   LiteralTestUtil::ExpectR1Equal<uint32>({22, 27, 21, 26}, results[1]);
926   LiteralTestUtil::ExpectR1Equal<uint32>({12, 17, 11, 16}, results[2]);
927   LiteralTestUtil::ExpectR1Equal<uint32>({13, 18, 10, 15}, results[3]);
928 }
929 
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllToAll_SplitDimension))930 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_SplitDimension)) {
931   const char* const kModuleStr = R"(
932   HloModule test
933   ENTRY test_computation {
934     id = u32[] replica-id()
935     id2 = u32[4, 2] broadcast(id), dimensions={}
936     a0 = u32[4, 2] constant({{10, 15}, {20, 25}, {30, 35}, {40, 45}})
937     a1 = u32[4, 2] add(id2, a0)
938     all2all = u32[4, 2] all-to-all(a1), replica_groups={{0,1,2,3}}, dimensions={0}
939     ROOT out = u32[8] reshape(all2all)
940   }
941   )";
942   const int64_t kNumReplicas = 4;
943   auto config = GetModuleConfigForTest(kNumReplicas);
944   TF_ASSERT_OK_AND_ASSIGN(auto module,
945                           ParseAndReturnVerifiedModule(kModuleStr, config));
946 
947   TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
948                           ExecuteReplicated(std::move(module), {}, kNumReplicas,
949                                             /*use_threads=*/true));
950   ASSERT_EQ(results.size(), kNumReplicas);
951   LiteralTestUtil::ExpectR1Equal<uint32>({10, 15, 11, 16, 12, 17, 13, 18},
952                                          results[0]);
953   LiteralTestUtil::ExpectR1Equal<uint32>({20, 25, 21, 26, 22, 27, 23, 28},
954                                          results[1]);
955   LiteralTestUtil::ExpectR1Equal<uint32>({30, 35, 31, 36, 32, 37, 33, 38},
956                                          results[2]);
957   LiteralTestUtil::ExpectR1Equal<uint32>({40, 45, 41, 46, 42, 47, 43, 48},
958                                          results[3]);
959 }
960 
XLA_TEST_F(CollectiveOpsTest,AllGather_Dim0)961 XLA_TEST_F(CollectiveOpsTest, AllGather_Dim0) {
962   const char* const kModuleStr = R"(
963   HloModule test
964   ENTRY test_computation {
965     id = u32[] replica-id()
966     id2 = u32[1, 2] broadcast(id), dimensions={}
967     a0 = u32[1, 2] constant({{10, 15}})
968     a1 = u32[1, 2] add(id2, a0)
969     allgather = u32[4, 2] all-gather(a1), dimensions={0}
970     ROOT out = u32[8] reshape(allgather)
971   }
972   )";
973   const int64_t kNumReplicas = 4;
974   auto config = GetModuleConfigForTest(kNumReplicas);
975   TF_ASSERT_OK_AND_ASSIGN(auto module,
976                           ParseAndReturnVerifiedModule(kModuleStr, config));
977 
978   TF_ASSERT_OK_AND_ASSIGN(
979       std::vector<Literal> results,
980       ExecuteReplicated(std::move(module), {}, kNumReplicas,
981                         /*use_threads=*/true, /*run_hlo_passes=*/true));
982   ASSERT_EQ(results.size(), kNumReplicas);
983   for (const Literal& result : results) {
984     LiteralTestUtil::ExpectR1Equal<uint32>({10, 15, 11, 16, 12, 17, 13, 18},
985                                            result);
986   }
987 }
988 
XLA_TEST_F(CollectiveOpsTest,AllGather_Dim1)989 XLA_TEST_F(CollectiveOpsTest, AllGather_Dim1) {
990   const char* const kModuleStr = R"(
991   HloModule test
992   ENTRY test_computation {
993     id = u32[] replica-id()
994     id2 = u32[2, 1] broadcast(id), dimensions={}
995     a0 = u32[2, 1] constant({{10}, {15}})
996     a1 = u32[2, 1] add(id2, a0)
997     allgather = u32[2, 4] all-gather(a1), dimensions={1}
998     ROOT out = u32[8] reshape(allgather)
999   }
1000   )";
1001   const int64_t kNumReplicas = 4;
1002   auto config = GetModuleConfigForTest(kNumReplicas);
1003   TF_ASSERT_OK_AND_ASSIGN(auto module,
1004                           ParseAndReturnVerifiedModule(kModuleStr, config));
1005 
1006   TF_ASSERT_OK_AND_ASSIGN(
1007       std::vector<Literal> results,
1008       ExecuteReplicated(std::move(module), {}, kNumReplicas,
1009                         /*use_threads=*/true, /*run_hlo_passes=*/true));
1010   ASSERT_EQ(results.size(), kNumReplicas);
1011   for (const Literal& result : results) {
1012     LiteralTestUtil::ExpectR1Equal<uint32>({10, 11, 12, 13, 15, 16, 17, 18},
1013                                            result);
1014   }
1015 }
1016 
XLA_TEST_F(CollectiveOpsTest,AllReduce_TupleAllReduce)1017 XLA_TEST_F(CollectiveOpsTest, AllReduce_TupleAllReduce) {
1018   std::string hlo_string = R"(
1019     HloModule test
1020 
1021     apply_op {
1022       x = f32[] parameter(0)
1023       y = f32[] parameter(1)
1024       ROOT apply_op = f32[] add(x, y)
1025     }
1026 
1027     ENTRY test_computation {
1028       p0 = f32[5] parameter(0)
1029       p1 = f32[7] parameter(1)
1030       ROOT out = (f32[5], f32[7]) all-reduce(p0, p1), replica_groups={}, to_apply=apply_op
1031     }
1032   )";
1033   static constexpr int kNumReplicas = 2;
1034   auto config = GetModuleConfigForTest();
1035   config.set_replica_count(kNumReplicas);
1036   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1037                           ParseAndReturnVerifiedModule(hlo_string, config));
1038 
1039   std::vector<float> input0_vec = {1., 2., 3., 4., 5.};
1040   auto input0_literal = LiteralUtil::CreateR1<float>(input0_vec);
1041   std::vector<float> input1_vec = {
1042       7., 3., 4., 1., 2., 3., 4.,
1043   };
1044   auto input1_literal = LiteralUtil::CreateR1<float>(input1_vec);
1045 
1046   TF_ASSERT_OK_AND_ASSIGN(
1047       std::vector<Literal> results,
1048       ExecuteReplicated(std::move(module), {&input0_literal, &input1_literal},
1049                         /*num_replicas=*/kNumReplicas,
1050                         /*use_threads=*/true));
1051   std::vector<float> expected0_vec = {2., 4., 6., 8., 10.};
1052   auto expected0_literal = LiteralUtil::CreateR1<float>(expected0_vec);
1053   std::vector<float> expected1_vec = {14., 6., 8., 2., 4., 6., 8.};
1054   auto expected1_literal = LiteralUtil::CreateR1<float>(expected1_vec);
1055   for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) {
1056     auto rs = results[replica_idx].DecomposeTuple();
1057     EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected0_literal, rs[0],
1058                                              ErrorSpec{1e-5, 1e-5}));
1059     EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected1_literal, rs[1],
1060                                              ErrorSpec{1e-5, 1e-5}));
1061   }
1062 }
1063 
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllGatherMixedTypes))1064 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllGatherMixedTypes)) {
1065   const char* const kModuleStr = R"(
1066   HloModule test
1067   ENTRY test_computation {
1068     id = u32[] replica-id()
1069     p0 = u32[2, 1] broadcast(id), dimensions={}
1070     p1 = f32[2, 1] convert(p0)
1071     allgather = (u32[2, 2], f32[2, 2]) all-gather(p0, p1), dimensions={1}
1072     ag0 = u32[2, 2] get-tuple-element(allgather), index=0
1073     ag1 = f32[2, 2] get-tuple-element(allgather), index=1
1074     r0 = u32[4] reshape(ag0)
1075     r1 = f32[4] reshape(ag1)
1076     ROOT out = (u32[4], f32[4]) tuple(r0, r1)
1077   }
1078   )";
1079   const int64_t kNumReplicas = 2;
1080   auto config = GetModuleConfigForTest(kNumReplicas);
1081   TF_ASSERT_OK_AND_ASSIGN(auto module,
1082                           ParseAndReturnVerifiedModule(kModuleStr, config));
1083 
1084   TF_ASSERT_OK_AND_ASSIGN(
1085       std::vector<Literal> results,
1086       ExecuteReplicated(std::move(module), {}, kNumReplicas,
1087                         /*use_threads=*/true, /*run_hlo_passes=*/true));
1088   for (int replica_idx = 0; replica_idx < kNumReplicas; replica_idx++) {
1089     auto rs = results[replica_idx].DecomposeTuple();
1090     LiteralTestUtil::ExpectR1Equal<uint32>({0, 1, 0, 1}, rs[0]);
1091     LiteralTestUtil::ExpectR1Near<float>({0.0, 1.0, 0.0, 1.0}, rs[1],
1092                                          ErrorSpec{1e-5, 1e-5});
1093   }
1094 }
1095 
XLA_TEST_F(CollectiveOpsTest,ReduceScatter)1096 XLA_TEST_F(CollectiveOpsTest, ReduceScatter) {
1097   const char* const kModuleStr = R"(
1098   HloModule test
1099   add {
1100     lhs = u32[] parameter(0)
1101     rhs = u32[] parameter(1)
1102     ROOT add = u32[] add(lhs, rhs)
1103   }
1104 
1105   ENTRY main {
1106     c0 = u32[8] constant({1, 2, 3, 4, 5, 6, 7, 8})
1107     c1 = u32[8] constant({10, 11, 12, 13, 14, 15, 16, 17})
1108     zero = u32[] constant(0)
1109     id = u32[] replica-id()
1110     p = pred[] compare(id, zero), direction=EQ
1111     pb = pred[8] broadcast(p), dimensions={}
1112     // data = c0 for replica 0 and c1 for replica 1
1113     data = u32[8] select(pb, c0, c1)
1114     ROOT ars = u32[4] reduce-scatter(data), replica_groups={},
1115                       dimensions={0}, to_apply=add
1116   }
1117   )";
1118 
1119   const int64_t kNumReplicas = 2;
1120   auto config = GetModuleConfigForTest(kNumReplicas);
1121   TF_ASSERT_OK_AND_ASSIGN(auto module,
1122                           ParseAndReturnVerifiedModule(kModuleStr, config));
1123 
1124   TF_ASSERT_OK_AND_ASSIGN(
1125       std::vector<Literal> results,
1126       ExecuteReplicated(std::move(module), {}, kNumReplicas,
1127                         /*use_threads=*/true, /*run_hlo_passes=*/true));
1128   LiteralTestUtil::ExpectR1Equal<uint32>({11, 13, 15, 17}, results[0]);
1129   LiteralTestUtil::ExpectR1Equal<uint32>({19, 21, 23, 25}, results[1]);
1130 }
1131 
XLA_TEST_F(CollectiveOpsTest,ReduceScatter_Dim1)1132 XLA_TEST_F(CollectiveOpsTest, ReduceScatter_Dim1) {
1133   const char* const kModuleStr = R"(
1134   HloModule test
1135   add {
1136     lhs = u32[] parameter(0)
1137     rhs = u32[] parameter(1)
1138     ROOT add = u32[] add(lhs, rhs)
1139   }
1140 
1141   ENTRY main {
1142     c0 = u32[2, 4] constant({{ 1,  2,  3,  4}, { 5,  6,  7,  8}})
1143     c1 = u32[2, 4] constant({{10, 11, 12, 13}, {14, 15, 16, 17}})
1144     zero = u32[] constant(0)
1145     id = u32[] replica-id()
1146     p = pred[] compare(id, zero), direction=EQ
1147     pb = pred[2, 4] broadcast(p), dimensions={}
1148     // data = c0 for replica 0 and c1 for replica 1
1149     data = u32[2, 4] select(pb, c0, c1)
1150     // all-reduce result = {{11, 13, 15, 17}, {19, 21, 23, 25}}
1151     ars = u32[2, 2] reduce-scatter(data), replica_groups={},
1152                     dimensions={1}, to_apply=add
1153     ROOT r = u32[4] reshape(ars)
1154   }
1155   )";
1156 
1157   const int64_t kNumReplicas = 2;
1158   auto config = GetModuleConfigForTest(kNumReplicas);
1159   TF_ASSERT_OK_AND_ASSIGN(auto module,
1160                           ParseAndReturnVerifiedModule(kModuleStr, config));
1161 
1162   TF_ASSERT_OK_AND_ASSIGN(
1163       std::vector<Literal> results,
1164       ExecuteReplicated(std::move(module), {}, kNumReplicas,
1165                         /*use_threads=*/true, /*run_hlo_passes=*/true));
1166   LiteralTestUtil::ExpectR1Equal<uint32>({11, 13, 19, 21}, results[0]);
1167   LiteralTestUtil::ExpectR1Equal<uint32>({15, 17, 23, 25}, results[1]);
1168 }
1169 
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllReduceReassociate))1170 XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllReduceReassociate)) {
1171   const char* const kModuleStr = R"(
1172   HloModule m
1173   sum {
1174     a = f32[] parameter(0)
1175     b = f32[] parameter(1)
1176     ROOT add.2 = f32[] add(a, b)
1177   }
1178 
1179   ENTRY main {
1180     c0 = f32[8] constant({  1,  2,  3,  4,  5,  6,  7,  8})
1181     c1 = f32[8] constant({ 11, 12, 13, 14, 15, 16, 17, 18})
1182     c2 = f32[8] constant({  2,  3,  4,  5,  6,  7,  8,  9})
1183     c3 = f32[8] constant({ 12, 13, 14, 15, 16, 17, 18, 19})
1184     zero = u32[] constant(0)
1185     id = u32[] replica-id()
1186     p = pred[] compare(id, zero), direction=EQ
1187     pb = pred[8] broadcast(p), dimensions={}
1188     // data0 = c0 for replica 0 and c1 for replica 1
1189     data0 = f32[8] select(pb, c0, c1)
1190     // data1 = c2 for replica 0 and c3 for replica 1
1191     data1 = f32[8] select(pb, c2, c3)
1192 
1193     ar0 = f32[8] all-reduce(data0), replica_groups={}, to_apply=sum
1194     ar1 = f32[8] all-reduce(data1), replica_groups={}, to_apply=sum
1195     ROOT add = f32[8] add(ar0, ar1)
1196   }
1197   )";
1198   const int64_t kNumReplicas = 2;
1199   auto config = GetModuleConfigForTest(kNumReplicas);
1200   TF_ASSERT_OK_AND_ASSIGN(auto module,
1201                           ParseAndReturnVerifiedModule(kModuleStr, config));
1202 
1203   TF_ASSERT_OK_AND_ASSIGN(
1204       std::vector<Literal> results,
1205       ExecuteReplicated(std::move(module), {}, kNumReplicas,
1206                         /*use_threads=*/true, /*run_hlo_passes=*/true));
1207 
1208   const ErrorSpec es{1e-5, 1e-5};
1209   EXPECT_TRUE(LiteralTestUtil::NearOrEqual(results[0], results[1], es));
1210   LiteralTestUtil::ExpectR1Near<float>(
1211       {26.0, 30.0, 34.0, 38.0, 42.0, 46.0, 50.0, 54.0}, results[0], es);
1212 }
1213 
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllGatherBroadcastReorder_NonUniform))1214 XLA_TEST_F(CollectiveOpsTest,
1215            DISABLED_ON_CPU(AllGatherBroadcastReorder_NonUniform)) {
1216   const char* const kModuleStr = R"(
1217   HloModule m
1218 
1219   ENTRY main {
1220     c0 = u32[2, 3] constant({{ 1,  2,  3}, { 4, 5, 6}})
1221     c1 = u32[2, 3] constant({{10, 11, 12}, {13, 14, 15}})
1222     zero = u32[] constant(0)
1223     id = u32[] replica-id()
1224     p = pred[] compare(id, zero), direction=EQ
1225     pb = pred[2, 3] broadcast(p), dimensions={}
1226     // data = c0 for replica 0 and c1 for replica 1
1227     data = u32[2, 3] select(pb, c0, c1)
1228     bc = u32[2, 4, 3] broadcast(data), dimensions={0, 2}
1229     ROOT ag = u32[2, 4, 6] all-gather(bc), dimensions={2}, replica_groups={{0, 1}}
1230   }
1231   )";
1232 
1233   const int64_t kNumReplicas = 2;
1234   auto config = GetModuleConfigForTest(kNumReplicas);
1235   TF_ASSERT_OK_AND_ASSIGN(auto module,
1236                           ParseAndReturnVerifiedModule(kModuleStr, config));
1237 
1238   TF_ASSERT_OK_AND_ASSIGN(
1239       std::vector<Literal> results,
1240       ExecuteReplicated(std::move(module), {}, kNumReplicas,
1241                         /*use_threads=*/true, /*run_hlo_passes=*/true));
1242 
1243   EXPECT_TRUE(LiteralTestUtil::Equal(results[0], results[1]));
1244   LiteralTestUtil::ExpectR3Equal<uint32_t>({{{1, 2, 3, 10, 11, 12},
1245                                              {1, 2, 3, 10, 11, 12},
1246                                              {1, 2, 3, 10, 11, 12},
1247                                              {1, 2, 3, 10, 11, 12}},
1248                                             {{4, 5, 6, 13, 14, 15},
1249                                              {4, 5, 6, 13, 14, 15},
1250                                              {4, 5, 6, 13, 14, 15},
1251                                              {4, 5, 6, 13, 14, 15}}},
1252                                            results[0]);
1253 }
1254 
XLA_TEST_F(CollectiveOpsTest,DISABLED_ON_CPU (AllGatherBroadcastReorder_Uniform))1255 XLA_TEST_F(CollectiveOpsTest,
1256            DISABLED_ON_CPU(AllGatherBroadcastReorder_Uniform)) {
1257   const char* const kModuleStr = R"(
1258   HloModule m
1259 
1260   ENTRY main {
1261     c0 = u32[2, 3] constant({{ 1,  2,  3}, { 4, 5, 6}})
1262     c1 = u32[2, 3] constant({{10, 11, 12}, {13, 14, 15}})
1263     zero = u32[] constant(0)
1264     id = u32[] replica-id()
1265     p = pred[] compare(id, zero), direction=EQ
1266     pb = pred[2, 3] broadcast(p), dimensions={}
1267     // data = c0 for replica 0 and c1 for replica 1
1268     data = u32[2, 3] select(pb, c0, c1)
1269     bc = u32[2, 4, 3] broadcast(data), dimensions={0, 2}
1270     ROOT ag = u32[2, 8, 3] all-gather(bc), dimensions={1}, replica_groups={{0, 1}}
1271   }
1272   )";
1273 
1274   const int64_t kNumReplicas = 2;
1275   auto config = GetModuleConfigForTest(kNumReplicas);
1276   TF_ASSERT_OK_AND_ASSIGN(auto module,
1277                           ParseAndReturnVerifiedModule(kModuleStr, config));
1278 
1279   TF_ASSERT_OK_AND_ASSIGN(
1280       std::vector<Literal> results,
1281       ExecuteReplicated(std::move(module), {}, kNumReplicas,
1282                         /*use_threads=*/true, /*run_hlo_passes=*/true));
1283   EXPECT_TRUE(LiteralTestUtil::Equal(results[0], results[1]));
1284   LiteralTestUtil::ExpectR3Equal<uint32_t>({{{1, 2, 3},
1285                                              {1, 2, 3},
1286                                              {1, 2, 3},
1287                                              {1, 2, 3},
1288                                              {10, 11, 12},
1289                                              {10, 11, 12},
1290                                              {10, 11, 12},
1291                                              {10, 11, 12}},
1292                                             {{4, 5, 6},
1293                                              {4, 5, 6},
1294                                              {4, 5, 6},
1295                                              {4, 5, 6},
1296                                              {13, 14, 15},
1297                                              {13, 14, 15},
1298                                              {13, 14, 15},
1299                                              {13, 14, 15}}},
1300                                            results[0]);
1301 }
1302 
1303 }  // namespace
1304 }  // namespace xla
1305