• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
17 
18 #include "tensorflow/compiler/xla/literal.h"
19 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
22 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
23 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
26 #include "tensorflow/compiler/xla/service/instruction_fusion.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/test_helpers.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/test.h"
36 
37 namespace xla {
38 namespace {
39 
40 using ::testing::ElementsAre;
41 using ::testing::UnorderedElementsAre;
42 
43 // Test is parameterized on a bool which is whether the dataflow analysis is
44 // performed with SSA form.
45 class HloDataflowAnalysisTest : public HloTestBase,
46                                 public ::testing::WithParamInterface<bool> {
47  protected:
HloDataflowAnalysisTest()48   HloDataflowAnalysisTest() : module_(CreateNewVerifiedModule()) {}
49 
50   // Run dataflow analysis on the member module. For convenience returns a
51   // reference to the generated analysis stored in analysis_.
RunAnalysis(bool ssa_form,bool bitcast_defines_value=false)52   const HloDataflowAnalysis& RunAnalysis(bool ssa_form,
53                                          bool bitcast_defines_value = false) {
54     FlattenCallGraph flatten;
55     EXPECT_TRUE(flatten.Run(module_.get()).ok());
56     analysis_ =
57         HloDataflowAnalysis::Run(*module_, ssa_form, bitcast_defines_value)
58             .ConsumeValueOrDie();
59     return *analysis_;
60   }
61 
62   // Return a vector of the HloValues at the given program position.
HloValuesAt(const HloInstruction * instruction,const ShapeIndex & index={})63   std::vector<HloValue> HloValuesAt(const HloInstruction* instruction,
64                                     const ShapeIndex& index = {}) {
65     CHECK(analysis_ != nullptr);
66     std::vector<HloValue> values;
67     for (const HloValue* value :
68          analysis_->GetValueSet(instruction, index).values()) {
69       values.push_back(*value);
70     }
71     return values;
72   }
73 
74   // Returns true if the top-level values for instructions 'a' and 'b' may
75   // interfere. Precondition: 'a' and 'b' define array-shaped values.
InstructionsMayInterfere(const HloOrdering & ordering,const HloInstruction * a,const HloInstruction * b)76   bool InstructionsMayInterfere(const HloOrdering& ordering,
77                                 const HloInstruction* a,
78                                 const HloInstruction* b) {
79     EXPECT_FALSE(a->shape().IsTuple());
80     EXPECT_FALSE(b->shape().IsTuple());
81     return ordering.MayInterfere(analysis_->GetValueDefinedAt(a),
82                                  analysis_->GetValueDefinedAt(b), *analysis_);
83   }
84 
CreateR0F32UnaryOpComputation(HloOpcode opcode)85   std::unique_ptr<HloComputation> CreateR0F32UnaryOpComputation(
86       HloOpcode opcode) {
87     HloComputation::Builder builder(TestName() + "." + HloOpcodeString(opcode));
88     HloInstruction* param0 = builder.AddInstruction(
89         HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
90     builder.AddInstruction(
91         HloInstruction::CreateUnary(scalar_shape_, opcode, param0));
92     return builder.Build();
93   }
94 
95   std::unique_ptr<HloModule> module_;
96   std::unique_ptr<HloDataflowAnalysis> analysis_;
97 
98   const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
99   const Shape vector_shape_ = ShapeUtil::MakeShape(F32, {42});
100   const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
101       {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})});
102 };
103 
TEST_P(HloDataflowAnalysisTest,BinaryOperation)104 TEST_P(HloDataflowAnalysisTest, BinaryOperation) {
105   // Test the dataflow for a simple binary operation (Add).
106   auto builder = HloComputation::Builder(TestName());
107   auto constant1 = builder.AddInstruction(
108       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
109   auto constant2 = builder.AddInstruction(
110       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
111   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
112       scalar_shape_, HloOpcode::kAdd, constant1, constant2));
113   module_->AddEntryComputation(builder.Build());
114   SCOPED_TRACE(module_->ToString());
115 
116   bool ssa_form = GetParam();
117   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
118 
119   // Each instruction should define a single value.
120   EXPECT_EQ(analysis.values().size(), 3);
121   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
122   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
123   EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
124 
125   // Verify the positions of the values. These positions are all trivial because
126   // there are no instructions which forward values.
127   EXPECT_THAT(analysis.GetValueDefinedAt(constant1).positions(),
128               UnorderedElementsAre(HloPosition{constant1, {}}));
129   EXPECT_THAT(analysis.GetValueDefinedAt(constant2).positions(),
130               UnorderedElementsAre(HloPosition{constant2, {}}));
131   EXPECT_THAT(analysis.GetValueDefinedAt(add).positions(),
132               UnorderedElementsAre(HloPosition{add, {}}));
133 
134   // Verify the uses of the values.
135   EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
136               UnorderedElementsAre(HloUse{add, 0, {}}));
137   EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
138               UnorderedElementsAre(HloUse{add, 1, {}}));
139   EXPECT_TRUE(analysis.GetValueDefinedAt(add).uses().empty());
140 
141   // Verify liveout values from the module.
142   EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
143   EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
144   EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
145 }
146 
TEST_P(HloDataflowAnalysisTest,TupleAndGtes)147 TEST_P(HloDataflowAnalysisTest, TupleAndGtes) {
148   // Verify the dataflow through a Tuple and GetTupleElement instructions.
149   auto builder = HloComputation::Builder(TestName());
150   auto param0 = builder.AddInstruction(
151       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
152   auto param1 = builder.AddInstruction(
153       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
154   auto tuple =
155       builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
156   auto gte0 = builder.AddInstruction(
157       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 0));
158   auto gte1 = builder.AddInstruction(
159       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1));
160   auto add = builder.AddInstruction(
161       HloInstruction::CreateBinary(scalar_shape_, HloOpcode::kAdd, gte0, gte1));
162   module_->AddEntryComputation(builder.Build());
163   SCOPED_TRACE(module_->ToString());
164 
165   bool ssa_form = GetParam();
166   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
167 
168   // The two params, tuple, and add should each define one value.
169   EXPECT_EQ(analysis.values().size(), 4);
170 
171   EXPECT_TRUE(analysis.ValueIsDefinedAt(param0));
172   EXPECT_TRUE(analysis.ValueIsDefinedAt(param1));
173   EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple, /*index=*/{}));
174   EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{0}));
175   EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{1}));
176   EXPECT_FALSE(analysis.ValueIsDefinedAt(gte0));
177   EXPECT_FALSE(analysis.ValueIsDefinedAt(gte1));
178   EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
179 
180   // Verify the positions of the values.
181   EXPECT_THAT(
182       analysis.GetValueDefinedAt(param0).positions(),
183       UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}},
184                            HloPosition{gte0, {}}));
185   EXPECT_THAT(
186       analysis.GetValueDefinedAt(param1).positions(),
187       UnorderedElementsAre(HloPosition{param1, {}}, HloPosition{tuple, {1}},
188                            HloPosition{gte1, {}}));
189   EXPECT_THAT(analysis.GetValueDefinedAt(tuple).positions(),
190               UnorderedElementsAre(HloPosition{tuple, {}}));
191 
192   // Verify uses. Of interest is that a GetTupleElement instruction is only a
193   // use of the top-level value in the tuple operand.
194   EXPECT_THAT(analysis.GetValueDefinedAt(param0).uses(),
195               UnorderedElementsAre(HloUse{add, 0, {}}));
196   EXPECT_THAT(analysis.GetValueDefinedAt(param1).uses(),
197               UnorderedElementsAre(HloUse{add, 1, {}}));
198   EXPECT_THAT(analysis.GetValueDefinedAt(tuple, /*index=*/{}).uses(),
199               UnorderedElementsAre(HloUse{gte0, 0, {}}, HloUse{gte1, 0, {}}));
200   EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
201 }
202 
TEST_P(HloDataflowAnalysisTest,NestedTuple)203 TEST_P(HloDataflowAnalysisTest, NestedTuple) {
204   // Verify the dataflow through a nested tuple.
205   auto builder = HloComputation::Builder(TestName());
206   auto constant1 = builder.AddInstruction(
207       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
208   auto constant2 = builder.AddInstruction(
209       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
210   auto tuple = builder.AddInstruction(
211       HloInstruction::CreateTuple({constant1, constant2}));
212   auto nested_tuple = builder.AddInstruction(
213       HloInstruction::CreateTuple({tuple, tuple, constant1}));
214   auto gte_tuple = builder.AddInstruction(
215       HloInstruction::CreateGetTupleElement(tuple->shape(), nested_tuple, 1));
216   auto gte_out = builder.AddInstruction(
217       HloInstruction::CreateGetTupleElement(scalar_shape_, gte_tuple, 0));
218   module_->AddEntryComputation(builder.Build());
219   SCOPED_TRACE(module_->ToString());
220 
221   bool ssa_form = GetParam();
222   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
223 
224   EXPECT_EQ(analysis.values().size(), 4);
225 
226   // Verify positions and uses.
227   EXPECT_THAT(
228       analysis.GetValueDefinedAt(constant1).positions(),
229       UnorderedElementsAre(
230           HloPosition{constant1, {}}, HloPosition{tuple, {0}},
231           HloPosition{nested_tuple, {0, 0}}, HloPosition{nested_tuple, {1, 0}},
232           HloPosition{nested_tuple, {2}}, HloPosition{gte_tuple, {0}},
233           HloPosition{gte_out, {}}));
234   // Constant values should have only a single use, which is the root of the
235   // computation.
236   EXPECT_THAT(analysis.GetValueDefinedAt(constant1, /*index=*/{}).uses(),
237               UnorderedElementsAre(HloUse{gte_out, 0, {0}}));
238   EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).uses().empty());
239 
240   // The top-level tuple values are used in GTE instructions.
241   EXPECT_THAT(analysis.GetValueDefinedAt(tuple, /*index=*/{}).uses(),
242               UnorderedElementsAre(HloUse{gte_out, 0, {}}));
243   EXPECT_THAT(analysis.GetValueDefinedAt(nested_tuple, /*index=*/{}).uses(),
244               UnorderedElementsAre(HloUse{gte_tuple, 0, {}}));
245 
246   EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
247   EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
248   EXPECT_FALSE(
249       analysis.GetValueDefinedAt(tuple, /*index=*/{}).live_out_of_module());
250   EXPECT_FALSE(analysis.GetValueDefinedAt(nested_tuple, /*index=*/{})
251                    .live_out_of_module());
252 }
253 
TEST_P(HloDataflowAnalysisTest,SingleCall)254 TEST_P(HloDataflowAnalysisTest, SingleCall) {
255   // Test a single call of a subcomputation. The subcomputation adds its two
256   // array-shaped parameters.
257   auto subbuilder = HloComputation::Builder("Subcomputation");
258   auto subparam0 = subbuilder.AddInstruction(
259       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
260   auto subparam1 = subbuilder.AddInstruction(
261       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
262   auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary(
263       scalar_shape_, HloOpcode::kAdd, subparam0, subparam1));
264   HloComputation* called_computation =
265       module_->AddEmbeddedComputation(subbuilder.Build());
266 
267   auto builder = HloComputation::Builder(TestName());
268   auto constant1 = builder.AddInstruction(
269       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
270   auto constant2 = builder.AddInstruction(
271       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
272   auto call = builder.AddInstruction(HloInstruction::CreateCall(
273       scalar_shape_, {constant1, constant2}, called_computation));
274   module_->AddEntryComputation(builder.Build());
275   SCOPED_TRACE(module_->ToString());
276 
277   bool ssa_form = GetParam();
278   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
279 
280   EXPECT_EQ(analysis.values().size(), 3);
281 
282   // The parameters of the subcomputation and the call instruction itself should
283   // not define values. Their values flow from elsewhere.
284   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
285   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
286   EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam0));
287   EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam1));
288   EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
289   EXPECT_FALSE(analysis.ValueIsDefinedAt(call));
290 
291   EXPECT_EQ(analysis.GetUniqueValueAt(subparam0),
292             analysis.GetValueDefinedAt(constant1));
293   EXPECT_EQ(analysis.GetUniqueValueAt(subparam1),
294             analysis.GetValueDefinedAt(constant2));
295   EXPECT_EQ(analysis.GetUniqueValueAt(call), analysis.GetValueDefinedAt(add));
296 
297   EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
298               UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{add, 0, {}}));
299   EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
300               UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{add, 1, {}}));
301 
302   EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
303 }
304 
TEST_P(HloDataflowAnalysisTest,NestedCalls)305 TEST_P(HloDataflowAnalysisTest, NestedCalls) {
306   // Test a module with nested computations. HLO is:
307   //
308   // F32[] inner_computation(F32[] %param0, F32[] %param1):
309   //   %add = Add(%param0, %param1)
310   //
311   // F32[] outer_computation((F32[] %param0, F32[] %param1):
312   //  ;; Note that parameters are interchanged in the call.
313   //   %nested_call = Call(inner_computation, {%param1, %param0})
314   //
315   // F32[] entry:
316   //   %constant1 = Constant(1.0)
317   //   %constant2 = Constant(2.0)
318   //   %call = Call(outer_computation, {%constant1, %constant2})
319   //
320   auto inner_builder = HloComputation::Builder("InnerComputation");
321   auto inner_param0 = inner_builder.AddInstruction(
322       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
323   auto inner_param1 = inner_builder.AddInstruction(
324       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
325   auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary(
326       scalar_shape_, HloOpcode::kAdd, inner_param0, inner_param1));
327   HloComputation* inner_computation =
328       module_->AddEmbeddedComputation(inner_builder.Build());
329 
330   auto outer_builder = HloComputation::Builder("OuterComputation");
331   auto outer_param0 = outer_builder.AddInstruction(
332       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
333   auto outer_param1 = outer_builder.AddInstruction(
334       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
335   // Swizzle parameters.
336   auto nested_call = outer_builder.AddInstruction(HloInstruction::CreateCall(
337       scalar_shape_, {outer_param1, outer_param0}, inner_computation));
338   HloComputation* outer_computation =
339       module_->AddEmbeddedComputation(outer_builder.Build());
340 
341   auto builder = HloComputation::Builder(TestName());
342   auto constant1 = builder.AddInstruction(
343       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
344   auto constant2 = builder.AddInstruction(
345       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
346   auto call = builder.AddInstruction(HloInstruction::CreateCall(
347       scalar_shape_, {constant1, constant2}, outer_computation));
348   module_->AddEntryComputation(builder.Build());
349   SCOPED_TRACE(module_->ToString());
350 
351   bool ssa_form = GetParam();
352   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
353 
354   // Only three values should be defined. Most instructions just pass through
355   // their operand values.
356   EXPECT_EQ(analysis.values().size(), 3);
357 
358   // Verify that the uses of the constants are properly swizzled by parameter
359   // permutation in nested_call.
360   EXPECT_THAT(
361       analysis.GetValueDefinedAt(constant1).uses(),
362       UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{nested_call, 1, {}},
363                            HloUse{add, 1, {}}));
364   EXPECT_THAT(
365       analysis.GetValueDefinedAt(constant2).uses(),
366       UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{nested_call, 0, {}},
367                            HloUse{add, 0, {}}));
368 
369   EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
370 }
371 
TEST_P(HloDataflowAnalysisTest,SingleWhile)372 TEST_P(HloDataflowAnalysisTest, SingleWhile) {
373   // Test a simple single while instruction. The while body includes a
374   // pass-through value. HLO:
375   //
376   // body((F32[], F32[]) %tuple_param):
377   //   %add = Add(%tuple_param{0}, %tuple_param{1})
378   //   return Tuple(%tuple_param{0}, %add)
379   //
380   // condition((F32[], F32[]) %tuple_param):
381   //   return Constant(false)
382   //
383   // entry:
384   //   %constant1 = Constant(1.0)
385   //   %constant2 = Constant(2.0)
386   //   %tuple = Tuple(%constant1, %constant2)
387   //   return While(%tuple, body, condition)
388   //
389   const Shape tuple_shape =
390       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
391 
392   // Element 0 passes transparently through the body.
393   auto body_builder = HloComputation::Builder("body");
394   auto body_param = body_builder.AddInstruction(
395       HloInstruction::CreateParameter(0, tuple_shape, "param"));
396   auto body_element_0 = body_builder.AddInstruction(
397       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
398   auto body_element_1 = body_builder.AddInstruction(
399       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
400   auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
401       scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
402   auto body_root = body_builder.AddInstruction(
403       HloInstruction::CreateTuple({body_element_0, add}));
404   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
405 
406   // Condition computation trivially returns a constant "false".
407   auto cond_builder = HloComputation::Builder("condition");
408   auto cond_param = cond_builder.AddInstruction(
409       HloInstruction::CreateParameter(0, tuple_shape, "param"));
410   auto cond_constant = cond_builder.AddInstruction(
411       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
412   HloComputation* condition =
413       module_->AddEmbeddedComputation(cond_builder.Build());
414 
415   auto builder = HloComputation::Builder(TestName());
416   auto constant1 = builder.AddInstruction(
417       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
418   auto constant2 = builder.AddInstruction(
419       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
420   auto tuple = builder.AddInstruction(
421       HloInstruction::CreateTuple({constant1, constant2}));
422   auto xla_while = builder.AddInstruction(
423       HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
424   module_->AddEntryComputation(builder.Build());
425   SCOPED_TRACE(module_->ToString());
426 
427   bool ssa_form = GetParam();
428   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
429 
430   EXPECT_FALSE(analysis.GetValueDefinedAt(cond_constant).live_out_of_module());
431 
432   if (ssa_form) {
433     // Element 0 of the tuple passed through the body so no phi value is
434     // defined.
435     EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
436     EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
437     EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
438 
439     // Element 1 of the tuple should be a phi value.
440     EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
441     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi());
442     EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
443     EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{1}).is_phi());
444     EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
445     EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi());
446 
447     EXPECT_THAT(
448         analysis.GetValueDefinedAt(constant1).uses(),
449         UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{body_root, 0, {}},
450                              HloUse{xla_while, 0, {0}}));
451 
452     // Constant1 passes through the body and out of the module.
453     EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
454     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
455                     .live_out_of_module());
456 
457     EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module());
458   } else {
459     // While instruction and subcomputation parameters should not define values
460     // in non-ssa form.
461     EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
462     EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
463     EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
464     EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
465     EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
466     EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
467 
468     EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
469     EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
470   }
471 }
472 
TEST_P(HloDataflowAnalysisTest,SequentialWhiles)473 TEST_P(HloDataflowAnalysisTest, SequentialWhiles) {
474   // Test sequential while instructions. The while body includes a
475   // pass-through value. HLO:
476   //
477   // body((F32[], F32[]) %tuple_param):
478   //   %add = Add(%tuple_param{0}, %tuple_param{1})
479   //   return Tuple(%tuple_param{0}, %add)
480   //
481   // condition((F32[], F32[]) %tuple_param):
482   //   return Constant(false)
483   //
484   // entry:
485   //   %constant1 = Constant(1.0)
486   //   %constant2 = Constant(2.0)
487   //   %tuple = Tuple(%constant1, %constant2)
488   //   %while0 = While(%tuple, body, condition)
489   //   %while1 = While(%while0, body, condition)
490   //   return While(%while1, body, condition)
491   //
492   const Shape tuple_shape =
493       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
494 
495   // Element 0 passes transparently through the body.
496   auto body_builder = HloComputation::Builder("body");
497   auto body_param = body_builder.AddInstruction(
498       HloInstruction::CreateParameter(0, tuple_shape, "param"));
499   auto body_element_0 = body_builder.AddInstruction(
500       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
501   auto body_element_1 = body_builder.AddInstruction(
502       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
503   auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
504       scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
505   body_builder.AddInstruction(
506       HloInstruction::CreateTuple({body_element_0, add}));
507   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
508 
509   auto cond_builder = HloComputation::Builder("condition");
510   cond_builder.AddInstruction(
511       HloInstruction::CreateParameter(0, tuple_shape, "param"));
512   cond_builder.AddInstruction(
513       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
514   HloComputation* condition =
515       module_->AddEmbeddedComputation(cond_builder.Build());
516 
517   auto builder = HloComputation::Builder(TestName());
518   auto constant1 = builder.AddInstruction(
519       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
520   auto constant2 = builder.AddInstruction(
521       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
522   auto tuple = builder.AddInstruction(
523       HloInstruction::CreateTuple({constant1, constant2}));
524   auto xla_while0 = builder.AddInstruction(
525       HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
526   auto xla_while1 = builder.AddInstruction(
527       HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while0));
528   auto xla_while2 = builder.AddInstruction(
529       HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while1));
530   module_->AddEntryComputation(builder.Build());
531   SCOPED_TRACE(module_->ToString());
532 
533   bool ssa_form = GetParam();
534   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
535 
536   // Element 0 is passed through all the while instructions and out of the
537   // module..
538   EXPECT_EQ(analysis.GetUniqueValueAt(xla_while0, /*index=*/{0}),
539             analysis.GetValueDefinedAt(constant1));
540   EXPECT_EQ(analysis.GetUniqueValueAt(xla_while1, /*index=*/{0}),
541             analysis.GetValueDefinedAt(constant1));
542   EXPECT_EQ(analysis.GetUniqueValueAt(xla_while2, /*index=*/{0}),
543             analysis.GetValueDefinedAt(constant1));
544   EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
545 }
546 
TEST_P(HloDataflowAnalysisTest,MultiLevelNestedWhile)547 TEST_P(HloDataflowAnalysisTest, MultiLevelNestedWhile) {
548   // Test nested while instructions. The level0 body (most inner while) and
549   // level1 body pass through the parameter, while level2 (most outer while)
550   // modifies it.
551   //
552   // level0_body((F32[]) %tuple_param):
553   //   return Tuple(%tuple_param{0})
554   //
555   // level1_body((F32[]) %tuple_param):
556   //   return While(%tuple_param{0}), body=level0
557   //
558   // level2_body((F32[]) %tuple_param):
559   //   while = While(%tuple_param{0}), body=level1
560   //.  return negate(%while{0})
561   //
562   // entry:
563   //   %constant = Constant(1.0)
564   //   %tuple = Tuple(%constant)
565   //   return While(%tuple), body=level2
566   //
567   const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_});
568   auto cond_builder = HloComputation::Builder("condition");
569   cond_builder.AddInstruction(
570       HloInstruction::CreateParameter(0, tuple_shape, "param"));
571   cond_builder.AddInstruction(
572       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
573   HloComputation* condition =
574       module_->AddEmbeddedComputation(cond_builder.Build());
575 
576   // level 0 passes transparently through the body.
577   auto level0_builder = HloComputation::Builder("level0_body");
578   auto level0_param = level0_builder.AddInstruction(
579       HloInstruction::CreateParameter(0, tuple_shape, "param"));
580   auto level0_element_0 = level0_builder.AddInstruction(
581       HloInstruction::CreateGetTupleElement(scalar_shape_, level0_param, 0));
582   auto level0_root = level0_builder.AddInstruction(
583       HloInstruction::CreateTuple({level0_element_0}));
584   HloComputation* level0_body =
585       module_->AddEmbeddedComputation(level0_builder.Build());
586 
587   // Element 1 passes transparently through the body.
588   auto level1_builder = HloComputation::Builder("level1_body");
589   auto level1_param = level1_builder.AddInstruction(
590       HloInstruction::CreateParameter(0, tuple_shape, "param"));
591   auto level1_root = level1_builder.AddInstruction(HloInstruction::CreateWhile(
592       tuple_shape, condition, level0_body, level1_param));
593   HloComputation* level1_body =
594       module_->AddEmbeddedComputation(level1_builder.Build());
595 
596   // Element 1 passes transparently through the body.
597   auto level2_builder = HloComputation::Builder("level2_body");
598   auto level2_param = level2_builder.AddInstruction(
599       HloInstruction::CreateParameter(0, tuple_shape, "param"));
600   auto level2_while = level2_builder.AddInstruction(HloInstruction::CreateWhile(
601       tuple_shape, condition, level1_body, level2_param));
602   auto level2_element_0 = level2_builder.AddInstruction(
603       HloInstruction::CreateGetTupleElement(scalar_shape_, level2_while, 0));
604   auto negate = level2_builder.AddInstruction(HloInstruction::CreateUnary(
605       scalar_shape_, HloOpcode::kNegate, level2_element_0));
606   level2_builder.AddInstruction(HloInstruction::CreateTuple({negate}));
607   HloComputation* level2_body =
608       module_->AddEmbeddedComputation(level2_builder.Build());
609 
610   auto builder = HloComputation::Builder(TestName());
611   auto constant1 = builder.AddInstruction(
612       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
613   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
614   builder.AddInstruction(
615       HloInstruction::CreateWhile(tuple_shape, condition, level2_body, tuple));
616   module_->AddEntryComputation(builder.Build());
617   SCOPED_TRACE(module_->ToString());
618 
619   bool ssa_form = GetParam();
620   if (!ssa_form) {
621     return;
622   }
623   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
624 
625   // Phi node on inner parameters and roots should have been eliminated.
626   EXPECT_FALSE(analysis.ValueIsDefinedAt(level1_param, /*index=*/{0}));
627   EXPECT_FALSE(analysis.ValueIsDefinedAt(level0_param, /*index=*/{0}));
628   EXPECT_FALSE(analysis.ValueIsDefinedAt(level1_root, /*index=*/{0}));
629   EXPECT_FALSE(analysis.ValueIsDefinedAt(level0_root, /*index=*/{0}));
630   EXPECT_TRUE(analysis.ValueIsDefinedAt(level2_param, /*index=*/{0}));
631   EXPECT_EQ(HloValuesAt(level1_param, /*index=*/{0}),
632             HloValuesAt(level2_param, /*index=*/{0}));
633   EXPECT_EQ(HloValuesAt(level0_param, /*index=*/{0}),
634             HloValuesAt(level2_param, /*index=*/{0}));
635   EXPECT_EQ(HloValuesAt(level1_root, /*index=*/{0}),
636             HloValuesAt(level2_param, /*index=*/{0}));
637   EXPECT_EQ(HloValuesAt(level0_root, /*index=*/{0}),
638             HloValuesAt(level2_param, /*index=*/{0}));
639 }
640 
TEST_P(HloDataflowAnalysisTest,NestedWhiles)641 TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
642   // Test nested while instructions. The inner body passes through element 0 of
643   // its parameter, and the outer body passes through element 1.  HLO:
644   //
645   // inner_body((F32[], F32[]) %tuple_param):
646   //   %add = Add(%tuple_param{0}, %tuple_param{1})
647   //   return Tuple(%tuple_param{0}, %add)
648   //
649   // outer_body((F32[], F32[]) %tuple_param):
650   //   %negate = Negate(%tuple_param{0})
651   //   %tuple = Tuple(%negate, %tuple_param{1})
652   //   return While(%tuple, inner_body, condition)
653   //
654   // entry:
655   //   %constant1 = Constant(1.0)
656   //   %constant2 = Constant(2.0)
657   //   %tuple = Tuple(%constant1, %constant2)
658   //   return While(%tuple, outer_body, condition)
659   //
660   const Shape tuple_shape =
661       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
662 
663   auto cond_builder = HloComputation::Builder("condition");
664   cond_builder.AddInstruction(
665       HloInstruction::CreateParameter(0, tuple_shape, "param"));
666   cond_builder.AddInstruction(
667       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
668   HloComputation* condition =
669       module_->AddEmbeddedComputation(cond_builder.Build());
670 
671   // Element 0 passes transparently through the body.
672   auto inner_builder = HloComputation::Builder("inner_body");
673   auto inner_param = inner_builder.AddInstruction(
674       HloInstruction::CreateParameter(0, tuple_shape, "param"));
675   auto inner_element_0 = inner_builder.AddInstruction(
676       HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 0));
677   auto inner_element_1 = inner_builder.AddInstruction(
678       HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 1));
679   auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary(
680       scalar_shape_, HloOpcode::kAdd, inner_element_0, inner_element_1));
681   inner_builder.AddInstruction(
682       HloInstruction::CreateTuple({inner_element_0, add}));
683   HloComputation* inner_body =
684       module_->AddEmbeddedComputation(inner_builder.Build());
685 
686   // Element 1 passes transparently through the body.
687   auto outer_builder = HloComputation::Builder("outer_body");
688   auto outer_param = outer_builder.AddInstruction(
689       HloInstruction::CreateParameter(0, tuple_shape, "param"));
690   auto outer_element_0 = outer_builder.AddInstruction(
691       HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 0));
692   auto negate = outer_builder.AddInstruction(HloInstruction::CreateUnary(
693       scalar_shape_, HloOpcode::kNegate, outer_element_0));
694   auto outer_element_1 = outer_builder.AddInstruction(
695       HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 1));
696   auto outer_tuple = outer_builder.AddInstruction(
697       HloInstruction::CreateTuple({negate, outer_element_1}));
698   auto nested_while = outer_builder.AddInstruction(HloInstruction::CreateWhile(
699       tuple_shape, condition, inner_body, outer_tuple));
700   HloComputation* outer_body =
701       module_->AddEmbeddedComputation(outer_builder.Build());
702 
703   auto builder = HloComputation::Builder(TestName());
704   auto constant1 = builder.AddInstruction(
705       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
706   auto constant2 = builder.AddInstruction(
707       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
708   auto tuple = builder.AddInstruction(
709       HloInstruction::CreateTuple({constant1, constant2}));
710   auto entry_while = builder.AddInstruction(
711       HloInstruction::CreateWhile(tuple_shape, condition, outer_body, tuple));
712   module_->AddEntryComputation(builder.Build());
713   SCOPED_TRACE(module_->ToString());
714 
715   bool ssa_form = GetParam();
716   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
717 
718   EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
719               UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
720   if (ssa_form) {
721     EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_param, /*index=*/{1}));
722     EXPECT_TRUE(
723         analysis.GetValueDefinedAt(inner_param, /*index=*/{1}).is_phi());
724 
725     // Element 0 of the nested while is %negate.
726     EXPECT_FALSE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{0}));
727     EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
728                 UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
729     // Element 1 is a phi value (join of %add and %constant2).
730     EXPECT_TRUE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{1}));
731     EXPECT_TRUE(
732         analysis.GetValueDefinedAt(nested_while, /*index=*/{1}).is_phi());
733 
734     EXPECT_TRUE(analysis.ValueIsDefinedAt(entry_while, /*index=*/{0}));
735     EXPECT_TRUE(
736         analysis.GetValueDefinedAt(entry_while, /*index=*/{0}).is_phi());
737 
738     EXPECT_TRUE(analysis.ValueIsDefinedAt(entry_while, /*index=*/{1}));
739     EXPECT_TRUE(
740         analysis.GetValueDefinedAt(entry_while, /*index=*/{1}).is_phi());
741   } else {
742     EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{1}),
743                 UnorderedElementsAre(analysis.GetValueDefinedAt(add),
744                                      analysis.GetValueDefinedAt(constant2)));
745 
746     EXPECT_THAT(HloValuesAt(nested_while, /*index=*/{0}),
747                 UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
748     EXPECT_THAT(HloValuesAt(nested_while, /*index=*/{1}),
749                 UnorderedElementsAre(analysis.GetValueDefinedAt(add),
750                                      analysis.GetValueDefinedAt(constant2)));
751 
752     EXPECT_THAT(HloValuesAt(entry_while, /*index=*/{0}),
753                 UnorderedElementsAre(analysis.GetValueDefinedAt(negate),
754                                      analysis.GetValueDefinedAt(constant1)));
755     EXPECT_THAT(HloValuesAt(entry_while, /*index=*/{1}),
756                 UnorderedElementsAre(analysis.GetValueDefinedAt(add),
757                                      analysis.GetValueDefinedAt(constant2)));
758   }
759 }
760 
TEST_P(HloDataflowAnalysisTest,SwizzlingWhileSharedInput)761 TEST_P(HloDataflowAnalysisTest, SwizzlingWhileSharedInput) {
762   // Test a while instruction with a body which permutes it's tuple parameter
763   // elements. HLO:
764   //
765   // body((F32[], F32[]) %tuple_param):
766   //   return Tuple(%tuple_param{1}, %tuple_param{0})
767   //
768   // condition((F32[], F32[]) %tuple_param):
769   //   return Constant(false)
770   //
771   // entry:
772   //   %constant1 = Constant(1.0)
773   //   %tuple = Tuple(%constant1, %constant1)
774   //   return While(%tuple, body, condition)
775   //
776   const Shape tuple_shape =
777       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
778 
779   auto body_builder = HloComputation::Builder("body");
780   auto body_param = body_builder.AddInstruction(
781       HloInstruction::CreateParameter(0, tuple_shape, "param"));
782   auto body_element_0 = body_builder.AddInstruction(
783       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
784   auto body_element_1 = body_builder.AddInstruction(
785       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
786   body_builder.AddInstruction(
787       HloInstruction::CreateTuple({body_element_1, body_element_0}));
788   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
789 
790   auto cond_builder = HloComputation::Builder("condition");
791   cond_builder.AddInstruction(
792       HloInstruction::CreateParameter(0, tuple_shape, "param"));
793   cond_builder.AddInstruction(
794       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
795   HloComputation* condition =
796       module_->AddEmbeddedComputation(cond_builder.Build());
797 
798   auto builder = HloComputation::Builder(TestName());
799   auto constant1 = builder.AddInstruction(
800       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
801   auto tuple = builder.AddInstruction(
802       HloInstruction::CreateTuple({constant1, constant1}));
803   builder.AddInstruction(
804       HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
805   module_->AddEntryComputation(builder.Build());
806   SCOPED_TRACE(module_->ToString());
807 
808   bool ssa_form = GetParam();
809   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
810   EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
811 }
812 
TEST_P(HloDataflowAnalysisTest,SwizzlingWhile)813 TEST_P(HloDataflowAnalysisTest, SwizzlingWhile) {
814   // Test a while instruction with a body which permutes it's tuple parameter
815   // elements. HLO:
816   //
817   // body((F32[], F32[]) %tuple_param):
818   //   return Tuple(%tuple_param{1}, %tuple_param{0})
819   //
820   // condition((F32[], F32[]) %tuple_param):
821   //   return Constant(false)
822   //
823   // entry:
824   //   %constant1 = Constant(1.0)
825   //   %constant2 = Constant(2.0)
826   //   %tuple = Tuple(%constant1, %constant2)
827   //   return While(%tuple, body, condition)
828   //
829   const Shape tuple_shape =
830       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
831 
832   auto body_builder = HloComputation::Builder("body");
833   auto body_param = body_builder.AddInstruction(
834       HloInstruction::CreateParameter(0, tuple_shape, "param"));
835   auto body_element_0 = body_builder.AddInstruction(
836       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
837   auto body_element_1 = body_builder.AddInstruction(
838       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
839   body_builder.AddInstruction(
840       HloInstruction::CreateTuple({body_element_1, body_element_0}));
841   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
842 
843   auto cond_builder = HloComputation::Builder("condition");
844   auto cond_param = cond_builder.AddInstruction(
845       HloInstruction::CreateParameter(0, tuple_shape, "param"));
846   cond_builder.AddInstruction(
847       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
848   HloComputation* condition =
849       module_->AddEmbeddedComputation(cond_builder.Build());
850 
851   auto builder = HloComputation::Builder(TestName());
852   auto constant1 = builder.AddInstruction(
853       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
854   auto constant2 = builder.AddInstruction(
855       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
856   auto tuple = builder.AddInstruction(
857       HloInstruction::CreateTuple({constant1, constant2}));
858   auto xla_while = builder.AddInstruction(
859       HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
860   module_->AddEntryComputation(builder.Build());
861   SCOPED_TRACE(module_->ToString());
862 
863   bool ssa_form = GetParam();
864   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
865 
866   if (ssa_form) {
867     // Element 0 and 1 in the while should both be phi values.
868     EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
869     EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{0}).is_phi());
870     EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
871     EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{1}).is_phi());
872 
873     EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
874     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}).is_phi());
875     EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
876     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi());
877 
878     EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
879     EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{0}).is_phi());
880     EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
881     EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi());
882 
883     EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
884     EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
885     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{})
886                     .live_out_of_module());
887     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0})
888                     .live_out_of_module());
889     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
890                     .live_out_of_module());
891   } else {
892     // Elements 0 and 1 have both constants as reaching definitions.
893     EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{0}),
894                 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
895                                      analysis.GetValueDefinedAt(constant2)));
896     EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{1}),
897                 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
898                                      analysis.GetValueDefinedAt(constant2)));
899     EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
900     EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
901   }
902 }
903 
TEST_P(HloDataflowAnalysisTest,ArraySelect)904 TEST_P(HloDataflowAnalysisTest, ArraySelect) {
905   // Test a kSelect of an array value.
906   auto builder = HloComputation::Builder(TestName());
907   auto pred = builder.AddInstruction(
908       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
909   auto constant1 = builder.AddInstruction(
910       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
911   auto constant2 = builder.AddInstruction(
912       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
913   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
914       scalar_shape_, HloOpcode::kSelect, pred, constant1, constant2));
915 
916   module_->AddEntryComputation(builder.Build());
917   SCOPED_TRACE(module_->ToString());
918 
919   bool ssa_form = GetParam();
920   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
921 
922   EXPECT_TRUE(analysis.ValueIsDefinedAt(select));
923   EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
924   EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
925   EXPECT_TRUE(analysis.GetValueDefinedAt(select).live_out_of_module());
926 }
927 
TEST_P(HloDataflowAnalysisTest,TupleSelect)928 TEST_P(HloDataflowAnalysisTest, TupleSelect) {
929   // Test a kTupleSelect. Non-top-level element flow through the instruction.
930   auto builder = HloComputation::Builder(TestName());
931   auto pred = builder.AddInstruction(
932       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
933   auto constant1 = builder.AddInstruction(
934       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
935   auto constant2 = builder.AddInstruction(
936       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
937   auto constant3 = builder.AddInstruction(
938       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
939   auto constant4 = builder.AddInstruction(
940       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4.0)));
941   auto tuple1 =
942       builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
943   auto tuple2 =
944       builder.AddInstruction(HloInstruction::CreateTuple({constant2}));
945   auto tuple3 =
946       builder.AddInstruction(HloInstruction::CreateTuple({constant3}));
947   auto tuple4 =
948       builder.AddInstruction(HloInstruction::CreateTuple({constant4}));
949   const Shape tuple_shape = tuple1->shape();
950   auto select11 = builder.AddInstruction(HloInstruction::CreateTernary(
951       tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple1));
952   auto select12 = builder.AddInstruction(HloInstruction::CreateTernary(
953       tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple2));
954   auto select34 = builder.AddInstruction(HloInstruction::CreateTernary(
955       tuple_shape, HloOpcode::kTupleSelect, pred, tuple3, tuple4));
956   auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary(
957       tuple_shape, HloOpcode::kTupleSelect, pred, select12, select34));
958 
959   module_->AddEntryComputation(builder.Build());
960   SCOPED_TRACE(module_->ToString());
961 
962   bool ssa_form = GetParam();
963   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
964 
965   // Top-level value is always defined by a kTupleSelect.
966   EXPECT_TRUE(analysis.ValueIsDefinedAt(select11));
967   EXPECT_TRUE(analysis.ValueIsDefinedAt(select12));
968   EXPECT_TRUE(analysis.ValueIsDefinedAt(select34));
969   EXPECT_TRUE(analysis.ValueIsDefinedAt(select1234));
970 
971   EXPECT_FALSE(analysis.ValueIsDefinedAt(select11, /*index=*/{0}));
972   EXPECT_FALSE(analysis.ValueIsDefinedAt(select12, /*index=*/{0}));
973   EXPECT_FALSE(analysis.ValueIsDefinedAt(select34, /*index=*/{0}));
974   EXPECT_FALSE(analysis.ValueIsDefinedAt(select1234, /*index=*/{0}));
975 
976   EXPECT_THAT(HloValuesAt(select11, /*index=*/{0}),
977               UnorderedElementsAre(analysis.GetValueDefinedAt(constant1)));
978   EXPECT_THAT(HloValuesAt(select12, /*index=*/{0}),
979               UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
980                                    analysis.GetValueDefinedAt(constant2)));
981   EXPECT_THAT(HloValuesAt(select34, /*index=*/{0}),
982               UnorderedElementsAre(analysis.GetValueDefinedAt(constant3),
983                                    analysis.GetValueDefinedAt(constant4)));
984   EXPECT_THAT(HloValuesAt(select1234, /*index=*/{0}),
985               UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
986                                    analysis.GetValueDefinedAt(constant2),
987                                    analysis.GetValueDefinedAt(constant3),
988                                    analysis.GetValueDefinedAt(constant4)));
989 
990   EXPECT_THAT(
991       analysis.GetValueDefinedAt(tuple1, /*index=*/{}).uses(),
992       UnorderedElementsAre(HloUse{select11, 1, {}}, HloUse{select11, 2, {}},
993                            HloUse{select12, 1, {}}));
994 
995   // The two constant values just pass through the Selects and are not
996   // used except at the root. They are live out however.
997   EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
998               UnorderedElementsAre(HloUse{select1234, 1, {0}}));
999   EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
1000               UnorderedElementsAre(HloUse{select1234, 1, {0}}));
1001   EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
1002   EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
1003 }
1004 
TEST_P(HloDataflowAnalysisTest,NestedTupleSelect)1005 TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) {
1006   // Test kTupleSelect of a nested tuple.
1007   auto builder = HloComputation::Builder(TestName());
1008   auto pred = builder.AddInstruction(
1009       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1010   auto constant1 = builder.AddInstruction(
1011       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1012   auto constant2 = builder.AddInstruction(
1013       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
1014   auto constant3 = builder.AddInstruction(
1015       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
1016   auto constant4 = builder.AddInstruction(
1017       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4.0)));
1018   auto constant5 = builder.AddInstruction(
1019       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0)));
1020   auto inner_tuple1 = builder.AddInstruction(
1021       HloInstruction::CreateTuple({constant2, constant3}));
1022   auto tuple1 = builder.AddInstruction(
1023       HloInstruction::CreateTuple({constant1, inner_tuple1}));
1024   auto inner_tuple2 = builder.AddInstruction(
1025       HloInstruction::CreateTuple({constant5, constant3}));
1026   auto tuple2 = builder.AddInstruction(
1027       HloInstruction::CreateTuple({constant4, inner_tuple2}));
1028   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
1029       tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
1030 
1031   module_->AddEntryComputation(builder.Build());
1032   SCOPED_TRACE(module_->ToString());
1033 
1034   bool ssa_form = GetParam();
1035   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1036 
1037   EXPECT_TRUE(analysis.ValueIsDefinedAt(select));
1038 
1039   EXPECT_THAT(HloValuesAt(select, /*index=*/{0}),
1040               UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
1041                                    analysis.GetValueDefinedAt(constant4)));
1042   EXPECT_THAT(HloValuesAt(select, /*index=*/{1}),
1043               UnorderedElementsAre(analysis.GetValueDefinedAt(inner_tuple1),
1044                                    analysis.GetValueDefinedAt(inner_tuple2)));
1045   EXPECT_THAT(HloValuesAt(select, /*index=*/{1, 0}),
1046               UnorderedElementsAre(analysis.GetValueDefinedAt(constant2),
1047                                    analysis.GetValueDefinedAt(constant5)));
1048   EXPECT_THAT(HloValuesAt(select, /*index=*/{1, 1}),
1049               UnorderedElementsAre(analysis.GetValueDefinedAt(constant3)));
1050 }
1051 
TEST_P(HloDataflowAnalysisTest,TupleSelectToWhile)1052 TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) {
1053   // Test a tuple-shaped kTupleSelect feeding a kWhile instruction. HLO:
1054   //
1055   // body((F32[], F32[]) %tuple_param):
1056   //   %add = Add(%tuple_param{0}, %tuple_param{1})
1057   //   return Tuple(%tuple_param{0}, %add)
1058   //
1059   // condition((F32[], F32[]) %tuple_param):
1060   //   return Constant(false)
1061   //
1062   // entry:
1063   //   %constant1 = Constant(1.0)
1064   //   %constant2 = Constant(2.0)
1065   //   %constant3 = Constant(3.0)
1066   //   %tuple1 = Tuple(%constant1)
1067   //   %tuple2 = Tuple(%constant2)
1068   //   %select = Select(%tuple1, %tuple2)
1069   //   %gte = GetTupleElement(%select, 0)
1070   //   %tuple = Tuple(%gte, %constant3)
1071   //   return While(%tuple, body, condition)
1072   //
1073   auto builder = HloComputation::Builder(TestName());
1074 
1075   const Shape tuple_shape =
1076       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1077 
1078   // Element 0 passes transparently through the body.
1079   auto body_builder = HloComputation::Builder("body");
1080   auto body_param = body_builder.AddInstruction(
1081       HloInstruction::CreateParameter(0, tuple_shape, "param"));
1082   auto body_element_0 = body_builder.AddInstruction(
1083       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
1084   auto body_element_1 = body_builder.AddInstruction(
1085       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
1086   auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
1087       scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
1088   body_builder.AddInstruction(
1089       HloInstruction::CreateTuple({body_element_0, add}));
1090   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
1091 
1092   auto cond_builder = HloComputation::Builder("condition");
1093   cond_builder.AddInstruction(
1094       HloInstruction::CreateParameter(0, tuple_shape, "param"));
1095   cond_builder.AddInstruction(
1096       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1097   HloComputation* condition =
1098       module_->AddEmbeddedComputation(cond_builder.Build());
1099 
1100   auto pred = builder.AddInstruction(
1101       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1102   auto constant1 = builder.AddInstruction(
1103       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1104   auto constant2 = builder.AddInstruction(
1105       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
1106   auto constant3 = builder.AddInstruction(
1107       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
1108   auto tuple1 =
1109       builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
1110   auto tuple2 =
1111       builder.AddInstruction(HloInstruction::CreateTuple({constant2}));
1112   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
1113       tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
1114   auto gte = builder.AddInstruction(
1115       HloInstruction::CreateGetTupleElement(scalar_shape_, select, 0));
1116   auto tuple =
1117       builder.AddInstruction(HloInstruction::CreateTuple({gte, constant3}));
1118   auto xla_while = builder.AddInstruction(
1119       HloInstruction::CreateWhile(tuple->shape(), condition, body, tuple));
1120 
1121   module_->AddEntryComputation(builder.Build());
1122   SCOPED_TRACE(module_->ToString());
1123 
1124   bool ssa_form = GetParam();
1125   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1126 
1127   if (ssa_form) {
1128     EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
1129     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}).is_phi());
1130     EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
1131     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi());
1132 
1133     EXPECT_FALSE(analysis.ValueIsDefinedAt(select, /*index=*/{0}));
1134 
1135     EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
1136     EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
1137     EXPECT_FALSE(analysis.GetValueDefinedAt(constant3).live_out_of_module());
1138     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
1139                     .live_out_of_module());
1140   } else {
1141     EXPECT_THAT(HloValuesAt(gte),
1142                 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
1143                                      analysis.GetValueDefinedAt(constant2)));
1144     EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{0}),
1145                 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
1146                                      analysis.GetValueDefinedAt(constant2)));
1147     EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{1}),
1148                 UnorderedElementsAre(analysis.GetValueDefinedAt(add),
1149                                      analysis.GetValueDefinedAt(constant3)));
1150     EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
1151     EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
1152     EXPECT_TRUE(analysis.GetValueDefinedAt(constant3).live_out_of_module());
1153   }
1154 }
1155 
TEST_P(HloDataflowAnalysisTest,BitcastDefinesValue)1156 TEST_P(HloDataflowAnalysisTest, BitcastDefinesValue) {
1157   // Test the bitcast_defines_value flag to the dataflow analysis.
1158   auto builder = HloComputation::Builder(TestName());
1159   auto constant = builder.AddInstruction(
1160       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1161   auto bitcast = builder.AddInstruction(
1162       HloInstruction::CreateBitcast(scalar_shape_, constant));
1163 
1164   module_->AddEntryComputation(builder.Build());
1165   SCOPED_TRACE(module_->ToString());
1166 
1167   bool ssa_form = GetParam();
1168   {
1169     const HloDataflowAnalysis& analysis =
1170         RunAnalysis(ssa_form, /*bitcast_defines_value=*/true);
1171 
1172     EXPECT_EQ(analysis.values().size(), 2);
1173 
1174     EXPECT_TRUE(analysis.ValueIsDefinedAt(constant));
1175     EXPECT_TRUE(analysis.ValueIsDefinedAt(bitcast));
1176     EXPECT_FALSE(analysis.GetValueDefinedAt(constant).live_out_of_module());
1177     EXPECT_TRUE(analysis.GetValueDefinedAt(bitcast).live_out_of_module());
1178   }
1179   {
1180     const HloDataflowAnalysis& analysis =
1181         RunAnalysis(ssa_form, /*bitcast_defines_value=*/false);
1182     EXPECT_EQ(analysis.values().size(), 1);
1183 
1184     EXPECT_TRUE(analysis.ValueIsDefinedAt(constant));
1185     EXPECT_FALSE(analysis.ValueIsDefinedAt(bitcast));
1186     EXPECT_TRUE(analysis.GetValueDefinedAt(constant).live_out_of_module());
1187   }
1188 }
1189 
TEST_P(HloDataflowAnalysisTest,TupleCopy)1190 TEST_P(HloDataflowAnalysisTest, TupleCopy) {
1191   // Test that a tuple-shaped copy only copies (defines) the top-level value.
1192   auto builder = HloComputation::Builder(TestName());
1193   auto param0 = builder.AddInstruction(
1194       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
1195   auto param1 = builder.AddInstruction(
1196       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
1197   auto tuple =
1198       builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
1199   auto copy = builder.AddInstruction(
1200       HloInstruction::CreateUnary(tuple->shape(), HloOpcode::kCopy, tuple));
1201   module_->AddEntryComputation(builder.Build());
1202   SCOPED_TRACE(module_->ToString());
1203 
1204   bool ssa_form = GetParam();
1205   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1206 
1207   EXPECT_EQ(analysis.values().size(), 4);
1208 
1209   EXPECT_TRUE(analysis.ValueIsDefinedAt(param0));
1210   EXPECT_TRUE(analysis.ValueIsDefinedAt(param1));
1211   EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple, /*index=*/{}));
1212   EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{0}));
1213   EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{1}));
1214   EXPECT_TRUE(analysis.ValueIsDefinedAt(copy, /*index=*/{}));
1215   EXPECT_FALSE(analysis.ValueIsDefinedAt(copy, /*index=*/{0}));
1216   EXPECT_FALSE(analysis.ValueIsDefinedAt(copy, /*index=*/{1}));
1217 
1218   EXPECT_THAT(HloValuesAt(copy, /*index=*/{0}),
1219               UnorderedElementsAre(analysis.GetValueDefinedAt(param0)));
1220   EXPECT_THAT(HloValuesAt(copy, /*index=*/{1}),
1221               UnorderedElementsAre(analysis.GetValueDefinedAt(param1)));
1222   EXPECT_TRUE(
1223       analysis.GetValueDefinedAt(copy, /*index=*/{}).live_out_of_module());
1224 }
1225 
TEST_P(HloDataflowAnalysisTest,CopyStartAndCopyDone)1226 TEST_P(HloDataflowAnalysisTest, CopyStartAndCopyDone) {
1227   // Test that a CopyDone forwards its operand tuple element at {0} to the
1228   // output.
1229   auto builder = HloComputation::Builder(TestName());
1230   auto constant = builder.AddInstruction(
1231       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1232   auto copy_start = builder.AddInstruction(HloInstruction::CreateCopyStart(
1233       ShapeUtil::MakeTupleShape({constant->shape(), constant->shape(),
1234                                  ShapeUtil::MakeShape(U32, {})}),
1235       constant));
1236   auto copy_done = builder.AddInstruction(HloInstruction::CreateUnary(
1237       constant->shape(), HloOpcode::kCopyDone, copy_start));
1238   module_->AddEntryComputation(builder.Build());
1239   SCOPED_TRACE(module_->ToString());
1240 
1241   bool ssa_form = GetParam();
1242   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1243 
1244   EXPECT_EQ(analysis.values().size(), 4);
1245 
1246   EXPECT_TRUE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{}));
1247   EXPECT_TRUE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{0}));
1248   EXPECT_FALSE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{1}));
1249   EXPECT_TRUE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{2}));
1250   EXPECT_FALSE(analysis.ValueIsDefinedAt(copy_done, /*index=*/{}));
1251   EXPECT_THAT(
1252       HloValuesAt(copy_done, /*index=*/{}),
1253       UnorderedElementsAre(analysis.GetValueDefinedAt(copy_start, {0})));
1254   EXPECT_TRUE(analysis.GetValueDefinedAt(copy_start, /*index=*/{0})
1255                   .live_out_of_module());
1256 }
1257 
TEST_P(HloDataflowAnalysisTest,SendAndSendDone)1258 TEST_P(HloDataflowAnalysisTest, SendAndSendDone) {
1259   // Test that a Send forwards its operand to the output tuple at {0}.
1260   auto builder = HloComputation::Builder(TestName());
1261   auto param = builder.AddInstruction(
1262       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
1263   auto token = builder.AddInstruction(HloInstruction::CreateToken());
1264   auto send = builder.AddInstruction(
1265       HloInstruction::CreateSend(param, token, /*channel_id=*/0));
1266   auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
1267   module_->AddEntryComputation(builder.Build());
1268   SCOPED_TRACE(module_->ToString());
1269 
1270   bool ssa_form = GetParam();
1271   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1272 
1273   EXPECT_EQ(analysis.values().size(), 6);
1274 
1275   EXPECT_TRUE(analysis.ValueIsDefinedAt(param));
1276   EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{}));
1277   EXPECT_FALSE(analysis.ValueIsDefinedAt(send, /*index=*/{0}));
1278   EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{1}));
1279   EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{2}));
1280   EXPECT_TRUE(analysis.ValueIsDefinedAt(send_done));
1281   EXPECT_THAT(HloValuesAt(send, /*index=*/{0}),
1282               UnorderedElementsAre(analysis.GetValueDefinedAt(param)));
1283 }
1284 
TEST_P(HloDataflowAnalysisTest,SetDimensionSizeForwardsValue)1285 TEST_P(HloDataflowAnalysisTest, SetDimensionSizeForwardsValue) {
1286   auto builder = HloComputation::Builder(TestName());
1287   auto param = builder.AddInstruction(
1288       HloInstruction::CreateParameter(0, vector_shape_, "param"));
1289   auto size = builder.AddInstruction(
1290       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(3)));
1291   auto sds = builder.AddInstruction(
1292       HloInstruction::CreateSetDimensionSize(vector_shape_, param, size, 0));
1293 
1294   module_->AddEntryComputation(builder.Build());
1295   SCOPED_TRACE(module_->ToString());
1296 
1297   bool ssa_form = GetParam();
1298   {
1299     const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1300     EXPECT_EQ(analysis.values().size(), 2);
1301 
1302     EXPECT_TRUE(analysis.ValueIsDefinedAt(param));
1303     EXPECT_FALSE(analysis.ValueIsDefinedAt(sds));
1304     EXPECT_TRUE(analysis.GetValueDefinedAt(param).live_out_of_module());
1305   }
1306 }
1307 
TEST_P(HloDataflowAnalysisTest,RecvAndRecvDone)1308 TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) {
1309   // Test that a RecvDone forwards its operand tuple element at {0} to element
1310   // {0} of the output.
1311   auto builder = HloComputation::Builder(TestName());
1312   auto token = builder.AddInstruction(HloInstruction::CreateToken());
1313   auto recv = builder.AddInstruction(
1314       HloInstruction::CreateRecv(scalar_shape_, token, /*channel_id=*/0));
1315   auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
1316   module_->AddEntryComputation(builder.Build());
1317   SCOPED_TRACE(module_->ToString());
1318 
1319   bool ssa_form = GetParam();
1320   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1321 
1322   EXPECT_EQ(analysis.values().size(), 7);
1323 
1324   EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{}));
1325   EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0}));
1326   EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{1}));
1327   EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{2}));
1328   EXPECT_TRUE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{}));
1329   EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{0}));
1330   EXPECT_TRUE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{1}));
1331   EXPECT_THAT(HloValuesAt(recv_done, /*index=*/{0}),
1332               UnorderedElementsAre(analysis.GetValueDefinedAt(recv, {0})));
1333   EXPECT_TRUE(
1334       analysis.GetValueDefinedAt(recv, /*index=*/{0}).live_out_of_module());
1335 }
1336 
TEST_P(HloDataflowAnalysisTest,ElementwiseChainInterference)1337 TEST_P(HloDataflowAnalysisTest, ElementwiseChainInterference) {
1338   // A simple chain of elementwise operations. No values should interfere.
1339   //
1340   // param --> negate -> exp -> log
1341   //
1342   auto builder = HloComputation::Builder(TestName());
1343   auto param = builder.AddInstruction(
1344       HloInstruction::CreateParameter(0, vector_shape_, "param"));
1345   auto negate = builder.AddInstruction(
1346       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
1347   auto exp = builder.AddInstruction(
1348       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, negate));
1349   auto log = builder.AddInstruction(
1350       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kLog, exp));
1351 
1352   module_->AddEntryComputation(builder.Build());
1353   SCOPED_TRACE(module_->ToString());
1354   RunAnalysis(GetParam());
1355 
1356   DependencyHloOrdering ordering(module_.get());
1357 
1358   // No values should interfere.
1359   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, negate));
1360   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
1361   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, log));
1362   EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, exp));
1363   EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, log));
1364   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, negate));
1365   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, log));
1366   EXPECT_FALSE(InstructionsMayInterfere(ordering, log, negate));
1367   EXPECT_FALSE(InstructionsMayInterfere(ordering, log, exp));
1368 
1369   // Values should interfere with itself.
1370   EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, exp));
1371 }
1372 
TEST_P(HloDataflowAnalysisTest,MultipleEntryParameters_Sequential)1373 TEST_P(HloDataflowAnalysisTest, MultipleEntryParameters_Sequential) {
1374   // Two entry params, which interfere with each other.
1375   //
1376   // param0 --> negate ---------------\
1377   //                param1 --> exp --> add
1378   auto builder = HloComputation::Builder(TestName());
1379   auto param0 = builder.AddInstruction(
1380       HloInstruction::CreateParameter(0, vector_shape_, "param0"));
1381   auto param1 = builder.AddInstruction(
1382       HloInstruction::CreateParameter(1, vector_shape_, "param1"));
1383   auto negate = builder.AddInstruction(
1384       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param0));
1385   auto exp = builder.AddInstruction(
1386       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param1));
1387   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1388       vector_shape_, HloOpcode::kAdd, negate, exp));
1389 
1390   auto entry = module_->AddEntryComputation(builder.Build());
1391   SCOPED_TRACE(module_->ToString());
1392   RunAnalysis(GetParam());
1393 
1394   HloSchedule schedule(module_.get());
1395   schedule.set_sequence(entry, {param0, negate, param1, exp, add});
1396   TF_ASSERT_OK(schedule.Verify());
1397   SequentialHloOrdering ordering(schedule);
1398 
1399   // Entry parameters interfere as if they are defined simultaneously at
1400   // the very beginning.
1401   EXPECT_TRUE(InstructionsMayInterfere(ordering, param0, param1));
1402   EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, negate));
1403   EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, exp));
1404   EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, add));
1405   EXPECT_TRUE(InstructionsMayInterfere(ordering, param1, param0));
1406   EXPECT_TRUE(InstructionsMayInterfere(ordering, param1, negate));
1407   EXPECT_FALSE(InstructionsMayInterfere(ordering, param1, exp));
1408   EXPECT_FALSE(InstructionsMayInterfere(ordering, param1, add));
1409 
1410   // Negate and exp still interfere.
1411   EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp));
1412   EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate));
1413 
1414   // But {negate, add} and {exp, add} don't interfere.
1415   EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add));
1416   EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate));
1417   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add));
1418   EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp));
1419 }
1420 
TEST_P(HloDataflowAnalysisTest,WhileParameters_Sequential)1421 TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) {
1422   // Similar to MultipleEntryParameters_Sequential, but the parameter is of
1423   // while body computation. Body computation in the sequential order:
1424   //
1425   //  %constant = Constant(...)
1426   //  %exp = Exp(%constant)
1427   //  %param = Param(0)
1428   //  %add = Add(%param, %exp)  ;; Root of body
1429   //  %dead_constant = Constant(...)
1430   //  %dead_negate = Negate(%dead_constant)
1431   //
1432   // %constant and its only use %exp are ordered before 'param'. However, the
1433   // %constant and %param values still interfere because the parameter is
1434   // considered live into the while body.
1435   //
1436   // Similarly, %dead_constant and %dead_negate are ordered after the root of
1437   // the body computation %add. However, %add is liveout of the computation so
1438   // %dead_constant and %add interfere.
1439   auto body_builder = HloComputation::Builder(TestName());
1440   auto body_param = body_builder.AddInstruction(
1441       HloInstruction::CreateParameter(0, scalar_shape_, "body_param"));
1442   auto constant = body_builder.AddInstruction(
1443       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1444   auto exp = body_builder.AddInstruction(
1445       HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant));
1446   auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
1447       scalar_shape_, HloOpcode::kAdd, exp, body_param));
1448   auto dead_constant = body_builder.AddInstruction(
1449       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1450   auto dead_negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
1451       scalar_shape_, HloOpcode::kNegate, dead_constant));
1452   HloComputation* body = module_->AddEmbeddedComputation(
1453       body_builder.Build(/*root_instruction=*/add));
1454 
1455   auto cond_builder = HloComputation::Builder("condition");
1456   auto cond_param = cond_builder.AddInstruction(
1457       HloInstruction::CreateParameter(0, scalar_shape_, "cond_param"));
1458   auto cond_constant = cond_builder.AddInstruction(
1459       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1460   HloComputation* condition =
1461       module_->AddEmbeddedComputation(cond_builder.Build());
1462 
1463   auto builder = HloComputation::Builder(TestName());
1464   auto param = builder.AddInstruction(
1465       HloInstruction::CreateParameter(0, scalar_shape_, "param"));
1466   auto xla_while = builder.AddInstruction(
1467       HloInstruction::CreateWhile(scalar_shape_, condition, body, param));
1468 
1469   auto entry = module_->AddEntryComputation(builder.Build());
1470   SCOPED_TRACE(module_->ToString());
1471   bool ssa_form = GetParam();
1472   RunAnalysis(ssa_form);
1473 
1474   HloSchedule schedule(module_.get());
1475   schedule.set_sequence(entry, {param, xla_while});
1476   schedule.set_sequence(condition, {cond_param, cond_constant});
1477   // Construct the order such that 'constant' and its use 'exp' are before
1478   // body_param.
1479   schedule.set_sequence(
1480       body, {constant, exp, body_param, add, dead_constant, dead_negate});
1481   TF_ASSERT_OK(schedule.Verify());
1482 
1483   SequentialHloOrdering ordering(schedule);
1484 
1485   // 'add' is live out of the body and will interfere with an later instructions
1486   // such as 'dead_constant' and 'dead_negate'.
1487   EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_constant));
1488   EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_negate));
1489 
1490   // The remaining checks test phi values defined by body and condition
1491   // parameters which only occur in the SSA form of the analysis.
1492   if (ssa_form) {
1493     // Though the ordering suggests 'constant' and 'param' should not interfere,
1494     // 'param' is live in and thus interferes with any earlier instruction of
1495     // the computation in the order (eg 'constant')'
1496     EXPECT_TRUE(InstructionsMayInterfere(ordering, body_param, constant));
1497     EXPECT_TRUE(InstructionsMayInterfere(ordering, body_param, exp));
1498     EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, add));
1499 
1500     // The following values end up in the same buffer:
1501     //  (1) the init value: 'param'
1502     //  (2) the body parameter: 'body_param'
1503     //  (3) the condition parameter: 'cond_param'
1504     //  (4) the root value of the while body: 'add'
1505     //  (5) the while value: 'xla_while'
1506     // None should interfere.
1507     EXPECT_FALSE(InstructionsMayInterfere(ordering, param, body_param));
1508     EXPECT_FALSE(InstructionsMayInterfere(ordering, param, cond_param));
1509     EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add));
1510     EXPECT_FALSE(InstructionsMayInterfere(ordering, param, xla_while));
1511 
1512     EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, cond_param));
1513     EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, add));
1514     EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, xla_while));
1515 
1516     EXPECT_FALSE(InstructionsMayInterfere(ordering, cond_param, add));
1517     EXPECT_FALSE(InstructionsMayInterfere(ordering, cond_param, xla_while));
1518 
1519     EXPECT_FALSE(InstructionsMayInterfere(ordering, add, xla_while));
1520   }
1521 }
1522 
TEST_P(HloDataflowAnalysisTest,NonElementwiseOperand)1523 TEST_P(HloDataflowAnalysisTest, NonElementwiseOperand) {
1524   // A chain of operations with two elementwise and one non-elementwise. The
1525   // elementwise op should not interfere with its operand, while the
1526   // non-elementwise op should interfere. Entry params always interfere.
1527   //
1528   // param --> exp -> negate -> reverse
1529   //
1530   auto builder = HloComputation::Builder(TestName());
1531   auto param = builder.AddInstruction(
1532       HloInstruction::CreateParameter(0, vector_shape_, "param"));
1533   auto exp = builder.AddInstruction(
1534       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
1535   auto negate = builder.AddInstruction(
1536       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, exp));
1537   auto reverse = builder.AddInstruction(
1538       HloInstruction::CreateReverse(vector_shape_, negate, {0}));
1539 
1540   module_->AddEntryComputation(builder.Build());
1541   SCOPED_TRACE(module_->ToString());
1542   RunAnalysis(GetParam());
1543 
1544   DependencyHloOrdering ordering(module_.get());
1545 
1546   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
1547   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, negate));
1548   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, reverse));
1549 
1550   // Negate is elementwise, so doesn't interfere with its operand.
1551   // Reverse is non-elementwise, so does interfere with its operand.
1552   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, negate));
1553   EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, reverse));
1554 }
1555 
TEST_P(HloDataflowAnalysisTest,OverlappedValues)1556 TEST_P(HloDataflowAnalysisTest, OverlappedValues) {
1557   // Verify simultaneously live values interfere (exp and negate).
1558   //
1559   // param --> negate -> add
1560   //     \---> exp -----/
1561   //
1562   auto builder = HloComputation::Builder(TestName());
1563   auto param = builder.AddInstruction(
1564       HloInstruction::CreateParameter(0, vector_shape_, "param"));
1565   auto negate = builder.AddInstruction(
1566       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
1567   auto exp = builder.AddInstruction(
1568       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
1569   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1570       vector_shape_, HloOpcode::kAdd, negate, exp));
1571 
1572   module_->AddEntryComputation(builder.Build());
1573   SCOPED_TRACE(module_->ToString());
1574   RunAnalysis(GetParam());
1575 
1576   DependencyHloOrdering ordering(module_.get());
1577 
1578   EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate));
1579   EXPECT_TRUE(InstructionsMayInterfere(ordering, param, exp));
1580   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add));
1581 
1582   // Negate and exp interfere with each other, but not with add.
1583   EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp));
1584   EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate));
1585   EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add));
1586   EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate));
1587   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add));
1588   EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp));
1589 }
1590 
TEST_P(HloDataflowAnalysisTest,OverlappedValuesSequentialOrder)1591 TEST_P(HloDataflowAnalysisTest, OverlappedValuesSequentialOrder) {
1592   // Identical to the test OverlappedValue but using a sequential ordering of
1593   // HLO instructions.
1594   //
1595   // param --> negate -> add
1596   //     \---> exp -----/
1597   //
1598   // Sequential order:
1599   //  param, negate, exp, add
1600   //
1601   // Liveness is identical to the DependencyHloOrdering.
1602   auto builder = HloComputation::Builder(TestName());
1603   auto param = builder.AddInstruction(
1604       HloInstruction::CreateParameter(0, vector_shape_, "param"));
1605   auto negate = builder.AddInstruction(
1606       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
1607   auto exp = builder.AddInstruction(
1608       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
1609   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1610       vector_shape_, HloOpcode::kAdd, negate, exp));
1611 
1612   auto entry = module_->AddEntryComputation(builder.Build());
1613   SCOPED_TRACE(module_->ToString());
1614   RunAnalysis(GetParam());
1615 
1616   HloSchedule schedule(module_.get());
1617   schedule.set_sequence(entry, {param, negate, exp, add});
1618   TF_ASSERT_OK(schedule.Verify());
1619   SequentialHloOrdering ordering(schedule);
1620 
1621   EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate));
1622   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
1623   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add));
1624 
1625   // Negate and exp interfere with each other, but not with add.
1626   EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp));
1627   EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate));
1628   EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add));
1629   EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate));
1630   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add));
1631   EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp));
1632 }
1633 
TEST_P(HloDataflowAnalysisTest,EmbeddedComputationInterference)1634 TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) {
1635   // Test MayInterfere() for embedded computation, specifically the interference
1636   // of values in different computations.
1637   //
1638   // embedded_computation:
1639   //   %embedded_param = Param(0)
1640   //   %embedded_log = Log(%embedded_param)
1641   //
1642   // entry computation:
1643   //   %param = Param(0)
1644   //   %negate = Negate(%param)
1645   //   %exp = Negate(%exp)
1646   //   %call = Call(embedded_computation, {%exp})
1647   //   %add = Add(%negate, %call)
1648   //
1649   // Note %negate is live across the call and should interfere with all values
1650   // in the embedded computation.
1651   auto embedded_builder = HloComputation::Builder(TestName() + "_embedded");
1652   auto embedded_param = embedded_builder.AddInstruction(
1653       HloInstruction::CreateParameter(0, vector_shape_, "embedded_param"));
1654   auto embedded_log =
1655       embedded_builder.AddInstruction(HloInstruction::CreateUnary(
1656           vector_shape_, HloOpcode::kLog, embedded_param));
1657   auto embedded_computation =
1658       module_->AddEmbeddedComputation(embedded_builder.Build());
1659 
1660   auto builder = HloComputation::Builder(TestName());
1661   auto param = builder.AddInstruction(
1662       HloInstruction::CreateParameter(0, vector_shape_, "param"));
1663   auto negate = builder.AddInstruction(
1664       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
1665   auto exp = builder.AddInstruction(
1666       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
1667   auto call = builder.AddInstruction(
1668       HloInstruction::CreateCall(vector_shape_, {exp}, embedded_computation));
1669   builder.AddInstruction(HloInstruction::CreateBinary(
1670       vector_shape_, HloOpcode::kAdd, negate, call));
1671   module_->AddEntryComputation(builder.Build());
1672   SCOPED_TRACE(module_->ToString());
1673   RunAnalysis(GetParam());
1674 
1675   DependencyHloOrdering ordering(module_.get());
1676 
1677   // Exp only use is the call so it should not interfere with values inside
1678   // the embedded computation.
1679   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, embedded_log));
1680 
1681   // Negate is live across the call and should interfere with values in the
1682   // embedded computation
1683   EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, embedded_log));
1684 }
1685 
TEST_P(HloDataflowAnalysisTest,GetFlattenedValueSet)1686 TEST_P(HloDataflowAnalysisTest, GetFlattenedValueSet) {
1687   const char* hlo_text = R"(
1688 HloModule test_aliasing_module
1689 
1690 ENTRY root {
1691   param = s32[1000] parameter(0)
1692   p0 = s32[1000] copy(param)
1693   p1 = s32[1000] copy(param)
1694   ROOT t = (s32[1000], s32[1000]) tuple(p0, p1)
1695   })";
1696   TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_text));
1697   auto entry = module_->entry_computation();
1698   entry->GetInstructionWithName("t");
1699   auto& dataflow_analysis = RunAnalysis(GetParam());
1700   auto set = dataflow_analysis.GetFlattenedValueSet(
1701       entry->GetInstructionWithName("t"));
1702   EXPECT_EQ(set.values().size(), 3);
1703 }
1704 
TEST_P(HloDataflowAnalysisTest,ConditionalWithIdentity)1705 TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) {
1706   // Test conditional with identity computations in both true and false cases.
1707   //
1708   // true_computation(F32[] %true_param):
1709   //   return %true_param
1710   //
1711   // false_computation(F32[] %false_param):
1712   //   return %false_param
1713   //
1714   // entry:
1715   //   %pred = Constant(true)
1716   //   %constant1 = Constant(56.0)
1717   //   %constant2 = Constant(12.0)
1718   //   return Conditional(%pred, %constant1, true_computation,
1719   //                      %constant2, false_computation)
1720 
1721   auto true_builder = HloComputation::Builder(TestName() + "_true");
1722   auto true_param = true_builder.AddInstruction(
1723       HloInstruction::CreateParameter(0, scalar_shape_, "true_param"));
1724   HloComputation* true_computation =
1725       module_->AddEmbeddedComputation(true_builder.Build());
1726 
1727   auto false_builder = HloComputation::Builder(TestName() + "_false");
1728   auto false_param = false_builder.AddInstruction(
1729       HloInstruction::CreateParameter(0, scalar_shape_, "false_param"));
1730   HloComputation* false_computation =
1731       module_->AddEmbeddedComputation(false_builder.Build());
1732 
1733   auto builder = HloComputation::Builder(TestName());
1734   auto pred = builder.AddInstruction(
1735       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1736   auto constant1 = builder.AddInstruction(
1737       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.0f)));
1738   auto constant2 = builder.AddInstruction(
1739       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.0f)));
1740   auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
1741       scalar_shape_, pred, constant1, true_computation, constant2,
1742       false_computation));
1743   module_->AddEntryComputation(builder.Build());
1744   SCOPED_TRACE(module_->ToString());
1745 
1746   const HloDataflowAnalysis& analysis = RunAnalysis(GetParam());
1747 
1748   EXPECT_TRUE(analysis.ValueIsDefinedAt(pred));
1749   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
1750   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
1751 
1752   EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param));
1753   EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param));
1754 
1755   EXPECT_EQ(analysis.GetUniqueValueAt(true_param),
1756             analysis.GetValueDefinedAt(constant1));
1757   EXPECT_EQ(analysis.GetUniqueValueAt(false_param),
1758             analysis.GetValueDefinedAt(constant2));
1759 
1760   EXPECT_THAT(analysis.GetValueDefinedAt(pred).uses(),
1761               ElementsAre(HloUse{conditional, 0, {}}));
1762   EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
1763               ElementsAre(HloUse{conditional, 1, {}}));
1764   EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
1765               ElementsAre(HloUse{conditional, 2, {}}));
1766 
1767   bool ssa_form = GetParam();
1768   if (ssa_form) {
1769     EXPECT_EQ(analysis.values().size(), 4);
1770     EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional));
1771   } else {
1772     EXPECT_EQ(analysis.values().size(), 3);
1773     EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
1774     EXPECT_THAT(HloValuesAt(conditional),
1775                 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
1776                                      analysis.GetValueDefinedAt(constant2)));
1777   }
1778 }
1779 
TEST_P(HloDataflowAnalysisTest,ConditionalTakingTupleOperand)1780 TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) {
1781   // Test conditional with true and false computations taking a tuple operand.
1782   //
1783   // true_computation((F32[], F32[]) %true_param):
1784   //   %true_x = GetTupleElement(%true_param, 0)
1785   //   %true_y = GetTupleElement(%true_param, 1)
1786   //   return Add(%true_x, %true_y)
1787   //
1788   // false_computation((F32[], F32[]) %false_param):
1789   //   %false_x = GetTupleElement(%false_param, 0)
1790   //   %false_y = GetTupleElement(%false_param, 1)
1791   //   return Subtract(%false_x, %false_y)
1792   //
1793   // entry:
1794   //   %pred = Constant(true)
1795   //   %constant1 = Constant(56.0)
1796   //   %constant2 = Constant(12.0)
1797   //   %tuple_operand = Tuple(%constant1, %constant2)
1798   //   return Conditional(%pred, %tuple_operand, true_computation,
1799   //                      %tuple_operand, false_computation)
1800 
1801   auto true_builder = HloComputation::Builder(TestName() + "_true");
1802   auto true_param = true_builder.AddInstruction(
1803       HloInstruction::CreateParameter(0, tuple_shape_, "true_param"));
1804   auto true_x = true_builder.AddInstruction(
1805       HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 0));
1806   auto true_y = true_builder.AddInstruction(
1807       HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 1));
1808   auto add = true_builder.AddInstruction(HloInstruction::CreateBinary(
1809       scalar_shape_, HloOpcode::kAdd, true_x, true_y));
1810   HloComputation* true_computation =
1811       module_->AddEmbeddedComputation(true_builder.Build());
1812 
1813   auto false_builder = HloComputation::Builder(TestName() + "_false");
1814   auto false_param = false_builder.AddInstruction(
1815       HloInstruction::CreateParameter(0, tuple_shape_, "false_param"));
1816   auto false_x = false_builder.AddInstruction(
1817       HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 0));
1818   auto false_y = false_builder.AddInstruction(
1819       HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 1));
1820   auto sub = false_builder.AddInstruction(HloInstruction::CreateBinary(
1821       scalar_shape_, HloOpcode::kSubtract, false_x, false_y));
1822   HloComputation* false_computation =
1823       module_->AddEmbeddedComputation(false_builder.Build());
1824 
1825   auto builder = HloComputation::Builder(TestName());
1826   auto pred = builder.AddInstruction(
1827       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1828   auto constant1 = builder.AddInstruction(
1829       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.0f)));
1830   auto constant2 = builder.AddInstruction(
1831       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.0f)));
1832   auto tuple_operand = builder.AddInstruction(
1833       HloInstruction::CreateTuple({constant1, constant2}));
1834   auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
1835       scalar_shape_, pred, tuple_operand, true_computation, tuple_operand,
1836       false_computation));
1837   module_->AddEntryComputation(builder.Build());
1838   SCOPED_TRACE(module_->ToString());
1839 
1840   const HloDataflowAnalysis& analysis = RunAnalysis(GetParam());
1841 
1842   EXPECT_TRUE(analysis.ValueIsDefinedAt(pred));
1843   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
1844   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
1845   EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand));
1846   EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
1847   EXPECT_TRUE(analysis.ValueIsDefinedAt(sub));
1848 
1849   EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param));
1850   EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param));
1851   EXPECT_FALSE(analysis.ValueIsDefinedAt(true_x));
1852   EXPECT_FALSE(analysis.ValueIsDefinedAt(true_y));
1853   EXPECT_FALSE(analysis.ValueIsDefinedAt(false_x));
1854   EXPECT_FALSE(analysis.ValueIsDefinedAt(false_y));
1855 
1856   EXPECT_EQ(analysis.GetUniqueValueAt(true_param),
1857             analysis.GetValueDefinedAt(tuple_operand));
1858   EXPECT_EQ(analysis.GetUniqueValueAt(false_param),
1859             analysis.GetValueDefinedAt(tuple_operand));
1860   EXPECT_EQ(analysis.GetUniqueValueAt(true_x),
1861             analysis.GetValueDefinedAt(constant1));
1862   EXPECT_EQ(analysis.GetUniqueValueAt(true_y),
1863             analysis.GetValueDefinedAt(constant2));
1864   EXPECT_EQ(analysis.GetUniqueValueAt(false_x),
1865             analysis.GetValueDefinedAt(constant1));
1866   EXPECT_EQ(analysis.GetUniqueValueAt(false_y),
1867             analysis.GetValueDefinedAt(constant2));
1868 
1869   EXPECT_THAT(analysis.GetValueDefinedAt(pred).uses(),
1870               ElementsAre(HloUse{conditional, 0, {}}));
1871   EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
1872               UnorderedElementsAre(HloUse{conditional, 1, {0}},
1873                                    HloUse{conditional, 2, {0}},
1874                                    HloUse{add, 0, {}}, HloUse{sub, 0, {}}));
1875   EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
1876               UnorderedElementsAre(HloUse{conditional, 1, {1}},
1877                                    HloUse{conditional, 2, {1}},
1878                                    HloUse{add, 1, {}}, HloUse{sub, 1, {}}));
1879   EXPECT_THAT(analysis.GetValueDefinedAt(tuple_operand).uses(),
1880               UnorderedElementsAre(
1881                   HloUse{conditional, 1, {}}, HloUse{conditional, 2, {}},
1882                   HloUse{true_x, 0, {}}, HloUse{true_y, 0, {}},
1883                   HloUse{false_x, 0, {}}, HloUse{false_y, 0, {}}));
1884 
1885   bool ssa_form = GetParam();
1886   if (ssa_form) {
1887     EXPECT_EQ(analysis.values().size(), 7);
1888     EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional));
1889   } else {
1890     EXPECT_EQ(analysis.values().size(), 6);
1891     EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
1892     EXPECT_THAT(HloValuesAt(conditional),
1893                 UnorderedElementsAre(analysis.GetValueDefinedAt(add),
1894                                      analysis.GetValueDefinedAt(sub)));
1895   }
1896 }
1897 
TEST_P(HloDataflowAnalysisTest,NestedConditionals)1898 TEST_P(HloDataflowAnalysisTest, NestedConditionals) {
1899   // computation1(F32[] %param1):
1900   //   %ceil = Ceil(%param1)
1901   //   return %ceil
1902   //
1903   // computation2(F32[] %param2):
1904   //   %floor = Floor(%param2)
1905   //   return %floor
1906   //
1907   // computation3(F32[] %param3):
1908   //   %negate = Negate(%param3)
1909   //   return %negate
1910   //
1911   // inner_conditional((PRED, F32[], F32[]) %param_cond):
1912   //   %pred_cond = GetTupleElement(%param_cond, 0)
1913   //   %true_operand_cond = GetTupleElement(%param_cond, 1)
1914   //   %false_operand_cond = GetTupleElement(%param_cond, 2)
1915   //   return Conditional(%pred_cond, %true_operand_cond, computation1,
1916   //                      %false_operand_cond, computation2)
1917   //
1918   // entry:
1919   //   %pred1 = Constant(true)
1920   //   %pred2 = Constant(false)
1921   //   %constant1 = Constant(1.1);
1922   //   %constant2 = Constant(2.2);
1923   //   %constant3 = Constant(3.3);
1924   //   return Conditional(%pred1, (%pred2, %constant1, %constant2),
1925   //                      inner_conditional, %constant3, computation3)
1926 
1927   auto computation1 = module_->AddEmbeddedComputation(
1928       CreateR0F32UnaryOpComputation(HloOpcode::kCeil));
1929   auto computation2 = module_->AddEmbeddedComputation(
1930       CreateR0F32UnaryOpComputation(HloOpcode::kFloor));
1931   auto computation3 = module_->AddEmbeddedComputation(
1932       CreateR0F32UnaryOpComputation(HloOpcode::kNegate));
1933 
1934   // Build inner_conditional computation.
1935   const Shape scalar_bool_shape = ShapeUtil::MakeShape(PRED, {});
1936   const Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
1937       {scalar_bool_shape, scalar_shape_, scalar_shape_});
1938   auto inner_builder =
1939       HloComputation::Builder(TestName() + "_inner_conditional");
1940   auto param_cond = inner_builder.AddInstruction(
1941       HloInstruction::CreateParameter(0, tuple_param_shape, "param_cond"));
1942   auto pred_cond = inner_builder.AddInstruction(
1943       HloInstruction::CreateGetTupleElement(scalar_bool_shape, param_cond, 0));
1944   auto true_operand_cond = inner_builder.AddInstruction(
1945       HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 1));
1946   auto false_operand_cond = inner_builder.AddInstruction(
1947       HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 2));
1948   auto inner_conditional =
1949       inner_builder.AddInstruction(HloInstruction::CreateConditional(
1950           scalar_shape_, pred_cond, true_operand_cond, computation1,
1951           false_operand_cond, computation2));
1952   auto inner_conditional_computation =
1953       module_->AddEmbeddedComputation(inner_builder.Build());
1954 
1955   // Build entry computation.
1956   auto builder = HloComputation::Builder(TestName());
1957   auto pred1 = builder.AddInstruction(
1958       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1959   auto pred2 = builder.AddInstruction(
1960       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1961   auto constant1 = builder.AddInstruction(
1962       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
1963   auto constant2 = builder.AddInstruction(
1964       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.2f)));
1965   auto constant3 = builder.AddInstruction(
1966       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.3f)));
1967   auto tuple_operand = builder.AddInstruction(
1968       HloInstruction::CreateTuple({pred2, constant1, constant2}));
1969   auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
1970       scalar_shape_, pred1, tuple_operand, inner_conditional_computation,
1971       constant3, computation3));
1972   module_->AddEntryComputation(builder.Build());
1973   SCOPED_TRACE(module_->ToString());
1974 
1975   const HloDataflowAnalysis& analysis = RunAnalysis(GetParam());
1976 
1977   EXPECT_TRUE(analysis.ValueIsDefinedAt(pred1));
1978   EXPECT_TRUE(analysis.ValueIsDefinedAt(pred2));
1979   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
1980   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
1981   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant3));
1982   EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand));
1983   EXPECT_TRUE(analysis.ValueIsDefinedAt(computation1->root_instruction()));
1984   EXPECT_TRUE(analysis.ValueIsDefinedAt(computation2->root_instruction()));
1985   EXPECT_TRUE(analysis.ValueIsDefinedAt(computation3->root_instruction()));
1986 
1987   auto computation1_param = computation1->parameter_instruction(0);
1988   auto computation2_param = computation2->parameter_instruction(0);
1989   auto computation3_param = computation3->parameter_instruction(0);
1990   EXPECT_FALSE(analysis.ValueIsDefinedAt(computation1_param));
1991   EXPECT_FALSE(analysis.ValueIsDefinedAt(computation2_param));
1992   EXPECT_FALSE(analysis.ValueIsDefinedAt(computation3_param));
1993   EXPECT_EQ(analysis.GetUniqueValueAt(computation1_param),
1994             analysis.GetValueDefinedAt(constant1));
1995   EXPECT_EQ(analysis.GetUniqueValueAt(computation2_param),
1996             analysis.GetValueDefinedAt(constant2));
1997   EXPECT_EQ(analysis.GetUniqueValueAt(computation3_param),
1998             analysis.GetValueDefinedAt(constant3));
1999 
2000   EXPECT_FALSE(analysis.ValueIsDefinedAt(param_cond));
2001   EXPECT_FALSE(analysis.ValueIsDefinedAt(pred_cond));
2002   EXPECT_FALSE(analysis.ValueIsDefinedAt(true_operand_cond));
2003   EXPECT_FALSE(analysis.ValueIsDefinedAt(false_operand_cond));
2004   EXPECT_EQ(analysis.GetUniqueValueAt(param_cond),
2005             analysis.GetValueDefinedAt(tuple_operand));
2006   EXPECT_EQ(analysis.GetUniqueValueAt(pred_cond),
2007             analysis.GetValueDefinedAt(pred2));
2008   EXPECT_EQ(analysis.GetUniqueValueAt(true_operand_cond),
2009             analysis.GetValueDefinedAt(constant1));
2010   EXPECT_EQ(analysis.GetUniqueValueAt(false_operand_cond),
2011             analysis.GetValueDefinedAt(constant2));
2012 
2013   bool ssa_form = GetParam();
2014   if (ssa_form) {
2015     EXPECT_EQ(analysis.values().size(), 11);
2016     EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_conditional));
2017     EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional));
2018   } else {
2019     EXPECT_EQ(analysis.values().size(), 9);
2020     EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional));
2021     EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
2022     EXPECT_THAT(
2023         HloValuesAt(inner_conditional),
2024         UnorderedElementsAre(
2025             analysis.GetValueDefinedAt(computation1->root_instruction()),
2026             analysis.GetValueDefinedAt(computation2->root_instruction())));
2027     EXPECT_THAT(
2028         HloValuesAt(conditional),
2029         UnorderedElementsAre(
2030             analysis.GetValueDefinedAt(computation1->root_instruction()),
2031             analysis.GetValueDefinedAt(computation2->root_instruction()),
2032             analysis.GetValueDefinedAt(computation3->root_instruction())));
2033   }
2034 }
2035 
TEST_P(HloDataflowAnalysisTest,AddDependency)2036 TEST_P(HloDataflowAnalysisTest, AddDependency) {
2037   string module_string = R"(
2038 HloModule AddDependency
2039 ENTRY %AddDependency (p: f32[3]) -> f32[3] {
2040   %p = f32[3] parameter(0)
2041   %token0 = token[] after-all()
2042   ROOT %add_dep = f32[3] add-dependency(f32[3] %p, token[] %token0)
2043 }
2044 )";
2045   TF_ASSERT_OK_AND_ASSIGN(
2046       std::unique_ptr<HloModule> module,
2047       ParseAndReturnVerifiedModule(module_string, GetModuleConfigForTest()));
2048 
2049   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloDataflowAnalysis> analysis,
2050                           HloDataflowAnalysis::Run(*module));
2051   const HloInstruction* root = module->entry_computation()->root_instruction();
2052   EXPECT_EQ(root->opcode(), HloOpcode::kAddDependency);
2053 
2054   // The after-all and parameter should define a value. Add-dependency should
2055   // not.
2056   EXPECT_EQ(analysis->values().size(), 2);
2057   EXPECT_FALSE(analysis->ValueIsDefinedAt(root));
2058 }
2059 
2060 INSTANTIATE_TEST_SUITE_P(HloDataflowAnalysisInstantiation,
2061                          HloDataflowAnalysisTest,
2062                          ::testing::Values(false, true));
2063 
2064 class HloDataflowAnalysisTestBase : public HloTestBase {
2065  protected:
BuildModule(std::unique_ptr<HloComputation> computation)2066   void BuildModule(std::unique_ptr<HloComputation> computation) {
2067     module_ = CreateNewVerifiedModule();
2068     computation_ = module_->AddEntryComputation(std::move(computation));
2069   }
2070 
RunAnalysis(const HloDataflowAnalysis::CanShareBuffer & can_share_buffer=nullptr)2071   void RunAnalysis(
2072       const HloDataflowAnalysis::CanShareBuffer& can_share_buffer = nullptr) {
2073     CHECK_NOTNULL(module_.get());
2074     dataflow_analysis_ = HloDataflowAnalysis::Run(
2075                              *module_, /*ssa_form=*/false,
2076                              /*bitcast_defines_value=*/false, can_share_buffer)
2077                              .ConsumeValueOrDie();
2078   }
2079 
BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation)2080   void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
2081     BuildModule(std::move(computation));
2082     RunAnalysis();
2083   }
2084 
2085   std::unique_ptr<HloModule> module_;
2086   HloComputation* computation_ = nullptr;
2087   std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
2088 };
2089 
2090 class DoesNotUseOperandBufferTest : public HloDataflowAnalysisTestBase {};
2091 
TEST_F(DoesNotUseOperandBufferTest,GetTupleElement)2092 TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) {
2093   auto builder = HloComputation::Builder(TestName());
2094 
2095   Shape elem_shape = ShapeUtil::MakeShape(F32, {8});
2096   auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
2097       0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple"));
2098   auto gte0 = builder.AddInstruction(
2099       HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0));
2100   auto gte1 = builder.AddInstruction(
2101       HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1));
2102   builder.AddInstruction(
2103       HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1));
2104 
2105   BuildModuleAndRunAnalysis(builder.Build());
2106 
2107   // GetTupleElement instructions only access the top-level buffer of their
2108   // operand.
2109   EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, gte0));
2110   EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, gte1));
2111   EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte0));
2112   EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte1));
2113 }
2114 
TEST_F(DoesNotUseOperandBufferTest,FusedDynamicUpdateSlice)2115 TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
2116   auto builder = HloComputation::Builder(TestName());
2117 
2118   Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2119   auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
2120       0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
2121   auto gte0 = builder.AddInstruction(
2122       HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
2123   auto gte1 = builder.AddInstruction(
2124       HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
2125 
2126   // Create a DynamicUpdateSlice instruction of tuple element 1.
2127   auto starts = builder.AddInstruction(
2128       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
2129   auto update = builder.AddInstruction(HloInstruction::CreateConstant(
2130       LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
2131   auto dynamic_update_slice =
2132       builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2133           data_shape, gte1, update,
2134           std::initializer_list<HloInstruction*>({starts})));
2135   builder.AddInstruction(
2136       HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
2137 
2138   BuildModule(builder.Build());
2139   auto fusion = computation_->CreateFusionInstruction(
2140       {dynamic_update_slice, starts, update, gte1},
2141       HloInstruction::FusionKind::kLoop);
2142   RunAnalysis();
2143 
2144   // The fusion instruction never uses tuple element 0, but does use element 1.
2145   EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion));
2146   EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion));
2147 }
2148 
2149 // Similar to FusedDynamicUpdateSlice above, but tests indirect uses of the
2150 // parameter tuple.
TEST_F(DoesNotUseOperandBufferTest,IndirectUses)2151 TEST_F(DoesNotUseOperandBufferTest, IndirectUses) {
2152   auto builder = HloComputation::Builder(TestName());
2153 
2154   Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2155   auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
2156       0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
2157   auto t0 = builder.AddInstruction(
2158       HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 0));
2159   auto t1 = builder.AddInstruction(
2160       HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 1));
2161   // Swap the tuple elements.
2162   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({t1, t0}));
2163 
2164   auto gte0 = builder.AddInstruction(
2165       HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
2166   auto gte1 = builder.AddInstruction(
2167       HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
2168 
2169   // Create a DynamicUpdateSlice instruction of tuple element 1.
2170   auto starts = builder.AddInstruction(
2171       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
2172   auto update = builder.AddInstruction(HloInstruction::CreateConstant(
2173       LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
2174   auto dynamic_update_slice =
2175       builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2176           data_shape, gte1, update,
2177           std::initializer_list<HloInstruction*>({starts})));
2178   builder.AddInstruction(
2179       HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
2180 
2181   BuildModule(builder.Build());
2182   auto fusion = computation_->CreateFusionInstruction(
2183       {dynamic_update_slice, starts, update, gte1},
2184       HloInstruction::FusionKind::kLoop);
2185   RunAnalysis();
2186 
2187   // The fusion instruction never uses tuple element 0, but does use element 1.
2188   EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion));
2189   EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion));
2190   // The same holds for the parameter tuple, except that the tuple elements
2191   // are swapped in 'tuple'.
2192   EXPECT_TRUE(
2193       dataflow_analysis_->DoesNotUseOperandBuffer(tuple_param, {1}, fusion));
2194   EXPECT_FALSE(
2195       dataflow_analysis_->DoesNotUseOperandBuffer(tuple_param, {0}, fusion));
2196 }
2197 
2198 class CanShareOperandBufferWithUserTest : public HloDataflowAnalysisTestBase {};
2199 
TEST_F(CanShareOperandBufferWithUserTest,ElementWiseSameShape)2200 TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) {
2201   auto builder = HloComputation::Builder(TestName());
2202 
2203   Shape shape = ShapeUtil::MakeShape(F32, {8});
2204   auto param = builder.AddInstruction(
2205       HloInstruction::CreateParameter(0, shape, "param"));
2206   auto exp = builder.AddInstruction(
2207       HloInstruction::CreateUnary(shape, HloOpcode::kExp, param));
2208   auto log = builder.AddInstruction(
2209       HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp));
2210 
2211   BuildModuleAndRunAnalysis(builder.Build());
2212 
2213   EXPECT_TRUE(
2214       dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {}));
2215   EXPECT_TRUE(
2216       dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, log, {}));
2217 }
2218 
TEST_F(CanShareOperandBufferWithUserTest,NonElementwiseLoopFusionCantAliasOperandBuffer)2219 TEST_F(CanShareOperandBufferWithUserTest,
2220        NonElementwiseLoopFusionCantAliasOperandBuffer) {
2221   auto builder = HloComputation::Builder(TestName());
2222   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2223 
2224   auto param0 = builder.AddInstruction(
2225       HloInstruction::CreateParameter(0, data_shape, "param0"));
2226 
2227   auto neg = builder.AddInstruction(
2228       HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, param0));
2229 
2230   auto reverse = builder.AddInstruction(
2231       HloInstruction::CreateReverse(data_shape, neg, {0, 1}));
2232 
2233   BuildModule(builder.Build());
2234   auto fusion = computation_->CreateFusionInstruction(
2235       {reverse, neg}, HloInstruction::FusionKind::kLoop);
2236   RunAnalysis();
2237 
2238   EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
2239                                                                  fusion, {}));
2240 }
2241 
TEST_F(CanShareOperandBufferWithUserTest,MultiOutputFusionCanAliasOperandBuffer)2242 TEST_F(CanShareOperandBufferWithUserTest,
2243        MultiOutputFusionCanAliasOperandBuffer) {
2244   auto builder = HloComputation::Builder(TestName());
2245   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2246 
2247   Shape in_shape = ShapeUtil::MakeShape(F32, {8});
2248   Shape out_shape = ShapeUtil::MakeShape(PRED, {8});
2249   auto param0 = builder.AddInstruction(
2250       HloInstruction::CreateParameter(0, in_shape, "param0"));
2251   auto param1 = builder.AddInstruction(
2252       HloInstruction::CreateParameter(1, in_shape, "param1"));
2253 
2254   auto copy0 = builder.AddInstruction(
2255       HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param0));
2256   auto copy1 = builder.AddInstruction(
2257       HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param1));
2258 
2259   auto tuple =
2260       builder.AddInstruction(HloInstruction::CreateTuple({copy1, copy0}));
2261 
2262   BuildModule(builder.Build());
2263   auto fusion = computation_->CreateFusionInstruction(
2264       {tuple, copy1, copy0}, HloInstruction::FusionKind::kLoop);
2265   RunAnalysis();
2266 
2267   EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
2268                                                                 fusion, {0}));
2269   EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
2270                                                                 fusion, {1}));
2271   EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
2272                                                                 fusion, {0}));
2273   EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
2274                                                                 fusion, {1}));
2275 }
2276 
TEST_F(CanShareOperandBufferWithUserTest,ElementwiseLoopFusionCantAliasOperandBuffer)2277 TEST_F(CanShareOperandBufferWithUserTest,
2278        ElementwiseLoopFusionCantAliasOperandBuffer) {
2279   auto builder = HloComputation::Builder(TestName());
2280   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2281 
2282   auto one = builder.AddInstruction(
2283       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2284   auto operand = builder.AddInstruction(
2285       HloInstruction::CreateBroadcast(data_shape, one, {}));
2286 
2287   auto neg = builder.AddInstruction(
2288       HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, operand));
2289 
2290   auto exp = builder.AddInstruction(
2291       HloInstruction::CreateUnary(data_shape, HloOpcode::kExp, neg));
2292 
2293   BuildModule(builder.Build());
2294   auto fusion = computation_->CreateFusionInstruction(
2295       {exp, neg}, HloInstruction::FusionKind::kLoop);
2296   RunAnalysis();
2297 
2298   EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {},
2299                                                                 fusion, {}));
2300 }
2301 
TEST_F(CanShareOperandBufferWithUserTest,CanShareOperandWhenDynamicUpdateSliceIsFedByDynamicSliceWithSameIndex)2302 TEST_F(CanShareOperandBufferWithUserTest,
2303        CanShareOperandWhenDynamicUpdateSliceIsFedByDynamicSliceWithSameIndex) {
2304   auto builder = HloComputation::Builder(TestName());
2305   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2306   Shape slice_shape = ShapeUtil::MakeShape(F32, {1, 2});
2307 
2308   auto param = builder.AddInstruction(
2309       HloInstruction::CreateParameter(0, data_shape, "param0"));
2310   auto zero = builder.AddInstruction(
2311       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64>(0)));
2312   auto ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
2313       slice_shape, param, {zero, zero}, {1, 2}));
2314 
2315   auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2316       data_shape, param, ds, {zero, zero}));
2317 
2318   BuildModule(builder.Build());
2319   auto fusion = computation_->CreateFusionInstruction(
2320       {dus, ds, zero}, HloInstruction::FusionKind::kLoop);
2321   RunAnalysis();
2322 
2323   EXPECT_TRUE(
2324       dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {}));
2325 }
2326 
TEST_F(CanShareOperandBufferWithUserTest,DUSWithSliceWithSameIndices)2327 TEST_F(CanShareOperandBufferWithUserTest, DUSWithSliceWithSameIndices) {
2328   const char* kModule = R"(
2329     HloModule test
2330 
2331     fused_computation {
2332       p0 = f32[10,20,30] parameter(0)
2333       p1 = s32[] parameter(1)
2334       p2 = s32[] parameter(2)
2335       p3 = s32[] parameter(3)
2336       slice = f32[1,1,30] dynamic-slice(p0, p1, p2, p3), dynamic_slice_sizes={1,1,30}
2337       ROOT dus = f32[10,20,30] dynamic-update-slice(p0, slice, p1, p2, p3)
2338     }
2339 
2340     ENTRY test {
2341       p0 = f32[10,20,30] parameter(0)
2342       p1 = s32[] parameter(1)
2343       p2 = s32[] parameter(2)
2344       p3 = s32[] parameter(3)
2345       ROOT fusion = f32[10,20,30] fusion(p0, p1, p2, p3), kind=kLoop, calls=fused_computation
2346     }
2347   )";
2348   TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule));
2349   auto* fusion = module_->entry_computation()->root_instruction();
2350   auto* param = module_->entry_computation()->parameter_instruction(0);
2351 
2352   RunAnalysis();
2353   EXPECT_TRUE(
2354       dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {}));
2355 }
2356 
TEST_F(CanShareOperandBufferWithUserTest,ElementWiseDifferentShape)2357 TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
2358   auto builder = HloComputation::Builder(TestName());
2359 
2360   Shape in_shape = ShapeUtil::MakeShape(F32, {8});
2361   Shape out_shape = ShapeUtil::MakeShape(PRED, {8});
2362   auto param0 = builder.AddInstruction(
2363       HloInstruction::CreateParameter(0, in_shape, "param0"));
2364   auto param1 = builder.AddInstruction(
2365       HloInstruction::CreateParameter(1, in_shape, "param1"));
2366   auto result = builder.AddInstruction(HloInstruction::CreateCompare(
2367       out_shape, param0, param1, ComparisonDirection::kEq));
2368 
2369   BuildModuleAndRunAnalysis(builder.Build());
2370 
2371   EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
2372                                                                  result, {}));
2373   EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
2374                                                                  result, {}));
2375 }
2376 
TEST_F(CanShareOperandBufferWithUserTest,CopyShares)2377 TEST_F(CanShareOperandBufferWithUserTest, CopyShares) {
2378   auto builder = HloComputation::Builder(TestName());
2379 
2380   Shape shape = ShapeUtil::MakeShape(F32, {8});
2381   auto param = builder.AddInstruction(
2382       HloInstruction::CreateParameter(0, shape, "param"));
2383   auto exp = builder.AddInstruction(
2384       HloInstruction::CreateUnary(shape, HloOpcode::kExp, param));
2385   auto copy = builder.AddInstruction(
2386       HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp));
2387 
2388   BuildModuleAndRunAnalysis(builder.Build());
2389 
2390   EXPECT_TRUE(
2391       dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {}));
2392   EXPECT_TRUE(
2393       dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, copy, {}));
2394 }
2395 
TEST_F(CanShareOperandBufferWithUserTest,FusedDynamicUpdateSlice)2396 TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
2397   auto builder = HloComputation::Builder(TestName());
2398 
2399   Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2400   auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
2401       0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
2402   auto gte0 = builder.AddInstruction(
2403       HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
2404   auto gte1 = builder.AddInstruction(
2405       HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
2406 
2407   // Create a DynamicUpdateSlice instruction of tuple element 1.
2408   auto starts = builder.AddInstruction(
2409       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
2410   auto update = builder.AddInstruction(HloInstruction::CreateConstant(
2411       LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
2412   auto dynamic_update_slice =
2413       builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2414           data_shape, gte1, update,
2415           std::initializer_list<HloInstruction*>({starts})));
2416   builder.AddInstruction(
2417       HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
2418 
2419   BuildModule(builder.Build());
2420   auto fusion = computation_->CreateFusionInstruction(
2421       {dynamic_update_slice, starts, update, gte1},
2422       HloInstruction::FusionKind::kLoop);
2423   RunAnalysis();
2424 
2425   // The fusion instruction can share with tuple element 1.
2426   EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(tuple, {0},
2427                                                                  fusion, {}));
2428   EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(tuple, {1},
2429                                                                 fusion, {}));
2430 }
2431 
TEST_F(CanShareOperandBufferWithUserTest,FusedDynamicUpdateSliceWithConvertCanShare)2432 TEST_F(CanShareOperandBufferWithUserTest,
2433        FusedDynamicUpdateSliceWithConvertCanShare) {
2434   auto builder = HloComputation::Builder(TestName());
2435 
2436   Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2437   Shape data_shape_bf16 = ShapeUtil::MakeShape(BF16, {8});
2438   auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
2439       0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
2440   auto gte0 = builder.AddInstruction(
2441       HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
2442   auto gte1 = builder.AddInstruction(
2443       HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
2444 
2445   auto convert1 = builder.AddInstruction(
2446       HloInstruction::CreateConvert(data_shape_bf16, gte1));
2447 
2448   // Create a DynamicUpdateSlice instruction of tuple element 1.
2449   auto starts = builder.AddInstruction(
2450       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
2451   auto update = builder.AddInstruction(HloInstruction::CreateConstant(
2452       LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
2453   auto dynamic_update_slice =
2454       builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2455           data_shape_bf16, convert1, update,
2456           std::initializer_list<HloInstruction*>({starts})));
2457 
2458   auto convert2 = builder.AddInstruction(
2459       HloInstruction::CreateConvert(data_shape, dynamic_update_slice));
2460   builder.AddInstruction(HloInstruction::CreateTuple({gte0, convert2}));
2461 
2462   BuildModule(builder.Build());
2463   auto fusion = computation_->CreateFusionInstruction(
2464       {convert2, dynamic_update_slice, starts, update, convert1},
2465       HloInstruction::FusionKind::kLoop);
2466   RunAnalysis();
2467 
2468   EXPECT_TRUE(
2469       dataflow_analysis_->CanShareOperandBufferWithUser(gte1, {}, fusion, {}));
2470 }
2471 
TEST_F(CanShareOperandBufferWithUserTest,DynamicUpdateSliceCanShare)2472 TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
2473   auto builder = HloComputation::Builder(TestName());
2474 
2475   Shape data_shape = ShapeUtil::MakeShape(F32, {1, 8});
2476   Shape update_shape = ShapeUtil::MakeShape(F32, {1, 4});
2477   Shape starts_shape = ShapeUtil::MakeShape(S32, {2});
2478   auto data = builder.AddInstruction(
2479       HloInstruction::CreateParameter(0, data_shape, "data"));
2480   auto update = builder.AddInstruction(
2481       HloInstruction::CreateParameter(1, update_shape, "update"));
2482   auto start = builder.AddInstruction(
2483       HloInstruction::CreateParameter(2, starts_shape, "start"));
2484 
2485   auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2486       data_shape, data, update, {start}));
2487 
2488   BuildModuleAndRunAnalysis(builder.Build());
2489 
2490   // The DynamicUpdateSlice instruction can share with the data operand, but not
2491   // with update or start.
2492   EXPECT_TRUE(
2493       dataflow_analysis_->CanShareOperandBufferWithUser(data, {}, dus, {}));
2494   EXPECT_FALSE(
2495       dataflow_analysis_->CanShareOperandBufferWithUser(update, {}, dus, {}));
2496   EXPECT_FALSE(
2497       dataflow_analysis_->CanShareOperandBufferWithUser(start, {}, dus, {}));
2498 }
2499 
TEST_F(CanShareOperandBufferWithUserTest,ScatterCanShare)2500 TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) {
2501   const char* hlo_text = R"(
2502     HloModule TensorFlowScatterV1
2503 
2504     update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2505       lhs = s32[] parameter(0)
2506       ROOT rhs = s32[] parameter(1)
2507     }
2508 
2509     ENTRY main {
2510       operand = s32[3,3] parameter(0)
2511       indices = s32[2] parameter(1)
2512       updates = s32[2,3] parameter(2)
2513       ROOT scatter = s32[3,3] scatter(operand, indices, updates),
2514           to_apply=update_s32,
2515           update_window_dims={1},
2516           inserted_window_dims={0},
2517           scatter_dims_to_operand_dims={0},
2518           index_vector_dim=1
2519     }
2520   )";
2521   TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_text));
2522   computation_ = module_->entry_computation();
2523   RunAnalysis();
2524 
2525   HloInstruction* operand_param = computation_->parameter_instruction(0);
2526   HloInstruction* indices_param = computation_->parameter_instruction(1);
2527   HloInstruction* updates_param = computation_->parameter_instruction(2);
2528   HloInstruction* scatter = computation_->root_instruction();
2529 
2530   EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(
2531       operand_param, {}, scatter, {}));
2532   EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(
2533       indices_param, {}, scatter, {}));
2534   EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(
2535       updates_param, {}, scatter, {}));
2536 }
2537 
TEST_F(CanShareOperandBufferWithUserTest,TriangularSolveCanShare)2538 TEST_F(CanShareOperandBufferWithUserTest, TriangularSolveCanShare) {
2539   const char* hlo_text = R"(
2540     HloModule TensorFlowTriangularSolve
2541 
2542     ENTRY main {
2543       a = f32[4,4]{1,0} parameter(0)
2544       b = f32[3,4]{1,0} parameter(1)
2545       ROOT triangular-solve = f32[3,4]{1,0} triangular-solve(a, b), lower=true,
2546                                               transpose_a=NO_TRANSPOSE
2547     }
2548   )";
2549   TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_text));
2550   computation_ = module_->entry_computation();
2551   RunAnalysis();
2552 
2553   HloInstruction* lhs_param = computation_->parameter_instruction(0);
2554   HloInstruction* rhs_param = computation_->parameter_instruction(1);
2555   HloInstruction* triangular_solve = computation_->root_instruction();
2556 
2557   EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(
2558       lhs_param, {}, triangular_solve, {}));
2559   EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(
2560       rhs_param, {}, triangular_solve, {}));
2561 }
2562 
TEST_F(CanShareOperandBufferWithUserTest,SortCanShare)2563 TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) {
2564   auto builder = HloComputation::Builder(TestName());
2565   module_ = CreateNewVerifiedModule();
2566 
2567   Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
2568   auto keys = builder.AddInstruction(
2569       HloInstruction::CreateParameter(0, keys_shape, "keys"));
2570   TF_ASSERT_OK_AND_ASSIGN(
2571       auto* sort, MakeSortHlo(keys_shape, {keys}, -1, /*is_stable=*/false,
2572                               &builder, module_.get()));
2573 
2574   computation_ = module_->AddEntryComputation(builder.Build());
2575   RunAnalysis();
2576 
2577   EXPECT_TRUE(
2578       dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {}));
2579 }
2580 
TEST_F(CanShareOperandBufferWithUserTest,SortCanShareWithTupleUser)2581 TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) {
2582   auto builder = HloComputation::Builder(TestName());
2583   module_ = CreateNewVerifiedModule();
2584 
2585   Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
2586   Shape values_shape = ShapeUtil::MakeShape(F32, {8});
2587   auto keys = builder.AddInstruction(
2588       HloInstruction::CreateParameter(0, keys_shape, "keys"));
2589   auto values = builder.AddInstruction(
2590       HloInstruction::CreateParameter(1, values_shape, "values"));
2591   TF_ASSERT_OK_AND_ASSIGN(
2592       auto* sort,
2593       MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}),
2594                   {keys, values}, 0, /*is_stable=*/false, &builder,
2595                   module_.get()));
2596 
2597   computation_ = module_->AddEntryComputation(builder.Build());
2598   RunAnalysis();
2599 
2600   // The buffer for the keys can be shared with the first tuple entry.
2601   EXPECT_TRUE(
2602       dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {0}));
2603   // The buffer for the values can be shared with the second tuple entry.
2604   EXPECT_TRUE(
2605       dataflow_analysis_->CanShareOperandBufferWithUser(values, {}, sort, {1}));
2606   // Verify that the buffers are not shared with the "wrong" tuple entry.
2607   EXPECT_FALSE(
2608       dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {1}));
2609   EXPECT_FALSE(
2610       dataflow_analysis_->CanShareOperandBufferWithUser(values, {}, sort, {0}));
2611 }
2612 
TEST_F(CanShareOperandBufferWithUserTest,FusedDotAdd)2613 TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
2614   auto builder = HloComputation::Builder(TestName());
2615   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2616 
2617   auto a = builder.AddInstruction(HloInstruction::CreateConstant(
2618       LiteralUtil::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
2619   auto b = builder.AddInstruction(HloInstruction::CreateConstant(
2620       LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
2621 
2622   DotDimensionNumbers dot_dnums;
2623   dot_dnums.add_lhs_contracting_dimensions(1);
2624   dot_dnums.add_rhs_contracting_dimensions(0);
2625   PrecisionConfig precision_config;
2626   precision_config.mutable_operand_precision()->Resize(
2627       2, PrecisionConfig::DEFAULT);
2628   auto dot = builder.AddInstruction(
2629       HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config));
2630 
2631   auto one = builder.AddInstruction(
2632       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2633   auto add_operand = builder.AddInstruction(
2634       HloInstruction::CreateBroadcast(data_shape, one, {}));
2635 
2636   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
2637       data_shape, HloOpcode::kAdd, dot, add_operand));
2638 
2639   BuildModule(builder.Build());
2640   auto fusion = computation_->CreateFusionInstruction(
2641       {add, dot}, HloInstruction::FusionKind::kOutput);
2642   RunAnalysis();
2643 
2644   // Output fused dot add should be able to share buffer with 'add_operand'.
2645   EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(add_operand, {},
2646                                                                 fusion, {}));
2647 }
2648 
TEST_F(CanShareOperandBufferWithUserTest,OutputFusionCantAliasOperandBuffer)2649 TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
2650   auto builder = HloComputation::Builder(TestName());
2651   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2652 
2653   auto one = builder.AddInstruction(
2654       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2655   auto operand = builder.AddInstruction(
2656       HloInstruction::CreateBroadcast(data_shape, one, {}));
2657 
2658   auto reverse = builder.AddInstruction(
2659       HloInstruction::CreateReverse(data_shape, operand, {0, 1}));
2660 
2661   auto two = builder.AddInstruction(HloInstruction::CreateConstant(
2662       LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
2663 
2664   auto add = builder.AddInstruction(
2665       HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two));
2666 
2667   BuildModule(builder.Build());
2668   auto fusion = computation_->CreateFusionInstruction(
2669       {add, two, reverse}, HloInstruction::FusionKind::kOutput);
2670   RunAnalysis();
2671 
2672   // Output fused operand->reverse->add cannot alias operand buffer 'operand'.
2673   EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {},
2674                                                                  fusion, {}));
2675 }
2676 
TEST_F(CanShareOperandBufferWithUserTest,FusionCanShareBufferCustomized)2677 TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) {
2678   auto builder = HloComputation::Builder(TestName());
2679   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2680 
2681   auto one = builder.AddInstruction(
2682       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2683   auto operand = builder.AddInstruction(
2684       HloInstruction::CreateBroadcast(data_shape, one, {}));
2685   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
2686       data_shape, HloOpcode::kMultiply, operand, operand));
2687   auto two = builder.AddInstruction(HloInstruction::CreateConstant(
2688       LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
2689   auto add = builder.AddInstruction(
2690       HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, mul, two));
2691 
2692   BuildModule(builder.Build());
2693   auto fusion = computation_->CreateFusionInstruction(
2694       {add, two, mul}, HloInstruction::FusionKind::kInput);
2695   RunAnalysis(/*can_share_buffer=*/[](const HloInstruction* fusion,
2696                                       const HloInstruction*,
2697                                       const ShapeIndex&) {
2698     return fusion->IsLoopFusion();
2699   });
2700 
2701   EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {},
2702                                                                  fusion, {}));
2703 }
2704 
TEST_F(CanShareOperandBufferWithUserTest,WhileCanShare)2705 TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
2706   module_ = CreateNewVerifiedModule();
2707   Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2708   Shape pred_scalar_shape = ShapeUtil::MakeShape(PRED, {});
2709 
2710   auto b = HloComputation::Builder(TestName() + ".And");
2711   auto p0 = b.AddInstruction(
2712       HloInstruction::CreateParameter(0, pred_scalar_shape, "p0"));
2713   auto p1 = b.AddInstruction(
2714       HloInstruction::CreateParameter(1, pred_scalar_shape, "p1"));
2715   b.AddInstruction(
2716       HloInstruction::CreateBinary(pred_scalar_shape, HloOpcode::kAnd, p0, p1));
2717   auto and_computation = module_->AddEmbeddedComputation(b.Build());
2718 
2719   auto make_cond = [&data_shape, &and_computation]() {
2720     auto builder = HloComputation::Builder(TestName() + ".Cond");
2721     auto data = builder.AddInstruction(
2722         HloInstruction::CreateParameter(0, data_shape, "data"));
2723     auto compare = builder.AddInstruction(HloInstruction::CreateCompare(
2724         ShapeUtil::MakeShape(PRED, {8}), data, data, ComparisonDirection::kEq));
2725     auto true_value = builder.AddInstruction(
2726         HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
2727     builder.AddInstruction(
2728         HloInstruction::CreateReduce(ShapeUtil::MakeShape(PRED, {}), compare,
2729                                      true_value, {0}, and_computation));
2730     return builder.Build();
2731   };
2732 
2733   auto make_body = [&data_shape]() {
2734     auto builder = HloComputation::Builder(TestName() + ".Body");
2735     auto data = builder.AddInstruction(
2736         HloInstruction::CreateParameter(0, data_shape, "data"));
2737     builder.AddInstruction(
2738         HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data));
2739     return builder.Build();
2740   };
2741 
2742   HloComputation* cond_computation =
2743       module_->AddEmbeddedComputation(make_cond());
2744   HloComputation* body_computation =
2745       module_->AddEmbeddedComputation(make_body());
2746 
2747   auto builder = HloComputation::Builder(TestName());
2748   auto data = builder.AddInstruction(
2749       HloInstruction::CreateParameter(0, data_shape, "data"));
2750   auto whil = builder.AddInstruction(HloInstruction::CreateWhile(
2751       data_shape, cond_computation, body_computation, data));
2752   computation_ = module_->AddEntryComputation(builder.Build());
2753 
2754   RunAnalysis();
2755 
2756   // The While instruction can share with the data operand.
2757   EXPECT_TRUE(
2758       dataflow_analysis_->CanShareOperandBufferWithUser(data, {}, whil, {}));
2759 }
2760 
2761 // Tests that Call can alias operand buffer if the only use of the operand
2762 // in the called computation is an elementwise instruction.
TEST_F(CanShareOperandBufferWithUserTest,CallToComputationWithFusionRoot)2763 TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
2764   Shape shape = ShapeUtil::MakeShape(F32, {8});
2765   // Build sub-computation with fusion root.
2766   auto sub_builder = HloComputation::Builder(TestName() + "_sub");
2767   auto sub_param = sub_builder.AddInstruction(
2768       HloInstruction::CreateParameter(0, shape, "sub_param"));
2769   auto one = sub_builder.AddInstruction(
2770       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2771   auto ones = sub_builder.AddInstruction(
2772       HloInstruction::CreateBroadcast(shape, one, {}));
2773   auto add = sub_builder.AddInstruction(
2774       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones));
2775 
2776   module_ = CreateNewVerifiedModule();
2777   auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build());
2778   sub_computation->CreateFusionInstruction({add, ones},
2779                                            HloInstruction::FusionKind::kLoop);
2780 
2781   // Build entry-computation with kCall which calls 'sub_computation'.
2782   auto builder = HloComputation::Builder(TestName());
2783 
2784   auto param = builder.AddInstruction(
2785       HloInstruction::CreateParameter(0, shape, "param"));
2786   auto reverse =
2787       builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0}));
2788   auto call = builder.AddInstruction(
2789       HloInstruction::CreateCall(shape, {reverse}, sub_computation));
2790   computation_ = module_->AddEntryComputation(builder.Build());
2791 
2792   RunAnalysis();
2793 
2794   EXPECT_TRUE(
2795       dataflow_analysis_->CanShareOperandBufferWithUser(reverse, {}, call, {}));
2796 }
2797 
2798 }  // namespace
2799 }  // namespace xla
2800