• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/all_reduce_combiner.h"
17 
18 #include <memory>
19 
20 #include "absl/memory/memory.h"
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/literal_util.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
26 #include "tensorflow/compiler/xla/service/hlo_module.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
30 #include "tensorflow/compiler/xla/tests/test_utils.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/types.h"
35 
36 namespace xla {
37 namespace {
38 
39 using absl::nullopt;
40 using ::testing::AllOf;
41 namespace op = xla::testing::opcode_matchers;
42 int64 kMaxCombineCount = 256;
43 
AllReduceCount(const HloModule & module)44 int64 AllReduceCount(const HloModule& module) {
45   int64 count = 0;
46   for (HloComputation* computation : module.computations()) {
47     if (computation->IsFusionComputation()) {
48       continue;
49     }
50     for (HloInstruction* hlo : computation->instructions()) {
51       if (hlo->opcode() == HloOpcode::kAllReduce) {
52         ++count;
53       }
54     }
55   }
56   return count;
57 }
58 
59 // inputs[i] will be some op producing a shape of size sizes_in_kib[i] which
60 // feeds into a a all reduce op in all_reduces[i]. Returns a tuple
61 // of the all_reduces.
MakeCrossReplicaReductions(std::vector<int64> sizes_in_kib,std::vector<HloComputation * > reductions,std::vector<HloInstruction * > * inputs,HloComputation::Builder * b)62 HloInstruction* MakeCrossReplicaReductions(
63     std::vector<int64> sizes_in_kib, std::vector<HloComputation*> reductions,
64     std::vector<HloInstruction*>* inputs, HloComputation::Builder* b) {
65   CHECK_EQ(reductions.size(), sizes_in_kib.size());
66   std::vector<HloInstruction*> all_reduces;
67   for (int i = 0; i < sizes_in_kib.size(); i++) {
68     int64 size_in_kib = sizes_in_kib[i];
69     HloComputation* reduction = reductions[i];
70     auto constant = b->AddInstruction(
71         HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.3)));
72     Shape shape = ShapeUtil::MakeShape(
73         F32, {static_cast<int32>(size_in_kib * 1024 / sizeof(float))});
74     auto input =
75         b->AddInstruction(HloInstruction::CreateBroadcast(shape, constant, {}));
76     inputs->push_back(input);
77     all_reduces.push_back(b->AddInstruction(HloInstruction::CreateAllReduce(
78         shape, {input}, reduction, /*replica_groups=*/{},
79         /*constrain_layout=*/false, /*channel_id=*/nullopt,
80         /*use_global_device_ids=*/false)));
81   }
82   return b->AddInstruction(HloInstruction::CreateTuple(all_reduces));
83 }
84 
85 // Create and add a reduction computation in the given type to the module.
MakeReduction(const HloOpcode type,HloModule * module)86 HloComputation* MakeReduction(const HloOpcode type, HloModule* module) {
87   HloComputation::Builder sum_builder(HloOpcodeString(type));
88   auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter(
89       /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x"));
90   auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter(
91       /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y"));
92   sum_builder.AddInstruction(
93       HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {}), type, x, y));
94   HloComputation* reduction =
95       module->AddEmbeddedComputation(sum_builder.Build());
96   return reduction;
97 }
98 
99 // Creates replica groups for AllReduce. groups[i] represents replica ids
100 // for group 'i'.
CreateReplicaGroups(absl::Span<const std::vector<int64>> groups)101 std::vector<ReplicaGroup> CreateReplicaGroups(
102     absl::Span<const std::vector<int64>> groups) {
103   std::vector<ReplicaGroup> replica_groups(groups.size());
104   for (int64 i = 0; i < groups.size(); ++i) {
105     *replica_groups[i].mutable_replica_ids() = {groups[i].begin(),
106                                                 groups[i].end()};
107   }
108   return replica_groups;
109 }
110 
111 using AllReduceCombinerTest = HloTestBase;
112 
113 // Tests combination of several AllReduce instructions.
TEST_F(AllReduceCombinerTest,CombineAllReduces)114 TEST_F(AllReduceCombinerTest, CombineAllReduces) {
115   auto module = CreateNewVerifiedModule();
116   HloComputation* sum = MakeReduction(HloOpcode::kAdd, module.get());
117 
118   HloComputation::Builder b(TestName());
119   std::vector<HloInstruction*> inputs;
120   auto root = MakeCrossReplicaReductions(
121       {1, 2, 10, 7, 6}, {sum, sum, sum, sum, sum}, &inputs, &b);
122   auto computation = module->AddEntryComputation(b.Build());
123 
124   // Run the AllReduce combiner optimization pass.
125   AllReduceCombiner combine(10 * 1024 * 1024, kMaxCombineCount);
126   ASSERT_EQ(AllReduceCount(*module), inputs.size());
127   TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
128   ASSERT_EQ(AllReduceCount(*module), 1);
129   EXPECT_TRUE(changed);
130 
131   ASSERT_EQ(root, computation->root_instruction());
132   ASSERT_EQ(inputs.size(), root->operands().size());
133 
134   HloInstruction* combined = nullptr;
135   for (int64 i = 0; i < root->operands().size(); ++i) {
136     HloInstruction* hlo = root->mutable_operand(i);
137     ASSERT_TRUE(hlo->opcode() == HloOpcode::kGetTupleElement);
138     EXPECT_EQ(hlo->tuple_index(), i);
139     EXPECT_TRUE(ShapeUtil::Equal(inputs[i]->shape(), hlo->shape()));
140 
141     if (combined == nullptr) {
142       // Verify the combined all reduce instruction.
143       combined = hlo->mutable_operand(0);
144       ASSERT_TRUE(combined->opcode() == HloOpcode::kAllReduce);
145       EXPECT_TRUE(ShapeUtil::Equal(root->shape(), combined->shape()));
146       ASSERT_EQ(combined->operands().size(), inputs.size());
147     }
148     EXPECT_EQ(combined, hlo->operand(0));
149     EXPECT_TRUE(ShapeUtil::Equal(inputs[i]->shape(), hlo->shape()));
150     EXPECT_EQ(combined->operand(i), inputs[i]);
151     EXPECT_EQ(1, inputs[i]->users().size());
152   }
153   ASSERT_NE(combined, nullptr);
154 }
155 
156 // Tests combination of several cross replica reduction instructions in
157 // different types.k
TEST_F(AllReduceCombinerTest,CombineCrossReplicaReductionsInGroups)158 TEST_F(AllReduceCombinerTest, CombineCrossReplicaReductionsInGroups) {
159   auto module = CreateNewVerifiedModule();
160   HloComputation* sum = MakeReduction(HloOpcode::kAdd, module.get());
161   HloComputation* min = MakeReduction(HloOpcode::kMinimum, module.get());
162   HloComputation* max = MakeReduction(HloOpcode::kMaximum, module.get());
163   HloComputation* sum_2 = MakeReduction(HloOpcode::kAdd, module.get());
164 
165   HloComputation::Builder b(TestName());
166   std::vector<HloInstruction*> inputs;
167   MakeCrossReplicaReductions(
168       {1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
169       {sum, sum_2, min, min, min, max, max, max, sum, sum_2}, &inputs, &b);
170   module->AddEntryComputation(b.Build());
171 
172   // Run the AllReduce combiner optimization pass.
173   AllReduceCombiner combine(10 * 1024 * 1024, kMaxCombineCount);
174   ASSERT_EQ(AllReduceCount(*module), inputs.size());
175   TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
176   ASSERT_EQ(AllReduceCount(*module), 3)
177       << "expects 3 groups for 3 reduction types.";
178   EXPECT_TRUE(changed);
179 }
180 
181 // Tests that the combination threshold is respected.
TEST_F(AllReduceCombinerTest,RespectThreshold)182 TEST_F(AllReduceCombinerTest, RespectThreshold) {
183   auto module = CreateNewVerifiedModule();
184   HloComputation* sum = MakeReduction(HloOpcode::kAdd, module.get());
185 
186   HloComputation::Builder b(TestName());
187   std::vector<HloInstruction*> inputs;
188   MakeCrossReplicaReductions({8, 4}, {sum, sum}, &inputs, &b);
189   module->AddEntryComputation(b.Build());
190 
191   // Run the AllReduce combiner optimization pass with threshold less than
192   // the combined size of the all reduce ops so that the combination
193   // cannot occur.
194   {
195     AllReduceCombiner combine((8 + 4) * 1024 - 1, kMaxCombineCount);
196     ASSERT_EQ(AllReduceCount(*module), inputs.size());
197     TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
198     EXPECT_EQ(AllReduceCount(*module), inputs.size());
199     EXPECT_FALSE(changed);
200   }
201 
202   // Run the AllReduce combiner optimization pass again with a slightly
203   // higher threshold so that the combination can occur.
204   {
205     AllReduceCombiner combine((8 + 4) * 1024, kMaxCombineCount);
206     ASSERT_EQ(AllReduceCount(*module), inputs.size());
207     TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
208     EXPECT_EQ(AllReduceCount(*module), 1);
209     EXPECT_TRUE(changed);
210   }
211 }
212 
213 // Tests that dependent all reduces are not combined.
TEST_F(AllReduceCombinerTest,NoDependentCombination)214 TEST_F(AllReduceCombinerTest, NoDependentCombination) {
215   auto module = CreateNewVerifiedModule();
216   HloComputation* reduction = MakeReduction(HloOpcode::kAdd, module.get());
217 
218   HloComputation::Builder b(TestName());
219   auto constant = b.AddInstruction(
220       HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.3)));
221   auto all_reduce = b.AddInstruction(HloInstruction::CreateAllReduce(
222       constant->shape(), {constant}, reduction, /*replica_groups=*/{},
223       /*constrain_layout=*/false, /*channel_id=*/nullopt,
224       /*use_global_device_ids=*/false));
225   b.AddInstruction(HloInstruction::CreateAllReduce(
226       constant->shape(), {all_reduce}, reduction,
227       /*replica_groups=*/{}, /*constrain_layout=*/false,
228       /*channel_id=*/nullopt, /*use_global_device_ids=*/false));
229 
230   module->AddEntryComputation(b.Build());
231 
232   AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
233   ASSERT_EQ(AllReduceCount(*module), 2);
234   TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
235   EXPECT_EQ(AllReduceCount(*module), 2);
236   EXPECT_FALSE(changed);
237 }
238 
239 // Tests that AllReduce ops with different groups are not combined.
TEST_F(AllReduceCombinerTest,GroupAllReduce)240 TEST_F(AllReduceCombinerTest, GroupAllReduce) {
241   auto module = CreateNewVerifiedModule(TestName(), /*replica_count=*/4);
242   HloComputation::Builder b(TestName());
243   HloComputation* reduction = MakeReduction(HloOpcode::kAdd, module.get());
244 
245   auto constant = b.AddInstruction(
246       HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.3)));
247   auto crs0 = b.AddInstruction(HloInstruction::CreateAllReduce(
248       constant->shape(), {constant}, reduction,
249       CreateReplicaGroups({{0, 1}, {2, 3}}),
250       /*constrain_layout=*/false,
251       /*channel_id=*/nullopt, /*use_global_device_ids=*/false));
252   auto crs1 = b.AddInstruction(HloInstruction::CreateAllReduce(
253       constant->shape(), {constant}, reduction,
254       CreateReplicaGroups({{0, 2}, {1, 3}}),
255       /*constrain_layout=*/false,
256       /*channel_id=*/nullopt, /*use_global_device_ids=*/false));
257   b.AddInstruction(HloInstruction::CreateTuple({crs0, crs1}));
258 
259   module->AddEntryComputation(b.Build());
260 
261   AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
262   ASSERT_EQ(AllReduceCount(*module), 2);
263   TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
264   EXPECT_EQ(AllReduceCount(*module), 2);
265   EXPECT_FALSE(changed);
266 }
267 
TEST_F(AllReduceCombinerTest,DomainPreventsCombining)268 TEST_F(AllReduceCombinerTest, DomainPreventsCombining) {
269   const char* const hlo_string = R"(
270 HloModule Module
271 
272 summit {
273   lhs = f32[] parameter(0)
274   rhs = f32[] parameter(1)
275   ROOT add = f32[] add(lhs, rhs)
276 }
277 
278 ENTRY entry {
279   param0 = f32[128] parameter(0), sharding={maximal device=0}
280   param1 = f32[128] parameter(1), sharding={maximal device=1}
281   crs0 = f32[128] all-reduce(param0),
282     replica_groups={}, to_apply=summit, sharding={maximal device=0}
283   crs1 = f32[128] all-reduce(param1),
284     replica_groups={}, to_apply=summit, sharding={maximal device=1}
285   domain0 = f32[128] domain(crs0),
286     domain={kind="sharding", entry={{maximal device=0}, {maximal device=1}}, exit={maximal device=0}}
287   domain1 = f32[128] domain(crs1),
288     domain={kind="sharding", entry={{maximal device=0}, {maximal device=1}}, exit={maximal device=1}}
289   ROOT tuple = (f32[128], f32[128]) tuple(domain0, domain1),
290     sharding={{maximal device=0}, {maximal device=1}}
291 }
292 )";
293   TF_ASSERT_OK_AND_ASSIGN(auto module,
294                           ParseAndReturnVerifiedModule(hlo_string));
295   LOG(INFO) << "Original module:\n" << module->ToString();
296 
297   AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
298   ASSERT_EQ(AllReduceCount(*module), 2);
299   TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
300   EXPECT_EQ(AllReduceCount(*module), 2);
301   EXPECT_FALSE(changed);
302 }
303 
304 // This test checks that two CRS instructions that are in separate domains
305 // but with the same domain metadata can be combined.
TEST_F(AllReduceCombinerTest,CombineFromTwoDomainsWithSameMetadata)306 TEST_F(AllReduceCombinerTest, CombineFromTwoDomainsWithSameMetadata) {
307   const char* const hlo_string = R"(
308 HloModule Module
309 
310 summit {
311   lhs = f32[] parameter(0)
312   rhs = f32[] parameter(1)
313   ROOT add = f32[] add(lhs, rhs)
314 }
315 
316 ENTRY entry {
317   param0 = f32[128] parameter(0), sharding={maximal device=0}
318   param1 = f32[128] parameter(1), sharding={maximal device=1}
319   param2 = f32[128] parameter(2), sharding={maximal device=1}
320   crs0 = f32[128] all-reduce(param0),
321     replica_groups={}, to_apply=summit, sharding={maximal device=0}
322   crs1 = f32[128] all-reduce(param1),
323     replica_groups={}, to_apply=summit, sharding={maximal device=1}
324   crs2 = f32[128] all-reduce(param2),
325     replica_groups={}, to_apply=summit, sharding={maximal device=0}
326   domain0 = f32[128] domain(crs0),
327     domain={kind="sharding", entry={{maximal device=0}, {maximal device=1},
328     {maximal device=0}}, exit={maximal device=0}}
329   domain1 = f32[128] domain(crs1),
330     domain={kind="sharding", entry={{maximal device=0}, {maximal device=1},
331     {maximal device=0}}, exit={maximal device=1}}
332   domain2 = f32[128] domain(crs2),
333     domain={kind="sharding", entry={{maximal device=0}, {maximal device=1},
334     {maximal device=0}}, exit={maximal device=0}}
335   ROOT tuple = (f32[128], f32[128], f32[128]) tuple(domain0, domain1, domain2),
336     sharding={{maximal device=0}, {maximal device=1}, {maximal device=0}}
337 }
338 )";
339   TF_ASSERT_OK_AND_ASSIGN(auto module,
340                           ParseAndReturnVerifiedModule(hlo_string));
341 
342   AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
343   ASSERT_EQ(AllReduceCount(*module), 3);
344   TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
345   EXPECT_EQ(AllReduceCount(*module), 2);
346   EXPECT_TRUE(changed);
347 }
348 
TEST_F(AllReduceCombinerTest,DoNotCombineCrossShardAndCrosReplicaInSPMD)349 TEST_F(AllReduceCombinerTest, DoNotCombineCrossShardAndCrosReplicaInSPMD) {
350   const char* const hlo_string = R"(
351 HloModule Module
352 
353 summit {
354   lhs = f32[] parameter(0)
355   rhs = f32[] parameter(1)
356   ROOT add = f32[] add(lhs, rhs)
357 }
358 
359 ENTRY entry {
360   param0 = f32[128] parameter(0), sharding={maximal device=0}
361   param1 = f32[128] parameter(1), sharding={maximal device=1}
362   cross_shard_ar = f32[128] all-reduce(param0),
363     replica_groups={{0}}, to_apply=summit, channel_id=1
364   cross_replica_ar = f32[128] all-reduce(param1),
365     replica_groups={{0}}, to_apply=summit, sharding={maximal device=1}
366   ROOT tuple = (f32[128], f32[128]) tuple(cross_shard_ar, cross_replica_ar)
367 }
368 )";
369   TF_ASSERT_OK_AND_ASSIGN(auto module,
370                           ParseAndReturnVerifiedModule(hlo_string));
371 
372   AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
373   ASSERT_EQ(AllReduceCount(*module), 2);
374   TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
375   EXPECT_EQ(AllReduceCount(*module), 2);
376   EXPECT_FALSE(changed);
377 }
378 
TEST_F(AllReduceCombinerTest,CrossCoreAllReduce)379 TEST_F(AllReduceCombinerTest, CrossCoreAllReduce) {
380   const char* const hlo_string = R"(
381 HloModule Module
382 
383 summit {
384   lhs = f32[] parameter(0)
385   rhs = f32[] parameter(1)
386   ROOT add = f32[] add(lhs, rhs)
387 }
388 
389 ENTRY entry {
390   param0 = f32[128] parameter(0), sharding={maximal device=0}
391   param1 = f32[128] parameter(1), sharding={maximal device=1}
392   crs00 = f32[128] all-reduce(param0),
393     replica_groups={{0}}, channel_id=1, to_apply=summit,
394     sharding={maximal device=0}
395   crs01 = f32[128] all-reduce(param1),
396     replica_groups={{0}}, channel_id=1, to_apply=summit,
397     sharding={maximal device=1}
398   crs10 = f32[128] all-reduce(param0),
399     replica_groups={{0}}, channel_id=2, to_apply=summit,
400     sharding={maximal device=0}
401   crs11 = f32[128] all-reduce(param1),
402     replica_groups={{0}}, channel_id=2, to_apply=summit,
403     sharding={maximal device=1}
404   domain0 = f32[128] domain(crs00),
405     domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
406   ROOT add = f32[128] add(domain0, crs11),
407     sharding={maximal device=1}
408 })";
409   TF_ASSERT_OK_AND_ASSIGN(auto module,
410                           ParseAndReturnVerifiedModule(hlo_string));
411 
412   AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
413   ASSERT_EQ(AllReduceCount(*module), 4);
414   TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
415   EXPECT_EQ(AllReduceCount(*module), 2);
416   EXPECT_TRUE(changed);
417 
418   EXPECT_THAT(
419       module->entry_computation()->root_instruction(),
420       op::Add(op::Domain(op::GetTupleElement(
421                   AllOf(op::AllReduce(op::Parameter(0), op::Parameter(0)),
422                         op::Shape("(f32[128], f32[128])")),
423                   1)),
424               op::GetTupleElement(
425                   AllOf(op::AllReduce(op::Parameter(1), op::Parameter(1)),
426                         op::Shape("(f32[128], f32[128])")),
427                   0)));
428 }
429 
TEST_F(AllReduceCombinerTest,CrossCombineGroupCycle)430 TEST_F(AllReduceCombinerTest, CrossCombineGroupCycle) {
431   const char* const hlo_string = R"(
432 HloModule module
433 
434 %add {
435   lhs = f32[] parameter(0)
436   rhs = f32[] parameter(1)
437   ROOT add = f32[] add(lhs, rhs)
438 }
439 
440 %max {
441   lhs = f32[] parameter(0)
442   rhs = f32[] parameter(1)
443   ROOT add = f32[] maximum(lhs, rhs)
444 }
445 ENTRY %comp {
446   p0 = f32[128] parameter(0)
447   p1 = f32[128] parameter(1)
448 
449   crs00 = f32[128] all-reduce(p0), to_apply=add
450   crs10 = f32[128] all-reduce(p1), to_apply=max
451 
452   crs01 = f32[128] all-reduce(crs00), to_apply=max
453   crs11 = f32[128] all-reduce(crs10), to_apply=add
454   add0 = f32[128] add(crs01, crs11)
455 
456   crs02 = f32[128] all-reduce(add0), to_apply=add
457   crs12 = f32[128] all-reduce(crs11), to_apply=add
458   ROOT tuple = (f32[128], f32[128]) tuple(crs02, crs12)
459 })";
460   TF_ASSERT_OK_AND_ASSIGN(auto module,
461                           ParseAndReturnVerifiedModule(hlo_string));
462 
463   AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
464   ASSERT_EQ(AllReduceCount(*module), 6);
465   TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
466   EXPECT_EQ(AllReduceCount(*module), 4);
467   EXPECT_TRUE(changed);
468 
469   auto crs0 = op::AllReduce(op::Parameter(0), op::AllReduce(op::Parameter(1)));
470   auto add = op::Add(op::AllReduce(op::GetTupleElement(crs0, 0)),
471                      op::GetTupleElement(crs0, 1));
472   auto crs1 = op::AllReduce(add, op::GetTupleElement(crs0));
473   EXPECT_THAT(
474       module->entry_computation()->root_instruction(),
475       op::Tuple(op::GetTupleElement(crs1, 0), op::GetTupleElement(crs1, 1)));
476 }
477 
478 }  // namespace
479 }  // namespace xla
480