1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/heap_simulator.h"
17
18 #include <memory>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/memory/memory.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/service/buffer_value.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
30 #include "tensorflow/compiler/xla/service/hlo_value.h"
31 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/core/platform/test.h"
36
37 namespace xla {
38 namespace {
39
40 class MinimumMemoryForSequenceTest : public HloTestBase {};
41
TEST_F(MinimumMemoryForSequenceTest,MultiComputation)42 TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
43 auto module = CreateNewVerifiedModule();
44 const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
45 const Shape tuple_shape =
46 ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
47
48 auto cond_builder = HloComputation::Builder("WhileCond");
49 // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
50 HloInstruction* cond_param = cond_builder.AddInstruction(
51 HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
52 HloInstruction* cond_iter = cond_builder.AddInstruction(
53 HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
54 HloInstruction* cond_data = cond_builder.AddInstruction(
55 HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
56 // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
57 HloInstruction* cond_lt = cond_builder.AddInstruction(
58 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
59 cond_data, ComparisonDirection::kLt));
60 HloComputation* cond_computation =
61 module->AddEmbeddedComputation(cond_builder.Build());
62
63 auto body_builder = HloComputation::Builder("WhileBody");
64 // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
65 HloInstruction* body_param = body_builder.AddInstruction(
66 HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
67 HloComputation* body_computation =
68 module->AddEmbeddedComputation(body_builder.Build());
69
70 auto builder = HloComputation::Builder(TestName());
71 // Entry params: 8 bytes (4 bytes per param), TOTAL=8
72 HloInstruction* iter = builder.AddInstruction(
73 HloInstruction::CreateParameter(0, scalar_shape, "param_iter"));
74 HloInstruction* data = builder.AddInstruction(
75 HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
76 // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24
77 HloInstruction* tuple =
78 builder.AddInstruction(HloInstruction::CreateTuple({iter, data}));
79 // While: 8 bytes (4 bytes per element), TOTAL=32
80 // Both cond and body use a max of 24 bytes, TOTAL=56
81 HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
82 tuple_shape, cond_computation, body_computation, tuple));
83 HloComputation* entry_computation =
84 module->AddEntryComputation(builder.Build());
85
86 auto size_fn = [](const BufferValue& buffer) {
87 return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
88 };
89
90 HloSchedule schedule(module.get());
91 schedule.set_sequence(cond_computation,
92 {cond_param, cond_iter, cond_data, cond_lt});
93 schedule.set_sequence(body_computation, {body_param});
94 schedule.set_sequence(entry_computation, {iter, data, tuple, while_op});
95 TF_ASSERT_OK(schedule.Verify());
96
97 EXPECT_EQ(
98 25,
99 HeapSimulator::MinimumMemoryForModule(schedule, size_fn).ValueOrDie());
100 }
101
TEST_F(MinimumMemoryForSequenceTest,SubcomputationAccounting)102 TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) {
103 // HloModule SubcomputationAccounting
104
105 // %WhileBody (body_param: f32[4]) -> f32[4] {
106 // %body_param = f32[4]{0} parameter(0)
107 // %constant.1 = f32[4]{0} constant({1, 1, 1, 1})
108 // ROOT %subtract = f32[4]{0} subtract(f32[4]{0} %body_param, f32[4]{0}
109 // %constant.1)
110 // }
111
112 // %WhileCond (cond_param: f32[4]) -> pred[] {
113 // %cond_param = f32[4]{0} parameter(0)
114 // %slice = f32[1]{0} slice(f32[4]{0} %cond_param), slice={[0:1]}
115 // %reshape = f32[] reshape(f32[1]{0} %slice)
116 // %constant = f32[] constant(0)
117 // ROOT %not-equal-to = pred[] compare(f32[] %reshape, f32[] %constant),
118 // direction=NE
119 // }
120
121 // ENTRY %SubcomputationAccounting () -> f32[2,4] {
122 // %constant.3 = f32[2,4]{1,0} constant(f32[2,4] { { 1, 2, 3, 4 }, { 1, 2,
123 // 3, 4 } }) %transpose = f32[2,4]{1,0} transpose(f32[2,4]{1,0}
124 // %constant.3), dimensions={0,1} %constant.2 = f32[4]{0} constant({1, 1, 1,
125 // 1}) %while = f32[4]{0} while(f32[4]{0} %constant.2),
126 // condition=%WhileCond, body=%WhileBody %broadcast = f32[2,4]{1,0}
127 // broadcast(f32[4]{0} %while), dimensions={1} ROOT %add = f32[2,4]{1,0}
128 // add(f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast)
129 // }
130
131 auto module = CreateNewVerifiedModule();
132 const Shape r0f32 = ShapeUtil::MakeShape(F32, {});
133 const Shape r1f32 = ShapeUtil::MakeShape(F32, {4});
134 const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4});
135
136 // reshape(slice(param)) != 0
137 // Needs 5 bytes
138 auto cond_builder = HloComputation::Builder("WhileCond");
139 HloInstruction* cond_param = cond_builder.AddInstruction(
140 HloInstruction::CreateParameter(0, r1f32, "cond_param"));
141 HloInstruction* slice =
142 cond_builder.AddInstruction(HloInstruction::CreateSlice(
143 ShapeUtil::MakeShape(F32, {1}), cond_param, {0}, {1}, {1}));
144 HloInstruction* reshape =
145 cond_builder.AddInstruction(HloInstruction::CreateReshape(r0f32, slice));
146 HloInstruction* zero = cond_builder.AddInstruction(
147 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
148 HloInstruction* cond_comparison = cond_builder.AddInstruction(
149 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), reshape,
150 zero, ComparisonDirection::kNe));
151 auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
152
153 // param - 1
154 // Needs 16 bytes
155 auto body_builder = HloComputation::Builder("WhileBody");
156 HloInstruction* body_param = body_builder.AddInstruction(
157 HloInstruction::CreateParameter(0, r1f32, "body_param"));
158 HloInstruction* one_vector =
159 body_builder.AddInstruction(HloInstruction::CreateConstant(
160 LiteralUtil::CreateR1<float>({1, 1, 1, 1})));
161 HloInstruction* subtract =
162 body_builder.AddInstruction(HloInstruction::CreateBinary(
163 r1f32, HloOpcode::kSubtract, body_param, one_vector));
164 auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
165
166 // transpose(matrix) + bcast(while)
167 auto builder = HloComputation::Builder(TestName());
168 HloInstruction* while_init =
169 builder.AddInstruction(HloInstruction::CreateConstant(
170 LiteralUtil::CreateR1<float>({1, 1, 1, 1})));
171 // Creates 16 bytes, ignoring subcomputations
172 HloInstruction* while_loop =
173 builder.AddInstruction(HloInstruction::CreateWhile(
174 r1f32, cond_computation, body_computation, while_init));
175
176 // Creates 32 bytes and frees 16
177 HloInstruction* bcast = builder.AddInstruction(
178 HloInstruction::CreateBroadcast(r2f32, while_loop, {1}));
179
180 HloInstruction* matrix = builder.AddInstruction(
181 HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>(
182 {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}})));
183 // Creates 32 bytes
184 HloInstruction* transpose = builder.AddInstruction(
185 HloInstruction::CreateTranspose(r2f32, matrix, {0, 1}));
186
187 // Creates 32 bytes and frees 64
188 HloInstruction* add = builder.AddInstruction(
189 HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast));
190
191 auto entry_computation = module->AddEntryComputation(builder.Build());
192
193 HloSchedule schedule(module.get());
194 std::vector<HloInstruction*> cond_vec = {cond_param, slice, reshape, zero,
195 cond_comparison};
196 std::vector<HloInstruction*> while_body_vec = {body_param, one_vector,
197 subtract};
198 std::vector<HloInstruction*> entry_comp_vec = {while_init, while_loop, bcast,
199 matrix, transpose, add};
200 schedule.set_sequence(cond_computation, cond_vec);
201 schedule.set_sequence(body_computation, while_body_vec);
202 schedule.set_sequence(entry_computation, entry_comp_vec);
203
204 auto size_fn = [](const BufferValue& buffer) {
205 return ShapeUtil::ByteSizeOf(buffer.shape());
206 };
207 absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
208 memory_by_computation[cond_computation] = 5;
209 memory_by_computation[body_computation] = 16;
210
211 std::unique_ptr<HloAliasAnalysis> alias_analysis =
212 HloAliasAnalysis::Run(module.get()).ValueOrDie();
213
214 // HeapSimulator accounts for subcomputations. The output buffer is aliased,
215 // so we don't double count.
216 EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation(
217 *entry_computation, schedule.sequence(entry_computation),
218 *alias_analysis, size_fn, &memory_by_computation)
219 .ValueOrDie());
220 }
221
222 const char kAlloc[] = "Alloc";
223 const char kFree[] = "Free";
224 const char kShare[] = "Share";
225 const char kFinish[] = "Finish";
226
227 // CallSequence records a sequence of Alloc/Free/Finish calls.
228 using CallSequence = std::vector<std::pair<string, const HloValue*>>;
229
230 // HeapCallRecorder is a dummy heap algorithm that simply records its calls.
231 class HeapCallRecorder : public HeapAlgorithm<HloValue> {
232 public:
HeapCallRecorder(CallSequence * calls)233 explicit HeapCallRecorder(CallSequence* calls) : calls_(calls) {}
~HeapCallRecorder()234 ~HeapCallRecorder() override {}
235
Alloc(const HloValue * buffer,int64 size)236 void Alloc(const HloValue* buffer, int64 size) override {
237 calls_->emplace_back(kAlloc, buffer);
238 // Instead of assigning a real offset, we set the cardinality of the Alloc
239 // call. This isn't a valid assignment, but allows us to easily test for
240 // buffer sharing.
241 const int64 offset = result_.chunk_map.size();
242 result_.chunk_map.emplace(buffer, Chunk{offset, size});
243 }
244
ShareWith(const HloValue * buffer,const HloValue * shared,int64 size)245 void ShareWith(const HloValue* buffer, const HloValue* shared,
246 int64 size) override {
247 calls_->emplace_back(kShare, buffer);
248 // Instead of assigning a real offset, we set the cardinality of the Alloc
249 // call. This isn't a valid assignment, but allows us to easily test for
250 // buffer sharing.
251 const int64 offset = result_.chunk_map[shared].offset;
252 result_.chunk_map.emplace(buffer, Chunk{offset, size});
253 }
Free(const HloValue * buffer,int64 size)254 void Free(const HloValue* buffer, int64 size) override {
255 calls_->emplace_back(kFree, buffer);
256 }
Finish()257 Result Finish() override {
258 calls_->emplace_back(kFinish, nullptr);
259 HeapSimulator::Result<HloValue> result;
260 result.heap_size = result_.heap_size;
261 result.heap_results.emplace_back(std::move(result_));
262 return result;
263 }
264
265 private:
266 CallSequence* calls_;
267 HeapSimulator::HeapResult<HloValue> result_;
268 };
269
270 // HeapSimulatorTracker runs the heap simulator, recording the sequence of calls
271 // made to the underlying heap algorithm. Tests compare the actual call
272 // sequence against an expected sequence.
273 class HeapSimulatorTracker {
274 public:
HeapSimulatorTracker(std::unique_ptr<HloModule> module,const std::vector<HloInstruction * > & instruction_sequence,const std::vector<HloInstruction * > & must_alias_set={},const HloDataflowAnalysis::CanShareBuffer & can_share_buffer=nullptr)275 explicit HeapSimulatorTracker(
276 std::unique_ptr<HloModule> module,
277 const std::vector<HloInstruction*>& instruction_sequence,
278 const std::vector<HloInstruction*>& must_alias_set = {},
279 const HloDataflowAnalysis::CanShareBuffer& can_share_buffer = nullptr) {
280 module_ = std::move(module);
281 Init(instruction_sequence, can_share_buffer);
282 }
283
284 // Constructor for testing a single entry computation.
HeapSimulatorTracker(const string & name,std::unique_ptr<HloComputation> entry_computation,const std::vector<HloInstruction * > & instruction_sequence,const std::vector<HloInstruction * > & must_alias_set={},const HloDataflowAnalysis::CanShareBuffer & can_share_buffer=nullptr)285 explicit HeapSimulatorTracker(
286 const string& name, std::unique_ptr<HloComputation> entry_computation,
287 const std::vector<HloInstruction*>& instruction_sequence,
288 const std::vector<HloInstruction*>& must_alias_set = {},
289 const HloDataflowAnalysis::CanShareBuffer& can_share_buffer = nullptr) {
290 HloModuleConfig config;
291 module_ = absl::make_unique<HloModule>(name, config);
292 module_->AddEntryComputation(std::move(entry_computation));
293 Init(instruction_sequence, can_share_buffer);
294 }
295
HeapSimulatorTracker(const string & name)296 explicit HeapSimulatorTracker(const string& name) {
297 HloModuleConfig config;
298 module_ = absl::make_unique<HloModule>(name, config);
299 }
300
301 // Similar to the single entry computation constructor above, but runs the
302 // simulation over the entire module.
RunWholeModule(const std::vector<HloInstruction * > & full_module_sequence)303 void RunWholeModule(
304 const std::vector<HloInstruction*>& full_module_sequence) {
305 alias_analysis_ = HloAliasAnalysis::Run(module_.get()).ConsumeValueOrDie();
306
307 // Construct the module sequence grouped by computation.
308 HloSchedule schedule(module_.get());
309 absl::flat_hash_map<const HloInstruction*, int> reverse_position;
310 for (int i = 0; i < full_module_sequence.size(); ++i) {
311 HloInstruction* instruction = full_module_sequence[i];
312 schedule.GetOrCreateSequence(instruction->parent())
313 .push_back(instruction);
314 reverse_position[instruction] = full_module_sequence.size() - i;
315 }
316
317 // Hack the size_fn so that it returns a decreasing value as we step through
318 // the sequence. This lets us ensure the Alloc calls are in the sequence
319 // order. The Free calls are sorted by BufferValue.id, which is at least
320 // deterministic.
321 auto size_fn = [&reverse_position](const BufferValue& buffer) {
322 return reverse_position[buffer.instruction()];
323 };
324 auto algorithm = absl::make_unique<HeapCallRecorder>(&actual_calls_);
325 result_ = HeapSimulator::Run(std::move(algorithm), *module_, schedule,
326 *alias_analysis_, size_fn)
327 .ConsumeValueOrDie();
328 }
329
module()330 HloModule* module() { return module_.get(); }
331
332 // Returns the buffer defined at the given instruction and index.
BufferAt(const HloInstruction * instruction,const ShapeIndex & index) const333 const HloValue* BufferAt(const HloInstruction* instruction,
334 const ShapeIndex& index) const {
335 return &alias_analysis_->dataflow_analysis().GetUniqueValueAt(instruction,
336 index);
337 }
338
OffsetAt(const HloInstruction * instruction,const ShapeIndex & index)339 int64 OffsetAt(const HloInstruction* instruction, const ShapeIndex& index) {
340 const HloValue* buffer = BufferAt(instruction, index);
341 CHECK_EQ(1, result_.heap_results.size());
342 return result_.heap_results.at(0).chunk_map.at(buffer).offset;
343 }
344
345 // Ensures the expected sequence of Alloc/Free/Finish calls was performed.
ExpectCallSequence(const CallSequence & expected) const346 void ExpectCallSequence(const CallSequence& expected) const {
347 auto to_string = [](const CallSequence& sequence) {
348 std::string output;
349 for (int64 i = 0; i < sequence.size(); ++i) {
350 auto pair = sequence.at(i);
351 absl::StrAppendFormat(&output, "%d", i);
352 absl::StrAppendFormat(&output, " :%s", pair.first);
353 if (pair.second != nullptr) {
354 absl::StrAppendFormat(&output, " - %s{%s}\n",
355 pair.second->instruction()->name(),
356 pair.second->index().ToString());
357 }
358 }
359 return output;
360 };
361 EXPECT_EQ(expected, actual_calls_) << "Expected:\n"
362 << to_string(expected) << " \nActual:\n"
363 << to_string(actual_calls_) << "\n";
364 }
365
366 // Ensures the buffers defined by the respective (instruction,index) pairs are
367 // shared, relying on the unique offsets assigned in
368 // HeapCallRecorder::Alloc.
ExpectSharedBuffers(const HloInstruction * instruction_a,const ShapeIndex & index_a,const HloInstruction * instruction_b,const ShapeIndex & index_b)369 void ExpectSharedBuffers(const HloInstruction* instruction_a,
370 const ShapeIndex& index_a,
371 const HloInstruction* instruction_b,
372 const ShapeIndex& index_b) {
373 int64 offset_a = OffsetAt(instruction_a, index_a);
374 int64 offset_b = OffsetAt(instruction_b, index_b);
375 EXPECT_EQ(offset_a, offset_b);
376 }
377
378 private:
Init(const std::vector<HloInstruction * > & instruction_sequence,const HloDataflowAnalysis::CanShareBuffer & can_share_buffer)379 void Init(const std::vector<HloInstruction*>& instruction_sequence,
380 const HloDataflowAnalysis::CanShareBuffer& can_share_buffer) {
381 // Since we're only tracking the sequence of Alloc/Free calls, the actual
382 // size of the buffers doesn't matter, so we always return 0. We rely on
383 // the secondary sorting criteria of DecreasingSizeRunsHeap to sort calls
384 // by buffer id, for determinism in the tests.
385 auto zero_size = [](const BufferValue& buffer) { return 0; };
386 auto algorithm = absl::make_unique<HeapCallRecorder>(&actual_calls_);
387
388 alias_analysis_ =
389 HloAliasAnalysis::Run(module_.get(), can_share_buffer).ValueOrDie();
390
391 HeapSimulator::Options options;
392
393 result_ =
394 HeapSimulator::Run(std::move(algorithm), *module_->entry_computation(),
395 HloInstructionSequence(instruction_sequence),
396 *alias_analysis_, zero_size, options)
397 .ConsumeValueOrDie();
398 }
399
400 std::unique_ptr<HloModule> module_;
401 std::unique_ptr<HloAliasAnalysis> alias_analysis_;
402 CallSequence actual_calls_;
403 HeapSimulator::Result<HloValue> result_;
404 };
405
406 class HeapSimulatorTest : public HloTestBase {
407 protected:
HeapSimulatorTest()408 HeapSimulatorTest() {}
~HeapSimulatorTest()409 ~HeapSimulatorTest() override {}
410
411 // Shapes for use in the examples.
412 Shape f32scalar_ = ShapeUtil::MakeShape(xla::F32, {});
413 Shape f32vec4_ = ShapeUtil::MakeShape(F32, {4});
414 };
415
TEST_F(HeapSimulatorTest,ScalarConstant)416 TEST_F(HeapSimulatorTest, ScalarConstant) {
417 auto builder = HloComputation::Builder(TestName());
418 auto const0 = builder.AddInstruction(
419 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
420
421 // Constants aren't assigned. See b/32248867
422 HeapSimulatorTracker tracker(TestName(), builder.Build(), {const0});
423 tracker.ExpectCallSequence({{kFinish, nullptr}});
424 }
425
TEST_F(HeapSimulatorTest,OneParam)426 TEST_F(HeapSimulatorTest, OneParam) {
427 auto builder = HloComputation::Builder(TestName());
428 auto param0 = builder.AddInstruction(
429 HloInstruction::CreateParameter(0, f32scalar_, "param0"));
430
431 // A single parameter which is also the output.
432 HeapSimulatorTracker tracker(TestName(), builder.Build(), {param0});
433 tracker.ExpectCallSequence({
434 {kAlloc, tracker.BufferAt(param0, {})},
435 {kFree, tracker.BufferAt(param0, {})},
436 {kFinish, nullptr},
437 });
438 }
439
TEST_F(HeapSimulatorTest,Multiply)440 TEST_F(HeapSimulatorTest, Multiply) {
441 auto builder = HloComputation::Builder(TestName());
442 auto paramA = builder.AddInstruction(
443 HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
444 auto paramX = builder.AddInstruction(
445 HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
446 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
447 f32vec4_, HloOpcode::kMultiply, paramA, paramX));
448
449 // We must keep all parameters and outputs.
450 HeapSimulatorTracker tracker(TestName(), builder.Build(),
451 {paramA, paramX, mul});
452 tracker.ExpectCallSequence({
453 {kAlloc, tracker.BufferAt(paramA, {})},
454 {kAlloc, tracker.BufferAt(paramX, {})},
455 {kAlloc, tracker.BufferAt(mul, {})},
456 // All params and outputs are freed at the end.
457 {kFree, tracker.BufferAt(paramA, {})},
458 {kFree, tracker.BufferAt(paramX, {})},
459 {kFree, tracker.BufferAt(mul, {})},
460 {kFinish, nullptr},
461 });
462 }
463
TEST_F(HeapSimulatorTest,MultiplyAdd)464 TEST_F(HeapSimulatorTest, MultiplyAdd) {
465 auto builder = HloComputation::Builder(TestName());
466 auto paramA = builder.AddInstruction(
467 HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
468 auto paramX = builder.AddInstruction(
469 HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
470 auto paramY = builder.AddInstruction(
471 HloInstruction::CreateParameter(2, f32vec4_, "paramY"));
472 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
473 f32vec4_, HloOpcode::kMultiply, paramA, paramX));
474 auto add = builder.AddInstruction(
475 HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, mul, paramY));
476
477 // The buffer for add is the output, and it's shared with the buffer for
478 // mul.
479 HeapSimulatorTracker tracker(TestName(), builder.Build(),
480 {paramA, paramX, mul, paramY, add});
481 tracker.ExpectCallSequence({
482 {kAlloc, tracker.BufferAt(paramA, {})},
483 {kAlloc, tracker.BufferAt(paramX, {})},
484 {kAlloc, tracker.BufferAt(paramY, {})},
485 {kAlloc, tracker.BufferAt(mul, {})},
486 {kFree, tracker.BufferAt(mul, {})},
487 {kShare, tracker.BufferAt(add, {})},
488 // All params and outputs are freed at the end.
489 {kFree, tracker.BufferAt(paramA, {})},
490 {kFree, tracker.BufferAt(paramX, {})},
491 {kFree, tracker.BufferAt(paramY, {})},
492 {kFree, tracker.BufferAt(add, {})},
493 {kFinish, nullptr},
494 });
495 tracker.ExpectSharedBuffers(add, {}, mul, {});
496 }
497
TEST_F(HeapSimulatorTest,FusionOutputsOnlyShareOnce)498 TEST_F(HeapSimulatorTest, FusionOutputsOnlyShareOnce) {
499 // Test that only one output of a fusion node will be shared with its operand.
500 auto can_share_buffer =
501 [](const HloInstruction* instr, const HloInstruction* operand,
502 const ShapeIndex& user_index) -> absl::optional<bool> {
503 if (instr->opcode() == HloOpcode::kFusion) {
504 return true;
505 }
506 return false;
507 };
508
509 HloModuleConfig config;
510 auto module = absl::make_unique<HloModule>(TestName(), config);
511
512 auto builder = HloComputation::Builder(TestName());
513 auto paramA = builder.AddInstruction(
514 HloInstruction::CreateParameter(0, f32vec4_, "paramA"));
515 auto negate = builder.AddInstruction(
516 HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, paramA));
517
518 // The fusion node has two outputs, both are eligible for being reused with
519 // operand.
520 auto fusion_builder = HloComputation::Builder("simple_two_way_forwarding");
521 {
522 auto param = fusion_builder.AddInstruction(
523 HloInstruction::CreateParameter(0, f32vec4_, "x"));
524 fusion_builder.AddInstruction(HloInstruction::CreateTuple({param, param}));
525 }
526 auto fusion_computation =
527 module->AddEmbeddedComputation(fusion_builder.Build());
528
529 auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
530 ShapeUtil::MakeTupleShape({f32vec4_, f32vec4_}),
531 HloInstruction::FusionKind::kLoop, {negate}, fusion_computation));
532
533 auto element0 = builder.AddInstruction(
534 HloInstruction::CreateGetTupleElement(f32scalar_, fusion, 0));
535
536 auto element1 = builder.AddInstruction(
537 HloInstruction::CreateGetTupleElement(f32scalar_, fusion, 1));
538
539 auto negate0 = builder.AddInstruction(
540 HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, element0));
541 auto negate1 = builder.AddInstruction(
542 HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, element1));
543
544 builder.AddInstruction(HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd,
545 negate0, negate1));
546
547 module->AddEntryComputation(builder.Build());
548 HeapSimulatorTracker tracker(
549 std::move(module),
550 {paramA, negate, fusion, element0, element1, negate0, negate1}, {},
551 can_share_buffer);
552 tracker.ExpectCallSequence({
553 {kAlloc, tracker.BufferAt(paramA, {})},
554 {kAlloc, tracker.BufferAt(negate, {})},
555 {kAlloc, tracker.BufferAt(fusion, {})},
556 {kFree, tracker.BufferAt(negate, {})},
557 {kShare, tracker.BufferAt(fusion, {0})},
558 {kAlloc, tracker.BufferAt(fusion, {1})},
559 {kFree, tracker.BufferAt(fusion, {})},
560 {kAlloc, tracker.BufferAt(negate0, {})},
561 {kFree, tracker.BufferAt(fusion, {0})},
562 {kFree, tracker.BufferAt(negate0, {})},
563 {kAlloc, tracker.BufferAt(negate1, {})},
564 {kFree, tracker.BufferAt(fusion, {1})},
565 {kFree, tracker.BufferAt(negate1, {})},
566 {kFree, tracker.BufferAt(paramA, {})},
567 {kFinish, nullptr},
568 });
569 }
570
TEST_F(HeapSimulatorTest,FusionOutputsOnlyShareOnceOutputShortLived)571 TEST_F(HeapSimulatorTest, FusionOutputsOnlyShareOnceOutputShortLived) {
572 // Test that only one output of a fusion node will be shared with its operand.
573 // This variant of the test has a fusion node that dies immediately.
574 auto can_share_buffer =
575 [](const HloInstruction* instr, const HloInstruction* operand,
576 const ShapeIndex& user_index) -> absl::optional<bool> {
577 if (instr->opcode() == HloOpcode::kFusion) {
578 return true;
579 }
580 return false;
581 };
582
583 HloModuleConfig config;
584 auto module = absl::make_unique<HloModule>(TestName(), config);
585
586 auto builder = HloComputation::Builder(TestName());
587 auto paramA = builder.AddInstruction(
588 HloInstruction::CreateParameter(0, f32vec4_, "paramA"));
589 auto negate = builder.AddInstruction(
590 HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, paramA));
591
592 // The fusion node has two outputs, both are eligible for being reused with
593 // operand.
594 auto fusion_builder = HloComputation::Builder("simple_two_way_forwarding");
595 {
596 auto param = fusion_builder.AddInstruction(
597 HloInstruction::CreateParameter(0, f32vec4_, "x"));
598 fusion_builder.AddInstruction(HloInstruction::CreateTuple({param, param}));
599 }
600 auto fusion_computation =
601 module->AddEmbeddedComputation(fusion_builder.Build());
602
603 auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
604 ShapeUtil::MakeTupleShape({f32vec4_, f32vec4_}),
605 HloInstruction::FusionKind::kLoop, {negate}, fusion_computation));
606
607 auto element1 = builder.AddInstruction(
608 HloInstruction::CreateGetTupleElement(f32scalar_, fusion, 1));
609
610 auto negate1 = builder.AddInstruction(
611 HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, element1));
612
613 module->AddEntryComputation(builder.Build());
614 HeapSimulatorTracker tracker(std::move(module),
615 {paramA, negate, fusion, element1, negate1}, {},
616 can_share_buffer);
617 tracker.ExpectCallSequence({
618 {kAlloc, tracker.BufferAt(paramA, {})},
619 {kAlloc, tracker.BufferAt(negate, {})},
620 {kFree, tracker.BufferAt(negate, {})},
621 {kShare, tracker.BufferAt(fusion, {0})},
622 {kAlloc, tracker.BufferAt(fusion, {})},
623 {kAlloc, tracker.BufferAt(fusion, {1})},
624 {kFree, tracker.BufferAt(fusion, {0})},
625 {kFree, tracker.BufferAt(fusion, {})},
626 {kAlloc, tracker.BufferAt(negate1, {})},
627 {kFree, tracker.BufferAt(fusion, {1})},
628 {kFree, tracker.BufferAt(paramA, {})},
629 {kFree, tracker.BufferAt(negate1, {})},
630 {kFinish, nullptr},
631 });
632 }
633
TEST_F(HeapSimulatorTest,BufferReusedOnce)634 TEST_F(HeapSimulatorTest, BufferReusedOnce) {
635 HeapSimulatorTracker tracker(TestName());
636 auto builder = HloComputation::Builder(TestName());
637
638 HloComputation::Builder fusion_builder("fusion");
639 {
640 HloComputation::Builder& builder = fusion_builder;
641 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
642 /*parameter_number=*/0, f32vec4_, "A"));
643 auto exp = builder.AddInstruction(
644 HloInstruction::CreateUnary(f32vec4_, HloOpcode::kExp, a_param));
645 auto neg = builder.AddInstruction(
646 HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, a_param));
647
648 builder.AddInstruction(HloInstruction::CreateTuple({exp, neg}));
649 }
650 auto fusion_computation =
651 tracker.module()->AddEmbeddedComputation(fusion_builder.Build());
652 auto a_param = builder.AddInstruction(
653 HloInstruction::CreateParameter(0, f32vec4_, "paramA"));
654 auto neg = builder.AddInstruction(
655 HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, a_param));
656 auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
657 ShapeUtil::MakeTupleShape({f32vec4_, f32vec4_}),
658 HloInstruction::FusionKind::kLoop, {neg}, fusion_computation));
659 tracker.module()->AddEntryComputation(builder.Build());
660
661 tracker.RunWholeModule({a_param, neg, fusion});
662
663 auto neg_buffer = tracker.OffsetAt(neg, {});
664 int64 output_buffer_0 = tracker.OffsetAt(fusion, {0});
665 int64 output_buffer_1 = tracker.OffsetAt(fusion, {1});
666 // Only one buffer should be shared.
667 EXPECT_TRUE((neg_buffer == output_buffer_0) ^
668 (neg_buffer == output_buffer_1));
669 }
670
TEST_F(HeapSimulatorTest,MultiplyDot)671 TEST_F(HeapSimulatorTest, MultiplyDot) {
672 auto builder = HloComputation::Builder(TestName());
673 auto paramA = builder.AddInstruction(
674 HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
675 auto paramX = builder.AddInstruction(
676 HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
677 auto paramY = builder.AddInstruction(
678 HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
679 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
680 f32vec4_, HloOpcode::kMultiply, paramA, paramX));
681 DotDimensionNumbers dot_dnums;
682 dot_dnums.add_lhs_contracting_dimensions(1);
683 dot_dnums.add_rhs_contracting_dimensions(0);
684 auto dot = builder.AddInstruction(HloInstruction::CreateDot(
685 f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
686
687 // The buffer for dot is the output, and it cannot be shared with the buffer
688 // for mul, since dot isn't elementwise.
689 HeapSimulatorTracker tracker(TestName(), builder.Build(),
690 {paramA, paramX, mul, paramY, dot});
691 tracker.ExpectCallSequence({
692 {kAlloc, tracker.BufferAt(paramA, {})},
693 {kAlloc, tracker.BufferAt(paramX, {})},
694 {kAlloc, tracker.BufferAt(paramY, {})},
695 {kAlloc, tracker.BufferAt(mul, {})},
696 {kAlloc, tracker.BufferAt(dot, {})},
697 // All params and outputs are freed at the end.
698 {kFree, tracker.BufferAt(mul, {})},
699 {kFree, tracker.BufferAt(paramA, {})},
700 {kFree, tracker.BufferAt(paramX, {})},
701 {kFree, tracker.BufferAt(paramY, {})},
702 {kFree, tracker.BufferAt(dot, {})},
703 {kFinish, nullptr},
704 });
705 }
706
TEST_F(HeapSimulatorTest,MultiplyDotAdd)707 TEST_F(HeapSimulatorTest, MultiplyDotAdd) {
708 auto builder = HloComputation::Builder(TestName());
709 auto paramA = builder.AddInstruction(
710 HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
711 auto paramX = builder.AddInstruction(
712 HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
713 auto paramY = builder.AddInstruction(
714 HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
715 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
716 f32vec4_, HloOpcode::kMultiply, paramA, paramX));
717 DotDimensionNumbers dot_dnums;
718 dot_dnums.add_lhs_contracting_dimensions(1);
719 dot_dnums.add_rhs_contracting_dimensions(0);
720 auto dot = builder.AddInstruction(HloInstruction::CreateDot(
721 f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
722 auto add = builder.AddInstruction(
723 HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA));
724
725 // The buffer for add is the output, and it's shared with the buffer for
726 // dot.
727 HeapSimulatorTracker tracker(TestName(), builder.Build(),
728 {paramA, paramX, mul, paramY, dot, add});
729 tracker.ExpectCallSequence({
730 {kAlloc, tracker.BufferAt(paramA, {})},
731 {kAlloc, tracker.BufferAt(paramX, {})},
732 {kAlloc, tracker.BufferAt(paramY, {})},
733 {kAlloc, tracker.BufferAt(mul, {})},
734 {kAlloc, tracker.BufferAt(dot, {})},
735 {kFree, tracker.BufferAt(mul, {})},
736 {kFree, tracker.BufferAt(dot, {})},
737 {kShare, tracker.BufferAt(add, {})},
738 // All params and outputs are freed at the end.
739 {kFree, tracker.BufferAt(paramA, {})},
740 {kFree, tracker.BufferAt(paramX, {})},
741 {kFree, tracker.BufferAt(paramY, {})},
742 {kFree, tracker.BufferAt(add, {})},
743 {kFinish, nullptr},
744 });
745 tracker.ExpectSharedBuffers(add, {}, dot, {});
746 }
747
TEST_F(HeapSimulatorTest,MultiplyDotDot)748 TEST_F(HeapSimulatorTest, MultiplyDotDot) {
749 auto builder = HloComputation::Builder(TestName());
750 auto paramA = builder.AddInstruction(
751 HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
752 auto paramX = builder.AddInstruction(
753 HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
754 auto paramY = builder.AddInstruction(
755 HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
756 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
757 f32vec4_, HloOpcode::kMultiply, paramA, paramX));
758 DotDimensionNumbers dot_dnums;
759 dot_dnums.add_lhs_contracting_dimensions(1);
760 dot_dnums.add_rhs_contracting_dimensions(0);
761 auto dot0 = builder.AddInstruction(HloInstruction::CreateDot(
762 f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
763 auto dot1 = builder.AddInstruction(HloInstruction::CreateDot(
764 f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2)));
765
766 // The buffer for dot1 is the output. No buffers can be shared. The buffer
767 // for mul is freed before the end, since it's no longer used after dot0
768 // finishes.
769 HeapSimulatorTracker tracker(TestName(), builder.Build(),
770 {paramA, paramX, mul, paramY, dot0, dot1});
771 tracker.ExpectCallSequence({
772 {kAlloc, tracker.BufferAt(paramA, {})},
773 {kAlloc, tracker.BufferAt(paramX, {})},
774 {kAlloc, tracker.BufferAt(paramY, {})},
775 {kAlloc, tracker.BufferAt(mul, {})},
776 {kAlloc, tracker.BufferAt(dot0, {})},
777 {kFree, tracker.BufferAt(mul, {})}, // mul no longer used
778 {kAlloc, tracker.BufferAt(dot1, {})},
779 {kFree, tracker.BufferAt(dot0, {})},
780 // All params and outputs are freed at the end.
781 {kFree, tracker.BufferAt(paramA, {})},
782 {kFree, tracker.BufferAt(paramX, {})},
783 {kFree, tracker.BufferAt(paramY, {})},
784 {kFree, tracker.BufferAt(dot1, {})},
785 {kFinish, nullptr},
786 });
787 }
788
TEST_F(HeapSimulatorTest,MultiplyDotDotTuple)789 TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) {
790 auto builder = HloComputation::Builder(TestName());
791 auto paramA = builder.AddInstruction(
792 HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
793 auto paramX = builder.AddInstruction(
794 HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
795 auto paramY = builder.AddInstruction(
796 HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
797 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
798 f32vec4_, HloOpcode::kMultiply, paramA, paramX));
799 DotDimensionNumbers dot_dnums;
800 dot_dnums.add_lhs_contracting_dimensions(1);
801 dot_dnums.add_rhs_contracting_dimensions(0);
802 auto dot0 = builder.AddInstruction(HloInstruction::CreateDot(
803 f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
804 auto dot1 = builder.AddInstruction(HloInstruction::CreateDot(
805 f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2)));
806 auto tuple =
807 builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1}));
808
809 // The buffers for dot0, dot1 and tuple are the output. No buffers can be
810 // shared. The buffer for mul is freed before the end, since it's no longer
811 // used after dot0 finishes.
812 HeapSimulatorTracker tracker(
813 TestName(), builder.Build(),
814 {paramA, paramX, mul, paramY, dot0, dot1, tuple});
815 tracker.ExpectCallSequence({
816 {kAlloc, tracker.BufferAt(paramA, {})},
817 {kAlloc, tracker.BufferAt(paramX, {})},
818 {kAlloc, tracker.BufferAt(paramY, {})},
819 {kAlloc, tracker.BufferAt(mul, {})},
820 {kAlloc, tracker.BufferAt(dot0, {})},
821 {kFree, tracker.BufferAt(mul, {})}, // mul no longer used
822 {kAlloc, tracker.BufferAt(dot1, {})},
823 {kAlloc, tracker.BufferAt(tuple, {})},
824 // All params and outputs are freed at the end.
825 {kFree, tracker.BufferAt(paramA, {})},
826 {kFree, tracker.BufferAt(paramX, {})},
827 {kFree, tracker.BufferAt(paramY, {})},
828 {kFree, tracker.BufferAt(dot0, {})},
829 {kFree, tracker.BufferAt(dot1, {})},
830 {kFree, tracker.BufferAt(tuple, {})},
831 {kFinish, nullptr},
832 });
833 }
834
TEST_F(HeapSimulatorTest,IndependentTupleElements)835 TEST_F(HeapSimulatorTest, IndependentTupleElements) {
836 auto builder = HloComputation::Builder(TestName());
837 auto paramA = builder.AddInstruction(
838 HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
839 auto paramB = builder.AddInstruction(
840 HloInstruction::CreateParameter(1, f32scalar_, "paramB"));
841 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
842 f32scalar_, HloOpcode::kMultiply, paramA, paramB));
843 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
844 f32scalar_, HloOpcode::kAdd, paramA, paramB));
845 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({mul, add}));
846 auto element0 = builder.AddInstruction(
847 HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 0));
848 auto broadcast = builder.AddInstruction(
849 HloInstruction::CreateBroadcast(f32vec4_, element0, {0}));
850 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
851 f32scalar_, HloOpcode::kSubtract, paramA, paramB));
852 auto element1 = builder.AddInstruction(
853 HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 1));
854 auto output = builder.AddInstruction(
855 HloInstruction::CreateTuple({broadcast, sub, element1}));
856
857 HeapSimulatorTracker tracker(TestName(), builder.Build(),
858 {paramA, paramB, mul, add, tuple, element0,
859 broadcast, sub, element1, output});
860 tracker.ExpectCallSequence({
861 {kAlloc, tracker.BufferAt(paramA, {})},
862 {kAlloc, tracker.BufferAt(paramB, {})},
863 {kAlloc, tracker.BufferAt(mul, {})},
864 {kAlloc, tracker.BufferAt(add, {})},
865 {kAlloc, tracker.BufferAt(tuple, {})},
866 {kAlloc, tracker.BufferAt(broadcast, {})},
867 // The mul can be freed right after the broadcast happens, even though
868 // The other GetTupleElement is still alive.
869 {kFree, tracker.BufferAt(mul, {})},
870 {kAlloc, tracker.BufferAt(sub, {})},
871 // The temporary tuple is now dead.
872 {kFree, tracker.BufferAt(tuple, {})},
873 {kAlloc, tracker.BufferAt(output, {})},
874 // All params and outputs are freed at the end.
875 {kFree, tracker.BufferAt(paramA, {})},
876 {kFree, tracker.BufferAt(paramB, {})},
877 {kFree, tracker.BufferAt(add, {})},
878 {kFree, tracker.BufferAt(broadcast, {})},
879 {kFree, tracker.BufferAt(sub, {})},
880 {kFree, tracker.BufferAt(output, {})},
881 {kFinish, nullptr},
882 });
883 }
884
TEST_F(HeapSimulatorTest,WholeModule)885 TEST_F(HeapSimulatorTest, WholeModule) {
886 HeapSimulatorTracker tracker(TestName());
887
888 const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
889 const Shape tuple_shape =
890 ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
891
892 auto cond_builder = HloComputation::Builder("WhileCond");
893 HloInstruction* cond_param = cond_builder.AddInstruction(
894 HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
895 HloInstruction* cond_iter = cond_builder.AddInstruction(
896 HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
897 HloInstruction* cond_data = cond_builder.AddInstruction(
898 HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
899 HloInstruction* cond_lt = cond_builder.AddInstruction(
900 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
901 cond_data, ComparisonDirection::kLt));
902 HloComputation* cond_computation =
903 tracker.module()->AddEmbeddedComputation(cond_builder.Build());
904
905 auto body_builder = HloComputation::Builder("WhileBody");
906 HloInstruction* body_param = body_builder.AddInstruction(
907 HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
908 HloComputation* body_computation =
909 tracker.module()->AddEmbeddedComputation(body_builder.Build());
910
911 auto builder = HloComputation::Builder(TestName());
912 HloInstruction* param = builder.AddInstruction(
913 HloInstruction::CreateParameter(0, tuple_shape, "param"));
914 HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
915 tuple_shape, cond_computation, body_computation, param));
916 tracker.module()->AddEntryComputation(builder.Build());
917
918 tracker.RunWholeModule(
919 {param, while_op, body_param, cond_param, cond_iter, cond_data, cond_lt});
920 tracker.ExpectCallSequence({
921 // The entry computation param and while_op are allocated first.
922 {kAlloc, tracker.BufferAt(param, {})},
923 {kAlloc, tracker.BufferAt(param, {0})},
924 {kAlloc, tracker.BufferAt(param, {1})},
925
926 // Now the final cond less-than buffer is allocated.
927 {kAlloc, tracker.BufferAt(cond_lt, {})},
928
929 // The order of the remaining Free calls is based on the BufferValue.id,
930 // which is deterministic, but not obvious.
931 {kFree, tracker.BufferAt(cond_lt, {})},
932 {kFree, tracker.BufferAt(param, {})},
933 {kFree, tracker.BufferAt(param, {0})},
934 {kFree, tracker.BufferAt(param, {1})},
935 {kFinish, nullptr},
936 });
937 }
938
939 // Base class for heap algorithm tests.
940 class HeapAlgorithmTestBase : public ::testing::Test {
941 protected:
HeapAlgorithmTestBase()942 HeapAlgorithmTestBase() : builder_("heap_simulator_test") {
943 buffer_a_ = DummyBufferValue();
944 buffer_b_ = DummyBufferValue();
945 buffer_c_ = DummyBufferValue();
946 buffer_d_ = DummyBufferValue();
947 buffer_e_ = DummyBufferValue();
948 buffer_f_ = DummyBufferValue();
949 buffer_g_ = DummyBufferValue();
950 buffer_h_ = DummyBufferValue();
951 buffer_i_ = DummyBufferValue();
952 }
~HeapAlgorithmTestBase()953 ~HeapAlgorithmTestBase() override {}
954
955 const HloValue* buffer_a_;
956 const HloValue* buffer_b_;
957 const HloValue* buffer_c_;
958 const HloValue* buffer_d_;
959 const HloValue* buffer_e_;
960 const HloValue* buffer_f_;
961 const HloValue* buffer_g_;
962 const HloValue* buffer_h_;
963 const HloValue* buffer_i_;
964
965 private:
966 // Create a dummy HloValue to pass to the heap algorithm.
DummyBufferValue()967 const HloValue* DummyBufferValue() {
968 const HloValue::Id id = buffers_.size();
969 auto const0 = builder_.AddInstruction(
970 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
971 buffers_.emplace_back(
972 absl::make_unique<HloValue>(id, const0, ShapeIndex{}));
973 return buffers_.back().get();
974 }
975
976 HloComputation::Builder builder_;
977 std::vector<std::unique_ptr<HloValue>> buffers_;
978 };
979
980 class NoFragmentationStatsHeapTest : public HeapAlgorithmTestBase {};
981
TEST_F(NoFragmentationStatsHeapTest,Empty)982 TEST_F(NoFragmentationStatsHeapTest, Empty) {
983 NoFragmentationStatsHeap<HloValue> heap;
984 EXPECT_EQ(0, heap.Finish().heap_size);
985 }
986
TEST_F(NoFragmentationStatsHeapTest,Simple)987 TEST_F(NoFragmentationStatsHeapTest, Simple) {
988 NoFragmentationStatsHeap<HloValue> heap;
989 heap.Alloc(buffer_a_, 10);
990 heap.Alloc(buffer_b_, 20);
991 heap.Alloc(buffer_c_, 30);
992 heap.Alloc(buffer_d_, 30);
993 heap.Free(buffer_a_, 10);
994 heap.Free(buffer_b_, 20);
995 heap.Free(buffer_c_, 30);
996 heap.Free(buffer_d_, 30);
997 EXPECT_EQ(90, heap.Finish().heap_size);
998 }
999
TEST_F(NoFragmentationStatsHeapTest,Mixed)1000 TEST_F(NoFragmentationStatsHeapTest, Mixed) {
1001 NoFragmentationStatsHeap<HloValue> heap;
1002 heap.Alloc(buffer_a_, 10); // max: A
1003
1004 heap.Alloc(buffer_b_, 20); // max: A+B
1005 heap.Free(buffer_b_, 20);
1006
1007 heap.Alloc(buffer_c_, 30); // max: A+C
1008 heap.Free(buffer_c_, 30);
1009
1010 heap.Alloc(buffer_d_, 5); // max: A+C
1011 heap.Free(buffer_d_, 5);
1012
1013 heap.Free(buffer_a_, 10);
1014 EXPECT_EQ(40, heap.Finish().heap_size);
1015 }
1016
1017 class GlobalDecreasingSizeBestFitHeapTest : public HeapAlgorithmTestBase {
1018 protected:
1019 class InheritedGlobalDecreasingSizeBestFitHeap
1020 : public GlobalDecreasingSizeBestFitHeap<HloValue> {
1021 public:
InheritedGlobalDecreasingSizeBestFitHeap()1022 InheritedGlobalDecreasingSizeBestFitHeap()
1023 : GlobalDecreasingSizeBestFitHeap(/*alignment=*/1) {}
1024
1025 // Finds a chunk candidate and returns the offset and the new heap size.
FindChunkCandidate(const HloValue * buffer,int64 size,int64 start,int64 end,int64 preferred_offset=-1)1026 std::pair<int64, int64> FindChunkCandidate(const HloValue* buffer,
1027 int64 size, int64 start,
1028 int64 end,
1029 int64 preferred_offset = -1) {
1030 buffer_interval_.buffer = buffer;
1031 buffer_interval_.size = size;
1032 buffer_interval_.start = start;
1033 buffer_interval_.end = end;
1034 chunk_candidate_ = GlobalDecreasingSizeBestFitHeap::FindChunkCandidate(
1035 buffer_interval_, preferred_offset);
1036 EXPECT_EQ(chunk_candidate_.chunk.size, size);
1037 return {chunk_candidate_.chunk.offset, chunk_candidate_.heap_size};
1038 }
1039
1040 // Commits the previously found chunk candidate.
CommitChunk()1041 void CommitChunk() {
1042 GlobalDecreasingSizeBestFitHeap::CommitChunk(buffer_interval_,
1043 chunk_candidate_);
1044 }
1045
1046 private:
1047 BufferInterval buffer_interval_;
1048 ChunkCandidate chunk_candidate_;
1049 };
1050
1051 InheritedGlobalDecreasingSizeBestFitHeap heap_;
1052 };
1053
TEST_F(GlobalDecreasingSizeBestFitHeapTest,Empty)1054 TEST_F(GlobalDecreasingSizeBestFitHeapTest, Empty) {
1055 GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1056 const HeapSimulator::Result<HloValue> result = heap.Finish();
1057 EXPECT_EQ(0, result.heap_size);
1058 EXPECT_EQ(1, result.heap_results.size());
1059 EXPECT_EQ(0, result.heap_results.at(0).chunk_map.size());
1060 }
1061
TEST_F(GlobalDecreasingSizeBestFitHeapTest,DecreasingSize)1062 TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSize) {
1063 // space
1064 // ^
1065 // | +---a---+
1066 // | +-------+
1067 // | +---c---+
1068 // | +-------+
1069 // | | b |
1070 // | +-------+
1071 // | +-------+
1072 // | | |
1073 // | | d |
1074 // | +-------+
1075 // -----------------> time
1076 GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1077 heap.Alloc(buffer_a_, 10);
1078 heap.Alloc(buffer_b_, 30);
1079 heap.Alloc(buffer_c_, 20);
1080 heap.Alloc(buffer_d_, 40);
1081 heap.Free(buffer_a_, 10);
1082 heap.Free(buffer_b_, 30);
1083 heap.Free(buffer_c_, 20);
1084 heap.Free(buffer_d_, 40);
1085
1086 const HeapSimulator::Result<HloValue> results = heap.Finish();
1087 EXPECT_EQ(1, results.heap_results.size());
1088 const HeapSimulator::HeapResult<HloValue>& result =
1089 results.heap_results.at(0);
1090 EXPECT_EQ(100, result.heap_size);
1091 EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
1092 EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size);
1093 EXPECT_EQ(20, result.chunk_map.at(buffer_c_).size);
1094 EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size);
1095
1096 EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset);
1097 EXPECT_EQ(40, result.chunk_map.at(buffer_b_).offset);
1098 EXPECT_EQ(70, result.chunk_map.at(buffer_c_).offset);
1099 EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset);
1100 }
1101
TEST_F(GlobalDecreasingSizeBestFitHeapTest,DecreasingSizeWithAlignment)1102 TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSizeWithAlignment) {
1103 // space
1104 // ^
1105 // | +-------+
1106 // | +---b---+
1107 // | +-------+
1108 // | | |
1109 // | | d |
1110 // | +---a---+ +-------+
1111 // |
1112 // | +-------+
1113 // | | |
1114 // | | c |
1115 // | | |
1116 // | +-------+
1117 // ---------------------> time
1118 GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/20);
1119 heap.Alloc(buffer_a_, 10);
1120 heap.Alloc(buffer_b_, 20);
1121 heap.Alloc(buffer_c_, 50);
1122 heap.Free(buffer_a_, 10);
1123 heap.Alloc(buffer_d_, 40);
1124 heap.Free(buffer_b_, 20);
1125 heap.Free(buffer_c_, 50);
1126 heap.Free(buffer_d_, 40);
1127
1128 const HeapSimulator::Result<HloValue> results = heap.Finish();
1129 EXPECT_EQ(1, results.heap_results.size());
1130 const HeapSimulator::HeapResult<HloValue>& result =
1131 results.heap_results.at(0);
1132 EXPECT_EQ(120, result.heap_size);
1133 EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
1134 EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
1135 EXPECT_EQ(50, result.chunk_map.at(buffer_c_).size);
1136 EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size);
1137
1138 EXPECT_EQ(60, result.chunk_map.at(buffer_a_).offset);
1139 EXPECT_EQ(100, result.chunk_map.at(buffer_b_).offset);
1140 EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset);
1141 EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset);
1142 }
1143
TEST_F(GlobalDecreasingSizeBestFitHeapTest,BestFit)1144 TEST_F(GlobalDecreasingSizeBestFitHeapTest, BestFit) {
1145 // space
1146 // ^
1147 // | +-------+
1148 // | +---b---+
1149 // | +-------+
1150 // | | d |
1151 // | +--a--+ +-------+
1152 // | +-------+
1153 // | | |
1154 // | | c |
1155 // | +-------+
1156 // | +-------+
1157 // | | |
1158 // | | e |
1159 // | | |
1160 // | +-------+
1161 // ---------------------> time
1162 GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1163 heap.Alloc(buffer_a_, 10);
1164 heap.Alloc(buffer_b_, 20);
1165 heap.Alloc(buffer_c_, 40);
1166 heap.Free(buffer_a_, 10);
1167 heap.Alloc(buffer_d_, 30);
1168 heap.Alloc(buffer_e_, 50);
1169 heap.Free(buffer_b_, 20);
1170 heap.Free(buffer_c_, 40);
1171 heap.Free(buffer_d_, 30);
1172 heap.Free(buffer_e_, 50);
1173
1174 const HeapSimulator::Result<HloValue> results = heap.Finish();
1175 EXPECT_EQ(1, results.heap_results.size());
1176 const HeapSimulator::HeapResult<HloValue>& result =
1177 results.heap_results.at(0);
1178 EXPECT_EQ(140, result.heap_size);
1179 EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
1180 EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
1181 EXPECT_EQ(40, result.chunk_map.at(buffer_c_).size);
1182 EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size);
1183 EXPECT_EQ(50, result.chunk_map.at(buffer_e_).size);
1184
1185 EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset);
1186 EXPECT_EQ(120, result.chunk_map.at(buffer_b_).offset);
1187 EXPECT_EQ(50, result.chunk_map.at(buffer_c_).offset);
1188 EXPECT_EQ(90, result.chunk_map.at(buffer_d_).offset);
1189 EXPECT_EQ(0, result.chunk_map.at(buffer_e_).offset);
1190 }
1191
TEST_F(GlobalDecreasingSizeBestFitHeapTest,Colocated)1192 TEST_F(GlobalDecreasingSizeBestFitHeapTest, Colocated) {
1193 // space colocate
1194 // ^ +--------------+
1195 // | v v
1196 // |+------+ +-------+
1197 // || | | |
1198 // || |+----+| |
1199 // |+--a---++-b--++---c---+
1200 // ---------------------> time
1201 GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1202 heap.Alloc(buffer_a_, 40);
1203 heap.Free(buffer_a_, 40);
1204 heap.Alloc(buffer_b_, 20);
1205 heap.Free(buffer_b_, 20);
1206 heap.ShareWith(buffer_c_, buffer_a_, 40);
1207 heap.Free(buffer_c_, 40);
1208
1209 const HeapSimulator::Result<HloValue> results = heap.Finish();
1210 EXPECT_EQ(1, results.heap_results.size());
1211 const HeapSimulator::HeapResult<HloValue>& result =
1212 results.heap_results.at(0);
1213 EXPECT_EQ(40, result.heap_size);
1214 EXPECT_EQ(40, result.chunk_map.at(buffer_a_).size);
1215 EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
1216 EXPECT_EQ(40, result.chunk_map.at(buffer_c_).size);
1217
1218 EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
1219 EXPECT_EQ(0, result.chunk_map.at(buffer_b_).offset);
1220 EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset);
1221 }
1222
TEST_F(GlobalDecreasingSizeBestFitHeapTest,ColocatedII)1223 TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedII) {
1224 // space
1225 // ^ +---------------+
1226 // | +-------b-------+
1227 // |+------+ +-------+
1228 // || | | |
1229 // || | | | <--- colocate with a
1230 // |+--a---+ +---c---+
1231 // ---------------------> time
1232 GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1233 heap.Alloc(buffer_a_, 40);
1234 heap.Free(buffer_a_, 40);
1235 heap.Alloc(buffer_b_, 20);
1236
1237 heap.ShareWith(buffer_c_, buffer_a_, 40);
1238 heap.Free(buffer_c_, 40);
1239 heap.Free(buffer_b_, 20);
1240
1241 const HeapSimulator::Result<HloValue> results = heap.Finish();
1242 EXPECT_EQ(1, results.heap_results.size());
1243 const HeapSimulator::HeapResult<HloValue>& result =
1244 results.heap_results.at(0);
1245 EXPECT_EQ(60, result.heap_size);
1246 EXPECT_EQ(40, result.chunk_map.at(buffer_a_).size);
1247 EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
1248 EXPECT_EQ(40, result.chunk_map.at(buffer_c_).size);
1249
1250 EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
1251 EXPECT_EQ(40, result.chunk_map.at(buffer_b_).offset);
1252 EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset);
1253 }
1254
TEST_F(GlobalDecreasingSizeBestFitHeapTest,ColocatedIII)1255 TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedIII) {
1256 // space
1257 // ^+------+ +-------+
1258 // || | | | <--- colocate with a
1259 // |+--a---+ +---c---+
1260 // | +---------------+
1261 // | | |
1262 // | | |
1263 // | +-------b-------+
1264 // ---------------------> time
1265 GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1266 heap.Alloc(buffer_a_, 10);
1267 heap.Free(buffer_a_, 10);
1268 heap.Alloc(buffer_b_, 30);
1269
1270 heap.ShareWith(buffer_c_, buffer_a_, 10);
1271 heap.Free(buffer_c_, 10);
1272 heap.Free(buffer_b_, 30);
1273
1274 const HeapSimulator::Result<HloValue> results = heap.Finish();
1275 EXPECT_EQ(1, results.heap_results.size());
1276 const HeapSimulator::HeapResult<HloValue>& result =
1277 results.heap_results.at(0);
1278 EXPECT_EQ(40, result.heap_size);
1279 EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
1280 EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size);
1281 EXPECT_EQ(10, result.chunk_map.at(buffer_c_).size);
1282
1283 EXPECT_EQ(30, result.chunk_map.at(buffer_a_).offset);
1284 EXPECT_EQ(0, result.chunk_map.at(buffer_b_).offset);
1285 EXPECT_EQ(30, result.chunk_map.at(buffer_c_).offset);
1286 }
1287
TEST_F(GlobalDecreasingSizeBestFitHeapTest,ChunkCandidate)1288 TEST_F(GlobalDecreasingSizeBestFitHeapTest, ChunkCandidate) {
1289 // space
1290 // ^
1291 // 35|
1292 // | +-----------+
1293 // | | |
1294 // 30| | |
1295 // | | po: 15 |
1296 // | | |
1297 // 25| +-----g-----+
1298 // | +-----+
1299 // | |po:20|
1300 // 20| +--f--+
1301 // | +-----+
1302 // | | |
1303 // 15| | |
1304 // | +-----------------+ |po:10|
1305 // | | | | |
1306 // 10| +-------c---------+ +--e--+
1307 // | +-----+ +-----------+
1308 // | | | | po: 5 |
1309 // 5| | | +-----a-----+
1310 // |+-----+ | |
1311 // ||po:10| | |
1312 // 0|+--d--+ +--b--+
1313 // -----------------------------------------> time
1314 // 0 1 2 3 4 5 6 7 8 9 10 11 12 13
1315 using pair = std::pair<int64, int64>;
1316 EXPECT_EQ(pair(5, 10), heap_.FindChunkCandidate(buffer_a_, 5, 6, 10, 5));
1317 heap_.CommitChunk(); // offset: 5, size: 5, start: 6, end: 10
1318 // Preferred offset 5 is returned.
1319 EXPECT_EQ(pair(0, 10), heap_.FindChunkCandidate(buffer_b_, 10, 3, 5));
1320 heap_.CommitChunk(); // offset: 0, size: 10, start: 3, end: 5
1321 EXPECT_EQ(pair(10, 15), heap_.FindChunkCandidate(buffer_c_, 5, 2, 8));
1322 heap_.CommitChunk(); // offset: 10, size: 5, start: 2, end: 8
1323 EXPECT_EQ(pair(0, 15), heap_.FindChunkCandidate(buffer_d_, 5, 0, 2, 10));
1324 heap_.CommitChunk(); // offset: 0, size: 5, start: 0, end: 2
1325 // Preferred offset 10 could not be given because it is occupied.
1326 EXPECT_EQ(pair(10, 20), heap_.FindChunkCandidate(buffer_e_, 10, 11, 13, 10));
1327 heap_.CommitChunk(); // offset: 10, size: 10, start: 11, end: 13
1328 // Preferred offset 10 is returned.
1329 EXPECT_EQ(pair(20, 25), heap_.FindChunkCandidate(buffer_f_, 5, 3, 5, 20));
1330 heap_.CommitChunk(); // offset: 20, size: 5, start: 3, end: 5
1331 // Preferred offset 20 is returned.
1332 EXPECT_EQ(pair(25, 35), heap_.FindChunkCandidate(buffer_g_, 10, 4, 8, 15));
1333 heap_.CommitChunk(); // offset: 25, size: 10, start: 4, end: 8
1334 // Preferred offset 15 could not be given because it is occupied.
1335 }
1336
1337 class ConstrainedGlobalDecreasingSizeBestFitHeapTest
1338 : public HeapAlgorithmTestBase {};
1339
TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest,DecreasingSize)1340 TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest, DecreasingSize) {
1341 // space
1342 // ^
1343 // | +-------+
1344 // | +---c---+
1345 // | +-------+
1346 // | | b |
1347 // | +-------+
1348 // | ................ // split into two allocations.
1349 // | +---a---+
1350 // | +-------+
1351 // | | |
1352 // | | d |
1353 // | +-------+
1354 // -----------------> time
1355 ConstrainedGlobalDecreasingSizeBestFitHeap heap(/*size_limit_per_heap=*/50,
1356 /*alignment=*/1);
1357 heap.Alloc(buffer_a_, 10);
1358 heap.Alloc(buffer_b_, 30);
1359 heap.Alloc(buffer_c_, 20);
1360 heap.Alloc(buffer_d_, 40);
1361 heap.Free(buffer_a_, 10);
1362 heap.Free(buffer_b_, 30);
1363 heap.Free(buffer_c_, 20);
1364 heap.Free(buffer_d_, 40);
1365
1366 const HeapSimulator::Result<HloValue> result = heap.Finish();
1367 EXPECT_EQ(100, result.heap_size);
1368 EXPECT_EQ(2, result.heap_results.size());
1369
1370 EXPECT_TRUE(result.heap_results[0].chunk_map.contains(buffer_a_));
1371 EXPECT_TRUE(result.heap_results[0].chunk_map.contains(buffer_d_));
1372 EXPECT_EQ(10, result.heap_results[0].chunk_map.at(buffer_a_).size);
1373 EXPECT_EQ(40, result.heap_results[0].chunk_map.at(buffer_d_).size);
1374 EXPECT_EQ(40, result.heap_results[0].chunk_map.at(buffer_a_).offset);
1375 EXPECT_EQ(0, result.heap_results[0].chunk_map.at(buffer_d_).offset);
1376 }
1377
TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest,DecreasingSizeWithAlignment)1378 TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest,
1379 DecreasingSizeWithAlignment) {
1380 // space
1381 // ^
1382 // | +-------+
1383 // | +---b---+
1384 // | +-------+
1385 // | | |
1386 // | | d |
1387 // | +-------+
1388 // | ...................
1389 // | +---a---+
1390 // |
1391 // | +-------+
1392 // | | |
1393 // | | c |
1394 // | | |
1395 // | +-------+
1396 // ---------------------> time
1397 ConstrainedGlobalDecreasingSizeBestFitHeap heap(/*size_limit_per_heap=*/70,
1398 /*alignment=*/20);
1399 heap.Alloc(buffer_a_, 10);
1400 heap.Alloc(buffer_b_, 20);
1401 heap.Alloc(buffer_c_, 50);
1402 heap.Free(buffer_a_, 10);
1403 heap.Alloc(buffer_d_, 40);
1404 heap.Free(buffer_b_, 20);
1405 heap.Free(buffer_c_, 50);
1406 heap.Free(buffer_d_, 40);
1407
1408 const HeapSimulator::Result<HloValue> result = heap.Finish();
1409 EXPECT_EQ(130, result.heap_size); // 70 + 60
1410 EXPECT_EQ(2, result.heap_results.size());
1411
1412 EXPECT_TRUE(result.heap_results[0].chunk_map.contains(buffer_a_));
1413 EXPECT_TRUE(result.heap_results[0].chunk_map.contains(buffer_c_));
1414 EXPECT_EQ(10, result.heap_results[0].chunk_map.at(buffer_a_).size);
1415 EXPECT_EQ(50, result.heap_results[0].chunk_map.at(buffer_c_).size);
1416 EXPECT_EQ(60, result.heap_results[0].chunk_map.at(buffer_a_).offset);
1417 EXPECT_EQ(0, result.heap_results[0].chunk_map.at(buffer_c_).offset);
1418 }
1419
TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest,ColocatedII)1420 TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest, ColocatedII) {
1421 // space
1422 // ^
1423 // | +---------------+
1424 // | +-------b-------+
1425 // | ....................
1426 // |+------+ +-------+
1427 // || | | |
1428 // || | | | <--- colocate with a
1429 // |+--a---+ +---c---+
1430 // ---------------------> time
1431 ConstrainedGlobalDecreasingSizeBestFitHeap heap(/*size_limit_per_heap=*/50,
1432 /*alignment=*/20);
1433 heap.Alloc(buffer_a_, 30);
1434 heap.Free(buffer_a_, 30);
1435 heap.Alloc(buffer_b_, 20);
1436
1437 heap.ShareWith(buffer_c_, buffer_a_, 40);
1438 heap.Free(buffer_c_, 40);
1439 heap.Free(buffer_b_, 20);
1440
1441 const HeapSimulator::Result<HloValue> result = heap.Finish();
1442 EXPECT_EQ(50, result.heap_size);
1443 EXPECT_EQ(2, result.heap_results.size());
1444
1445 EXPECT_TRUE(result.heap_results[0].chunk_map.contains(buffer_a_));
1446 EXPECT_TRUE(result.heap_results[0].chunk_map.contains(buffer_c_));
1447 EXPECT_EQ(30, result.heap_results[0].chunk_map.at(buffer_a_).size);
1448 EXPECT_EQ(30, result.heap_results[0].chunk_map.at(buffer_c_).size);
1449 EXPECT_EQ(0, result.heap_results[0].chunk_map.at(buffer_a_).offset);
1450 EXPECT_EQ(0, result.heap_results[0].chunk_map.at(buffer_c_).offset);
1451 }
1452
1453 class IntervalTreeTest : public ::testing::Test {};
1454
TEST_F(IntervalTreeTest,InsertAndRemove)1455 TEST_F(IntervalTreeTest, InsertAndRemove) {
1456 HeapSimulator::Chunk chunk({1, 2});
1457 BufferIntervalTree tree;
1458 tree.Add(1, 2, chunk);
1459 EXPECT_TRUE(tree.Remove(1, 2, chunk));
1460 EXPECT_FALSE(tree.Remove(1, 2, chunk));
1461 ASSERT_EQ(tree.GetRoot(), nullptr);
1462 // Do it again.
1463 tree.Add(1, 2, chunk);
1464 EXPECT_TRUE(tree.Remove(1, 2, chunk));
1465 EXPECT_FALSE(tree.Remove(1, 2, chunk));
1466 ASSERT_EQ(tree.GetRoot(), nullptr);
1467 }
1468
TEST_F(IntervalTreeTest,InsertAndRemoveTwoLevelsLeft)1469 TEST_F(IntervalTreeTest, InsertAndRemoveTwoLevelsLeft) {
1470 HeapSimulator::Chunk chunk({1, 2}); // Value in chunk doesn't matter here.
1471 // [20, 36] (45)
1472 // /
1473 // [1, 45] (45)
1474
1475 BufferIntervalTree tree;
1476 tree.Add(20, 36, chunk);
1477 tree.Add(1, 45, chunk);
1478 EXPECT_TRUE(tree.Remove(1, 45, chunk));
1479 EXPECT_EQ(tree.GetRoot()->subtree_end, 36);
1480 EXPECT_TRUE(tree.Remove(20, 36, chunk));
1481 ASSERT_EQ(tree.GetRoot(), nullptr);
1482 }
1483
TEST_F(IntervalTreeTest,InsertAndRemoveTwoLevelsRight)1484 TEST_F(IntervalTreeTest, InsertAndRemoveTwoLevelsRight) {
1485 HeapSimulator::Chunk chunk({1, 2}); // Value in chunk doesn't matter here.
1486 // [20, 36] (45)
1487 // \
1488 // [21, 45] (45)
1489 BufferIntervalTree tree;
1490 tree.Add(20, 36, chunk);
1491 tree.Add(21, 45, chunk);
1492 EXPECT_TRUE(tree.Remove(21, 45, chunk));
1493 EXPECT_EQ(tree.GetRoot()->subtree_end, 36);
1494 EXPECT_TRUE(tree.Remove(20, 36, chunk));
1495 ASSERT_EQ(tree.GetRoot(), nullptr);
1496 }
1497
TEST_F(IntervalTreeTest,TwoLevelsRight_RootFirst)1498 TEST_F(IntervalTreeTest, TwoLevelsRight_RootFirst) {
1499 HeapSimulator::Chunk chunk({1, 2}); // Value in chunk doesn't matter here.
1500 // [20, 36] (45)
1501 // \
1502 // [21, 45] (45)
1503 BufferIntervalTree tree;
1504 tree.Add(20, 36, chunk);
1505 tree.Add(21, 45, chunk);
1506 EXPECT_TRUE(tree.Remove(20, 36, chunk));
1507 EXPECT_EQ(tree.GetRoot()->subtree_end, 45);
1508 EXPECT_EQ(tree.GetRoot()->start, 21);
1509 EXPECT_EQ(tree.GetRoot()->end, 45);
1510 EXPECT_EQ(tree.GetRoot()->left, nullptr);
1511 EXPECT_EQ(tree.GetRoot()->right, nullptr);
1512 EXPECT_TRUE(tree.Remove(21, 45, chunk));
1513 ASSERT_EQ(tree.GetRoot(), nullptr);
1514 }
1515
TEST_F(IntervalTreeTest,TwoLevelsLeft_RootFirst)1516 TEST_F(IntervalTreeTest, TwoLevelsLeft_RootFirst) {
1517 HeapSimulator::Chunk chunk({1, 2}); // Value in chunk doesn't matter here.
1518 // [20, 36] (45)
1519 // /
1520 // [1, 45] (45)
1521 BufferIntervalTree tree;
1522 tree.Add(20, 36, chunk);
1523 tree.Add(1, 45, chunk);
1524 EXPECT_TRUE(tree.Remove(20, 36, chunk));
1525 EXPECT_EQ(tree.GetRoot()->subtree_end, 45);
1526 EXPECT_EQ(tree.GetRoot()->start, 1);
1527 EXPECT_EQ(tree.GetRoot()->end, 45);
1528 EXPECT_EQ(tree.GetRoot()->left, nullptr);
1529 EXPECT_EQ(tree.GetRoot()->right, nullptr);
1530 EXPECT_TRUE(tree.Remove(1, 45, chunk));
1531 ASSERT_EQ(tree.GetRoot(), nullptr);
1532 }
1533
TEST_F(IntervalTreeTest,ThreeLevelsRight)1534 TEST_F(IntervalTreeTest, ThreeLevelsRight) {
1535 HeapSimulator::Chunk chunk({1, 2}); // Value in chunk doesn't matter here.
1536 // [20, 36] (45)
1537 // \
1538 // [21, 45] (45)
1539 // \
1540 // [22, 40] (40)
1541 BufferIntervalTree tree;
1542 tree.Add(20, 36, chunk);
1543 tree.Add(21, 45, chunk);
1544 tree.Add(22, 40, chunk);
1545 EXPECT_TRUE(tree.Remove(21, 45, chunk));
1546 EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1547 EXPECT_TRUE(tree.Remove(20, 36, chunk));
1548 EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1549 EXPECT_TRUE(tree.Remove(22, 40, chunk));
1550 ASSERT_EQ(tree.GetRoot(), nullptr);
1551 }
TEST_F(IntervalTreeTest,ThreeLevelsLeftLeft)1552 TEST_F(IntervalTreeTest, ThreeLevelsLeftLeft) {
1553 HeapSimulator::Chunk chunk({1, 2}); // Value in chunk doesn't matter here.
1554 // [20, 36] (45)
1555 // /
1556 // [10, 45] (45)
1557 // /
1558 // [1, 40] (40)
1559 BufferIntervalTree tree;
1560 tree.Add(20, 36, chunk);
1561 tree.Add(10, 45, chunk);
1562 tree.Add(1, 40, chunk);
1563 EXPECT_TRUE(tree.Remove(10, 45, chunk));
1564 EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1565 EXPECT_TRUE(tree.Remove(1, 40, chunk));
1566 EXPECT_EQ(tree.GetRoot()->subtree_end, 36);
1567 EXPECT_TRUE(tree.Remove(20, 36, chunk));
1568 ASSERT_EQ(tree.GetRoot(), nullptr);
1569 }
1570
TEST_F(IntervalTreeTest,ThreeLevelsLeftRight)1571 TEST_F(IntervalTreeTest, ThreeLevelsLeftRight) {
1572 HeapSimulator::Chunk chunk({1, 2}); // Value in chunk doesn't matter here.
1573 // [20, 36] (45)
1574 // /
1575 // [10, 45] (45)
1576 // \
1577 // [15, 40] (40)
1578 BufferIntervalTree tree;
1579 tree.Add(20, 36, chunk);
1580 tree.Add(10, 45, chunk);
1581 tree.Add(15, 40, chunk);
1582 EXPECT_TRUE(tree.Remove(10, 45, chunk));
1583 EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1584 EXPECT_TRUE(tree.Remove(15, 40, chunk));
1585 EXPECT_EQ(tree.GetRoot()->subtree_end, 36);
1586 EXPECT_TRUE(tree.Remove(20, 36, chunk));
1587 ASSERT_EQ(tree.GetRoot(), nullptr);
1588 }
1589
TEST_F(IntervalTreeTest,ThreeLevelsRightLeft)1590 TEST_F(IntervalTreeTest, ThreeLevelsRightLeft) {
1591 HeapSimulator::Chunk chunk({1, 2}); // Value in chunk doesn't matter here.
1592 // [20, 36] (45)
1593 // \
1594 // [25, 45] (45)
1595 // /
1596 // [22, 40] (40)
1597 BufferIntervalTree tree;
1598 tree.Add(20, 36, chunk);
1599 tree.Add(25, 45, chunk);
1600 tree.Add(22, 40, chunk);
1601 EXPECT_TRUE(tree.Remove(25, 45, chunk));
1602 EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1603 EXPECT_TRUE(tree.Remove(20, 36, chunk));
1604 EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1605 EXPECT_TRUE(tree.Remove(22, 40, chunk));
1606 ASSERT_EQ(tree.GetRoot(), nullptr);
1607 }
1608
TEST_F(IntervalTreeTest,ThreeLevelsRightLeftChunkDifferent)1609 TEST_F(IntervalTreeTest, ThreeLevelsRightLeftChunkDifferent) {
1610 HeapSimulator::Chunk chunk1({1, 2});
1611 HeapSimulator::Chunk chunk2({2, 3});
1612 HeapSimulator::Chunk chunk3({3, 4});
1613 // [20, 36] (45) Chunk1({1, 2})
1614 // \
1615 // [25, 45] (45) Chunk2({2, 3})
1616 // /
1617 // [22, 40] (40) Chunk3({3, 4})
1618 BufferIntervalTree tree;
1619 tree.Add(20, 36, chunk1);
1620 tree.Add(25, 45, chunk2);
1621 tree.Add(22, 40, chunk3);
1622 EXPECT_TRUE(tree.Remove(25, 45, chunk2));
1623 // Chunk 1 is till the root after removing chunk 2.
1624 EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1625 EXPECT_EQ(tree.GetRoot()->chunk.offset, 1);
1626 EXPECT_EQ(tree.GetRoot()->chunk.size, 2);
1627 EXPECT_TRUE(tree.Remove(20, 36, chunk1));
1628 // Chunk 3 becomes the root now.
1629 EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1630 EXPECT_EQ(tree.GetRoot()->chunk.offset, 3);
1631 EXPECT_EQ(tree.GetRoot()->chunk.size, 4);
1632 EXPECT_TRUE(tree.Remove(22, 40, chunk3));
1633 ASSERT_EQ(tree.GetRoot(), nullptr);
1634 }
1635
1636 } // namespace
1637 } // namespace xla
1638