1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
17
18 #include "tensorflow/compiler/xla/literal.h"
19 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
22 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
23 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
26 #include "tensorflow/compiler/xla/service/instruction_fusion.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/test_helpers.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/test.h"
36
37 namespace xla {
38 namespace {
39
40 using ::testing::ElementsAre;
41 using ::testing::UnorderedElementsAre;
42
43 // Test is parameterized on a bool which is whether the dataflow analysis is
44 // performed with SSA form.
45 class HloDataflowAnalysisTest : public HloTestBase,
46 public ::testing::WithParamInterface<bool> {
47 protected:
HloDataflowAnalysisTest()48 HloDataflowAnalysisTest() : module_(CreateNewVerifiedModule()) {}
49
50 // Run dataflow analysis on the member module. For convenience returns a
51 // reference to the generated analysis stored in analysis_.
RunAnalysis(bool ssa_form,bool bitcast_defines_value=false)52 const HloDataflowAnalysis& RunAnalysis(bool ssa_form,
53 bool bitcast_defines_value = false) {
54 FlattenCallGraph flatten;
55 EXPECT_TRUE(flatten.Run(module_.get()).ok());
56 analysis_ =
57 HloDataflowAnalysis::Run(*module_, ssa_form, bitcast_defines_value)
58 .ConsumeValueOrDie();
59 return *analysis_;
60 }
61
62 // Return a vector of the HloValues at the given program position.
HloValuesAt(const HloInstruction * instruction,const ShapeIndex & index={})63 std::vector<HloValue> HloValuesAt(const HloInstruction* instruction,
64 const ShapeIndex& index = {}) {
65 CHECK(analysis_ != nullptr);
66 std::vector<HloValue> values;
67 for (const HloValue* value :
68 analysis_->GetValueSet(instruction, index).values()) {
69 values.push_back(*value);
70 }
71 return values;
72 }
73
74 // Returns true if the top-level values for instructions 'a' and 'b' may
75 // interfere. Precondition: 'a' and 'b' define array-shaped values.
InstructionsMayInterfere(const HloOrdering & ordering,const HloInstruction * a,const HloInstruction * b)76 bool InstructionsMayInterfere(const HloOrdering& ordering,
77 const HloInstruction* a,
78 const HloInstruction* b) {
79 EXPECT_FALSE(a->shape().IsTuple());
80 EXPECT_FALSE(b->shape().IsTuple());
81 return ordering.MayInterfere(analysis_->GetValueDefinedAt(a),
82 analysis_->GetValueDefinedAt(b), *analysis_);
83 }
84
CreateR0F32UnaryOpComputation(HloOpcode opcode)85 std::unique_ptr<HloComputation> CreateR0F32UnaryOpComputation(
86 HloOpcode opcode) {
87 HloComputation::Builder builder(TestName() + "." + HloOpcodeString(opcode));
88 HloInstruction* param0 = builder.AddInstruction(
89 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
90 builder.AddInstruction(
91 HloInstruction::CreateUnary(scalar_shape_, opcode, param0));
92 return builder.Build();
93 }
94
95 std::unique_ptr<HloModule> module_;
96 std::unique_ptr<HloDataflowAnalysis> analysis_;
97
98 const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
99 const Shape vector_shape_ = ShapeUtil::MakeShape(F32, {42});
100 const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
101 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})});
102 };
103
TEST_P(HloDataflowAnalysisTest,BinaryOperation)104 TEST_P(HloDataflowAnalysisTest, BinaryOperation) {
105 // Test the dataflow for a simple binary operation (Add).
106 auto builder = HloComputation::Builder(TestName());
107 auto constant1 = builder.AddInstruction(
108 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
109 auto constant2 = builder.AddInstruction(
110 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
111 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
112 scalar_shape_, HloOpcode::kAdd, constant1, constant2));
113 module_->AddEntryComputation(builder.Build());
114 SCOPED_TRACE(module_->ToString());
115
116 bool ssa_form = GetParam();
117 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
118
119 // Each instruction should define a single value.
120 EXPECT_EQ(analysis.values().size(), 3);
121 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
122 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
123 EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
124
125 // Verify the positions of the values. These positions are all trivial because
126 // there are no instructions which forward values.
127 EXPECT_THAT(analysis.GetValueDefinedAt(constant1).positions(),
128 UnorderedElementsAre(HloPosition{constant1, {}}));
129 EXPECT_THAT(analysis.GetValueDefinedAt(constant2).positions(),
130 UnorderedElementsAre(HloPosition{constant2, {}}));
131 EXPECT_THAT(analysis.GetValueDefinedAt(add).positions(),
132 UnorderedElementsAre(HloPosition{add, {}}));
133
134 // Verify the uses of the values.
135 EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
136 UnorderedElementsAre(HloUse{add, 0, {}}));
137 EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
138 UnorderedElementsAre(HloUse{add, 1, {}}));
139 EXPECT_TRUE(analysis.GetValueDefinedAt(add).uses().empty());
140
141 // Verify liveout values from the module.
142 EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
143 EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
144 EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
145 }
146
TEST_P(HloDataflowAnalysisTest,TupleAndGtes)147 TEST_P(HloDataflowAnalysisTest, TupleAndGtes) {
148 // Verify the dataflow through a Tuple and GetTupleElement instructions.
149 auto builder = HloComputation::Builder(TestName());
150 auto param0 = builder.AddInstruction(
151 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
152 auto param1 = builder.AddInstruction(
153 HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
154 auto tuple =
155 builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
156 auto gte0 = builder.AddInstruction(
157 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 0));
158 auto gte1 = builder.AddInstruction(
159 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1));
160 auto add = builder.AddInstruction(
161 HloInstruction::CreateBinary(scalar_shape_, HloOpcode::kAdd, gte0, gte1));
162 module_->AddEntryComputation(builder.Build());
163 SCOPED_TRACE(module_->ToString());
164
165 bool ssa_form = GetParam();
166 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
167
168 // The two params, tuple, and add should each define one value.
169 EXPECT_EQ(analysis.values().size(), 4);
170
171 EXPECT_TRUE(analysis.ValueIsDefinedAt(param0));
172 EXPECT_TRUE(analysis.ValueIsDefinedAt(param1));
173 EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple, /*index=*/{}));
174 EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{0}));
175 EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{1}));
176 EXPECT_FALSE(analysis.ValueIsDefinedAt(gte0));
177 EXPECT_FALSE(analysis.ValueIsDefinedAt(gte1));
178 EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
179
180 // Verify the positions of the values.
181 EXPECT_THAT(
182 analysis.GetValueDefinedAt(param0).positions(),
183 UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}},
184 HloPosition{gte0, {}}));
185 EXPECT_THAT(
186 analysis.GetValueDefinedAt(param1).positions(),
187 UnorderedElementsAre(HloPosition{param1, {}}, HloPosition{tuple, {1}},
188 HloPosition{gte1, {}}));
189 EXPECT_THAT(analysis.GetValueDefinedAt(tuple).positions(),
190 UnorderedElementsAre(HloPosition{tuple, {}}));
191
192 // Verify uses. Of interest is that a GetTupleElement instruction is only a
193 // use of the top-level value in the tuple operand.
194 EXPECT_THAT(analysis.GetValueDefinedAt(param0).uses(),
195 UnorderedElementsAre(HloUse{add, 0, {}}));
196 EXPECT_THAT(analysis.GetValueDefinedAt(param1).uses(),
197 UnorderedElementsAre(HloUse{add, 1, {}}));
198 EXPECT_THAT(analysis.GetValueDefinedAt(tuple, /*index=*/{}).uses(),
199 UnorderedElementsAre(HloUse{gte0, 0, {}}, HloUse{gte1, 0, {}}));
200 EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
201 }
202
TEST_P(HloDataflowAnalysisTest,NestedTuple)203 TEST_P(HloDataflowAnalysisTest, NestedTuple) {
204 // Verify the dataflow through a nested tuple.
205 auto builder = HloComputation::Builder(TestName());
206 auto constant1 = builder.AddInstruction(
207 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
208 auto constant2 = builder.AddInstruction(
209 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
210 auto tuple = builder.AddInstruction(
211 HloInstruction::CreateTuple({constant1, constant2}));
212 auto nested_tuple = builder.AddInstruction(
213 HloInstruction::CreateTuple({tuple, tuple, constant1}));
214 auto gte_tuple = builder.AddInstruction(
215 HloInstruction::CreateGetTupleElement(tuple->shape(), nested_tuple, 1));
216 auto gte_out = builder.AddInstruction(
217 HloInstruction::CreateGetTupleElement(scalar_shape_, gte_tuple, 0));
218 module_->AddEntryComputation(builder.Build());
219 SCOPED_TRACE(module_->ToString());
220
221 bool ssa_form = GetParam();
222 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
223
224 EXPECT_EQ(analysis.values().size(), 4);
225
226 // Verify positions and uses.
227 EXPECT_THAT(
228 analysis.GetValueDefinedAt(constant1).positions(),
229 UnorderedElementsAre(
230 HloPosition{constant1, {}}, HloPosition{tuple, {0}},
231 HloPosition{nested_tuple, {0, 0}}, HloPosition{nested_tuple, {1, 0}},
232 HloPosition{nested_tuple, {2}}, HloPosition{gte_tuple, {0}},
233 HloPosition{gte_out, {}}));
234 // Constant values should have only a single use, which is the root of the
235 // computation.
236 EXPECT_THAT(analysis.GetValueDefinedAt(constant1, /*index=*/{}).uses(),
237 UnorderedElementsAre(HloUse{gte_out, 0, {0}}));
238 EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).uses().empty());
239
240 // The top-level tuple values are used in GTE instructions.
241 EXPECT_THAT(analysis.GetValueDefinedAt(tuple, /*index=*/{}).uses(),
242 UnorderedElementsAre(HloUse{gte_out, 0, {}}));
243 EXPECT_THAT(analysis.GetValueDefinedAt(nested_tuple, /*index=*/{}).uses(),
244 UnorderedElementsAre(HloUse{gte_tuple, 0, {}}));
245
246 EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
247 EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
248 EXPECT_FALSE(
249 analysis.GetValueDefinedAt(tuple, /*index=*/{}).live_out_of_module());
250 EXPECT_FALSE(analysis.GetValueDefinedAt(nested_tuple, /*index=*/{})
251 .live_out_of_module());
252 }
253
TEST_P(HloDataflowAnalysisTest,SingleCall)254 TEST_P(HloDataflowAnalysisTest, SingleCall) {
255 // Test a single call of a subcomputation. The subcomputation adds its two
256 // array-shaped parameters.
257 auto subbuilder = HloComputation::Builder("Subcomputation");
258 auto subparam0 = subbuilder.AddInstruction(
259 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
260 auto subparam1 = subbuilder.AddInstruction(
261 HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
262 auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary(
263 scalar_shape_, HloOpcode::kAdd, subparam0, subparam1));
264 HloComputation* called_computation =
265 module_->AddEmbeddedComputation(subbuilder.Build());
266
267 auto builder = HloComputation::Builder(TestName());
268 auto constant1 = builder.AddInstruction(
269 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
270 auto constant2 = builder.AddInstruction(
271 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
272 auto call = builder.AddInstruction(HloInstruction::CreateCall(
273 scalar_shape_, {constant1, constant2}, called_computation));
274 module_->AddEntryComputation(builder.Build());
275 SCOPED_TRACE(module_->ToString());
276
277 bool ssa_form = GetParam();
278 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
279
280 EXPECT_EQ(analysis.values().size(), 3);
281
282 // The parameters of the subcomputation and the call instruction itself should
283 // not define values. Their values flow from elsewhere.
284 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
285 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
286 EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam0));
287 EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam1));
288 EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
289 EXPECT_FALSE(analysis.ValueIsDefinedAt(call));
290
291 EXPECT_EQ(analysis.GetUniqueValueAt(subparam0),
292 analysis.GetValueDefinedAt(constant1));
293 EXPECT_EQ(analysis.GetUniqueValueAt(subparam1),
294 analysis.GetValueDefinedAt(constant2));
295 EXPECT_EQ(analysis.GetUniqueValueAt(call), analysis.GetValueDefinedAt(add));
296
297 EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
298 UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{add, 0, {}}));
299 EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
300 UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{add, 1, {}}));
301
302 EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
303 }
304
TEST_P(HloDataflowAnalysisTest,NestedCalls)305 TEST_P(HloDataflowAnalysisTest, NestedCalls) {
306 // Test a module with nested computations. HLO is:
307 //
308 // F32[] inner_computation(F32[] %param0, F32[] %param1):
309 // %add = Add(%param0, %param1)
310 //
311 // F32[] outer_computation((F32[] %param0, F32[] %param1):
312 // ;; Note that parameters are interchanged in the call.
313 // %nested_call = Call(inner_computation, {%param1, %param0})
314 //
315 // F32[] entry:
316 // %constant1 = Constant(1.0)
317 // %constant2 = Constant(2.0)
318 // %call = Call(outer_computation, {%constant1, %constant2})
319 //
320 auto inner_builder = HloComputation::Builder("InnerComputation");
321 auto inner_param0 = inner_builder.AddInstruction(
322 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
323 auto inner_param1 = inner_builder.AddInstruction(
324 HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
325 auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary(
326 scalar_shape_, HloOpcode::kAdd, inner_param0, inner_param1));
327 HloComputation* inner_computation =
328 module_->AddEmbeddedComputation(inner_builder.Build());
329
330 auto outer_builder = HloComputation::Builder("OuterComputation");
331 auto outer_param0 = outer_builder.AddInstruction(
332 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
333 auto outer_param1 = outer_builder.AddInstruction(
334 HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
335 // Swizzle parameters.
336 auto nested_call = outer_builder.AddInstruction(HloInstruction::CreateCall(
337 scalar_shape_, {outer_param1, outer_param0}, inner_computation));
338 HloComputation* outer_computation =
339 module_->AddEmbeddedComputation(outer_builder.Build());
340
341 auto builder = HloComputation::Builder(TestName());
342 auto constant1 = builder.AddInstruction(
343 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
344 auto constant2 = builder.AddInstruction(
345 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
346 auto call = builder.AddInstruction(HloInstruction::CreateCall(
347 scalar_shape_, {constant1, constant2}, outer_computation));
348 module_->AddEntryComputation(builder.Build());
349 SCOPED_TRACE(module_->ToString());
350
351 bool ssa_form = GetParam();
352 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
353
354 // Only three values should be defined. Most instructions just pass through
355 // their operand values.
356 EXPECT_EQ(analysis.values().size(), 3);
357
358 // Verify that the uses of the constants are properly swizzled by parameter
359 // permutation in nested_call.
360 EXPECT_THAT(
361 analysis.GetValueDefinedAt(constant1).uses(),
362 UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{nested_call, 1, {}},
363 HloUse{add, 1, {}}));
364 EXPECT_THAT(
365 analysis.GetValueDefinedAt(constant2).uses(),
366 UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{nested_call, 0, {}},
367 HloUse{add, 0, {}}));
368
369 EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
370 }
371
TEST_P(HloDataflowAnalysisTest,SingleWhile)372 TEST_P(HloDataflowAnalysisTest, SingleWhile) {
373 // Test a simple single while instruction. The while body includes a
374 // pass-through value. HLO:
375 //
376 // body((F32[], F32[]) %tuple_param):
377 // %add = Add(%tuple_param{0}, %tuple_param{1})
378 // return Tuple(%tuple_param{0}, %add)
379 //
380 // condition((F32[], F32[]) %tuple_param):
381 // return Constant(false)
382 //
383 // entry:
384 // %constant1 = Constant(1.0)
385 // %constant2 = Constant(2.0)
386 // %tuple = Tuple(%constant1, %constant2)
387 // return While(%tuple, body, condition)
388 //
389 const Shape tuple_shape =
390 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
391
392 // Element 0 passes transparently through the body.
393 auto body_builder = HloComputation::Builder("body");
394 auto body_param = body_builder.AddInstruction(
395 HloInstruction::CreateParameter(0, tuple_shape, "param"));
396 auto body_element_0 = body_builder.AddInstruction(
397 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
398 auto body_element_1 = body_builder.AddInstruction(
399 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
400 auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
401 scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
402 auto body_root = body_builder.AddInstruction(
403 HloInstruction::CreateTuple({body_element_0, add}));
404 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
405
406 // Condition computation trivially returns a constant "false".
407 auto cond_builder = HloComputation::Builder("condition");
408 auto cond_param = cond_builder.AddInstruction(
409 HloInstruction::CreateParameter(0, tuple_shape, "param"));
410 auto cond_constant = cond_builder.AddInstruction(
411 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
412 HloComputation* condition =
413 module_->AddEmbeddedComputation(cond_builder.Build());
414
415 auto builder = HloComputation::Builder(TestName());
416 auto constant1 = builder.AddInstruction(
417 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
418 auto constant2 = builder.AddInstruction(
419 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
420 auto tuple = builder.AddInstruction(
421 HloInstruction::CreateTuple({constant1, constant2}));
422 auto xla_while = builder.AddInstruction(
423 HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
424 module_->AddEntryComputation(builder.Build());
425 SCOPED_TRACE(module_->ToString());
426
427 bool ssa_form = GetParam();
428 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
429
430 EXPECT_FALSE(analysis.GetValueDefinedAt(cond_constant).live_out_of_module());
431
432 if (ssa_form) {
433 // Element 0 of the tuple passed through the body so no phi value is
434 // defined.
435 EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
436 EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
437 EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
438
439 // Element 1 of the tuple should be a phi value.
440 EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
441 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi());
442 EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
443 EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{1}).is_phi());
444 EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
445 EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi());
446
447 EXPECT_THAT(
448 analysis.GetValueDefinedAt(constant1).uses(),
449 UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{body_root, 0, {}},
450 HloUse{xla_while, 0, {0}}));
451
452 // Constant1 passes through the body and out of the module.
453 EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
454 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
455 .live_out_of_module());
456
457 EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module());
458 } else {
459 // While instruction and subcomputation parameters should not define values
460 // in non-ssa form.
461 EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
462 EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
463 EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
464 EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
465 EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
466 EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
467
468 EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
469 EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
470 }
471 }
472
TEST_P(HloDataflowAnalysisTest,SequentialWhiles)473 TEST_P(HloDataflowAnalysisTest, SequentialWhiles) {
474 // Test sequential while instructions. The while body includes a
475 // pass-through value. HLO:
476 //
477 // body((F32[], F32[]) %tuple_param):
478 // %add = Add(%tuple_param{0}, %tuple_param{1})
479 // return Tuple(%tuple_param{0}, %add)
480 //
481 // condition((F32[], F32[]) %tuple_param):
482 // return Constant(false)
483 //
484 // entry:
485 // %constant1 = Constant(1.0)
486 // %constant2 = Constant(2.0)
487 // %tuple = Tuple(%constant1, %constant2)
488 // %while0 = While(%tuple, body, condition)
489 // %while1 = While(%while0, body, condition)
490 // return While(%while1, body, condition)
491 //
492 const Shape tuple_shape =
493 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
494
495 // Element 0 passes transparently through the body.
496 auto body_builder = HloComputation::Builder("body");
497 auto body_param = body_builder.AddInstruction(
498 HloInstruction::CreateParameter(0, tuple_shape, "param"));
499 auto body_element_0 = body_builder.AddInstruction(
500 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
501 auto body_element_1 = body_builder.AddInstruction(
502 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
503 auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
504 scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
505 body_builder.AddInstruction(
506 HloInstruction::CreateTuple({body_element_0, add}));
507 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
508
509 auto cond_builder = HloComputation::Builder("condition");
510 cond_builder.AddInstruction(
511 HloInstruction::CreateParameter(0, tuple_shape, "param"));
512 cond_builder.AddInstruction(
513 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
514 HloComputation* condition =
515 module_->AddEmbeddedComputation(cond_builder.Build());
516
517 auto builder = HloComputation::Builder(TestName());
518 auto constant1 = builder.AddInstruction(
519 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
520 auto constant2 = builder.AddInstruction(
521 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
522 auto tuple = builder.AddInstruction(
523 HloInstruction::CreateTuple({constant1, constant2}));
524 auto xla_while0 = builder.AddInstruction(
525 HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
526 auto xla_while1 = builder.AddInstruction(
527 HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while0));
528 auto xla_while2 = builder.AddInstruction(
529 HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while1));
530 module_->AddEntryComputation(builder.Build());
531 SCOPED_TRACE(module_->ToString());
532
533 bool ssa_form = GetParam();
534 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
535
536 // Element 0 is passed through all the while instructions and out of the
537 // module..
538 EXPECT_EQ(analysis.GetUniqueValueAt(xla_while0, /*index=*/{0}),
539 analysis.GetValueDefinedAt(constant1));
540 EXPECT_EQ(analysis.GetUniqueValueAt(xla_while1, /*index=*/{0}),
541 analysis.GetValueDefinedAt(constant1));
542 EXPECT_EQ(analysis.GetUniqueValueAt(xla_while2, /*index=*/{0}),
543 analysis.GetValueDefinedAt(constant1));
544 EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
545 }
546
TEST_P(HloDataflowAnalysisTest,MultiLevelNestedWhile)547 TEST_P(HloDataflowAnalysisTest, MultiLevelNestedWhile) {
548 // Test nested while instructions. The level0 body (most inner while) and
549 // level1 body pass through the parameter, while level2 (most outer while)
550 // modifies it.
551 //
552 // level0_body((F32[]) %tuple_param):
553 // return Tuple(%tuple_param{0})
554 //
555 // level1_body((F32[]) %tuple_param):
556 // return While(%tuple_param{0}), body=level0
557 //
558 // level2_body((F32[]) %tuple_param):
559 // while = While(%tuple_param{0}), body=level1
560 //. return negate(%while{0})
561 //
562 // entry:
563 // %constant = Constant(1.0)
564 // %tuple = Tuple(%constant)
565 // return While(%tuple), body=level2
566 //
567 const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_});
568 auto cond_builder = HloComputation::Builder("condition");
569 cond_builder.AddInstruction(
570 HloInstruction::CreateParameter(0, tuple_shape, "param"));
571 cond_builder.AddInstruction(
572 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
573 HloComputation* condition =
574 module_->AddEmbeddedComputation(cond_builder.Build());
575
576 // level 0 passes transparently through the body.
577 auto level0_builder = HloComputation::Builder("level0_body");
578 auto level0_param = level0_builder.AddInstruction(
579 HloInstruction::CreateParameter(0, tuple_shape, "param"));
580 auto level0_element_0 = level0_builder.AddInstruction(
581 HloInstruction::CreateGetTupleElement(scalar_shape_, level0_param, 0));
582 auto level0_root = level0_builder.AddInstruction(
583 HloInstruction::CreateTuple({level0_element_0}));
584 HloComputation* level0_body =
585 module_->AddEmbeddedComputation(level0_builder.Build());
586
587 // Element 1 passes transparently through the body.
588 auto level1_builder = HloComputation::Builder("level1_body");
589 auto level1_param = level1_builder.AddInstruction(
590 HloInstruction::CreateParameter(0, tuple_shape, "param"));
591 auto level1_root = level1_builder.AddInstruction(HloInstruction::CreateWhile(
592 tuple_shape, condition, level0_body, level1_param));
593 HloComputation* level1_body =
594 module_->AddEmbeddedComputation(level1_builder.Build());
595
596 // Element 1 passes transparently through the body.
597 auto level2_builder = HloComputation::Builder("level2_body");
598 auto level2_param = level2_builder.AddInstruction(
599 HloInstruction::CreateParameter(0, tuple_shape, "param"));
600 auto level2_while = level2_builder.AddInstruction(HloInstruction::CreateWhile(
601 tuple_shape, condition, level1_body, level2_param));
602 auto level2_element_0 = level2_builder.AddInstruction(
603 HloInstruction::CreateGetTupleElement(scalar_shape_, level2_while, 0));
604 auto negate = level2_builder.AddInstruction(HloInstruction::CreateUnary(
605 scalar_shape_, HloOpcode::kNegate, level2_element_0));
606 level2_builder.AddInstruction(HloInstruction::CreateTuple({negate}));
607 HloComputation* level2_body =
608 module_->AddEmbeddedComputation(level2_builder.Build());
609
610 auto builder = HloComputation::Builder(TestName());
611 auto constant1 = builder.AddInstruction(
612 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
613 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
614 builder.AddInstruction(
615 HloInstruction::CreateWhile(tuple_shape, condition, level2_body, tuple));
616 module_->AddEntryComputation(builder.Build());
617 SCOPED_TRACE(module_->ToString());
618
619 bool ssa_form = GetParam();
620 if (!ssa_form) {
621 return;
622 }
623 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
624
625 // Phi node on inner parameters and roots should have been eliminated.
626 EXPECT_FALSE(analysis.ValueIsDefinedAt(level1_param, /*index=*/{0}));
627 EXPECT_FALSE(analysis.ValueIsDefinedAt(level0_param, /*index=*/{0}));
628 EXPECT_FALSE(analysis.ValueIsDefinedAt(level1_root, /*index=*/{0}));
629 EXPECT_FALSE(analysis.ValueIsDefinedAt(level0_root, /*index=*/{0}));
630 EXPECT_TRUE(analysis.ValueIsDefinedAt(level2_param, /*index=*/{0}));
631 EXPECT_EQ(HloValuesAt(level1_param, /*index=*/{0}),
632 HloValuesAt(level2_param, /*index=*/{0}));
633 EXPECT_EQ(HloValuesAt(level0_param, /*index=*/{0}),
634 HloValuesAt(level2_param, /*index=*/{0}));
635 EXPECT_EQ(HloValuesAt(level1_root, /*index=*/{0}),
636 HloValuesAt(level2_param, /*index=*/{0}));
637 EXPECT_EQ(HloValuesAt(level0_root, /*index=*/{0}),
638 HloValuesAt(level2_param, /*index=*/{0}));
639 }
640
TEST_P(HloDataflowAnalysisTest,NestedWhiles)641 TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
642 // Test nested while instructions. The inner body passes through element 0 of
643 // its parameter, and the outer body passes through element 1. HLO:
644 //
645 // inner_body((F32[], F32[]) %tuple_param):
646 // %add = Add(%tuple_param{0}, %tuple_param{1})
647 // return Tuple(%tuple_param{0}, %add)
648 //
649 // outer_body((F32[], F32[]) %tuple_param):
650 // %negate = Negate(%tuple_param{0})
651 // %tuple = Tuple(%negate, %tuple_param{1})
652 // return While(%tuple, inner_body, condition)
653 //
654 // entry:
655 // %constant1 = Constant(1.0)
656 // %constant2 = Constant(2.0)
657 // %tuple = Tuple(%constant1, %constant2)
658 // return While(%tuple, outer_body, condition)
659 //
660 const Shape tuple_shape =
661 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
662
663 auto cond_builder = HloComputation::Builder("condition");
664 cond_builder.AddInstruction(
665 HloInstruction::CreateParameter(0, tuple_shape, "param"));
666 cond_builder.AddInstruction(
667 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
668 HloComputation* condition =
669 module_->AddEmbeddedComputation(cond_builder.Build());
670
671 // Element 0 passes transparently through the body.
672 auto inner_builder = HloComputation::Builder("inner_body");
673 auto inner_param = inner_builder.AddInstruction(
674 HloInstruction::CreateParameter(0, tuple_shape, "param"));
675 auto inner_element_0 = inner_builder.AddInstruction(
676 HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 0));
677 auto inner_element_1 = inner_builder.AddInstruction(
678 HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 1));
679 auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary(
680 scalar_shape_, HloOpcode::kAdd, inner_element_0, inner_element_1));
681 inner_builder.AddInstruction(
682 HloInstruction::CreateTuple({inner_element_0, add}));
683 HloComputation* inner_body =
684 module_->AddEmbeddedComputation(inner_builder.Build());
685
686 // Element 1 passes transparently through the body.
687 auto outer_builder = HloComputation::Builder("outer_body");
688 auto outer_param = outer_builder.AddInstruction(
689 HloInstruction::CreateParameter(0, tuple_shape, "param"));
690 auto outer_element_0 = outer_builder.AddInstruction(
691 HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 0));
692 auto negate = outer_builder.AddInstruction(HloInstruction::CreateUnary(
693 scalar_shape_, HloOpcode::kNegate, outer_element_0));
694 auto outer_element_1 = outer_builder.AddInstruction(
695 HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 1));
696 auto outer_tuple = outer_builder.AddInstruction(
697 HloInstruction::CreateTuple({negate, outer_element_1}));
698 auto nested_while = outer_builder.AddInstruction(HloInstruction::CreateWhile(
699 tuple_shape, condition, inner_body, outer_tuple));
700 HloComputation* outer_body =
701 module_->AddEmbeddedComputation(outer_builder.Build());
702
703 auto builder = HloComputation::Builder(TestName());
704 auto constant1 = builder.AddInstruction(
705 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
706 auto constant2 = builder.AddInstruction(
707 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
708 auto tuple = builder.AddInstruction(
709 HloInstruction::CreateTuple({constant1, constant2}));
710 auto entry_while = builder.AddInstruction(
711 HloInstruction::CreateWhile(tuple_shape, condition, outer_body, tuple));
712 module_->AddEntryComputation(builder.Build());
713 SCOPED_TRACE(module_->ToString());
714
715 bool ssa_form = GetParam();
716 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
717
718 EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
719 UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
720 if (ssa_form) {
721 EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_param, /*index=*/{1}));
722 EXPECT_TRUE(
723 analysis.GetValueDefinedAt(inner_param, /*index=*/{1}).is_phi());
724
725 // Element 0 of the nested while is %negate.
726 EXPECT_FALSE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{0}));
727 EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
728 UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
729 // Element 1 is a phi value (join of %add and %constant2).
730 EXPECT_TRUE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{1}));
731 EXPECT_TRUE(
732 analysis.GetValueDefinedAt(nested_while, /*index=*/{1}).is_phi());
733
734 EXPECT_TRUE(analysis.ValueIsDefinedAt(entry_while, /*index=*/{0}));
735 EXPECT_TRUE(
736 analysis.GetValueDefinedAt(entry_while, /*index=*/{0}).is_phi());
737
738 EXPECT_TRUE(analysis.ValueIsDefinedAt(entry_while, /*index=*/{1}));
739 EXPECT_TRUE(
740 analysis.GetValueDefinedAt(entry_while, /*index=*/{1}).is_phi());
741 } else {
742 EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{1}),
743 UnorderedElementsAre(analysis.GetValueDefinedAt(add),
744 analysis.GetValueDefinedAt(constant2)));
745
746 EXPECT_THAT(HloValuesAt(nested_while, /*index=*/{0}),
747 UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
748 EXPECT_THAT(HloValuesAt(nested_while, /*index=*/{1}),
749 UnorderedElementsAre(analysis.GetValueDefinedAt(add),
750 analysis.GetValueDefinedAt(constant2)));
751
752 EXPECT_THAT(HloValuesAt(entry_while, /*index=*/{0}),
753 UnorderedElementsAre(analysis.GetValueDefinedAt(negate),
754 analysis.GetValueDefinedAt(constant1)));
755 EXPECT_THAT(HloValuesAt(entry_while, /*index=*/{1}),
756 UnorderedElementsAre(analysis.GetValueDefinedAt(add),
757 analysis.GetValueDefinedAt(constant2)));
758 }
759 }
760
TEST_P(HloDataflowAnalysisTest,SwizzlingWhileSharedInput)761 TEST_P(HloDataflowAnalysisTest, SwizzlingWhileSharedInput) {
762 // Test a while instruction with a body which permutes it's tuple parameter
763 // elements. HLO:
764 //
765 // body((F32[], F32[]) %tuple_param):
766 // return Tuple(%tuple_param{1}, %tuple_param{0})
767 //
768 // condition((F32[], F32[]) %tuple_param):
769 // return Constant(false)
770 //
771 // entry:
772 // %constant1 = Constant(1.0)
773 // %tuple = Tuple(%constant1, %constant1)
774 // return While(%tuple, body, condition)
775 //
776 const Shape tuple_shape =
777 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
778
779 auto body_builder = HloComputation::Builder("body");
780 auto body_param = body_builder.AddInstruction(
781 HloInstruction::CreateParameter(0, tuple_shape, "param"));
782 auto body_element_0 = body_builder.AddInstruction(
783 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
784 auto body_element_1 = body_builder.AddInstruction(
785 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
786 body_builder.AddInstruction(
787 HloInstruction::CreateTuple({body_element_1, body_element_0}));
788 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
789
790 auto cond_builder = HloComputation::Builder("condition");
791 cond_builder.AddInstruction(
792 HloInstruction::CreateParameter(0, tuple_shape, "param"));
793 cond_builder.AddInstruction(
794 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
795 HloComputation* condition =
796 module_->AddEmbeddedComputation(cond_builder.Build());
797
798 auto builder = HloComputation::Builder(TestName());
799 auto constant1 = builder.AddInstruction(
800 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
801 auto tuple = builder.AddInstruction(
802 HloInstruction::CreateTuple({constant1, constant1}));
803 builder.AddInstruction(
804 HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
805 module_->AddEntryComputation(builder.Build());
806 SCOPED_TRACE(module_->ToString());
807
808 bool ssa_form = GetParam();
809 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
810 EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
811 }
812
TEST_P(HloDataflowAnalysisTest,SwizzlingWhile)813 TEST_P(HloDataflowAnalysisTest, SwizzlingWhile) {
814 // Test a while instruction with a body which permutes it's tuple parameter
815 // elements. HLO:
816 //
817 // body((F32[], F32[]) %tuple_param):
818 // return Tuple(%tuple_param{1}, %tuple_param{0})
819 //
820 // condition((F32[], F32[]) %tuple_param):
821 // return Constant(false)
822 //
823 // entry:
824 // %constant1 = Constant(1.0)
825 // %constant2 = Constant(2.0)
826 // %tuple = Tuple(%constant1, %constant2)
827 // return While(%tuple, body, condition)
828 //
829 const Shape tuple_shape =
830 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
831
832 auto body_builder = HloComputation::Builder("body");
833 auto body_param = body_builder.AddInstruction(
834 HloInstruction::CreateParameter(0, tuple_shape, "param"));
835 auto body_element_0 = body_builder.AddInstruction(
836 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
837 auto body_element_1 = body_builder.AddInstruction(
838 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
839 body_builder.AddInstruction(
840 HloInstruction::CreateTuple({body_element_1, body_element_0}));
841 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
842
843 auto cond_builder = HloComputation::Builder("condition");
844 auto cond_param = cond_builder.AddInstruction(
845 HloInstruction::CreateParameter(0, tuple_shape, "param"));
846 cond_builder.AddInstruction(
847 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
848 HloComputation* condition =
849 module_->AddEmbeddedComputation(cond_builder.Build());
850
851 auto builder = HloComputation::Builder(TestName());
852 auto constant1 = builder.AddInstruction(
853 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
854 auto constant2 = builder.AddInstruction(
855 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
856 auto tuple = builder.AddInstruction(
857 HloInstruction::CreateTuple({constant1, constant2}));
858 auto xla_while = builder.AddInstruction(
859 HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
860 module_->AddEntryComputation(builder.Build());
861 SCOPED_TRACE(module_->ToString());
862
863 bool ssa_form = GetParam();
864 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
865
866 if (ssa_form) {
867 // Element 0 and 1 in the while should both be phi values.
868 EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
869 EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{0}).is_phi());
870 EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
871 EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{1}).is_phi());
872
873 EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
874 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}).is_phi());
875 EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
876 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi());
877
878 EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
879 EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{0}).is_phi());
880 EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
881 EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi());
882
883 EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
884 EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
885 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{})
886 .live_out_of_module());
887 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0})
888 .live_out_of_module());
889 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
890 .live_out_of_module());
891 } else {
892 // Elements 0 and 1 have both constants as reaching definitions.
893 EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{0}),
894 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
895 analysis.GetValueDefinedAt(constant2)));
896 EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{1}),
897 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
898 analysis.GetValueDefinedAt(constant2)));
899 EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
900 EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
901 }
902 }
903
TEST_P(HloDataflowAnalysisTest,ArraySelect)904 TEST_P(HloDataflowAnalysisTest, ArraySelect) {
905 // Test a kSelect of an array value.
906 auto builder = HloComputation::Builder(TestName());
907 auto pred = builder.AddInstruction(
908 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
909 auto constant1 = builder.AddInstruction(
910 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
911 auto constant2 = builder.AddInstruction(
912 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
913 auto select = builder.AddInstruction(HloInstruction::CreateTernary(
914 scalar_shape_, HloOpcode::kSelect, pred, constant1, constant2));
915
916 module_->AddEntryComputation(builder.Build());
917 SCOPED_TRACE(module_->ToString());
918
919 bool ssa_form = GetParam();
920 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
921
922 EXPECT_TRUE(analysis.ValueIsDefinedAt(select));
923 EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
924 EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
925 EXPECT_TRUE(analysis.GetValueDefinedAt(select).live_out_of_module());
926 }
927
TEST_P(HloDataflowAnalysisTest,TupleSelect)928 TEST_P(HloDataflowAnalysisTest, TupleSelect) {
929 // Test a kTupleSelect. Non-top-level element flow through the instruction.
930 auto builder = HloComputation::Builder(TestName());
931 auto pred = builder.AddInstruction(
932 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
933 auto constant1 = builder.AddInstruction(
934 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
935 auto constant2 = builder.AddInstruction(
936 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
937 auto constant3 = builder.AddInstruction(
938 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
939 auto constant4 = builder.AddInstruction(
940 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4.0)));
941 auto tuple1 =
942 builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
943 auto tuple2 =
944 builder.AddInstruction(HloInstruction::CreateTuple({constant2}));
945 auto tuple3 =
946 builder.AddInstruction(HloInstruction::CreateTuple({constant3}));
947 auto tuple4 =
948 builder.AddInstruction(HloInstruction::CreateTuple({constant4}));
949 const Shape tuple_shape = tuple1->shape();
950 auto select11 = builder.AddInstruction(HloInstruction::CreateTernary(
951 tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple1));
952 auto select12 = builder.AddInstruction(HloInstruction::CreateTernary(
953 tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple2));
954 auto select34 = builder.AddInstruction(HloInstruction::CreateTernary(
955 tuple_shape, HloOpcode::kTupleSelect, pred, tuple3, tuple4));
956 auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary(
957 tuple_shape, HloOpcode::kTupleSelect, pred, select12, select34));
958
959 module_->AddEntryComputation(builder.Build());
960 SCOPED_TRACE(module_->ToString());
961
962 bool ssa_form = GetParam();
963 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
964
965 // Top-level value is always defined by a kTupleSelect.
966 EXPECT_TRUE(analysis.ValueIsDefinedAt(select11));
967 EXPECT_TRUE(analysis.ValueIsDefinedAt(select12));
968 EXPECT_TRUE(analysis.ValueIsDefinedAt(select34));
969 EXPECT_TRUE(analysis.ValueIsDefinedAt(select1234));
970
971 EXPECT_FALSE(analysis.ValueIsDefinedAt(select11, /*index=*/{0}));
972 EXPECT_FALSE(analysis.ValueIsDefinedAt(select12, /*index=*/{0}));
973 EXPECT_FALSE(analysis.ValueIsDefinedAt(select34, /*index=*/{0}));
974 EXPECT_FALSE(analysis.ValueIsDefinedAt(select1234, /*index=*/{0}));
975
976 EXPECT_THAT(HloValuesAt(select11, /*index=*/{0}),
977 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1)));
978 EXPECT_THAT(HloValuesAt(select12, /*index=*/{0}),
979 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
980 analysis.GetValueDefinedAt(constant2)));
981 EXPECT_THAT(HloValuesAt(select34, /*index=*/{0}),
982 UnorderedElementsAre(analysis.GetValueDefinedAt(constant3),
983 analysis.GetValueDefinedAt(constant4)));
984 EXPECT_THAT(HloValuesAt(select1234, /*index=*/{0}),
985 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
986 analysis.GetValueDefinedAt(constant2),
987 analysis.GetValueDefinedAt(constant3),
988 analysis.GetValueDefinedAt(constant4)));
989
990 EXPECT_THAT(
991 analysis.GetValueDefinedAt(tuple1, /*index=*/{}).uses(),
992 UnorderedElementsAre(HloUse{select11, 1, {}}, HloUse{select11, 2, {}},
993 HloUse{select12, 1, {}}));
994
995 // The two constant values just pass through the Selects and are not
996 // used except at the root. They are live out however.
997 EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
998 UnorderedElementsAre(HloUse{select1234, 1, {0}}));
999 EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
1000 UnorderedElementsAre(HloUse{select1234, 1, {0}}));
1001 EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
1002 EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
1003 }
1004
TEST_P(HloDataflowAnalysisTest,NestedTupleSelect)1005 TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) {
1006 // Test kTupleSelect of a nested tuple.
1007 auto builder = HloComputation::Builder(TestName());
1008 auto pred = builder.AddInstruction(
1009 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1010 auto constant1 = builder.AddInstruction(
1011 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1012 auto constant2 = builder.AddInstruction(
1013 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
1014 auto constant3 = builder.AddInstruction(
1015 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
1016 auto constant4 = builder.AddInstruction(
1017 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4.0)));
1018 auto constant5 = builder.AddInstruction(
1019 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0)));
1020 auto inner_tuple1 = builder.AddInstruction(
1021 HloInstruction::CreateTuple({constant2, constant3}));
1022 auto tuple1 = builder.AddInstruction(
1023 HloInstruction::CreateTuple({constant1, inner_tuple1}));
1024 auto inner_tuple2 = builder.AddInstruction(
1025 HloInstruction::CreateTuple({constant5, constant3}));
1026 auto tuple2 = builder.AddInstruction(
1027 HloInstruction::CreateTuple({constant4, inner_tuple2}));
1028 auto select = builder.AddInstruction(HloInstruction::CreateTernary(
1029 tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
1030
1031 module_->AddEntryComputation(builder.Build());
1032 SCOPED_TRACE(module_->ToString());
1033
1034 bool ssa_form = GetParam();
1035 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1036
1037 EXPECT_TRUE(analysis.ValueIsDefinedAt(select));
1038
1039 EXPECT_THAT(HloValuesAt(select, /*index=*/{0}),
1040 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
1041 analysis.GetValueDefinedAt(constant4)));
1042 EXPECT_THAT(HloValuesAt(select, /*index=*/{1}),
1043 UnorderedElementsAre(analysis.GetValueDefinedAt(inner_tuple1),
1044 analysis.GetValueDefinedAt(inner_tuple2)));
1045 EXPECT_THAT(HloValuesAt(select, /*index=*/{1, 0}),
1046 UnorderedElementsAre(analysis.GetValueDefinedAt(constant2),
1047 analysis.GetValueDefinedAt(constant5)));
1048 EXPECT_THAT(HloValuesAt(select, /*index=*/{1, 1}),
1049 UnorderedElementsAre(analysis.GetValueDefinedAt(constant3)));
1050 }
1051
TEST_P(HloDataflowAnalysisTest,TupleSelectToWhile)1052 TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) {
1053 // Test a tuple-shaped kTupleSelect feeding a kWhile instruction. HLO:
1054 //
1055 // body((F32[], F32[]) %tuple_param):
1056 // %add = Add(%tuple_param{0}, %tuple_param{1})
1057 // return Tuple(%tuple_param{0}, %add)
1058 //
1059 // condition((F32[], F32[]) %tuple_param):
1060 // return Constant(false)
1061 //
1062 // entry:
1063 // %constant1 = Constant(1.0)
1064 // %constant2 = Constant(2.0)
1065 // %constant3 = Constant(3.0)
1066 // %tuple1 = Tuple(%constant1)
1067 // %tuple2 = Tuple(%constant2)
1068 // %select = Select(%tuple1, %tuple2)
1069 // %gte = GetTupleElement(%select, 0)
1070 // %tuple = Tuple(%gte, %constant3)
1071 // return While(%tuple, body, condition)
1072 //
1073 auto builder = HloComputation::Builder(TestName());
1074
1075 const Shape tuple_shape =
1076 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1077
1078 // Element 0 passes transparently through the body.
1079 auto body_builder = HloComputation::Builder("body");
1080 auto body_param = body_builder.AddInstruction(
1081 HloInstruction::CreateParameter(0, tuple_shape, "param"));
1082 auto body_element_0 = body_builder.AddInstruction(
1083 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
1084 auto body_element_1 = body_builder.AddInstruction(
1085 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
1086 auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
1087 scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
1088 body_builder.AddInstruction(
1089 HloInstruction::CreateTuple({body_element_0, add}));
1090 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
1091
1092 auto cond_builder = HloComputation::Builder("condition");
1093 cond_builder.AddInstruction(
1094 HloInstruction::CreateParameter(0, tuple_shape, "param"));
1095 cond_builder.AddInstruction(
1096 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1097 HloComputation* condition =
1098 module_->AddEmbeddedComputation(cond_builder.Build());
1099
1100 auto pred = builder.AddInstruction(
1101 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1102 auto constant1 = builder.AddInstruction(
1103 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1104 auto constant2 = builder.AddInstruction(
1105 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
1106 auto constant3 = builder.AddInstruction(
1107 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
1108 auto tuple1 =
1109 builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
1110 auto tuple2 =
1111 builder.AddInstruction(HloInstruction::CreateTuple({constant2}));
1112 auto select = builder.AddInstruction(HloInstruction::CreateTernary(
1113 tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
1114 auto gte = builder.AddInstruction(
1115 HloInstruction::CreateGetTupleElement(scalar_shape_, select, 0));
1116 auto tuple =
1117 builder.AddInstruction(HloInstruction::CreateTuple({gte, constant3}));
1118 auto xla_while = builder.AddInstruction(
1119 HloInstruction::CreateWhile(tuple->shape(), condition, body, tuple));
1120
1121 module_->AddEntryComputation(builder.Build());
1122 SCOPED_TRACE(module_->ToString());
1123
1124 bool ssa_form = GetParam();
1125 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1126
1127 if (ssa_form) {
1128 EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
1129 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}).is_phi());
1130 EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
1131 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi());
1132
1133 EXPECT_FALSE(analysis.ValueIsDefinedAt(select, /*index=*/{0}));
1134
1135 EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
1136 EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
1137 EXPECT_FALSE(analysis.GetValueDefinedAt(constant3).live_out_of_module());
1138 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
1139 .live_out_of_module());
1140 } else {
1141 EXPECT_THAT(HloValuesAt(gte),
1142 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
1143 analysis.GetValueDefinedAt(constant2)));
1144 EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{0}),
1145 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
1146 analysis.GetValueDefinedAt(constant2)));
1147 EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{1}),
1148 UnorderedElementsAre(analysis.GetValueDefinedAt(add),
1149 analysis.GetValueDefinedAt(constant3)));
1150 EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
1151 EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
1152 EXPECT_TRUE(analysis.GetValueDefinedAt(constant3).live_out_of_module());
1153 }
1154 }
1155
TEST_P(HloDataflowAnalysisTest,BitcastDefinesValue)1156 TEST_P(HloDataflowAnalysisTest, BitcastDefinesValue) {
1157 // Test the bitcast_defines_value flag to the dataflow analysis.
1158 auto builder = HloComputation::Builder(TestName());
1159 auto constant = builder.AddInstruction(
1160 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1161 auto bitcast = builder.AddInstruction(
1162 HloInstruction::CreateBitcast(scalar_shape_, constant));
1163
1164 module_->AddEntryComputation(builder.Build());
1165 SCOPED_TRACE(module_->ToString());
1166
1167 bool ssa_form = GetParam();
1168 {
1169 const HloDataflowAnalysis& analysis =
1170 RunAnalysis(ssa_form, /*bitcast_defines_value=*/true);
1171
1172 EXPECT_EQ(analysis.values().size(), 2);
1173
1174 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant));
1175 EXPECT_TRUE(analysis.ValueIsDefinedAt(bitcast));
1176 EXPECT_FALSE(analysis.GetValueDefinedAt(constant).live_out_of_module());
1177 EXPECT_TRUE(analysis.GetValueDefinedAt(bitcast).live_out_of_module());
1178 }
1179 {
1180 const HloDataflowAnalysis& analysis =
1181 RunAnalysis(ssa_form, /*bitcast_defines_value=*/false);
1182 EXPECT_EQ(analysis.values().size(), 1);
1183
1184 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant));
1185 EXPECT_FALSE(analysis.ValueIsDefinedAt(bitcast));
1186 EXPECT_TRUE(analysis.GetValueDefinedAt(constant).live_out_of_module());
1187 }
1188 }
1189
TEST_P(HloDataflowAnalysisTest,TupleCopy)1190 TEST_P(HloDataflowAnalysisTest, TupleCopy) {
1191 // Test that a tuple-shaped copy only copies (defines) the top-level value.
1192 auto builder = HloComputation::Builder(TestName());
1193 auto param0 = builder.AddInstruction(
1194 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
1195 auto param1 = builder.AddInstruction(
1196 HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
1197 auto tuple =
1198 builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
1199 auto copy = builder.AddInstruction(
1200 HloInstruction::CreateUnary(tuple->shape(), HloOpcode::kCopy, tuple));
1201 module_->AddEntryComputation(builder.Build());
1202 SCOPED_TRACE(module_->ToString());
1203
1204 bool ssa_form = GetParam();
1205 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1206
1207 EXPECT_EQ(analysis.values().size(), 4);
1208
1209 EXPECT_TRUE(analysis.ValueIsDefinedAt(param0));
1210 EXPECT_TRUE(analysis.ValueIsDefinedAt(param1));
1211 EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple, /*index=*/{}));
1212 EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{0}));
1213 EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{1}));
1214 EXPECT_TRUE(analysis.ValueIsDefinedAt(copy, /*index=*/{}));
1215 EXPECT_FALSE(analysis.ValueIsDefinedAt(copy, /*index=*/{0}));
1216 EXPECT_FALSE(analysis.ValueIsDefinedAt(copy, /*index=*/{1}));
1217
1218 EXPECT_THAT(HloValuesAt(copy, /*index=*/{0}),
1219 UnorderedElementsAre(analysis.GetValueDefinedAt(param0)));
1220 EXPECT_THAT(HloValuesAt(copy, /*index=*/{1}),
1221 UnorderedElementsAre(analysis.GetValueDefinedAt(param1)));
1222 EXPECT_TRUE(
1223 analysis.GetValueDefinedAt(copy, /*index=*/{}).live_out_of_module());
1224 }
1225
TEST_P(HloDataflowAnalysisTest,CopyStartAndCopyDone)1226 TEST_P(HloDataflowAnalysisTest, CopyStartAndCopyDone) {
1227 // Test that a CopyDone forwards its operand tuple element at {0} to the
1228 // output.
1229 auto builder = HloComputation::Builder(TestName());
1230 auto constant = builder.AddInstruction(
1231 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1232 auto copy_start = builder.AddInstruction(HloInstruction::CreateCopyStart(
1233 ShapeUtil::MakeTupleShape({constant->shape(), constant->shape(),
1234 ShapeUtil::MakeShape(U32, {})}),
1235 constant));
1236 auto copy_done = builder.AddInstruction(HloInstruction::CreateUnary(
1237 constant->shape(), HloOpcode::kCopyDone, copy_start));
1238 module_->AddEntryComputation(builder.Build());
1239 SCOPED_TRACE(module_->ToString());
1240
1241 bool ssa_form = GetParam();
1242 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1243
1244 EXPECT_EQ(analysis.values().size(), 4);
1245
1246 EXPECT_TRUE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{}));
1247 EXPECT_TRUE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{0}));
1248 EXPECT_FALSE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{1}));
1249 EXPECT_TRUE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{2}));
1250 EXPECT_FALSE(analysis.ValueIsDefinedAt(copy_done, /*index=*/{}));
1251 EXPECT_THAT(
1252 HloValuesAt(copy_done, /*index=*/{}),
1253 UnorderedElementsAre(analysis.GetValueDefinedAt(copy_start, {0})));
1254 EXPECT_TRUE(analysis.GetValueDefinedAt(copy_start, /*index=*/{0})
1255 .live_out_of_module());
1256 }
1257
TEST_P(HloDataflowAnalysisTest,SendAndSendDone)1258 TEST_P(HloDataflowAnalysisTest, SendAndSendDone) {
1259 // Test that a Send forwards its operand to the output tuple at {0}.
1260 auto builder = HloComputation::Builder(TestName());
1261 auto param = builder.AddInstruction(
1262 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
1263 auto token = builder.AddInstruction(HloInstruction::CreateToken());
1264 auto send = builder.AddInstruction(
1265 HloInstruction::CreateSend(param, token, /*channel_id=*/0));
1266 auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
1267 module_->AddEntryComputation(builder.Build());
1268 SCOPED_TRACE(module_->ToString());
1269
1270 bool ssa_form = GetParam();
1271 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1272
1273 EXPECT_EQ(analysis.values().size(), 6);
1274
1275 EXPECT_TRUE(analysis.ValueIsDefinedAt(param));
1276 EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{}));
1277 EXPECT_FALSE(analysis.ValueIsDefinedAt(send, /*index=*/{0}));
1278 EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{1}));
1279 EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{2}));
1280 EXPECT_TRUE(analysis.ValueIsDefinedAt(send_done));
1281 EXPECT_THAT(HloValuesAt(send, /*index=*/{0}),
1282 UnorderedElementsAre(analysis.GetValueDefinedAt(param)));
1283 }
1284
TEST_P(HloDataflowAnalysisTest,SetDimensionSizeForwardsValue)1285 TEST_P(HloDataflowAnalysisTest, SetDimensionSizeForwardsValue) {
1286 auto builder = HloComputation::Builder(TestName());
1287 auto param = builder.AddInstruction(
1288 HloInstruction::CreateParameter(0, vector_shape_, "param"));
1289 auto size = builder.AddInstruction(
1290 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(3)));
1291 auto sds = builder.AddInstruction(
1292 HloInstruction::CreateSetDimensionSize(vector_shape_, param, size, 0));
1293
1294 module_->AddEntryComputation(builder.Build());
1295 SCOPED_TRACE(module_->ToString());
1296
1297 bool ssa_form = GetParam();
1298 {
1299 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1300 EXPECT_EQ(analysis.values().size(), 2);
1301
1302 EXPECT_TRUE(analysis.ValueIsDefinedAt(param));
1303 EXPECT_FALSE(analysis.ValueIsDefinedAt(sds));
1304 EXPECT_TRUE(analysis.GetValueDefinedAt(param).live_out_of_module());
1305 }
1306 }
1307
TEST_P(HloDataflowAnalysisTest,RecvAndRecvDone)1308 TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) {
1309 // Test that a RecvDone forwards its operand tuple element at {0} to element
1310 // {0} of the output.
1311 auto builder = HloComputation::Builder(TestName());
1312 auto token = builder.AddInstruction(HloInstruction::CreateToken());
1313 auto recv = builder.AddInstruction(
1314 HloInstruction::CreateRecv(scalar_shape_, token, /*channel_id=*/0));
1315 auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
1316 module_->AddEntryComputation(builder.Build());
1317 SCOPED_TRACE(module_->ToString());
1318
1319 bool ssa_form = GetParam();
1320 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1321
1322 EXPECT_EQ(analysis.values().size(), 7);
1323
1324 EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{}));
1325 EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0}));
1326 EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{1}));
1327 EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{2}));
1328 EXPECT_TRUE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{}));
1329 EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{0}));
1330 EXPECT_TRUE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{1}));
1331 EXPECT_THAT(HloValuesAt(recv_done, /*index=*/{0}),
1332 UnorderedElementsAre(analysis.GetValueDefinedAt(recv, {0})));
1333 EXPECT_TRUE(
1334 analysis.GetValueDefinedAt(recv, /*index=*/{0}).live_out_of_module());
1335 }
1336
TEST_P(HloDataflowAnalysisTest,ElementwiseChainInterference)1337 TEST_P(HloDataflowAnalysisTest, ElementwiseChainInterference) {
1338 // A simple chain of elementwise operations. No values should interfere.
1339 //
1340 // param --> negate -> exp -> log
1341 //
1342 auto builder = HloComputation::Builder(TestName());
1343 auto param = builder.AddInstruction(
1344 HloInstruction::CreateParameter(0, vector_shape_, "param"));
1345 auto negate = builder.AddInstruction(
1346 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
1347 auto exp = builder.AddInstruction(
1348 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, negate));
1349 auto log = builder.AddInstruction(
1350 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kLog, exp));
1351
1352 module_->AddEntryComputation(builder.Build());
1353 SCOPED_TRACE(module_->ToString());
1354 RunAnalysis(GetParam());
1355
1356 DependencyHloOrdering ordering(module_.get());
1357
1358 // No values should interfere.
1359 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, negate));
1360 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
1361 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, log));
1362 EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, exp));
1363 EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, log));
1364 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, negate));
1365 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, log));
1366 EXPECT_FALSE(InstructionsMayInterfere(ordering, log, negate));
1367 EXPECT_FALSE(InstructionsMayInterfere(ordering, log, exp));
1368
1369 // Values should interfere with itself.
1370 EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, exp));
1371 }
1372
TEST_P(HloDataflowAnalysisTest,MultipleEntryParameters_Sequential)1373 TEST_P(HloDataflowAnalysisTest, MultipleEntryParameters_Sequential) {
1374 // Two entry params, which interfere with each other.
1375 //
1376 // param0 --> negate ---------------\
1377 // param1 --> exp --> add
1378 auto builder = HloComputation::Builder(TestName());
1379 auto param0 = builder.AddInstruction(
1380 HloInstruction::CreateParameter(0, vector_shape_, "param0"));
1381 auto param1 = builder.AddInstruction(
1382 HloInstruction::CreateParameter(1, vector_shape_, "param1"));
1383 auto negate = builder.AddInstruction(
1384 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param0));
1385 auto exp = builder.AddInstruction(
1386 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param1));
1387 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1388 vector_shape_, HloOpcode::kAdd, negate, exp));
1389
1390 auto entry = module_->AddEntryComputation(builder.Build());
1391 SCOPED_TRACE(module_->ToString());
1392 RunAnalysis(GetParam());
1393
1394 HloSchedule schedule(module_.get());
1395 schedule.set_sequence(entry, {param0, negate, param1, exp, add});
1396 TF_ASSERT_OK(schedule.Verify());
1397 SequentialHloOrdering ordering(schedule);
1398
1399 // Entry parameters interfere as if they are defined simultaneously at
1400 // the very beginning.
1401 EXPECT_TRUE(InstructionsMayInterfere(ordering, param0, param1));
1402 EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, negate));
1403 EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, exp));
1404 EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, add));
1405 EXPECT_TRUE(InstructionsMayInterfere(ordering, param1, param0));
1406 EXPECT_TRUE(InstructionsMayInterfere(ordering, param1, negate));
1407 EXPECT_FALSE(InstructionsMayInterfere(ordering, param1, exp));
1408 EXPECT_FALSE(InstructionsMayInterfere(ordering, param1, add));
1409
1410 // Negate and exp still interfere.
1411 EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp));
1412 EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate));
1413
1414 // But {negate, add} and {exp, add} don't interfere.
1415 EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add));
1416 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate));
1417 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add));
1418 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp));
1419 }
1420
TEST_P(HloDataflowAnalysisTest,WhileParameters_Sequential)1421 TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) {
1422 // Similar to MultipleEntryParameters_Sequential, but the parameter is of
1423 // while body computation. Body computation in the sequential order:
1424 //
1425 // %constant = Constant(...)
1426 // %exp = Exp(%constant)
1427 // %param = Param(0)
1428 // %add = Add(%param, %exp) ;; Root of body
1429 // %dead_constant = Constant(...)
1430 // %dead_negate = Negate(%dead_constant)
1431 //
1432 // %constant and its only use %exp are ordered before 'param'. However, the
1433 // %constant and %param values still interfere because the parameter is
1434 // considered live into the while body.
1435 //
1436 // Similarly, %dead_constant and %dead_negate are ordered after the root of
1437 // the body computation %add. However, %add is liveout of the computation so
1438 // %dead_constant and %add interfere.
1439 auto body_builder = HloComputation::Builder(TestName());
1440 auto body_param = body_builder.AddInstruction(
1441 HloInstruction::CreateParameter(0, scalar_shape_, "body_param"));
1442 auto constant = body_builder.AddInstruction(
1443 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1444 auto exp = body_builder.AddInstruction(
1445 HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant));
1446 auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
1447 scalar_shape_, HloOpcode::kAdd, exp, body_param));
1448 auto dead_constant = body_builder.AddInstruction(
1449 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1450 auto dead_negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
1451 scalar_shape_, HloOpcode::kNegate, dead_constant));
1452 HloComputation* body = module_->AddEmbeddedComputation(
1453 body_builder.Build(/*root_instruction=*/add));
1454
1455 auto cond_builder = HloComputation::Builder("condition");
1456 auto cond_param = cond_builder.AddInstruction(
1457 HloInstruction::CreateParameter(0, scalar_shape_, "cond_param"));
1458 auto cond_constant = cond_builder.AddInstruction(
1459 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1460 HloComputation* condition =
1461 module_->AddEmbeddedComputation(cond_builder.Build());
1462
1463 auto builder = HloComputation::Builder(TestName());
1464 auto param = builder.AddInstruction(
1465 HloInstruction::CreateParameter(0, scalar_shape_, "param"));
1466 auto xla_while = builder.AddInstruction(
1467 HloInstruction::CreateWhile(scalar_shape_, condition, body, param));
1468
1469 auto entry = module_->AddEntryComputation(builder.Build());
1470 SCOPED_TRACE(module_->ToString());
1471 bool ssa_form = GetParam();
1472 RunAnalysis(ssa_form);
1473
1474 HloSchedule schedule(module_.get());
1475 schedule.set_sequence(entry, {param, xla_while});
1476 schedule.set_sequence(condition, {cond_param, cond_constant});
1477 // Construct the order such that 'constant' and its use 'exp' are before
1478 // body_param.
1479 schedule.set_sequence(
1480 body, {constant, exp, body_param, add, dead_constant, dead_negate});
1481 TF_ASSERT_OK(schedule.Verify());
1482
1483 SequentialHloOrdering ordering(schedule);
1484
1485 // 'add' is live out of the body and will interfere with an later instructions
1486 // such as 'dead_constant' and 'dead_negate'.
1487 EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_constant));
1488 EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_negate));
1489
1490 // The remaining checks test phi values defined by body and condition
1491 // parameters which only occur in the SSA form of the analysis.
1492 if (ssa_form) {
1493 // Though the ordering suggests 'constant' and 'param' should not interfere,
1494 // 'param' is live in and thus interferes with any earlier instruction of
1495 // the computation in the order (eg 'constant')'
1496 EXPECT_TRUE(InstructionsMayInterfere(ordering, body_param, constant));
1497 EXPECT_TRUE(InstructionsMayInterfere(ordering, body_param, exp));
1498 EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, add));
1499
1500 // The following values end up in the same buffer:
1501 // (1) the init value: 'param'
1502 // (2) the body parameter: 'body_param'
1503 // (3) the condition parameter: 'cond_param'
1504 // (4) the root value of the while body: 'add'
1505 // (5) the while value: 'xla_while'
1506 // None should interfere.
1507 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, body_param));
1508 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, cond_param));
1509 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add));
1510 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, xla_while));
1511
1512 EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, cond_param));
1513 EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, add));
1514 EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, xla_while));
1515
1516 EXPECT_FALSE(InstructionsMayInterfere(ordering, cond_param, add));
1517 EXPECT_FALSE(InstructionsMayInterfere(ordering, cond_param, xla_while));
1518
1519 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, xla_while));
1520 }
1521 }
1522
TEST_P(HloDataflowAnalysisTest,NonElementwiseOperand)1523 TEST_P(HloDataflowAnalysisTest, NonElementwiseOperand) {
1524 // A chain of operations with two elementwise and one non-elementwise. The
1525 // elementwise op should not interfere with its operand, while the
1526 // non-elementwise op should interfere. Entry params always interfere.
1527 //
1528 // param --> exp -> negate -> reverse
1529 //
1530 auto builder = HloComputation::Builder(TestName());
1531 auto param = builder.AddInstruction(
1532 HloInstruction::CreateParameter(0, vector_shape_, "param"));
1533 auto exp = builder.AddInstruction(
1534 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
1535 auto negate = builder.AddInstruction(
1536 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, exp));
1537 auto reverse = builder.AddInstruction(
1538 HloInstruction::CreateReverse(vector_shape_, negate, {0}));
1539
1540 module_->AddEntryComputation(builder.Build());
1541 SCOPED_TRACE(module_->ToString());
1542 RunAnalysis(GetParam());
1543
1544 DependencyHloOrdering ordering(module_.get());
1545
1546 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
1547 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, negate));
1548 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, reverse));
1549
1550 // Negate is elementwise, so doesn't interfere with its operand.
1551 // Reverse is non-elementwise, so does interfere with its operand.
1552 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, negate));
1553 EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, reverse));
1554 }
1555
TEST_P(HloDataflowAnalysisTest,OverlappedValues)1556 TEST_P(HloDataflowAnalysisTest, OverlappedValues) {
1557 // Verify simultaneously live values interfere (exp and negate).
1558 //
1559 // param --> negate -> add
1560 // \---> exp -----/
1561 //
1562 auto builder = HloComputation::Builder(TestName());
1563 auto param = builder.AddInstruction(
1564 HloInstruction::CreateParameter(0, vector_shape_, "param"));
1565 auto negate = builder.AddInstruction(
1566 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
1567 auto exp = builder.AddInstruction(
1568 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
1569 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1570 vector_shape_, HloOpcode::kAdd, negate, exp));
1571
1572 module_->AddEntryComputation(builder.Build());
1573 SCOPED_TRACE(module_->ToString());
1574 RunAnalysis(GetParam());
1575
1576 DependencyHloOrdering ordering(module_.get());
1577
1578 EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate));
1579 EXPECT_TRUE(InstructionsMayInterfere(ordering, param, exp));
1580 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add));
1581
1582 // Negate and exp interfere with each other, but not with add.
1583 EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp));
1584 EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate));
1585 EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add));
1586 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate));
1587 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add));
1588 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp));
1589 }
1590
TEST_P(HloDataflowAnalysisTest,OverlappedValuesSequentialOrder)1591 TEST_P(HloDataflowAnalysisTest, OverlappedValuesSequentialOrder) {
1592 // Identical to the test OverlappedValue but using a sequential ordering of
1593 // HLO instructions.
1594 //
1595 // param --> negate -> add
1596 // \---> exp -----/
1597 //
1598 // Sequential order:
1599 // param, negate, exp, add
1600 //
1601 // Liveness is identical to the DependencyHloOrdering.
1602 auto builder = HloComputation::Builder(TestName());
1603 auto param = builder.AddInstruction(
1604 HloInstruction::CreateParameter(0, vector_shape_, "param"));
1605 auto negate = builder.AddInstruction(
1606 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
1607 auto exp = builder.AddInstruction(
1608 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
1609 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1610 vector_shape_, HloOpcode::kAdd, negate, exp));
1611
1612 auto entry = module_->AddEntryComputation(builder.Build());
1613 SCOPED_TRACE(module_->ToString());
1614 RunAnalysis(GetParam());
1615
1616 HloSchedule schedule(module_.get());
1617 schedule.set_sequence(entry, {param, negate, exp, add});
1618 TF_ASSERT_OK(schedule.Verify());
1619 SequentialHloOrdering ordering(schedule);
1620
1621 EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate));
1622 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
1623 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add));
1624
1625 // Negate and exp interfere with each other, but not with add.
1626 EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp));
1627 EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate));
1628 EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add));
1629 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate));
1630 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add));
1631 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp));
1632 }
1633
TEST_P(HloDataflowAnalysisTest,EmbeddedComputationInterference)1634 TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) {
1635 // Test MayInterfere() for embedded computation, specifically the interference
1636 // of values in different computations.
1637 //
1638 // embedded_computation:
1639 // %embedded_param = Param(0)
1640 // %embedded_log = Log(%embedded_param)
1641 //
1642 // entry computation:
1643 // %param = Param(0)
1644 // %negate = Negate(%param)
1645 // %exp = Negate(%exp)
1646 // %call = Call(embedded_computation, {%exp})
1647 // %add = Add(%negate, %call)
1648 //
1649 // Note %negate is live across the call and should interfere with all values
1650 // in the embedded computation.
1651 auto embedded_builder = HloComputation::Builder(TestName() + "_embedded");
1652 auto embedded_param = embedded_builder.AddInstruction(
1653 HloInstruction::CreateParameter(0, vector_shape_, "embedded_param"));
1654 auto embedded_log =
1655 embedded_builder.AddInstruction(HloInstruction::CreateUnary(
1656 vector_shape_, HloOpcode::kLog, embedded_param));
1657 auto embedded_computation =
1658 module_->AddEmbeddedComputation(embedded_builder.Build());
1659
1660 auto builder = HloComputation::Builder(TestName());
1661 auto param = builder.AddInstruction(
1662 HloInstruction::CreateParameter(0, vector_shape_, "param"));
1663 auto negate = builder.AddInstruction(
1664 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
1665 auto exp = builder.AddInstruction(
1666 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
1667 auto call = builder.AddInstruction(
1668 HloInstruction::CreateCall(vector_shape_, {exp}, embedded_computation));
1669 builder.AddInstruction(HloInstruction::CreateBinary(
1670 vector_shape_, HloOpcode::kAdd, negate, call));
1671 module_->AddEntryComputation(builder.Build());
1672 SCOPED_TRACE(module_->ToString());
1673 RunAnalysis(GetParam());
1674
1675 DependencyHloOrdering ordering(module_.get());
1676
1677 // Exp only use is the call so it should not interfere with values inside
1678 // the embedded computation.
1679 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, embedded_log));
1680
1681 // Negate is live across the call and should interfere with values in the
1682 // embedded computation
1683 EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, embedded_log));
1684 }
1685
TEST_P(HloDataflowAnalysisTest,GetFlattenedValueSet)1686 TEST_P(HloDataflowAnalysisTest, GetFlattenedValueSet) {
1687 const char* hlo_text = R"(
1688 HloModule test_aliasing_module
1689
1690 ENTRY root {
1691 param = s32[1000] parameter(0)
1692 p0 = s32[1000] copy(param)
1693 p1 = s32[1000] copy(param)
1694 ROOT t = (s32[1000], s32[1000]) tuple(p0, p1)
1695 })";
1696 TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_text));
1697 auto entry = module_->entry_computation();
1698 entry->GetInstructionWithName("t");
1699 auto& dataflow_analysis = RunAnalysis(GetParam());
1700 auto set = dataflow_analysis.GetFlattenedValueSet(
1701 entry->GetInstructionWithName("t"));
1702 EXPECT_EQ(set.values().size(), 3);
1703 }
1704
TEST_P(HloDataflowAnalysisTest,ConditionalWithIdentity)1705 TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) {
1706 // Test conditional with identity computations in both true and false cases.
1707 //
1708 // true_computation(F32[] %true_param):
1709 // return %true_param
1710 //
1711 // false_computation(F32[] %false_param):
1712 // return %false_param
1713 //
1714 // entry:
1715 // %pred = Constant(true)
1716 // %constant1 = Constant(56.0)
1717 // %constant2 = Constant(12.0)
1718 // return Conditional(%pred, %constant1, true_computation,
1719 // %constant2, false_computation)
1720
1721 auto true_builder = HloComputation::Builder(TestName() + "_true");
1722 auto true_param = true_builder.AddInstruction(
1723 HloInstruction::CreateParameter(0, scalar_shape_, "true_param"));
1724 HloComputation* true_computation =
1725 module_->AddEmbeddedComputation(true_builder.Build());
1726
1727 auto false_builder = HloComputation::Builder(TestName() + "_false");
1728 auto false_param = false_builder.AddInstruction(
1729 HloInstruction::CreateParameter(0, scalar_shape_, "false_param"));
1730 HloComputation* false_computation =
1731 module_->AddEmbeddedComputation(false_builder.Build());
1732
1733 auto builder = HloComputation::Builder(TestName());
1734 auto pred = builder.AddInstruction(
1735 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1736 auto constant1 = builder.AddInstruction(
1737 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.0f)));
1738 auto constant2 = builder.AddInstruction(
1739 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.0f)));
1740 auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
1741 scalar_shape_, pred, constant1, true_computation, constant2,
1742 false_computation));
1743 module_->AddEntryComputation(builder.Build());
1744 SCOPED_TRACE(module_->ToString());
1745
1746 const HloDataflowAnalysis& analysis = RunAnalysis(GetParam());
1747
1748 EXPECT_TRUE(analysis.ValueIsDefinedAt(pred));
1749 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
1750 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
1751
1752 EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param));
1753 EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param));
1754
1755 EXPECT_EQ(analysis.GetUniqueValueAt(true_param),
1756 analysis.GetValueDefinedAt(constant1));
1757 EXPECT_EQ(analysis.GetUniqueValueAt(false_param),
1758 analysis.GetValueDefinedAt(constant2));
1759
1760 EXPECT_THAT(analysis.GetValueDefinedAt(pred).uses(),
1761 ElementsAre(HloUse{conditional, 0, {}}));
1762 EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
1763 ElementsAre(HloUse{conditional, 1, {}}));
1764 EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
1765 ElementsAre(HloUse{conditional, 2, {}}));
1766
1767 bool ssa_form = GetParam();
1768 if (ssa_form) {
1769 EXPECT_EQ(analysis.values().size(), 4);
1770 EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional));
1771 } else {
1772 EXPECT_EQ(analysis.values().size(), 3);
1773 EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
1774 EXPECT_THAT(HloValuesAt(conditional),
1775 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
1776 analysis.GetValueDefinedAt(constant2)));
1777 }
1778 }
1779
TEST_P(HloDataflowAnalysisTest,ConditionalTakingTupleOperand)1780 TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) {
1781 // Test conditional with true and false computations taking a tuple operand.
1782 //
1783 // true_computation((F32[], F32[]) %true_param):
1784 // %true_x = GetTupleElement(%true_param, 0)
1785 // %true_y = GetTupleElement(%true_param, 1)
1786 // return Add(%true_x, %true_y)
1787 //
1788 // false_computation((F32[], F32[]) %false_param):
1789 // %false_x = GetTupleElement(%false_param, 0)
1790 // %false_y = GetTupleElement(%false_param, 1)
1791 // return Subtract(%false_x, %false_y)
1792 //
1793 // entry:
1794 // %pred = Constant(true)
1795 // %constant1 = Constant(56.0)
1796 // %constant2 = Constant(12.0)
1797 // %tuple_operand = Tuple(%constant1, %constant2)
1798 // return Conditional(%pred, %tuple_operand, true_computation,
1799 // %tuple_operand, false_computation)
1800
1801 auto true_builder = HloComputation::Builder(TestName() + "_true");
1802 auto true_param = true_builder.AddInstruction(
1803 HloInstruction::CreateParameter(0, tuple_shape_, "true_param"));
1804 auto true_x = true_builder.AddInstruction(
1805 HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 0));
1806 auto true_y = true_builder.AddInstruction(
1807 HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 1));
1808 auto add = true_builder.AddInstruction(HloInstruction::CreateBinary(
1809 scalar_shape_, HloOpcode::kAdd, true_x, true_y));
1810 HloComputation* true_computation =
1811 module_->AddEmbeddedComputation(true_builder.Build());
1812
1813 auto false_builder = HloComputation::Builder(TestName() + "_false");
1814 auto false_param = false_builder.AddInstruction(
1815 HloInstruction::CreateParameter(0, tuple_shape_, "false_param"));
1816 auto false_x = false_builder.AddInstruction(
1817 HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 0));
1818 auto false_y = false_builder.AddInstruction(
1819 HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 1));
1820 auto sub = false_builder.AddInstruction(HloInstruction::CreateBinary(
1821 scalar_shape_, HloOpcode::kSubtract, false_x, false_y));
1822 HloComputation* false_computation =
1823 module_->AddEmbeddedComputation(false_builder.Build());
1824
1825 auto builder = HloComputation::Builder(TestName());
1826 auto pred = builder.AddInstruction(
1827 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1828 auto constant1 = builder.AddInstruction(
1829 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.0f)));
1830 auto constant2 = builder.AddInstruction(
1831 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.0f)));
1832 auto tuple_operand = builder.AddInstruction(
1833 HloInstruction::CreateTuple({constant1, constant2}));
1834 auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
1835 scalar_shape_, pred, tuple_operand, true_computation, tuple_operand,
1836 false_computation));
1837 module_->AddEntryComputation(builder.Build());
1838 SCOPED_TRACE(module_->ToString());
1839
1840 const HloDataflowAnalysis& analysis = RunAnalysis(GetParam());
1841
1842 EXPECT_TRUE(analysis.ValueIsDefinedAt(pred));
1843 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
1844 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
1845 EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand));
1846 EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
1847 EXPECT_TRUE(analysis.ValueIsDefinedAt(sub));
1848
1849 EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param));
1850 EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param));
1851 EXPECT_FALSE(analysis.ValueIsDefinedAt(true_x));
1852 EXPECT_FALSE(analysis.ValueIsDefinedAt(true_y));
1853 EXPECT_FALSE(analysis.ValueIsDefinedAt(false_x));
1854 EXPECT_FALSE(analysis.ValueIsDefinedAt(false_y));
1855
1856 EXPECT_EQ(analysis.GetUniqueValueAt(true_param),
1857 analysis.GetValueDefinedAt(tuple_operand));
1858 EXPECT_EQ(analysis.GetUniqueValueAt(false_param),
1859 analysis.GetValueDefinedAt(tuple_operand));
1860 EXPECT_EQ(analysis.GetUniqueValueAt(true_x),
1861 analysis.GetValueDefinedAt(constant1));
1862 EXPECT_EQ(analysis.GetUniqueValueAt(true_y),
1863 analysis.GetValueDefinedAt(constant2));
1864 EXPECT_EQ(analysis.GetUniqueValueAt(false_x),
1865 analysis.GetValueDefinedAt(constant1));
1866 EXPECT_EQ(analysis.GetUniqueValueAt(false_y),
1867 analysis.GetValueDefinedAt(constant2));
1868
1869 EXPECT_THAT(analysis.GetValueDefinedAt(pred).uses(),
1870 ElementsAre(HloUse{conditional, 0, {}}));
1871 EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
1872 UnorderedElementsAre(HloUse{conditional, 1, {0}},
1873 HloUse{conditional, 2, {0}},
1874 HloUse{add, 0, {}}, HloUse{sub, 0, {}}));
1875 EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
1876 UnorderedElementsAre(HloUse{conditional, 1, {1}},
1877 HloUse{conditional, 2, {1}},
1878 HloUse{add, 1, {}}, HloUse{sub, 1, {}}));
1879 EXPECT_THAT(analysis.GetValueDefinedAt(tuple_operand).uses(),
1880 UnorderedElementsAre(
1881 HloUse{conditional, 1, {}}, HloUse{conditional, 2, {}},
1882 HloUse{true_x, 0, {}}, HloUse{true_y, 0, {}},
1883 HloUse{false_x, 0, {}}, HloUse{false_y, 0, {}}));
1884
1885 bool ssa_form = GetParam();
1886 if (ssa_form) {
1887 EXPECT_EQ(analysis.values().size(), 7);
1888 EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional));
1889 } else {
1890 EXPECT_EQ(analysis.values().size(), 6);
1891 EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
1892 EXPECT_THAT(HloValuesAt(conditional),
1893 UnorderedElementsAre(analysis.GetValueDefinedAt(add),
1894 analysis.GetValueDefinedAt(sub)));
1895 }
1896 }
1897
TEST_P(HloDataflowAnalysisTest,NestedConditionals)1898 TEST_P(HloDataflowAnalysisTest, NestedConditionals) {
1899 // computation1(F32[] %param1):
1900 // %ceil = Ceil(%param1)
1901 // return %ceil
1902 //
1903 // computation2(F32[] %param2):
1904 // %floor = Floor(%param2)
1905 // return %floor
1906 //
1907 // computation3(F32[] %param3):
1908 // %negate = Negate(%param3)
1909 // return %negate
1910 //
1911 // inner_conditional((PRED, F32[], F32[]) %param_cond):
1912 // %pred_cond = GetTupleElement(%param_cond, 0)
1913 // %true_operand_cond = GetTupleElement(%param_cond, 1)
1914 // %false_operand_cond = GetTupleElement(%param_cond, 2)
1915 // return Conditional(%pred_cond, %true_operand_cond, computation1,
1916 // %false_operand_cond, computation2)
1917 //
1918 // entry:
1919 // %pred1 = Constant(true)
1920 // %pred2 = Constant(false)
1921 // %constant1 = Constant(1.1);
1922 // %constant2 = Constant(2.2);
1923 // %constant3 = Constant(3.3);
1924 // return Conditional(%pred1, (%pred2, %constant1, %constant2),
1925 // inner_conditional, %constant3, computation3)
1926
1927 auto computation1 = module_->AddEmbeddedComputation(
1928 CreateR0F32UnaryOpComputation(HloOpcode::kCeil));
1929 auto computation2 = module_->AddEmbeddedComputation(
1930 CreateR0F32UnaryOpComputation(HloOpcode::kFloor));
1931 auto computation3 = module_->AddEmbeddedComputation(
1932 CreateR0F32UnaryOpComputation(HloOpcode::kNegate));
1933
1934 // Build inner_conditional computation.
1935 const Shape scalar_bool_shape = ShapeUtil::MakeShape(PRED, {});
1936 const Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
1937 {scalar_bool_shape, scalar_shape_, scalar_shape_});
1938 auto inner_builder =
1939 HloComputation::Builder(TestName() + "_inner_conditional");
1940 auto param_cond = inner_builder.AddInstruction(
1941 HloInstruction::CreateParameter(0, tuple_param_shape, "param_cond"));
1942 auto pred_cond = inner_builder.AddInstruction(
1943 HloInstruction::CreateGetTupleElement(scalar_bool_shape, param_cond, 0));
1944 auto true_operand_cond = inner_builder.AddInstruction(
1945 HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 1));
1946 auto false_operand_cond = inner_builder.AddInstruction(
1947 HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 2));
1948 auto inner_conditional =
1949 inner_builder.AddInstruction(HloInstruction::CreateConditional(
1950 scalar_shape_, pred_cond, true_operand_cond, computation1,
1951 false_operand_cond, computation2));
1952 auto inner_conditional_computation =
1953 module_->AddEmbeddedComputation(inner_builder.Build());
1954
1955 // Build entry computation.
1956 auto builder = HloComputation::Builder(TestName());
1957 auto pred1 = builder.AddInstruction(
1958 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1959 auto pred2 = builder.AddInstruction(
1960 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1961 auto constant1 = builder.AddInstruction(
1962 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
1963 auto constant2 = builder.AddInstruction(
1964 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.2f)));
1965 auto constant3 = builder.AddInstruction(
1966 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.3f)));
1967 auto tuple_operand = builder.AddInstruction(
1968 HloInstruction::CreateTuple({pred2, constant1, constant2}));
1969 auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
1970 scalar_shape_, pred1, tuple_operand, inner_conditional_computation,
1971 constant3, computation3));
1972 module_->AddEntryComputation(builder.Build());
1973 SCOPED_TRACE(module_->ToString());
1974
1975 const HloDataflowAnalysis& analysis = RunAnalysis(GetParam());
1976
1977 EXPECT_TRUE(analysis.ValueIsDefinedAt(pred1));
1978 EXPECT_TRUE(analysis.ValueIsDefinedAt(pred2));
1979 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
1980 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
1981 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant3));
1982 EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand));
1983 EXPECT_TRUE(analysis.ValueIsDefinedAt(computation1->root_instruction()));
1984 EXPECT_TRUE(analysis.ValueIsDefinedAt(computation2->root_instruction()));
1985 EXPECT_TRUE(analysis.ValueIsDefinedAt(computation3->root_instruction()));
1986
1987 auto computation1_param = computation1->parameter_instruction(0);
1988 auto computation2_param = computation2->parameter_instruction(0);
1989 auto computation3_param = computation3->parameter_instruction(0);
1990 EXPECT_FALSE(analysis.ValueIsDefinedAt(computation1_param));
1991 EXPECT_FALSE(analysis.ValueIsDefinedAt(computation2_param));
1992 EXPECT_FALSE(analysis.ValueIsDefinedAt(computation3_param));
1993 EXPECT_EQ(analysis.GetUniqueValueAt(computation1_param),
1994 analysis.GetValueDefinedAt(constant1));
1995 EXPECT_EQ(analysis.GetUniqueValueAt(computation2_param),
1996 analysis.GetValueDefinedAt(constant2));
1997 EXPECT_EQ(analysis.GetUniqueValueAt(computation3_param),
1998 analysis.GetValueDefinedAt(constant3));
1999
2000 EXPECT_FALSE(analysis.ValueIsDefinedAt(param_cond));
2001 EXPECT_FALSE(analysis.ValueIsDefinedAt(pred_cond));
2002 EXPECT_FALSE(analysis.ValueIsDefinedAt(true_operand_cond));
2003 EXPECT_FALSE(analysis.ValueIsDefinedAt(false_operand_cond));
2004 EXPECT_EQ(analysis.GetUniqueValueAt(param_cond),
2005 analysis.GetValueDefinedAt(tuple_operand));
2006 EXPECT_EQ(analysis.GetUniqueValueAt(pred_cond),
2007 analysis.GetValueDefinedAt(pred2));
2008 EXPECT_EQ(analysis.GetUniqueValueAt(true_operand_cond),
2009 analysis.GetValueDefinedAt(constant1));
2010 EXPECT_EQ(analysis.GetUniqueValueAt(false_operand_cond),
2011 analysis.GetValueDefinedAt(constant2));
2012
2013 bool ssa_form = GetParam();
2014 if (ssa_form) {
2015 EXPECT_EQ(analysis.values().size(), 11);
2016 EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_conditional));
2017 EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional));
2018 } else {
2019 EXPECT_EQ(analysis.values().size(), 9);
2020 EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional));
2021 EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
2022 EXPECT_THAT(
2023 HloValuesAt(inner_conditional),
2024 UnorderedElementsAre(
2025 analysis.GetValueDefinedAt(computation1->root_instruction()),
2026 analysis.GetValueDefinedAt(computation2->root_instruction())));
2027 EXPECT_THAT(
2028 HloValuesAt(conditional),
2029 UnorderedElementsAre(
2030 analysis.GetValueDefinedAt(computation1->root_instruction()),
2031 analysis.GetValueDefinedAt(computation2->root_instruction()),
2032 analysis.GetValueDefinedAt(computation3->root_instruction())));
2033 }
2034 }
2035
TEST_P(HloDataflowAnalysisTest,AddDependency)2036 TEST_P(HloDataflowAnalysisTest, AddDependency) {
2037 string module_string = R"(
2038 HloModule AddDependency
2039 ENTRY %AddDependency (p: f32[3]) -> f32[3] {
2040 %p = f32[3] parameter(0)
2041 %token0 = token[] after-all()
2042 ROOT %add_dep = f32[3] add-dependency(f32[3] %p, token[] %token0)
2043 }
2044 )";
2045 TF_ASSERT_OK_AND_ASSIGN(
2046 std::unique_ptr<HloModule> module,
2047 ParseAndReturnVerifiedModule(module_string, GetModuleConfigForTest()));
2048
2049 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloDataflowAnalysis> analysis,
2050 HloDataflowAnalysis::Run(*module));
2051 const HloInstruction* root = module->entry_computation()->root_instruction();
2052 EXPECT_EQ(root->opcode(), HloOpcode::kAddDependency);
2053
2054 // The after-all and parameter should define a value. Add-dependency should
2055 // not.
2056 EXPECT_EQ(analysis->values().size(), 2);
2057 EXPECT_FALSE(analysis->ValueIsDefinedAt(root));
2058 }
2059
2060 INSTANTIATE_TEST_SUITE_P(HloDataflowAnalysisInstantiation,
2061 HloDataflowAnalysisTest,
2062 ::testing::Values(false, true));
2063
2064 class HloDataflowAnalysisTestBase : public HloTestBase {
2065 protected:
BuildModule(std::unique_ptr<HloComputation> computation)2066 void BuildModule(std::unique_ptr<HloComputation> computation) {
2067 module_ = CreateNewVerifiedModule();
2068 computation_ = module_->AddEntryComputation(std::move(computation));
2069 }
2070
RunAnalysis(const HloDataflowAnalysis::CanShareBuffer & can_share_buffer=nullptr)2071 void RunAnalysis(
2072 const HloDataflowAnalysis::CanShareBuffer& can_share_buffer = nullptr) {
2073 CHECK_NOTNULL(module_.get());
2074 dataflow_analysis_ = HloDataflowAnalysis::Run(
2075 *module_, /*ssa_form=*/false,
2076 /*bitcast_defines_value=*/false, can_share_buffer)
2077 .ConsumeValueOrDie();
2078 }
2079
BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation)2080 void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
2081 BuildModule(std::move(computation));
2082 RunAnalysis();
2083 }
2084
2085 std::unique_ptr<HloModule> module_;
2086 HloComputation* computation_ = nullptr;
2087 std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
2088 };
2089
2090 class DoesNotUseOperandBufferTest : public HloDataflowAnalysisTestBase {};
2091
TEST_F(DoesNotUseOperandBufferTest,GetTupleElement)2092 TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) {
2093 auto builder = HloComputation::Builder(TestName());
2094
2095 Shape elem_shape = ShapeUtil::MakeShape(F32, {8});
2096 auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
2097 0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple"));
2098 auto gte0 = builder.AddInstruction(
2099 HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0));
2100 auto gte1 = builder.AddInstruction(
2101 HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1));
2102 builder.AddInstruction(
2103 HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1));
2104
2105 BuildModuleAndRunAnalysis(builder.Build());
2106
2107 // GetTupleElement instructions only access the top-level buffer of their
2108 // operand.
2109 EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, gte0));
2110 EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, gte1));
2111 EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte0));
2112 EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte1));
2113 }
2114
TEST_F(DoesNotUseOperandBufferTest,FusedDynamicUpdateSlice)2115 TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
2116 auto builder = HloComputation::Builder(TestName());
2117
2118 Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2119 auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
2120 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
2121 auto gte0 = builder.AddInstruction(
2122 HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
2123 auto gte1 = builder.AddInstruction(
2124 HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
2125
2126 // Create a DynamicUpdateSlice instruction of tuple element 1.
2127 auto starts = builder.AddInstruction(
2128 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
2129 auto update = builder.AddInstruction(HloInstruction::CreateConstant(
2130 LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
2131 auto dynamic_update_slice =
2132 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2133 data_shape, gte1, update,
2134 std::initializer_list<HloInstruction*>({starts})));
2135 builder.AddInstruction(
2136 HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
2137
2138 BuildModule(builder.Build());
2139 auto fusion = computation_->CreateFusionInstruction(
2140 {dynamic_update_slice, starts, update, gte1},
2141 HloInstruction::FusionKind::kLoop);
2142 RunAnalysis();
2143
2144 // The fusion instruction never uses tuple element 0, but does use element 1.
2145 EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion));
2146 EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion));
2147 }
2148
2149 // Similar to FusedDynamicUpdateSlice above, but tests indirect uses of the
2150 // parameter tuple.
TEST_F(DoesNotUseOperandBufferTest,IndirectUses)2151 TEST_F(DoesNotUseOperandBufferTest, IndirectUses) {
2152 auto builder = HloComputation::Builder(TestName());
2153
2154 Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2155 auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
2156 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
2157 auto t0 = builder.AddInstruction(
2158 HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 0));
2159 auto t1 = builder.AddInstruction(
2160 HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 1));
2161 // Swap the tuple elements.
2162 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({t1, t0}));
2163
2164 auto gte0 = builder.AddInstruction(
2165 HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
2166 auto gte1 = builder.AddInstruction(
2167 HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
2168
2169 // Create a DynamicUpdateSlice instruction of tuple element 1.
2170 auto starts = builder.AddInstruction(
2171 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
2172 auto update = builder.AddInstruction(HloInstruction::CreateConstant(
2173 LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
2174 auto dynamic_update_slice =
2175 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2176 data_shape, gte1, update,
2177 std::initializer_list<HloInstruction*>({starts})));
2178 builder.AddInstruction(
2179 HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
2180
2181 BuildModule(builder.Build());
2182 auto fusion = computation_->CreateFusionInstruction(
2183 {dynamic_update_slice, starts, update, gte1},
2184 HloInstruction::FusionKind::kLoop);
2185 RunAnalysis();
2186
2187 // The fusion instruction never uses tuple element 0, but does use element 1.
2188 EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion));
2189 EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion));
2190 // The same holds for the parameter tuple, except that the tuple elements
2191 // are swapped in 'tuple'.
2192 EXPECT_TRUE(
2193 dataflow_analysis_->DoesNotUseOperandBuffer(tuple_param, {1}, fusion));
2194 EXPECT_FALSE(
2195 dataflow_analysis_->DoesNotUseOperandBuffer(tuple_param, {0}, fusion));
2196 }
2197
2198 class CanShareOperandBufferWithUserTest : public HloDataflowAnalysisTestBase {};
2199
TEST_F(CanShareOperandBufferWithUserTest,ElementWiseSameShape)2200 TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) {
2201 auto builder = HloComputation::Builder(TestName());
2202
2203 Shape shape = ShapeUtil::MakeShape(F32, {8});
2204 auto param = builder.AddInstruction(
2205 HloInstruction::CreateParameter(0, shape, "param"));
2206 auto exp = builder.AddInstruction(
2207 HloInstruction::CreateUnary(shape, HloOpcode::kExp, param));
2208 auto log = builder.AddInstruction(
2209 HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp));
2210
2211 BuildModuleAndRunAnalysis(builder.Build());
2212
2213 EXPECT_TRUE(
2214 dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {}));
2215 EXPECT_TRUE(
2216 dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, log, {}));
2217 }
2218
TEST_F(CanShareOperandBufferWithUserTest,NonElementwiseLoopFusionCantAliasOperandBuffer)2219 TEST_F(CanShareOperandBufferWithUserTest,
2220 NonElementwiseLoopFusionCantAliasOperandBuffer) {
2221 auto builder = HloComputation::Builder(TestName());
2222 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2223
2224 auto param0 = builder.AddInstruction(
2225 HloInstruction::CreateParameter(0, data_shape, "param0"));
2226
2227 auto neg = builder.AddInstruction(
2228 HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, param0));
2229
2230 auto reverse = builder.AddInstruction(
2231 HloInstruction::CreateReverse(data_shape, neg, {0, 1}));
2232
2233 BuildModule(builder.Build());
2234 auto fusion = computation_->CreateFusionInstruction(
2235 {reverse, neg}, HloInstruction::FusionKind::kLoop);
2236 RunAnalysis();
2237
2238 EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
2239 fusion, {}));
2240 }
2241
TEST_F(CanShareOperandBufferWithUserTest,MultiOutputFusionCanAliasOperandBuffer)2242 TEST_F(CanShareOperandBufferWithUserTest,
2243 MultiOutputFusionCanAliasOperandBuffer) {
2244 auto builder = HloComputation::Builder(TestName());
2245 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2246
2247 Shape in_shape = ShapeUtil::MakeShape(F32, {8});
2248 Shape out_shape = ShapeUtil::MakeShape(PRED, {8});
2249 auto param0 = builder.AddInstruction(
2250 HloInstruction::CreateParameter(0, in_shape, "param0"));
2251 auto param1 = builder.AddInstruction(
2252 HloInstruction::CreateParameter(1, in_shape, "param1"));
2253
2254 auto copy0 = builder.AddInstruction(
2255 HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param0));
2256 auto copy1 = builder.AddInstruction(
2257 HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param1));
2258
2259 auto tuple =
2260 builder.AddInstruction(HloInstruction::CreateTuple({copy1, copy0}));
2261
2262 BuildModule(builder.Build());
2263 auto fusion = computation_->CreateFusionInstruction(
2264 {tuple, copy1, copy0}, HloInstruction::FusionKind::kLoop);
2265 RunAnalysis();
2266
2267 EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
2268 fusion, {0}));
2269 EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
2270 fusion, {1}));
2271 EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
2272 fusion, {0}));
2273 EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
2274 fusion, {1}));
2275 }
2276
TEST_F(CanShareOperandBufferWithUserTest,ElementwiseLoopFusionCantAliasOperandBuffer)2277 TEST_F(CanShareOperandBufferWithUserTest,
2278 ElementwiseLoopFusionCantAliasOperandBuffer) {
2279 auto builder = HloComputation::Builder(TestName());
2280 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2281
2282 auto one = builder.AddInstruction(
2283 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2284 auto operand = builder.AddInstruction(
2285 HloInstruction::CreateBroadcast(data_shape, one, {}));
2286
2287 auto neg = builder.AddInstruction(
2288 HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, operand));
2289
2290 auto exp = builder.AddInstruction(
2291 HloInstruction::CreateUnary(data_shape, HloOpcode::kExp, neg));
2292
2293 BuildModule(builder.Build());
2294 auto fusion = computation_->CreateFusionInstruction(
2295 {exp, neg}, HloInstruction::FusionKind::kLoop);
2296 RunAnalysis();
2297
2298 EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {},
2299 fusion, {}));
2300 }
2301
TEST_F(CanShareOperandBufferWithUserTest,CanShareOperandWhenDynamicUpdateSliceIsFedByDynamicSliceWithSameIndex)2302 TEST_F(CanShareOperandBufferWithUserTest,
2303 CanShareOperandWhenDynamicUpdateSliceIsFedByDynamicSliceWithSameIndex) {
2304 auto builder = HloComputation::Builder(TestName());
2305 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2306 Shape slice_shape = ShapeUtil::MakeShape(F32, {1, 2});
2307
2308 auto param = builder.AddInstruction(
2309 HloInstruction::CreateParameter(0, data_shape, "param0"));
2310 auto zero = builder.AddInstruction(
2311 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64>(0)));
2312 auto ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
2313 slice_shape, param, {zero, zero}, {1, 2}));
2314
2315 auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2316 data_shape, param, ds, {zero, zero}));
2317
2318 BuildModule(builder.Build());
2319 auto fusion = computation_->CreateFusionInstruction(
2320 {dus, ds, zero}, HloInstruction::FusionKind::kLoop);
2321 RunAnalysis();
2322
2323 EXPECT_TRUE(
2324 dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {}));
2325 }
2326
TEST_F(CanShareOperandBufferWithUserTest,DUSWithSliceWithSameIndices)2327 TEST_F(CanShareOperandBufferWithUserTest, DUSWithSliceWithSameIndices) {
2328 const char* kModule = R"(
2329 HloModule test
2330
2331 fused_computation {
2332 p0 = f32[10,20,30] parameter(0)
2333 p1 = s32[] parameter(1)
2334 p2 = s32[] parameter(2)
2335 p3 = s32[] parameter(3)
2336 slice = f32[1,1,30] dynamic-slice(p0, p1, p2, p3), dynamic_slice_sizes={1,1,30}
2337 ROOT dus = f32[10,20,30] dynamic-update-slice(p0, slice, p1, p2, p3)
2338 }
2339
2340 ENTRY test {
2341 p0 = f32[10,20,30] parameter(0)
2342 p1 = s32[] parameter(1)
2343 p2 = s32[] parameter(2)
2344 p3 = s32[] parameter(3)
2345 ROOT fusion = f32[10,20,30] fusion(p0, p1, p2, p3), kind=kLoop, calls=fused_computation
2346 }
2347 )";
2348 TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule));
2349 auto* fusion = module_->entry_computation()->root_instruction();
2350 auto* param = module_->entry_computation()->parameter_instruction(0);
2351
2352 RunAnalysis();
2353 EXPECT_TRUE(
2354 dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {}));
2355 }
2356
TEST_F(CanShareOperandBufferWithUserTest,ElementWiseDifferentShape)2357 TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
2358 auto builder = HloComputation::Builder(TestName());
2359
2360 Shape in_shape = ShapeUtil::MakeShape(F32, {8});
2361 Shape out_shape = ShapeUtil::MakeShape(PRED, {8});
2362 auto param0 = builder.AddInstruction(
2363 HloInstruction::CreateParameter(0, in_shape, "param0"));
2364 auto param1 = builder.AddInstruction(
2365 HloInstruction::CreateParameter(1, in_shape, "param1"));
2366 auto result = builder.AddInstruction(HloInstruction::CreateCompare(
2367 out_shape, param0, param1, ComparisonDirection::kEq));
2368
2369 BuildModuleAndRunAnalysis(builder.Build());
2370
2371 EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
2372 result, {}));
2373 EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
2374 result, {}));
2375 }
2376
TEST_F(CanShareOperandBufferWithUserTest,CopyShares)2377 TEST_F(CanShareOperandBufferWithUserTest, CopyShares) {
2378 auto builder = HloComputation::Builder(TestName());
2379
2380 Shape shape = ShapeUtil::MakeShape(F32, {8});
2381 auto param = builder.AddInstruction(
2382 HloInstruction::CreateParameter(0, shape, "param"));
2383 auto exp = builder.AddInstruction(
2384 HloInstruction::CreateUnary(shape, HloOpcode::kExp, param));
2385 auto copy = builder.AddInstruction(
2386 HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp));
2387
2388 BuildModuleAndRunAnalysis(builder.Build());
2389
2390 EXPECT_TRUE(
2391 dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {}));
2392 EXPECT_TRUE(
2393 dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, copy, {}));
2394 }
2395
TEST_F(CanShareOperandBufferWithUserTest,FusedDynamicUpdateSlice)2396 TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
2397 auto builder = HloComputation::Builder(TestName());
2398
2399 Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2400 auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
2401 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
2402 auto gte0 = builder.AddInstruction(
2403 HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
2404 auto gte1 = builder.AddInstruction(
2405 HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
2406
2407 // Create a DynamicUpdateSlice instruction of tuple element 1.
2408 auto starts = builder.AddInstruction(
2409 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
2410 auto update = builder.AddInstruction(HloInstruction::CreateConstant(
2411 LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
2412 auto dynamic_update_slice =
2413 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2414 data_shape, gte1, update,
2415 std::initializer_list<HloInstruction*>({starts})));
2416 builder.AddInstruction(
2417 HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
2418
2419 BuildModule(builder.Build());
2420 auto fusion = computation_->CreateFusionInstruction(
2421 {dynamic_update_slice, starts, update, gte1},
2422 HloInstruction::FusionKind::kLoop);
2423 RunAnalysis();
2424
2425 // The fusion instruction can share with tuple element 1.
2426 EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(tuple, {0},
2427 fusion, {}));
2428 EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(tuple, {1},
2429 fusion, {}));
2430 }
2431
TEST_F(CanShareOperandBufferWithUserTest,FusedDynamicUpdateSliceWithConvertCanShare)2432 TEST_F(CanShareOperandBufferWithUserTest,
2433 FusedDynamicUpdateSliceWithConvertCanShare) {
2434 auto builder = HloComputation::Builder(TestName());
2435
2436 Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2437 Shape data_shape_bf16 = ShapeUtil::MakeShape(BF16, {8});
2438 auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
2439 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
2440 auto gte0 = builder.AddInstruction(
2441 HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
2442 auto gte1 = builder.AddInstruction(
2443 HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
2444
2445 auto convert1 = builder.AddInstruction(
2446 HloInstruction::CreateConvert(data_shape_bf16, gte1));
2447
2448 // Create a DynamicUpdateSlice instruction of tuple element 1.
2449 auto starts = builder.AddInstruction(
2450 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
2451 auto update = builder.AddInstruction(HloInstruction::CreateConstant(
2452 LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
2453 auto dynamic_update_slice =
2454 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2455 data_shape_bf16, convert1, update,
2456 std::initializer_list<HloInstruction*>({starts})));
2457
2458 auto convert2 = builder.AddInstruction(
2459 HloInstruction::CreateConvert(data_shape, dynamic_update_slice));
2460 builder.AddInstruction(HloInstruction::CreateTuple({gte0, convert2}));
2461
2462 BuildModule(builder.Build());
2463 auto fusion = computation_->CreateFusionInstruction(
2464 {convert2, dynamic_update_slice, starts, update, convert1},
2465 HloInstruction::FusionKind::kLoop);
2466 RunAnalysis();
2467
2468 EXPECT_TRUE(
2469 dataflow_analysis_->CanShareOperandBufferWithUser(gte1, {}, fusion, {}));
2470 }
2471
TEST_F(CanShareOperandBufferWithUserTest,DynamicUpdateSliceCanShare)2472 TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
2473 auto builder = HloComputation::Builder(TestName());
2474
2475 Shape data_shape = ShapeUtil::MakeShape(F32, {1, 8});
2476 Shape update_shape = ShapeUtil::MakeShape(F32, {1, 4});
2477 Shape starts_shape = ShapeUtil::MakeShape(S32, {2});
2478 auto data = builder.AddInstruction(
2479 HloInstruction::CreateParameter(0, data_shape, "data"));
2480 auto update = builder.AddInstruction(
2481 HloInstruction::CreateParameter(1, update_shape, "update"));
2482 auto start = builder.AddInstruction(
2483 HloInstruction::CreateParameter(2, starts_shape, "start"));
2484
2485 auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2486 data_shape, data, update, {start}));
2487
2488 BuildModuleAndRunAnalysis(builder.Build());
2489
2490 // The DynamicUpdateSlice instruction can share with the data operand, but not
2491 // with update or start.
2492 EXPECT_TRUE(
2493 dataflow_analysis_->CanShareOperandBufferWithUser(data, {}, dus, {}));
2494 EXPECT_FALSE(
2495 dataflow_analysis_->CanShareOperandBufferWithUser(update, {}, dus, {}));
2496 EXPECT_FALSE(
2497 dataflow_analysis_->CanShareOperandBufferWithUser(start, {}, dus, {}));
2498 }
2499
TEST_F(CanShareOperandBufferWithUserTest,ScatterCanShare)2500 TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) {
2501 const char* hlo_text = R"(
2502 HloModule TensorFlowScatterV1
2503
2504 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2505 lhs = s32[] parameter(0)
2506 ROOT rhs = s32[] parameter(1)
2507 }
2508
2509 ENTRY main {
2510 operand = s32[3,3] parameter(0)
2511 indices = s32[2] parameter(1)
2512 updates = s32[2,3] parameter(2)
2513 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
2514 to_apply=update_s32,
2515 update_window_dims={1},
2516 inserted_window_dims={0},
2517 scatter_dims_to_operand_dims={0},
2518 index_vector_dim=1
2519 }
2520 )";
2521 TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_text));
2522 computation_ = module_->entry_computation();
2523 RunAnalysis();
2524
2525 HloInstruction* operand_param = computation_->parameter_instruction(0);
2526 HloInstruction* indices_param = computation_->parameter_instruction(1);
2527 HloInstruction* updates_param = computation_->parameter_instruction(2);
2528 HloInstruction* scatter = computation_->root_instruction();
2529
2530 EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(
2531 operand_param, {}, scatter, {}));
2532 EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(
2533 indices_param, {}, scatter, {}));
2534 EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(
2535 updates_param, {}, scatter, {}));
2536 }
2537
TEST_F(CanShareOperandBufferWithUserTest,TriangularSolveCanShare)2538 TEST_F(CanShareOperandBufferWithUserTest, TriangularSolveCanShare) {
2539 const char* hlo_text = R"(
2540 HloModule TensorFlowTriangularSolve
2541
2542 ENTRY main {
2543 a = f32[4,4]{1,0} parameter(0)
2544 b = f32[3,4]{1,0} parameter(1)
2545 ROOT triangular-solve = f32[3,4]{1,0} triangular-solve(a, b), lower=true,
2546 transpose_a=NO_TRANSPOSE
2547 }
2548 )";
2549 TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_text));
2550 computation_ = module_->entry_computation();
2551 RunAnalysis();
2552
2553 HloInstruction* lhs_param = computation_->parameter_instruction(0);
2554 HloInstruction* rhs_param = computation_->parameter_instruction(1);
2555 HloInstruction* triangular_solve = computation_->root_instruction();
2556
2557 EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(
2558 lhs_param, {}, triangular_solve, {}));
2559 EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(
2560 rhs_param, {}, triangular_solve, {}));
2561 }
2562
TEST_F(CanShareOperandBufferWithUserTest,SortCanShare)2563 TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) {
2564 auto builder = HloComputation::Builder(TestName());
2565 module_ = CreateNewVerifiedModule();
2566
2567 Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
2568 auto keys = builder.AddInstruction(
2569 HloInstruction::CreateParameter(0, keys_shape, "keys"));
2570 TF_ASSERT_OK_AND_ASSIGN(
2571 auto* sort, MakeSortHlo(keys_shape, {keys}, -1, /*is_stable=*/false,
2572 &builder, module_.get()));
2573
2574 computation_ = module_->AddEntryComputation(builder.Build());
2575 RunAnalysis();
2576
2577 EXPECT_TRUE(
2578 dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {}));
2579 }
2580
TEST_F(CanShareOperandBufferWithUserTest,SortCanShareWithTupleUser)2581 TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) {
2582 auto builder = HloComputation::Builder(TestName());
2583 module_ = CreateNewVerifiedModule();
2584
2585 Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
2586 Shape values_shape = ShapeUtil::MakeShape(F32, {8});
2587 auto keys = builder.AddInstruction(
2588 HloInstruction::CreateParameter(0, keys_shape, "keys"));
2589 auto values = builder.AddInstruction(
2590 HloInstruction::CreateParameter(1, values_shape, "values"));
2591 TF_ASSERT_OK_AND_ASSIGN(
2592 auto* sort,
2593 MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}),
2594 {keys, values}, 0, /*is_stable=*/false, &builder,
2595 module_.get()));
2596
2597 computation_ = module_->AddEntryComputation(builder.Build());
2598 RunAnalysis();
2599
2600 // The buffer for the keys can be shared with the first tuple entry.
2601 EXPECT_TRUE(
2602 dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {0}));
2603 // The buffer for the values can be shared with the second tuple entry.
2604 EXPECT_TRUE(
2605 dataflow_analysis_->CanShareOperandBufferWithUser(values, {}, sort, {1}));
2606 // Verify that the buffers are not shared with the "wrong" tuple entry.
2607 EXPECT_FALSE(
2608 dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {1}));
2609 EXPECT_FALSE(
2610 dataflow_analysis_->CanShareOperandBufferWithUser(values, {}, sort, {0}));
2611 }
2612
TEST_F(CanShareOperandBufferWithUserTest,FusedDotAdd)2613 TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
2614 auto builder = HloComputation::Builder(TestName());
2615 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2616
2617 auto a = builder.AddInstruction(HloInstruction::CreateConstant(
2618 LiteralUtil::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
2619 auto b = builder.AddInstruction(HloInstruction::CreateConstant(
2620 LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
2621
2622 DotDimensionNumbers dot_dnums;
2623 dot_dnums.add_lhs_contracting_dimensions(1);
2624 dot_dnums.add_rhs_contracting_dimensions(0);
2625 PrecisionConfig precision_config;
2626 precision_config.mutable_operand_precision()->Resize(
2627 2, PrecisionConfig::DEFAULT);
2628 auto dot = builder.AddInstruction(
2629 HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config));
2630
2631 auto one = builder.AddInstruction(
2632 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2633 auto add_operand = builder.AddInstruction(
2634 HloInstruction::CreateBroadcast(data_shape, one, {}));
2635
2636 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
2637 data_shape, HloOpcode::kAdd, dot, add_operand));
2638
2639 BuildModule(builder.Build());
2640 auto fusion = computation_->CreateFusionInstruction(
2641 {add, dot}, HloInstruction::FusionKind::kOutput);
2642 RunAnalysis();
2643
2644 // Output fused dot add should be able to share buffer with 'add_operand'.
2645 EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(add_operand, {},
2646 fusion, {}));
2647 }
2648
TEST_F(CanShareOperandBufferWithUserTest,OutputFusionCantAliasOperandBuffer)2649 TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
2650 auto builder = HloComputation::Builder(TestName());
2651 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2652
2653 auto one = builder.AddInstruction(
2654 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2655 auto operand = builder.AddInstruction(
2656 HloInstruction::CreateBroadcast(data_shape, one, {}));
2657
2658 auto reverse = builder.AddInstruction(
2659 HloInstruction::CreateReverse(data_shape, operand, {0, 1}));
2660
2661 auto two = builder.AddInstruction(HloInstruction::CreateConstant(
2662 LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
2663
2664 auto add = builder.AddInstruction(
2665 HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two));
2666
2667 BuildModule(builder.Build());
2668 auto fusion = computation_->CreateFusionInstruction(
2669 {add, two, reverse}, HloInstruction::FusionKind::kOutput);
2670 RunAnalysis();
2671
2672 // Output fused operand->reverse->add cannot alias operand buffer 'operand'.
2673 EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {},
2674 fusion, {}));
2675 }
2676
TEST_F(CanShareOperandBufferWithUserTest,FusionCanShareBufferCustomized)2677 TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) {
2678 auto builder = HloComputation::Builder(TestName());
2679 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2680
2681 auto one = builder.AddInstruction(
2682 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2683 auto operand = builder.AddInstruction(
2684 HloInstruction::CreateBroadcast(data_shape, one, {}));
2685 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
2686 data_shape, HloOpcode::kMultiply, operand, operand));
2687 auto two = builder.AddInstruction(HloInstruction::CreateConstant(
2688 LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
2689 auto add = builder.AddInstruction(
2690 HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, mul, two));
2691
2692 BuildModule(builder.Build());
2693 auto fusion = computation_->CreateFusionInstruction(
2694 {add, two, mul}, HloInstruction::FusionKind::kInput);
2695 RunAnalysis(/*can_share_buffer=*/[](const HloInstruction* fusion,
2696 const HloInstruction*,
2697 const ShapeIndex&) {
2698 return fusion->IsLoopFusion();
2699 });
2700
2701 EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {},
2702 fusion, {}));
2703 }
2704
TEST_F(CanShareOperandBufferWithUserTest,WhileCanShare)2705 TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
2706 module_ = CreateNewVerifiedModule();
2707 Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2708 Shape pred_scalar_shape = ShapeUtil::MakeShape(PRED, {});
2709
2710 auto b = HloComputation::Builder(TestName() + ".And");
2711 auto p0 = b.AddInstruction(
2712 HloInstruction::CreateParameter(0, pred_scalar_shape, "p0"));
2713 auto p1 = b.AddInstruction(
2714 HloInstruction::CreateParameter(1, pred_scalar_shape, "p1"));
2715 b.AddInstruction(
2716 HloInstruction::CreateBinary(pred_scalar_shape, HloOpcode::kAnd, p0, p1));
2717 auto and_computation = module_->AddEmbeddedComputation(b.Build());
2718
2719 auto make_cond = [&data_shape, &and_computation]() {
2720 auto builder = HloComputation::Builder(TestName() + ".Cond");
2721 auto data = builder.AddInstruction(
2722 HloInstruction::CreateParameter(0, data_shape, "data"));
2723 auto compare = builder.AddInstruction(HloInstruction::CreateCompare(
2724 ShapeUtil::MakeShape(PRED, {8}), data, data, ComparisonDirection::kEq));
2725 auto true_value = builder.AddInstruction(
2726 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
2727 builder.AddInstruction(
2728 HloInstruction::CreateReduce(ShapeUtil::MakeShape(PRED, {}), compare,
2729 true_value, {0}, and_computation));
2730 return builder.Build();
2731 };
2732
2733 auto make_body = [&data_shape]() {
2734 auto builder = HloComputation::Builder(TestName() + ".Body");
2735 auto data = builder.AddInstruction(
2736 HloInstruction::CreateParameter(0, data_shape, "data"));
2737 builder.AddInstruction(
2738 HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data));
2739 return builder.Build();
2740 };
2741
2742 HloComputation* cond_computation =
2743 module_->AddEmbeddedComputation(make_cond());
2744 HloComputation* body_computation =
2745 module_->AddEmbeddedComputation(make_body());
2746
2747 auto builder = HloComputation::Builder(TestName());
2748 auto data = builder.AddInstruction(
2749 HloInstruction::CreateParameter(0, data_shape, "data"));
2750 auto whil = builder.AddInstruction(HloInstruction::CreateWhile(
2751 data_shape, cond_computation, body_computation, data));
2752 computation_ = module_->AddEntryComputation(builder.Build());
2753
2754 RunAnalysis();
2755
2756 // The While instruction can share with the data operand.
2757 EXPECT_TRUE(
2758 dataflow_analysis_->CanShareOperandBufferWithUser(data, {}, whil, {}));
2759 }
2760
2761 // Tests that Call can alias operand buffer if the only use of the operand
2762 // in the called computation is an elementwise instruction.
TEST_F(CanShareOperandBufferWithUserTest,CallToComputationWithFusionRoot)2763 TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
2764 Shape shape = ShapeUtil::MakeShape(F32, {8});
2765 // Build sub-computation with fusion root.
2766 auto sub_builder = HloComputation::Builder(TestName() + "_sub");
2767 auto sub_param = sub_builder.AddInstruction(
2768 HloInstruction::CreateParameter(0, shape, "sub_param"));
2769 auto one = sub_builder.AddInstruction(
2770 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2771 auto ones = sub_builder.AddInstruction(
2772 HloInstruction::CreateBroadcast(shape, one, {}));
2773 auto add = sub_builder.AddInstruction(
2774 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones));
2775
2776 module_ = CreateNewVerifiedModule();
2777 auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build());
2778 sub_computation->CreateFusionInstruction({add, ones},
2779 HloInstruction::FusionKind::kLoop);
2780
2781 // Build entry-computation with kCall which calls 'sub_computation'.
2782 auto builder = HloComputation::Builder(TestName());
2783
2784 auto param = builder.AddInstruction(
2785 HloInstruction::CreateParameter(0, shape, "param"));
2786 auto reverse =
2787 builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0}));
2788 auto call = builder.AddInstruction(
2789 HloInstruction::CreateCall(shape, {reverse}, sub_computation));
2790 computation_ = module_->AddEntryComputation(builder.Build());
2791
2792 RunAnalysis();
2793
2794 EXPECT_TRUE(
2795 dataflow_analysis_->CanShareOperandBufferWithUser(reverse, {}, call, {}));
2796 }
2797
2798 } // namespace
2799 } // namespace xla
2800