• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
17 
18 #include <map>
19 #include <memory>
20 
21 #include "tensorflow/compiler/xla/literal_util.h"
22 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
23 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
24 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
25 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/instruction_fusion.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/test_helpers.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/test.h"
35 
36 namespace op = xla::testing::opcode_matchers;
37 
38 namespace xla {
39 namespace {
40 
41 using ::testing::UnorderedElementsAre;
42 using ::testing::UnorderedElementsAreArray;
43 
44 class TuplePointsToAnalysisTest : public HloTestBase {
45  protected:
46   // Builds a module with the given entry computation and runs points to
47   // analysis.
BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation)48   void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
49     BuildModule(std::move(computation));
50     RunAnalysis();
51   }
52 
BuildModule(std::unique_ptr<HloComputation> computation)53   void BuildModule(std::unique_ptr<HloComputation> computation) {
54     module_ = CreateNewVerifiedModule();
55     module_->AddEntryComputation(std::move(computation));
56   }
57 
RunAnalysis()58   void RunAnalysis() {
59     CHECK_NOTNULL(module_.get());
60     points_to_analysis_ =
61         TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
62   }
63 
64   // Returns the LogicalBuffer defined at the given instruction and
65   // index. CHECKs if no buffer is defined at that point.
GetBuffer(const HloInstruction * instruction,const ShapeIndex & index)66   const LogicalBuffer* const GetBuffer(const HloInstruction* instruction,
67                                        const ShapeIndex& index) {
68     const auto& pointed_to =
69         points_to_analysis_->GetPointsToSet(instruction).element(index);
70     CHECK_EQ(1, pointed_to.size());
71     CHECK_EQ(instruction, pointed_to[0]->instruction());
72     CHECK(index == pointed_to[0]->index());
73     return pointed_to[0];
74   }
75 
76   // Checks that the given points-to set contains exactly (unordered) the given
77   // LogicalBuffers.
ExpectHasBuffers(const PointsToSet::BufferList & points_to_set,absl::Span<const LogicalBuffer * const> buffers)78   void ExpectHasBuffers(const PointsToSet::BufferList& points_to_set,
79                         absl::Span<const LogicalBuffer* const> buffers) {
80     std::vector<const LogicalBuffer*> vec(buffers.begin(), buffers.end());
81     EXPECT_THAT(points_to_set, UnorderedElementsAreArray(vec));
82   }
83 
84   // Checks that the given points-to set contains exactly (unordered) the
85   // top-level buffers of the given instructions.
ExpectHasTopLevelBuffers(const PointsToSet::BufferList & points_to_set,absl::Span<HloInstruction * const> instructions)86   void ExpectHasTopLevelBuffers(
87       const PointsToSet::BufferList& points_to_set,
88       absl::Span<HloInstruction* const> instructions) {
89     PointsToSet::BufferList buffers;
90     for (auto instruction : instructions) {
91       buffers.push_back(GetBuffer(instruction, /*index=*/{}));
92     }
93     ExpectHasBuffers(points_to_set, buffers);
94   }
95 
96   // Overload which takes a set instead of a vector.
ExpectHasTopLevelBuffers(const PointsToSet::BufferSet & points_to_set,absl::Span<HloInstruction * const> instructions)97   void ExpectHasTopLevelBuffers(
98       const PointsToSet::BufferSet& points_to_set,
99       absl::Span<HloInstruction* const> instructions) {
100     ExpectHasTopLevelBuffers(
101         PointsToSet::BufferList(points_to_set.begin(), points_to_set.end()),
102         instructions);
103   }
104 
105   // Checks that the buffer defined at the given instruction and index has
106   // aliases which are exactly (unordered) the given instruction/index pairs.
ExpectHasBufferAliases(const HloInstruction * instruction,const ShapeIndex & index,absl::Span<const std::pair<HloInstruction *,ShapeIndex>> expected)107   void ExpectHasBufferAliases(
108       const HloInstruction* instruction, const ShapeIndex& index,
109       absl::Span<const std::pair<HloInstruction*, ShapeIndex>> expected) {
110     const LogicalBuffer* buffer =
111         points_to_analysis_->GetBufferDefinedAt(instruction, index)
112             .ValueOrDie();
113     std::vector<BufferAlias> expected_aliases;
114     for (auto& pair : expected) {
115       expected_aliases.push_back(BufferAlias(pair.first, pair.second));
116     }
117     EXPECT_THAT(points_to_analysis_->GetBufferAliases(*buffer),
118                 UnorderedElementsAreArray(expected_aliases));
119   }
120 
121   std::unique_ptr<HloModule> module_;
122   std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
123 };
124 
TEST_F(TuplePointsToAnalysisTest,SimpleTuple)125 TEST_F(TuplePointsToAnalysisTest, SimpleTuple) {
126   auto builder = HloComputation::Builder(TestName());
127   auto constant1 = builder.AddInstruction(
128       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
129   auto constant2 = builder.AddInstruction(
130       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
131   auto tuple = builder.AddInstruction(
132       HloInstruction::CreateTuple({constant1, constant2}));
133 
134   BuildModuleAndRunAnalysis(builder.Build());
135   EXPECT_EQ(1, points_to_analysis_->GetPointsToSet(constant1).size());
136   ExpectHasTopLevelBuffers(
137       points_to_analysis_->GetPointsToSet(constant1).element({}), {constant1});
138   EXPECT_TRUE(
139       points_to_analysis_->GetPointsToSet(constant1).tuple_sources({}).empty());
140   EXPECT_TRUE(points_to_analysis_->GetPointsToSet(tuple).IsDistinct());
141 
142   EXPECT_EQ(1, points_to_analysis_->GetPointsToSet(constant2).size());
143   ExpectHasTopLevelBuffers(
144       points_to_analysis_->GetPointsToSet(constant2).element({}), {constant2});
145   EXPECT_TRUE(
146       points_to_analysis_->GetPointsToSet(constant2).tuple_sources({}).empty());
147 
148   EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(tuple).size());
149   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous());
150   EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}),
151               UnorderedElementsAre(tuple));
152 
153   ExpectHasTopLevelBuffers(
154       points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
155       {constant1, constant2, tuple});
156   ExpectHasTopLevelBuffers(
157       points_to_analysis_->GetPointsToSet(tuple).element({}), {tuple});
158   ExpectHasTopLevelBuffers(
159       points_to_analysis_->GetPointsToSet(tuple).element({0}), {constant1});
160   ExpectHasTopLevelBuffers(
161       points_to_analysis_->GetPointsToSet(tuple).element({1}), {constant2});
162 
163   const PointsToSet& tuple_points_to_set =
164       points_to_analysis_->GetPointsToSet(tuple);
165   EXPECT_TRUE(tuple_points_to_set.ContainsBufferAtIndex(
166       *GetBuffer(constant1, {}), {0}));
167   EXPECT_TRUE(tuple_points_to_set.ContainsBufferAtIndex(
168       *GetBuffer(constant2, {}), {1}));
169   EXPECT_FALSE(tuple_points_to_set.ContainsBufferAtIndex(
170       *GetBuffer(constant2, {}), {0}));
171   EXPECT_TRUE(tuple_points_to_set.ContainsBuffer(*GetBuffer(constant1, {})));
172   EXPECT_TRUE(tuple_points_to_set.ContainsBuffer(*GetBuffer(constant2, {})));
173 }
174 
TEST_F(TuplePointsToAnalysisTest,NestedTuple)175 TEST_F(TuplePointsToAnalysisTest, NestedTuple) {
176   // Create a (nested) tuple containing an inner tuple. The points-to set of the
177   // outer tuple should contain all elements of the points-to set of the inner
178   // tuple.
179   auto builder = HloComputation::Builder(TestName());
180   auto constant1 = builder.AddInstruction(
181       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
182   auto constant2 = builder.AddInstruction(
183       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
184   auto inner_tuple = builder.AddInstruction(
185       HloInstruction::CreateTuple({constant1, constant2}));
186 
187   auto constant3 = builder.AddInstruction(
188       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
189   auto tuple = builder.AddInstruction(
190       HloInstruction::CreateTuple({inner_tuple, constant3}));
191 
192   BuildModuleAndRunAnalysis(builder.Build());
193   ExpectHasTopLevelBuffers(
194       points_to_analysis_->GetPointsToSet(constant1).element({}), {constant1});
195   ExpectHasTopLevelBuffers(
196       points_to_analysis_->GetPointsToSet(constant2).element({}), {constant2});
197   ExpectHasTopLevelBuffers(
198       points_to_analysis_->GetPointsToSet(constant3).element({}), {constant3});
199 
200   EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(inner_tuple).size());
201   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(inner_tuple).IsAmbiguous());
202   EXPECT_TRUE(points_to_analysis_->GetPointsToSet(inner_tuple).IsDistinct());
203   ExpectHasTopLevelBuffers(
204       points_to_analysis_->GetPointsToSet(inner_tuple).CreateFlattenedSet(),
205       {constant1, constant2, inner_tuple});
206   ExpectHasTopLevelBuffers(
207       points_to_analysis_->GetPointsToSet(inner_tuple).element({}),
208       {inner_tuple});
209   EXPECT_THAT(
210       points_to_analysis_->GetPointsToSet(inner_tuple).tuple_sources({}),
211       UnorderedElementsAre(inner_tuple));
212 
213   EXPECT_EQ(5, points_to_analysis_->GetPointsToSet(tuple).size());
214   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous());
215   ExpectHasTopLevelBuffers(
216       points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
217       {constant1, constant2, constant3, inner_tuple, tuple});
218 
219   EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}),
220               UnorderedElementsAre(tuple));
221   EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({0}),
222               UnorderedElementsAre(inner_tuple));
223   EXPECT_TRUE(
224       points_to_analysis_->GetPointsToSet(tuple).tuple_sources({1}).empty());
225 
226   ExpectHasTopLevelBuffers(
227       points_to_analysis_->GetPointsToSet(tuple).element({0}), {inner_tuple});
228   ExpectHasTopLevelBuffers(
229       points_to_analysis_->GetPointsToSet(tuple).element({0, 0}), {constant1});
230   ExpectHasTopLevelBuffers(
231       points_to_analysis_->GetPointsToSet(tuple).element({0, 1}), {constant2});
232   ExpectHasTopLevelBuffers(
233       points_to_analysis_->GetPointsToSet(tuple).element({1}), {constant3});
234 }
235 
TEST_F(TuplePointsToAnalysisTest,GetTupleElement)236 TEST_F(TuplePointsToAnalysisTest, GetTupleElement) {
237   // Create a nested tuple, then extract the inner tuple with GetTupleElement.
238   // The points-to set of the GetTupleElement should be the same as the inner
239   // tuple.
240   auto builder = HloComputation::Builder(TestName());
241   auto constant1 = builder.AddInstruction(
242       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
243   auto constant2 = builder.AddInstruction(
244       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
245   auto inner_tuple = builder.AddInstruction(
246       HloInstruction::CreateTuple({constant1, constant2}));
247 
248   auto constant3 = builder.AddInstruction(
249       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
250   auto tuple = builder.AddInstruction(
251       HloInstruction::CreateTuple({inner_tuple, constant3}));
252 
253   auto get_tuple_element = builder.AddInstruction(
254       HloInstruction::CreateGetTupleElement(inner_tuple->shape(), tuple, 0));
255 
256   BuildModuleAndRunAnalysis(builder.Build());
257 
258   auto& points_to_set = points_to_analysis_->GetPointsToSet(get_tuple_element);
259   EXPECT_EQ(3, points_to_set.size());
260   EXPECT_FALSE(points_to_set.IsAmbiguous());
261   EXPECT_TRUE(points_to_set.IsDistinct());
262   ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(),
263                            {constant1, constant2, inner_tuple});
264   ExpectHasTopLevelBuffers(points_to_set.element({}), {inner_tuple});
265 
266   EXPECT_THAT(points_to_set.tuple_sources({}),
267               UnorderedElementsAre(inner_tuple));
268 }
269 
TEST_F(TuplePointsToAnalysisTest,AddDependency)270 TEST_F(TuplePointsToAnalysisTest, AddDependency) {
271   auto builder = HloComputation::Builder(TestName());
272   auto constant = builder.AddInstruction(
273       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
274   auto token = builder.AddInstruction(HloInstruction::CreateToken());
275   auto add_dependency = builder.AddInstruction(
276       HloInstruction::CreateAddDependency(constant, token));
277   BuildModuleAndRunAnalysis(builder.Build());
278 
279   auto& points_to_set = points_to_analysis_->GetPointsToSet(add_dependency);
280   EXPECT_EQ(1, points_to_set.size());
281   EXPECT_FALSE(points_to_set.IsAmbiguous());
282   EXPECT_TRUE(points_to_set.IsDistinct());
283   ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(), {constant});
284 }
285 
TEST_F(TuplePointsToAnalysisTest,DuplicatedElement)286 TEST_F(TuplePointsToAnalysisTest, DuplicatedElement) {
287   // Create a tuple which contains duplicate elements.
288   auto builder = HloComputation::Builder(TestName());
289   auto constant = builder.AddInstruction(
290       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
291   auto tuple = builder.AddInstruction(
292       HloInstruction::CreateTuple({constant, constant, constant}));
293 
294   BuildModuleAndRunAnalysis(builder.Build());
295 
296   EXPECT_EQ(2, points_to_analysis_->GetPointsToSet(tuple).size());
297   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous());
298   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsDistinct());
299   ExpectHasTopLevelBuffers(
300       points_to_analysis_->GetPointsToSet(tuple).element({}), {tuple});
301   ExpectHasTopLevelBuffers(
302       points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
303       {constant, tuple});
304 }
305 
TEST_F(TuplePointsToAnalysisTest,TupleCopy)306 TEST_F(TuplePointsToAnalysisTest, TupleCopy) {
307   // Create a copy (HloOpcode::kCopy) of a tuple. The points to sets should be
308   // the same.
309   auto builder = HloComputation::Builder(TestName());
310   auto constant1 = builder.AddInstruction(
311       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
312   auto constant2 = builder.AddInstruction(
313       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
314   auto tuple = builder.AddInstruction(
315       HloInstruction::CreateTuple({constant1, constant2}));
316   auto copy = builder.AddInstruction(
317       HloInstruction::CreateUnary(tuple->shape(), HloOpcode::kCopy, tuple));
318 
319   BuildModuleAndRunAnalysis(builder.Build());
320 
321   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(copy).IsAmbiguous());
322   EXPECT_TRUE(points_to_analysis_->GetPointsToSet(copy).IsDistinct());
323   ExpectHasTopLevelBuffers(
324       points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
325       {constant1, constant2, tuple});
326   ExpectHasTopLevelBuffers(
327       points_to_analysis_->GetPointsToSet(copy).element({}), {copy});
328   ExpectHasTopLevelBuffers(
329       points_to_analysis_->GetPointsToSet(copy).CreateFlattenedSet(),
330       {constant1, constant2, copy});
331 }
332 
TEST_F(TuplePointsToAnalysisTest,CopyStartAndCopyDone)333 TEST_F(TuplePointsToAnalysisTest, CopyStartAndCopyDone) {
334   // CopyDone forwards its operand to the output tuple at {0}.
335   auto builder = HloComputation::Builder(TestName());
336   auto constant = builder.AddInstruction(
337       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
338   auto copy_start = builder.AddInstruction(HloInstruction::CreateCopyStart(
339       ShapeUtil::MakeTupleShape({constant->shape(), constant->shape(),
340                                  ShapeUtil::MakeShape(U32, {})}),
341       constant));
342   auto copy_done = builder.AddInstruction(HloInstruction::CreateUnary(
343       constant->shape(), HloOpcode::kCopyDone, copy_start));
344 
345   BuildModuleAndRunAnalysis(builder.Build());
346 
347   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(copy_start).IsAmbiguous());
348   EXPECT_TRUE(points_to_analysis_->GetPointsToSet(copy_start).IsDistinct());
349   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(copy_done).IsAmbiguous());
350   EXPECT_TRUE(points_to_analysis_->GetPointsToSet(copy_done).IsDistinct());
351 
352   ExpectHasTopLevelBuffers(
353       points_to_analysis_->GetPointsToSet(copy_start).element({}),
354       {copy_start});
355   ExpectHasBufferAliases(copy_start, {0}, {{copy_start, {0}}, {copy_done, {}}});
356   ExpectHasBufferAliases(constant, {}, {{constant, {}}, {copy_start, {1}}});
357 }
358 
TEST_F(TuplePointsToAnalysisTest,SendAndSendDone)359 TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) {
360   // Send forwards its operand to the output tuple at {0}.
361   auto builder = HloComputation::Builder(TestName());
362   auto constant = builder.AddInstruction(
363       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
364   auto token = builder.AddInstruction(HloInstruction::CreateToken());
365   auto send = builder.AddInstruction(
366       HloInstruction::CreateSend(constant, token, /*channel_id=*/0));
367   auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
368 
369   BuildModuleAndRunAnalysis(builder.Build());
370 
371   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send).IsAmbiguous());
372   EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send).IsDistinct());
373   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send_done).IsAmbiguous());
374   EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send_done).IsDistinct());
375 
376   ExpectHasTopLevelBuffers(
377       points_to_analysis_->GetPointsToSet(send).element({}), {send});
378   ExpectHasTopLevelBuffers(
379       points_to_analysis_->GetPointsToSet(send).element({0}), {constant});
380   ExpectHasTopLevelBuffers(
381       points_to_analysis_->GetPointsToSet(send_done).CreateFlattenedSet(),
382       {send_done});
383   ExpectHasBufferAliases(constant, {}, {{constant, {}}, {send, {0}}});
384 }
385 
TEST_F(TuplePointsToAnalysisTest,RecvAndRecvDone)386 TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) {
387   // RecvDone forwards its operand tuple element at {0} to the output.
388   auto builder = HloComputation::Builder(TestName());
389   auto token = builder.AddInstruction(HloInstruction::CreateToken());
390   auto recv = builder.AddInstruction(HloInstruction::CreateRecv(
391       ShapeUtil::MakeShape(F32, {1, 2, 3}), token, /*channel_id=*/0));
392   auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
393 
394   BuildModuleAndRunAnalysis(builder.Build());
395 
396   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv).IsAmbiguous());
397   EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv).IsDistinct());
398   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv_done).IsAmbiguous());
399   EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv_done).IsDistinct());
400 
401   ExpectHasTopLevelBuffers(
402       points_to_analysis_->GetPointsToSet(recv).element({}), {recv});
403   ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {0}}});
404 }
405 
TEST_F(TuplePointsToAnalysisTest,TupleSelect)406 TEST_F(TuplePointsToAnalysisTest, TupleSelect) {
407   // Select from two different tuples. This should create an ambiguous points to
408   // set containing the union of both sides.
409   auto builder = HloComputation::Builder(TestName());
410   auto constant1 = builder.AddInstruction(
411       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
412   auto constant2 = builder.AddInstruction(
413       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
414   auto tuple1 = builder.AddInstruction(
415       HloInstruction::CreateTuple({constant1, constant2}));
416   auto tuple2 = builder.AddInstruction(
417       HloInstruction::CreateTuple({constant2, constant2}));
418 
419   auto pred = builder.AddInstruction(
420       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
421   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
422       tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
423 
424   BuildModuleAndRunAnalysis(builder.Build());
425 
426   auto& points_to_set = points_to_analysis_->GetPointsToSet(select);
427   EXPECT_EQ(3, points_to_set.size());
428   EXPECT_TRUE(points_to_set.IsAmbiguous());
429   EXPECT_FALSE(points_to_set.IsDistinct());
430   ExpectHasTopLevelBuffers(points_to_set.element({}), {select});
431   ExpectHasTopLevelBuffers(points_to_set.element({0}), {constant1, constant2});
432   ExpectHasTopLevelBuffers(points_to_set.element({1}), {constant2});
433   ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(),
434                            {constant1, constant2, select});
435 }
436 
TEST_F(TuplePointsToAnalysisTest,SelectTupleParameters)437 TEST_F(TuplePointsToAnalysisTest, SelectTupleParameters) {
438   // Create a Select which selects between two tuple parameters. Verify the
439   // points-to sets and tuple sources are properly set.
440   Shape tuple_shape = ShapeUtil::MakeTupleShape(
441       {ShapeUtil::MakeShape(F32, {1, 2, 3}), ShapeUtil::MakeShape(U32, {5})});
442 
443   auto builder = HloComputation::Builder(TestName());
444   auto param0 = builder.AddInstruction(
445       HloInstruction::CreateParameter(0, tuple_shape, "param0"));
446   auto param1 = builder.AddInstruction(
447       HloInstruction::CreateParameter(1, tuple_shape, "param1"));
448   auto pred = builder.AddInstruction(
449       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
450   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
451       tuple_shape, HloOpcode::kTupleSelect, pred, param0, param1));
452   auto copy = builder.AddInstruction(
453       HloInstruction::CreateUnary(tuple_shape, HloOpcode::kCopy, select));
454 
455   BuildModuleAndRunAnalysis(builder.Build());
456 
457   // The points-to set of each element of a tuple parameters should be itself
458   // with the appropriate index.
459   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(param0).element({}),
460                    {GetBuffer(param0, {})});
461   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(param0).element({0}),
462                    {GetBuffer(param0, {0})});
463   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(param0).element({1}),
464                    {GetBuffer(param0, {1})});
465 
466   // Select's point-to set of its subelements should be the respective
467   // subelements of param0 and param1. The top-level buffer, however, does not
468   // alias as it is created by the select instruction.
469   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(select).element({}),
470                    {GetBuffer(select, {})});
471   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(select).element({0}),
472                    {GetBuffer(param0, {0}), GetBuffer(param1, {0})});
473   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(select).element({1}),
474                    {GetBuffer(param0, {1}), GetBuffer(param1, {1})});
475 
476   // Copy should be identical to select other than the top-level buffer.
477   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(copy).element({}),
478                    {GetBuffer(copy, {})});
479   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(copy).element({0}),
480                    {GetBuffer(param0, {0}), GetBuffer(param1, {0})});
481   ExpectHasBuffers(points_to_analysis_->GetPointsToSet(copy).element({1}),
482                    {GetBuffer(param0, {1}), GetBuffer(param1, {1})});
483 }
484 
TEST_F(TuplePointsToAnalysisTest,UnambiguousTupleSelect)485 TEST_F(TuplePointsToAnalysisTest, UnambiguousTupleSelect) {
486   // Select from two identical tuples. The result should not be ambiguous.
487   auto builder = HloComputation::Builder(TestName());
488   auto constant1 = builder.AddInstruction(
489       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
490   auto constant2 = builder.AddInstruction(
491       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
492   auto tuple1 = builder.AddInstruction(
493       HloInstruction::CreateTuple({constant1, constant2}));
494   auto tuple2 = builder.AddInstruction(
495       HloInstruction::CreateTuple({constant1, constant2}));
496 
497   auto pred = builder.AddInstruction(
498       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
499   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
500       tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
501 
502   BuildModuleAndRunAnalysis(builder.Build());
503 
504   auto& points_to_set = points_to_analysis_->GetPointsToSet(select);
505   EXPECT_EQ(3, points_to_set.size());
506   EXPECT_FALSE(points_to_set.IsAmbiguous());
507   EXPECT_TRUE(points_to_set.IsDistinct());
508   ExpectHasTopLevelBuffers(points_to_set.element({}), {select});
509   ExpectHasTopLevelBuffers(points_to_set.element({0}), {constant1});
510   ExpectHasTopLevelBuffers(points_to_set.element({1}), {constant2});
511   ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(),
512                            {constant1, constant2, select});
513 }
514 
TEST_F(TuplePointsToAnalysisTest,NestedTupleSelect)515 TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) {
516   // Select from nested tuples. Verify that the nested points-to sets contain
517   // the right values.
518   auto builder = HloComputation::Builder(TestName());
519   auto constant1 = builder.AddInstruction(
520       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
521   auto constant2 = builder.AddInstruction(
522       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
523   auto inner_tuple1 = builder.AddInstruction(
524       HloInstruction::CreateTuple({constant1, constant2}));
525   auto inner_tuple2 = builder.AddInstruction(
526       HloInstruction::CreateTuple({constant2, constant2}));
527 
528   auto tuple1 =
529       builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple1}));
530   auto tuple2 =
531       builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple2}));
532 
533   auto pred = builder.AddInstruction(
534       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
535   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
536       tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
537 
538   BuildModuleAndRunAnalysis(builder.Build());
539 
540   auto& points_to_set = points_to_analysis_->GetPointsToSet(select);
541   EXPECT_EQ(5, points_to_set.size());
542   EXPECT_TRUE(points_to_set.IsAmbiguous());
543   EXPECT_FALSE(points_to_set.IsDistinct());
544 
545   // Verify points-to set.
546   ExpectHasTopLevelBuffers(points_to_set.element({}), {select});
547   ExpectHasTopLevelBuffers(points_to_set.element({0}),
548                            {inner_tuple1, inner_tuple2});
549   ExpectHasTopLevelBuffers(points_to_set.element({0, 0}),
550                            {constant1, constant2});
551   ExpectHasTopLevelBuffers(points_to_set.element({0, 1}), {constant2});
552 
553   // Verify tuple sources.
554   EXPECT_THAT(points_to_set.tuple_sources({}),
555               UnorderedElementsAre(tuple1, tuple2));
556   EXPECT_THAT(points_to_set.tuple_sources({0}),
557               UnorderedElementsAre(inner_tuple1, inner_tuple2));
558   EXPECT_EQ(0, points_to_set.tuple_sources({0, 0}).size());
559   EXPECT_EQ(0, points_to_set.tuple_sources({0, 1}).size());
560 }
561 
TEST_F(TuplePointsToAnalysisTest,TupleWithBitcast)562 TEST_F(TuplePointsToAnalysisTest, TupleWithBitcast) {
563   // Bitcast is an alias of its operand. A tuple with a bitcast element should
564   // have the operand of the bitcast in its points-to set.
565   auto builder = HloComputation::Builder(TestName());
566   auto constant1 = builder.AddInstruction(
567       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
568   auto constant2 = builder.AddInstruction(
569       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
570   auto bitcast = builder.AddInstruction(
571       HloInstruction::CreateBitcast(constant2->shape(), constant2));
572   auto tuple =
573       builder.AddInstruction(HloInstruction::CreateTuple({constant1, bitcast}));
574 
575   BuildModuleAndRunAnalysis(builder.Build());
576 
577   EXPECT_EQ(1, points_to_analysis_->GetPointsToSet(bitcast).size());
578   ExpectHasTopLevelBuffers(
579       points_to_analysis_->GetPointsToSet(bitcast).element({}), {constant2});
580   EXPECT_TRUE(
581       points_to_analysis_->GetPointsToSet(bitcast).tuple_sources({}).empty());
582 
583   EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(tuple).size());
584   EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous());
585   EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}),
586               UnorderedElementsAre(tuple));
587 
588   ExpectHasTopLevelBuffers(
589       points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
590       {constant1, constant2, tuple});
591   ExpectHasTopLevelBuffers(
592       points_to_analysis_->GetPointsToSet(tuple).element({}), {tuple});
593   ExpectHasTopLevelBuffers(
594       points_to_analysis_->GetPointsToSet(tuple).element({0}), {constant1});
595   ExpectHasTopLevelBuffers(
596       points_to_analysis_->GetPointsToSet(tuple).element({1}), {constant2});
597 }
598 
TEST_F(TuplePointsToAnalysisTest,PointsToTupleConstantElements)599 TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) {
600   // Construct a tuple constant and kCopy it. Verify the points-to set of the
601   // copy correctly points into the nested elements of the constant.
602   auto builder = HloComputation::Builder(TestName());
603   Literal elements[] = {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}),
604                         LiteralUtil::CreateR1<float>({2.0, 42})};
605   auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
606       LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
607   auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
608       tuple_constant->shape(), HloOpcode::kCopy, tuple_constant));
609 
610   BuildModuleAndRunAnalysis(builder.Build());
611 
612   auto& points_to_set = points_to_analysis_->GetPointsToSet(copy);
613 
614   ExpectHasBuffers(points_to_set.element({}), {GetBuffer(copy, {})});
615   ExpectHasBuffers(points_to_set.element({0}),
616                    {GetBuffer(tuple_constant, {0})});
617   ExpectHasBuffers(points_to_set.element({1}),
618                    {GetBuffer(tuple_constant, {1})});
619 }
620 
TEST_F(TuplePointsToAnalysisTest,BufferAliases)621 TEST_F(TuplePointsToAnalysisTest, BufferAliases) {
622   // Create a nested tuple in which individual elements appear multiple
623   // times. Verify buffer alias sets.
624   auto builder = HloComputation::Builder(TestName());
625   auto constant1 = builder.AddInstruction(
626       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
627   auto constant2 = builder.AddInstruction(
628       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
629   auto inner_tuple = builder.AddInstruction(
630       HloInstruction::CreateTuple({constant1, constant2}));
631   auto tuple = builder.AddInstruction(
632       HloInstruction::CreateTuple({inner_tuple, constant2}));
633 
634   BuildModuleAndRunAnalysis(builder.Build());
635 
636   ExpectHasBufferAliases(
637       constant1, /*index=*/{},
638       {{constant1, {}}, {inner_tuple, {0}}, {tuple, {0, 0}}});
639   ExpectHasBufferAliases(
640       constant2, /*index=*/{},
641       {{constant2, {}}, {inner_tuple, {1}}, {tuple, {0, 1}}, {tuple, {1}}});
642   ExpectHasBufferAliases(inner_tuple, /*index=*/{},
643                          {{inner_tuple, {}}, {tuple, {0}}});
644   ExpectHasBufferAliases(tuple, /*index=*/{}, {{tuple, {}}});
645 }
646 
TEST_F(TuplePointsToAnalysisTest,CustomCall)647 TEST_F(TuplePointsToAnalysisTest, CustomCall) {
648   auto builder = HloComputation::Builder(TestName());
649   auto constant = builder.AddInstruction(
650       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
651   Shape data_shape = ShapeUtil::MakeShape(F32, {});
652   auto ccall = builder.AddInstruction(HloInstruction::CreateCustomCall(
653       ShapeUtil::MakeTupleShape({data_shape, data_shape}), {constant},
654       "TestOp"));
655   Cast<HloCustomCallInstruction>(ccall)->set_output_to_operand_aliasing(
656       {std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>{
657           ShapeIndex{1}, std::pair<int64, ShapeIndex>(0, {})}});
658   auto gte0 = builder.AddInstruction(
659       HloInstruction::CreateGetTupleElement(data_shape, ccall, 0));
660   auto gte1 = builder.AddInstruction(
661       HloInstruction::CreateGetTupleElement(data_shape, ccall, 1));
662 
663   BuildModuleAndRunAnalysis(builder.Build());
664 
665   ExpectHasBufferAliases(ccall, /*index=*/{0}, {{gte0, {}}, {ccall, {0}}});
666   ExpectHasBufferAliases(constant, /*index=*/{},
667                          {{constant, {}}, {gte1, {}}, {ccall, {1}}});
668 }
669 
670 class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest {
671  protected:
672   // Builds a computation, runs instruction fusion HloPass, runs points-to
673   // analysis, then checks for expected results (see unit test cases for
674   // example computation graphs).
Run(const bool add_additional_gte0_user)675   void Run(const bool add_additional_gte0_user) {
676     Shape input_shape = ShapeUtil::MakeShape(F32, {8});
677     Shape update_shape = ShapeUtil::MakeShape(F32, {3});
678     Shape starts_shape = ShapeUtil::MakeShape(S32, {});
679     Shape tuple_shape =
680         ShapeUtil::MakeTupleShape({input_shape, update_shape, starts_shape});
681 
682     auto builder = HloComputation::Builder(TestName());
683     // Create tuple-shaped parameter.
684     auto tuple_param0 = builder.AddInstruction(
685         HloInstruction::CreateParameter(0, tuple_shape, "param0"));
686     // Create 'tuple_element1' = GetTupleElement(tuple_param0, 1).
687     auto tuple_element1 = builder.AddInstruction(
688         HloInstruction::CreateGetTupleElement(update_shape, tuple_param0, 1));
689     auto ones = builder.AddInstruction(HloInstruction::CreateConstant(
690         LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f})));
691     // Create 'update' = Add(GetTupleElement(tuple_param0, 1), ones)
692     auto update = builder.AddInstruction(HloInstruction::CreateBinary(
693         update_shape, HloOpcode::kAdd, tuple_element1, ones));
694     // Create 'input' = GetTupleElement(tuple_param0, 0).
695     auto input = builder.AddInstruction(
696         HloInstruction::CreateGetTupleElement(input_shape, tuple_param0, 0));
697 
698     if (add_additional_gte0_user) {
699       // Create 'slice' as an additional user of 'input'.
700       auto slice = builder.AddInstruction(
701           HloInstruction::CreateSlice(update_shape, input, {0}, {3}, {1}));
702       // Modify 'update' to take 'slice' output.
703       update = builder.AddInstruction(HloInstruction::CreateBinary(
704           update_shape, HloOpcode::kAdd, update, slice));
705     }
706 
707     // Create slice 'starts' = GetTupleElement(tuple_param0, 2).
708     auto starts = builder.AddInstruction(
709         HloInstruction::CreateGetTupleElement(starts_shape, tuple_param0, 2));
710     // Update 'input' with 'update' at dynamic 'starts' indices.
711     builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
712         input_shape, input, update, {starts}));
713 
714     // Build computation and add it to module as entry computation.
715     BuildModule(builder.Build());
716     // Run instruction fusion HloPass.
717     EXPECT_TRUE(InstructionFusion(InstructionFusion::IsExpensive)
718                     .Run(module_.get())
719                     .ValueOrDie());
720     // Get computation root instruction (should be a kFusion).
721     auto* fusion = module_->entry_computation()->root_instruction();
722     EXPECT_THAT(fusion, op::Fusion(tuple_param0));
723     // Run points-to analysis (should include fused instructions from 'fusion').
724     RunAnalysis();
725 
726     // Check points-to set of fusion parameter associated with 'tuple_param0'.
727     auto* fusion_param = GetFusionParameterForOperand(fusion, tuple_param0);
728     ExpectHasBuffers(
729         points_to_analysis_->GetPointsToSet(fusion_param).element({}),
730         {GetBuffer(fusion_param, {})});
731     ExpectHasBuffers(
732         points_to_analysis_->GetPointsToSet(fusion_param).element({0}),
733         {GetBuffer(fusion_param, {0})});
734     ExpectHasBuffers(
735         points_to_analysis_->GetPointsToSet(fusion_param).element({1}),
736         {GetBuffer(fusion_param, {1})});
737     ExpectHasBuffers(
738         points_to_analysis_->GetPointsToSet(fusion_param).element({2}),
739         {GetBuffer(fusion_param, {2})});
740 
741     // Check that Gte at tuple_index = 0 points-to fusion_param({0})
742     auto fused_gte0 = GetUniqueFusionParameterUserAt(fusion_param, 0);
743     ExpectHasBuffers(
744         points_to_analysis_->GetPointsToSet(fused_gte0).element({}),
745         {GetBuffer(fusion_param, {0})});
746     // Check that Gte at tuple_index = 1 points-to fusion_param({1})
747     auto fused_gte1 = GetUniqueFusionParameterUserAt(fusion_param, 1);
748     ExpectHasBuffers(
749         points_to_analysis_->GetPointsToSet(fused_gte1).element({}),
750         {GetBuffer(fusion_param, {1})});
751     // Check that Gte at tuple_index = 2 points-to fusion_param({2})
752     auto fused_gte2 = GetUniqueFusionParameterUserAt(fusion_param, 2);
753     ExpectHasBuffers(
754         points_to_analysis_->GetPointsToSet(fused_gte2).element({}),
755         {GetBuffer(fusion_param, {2})});
756 
757     // Check buffer aliases of 'fusion_param' at shape index {0}.
758     ExpectHasBufferAliases(fusion_param, /*index=*/{0},
759                            {{fusion_param, {0}}, {fused_gte0, {}}});
760     // Check buffer aliases of 'fusion_param' at shape index {1}.
761     ExpectHasBufferAliases(fusion_param, /*index=*/{1},
762                            {{fusion_param, {1}}, {fused_gte1, {}}});
763     // Check buffer aliases of 'fusion_param' at shape index {2}.
764     ExpectHasBufferAliases(fusion_param, /*index=*/{2},
765                            {{fusion_param, {2}}, {fused_gte2, {}}});
766 
767     // Check number of users of 'fusion_param' aliases at shape index {0}.
768     ExpectNumUsersOfAliases(fusion_param, {0},
769                             add_additional_gte0_user ? 2 : 1);
770   }
771 
772   // Returns fusion parameter (from 'fusion.fused_instructions') corresponding
773   // to fusion 'operand'.
GetFusionParameterForOperand(HloInstruction * fusion,HloInstruction * operand)774   HloInstruction* GetFusionParameterForOperand(HloInstruction* fusion,
775                                                HloInstruction* operand) {
776     auto it = absl::c_find_if(
777         fusion->fused_instructions(), [&](const HloInstruction* fused) {
778           return fused->opcode() == HloOpcode::kParameter &&
779                  fusion->operand(fused->parameter_number()) == operand;
780         });
781     CHECK(it != fusion->fused_instructions().end());
782     return *it;
783   }
784 
785   // Returns all users of 'fusion_paran' at 'tuple_index'.
GetFusionParameterUsersAt(HloInstruction * fusion_param,int64 tuple_index)786   std::vector<HloInstruction*> GetFusionParameterUsersAt(
787       HloInstruction* fusion_param, int64 tuple_index) {
788     CHECK(fusion_param->shape().IsTuple());
789     std::vector<HloInstruction*> users_at_tuple_index;
790     for (auto user : fusion_param->users()) {
791       CHECK_EQ(HloOpcode::kGetTupleElement, user->opcode());
792       if (user->tuple_index() == tuple_index) {
793         users_at_tuple_index.push_back(user);
794       }
795     }
796     return users_at_tuple_index;
797   }
798 
799   // Returns the unique user of 'fusion_param' at 'tuple_index'.
GetUniqueFusionParameterUserAt(HloInstruction * fusion_param,int64 tuple_index)800   HloInstruction* GetUniqueFusionParameterUserAt(HloInstruction* fusion_param,
801                                                  int64 tuple_index) {
802     std::vector<HloInstruction*> users =
803         GetFusionParameterUsersAt(fusion_param, tuple_index);
804     CHECK_EQ(1, users.size());
805     return users[0];
806   }
807 
808   // Checks that the count of all users of all aliases of 'instruction' at
809   // 'index' match 'expected_num_users'.
ExpectNumUsersOfAliases(const HloInstruction * instruction,const ShapeIndex & index,const int64 expected_num_users)810   void ExpectNumUsersOfAliases(const HloInstruction* instruction,
811                                const ShapeIndex& index,
812                                const int64 expected_num_users) {
813     const auto* buffer = GetBuffer(instruction, index);
814     int64 num_users = 0;
815     for (const auto& alias : points_to_analysis_->GetBufferAliases(*buffer)) {
816       for (auto user : alias.instruction()->users()) {
817         if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) {
818           // Gte instructions only access the top-level buffer of their operand.
819           continue;
820         }
821         ++num_users;
822       }
823     }
824     EXPECT_EQ(expected_num_users, num_users);
825   }
826 };
827 
828 // Tests the points-to set of tuple-shaped fusion parameter 0 and all GTE users.
829 // Tests the alias set of tuple-shaped fusion parameter 0 at all shape indices.
830 // Tests that there is a single user of the aliases of tuple-shaped fusion
831 // parameter 0 at shape index {0}.
832 //
833 //             Param0    Const
834 //                 \      /
835 //                  Fusion
836 //                 /      \
837 //        FusionParam0   FusionParam1
838 //        /     |    \       |
839 //     Gte(0) Gte(2) Gte(1)  /
840 //        \     |      \    /
841 //         \    |       Add
842 //          \   |        /
843 //           \0 |2      /1
844 //          DynamicUpdateSlice  // fused root.
845 //
TEST_F(FusionPointsToAnalysisTest,FusionParam0OneUser)846 TEST_F(FusionPointsToAnalysisTest, FusionParam0OneUser) {
847   Run(/*add_additional_gte0_user=*/false);
848 }
849 
850 // Tests the points-to set of tuple-shaped fusion parameter 0 and all GTE users.
851 // Tests the alias set of tuple-shaped fusion parameter 0 at all shape indices.
852 // Tests that there are two users of the aliases of tuple-shaped fusion
853 // parameter 0 at shape index {0}.
854 //
855 //             Param0    Const
856 //                 \      /
857 //                  Fusion
858 //                 /      \
859 //        FusionParam0   FusionParam1
860 //        /     |    \       |
861 //     Gte(2) Gte(0) Gte(1)  /
862 //        \     |      \    /
863 //         \    |\      Add
864 //          \   | \      /
865 //           |  | Slice /
866 //           |  |   \  /
867 //           |  |   Add
868 //           |  |    |
869 //           |2 |0   |1
870 //          DynamicUpdateSlice  // fused root.
871 //
TEST_F(FusionPointsToAnalysisTest,FusionParam0TwoUsers)872 TEST_F(FusionPointsToAnalysisTest, FusionParam0TwoUsers) {
873   Run(/*add_additional_gte0_user=*/true);
874 }
875 
876 class PointsToAnalysisTestBase : public HloTestBase {
877  protected:
BuildModule(std::unique_ptr<HloComputation> computation)878   void BuildModule(std::unique_ptr<HloComputation> computation) {
879     module_ = CreateNewVerifiedModule();
880     computation_ = module_->AddEntryComputation(std::move(computation));
881   }
882 
RunAnalysis()883   void RunAnalysis() {
884     CHECK_NOTNULL(module_.get());
885     points_to_analysis_ =
886         TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
887   }
888 
BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation)889   void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
890     BuildModule(std::move(computation));
891     RunAnalysis();
892   }
893 
894   std::unique_ptr<HloModule> module_;
895   HloComputation* computation_ = nullptr;
896   std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
897 };
898 
899 class DoesNotUseOperandBufferTest : public PointsToAnalysisTestBase {};
900 
TEST_F(DoesNotUseOperandBufferTest,GetTupleElement)901 TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) {
902   auto builder = HloComputation::Builder(TestName());
903 
904   Shape elem_shape = ShapeUtil::MakeShape(F32, {8});
905   auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
906       0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple"));
907   auto gte0 = builder.AddInstruction(
908       HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0));
909   auto gte1 = builder.AddInstruction(
910       HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1));
911   builder.AddInstruction(
912       HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1));
913 
914   BuildModuleAndRunAnalysis(builder.Build());
915 
916   // GetTupleElement instructions only access the top-level buffer of their
917   // operand.
918   EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {0}, gte0));
919   EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {1}, gte1));
920   EXPECT_FALSE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte0));
921   EXPECT_FALSE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte1));
922 }
923 
TEST_F(DoesNotUseOperandBufferTest,FusedDynamicUpdateSlice)924 TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
925   auto builder = HloComputation::Builder(TestName());
926 
927   Shape data_shape = ShapeUtil::MakeShape(F32, {8});
928   auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
929       0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
930   auto gte0 = builder.AddInstruction(
931       HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
932   auto gte1 = builder.AddInstruction(
933       HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
934 
935   // Create a DynamicUpdateSlice instruction of tuple element 1.
936   auto starts = builder.AddInstruction(
937       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
938   auto update = builder.AddInstruction(HloInstruction::CreateConstant(
939       LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
940   auto dynamic_update_slice =
941       builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
942           data_shape, gte1, update, {starts}));
943   builder.AddInstruction(
944       HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
945 
946   BuildModule(builder.Build());
947   auto fusion = computation_->CreateFusionInstruction(
948       {dynamic_update_slice, starts, update, gte1},
949       HloInstruction::FusionKind::kLoop);
950   RunAnalysis();
951 
952   // The fusion instruction never uses tuple element 0, but does use element 1.
953   EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion));
954   EXPECT_FALSE(
955       points_to_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion));
956 }
957 }  // namespace
958 }  // namespace xla
959