• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
17 
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/memory/memory.h"
26 #include "absl/strings/string_view.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/service/buffer_value.h"
29 #include "tensorflow/compiler/xla/service/call_graph.h"
30 #include "tensorflow/compiler/xla/service/copy_insertion.h"
31 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
32 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
35 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
36 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
37 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
38 #include "tensorflow/compiler/xla/service/hlo_parser.h"
39 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/test.h"
42 #include "tensorflow/compiler/xla/test_helpers.h"
43 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
44 #include "tensorflow/compiler/xla/types.h"
45 #include "tensorflow/compiler/xla/xla_data.pb.h"
46 #include "tensorflow/core/lib/core/status_test_util.h"
47 #include "tensorflow/core/platform/macros.h"
48 
49 namespace xla {
50 namespace {
51 
52 using memory_space_assignment::PresetAssignments;
53 using ::testing::UnorderedElementsAre;
54 
55 // DFS visitor that collects the instructions referenced by a computation
56 // without descending into nested computations, i.e., only from the operands.
57 class InstructionListVisitor : public DfsHloVisitorWithDefault {
58  public:
InstructionListVisitor(const HloInstruction * root)59   explicit InstructionListVisitor(const HloInstruction* root) : root_(root) {}
60 
DefaultAction(HloInstruction * hlo)61   Status DefaultAction(HloInstruction* hlo) override {
62     // For each instruction, just push it on the list after walking the
63     // operands.
64     instructions_.push_back(hlo);
65     VLOG(0) << "List instruction " << hlo->ToString();
66     return Status::OK();
67   }
68 
GetInstructions()69   std::vector<const HloInstruction*> GetInstructions() { return instructions_; }
70 
71  private:
72   // The instruction root of the computation.
73   const HloInstruction* root_;
74 
75   // The full set of instructions found (may be duplicates, e.g., kParameter).
76   std::vector<const HloInstruction*> instructions_;
77 
78   TF_DISALLOW_COPY_AND_ASSIGN(InstructionListVisitor);
79 };
80 
GetInstructions(HloInstruction * root)81 const std::vector<const HloInstruction*> GetInstructions(HloInstruction* root) {
82   InstructionListVisitor main_list(root);
83   TF_CHECK_OK(root->Accept(&main_list));
84   return main_list.GetInstructions();
85 }
86 
87 class BufferAssignmentTest : public HloTestBase {
88  protected:
~BufferAssignmentTest()89   ~BufferAssignmentTest() override {}
90 
RunBufferAssignment(HloModule * module,int64_t alignment=1)91   std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
92                                                         int64_t alignment = 1) {
93     return BufferAssigner::Run(
94                module, absl::make_unique<DependencyHloOrdering>(module),
95                backend().compiler()->BufferSizeBytesFunction(),
96                [alignment](LogicalBuffer::Color) { return alignment; },
97                /*allocate_buffers_for_constants=*/true)
98         .ConsumeValueOrDie();
99   }
100 
RunBufferAssignmentNoBuffersForConstants(HloModule * module,int64_t alignment=1)101   std::unique_ptr<BufferAssignment> RunBufferAssignmentNoBuffersForConstants(
102       HloModule* module, int64_t alignment = 1) {
103     return BufferAssigner::Run(
104                module, absl::make_unique<DependencyHloOrdering>(module),
105                backend().compiler()->BufferSizeBytesFunction(),
106                [alignment](LogicalBuffer::Color) { return alignment; },
107                /*allocate_buffers_for_constants=*/false)
108         .ConsumeValueOrDie();
109   }
110 
RunBufferAssignmentNoBuffersReuseForAdd(HloModule * module,int64_t alignment=1)111   std::unique_ptr<BufferAssignment> RunBufferAssignmentNoBuffersReuseForAdd(
112       HloModule* module, int64_t alignment = 1) {
113     absl::flat_hash_set<HloOpcode> must_not_live_out = {HloOpcode::kAdd};
114 
115     return BufferAssigner::Run(
116                module, absl::make_unique<DependencyHloOrdering>(module),
117                backend().compiler()->BufferSizeBytesFunction(),
118                [alignment](LogicalBuffer::Color) { return alignment; },
119                /*allocate_buffers_for_constants=*/false,
120                /*colorer=*/BufferAssigner::DefaultColorer(),
121                /*must_not_live_out=*/must_not_live_out)
122         .ConsumeValueOrDie();
123   }
124 
RunColoredBufferAssignment(HloModule * module,BufferAssigner::Colorer colorer,int64_t alignment=1)125   std::unique_ptr<BufferAssignment> RunColoredBufferAssignment(
126       HloModule* module, BufferAssigner::Colorer colorer,
127       int64_t alignment = 1) {
128     return BufferAssigner::Run(
129                module, absl::make_unique<DependencyHloOrdering>(module),
130                backend().compiler()->BufferSizeBytesFunction(),
131                [alignment](LogicalBuffer::Color) { return alignment; },
132                /*allocate_buffers_for_constants=*/true, std::move(colorer))
133         .ConsumeValueOrDie();
134   }
135 
RunBufferAssignmentWithInstructionSequence(HloModule * module,absl::Span<HloInstruction * const> instruction_sequence,int64_t alignment=1)136   std::unique_ptr<BufferAssignment> RunBufferAssignmentWithInstructionSequence(
137       HloModule* module, absl::Span<HloInstruction* const> instruction_sequence,
138       int64_t alignment = 1) {
139     HloSchedule schedule(module);
140     schedule.set_sequence(module->entry_computation(), instruction_sequence);
141     return BufferAssigner::Run(
142                module, absl::make_unique<SequentialHloOrdering>(schedule),
143                backend().compiler()->BufferSizeBytesFunction(),
144                [alignment](LogicalBuffer::Color) { return alignment; },
145                /*allocate_buffers_for_constants=*/true)
146         .ConsumeValueOrDie();
147   }
148 
RunBufferAssignmentWithPresetAssignments(HloModule * module,std::unique_ptr<PresetAssignments> preset_assignments,int64_t alignment=1)149   std::unique_ptr<BufferAssignment> RunBufferAssignmentWithPresetAssignments(
150       HloModule* module, std::unique_ptr<PresetAssignments> preset_assignments,
151       int64_t alignment = 1) {
152     return BufferAssigner::Run(
153                module, absl::make_unique<DependencyHloOrdering>(module),
154                backend().compiler()->BufferSizeBytesFunction(),
155                [alignment](LogicalBuffer::Color) { return alignment; },
156                /*allocate_buffers_for_constants=*/true,
157                BufferAssigner::DefaultColorer(),
158                /*must_not_live_out=*/{},
159                /*can_share_buffer=*/nullptr, std::move(preset_assignments))
160         .ConsumeValueOrDie();
161   }
162 
163   // Builds an x+1.0 computation to use in a Map.
BuildMapComputationPlus1(const string & name)164   std::unique_ptr<HloComputation> BuildMapComputationPlus1(const string& name) {
165     auto builder = HloComputation::Builder(name);
166     auto param =
167         builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
168     auto value = builder.AddInstruction(
169         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
170     builder.AddInstruction(
171         HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, value));
172     return builder.Build();
173   }
174 
BuildReduceComputation(const string & name)175   std::unique_ptr<HloComputation> BuildReduceComputation(const string& name) {
176     auto builder = HloComputation::Builder(name);
177     auto param =
178         builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
179     auto param2 =
180         builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "y"));
181     builder.AddInstruction(
182         HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, param2));
183     return builder.Build();
184   }
185 
186   // Builds a simple compare-to-limit (x < 4) computation for a While.
187   //
188   // condition:
189   //   const4[s32] -----------------------------------\
190   //                                                   \
191   //   param[(s32,f32[4])] --- get-tuple-element[0] --- less-than
192   //
BuildWhileConditionComputation(const string & name)193   std::unique_ptr<HloComputation> BuildWhileConditionComputation(
194       const string& name) {
195     auto builder = HloComputation::Builder(name);
196     auto const4 = builder.AddInstruction(
197         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
198     auto param = builder.AddInstruction(
199         HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
200     auto index = builder.AddInstruction(
201         HloInstruction::CreateGetTupleElement(const4->shape(), param, 0));
202     builder.AddInstruction(
203         HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), index,
204                                       const4, ComparisonDirection::kLt));
205     return builder.Build();
206   }
207 
208   // Builds a simple body computation for a While.
209   //
210   // body:
211   //   constv[f32[4]] --------------------------------------\
212   //                                                         \
213   //                           /--- get-tuple-elementv[1] --- addv ---\
214   //   param[(s32,f32[4])] ---|                                    tuple
215   //                           \--- get-tuple-elementc[0] --- addc ---/
216   //                                                         /
217   //   const1[s32] -----------------------------------------/
218   //
BuildWhileBodyComputation(const string & name)219   std::unique_ptr<HloComputation> BuildWhileBodyComputation(
220       const string& name) {
221     auto builder = HloComputation::Builder(name);
222     auto const1 = builder.AddInstruction(
223         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
224     auto constv = builder.AddInstruction(HloInstruction::CreateConstant(
225         LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
226     auto param = builder.AddInstruction(
227         HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
228     auto indexc = builder.AddInstruction(
229         HloInstruction::CreateGetTupleElement(const1->shape(), param, 0));
230     auto addc = builder.AddInstruction(HloInstruction::CreateBinary(
231         indexc->shape(), HloOpcode::kAdd, indexc, const1));
232     auto indexv = builder.AddInstruction(
233         HloInstruction::CreateGetTupleElement(constv->shape(), param, 1));
234     auto addv = builder.AddInstruction(HloInstruction::CreateBinary(
235         constv->shape(), HloOpcode::kAdd, indexv, constv));
236     builder.AddInstruction(HloInstruction::CreateTuple({addc, addv}));
237     return builder.Build();
238   }
239 
BuildR0F32UnaryOpComputation(HloOpcode opcode,const string & name)240   std::unique_ptr<HloComputation> BuildR0F32UnaryOpComputation(
241       HloOpcode opcode, const string& name) {
242     auto builder = HloComputation::Builder(name);
243     auto param =
244         builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
245     builder.AddInstruction(HloInstruction::CreateUnary(r0f32_, opcode, param));
246     return builder.Build();
247   }
248 
249   // Verifies that the given instruction hlo has a valid input buffer assigned,
250   // i.e., the parameter number matches the op's.
GetAssignedInputAllocation(const BufferAssignment & buffers,HloInstruction * hlo)251   const BufferAllocation& GetAssignedInputAllocation(
252       const BufferAssignment& buffers, HloInstruction* hlo) {
253     LOG(INFO) << "Checking input: " << hlo->ToString();
254     const BufferAllocation& buffer =
255         *buffers.GetUniqueTopLevelSlice(hlo).ConsumeValueOrDie().allocation();
256     EXPECT_EQ(hlo->parameter_number(), buffer.parameter_number());
257     return buffer;
258   }
259 
260   // Verifies that the given instruction hlo has a valid output buffer
261   // assigned, and returns it.
GetAssignedOutputAllocation(const BufferAssignment & buffers,HloInstruction * hlo)262   const BufferAllocation& GetAssignedOutputAllocation(
263       const BufferAssignment& buffers, HloInstruction* hlo) {
264     LOG(INFO) << "Checking output: " << hlo->ToString();
265     const BufferAllocation& buffer = GetTopLevelAllocation(buffers, hlo);
266     return buffer;
267   }
268 
269   // Returns the allocation for the given instruction.
GetAllocation(const BufferAssignment & buffers,const HloInstruction * hlo,const ShapeIndex & index)270   const BufferAllocation& GetAllocation(const BufferAssignment& buffers,
271                                         const HloInstruction* hlo,
272                                         const ShapeIndex& index) {
273     return *buffers.GetUniqueSlice(hlo, index).ConsumeValueOrDie().allocation();
274   }
GetTopLevelAllocation(const BufferAssignment & buffers,const HloInstruction * hlo)275   const BufferAllocation& GetTopLevelAllocation(const BufferAssignment& buffers,
276                                                 const HloInstruction* hlo) {
277     return *buffers.GetUniqueTopLevelSlice(hlo)
278                 .ConsumeValueOrDie()
279                 .allocation();
280   }
281 
282   // Verifies that all instructions in the given instruction list except
283   // kConstant have assigned buffers, and returns their total size. If min_index
284   // and max_index are not nullptr, the minimum and maximum buffer indices in
285   // the assignment are written into them.
ValidateBuffers(const std::vector<const HloInstruction * > & instructions,const BufferAssignment & buffers)286   int64 ValidateBuffers(const std::vector<const HloInstruction*>& instructions,
287                         const BufferAssignment& buffers) {
288     // Verifies all instructions have buffers, and gets the index ranges.
289     for (const HloInstruction* hlo : instructions) {
290       if (!buffers.HasTopLevelAllocation(hlo)) {
291         // If `hlo` has no assigned buffer, it is either a constant or a nested
292         // parameter.
293         EXPECT_TRUE(HloOpcode::kConstant == hlo->opcode() ||
294                     HloOpcode::kParameter == hlo->opcode());
295         continue;
296       }
297     }
298 
299     // Gets the total size of all buffers assigned.
300     int64_t total_size = 0;
301     for (auto& allocation : buffers.Allocations()) {
302       total_size += allocation.size();
303     }
304     return total_size;
305   }
306 
307   // Shapes for use in the examples.
308   Shape s32_ = ShapeUtil::MakeShape(xla::S32, {});
309   Shape r0f32_ = ShapeUtil::MakeShape(xla::F32, {});
310   Shape f32vec4_ = ShapeUtil::MakeShape(F32, {4});
311   Shape f32vec10_ = ShapeUtil::MakeShape(F32, {10});
312   Shape f32vec100_ = ShapeUtil::MakeShape(F32, {100});
313   Shape f32a100x10_ = ShapeUtil::MakeShape(F32, {100, 10});
314   Shape t_s32_f32v4_ = ShapeUtil::MakeTupleShape({s32_, f32vec4_});
315   Shape t_s32_f32v10_ = ShapeUtil::MakeTupleShape({s32_, f32vec10_});
316 };
317 
318 // Returns true if the buffers assigned to instructions in "a" are distinct
319 // from the buffers assigned to those in "b" (ie, intersection is empty).
BuffersDistinct(const std::vector<const HloInstruction * > & a,const std::vector<const HloInstruction * > & b,const BufferAssignment & assignment)320 static bool BuffersDistinct(const std::vector<const HloInstruction*>& a,
321                             const std::vector<const HloInstruction*>& b,
322                             const BufferAssignment& assignment) {
323   absl::flat_hash_set<BufferAllocation::Slice> a_slices;
324   for (const HloInstruction* instruction : a) {
325     if (assignment.HasTopLevelAllocation(instruction)) {
326       a_slices.insert(
327           assignment.GetUniqueTopLevelSlice(instruction).ConsumeValueOrDie());
328     }
329   }
330 
331   for (const HloInstruction* instruction : b) {
332     if (assignment.HasTopLevelAllocation(instruction)) {
333       if (a_slices.contains(assignment.GetUniqueTopLevelSlice(instruction)
334                                 .ConsumeValueOrDie())) {
335         return false;
336       }
337     }
338   }
339   return true;
340 }
341 
342 // Tests a computation consisting of a single scalar constant node.
TEST_F(BufferAssignmentTest,ScalarConstant)343 TEST_F(BufferAssignmentTest, ScalarConstant) {
344   auto builder = HloComputation::Builder(TestName());
345   auto const0 = builder.AddInstruction(
346       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
347   auto module = CreateNewVerifiedModule();
348   module->AddEntryComputation(builder.Build());
349 
350   {
351     auto buffers = RunBufferAssignment(module.get());
352     EXPECT_TRUE(buffers->HasTopLevelAllocation(const0));
353   }
354 
355   {
356     auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get());
357     EXPECT_FALSE(buffers->HasTopLevelAllocation(const0));
358   }
359 }
360 
TEST_F(BufferAssignmentTest,BufferForConst)361 TEST_F(BufferAssignmentTest, BufferForConst) {
362   // Addition of two vector constants: checks that internal constant nodes have
363   // no buffers assigned, and their consumer has a buffer.
364   auto builder = HloComputation::Builder(TestName());
365   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
366       LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
367   auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
368       LiteralUtil::CreateR1<float>({4.1f, 4.2f, 4.3f, 4.4f})));
369   auto add = builder.AddInstruction(
370       HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1));
371   auto module = CreateNewVerifiedModule();
372   module->AddEntryComputation(builder.Build());
373 
374   {
375     auto buffers = RunBufferAssignment(module.get());
376     EXPECT_TRUE(buffers->HasTopLevelAllocation(const0));
377     EXPECT_TRUE(buffers->HasTopLevelAllocation(const1));
378     GetAssignedOutputAllocation(*buffers, add);
379   }
380   {
381     auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get());
382     EXPECT_FALSE(buffers->HasTopLevelAllocation(const0));
383     EXPECT_FALSE(buffers->HasTopLevelAllocation(const1));
384     GetAssignedOutputAllocation(*buffers, add);
385   }
386 }
387 
TEST_F(BufferAssignmentTest,HasAllocationAt)388 TEST_F(BufferAssignmentTest, HasAllocationAt) {
389   // Create a tuple with non-const and const elements and check that
390   // HasAllocationAt works correctly.
391   auto builder = HloComputation::Builder(TestName());
392   auto param0 = builder.AddInstruction(
393       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
394   auto constant = builder.AddInstruction(
395       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
396   auto negate = builder.AddInstruction(
397       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
398   auto tuple = builder.AddInstruction(
399       HloInstruction::CreateTuple({negate, param0, constant}));
400   auto module = CreateNewVerifiedModule();
401   module->AddEntryComputation(builder.Build());
402 
403   auto buffers = RunBufferAssignment(module.get());
404   // Make sure that HasAllocationAt() agrees with what HasTopLevelAllocation()
405   // reports for the instruction directly.
406   EXPECT_EQ(buffers->HasTopLevelAllocation(tuple),
407             buffers->HasAllocationAt(tuple, /*index=*/{}));
408   EXPECT_EQ(buffers->HasTopLevelAllocation(negate),
409             buffers->HasAllocationAt(tuple, /*index=*/{0}));
410   EXPECT_EQ(buffers->HasTopLevelAllocation(param0),
411             buffers->HasAllocationAt(tuple, /*index=*/{1}));
412   EXPECT_EQ(buffers->HasTopLevelAllocation(constant),
413             buffers->HasAllocationAt(tuple, /*index=*/{2}));
414 }
415 
TEST_F(BufferAssignmentTest,BufferForOutputConst)416 TEST_F(BufferAssignmentTest, BufferForOutputConst) {
417   // This computation copies a constant to output.
418   auto builder = HloComputation::Builder(TestName());
419   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
420       LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
421   auto copy = builder.AddInstruction(
422       HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0));
423   auto module = CreateNewVerifiedModule();
424   module->AddEntryComputation(builder.Build());
425 
426   auto buffers = RunBufferAssignment(module.get());
427   // The copy node now has an output buffer.
428   GetAssignedOutputAllocation(*buffers, copy);
429 }
430 
TEST_F(BufferAssignmentTest,Basic)431 TEST_F(BufferAssignmentTest, Basic) {
432   // paramscalar ------- (mul) -- (add) -- (sub)
433   //                     /        /        /
434   // param0[100] -------/        /        /
435   //                            /        /
436   // param1[100] --------------/--------/
437   auto builder = HloComputation::Builder(TestName());
438   auto paramscalar =
439       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
440   auto broadcast = builder.AddInstruction(
441       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
442   auto param0 = builder.AddInstruction(
443       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
444   auto param1 = builder.AddInstruction(
445       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
446   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
447       f32vec100_, HloOpcode::kMultiply, broadcast, param0));
448   auto add = builder.AddInstruction(
449       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
450   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
451       f32vec100_, HloOpcode::kSubtract, add, param1));
452   auto module = CreateNewVerifiedModule();
453   module->AddEntryComputation(builder.Build());
454 
455   auto buffers = RunBufferAssignment(module.get());
456 
457   // Distinct input buffers were assigned for parameters.
458   BufferAllocation paramscalar_buffer =
459       GetAssignedInputAllocation(*buffers, paramscalar);
460   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
461   BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
462   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
463   EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
464   EXPECT_NE(param0_buffer.index(), param1_buffer.index());
465 
466   // The mul node has a valid buffer assigned, doesn't share with input.
467   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
468   EXPECT_NE(mul_buffer.index(), param0_buffer.index());
469 
470   // The add node can reuse the mul node's buffer.
471   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
472   EXPECT_EQ(add_buffer.index(), mul_buffer.index());
473 
474   // The sub node has a valid output buffer assigned.
475   GetAssignedOutputAllocation(*buffers, sub);
476 }
477 
TEST_F(BufferAssignmentTest,AliasedParamCanBeReused)478 TEST_F(BufferAssignmentTest, AliasedParamCanBeReused) {
479   // If an input buffer and output buffer aliases, the input buffer can be
480   // reused for other intermediate results.
481   //
482   // param0[100] ----- (neg1) -- (neg2)
483   //    |                           |
484   //    + -------- Aliased ---------+
485 
486   auto builder = HloComputation::Builder(TestName());
487 
488   auto param = builder.AddInstruction(
489       HloInstruction::CreateParameter(0, f32vec100_, "p0"));
490   auto neg_1 = builder.AddInstruction(
491       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param));
492   auto neg_2 = builder.AddInstruction(
493       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, neg_1));
494 
495   auto module = CreateNewVerifiedModule();
496   module->AddEntryComputation(builder.Build());
497 
498   TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias({}, 0, {}));
499 
500   auto buffers = RunBufferAssignment(module.get());
501 
502   BufferAllocation param_buffer = GetAssignedInputAllocation(*buffers, param);
503   BufferAllocation neg_1_buffer = GetAllocation(*buffers, neg_1, {});
504   BufferAllocation neg_2_buffer = GetAllocation(*buffers, neg_2, {});
505 
506   // Everything use one buffer.
507   EXPECT_EQ(param_buffer.index(), neg_1_buffer.index());
508   EXPECT_EQ(neg_2_buffer.index(), neg_1_buffer.index());
509 }
510 
TEST_F(BufferAssignmentTest,AddCannotReuse)511 TEST_F(BufferAssignmentTest, AddCannotReuse) {
512   // Pass in a special rule to indicate that "add" cannot be live out.
513   //
514   // paramscalar ------- (mul) -- (add) -- (sub)
515   //                     /        /        /
516   // param0[100] -------/        /        /
517   //                            /        /
518   // param1[100] --------------/--------/
519   auto builder = HloComputation::Builder(TestName());
520   auto paramscalar =
521       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
522   auto broadcast = builder.AddInstruction(
523       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
524   auto param0 = builder.AddInstruction(
525       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
526   auto param1 = builder.AddInstruction(
527       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
528   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
529       f32vec100_, HloOpcode::kMultiply, broadcast, param0));
530   auto add = builder.AddInstruction(
531       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
532   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
533       f32vec100_, HloOpcode::kSubtract, add, param1));
534   auto module = CreateNewVerifiedModule();
535   module->AddEntryComputation(builder.Build());
536 
537   auto buffers = RunBufferAssignmentNoBuffersReuseForAdd(module.get());
538 
539   // Distinct input buffers were assigned for parameters.
540   BufferAllocation paramscalar_buffer =
541       GetAssignedInputAllocation(*buffers, paramscalar);
542   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
543   BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
544   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
545   EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
546   EXPECT_NE(param0_buffer.index(), param1_buffer.index());
547 
548   // The mul node has a valid buffer assigned, doesn't share with input.
549   const BufferAllocation& sub_buffer = GetTopLevelAllocation(*buffers, sub);
550   EXPECT_NE(sub_buffer.index(), param0_buffer.index());
551 
552   // The add node cannot reuse the mul node's buffer since we told buffer
553   // assignment so.
554   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
555   EXPECT_NE(add_buffer.index(), sub_buffer.index());
556 
557   // The sub node has a valid output buffer assigned.
558   GetAssignedOutputAllocation(*buffers, sub);
559 }
560 
TEST_F(BufferAssignmentTest,BasicUniquelyColored)561 TEST_F(BufferAssignmentTest, BasicUniquelyColored) {
562   // paramscalar ------- (mul) -- (add) -- (sub)
563   //                     /        /        /
564   // param0[100] -------/        /        /
565   //                            /        /
566   // param1[100] --------------/--------/
567   // The output of each op is colored with a different color, so we can not
568   // share anything.
569   auto builder = HloComputation::Builder(TestName());
570   auto paramscalar =
571       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
572   auto broadcast = builder.AddInstruction(
573       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
574   auto param0 = builder.AddInstruction(
575       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
576   auto param1 = builder.AddInstruction(
577       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
578   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
579       f32vec100_, HloOpcode::kMultiply, broadcast, param0));
580   auto add = builder.AddInstruction(
581       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
582   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
583       f32vec100_, HloOpcode::kSubtract, add, param1));
584   auto module = CreateNewVerifiedModule();
585   module->AddEntryComputation(builder.Build());
586 
587   auto colorer = [](HloAliasAnalysis* alias_analysis, const HloOrdering&) {
588     int color = 0;
589     for (HloValue::Id id = 0;
590          id < alias_analysis->dataflow_analysis().values().size(); id++) {
591       auto& value = alias_analysis->dataflow_analysis().GetValue(id);
592       value.set_color(BufferValue::Color(color++));
593     }
594     return Status::OK();
595   };
596 
597   auto buffers = RunColoredBufferAssignment(module.get(), colorer);
598 
599   // Distinct input buffers were assigned for parameters.
600   BufferAllocation paramscalar_buffer =
601       GetAssignedInputAllocation(*buffers, paramscalar);
602   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
603   BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
604   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
605   EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
606   EXPECT_NE(param0_buffer.index(), param1_buffer.index());
607 
608   // The mul node has a valid buffer assigned, doesn't share with input.
609   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
610   EXPECT_NE(mul_buffer.index(), param0_buffer.index());
611 
612   // The add node can not reuse the mul node's buffer due to coloring.
613   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
614   EXPECT_NE(add_buffer.index(), mul_buffer.index());
615 
616   // The sub node has a valid output buffer assigned.
617   GetAssignedOutputAllocation(*buffers, sub);
618 
619   // Check if the HLO instructions have the correct colors in the layout.
620   EXPECT_EQ(param0->shape().layout().memory_space(), 2);
621   EXPECT_EQ(param1->shape().layout().memory_space(), 3);
622   EXPECT_EQ(mul->shape().layout().memory_space(), 4);
623   EXPECT_EQ(add->shape().layout().memory_space(), 5);
624   EXPECT_EQ(sub->shape().layout().memory_space(), 6);
625 }
626 
TEST_F(BufferAssignmentTest,BasicPartiallyColored)627 TEST_F(BufferAssignmentTest, BasicPartiallyColored) {
628   // paramscalar ------- (mul) -- (add) -- (sub)
629   //                     /        /        /
630   // param0[100] -------/        /        /
631   //                            /        /
632   // param1[100] --------------/--------/
633   // The output of the mul and the add have the color 1, and the other buffers
634   // have the color 0, which allows the mul and add to share buffers.
635   auto builder = HloComputation::Builder(TestName());
636   auto paramscalar =
637       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
638   auto broadcast = builder.AddInstruction(
639       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
640   auto param0 = builder.AddInstruction(
641       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
642   auto param1 = builder.AddInstruction(
643       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
644   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
645       f32vec100_, HloOpcode::kMultiply, broadcast, param0));
646   auto add = builder.AddInstruction(
647       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
648   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
649       f32vec100_, HloOpcode::kSubtract, add, param1));
650   auto module = CreateNewVerifiedModule();
651   module->AddEntryComputation(builder.Build());
652 
653   auto colorer = [](HloAliasAnalysis* alias_analysis, const HloOrdering&) {
654     for (HloValue::Id id = 0;
655          id < alias_analysis->dataflow_analysis().values().size(); id++) {
656       auto& value = alias_analysis->dataflow_analysis().GetValue(id);
657       auto& buffer = alias_analysis->GetBufferContainingValue(value);
658       for (const auto& alias : buffer.values()) {
659         if (alias->instruction()->opcode() == HloOpcode::kAdd ||
660             alias->instruction()->opcode() == HloOpcode::kMultiply) {
661           value.set_color(LogicalBuffer::Color(1));
662         }
663       }
664       if (!value.has_color()) {
665         value.set_color(LogicalBuffer::Color(0));
666       }
667     }
668     return Status::OK();
669   };
670 
671   auto buffers = RunColoredBufferAssignment(module.get(), colorer);
672 
673   // Distinct input buffers were assigned for parameters.
674   BufferAllocation paramscalar_buffer =
675       GetAssignedInputAllocation(*buffers, paramscalar);
676   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
677   BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
678   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
679   EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
680   EXPECT_NE(param0_buffer.index(), param1_buffer.index());
681 
682   // The mul node has a valid buffer assigned, doesn't share with input.
683   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
684   EXPECT_NE(mul_buffer.index(), param0_buffer.index());
685 
686   // The add node can reuse the mul node's buffer.
687   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
688   EXPECT_EQ(add_buffer.index(), mul_buffer.index());
689 
690   // The sub node has a valid output buffer assigned.
691   GetAssignedOutputAllocation(*buffers, sub);
692 
693   // Check if the HLO instructions have the correct colors in the layout.
694   EXPECT_EQ(mul->shape().layout().memory_space(), 1);
695   EXPECT_EQ(add->shape().layout().memory_space(), 1);
696   EXPECT_EQ(sub->shape().layout().memory_space(), 0);
697   EXPECT_EQ(param0->shape().layout().memory_space(), 0);
698   EXPECT_EQ(param1->shape().layout().memory_space(), 0);
699 }
700 
TEST_F(BufferAssignmentTest,PresetAssignments)701 TEST_F(BufferAssignmentTest, PresetAssignments) {
702   // paramscalar ------- (mul) -- (add) -- (sub)
703   //                     /        /        /
704   // param0[100] -------/        /        /
705   //                            /        /
706   // param1[100] --------------/--------/
707   // Similar to BasicPartiallyColored, but the color is set in the layout.
708   // The output of the mul and the add have the color 1 and have preset
709   // assignments, and the other buffers have the color 0, which allows the mul
710   // and add to share buffers.
711   auto builder = HloComputation::Builder(TestName());
712   auto paramscalar =
713       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
714   auto broadcast = builder.AddInstruction(
715       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
716   auto param0 = builder.AddInstruction(
717       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
718   auto param1 = builder.AddInstruction(
719       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
720   Shape f32vec100_color1 =
721       ShapeUtil::MakeShapeWithLayout(F32, {100}, {0}, {}, 0, 1);
722   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
723       f32vec100_color1, HloOpcode::kMultiply, broadcast, param0));
724   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
725       f32vec100_color1, HloOpcode::kAdd, mul, param1));
726   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
727       f32vec100_, HloOpcode::kSubtract, add, param1));
728   auto module = CreateNewVerifiedModule();
729   module->AddEntryComputation(builder.Build());
730 
731   auto preset_assignments = absl::make_unique<PresetAssignments>();
732   preset_assignments->add_chunk({mul, {}}, {/*offset=*/100, /*size=*/400});
733   preset_assignments->add_chunk({add, {}}, {/*offset=*/550, /*size=*/400});
734   preset_assignments->assignment_information_for_space(/*memory_space=*/1)
735       ->size = 950;
736 
737   auto buffers = RunBufferAssignmentWithPresetAssignments(
738       module.get(), std::move(preset_assignments));
739 
740   // Distinct input buffers were assigned for parameters.
741   BufferAllocation paramscalar_buffer =
742       GetAssignedInputAllocation(*buffers, paramscalar);
743   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
744   BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
745   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
746   EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
747   EXPECT_EQ(paramscalar_buffer.color(), LogicalBuffer::Color(0));
748   EXPECT_NE(param0_buffer.index(), param1_buffer.index());
749   EXPECT_EQ(param0_buffer.color(), LogicalBuffer::Color(0));
750 
751   // The mul and add use the same preset buffer. Ensure it has the correct color
752   // and offsets.
753   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
754   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
755   EXPECT_EQ(mul_buffer, add_buffer);
756   EXPECT_NE(mul_buffer.index(), param0_buffer.index());
757   EXPECT_EQ(mul_buffer.color(), LogicalBuffer::Color(1));
758 
759   EXPECT_EQ(mul_buffer.assigned_buffers().size(), 2);
760   for (const auto& value_and_offsetsize : mul_buffer.assigned_buffers()) {
761     if (value_and_offsetsize.first->instruction() == mul) {
762       EXPECT_EQ(value_and_offsetsize.second.offset, 100);
763       EXPECT_EQ(value_and_offsetsize.second.size, 400);
764     } else {
765       EXPECT_EQ(value_and_offsetsize.first->instruction(), add);
766       EXPECT_EQ(value_and_offsetsize.second.offset, 550);
767       EXPECT_EQ(value_and_offsetsize.second.size, 400);
768     }
769   }
770 
771   // The sub node has a valid output buffer assigned.
772   GetAssignedOutputAllocation(*buffers, sub);
773 }
774 
TEST_F(BufferAssignmentTest,PresetAssignmentsWhile)775 TEST_F(BufferAssignmentTest, PresetAssignmentsWhile) {
776   // Tests preset assignments when there is no 1-to-1 correspondence between
777   // HloValue and HloBuffer (i.e., a while loop).
778   auto module = CreateNewVerifiedModule();
779   Shape f32vec10_color1 =
780       ShapeUtil::MakeShapeWithLayout(F32, {10}, {0}, {}, 0, 1);
781   Shape t_s32_f32v10_color1 =
782       ShapeUtil::MakeTupleShape({s32_, f32vec10_color1});
783 
784   auto cond_builder = HloComputation::Builder("WhileCond");
785   HloInstruction* cond_param = cond_builder.AddInstruction(
786       HloInstruction::CreateParameter(0, t_s32_f32v10_color1, "cond_param"));
787   HloInstruction* cond_iter = cond_builder.AddInstruction(
788       HloInstruction::CreateGetTupleElement(s32_, cond_param, 0));
789   HloInstruction* cond_limit = cond_builder.AddInstruction(
790       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(50)));
791   cond_builder.AddInstruction(
792       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
793                                     cond_limit, ComparisonDirection::kLt));
794   HloComputation* cond_computation =
795       module->AddEmbeddedComputation(cond_builder.Build());
796 
797   auto body_builder = HloComputation::Builder("WhileBody");
798   HloInstruction* body_param = body_builder.AddInstruction(
799       HloInstruction::CreateParameter(0, t_s32_f32v10_color1, "body_param"));
800   HloInstruction* body_iter = body_builder.AddInstruction(
801       HloInstruction::CreateGetTupleElement(s32_, body_param, 0));
802   HloInstruction* body_data = body_builder.AddInstruction(
803       HloInstruction::CreateGetTupleElement(f32vec10_color1, body_param, 1));
804   HloInstruction* body_data_increment = body_builder.AddInstruction(
805       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
806           {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f})));
807   HloInstruction* body_data_next =
808       body_builder.AddInstruction(HloInstruction::CreateBinary(
809           f32vec10_color1, HloOpcode::kAdd, body_data, body_data_increment));
810   HloInstruction* body_iter_increment = body_builder.AddInstruction(
811       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
812   HloInstruction* body_iter_next =
813       body_builder.AddInstruction(HloInstruction::CreateBinary(
814           s32_, HloOpcode::kAdd, body_iter, body_iter_increment));
815   body_builder.AddInstruction(
816       HloInstruction::CreateTuple({body_iter_next, body_data_next}));
817   HloComputation* body_computation =
818       module->AddEmbeddedComputation(body_builder.Build());
819 
820   auto builder = HloComputation::Builder(TestName());
821   HloInstruction* iter = builder.AddInstruction(
822       HloInstruction::CreateParameter(0, s32_, "param_iter"));
823   HloInstruction* data = builder.AddInstruction(
824       HloInstruction::CreateParameter(1, f32vec10_, "param_data"));
825   HloInstruction* negate = builder.AddInstruction(
826       HloInstruction::CreateUnary(f32vec10_color1, HloOpcode::kNegate, data));
827   HloInstruction* tuple =
828       builder.AddInstruction(HloInstruction::CreateTuple({iter, negate}));
829   HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
830       t_s32_f32v10_color1, cond_computation, body_computation, tuple));
831   HloInstruction* while_data = builder.AddInstruction(
832       HloInstruction::CreateGetTupleElement(f32vec10_color1, while_op, 1));
833   builder.AddInstruction(HloInstruction::CreateBinary(
834       f32vec10_, HloOpcode::kAdd, while_data, data));
835   module->AddEntryComputation(builder.Build());
836 
837   // Set only one preset assignment for while data and its aliases.
838   auto preset_assignments = absl::make_unique<PresetAssignments>();
839   preset_assignments->add_chunk({negate, {}}, {/*offset=*/100, /*size=*/40});
840   preset_assignments->assignment_information_for_space(/*memory_space=*/1)
841       ->size = 140;
842 
843   auto buffers = RunBufferAssignmentWithPresetAssignments(
844       module.get(), std::move(preset_assignments));
845 
846   // All assigned buffers are aliased so they should have the same offset and
847   // size.
848   const BufferAllocation& data_buffer = GetTopLevelAllocation(*buffers, negate);
849   EXPECT_EQ(data_buffer.assigned_buffers().size(), 5);
850   for (const auto& value_and_offsetsize : data_buffer.assigned_buffers()) {
851     EXPECT_EQ(value_and_offsetsize.second.offset, 100);
852     EXPECT_EQ(value_and_offsetsize.second.size, 40);
853     EXPECT_EQ(value_and_offsetsize.first->color(), LogicalBuffer::Color(1));
854   }
855 }
856 
TEST_F(BufferAssignmentTest,MultipleUsersForNode)857 TEST_F(BufferAssignmentTest, MultipleUsersForNode) {
858   // This is similar to the Basic test, with the difference that (sub) is
859   // another user of (mul)'s result, so (mul)'s buffer cannot be reused for
860   // (add)'s output.
861   //
862   // paramscalar -------\     /-----------\
863   //                     \   /             \
864   // param0[100] ------- (mul) -- (add) -- (sub)
865   //                              /
866   // param1[100] ----------------/
867   //
868   auto builder = HloComputation::Builder(TestName());
869   auto paramscalar =
870       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
871   auto broadcast = builder.AddInstruction(
872       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
873   auto param0 = builder.AddInstruction(
874       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
875   auto param1 = builder.AddInstruction(
876       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
877   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
878       f32vec100_, HloOpcode::kMultiply, broadcast, param0));
879   auto add = builder.AddInstruction(
880       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
881   auto sub = builder.AddInstruction(
882       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kSubtract, add, mul));
883   auto module = CreateNewVerifiedModule();
884   module->AddEntryComputation(builder.Build());
885 
886   auto buffers = RunBufferAssignment(module.get());
887 
888   // Input buffers were assigned for parameters.
889   BufferAllocation paramscalar_buffer =
890       GetAssignedInputAllocation(*buffers, paramscalar);
891   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
892   BufferAllocation param1_index = GetAssignedInputAllocation(*buffers, param1);
893   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
894   EXPECT_NE(paramscalar_buffer.index(), param1_index.index());
895   EXPECT_NE(param0_buffer.index(), param1_index.index());
896 
897   // The mul node had a buffer allocated.
898   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
899 
900   // Now the add node can't reuse the mul node's buffer.
901   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
902   EXPECT_NE(add_buffer.index(), mul_buffer.index());
903 
904   // Log size information for inspection.
905   const std::vector<const HloInstruction*> level0 = GetInstructions(sub);
906   int64_t size0 = ValidateBuffers(level0, *buffers);
907   LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
908             << " for " << level0.size() << " instructions; "
909             << "total buffer size " << size0;
910 }
911 
TEST_F(BufferAssignmentTest,TrivialMap)912 TEST_F(BufferAssignmentTest, TrivialMap) {
913   // This tests a trivial x+1 map as the only operation.
914   //
915   // param0[100x10] ---> (map x+1)
916   //
917   // Builds the map function.
918   auto module = CreateNewVerifiedModule();
919   auto map_computation =
920       module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1"));
921   auto inner_last = map_computation->root_instruction();
922 
923   // Creates the main kernel and verifies instruction counts.
924   auto builder = HloComputation::Builder(TestName());
925   auto param0 = builder.AddInstruction(
926       HloInstruction::CreateParameter(0, f32a100x10_, "p"));
927   auto map = builder.AddInstruction(
928       HloInstruction::CreateMap(f32a100x10_, {param0}, map_computation));
929   module->AddEntryComputation(builder.Build());
930 
931   const std::vector<const HloInstruction*> level0 = GetInstructions(map);
932   EXPECT_EQ(2, level0.size()) << "Invalid main kernel size";
933   const std::vector<const HloInstruction*> level1 = GetInstructions(inner_last);
934   EXPECT_EQ(3, level1.size()) << "Invalid nested add+1 size";
935 
936   // Assigns buffers and fetches sizes.
937   auto buffers = RunBufferAssignment(module.get());
938   int64_t size0 = ValidateBuffers(level0, *buffers);
939   int64_t size1 = ValidateBuffers(level1, *buffers);
940 
941   // Both algorithms assign the map's buffer before processing the embedded
942   // computation, so we can verify that the buffers aren't shared between them
943   // by checking:
944   EXPECT_TRUE(BuffersDistinct(level0, level1, *buffers))
945       << "Reuse between main kernel and embedded mapping.";
946 
947   // An input buffer was assigned for the parameter.
948   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
949 
950   // An output buffer was assigned for the map.
951   BufferAllocation map_buffer = GetAssignedOutputAllocation(*buffers, map);
952   EXPECT_NE(param0_buffer.index(), map_buffer.index());
953 
954   // The final computation node of the map is an add of an f32 param and a
955   // constant.
956   EXPECT_EQ(HloOpcode::kAdd, inner_last->opcode());
957   const BufferAllocation& inner_add_buffer =
958       GetTopLevelAllocation(*buffers, inner_last);
959   EXPECT_NE(inner_add_buffer.index(), map_buffer.index());
960 
961   // Log size information for inspection.
962   LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
963             << " for " << level0.size() + level1.size() << " instructions; "
964             << "total buffer size " << size0 + size1;
965 }
966 
TEST_F(BufferAssignmentTest,CannotReuseInputBufferOfReduce)967 TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) {
968   // Make sure that the input buffer of a reduce cannot be reused for its
969   // output.  (Reuse is not safe in the general case, as it reshapes and some
970   // out-of-order reductions could overwrite an element before a use.)
971   //
972   // param0[100] --- (exp1) --- (exp2) --- (reduce x+y) --- (exp3)
973   auto module = CreateNewVerifiedModule();
974   auto reduce_computation =
975       module->AddEmbeddedComputation(BuildReduceComputation("f32+f32"));
976 
977   auto builder = HloComputation::Builder(TestName());
978   auto param0 = builder.AddInstruction(
979       HloInstruction::CreateParameter(0, f32a100x10_, "p"));
980   auto exp1 = builder.AddInstruction(
981       HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, param0));
982   auto exp2 = builder.AddInstruction(
983       HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, exp1));
984   auto const0 = builder.AddInstruction(
985       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
986   auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
987       /*shape=*/f32vec10_,
988       /*operand=*/exp2,
989       /*init_value=*/const0,
990       /*dimensions_to_reduce=*/{0}, reduce_computation));
991   auto exp3 = builder.AddInstruction(
992       HloInstruction::CreateUnary(f32vec10_, HloOpcode::kExp, reduce));
993 
994   module->AddEntryComputation(builder.Build());
995 
996   auto buffers = RunBufferAssignment(module.get());
997   const std::vector<const HloInstruction*> instrs = GetInstructions(exp3);
998   ValidateBuffers(instrs, *buffers);
999 
1000   const BufferAllocation& exp1_buffer = GetTopLevelAllocation(*buffers, exp1);
1001   const BufferAllocation& exp2_buffer = GetTopLevelAllocation(*buffers, exp2);
1002   const BufferAllocation& reduce_buffer =
1003       GetTopLevelAllocation(*buffers, reduce);
1004 
1005   // The buffer of exp1 is trivially reusable for exp2 - this is just for sanity
1006   // checking.
1007   EXPECT_EQ(exp1_buffer.index(), exp2_buffer.index());
1008 
1009   // The buffer of exp2 cannot be used for reduce, even though it's the only
1010   // operand.
1011   EXPECT_NE(exp2_buffer.index(), reduce_buffer.index());
1012 }
1013 
TEST_F(BufferAssignmentTest,ExampleWhile)1014 TEST_F(BufferAssignmentTest, ExampleWhile) {
1015   // This tests a While loop example from the ir_semantics document.
1016   //
1017   // condition (s32,f32[4]) -> bool -- see BuildWhileConditionComputation.
1018   // body: (s32,f32[4]) -> (s32,f32[4]) -- see BuildWhileBodyComputation.
1019   //
1020   // const3[s32] -------\
1021   // const4[f32[4]] --- tuple --- while[condition, body]
1022   //
1023   // Builds the nested condition and body.
1024   auto module = CreateNewVerifiedModule();
1025   auto condition_computation =
1026       module->AddEmbeddedComputation(BuildWhileConditionComputation("if<4"));
1027   auto body_computation =
1028       module->AddEmbeddedComputation(BuildWhileBodyComputation("add-update"));
1029 
1030   // Creates the main kernel and verifies instruction counts.
1031   auto builder = HloComputation::Builder(TestName());
1032   auto const3 = builder.AddInstruction(
1033       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
1034   auto const4 = builder.AddInstruction(HloInstruction::CreateConstant(
1035       LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
1036   auto tuple =
1037       builder.AddInstruction(HloInstruction::CreateTuple({const3, const4}));
1038   auto while_op = builder.AddInstruction(HloInstruction::CreateWhile(
1039       t_s32_f32v4_, condition_computation, body_computation, tuple));
1040   module->AddEntryComputation(builder.Build());
1041 
1042   const std::vector<const HloInstruction*> level0 = GetInstructions(while_op);
1043   EXPECT_EQ(4, level0.size()) << "Invalid while kernel size";
1044   const std::vector<const HloInstruction*> levelc =
1045       GetInstructions(condition_computation->root_instruction());
1046   EXPECT_EQ(4, levelc.size()) << "Invalid nested condition size";
1047   const std::vector<const HloInstruction*> levelb =
1048       GetInstructions(body_computation->root_instruction());
1049   EXPECT_EQ(8, levelb.size()) << "Invalid nested body size";
1050 
1051   // Assigns buffers and fetches sizes.
1052   auto buffers = RunBufferAssignment(module.get());
1053   int64_t size0 = ValidateBuffers(level0, *buffers);
1054   int64_t sizec = ValidateBuffers(levelc, *buffers);
1055   int64_t sizeb = ValidateBuffers(levelb, *buffers);
1056 
1057   // BufferAssignment will assign a single allocation for the following
1058   // instructions: while, while.cond.param, while.body.param, while.body.result.
1059   EXPECT_FALSE(BuffersDistinct(level0, levelc, *buffers))
1060       << "Should be reuse between main kernel and embedded condition.";
1061   EXPECT_FALSE(BuffersDistinct(levelb, levelc, *buffers))
1062       << "Should be reuse between embedded condition and body.";
1063   // Expect buffer reuse between main kernel and body computation.
1064   EXPECT_FALSE(BuffersDistinct(level0, levelb, *buffers))
1065       << "Should be reuse between main kernel and embedded body.";
1066 
1067   // The final computation node of the while body is a tuple of s32 and
1068   // f32[4] adds.
1069   HloInstruction* body_root = body_computation->root_instruction();
1070   EXPECT_EQ(HloOpcode::kTuple, body_root->opcode());
1071 
1072   // Check that buffer for each subshape of 'while_op' shares allocation with
1073   // corresponding buffer from while body computation at same index.
1074   ShapeUtil::ForEachSubshape(
1075       while_op->shape(),
1076       [this, &buffers, while_op, body_root](const Shape& /*subshape*/,
1077                                             const ShapeIndex& index) {
1078         auto while_op_allocation = GetAllocation(*buffers, while_op, index);
1079         auto body_root_allocation = GetAllocation(*buffers, body_root, index);
1080         EXPECT_EQ(while_op_allocation.index(), body_root_allocation.index());
1081       });
1082 
1083   // Log size information for inspection.
1084   LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
1085             << " for " << level0.size() + levelc.size() + levelb.size()
1086             << " instructions; total buffer size " << size0 + sizec + sizeb;
1087 }
1088 
TEST_F(BufferAssignmentTest,ExampleConditional)1089 TEST_F(BufferAssignmentTest, ExampleConditional) {
1090   auto module = CreateNewVerifiedModule();
1091   auto true_computation = module->AddEmbeddedComputation(
1092       BuildR0F32UnaryOpComputation(HloOpcode::kCeil, "Ceil"));
1093   auto false_computation = module->AddEmbeddedComputation(
1094       BuildR0F32UnaryOpComputation(HloOpcode::kFloor, "Floor"));
1095 
1096   auto builder = HloComputation::Builder(TestName());
1097   auto pred = builder.AddInstruction(
1098       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1099   auto const1 = builder.AddInstruction(
1100       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.4f)));
1101   auto const2 = builder.AddInstruction(
1102       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.4f)));
1103   auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
1104       r0f32_, pred, const1, true_computation, const2, false_computation));
1105   module->AddEntryComputation(builder.Build());
1106 
1107   const std::vector<const HloInstruction*> conditional_instrs =
1108       GetInstructions(conditional);
1109   const std::vector<const HloInstruction*> true_instrs =
1110       GetInstructions(true_computation->root_instruction());
1111   const std::vector<const HloInstruction*> false_instrs =
1112       GetInstructions(false_computation->root_instruction());
1113   EXPECT_EQ(4, conditional_instrs.size());
1114   EXPECT_EQ(2, true_instrs.size());
1115   EXPECT_EQ(2, false_instrs.size());
1116 
1117   auto buffers = RunBufferAssignment(module.get());
1118   ValidateBuffers(conditional_instrs, *buffers);
1119   ValidateBuffers(true_instrs, *buffers);
1120   ValidateBuffers(false_instrs, *buffers);
1121 
1122   EXPECT_FALSE(BuffersDistinct(conditional_instrs, true_instrs, *buffers))
1123       << "Should be reuse between conditional and true computation.";
1124   EXPECT_FALSE(BuffersDistinct(conditional_instrs, false_instrs, *buffers))
1125       << "Should be reuse between conditional and false computation.";
1126   EXPECT_FALSE(BuffersDistinct(true_instrs, false_instrs, *buffers))
1127       << "Should be reuse between true and false computations.";
1128 
1129   const BufferAllocation& conditional_buffer =
1130       GetTopLevelAllocation(*buffers, conditional);
1131   const BufferAllocation& true_buffer =
1132       GetTopLevelAllocation(*buffers, true_computation->root_instruction());
1133   const BufferAllocation& false_buffer =
1134       GetTopLevelAllocation(*buffers, false_computation->root_instruction());
1135   EXPECT_EQ(conditional_buffer.size(), true_buffer.size());
1136   EXPECT_EQ(conditional_buffer.size(), false_buffer.size());
1137 }
1138 
TEST_F(BufferAssignmentTest,UnaryOpReuseChain)1139 TEST_F(BufferAssignmentTest, UnaryOpReuseChain) {
1140   // param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg)
1141   auto builder = HloComputation::Builder(TestName());
1142   auto param0 = builder.AddInstruction(
1143       HloInstruction::CreateParameter(0, f32vec100_, "p"));
1144   auto exp1 = builder.AddInstruction(
1145       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, param0));
1146   auto tanh = builder.AddInstruction(
1147       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kTanh, exp1));
1148   auto exp2 = builder.AddInstruction(
1149       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, tanh));
1150   auto neg = builder.AddInstruction(
1151       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, exp2));
1152 
1153   auto module = CreateNewVerifiedModule();
1154   module->AddEntryComputation(builder.Build());
1155   auto assignment = RunBufferAssignment(module.get());
1156 
1157   // tanh and exp2 can reuse exp1's buffer
1158   EXPECT_TRUE(assignment->HasTopLevelAllocation(exp1));
1159   auto& buffer_for_exp1 = GetTopLevelAllocation(*assignment, exp1);
1160   EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, tanh));
1161   EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, exp2));
1162   EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, neg));
1163 }
1164 
TEST_F(BufferAssignmentTest,ReuseNonOperandBuffer)1165 TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) {
1166   // This computation is a chain of operations which decreases in buffer size
1167   // (via slice) then increases in size (via broadcast):
1168   //
1169   // param ---> (negate) ---> (slice) ---> (broadcast)
1170   //
1171   // The negate should share a buffer with broadcast.
1172   auto builder = HloComputation::Builder(TestName());
1173   auto param0 = builder.AddInstruction(
1174       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1175   auto negate = builder.AddInstruction(
1176       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1177   auto slice = builder.AddInstruction(
1178       HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1179   auto broadcast = builder.AddInstruction(
1180       HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
1181 
1182   auto module = CreateNewVerifiedModule();
1183   module->AddEntryComputation(builder.Build());
1184   auto assignment = RunBufferAssignment(module.get());
1185 
1186   // negate and broadcast should share a buffer.
1187   EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast));
1188   auto& buffer_for_bcast = GetTopLevelAllocation(*assignment, broadcast);
1189   EXPECT_EQ(buffer_for_bcast, GetTopLevelAllocation(*assignment, negate));
1190 
1191   // Slice should have its own buffer.
1192   EXPECT_NE(buffer_for_bcast, GetTopLevelAllocation(*assignment, slice));
1193 }
1194 
TEST_F(BufferAssignmentTest,NoReuseLiveBuffer)1195 TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) {
1196   // This computation is identical to that in ReuseNonOperandBuffer, but the
1197   // negate value is live until the end of the computation (due to it being an
1198   // operand of the output tuple) preventing reuse.
1199   //
1200   // param ---> (negate) ---> (slice) ---> (broadcast)-> (tuple)
1201   //                  \-----------------------------------/
1202   //
1203   // The negate should not share a buffer with broadcast.
1204   auto builder = HloComputation::Builder(TestName());
1205   auto param0 = builder.AddInstruction(
1206       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1207   auto negate = builder.AddInstruction(
1208       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1209   auto slice = builder.AddInstruction(
1210       HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1211   auto broadcast = builder.AddInstruction(
1212       HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
1213   builder.AddInstruction(HloInstruction::CreateTuple({negate, broadcast}));
1214 
1215   auto module = CreateNewVerifiedModule();
1216   module->AddEntryComputation(builder.Build());
1217   auto assignment = RunBufferAssignment(module.get());
1218 
1219   // The instructions should not share buffers.
1220   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1221             GetTopLevelAllocation(*assignment, negate));
1222   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1223             GetTopLevelAllocation(*assignment, slice));
1224   EXPECT_NE(GetTopLevelAllocation(*assignment, negate),
1225             GetTopLevelAllocation(*assignment, slice));
1226 }
1227 
TEST_F(BufferAssignmentTest,NoReuseAliasedBuffer)1228 TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) {
1229   // This computation is identical to that in ReuseNonOperandBuffer, but the
1230   // negate value is placed into a tuple which lives to the end of the
1231   // computation. This extends the live range of negate's buffer preventing
1232   // reuse due to buffer aliasing.
1233   //
1234   // param ---> (negate) ---> (tuple) -> (slice) ---> (broadcast)-> (tuple)
1235   //                              \-----------------------------------/
1236   //
1237   // The negate should not share a buffer with broadcast.
1238   auto builder = HloComputation::Builder(TestName());
1239   auto param0 = builder.AddInstruction(
1240       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1241   auto negate = builder.AddInstruction(
1242       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1243   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({negate}));
1244   auto tuple_element = builder.AddInstruction(
1245       HloInstruction::CreateGetTupleElement(f32vec100_, tuple, 0));
1246   auto slice = builder.AddInstruction(
1247       HloInstruction::CreateSlice(f32vec10_, tuple_element, {0}, {10}, {1}));
1248   auto broadcast = builder.AddInstruction(
1249       HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
1250   builder.AddInstruction(HloInstruction::CreateTuple({tuple, broadcast}));
1251 
1252   auto module = CreateNewVerifiedModule();
1253   module->AddEntryComputation(builder.Build());
1254   auto assignment = RunBufferAssignment(module.get());
1255 
1256   // The instructions should not share buffers.
1257   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1258             GetTopLevelAllocation(*assignment, negate));
1259   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1260             GetTopLevelAllocation(*assignment, slice));
1261   EXPECT_NE(GetTopLevelAllocation(*assignment, negate),
1262             GetTopLevelAllocation(*assignment, slice));
1263 }
1264 
TEST_F(BufferAssignmentTest,DoNotReuseOversizedOutputBuffer)1265 TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) {
1266   // This computation is very similar to ReuseNonOperandBuffer except the
1267   // broadcast has a smaller output than the negate. This should block reuse of
1268   // negate's buffer by broadcast because the output buffer(s) of a computation
1269   // should be exactly sized for the value.
1270   //
1271   // param ---> (negate) ---> (slice) ---> (broadcast)
1272   //
1273   // Neither negate nor slice may share a buffer with broadcast.
1274   auto builder = HloComputation::Builder(TestName());
1275   auto param0 = builder.AddInstruction(
1276       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1277   // Negate output is 100 elements.
1278   auto negate = builder.AddInstruction(
1279       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1280   // Slice output is 10 elements.
1281   auto slice = builder.AddInstruction(
1282       HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1283   // Broadcast output is 40 elements.
1284   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1285       ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
1286 
1287   auto module = CreateNewVerifiedModule();
1288   module->AddEntryComputation(builder.Build());
1289   auto assignment = RunBufferAssignment(module.get());
1290 
1291   // The broadcast output buffer cannot be shared.
1292   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1293             GetTopLevelAllocation(*assignment, negate));
1294   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1295             GetTopLevelAllocation(*assignment, slice));
1296 }
1297 
TEST_F(BufferAssignmentTest,ReuseOutputBufferIfExactlySized)1298 TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) {
1299   // This is identical to DoNotReuseOversizedOutputBuffer except the broadcast
1300   // output is exactly the same size as the negate (rather than being
1301   // smaller). This enables reuse of negate's buffer by the broadcast because
1302   // the output buffer will be sized exactly to its value.
1303   //
1304   // param ---> (negate) ---> (slice) ---> (broadcast)
1305   //
1306   // The negate should *not* share a buffer with broadcast.
1307   auto builder = HloComputation::Builder(TestName());
1308   auto param0 = builder.AddInstruction(
1309       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1310   // Negate output is 100 elements.
1311   auto negate = builder.AddInstruction(
1312       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1313   auto slice = builder.AddInstruction(
1314       HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1315   // Broadcast output is 40 elements.
1316   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1317       ShapeUtil::MakeShape(F32, {10, 10}), slice, {0}));
1318 
1319   auto module = CreateNewVerifiedModule();
1320   module->AddEntryComputation(builder.Build());
1321   auto assignment = RunBufferAssignment(module.get());
1322 
1323   // negate and broadcast should share a buffer.
1324   EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast));
1325   auto& buffer_for_bcast = GetTopLevelAllocation(*assignment, broadcast);
1326   EXPECT_EQ(buffer_for_bcast, GetTopLevelAllocation(*assignment, negate));
1327 
1328   // Slice should have its own buffer.
1329   EXPECT_NE(buffer_for_bcast, GetTopLevelAllocation(*assignment, slice));
1330 }
1331 
TEST_F(BufferAssignmentTest,DoNotReuseOversizedOutputBufferInTuple)1332 TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) {
1333   // This computation is very similar to ReuseNonOperandBuffer except the
1334   // broadcast has a smaller output than the negate, and the broadcast is
1335   // contained in the computation output as a tuple element. This should block
1336   // reuse of the negate's buffer by the broadcast because the output buffer(s)
1337   // of a computation should be exactly sized for the value. This includes those
1338   // buffers aliased in the output (eg, contained as tuple elements).
1339   //
1340   // param ---> (negate) ---> (slice) ---> (broadcast) --> (tuple)
1341   //
1342   // Neither negate nor slice may share a buffer with broadcast.
1343   auto builder = HloComputation::Builder(TestName());
1344   auto param0 = builder.AddInstruction(
1345       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1346   // Negate output is 100 elements.
1347   auto negate = builder.AddInstruction(
1348       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1349   // Slice output is 10 elements.
1350   auto slice = builder.AddInstruction(
1351       HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1352   // Broadcast output is 40 elements.
1353   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1354       ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
1355   builder.AddInstruction(HloInstruction::CreateTuple({broadcast}));
1356 
1357   auto module = CreateNewVerifiedModule();
1358   module->AddEntryComputation(builder.Build());
1359   auto assignment = RunBufferAssignment(module.get());
1360 
1361   // The broadcast output buffer cannot be shared.
1362   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1363             GetTopLevelAllocation(*assignment, negate));
1364   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1365             GetTopLevelAllocation(*assignment, slice));
1366 }
1367 
TEST_F(BufferAssignmentTest,EmbeddedComputationBuffers)1368 TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) {
1369   // Verify that buffers for embedded computations are properly marked as
1370   // thread-local and that embedded parameters are not marked as
1371   // is_entry_computation_parameter.
1372   auto module = CreateNewVerifiedModule();
1373   auto vec_shape = ShapeUtil::MakeShape(F32, {42});
1374   auto scalar_shape = ShapeUtil::MakeShape(F32, {});
1375 
1376   // Create a scalar computation to use in a map.
1377   auto map_builder = HloComputation::Builder(TestName() + "_map");
1378   auto map_param = map_builder.AddInstruction(
1379       HloInstruction::CreateParameter(0, scalar_shape, "map_param"));
1380   auto map_root = map_builder.AddInstruction(
1381       HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, map_param));
1382   auto map_computation = module->AddEmbeddedComputation(map_builder.Build());
1383 
1384   // Create a vector computation to use in a kCall.
1385   auto call_builder = HloComputation::Builder(TestName() + "_call");
1386   auto call_param = call_builder.AddInstruction(
1387       HloInstruction::CreateParameter(0, vec_shape, "vec_param"));
1388   auto call_root = call_builder.AddInstruction(
1389       HloInstruction::CreateUnary(vec_shape, HloOpcode::kExp, call_param));
1390   auto call_computation = module->AddEmbeddedComputation(call_builder.Build());
1391 
1392   // Create entry computation which kCalls call_computation and then calls map
1393   // with map_computation on the result.
1394   auto builder = HloComputation::Builder(TestName());
1395   auto param = builder.AddInstruction(
1396       HloInstruction::CreateParameter(0, vec_shape, "param"));
1397   auto call = builder.AddInstruction(
1398       HloInstruction::CreateCall(vec_shape, {param}, call_computation));
1399   auto map = builder.AddInstruction(
1400       HloInstruction::CreateMap(vec_shape, {call}, map_computation));
1401   module->AddEntryComputation(builder.Build());
1402 
1403   auto assignment = RunBufferAssignment(module.get());
1404 
1405   // Allocations for the map computation should be thread-local and not
1406   // live-out.
1407   auto& map_param_alloc = GetTopLevelAllocation(*assignment, map_param);
1408   EXPECT_FALSE(map_param_alloc.is_entry_computation_parameter());
1409   EXPECT_FALSE(map_param_alloc.maybe_live_out());
1410   EXPECT_TRUE(map_param_alloc.is_thread_local());
1411 
1412   auto& map_root_alloc = GetTopLevelAllocation(*assignment, map_root);
1413   EXPECT_FALSE(map_root_alloc.is_entry_computation_parameter());
1414   EXPECT_FALSE(map_root_alloc.maybe_live_out());
1415   EXPECT_TRUE(map_root_alloc.is_thread_local());
1416 
1417   // Allocations for the call computation should not be thread-local.
1418   auto& call_param_alloc = GetTopLevelAllocation(*assignment, call_param);
1419   EXPECT_TRUE(call_param_alloc.is_entry_computation_parameter());
1420   EXPECT_FALSE(call_param_alloc.maybe_live_out());
1421   EXPECT_FALSE(call_param_alloc.is_thread_local());
1422 
1423   auto& call_root_alloc = GetTopLevelAllocation(*assignment, call_root);
1424   EXPECT_FALSE(call_root_alloc.is_entry_computation_parameter());
1425   EXPECT_FALSE(call_root_alloc.is_thread_local());
1426 
1427   // Entry computation allocations can be marked liveout and
1428   // is_entry_computation_parameter.
1429   auto& param_alloc = GetTopLevelAllocation(*assignment, param);
1430   EXPECT_TRUE(param_alloc.is_entry_computation_parameter());
1431   EXPECT_FALSE(param_alloc.maybe_live_out());
1432   EXPECT_FALSE(param_alloc.is_thread_local());
1433 
1434   auto& map_alloc = GetTopLevelAllocation(*assignment, map);
1435   EXPECT_FALSE(map_alloc.is_entry_computation_parameter());
1436   EXPECT_TRUE(map_alloc.maybe_live_out());
1437   EXPECT_FALSE(map_alloc.is_thread_local());
1438 }
1439 
TEST_F(BufferAssignmentTest,CustomCallEmbeddedComputationBuffers)1440 TEST_F(BufferAssignmentTest, CustomCallEmbeddedComputationBuffers) {
1441   // Verify that buffers for embedded computations in a custom call are properly
1442   // marked as thread-local.
1443   auto module = CreateNewVerifiedModule();
1444   auto scalar_shape = ShapeUtil::MakeShape(F32, {});
1445 
1446   // Create a scalar computation to use in a map.
1447   auto map_builder = HloComputation::Builder(TestName() + "_map");
1448   auto map_param = map_builder.AddInstruction(
1449       HloInstruction::CreateParameter(0, scalar_shape, "map_param"));
1450   auto map_root = map_builder.AddInstruction(
1451       HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, map_param));
1452   auto map_computation = module->AddEmbeddedComputation(map_builder.Build());
1453 
1454   // Create entry computation with a custom call on map_computation.
1455   auto builder = HloComputation::Builder(TestName());
1456   auto param = builder.AddInstruction(
1457       HloInstruction::CreateParameter(0, scalar_shape, "param"));
1458   builder.AddInstruction(HloInstruction::CreateCustomCall(
1459       scalar_shape, {param}, map_computation, "call_name"));
1460   module->AddEntryComputation(builder.Build());
1461 
1462   auto assignment = RunBufferAssignment(module.get());
1463 
1464   // Allocations for the map computation should be thread-local and not
1465   // live-out.
1466   auto& map_param_alloc = GetTopLevelAllocation(*assignment, map_param);
1467   EXPECT_FALSE(map_param_alloc.is_entry_computation_parameter());
1468   EXPECT_FALSE(map_param_alloc.maybe_live_out());
1469   EXPECT_TRUE(map_param_alloc.is_thread_local());
1470 
1471   auto& map_root_alloc = GetTopLevelAllocation(*assignment, map_root);
1472   EXPECT_FALSE(map_root_alloc.is_entry_computation_parameter());
1473   EXPECT_FALSE(map_root_alloc.maybe_live_out());
1474   EXPECT_TRUE(map_root_alloc.is_thread_local());
1475 }
1476 
TEST_F(BufferAssignmentTest,TupleParameterAsOutput)1477 TEST_F(BufferAssignmentTest, TupleParameterAsOutput) {
1478   // Test a computation that returns a tuple parameter.
1479   auto builder = HloComputation::Builder(TestName());
1480   auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
1481       0,
1482       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
1483                                  ShapeUtil::MakeShape(F32, {}),
1484                                  ShapeUtil::MakeShape(S32, {42})}),
1485       "param0"));
1486 
1487   auto module = CreateNewVerifiedModule();
1488   module->AddEntryComputation(builder.Build());
1489   auto assignment = RunBufferAssignment(module.get());
1490 
1491   // There should be four allocations: one for vector of pointers, and one for
1492   // each tuple element.
1493   EXPECT_EQ(4, assignment->Allocations().size());
1494 
1495   // Verify each buffer allocation is marked as an entry computation parameter
1496   // and is liveout.
1497   ShapeUtil::ForEachSubshape(
1498       tuple_param->shape(),
1499       [this, &assignment, tuple_param](const Shape& /*subshape*/,
1500                                        const ShapeIndex& index) {
1501         auto allocation = GetAllocation(*assignment, tuple_param, index);
1502         EXPECT_TRUE(allocation.is_entry_computation_parameter());
1503         EXPECT_EQ(0, allocation.parameter_number());
1504         EXPECT_TRUE(allocation.maybe_live_out());
1505       });
1506 }
1507 
TEST_F(BufferAssignmentTest,ElementOfNestedTupleParameterAsOutput)1508 TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) {
1509   // Test a computation which returns a GetElementTuple of a nested tuple
1510   // parameter.
1511   auto builder = HloComputation::Builder(TestName());
1512   auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
1513       0,
1514       ShapeUtil::MakeTupleShape(
1515           {ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
1516            ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {42}),
1517                                       ShapeUtil::MakeShape(S32, {101})})}),
1518       "param0"));
1519   auto tuple_element =
1520       builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1521           ShapeUtil::GetSubshape(tuple_param->shape(), {1}), tuple_param, 1));
1522 
1523   auto module = CreateNewVerifiedModule();
1524   module->AddEntryComputation(builder.Build());
1525   auto assignment = RunBufferAssignment(module.get());
1526 
1527   // Only some of the elements of the input param are liveout.
1528   EXPECT_FALSE(
1529       GetAllocation(*assignment, tuple_param, /*index=*/{}).maybe_live_out());
1530   // Tuple element at index={1} is live out because GetTupleElement({1})
1531   // forwards a pointer to this allocation (instead of defining its own buffer).
1532   EXPECT_TRUE(
1533       GetAllocation(*assignment, tuple_param, /*index=*/{1}).maybe_live_out());
1534   EXPECT_TRUE(GetAllocation(*assignment, tuple_param, /*index=*/{1, 0})
1535                   .maybe_live_out());
1536   EXPECT_TRUE(GetAllocation(*assignment, tuple_param, /*index=*/{1, 1})
1537                   .maybe_live_out());
1538 
1539   // The GetTupleElement output is liveout.
1540   EXPECT_TRUE(
1541       GetTopLevelAllocation(*assignment, tuple_element).maybe_live_out());
1542 
1543   // Verify that the GetTupleElement allocations of its elements match the
1544   // corresponding tuple parameter allocations because they alias.
1545   EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1, 0}),
1546             GetAllocation(*assignment, tuple_element, /*index=*/{0}));
1547   EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1, 1}),
1548             GetAllocation(*assignment, tuple_element, /*index=*/{1}));
1549 
1550   // GetTupleElement forwards a pointer to its underlying buffer, so verify
1551   // that it has the same allocation than the corresponding parameter element.
1552   EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1}),
1553             GetTopLevelAllocation(*assignment, tuple_element));
1554 }
1555 
1556 // TODO(b/32248867): Enable when buffer assignment gives allocations to
1557 // constants.
TEST_F(BufferAssignmentTest,TupleConstantAsOutput)1558 TEST_F(BufferAssignmentTest, TupleConstantAsOutput) {
1559   // Test that a tuple constant which is forwarded to the computation output
1560   // is properly handled.
1561   auto builder = HloComputation::Builder(TestName());
1562   Literal elements[] = {LiteralUtil::CreateR0<int64>(0),
1563                         LiteralUtil::CreateR0<int64>(1)};
1564   builder.AddInstruction(HloInstruction::CreateConstant(
1565       LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
1566 
1567   auto module = CreateNewVerifiedModule();
1568   module->AddEntryComputation(builder.Build());
1569   auto assignment = RunBufferAssignment(module.get());
1570 
1571   EXPECT_EQ(3, assignment->Allocations().size());
1572 }
1573 
TEST_F(BufferAssignmentTest,TupleCustomCallAsOutput)1574 TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) {
1575   // Test a computation which returns a tuple custom call value.
1576   auto builder = HloComputation::Builder(TestName());
1577   auto custom_call = builder.AddInstruction(HloInstruction::CreateCustomCall(
1578       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
1579                                  ShapeUtil::MakeShape(S32, {101})}),
1580       /*operands=*/{}, /*custom_call_target=*/"foo_function"));
1581   auto module = CreateNewVerifiedModule();
1582   module->AddEntryComputation(builder.Build());
1583   auto assignment = RunBufferAssignment(module.get());
1584 
1585   EXPECT_EQ(3, assignment->Allocations().size());
1586   EXPECT_TRUE(
1587       GetAllocation(*assignment, custom_call, /*index=*/{}).maybe_live_out());
1588   EXPECT_TRUE(
1589       GetAllocation(*assignment, custom_call, /*index=*/{0}).maybe_live_out());
1590   EXPECT_TRUE(
1591       GetAllocation(*assignment, custom_call, /*index=*/{1}).maybe_live_out());
1592 }
1593 
TEST_F(BufferAssignmentTest,TupleCallAsOutput)1594 TEST_F(BufferAssignmentTest, TupleCallAsOutput) {
1595   // Test a computation which returns a tuple call value.
1596   auto module = CreateNewVerifiedModule();
1597   auto elem_shape = f32vec4_;
1598   auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape});
1599 
1600   auto sub_builder = HloComputation::Builder(TestName() + "_sub");
1601   auto sub_param = sub_builder.AddInstruction(
1602       HloInstruction::CreateParameter(0, elem_shape, "sub_param"));
1603   auto sub_tuple =
1604       sub_builder.AddInstruction(HloInstruction::CreateTuple({sub_param}));
1605   auto sub_computation = module->AddEmbeddedComputation(sub_builder.Build());
1606 
1607   auto builder = HloComputation::Builder(TestName());
1608   auto param = builder.AddInstruction(
1609       HloInstruction::CreateParameter(0, elem_shape, "param"));
1610   auto call = builder.AddInstruction(
1611       HloInstruction::CreateCall(tuple_shape, {param}, sub_computation));
1612   module->AddEntryComputation(builder.Build());
1613 
1614   auto assignment = RunBufferAssignment(module.get());
1615 
1616   EXPECT_EQ(2, assignment->Allocations().size());
1617   // Buffers for call are colocated with the sub-computation.
1618   EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{}),
1619             GetAllocation(*assignment, sub_tuple, /*index=*/{}));
1620   EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{0}),
1621             GetAllocation(*assignment, sub_param, /*index=*/{}));
1622 
1623   // The parameter isn't aliased with the result tuple, but it is aliased with
1624   // the call operand.
1625   EXPECT_NE(GetTopLevelAllocation(*assignment, param),
1626             GetTopLevelAllocation(*assignment, sub_tuple));
1627   EXPECT_EQ(GetTopLevelAllocation(*assignment, param),
1628             GetTopLevelAllocation(*assignment, sub_param));
1629 }
1630 
TEST_F(BufferAssignmentTest,TupleChainedCallAsOutput)1631 TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) {
1632   // Test a chain of calls with tuple output. The chain looks like:
1633   // A: call(B, tuple(param))
1634   // B: call(C, param)
1635   // C: call(D, param)
1636   // D: param
1637   auto module = CreateNewVerifiedModule();
1638   auto elem_shape = f32vec4_;
1639   auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape});
1640 
1641   auto d_builder = HloComputation::Builder(TestName() + "_d");
1642   auto d_param = d_builder.AddInstruction(
1643       HloInstruction::CreateParameter(0, tuple_shape, "d_param"));
1644   auto d_computation = d_builder.Build();
1645 
1646   auto c_builder = HloComputation::Builder(TestName() + "_c");
1647   auto c_param = c_builder.AddInstruction(
1648       HloInstruction::CreateParameter(0, tuple_shape, "c_param"));
1649   auto c_call = c_builder.AddInstruction(
1650       HloInstruction::CreateCall(tuple_shape, {c_param}, d_computation.get()));
1651   auto c_computation = c_builder.Build();
1652 
1653   auto b_builder = HloComputation::Builder(TestName() + "_b");
1654   auto b_param = b_builder.AddInstruction(
1655       HloInstruction::CreateParameter(0, tuple_shape, "b_param"));
1656   auto b_call = b_builder.AddInstruction(
1657       HloInstruction::CreateCall(tuple_shape, {b_param}, c_computation.get()));
1658   auto b_computation = b_builder.Build();
1659 
1660   auto a_builder = HloComputation::Builder(TestName());
1661   auto a_param = a_builder.AddInstruction(
1662       HloInstruction::CreateParameter(0, elem_shape, "param"));
1663   auto a_tuple =
1664       a_builder.AddInstruction(HloInstruction::CreateTuple({a_param}));
1665   auto a_call = a_builder.AddInstruction(
1666       HloInstruction::CreateCall(tuple_shape, {a_tuple}, b_computation.get()));
1667   auto a_computation = a_builder.Build();
1668 
1669   // Add the computations in an order that doesn't match the dependency
1670   // post-order, to shake out more possible bugs.
1671   module->AddEmbeddedComputation(std::move(d_computation));
1672   module->AddEmbeddedComputation(std::move(c_computation));
1673   module->AddEntryComputation(std::move(a_computation));
1674   module->AddEmbeddedComputation(std::move(b_computation));
1675 
1676   auto assignment = RunBufferAssignment(module.get());
1677 
1678   // Buffers for call are colocated with the sub-computations.
1679   EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{}),
1680             GetAllocation(*assignment, b_call, /*index=*/{}));
1681   EXPECT_EQ(GetAllocation(*assignment, b_call, /*index=*/{}),
1682             GetAllocation(*assignment, c_call, /*index=*/{}));
1683   EXPECT_EQ(GetAllocation(*assignment, c_call, /*index=*/{}),
1684             GetAllocation(*assignment, d_param, /*index=*/{}));
1685   EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{0}),
1686             GetAllocation(*assignment, b_call, /*index=*/{0}));
1687   EXPECT_EQ(GetAllocation(*assignment, b_call, /*index=*/{0}),
1688             GetAllocation(*assignment, c_call, /*index=*/{0}));
1689   EXPECT_EQ(GetAllocation(*assignment, c_call, /*index=*/{0}),
1690             GetAllocation(*assignment, d_param, /*index=*/{0}));
1691 
1692   EXPECT_TRUE(BuffersDistinct({a_param}, {b_param}, *assignment));
1693   EXPECT_TRUE(BuffersDistinct({a_param}, {c_param}, *assignment));
1694   EXPECT_TRUE(BuffersDistinct({a_param}, {d_param}, *assignment));
1695 
1696   EXPECT_EQ(GetAllocation(*assignment, b_param, /*index=*/{0}),
1697             GetAllocation(*assignment, c_param, /*index=*/{0}));
1698   EXPECT_EQ(GetAllocation(*assignment, c_param, /*index=*/{0}),
1699             GetAllocation(*assignment, d_param, /*index=*/{0}));
1700 }
1701 
TEST_F(BufferAssignmentTest,BitcastAsOutput)1702 TEST_F(BufferAssignmentTest, BitcastAsOutput) {
1703   // Test a computation which returns a bitcast value.
1704   auto builder = HloComputation::Builder(TestName());
1705   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
1706       0, ShapeUtil::MakeShape(F32, {42}), "param"));
1707   auto bitcast = builder.AddInstruction(
1708       HloInstruction::CreateBitcast(param->shape(), param));
1709 
1710   auto module = CreateNewVerifiedModule();
1711   module->AddEntryComputation(builder.Build());
1712   auto assignment = RunBufferAssignment(module.get());
1713 
1714   // Bitcast should get the same allocation as the param.
1715   EXPECT_EQ(1, assignment->Allocations().size());
1716   EXPECT_EQ(GetTopLevelAllocation(*assignment, param),
1717             GetTopLevelAllocation(*assignment, bitcast));
1718 }
1719 
TEST_F(BufferAssignmentTest,AmbiguousBufferAsOutput)1720 TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) {
1721   // Test a computation with an output that has an ambiguous points-to set.
1722   // This is constructed using a select among tuple shapes.
1723   auto builder = HloComputation::Builder(TestName());
1724   auto tuple_shape =
1725       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4})});
1726 
1727   auto tuple_param0 = builder.AddInstruction(
1728       HloInstruction::CreateParameter(0, tuple_shape, "param0"));
1729   auto tuple_param1 = builder.AddInstruction(
1730       HloInstruction::CreateParameter(1, tuple_shape, "param1"));
1731   auto pred_param = builder.AddInstruction(HloInstruction::CreateParameter(
1732       2, ShapeUtil::MakeShape(PRED, {}), "param1"));
1733   auto select = builder.AddInstruction(
1734       HloInstruction::CreateTernary(tuple_shape, HloOpcode::kTupleSelect,
1735                                     pred_param, tuple_param0, tuple_param1));
1736 
1737   auto module = CreateNewVerifiedModule();
1738   module->AddEntryComputation(builder.Build());
1739   auto assignment = RunBufferAssignment(module.get());
1740 
1741   // Select shallow copies one of its operands so it defines its own top-level
1742   // buffer and receives its own allocation.
1743   auto select_alloc = GetTopLevelAllocation(*assignment, select);
1744   EXPECT_EQ(1, select_alloc.assigned_buffers().size());
1745   EXPECT_EQ(select,
1746             select_alloc.assigned_buffers().begin()->first->instruction());
1747 
1748   // The buffer for the tuple element of the select is forwarded from one its
1749   // operands which cannot be determined statically. Therefore its slices
1750   // should include the slices of both of the elements in the parameters.
1751   auto element_slices = assignment->GetAllSlices(select, /*index=*/{0});
1752   EXPECT_EQ(2, element_slices.size());
1753   EXPECT_THAT(element_slices,
1754               UnorderedElementsAre(
1755                   assignment->GetUniqueSlice(tuple_param0, /*index=*/{0})
1756                       .ConsumeValueOrDie(),
1757                   assignment->GetUniqueSlice(tuple_param1, /*index=*/{0})
1758                       .ConsumeValueOrDie()));
1759 }
1760 
1761 // TODO(b/34669761): Remove this test when buffers are allowed to share
1762 // allocations.
TEST_F(BufferAssignmentTest,TupleBufferNotReused)1763 TEST_F(BufferAssignmentTest, TupleBufferNotReused) {
1764   // Test a computation that returns a tuple parameter.
1765   auto builder = HloComputation::Builder(TestName());
1766   auto scalar_shape = ShapeUtil::MakeShape(F32, {});
1767   auto param = builder.AddInstruction(
1768       HloInstruction::CreateParameter(0, scalar_shape, "param0"));
1769   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({param}));
1770   auto tuple_element = builder.AddInstruction(
1771       HloInstruction::CreateGetTupleElement(scalar_shape, tuple, 0));
1772   auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
1773       scalar_shape, HloOpcode::kCopy, tuple_element));
1774 
1775   auto module = CreateNewVerifiedModule();
1776   module->AddEntryComputation(builder.Build());
1777   auto assignment = RunBufferAssignment(module.get());
1778 
1779   // There should be no buffer reuse. The copy should not reuse the tuple
1780   // buffer.
1781   EXPECT_EQ(3, assignment->Allocations().size());
1782   EXPECT_NE(GetTopLevelAllocation(*assignment, tuple),
1783             GetTopLevelAllocation(*assignment, copy));
1784 }
1785 
TEST_F(BufferAssignmentTest,OneTempAllocation)1786 TEST_F(BufferAssignmentTest, OneTempAllocation) {
1787   // Test a computation that requires multiple temp buffers, and ensure they
1788   // are combined into a single allocation.
1789   auto builder = HloComputation::Builder(TestName());
1790   Shape shape_2x3 = ShapeUtil::MakeShape(F32, {2, 3});
1791   Shape shape_2x4 = ShapeUtil::MakeShape(F32, {2, 4});
1792   Shape shape_3x4 = ShapeUtil::MakeShape(F32, {3, 4});
1793   Shape shape_4x4 = ShapeUtil::MakeShape(F32, {4, 4});
1794   Shape shape_5x4 = ShapeUtil::MakeShape(F32, {5, 4});
1795 
1796   // There should be separate temp buffers for dot_ab and dot_bc.
1797   auto param_a = builder.AddInstruction(
1798       HloInstruction::CreateParameter(0, shape_2x3, "param_a"));
1799   auto param_b = builder.AddInstruction(
1800       HloInstruction::CreateParameter(1, shape_3x4, "param_b"));
1801   auto param_c = builder.AddInstruction(
1802       HloInstruction::CreateParameter(2, shape_4x4, "param_c"));
1803   DotDimensionNumbers dot_dnums;
1804   dot_dnums.add_lhs_contracting_dimensions(1);
1805   dot_dnums.add_rhs_contracting_dimensions(0);
1806   PrecisionConfig precision_config;
1807   precision_config.mutable_operand_precision()->Resize(
1808       2, PrecisionConfig::DEFAULT);
1809   auto dot_ab = builder.AddInstruction(HloInstruction::CreateDot(
1810       shape_2x4, param_a, param_b, dot_dnums, precision_config));
1811   auto dot_bc = builder.AddInstruction(HloInstruction::CreateDot(
1812       shape_3x4, param_b, param_c, dot_dnums, precision_config));
1813   builder.AddInstruction(
1814       HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 0));
1815 
1816   // Run buffer assignment with alignment=1.
1817   auto module = CreateNewVerifiedModule();
1818   module->AddEntryComputation(builder.Build());
1819   auto assignment = RunBufferAssignment(module.get(), /*alignment=*/1);
1820 
1821   // There are 5 allocations: 3 parameters, 1 output, and 1 temp.
1822   EXPECT_EQ(5, assignment->Allocations().size());
1823 
1824   // Ensure the temp buffers for dot_ab and dot_bc share a single allocation,
1825   // and each occupies different slices of that allocation.
1826   BufferAllocation::Slice slice_ab =
1827       assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie();
1828   BufferAllocation::Slice slice_bc =
1829       assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie();
1830   EXPECT_EQ(slice_ab.allocation(), slice_bc.allocation());
1831   EXPECT_NE(slice_ab, slice_bc);
1832   EXPECT_EQ(32, slice_ab.size());
1833   EXPECT_EQ(48, slice_bc.size());
1834   EXPECT_EQ(80, slice_ab.allocation()->size());
1835   EXPECT_EQ(80, slice_bc.allocation()->size());
1836 
1837   // Re-run buffer assignment with alignment=64.
1838   assignment = RunBufferAssignment(module.get(), /*alignment=*/64);
1839   EXPECT_EQ(5, assignment->Allocations().size());
1840   slice_ab = assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie();
1841   slice_bc = assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie();
1842   EXPECT_EQ(slice_ab.allocation(), slice_bc.allocation());
1843   EXPECT_NE(slice_ab, slice_bc);
1844   EXPECT_EQ(32, slice_ab.size());
1845   EXPECT_EQ(48, slice_bc.size());
1846   // Ensure the offsets and allocation size account for the alignment, without
1847   // assuming which buffer gets assigned first.
1848   if (slice_ab.offset() == 0) {
1849     EXPECT_EQ(64, slice_bc.offset());
1850     EXPECT_EQ(64 + 48, slice_ab.allocation()->size());
1851     EXPECT_EQ(64 + 48, slice_bc.allocation()->size());
1852   } else {
1853     EXPECT_EQ(64, slice_ab.offset());
1854     EXPECT_EQ(0, slice_bc.offset());
1855     EXPECT_EQ(64 + 32, slice_ab.allocation()->size());
1856     EXPECT_EQ(64 + 32, slice_bc.allocation()->size());
1857   }
1858 }
1859 
TEST_F(BufferAssignmentTest,TrivialPeakBuffers)1860 TEST_F(BufferAssignmentTest, TrivialPeakBuffers) {
1861   // paramscalar -(bc)- (mul) -- (add) -- (sub)
1862   //                     /        /        /
1863   // param0[100] -------/        /        /
1864   //                            /        /
1865   // param1[100] --------------/--------/
1866   auto builder = HloComputation::Builder(TestName());
1867   auto paramscalar =
1868       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
1869   auto broadcast = builder.AddInstruction(
1870       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
1871   auto param0 = builder.AddInstruction(
1872       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
1873   auto param1 = builder.AddInstruction(
1874       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
1875   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
1876       f32vec100_, HloOpcode::kMultiply, broadcast, param0));
1877   auto add = builder.AddInstruction(
1878       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
1879   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
1880       f32vec100_, HloOpcode::kSubtract, add, param1));
1881   auto module = CreateNewVerifiedModule();
1882   module->AddEntryComputation(builder.Build());
1883 
1884   auto buffers = RunBufferAssignment(module.get());
1885 
1886   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
1887   const std::vector<const HloValue*>& peak_buffers =
1888       mul_buffer.PeakMemoryLogicalBuffers();
1889   ASSERT_EQ(peak_buffers.size(), 1);
1890   EXPECT_EQ(peak_buffers[0]->instruction(), sub);
1891 }
1892 
TEST_F(BufferAssignmentTest,PeakBuffers)1893 TEST_F(BufferAssignmentTest, PeakBuffers) {
1894   // Compute the peak liveness buffers of the following sequence:
1895   //
1896   //   %param = ...
1897   //   %log = log(%param)
1898   //   %rev = reverse(%log)
1899   //   %neg = neg(%param)
1900   //   %concat = concat(%rev, %neg)
1901   //   ROOT %root = slice(concat)
1902   //
1903   // In the temporary block, the set of live buffers at peak memory use should
1904   // be {%rev, %neg, %concat}. This occurs right at the concat itself.
1905   auto builder = HloComputation::Builder(TestName());
1906   auto param = builder.AddInstruction(
1907       HloInstruction::CreateParameter(0, f32vec100_, "p"));
1908   auto log = builder.AddInstruction(
1909       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kLog, param));
1910   auto rev = builder.AddInstruction(
1911       HloInstruction::CreateReverse(f32vec100_, log, {0}));
1912   auto neg = builder.AddInstruction(
1913       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param));
1914   const Shape concat_shape = ShapeUtil::MakeShape(F32, {200});
1915   auto concat = builder.AddInstruction(
1916       HloInstruction::CreateConcatenate(concat_shape, {rev, neg}, 0));
1917   // Make the root tiny so no interior nodes can share its buffer.
1918   auto root = builder.AddInstruction(HloInstruction::CreateSlice(
1919 
1920       ShapeUtil::MakeShape(F32, {1}), concat, {0}, {1}, {1}));
1921 
1922   auto module = CreateNewVerifiedModule();
1923   module->AddEntryComputation(builder.Build());
1924 
1925   auto buffers = RunBufferAssignmentWithInstructionSequence(
1926       module.get(), {param, log, rev, neg, concat, root});
1927 
1928   // The temporary buffer should hold the 4 interior instructions.
1929   const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, concat);
1930   EXPECT_FALSE(buffer.IsInputOrOutput());
1931   EXPECT_TRUE(buffer.IsPreallocatedTempBuffer());
1932   ASSERT_EQ(buffer.assigned_buffers().size(), 4);
1933 
1934   const std::vector<const HloValue*>& peak_buffers =
1935       buffer.PeakMemoryLogicalBuffers();
1936 
1937   // The peak live set should be concat and its inputs.
1938   ASSERT_EQ(peak_buffers.size(), 3);
1939   std::vector<const HloInstruction*> peak_instructions;
1940   for (const HloValue* logical_buffer : peak_buffers) {
1941     peak_instructions.push_back(logical_buffer->instruction());
1942   }
1943   EXPECT_THAT(peak_instructions, UnorderedElementsAre(rev, neg, concat));
1944 }
1945 
TEST_F(BufferAssignmentTest,InPlaceBuffer)1946 TEST_F(BufferAssignmentTest, InPlaceBuffer) {
1947   const char* hlo_text = R"(
1948 HloModule Module
1949 
1950 ENTRY main {
1951   state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
1952   constant.1 = f32[] constant(0)
1953   broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={}
1954   get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1
1955   get-tuple-element.3 = s32[] get-tuple-element(state), index=0
1956   constant.2 = s32[] constant(128)
1957   add.5 = s32[] add(get-tuple-element.3, constant.2)
1958   constant.3 = s32[] constant(0)
1959   dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
1960   dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
1961   ROOT tuple.85 = (s32[], f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
1962 }
1963 )";
1964 
1965   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text));
1966   HloInstruction* parameter =
1967       m->entry_computation()->GetInstructionWithName("get-tuple-element.4");
1968   HloInstruction* dus1 =
1969       m->entry_computation()->GetInstructionWithName("dynamic-update-slice.5");
1970   HloInstruction* dus2 =
1971       m->entry_computation()->GetInstructionWithName("dynamic-update-slice.9");
1972 
1973   auto buffers = RunBufferAssignment(m.get());
1974 
1975   {
1976     const BufferAllocation& parameter_alloc =
1977         GetTopLevelAllocation(*buffers, parameter);
1978 
1979     const BufferAllocation& dus1_alloc = GetTopLevelAllocation(*buffers, dus1);
1980     EXPECT_EQ(parameter_alloc, dus1_alloc);
1981     const BufferAllocation& dus2_alloc = GetTopLevelAllocation(*buffers, dus2);
1982     EXPECT_EQ(parameter_alloc, dus2_alloc);
1983   }
1984 }
1985 
TEST_F(BufferAssignmentTest,ConstantBuffersAreNotReused)1986 TEST_F(BufferAssignmentTest, ConstantBuffersAreNotReused) {
1987   const char* hlo_text = R"(
1988 HloModule Module
1989 
1990 True {
1991   ROOT x.0.1 = f32[] parameter(0)
1992 }
1993 
1994 False {
1995   x.0.0 = f32[] parameter(0)
1996   ROOT copy.1 = f32[] copy(x.0.0)
1997 }
1998 
1999 ENTRY main {
2000   pred.1.0 = pred[] parameter(0)
2001   constant.1.1 = f32[] constant(56)
2002   copy.2 = f32[] copy(constant.1.1)
2003   constant.1.2 = f32[] constant(12)
2004   ROOT conditional.1.3 = f32[] conditional(pred.1.0, copy.2, constant.1.2),
2005       true_computation=True, false_computation=False
2006 }
2007 )";
2008 
2009   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text));
2010   HloInstruction* constant_1 =
2011       m->entry_computation()->GetInstructionWithName("constant.1.1");
2012   HloInstruction* constant_2 =
2013       m->entry_computation()->GetInstructionWithName("constant.1.2");
2014 
2015   auto buffers = RunBufferAssignment(m.get());
2016 
2017   {
2018     const BufferAllocation& allocation_for_const_1 =
2019         GetTopLevelAllocation(*buffers, constant_1);
2020     EXPECT_TRUE(allocation_for_const_1.is_constant());
2021     for (const auto& buffer_offset_pair :
2022          allocation_for_const_1.assigned_buffers()) {
2023       EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
2024                 HloOpcode::kCopy);
2025       EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
2026                 HloOpcode::kConditional);
2027     }
2028   }
2029 
2030   {
2031     const BufferAllocation& allocation_for_const_2 =
2032         GetTopLevelAllocation(*buffers, constant_2);
2033     EXPECT_TRUE(allocation_for_const_2.is_constant());
2034     for (const auto& buffer_offset_pair :
2035          allocation_for_const_2.assigned_buffers()) {
2036       EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
2037                 HloOpcode::kCopy);
2038       EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
2039                 HloOpcode::kConditional);
2040     }
2041   }
2042 }
2043 
2044 class WhileBufferAssignmentTest : public HloTestBase {
2045  protected:
BuildWhileConditionComputation(const string & name)2046   std::unique_ptr<HloComputation> BuildWhileConditionComputation(
2047       const string& name) {
2048     auto builder = HloComputation::Builder(name);
2049     builder.AddInstruction(
2050         HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
2051     auto zero = builder.AddInstruction(
2052         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
2053     auto ten = builder.AddInstruction(
2054         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(10)));
2055     builder.AddInstruction(HloInstruction::CreateCompare(
2056         ShapeUtil::MakeShape(PRED, {}), zero, ten, ComparisonDirection::kLt));
2057     return builder.Build();
2058   }
2059 
BuildWhileBodyComputation(const string & name)2060   std::unique_ptr<HloComputation> BuildWhileBodyComputation(
2061       const string& name) {
2062     auto builder = HloComputation::Builder(name);
2063     auto loop_state = builder.AddInstruction(
2064         HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
2065     auto input = builder.AddInstruction(
2066         HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 0));
2067     auto weights = builder.AddInstruction(
2068         HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
2069     auto output = builder.AddInstruction(HloInstruction::CreateBinary(
2070         data_shape_, HloOpcode::kMultiply, input, weights));
2071     builder.AddInstruction(
2072         HloInstruction::CreateTuple({input, weights, output}));
2073     return builder.Build();
2074   }
2075 
RunBufferAssignment(HloModule * module,int64_t alignment=1)2076   std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
2077                                                         int64_t alignment = 1) {
2078     HloSchedule schedule =
2079         ScheduleModule(module, ByteSizeOf).ConsumeValueOrDie();
2080     return BufferAssigner::Run(
2081                module, absl::make_unique<SequentialHloOrdering>(schedule),
2082                ByteSizeOf,
2083                [alignment](LogicalBuffer::Color) { return alignment; },
2084                /*allocate_buffers_for_constants=*/true)
2085         .ConsumeValueOrDie();
2086   }
2087 
ByteSizeOf(const BufferValue & buffer)2088   static int64 ByteSizeOf(const BufferValue& buffer) {
2089     return ShapeUtil::ByteSizeOf(buffer.shape(), sizeof(void*));
2090   }
2091 
2092   Shape data_shape_ = ShapeUtil::MakeShape(F32, {4});
2093   Shape loop_state_shape_ =
2094       ShapeUtil::MakeTupleShape({data_shape_, data_shape_, data_shape_});
2095 };
2096 
RunCopyInsertion(HloModule * module)2097 static void RunCopyInsertion(HloModule* module) {
2098   CopyInsertion copy_insertion;
2099   EXPECT_IS_OK(copy_insertion.Run(module).status());
2100 }
2101 
TEST_F(WhileBufferAssignmentTest,TwoForwardWhileLoops)2102 TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) {
2103   auto module = CreateNewVerifiedModule();
2104   auto builder = HloComputation::Builder("entry");
2105 
2106   auto input0 = builder.AddInstruction(
2107       HloInstruction::CreateParameter(0, data_shape_, "input0"));
2108   auto weights0 = builder.AddInstruction(
2109       HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2110   auto weights1 = builder.AddInstruction(
2111       HloInstruction::CreateParameter(2, data_shape_, "weights1"));
2112 
2113   auto zero = builder.AddInstruction(
2114       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2115   auto output0 = builder.AddInstruction(
2116       HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2117   auto output1 = builder.AddInstruction(
2118       HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2119 
2120   auto cond0 =
2121       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2122   auto body0 =
2123       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2124 
2125   auto tuple0 = builder.AddInstruction(
2126       HloInstruction::CreateTuple({input0, weights0, output0}));
2127   auto while0 = builder.AddInstruction(
2128       HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
2129 
2130   auto cond1 =
2131       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2132   auto body1 =
2133       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2134   auto input1 = builder.AddInstruction(
2135       HloInstruction::CreateGetTupleElement(data_shape_, while0, 2));
2136   auto tuple1 = builder.AddInstruction(
2137       HloInstruction::CreateTuple({input1, weights1, output1}));
2138   auto while1 = builder.AddInstruction(
2139       HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
2140 
2141   module->AddEntryComputation(builder.Build());
2142   RunCopyInsertion(module.get());
2143   auto assignment = RunBufferAssignment(module.get());
2144 
2145   // Verify 'input0' and read-only use while0{0} alias.
2146   EXPECT_EQ(assignment->GetUniqueSlice(input0, {}).ConsumeValueOrDie(),
2147             assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie());
2148   // Verify 'weights0' and read-only use while0{1} alias.
2149   EXPECT_EQ(assignment->GetUniqueSlice(weights0, {}).ConsumeValueOrDie(),
2150             assignment->GetUniqueSlice(while0, {1}).ConsumeValueOrDie());
2151   // Verify 'while0{2}' and read-only use while1{0} alias.
2152   EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).ConsumeValueOrDie(),
2153             assignment->GetUniqueSlice(while1, {0}).ConsumeValueOrDie());
2154   // Verify 'weights1' and read-only use while1{1} alias.
2155   EXPECT_EQ(assignment->GetUniqueSlice(weights1, {}).ConsumeValueOrDie(),
2156             assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie());
2157 }
2158 
2159 // Tests that two colocated buffer sets are not merged if an entry parameter
2160 // buffer belongs to either of the colocation sets (b/73267882).
2161 //
2162 // %param --> %while.0 --> %mul --> %while.1 --> %broadcast
2163 //
2164 // %while.0 body just forwards the init value, so the loop carried variable
2165 // remains the constant, whereas %while.1 changes the loop carried variable.
TEST_F(WhileBufferAssignmentTest,ColocatedBufferWithEntryParameter)2166 TEST_F(WhileBufferAssignmentTest, ColocatedBufferWithEntryParameter) {
2167   const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
2168 
2169   const char* module_str = R"(
2170 HloModule test_module
2171 
2172 %cond.v0 {
2173   %param = s32[] parameter(0)
2174   ROOT %constant = pred[] constant(true)
2175 }
2176 
2177 %cond.v1 {
2178   %param.0 = s32[] parameter(0)
2179   ROOT %constant.0 = pred[] constant(true)
2180 }
2181 
2182 %body.v0 {
2183   ROOT %param.1 = s32[] parameter(0)
2184 }
2185 
2186 %body.v1 {
2187   %param.2 = s32[] parameter(0)
2188   ROOT add = s32[] add(%param.2, %param.2)
2189 }
2190 
2191 ENTRY %test_module {
2192   %param.3 = s32[] parameter(0)
2193   %while.0 = s32[] while(%param.3), condition=%cond.v0, body=%body.v0
2194   %mul = s32[] multiply(%while.0, %while.0)
2195   %while.1 = s32[] while(%mul), condition=%cond.v1, body=%body.v1
2196   ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={}
2197 })";
2198 
2199   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
2200 
2201   // Run CopyInsertion and check if the graph constructed above doesn't need
2202   // any copies inserted for BufferAssignment to run.
2203   int64_t instruction_count = m->instruction_count();
2204   CopyInsertion copy_insertion;
2205   ASSERT_IS_OK(copy_insertion.Run(m.get()).status());
2206   ASSERT_EQ(instruction_count, m->instruction_count());
2207 
2208   // Get the instructions in the module.
2209   const HloInstruction* bcast = m->entry_computation()->root_instruction();
2210   const HloInstruction* param =
2211       m->entry_computation()->parameter_instruction(0);
2212   ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
2213   const HloInstruction* while1 = bcast->operand(0);
2214   ASSERT_EQ(while1->opcode(), HloOpcode::kWhile);
2215   const HloInstruction* while0 = while1->operand(0)->operand(0);
2216   ASSERT_EQ(while0->opcode(), HloOpcode::kWhile);
2217 
2218   // Run buffer assignment.
2219   auto assignment = RunBufferAssignment(m.get());
2220   TF_ASSERT_OK_AND_ASSIGN(auto slice_param,
2221                           assignment->GetUniqueSlice(param, {}));
2222   TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
2223                           assignment->GetUniqueSlice(while0, {}));
2224   TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
2225                           assignment->GetUniqueSlice(while1, {}));
2226 
2227   // The parameter slice is part of the while0's colocation set (init value),
2228   // but not merged into the while1's colocation set.
2229   EXPECT_EQ(slice_param, slice_while0);
2230   EXPECT_NE(slice_param, slice_while1);
2231 }
2232 
TEST_F(WhileBufferAssignmentTest,ColocatedBufferWithConstant)2233 TEST_F(WhileBufferAssignmentTest, ColocatedBufferWithConstant) {
2234   const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
2235 
2236   const char* module_str = R"(
2237 HloModule test_module
2238 
2239 %cond.v0 {
2240   %param = s32[] parameter(0)
2241   ROOT %constant = pred[] constant(true)
2242 }
2243 
2244 %cond.v1 {
2245   %param.0 = s32[] parameter(0)
2246   ROOT %constant.0 = pred[] constant(true)
2247 }
2248 
2249 %body.v0 {
2250   ROOT %param.1 = s32[] parameter(0)
2251 }
2252 
2253 %body.v1 {
2254   %param.2 = s32[] parameter(0)
2255   ROOT add = s32[] add(%param.2, %param.2)
2256 }
2257 
2258 ENTRY %test_module {
2259   %constant.42 = s32[] constant(42)
2260   %while.0 = s32[] while(%constant.42), condition=%cond.v0, body=%body.v0
2261   %mul = s32[] multiply(%while.0, %while.0)
2262   %while.1 = s32[] while(%mul), condition=%cond.v1, body=%body.v1
2263   ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={}
2264 })";
2265 
2266   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
2267 
2268   // Run CopyInsertion and check if the graph constructed above doesn't need
2269   // any copies inserted for BufferAssignment to run.
2270   int64_t instruction_count = m->instruction_count();
2271   CopyInsertion copy_insertion;
2272   ASSERT_IS_OK(copy_insertion.Run(m.get()).status());
2273   ASSERT_EQ(instruction_count, m->instruction_count());
2274 
2275   // Get the instructions in the module.
2276   const HloInstruction* bcast = m->entry_computation()->root_instruction();
2277   const HloInstruction* constant =
2278       m->entry_computation()->GetInstructionWithName("constant.42");
2279   ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
2280   const HloInstruction* while1 = bcast->operand(0);
2281   ASSERT_EQ(while1->opcode(), HloOpcode::kWhile);
2282   const HloInstruction* while0 = while1->operand(0)->operand(0);
2283   ASSERT_EQ(while0->opcode(), HloOpcode::kWhile);
2284 
2285   // Run buffer assignment.
2286   auto assignment = RunBufferAssignment(m.get());
2287   TF_ASSERT_OK_AND_ASSIGN(auto slice_constant,
2288                           assignment->GetUniqueSlice(constant, {}));
2289   TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
2290                           assignment->GetUniqueSlice(while0, {}));
2291   TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
2292                           assignment->GetUniqueSlice(while1, {}));
2293 
2294   // The constant slice is part of the while0's colocation set (init value), but
2295   // not merged into the while1's colocation set.
2296   EXPECT_EQ(slice_constant, slice_while0);
2297   EXPECT_NE(slice_constant, slice_while1);
2298 }
2299 
2300 // Tests that the colocated buffers for while instructions are properly assigned
2301 // during buffer assignment such that the result tuple elements are not assigned
2302 // to the same buffer.
2303 //
2304 // %infeed --> %while.0 --> %while.1 --+
2305 //                                     +-- %tuple
2306 //   %zero -->   %add   --> %while.2 --+
2307 //
2308 // Execution Order:
2309 // %infeed -> %while.0 -> %while.1 -> %zero -> %add -> %while.2 -> %tuple
2310 //
2311 // The HLO computation used in this test requires specific ordering to expose
2312 // the bug (b/72496031). During buffer assignment, the visitation order of
2313 // colocated buffers is %while.2 -> while.0 -> while.1, and the buffer
2314 // assignment was coalescing the colocated buffers for all 3 while instructions,
2315 // therefore assigning the same buffer to the two result tuple elements.
TEST_F(WhileBufferAssignmentTest,ColocatedBuffers)2316 TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
2317   const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
2318 
2319   // Builds a condition computation: x -> x < 4
2320   auto build_cond = [&]() {
2321     auto builder = HloComputation::Builder("cond");
2322     auto const4 = builder.AddInstruction(
2323         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
2324     auto param =
2325         builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
2326     builder.AddInstruction(
2327         HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param,
2328                                       const4, ComparisonDirection::kLt));
2329     return builder.Build();
2330   };
2331 
2332   // Builds a body computation: x -> x + 9
2333   auto build_body = [&]() {
2334     auto builder = HloComputation::Builder("body");
2335     auto const9 = builder.AddInstruction(
2336         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(9)));
2337     auto param =
2338         builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
2339     builder.AddInstruction(
2340         HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, param, const9));
2341     return builder.Build();
2342   };
2343 
2344   // Build the entry computation as described in the comment above.
2345   auto module = CreateNewVerifiedModule();
2346   auto builder = HloComputation::Builder("entry");
2347 
2348   auto token = builder.AddInstruction(HloInstruction::CreateToken());
2349   auto infeed =
2350       builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, token, ""));
2351   auto infeed_data = builder.AddInstruction(
2352       HloInstruction::CreateGetTupleElement(r0s32, infeed, 0));
2353   auto cond0 = module->AddEmbeddedComputation(build_cond());
2354   auto body0 = module->AddEmbeddedComputation(build_body());
2355   auto while0 = builder.AddInstruction(
2356       HloInstruction::CreateWhile(r0s32, cond0, body0, infeed_data));
2357 
2358   auto cond1 = module->AddEmbeddedComputation(build_cond());
2359   auto body1 = module->AddEmbeddedComputation(build_body());
2360   auto while1 = builder.AddInstruction(
2361       HloInstruction::CreateWhile(r0s32, cond1, body1, while0));
2362 
2363   auto zero = builder.AddInstruction(
2364       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
2365   auto add = builder.AddInstruction(
2366       HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, zero, zero));
2367   auto cond2 = module->AddEmbeddedComputation(build_cond());
2368   auto body2 = module->AddEmbeddedComputation(build_body());
2369   auto while2 = builder.AddInstruction(
2370       HloInstruction::CreateWhile(r0s32, cond2, body2, add));
2371 
2372   auto tuple =
2373       builder.AddInstruction(HloInstruction::CreateTuple({while2, while1}));
2374   module->AddEntryComputation(builder.Build());
2375 
2376   // Run CopyInsertion and check if the graph constructed above doesn't need
2377   // any copies inserted for BufferAssignment to run.
2378   int64_t instruction_count = module->instruction_count();
2379   CopyInsertion copy_insertion;
2380   ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
2381   ASSERT_EQ(instruction_count, module->instruction_count());
2382 
2383   // Create a sequential order among all the instructions in the entry
2384   // computation, since the issue this test stresses depends on the order the
2385   // nodes are traversed during BufferAssignment.
2386   TF_ASSERT_OK_AND_ASSIGN(
2387       HloSchedule schedule,
2388       ScheduleModule(module.get(), [](const BufferValue& buffer) {
2389         return ShapeUtil::ByteSizeOf(buffer.shape(),
2390                                      /*pointer_size=*/sizeof(void*));
2391       }));
2392   schedule.set_sequence(
2393       module->entry_computation(),
2394       {token, infeed, infeed_data, while0, while1, zero, add, while2, tuple});
2395   TF_ASSERT_OK(schedule.Verify());
2396 
2397   TF_ASSERT_OK_AND_ASSIGN(
2398       auto assignment,
2399       BufferAssigner::Run(
2400           module.get(), absl::make_unique<SequentialHloOrdering>(schedule),
2401           backend().compiler()->BufferSizeBytesFunction(),
2402           [](LogicalBuffer::Color) { return 1; },
2403           /*allocate_buffers_for_constants=*/true));
2404 
2405   // The result tuple elements must be assigned with different buffers.
2406   TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0}));
2407   TF_ASSERT_OK_AND_ASSIGN(auto slice1, assignment->GetUniqueSlice(tuple, {1}));
2408   EXPECT_NE(slice0, slice1);
2409 
2410   // while0 and while1 result buffers must be equal to slice1.
2411   TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
2412                           assignment->GetUniqueSlice(while0, {}));
2413   TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
2414                           assignment->GetUniqueSlice(while1, {}));
2415   EXPECT_EQ(slice1, slice_while0);
2416   EXPECT_EQ(slice1, slice_while1);
2417 
2418   // while2 result buffer must be equal to slice0.
2419   TF_ASSERT_OK_AND_ASSIGN(auto slice_while2,
2420                           assignment->GetUniqueSlice(while2, {}));
2421   EXPECT_EQ(slice0, slice_while2);
2422 }
2423 
TEST_F(WhileBufferAssignmentTest,OneForwardBackwardWhileLoopSet)2424 TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
2425   auto module = CreateNewVerifiedModule();
2426   auto builder = HloComputation::Builder("entry");
2427 
2428   auto input0 = builder.AddInstruction(
2429       HloInstruction::CreateParameter(0, data_shape_, "input0"));
2430   auto weights0 = builder.AddInstruction(
2431       HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2432 
2433   auto zero = builder.AddInstruction(
2434       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2435   auto output0 = builder.AddInstruction(
2436       HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2437 
2438   auto cond0 =
2439       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2440   auto body0 =
2441       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2442 
2443   auto tuple0 = builder.AddInstruction(
2444       HloInstruction::CreateTuple({input0, weights0, output0}));
2445   auto while0 = builder.AddInstruction(
2446       HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
2447 
2448   auto cond1 =
2449       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2450   auto body1 =
2451       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2452 
2453   auto while1 = builder.AddInstruction(
2454       HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0));
2455 
2456   module->AddEntryComputation(builder.Build());
2457   RunCopyInsertion(module.get());
2458   auto assignment = RunBufferAssignment(module.get());
2459 
2460   // while0 and while1 buffers should be completely aligned.
2461   EXPECT_EQ(assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie(),
2462             assignment->GetUniqueSlice(while1, {0}).ConsumeValueOrDie());
2463   EXPECT_EQ(assignment->GetUniqueSlice(while0, {1}).ConsumeValueOrDie(),
2464             assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie());
2465   EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).ConsumeValueOrDie(),
2466             assignment->GetUniqueSlice(while1, {2}).ConsumeValueOrDie());
2467 }
2468 
TEST_F(BufferAssignmentTest,TwoCalls)2469 TEST_F(BufferAssignmentTest, TwoCalls) {
2470   auto module = CreateNewVerifiedModule();
2471   Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {});
2472   HloComputation* sub_computation;
2473   {
2474     auto builder = HloComputation::Builder(TestName() + "_sub_comp");
2475     auto param = builder.AddInstruction(
2476         HloInstruction::CreateParameter(0, r0f32, "param"));
2477     auto constant1 = builder.AddInstruction(
2478         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2479     auto add = builder.AddInstruction(
2480         HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, constant1));
2481     sub_computation = module->AddEmbeddedComputation(builder.Build(add));
2482   }
2483   auto builder = HloComputation::Builder(TestName());
2484   auto constant2 = builder.AddInstruction(
2485       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
2486   auto constant3 = builder.AddInstruction(
2487       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
2488   auto call1 = builder.AddInstruction(
2489       HloInstruction::CreateCall(r0f32, {constant2}, sub_computation));
2490   auto call2 = builder.AddInstruction(
2491       HloInstruction::CreateCall(r0f32, {constant3}, sub_computation));
2492   auto add1 = builder.AddInstruction(
2493       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, call1, constant2));
2494   auto add2 = builder.AddInstruction(
2495       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, call2, add1));
2496   module->AddEntryComputation(builder.Build(add2));
2497 
2498   {
2499     FlattenCallGraph flatten;
2500     TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
2501     EXPECT_TRUE(result);
2502     std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
2503   }
2504 
2505   RunCopyInsertion(module.get());
2506   auto assignment = RunBufferAssignment(module.get());
2507 
2508   EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment));
2509 }
2510 
TEST_F(BufferAssignmentTest,CallParamCoAllocation)2511 TEST_F(BufferAssignmentTest, CallParamCoAllocation) {
2512   const char* hlo_text = R"(
2513 HloModule CallParamCoAllocation
2514 
2515 Callee {
2516   param0 = (f32[100],(f32[200],f32[300])) parameter(0)
2517   param1 = s32[20] parameter(1)
2518   ROOT constant = f32[] constant(1)
2519 }
2520 
2521 ENTRY Main {
2522   entry_param0 = f32[100] parameter(0)
2523   entry_param1 = s32[20]  parameter(1)
2524   custom_call = (f32[200],f32[300]) custom-call(), custom_call_target="call-target"
2525   call_op0 = (f32[100],(f32[200],f32[300])) tuple(entry_param0, custom_call)
2526   ROOT call_result = f32[] call(call_op0, entry_param1), to_apply=Callee
2527 }
2528 )";
2529 
2530   HloModuleConfig config;
2531   config.set_debug_options(GetDebugOptionsFromFlags());
2532   TF_ASSERT_OK_AND_ASSIGN(auto m,
2533                           ParseAndReturnVerifiedModule(hlo_text, config));
2534 
2535   auto buffers = RunBufferAssignment(m.get());
2536 
2537   HloComputation* main = m->entry_computation();
2538   HloComputation* callee = m->GetComputationWithName("Callee");
2539   EXPECT_NE(callee, nullptr);
2540 
2541   HloInstruction* param0 = callee->parameter_instruction(0);
2542   HloInstruction* param1 = callee->parameter_instruction(1);
2543 
2544   HloInstruction* entry_param0 = main->parameter_instruction(0);
2545   HloInstruction* entry_param1 = main->parameter_instruction(1);
2546   HloInstruction* custom_call = main->GetInstructionWithName("custom_call");
2547 
2548   EXPECT_EQ(GetAllocation(*buffers, entry_param0, {}),
2549             GetAllocation(*buffers, param0, {0}));
2550   EXPECT_EQ(GetAllocation(*buffers, entry_param1, {}),
2551             GetAllocation(*buffers, param1, {}));
2552 
2553   EXPECT_EQ(GetAllocation(*buffers, custom_call, {}),
2554             GetAllocation(*buffers, param0, {1}));
2555   EXPECT_EQ(GetAllocation(*buffers, custom_call, {0}),
2556             GetAllocation(*buffers, param0, {1, 0}));
2557   EXPECT_EQ(GetAllocation(*buffers, custom_call, {1}),
2558             GetAllocation(*buffers, param0, {1, 1}));
2559 }
2560 
TEST_F(BufferAssignmentTest,BufferInfoStringTest)2561 TEST_F(BufferAssignmentTest, BufferInfoStringTest) {
2562   absl::string_view module_str = R"(
2563 HloModule test_module
2564 
2565 ENTRY %test_module {
2566   %param.0 = s32[1024]{0} parameter(0)
2567   %param.1 = s32[1024]{0} parameter(1)
2568   %mul = s32[1024]{0} multiply(%param.0, %param.1)
2569   %add = s32[1024]{0} add(%mul, %param.0)
2570   ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[1024] %add), dimensions={0}
2571 })";
2572 
2573   absl::string_view reference_str =
2574       R"(buffer_id,buffer_name,offset,size,definition_time,end_time,num_uses,use_times,use_names
2575 0,"<0 param.0 @0>",0,4096,0,5,2,"2;3","mul, operand 0;add, operand 1"
2576 1,"<1 param.1 @0>",0,4096,1,5,1,"2","mul, operand 1"
2577 2,"<2 mul @0>",0,4096,2,3,1,"3","add, operand 0"
2578 3,"<3 add @0>",0,4096,3,4,1,"4","bcast, operand 0"
2579 4,"<4 bcast @0>",0,4194304,4,5,0,"",""
2580 )";
2581 
2582   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
2583   HloInstruction* const param0 = FindInstruction(m.get(), "param.0");
2584   HloInstruction* const param1 = FindInstruction(m.get(), "param.1");
2585   HloInstruction* const mul = FindInstruction(m.get(), "mul");
2586   HloInstruction* const add = FindInstruction(m.get(), "add");
2587   HloInstruction* const bcast = FindInstruction(m.get(), "bcast");
2588   // Run buffer assignment.
2589   auto assignment = RunBufferAssignmentWithInstructionSequence(
2590       m.get(), {param0, param1, mul, add, bcast});
2591   const std::string buffer_info_str = assignment->BufferInfoString();
2592 
2593   EXPECT_EQ(buffer_info_str, reference_str);
2594 }
2595 
TEST_F(WhileBufferAssignmentTest,WhileLoopsInterferingResultRange)2596 TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
2597   auto module = CreateNewVerifiedModule();
2598   auto builder = HloComputation::Builder(TestName());
2599 
2600   auto zero = builder.AddInstruction(
2601       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2602   auto one = builder.AddInstruction(
2603       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2604 
2605   auto input0 = builder.AddInstruction(
2606       HloInstruction::CreateParameter(0, data_shape_, "input0"));
2607   auto weights0 = builder.AddInstruction(
2608       HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2609   auto output0 = builder.AddInstruction(
2610       HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2611 
2612   auto input1 = builder.AddInstruction(
2613       HloInstruction::CreateParameter(2, data_shape_, "input1"));
2614   auto weights1 = builder.AddInstruction(
2615       HloInstruction::CreateParameter(3, data_shape_, "weights1"));
2616   auto output1 = builder.AddInstruction(
2617       HloInstruction::CreateBroadcast(data_shape_, one, {}));
2618 
2619   auto cond =
2620       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2621   auto body = module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2622 
2623   auto tuple0 = builder.AddInstruction(
2624       HloInstruction::CreateTuple({input0, weights0, output0}));
2625   auto tuple1 = builder.AddInstruction(
2626       HloInstruction::CreateTuple({input1, weights1, output1}));
2627 
2628   auto while0 = builder.AddInstruction(
2629       HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple0));
2630   auto while1 = builder.AddInstruction(
2631       HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple1));
2632 
2633   auto gte0 = builder.AddInstruction(
2634       HloInstruction::CreateGetTupleElement(data_shape_, while0, 0));
2635   auto gte1 = builder.AddInstruction(
2636       HloInstruction::CreateGetTupleElement(data_shape_, while1, 1));
2637   auto root_add = builder.AddInstruction(
2638       HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, gte0, gte1));
2639 
2640   module->AddEntryComputation(builder.Build());
2641 
2642   {
2643     FlattenCallGraph flatten;
2644     TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
2645     EXPECT_TRUE(result);
2646   }
2647 
2648   RunCopyInsertion(module.get());
2649 
2650   HloSchedule schedule =
2651       ScheduleModule(module.get(), ByteSizeOf).ConsumeValueOrDie();
2652 
2653   // To trigger b/38494731, we want a specific Hlo schedule for the
2654   // root computation, so we overwrite that entry with a manually
2655   // crafted sequence.
2656   schedule.set_sequence(
2657       module->entry_computation(),
2658       {input1, weights1, one, output1, while1->mutable_operand(0), while1,
2659        input0, weights0, zero, output0, while0->mutable_operand(0), while0,
2660        gte0, gte1, root_add});
2661 
2662   // If this ASSERT fails, we constructed a bogus sequence above and this test
2663   // itself is buggy.
2664   TF_ASSERT_OK(schedule.Verify());
2665 
2666   auto assignment =
2667       BufferAssigner::Run(
2668           module.get(), absl::make_unique<SequentialHloOrdering>(schedule),
2669           ByteSizeOf, [](LogicalBuffer::Color) { return 1; },
2670           /*allocate_buffers_for_constants=*/true)
2671           .ConsumeValueOrDie();
2672 
2673   EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment));
2674 }
2675 
TEST_F(WhileBufferAssignmentTest,WhilesDontShareEntryParamIfLiveOut)2676 TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) {
2677   auto module = CreateNewVerifiedModule();
2678   auto builder = HloComputation::Builder("entry");
2679 
2680   auto input0 = builder.AddInstruction(
2681       HloInstruction::CreateParameter(0, data_shape_, "input0"));
2682   auto weights0 = builder.AddInstruction(
2683       HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2684 
2685   auto zero = builder.AddInstruction(
2686       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2687   auto output0 = builder.AddInstruction(
2688       HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2689   auto output1 = builder.AddInstruction(
2690       HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2691 
2692   auto cond0 =
2693       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2694   auto body0 =
2695       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2696 
2697   auto tuple0 = builder.AddInstruction(
2698       HloInstruction::CreateTuple({input0, weights0, output0}));
2699   auto while0 = builder.AddInstruction(
2700       HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
2701 
2702   // Get output of 'while0' and feed as input to 'while1'.
2703   auto while0_out = builder.AddInstruction(
2704       HloInstruction::CreateGetTupleElement(data_shape_, while0, 2));
2705 
2706   auto cond1 =
2707       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2708   auto body1 =
2709       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2710 
2711   auto tuple1 = builder.AddInstruction(
2712       HloInstruction::CreateTuple({while0_out, weights0, output1}));
2713   auto while1 = builder.AddInstruction(
2714       HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
2715 
2716   // Get output of 'while1' so that it is live out of computation.
2717   auto while1_out = builder.AddInstruction(
2718       HloInstruction::CreateGetTupleElement(data_shape_, while1, 2));
2719 
2720   module->AddEntryComputation(builder.Build());
2721   RunCopyInsertion(module.get());
2722   auto assignment = RunBufferAssignment(module.get());
2723   // Get BufferAllocation for root instruction.
2724   auto* root_alloc = assignment->GetUniqueTopLevelSlice(while1_out)
2725                          .ConsumeValueOrDie()
2726                          .allocation();
2727   // Test that root instruction allocation is live out.
2728   EXPECT_TRUE(root_alloc->maybe_live_out());
2729   // Test that root instruction allocation is not an entry parameter.
2730   EXPECT_FALSE(root_alloc->is_entry_computation_parameter());
2731 }
2732 
TEST_F(WhileBufferAssignmentTest,WhileWithDynamicUpdateSliceShare)2733 TEST_F(WhileBufferAssignmentTest, WhileWithDynamicUpdateSliceShare) {
2734   const char* const hlo_string = R"(
2735 HloModule test
2736 
2737 while_body {
2738   state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
2739   constant.1 = f32[] constant(0)
2740   broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={}
2741   get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1
2742   get-tuple-element.3 = s32[] get-tuple-element(state), index=0
2743   constant.2 = s32[] constant(128)
2744   add.5 = s32[] add(get-tuple-element.3, constant.2)
2745   constant.3 = s32[] constant(0)
2746   dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
2747   dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
2748   ROOT tuple.85 = (s32[], f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
2749 }
2750 
2751 while_condition {
2752   state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
2753   get-tuple-element = s32[] get-tuple-element(state), index=0
2754   get-tuple-element.1 = s32[] constant(3)
2755   ROOT less-than.339.338 = pred[] compare(get-tuple-element, get-tuple-element.1), direction=LT
2756 }
2757 
2758 ENTRY entry_computation {
2759   constant.7 = s32[] constant(0)
2760   copy.1 = s32[] copy(constant.7)
2761   constant.6 = f32[] constant(0)
2762   broadcast.6 = f32[1280,1,128]{2,1,0} broadcast(constant.6), dimensions={}
2763   tuple.1 = (s32[], f32[1280,1,128]{2,1,0}) tuple(copy.1, broadcast.6)
2764   while.0 = (s32[], f32[1280,1,128]{2,1,0}) while(tuple.1), condition=while_condition, body=while_body
2765   ROOT get-tuple-element.2 = s32[] get-tuple-element(while.0), index=0
2766 }
2767 
2768 )";
2769   auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
2770   auto module = module_or_status.ConsumeValueOrDie();
2771 
2772   RunCopyInsertion(module.get());
2773   auto assignment = RunBufferAssignment(module.get());
2774   // Get BufferAllocation for root instruction.
2775   auto dus9 = FindInstruction(module.get(), "dynamic-update-slice.9");
2776   auto dus9_alloc_slice =
2777       assignment->GetUniqueTopLevelSlice(dus9).ConsumeValueOrDie();
2778   auto dus5 = FindInstruction(module.get(), "dynamic-update-slice.5");
2779   auto dus5_alloc_slice =
2780       assignment->GetUniqueTopLevelSlice(dus5).ConsumeValueOrDie();
2781   // Test that the two dynamic-update-slice ops share the same allocation slice.
2782   EXPECT_EQ(dus9_alloc_slice.allocation(), dus5_alloc_slice.allocation());
2783   EXPECT_EQ(dus9_alloc_slice, dus5_alloc_slice);
2784 }
2785 }  // namespace
2786 }  // namespace xla
2787