1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
17
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/memory/memory.h"
26 #include "absl/strings/string_view.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/service/buffer_value.h"
29 #include "tensorflow/compiler/xla/service/call_graph.h"
30 #include "tensorflow/compiler/xla/service/copy_insertion.h"
31 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
32 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
35 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
36 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
37 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
38 #include "tensorflow/compiler/xla/service/hlo_parser.h"
39 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/test.h"
42 #include "tensorflow/compiler/xla/test_helpers.h"
43 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
44 #include "tensorflow/compiler/xla/types.h"
45 #include "tensorflow/compiler/xla/xla_data.pb.h"
46 #include "tensorflow/core/lib/core/status_test_util.h"
47 #include "tensorflow/core/platform/macros.h"
48
49 namespace xla {
50 namespace {
51
52 using memory_space_assignment::PresetAssignments;
53 using ::testing::UnorderedElementsAre;
54
55 // DFS visitor that collects the instructions referenced by a computation
56 // without descending into nested computations, i.e., only from the operands.
57 class InstructionListVisitor : public DfsHloVisitorWithDefault {
58 public:
InstructionListVisitor(const HloInstruction * root)59 explicit InstructionListVisitor(const HloInstruction* root) : root_(root) {}
60
DefaultAction(HloInstruction * hlo)61 Status DefaultAction(HloInstruction* hlo) override {
62 // For each instruction, just push it on the list after walking the
63 // operands.
64 instructions_.push_back(hlo);
65 VLOG(0) << "List instruction " << hlo->ToString();
66 return Status::OK();
67 }
68
GetInstructions()69 std::vector<const HloInstruction*> GetInstructions() { return instructions_; }
70
71 private:
72 // The instruction root of the computation.
73 const HloInstruction* root_;
74
75 // The full set of instructions found (may be duplicates, e.g., kParameter).
76 std::vector<const HloInstruction*> instructions_;
77
78 TF_DISALLOW_COPY_AND_ASSIGN(InstructionListVisitor);
79 };
80
GetInstructions(HloInstruction * root)81 const std::vector<const HloInstruction*> GetInstructions(HloInstruction* root) {
82 InstructionListVisitor main_list(root);
83 TF_CHECK_OK(root->Accept(&main_list));
84 return main_list.GetInstructions();
85 }
86
87 class BufferAssignmentTest : public HloTestBase {
88 protected:
~BufferAssignmentTest()89 ~BufferAssignmentTest() override {}
90
RunBufferAssignment(HloModule * module,int64_t alignment=1)91 std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
92 int64_t alignment = 1) {
93 return BufferAssigner::Run(
94 module, absl::make_unique<DependencyHloOrdering>(module),
95 backend().compiler()->BufferSizeBytesFunction(),
96 [alignment](LogicalBuffer::Color) { return alignment; },
97 /*allocate_buffers_for_constants=*/true)
98 .ConsumeValueOrDie();
99 }
100
RunBufferAssignmentNoBuffersForConstants(HloModule * module,int64_t alignment=1)101 std::unique_ptr<BufferAssignment> RunBufferAssignmentNoBuffersForConstants(
102 HloModule* module, int64_t alignment = 1) {
103 return BufferAssigner::Run(
104 module, absl::make_unique<DependencyHloOrdering>(module),
105 backend().compiler()->BufferSizeBytesFunction(),
106 [alignment](LogicalBuffer::Color) { return alignment; },
107 /*allocate_buffers_for_constants=*/false)
108 .ConsumeValueOrDie();
109 }
110
RunBufferAssignmentNoBuffersReuseForAdd(HloModule * module,int64_t alignment=1)111 std::unique_ptr<BufferAssignment> RunBufferAssignmentNoBuffersReuseForAdd(
112 HloModule* module, int64_t alignment = 1) {
113 absl::flat_hash_set<HloOpcode> must_not_live_out = {HloOpcode::kAdd};
114
115 return BufferAssigner::Run(
116 module, absl::make_unique<DependencyHloOrdering>(module),
117 backend().compiler()->BufferSizeBytesFunction(),
118 [alignment](LogicalBuffer::Color) { return alignment; },
119 /*allocate_buffers_for_constants=*/false,
120 /*colorer=*/BufferAssigner::DefaultColorer(),
121 /*must_not_live_out=*/must_not_live_out)
122 .ConsumeValueOrDie();
123 }
124
RunColoredBufferAssignment(HloModule * module,BufferAssigner::Colorer colorer,int64_t alignment=1)125 std::unique_ptr<BufferAssignment> RunColoredBufferAssignment(
126 HloModule* module, BufferAssigner::Colorer colorer,
127 int64_t alignment = 1) {
128 return BufferAssigner::Run(
129 module, absl::make_unique<DependencyHloOrdering>(module),
130 backend().compiler()->BufferSizeBytesFunction(),
131 [alignment](LogicalBuffer::Color) { return alignment; },
132 /*allocate_buffers_for_constants=*/true, std::move(colorer))
133 .ConsumeValueOrDie();
134 }
135
RunBufferAssignmentWithInstructionSequence(HloModule * module,absl::Span<HloInstruction * const> instruction_sequence,int64_t alignment=1)136 std::unique_ptr<BufferAssignment> RunBufferAssignmentWithInstructionSequence(
137 HloModule* module, absl::Span<HloInstruction* const> instruction_sequence,
138 int64_t alignment = 1) {
139 HloSchedule schedule(module);
140 schedule.set_sequence(module->entry_computation(), instruction_sequence);
141 return BufferAssigner::Run(
142 module, absl::make_unique<SequentialHloOrdering>(schedule),
143 backend().compiler()->BufferSizeBytesFunction(),
144 [alignment](LogicalBuffer::Color) { return alignment; },
145 /*allocate_buffers_for_constants=*/true)
146 .ConsumeValueOrDie();
147 }
148
RunBufferAssignmentWithPresetAssignments(HloModule * module,std::unique_ptr<PresetAssignments> preset_assignments,int64_t alignment=1)149 std::unique_ptr<BufferAssignment> RunBufferAssignmentWithPresetAssignments(
150 HloModule* module, std::unique_ptr<PresetAssignments> preset_assignments,
151 int64_t alignment = 1) {
152 return BufferAssigner::Run(
153 module, absl::make_unique<DependencyHloOrdering>(module),
154 backend().compiler()->BufferSizeBytesFunction(),
155 [alignment](LogicalBuffer::Color) { return alignment; },
156 /*allocate_buffers_for_constants=*/true,
157 BufferAssigner::DefaultColorer(),
158 /*must_not_live_out=*/{},
159 /*can_share_buffer=*/nullptr, std::move(preset_assignments))
160 .ConsumeValueOrDie();
161 }
162
163 // Builds an x+1.0 computation to use in a Map.
BuildMapComputationPlus1(const string & name)164 std::unique_ptr<HloComputation> BuildMapComputationPlus1(const string& name) {
165 auto builder = HloComputation::Builder(name);
166 auto param =
167 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
168 auto value = builder.AddInstruction(
169 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
170 builder.AddInstruction(
171 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, value));
172 return builder.Build();
173 }
174
BuildReduceComputation(const string & name)175 std::unique_ptr<HloComputation> BuildReduceComputation(const string& name) {
176 auto builder = HloComputation::Builder(name);
177 auto param =
178 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
179 auto param2 =
180 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "y"));
181 builder.AddInstruction(
182 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, param2));
183 return builder.Build();
184 }
185
186 // Builds a simple compare-to-limit (x < 4) computation for a While.
187 //
188 // condition:
189 // const4[s32] -----------------------------------\
190 // \
191 // param[(s32,f32[4])] --- get-tuple-element[0] --- less-than
192 //
BuildWhileConditionComputation(const string & name)193 std::unique_ptr<HloComputation> BuildWhileConditionComputation(
194 const string& name) {
195 auto builder = HloComputation::Builder(name);
196 auto const4 = builder.AddInstruction(
197 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
198 auto param = builder.AddInstruction(
199 HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
200 auto index = builder.AddInstruction(
201 HloInstruction::CreateGetTupleElement(const4->shape(), param, 0));
202 builder.AddInstruction(
203 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), index,
204 const4, ComparisonDirection::kLt));
205 return builder.Build();
206 }
207
208 // Builds a simple body computation for a While.
209 //
210 // body:
211 // constv[f32[4]] --------------------------------------\
212 // \
213 // /--- get-tuple-elementv[1] --- addv ---\
214 // param[(s32,f32[4])] ---| tuple
215 // \--- get-tuple-elementc[0] --- addc ---/
216 // /
217 // const1[s32] -----------------------------------------/
218 //
BuildWhileBodyComputation(const string & name)219 std::unique_ptr<HloComputation> BuildWhileBodyComputation(
220 const string& name) {
221 auto builder = HloComputation::Builder(name);
222 auto const1 = builder.AddInstruction(
223 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
224 auto constv = builder.AddInstruction(HloInstruction::CreateConstant(
225 LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
226 auto param = builder.AddInstruction(
227 HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
228 auto indexc = builder.AddInstruction(
229 HloInstruction::CreateGetTupleElement(const1->shape(), param, 0));
230 auto addc = builder.AddInstruction(HloInstruction::CreateBinary(
231 indexc->shape(), HloOpcode::kAdd, indexc, const1));
232 auto indexv = builder.AddInstruction(
233 HloInstruction::CreateGetTupleElement(constv->shape(), param, 1));
234 auto addv = builder.AddInstruction(HloInstruction::CreateBinary(
235 constv->shape(), HloOpcode::kAdd, indexv, constv));
236 builder.AddInstruction(HloInstruction::CreateTuple({addc, addv}));
237 return builder.Build();
238 }
239
BuildR0F32UnaryOpComputation(HloOpcode opcode,const string & name)240 std::unique_ptr<HloComputation> BuildR0F32UnaryOpComputation(
241 HloOpcode opcode, const string& name) {
242 auto builder = HloComputation::Builder(name);
243 auto param =
244 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
245 builder.AddInstruction(HloInstruction::CreateUnary(r0f32_, opcode, param));
246 return builder.Build();
247 }
248
249 // Verifies that the given instruction hlo has a valid input buffer assigned,
250 // i.e., the parameter number matches the op's.
GetAssignedInputAllocation(const BufferAssignment & buffers,HloInstruction * hlo)251 const BufferAllocation& GetAssignedInputAllocation(
252 const BufferAssignment& buffers, HloInstruction* hlo) {
253 LOG(INFO) << "Checking input: " << hlo->ToString();
254 const BufferAllocation& buffer =
255 *buffers.GetUniqueTopLevelSlice(hlo).ConsumeValueOrDie().allocation();
256 EXPECT_EQ(hlo->parameter_number(), buffer.parameter_number());
257 return buffer;
258 }
259
260 // Verifies that the given instruction hlo has a valid output buffer
261 // assigned, and returns it.
GetAssignedOutputAllocation(const BufferAssignment & buffers,HloInstruction * hlo)262 const BufferAllocation& GetAssignedOutputAllocation(
263 const BufferAssignment& buffers, HloInstruction* hlo) {
264 LOG(INFO) << "Checking output: " << hlo->ToString();
265 const BufferAllocation& buffer = GetTopLevelAllocation(buffers, hlo);
266 return buffer;
267 }
268
269 // Returns the allocation for the given instruction.
GetAllocation(const BufferAssignment & buffers,const HloInstruction * hlo,const ShapeIndex & index)270 const BufferAllocation& GetAllocation(const BufferAssignment& buffers,
271 const HloInstruction* hlo,
272 const ShapeIndex& index) {
273 return *buffers.GetUniqueSlice(hlo, index).ConsumeValueOrDie().allocation();
274 }
GetTopLevelAllocation(const BufferAssignment & buffers,const HloInstruction * hlo)275 const BufferAllocation& GetTopLevelAllocation(const BufferAssignment& buffers,
276 const HloInstruction* hlo) {
277 return *buffers.GetUniqueTopLevelSlice(hlo)
278 .ConsumeValueOrDie()
279 .allocation();
280 }
281
282 // Verifies that all instructions in the given instruction list except
283 // kConstant have assigned buffers, and returns their total size. If min_index
284 // and max_index are not nullptr, the minimum and maximum buffer indices in
285 // the assignment are written into them.
ValidateBuffers(const std::vector<const HloInstruction * > & instructions,const BufferAssignment & buffers)286 int64 ValidateBuffers(const std::vector<const HloInstruction*>& instructions,
287 const BufferAssignment& buffers) {
288 // Verifies all instructions have buffers, and gets the index ranges.
289 for (const HloInstruction* hlo : instructions) {
290 if (!buffers.HasTopLevelAllocation(hlo)) {
291 // If `hlo` has no assigned buffer, it is either a constant or a nested
292 // parameter.
293 EXPECT_TRUE(HloOpcode::kConstant == hlo->opcode() ||
294 HloOpcode::kParameter == hlo->opcode());
295 continue;
296 }
297 }
298
299 // Gets the total size of all buffers assigned.
300 int64_t total_size = 0;
301 for (auto& allocation : buffers.Allocations()) {
302 total_size += allocation.size();
303 }
304 return total_size;
305 }
306
307 // Shapes for use in the examples.
308 Shape s32_ = ShapeUtil::MakeShape(xla::S32, {});
309 Shape r0f32_ = ShapeUtil::MakeShape(xla::F32, {});
310 Shape f32vec4_ = ShapeUtil::MakeShape(F32, {4});
311 Shape f32vec10_ = ShapeUtil::MakeShape(F32, {10});
312 Shape f32vec100_ = ShapeUtil::MakeShape(F32, {100});
313 Shape f32a100x10_ = ShapeUtil::MakeShape(F32, {100, 10});
314 Shape t_s32_f32v4_ = ShapeUtil::MakeTupleShape({s32_, f32vec4_});
315 Shape t_s32_f32v10_ = ShapeUtil::MakeTupleShape({s32_, f32vec10_});
316 };
317
318 // Returns true if the buffers assigned to instructions in "a" are distinct
319 // from the buffers assigned to those in "b" (ie, intersection is empty).
BuffersDistinct(const std::vector<const HloInstruction * > & a,const std::vector<const HloInstruction * > & b,const BufferAssignment & assignment)320 static bool BuffersDistinct(const std::vector<const HloInstruction*>& a,
321 const std::vector<const HloInstruction*>& b,
322 const BufferAssignment& assignment) {
323 absl::flat_hash_set<BufferAllocation::Slice> a_slices;
324 for (const HloInstruction* instruction : a) {
325 if (assignment.HasTopLevelAllocation(instruction)) {
326 a_slices.insert(
327 assignment.GetUniqueTopLevelSlice(instruction).ConsumeValueOrDie());
328 }
329 }
330
331 for (const HloInstruction* instruction : b) {
332 if (assignment.HasTopLevelAllocation(instruction)) {
333 if (a_slices.contains(assignment.GetUniqueTopLevelSlice(instruction)
334 .ConsumeValueOrDie())) {
335 return false;
336 }
337 }
338 }
339 return true;
340 }
341
342 // Tests a computation consisting of a single scalar constant node.
TEST_F(BufferAssignmentTest,ScalarConstant)343 TEST_F(BufferAssignmentTest, ScalarConstant) {
344 auto builder = HloComputation::Builder(TestName());
345 auto const0 = builder.AddInstruction(
346 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
347 auto module = CreateNewVerifiedModule();
348 module->AddEntryComputation(builder.Build());
349
350 {
351 auto buffers = RunBufferAssignment(module.get());
352 EXPECT_TRUE(buffers->HasTopLevelAllocation(const0));
353 }
354
355 {
356 auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get());
357 EXPECT_FALSE(buffers->HasTopLevelAllocation(const0));
358 }
359 }
360
TEST_F(BufferAssignmentTest,BufferForConst)361 TEST_F(BufferAssignmentTest, BufferForConst) {
362 // Addition of two vector constants: checks that internal constant nodes have
363 // no buffers assigned, and their consumer has a buffer.
364 auto builder = HloComputation::Builder(TestName());
365 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
366 LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
367 auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
368 LiteralUtil::CreateR1<float>({4.1f, 4.2f, 4.3f, 4.4f})));
369 auto add = builder.AddInstruction(
370 HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1));
371 auto module = CreateNewVerifiedModule();
372 module->AddEntryComputation(builder.Build());
373
374 {
375 auto buffers = RunBufferAssignment(module.get());
376 EXPECT_TRUE(buffers->HasTopLevelAllocation(const0));
377 EXPECT_TRUE(buffers->HasTopLevelAllocation(const1));
378 GetAssignedOutputAllocation(*buffers, add);
379 }
380 {
381 auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get());
382 EXPECT_FALSE(buffers->HasTopLevelAllocation(const0));
383 EXPECT_FALSE(buffers->HasTopLevelAllocation(const1));
384 GetAssignedOutputAllocation(*buffers, add);
385 }
386 }
387
TEST_F(BufferAssignmentTest,HasAllocationAt)388 TEST_F(BufferAssignmentTest, HasAllocationAt) {
389 // Create a tuple with non-const and const elements and check that
390 // HasAllocationAt works correctly.
391 auto builder = HloComputation::Builder(TestName());
392 auto param0 = builder.AddInstruction(
393 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
394 auto constant = builder.AddInstruction(
395 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
396 auto negate = builder.AddInstruction(
397 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
398 auto tuple = builder.AddInstruction(
399 HloInstruction::CreateTuple({negate, param0, constant}));
400 auto module = CreateNewVerifiedModule();
401 module->AddEntryComputation(builder.Build());
402
403 auto buffers = RunBufferAssignment(module.get());
404 // Make sure that HasAllocationAt() agrees with what HasTopLevelAllocation()
405 // reports for the instruction directly.
406 EXPECT_EQ(buffers->HasTopLevelAllocation(tuple),
407 buffers->HasAllocationAt(tuple, /*index=*/{}));
408 EXPECT_EQ(buffers->HasTopLevelAllocation(negate),
409 buffers->HasAllocationAt(tuple, /*index=*/{0}));
410 EXPECT_EQ(buffers->HasTopLevelAllocation(param0),
411 buffers->HasAllocationAt(tuple, /*index=*/{1}));
412 EXPECT_EQ(buffers->HasTopLevelAllocation(constant),
413 buffers->HasAllocationAt(tuple, /*index=*/{2}));
414 }
415
TEST_F(BufferAssignmentTest,BufferForOutputConst)416 TEST_F(BufferAssignmentTest, BufferForOutputConst) {
417 // This computation copies a constant to output.
418 auto builder = HloComputation::Builder(TestName());
419 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
420 LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
421 auto copy = builder.AddInstruction(
422 HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0));
423 auto module = CreateNewVerifiedModule();
424 module->AddEntryComputation(builder.Build());
425
426 auto buffers = RunBufferAssignment(module.get());
427 // The copy node now has an output buffer.
428 GetAssignedOutputAllocation(*buffers, copy);
429 }
430
TEST_F(BufferAssignmentTest,Basic)431 TEST_F(BufferAssignmentTest, Basic) {
432 // paramscalar ------- (mul) -- (add) -- (sub)
433 // / / /
434 // param0[100] -------/ / /
435 // / /
436 // param1[100] --------------/--------/
437 auto builder = HloComputation::Builder(TestName());
438 auto paramscalar =
439 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
440 auto broadcast = builder.AddInstruction(
441 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
442 auto param0 = builder.AddInstruction(
443 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
444 auto param1 = builder.AddInstruction(
445 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
446 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
447 f32vec100_, HloOpcode::kMultiply, broadcast, param0));
448 auto add = builder.AddInstruction(
449 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
450 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
451 f32vec100_, HloOpcode::kSubtract, add, param1));
452 auto module = CreateNewVerifiedModule();
453 module->AddEntryComputation(builder.Build());
454
455 auto buffers = RunBufferAssignment(module.get());
456
457 // Distinct input buffers were assigned for parameters.
458 BufferAllocation paramscalar_buffer =
459 GetAssignedInputAllocation(*buffers, paramscalar);
460 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
461 BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
462 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
463 EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
464 EXPECT_NE(param0_buffer.index(), param1_buffer.index());
465
466 // The mul node has a valid buffer assigned, doesn't share with input.
467 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
468 EXPECT_NE(mul_buffer.index(), param0_buffer.index());
469
470 // The add node can reuse the mul node's buffer.
471 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
472 EXPECT_EQ(add_buffer.index(), mul_buffer.index());
473
474 // The sub node has a valid output buffer assigned.
475 GetAssignedOutputAllocation(*buffers, sub);
476 }
477
TEST_F(BufferAssignmentTest,AliasedParamCanBeReused)478 TEST_F(BufferAssignmentTest, AliasedParamCanBeReused) {
479 // If an input buffer and output buffer aliases, the input buffer can be
480 // reused for other intermediate results.
481 //
482 // param0[100] ----- (neg1) -- (neg2)
483 // | |
484 // + -------- Aliased ---------+
485
486 auto builder = HloComputation::Builder(TestName());
487
488 auto param = builder.AddInstruction(
489 HloInstruction::CreateParameter(0, f32vec100_, "p0"));
490 auto neg_1 = builder.AddInstruction(
491 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param));
492 auto neg_2 = builder.AddInstruction(
493 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, neg_1));
494
495 auto module = CreateNewVerifiedModule();
496 module->AddEntryComputation(builder.Build());
497
498 TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias({}, 0, {}));
499
500 auto buffers = RunBufferAssignment(module.get());
501
502 BufferAllocation param_buffer = GetAssignedInputAllocation(*buffers, param);
503 BufferAllocation neg_1_buffer = GetAllocation(*buffers, neg_1, {});
504 BufferAllocation neg_2_buffer = GetAllocation(*buffers, neg_2, {});
505
506 // Everything use one buffer.
507 EXPECT_EQ(param_buffer.index(), neg_1_buffer.index());
508 EXPECT_EQ(neg_2_buffer.index(), neg_1_buffer.index());
509 }
510
TEST_F(BufferAssignmentTest,AddCannotReuse)511 TEST_F(BufferAssignmentTest, AddCannotReuse) {
512 // Pass in a special rule to indicate that "add" cannot be live out.
513 //
514 // paramscalar ------- (mul) -- (add) -- (sub)
515 // / / /
516 // param0[100] -------/ / /
517 // / /
518 // param1[100] --------------/--------/
519 auto builder = HloComputation::Builder(TestName());
520 auto paramscalar =
521 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
522 auto broadcast = builder.AddInstruction(
523 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
524 auto param0 = builder.AddInstruction(
525 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
526 auto param1 = builder.AddInstruction(
527 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
528 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
529 f32vec100_, HloOpcode::kMultiply, broadcast, param0));
530 auto add = builder.AddInstruction(
531 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
532 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
533 f32vec100_, HloOpcode::kSubtract, add, param1));
534 auto module = CreateNewVerifiedModule();
535 module->AddEntryComputation(builder.Build());
536
537 auto buffers = RunBufferAssignmentNoBuffersReuseForAdd(module.get());
538
539 // Distinct input buffers were assigned for parameters.
540 BufferAllocation paramscalar_buffer =
541 GetAssignedInputAllocation(*buffers, paramscalar);
542 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
543 BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
544 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
545 EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
546 EXPECT_NE(param0_buffer.index(), param1_buffer.index());
547
548 // The mul node has a valid buffer assigned, doesn't share with input.
549 const BufferAllocation& sub_buffer = GetTopLevelAllocation(*buffers, sub);
550 EXPECT_NE(sub_buffer.index(), param0_buffer.index());
551
552 // The add node cannot reuse the mul node's buffer since we told buffer
553 // assignment so.
554 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
555 EXPECT_NE(add_buffer.index(), sub_buffer.index());
556
557 // The sub node has a valid output buffer assigned.
558 GetAssignedOutputAllocation(*buffers, sub);
559 }
560
TEST_F(BufferAssignmentTest,BasicUniquelyColored)561 TEST_F(BufferAssignmentTest, BasicUniquelyColored) {
562 // paramscalar ------- (mul) -- (add) -- (sub)
563 // / / /
564 // param0[100] -------/ / /
565 // / /
566 // param1[100] --------------/--------/
567 // The output of each op is colored with a different color, so we can not
568 // share anything.
569 auto builder = HloComputation::Builder(TestName());
570 auto paramscalar =
571 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
572 auto broadcast = builder.AddInstruction(
573 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
574 auto param0 = builder.AddInstruction(
575 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
576 auto param1 = builder.AddInstruction(
577 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
578 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
579 f32vec100_, HloOpcode::kMultiply, broadcast, param0));
580 auto add = builder.AddInstruction(
581 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
582 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
583 f32vec100_, HloOpcode::kSubtract, add, param1));
584 auto module = CreateNewVerifiedModule();
585 module->AddEntryComputation(builder.Build());
586
587 auto colorer = [](HloAliasAnalysis* alias_analysis, const HloOrdering&) {
588 int color = 0;
589 for (HloValue::Id id = 0;
590 id < alias_analysis->dataflow_analysis().values().size(); id++) {
591 auto& value = alias_analysis->dataflow_analysis().GetValue(id);
592 value.set_color(BufferValue::Color(color++));
593 }
594 return Status::OK();
595 };
596
597 auto buffers = RunColoredBufferAssignment(module.get(), colorer);
598
599 // Distinct input buffers were assigned for parameters.
600 BufferAllocation paramscalar_buffer =
601 GetAssignedInputAllocation(*buffers, paramscalar);
602 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
603 BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
604 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
605 EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
606 EXPECT_NE(param0_buffer.index(), param1_buffer.index());
607
608 // The mul node has a valid buffer assigned, doesn't share with input.
609 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
610 EXPECT_NE(mul_buffer.index(), param0_buffer.index());
611
612 // The add node can not reuse the mul node's buffer due to coloring.
613 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
614 EXPECT_NE(add_buffer.index(), mul_buffer.index());
615
616 // The sub node has a valid output buffer assigned.
617 GetAssignedOutputAllocation(*buffers, sub);
618
619 // Check if the HLO instructions have the correct colors in the layout.
620 EXPECT_EQ(param0->shape().layout().memory_space(), 2);
621 EXPECT_EQ(param1->shape().layout().memory_space(), 3);
622 EXPECT_EQ(mul->shape().layout().memory_space(), 4);
623 EXPECT_EQ(add->shape().layout().memory_space(), 5);
624 EXPECT_EQ(sub->shape().layout().memory_space(), 6);
625 }
626
TEST_F(BufferAssignmentTest,BasicPartiallyColored)627 TEST_F(BufferAssignmentTest, BasicPartiallyColored) {
628 // paramscalar ------- (mul) -- (add) -- (sub)
629 // / / /
630 // param0[100] -------/ / /
631 // / /
632 // param1[100] --------------/--------/
633 // The output of the mul and the add have the color 1, and the other buffers
634 // have the color 0, which allows the mul and add to share buffers.
635 auto builder = HloComputation::Builder(TestName());
636 auto paramscalar =
637 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
638 auto broadcast = builder.AddInstruction(
639 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
640 auto param0 = builder.AddInstruction(
641 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
642 auto param1 = builder.AddInstruction(
643 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
644 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
645 f32vec100_, HloOpcode::kMultiply, broadcast, param0));
646 auto add = builder.AddInstruction(
647 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
648 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
649 f32vec100_, HloOpcode::kSubtract, add, param1));
650 auto module = CreateNewVerifiedModule();
651 module->AddEntryComputation(builder.Build());
652
653 auto colorer = [](HloAliasAnalysis* alias_analysis, const HloOrdering&) {
654 for (HloValue::Id id = 0;
655 id < alias_analysis->dataflow_analysis().values().size(); id++) {
656 auto& value = alias_analysis->dataflow_analysis().GetValue(id);
657 auto& buffer = alias_analysis->GetBufferContainingValue(value);
658 for (const auto& alias : buffer.values()) {
659 if (alias->instruction()->opcode() == HloOpcode::kAdd ||
660 alias->instruction()->opcode() == HloOpcode::kMultiply) {
661 value.set_color(LogicalBuffer::Color(1));
662 }
663 }
664 if (!value.has_color()) {
665 value.set_color(LogicalBuffer::Color(0));
666 }
667 }
668 return Status::OK();
669 };
670
671 auto buffers = RunColoredBufferAssignment(module.get(), colorer);
672
673 // Distinct input buffers were assigned for parameters.
674 BufferAllocation paramscalar_buffer =
675 GetAssignedInputAllocation(*buffers, paramscalar);
676 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
677 BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
678 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
679 EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
680 EXPECT_NE(param0_buffer.index(), param1_buffer.index());
681
682 // The mul node has a valid buffer assigned, doesn't share with input.
683 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
684 EXPECT_NE(mul_buffer.index(), param0_buffer.index());
685
686 // The add node can reuse the mul node's buffer.
687 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
688 EXPECT_EQ(add_buffer.index(), mul_buffer.index());
689
690 // The sub node has a valid output buffer assigned.
691 GetAssignedOutputAllocation(*buffers, sub);
692
693 // Check if the HLO instructions have the correct colors in the layout.
694 EXPECT_EQ(mul->shape().layout().memory_space(), 1);
695 EXPECT_EQ(add->shape().layout().memory_space(), 1);
696 EXPECT_EQ(sub->shape().layout().memory_space(), 0);
697 EXPECT_EQ(param0->shape().layout().memory_space(), 0);
698 EXPECT_EQ(param1->shape().layout().memory_space(), 0);
699 }
700
TEST_F(BufferAssignmentTest,PresetAssignments)701 TEST_F(BufferAssignmentTest, PresetAssignments) {
702 // paramscalar ------- (mul) -- (add) -- (sub)
703 // / / /
704 // param0[100] -------/ / /
705 // / /
706 // param1[100] --------------/--------/
707 // Similar to BasicPartiallyColored, but the color is set in the layout.
708 // The output of the mul and the add have the color 1 and have preset
709 // assignments, and the other buffers have the color 0, which allows the mul
710 // and add to share buffers.
711 auto builder = HloComputation::Builder(TestName());
712 auto paramscalar =
713 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
714 auto broadcast = builder.AddInstruction(
715 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
716 auto param0 = builder.AddInstruction(
717 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
718 auto param1 = builder.AddInstruction(
719 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
720 Shape f32vec100_color1 =
721 ShapeUtil::MakeShapeWithLayout(F32, {100}, {0}, {}, 0, 1);
722 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
723 f32vec100_color1, HloOpcode::kMultiply, broadcast, param0));
724 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
725 f32vec100_color1, HloOpcode::kAdd, mul, param1));
726 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
727 f32vec100_, HloOpcode::kSubtract, add, param1));
728 auto module = CreateNewVerifiedModule();
729 module->AddEntryComputation(builder.Build());
730
731 auto preset_assignments = absl::make_unique<PresetAssignments>();
732 preset_assignments->add_chunk({mul, {}}, {/*offset=*/100, /*size=*/400});
733 preset_assignments->add_chunk({add, {}}, {/*offset=*/550, /*size=*/400});
734 preset_assignments->assignment_information_for_space(/*memory_space=*/1)
735 ->size = 950;
736
737 auto buffers = RunBufferAssignmentWithPresetAssignments(
738 module.get(), std::move(preset_assignments));
739
740 // Distinct input buffers were assigned for parameters.
741 BufferAllocation paramscalar_buffer =
742 GetAssignedInputAllocation(*buffers, paramscalar);
743 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
744 BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
745 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
746 EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
747 EXPECT_EQ(paramscalar_buffer.color(), LogicalBuffer::Color(0));
748 EXPECT_NE(param0_buffer.index(), param1_buffer.index());
749 EXPECT_EQ(param0_buffer.color(), LogicalBuffer::Color(0));
750
751 // The mul and add use the same preset buffer. Ensure it has the correct color
752 // and offsets.
753 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
754 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
755 EXPECT_EQ(mul_buffer, add_buffer);
756 EXPECT_NE(mul_buffer.index(), param0_buffer.index());
757 EXPECT_EQ(mul_buffer.color(), LogicalBuffer::Color(1));
758
759 EXPECT_EQ(mul_buffer.assigned_buffers().size(), 2);
760 for (const auto& value_and_offsetsize : mul_buffer.assigned_buffers()) {
761 if (value_and_offsetsize.first->instruction() == mul) {
762 EXPECT_EQ(value_and_offsetsize.second.offset, 100);
763 EXPECT_EQ(value_and_offsetsize.second.size, 400);
764 } else {
765 EXPECT_EQ(value_and_offsetsize.first->instruction(), add);
766 EXPECT_EQ(value_and_offsetsize.second.offset, 550);
767 EXPECT_EQ(value_and_offsetsize.second.size, 400);
768 }
769 }
770
771 // The sub node has a valid output buffer assigned.
772 GetAssignedOutputAllocation(*buffers, sub);
773 }
774
TEST_F(BufferAssignmentTest,PresetAssignmentsWhile)775 TEST_F(BufferAssignmentTest, PresetAssignmentsWhile) {
776 // Tests preset assignments when there is no 1-to-1 correspondence between
777 // HloValue and HloBuffer (i.e., a while loop).
778 auto module = CreateNewVerifiedModule();
779 Shape f32vec10_color1 =
780 ShapeUtil::MakeShapeWithLayout(F32, {10}, {0}, {}, 0, 1);
781 Shape t_s32_f32v10_color1 =
782 ShapeUtil::MakeTupleShape({s32_, f32vec10_color1});
783
784 auto cond_builder = HloComputation::Builder("WhileCond");
785 HloInstruction* cond_param = cond_builder.AddInstruction(
786 HloInstruction::CreateParameter(0, t_s32_f32v10_color1, "cond_param"));
787 HloInstruction* cond_iter = cond_builder.AddInstruction(
788 HloInstruction::CreateGetTupleElement(s32_, cond_param, 0));
789 HloInstruction* cond_limit = cond_builder.AddInstruction(
790 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(50)));
791 cond_builder.AddInstruction(
792 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
793 cond_limit, ComparisonDirection::kLt));
794 HloComputation* cond_computation =
795 module->AddEmbeddedComputation(cond_builder.Build());
796
797 auto body_builder = HloComputation::Builder("WhileBody");
798 HloInstruction* body_param = body_builder.AddInstruction(
799 HloInstruction::CreateParameter(0, t_s32_f32v10_color1, "body_param"));
800 HloInstruction* body_iter = body_builder.AddInstruction(
801 HloInstruction::CreateGetTupleElement(s32_, body_param, 0));
802 HloInstruction* body_data = body_builder.AddInstruction(
803 HloInstruction::CreateGetTupleElement(f32vec10_color1, body_param, 1));
804 HloInstruction* body_data_increment = body_builder.AddInstruction(
805 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
806 {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f})));
807 HloInstruction* body_data_next =
808 body_builder.AddInstruction(HloInstruction::CreateBinary(
809 f32vec10_color1, HloOpcode::kAdd, body_data, body_data_increment));
810 HloInstruction* body_iter_increment = body_builder.AddInstruction(
811 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
812 HloInstruction* body_iter_next =
813 body_builder.AddInstruction(HloInstruction::CreateBinary(
814 s32_, HloOpcode::kAdd, body_iter, body_iter_increment));
815 body_builder.AddInstruction(
816 HloInstruction::CreateTuple({body_iter_next, body_data_next}));
817 HloComputation* body_computation =
818 module->AddEmbeddedComputation(body_builder.Build());
819
820 auto builder = HloComputation::Builder(TestName());
821 HloInstruction* iter = builder.AddInstruction(
822 HloInstruction::CreateParameter(0, s32_, "param_iter"));
823 HloInstruction* data = builder.AddInstruction(
824 HloInstruction::CreateParameter(1, f32vec10_, "param_data"));
825 HloInstruction* negate = builder.AddInstruction(
826 HloInstruction::CreateUnary(f32vec10_color1, HloOpcode::kNegate, data));
827 HloInstruction* tuple =
828 builder.AddInstruction(HloInstruction::CreateTuple({iter, negate}));
829 HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
830 t_s32_f32v10_color1, cond_computation, body_computation, tuple));
831 HloInstruction* while_data = builder.AddInstruction(
832 HloInstruction::CreateGetTupleElement(f32vec10_color1, while_op, 1));
833 builder.AddInstruction(HloInstruction::CreateBinary(
834 f32vec10_, HloOpcode::kAdd, while_data, data));
835 module->AddEntryComputation(builder.Build());
836
837 // Set only one preset assignment for while data and its aliases.
838 auto preset_assignments = absl::make_unique<PresetAssignments>();
839 preset_assignments->add_chunk({negate, {}}, {/*offset=*/100, /*size=*/40});
840 preset_assignments->assignment_information_for_space(/*memory_space=*/1)
841 ->size = 140;
842
843 auto buffers = RunBufferAssignmentWithPresetAssignments(
844 module.get(), std::move(preset_assignments));
845
846 // All assigned buffers are aliased so they should have the same offset and
847 // size.
848 const BufferAllocation& data_buffer = GetTopLevelAllocation(*buffers, negate);
849 EXPECT_EQ(data_buffer.assigned_buffers().size(), 5);
850 for (const auto& value_and_offsetsize : data_buffer.assigned_buffers()) {
851 EXPECT_EQ(value_and_offsetsize.second.offset, 100);
852 EXPECT_EQ(value_and_offsetsize.second.size, 40);
853 EXPECT_EQ(value_and_offsetsize.first->color(), LogicalBuffer::Color(1));
854 }
855 }
856
TEST_F(BufferAssignmentTest,MultipleUsersForNode)857 TEST_F(BufferAssignmentTest, MultipleUsersForNode) {
858 // This is similar to the Basic test, with the difference that (sub) is
859 // another user of (mul)'s result, so (mul)'s buffer cannot be reused for
860 // (add)'s output.
861 //
862 // paramscalar -------\ /-----------\
863 // \ / \
864 // param0[100] ------- (mul) -- (add) -- (sub)
865 // /
866 // param1[100] ----------------/
867 //
868 auto builder = HloComputation::Builder(TestName());
869 auto paramscalar =
870 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
871 auto broadcast = builder.AddInstruction(
872 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
873 auto param0 = builder.AddInstruction(
874 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
875 auto param1 = builder.AddInstruction(
876 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
877 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
878 f32vec100_, HloOpcode::kMultiply, broadcast, param0));
879 auto add = builder.AddInstruction(
880 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
881 auto sub = builder.AddInstruction(
882 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kSubtract, add, mul));
883 auto module = CreateNewVerifiedModule();
884 module->AddEntryComputation(builder.Build());
885
886 auto buffers = RunBufferAssignment(module.get());
887
888 // Input buffers were assigned for parameters.
889 BufferAllocation paramscalar_buffer =
890 GetAssignedInputAllocation(*buffers, paramscalar);
891 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
892 BufferAllocation param1_index = GetAssignedInputAllocation(*buffers, param1);
893 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
894 EXPECT_NE(paramscalar_buffer.index(), param1_index.index());
895 EXPECT_NE(param0_buffer.index(), param1_index.index());
896
897 // The mul node had a buffer allocated.
898 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
899
900 // Now the add node can't reuse the mul node's buffer.
901 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
902 EXPECT_NE(add_buffer.index(), mul_buffer.index());
903
904 // Log size information for inspection.
905 const std::vector<const HloInstruction*> level0 = GetInstructions(sub);
906 int64_t size0 = ValidateBuffers(level0, *buffers);
907 LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
908 << " for " << level0.size() << " instructions; "
909 << "total buffer size " << size0;
910 }
911
TEST_F(BufferAssignmentTest,TrivialMap)912 TEST_F(BufferAssignmentTest, TrivialMap) {
913 // This tests a trivial x+1 map as the only operation.
914 //
915 // param0[100x10] ---> (map x+1)
916 //
917 // Builds the map function.
918 auto module = CreateNewVerifiedModule();
919 auto map_computation =
920 module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1"));
921 auto inner_last = map_computation->root_instruction();
922
923 // Creates the main kernel and verifies instruction counts.
924 auto builder = HloComputation::Builder(TestName());
925 auto param0 = builder.AddInstruction(
926 HloInstruction::CreateParameter(0, f32a100x10_, "p"));
927 auto map = builder.AddInstruction(
928 HloInstruction::CreateMap(f32a100x10_, {param0}, map_computation));
929 module->AddEntryComputation(builder.Build());
930
931 const std::vector<const HloInstruction*> level0 = GetInstructions(map);
932 EXPECT_EQ(2, level0.size()) << "Invalid main kernel size";
933 const std::vector<const HloInstruction*> level1 = GetInstructions(inner_last);
934 EXPECT_EQ(3, level1.size()) << "Invalid nested add+1 size";
935
936 // Assigns buffers and fetches sizes.
937 auto buffers = RunBufferAssignment(module.get());
938 int64_t size0 = ValidateBuffers(level0, *buffers);
939 int64_t size1 = ValidateBuffers(level1, *buffers);
940
941 // Both algorithms assign the map's buffer before processing the embedded
942 // computation, so we can verify that the buffers aren't shared between them
943 // by checking:
944 EXPECT_TRUE(BuffersDistinct(level0, level1, *buffers))
945 << "Reuse between main kernel and embedded mapping.";
946
947 // An input buffer was assigned for the parameter.
948 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
949
950 // An output buffer was assigned for the map.
951 BufferAllocation map_buffer = GetAssignedOutputAllocation(*buffers, map);
952 EXPECT_NE(param0_buffer.index(), map_buffer.index());
953
954 // The final computation node of the map is an add of an f32 param and a
955 // constant.
956 EXPECT_EQ(HloOpcode::kAdd, inner_last->opcode());
957 const BufferAllocation& inner_add_buffer =
958 GetTopLevelAllocation(*buffers, inner_last);
959 EXPECT_NE(inner_add_buffer.index(), map_buffer.index());
960
961 // Log size information for inspection.
962 LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
963 << " for " << level0.size() + level1.size() << " instructions; "
964 << "total buffer size " << size0 + size1;
965 }
966
TEST_F(BufferAssignmentTest,CannotReuseInputBufferOfReduce)967 TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) {
968 // Make sure that the input buffer of a reduce cannot be reused for its
969 // output. (Reuse is not safe in the general case, as it reshapes and some
970 // out-of-order reductions could overwrite an element before a use.)
971 //
972 // param0[100] --- (exp1) --- (exp2) --- (reduce x+y) --- (exp3)
973 auto module = CreateNewVerifiedModule();
974 auto reduce_computation =
975 module->AddEmbeddedComputation(BuildReduceComputation("f32+f32"));
976
977 auto builder = HloComputation::Builder(TestName());
978 auto param0 = builder.AddInstruction(
979 HloInstruction::CreateParameter(0, f32a100x10_, "p"));
980 auto exp1 = builder.AddInstruction(
981 HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, param0));
982 auto exp2 = builder.AddInstruction(
983 HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, exp1));
984 auto const0 = builder.AddInstruction(
985 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
986 auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
987 /*shape=*/f32vec10_,
988 /*operand=*/exp2,
989 /*init_value=*/const0,
990 /*dimensions_to_reduce=*/{0}, reduce_computation));
991 auto exp3 = builder.AddInstruction(
992 HloInstruction::CreateUnary(f32vec10_, HloOpcode::kExp, reduce));
993
994 module->AddEntryComputation(builder.Build());
995
996 auto buffers = RunBufferAssignment(module.get());
997 const std::vector<const HloInstruction*> instrs = GetInstructions(exp3);
998 ValidateBuffers(instrs, *buffers);
999
1000 const BufferAllocation& exp1_buffer = GetTopLevelAllocation(*buffers, exp1);
1001 const BufferAllocation& exp2_buffer = GetTopLevelAllocation(*buffers, exp2);
1002 const BufferAllocation& reduce_buffer =
1003 GetTopLevelAllocation(*buffers, reduce);
1004
1005 // The buffer of exp1 is trivially reusable for exp2 - this is just for sanity
1006 // checking.
1007 EXPECT_EQ(exp1_buffer.index(), exp2_buffer.index());
1008
1009 // The buffer of exp2 cannot be used for reduce, even though it's the only
1010 // operand.
1011 EXPECT_NE(exp2_buffer.index(), reduce_buffer.index());
1012 }
1013
TEST_F(BufferAssignmentTest,ExampleWhile)1014 TEST_F(BufferAssignmentTest, ExampleWhile) {
1015 // This tests a While loop example from the ir_semantics document.
1016 //
1017 // condition (s32,f32[4]) -> bool -- see BuildWhileConditionComputation.
1018 // body: (s32,f32[4]) -> (s32,f32[4]) -- see BuildWhileBodyComputation.
1019 //
1020 // const3[s32] -------\
1021 // const4[f32[4]] --- tuple --- while[condition, body]
1022 //
1023 // Builds the nested condition and body.
1024 auto module = CreateNewVerifiedModule();
1025 auto condition_computation =
1026 module->AddEmbeddedComputation(BuildWhileConditionComputation("if<4"));
1027 auto body_computation =
1028 module->AddEmbeddedComputation(BuildWhileBodyComputation("add-update"));
1029
1030 // Creates the main kernel and verifies instruction counts.
1031 auto builder = HloComputation::Builder(TestName());
1032 auto const3 = builder.AddInstruction(
1033 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
1034 auto const4 = builder.AddInstruction(HloInstruction::CreateConstant(
1035 LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
1036 auto tuple =
1037 builder.AddInstruction(HloInstruction::CreateTuple({const3, const4}));
1038 auto while_op = builder.AddInstruction(HloInstruction::CreateWhile(
1039 t_s32_f32v4_, condition_computation, body_computation, tuple));
1040 module->AddEntryComputation(builder.Build());
1041
1042 const std::vector<const HloInstruction*> level0 = GetInstructions(while_op);
1043 EXPECT_EQ(4, level0.size()) << "Invalid while kernel size";
1044 const std::vector<const HloInstruction*> levelc =
1045 GetInstructions(condition_computation->root_instruction());
1046 EXPECT_EQ(4, levelc.size()) << "Invalid nested condition size";
1047 const std::vector<const HloInstruction*> levelb =
1048 GetInstructions(body_computation->root_instruction());
1049 EXPECT_EQ(8, levelb.size()) << "Invalid nested body size";
1050
1051 // Assigns buffers and fetches sizes.
1052 auto buffers = RunBufferAssignment(module.get());
1053 int64_t size0 = ValidateBuffers(level0, *buffers);
1054 int64_t sizec = ValidateBuffers(levelc, *buffers);
1055 int64_t sizeb = ValidateBuffers(levelb, *buffers);
1056
1057 // BufferAssignment will assign a single allocation for the following
1058 // instructions: while, while.cond.param, while.body.param, while.body.result.
1059 EXPECT_FALSE(BuffersDistinct(level0, levelc, *buffers))
1060 << "Should be reuse between main kernel and embedded condition.";
1061 EXPECT_FALSE(BuffersDistinct(levelb, levelc, *buffers))
1062 << "Should be reuse between embedded condition and body.";
1063 // Expect buffer reuse between main kernel and body computation.
1064 EXPECT_FALSE(BuffersDistinct(level0, levelb, *buffers))
1065 << "Should be reuse between main kernel and embedded body.";
1066
1067 // The final computation node of the while body is a tuple of s32 and
1068 // f32[4] adds.
1069 HloInstruction* body_root = body_computation->root_instruction();
1070 EXPECT_EQ(HloOpcode::kTuple, body_root->opcode());
1071
1072 // Check that buffer for each subshape of 'while_op' shares allocation with
1073 // corresponding buffer from while body computation at same index.
1074 ShapeUtil::ForEachSubshape(
1075 while_op->shape(),
1076 [this, &buffers, while_op, body_root](const Shape& /*subshape*/,
1077 const ShapeIndex& index) {
1078 auto while_op_allocation = GetAllocation(*buffers, while_op, index);
1079 auto body_root_allocation = GetAllocation(*buffers, body_root, index);
1080 EXPECT_EQ(while_op_allocation.index(), body_root_allocation.index());
1081 });
1082
1083 // Log size information for inspection.
1084 LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
1085 << " for " << level0.size() + levelc.size() + levelb.size()
1086 << " instructions; total buffer size " << size0 + sizec + sizeb;
1087 }
1088
TEST_F(BufferAssignmentTest,ExampleConditional)1089 TEST_F(BufferAssignmentTest, ExampleConditional) {
1090 auto module = CreateNewVerifiedModule();
1091 auto true_computation = module->AddEmbeddedComputation(
1092 BuildR0F32UnaryOpComputation(HloOpcode::kCeil, "Ceil"));
1093 auto false_computation = module->AddEmbeddedComputation(
1094 BuildR0F32UnaryOpComputation(HloOpcode::kFloor, "Floor"));
1095
1096 auto builder = HloComputation::Builder(TestName());
1097 auto pred = builder.AddInstruction(
1098 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1099 auto const1 = builder.AddInstruction(
1100 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.4f)));
1101 auto const2 = builder.AddInstruction(
1102 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.4f)));
1103 auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
1104 r0f32_, pred, const1, true_computation, const2, false_computation));
1105 module->AddEntryComputation(builder.Build());
1106
1107 const std::vector<const HloInstruction*> conditional_instrs =
1108 GetInstructions(conditional);
1109 const std::vector<const HloInstruction*> true_instrs =
1110 GetInstructions(true_computation->root_instruction());
1111 const std::vector<const HloInstruction*> false_instrs =
1112 GetInstructions(false_computation->root_instruction());
1113 EXPECT_EQ(4, conditional_instrs.size());
1114 EXPECT_EQ(2, true_instrs.size());
1115 EXPECT_EQ(2, false_instrs.size());
1116
1117 auto buffers = RunBufferAssignment(module.get());
1118 ValidateBuffers(conditional_instrs, *buffers);
1119 ValidateBuffers(true_instrs, *buffers);
1120 ValidateBuffers(false_instrs, *buffers);
1121
1122 EXPECT_FALSE(BuffersDistinct(conditional_instrs, true_instrs, *buffers))
1123 << "Should be reuse between conditional and true computation.";
1124 EXPECT_FALSE(BuffersDistinct(conditional_instrs, false_instrs, *buffers))
1125 << "Should be reuse between conditional and false computation.";
1126 EXPECT_FALSE(BuffersDistinct(true_instrs, false_instrs, *buffers))
1127 << "Should be reuse between true and false computations.";
1128
1129 const BufferAllocation& conditional_buffer =
1130 GetTopLevelAllocation(*buffers, conditional);
1131 const BufferAllocation& true_buffer =
1132 GetTopLevelAllocation(*buffers, true_computation->root_instruction());
1133 const BufferAllocation& false_buffer =
1134 GetTopLevelAllocation(*buffers, false_computation->root_instruction());
1135 EXPECT_EQ(conditional_buffer.size(), true_buffer.size());
1136 EXPECT_EQ(conditional_buffer.size(), false_buffer.size());
1137 }
1138
TEST_F(BufferAssignmentTest,UnaryOpReuseChain)1139 TEST_F(BufferAssignmentTest, UnaryOpReuseChain) {
1140 // param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg)
1141 auto builder = HloComputation::Builder(TestName());
1142 auto param0 = builder.AddInstruction(
1143 HloInstruction::CreateParameter(0, f32vec100_, "p"));
1144 auto exp1 = builder.AddInstruction(
1145 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, param0));
1146 auto tanh = builder.AddInstruction(
1147 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kTanh, exp1));
1148 auto exp2 = builder.AddInstruction(
1149 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, tanh));
1150 auto neg = builder.AddInstruction(
1151 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, exp2));
1152
1153 auto module = CreateNewVerifiedModule();
1154 module->AddEntryComputation(builder.Build());
1155 auto assignment = RunBufferAssignment(module.get());
1156
1157 // tanh and exp2 can reuse exp1's buffer
1158 EXPECT_TRUE(assignment->HasTopLevelAllocation(exp1));
1159 auto& buffer_for_exp1 = GetTopLevelAllocation(*assignment, exp1);
1160 EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, tanh));
1161 EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, exp2));
1162 EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, neg));
1163 }
1164
TEST_F(BufferAssignmentTest,ReuseNonOperandBuffer)1165 TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) {
1166 // This computation is a chain of operations which decreases in buffer size
1167 // (via slice) then increases in size (via broadcast):
1168 //
1169 // param ---> (negate) ---> (slice) ---> (broadcast)
1170 //
1171 // The negate should share a buffer with broadcast.
1172 auto builder = HloComputation::Builder(TestName());
1173 auto param0 = builder.AddInstruction(
1174 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1175 auto negate = builder.AddInstruction(
1176 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1177 auto slice = builder.AddInstruction(
1178 HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1179 auto broadcast = builder.AddInstruction(
1180 HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
1181
1182 auto module = CreateNewVerifiedModule();
1183 module->AddEntryComputation(builder.Build());
1184 auto assignment = RunBufferAssignment(module.get());
1185
1186 // negate and broadcast should share a buffer.
1187 EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast));
1188 auto& buffer_for_bcast = GetTopLevelAllocation(*assignment, broadcast);
1189 EXPECT_EQ(buffer_for_bcast, GetTopLevelAllocation(*assignment, negate));
1190
1191 // Slice should have its own buffer.
1192 EXPECT_NE(buffer_for_bcast, GetTopLevelAllocation(*assignment, slice));
1193 }
1194
TEST_F(BufferAssignmentTest,NoReuseLiveBuffer)1195 TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) {
1196 // This computation is identical to that in ReuseNonOperandBuffer, but the
1197 // negate value is live until the end of the computation (due to it being an
1198 // operand of the output tuple) preventing reuse.
1199 //
1200 // param ---> (negate) ---> (slice) ---> (broadcast)-> (tuple)
1201 // \-----------------------------------/
1202 //
1203 // The negate should not share a buffer with broadcast.
1204 auto builder = HloComputation::Builder(TestName());
1205 auto param0 = builder.AddInstruction(
1206 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1207 auto negate = builder.AddInstruction(
1208 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1209 auto slice = builder.AddInstruction(
1210 HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1211 auto broadcast = builder.AddInstruction(
1212 HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
1213 builder.AddInstruction(HloInstruction::CreateTuple({negate, broadcast}));
1214
1215 auto module = CreateNewVerifiedModule();
1216 module->AddEntryComputation(builder.Build());
1217 auto assignment = RunBufferAssignment(module.get());
1218
1219 // The instructions should not share buffers.
1220 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1221 GetTopLevelAllocation(*assignment, negate));
1222 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1223 GetTopLevelAllocation(*assignment, slice));
1224 EXPECT_NE(GetTopLevelAllocation(*assignment, negate),
1225 GetTopLevelAllocation(*assignment, slice));
1226 }
1227
TEST_F(BufferAssignmentTest,NoReuseAliasedBuffer)1228 TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) {
1229 // This computation is identical to that in ReuseNonOperandBuffer, but the
1230 // negate value is placed into a tuple which lives to the end of the
1231 // computation. This extends the live range of negate's buffer preventing
1232 // reuse due to buffer aliasing.
1233 //
1234 // param ---> (negate) ---> (tuple) -> (slice) ---> (broadcast)-> (tuple)
1235 // \-----------------------------------/
1236 //
1237 // The negate should not share a buffer with broadcast.
1238 auto builder = HloComputation::Builder(TestName());
1239 auto param0 = builder.AddInstruction(
1240 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1241 auto negate = builder.AddInstruction(
1242 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1243 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({negate}));
1244 auto tuple_element = builder.AddInstruction(
1245 HloInstruction::CreateGetTupleElement(f32vec100_, tuple, 0));
1246 auto slice = builder.AddInstruction(
1247 HloInstruction::CreateSlice(f32vec10_, tuple_element, {0}, {10}, {1}));
1248 auto broadcast = builder.AddInstruction(
1249 HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
1250 builder.AddInstruction(HloInstruction::CreateTuple({tuple, broadcast}));
1251
1252 auto module = CreateNewVerifiedModule();
1253 module->AddEntryComputation(builder.Build());
1254 auto assignment = RunBufferAssignment(module.get());
1255
1256 // The instructions should not share buffers.
1257 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1258 GetTopLevelAllocation(*assignment, negate));
1259 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1260 GetTopLevelAllocation(*assignment, slice));
1261 EXPECT_NE(GetTopLevelAllocation(*assignment, negate),
1262 GetTopLevelAllocation(*assignment, slice));
1263 }
1264
TEST_F(BufferAssignmentTest,DoNotReuseOversizedOutputBuffer)1265 TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) {
1266 // This computation is very similar to ReuseNonOperandBuffer except the
1267 // broadcast has a smaller output than the negate. This should block reuse of
1268 // negate's buffer by broadcast because the output buffer(s) of a computation
1269 // should be exactly sized for the value.
1270 //
1271 // param ---> (negate) ---> (slice) ---> (broadcast)
1272 //
1273 // Neither negate nor slice may share a buffer with broadcast.
1274 auto builder = HloComputation::Builder(TestName());
1275 auto param0 = builder.AddInstruction(
1276 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1277 // Negate output is 100 elements.
1278 auto negate = builder.AddInstruction(
1279 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1280 // Slice output is 10 elements.
1281 auto slice = builder.AddInstruction(
1282 HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1283 // Broadcast output is 40 elements.
1284 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1285 ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
1286
1287 auto module = CreateNewVerifiedModule();
1288 module->AddEntryComputation(builder.Build());
1289 auto assignment = RunBufferAssignment(module.get());
1290
1291 // The broadcast output buffer cannot be shared.
1292 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1293 GetTopLevelAllocation(*assignment, negate));
1294 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1295 GetTopLevelAllocation(*assignment, slice));
1296 }
1297
TEST_F(BufferAssignmentTest,ReuseOutputBufferIfExactlySized)1298 TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) {
1299 // This is identical to DoNotReuseOversizedOutputBuffer except the broadcast
1300 // output is exactly the same size as the negate (rather than being
1301 // smaller). This enables reuse of negate's buffer by the broadcast because
1302 // the output buffer will be sized exactly to its value.
1303 //
1304 // param ---> (negate) ---> (slice) ---> (broadcast)
1305 //
1306 // The negate should *not* share a buffer with broadcast.
1307 auto builder = HloComputation::Builder(TestName());
1308 auto param0 = builder.AddInstruction(
1309 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1310 // Negate output is 100 elements.
1311 auto negate = builder.AddInstruction(
1312 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1313 auto slice = builder.AddInstruction(
1314 HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1315 // Broadcast output is 40 elements.
1316 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1317 ShapeUtil::MakeShape(F32, {10, 10}), slice, {0}));
1318
1319 auto module = CreateNewVerifiedModule();
1320 module->AddEntryComputation(builder.Build());
1321 auto assignment = RunBufferAssignment(module.get());
1322
1323 // negate and broadcast should share a buffer.
1324 EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast));
1325 auto& buffer_for_bcast = GetTopLevelAllocation(*assignment, broadcast);
1326 EXPECT_EQ(buffer_for_bcast, GetTopLevelAllocation(*assignment, negate));
1327
1328 // Slice should have its own buffer.
1329 EXPECT_NE(buffer_for_bcast, GetTopLevelAllocation(*assignment, slice));
1330 }
1331
TEST_F(BufferAssignmentTest,DoNotReuseOversizedOutputBufferInTuple)1332 TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) {
1333 // This computation is very similar to ReuseNonOperandBuffer except the
1334 // broadcast has a smaller output than the negate, and the broadcast is
1335 // contained in the computation output as a tuple element. This should block
1336 // reuse of the negate's buffer by the broadcast because the output buffer(s)
1337 // of a computation should be exactly sized for the value. This includes those
1338 // buffers aliased in the output (eg, contained as tuple elements).
1339 //
1340 // param ---> (negate) ---> (slice) ---> (broadcast) --> (tuple)
1341 //
1342 // Neither negate nor slice may share a buffer with broadcast.
1343 auto builder = HloComputation::Builder(TestName());
1344 auto param0 = builder.AddInstruction(
1345 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1346 // Negate output is 100 elements.
1347 auto negate = builder.AddInstruction(
1348 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1349 // Slice output is 10 elements.
1350 auto slice = builder.AddInstruction(
1351 HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1352 // Broadcast output is 40 elements.
1353 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1354 ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
1355 builder.AddInstruction(HloInstruction::CreateTuple({broadcast}));
1356
1357 auto module = CreateNewVerifiedModule();
1358 module->AddEntryComputation(builder.Build());
1359 auto assignment = RunBufferAssignment(module.get());
1360
1361 // The broadcast output buffer cannot be shared.
1362 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1363 GetTopLevelAllocation(*assignment, negate));
1364 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1365 GetTopLevelAllocation(*assignment, slice));
1366 }
1367
TEST_F(BufferAssignmentTest,EmbeddedComputationBuffers)1368 TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) {
1369 // Verify that buffers for embedded computations are properly marked as
1370 // thread-local and that embedded parameters are not marked as
1371 // is_entry_computation_parameter.
1372 auto module = CreateNewVerifiedModule();
1373 auto vec_shape = ShapeUtil::MakeShape(F32, {42});
1374 auto scalar_shape = ShapeUtil::MakeShape(F32, {});
1375
1376 // Create a scalar computation to use in a map.
1377 auto map_builder = HloComputation::Builder(TestName() + "_map");
1378 auto map_param = map_builder.AddInstruction(
1379 HloInstruction::CreateParameter(0, scalar_shape, "map_param"));
1380 auto map_root = map_builder.AddInstruction(
1381 HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, map_param));
1382 auto map_computation = module->AddEmbeddedComputation(map_builder.Build());
1383
1384 // Create a vector computation to use in a kCall.
1385 auto call_builder = HloComputation::Builder(TestName() + "_call");
1386 auto call_param = call_builder.AddInstruction(
1387 HloInstruction::CreateParameter(0, vec_shape, "vec_param"));
1388 auto call_root = call_builder.AddInstruction(
1389 HloInstruction::CreateUnary(vec_shape, HloOpcode::kExp, call_param));
1390 auto call_computation = module->AddEmbeddedComputation(call_builder.Build());
1391
1392 // Create entry computation which kCalls call_computation and then calls map
1393 // with map_computation on the result.
1394 auto builder = HloComputation::Builder(TestName());
1395 auto param = builder.AddInstruction(
1396 HloInstruction::CreateParameter(0, vec_shape, "param"));
1397 auto call = builder.AddInstruction(
1398 HloInstruction::CreateCall(vec_shape, {param}, call_computation));
1399 auto map = builder.AddInstruction(
1400 HloInstruction::CreateMap(vec_shape, {call}, map_computation));
1401 module->AddEntryComputation(builder.Build());
1402
1403 auto assignment = RunBufferAssignment(module.get());
1404
1405 // Allocations for the map computation should be thread-local and not
1406 // live-out.
1407 auto& map_param_alloc = GetTopLevelAllocation(*assignment, map_param);
1408 EXPECT_FALSE(map_param_alloc.is_entry_computation_parameter());
1409 EXPECT_FALSE(map_param_alloc.maybe_live_out());
1410 EXPECT_TRUE(map_param_alloc.is_thread_local());
1411
1412 auto& map_root_alloc = GetTopLevelAllocation(*assignment, map_root);
1413 EXPECT_FALSE(map_root_alloc.is_entry_computation_parameter());
1414 EXPECT_FALSE(map_root_alloc.maybe_live_out());
1415 EXPECT_TRUE(map_root_alloc.is_thread_local());
1416
1417 // Allocations for the call computation should not be thread-local.
1418 auto& call_param_alloc = GetTopLevelAllocation(*assignment, call_param);
1419 EXPECT_TRUE(call_param_alloc.is_entry_computation_parameter());
1420 EXPECT_FALSE(call_param_alloc.maybe_live_out());
1421 EXPECT_FALSE(call_param_alloc.is_thread_local());
1422
1423 auto& call_root_alloc = GetTopLevelAllocation(*assignment, call_root);
1424 EXPECT_FALSE(call_root_alloc.is_entry_computation_parameter());
1425 EXPECT_FALSE(call_root_alloc.is_thread_local());
1426
1427 // Entry computation allocations can be marked liveout and
1428 // is_entry_computation_parameter.
1429 auto& param_alloc = GetTopLevelAllocation(*assignment, param);
1430 EXPECT_TRUE(param_alloc.is_entry_computation_parameter());
1431 EXPECT_FALSE(param_alloc.maybe_live_out());
1432 EXPECT_FALSE(param_alloc.is_thread_local());
1433
1434 auto& map_alloc = GetTopLevelAllocation(*assignment, map);
1435 EXPECT_FALSE(map_alloc.is_entry_computation_parameter());
1436 EXPECT_TRUE(map_alloc.maybe_live_out());
1437 EXPECT_FALSE(map_alloc.is_thread_local());
1438 }
1439
TEST_F(BufferAssignmentTest,CustomCallEmbeddedComputationBuffers)1440 TEST_F(BufferAssignmentTest, CustomCallEmbeddedComputationBuffers) {
1441 // Verify that buffers for embedded computations in a custom call are properly
1442 // marked as thread-local.
1443 auto module = CreateNewVerifiedModule();
1444 auto scalar_shape = ShapeUtil::MakeShape(F32, {});
1445
1446 // Create a scalar computation to use in a map.
1447 auto map_builder = HloComputation::Builder(TestName() + "_map");
1448 auto map_param = map_builder.AddInstruction(
1449 HloInstruction::CreateParameter(0, scalar_shape, "map_param"));
1450 auto map_root = map_builder.AddInstruction(
1451 HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, map_param));
1452 auto map_computation = module->AddEmbeddedComputation(map_builder.Build());
1453
1454 // Create entry computation with a custom call on map_computation.
1455 auto builder = HloComputation::Builder(TestName());
1456 auto param = builder.AddInstruction(
1457 HloInstruction::CreateParameter(0, scalar_shape, "param"));
1458 builder.AddInstruction(HloInstruction::CreateCustomCall(
1459 scalar_shape, {param}, map_computation, "call_name"));
1460 module->AddEntryComputation(builder.Build());
1461
1462 auto assignment = RunBufferAssignment(module.get());
1463
1464 // Allocations for the map computation should be thread-local and not
1465 // live-out.
1466 auto& map_param_alloc = GetTopLevelAllocation(*assignment, map_param);
1467 EXPECT_FALSE(map_param_alloc.is_entry_computation_parameter());
1468 EXPECT_FALSE(map_param_alloc.maybe_live_out());
1469 EXPECT_TRUE(map_param_alloc.is_thread_local());
1470
1471 auto& map_root_alloc = GetTopLevelAllocation(*assignment, map_root);
1472 EXPECT_FALSE(map_root_alloc.is_entry_computation_parameter());
1473 EXPECT_FALSE(map_root_alloc.maybe_live_out());
1474 EXPECT_TRUE(map_root_alloc.is_thread_local());
1475 }
1476
TEST_F(BufferAssignmentTest,TupleParameterAsOutput)1477 TEST_F(BufferAssignmentTest, TupleParameterAsOutput) {
1478 // Test a computation that returns a tuple parameter.
1479 auto builder = HloComputation::Builder(TestName());
1480 auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
1481 0,
1482 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
1483 ShapeUtil::MakeShape(F32, {}),
1484 ShapeUtil::MakeShape(S32, {42})}),
1485 "param0"));
1486
1487 auto module = CreateNewVerifiedModule();
1488 module->AddEntryComputation(builder.Build());
1489 auto assignment = RunBufferAssignment(module.get());
1490
1491 // There should be four allocations: one for vector of pointers, and one for
1492 // each tuple element.
1493 EXPECT_EQ(4, assignment->Allocations().size());
1494
1495 // Verify each buffer allocation is marked as an entry computation parameter
1496 // and is liveout.
1497 ShapeUtil::ForEachSubshape(
1498 tuple_param->shape(),
1499 [this, &assignment, tuple_param](const Shape& /*subshape*/,
1500 const ShapeIndex& index) {
1501 auto allocation = GetAllocation(*assignment, tuple_param, index);
1502 EXPECT_TRUE(allocation.is_entry_computation_parameter());
1503 EXPECT_EQ(0, allocation.parameter_number());
1504 EXPECT_TRUE(allocation.maybe_live_out());
1505 });
1506 }
1507
TEST_F(BufferAssignmentTest,ElementOfNestedTupleParameterAsOutput)1508 TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) {
1509 // Test a computation which returns a GetElementTuple of a nested tuple
1510 // parameter.
1511 auto builder = HloComputation::Builder(TestName());
1512 auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
1513 0,
1514 ShapeUtil::MakeTupleShape(
1515 {ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
1516 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {42}),
1517 ShapeUtil::MakeShape(S32, {101})})}),
1518 "param0"));
1519 auto tuple_element =
1520 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1521 ShapeUtil::GetSubshape(tuple_param->shape(), {1}), tuple_param, 1));
1522
1523 auto module = CreateNewVerifiedModule();
1524 module->AddEntryComputation(builder.Build());
1525 auto assignment = RunBufferAssignment(module.get());
1526
1527 // Only some of the elements of the input param are liveout.
1528 EXPECT_FALSE(
1529 GetAllocation(*assignment, tuple_param, /*index=*/{}).maybe_live_out());
1530 // Tuple element at index={1} is live out because GetTupleElement({1})
1531 // forwards a pointer to this allocation (instead of defining its own buffer).
1532 EXPECT_TRUE(
1533 GetAllocation(*assignment, tuple_param, /*index=*/{1}).maybe_live_out());
1534 EXPECT_TRUE(GetAllocation(*assignment, tuple_param, /*index=*/{1, 0})
1535 .maybe_live_out());
1536 EXPECT_TRUE(GetAllocation(*assignment, tuple_param, /*index=*/{1, 1})
1537 .maybe_live_out());
1538
1539 // The GetTupleElement output is liveout.
1540 EXPECT_TRUE(
1541 GetTopLevelAllocation(*assignment, tuple_element).maybe_live_out());
1542
1543 // Verify that the GetTupleElement allocations of its elements match the
1544 // corresponding tuple parameter allocations because they alias.
1545 EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1, 0}),
1546 GetAllocation(*assignment, tuple_element, /*index=*/{0}));
1547 EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1, 1}),
1548 GetAllocation(*assignment, tuple_element, /*index=*/{1}));
1549
1550 // GetTupleElement forwards a pointer to its underlying buffer, so verify
1551 // that it has the same allocation than the corresponding parameter element.
1552 EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1}),
1553 GetTopLevelAllocation(*assignment, tuple_element));
1554 }
1555
1556 // TODO(b/32248867): Enable when buffer assignment gives allocations to
1557 // constants.
TEST_F(BufferAssignmentTest,TupleConstantAsOutput)1558 TEST_F(BufferAssignmentTest, TupleConstantAsOutput) {
1559 // Test that a tuple constant which is forwarded to the computation output
1560 // is properly handled.
1561 auto builder = HloComputation::Builder(TestName());
1562 Literal elements[] = {LiteralUtil::CreateR0<int64>(0),
1563 LiteralUtil::CreateR0<int64>(1)};
1564 builder.AddInstruction(HloInstruction::CreateConstant(
1565 LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
1566
1567 auto module = CreateNewVerifiedModule();
1568 module->AddEntryComputation(builder.Build());
1569 auto assignment = RunBufferAssignment(module.get());
1570
1571 EXPECT_EQ(3, assignment->Allocations().size());
1572 }
1573
TEST_F(BufferAssignmentTest,TupleCustomCallAsOutput)1574 TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) {
1575 // Test a computation which returns a tuple custom call value.
1576 auto builder = HloComputation::Builder(TestName());
1577 auto custom_call = builder.AddInstruction(HloInstruction::CreateCustomCall(
1578 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
1579 ShapeUtil::MakeShape(S32, {101})}),
1580 /*operands=*/{}, /*custom_call_target=*/"foo_function"));
1581 auto module = CreateNewVerifiedModule();
1582 module->AddEntryComputation(builder.Build());
1583 auto assignment = RunBufferAssignment(module.get());
1584
1585 EXPECT_EQ(3, assignment->Allocations().size());
1586 EXPECT_TRUE(
1587 GetAllocation(*assignment, custom_call, /*index=*/{}).maybe_live_out());
1588 EXPECT_TRUE(
1589 GetAllocation(*assignment, custom_call, /*index=*/{0}).maybe_live_out());
1590 EXPECT_TRUE(
1591 GetAllocation(*assignment, custom_call, /*index=*/{1}).maybe_live_out());
1592 }
1593
TEST_F(BufferAssignmentTest,TupleCallAsOutput)1594 TEST_F(BufferAssignmentTest, TupleCallAsOutput) {
1595 // Test a computation which returns a tuple call value.
1596 auto module = CreateNewVerifiedModule();
1597 auto elem_shape = f32vec4_;
1598 auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape});
1599
1600 auto sub_builder = HloComputation::Builder(TestName() + "_sub");
1601 auto sub_param = sub_builder.AddInstruction(
1602 HloInstruction::CreateParameter(0, elem_shape, "sub_param"));
1603 auto sub_tuple =
1604 sub_builder.AddInstruction(HloInstruction::CreateTuple({sub_param}));
1605 auto sub_computation = module->AddEmbeddedComputation(sub_builder.Build());
1606
1607 auto builder = HloComputation::Builder(TestName());
1608 auto param = builder.AddInstruction(
1609 HloInstruction::CreateParameter(0, elem_shape, "param"));
1610 auto call = builder.AddInstruction(
1611 HloInstruction::CreateCall(tuple_shape, {param}, sub_computation));
1612 module->AddEntryComputation(builder.Build());
1613
1614 auto assignment = RunBufferAssignment(module.get());
1615
1616 EXPECT_EQ(2, assignment->Allocations().size());
1617 // Buffers for call are colocated with the sub-computation.
1618 EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{}),
1619 GetAllocation(*assignment, sub_tuple, /*index=*/{}));
1620 EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{0}),
1621 GetAllocation(*assignment, sub_param, /*index=*/{}));
1622
1623 // The parameter isn't aliased with the result tuple, but it is aliased with
1624 // the call operand.
1625 EXPECT_NE(GetTopLevelAllocation(*assignment, param),
1626 GetTopLevelAllocation(*assignment, sub_tuple));
1627 EXPECT_EQ(GetTopLevelAllocation(*assignment, param),
1628 GetTopLevelAllocation(*assignment, sub_param));
1629 }
1630
TEST_F(BufferAssignmentTest,TupleChainedCallAsOutput)1631 TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) {
1632 // Test a chain of calls with tuple output. The chain looks like:
1633 // A: call(B, tuple(param))
1634 // B: call(C, param)
1635 // C: call(D, param)
1636 // D: param
1637 auto module = CreateNewVerifiedModule();
1638 auto elem_shape = f32vec4_;
1639 auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape});
1640
1641 auto d_builder = HloComputation::Builder(TestName() + "_d");
1642 auto d_param = d_builder.AddInstruction(
1643 HloInstruction::CreateParameter(0, tuple_shape, "d_param"));
1644 auto d_computation = d_builder.Build();
1645
1646 auto c_builder = HloComputation::Builder(TestName() + "_c");
1647 auto c_param = c_builder.AddInstruction(
1648 HloInstruction::CreateParameter(0, tuple_shape, "c_param"));
1649 auto c_call = c_builder.AddInstruction(
1650 HloInstruction::CreateCall(tuple_shape, {c_param}, d_computation.get()));
1651 auto c_computation = c_builder.Build();
1652
1653 auto b_builder = HloComputation::Builder(TestName() + "_b");
1654 auto b_param = b_builder.AddInstruction(
1655 HloInstruction::CreateParameter(0, tuple_shape, "b_param"));
1656 auto b_call = b_builder.AddInstruction(
1657 HloInstruction::CreateCall(tuple_shape, {b_param}, c_computation.get()));
1658 auto b_computation = b_builder.Build();
1659
1660 auto a_builder = HloComputation::Builder(TestName());
1661 auto a_param = a_builder.AddInstruction(
1662 HloInstruction::CreateParameter(0, elem_shape, "param"));
1663 auto a_tuple =
1664 a_builder.AddInstruction(HloInstruction::CreateTuple({a_param}));
1665 auto a_call = a_builder.AddInstruction(
1666 HloInstruction::CreateCall(tuple_shape, {a_tuple}, b_computation.get()));
1667 auto a_computation = a_builder.Build();
1668
1669 // Add the computations in an order that doesn't match the dependency
1670 // post-order, to shake out more possible bugs.
1671 module->AddEmbeddedComputation(std::move(d_computation));
1672 module->AddEmbeddedComputation(std::move(c_computation));
1673 module->AddEntryComputation(std::move(a_computation));
1674 module->AddEmbeddedComputation(std::move(b_computation));
1675
1676 auto assignment = RunBufferAssignment(module.get());
1677
1678 // Buffers for call are colocated with the sub-computations.
1679 EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{}),
1680 GetAllocation(*assignment, b_call, /*index=*/{}));
1681 EXPECT_EQ(GetAllocation(*assignment, b_call, /*index=*/{}),
1682 GetAllocation(*assignment, c_call, /*index=*/{}));
1683 EXPECT_EQ(GetAllocation(*assignment, c_call, /*index=*/{}),
1684 GetAllocation(*assignment, d_param, /*index=*/{}));
1685 EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{0}),
1686 GetAllocation(*assignment, b_call, /*index=*/{0}));
1687 EXPECT_EQ(GetAllocation(*assignment, b_call, /*index=*/{0}),
1688 GetAllocation(*assignment, c_call, /*index=*/{0}));
1689 EXPECT_EQ(GetAllocation(*assignment, c_call, /*index=*/{0}),
1690 GetAllocation(*assignment, d_param, /*index=*/{0}));
1691
1692 EXPECT_TRUE(BuffersDistinct({a_param}, {b_param}, *assignment));
1693 EXPECT_TRUE(BuffersDistinct({a_param}, {c_param}, *assignment));
1694 EXPECT_TRUE(BuffersDistinct({a_param}, {d_param}, *assignment));
1695
1696 EXPECT_EQ(GetAllocation(*assignment, b_param, /*index=*/{0}),
1697 GetAllocation(*assignment, c_param, /*index=*/{0}));
1698 EXPECT_EQ(GetAllocation(*assignment, c_param, /*index=*/{0}),
1699 GetAllocation(*assignment, d_param, /*index=*/{0}));
1700 }
1701
TEST_F(BufferAssignmentTest,BitcastAsOutput)1702 TEST_F(BufferAssignmentTest, BitcastAsOutput) {
1703 // Test a computation which returns a bitcast value.
1704 auto builder = HloComputation::Builder(TestName());
1705 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
1706 0, ShapeUtil::MakeShape(F32, {42}), "param"));
1707 auto bitcast = builder.AddInstruction(
1708 HloInstruction::CreateBitcast(param->shape(), param));
1709
1710 auto module = CreateNewVerifiedModule();
1711 module->AddEntryComputation(builder.Build());
1712 auto assignment = RunBufferAssignment(module.get());
1713
1714 // Bitcast should get the same allocation as the param.
1715 EXPECT_EQ(1, assignment->Allocations().size());
1716 EXPECT_EQ(GetTopLevelAllocation(*assignment, param),
1717 GetTopLevelAllocation(*assignment, bitcast));
1718 }
1719
TEST_F(BufferAssignmentTest,AmbiguousBufferAsOutput)1720 TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) {
1721 // Test a computation with an output that has an ambiguous points-to set.
1722 // This is constructed using a select among tuple shapes.
1723 auto builder = HloComputation::Builder(TestName());
1724 auto tuple_shape =
1725 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4})});
1726
1727 auto tuple_param0 = builder.AddInstruction(
1728 HloInstruction::CreateParameter(0, tuple_shape, "param0"));
1729 auto tuple_param1 = builder.AddInstruction(
1730 HloInstruction::CreateParameter(1, tuple_shape, "param1"));
1731 auto pred_param = builder.AddInstruction(HloInstruction::CreateParameter(
1732 2, ShapeUtil::MakeShape(PRED, {}), "param1"));
1733 auto select = builder.AddInstruction(
1734 HloInstruction::CreateTernary(tuple_shape, HloOpcode::kTupleSelect,
1735 pred_param, tuple_param0, tuple_param1));
1736
1737 auto module = CreateNewVerifiedModule();
1738 module->AddEntryComputation(builder.Build());
1739 auto assignment = RunBufferAssignment(module.get());
1740
1741 // Select shallow copies one of its operands so it defines its own top-level
1742 // buffer and receives its own allocation.
1743 auto select_alloc = GetTopLevelAllocation(*assignment, select);
1744 EXPECT_EQ(1, select_alloc.assigned_buffers().size());
1745 EXPECT_EQ(select,
1746 select_alloc.assigned_buffers().begin()->first->instruction());
1747
1748 // The buffer for the tuple element of the select is forwarded from one its
1749 // operands which cannot be determined statically. Therefore its slices
1750 // should include the slices of both of the elements in the parameters.
1751 auto element_slices = assignment->GetAllSlices(select, /*index=*/{0});
1752 EXPECT_EQ(2, element_slices.size());
1753 EXPECT_THAT(element_slices,
1754 UnorderedElementsAre(
1755 assignment->GetUniqueSlice(tuple_param0, /*index=*/{0})
1756 .ConsumeValueOrDie(),
1757 assignment->GetUniqueSlice(tuple_param1, /*index=*/{0})
1758 .ConsumeValueOrDie()));
1759 }
1760
1761 // TODO(b/34669761): Remove this test when buffers are allowed to share
1762 // allocations.
TEST_F(BufferAssignmentTest,TupleBufferNotReused)1763 TEST_F(BufferAssignmentTest, TupleBufferNotReused) {
1764 // Test a computation that returns a tuple parameter.
1765 auto builder = HloComputation::Builder(TestName());
1766 auto scalar_shape = ShapeUtil::MakeShape(F32, {});
1767 auto param = builder.AddInstruction(
1768 HloInstruction::CreateParameter(0, scalar_shape, "param0"));
1769 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({param}));
1770 auto tuple_element = builder.AddInstruction(
1771 HloInstruction::CreateGetTupleElement(scalar_shape, tuple, 0));
1772 auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
1773 scalar_shape, HloOpcode::kCopy, tuple_element));
1774
1775 auto module = CreateNewVerifiedModule();
1776 module->AddEntryComputation(builder.Build());
1777 auto assignment = RunBufferAssignment(module.get());
1778
1779 // There should be no buffer reuse. The copy should not reuse the tuple
1780 // buffer.
1781 EXPECT_EQ(3, assignment->Allocations().size());
1782 EXPECT_NE(GetTopLevelAllocation(*assignment, tuple),
1783 GetTopLevelAllocation(*assignment, copy));
1784 }
1785
TEST_F(BufferAssignmentTest,OneTempAllocation)1786 TEST_F(BufferAssignmentTest, OneTempAllocation) {
1787 // Test a computation that requires multiple temp buffers, and ensure they
1788 // are combined into a single allocation.
1789 auto builder = HloComputation::Builder(TestName());
1790 Shape shape_2x3 = ShapeUtil::MakeShape(F32, {2, 3});
1791 Shape shape_2x4 = ShapeUtil::MakeShape(F32, {2, 4});
1792 Shape shape_3x4 = ShapeUtil::MakeShape(F32, {3, 4});
1793 Shape shape_4x4 = ShapeUtil::MakeShape(F32, {4, 4});
1794 Shape shape_5x4 = ShapeUtil::MakeShape(F32, {5, 4});
1795
1796 // There should be separate temp buffers for dot_ab and dot_bc.
1797 auto param_a = builder.AddInstruction(
1798 HloInstruction::CreateParameter(0, shape_2x3, "param_a"));
1799 auto param_b = builder.AddInstruction(
1800 HloInstruction::CreateParameter(1, shape_3x4, "param_b"));
1801 auto param_c = builder.AddInstruction(
1802 HloInstruction::CreateParameter(2, shape_4x4, "param_c"));
1803 DotDimensionNumbers dot_dnums;
1804 dot_dnums.add_lhs_contracting_dimensions(1);
1805 dot_dnums.add_rhs_contracting_dimensions(0);
1806 PrecisionConfig precision_config;
1807 precision_config.mutable_operand_precision()->Resize(
1808 2, PrecisionConfig::DEFAULT);
1809 auto dot_ab = builder.AddInstruction(HloInstruction::CreateDot(
1810 shape_2x4, param_a, param_b, dot_dnums, precision_config));
1811 auto dot_bc = builder.AddInstruction(HloInstruction::CreateDot(
1812 shape_3x4, param_b, param_c, dot_dnums, precision_config));
1813 builder.AddInstruction(
1814 HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 0));
1815
1816 // Run buffer assignment with alignment=1.
1817 auto module = CreateNewVerifiedModule();
1818 module->AddEntryComputation(builder.Build());
1819 auto assignment = RunBufferAssignment(module.get(), /*alignment=*/1);
1820
1821 // There are 5 allocations: 3 parameters, 1 output, and 1 temp.
1822 EXPECT_EQ(5, assignment->Allocations().size());
1823
1824 // Ensure the temp buffers for dot_ab and dot_bc share a single allocation,
1825 // and each occupies different slices of that allocation.
1826 BufferAllocation::Slice slice_ab =
1827 assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie();
1828 BufferAllocation::Slice slice_bc =
1829 assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie();
1830 EXPECT_EQ(slice_ab.allocation(), slice_bc.allocation());
1831 EXPECT_NE(slice_ab, slice_bc);
1832 EXPECT_EQ(32, slice_ab.size());
1833 EXPECT_EQ(48, slice_bc.size());
1834 EXPECT_EQ(80, slice_ab.allocation()->size());
1835 EXPECT_EQ(80, slice_bc.allocation()->size());
1836
1837 // Re-run buffer assignment with alignment=64.
1838 assignment = RunBufferAssignment(module.get(), /*alignment=*/64);
1839 EXPECT_EQ(5, assignment->Allocations().size());
1840 slice_ab = assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie();
1841 slice_bc = assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie();
1842 EXPECT_EQ(slice_ab.allocation(), slice_bc.allocation());
1843 EXPECT_NE(slice_ab, slice_bc);
1844 EXPECT_EQ(32, slice_ab.size());
1845 EXPECT_EQ(48, slice_bc.size());
1846 // Ensure the offsets and allocation size account for the alignment, without
1847 // assuming which buffer gets assigned first.
1848 if (slice_ab.offset() == 0) {
1849 EXPECT_EQ(64, slice_bc.offset());
1850 EXPECT_EQ(64 + 48, slice_ab.allocation()->size());
1851 EXPECT_EQ(64 + 48, slice_bc.allocation()->size());
1852 } else {
1853 EXPECT_EQ(64, slice_ab.offset());
1854 EXPECT_EQ(0, slice_bc.offset());
1855 EXPECT_EQ(64 + 32, slice_ab.allocation()->size());
1856 EXPECT_EQ(64 + 32, slice_bc.allocation()->size());
1857 }
1858 }
1859
TEST_F(BufferAssignmentTest,TrivialPeakBuffers)1860 TEST_F(BufferAssignmentTest, TrivialPeakBuffers) {
1861 // paramscalar -(bc)- (mul) -- (add) -- (sub)
1862 // / / /
1863 // param0[100] -------/ / /
1864 // / /
1865 // param1[100] --------------/--------/
1866 auto builder = HloComputation::Builder(TestName());
1867 auto paramscalar =
1868 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
1869 auto broadcast = builder.AddInstruction(
1870 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
1871 auto param0 = builder.AddInstruction(
1872 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
1873 auto param1 = builder.AddInstruction(
1874 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
1875 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
1876 f32vec100_, HloOpcode::kMultiply, broadcast, param0));
1877 auto add = builder.AddInstruction(
1878 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
1879 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
1880 f32vec100_, HloOpcode::kSubtract, add, param1));
1881 auto module = CreateNewVerifiedModule();
1882 module->AddEntryComputation(builder.Build());
1883
1884 auto buffers = RunBufferAssignment(module.get());
1885
1886 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
1887 const std::vector<const HloValue*>& peak_buffers =
1888 mul_buffer.PeakMemoryLogicalBuffers();
1889 ASSERT_EQ(peak_buffers.size(), 1);
1890 EXPECT_EQ(peak_buffers[0]->instruction(), sub);
1891 }
1892
TEST_F(BufferAssignmentTest,PeakBuffers)1893 TEST_F(BufferAssignmentTest, PeakBuffers) {
1894 // Compute the peak liveness buffers of the following sequence:
1895 //
1896 // %param = ...
1897 // %log = log(%param)
1898 // %rev = reverse(%log)
1899 // %neg = neg(%param)
1900 // %concat = concat(%rev, %neg)
1901 // ROOT %root = slice(concat)
1902 //
1903 // In the temporary block, the set of live buffers at peak memory use should
1904 // be {%rev, %neg, %concat}. This occurs right at the concat itself.
1905 auto builder = HloComputation::Builder(TestName());
1906 auto param = builder.AddInstruction(
1907 HloInstruction::CreateParameter(0, f32vec100_, "p"));
1908 auto log = builder.AddInstruction(
1909 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kLog, param));
1910 auto rev = builder.AddInstruction(
1911 HloInstruction::CreateReverse(f32vec100_, log, {0}));
1912 auto neg = builder.AddInstruction(
1913 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param));
1914 const Shape concat_shape = ShapeUtil::MakeShape(F32, {200});
1915 auto concat = builder.AddInstruction(
1916 HloInstruction::CreateConcatenate(concat_shape, {rev, neg}, 0));
1917 // Make the root tiny so no interior nodes can share its buffer.
1918 auto root = builder.AddInstruction(HloInstruction::CreateSlice(
1919
1920 ShapeUtil::MakeShape(F32, {1}), concat, {0}, {1}, {1}));
1921
1922 auto module = CreateNewVerifiedModule();
1923 module->AddEntryComputation(builder.Build());
1924
1925 auto buffers = RunBufferAssignmentWithInstructionSequence(
1926 module.get(), {param, log, rev, neg, concat, root});
1927
1928 // The temporary buffer should hold the 4 interior instructions.
1929 const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, concat);
1930 EXPECT_FALSE(buffer.IsInputOrOutput());
1931 EXPECT_TRUE(buffer.IsPreallocatedTempBuffer());
1932 ASSERT_EQ(buffer.assigned_buffers().size(), 4);
1933
1934 const std::vector<const HloValue*>& peak_buffers =
1935 buffer.PeakMemoryLogicalBuffers();
1936
1937 // The peak live set should be concat and its inputs.
1938 ASSERT_EQ(peak_buffers.size(), 3);
1939 std::vector<const HloInstruction*> peak_instructions;
1940 for (const HloValue* logical_buffer : peak_buffers) {
1941 peak_instructions.push_back(logical_buffer->instruction());
1942 }
1943 EXPECT_THAT(peak_instructions, UnorderedElementsAre(rev, neg, concat));
1944 }
1945
TEST_F(BufferAssignmentTest,InPlaceBuffer)1946 TEST_F(BufferAssignmentTest, InPlaceBuffer) {
1947 const char* hlo_text = R"(
1948 HloModule Module
1949
1950 ENTRY main {
1951 state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
1952 constant.1 = f32[] constant(0)
1953 broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={}
1954 get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1
1955 get-tuple-element.3 = s32[] get-tuple-element(state), index=0
1956 constant.2 = s32[] constant(128)
1957 add.5 = s32[] add(get-tuple-element.3, constant.2)
1958 constant.3 = s32[] constant(0)
1959 dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
1960 dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
1961 ROOT tuple.85 = (s32[], f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
1962 }
1963 )";
1964
1965 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text));
1966 HloInstruction* parameter =
1967 m->entry_computation()->GetInstructionWithName("get-tuple-element.4");
1968 HloInstruction* dus1 =
1969 m->entry_computation()->GetInstructionWithName("dynamic-update-slice.5");
1970 HloInstruction* dus2 =
1971 m->entry_computation()->GetInstructionWithName("dynamic-update-slice.9");
1972
1973 auto buffers = RunBufferAssignment(m.get());
1974
1975 {
1976 const BufferAllocation& parameter_alloc =
1977 GetTopLevelAllocation(*buffers, parameter);
1978
1979 const BufferAllocation& dus1_alloc = GetTopLevelAllocation(*buffers, dus1);
1980 EXPECT_EQ(parameter_alloc, dus1_alloc);
1981 const BufferAllocation& dus2_alloc = GetTopLevelAllocation(*buffers, dus2);
1982 EXPECT_EQ(parameter_alloc, dus2_alloc);
1983 }
1984 }
1985
TEST_F(BufferAssignmentTest,ConstantBuffersAreNotReused)1986 TEST_F(BufferAssignmentTest, ConstantBuffersAreNotReused) {
1987 const char* hlo_text = R"(
1988 HloModule Module
1989
1990 True {
1991 ROOT x.0.1 = f32[] parameter(0)
1992 }
1993
1994 False {
1995 x.0.0 = f32[] parameter(0)
1996 ROOT copy.1 = f32[] copy(x.0.0)
1997 }
1998
1999 ENTRY main {
2000 pred.1.0 = pred[] parameter(0)
2001 constant.1.1 = f32[] constant(56)
2002 copy.2 = f32[] copy(constant.1.1)
2003 constant.1.2 = f32[] constant(12)
2004 ROOT conditional.1.3 = f32[] conditional(pred.1.0, copy.2, constant.1.2),
2005 true_computation=True, false_computation=False
2006 }
2007 )";
2008
2009 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text));
2010 HloInstruction* constant_1 =
2011 m->entry_computation()->GetInstructionWithName("constant.1.1");
2012 HloInstruction* constant_2 =
2013 m->entry_computation()->GetInstructionWithName("constant.1.2");
2014
2015 auto buffers = RunBufferAssignment(m.get());
2016
2017 {
2018 const BufferAllocation& allocation_for_const_1 =
2019 GetTopLevelAllocation(*buffers, constant_1);
2020 EXPECT_TRUE(allocation_for_const_1.is_constant());
2021 for (const auto& buffer_offset_pair :
2022 allocation_for_const_1.assigned_buffers()) {
2023 EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
2024 HloOpcode::kCopy);
2025 EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
2026 HloOpcode::kConditional);
2027 }
2028 }
2029
2030 {
2031 const BufferAllocation& allocation_for_const_2 =
2032 GetTopLevelAllocation(*buffers, constant_2);
2033 EXPECT_TRUE(allocation_for_const_2.is_constant());
2034 for (const auto& buffer_offset_pair :
2035 allocation_for_const_2.assigned_buffers()) {
2036 EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
2037 HloOpcode::kCopy);
2038 EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
2039 HloOpcode::kConditional);
2040 }
2041 }
2042 }
2043
2044 class WhileBufferAssignmentTest : public HloTestBase {
2045 protected:
BuildWhileConditionComputation(const string & name)2046 std::unique_ptr<HloComputation> BuildWhileConditionComputation(
2047 const string& name) {
2048 auto builder = HloComputation::Builder(name);
2049 builder.AddInstruction(
2050 HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
2051 auto zero = builder.AddInstruction(
2052 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
2053 auto ten = builder.AddInstruction(
2054 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(10)));
2055 builder.AddInstruction(HloInstruction::CreateCompare(
2056 ShapeUtil::MakeShape(PRED, {}), zero, ten, ComparisonDirection::kLt));
2057 return builder.Build();
2058 }
2059
BuildWhileBodyComputation(const string & name)2060 std::unique_ptr<HloComputation> BuildWhileBodyComputation(
2061 const string& name) {
2062 auto builder = HloComputation::Builder(name);
2063 auto loop_state = builder.AddInstruction(
2064 HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
2065 auto input = builder.AddInstruction(
2066 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 0));
2067 auto weights = builder.AddInstruction(
2068 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
2069 auto output = builder.AddInstruction(HloInstruction::CreateBinary(
2070 data_shape_, HloOpcode::kMultiply, input, weights));
2071 builder.AddInstruction(
2072 HloInstruction::CreateTuple({input, weights, output}));
2073 return builder.Build();
2074 }
2075
RunBufferAssignment(HloModule * module,int64_t alignment=1)2076 std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
2077 int64_t alignment = 1) {
2078 HloSchedule schedule =
2079 ScheduleModule(module, ByteSizeOf).ConsumeValueOrDie();
2080 return BufferAssigner::Run(
2081 module, absl::make_unique<SequentialHloOrdering>(schedule),
2082 ByteSizeOf,
2083 [alignment](LogicalBuffer::Color) { return alignment; },
2084 /*allocate_buffers_for_constants=*/true)
2085 .ConsumeValueOrDie();
2086 }
2087
ByteSizeOf(const BufferValue & buffer)2088 static int64 ByteSizeOf(const BufferValue& buffer) {
2089 return ShapeUtil::ByteSizeOf(buffer.shape(), sizeof(void*));
2090 }
2091
2092 Shape data_shape_ = ShapeUtil::MakeShape(F32, {4});
2093 Shape loop_state_shape_ =
2094 ShapeUtil::MakeTupleShape({data_shape_, data_shape_, data_shape_});
2095 };
2096
RunCopyInsertion(HloModule * module)2097 static void RunCopyInsertion(HloModule* module) {
2098 CopyInsertion copy_insertion;
2099 EXPECT_IS_OK(copy_insertion.Run(module).status());
2100 }
2101
TEST_F(WhileBufferAssignmentTest,TwoForwardWhileLoops)2102 TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) {
2103 auto module = CreateNewVerifiedModule();
2104 auto builder = HloComputation::Builder("entry");
2105
2106 auto input0 = builder.AddInstruction(
2107 HloInstruction::CreateParameter(0, data_shape_, "input0"));
2108 auto weights0 = builder.AddInstruction(
2109 HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2110 auto weights1 = builder.AddInstruction(
2111 HloInstruction::CreateParameter(2, data_shape_, "weights1"));
2112
2113 auto zero = builder.AddInstruction(
2114 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2115 auto output0 = builder.AddInstruction(
2116 HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2117 auto output1 = builder.AddInstruction(
2118 HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2119
2120 auto cond0 =
2121 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2122 auto body0 =
2123 module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2124
2125 auto tuple0 = builder.AddInstruction(
2126 HloInstruction::CreateTuple({input0, weights0, output0}));
2127 auto while0 = builder.AddInstruction(
2128 HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
2129
2130 auto cond1 =
2131 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2132 auto body1 =
2133 module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2134 auto input1 = builder.AddInstruction(
2135 HloInstruction::CreateGetTupleElement(data_shape_, while0, 2));
2136 auto tuple1 = builder.AddInstruction(
2137 HloInstruction::CreateTuple({input1, weights1, output1}));
2138 auto while1 = builder.AddInstruction(
2139 HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
2140
2141 module->AddEntryComputation(builder.Build());
2142 RunCopyInsertion(module.get());
2143 auto assignment = RunBufferAssignment(module.get());
2144
2145 // Verify 'input0' and read-only use while0{0} alias.
2146 EXPECT_EQ(assignment->GetUniqueSlice(input0, {}).ConsumeValueOrDie(),
2147 assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie());
2148 // Verify 'weights0' and read-only use while0{1} alias.
2149 EXPECT_EQ(assignment->GetUniqueSlice(weights0, {}).ConsumeValueOrDie(),
2150 assignment->GetUniqueSlice(while0, {1}).ConsumeValueOrDie());
2151 // Verify 'while0{2}' and read-only use while1{0} alias.
2152 EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).ConsumeValueOrDie(),
2153 assignment->GetUniqueSlice(while1, {0}).ConsumeValueOrDie());
2154 // Verify 'weights1' and read-only use while1{1} alias.
2155 EXPECT_EQ(assignment->GetUniqueSlice(weights1, {}).ConsumeValueOrDie(),
2156 assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie());
2157 }
2158
2159 // Tests that two colocated buffer sets are not merged if an entry parameter
2160 // buffer belongs to either of the colocation sets (b/73267882).
2161 //
2162 // %param --> %while.0 --> %mul --> %while.1 --> %broadcast
2163 //
2164 // %while.0 body just forwards the init value, so the loop carried variable
2165 // remains the constant, whereas %while.1 changes the loop carried variable.
TEST_F(WhileBufferAssignmentTest,ColocatedBufferWithEntryParameter)2166 TEST_F(WhileBufferAssignmentTest, ColocatedBufferWithEntryParameter) {
2167 const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
2168
2169 const char* module_str = R"(
2170 HloModule test_module
2171
2172 %cond.v0 {
2173 %param = s32[] parameter(0)
2174 ROOT %constant = pred[] constant(true)
2175 }
2176
2177 %cond.v1 {
2178 %param.0 = s32[] parameter(0)
2179 ROOT %constant.0 = pred[] constant(true)
2180 }
2181
2182 %body.v0 {
2183 ROOT %param.1 = s32[] parameter(0)
2184 }
2185
2186 %body.v1 {
2187 %param.2 = s32[] parameter(0)
2188 ROOT add = s32[] add(%param.2, %param.2)
2189 }
2190
2191 ENTRY %test_module {
2192 %param.3 = s32[] parameter(0)
2193 %while.0 = s32[] while(%param.3), condition=%cond.v0, body=%body.v0
2194 %mul = s32[] multiply(%while.0, %while.0)
2195 %while.1 = s32[] while(%mul), condition=%cond.v1, body=%body.v1
2196 ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={}
2197 })";
2198
2199 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
2200
2201 // Run CopyInsertion and check if the graph constructed above doesn't need
2202 // any copies inserted for BufferAssignment to run.
2203 int64_t instruction_count = m->instruction_count();
2204 CopyInsertion copy_insertion;
2205 ASSERT_IS_OK(copy_insertion.Run(m.get()).status());
2206 ASSERT_EQ(instruction_count, m->instruction_count());
2207
2208 // Get the instructions in the module.
2209 const HloInstruction* bcast = m->entry_computation()->root_instruction();
2210 const HloInstruction* param =
2211 m->entry_computation()->parameter_instruction(0);
2212 ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
2213 const HloInstruction* while1 = bcast->operand(0);
2214 ASSERT_EQ(while1->opcode(), HloOpcode::kWhile);
2215 const HloInstruction* while0 = while1->operand(0)->operand(0);
2216 ASSERT_EQ(while0->opcode(), HloOpcode::kWhile);
2217
2218 // Run buffer assignment.
2219 auto assignment = RunBufferAssignment(m.get());
2220 TF_ASSERT_OK_AND_ASSIGN(auto slice_param,
2221 assignment->GetUniqueSlice(param, {}));
2222 TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
2223 assignment->GetUniqueSlice(while0, {}));
2224 TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
2225 assignment->GetUniqueSlice(while1, {}));
2226
2227 // The parameter slice is part of the while0's colocation set (init value),
2228 // but not merged into the while1's colocation set.
2229 EXPECT_EQ(slice_param, slice_while0);
2230 EXPECT_NE(slice_param, slice_while1);
2231 }
2232
TEST_F(WhileBufferAssignmentTest,ColocatedBufferWithConstant)2233 TEST_F(WhileBufferAssignmentTest, ColocatedBufferWithConstant) {
2234 const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
2235
2236 const char* module_str = R"(
2237 HloModule test_module
2238
2239 %cond.v0 {
2240 %param = s32[] parameter(0)
2241 ROOT %constant = pred[] constant(true)
2242 }
2243
2244 %cond.v1 {
2245 %param.0 = s32[] parameter(0)
2246 ROOT %constant.0 = pred[] constant(true)
2247 }
2248
2249 %body.v0 {
2250 ROOT %param.1 = s32[] parameter(0)
2251 }
2252
2253 %body.v1 {
2254 %param.2 = s32[] parameter(0)
2255 ROOT add = s32[] add(%param.2, %param.2)
2256 }
2257
2258 ENTRY %test_module {
2259 %constant.42 = s32[] constant(42)
2260 %while.0 = s32[] while(%constant.42), condition=%cond.v0, body=%body.v0
2261 %mul = s32[] multiply(%while.0, %while.0)
2262 %while.1 = s32[] while(%mul), condition=%cond.v1, body=%body.v1
2263 ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={}
2264 })";
2265
2266 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
2267
2268 // Run CopyInsertion and check if the graph constructed above doesn't need
2269 // any copies inserted for BufferAssignment to run.
2270 int64_t instruction_count = m->instruction_count();
2271 CopyInsertion copy_insertion;
2272 ASSERT_IS_OK(copy_insertion.Run(m.get()).status());
2273 ASSERT_EQ(instruction_count, m->instruction_count());
2274
2275 // Get the instructions in the module.
2276 const HloInstruction* bcast = m->entry_computation()->root_instruction();
2277 const HloInstruction* constant =
2278 m->entry_computation()->GetInstructionWithName("constant.42");
2279 ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
2280 const HloInstruction* while1 = bcast->operand(0);
2281 ASSERT_EQ(while1->opcode(), HloOpcode::kWhile);
2282 const HloInstruction* while0 = while1->operand(0)->operand(0);
2283 ASSERT_EQ(while0->opcode(), HloOpcode::kWhile);
2284
2285 // Run buffer assignment.
2286 auto assignment = RunBufferAssignment(m.get());
2287 TF_ASSERT_OK_AND_ASSIGN(auto slice_constant,
2288 assignment->GetUniqueSlice(constant, {}));
2289 TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
2290 assignment->GetUniqueSlice(while0, {}));
2291 TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
2292 assignment->GetUniqueSlice(while1, {}));
2293
2294 // The constant slice is part of the while0's colocation set (init value), but
2295 // not merged into the while1's colocation set.
2296 EXPECT_EQ(slice_constant, slice_while0);
2297 EXPECT_NE(slice_constant, slice_while1);
2298 }
2299
2300 // Tests that the colocated buffers for while instructions are properly assigned
2301 // during buffer assignment such that the result tuple elements are not assigned
2302 // to the same buffer.
2303 //
2304 // %infeed --> %while.0 --> %while.1 --+
2305 // +-- %tuple
2306 // %zero --> %add --> %while.2 --+
2307 //
2308 // Execution Order:
2309 // %infeed -> %while.0 -> %while.1 -> %zero -> %add -> %while.2 -> %tuple
2310 //
2311 // The HLO computation used in this test requires specific ordering to expose
2312 // the bug (b/72496031). During buffer assignment, the visitation order of
2313 // colocated buffers is %while.2 -> while.0 -> while.1, and the buffer
2314 // assignment was coalescing the colocated buffers for all 3 while instructions,
2315 // therefore assigning the same buffer to the two result tuple elements.
TEST_F(WhileBufferAssignmentTest,ColocatedBuffers)2316 TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
2317 const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
2318
2319 // Builds a condition computation: x -> x < 4
2320 auto build_cond = [&]() {
2321 auto builder = HloComputation::Builder("cond");
2322 auto const4 = builder.AddInstruction(
2323 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
2324 auto param =
2325 builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
2326 builder.AddInstruction(
2327 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param,
2328 const4, ComparisonDirection::kLt));
2329 return builder.Build();
2330 };
2331
2332 // Builds a body computation: x -> x + 9
2333 auto build_body = [&]() {
2334 auto builder = HloComputation::Builder("body");
2335 auto const9 = builder.AddInstruction(
2336 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(9)));
2337 auto param =
2338 builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
2339 builder.AddInstruction(
2340 HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, param, const9));
2341 return builder.Build();
2342 };
2343
2344 // Build the entry computation as described in the comment above.
2345 auto module = CreateNewVerifiedModule();
2346 auto builder = HloComputation::Builder("entry");
2347
2348 auto token = builder.AddInstruction(HloInstruction::CreateToken());
2349 auto infeed =
2350 builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, token, ""));
2351 auto infeed_data = builder.AddInstruction(
2352 HloInstruction::CreateGetTupleElement(r0s32, infeed, 0));
2353 auto cond0 = module->AddEmbeddedComputation(build_cond());
2354 auto body0 = module->AddEmbeddedComputation(build_body());
2355 auto while0 = builder.AddInstruction(
2356 HloInstruction::CreateWhile(r0s32, cond0, body0, infeed_data));
2357
2358 auto cond1 = module->AddEmbeddedComputation(build_cond());
2359 auto body1 = module->AddEmbeddedComputation(build_body());
2360 auto while1 = builder.AddInstruction(
2361 HloInstruction::CreateWhile(r0s32, cond1, body1, while0));
2362
2363 auto zero = builder.AddInstruction(
2364 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
2365 auto add = builder.AddInstruction(
2366 HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, zero, zero));
2367 auto cond2 = module->AddEmbeddedComputation(build_cond());
2368 auto body2 = module->AddEmbeddedComputation(build_body());
2369 auto while2 = builder.AddInstruction(
2370 HloInstruction::CreateWhile(r0s32, cond2, body2, add));
2371
2372 auto tuple =
2373 builder.AddInstruction(HloInstruction::CreateTuple({while2, while1}));
2374 module->AddEntryComputation(builder.Build());
2375
2376 // Run CopyInsertion and check if the graph constructed above doesn't need
2377 // any copies inserted for BufferAssignment to run.
2378 int64_t instruction_count = module->instruction_count();
2379 CopyInsertion copy_insertion;
2380 ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
2381 ASSERT_EQ(instruction_count, module->instruction_count());
2382
2383 // Create a sequential order among all the instructions in the entry
2384 // computation, since the issue this test stresses depends on the order the
2385 // nodes are traversed during BufferAssignment.
2386 TF_ASSERT_OK_AND_ASSIGN(
2387 HloSchedule schedule,
2388 ScheduleModule(module.get(), [](const BufferValue& buffer) {
2389 return ShapeUtil::ByteSizeOf(buffer.shape(),
2390 /*pointer_size=*/sizeof(void*));
2391 }));
2392 schedule.set_sequence(
2393 module->entry_computation(),
2394 {token, infeed, infeed_data, while0, while1, zero, add, while2, tuple});
2395 TF_ASSERT_OK(schedule.Verify());
2396
2397 TF_ASSERT_OK_AND_ASSIGN(
2398 auto assignment,
2399 BufferAssigner::Run(
2400 module.get(), absl::make_unique<SequentialHloOrdering>(schedule),
2401 backend().compiler()->BufferSizeBytesFunction(),
2402 [](LogicalBuffer::Color) { return 1; },
2403 /*allocate_buffers_for_constants=*/true));
2404
2405 // The result tuple elements must be assigned with different buffers.
2406 TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0}));
2407 TF_ASSERT_OK_AND_ASSIGN(auto slice1, assignment->GetUniqueSlice(tuple, {1}));
2408 EXPECT_NE(slice0, slice1);
2409
2410 // while0 and while1 result buffers must be equal to slice1.
2411 TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
2412 assignment->GetUniqueSlice(while0, {}));
2413 TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
2414 assignment->GetUniqueSlice(while1, {}));
2415 EXPECT_EQ(slice1, slice_while0);
2416 EXPECT_EQ(slice1, slice_while1);
2417
2418 // while2 result buffer must be equal to slice0.
2419 TF_ASSERT_OK_AND_ASSIGN(auto slice_while2,
2420 assignment->GetUniqueSlice(while2, {}));
2421 EXPECT_EQ(slice0, slice_while2);
2422 }
2423
TEST_F(WhileBufferAssignmentTest,OneForwardBackwardWhileLoopSet)2424 TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
2425 auto module = CreateNewVerifiedModule();
2426 auto builder = HloComputation::Builder("entry");
2427
2428 auto input0 = builder.AddInstruction(
2429 HloInstruction::CreateParameter(0, data_shape_, "input0"));
2430 auto weights0 = builder.AddInstruction(
2431 HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2432
2433 auto zero = builder.AddInstruction(
2434 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2435 auto output0 = builder.AddInstruction(
2436 HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2437
2438 auto cond0 =
2439 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2440 auto body0 =
2441 module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2442
2443 auto tuple0 = builder.AddInstruction(
2444 HloInstruction::CreateTuple({input0, weights0, output0}));
2445 auto while0 = builder.AddInstruction(
2446 HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
2447
2448 auto cond1 =
2449 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2450 auto body1 =
2451 module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2452
2453 auto while1 = builder.AddInstruction(
2454 HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0));
2455
2456 module->AddEntryComputation(builder.Build());
2457 RunCopyInsertion(module.get());
2458 auto assignment = RunBufferAssignment(module.get());
2459
2460 // while0 and while1 buffers should be completely aligned.
2461 EXPECT_EQ(assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie(),
2462 assignment->GetUniqueSlice(while1, {0}).ConsumeValueOrDie());
2463 EXPECT_EQ(assignment->GetUniqueSlice(while0, {1}).ConsumeValueOrDie(),
2464 assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie());
2465 EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).ConsumeValueOrDie(),
2466 assignment->GetUniqueSlice(while1, {2}).ConsumeValueOrDie());
2467 }
2468
TEST_F(BufferAssignmentTest,TwoCalls)2469 TEST_F(BufferAssignmentTest, TwoCalls) {
2470 auto module = CreateNewVerifiedModule();
2471 Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {});
2472 HloComputation* sub_computation;
2473 {
2474 auto builder = HloComputation::Builder(TestName() + "_sub_comp");
2475 auto param = builder.AddInstruction(
2476 HloInstruction::CreateParameter(0, r0f32, "param"));
2477 auto constant1 = builder.AddInstruction(
2478 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2479 auto add = builder.AddInstruction(
2480 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, constant1));
2481 sub_computation = module->AddEmbeddedComputation(builder.Build(add));
2482 }
2483 auto builder = HloComputation::Builder(TestName());
2484 auto constant2 = builder.AddInstruction(
2485 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
2486 auto constant3 = builder.AddInstruction(
2487 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
2488 auto call1 = builder.AddInstruction(
2489 HloInstruction::CreateCall(r0f32, {constant2}, sub_computation));
2490 auto call2 = builder.AddInstruction(
2491 HloInstruction::CreateCall(r0f32, {constant3}, sub_computation));
2492 auto add1 = builder.AddInstruction(
2493 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, call1, constant2));
2494 auto add2 = builder.AddInstruction(
2495 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, call2, add1));
2496 module->AddEntryComputation(builder.Build(add2));
2497
2498 {
2499 FlattenCallGraph flatten;
2500 TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
2501 EXPECT_TRUE(result);
2502 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
2503 }
2504
2505 RunCopyInsertion(module.get());
2506 auto assignment = RunBufferAssignment(module.get());
2507
2508 EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment));
2509 }
2510
TEST_F(BufferAssignmentTest,CallParamCoAllocation)2511 TEST_F(BufferAssignmentTest, CallParamCoAllocation) {
2512 const char* hlo_text = R"(
2513 HloModule CallParamCoAllocation
2514
2515 Callee {
2516 param0 = (f32[100],(f32[200],f32[300])) parameter(0)
2517 param1 = s32[20] parameter(1)
2518 ROOT constant = f32[] constant(1)
2519 }
2520
2521 ENTRY Main {
2522 entry_param0 = f32[100] parameter(0)
2523 entry_param1 = s32[20] parameter(1)
2524 custom_call = (f32[200],f32[300]) custom-call(), custom_call_target="call-target"
2525 call_op0 = (f32[100],(f32[200],f32[300])) tuple(entry_param0, custom_call)
2526 ROOT call_result = f32[] call(call_op0, entry_param1), to_apply=Callee
2527 }
2528 )";
2529
2530 HloModuleConfig config;
2531 config.set_debug_options(GetDebugOptionsFromFlags());
2532 TF_ASSERT_OK_AND_ASSIGN(auto m,
2533 ParseAndReturnVerifiedModule(hlo_text, config));
2534
2535 auto buffers = RunBufferAssignment(m.get());
2536
2537 HloComputation* main = m->entry_computation();
2538 HloComputation* callee = m->GetComputationWithName("Callee");
2539 EXPECT_NE(callee, nullptr);
2540
2541 HloInstruction* param0 = callee->parameter_instruction(0);
2542 HloInstruction* param1 = callee->parameter_instruction(1);
2543
2544 HloInstruction* entry_param0 = main->parameter_instruction(0);
2545 HloInstruction* entry_param1 = main->parameter_instruction(1);
2546 HloInstruction* custom_call = main->GetInstructionWithName("custom_call");
2547
2548 EXPECT_EQ(GetAllocation(*buffers, entry_param0, {}),
2549 GetAllocation(*buffers, param0, {0}));
2550 EXPECT_EQ(GetAllocation(*buffers, entry_param1, {}),
2551 GetAllocation(*buffers, param1, {}));
2552
2553 EXPECT_EQ(GetAllocation(*buffers, custom_call, {}),
2554 GetAllocation(*buffers, param0, {1}));
2555 EXPECT_EQ(GetAllocation(*buffers, custom_call, {0}),
2556 GetAllocation(*buffers, param0, {1, 0}));
2557 EXPECT_EQ(GetAllocation(*buffers, custom_call, {1}),
2558 GetAllocation(*buffers, param0, {1, 1}));
2559 }
2560
TEST_F(BufferAssignmentTest,BufferInfoStringTest)2561 TEST_F(BufferAssignmentTest, BufferInfoStringTest) {
2562 absl::string_view module_str = R"(
2563 HloModule test_module
2564
2565 ENTRY %test_module {
2566 %param.0 = s32[1024]{0} parameter(0)
2567 %param.1 = s32[1024]{0} parameter(1)
2568 %mul = s32[1024]{0} multiply(%param.0, %param.1)
2569 %add = s32[1024]{0} add(%mul, %param.0)
2570 ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[1024] %add), dimensions={0}
2571 })";
2572
2573 absl::string_view reference_str =
2574 R"(buffer_id,buffer_name,offset,size,definition_time,end_time,num_uses,use_times,use_names
2575 0,"<0 param.0 @0>",0,4096,0,5,2,"2;3","mul, operand 0;add, operand 1"
2576 1,"<1 param.1 @0>",0,4096,1,5,1,"2","mul, operand 1"
2577 2,"<2 mul @0>",0,4096,2,3,1,"3","add, operand 0"
2578 3,"<3 add @0>",0,4096,3,4,1,"4","bcast, operand 0"
2579 4,"<4 bcast @0>",0,4194304,4,5,0,"",""
2580 )";
2581
2582 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
2583 HloInstruction* const param0 = FindInstruction(m.get(), "param.0");
2584 HloInstruction* const param1 = FindInstruction(m.get(), "param.1");
2585 HloInstruction* const mul = FindInstruction(m.get(), "mul");
2586 HloInstruction* const add = FindInstruction(m.get(), "add");
2587 HloInstruction* const bcast = FindInstruction(m.get(), "bcast");
2588 // Run buffer assignment.
2589 auto assignment = RunBufferAssignmentWithInstructionSequence(
2590 m.get(), {param0, param1, mul, add, bcast});
2591 const std::string buffer_info_str = assignment->BufferInfoString();
2592
2593 EXPECT_EQ(buffer_info_str, reference_str);
2594 }
2595
TEST_F(WhileBufferAssignmentTest,WhileLoopsInterferingResultRange)2596 TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
2597 auto module = CreateNewVerifiedModule();
2598 auto builder = HloComputation::Builder(TestName());
2599
2600 auto zero = builder.AddInstruction(
2601 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2602 auto one = builder.AddInstruction(
2603 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2604
2605 auto input0 = builder.AddInstruction(
2606 HloInstruction::CreateParameter(0, data_shape_, "input0"));
2607 auto weights0 = builder.AddInstruction(
2608 HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2609 auto output0 = builder.AddInstruction(
2610 HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2611
2612 auto input1 = builder.AddInstruction(
2613 HloInstruction::CreateParameter(2, data_shape_, "input1"));
2614 auto weights1 = builder.AddInstruction(
2615 HloInstruction::CreateParameter(3, data_shape_, "weights1"));
2616 auto output1 = builder.AddInstruction(
2617 HloInstruction::CreateBroadcast(data_shape_, one, {}));
2618
2619 auto cond =
2620 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2621 auto body = module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2622
2623 auto tuple0 = builder.AddInstruction(
2624 HloInstruction::CreateTuple({input0, weights0, output0}));
2625 auto tuple1 = builder.AddInstruction(
2626 HloInstruction::CreateTuple({input1, weights1, output1}));
2627
2628 auto while0 = builder.AddInstruction(
2629 HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple0));
2630 auto while1 = builder.AddInstruction(
2631 HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple1));
2632
2633 auto gte0 = builder.AddInstruction(
2634 HloInstruction::CreateGetTupleElement(data_shape_, while0, 0));
2635 auto gte1 = builder.AddInstruction(
2636 HloInstruction::CreateGetTupleElement(data_shape_, while1, 1));
2637 auto root_add = builder.AddInstruction(
2638 HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, gte0, gte1));
2639
2640 module->AddEntryComputation(builder.Build());
2641
2642 {
2643 FlattenCallGraph flatten;
2644 TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
2645 EXPECT_TRUE(result);
2646 }
2647
2648 RunCopyInsertion(module.get());
2649
2650 HloSchedule schedule =
2651 ScheduleModule(module.get(), ByteSizeOf).ConsumeValueOrDie();
2652
2653 // To trigger b/38494731, we want a specific Hlo schedule for the
2654 // root computation, so we overwrite that entry with a manually
2655 // crafted sequence.
2656 schedule.set_sequence(
2657 module->entry_computation(),
2658 {input1, weights1, one, output1, while1->mutable_operand(0), while1,
2659 input0, weights0, zero, output0, while0->mutable_operand(0), while0,
2660 gte0, gte1, root_add});
2661
2662 // If this ASSERT fails, we constructed a bogus sequence above and this test
2663 // itself is buggy.
2664 TF_ASSERT_OK(schedule.Verify());
2665
2666 auto assignment =
2667 BufferAssigner::Run(
2668 module.get(), absl::make_unique<SequentialHloOrdering>(schedule),
2669 ByteSizeOf, [](LogicalBuffer::Color) { return 1; },
2670 /*allocate_buffers_for_constants=*/true)
2671 .ConsumeValueOrDie();
2672
2673 EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment));
2674 }
2675
TEST_F(WhileBufferAssignmentTest,WhilesDontShareEntryParamIfLiveOut)2676 TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) {
2677 auto module = CreateNewVerifiedModule();
2678 auto builder = HloComputation::Builder("entry");
2679
2680 auto input0 = builder.AddInstruction(
2681 HloInstruction::CreateParameter(0, data_shape_, "input0"));
2682 auto weights0 = builder.AddInstruction(
2683 HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2684
2685 auto zero = builder.AddInstruction(
2686 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2687 auto output0 = builder.AddInstruction(
2688 HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2689 auto output1 = builder.AddInstruction(
2690 HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2691
2692 auto cond0 =
2693 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2694 auto body0 =
2695 module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2696
2697 auto tuple0 = builder.AddInstruction(
2698 HloInstruction::CreateTuple({input0, weights0, output0}));
2699 auto while0 = builder.AddInstruction(
2700 HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
2701
2702 // Get output of 'while0' and feed as input to 'while1'.
2703 auto while0_out = builder.AddInstruction(
2704 HloInstruction::CreateGetTupleElement(data_shape_, while0, 2));
2705
2706 auto cond1 =
2707 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2708 auto body1 =
2709 module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2710
2711 auto tuple1 = builder.AddInstruction(
2712 HloInstruction::CreateTuple({while0_out, weights0, output1}));
2713 auto while1 = builder.AddInstruction(
2714 HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
2715
2716 // Get output of 'while1' so that it is live out of computation.
2717 auto while1_out = builder.AddInstruction(
2718 HloInstruction::CreateGetTupleElement(data_shape_, while1, 2));
2719
2720 module->AddEntryComputation(builder.Build());
2721 RunCopyInsertion(module.get());
2722 auto assignment = RunBufferAssignment(module.get());
2723 // Get BufferAllocation for root instruction.
2724 auto* root_alloc = assignment->GetUniqueTopLevelSlice(while1_out)
2725 .ConsumeValueOrDie()
2726 .allocation();
2727 // Test that root instruction allocation is live out.
2728 EXPECT_TRUE(root_alloc->maybe_live_out());
2729 // Test that root instruction allocation is not an entry parameter.
2730 EXPECT_FALSE(root_alloc->is_entry_computation_parameter());
2731 }
2732
TEST_F(WhileBufferAssignmentTest,WhileWithDynamicUpdateSliceShare)2733 TEST_F(WhileBufferAssignmentTest, WhileWithDynamicUpdateSliceShare) {
2734 const char* const hlo_string = R"(
2735 HloModule test
2736
2737 while_body {
2738 state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
2739 constant.1 = f32[] constant(0)
2740 broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={}
2741 get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1
2742 get-tuple-element.3 = s32[] get-tuple-element(state), index=0
2743 constant.2 = s32[] constant(128)
2744 add.5 = s32[] add(get-tuple-element.3, constant.2)
2745 constant.3 = s32[] constant(0)
2746 dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
2747 dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
2748 ROOT tuple.85 = (s32[], f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
2749 }
2750
2751 while_condition {
2752 state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
2753 get-tuple-element = s32[] get-tuple-element(state), index=0
2754 get-tuple-element.1 = s32[] constant(3)
2755 ROOT less-than.339.338 = pred[] compare(get-tuple-element, get-tuple-element.1), direction=LT
2756 }
2757
2758 ENTRY entry_computation {
2759 constant.7 = s32[] constant(0)
2760 copy.1 = s32[] copy(constant.7)
2761 constant.6 = f32[] constant(0)
2762 broadcast.6 = f32[1280,1,128]{2,1,0} broadcast(constant.6), dimensions={}
2763 tuple.1 = (s32[], f32[1280,1,128]{2,1,0}) tuple(copy.1, broadcast.6)
2764 while.0 = (s32[], f32[1280,1,128]{2,1,0}) while(tuple.1), condition=while_condition, body=while_body
2765 ROOT get-tuple-element.2 = s32[] get-tuple-element(while.0), index=0
2766 }
2767
2768 )";
2769 auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
2770 auto module = module_or_status.ConsumeValueOrDie();
2771
2772 RunCopyInsertion(module.get());
2773 auto assignment = RunBufferAssignment(module.get());
2774 // Get BufferAllocation for root instruction.
2775 auto dus9 = FindInstruction(module.get(), "dynamic-update-slice.9");
2776 auto dus9_alloc_slice =
2777 assignment->GetUniqueTopLevelSlice(dus9).ConsumeValueOrDie();
2778 auto dus5 = FindInstruction(module.get(), "dynamic-update-slice.5");
2779 auto dus5_alloc_slice =
2780 assignment->GetUniqueTopLevelSlice(dus5).ConsumeValueOrDie();
2781 // Test that the two dynamic-update-slice ops share the same allocation slice.
2782 EXPECT_EQ(dus9_alloc_slice.allocation(), dus5_alloc_slice.allocation());
2783 EXPECT_EQ(dus9_alloc_slice, dus5_alloc_slice);
2784 }
2785 } // namespace
2786 } // namespace xla
2787