• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/copy_insertion.h"
17 
18 #include <set>
19 
20 #include "tensorflow/compiler/xla/debug_options_flags.h"
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/hlo_runner.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/test_helpers.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/platform/test_benchmark.h"
34 
35 namespace op = xla::testing::opcode_matchers;
36 
37 namespace xla {
38 namespace {
39 
40 using ::testing::UnorderedElementsAre;
41 
CountCopies(const HloComputation & computation)42 int64 CountCopies(const HloComputation& computation) {
43   int64_t count = 0;
44   for (const auto& instruction : computation.instructions()) {
45     if (instruction->opcode() == HloOpcode::kCopy) {
46       count++;
47     }
48   }
49   return count;
50 }
51 
CountCopies(const HloModule & module)52 int64 CountCopies(const HloModule& module) {
53   int64_t count = 0;
54   for (const auto& computation : module.computations()) {
55     count += CountCopies(*computation);
56   }
57   return count;
58 }
59 
CountControlEdges(const HloComputation & computation)60 int64 CountControlEdges(const HloComputation& computation) {
61   int64_t count = 0;
62   for (const auto& instruction : computation.instructions()) {
63     count += instruction->control_successors().size();
64   }
65   return count;
66 }
67 
CountControlEdges(const HloModule & module)68 int64 CountControlEdges(const HloModule& module) {
69   int64_t count = 0;
70   for (const auto& computation : module.computations()) {
71     count += CountControlEdges(*computation);
72   }
73   return count;
74 }
75 
76 class CopyInsertionTest : public HloTestBase {
77  protected:
InsertCopies(HloModule * module)78   void InsertCopies(HloModule* module) {
79     CopyInsertion copy_insertion;
80     ASSERT_IS_OK(copy_insertion.Run(module).status());
81   }
82 
83   const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
84 };
85 
TEST_F(CopyInsertionTest,SingleParameter)86 TEST_F(CopyInsertionTest, SingleParameter) {
87   // Computation is a single parameter passed into a tuple. The parameter should
88   // be copied before entering the tuple.
89   auto builder = HloComputation::Builder(TestName());
90   HloInstruction* x = builder.AddInstruction(
91       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x"));
92   HloInstruction* tuple =
93       builder.AddInstruction(HloInstruction::CreateTuple({x}));
94 
95   EXPECT_THAT(x->users(), UnorderedElementsAre(tuple));
96 
97   auto module = CreateNewVerifiedModule();
98   module->AddEntryComputation(builder.Build());
99 
100   InsertCopies(module.get());
101 
102   EXPECT_THAT(module->entry_computation()->root_instruction(),
103               op::Tuple(op::Copy(x)));
104 }
105 
TEST_F(CopyInsertionTest,SingleConstant)106 TEST_F(CopyInsertionTest, SingleConstant) {
107   // Computation is a single constant passed into a tuple. The parameter should
108   // be copied before entering the tuple.
109   auto builder = HloComputation::Builder(TestName());
110   HloInstruction* constant = builder.AddInstruction(
111       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
112   HloInstruction* tuple =
113       builder.AddInstruction(HloInstruction::CreateTuple({constant}));
114 
115   EXPECT_THAT(constant->users(), UnorderedElementsAre(tuple));
116 
117   auto module = CreateNewVerifiedModule();
118   module->AddEntryComputation(builder.Build());
119 
120   InsertCopies(module.get());
121   EXPECT_EQ(CountCopies(*module), 1);
122 
123   EXPECT_THAT(module->entry_computation()->root_instruction(),
124               op::Tuple(op::Copy(constant)));
125 }
126 
TEST_F(CopyInsertionTest,ExistingCopiesNotRemoved)127 TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) {
128   // Verify that kCopy instructions which change layout and exist before
129   // copy-insertion remain in the graph after copy-insertion.
130   auto module = CreateNewVerifiedModule();
131 
132   auto builder = HloComputation::Builder(TestName());
133   HloInstruction* constant =
134       builder.AddInstruction(HloInstruction::CreateConstant(
135           LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}})));
136   auto minor_to_major = LayoutUtil::MinorToMajor(constant->shape());
137   Layout reversed_layout =
138       LayoutUtil::MakeLayoutFromMajorToMinor(minor_to_major);
139   Shape copy_shape = constant->shape();
140   *copy_shape.mutable_layout() = reversed_layout;
141   HloInstruction* copy_1 = builder.AddInstruction(
142       HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant));
143   HloInstruction* copy_2 = builder.AddInstruction(
144       HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant));
145   HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
146       constant->shape(), HloOpcode::kAdd, copy_1, copy_2));
147   builder.AddInstruction(
148       HloInstruction::CreateUnary(add->shape(), HloOpcode::kCopy, add));
149 
150   module->AddEntryComputation(builder.Build());
151 
152   EXPECT_EQ(CountCopies(*module), 3);
153 
154   InsertCopies(module.get());
155 
156   EXPECT_EQ(CountCopies(*module), 2);
157 
158   EXPECT_EQ(module->entry_computation()->root_instruction(), add);
159   EXPECT_THAT(module->entry_computation()->root_instruction(),
160               op::Add(op::Copy(op::Constant()), op::Copy(op::Constant())));
161 }
162 
TEST_F(CopyInsertionTest,MultipleConstantsAndParameters)163 TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) {
164   // Create a computation with more than one constant and parameter. Only one of
165   // each constant/parameter is pointed to by the output tuple. Only these
166   // instructions should be copied.
167   auto builder = HloComputation::Builder(TestName());
168 
169   HloInstruction* constant1 = builder.AddInstruction(
170       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
171   HloInstruction* constant2 = builder.AddInstruction(
172       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
173 
174   HloInstruction* x = builder.AddInstruction(
175       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x"));
176   HloInstruction* y = builder.AddInstruction(
177       HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "y"));
178 
179   HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
180       ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, constant1, y));
181 
182   builder.AddInstruction(HloInstruction::CreateTuple({constant2, x, add}));
183 
184   auto module = CreateNewVerifiedModule();
185   module->AddEntryComputation(builder.Build());
186 
187   InsertCopies(module.get());
188   EXPECT_EQ(CountCopies(*module), 2);
189 
190   EXPECT_THAT(
191       module->entry_computation()->root_instruction(),
192       op::Tuple(op::Copy(constant2), op::Copy(x), op::Add(constant1, y)));
193 }
194 
TEST_F(CopyInsertionTest,AmbiguousPointsToSet)195 TEST_F(CopyInsertionTest, AmbiguousPointsToSet) {
196   // Create a computation using select which has an ambiguous points-to set for
197   // the computation result. Verify that copies are added properly.
198   auto builder = HloComputation::Builder(TestName());
199   HloInstruction* constant1 = builder.AddInstruction(
200       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
201   HloInstruction* constant2 = builder.AddInstruction(
202       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
203   HloInstruction* constant3 = builder.AddInstruction(
204       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
205 
206   HloInstruction* tuple1 = builder.AddInstruction(
207       HloInstruction::CreateTuple({constant1, constant2}));
208   HloInstruction* tuple2 = builder.AddInstruction(
209       HloInstruction::CreateTuple({constant3, constant2}));
210 
211   HloInstruction* pred = builder.AddInstruction(
212       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
213   builder.AddInstruction(HloInstruction::CreateTernary(
214       tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
215 
216   EXPECT_THAT(constant1->users(), UnorderedElementsAre(tuple1));
217   EXPECT_THAT(constant2->users(), UnorderedElementsAre(tuple1, tuple2));
218   EXPECT_THAT(constant3->users(), UnorderedElementsAre(tuple2));
219 
220   auto module = CreateNewVerifiedModule();
221   module->AddEntryComputation(builder.Build());
222 
223   HloInstruction* old_root = module->entry_computation()->root_instruction();
224   InsertCopies(module.get());
225   EXPECT_EQ(CountCopies(*module), 2);
226 
227   EXPECT_THAT(module->entry_computation()->root_instruction(),
228               op::Tuple(op::Copy(op::GetTupleElement(old_root)),
229                         op::Copy(op::GetTupleElement(old_root))));
230 }
231 
TEST_F(CopyInsertionTest,BitcastParameter)232 TEST_F(CopyInsertionTest, BitcastParameter) {
233   // The output of a bitcast is its operand (same buffer), so a bitcast
234   // parameter feeding the result must have a copy added.
235   auto builder = HloComputation::Builder(TestName());
236   HloInstruction* x = builder.AddInstruction(
237       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {4}), "x"));
238   HloInstruction* bitcast = builder.AddInstruction(
239       HloInstruction::CreateBitcast(ShapeUtil::MakeShape(F32, {2, 2}), x));
240 
241   auto module = CreateNewVerifiedModule();
242   module->AddEntryComputation(builder.Build());
243 
244   EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast));
245 
246   HloInstruction* old_root = module->entry_computation()->root_instruction();
247   InsertCopies(module.get());
248   EXPECT_EQ(CountCopies(*module), 1);
249 
250   EXPECT_THAT(module->entry_computation()->root_instruction(),
251               op::Copy(old_root));
252 }
253 
TEST_F(CopyInsertionTest,BitcastConstant)254 TEST_F(CopyInsertionTest, BitcastConstant) {
255   // The output of a bitcast is its operand (same buffer), so a bitcast
256   // constant feeding the result must have a copy added.
257   auto builder = HloComputation::Builder(TestName());
258   HloInstruction* constant =
259       builder.AddInstruction(HloInstruction::CreateConstant(
260           LiteralUtil::CreateR1<float>({1.0, 42.0})));
261   HloInstruction* bitcast = builder.AddInstruction(
262       HloInstruction::CreateBitcast(ShapeUtil::MakeShape(F32, {2}), constant));
263 
264   auto module = CreateNewVerifiedModule();
265   module->AddEntryComputation(builder.Build());
266 
267   EXPECT_THAT(constant->users(), UnorderedElementsAre(bitcast));
268 
269   HloInstruction* old_root = module->entry_computation()->root_instruction();
270   InsertCopies(module.get());
271   EXPECT_EQ(CountCopies(*module), 1);
272 
273   EXPECT_THAT(module->entry_computation()->root_instruction(),
274               op::Copy(old_root));
275 }
276 
TEST_F(CopyInsertionTest,BitcastTupleElementParameter)277 TEST_F(CopyInsertionTest, BitcastTupleElementParameter) {
278   // Same as BitcastParameter, but the bitcast is wrapped in a tuple.
279   auto builder = HloComputation::Builder(TestName());
280   HloInstruction* x = builder.AddInstruction(
281       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {4}), "x"));
282   HloInstruction* bitcast = builder.AddInstruction(
283       HloInstruction::CreateBitcast(ShapeUtil::MakeShape(F32, {2, 2}), x));
284   builder.AddInstruction(HloInstruction::CreateTuple({bitcast}));
285 
286   auto module = CreateNewVerifiedModule();
287   module->AddEntryComputation(builder.Build());
288 
289   EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast));
290 
291   InsertCopies(module.get());
292   EXPECT_EQ(CountCopies(*module), 1);
293 
294   EXPECT_THAT(module->entry_computation()->root_instruction(),
295               op::Tuple(op::Copy(bitcast)));
296 }
297 
TEST_F(CopyInsertionTest,NestedTupleParameter)298 TEST_F(CopyInsertionTest, NestedTupleParameter) {
299   // Construct a trivial computation where the root of the computation is a
300   // nested tuple-shaped parameter. The parameter should be deep copied and the
301   // copy should be the root of the computation.
302   auto builder = HloComputation::Builder(TestName());
303 
304   // Param shape is: ((F32[], S32[1,2,3]), F32[42])
305   builder.AddInstruction(HloInstruction::CreateParameter(
306       0,
307       ShapeUtil::MakeTupleShape(
308           {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}),
309                                       ShapeUtil::MakeShape(S32, {1, 2, 3})}),
310            ShapeUtil::MakeShape(F32, {42})}),
311       "param0"));
312 
313   auto module = CreateNewVerifiedModule();
314   module->AddEntryComputation(builder.Build());
315 
316   EXPECT_EQ(HloOpcode::kParameter,
317             module->entry_computation()->root_instruction()->opcode());
318 
319   HloInstruction* old_root = module->entry_computation()->root_instruction();
320   InsertCopies(module.get());
321   EXPECT_EQ(CountCopies(*module), 3);
322 
323   HloInstruction* new_root = module->entry_computation()->root_instruction();
324   EXPECT_NE(old_root, new_root);
325 
326   EXPECT_THAT(
327       new_root,
328       op::Tuple(
329           op::Tuple(
330               op::Copy(op::GetTupleElement(op::GetTupleElement(old_root))),
331               op::Copy(op::GetTupleElement(op::GetTupleElement(old_root)))),
332           op::Copy(op::GetTupleElement(old_root))));
333 }
334 
TEST_F(CopyInsertionTest,ElementOfNestedTupleParameter)335 TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) {
336   // Construct a computation where the root of the computation is a tuple
337   // element of a nested tuple-shaped parameter.
338   auto builder = HloComputation::Builder(TestName());
339 
340   // Param shape is: ((F32[], S32[1,2,3]), F32[42])
341   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
342       0,
343       ShapeUtil::MakeTupleShape(
344           {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}),
345                                       ShapeUtil::MakeShape(S32, {1, 2, 3})}),
346            ShapeUtil::MakeShape(F32, {42})}),
347       "param0"));
348 
349   // The return value of the computation is the zero-th element of the nested
350   // tuple. This element is itself a tuple.
351   auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
352       ShapeUtil::GetSubshape(param->shape(), {0}), param, 0));
353 
354   auto module = CreateNewVerifiedModule();
355   module->AddEntryComputation(builder.Build());
356 
357   EXPECT_EQ(gte, module->entry_computation()->root_instruction());
358 
359   InsertCopies(module.get());
360   EXPECT_EQ(CountCopies(*module), 2);
361 
362   EXPECT_THAT(
363       module->entry_computation()->root_instruction(),
364       op::Tuple(op::Copy(op::GetTupleElement(op::GetTupleElement(param))),
365                 op::Copy(op::GetTupleElement(op::GetTupleElement(param)))));
366 }
367 
TEST_F(CopyInsertionTest,AmbiguousTopLevelRoot)368 TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) {
369   // Create a computation using select which has an ambiguous points-to set for
370   // the top-level buffer of the root of the computation. Verify that a shallow
371   // copy is added.
372   auto builder = HloComputation::Builder(TestName());
373   HloInstruction* constant1 = builder.AddInstruction(
374       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
375   HloInstruction* constant2 = builder.AddInstruction(
376       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
377 
378   HloInstruction* tuple1 = builder.AddInstruction(
379       HloInstruction::CreateTuple({constant1, constant2}));
380   HloInstruction* tuple2 = builder.AddInstruction(
381       HloInstruction::CreateTuple({constant2, constant1}));
382 
383   HloInstruction* pred = builder.AddInstruction(
384       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
385   HloInstruction* select = builder.AddInstruction(HloInstruction::CreateTernary(
386       tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
387   HloInstruction* gte =
388       builder.AddInstruction(HloInstruction::CreateGetTupleElement(
389           ShapeUtil::GetSubshape(select->shape(), {0}), select, 0));
390 
391   auto module = CreateNewVerifiedModule();
392   module->AddEntryComputation(builder.Build());
393 
394   EXPECT_EQ(gte, module->entry_computation()->root_instruction());
395 
396   HloInstruction* old_root = module->entry_computation()->root_instruction();
397   InsertCopies(module.get());
398   EXPECT_EQ(CountCopies(*module), 1);
399 
400   EXPECT_THAT(module->entry_computation()->root_instruction(),
401               op::Copy(old_root));
402 }
403 
404 class WhileCopyInsertionTest : public CopyInsertionTest {
405  protected:
WhileCopyInsertionTest()406   WhileCopyInsertionTest() : module_(CreateNewVerifiedModule()) {}
407 
408   // Builds a While condition computation which reads the induction variable
409   // from the tuple parameter, and returns a predicate indicating whether this
410   // value is less than the constant '10'.
411   // The parameter 'nested' specifies the loop state shape from which to
412   // read the induction variable.
BuildConditionComputation(const Shape & loop_state_shape)413   std::unique_ptr<HloComputation> BuildConditionComputation(
414       const Shape& loop_state_shape) {
415     auto builder = HloComputation::Builder(TestName() + ".Condition");
416     auto limit_const = builder.AddInstruction(
417         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(10)));
418     auto loop_state = builder.AddInstruction(
419         HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
420     auto induction_variable =
421         builder.AddInstruction(HloInstruction::CreateGetTupleElement(
422             limit_const->shape(), loop_state, 0));
423     builder.AddInstruction(HloInstruction::CreateCompare(
424         condition_result_shape_, induction_variable, limit_const,
425         ComparisonDirection::kLt));
426     return builder.Build();
427   }
428 
429   // Builds a While body computation with one output tuple element dependent on
430   // both input tuple elements.
431   // EX:
432   // Body({in0, in1})
433   //   out0 = Add(in0, 1)
434   //   out1 = Add(BCast(in0), in1)
435   //   Tuple(out0, out1)
BuildDependentBodyComputation()436   std::unique_ptr<HloComputation> BuildDependentBodyComputation() {
437     auto builder = HloComputation::Builder(TestName() + ".Body");
438     // Create param instruction to access loop state.
439     auto loop_state = builder.AddInstruction(
440         HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
441     // Update the induction variable GTE(0).
442     auto induction_variable =
443         builder.AddInstruction(HloInstruction::CreateGetTupleElement(
444             induction_variable_shape_, loop_state, 0));
445     auto inc = builder.AddInstruction(
446         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
447     auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
448         induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
449     // Update data GTE(1).
450     auto data = builder.AddInstruction(
451         HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
452     // Use 'induction_variable' in computation with no path to output tuple.
453     Shape f32_scalar_shape = ShapeUtil::MakeShape(F32, {});
454     auto convert = builder.AddInstruction(
455         HloInstruction::CreateConvert(f32_scalar_shape, induction_variable));
456     auto update = builder.AddInstruction(
457         HloInstruction::CreateBroadcast(data_shape_, convert, {}));
458     auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
459         data_shape_, HloOpcode::kAdd, data, update));
460     // Create output Tuple.
461     builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
462     return builder.Build();
463   }
464 
465   // Builds a While body computation with two output tuple elements dependent on
466   // both input tuple elements.
467   //
468   // EX: Body({in0, in1, in2})
469   //   out0 = Add(in0, 1)
470   //   out1 = in1
471   //   out2 = in2
472   //   Tuple(out0, out1, out2)
BuildDependentBodyComputation2()473   std::unique_ptr<HloComputation> BuildDependentBodyComputation2() {
474     auto builder = HloComputation::Builder(TestName() + ".Body");
475 
476     const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
477         {induction_variable_shape_, data_shape_, data_shape_});
478 
479     auto loop_state = builder.AddInstruction(
480         HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
481 
482     // Update the induction variable GTE(0).
483     auto induction_variable =
484         builder.AddInstruction(HloInstruction::CreateGetTupleElement(
485             induction_variable_shape_, loop_state, 0));
486     auto inc = builder.AddInstruction(
487         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
488 
489     // add0 = Add(in0, 1)
490     auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
491         induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
492     // data1 = GTE(1).
493     HloInstruction* data1 = builder.AddInstruction(
494         HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
495 
496     // data2 = GTE(2).
497     HloInstruction* data2 = builder.AddInstruction(
498         HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 2));
499 
500     // Create output Tuple.
501     builder.AddInstruction(HloInstruction::CreateTuple({add0, data1, data2}));
502 
503     return builder.Build();
504   }
505 
506   // Builds a While body computation with read-only tuple element 0.
507   // EX:
508   // Body({in0, in1})
509   //   out0 = in0
510   //   out1 = Add(BCast(in0), in1)
511   //   Tuple(out0, out1)
BuildDependentBodyOneReadOnlyComputation()512   std::unique_ptr<HloComputation> BuildDependentBodyOneReadOnlyComputation() {
513     auto builder = HloComputation::Builder(TestName() + ".Body");
514     // Create param instruction to access loop state.
515     auto loop_state = builder.AddInstruction(
516         HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
517     // Update the induction variable GTE(0).
518     auto induction_variable =
519         builder.AddInstruction(HloInstruction::CreateGetTupleElement(
520             induction_variable_shape_, loop_state, 0));
521     // Update data GTE(1).
522     auto data = builder.AddInstruction(
523         HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
524 
525     // Use 'induction_variable' in computation with no path to output tuple.
526     Shape f32_scalar_shape = ShapeUtil::MakeShape(F32, {});
527     auto convert = builder.AddInstruction(
528         HloInstruction::CreateConvert(f32_scalar_shape, induction_variable));
529     auto update = builder.AddInstruction(
530         HloInstruction::CreateBroadcast(data_shape_, convert, {}));
531     auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
532         data_shape_, HloOpcode::kAdd, data, update));
533     // Create output Tuple.
534     builder.AddInstruction(
535         HloInstruction::CreateTuple({induction_variable, add1}));
536     return builder.Build();
537   }
538 
539   // Builds a While body computation with independent outputs.
540   // EX:
541   // Body({in0, in1})
542   //   out0 = Add(in0, 1)
543   //   out1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
544   //   Tuple(out0, out1)
BuildIndependentBodyComputation(bool nested=false)545   std::unique_ptr<HloComputation> BuildIndependentBodyComputation(
546       bool nested = false) {
547     auto builder = HloComputation::Builder(TestName() + ".Body");
548     // Create param instruction to access loop state.
549     const Shape& loop_state_shape =
550         nested ? nested_loop_state_shape_ : loop_state_shape_;
551 
552     auto loop_state = builder.AddInstruction(
553         HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
554     // Update the induction variable GTE(0).
555     auto induction_variable =
556         builder.AddInstruction(HloInstruction::CreateGetTupleElement(
557             induction_variable_shape_, loop_state, 0));
558     auto inc = builder.AddInstruction(
559         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
560     // add0 = Add(in0, 1)
561     auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
562         induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
563     // Update data GTE(1).
564     HloInstruction* data = nullptr;
565     if (nested) {
566       data = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
567           nested_tuple_shape_, loop_state, 1));
568       data = builder.AddInstruction(
569           HloInstruction::CreateGetTupleElement(data_shape_, data, 0));
570     } else {
571       data = builder.AddInstruction(
572           HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
573     }
574     auto update = builder.AddInstruction(
575         HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
576             {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
577     // add1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
578     auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
579         data_shape_, HloOpcode::kAdd, data, update));
580     // Create output Tuple.
581     if (nested) {
582       auto nested_tuple =
583           builder.AddInstruction(HloInstruction::CreateTuple({add1, add1}));
584       builder.AddInstruction(HloInstruction::CreateTuple({add0, nested_tuple}));
585     } else {
586       builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
587     }
588     return builder.Build();
589   }
590 
591   // Builds a While body computation with the following nested tuple
592   // sub-computation:
593   //                            |
594   //                    GTE(loop_state, 1)
595   //                       /           \
596   // GTE(GTE(loop_state, 1), 0)     GTE(GTE(loop_state, 1), 1)
597   //           |                              |
598   //          Add                           Reverse
599   //           |                              |
BuildNestedBodyComputation()600   std::unique_ptr<HloComputation> BuildNestedBodyComputation() {
601     auto builder = HloComputation::Builder(TestName() + ".Body");
602     // Create param instruction to access loop state.
603     auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter(
604         0, nested_loop_state_shape_, "loop_state"));
605     // Update GTE(0).
606     auto gte0 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
607         induction_variable_shape_, loop_state, 0));
608     auto inc = builder.AddInstruction(
609         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
610     auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
611         gte0->shape(), HloOpcode::kAdd, gte0, inc));
612 
613     // GTE(loop_state, 1)
614     auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
615         nested_tuple_shape_, loop_state, 1));
616     // GTE(GTE(loop_state, 1), 0) -> Add
617     auto gte10 = builder.AddInstruction(
618         HloInstruction::CreateGetTupleElement(data_shape_, gte1, 0));
619     auto update10 = builder.AddInstruction(
620         HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
621             {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
622     auto add10 = builder.AddInstruction(HloInstruction::CreateBinary(
623         data_shape_, HloOpcode::kAdd, gte10, update10));
624 
625     // GTE(GTE(loop_state, 1), 1) -> Reverse
626     auto gte11 = builder.AddInstruction(
627         HloInstruction::CreateGetTupleElement(data_shape_, gte1, 1));
628     auto rev11 = builder.AddInstruction(
629         HloInstruction::CreateReverse(data_shape_, gte11, {0}));
630 
631     // Create output Tuple.
632     auto inner_tuple =
633         builder.AddInstruction(HloInstruction::CreateTuple({add10, rev11}));
634     builder.AddInstruction(HloInstruction::CreateTuple({add0, inner_tuple}));
635     return builder.Build();
636   }
637 
638   // Builds a While instruction using 'condition' and 'body' sub-computations.
639   // Init operand is initialized to zeros of appropriate shape.
BuildWhileInstruction(HloComputation * condition,HloComputation * body,bool nested=false)640   HloInstruction* BuildWhileInstruction(HloComputation* condition,
641                                         HloComputation* body,
642                                         bool nested = false) {
643     auto builder = HloComputation::Builder(TestName() + ".While");
644     auto induction_var_init = builder.AddInstruction(
645         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
646 
647     auto data_init = builder.AddInstruction(
648         HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
649             {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
650 
651     if (nested) {
652       auto inner_init = builder.AddInstruction(
653           HloInstruction::CreateTuple({data_init, data_init}));
654       auto loop_state_init = builder.AddInstruction(
655           HloInstruction::CreateTuple({induction_var_init, inner_init}));
656       auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile(
657           loop_state_init->shape(), condition, body, loop_state_init));
658       module_->AddEntryComputation(builder.Build());
659       return while_hlo;
660     }
661 
662     auto loop_state_init = builder.AddInstruction(
663         HloInstruction::CreateTuple({induction_var_init, data_init}));
664     auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile(
665         loop_state_shape_, condition, body, loop_state_init));
666     module_->AddEntryComputation(builder.Build());
667     return while_hlo;
668   }
669 
BuildWhileInstruction_InitPointsToConstant()670   HloInstruction* BuildWhileInstruction_InitPointsToConstant() {
671     auto builder = HloComputation::Builder(TestName() + ".While");
672     auto data_init = builder.AddInstruction(
673         HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
674             {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
675     return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init,
676                                                &builder);
677   }
678 
BuildWhileInstruction_InitPointsToParameter()679   HloInstruction* BuildWhileInstruction_InitPointsToParameter() {
680     auto builder = HloComputation::Builder(TestName() + ".While");
681     auto data_init = builder.AddInstruction(
682         HloInstruction::CreateParameter(0, data_shape_, "data_init"));
683     return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init,
684                                                &builder);
685   }
686 
BuildWhileInstruction_InitPointsToAmbiguous()687   HloInstruction* BuildWhileInstruction_InitPointsToAmbiguous() {
688     auto builder = HloComputation::Builder(TestName() + ".While");
689 
690     auto one = builder.AddInstruction(
691         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
692     auto v1 = builder.AddInstruction(
693         HloInstruction::CreateBroadcast(data_shape_, one, {}));
694     auto zero = builder.AddInstruction(
695         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
696     auto v2 = builder.AddInstruction(
697         HloInstruction::CreateBroadcast(data_shape_, zero, {}));
698 
699     auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({v1, v2}));
700     auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({v2, v1}));
701 
702     auto pred = builder.AddInstruction(
703         HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
704     auto data_init = builder.AddInstruction(HloInstruction::CreateTernary(
705         nested_tuple_shape_, HloOpcode::kTupleSelect, pred, tuple1, tuple2));
706 
707     return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_,
708                                                data_init, &builder);
709   }
710 
BuildWhileInstruction_InitPointsToNonDistinct()711   HloInstruction* BuildWhileInstruction_InitPointsToNonDistinct() {
712     auto builder = HloComputation::Builder(TestName() + ".While");
713 
714     auto one = builder.AddInstruction(
715         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
716     auto one_vec = builder.AddInstruction(
717         HloInstruction::CreateBroadcast(data_shape_, one, {}));
718     auto data_init =
719         builder.AddInstruction(HloInstruction::CreateTuple({one_vec, one_vec}));
720 
721     return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_,
722                                                data_init, &builder);
723   }
724 
BuildWhileInstruction_InitPointsToInterfering()725   HloInstruction* BuildWhileInstruction_InitPointsToInterfering() {
726     auto builder = HloComputation::Builder(TestName() + ".While");
727     auto one = builder.AddInstruction(
728         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
729     auto data_init = builder.AddInstruction(
730         HloInstruction::CreateBroadcast(data_shape_, one, {}));
731     auto one_vec = builder.AddInstruction(
732         HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
733             {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
734     // Take a reference to 'data_init' to make it interfere with while result.
735     auto add = builder.AddInstruction(HloInstruction::CreateBinary(
736         data_shape_, HloOpcode::kAdd, data_init, one_vec));
737 
738     auto xla_while = BuildWhileInstructionWithCustomInit(loop_state_shape_,
739                                                          data_init, &builder);
740 
741     // Add an additional binary operation operating on the while and the
742     // interfering add so that neither operation is dead.
743     auto gte = xla_while->parent()->AddInstruction(
744         HloInstruction::CreateGetTupleElement(
745             ShapeUtil::GetSubshape(xla_while->shape(), {1}), xla_while, 1));
746     auto sub = xla_while->parent()->AddInstruction(HloInstruction::CreateBinary(
747         data_shape_, HloOpcode::kSubtract, add, gte));
748     auto gte0 = xla_while->parent()->AddInstruction(
749         HloInstruction::CreateGetTupleElement(
750             ShapeUtil::GetSubshape(xla_while->shape(), {0}), xla_while, 0));
751     auto tuple = xla_while->parent()->AddInstruction(
752         HloInstruction::CreateTuple({gte0, sub}));
753 
754     xla_while->parent()->set_root_instruction(tuple);
755 
756     return xla_while;
757   }
758 
BuildWhileInstructionWithCustomInit(const Shape & loop_state_shape,HloInstruction * data_init,HloComputation::Builder * builder)759   HloInstruction* BuildWhileInstructionWithCustomInit(
760       const Shape& loop_state_shape, HloInstruction* data_init,
761       HloComputation::Builder* builder) {
762     const bool nested =
763         ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_);
764     auto induction_var_init = builder->AddInstruction(
765         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
766     auto condition = module_->AddEmbeddedComputation(
767         BuildConditionComputation(loop_state_shape));
768     auto body = module_->AddEmbeddedComputation(
769         BuildIndependentBodyComputation(nested));
770     auto loop_state_init = builder->AddInstruction(
771         HloInstruction::CreateTuple({induction_var_init, data_init}));
772     auto while_hlo = builder->AddInstruction(HloInstruction::CreateWhile(
773         loop_state_shape, condition, body, loop_state_init));
774     module_->AddEntryComputation(builder->Build());
775     return while_hlo;
776   }
777 
778   std::unique_ptr<HloModule> module_;
779   Shape induction_variable_shape_ = ShapeUtil::MakeShape(S32, {});
780   Shape data_shape_ = ShapeUtil::MakeShape(F32, {8});
781   Shape loop_state_shape_ =
782       ShapeUtil::MakeTupleShape({induction_variable_shape_, data_shape_});
783   Shape nested_tuple_shape_ =
784       ShapeUtil::MakeTupleShape({data_shape_, data_shape_});
785   Shape nested_loop_state_shape_ = ShapeUtil::MakeTupleShape(
786       {induction_variable_shape_, nested_tuple_shape_});
787   Shape condition_result_shape_ = ShapeUtil::MakeShape(PRED, {});
788 };
789 
790 // Tests while body computation with independent tuple elements:
791 //
792 //   While.Body({in0, in1})
793 //     out0 = Add(in0, 1)
794 //     out1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
795 //     Tuple(out0, out1)
796 //
797 // CopyInsertion pass should not generate any copies.
798 //
TEST_F(WhileCopyInsertionTest,IndependentTupleElements)799 TEST_F(WhileCopyInsertionTest, IndependentTupleElements) {
800   auto condition = module_->AddEmbeddedComputation(
801       BuildConditionComputation(loop_state_shape_));
802   auto body =
803       module_->AddEmbeddedComputation(BuildIndependentBodyComputation());
804   auto while_hlo = BuildWhileInstruction(condition, body);
805 
806   InsertCopies(module_.get());
807 
808   // Body should have no copies as the adds can be done inplace.
809   EXPECT_EQ(CountCopies(*body), 0);
810   EXPECT_EQ(CountControlEdges(*module_), 0);
811 
812   // Both init indices need copies as they are constants.
813   EXPECT_THAT(while_hlo->operand(0),
814               op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant())));
815 }
816 
817 // Tests Copy Insertion when a while feeds another while
818 //                         PARAMETER
819 //                        |        |
820 //                        GTE(0)   GTE(1)
821 //                        |        |
822 //                        X = CreateTuple(GTE(0), GTE(1))
823 //                                 |
824 //                        WHILE(X) (root)
TEST_F(WhileCopyInsertionTest,WhileFeedingWhileThruParameterWithCopies)825 TEST_F(WhileCopyInsertionTest, WhileFeedingWhileThruParameterWithCopies) {
826   const string& hlo_string = R"(
827 HloModule DependentTupleElements
828 
829 %DependentTupleElements.Body (loop_state.1: (s32[], f32[8])) -> (s32[], f32[8]) {
830   %loop_state.1 = (s32[], f32[8]{0}) parameter(0)
831   %get-tuple-element.1 = s32[] get-tuple-element((s32[], f32[8]{0}) %loop_state.1), index=0
832   %constant.1 = s32[] constant(1)
833   %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
834   %get-tuple-element.2 = f32[8]{0} get-tuple-element((s32[], f32[8]{0}) %loop_state.1), index=1
835   %convert = f32[] convert(s32[] %get-tuple-element.1)
836   %broadcast = f32[8]{0} broadcast(f32[] %convert), dimensions={}
837   %add.1 = f32[8]{0} add(f32[8]{0} %get-tuple-element.2, f32[8]{0} %broadcast)
838   ROOT %tuple = (s32[], f32[8]{0}) tuple(s32[] %add, f32[8]{0} %add.1)
839 }
840 
841 %DependentTupleElements.Condition (loop_state: (s32[], f32[8])) -> pred[] {
842   %loop_state = (s32[], f32[8]{0}) parameter(0)
843   %get-tuple-element = s32[] get-tuple-element((s32[], f32[8]{0}) %loop_state), index=0
844   %constant = s32[] constant(10)
845   ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
846 }
847 
848 ENTRY %DependentTupleElements.While () -> (s32[], f32[8]) {
849   %constant.2 = s32[] constant(0)
850   %constant.3 = f32[8]{0} constant({0, 0, 0, 0, 0, 0, 0, 0})
851   %tuple.1 = (s32[], f32[8]{0}) tuple(s32[] %constant.2, f32[8]{0} %constant.3)
852   ROOT %while.1 = (s32[], f32[8]{0}) while((s32[], f32[8]{0}) %tuple.1), condition=%DependentTupleElements.Condition, body=%DependentTupleElements.Body
853 }
854 )";
855   auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
856   auto module_ = module_or_status.ConsumeValueOrDie();
857   auto while_hlo = module_->entry_computation()->root_instruction();
858   // module_ and while_hlo are the pre-existing module and hlo, the below
859   // code generates a clone of the existing while and replaces that while
860   // with itself. The body of the new while calls the previous while
861   HloComputation* outer_while_condition =
862       module_->AddEmbeddedComputation(while_hlo->while_condition()->Clone());
863   HloComputation* outer_while_body =
864       module_->AddEmbeddedComputation(while_hlo->while_body()->Clone());
865   HloInstruction* outer_while =
866       while_hlo->parent()->AddInstruction(HloInstruction::CreateWhile(
867           while_hlo->shape(), outer_while_condition, outer_while_body,
868           while_hlo->mutable_operand(0)));
869   HloInstruction* outer_param = outer_while_body->parameter_instruction(0);
870   std::vector<HloInstruction*> materialized_gtes;
871   for (int i = 0; i < outer_param->shape().tuple_shapes_size(); ++i) {
872     materialized_gtes.push_back(
873         outer_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
874             outer_param->shape().tuple_shapes(i), outer_param, i)));
875   }
876   HloInstruction* dual_init = outer_while_body->AddInstruction(
877       HloInstruction::CreateTuple(materialized_gtes));
878   HloInstruction* dual_while =
879       outer_while_body->AddInstruction(HloInstruction::CreateWhile(
880           while_hlo->shape(), while_hlo->while_condition(),
881           while_hlo->while_body(), dual_init));
882   TF_CHECK_OK(outer_while_body->ReplaceInstruction(
883       outer_while_body->root_instruction(), dual_while));
884   TF_CHECK_OK(while_hlo->parent()->ReplaceInstruction(while_hlo, outer_while));
885   InsertCopies(module_.get());
886 }
887 
888 // Tests Copy Insertion when a while feeds another while
889 //                         PARAMETER
890 //                        |        |
891 //                         \      /
892 //                           WHILE(PARAMETER) (root)
TEST_F(WhileCopyInsertionTest,WhileFeedingWhileThruParameterNoCopies)893 TEST_F(WhileCopyInsertionTest, WhileFeedingWhileThruParameterNoCopies) {
894   const string& hlo_string = R"(
895 HloModule DependentTupleElements
896 
897 %DependentTupleElements.Body (loop_state.1: (s32[], f32[8])) -> (s32[], f32[8]) {
898   %loop_state.1 = (s32[], f32[8]{0}) parameter(0)
899   %get-tuple-element.1 = s32[] get-tuple-element((s32[], f32[8]{0}) %loop_state.1), index=0
900   %constant.1 = s32[] constant(1)
901   %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
902   %get-tuple-element.2 = f32[8]{0} get-tuple-element((s32[], f32[8]{0}) %loop_state.1), index=1
903   %convert = f32[] convert(s32[] %get-tuple-element.1)
904   %broadcast = f32[8]{0} broadcast(f32[] %convert), dimensions={}
905   %add.1 = f32[8]{0} add(f32[8]{0} %get-tuple-element.2, f32[8]{0} %broadcast)
906   ROOT %tuple = (s32[], f32[8]{0}) tuple(s32[] %add, f32[8]{0} %add.1)
907 }
908 
909 %DependentTupleElements.Condition (loop_state: (s32[], f32[8])) -> pred[] {
910   %loop_state = (s32[], f32[8]{0}) parameter(0)
911   %get-tuple-element = s32[] get-tuple-element((s32[], f32[8]{0}) %loop_state), index=0
912   %constant = s32[] constant(10)
913   ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
914 }
915 
916 ENTRY %DependentTupleElements.While () -> (s32[], f32[8]) {
917   %constant.2 = s32[] constant(0)
918   %constant.3 = f32[8]{0} constant({0, 0, 0, 0, 0, 0, 0, 0})
919   %tuple.1 = (s32[], f32[8]{0}) tuple(s32[] %constant.2, f32[8]{0} %constant.3)
920   ROOT %while.1 = (s32[], f32[8]{0}) while((s32[], f32[8]{0}) %tuple.1), condition=%DependentTupleElements.Condition, body=%DependentTupleElements.Body
921 }
922 )";
923   auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
924   auto module_ = module_or_status.ConsumeValueOrDie();
925   auto while_hlo = module_->entry_computation()->root_instruction();
926   // module_ and while_hlo are the pre-existing module and hlo, the below
927   // code generates a clone of the existing while and replaces that while
928   // with itself. The body of the new while calls the previous while
929   HloComputation* outer_while_condition =
930       module_->AddEmbeddedComputation(while_hlo->while_condition()->Clone());
931   HloComputation* outer_while_body =
932       module_->AddEmbeddedComputation(while_hlo->while_body()->Clone());
933   HloInstruction* outer_while =
934       while_hlo->parent()->AddInstruction(HloInstruction::CreateWhile(
935           while_hlo->shape(), outer_while_condition, outer_while_body,
936           while_hlo->mutable_operand(0)));
937   HloInstruction* outer_param = outer_while_body->parameter_instruction(0);
938   HloInstruction* dual_while =
939       outer_while_body->AddInstruction(HloInstruction::CreateWhile(
940           while_hlo->shape(), while_hlo->while_condition(),
941           while_hlo->while_body(), outer_param));
942   TF_CHECK_OK(outer_while_body->ReplaceInstruction(
943       outer_while_body->root_instruction(), dual_while));
944   TF_CHECK_OK(while_hlo->parent()->ReplaceInstruction(while_hlo, outer_while));
945   InsertCopies(module_.get());
946 }
947 
948 // Tests Copy Insertion when a while feeds another while
949 //                         PARAMETER
950 //                        |        |
951 //                         \      /
952 //                           WHILE(PARAMETER) (root)
TEST_F(WhileCopyInsertionTest,WhileFeedingWhileThruParameterBig)953 TEST_F(WhileCopyInsertionTest, WhileFeedingWhileThruParameterBig) {
954   const string& hlo_string = R"(
955 HloModule DependentTupleElements
956 
957 %DependentTupleElements.Body (loop_state.1: (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0})) -> (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) {
958   %loop_state.1 = (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) parameter(0)
959   %get-tuple-element.1 = s32[] get-tuple-element((s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) %loop_state.1), index=0
960   %constant.1 = s32[] constant(1)
961   %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
962   %get-tuple-element.2 = f32[8]{0} get-tuple-element((s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) %loop_state.1), index=1
963   %convert = f32[] convert(s32[] %get-tuple-element.1)
964   %broadcast = f32[8]{0} broadcast(f32[] %convert), dimensions={}
965   %add.1 = f32[8]{0} add(f32[8]{0} %get-tuple-element.2, f32[8]{0} %broadcast)
966   ROOT %tuple = (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) tuple(s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1, s32[] %add, f32[8]{0} %add.1)
967 }
968 
969 %DependentTupleElements.Condition (loop_state: (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0})) -> pred[] {
970   %loop_state = (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) parameter(0)
971   %get-tuple-element = s32[] get-tuple-element((s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) %loop_state), index=0
972   %constant = s32[] constant(10)
973   ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
974 }
975 
976 ENTRY %DependentTupleElements.While () -> (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) {
977   %constant.2 = s32[] constant(0)
978   %constant.3 = f32[8]{0} constant({0, 0, 0, 0, 0, 0, 0, 0})
979   %tuple.1 = (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) tuple(s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3, s32[] %constant.2, f32[8]{0} %constant.3)
980   ROOT %while.1 = (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) while( (s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}, s32[], f32[8]{0}) %tuple.1), condition=%DependentTupleElements.Condition, body=%DependentTupleElements.Body
981 }
982 )";
983   auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
984   auto module_ = module_or_status.ConsumeValueOrDie();
985   auto while_hlo = module_->entry_computation()->root_instruction();
986   // module_ and while_hlo are the pre-existing module and hlo, the below
987   // code generates a clone of the existing while and replaces that while
988   // with itself. The body of the new while calls the previous while
989   HloComputation* outer_while_condition =
990       module_->AddEmbeddedComputation(while_hlo->while_condition()->Clone());
991   HloComputation* outer_while_body =
992       module_->AddEmbeddedComputation(while_hlo->while_body()->Clone());
993   HloInstruction* outer_while =
994       while_hlo->parent()->AddInstruction(HloInstruction::CreateWhile(
995           while_hlo->shape(), outer_while_condition, outer_while_body,
996           while_hlo->mutable_operand(0)));
997   HloInstruction* outer_param = outer_while_body->parameter_instruction(0);
998   std::vector<HloInstruction*> materialized_gtes;
999   for (int i = 0; i < outer_param->shape().tuple_shapes_size(); ++i) {
1000     materialized_gtes.push_back(
1001         outer_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
1002             outer_param->shape().tuple_shapes(i), outer_param, i)));
1003   }
1004   HloInstruction* dual_init = outer_while_body->AddInstruction(
1005       HloInstruction::CreateTuple(materialized_gtes));
1006   HloInstruction* dual_while =
1007       outer_while_body->AddInstruction(HloInstruction::CreateWhile(
1008           while_hlo->shape(), while_hlo->while_condition(),
1009           while_hlo->while_body(), dual_init));
1010   TF_CHECK_OK(outer_while_body->ReplaceInstruction(
1011       outer_while_body->root_instruction(), dual_while));
1012   TF_CHECK_OK(while_hlo->parent()->ReplaceInstruction(while_hlo, outer_while));
1013   InsertCopies(module_.get());
1014 }
1015 
1016 // Tests while body computation with dependent tuple elements:
1017 //
1018 //   While.Body({in0, in1})
1019 //     out0 = Add(in0, 1)
1020 //     out1 = Add(BCast(in0), in1)
1021 //     Tuple(out0, out1)
1022 //
1023 // CopyInsertion pass should convert the root instruction to:
1024 //
1025 //     Tuple(Copy(out0), out1)
1026 //
TEST_F(WhileCopyInsertionTest,DependentTupleElements)1027 TEST_F(WhileCopyInsertionTest, DependentTupleElements) {
1028   auto condition = module_->AddEmbeddedComputation(
1029       BuildConditionComputation(loop_state_shape_));
1030   auto body = module_->AddEmbeddedComputation(BuildDependentBodyComputation());
1031   auto while_hlo = BuildWhileInstruction(condition, body);
1032 
1033   InsertCopies(module_.get());
1034 
1035   EXPECT_EQ(CountCopies(*body), 1);
1036   EXPECT_EQ(CountControlEdges(*body), 0);
1037 
1038   EXPECT_THAT(
1039       body->root_instruction(),
1040       op::Tuple(op::Add(), op::Add(op::GetTupleElement(), op::Broadcast())));
1041 
1042   auto add = body->root_instruction()->operand(0);
1043   auto bcast = body->root_instruction()->operand(1)->operand(1);
1044   ASSERT_EQ(add->opcode(), HloOpcode::kAdd);
1045   ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
1046 
1047   EXPECT_THAT(while_hlo->while_body()->root_instruction(),
1048               op::Tuple(op::Add(op::Copy(), op::Constant()),
1049                         op::Add(op::GetTupleElement(),
1050                                 op::Broadcast(op::Convert(op::Copy())))));
1051 
1052   // Both init indices need copies as they are constants.
1053   EXPECT_THAT(while_hlo->operand(0),
1054               op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant())));
1055 }
1056 
1057 // Tests while body computation with read-only tuple element 0:
1058 //
1059 //                         PARAMETER
1060 //                         /       \
1061 //                      GTE(0)     GTE(1)
1062 //                        |  \      |
1063 //                        |   BCAST |
1064 //                        |      \  |
1065 //                        |       ADD
1066 //                        |        |
1067 //                         \      /
1068 //                           TUPLE (root)
1069 //
1070 // CopyInsertion pass should not generate any copies for the while body.
TEST_F(WhileCopyInsertionTest,DependentTupleElements_OneReadOnly)1071 TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly) {
1072   auto condition = module_->AddEmbeddedComputation(
1073       BuildConditionComputation(loop_state_shape_));
1074   auto body = module_->AddEmbeddedComputation(
1075       BuildDependentBodyOneReadOnlyComputation());
1076   BuildWhileInstruction(condition, body);
1077 
1078   InsertCopies(module_.get());
1079 
1080   // No copies or control edges should be inserted. The body is legal as is.
1081   EXPECT_EQ(CountCopies(*body), 0);
1082   EXPECT_EQ(CountControlEdges(*body), 0);
1083 }
1084 
1085 // Same as above, but with two while loops, sharing entry parameters.
TEST_F(WhileCopyInsertionTest,DependentTupleElements_OneReadOnly_TwoLoops_EntryParams)1086 TEST_F(WhileCopyInsertionTest,
1087        DependentTupleElements_OneReadOnly_TwoLoops_EntryParams) {
1088   auto condition1 = module_->AddEmbeddedComputation(
1089       BuildConditionComputation(loop_state_shape_));
1090   auto condition2 = module_->AddEmbeddedComputation(
1091       BuildConditionComputation(loop_state_shape_));
1092   auto body1 = module_->AddEmbeddedComputation(
1093       BuildDependentBodyOneReadOnlyComputation());
1094   auto body2 = module_->AddEmbeddedComputation(
1095       BuildDependentBodyOneReadOnlyComputation());
1096 
1097   auto builder = HloComputation::Builder(TestName() + ".While");
1098   auto iter_param = builder.AddInstruction(
1099       HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
1100   auto data_param = builder.AddInstruction(
1101       HloInstruction::CreateParameter(1, data_shape_, "data"));
1102   auto loop_init = builder.AddInstruction(
1103       HloInstruction::CreateTuple({iter_param, data_param}));
1104 
1105   auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
1106       loop_state_shape_, condition1, body1, loop_init));
1107   auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
1108       loop_state_shape_, condition2, body2, loop_init));
1109 
1110   // Add a couple elements from each of the while so both whiles are live.
1111   auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1112       ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0));
1113   auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1114       ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0));
1115   builder.AddInstruction(
1116       HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2));
1117 
1118   auto entry = module_->AddEntryComputation(builder.Build());
1119 
1120   InsertCopies(module_.get());
1121 
1122   // Neither body should have any copies or control edges in them.
1123   EXPECT_EQ(CountCopies(*body1), 0);
1124   EXPECT_EQ(CountCopies(*body2), 0);
1125   EXPECT_EQ(CountControlEdges(*body1), 0);
1126   EXPECT_EQ(CountControlEdges(*body2), 0);
1127 
1128   // Only two copies should be necessary. Each of the whiles should have
1129   // a copy of tuple element 1 (init value is a parameter, and the element is
1130   // not non-read-only) so each of the while bodies gets its own buffer to write
1131   // element 1 into.
1132   EXPECT_EQ(CountCopies(*entry), 2);
1133 
1134   EXPECT_EQ(while_hlo1->operand(0)->operand(1)->opcode(), HloOpcode::kCopy);
1135   EXPECT_EQ(while_hlo2->operand(0)->operand(1)->opcode(), HloOpcode::kCopy);
1136 
1137   // The two copies of element 1 should be different.
1138   EXPECT_NE(while_hlo1->operand(0)->operand(1),
1139             while_hlo2->operand(0)->operand(1));
1140 }
1141 
1142 // Same as above, but with two while loops, sharing non-parameters.
TEST_F(WhileCopyInsertionTest,DependentTupleElements_OneReadOnly_TwoLoops_NonParams)1143 TEST_F(WhileCopyInsertionTest,
1144        DependentTupleElements_OneReadOnly_TwoLoops_NonParams) {
1145   auto condition1 = module_->AddEmbeddedComputation(
1146       BuildConditionComputation(loop_state_shape_));
1147   auto condition2 = module_->AddEmbeddedComputation(
1148       BuildConditionComputation(loop_state_shape_));
1149   auto body1 = module_->AddEmbeddedComputation(
1150       BuildDependentBodyOneReadOnlyComputation());
1151   auto body2 = module_->AddEmbeddedComputation(
1152       BuildDependentBodyOneReadOnlyComputation());
1153 
1154   auto builder = HloComputation::Builder(TestName() + ".While");
1155   auto iter_param = builder.AddInstruction(
1156       HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
1157   auto data_param = builder.AddInstruction(
1158       HloInstruction::CreateParameter(1, data_shape_, "data"));
1159   // Add dummy ops to ensure loop_init elements aren't entry parameters.
1160   Shape f32_scalar_shape = ShapeUtil::MakeShape(F32, {});
1161   auto convert = builder.AddInstruction(
1162       HloInstruction::CreateConvert(f32_scalar_shape, iter_param));
1163   auto iter_value = builder.AddInstruction(
1164       HloInstruction::CreateUnary(convert->shape(), HloOpcode::kExp, convert));
1165   auto convert2 = builder.AddInstruction(
1166       HloInstruction::CreateConvert(induction_variable_shape_, iter_value));
1167   auto data_value = builder.AddInstruction(HloInstruction::CreateUnary(
1168       data_param->shape(), HloOpcode::kExp, data_param));
1169   auto loop_init = builder.AddInstruction(
1170       HloInstruction::CreateTuple({convert2, data_value}));
1171 
1172   auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
1173       loop_state_shape_, condition1, body1, loop_init));
1174   auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
1175       loop_state_shape_, condition2, body2, loop_init));
1176 
1177   // Add a couple elements from each of the while so both whiles are not dead.
1178   auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1179       ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0));
1180   auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1181       ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0));
1182   builder.AddInstruction(
1183       HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2));
1184   auto entry = module_->AddEntryComputation(builder.Build());
1185 
1186   InsertCopies(module_.get());
1187 
1188   // Ideally only one copy should be necessary. One of the whiles should
1189   // have a copy of tuple element 1 (the non-read-only element) so each of the
1190   // while bodies gets its own buffer to write element 1 into. However, the
1191   // analysis isn't perfect and adds an additional copy of element 0.
1192   EXPECT_EQ(CountCopies(*entry), 2);
1193 
1194   EXPECT_THAT(while_hlo1->operand(0),
1195               op::Tuple(op::Convert(op::Exp()), op::Copy(op::Exp())));
1196   EXPECT_THAT(while_hlo2->operand(0),
1197               op::Tuple(op::Convert(op::Exp()), op::Copy(op::Exp())));
1198 }
1199 
1200 // Tests while body computation with nested tuple elements:
1201 //
1202 //                            |
1203 //                    GTE(loop_state, 1)
1204 //                       /          \
1205 // GTE(GTE(loop_state, 1), 0)     GTE(GTE(loop_state, 1), 1)
1206 //           |                              |
1207 //          Add                           Reverse
1208 //           |                              |
1209 //
1210 // CopyInsertion pass will conceptually generate the following, but with the
1211 // actual GTE and Tuple instructions optimized away:
1212 //
1213 //                    Tuple  // old root
1214 //                   /     \
1215 //                  /       \
1216 //                GTE(0)   GTE(1)
1217 //                  |       /  \
1218 //                  |      /    \
1219 //                  |    GTE(0) GTE(1)
1220 //                  |       |    |
1221 //                  |       |   Copy
1222 //                  |       |    |
1223 //                   \      |   /
1224 //                    \    Tuple  // "inner" tuple.
1225 //                     \    /
1226 //                      \  /
1227 //                     Tuple  // new root
1228 //
TEST_F(WhileCopyInsertionTest,NestedTupleElements)1229 TEST_F(WhileCopyInsertionTest, NestedTupleElements) {
1230   auto condition = module_->AddEmbeddedComputation(
1231       BuildConditionComputation(nested_loop_state_shape_));
1232   auto body = module_->AddEmbeddedComputation(BuildNestedBodyComputation());
1233   BuildWhileInstruction(condition, body, true);
1234 
1235   //  HloInstruction* old_root = body->root_instruction();
1236   InsertCopies(module_.get());
1237 
1238   // The only copy necessary is for the kReverse as it cannot be done
1239   // in-place (instruction can share buffer with operand). The other elements of
1240   // the loop state are kAdd instructions which can be done in-place.
1241   EXPECT_EQ(CountCopies(*body), 1);
1242 
1243   // Each element of the init needs a copy as all are constants.
1244   EXPECT_EQ(CountCopies(*module_), 4);
1245 
1246   // Either the kReverse itself must be copied or the operand of the kReverse
1247   // must be copied.
1248   if (body->root_instruction()->operand(1)->operand(1)->opcode() ==
1249       HloOpcode::kCopy) {
1250     EXPECT_THAT(
1251         body->root_instruction(),
1252         op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Reverse()))));
1253   } else {
1254     EXPECT_THAT(
1255         body->root_instruction(),
1256         op::Tuple(op::Add(), op::Tuple(op::Add(), op::Reverse(op::Copy()))));
1257   }
1258 }
1259 
1260 // Tests while init instruction which points-to a constant.
1261 //
1262 //     init = Tuple(Constant(S32, {}), Constant(F32, {8}))
1263 //
1264 // CopyInsertion pass should add copies for both constants.
1265 //
TEST_F(WhileCopyInsertionTest,InitPointsToConstant)1266 TEST_F(WhileCopyInsertionTest, InitPointsToConstant) {
1267   auto while_hlo = BuildWhileInstruction_InitPointsToConstant();
1268 
1269   InsertCopies(module_.get());
1270   EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0);
1271   EXPECT_EQ(CountCopies(*module_), 2);
1272 
1273   EXPECT_THAT(while_hlo->operand(0),
1274               op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant())));
1275 }
1276 
1277 // Tests while init instruction which points-to a parameter.
1278 //
1279 //     init = Tuple(Constant(S32, {}), Parameter(F32, {8}))
1280 //
1281 // CopyInsertion pass should add copies for both the constant and parameter.
1282 //
TEST_F(WhileCopyInsertionTest,InitPointsToParameter)1283 TEST_F(WhileCopyInsertionTest, InitPointsToParameter) {
1284   auto while_hlo = BuildWhileInstruction_InitPointsToParameter();
1285 
1286   InsertCopies(module_.get());
1287   EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0);
1288   EXPECT_EQ(CountCopies(*module_), 2);
1289 
1290   EXPECT_THAT(while_hlo->operand(0),
1291               op::Tuple(op::Copy(op::Constant()), op::Copy(op::Parameter())));
1292 }
1293 
1294 // Tests while init instruction which has an ambiguous points-to set.
1295 //
1296 //     select = Select(pred, tuple1, tuple2)
1297 //     init = Tuple(Constant(S32, {}), Parameter(F32, {8}))
1298 //
1299 // CopyInsertion pass will conceptually generate the following, but with some of
1300 // the actual GTE and Tuple instructions optimized away:
1301 //
1302 //                    Tuple  // old init
1303 //                   /     \
1304 //                  /       \
1305 //                GTE(0)   GTE(1)
1306 //                  |       /  \
1307 //                  |      /    \
1308 //                  |    GTE(0) GTE(1)
1309 //                  |       |    |
1310 //                Copy   Copy   Copy
1311 //                  |       |    |
1312 //                   \      |   /
1313 //                    \    Tuple
1314 //                     \    /
1315 //                      \  /
1316 //                     Tuple  // new init
1317 //
TEST_F(WhileCopyInsertionTest,InitPointsToAmbiguous)1318 TEST_F(WhileCopyInsertionTest, InitPointsToAmbiguous) {
1319   auto while_hlo = BuildWhileInstruction_InitPointsToAmbiguous();
1320 
1321   InsertCopies(module_.get());
1322   EXPECT_EQ(CountCopies(*module_), 4);
1323   // The entry computation requires three copies to resolve the ambiguity of two
1324   // init elements and the constant passed in as one of the init elements.
1325   EXPECT_EQ(CountCopies(*module_->entry_computation()), 3);
1326   EXPECT_THAT(while_hlo->operand(0),
1327               op::Tuple(op::Copy(op::Constant()),
1328                         op::Tuple(op::Copy(op::GetTupleElement()),
1329                                   op::Copy(op::GetTupleElement()))));
1330 
1331   // The body requires one copy because the buffer set is not distinct: the
1332   // result of one of the adds is written into two elements of the output of the
1333   // loop body. Either element might be copied.
1334   EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1);
1335   if (while_hlo->while_body()
1336           ->root_instruction()
1337           ->operand(1)
1338           ->operand(0)
1339           ->opcode() == HloOpcode::kCopy) {
1340     EXPECT_THAT(
1341         while_hlo->while_body()->root_instruction(),
1342         op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add())));
1343   } else {
1344     EXPECT_THAT(
1345         while_hlo->while_body()->root_instruction(),
1346         op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add()))));
1347   }
1348 }
1349 
1350 // Tests while init instruction which has a non-distinct points-to set.
1351 //
1352 //     init = Tuple(Constant(S32, {}), Tuple({vec_one, vec_one}))
1353 //
1354 // CopyInsertion pass will conceptually generate the following, but with some of
1355 // the actual GTE and Tuple instructions optimized away:
1356 //
1357 //                    Tuple  // old init
1358 //                   /     \
1359 //                  /       \
1360 //                GTE(0)   GTE(1)
1361 //                  |       /  \
1362 //                  |      /    \
1363 //                  |    GTE(0) GTE(1)
1364 //                  |       |    |
1365 //                Copy   Copy   Copy
1366 //                  |       |    |
1367 //                   \      |   /
1368 //                    \    Tuple
1369 //                     \    /
1370 //                      \  /
1371 //                     Tuple  // new init
1372 //
TEST_F(WhileCopyInsertionTest,InitPointsToNonDistinct)1373 TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) {
1374   auto while_hlo = BuildWhileInstruction_InitPointsToNonDistinct();
1375 
1376   InsertCopies(module_.get());
1377 
1378   // The entry computation requires two copies to resolve the non-distinctness
1379   // of two init elements and the constant passed in as one of the init
1380   // elements. Either element can be copied for the distinctness issue.
1381   EXPECT_EQ(CountCopies(*module_->entry_computation()), 2);
1382   if (while_hlo->operand(0)->operand(1)->operand(0)->opcode() ==
1383       HloOpcode::kCopy) {
1384     EXPECT_THAT(
1385         while_hlo->operand(0),
1386         op::Tuple(op::Copy(op::Constant()),
1387                   op::Tuple(op::Copy(op::Broadcast()), op::Broadcast())));
1388   } else {
1389     EXPECT_THAT(
1390         while_hlo->operand(0),
1391         op::Tuple(op::Copy(op::Constant()),
1392                   op::Tuple(op::Broadcast(), op::Copy(op::Broadcast()))));
1393   }
1394 
1395   // The body requires one copy because the buffer set is not distinct: the
1396   // result of one of the adds is written into two elements of the output of the
1397   // loop body. Either element might be copied.
1398   EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1);
1399   if (while_hlo->while_body()
1400           ->root_instruction()
1401           ->operand(1)
1402           ->operand(0)
1403           ->opcode() == HloOpcode::kCopy) {
1404     EXPECT_THAT(
1405         while_hlo->while_body()->root_instruction(),
1406         op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add())));
1407   } else {
1408     EXPECT_THAT(
1409         while_hlo->while_body()->root_instruction(),
1410         op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add()))));
1411   }
1412 }
1413 
1414 // Tests while init instruction buffer which interferes with while result
1415 // buffer.
1416 //
1417 //     init_data = Broadcast(...)
1418 //     add_unrelated = Add(init_data) // takes a reference to cause interference
1419 //     init = Tuple(Constant(S32, {}), init_data))
1420 //
1421 // CopyInsertion pass should copy both operands.
1422 //
TEST_F(WhileCopyInsertionTest,InitPointsToInterfering)1423 TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) {
1424   auto while_hlo = BuildWhileInstruction_InitPointsToInterfering();
1425 
1426   InsertCopies(module_.get());
1427   EXPECT_EQ(CountCopies(*module_), 2);
1428   EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0);
1429 
1430   EXPECT_THAT(while_hlo->operand(0),
1431               op::Tuple(op::Copy(op::Constant()), op::Copy(op::Broadcast())));
1432 }
1433 
1434 // Tests while init instruction buffer which has a non-distinct points-to set:
1435 //
1436 //     init = Tuple(Parameter(S32, {}), Parameter(F32, {8},
1437 //                  Parameter(F32, {8})))
1438 //
1439 // where the second and third parameters are identical *and* the tuple shared
1440 // by another while instruction.
1441 //
1442 // Verifies that the resulting point-to set is distinct in the resulting Tuple
1443 // (non-identical Copys). In other words, verifies that copy sharing does not
1444 // insert identical copies to the resulting tuple.
TEST_F(WhileCopyInsertionTest,InitPointsToNonDistinctUsedByTwoWhileLoops)1445 TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) {
1446   // Loop body that outputs tuple comprises two elements dependent on the init
1447   // tuple.
1448   const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
1449       {induction_variable_shape_, data_shape_, data_shape_});
1450 
1451   auto condition1 = module_->AddEmbeddedComputation(
1452       BuildConditionComputation(loop_state_shape));
1453   auto condition2 = module_->AddEmbeddedComputation(
1454       BuildConditionComputation(loop_state_shape));
1455   auto body1 =
1456       module_->AddEmbeddedComputation(BuildDependentBodyComputation2());
1457   auto body2 =
1458       module_->AddEmbeddedComputation(BuildDependentBodyComputation2());
1459 
1460   auto builder = HloComputation::Builder(TestName() + ".While");
1461 
1462   auto iter_param = builder.AddInstruction(
1463       HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
1464   auto data_param = builder.AddInstruction(
1465       HloInstruction::CreateParameter(1, data_shape_, "data"));
1466 
1467   // Loop init tuple contains two identical parameter buffers.
1468   auto loop_init = builder.AddInstruction(
1469       HloInstruction::CreateTuple({iter_param, data_param, data_param}));
1470 
1471   // Two while loops share the same loop init tuple.
1472   auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
1473       loop_state_shape, condition1, body1, loop_init));
1474   auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
1475       loop_state_shape, condition2, body2, loop_init));
1476 
1477   // Add add instruction so neither while is dead.
1478   auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1479       ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0));
1480   auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1481       ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo2, 0));
1482   builder.AddInstruction(
1483       HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2));
1484 
1485   module_->AddEntryComputation(builder.Build());
1486 
1487   InsertCopies(module_.get());
1488 
1489   // None of the bodies should have copies or control flow edges.
1490   EXPECT_EQ(CountCopies(*body1), 0);
1491   EXPECT_EQ(CountCopies(*body2), 0);
1492 
1493   // The loop bodies pass through elements 1 and 2 in the init tuple, so ideally
1494   // these should not need to be copied before either while. However, copy
1495   // insertion is not able to reason about the transparency of elements through
1496   // while bodies in all circumstances so extra copies are added (b/xxx).
1497   EXPECT_EQ(CountCopies(*module_->entry_computation()), 2);
1498 
1499   EXPECT_THAT(while_hlo1->operand(0),
1500               op::Tuple(op::Copy(), op::Parameter(), op::Parameter()));
1501   EXPECT_THAT(while_hlo2->operand(0),
1502               op::Tuple(op::Copy(), op::Parameter(), op::Parameter()));
1503 }
1504 
TEST_F(CopyInsertionTest,SwizzlingWhile)1505 TEST_F(CopyInsertionTest, SwizzlingWhile) {
1506   // Test a while instruction with a body which permutes its tuple parameter
1507   // elements.
1508   auto module = CreateNewVerifiedModule();
1509   const Shape loop_state_shape =
1510       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1511 
1512   // Body simply interchanges the two tuple elements in the loop state.
1513   auto body_builder = HloComputation::Builder("body");
1514   auto body_param = body_builder.AddInstruction(
1515       HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1516   auto body_element_0 = body_builder.AddInstruction(
1517       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
1518   auto body_element_1 = body_builder.AddInstruction(
1519       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
1520   body_builder.AddInstruction(
1521       HloInstruction::CreateTuple({body_element_1, body_element_0}));
1522   HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1523 
1524   auto cond_builder = HloComputation::Builder("condition");
1525   cond_builder.AddInstruction(
1526       HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1527   auto cond_constant = cond_builder.AddInstruction(
1528       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1529   cond_builder.AddInstruction(HloInstruction::CreateUnary(
1530       cond_constant->shape(), HloOpcode::kNot, cond_constant));
1531   HloComputation* condition =
1532       module->AddEmbeddedComputation(cond_builder.Build());
1533 
1534   auto builder = HloComputation::Builder(TestName());
1535   auto constant1 = builder.AddInstruction(
1536       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1537   auto constant2 = builder.AddInstruction(
1538       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
1539   auto tuple = builder.AddInstruction(
1540       HloInstruction::CreateTuple({constant1, constant2}));
1541   auto xla_while = builder.AddInstruction(
1542       HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple));
1543   module->AddEntryComputation(builder.Build());
1544 
1545   InsertCopies(module.get());
1546 
1547   EXPECT_EQ(CountCopies(*module), 6);
1548 
1549   // The loop state elements should be copied at the parameter and at the root
1550   // with a control edge in between (see DeepCopyAndAddControlEdges). This is
1551   // technically one more copy than is strictly necessary, but in order to have
1552   // only three copies the copies of different loop state elements must be
1553   // ordered with a control edge.
1554   EXPECT_EQ(CountCopies(*body), 4);
1555   EXPECT_EQ(CountControlEdges(*body), 2);
1556 
1557   EXPECT_THAT(body->root_instruction(),
1558               op::Tuple(op::Copy(op::Copy()), op::Copy(op::Copy())));
1559 
1560   EXPECT_EQ(CountCopies(*module->entry_computation()), 2);
1561   EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy()));
1562 }
1563 
TEST_F(CopyInsertionTest,CrossingParameters)1564 TEST_F(CopyInsertionTest, CrossingParameters) {
1565   // Test a case where two parameters' dataflow cross with each other while
1566   // input and output are aliased with same index:
1567   //
1568   //  (p0 ,  p1)
1569   //   | \   /|
1570   //   |  \ / |
1571   // alias X  alias
1572   //   |  / \ |
1573   //   | /   \|
1574   //  (p1  ,  p0)
1575   auto module = CreateNewVerifiedModule();
1576   const Shape tuple_shape =
1577       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1578 
1579   auto builder = HloComputation::Builder(TestName());
1580   auto param = builder.AddInstruction(
1581       HloInstruction::CreateParameter(0, tuple_shape, "0"));
1582   auto gte0 = builder.AddInstruction(
1583       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1584   auto gte1 = builder.AddInstruction(
1585       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1586   builder.AddInstruction(HloInstruction::CreateTuple({gte1, gte0}));
1587   module->AddEntryComputation(builder.Build());
1588   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1589       /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
1590   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1591       /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
1592   InsertCopies(module.get());
1593 
1594   EXPECT_EQ(CountCopies(*module), 4);
1595 }
1596 
TEST_F(CopyInsertionTest,ParametersAliasing)1597 TEST_F(CopyInsertionTest, ParametersAliasing) {
1598   // Test a case where two parameters' dataflow don't interfere with each other
1599   // while aliased.
1600   //
1601   //  (p0 ,  p1)
1602   //   |      |
1603   //   |      |
1604   // alias   alias
1605   //   |      |
1606   //   |      |
1607   //  (p0 ,  p1)
1608   auto module = CreateNewVerifiedModule();
1609   const Shape tuple_shape =
1610       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1611 
1612   auto builder = HloComputation::Builder(TestName());
1613   auto param = builder.AddInstruction(
1614       HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1615   auto gte0 = builder.AddInstruction(
1616       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1617   auto gte1 = builder.AddInstruction(
1618       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1619   builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
1620   module->AddEntryComputation(builder.Build());
1621   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1622       /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
1623   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1624       /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
1625   InsertCopies(module.get());
1626 
1627   EXPECT_EQ(CountCopies(*module), 0);
1628 }
1629 
TEST_F(CopyInsertionTest,ParameterWithNoAliasing)1630 TEST_F(CopyInsertionTest, ParameterWithNoAliasing) {
1631   // Test a case where no parameter is aliased with result. In this case, copy
1632   // should be added
1633   //
1634   //  (p0 ,  p1)
1635   //   |      |
1636   //   |      |
1637   //   |      |
1638   //   |      |
1639   //   |      |
1640   //  (p0 ,  p1)
1641   auto module = CreateNewVerifiedModule();
1642   const Shape tuple_shape =
1643       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1644 
1645   auto builder = HloComputation::Builder(TestName());
1646   auto param = builder.AddInstruction(
1647       HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1648   auto gte0 = builder.AddInstruction(
1649       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1650   auto gte1 = builder.AddInstruction(
1651       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1652   builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
1653   module->AddEntryComputation(builder.Build());
1654   InsertCopies(module.get());
1655 
1656   EXPECT_THAT(module->entry_computation()->root_instruction(),
1657               op::Tuple(op::Copy(op::GetTupleElement(param, 0)),
1658                         op::Copy(op::GetTupleElement(param, 1))));
1659 
1660   EXPECT_EQ(CountCopies(*module), 2);
1661 }
1662 
TEST_F(CopyInsertionTest,ParameterWithPartialAliasing)1663 TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) {
1664   // Test a case where one parameter is aliased with result while another one
1665   // isn't.
1666   //
1667   //  (p0 ,  p1)
1668   //   |      |
1669   //   |      |
1670   // alias    |
1671   //   |      |
1672   //   |      |
1673   //  (p0 ,  p1)
1674   auto module = CreateNewVerifiedModule();
1675   const Shape tuple_shape =
1676       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1677 
1678   auto builder = HloComputation::Builder(TestName());
1679   auto param = builder.AddInstruction(
1680       HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1681   auto gte0 = builder.AddInstruction(
1682       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1683   auto gte1 = builder.AddInstruction(
1684       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1685   builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
1686   module->AddEntryComputation(builder.Build());
1687   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1688       /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
1689   InsertCopies(module.get());
1690 
1691   EXPECT_THAT(module->entry_computation()->root_instruction(),
1692               op::Tuple(op::GetTupleElement(param, 0),
1693                         op::Copy(op::GetTupleElement(param, 1))));
1694 
1695   EXPECT_EQ(CountCopies(*module), 1);
1696 }
1697 
TEST_F(CopyInsertionTest,ParameterAndParallelOpsWithPartialAliasing)1698 TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) {
1699   // Test a case where one parameter is aliased with result while another one
1700   // isn't.
1701   //
1702   //   +-- (p0 ,  p1)
1703   //   |    |      |
1704   //   |    |      |
1705   // alias Negate  Negate
1706   //   |    |      |
1707   //   |    |      |
1708   //   +-- (p0 ,  p1)
1709   auto module = CreateNewVerifiedModule();
1710   const Shape tuple_shape =
1711       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1712 
1713   auto builder = HloComputation::Builder(TestName());
1714   auto param = builder.AddInstruction(
1715       HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1716   auto gte0 = builder.AddInstruction(
1717       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1718   auto gte1 = builder.AddInstruction(
1719       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1720 
1721   auto negate0 = builder.AddInstruction(
1722       HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
1723 
1724   auto negate1 = builder.AddInstruction(
1725       HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
1726   builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1}));
1727   module->AddEntryComputation(builder.Build());
1728   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1729       /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
1730   InsertCopies(module.get());
1731 
1732   EXPECT_EQ(CountCopies(*module), 0);
1733 }
1734 
TEST_F(CopyInsertionTest,ParameterAndOpsWithPartialAliasing)1735 TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) {
1736   // Test a case where one parameter is aliased with result while another one
1737   // isn't.
1738   //
1739   //   +-- (p0 ,  p1)
1740   //   |    |      |
1741   //   |    |      |
1742   // alias Negate  Negate
1743   //   |    |      |
1744   //   |    Add----+
1745   //   |    |      |
1746   //   +-- (p0 ,  p1)
1747   auto module = CreateNewVerifiedModule();
1748   const Shape tuple_shape =
1749       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1750 
1751   auto builder = HloComputation::Builder(TestName());
1752   auto param = builder.AddInstruction(
1753       HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1754   auto gte0 = builder.AddInstruction(
1755       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1756   auto gte1 = builder.AddInstruction(
1757       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1758 
1759   auto negate0 = builder.AddInstruction(
1760       HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
1761 
1762   auto negate1 = builder.AddInstruction(
1763       HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
1764 
1765   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1766       scalar_shape_, HloOpcode::kAdd, negate0, negate1));
1767   builder.AddInstruction(HloInstruction::CreateTuple({add, negate1}));
1768   module->AddEntryComputation(builder.Build());
1769   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1770       /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
1771   InsertCopies(module.get());
1772 
1773   EXPECT_EQ(CountCopies(*module), 0);
1774 }
1775 
TEST_F(CopyInsertionTest,SwizzlingWhileWithOneOp)1776 TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) {
1777   // Test a while instruction with a body which permutes its tuple parameter
1778   // elements and applies one operation to one of the elements. The addition of
1779   // the operation (instruction) on the element makes the live range of the
1780   // respective input and output elements different than if the instruction were
1781   // not there (as in the SwizzlingWhile test above).
1782   auto module = CreateNewVerifiedModule();
1783   const Shape loop_state_shape =
1784       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1785 
1786   // Body interchanges the two tuple elements in the loop state and negates one
1787   // of them.
1788   auto body_builder = HloComputation::Builder("body");
1789   auto body_param = body_builder.AddInstruction(
1790       HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1791   auto body_element_0 = body_builder.AddInstruction(
1792       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
1793   auto body_element_1 = body_builder.AddInstruction(
1794       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
1795   auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
1796       scalar_shape_, HloOpcode::kNegate, body_element_1));
1797   body_builder.AddInstruction(
1798       HloInstruction::CreateTuple({negate, body_element_0}));
1799   HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1800 
1801   auto cond_builder = HloComputation::Builder("condition");
1802   cond_builder.AddInstruction(
1803       HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1804   auto cond_constant = cond_builder.AddInstruction(
1805       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1806   cond_builder.AddInstruction(HloInstruction::CreateUnary(
1807       cond_constant->shape(), HloOpcode::kNot, cond_constant));
1808   HloComputation* condition =
1809       module->AddEmbeddedComputation(cond_builder.Build());
1810 
1811   auto builder = HloComputation::Builder(TestName());
1812   auto constant1 = builder.AddInstruction(
1813       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1814   auto constant2 = builder.AddInstruction(
1815       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
1816   auto tuple = builder.AddInstruction(
1817       HloInstruction::CreateTuple({constant1, constant2}));
1818   auto xla_while = builder.AddInstruction(
1819       HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple));
1820   module->AddEntryComputation(builder.Build());
1821 
1822   InsertCopies(module.get());
1823 
1824   EXPECT_EQ(CountCopies(*module), 6);
1825 
1826   // The loop state elements should be copied at the parameter and at the root
1827   // with a control edge in between (see DeepCopyAndAddControlEdges).
1828   EXPECT_EQ(CountCopies(*body), 4);
1829   EXPECT_EQ(CountControlEdges(*body), 2);
1830 
1831   EXPECT_THAT(
1832       body->root_instruction(),
1833       op::Tuple(op::Copy(op::Negate(op::Copy())), op::Copy(op::Copy())));
1834 
1835   EXPECT_EQ(CountCopies(*module->entry_computation()), 2);
1836   EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy()));
1837 }
1838 
TEST_F(CopyInsertionTest,SwizzlingWhileSharedInput)1839 TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) {
1840   // Test a while instruction with a body which permutes it's tuple parameter
1841   // elements similar to SwizzlinWhile above. However, in this test the input to
1842   // the while body is a single constant (both loop state elements are the same
1843   // constant). This means no copies are necessary because both loop state
1844   // elements are the same so interchanging them is a no-op.
1845   auto module = CreateNewVerifiedModule();
1846   const Shape loop_state_shape =
1847       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1848 
1849   // Body simply interchanges the two tuple elements in the loop state.
1850   auto body_builder = HloComputation::Builder("body");
1851   auto body_param = body_builder.AddInstruction(
1852       HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1853   auto body_element_0 = body_builder.AddInstruction(
1854       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
1855   auto body_element_1 = body_builder.AddInstruction(
1856       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
1857   body_builder.AddInstruction(
1858       HloInstruction::CreateTuple({body_element_1, body_element_0}));
1859   HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1860 
1861   auto cond_builder = HloComputation::Builder("condition");
1862   cond_builder.AddInstruction(
1863       HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1864   auto cond_constant = cond_builder.AddInstruction(
1865       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1866   cond_builder.AddInstruction(HloInstruction::CreateUnary(
1867       cond_constant->shape(), HloOpcode::kNot, cond_constant));
1868   HloComputation* condition =
1869       module->AddEmbeddedComputation(cond_builder.Build());
1870 
1871   auto builder = HloComputation::Builder(TestName());
1872   auto constant = builder.AddInstruction(
1873       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1874   auto tuple =
1875       builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
1876   builder.AddInstruction(
1877       HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple));
1878   module->AddEntryComputation(builder.Build());
1879 
1880   InsertCopies(module.get());
1881 
1882   EXPECT_EQ(CountCopies(*module), 2);
1883   EXPECT_EQ(CountCopies(*body), 0);
1884 
1885   EXPECT_EQ(CountCopies(*module->entry_computation()), 2);
1886   EXPECT_THAT(module->entry_computation()->root_instruction(),
1887               op::Tuple(op::Copy(), op::Copy()));
1888 }
1889 
TEST_F(CopyInsertionTest,SequentialWhiles)1890 TEST_F(CopyInsertionTest, SequentialWhiles) {
1891   // Construct a computation with a series of sequential while instructions
1892   // containing four loop state elements:
1893   //
1894   //   element 0 is passed to each while directly from an entry parameter.
1895   //
1896   //   element 1 is passed transparently in series through all the while bodies.
1897   //
1898   //   element 2 is negated in each while body. (in-place possible)
1899   //
1900   //   element 3 is reversed in each while body. (in-place not possible)
1901   //
1902   const Shape element_shape = ShapeUtil::MakeShape(F32, {42});
1903   const Shape loop_state_shape = ShapeUtil::MakeTupleShape(
1904       {element_shape, element_shape, element_shape, element_shape});
1905 
1906   auto module = CreateNewVerifiedModule();
1907   auto builder = HloComputation::Builder(TestName());
1908   auto param_0 = builder.AddInstruction(
1909       HloInstruction::CreateParameter(0, element_shape, "param_0"));
1910   auto param_1 = builder.AddInstruction(
1911       HloInstruction::CreateParameter(1, element_shape, "param_1"));
1912   auto param_2 = builder.AddInstruction(
1913       HloInstruction::CreateParameter(2, element_shape, "param_2"));
1914   auto param_3 = builder.AddInstruction(
1915       HloInstruction::CreateParameter(3, element_shape, "param_3"));
1916 
1917   // The number of sequential kWhile instructions.
1918   const int kNumWhiles = 3;
1919 
1920   HloInstruction* prev_element_1 = param_1;
1921   HloInstruction* prev_element_2 = param_2;
1922   HloInstruction* prev_element_3 = param_3;
1923 
1924   // Vector containing all of the while instructions.
1925   std::vector<const HloInstruction*> whiles;
1926   for (int i = 0; i < kNumWhiles; ++i) {
1927     auto body_builder = HloComputation::Builder("body");
1928     auto body_param = body_builder.AddInstruction(
1929         HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1930     auto body_element_0 = body_builder.AddInstruction(
1931         HloInstruction::CreateGetTupleElement(element_shape, body_param, 0));
1932     auto body_element_1 = body_builder.AddInstruction(
1933         HloInstruction::CreateGetTupleElement(element_shape, body_param, 1));
1934     auto body_element_2 = body_builder.AddInstruction(
1935         HloInstruction::CreateGetTupleElement(element_shape, body_param, 2));
1936     auto body_element_3 = body_builder.AddInstruction(
1937         HloInstruction::CreateGetTupleElement(element_shape, body_param, 3));
1938     auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
1939         element_shape, HloOpcode::kNegate, body_element_2));
1940     auto reverse = body_builder.AddInstruction(
1941         HloInstruction::CreateReverse(element_shape, body_element_3, {0}));
1942     body_builder.AddInstruction(HloInstruction::CreateTuple(
1943         {body_element_0, body_element_1, negate, reverse}));
1944     HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1945 
1946     auto cond_builder = HloComputation::Builder("condition");
1947     cond_builder.AddInstruction(
1948         HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1949     auto cond_constant = cond_builder.AddInstruction(
1950         HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1951     cond_builder.AddInstruction(HloInstruction::CreateUnary(
1952         cond_constant->shape(), HloOpcode::kNot, cond_constant));
1953     HloComputation* condition =
1954         module->AddEmbeddedComputation(cond_builder.Build());
1955 
1956     auto while_init = builder.AddInstruction(HloInstruction::CreateTuple(
1957         {param_0, prev_element_1, prev_element_2, prev_element_3}));
1958 
1959     auto xla_while = builder.AddInstruction(HloInstruction::CreateWhile(
1960         loop_state_shape, condition, body, while_init));
1961     whiles.push_back(xla_while);
1962     if (i != kNumWhiles - 1) {
1963       prev_element_1 = builder.AddInstruction(
1964           HloInstruction::CreateGetTupleElement(element_shape, xla_while, 1));
1965       prev_element_2 = builder.AddInstruction(
1966           HloInstruction::CreateGetTupleElement(element_shape, xla_while, 2));
1967       prev_element_3 = builder.AddInstruction(
1968           HloInstruction::CreateGetTupleElement(element_shape, xla_while, 3));
1969     }
1970   }
1971 
1972   module->AddEntryComputation(builder.Build());
1973 
1974   InsertCopies(module.get());
1975 
1976   // Each while body has one copy. And each loop state element is copied once in
1977   // the entry computation.
1978   EXPECT_EQ(CountCopies(*module), 4 + kNumWhiles);
1979 
1980   // Each while body should have exactly one copy for element three which is an
1981   // op (kReverse) which cannot be done in place.
1982   for (const HloInstruction* xla_while : whiles) {
1983     EXPECT_EQ(CountCopies(*xla_while->while_body()), 1);
1984   }
1985 
1986   EXPECT_THAT(whiles[0]->operand(0), op::Tuple(op::Parameter(), op::Parameter(),
1987                                                op::Copy(), op::Copy()));
1988   EXPECT_THAT(module->entry_computation()->root_instruction(),
1989               op::Tuple(op::Copy(), op::Copy(), op::GetTupleElement(),
1990                         op::GetTupleElement()));
1991 }
1992 
TEST_F(CopyInsertionTest,WhileBodyWithConstantRoot)1993 TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) {
1994   // Test a while body and condition which are each simply a constant (root of
1995   // computation is a constant). The body constant should be copied.
1996   auto module = CreateNewVerifiedModule();
1997   auto builder = HloComputation::Builder(TestName());
1998   auto param_0 = builder.AddInstruction(
1999       HloInstruction::CreateParameter(0, scalar_shape_, "param_0"));
2000 
2001   auto body_builder = HloComputation::Builder("body");
2002   body_builder.AddInstruction(
2003       HloInstruction::CreateParameter(0, scalar_shape_, "param"));
2004   body_builder.AddInstruction(
2005       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0)));
2006   HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
2007 
2008   auto cond_builder = HloComputation::Builder("condition");
2009   cond_builder.AddInstruction(
2010       HloInstruction::CreateParameter(0, scalar_shape_, "param"));
2011   cond_builder.AddInstruction(
2012       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
2013   HloComputation* condition =
2014       module->AddEmbeddedComputation(cond_builder.Build());
2015 
2016   auto xla_while = builder.AddInstruction(
2017       HloInstruction::CreateWhile(scalar_shape_, condition, body, param_0));
2018 
2019   module->AddEntryComputation(builder.Build());
2020 
2021   InsertCopies(module.get());
2022 
2023   EXPECT_EQ(CountCopies(*module), 2);
2024 
2025   EXPECT_THAT(xla_while->operand(0), op::Copy(op::Parameter()));
2026   EXPECT_THAT(body->root_instruction(), op::Copy(op::Constant()));
2027   EXPECT_THAT(condition->root_instruction(), op::Constant());
2028 }
2029 
TEST_F(CopyInsertionTest,TokensShouldNotBeCopied)2030 TEST_F(CopyInsertionTest, TokensShouldNotBeCopied) {
2031   string module_string = R"(
2032 HloModule TokensShouldNotBeCopied
2033 
2034 %Body (param.1: (s32[], token[])) -> (s32[], token[]) {
2035   %param.1 = (s32[], token[]) parameter(0)
2036   %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
2037   %constant.1 = s32[] constant(1)
2038   %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
2039   %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
2040   %after-all = token[] after-all(token[] %get-tuple-element.2)
2041   ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
2042 }
2043 
2044 %Cond (param: (s32[], token[])) -> pred[] {
2045   %param = (s32[], token[]) parameter(0)
2046   %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
2047   %constant = s32[] constant(42)
2048   ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
2049 }
2050 
2051 ENTRY %TokensShouldNotBeCopied () -> s32[] {
2052   %one = s32[] constant(1)
2053   %negative_one = s32[] negate(%one)
2054   %init_token = token[] after-all()
2055   %init_tuple = (s32[], token[]) tuple(s32[] %negative_one, token[] %init_token)
2056   %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
2057   ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
2058 }
2059 )";
2060   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2061                           ParseAndReturnVerifiedModule(module_string));
2062   InsertCopies(module.get());
2063 
2064   // There should be no copies added because tokens should not be copied.
2065   EXPECT_EQ(CountCopies(*module), 0);
2066 }
2067 
MakeTrivialCondition(const Shape & shape)2068 std::unique_ptr<HloComputation> MakeTrivialCondition(const Shape& shape) {
2069   auto builder = HloComputation::Builder("trivial_condition");
2070   builder.AddInstruction(
2071       HloInstruction::CreateParameter(0, shape, "loop_state"));
2072   auto constant = builder.AddInstruction(
2073       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
2074   builder.AddInstruction(HloInstruction::CreateUnary(
2075       constant->shape(), HloOpcode::kNot, constant));
2076   return builder.Build();
2077 }
2078 
MakeBenchmarkWhileBody()2079 std::unique_ptr<HloComputation> MakeBenchmarkWhileBody() {
2080   auto builder = HloComputation::Builder("benchmark_loop_body");
2081   const Shape element_shape = ShapeUtil::MakeShape(F32, {42});
2082   const Shape loop_state_shape =
2083       ShapeUtil::MakeTupleShape({element_shape, element_shape, element_shape});
2084   HloInstruction* param = builder.AddInstruction(
2085       HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
2086   HloInstruction* element_0 = builder.AddInstruction(
2087       HloInstruction::CreateGetTupleElement(element_shape, param, 0));
2088   HloInstruction* element_1 = builder.AddInstruction(
2089       HloInstruction::CreateGetTupleElement(element_shape, param, 1));
2090   HloInstruction* element_2 = builder.AddInstruction(
2091       HloInstruction::CreateGetTupleElement(element_shape, param, 2));
2092 
2093   HloInstruction* rev_1 = builder.AddInstruction(
2094       HloInstruction::CreateReverse(element_shape, element_1, {0}));
2095   HloInstruction* add_1_2 = builder.AddInstruction(HloInstruction::CreateBinary(
2096       element_shape, HloOpcode::kAdd, element_1, element_2));
2097 
2098   builder.AddInstruction(
2099       HloInstruction::CreateTuple({element_0, rev_1, add_1_2}));
2100   return builder.Build();
2101 }
2102 
BM_SequentialWhiles(::testing::benchmark::State & state)2103 void BM_SequentialWhiles(::testing::benchmark::State& state) {
2104   const int num_whiles = state.range(0);
2105 
2106   // This benchmark constructs a chain of sequential while instructions.
2107   // Timer starts automatically at the first iteration of this loop
2108   // and ends after the last one.
2109   for (auto s : state) {
2110     state.PauseTiming();
2111     HloModuleConfig config;
2112     config.set_debug_options(GetDebugOptionsFromFlags());
2113     HloModule module("BM_SequentialWhiles", config);
2114 
2115     auto builder = HloComputation::Builder("BM_SequentialWhiles");
2116     HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
2117         0, ShapeUtil::MakeShape(F32, {42}), "x"));
2118     HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
2119         1, ShapeUtil::MakeShape(F32, {42}), "y"));
2120     HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
2121         2, ShapeUtil::MakeShape(F32, {42}), "z"));
2122     HloInstruction* init =
2123         builder.AddInstruction(HloInstruction::CreateTuple({x, y, z}));
2124 
2125     HloInstruction* prev_loop_state = init;
2126     for (int w = 0; w < num_whiles; ++w) {
2127       HloComputation* condition =
2128           module.AddEmbeddedComputation(MakeTrivialCondition(init->shape()));
2129       HloComputation* body =
2130           module.AddEmbeddedComputation(MakeBenchmarkWhileBody());
2131       prev_loop_state = builder.AddInstruction(HloInstruction::CreateWhile(
2132           init->shape(), condition, body, prev_loop_state));
2133     }
2134     module.AddEntryComputation(builder.Build());
2135 
2136     CopyInsertion copy_insertion;
2137 
2138     state.ResumeTiming();
2139     ASSERT_IS_OK(copy_insertion.Run(&module).status());
2140     state.PauseTiming();
2141 
2142     // The entry computation should have three copies, and each body has one.
2143     ASSERT_EQ(CountCopies(module), 3 + num_whiles);
2144     state.ResumeTiming();
2145   }
2146 }
2147 
BM_ParallelWhiles(::testing::benchmark::State & state)2148 void BM_ParallelWhiles(::testing::benchmark::State& state) {
2149   const int num_whiles = state.range(0);
2150 
2151   // This benchmark constructs a fan-out of parallel while instructions.
2152   for (auto s : state) {
2153     state.PauseTiming();
2154     HloModuleConfig config;
2155     config.set_debug_options(GetDebugOptionsFromFlags());
2156     HloModule module("BM_SequentialWhiles", config);
2157 
2158     auto builder = HloComputation::Builder("BM_ParallelWhiles");
2159     HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
2160         0, ShapeUtil::MakeShape(F32, {42}), "x"));
2161     HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
2162         1, ShapeUtil::MakeShape(F32, {42}), "y"));
2163     HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
2164         2, ShapeUtil::MakeShape(F32, {42}), "z"));
2165     HloInstruction* init =
2166         builder.AddInstruction(HloInstruction::CreateTuple({x, y, z}));
2167 
2168     HloInstruction* sum = nullptr;
2169     for (int w = 0; w < num_whiles; ++w) {
2170       HloComputation* condition =
2171           module.AddEmbeddedComputation(MakeTrivialCondition(init->shape()));
2172       HloComputation* body =
2173           module.AddEmbeddedComputation(MakeBenchmarkWhileBody());
2174 
2175       HloInstruction* xla_while = builder.AddInstruction(
2176           HloInstruction::CreateWhile(init->shape(), condition, body, init));
2177 
2178       if (sum == nullptr) {
2179         sum = builder.AddInstruction(
2180             HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0));
2181       } else {
2182         HloInstruction* element_0 = builder.AddInstruction(
2183             HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0));
2184         sum = builder.AddInstruction(HloInstruction::CreateBinary(
2185             x->shape(), HloOpcode::kAdd, sum, element_0));
2186       }
2187     }
2188     module.AddEntryComputation(builder.Build());
2189 
2190     CopyInsertion copy_insertion;
2191 
2192     state.ResumeTiming();
2193     ASSERT_IS_OK(copy_insertion.Run(&module).status());
2194     state.PauseTiming();
2195 
2196     // Each body receives of copy of two of the parameters (the corresponding
2197     // elements in the body are modified), and there is one copy in each body.
2198     ASSERT_EQ(CountCopies(module), 3 * num_whiles);
2199   }
2200 }
2201 
MakeBenchmarkWhileBody(const int num_tuple_inputs)2202 std::unique_ptr<HloComputation> MakeBenchmarkWhileBody(
2203     const int num_tuple_inputs) {
2204   auto builder = HloComputation::Builder("benchmark_loop_body");
2205   const Shape element_shape = ShapeUtil::MakeShape(F32, {});
2206   std::vector<Shape> input_shape(num_tuple_inputs, element_shape);
2207   const Shape loop_state_shape = ShapeUtil::MakeTupleShape(input_shape);
2208   HloInstruction* param = builder.AddInstruction(
2209       HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
2210   std::vector<HloInstruction*> gte_nodes(num_tuple_inputs);
2211   for (int i = 0; i < num_tuple_inputs; ++i) {
2212     gte_nodes[i] = builder.AddInstruction(
2213         HloInstruction::CreateGetTupleElement(element_shape, param, i));
2214   }
2215   builder.AddInstruction(HloInstruction::CreateTuple(gte_nodes));
2216   return builder.Build();
2217 }
2218 
BM_ManyElementTuple(::testing::benchmark::State & state)2219 void BM_ManyElementTuple(::testing::benchmark::State& state) {
2220   const int num_tuple_inputs = state.range(0);
2221   HloModuleConfig config;
2222   config.set_debug_options(GetDebugOptionsFromFlags());
2223   CopyInsertion copy_insertion;
2224   const Shape element_shape = ShapeUtil::MakeShape(F32, {});
2225   std::vector<HloInstruction*> tuple_params(num_tuple_inputs);
2226   for (auto s : state) {
2227     state.PauseTiming();
2228     auto builder = HloComputation::Builder("BM_ParallelWhiles");
2229     HloModule module("BM_ManyElementTuple", config);
2230     for (int j = 0; j < num_tuple_inputs; ++j) {
2231       tuple_params[j] = builder.AddInstruction(
2232           HloInstruction::CreateParameter(j, element_shape, ""));
2233     }
2234     HloInstruction* init =
2235         builder.AddInstruction(HloInstruction::CreateTuple(tuple_params));
2236     HloComputation* condition =
2237         module.AddEmbeddedComputation(MakeTrivialCondition(init->shape()));
2238     HloComputation* body =
2239         module.AddEmbeddedComputation(MakeBenchmarkWhileBody(num_tuple_inputs));
2240     HloInstruction* xla_while = builder.AddInstruction(
2241         HloInstruction::CreateWhile(init->shape(), condition, body, init));
2242     builder.AddInstruction(HloInstruction::CreateGetTupleElement(
2243         ShapeUtil::MakeShape(F32, {}), xla_while, 0));
2244     module.AddEntryComputation(builder.Build());
2245     state.ResumeTiming();
2246     ASSERT_IS_OK(copy_insertion.Run(&module).status());
2247   }
2248 }
2249 
2250 BENCHMARK(BM_SequentialWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096);
2251 BENCHMARK(BM_ParallelWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096);
2252 BENCHMARK(BM_ManyElementTuple)->Arg(1024)->Arg(12288);
2253 
TEST_F(CopyInsertionTest,SimpleControlFlowTest)2254 TEST_F(CopyInsertionTest, SimpleControlFlowTest) {
2255   const string& hlo_string = R"(
2256 HloModule TestModule
2257 
2258 if-body.v5 {
2259   constant.3 = s32[] constant(-1)
2260   p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2261   get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1
2262   get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0
2263   get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1
2264   add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66)
2265   tuple.33 = (s32[]) tuple(add.3)
2266   ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33)
2267 }
2268 
2269 if-condition.v4 {
2270   p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2271   get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0
2272   constant.4 = s32[] constant(0)
2273   ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ
2274 }
2275 
2276 _functionalize_body_1__.v28 {
2277   arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0)
2278   get-tuple-element.68 = s32[] get-tuple-element(arg_tuple.4), index=0
2279   constant.7 = s32[] constant(1)
2280   add.4 = s32[] add(get-tuple-element.68, constant.7)
2281   get-tuple-element.69 = s32[] get-tuple-element(arg_tuple.4), index=1
2282   get-tuple-element.70 = s32[] get-tuple-element(arg_tuple.4), index=2
2283   less-than-or-equal-to = pred[] compare(get-tuple-element.69, get-tuple-element.70), direction=LE
2284   constant.8 = s32[] constant(0)
2285   select = s32[] select(less-than-or-equal-to, constant.8, constant.7)
2286   get-tuple-element.71 = s32[] get-tuple-element(arg_tuple.4), index=3
2287   tuple.35 = (s32[], s32[], s32[]) tuple(get-tuple-element.69, get-tuple-element.71, get-tuple-element.70)
2288   tuple.36 = (s32[]) tuple(constant.8)
2289   tuple.37 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.35, tuple.36)
2290   while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.37), condition=if-condition.v4, body=if-body.v5
2291   get-tuple-element.72 = (s32[]) get-tuple-element(while), index=2
2292   get-tuple-element.73 = s32[] get-tuple-element(get-tuple-element.72), index=0
2293   ROOT tuple.38 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.69, get-tuple-element.70, get-tuple-element.73)
2294 }
2295 
2296 cond_wrapper.v3.1 {
2297   inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0)
2298   get-tuple-element.75 = s32[] get-tuple-element(inputs.1), index=0
2299   constant.11 = s32[] constant(7)
2300   ROOT less-than.2 = pred[] compare(get-tuple-element.75, constant.11), direction=LT
2301 }
2302 
2303 _functionalize_body_2__.v25 {
2304   arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2305   get-tuple-element.76 = s32[] get-tuple-element(arg_tuple.5), index=0
2306   get-tuple-element.77 = s32[] get-tuple-element(arg_tuple.5), index=2
2307   get-tuple-element.78 = s32[] get-tuple-element(arg_tuple.5), index=3
2308   get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=4
2309   tuple.39 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.76, get-tuple-element.77, get-tuple-element.78, get-tuple-element.79)
2310   while.2 = (s32[], s32[], s32[], s32[]) while(tuple.39), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28
2311   get-tuple-element.80 = s32[] get-tuple-element(while.2), index=0
2312   get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=1
2313   constant.12 = s32[] constant(1)
2314   add.5 = s32[] add(get-tuple-element.81, constant.12)
2315   get-tuple-element.82 = s32[] get-tuple-element(while.2), index=3
2316   ROOT tuple.40 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.80, add.5, get-tuple-element.77, get-tuple-element.78, get-tuple-element.82)
2317 }
2318 
2319 cond_wrapper.v3.2 {
2320   inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2321   get-tuple-element.83 = s32[] get-tuple-element(inputs.2), index=1
2322   constant.13 = s32[] constant(5)
2323   ROOT less-than.3 = pred[] compare(get-tuple-element.83, constant.13), direction=LT
2324 }
2325 
2326 ENTRY TestComputation {
2327   arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2328   ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25
2329 }
2330 )";
2331   auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
2332   auto module = module_or_status.ConsumeValueOrDie();
2333   InsertCopies(module.get());
2334 }
2335 
TEST_F(CopyInsertionTest,ControlFlowTest)2336 TEST_F(CopyInsertionTest, ControlFlowTest) {
2337   const string& hlo_string = R"(
2338 HloModule TestModule
2339 
2340 if-body.v5 {
2341   constant.3 = s32[] constant(-1)
2342   p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2343   get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1
2344   get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0
2345   get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1
2346   add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66)
2347   tuple.33 = (s32[]) tuple(add.3)
2348   ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33)
2349 }
2350 
2351 if-condition.v4 {
2352   p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2353   get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0
2354   constant.4 = s32[] constant(0)
2355   ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ
2356 }
2357 
2358 if-body.v5.1 {
2359   constant.5 = s32[] constant(-1)
2360   p.3 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2361   get-tuple-element.68 = (s32[], s32[], s32[]) get-tuple-element(p.3), index=1
2362   get-tuple-element.70 = s32[] get-tuple-element(get-tuple-element.68), index=2
2363   multiply.1 = s32[] multiply(get-tuple-element.70, get-tuple-element.70)
2364   tuple.35 = (s32[]) tuple(multiply.1)
2365   ROOT tuple.36 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.5, get-tuple-element.68, tuple.35)
2366 }
2367 
2368 if-condition.v4.1 {
2369   p.4 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2370   get-tuple-element.71 = s32[] get-tuple-element(p.4), index=0
2371   constant.6 = s32[] constant(1)
2372   ROOT equal-to.1 = pred[] compare(get-tuple-element.71, constant.6), direction=EQ
2373 }
2374 
2375 _functionalize_body_1__.v28 {
2376   arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0)
2377   get-tuple-element.72 = s32[] get-tuple-element(arg_tuple.4), index=0
2378   constant.7 = s32[] constant(1)
2379   add.4 = s32[] add(get-tuple-element.72, constant.7)
2380   get-tuple-element.73 = s32[] get-tuple-element(arg_tuple.4), index=1
2381   get-tuple-element.74 = s32[] get-tuple-element(arg_tuple.4), index=2
2382   less-than-or-equal-to = pred[] compare(get-tuple-element.73, get-tuple-element.74), direction=LE
2383   constant.8 = s32[] constant(0)
2384   select = s32[] select(less-than-or-equal-to, constant.8, constant.7)
2385   get-tuple-element.75 = s32[] get-tuple-element(arg_tuple.4), index=3
2386   tuple.37 = (s32[], s32[], s32[]) tuple(get-tuple-element.73, get-tuple-element.75, get-tuple-element.74)
2387   tuple.38 = (s32[]) tuple(constant.8)
2388   tuple.39 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.37, tuple.38)
2389   while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.39), condition=if-condition.v4, body=if-body.v5
2390   while.1 = (s32[], (s32[], s32[], s32[]), (s32[])) while(while), condition=if-condition.v4.1, body=if-body.v5.1
2391   get-tuple-element.76 = (s32[]) get-tuple-element(while.1), index=2
2392   get-tuple-element.77 = s32[] get-tuple-element(get-tuple-element.76), index=0
2393   ROOT tuple.40 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.73, get-tuple-element.74, get-tuple-element.77)
2394 }
2395 
2396 cond_wrapper.v3.1 {
2397   inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0)
2398   get-tuple-element.78 = s32[] get-tuple-element(inputs.1), index=0
2399   constant.11 = s32[] constant(7)
2400   ROOT less-than.2 = pred[] compare(get-tuple-element.78, constant.11), direction=LT
2401 }
2402 
2403 _functionalize_body_2__.v25 {
2404   arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2405   get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=0
2406   get-tuple-element.80 = s32[] get-tuple-element(arg_tuple.5), index=2
2407   get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=3
2408   get-tuple-element.82 = s32[] get-tuple-element(arg_tuple.5), index=4
2409   tuple.41 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.79, get-tuple-element.80, get-tuple-element.81, get-tuple-element.82)
2410   while.2 = (s32[], s32[], s32[], s32[]) while(tuple.41), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28
2411   get-tuple-element.83 = s32[] get-tuple-element(while.2), index=0
2412   get-tuple-element.84 = s32[] get-tuple-element(arg_tuple.5), index=1
2413   constant.12 = s32[] constant(1)
2414   add.5 = s32[] add(get-tuple-element.84, constant.12)
2415   get-tuple-element.85 = s32[] get-tuple-element(while.2), index=3
2416   ROOT tuple.42 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.83, add.5, get-tuple-element.80, get-tuple-element.81, get-tuple-element.85)
2417 }
2418 
2419 cond_wrapper.v3.2 {
2420   inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2421   get-tuple-element.86 = s32[] get-tuple-element(inputs.2), index=1
2422   constant.13 = s32[] constant(5)
2423   ROOT less-than.3 = pred[] compare(get-tuple-element.86, constant.13), direction=LT
2424 }
2425 
2426 ENTRY TestComputation {
2427   arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2428   ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25
2429 }
2430 )";
2431   auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
2432   auto module = module_or_status.ConsumeValueOrDie();
2433   InsertCopies(module.get());
2434 }
2435 
TEST_F(CopyInsertionTest,NestedWhiles)2436 TEST_F(CopyInsertionTest, NestedWhiles) {
2437   // Verify that only no unnecessary copies remain after copy insertion for
2438   // trivial nested whiles (b/112472605).
2439   const string& hlo_string = R"(
2440 HloModule TestModule
2441 
2442 cond.inner {
2443   ROOT param.cond.inner = pred[] parameter(0)
2444 }
2445 
2446 body.inner {
2447   param.body.inner = pred[] parameter(0)
2448   ROOT not = pred[] not(param.body.inner)
2449 }
2450 
2451 cond.outer {
2452   ROOT param.cond.outer = pred[] parameter(0)
2453 }
2454 
2455 body.outer {
2456   param.cond.outer = pred[] parameter(0)
2457   ROOT while = pred[] while(param.cond.outer), condition=cond.inner, body=body.inner
2458 }
2459 
2460 ENTRY TestComputation {
2461   entry_param = pred[] parameter(0)
2462   ROOT while = pred[] while(entry_param), condition=cond.outer, body=body.outer
2463 }
2464 )";
2465   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2466                           ParseAndReturnVerifiedModule(hlo_string));
2467   InsertCopies(module.get());
2468 
2469   // There should only be a single copy inserted, and it's in the entry
2470   // computation.
2471   EXPECT_EQ(CountCopies(*module), 1);
2472   EXPECT_THAT(module->entry_computation()->root_instruction(),
2473               op::While(op::Copy(op::Parameter())));
2474 }
2475 
TEST_F(CopyInsertionTest,NestedWhileAndConditional2)2476 TEST_F(CopyInsertionTest, NestedWhileAndConditional2) {
2477   const string& hlo_string = R"(
2478 HloModule TestModule
2479 
2480 on_true
2481  {
2482   v1 = f32[2] parameter(0)
2483   v2 = f32[2] add(v1,v1)
2484   ROOT t1 = (f32[2], f32[2]) tuple(v1,v2)
2485 }
2486 
2487 on_false
2488  {
2489   v1 = f32[2] parameter(0)
2490   v2 = f32[2] multiply(v1,v1)
2491   ROOT t2 = (f32[2], f32[2]) tuple(v1,v2)
2492 }
2493 
2494 cond.outer {
2495   param.1 = (pred[], f32[2], f32[2]) parameter(0)
2496   ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0
2497 }
2498 
2499 body.outer {
2500   param.1 = (pred[], f32[2], f32[2]) parameter(0)
2501   pred.1 = pred[] get-tuple-element(param.1), index=0
2502   arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1
2503   if = (f32[2], f32[2]) conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false
2504   e1 = f32[2] get-tuple-element(if), index=0
2505   e2 = f32[2] get-tuple-element(if), index=1
2506   ROOT res = (pred[], f32[2], f32[2]) tuple(pred.1,e1, e2)
2507 }
2508 
2509 ENTRY TestComputation {
2510   entry_param.1 = pred[] parameter(0)
2511   float_param = f32[2] parameter(1)
2512   entry_param = (pred[], f32[2], f32[2]) tuple(entry_param.1, float_param, float_param)
2513   ROOT while = (pred[], f32[2], f32[2]) while(entry_param), condition=cond.outer, body=body.outer
2514 }
2515 )";
2516   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2517                           ParseAndReturnVerifiedModule(hlo_string));
2518   InsertCopies(module.get());
2519   VLOG(2) << module->ToString() << "\n";
2520 
2521   // An extra copy must be kept inside the loop due to uses in the conditional.
2522   EXPECT_EQ(CountCopies(*module), 3);
2523 }
2524 
TEST_F(CopyInsertionTest,NestedWhileAndConditional)2525 TEST_F(CopyInsertionTest, NestedWhileAndConditional) {
2526   const string& hlo_string = R"(
2527 HloModule TestModule
2528 
2529 on_true
2530  {
2531   v1 = f32[2] parameter(0)
2532   ROOT v2 = f32[2] add(v1,v1)
2533 }
2534 
2535 on_false
2536  {
2537   v1 = f32[2] parameter(0)
2538   ROOT v2 = f32[2] multiply(v1,v1)
2539 }
2540 
2541 cond.outer {
2542   param.1 = (pred[], f32[2]) parameter(0)
2543   ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0
2544 }
2545 
2546 body.outer {
2547   param.1 = (pred[], f32[2]) parameter(0)
2548   pred.1 = pred[] get-tuple-element(param.1), index=0
2549   arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1
2550   if = f32[2] conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false
2551   ROOT res = (pred[], f32[2]) tuple(pred.1,if)
2552 }
2553 
2554 ENTRY TestComputation {
2555   entry_param.1 = pred[] parameter(0)
2556   float_param = f32[2] parameter(1)
2557   entry_param = (pred[], f32[2]) tuple(entry_param.1, float_param)
2558   ROOT while = (pred[], f32[2]) while(entry_param), condition=cond.outer, body=body.outer
2559 }
2560 )";
2561   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2562                           ParseAndReturnVerifiedModule(hlo_string));
2563   InsertCopies(module.get());
2564   VLOG(2) << module->ToString() << "\n";
2565 
2566   // There should only be two copies inserted, and in the entry and exit of the
2567   // computation.
2568   EXPECT_EQ(CountCopies(*module), 2);
2569 }
2570 
TEST_F(CopyInsertionTest,FixpointComputationRequired)2571 TEST_F(CopyInsertionTest, FixpointComputationRequired) {
2572   const string& hlo_string = R"(
2573 HloModule Module
2574 
2575 fused_computation {
2576   param0 = f32[3,3,96,1] parameter(0)
2577   param1 = f32[] parameter(1)
2578   broadcast = f32[3,3,96,1] broadcast(f32[] param1), dimensions={}
2579   ROOT %add.0 = f32[3,3,96,1] add(f32[3,3,96,1] param0, f32[3,3,96,1] broadcast)
2580 }
2581 
2582 ENTRY entry_computation {
2583   arg0 = f32[3,3,96,1] parameter(0)
2584   arg1 = f32[] parameter(1)
2585   fusion = f32[3,3,96,1] fusion(f32[3,3,96,1] arg0, f32[] arg1),
2586     kind=kLoop, calls=fused_computation
2587   negate = f32[] negate(f32[] arg1)
2588   ROOT tuple = (f32[3,3,96,1], f32[3,3,96,1], f32[], f32[]) tuple(
2589     f32[3,3,96,1] fusion,
2590     f32[3,3,96,1] arg0,
2591     f32[] negate,
2592     f32[] arg1)
2593 }
2594   )";
2595 
2596   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2597                           ParseAndReturnVerifiedModule(hlo_string));
2598   // Set up the aliasing manually which normally would be set by
2599   // alias_passthrough_params pass.
2600   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
2601       /*output_index=*/{1},
2602       /*param_number=*/0,
2603       /*param_index=*/{}));
2604   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
2605       /*output_index=*/{3},
2606       /*param_number=*/1,
2607       /*param_index=*/{}));
2608 
2609   InsertCopies(module.get());
2610 
2611   // There should be no copies inserted.
2612   EXPECT_EQ(CountCopies(*module), 0);
2613 }
2614 
TEST_F(CopyInsertionTest,NoAliasCheckViolation)2615 TEST_F(CopyInsertionTest, NoAliasCheckViolation) {
2616   const string& hlo_string = R"(
2617 HloModule cluster
2618 
2619 ENTRY Entry {
2620   %arg = f32[8,28,28,1] parameter(0)
2621   %bitcast.2 = f32[8,1,28,28] bitcast(f32[8,28,28,1] %arg)
2622   ROOT %tuple.1 = (f32[8,1,28,28], f32[8,28,28,1]) tuple(f32[8,1,28,28] %bitcast.2, f32[8,28,28,1] %arg)
2623 }
2624 )";
2625   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2626                           ParseAndReturnVerifiedModule(hlo_string));
2627   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
2628       /*output_index=*/{1},
2629       /*param_number=*/0,
2630       /*param_index=*/{}));
2631   InsertCopies(module.get());
2632   EXPECT_EQ(CountCopies(*module), 1);
2633 }
2634 
TEST_F(CopyInsertionTest,DynamicUpdateSliceNoCopy)2635 TEST_F(CopyInsertionTest, DynamicUpdateSliceNoCopy) {
2636   absl::string_view hlo_string = R"(
2637 HloModule Module
2638 
2639 ENTRY main {
2640   param = f32[1280,1,128] parameter(0)
2641   negate = f32[1280,1,128] negate(param)
2642   constant.1 = f32[] constant(0)
2643   broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2644   constant.3 = s32[] constant(0)
2645   ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(negate, broadcast.6, constant.3, constant.3, constant.3)
2646 }
2647 )";
2648   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2649                           ParseAndReturnVerifiedModule(hlo_string));
2650   InsertCopies(module.get());
2651   EXPECT_EQ(CountCopies(*module), 0);
2652 }
2653 
TEST_F(CopyInsertionTest,FusedDynamicUpdateSliceNoCopy)2654 TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceNoCopy) {
2655   absl::string_view hlo_string = R"(
2656 HloModule Module
2657 
2658 fused_computation {
2659   param0 = f32[1280,1,128] parameter(0)
2660   constant.1 = f32[] constant(0)
2661   broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2662   constant.3 = s32[] constant(0)
2663   ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
2664 }
2665 
2666 ENTRY main {
2667   param = f32[1280,1,128] parameter(0)
2668   negate = f32[1280,1,128] negate(param)
2669   ROOT fusion = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation
2670 }
2671 )";
2672   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2673                           ParseAndReturnVerifiedModule(hlo_string));
2674   InsertCopies(module.get());
2675   EXPECT_EQ(CountCopies(*module), 0);
2676 }
2677 
TEST_F(CopyInsertionTest,DynamicUpdateSliceCopy)2678 TEST_F(CopyInsertionTest, DynamicUpdateSliceCopy) {
2679   absl::string_view hlo_string = R"(
2680 HloModule Module
2681 
2682 ENTRY main {
2683   param = f32[1280,1,128] parameter(0)
2684   negate = f32[1280,1,128] negate(param)
2685   constant.1 = f32[] constant(0)
2686   broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2687   constant.3 = s32[] constant(0)
2688   add = f32[1280,1,128] add(negate, negate)
2689   dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(negate, broadcast.6, constant.3, constant.3, constant.3)
2690   ROOT tuple = (f32[1280,1,128], f32[1280,1,128]) tuple(add, dynamic-update-slice.5)
2691 }
2692 )";
2693   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2694                           ParseAndReturnVerifiedModule(hlo_string));
2695   InsertCopies(module.get());
2696   EXPECT_EQ(CountCopies(*module), 1);
2697 }
2698 
TEST_F(CopyInsertionTest,DynamicUpdateSliceParameterShareCopy)2699 TEST_F(CopyInsertionTest, DynamicUpdateSliceParameterShareCopy) {
2700   absl::string_view hlo_string = R"(
2701 HloModule Module
2702 
2703 ENTRY main {
2704   param = f32[1280,1,128] parameter(0)
2705   constant.1 = f32[] constant(0)
2706   broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2707   constant.3 = s32[] constant(0)
2708   ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param, broadcast.6, constant.3, constant.3, constant.3)
2709 }
2710 )";
2711   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2712                           ParseAndReturnVerifiedModule(hlo_string));
2713   InsertCopies(module.get());
2714   EXPECT_EQ(CountCopies(*module), 1);
2715 }
2716 
TEST_F(CopyInsertionTest,FusedDynamicUpdateSliceCopy)2717 TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceCopy) {
2718   absl::string_view hlo_string = R"(
2719 HloModule Module
2720 
2721 fused_computation {
2722   param0 = f32[1280,1,128] parameter(0)
2723   constant.1 = f32[] constant(0)
2724   broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2725   constant.3 = s32[] constant(0)
2726   ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
2727 }
2728 
2729 ENTRY main {
2730   param = f32[1280,1,128] parameter(0)
2731   negate = f32[1280,1,128] negate(param)
2732   add = f32[1280,1,128] add(negate, negate)
2733   fusion = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation
2734   ROOT tuple = (f32[1280,1,128], f32[1280,1,128]) tuple(negate, fusion)
2735 }
2736 )";
2737   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2738                           ParseAndReturnVerifiedModule(hlo_string));
2739   InsertCopies(module.get());
2740   EXPECT_EQ(CountCopies(*module), 1);
2741 }
2742 
TEST_F(CopyInsertionTest,ChainDynamicUpdateSliceCopy)2743 TEST_F(CopyInsertionTest, ChainDynamicUpdateSliceCopy) {
2744   absl::string_view hlo_string = R"(
2745 HloModule Module
2746 
2747 ENTRY main {
2748   state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
2749   constant.1 = f32[] constant(0)
2750   broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={}
2751   get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1
2752   get-tuple-element.3 = s32[] get-tuple-element(state), index=0
2753   constant.2 = s32[] constant(128)
2754   add.5 = s32[] add(get-tuple-element.3, constant.2)
2755   constant.3 = s32[] constant(0)
2756   dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
2757   dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
2758   ROOT tuple.85 = (s32[], f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
2759 }
2760 )";
2761   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2762                           ParseAndReturnVerifiedModule(hlo_string));
2763   InsertCopies(module.get());
2764   EXPECT_EQ(CountCopies(*module), 1);
2765 }
2766 
TEST_F(CopyInsertionTest,FusedDynamicUpdateSliceCopy2)2767 TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceCopy2) {
2768   absl::string_view hlo_string = R"(
2769 HloModule Module
2770 
2771 fused_computation.1 {
2772   param0 = f32[1280,1,128] parameter(0)
2773   constant.1 = f32[] constant(0)
2774   broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2775   constant.3 = s32[] constant(0)
2776   ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
2777 }
2778 
2779 fused_computation.2 {
2780   param0 = f32[1280,1,128] parameter(0)
2781   param1 = f32[1280,1,128] parameter(1)
2782   slice = f32[128,1,128] slice(param1), slice={[0:128], [0:1], [0:128]}
2783   constant.3 = s32[] constant(0)
2784   ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, slice, constant.3, constant.3, constant.3)
2785 }
2786 
2787 ENTRY main {
2788   param = f32[1280,1,128] parameter(0)
2789   negate = f32[1280,1,128] negate(param)
2790   add = f32[1280,1,128] add(negate, negate)
2791   fusion1 = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation.1
2792   ROOT fusion2 = f32[1280,1,128] fusion(fusion1, negate), kind=kLoop, calls=fused_computation.2
2793 }
2794 )";
2795   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2796                           ParseAndReturnVerifiedModule(hlo_string));
2797   InsertCopies(module.get());
2798   EXPECT_EQ(CountCopies(*module), 1);
2799 }
2800 
TEST_F(CopyInsertionTest,MultiOutputFusedDynamicUpdateSliceCopy)2801 TEST_F(CopyInsertionTest, MultiOutputFusedDynamicUpdateSliceCopy) {
2802   // Tests multi-output fusion with two DUS outputs, requiring two copies.
2803   absl::string_view hlo_string = R"(
2804 HloModule Module
2805 
2806 fused_computation {
2807   param0 = f32[1280,1,128] parameter(0)
2808   param1 = f32[1280,1,128] parameter(1)
2809   param2 = f32[1280,1,128] parameter(2)
2810   constant.1 = f32[] constant(0)
2811   broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2812   constant.3 = s32[] constant(0)
2813   add.1 = f32[1280,1,128] add(param0, param0)
2814   dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3)
2815   dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3)
2816   ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6)
2817 }
2818 
2819 ENTRY main {
2820   param = f32[1280,1,128] parameter(0)
2821   negate0 = f32[1280,1,128] negate(param)
2822   negate1 = f32[1280,1,128] negate(param)
2823   negate2 = f32[1280,1,128] negate(param)
2824   fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation
2825   gte0 = f32[1280,1,128] get-tuple-element(fusion), index=0
2826   gte1 = f32[1280,1,128] get-tuple-element(fusion), index=1
2827   gte2 = f32[1280,1,128] get-tuple-element(fusion), index=2
2828   add0 = f32[1280,1,128] add(negate0, gte0)
2829   add1 = f32[1280,1,128] add(negate1, gte1)
2830   add2 = f32[1280,1,128] add(negate2, gte2)
2831   ROOT tuple = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add0, add1, add2)
2832 }
2833 )";
2834   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2835                           ParseAndReturnVerifiedModule(hlo_string));
2836   InsertCopies(module.get());
2837   EXPECT_EQ(CountCopies(*module), 2);
2838 }
2839 
TEST_F(CopyInsertionTest,MultiOutputFusedDynamicUpdateSliceNoCopy)2840 TEST_F(CopyInsertionTest, MultiOutputFusedDynamicUpdateSliceNoCopy) {
2841   // Same as above, but negate1 is not used beyond fusion, so it only needs one
2842   // copy for negate0.
2843   absl::string_view hlo_string = R"(
2844 HloModule Module
2845 
2846 fused_computation {
2847   param0 = f32[1280,1,128] parameter(0)
2848   param1 = f32[1280,1,128] parameter(1)
2849   param2 = f32[1280,1,128] parameter(2)
2850   constant.1 = f32[] constant(0)
2851   broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
2852   constant.3 = s32[] constant(0)
2853   add.1 = f32[1280,1,128] add(param0, param0)
2854   dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3)
2855   dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3)
2856   ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6)
2857 }
2858 
2859 ENTRY main {
2860   param = f32[1280,1,128] parameter(0)
2861   negate0 = f32[1280,1,128] negate(param)
2862   negate1 = f32[1280,1,128] negate(param)
2863   negate2 = f32[1280,1,128] negate(param)
2864   fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation
2865   gte0 = f32[1280,1,128] get-tuple-element(fusion), index=0
2866   gte1 = f32[1280,1,128] get-tuple-element(fusion), index=1
2867   gte2 = f32[1280,1,128] get-tuple-element(fusion), index=2
2868   add0 = f32[1280,1,128] add(negate0, gte0)
2869   add1 = f32[1280,1,128] add(gte1, gte1)
2870   add2 = f32[1280,1,128] add(negate2, gte2)
2871   ROOT tuple = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add0, add1, add2)
2872 }
2873 )";
2874   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2875                           ParseAndReturnVerifiedModule(hlo_string));
2876   InsertCopies(module.get());
2877   EXPECT_EQ(CountCopies(*module), 1);
2878 }
2879 
TEST_F(CopyInsertionTest,HorizontalLoopFusionNoCopy)2880 TEST_F(CopyInsertionTest, HorizontalLoopFusionNoCopy) {
2881   const string& hlo_string = R"(
2882     HloModule test
2883 
2884     fused_computation {
2885       p0 = f32[10,20] parameter(0)
2886       p1 = f32[10,20] parameter(1)
2887       p2 = f32[10,10] parameter(2)
2888       p3 = f32[10,10] parameter(3)
2889       add0 = f32[10, 20] add(p0, p1)
2890       sub0 = f32[10, 10] subtract(p2, p3)
2891       reshape0 = f32[200] reshape(add0)
2892       reshape1 = f32[100] reshape(sub0)
2893       concat0 = f32[300] concatenate(reshape0, reshape1), dimensions={0}
2894       slice0 = f32[200] slice(concat0), slice={[0:200]}
2895       slice1 = f32[100] slice(concat0), slice={[200:300]}
2896       ROOT tuple = (f32[200], f32[100]) tuple(slice0, slice1)
2897     }
2898 
2899     ENTRY test {
2900       p0 = f32[10,20] parameter(0)
2901       p1 = f32[10,20] parameter(1)
2902       p2 = f32[10,10] parameter(2)
2903       p3 = f32[10,10] parameter(3)
2904       fusion = (f32[200], f32[100]) fusion(p0, p1, p2, p3), kind=kInput, calls=fused_computation
2905       gte0 = f32[200] get-tuple-element(fusion), index=0
2906       gte1 = f32[100] get-tuple-element(fusion), index=1
2907       bitcast0 = f32[10,20] bitcast(gte0)
2908       bitcast1 = f32[10,10] bitcast(gte1)
2909       ROOT tuple = (f32[10,20], f32[10,10]) tuple(bitcast0, bitcast1)
2910     }
2911   )";
2912 
2913   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2914                           ParseAndReturnVerifiedModule(hlo_string));
2915   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
2916       /*output_index=*/{0},
2917       /*param_number=*/0,
2918       /*param_index=*/{}));
2919   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
2920       /*output_index=*/{1},
2921       /*param_number=*/3,
2922       /*param_index=*/{}));
2923 
2924   InsertCopies(module.get());
2925 
2926   // There should be no copies inserted.
2927   EXPECT_EQ(CountCopies(*module), 0);
2928 }
2929 
TEST_F(CopyInsertionTest,NestedWhileAndConditional3)2930 TEST_F(CopyInsertionTest, NestedWhileAndConditional3) {
2931   const string& hlo_string = R"(
2932 HloModule TestModule
2933 
2934 on_true.1
2935  {
2936   ROOT v1 = f32[2] parameter(0)
2937 }
2938 
2939 on_false.1
2940  {
2941   v1 = f32[2] parameter(0)
2942   ROOT v2 = f32[2] multiply(v1,v1)
2943 }
2944 
2945 on_true
2946  {
2947   v1 = f32[2] parameter(0)
2948   v2 = f32[2] add(v1,v1)
2949   v3 = (f32[2],f32[2]) tuple(v1,v2)
2950   v4 = f32[2] get-tuple-element(v3), index=1
2951   v5 = f32[2] multiply(v4,v2)
2952    ROOT t1 = (f32[2], f32[2]) tuple(v5,v2)
2953 }
2954 
2955 on_false
2956  {
2957   v1 = f32[2] parameter(0)
2958   v2 = f32[2] multiply(v1,v1)
2959   pred.1 = pred[] constant(true)
2960   v4 = f32[2] conditional(pred.1, v1, v2), true_computation=on_true.1, false_computation=on_false.1
2961   v5 = f32[2] multiply(v4,v2)
2962   ROOT t2 = (f32[2], f32[2]) tuple(v2,v5)
2963 
2964 }
2965 
2966 cond.outer {
2967   param.1 = (pred[], f32[2], f32[2]) parameter(0)
2968   ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0
2969 }
2970 
2971 body.outer {
2972   param.1 = (pred[], f32[2], f32[2]) parameter(0)
2973   pred.1 = pred[] get-tuple-element(param.1), index=0
2974   arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1
2975   if = (f32[2], f32[2]) conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false
2976   e1 = f32[2] get-tuple-element(if), index=0
2977   e2 = f32[2] get-tuple-element(if), index=1
2978   ROOT res = (pred[], f32[2], f32[2]) tuple(pred.1,e1, e2)
2979 }
2980 
2981 ENTRY TestComputation {
2982   entry_param.1 = pred[] parameter(0)
2983   float_param = f32[2] parameter(1)
2984   entry_param = (pred[], f32[2], f32[2]) tuple(entry_param.1, float_param, float_param)
2985   ROOT while = (pred[], f32[2], f32[2]) while(entry_param), condition=cond.outer, body=body.outer
2986 }
2987 )";
2988   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2989                           ParseAndReturnVerifiedModule(hlo_string));
2990   InsertCopies(module.get());
2991   // An extra copy must be kept inside the loop due to uses in the conditional
2992   EXPECT_EQ(CountCopies(*module), 4);
2993 }
2994 
TEST_F(CopyInsertionTest,ConditionalBranchMustCopy1)2995 TEST_F(CopyInsertionTest, ConditionalBranchMustCopy1) {
2996   const string& hlo_string = R"(
2997 HloModule TestModule
2998 
2999  branch_0_comp.5.clone {
3000  %parameter.0 = (s32[2]{0:T(128)}) parameter(0)
3001  %get-tuple-element = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.0), index=0
3002  %negate = s32[2]{0:T(128)} negate(s32[2]{0:T(128)} %get-tuple-element)
3003  %copy = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %negate)
3004  ROOT tuple.5 = (s32[2]{0:T(128)}) tuple(%copy)
3005  }
3006 
3007  branch_1_comp.12.clone {
3008   %parameter.4 = (s32[2]{0:T(128)}) parameter(0)
3009   %get-tuple-element.5 = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.4), index=0
3010   %copy.1 = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %get-tuple-element.5)
3011   ROOT tuple.6 = (s32[2]{0:T(128)}) tuple(%copy.1)
3012  }
3013 
3014 ENTRY TestComputation {
3015   %parameter.1 = s32[]{:T(128)} parameter(0), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3016   %parameter.2 = s32[2]{0:T(128)} parameter(1), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3017   %parameter.3 = s32[2]{0:T(128)} parameter(2), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3018   %tuple.1 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.3)
3019   %tuple.3 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.2)
3020   %conditional.18 = (s32[2]{0:T(128)}) conditional(s32[]{:T(128)} %parameter.1, (s32[2]{0:T(128)}) %tuple.1, (s32[2]{0:T(128)}) %tuple.3), branch_computations={%branch_0_comp.5.clone, %branch_1_comp.12.clone}, metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3021   %gte.1 = s32[2]{0:T(128)} get-tuple-element(conditional.18), index=0
3022   ROOT tuple.4 = (s32[2]{0:T(128)},s32[2]{0:T(128)}) tuple(parameter.2, gte.1)
3023 }
3024 )";
3025   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3026                           ParseAndReturnVerifiedModule(hlo_string));
3027   InsertCopies(module.get());
3028   CopyInsertion copy_insertion(nullptr,
3029                                /*use_region_based_live_range_analysis=*/true);
3030   ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
3031   VLOG(3) << module->ToString();
3032   // The copy.1 must be kept due to modification in the other branch.
3033   auto conditional18 = FindInstruction(module.get(), "conditional.18");
3034   CHECK_NE(conditional18, nullptr);
3035   auto tuple6 = conditional18->branch_computation(1)->root_instruction();
3036   CHECK_EQ(tuple6->opcode(), HloOpcode::kTuple);
3037   auto copy1 = tuple6->operand(0);
3038   CHECK_EQ(copy1->opcode(), HloOpcode::kCopy);
3039 }
3040 
TEST_F(CopyInsertionTest,ConditionalBranchMustCopy2)3041 TEST_F(CopyInsertionTest, ConditionalBranchMustCopy2) {
3042   const string& hlo_string = R"(
3043 HloModule TestModule
3044 
3045  branch_0_comp.5.clone {
3046  %parameter.0 = (s32[2]{0:T(128)}) parameter(0)
3047  %get-tuple-element = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.0), index=0
3048  %negate = s32[2]{0:T(128)} negate(s32[2]{0:T(128)} %get-tuple-element)
3049  %copy = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %negate)
3050  ROOT tuple.5 = (s32[2]{0:T(128)}) tuple(%copy)
3051  }
3052 
3053  branch_1_comp.12.clone {
3054   %parameter.4 = (s32[2]{0:T(128)}) parameter(0)
3055   %get-tuple-element.5 = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.4), index=0
3056   %copy.1 = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %get-tuple-element.5)
3057   %constant.1 = s32[] constant(0)
3058   %broadcast.6 = s32[2] broadcast(constant.1), dimensions={}
3059   dynamic-update-slice.5 = s32[2]{0:T(128)} dynamic-update-slice(%copy.1, %broadcast.6, %constant.1)
3060   %add.1 = s32[2]{0:T(128)} add(dynamic-update-slice.5, %copy.1)
3061   ROOT tuple.6 = (s32[2]{0:T(128)}) tuple(%add.1)
3062  }
3063 
3064 ENTRY TestComputation {
3065   %parameter.1 = s32[]{:T(128)} parameter(0), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3066   %parameter.2 = s32[2]{0:T(128)} parameter(1), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3067   %parameter.3 = s32[2]{0:T(128)} parameter(2), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3068   %tuple.1 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.3)
3069   %tuple.3 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.2)
3070   %conditional.18 = (s32[2]{0:T(128)}) conditional(s32[]{:T(128)} %parameter.1, (s32[2]{0:T(128)}) %tuple.1, (s32[2]{0:T(128)}) %tuple.3), branch_computations={%branch_0_comp.5.clone, %branch_1_comp.12.clone}
3071   %gte.1 = s32[2]{0:T(128)} get-tuple-element(conditional.18), index=0
3072   ROOT tuple.4 = (s32[2]{0:T(128)},s32[2]{0:T(128)}) tuple(parameter.2, gte.1)
3073 }
3074 )";
3075   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3076                           ParseAndReturnVerifiedModule(hlo_string));
3077   CopyInsertion copy_insertion(nullptr,
3078                                /*use_region_based_live_range_analysis=*/true);
3079   ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
3080   // The copy.1 must be kept due to modification in the other branch.
3081   auto conditional18 = FindInstruction(module.get(), "conditional.18");
3082   CHECK_NE(conditional18, nullptr);
3083   auto tuple6 = conditional18->branch_computation(1)->root_instruction();
3084   CHECK_EQ(tuple6->opcode(), HloOpcode::kTuple);
3085   auto add1 = tuple6->operand(0);
3086   CHECK_EQ(add1->opcode(), HloOpcode::kAdd);
3087   auto dus = add1->operand(0);
3088   auto copy1 = dus->operand(0);
3089   CHECK_EQ(copy1->opcode(), HloOpcode::kCopy);
3090 }
3091 
TEST_F(CopyInsertionTest,ConditionalBranchMustCopy3)3092 TEST_F(CopyInsertionTest, ConditionalBranchMustCopy3) {
3093   const string& hlo_string = R"(
3094 HloModule primitive_computation_cond.19
3095 %branch_0_comp.5.clone (parameter.0: (s32[2])) -> (s32[2]) {
3096   %parameter.0 = (s32[2]{0:T(128)}) parameter(0)
3097   %get-tuple-element = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.0), index=0
3098   %negate = s32[2]{0:T(128)} negate(s32[2]{0:T(128)} %get-tuple-element)
3099   %copy = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %negate)
3100   ROOT %tuple.5 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %copy)
3101 }
3102 
3103 %branch_1_comp.12.clone (parameter.4: (s32[2])) -> (s32[2]) {
3104   %parameter.4 = (s32[2]{0:T(128)}) parameter(0)
3105   %get-tuple-element.5 = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.4), index=0
3106   %copy.1 = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %get-tuple-element.5)
3107   ROOT %tuple.6 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %copy.1)
3108 }
3109 
3110 ENTRY %primitive_computation_cond.19 (parameter.1: s32[], parameter.2: s32[2], parameter.3: s32[2]) -> (s32[2]) {
3111   %parameter.1 = s32[]{:T(128)} parameter(0), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3112   %parameter.3 = s32[2]{0:T(128)} parameter(2), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3113   %tuple.1 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.3)
3114   %parameter.2 = s32[2]{0:T(128)} parameter(1), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3115   %tuple.3 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.2)
3116   ROOT %conditional.18 = (s32[2]{0:T(128)}) conditional(s32[]{:T(128)} %parameter.1, (s32[2]{0:T(128)}) %tuple.1, (s32[2]{0:T(128)}) %tuple.3), branch_computations={%branch_0_comp.5.clone, %branch_1_comp.12.clone}, metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3117 }
3118 )";
3119   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3120                           ParseAndReturnVerifiedModule(hlo_string));
3121   InsertCopies(module.get());
3122   CopyInsertion copy_insertion(nullptr,
3123                                /*use_region_based_live_range_analysis=*/true);
3124   ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
3125   VLOG(3) << module->ToString();
3126   // The copy.1 must be kept b/c aliasing of parameter and root is not allowed.
3127   auto conditional18 = FindInstruction(module.get(), "conditional.18");
3128   CHECK_NE(conditional18, nullptr);
3129   auto tuple6 = conditional18->branch_computation(1)->root_instruction();
3130   CHECK_EQ(tuple6->opcode(), HloOpcode::kTuple);
3131   auto copy1 = tuple6->operand(0);
3132   CHECK_EQ(copy1->opcode(), HloOpcode::kCopy);
3133 }
3134 
TEST_F(CopyInsertionTest,ConditionalBranchDoNotCopy1)3135 TEST_F(CopyInsertionTest, ConditionalBranchDoNotCopy1) {
3136   const string& hlo_string = R"(
3137 HloModule TestModule
3138 
3139  branch_0_comp.5.clone {
3140  %parameter.0 = (s32[2]{0:T(128)}) parameter(0)
3141  %get-tuple-element = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.0), index=0
3142  %negate = s32[2]{0:T(128)} negate(s32[2]{0:T(128)} %get-tuple-element)
3143  %copy = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %negate)
3144  ROOT tuple.5 = (s32[2]{0:T(128)}) tuple(%copy)
3145  }
3146 
3147  branch_1_comp.12.clone {
3148   %parameter.4 = (s32[2]{0:T(128)}) parameter(0)
3149   %get-tuple-element.5 = s32[2]{0:T(128)} get-tuple-element((s32[2]{0:T(128)}) %parameter.4), index=0
3150   %copy.1 = s32[2]{0:T(128)} copy(s32[2]{0:T(128)} %get-tuple-element.5)
3151   ROOT tuple.6 = (s32[2]{0:T(128)}) tuple(%copy.1)
3152  }
3153 
3154 ENTRY TestComputation {
3155   %parameter.1 = s32[]{:T(128)} parameter(0), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3156   %parameter.2 = s32[2]{0:T(128)} parameter(1), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3157   %parameter.3 = s32[2]{0:T(128)} parameter(2), metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3158   %tuple.1 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.3)
3159   %tuple.3 = (s32[2]{0:T(128)}) tuple(s32[2]{0:T(128)} %parameter.2)
3160   %conditional.18 = (s32[2]{0:T(128)}) conditional(s32[]{:T(128)} %parameter.1, (s32[2]{0:T(128)}) %tuple.1, (s32[2]{0:T(128)}) %tuple.3), branch_computations={%branch_0_comp.5.clone, %branch_1_comp.12.clone}, metadata={op_type="cond" op_name="cond[ linear=(False, False) ]"}
3161   %gte.1 = s32[2]{0:T(128)} get-tuple-element(conditional.18), index=0
3162   ROOT tuple.4 = (s32[2]{0:T(128)},s32[2]{0:T(128)}) tuple(gte.1, gte.1)
3163 }
3164 )";
3165   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3166                           ParseAndReturnVerifiedModule(hlo_string));
3167   CopyInsertion copy_insertion(nullptr,
3168                                /*use_region_based_live_range_analysis=*/true);
3169   ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
3170   VLOG(3) << module->ToString() << "\n";
3171 
3172   // The copy.1 must be kept due to modification in the other branch.
3173   auto conditional18 = FindInstruction(module.get(), "conditional.18");
3174   CHECK_NE(conditional18, nullptr);
3175   auto tuple6 = conditional18->branch_computation(1)->root_instruction();
3176   CHECK_EQ(tuple6->opcode(), HloOpcode::kParameter);
3177 }
3178 
TEST_F(CopyInsertionTest,RootInstructionNotLast)3179 TEST_F(CopyInsertionTest, RootInstructionNotLast) {
3180   // This is a test for b/189219227. When the root instruction is scheduled not
3181   // as the last instruction, it still lives out. So, we make sure that the copy
3182   // after the root cannot be removed.
3183   const string& hlo_string = R"(
3184 HloModule module, is_scheduled=true
3185 
3186 body2 {
3187   p_body2 = (f32[2]{0}) parameter(0)
3188   p_body2.1 = f32[2]{0} get-tuple-element(p_body2), index=0
3189   add.3 = f32[2]{0} add(p_body2.1, p_body2.1)
3190   ROOT root2 = (f32[2]{0}) tuple(add.3)
3191 }
3192 
3193 condition2 {
3194   p_cond2 = (f32[2]{0}) parameter(0)
3195   ROOT result = pred[] constant(true)
3196 }
3197 
3198 body {
3199   p_body = (f32[2]{0}) parameter(0)
3200   p_body.1 = f32[2]{0} get-tuple-element(p_body), index=0
3201   ROOT root = (f32[2]{0}) tuple(p_body.1)
3202   copy = f32[2]{0} copy(p_body.1)
3203   tuple = (f32[2]{0}) tuple(copy)
3204   while.1 = (f32[2]{0}) while(tuple), condition=condition2, body=body2
3205 }
3206 
3207 condition {
3208   p_cond = (f32[2]{0}) parameter(0)
3209   ROOT result = pred[] constant(true)
3210 }
3211 
3212 ENTRY entry {
3213   const0 = f32[2]{0} constant({1, 2})
3214   while_init = (f32[2]{0}) tuple(const0)
3215   ROOT while.0 = (f32[2]{0}) while(while_init), condition=condition, body=body
3216 }
3217 )";
3218   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3219                           ParseAndReturnVerifiedModule(hlo_string));
3220   CopyInsertion copy_insertion(nullptr,
3221                                /*use_region_based_live_range_analysis=*/true);
3222   SequentialHloOrdering ordering(module->schedule());
3223   ASSERT_IS_OK(copy_insertion.RemoveUnnecessaryCopies(&ordering, module.get()));
3224   auto while_1 = FindInstruction(module.get(), "while.1");
3225   EXPECT_THAT(while_1, op::While(op::Tuple(op::Copy())));
3226 }
3227 
TEST_F(CopyInsertionTest,InPlaceCollectivePermuteCopy)3228 TEST_F(CopyInsertionTest, InPlaceCollectivePermuteCopy) {
3229   absl::string_view hlo_string = R"(
3230 HloModule hlo_runner_test_0.1
3231 ENTRY hlo_runner_test_0.1 {
3232     replica_id = u32[] replica-id()
3233     broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
3234     constant.1 = u32[] constant(1000)
3235     broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
3236     broadcast.2 = u32[4,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
3237     constant.2 = s32[] constant(0)
3238     constant.3 = s32[] constant(1)
3239     tuple.input = (u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.0, u32[2,8,128]{2,1,0:T(2,128)} broadcast.0)
3240     tuple.output = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.1, u32[4,8,128]{2,1,0:T(2,128)} broadcast.2)
3241     tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
3242     tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
3243     tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
3244     constant.4 = s32[] constant(2)
3245     tuple.5 = (s32[],s32[],s32[]) tuple(constant.4, constant.2, constant.2)
3246     tuple.6 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.5)
3247     tuple.7 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.2)
3248     tuple.8 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.7)
3249     tuple.9 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.6)
3250     tuple.10 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.7)
3251     collective-permute.0 = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) collective-permute((u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple.input, (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple.output, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.8, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.9), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128},{2,8,128},{2,8,128}}
3252     collective-permute.1 = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) collective-permute((u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple.input, (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple.output, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.8, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.10), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128},{2,8,128},{2,8,128}}
3253     ROOT tuple = ((u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}), (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)})) tuple(collective-permute.0, collective-permute.1)
3254   }
3255 )";
3256   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
3257                           ParseAndReturnVerifiedModule(hlo_string));
3258   InsertCopies(module.get());
3259   EXPECT_EQ(CountCopies(*module), 4);
3260 }
3261 
3262 }  // namespace
3263 }  // namespace xla
3264