• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/heap_simulator.h"
17 
18 #include <memory>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "tensorflow/compiler/xla/literal.h"
24 #include "tensorflow/compiler/xla/service/buffer_value.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
29 #include "tensorflow/compiler/xla/service/hlo_value.h"
30 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/test.h"
35 
36 namespace xla {
37 namespace {
38 
39 class MinimumMemoryForSequenceTest : public HloTestBase {};
40 
TEST_F(MinimumMemoryForSequenceTest,MultiComputation)41 TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
42   auto module = CreateNewVerifiedModule();
43   const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
44   const Shape tuple_shape =
45       ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
46 
47   auto cond_builder = HloComputation::Builder("WhileCond");
48   // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
49   HloInstruction* cond_param = cond_builder.AddInstruction(
50       HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
51   HloInstruction* cond_iter = cond_builder.AddInstruction(
52       HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
53   HloInstruction* cond_data = cond_builder.AddInstruction(
54       HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
55   // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
56   HloInstruction* cond_lt = cond_builder.AddInstruction(
57       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
58                                     cond_data, ComparisonDirection::kLt));
59   HloComputation* cond_computation =
60       module->AddEmbeddedComputation(cond_builder.Build());
61 
62   auto body_builder = HloComputation::Builder("WhileBody");
63   // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
64   HloInstruction* body_param = body_builder.AddInstruction(
65       HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
66   HloComputation* body_computation =
67       module->AddEmbeddedComputation(body_builder.Build());
68 
69   auto builder = HloComputation::Builder(TestName());
70   // Entry params: 8 bytes (4 bytes per param), TOTAL=8
71   HloInstruction* iter = builder.AddInstruction(
72       HloInstruction::CreateParameter(0, scalar_shape, "param_iter"));
73   HloInstruction* data = builder.AddInstruction(
74       HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
75   // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24
76   HloInstruction* tuple =
77       builder.AddInstruction(HloInstruction::CreateTuple({iter, data}));
78   // While: 8 bytes (4 bytes per element), TOTAL=32
79   // Both cond and body use a max of 24 bytes, TOTAL=56
80   HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
81       tuple_shape, cond_computation, body_computation, tuple));
82   HloComputation* entry_computation =
83       module->AddEntryComputation(builder.Build());
84 
85   auto size_fn = [](const BufferValue& buffer) {
86     return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
87   };
88 
89   HloSchedule schedule(module.get());
90   schedule.set_sequence(cond_computation,
91                         {cond_param, cond_iter, cond_data, cond_lt});
92   schedule.set_sequence(body_computation, {body_param});
93   schedule.set_sequence(entry_computation, {iter, data, tuple, while_op});
94   TF_ASSERT_OK(schedule.Verify());
95 
96   EXPECT_EQ(
97       25,
98       HeapSimulator::MinimumMemoryForModule(schedule, size_fn).ValueOrDie());
99 }
100 
TEST_F(MinimumMemoryForSequenceTest,SubcomputationAccounting)101 TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) {
102   // HloModule SubcomputationAccounting
103 
104   // %WhileBody (body_param: f32[4]) -> f32[4] {
105   //   %body_param = f32[4]{0} parameter(0)
106   //   %constant.1 = f32[4]{0} constant({1, 1, 1, 1})
107   //   ROOT %subtract = f32[4]{0} subtract(f32[4]{0} %body_param, f32[4]{0}
108   //   %constant.1)
109   // }
110 
111   // %WhileCond (cond_param: f32[4]) -> pred[] {
112   //   %cond_param = f32[4]{0} parameter(0)
113   //   %slice = f32[1]{0} slice(f32[4]{0} %cond_param), slice={[0:1]}
114   //   %reshape = f32[] reshape(f32[1]{0} %slice)
115   //   %constant = f32[] constant(0)
116   //   ROOT %not-equal-to = pred[] compare(f32[] %reshape, f32[] %constant),
117   //   direction=NE
118   // }
119 
120   // ENTRY %SubcomputationAccounting () -> f32[2,4] {
121   //   %constant.3 = f32[2,4]{1,0} constant(f32[2,4] { { 1, 2, 3, 4 }, { 1, 2,
122   //   3, 4 } }) %transpose = f32[2,4]{1,0} transpose(f32[2,4]{1,0}
123   //   %constant.3), dimensions={0,1} %constant.2 = f32[4]{0} constant({1, 1, 1,
124   //   1}) %while = f32[4]{0} while(f32[4]{0} %constant.2),
125   //   condition=%WhileCond, body=%WhileBody %broadcast = f32[2,4]{1,0}
126   //   broadcast(f32[4]{0} %while), dimensions={1} ROOT %add = f32[2,4]{1,0}
127   //   add(f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast)
128   // }
129 
130   auto module = CreateNewVerifiedModule();
131   const Shape r0f32 = ShapeUtil::MakeShape(F32, {});
132   const Shape r1f32 = ShapeUtil::MakeShape(F32, {4});
133   const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4});
134 
135   // reshape(slice(param)) != 0
136   // Needs 5 bytes
137   auto cond_builder = HloComputation::Builder("WhileCond");
138   HloInstruction* cond_param = cond_builder.AddInstruction(
139       HloInstruction::CreateParameter(0, r1f32, "cond_param"));
140   HloInstruction* slice =
141       cond_builder.AddInstruction(HloInstruction::CreateSlice(
142           ShapeUtil::MakeShape(F32, {1}), cond_param, {0}, {1}, {1}));
143   HloInstruction* reshape =
144       cond_builder.AddInstruction(HloInstruction::CreateReshape(r0f32, slice));
145   HloInstruction* zero = cond_builder.AddInstruction(
146       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
147   HloInstruction* cond_comparison = cond_builder.AddInstruction(
148       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), reshape,
149                                     zero, ComparisonDirection::kNe));
150   auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
151 
152   // param - 1
153   // Needs 16 bytes
154   auto body_builder = HloComputation::Builder("WhileBody");
155   HloInstruction* body_param = body_builder.AddInstruction(
156       HloInstruction::CreateParameter(0, r1f32, "body_param"));
157   HloInstruction* one_vector =
158       body_builder.AddInstruction(HloInstruction::CreateConstant(
159           LiteralUtil::CreateR1<float>({1, 1, 1, 1})));
160   HloInstruction* subtract =
161       body_builder.AddInstruction(HloInstruction::CreateBinary(
162           r1f32, HloOpcode::kSubtract, body_param, one_vector));
163   auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
164 
165   // transpose(matrix) + bcast(while)
166   auto builder = HloComputation::Builder(TestName());
167   HloInstruction* while_init =
168       builder.AddInstruction(HloInstruction::CreateConstant(
169           LiteralUtil::CreateR1<float>({1, 1, 1, 1})));
170   // Creates 16 bytes, ignoring subcomputations
171   HloInstruction* while_loop =
172       builder.AddInstruction(HloInstruction::CreateWhile(
173           r1f32, cond_computation, body_computation, while_init));
174 
175   // Creates 32 bytes and frees 16
176   HloInstruction* bcast = builder.AddInstruction(
177       HloInstruction::CreateBroadcast(r2f32, while_loop, {1}));
178 
179   HloInstruction* matrix = builder.AddInstruction(
180       HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>(
181           {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}})));
182   // Creates 32 bytes
183   HloInstruction* transpose = builder.AddInstruction(
184       HloInstruction::CreateTranspose(r2f32, matrix, {0, 1}));
185 
186   // Creates 32 bytes and frees 64
187   HloInstruction* add = builder.AddInstruction(
188       HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast));
189 
190   auto entry_computation = module->AddEntryComputation(builder.Build());
191 
192   HloSchedule schedule(module.get());
193   std::vector<HloInstruction*> cond_vec = {cond_param, slice, reshape, zero,
194                                            cond_comparison};
195   std::vector<HloInstruction*> while_body_vec = {body_param, one_vector,
196                                                  subtract};
197   std::vector<HloInstruction*> entry_comp_vec = {while_init, while_loop, bcast,
198                                                  matrix,     transpose,  add};
199   schedule.set_sequence(cond_computation, cond_vec);
200   schedule.set_sequence(body_computation, while_body_vec);
201   schedule.set_sequence(entry_computation, entry_comp_vec);
202 
203   auto size_fn = [](const BufferValue& buffer) {
204     return ShapeUtil::ByteSizeOf(buffer.shape());
205   };
206   absl::flat_hash_map<const HloComputation*, int64_t> memory_by_computation;
207   memory_by_computation[cond_computation] = 5;
208   memory_by_computation[body_computation] = 16;
209 
210   std::unique_ptr<HloAliasAnalysis> alias_analysis =
211       HloAliasAnalysis::Run(module.get()).ValueOrDie();
212 
213   // HeapSimulator accounts for subcomputations. The output buffer is aliased,
214   // so we don't double count.
215   EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation(
216                     *entry_computation, schedule.sequence(entry_computation),
217                     *alias_analysis, size_fn, &memory_by_computation)
218                     .ValueOrDie());
219 }
220 
221 const char kAlloc[] = "Alloc";
222 const char kFree[] = "Free";
223 const char kShare[] = "Share";
224 const char kFinish[] = "Finish";
225 
226 // CallSequence records a sequence of Alloc/Free/Finish calls.
227 using CallSequence = std::vector<std::pair<std::string, const HloValue*>>;
228 
229 // HeapCallRecorder is a dummy heap algorithm that simply records its calls.
230 class HeapCallRecorder : public HeapAlgorithm<HloValue> {
231  public:
HeapCallRecorder(CallSequence * calls)232   explicit HeapCallRecorder(CallSequence* calls) : calls_(calls) {}
~HeapCallRecorder()233   ~HeapCallRecorder() override {}
234 
Alloc(const HloValue * buffer,int64_t size)235   void Alloc(const HloValue* buffer, int64_t size) override {
236     calls_->emplace_back(kAlloc, buffer);
237     // Instead of assigning a real offset, we set the cardinality of the Alloc
238     // call.  This isn't a valid assignment, but allows us to easily test for
239     // buffer sharing.
240     const int64_t offset = result_.chunk_map.size();
241     result_.chunk_map.emplace(buffer, Chunk{offset, size});
242   }
243 
ShareWith(const HloValue * buffer,const HloValue * shared,int64_t size)244   void ShareWith(const HloValue* buffer, const HloValue* shared,
245                  int64_t size) override {
246     calls_->emplace_back(kShare, buffer);
247     // Instead of assigning a real offset, we set the cardinality of the Alloc
248     // call.  This isn't a valid assignment, but allows us to easily test for
249     // buffer sharing.
250     const int64_t offset = result_.chunk_map[shared].offset;
251     result_.chunk_map.emplace(buffer, Chunk{offset, size});
252   }
Free(const HloValue * buffer,int64_t size)253   void Free(const HloValue* buffer, int64_t size) override {
254     calls_->emplace_back(kFree, buffer);
255   }
Finish()256   Result Finish() override {
257     calls_->emplace_back(kFinish, nullptr);
258     HeapSimulator::Result<HloValue> result;
259     result.heap_size = result_.heap_size;
260     result.heap_results.emplace_back(std::move(result_));
261     return result;
262   }
263 
264  private:
265   CallSequence* calls_;
266   HeapSimulator::HeapResult<HloValue> result_;
267 };
268 
269 // HeapSimulatorTracker runs the heap simulator, recording the sequence of calls
270 // made to the underlying heap algorithm.  Tests compare the actual call
271 // sequence against an expected sequence.
272 class HeapSimulatorTracker {
273  public:
HeapSimulatorTracker(std::unique_ptr<HloModule> module,const std::vector<HloInstruction * > & instruction_sequence,const std::vector<HloInstruction * > & must_alias_set={},const HloDataflowAnalysis::CanShareBuffer & can_share_buffer=nullptr)274   explicit HeapSimulatorTracker(
275       std::unique_ptr<HloModule> module,
276       const std::vector<HloInstruction*>& instruction_sequence,
277       const std::vector<HloInstruction*>& must_alias_set = {},
278       const HloDataflowAnalysis::CanShareBuffer& can_share_buffer = nullptr) {
279     module_ = std::move(module);
280     Init(instruction_sequence, can_share_buffer);
281   }
282 
283   // Constructor for testing a single entry computation.
HeapSimulatorTracker(const std::string & name,std::unique_ptr<HloComputation> entry_computation,const std::vector<HloInstruction * > & instruction_sequence,const std::vector<HloInstruction * > & must_alias_set={},const HloDataflowAnalysis::CanShareBuffer & can_share_buffer=nullptr)284   explicit HeapSimulatorTracker(
285       const std::string& name,
286       std::unique_ptr<HloComputation> entry_computation,
287       const std::vector<HloInstruction*>& instruction_sequence,
288       const std::vector<HloInstruction*>& must_alias_set = {},
289       const HloDataflowAnalysis::CanShareBuffer& can_share_buffer = nullptr) {
290     HloModuleConfig config;
291     module_ = std::make_unique<HloModule>(name, config);
292     module_->AddEntryComputation(std::move(entry_computation));
293     Init(instruction_sequence, can_share_buffer);
294   }
295 
HeapSimulatorTracker(const std::string & name)296   explicit HeapSimulatorTracker(const std::string& name) {
297     HloModuleConfig config;
298     module_ = std::make_unique<HloModule>(name, config);
299   }
300 
301   // Similar to the single entry computation constructor above, but runs the
302   // simulation over the entire module.
RunWholeModule(const std::vector<HloInstruction * > & full_module_sequence)303   void RunWholeModule(
304       const std::vector<HloInstruction*>& full_module_sequence) {
305     alias_analysis_ = HloAliasAnalysis::Run(module_.get()).value();
306 
307     // Construct the module sequence grouped by computation.
308     HloSchedule schedule(module_.get());
309     absl::flat_hash_map<const HloInstruction*, int> reverse_position;
310     for (int i = 0; i < full_module_sequence.size(); ++i) {
311       HloInstruction* instruction = full_module_sequence[i];
312       schedule.GetOrCreateSequence(instruction->parent())
313           .push_back(instruction);
314       reverse_position[instruction] = full_module_sequence.size() - i;
315     }
316 
317     // Hack the size_fn so that it returns a decreasing value as we step through
318     // the sequence. This lets us ensure the Alloc calls are in the sequence
319     // order. The Free calls are sorted by BufferValue.id, which is at least
320     // deterministic.
321     auto size_fn = [&reverse_position](const BufferValue& buffer) {
322       return reverse_position[buffer.instruction()];
323     };
324     auto algorithm = std::make_unique<HeapCallRecorder>(&actual_calls_);
325     result_ = HeapSimulator::Run(std::move(algorithm), *module_, schedule,
326                                  *alias_analysis_, size_fn)
327                   .value();
328   }
329 
module()330   HloModule* module() { return module_.get(); }
331 
332   // Returns the buffer defined at the given instruction and index.
BufferAt(const HloInstruction * instruction,const ShapeIndex & index) const333   const HloValue* BufferAt(const HloInstruction* instruction,
334                            const ShapeIndex& index) const {
335     return &alias_analysis_->dataflow_analysis().GetUniqueValueAt(instruction,
336                                                                   index);
337   }
338 
OffsetAt(const HloInstruction * instruction,const ShapeIndex & index)339   int64_t OffsetAt(const HloInstruction* instruction, const ShapeIndex& index) {
340     const HloValue* buffer = BufferAt(instruction, index);
341     CHECK_EQ(1, result_.heap_results.size());
342     return result_.heap_results.at(0).chunk_map.at(buffer).offset;
343   }
344 
345   // Ensures the expected sequence of Alloc/Free/Finish calls was performed.
ExpectCallSequence(const CallSequence & expected) const346   void ExpectCallSequence(const CallSequence& expected) const {
347     auto to_string = [](const CallSequence& sequence) {
348       std::string output;
349       for (int64_t i = 0; i < sequence.size(); ++i) {
350         auto pair = sequence.at(i);
351         absl::StrAppendFormat(&output, "%d", i);
352         absl::StrAppendFormat(&output, " :%s", pair.first);
353         if (pair.second != nullptr) {
354           absl::StrAppendFormat(&output, " - %s{%s}\n",
355                                 pair.second->instruction()->name(),
356                                 pair.second->index().ToString());
357         }
358       }
359       return output;
360     };
361     EXPECT_EQ(expected, actual_calls_) << "Expected:\n"
362                                        << to_string(expected) << " \nActual:\n"
363                                        << to_string(actual_calls_) << "\n";
364   }
365 
366   // Ensures the buffers defined by the respective (instruction,index) pairs are
367   // shared, relying on the unique offsets assigned in
368   // HeapCallRecorder::Alloc.
ExpectSharedBuffers(const HloInstruction * instruction_a,const ShapeIndex & index_a,const HloInstruction * instruction_b,const ShapeIndex & index_b)369   void ExpectSharedBuffers(const HloInstruction* instruction_a,
370                            const ShapeIndex& index_a,
371                            const HloInstruction* instruction_b,
372                            const ShapeIndex& index_b) {
373     int64_t offset_a = OffsetAt(instruction_a, index_a);
374     int64_t offset_b = OffsetAt(instruction_b, index_b);
375     EXPECT_EQ(offset_a, offset_b);
376   }
377 
378  private:
Init(const std::vector<HloInstruction * > & instruction_sequence,const HloDataflowAnalysis::CanShareBuffer & can_share_buffer)379   void Init(const std::vector<HloInstruction*>& instruction_sequence,
380             const HloDataflowAnalysis::CanShareBuffer& can_share_buffer) {
381     // Since we're only tracking the sequence of Alloc/Free calls, the actual
382     // size of the buffers doesn't matter, so we always return 0.  We rely on
383     // the secondary sorting criteria of DecreasingSizeRunsHeap to sort calls
384     // by buffer id, for determinism in the tests.
385     auto zero_size = [](const BufferValue& buffer) { return 0; };
386     auto algorithm = std::make_unique<HeapCallRecorder>(&actual_calls_);
387 
388     alias_analysis_ =
389         HloAliasAnalysis::Run(module_.get(), can_share_buffer).ValueOrDie();
390 
391     HeapSimulator::Options options;
392 
393     result_ =
394         HeapSimulator::Run(std::move(algorithm), *module_->entry_computation(),
395                            HloInstructionSequence(instruction_sequence),
396                            *alias_analysis_, zero_size, options)
397             .value();
398   }
399 
400   std::unique_ptr<HloModule> module_;
401   std::unique_ptr<HloAliasAnalysis> alias_analysis_;
402   CallSequence actual_calls_;
403   HeapSimulator::Result<HloValue> result_;
404 };
405 
406 class HeapSimulatorTest : public HloTestBase {
407  protected:
HeapSimulatorTest()408   HeapSimulatorTest() {}
~HeapSimulatorTest()409   ~HeapSimulatorTest() override {}
410 
411   // Shapes for use in the examples.
412   Shape f32scalar_ = ShapeUtil::MakeShape(xla::F32, {});
413   Shape f32vec4_ = ShapeUtil::MakeShape(F32, {4});
414 };
415 
TEST_F(HeapSimulatorTest,ScalarConstant)416 TEST_F(HeapSimulatorTest, ScalarConstant) {
417   auto builder = HloComputation::Builder(TestName());
418   auto const0 = builder.AddInstruction(
419       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
420 
421   // Constants aren't assigned.  See b/32248867
422   HeapSimulatorTracker tracker(TestName(), builder.Build(), {const0});
423   tracker.ExpectCallSequence({{kFinish, nullptr}});
424 }
425 
TEST_F(HeapSimulatorTest,OneParam)426 TEST_F(HeapSimulatorTest, OneParam) {
427   auto builder = HloComputation::Builder(TestName());
428   auto param0 = builder.AddInstruction(
429       HloInstruction::CreateParameter(0, f32scalar_, "param0"));
430 
431   // A single parameter which is also the output.
432   HeapSimulatorTracker tracker(TestName(), builder.Build(), {param0});
433   tracker.ExpectCallSequence({
434       {kAlloc, tracker.BufferAt(param0, {})},
435       {kFree, tracker.BufferAt(param0, {})},
436       {kFinish, nullptr},
437   });
438 }
439 
TEST_F(HeapSimulatorTest,Multiply)440 TEST_F(HeapSimulatorTest, Multiply) {
441   auto builder = HloComputation::Builder(TestName());
442   auto paramA = builder.AddInstruction(
443       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
444   auto paramX = builder.AddInstruction(
445       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
446   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
447       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
448 
449   // We must keep all parameters and outputs.
450   HeapSimulatorTracker tracker(TestName(), builder.Build(),
451                                {paramA, paramX, mul});
452   tracker.ExpectCallSequence({
453       {kAlloc, tracker.BufferAt(paramA, {})},
454       {kAlloc, tracker.BufferAt(paramX, {})},
455       {kAlloc, tracker.BufferAt(mul, {})},
456       // All params and outputs are freed at the end.
457       {kFree, tracker.BufferAt(paramA, {})},
458       {kFree, tracker.BufferAt(paramX, {})},
459       {kFree, tracker.BufferAt(mul, {})},
460       {kFinish, nullptr},
461   });
462 }
463 
TEST_F(HeapSimulatorTest,MultiplyAdd)464 TEST_F(HeapSimulatorTest, MultiplyAdd) {
465   auto builder = HloComputation::Builder(TestName());
466   auto paramA = builder.AddInstruction(
467       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
468   auto paramX = builder.AddInstruction(
469       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
470   auto paramY = builder.AddInstruction(
471       HloInstruction::CreateParameter(2, f32vec4_, "paramY"));
472   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
473       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
474   auto add = builder.AddInstruction(
475       HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, mul, paramY));
476 
477   // The buffer for add is the output, and it's shared with the buffer for
478   // mul.
479   HeapSimulatorTracker tracker(TestName(), builder.Build(),
480                                {paramA, paramX, mul, paramY, add});
481   tracker.ExpectCallSequence({
482       {kAlloc, tracker.BufferAt(paramA, {})},
483       {kAlloc, tracker.BufferAt(paramX, {})},
484       {kAlloc, tracker.BufferAt(paramY, {})},
485       {kAlloc, tracker.BufferAt(mul, {})},
486       {kFree, tracker.BufferAt(mul, {})},
487       {kShare, tracker.BufferAt(add, {})},
488       // All params and outputs are freed at the end.
489       {kFree, tracker.BufferAt(paramA, {})},
490       {kFree, tracker.BufferAt(paramX, {})},
491       {kFree, tracker.BufferAt(paramY, {})},
492       {kFree, tracker.BufferAt(add, {})},
493       {kFinish, nullptr},
494   });
495   tracker.ExpectSharedBuffers(add, {}, mul, {});
496 }
497 
TEST_F(HeapSimulatorTest,FusionOutputsOnlyShareOnce)498 TEST_F(HeapSimulatorTest, FusionOutputsOnlyShareOnce) {
499   // Test that only one output of a fusion node will be shared with its operand.
500   auto can_share_buffer =
501       [](const HloInstruction* instr, const HloInstruction* operand,
502          const ShapeIndex& user_index) -> std::optional<bool> {
503     return instr->opcode() == HloOpcode::kFusion &&
504            operand->shape().IsArray() &&
505            ShapeUtil::Equal(operand->shape(),
506                             ShapeUtil::GetSubshape(instr->shape(), user_index));
507   };
508 
509   HloModuleConfig config;
510   auto module = std::make_unique<HloModule>(TestName(), config);
511 
512   auto builder = HloComputation::Builder(TestName());
513   auto paramA = builder.AddInstruction(
514       HloInstruction::CreateParameter(0, f32vec4_, "paramA"));
515   auto negate = builder.AddInstruction(
516       HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, paramA));
517 
518   // The fusion node has two outputs, both are eligible for being reused with
519   // operand.
520   auto fusion_builder = HloComputation::Builder("simple_two_way_forwarding");
521   {
522     auto param = fusion_builder.AddInstruction(
523         HloInstruction::CreateParameter(0, f32vec4_, "x"));
524     fusion_builder.AddInstruction(HloInstruction::CreateTuple({param, param}));
525   }
526   auto fusion_computation =
527       module->AddEmbeddedComputation(fusion_builder.Build());
528 
529   auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
530       ShapeUtil::MakeTupleShape({f32vec4_, f32vec4_}),
531       HloInstruction::FusionKind::kLoop, {negate}, fusion_computation));
532 
533   auto element0 = builder.AddInstruction(
534       HloInstruction::CreateGetTupleElement(f32scalar_, fusion, 0));
535 
536   auto element1 = builder.AddInstruction(
537       HloInstruction::CreateGetTupleElement(f32scalar_, fusion, 1));
538 
539   auto negate0 = builder.AddInstruction(
540       HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, element0));
541   auto negate1 = builder.AddInstruction(
542       HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, element1));
543 
544   builder.AddInstruction(HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd,
545                                                       negate0, negate1));
546 
547   module->AddEntryComputation(builder.Build());
548   HeapSimulatorTracker tracker(
549       std::move(module),
550       {paramA, negate, fusion, element0, element1, negate0, negate1}, {},
551       can_share_buffer);
552   tracker.ExpectCallSequence({
553       {kAlloc, tracker.BufferAt(paramA, {})},
554       {kAlloc, tracker.BufferAt(negate, {})},
555       {kAlloc, tracker.BufferAt(fusion, {})},
556       {kFree, tracker.BufferAt(negate, {})},
557       {kShare, tracker.BufferAt(fusion, {0})},
558       {kAlloc, tracker.BufferAt(fusion, {1})},
559       {kFree, tracker.BufferAt(fusion, {})},
560       {kAlloc, tracker.BufferAt(negate0, {})},
561       {kFree, tracker.BufferAt(fusion, {0})},
562       {kFree, tracker.BufferAt(negate0, {})},
563       {kAlloc, tracker.BufferAt(negate1, {})},
564       {kFree, tracker.BufferAt(fusion, {1})},
565       {kFree, tracker.BufferAt(negate1, {})},
566       {kFree, tracker.BufferAt(paramA, {})},
567       {kFinish, nullptr},
568   });
569 }
570 
TEST_F(HeapSimulatorTest,FusionOutputsOnlyShareOnceOutputShortLived)571 TEST_F(HeapSimulatorTest, FusionOutputsOnlyShareOnceOutputShortLived) {
572   // Test that only one output of a fusion node will be shared with its operand.
573   // This variant of the test has a fusion node that dies immediately.
574   auto can_share_buffer =
575       [](const HloInstruction* instr, const HloInstruction* operand,
576          const ShapeIndex& user_index) -> std::optional<bool> {
577     if (instr->opcode() == HloOpcode::kFusion) {
578       return true;
579     }
580     return false;
581   };
582 
583   HloModuleConfig config;
584   auto module = std::make_unique<HloModule>(TestName(), config);
585 
586   auto builder = HloComputation::Builder(TestName());
587   auto paramA = builder.AddInstruction(
588       HloInstruction::CreateParameter(0, f32vec4_, "paramA"));
589   auto negate = builder.AddInstruction(
590       HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, paramA));
591 
592   // The fusion node has two outputs, both are eligible for being reused with
593   // operand.
594   auto fusion_builder = HloComputation::Builder("simple_two_way_forwarding");
595   {
596     auto param = fusion_builder.AddInstruction(
597         HloInstruction::CreateParameter(0, f32vec4_, "x"));
598     fusion_builder.AddInstruction(HloInstruction::CreateTuple({param, param}));
599   }
600   auto fusion_computation =
601       module->AddEmbeddedComputation(fusion_builder.Build());
602 
603   auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
604       ShapeUtil::MakeTupleShape({f32vec4_, f32vec4_}),
605       HloInstruction::FusionKind::kLoop, {negate}, fusion_computation));
606 
607   auto element1 = builder.AddInstruction(
608       HloInstruction::CreateGetTupleElement(f32scalar_, fusion, 1));
609 
610   auto negate1 = builder.AddInstruction(
611       HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, element1));
612 
613   module->AddEntryComputation(builder.Build());
614   HeapSimulatorTracker tracker(std::move(module),
615                                {paramA, negate, fusion, element1, negate1}, {},
616                                can_share_buffer);
617   tracker.ExpectCallSequence({
618       {kAlloc, tracker.BufferAt(paramA, {})},
619       {kAlloc, tracker.BufferAt(negate, {})},
620       {kFree, tracker.BufferAt(negate, {})},
621       {kShare, tracker.BufferAt(fusion, {0})},
622       {kAlloc, tracker.BufferAt(fusion, {})},
623       {kAlloc, tracker.BufferAt(fusion, {1})},
624       {kFree, tracker.BufferAt(fusion, {0})},
625       {kFree, tracker.BufferAt(fusion, {})},
626       {kAlloc, tracker.BufferAt(negate1, {})},
627       {kFree, tracker.BufferAt(fusion, {1})},
628       {kFree, tracker.BufferAt(paramA, {})},
629       {kFree, tracker.BufferAt(negate1, {})},
630       {kFinish, nullptr},
631   });
632 }
633 
TEST_F(HeapSimulatorTest,BufferReusedOnce)634 TEST_F(HeapSimulatorTest, BufferReusedOnce) {
635   HeapSimulatorTracker tracker(TestName());
636   auto builder = HloComputation::Builder(TestName());
637 
638   HloComputation::Builder fusion_builder("fusion");
639   {
640     HloComputation::Builder& builder = fusion_builder;
641     auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
642         /*parameter_number=*/0, f32vec4_, "A"));
643     auto exp = builder.AddInstruction(
644         HloInstruction::CreateUnary(f32vec4_, HloOpcode::kExp, a_param));
645     auto neg = builder.AddInstruction(
646         HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, a_param));
647 
648     builder.AddInstruction(HloInstruction::CreateTuple({exp, neg}));
649   }
650   auto fusion_computation =
651       tracker.module()->AddEmbeddedComputation(fusion_builder.Build());
652   auto a_param = builder.AddInstruction(
653       HloInstruction::CreateParameter(0, f32vec4_, "paramA"));
654   auto neg = builder.AddInstruction(
655       HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, a_param));
656   auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
657       ShapeUtil::MakeTupleShape({f32vec4_, f32vec4_}),
658       HloInstruction::FusionKind::kLoop, {neg}, fusion_computation));
659   tracker.module()->AddEntryComputation(builder.Build());
660 
661   tracker.RunWholeModule({a_param, neg, fusion});
662 
663   auto neg_buffer = tracker.OffsetAt(neg, {});
664   int64_t output_buffer_0 = tracker.OffsetAt(fusion, {0});
665   int64_t output_buffer_1 = tracker.OffsetAt(fusion, {1});
666   // Only one buffer should be shared.
667   EXPECT_TRUE((neg_buffer == output_buffer_0) ^
668               (neg_buffer == output_buffer_1));
669 }
670 
TEST_F(HeapSimulatorTest,MultiplyDot)671 TEST_F(HeapSimulatorTest, MultiplyDot) {
672   auto builder = HloComputation::Builder(TestName());
673   auto paramA = builder.AddInstruction(
674       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
675   auto paramX = builder.AddInstruction(
676       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
677   auto paramY = builder.AddInstruction(
678       HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
679   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
680       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
681   DotDimensionNumbers dot_dnums;
682   dot_dnums.add_lhs_contracting_dimensions(1);
683   dot_dnums.add_rhs_contracting_dimensions(0);
684   auto dot = builder.AddInstruction(HloInstruction::CreateDot(
685       f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
686 
687   // The buffer for dot is the output, and it cannot be shared with the buffer
688   // for mul, since dot isn't elementwise.
689   HeapSimulatorTracker tracker(TestName(), builder.Build(),
690                                {paramA, paramX, mul, paramY, dot});
691   tracker.ExpectCallSequence({
692       {kAlloc, tracker.BufferAt(paramA, {})},
693       {kAlloc, tracker.BufferAt(paramX, {})},
694       {kAlloc, tracker.BufferAt(paramY, {})},
695       {kAlloc, tracker.BufferAt(mul, {})},
696       {kAlloc, tracker.BufferAt(dot, {})},
697       // All params and outputs are freed at the end.
698       {kFree, tracker.BufferAt(mul, {})},
699       {kFree, tracker.BufferAt(paramA, {})},
700       {kFree, tracker.BufferAt(paramX, {})},
701       {kFree, tracker.BufferAt(paramY, {})},
702       {kFree, tracker.BufferAt(dot, {})},
703       {kFinish, nullptr},
704   });
705 }
706 
TEST_F(HeapSimulatorTest,MultiplyDotAdd)707 TEST_F(HeapSimulatorTest, MultiplyDotAdd) {
708   auto builder = HloComputation::Builder(TestName());
709   auto paramA = builder.AddInstruction(
710       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
711   auto paramX = builder.AddInstruction(
712       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
713   auto paramY = builder.AddInstruction(
714       HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
715   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
716       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
717   DotDimensionNumbers dot_dnums;
718   dot_dnums.add_lhs_contracting_dimensions(1);
719   dot_dnums.add_rhs_contracting_dimensions(0);
720   auto dot = builder.AddInstruction(HloInstruction::CreateDot(
721       f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
722   auto add = builder.AddInstruction(
723       HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA));
724 
725   // The buffer for add is the output, and it's shared with the buffer for
726   // dot.
727   HeapSimulatorTracker tracker(TestName(), builder.Build(),
728                                {paramA, paramX, mul, paramY, dot, add});
729   tracker.ExpectCallSequence({
730       {kAlloc, tracker.BufferAt(paramA, {})},
731       {kAlloc, tracker.BufferAt(paramX, {})},
732       {kAlloc, tracker.BufferAt(paramY, {})},
733       {kAlloc, tracker.BufferAt(mul, {})},
734       {kAlloc, tracker.BufferAt(dot, {})},
735       {kFree, tracker.BufferAt(mul, {})},
736       {kFree, tracker.BufferAt(dot, {})},
737       {kShare, tracker.BufferAt(add, {})},
738       // All params and outputs are freed at the end.
739       {kFree, tracker.BufferAt(paramA, {})},
740       {kFree, tracker.BufferAt(paramX, {})},
741       {kFree, tracker.BufferAt(paramY, {})},
742       {kFree, tracker.BufferAt(add, {})},
743       {kFinish, nullptr},
744   });
745   tracker.ExpectSharedBuffers(add, {}, dot, {});
746 }
747 
TEST_F(HeapSimulatorTest,MultiplyDotDot)748 TEST_F(HeapSimulatorTest, MultiplyDotDot) {
749   auto builder = HloComputation::Builder(TestName());
750   auto paramA = builder.AddInstruction(
751       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
752   auto paramX = builder.AddInstruction(
753       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
754   auto paramY = builder.AddInstruction(
755       HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
756   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
757       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
758   DotDimensionNumbers dot_dnums;
759   dot_dnums.add_lhs_contracting_dimensions(1);
760   dot_dnums.add_rhs_contracting_dimensions(0);
761   auto dot0 = builder.AddInstruction(HloInstruction::CreateDot(
762       f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
763   auto dot1 = builder.AddInstruction(HloInstruction::CreateDot(
764       f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2)));
765 
766   // The buffer for dot1 is the output.  No buffers can be shared.  The buffer
767   // for mul is freed before the end, since it's no longer used after dot0
768   // finishes.
769   HeapSimulatorTracker tracker(TestName(), builder.Build(),
770                                {paramA, paramX, mul, paramY, dot0, dot1});
771   tracker.ExpectCallSequence({
772       {kAlloc, tracker.BufferAt(paramA, {})},
773       {kAlloc, tracker.BufferAt(paramX, {})},
774       {kAlloc, tracker.BufferAt(paramY, {})},
775       {kAlloc, tracker.BufferAt(mul, {})},
776       {kAlloc, tracker.BufferAt(dot0, {})},
777       {kFree, tracker.BufferAt(mul, {})},  // mul no longer used
778       {kAlloc, tracker.BufferAt(dot1, {})},
779       {kFree, tracker.BufferAt(dot0, {})},
780       // All params and outputs are freed at the end.
781       {kFree, tracker.BufferAt(paramA, {})},
782       {kFree, tracker.BufferAt(paramX, {})},
783       {kFree, tracker.BufferAt(paramY, {})},
784       {kFree, tracker.BufferAt(dot1, {})},
785       {kFinish, nullptr},
786   });
787 }
788 
TEST_F(HeapSimulatorTest,MultiplyDotDotTuple)789 TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) {
790   auto builder = HloComputation::Builder(TestName());
791   auto paramA = builder.AddInstruction(
792       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
793   auto paramX = builder.AddInstruction(
794       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
795   auto paramY = builder.AddInstruction(
796       HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
797   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
798       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
799   DotDimensionNumbers dot_dnums;
800   dot_dnums.add_lhs_contracting_dimensions(1);
801   dot_dnums.add_rhs_contracting_dimensions(0);
802   auto dot0 = builder.AddInstruction(HloInstruction::CreateDot(
803       f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
804   auto dot1 = builder.AddInstruction(HloInstruction::CreateDot(
805       f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2)));
806   auto tuple =
807       builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1}));
808 
809   // The buffers for dot0, dot1 and tuple are the output.  No buffers can be
810   // shared.  The buffer for mul is freed before the end, since it's no longer
811   // used after dot0 finishes.
812   HeapSimulatorTracker tracker(
813       TestName(), builder.Build(),
814       {paramA, paramX, mul, paramY, dot0, dot1, tuple});
815   tracker.ExpectCallSequence({
816       {kAlloc, tracker.BufferAt(paramA, {})},
817       {kAlloc, tracker.BufferAt(paramX, {})},
818       {kAlloc, tracker.BufferAt(paramY, {})},
819       {kAlloc, tracker.BufferAt(mul, {})},
820       {kAlloc, tracker.BufferAt(dot0, {})},
821       {kFree, tracker.BufferAt(mul, {})},  // mul no longer used
822       {kAlloc, tracker.BufferAt(dot1, {})},
823       {kAlloc, tracker.BufferAt(tuple, {})},
824       // All params and outputs are freed at the end.
825       {kFree, tracker.BufferAt(paramA, {})},
826       {kFree, tracker.BufferAt(paramX, {})},
827       {kFree, tracker.BufferAt(paramY, {})},
828       {kFree, tracker.BufferAt(dot0, {})},
829       {kFree, tracker.BufferAt(dot1, {})},
830       {kFree, tracker.BufferAt(tuple, {})},
831       {kFinish, nullptr},
832   });
833 }
834 
TEST_F(HeapSimulatorTest,IndependentTupleElements)835 TEST_F(HeapSimulatorTest, IndependentTupleElements) {
836   auto builder = HloComputation::Builder(TestName());
837   auto paramA = builder.AddInstruction(
838       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
839   auto paramB = builder.AddInstruction(
840       HloInstruction::CreateParameter(1, f32scalar_, "paramB"));
841   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
842       f32scalar_, HloOpcode::kMultiply, paramA, paramB));
843   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
844       f32scalar_, HloOpcode::kAdd, paramA, paramB));
845   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({mul, add}));
846   auto element0 = builder.AddInstruction(
847       HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 0));
848   auto broadcast = builder.AddInstruction(
849       HloInstruction::CreateBroadcast(f32vec4_, element0, {0}));
850   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
851       f32scalar_, HloOpcode::kSubtract, paramA, paramB));
852   auto element1 = builder.AddInstruction(
853       HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 1));
854   auto output = builder.AddInstruction(
855       HloInstruction::CreateTuple({broadcast, sub, element1}));
856 
857   HeapSimulatorTracker tracker(TestName(), builder.Build(),
858                                {paramA, paramB, mul, add, tuple, element0,
859                                 broadcast, sub, element1, output});
860   tracker.ExpectCallSequence({
861       {kAlloc, tracker.BufferAt(paramA, {})},
862       {kAlloc, tracker.BufferAt(paramB, {})},
863       {kAlloc, tracker.BufferAt(mul, {})},
864       {kAlloc, tracker.BufferAt(add, {})},
865       {kAlloc, tracker.BufferAt(tuple, {})},
866       {kAlloc, tracker.BufferAt(broadcast, {})},
867       // The mul can be freed right after the broadcast happens, even though
868       // The other GetTupleElement is still alive.
869       {kFree, tracker.BufferAt(mul, {})},
870       {kAlloc, tracker.BufferAt(sub, {})},
871       // The temporary tuple is now dead.
872       {kFree, tracker.BufferAt(tuple, {})},
873       {kAlloc, tracker.BufferAt(output, {})},
874       // All params and outputs are freed at the end.
875       {kFree, tracker.BufferAt(paramA, {})},
876       {kFree, tracker.BufferAt(paramB, {})},
877       {kFree, tracker.BufferAt(add, {})},
878       {kFree, tracker.BufferAt(broadcast, {})},
879       {kFree, tracker.BufferAt(sub, {})},
880       {kFree, tracker.BufferAt(output, {})},
881       {kFinish, nullptr},
882   });
883 }
884 
TEST_F(HeapSimulatorTest,WholeModule)885 TEST_F(HeapSimulatorTest, WholeModule) {
886   HeapSimulatorTracker tracker(TestName());
887 
888   const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
889   const Shape tuple_shape =
890       ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
891 
892   auto cond_builder = HloComputation::Builder("WhileCond");
893   HloInstruction* cond_param = cond_builder.AddInstruction(
894       HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
895   HloInstruction* cond_iter = cond_builder.AddInstruction(
896       HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
897   HloInstruction* cond_data = cond_builder.AddInstruction(
898       HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
899   HloInstruction* cond_lt = cond_builder.AddInstruction(
900       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
901                                     cond_data, ComparisonDirection::kLt));
902   HloComputation* cond_computation =
903       tracker.module()->AddEmbeddedComputation(cond_builder.Build());
904 
905   auto body_builder = HloComputation::Builder("WhileBody");
906   HloInstruction* body_param = body_builder.AddInstruction(
907       HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
908   HloComputation* body_computation =
909       tracker.module()->AddEmbeddedComputation(body_builder.Build());
910 
911   auto builder = HloComputation::Builder(TestName());
912   HloInstruction* param = builder.AddInstruction(
913       HloInstruction::CreateParameter(0, tuple_shape, "param"));
914   HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
915       tuple_shape, cond_computation, body_computation, param));
916   tracker.module()->AddEntryComputation(builder.Build());
917 
918   tracker.RunWholeModule(
919       {param, while_op, body_param, cond_param, cond_iter, cond_data, cond_lt});
920   tracker.ExpectCallSequence({
921       // The entry computation param and while_op are allocated first.
922       {kAlloc, tracker.BufferAt(param, {})},
923       {kAlloc, tracker.BufferAt(param, {0})},
924       {kAlloc, tracker.BufferAt(param, {1})},
925 
926       // Now the final cond less-than buffer is allocated.
927       {kAlloc, tracker.BufferAt(cond_lt, {})},
928 
929       // The order of the remaining Free calls is based on the BufferValue.id,
930       // which is deterministic, but not obvious.
931       {kFree, tracker.BufferAt(cond_lt, {})},
932       {kFree, tracker.BufferAt(param, {})},
933       {kFree, tracker.BufferAt(param, {0})},
934       {kFree, tracker.BufferAt(param, {1})},
935       {kFinish, nullptr},
936   });
937 }
938 
939 // Base class for heap algorithm tests.
940 class HeapAlgorithmTestBase : public ::testing::Test {
941  protected:
HeapAlgorithmTestBase()942   HeapAlgorithmTestBase() : builder_("heap_simulator_test") {
943     buffer_a_ = DummyBufferValue();
944     buffer_b_ = DummyBufferValue();
945     buffer_c_ = DummyBufferValue();
946     buffer_d_ = DummyBufferValue();
947     buffer_e_ = DummyBufferValue();
948     buffer_f_ = DummyBufferValue();
949     buffer_g_ = DummyBufferValue();
950     buffer_h_ = DummyBufferValue();
951     buffer_i_ = DummyBufferValue();
952   }
~HeapAlgorithmTestBase()953   ~HeapAlgorithmTestBase() override {}
954 
955   const HloValue* buffer_a_;
956   const HloValue* buffer_b_;
957   const HloValue* buffer_c_;
958   const HloValue* buffer_d_;
959   const HloValue* buffer_e_;
960   const HloValue* buffer_f_;
961   const HloValue* buffer_g_;
962   const HloValue* buffer_h_;
963   const HloValue* buffer_i_;
964 
965  private:
966   // Create a dummy HloValue to pass to the heap algorithm.
DummyBufferValue()967   const HloValue* DummyBufferValue() {
968     const HloValue::Id id = buffers_.size();
969     auto const0 = builder_.AddInstruction(
970         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
971     buffers_.emplace_back(std::make_unique<HloValue>(id, const0, ShapeIndex{}));
972     return buffers_.back().get();
973   }
974 
975   HloComputation::Builder builder_;
976   std::vector<std::unique_ptr<HloValue>> buffers_;
977 };
978 
979 class NoFragmentationStatsHeapTest : public HeapAlgorithmTestBase {};
980 
TEST_F(NoFragmentationStatsHeapTest,Empty)981 TEST_F(NoFragmentationStatsHeapTest, Empty) {
982   NoFragmentationStatsHeap<HloValue> heap;
983   EXPECT_EQ(0, heap.Finish().heap_size);
984 }
985 
TEST_F(NoFragmentationStatsHeapTest,Simple)986 TEST_F(NoFragmentationStatsHeapTest, Simple) {
987   NoFragmentationStatsHeap<HloValue> heap;
988   heap.Alloc(buffer_a_, 10);
989   heap.Alloc(buffer_b_, 20);
990   heap.Alloc(buffer_c_, 30);
991   heap.Alloc(buffer_d_, 30);
992   heap.Free(buffer_a_, 10);
993   heap.Free(buffer_b_, 20);
994   heap.Free(buffer_c_, 30);
995   heap.Free(buffer_d_, 30);
996   EXPECT_EQ(90, heap.Finish().heap_size);
997 }
998 
TEST_F(NoFragmentationStatsHeapTest,Mixed)999 TEST_F(NoFragmentationStatsHeapTest, Mixed) {
1000   NoFragmentationStatsHeap<HloValue> heap;
1001   heap.Alloc(buffer_a_, 10);  // max: A
1002 
1003   heap.Alloc(buffer_b_, 20);  // max: A+B
1004   heap.Free(buffer_b_, 20);
1005 
1006   heap.Alloc(buffer_c_, 30);  // max: A+C
1007   heap.Free(buffer_c_, 30);
1008 
1009   heap.Alloc(buffer_d_, 5);  // max: A+C
1010   heap.Free(buffer_d_, 5);
1011 
1012   heap.Free(buffer_a_, 10);
1013   EXPECT_EQ(40, heap.Finish().heap_size);
1014 }
1015 
1016 class GlobalDecreasingSizeBestFitHeapTest : public HeapAlgorithmTestBase {
1017  protected:
1018   class InheritedGlobalDecreasingSizeBestFitHeap
1019       : public GlobalDecreasingSizeBestFitHeap<HloValue> {
1020    public:
InheritedGlobalDecreasingSizeBestFitHeap()1021     InheritedGlobalDecreasingSizeBestFitHeap()
1022         : GlobalDecreasingSizeBestFitHeap(/*alignment=*/1) {}
1023 
1024     // Finds a chunk candidate and returns the offset and the new heap size.
FindChunkCandidate(const HloValue * buffer,int64_t size,int64_t start,int64_t end,int64_t preferred_offset=-1)1025     std::pair<int64_t, int64_t> FindChunkCandidate(
1026         const HloValue* buffer, int64_t size, int64_t start, int64_t end,
1027         int64_t preferred_offset = -1) {
1028       buffer_interval_.buffer = buffer;
1029       buffer_interval_.size = size;
1030       buffer_interval_.start = start;
1031       buffer_interval_.end = end;
1032       chunk_candidate_ = GlobalDecreasingSizeBestFitHeap::FindChunkCandidate(
1033           buffer_interval_, preferred_offset);
1034       EXPECT_EQ(chunk_candidate_.size, size);
1035       return {chunk_candidate_.offset,
1036               result_.UpdatedHeapSize(chunk_candidate_)};
1037     }
1038 
1039     // Commits the previously found chunk candidate.
CommitChunk()1040     void CommitChunk() {
1041       GlobalDecreasingSizeBestFitHeap::CommitChunk(buffer_interval_,
1042                                                    chunk_candidate_);
1043     }
1044 
1045    private:
1046     BufferInterval buffer_interval_;
1047     Chunk chunk_candidate_;
1048   };
1049 
1050   InheritedGlobalDecreasingSizeBestFitHeap heap_;
1051 };
1052 
TEST_F(GlobalDecreasingSizeBestFitHeapTest,Empty)1053 TEST_F(GlobalDecreasingSizeBestFitHeapTest, Empty) {
1054   GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1055   const HeapSimulator::Result<HloValue> result = heap.Finish();
1056   EXPECT_EQ(0, result.heap_size);
1057   EXPECT_EQ(1, result.heap_results.size());
1058   EXPECT_EQ(0, result.heap_results.at(0).chunk_map.size());
1059 }
1060 
TEST_F(GlobalDecreasingSizeBestFitHeapTest,DecreasingSize)1061 TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSize) {
1062   // space
1063   //   ^
1064   //   |  +---a---+
1065   //   |      +-------+
1066   //   |      +---c---+
1067   //   |    +-------+
1068   //   |    |   b   |
1069   //   |    +-------+
1070   //   |         +-------+
1071   //   |         |       |
1072   //   |         |   d   |
1073   //   |         +-------+
1074   //   -----------------> time
1075   GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1076   heap.Alloc(buffer_a_, 10);
1077   heap.Alloc(buffer_b_, 30);
1078   heap.Alloc(buffer_c_, 20);
1079   heap.Alloc(buffer_d_, 40);
1080   heap.Free(buffer_a_, 10);
1081   heap.Free(buffer_b_, 30);
1082   heap.Free(buffer_c_, 20);
1083   heap.Free(buffer_d_, 40);
1084 
1085   const HeapSimulator::Result<HloValue> results = heap.Finish();
1086   EXPECT_EQ(1, results.heap_results.size());
1087   const HeapSimulator::HeapResult<HloValue>& result =
1088       results.heap_results.at(0);
1089   EXPECT_EQ(100, result.heap_size);
1090   EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
1091   EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size);
1092   EXPECT_EQ(20, result.chunk_map.at(buffer_c_).size);
1093   EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size);
1094 
1095   EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset);
1096   EXPECT_EQ(40, result.chunk_map.at(buffer_b_).offset);
1097   EXPECT_EQ(70, result.chunk_map.at(buffer_c_).offset);
1098   EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset);
1099 }
1100 
TEST_F(GlobalDecreasingSizeBestFitHeapTest,DecreasingSizeWithAlignment)1101 TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSizeWithAlignment) {
1102   // space
1103   //   ^
1104   //   |      +-------+
1105   //   |      +---b---+
1106   //   |            +-------+
1107   //   |            |       |
1108   //   |            |   d   |
1109   //   |  +---a---+ +-------+
1110   //   |
1111   //   |         +-------+
1112   //   |         |       |
1113   //   |         |   c   |
1114   //   |         |       |
1115   //   |         +-------+
1116   //   ---------------------> time
1117   GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/20);
1118   heap.Alloc(buffer_a_, 10);
1119   heap.Alloc(buffer_b_, 20);
1120   heap.Alloc(buffer_c_, 50);
1121   heap.Free(buffer_a_, 10);
1122   heap.Alloc(buffer_d_, 40);
1123   heap.Free(buffer_b_, 20);
1124   heap.Free(buffer_c_, 50);
1125   heap.Free(buffer_d_, 40);
1126 
1127   const HeapSimulator::Result<HloValue> results = heap.Finish();
1128   EXPECT_EQ(1, results.heap_results.size());
1129   const HeapSimulator::HeapResult<HloValue>& result =
1130       results.heap_results.at(0);
1131   EXPECT_EQ(120, result.heap_size);
1132   EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
1133   EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
1134   EXPECT_EQ(50, result.chunk_map.at(buffer_c_).size);
1135   EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size);
1136 
1137   EXPECT_EQ(60, result.chunk_map.at(buffer_a_).offset);
1138   EXPECT_EQ(100, result.chunk_map.at(buffer_b_).offset);
1139   EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset);
1140   EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset);
1141 }
1142 
TEST_F(GlobalDecreasingSizeBestFitHeapTest,BestFit)1143 TEST_F(GlobalDecreasingSizeBestFitHeapTest, BestFit) {
1144   // space
1145   //   ^
1146   //   |    +-------+
1147   //   |    +---b---+
1148   //   |         +-------+
1149   //   |         |   d   |
1150   //   | +--a--+ +-------+
1151   //   |      +-------+
1152   //   |      |       |
1153   //   |      |   c   |
1154   //   |      +-------+
1155   //   |           +-------+
1156   //   |           |       |
1157   //   |           |   e   |
1158   //   |           |       |
1159   //   |           +-------+
1160   //   ---------------------> time
1161   GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1162   heap.Alloc(buffer_a_, 10);
1163   heap.Alloc(buffer_b_, 20);
1164   heap.Alloc(buffer_c_, 40);
1165   heap.Free(buffer_a_, 10);
1166   heap.Alloc(buffer_d_, 30);
1167   heap.Alloc(buffer_e_, 50);
1168   heap.Free(buffer_b_, 20);
1169   heap.Free(buffer_c_, 40);
1170   heap.Free(buffer_d_, 30);
1171   heap.Free(buffer_e_, 50);
1172 
1173   const HeapSimulator::Result<HloValue> results = heap.Finish();
1174   EXPECT_EQ(1, results.heap_results.size());
1175   const HeapSimulator::HeapResult<HloValue>& result =
1176       results.heap_results.at(0);
1177   EXPECT_EQ(140, result.heap_size);
1178   EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
1179   EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
1180   EXPECT_EQ(40, result.chunk_map.at(buffer_c_).size);
1181   EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size);
1182   EXPECT_EQ(50, result.chunk_map.at(buffer_e_).size);
1183 
1184   EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset);
1185   EXPECT_EQ(120, result.chunk_map.at(buffer_b_).offset);
1186   EXPECT_EQ(50, result.chunk_map.at(buffer_c_).offset);
1187   EXPECT_EQ(90, result.chunk_map.at(buffer_d_).offset);
1188   EXPECT_EQ(0, result.chunk_map.at(buffer_e_).offset);
1189 }
1190 
TEST_F(GlobalDecreasingSizeBestFitHeapTest,Colocated)1191 TEST_F(GlobalDecreasingSizeBestFitHeapTest, Colocated) {
1192   // space      colocate
1193   //   ^   +--------------+
1194   //   |   v              v
1195   //   |+------+      +-------+
1196   //   ||      |      |       |
1197   //   ||      |+----+|       |
1198   //   |+--a---++-b--++---c---+
1199   //   ---------------------> time
1200   GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1201   heap.Alloc(buffer_a_, 40);
1202   heap.Free(buffer_a_, 40);
1203   heap.Alloc(buffer_b_, 20);
1204   heap.Free(buffer_b_, 20);
1205   heap.ShareWith(buffer_c_, buffer_a_, 40);
1206   heap.Free(buffer_c_, 40);
1207 
1208   const HeapSimulator::Result<HloValue> results = heap.Finish();
1209   EXPECT_EQ(1, results.heap_results.size());
1210   const HeapSimulator::HeapResult<HloValue>& result =
1211       results.heap_results.at(0);
1212   EXPECT_EQ(40, result.heap_size);
1213   EXPECT_EQ(40, result.chunk_map.at(buffer_a_).size);
1214   EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
1215   EXPECT_EQ(40, result.chunk_map.at(buffer_c_).size);
1216 
1217   EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
1218   EXPECT_EQ(0, result.chunk_map.at(buffer_b_).offset);
1219   EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset);
1220 }
1221 
TEST_F(GlobalDecreasingSizeBestFitHeapTest,ColocatedII)1222 TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedII) {
1223   // space
1224   //   ^       +---------------+
1225   //   |       +-------b-------+
1226   //   |+------+      +-------+
1227   //   ||      |      |       |
1228   //   ||      |      |       | <--- colocate with a
1229   //   |+--a---+      +---c---+
1230   //   ---------------------> time
1231   GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1232   heap.Alloc(buffer_a_, 40);
1233   heap.Free(buffer_a_, 40);
1234   heap.Alloc(buffer_b_, 20);
1235 
1236   heap.ShareWith(buffer_c_, buffer_a_, 40);
1237   heap.Free(buffer_c_, 40);
1238   heap.Free(buffer_b_, 20);
1239 
1240   const HeapSimulator::Result<HloValue> results = heap.Finish();
1241   EXPECT_EQ(1, results.heap_results.size());
1242   const HeapSimulator::HeapResult<HloValue>& result =
1243       results.heap_results.at(0);
1244   EXPECT_EQ(60, result.heap_size);
1245   EXPECT_EQ(40, result.chunk_map.at(buffer_a_).size);
1246   EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
1247   EXPECT_EQ(40, result.chunk_map.at(buffer_c_).size);
1248 
1249   EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
1250   EXPECT_EQ(40, result.chunk_map.at(buffer_b_).offset);
1251   EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset);
1252 }
1253 
TEST_F(GlobalDecreasingSizeBestFitHeapTest,ColocatedIII)1254 TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedIII) {
1255   // space
1256   //   ^+------+      +-------+
1257   //   ||      |      |       | <--- colocate with a
1258   //   |+--a---+      +---c---+
1259   //   |       +---------------+
1260   //   |       |               |
1261   //   |       |               |
1262   //   |       +-------b-------+
1263   //   ---------------------> time
1264   GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1265   heap.Alloc(buffer_a_, 10);
1266   heap.Free(buffer_a_, 10);
1267   heap.Alloc(buffer_b_, 30);
1268 
1269   heap.ShareWith(buffer_c_, buffer_a_, 10);
1270   heap.Free(buffer_c_, 10);
1271   heap.Free(buffer_b_, 30);
1272 
1273   const HeapSimulator::Result<HloValue> results = heap.Finish();
1274   EXPECT_EQ(1, results.heap_results.size());
1275   const HeapSimulator::HeapResult<HloValue>& result =
1276       results.heap_results.at(0);
1277   EXPECT_EQ(40, result.heap_size);
1278   EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
1279   EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size);
1280   EXPECT_EQ(10, result.chunk_map.at(buffer_c_).size);
1281 
1282   EXPECT_EQ(30, result.chunk_map.at(buffer_a_).offset);
1283   EXPECT_EQ(0, result.chunk_map.at(buffer_b_).offset);
1284   EXPECT_EQ(30, result.chunk_map.at(buffer_c_).offset);
1285 }
1286 
TEST_F(GlobalDecreasingSizeBestFitHeapTest,ColocatedDifferentSize1)1287 TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedDifferentSize1) {
1288   // space
1289   //   ^
1290   //   |         +---------------+
1291   //   |+------+ +-------b-------+
1292   //   ||      |      +-------+
1293   //   ||      |      |       | <--- colocate with a
1294   //   |+--a---+      +---c---+
1295   //   ---------------------> time
1296   GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1297   heap.Alloc(buffer_a_, 40);
1298   heap.Free(buffer_a_, 40);
1299   heap.Alloc(buffer_b_, 20);
1300 
1301   heap.ShareWith(buffer_c_, buffer_a_, 30);
1302   heap.Free(buffer_c_, 30);
1303   heap.Free(buffer_b_, 20);
1304 
1305   const HeapSimulator::Result<HloValue> results = heap.Finish();
1306   EXPECT_EQ(1, results.heap_results.size());
1307   const HeapSimulator::HeapResult<HloValue>& result =
1308       results.heap_results.at(0);
1309   EXPECT_EQ(50, result.heap_size);
1310   EXPECT_EQ(40, result.chunk_map.at(buffer_a_).size);
1311   EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
1312   EXPECT_EQ(30, result.chunk_map.at(buffer_c_).size);
1313 
1314   EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
1315   EXPECT_EQ(30, result.chunk_map.at(buffer_b_).offset);
1316   EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset);
1317 }
1318 
TEST_F(GlobalDecreasingSizeBestFitHeapTest,ColocatedDifferentSize2)1319 TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedDifferentSize2) {
1320   // space
1321   //   ^         +-------------+
1322   //   |         +-----b-------+
1323   //   |              +-------+
1324   //   |+------+      |       |
1325   //   ||      |      |       |
1326   //   ||      |      |       | <--- colocate with a
1327   //   |+--a---+      +---c---+
1328   //   ---------------------> time
1329   GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1330   heap.Alloc(buffer_a_, 40);
1331   heap.Free(buffer_a_, 40);
1332   heap.Alloc(buffer_b_, 20);
1333 
1334   heap.ShareWith(buffer_c_, buffer_a_, 50);
1335   heap.Free(buffer_c_, 50);
1336   heap.Free(buffer_b_, 20);
1337 
1338   const HeapSimulator::Result<HloValue> results = heap.Finish();
1339   EXPECT_EQ(1, results.heap_results.size());
1340   const HeapSimulator::HeapResult<HloValue>& result =
1341       results.heap_results.at(0);
1342   EXPECT_EQ(70, result.heap_size);
1343   EXPECT_EQ(40, result.chunk_map.at(buffer_a_).size);
1344   EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
1345   EXPECT_EQ(50, result.chunk_map.at(buffer_c_).size);
1346 
1347   EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
1348   EXPECT_EQ(50, result.chunk_map.at(buffer_b_).offset);
1349   EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset);
1350 }
1351 
TEST_F(GlobalDecreasingSizeBestFitHeapTest,ChunkCandidate)1352 TEST_F(GlobalDecreasingSizeBestFitHeapTest, ChunkCandidate) {
1353   // space
1354   //   ^
1355   // 35|
1356   //   |            +-----------+
1357   //   |            |           |
1358   // 30|            |           |
1359   //   |            |  po: 15   |
1360   //   |            |           |
1361   // 25|            +-----g-----+
1362   //   |         +-----+
1363   //   |         |po:20|
1364   // 20|         +--f--+
1365   //   |                                +-----+
1366   //   |                                |     |
1367   // 15|                                |     |
1368   //   |      +-----------------+       |po:10|
1369   //   |      |                 |       |     |
1370   // 10|      +-------c---------+       +--e--+
1371   //   |         +-----+  +-----------+
1372   //   |         |     |  |   po: 5   |
1373   //  5|         |     |  +-----a-----+
1374   //   |+-----+  |     |
1375   //   ||po:10|  |     |
1376   //  0|+--d--+  +--b--+
1377   //   -----------------------------------------> time
1378   //    0  1  2  3  4  5  6  7  8  9 10 11 12 13
1379   using pair = std::pair<int64_t, int64_t>;
1380   EXPECT_EQ(pair(5, 10), heap_.FindChunkCandidate(buffer_a_, 5, 6, 10, 5));
1381   heap_.CommitChunk();  // offset: 5, size: 5, start: 6, end: 10
1382   // Preferred offset 5 is returned.
1383   EXPECT_EQ(pair(0, 10), heap_.FindChunkCandidate(buffer_b_, 10, 3, 5));
1384   heap_.CommitChunk();  // offset: 0, size: 10, start: 3, end: 5
1385   EXPECT_EQ(pair(10, 15), heap_.FindChunkCandidate(buffer_c_, 5, 2, 8));
1386   heap_.CommitChunk();  // offset: 10, size: 5, start: 2, end: 8
1387   EXPECT_EQ(pair(0, 15), heap_.FindChunkCandidate(buffer_d_, 5, 0, 2, 10));
1388   heap_.CommitChunk();  // offset: 0, size: 5, start: 0, end: 2
1389   // Preferred offset 10 could not be given because it is occupied.
1390   EXPECT_EQ(pair(10, 20), heap_.FindChunkCandidate(buffer_e_, 10, 11, 13, 10));
1391   heap_.CommitChunk();  // offset: 10, size: 10, start: 11, end: 13
1392   // Preferred offset 10 is returned.
1393   EXPECT_EQ(pair(20, 25), heap_.FindChunkCandidate(buffer_f_, 5, 3, 5, 20));
1394   heap_.CommitChunk();  // offset: 20, size: 5, start: 3, end: 5
1395   // Preferred offset 20 is returned.
1396   EXPECT_EQ(pair(25, 35), heap_.FindChunkCandidate(buffer_g_, 10, 4, 8, 15));
1397   heap_.CommitChunk();  // offset: 25, size: 10, start: 4, end: 8
1398   // Preferred offset 15 could not be given because it is occupied.
1399 }
1400 
1401 class ConstrainedGlobalDecreasingSizeBestFitHeapTest
1402     : public HeapAlgorithmTestBase {};
1403 
TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest,DecreasingSize)1404 TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest, DecreasingSize) {
1405   // space
1406   //   ^
1407   //   |      +-------+
1408   //   |      +---c---+
1409   //   |    +-------+
1410   //   |    |   b   |
1411   //   |    +-------+
1412   //   | ................ // split into two allocations.
1413   //   |  +---a---+
1414   //   |         +-------+
1415   //   |         |       |
1416   //   |         |   d   |
1417   //   |         +-------+
1418   //   -----------------> time
1419   ConstrainedGlobalDecreasingSizeBestFitHeap heap(/*size_limit_per_heap=*/50,
1420                                                   /*alignment=*/1);
1421   heap.Alloc(buffer_a_, 10);
1422   heap.Alloc(buffer_b_, 30);
1423   heap.Alloc(buffer_c_, 20);
1424   heap.Alloc(buffer_d_, 40);
1425   heap.Free(buffer_a_, 10);
1426   heap.Free(buffer_b_, 30);
1427   heap.Free(buffer_c_, 20);
1428   heap.Free(buffer_d_, 40);
1429 
1430   const HeapSimulator::Result<HloValue> result = heap.Finish();
1431   EXPECT_EQ(100, result.heap_size);
1432   EXPECT_EQ(2, result.heap_results.size());
1433 
1434   EXPECT_TRUE(result.heap_results[0].chunk_map.contains(buffer_a_));
1435   EXPECT_TRUE(result.heap_results[0].chunk_map.contains(buffer_d_));
1436   EXPECT_EQ(10, result.heap_results[0].chunk_map.at(buffer_a_).size);
1437   EXPECT_EQ(40, result.heap_results[0].chunk_map.at(buffer_d_).size);
1438   EXPECT_EQ(40, result.heap_results[0].chunk_map.at(buffer_a_).offset);
1439   EXPECT_EQ(0, result.heap_results[0].chunk_map.at(buffer_d_).offset);
1440 }
1441 
TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest,DecreasingSizeWithAlignment)1442 TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest,
1443        DecreasingSizeWithAlignment) {
1444   // space
1445   //   ^
1446   //   |      +-------+
1447   //   |      +---b---+
1448   //   |            +-------+
1449   //   |            |       |
1450   //   |            |   d   |
1451   //   |            +-------+
1452   //   | ...................
1453   //   |  +---a---+
1454   //   |
1455   //   |         +-------+
1456   //   |         |       |
1457   //   |         |   c   |
1458   //   |         |       |
1459   //   |         +-------+
1460   //   ---------------------> time
1461   ConstrainedGlobalDecreasingSizeBestFitHeap heap(/*size_limit_per_heap=*/70,
1462                                                   /*alignment=*/20);
1463   heap.Alloc(buffer_a_, 10);
1464   heap.Alloc(buffer_b_, 20);
1465   heap.Alloc(buffer_c_, 50);
1466   heap.Free(buffer_a_, 10);
1467   heap.Alloc(buffer_d_, 40);
1468   heap.Free(buffer_b_, 20);
1469   heap.Free(buffer_c_, 50);
1470   heap.Free(buffer_d_, 40);
1471 
1472   const HeapSimulator::Result<HloValue> result = heap.Finish();
1473   EXPECT_EQ(130, result.heap_size);  // 70 + 60
1474   EXPECT_EQ(2, result.heap_results.size());
1475 
1476   EXPECT_TRUE(result.heap_results[0].chunk_map.contains(buffer_a_));
1477   EXPECT_TRUE(result.heap_results[0].chunk_map.contains(buffer_c_));
1478   EXPECT_EQ(10, result.heap_results[0].chunk_map.at(buffer_a_).size);
1479   EXPECT_EQ(50, result.heap_results[0].chunk_map.at(buffer_c_).size);
1480   EXPECT_EQ(60, result.heap_results[0].chunk_map.at(buffer_a_).offset);
1481   EXPECT_EQ(0, result.heap_results[0].chunk_map.at(buffer_c_).offset);
1482 }
1483 
TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest,ColocatedII)1484 TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest, ColocatedII) {
1485   // space
1486   //   ^
1487   //   |       +---------------+
1488   //   |       +-------b-------+
1489   //   | ....................
1490   //   |+------+      +-------+
1491   //   ||      |      |       |
1492   //   ||      |      |       | <--- colocate with a
1493   //   |+--a---+      +---c---+
1494   //   ---------------------> time
1495   ConstrainedGlobalDecreasingSizeBestFitHeap heap(/*size_limit_per_heap=*/50,
1496                                                   /*alignment=*/20);
1497   heap.Alloc(buffer_a_, 30);
1498   heap.Free(buffer_a_, 30);
1499   heap.Alloc(buffer_b_, 20);
1500 
1501   heap.ShareWith(buffer_c_, buffer_a_, 40);
1502   heap.Free(buffer_c_, 40);
1503   heap.Free(buffer_b_, 20);
1504 
1505   const HeapSimulator::Result<HloValue> result = heap.Finish();
1506   EXPECT_EQ(60, result.heap_size);  // 40 + 20
1507   EXPECT_EQ(2, result.heap_results.size());
1508 
1509   EXPECT_TRUE(result.heap_results[0].chunk_map.contains(buffer_a_));
1510   EXPECT_TRUE(result.heap_results[0].chunk_map.contains(buffer_c_));
1511   EXPECT_EQ(30, result.heap_results[0].chunk_map.at(buffer_a_).size);
1512   EXPECT_EQ(40, result.heap_results[0].chunk_map.at(buffer_c_).size);
1513   EXPECT_EQ(0, result.heap_results[0].chunk_map.at(buffer_a_).offset);
1514   EXPECT_EQ(0, result.heap_results[0].chunk_map.at(buffer_c_).offset);
1515 }
1516 
1517 class IntervalTreeTest : public ::testing::Test {};
1518 
TEST_F(IntervalTreeTest,InsertAndRemove)1519 TEST_F(IntervalTreeTest, InsertAndRemove) {
1520   HeapSimulator::Chunk chunk({1, 2});
1521   BufferIntervalTree tree;
1522   tree.Add(1, 2, chunk);
1523   EXPECT_TRUE(tree.Remove(1, 2, chunk));
1524   EXPECT_FALSE(tree.Remove(1, 2, chunk));
1525   ASSERT_EQ(tree.GetRoot(), nullptr);
1526   // Do it again.
1527   tree.Add(1, 2, chunk);
1528   EXPECT_TRUE(tree.Remove(1, 2, chunk));
1529   EXPECT_FALSE(tree.Remove(1, 2, chunk));
1530   ASSERT_EQ(tree.GetRoot(), nullptr);
1531 }
1532 
TEST_F(IntervalTreeTest,InsertAndRemoveTwoLevelsLeft)1533 TEST_F(IntervalTreeTest, InsertAndRemoveTwoLevelsLeft) {
1534   HeapSimulator::Chunk chunk({1, 2});  // Value in chunk doesn't matter here.
1535   //    [20, 36] (45)
1536   //     /
1537   //  [1, 45] (45)
1538 
1539   BufferIntervalTree tree;
1540   tree.Add(20, 36, chunk);
1541   tree.Add(1, 45, chunk);
1542   EXPECT_TRUE(tree.Remove(1, 45, chunk));
1543   EXPECT_EQ(tree.GetRoot()->subtree_end, 36);
1544   EXPECT_TRUE(tree.Remove(20, 36, chunk));
1545   ASSERT_EQ(tree.GetRoot(), nullptr);
1546 }
1547 
TEST_F(IntervalTreeTest,InsertAndRemoveTwoLevelsRight)1548 TEST_F(IntervalTreeTest, InsertAndRemoveTwoLevelsRight) {
1549   HeapSimulator::Chunk chunk({1, 2});  // Value in chunk doesn't matter here.
1550   //    [20, 36] (45)
1551   //          \
1552   //         [21, 45] (45)
1553   BufferIntervalTree tree;
1554   tree.Add(20, 36, chunk);
1555   tree.Add(21, 45, chunk);
1556   EXPECT_TRUE(tree.Remove(21, 45, chunk));
1557   EXPECT_EQ(tree.GetRoot()->subtree_end, 36);
1558   EXPECT_TRUE(tree.Remove(20, 36, chunk));
1559   ASSERT_EQ(tree.GetRoot(), nullptr);
1560 }
1561 
TEST_F(IntervalTreeTest,TwoLevelsRight_RootFirst)1562 TEST_F(IntervalTreeTest, TwoLevelsRight_RootFirst) {
1563   HeapSimulator::Chunk chunk({1, 2});  // Value in chunk doesn't matter here.
1564   //    [20, 36] (45)
1565   //          \
1566   //         [21, 45] (45)
1567   BufferIntervalTree tree;
1568   tree.Add(20, 36, chunk);
1569   tree.Add(21, 45, chunk);
1570   EXPECT_TRUE(tree.Remove(20, 36, chunk));
1571   EXPECT_EQ(tree.GetRoot()->subtree_end, 45);
1572   EXPECT_EQ(tree.GetRoot()->start, 21);
1573   EXPECT_EQ(tree.GetRoot()->end, 45);
1574   EXPECT_EQ(tree.GetRoot()->left, nullptr);
1575   EXPECT_EQ(tree.GetRoot()->right, nullptr);
1576   EXPECT_TRUE(tree.Remove(21, 45, chunk));
1577   ASSERT_EQ(tree.GetRoot(), nullptr);
1578 }
1579 
TEST_F(IntervalTreeTest,TwoLevelsLeft_RootFirst)1580 TEST_F(IntervalTreeTest, TwoLevelsLeft_RootFirst) {
1581   HeapSimulator::Chunk chunk({1, 2});  // Value in chunk doesn't matter here.
1582   //    [20, 36] (45)
1583   //      /
1584   //  [1, 45] (45)
1585   BufferIntervalTree tree;
1586   tree.Add(20, 36, chunk);
1587   tree.Add(1, 45, chunk);
1588   EXPECT_TRUE(tree.Remove(20, 36, chunk));
1589   EXPECT_EQ(tree.GetRoot()->subtree_end, 45);
1590   EXPECT_EQ(tree.GetRoot()->start, 1);
1591   EXPECT_EQ(tree.GetRoot()->end, 45);
1592   EXPECT_EQ(tree.GetRoot()->left, nullptr);
1593   EXPECT_EQ(tree.GetRoot()->right, nullptr);
1594   EXPECT_TRUE(tree.Remove(1, 45, chunk));
1595   ASSERT_EQ(tree.GetRoot(), nullptr);
1596 }
1597 
TEST_F(IntervalTreeTest,ThreeLevelsRight)1598 TEST_F(IntervalTreeTest, ThreeLevelsRight) {
1599   HeapSimulator::Chunk chunk({1, 2});  // Value in chunk doesn't matter here.
1600   //    [20, 36] (45)
1601   //          \
1602   //         [21, 45] (45)
1603   //              \
1604   //              [22, 40] (40)
1605   BufferIntervalTree tree;
1606   tree.Add(20, 36, chunk);
1607   tree.Add(21, 45, chunk);
1608   tree.Add(22, 40, chunk);
1609   EXPECT_TRUE(tree.Remove(21, 45, chunk));
1610   EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1611   EXPECT_TRUE(tree.Remove(20, 36, chunk));
1612   EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1613   EXPECT_TRUE(tree.Remove(22, 40, chunk));
1614   ASSERT_EQ(tree.GetRoot(), nullptr);
1615 }
TEST_F(IntervalTreeTest,ThreeLevelsLeftLeft)1616 TEST_F(IntervalTreeTest, ThreeLevelsLeftLeft) {
1617   HeapSimulator::Chunk chunk({1, 2});  // Value in chunk doesn't matter here.
1618   //    [20, 36] (45)
1619   //       /
1620   //  [10, 45] (45)
1621   //      /
1622   // [1, 40] (40)
1623   BufferIntervalTree tree;
1624   tree.Add(20, 36, chunk);
1625   tree.Add(10, 45, chunk);
1626   tree.Add(1, 40, chunk);
1627   EXPECT_TRUE(tree.Remove(10, 45, chunk));
1628   EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1629   EXPECT_TRUE(tree.Remove(1, 40, chunk));
1630   EXPECT_EQ(tree.GetRoot()->subtree_end, 36);
1631   EXPECT_TRUE(tree.Remove(20, 36, chunk));
1632   ASSERT_EQ(tree.GetRoot(), nullptr);
1633 }
1634 
TEST_F(IntervalTreeTest,ThreeLevelsLeftRight)1635 TEST_F(IntervalTreeTest, ThreeLevelsLeftRight) {
1636   HeapSimulator::Chunk chunk({1, 2});  // Value in chunk doesn't matter here.
1637   //    [20, 36] (45)
1638   //       /
1639   //  [10, 45] (45)
1640   //      \
1641   //     [15, 40] (40)
1642   BufferIntervalTree tree;
1643   tree.Add(20, 36, chunk);
1644   tree.Add(10, 45, chunk);
1645   tree.Add(15, 40, chunk);
1646   EXPECT_TRUE(tree.Remove(10, 45, chunk));
1647   EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1648   EXPECT_TRUE(tree.Remove(15, 40, chunk));
1649   EXPECT_EQ(tree.GetRoot()->subtree_end, 36);
1650   EXPECT_TRUE(tree.Remove(20, 36, chunk));
1651   ASSERT_EQ(tree.GetRoot(), nullptr);
1652 }
1653 
TEST_F(IntervalTreeTest,ThreeLevelsRightLeft)1654 TEST_F(IntervalTreeTest, ThreeLevelsRightLeft) {
1655   HeapSimulator::Chunk chunk({1, 2});  // Value in chunk doesn't matter here.
1656   //    [20, 36] (45)
1657   //          \
1658   //         [25, 45] (45)
1659   //           /
1660   //       [22, 40] (40)
1661   BufferIntervalTree tree;
1662   tree.Add(20, 36, chunk);
1663   tree.Add(25, 45, chunk);
1664   tree.Add(22, 40, chunk);
1665   EXPECT_TRUE(tree.Remove(25, 45, chunk));
1666   EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1667   EXPECT_TRUE(tree.Remove(20, 36, chunk));
1668   EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1669   EXPECT_TRUE(tree.Remove(22, 40, chunk));
1670   ASSERT_EQ(tree.GetRoot(), nullptr);
1671 }
1672 
TEST_F(IntervalTreeTest,ThreeLevelsRightLeftChunkDifferent)1673 TEST_F(IntervalTreeTest, ThreeLevelsRightLeftChunkDifferent) {
1674   HeapSimulator::Chunk chunk1({1, 2});
1675   HeapSimulator::Chunk chunk2({2, 3});
1676   HeapSimulator::Chunk chunk3({3, 4});
1677   //    [20, 36] (45) Chunk1({1, 2})
1678   //          \
1679   //         [25, 45] (45) Chunk2({2, 3})
1680   //           /
1681   //       [22, 40] (40) Chunk3({3, 4})
1682   BufferIntervalTree tree;
1683   tree.Add(20, 36, chunk1);
1684   tree.Add(25, 45, chunk2);
1685   tree.Add(22, 40, chunk3);
1686   EXPECT_TRUE(tree.Remove(25, 45, chunk2));
1687   // Chunk 1 is till the root after removing chunk 2.
1688   EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1689   EXPECT_EQ(tree.GetRoot()->chunk.offset, 1);
1690   EXPECT_EQ(tree.GetRoot()->chunk.size, 2);
1691   EXPECT_TRUE(tree.Remove(20, 36, chunk1));
1692   // Chunk 3 becomes the root now.
1693   EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1694   EXPECT_EQ(tree.GetRoot()->chunk.offset, 3);
1695   EXPECT_EQ(tree.GetRoot()->chunk.size, 4);
1696   EXPECT_TRUE(tree.Remove(22, 40, chunk3));
1697   ASSERT_EQ(tree.GetRoot(), nullptr);
1698 }
1699 
1700 }  // namespace
1701 }  // namespace xla
1702