• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/heap_simulator.h"
17 
18 #include <algorithm>
19 #include <vector>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/memory/memory.h"
24 #include "tensorflow/compiler/xla/map_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_live_range.h"
26 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
27 #include "tensorflow/compiler/xla/util.h"
28 
29 namespace xla {
30 
31 using absl::flat_hash_map;
32 using absl::flat_hash_set;
33 
OverlapsWith(Chunk other_chunk) const34 bool HeapSimulator::Chunk::OverlapsWith(Chunk other_chunk) const {
35   CHECK_NE(size, 0);
36   CHECK_NE(other_chunk.size, 0);
37   return offset < other_chunk.chunk_end() && other_chunk.offset < chunk_end();
38 }
39 
40 /*static*/
MinimumMemoryForModule(const HloSchedule & schedule,const LogicalBuffer::SizeFunction & size_function)41 StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
42     const HloSchedule& schedule,
43     const LogicalBuffer::SizeFunction& size_function) {
44   if (schedule.empty()) {
45     return 0;
46   }
47   const HloModule* module = schedule.module();
48 
49   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
50                       HloAliasAnalysis::Run(module));
51 
52   // The absolute minimum memory required for a given sequence of instructions
53   // is determined by the sequence of Alloc and Free calls on a simulated heap,
54   // ignoring fragmentation. We run the heap simulation on the whole module,
55   // rather than summing each computation, since it gives us a better lower
56   // bound, by minimizing the liveness of sub-computations.
57   TF_ASSIGN_OR_RETURN(
58       HeapSimulator::Result result,
59       HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(), *module,
60                          schedule, *alias_analysis, size_function));
61   return result.heap_size;
62 }
63 
64 /*static*/
MinimumMemoryForComputation(const HloComputation & computation,const HloInstructionSequence & sequence,const HloAliasAnalysis & alias_analysis,const LogicalBuffer::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> * memory_by_computation)65 StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
66     const HloComputation& computation, const HloInstructionSequence& sequence,
67     const HloAliasAnalysis& alias_analysis,
68     const LogicalBuffer::SizeFunction& size_function,
69     const absl::flat_hash_map<const HloComputation*, int64>*
70         memory_by_computation) {
71   TF_ASSIGN_OR_RETURN(
72       HeapSimulator::Result result,
73       HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(),
74                          computation, sequence, alias_analysis, size_function,
75                          HeapSimulator::Options(), memory_by_computation));
76   return result.heap_size;
77 }
78 
MinimumMemoryForComputation(const HloComputation & computation,const HloInstructionSequence & sequence,const HloAliasAnalysis & alias_analysis,const LogicalBuffer::SizeFunction & size_function,const HloSchedule * schedule)79 StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
80     const HloComputation& computation, const HloInstructionSequence& sequence,
81     const HloAliasAnalysis& alias_analysis,
82     const LogicalBuffer::SizeFunction& size_function,
83     const HloSchedule* schedule) {
84   TF_ASSIGN_OR_RETURN(
85       HeapSimulator::Result result,
86       HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(),
87                          computation, sequence, alias_analysis, size_function,
88                          schedule, HeapSimulator::Options()));
89   return result.heap_size;
90 }
91 
92 /*static*/
Run(std::unique_ptr<HeapAlgorithm> algorithm,const HloModule & module,const HloSchedule & schedule,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_fn,const Options & options)93 StatusOr<HeapSimulator::Result> HeapSimulator::Run(
94     std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module,
95     const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis,
96     const BufferValue::SizeFunction& size_fn, const Options& options) {
97   HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule);
98   const HloComputation* entry_computation = module.entry_computation();
99   const HloInstructionSequence& instruction_sequence =
100       schedule.sequence(entry_computation);
101   TF_ASSIGN_OR_RETURN(
102       std::unique_ptr<HloLiveRange> hlo_live_range,
103       HloLiveRange::Run(schedule, alias_analysis, entry_computation));
104   TF_RETURN_IF_ERROR(heap.RunComputation(*entry_computation,
105                                          instruction_sequence, alias_analysis,
106                                          hlo_live_range.get()));
107   return heap.Finish();
108 }
109 
110 /*static*/
Run(std::unique_ptr<HeapAlgorithm> algorithm,const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_fn,const Options & options,const absl::flat_hash_map<const HloComputation *,int64> * memory_by_computation)111 StatusOr<HeapSimulator::Result> HeapSimulator::Run(
112     std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation,
113     const HloInstructionSequence& instruction_sequence,
114     const HloAliasAnalysis& alias_analysis,
115     const BufferValue::SizeFunction& size_fn, const Options& options,
116     const absl::flat_hash_map<const HloComputation*, int64>*
117         memory_by_computation) {
118   HeapSimulator heap(std::move(algorithm), size_fn, options,
119                      /*schedule=*/nullptr, memory_by_computation);
120   HloSchedule schedule(computation.parent());
121   schedule.set_sequence(&computation, instruction_sequence);
122   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
123                       HloLiveRange::Run(schedule, alias_analysis, &computation,
124                                         /*module_scoped_analysis=*/false));
125   TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
126                                          alias_analysis, hlo_live_range.get()));
127   return heap.Finish();
128 }
129 
130 /*static*/
Run(std::unique_ptr<HeapAlgorithm> algorithm,const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_fn,const HloSchedule * schedule,const Options & options)131 StatusOr<HeapSimulator::Result> HeapSimulator::Run(
132     std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation,
133     const HloInstructionSequence& instruction_sequence,
134     const HloAliasAnalysis& alias_analysis,
135     const BufferValue::SizeFunction& size_fn, const HloSchedule* schedule,
136     const Options& options) {
137   HeapSimulator heap(std::move(algorithm), size_fn, options,
138                      /*schedule=*/schedule, nullptr);
139   TF_ASSIGN_OR_RETURN(
140       std::unique_ptr<HloLiveRange> hlo_live_range,
141       HloLiveRange::Run(*schedule, alias_analysis, &computation));
142   TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
143                                          alias_analysis, hlo_live_range.get()));
144   return heap.Finish();
145 }
146 
147 // Runs a heap simulation for the given 'computation', assuming the given
148 // 'instruction_sequence'.
RunComputation(const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const HloAliasAnalysis & alias_analysis,HloLiveRange * hlo_live_range)149 Status HeapSimulator::RunComputation(
150     const HloComputation& computation,
151     const HloInstructionSequence& instruction_sequence,
152     const HloAliasAnalysis& alias_analysis, HloLiveRange* hlo_live_range) {
153   XLA_VLOG_LINES(1, computation.parent()->ToString());
154   XLA_VLOG_LINES(2, computation.ToString());
155 
156   VLOG(1) << hlo_live_range->ToString();
157 
158   HloDataflowAnalysis& dataflow_analysis = alias_analysis.dataflow_analysis();
159 
160   // Record the buffer define/free event for each time step. We free all
161   // remaining buffers (entry parameter, etc) after the program has finished
162   // running, so we set the size of to program_end_time + 1.
163   std::vector<std::vector<const HloValue*>> buffers_defined(
164       hlo_live_range->schedule_end_time() + 1);
165   std::vector<std::vector<const HloValue*>> buffers_freed(
166       hlo_live_range->schedule_end_time() + 1);
167 
168   // values_to_assign tracks the HloValues that we need to assign a buffer to.
169   // Note that we only need to assign a buffer to a value when both of the
170   // following conditions are met:
171   //
172   // - The user specifically asks us to assign a buffer to a set of HloValues,
173   // and the value is in the set. If the user don't provide such a set, by
174   // default we assign buffer to all HloValues.
175   //
176   // - If the instruction is in a nested call of the current computation, only
177   // assign a buffer if we are doing global heap simulation.
178   std::vector<const HloValue*> values_to_assign;
179   values_to_assign.reserve(dataflow_analysis.values().size());
180 
181   for (const HloValue* value : dataflow_analysis.values()) {
182     // Ignore buffers that are not tracked.
183     if (hlo_live_range->instruction_schedule().count(
184             value->defining_instruction()) == 0) {
185       continue;
186     }
187     if (IgnoreBuffer(value)) {
188       continue;
189     }
190     values_to_assign.push_back(value);
191   }
192 
193   auto& buffer_live_ranges = hlo_live_range->buffer_live_ranges();
194 
195   absl::c_sort(values_to_assign,
196                [&](const HloValue* value1, const HloValue* value2) {
197                  const auto& live_range1 = buffer_live_ranges.at(value1);
198                  const auto& live_range2 = buffer_live_ranges.at(value2);
199                  return std::forward_as_tuple(live_range1.start,
200                                               live_range1.end, value1->id()) <
201                         std::forward_as_tuple(live_range2.start,
202                                               live_range2.end, value2->id());
203                });
204 
205   // For each value that we need to assign a buffer to, add the define and free
206   // events.
207   for (const HloValue* value : values_to_assign) {
208     auto live_range = buffer_live_ranges.at(value);
209     buffers_defined[live_range.start].push_back(value);
210     buffers_freed[live_range.end].push_back(value);
211   }
212 
213   // All HloValues in a hlo buffer should be allocated to the same address. This
214   // map tracks the first value that got allocated in a buffer.
215   absl::flat_hash_map<const HloBuffer*, const HloValue*> first_allocated_value;
216 
217   VLOG(1) << "Program time" << hlo_live_range->schedule_end_time();
218 
219   // Go through each step in the program and replay each buffer define and free
220   // events.
221   for (int64 i = 0; i < hlo_live_range->schedule_end_time() + 1; ++i) {
222     VLOG(1) << "Time step: " << i;
223 
224     for (const HloValue* value : buffers_defined[i]) {
225       bool shared = false;
226       VLOG(1) << "Start buffer: " << value->ToShortString();
227       const HloBuffer* hlo_buffer =
228           &alias_analysis.GetBufferContainingValue(*value);
229       if (first_allocated_value.count(hlo_buffer) != 0) {
230         // We've already assigned an address for another value in this HloBuffer
231         // (HloBuffer holds several aliased HloValues). All values in a buffer
232         // should be assigned the same address. Find the one that's already
233         // allocated and reuse its address.
234         ShareBuffer(value, first_allocated_value[hlo_buffer],
235                     value->instruction());
236         VLOG(1) << "  ShareWith"
237                 << first_allocated_value[hlo_buffer]->ToShortString();
238         continue;
239       }
240       if (options_.may_reuse_operand_buffers &&
241           hlo_buffer->values().size() == 1) {
242         // We don't support sharing an aliased buffer
243         // (hlo_buffer->values().size() > 1) with its operand.
244         for (const HloInstruction* operand : value->instruction()->operands()) {
245           const HloValueSet operand_value_set =
246               dataflow_analysis.GetValueSet(operand);
247           for (const HloValue* operand_value : operand_value_set.values()) {
248             const HloBuffer* operand_buffer =
249                 &alias_analysis.GetBufferContainingValue(*operand_value);
250             if (operand_buffer->values().size() > 1) {
251               continue;
252             }
253             auto it = buffer_live_ranges.find(operand_value);
254             if (it == buffer_live_ranges.end()) {
255               continue;
256             }
257 
258             auto& operand_live_range = it->second;
259 
260             auto& user_live_range = buffer_live_ranges[value];
261 
262             // Can only share buffers that are about to be freed.
263             if (operand_live_range.end != i) {
264               continue;
265             }
266 
267             if (IgnoreBuffer(operand_value)) {
268               continue;
269             }
270 
271             if (!absl::c_linear_search(buffers_freed[i], operand_value)) {
272               // If the operand buffer is not being freed (either because it has
273               // existing users, or it has been reused by other buffers), don't
274               // consider the operand as a candidate of buffer sharing.
275               continue;
276             }
277 
278             // The instruction that defines the operand value can be different
279             // from the actual operand, if directly passing the defining
280             // instruction into "CanShareOperandBufferWithUser" it creates a
281             // check failure. The first condition guards against that case.
282             if (value->instruction()->IsUserOf(operand_value->instruction()) &&
283                 value->instruction()->opcode() != HloOpcode::kCopy &&
284                 dataflow_analysis.CanShareOperandBufferWithUser(
285                     operand_value->instruction(), operand_value->index(),
286                     value->instruction(), value->index())) {
287               // Remove the operand buffer right before sharing (allocating) a
288               // new one.
289               Free(operand_value, operand_value->instruction());
290               buffers_freed[i].erase(
291                   std::remove(buffers_freed[i].begin(), buffers_freed[i].end(),
292                               operand_value),
293                   buffers_freed[i].end());
294               ShareBuffer(value, operand_value, value->instruction());
295               // The live range of the operand buffer is now extended to the end
296               // of the current instruction.
297               operand_live_range.end = user_live_range.end;
298               VLOG(1) << "Sharing " << value->ToShortString() << " with "
299                       << operand_value->ToShortString()
300                       << ", size:" << size_fn_(*value);
301               shared = true;
302               break;
303             }
304           }
305           if (shared) {
306             break;
307           }
308         }
309       }
310       if (!shared) {
311         Alloc(value, value->instruction());
312         first_allocated_value[hlo_buffer] = value;
313       }
314     }
315 
316     if (!buffers_freed[i].empty()) {
317       VLOG(1) << "Free Buffer: ";
318     }
319     for (const HloValue* value : buffers_freed[i]) {
320       VLOG(1) << "  " << value->ToShortString();
321 
322       Free(value, value->instruction());
323     }
324   }
325   return Status::OK();
326 }
327 
HeapSimulator(std::unique_ptr<HeapAlgorithm> algorithm,const BufferValue::SizeFunction & size_fn,const Options & options,const HloSchedule * schedule,const absl::flat_hash_map<const HloComputation *,int64> * memory_by_computation)328 HeapSimulator::HeapSimulator(
329     std::unique_ptr<HeapAlgorithm> algorithm,
330     const BufferValue::SizeFunction& size_fn, const Options& options,
331     const HloSchedule* schedule,
332     const absl::flat_hash_map<const HloComputation*, int64>*
333         memory_by_computation)
334     : no_fragmentation_stats_(absl::make_unique<NoFragmentationStatsHeap>()),
335       algorithm_(std::move(algorithm)),
336       size_fn_(size_fn),
337       options_(options),
338       schedule_(schedule),
339       memory_by_computation_(memory_by_computation) {
340   debug_trace_.set_whole_module_simulation(schedule_ != nullptr);
341 }
342 
~HeapSimulator()343 HeapSimulator::~HeapSimulator() {}
344 
IgnoreBuffer(const HloValue * buffer) const345 bool HeapSimulator::IgnoreBuffer(const HloValue* buffer) const {
346   // Buffers for constants are ignored unless the alloc_constants option is
347   // set. Also ignore buffers that we're not meant to assign.
348   //
349   // TODO(b/32248867): For consistency, constants should get allocations.
350   if (!options_.alloc_constants &&
351       buffer->instruction()->opcode() == HloOpcode::kConstant) {
352     return true;
353   }
354   return options_.buffers_to_assign != nullptr &&
355          !options_.buffers_to_assign->contains(buffer);
356 }
357 
358 // Alloc always calls the underlying heap algorithm.
Alloc(const HloValue * buffer,const HloInstruction * instruction)359 void HeapSimulator::Alloc(const HloValue* buffer,
360                           const HloInstruction* instruction) {
361   CHECK(!allocated_buffers_.contains(buffer))
362       << "Alloc called on allocated buffer: " << *buffer;
363   CHECK(!freed_buffers_.contains(buffer))
364       << "Alloc called on freed buffer: " << *buffer;
365 
366   allocated_buffers_.insert(buffer);
367   const int64 size = size_fn_(*buffer);
368   algorithm_->Alloc(buffer, size);
369   no_fragmentation_stats_->Alloc(buffer, size);
370   FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction,
371                  nullptr);
372 }
373 
374 // Free calls the underlying algorithm for non-shared buffers, and for shared
375 // buffers whose group liveness has expired.  Shared group liveness is tracked
376 // by maintaining a refcount; the Free call on the last buffer in the group
377 // causes Free to be called on the underlying algorithm.
Free(const HloValue * buffer,const HloInstruction * instruction)378 void HeapSimulator::Free(const HloValue* buffer,
379                          const HloInstruction* instruction) {
380   const int64 size = size_fn_(*buffer);
381   algorithm_->Free(buffer, size);
382   no_fragmentation_stats_->Free(buffer, size);
383   FillDebugTrace(HeapSimulatorTrace::Event::FREE, buffer, instruction, nullptr);
384 }
385 
386 // ShareBuffer associates buffers with their SharedGroup in shared_buffers_.
387 // The 'buffer' must be a non-allocated, non-freed buffer, just like in calls
388 // to Alloc.  The 'shared' buffer must be a previously allocated or shared
389 // buffer. Both 'buffer' and 'shared' will be associated with the same
390 // SharedGroup.
ShareBuffer(const HloValue * buffer,const HloValue * shared,const HloInstruction * instruction)391 void HeapSimulator::ShareBuffer(const HloValue* buffer, const HloValue* shared,
392                                 const HloInstruction* instruction) {
393   algorithm_->ShareWith(buffer, shared, size_fn_(*shared));
394   no_fragmentation_stats_->ShareWith(buffer, shared, size_fn_(*shared));
395   FillDebugTrace(HeapSimulatorTrace::Event::SHARE_WITH, buffer, instruction,
396                  shared);
397 }
398 
Finish()399 HeapSimulator::Result HeapSimulator::Finish() {
400   Result result = algorithm_->Finish();
401 
402   // Post-process the result to add chunks for shared buffers.  An empty chunk
403   // map means that either no buffers were allocated, or the heap was only
404   // collecting statistics, e.g. NoFragmentationStatsHeap.
405   if (!result.chunk_map.empty()) {
406     // If we were told to assign specific buffers, make sure we've assigned
407     // exactly that many buffers.
408     if (options_.buffers_to_assign != nullptr) {
409       CHECK_EQ(options_.buffers_to_assign->size(), result.chunk_map.size());
410     }
411   }
412 
413   // Fragmentation is the difference between the actual and ideal sizes.
414   const Result no_frag_result = no_fragmentation_stats_->Finish();
415   result.fragmentation_size = result.heap_size - no_frag_result.heap_size;
416 
417   // Copy the debug trace we collected to the final result.
418   result.debug_trace.Swap(&debug_trace_);
419 
420   return result;
421 }
422 
FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,const HloValue * buffer,const HloInstruction * instruction,const HloValue * share_with_canonical)423 void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,
424                                    const HloValue* buffer,
425                                    const HloInstruction* instruction,
426                                    const HloValue* share_with_canonical) {
427   HeapSimulatorTrace::Event* event = debug_trace_.add_events();
428   event->set_kind(kind);
429   event->set_buffer_id(buffer->id());
430   event->set_computation_name(instruction->parent()->name());
431   event->set_instruction_name(instruction->name());
432   if (kind == HeapSimulatorTrace::Event::SHARE_WITH) {
433     CHECK(share_with_canonical != nullptr);
434     event->set_share_with_canonical_id(share_with_canonical->id());
435   } else {
436     CHECK(share_with_canonical == nullptr);
437   }
438 }
439 
Alloc(const HloValue * buffer,int64 size)440 void NoFragmentationStatsHeap::Alloc(const HloValue* buffer, int64 size) {
441   current_heap_size_ += size;
442   if (current_heap_size_ > max_heap_size_) {
443     max_heap_size_ = current_heap_size_;
444   }
445 }
446 
AccountForSubcomputationMemory(const HloInstruction * instruction,int64 alloc_size_by_instruction,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation)447 void NoFragmentationStatsHeap::AccountForSubcomputationMemory(
448     const HloInstruction* instruction, int64 alloc_size_by_instruction,
449     const absl::flat_hash_map<const HloComputation*, int64>&
450         memory_by_computation) {
451   // We only count the memory usage of the largest subcomputation, instead of
452   // adding them all, because subcomputations won't execute in parallel.
453   int64 max_subcomputation_bytes = 0;
454   for (const auto* c : instruction->called_computations()) {
455     auto it = memory_by_computation.find(c);
456     if (it != memory_by_computation.end()) {
457       int64 subcomputation_bytes = it->second;
458       if (subcomputation_bytes > max_subcomputation_bytes) {
459         max_subcomputation_bytes = subcomputation_bytes;
460       }
461     }
462   }
463   if (max_subcomputation_bytes > 0 &&
464       (instruction->opcode() == HloOpcode::kWhile ||
465        instruction->opcode() == HloOpcode::kCall ||
466        instruction->opcode() == HloOpcode::kConditional)) {
467     // The output buffer of while/call/conditional is always aliased with the
468     // output buffer of the root instruction in the body. Don't double count.
469     max_subcomputation_bytes -= alloc_size_by_instruction;
470   }
471   max_heap_size_ =
472       std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes);
473 }
474 
Free(const HloValue * buffer,int64 size)475 void NoFragmentationStatsHeap::Free(const HloValue* buffer, int64 size) {
476   current_heap_size_ -= size;
477 }
478 
Finish()479 HeapSimulator::Result NoFragmentationStatsHeap::Finish() {
480   // The result.chunk_map is empty, since we only collect stats, and don't
481   // actually compute chunk assignments.
482   Result result;
483   result.heap_size = max_heap_size_;
484   return result;
485 }
486 
GlobalDecreasingSizeBestFitHeap(int64 alignment,Type type)487 GlobalDecreasingSizeBestFitHeap::GlobalDecreasingSizeBestFitHeap(
488     int64 alignment, Type type)
489     : alignment_(alignment) {
490   if (type == kTemporal) {
491     buffer_interval_compare_ = GetTemporalBufferIntervalCompare();
492   } else {
493     CHECK(type == kSpatial);
494     buffer_interval_compare_ = GetSpatialBufferIntervalCompare();
495   }
496 }
497 
498 GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare
GetTemporalBufferIntervalCompare() const499 GlobalDecreasingSizeBestFitHeap::GetTemporalBufferIntervalCompare() const {
500   return [&](const BufferInterval& x, const BufferInterval& y) {
501     int64 x_end = x.end;
502     for (auto colocation : GetTransitiveColocations(x)) {
503       x_end = std::max(x_end, buffer_intervals_.at(colocation).end);
504     }
505 
506     int64 y_end = y.end;
507     for (auto colocation : GetTransitiveColocations(y)) {
508       y_end = std::max(y_end, buffer_intervals_.at(colocation).end);
509     }
510 
511     if (x_end - x.start != y_end - y.start) {
512       return x_end - x.start > y_end - y.start;
513     }
514 
515     if (x.size != y.size) {
516       return x.size > y.size;
517     }
518     return x.buffer->id() < y.buffer->id();
519   };
520 }
521 
522 /*static*/ GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare
GetSpatialBufferIntervalCompare()523 GlobalDecreasingSizeBestFitHeap::GetSpatialBufferIntervalCompare() {
524   return [&](const BufferInterval& x, const BufferInterval& y) {
525     if (x.size != y.size) {
526       return x.size > y.size;
527     }
528     if (x.end - x.start != y.end - y.start) {
529       return x.end - x.start > y.end - y.start;
530     }
531     return x.buffer->id() < y.buffer->id();
532   };
533 }
534 
Alloc(const HloValue * buffer,int64 size)535 void GlobalDecreasingSizeBestFitHeap::Alloc(const HloValue* buffer,
536                                             int64 size) {
537   // Degenerate case: 0-sized buffers are always allocated at offset 0.
538   if (size == 0) {
539     result_.chunk_map.emplace(buffer, Chunk{0, 0});
540     return;
541   }
542 
543   auto emplace_result = buffer_intervals_.emplace(
544       buffer, BufferInterval{buffer, size, current_time_, -1, {}, true});
545   DCHECK(emplace_result.second);
546   ++current_time_;
547 }
548 
ShareWith(const HloValue * buffer,const HloValue * share_with,int64 size)549 void GlobalDecreasingSizeBestFitHeap::ShareWith(const HloValue* buffer,
550                                                 const HloValue* share_with,
551                                                 int64 size) {
552   // Degenerate case: 0-sized buffers are always allocated at offset 0.
553   if (size == 0) {
554     result_.chunk_map.emplace(buffer, Chunk{0, 0});
555     return;
556   }
557   DCHECK_NE(buffer_intervals_.count(share_with), 0);
558   buffer_intervals_[share_with].colocations.push_back(buffer);
559   auto emplace_result = buffer_intervals_.emplace(
560       buffer, BufferInterval{buffer, size, current_time_, -1, {}, false});
561   DCHECK(emplace_result.second);
562   ++current_time_;
563 }
564 
565 absl::flat_hash_set<const HloValue*>
GetTransitiveColocations(const BufferInterval & interval) const566 GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations(
567     const BufferInterval& interval) const {
568   absl::flat_hash_set<const HloValue*> result;
569   std::vector<const BufferInterval*> worklist = {&interval};
570   while (!worklist.empty()) {
571     const BufferInterval* item = worklist.back();
572     worklist.pop_back();
573     for (const HloValue* buffer_colocated : item->colocations) {
574       result.insert(buffer_colocated);
575       worklist.push_back(&buffer_intervals_.at(buffer_colocated));
576     }
577   }
578 
579   return result;
580 }
581 
Free(const HloValue * buffer,int64 size)582 void GlobalDecreasingSizeBestFitHeap::Free(const HloValue* buffer, int64 size) {
583   // Degenerate case: 0-sized buffers are always allocated at offset 0.
584   if (size == 0) {
585     return;
586   }
587   BufferInterval& buffer_interval = FindOrDie(buffer_intervals_, buffer);
588   DCHECK_EQ(buffer_interval.buffer, buffer);
589   DCHECK_EQ(buffer_interval.size, size);
590   DCHECK_EQ(buffer_interval.end, -1);
591   if (buffer_interval.end != -1) {
592     return;
593   }
594   buffer_interval.end = current_time_;
595   ++current_time_;
596 }
597 
598 using Chunk = HeapSimulator::Chunk;
599 
Add(int64 start,int64 end,const Chunk & chunk)600 void BufferIntervalTree::Add(int64 start, int64 end, const Chunk& chunk) {
601   node_storage_.emplace_back(
602       BufferIntervalTreeNode{start, end, end, chunk, nullptr, nullptr});
603 
604   if (node_storage_.size() == 1) {
605     // This is root.
606     return;
607   }
608 
609   BufferIntervalTreeNode* parent = &node_storage_.front();
610   while (true) {
611     parent->subtree_end = std::max(parent->subtree_end, end);
612     if (parent->start > start) {
613       if (parent->left == nullptr) {
614         parent->left = &node_storage_.back();
615         return;
616       }
617       parent = parent->left;
618     } else {
619       if (parent->right == nullptr) {
620         parent->right = &node_storage_.back();
621         return;
622       }
623       parent = parent->right;
624     }
625   }
626 }
627 
ChunksOverlappingInTime(int64 start,int64 end) const628 std::vector<Chunk> BufferIntervalTree::ChunksOverlappingInTime(
629     int64 start, int64 end) const {
630   std::vector<Chunk> result;
631   if (node_storage_.empty()) {
632     return result;
633   }
634   std::vector<const BufferIntervalTreeNode*> visiting_stack;
635   visiting_stack.push_back(&node_storage_.front());
636   while (!visiting_stack.empty()) {
637     const BufferIntervalTreeNode* top = visiting_stack.back();
638     visiting_stack.pop_back();
639     if (start > top->subtree_end) {
640       continue;
641     }
642     if (top->left != nullptr) {
643       visiting_stack.push_back(top->left);
644     }
645     if (top->start <= end && top->end >= start) {
646       result.push_back(top->chunk);
647     }
648     if (end < top->start) {
649       continue;
650     }
651     if (top->right != nullptr) {
652       visiting_stack.push_back(top->right);
653     }
654   }
655   return result;
656 }
657 
Finish()658 HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() {
659   std::vector<BufferInterval> sorted_buffer_intervals =
660       GetSortedBufferIntervals();
661 
662   for (auto& buffer_interval : sorted_buffer_intervals) {
663     if (!buffer_interval.need_allocation) {
664       continue;
665     }
666 
667     ChunkCandidate chunk_candidate = FindChunkCandidate(buffer_interval);
668     // This implementation of the heap algorithm does not have a notion of
669     // maximum heap size, so it just commits.
670     CommitChunk(buffer_interval, chunk_candidate);
671   }
672   VLOG(1) << "result heap_size: " << result_.heap_size;
673   return result_;
674 }
675 
676 std::vector<GlobalDecreasingSizeBestFitHeap::BufferInterval>
GetSortedBufferIntervals() const677 GlobalDecreasingSizeBestFitHeap::GetSortedBufferIntervals() const {
678   std::vector<BufferInterval> sorted_buffer_intervals;
679   for (auto& entry : buffer_intervals_) {
680     sorted_buffer_intervals.push_back(entry.second);
681   }
682   absl::c_sort(sorted_buffer_intervals, buffer_interval_compare_);
683 
684   return sorted_buffer_intervals;
685 }
686 
687 GlobalDecreasingSizeBestFitHeap::ChunkCandidate
FindChunkCandidate(const GlobalDecreasingSizeBestFitHeap::BufferInterval & buffer_interval,int64 preferred_offset) const688 GlobalDecreasingSizeBestFitHeap::FindChunkCandidate(
689     const GlobalDecreasingSizeBestFitHeap::BufferInterval& buffer_interval,
690     int64 preferred_offset) const {
691   VLOG(1) << "Finding chunks for buffer: "
692           << buffer_interval.buffer->ToString();
693   VLOG(1) << "Size " << buffer_interval.size << ", start "
694           << buffer_interval.start << ", end " << buffer_interval.end;
695   auto chunks_overlapping_in_time = interval_tree_.ChunksOverlappingInTime(
696       buffer_interval.start, buffer_interval.end);
697   // Get all colocated buffers and gather all interferenced chunks.
698   //
699   // Imagine that we've already allocated three chunks : a, b and c.  And now
700   // we want to allocate d. Since e is colocated with d, we have to allocate
701   // chunks for them together at the same address. To do this, we first gather
702   // all chunks that overlap with d and e on the time dimension, in this case
703   // the overlapped chunks are a and b (c doesn't overlap with either of d and
704   // e), then find create a new chunk that doesn't overlap with a and b on the
705   // space dimension.
706   //
707   // space
708   //   ^
709   //   |+--d---+      +---e---+
710   //   |
711   //   |+---+  +---------------+  +-------+
712   //   ||   |  |               |  |       |
713   //   ||   |  |               |  |       |
714   //   |+-a-+  +-------b-------+  +---c---+
715   //   ----------------------------------------> time
716   for (auto colocation : GetTransitiveColocations(buffer_interval)) {
717     auto colocation_interval = buffer_intervals_.at(colocation);
718     auto colocation_overlapping = interval_tree_.ChunksOverlappingInTime(
719         colocation_interval.start, colocation_interval.end);
720     VLOG(1) << "  Alias size " << colocation_interval.size << ", start "
721             << colocation_interval.start << ", end " << colocation_interval.end
722             << " " << colocation_interval.buffer->ToString();
723     chunks_overlapping_in_time.insert(chunks_overlapping_in_time.end(),
724                                       colocation_overlapping.begin(),
725                                       colocation_overlapping.end());
726   }
727   absl::c_sort(chunks_overlapping_in_time, [](const Chunk& x, const Chunk& y) {
728     return x.offset < y.offset;
729   });
730 
731   // Find the minimum free chunk that can hold this buffer.
732   ChunkCandidate chunk_candidate{Chunk{-1, INT64_MAX}, result_.heap_size};
733   Chunk& min_fit_chunk = chunk_candidate.chunk;
734   int64 preferred_chunk_end = preferred_offset + buffer_interval.size;
735   auto use_free_chunk_if_smaller = [&](int64 free_offset, int64 free_size) {
736     if (free_size < buffer_interval.size) {
737       return;
738     }
739 
740     // If a preferred offset is provided, pick that offset.
741     if (free_offset <= preferred_offset &&
742         free_offset + free_size >= preferred_chunk_end) {
743       min_fit_chunk = {preferred_offset, buffer_interval.size};
744     } else if (free_offset + free_size == result_.heap_size &&
745                free_offset <= preferred_offset) {
746       // If the free offset is at the very end and if the preferred offset lies
747       // in this, pick the preferred offset and grow the heap.
748       min_fit_chunk = {preferred_offset, buffer_interval.size};
749       chunk_candidate.heap_size = preferred_chunk_end;
750     }
751 
752     // Pick the min-fit chunk only if we didn't have a preferred offset or a
753     // chunk at the preferred offset hasn't been found.
754     if ((preferred_offset < 0 || min_fit_chunk.offset != preferred_offset) &&
755         free_size < min_fit_chunk.size) {
756       min_fit_chunk = {free_offset, free_size};
757     }
758   };
759 
760   int64 offset = 0;
761   for (auto& chunk : chunks_overlapping_in_time) {
762     if (offset < chunk.offset) {
763       use_free_chunk_if_smaller(offset, chunk.offset - offset);
764     }
765     offset = std::max(offset, RoundUpToNearest(chunk.chunk_end(), alignment_));
766   }
767   use_free_chunk_if_smaller(offset, result_.heap_size - offset);
768   // When preferred offset is provided and the preferred offset is larger than
769   // the current heap size, simply use the preferred offset provided.
770   if (result_.heap_size <= preferred_offset) {
771     chunk_candidate.heap_size = preferred_chunk_end;
772     min_fit_chunk = {preferred_offset, buffer_interval.size};
773   }
774 
775   if (min_fit_chunk.offset == -1) {
776     // Increase the heap size to fit in the last free chunk.
777     chunk_candidate.heap_size = offset + buffer_interval.size;
778     min_fit_chunk = {offset, buffer_interval.size};
779   }
780 
781   min_fit_chunk.size = buffer_interval.size;
782   return chunk_candidate;
783 }
784 
CommitChunk(const GlobalDecreasingSizeBestFitHeap::BufferInterval & buffer_interval,GlobalDecreasingSizeBestFitHeap::ChunkCandidate chunk_candidate)785 void GlobalDecreasingSizeBestFitHeap::CommitChunk(
786     const GlobalDecreasingSizeBestFitHeap::BufferInterval& buffer_interval,
787     GlobalDecreasingSizeBestFitHeap::ChunkCandidate chunk_candidate) {
788   // Update the maximum heap size according to the one determined by the chunk
789   // candidate.
790   result_.heap_size = chunk_candidate.heap_size;
791   interval_tree_.Add(buffer_interval.start, buffer_interval.end,
792                      chunk_candidate.chunk);
793   for (auto colocation : GetTransitiveColocations(buffer_interval)) {
794     AddToChunkMap(colocation, chunk_candidate.chunk);
795     auto colocation_interval = buffer_intervals_[colocation];
796     interval_tree_.Add(colocation_interval.start, colocation_interval.end,
797                        chunk_candidate.chunk);
798   }
799 
800   AddToChunkMap(buffer_interval.buffer, chunk_candidate.chunk);
801 }
802 
AddToChunkMap(const HloValue * buffer,Chunk chunk)803 void GlobalDecreasingSizeBestFitHeap::AddToChunkMap(const HloValue* buffer,
804                                                     Chunk chunk) {
805   const auto emplace_result = result_.chunk_map.emplace(buffer, chunk);
806   DCHECK(emplace_result.second);
807 }
808 
Finish()809 HeapSimulator::Result ChooseBestHeapAlgorithm::Finish() {
810   DCHECK(!algorithms_.empty());
811   std::vector<Result> results(algorithms_.size());
812   int64 min_size = INT64_MAX;
813   int min_size_index = -1;
814   for (int i = 0; i < algorithms_.size(); ++i) {
815     results[i] = algorithms_[i]->Finish();
816     if (results[i].heap_size < min_size) {
817       min_size = results[i].heap_size;
818       min_size_index = i;
819     }
820   }
821 
822   DCHECK_GE(min_size_index, 0);
823   return results[min_size_index];
824 }
825 
826 }  // namespace xla
827