• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/heap_simulator.h"
17 
18 #include <algorithm>
19 #include <vector>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/memory/memory.h"
24 #include "tensorflow/compiler/xla/map_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_live_range.h"
26 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
27 #include "tensorflow/compiler/xla/service/memory_space_assignment_repacking.h"
28 #include "tensorflow/compiler/xla/util.h"
29 
30 namespace xla {
31 
32 using absl::flat_hash_map;
33 using absl::flat_hash_set;
34 
OverlapsWith(Chunk other_chunk) const35 bool HeapSimulator::Chunk::OverlapsWith(Chunk other_chunk) const {
36   CHECK_NE(size, 0);
37   CHECK_NE(other_chunk.size, 0);
38   return offset < other_chunk.chunk_end() && other_chunk.offset < chunk_end();
39 }
40 
41 /*static*/
MinimumMemoryForModule(const HloSchedule & schedule,const LogicalBuffer::SizeFunction & size_function)42 StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
43     const HloSchedule& schedule,
44     const LogicalBuffer::SizeFunction& size_function) {
45   if (schedule.empty()) {
46     return 0;
47   }
48   const HloModule* module = schedule.module();
49 
50   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
51                       HloAliasAnalysis::Run(module));
52 
53   // The absolute minimum memory required for a given sequence of instructions
54   // is determined by the sequence of Alloc and Free calls on a simulated heap,
55   // ignoring fragmentation. We run the heap simulation on the whole module,
56   // rather than summing each computation, since it gives us a better lower
57   // bound, by minimizing the liveness of sub-computations.
58   TF_ASSIGN_OR_RETURN(
59       HeapSimulator::Result<HloValue> result,
60       HeapSimulator::Run(
61           absl::make_unique<NoFragmentationStatsHeap<HloValue>>(), *module,
62           schedule, *alias_analysis, size_function));
63   return result.heap_size;
64 }
65 
66 /*static*/
MinimumMemoryForComputation(const HloComputation & computation,const HloInstructionSequence & sequence,const HloAliasAnalysis & alias_analysis,const LogicalBuffer::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> * memory_by_computation)67 StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
68     const HloComputation& computation, const HloInstructionSequence& sequence,
69     const HloAliasAnalysis& alias_analysis,
70     const LogicalBuffer::SizeFunction& size_function,
71     const absl::flat_hash_map<const HloComputation*, int64>*
72         memory_by_computation) {
73   TF_ASSIGN_OR_RETURN(
74       HeapSimulator::Result<HloValue> result,
75       HeapSimulator::Run(
76           absl::make_unique<NoFragmentationStatsHeap<HloValue>>(), computation,
77           sequence, alias_analysis, size_function, HeapSimulator::Options(),
78           memory_by_computation));
79   return result.heap_size;
80 }
81 
MinimumMemoryForComputation(const HloComputation & computation,const HloInstructionSequence & sequence,const HloAliasAnalysis & alias_analysis,const LogicalBuffer::SizeFunction & size_function,const HloSchedule * schedule)82 StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
83     const HloComputation& computation, const HloInstructionSequence& sequence,
84     const HloAliasAnalysis& alias_analysis,
85     const LogicalBuffer::SizeFunction& size_function,
86     const HloSchedule* schedule) {
87   TF_ASSIGN_OR_RETURN(
88       HeapSimulator::Result<HloValue> result,
89       HeapSimulator::Run(
90           absl::make_unique<NoFragmentationStatsHeap<HloValue>>(), computation,
91           sequence, alias_analysis, size_function, schedule,
92           HeapSimulator::Options()));
93   return result.heap_size;
94 }
95 
96 /*static*/
Run(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,const HloModule & module,const HloSchedule & schedule,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_fn,const Options & options)97 StatusOr<HeapSimulator::Result<HloValue>> HeapSimulator::Run(
98     std::unique_ptr<HeapAlgorithm<HloValue>> algorithm, const HloModule& module,
99     const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis,
100     const BufferValue::SizeFunction& size_fn, const Options& options) {
101   HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule);
102   const HloComputation* entry_computation = module.entry_computation();
103   const HloInstructionSequence& instruction_sequence =
104       schedule.sequence(entry_computation);
105   TF_ASSIGN_OR_RETURN(
106       std::unique_ptr<HloLiveRange> hlo_live_range,
107       HloLiveRange::Run(schedule, alias_analysis, entry_computation));
108   TF_RETURN_IF_ERROR(heap.RunComputation(*entry_computation,
109                                          instruction_sequence, alias_analysis,
110                                          hlo_live_range.get()));
111   return heap.Finish();
112 }
113 
114 /*static*/
Run(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_fn,const Options & options,const absl::flat_hash_map<const HloComputation *,int64> * memory_by_computation)115 StatusOr<HeapSimulator::Result<HloValue>> HeapSimulator::Run(
116     std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
117     const HloComputation& computation,
118     const HloInstructionSequence& instruction_sequence,
119     const HloAliasAnalysis& alias_analysis,
120     const BufferValue::SizeFunction& size_fn, const Options& options,
121     const absl::flat_hash_map<const HloComputation*, int64>*
122         memory_by_computation) {
123   HeapSimulator heap(std::move(algorithm), size_fn, options,
124                      /*schedule=*/nullptr, memory_by_computation);
125   HloSchedule schedule(computation.parent());
126   schedule.set_sequence(&computation, instruction_sequence);
127   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
128                       HloLiveRange::Run(schedule, alias_analysis, &computation,
129                                         /*module_scoped_analysis=*/false));
130   TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
131                                          alias_analysis, hlo_live_range.get()));
132   return heap.Finish();
133 }
134 
135 /*static*/
Run(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_fn,const HloSchedule * schedule,const Options & options)136 StatusOr<HeapSimulator::Result<HloValue>> HeapSimulator::Run(
137     std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
138     const HloComputation& computation,
139     const HloInstructionSequence& instruction_sequence,
140     const HloAliasAnalysis& alias_analysis,
141     const BufferValue::SizeFunction& size_fn, const HloSchedule* schedule,
142     const Options& options) {
143   HeapSimulator heap(std::move(algorithm), size_fn, options,
144                      /*schedule=*/schedule, nullptr);
145   TF_ASSIGN_OR_RETURN(
146       std::unique_ptr<HloLiveRange> hlo_live_range,
147       HloLiveRange::Run(*schedule, alias_analysis, &computation));
148   TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
149                                          alias_analysis, hlo_live_range.get()));
150   return heap.Finish();
151 }
152 
153 // Runs a heap simulation for the given 'computation', assuming the given
154 // 'instruction_sequence'.
RunComputation(const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const HloAliasAnalysis & alias_analysis,HloLiveRange * hlo_live_range)155 Status HeapSimulator::RunComputation(
156     const HloComputation& computation,
157     const HloInstructionSequence& instruction_sequence,
158     const HloAliasAnalysis& alias_analysis, HloLiveRange* hlo_live_range) {
159   XLA_VLOG_LINES(1, computation.parent()->ToString());
160   XLA_VLOG_LINES(2, computation.ToString());
161 
162   VLOG(1) << hlo_live_range->ToString();
163 
164   HloDataflowAnalysis& dataflow_analysis = alias_analysis.dataflow_analysis();
165 
166   // Record the buffer define/free event for each time step. We free all
167   // remaining buffers (entry parameter, etc) after the program has finished
168   // running, so we set the size of to program_end_time + 1.
169   std::vector<std::vector<const HloValue*>> buffers_defined(
170       hlo_live_range->schedule_end_time() + 1);
171   std::vector<std::vector<const HloValue*>> buffers_freed(
172       hlo_live_range->schedule_end_time() + 1);
173 
174   // values_to_assign tracks the HloValues that we need to assign a buffer to.
175   // Note that we only need to assign a buffer to a value when both of the
176   // following conditions are met:
177   //
178   // - The user specifically asks us to assign a buffer to a set of HloValues,
179   // and the value is in the set. If the user don't provide such a set, by
180   // default we assign buffer to all HloValues.
181   //
182   // - If the instruction is in a nested call of the current computation, only
183   // assign a buffer if we are doing global heap simulation.
184   std::vector<const HloValue*> values_to_assign;
185   values_to_assign.reserve(dataflow_analysis.values().size());
186 
187   auto& buffer_live_ranges = hlo_live_range->buffer_live_ranges();
188 
189   for (const HloValue* value : dataflow_analysis.values()) {
190     // Ignore buffers that are not tracked.
191     if (!buffer_live_ranges.contains(value)) {
192       continue;
193     }
194     if (IgnoreBuffer(value)) {
195       continue;
196     }
197 
198     values_to_assign.push_back(value);
199   }
200 
201   absl::c_sort(values_to_assign,
202                [&](const HloValue* value1, const HloValue* value2) {
203                  const auto& live_range1 = buffer_live_ranges.at(value1);
204                  const auto& live_range2 = buffer_live_ranges.at(value2);
205                  return std::forward_as_tuple(live_range1.start,
206                                               live_range1.end, value1->id()) <
207                         std::forward_as_tuple(live_range2.start,
208                                               live_range2.end, value2->id());
209                });
210 
211   // For each value that we need to assign a buffer to, add the define and free
212   // events.
213   for (const HloValue* value : values_to_assign) {
214     auto live_range = buffer_live_ranges.at(value);
215     buffers_defined[live_range.start].push_back(value);
216     buffers_freed[live_range.end].push_back(value);
217   }
218 
219   // All HloValues in a hlo buffer should be allocated to the same address. This
220   // map tracks the first value that got allocated in a buffer.
221   absl::flat_hash_map<const HloBuffer*, const HloValue*> first_allocated_value;
222 
223   VLOG(1) << "Program time" << hlo_live_range->schedule_end_time();
224 
225   // Go through each step in the program and replay each buffer define and free
226   // events.
227   for (int64_t i = 0; i < hlo_live_range->schedule_end_time() + 1; ++i) {
228     VLOG(1) << "Time step: " << i;
229 
230     for (const HloValue* value : buffers_defined[i]) {
231       bool shared = false;
232       VLOG(1) << "Start buffer: " << value->ToShortString();
233       const HloBuffer* hlo_buffer =
234           &alias_analysis.GetBufferContainingValue(*value);
235       if (first_allocated_value.count(hlo_buffer) != 0) {
236         // We've already assigned an address for another value in this HloBuffer
237         // (HloBuffer holds several aliased HloValues). All values in a buffer
238         // should be assigned the same address. Find the one that's already
239         // allocated and reuse its address.
240         ShareBuffer(value, first_allocated_value[hlo_buffer],
241                     value->instruction());
242         VLOG(1) << "  ShareWith"
243                 << first_allocated_value[hlo_buffer]->ToShortString();
244         continue;
245       }
246       if (options_.may_reuse_operand_buffers &&
247           hlo_buffer->values().size() == 1) {
248         // We don't support sharing an aliased buffer
249         // (hlo_buffer->values().size() > 1) with its operand.
250         for (const HloInstruction* operand : value->instruction()->operands()) {
251           const HloValueSet operand_value_set =
252               dataflow_analysis.GetValueSet(operand);
253           for (const HloValue* operand_value : operand_value_set.values()) {
254             const HloBuffer* operand_buffer =
255                 &alias_analysis.GetBufferContainingValue(*operand_value);
256             if (operand_buffer->values().size() > 1) {
257               continue;
258             }
259             auto it = buffer_live_ranges.find(operand_value);
260             if (it == buffer_live_ranges.end()) {
261               continue;
262             }
263 
264             auto& operand_live_range = it->second;
265 
266             auto& user_live_range = buffer_live_ranges[value];
267 
268             // Can only share buffers that are about to be freed.
269             if (operand_live_range.end != i) {
270               continue;
271             }
272 
273             if (IgnoreBuffer(operand_value)) {
274               continue;
275             }
276 
277             if (!absl::c_linear_search(buffers_freed[i], operand_value)) {
278               // If the operand buffer is not being freed (either because it has
279               // existing users, or it has been reused by other buffers), don't
280               // consider the operand as a candidate of buffer sharing.
281               continue;
282             }
283 
284             // The instruction that defines the operand value can be different
285             // from the actual operand, if directly passing the defining
286             // instruction into "CanShareOperandBufferWithUser" it creates a
287             // check failure. The first condition guards against that case.
288             if (value->instruction()->IsUserOf(operand_value->instruction()) &&
289                 value->instruction()->opcode() != HloOpcode::kCopy &&
290                 dataflow_analysis.CanShareOperandBufferWithUser(
291                     operand_value->instruction(), operand_value->index(),
292                     value->instruction(), value->index())) {
293               // Remove the operand buffer right before sharing (allocating) a
294               // new one.
295               Free(operand_value, operand_value->instruction());
296               buffers_freed[i].erase(
297                   std::remove(buffers_freed[i].begin(), buffers_freed[i].end(),
298                               operand_value),
299                   buffers_freed[i].end());
300               ShareBuffer(value, operand_value, value->instruction());
301               // The live range of the operand buffer is now extended to the end
302               // of the current instruction.
303               operand_live_range.end = user_live_range.end;
304               VLOG(1) << "Sharing " << value->ToShortString() << " with "
305                       << operand_value->ToShortString()
306                       << ", size:" << size_fn_(*value);
307               shared = true;
308               break;
309             }
310           }
311           if (shared) {
312             break;
313           }
314         }
315       }
316       if (!shared) {
317         Alloc(value, value->instruction());
318         first_allocated_value[hlo_buffer] = value;
319       }
320     }
321 
322     if (!buffers_freed[i].empty()) {
323       VLOG(1) << "Free Buffer: ";
324     }
325     for (const HloValue* value : buffers_freed[i]) {
326       VLOG(1) << "  " << value->ToShortString();
327 
328       Free(value, value->instruction());
329     }
330   }
331   return Status::OK();
332 }
333 
HeapSimulator(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,const BufferValue::SizeFunction & size_fn,const Options & options,const HloSchedule * schedule,const absl::flat_hash_map<const HloComputation *,int64> * memory_by_computation)334 HeapSimulator::HeapSimulator(
335     std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
336     const BufferValue::SizeFunction& size_fn, const Options& options,
337     const HloSchedule* schedule,
338     const absl::flat_hash_map<const HloComputation*, int64>*
339         memory_by_computation)
340     : no_fragmentation_stats_(
341           absl::make_unique<NoFragmentationStatsHeap<HloValue>>()),
342       algorithm_(std::move(algorithm)),
343       size_fn_(size_fn),
344       options_(options),
345       schedule_(schedule),
346       memory_by_computation_(memory_by_computation) {
347   debug_trace_.set_whole_module_simulation(schedule_ != nullptr);
348 }
349 
~HeapSimulator()350 HeapSimulator::~HeapSimulator() {}
351 
IgnoreBuffer(const HloValue * buffer) const352 bool HeapSimulator::IgnoreBuffer(const HloValue* buffer) const {
353   // Buffers for constants are ignored unless the alloc_constants option is
354   // set. Also ignore buffers that we're not meant to assign.
355   //
356   // TODO(b/32248867): For consistency, constants should get allocations.
357   if (!options_.alloc_constants &&
358       buffer->instruction()->opcode() == HloOpcode::kConstant) {
359     return true;
360   }
361   return options_.buffers_to_assign != nullptr &&
362          !options_.buffers_to_assign->contains(buffer);
363 }
364 
365 // Alloc always calls the underlying heap algorithm.
Alloc(const HloValue * buffer,const HloInstruction * instruction)366 void HeapSimulator::Alloc(const HloValue* buffer,
367                           const HloInstruction* instruction) {
368   CHECK(!allocated_buffers_.contains(buffer))
369       << "Alloc called on allocated buffer: " << *buffer;
370   CHECK(!freed_buffers_.contains(buffer))
371       << "Alloc called on freed buffer: " << *buffer;
372 
373   allocated_buffers_.insert(buffer);
374   const int64_t size = size_fn_(*buffer);
375   algorithm_->Alloc(buffer, size);
376   no_fragmentation_stats_->Alloc(buffer, size);
377   FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction,
378                  nullptr);
379 }
380 
381 // Free calls the underlying algorithm for non-shared buffers, and for shared
382 // buffers whose group liveness has expired.  Shared group liveness is tracked
383 // by maintaining a refcount; the Free call on the last buffer in the group
384 // causes Free to be called on the underlying algorithm.
Free(const HloValue * buffer,const HloInstruction * instruction)385 void HeapSimulator::Free(const HloValue* buffer,
386                          const HloInstruction* instruction) {
387   const int64_t size = size_fn_(*buffer);
388   algorithm_->Free(buffer, size);
389   no_fragmentation_stats_->Free(buffer, size);
390   FillDebugTrace(HeapSimulatorTrace::Event::FREE, buffer, instruction, nullptr);
391 }
392 
393 // ShareBuffer associates buffers with their SharedGroup in shared_buffers_.
394 // The 'buffer' must be a non-allocated, non-freed buffer, just like in calls
395 // to Alloc.  The 'shared' buffer must be a previously allocated or shared
396 // buffer. Both 'buffer' and 'shared' will be associated with the same
397 // SharedGroup.
ShareBuffer(const HloValue * buffer,const HloValue * shared,const HloInstruction * instruction)398 void HeapSimulator::ShareBuffer(const HloValue* buffer, const HloValue* shared,
399                                 const HloInstruction* instruction) {
400   algorithm_->ShareWith(buffer, shared, size_fn_(*shared));
401   no_fragmentation_stats_->ShareWith(buffer, shared, size_fn_(*shared));
402   FillDebugTrace(HeapSimulatorTrace::Event::SHARE_WITH, buffer, instruction,
403                  shared);
404 }
405 
Finish()406 HeapSimulator::Result<HloValue> HeapSimulator::Finish() {
407   Result<HloValue> result = algorithm_->Finish();
408 
409   // Post-process the result to add chunks for shared buffers.  An empty chunk
410   // map means that either no buffers were allocated, or the heap was only
411   // collecting statistics, e.g. NoFragmentationStatsHeap.
412   size_t total_chunk_count = absl::c_accumulate(
413       result.heap_results, static_cast<size_t>(0),
414       [&](size_t lhs, const HeapResult<HloValue>& rhs) -> size_t {
415         return lhs + rhs.chunk_map.size();
416       });
417   if (total_chunk_count != 0) {
418     // If we were told to assign specific buffers, make sure we've assigned
419     // exactly that many buffers.
420     if (options_.buffers_to_assign != nullptr) {
421       CHECK_EQ(options_.buffers_to_assign->size(), total_chunk_count);
422     }
423   }
424 
425   // Fragmentation is the difference between the actual and ideal sizes.
426   const Result<HloValue> no_frag_result = no_fragmentation_stats_->Finish();
427   result.fragmentation_size = result.heap_size - no_frag_result.heap_size;
428 
429   // Copy the debug trace we collected to the final result.
430   result.debug_trace.Swap(&debug_trace_);
431 
432   return result;
433 }
434 
FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,const HloValue * buffer,const HloInstruction * instruction,const HloValue * share_with_canonical)435 void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,
436                                    const HloValue* buffer,
437                                    const HloInstruction* instruction,
438                                    const HloValue* share_with_canonical) {
439   HeapSimulatorTrace::Event* event = debug_trace_.add_events();
440   event->set_kind(kind);
441   event->set_buffer_id(buffer->id());
442   event->set_computation_name(instruction->parent()->name());
443   event->set_instruction_name(instruction->name());
444   if (kind == HeapSimulatorTrace::Event::SHARE_WITH) {
445     CHECK(share_with_canonical != nullptr);
446     event->set_share_with_canonical_id(share_with_canonical->id());
447   } else {
448     CHECK(share_with_canonical == nullptr);
449   }
450 }
451 
452 template <typename BufferType>
Alloc(const BufferType * buffer,int64_t size)453 void NoFragmentationStatsHeap<BufferType>::Alloc(const BufferType* buffer,
454                                                  int64_t size) {
455   current_heap_size_ += size;
456   if (current_heap_size_ > max_heap_size_) {
457     max_heap_size_ = current_heap_size_;
458   }
459 }
460 
461 template <typename BufferType>
AccountForSubcomputationMemory(const HloInstruction * instruction,int64_t alloc_size_by_instruction,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation)462 void NoFragmentationStatsHeap<BufferType>::AccountForSubcomputationMemory(
463     const HloInstruction* instruction, int64_t alloc_size_by_instruction,
464     const absl::flat_hash_map<const HloComputation*, int64>&
465         memory_by_computation) {
466   // We only count the memory usage of the largest subcomputation, instead of
467   // adding them all, because subcomputations won't execute in parallel.
468   int64_t max_subcomputation_bytes = 0;
469   for (const auto* c : instruction->called_computations()) {
470     auto it = memory_by_computation.find(c);
471     if (it != memory_by_computation.end()) {
472       int64_t subcomputation_bytes = it->second;
473       if (subcomputation_bytes > max_subcomputation_bytes) {
474         max_subcomputation_bytes = subcomputation_bytes;
475       }
476     }
477   }
478   if (max_subcomputation_bytes > 0 &&
479       (instruction->opcode() == HloOpcode::kWhile ||
480        instruction->opcode() == HloOpcode::kCall ||
481        instruction->opcode() == HloOpcode::kConditional)) {
482     // The output buffer of while/call/conditional is always aliased with the
483     // output buffer of the root instruction in the body. Don't double count.
484     max_subcomputation_bytes -= alloc_size_by_instruction;
485   }
486   max_heap_size_ =
487       std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes);
488 }
489 
490 template <typename BufferType>
Free(const BufferType * buffer,int64_t size)491 void NoFragmentationStatsHeap<BufferType>::Free(const BufferType* buffer,
492                                                 int64_t size) {
493   current_heap_size_ -= size;
494 }
495 
496 template <typename BufferType>
497 HeapSimulator::Result<BufferType>
Finish()498 NoFragmentationStatsHeap<BufferType>::Finish() {
499   // The result.chunk_map is empty, since we only collect stats, and don't
500   // actually compute chunk assignments.
501   Result result;
502   result.heap_size = max_heap_size_;
503   return result;
504 }
505 
506 template <typename BufferType>
GlobalDecreasingSizeBestFitHeap(int64_t alignment,Type type)507 GlobalDecreasingSizeBestFitHeap<BufferType>::GlobalDecreasingSizeBestFitHeap(
508     int64_t alignment, Type type)
509     : alignment_(alignment) {
510   if (type == kTemporal) {
511     buffer_interval_compare_ = GetTemporalBufferIntervalCompare();
512   } else {
513     CHECK(type == kSpatial);
514     buffer_interval_compare_ = GetSpatialBufferIntervalCompare();
515   }
516 }
517 
518 template <typename BufferType>
519 typename GlobalDecreasingSizeBestFitHeap<BufferType>::BufferIntervalCompare
GetTemporalBufferIntervalCompare() const520 GlobalDecreasingSizeBestFitHeap<BufferType>::GetTemporalBufferIntervalCompare()
521     const {
522   return [&](const BufferInterval& x, const BufferInterval& y) {
523     int64_t x_end = x.end;
524     for (auto colocation : GetTransitiveColocations(x)) {
525       x_end = std::max(x_end, buffer_intervals_.at(colocation).end);
526     }
527 
528     int64_t y_end = y.end;
529     for (auto colocation : GetTransitiveColocations(y)) {
530       y_end = std::max(y_end, buffer_intervals_.at(colocation).end);
531     }
532 
533     if (x_end - x.start != y_end - y.start) {
534       return x_end - x.start > y_end - y.start;
535     }
536 
537     if (x.size != y.size) {
538       return x.size > y.size;
539     }
540     return *x.buffer < *y.buffer;
541   };
542 }
543 
544 template <typename BufferType>
545 /*static*/ typename GlobalDecreasingSizeBestFitHeap<
546     BufferType>::BufferIntervalCompare
GetSpatialBufferIntervalCompare()547 GlobalDecreasingSizeBestFitHeap<BufferType>::GetSpatialBufferIntervalCompare() {
548   return [&](const BufferInterval& x, const BufferInterval& y) {
549     if (x.size != y.size) {
550       return x.size > y.size;
551     }
552     if (x.end - x.start != y.end - y.start) {
553       return x.end - x.start > y.end - y.start;
554     }
555     return *x.buffer < *y.buffer;
556   };
557 }
558 
559 template <typename BufferType>
Alloc(const BufferType * buffer,int64_t size)560 void GlobalDecreasingSizeBestFitHeap<BufferType>::Alloc(
561     const BufferType* buffer, int64_t size) {
562   // Degenerate case: 0-sized buffers are always allocated at offset 0.
563   if (size == 0) {
564     result_.chunk_map.emplace(buffer, Chunk{0, 0});
565     return;
566   }
567 
568   auto emplace_result = buffer_intervals_.emplace(
569       buffer, BufferInterval{buffer, size, current_time_, -1, {}, true});
570   DCHECK(emplace_result.second);
571   ++current_time_;
572 }
573 
574 template <typename BufferType>
ShareWith(const BufferType * buffer,const BufferType * share_with,int64_t size)575 void GlobalDecreasingSizeBestFitHeap<BufferType>::ShareWith(
576     const BufferType* buffer, const BufferType* share_with, int64_t size) {
577   // Degenerate case: 0-sized buffers are always allocated at offset 0.
578   if (size == 0) {
579     result_.chunk_map.emplace(buffer, Chunk{0, 0});
580     return;
581   }
582   DCHECK_NE(buffer_intervals_.count(share_with), 0);
583   buffer_intervals_[share_with].colocations.push_back(buffer);
584   auto emplace_result = buffer_intervals_.emplace(
585       buffer, BufferInterval{buffer, size, current_time_, -1, {}, false});
586   DCHECK(emplace_result.second);
587   ++current_time_;
588 }
589 
590 template <typename BufferType>
591 absl::flat_hash_set<const BufferType*>
GetTransitiveColocations(const BufferInterval & interval) const592 GlobalDecreasingSizeBestFitHeap<BufferType>::GetTransitiveColocations(
593     const BufferInterval& interval) const {
594   absl::flat_hash_set<const BufferType*> result;
595   std::vector<const BufferInterval*> worklist = {&interval};
596   while (!worklist.empty()) {
597     const BufferInterval* item = worklist.back();
598     worklist.pop_back();
599     for (const BufferType* buffer_colocated : item->colocations) {
600       result.insert(buffer_colocated);
601       worklist.push_back(&buffer_intervals_.at(buffer_colocated));
602     }
603   }
604 
605   return result;
606 }
607 
608 template <typename BufferType>
Free(const BufferType * buffer,int64_t size)609 void GlobalDecreasingSizeBestFitHeap<BufferType>::Free(const BufferType* buffer,
610                                                        int64_t size) {
611   // Degenerate case: 0-sized buffers are always allocated at offset 0.
612   if (size == 0) {
613     return;
614   }
615   BufferInterval& buffer_interval = FindOrDie(buffer_intervals_, buffer);
616   DCHECK_EQ(buffer_interval.buffer, buffer);
617   DCHECK_EQ(buffer_interval.size, size);
618   DCHECK_EQ(buffer_interval.end, -1);
619   if (buffer_interval.end != -1) {
620     return;
621   }
622   buffer_interval.end = current_time_;
623   ++current_time_;
624 }
625 
626 using Chunk = HeapSimulator::Chunk;
627 
Add(int64_t start,int64_t end,const Chunk & chunk)628 void BufferIntervalTree::Add(int64_t start, int64_t end, const Chunk& chunk) {
629   node_storage_.emplace_back(BufferIntervalTreeNode{
630       start, end, end, chunk,
631       /*left=*/nullptr, /*right=*/nullptr, /*parent=*/nullptr});
632   if (root_ == nullptr) {
633     root_ = &node_storage_.back();
634     // This is root.
635     return;
636   }
637 
638   BufferIntervalTreeNode* parent = root_;
639   while (true) {
640     parent->subtree_end = std::max(parent->subtree_end, end);
641     if (parent->start > start) {
642       if (parent->left == nullptr) {
643         parent->left = &node_storage_.back();
644         node_storage_.back().parent = parent;
645         return;
646       }
647       parent = parent->left;
648     } else {
649       if (parent->right == nullptr) {
650         parent->right = &node_storage_.back();
651         node_storage_.back().parent = parent;
652         return;
653       }
654       parent = parent->right;
655     }
656   }
657 }
658 
Remove(int64_t start,int64_t end,const Chunk & chunk)659 bool BufferIntervalTree::Remove(int64_t start, int64_t end,
660                                 const Chunk& chunk) {
661   BufferIntervalTreeNode* to_delete = root_;
662   while (to_delete != nullptr) {
663     if (to_delete->start == start && to_delete->end == end &&
664         to_delete->chunk.offset == chunk.offset) {
665       break;
666     }
667     if (start < to_delete->start) {
668       to_delete = to_delete->left;
669     } else {
670       to_delete = to_delete->right;
671     }
672   }
673   if (to_delete == nullptr) {
674     // Nothing to delete.
675     return false;
676   }
677   // Found the node to be deleted, enter deletion sequence.
678 
679   // Recursively traverse the parents of node and fix up the `subtree_end`
680   // invariant of a node. Recursive lambda need an explicit
681   // std::function declaration.
682   std::function<void(BufferIntervalTreeNode*)> fix_up =
683       [&](BufferIntervalTreeNode* node) {
684         if (node == nullptr) {
685           return;
686         }
687         node->subtree_end = node->end;
688         if (node->left) {
689           node->subtree_end =
690               std::max(node->subtree_end, node->left->subtree_end);
691         }
692         if (node->right) {
693           node->subtree_end =
694               std::max(node->subtree_end, node->right->subtree_end);
695         }
696         // Recursively go up.
697         fix_up(node->parent);
698       };
699 
700   if (to_delete->right == nullptr) {
701     // to_delete has no right child, simply move up left child of to_delete if
702     // any.
703     //
704     // Turn:
705     //      parent
706     //       /
707     // to_delete
708     //  /      \
709     // left    nullptr
710     //
711     // Into:
712     //      parent
713     //      /
714     //    left
715     if (root_ == to_delete) {
716       // Deleting root is simply reseting root;
717       root_ = to_delete->left;
718       return true;
719     }
720 
721     if (to_delete == to_delete->parent->left) {
722       // to_delete is left child of parent.
723       to_delete->parent->left = to_delete->left;
724     }
725     if (to_delete == to_delete->parent->right) {
726       // to_delete is right child of parent.
727       to_delete->parent->right = to_delete->left;
728     }
729     // Rewire parent to the node being moved up.
730     if (to_delete->left) {
731       to_delete->left->parent = to_delete->parent;
732     }
733     // Fix up starting from subroot.
734     fix_up(to_delete);
735   } else {
736     // 1. Find left-most node of the right subtree, promote it to the position
737     // of to_delete.
738     BufferIntervalTreeNode* to_promote = to_delete->right;
739     while (to_promote->left != nullptr) {
740       // Go to left-most subtree.
741       to_promote = to_promote->left;
742     }
743 
744     // 2. Copy the content of `to_promote` to `to_delete`.
745     to_delete->start = to_promote->start;
746     to_delete->end = to_promote->end;
747     // This is incorrect but we will fix this up later in the `fix_up`
748     // procedure.
749     to_delete->subtree_end = to_promote->subtree_end;
750     to_delete->chunk = to_promote->chunk;
751     auto to_promote_parent = to_promote->parent;
752     // 3. Move the right child of `to_promote` up if there is any.
753     //
754     // Turn
755     //
756     // to_delete
757     //         \
758     //        to_promote_parent
759     //         /
760     //    to_promote
761     //          \
762     //          right
763     // into
764     //
765     // to_promote
766     //         \
767     //         to_promote_parent
768     //         /
769     //      right
770     if (to_promote_parent->left == to_promote) {
771       to_promote_parent->left = to_promote->right;
772     } else {
773       to_promote_parent->right = to_promote->right;
774     }
775     if (to_promote->right) {
776       // Set correct parent.
777       to_promote->right->parent = to_promote_parent;
778     }
779     // 4. Recursive fix up the `subtree_end` starting from
780     // `to_promote_parent`.
781     fix_up(to_promote_parent);
782   }
783   // Don't free the entry in node_storage_ until we free the entire tree.
784   return true;
785 }
786 
ChunksOverlappingInTime(int64_t start,int64_t end) const787 std::vector<Chunk> BufferIntervalTree::ChunksOverlappingInTime(
788     int64_t start, int64_t end) const {
789   std::vector<Chunk> result;
790   if (root_ == nullptr) {
791     return result;
792   }
793   std::vector<const BufferIntervalTreeNode*> visiting_stack;
794   visiting_stack.push_back(root_);
795   while (!visiting_stack.empty()) {
796     const BufferIntervalTreeNode* top = visiting_stack.back();
797     visiting_stack.pop_back();
798     if (start > top->subtree_end) {
799       continue;
800     }
801     if (top->left != nullptr) {
802       visiting_stack.push_back(top->left);
803     }
804     if (top->start <= end && top->end >= start) {
805       result.push_back(top->chunk);
806     }
807     if (end < top->start) {
808       continue;
809     }
810     if (top->right != nullptr) {
811       visiting_stack.push_back(top->right);
812     }
813   }
814   return result;
815 }
816 
817 template <typename BufferType>
818 HeapSimulator::Result<BufferType>
Finish()819 GlobalDecreasingSizeBestFitHeap<BufferType>::Finish() {
820   std::vector<BufferInterval> sorted_buffer_intervals =
821       GetSortedBufferIntervals();
822 
823   for (auto& buffer_interval : sorted_buffer_intervals) {
824     if (!buffer_interval.need_allocation) {
825       continue;
826     }
827 
828     ChunkCandidate chunk_candidate = FindChunkCandidate(buffer_interval);
829     // This implementation of the heap algorithm does not have a notion of
830     // maximum heap size, so it just commits.
831     CommitChunk(buffer_interval, chunk_candidate);
832   }
833   VLOG(1) << "result heap_size: " << result_.heap_size;
834   Result result;
835   result.heap_size = result_.heap_size;
836   result.heap_results.emplace_back(result_);
837   return result;
838 }
839 
840 template <typename BufferType>
841 std::vector<
842     typename GlobalDecreasingSizeBestFitHeap<BufferType>::BufferInterval>
GetSortedBufferIntervals() const843 GlobalDecreasingSizeBestFitHeap<BufferType>::GetSortedBufferIntervals() const {
844   std::vector<BufferInterval> sorted_buffer_intervals;
845   for (auto& entry : buffer_intervals_) {
846     sorted_buffer_intervals.push_back(entry.second);
847   }
848   absl::c_sort(sorted_buffer_intervals, buffer_interval_compare_);
849 
850   return sorted_buffer_intervals;
851 }
852 
853 template <typename BufferType>
854 typename GlobalDecreasingSizeBestFitHeap<BufferType>::ChunkCandidate
FindChunkCandidate(const GlobalDecreasingSizeBestFitHeap::BufferInterval & buffer_interval,int64_t preferred_offset) const855 GlobalDecreasingSizeBestFitHeap<BufferType>::FindChunkCandidate(
856     const GlobalDecreasingSizeBestFitHeap::BufferInterval& buffer_interval,
857     int64_t preferred_offset) const {
858   VLOG(1) << "Finding chunks for buffer: "
859           << buffer_interval.buffer->ToString();
860   VLOG(1) << "Size " << buffer_interval.size << ", start "
861           << buffer_interval.start << ", end " << buffer_interval.end;
862   auto chunks_overlapping_in_time = interval_tree_.ChunksOverlappingInTime(
863       buffer_interval.start, buffer_interval.end);
864   // Get all colocated buffers and gather all interferenced chunks.
865   //
866   // Imagine that we've already allocated three chunks : a, b and c.  And now
867   // we want to allocate d. Since e is colocated with d, we have to allocate
868   // chunks for them together at the same address. To do this, we first gather
869   // all chunks that overlap with d and e on the time dimension, in this case
870   // the overlapped chunks are a and b (c doesn't overlap with either of d and
871   // e), then find create a new chunk that doesn't overlap with a and b on the
872   // space dimension.
873   //
874   // space
875   //   ^
876   //   |+--d---+      +---e---+
877   //   |
878   //   |+---+  +---------------+  +-------+
879   //   ||   |  |               |  |       |
880   //   ||   |  |               |  |       |
881   //   |+-a-+  +-------b-------+  +---c---+
882   //   ----------------------------------------> time
883   for (auto colocation : GetTransitiveColocations(buffer_interval)) {
884     auto colocation_interval = buffer_intervals_.at(colocation);
885     auto colocation_overlapping = interval_tree_.ChunksOverlappingInTime(
886         colocation_interval.start, colocation_interval.end);
887     VLOG(1) << "  Alias size " << colocation_interval.size << ", start "
888             << colocation_interval.start << ", end " << colocation_interval.end
889             << " " << colocation_interval.buffer->ToString();
890     chunks_overlapping_in_time.insert(chunks_overlapping_in_time.end(),
891                                       colocation_overlapping.begin(),
892                                       colocation_overlapping.end());
893   }
894   absl::c_sort(chunks_overlapping_in_time, [](const Chunk& x, const Chunk& y) {
895     return x.offset < y.offset;
896   });
897 
898   // Find the minimum free chunk that can hold this buffer.
899   ChunkCandidate chunk_candidate{Chunk{-1, INT64_MAX}, result_.heap_size};
900   Chunk& min_fit_chunk = chunk_candidate.chunk;
901   int64_t preferred_chunk_end = preferred_offset + buffer_interval.size;
902   auto use_free_chunk_if_smaller = [&](int64_t free_offset, int64_t free_size) {
903     if (free_size < buffer_interval.size) {
904       return;
905     }
906 
907     // If a preferred offset is provided, pick that offset.
908     if (free_offset <= preferred_offset &&
909         free_offset + free_size >= preferred_chunk_end) {
910       min_fit_chunk = {preferred_offset, buffer_interval.size};
911     } else if (free_offset + free_size == result_.heap_size &&
912                free_offset <= preferred_offset) {
913       // If the free offset is at the very end and if the preferred offset lies
914       // in this, pick the preferred offset and grow the heap.
915       min_fit_chunk = {preferred_offset, buffer_interval.size};
916       chunk_candidate.heap_size = preferred_chunk_end;
917     }
918 
919     // Pick the min-fit chunk only if we didn't have a preferred offset or a
920     // chunk at the preferred offset hasn't been found.
921     if ((preferred_offset < 0 || min_fit_chunk.offset != preferred_offset) &&
922         free_size < min_fit_chunk.size) {
923       min_fit_chunk = {free_offset, free_size};
924     }
925   };
926 
927   int64_t offset = 0;
928   for (auto& chunk : chunks_overlapping_in_time) {
929     if (offset < chunk.offset) {
930       use_free_chunk_if_smaller(offset, chunk.offset - offset);
931     }
932     offset = std::max(offset, RoundUpToNearest(chunk.chunk_end(), alignment_));
933   }
934   use_free_chunk_if_smaller(offset, result_.heap_size - offset);
935   // When preferred offset is provided and the preferred offset is larger than
936   // the current heap size, simply use the preferred offset provided.
937   if (result_.heap_size <= preferred_offset) {
938     chunk_candidate.heap_size = preferred_chunk_end;
939     min_fit_chunk = {preferred_offset, buffer_interval.size};
940   }
941 
942   if (min_fit_chunk.offset == -1) {
943     // Increase the heap size to fit in the last free chunk.
944     chunk_candidate.heap_size = offset + buffer_interval.size;
945     min_fit_chunk = {offset, buffer_interval.size};
946   }
947 
948   min_fit_chunk.size = buffer_interval.size;
949   return chunk_candidate;
950 }
951 
952 template <typename BufferType>
CommitChunk(const GlobalDecreasingSizeBestFitHeap<BufferType>::BufferInterval & buffer_interval,GlobalDecreasingSizeBestFitHeap<BufferType>::ChunkCandidate chunk_candidate)953 void GlobalDecreasingSizeBestFitHeap<BufferType>::CommitChunk(
954     const GlobalDecreasingSizeBestFitHeap<BufferType>::BufferInterval&
955         buffer_interval,
956     GlobalDecreasingSizeBestFitHeap<BufferType>::ChunkCandidate
957         chunk_candidate) {
958   // Update the maximum heap size according to the one determined by the chunk
959   // candidate.
960   result_.heap_size = chunk_candidate.heap_size;
961   interval_tree_.Add(buffer_interval.start, buffer_interval.end,
962                      chunk_candidate.chunk);
963   for (auto colocation : GetTransitiveColocations(buffer_interval)) {
964     AddToChunkMap(colocation, chunk_candidate.chunk);
965     auto colocation_interval = buffer_intervals_[colocation];
966     interval_tree_.Add(colocation_interval.start, colocation_interval.end,
967                        chunk_candidate.chunk);
968   }
969 
970   AddToChunkMap(buffer_interval.buffer, chunk_candidate.chunk);
971 }
972 
973 template <typename BufferType>
AddToChunkMap(const BufferType * buffer,Chunk chunk)974 void GlobalDecreasingSizeBestFitHeap<BufferType>::AddToChunkMap(
975     const BufferType* buffer, Chunk chunk) {
976   const auto emplace_result = result_.chunk_map.emplace(buffer, chunk);
977   DCHECK(emplace_result.second);
978 }
979 
980 HeapSimulator::Result<HloValue>
Finish()981 ConstrainedGlobalDecreasingSizeBestFitHeap::Finish() {
982   std::vector<BufferInterval> sorted_buffer_vec = GetSortedBufferIntervals();
983   // Convert into std::list so that erase() is O(1).
984   std::list<BufferInterval> sorted_buffer_intervals(sorted_buffer_vec.begin(),
985                                                     sorted_buffer_vec.end());
986 
987   // Use do-while here, because we need to create 1 heap in `multi_heap_result`
988   // even if `sorted_buffer_intervals` is empty.
989   Result multi_heap_result;
990   do {
991     // Place buffers into the currently processed heap as many as possible.
992     for (auto it = sorted_buffer_intervals.begin();
993          it != sorted_buffer_intervals.end();) {
994       BufferInterval buffer_interval = *it;
995       if (!buffer_interval.need_allocation) {
996         it = sorted_buffer_intervals.erase(it);
997         continue;
998       }
999       if (buffer_interval.size > size_limit_per_heap_) {
1000         LOG(WARNING) << "Alloc buffer size " << buffer_interval.size
1001                      << " larger than the per-heap size limit "
1002                      << size_limit_per_heap_;
1003       }
1004 
1005       ChunkCandidate chunk_candidate = FindChunkCandidate(buffer_interval);
1006       if (chunk_candidate.heap_size <= size_limit_per_heap_ ||
1007           // Commit the chunk as long as the heap is empty. We do this because
1008           // we want the size constraint to be soft, meaning that results are
1009           // successfully generated even if there are some buffer sizes larger
1010           // than the given constraint size.
1011           result_.heap_size == 0) {
1012         CommitChunk(buffer_interval, chunk_candidate);
1013         it = sorted_buffer_intervals.erase(it);
1014         continue;
1015       }
1016 
1017       ++it;
1018     }
1019     // Collect the result from the currently processed heap and reset the heap
1020     // states.
1021     multi_heap_result.heap_size += result_.heap_size;
1022     multi_heap_result.heap_results.push_back(std::move(result_));
1023     result_ = {};
1024     interval_tree_ = {};
1025   } while (!sorted_buffer_intervals.empty());
1026 
1027   VLOG(1) << "Number of heaps produced = "
1028           << multi_heap_result.heap_results.size();
1029   return multi_heap_result;
1030 }
1031 
1032 template <typename BufferType>
1033 HeapSimulator::Result<BufferType>
Finish()1034 ChooseBestHeapAlgorithm<BufferType>::Finish() {
1035   DCHECK(!algorithms_.empty());
1036   std::vector<Result> results(algorithms_.size());
1037   int64_t min_size = INT64_MAX;
1038   int min_size_index = -1;
1039   for (int i = 0; i < algorithms_.size(); ++i) {
1040     results[i] = algorithms_[i]->Finish();
1041     if (results[i].heap_size < min_size) {
1042       min_size = results[i].heap_size;
1043       min_size_index = i;
1044     }
1045   }
1046 
1047   DCHECK_GE(min_size_index, 0);
1048   return results[min_size_index];
1049 }
1050 
1051 template class GlobalDecreasingSizeBestFitHeap<HloValue>;
1052 template class GlobalDecreasingSizeBestFitHeap<
1053     MemorySpaceAssignmentRepacker::AllocationBlock>;
1054 template class ChooseBestHeapAlgorithm<HloValue>;
1055 
1056 }  // namespace xla
1057