1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
17
18 #include <map>
19 #include <memory>
20
21 #include "tensorflow/compiler/xla/literal_util.h"
22 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
23 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
24 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
25 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/instruction_fusion.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/test_helpers.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/test.h"
35
36 namespace op = xla::testing::opcode_matchers;
37
38 namespace xla {
39 namespace {
40
41 using ::testing::UnorderedElementsAre;
42 using ::testing::UnorderedElementsAreArray;
43
44 class TuplePointsToAnalysisTest : public HloTestBase {
45 protected:
46 // Builds a module with the given entry computation and runs points to
47 // analysis.
BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation)48 void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
49 BuildModule(std::move(computation));
50 RunAnalysis();
51 }
52
BuildModule(std::unique_ptr<HloComputation> computation)53 void BuildModule(std::unique_ptr<HloComputation> computation) {
54 module_ = CreateNewVerifiedModule();
55 module_->AddEntryComputation(std::move(computation));
56 }
57
RunAnalysis()58 void RunAnalysis() {
59 CHECK_NOTNULL(module_.get());
60 points_to_analysis_ =
61 TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
62 }
63
64 // Returns the LogicalBuffer defined at the given instruction and
65 // index. CHECKs if no buffer is defined at that point.
GetBuffer(const HloInstruction * instruction,const ShapeIndex & index)66 const LogicalBuffer* const GetBuffer(const HloInstruction* instruction,
67 const ShapeIndex& index) {
68 const auto& pointed_to =
69 points_to_analysis_->GetPointsToSet(instruction).element(index);
70 CHECK_EQ(1, pointed_to.size());
71 CHECK_EQ(instruction, pointed_to[0]->instruction());
72 CHECK(index == pointed_to[0]->index());
73 return pointed_to[0];
74 }
75
76 // Checks that the given points-to set contains exactly (unordered) the given
77 // LogicalBuffers.
ExpectHasBuffers(const PointsToSet::BufferList & points_to_set,absl::Span<const LogicalBuffer * const> buffers)78 void ExpectHasBuffers(const PointsToSet::BufferList& points_to_set,
79 absl::Span<const LogicalBuffer* const> buffers) {
80 std::vector<const LogicalBuffer*> vec(buffers.begin(), buffers.end());
81 EXPECT_THAT(points_to_set, UnorderedElementsAreArray(vec));
82 }
83
84 // Checks that the given points-to set contains exactly (unordered) the
85 // top-level buffers of the given instructions.
ExpectHasTopLevelBuffers(const PointsToSet::BufferList & points_to_set,absl::Span<HloInstruction * const> instructions)86 void ExpectHasTopLevelBuffers(
87 const PointsToSet::BufferList& points_to_set,
88 absl::Span<HloInstruction* const> instructions) {
89 PointsToSet::BufferList buffers;
90 for (auto instruction : instructions) {
91 buffers.push_back(GetBuffer(instruction, /*index=*/{}));
92 }
93 ExpectHasBuffers(points_to_set, buffers);
94 }
95
96 // Overload which takes a set instead of a vector.
ExpectHasTopLevelBuffers(const PointsToSet::BufferSet & points_to_set,absl::Span<HloInstruction * const> instructions)97 void ExpectHasTopLevelBuffers(
98 const PointsToSet::BufferSet& points_to_set,
99 absl::Span<HloInstruction* const> instructions) {
100 ExpectHasTopLevelBuffers(
101 PointsToSet::BufferList(points_to_set.begin(), points_to_set.end()),
102 instructions);
103 }
104
105 // Checks that the buffer defined at the given instruction and index has
106 // aliases which are exactly (unordered) the given instruction/index pairs.
ExpectHasBufferAliases(const HloInstruction * instruction,const ShapeIndex & index,absl::Span<const std::pair<HloInstruction *,ShapeIndex>> expected)107 void ExpectHasBufferAliases(
108 const HloInstruction* instruction, const ShapeIndex& index,
109 absl::Span<const std::pair<HloInstruction*, ShapeIndex>> expected) {
110 const LogicalBuffer* buffer =
111 points_to_analysis_->GetBufferDefinedAt(instruction, index)
112 .ValueOrDie();
113 std::vector<BufferAlias> expected_aliases;
114 for (auto& pair : expected) {
115 expected_aliases.push_back(BufferAlias(pair.first, pair.second));
116 }
117 EXPECT_THAT(points_to_analysis_->GetBufferAliases(*buffer),
118 UnorderedElementsAreArray(expected_aliases));
119 }
120
121 std::unique_ptr<HloModule> module_;
122 std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
123 };
124
TEST_F(TuplePointsToAnalysisTest,SimpleTuple)125 TEST_F(TuplePointsToAnalysisTest, SimpleTuple) {
126 auto builder = HloComputation::Builder(TestName());
127 auto constant1 = builder.AddInstruction(
128 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
129 auto constant2 = builder.AddInstruction(
130 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
131 auto tuple = builder.AddInstruction(
132 HloInstruction::CreateTuple({constant1, constant2}));
133
134 BuildModuleAndRunAnalysis(builder.Build());
135 EXPECT_EQ(1, points_to_analysis_->GetPointsToSet(constant1).size());
136 ExpectHasTopLevelBuffers(
137 points_to_analysis_->GetPointsToSet(constant1).element({}), {constant1});
138 EXPECT_TRUE(
139 points_to_analysis_->GetPointsToSet(constant1).tuple_sources({}).empty());
140 EXPECT_TRUE(points_to_analysis_->GetPointsToSet(tuple).IsDistinct());
141
142 EXPECT_EQ(1, points_to_analysis_->GetPointsToSet(constant2).size());
143 ExpectHasTopLevelBuffers(
144 points_to_analysis_->GetPointsToSet(constant2).element({}), {constant2});
145 EXPECT_TRUE(
146 points_to_analysis_->GetPointsToSet(constant2).tuple_sources({}).empty());
147
148 EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(tuple).size());
149 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous());
150 EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}),
151 UnorderedElementsAre(tuple));
152
153 ExpectHasTopLevelBuffers(
154 points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
155 {constant1, constant2, tuple});
156 ExpectHasTopLevelBuffers(
157 points_to_analysis_->GetPointsToSet(tuple).element({}), {tuple});
158 ExpectHasTopLevelBuffers(
159 points_to_analysis_->GetPointsToSet(tuple).element({0}), {constant1});
160 ExpectHasTopLevelBuffers(
161 points_to_analysis_->GetPointsToSet(tuple).element({1}), {constant2});
162
163 const PointsToSet& tuple_points_to_set =
164 points_to_analysis_->GetPointsToSet(tuple);
165 EXPECT_TRUE(tuple_points_to_set.ContainsBufferAtIndex(
166 *GetBuffer(constant1, {}), {0}));
167 EXPECT_TRUE(tuple_points_to_set.ContainsBufferAtIndex(
168 *GetBuffer(constant2, {}), {1}));
169 EXPECT_FALSE(tuple_points_to_set.ContainsBufferAtIndex(
170 *GetBuffer(constant2, {}), {0}));
171 EXPECT_TRUE(tuple_points_to_set.ContainsBuffer(*GetBuffer(constant1, {})));
172 EXPECT_TRUE(tuple_points_to_set.ContainsBuffer(*GetBuffer(constant2, {})));
173 }
174
TEST_F(TuplePointsToAnalysisTest,NestedTuple)175 TEST_F(TuplePointsToAnalysisTest, NestedTuple) {
176 // Create a (nested) tuple containing an inner tuple. The points-to set of the
177 // outer tuple should contain all elements of the points-to set of the inner
178 // tuple.
179 auto builder = HloComputation::Builder(TestName());
180 auto constant1 = builder.AddInstruction(
181 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
182 auto constant2 = builder.AddInstruction(
183 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
184 auto inner_tuple = builder.AddInstruction(
185 HloInstruction::CreateTuple({constant1, constant2}));
186
187 auto constant3 = builder.AddInstruction(
188 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
189 auto tuple = builder.AddInstruction(
190 HloInstruction::CreateTuple({inner_tuple, constant3}));
191
192 BuildModuleAndRunAnalysis(builder.Build());
193 ExpectHasTopLevelBuffers(
194 points_to_analysis_->GetPointsToSet(constant1).element({}), {constant1});
195 ExpectHasTopLevelBuffers(
196 points_to_analysis_->GetPointsToSet(constant2).element({}), {constant2});
197 ExpectHasTopLevelBuffers(
198 points_to_analysis_->GetPointsToSet(constant3).element({}), {constant3});
199
200 EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(inner_tuple).size());
201 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(inner_tuple).IsAmbiguous());
202 EXPECT_TRUE(points_to_analysis_->GetPointsToSet(inner_tuple).IsDistinct());
203 ExpectHasTopLevelBuffers(
204 points_to_analysis_->GetPointsToSet(inner_tuple).CreateFlattenedSet(),
205 {constant1, constant2, inner_tuple});
206 ExpectHasTopLevelBuffers(
207 points_to_analysis_->GetPointsToSet(inner_tuple).element({}),
208 {inner_tuple});
209 EXPECT_THAT(
210 points_to_analysis_->GetPointsToSet(inner_tuple).tuple_sources({}),
211 UnorderedElementsAre(inner_tuple));
212
213 EXPECT_EQ(5, points_to_analysis_->GetPointsToSet(tuple).size());
214 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous());
215 ExpectHasTopLevelBuffers(
216 points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
217 {constant1, constant2, constant3, inner_tuple, tuple});
218
219 EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}),
220 UnorderedElementsAre(tuple));
221 EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({0}),
222 UnorderedElementsAre(inner_tuple));
223 EXPECT_TRUE(
224 points_to_analysis_->GetPointsToSet(tuple).tuple_sources({1}).empty());
225
226 ExpectHasTopLevelBuffers(
227 points_to_analysis_->GetPointsToSet(tuple).element({0}), {inner_tuple});
228 ExpectHasTopLevelBuffers(
229 points_to_analysis_->GetPointsToSet(tuple).element({0, 0}), {constant1});
230 ExpectHasTopLevelBuffers(
231 points_to_analysis_->GetPointsToSet(tuple).element({0, 1}), {constant2});
232 ExpectHasTopLevelBuffers(
233 points_to_analysis_->GetPointsToSet(tuple).element({1}), {constant3});
234 }
235
TEST_F(TuplePointsToAnalysisTest,GetTupleElement)236 TEST_F(TuplePointsToAnalysisTest, GetTupleElement) {
237 // Create a nested tuple, then extract the inner tuple with GetTupleElement.
238 // The points-to set of the GetTupleElement should be the same as the inner
239 // tuple.
240 auto builder = HloComputation::Builder(TestName());
241 auto constant1 = builder.AddInstruction(
242 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
243 auto constant2 = builder.AddInstruction(
244 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
245 auto inner_tuple = builder.AddInstruction(
246 HloInstruction::CreateTuple({constant1, constant2}));
247
248 auto constant3 = builder.AddInstruction(
249 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
250 auto tuple = builder.AddInstruction(
251 HloInstruction::CreateTuple({inner_tuple, constant3}));
252
253 auto get_tuple_element = builder.AddInstruction(
254 HloInstruction::CreateGetTupleElement(inner_tuple->shape(), tuple, 0));
255
256 BuildModuleAndRunAnalysis(builder.Build());
257
258 auto& points_to_set = points_to_analysis_->GetPointsToSet(get_tuple_element);
259 EXPECT_EQ(3, points_to_set.size());
260 EXPECT_FALSE(points_to_set.IsAmbiguous());
261 EXPECT_TRUE(points_to_set.IsDistinct());
262 ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(),
263 {constant1, constant2, inner_tuple});
264 ExpectHasTopLevelBuffers(points_to_set.element({}), {inner_tuple});
265
266 EXPECT_THAT(points_to_set.tuple_sources({}),
267 UnorderedElementsAre(inner_tuple));
268 }
269
TEST_F(TuplePointsToAnalysisTest,AddDependency)270 TEST_F(TuplePointsToAnalysisTest, AddDependency) {
271 auto builder = HloComputation::Builder(TestName());
272 auto constant = builder.AddInstruction(
273 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
274 auto token = builder.AddInstruction(HloInstruction::CreateToken());
275 auto add_dependency = builder.AddInstruction(
276 HloInstruction::CreateAddDependency(constant, token));
277 BuildModuleAndRunAnalysis(builder.Build());
278
279 auto& points_to_set = points_to_analysis_->GetPointsToSet(add_dependency);
280 EXPECT_EQ(1, points_to_set.size());
281 EXPECT_FALSE(points_to_set.IsAmbiguous());
282 EXPECT_TRUE(points_to_set.IsDistinct());
283 ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(), {constant});
284 }
285
TEST_F(TuplePointsToAnalysisTest,DuplicatedElement)286 TEST_F(TuplePointsToAnalysisTest, DuplicatedElement) {
287 // Create a tuple which contains duplicate elements.
288 auto builder = HloComputation::Builder(TestName());
289 auto constant = builder.AddInstruction(
290 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
291 auto tuple = builder.AddInstruction(
292 HloInstruction::CreateTuple({constant, constant, constant}));
293
294 BuildModuleAndRunAnalysis(builder.Build());
295
296 EXPECT_EQ(2, points_to_analysis_->GetPointsToSet(tuple).size());
297 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous());
298 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsDistinct());
299 ExpectHasTopLevelBuffers(
300 points_to_analysis_->GetPointsToSet(tuple).element({}), {tuple});
301 ExpectHasTopLevelBuffers(
302 points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
303 {constant, tuple});
304 }
305
TEST_F(TuplePointsToAnalysisTest,TupleCopy)306 TEST_F(TuplePointsToAnalysisTest, TupleCopy) {
307 // Create a copy (HloOpcode::kCopy) of a tuple. The points to sets should be
308 // the same.
309 auto builder = HloComputation::Builder(TestName());
310 auto constant1 = builder.AddInstruction(
311 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
312 auto constant2 = builder.AddInstruction(
313 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
314 auto tuple = builder.AddInstruction(
315 HloInstruction::CreateTuple({constant1, constant2}));
316 auto copy = builder.AddInstruction(
317 HloInstruction::CreateUnary(tuple->shape(), HloOpcode::kCopy, tuple));
318
319 BuildModuleAndRunAnalysis(builder.Build());
320
321 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(copy).IsAmbiguous());
322 EXPECT_TRUE(points_to_analysis_->GetPointsToSet(copy).IsDistinct());
323 ExpectHasTopLevelBuffers(
324 points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
325 {constant1, constant2, tuple});
326 ExpectHasTopLevelBuffers(
327 points_to_analysis_->GetPointsToSet(copy).element({}), {copy});
328 ExpectHasTopLevelBuffers(
329 points_to_analysis_->GetPointsToSet(copy).CreateFlattenedSet(),
330 {constant1, constant2, copy});
331 }
332
TEST_F(TuplePointsToAnalysisTest,CopyStartAndCopyDone)333 TEST_F(TuplePointsToAnalysisTest, CopyStartAndCopyDone) {
334 // CopyDone forwards its operand to the output tuple at {0}.
335 auto builder = HloComputation::Builder(TestName());
336 auto constant = builder.AddInstruction(
337 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
338 auto copy_start = builder.AddInstruction(HloInstruction::CreateCopyStart(
339 ShapeUtil::MakeTupleShape({constant->shape(), constant->shape(),
340 ShapeUtil::MakeShape(U32, {})}),
341 constant));
342 auto copy_done = builder.AddInstruction(HloInstruction::CreateUnary(
343 constant->shape(), HloOpcode::kCopyDone, copy_start));
344
345 BuildModuleAndRunAnalysis(builder.Build());
346
347 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(copy_start).IsAmbiguous());
348 EXPECT_TRUE(points_to_analysis_->GetPointsToSet(copy_start).IsDistinct());
349 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(copy_done).IsAmbiguous());
350 EXPECT_TRUE(points_to_analysis_->GetPointsToSet(copy_done).IsDistinct());
351
352 ExpectHasTopLevelBuffers(
353 points_to_analysis_->GetPointsToSet(copy_start).element({}),
354 {copy_start});
355 ExpectHasBufferAliases(copy_start, {0}, {{copy_start, {0}}, {copy_done, {}}});
356 ExpectHasBufferAliases(constant, {}, {{constant, {}}, {copy_start, {1}}});
357 }
358
TEST_F(TuplePointsToAnalysisTest,SendAndSendDone)359 TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) {
360 // Send forwards its operand to the output tuple at {0}.
361 auto builder = HloComputation::Builder(TestName());
362 auto constant = builder.AddInstruction(
363 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
364 auto token = builder.AddInstruction(HloInstruction::CreateToken());
365 auto send = builder.AddInstruction(
366 HloInstruction::CreateSend(constant, token, /*channel_id=*/0));
367 auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
368
369 BuildModuleAndRunAnalysis(builder.Build());
370
371 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send).IsAmbiguous());
372 EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send).IsDistinct());
373 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send_done).IsAmbiguous());
374 EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send_done).IsDistinct());
375
376 ExpectHasTopLevelBuffers(
377 points_to_analysis_->GetPointsToSet(send).element({}), {send});
378 ExpectHasTopLevelBuffers(
379 points_to_analysis_->GetPointsToSet(send).element({0}), {constant});
380 ExpectHasTopLevelBuffers(
381 points_to_analysis_->GetPointsToSet(send_done).CreateFlattenedSet(),
382 {send_done});
383 ExpectHasBufferAliases(constant, {}, {{constant, {}}, {send, {0}}});
384 }
385
TEST_F(TuplePointsToAnalysisTest,RecvAndRecvDone)386 TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) {
387 // RecvDone forwards its operand tuple element at {0} to the output.
388 auto builder = HloComputation::Builder(TestName());
389 auto token = builder.AddInstruction(HloInstruction::CreateToken());
390 auto recv = builder.AddInstruction(HloInstruction::CreateRecv(
391 ShapeUtil::MakeShape(F32, {1, 2, 3}), token, /*channel_id=*/0));
392 auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
393
394 BuildModuleAndRunAnalysis(builder.Build());
395
396 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv).IsAmbiguous());
397 EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv).IsDistinct());
398 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv_done).IsAmbiguous());
399 EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv_done).IsDistinct());
400
401 ExpectHasTopLevelBuffers(
402 points_to_analysis_->GetPointsToSet(recv).element({}), {recv});
403 ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {0}}});
404 }
405
TEST_F(TuplePointsToAnalysisTest,TupleSelect)406 TEST_F(TuplePointsToAnalysisTest, TupleSelect) {
407 // Select from two different tuples. This should create an ambiguous points to
408 // set containing the union of both sides.
409 auto builder = HloComputation::Builder(TestName());
410 auto constant1 = builder.AddInstruction(
411 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
412 auto constant2 = builder.AddInstruction(
413 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
414 auto tuple1 = builder.AddInstruction(
415 HloInstruction::CreateTuple({constant1, constant2}));
416 auto tuple2 = builder.AddInstruction(
417 HloInstruction::CreateTuple({constant2, constant2}));
418
419 auto pred = builder.AddInstruction(
420 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
421 auto select = builder.AddInstruction(HloInstruction::CreateTernary(
422 tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
423
424 BuildModuleAndRunAnalysis(builder.Build());
425
426 auto& points_to_set = points_to_analysis_->GetPointsToSet(select);
427 EXPECT_EQ(3, points_to_set.size());
428 EXPECT_TRUE(points_to_set.IsAmbiguous());
429 EXPECT_FALSE(points_to_set.IsDistinct());
430 ExpectHasTopLevelBuffers(points_to_set.element({}), {select});
431 ExpectHasTopLevelBuffers(points_to_set.element({0}), {constant1, constant2});
432 ExpectHasTopLevelBuffers(points_to_set.element({1}), {constant2});
433 ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(),
434 {constant1, constant2, select});
435 }
436
TEST_F(TuplePointsToAnalysisTest,SelectTupleParameters)437 TEST_F(TuplePointsToAnalysisTest, SelectTupleParameters) {
438 // Create a Select which selects between two tuple parameters. Verify the
439 // points-to sets and tuple sources are properly set.
440 Shape tuple_shape = ShapeUtil::MakeTupleShape(
441 {ShapeUtil::MakeShape(F32, {1, 2, 3}), ShapeUtil::MakeShape(U32, {5})});
442
443 auto builder = HloComputation::Builder(TestName());
444 auto param0 = builder.AddInstruction(
445 HloInstruction::CreateParameter(0, tuple_shape, "param0"));
446 auto param1 = builder.AddInstruction(
447 HloInstruction::CreateParameter(1, tuple_shape, "param1"));
448 auto pred = builder.AddInstruction(
449 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
450 auto select = builder.AddInstruction(HloInstruction::CreateTernary(
451 tuple_shape, HloOpcode::kTupleSelect, pred, param0, param1));
452 auto copy = builder.AddInstruction(
453 HloInstruction::CreateUnary(tuple_shape, HloOpcode::kCopy, select));
454
455 BuildModuleAndRunAnalysis(builder.Build());
456
457 // The points-to set of each element of a tuple parameters should be itself
458 // with the appropriate index.
459 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(param0).element({}),
460 {GetBuffer(param0, {})});
461 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(param0).element({0}),
462 {GetBuffer(param0, {0})});
463 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(param0).element({1}),
464 {GetBuffer(param0, {1})});
465
466 // Select's point-to set of its subelements should be the respective
467 // subelements of param0 and param1. The top-level buffer, however, does not
468 // alias as it is created by the select instruction.
469 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(select).element({}),
470 {GetBuffer(select, {})});
471 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(select).element({0}),
472 {GetBuffer(param0, {0}), GetBuffer(param1, {0})});
473 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(select).element({1}),
474 {GetBuffer(param0, {1}), GetBuffer(param1, {1})});
475
476 // Copy should be identical to select other than the top-level buffer.
477 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(copy).element({}),
478 {GetBuffer(copy, {})});
479 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(copy).element({0}),
480 {GetBuffer(param0, {0}), GetBuffer(param1, {0})});
481 ExpectHasBuffers(points_to_analysis_->GetPointsToSet(copy).element({1}),
482 {GetBuffer(param0, {1}), GetBuffer(param1, {1})});
483 }
484
TEST_F(TuplePointsToAnalysisTest,UnambiguousTupleSelect)485 TEST_F(TuplePointsToAnalysisTest, UnambiguousTupleSelect) {
486 // Select from two identical tuples. The result should not be ambiguous.
487 auto builder = HloComputation::Builder(TestName());
488 auto constant1 = builder.AddInstruction(
489 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
490 auto constant2 = builder.AddInstruction(
491 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
492 auto tuple1 = builder.AddInstruction(
493 HloInstruction::CreateTuple({constant1, constant2}));
494 auto tuple2 = builder.AddInstruction(
495 HloInstruction::CreateTuple({constant1, constant2}));
496
497 auto pred = builder.AddInstruction(
498 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
499 auto select = builder.AddInstruction(HloInstruction::CreateTernary(
500 tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
501
502 BuildModuleAndRunAnalysis(builder.Build());
503
504 auto& points_to_set = points_to_analysis_->GetPointsToSet(select);
505 EXPECT_EQ(3, points_to_set.size());
506 EXPECT_FALSE(points_to_set.IsAmbiguous());
507 EXPECT_TRUE(points_to_set.IsDistinct());
508 ExpectHasTopLevelBuffers(points_to_set.element({}), {select});
509 ExpectHasTopLevelBuffers(points_to_set.element({0}), {constant1});
510 ExpectHasTopLevelBuffers(points_to_set.element({1}), {constant2});
511 ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(),
512 {constant1, constant2, select});
513 }
514
TEST_F(TuplePointsToAnalysisTest,NestedTupleSelect)515 TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) {
516 // Select from nested tuples. Verify that the nested points-to sets contain
517 // the right values.
518 auto builder = HloComputation::Builder(TestName());
519 auto constant1 = builder.AddInstruction(
520 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
521 auto constant2 = builder.AddInstruction(
522 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
523 auto inner_tuple1 = builder.AddInstruction(
524 HloInstruction::CreateTuple({constant1, constant2}));
525 auto inner_tuple2 = builder.AddInstruction(
526 HloInstruction::CreateTuple({constant2, constant2}));
527
528 auto tuple1 =
529 builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple1}));
530 auto tuple2 =
531 builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple2}));
532
533 auto pred = builder.AddInstruction(
534 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
535 auto select = builder.AddInstruction(HloInstruction::CreateTernary(
536 tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
537
538 BuildModuleAndRunAnalysis(builder.Build());
539
540 auto& points_to_set = points_to_analysis_->GetPointsToSet(select);
541 EXPECT_EQ(5, points_to_set.size());
542 EXPECT_TRUE(points_to_set.IsAmbiguous());
543 EXPECT_FALSE(points_to_set.IsDistinct());
544
545 // Verify points-to set.
546 ExpectHasTopLevelBuffers(points_to_set.element({}), {select});
547 ExpectHasTopLevelBuffers(points_to_set.element({0}),
548 {inner_tuple1, inner_tuple2});
549 ExpectHasTopLevelBuffers(points_to_set.element({0, 0}),
550 {constant1, constant2});
551 ExpectHasTopLevelBuffers(points_to_set.element({0, 1}), {constant2});
552
553 // Verify tuple sources.
554 EXPECT_THAT(points_to_set.tuple_sources({}),
555 UnorderedElementsAre(tuple1, tuple2));
556 EXPECT_THAT(points_to_set.tuple_sources({0}),
557 UnorderedElementsAre(inner_tuple1, inner_tuple2));
558 EXPECT_EQ(0, points_to_set.tuple_sources({0, 0}).size());
559 EXPECT_EQ(0, points_to_set.tuple_sources({0, 1}).size());
560 }
561
TEST_F(TuplePointsToAnalysisTest,TupleWithBitcast)562 TEST_F(TuplePointsToAnalysisTest, TupleWithBitcast) {
563 // Bitcast is an alias of its operand. A tuple with a bitcast element should
564 // have the operand of the bitcast in its points-to set.
565 auto builder = HloComputation::Builder(TestName());
566 auto constant1 = builder.AddInstruction(
567 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
568 auto constant2 = builder.AddInstruction(
569 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
570 auto bitcast = builder.AddInstruction(
571 HloInstruction::CreateBitcast(constant2->shape(), constant2));
572 auto tuple =
573 builder.AddInstruction(HloInstruction::CreateTuple({constant1, bitcast}));
574
575 BuildModuleAndRunAnalysis(builder.Build());
576
577 EXPECT_EQ(1, points_to_analysis_->GetPointsToSet(bitcast).size());
578 ExpectHasTopLevelBuffers(
579 points_to_analysis_->GetPointsToSet(bitcast).element({}), {constant2});
580 EXPECT_TRUE(
581 points_to_analysis_->GetPointsToSet(bitcast).tuple_sources({}).empty());
582
583 EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(tuple).size());
584 EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous());
585 EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}),
586 UnorderedElementsAre(tuple));
587
588 ExpectHasTopLevelBuffers(
589 points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
590 {constant1, constant2, tuple});
591 ExpectHasTopLevelBuffers(
592 points_to_analysis_->GetPointsToSet(tuple).element({}), {tuple});
593 ExpectHasTopLevelBuffers(
594 points_to_analysis_->GetPointsToSet(tuple).element({0}), {constant1});
595 ExpectHasTopLevelBuffers(
596 points_to_analysis_->GetPointsToSet(tuple).element({1}), {constant2});
597 }
598
TEST_F(TuplePointsToAnalysisTest,PointsToTupleConstantElements)599 TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) {
600 // Construct a tuple constant and kCopy it. Verify the points-to set of the
601 // copy correctly points into the nested elements of the constant.
602 auto builder = HloComputation::Builder(TestName());
603 Literal elements[] = {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}),
604 LiteralUtil::CreateR1<float>({2.0, 42})};
605 auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
606 LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
607 auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
608 tuple_constant->shape(), HloOpcode::kCopy, tuple_constant));
609
610 BuildModuleAndRunAnalysis(builder.Build());
611
612 auto& points_to_set = points_to_analysis_->GetPointsToSet(copy);
613
614 ExpectHasBuffers(points_to_set.element({}), {GetBuffer(copy, {})});
615 ExpectHasBuffers(points_to_set.element({0}),
616 {GetBuffer(tuple_constant, {0})});
617 ExpectHasBuffers(points_to_set.element({1}),
618 {GetBuffer(tuple_constant, {1})});
619 }
620
TEST_F(TuplePointsToAnalysisTest,BufferAliases)621 TEST_F(TuplePointsToAnalysisTest, BufferAliases) {
622 // Create a nested tuple in which individual elements appear multiple
623 // times. Verify buffer alias sets.
624 auto builder = HloComputation::Builder(TestName());
625 auto constant1 = builder.AddInstruction(
626 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
627 auto constant2 = builder.AddInstruction(
628 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
629 auto inner_tuple = builder.AddInstruction(
630 HloInstruction::CreateTuple({constant1, constant2}));
631 auto tuple = builder.AddInstruction(
632 HloInstruction::CreateTuple({inner_tuple, constant2}));
633
634 BuildModuleAndRunAnalysis(builder.Build());
635
636 ExpectHasBufferAliases(
637 constant1, /*index=*/{},
638 {{constant1, {}}, {inner_tuple, {0}}, {tuple, {0, 0}}});
639 ExpectHasBufferAliases(
640 constant2, /*index=*/{},
641 {{constant2, {}}, {inner_tuple, {1}}, {tuple, {0, 1}}, {tuple, {1}}});
642 ExpectHasBufferAliases(inner_tuple, /*index=*/{},
643 {{inner_tuple, {}}, {tuple, {0}}});
644 ExpectHasBufferAliases(tuple, /*index=*/{}, {{tuple, {}}});
645 }
646
TEST_F(TuplePointsToAnalysisTest,CustomCall)647 TEST_F(TuplePointsToAnalysisTest, CustomCall) {
648 auto builder = HloComputation::Builder(TestName());
649 auto constant = builder.AddInstruction(
650 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
651 Shape data_shape = ShapeUtil::MakeShape(F32, {});
652 auto ccall = builder.AddInstruction(HloInstruction::CreateCustomCall(
653 ShapeUtil::MakeTupleShape({data_shape, data_shape}), {constant},
654 "TestOp"));
655 Cast<HloCustomCallInstruction>(ccall)->set_output_to_operand_aliasing(
656 {std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>{
657 ShapeIndex{1}, std::pair<int64, ShapeIndex>(0, {})}});
658 auto gte0 = builder.AddInstruction(
659 HloInstruction::CreateGetTupleElement(data_shape, ccall, 0));
660 auto gte1 = builder.AddInstruction(
661 HloInstruction::CreateGetTupleElement(data_shape, ccall, 1));
662
663 BuildModuleAndRunAnalysis(builder.Build());
664
665 ExpectHasBufferAliases(ccall, /*index=*/{0}, {{gte0, {}}, {ccall, {0}}});
666 ExpectHasBufferAliases(constant, /*index=*/{},
667 {{constant, {}}, {gte1, {}}, {ccall, {1}}});
668 }
669
670 class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest {
671 protected:
672 // Builds a computation, runs instruction fusion HloPass, runs points-to
673 // analysis, then checks for expected results (see unit test cases for
674 // example computation graphs).
Run(const bool add_additional_gte0_user)675 void Run(const bool add_additional_gte0_user) {
676 Shape input_shape = ShapeUtil::MakeShape(F32, {8});
677 Shape update_shape = ShapeUtil::MakeShape(F32, {3});
678 Shape starts_shape = ShapeUtil::MakeShape(S32, {});
679 Shape tuple_shape =
680 ShapeUtil::MakeTupleShape({input_shape, update_shape, starts_shape});
681
682 auto builder = HloComputation::Builder(TestName());
683 // Create tuple-shaped parameter.
684 auto tuple_param0 = builder.AddInstruction(
685 HloInstruction::CreateParameter(0, tuple_shape, "param0"));
686 // Create 'tuple_element1' = GetTupleElement(tuple_param0, 1).
687 auto tuple_element1 = builder.AddInstruction(
688 HloInstruction::CreateGetTupleElement(update_shape, tuple_param0, 1));
689 auto ones = builder.AddInstruction(HloInstruction::CreateConstant(
690 LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f})));
691 // Create 'update' = Add(GetTupleElement(tuple_param0, 1), ones)
692 auto update = builder.AddInstruction(HloInstruction::CreateBinary(
693 update_shape, HloOpcode::kAdd, tuple_element1, ones));
694 // Create 'input' = GetTupleElement(tuple_param0, 0).
695 auto input = builder.AddInstruction(
696 HloInstruction::CreateGetTupleElement(input_shape, tuple_param0, 0));
697
698 if (add_additional_gte0_user) {
699 // Create 'slice' as an additional user of 'input'.
700 auto slice = builder.AddInstruction(
701 HloInstruction::CreateSlice(update_shape, input, {0}, {3}, {1}));
702 // Modify 'update' to take 'slice' output.
703 update = builder.AddInstruction(HloInstruction::CreateBinary(
704 update_shape, HloOpcode::kAdd, update, slice));
705 }
706
707 // Create slice 'starts' = GetTupleElement(tuple_param0, 2).
708 auto starts = builder.AddInstruction(
709 HloInstruction::CreateGetTupleElement(starts_shape, tuple_param0, 2));
710 // Update 'input' with 'update' at dynamic 'starts' indices.
711 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
712 input_shape, input, update, {starts}));
713
714 // Build computation and add it to module as entry computation.
715 BuildModule(builder.Build());
716 // Run instruction fusion HloPass.
717 EXPECT_TRUE(InstructionFusion(InstructionFusion::IsExpensive)
718 .Run(module_.get())
719 .ValueOrDie());
720 // Get computation root instruction (should be a kFusion).
721 auto* fusion = module_->entry_computation()->root_instruction();
722 EXPECT_THAT(fusion, op::Fusion(tuple_param0));
723 // Run points-to analysis (should include fused instructions from 'fusion').
724 RunAnalysis();
725
726 // Check points-to set of fusion parameter associated with 'tuple_param0'.
727 auto* fusion_param = GetFusionParameterForOperand(fusion, tuple_param0);
728 ExpectHasBuffers(
729 points_to_analysis_->GetPointsToSet(fusion_param).element({}),
730 {GetBuffer(fusion_param, {})});
731 ExpectHasBuffers(
732 points_to_analysis_->GetPointsToSet(fusion_param).element({0}),
733 {GetBuffer(fusion_param, {0})});
734 ExpectHasBuffers(
735 points_to_analysis_->GetPointsToSet(fusion_param).element({1}),
736 {GetBuffer(fusion_param, {1})});
737 ExpectHasBuffers(
738 points_to_analysis_->GetPointsToSet(fusion_param).element({2}),
739 {GetBuffer(fusion_param, {2})});
740
741 // Check that Gte at tuple_index = 0 points-to fusion_param({0})
742 auto fused_gte0 = GetUniqueFusionParameterUserAt(fusion_param, 0);
743 ExpectHasBuffers(
744 points_to_analysis_->GetPointsToSet(fused_gte0).element({}),
745 {GetBuffer(fusion_param, {0})});
746 // Check that Gte at tuple_index = 1 points-to fusion_param({1})
747 auto fused_gte1 = GetUniqueFusionParameterUserAt(fusion_param, 1);
748 ExpectHasBuffers(
749 points_to_analysis_->GetPointsToSet(fused_gte1).element({}),
750 {GetBuffer(fusion_param, {1})});
751 // Check that Gte at tuple_index = 2 points-to fusion_param({2})
752 auto fused_gte2 = GetUniqueFusionParameterUserAt(fusion_param, 2);
753 ExpectHasBuffers(
754 points_to_analysis_->GetPointsToSet(fused_gte2).element({}),
755 {GetBuffer(fusion_param, {2})});
756
757 // Check buffer aliases of 'fusion_param' at shape index {0}.
758 ExpectHasBufferAliases(fusion_param, /*index=*/{0},
759 {{fusion_param, {0}}, {fused_gte0, {}}});
760 // Check buffer aliases of 'fusion_param' at shape index {1}.
761 ExpectHasBufferAliases(fusion_param, /*index=*/{1},
762 {{fusion_param, {1}}, {fused_gte1, {}}});
763 // Check buffer aliases of 'fusion_param' at shape index {2}.
764 ExpectHasBufferAliases(fusion_param, /*index=*/{2},
765 {{fusion_param, {2}}, {fused_gte2, {}}});
766
767 // Check number of users of 'fusion_param' aliases at shape index {0}.
768 ExpectNumUsersOfAliases(fusion_param, {0},
769 add_additional_gte0_user ? 2 : 1);
770 }
771
772 // Returns fusion parameter (from 'fusion.fused_instructions') corresponding
773 // to fusion 'operand'.
GetFusionParameterForOperand(HloInstruction * fusion,HloInstruction * operand)774 HloInstruction* GetFusionParameterForOperand(HloInstruction* fusion,
775 HloInstruction* operand) {
776 auto it = absl::c_find_if(
777 fusion->fused_instructions(), [&](const HloInstruction* fused) {
778 return fused->opcode() == HloOpcode::kParameter &&
779 fusion->operand(fused->parameter_number()) == operand;
780 });
781 CHECK(it != fusion->fused_instructions().end());
782 return *it;
783 }
784
785 // Returns all users of 'fusion_paran' at 'tuple_index'.
GetFusionParameterUsersAt(HloInstruction * fusion_param,int64 tuple_index)786 std::vector<HloInstruction*> GetFusionParameterUsersAt(
787 HloInstruction* fusion_param, int64 tuple_index) {
788 CHECK(fusion_param->shape().IsTuple());
789 std::vector<HloInstruction*> users_at_tuple_index;
790 for (auto user : fusion_param->users()) {
791 CHECK_EQ(HloOpcode::kGetTupleElement, user->opcode());
792 if (user->tuple_index() == tuple_index) {
793 users_at_tuple_index.push_back(user);
794 }
795 }
796 return users_at_tuple_index;
797 }
798
799 // Returns the unique user of 'fusion_param' at 'tuple_index'.
GetUniqueFusionParameterUserAt(HloInstruction * fusion_param,int64 tuple_index)800 HloInstruction* GetUniqueFusionParameterUserAt(HloInstruction* fusion_param,
801 int64 tuple_index) {
802 std::vector<HloInstruction*> users =
803 GetFusionParameterUsersAt(fusion_param, tuple_index);
804 CHECK_EQ(1, users.size());
805 return users[0];
806 }
807
808 // Checks that the count of all users of all aliases of 'instruction' at
809 // 'index' match 'expected_num_users'.
ExpectNumUsersOfAliases(const HloInstruction * instruction,const ShapeIndex & index,const int64 expected_num_users)810 void ExpectNumUsersOfAliases(const HloInstruction* instruction,
811 const ShapeIndex& index,
812 const int64 expected_num_users) {
813 const auto* buffer = GetBuffer(instruction, index);
814 int64 num_users = 0;
815 for (const auto& alias : points_to_analysis_->GetBufferAliases(*buffer)) {
816 for (auto user : alias.instruction()->users()) {
817 if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) {
818 // Gte instructions only access the top-level buffer of their operand.
819 continue;
820 }
821 ++num_users;
822 }
823 }
824 EXPECT_EQ(expected_num_users, num_users);
825 }
826 };
827
828 // Tests the points-to set of tuple-shaped fusion parameter 0 and all GTE users.
829 // Tests the alias set of tuple-shaped fusion parameter 0 at all shape indices.
830 // Tests that there is a single user of the aliases of tuple-shaped fusion
831 // parameter 0 at shape index {0}.
832 //
833 // Param0 Const
834 // \ /
835 // Fusion
836 // / \
837 // FusionParam0 FusionParam1
838 // / | \ |
839 // Gte(0) Gte(2) Gte(1) /
840 // \ | \ /
841 // \ | Add
842 // \ | /
843 // \0 |2 /1
844 // DynamicUpdateSlice // fused root.
845 //
TEST_F(FusionPointsToAnalysisTest,FusionParam0OneUser)846 TEST_F(FusionPointsToAnalysisTest, FusionParam0OneUser) {
847 Run(/*add_additional_gte0_user=*/false);
848 }
849
850 // Tests the points-to set of tuple-shaped fusion parameter 0 and all GTE users.
851 // Tests the alias set of tuple-shaped fusion parameter 0 at all shape indices.
852 // Tests that there are two users of the aliases of tuple-shaped fusion
853 // parameter 0 at shape index {0}.
854 //
855 // Param0 Const
856 // \ /
857 // Fusion
858 // / \
859 // FusionParam0 FusionParam1
860 // / | \ |
861 // Gte(2) Gte(0) Gte(1) /
862 // \ | \ /
863 // \ |\ Add
864 // \ | \ /
865 // | | Slice /
866 // | | \ /
867 // | | Add
868 // | | |
869 // |2 |0 |1
870 // DynamicUpdateSlice // fused root.
871 //
TEST_F(FusionPointsToAnalysisTest,FusionParam0TwoUsers)872 TEST_F(FusionPointsToAnalysisTest, FusionParam0TwoUsers) {
873 Run(/*add_additional_gte0_user=*/true);
874 }
875
876 class PointsToAnalysisTestBase : public HloTestBase {
877 protected:
BuildModule(std::unique_ptr<HloComputation> computation)878 void BuildModule(std::unique_ptr<HloComputation> computation) {
879 module_ = CreateNewVerifiedModule();
880 computation_ = module_->AddEntryComputation(std::move(computation));
881 }
882
RunAnalysis()883 void RunAnalysis() {
884 CHECK_NOTNULL(module_.get());
885 points_to_analysis_ =
886 TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
887 }
888
BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation)889 void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
890 BuildModule(std::move(computation));
891 RunAnalysis();
892 }
893
894 std::unique_ptr<HloModule> module_;
895 HloComputation* computation_ = nullptr;
896 std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
897 };
898
899 class DoesNotUseOperandBufferTest : public PointsToAnalysisTestBase {};
900
TEST_F(DoesNotUseOperandBufferTest,GetTupleElement)901 TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) {
902 auto builder = HloComputation::Builder(TestName());
903
904 Shape elem_shape = ShapeUtil::MakeShape(F32, {8});
905 auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
906 0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple"));
907 auto gte0 = builder.AddInstruction(
908 HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0));
909 auto gte1 = builder.AddInstruction(
910 HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1));
911 builder.AddInstruction(
912 HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1));
913
914 BuildModuleAndRunAnalysis(builder.Build());
915
916 // GetTupleElement instructions only access the top-level buffer of their
917 // operand.
918 EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {0}, gte0));
919 EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {1}, gte1));
920 EXPECT_FALSE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte0));
921 EXPECT_FALSE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte1));
922 }
923
TEST_F(DoesNotUseOperandBufferTest,FusedDynamicUpdateSlice)924 TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
925 auto builder = HloComputation::Builder(TestName());
926
927 Shape data_shape = ShapeUtil::MakeShape(F32, {8});
928 auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
929 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
930 auto gte0 = builder.AddInstruction(
931 HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
932 auto gte1 = builder.AddInstruction(
933 HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
934
935 // Create a DynamicUpdateSlice instruction of tuple element 1.
936 auto starts = builder.AddInstruction(
937 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
938 auto update = builder.AddInstruction(HloInstruction::CreateConstant(
939 LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
940 auto dynamic_update_slice =
941 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
942 data_shape, gte1, update, {starts}));
943 builder.AddInstruction(
944 HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
945
946 BuildModule(builder.Build());
947 auto fusion = computation_->CreateFusionInstruction(
948 {dynamic_update_slice, starts, update, gte1},
949 HloInstruction::FusionKind::kLoop);
950 RunAnalysis();
951
952 // The fusion instruction never uses tuple element 0, but does use element 1.
953 EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion));
954 EXPECT_FALSE(
955 points_to_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion));
956 }
957 } // namespace
958 } // namespace xla
959