• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/heap_simulator.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <memory>
22 #include <tuple>
23 #include <vector>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/container/btree_map.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "tensorflow/compiler/xla/comparison_util.h"
30 #include "tensorflow/compiler/xla/map_util.h"
31 #include "tensorflow/compiler/xla/service/hlo_live_range.h"
32 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
33 #include "tensorflow/compiler/xla/service/memory_space_assignment_repacking.h"
34 #include "tensorflow/compiler/xla/util.h"
35 
36 namespace xla {
37 
38 using absl::flat_hash_map;
39 using absl::flat_hash_set;
40 
OverlapsWith(Chunk other_chunk) const41 bool HeapSimulator::Chunk::OverlapsWith(Chunk other_chunk) const {
42   CHECK_NE(size, 0);
43   CHECK_NE(other_chunk.size, 0);
44   return offset < other_chunk.chunk_end() && other_chunk.offset < chunk_end();
45 }
46 
47 /*static*/
MinimumMemoryForModule(const HloSchedule & schedule,const LogicalBuffer::SizeFunction & size_function)48 StatusOr<int64_t> HeapSimulator::MinimumMemoryForModule(
49     const HloSchedule& schedule,
50     const LogicalBuffer::SizeFunction& size_function) {
51   if (schedule.empty()) {
52     return 0;
53   }
54   const HloModule* module = schedule.module();
55 
56   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
57                       HloAliasAnalysis::Run(module));
58 
59   // The absolute minimum memory required for a given sequence of instructions
60   // is determined by the sequence of Alloc and Free calls on a simulated heap,
61   // ignoring fragmentation. We run the heap simulation on the whole module,
62   // rather than summing each computation, since it gives us a better lower
63   // bound, by minimizing the liveness of sub-computations.
64   TF_ASSIGN_OR_RETURN(
65       HeapSimulator::Result<HloValue> result,
66       HeapSimulator::Run(std::make_unique<NoFragmentationStatsHeap<HloValue>>(),
67                          *module, schedule, *alias_analysis, size_function));
68   return result.heap_size;
69 }
70 
71 /*static*/
MinimumMemoryForComputation(const HloComputation & computation,const HloInstructionSequence & sequence,const HloAliasAnalysis & alias_analysis,const LogicalBuffer::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64_t> * memory_by_computation)72 StatusOr<int64_t> HeapSimulator::MinimumMemoryForComputation(
73     const HloComputation& computation, const HloInstructionSequence& sequence,
74     const HloAliasAnalysis& alias_analysis,
75     const LogicalBuffer::SizeFunction& size_function,
76     const absl::flat_hash_map<const HloComputation*, int64_t>*
77         memory_by_computation) {
78   TF_ASSIGN_OR_RETURN(
79       HeapSimulator::Result<HloValue> result,
80       HeapSimulator::Run(std::make_unique<NoFragmentationStatsHeap<HloValue>>(),
81                          computation, sequence, alias_analysis, size_function,
82                          HeapSimulator::Options(), memory_by_computation));
83   return result.heap_size;
84 }
85 
MinimumMemoryForComputation(const HloComputation & computation,const HloInstructionSequence & sequence,const HloAliasAnalysis & alias_analysis,const LogicalBuffer::SizeFunction & size_function,const HloSchedule * schedule)86 StatusOr<int64_t> HeapSimulator::MinimumMemoryForComputation(
87     const HloComputation& computation, const HloInstructionSequence& sequence,
88     const HloAliasAnalysis& alias_analysis,
89     const LogicalBuffer::SizeFunction& size_function,
90     const HloSchedule* schedule) {
91   TF_ASSIGN_OR_RETURN(
92       HeapSimulator::Result<HloValue> result,
93       HeapSimulator::Run(std::make_unique<NoFragmentationStatsHeap<HloValue>>(),
94                          computation, sequence, alias_analysis, size_function,
95                          schedule, HeapSimulator::Options()));
96   return result.heap_size;
97 }
98 
99 /*static*/
Run(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,const HloModule & module,const HloSchedule & schedule,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_fn,const Options & options)100 StatusOr<HeapSimulator::Result<HloValue>> HeapSimulator::Run(
101     std::unique_ptr<HeapAlgorithm<HloValue>> algorithm, const HloModule& module,
102     const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis,
103     const BufferValue::SizeFunction& size_fn, const Options& options) {
104   HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule);
105   const HloComputation* entry_computation = module.entry_computation();
106   const HloInstructionSequence& instruction_sequence =
107       schedule.sequence(entry_computation);
108   TF_ASSIGN_OR_RETURN(
109       std::unique_ptr<HloLiveRange> hlo_live_range,
110       HloLiveRange::Run(schedule, alias_analysis, entry_computation));
111   TF_RETURN_IF_ERROR(heap.RunComputation(*entry_computation,
112                                          instruction_sequence, alias_analysis,
113                                          hlo_live_range.get()));
114   return heap.Finish();
115 }
116 
117 /*static*/
Run(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_fn,const Options & options,const absl::flat_hash_map<const HloComputation *,int64_t> * memory_by_computation)118 StatusOr<HeapSimulator::Result<HloValue>> HeapSimulator::Run(
119     std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
120     const HloComputation& computation,
121     const HloInstructionSequence& instruction_sequence,
122     const HloAliasAnalysis& alias_analysis,
123     const BufferValue::SizeFunction& size_fn, const Options& options,
124     const absl::flat_hash_map<const HloComputation*, int64_t>*
125         memory_by_computation) {
126   HeapSimulator heap(std::move(algorithm), size_fn, options,
127                      /*schedule=*/nullptr, memory_by_computation);
128   HloSchedule schedule(computation.parent());
129   schedule.set_sequence(&computation, instruction_sequence);
130   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
131                       HloLiveRange::Run(schedule, alias_analysis, &computation,
132                                         /*module_scoped_analysis=*/false));
133   TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
134                                          alias_analysis, hlo_live_range.get()));
135   return heap.Finish();
136 }
137 
138 /*static*/
Run(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_fn,const HloSchedule * schedule,const Options & options)139 StatusOr<HeapSimulator::Result<HloValue>> HeapSimulator::Run(
140     std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
141     const HloComputation& computation,
142     const HloInstructionSequence& instruction_sequence,
143     const HloAliasAnalysis& alias_analysis,
144     const BufferValue::SizeFunction& size_fn, const HloSchedule* schedule,
145     const Options& options) {
146   HeapSimulator heap(std::move(algorithm), size_fn, options,
147                      /*schedule=*/schedule, nullptr);
148   TF_ASSIGN_OR_RETURN(
149       std::unique_ptr<HloLiveRange> hlo_live_range,
150       HloLiveRange::Run(*schedule, alias_analysis, &computation));
151   TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
152                                          alias_analysis, hlo_live_range.get()));
153   return heap.Finish();
154 }
155 
156 // Runs a heap simulation for the given 'computation', assuming the given
157 // 'instruction_sequence'.
RunComputation(const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const HloAliasAnalysis & alias_analysis,HloLiveRange * hlo_live_range)158 Status HeapSimulator::RunComputation(
159     const HloComputation& computation,
160     const HloInstructionSequence& instruction_sequence,
161     const HloAliasAnalysis& alias_analysis, HloLiveRange* hlo_live_range) {
162   XLA_VLOG_LINES(1, computation.parent()->ToString());
163   XLA_VLOG_LINES(2, computation.ToString());
164 
165   VLOG(1) << hlo_live_range->ToString();
166 
167   HloDataflowAnalysis& dataflow_analysis = alias_analysis.dataflow_analysis();
168 
169   // Record the buffer define/free event for each time step. We free all
170   // remaining buffers (entry parameter, etc) after the program has finished
171   // running, so we set the size of to program_end_time + 1.
172   std::vector<std::vector<const HloValue*>> buffers_defined(
173       hlo_live_range->schedule_end_time() + 1);
174   std::vector<std::vector<const HloValue*>> buffers_freed(
175       hlo_live_range->schedule_end_time() + 1);
176 
177   // values_to_assign tracks the HloValues that we need to assign a buffer to.
178   // Note that we only need to assign a buffer to a value when both of the
179   // following conditions are met:
180   //
181   // - The user specifically asks us to assign a buffer to a set of HloValues,
182   // and the value is in the set. If the user don't provide such a set, by
183   // default we assign buffer to all HloValues.
184   //
185   // - If the instruction is in a nested call of the current computation, only
186   // assign a buffer if we are doing global heap simulation.
187   std::vector<const HloValue*> values_to_assign;
188   values_to_assign.reserve(dataflow_analysis.values().size());
189 
190   auto& buffer_live_ranges = hlo_live_range->buffer_live_ranges();
191 
192   for (const HloValue* value : dataflow_analysis.values()) {
193     // Ignore buffers that are not tracked.
194     if (!buffer_live_ranges.contains(value)) {
195       continue;
196     }
197     if (IgnoreBuffer(value)) {
198       continue;
199     }
200 
201     values_to_assign.push_back(value);
202   }
203 
204   absl::c_sort(values_to_assign,
205                [&](const HloValue* value1, const HloValue* value2) {
206                  const auto& live_range1 = buffer_live_ranges.at(value1);
207                  const auto& live_range2 = buffer_live_ranges.at(value2);
208                  return std::forward_as_tuple(live_range1.start,
209                                               live_range1.end, value1->id()) <
210                         std::forward_as_tuple(live_range2.start,
211                                               live_range2.end, value2->id());
212                });
213 
214   // For each value that we need to assign a buffer to, add the define and free
215   // events.
216   for (const HloValue* value : values_to_assign) {
217     auto live_range = buffer_live_ranges.at(value);
218     buffers_defined[live_range.start].push_back(value);
219     buffers_freed[live_range.end].push_back(value);
220   }
221 
222   // All HloValues in a hlo buffer should be allocated to the same address. This
223   // map tracks the first value that got allocated in a buffer.
224   absl::flat_hash_map<const HloBuffer*, const HloValue*> first_allocated_value;
225 
226   VLOG(1) << "Program time" << hlo_live_range->schedule_end_time();
227 
228   // Go through each step in the program and replay each buffer define and free
229   // events.
230   for (int64_t i = 0; i < hlo_live_range->schedule_end_time() + 1; ++i) {
231     VLOG(1) << "Time step: " << i;
232 
233     for (const HloValue* value : buffers_defined[i]) {
234       bool shared = false;
235       VLOG(1) << "Start buffer: " << value->ToShortString();
236       const HloBuffer* hlo_buffer =
237           &alias_analysis.GetBufferContainingValue(*value);
238       if (first_allocated_value.count(hlo_buffer) != 0) {
239         // We've already assigned an address for another value in this HloBuffer
240         // (HloBuffer holds several aliased HloValues). All values in a buffer
241         // should be assigned the same address. Find the one that's already
242         // allocated and reuse its address.
243         ShareBuffer(value, first_allocated_value[hlo_buffer],
244                     value->instruction());
245         VLOG(1) << "  ShareWith"
246                 << first_allocated_value[hlo_buffer]->ToShortString();
247         continue;
248       }
249       if (options_.may_reuse_operand_buffers &&
250           hlo_buffer->values().size() == 1) {
251         // We don't support sharing an aliased buffer
252         // (hlo_buffer->values().size() > 1) with its operand.
253         for (const HloInstruction* operand : value->instruction()->operands()) {
254           const HloValueSet operand_value_set =
255               dataflow_analysis.GetValueSet(operand);
256           for (const HloValue* operand_value : operand_value_set.values()) {
257             const HloBuffer* operand_buffer =
258                 &alias_analysis.GetBufferContainingValue(*operand_value);
259             if (operand_buffer->values().size() > 1) {
260               continue;
261             }
262             auto it = buffer_live_ranges.find(operand_value);
263             if (it == buffer_live_ranges.end()) {
264               continue;
265             }
266 
267             auto& operand_live_range = it->second;
268 
269             auto& user_live_range = buffer_live_ranges[value];
270 
271             // Can only share buffers that are about to be freed.
272             if (operand_live_range.end != i) {
273               continue;
274             }
275 
276             if (IgnoreBuffer(operand_value)) {
277               continue;
278             }
279 
280             if (!absl::c_linear_search(buffers_freed[i], operand_value)) {
281               // If the operand buffer is not being freed (either because it has
282               // existing users, or it has been reused by other buffers), don't
283               // consider the operand as a candidate of buffer sharing.
284               continue;
285             }
286 
287             // The instruction that defines the operand value can be different
288             // from the actual operand, if directly passing the defining
289             // instruction into "CanShareOperandBufferWithUser" it creates a
290             // check failure. The first condition guards against that case.
291             if (value->instruction()->IsUserOf(operand_value->instruction()) &&
292                 value->instruction()->opcode() != HloOpcode::kCopy &&
293                 dataflow_analysis.CanShareOperandBufferWithUser(
294                     operand_value->instruction(), operand_value->index(),
295                     value->instruction(), value->index())) {
296               // Remove the operand buffer right before sharing (allocating) a
297               // new one.
298               Free(operand_value, operand_value->instruction());
299               buffers_freed[i].erase(
300                   std::remove(buffers_freed[i].begin(), buffers_freed[i].end(),
301                               operand_value),
302                   buffers_freed[i].end());
303               ShareBuffer(value, operand_value, value->instruction());
304               // The live range of the operand buffer is now extended to the end
305               // of the current instruction.
306               operand_live_range.end = user_live_range.end;
307               VLOG(1) << "Sharing " << value->ToShortString() << " with "
308                       << operand_value->ToShortString()
309                       << ", size:" << size_fn_(*value);
310               shared = true;
311               break;
312             }
313           }
314           if (shared) {
315             break;
316           }
317         }
318       }
319       if (!shared) {
320         Alloc(value, value->instruction());
321         first_allocated_value[hlo_buffer] = value;
322       }
323     }
324 
325     if (!buffers_freed[i].empty()) {
326       VLOG(1) << "Free Buffer: ";
327     }
328     for (const HloValue* value : buffers_freed[i]) {
329       VLOG(1) << "  " << value->ToShortString();
330 
331       Free(value, value->instruction());
332     }
333   }
334   return OkStatus();
335 }
336 
HeapSimulator(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,const BufferValue::SizeFunction & size_fn,const Options & options,const HloSchedule * schedule,const absl::flat_hash_map<const HloComputation *,int64_t> * memory_by_computation)337 HeapSimulator::HeapSimulator(
338     std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
339     const BufferValue::SizeFunction& size_fn, const Options& options,
340     const HloSchedule* schedule,
341     const absl::flat_hash_map<const HloComputation*, int64_t>*
342         memory_by_computation)
343     : no_fragmentation_stats_(
344           std::make_unique<NoFragmentationStatsHeap<HloValue>>()),
345       algorithm_(std::move(algorithm)),
346       size_fn_(size_fn),
347       options_(options),
348       schedule_(schedule),
349       memory_by_computation_(memory_by_computation) {
350   debug_trace_.set_whole_module_simulation(schedule_ != nullptr);
351 }
352 
~HeapSimulator()353 HeapSimulator::~HeapSimulator() {}
354 
IgnoreBuffer(const HloValue * buffer) const355 bool HeapSimulator::IgnoreBuffer(const HloValue* buffer) const {
356   // Buffers for constants are ignored unless the alloc_constants option is
357   // set. Also ignore buffers that we're not meant to assign.
358   //
359   // TODO(b/32248867): For consistency, constants should get allocations.
360   if (!options_.alloc_constants &&
361       buffer->instruction()->opcode() == HloOpcode::kConstant) {
362     return true;
363   }
364   return options_.buffers_to_assign != nullptr &&
365          !options_.buffers_to_assign->contains(buffer);
366 }
367 
368 // Alloc always calls the underlying heap algorithm.
Alloc(const HloValue * buffer,const HloInstruction * instruction)369 void HeapSimulator::Alloc(const HloValue* buffer,
370                           const HloInstruction* instruction) {
371   CHECK(!allocated_buffers_.contains(buffer))
372       << "Alloc called on allocated buffer: " << *buffer;
373   CHECK(!freed_buffers_.contains(buffer))
374       << "Alloc called on freed buffer: " << *buffer;
375 
376   allocated_buffers_.insert(buffer);
377   const int64_t size = size_fn_(*buffer);
378   algorithm_->Alloc(buffer, size);
379   no_fragmentation_stats_->Alloc(buffer, size);
380   FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction,
381                  nullptr);
382 }
383 
384 // Free calls the underlying algorithm for non-shared buffers, and for shared
385 // buffers whose group liveness has expired.  Shared group liveness is tracked
386 // by maintaining a refcount; the Free call on the last buffer in the group
387 // causes Free to be called on the underlying algorithm.
Free(const HloValue * buffer,const HloInstruction * instruction)388 void HeapSimulator::Free(const HloValue* buffer,
389                          const HloInstruction* instruction) {
390   const int64_t size = size_fn_(*buffer);
391   algorithm_->Free(buffer, size);
392   no_fragmentation_stats_->Free(buffer, size);
393   FillDebugTrace(HeapSimulatorTrace::Event::FREE, buffer, instruction, nullptr);
394 }
395 
396 // ShareBuffer associates buffers with their SharedGroup in shared_buffers_.
397 // The 'buffer' must be a non-allocated, non-freed buffer, just like in calls
398 // to Alloc.  The 'shared' buffer must be a previously allocated or shared
399 // buffer. Both 'buffer' and 'shared' will be associated with the same
400 // SharedGroup.
ShareBuffer(const HloValue * buffer,const HloValue * shared,const HloInstruction * instruction)401 void HeapSimulator::ShareBuffer(const HloValue* buffer, const HloValue* shared,
402                                 const HloInstruction* instruction) {
403   algorithm_->ShareWith(buffer, shared, size_fn_(*shared));
404   no_fragmentation_stats_->ShareWith(buffer, shared, size_fn_(*shared));
405   FillDebugTrace(HeapSimulatorTrace::Event::SHARE_WITH, buffer, instruction,
406                  shared);
407 }
408 
Finish()409 HeapSimulator::Result<HloValue> HeapSimulator::Finish() {
410   Result<HloValue> result = algorithm_->Finish();
411 
412   // Post-process the result to add chunks for shared buffers.  An empty chunk
413   // map means that either no buffers were allocated, or the heap was only
414   // collecting statistics, e.g. NoFragmentationStatsHeap.
415   size_t total_chunk_count = absl::c_accumulate(
416       result.heap_results, static_cast<size_t>(0),
417       [&](size_t lhs, const HeapResult<HloValue>& rhs) -> size_t {
418         return lhs + rhs.chunk_map.size();
419       });
420   if (total_chunk_count != 0) {
421     // If we were told to assign specific buffers, make sure we've assigned
422     // exactly that many buffers.
423     if (options_.buffers_to_assign != nullptr) {
424       CHECK_EQ(options_.buffers_to_assign->size(), total_chunk_count);
425     }
426   }
427 
428   // Fragmentation is the difference between the actual and ideal sizes.
429   const Result<HloValue> no_frag_result = no_fragmentation_stats_->Finish();
430   result.fragmentation_size = result.heap_size - no_frag_result.heap_size;
431 
432   // Copy the debug trace we collected to the final result.
433   result.debug_trace.Swap(&debug_trace_);
434 
435   return result;
436 }
437 
FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,const HloValue * buffer,const HloInstruction * instruction,const HloValue * share_with_canonical)438 void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,
439                                    const HloValue* buffer,
440                                    const HloInstruction* instruction,
441                                    const HloValue* share_with_canonical) {
442   HeapSimulatorTrace::Event* event = debug_trace_.add_events();
443   event->set_kind(kind);
444   event->set_buffer_id(buffer->id());
445   event->set_computation_name(instruction->parent()->name());
446   event->set_instruction_name(instruction->name());
447   if (kind == HeapSimulatorTrace::Event::SHARE_WITH) {
448     CHECK(share_with_canonical != nullptr);
449     event->set_share_with_canonical_id(share_with_canonical->id());
450   } else {
451     CHECK(share_with_canonical == nullptr);
452   }
453 }
454 
455 template <typename BufferType>
Alloc(const BufferType * buffer,int64_t size)456 void NoFragmentationStatsHeap<BufferType>::Alloc(const BufferType* buffer,
457                                                  int64_t size) {
458   current_heap_size_ += size;
459   if (current_heap_size_ > max_heap_size_) {
460     max_heap_size_ = current_heap_size_;
461   }
462 }
463 
464 template <typename BufferType>
AccountForSubcomputationMemory(const HloInstruction * instruction,int64_t alloc_size_by_instruction,const absl::flat_hash_map<const HloComputation *,int64_t> & memory_by_computation)465 void NoFragmentationStatsHeap<BufferType>::AccountForSubcomputationMemory(
466     const HloInstruction* instruction, int64_t alloc_size_by_instruction,
467     const absl::flat_hash_map<const HloComputation*, int64_t>&
468         memory_by_computation) {
469   // We only count the memory usage of the largest subcomputation, instead of
470   // adding them all, because subcomputations won't execute in parallel.
471   int64_t max_subcomputation_bytes = 0;
472   for (const auto* c : instruction->called_computations()) {
473     auto it = memory_by_computation.find(c);
474     if (it != memory_by_computation.end()) {
475       int64_t subcomputation_bytes = it->second;
476       if (subcomputation_bytes > max_subcomputation_bytes) {
477         max_subcomputation_bytes = subcomputation_bytes;
478       }
479     }
480   }
481   if (max_subcomputation_bytes > 0 &&
482       (instruction->opcode() == HloOpcode::kWhile ||
483        instruction->opcode() == HloOpcode::kCall ||
484        instruction->opcode() == HloOpcode::kConditional)) {
485     // The output buffer of while/call/conditional is always aliased with the
486     // output buffer of the root instruction in the body. Don't double count.
487     max_subcomputation_bytes -= alloc_size_by_instruction;
488   }
489   max_heap_size_ =
490       std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes);
491 }
492 
493 template <typename BufferType>
Free(const BufferType * buffer,int64_t size)494 void NoFragmentationStatsHeap<BufferType>::Free(const BufferType* buffer,
495                                                 int64_t size) {
496   current_heap_size_ -= size;
497 }
498 
499 template <typename BufferType>
500 HeapSimulator::Result<BufferType>
Finish()501 NoFragmentationStatsHeap<BufferType>::Finish() {
502   // The result.chunk_map is empty, since we only collect stats, and don't
503   // actually compute chunk assignments.
504   Result result;
505   result.heap_size = max_heap_size_;
506   return result;
507 }
508 
509 template <typename BufferType>
GlobalDecreasingSizeBestFitHeap(int64_t alignment,Type type)510 GlobalDecreasingSizeBestFitHeap<BufferType>::GlobalDecreasingSizeBestFitHeap(
511     int64_t alignment, Type type)
512     : alignment_(alignment) {
513   if (type == kTemporal) {
514     buffer_interval_compare_ = GetTemporalBufferIntervalCompare();
515   } else {
516     CHECK(type == kSpatial);
517     buffer_interval_compare_ = GetSpatialBufferIntervalCompare();
518   }
519 }
520 
521 template <typename BufferType>
522 typename GlobalDecreasingSizeBestFitHeap<BufferType>::BufferIntervalCompare
GetTemporalBufferIntervalCompare() const523 GlobalDecreasingSizeBestFitHeap<BufferType>::GetTemporalBufferIntervalCompare()
524     const {
525   return LessThanByKey([this](const BufferInterval& x) {
526     int64_t x_end = x.end;
527     for (auto colocation : GetTransitiveColocations(x)) {
528       x_end = std::max(x_end, buffer_intervals_.at(colocation).end);
529     }
530     // Sort by duration (descending), size (descending), buffer (ascending).
531     return std::make_tuple(x.start - x_end, -x.size, std::cref(*x.buffer));
532   });
533 }
534 
535 template <typename BufferType>
536 /*static*/ typename GlobalDecreasingSizeBestFitHeap<
537     BufferType>::BufferIntervalCompare
GetSpatialBufferIntervalCompare()538 GlobalDecreasingSizeBestFitHeap<BufferType>::GetSpatialBufferIntervalCompare() {
539   return LessThanByKey([](const BufferInterval& x) {
540     // Sort by size (descending), duration (descending), buffer (ascending).
541     return std::make_tuple(-x.size, x.start - x.end, std::cref(*x.buffer));
542   });
543 }
544 
545 template <typename BufferType>
Alloc(const BufferType * buffer,int64_t size)546 void GlobalDecreasingSizeBestFitHeap<BufferType>::Alloc(
547     const BufferType* buffer, int64_t size) {
548   // Degenerate case: 0-sized buffers are always allocated at offset 0.
549   if (size == 0) {
550     result_.chunk_map.emplace(buffer, Chunk{0, 0});
551     return;
552   }
553 
554   auto emplace_result = buffer_intervals_.emplace(
555       buffer, BufferInterval{buffer, size, current_time_, -1, {}, true});
556   DCHECK(emplace_result.second);
557   ++current_time_;
558 }
559 
560 template <typename BufferType>
ShareWith(const BufferType * buffer,const BufferType * share_with,int64_t size)561 void GlobalDecreasingSizeBestFitHeap<BufferType>::ShareWith(
562     const BufferType* buffer, const BufferType* share_with, int64_t size) {
563   // Degenerate case: 0-sized buffers are always allocated at offset 0.
564   if (size == 0) {
565     result_.chunk_map.emplace(buffer, Chunk{0, 0});
566     return;
567   }
568   DCHECK_NE(buffer_intervals_.count(share_with), 0);
569   buffer_intervals_[share_with].colocations.push_back(buffer);
570   auto emplace_result = buffer_intervals_.emplace(
571       buffer, BufferInterval{buffer, size, current_time_, -1, {}, false});
572   DCHECK(emplace_result.second);
573   ++current_time_;
574 }
575 
576 template <typename BufferType>
577 absl::flat_hash_set<const BufferType*>
GetTransitiveColocations(const BufferInterval & interval) const578 GlobalDecreasingSizeBestFitHeap<BufferType>::GetTransitiveColocations(
579     const BufferInterval& interval) const {
580   absl::flat_hash_set<const BufferType*> result;
581   std::vector<const BufferInterval*> worklist = {&interval};
582   while (!worklist.empty()) {
583     const BufferInterval* item = worklist.back();
584     worklist.pop_back();
585     for (const BufferType* buffer_colocated : item->colocations) {
586       if (result.insert(buffer_colocated).second) {
587         worklist.push_back(&buffer_intervals_.at(buffer_colocated));
588       }
589     }
590   }
591 
592   return result;
593 }
594 
595 template <typename BufferType>
Free(const BufferType * buffer,int64_t size)596 void GlobalDecreasingSizeBestFitHeap<BufferType>::Free(const BufferType* buffer,
597                                                        int64_t size) {
598   // Degenerate case: 0-sized buffers are always allocated at offset 0.
599   if (size == 0) {
600     return;
601   }
602   BufferInterval& buffer_interval = FindOrDie(buffer_intervals_, buffer);
603   DCHECK_EQ(buffer_interval.buffer, buffer);
604   DCHECK_EQ(buffer_interval.size, size);
605   DCHECK_EQ(buffer_interval.end, -1);
606   if (buffer_interval.end != -1) {
607     return;
608   }
609   buffer_interval.end = current_time_;
610   ++current_time_;
611 }
612 
613 using Chunk = HeapSimulator::Chunk;
614 
Add(int64_t start,int64_t end,const Chunk & chunk)615 void BufferIntervalTree::Add(int64_t start, int64_t end, const Chunk& chunk) {
616   node_storage_.emplace_back(BufferIntervalTreeNode{
617       start, end, end, chunk,
618       /*left=*/nullptr, /*right=*/nullptr, /*parent=*/nullptr});
619   if (root_ == nullptr) {
620     root_ = &node_storage_.back();
621     // This is root.
622     return;
623   }
624 
625   BufferIntervalTreeNode* parent = root_;
626   while (true) {
627     parent->subtree_end = std::max(parent->subtree_end, end);
628     if (parent->start > start) {
629       if (parent->left == nullptr) {
630         parent->left = &node_storage_.back();
631         node_storage_.back().parent = parent;
632         return;
633       }
634       parent = parent->left;
635     } else {
636       if (parent->right == nullptr) {
637         parent->right = &node_storage_.back();
638         node_storage_.back().parent = parent;
639         return;
640       }
641       parent = parent->right;
642     }
643   }
644 }
645 
Remove(int64_t start,int64_t end,const Chunk & chunk)646 bool BufferIntervalTree::Remove(int64_t start, int64_t end,
647                                 const Chunk& chunk) {
648   BufferIntervalTreeNode* to_delete = root_;
649   while (to_delete != nullptr) {
650     if (to_delete->start == start && to_delete->end == end &&
651         to_delete->chunk.offset == chunk.offset) {
652       break;
653     }
654     if (start < to_delete->start) {
655       to_delete = to_delete->left;
656     } else {
657       to_delete = to_delete->right;
658     }
659   }
660   if (to_delete == nullptr) {
661     // Nothing to delete.
662     return false;
663   }
664   // Found the node to be deleted, enter deletion sequence.
665 
666   // Recursively traverse the parents of node and fix up the `subtree_end`
667   // invariant of a node. Recursive lambda need an explicit
668   // std::function declaration.
669   std::function<void(BufferIntervalTreeNode*)> fix_up =
670       [&](BufferIntervalTreeNode* node) {
671         if (node == nullptr) {
672           return;
673         }
674         node->subtree_end = node->end;
675         if (node->left) {
676           node->subtree_end =
677               std::max(node->subtree_end, node->left->subtree_end);
678         }
679         if (node->right) {
680           node->subtree_end =
681               std::max(node->subtree_end, node->right->subtree_end);
682         }
683         // Recursively go up.
684         fix_up(node->parent);
685       };
686 
687   if (to_delete->right == nullptr) {
688     // to_delete has no right child, simply move up left child of to_delete if
689     // any.
690     //
691     // Turn:
692     //      parent
693     //       /
694     // to_delete
695     //  /      \
696     // left    nullptr
697     //
698     // Into:
699     //      parent
700     //      /
701     //    left
702     if (root_ == to_delete) {
703       // Deleting root is simply reseting root;
704       root_ = to_delete->left;
705       return true;
706     }
707 
708     if (to_delete == to_delete->parent->left) {
709       // to_delete is left child of parent.
710       to_delete->parent->left = to_delete->left;
711     }
712     if (to_delete == to_delete->parent->right) {
713       // to_delete is right child of parent.
714       to_delete->parent->right = to_delete->left;
715     }
716     // Rewire parent to the node being moved up.
717     if (to_delete->left) {
718       to_delete->left->parent = to_delete->parent;
719     }
720     // Fix up starting from subroot.
721     fix_up(to_delete);
722   } else {
723     // 1. Find left-most node of the right subtree, promote it to the position
724     // of to_delete.
725     BufferIntervalTreeNode* to_promote = to_delete->right;
726     while (to_promote->left != nullptr) {
727       // Go to left-most subtree.
728       to_promote = to_promote->left;
729     }
730 
731     // 2. Copy the content of `to_promote` to `to_delete`.
732     to_delete->start = to_promote->start;
733     to_delete->end = to_promote->end;
734     // This is incorrect but we will fix this up later in the `fix_up`
735     // procedure.
736     to_delete->subtree_end = to_promote->subtree_end;
737     to_delete->chunk = to_promote->chunk;
738     auto to_promote_parent = to_promote->parent;
739     // 3. Move the right child of `to_promote` up if there is any.
740     //
741     // Turn
742     //
743     // to_delete
744     //         \
745     //        to_promote_parent
746     //         /
747     //    to_promote
748     //          \
749     //          right
750     // into
751     //
752     // to_promote
753     //         \
754     //         to_promote_parent
755     //         /
756     //      right
757     if (to_promote_parent->left == to_promote) {
758       to_promote_parent->left = to_promote->right;
759     } else {
760       to_promote_parent->right = to_promote->right;
761     }
762     if (to_promote->right) {
763       // Set correct parent.
764       to_promote->right->parent = to_promote_parent;
765     }
766     // 4. Recursive fix up the `subtree_end` starting from
767     // `to_promote_parent`.
768     fix_up(to_promote_parent);
769   }
770   // Don't free the entry in node_storage_ until we free the entire tree.
771   return true;
772 }
773 
ChunksOverlappingInTime(int64_t start,int64_t end) const774 std::vector<Chunk> BufferIntervalTree::ChunksOverlappingInTime(
775     int64_t start, int64_t end) const {
776   std::vector<Chunk> result;
777   if (root_ == nullptr) {
778     return result;
779   }
780   std::vector<const BufferIntervalTreeNode*> visiting_stack;
781   visiting_stack.push_back(root_);
782   while (!visiting_stack.empty()) {
783     const BufferIntervalTreeNode* top = visiting_stack.back();
784     visiting_stack.pop_back();
785     if (start > top->subtree_end) {
786       continue;
787     }
788     if (top->left != nullptr) {
789       visiting_stack.push_back(top->left);
790     }
791     if (top->start <= end && top->end >= start) {
792       result.push_back(top->chunk);
793     }
794     if (end < top->start) {
795       continue;
796     }
797     if (top->right != nullptr) {
798       visiting_stack.push_back(top->right);
799     }
800   }
801   return result;
802 }
803 
804 template <typename BufferType>
805 HeapSimulator::Result<BufferType>
Finish()806 GlobalDecreasingSizeBestFitHeap<BufferType>::Finish() {
807   std::vector<BufferInterval> sorted_buffer_intervals =
808       GetSortedBufferIntervals();
809 
810   for (auto& buffer_interval : sorted_buffer_intervals) {
811     if (!buffer_interval.need_allocation) {
812       continue;
813     }
814 
815     // This implementation of the heap algorithm does not have a notion of
816     // maximum heap size, so it just commits.
817     CommitChunk(buffer_interval, FindChunkCandidate(buffer_interval));
818   }
819   VLOG(1) << "result heap_size: " << result_.heap_size;
820   Result result;
821   result.heap_size = result_.heap_size;
822   result.heap_results.emplace_back(result_);
823   return result;
824 }
825 
826 template <typename BufferType>
827 std::vector<
828     typename GlobalDecreasingSizeBestFitHeap<BufferType>::BufferInterval>
GetSortedBufferIntervals() const829 GlobalDecreasingSizeBestFitHeap<BufferType>::GetSortedBufferIntervals() const {
830   std::vector<BufferInterval> sorted_buffer_intervals;
831   sorted_buffer_intervals.reserve(buffer_intervals_.size());
832   for (auto& entry : buffer_intervals_) {
833     sorted_buffer_intervals.push_back(entry.second);
834   }
835   absl::c_sort(sorted_buffer_intervals, buffer_interval_compare_);
836 
837   return sorted_buffer_intervals;
838 }
839 
840 template <typename BufferType>
841 typename GlobalDecreasingSizeBestFitHeap<BufferType>::Chunk
FindChunkCandidate(const GlobalDecreasingSizeBestFitHeap::BufferInterval & buffer_interval,int64_t preferred_offset) const842 GlobalDecreasingSizeBestFitHeap<BufferType>::FindChunkCandidate(
843     const GlobalDecreasingSizeBestFitHeap::BufferInterval& buffer_interval,
844     int64_t preferred_offset) const {
845   VLOG(1) << "Finding chunks for buffer: "
846           << buffer_interval.buffer->ToString();
847   VLOG(1) << "Size " << buffer_interval.size << ", start "
848           << buffer_interval.start << ", end " << buffer_interval.end;
849   // Get all colocated buffers and gather all interferenced chunks.
850   //
851   // Imagine that we've already allocated three chunks : a, b and c.  And now
852   // we want to allocate d. Since e is colocated with d, we have to allocate
853   // chunks for them together at the same address. To do this, we first gather
854   // all chunks that overlap with d and e on the time dimension, in this case
855   // the overlapped chunks are a and b (c doesn't overlap with either of d and
856   // e), then find create a new chunk that doesn't overlap with a and b on the
857   // space dimension.
858   //
859   // space
860   //   ^
861   //   |+--d---+      +---e---+
862   //   |
863   //   |+---+  +---------------+  +-------+
864   //   ||   |  |               |  |       |
865   //   ||   |  |               |  |       |
866   //   |+-a-+  +-------b-------+  +---c---+
867   //   ----------------------------------------> time
868 
869   // Map free chunk offsets -> ends.
870   // We use `greater` for the comparison so that we can use `lower_bound` to
871   // find the largest key less than or equal to the lookup value.
872   absl::btree_map<int64_t, int64_t, std::greater<int64_t>> free_chunks{
873       {0, INT64_MAX}};  // Initialize with "infinite" free memory.
874 
875   // Find the max size of interval across its colocations and use this value to
876   // determine whether the buffer will fit in the heap.
877   int64_t max_colocation_size = buffer_interval.size;
878   for (const BufferType* colocation :
879        GetTransitiveColocations(buffer_interval)) {
880     max_colocation_size =
881         std::max(max_colocation_size, buffer_intervals_.at(colocation).size);
882   }
883 
884   // Subtract chunks that are in use from the free chunks.
885   auto subtract_used_chunks = [&](const std::vector<Chunk>& used_chunks) {
886     for (const Chunk& used_chunk : used_chunks) {
887       // Find the free chunks containing the start and end of the used chunk.
888       auto it_end = free_chunks.lower_bound(used_chunk.chunk_end());
889       if (it_end == free_chunks.end()) continue;
890       auto it_start = free_chunks.lower_bound(used_chunk.offset);
891 
892       // Store original free chunk end, in case `it_start == it_end`.
893       int64_t free_chunk_end = it_end->second;
894 
895       // Subtract from free chunk containing start of used range, removing if it
896       // becomes too small for the buffer.
897       if (it_start != free_chunks.end()) {
898         if (used_chunk.offset - it_start->first >= buffer_interval.size) {
899           it_start->second = std::min(it_start->second, used_chunk.offset);
900         } else {
901           ++it_start;  // Increment iterator so that this entry is erased below.
902         }
903       }
904 
905       // Erase from the start chunk (possibly inclusive) to the end chunk
906       // (always inclusive). We iterate from end to start, as the map is in
907       // reverse order.
908       free_chunks.erase(it_end, it_start);
909 
910       // Create a new free chunk after the used chunk, if it is large enough.
911       int64_t chunk_end_aligned = RoundUpTo(used_chunk.chunk_end(), alignment_);
912       if (free_chunk_end - chunk_end_aligned >= max_colocation_size) {
913         CHECK(free_chunks.insert({chunk_end_aligned, free_chunk_end}).second);
914       }
915     }
916   };
917 
918   subtract_used_chunks(interval_tree_.ChunksOverlappingInTime(
919       buffer_interval.start, buffer_interval.end));
920 
921   for (const BufferType* colocation :
922        GetTransitiveColocations(buffer_interval)) {
923     const BufferInterval& interval = buffer_intervals_.at(colocation);
924     VLOG(1) << "  Alias size " << interval.size << ", start " << interval.start
925             << ", end " << interval.end << " " << interval.buffer->ToString();
926 
927     subtract_used_chunks(
928         interval_tree_.ChunksOverlappingInTime(interval.start, interval.end));
929   }
930 
931   // Try to find a large enough free chunk containing the preferred offset.
932   Chunk chunk{preferred_offset, max_colocation_size};
933   auto it = (preferred_offset < 0) ? free_chunks.end()
934                                    : free_chunks.lower_bound(preferred_offset);
935   if (it == free_chunks.end() || (it->second < chunk.chunk_end())) {
936     // Otherwise, find the smallest free chunk. In the case of a tie, prefer the
937     // smallest offset. We ensure above that all of the free chunks are large
938     // enough to store the buffer.
939     chunk.offset = absl::c_min_element(free_chunks, [](auto a, auto b) {
940                      return std::forward_as_tuple(a.second - a.first, a.first) <
941                             std::forward_as_tuple(b.second - b.first, b.first);
942                    })->first;
943   }
944   return chunk;
945 }
946 
947 template <typename BufferType>
CommitChunk(const GlobalDecreasingSizeBestFitHeap<BufferType>::BufferInterval & buffer_interval,GlobalDecreasingSizeBestFitHeap<BufferType>::Chunk chunk)948 void GlobalDecreasingSizeBestFitHeap<BufferType>::CommitChunk(
949     const GlobalDecreasingSizeBestFitHeap<BufferType>::BufferInterval&
950         buffer_interval,
951     GlobalDecreasingSizeBestFitHeap<BufferType>::Chunk chunk) {
952   // Update the maximum heap size according to the one determined by the chunk
953   // candidate. In case of colocations of different sizes, the chunk size
954   // returned is the maximum of all colocations, so use this value to update the
955   // heap size.
956   result_.heap_size = result_.UpdatedHeapSize(chunk);
957   // Now, update the chunk size to the actual size of the buffer interval.
958   chunk.size = buffer_interval.size;
959   interval_tree_.Add(buffer_interval.start, buffer_interval.end, chunk);
960   for (auto colocation : GetTransitiveColocations(buffer_interval)) {
961     auto colocation_interval = buffer_intervals_[colocation];
962     // Create a colocation chunk with the same offset but with the correct size
963     // of the colocated interval in case the colocations are of different sizes.
964     Chunk colocation_chunk{chunk.offset, colocation_interval.size};
965     AddToChunkMap(colocation, colocation_chunk);
966     interval_tree_.Add(colocation_interval.start, colocation_interval.end,
967                        colocation_chunk);
968   }
969 
970   AddToChunkMap(buffer_interval.buffer, chunk);
971 }
972 
973 template <typename BufferType>
AddToChunkMap(const BufferType * buffer,Chunk chunk)974 void GlobalDecreasingSizeBestFitHeap<BufferType>::AddToChunkMap(
975     const BufferType* buffer, Chunk chunk) {
976   const auto emplace_result = result_.chunk_map.emplace(buffer, chunk);
977   DCHECK(emplace_result.second);
978 }
979 
980 HeapSimulator::Result<HloValue>
Finish()981 ConstrainedGlobalDecreasingSizeBestFitHeap::Finish() {
982   std::vector<BufferInterval> sorted_buffer_vec = GetSortedBufferIntervals();
983   // Convert into std::list so that erase() is O(1).
984   std::list<BufferInterval> sorted_buffer_intervals(sorted_buffer_vec.begin(),
985                                                     sorted_buffer_vec.end());
986 
987   // Use do-while here, because we need to create 1 heap in `multi_heap_result`
988   // even if `sorted_buffer_intervals` is empty.
989   Result multi_heap_result;
990   do {
991     // Place buffers into the currently processed heap as many as possible.
992     for (auto it = sorted_buffer_intervals.begin();
993          it != sorted_buffer_intervals.end();) {
994       BufferInterval buffer_interval = *it;
995       if (!buffer_interval.need_allocation) {
996         it = sorted_buffer_intervals.erase(it);
997         continue;
998       }
999       if (buffer_interval.size > size_limit_per_heap_) {
1000         LOG(WARNING) << "Alloc buffer size " << buffer_interval.size
1001                      << " larger than the per-heap size limit "
1002                      << size_limit_per_heap_;
1003       }
1004 
1005       Chunk chunk_candidate = FindChunkCandidate(buffer_interval);
1006       if (chunk_candidate.chunk_end() <= size_limit_per_heap_ ||
1007           // Commit the chunk as long as the heap is empty. We do this because
1008           // we want the size constraint to be soft, meaning that results are
1009           // successfully generated even if there are some buffer sizes larger
1010           // than the given constraint size.
1011           result_.heap_size == 0) {
1012         CommitChunk(buffer_interval, chunk_candidate);
1013         it = sorted_buffer_intervals.erase(it);
1014         continue;
1015       }
1016 
1017       ++it;
1018     }
1019     // Collect the result from the currently processed heap and reset the heap
1020     // states.
1021     multi_heap_result.heap_size += result_.heap_size;
1022     multi_heap_result.heap_results.push_back(std::move(result_));
1023     result_ = {};
1024     interval_tree_ = {};
1025   } while (!sorted_buffer_intervals.empty());
1026 
1027   VLOG(1) << "Number of heaps produced = "
1028           << multi_heap_result.heap_results.size();
1029   return multi_heap_result;
1030 }
1031 
1032 template <typename BufferType>
1033 HeapSimulator::Result<BufferType>
Finish()1034 ChooseBestHeapAlgorithm<BufferType>::Finish() {
1035   DCHECK(!algorithms_.empty());
1036   std::vector<Result> results(algorithms_.size());
1037   int64_t min_size = INT64_MAX;
1038   int min_size_index = -1;
1039   for (int i = 0; i < algorithms_.size(); ++i) {
1040     results[i] = algorithms_[i]->Finish();
1041     if (results[i].heap_size < min_size) {
1042       min_size = results[i].heap_size;
1043       min_size_index = i;
1044     }
1045   }
1046 
1047   DCHECK_GE(min_size_index, 0);
1048   return results[min_size_index];
1049 }
1050 
1051 template class GlobalDecreasingSizeBestFitHeap<HloValue>;
1052 template class GlobalDecreasingSizeBestFitHeap<
1053     MemorySpaceAssignmentRepacker::AllocationBlock>;
1054 template class ChooseBestHeapAlgorithm<HloValue>;
1055 
1056 }  // namespace xla
1057