1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/copy_insertion.h"
17
18 #include <set>
19
20 #include "tensorflow/compiler/xla/debug_options_flags.h"
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/hlo_runner.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/test_helpers.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/platform/test_benchmark.h"
34
35 namespace op = xla::testing::opcode_matchers;
36
37 namespace xla {
38 namespace {
39
40 using ::testing::UnorderedElementsAre;
41
CountCopies(const HloComputation & computation)42 int64 CountCopies(const HloComputation& computation) {
43 int64_t count = 0;
44 for (const auto& instruction : computation.instructions()) {
45 if (instruction->opcode() == HloOpcode::kCopy) {
46 count++;
47 }
48 }
49 return count;
50 }
51
CountCopies(const HloModule & module)52 int64 CountCopies(const HloModule& module) {
53 int64_t count = 0;
54 for (const auto& computation : module.computations()) {
55 count += CountCopies(*computation);
56 }
57 return count;
58 }
59
CountControlEdges(const HloComputation & computation)60 int64 CountControlEdges(const HloComputation& computation) {
61 int64_t count = 0;
62 for (const auto& instruction : computation.instructions()) {
63 count += instruction->control_successors().size();
64 }
65 return count;
66 }
67
CountControlEdges(const HloModule & module)68 int64 CountControlEdges(const HloModule& module) {
69 int64_t count = 0;
70 for (const auto& computation : module.computations()) {
71 count += CountControlEdges(*computation);
72 }
73 return count;
74 }
75
76 class CopyInsertionTest : public HloTestBase {
77 protected:
InsertCopies(HloModule * module)78 void InsertCopies(HloModule* module) {
79 CopyInsertion copy_insertion;
80 ASSERT_IS_OK(copy_insertion.Run(module).status());
81 }
82
83 const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
84 };
85
TEST_F(CopyInsertionTest,SingleParameter)86 TEST_F(CopyInsertionTest, SingleParameter) {
87 // Computation is a single parameter passed into a tuple. The parameter should
88 // be copied before entering the tuple.
89 auto builder = HloComputation::Builder(TestName());
90 HloInstruction* x = builder.AddInstruction(
91 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x"));
92 HloInstruction* tuple =
93 builder.AddInstruction(HloInstruction::CreateTuple({x}));
94
95 EXPECT_THAT(x->users(), UnorderedElementsAre(tuple));
96
97 auto module = CreateNewVerifiedModule();
98 module->AddEntryComputation(builder.Build());
99
100 InsertCopies(module.get());
101
102 EXPECT_THAT(module->entry_computation()->root_instruction(),
103 op::Tuple(op::Copy(x)));
104 }
105
TEST_F(CopyInsertionTest,SingleConstant)106 TEST_F(CopyInsertionTest, SingleConstant) {
107 // Computation is a single constant passed into a tuple. The parameter should
108 // be copied before entering the tuple.
109 auto builder = HloComputation::Builder(TestName());
110 HloInstruction* constant = builder.AddInstruction(
111 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
112 HloInstruction* tuple =
113 builder.AddInstruction(HloInstruction::CreateTuple({constant}));
114
115 EXPECT_THAT(constant->users(), UnorderedElementsAre(tuple));
116
117 auto module = CreateNewVerifiedModule();
118 module->AddEntryComputation(builder.Build());
119
120 InsertCopies(module.get());
121 EXPECT_EQ(CountCopies(*module), 1);
122
123 EXPECT_THAT(module->entry_computation()->root_instruction(),
124 op::Tuple(op::Copy(constant)));
125 }
126
TEST_F(CopyInsertionTest,ExistingCopiesNotRemoved)127 TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) {
128 // Verify that kCopy instructions which change layout and exist before
129 // copy-insertion remain in the graph after copy-insertion.
130 auto module = CreateNewVerifiedModule();
131
132 auto builder = HloComputation::Builder(TestName());
133 HloInstruction* constant =
134 builder.AddInstruction(HloInstruction::CreateConstant(
135 LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}})));
136 auto minor_to_major = LayoutUtil::MinorToMajor(constant->shape());
137 Layout reversed_layout =
138 LayoutUtil::MakeLayoutFromMajorToMinor(minor_to_major);
139 Shape copy_shape = constant->shape();
140 *copy_shape.mutable_layout() = reversed_layout;
141 HloInstruction* copy_1 = builder.AddInstruction(
142 HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant));
143 HloInstruction* copy_2 = builder.AddInstruction(
144 HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant));
145 HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
146 constant->shape(), HloOpcode::kAdd, copy_1, copy_2));
147 builder.AddInstruction(
148 HloInstruction::CreateUnary(add->shape(), HloOpcode::kCopy, add));
149
150 module->AddEntryComputation(builder.Build());
151
152 EXPECT_EQ(CountCopies(*module), 3);
153
154 InsertCopies(module.get());
155
156 EXPECT_EQ(CountCopies(*module), 2);
157
158 EXPECT_EQ(module->entry_computation()->root_instruction(), add);
159 EXPECT_THAT(module->entry_computation()->root_instruction(),
160 op::Add(op::Copy(op::Constant()), op::Copy(op::Constant())));
161 }
162
TEST_F(CopyInsertionTest,MultipleConstantsAndParameters)163 TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) {
164 // Create a computation with more than one constant and parameter. Only one of
165 // each constant/parameter is pointed to by the output tuple. Only these
166 // instructions should be copied.
167 auto builder = HloComputation::Builder(TestName());
168
169 HloInstruction* constant1 = builder.AddInstruction(
170 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
171 HloInstruction* constant2 = builder.AddInstruction(
172 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
173
174 HloInstruction* x = builder.AddInstruction(
175 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x"));
176 HloInstruction* y = builder.AddInstruction(
177 HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "y"));
178
179 HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
180 ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, constant1, y));
181
182 builder.AddInstruction(HloInstruction::CreateTuple({constant2, x, add}));
183
184 auto module = CreateNewVerifiedModule();
185 module->AddEntryComputation(builder.Build());
186
187 InsertCopies(module.get());
188 EXPECT_EQ(CountCopies(*module), 2);
189
190 EXPECT_THAT(
191 module->entry_computation()->root_instruction(),
192 op::Tuple(op::Copy(constant2), op::Copy(x), op::Add(constant1, y)));
193 }
194
TEST_F(CopyInsertionTest,AmbiguousPointsToSet)195 TEST_F(CopyInsertionTest, AmbiguousPointsToSet) {
196 // Create a computation using select which has an ambiguous points-to set for
197 // the computation result. Verify that copies are added properly.
198 auto builder = HloComputation::Builder(TestName());
199 HloInstruction* constant1 = builder.AddInstruction(
200 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
201 HloInstruction* constant2 = builder.AddInstruction(
202 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
203 HloInstruction* constant3 = builder.AddInstruction(
204 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
205
206 HloInstruction* tuple1 = builder.AddInstruction(
207 HloInstruction::CreateTuple({constant1, constant2}));
208 HloInstruction* tuple2 = builder.AddInstruction(
209 HloInstruction::CreateTuple({constant3, constant2}));
210
211 HloInstruction* pred = builder.AddInstruction(
212 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
213 builder.AddInstruction(HloInstruction::CreateTernary(
214 tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
215
216 EXPECT_THAT(constant1->users(), UnorderedElementsAre(tuple1));
217 EXPECT_THAT(constant2->users(), UnorderedElementsAre(tuple1, tuple2));
218 EXPECT_THAT(constant3->users(), UnorderedElementsAre(tuple2));
219
220 auto module = CreateNewVerifiedModule();
221 module->AddEntryComputation(builder.Build());
222
223 HloInstruction* old_root = module->entry_computation()->root_instruction();
224 InsertCopies(module.get());
225 EXPECT_EQ(CountCopies(*module), 2);
226
227 EXPECT_THAT(module->entry_computation()->root_instruction(),
228 op::Tuple(op::Copy(op::GetTupleElement(old_root)),
229 op::Copy(op::GetTupleElement(old_root))));
230 }
231
TEST_F(CopyInsertionTest,BitcastParameter)232 TEST_F(CopyInsertionTest, BitcastParameter) {
233 // The output of a bitcast is its operand (same buffer), so a bitcast
234 // parameter feeding the result must have a copy added.
235 auto builder = HloComputation::Builder(TestName());
236 HloInstruction* x = builder.AddInstruction(
237 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {4}), "x"));
238 HloInstruction* bitcast = builder.AddInstruction(
239 HloInstruction::CreateBitcast(ShapeUtil::MakeShape(F32, {2, 2}), x));
240
241 auto module = CreateNewVerifiedModule();
242 module->AddEntryComputation(builder.Build());
243
244 EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast));
245
246 HloInstruction* old_root = module->entry_computation()->root_instruction();
247 InsertCopies(module.get());
248 EXPECT_EQ(CountCopies(*module), 1);
249
250 EXPECT_THAT(module->entry_computation()->root_instruction(),
251 op::Copy(old_root));
252 }
253
TEST_F(CopyInsertionTest,BitcastConstant)254 TEST_F(CopyInsertionTest, BitcastConstant) {
255 // The output of a bitcast is its operand (same buffer), so a bitcast
256 // constant feeding the result must have a copy added.
257 auto builder = HloComputation::Builder(TestName());
258 HloInstruction* constant =
259 builder.AddInstruction(HloInstruction::CreateConstant(
260 LiteralUtil::CreateR1<float>({1.0, 42.0})));
261 HloInstruction* bitcast = builder.AddInstruction(
262 HloInstruction::CreateBitcast(ShapeUtil::MakeShape(F32, {2}), constant));
263
264 auto module = CreateNewVerifiedModule();
265 module->AddEntryComputation(builder.Build());
266
267 EXPECT_THAT(constant->users(), UnorderedElementsAre(bitcast));
268
269 HloInstruction* old_root = module->entry_computation()->root_instruction();
270 InsertCopies(module.get());
271 EXPECT_EQ(CountCopies(*module), 1);
272
273 EXPECT_THAT(module->entry_computation()->root_instruction(),
274 op::Copy(old_root));
275 }
276
TEST_F(CopyInsertionTest,BitcastTupleElementParameter)277 TEST_F(CopyInsertionTest, BitcastTupleElementParameter) {
278 // Same as BitcastParameter, but the bitcast is wrapped in a tuple.
279 auto builder = HloComputation::Builder(TestName());
280 HloInstruction* x = builder.AddInstruction(
281 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {4}), "x"));
282 HloInstruction* bitcast = builder.AddInstruction(
283 HloInstruction::CreateBitcast(ShapeUtil::MakeShape(F32, {2, 2}), x));
284 builder.AddInstruction(HloInstruction::CreateTuple({bitcast}));
285
286 auto module = CreateNewVerifiedModule();
287 module->AddEntryComputation(builder.Build());
288
289 EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast));
290
291 InsertCopies(module.get());
292 EXPECT_EQ(CountCopies(*module), 1);
293
294 EXPECT_THAT(module->entry_computation()->root_instruction(),
295 op::Tuple(op::Copy(bitcast)));
296 }
297
TEST_F(CopyInsertionTest,NestedTupleParameter)298 TEST_F(CopyInsertionTest, NestedTupleParameter) {
299 // Construct a trivial computation where the root of the computation is a
300 // nested tuple-shaped parameter. The parameter should be deep copied and the
301 // copy should be the root of the computation.
302 auto builder = HloComputation::Builder(TestName());
303
304 // Param shape is: ((F32[], S32[1,2,3]), F32[42])
305 builder.AddInstruction(HloInstruction::CreateParameter(
306 0,
307 ShapeUtil::MakeTupleShape(
308 {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}),
309 ShapeUtil::MakeShape(S32, {1, 2, 3})}),
310 ShapeUtil::MakeShape(F32, {42})}),
311 "param0"));
312
313 auto module = CreateNewVerifiedModule();
314 module->AddEntryComputation(builder.Build());
315
316 EXPECT_EQ(HloOpcode::kParameter,
317 module->entry_computation()->root_instruction()->opcode());
318
319 HloInstruction* old_root = module->entry_computation()->root_instruction();
320 InsertCopies(module.get());
321 EXPECT_EQ(CountCopies(*module), 3);
322
323 HloInstruction* new_root = module->entry_computation()->root_instruction();
324 EXPECT_NE(old_root, new_root);
325
326 EXPECT_THAT(
327 new_root,
328 op::Tuple(
329 op::Tuple(
330 op::Copy(op::GetTupleElement(op::GetTupleElement(old_root))),
331 op::Copy(op::GetTupleElement(op::GetTupleElement(old_root)))),
332 op::Copy(op::GetTupleElement(old_root))));
333 }
334
TEST_F(CopyInsertionTest,ElementOfNestedTupleParameter)335 TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) {
336 // Construct a computation where the root of the computation is a tuple
337 // element of a nested tuple-shaped parameter.
338 auto builder = HloComputation::Builder(TestName());
339
340 // Param shape is: ((F32[], S32[1,2,3]), F32[42])
341 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
342 0,
343 ShapeUtil::MakeTupleShape(
344 {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}),
345 ShapeUtil::MakeShape(S32, {1, 2, 3})}),
346 ShapeUtil::MakeShape(F32, {42})}),
347 "param0"));
348
349 // The return value of the computation is the zero-th element of the nested
350 // tuple. This element is itself a tuple.
351 auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
352 ShapeUtil::GetSubshape(param->shape(), {0}), param, 0));
353
354 auto module = CreateNewVerifiedModule();
355 module->AddEntryComputation(builder.Build());
356
357 EXPECT_EQ(gte, module->entry_computation()->root_instruction());
358
359 InsertCopies(module.get());
360 EXPECT_EQ(CountCopies(*module), 2);
361
362 EXPECT_THAT(
363 module->entry_computation()->root_instruction(),
364 op::Tuple(op::Copy(op::GetTupleElement(op::GetTupleElement(param))),
365 op::Copy(op::GetTupleElement(op::GetTupleElement(param)))));
366 }
367
TEST_F(CopyInsertionTest,AmbiguousTopLevelRoot)368 TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) {
369 // Create a computation using select which has an ambiguous points-to set for
370 // the top-level buffer of the root of the computation. Verify that a shallow
371 // copy is added.
372 auto builder = HloComputation::Builder(TestName());
373 HloInstruction* constant1 = builder.AddInstruction(
374 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
375 HloInstruction* constant2 = builder.AddInstruction(
376 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
377
378 HloInstruction* tuple1 = builder.AddInstruction(
379 HloInstruction::CreateTuple({constant1, constant2}));
380 HloInstruction* tuple2 = builder.AddInstruction(
381 HloInstruction::CreateTuple({constant2, constant1}));
382
383 HloInstruction* pred = builder.AddInstruction(
384 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
385 HloInstruction* select = builder.AddInstruction(HloInstruction::CreateTernary(
386 tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
387 HloInstruction* gte =
388 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
389 ShapeUtil::GetSubshape(select->shape(), {0}), select, 0));
390
391 auto module = CreateNewVerifiedModule();
392 module->AddEntryComputation(builder.Build());
393
394 EXPECT_EQ(gte, module->entry_computation()->root_instruction());
395
396 HloInstruction* old_root = module->entry_computation()->root_instruction();
397 InsertCopies(module.get());
398 EXPECT_EQ(CountCopies(*module), 1);
399
400 EXPECT_THAT(module->entry_computation()->root_instruction(),
401 op::Copy(old_root));
402 }
403
404 class WhileCopyInsertionTest : public CopyInsertionTest {
405 protected:
WhileCopyInsertionTest()406 WhileCopyInsertionTest() : module_(CreateNewVerifiedModule()) {}
407
408 // Builds a While condition computation which reads the induction variable
409 // from the tuple parameter, and returns a predicate indicating whether this
410 // value is less than the constant '10'.
411 // The parameter 'nested' specifies the loop state shape from which to
412 // read the induction variable.
BuildConditionComputation(const Shape & loop_state_shape)413 std::unique_ptr<HloComputation> BuildConditionComputation(
414 const Shape& loop_state_shape) {
415 auto builder = HloComputation::Builder(TestName() + ".Condition");
416 auto limit_const = builder.AddInstruction(
417 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(10)));
418 auto loop_state = builder.AddInstruction(
419 HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
420 auto induction_variable =
421 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
422 limit_const->shape(), loop_state, 0));
423 builder.AddInstruction(HloInstruction::CreateCompare(
424 condition_result_shape_, induction_variable, limit_const,
425 ComparisonDirection::kLt));
426 return builder.Build();
427 }
428
429 // Builds a While body computation with one output tuple element dependent on
430 // both input tuple elements.
431 // EX:
432 // Body({in0, in1})
433 // out0 = Add(in0, 1)
434 // out1 = Add(BCast(in0), in1)
435 // Tuple(out0, out1)
BuildDependentBodyComputation()436 std::unique_ptr<HloComputation> BuildDependentBodyComputation() {
437 auto builder = HloComputation::Builder(TestName() + ".Body");
438 // Create param instruction to access loop state.
439 auto loop_state = builder.AddInstruction(
440 HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
441 // Update the induction variable GTE(0).
442 auto induction_variable =
443 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
444 induction_variable_shape_, loop_state, 0));
445 auto inc = builder.AddInstruction(
446 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
447 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
448 induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
449 // Update data GTE(1).
450 auto data = builder.AddInstruction(
451 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
452 // Use 'induction_variable' in computation with no path to output tuple.
453 Shape f32_scalar_shape = ShapeUtil::MakeShape(F32, {});
454 auto convert = builder.AddInstruction(
455 HloInstruction::CreateConvert(f32_scalar_shape, induction_variable));
456 auto update = builder.AddInstruction(
457 HloInstruction::CreateBroadcast(data_shape_, convert, {}));
458 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
459 data_shape_, HloOpcode::kAdd, data, update));
460 // Create output Tuple.
461 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
462 return builder.Build();
463 }
464
465 // Builds a While body computation with two output tuple elements dependent on
466 // both input tuple elements.
467 //
468 // EX: Body({in0, in1, in2})
469 // out0 = Add(in0, 1)
470 // out1 = in1
471 // out2 = in2
472 // Tuple(out0, out1, out2)
BuildDependentBodyComputation2()473 std::unique_ptr<HloComputation> BuildDependentBodyComputation2() {
474 auto builder = HloComputation::Builder(TestName() + ".Body");
475
476 const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
477 {induction_variable_shape_, data_shape_, data_shape_});
478
479 auto loop_state = builder.AddInstruction(
480 HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
481
482 // Update the induction variable GTE(0).
483 auto induction_variable =
484 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
485 induction_variable_shape_, loop_state, 0));
486 auto inc = builder.AddInstruction(
487 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
488
489 // add0 = Add(in0, 1)
490 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
491 induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
492 // data1 = GTE(1).
493 HloInstruction* data1 = builder.AddInstruction(
494 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
495
496 // data2 = GTE(2).
497 HloInstruction* data2 = builder.AddInstruction(
498 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 2));
499
500 // Create output Tuple.
501 builder.AddInstruction(HloInstruction::CreateTuple({add0, data1, data2}));
502
503 return builder.Build();
504 }
505
506 // Builds a While body computation with read-only tuple element 0.
507 // EX:
508 // Body({in0, in1})
509 // out0 = in0
510 // out1 = Add(BCast(in0), in1)
511 // Tuple(out0, out1)
BuildDependentBodyOneReadOnlyComputation()512 std::unique_ptr<HloComputation> BuildDependentBodyOneReadOnlyComputation() {
513 auto builder = HloComputation::Builder(TestName() + ".Body");
514 // Create param instruction to access loop state.
515 auto loop_state = builder.AddInstruction(
516 HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
517 // Update the induction variable GTE(0).
518 auto induction_variable =
519 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
520 induction_variable_shape_, loop_state, 0));
521 // Update data GTE(1).
522 auto data = builder.AddInstruction(
523 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
524
525 // Use 'induction_variable' in computation with no path to output tuple.
526 Shape f32_scalar_shape = ShapeUtil::MakeShape(F32, {});
527 auto convert = builder.AddInstruction(
528 HloInstruction::CreateConvert(f32_scalar_shape, induction_variable));
529 auto update = builder.AddInstruction(
530 HloInstruction::CreateBroadcast(data_shape_, convert, {}));
531 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
532 data_shape_, HloOpcode::kAdd, data, update));
533 // Create output Tuple.
534 builder.AddInstruction(
535 HloInstruction::CreateTuple({induction_variable, add1}));
536 return builder.Build();
537 }
538
539 // Builds a While body computation with independent outputs.
540 // EX:
541 // Body({in0, in1})
542 // out0 = Add(in0, 1)
543 // out1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
544 // Tuple(out0, out1)
BuildIndependentBodyComputation(bool nested=false)545 std::unique_ptr<HloComputation> BuildIndependentBodyComputation(
546 bool nested = false) {
547 auto builder = HloComputation::Builder(TestName() + ".Body");
548 // Create param instruction to access loop state.
549 const Shape& loop_state_shape =
550 nested ? nested_loop_state_shape_ : loop_state_shape_;
551
552 auto loop_state = builder.AddInstruction(
553 HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
554 // Update the induction variable GTE(0).
555 auto induction_variable =
556 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
557 induction_variable_shape_, loop_state, 0));
558 auto inc = builder.AddInstruction(
559 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
560 // add0 = Add(in0, 1)
561 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
562 induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
563 // Update data GTE(1).
564 HloInstruction* data = nullptr;
565 if (nested) {
566 data = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
567 nested_tuple_shape_, loop_state, 1));
568 data = builder.AddInstruction(
569 HloInstruction::CreateGetTupleElement(data_shape_, data, 0));
570 } else {
571 data = builder.AddInstruction(
572 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
573 }
574 auto update = builder.AddInstruction(
575 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
576 {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
577 // add1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
578 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
579 data_shape_, HloOpcode::kAdd, data, update));
580 // Create output Tuple.
581 if (nested) {
582 auto nested_tuple =
583 builder.AddInstruction(HloInstruction::CreateTuple({add1, add1}));
584 builder.AddInstruction(HloInstruction::CreateTuple({add0, nested_tuple}));
585 } else {
586 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
587 }
588 return builder.Build();
589 }
590
591 // Builds a While body computation with the following nested tuple
592 // sub-computation:
593 // |
594 // GTE(loop_state, 1)
595 // / \
596 // GTE(GTE(loop_state, 1), 0) GTE(GTE(loop_state, 1), 1)
597 // | |
598 // Add Reverse
599 // | |
BuildNestedBodyComputation()600 std::unique_ptr<HloComputation> BuildNestedBodyComputation() {
601 auto builder = HloComputation::Builder(TestName() + ".Body");
602 // Create param instruction to access loop state.
603 auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter(
604 0, nested_loop_state_shape_, "loop_state"));
605 // Update GTE(0).
606 auto gte0 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
607 induction_variable_shape_, loop_state, 0));
608 auto inc = builder.AddInstruction(
609 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
610 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
611 gte0->shape(), HloOpcode::kAdd, gte0, inc));
612
613 // GTE(loop_state, 1)
614 auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
615 nested_tuple_shape_, loop_state, 1));
616 // GTE(GTE(loop_state, 1), 0) -> Add
617 auto gte10 = builder.AddInstruction(
618 HloInstruction::CreateGetTupleElement(data_shape_, gte1, 0));
619 auto update10 = builder.AddInstruction(
620 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
621 {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
622 auto add10 = builder.AddInstruction(HloInstruction::CreateBinary(
623 data_shape_, HloOpcode::kAdd, gte10, update10));
624
625 // GTE(GTE(loop_state, 1), 1) -> Reverse
626 auto gte11 = builder.AddInstruction(
627 HloInstruction::CreateGetTupleElement(data_shape_, gte1, 1));
628 auto rev11 = builder.AddInstruction(
629 HloInstruction::CreateReverse(data_shape_, gte11, {0}));
630
631 // Create output Tuple.
632 auto inner_tuple =
633 builder.AddInstruction(HloInstruction::CreateTuple({add10, rev11}));
634 builder.AddInstruction(HloInstruction::CreateTuple({add0, inner_tuple}));
635 return builder.Build();
636 }
637
638 // Builds a While instruction using 'condition' and 'body' sub-computations.
639 // Init operand is initialized to zeros of appropriate shape.
BuildWhileInstruction(HloComputation * condition,HloComputation * body,bool nested=false)640 HloInstruction* BuildWhileInstruction(HloComputation* condition,
641 HloComputation* body,
642 bool nested = false) {
643 auto builder = HloComputation::Builder(TestName() + ".While");
644 auto induction_var_init = builder.AddInstruction(
645 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
646
647 auto data_init = builder.AddInstruction(
648 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
649 {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
650
651 if (nested) {
652 auto inner_init = builder.AddInstruction(
653 HloInstruction::CreateTuple({data_init, data_init}));
654 auto loop_state_init = builder.AddInstruction(
655 HloInstruction::CreateTuple({induction_var_init, inner_init}));
656 auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile(
657 loop_state_init->shape(), condition, body, loop_state_init));
658 module_->AddEntryComputation(builder.Build());
659 return while_hlo;
660 }
661
662 auto loop_state_init = builder.AddInstruction(
663 HloInstruction::CreateTuple({induction_var_init, data_init}));
664 auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile(
665 loop_state_shape_, condition, body, loop_state_init));
666 module_->AddEntryComputation(builder.Build());
667 return while_hlo;
668 }
669
BuildWhileInstruction_InitPointsToConstant()670 HloInstruction* BuildWhileInstruction_InitPointsToConstant() {
671 auto builder = HloComputation::Builder(TestName() + ".While");
672 auto data_init = builder.AddInstruction(
673 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
674 {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
675 return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init,
676 &builder);
677 }
678
BuildWhileInstruction_InitPointsToParameter()679 HloInstruction* BuildWhileInstruction_InitPointsToParameter() {
680 auto builder = HloComputation::Builder(TestName() + ".While");
681 auto data_init = builder.AddInstruction(
682 HloInstruction::CreateParameter(0, data_shape_, "data_init"));
683 return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init,
684 &builder);
685 }
686
BuildWhileInstruction_InitPointsToAmbiguous()687 HloInstruction* BuildWhileInstruction_InitPointsToAmbiguous() {
688 auto builder = HloComputation::Builder(TestName() + ".While");
689
690 auto one = builder.AddInstruction(
691 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
692 auto v1 = builder.AddInstruction(
693 HloInstruction::CreateBroadcast(data_shape_, one, {}));
694 auto zero = builder.AddInstruction(
695 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
696 auto v2 = builder.AddInstruction(
697 HloInstruction::CreateBroadcast(data_shape_, zero, {}));
698
699 auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({v1, v2}));
700 auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({v2, v1}));
701
702 auto pred = builder.AddInstruction(
703 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
704 auto data_init = builder.AddInstruction(HloInstruction::CreateTernary(
705 nested_tuple_shape_, HloOpcode::kTupleSelect, pred, tuple1, tuple2));
706
707 return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_,
708 data_init, &builder);
709 }
710
BuildWhileInstruction_InitPointsToNonDistinct()711 HloInstruction* BuildWhileInstruction_InitPointsToNonDistinct() {
712 auto builder = HloComputation::Builder(TestName() + ".While");
713
714 auto one = builder.AddInstruction(
715 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
716 auto one_vec = builder.AddInstruction(
717 HloInstruction::CreateBroadcast(data_shape_, one, {}));
718 auto data_init =
719 builder.AddInstruction(HloInstruction::CreateTuple({one_vec, one_vec}));
720
721 return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_,
722 data_init, &builder);
723 }
724
BuildWhileInstruction_InitPointsToInterfering()725 HloInstruction* BuildWhileInstruction_InitPointsToInterfering() {
726 auto builder = HloComputation::Builder(TestName() + ".While");
727 auto one = builder.AddInstruction(
728 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
729 auto data_init = builder.AddInstruction(
730 HloInstruction::CreateBroadcast(data_shape_, one, {}));
731 auto one_vec = builder.AddInstruction(
732 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
733 {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
734 // Take a reference to 'data_init' to make it interfere with while result.
735 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
736 data_shape_, HloOpcode::kAdd, data_init, one_vec));
737
738 auto xla_while = BuildWhileInstructionWithCustomInit(loop_state_shape_,
739 data_init, &builder);
740
741 // Add an additional binary operation operating on the while and the
742 // interfering add so that neither operation is dead.
743 auto gte = xla_while->parent()->AddInstruction(
744 HloInstruction::CreateGetTupleElement(
745 ShapeUtil::GetSubshape(xla_while->shape(), {1}), xla_while, 1));
746 auto sub = xla_while->parent()->AddInstruction(HloInstruction::CreateBinary(
747 data_shape_, HloOpcode::kSubtract, add, gte));
748 auto gte0 = xla_while->parent()->AddInstruction(
749 HloInstruction::CreateGetTupleElement(
750 ShapeUtil::GetSubshape(xla_while->shape(), {0}), xla_while, 0));
751 auto tuple = xla_while->parent()->AddInstruction(
752 HloInstruction::CreateTuple({gte0, sub}));
753
754 xla_while->parent()->set_root_instruction(tuple);
755
756 return xla_while;
757 }
758
BuildWhileInstructionWithCustomInit(const Shape & loop_state_shape,HloInstruction * data_init,HloComputation::Builder * builder)759 HloInstruction* BuildWhileInstructionWithCustomInit(
760 const Shape& loop_state_shape, HloInstruction* data_init,
761 HloComputation::Builder* builder) {
762 const bool nested =
763 ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_);
764 auto induction_var_init = builder->AddInstruction(
765 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
766 auto condition = module_->AddEmbeddedComputation(
767 BuildConditionComputation(loop_state_shape));
768 auto body = module_->AddEmbeddedComputation(
769 BuildIndependentBodyComputation(nested));
770 auto loop_state_init = builder->AddInstruction(
771 HloInstruction::CreateTuple({induction_var_init, data_init}));
772 auto while_hlo = builder->AddInstruction(HloInstruction::CreateWhile(
773 loop_state_shape, condition, body, loop_state_init));
774 module_->AddEntryComputation(builder->Build());
775 return while_hlo;
776 }
777
778 std::unique_ptr<HloModule> module_;
779 Shape induction_variable_shape_ = ShapeUtil::MakeShape(S32, {});
780 Shape data_shape_ = ShapeUtil::MakeShape(F32, {8});
781 Shape loop_state_shape_ =
782 ShapeUtil::MakeTupleShape({induction_variable_shape_, data_shape_});
783 Shape nested_tuple_shape_ =
784 ShapeUtil::MakeTupleShape({data_shape_, data_shape_});
785 Shape nested_loop_state_shape_ = ShapeUtil::MakeTupleShape(
786 {induction_variable_shape_, nested_tuple_shape_});
787 Shape condition_result_shape_ = ShapeUtil::MakeShape(PRED, {});
788 };
789
790 // Tests while body computation with independent tuple elements:
791 //
792 // While.Body({in0, in1})
793 // out0 = Add(in0, 1)
794 // out1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
795 // Tuple(out0, out1)
796 //
797 // CopyInsertion pass should not generate any copies.
798 //
TEST_F(WhileCopyInsertionTest,IndependentTupleElements)799 TEST_F(WhileCopyInsertionTest, IndependentTupleElements) {
800 auto condition = module_->AddEmbeddedComputation(
801 BuildConditionComputation(loop_state_shape_));
802 auto body =
803 module_->AddEmbeddedComputation(BuildIndependentBodyComputation());
804 auto while_hlo = BuildWhileInstruction(condition, body);
805
806 InsertCopies(module_.get());
807
808 // Body should have no copies as the adds can be done inplace.
809 EXPECT_EQ(CountCopies(*body), 0);
810 EXPECT_EQ(CountControlEdges(*module_), 0);
811
812 // Both init indices need copies as they are constants.
813 EXPECT_THAT(while_hlo->operand(0),
814 op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant())));
815 }
816
817 // Tests Copy Insertion when a while feeds another while
818 // PARAMETER
819 // | |
820 // GTE(0) GTE(1)
821 // | |
822 // X = CreateTuple(GTE(0), GTE(1))
823 // |
824 // WHILE(X) (root)
TEST_F(WhileCopyInsertionTest,WhileFeedingWhileThruParameterWithCopies)825 TEST_F(WhileCopyInsertionTest, WhileFeedingWhileThruParameterWithCopies) {
826 const string& hlo_string = R"(
827 HloModule DependentTupleElements
828
829 %DependentTupleElements.Body (loop_state.1: (s32[], f32[8])) -> (s32[], f32[8]) {
830 %loop_state.1 = (s32[], f32[8]{0}) parameter(0)
831 %get-tuple-element.1 = s32[] get-tuple-element((s32[], f32[8]{0}) %loop_state.1), index=0
832 %constant.1 = s32[] constant(1)
833 %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
834 %get-tuple-element.2 = f32[8]{0} get-tuple-element((s32[], f32[8]{0}) %loop_state.1), index=1
835 %convert = f32[] convert(s32[] %get-tuple-element.1)
836 %broadcast = f32[8]{0} broadcast(f32[] %convert), dimensions={}
837 %add.1 = f32[8]{0} add(f32[8]{0} %get-tuple-element.2, f32[8]{0} %broadcast)
838 ROOT %tuple = (s32[], f32[8]{0}) tuple(s32[] %add, f32[8]{0} %add.1)
839 }
840
841 %DependentTupleElements.Condition (loop_state: (s32[], f32[8])) -> pred[] {
842 %loop_state = (s32[], f32[8]{0}) parameter(0)
843 %get-tuple-element = s32[] get-tuple-element((s32[], f32[8]{0}) %loop_state), index=0
844 %constant = s32[] constant(10)
845 ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
846 }
847
848 ENTRY %DependentTupleElements.While () -> (s32[], f32[8]) {
849 %constant.2 = s32[] constant(0)
850 %constant.3 = f32[8]{0} constant({0, 0, 0, 0, 0, 0, 0, 0})
851 %tuple.1 = (s32[], f32[8]{0}) tuple(s32[] %constant.2, f32[8]{0} %constant.3)
852 ROOT %while.1 = (s32[], f32[8]{0}) while((s32[], f32[8]{0}) %tuple.1), condition=%DependentTupleElements.Condition, body=%DependentTupleElements.Body
853 }
854 )";
855 auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
856 auto module_ = module_or_status.ConsumeValueOrDie();
857 auto while_hlo = module_->entry_computation()->root_instruction();
858 // module_ and while_hlo are the pre-existing module and hlo, the below
859 // code generates a clone of the existing while and replaces that while
860 // with itself. The body of the new while calls the previous while
861 HloComputation* outer_while_condition =
862 module_->AddEmbeddedComputation(while_hlo->while_condition()->Clone());
863 HloComputation* outer_while_body =
864 module_->AddEmbeddedComputation(while_hlo->while_body()->Clone());
865 HloInstruction* outer_while =
866 while_hlo->parent()->AddInstruction(HloInstruction::CreateWhile(
867 while_hlo->shape(), outer_while_condition, outer_while_body,
868 while_hlo->mutable_operand(0)));
869 HloInstruction* outer_param = outer_while_body->parameter_instruction(0);
870 std::vector<HloInstruction*> materialized_gtes;
871 for (int i = 0; i < outer_param->shape().tuple_shapes_size(); ++i) {
872 materialized_gtes.push_back(
873 outer_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
874 outer_param->shape().tuple_shapes(i), outer_param, i)));
875 }
876 HloInstruction* dual_init = outer_while_body->AddInstruction(
877 HloInstruction::CreateTuple(materialized_gtes));
878 HloInstruction* dual_while =
879 outer_while_body->AddInstruction(HloInstruction::CreateWhile(
880 while_hlo->shape(), while_hlo->while_condition(),
881 while_hlo->while_body(), dual_init));
882 TF_CHECK_OK(outer_while_body->ReplaceInstruction(
883 outer_while_body->root_instruction(), dual_while));
884 TF_CHECK_OK(while_hlo->parent()->ReplaceInstruction(while_hlo, outer_while));
885 InsertCopies(module_.get());
886 }
887
888 // Tests Copy Insertion when a while feeds another while
889 // PARAMETER
890 // | |
891 // \ /
892 // WHILE(PARAMETER) (root)
TEST_F(WhileCopyInsertionTest,WhileFeedingWhileThruParameterNoCopies)893 TEST_F(WhileCopyInsertionTest, WhileFeedingWhileThruParameterNoCopies) {
894 const string& hlo_string = R"(
895 HloModule DependentTupleElements
896
897 %DependentTupleElements.Body (loop_state.1: (s32[], f32[8])) -> (s32[], f32[8]) {
898 %loop_state.1 = (s32[], f32[8]{0}) parameter(0)
899 %get-tuple-element.1 = s32[] get-tuple-element((s32[], f32[8]{0}) %loop_state.1), index=0
900 %constant.1 = s32[] constant(1)
901 %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
902 %get-tuple-element.2 = f32[8]{0} get-tuple-element((s32[], f32[8]{0}) %loop_state.1), index=1
903 %convert = f32[] convert(s32[] %get-tuple-element.1)
904 %broadcast = f32[8]{0} broadcast(f32[] %convert), dimensions={}
905 %add.1 = f32[8]{0} add(f32[8]{0} %get-tuple-element.2, f32[8]{0} %broadcast)
906 ROOT %tuple = (s32[], f32[8]{0}) tuple(s32[] %add, f32[8]{0} %add.1)
907 }
908
909 %DependentTupleElements.Condition (loop_state: (s32[], f32[8])) -> pred[] {
910 %loop_state = (s32[], f32[8]{0}) parameter(0)
911 %get-tuple-element = s32[] get-tuple-element((s32[], f32[8]{0}) %loop_state), index=0
912 %constant = s32[] constant(10)
913 ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
914 }
915
916 ENTRY %DependentTupleElements.While () -> (s32[], f32[8]) {
917 %constant.2 = s32[] constant(0)
918 %constant.3 = f32[8]{0} constant({0, 0, 0, 0, 0, 0, 0, 0})
919 %tuple.1 = (s32[], f32[8]{0}) tuple(s32[] %constant.2, f32[8]{0} %constant.3)
920 ROOT %while.1 = (s32[], f32[8]{0}) while((s32[], f32[8]{0}) %tuple.1), condition=%DependentTupleElements.Condition, body=%DependentTupleElements.Body
921 }
922 )";
923 auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
924 auto module_ = module_or_status.ConsumeValueOrDie();
925 auto while_hlo = module_->entry_computation()->root_instruction();
926 // module_ and while_hlo are the pre-existing module and hlo, the below
927 // code generates a clone of the existing while and replaces that while
928 // with itself. The body of the new while calls the previous while
929 HloComputation* outer_while_condition =
930 module_->AddEmbeddedComputation(while_hlo->while_condition()->Clone());
931 HloComputation* outer_while_body =
932 module_->AddEmbeddedComputation(while_hlo->while_body()->Clone());
933 HloInstruction* outer_while =
934 while_hlo->parent()->AddInstruction(HloInstruction::CreateWhile(
935 while_hlo->shape(), outer_while_condition, outer_while_body,
936 while_hlo->mutable_operand(0)));
937 HloInstruction* outer_param = outer_while_body->parameter_instruction(0);
938 HloInstruction* dual_while =
939 outer_while_body->AddInstruction(HloInstruction::CreateWhile(
940 while_hlo->shape(), while_hlo->while_condition(),
941 while_hlo->while_body(), outer_param));
942 TF_CHECK_OK(outer_while_body->ReplaceInstruction(
943 outer_while_body->root_instruction(), dual_while));
944 TF_CHECK_OK(while_hlo->parent()->ReplaceInstruction(while_hlo, outer_while));
945 InsertCopies(module_.get());
946 }
947
948 // Tests Copy Insertion when a while feeds another while
949 // PARAMETER
950 // | |
951 // \ /
952 // WHILE(PARAMETER) (root)
TEST_F(WhileCopyInsertionTest,WhileFeedingWhileThruParameterBig)953 TEST_F(WhileCopyInsertionTest, WhileFeedingWhileThruParameterBig) {
954 const string& hlo_string = R"(
955 HloModule DependentTupleElements
956
957 %DependentTupleElements.Body (loop_state.1: (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0})) -> (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) {
958 %loop_state.1 = (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) parameter(0)
959 %get-tuple-element.1 = s32[] get-tuple-element((s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) %loop_state.1), index=0
960 %constant.1 = s32[] constant(1)
961 %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
962 %get-tuple-element.2 = f32[8]{0} get-tuple-element((s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) %loop_state.1), index=1
963 %convert = f32[] convert(s32[] %get-tuple-element.1)
964 %broadcast = f32[8]{0} broadcast(f32[] %convert), dimensions={}
965 %add.1 = f32[8]{0} add(f32[8]{0} %get-tuple-element.2, f32[8]{0} %broadcast)
966 ROOT %tuple = (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) tuple(s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1)
967 }
968
969 %DependentTupleElements.Condition (loop_state: (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0})) -> pred[] {
970 %loop_state = (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) parameter(0)
971 %get-tuple-element = s32[] get-tuple-element((s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) %loop_state), index=0
972 %constant = s32[] constant(10)
973 ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
974 }
975
976 ENTRY %DependentTupleElements.While () -> (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) {
977 %constant.2 = s32[] constant(0)
978 %constant.3 = f32[8]{0} constant({0, 0, 0, 0, 0, 0, 0, 0})
979 %tuple.1 = (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) tuple(s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3)
980 ROOT %while.1 = (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) while( (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) %tuple.1), condition=%DependentTupleElements.Condition, body=%DependentTupleElements.Body
981 }
982 )";
983 auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
984 auto module_ = module_or_status.ConsumeValueOrDie();
985 auto while_hlo = module_->entry_computation()->root_instruction();
986 // module_ and while_hlo are the pre-existing module and hlo, the below
987 // code generates a clone of the existing while and replaces that while
988 // with itself. The body of the new while calls the previous while
989 HloComputation* outer_while_condition =
990 module_->AddEmbeddedComputation(while_hlo->while_condition()->Clone());
991 HloComputation* outer_while_body =
992 module_->AddEmbeddedComputation(while_hlo->while_body()->Clone());
993 HloInstruction* outer_while =
994 while_hlo->parent()->AddInstruction(HloInstruction::CreateWhile(
995 while_hlo->shape(), outer_while_condition, outer_while_body,
996 while_hlo->mutable_operand(0)));
997 HloInstruction* outer_param = outer_while_body->parameter_instruction(0);
998 std::vector<HloInstruction*> materialized_gtes;
999 for (int i = 0; i < outer_param->shape().tuple_shapes_size(); ++i) {
1000 materialized_gtes.push_back(
1001 outer_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
1002 outer_param->shape().tuple_shapes(i), outer_param, i)));
1003 }
1004 HloInstruction* dual_init = outer_while_body->AddInstruction(
1005 HloInstruction::CreateTuple(materialized_gtes));
1006 HloInstruction* dual_while =
1007 outer_while_body->AddInstruction(HloInstruction::CreateWhile(
1008 while_hlo->shape(), while_hlo->while_condition(),
1009 while_hlo->while_body(), dual_init));
1010 TF_CHECK_OK(outer_while_body->ReplaceInstruction(
1011 outer_while_body->root_instruction(), dual_while));
1012 TF_CHECK_OK(while_hlo->parent()->ReplaceInstruction(while_hlo, outer_while));
1013 InsertCopies(module_.get());
1014 }
1015
1016 // Tests while body computation with dependent tuple elements:
1017 //
1018 // While.Body({in0, in1})
1019 // out0 = Add(in0, 1)
1020 // out1 = Add(BCast(in0), in1)
1021 // Tuple(out0, out1)
1022 //
1023 // CopyInsertion pass should convert the root instruction to:
1024 //
1025 // Tuple(Copy(out0), out1)
1026 //
TEST_F(WhileCopyInsertionTest,DependentTupleElements)1027 TEST_F(WhileCopyInsertionTest, DependentTupleElements) {
1028 auto condition = module_->AddEmbeddedComputation(
1029 BuildConditionComputation(loop_state_shape_));
1030 auto body = module_->AddEmbeddedComputation(BuildDependentBodyComputation());
1031 auto while_hlo = BuildWhileInstruction(condition, body);
1032
1033 InsertCopies(module_.get());
1034
1035 EXPECT_EQ(CountCopies(*body), 1);
1036 EXPECT_EQ(CountControlEdges(*body), 0);
1037
1038 EXPECT_THAT(
1039 body->root_instruction(),
1040 op::Tuple(op::Add(), op::Add(op::GetTupleElement(), op::Broadcast())));
1041
1042 auto add = body->root_instruction()->operand(0);
1043 auto bcast = body->root_instruction()->operand(1)->operand(1);
1044 ASSERT_EQ(add->opcode(), HloOpcode::kAdd);
1045 ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
1046
1047 EXPECT_THAT(while_hlo->while_body()->root_instruction(),
1048 op::Tuple(op::Add(op::Copy(), op::Constant()),
1049 op::Add(op::GetTupleElement(),
1050 op::Broadcast(op::Convert(op::Copy())))));
1051
1052 // Both init indices need copies as they are constants.
1053 EXPECT_THAT(while_hlo->operand(0),
1054 op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant())));
1055 }
1056
1057 // Tests while body computation with read-only tuple element 0:
1058 //
1059 // PARAMETER
1060 // / \
1061 // GTE(0) GTE(1)
1062 // | \ |
1063 // | BCAST |
1064 // | \ |
1065 // | ADD
1066 // | |
1067 // \ /
1068 // TUPLE (root)
1069 //
1070 // CopyInsertion pass should not generate any copies for the while body.
TEST_F(WhileCopyInsertionTest,DependentTupleElements_OneReadOnly)1071 TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly) {
1072 auto condition = module_->AddEmbeddedComputation(
1073 BuildConditionComputation(loop_state_shape_));
1074 auto body = module_->AddEmbeddedComputation(
1075 BuildDependentBodyOneReadOnlyComputation());
1076 BuildWhileInstruction(condition, body);
1077
1078 InsertCopies(module_.get());
1079
1080 // No copies or control edges should be inserted. The body is legal as is.
1081 EXPECT_EQ(CountCopies(*body), 0);
1082 EXPECT_EQ(CountControlEdges(*body), 0);
1083 }
1084
1085 // Same as above, but with two while loops, sharing entry parameters.
TEST_F(WhileCopyInsertionTest,DependentTupleElements_OneReadOnly_TwoLoops_EntryParams)1086 TEST_F(WhileCopyInsertionTest,
1087 DependentTupleElements_OneReadOnly_TwoLoops_EntryParams) {
1088 auto condition1 = module_->AddEmbeddedComputation(
1089 BuildConditionComputation(loop_state_shape_));
1090 auto condition2 = module_->AddEmbeddedComputation(
1091 BuildConditionComputation(loop_state_shape_));
1092 auto body1 = module_->AddEmbeddedComputation(
1093 BuildDependentBodyOneReadOnlyComputation());
1094 auto body2 = module_->AddEmbeddedComputation(
1095 BuildDependentBodyOneReadOnlyComputation());
1096
1097 auto builder = HloComputation::Builder(TestName() + ".While");
1098 auto iter_param = builder.AddInstruction(
1099 HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
1100 auto data_param = builder.AddInstruction(
1101 HloInstruction::CreateParameter(1, data_shape_, "data"));
1102 auto loop_init = builder.AddInstruction(
1103 HloInstruction::CreateTuple({iter_param, data_param}));
1104
1105 auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
1106 loop_state_shape_, condition1, body1, loop_init));
1107 auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
1108 loop_state_shape_, condition2, body2, loop_init));
1109
1110 // Add a couple elements from each of the while so both whiles are live.
1111 auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1112 ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0));
1113 auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1114 ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0));
1115 builder.AddInstruction(
1116 HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2));
1117
1118 auto entry = module_->AddEntryComputation(builder.Build());
1119
1120 InsertCopies(module_.get());
1121
1122 // Neither body should have any copies or control edges in them.
1123 EXPECT_EQ(CountCopies(*body1), 0);
1124 EXPECT_EQ(CountCopies(*body2), 0);
1125 EXPECT_EQ(CountControlEdges(*body1), 0);
1126 EXPECT_EQ(CountControlEdges(*body2), 0);
1127
1128 // Only two copies should be necessary. Each of the whiles should have
1129 // a copy of tuple element 1 (init value is a parameter, and the element is
1130 // not non-read-only) so each of the while bodies gets its own buffer to write
1131 // element 1 into.
1132 EXPECT_EQ(CountCopies(*entry), 2);
1133
1134 EXPECT_EQ(while_hlo1->operand(0)->operand(1)->opcode(), HloOpcode::kCopy);
1135 EXPECT_EQ(while_hlo2->operand(0)->operand(1)->opcode(), HloOpcode::kCopy);
1136
1137 // The two copies of element 1 should be different.
1138 EXPECT_NE(while_hlo1->operand(0)->operand(1),
1139 while_hlo2->operand(0)->operand(1));
1140 }
1141
1142 // Same as above, but with two while loops, sharing non-parameters.
TEST_F(WhileCopyInsertionTest,DependentTupleElements_OneReadOnly_TwoLoops_NonParams)1143 TEST_F(WhileCopyInsertionTest,
1144 DependentTupleElements_OneReadOnly_TwoLoops_NonParams) {
1145 auto condition1 = module_->AddEmbeddedComputation(
1146 BuildConditionComputation(loop_state_shape_));
1147 auto condition2 = module_->AddEmbeddedComputation(
1148 BuildConditionComputation(loop_state_shape_));
1149 auto body1 = module_->AddEmbeddedComputation(
1150 BuildDependentBodyOneReadOnlyComputation());
1151 auto body2 = module_->AddEmbeddedComputation(
1152 BuildDependentBodyOneReadOnlyComputation());
1153
1154 auto builder = HloComputation::Builder(TestName() + ".While");
1155 auto iter_param = builder.AddInstruction(
1156 HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
1157 auto data_param = builder.AddInstruction(
1158 HloInstruction::CreateParameter(1, data_shape_, "data"));
1159 // Add dummy ops to ensure loop_init elements aren't entry parameters.
1160 Shape f32_scalar_shape = ShapeUtil::MakeShape(F32, {});
1161 auto convert = builder.AddInstruction(
1162 HloInstruction::CreateConvert(f32_scalar_shape, iter_param));
1163 auto iter_value = builder.AddInstruction(
1164 HloInstruction::CreateUnary(convert->shape(), HloOpcode::kExp, convert));
1165 auto convert2 = builder.AddInstruction(
1166 HloInstruction::CreateConvert(induction_variable_shape_, iter_value));
1167 auto data_value = builder.AddInstruction(HloInstruction::CreateUnary(
1168 data_param->shape(), HloOpcode::kExp, data_param));
1169 auto loop_init = builder.AddInstruction(
1170 HloInstruction::CreateTuple({convert2, data_value}));
1171
1172 auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
1173 loop_state_shape_, condition1, body1, loop_init));
1174 auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
1175 loop_state_shape_, condition2, body2, loop_init));
1176
1177 // Add a couple elements from each of the while so both whiles are not dead.
1178 auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1179 ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0));
1180 auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1181 ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0));
1182 builder.AddInstruction(
1183 HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2));
1184 auto entry = module_->AddEntryComputation(builder.Build());
1185
1186 InsertCopies(module_.get());
1187
1188 // Ideally only one copy should be necessary. One of the whiles should
1189 // have a copy of tuple element 1 (the non-read-only element) so each of the
1190 // while bodies gets its own buffer to write element 1 into. However, the
1191 // analysis isn't perfect and adds an additional copy of element 0.
1192 EXPECT_EQ(CountCopies(*entry), 2);
1193
1194 EXPECT_THAT(while_hlo1->operand(0),
1195 op::Tuple(op::Convert(op::Exp()), op::Copy(op::Exp())));
1196 EXPECT_THAT(while_hlo2->operand(0),
1197 op::Tuple(op::Convert(op::Exp()), op::Copy(op::Exp())));
1198 }
1199
1200 // Tests while body computation with nested tuple elements:
1201 //
1202 // |
1203 // GTE(loop_state, 1)
1204 // / \
1205 // GTE(GTE(loop_state, 1), 0) GTE(GTE(loop_state, 1), 1)
1206 // | |
1207 // Add Reverse
1208 // | |
1209 //
1210 // CopyInsertion pass will conceptually generate the following, but with the
1211 // actual GTE and Tuple instructions optimized away:
1212 //
1213 // Tuple // old root
1214 // / \
1215 // / \
1216 // GTE(0) GTE(1)
1217 // | / \
1218 // | / \
1219 // | GTE(0) GTE(1)
1220 // | | |
1221 // | | Copy
1222 // | | |
1223 // \ | /
1224 // \ Tuple // "inner" tuple.
1225 // \ /
1226 // \ /
1227 // Tuple // new root
1228 //
TEST_F(WhileCopyInsertionTest,NestedTupleElements)1229 TEST_F(WhileCopyInsertionTest, NestedTupleElements) {
1230 auto condition = module_->AddEmbeddedComputation(
1231 BuildConditionComputation(nested_loop_state_shape_));
1232 auto body = module_->AddEmbeddedComputation(BuildNestedBodyComputation());
1233 BuildWhileInstruction(condition, body, true);
1234
1235 // HloInstruction* old_root = body->root_instruction();
1236 InsertCopies(module_.get());
1237
1238 // The only copy necessary is for the kReverse as it cannot be done
1239 // in-place (instruction can share buffer with operand). The other elements of
1240 // the loop state are kAdd instructions which can be done in-place.
1241 EXPECT_EQ(CountCopies(*body), 1);
1242
1243 // Each element of the init needs a copy as all are constants.
1244 EXPECT_EQ(CountCopies(*module_), 4);
1245
1246 // Either the kReverse itself must be copied or the operand of the kReverse
1247 // must be copied.
1248 if (body->root_instruction()->operand(1)->operand(1)->opcode() ==
1249 HloOpcode::kCopy) {
1250 EXPECT_THAT(
1251 body->root_instruction(),
1252 op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Reverse()))));
1253 } else {
1254 EXPECT_THAT(
1255 body->root_instruction(),
1256 op::Tuple(op::Add(), op::Tuple(op::Add(), op::Reverse(op::Copy()))));
1257 }
1258 }
1259
1260 // Tests while init instruction which points-to a constant.
1261 //
1262 // init = Tuple(Constant(S32, {}), Constant(F32, {8}))
1263 //
1264 // CopyInsertion pass should add copies for both constants.
1265 //
TEST_F(WhileCopyInsertionTest,InitPointsToConstant)1266 TEST_F(WhileCopyInsertionTest, InitPointsToConstant) {
1267 auto while_hlo = BuildWhileInstruction_InitPointsToConstant();
1268
1269 InsertCopies(module_.get());
1270 EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0);
1271 EXPECT_EQ(CountCopies(*module_), 2);
1272
1273 EXPECT_THAT(while_hlo->operand(0),
1274 op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant())));
1275 }
1276
1277 // Tests while init instruction which points-to a parameter.
1278 //
1279 // init = Tuple(Constant(S32, {}), Parameter(F32, {8}))
1280 //
1281 // CopyInsertion pass should add copies for both the constant and parameter.
1282 //
TEST_F(WhileCopyInsertionTest,InitPointsToParameter)1283 TEST_F(WhileCopyInsertionTest, InitPointsToParameter) {
1284 auto while_hlo = BuildWhileInstruction_InitPointsToParameter();
1285
1286 InsertCopies(module_.get());
1287 EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0);
1288 EXPECT_EQ(CountCopies(*module_), 2);
1289
1290 EXPECT_THAT(while_hlo->operand(0),
1291 op::Tuple(op::Copy(op::Constant()), op::Copy(op::Parameter())));
1292 }
1293
1294 // Tests while init instruction which has an ambiguous points-to set.
1295 //
1296 // select = Select(pred, tuple1, tuple2)
1297 // init = Tuple(Constant(S32, {}), Parameter(F32, {8}))
1298 //
1299 // CopyInsertion pass will conceptually generate the following, but with some of
1300 // the actual GTE and Tuple instructions optimized away:
1301 //
1302 // Tuple // old init
1303 // / \
1304 // / \
1305 // GTE(0) GTE(1)
1306 // | / \
1307 // | / \
1308 // | GTE(0) GTE(1)
1309 // | | |
1310 // Copy Copy Copy
1311 // | | |
1312 // \ | /
1313 // \ Tuple
1314 // \ /
1315 // \ /
1316 // Tuple // new init
1317 //
TEST_F(WhileCopyInsertionTest,InitPointsToAmbiguous)1318 TEST_F(WhileCopyInsertionTest, InitPointsToAmbiguous) {
1319 auto while_hlo = BuildWhileInstruction_InitPointsToAmbiguous();
1320
1321 InsertCopies(module_.get());
1322 EXPECT_EQ(CountCopies(*module_), 4);
1323 // The entry computation requires three copies to resolve the ambiguity of two
1324 // init elements and the constant passed in as one of the init elements.
1325 EXPECT_EQ(CountCopies(*module_->entry_computation()), 3);
1326 EXPECT_THAT(while_hlo->operand(0),
1327 op::Tuple(op::Copy(op::Constant()),
1328 op::Tuple(op::Copy(op::GetTupleElement()),
1329 op::Copy(op::GetTupleElement()))));
1330
1331 // The body requires one copy because the buffer set is not distinct: the
1332 // result of one of the adds is written into two elements of the output of the
1333 // loop body. Either element might be copied.
1334 EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1);
1335 if (while_hlo->while_body()
1336 ->root_instruction()
1337 ->operand(1)
1338 ->operand(0)
1339 ->opcode() == HloOpcode::kCopy) {
1340 EXPECT_THAT(
1341 while_hlo->while_body()->root_instruction(),
1342 op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add())));
1343 } else {
1344 EXPECT_THAT(
1345 while_hlo->while_body()->root_instruction(),
1346 op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add()))));
1347 }
1348 }
1349
1350 // Tests while init instruction which has a non-distinct points-to set.
1351 //
1352 // init = Tuple(Constant(S32, {}), Tuple({vec_one, vec_one}))
1353 //
1354 // CopyInsertion pass will conceptually generate the following, but with some of
1355 // the actual GTE and Tuple instructions optimized away:
1356 //
1357 // Tuple // old init
1358 // / \
1359 // / \
1360 // GTE(0) GTE(1)
1361 // | / \
1362 // | / \
1363 // | GTE(0) GTE(1)
1364 // | | |
1365 // Copy Copy Copy
1366 // | | |
1367 // \ | /
1368 // \ Tuple
1369 // \ /
1370 // \ /
1371 // Tuple // new init
1372 //
TEST_F(WhileCopyInsertionTest,InitPointsToNonDistinct)1373 TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) {
1374 auto while_hlo = BuildWhileInstruction_InitPointsToNonDistinct();
1375
1376 InsertCopies(module_.get());
1377
1378 // The entry computation requires two copies to resolve the non-distinctness
1379 // of two init elements and the constant passed in as one of the init
1380 // elements. Either element can be copied for the distinctness issue.
1381 EXPECT_EQ(CountCopies(*module_->entry_computation()), 2);
1382 if (while_hlo->operand(0)->operand(1)->operand(0)->opcode() ==
1383 HloOpcode::kCopy) {
1384 EXPECT_THAT(
1385 while_hlo->operand(0),
1386 op::Tuple(op::Copy(op::Constant()),
1387 op::Tuple(op::Copy(op::Broadcast()), op::Broadcast())));
1388 } else {
1389 EXPECT_THAT(
1390 while_hlo->operand(0),
1391 op::Tuple(op::Copy(op::Constant()),
1392 op::Tuple(op::Broadcast(), op::Copy(op::Broadcast()))));
1393 }
1394
1395 // The body requires one copy because the buffer set is not distinct: the
1396 // result of one of the adds is written into two elements of the output of the
1397 // loop body. Either element might be copied.
1398 EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1);
1399 if (while_hlo->while_body()
1400 ->root_instruction()
1401 ->operand(1)
1402 ->operand(0)
1403 ->opcode() == HloOpcode::kCopy) {
1404 EXPECT_THAT(
1405 while_hlo->while_body()->root_instruction(),
1406 op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add())));
1407 } else {
1408 EXPECT_THAT(
1409 while_hlo->while_body()->root_instruction(),
1410 op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add()))));
1411 }
1412 }
1413
1414 // Tests while init instruction buffer which interferes with while result
1415 // buffer.
1416 //
1417 // init_data = Broadcast(...)
1418 // add_unrelated = Add(init_data) // takes a reference to cause interference
1419 // init = Tuple(Constant(S32, {}), init_data))
1420 //
1421 // CopyInsertion pass should copy both operands.
1422 //
TEST_F(WhileCopyInsertionTest,InitPointsToInterfering)1423 TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) {
1424 auto while_hlo = BuildWhileInstruction_InitPointsToInterfering();
1425
1426 InsertCopies(module_.get());
1427 EXPECT_EQ(CountCopies(*module_), 2);
1428 EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0);
1429
1430 EXPECT_THAT(while_hlo->operand(0),
1431 op::Tuple(op::Copy(op::Constant()), op::Copy(op::Broadcast())));
1432 }
1433
1434 // Tests while init instruction buffer which has a non-distinct points-to set:
1435 //
1436 // init = Tuple(Parameter(S32, {}), Parameter(F32, {8},
1437 // Parameter(F32, {8})))
1438 //
1439 // where the second and third parameters are identical *and* the tuple shared
1440 // by another while instruction.
1441 //
1442 // Verifies that the resulting point-to set is distinct in the resulting Tuple
1443 // (non-identical Copys). In other words, verifies that copy sharing does not
1444 // insert identical copies to the resulting tuple.
TEST_F(WhileCopyInsertionTest,InitPointsToNonDistinctUsedByTwoWhileLoops)1445 TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) {
1446 // Loop body that outputs tuple comprises two elements dependent on the init
1447 // tuple.
1448 const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
1449 {induction_variable_shape_, data_shape_, data_shape_});
1450
1451 auto condition1 = module_->AddEmbeddedComputation(
1452 BuildConditionComputation(loop_state_shape));
1453 auto condition2 = module_->AddEmbeddedComputation(
1454 BuildConditionComputation(loop_state_shape));
1455 auto body1 =
1456 module_->AddEmbeddedComputation(BuildDependentBodyComputation2());
1457 auto body2 =
1458 module_->AddEmbeddedComputation(BuildDependentBodyComputation2());
1459
1460 auto builder = HloComputation::Builder(TestName() + ".While");
1461
1462 auto iter_param = builder.AddInstruction(
1463 HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
1464 auto data_param = builder.AddInstruction(
1465 HloInstruction::CreateParameter(1, data_shape_, "data"));
1466
1467 // Loop init tuple contains two identical parameter buffers.
1468 auto loop_init = builder.AddInstruction(
1469 HloInstruction::CreateTuple({iter_param, data_param, data_param}));
1470
1471 // Two while loops share the same loop init tuple.
1472 auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
1473 loop_state_shape, condition1, body1, loop_init));
1474 auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
1475 loop_state_shape, condition2, body2, loop_init));
1476
1477 // Add add instruction so neither while is dead.
1478 auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1479 ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0));
1480 auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1481 ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo2, 0));
1482 builder.AddInstruction(
1483 HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2));
1484
1485 module_->AddEntryComputation(builder.Build());
1486
1487 InsertCopies(module_.get());
1488
1489 // None of the bodies should have copies or control flow edges.
1490 EXPECT_EQ(CountCopies(*body1), 0);
1491 EXPECT_EQ(CountCopies(*body2), 0);
1492
1493 // The loop bodies pass through elements 1 and 2 in the init tuple, so ideally
1494 // these should not need to be copied before either while. However, copy
1495 // insertion is not able to reason about the transparency of elements through
1496 // while bodies in all circumstances so extra copies are added (b/xxx).
1497 EXPECT_EQ(CountCopies(*module_->entry_computation()), 2);
1498
1499 EXPECT_THAT(while_hlo1->operand(0),
1500 op::Tuple(op::Copy(), op::Parameter(), op::Parameter()));
1501 EXPECT_THAT(while_hlo2->operand(0),
1502 op::Tuple(op::Copy(), op::Parameter(), op::Parameter()));
1503 }
1504
TEST_F(CopyInsertionTest,SwizzlingWhile)1505 TEST_F(CopyInsertionTest, SwizzlingWhile) {
1506 // Test a while instruction with a body which permutes its tuple parameter
1507 // elements.
1508 auto module = CreateNewVerifiedModule();
1509 const Shape loop_state_shape =
1510 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1511
1512 // Body simply interchanges the two tuple elements in the loop state.
1513 auto body_builder = HloComputation::Builder("body");
1514 auto body_param = body_builder.AddInstruction(
1515 HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1516 auto body_element_0 = body_builder.AddInstruction(
1517 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
1518 auto body_element_1 = body_builder.AddInstruction(
1519 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
1520 body_builder.AddInstruction(
1521 HloInstruction::CreateTuple({body_element_1, body_element_0}));
1522 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1523
1524 auto cond_builder = HloComputation::Builder("condition");
1525 cond_builder.AddInstruction(
1526 HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1527 auto cond_constant = cond_builder.AddInstruction(
1528 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1529 cond_builder.AddInstruction(HloInstruction::CreateUnary(
1530 cond_constant->shape(), HloOpcode::kNot, cond_constant));
1531 HloComputation* condition =
1532 module->AddEmbeddedComputation(cond_builder.Build());
1533
1534 auto builder = HloComputation::Builder(TestName());
1535 auto constant1 = builder.AddInstruction(
1536 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1537 auto constant2 = builder.AddInstruction(
1538 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
1539 auto tuple = builder.AddInstruction(
1540 HloInstruction::CreateTuple({constant1, constant2}));
1541 auto xla_while = builder.AddInstruction(
1542 HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple));
1543 module->AddEntryComputation(builder.Build());
1544
1545 InsertCopies(module.get());
1546
1547 EXPECT_EQ(CountCopies(*module), 6);
1548
1549 // The loop state elements should be copied at the parameter and at the root
1550 // with a control edge in between (see DeepCopyAndAddControlEdges). This is
1551 // technically one more copy than is strictly necessary, but in order to have
1552 // only three copies the copies of different loop state elements must be
1553 // ordered with a control edge.
1554 EXPECT_EQ(CountCopies(*body), 4);
1555 EXPECT_EQ(CountControlEdges(*body), 2);
1556
1557 EXPECT_THAT(body->root_instruction(),
1558 op::Tuple(op::Copy(op::Copy()), op::Copy(op::Copy())));
1559
1560 EXPECT_EQ(CountCopies(*module->entry_computation()), 2);
1561 EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy()));
1562 }
1563
TEST_F(CopyInsertionTest,CrossingParameters)1564 TEST_F(CopyInsertionTest, CrossingParameters) {
1565 // Test a case where two parameters' dataflow cross with each other while
1566 // input and output are aliased with same index:
1567 //
1568 // (p0 , p1)
1569 // | \ /|
1570 // | \ / |
1571 // alias X alias
1572 // | / \ |
1573 // | / \|
1574 // (p1 , p0)
1575 auto module = CreateNewVerifiedModule();
1576 const Shape tuple_shape =
1577 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1578
1579 auto builder = HloComputation::Builder(TestName());
1580 auto param = builder.AddInstruction(
1581 HloInstruction::CreateParameter(0, tuple_shape, "0"));
1582 auto gte0 = builder.AddInstruction(
1583 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1584 auto gte1 = builder.AddInstruction(
1585 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1586 builder.AddInstruction(HloInstruction::CreateTuple({gte1, gte0}));
1587 module->AddEntryComputation(builder.Build());
1588 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1589 /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
1590 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1591 /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
1592 InsertCopies(module.get());
1593
1594 EXPECT_EQ(CountCopies(*module), 4);
1595 }
1596
TEST_F(CopyInsertionTest,ParametersAliasing)1597 TEST_F(CopyInsertionTest, ParametersAliasing) {
1598 // Test a case where two parameters' dataflow don't interfere with each other
1599 // while aliased.
1600 //
1601 // (p0 , p1)
1602 // | |
1603 // | |
1604 // alias alias
1605 // | |
1606 // | |
1607 // (p0 , p1)
1608 auto module = CreateNewVerifiedModule();
1609 const Shape tuple_shape =
1610 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1611
1612 auto builder = HloComputation::Builder(TestName());
1613 auto param = builder.AddInstruction(
1614 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1615 auto gte0 = builder.AddInstruction(
1616 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1617 auto gte1 = builder.AddInstruction(
1618 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1619 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
1620 module->AddEntryComputation(builder.Build());
1621 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1622 /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
1623 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1624 /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
1625 InsertCopies(module.get());
1626
1627 EXPECT_EQ(CountCopies(*module), 0);
1628 }
1629
TEST_F(CopyInsertionTest,ParameterWithNoAliasing)1630 TEST_F(CopyInsertionTest, ParameterWithNoAliasing) {
1631 // Test a case where no parameter is aliased with result. In this case, copy
1632 // should be added
1633 //
1634 // (p0 , p1)
1635 // | |
1636 // | |
1637 // | |
1638 // | |
1639 // | |
1640 // (p0 , p1)
1641 auto module = CreateNewVerifiedModule();
1642 const Shape tuple_shape =
1643 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1644
1645 auto builder = HloComputation::Builder(TestName());
1646 auto param = builder.AddInstruction(
1647 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1648 auto gte0 = builder.AddInstruction(
1649 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1650 auto gte1 = builder.AddInstruction(
1651 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1652 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
1653 module->AddEntryComputation(builder.Build());
1654 InsertCopies(module.get());
1655
1656 EXPECT_THAT(module->entry_computation()->root_instruction(),
1657 op::Tuple(op::Copy(op::GetTupleElement(param, 0)),
1658 op::Copy(op::GetTupleElement(param, 1))));
1659
1660 EXPECT_EQ(CountCopies(*module), 2);
1661 }
1662
TEST_F(CopyInsertionTest,ParameterWithPartialAliasing)1663 TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) {
1664 // Test a case where one parameter is aliased with result while another one
1665 // isn't.
1666 //
1667 // (p0 , p1)
1668 // | |
1669 // | |
1670 // alias |
1671 // | |
1672 // | |
1673 // (p0 , p1)
1674 auto module = CreateNewVerifiedModule();
1675 const Shape tuple_shape =
1676 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1677
1678 auto builder = HloComputation::Builder(TestName());
1679 auto param = builder.AddInstruction(
1680 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1681 auto gte0 = builder.AddInstruction(
1682 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1683 auto gte1 = builder.AddInstruction(
1684 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1685 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
1686 module->AddEntryComputation(builder.Build());
1687 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1688 /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
1689 InsertCopies(module.get());
1690
1691 EXPECT_THAT(module->entry_computation()->root_instruction(),
1692 op::Tuple(op::GetTupleElement(param, 0),
1693 op::Copy(op::GetTupleElement(param, 1))));
1694
1695 EXPECT_EQ(CountCopies(*module), 1);
1696 }
1697
TEST_F(CopyInsertionTest,ParameterAndParallelOpsWithPartialAliasing)1698 TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) {
1699 // Test a case where one parameter is aliased with result while another one
1700 // isn't.
1701 //
1702 // +-- (p0 , p1)
1703 // | | |
1704 // | | |
1705 // alias Negate Negate
1706 // | | |
1707 // | | |
1708 // +-- (p0 , p1)
1709 auto module = CreateNewVerifiedModule();
1710 const Shape tuple_shape =
1711 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1712
1713 auto builder = HloComputation::Builder(TestName());
1714 auto param = builder.AddInstruction(
1715 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1716 auto gte0 = builder.AddInstruction(
1717 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1718 auto gte1 = builder.AddInstruction(
1719 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1720
1721 auto negate0 = builder.AddInstruction(
1722 HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
1723
1724 auto negate1 = builder.AddInstruction(
1725 HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
1726 builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1}));
1727 module->AddEntryComputation(builder.Build());
1728 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1729 /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
1730 InsertCopies(module.get());
1731
1732 EXPECT_EQ(CountCopies(*module), 0);
1733 }
1734
TEST_F(CopyInsertionTest,ParameterAndOpsWithPartialAliasing)1735 TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) {
1736 // Test a case where one parameter is aliased with result while another one
1737 // isn't.
1738 //
1739 // +-- (p0 , p1)
1740 // | | |
1741 // | | |
1742 // alias Negate Negate
1743 // | | |
1744 // | Add----+
1745 // | | |
1746 // +-- (p0 , p1)
1747 auto module = CreateNewVerifiedModule();
1748 const Shape tuple_shape =
1749 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1750
1751 auto builder = HloComputation::Builder(TestName());
1752 auto param = builder.AddInstruction(
1753 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1754 auto gte0 = builder.AddInstruction(
1755 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1756 auto gte1 = builder.AddInstruction(
1757 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1758
1759 auto negate0 = builder.AddInstruction(
1760 HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
1761
1762 auto negate1 = builder.AddInstruction(
1763 HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
1764
1765 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1766 scalar_shape_, HloOpcode::kAdd, negate0, negate1));
1767 builder.AddInstruction(HloInstruction::CreateTuple({add, negate1}));
1768 module->AddEntryComputation(builder.Build());
1769 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1770 /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
1771 InsertCopies(module.get());
1772
1773 EXPECT_EQ(CountCopies(*module), 0);
1774 }
1775
TEST_F(CopyInsertionTest,SwizzlingWhileWithOneOp)1776 TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) {
1777 // Test a while instruction with a body which permutes its tuple parameter
1778 // elements and applies one operation to one of the elements. The addition of
1779 // the operation (instruction) on the element makes the live range of the
1780 // respective input and output elements different than if the instruction were
1781 // not there (as in the SwizzlingWhile test above).
1782 auto module = CreateNewVerifiedModule();
1783 const Shape loop_state_shape =
1784 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1785
1786 // Body interchanges the two tuple elements in the loop state and negates one
1787 // of them.
1788 auto body_builder = HloComputation::Builder("body");
1789 auto body_param = body_builder.AddInstruction(
1790 HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1791 auto body_element_0 = body_builder.AddInstruction(
1792 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
1793 auto body_element_1 = body_builder.AddInstruction(
1794 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
1795 auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
1796 scalar_shape_, HloOpcode::kNegate, body_element_1));
1797 body_builder.AddInstruction(
1798 HloInstruction::CreateTuple({negate, body_element_0}));
1799 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1800
1801 auto cond_builder = HloComputation::Builder("condition");
1802 cond_builder.AddInstruction(
1803 HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1804 auto cond_constant = cond_builder.AddInstruction(
1805 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1806 cond_builder.AddInstruction(HloInstruction::CreateUnary(
1807 cond_constant->shape(), HloOpcode::kNot, cond_constant));
1808 HloComputation* condition =
1809 module->AddEmbeddedComputation(cond_builder.Build());
1810
1811 auto builder = HloComputation::Builder(TestName());
1812 auto constant1 = builder.AddInstruction(
1813 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1814 auto constant2 = builder.AddInstruction(
1815 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
1816 auto tuple = builder.AddInstruction(
1817 HloInstruction::CreateTuple({constant1, constant2}));
1818 auto xla_while = builder.AddInstruction(
1819 HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple));
1820 module->AddEntryComputation(builder.Build());
1821
1822 InsertCopies(module.get());
1823
1824 EXPECT_EQ(CountCopies(*module), 6);
1825
1826 // The loop state elements should be copied at the parameter and at the root
1827 // with a control edge in between (see DeepCopyAndAddControlEdges).
1828 EXPECT_EQ(CountCopies(*body), 4);
1829 EXPECT_EQ(CountControlEdges(*body), 2);
1830
1831 EXPECT_THAT(
1832 body->root_instruction(),
1833 op::Tuple(op::Copy(op::Negate(op::Copy())), op::Copy(op::Copy())));
1834
1835 EXPECT_EQ(CountCopies(*module->entry_computation()), 2);
1836 EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy()));
1837 }
1838
TEST_F(CopyInsertionTest,SwizzlingWhileSharedInput)1839 TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) {
1840 // Test a while instruction with a body which permutes it's tuple parameter
1841 // elements similar to SwizzlinWhile above. However, in this test the input to
1842 // the while body is a single constant (both loop state elements are the same
1843 // constant). This means no copies are necessary because both loop state
1844 // elements are the same so interchanging them is a no-op.
1845 auto module = CreateNewVerifiedModule();
1846 const Shape loop_state_shape =
1847 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1848
1849 // Body simply interchanges the two tuple elements in the loop state.
1850 auto body_builder = HloComputation::Builder("body");
1851 auto body_param = body_builder.AddInstruction(
1852 HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1853 auto body_element_0 = body_builder.AddInstruction(
1854 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
1855 auto body_element_1 = body_builder.AddInstruction(
1856 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
1857 body_builder.AddInstruction(
1858 HloInstruction::CreateTuple({body_element_1, body_element_0}));
1859 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1860
1861 auto cond_builder = HloComputation::Builder("condition");
1862 cond_builder.AddInstruction(
1863 HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1864 auto cond_constant = cond_builder.AddInstruction(
1865 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1866 cond_builder.AddInstruction(HloInstruction::CreateUnary(
1867 cond_constant->shape(), HloOpcode::kNot, cond_constant));
1868 HloComputation* condition =
1869 module->AddEmbeddedComputation(cond_builder.Build());
1870
1871 auto builder = HloComputation::Builder(TestName());
1872 auto constant = builder.AddInstruction(
1873 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1874 auto tuple =
1875 builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
1876 builder.AddInstruction(
1877 HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple));
1878 module->AddEntryComputation(builder.Build());
1879
1880 InsertCopies(module.get());
1881
1882 EXPECT_EQ(CountCopies(*module), 2);
1883 EXPECT_EQ(CountCopies(*body), 0);
1884
1885 EXPECT_EQ(CountCopies(*module->entry_computation()), 2);
1886 EXPECT_THAT(module->entry_computation()->root_instruction(),
1887 op::Tuple(op::Copy(), op::Copy()));
1888 }
1889
TEST_F(CopyInsertionTest,SequentialWhiles)1890 TEST_F(CopyInsertionTest, SequentialWhiles) {
1891 // Construct a computation with a series of sequential while instructions
1892 // containing four loop state elements:
1893 //
1894 // element 0 is passed to each while directly from an entry parameter.
1895 //
1896 // element 1 is passed transparently in series through all the while bodies.
1897 //
1898 // element 2 is negated in each while body. (in-place possible)
1899 //
1900 // element 3 is reversed in each while body. (in-place not possible)
1901 //
1902 const Shape element_shape = ShapeUtil::MakeShape(F32, {42});
1903 const Shape loop_state_shape = ShapeUtil::MakeTupleShape(
1904 {element_shape, element_shape, element_shape, element_shape});
1905
1906 auto module = CreateNewVerifiedModule();
1907 auto builder = HloComputation::Builder(TestName());
1908 auto param_0 = builder.AddInstruction(
1909 HloInstruction::CreateParameter(0, element_shape, "param_0"));
1910 auto param_1 = builder.AddInstruction(
1911 HloInstruction::CreateParameter(1, element_shape, "param_1"));
1912 auto param_2 = builder.AddInstruction(
1913 HloInstruction::CreateParameter(2, element_shape, "param_2"));
1914 auto param_3 = builder.AddInstruction(
1915 HloInstruction::CreateParameter(3, element_shape, "param_3"));
1916
1917 // The number of sequential kWhile instructions.
1918 const int kNumWhiles = 3;
1919
1920 HloInstruction* prev_element_1 = param_1;
1921 HloInstruction* prev_element_2 = param_2;
1922 HloInstruction* prev_element_3 = param_3;
1923
1924 // Vector containing all of the while instructions.
1925 std::vector<const HloInstruction*> whiles;
1926 for (int i = 0; i < kNumWhiles; ++i) {
1927 auto body_builder = HloComputation::Builder("body");
1928 auto body_param = body_builder.AddInstruction(
1929 HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1930 auto body_element_0 = body_builder.AddInstruction(
1931 HloInstruction::CreateGetTupleElement(element_shape, body_param, 0));
1932 auto body_element_1 = body_builder.AddInstruction(
1933 HloInstruction::CreateGetTupleElement(element_shape, body_param, 1));
1934 auto body_element_2 = body_builder.AddInstruction(
1935 HloInstruction::CreateGetTupleElement(element_shape, body_param, 2));
1936 auto body_element_3 = body_builder.AddInstruction(
1937 HloInstruction::CreateGetTupleElement(element_shape, body_param, 3));
1938 auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
1939 element_shape, HloOpcode::kNegate, body_element_2));
1940 auto reverse = body_builder.AddInstruction(
1941 HloInstruction::CreateReverse(element_shape, body_element_3, {0}));
1942 body_builder.AddInstruction(HloInstruction::CreateTuple(
1943 {body_element_0, body_element_1, negate, reverse}));
1944 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1945
1946 auto cond_builder = HloComputation::Builder("condition");
1947 cond_builder.AddInstruction(
1948 HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1949 auto cond_constant = cond_builder.AddInstruction(
1950 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1951 cond_builder.AddInstruction(HloInstruction::CreateUnary(
1952 cond_constant->shape(), HloOpcode::kNot, cond_constant));
1953 HloComputation* condition =
1954 module->AddEmbeddedComputation(cond_builder.Build());
1955
1956 auto while_init = builder.AddInstruction(HloInstruction::CreateTuple(
1957 {param_0, prev_element_1, prev_element_2, prev_element_3}));
1958
1959 auto xla_while = builder.AddInstruction(HloInstruction::CreateWhile(
1960 loop_state_shape, condition, body, while_init));
1961 whiles.push_back(xla_while);
1962 if (i != kNumWhiles - 1) {
1963 prev_element_1 = builder.AddInstruction(
1964 HloInstruction::CreateGetTupleElement(element_shape, xla_while, 1));
1965 prev_element_2 = builder.AddInstruction(
1966 HloInstruction::CreateGetTupleElement(element_shape, xla_while, 2));
1967 prev_element_3 = builder.AddInstruction(
1968 HloInstruction::CreateGetTupleElement(element_shape, xla_while, 3));
1969 }
1970 }
1971
1972 module->AddEntryComputation(builder.Build());
1973
1974 InsertCopies(module.get());
1975
1976 // Each while body has one copy. And each loop state element is copied once in
1977 // the entry computation.
1978 EXPECT_EQ(CountCopies(*module), 4 + kNumWhiles);
1979
1980 // Each while body should have exactly one copy for element three which is an
1981 // op (kReverse) which cannot be done in place.
1982 for (const HloInstruction* xla_while : whiles) {
1983 EXPECT_EQ(CountCopies(*xla_while->while_body()), 1);
1984 }
1985
1986 EXPECT_THAT(whiles[0]->operand(0), op::Tuple(op::Parameter(), op::Parameter(),
1987 op::Copy(), op::Copy()));
1988 EXPECT_THAT(module->entry_computation()->root_instruction(),
1989 op::Tuple(op::Copy(), op::Copy(), op::GetTupleElement(),
1990 op::GetTupleElement()));
1991 }
1992
TEST_F(CopyInsertionTest,WhileBodyWithConstantRoot)1993 TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) {
1994 // Test a while body and condition which are each simply a constant (root of
1995 // computation is a constant). The body constant should be copied.
1996 auto module = CreateNewVerifiedModule();
1997 auto builder = HloComputation::Builder(TestName());
1998 auto param_0 = builder.AddInstruction(
1999 HloInstruction::CreateParameter(0, scalar_shape_, "param_0"));
2000
2001 auto body_builder = HloComputation::Builder("body");
2002 body_builder.AddInstruction(
2003 HloInstruction::CreateParameter(0, scalar_shape_, "param"));
2004 body_builder.AddInstruction(
2005 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0)));
2006 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
2007
2008 auto cond_builder = HloComputation::Builder("condition");
2009 cond_builder.AddInstruction(
2010 HloInstruction::CreateParameter(0, scalar_shape_, "param"));
2011 cond_builder.AddInstruction(
2012 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
2013 HloComputation* condition =
2014 module->AddEmbeddedComputation(cond_builder.Build());
2015
2016 auto xla_while = builder.AddInstruction(
2017 HloInstruction::CreateWhile(scalar_shape_, condition, body, param_0));
2018
2019 module->AddEntryComputation(builder.Build());
2020
2021 InsertCopies(module.get());
2022
2023 EXPECT_EQ(CountCopies(*module), 2);
2024
2025 EXPECT_THAT(xla_while->operand(0), op::Copy(op::Parameter()));
2026 EXPECT_THAT(body->root_instruction(), op::Copy(op::Constant()));
2027 EXPECT_THAT(condition->root_instruction(), op::Constant());
2028 }
2029
TEST_F(CopyInsertionTest,TokensShouldNotBeCopied)2030 TEST_F(CopyInsertionTest, TokensShouldNotBeCopied) {
2031 string module_string = R"(
2032 HloModule TokensShouldNotBeCopied
2033
2034 %Body (param.1: (s32[], token[])) -> (s32[], token[]) {
2035 %param.1 = (s32[], token[]) parameter(0)
2036 %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
2037 %constant.1 = s32[] constant(1)
2038 %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
2039 %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
2040 %after-all = token[] after-all(token[] %get-tuple-element.2)
2041 ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
2042 }
2043
2044 %Cond (param: (s32[], token[])) -> pred[] {
2045 %param = (s32[], token[]) parameter(0)
2046 %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
2047 %constant = s32[] constant(42)
2048 ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
2049 }
2050
2051 ENTRY %TokensShouldNotBeCopied () -> s32[] {
2052 %one = s32[] constant(1)
2053 %negative_one = s32[] negate(%one)
2054 %init_token = token[] after-all()
2055 %init_tuple = (s32[], token[]) tuple(s32[] %negative_one, token[] %init_token)
2056 %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
2057 ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
2058 }
2059 )";
2060 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2061 ParseAndReturnVerifiedModule(module_string));
2062 InsertCopies(module.get());
2063
2064 // There should be no copies added because tokens should not be copied.
2065 EXPECT_EQ(CountCopies(*module), 0);
2066 }
2067
MakeTrivialCondition(const Shape & shape)2068 std::unique_ptr<HloComputation> MakeTrivialCondition(const Shape& shape) {
2069 auto builder = HloComputation::Builder("trivial_condition");
2070 builder.AddInstruction(
2071 HloInstruction::CreateParameter(0, shape, "loop_state"));
2072 auto constant = builder.AddInstruction(
2073 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
2074 builder.AddInstruction(HloInstruction::CreateUnary(
2075 constant->shape(), HloOpcode::kNot, constant));
2076 return builder.Build();
2077 }
2078
MakeBenchmarkWhileBody()2079 std::unique_ptr<HloComputation> MakeBenchmarkWhileBody() {
2080 auto builder = HloComputation::Builder("benchmark_loop_body");
2081 const Shape element_shape = ShapeUtil::MakeShape(F32, {42});
2082 const Shape loop_state_shape =
2083 ShapeUtil::MakeTupleShape({element_shape, element_shape, element_shape});
2084 HloInstruction* param = builder.AddInstruction(
2085 HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
2086 HloInstruction* element_0 = builder.AddInstruction(
2087 HloInstruction::CreateGetTupleElement(element_shape, param, 0));
2088 HloInstruction* element_1 = builder.AddInstruction(
2089 HloInstruction::CreateGetTupleElement(element_shape, param, 1));
2090 HloInstruction* element_2 = builder.AddInstruction(
2091 HloInstruction::CreateGetTupleElement(element_shape, param, 2));
2092
2093 HloInstruction* rev_1 = builder.AddInstruction(
2094 HloInstruction::CreateReverse(element_shape, element_1, {0}));
2095 HloInstruction* add_1_2 = builder.AddInstruction(HloInstruction::CreateBinary(
2096 element_shape, HloOpcode::kAdd, element_1, element_2));
2097
2098 builder.AddInstruction(
2099 HloInstruction::CreateTuple({element_0, rev_1, add_1_2}));
2100 return builder.Build();
2101 }
2102
BM_SequentialWhiles(::testing::benchmark::State & state)2103 void BM_SequentialWhiles(::testing::benchmark::State& state) {
2104 const int num_whiles = state.range(0);
2105
2106 // This benchmark constructs a chain of sequential while instructions.
2107 // Timer starts automatically at the first iteration of this loop
2108 // and ends after the last one.
2109 for (auto s : state) {
2110 state.PauseTiming();
2111 HloModuleConfig config;
2112 config.set_debug_options(GetDebugOptionsFromFlags());
2113 HloModule module("BM_SequentialWhiles", config);
2114
2115 auto builder = HloComputation::Builder("BM_SequentialWhiles");
2116 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
2117 0, ShapeUtil::MakeShape(F32, {42}), "x"));
2118 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
2119 1, ShapeUtil::MakeShape(F32, {42}), "y"));
2120 HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
2121 2, ShapeUtil::MakeShape(F32, {42}), "z"));
2122 HloInstruction* init =
2123 builder.AddInstruction(HloInstruction::CreateTuple({x, y, z}));
2124
2125 HloInstruction* prev_loop_state = init;
2126 for (int w = 0; w < num_whiles; ++w) {
2127 HloComputation* condition =
2128 module.AddEmbeddedComputation(MakeTrivialCondition(init->shape()));
2129 HloComputation* body =
2130 module.AddEmbeddedComputation(MakeBenchmarkWhileBody());
2131 prev_loop_state = builder.AddInstruction(HloInstruction::CreateWhile(
2132 init->shape(), condition, body, prev_loop_state));
2133 }
2134 module.AddEntryComputation(builder.Build());
2135
2136 CopyInsertion copy_insertion;
2137
2138 state.ResumeTiming();
2139 ASSERT_IS_OK(copy_insertion.Run(&module).status());
2140 state.PauseTiming();
2141
2142 // The entry computation should have three copies, and each body has one.
2143 ASSERT_EQ(CountCopies(module), 3 + num_whiles);
2144 state.ResumeTiming();
2145 }
2146 }
2147
BM_ParallelWhiles(::testing::benchmark::State & state)2148 void BM_ParallelWhiles(::testing::benchmark::State& state) {
2149 const int num_whiles = state.range(0);
2150
2151 // This benchmark constructs a fan-out of parallel while instructions.
2152 for (auto s : state) {
2153 state.PauseTiming();
2154 HloModuleConfig config;
2155 config.set_debug_options(GetDebugOptionsFromFlags());
2156 HloModule module("BM_SequentialWhiles", config);
2157
2158 auto builder = HloComputation::Builder("BM_ParallelWhiles");
2159 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
2160 0, ShapeUtil::MakeShape(F32, {42}), "x"));
2161 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
2162 1, ShapeUtil::MakeShape(F32, {42}), "y"));
2163 HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
2164 2, ShapeUtil::MakeShape(F32, {42}), "z"));
2165 HloInstruction* init =
2166 builder.AddInstruction(HloInstruction::CreateTuple({x, y, z}));
2167
2168 HloInstruction* sum = nullptr;
2169 for (int w = 0; w < num_whiles; ++w) {
2170 HloComputation* condition =
2171 module.AddEmbeddedComputation(MakeTrivialCondition(init->shape()));
2172 HloComputation* body =
2173 module.AddEmbeddedComputation(MakeBenchmarkWhileBody());
2174
2175 HloInstruction* xla_while = builder.AddInstruction(
2176 HloInstruction::CreateWhile(init->shape(), condition, body, init));
2177
2178 if (sum == nullptr) {
2179 sum = builder.AddInstruction(
2180 HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0));
2181 } else {
2182 HloInstruction* element_0 = builder.AddInstruction(
2183 HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0));
2184 sum = builder.AddInstruction(HloInstruction::CreateBinary(
2185 x->shape(), HloOpcode::kAdd, sum, element_0));
2186 }
2187 }
2188 module.AddEntryComputation(builder.Build());
2189
2190 CopyInsertion copy_insertion;
2191
2192 state.ResumeTiming();
2193 ASSERT_IS_OK(copy_insertion.Run(&module).status());
2194 state.PauseTiming();
2195
2196 // Each body receives of copy of two of the parameters (the corresponding
2197 // elements in the body are modified), and there is one copy in each body.
2198 ASSERT_EQ(CountCopies(module), 3 * num_whiles);
2199 }
2200 }
2201
MakeBenchmarkWhileBody(const int num_tuple_inputs)2202 std::unique_ptr<HloComputation> MakeBenchmarkWhileBody(
2203 const int num_tuple_inputs) {
2204 auto builder = HloComputation::Builder("benchmark_loop_body");
2205 const Shape element_shape = ShapeUtil::MakeShape(F32, {});
2206 std::vector<Shape> input_shape(num_tuple_inputs, element_shape);
2207 const Shape loop_state_shape = ShapeUtil::MakeTupleShape(input_shape);
2208 HloInstruction* param = builder.AddInstruction(
2209 HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
2210 std::vector<HloInstruction*> gte_nodes(num_tuple_inputs);
2211 for (int i = 0; i < num_tuple_inputs; ++i) {
2212 gte_nodes[i] = builder.AddInstruction(
2213 HloInstruction::CreateGetTupleElement(element_shape, param, i));
2214 }
2215 builder.AddInstruction(HloInstruction::CreateTuple(gte_nodes));
2216 return builder.Build();
2217 }
2218
BM_ManyElementTuple(::testing::benchmark::State & state)2219 void BM_ManyElementTuple(::testing::benchmark::State& state) {
2220 const int num_tuple_inputs = state.range(0);
2221 HloModuleConfig config;
2222 config.set_debug_options(GetDebugOptionsFromFlags());
2223 CopyInsertion copy_insertion;
2224 const Shape element_shape = ShapeUtil::MakeShape(F32, {});
2225 std::vector<HloInstruction*> tuple_params(num_tuple_inputs);
2226 for (auto s : state) {
2227 state.PauseTiming();
2228 auto builder = HloComputation::Builder("BM_ParallelWhiles");
2229 HloModule module("BM_ManyElementTuple", config);
2230 for (int j = 0; j < num_tuple_inputs; ++j) {
2231 tuple_params[j] = builder.AddInstruction(
2232 HloInstruction::CreateParameter(j, element_shape, ""));
2233 }
2234 HloInstruction* init =
2235 builder.AddInstruction(HloInstruction::CreateTuple(tuple_params));
2236 HloComputation* condition =
2237 module.AddEmbeddedComputation(MakeTrivialCondition(init->shape()));
2238 HloComputation* body =
2239 module.AddEmbeddedComputation(MakeBenchmarkWhileBody(num_tuple_inputs));
2240 HloInstruction* xla_while = builder.AddInstruction(
2241 HloInstruction::CreateWhile(init->shape(), condition, body, init));
2242 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
2243 ShapeUtil::MakeShape(F32, {}), xla_while, 0));
2244 module.AddEntryComputation(builder.Build());
2245 state.ResumeTiming();
2246 ASSERT_IS_OK(copy_insertion.Run(&module).status());
2247 }
2248 }
2249
2250 BENCHMARK(BM_SequentialWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096);
2251 BENCHMARK(BM_ParallelWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096);
2252 BENCHMARK(BM_ManyElementTuple)->Arg(1024)->Arg(12288);
2253
TEST_F(CopyInsertionTest,SimpleControlFlowTest)2254 TEST_F(CopyInsertionTest, SimpleControlFlowTest) {
2255 const string& hlo_string = R"(
2256 HloModule TestModule
2257
2258 if-body.v5 {
2259 constant.3 = s32[] constant(-1)
2260 p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2261 get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1
2262 get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0
2263 get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1
2264 add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66)
2265 tuple.33 = (s32[]) tuple(add.3)
2266 ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33)
2267 }
2268
2269 if-condition.v4 {
2270 p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2271 get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0
2272 constant.4 = s32[] constant(0)
2273 ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ
2274 }
2275
2276 _functionalize_body_1__.v28 {
2277 arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0)
2278 get-tuple-element.68 = s32[] get-tuple-element(arg_tuple.4), index=0
2279 constant.7 = s32[] constant(1)
2280 add.4 = s32[] add(get-tuple-element.68, constant.7)
2281 get-tuple-element.69 = s32[] get-tuple-element(arg_tuple.4), index=1
2282 get-tuple-element.70 = s32[] get-tuple-element(arg_tuple.4), index=2
2283 less-than-or-equal-to = pred[] compare(get-tuple-element.69, get-tuple-element.70), direction=LE
2284 constant.8 = s32[] constant(0)
2285 select = s32[] select(less-than-or-equal-to, constant.8, constant.7)
2286 get-tuple-element.71 = s32[] get-tuple-element(arg_tuple.4), index=3
2287 tuple.35 = (s32[], s32[], s32[]) tuple(get-tuple-element.69, get-tuple-element.71, get-tuple-element.70)
2288 tuple.36 = (s32[]) tuple(constant.8)
2289 tuple.37 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.35, tuple.36)
2290 while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.37), condition=if-condition.v4, body=if-body.v5
2291 get-tuple-element.72 = (s32[]) get-tuple-element(while), index=2
2292 get-tuple-element.73 = s32[] get-tuple-element(get-tuple-element.72), index=0
2293 ROOT tuple.38 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.69, get-tuple-element.70, get-tuple-element.73)
2294 }
2295
2296 cond_wrapper.v3.1 {
2297 inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0)
2298 get-tuple-element.75 = s32[] get-tuple-element(inputs.1), index=0
2299 constant.11 = s32[] constant(7)
2300 ROOT less-than.2 = pred[] compare(get-tuple-element.75, constant.11), direction=LT
2301 }
2302
2303 _functionalize_body_2__.v25 {
2304 arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2305 get-tuple-element.76 = s32[] get-tuple-element(arg_tuple.5), index=0
2306 get-tuple-element.77 = s32[] get-tuple-element(arg_tuple.5), index=2
2307 get-tuple-element.78 = s32[] get-tuple-element(arg_tuple.5), index=3
2308 get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=4
2309 tuple.39 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.76, get-tuple-element.77, get-tuple-element.78, get-tuple-element.79)
2310 while.2 = (s32[], s32[], s32[], s32[]) while(tuple.39), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28
2311 get-tuple-element.80 = s32[] get-tuple-element(while.2), index=0
2312 get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=1
2313 constant.12 = s32[] constant(1)
2314 add.5 = s32[] add(get-tuple-element.81, constant.12)
2315 get-tuple-element.82 = s32[] get-tuple-element(while.2), index=3
2316 ROOT tuple.40 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.80, add.5, get-tuple-element.77, get-tuple-element.78, get-tuple-element.82)
2317 }
2318
2319 cond_wrapper.v3.2 {
2320 inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2321 get-tuple-element.83 = s32[] get-tuple-element(inputs.2), index=1
2322 constant.13 = s32[] constant(5)
2323 ROOT less-than.3 = pred[] compare(get-tuple-element.83, constant.13), direction=LT
2324 }
2325
2326 ENTRY TestComputation {
2327 arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2328 ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25
2329 }
2330 )";
2331 auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
2332 auto module = module_or_status.ConsumeValueOrDie();
2333 InsertCopies(module.get());
2334 }
2335
TEST_F(CopyInsertionTest,ControlFlowTest)2336 TEST_F(CopyInsertionTest, ControlFlowTest) {
2337 const string& hlo_string = R"(
2338 HloModule TestModule
2339
2340 if-body.v5 {
2341 constant.3 = s32[] constant(-1)
2342 p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2343 get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1
2344 get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0
2345 get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1
2346 add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66)
2347 tuple.33 = (s32[]) tuple(add.3)
2348 ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33)
2349 }
2350
2351 if-condition.v4 {
2352 p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2353 get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0
2354 constant.4 = s32[] constant(0)
2355 ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ
2356 }
2357
2358 if-body.v5.1 {
2359 constant.5 = s32[] constant(-1)
2360 p.3 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2361 get-tuple-element.68 = (s32[], s32[], s32[]) get-tuple-element(p.3), index=1
2362 get-tuple-element.70 = s32[] get-tuple-element(get-tuple-element.68), index=2
2363 multiply.1 = s32[] multiply(get-tuple-element.70, get-tuple-element.70)
2364 tuple.35 = (s32[]) tuple(multiply.1)
2365 ROOT tuple.36 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.5, get-tuple-element.68, tuple.35)
2366 }
2367
2368 if-condition.v4.1 {
2369 p.4 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2370 get-tuple-element.71 = s32[] get-tuple-element(p.4), index=0
2371 constant.6 = s32[] constant(1)
2372 ROOT equal-to.1 = pred[] compare(get-tuple-element.71, constant.6), direction=EQ
2373 }
2374
2375 _functionalize_body_1__.v28 {
2376 arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0)
2377 get-tuple-element.72 = s32[] get-tuple-element(arg_tuple.4), index=0
2378 constant.7 = s32[] constant(1)
2379 add.4 = s32[] add(get-tuple-element.72, constant.7)
2380 get-tuple-element.73 = s32[] get-tuple-element(arg_tuple.4), index=1
2381 get-tuple-element.74 = s32[] get-tuple-element(arg_tuple.4), index=2
2382 less-than-or-equal-to = pred[] compare(get-tuple-element.73, get-tuple-element.74), direction=LE
2383 constant.8 = s32[] constant(0)
2384 select = s32[] select(less-than-or-equal-to, constant.8, constant.7)
2385 get-tuple-element.75 = s32[] get-tuple-element(arg_tuple.4), index=3
2386 tuple.37 = (s32[], s32[], s32[]) tuple(get-tuple-element.73, get-tuple-element.75, get-tuple-element.74)
2387 tuple.38 = (s32[]) tuple(constant.8)
2388 tuple.39 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.37, tuple.38)
2389 while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.39), condition=if-condition.v4, body=if-body.v5
2390 while.1 = (s32[], (s32[], s32[], s32[]), (s32[])) while(while), condition=if-condition.v4.1, body=if-body.v5.1
2391 get-tuple-element.76 = (s32[]) get-tuple-element(while.1), index=2
2392 get-tuple-element.77 = s32[] get-tuple-element(get-tuple-element.76), index=0
2393 ROOT tuple.40 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.73, get-tuple-element.74, get-tuple-element.77)
2394 }
2395
2396 cond_wrapper.v3.1 {
2397 inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0)
2398 get-tuple-element.78 = s32[] get-tuple-element(inputs.1), index=0
2399 constant.11 = s32[] constant(7)
2400 ROOT less-than.2 = pred[] compare(get-tuple-element.78, constant.11), direction=LT
2401 }
2402
2403 _functionalize_body_2__.v25 {
2404 arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2405 get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=0
2406 get-tuple-element.80 = s32[] get-tuple-element(arg_tuple.5), index=2
2407 get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=3
2408 get-tuple-element.82 = s32[] get-tuple-element(arg_tuple.5), index=4
2409 tuple.41 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.79, get-tuple-element.80, get-tuple-element.81, get-tuple-element.82)
2410 while.2 = (s32[], s32[], s32[], s32[]) while(tuple.41), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28
2411 get-tuple-element.83 = s32[] get-tuple-element(while.2), index=0
2412 get-tuple-element.84 = s32[] get-tuple-element(arg_tuple.5), index=1
2413 constant.12 = s32[] constant(1)
2414 add.5 = s32[] add(get-tuple-element.84, constant.12)
2415 get-tuple-element.85 = s32[] get-tuple-element(while.2), index=3
2416 ROOT tuple.42 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.83, add.5, get-tuple-element.80, get-tuple-element.81, get-tuple-element.85)
2417 }
2418
2419 cond_wrapper.v3.2 {
2420 inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2421 get-tuple-element.86 = s32[] get-tuple-element(inputs.2), index=1
2422 constant.13 = s32[] constant(5)
2423 ROOT less-than.3 = pred[] compare(get-tuple-element.86, constant.13), direction=LT
2424 }
2425
2426 ENTRY TestComputation {
2427 arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2428 ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25
2429 }
2430 )";
2431 auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
2432 auto module = module_or_status.ConsumeValueOrDie();
2433 InsertCopies(module.get());
2434 }
2435
TEST_F(CopyInsertionTest,NestedWhiles)2436 TEST_F(CopyInsertionTest, NestedWhiles) {
2437 // Verify that only no unnecessary copies remain after copy insertion for
2438 // trivial nested whiles (b/112472605).
2439 const string& hlo_string = R"(
2440 HloModule TestModule
2441
2442 cond.inner {
2443 ROOT param.cond.inner = pred[] parameter(0)
2444 }
2445
2446 body.inner {
2447 param.body.inner = pred[] parameter(0)
2448 ROOT not = pred[] not(param.body.inner)
2449 }
2450
2451 cond.outer {
2452 ROOT param.cond.outer = pred[] parameter(0)
2453 }
2454
2455 body.outer {
2456 param.cond.outer = pred[] parameter(0)
2457 ROOT while = pred[] while(param.cond.outer), condition=cond.inner, body=body.inner
2458 }
2459
2460 ENTRY TestComputation {
2461 entry_param = pred[] parameter(0)
2462 ROOT while = pred[] while(entry_param), condition=cond.outer, body=body.outer
2463 }
2464 )";
2465 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2466 ParseAndReturnVerifiedModule(hlo_string));
2467 InsertCopies(module.get());
2468
2469 // There should only be a single copy inserted, and it's in the entry
2470 // computation.
2471 EXPECT_EQ(CountCopies(*module), 1);
2472 EXPECT_THAT(module->entry_computation()->root_instruction(),
2473 op::While(op::Copy(op::Parameter())));
2474 }
2475
TEST_F(CopyInsertionTest,NestedWhileAndConditional2)2476 TEST_F(CopyInsertionTest, NestedWhileAndConditional2) {
2477 const string& hlo_string = R"(
2478 HloModule TestModule
2479
2480 on_true
2481 {
2482 v1 = f32[2] parameter(0)
2483 v2 = f32[2] add(v1,v1)
2484 ROOT t1 = (f32[2], f32[2]) tuple(v1,v2)
2485 }
2486
2487 on_false
2488 {
2489 v1 = f32[2] parameter(0)
2490 v2 = f32[2] multiply(v1,v1)
2491 ROOT t2 = (f32[2], f32[2]) tuple(v1,v2)
2492 }
2493
2494 cond.outer {
2495 param.1 = (pred[], f32[2], f32[2]) parameter(0)
2496 ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0
2497 }
2498
2499 body.outer {
2500 param.1 = (pred[], f32[2], f32[2]) parameter(0)
2501 pred.1 = pred[] get-tuple-element(param.1), index=0
2502 arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1
2503 if = (f32[2], f32[2]) conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false
2504 e1 = f32[2] get-tuple-element(if), index=0
2505 e2 = f32[2] get-tuple-element(if), index=1
2506 ROOT res = (pred[], f32[2], f32[2]) tuple(pred.1,e1, e2)
2507 }
2508
2509 ENTRY TestComputation {
2510 entry_param.1 = pred[] parameter(0)
2511 float_param = f32[2] parameter(1)
2512 entry_param = (pred[], f32[2], f32[2]) tuple(entry_param.1, float_param, float_param)
2513 ROOT while = (pred[], f32[2], f32[2]) while(entry_param), condition=cond.outer, body=body.outer
2514 }
2515 )";
2516 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2517 ParseAndReturnVerifiedModule(hlo_string));
2518 InsertCopies(module.get());
2519 VLOG(2) << module->ToString() << "\n";
2520
2521 // An extra copy must be kept inside the loop due to uses in the conditional.
2522 EXPECT_EQ(CountCopies(*module), 3);
2523 }
2524
TEST_F(CopyInsertionTest,NestedWhileAndConditional)2525 TEST_F(CopyInsertionTest, NestedWhileAndConditional) {
2526 const string& hlo_string = R"(
2527 HloModule TestModule
2528
2529 on_true
2530 {
2531 v1 = f32[2] parameter(0)
2532 ROOT v2 = f32[2] add(v1,v1)
2533 }
2534
2535 on_false
2536 {
2537 v1 = f32[2] parameter(0)
2538 ROOT v2 = f32[2] multiply(v1,v1)
2539 }
2540
2541 cond.outer {
2542 param.1 = (pred[], f32[2]) parameter(0)
2543 ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0
2544 }
2545
2546 body.outer {
2547 param.1 = (pred[], f32[2]) parameter(0)
2548 pred.1 = pred[] get-tuple-element(param.1), index=0
2549 arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1
2550 if = f32[2] conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false
2551 ROOT res = (pred[], f32[2]) tuple(pred.1,if)
2552 }
2553
2554 ENTRY TestComputation {
2555 entry_param.1 = pred[] parameter(0)
2556 float_param = f32[2] parameter(1)
2557 entry_param = (pred[], f32[2]) tuple(entry_param.1, float_param)
2558 ROOT while = (pred[], f32[2]) while(entry_param), condition=cond.outer, body=body.outer
2559 }
2560 )";
2561 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2562 ParseAndReturnVerifiedModule(hlo_string));
2563 InsertCopies(module.get());
2564 VLOG(2) << module->ToString() << "\n";
2565
2566 // There should only be two copies inserted, and in the entry and exit of the
2567 // computation.
2568 EXPECT_EQ(CountCopies(*module), 2);
2569 }
2570
TEST_F(CopyInsertionTest,FixpointComputationRequired)2571 TEST_F(CopyInsertionTest, FixpointComputationRequired) {
2572 const string& hlo_string = R"(
2573 HloModule Module
2574
2575 fused_computation {
2576 param0 = f32[3,3,96,1] parameter(0)
2577 param1 = f32[] parameter(1)
2578 broadcast = f32[3,3,96,1] broadcast(f32[] param1), dimensions={}
2579 ROOT %add.0 = f32[3,3,96,1] add(f32[3,3,96,1] param0, f32[3,3,96,1] broadcast)
2580 }
2581
2582 ENTRY entry_computation {
2583 arg0 = f32[3,3,96,1] parameter(0)
2584 arg1 = f32[] parameter(1)
2585 fusion = f32[3,3,96,1] fusion(f32[3,3,96,1] arg0, f32[] arg1),
2586 kind=kLoop, calls=fused_computation
2587 negate = f32[] negate(f32[] arg1)
2588 ROOT tuple = (f32[3,3,96,1], f32[3,3,96,1], f32[], f32[]) tuple(
2589 f32[3,3,96,1] fusion,
2590 f32[3,3,96,1] arg0,
2591 f32[] negate,
2592 f32[] arg1)
2593 }
2594 )";
2595
2596 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2597 ParseAndReturnVerifiedModule(hlo_string));
2598 // Set up the aliasing manually which normally would be set by
2599 // alias_passthrough_params pass.
2600 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
2601 /*output_index=*/{1},
2602 /*param_number=*/0,
2603 /*param_index=*/{}));
2604 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
2605 /*output_index=*/{3},
2606 /*param_number=*/1,
2607 /*param_index=*/{}));
2608
2609 InsertCopies(module.get());
2610
2611 // There should be no copies inserted.
2612 EXPECT_EQ(CountCopies(*module), 0);
2613 }
2614
TEST_F(CopyInsertionTest,NoAliasCheckViolation)2615 TEST_F(CopyInsertionTest, NoAliasCheckViolation) {
2616 const string& hlo_string = R"(
2617 HloModule cluster
2618
2619 ENTRY Entry {
2620 %arg = f32[8,28,28,1] parameter(0)
2621 %bitcast.2 = f32[8,1,28,28] bitcast(f32[8,28,28,1] %arg)
2622 ROOT %tuple.1 = (f32[8,1,28,28], f32[8,28,28,1]) tuple(f32[8,1,28,28] %bitcast.2, f32[8,28,28,1] %arg)
2623 }
2624 )";
2625 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2626 ParseAndReturnVerifiedModule(hlo_string));
2627 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
2628 /*output_index=*/{1},
2629 /*param_number=*/0,
2630 /*param_index=*/{}));
2631 InsertCopies(module.get());
2632 EXPECT_EQ(CountCopies(*module), 1);
2633 }
2634
TEST_F(CopyInsertionTest,DynamicUpdateSliceNoCopy)2635 TEST_F(CopyInsertionTest, DynamicUpdateSliceNoCopy) {
2636 absl::string_view hlo_string = R"(
2637 HloModule Module
2638
2639 ENTRY main {
2640 param = f32[1280,1,128] parameter(0)
2641 negate = f32[1280,1,128] negate(param)
2642 constant.1 = f32[] constant(0)
2643 broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2644 constant.3 = s32[] constant(0)
2645 ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(negate, broadcast.6, constant.3, constant.3, constant.3)
2646 }
2647 )";
2648 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2649 ParseAndReturnVerifiedModule(hlo_string));
2650 InsertCopies(module.get());
2651 EXPECT_EQ(CountCopies(*module), 0);
2652 }
2653
TEST_F(CopyInsertionTest,FusedDynamicUpdateSliceNoCopy)2654 TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceNoCopy) {
2655 absl::string_view hlo_string = R"(
2656 HloModule Module
2657
2658 fused_computation {
2659 param0 = f32[1280,1,128] parameter(0)
2660 constant.1 = f32[] constant(0)
2661 broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2662 constant.3 = s32[] constant(0)
2663 ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
2664 }
2665
2666 ENTRY main {
2667 param = f32[1280,1,128] parameter(0)
2668 negate = f32[1280,1,128] negate(param)
2669 ROOT fusion = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation
2670 }
2671 )";
2672 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2673 ParseAndReturnVerifiedModule(hlo_string));
2674 InsertCopies(module.get());
2675 EXPECT_EQ(CountCopies(*module), 0);
2676 }
2677
TEST_F(CopyInsertionTest,DynamicUpdateSliceCopy)2678 TEST_F(CopyInsertionTest, DynamicUpdateSliceCopy) {
2679 absl::string_view hlo_string = R"(
2680 HloModule Module
2681
2682 ENTRY main {
2683 param = f32[1280,1,128] parameter(0)
2684 negate = f32[1280,1,128] negate(param)
2685 constant.1 = f32[] constant(0)
2686 broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2687 constant.3 = s32[] constant(0)
2688 add = f32[1280,1,128] add(negate, negate)
2689 dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(negate, broadcast.6, constant.3, constant.3, constant.3)
2690 ROOT tuple = (f32[1280,1,128], f32[1280,1,128]) tuple(add, dynamic-update-slice.5)
2691 }
2692 )";
2693 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2694 ParseAndReturnVerifiedModule(hlo_string));
2695 InsertCopies(module.get());
2696 EXPECT_EQ(CountCopies(*module), 1);
2697 }
2698
TEST_F(CopyInsertionTest,DynamicUpdateSliceParameterShareCopy)2699 TEST_F(CopyInsertionTest, DynamicUpdateSliceParameterShareCopy) {
2700 absl::string_view hlo_string = R"(
2701 HloModule Module
2702
2703 ENTRY main {
2704 param = f32[1280,1,128] parameter(0)
2705 constant.1 = f32[] constant(0)
2706 broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2707 constant.3 = s32[] constant(0)
2708 ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param, broadcast.6, constant.3, constant.3, constant.3)
2709 }
2710 )";
2711 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2712 ParseAndReturnVerifiedModule(hlo_string));
2713 InsertCopies(module.get());
2714 EXPECT_EQ(CountCopies(*module), 1);
2715 }
2716
TEST_F(CopyInsertionTest,FusedDynamicUpdateSliceCopy)2717 TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceCopy) {
2718 absl::string_view hlo_string = R"(
2719 HloModule Module
2720
2721 fused_computation {
2722 param0 = f32[1280,1,128] parameter(0)
2723 constant.1 = f32[] constant(0)
2724 broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2725 constant.3 = s32[] constant(0)
2726 ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
2727 }
2728
2729 ENTRY main {
2730 param = f32[1280,1,128] parameter(0)
2731 negate = f32[1280,1,128] negate(param)
2732 add = f32[1280,1,128] add(negate, negate)
2733 fusion = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation
2734 ROOT tuple = (f32[1280,1,128], f32[1280,1,128]) tuple(negate, fusion)
2735 }
2736 )";
2737 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2738 ParseAndReturnVerifiedModule(hlo_string));
2739 InsertCopies(module.get());
2740 EXPECT_EQ(CountCopies(*module), 1);
2741 }
2742
TEST_F(CopyInsertionTest,ChainDynamicUpdateSliceCopy)2743 TEST_F(CopyInsertionTest, ChainDynamicUpdateSliceCopy) {
2744 absl::string_view hlo_string = R"(
2745 HloModule Module
2746
2747 ENTRY main {
2748 state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
2749 constant.1 = f32[] constant(0)
2750 broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={}
2751 get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1
2752 get-tuple-element.3 = s32[] get-tuple-element(state), index=0
2753 constant.2 = s32[] constant(128)
2754 add.5 = s32[] add(get-tuple-element.3, constant.2)
2755 constant.3 = s32[] constant(0)
2756 dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
2757 dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
2758 ROOT tuple.85 = (s32[], f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
2759 }
2760 )";
2761 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2762 ParseAndReturnVerifiedModule(hlo_string));
2763 InsertCopies(module.get());
2764 EXPECT_EQ(CountCopies(*module), 1);
2765 }
2766
TEST_F(CopyInsertionTest,FusedDynamicUpdateSliceCopy2)2767 TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceCopy2) {
2768 absl::string_view hlo_string = R"(
2769 HloModule Module
2770
2771 fused_computation.1 {
2772 param0 = f32[1280,1,128] parameter(0)
2773 constant.1 = f32[] constant(0)
2774 broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2775 constant.3 = s32[] constant(0)
2776 ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
2777 }
2778
2779 fused_computation.2 {
2780 param0 = f32[1280,1,128] parameter(0)
2781 param1 = f32[1280,1,128] parameter(1)
2782 slice = f32[128,1,128] slice(param1), slice={[0:128], [0:1], [0:128]}
2783 constant.3 = s32[] constant(0)
2784 ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, slice, constant.3, constant.3, constant.3)
2785 }
2786
2787 ENTRY main {
2788 param = f32[1280,1,128] parameter(0)
2789 negate = f32[1280,1,128] negate(param)
2790 add = f32[1280,1,128] add(negate, negate)
2791 fusion1 = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation.1
2792 ROOT fusion2 = f32[1280,1,128] fusion(fusion1, negate), kind=kLoop, calls=fused_computation.2
2793 }
2794 )";
2795 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2796 ParseAndReturnVerifiedModule(hlo_string));
2797 InsertCopies(module.get());
2798 EXPECT_EQ(CountCopies(*module), 1);
2799 }
2800
TEST_F(CopyInsertionTest,MultiOutputFusedDynamicUpdateSliceCopy)2801 TEST_F(CopyInsertionTest, MultiOutputFusedDynamicUpdateSliceCopy) {
2802 // Tests multi-output fusion with two DUS outputs, requiring two copies.
2803 absl::string_view hlo_string = R"(
2804 HloModule Module
2805
2806 fused_computation {
2807 param0 = f32[1280,1,128] parameter(0)
2808 param1 = f32[1280,1,128] parameter(1)
2809 param2 = f32[1280,1,128] parameter(2)
2810 constant.1 = f32[] constant(0)
2811 broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2812 constant.3 = s32[] constant(0)
2813 add.1 = f32[1280,1,128] add(param0, param0)
2814 dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3)
2815 dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3)
2816 ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6)
2817 }
2818
2819 ENTRY main {
2820 param = f32[1280,1,128] parameter(0)
2821 negate0 = f32[1280,1,128] negate(param)
2822 negate1 = f32[1280,1,128] negate(param)
2823 negate2 = f32[1280,1,128] negate(param)
2824 fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation
2825 gte0 = f32[1280,1,128] get-tuple-element(fusion), index=0
2826 gte1 = f32[1280,1,128] get-tuple-element(fusion), index=1
2827 gte2 = f32[1280,1,128] get-tuple-element(fusion), index=2
2828 add0 = f32[1280,1,128] add(negate0, gte0)
2829 add1 = f32[1280,1,128] add(negate1, gte1)
2830 add2 = f32[1280,1,128] add(negate2, gte2)
2831 ROOT tuple = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add0, add1, add2)
2832 }
2833 )";
2834 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2835 ParseAndReturnVerifiedModule(hlo_string));
2836 InsertCopies(module.get());
2837 EXPECT_EQ(CountCopies(*module), 2);
2838 }
2839
TEST_F(CopyInsertionTest,MultiOutputFusedDynamicUpdateSliceNoCopy)2840 TEST_F(CopyInsertionTest, MultiOutputFusedDynamicUpdateSliceNoCopy) {
2841 // Same as above, but negate1 is not used beyond fusion, so it only needs one
2842 // copy for negate0.
2843 absl::string_view hlo_string = R"(
2844 HloModule Module
2845
2846 fused_computation {
2847 param0 = f32[1280,1,128] parameter(0)
2848 param1 = f32[1280,1,128] parameter(1)
2849 param2 = f32[1280,1,128] parameter(2)
2850 constant.1 = f32[] constant(0)
2851 broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2852 constant.3 = s32[] constant(0)
2853 add.1 = f32[1280,1,128] add(param0, param0)
2854 dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3)
2855 dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3)
2856 ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6)
2857 }
2858
2859 ENTRY main {
2860 param = f32[1280,1,128] parameter(0)
2861 negate0 = f32[1280,1,128] negate(param)
2862 negate1 = f32[1280,1,128] negate(param)
2863 negate2 = f32[1280,1,128] negate(param)
2864 fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation
2865 gte0 = f32[1280,1,128] get-tuple-element(fusion), index=0
2866 gte1 = f32[1280,1,128] get-tuple-element(fusion), index=1
2867 gte2 = f32[1280,1,128] get-tuple-element(fusion), index=2
2868 add0 = f32[1280,1,128] add(negate0, gte0)
2869 add1 = f32[1280,1,128] add(gte1, gte1)
2870 add2 = f32[1280,1,128] add(negate2, gte2)
2871 ROOT tuple = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add0, add1, add2)
2872 }
2873 )";
2874 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2875 ParseAndReturnVerifiedModule(hlo_string));
2876 InsertCopies(module.get());
2877 EXPECT_EQ(CountCopies(*module), 1);
2878 }
2879
TEST_F(CopyInsertionTest,HorizontalLoopFusionNoCopy)2880 TEST_F(CopyInsertionTest, HorizontalLoopFusionNoCopy) {
2881 const string& hlo_string = R"(
2882 HloModule test
2883
2884 fused_computation {
2885 p0 = f32[10,20] parameter(0)
2886 p1 = f32[10,20] parameter(1)
2887 p2 = f32[10,10] parameter(2)
2888 p3 = f32[10,10] parameter(3)
2889 add0 = f32[10, 20] add(p0, p1)
2890 sub0 = f32[10, 10] subtract(p2, p3)
2891 reshape0 = f32[200] reshape(add0)
2892 reshape1 = f32[100] reshape(sub0)
2893 concat0 = f32[300] concatenate(reshape0, reshape1), dimensions={0}
2894 slice0 = f32[200] slice(concat0), slice={[0:200]}
2895 slice1 = f32[100] slice(concat0), slice={[200:300]}
2896 ROOT tuple = (f32[200], f32[100]) tuple(slice0, slice1)
2897 }
2898
2899 ENTRY test {
2900 p0 = f32[10,20] parameter(0)
2901 p1 = f32[10,20] parameter(1)
2902 p2 = f32[10,10] parameter(2)
2903 p3 = f32[10,10] parameter(3)
2904 fusion = (f32[200], f32[100]) fusion(p0, p1, p2, p3), kind=kInput, calls=fused_computation
2905 gte0 = f32[200] get-tuple-element(fusion), index=0
2906 gte1 = f32[100] get-tuple-element(fusion), index=1
2907 bitcast0 = f32[10,20] bitcast(gte0)
2908 bitcast1 = f32[10,10] bitcast(gte1)
2909 ROOT tuple = (f32[10,20], f32[10,10]) tuple(bitcast0, bitcast1)
2910 }
2911 )";
2912
2913 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2914 ParseAndReturnVerifiedModule(hlo_string));
2915 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
2916 /*output_index=*/{0},
2917 /*param_number=*/0,
2918 /*param_index=*/{}));
2919 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
2920 /*output_index=*/{1},
2921 /*param_number=*/3,
2922 /*param_index=*/{}));
2923
2924 InsertCopies(module.get());
2925
2926 // There should be no copies inserted.
2927 EXPECT_EQ(CountCopies(*module), 0);
2928 }
2929
TEST_F(CopyInsertionTest,NestedWhileAndConditional3)2930 TEST_F(CopyInsertionTest, NestedWhileAndConditional3) {
2931 const string& hlo_string = R"(
2932 HloModule TestModule
2933
2934 on_true.1
2935 {
2936 ROOT v1 = f32[2] parameter(0)
2937 }
2938
2939 on_false.1
2940 {
2941 v1 = f32[2] parameter(0)
2942 ROOT v2 = f32[2] multiply(v1,v1)
2943 }
2944
2945 on_true
2946 {
2947 v1 = f32[2] parameter(0)
2948 v2 = f32[2] add(v1,v1)
2949 v3 = (f32[2],f32[2]) tuple(v1,v2)
2950 v4 = f32[2] get-tuple-element(v3), index=1
2951 v5 = f32[2] multiply(v4,v2)
2952 ROOT t1 = (f32[2], f32[2]) tuple(v5,v2)
2953 }
2954
2955 on_false
2956 {
2957 v1 = f32[2] parameter(0)
2958 v2 = f32[2] multiply(v1,v1)
2959 pred.1 = pred[] constant(true)
2960 v4 = f32[2] conditional(pred.1, v1, v2), true_computation=on_true.1, false_computation=on_false.1
2961 v5 = f32[2] multiply(v4,v2)
2962 ROOT t2 = (f32[2], f32[2]) tuple(v2,v5)
2963
2964 }
2965
2966 cond.outer {
2967 param.1 = (pred[], f32[2], f32[2]) parameter(0)
2968 ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0
2969 }
2970
2971 body.outer {
2972 param.1 = (pred[], f32[2], f32[2]) parameter(0)
2973 pred.1 = pred[] get-tuple-element(param.1), index=0
2974 arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1
2975 if = (f32[2], f32[2]) conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false
2976 e1 = f32[2] get-tuple-element(if), index=0
2977 e2 = f32[2] get-tuple-element(if), index=1
2978 ROOT res = (pred[], f32[2], f32[2]) tuple(pred.1,e1, e2)
2979 }
2980
2981 ENTRY TestComputation {
2982 entry_param.1 = pred[] parameter(0)
2983 float_param = f32[2] parameter(1)
2984 entry_param = (pred[], f32[2], f32[2]) tuple(entry_param.1, float_param, float_param)
2985 ROOT while = (pred[], f32[2], f32[2]) while(entry_param), condition=cond.outer, body=body.outer
2986 }
2987 )";
2988 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2989 ParseAndReturnVerifiedModule(hlo_string));
2990 InsertCopies(module.get());
2991 // An extra copy must be kept inside the loop due to uses in the conditional
2992 EXPECT_EQ(CountCopies(*module), 4);
2993 }
2994
TEST_F(CopyInsertionTest,ConditionalBranchMustCopy1)2995 TEST_F(CopyInsertionTest, ConditionalBranchMustCopy1) {
2996 const string& hlo_string = R"(
2997 HloModule TestModule
2998
2999 branch_0_comp.5.clone {
3000 %parameter.0 = (s32[2]{0:T(128)}) parameter(0)
3001 %get-tuple-element = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.0), index=0
3002 %negate = s32[2]{0:T(128)} negate(s32[2]{0:T(128)} %get-tuple-element)
3003 %copy = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %negate)
3004 ROOT tuple.5 = (s32[2]{0:T(128)}) tuple(%copy)
3005 }
3006
3007 branch_1_comp.12.clone {
3008 %parameter.4 = (s32[2]{0:T(128)}) parameter(0)
3009 %get-tuple-element.5 = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.4), index=0
3010 %copy.1 = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %get-tuple-element.5)
3011 ROOT tuple.6 = (s32[2]{0:T(128)}) tuple(%copy.1)
3012 }
3013
3014 ENTRY TestComputation {
3015 %parameter.1 = s32[]{:T(128)} parameter(0), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3016 %parameter.2 = s32[2]{0:T(128)} parameter(1), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3017 %parameter.3 = s32[2]{0:T(128)} parameter(2), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3018 %tuple.1 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.3)
3019 %tuple.3 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.2)
3020 %conditional.18 = (s32[2]{0:T(128)}) conditional(s32[]{:T(128)} %parameter.1, (s32[2]{0:T(128)}) %tuple.1, (s32[2]{0:T(128)}) %tuple.3), branch_computations={%branch_0_comp.5.clone, %branch_1_comp.12.clone}, metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3021 %gte.1 = s32[2]{0:T(128)} get-tuple-element(conditional.18), index=0
3022 ROOT tuple.4 = (s32[2]{0:T(128)},s32[2]{0:T(128)}) tuple(parameter.2, gte.1)
3023 }
3024 )";
3025 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3026 ParseAndReturnVerifiedModule(hlo_string));
3027 InsertCopies(module.get());
3028 CopyInsertion copy_insertion(nullptr,
3029 /*use_region_based_live_range_analysis=*/true);
3030 ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
3031 VLOG(3) << module->ToString();
3032 // The copy.1 must be kept due to modification in the other branch.
3033 auto conditional18 = FindInstruction(module.get(), "conditional.18");
3034 CHECK_NE(conditional18, nullptr);
3035 auto tuple6 = conditional18->branch_computation(1)->root_instruction();
3036 CHECK_EQ(tuple6->opcode(), HloOpcode::kTuple);
3037 auto copy1 = tuple6->operand(0);
3038 CHECK_EQ(copy1->opcode(), HloOpcode::kCopy);
3039 }
3040
TEST_F(CopyInsertionTest,ConditionalBranchMustCopy2)3041 TEST_F(CopyInsertionTest, ConditionalBranchMustCopy2) {
3042 const string& hlo_string = R"(
3043 HloModule TestModule
3044
3045 branch_0_comp.5.clone {
3046 %parameter.0 = (s32[2]{0:T(128)}) parameter(0)
3047 %get-tuple-element = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.0), index=0
3048 %negate = s32[2]{0:T(128)} negate(s32[2]{0:T(128)} %get-tuple-element)
3049 %copy = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %negate)
3050 ROOT tuple.5 = (s32[2]{0:T(128)}) tuple(%copy)
3051 }
3052
3053 branch_1_comp.12.clone {
3054 %parameter.4 = (s32[2]{0:T(128)}) parameter(0)
3055 %get-tuple-element.5 = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.4), index=0
3056 %copy.1 = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %get-tuple-element.5)
3057 %constant.1 = s32[] constant(0)
3058 %broadcast.6 = s32[2] broadcast(constant.1), dimensions={}
3059 dynamic-update-slice.5 = s32[2]{0:T(128)} dynamic-update-slice(%copy.1, %broadcast.6, %constant.1)
3060 %add.1 = s32[2]{0:T(128)} add(dynamic-update-slice.5, %copy.1)
3061 ROOT tuple.6 = (s32[2]{0:T(128)}) tuple(%add.1)
3062 }
3063
3064 ENTRY TestComputation {
3065 %parameter.1 = s32[]{:T(128)} parameter(0), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3066 %parameter.2 = s32[2]{0:T(128)} parameter(1), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3067 %parameter.3 = s32[2]{0:T(128)} parameter(2), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3068 %tuple.1 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.3)
3069 %tuple.3 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.2)
3070 %conditional.18 = (s32[2]{0:T(128)}) conditional(s32[]{:T(128)} %parameter.1, (s32[2]{0:T(128)}) %tuple.1, (s32[2]{0:T(128)}) %tuple.3), branch_computations={%branch_0_comp.5.clone, %branch_1_comp.12.clone}
3071 %gte.1 = s32[2]{0:T(128)} get-tuple-element(conditional.18), index=0
3072 ROOT tuple.4 = (s32[2]{0:T(128)},s32[2]{0:T(128)}) tuple(parameter.2, gte.1)
3073 }
3074 )";
3075 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3076 ParseAndReturnVerifiedModule(hlo_string));
3077 CopyInsertion copy_insertion(nullptr,
3078 /*use_region_based_live_range_analysis=*/true);
3079 ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
3080 // The copy.1 must be kept due to modification in the other branch.
3081 auto conditional18 = FindInstruction(module.get(), "conditional.18");
3082 CHECK_NE(conditional18, nullptr);
3083 auto tuple6 = conditional18->branch_computation(1)->root_instruction();
3084 CHECK_EQ(tuple6->opcode(), HloOpcode::kTuple);
3085 auto add1 = tuple6->operand(0);
3086 CHECK_EQ(add1->opcode(), HloOpcode::kAdd);
3087 auto dus = add1->operand(0);
3088 auto copy1 = dus->operand(0);
3089 CHECK_EQ(copy1->opcode(), HloOpcode::kCopy);
3090 }
3091
TEST_F(CopyInsertionTest,ConditionalBranchMustCopy3)3092 TEST_F(CopyInsertionTest, ConditionalBranchMustCopy3) {
3093 const string& hlo_string = R"(
3094 HloModule primitive_computation_cond.19
3095 %branch_0_comp.5.clone (parameter.0: (s32[2])) -> (s32[2]) {
3096 %parameter.0 = (s32[2]{0:T(128)}) parameter(0)
3097 %get-tuple-element = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.0), index=0
3098 %negate = s32[2]{0:T(128)} negate(s32[2]{0:T(128)} %get-tuple-element)
3099 %copy = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %negate)
3100 ROOT %tuple.5 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %copy)
3101 }
3102
3103 %branch_1_comp.12.clone (parameter.4: (s32[2])) -> (s32[2]) {
3104 %parameter.4 = (s32[2]{0:T(128)}) parameter(0)
3105 %get-tuple-element.5 = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.4), index=0
3106 %copy.1 = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %get-tuple-element.5)
3107 ROOT %tuple.6 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %copy.1)
3108 }
3109
3110 ENTRY %primitive_computation_cond.19 (parameter.1: s32[], parameter.2: s32[2], parameter.3: s32[2]) -> (s32[2]) {
3111 %parameter.1 = s32[]{:T(128)} parameter(0), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3112 %parameter.3 = s32[2]{0:T(128)} parameter(2), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3113 %tuple.1 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.3)
3114 %parameter.2 = s32[2]{0:T(128)} parameter(1), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3115 %tuple.3 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.2)
3116 ROOT %conditional.18 = (s32[2]{0:T(128)}) conditional(s32[]{:T(128)} %parameter.1, (s32[2]{0:T(128)}) %tuple.1, (s32[2]{0:T(128)}) %tuple.3), branch_computations={%branch_0_comp.5.clone, %branch_1_comp.12.clone}, metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3117 }
3118 )";
3119 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3120 ParseAndReturnVerifiedModule(hlo_string));
3121 InsertCopies(module.get());
3122 CopyInsertion copy_insertion(nullptr,
3123 /*use_region_based_live_range_analysis=*/true);
3124 ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
3125 VLOG(3) << module->ToString();
3126 // The copy.1 must be kept b/c aliasing of parameter and root is not allowed.
3127 auto conditional18 = FindInstruction(module.get(), "conditional.18");
3128 CHECK_NE(conditional18, nullptr);
3129 auto tuple6 = conditional18->branch_computation(1)->root_instruction();
3130 CHECK_EQ(tuple6->opcode(), HloOpcode::kTuple);
3131 auto copy1 = tuple6->operand(0);
3132 CHECK_EQ(copy1->opcode(), HloOpcode::kCopy);
3133 }
3134
TEST_F(CopyInsertionTest,ConditionalBranchDoNotCopy1)3135 TEST_F(CopyInsertionTest, ConditionalBranchDoNotCopy1) {
3136 const string& hlo_string = R"(
3137 HloModule TestModule
3138
3139 branch_0_comp.5.clone {
3140 %parameter.0 = (s32[2]{0:T(128)}) parameter(0)
3141 %get-tuple-element = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.0), index=0
3142 %negate = s32[2]{0:T(128)} negate(s32[2]{0:T(128)} %get-tuple-element)
3143 %copy = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %negate)
3144 ROOT tuple.5 = (s32[2]{0:T(128)}) tuple(%copy)
3145 }
3146
3147 branch_1_comp.12.clone {
3148 %parameter.4 = (s32[2]{0:T(128)}) parameter(0)
3149 %get-tuple-element.5 = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.4), index=0
3150 %copy.1 = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %get-tuple-element.5)
3151 ROOT tuple.6 = (s32[2]{0:T(128)}) tuple(%copy.1)
3152 }
3153
3154 ENTRY TestComputation {
3155 %parameter.1 = s32[]{:T(128)} parameter(0), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3156 %parameter.2 = s32[2]{0:T(128)} parameter(1), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3157 %parameter.3 = s32[2]{0:T(128)} parameter(2), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3158 %tuple.1 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.3)
3159 %tuple.3 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.2)
3160 %conditional.18 = (s32[2]{0:T(128)}) conditional(s32[]{:T(128)} %parameter.1, (s32[2]{0:T(128)}) %tuple.1, (s32[2]{0:T(128)}) %tuple.3), branch_computations={%branch_0_comp.5.clone, %branch_1_comp.12.clone}, metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3161 %gte.1 = s32[2]{0:T(128)} get-tuple-element(conditional.18), index=0
3162 ROOT tuple.4 = (s32[2]{0:T(128)},s32[2]{0:T(128)}) tuple(gte.1, gte.1)
3163 }
3164 )";
3165 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3166 ParseAndReturnVerifiedModule(hlo_string));
3167 CopyInsertion copy_insertion(nullptr,
3168 /*use_region_based_live_range_analysis=*/true);
3169 ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
3170 VLOG(3) << module->ToString() << "\n";
3171
3172 // The copy.1 must be kept due to modification in the other branch.
3173 auto conditional18 = FindInstruction(module.get(), "conditional.18");
3174 CHECK_NE(conditional18, nullptr);
3175 auto tuple6 = conditional18->branch_computation(1)->root_instruction();
3176 CHECK_EQ(tuple6->opcode(), HloOpcode::kParameter);
3177 }
3178
TEST_F(CopyInsertionTest,RootInstructionNotLast)3179 TEST_F(CopyInsertionTest, RootInstructionNotLast) {
3180 // This is a test for b/189219227. When the root instruction is scheduled not
3181 // as the last instruction, it still lives out. So, we make sure that the copy
3182 // after the root cannot be removed.
3183 const string& hlo_string = R"(
3184 HloModule module, is_scheduled=true
3185
3186 body2 {
3187 p_body2 = (f32[2]{0}) parameter(0)
3188 p_body2.1 = f32[2]{0} get-tuple-element(p_body2), index=0
3189 add.3 = f32[2]{0} add(p_body2.1, p_body2.1)
3190 ROOT root2 = (f32[2]{0}) tuple(add.3)
3191 }
3192
3193 condition2 {
3194 p_cond2 = (f32[2]{0}) parameter(0)
3195 ROOT result = pred[] constant(true)
3196 }
3197
3198 body {
3199 p_body = (f32[2]{0}) parameter(0)
3200 p_body.1 = f32[2]{0} get-tuple-element(p_body), index=0
3201 ROOT root = (f32[2]{0}) tuple(p_body.1)
3202 copy = f32[2]{0} copy(p_body.1)
3203 tuple = (f32[2]{0}) tuple(copy)
3204 while.1 = (f32[2]{0}) while(tuple), condition=condition2, body=body2
3205 }
3206
3207 condition {
3208 p_cond = (f32[2]{0}) parameter(0)
3209 ROOT result = pred[] constant(true)
3210 }
3211
3212 ENTRY entry {
3213 const0 = f32[2]{0} constant({1, 2})
3214 while_init = (f32[2]{0}) tuple(const0)
3215 ROOT while.0 = (f32[2]{0}) while(while_init), condition=condition, body=body
3216 }
3217 )";
3218 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3219 ParseAndReturnVerifiedModule(hlo_string));
3220 CopyInsertion copy_insertion(nullptr,
3221 /*use_region_based_live_range_analysis=*/true);
3222 SequentialHloOrdering ordering(module->schedule());
3223 ASSERT_IS_OK(copy_insertion.RemoveUnnecessaryCopies(&ordering, module.get()));
3224 auto while_1 = FindInstruction(module.get(), "while.1");
3225 EXPECT_THAT(while_1, op::While(op::Tuple(op::Copy())));
3226 }
3227
TEST_F(CopyInsertionTest,InPlaceCollectivePermuteCopy)3228 TEST_F(CopyInsertionTest, InPlaceCollectivePermuteCopy) {
3229 absl::string_view hlo_string = R"(
3230 HloModule hlo_runner_test_0.1
3231 ENTRY hlo_runner_test_0.1 {
3232 replica_id = u32[] replica-id()
3233 broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
3234 constant.1 = u32[] constant(1000)
3235 broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
3236 broadcast.2 = u32[4,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
3237 constant.2 = s32[] constant(0)
3238 constant.3 = s32[] constant(1)
3239 tuple.input = (u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.0, u32[2,8,128]{2,1,0:T(2,128)} broadcast.0)
3240 tuple.output = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.1, u32[4,8,128]{2,1,0:T(2,128)} broadcast.2)
3241 tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
3242 tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
3243 tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
3244 constant.4 = s32[] constant(2)
3245 tuple.5 = (s32[],s32[],s32[]) tuple(constant.4, constant.2, constant.2)
3246 tuple.6 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.5)
3247 tuple.7 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.2)
3248 tuple.8 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.7)
3249 tuple.9 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.6)
3250 tuple.10 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.7)
3251 collective-permute.0 = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) collective-permute((u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple.input, (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple.output, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.8, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.9), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128},{2,8,128},{2,8,128}}
3252 collective-permute.1 = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) collective-permute((u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple.input, (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple.output, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.8, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.10), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128},{2,8,128},{2,8,128}}
3253 ROOT tuple = ((u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}), (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)})) tuple(collective-permute.0, collective-permute.1)
3254 }
3255 )";
3256 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3257 ParseAndReturnVerifiedModule(hlo_string));
3258 InsertCopies(module.get());
3259 EXPECT_EQ(CountCopies(*module), 4);
3260 }
3261
3262 } // namespace
3263 } // namespace xla
3264