• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/heap_simulator.h"
17 
18 #include <memory>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/memory/memory.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/service/buffer_value.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
30 #include "tensorflow/compiler/xla/service/hlo_value.h"
31 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/core/platform/test.h"
36 
37 namespace xla {
38 namespace {
39 
40 class MinimumMemoryForSequenceTest : public HloTestBase {};
41 
TEST_F(MinimumMemoryForSequenceTest,MultiComputation)42 TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
43   auto module = CreateNewVerifiedModule();
44   const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
45   const Shape tuple_shape =
46       ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
47 
48   auto cond_builder = HloComputation::Builder("WhileCond");
49   // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
50   HloInstruction* cond_param = cond_builder.AddInstruction(
51       HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
52   HloInstruction* cond_iter = cond_builder.AddInstruction(
53       HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
54   HloInstruction* cond_data = cond_builder.AddInstruction(
55       HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
56   // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
57   HloInstruction* cond_lt = cond_builder.AddInstruction(
58       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
59                                     cond_data, ComparisonDirection::kLt));
60   HloComputation* cond_computation =
61       module->AddEmbeddedComputation(cond_builder.Build());
62 
63   auto body_builder = HloComputation::Builder("WhileBody");
64   // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
65   HloInstruction* body_param = body_builder.AddInstruction(
66       HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
67   HloComputation* body_computation =
68       module->AddEmbeddedComputation(body_builder.Build());
69 
70   auto builder = HloComputation::Builder(TestName());
71   // Entry params: 8 bytes (4 bytes per param), TOTAL=8
72   HloInstruction* iter = builder.AddInstruction(
73       HloInstruction::CreateParameter(0, scalar_shape, "param_iter"));
74   HloInstruction* data = builder.AddInstruction(
75       HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
76   // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24
77   HloInstruction* tuple =
78       builder.AddInstruction(HloInstruction::CreateTuple({iter, data}));
79   // While: 8 bytes (4 bytes per element), TOTAL=32
80   // Both cond and body use a max of 24 bytes, TOTAL=56
81   HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
82       tuple_shape, cond_computation, body_computation, tuple));
83   HloComputation* entry_computation =
84       module->AddEntryComputation(builder.Build());
85 
86   auto size_fn = [](const BufferValue& buffer) {
87     return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
88   };
89 
90   HloSchedule schedule(module.get());
91   schedule.set_sequence(cond_computation,
92                         {cond_param, cond_iter, cond_data, cond_lt});
93   schedule.set_sequence(body_computation, {body_param});
94   schedule.set_sequence(entry_computation, {iter, data, tuple, while_op});
95   TF_ASSERT_OK(schedule.Verify());
96 
97   EXPECT_EQ(
98       25,
99       HeapSimulator::MinimumMemoryForModule(schedule, size_fn).ValueOrDie());
100 }
101 
TEST_F(MinimumMemoryForSequenceTest,SubcomputationAccounting)102 TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) {
103   // HloModule SubcomputationAccounting
104 
105   // %WhileBody (body_param: f32[4]) -> f32[4] {
106   //   %body_param = f32[4]{0} parameter(0)
107   //   %constant.1 = f32[4]{0} constant({1, 1, 1, 1})
108   //   ROOT %subtract = f32[4]{0} subtract(f32[4]{0} %body_param, f32[4]{0}
109   //   %constant.1)
110   // }
111 
112   // %WhileCond (cond_param: f32[4]) -> pred[] {
113   //   %cond_param = f32[4]{0} parameter(0)
114   //   %slice = f32[1]{0} slice(f32[4]{0} %cond_param), slice={[0:1]}
115   //   %reshape = f32[] reshape(f32[1]{0} %slice)
116   //   %constant = f32[] constant(0)
117   //   ROOT %not-equal-to = pred[] compare(f32[] %reshape, f32[] %constant),
118   //   direction=NE
119   // }
120 
121   // ENTRY %SubcomputationAccounting () -> f32[2,4] {
122   //   %constant.3 = f32[2,4]{1,0} constant(f32[2,4] { { 1, 2, 3, 4 }, { 1, 2,
123   //   3, 4 } }) %transpose = f32[2,4]{1,0} transpose(f32[2,4]{1,0}
124   //   %constant.3), dimensions={0,1} %constant.2 = f32[4]{0} constant({1, 1, 1,
125   //   1}) %while = f32[4]{0} while(f32[4]{0} %constant.2),
126   //   condition=%WhileCond, body=%WhileBody %broadcast = f32[2,4]{1,0}
127   //   broadcast(f32[4]{0} %while), dimensions={1} ROOT %add = f32[2,4]{1,0}
128   //   add(f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast)
129   // }
130 
131   auto module = CreateNewVerifiedModule();
132   const Shape r0f32 = ShapeUtil::MakeShape(F32, {});
133   const Shape r1f32 = ShapeUtil::MakeShape(F32, {4});
134   const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4});
135 
136   // reshape(slice(param)) != 0
137   // Needs 5 bytes
138   auto cond_builder = HloComputation::Builder("WhileCond");
139   HloInstruction* cond_param = cond_builder.AddInstruction(
140       HloInstruction::CreateParameter(0, r1f32, "cond_param"));
141   HloInstruction* slice =
142       cond_builder.AddInstruction(HloInstruction::CreateSlice(
143           ShapeUtil::MakeShape(F32, {1}), cond_param, {0}, {1}, {1}));
144   HloInstruction* reshape =
145       cond_builder.AddInstruction(HloInstruction::CreateReshape(r0f32, slice));
146   HloInstruction* zero = cond_builder.AddInstruction(
147       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
148   HloInstruction* cond_comparison = cond_builder.AddInstruction(
149       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), reshape,
150                                     zero, ComparisonDirection::kNe));
151   auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
152 
153   // param - 1
154   // Needs 16 bytes
155   auto body_builder = HloComputation::Builder("WhileBody");
156   HloInstruction* body_param = body_builder.AddInstruction(
157       HloInstruction::CreateParameter(0, r1f32, "body_param"));
158   HloInstruction* one_vector =
159       body_builder.AddInstruction(HloInstruction::CreateConstant(
160           LiteralUtil::CreateR1<float>({1, 1, 1, 1})));
161   HloInstruction* subtract =
162       body_builder.AddInstruction(HloInstruction::CreateBinary(
163           r1f32, HloOpcode::kSubtract, body_param, one_vector));
164   auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
165 
166   // transpose(matrix) + bcast(while)
167   auto builder = HloComputation::Builder(TestName());
168   HloInstruction* while_init =
169       builder.AddInstruction(HloInstruction::CreateConstant(
170           LiteralUtil::CreateR1<float>({1, 1, 1, 1})));
171   // Creates 16 bytes, ignoring subcomputations
172   HloInstruction* while_loop =
173       builder.AddInstruction(HloInstruction::CreateWhile(
174           r1f32, cond_computation, body_computation, while_init));
175 
176   // Creates 32 bytes and frees 16
177   HloInstruction* bcast = builder.AddInstruction(
178       HloInstruction::CreateBroadcast(r2f32, while_loop, {1}));
179 
180   HloInstruction* matrix = builder.AddInstruction(
181       HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>(
182           {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}})));
183   // Creates 32 bytes
184   HloInstruction* transpose = builder.AddInstruction(
185       HloInstruction::CreateTranspose(r2f32, matrix, {0, 1}));
186 
187   // Creates 32 bytes and frees 64
188   HloInstruction* add = builder.AddInstruction(
189       HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast));
190 
191   auto entry_computation = module->AddEntryComputation(builder.Build());
192 
193   HloSchedule schedule(module.get());
194   std::vector<HloInstruction*> cond_vec = {cond_param, slice, reshape, zero,
195                                            cond_comparison};
196   std::vector<HloInstruction*> while_body_vec = {body_param, one_vector,
197                                                  subtract};
198   std::vector<HloInstruction*> entry_comp_vec = {while_init, while_loop, bcast,
199                                                  matrix,     transpose,  add};
200   schedule.set_sequence(cond_computation, cond_vec);
201   schedule.set_sequence(body_computation, while_body_vec);
202   schedule.set_sequence(entry_computation, entry_comp_vec);
203 
204   auto size_fn = [](const BufferValue& buffer) {
205     return ShapeUtil::ByteSizeOf(buffer.shape());
206   };
207   absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
208   memory_by_computation[cond_computation] = 5;
209   memory_by_computation[body_computation] = 16;
210 
211   std::unique_ptr<HloAliasAnalysis> alias_analysis =
212       HloAliasAnalysis::Run(module.get()).ValueOrDie();
213 
214   // HeapSimulator accounts for subcomputations. The output buffer is aliased,
215   // so we don't double count.
216   EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation(
217                     *entry_computation, schedule.sequence(entry_computation),
218                     *alias_analysis, size_fn, &memory_by_computation)
219                     .ValueOrDie());
220 }
221 
222 const char kAlloc[] = "Alloc";
223 const char kFree[] = "Free";
224 const char kShare[] = "Share";
225 const char kFinish[] = "Finish";
226 
227 // CallSequence records a sequence of Alloc/Free/Finish calls.
228 using CallSequence = std::vector<std::pair<string, const HloValue*>>;
229 
230 // HeapCallRecorder is a dummy heap algorithm that simply records its calls.
231 class HeapCallRecorder : public HeapAlgorithm<HloValue> {
232  public:
HeapCallRecorder(CallSequence * calls)233   explicit HeapCallRecorder(CallSequence* calls) : calls_(calls) {}
~HeapCallRecorder()234   ~HeapCallRecorder() override {}
235 
Alloc(const HloValue * buffer,int64 size)236   void Alloc(const HloValue* buffer, int64 size) override {
237     calls_->emplace_back(kAlloc, buffer);
238     // Instead of assigning a real offset, we set the cardinality of the Alloc
239     // call.  This isn't a valid assignment, but allows us to easily test for
240     // buffer sharing.
241     const int64 offset = result_.chunk_map.size();
242     result_.chunk_map.emplace(buffer, Chunk{offset, size});
243   }
244 
ShareWith(const HloValue * buffer,const HloValue * shared,int64 size)245   void ShareWith(const HloValue* buffer, const HloValue* shared,
246                  int64 size) override {
247     calls_->emplace_back(kShare, buffer);
248     // Instead of assigning a real offset, we set the cardinality of the Alloc
249     // call.  This isn't a valid assignment, but allows us to easily test for
250     // buffer sharing.
251     const int64 offset = result_.chunk_map[shared].offset;
252     result_.chunk_map.emplace(buffer, Chunk{offset, size});
253   }
Free(const HloValue * buffer,int64 size)254   void Free(const HloValue* buffer, int64 size) override {
255     calls_->emplace_back(kFree, buffer);
256   }
Finish()257   Result Finish() override {
258     calls_->emplace_back(kFinish, nullptr);
259     HeapSimulator::Result<HloValue> result;
260     result.heap_size = result_.heap_size;
261     result.heap_results.emplace_back(std::move(result_));
262     return result;
263   }
264 
265  private:
266   CallSequence* calls_;
267   HeapSimulator::HeapResult<HloValue> result_;
268 };
269 
270 // HeapSimulatorTracker runs the heap simulator, recording the sequence of calls
271 // made to the underlying heap algorithm.  Tests compare the actual call
272 // sequence against an expected sequence.
273 class HeapSimulatorTracker {
274  public:
HeapSimulatorTracker(std::unique_ptr<HloModule> module,const std::vector<HloInstruction * > & instruction_sequence,const std::vector<HloInstruction * > & must_alias_set={},const HloDataflowAnalysis::CanShareBuffer & can_share_buffer=nullptr)275   explicit HeapSimulatorTracker(
276       std::unique_ptr<HloModule> module,
277       const std::vector<HloInstruction*>& instruction_sequence,
278       const std::vector<HloInstruction*>& must_alias_set = {},
279       const HloDataflowAnalysis::CanShareBuffer& can_share_buffer = nullptr) {
280     module_ = std::move(module);
281     Init(instruction_sequence, can_share_buffer);
282   }
283 
284   // Constructor for testing a single entry computation.
HeapSimulatorTracker(const string & name,std::unique_ptr<HloComputation> entry_computation,const std::vector<HloInstruction * > & instruction_sequence,const std::vector<HloInstruction * > & must_alias_set={},const HloDataflowAnalysis::CanShareBuffer & can_share_buffer=nullptr)285   explicit HeapSimulatorTracker(
286       const string& name, std::unique_ptr<HloComputation> entry_computation,
287       const std::vector<HloInstruction*>& instruction_sequence,
288       const std::vector<HloInstruction*>& must_alias_set = {},
289       const HloDataflowAnalysis::CanShareBuffer& can_share_buffer = nullptr) {
290     HloModuleConfig config;
291     module_ = absl::make_unique<HloModule>(name, config);
292     module_->AddEntryComputation(std::move(entry_computation));
293     Init(instruction_sequence, can_share_buffer);
294   }
295 
HeapSimulatorTracker(const string & name)296   explicit HeapSimulatorTracker(const string& name) {
297     HloModuleConfig config;
298     module_ = absl::make_unique<HloModule>(name, config);
299   }
300 
301   // Similar to the single entry computation constructor above, but runs the
302   // simulation over the entire module.
RunWholeModule(const std::vector<HloInstruction * > & full_module_sequence)303   void RunWholeModule(
304       const std::vector<HloInstruction*>& full_module_sequence) {
305     alias_analysis_ = HloAliasAnalysis::Run(module_.get()).ConsumeValueOrDie();
306 
307     // Construct the module sequence grouped by computation.
308     HloSchedule schedule(module_.get());
309     absl::flat_hash_map<const HloInstruction*, int> reverse_position;
310     for (int i = 0; i < full_module_sequence.size(); ++i) {
311       HloInstruction* instruction = full_module_sequence[i];
312       schedule.GetOrCreateSequence(instruction->parent())
313           .push_back(instruction);
314       reverse_position[instruction] = full_module_sequence.size() - i;
315     }
316 
317     // Hack the size_fn so that it returns a decreasing value as we step through
318     // the sequence. This lets us ensure the Alloc calls are in the sequence
319     // order. The Free calls are sorted by BufferValue.id, which is at least
320     // deterministic.
321     auto size_fn = [&reverse_position](const BufferValue& buffer) {
322       return reverse_position[buffer.instruction()];
323     };
324     auto algorithm = absl::make_unique<HeapCallRecorder>(&actual_calls_);
325     result_ = HeapSimulator::Run(std::move(algorithm), *module_, schedule,
326                                  *alias_analysis_, size_fn)
327                   .ConsumeValueOrDie();
328   }
329 
module()330   HloModule* module() { return module_.get(); }
331 
332   // Returns the buffer defined at the given instruction and index.
BufferAt(const HloInstruction * instruction,const ShapeIndex & index) const333   const HloValue* BufferAt(const HloInstruction* instruction,
334                            const ShapeIndex& index) const {
335     return &alias_analysis_->dataflow_analysis().GetUniqueValueAt(instruction,
336                                                                   index);
337   }
338 
OffsetAt(const HloInstruction * instruction,const ShapeIndex & index)339   int64 OffsetAt(const HloInstruction* instruction, const ShapeIndex& index) {
340     const HloValue* buffer = BufferAt(instruction, index);
341     CHECK_EQ(1, result_.heap_results.size());
342     return result_.heap_results.at(0).chunk_map.at(buffer).offset;
343   }
344 
345   // Ensures the expected sequence of Alloc/Free/Finish calls was performed.
ExpectCallSequence(const CallSequence & expected) const346   void ExpectCallSequence(const CallSequence& expected) const {
347     auto to_string = [](const CallSequence& sequence) {
348       std::string output;
349       for (int64 i = 0; i < sequence.size(); ++i) {
350         auto pair = sequence.at(i);
351         absl::StrAppendFormat(&output, "%d", i);
352         absl::StrAppendFormat(&output, " :%s", pair.first);
353         if (pair.second != nullptr) {
354           absl::StrAppendFormat(&output, " - %s{%s}\n",
355                                 pair.second->instruction()->name(),
356                                 pair.second->index().ToString());
357         }
358       }
359       return output;
360     };
361     EXPECT_EQ(expected, actual_calls_) << "Expected:\n"
362                                        << to_string(expected) << " \nActual:\n"
363                                        << to_string(actual_calls_) << "\n";
364   }
365 
366   // Ensures the buffers defined by the respective (instruction,index) pairs are
367   // shared, relying on the unique offsets assigned in
368   // HeapCallRecorder::Alloc.
ExpectSharedBuffers(const HloInstruction * instruction_a,const ShapeIndex & index_a,const HloInstruction * instruction_b,const ShapeIndex & index_b)369   void ExpectSharedBuffers(const HloInstruction* instruction_a,
370                            const ShapeIndex& index_a,
371                            const HloInstruction* instruction_b,
372                            const ShapeIndex& index_b) {
373     int64 offset_a = OffsetAt(instruction_a, index_a);
374     int64 offset_b = OffsetAt(instruction_b, index_b);
375     EXPECT_EQ(offset_a, offset_b);
376   }
377 
378  private:
Init(const std::vector<HloInstruction * > & instruction_sequence,const HloDataflowAnalysis::CanShareBuffer & can_share_buffer)379   void Init(const std::vector<HloInstruction*>& instruction_sequence,
380             const HloDataflowAnalysis::CanShareBuffer& can_share_buffer) {
381     // Since we're only tracking the sequence of Alloc/Free calls, the actual
382     // size of the buffers doesn't matter, so we always return 0.  We rely on
383     // the secondary sorting criteria of DecreasingSizeRunsHeap to sort calls
384     // by buffer id, for determinism in the tests.
385     auto zero_size = [](const BufferValue& buffer) { return 0; };
386     auto algorithm = absl::make_unique<HeapCallRecorder>(&actual_calls_);
387 
388     alias_analysis_ =
389         HloAliasAnalysis::Run(module_.get(), can_share_buffer).ValueOrDie();
390 
391     HeapSimulator::Options options;
392 
393     result_ =
394         HeapSimulator::Run(std::move(algorithm), *module_->entry_computation(),
395                            HloInstructionSequence(instruction_sequence),
396                            *alias_analysis_, zero_size, options)
397             .ConsumeValueOrDie();
398   }
399 
400   std::unique_ptr<HloModule> module_;
401   std::unique_ptr<HloAliasAnalysis> alias_analysis_;
402   CallSequence actual_calls_;
403   HeapSimulator::Result<HloValue> result_;
404 };
405 
406 class HeapSimulatorTest : public HloTestBase {
407  protected:
HeapSimulatorTest()408   HeapSimulatorTest() {}
~HeapSimulatorTest()409   ~HeapSimulatorTest() override {}
410 
411   // Shapes for use in the examples.
412   Shape f32scalar_ = ShapeUtil::MakeShape(xla::F32, {});
413   Shape f32vec4_ = ShapeUtil::MakeShape(F32, {4});
414 };
415 
TEST_F(HeapSimulatorTest,ScalarConstant)416 TEST_F(HeapSimulatorTest, ScalarConstant) {
417   auto builder = HloComputation::Builder(TestName());
418   auto const0 = builder.AddInstruction(
419       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
420 
421   // Constants aren't assigned.  See b/32248867
422   HeapSimulatorTracker tracker(TestName(), builder.Build(), {const0});
423   tracker.ExpectCallSequence({{kFinish, nullptr}});
424 }
425 
TEST_F(HeapSimulatorTest,OneParam)426 TEST_F(HeapSimulatorTest, OneParam) {
427   auto builder = HloComputation::Builder(TestName());
428   auto param0 = builder.AddInstruction(
429       HloInstruction::CreateParameter(0, f32scalar_, "param0"));
430 
431   // A single parameter which is also the output.
432   HeapSimulatorTracker tracker(TestName(), builder.Build(), {param0});
433   tracker.ExpectCallSequence({
434       {kAlloc, tracker.BufferAt(param0, {})},
435       {kFree, tracker.BufferAt(param0, {})},
436       {kFinish, nullptr},
437   });
438 }
439 
TEST_F(HeapSimulatorTest,Multiply)440 TEST_F(HeapSimulatorTest, Multiply) {
441   auto builder = HloComputation::Builder(TestName());
442   auto paramA = builder.AddInstruction(
443       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
444   auto paramX = builder.AddInstruction(
445       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
446   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
447       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
448 
449   // We must keep all parameters and outputs.
450   HeapSimulatorTracker tracker(TestName(), builder.Build(),
451                                {paramA, paramX, mul});
452   tracker.ExpectCallSequence({
453       {kAlloc, tracker.BufferAt(paramA, {})},
454       {kAlloc, tracker.BufferAt(paramX, {})},
455       {kAlloc, tracker.BufferAt(mul, {})},
456       // All params and outputs are freed at the end.
457       {kFree, tracker.BufferAt(paramA, {})},
458       {kFree, tracker.BufferAt(paramX, {})},
459       {kFree, tracker.BufferAt(mul, {})},
460       {kFinish, nullptr},
461   });
462 }
463 
TEST_F(HeapSimulatorTest,MultiplyAdd)464 TEST_F(HeapSimulatorTest, MultiplyAdd) {
465   auto builder = HloComputation::Builder(TestName());
466   auto paramA = builder.AddInstruction(
467       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
468   auto paramX = builder.AddInstruction(
469       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
470   auto paramY = builder.AddInstruction(
471       HloInstruction::CreateParameter(2, f32vec4_, "paramY"));
472   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
473       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
474   auto add = builder.AddInstruction(
475       HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, mul, paramY));
476 
477   // The buffer for add is the output, and it's shared with the buffer for
478   // mul.
479   HeapSimulatorTracker tracker(TestName(), builder.Build(),
480                                {paramA, paramX, mul, paramY, add});
481   tracker.ExpectCallSequence({
482       {kAlloc, tracker.BufferAt(paramA, {})},
483       {kAlloc, tracker.BufferAt(paramX, {})},
484       {kAlloc, tracker.BufferAt(paramY, {})},
485       {kAlloc, tracker.BufferAt(mul, {})},
486       {kFree, tracker.BufferAt(mul, {})},
487       {kShare, tracker.BufferAt(add, {})},
488       // All params and outputs are freed at the end.
489       {kFree, tracker.BufferAt(paramA, {})},
490       {kFree, tracker.BufferAt(paramX, {})},
491       {kFree, tracker.BufferAt(paramY, {})},
492       {kFree, tracker.BufferAt(add, {})},
493       {kFinish, nullptr},
494   });
495   tracker.ExpectSharedBuffers(add, {}, mul, {});
496 }
497 
TEST_F(HeapSimulatorTest,FusionOutputsOnlyShareOnce)498 TEST_F(HeapSimulatorTest, FusionOutputsOnlyShareOnce) {
499   // Test that only one output of a fusion node will be shared with its operand.
500   auto can_share_buffer =
501       [](const HloInstruction* instr, const HloInstruction* operand,
502          const ShapeIndex& user_index) -> absl::optional<bool> {
503     if (instr->opcode() == HloOpcode::kFusion) {
504       return true;
505     }
506     return false;
507   };
508 
509   HloModuleConfig config;
510   auto module = absl::make_unique<HloModule>(TestName(), config);
511 
512   auto builder = HloComputation::Builder(TestName());
513   auto paramA = builder.AddInstruction(
514       HloInstruction::CreateParameter(0, f32vec4_, "paramA"));
515   auto negate = builder.AddInstruction(
516       HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, paramA));
517 
518   // The fusion node has two outputs, both are eligible for being reused with
519   // operand.
520   auto fusion_builder = HloComputation::Builder("simple_two_way_forwarding");
521   {
522     auto param = fusion_builder.AddInstruction(
523         HloInstruction::CreateParameter(0, f32vec4_, "x"));
524     fusion_builder.AddInstruction(HloInstruction::CreateTuple({param, param}));
525   }
526   auto fusion_computation =
527       module->AddEmbeddedComputation(fusion_builder.Build());
528 
529   auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
530       ShapeUtil::MakeTupleShape({f32vec4_, f32vec4_}),
531       HloInstruction::FusionKind::kLoop, {negate}, fusion_computation));
532 
533   auto element0 = builder.AddInstruction(
534       HloInstruction::CreateGetTupleElement(f32scalar_, fusion, 0));
535 
536   auto element1 = builder.AddInstruction(
537       HloInstruction::CreateGetTupleElement(f32scalar_, fusion, 1));
538 
539   auto negate0 = builder.AddInstruction(
540       HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, element0));
541   auto negate1 = builder.AddInstruction(
542       HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, element1));
543 
544   builder.AddInstruction(HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd,
545                                                       negate0, negate1));
546 
547   module->AddEntryComputation(builder.Build());
548   HeapSimulatorTracker tracker(
549       std::move(module),
550       {paramA, negate, fusion, element0, element1, negate0, negate1}, {},
551       can_share_buffer);
552   tracker.ExpectCallSequence({
553       {kAlloc, tracker.BufferAt(paramA, {})},
554       {kAlloc, tracker.BufferAt(negate, {})},
555       {kAlloc, tracker.BufferAt(fusion, {})},
556       {kFree, tracker.BufferAt(negate, {})},
557       {kShare, tracker.BufferAt(fusion, {0})},
558       {kAlloc, tracker.BufferAt(fusion, {1})},
559       {kFree, tracker.BufferAt(fusion, {})},
560       {kAlloc, tracker.BufferAt(negate0, {})},
561       {kFree, tracker.BufferAt(fusion, {0})},
562       {kFree, tracker.BufferAt(negate0, {})},
563       {kAlloc, tracker.BufferAt(negate1, {})},
564       {kFree, tracker.BufferAt(fusion, {1})},
565       {kFree, tracker.BufferAt(negate1, {})},
566       {kFree, tracker.BufferAt(paramA, {})},
567       {kFinish, nullptr},
568   });
569 }
570 
TEST_F(HeapSimulatorTest,FusionOutputsOnlyShareOnceOutputShortLived)571 TEST_F(HeapSimulatorTest, FusionOutputsOnlyShareOnceOutputShortLived) {
572   // Test that only one output of a fusion node will be shared with its operand.
573   // This variant of the test has a fusion node that dies immediately.
574   auto can_share_buffer =
575       [](const HloInstruction* instr, const HloInstruction* operand,
576          const ShapeIndex& user_index) -> absl::optional<bool> {
577     if (instr->opcode() == HloOpcode::kFusion) {
578       return true;
579     }
580     return false;
581   };
582 
583   HloModuleConfig config;
584   auto module = absl::make_unique<HloModule>(TestName(), config);
585 
586   auto builder = HloComputation::Builder(TestName());
587   auto paramA = builder.AddInstruction(
588       HloInstruction::CreateParameter(0, f32vec4_, "paramA"));
589   auto negate = builder.AddInstruction(
590       HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, paramA));
591 
592   // The fusion node has two outputs, both are eligible for being reused with
593   // operand.
594   auto fusion_builder = HloComputation::Builder("simple_two_way_forwarding");
595   {
596     auto param = fusion_builder.AddInstruction(
597         HloInstruction::CreateParameter(0, f32vec4_, "x"));
598     fusion_builder.AddInstruction(HloInstruction::CreateTuple({param, param}));
599   }
600   auto fusion_computation =
601       module->AddEmbeddedComputation(fusion_builder.Build());
602 
603   auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
604       ShapeUtil::MakeTupleShape({f32vec4_, f32vec4_}),
605       HloInstruction::FusionKind::kLoop, {negate}, fusion_computation));
606 
607   auto element1 = builder.AddInstruction(
608       HloInstruction::CreateGetTupleElement(f32scalar_, fusion, 1));
609 
610   auto negate1 = builder.AddInstruction(
611       HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, element1));
612 
613   module->AddEntryComputation(builder.Build());
614   HeapSimulatorTracker tracker(std::move(module),
615                                {paramA, negate, fusion, element1, negate1}, {},
616                                can_share_buffer);
617   tracker.ExpectCallSequence({
618       {kAlloc, tracker.BufferAt(paramA, {})},
619       {kAlloc, tracker.BufferAt(negate, {})},
620       {kFree, tracker.BufferAt(negate, {})},
621       {kShare, tracker.BufferAt(fusion, {0})},
622       {kAlloc, tracker.BufferAt(fusion, {})},
623       {kAlloc, tracker.BufferAt(fusion, {1})},
624       {kFree, tracker.BufferAt(fusion, {0})},
625       {kFree, tracker.BufferAt(fusion, {})},
626       {kAlloc, tracker.BufferAt(negate1, {})},
627       {kFree, tracker.BufferAt(fusion, {1})},
628       {kFree, tracker.BufferAt(paramA, {})},
629       {kFree, tracker.BufferAt(negate1, {})},
630       {kFinish, nullptr},
631   });
632 }
633 
TEST_F(HeapSimulatorTest,BufferReusedOnce)634 TEST_F(HeapSimulatorTest, BufferReusedOnce) {
635   HeapSimulatorTracker tracker(TestName());
636   auto builder = HloComputation::Builder(TestName());
637 
638   HloComputation::Builder fusion_builder("fusion");
639   {
640     HloComputation::Builder& builder = fusion_builder;
641     auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
642         /*parameter_number=*/0, f32vec4_, "A"));
643     auto exp = builder.AddInstruction(
644         HloInstruction::CreateUnary(f32vec4_, HloOpcode::kExp, a_param));
645     auto neg = builder.AddInstruction(
646         HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, a_param));
647 
648     builder.AddInstruction(HloInstruction::CreateTuple({exp, neg}));
649   }
650   auto fusion_computation =
651       tracker.module()->AddEmbeddedComputation(fusion_builder.Build());
652   auto a_param = builder.AddInstruction(
653       HloInstruction::CreateParameter(0, f32vec4_, "paramA"));
654   auto neg = builder.AddInstruction(
655       HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, a_param));
656   auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
657       ShapeUtil::MakeTupleShape({f32vec4_, f32vec4_}),
658       HloInstruction::FusionKind::kLoop, {neg}, fusion_computation));
659   tracker.module()->AddEntryComputation(builder.Build());
660 
661   tracker.RunWholeModule({a_param, neg, fusion});
662 
663   auto neg_buffer = tracker.OffsetAt(neg, {});
664   int64 output_buffer_0 = tracker.OffsetAt(fusion, {0});
665   int64 output_buffer_1 = tracker.OffsetAt(fusion, {1});
666   // Only one buffer should be shared.
667   EXPECT_TRUE((neg_buffer == output_buffer_0) ^
668               (neg_buffer == output_buffer_1));
669 }
670 
TEST_F(HeapSimulatorTest,MultiplyDot)671 TEST_F(HeapSimulatorTest, MultiplyDot) {
672   auto builder = HloComputation::Builder(TestName());
673   auto paramA = builder.AddInstruction(
674       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
675   auto paramX = builder.AddInstruction(
676       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
677   auto paramY = builder.AddInstruction(
678       HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
679   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
680       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
681   DotDimensionNumbers dot_dnums;
682   dot_dnums.add_lhs_contracting_dimensions(1);
683   dot_dnums.add_rhs_contracting_dimensions(0);
684   auto dot = builder.AddInstruction(HloInstruction::CreateDot(
685       f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
686 
687   // The buffer for dot is the output, and it cannot be shared with the buffer
688   // for mul, since dot isn't elementwise.
689   HeapSimulatorTracker tracker(TestName(), builder.Build(),
690                                {paramA, paramX, mul, paramY, dot});
691   tracker.ExpectCallSequence({
692       {kAlloc, tracker.BufferAt(paramA, {})},
693       {kAlloc, tracker.BufferAt(paramX, {})},
694       {kAlloc, tracker.BufferAt(paramY, {})},
695       {kAlloc, tracker.BufferAt(mul, {})},
696       {kAlloc, tracker.BufferAt(dot, {})},
697       // All params and outputs are freed at the end.
698       {kFree, tracker.BufferAt(mul, {})},
699       {kFree, tracker.BufferAt(paramA, {})},
700       {kFree, tracker.BufferAt(paramX, {})},
701       {kFree, tracker.BufferAt(paramY, {})},
702       {kFree, tracker.BufferAt(dot, {})},
703       {kFinish, nullptr},
704   });
705 }
706 
TEST_F(HeapSimulatorTest,MultiplyDotAdd)707 TEST_F(HeapSimulatorTest, MultiplyDotAdd) {
708   auto builder = HloComputation::Builder(TestName());
709   auto paramA = builder.AddInstruction(
710       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
711   auto paramX = builder.AddInstruction(
712       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
713   auto paramY = builder.AddInstruction(
714       HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
715   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
716       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
717   DotDimensionNumbers dot_dnums;
718   dot_dnums.add_lhs_contracting_dimensions(1);
719   dot_dnums.add_rhs_contracting_dimensions(0);
720   auto dot = builder.AddInstruction(HloInstruction::CreateDot(
721       f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
722   auto add = builder.AddInstruction(
723       HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA));
724 
725   // The buffer for add is the output, and it's shared with the buffer for
726   // dot.
727   HeapSimulatorTracker tracker(TestName(), builder.Build(),
728                                {paramA, paramX, mul, paramY, dot, add});
729   tracker.ExpectCallSequence({
730       {kAlloc, tracker.BufferAt(paramA, {})},
731       {kAlloc, tracker.BufferAt(paramX, {})},
732       {kAlloc, tracker.BufferAt(paramY, {})},
733       {kAlloc, tracker.BufferAt(mul, {})},
734       {kAlloc, tracker.BufferAt(dot, {})},
735       {kFree, tracker.BufferAt(mul, {})},
736       {kFree, tracker.BufferAt(dot, {})},
737       {kShare, tracker.BufferAt(add, {})},
738       // All params and outputs are freed at the end.
739       {kFree, tracker.BufferAt(paramA, {})},
740       {kFree, tracker.BufferAt(paramX, {})},
741       {kFree, tracker.BufferAt(paramY, {})},
742       {kFree, tracker.BufferAt(add, {})},
743       {kFinish, nullptr},
744   });
745   tracker.ExpectSharedBuffers(add, {}, dot, {});
746 }
747 
TEST_F(HeapSimulatorTest,MultiplyDotDot)748 TEST_F(HeapSimulatorTest, MultiplyDotDot) {
749   auto builder = HloComputation::Builder(TestName());
750   auto paramA = builder.AddInstruction(
751       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
752   auto paramX = builder.AddInstruction(
753       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
754   auto paramY = builder.AddInstruction(
755       HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
756   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
757       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
758   DotDimensionNumbers dot_dnums;
759   dot_dnums.add_lhs_contracting_dimensions(1);
760   dot_dnums.add_rhs_contracting_dimensions(0);
761   auto dot0 = builder.AddInstruction(HloInstruction::CreateDot(
762       f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
763   auto dot1 = builder.AddInstruction(HloInstruction::CreateDot(
764       f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2)));
765 
766   // The buffer for dot1 is the output.  No buffers can be shared.  The buffer
767   // for mul is freed before the end, since it's no longer used after dot0
768   // finishes.
769   HeapSimulatorTracker tracker(TestName(), builder.Build(),
770                                {paramA, paramX, mul, paramY, dot0, dot1});
771   tracker.ExpectCallSequence({
772       {kAlloc, tracker.BufferAt(paramA, {})},
773       {kAlloc, tracker.BufferAt(paramX, {})},
774       {kAlloc, tracker.BufferAt(paramY, {})},
775       {kAlloc, tracker.BufferAt(mul, {})},
776       {kAlloc, tracker.BufferAt(dot0, {})},
777       {kFree, tracker.BufferAt(mul, {})},  // mul no longer used
778       {kAlloc, tracker.BufferAt(dot1, {})},
779       {kFree, tracker.BufferAt(dot0, {})},
780       // All params and outputs are freed at the end.
781       {kFree, tracker.BufferAt(paramA, {})},
782       {kFree, tracker.BufferAt(paramX, {})},
783       {kFree, tracker.BufferAt(paramY, {})},
784       {kFree, tracker.BufferAt(dot1, {})},
785       {kFinish, nullptr},
786   });
787 }
788 
TEST_F(HeapSimulatorTest,MultiplyDotDotTuple)789 TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) {
790   auto builder = HloComputation::Builder(TestName());
791   auto paramA = builder.AddInstruction(
792       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
793   auto paramX = builder.AddInstruction(
794       HloInstruction::CreateParameter(1, f32vec4_, "paramX"));
795   auto paramY = builder.AddInstruction(
796       HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
797   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
798       f32vec4_, HloOpcode::kMultiply, paramA, paramX));
799   DotDimensionNumbers dot_dnums;
800   dot_dnums.add_lhs_contracting_dimensions(1);
801   dot_dnums.add_rhs_contracting_dimensions(0);
802   auto dot0 = builder.AddInstruction(HloInstruction::CreateDot(
803       f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
804   auto dot1 = builder.AddInstruction(HloInstruction::CreateDot(
805       f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2)));
806   auto tuple =
807       builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1}));
808 
809   // The buffers for dot0, dot1 and tuple are the output.  No buffers can be
810   // shared.  The buffer for mul is freed before the end, since it's no longer
811   // used after dot0 finishes.
812   HeapSimulatorTracker tracker(
813       TestName(), builder.Build(),
814       {paramA, paramX, mul, paramY, dot0, dot1, tuple});
815   tracker.ExpectCallSequence({
816       {kAlloc, tracker.BufferAt(paramA, {})},
817       {kAlloc, tracker.BufferAt(paramX, {})},
818       {kAlloc, tracker.BufferAt(paramY, {})},
819       {kAlloc, tracker.BufferAt(mul, {})},
820       {kAlloc, tracker.BufferAt(dot0, {})},
821       {kFree, tracker.BufferAt(mul, {})},  // mul no longer used
822       {kAlloc, tracker.BufferAt(dot1, {})},
823       {kAlloc, tracker.BufferAt(tuple, {})},
824       // All params and outputs are freed at the end.
825       {kFree, tracker.BufferAt(paramA, {})},
826       {kFree, tracker.BufferAt(paramX, {})},
827       {kFree, tracker.BufferAt(paramY, {})},
828       {kFree, tracker.BufferAt(dot0, {})},
829       {kFree, tracker.BufferAt(dot1, {})},
830       {kFree, tracker.BufferAt(tuple, {})},
831       {kFinish, nullptr},
832   });
833 }
834 
TEST_F(HeapSimulatorTest,IndependentTupleElements)835 TEST_F(HeapSimulatorTest, IndependentTupleElements) {
836   auto builder = HloComputation::Builder(TestName());
837   auto paramA = builder.AddInstruction(
838       HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
839   auto paramB = builder.AddInstruction(
840       HloInstruction::CreateParameter(1, f32scalar_, "paramB"));
841   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
842       f32scalar_, HloOpcode::kMultiply, paramA, paramB));
843   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
844       f32scalar_, HloOpcode::kAdd, paramA, paramB));
845   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({mul, add}));
846   auto element0 = builder.AddInstruction(
847       HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 0));
848   auto broadcast = builder.AddInstruction(
849       HloInstruction::CreateBroadcast(f32vec4_, element0, {0}));
850   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
851       f32scalar_, HloOpcode::kSubtract, paramA, paramB));
852   auto element1 = builder.AddInstruction(
853       HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 1));
854   auto output = builder.AddInstruction(
855       HloInstruction::CreateTuple({broadcast, sub, element1}));
856 
857   HeapSimulatorTracker tracker(TestName(), builder.Build(),
858                                {paramA, paramB, mul, add, tuple, element0,
859                                 broadcast, sub, element1, output});
860   tracker.ExpectCallSequence({
861       {kAlloc, tracker.BufferAt(paramA, {})},
862       {kAlloc, tracker.BufferAt(paramB, {})},
863       {kAlloc, tracker.BufferAt(mul, {})},
864       {kAlloc, tracker.BufferAt(add, {})},
865       {kAlloc, tracker.BufferAt(tuple, {})},
866       {kAlloc, tracker.BufferAt(broadcast, {})},
867       // The mul can be freed right after the broadcast happens, even though
868       // The other GetTupleElement is still alive.
869       {kFree, tracker.BufferAt(mul, {})},
870       {kAlloc, tracker.BufferAt(sub, {})},
871       // The temporary tuple is now dead.
872       {kFree, tracker.BufferAt(tuple, {})},
873       {kAlloc, tracker.BufferAt(output, {})},
874       // All params and outputs are freed at the end.
875       {kFree, tracker.BufferAt(paramA, {})},
876       {kFree, tracker.BufferAt(paramB, {})},
877       {kFree, tracker.BufferAt(add, {})},
878       {kFree, tracker.BufferAt(broadcast, {})},
879       {kFree, tracker.BufferAt(sub, {})},
880       {kFree, tracker.BufferAt(output, {})},
881       {kFinish, nullptr},
882   });
883 }
884 
TEST_F(HeapSimulatorTest,WholeModule)885 TEST_F(HeapSimulatorTest, WholeModule) {
886   HeapSimulatorTracker tracker(TestName());
887 
888   const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
889   const Shape tuple_shape =
890       ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
891 
892   auto cond_builder = HloComputation::Builder("WhileCond");
893   HloInstruction* cond_param = cond_builder.AddInstruction(
894       HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
895   HloInstruction* cond_iter = cond_builder.AddInstruction(
896       HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
897   HloInstruction* cond_data = cond_builder.AddInstruction(
898       HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
899   HloInstruction* cond_lt = cond_builder.AddInstruction(
900       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
901                                     cond_data, ComparisonDirection::kLt));
902   HloComputation* cond_computation =
903       tracker.module()->AddEmbeddedComputation(cond_builder.Build());
904 
905   auto body_builder = HloComputation::Builder("WhileBody");
906   HloInstruction* body_param = body_builder.AddInstruction(
907       HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
908   HloComputation* body_computation =
909       tracker.module()->AddEmbeddedComputation(body_builder.Build());
910 
911   auto builder = HloComputation::Builder(TestName());
912   HloInstruction* param = builder.AddInstruction(
913       HloInstruction::CreateParameter(0, tuple_shape, "param"));
914   HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
915       tuple_shape, cond_computation, body_computation, param));
916   tracker.module()->AddEntryComputation(builder.Build());
917 
918   tracker.RunWholeModule(
919       {param, while_op, body_param, cond_param, cond_iter, cond_data, cond_lt});
920   tracker.ExpectCallSequence({
921       // The entry computation param and while_op are allocated first.
922       {kAlloc, tracker.BufferAt(param, {})},
923       {kAlloc, tracker.BufferAt(param, {0})},
924       {kAlloc, tracker.BufferAt(param, {1})},
925 
926       // Now the final cond less-than buffer is allocated.
927       {kAlloc, tracker.BufferAt(cond_lt, {})},
928 
929       // The order of the remaining Free calls is based on the BufferValue.id,
930       // which is deterministic, but not obvious.
931       {kFree, tracker.BufferAt(cond_lt, {})},
932       {kFree, tracker.BufferAt(param, {})},
933       {kFree, tracker.BufferAt(param, {0})},
934       {kFree, tracker.BufferAt(param, {1})},
935       {kFinish, nullptr},
936   });
937 }
938 
939 // Base class for heap algorithm tests.
940 class HeapAlgorithmTestBase : public ::testing::Test {
941  protected:
HeapAlgorithmTestBase()942   HeapAlgorithmTestBase() : builder_("heap_simulator_test") {
943     buffer_a_ = DummyBufferValue();
944     buffer_b_ = DummyBufferValue();
945     buffer_c_ = DummyBufferValue();
946     buffer_d_ = DummyBufferValue();
947     buffer_e_ = DummyBufferValue();
948     buffer_f_ = DummyBufferValue();
949     buffer_g_ = DummyBufferValue();
950     buffer_h_ = DummyBufferValue();
951     buffer_i_ = DummyBufferValue();
952   }
~HeapAlgorithmTestBase()953   ~HeapAlgorithmTestBase() override {}
954 
955   const HloValue* buffer_a_;
956   const HloValue* buffer_b_;
957   const HloValue* buffer_c_;
958   const HloValue* buffer_d_;
959   const HloValue* buffer_e_;
960   const HloValue* buffer_f_;
961   const HloValue* buffer_g_;
962   const HloValue* buffer_h_;
963   const HloValue* buffer_i_;
964 
965  private:
966   // Create a dummy HloValue to pass to the heap algorithm.
DummyBufferValue()967   const HloValue* DummyBufferValue() {
968     const HloValue::Id id = buffers_.size();
969     auto const0 = builder_.AddInstruction(
970         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
971     buffers_.emplace_back(
972         absl::make_unique<HloValue>(id, const0, ShapeIndex{}));
973     return buffers_.back().get();
974   }
975 
976   HloComputation::Builder builder_;
977   std::vector<std::unique_ptr<HloValue>> buffers_;
978 };
979 
980 class NoFragmentationStatsHeapTest : public HeapAlgorithmTestBase {};
981 
TEST_F(NoFragmentationStatsHeapTest,Empty)982 TEST_F(NoFragmentationStatsHeapTest, Empty) {
983   NoFragmentationStatsHeap<HloValue> heap;
984   EXPECT_EQ(0, heap.Finish().heap_size);
985 }
986 
TEST_F(NoFragmentationStatsHeapTest,Simple)987 TEST_F(NoFragmentationStatsHeapTest, Simple) {
988   NoFragmentationStatsHeap<HloValue> heap;
989   heap.Alloc(buffer_a_, 10);
990   heap.Alloc(buffer_b_, 20);
991   heap.Alloc(buffer_c_, 30);
992   heap.Alloc(buffer_d_, 30);
993   heap.Free(buffer_a_, 10);
994   heap.Free(buffer_b_, 20);
995   heap.Free(buffer_c_, 30);
996   heap.Free(buffer_d_, 30);
997   EXPECT_EQ(90, heap.Finish().heap_size);
998 }
999 
TEST_F(NoFragmentationStatsHeapTest,Mixed)1000 TEST_F(NoFragmentationStatsHeapTest, Mixed) {
1001   NoFragmentationStatsHeap<HloValue> heap;
1002   heap.Alloc(buffer_a_, 10);  // max: A
1003 
1004   heap.Alloc(buffer_b_, 20);  // max: A+B
1005   heap.Free(buffer_b_, 20);
1006 
1007   heap.Alloc(buffer_c_, 30);  // max: A+C
1008   heap.Free(buffer_c_, 30);
1009 
1010   heap.Alloc(buffer_d_, 5);  // max: A+C
1011   heap.Free(buffer_d_, 5);
1012 
1013   heap.Free(buffer_a_, 10);
1014   EXPECT_EQ(40, heap.Finish().heap_size);
1015 }
1016 
1017 class GlobalDecreasingSizeBestFitHeapTest : public HeapAlgorithmTestBase {
1018  protected:
1019   class InheritedGlobalDecreasingSizeBestFitHeap
1020       : public GlobalDecreasingSizeBestFitHeap<HloValue> {
1021    public:
InheritedGlobalDecreasingSizeBestFitHeap()1022     InheritedGlobalDecreasingSizeBestFitHeap()
1023         : GlobalDecreasingSizeBestFitHeap(/*alignment=*/1) {}
1024 
1025     // Finds a chunk candidate and returns the offset and the new heap size.
FindChunkCandidate(const HloValue * buffer,int64 size,int64 start,int64 end,int64 preferred_offset=-1)1026     std::pair<int64, int64> FindChunkCandidate(const HloValue* buffer,
1027                                                int64 size, int64 start,
1028                                                int64 end,
1029                                                int64 preferred_offset = -1) {
1030       buffer_interval_.buffer = buffer;
1031       buffer_interval_.size = size;
1032       buffer_interval_.start = start;
1033       buffer_interval_.end = end;
1034       chunk_candidate_ = GlobalDecreasingSizeBestFitHeap::FindChunkCandidate(
1035           buffer_interval_, preferred_offset);
1036       EXPECT_EQ(chunk_candidate_.chunk.size, size);
1037       return {chunk_candidate_.chunk.offset, chunk_candidate_.heap_size};
1038     }
1039 
1040     // Commits the previously found chunk candidate.
CommitChunk()1041     void CommitChunk() {
1042       GlobalDecreasingSizeBestFitHeap::CommitChunk(buffer_interval_,
1043                                                    chunk_candidate_);
1044     }
1045 
1046    private:
1047     BufferInterval buffer_interval_;
1048     ChunkCandidate chunk_candidate_;
1049   };
1050 
1051   InheritedGlobalDecreasingSizeBestFitHeap heap_;
1052 };
1053 
TEST_F(GlobalDecreasingSizeBestFitHeapTest,Empty)1054 TEST_F(GlobalDecreasingSizeBestFitHeapTest, Empty) {
1055   GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1056   const HeapSimulator::Result<HloValue> result = heap.Finish();
1057   EXPECT_EQ(0, result.heap_size);
1058   EXPECT_EQ(1, result.heap_results.size());
1059   EXPECT_EQ(0, result.heap_results.at(0).chunk_map.size());
1060 }
1061 
TEST_F(GlobalDecreasingSizeBestFitHeapTest,DecreasingSize)1062 TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSize) {
1063   // space
1064   //   ^
1065   //   |  +---a---+
1066   //   |      +-------+
1067   //   |      +---c---+
1068   //   |    +-------+
1069   //   |    |   b   |
1070   //   |    +-------+
1071   //   |         +-------+
1072   //   |         |       |
1073   //   |         |   d   |
1074   //   |         +-------+
1075   //   -----------------> time
1076   GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1077   heap.Alloc(buffer_a_, 10);
1078   heap.Alloc(buffer_b_, 30);
1079   heap.Alloc(buffer_c_, 20);
1080   heap.Alloc(buffer_d_, 40);
1081   heap.Free(buffer_a_, 10);
1082   heap.Free(buffer_b_, 30);
1083   heap.Free(buffer_c_, 20);
1084   heap.Free(buffer_d_, 40);
1085 
1086   const HeapSimulator::Result<HloValue> results = heap.Finish();
1087   EXPECT_EQ(1, results.heap_results.size());
1088   const HeapSimulator::HeapResult<HloValue>& result =
1089       results.heap_results.at(0);
1090   EXPECT_EQ(100, result.heap_size);
1091   EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
1092   EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size);
1093   EXPECT_EQ(20, result.chunk_map.at(buffer_c_).size);
1094   EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size);
1095 
1096   EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset);
1097   EXPECT_EQ(40, result.chunk_map.at(buffer_b_).offset);
1098   EXPECT_EQ(70, result.chunk_map.at(buffer_c_).offset);
1099   EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset);
1100 }
1101 
TEST_F(GlobalDecreasingSizeBestFitHeapTest,DecreasingSizeWithAlignment)1102 TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSizeWithAlignment) {
1103   // space
1104   //   ^
1105   //   |      +-------+
1106   //   |      +---b---+
1107   //   |            +-------+
1108   //   |            |       |
1109   //   |            |   d   |
1110   //   |  +---a---+ +-------+
1111   //   |
1112   //   |         +-------+
1113   //   |         |       |
1114   //   |         |   c   |
1115   //   |         |       |
1116   //   |         +-------+
1117   //   ---------------------> time
1118   GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/20);
1119   heap.Alloc(buffer_a_, 10);
1120   heap.Alloc(buffer_b_, 20);
1121   heap.Alloc(buffer_c_, 50);
1122   heap.Free(buffer_a_, 10);
1123   heap.Alloc(buffer_d_, 40);
1124   heap.Free(buffer_b_, 20);
1125   heap.Free(buffer_c_, 50);
1126   heap.Free(buffer_d_, 40);
1127 
1128   const HeapSimulator::Result<HloValue> results = heap.Finish();
1129   EXPECT_EQ(1, results.heap_results.size());
1130   const HeapSimulator::HeapResult<HloValue>& result =
1131       results.heap_results.at(0);
1132   EXPECT_EQ(120, result.heap_size);
1133   EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
1134   EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
1135   EXPECT_EQ(50, result.chunk_map.at(buffer_c_).size);
1136   EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size);
1137 
1138   EXPECT_EQ(60, result.chunk_map.at(buffer_a_).offset);
1139   EXPECT_EQ(100, result.chunk_map.at(buffer_b_).offset);
1140   EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset);
1141   EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset);
1142 }
1143 
TEST_F(GlobalDecreasingSizeBestFitHeapTest,BestFit)1144 TEST_F(GlobalDecreasingSizeBestFitHeapTest, BestFit) {
1145   // space
1146   //   ^
1147   //   |    +-------+
1148   //   |    +---b---+
1149   //   |         +-------+
1150   //   |         |   d   |
1151   //   | +--a--+ +-------+
1152   //   |      +-------+
1153   //   |      |       |
1154   //   |      |   c   |
1155   //   |      +-------+
1156   //   |           +-------+
1157   //   |           |       |
1158   //   |           |   e   |
1159   //   |           |       |
1160   //   |           +-------+
1161   //   ---------------------> time
1162   GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1163   heap.Alloc(buffer_a_, 10);
1164   heap.Alloc(buffer_b_, 20);
1165   heap.Alloc(buffer_c_, 40);
1166   heap.Free(buffer_a_, 10);
1167   heap.Alloc(buffer_d_, 30);
1168   heap.Alloc(buffer_e_, 50);
1169   heap.Free(buffer_b_, 20);
1170   heap.Free(buffer_c_, 40);
1171   heap.Free(buffer_d_, 30);
1172   heap.Free(buffer_e_, 50);
1173 
1174   const HeapSimulator::Result<HloValue> results = heap.Finish();
1175   EXPECT_EQ(1, results.heap_results.size());
1176   const HeapSimulator::HeapResult<HloValue>& result =
1177       results.heap_results.at(0);
1178   EXPECT_EQ(140, result.heap_size);
1179   EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
1180   EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
1181   EXPECT_EQ(40, result.chunk_map.at(buffer_c_).size);
1182   EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size);
1183   EXPECT_EQ(50, result.chunk_map.at(buffer_e_).size);
1184 
1185   EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset);
1186   EXPECT_EQ(120, result.chunk_map.at(buffer_b_).offset);
1187   EXPECT_EQ(50, result.chunk_map.at(buffer_c_).offset);
1188   EXPECT_EQ(90, result.chunk_map.at(buffer_d_).offset);
1189   EXPECT_EQ(0, result.chunk_map.at(buffer_e_).offset);
1190 }
1191 
TEST_F(GlobalDecreasingSizeBestFitHeapTest,Colocated)1192 TEST_F(GlobalDecreasingSizeBestFitHeapTest, Colocated) {
1193   // space      colocate
1194   //   ^   +--------------+
1195   //   |   v              v
1196   //   |+------+      +-------+
1197   //   ||      |      |       |
1198   //   ||      |+----+|       |
1199   //   |+--a---++-b--++---c---+
1200   //   ---------------------> time
1201   GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1202   heap.Alloc(buffer_a_, 40);
1203   heap.Free(buffer_a_, 40);
1204   heap.Alloc(buffer_b_, 20);
1205   heap.Free(buffer_b_, 20);
1206   heap.ShareWith(buffer_c_, buffer_a_, 40);
1207   heap.Free(buffer_c_, 40);
1208 
1209   const HeapSimulator::Result<HloValue> results = heap.Finish();
1210   EXPECT_EQ(1, results.heap_results.size());
1211   const HeapSimulator::HeapResult<HloValue>& result =
1212       results.heap_results.at(0);
1213   EXPECT_EQ(40, result.heap_size);
1214   EXPECT_EQ(40, result.chunk_map.at(buffer_a_).size);
1215   EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
1216   EXPECT_EQ(40, result.chunk_map.at(buffer_c_).size);
1217 
1218   EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
1219   EXPECT_EQ(0, result.chunk_map.at(buffer_b_).offset);
1220   EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset);
1221 }
1222 
TEST_F(GlobalDecreasingSizeBestFitHeapTest,ColocatedII)1223 TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedII) {
1224   // space
1225   //   ^       +---------------+
1226   //   |       +-------b-------+
1227   //   |+------+      +-------+
1228   //   ||      |      |       |
1229   //   ||      |      |       | <--- colocate with a
1230   //   |+--a---+      +---c---+
1231   //   ---------------------> time
1232   GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1233   heap.Alloc(buffer_a_, 40);
1234   heap.Free(buffer_a_, 40);
1235   heap.Alloc(buffer_b_, 20);
1236 
1237   heap.ShareWith(buffer_c_, buffer_a_, 40);
1238   heap.Free(buffer_c_, 40);
1239   heap.Free(buffer_b_, 20);
1240 
1241   const HeapSimulator::Result<HloValue> results = heap.Finish();
1242   EXPECT_EQ(1, results.heap_results.size());
1243   const HeapSimulator::HeapResult<HloValue>& result =
1244       results.heap_results.at(0);
1245   EXPECT_EQ(60, result.heap_size);
1246   EXPECT_EQ(40, result.chunk_map.at(buffer_a_).size);
1247   EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
1248   EXPECT_EQ(40, result.chunk_map.at(buffer_c_).size);
1249 
1250   EXPECT_EQ(0, result.chunk_map.at(buffer_a_).offset);
1251   EXPECT_EQ(40, result.chunk_map.at(buffer_b_).offset);
1252   EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset);
1253 }
1254 
TEST_F(GlobalDecreasingSizeBestFitHeapTest,ColocatedIII)1255 TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedIII) {
1256   // space
1257   //   ^+------+      +-------+
1258   //   ||      |      |       | <--- colocate with a
1259   //   |+--a---+      +---c---+
1260   //   |       +---------------+
1261   //   |       |               |
1262   //   |       |               |
1263   //   |       +-------b-------+
1264   //   ---------------------> time
1265   GlobalDecreasingSizeBestFitHeap<HloValue> heap(/*alignment=*/1);
1266   heap.Alloc(buffer_a_, 10);
1267   heap.Free(buffer_a_, 10);
1268   heap.Alloc(buffer_b_, 30);
1269 
1270   heap.ShareWith(buffer_c_, buffer_a_, 10);
1271   heap.Free(buffer_c_, 10);
1272   heap.Free(buffer_b_, 30);
1273 
1274   const HeapSimulator::Result<HloValue> results = heap.Finish();
1275   EXPECT_EQ(1, results.heap_results.size());
1276   const HeapSimulator::HeapResult<HloValue>& result =
1277       results.heap_results.at(0);
1278   EXPECT_EQ(40, result.heap_size);
1279   EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
1280   EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size);
1281   EXPECT_EQ(10, result.chunk_map.at(buffer_c_).size);
1282 
1283   EXPECT_EQ(30, result.chunk_map.at(buffer_a_).offset);
1284   EXPECT_EQ(0, result.chunk_map.at(buffer_b_).offset);
1285   EXPECT_EQ(30, result.chunk_map.at(buffer_c_).offset);
1286 }
1287 
TEST_F(GlobalDecreasingSizeBestFitHeapTest,ChunkCandidate)1288 TEST_F(GlobalDecreasingSizeBestFitHeapTest, ChunkCandidate) {
1289   // space
1290   //   ^
1291   // 35|
1292   //   |            +-----------+
1293   //   |            |           |
1294   // 30|            |           |
1295   //   |            |  po: 15   |
1296   //   |            |           |
1297   // 25|            +-----g-----+
1298   //   |         +-----+
1299   //   |         |po:20|
1300   // 20|         +--f--+
1301   //   |                                +-----+
1302   //   |                                |     |
1303   // 15|                                |     |
1304   //   |      +-----------------+       |po:10|
1305   //   |      |                 |       |     |
1306   // 10|      +-------c---------+       +--e--+
1307   //   |         +-----+  +-----------+
1308   //   |         |     |  |   po: 5   |
1309   //  5|         |     |  +-----a-----+
1310   //   |+-----+  |     |
1311   //   ||po:10|  |     |
1312   //  0|+--d--+  +--b--+
1313   //   -----------------------------------------> time
1314   //    0  1  2  3  4  5  6  7  8  9 10 11 12 13
1315   using pair = std::pair<int64, int64>;
1316   EXPECT_EQ(pair(5, 10), heap_.FindChunkCandidate(buffer_a_, 5, 6, 10, 5));
1317   heap_.CommitChunk();  // offset: 5, size: 5, start: 6, end: 10
1318   // Preferred offset 5 is returned.
1319   EXPECT_EQ(pair(0, 10), heap_.FindChunkCandidate(buffer_b_, 10, 3, 5));
1320   heap_.CommitChunk();  // offset: 0, size: 10, start: 3, end: 5
1321   EXPECT_EQ(pair(10, 15), heap_.FindChunkCandidate(buffer_c_, 5, 2, 8));
1322   heap_.CommitChunk();  // offset: 10, size: 5, start: 2, end: 8
1323   EXPECT_EQ(pair(0, 15), heap_.FindChunkCandidate(buffer_d_, 5, 0, 2, 10));
1324   heap_.CommitChunk();  // offset: 0, size: 5, start: 0, end: 2
1325   // Preferred offset 10 could not be given because it is occupied.
1326   EXPECT_EQ(pair(10, 20), heap_.FindChunkCandidate(buffer_e_, 10, 11, 13, 10));
1327   heap_.CommitChunk();  // offset: 10, size: 10, start: 11, end: 13
1328   // Preferred offset 10 is returned.
1329   EXPECT_EQ(pair(20, 25), heap_.FindChunkCandidate(buffer_f_, 5, 3, 5, 20));
1330   heap_.CommitChunk();  // offset: 20, size: 5, start: 3, end: 5
1331   // Preferred offset 20 is returned.
1332   EXPECT_EQ(pair(25, 35), heap_.FindChunkCandidate(buffer_g_, 10, 4, 8, 15));
1333   heap_.CommitChunk();  // offset: 25, size: 10, start: 4, end: 8
1334   // Preferred offset 15 could not be given because it is occupied.
1335 }
1336 
1337 class ConstrainedGlobalDecreasingSizeBestFitHeapTest
1338     : public HeapAlgorithmTestBase {};
1339 
TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest,DecreasingSize)1340 TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest, DecreasingSize) {
1341   // space
1342   //   ^
1343   //   |      +-------+
1344   //   |      +---c---+
1345   //   |    +-------+
1346   //   |    |   b   |
1347   //   |    +-------+
1348   //   | ................ // split into two allocations.
1349   //   |  +---a---+
1350   //   |         +-------+
1351   //   |         |       |
1352   //   |         |   d   |
1353   //   |         +-------+
1354   //   -----------------> time
1355   ConstrainedGlobalDecreasingSizeBestFitHeap heap(/*size_limit_per_heap=*/50,
1356                                                   /*alignment=*/1);
1357   heap.Alloc(buffer_a_, 10);
1358   heap.Alloc(buffer_b_, 30);
1359   heap.Alloc(buffer_c_, 20);
1360   heap.Alloc(buffer_d_, 40);
1361   heap.Free(buffer_a_, 10);
1362   heap.Free(buffer_b_, 30);
1363   heap.Free(buffer_c_, 20);
1364   heap.Free(buffer_d_, 40);
1365 
1366   const HeapSimulator::Result<HloValue> result = heap.Finish();
1367   EXPECT_EQ(100, result.heap_size);
1368   EXPECT_EQ(2, result.heap_results.size());
1369 
1370   EXPECT_TRUE(result.heap_results[0].chunk_map.contains(buffer_a_));
1371   EXPECT_TRUE(result.heap_results[0].chunk_map.contains(buffer_d_));
1372   EXPECT_EQ(10, result.heap_results[0].chunk_map.at(buffer_a_).size);
1373   EXPECT_EQ(40, result.heap_results[0].chunk_map.at(buffer_d_).size);
1374   EXPECT_EQ(40, result.heap_results[0].chunk_map.at(buffer_a_).offset);
1375   EXPECT_EQ(0, result.heap_results[0].chunk_map.at(buffer_d_).offset);
1376 }
1377 
TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest,DecreasingSizeWithAlignment)1378 TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest,
1379        DecreasingSizeWithAlignment) {
1380   // space
1381   //   ^
1382   //   |      +-------+
1383   //   |      +---b---+
1384   //   |            +-------+
1385   //   |            |       |
1386   //   |            |   d   |
1387   //   |            +-------+
1388   //   | ...................
1389   //   |  +---a---+
1390   //   |
1391   //   |         +-------+
1392   //   |         |       |
1393   //   |         |   c   |
1394   //   |         |       |
1395   //   |         +-------+
1396   //   ---------------------> time
1397   ConstrainedGlobalDecreasingSizeBestFitHeap heap(/*size_limit_per_heap=*/70,
1398                                                   /*alignment=*/20);
1399   heap.Alloc(buffer_a_, 10);
1400   heap.Alloc(buffer_b_, 20);
1401   heap.Alloc(buffer_c_, 50);
1402   heap.Free(buffer_a_, 10);
1403   heap.Alloc(buffer_d_, 40);
1404   heap.Free(buffer_b_, 20);
1405   heap.Free(buffer_c_, 50);
1406   heap.Free(buffer_d_, 40);
1407 
1408   const HeapSimulator::Result<HloValue> result = heap.Finish();
1409   EXPECT_EQ(130, result.heap_size);  // 70 + 60
1410   EXPECT_EQ(2, result.heap_results.size());
1411 
1412   EXPECT_TRUE(result.heap_results[0].chunk_map.contains(buffer_a_));
1413   EXPECT_TRUE(result.heap_results[0].chunk_map.contains(buffer_c_));
1414   EXPECT_EQ(10, result.heap_results[0].chunk_map.at(buffer_a_).size);
1415   EXPECT_EQ(50, result.heap_results[0].chunk_map.at(buffer_c_).size);
1416   EXPECT_EQ(60, result.heap_results[0].chunk_map.at(buffer_a_).offset);
1417   EXPECT_EQ(0, result.heap_results[0].chunk_map.at(buffer_c_).offset);
1418 }
1419 
TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest,ColocatedII)1420 TEST_F(ConstrainedGlobalDecreasingSizeBestFitHeapTest, ColocatedII) {
1421   // space
1422   //   ^
1423   //   |       +---------------+
1424   //   |       +-------b-------+
1425   //   | ....................
1426   //   |+------+      +-------+
1427   //   ||      |      |       |
1428   //   ||      |      |       | <--- colocate with a
1429   //   |+--a---+      +---c---+
1430   //   ---------------------> time
1431   ConstrainedGlobalDecreasingSizeBestFitHeap heap(/*size_limit_per_heap=*/50,
1432                                                   /*alignment=*/20);
1433   heap.Alloc(buffer_a_, 30);
1434   heap.Free(buffer_a_, 30);
1435   heap.Alloc(buffer_b_, 20);
1436 
1437   heap.ShareWith(buffer_c_, buffer_a_, 40);
1438   heap.Free(buffer_c_, 40);
1439   heap.Free(buffer_b_, 20);
1440 
1441   const HeapSimulator::Result<HloValue> result = heap.Finish();
1442   EXPECT_EQ(50, result.heap_size);
1443   EXPECT_EQ(2, result.heap_results.size());
1444 
1445   EXPECT_TRUE(result.heap_results[0].chunk_map.contains(buffer_a_));
1446   EXPECT_TRUE(result.heap_results[0].chunk_map.contains(buffer_c_));
1447   EXPECT_EQ(30, result.heap_results[0].chunk_map.at(buffer_a_).size);
1448   EXPECT_EQ(30, result.heap_results[0].chunk_map.at(buffer_c_).size);
1449   EXPECT_EQ(0, result.heap_results[0].chunk_map.at(buffer_a_).offset);
1450   EXPECT_EQ(0, result.heap_results[0].chunk_map.at(buffer_c_).offset);
1451 }
1452 
1453 class IntervalTreeTest : public ::testing::Test {};
1454 
TEST_F(IntervalTreeTest,InsertAndRemove)1455 TEST_F(IntervalTreeTest, InsertAndRemove) {
1456   HeapSimulator::Chunk chunk({1, 2});
1457   BufferIntervalTree tree;
1458   tree.Add(1, 2, chunk);
1459   EXPECT_TRUE(tree.Remove(1, 2, chunk));
1460   EXPECT_FALSE(tree.Remove(1, 2, chunk));
1461   ASSERT_EQ(tree.GetRoot(), nullptr);
1462   // Do it again.
1463   tree.Add(1, 2, chunk);
1464   EXPECT_TRUE(tree.Remove(1, 2, chunk));
1465   EXPECT_FALSE(tree.Remove(1, 2, chunk));
1466   ASSERT_EQ(tree.GetRoot(), nullptr);
1467 }
1468 
TEST_F(IntervalTreeTest,InsertAndRemoveTwoLevelsLeft)1469 TEST_F(IntervalTreeTest, InsertAndRemoveTwoLevelsLeft) {
1470   HeapSimulator::Chunk chunk({1, 2});  // Value in chunk doesn't matter here.
1471   //    [20, 36] (45)
1472   //     /
1473   //  [1, 45] (45)
1474 
1475   BufferIntervalTree tree;
1476   tree.Add(20, 36, chunk);
1477   tree.Add(1, 45, chunk);
1478   EXPECT_TRUE(tree.Remove(1, 45, chunk));
1479   EXPECT_EQ(tree.GetRoot()->subtree_end, 36);
1480   EXPECT_TRUE(tree.Remove(20, 36, chunk));
1481   ASSERT_EQ(tree.GetRoot(), nullptr);
1482 }
1483 
TEST_F(IntervalTreeTest,InsertAndRemoveTwoLevelsRight)1484 TEST_F(IntervalTreeTest, InsertAndRemoveTwoLevelsRight) {
1485   HeapSimulator::Chunk chunk({1, 2});  // Value in chunk doesn't matter here.
1486   //    [20, 36] (45)
1487   //          \
1488   //         [21, 45] (45)
1489   BufferIntervalTree tree;
1490   tree.Add(20, 36, chunk);
1491   tree.Add(21, 45, chunk);
1492   EXPECT_TRUE(tree.Remove(21, 45, chunk));
1493   EXPECT_EQ(tree.GetRoot()->subtree_end, 36);
1494   EXPECT_TRUE(tree.Remove(20, 36, chunk));
1495   ASSERT_EQ(tree.GetRoot(), nullptr);
1496 }
1497 
TEST_F(IntervalTreeTest,TwoLevelsRight_RootFirst)1498 TEST_F(IntervalTreeTest, TwoLevelsRight_RootFirst) {
1499   HeapSimulator::Chunk chunk({1, 2});  // Value in chunk doesn't matter here.
1500   //    [20, 36] (45)
1501   //          \
1502   //         [21, 45] (45)
1503   BufferIntervalTree tree;
1504   tree.Add(20, 36, chunk);
1505   tree.Add(21, 45, chunk);
1506   EXPECT_TRUE(tree.Remove(20, 36, chunk));
1507   EXPECT_EQ(tree.GetRoot()->subtree_end, 45);
1508   EXPECT_EQ(tree.GetRoot()->start, 21);
1509   EXPECT_EQ(tree.GetRoot()->end, 45);
1510   EXPECT_EQ(tree.GetRoot()->left, nullptr);
1511   EXPECT_EQ(tree.GetRoot()->right, nullptr);
1512   EXPECT_TRUE(tree.Remove(21, 45, chunk));
1513   ASSERT_EQ(tree.GetRoot(), nullptr);
1514 }
1515 
TEST_F(IntervalTreeTest,TwoLevelsLeft_RootFirst)1516 TEST_F(IntervalTreeTest, TwoLevelsLeft_RootFirst) {
1517   HeapSimulator::Chunk chunk({1, 2});  // Value in chunk doesn't matter here.
1518   //    [20, 36] (45)
1519   //      /
1520   //  [1, 45] (45)
1521   BufferIntervalTree tree;
1522   tree.Add(20, 36, chunk);
1523   tree.Add(1, 45, chunk);
1524   EXPECT_TRUE(tree.Remove(20, 36, chunk));
1525   EXPECT_EQ(tree.GetRoot()->subtree_end, 45);
1526   EXPECT_EQ(tree.GetRoot()->start, 1);
1527   EXPECT_EQ(tree.GetRoot()->end, 45);
1528   EXPECT_EQ(tree.GetRoot()->left, nullptr);
1529   EXPECT_EQ(tree.GetRoot()->right, nullptr);
1530   EXPECT_TRUE(tree.Remove(1, 45, chunk));
1531   ASSERT_EQ(tree.GetRoot(), nullptr);
1532 }
1533 
TEST_F(IntervalTreeTest,ThreeLevelsRight)1534 TEST_F(IntervalTreeTest, ThreeLevelsRight) {
1535   HeapSimulator::Chunk chunk({1, 2});  // Value in chunk doesn't matter here.
1536   //    [20, 36] (45)
1537   //          \
1538   //         [21, 45] (45)
1539   //              \
1540   //              [22, 40] (40)
1541   BufferIntervalTree tree;
1542   tree.Add(20, 36, chunk);
1543   tree.Add(21, 45, chunk);
1544   tree.Add(22, 40, chunk);
1545   EXPECT_TRUE(tree.Remove(21, 45, chunk));
1546   EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1547   EXPECT_TRUE(tree.Remove(20, 36, chunk));
1548   EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1549   EXPECT_TRUE(tree.Remove(22, 40, chunk));
1550   ASSERT_EQ(tree.GetRoot(), nullptr);
1551 }
TEST_F(IntervalTreeTest,ThreeLevelsLeftLeft)1552 TEST_F(IntervalTreeTest, ThreeLevelsLeftLeft) {
1553   HeapSimulator::Chunk chunk({1, 2});  // Value in chunk doesn't matter here.
1554   //    [20, 36] (45)
1555   //       /
1556   //  [10, 45] (45)
1557   //      /
1558   // [1, 40] (40)
1559   BufferIntervalTree tree;
1560   tree.Add(20, 36, chunk);
1561   tree.Add(10, 45, chunk);
1562   tree.Add(1, 40, chunk);
1563   EXPECT_TRUE(tree.Remove(10, 45, chunk));
1564   EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1565   EXPECT_TRUE(tree.Remove(1, 40, chunk));
1566   EXPECT_EQ(tree.GetRoot()->subtree_end, 36);
1567   EXPECT_TRUE(tree.Remove(20, 36, chunk));
1568   ASSERT_EQ(tree.GetRoot(), nullptr);
1569 }
1570 
TEST_F(IntervalTreeTest,ThreeLevelsLeftRight)1571 TEST_F(IntervalTreeTest, ThreeLevelsLeftRight) {
1572   HeapSimulator::Chunk chunk({1, 2});  // Value in chunk doesn't matter here.
1573   //    [20, 36] (45)
1574   //       /
1575   //  [10, 45] (45)
1576   //      \
1577   //     [15, 40] (40)
1578   BufferIntervalTree tree;
1579   tree.Add(20, 36, chunk);
1580   tree.Add(10, 45, chunk);
1581   tree.Add(15, 40, chunk);
1582   EXPECT_TRUE(tree.Remove(10, 45, chunk));
1583   EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1584   EXPECT_TRUE(tree.Remove(15, 40, chunk));
1585   EXPECT_EQ(tree.GetRoot()->subtree_end, 36);
1586   EXPECT_TRUE(tree.Remove(20, 36, chunk));
1587   ASSERT_EQ(tree.GetRoot(), nullptr);
1588 }
1589 
TEST_F(IntervalTreeTest,ThreeLevelsRightLeft)1590 TEST_F(IntervalTreeTest, ThreeLevelsRightLeft) {
1591   HeapSimulator::Chunk chunk({1, 2});  // Value in chunk doesn't matter here.
1592   //    [20, 36] (45)
1593   //          \
1594   //         [25, 45] (45)
1595   //           /
1596   //       [22, 40] (40)
1597   BufferIntervalTree tree;
1598   tree.Add(20, 36, chunk);
1599   tree.Add(25, 45, chunk);
1600   tree.Add(22, 40, chunk);
1601   EXPECT_TRUE(tree.Remove(25, 45, chunk));
1602   EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1603   EXPECT_TRUE(tree.Remove(20, 36, chunk));
1604   EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1605   EXPECT_TRUE(tree.Remove(22, 40, chunk));
1606   ASSERT_EQ(tree.GetRoot(), nullptr);
1607 }
1608 
TEST_F(IntervalTreeTest,ThreeLevelsRightLeftChunkDifferent)1609 TEST_F(IntervalTreeTest, ThreeLevelsRightLeftChunkDifferent) {
1610   HeapSimulator::Chunk chunk1({1, 2});
1611   HeapSimulator::Chunk chunk2({2, 3});
1612   HeapSimulator::Chunk chunk3({3, 4});
1613   //    [20, 36] (45) Chunk1({1, 2})
1614   //          \
1615   //         [25, 45] (45) Chunk2({2, 3})
1616   //           /
1617   //       [22, 40] (40) Chunk3({3, 4})
1618   BufferIntervalTree tree;
1619   tree.Add(20, 36, chunk1);
1620   tree.Add(25, 45, chunk2);
1621   tree.Add(22, 40, chunk3);
1622   EXPECT_TRUE(tree.Remove(25, 45, chunk2));
1623   // Chunk 1 is till the root after removing chunk 2.
1624   EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1625   EXPECT_EQ(tree.GetRoot()->chunk.offset, 1);
1626   EXPECT_EQ(tree.GetRoot()->chunk.size, 2);
1627   EXPECT_TRUE(tree.Remove(20, 36, chunk1));
1628   // Chunk 3 becomes the root now.
1629   EXPECT_EQ(tree.GetRoot()->subtree_end, 40);
1630   EXPECT_EQ(tree.GetRoot()->chunk.offset, 3);
1631   EXPECT_EQ(tree.GetRoot()->chunk.size, 4);
1632   EXPECT_TRUE(tree.Remove(22, 40, chunk3));
1633   ASSERT_EQ(tree.GetRoot(), nullptr);
1634 }
1635 
1636 }  // namespace
1637 }  // namespace xla
1638