1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/heap_simulator.h"
17
18 #include <memory>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/container/flat_hash_map.h"
23 #include "tensorflow/compiler/xla/literal.h"
24 #include "tensorflow/compiler/xla/service/buffer_value.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
29 #include "tensorflow/compiler/xla/service/hlo_value.h"
30 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/test.h"
35
36 namespace xla {
37 namespace {
38
39 class MinimumMemoryForSequenceTest : public HloTestBase {};
40
TEST_F(MinimumMemoryForSequenceTest,MultiComputation)41 TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
42 auto module = CreateNewVerifiedModule();
43 const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
44 const Shape tuple_shape =
45 ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
46
47 auto cond_builder = HloComputation::Builder("WhileCond");
48 // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
49 HloInstruction* cond_param = cond_builder.AddInstruction(
50 HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
51 HloInstruction* cond_iter = cond_builder.AddInstruction(
52 HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
53 HloInstruction* cond_data = cond_builder.AddInstruction(
54 HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
55 // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
56 HloInstruction* cond_lt = cond_builder.AddInstruction(
57 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
58 cond_data, ComparisonDirection::kLt));
59 HloComputation* cond_computation =
60 module->AddEmbeddedComputation(cond_builder.Build());
61
62 auto body_builder = HloComputation::Builder("WhileBody");
63 // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
64 HloInstruction* body_param = body_builder.AddInstruction(
65 HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
66 HloComputation* body_computation =
67 module->AddEmbeddedComputation(body_builder.Build());
68
69 auto builder = HloComputation::Builder(TestName());
70 // Entry params: 8 bytes (4 bytes per param), TOTAL=8
71 HloInstruction* iter = builder.AddInstruction(
72 HloInstruction::CreateParameter(0, scalar_shape, "param_iter"));
73 HloInstruction* data = builder.AddInstruction(
74 HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
75 // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24
76 HloInstruction* tuple =
77 builder.AddInstruction(HloInstruction::CreateTuple({iter, data}));
78 // While: 8 bytes (4 bytes per element), TOTAL=32
79 // Both cond and body use a max of 24 bytes, TOTAL=56
80 HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
81 tuple_shape, cond_computation, body_computation, tuple));
82 HloComputation* entry_computation =
83 module->AddEntryComputation(builder.Build());
84
85 auto size_fn = [](const BufferValue& buffer) {
86 return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
87 };
88
89 HloSchedule schedule(module.get());
90 schedule.set_sequence(cond_computation,
91 {cond_param, cond_iter, cond_data, cond_lt});
92 schedule.set_sequence(body_computation, {body_param});
93 schedule.set_sequence(entry_computation, {iter, data, tuple, while_op});
94 TF_ASSERT_OK(schedule.Verify());
95
96 EXPECT_EQ(
97 25,
98 HeapSimulator::MinimumMemoryForModule(schedule, size_fn).ValueOrDie());
99 }
100
TEST_F(MinimumMemoryForSequenceTest,SubcomputationAccounting)101 TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) {
102 // HloModule SubcomputationAccounting
103
104 // %WhileBody (body_param: f32[4]) -> f32[4] {
105 // %body_param = f32[4]{0} parameter(0)
106 // %constant.1 = f32[4]{0} constant({1, 1, 1, 1})
107 // ROOT %subtract = f32[4]{0} subtract(f32[4]{0} %body_param, f32[4]{0}
108 // %constant.1)
109 // }
110
111 // %WhileCond (cond_param: f32[4]) -> pred[] {
112 // %cond_param = f32[4]{0} parameter(0)
113 // %slice = f32[1]{0} slice(f32[4]{0} %cond_param), slice={[0:1]}
114 // %reshape = f32[] reshape(f32[1]{0} %slice)
115 // %constant = f32[] constant(0)
116 // ROOT %not-equal-to = pred[] compare(f32[] %reshape, f32[] %constant),
117 // direction=NE
118 // }
119
120 // ENTRY %SubcomputationAccounting () -> f32[2,4] {
121 // %constant.3 = f32[2,4]{1,0} constant(f32[2,4] { { 1, 2, 3, 4 }, { 1, 2,
122 // 3, 4 } }) %transpose = f32[2,4]{1,0} transpose(f32[2,4]{1,0}
123 // %constant.3), dimensions={0,1} %constant.2 = f32[4]{0} constant({1, 1, 1,
124 // 1}) %while = f32[4]{0} while(f32[4]{0} %constant.2),
125 // condition=%WhileCond, body=%WhileBody %broadcast = f32[2,4]{1,0}
126 // broadcast(f32[4]{0} %while), dimensions={1} ROOT %add = f32[2,4]{1,0}
127 // add(f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast)
128 // }
129
130 auto module = CreateNewVerifiedModule();
131 const Shape r0f32 = ShapeUtil::MakeShape(F32, {});
132 const Shape r1f32 = ShapeUtil::MakeShape(F32, {4});
133 const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4});
134
135 // reshape(slice(param)) != 0
136 // Needs 5 bytes
137 auto cond_builder = HloComputation::Builder("WhileCond");
138 HloInstruction* cond_param = cond_builder.AddInstruction(
139 HloInstruction::CreateParameter(0, r1f32, "cond_param"));
140 HloInstruction* slice =
141 cond_builder.AddInstruction(HloInstruction::CreateSlice(
142 ShapeUtil::MakeShape(F32, {1}), cond_param, {0}, {1}, {1}));
143 HloInstruction* reshape =
144 cond_builder.AddInstruction(HloInstruction::CreateReshape(r0f32, slice));
145 HloInstruction* zero = cond_builder.AddInstruction(
146 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
147 HloInstruction* cond_comparison = cond_builder.AddInstruction(
148 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), reshape,
149 zero, ComparisonDirection::kNe));
150 auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
151
152 // param - 1
153 // Needs 16 bytes
154 auto body_builder = HloComputation::Builder("WhileBody");
155 HloInstruction* body_param = body_builder.AddInstruction(
156 HloInstruction::CreateParameter(0, r1f32, "body_param"));
157 HloInstruction* one_vector =
158 body_builder.AddInstruction(HloInstruction::CreateConstant(
159 LiteralUtil::CreateR1<float>({1, 1, 1, 1})));
160 HloInstruction* subtract =
161 body_builder.AddInstruction(HloInstruction::CreateBinary(
162 r1f32, HloOpcode::kSubtract, body_param, one_vector));
163 auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
164
165 // transpose(matrix) + bcast(while)
166 auto builder = HloComputation::Builder(TestName());
167 HloInstruction* while_init =
168 builder.AddInstruction(HloInstruction::CreateConstant(
169 LiteralUtil::CreateR1<float>({1, 1, 1, 1})));
170 // Creates 16 bytes, ignoring subcomputations
171 HloInstruction* while_loop =
172 builder.AddInstruction(HloInstruction::CreateWhile(
173 r1f32, cond_computation, body_computation, while_init));
174
175 // Creates 32 bytes and frees 16
176 HloInstruction* bcast = builder.AddInstruction(
177 HloInstruction::CreateBroadcast(r2f32, while_loop, {1}));
178
179 HloInstruction* matrix = builder.AddInstruction(
180 HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>(
181 {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}})));
182 // Creates 32 bytes
183 HloInstruction* transpose = builder.AddInstruction(
184 HloInstruction::CreateTranspose(r2f32, matrix, {0, 1}));
185
186 // Creates 32 bytes and frees 64
187 HloInstruction* add = builder.AddInstruction(
188 HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast));
189
190 auto entry_computation = module->AddEntryComputation(builder.Build());
191
192 HloSchedule schedule(module.get());
193 std::vector<HloInstruction*> cond_vec = {cond_param, slice, reshape, zero,
194 cond_comparison};
195 std::vector<HloInstruction*> while_body_vec = {body_param, one_vector,
196 subtract};
197 std::vector<HloInstruction*> entry_comp_vec = {while_init, while_loop, bcast,
198 matrix, transpose, add};
199 schedule.set_sequence(cond_computation, cond_vec);
200 schedule.set_sequence(body_computation, while_body_vec);
201 schedule.set_sequence(entry_computation, entry_comp_vec);
202
203 auto size_fn = [](const BufferValue& buffer) {
204 return ShapeUtil::ByteSizeOf(buffer.shape());
205 };
206 absl::flat_hash_map<const HloComputation*, int64_t> memory_by_computation;
207 memory_by_computation[cond_computation] = 5;
208 memory_by_computation[body_computation] = 16;
209
210 std::unique_ptr<HloAliasAnalysis> alias_analysis =
211 HloAliasAnalysis::Run(module.get()).ValueOrDie();
212
213 // HeapSimulator accounts for subcomputations. The output buffer is aliased,
214 // so we don't double count.
215 EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation(
216 *entry_computation, schedule.sequence(entry_computation),
217 *alias_analysis, size_fn, &memory_by_computation)
218 .ValueOrDie());
219 }
220
221 const char kAlloc[] = "Alloc";
222 const char kFree[] = "Free";
223 const char kShare[] = "Share";
224 const char kFinish[] = "Finish";
225
226 // CallSequence records a sequence of Alloc/Free/Finish calls.
227 using CallSequence = std::vector<std::pair<std::string, const HloValue*>>;
228
229 // HeapCallRecorder is a dummy heap algorithm that simply records its calls.
230 class HeapCallRecorder : public HeapAlgorithm<HloValue> {
231 public:
HeapCallRecorder(CallSequence * calls)232 explicit HeapCallRecorder(CallSequence* calls) : calls_(calls) {}
~HeapCallRecorder()233 ~HeapCallRecorder() override {}
234
Alloc(const HloValue * buffer,int64_t size)235 void Alloc(const HloValue* buffer, int64_t size) override {
236 calls_->emplace_back(kAlloc, buffer);
237 // Instead of assigning a real offset, we set the cardinality of the Alloc
238 // call. This isn't a valid assignment, but allows us to easily test for
239 // buffer sharing.
240 const int64_t offset = result_.chunk_map.size();
241 result_.chunk_map.emplace(buffer, Chunk{offset, size});
242 }
243
ShareWith(const HloValue * buffer,const HloValue * shared,int64_t size)244 void ShareWith(const HloValue* buffer, const HloValue* shared,
245 int64_t size) override {
246 calls_->emplace_back(kShare, buffer);
247 // Instead of assigning a real offset, we set the cardinality of the Alloc
248 // call. This isn't a valid assignment, but allows us to easily test for
249 // buffer sharing.
250 const int64_t offset = result_.chunk_map[shared].offset;
251 result_.chunk_map.emplace(buffer, Chunk{offset, size});
252 }
Free(const HloValue * buffer,int64_t size)253 void Free(const HloValue* buffer, int64_t size) override {
254 calls_->emplace_back(kFree, buffer);
255 }
Finish()256 Result Finish() override {
257 calls_->emplace_back(kFinish, nullptr);
258 HeapSimulator::Result<HloValue> result;
259 result.heap_size = result_.heap_size;
260 result.heap_results.emplace_back(std::move(result_));
261 return result;
262 }
263
264 private:
265 CallSequence* calls_;
266 HeapSimulator::HeapResult<HloValue> result_;
267 };
268
269 // HeapSimulatorTracker runs the heap simulator, recording the sequence of calls
270 // made to the underlying heap algorithm. Tests compare the actual call
271 // sequence against an expected sequence.
272 class HeapSimulatorTracker {
273 public:
HeapSimulatorTracker(std::unique_ptr<HloModule> module,const std::vector<HloInstruction * > & instruction_sequence,const std::vector<HloInstruction * > & must_alias_set={},const HloDataflowAnalysis::CanShareBuffer & can_share_buffer=nullptr)274 explicit HeapSimulatorTracker(
275 std::unique_ptr<HloModule> module,
276 const std::vector<HloInstruction*>& instruction_sequence,
277 const std::vector<HloInstruction*>& must_alias_set = {},
278 const HloDataflowAnalysis::CanShareBuffer& can_share_buffer = nullptr) {
279 module_ = std::move(module);
280 Init(instruction_sequence, can_share_buffer);
281 }
282
283 // Constructor for testing a single entry computation.
HeapSimulatorTracker(const std::string & name,std::unique_ptr<HloComputation> entry_computation,const std::vector<HloInstruction * > & instruction_sequence,const std::vector<HloInstruction * > & must_alias_set={},const HloDataflowAnalysis::CanShareBuffer & can_share_buffer=nullptr)284 explicit HeapSimulatorTracker(
285 const std::string& name,
286 std::unique_ptr<HloComputation> entry_computation,
287 const std::vector<HloInstruction*>& instruction_sequence,
288 const std::vector<HloInstruction*>& must_alias_set = {},
289 const HloDataflowAnalysis::CanShareBuffer& can_share_buffer = nullptr) {
290 HloModuleConfig config;
291 module_ = std::make_unique<HloModule>(name, config);
292 module_->AddEntryComputation(std::move(entry_computation));
293 Init(instruction_sequence, can_share_buffer);
294 }
295
HeapSimulatorTracker(const std::string & name)296 explicit HeapSimulatorTracker(const std::string& name) {
297 HloModuleConfig config;
298 module_ = std::make_unique<HloModule>(name, config);
299 }
300
301 // Similar to the single entry computation constructor above, but runs the
302 // simulation over the entire module.
RunWholeModule(const std::vector<HloInstruction * > & full_module_sequence)303 void RunWholeModule(
304 const std::vector<HloInstruction*>& full_module_sequence) {
305 alias_analysis_ = HloAliasAnalysis::Run(module_.get()).value();
306
307 // Construct the module sequence grouped by computation.
308 HloSchedule schedule(module_.get());
309 absl::flat_hash_map<const HloInstruction*, int> reverse_position;
310 for (int i = 0; i < full_module_sequence.size(); ++i) {
311 HloInstruction* instruction = full_module_sequence[i];
312 schedule.GetOrCreateSequence(instruction->parent())
313 .push_back(instruction);
314 reverse_position[instruction] = full_module_sequence.size() - i;
315 }
316
317 // Hack the size_fn so that it returns a decreasing value as we step through
318 // the sequence. This lets us ensure the Alloc calls are in the sequence
319 // order. The Free calls are sorted by BufferValue.id, which is at least
320 // deterministic.
321 auto size_fn = [&reverse_position](const BufferValue& buffer) {
322 return reverse_position[buffer.instruction()];
323 };
324 auto algorithm = std::make_unique<HeapCallRecorder>(&actual_calls_);
325 result_ = HeapSimulator::Run(std::move(algorithm), *module_, schedule,
326 *alias_analysis_, size_fn)
327 .value();
328 }
329
module()330 HloModule* module() { return module_.get(); }
331
332 // Returns the buffer defined at the given instruction and index.
BufferAt(const HloInstruction * instruction,const ShapeIndex & index) const333 const HloValue* BufferAt(const HloInstruction* instruction,
334 const ShapeIndex& index) const {
335 return &alias_analysis_->dataflow_analysis().GetUniqueValueAt(instruction,
336 index);
337 }
338
OffsetAt(const HloInstruction * instruction,const ShapeIndex & index)339 int64_t OffsetAt(const HloInstruction* instruction, const ShapeIndex& index) {
340 const HloValue* buffer = BufferAt(instruction, index);
341 CHECK_EQ(1, result_.heap_results.size());
342 return result_.heap_results.at(0).chunk_map.at(buffer).offset;
343 }
344
345 // Ensures the expected sequence of Alloc/Free/Finish calls was performed.
ExpectCallSequence(const CallSequence & expected) const346 void ExpectCallSequence(const CallSequence& expected) const {
347 auto to_string = [](const CallSequence& sequence) {
348 std::string output;
349 for (int64_t i = 0; i < sequence.size(); ++i) {
350 auto pair = sequence.at(i);
351 absl::StrAppendFormat(&output, "%d", i);
352 absl::StrAppendFormat(&output, " :%s", pair.first);
353 if (pair.second != nullptr) {
354 absl::StrAppendFormat(&output, " - %s{%s}\n",
355 pair.second->instruction()->name(),
356 pair.second->index().ToString());
357 }
358 }
359 return output;
360 };
361 EXPECT_EQ(expected, actual_calls_) << "Expected:\n"
362 << to_string(expected) << " \nActual:\n"
363 << to_string(actual_calls_) << "\n";
364 }
365
366 // Ensures the buffers defined by the respective (instruction,index) pairs are
367 // shared, relying on the unique offsets assigned in
368 // HeapCallRecorder::Alloc.
ExpectSharedBuffers(const HloInstruction * instruction_a,const ShapeIndex & index_a,const HloInstruction * instruction_b,const ShapeIndex & index_b)369 void ExpectSharedBuffers(const HloInstruction* instruction_a,
370 const ShapeIndex& index_a,
371 const HloInstruction* instruction_b,
372 const ShapeIndex& index_b) {
373 int64_t offset_a = OffsetAt(instruction_a, index_a);
374 int64_t offset_b = OffsetAt(instruction_b, index_b);
375 EXPECT_EQ(offset_a, offset_b);
376 }
377
378 private:
Init(const std::vector<HloInstruction * > & instruction_sequence,const HloDataflowAnalysis::CanShareBuffer & can_share_buffer)379 void Init(const std::vector<HloInstruction*>& instruction_sequence,
380 const HloDataflowAnalysis::CanShareBuffer& can_share_buffer) {
381 // Since we're only tracking the sequence of Alloc/Free calls, the actual
382 // size of the buffers doesn't matter, so we always return 0. We rely on
383 // the secondary sorting criteria of DecreasingSizeRunsHeap to sort calls
384 // by buffer id, for determinism in the tests.
385 auto zero_size = [](const BufferValue& buffer) { return 0; };
386 auto algorithm = std::make_unique<HeapCallRecorder>(&actual_calls_);
387
388 alias_analysis_ =
389 HloAliasAnalysis::Run(module_.get(), can_share_buffer).ValueOrDie();
390
391 HeapSimulator::Options options;
392
393 result_ =
394 HeapSimulator::Run(std::move(algorithm), *module_->entry_computation(),
395 HloInstructionSequence(instruction_sequence),
396 *alias_analysis_, zero_size, options)
397 .value();
398 }
399
400 std::unique_ptr<HloModule> module_;
401 std::unique_ptr<HloAliasAnalysis> alias_analysis_;
402 CallSequence actual_calls_;
403 HeapSimulator::Result<HloValue> result_;
404 };
405
406 class HeapSimulatorTest : public HloTestBase {
407 protected:
HeapSimulatorTest()408 HeapSimulatorTest() {}
~HeapSimulatorTest()409 ~HeapSimulatorTest() override {}
410
411 // Shapes for use in the examples.
412 Shape f32scalar_ = ShapeUtil::MakeShape(xla::F32, {});
413 Shape f32vec4_ = ShapeUtil::MakeShape(F32, {4});
414 };
415
TEST_F(HeapSimulatorTest,ScalarConstant)416 TEST_F(HeapSimulatorTest, ScalarConstant) {
417 auto builder = HloComputation::Builder(TestName());
418 auto const0 = builder.AddInstruction(
419 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
420
421 // Constants aren't assigned. See b/32248867
422 HeapSimulatorTracker tracker(TestName(), builder.Build(), {const0});
423 tracker.ExpectCallSequence({{kFinish, nullptr}});
424 }
425
TEST_F(HeapSimulatorTest,OneParam)426 TEST_F(HeapSimulatorTest, OneParam) {
427 auto builder = HloComputation::Builder(TestName());
428 auto param0 = builder.AddInstruction(
429 HloInstruction::CreateParameter(0, f32scalar_, "param0"));
430
431 // A single parameter which is also the output.
432 HeapSimulatorTracker tracker(TestName(), builder.Build(), {param0});
433 tracker.ExpectCallSequence({
434 {kAlloc, tracker.BufferAt(param0, {})},
435 {kFree, tracker.BufferAt(param0, {})},
436 {kFinish, nullptr},
437 });
438 }
439
TEST_F(HeapSimulatorTest,Multiply)440 TEST_F(HeapSimulatorTest, Multiply) {
441 auto builder = HloComputation::Builder(TestName());
442 auto paramA = builder.AddInstruction(
443 HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
444 auto paramX = builder.AddInstruction(
445 HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
446 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
447 f32vec4_, HloOpcode::kMultiply, paramA, paramX));
448
449 // We must keep all parameters and outputs.
450 HeapSimulatorTracker tracker(TestName(), builder.Build(),
451 {paramA, paramX, mul});
452 tracker.ExpectCallSequence({
453 {kAlloc, tracker.BufferAt(paramA, {})},
454 {kAlloc, tracker.BufferAt(paramX, {})},
455 {kAlloc, tracker.BufferAt(mul, {})},
456 // All params and outputs are freed at the end.
457 {kFree, tracker.BufferAt(paramA, {})},
458 {kFree, tracker.BufferAt(paramX, {})},
459 {kFree, tracker.BufferAt(mul, {})},
460 {kFinish, nullptr},
461 });
462 }
463
TEST_F(HeapSimulatorTest,MultiplyAdd)464 TEST_F(HeapSimulatorTest, MultiplyAdd) {
465 auto builder = HloComputation::Builder(TestName());
466 auto paramA = builder.AddInstruction(
467 HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
468 auto paramX = builder.AddInstruction(
469 HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
470 auto paramY = builder.AddInstruction(
471 HloInstruction::CreateParameter(2, f32vec4_, "paramY"));
472 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
473 f32vec4_, HloOpcode::kMultiply, paramA, paramX));
474 auto add = builder.AddInstruction(
475 HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, mul, paramY));
476
477 // The buffer for add is the output, and it's shared with the buffer for
478 // mul.
479 HeapSimulatorTracker tracker(TestName(), builder.Build(),
480 {paramA, paramX, mul, paramY, add});
481 tracker.ExpectCallSequence({
482 {kAlloc, tracker.BufferAt(paramA, {})},
483 {kAlloc, tracker.BufferAt(paramX, {})},
484 {kAlloc, tracker.BufferAt(paramY, {})},
485 {kAlloc, tracker.BufferAt(mul, {})},
486 {kFree, tracker.BufferAt(mul, {})},
487 {kShare, tracker.BufferAt(add, {})},
488 // All params and outputs are freed at the end.
489 {kFree, tracker.BufferAt(paramA, {})},
490 {kFree, tracker.BufferAt(paramX, {})},
491 {kFree, tracker.BufferAt(paramY, {})},
492 {kFree, tracker.BufferAt(add, {})},
493 {kFinish, nullptr},
494 });
495 tracker.ExpectSharedBuffers(add, {}, mul, {});
496 }
497
TEST_F(HeapSimulatorTest,FusionOutputsOnlyShareOnce)498 TEST_F(HeapSimulatorTest, FusionOutputsOnlyShareOnce) {
499 // Test that only one output of a fusion node will be shared with its operand.
500 auto can_share_buffer =
501 [](const HloInstruction* instr, const HloInstruction* operand,
502 const ShapeIndex& user_index) -> std::optional<bool> {
503 return instr->opcode() == HloOpcode::kFusion &&
504 operand->shape().IsArray() &&
505 ShapeUtil::Equal(operand->shape(),
506 ShapeUtil::GetSubshape(instr->shape(), user_index));
507 };
508
509 HloModuleConfig config;
510 auto module = std::make_unique<HloModule>(TestName(), config);
511
512 auto builder = HloComputation::Builder(TestName());
513 auto paramA = builder.AddInstruction(
514 HloInstruction::CreateParameter(0, f32vec4_, "paramA"));
515 auto negate = builder.AddInstruction(
516 HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, paramA));
517
518 // The fusion node has two outputs, both are eligible for being reused with
519 // operand.
520 auto fusion_builder = HloComputation::Builder("simple_two_way_forwarding");
521 {
522 auto param = fusion_builder.AddInstruction(
523 HloInstruction::CreateParameter(0, f32vec4_, "x"));
524 fusion_builder.AddInstruction(HloInstruction::CreateTuple({param, param}));
525 }
526 auto fusion_computation =
527 module->AddEmbeddedComputation(fusion_builder.Build());
528
529 auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
530 ShapeUtil::MakeTupleShape({f32vec4_, f32vec4_}),
531 HloInstruction::FusionKind::kLoop, {negate}, fusion_computation));
532
533 auto element0 = builder.AddInstruction(
534 HloInstruction::CreateGetTupleElement(f32scalar_, fusion, 0));
535
536 auto element1 = builder.AddInstruction(
537 HloInstruction::CreateGetTupleElement(f32scalar_, fusion, 1));
538
539 auto negate0 = builder.AddInstruction(
540 HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, element0));
541 auto negate1 = builder.AddInstruction(
542 HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, element1));
543
544 builder.AddInstruction(HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd,
545 negate0, negate1));
546
547 module->AddEntryComputation(builder.Build());
548 HeapSimulatorTracker tracker(
549 std::move(module),
550 {paramA, negate, fusion, element0, element1, negate0, negate1}, {},
551 can_share_buffer);
552 tracker.ExpectCallSequence({
553 {kAlloc, tracker.BufferAt(paramA, {})},
554 {kAlloc, tracker.BufferAt(negate, {})},
555 {kAlloc, tracker.BufferAt(fusion, {})},
556 {kFree, tracker.BufferAt(negate, {})},
557 {kShare, tracker.BufferAt(fusion, {0})},
558 {kAlloc, tracker.BufferAt(fusion, {1})},
559 {kFree, tracker.BufferAt(fusion, {})},
560 {kAlloc, tracker.BufferAt(negate0, {})},
561 {kFree, tracker.BufferAt(fusion, {0})},
562 {kFree, tracker.BufferAt(negate0, {})},
563 {kAlloc, tracker.BufferAt(negate1, {})},
564 {kFree, tracker.BufferAt(fusion, {1})},
565 {kFree, tracker.BufferAt(negate1, {})},
566 {kFree, tracker.BufferAt(paramA, {})},
567 {kFinish, nullptr},
568 });
569 }
570
TEST_F(HeapSimulatorTest,FusionOutputsOnlyShareOnceOutputShortLived)571 TEST_F(HeapSimulatorTest, FusionOutputsOnlyShareOnceOutputShortLived) {
572 // Test that only one output of a fusion node will be shared with its operand.
573 // This variant of the test has a fusion node that dies immediately.
574 auto can_share_buffer =
575 [](const HloInstruction* instr, const HloInstruction* operand,
576 const ShapeIndex& user_index) -> std::optional<bool> {
577 if (instr->opcode() == HloOpcode::kFusion) {
578 return true;
579 }
580 return false;
581 };
582
583 HloModuleConfig config;
584 auto module = std::make_unique<HloModule>(TestName(), config);
585
586 auto builder = HloComputation::Builder(TestName());
587 auto paramA = builder.AddInstruction(
588 HloInstruction::CreateParameter(0, f32vec4_, "paramA"));
589 auto negate = builder.AddInstruction(
590 HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, paramA));
591
592 // The fusion node has two outputs, both are eligible for being reused with
593 // operand.
594 auto fusion_builder = HloComputation::Builder("simple_two_way_forwarding");
595 {
596 auto param = fusion_builder.AddInstruction(
597 HloInstruction::CreateParameter(0, f32vec4_, "x"));
598 fusion_builder.AddInstruction(HloInstruction::CreateTuple({param, param}));
599 }
600 auto fusion_computation =
601 module->AddEmbeddedComputation(fusion_builder.Build());
602
603 auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
604 ShapeUtil::MakeTupleShape({f32vec4_, f32vec4_}),
605 HloInstruction::FusionKind::kLoop, {negate}, fusion_computation));
606
607 auto element1 = builder.AddInstruction(
608 HloInstruction::CreateGetTupleElement(f32scalar_, fusion, 1));
609
610 auto negate1 = builder.AddInstruction(
611 HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, element1));
612
613 module->AddEntryComputation(builder.Build());
614 HeapSimulatorTracker tracker(std::move(module),
615 {paramA, negate, fusion, element1, negate1}, {},
616 can_share_buffer);
617 tracker.ExpectCallSequence({
618 {kAlloc, tracker.BufferAt(paramA, {})},
619 {kAlloc, tracker.BufferAt(negate, {})},
620 {kFree, tracker.BufferAt(negate, {})},
621 {kShare, tracker.BufferAt(fusion, {0})},
622 {kAlloc, tracker.BufferAt(fusion, {})},
623 {kAlloc, tracker.BufferAt(fusion, {1})},
624 {kFree, tracker.BufferAt(fusion, {0})},
625 {kFree, tracker.BufferAt(fusion, {})},
626 {kAlloc, tracker.BufferAt(negate1, {})},
627 {kFree, tracker.BufferAt(fusion, {1})},
628 {kFree, tracker.BufferAt(paramA, {})},
629 {kFree, tracker.BufferAt(negate1, {})},
630 {kFinish, nullptr},
631 });
632 }
633
TEST_F(HeapSimulatorTest,BufferReusedOnce)634 TEST_F(HeapSimulatorTest, BufferReusedOnce) {
635 HeapSimulatorTracker tracker(TestName());
636 auto builder = HloComputation::Builder(TestName());
637
638 HloComputation::Builder fusion_builder("fusion");
639 {
640 HloComputation::Builder& builder = fusion_builder;
641 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
642 /*parameter_number=*/0, f32vec4_, "A"));
643 auto exp = builder.AddInstruction(
644 HloInstruction::CreateUnary(f32vec4_, HloOpcode::kExp, a_param));
645 auto neg = builder.AddInstruction(
646 HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, a_param));
647
648 builder.AddInstruction(HloInstruction::CreateTuple({exp, neg}));
649 }
650 auto fusion_computation =
651 tracker.module()->AddEmbeddedComputation(fusion_builder.Build());
652 auto a_param = builder.AddInstruction(
653 HloInstruction::CreateParameter(0, f32vec4_, "paramA"));
654 auto neg = builder.AddInstruction(
655 HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, a_param));
656 auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
657 ShapeUtil::MakeTupleShape({f32vec4_, f32vec4_}),
658 HloInstruction::FusionKind::kLoop, {neg}, fusion_computation));
659 tracker.module()->AddEntryComputation(builder.Build());
660
661 tracker.RunWholeModule({a_param, neg, fusion});
662
663 auto neg_buffer = tracker.OffsetAt(neg, {});
664 int64_t output_buffer_0 = tracker.OffsetAt(fusion, {0});
665 int64_t output_buffer_1 = tracker.OffsetAt(fusion, {1});
666 // Only one buffer should be shared.
667 EXPECT_TRUE((neg_buffer == output_buffer_0) ^
668 (neg_buffer == output_buffer_1));
669 }
670
TEST_F(HeapSimulatorTest,MultiplyDot)671 TEST_F(HeapSimulatorTest, MultiplyDot) {
672 auto builder = HloComputation::Builder(TestName());
673 auto paramA = builder.AddInstruction(
674 HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
675 auto paramX = builder.AddInstruction(
676 HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
677 auto paramY = builder.AddInstruction(
678 HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
679 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
680 f32vec4_, HloOpcode::kMultiply, paramA, paramX));
681 DotDimensionNumbers dot_dnums;
682 dot_dnums.add_lhs_contracting_dimensions(1);
683 dot_dnums.add_rhs_contracting_dimensions(0);
684 auto dot = builder.AddInstruction(HloInstruction::CreateDot(
685 f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
686
687 // The buffer for dot is the output, and it cannot be shared with the buffer
688 // for mul, since dot isn't elementwise.
689 HeapSimulatorTracker tracker(TestName(), builder.Build(),
690 {paramA, paramX, mul, paramY, dot});
691 tracker.ExpectCallSequence({
692 {kAlloc, tracker.BufferAt(paramA, {})},
693 {kAlloc, tracker.BufferAt(paramX, {})},
694 {kAlloc, tracker.BufferAt(paramY, {})},
695 {kAlloc, tracker.BufferAt(mul, {})},
696 {kAlloc, tracker.BufferAt(dot, {})},
697 // All params and outputs are freed at the end.
698 {kFree, tracker.BufferAt(mul, {})},
699 {kFree, tracker.BufferAt(paramA, {})},
700 {kFree, tracker.BufferAt(paramX, {})},
701 {kFree, tracker.BufferAt(paramY, {})},
702 {kFree, tracker.BufferAt(dot, {})},
703 {kFinish, nullptr},
704 });
705 }
706
TEST_F(HeapSimulatorTest,MultiplyDotAdd)707 TEST_F(HeapSimulatorTest, MultiplyDotAdd) {
708 auto builder = HloComputation::Builder(TestName());
709 auto paramA = builder.AddInstruction(
710 HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
711 auto paramX = builder.AddInstruction(
712 HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
713 auto paramY = builder.AddInstruction(
714 HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
715 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
716 f32vec4_, HloOpcode::kMultiply, paramA, paramX));
717 DotDimensionNumbers dot_dnums;
718 dot_dnums.add_lhs_contracting_dimensions(1);
719 dot_dnums.add_rhs_contracting_dimensions(0);
720 auto dot = builder.AddInstruction(HloInstruction::CreateDot(
721 f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
722 auto add = builder.AddInstruction(
723 HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA));
724
725 // The buffer for add is the output, and it's shared with the buffer for
726 // dot.
727 HeapSimulatorTracker tracker(TestName(), builder.Build(),
728 {paramA, paramX, mul, paramY, dot, add});
729 tracker.ExpectCallSequence({
730 {kAlloc, tracker.BufferAt(paramA, {})},
731 {kAlloc, tracker.BufferAt(paramX, {})},
732 {kAlloc, tracker.BufferAt(paramY, {})},
733 {kAlloc, tracker.BufferAt(mul, {})},
734 {kAlloc, tracker.BufferAt(dot, {})},
735 {kFree, tracker.BufferAt(mul, {})},
736 {kFree, tracker.BufferAt(dot, {})},
737 {kShare, tracker.BufferAt(add, {})},
738 // All params and outputs are freed at the end.
739 {kFree, tracker.BufferAt(paramA, {})},
740 {kFree, tracker.BufferAt(paramX, {})},
741 {kFree, tracker.BufferAt(paramY, {})},
742 {kFree, tracker.BufferAt(add, {})},
743 {kFinish, nullptr},
744 });
745 tracker.ExpectSharedBuffers(add, {}, dot, {});
746 }
747
TEST_F(HeapSimulatorTest,MultiplyDotDot)748 TEST_F(HeapSimulatorTest, MultiplyDotDot) {
749 auto builder = HloComputation::Builder(TestName());
750 auto paramA = builder.AddInstruction(
751 HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
752 auto paramX = builder.AddInstruction(
753 HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
754 auto paramY = builder.AddInstruction(
755 HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
756 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
757 f32vec4_, HloOpcode::kMultiply, paramA, paramX));
758 DotDimensionNumbers dot_dnums;
759 dot_dnums.add_lhs_contracting_dimensions(1);
760 dot_dnums.add_rhs_contracting_dimensions(0);
761 auto dot0 = builder.AddInstruction(HloInstruction::CreateDot(
762 f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
763 auto dot1 = builder.AddInstruction(HloInstruction::CreateDot(
764 f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2)));
765
766 // The buffer for dot1 is the output. No buffers can be shared. The buffer
767 // for mul is freed before the end, since it's no longer used after dot0
768 // finishes.
769 HeapSimulatorTracker tracker(TestName(), builder.Build(),
770 {paramA, paramX, mul, paramY, dot0, dot1});
771 tracker.ExpectCallSequence({
772 {kAlloc, tracker.BufferAt(paramA, {})},
773 {kAlloc, tracker.BufferAt(paramX, {})},
774 {kAlloc, tracker.BufferAt(paramY, {})},
775 {kAlloc, tracker.BufferAt(mul, {})},
776 {kAlloc, tracker.BufferAt(dot0, {})},
777 {kFree, tracker.BufferAt(mul, {})}, // mul no longer used
778 {kAlloc, tracker.BufferAt(dot1, {})},
779 {kFree, tracker.BufferAt(dot0, {})},
780 // All params and outputs are freed at the end.
781 {kFree, tracker.BufferAt(paramA, {})},
782 {kFree, tracker.BufferAt(paramX, {})},
783 {kFree, tracker.BufferAt(paramY, {})},
784 {kFree, tracker.BufferAt(dot1, {})},
785 {kFinish, nullptr},
786 });
787 }
788
TEST_F(HeapSimulatorTest,MultiplyDotDotTuple)789 TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) {
790 auto builder = HloComputation::Builder(TestName());
791 auto paramA = builder.AddInstruction(
792 HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
793 auto paramX = builder.AddInstruction(
794 HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
795 auto paramY = builder.AddInstruction(
796 HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
797 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
798 f32vec4_, HloOpcode::kMultiply, paramA, paramX));
799 DotDimensionNumbers dot_dnums;
800 dot_dnums.add_lhs_contracting_dimensions(1);
801 dot_dnums.add_rhs_contracting_dimensions(0);
802 auto dot0 = builder.AddInstruction(HloInstruction::CreateDot(
803 f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
804 auto dot1 = builder.AddInstruction(HloInstruction::CreateDot(
805 f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2)));
806 auto tuple =
807 builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1}));
808
809 // The buffers for dot0, dot1 and tuple are the output. No buffers can be
810 // shared. The buffer for mul is freed before the end, since it's no longer
811 // used after dot0 finishes.
812 HeapSimulatorTracker tracker(
813 TestName(), builder.Build(),
814 {paramA, paramX, mul, paramY, dot0, dot1, tuple});
815 tracker.ExpectCallSequence({
816 {kAlloc, tracker.BufferAt(paramA, {})},
817 {kAlloc, tracker.BufferAt(paramX, {})},
818 {kAlloc, tracker.BufferAt(paramY, {})},
819 {kAlloc, tracker.BufferAt(mul, {})},
820 {kAlloc, tracker.BufferAt(dot0, {})},
821 {kFree, tracker.BufferAt(mul, {})}, // mul no longer used
822 {kAlloc, tracker.BufferAt(dot1, {})},
823 {kAlloc, tracker.BufferAt(tuple, {})},
824 // All params and outputs are freed at the end.
825 {kFree, tracker.BufferAt(paramA, {})},
826 {kFree, tracker.BufferAt(paramX, {})},
827 {kFree, tracker.BufferAt(paramY, {})},
828 {kFree, tracker.BufferAt(dot0, {})},
829 {kFree, tracker.BufferAt(dot1, {})},
830 {kFree, tracker.BufferAt(tuple, {})},
831 {kFinish, nullptr},
832 });
833 }
834
TEST_F(HeapSimulatorTest,IndependentTupleElements)835 TEST_F(HeapSimulatorTest, IndependentTupleElements) {
836 auto builder = HloComputation::Builder(TestName());
837 auto paramA = builder.AddInstruction(
838 HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
839 auto paramB = builder.AddInstruction(
840 HloInstruction::CreateParameter(1, f32scalar_, "paramB"));
841 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
842 f32scalar_, HloOpcode::kMultiply, paramA, paramB));
843 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
844 f32scalar_, HloOpcode::kAdd, paramA, paramB));
845 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({mul, add}));
846 auto element0 = builder.AddInstruction(
847 HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 0));
848 auto broadcast = builder.AddInstruction(
849 HloInstruction::CreateBroadcast(f32vec4_, element0, {0}));
850 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
851 f32scalar_, HloOpcode::kSubtract, paramA, paramB));
852 auto element1 = builder.AddInstruction(
853 HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 1));
854 auto output = builder.AddInstruction(
855 HloInstruction::CreateTuple({broadcast, sub, element1}));
856
857 HeapSimulatorTracker tracker(TestName(), builder.Build(),
858 {paramA, paramB, mul, add, tuple, element0,
859 broadcast, sub, element1, output});
860 tracker.ExpectCallSequence({
861 {kAlloc, tracker.BufferAt(paramA, {})},
862 {kAlloc, tracker.BufferAt(paramB, {})},
863 {kAlloc, tracker.BufferAt(mul, {})},
864 {kAlloc, tracker.BufferAt(add, {})},
865 {kAlloc, tracker.BufferAt(tuple, {})},
866 {kAlloc, tracker.BufferAt(broadcast, {})},
867 // The mul can be freed right after the broadcast happens, even though
868 // The other GetTupleElement is still alive.
869 {kFree, tracker.BufferAt(mul, {})},
870 {kAlloc, tracker.BufferAt(sub, {})},
871 // The temporary tuple is now dead.
872 {kFree, tracker.BufferAt(tuple, {})},
873 {kAlloc, tracker.BufferAt(output, {})},
874 // All params and outputs are freed at the end.
875 {kFree, tracker.BufferAt(paramA, {})},
876 {kFree, tracker.BufferAt(paramB, {})},
877 {kFree, tracker.BufferAt(add, {})},
878 {kFree, tracker.BufferAt(broadcast, {})},
879 {kFree, tracker.BufferAt(sub, {})},
880 {kFree, tracker.BufferAt(output, {})},
881 {kFinish, nullptr},
882 });
883 }
884
TEST_F(HeapSimulatorTest,WholeModule)885 TEST_F(HeapSimulatorTest, WholeModule) {
886 HeapSimulatorTracker tracker(TestName());
887
888 const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
889 const Shape tuple_shape =
890 ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
891
892 auto cond_builder = HloComputation::Builder("WhileCond");
893 HloInstruction* cond_param = cond_builder.AddInstruction(
894 HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
895 HloInstruction* cond_iter = cond_builder.AddInstruction(
896 HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
897 HloInstruction* cond_data = cond_builder.AddInstruction(
898 HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
899 HloInstruction* cond_lt = cond_builder.AddInstruction(
900 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
901 cond_data, ComparisonDirection::kLt));
902 HloComputation* cond_computation =
903 tracker.module()->AddEmbeddedComputation(cond_builder.Build());
904
905 auto body_builder = HloComputation::Builder("WhileBody");
906 HloInstruction* body_param = body_builder.AddInstruction(
907 HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
908 HloComputation* body_computation =
909 tracker.module()->AddEmbeddedComputation(body_builder.Build());
910
911 auto builder = HloComputation::Builder(TestName());
912 HloInstruction* param = builder.AddInstruction(
913 HloInstruction::CreateParameter(0, tuple_shape, "param"));
914 HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
915 tuple_shape, cond_computation, body_computation, param));
916 tracker.module()->AddEntryComputation(builder.Build());
917
918 tracker.RunWholeModule(
919 {param, while_op, body_param, cond_param, cond_iter, cond_data, cond_lt});
920 tracker.ExpectCallSequence({
921 // The entry computation param and while_op are allocated first.
922 {kAlloc, tracker.BufferAt(param, {})},
923 {kAlloc, tracker.BufferAt(param, {0})},
924 {kAlloc, tracker.BufferAt(param, {1})},
925
926 // Now the final cond less-than buffer is allocated.
927 {kAlloc, tracker.BufferAt(cond_lt, {})},
928
929 // The order of the remaining Free calls is based on the BufferValue.id,
930 // which is deterministic, but not obvious.
931 {kFree, tracker.BufferAt(cond_lt, {})},
932 {kFree, tracker.BufferAt(param, {})},
933 {kFree, tracker.BufferAt(param, {0})},
934 {kFree, tracker.BufferAt(param, {1})},
935 {kFinish, nullptr},
936 });
937 }
938
939 // Base class for heap algorithm tests.
940 class HeapAlgorithmTestBase : public ::testing::Test {
941 protected:
HeapAlgorithmTestBase()942 HeapAlgorithmTestBase() : builder_("heap_simulator_test") {
943 buffer_a_ = DummyBufferValue();
944 buffer_b_ = DummyBufferValue();
945 buffer_c_ = DummyBufferValue();
946 buffer_d_ = DummyBufferValue();
947 buffer_e_ = DummyBufferValue();
948 buffer_f_ = DummyBufferValue();
949 buffer_g_ = DummyBufferValue();
950 buffer_h_ = DummyBufferValue();
951 buffer_i_ = DummyBufferValue();
952 }
~HeapAlgorithmTestBase()953 ~HeapAlgorithmTestBase() override {}
954
955 const HloValue* buffer_a_;
956 const HloValue* buffer_b_;
957 const HloValue* buffer_c_;
958 const HloValue* buffer_d_;
959 const HloValue* buffer_e_;
960 const HloValue* buffer_f_;
961 const HloValue* buffer_g_;
962 const HloValue* buffer_h_;
963 const HloValue* buffer_i_;
964
965 private:
966 // Create a dummy HloValue to pass to the heap algorithm.
DummyBufferValue()967 const HloValue* DummyBufferValue() {
968 const HloValue::Id id = buffers_.size();
969 auto const0 = builder_.AddInstruction(
970 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
971 buffers_.emplace_back(std::make_unique<HloValue>(id, const0, ShapeIndex{}));
972 return buffers_.back().get();
973 }
974
975 HloComputation::Builder builder_;
976 std::vector<std::unique_ptr<HloValue>> buffers_;
977 };
978
979 class NoFragmentationStatsHeapTest : public HeapAlgorithmTestBase {};
980
TEST_F(NoFragmentationStatsHeapTest,Empty)981 TEST_F(NoFragmentationStatsHeapTest, Empty) {
982 NoFragmentationStatsHeap<HloValue> heap;
983 EXPECT_EQ(0, heap.Finish().heap_size);
984 }
985
TEST_F(NoFragmentationStatsHeapTest,Simple)986 TEST_F(NoFragmentationStatsHeapTest, Simple) {
987 NoFragmentationStatsHeap<HloValue> heap;
988 heap.Alloc(buffer_a_, 10);
989 heap.Alloc(buffer_b_, 20);
990 heap.Alloc(buffer_c_, 30);
991 heap.Alloc(buffer_d_, 30);
992 heap.Free(buffer_a_, 10);
993 heap.Free(buffer_b_, 20);
994 heap.Free(buffer_c_, 30);
995 heap.Free(buffer_d_, 30);
996 EXPECT_EQ(90, heap.Finish().heap_size);
997 }
998
TEST_F(NoFragmentationStatsHeapTest,Mixed)999 TEST_F(NoFragmentationStatsHeapTest, Mixed) {
1000 NoFragmentationStatsHeap<HloValue> heap;
1001 heap.Alloc(buffer_a_, 10); // max: A
1002
1003 heap.Alloc(buffer_b_, 20); // max: A+B
1004 heap.Free(buffer_b_, 20);
1005
1006 heap.Alloc(buffer_c_, 30); // max: A+C
1007 heap.Free(buffer_c_, 30);
1008
1009 heap.Alloc(buffer_d_, 5); // max: A+C
1010 heap.Free(buffer_d_, 5);
1011
1012 heap.Free(buffer_a_, 10);
1013 EXPECT_EQ(40, heap.Finish().heap_size);
1014 }
1015
1016 class GlobalDecreasingSizeBestFitHeapTest : public HeapAlgorithmTestBase {
1017 protected:
1018 class InheritedGlobalDecreasingSizeBestFitHeap
1019 : public GlobalDecreasingSizeBestFitHeap<HloValue> {
1020 public:
InheritedGlobalDecreasingSizeBestFitHeap()1021 InheritedGlobalDecreasingSizeBestFitHeap()
1022 : GlobalDecreasingSizeBestFitHeap(/*alignment=*/1) {}
1023
1024 // Finds a chunk candidate and returns the offset and the new heap size.
FindChunkCandidate(const HloValue * buffer,int64_t size,int64_t start,int64_t end,int64_t preferred_offset=-1)1025 std::pair<int64_t, int64_t> FindChunkCandidate(
1026 const HloValue* buffer, int64_t size, int64_t start, int64_t end,
1027 int64_t preferred_offset = -1) {
1028 buffer_interval_.buffer = buffer;
1029 buffer_interval_.size = size;
1030 buffer_interval_.start = start;
1031 buffer_interval_.end = end;
1032 chunk_candidate_ = GlobalDecreasingSizeBestFitHeap::FindChunkCandidate(
1033 buffer_interval_, preferred_offset);
1034 EXPECT_EQ(chunk_candidate_.size, size);
1035 return {chunk_candidate_.offset,
1036 result_.UpdatedHeapSize(chunk_candidate_)};
1037 }
1038
1039 // Commits the previously found chunk candidate.
CommitChunk()1040 void CommitChunk() {
1041 GlobalDecreasingSizeBestFitHeap::CommitChunk(buffer_interval_,
1042 chunk_candidate_);
1043 }
1044
1045 private:
1046 BufferInterval buffer_interval_;
1047 Chunk chunk_candidate_;
1048 };
1049
1050 InheritedGlobalDecreasingSizeBestFitHeap heap_;
1051 };
1052
TEST_F(GlobalDecreasingSizeBestFitHeapTest,Empty)1053 TEST_F(GlobalDecreasingSizeBestFitHeapTest, Empty) {
1054 GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1055 const HeapSimulator::Result<HloValue> result = heap.Finish();
1056 EXPECT_EQ(0, result.heap_size);
1057 EXPECT_EQ(1, result.heap_results.size());
1058 EXPECT_EQ(0, result.heap_results.at(0).chunk_map.size());
1059 }
1060
TEST_F(GlobalDecreasingSizeBestFitHeapTest,DecreasingSize)1061 TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSize) {
1062 // space
1063 // ^
1064 // | +---a---+
1065 // | +-------+
1066 // | +---c---+
1067 // | +-------+
1068 // | | b |
1069 // | +-------+
1070 // | +-------+
1071 // | | |
1072 // | | d |
1073 // | +-------+
1074 // -----------------> time
1075 GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1076 heap.Alloc(buffer_a_, 10);
1077 heap.Alloc(buffer_b_, 30);
1078 heap.Alloc(buffer_c_, 20);
1079 heap.Alloc(buffer_d_, 40);
1080 heap.Free(buffer_a_, 10);
1081 heap.Free(buffer_b_, 30);
1082 heap.Free(buffer_c_, 20);
1083 heap.Free(buffer_d_, 40);
1084
1085 const HeapSimulator::Result<HloValue> results = heap.Finish();
1086 EXPECT_EQ(1, results.heap_results.size());
1087 const HeapSimulator::HeapResult<HloValue>& result =
1088 results.heap_results.at(0);
1089 EXPECT_EQ(100, result.heap_size);
1090 EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
1091 EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size);
1092 EXPECT_EQ(20, result.chunk_map.at(buffer_c_).size);
1093 EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size);
1094
1095 EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset);
1096 EXPECT_EQ(40, result.chunk_map.at(buffer_b_).offset);
1097 EXPECT_EQ(70, result.chunk_map.at(buffer_c_).offset);
1098 EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset);
1099 }
1100
TEST_F(GlobalDecreasingSizeBestFitHeapTest,DecreasingSizeWithAlignment)1101 TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSizeWithAlignment) {
1102 // space
1103 // ^
1104 // | +-------+
1105 // | +---b---+
1106 // | +-------+
1107 // | | |
1108 // | | d |
1109 // | +---a---+ +-------+
1110 // |
1111 // | +-------+
1112 // | | |
1113 // | | c |
1114 // | | |
1115 // | +-------+
1116 // ---------------------> time
1117 GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/20);
1118 heap.Alloc(buffer_a_, 10);
1119 heap.Alloc(buffer_b_, 20);
1120 heap.Alloc(buffer_c_, 50);
1121 heap.Free(buffer_a_, 10);
1122 heap.Alloc(buffer_d_, 40);
1123 heap.Free(buffer_b_, 20);
1124 heap.Free(buffer_c_, 50);
1125 heap.Free(buffer_d_, 40);
1126
1127 const HeapSimulator::Result<HloValue> results = heap.Finish();
1128 EXPECT_EQ(1, results.heap_results.size());
1129 const HeapSimulator::HeapResult<HloValue>& result =
1130 results.heap_results.at(0);
1131 EXPECT_EQ(120, result.heap_size);
1132 EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
1133 EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
1134 EXPECT_EQ(50, result.chunk_map.at(buffer_c_).size);
1135 EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size);
1136
1137 EXPECT_EQ(60, result.chunk_map.at(buffer_a_).offset);
1138 EXPECT_EQ(100, result.chunk_map.at(buffer_b_).offset);
1139 EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset);
1140 EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset);
1141 }
1142
TEST_F(GlobalDecreasingSizeBestFitHeapTest,BestFit)1143 TEST_F(GlobalDecreasingSizeBestFitHeapTest, BestFit) {
1144 // space
1145 // ^
1146 // | +-------+
1147 // | +---b---+
1148 // | +-------+
1149 // | | d |
1150 // | +--a--+ +-------+
1151 // | +-------+
1152 // | | |
1153 // | | c |
1154 // | +-------+
1155 // | +-------+
1156 // | | |
1157 // | | e |
1158 // | | |
1159 // | +-------+
1160 // ---------------------> time
1161 GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1162 heap.Alloc(buffer_a_, 10);
1163 heap.Alloc(buffer_b_, 20);
1164 heap.Alloc(buffer_c_, 40);
1165 heap.Free(buffer_a_, 10);
1166 heap.Alloc(buffer_d_, 30);
1167 heap.Alloc(buffer_e_, 50);
1168 heap.Free(buffer_b_, 20);
1169 heap.Free(buffer_c_, 40);
1170 heap.Free(buffer_d_, 30);
1171 heap.Free(buffer_e_, 50);
1172
1173 const HeapSimulator::Result<HloValue> results = heap.Finish();
1174 EXPECT_EQ(1, results.heap_results.size());
1175 const HeapSimulator::HeapResult<HloValue>& result =
1176 results.heap_results.at(0);
1177 EXPECT_EQ(140, result.heap_size);
1178 EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
1179 EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
1180 EXPECT_EQ(40, result.chunk_map.at(buffer_c_).size);
1181 EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size);
1182 EXPECT_EQ(50, result.chunk_map.at(buffer_e_).size);
1183
1184 EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset);
1185 EXPECT_EQ(120, result.chunk_map.at(buffer_b_).offset);
1186 EXPECT_EQ(50, result.chunk_map.at(buffer_c_).offset);
1187 EXPECT_EQ(90, result.chunk_map.at(buffer_d_).offset);
1188 EXPECT_EQ(0, result.chunk_map.at(buffer_e_).offset);
1189 }
1190
TEST_F(GlobalDecreasingSizeBestFitHeapTest,Colocated)1191 TEST_F(GlobalDecreasingSizeBestFitHeapTest, Colocated) {
1192 // space colocate
1193 // ^ +--------------+
1194 // | v v
1195 // |+------+ +-------+
1196 // || | | |
1197 // || |+----+| |
1198 // |+--a---++-b--++---c---+
1199 // ---------------------> time
1200 GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1201 heap.Alloc(buffer_a_, 40);
1202 heap.Free(buffer_a_, 40);
1203 heap.Alloc(buffer_b_, 20);
1204 heap.Free(buffer_b_, 20);
1205 heap.ShareWith(buffer_c_, buffer_a_, 40);
1206 heap.Free(buffer_c_, 40);
1207
1208 const HeapSimulator::Result<HloValue> results = heap.Finish();
1209 EXPECT_EQ(1, results.heap_results.size());
1210 const HeapSimulator::HeapResult<HloValue>& result =
1211 results.heap_results.at(0);
1212 EXPECT_EQ(40, result.heap_size);
1213 EXPECT_EQ(40, result.chunk_map.at(buffer_a_).size);
1214 EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
1215 EXPECT_EQ(40, result.chunk_map.at(buffer_c_).size);
1216
1217 EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
1218 EXPECT_EQ(0, result.chunk_map.at(buffer_b_).offset);
1219 EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset);
1220 }
1221
TEST_F(GlobalDecreasingSizeBestFitHeapTest,ColocatedII)1222 TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedII) {
1223 // space
1224 // ^ +---------------+
1225 // | +-------b-------+
1226 // |+------+ +-------+
1227 // || | | |
1228 // || | | | <--- colocate with a
1229 // |+--a---+ +---c---+
1230 // ---------------------> time
1231 GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1232 heap.Alloc(buffer_a_, 40);
1233 heap.Free(buffer_a_, 40);
1234 heap.Alloc(buffer_b_, 20);
1235
1236 heap.ShareWith(buffer_c_, buffer_a_, 40);
1237 heap.Free(buffer_c_, 40);
1238 heap.Free(buffer_b_, 20);
1239
1240 const HeapSimulator::Result<HloValue> results = heap.Finish();
1241 EXPECT_EQ(1, results.heap_results.size());
1242 const HeapSimulator::HeapResult<HloValue>& result =
1243 results.heap_results.at(0);
1244 EXPECT_EQ(60, result.heap_size);
1245 EXPECT_EQ(40, result.chunk_map.at(buffer_a_).size);
1246 EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
1247 EXPECT_EQ(40, result.chunk_map.at(buffer_c_).size);
1248
1249 EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
1250 EXPECT_EQ(40, result.chunk_map.at(buffer_b_).offset);
1251 EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset);
1252 }
1253
TEST_F(GlobalDecreasingSizeBestFitHeapTest,ColocatedIII)1254 TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedIII) {
1255 // space
1256 // ^+------+ +-------+
1257 // || | | | <--- colocate with a
1258 // |+--a---+ +---c---+
1259 // | +---------------+
1260 // | | |
1261 // | | |
1262 // | +-------b-------+
1263 // ---------------------> time
1264 GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1265 heap.Alloc(buffer_a_, 10);
1266 heap.Free(buffer_a_, 10);
1267 heap.Alloc(buffer_b_, 30);
1268
1269 heap.ShareWith(buffer_c_, buffer_a_, 10);
1270 heap.Free(buffer_c_, 10);
1271 heap.Free(buffer_b_, 30);
1272
1273 const HeapSimulator::Result<HloValue> results = heap.Finish();
1274 EXPECT_EQ(1, results.heap_results.size());
1275 const HeapSimulator::HeapResult<HloValue>& result =
1276 results.heap_results.at(0);
1277 EXPECT_EQ(40, result.heap_size);
1278 EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
1279 EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size);
1280 EXPECT_EQ(10, result.chunk_map.at(buffer_c_).size);
1281
1282 EXPECT_EQ(30, result.chunk_map.at(buffer_a_).offset);
1283 EXPECT_EQ(0, result.chunk_map.at(buffer_b_).offset);
1284 EXPECT_EQ(30, result.chunk_map.at(buffer_c_).offset);
1285 }
1286
TEST_F(GlobalDecreasingSizeBestFitHeapTest,ColocatedDifferentSize1)1287 TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedDifferentSize1) {
1288 // space
1289 // ^
1290 // | +---------------+
1291 // |+------+ +-------b-------+
1292 // || | +-------+
1293 // || | | | <--- colocate with a
1294 // |+--a---+ +---c---+
1295 // ---------------------> time
1296 GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1297 heap.Alloc(buffer_a_, 40);
1298 heap.Free(buffer_a_, 40);
1299 heap.Alloc(buffer_b_, 20);
1300
1301 heap.ShareWith(buffer_c_, buffer_a_, 30);
1302 heap.Free(buffer_c_, 30);
1303 heap.Free(buffer_b_, 20);
1304
1305 const HeapSimulator::Result<HloValue> results = heap.Finish();
1306 EXPECT_EQ(1, results.heap_results.size());
1307 const HeapSimulator::HeapResult<HloValue>& result =
1308 results.heap_results.at(0);
1309 EXPECT_EQ(50, result.heap_size);
1310 EXPECT_EQ(40, result.chunk_map.at(buffer_a_).size);
1311 EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
1312 EXPECT_EQ(30, result.chunk_map.at(buffer_c_).size);
1313
1314 EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
1315 EXPECT_EQ(30, result.chunk_map.at(buffer_b_).offset);
1316 EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset);
1317 }
1318
TEST_F(GlobalDecreasingSizeBestFitHeapTest,ColocatedDifferentSize2)1319 TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedDifferentSize2) {
1320 // space
1321 // ^ +-------------+
1322 // | +-----b-------+
1323 // | +-------+
1324 // |+------+ | |
1325 // || | | |
1326 // || | | | <--- colocate with a
1327 // |+--a---+ +---c---+
1328 // ---------------------> time
1329 GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1330 heap.Alloc(buffer_a_, 40);
1331 heap.Free(buffer_a_, 40);
1332 heap.Alloc(buffer_b_, 20);
1333
1334 heap.ShareWith(buffer_c_, buffer_a_, 50);
1335 heap.Free(buffer_c_, 50);
1336 heap.Free(buffer_b_, 20);
1337
1338 const HeapSimulator::Result<HloValue> results = heap.Finish();
1339 EXPECT_EQ(1, results.heap_results.size());
1340 const HeapSimulator::HeapResult<HloValue>& result =
1341 results.heap_results.at(0);
1342 EXPECT_EQ(70, result.heap_size);
1343 EXPECT_EQ(40, result.chunk_map.at(buffer_a_).size);
1344 EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
1345 EXPECT_EQ(50, result.chunk_map.at(buffer_c_).size);
1346
1347 EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
1348 EXPECT_EQ(50, result.chunk_map.at(buffer_b_).offset);
1349 EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset);
1350 }
1351
TEST_F(GlobalDecreasingSizeBestFitHeapTest,ChunkCandidate)1352 TEST_F(GlobalDecreasingSizeBestFitHeapTest, ChunkCandidate) {
1353 // space
1354 // ^
1355 // 35|
1356 // | +-----------+
1357 // | | |
1358 // 30| | |
1359 // | | po: 15 |
1360 // | | |
1361 // 25| +-----g-----+
1362 // | +-----+
1363 // | |po:20|
1364 // 20| +--f--+
1365 // | +-----+
1366 // | | |
1367 // 15| | |
1368 // | +-----------------+ |po:10|
1369 // | | | | |
1370 // 10| +-------c---------+ +--e--+
1371 // | +-----+ +-----------+
1372 // | | | | po: 5 |
1373 // 5| | | +-----a-----+
1374 // |+-----+ | |
1375 // ||po:10| | |
1376 // 0|+--d--+ +--b--+
1377 // -----------------------------------------> time
1378 // 0 1 2 3 4 5 6 7 8 9 10 11 12 13
1379 using pair = std::pair<int64_t, int64_t>;
1380 EXPECT_EQ(pair(5, 10), heap_.FindChunkCandidate(buffer_a_, 5, 6, 10, 5));
1381 heap_.CommitChunk(); // offset: 5, size: 5, start: 6, end: 10
1382 // Preferred offset 5 is returned.
1383 EXPECT_EQ(pair(0, 10), heap_.FindChunkCandidate(buffer_b_, 10, 3, 5));
1384 heap_.CommitChunk(); // offset: 0, size: 10, start: 3, end: 5
1385 EXPECT_EQ(pair(10, 15), heap_.FindChunkCandidate(buffer_c_, 5, 2, 8));
1386 heap_.CommitChunk(); // offset: 10, size: 5, start: 2, end: 8
1387 EXPECT_EQ(pair(0, 15), heap_.FindChunkCandidate(buffer_d_, 5, 0, 2, 10));
1388 heap_.CommitChunk(); // offset: 0, size: 5, start: 0, end: 2
1389 // Preferred offset 10 could not be given because it is occupied.
1390 EXPECT_EQ(pair(10, 20), heap_.FindChunkCandidate(buffer_e_, 10, 11, 13, 10));
1391 heap_.CommitChunk(); // offset: 10, size: 10, start: 11, end: 13
1392 // Preferred offset 10 is returned.
1393 EXPECT_EQ(pair(20, 25), heap_.FindChunkCandidate(buffer_f_, 5, 3, 5, 20));
1394 heap_.CommitChunk(); // offset: 20, size: 5, start: 3, end: 5
1395 // Preferred offset 20 is returned.
1396 EXPECT_EQ(pair(25, 35), heap_.FindChunkCandidate(buffer_g_, 10, 4, 8, 15));
1397 heap_.CommitChunk(); // offset: 25, size: 10, start: 4, end: 8
1398 // Preferred offset 15 could not be given because it is occupied.
1399 }
1400
1401 class ConstrainedGlobalDecreasingSizeBestFitHeapTest
1402 : public HeapAlgorithmTestBase {};
1403
TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest,DecreasingSize)1404 TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest, DecreasingSize) {
1405 // space
1406 // ^
1407 // | +-------+
1408 // | +---c---+
1409 // | +-------+
1410 // | | b |
1411 // | +-------+
1412 // | ................ // split into two allocations.
1413 // | +---a---+
1414 // | +-------+
1415 // | | |
1416 // | | d |
1417 // | +-------+
1418 // -----------------> time
1419 ConstrainedGlobalDecreasingSizeBestFitHeap heap(/*size_limit_per_heap=*/50,
1420 /*alignment=*/1);
1421 heap.Alloc(buffer_a_, 10);
1422 heap.Alloc(buffer_b_, 30);
1423 heap.Alloc(buffer_c_, 20);
1424 heap.Alloc(buffer_d_, 40);
1425 heap.Free(buffer_a_, 10);
1426 heap.Free(buffer_b_, 30);
1427 heap.Free(buffer_c_, 20);
1428 heap.Free(buffer_d_, 40);
1429
1430 const HeapSimulator::Result<HloValue> result = heap.Finish();
1431 EXPECT_EQ(100, result.heap_size);
1432 EXPECT_EQ(2, result.heap_results.size());
1433
1434 EXPECT_TRUE(result.heap_results[0].chunk_map.contains(buffer_a_));
1435 EXPECT_TRUE(result.heap_results[0].chunk_map.contains(buffer_d_));
1436 EXPECT_EQ(10, result.heap_results[0].chunk_map.at(buffer_a_).size);
1437 EXPECT_EQ(40, result.heap_results[0].chunk_map.at(buffer_d_).size);
1438 EXPECT_EQ(40, result.heap_results[0].chunk_map.at(buffer_a_).offset);
1439 EXPECT_EQ(0, result.heap_results[0].chunk_map.at(buffer_d_).offset);
1440 }
1441
TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest,DecreasingSizeWithAlignment)1442 TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest,
1443 DecreasingSizeWithAlignment) {
1444 // space
1445 // ^
1446 // | +-------+
1447 // | +---b---+
1448 // | +-------+
1449 // | | |
1450 // | | d |
1451 // | +-------+
1452 // | ...................
1453 // | +---a---+
1454 // |
1455 // | +-------+
1456 // | | |
1457 // | | c |
1458 // | | |
1459 // | +-------+
1460 // ---------------------> time
1461 ConstrainedGlobalDecreasingSizeBestFitHeap heap(/*size_limit_per_heap=*/70,
1462 /*alignment=*/20);
1463 heap.Alloc(buffer_a_, 10);
1464 heap.Alloc(buffer_b_, 20);
1465 heap.Alloc(buffer_c_, 50);
1466 heap.Free(buffer_a_, 10);
1467 heap.Alloc(buffer_d_, 40);
1468 heap.Free(buffer_b_, 20);
1469 heap.Free(buffer_c_, 50);
1470 heap.Free(buffer_d_, 40);
1471
1472 const HeapSimulator::Result<HloValue> result = heap.Finish();
1473 EXPECT_EQ(130, result.heap_size); // 70 + 60
1474 EXPECT_EQ(2, result.heap_results.size());
1475
1476 EXPECT_TRUE(result.heap_results[0].chunk_map.contains(buffer_a_));
1477 EXPECT_TRUE(result.heap_results[0].chunk_map.contains(buffer_c_));
1478 EXPECT_EQ(10, result.heap_results[0].chunk_map.at(buffer_a_).size);
1479 EXPECT_EQ(50, result.heap_results[0].chunk_map.at(buffer_c_).size);
1480 EXPECT_EQ(60, result.heap_results[0].chunk_map.at(buffer_a_).offset);
1481 EXPECT_EQ(0, result.heap_results[0].chunk_map.at(buffer_c_).offset);
1482 }
1483
TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest,ColocatedII)1484 TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest, ColocatedII) {
1485 // space
1486 // ^
1487 // | +---------------+
1488 // | +-------b-------+
1489 // | ....................
1490 // |+------+ +-------+
1491 // || | | |
1492 // || | | | <--- colocate with a
1493 // |+--a---+ +---c---+
1494 // ---------------------> time
1495 ConstrainedGlobalDecreasingSizeBestFitHeap heap(/*size_limit_per_heap=*/50,
1496 /*alignment=*/20);
1497 heap.Alloc(buffer_a_, 30);
1498 heap.Free(buffer_a_, 30);
1499 heap.Alloc(buffer_b_, 20);
1500
1501 heap.ShareWith(buffer_c_, buffer_a_, 40);
1502 heap.Free(buffer_c_, 40);
1503 heap.Free(buffer_b_, 20);
1504
1505 const HeapSimulator::Result<HloValue> result = heap.Finish();
1506 EXPECT_EQ(60, result.heap_size); // 40 + 20
1507 EXPECT_EQ(2, result.heap_results.size());
1508
1509 EXPECT_TRUE(result.heap_results[0].chunk_map.contains(buffer_a_));
1510 EXPECT_TRUE(result.heap_results[0].chunk_map.contains(buffer_c_));
1511 EXPECT_EQ(30, result.heap_results[0].chunk_map.at(buffer_a_).size);
1512 EXPECT_EQ(40, result.heap_results[0].chunk_map.at(buffer_c_).size);
1513 EXPECT_EQ(0, result.heap_results[0].chunk_map.at(buffer_a_).offset);
1514 EXPECT_EQ(0, result.heap_results[0].chunk_map.at(buffer_c_).offset);
1515 }
1516
1517 class IntervalTreeTest : public ::testing::Test {};
1518
TEST_F(IntervalTreeTest,InsertAndRemove)1519 TEST_F(IntervalTreeTest, InsertAndRemove) {
1520 HeapSimulator::Chunk chunk({1, 2});
1521 BufferIntervalTree tree;
1522 tree.Add(1, 2, chunk);
1523 EXPECT_TRUE(tree.Remove(1, 2, chunk));
1524 EXPECT_FALSE(tree.Remove(1, 2, chunk));
1525 ASSERT_EQ(tree.GetRoot(), nullptr);
1526 // Do it again.
1527 tree.Add(1, 2, chunk);
1528 EXPECT_TRUE(tree.Remove(1, 2, chunk));
1529 EXPECT_FALSE(tree.Remove(1, 2, chunk));
1530 ASSERT_EQ(tree.GetRoot(), nullptr);
1531 }
1532
TEST_F(IntervalTreeTest,InsertAndRemoveTwoLevelsLeft)1533 TEST_F(IntervalTreeTest, InsertAndRemoveTwoLevelsLeft) {
1534 HeapSimulator::Chunk chunk({1, 2}); // Value in chunk doesn't matter here.
1535 // [20, 36] (45)
1536 // /
1537 // [1, 45] (45)
1538
1539 BufferIntervalTree tree;
1540 tree.Add(20, 36, chunk);
1541 tree.Add(1, 45, chunk);
1542 EXPECT_TRUE(tree.Remove(1, 45, chunk));
1543 EXPECT_EQ(tree.GetRoot()->subtree_end, 36);
1544 EXPECT_TRUE(tree.Remove(20, 36, chunk));
1545 ASSERT_EQ(tree.GetRoot(), nullptr);
1546 }
1547
TEST_F(IntervalTreeTest,InsertAndRemoveTwoLevelsRight)1548 TEST_F(IntervalTreeTest, InsertAndRemoveTwoLevelsRight) {
1549 HeapSimulator::Chunk chunk({1, 2}); // Value in chunk doesn't matter here.
1550 // [20, 36] (45)
1551 // \
1552 // [21, 45] (45)
1553 BufferIntervalTree tree;
1554 tree.Add(20, 36, chunk);
1555 tree.Add(21, 45, chunk);
1556 EXPECT_TRUE(tree.Remove(21, 45, chunk));
1557 EXPECT_EQ(tree.GetRoot()->subtree_end, 36);
1558 EXPECT_TRUE(tree.Remove(20, 36, chunk));
1559 ASSERT_EQ(tree.GetRoot(), nullptr);
1560 }
1561
TEST_F(IntervalTreeTest,TwoLevelsRight_RootFirst)1562 TEST_F(IntervalTreeTest, TwoLevelsRight_RootFirst) {
1563 HeapSimulator::Chunk chunk({1, 2}); // Value in chunk doesn't matter here.
1564 // [20, 36] (45)
1565 // \
1566 // [21, 45] (45)
1567 BufferIntervalTree tree;
1568 tree.Add(20, 36, chunk);
1569 tree.Add(21, 45, chunk);
1570 EXPECT_TRUE(tree.Remove(20, 36, chunk));
1571 EXPECT_EQ(tree.GetRoot()->subtree_end, 45);
1572 EXPECT_EQ(tree.GetRoot()->start, 21);
1573 EXPECT_EQ(tree.GetRoot()->end, 45);
1574 EXPECT_EQ(tree.GetRoot()->left, nullptr);
1575 EXPECT_EQ(tree.GetRoot()->right, nullptr);
1576 EXPECT_TRUE(tree.Remove(21, 45, chunk));
1577 ASSERT_EQ(tree.GetRoot(), nullptr);
1578 }
1579
TEST_F(IntervalTreeTest,TwoLevelsLeft_RootFirst)1580 TEST_F(IntervalTreeTest, TwoLevelsLeft_RootFirst) {
1581 HeapSimulator::Chunk chunk({1, 2}); // Value in chunk doesn't matter here.
1582 // [20, 36] (45)
1583 // /
1584 // [1, 45] (45)
1585 BufferIntervalTree tree;
1586 tree.Add(20, 36, chunk);
1587 tree.Add(1, 45, chunk);
1588 EXPECT_TRUE(tree.Remove(20, 36, chunk));
1589 EXPECT_EQ(tree.GetRoot()->subtree_end, 45);
1590 EXPECT_EQ(tree.GetRoot()->start, 1);
1591 EXPECT_EQ(tree.GetRoot()->end, 45);
1592 EXPECT_EQ(tree.GetRoot()->left, nullptr);
1593 EXPECT_EQ(tree.GetRoot()->right, nullptr);
1594 EXPECT_TRUE(tree.Remove(1, 45, chunk));
1595 ASSERT_EQ(tree.GetRoot(), nullptr);
1596 }
1597
TEST_F(IntervalTreeTest,ThreeLevelsRight)1598 TEST_F(IntervalTreeTest, ThreeLevelsRight) {
1599 HeapSimulator::Chunk chunk({1, 2}); // Value in chunk doesn't matter here.
1600 // [20, 36] (45)
1601 // \
1602 // [21, 45] (45)
1603 // \
1604 // [22, 40] (40)
1605 BufferIntervalTree tree;
1606 tree.Add(20, 36, chunk);
1607 tree.Add(21, 45, chunk);
1608 tree.Add(22, 40, chunk);
1609 EXPECT_TRUE(tree.Remove(21, 45, chunk));
1610 EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1611 EXPECT_TRUE(tree.Remove(20, 36, chunk));
1612 EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1613 EXPECT_TRUE(tree.Remove(22, 40, chunk));
1614 ASSERT_EQ(tree.GetRoot(), nullptr);
1615 }
TEST_F(IntervalTreeTest,ThreeLevelsLeftLeft)1616 TEST_F(IntervalTreeTest, ThreeLevelsLeftLeft) {
1617 HeapSimulator::Chunk chunk({1, 2}); // Value in chunk doesn't matter here.
1618 // [20, 36] (45)
1619 // /
1620 // [10, 45] (45)
1621 // /
1622 // [1, 40] (40)
1623 BufferIntervalTree tree;
1624 tree.Add(20, 36, chunk);
1625 tree.Add(10, 45, chunk);
1626 tree.Add(1, 40, chunk);
1627 EXPECT_TRUE(tree.Remove(10, 45, chunk));
1628 EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1629 EXPECT_TRUE(tree.Remove(1, 40, chunk));
1630 EXPECT_EQ(tree.GetRoot()->subtree_end, 36);
1631 EXPECT_TRUE(tree.Remove(20, 36, chunk));
1632 ASSERT_EQ(tree.GetRoot(), nullptr);
1633 }
1634
TEST_F(IntervalTreeTest,ThreeLevelsLeftRight)1635 TEST_F(IntervalTreeTest, ThreeLevelsLeftRight) {
1636 HeapSimulator::Chunk chunk({1, 2}); // Value in chunk doesn't matter here.
1637 // [20, 36] (45)
1638 // /
1639 // [10, 45] (45)
1640 // \
1641 // [15, 40] (40)
1642 BufferIntervalTree tree;
1643 tree.Add(20, 36, chunk);
1644 tree.Add(10, 45, chunk);
1645 tree.Add(15, 40, chunk);
1646 EXPECT_TRUE(tree.Remove(10, 45, chunk));
1647 EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1648 EXPECT_TRUE(tree.Remove(15, 40, chunk));
1649 EXPECT_EQ(tree.GetRoot()->subtree_end, 36);
1650 EXPECT_TRUE(tree.Remove(20, 36, chunk));
1651 ASSERT_EQ(tree.GetRoot(), nullptr);
1652 }
1653
TEST_F(IntervalTreeTest,ThreeLevelsRightLeft)1654 TEST_F(IntervalTreeTest, ThreeLevelsRightLeft) {
1655 HeapSimulator::Chunk chunk({1, 2}); // Value in chunk doesn't matter here.
1656 // [20, 36] (45)
1657 // \
1658 // [25, 45] (45)
1659 // /
1660 // [22, 40] (40)
1661 BufferIntervalTree tree;
1662 tree.Add(20, 36, chunk);
1663 tree.Add(25, 45, chunk);
1664 tree.Add(22, 40, chunk);
1665 EXPECT_TRUE(tree.Remove(25, 45, chunk));
1666 EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1667 EXPECT_TRUE(tree.Remove(20, 36, chunk));
1668 EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1669 EXPECT_TRUE(tree.Remove(22, 40, chunk));
1670 ASSERT_EQ(tree.GetRoot(), nullptr);
1671 }
1672
TEST_F(IntervalTreeTest,ThreeLevelsRightLeftChunkDifferent)1673 TEST_F(IntervalTreeTest, ThreeLevelsRightLeftChunkDifferent) {
1674 HeapSimulator::Chunk chunk1({1, 2});
1675 HeapSimulator::Chunk chunk2({2, 3});
1676 HeapSimulator::Chunk chunk3({3, 4});
1677 // [20, 36] (45) Chunk1({1, 2})
1678 // \
1679 // [25, 45] (45) Chunk2({2, 3})
1680 // /
1681 // [22, 40] (40) Chunk3({3, 4})
1682 BufferIntervalTree tree;
1683 tree.Add(20, 36, chunk1);
1684 tree.Add(25, 45, chunk2);
1685 tree.Add(22, 40, chunk3);
1686 EXPECT_TRUE(tree.Remove(25, 45, chunk2));
1687 // Chunk 1 is till the root after removing chunk 2.
1688 EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1689 EXPECT_EQ(tree.GetRoot()->chunk.offset, 1);
1690 EXPECT_EQ(tree.GetRoot()->chunk.size, 2);
1691 EXPECT_TRUE(tree.Remove(20, 36, chunk1));
1692 // Chunk 3 becomes the root now.
1693 EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1694 EXPECT_EQ(tree.GetRoot()->chunk.offset, 3);
1695 EXPECT_EQ(tree.GetRoot()->chunk.size, 4);
1696 EXPECT_TRUE(tree.Remove(22, 40, chunk3));
1697 ASSERT_EQ(tree.GetRoot(), nullptr);
1698 }
1699
1700 } // namespace
1701 } // namespace xla
1702