1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/heap_simulator.h"
17
18 #include <memory>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/memory/memory.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/service/buffer_value.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
30 #include "tensorflow/compiler/xla/service/hlo_value.h"
31 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35
36 namespace xla {
37 namespace {
38
39 class MinimumMemoryForSequenceTest : public HloTestBase {};
40
TEST_F(MinimumMemoryForSequenceTest,MultiComputation)41 TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
42 auto module = CreateNewVerifiedModule();
43 const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
44 const Shape tuple_shape =
45 ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
46
47 auto cond_builder = HloComputation::Builder("WhileCond");
48 // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
49 HloInstruction* cond_param = cond_builder.AddInstruction(
50 HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
51 HloInstruction* cond_iter = cond_builder.AddInstruction(
52 HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
53 HloInstruction* cond_data = cond_builder.AddInstruction(
54 HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
55 // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
56 HloInstruction* cond_lt = cond_builder.AddInstruction(
57 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
58 cond_data, ComparisonDirection::kLt));
59 HloComputation* cond_computation =
60 module->AddEmbeddedComputation(cond_builder.Build());
61
62 auto body_builder = HloComputation::Builder("WhileBody");
63 // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
64 HloInstruction* body_param = body_builder.AddInstruction(
65 HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
66 HloComputation* body_computation =
67 module->AddEmbeddedComputation(body_builder.Build());
68
69 auto builder = HloComputation::Builder(TestName());
70 // Entry params: 8 bytes (4 bytes per param), TOTAL=8
71 HloInstruction* iter = builder.AddInstruction(
72 HloInstruction::CreateParameter(0, scalar_shape, "param_iter"));
73 HloInstruction* data = builder.AddInstruction(
74 HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
75 // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24
76 HloInstruction* tuple =
77 builder.AddInstruction(HloInstruction::CreateTuple({iter, data}));
78 // While: 8 bytes (4 bytes per element), TOTAL=32
79 // Both cond and body use a max of 24 bytes, TOTAL=56
80 HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
81 tuple_shape, cond_computation, body_computation, tuple));
82 HloComputation* entry_computation =
83 module->AddEntryComputation(builder.Build());
84
85 auto size_fn = [](const BufferValue& buffer) {
86 return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
87 };
88
89 HloSchedule schedule(module.get());
90 schedule.set_sequence(cond_computation,
91 {cond_param, cond_iter, cond_data, cond_lt});
92 schedule.set_sequence(body_computation, {body_param});
93 schedule.set_sequence(entry_computation, {iter, data, tuple, while_op});
94 TF_ASSERT_OK(schedule.Verify());
95
96 EXPECT_EQ(
97 56,
98 HeapSimulator::MinimumMemoryForModule(schedule, size_fn).ValueOrDie());
99 }
100
TEST_F(MinimumMemoryForSequenceTest,SubcomputationAccounting)101 TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) {
102 // HloModule SubcomputationAccounting
103
104 // %WhileBody (body_param: f32[4]) -> f32[4] {
105 // %body_param = f32[4]{0} parameter(0)
106 // %constant.1 = f32[4]{0} constant({1, 1, 1, 1})
107 // ROOT %subtract = f32[4]{0} subtract(f32[4]{0} %body_param, f32[4]{0}
108 // %constant.1)
109 // }
110
111 // %WhileCond (cond_param: f32[4]) -> pred[] {
112 // %cond_param = f32[4]{0} parameter(0)
113 // %slice = f32[1]{0} slice(f32[4]{0} %cond_param), slice={[0:1]}
114 // %reshape = f32[] reshape(f32[1]{0} %slice)
115 // %constant = f32[] constant(0)
116 // ROOT %not-equal-to = pred[] compare(f32[] %reshape, f32[] %constant),
117 // direction=NE
118 // }
119
120 // ENTRY %SubcomputationAccounting () -> f32[2,4] {
121 // %constant.3 = f32[2,4]{1,0} constant(f32[2,4] { { 1, 2, 3, 4 }, { 1, 2,
122 // 3, 4 } }) %transpose = f32[2,4]{1,0} transpose(f32[2,4]{1,0}
123 // %constant.3), dimensions={0,1} %constant.2 = f32[4]{0} constant({1, 1, 1,
124 // 1}) %while = f32[4]{0} while(f32[4]{0} %constant.2),
125 // condition=%WhileCond, body=%WhileBody %broadcast = f32[2,4]{1,0}
126 // broadcast(f32[4]{0} %while), dimensions={1} ROOT %add = f32[2,4]{1,0}
127 // add(f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast)
128 // }
129
130 auto module = CreateNewVerifiedModule();
131 const Shape r0f32 = ShapeUtil::MakeShape(F32, {});
132 const Shape r1f32 = ShapeUtil::MakeShape(F32, {4});
133 const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4});
134
135 // reshape(slice(param)) != 0
136 // Needs 5 bytes
137 auto cond_builder = HloComputation::Builder("WhileCond");
138 HloInstruction* cond_param = cond_builder.AddInstruction(
139 HloInstruction::CreateParameter(0, r1f32, "cond_param"));
140 HloInstruction* slice =
141 cond_builder.AddInstruction(HloInstruction::CreateSlice(
142 ShapeUtil::MakeShape(F32, {1}), cond_param, {0}, {1}, {1}));
143 HloInstruction* reshape =
144 cond_builder.AddInstruction(HloInstruction::CreateReshape(r0f32, slice));
145 HloInstruction* zero = cond_builder.AddInstruction(
146 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
147 HloInstruction* cond_comparison = cond_builder.AddInstruction(
148 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), reshape,
149 zero, ComparisonDirection::kNe));
150 auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
151
152 // param - 1
153 // Needs 16 bytes
154 auto body_builder = HloComputation::Builder("WhileBody");
155 HloInstruction* body_param = body_builder.AddInstruction(
156 HloInstruction::CreateParameter(0, r1f32, "body_param"));
157 HloInstruction* one_vector =
158 body_builder.AddInstruction(HloInstruction::CreateConstant(
159 LiteralUtil::CreateR1<float>({1, 1, 1, 1})));
160 HloInstruction* subtract =
161 body_builder.AddInstruction(HloInstruction::CreateBinary(
162 r1f32, HloOpcode::kSubtract, body_param, one_vector));
163 auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
164
165 // transpose(matrix) + bcast(while)
166 auto builder = HloComputation::Builder(TestName());
167 HloInstruction* while_init =
168 builder.AddInstruction(HloInstruction::CreateConstant(
169 LiteralUtil::CreateR1<float>({1, 1, 1, 1})));
170 // Creates 16 bytes, ignoring subcomputations
171 HloInstruction* while_loop =
172 builder.AddInstruction(HloInstruction::CreateWhile(
173 r1f32, cond_computation, body_computation, while_init));
174
175 // Creates 32 bytes and frees 16
176 HloInstruction* bcast = builder.AddInstruction(
177 HloInstruction::CreateBroadcast(r2f32, while_loop, {1}));
178
179 HloInstruction* matrix = builder.AddInstruction(
180 HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>(
181 {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}})));
182 // Creates 32 bytes
183 HloInstruction* transpose = builder.AddInstruction(
184 HloInstruction::CreateTranspose(r2f32, matrix, {0, 1}));
185
186 // Creates 32 bytes and frees 64
187 HloInstruction* add = builder.AddInstruction(
188 HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast));
189
190 auto entry_computation = module->AddEntryComputation(builder.Build());
191
192 HloSchedule schedule(module.get());
193 std::vector<HloInstruction*> cond_vec = {cond_param, slice, reshape, zero,
194 cond_comparison};
195 std::vector<HloInstruction*> while_body_vec = {body_param, one_vector,
196 subtract};
197 std::vector<HloInstruction*> entry_comp_vec = {while_init, while_loop, bcast,
198 matrix, transpose, add};
199 schedule.set_sequence(cond_computation, cond_vec);
200 schedule.set_sequence(body_computation, while_body_vec);
201 schedule.set_sequence(entry_computation, entry_comp_vec);
202
203 auto size_fn = [](const BufferValue& buffer) {
204 return ShapeUtil::ByteSizeOf(buffer.shape());
205 };
206 absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
207 memory_by_computation[cond_computation] = 5;
208 memory_by_computation[body_computation] = 16;
209 std::unique_ptr<TuplePointsToAnalysis> points_to_analysis =
210 TuplePointsToAnalysis::Run(module.get()).ValueOrDie();
211
212 // HeapSimulator accounts for subcomputations. The output buffer is aliased,
213 // so we don't double count.
214 EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation(
215 *entry_computation, schedule.sequence(entry_computation),
216 *points_to_analysis, size_fn, &memory_by_computation)
217 .ValueOrDie());
218 }
219
220 const char kAlloc[] = "Alloc";
221 const char kFree[] = "Free";
222 const char kFinish[] = "Finish";
223
224 // CallSequence records a sequence of Alloc/Free/Finish calls.
225 using CallSequence = std::vector<std::pair<string, const BufferValue*>>;
226
227 // HeapCallRecorder is a dummy heap algorithm that simply records its calls.
228 class HeapCallRecorder : public HeapAlgorithm {
229 public:
HeapCallRecorder(CallSequence * calls)230 explicit HeapCallRecorder(CallSequence* calls) : calls_(calls) {}
~HeapCallRecorder()231 ~HeapCallRecorder() override {}
232
Alloc(const BufferValue * buffer,int64 size)233 void Alloc(const BufferValue* buffer, int64 size) override {
234 calls_->emplace_back(kAlloc, buffer);
235 // Instead of assigning a real offset, we set the cardinality of the Alloc
236 // call. This isn't a valid assignment, but allows us to easily test for
237 // buffer sharing.
238 const int64 offset = result_.chunk_map.size();
239 result_.chunk_map.emplace(buffer, Chunk{offset, size});
240 }
Free(const BufferValue * buffer,int64 size)241 void Free(const BufferValue* buffer, int64 size) override {
242 calls_->emplace_back(kFree, buffer);
243 }
Finish()244 Result Finish() override {
245 calls_->emplace_back(kFinish, nullptr);
246 return result_;
247 }
248
249 private:
250 CallSequence* calls_;
251 Result result_;
252 };
253
254 // HeapSimulatorTracker runs the heap simulator, recording the sequence of calls
255 // made to the underlying heap algorithm. Tests compare the actual call
256 // sequence against an expected sequence.
257 class HeapSimulatorTracker {
258 public:
259 // Constructor for testing a single entry computation.
HeapSimulatorTracker(const string & name,std::unique_ptr<HloComputation> computation,const std::vector<HloInstruction * > & instruction_sequence)260 HeapSimulatorTracker(
261 const string& name, std::unique_ptr<HloComputation> computation,
262 const std::vector<HloInstruction*>& instruction_sequence) {
263 HloModuleConfig config;
264 module_ = absl::make_unique<HloModule>(name, config);
265 module_->AddEntryComputation(std::move(computation));
266 points_to_analysis_ =
267 TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
268 // Since we're only tracking the sequence of Alloc/Free calls, the actual
269 // size of the buffers doesn't matter, so we always return 0. We rely on
270 // the secondary sorting criteria of DecreasingSizeRunsHeap to sort calls by
271 // buffer id, for determinism in the tests.
272 auto zero_size = [](const BufferValue& buffer) { return 0; };
273 auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>(
274 absl::make_unique<HeapCallRecorder>(&actual_calls_));
275 result_ =
276 HeapSimulator::Run(std::move(algorithm), *module_->entry_computation(),
277 HloInstructionSequence(instruction_sequence),
278 *points_to_analysis_, zero_size)
279 .ConsumeValueOrDie();
280 }
281
HeapSimulatorTracker(const string & name)282 explicit HeapSimulatorTracker(const string& name) {
283 HloModuleConfig config;
284 module_ = absl::make_unique<HloModule>(name, config);
285 }
286
287 // Similar to the single entry computation constructor above, but runs the
288 // simulation over the entire module.
RunWholeModule(const std::vector<HloInstruction * > & full_module_sequence)289 void RunWholeModule(
290 const std::vector<HloInstruction*>& full_module_sequence) {
291 points_to_analysis_ =
292 TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
293
294 // Construct the module sequence grouped by computation.
295 HloSchedule schedule(module_.get());
296 absl::flat_hash_map<const HloInstruction*, int> reverse_position;
297 for (int i = 0; i < full_module_sequence.size(); ++i) {
298 HloInstruction* instruction = full_module_sequence[i];
299 schedule.GetOrCreateSequence(instruction->parent())
300 .push_back(instruction);
301 reverse_position[instruction] = full_module_sequence.size() - i;
302 }
303
304 // Hack the size_fn so that it returns a decreasing value as we step through
305 // the sequence. This lets us ensure the Alloc calls are in the sequence
306 // order. The Free calls are sorted by BufferValue.id, which is at least
307 // deterministic.
308 auto size_fn = [&reverse_position](const BufferValue& buffer) {
309 return reverse_position[buffer.instruction()];
310 };
311 auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>(
312 absl::make_unique<HeapCallRecorder>(&actual_calls_));
313 result_ = HeapSimulator::Run(std::move(algorithm), *module_, schedule,
314 *points_to_analysis_, size_fn)
315 .ConsumeValueOrDie();
316 }
317
module()318 HloModule* module() { return module_.get(); }
319
320 // Returns the buffer defined at the given instruction and index.
BufferAt(const HloInstruction * instruction,const ShapeIndex & index) const321 const BufferValue* BufferAt(const HloInstruction* instruction,
322 const ShapeIndex& index) const {
323 return points_to_analysis_->GetBufferDefinedAt(instruction, index)
324 .ConsumeValueOrDie();
325 }
326
OffsetAt(const HloInstruction * instruction,const ShapeIndex & index)327 int64 OffsetAt(const HloInstruction* instruction, const ShapeIndex& index) {
328 const BufferValue* buffer = BufferAt(instruction, index);
329 return result_.chunk_map.at(buffer).offset;
330 }
331
332 // Ensures the expected sequence of Alloc/Free/Finish calls was performed.
ExpectCallSequence(const CallSequence & expected) const333 void ExpectCallSequence(const CallSequence& expected) const {
334 EXPECT_EQ(expected, actual_calls_);
335 }
336
337 // Ensures the buffers defined by the respective (instruction,index) pairs are
338 // shared, relying on the unique offsets assigned in HeapCallRecorder::Alloc.
ExpectSharedBuffers(const HloInstruction * instruction_a,const ShapeIndex & index_a,const HloInstruction * instruction_b,const ShapeIndex & index_b)339 void ExpectSharedBuffers(const HloInstruction* instruction_a,
340 const ShapeIndex& index_a,
341 const HloInstruction* instruction_b,
342 const ShapeIndex& index_b) {
343 int64 offset_a = OffsetAt(instruction_a, index_a);
344 int64 offset_b = OffsetAt(instruction_b, index_b);
345 EXPECT_EQ(offset_a, offset_b);
346 }
347
348 private:
349 std::unique_ptr<HloModule> module_;
350 std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
351 CallSequence actual_calls_;
352 HeapSimulator::Result result_;
353 };
354
355 class HeapSimulatorTest : public HloTestBase {
356 protected:
HeapSimulatorTest()357 HeapSimulatorTest() {}
~HeapSimulatorTest()358 ~HeapSimulatorTest() override {}
359
360 // Shapes for use in the examples.
361 Shape f32scalar_ = ShapeUtil::MakeShape(xla::F32, {});
362 Shape f32vec4_ = ShapeUtil::MakeShape(F32, {4});
363 };
364
TEST_F(HeapSimulatorTest,ScalarConstant)365 TEST_F(HeapSimulatorTest, ScalarConstant) {
366 auto builder = HloComputation::Builder(TestName());
367 auto const0 = builder.AddInstruction(
368 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
369
370 // Constants aren't assigned. See b/32248867
371 HeapSimulatorTracker tracker(TestName(), builder.Build(), {const0});
372 tracker.ExpectCallSequence({{kFinish, nullptr}});
373 }
374
TEST_F(HeapSimulatorTest,OneParam)375 TEST_F(HeapSimulatorTest, OneParam) {
376 auto builder = HloComputation::Builder(TestName());
377 auto param0 = builder.AddInstruction(
378 HloInstruction::CreateParameter(0, f32scalar_, "param0"));
379
380 // A single parameter which is also the output.
381 HeapSimulatorTracker tracker(TestName(), builder.Build(), {param0});
382 tracker.ExpectCallSequence({
383 {kAlloc, tracker.BufferAt(param0, {})},
384 {kFree, tracker.BufferAt(param0, {})},
385 {kFinish, nullptr},
386 });
387 }
388
TEST_F(HeapSimulatorTest,Multiply)389 TEST_F(HeapSimulatorTest, Multiply) {
390 auto builder = HloComputation::Builder(TestName());
391 auto paramA = builder.AddInstruction(
392 HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
393 auto paramX = builder.AddInstruction(
394 HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
395 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
396 f32vec4_, HloOpcode::kMultiply, paramA, paramX));
397
398 // We must keep all parameters and outputs.
399 HeapSimulatorTracker tracker(TestName(), builder.Build(),
400 {paramA, paramX, mul});
401 tracker.ExpectCallSequence({
402 {kAlloc, tracker.BufferAt(paramA, {})},
403 {kAlloc, tracker.BufferAt(paramX, {})},
404 {kAlloc, tracker.BufferAt(mul, {})},
405 // All params and outputs are freed at the end.
406 {kFree, tracker.BufferAt(paramA, {})},
407 {kFree, tracker.BufferAt(paramX, {})},
408 {kFree, tracker.BufferAt(mul, {})},
409 {kFinish, nullptr},
410 });
411 }
412
TEST_F(HeapSimulatorTest,MultiplyAdd)413 TEST_F(HeapSimulatorTest, MultiplyAdd) {
414 auto builder = HloComputation::Builder(TestName());
415 auto paramA = builder.AddInstruction(
416 HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
417 auto paramX = builder.AddInstruction(
418 HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
419 auto paramY = builder.AddInstruction(
420 HloInstruction::CreateParameter(2, f32vec4_, "paramY"));
421 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
422 f32vec4_, HloOpcode::kMultiply, paramA, paramX));
423 auto add = builder.AddInstruction(
424 HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, mul, paramY));
425
426 // The buffer for add is the output, and it's shared with the buffer for mul.
427 HeapSimulatorTracker tracker(TestName(), builder.Build(),
428 {paramA, paramX, mul, paramY, add});
429 tracker.ExpectCallSequence({
430 {kAlloc, tracker.BufferAt(paramA, {})},
431 {kAlloc, tracker.BufferAt(paramX, {})},
432 {kAlloc, tracker.BufferAt(mul, {})},
433 {kAlloc, tracker.BufferAt(paramY, {})},
434 // All params and outputs are freed at the end.
435 {kFree, tracker.BufferAt(paramA, {})},
436 {kFree, tracker.BufferAt(paramX, {})},
437 {kFree, tracker.BufferAt(mul, {})},
438 {kFree, tracker.BufferAt(paramY, {})},
439 {kFinish, nullptr},
440 });
441 tracker.ExpectSharedBuffers(add, {}, mul, {});
442 }
443
TEST_F(HeapSimulatorTest,BufferReusedOnce)444 TEST_F(HeapSimulatorTest, BufferReusedOnce) {
445 HeapSimulatorTracker tracker(TestName());
446 auto builder = HloComputation::Builder(TestName());
447
448 HloComputation::Builder fusion_builder("fusion");
449 {
450 HloComputation::Builder& builder = fusion_builder;
451 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
452 /*parameter_number=*/0, f32vec4_, "A"));
453 auto exp = builder.AddInstruction(
454 HloInstruction::CreateUnary(f32vec4_, HloOpcode::kExp, a_param));
455 auto neg = builder.AddInstruction(
456 HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, a_param));
457
458 builder.AddInstruction(HloInstruction::CreateTuple({exp, neg}));
459 }
460 auto fusion_computation =
461 tracker.module()->AddEmbeddedComputation(fusion_builder.Build());
462 auto a_param = builder.AddInstruction(
463 HloInstruction::CreateParameter(0, f32vec4_, "paramA"));
464 auto neg = builder.AddInstruction(
465 HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, a_param));
466 auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
467 ShapeUtil::MakeTupleShape({f32vec4_, f32vec4_}),
468 HloInstruction::FusionKind::kLoop, {neg}, fusion_computation));
469 tracker.module()->AddEntryComputation(builder.Build());
470
471 tracker.RunWholeModule({a_param, neg, fusion});
472
473 auto neg_buffer = tracker.OffsetAt(neg, {});
474 int64 output_buffer_0 = tracker.OffsetAt(fusion, {0});
475 int64 output_buffer_1 = tracker.OffsetAt(fusion, {1});
476 // Only one buffer should be shared.
477 EXPECT_TRUE((neg_buffer == output_buffer_0) ^
478 (neg_buffer == output_buffer_1));
479 }
480
TEST_F(HeapSimulatorTest,MultiplyDot)481 TEST_F(HeapSimulatorTest, MultiplyDot) {
482 auto builder = HloComputation::Builder(TestName());
483 auto paramA = builder.AddInstruction(
484 HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
485 auto paramX = builder.AddInstruction(
486 HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
487 auto paramY = builder.AddInstruction(
488 HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
489 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
490 f32vec4_, HloOpcode::kMultiply, paramA, paramX));
491 DotDimensionNumbers dot_dnums;
492 dot_dnums.add_lhs_contracting_dimensions(1);
493 dot_dnums.add_rhs_contracting_dimensions(0);
494 auto dot = builder.AddInstruction(HloInstruction::CreateDot(
495 f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
496
497 // The buffer for dot is the output, and it cannot be shared with the buffer
498 // for mul, since dot isn't elementwise.
499 HeapSimulatorTracker tracker(TestName(), builder.Build(),
500 {paramA, paramX, mul, paramY, dot});
501 tracker.ExpectCallSequence({
502 {kAlloc, tracker.BufferAt(paramA, {})},
503 {kAlloc, tracker.BufferAt(paramX, {})},
504 {kAlloc, tracker.BufferAt(mul, {})},
505 {kAlloc, tracker.BufferAt(paramY, {})},
506 {kAlloc, tracker.BufferAt(dot, {})},
507 // All params and outputs are freed at the end.
508 {kFree, tracker.BufferAt(paramA, {})},
509 {kFree, tracker.BufferAt(paramX, {})},
510 {kFree, tracker.BufferAt(mul, {})},
511 {kFree, tracker.BufferAt(paramY, {})},
512 {kFree, tracker.BufferAt(dot, {})},
513 {kFinish, nullptr},
514 });
515 }
516
TEST_F(HeapSimulatorTest,MultiplyDotAdd)517 TEST_F(HeapSimulatorTest, MultiplyDotAdd) {
518 auto builder = HloComputation::Builder(TestName());
519 auto paramA = builder.AddInstruction(
520 HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
521 auto paramX = builder.AddInstruction(
522 HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
523 auto paramY = builder.AddInstruction(
524 HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
525 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
526 f32vec4_, HloOpcode::kMultiply, paramA, paramX));
527 DotDimensionNumbers dot_dnums;
528 dot_dnums.add_lhs_contracting_dimensions(1);
529 dot_dnums.add_rhs_contracting_dimensions(0);
530 auto dot = builder.AddInstruction(HloInstruction::CreateDot(
531 f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
532 auto add = builder.AddInstruction(
533 HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA));
534
535 // The buffer for add is the output, and it's shared with the buffer for dot.
536 HeapSimulatorTracker tracker(TestName(), builder.Build(),
537 {paramA, paramX, mul, paramY, dot, add});
538 tracker.ExpectCallSequence({
539 {kAlloc, tracker.BufferAt(paramA, {})},
540 {kAlloc, tracker.BufferAt(paramX, {})},
541 {kAlloc, tracker.BufferAt(mul, {})},
542 {kAlloc, tracker.BufferAt(paramY, {})},
543 {kAlloc, tracker.BufferAt(dot, {})},
544 // All params and outputs are freed at the end.
545 {kFree, tracker.BufferAt(paramA, {})},
546 {kFree, tracker.BufferAt(paramX, {})},
547 {kFree, tracker.BufferAt(mul, {})},
548 {kFree, tracker.BufferAt(paramY, {})},
549 {kFree, tracker.BufferAt(dot, {})},
550 {kFinish, nullptr},
551 });
552 tracker.ExpectSharedBuffers(add, {}, dot, {});
553 }
554
TEST_F(HeapSimulatorTest,MultiplyDotDot)555 TEST_F(HeapSimulatorTest, MultiplyDotDot) {
556 auto builder = HloComputation::Builder(TestName());
557 auto paramA = builder.AddInstruction(
558 HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
559 auto paramX = builder.AddInstruction(
560 HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
561 auto paramY = builder.AddInstruction(
562 HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
563 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
564 f32vec4_, HloOpcode::kMultiply, paramA, paramX));
565 DotDimensionNumbers dot_dnums;
566 dot_dnums.add_lhs_contracting_dimensions(1);
567 dot_dnums.add_rhs_contracting_dimensions(0);
568 auto dot0 = builder.AddInstruction(HloInstruction::CreateDot(
569 f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
570 auto dot1 = builder.AddInstruction(HloInstruction::CreateDot(
571 f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2)));
572
573 // The buffer for dot1 is the output. No buffers can be shared. The buffer
574 // for mul is freed before the end, since it's no longer used after dot0
575 // finishes.
576 HeapSimulatorTracker tracker(TestName(), builder.Build(),
577 {paramA, paramX, mul, paramY, dot0, dot1});
578 tracker.ExpectCallSequence({
579 {kAlloc, tracker.BufferAt(paramA, {})},
580 {kAlloc, tracker.BufferAt(paramX, {})},
581 {kAlloc, tracker.BufferAt(mul, {})},
582 {kAlloc, tracker.BufferAt(paramY, {})},
583 {kAlloc, tracker.BufferAt(dot0, {})},
584 {kFree, tracker.BufferAt(mul, {})}, // mul no longer used
585 {kAlloc, tracker.BufferAt(dot1, {})},
586 // All params and outputs are freed at the end.
587 {kFree, tracker.BufferAt(paramA, {})},
588 {kFree, tracker.BufferAt(paramX, {})},
589 {kFree, tracker.BufferAt(paramY, {})},
590 {kFree, tracker.BufferAt(dot0, {})},
591 {kFree, tracker.BufferAt(dot1, {})},
592 {kFinish, nullptr},
593 });
594 }
595
TEST_F(HeapSimulatorTest,MultiplyDotDotTuple)596 TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) {
597 auto builder = HloComputation::Builder(TestName());
598 auto paramA = builder.AddInstruction(
599 HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
600 auto paramX = builder.AddInstruction(
601 HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
602 auto paramY = builder.AddInstruction(
603 HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
604 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
605 f32vec4_, HloOpcode::kMultiply, paramA, paramX));
606 DotDimensionNumbers dot_dnums;
607 dot_dnums.add_lhs_contracting_dimensions(1);
608 dot_dnums.add_rhs_contracting_dimensions(0);
609 auto dot0 = builder.AddInstruction(HloInstruction::CreateDot(
610 f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
611 auto dot1 = builder.AddInstruction(HloInstruction::CreateDot(
612 f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2)));
613 auto tuple =
614 builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1}));
615
616 // The buffers for dot0, dot1 and tuple are the output. No buffers can be
617 // shared. The buffer for mul is freed before the end, since it's no longer
618 // used after dot0 finishes.
619 HeapSimulatorTracker tracker(
620 TestName(), builder.Build(),
621 {paramA, paramX, mul, paramY, dot0, dot1, tuple});
622 tracker.ExpectCallSequence({
623 {kAlloc, tracker.BufferAt(paramA, {})},
624 {kAlloc, tracker.BufferAt(paramX, {})},
625 {kAlloc, tracker.BufferAt(mul, {})},
626 {kAlloc, tracker.BufferAt(paramY, {})},
627 {kAlloc, tracker.BufferAt(dot0, {})},
628 {kFree, tracker.BufferAt(mul, {})}, // mul no longer used
629 {kAlloc, tracker.BufferAt(dot1, {})},
630 {kAlloc, tracker.BufferAt(tuple, {})},
631 // All params and outputs are freed at the end.
632 {kFree, tracker.BufferAt(paramA, {})},
633 {kFree, tracker.BufferAt(paramX, {})},
634 {kFree, tracker.BufferAt(paramY, {})},
635 {kFree, tracker.BufferAt(dot0, {})},
636 {kFree, tracker.BufferAt(dot1, {})},
637 {kFree, tracker.BufferAt(tuple, {})},
638 {kFinish, nullptr},
639 });
640 }
641
TEST_F(HeapSimulatorTest,IndependentTupleElements)642 TEST_F(HeapSimulatorTest, IndependentTupleElements) {
643 auto builder = HloComputation::Builder(TestName());
644 auto paramA = builder.AddInstruction(
645 HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
646 auto paramB = builder.AddInstruction(
647 HloInstruction::CreateParameter(1, f32scalar_, "paramB"));
648 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
649 f32scalar_, HloOpcode::kMultiply, paramA, paramB));
650 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
651 f32scalar_, HloOpcode::kAdd, paramA, paramB));
652 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({mul, add}));
653 auto element0 = builder.AddInstruction(
654 HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 0));
655 auto broadcast = builder.AddInstruction(
656 HloInstruction::CreateBroadcast(f32vec4_, element0, {0}));
657 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
658 f32scalar_, HloOpcode::kSubtract, paramA, paramB));
659 auto element1 = builder.AddInstruction(
660 HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 1));
661 auto output = builder.AddInstruction(
662 HloInstruction::CreateTuple({broadcast, sub, element1}));
663
664 HeapSimulatorTracker tracker(TestName(), builder.Build(),
665 {paramA, paramB, mul, add, tuple, element0,
666 broadcast, sub, element1, output});
667 tracker.ExpectCallSequence({
668 {kAlloc, tracker.BufferAt(paramA, {})},
669 {kAlloc, tracker.BufferAt(paramB, {})},
670 {kAlloc, tracker.BufferAt(mul, {})},
671 {kAlloc, tracker.BufferAt(add, {})},
672 {kAlloc, tracker.BufferAt(tuple, {})},
673 {kAlloc, tracker.BufferAt(broadcast, {})},
674 // The mul can be freed right after the broadcast happens, even though
675 // The other GetTupleElement is still alive.
676 {kFree, tracker.BufferAt(mul, {})},
677 {kAlloc, tracker.BufferAt(sub, {})},
678 // The temporary tuple is now dead.
679 {kFree, tracker.BufferAt(tuple, {})},
680 {kAlloc, tracker.BufferAt(output, {})},
681 // All params and outputs are freed at the end.
682 {kFree, tracker.BufferAt(paramA, {})},
683 {kFree, tracker.BufferAt(paramB, {})},
684 {kFree, tracker.BufferAt(add, {})},
685 {kFree, tracker.BufferAt(broadcast, {})},
686 {kFree, tracker.BufferAt(sub, {})},
687 {kFree, tracker.BufferAt(output, {})},
688 {kFinish, nullptr},
689 });
690 }
691
TEST_F(HeapSimulatorTest,WholeModule)692 TEST_F(HeapSimulatorTest, WholeModule) {
693 HeapSimulatorTracker tracker(TestName());
694
695 const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
696 const Shape tuple_shape =
697 ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
698
699 auto cond_builder = HloComputation::Builder("WhileCond");
700 HloInstruction* cond_param = cond_builder.AddInstruction(
701 HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
702 HloInstruction* cond_iter = cond_builder.AddInstruction(
703 HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
704 HloInstruction* cond_data = cond_builder.AddInstruction(
705 HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
706 HloInstruction* cond_lt = cond_builder.AddInstruction(
707 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
708 cond_data, ComparisonDirection::kLt));
709 HloComputation* cond_computation =
710 tracker.module()->AddEmbeddedComputation(cond_builder.Build());
711
712 auto body_builder = HloComputation::Builder("WhileBody");
713 HloInstruction* body_param = body_builder.AddInstruction(
714 HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
715 HloComputation* body_computation =
716 tracker.module()->AddEmbeddedComputation(body_builder.Build());
717
718 auto builder = HloComputation::Builder(TestName());
719 HloInstruction* param = builder.AddInstruction(
720 HloInstruction::CreateParameter(0, tuple_shape, "param"));
721 HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
722 tuple_shape, cond_computation, body_computation, param));
723 tracker.module()->AddEntryComputation(builder.Build());
724
725 tracker.RunWholeModule(
726 {param, while_op, body_param, cond_param, cond_iter, cond_data, cond_lt});
727 tracker.ExpectCallSequence({
728 // The entry computation param and while_op are allocated first.
729 {kAlloc, tracker.BufferAt(param, {})},
730 {kAlloc, tracker.BufferAt(param, {0})},
731 {kAlloc, tracker.BufferAt(param, {1})},
732 {kAlloc, tracker.BufferAt(while_op, {})},
733 {kAlloc, tracker.BufferAt(while_op, {0})},
734 {kAlloc, tracker.BufferAt(while_op, {1})},
735
736 // Now the while body param is allocated and freed.
737 {kAlloc, tracker.BufferAt(body_param, {})},
738 {kAlloc, tracker.BufferAt(body_param, {0})},
739 {kAlloc, tracker.BufferAt(body_param, {1})},
740 {kFree, tracker.BufferAt(body_param, {})},
741 {kFree, tracker.BufferAt(body_param, {0})},
742 {kFree, tracker.BufferAt(body_param, {1})},
743
744 // Now the while cond param is allocated. The GTE instructions just alias
745 // the param elements, so the param tuple can immediately be freed.
746 {kAlloc, tracker.BufferAt(cond_param, {})},
747 {kAlloc, tracker.BufferAt(cond_param, {0})},
748 {kAlloc, tracker.BufferAt(cond_param, {1})},
749 {kFree, tracker.BufferAt(cond_param, {})},
750
751 // Now the final cond less-than buffer is allocated.
752 {kAlloc, tracker.BufferAt(cond_lt, {})},
753
754 // The order of the remaining Free calls is based on the BufferValue.id,
755 // which is deterministic, but not obvious.
756 {kFree, tracker.BufferAt(param, {})},
757 {kFree, tracker.BufferAt(param, {0})},
758 {kFree, tracker.BufferAt(param, {1})},
759
760 {kFree, tracker.BufferAt(while_op, {})},
761 {kFree, tracker.BufferAt(while_op, {0})},
762 {kFree, tracker.BufferAt(while_op, {1})},
763
764 {kFree, tracker.BufferAt(cond_param, {0})},
765 {kFree, tracker.BufferAt(cond_param, {1})},
766 {kFree, tracker.BufferAt(cond_lt, {})},
767
768 {kFinish, nullptr},
769 });
770 }
771
772 // Base class for heap algorithm tests.
773 class HeapAlgorithmTestBase : public ::testing::Test {
774 protected:
HeapAlgorithmTestBase()775 HeapAlgorithmTestBase() : builder_("heap_simulator_test") {
776 buffer_a_ = DummyBufferValue();
777 buffer_b_ = DummyBufferValue();
778 buffer_c_ = DummyBufferValue();
779 buffer_d_ = DummyBufferValue();
780 buffer_e_ = DummyBufferValue();
781 buffer_f_ = DummyBufferValue();
782 buffer_g_ = DummyBufferValue();
783 buffer_h_ = DummyBufferValue();
784 buffer_i_ = DummyBufferValue();
785 }
~HeapAlgorithmTestBase()786 ~HeapAlgorithmTestBase() override {}
787
788 const BufferValue* buffer_a_;
789 const BufferValue* buffer_b_;
790 const BufferValue* buffer_c_;
791 const BufferValue* buffer_d_;
792 const BufferValue* buffer_e_;
793 const BufferValue* buffer_f_;
794 const BufferValue* buffer_g_;
795 const BufferValue* buffer_h_;
796 const BufferValue* buffer_i_;
797
798 private:
799 // Create a dummy BufferValue to pass to the heap algorithm.
DummyBufferValue()800 const BufferValue* DummyBufferValue() {
801 const BufferValue::Id id = buffers_.size();
802 auto const0 = builder_.AddInstruction(
803 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
804 buffers_.emplace_back(
805 absl::make_unique<HloValue>(id, const0, ShapeIndex{}));
806 return buffers_.back().get();
807 }
808
809 HloComputation::Builder builder_;
810 std::vector<std::unique_ptr<BufferValue>> buffers_;
811 };
812
813 class NoFragmentationStatsHeapTest : public HeapAlgorithmTestBase {};
814
TEST_F(NoFragmentationStatsHeapTest,Empty)815 TEST_F(NoFragmentationStatsHeapTest, Empty) {
816 NoFragmentationStatsHeap heap;
817 EXPECT_EQ(0, heap.Finish().heap_size);
818 }
819
TEST_F(NoFragmentationStatsHeapTest,Simple)820 TEST_F(NoFragmentationStatsHeapTest, Simple) {
821 NoFragmentationStatsHeap heap;
822 heap.Alloc(buffer_a_, 10);
823 heap.Alloc(buffer_b_, 20);
824 heap.Alloc(buffer_c_, 30);
825 heap.Alloc(buffer_d_, 30);
826 heap.Free(buffer_a_, 10);
827 heap.Free(buffer_b_, 20);
828 heap.Free(buffer_c_, 30);
829 heap.Free(buffer_d_, 30);
830 EXPECT_EQ(90, heap.Finish().heap_size);
831 }
832
TEST_F(NoFragmentationStatsHeapTest,Mixed)833 TEST_F(NoFragmentationStatsHeapTest, Mixed) {
834 NoFragmentationStatsHeap heap;
835 heap.Alloc(buffer_a_, 10); // max: A
836
837 heap.Alloc(buffer_b_, 20); // max: A+B
838 heap.Free(buffer_b_, 20);
839
840 heap.Alloc(buffer_c_, 30); // max: A+C
841 heap.Free(buffer_c_, 30);
842
843 heap.Alloc(buffer_d_, 5); // max: A+C
844 heap.Free(buffer_d_, 5);
845
846 heap.Free(buffer_a_, 10);
847 EXPECT_EQ(40, heap.Finish().heap_size);
848 }
849
850 class DecreasingSizeRunsHeapTest : public HeapAlgorithmTestBase {};
851
TEST_F(DecreasingSizeRunsHeapTest,Empty)852 TEST_F(DecreasingSizeRunsHeapTest, Empty) {
853 CallSequence call_sequence;
854 DecreasingSizeRunsHeap heap(
855 absl::make_unique<HeapCallRecorder>(&call_sequence));
856 heap.Finish();
857 EXPECT_EQ(call_sequence, CallSequence({
858 {kFinish, nullptr},
859 }));
860 }
861
TEST_F(DecreasingSizeRunsHeapTest,Simple)862 TEST_F(DecreasingSizeRunsHeapTest, Simple) {
863 CallSequence call_sequence;
864 DecreasingSizeRunsHeap heap(
865 absl::make_unique<HeapCallRecorder>(&call_sequence));
866 heap.Alloc(buffer_a_, 10);
867 heap.Alloc(buffer_b_, 20);
868 heap.Alloc(buffer_c_, 30);
869 heap.Alloc(buffer_d_, 30);
870 heap.Free(buffer_a_, 10);
871 heap.Free(buffer_b_, 20);
872 heap.Free(buffer_c_, 30);
873 heap.Free(buffer_d_, 30);
874 heap.Finish();
875 // Runs of Allocs and Frees are sorted by decreasing size, with buffer id
876 // tiebreaker.
877 EXPECT_EQ(call_sequence, CallSequence({
878 {kAlloc, buffer_c_},
879 {kAlloc, buffer_d_},
880 {kAlloc, buffer_b_},
881 {kAlloc, buffer_a_},
882 {kFree, buffer_c_},
883 {kFree, buffer_d_},
884 {kFree, buffer_b_},
885 {kFree, buffer_a_},
886 {kFinish, nullptr},
887 }));
888 }
889
TEST_F(DecreasingSizeRunsHeapTest,Mixed)890 TEST_F(DecreasingSizeRunsHeapTest, Mixed) {
891 CallSequence call_sequence;
892 DecreasingSizeRunsHeap heap(
893 absl::make_unique<HeapCallRecorder>(&call_sequence));
894 heap.Alloc(buffer_a_, 10);
895 heap.Alloc(buffer_b_, 20);
896 heap.Free(buffer_b_, 20);
897
898 heap.Alloc(buffer_c_, 30);
899 heap.Free(buffer_c_, 30);
900
901 heap.Alloc(buffer_d_, 5);
902 heap.Free(buffer_d_, 5);
903 heap.Free(buffer_a_, 10);
904 heap.Finish();
905 // Runs of Allocs and Frees are sorted by decreasing size.
906 EXPECT_EQ(call_sequence, CallSequence({
907 {kAlloc, buffer_b_},
908 {kAlloc, buffer_a_},
909 {kFree, buffer_b_},
910
911 {kAlloc, buffer_c_},
912 {kFree, buffer_c_},
913
914 {kAlloc, buffer_d_},
915 {kFree, buffer_a_},
916 {kFree, buffer_d_},
917 {kFinish, nullptr},
918 }));
919 }
920
921 class LazyBestFitHeapTest : public HeapAlgorithmTestBase {};
922
TEST_F(LazyBestFitHeapTest,Empty)923 TEST_F(LazyBestFitHeapTest, Empty) {
924 LazyBestFitHeap heap(/*alignment=*/1);
925 const HeapSimulator::Result result = heap.Finish();
926 EXPECT_EQ(0, result.heap_size);
927 EXPECT_EQ(0, result.chunk_map.size());
928 }
929
TEST_F(LazyBestFitHeapTest,Simple)930 TEST_F(LazyBestFitHeapTest, Simple) {
931 LazyBestFitHeap heap(/*alignment=*/1);
932 heap.Alloc(buffer_a_, 10);
933 heap.Alloc(buffer_b_, 20);
934 heap.Alloc(buffer_c_, 30);
935 heap.Alloc(buffer_d_, 30);
936 heap.Free(buffer_a_, 10);
937 heap.Free(buffer_b_, 20);
938 heap.Free(buffer_c_, 30);
939 heap.Free(buffer_d_, 30);
940
941 const HeapSimulator::Result result = heap.Finish();
942 EXPECT_EQ(90, result.heap_size);
943 EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
944 EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
945 EXPECT_EQ(30, result.chunk_map.at(buffer_c_).size);
946 EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size);
947
948 EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
949 EXPECT_EQ(10, result.chunk_map.at(buffer_b_).offset);
950 EXPECT_EQ(30, result.chunk_map.at(buffer_c_).offset);
951 EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset);
952 }
953
TEST_F(LazyBestFitHeapTest,Mixed)954 TEST_F(LazyBestFitHeapTest, Mixed) {
955 LazyBestFitHeap heap(/*alignment=*/1);
956 heap.Alloc(buffer_a_, 10); // A lazy offset
957
958 heap.Alloc(buffer_b_, 20); // B lazy offset
959 heap.Free(buffer_b_, 20); // B range = [0, 20) free = [0, 20)
960
961 heap.Alloc(buffer_c_, 30); // C range = [0, 30)
962 heap.Free(buffer_c_, 30); // free = [0, 30)
963
964 heap.Alloc(buffer_d_, 5); // D range = [0, 5) free = [5, 30)
965 heap.Free(buffer_d_, 5); // free = [0, 30)
966
967 heap.Free(buffer_a_, 10); // A range = [30, 10) free = [0, 40)
968
969 const HeapSimulator::Result result = heap.Finish();
970 EXPECT_EQ(40, result.heap_size);
971 EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
972 EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
973 EXPECT_EQ(30, result.chunk_map.at(buffer_c_).size);
974 EXPECT_EQ(5, result.chunk_map.at(buffer_d_).size);
975
976 EXPECT_EQ(30, result.chunk_map.at(buffer_a_).offset);
977 EXPECT_EQ(0, result.chunk_map.at(buffer_b_).offset);
978 EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset);
979 EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset);
980 }
981
TEST_F(LazyBestFitHeapTest,BestFit)982 TEST_F(LazyBestFitHeapTest, BestFit) {
983 LazyBestFitHeap heap(/*alignment=*/1);
984
985 // First alloc/free buffer_a_, to force a big free chunk to appear.
986 heap.Alloc(buffer_a_, 200); // A lazy offset
987 heap.Free(buffer_a_, 200); // A range = [0, 200) free = [0, 200)
988
989 // Now alloc a bunch of buffers that are allocated out of the free chunk.
990 heap.Alloc(buffer_b_, 30); // B range = [0, 30) free = [30, 200)
991 heap.Alloc(buffer_c_, 30); // C range = [30, 60) free = [60, 200)
992 heap.Alloc(buffer_d_, 20); // D range = [60, 80) free = [80, 200)
993 heap.Alloc(buffer_e_, 20); // E range = [80, 100) free = [100, 200)
994 heap.Alloc(buffer_f_, 10); // F range = [100, 110) free = [110, 200)
995 heap.Alloc(buffer_g_, 10); // G range = [110, 120) free = [120, 200)
996 heap.Alloc(buffer_h_, 80); // H range = [120, 200)
997
998 // Free buffers to create free chunks of different sizes.
999 heap.Free(buffer_c_, 30); // free = [30, 60)
1000 heap.Free(buffer_e_, 20); // free = [30, 60), [80, 100)
1001 heap.Free(buffer_g_, 10); // free = [30, 60), [80, 100), [110, 120)
1002
1003 // The best fit is picked out of the existing free chunks.
1004 heap.Alloc(buffer_i_, 15); // I range = [80, 95)
1005
1006 // The frees here ensure the buffer-coalescing logic is exercised.
1007 heap.Free(buffer_b_, 30);
1008 heap.Free(buffer_d_, 20);
1009 heap.Free(buffer_f_, 10);
1010 heap.Free(buffer_h_, 80);
1011 heap.Free(buffer_i_, 15);
1012
1013 const HeapSimulator::Result result = heap.Finish();
1014 EXPECT_EQ(200, result.heap_size);
1015 EXPECT_EQ(200, result.chunk_map.at(buffer_a_).size);
1016 EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size);
1017 EXPECT_EQ(30, result.chunk_map.at(buffer_c_).size);
1018 EXPECT_EQ(20, result.chunk_map.at(buffer_d_).size);
1019 EXPECT_EQ(20, result.chunk_map.at(buffer_e_).size);
1020 EXPECT_EQ(10, result.chunk_map.at(buffer_f_).size);
1021 EXPECT_EQ(10, result.chunk_map.at(buffer_g_).size);
1022 EXPECT_EQ(80, result.chunk_map.at(buffer_h_).size);
1023 EXPECT_EQ(15, result.chunk_map.at(buffer_i_).size);
1024
1025 EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
1026 EXPECT_EQ(0, result.chunk_map.at(buffer_b_).offset);
1027 EXPECT_EQ(30, result.chunk_map.at(buffer_c_).offset);
1028 EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset);
1029 EXPECT_EQ(80, result.chunk_map.at(buffer_e_).offset);
1030 EXPECT_EQ(100, result.chunk_map.at(buffer_f_).offset);
1031 EXPECT_EQ(110, result.chunk_map.at(buffer_g_).offset);
1032 EXPECT_EQ(120, result.chunk_map.at(buffer_h_).offset);
1033 EXPECT_EQ(80, result.chunk_map.at(buffer_i_).offset);
1034 }
1035
TEST_F(LazyBestFitHeapTest,Lazy)1036 TEST_F(LazyBestFitHeapTest, Lazy) {
1037 LazyBestFitHeap heap(/*alignment=*/1);
1038
1039 // First alloc some buffers, which are all lazily allocated offsets.
1040 heap.Alloc(buffer_a_, 10);
1041 heap.Alloc(buffer_b_, 5);
1042 heap.Alloc(buffer_c_, 10);
1043
1044 // Now free some buffers, which forces offset assignment.
1045 heap.Free(buffer_a_, 10); // A range = [0, 10) free = [0, 10)
1046 heap.Free(buffer_c_, 10); // C range = [10, 20) free = [0, 20)
1047
1048 // If we hadn't lazily assigned offsets, the free chunk wouldn't be large
1049 // enough to hold the entire allocation.
1050 heap.Alloc(buffer_d_, 20); // D range = [0, 20)
1051
1052 heap.Free(buffer_b_, 5); // B range = [20, 25)
1053 heap.Free(buffer_d_, 20);
1054
1055 const HeapSimulator::Result result = heap.Finish();
1056 EXPECT_EQ(25, result.heap_size);
1057 EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
1058 EXPECT_EQ(5, result.chunk_map.at(buffer_b_).size);
1059 EXPECT_EQ(10, result.chunk_map.at(buffer_c_).size);
1060 EXPECT_EQ(20, result.chunk_map.at(buffer_d_).size);
1061
1062 EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
1063 EXPECT_EQ(20, result.chunk_map.at(buffer_b_).offset);
1064 EXPECT_EQ(10, result.chunk_map.at(buffer_c_).offset);
1065 EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset);
1066 }
1067
TEST_F(LazyBestFitHeapTest,ReuseLastFreeChunk)1068 TEST_F(LazyBestFitHeapTest, ReuseLastFreeChunk) {
1069 LazyBestFitHeap heap(/*alignment=*/1);
1070
1071 // First alloc/free buffer_a_, to force a big free chunk to appear.
1072 heap.Alloc(buffer_a_, 60); // A lazy offset
1073 heap.Free(buffer_a_, 60); // A range = [0, 60) free = [0, 60)
1074
1075 // Now alloc a bunch of buffers that are allocated out of the free chunk.
1076 heap.Alloc(buffer_b_, 10); // B range = [0, 10) free = [10, 60)
1077 heap.Alloc(buffer_c_, 20); // C range = [10, 30) free = [30, 60)
1078 heap.Alloc(buffer_d_, 30); // D range = [30, 60)
1079
1080 // Free buffers to create free chunks of different sizes.
1081 heap.Free(buffer_b_, 10); // free = [0, 10)
1082 heap.Free(buffer_d_, 30); // free = [0, 10), [30, 60)
1083
1084 // No free chunks are large enough, but the last free chunk is adjacent to the
1085 // end of the heap, so we re-use that chunk.
1086 heap.Alloc(buffer_e_, 40); // E range = [30, 70)
1087
1088 heap.Free(buffer_c_, 20);
1089 heap.Free(buffer_e_, 40);
1090
1091 const HeapSimulator::Result result = heap.Finish();
1092 EXPECT_EQ(70, result.heap_size);
1093 EXPECT_EQ(60, result.chunk_map.at(buffer_a_).size);
1094 EXPECT_EQ(10, result.chunk_map.at(buffer_b_).size);
1095 EXPECT_EQ(20, result.chunk_map.at(buffer_c_).size);
1096 EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size);
1097 EXPECT_EQ(40, result.chunk_map.at(buffer_e_).size);
1098
1099 EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
1100 EXPECT_EQ(0, result.chunk_map.at(buffer_b_).offset);
1101 EXPECT_EQ(10, result.chunk_map.at(buffer_c_).offset);
1102 EXPECT_EQ(30, result.chunk_map.at(buffer_d_).offset);
1103 EXPECT_EQ(30, result.chunk_map.at(buffer_e_).offset);
1104 }
1105
TEST_F(LazyBestFitHeapTest,Alignment)1106 TEST_F(LazyBestFitHeapTest, Alignment) {
1107 LazyBestFitHeap heap(/*alignment=*/64);
1108
1109 // First alloc some buffers, which are all lazily allocated offsets.
1110 heap.Alloc(buffer_a_, 10);
1111 heap.Alloc(buffer_b_, 5);
1112 heap.Alloc(buffer_c_, 10);
1113
1114 // Now free some buffers, which forces offset assignment with alignment.
1115 heap.Free(buffer_a_, 10); // A range = [0, 10) free = [0, 10)
1116 heap.Free(buffer_c_, 10); // C range = [64, 74) free = [0, 74)
1117
1118 // If we hadn't lazily assigned offsets, and accounted for alignment, the free
1119 // chunk wouldn't be large enough to hold the entire allocation.
1120 heap.Alloc(buffer_d_, 74); // D range = [0, 74) free = [)
1121
1122 heap.Free(buffer_b_, 5); // B range = [128, 133) free = [74, 133)
1123 heap.Alloc(buffer_e_, 23); // E range = [128, 151) free = [74, 128)
1124
1125 heap.Free(buffer_d_, 74); // free = [0, 128)
1126 heap.Free(buffer_e_, 23); // free = [0, 151)
1127
1128 const HeapSimulator::Result result = heap.Finish();
1129 EXPECT_EQ(151, result.heap_size);
1130 EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
1131 EXPECT_EQ(5, result.chunk_map.at(buffer_b_).size);
1132 EXPECT_EQ(10, result.chunk_map.at(buffer_c_).size);
1133 EXPECT_EQ(74, result.chunk_map.at(buffer_d_).size);
1134 EXPECT_EQ(23, result.chunk_map.at(buffer_e_).size);
1135
1136 EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
1137 EXPECT_EQ(128, result.chunk_map.at(buffer_b_).offset);
1138 EXPECT_EQ(64, result.chunk_map.at(buffer_c_).offset);
1139 EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset);
1140 EXPECT_EQ(128, result.chunk_map.at(buffer_e_).offset);
1141 }
1142
1143 class GlobalDecreasingSizeBestFitHeapTest : public HeapAlgorithmTestBase {};
1144
TEST_F(GlobalDecreasingSizeBestFitHeapTest,Empty)1145 TEST_F(GlobalDecreasingSizeBestFitHeapTest, Empty) {
1146 GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
1147 const HeapSimulator::Result result = heap.Finish();
1148 EXPECT_EQ(0, result.heap_size);
1149 EXPECT_EQ(0, result.chunk_map.size());
1150 }
1151
TEST_F(GlobalDecreasingSizeBestFitHeapTest,DecreasingSize)1152 TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSize) {
1153 // space
1154 // ^
1155 // | +---a---+
1156 // | +-------+
1157 // | +---c---+
1158 // | +-------+
1159 // | | b |
1160 // | +-------+
1161 // | +-------+
1162 // | | |
1163 // | | d |
1164 // | +-------+
1165 // -----------------> time
1166 GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
1167 heap.Alloc(buffer_a_, 10);
1168 heap.Alloc(buffer_b_, 30);
1169 heap.Alloc(buffer_c_, 20);
1170 heap.Alloc(buffer_d_, 40);
1171 heap.Free(buffer_a_, 10);
1172 heap.Free(buffer_b_, 30);
1173 heap.Free(buffer_c_, 20);
1174 heap.Free(buffer_d_, 40);
1175
1176 const HeapSimulator::Result result = heap.Finish();
1177 EXPECT_EQ(100, result.heap_size);
1178 EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
1179 EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size);
1180 EXPECT_EQ(20, result.chunk_map.at(buffer_c_).size);
1181 EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size);
1182
1183 EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset);
1184 EXPECT_EQ(40, result.chunk_map.at(buffer_b_).offset);
1185 EXPECT_EQ(70, result.chunk_map.at(buffer_c_).offset);
1186 EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset);
1187 }
1188
TEST_F(GlobalDecreasingSizeBestFitHeapTest,DecreasingSizeWithAlignment)1189 TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSizeWithAlignment) {
1190 // space
1191 // ^
1192 // | +-------+
1193 // | +---b---+
1194 // | +-------+
1195 // | | |
1196 // | | d |
1197 // | +---a---+ +-------+
1198 // |
1199 // | +-------+
1200 // | | |
1201 // | | c |
1202 // | | |
1203 // | +-------+
1204 // ---------------------> time
1205 GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/20);
1206 heap.Alloc(buffer_a_, 10);
1207 heap.Alloc(buffer_b_, 20);
1208 heap.Alloc(buffer_c_, 50);
1209 heap.Free(buffer_a_, 10);
1210 heap.Alloc(buffer_d_, 40);
1211 heap.Free(buffer_b_, 20);
1212 heap.Free(buffer_c_, 50);
1213 heap.Free(buffer_d_, 40);
1214
1215 const HeapSimulator::Result result = heap.Finish();
1216 EXPECT_EQ(120, result.heap_size);
1217 EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
1218 EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
1219 EXPECT_EQ(50, result.chunk_map.at(buffer_c_).size);
1220 EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size);
1221
1222 EXPECT_EQ(60, result.chunk_map.at(buffer_a_).offset);
1223 EXPECT_EQ(100, result.chunk_map.at(buffer_b_).offset);
1224 EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset);
1225 EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset);
1226 }
1227
TEST_F(GlobalDecreasingSizeBestFitHeapTest,BestFit)1228 TEST_F(GlobalDecreasingSizeBestFitHeapTest, BestFit) {
1229 // space
1230 // ^
1231 // | +-------+
1232 // | +---b---+
1233 // | +-------+
1234 // | | d |
1235 // | +--a--+ +-------+
1236 // | +-------+
1237 // | | |
1238 // | | c |
1239 // | +-------+
1240 // | +-------+
1241 // | | |
1242 // | | e |
1243 // | | |
1244 // | +-------+
1245 // ---------------------> time
1246 GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
1247 heap.Alloc(buffer_a_, 10);
1248 heap.Alloc(buffer_b_, 20);
1249 heap.Alloc(buffer_c_, 40);
1250 heap.Free(buffer_a_, 10);
1251 heap.Alloc(buffer_d_, 30);
1252 heap.Alloc(buffer_e_, 50);
1253 heap.Free(buffer_b_, 20);
1254 heap.Free(buffer_c_, 40);
1255 heap.Free(buffer_d_, 30);
1256 heap.Free(buffer_e_, 50);
1257
1258 const HeapSimulator::Result result = heap.Finish();
1259 EXPECT_EQ(140, result.heap_size);
1260 EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
1261 EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
1262 EXPECT_EQ(40, result.chunk_map.at(buffer_c_).size);
1263 EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size);
1264 EXPECT_EQ(50, result.chunk_map.at(buffer_e_).size);
1265
1266 EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset);
1267 EXPECT_EQ(120, result.chunk_map.at(buffer_b_).offset);
1268 EXPECT_EQ(50, result.chunk_map.at(buffer_c_).offset);
1269 EXPECT_EQ(90, result.chunk_map.at(buffer_d_).offset);
1270 EXPECT_EQ(0, result.chunk_map.at(buffer_e_).offset);
1271 }
1272
1273 } // namespace
1274 } // namespace xla
1275