1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/heap_simulator.h"
17
18 #include <algorithm>
19 #include <vector>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/memory/memory.h"
24 #include "tensorflow/compiler/xla/map_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_live_range.h"
26 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
27 #include "tensorflow/compiler/xla/service/memory_space_assignment_repacking.h"
28 #include "tensorflow/compiler/xla/util.h"
29
30 namespace xla {
31
32 using absl::flat_hash_map;
33 using absl::flat_hash_set;
34
OverlapsWith(Chunk other_chunk) const35 bool HeapSimulator::Chunk::OverlapsWith(Chunk other_chunk) const {
36 CHECK_NE(size, 0);
37 CHECK_NE(other_chunk.size, 0);
38 return offset < other_chunk.chunk_end() && other_chunk.offset < chunk_end();
39 }
40
41 /*static*/
MinimumMemoryForModule(const HloSchedule & schedule,const LogicalBuffer::SizeFunction & size_function)42 StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
43 const HloSchedule& schedule,
44 const LogicalBuffer::SizeFunction& size_function) {
45 if (schedule.empty()) {
46 return 0;
47 }
48 const HloModule* module = schedule.module();
49
50 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
51 HloAliasAnalysis::Run(module));
52
53 // The absolute minimum memory required for a given sequence of instructions
54 // is determined by the sequence of Alloc and Free calls on a simulated heap,
55 // ignoring fragmentation. We run the heap simulation on the whole module,
56 // rather than summing each computation, since it gives us a better lower
57 // bound, by minimizing the liveness of sub-computations.
58 TF_ASSIGN_OR_RETURN(
59 HeapSimulator::Result<HloValue> result,
60 HeapSimulator::Run(
61 absl::make_unique<NoFragmentationStatsHeap<HloValue>>(), *module,
62 schedule, *alias_analysis, size_function));
63 return result.heap_size;
64 }
65
66 /*static*/
MinimumMemoryForComputation(const HloComputation & computation,const HloInstructionSequence & sequence,const HloAliasAnalysis & alias_analysis,const LogicalBuffer::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> * memory_by_computation)67 StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
68 const HloComputation& computation, const HloInstructionSequence& sequence,
69 const HloAliasAnalysis& alias_analysis,
70 const LogicalBuffer::SizeFunction& size_function,
71 const absl::flat_hash_map<const HloComputation*, int64>*
72 memory_by_computation) {
73 TF_ASSIGN_OR_RETURN(
74 HeapSimulator::Result<HloValue> result,
75 HeapSimulator::Run(
76 absl::make_unique<NoFragmentationStatsHeap<HloValue>>(), computation,
77 sequence, alias_analysis, size_function, HeapSimulator::Options(),
78 memory_by_computation));
79 return result.heap_size;
80 }
81
MinimumMemoryForComputation(const HloComputation & computation,const HloInstructionSequence & sequence,const HloAliasAnalysis & alias_analysis,const LogicalBuffer::SizeFunction & size_function,const HloSchedule * schedule)82 StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
83 const HloComputation& computation, const HloInstructionSequence& sequence,
84 const HloAliasAnalysis& alias_analysis,
85 const LogicalBuffer::SizeFunction& size_function,
86 const HloSchedule* schedule) {
87 TF_ASSIGN_OR_RETURN(
88 HeapSimulator::Result<HloValue> result,
89 HeapSimulator::Run(
90 absl::make_unique<NoFragmentationStatsHeap<HloValue>>(), computation,
91 sequence, alias_analysis, size_function, schedule,
92 HeapSimulator::Options()));
93 return result.heap_size;
94 }
95
96 /*static*/
Run(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,const HloModule & module,const HloSchedule & schedule,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_fn,const Options & options)97 StatusOr<HeapSimulator::Result<HloValue>> HeapSimulator::Run(
98 std::unique_ptr<HeapAlgorithm<HloValue>> algorithm, const HloModule& module,
99 const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis,
100 const BufferValue::SizeFunction& size_fn, const Options& options) {
101 HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule);
102 const HloComputation* entry_computation = module.entry_computation();
103 const HloInstructionSequence& instruction_sequence =
104 schedule.sequence(entry_computation);
105 TF_ASSIGN_OR_RETURN(
106 std::unique_ptr<HloLiveRange> hlo_live_range,
107 HloLiveRange::Run(schedule, alias_analysis, entry_computation));
108 TF_RETURN_IF_ERROR(heap.RunComputation(*entry_computation,
109 instruction_sequence, alias_analysis,
110 hlo_live_range.get()));
111 return heap.Finish();
112 }
113
114 /*static*/
Run(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_fn,const Options & options,const absl::flat_hash_map<const HloComputation *,int64> * memory_by_computation)115 StatusOr<HeapSimulator::Result<HloValue>> HeapSimulator::Run(
116 std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
117 const HloComputation& computation,
118 const HloInstructionSequence& instruction_sequence,
119 const HloAliasAnalysis& alias_analysis,
120 const BufferValue::SizeFunction& size_fn, const Options& options,
121 const absl::flat_hash_map<const HloComputation*, int64>*
122 memory_by_computation) {
123 HeapSimulator heap(std::move(algorithm), size_fn, options,
124 /*schedule=*/nullptr, memory_by_computation);
125 HloSchedule schedule(computation.parent());
126 schedule.set_sequence(&computation, instruction_sequence);
127 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
128 HloLiveRange::Run(schedule, alias_analysis, &computation,
129 /*module_scoped_analysis=*/false));
130 TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
131 alias_analysis, hlo_live_range.get()));
132 return heap.Finish();
133 }
134
135 /*static*/
Run(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_fn,const HloSchedule * schedule,const Options & options)136 StatusOr<HeapSimulator::Result<HloValue>> HeapSimulator::Run(
137 std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
138 const HloComputation& computation,
139 const HloInstructionSequence& instruction_sequence,
140 const HloAliasAnalysis& alias_analysis,
141 const BufferValue::SizeFunction& size_fn, const HloSchedule* schedule,
142 const Options& options) {
143 HeapSimulator heap(std::move(algorithm), size_fn, options,
144 /*schedule=*/schedule, nullptr);
145 TF_ASSIGN_OR_RETURN(
146 std::unique_ptr<HloLiveRange> hlo_live_range,
147 HloLiveRange::Run(*schedule, alias_analysis, &computation));
148 TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
149 alias_analysis, hlo_live_range.get()));
150 return heap.Finish();
151 }
152
153 // Runs a heap simulation for the given 'computation', assuming the given
154 // 'instruction_sequence'.
RunComputation(const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const HloAliasAnalysis & alias_analysis,HloLiveRange * hlo_live_range)155 Status HeapSimulator::RunComputation(
156 const HloComputation& computation,
157 const HloInstructionSequence& instruction_sequence,
158 const HloAliasAnalysis& alias_analysis, HloLiveRange* hlo_live_range) {
159 XLA_VLOG_LINES(1, computation.parent()->ToString());
160 XLA_VLOG_LINES(2, computation.ToString());
161
162 VLOG(1) << hlo_live_range->ToString();
163
164 HloDataflowAnalysis& dataflow_analysis = alias_analysis.dataflow_analysis();
165
166 // Record the buffer define/free event for each time step. We free all
167 // remaining buffers (entry parameter, etc) after the program has finished
168 // running, so we set the size of to program_end_time + 1.
169 std::vector<std::vector<const HloValue*>> buffers_defined(
170 hlo_live_range->schedule_end_time() + 1);
171 std::vector<std::vector<const HloValue*>> buffers_freed(
172 hlo_live_range->schedule_end_time() + 1);
173
174 // values_to_assign tracks the HloValues that we need to assign a buffer to.
175 // Note that we only need to assign a buffer to a value when both of the
176 // following conditions are met:
177 //
178 // - The user specifically asks us to assign a buffer to a set of HloValues,
179 // and the value is in the set. If the user don't provide such a set, by
180 // default we assign buffer to all HloValues.
181 //
182 // - If the instruction is in a nested call of the current computation, only
183 // assign a buffer if we are doing global heap simulation.
184 std::vector<const HloValue*> values_to_assign;
185 values_to_assign.reserve(dataflow_analysis.values().size());
186
187 for (const HloValue* value : dataflow_analysis.values()) {
188 // Ignore buffers that are not tracked.
189 if (hlo_live_range->instruction_schedule().count(
190 value->defining_instruction()) == 0) {
191 continue;
192 }
193 if (IgnoreBuffer(value)) {
194 continue;
195 }
196 values_to_assign.push_back(value);
197 }
198
199 auto& buffer_live_ranges = hlo_live_range->buffer_live_ranges();
200
201 absl::c_sort(values_to_assign,
202 [&](const HloValue* value1, const HloValue* value2) {
203 const auto& live_range1 = buffer_live_ranges.at(value1);
204 const auto& live_range2 = buffer_live_ranges.at(value2);
205 return std::forward_as_tuple(live_range1.start,
206 live_range1.end, value1->id()) <
207 std::forward_as_tuple(live_range2.start,
208 live_range2.end, value2->id());
209 });
210
211 // For each value that we need to assign a buffer to, add the define and free
212 // events.
213 for (const HloValue* value : values_to_assign) {
214 auto live_range = buffer_live_ranges.at(value);
215 buffers_defined[live_range.start].push_back(value);
216 buffers_freed[live_range.end].push_back(value);
217 }
218
219 // All HloValues in a hlo buffer should be allocated to the same address. This
220 // map tracks the first value that got allocated in a buffer.
221 absl::flat_hash_map<const HloBuffer*, const HloValue*> first_allocated_value;
222
223 VLOG(1) << "Program time" << hlo_live_range->schedule_end_time();
224
225 // Go through each step in the program and replay each buffer define and free
226 // events.
227 for (int64 i = 0; i < hlo_live_range->schedule_end_time() + 1; ++i) {
228 VLOG(1) << "Time step: " << i;
229
230 for (const HloValue* value : buffers_defined[i]) {
231 bool shared = false;
232 VLOG(1) << "Start buffer: " << value->ToShortString();
233 const HloBuffer* hlo_buffer =
234 &alias_analysis.GetBufferContainingValue(*value);
235 if (first_allocated_value.count(hlo_buffer) != 0) {
236 // We've already assigned an address for another value in this HloBuffer
237 // (HloBuffer holds several aliased HloValues). All values in a buffer
238 // should be assigned the same address. Find the one that's already
239 // allocated and reuse its address.
240 ShareBuffer(value, first_allocated_value[hlo_buffer],
241 value->instruction());
242 VLOG(1) << " ShareWith"
243 << first_allocated_value[hlo_buffer]->ToShortString();
244 continue;
245 }
246 if (options_.may_reuse_operand_buffers &&
247 hlo_buffer->values().size() == 1) {
248 // We don't support sharing an aliased buffer
249 // (hlo_buffer->values().size() > 1) with its operand.
250 for (const HloInstruction* operand : value->instruction()->operands()) {
251 const HloValueSet operand_value_set =
252 dataflow_analysis.GetValueSet(operand);
253 for (const HloValue* operand_value : operand_value_set.values()) {
254 const HloBuffer* operand_buffer =
255 &alias_analysis.GetBufferContainingValue(*operand_value);
256 if (operand_buffer->values().size() > 1) {
257 continue;
258 }
259 auto it = buffer_live_ranges.find(operand_value);
260 if (it == buffer_live_ranges.end()) {
261 continue;
262 }
263
264 auto& operand_live_range = it->second;
265
266 auto& user_live_range = buffer_live_ranges[value];
267
268 // Can only share buffers that are about to be freed.
269 if (operand_live_range.end != i) {
270 continue;
271 }
272
273 if (IgnoreBuffer(operand_value)) {
274 continue;
275 }
276
277 if (!absl::c_linear_search(buffers_freed[i], operand_value)) {
278 // If the operand buffer is not being freed (either because it has
279 // existing users, or it has been reused by other buffers), don't
280 // consider the operand as a candidate of buffer sharing.
281 continue;
282 }
283
284 // The instruction that defines the operand value can be different
285 // from the actual operand, if directly passing the defining
286 // instruction into "CanShareOperandBufferWithUser" it creates a
287 // check failure. The first condition guards against that case.
288 if (value->instruction()->IsUserOf(operand_value->instruction()) &&
289 value->instruction()->opcode() != HloOpcode::kCopy &&
290 dataflow_analysis.CanShareOperandBufferWithUser(
291 operand_value->instruction(), operand_value->index(),
292 value->instruction(), value->index())) {
293 // Remove the operand buffer right before sharing (allocating) a
294 // new one.
295 Free(operand_value, operand_value->instruction());
296 buffers_freed[i].erase(
297 std::remove(buffers_freed[i].begin(), buffers_freed[i].end(),
298 operand_value),
299 buffers_freed[i].end());
300 ShareBuffer(value, operand_value, value->instruction());
301 // The live range of the operand buffer is now extended to the end
302 // of the current instruction.
303 operand_live_range.end = user_live_range.end;
304 VLOG(1) << "Sharing " << value->ToShortString() << " with "
305 << operand_value->ToShortString()
306 << ", size:" << size_fn_(*value);
307 shared = true;
308 break;
309 }
310 }
311 if (shared) {
312 break;
313 }
314 }
315 }
316 if (!shared) {
317 Alloc(value, value->instruction());
318 first_allocated_value[hlo_buffer] = value;
319 }
320 }
321
322 if (!buffers_freed[i].empty()) {
323 VLOG(1) << "Free Buffer: ";
324 }
325 for (const HloValue* value : buffers_freed[i]) {
326 VLOG(1) << " " << value->ToShortString();
327
328 Free(value, value->instruction());
329 }
330 }
331 return Status::OK();
332 }
333
HeapSimulator(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,const BufferValue::SizeFunction & size_fn,const Options & options,const HloSchedule * schedule,const absl::flat_hash_map<const HloComputation *,int64> * memory_by_computation)334 HeapSimulator::HeapSimulator(
335 std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
336 const BufferValue::SizeFunction& size_fn, const Options& options,
337 const HloSchedule* schedule,
338 const absl::flat_hash_map<const HloComputation*, int64>*
339 memory_by_computation)
340 : no_fragmentation_stats_(
341 absl::make_unique<NoFragmentationStatsHeap<HloValue>>()),
342 algorithm_(std::move(algorithm)),
343 size_fn_(size_fn),
344 options_(options),
345 schedule_(schedule),
346 memory_by_computation_(memory_by_computation) {
347 debug_trace_.set_whole_module_simulation(schedule_ != nullptr);
348 }
349
~HeapSimulator()350 HeapSimulator::~HeapSimulator() {}
351
IgnoreBuffer(const HloValue * buffer) const352 bool HeapSimulator::IgnoreBuffer(const HloValue* buffer) const {
353 // Buffers for constants are ignored unless the alloc_constants option is
354 // set. Also ignore buffers that we're not meant to assign.
355 //
356 // TODO(b/32248867): For consistency, constants should get allocations.
357 if (!options_.alloc_constants &&
358 buffer->instruction()->opcode() == HloOpcode::kConstant) {
359 return true;
360 }
361 return options_.buffers_to_assign != nullptr &&
362 !options_.buffers_to_assign->contains(buffer);
363 }
364
365 // Alloc always calls the underlying heap algorithm.
Alloc(const HloValue * buffer,const HloInstruction * instruction)366 void HeapSimulator::Alloc(const HloValue* buffer,
367 const HloInstruction* instruction) {
368 CHECK(!allocated_buffers_.contains(buffer))
369 << "Alloc called on allocated buffer: " << *buffer;
370 CHECK(!freed_buffers_.contains(buffer))
371 << "Alloc called on freed buffer: " << *buffer;
372
373 allocated_buffers_.insert(buffer);
374 const int64 size = size_fn_(*buffer);
375 algorithm_->Alloc(buffer, size);
376 no_fragmentation_stats_->Alloc(buffer, size);
377 FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction,
378 nullptr);
379 }
380
381 // Free calls the underlying algorithm for non-shared buffers, and for shared
382 // buffers whose group liveness has expired. Shared group liveness is tracked
383 // by maintaining a refcount; the Free call on the last buffer in the group
384 // causes Free to be called on the underlying algorithm.
Free(const HloValue * buffer,const HloInstruction * instruction)385 void HeapSimulator::Free(const HloValue* buffer,
386 const HloInstruction* instruction) {
387 const int64 size = size_fn_(*buffer);
388 algorithm_->Free(buffer, size);
389 no_fragmentation_stats_->Free(buffer, size);
390 FillDebugTrace(HeapSimulatorTrace::Event::FREE, buffer, instruction, nullptr);
391 }
392
393 // ShareBuffer associates buffers with their SharedGroup in shared_buffers_.
394 // The 'buffer' must be a non-allocated, non-freed buffer, just like in calls
395 // to Alloc. The 'shared' buffer must be a previously allocated or shared
396 // buffer. Both 'buffer' and 'shared' will be associated with the same
397 // SharedGroup.
ShareBuffer(const HloValue * buffer,const HloValue * shared,const HloInstruction * instruction)398 void HeapSimulator::ShareBuffer(const HloValue* buffer, const HloValue* shared,
399 const HloInstruction* instruction) {
400 algorithm_->ShareWith(buffer, shared, size_fn_(*shared));
401 no_fragmentation_stats_->ShareWith(buffer, shared, size_fn_(*shared));
402 FillDebugTrace(HeapSimulatorTrace::Event::SHARE_WITH, buffer, instruction,
403 shared);
404 }
405
Finish()406 HeapSimulator::Result<HloValue> HeapSimulator::Finish() {
407 Result<HloValue> result = algorithm_->Finish();
408
409 // Post-process the result to add chunks for shared buffers. An empty chunk
410 // map means that either no buffers were allocated, or the heap was only
411 // collecting statistics, e.g. NoFragmentationStatsHeap.
412 size_t total_chunk_count = absl::c_accumulate(
413 result.heap_results, static_cast<size_t>(0),
414 [&](size_t lhs, const HeapResult<HloValue>& rhs) -> size_t {
415 return lhs + rhs.chunk_map.size();
416 });
417 if (total_chunk_count != 0) {
418 // If we were told to assign specific buffers, make sure we've assigned
419 // exactly that many buffers.
420 if (options_.buffers_to_assign != nullptr) {
421 CHECK_EQ(options_.buffers_to_assign->size(), total_chunk_count);
422 }
423 }
424
425 // Fragmentation is the difference between the actual and ideal sizes.
426 const Result<HloValue> no_frag_result = no_fragmentation_stats_->Finish();
427 result.fragmentation_size = result.heap_size - no_frag_result.heap_size;
428
429 // Copy the debug trace we collected to the final result.
430 result.debug_trace.Swap(&debug_trace_);
431
432 return result;
433 }
434
FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,const HloValue * buffer,const HloInstruction * instruction,const HloValue * share_with_canonical)435 void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,
436 const HloValue* buffer,
437 const HloInstruction* instruction,
438 const HloValue* share_with_canonical) {
439 HeapSimulatorTrace::Event* event = debug_trace_.add_events();
440 event->set_kind(kind);
441 event->set_buffer_id(buffer->id());
442 event->set_computation_name(instruction->parent()->name());
443 event->set_instruction_name(instruction->name());
444 if (kind == HeapSimulatorTrace::Event::SHARE_WITH) {
445 CHECK(share_with_canonical != nullptr);
446 event->set_share_with_canonical_id(share_with_canonical->id());
447 } else {
448 CHECK(share_with_canonical == nullptr);
449 }
450 }
451
452 template <typename BufferType>
Alloc(const BufferType * buffer,int64 size)453 void NoFragmentationStatsHeap<BufferType>::Alloc(const BufferType* buffer,
454 int64 size) {
455 current_heap_size_ += size;
456 if (current_heap_size_ > max_heap_size_) {
457 max_heap_size_ = current_heap_size_;
458 }
459 }
460
461 template <typename BufferType>
AccountForSubcomputationMemory(const HloInstruction * instruction,int64 alloc_size_by_instruction,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation)462 void NoFragmentationStatsHeap<BufferType>::AccountForSubcomputationMemory(
463 const HloInstruction* instruction, int64 alloc_size_by_instruction,
464 const absl::flat_hash_map<const HloComputation*, int64>&
465 memory_by_computation) {
466 // We only count the memory usage of the largest subcomputation, instead of
467 // adding them all, because subcomputations won't execute in parallel.
468 int64 max_subcomputation_bytes = 0;
469 for (const auto* c : instruction->called_computations()) {
470 auto it = memory_by_computation.find(c);
471 if (it != memory_by_computation.end()) {
472 int64 subcomputation_bytes = it->second;
473 if (subcomputation_bytes > max_subcomputation_bytes) {
474 max_subcomputation_bytes = subcomputation_bytes;
475 }
476 }
477 }
478 if (max_subcomputation_bytes > 0 &&
479 (instruction->opcode() == HloOpcode::kWhile ||
480 instruction->opcode() == HloOpcode::kCall ||
481 instruction->opcode() == HloOpcode::kConditional)) {
482 // The output buffer of while/call/conditional is always aliased with the
483 // output buffer of the root instruction in the body. Don't double count.
484 max_subcomputation_bytes -= alloc_size_by_instruction;
485 }
486 max_heap_size_ =
487 std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes);
488 }
489
490 template <typename BufferType>
Free(const BufferType * buffer,int64 size)491 void NoFragmentationStatsHeap<BufferType>::Free(const BufferType* buffer,
492 int64 size) {
493 current_heap_size_ -= size;
494 }
495
496 template <typename BufferType>
497 HeapSimulator::Result<BufferType>
Finish()498 NoFragmentationStatsHeap<BufferType>::Finish() {
499 // The result.chunk_map is empty, since we only collect stats, and don't
500 // actually compute chunk assignments.
501 Result result;
502 result.heap_size = max_heap_size_;
503 return result;
504 }
505
506 template <typename BufferType>
GlobalDecreasingSizeBestFitHeap(int64 alignment,Type type)507 GlobalDecreasingSizeBestFitHeap<BufferType>::GlobalDecreasingSizeBestFitHeap(
508 int64 alignment, Type type)
509 : alignment_(alignment) {
510 if (type == kTemporal) {
511 buffer_interval_compare_ = GetTemporalBufferIntervalCompare();
512 } else {
513 CHECK(type == kSpatial);
514 buffer_interval_compare_ = GetSpatialBufferIntervalCompare();
515 }
516 }
517
518 template <typename BufferType>
519 typename GlobalDecreasingSizeBestFitHeap<BufferType>::BufferIntervalCompare
GetTemporalBufferIntervalCompare() const520 GlobalDecreasingSizeBestFitHeap<BufferType>::GetTemporalBufferIntervalCompare()
521 const {
522 return [&](const BufferInterval& x, const BufferInterval& y) {
523 int64 x_end = x.end;
524 for (auto colocation : GetTransitiveColocations(x)) {
525 x_end = std::max(x_end, buffer_intervals_.at(colocation).end);
526 }
527
528 int64 y_end = y.end;
529 for (auto colocation : GetTransitiveColocations(y)) {
530 y_end = std::max(y_end, buffer_intervals_.at(colocation).end);
531 }
532
533 if (x_end - x.start != y_end - y.start) {
534 return x_end - x.start > y_end - y.start;
535 }
536
537 if (x.size != y.size) {
538 return x.size > y.size;
539 }
540 return *x.buffer < *y.buffer;
541 };
542 }
543
544 template <typename BufferType>
545 /*static*/ typename GlobalDecreasingSizeBestFitHeap<
546 BufferType>::BufferIntervalCompare
GetSpatialBufferIntervalCompare()547 GlobalDecreasingSizeBestFitHeap<BufferType>::GetSpatialBufferIntervalCompare() {
548 return [&](const BufferInterval& x, const BufferInterval& y) {
549 if (x.size != y.size) {
550 return x.size > y.size;
551 }
552 if (x.end - x.start != y.end - y.start) {
553 return x.end - x.start > y.end - y.start;
554 }
555 return *x.buffer < *y.buffer;
556 };
557 }
558
559 template <typename BufferType>
Alloc(const BufferType * buffer,int64 size)560 void GlobalDecreasingSizeBestFitHeap<BufferType>::Alloc(
561 const BufferType* buffer, int64 size) {
562 // Degenerate case: 0-sized buffers are always allocated at offset 0.
563 if (size == 0) {
564 result_.chunk_map.emplace(buffer, Chunk{0, 0});
565 return;
566 }
567
568 auto emplace_result = buffer_intervals_.emplace(
569 buffer, BufferInterval{buffer, size, current_time_, -1, {}, true});
570 DCHECK(emplace_result.second);
571 ++current_time_;
572 }
573
574 template <typename BufferType>
ShareWith(const BufferType * buffer,const BufferType * share_with,int64 size)575 void GlobalDecreasingSizeBestFitHeap<BufferType>::ShareWith(
576 const BufferType* buffer, const BufferType* share_with, int64 size) {
577 // Degenerate case: 0-sized buffers are always allocated at offset 0.
578 if (size == 0) {
579 result_.chunk_map.emplace(buffer, Chunk{0, 0});
580 return;
581 }
582 DCHECK_NE(buffer_intervals_.count(share_with), 0);
583 buffer_intervals_[share_with].colocations.push_back(buffer);
584 auto emplace_result = buffer_intervals_.emplace(
585 buffer, BufferInterval{buffer, size, current_time_, -1, {}, false});
586 DCHECK(emplace_result.second);
587 ++current_time_;
588 }
589
590 template <typename BufferType>
591 absl::flat_hash_set<const BufferType*>
GetTransitiveColocations(const BufferInterval & interval) const592 GlobalDecreasingSizeBestFitHeap<BufferType>::GetTransitiveColocations(
593 const BufferInterval& interval) const {
594 absl::flat_hash_set<const BufferType*> result;
595 std::vector<const BufferInterval*> worklist = {&interval};
596 while (!worklist.empty()) {
597 const BufferInterval* item = worklist.back();
598 worklist.pop_back();
599 for (const BufferType* buffer_colocated : item->colocations) {
600 result.insert(buffer_colocated);
601 worklist.push_back(&buffer_intervals_.at(buffer_colocated));
602 }
603 }
604
605 return result;
606 }
607
608 template <typename BufferType>
Free(const BufferType * buffer,int64 size)609 void GlobalDecreasingSizeBestFitHeap<BufferType>::Free(const BufferType* buffer,
610 int64 size) {
611 // Degenerate case: 0-sized buffers are always allocated at offset 0.
612 if (size == 0) {
613 return;
614 }
615 BufferInterval& buffer_interval = FindOrDie(buffer_intervals_, buffer);
616 DCHECK_EQ(buffer_interval.buffer, buffer);
617 DCHECK_EQ(buffer_interval.size, size);
618 DCHECK_EQ(buffer_interval.end, -1);
619 if (buffer_interval.end != -1) {
620 return;
621 }
622 buffer_interval.end = current_time_;
623 ++current_time_;
624 }
625
626 using Chunk = HeapSimulator::Chunk;
627
Add(int64 start,int64 end,const Chunk & chunk)628 void BufferIntervalTree::Add(int64 start, int64 end, const Chunk& chunk) {
629 node_storage_.emplace_back(BufferIntervalTreeNode{
630 start, end, end, chunk,
631 /*left=*/nullptr, /*right=*/nullptr, /*parent=*/nullptr});
632 if (root_ == nullptr) {
633 root_ = &node_storage_.back();
634 // This is root.
635 return;
636 }
637
638 BufferIntervalTreeNode* parent = root_;
639 while (true) {
640 parent->subtree_end = std::max(parent->subtree_end, end);
641 if (parent->start > start) {
642 if (parent->left == nullptr) {
643 parent->left = &node_storage_.back();
644 node_storage_.back().parent = parent;
645 return;
646 }
647 parent = parent->left;
648 } else {
649 if (parent->right == nullptr) {
650 parent->right = &node_storage_.back();
651 node_storage_.back().parent = parent;
652 return;
653 }
654 parent = parent->right;
655 }
656 }
657 }
658
Remove(int64 start,int64 end,const Chunk & chunk)659 bool BufferIntervalTree::Remove(int64 start, int64 end, const Chunk& chunk) {
660 BufferIntervalTreeNode* to_delete = root_;
661 while (to_delete != nullptr) {
662 if (to_delete->start == start && to_delete->end == end &&
663 to_delete->chunk.offset == chunk.offset) {
664 break;
665 }
666 if (start < to_delete->start) {
667 to_delete = to_delete->left;
668 } else {
669 to_delete = to_delete->right;
670 }
671 }
672 if (to_delete == nullptr) {
673 // Nothing to delete.
674 return false;
675 }
676 // Found the node to be deleted, enter deletion sequence.
677
678 // Recursively traverse the parents of node and fix up the `subtree_end`
679 // invariant of a node. Recursive lambda need an explicit
680 // std::function declaration.
681 std::function<void(BufferIntervalTreeNode*)> fix_up =
682 [&](BufferIntervalTreeNode* node) {
683 if (node == nullptr) {
684 return;
685 }
686 node->subtree_end = node->end;
687 if (node->left) {
688 node->subtree_end =
689 std::max(node->subtree_end, node->left->subtree_end);
690 }
691 if (node->right) {
692 node->subtree_end =
693 std::max(node->subtree_end, node->right->subtree_end);
694 }
695 // Recursively go up.
696 fix_up(node->parent);
697 };
698
699 if (to_delete->right == nullptr) {
700 // to_delete has no right child, simply move up left child of to_delete if
701 // any.
702 //
703 // Turn:
704 // parent
705 // /
706 // to_delete
707 // / \
708 // left nullptr
709 //
710 // Into:
711 // parent
712 // /
713 // left
714 if (root_ == to_delete) {
715 // Deleting root is simply reseting root;
716 root_ = to_delete->left;
717 return true;
718 }
719
720 if (to_delete == to_delete->parent->left) {
721 // to_delete is left child of parent.
722 to_delete->parent->left = to_delete->left;
723 }
724 if (to_delete == to_delete->parent->right) {
725 // to_delete is right child of parent.
726 to_delete->parent->right = to_delete->left;
727 }
728 // Rewire parent to the node being moved up.
729 if (to_delete->left) {
730 to_delete->left->parent = to_delete->parent;
731 }
732 // Fix up starting from subroot.
733 fix_up(to_delete);
734 } else {
735 // 1. Find left-most node of the right subtree, promote it to the position
736 // of to_delete.
737 BufferIntervalTreeNode* to_promote = to_delete->right;
738 while (to_promote->left != nullptr) {
739 // Go to left-most subtree.
740 to_promote = to_promote->left;
741 }
742
743 // 2. Copy the content of `to_promote` to `to_delete`.
744 to_delete->start = to_promote->start;
745 to_delete->end = to_promote->end;
746 // This is incorrect but we will fix this up later in the `fix_up`
747 // procedure.
748 to_delete->subtree_end = to_promote->subtree_end;
749 to_delete->chunk = to_promote->chunk;
750 auto to_promote_parent = to_promote->parent;
751 // 3. Move the right child of `to_promote` up if there is any.
752 //
753 // Turn
754 //
755 // to_delete
756 // \
757 // to_promote_parent
758 // /
759 // to_promote
760 // \
761 // right
762 // into
763 //
764 // to_promote
765 // \
766 // to_promote_parent
767 // /
768 // right
769 if (to_promote_parent->left == to_promote) {
770 to_promote_parent->left = to_promote->right;
771 } else {
772 to_promote_parent->right = to_promote->right;
773 }
774 if (to_promote->right) {
775 // Set correct parent.
776 to_promote->right->parent = to_promote_parent;
777 }
778 // 4. Recursive fix up the `subtree_end` starting from
779 // `to_promote_parent`.
780 fix_up(to_promote_parent);
781 }
782 // Don't free the entry in node_storage_ until we free the entire tree.
783 return true;
784 }
785
ChunksOverlappingInTime(int64 start,int64 end) const786 std::vector<Chunk> BufferIntervalTree::ChunksOverlappingInTime(
787 int64 start, int64 end) const {
788 std::vector<Chunk> result;
789 if (root_ == nullptr) {
790 return result;
791 }
792 std::vector<const BufferIntervalTreeNode*> visiting_stack;
793 visiting_stack.push_back(root_);
794 while (!visiting_stack.empty()) {
795 const BufferIntervalTreeNode* top = visiting_stack.back();
796 visiting_stack.pop_back();
797 if (start > top->subtree_end) {
798 continue;
799 }
800 if (top->left != nullptr) {
801 visiting_stack.push_back(top->left);
802 }
803 if (top->start <= end && top->end >= start) {
804 result.push_back(top->chunk);
805 }
806 if (end < top->start) {
807 continue;
808 }
809 if (top->right != nullptr) {
810 visiting_stack.push_back(top->right);
811 }
812 }
813 return result;
814 }
815
816 template <typename BufferType>
817 HeapSimulator::Result<BufferType>
Finish()818 GlobalDecreasingSizeBestFitHeap<BufferType>::Finish() {
819 std::vector<BufferInterval> sorted_buffer_intervals =
820 GetSortedBufferIntervals();
821
822 for (auto& buffer_interval : sorted_buffer_intervals) {
823 if (!buffer_interval.need_allocation) {
824 continue;
825 }
826
827 ChunkCandidate chunk_candidate = FindChunkCandidate(buffer_interval);
828 // This implementation of the heap algorithm does not have a notion of
829 // maximum heap size, so it just commits.
830 CommitChunk(buffer_interval, chunk_candidate);
831 }
832 VLOG(1) << "result heap_size: " << result_.heap_size;
833 Result result;
834 result.heap_size = result_.heap_size;
835 result.heap_results.emplace_back(result_);
836 return result;
837 }
838
839 template <typename BufferType>
840 std::vector<
841 typename GlobalDecreasingSizeBestFitHeap<BufferType>::BufferInterval>
GetSortedBufferIntervals() const842 GlobalDecreasingSizeBestFitHeap<BufferType>::GetSortedBufferIntervals() const {
843 std::vector<BufferInterval> sorted_buffer_intervals;
844 for (auto& entry : buffer_intervals_) {
845 sorted_buffer_intervals.push_back(entry.second);
846 }
847 absl::c_sort(sorted_buffer_intervals, buffer_interval_compare_);
848
849 return sorted_buffer_intervals;
850 }
851
852 template <typename BufferType>
853 typename GlobalDecreasingSizeBestFitHeap<BufferType>::ChunkCandidate
FindChunkCandidate(const GlobalDecreasingSizeBestFitHeap::BufferInterval & buffer_interval,int64 preferred_offset) const854 GlobalDecreasingSizeBestFitHeap<BufferType>::FindChunkCandidate(
855 const GlobalDecreasingSizeBestFitHeap::BufferInterval& buffer_interval,
856 int64 preferred_offset) const {
857 VLOG(1) << "Finding chunks for buffer: "
858 << buffer_interval.buffer->ToString();
859 VLOG(1) << "Size " << buffer_interval.size << ", start "
860 << buffer_interval.start << ", end " << buffer_interval.end;
861 auto chunks_overlapping_in_time = interval_tree_.ChunksOverlappingInTime(
862 buffer_interval.start, buffer_interval.end);
863 // Get all colocated buffers and gather all interferenced chunks.
864 //
865 // Imagine that we've already allocated three chunks : a, b and c. And now
866 // we want to allocate d. Since e is colocated with d, we have to allocate
867 // chunks for them together at the same address. To do this, we first gather
868 // all chunks that overlap with d and e on the time dimension, in this case
869 // the overlapped chunks are a and b (c doesn't overlap with either of d and
870 // e), then find create a new chunk that doesn't overlap with a and b on the
871 // space dimension.
872 //
873 // space
874 // ^
875 // |+--d---+ +---e---+
876 // |
877 // |+---+ +---------------+ +-------+
878 // || | | | | |
879 // || | | | | |
880 // |+-a-+ +-------b-------+ +---c---+
881 // ----------------------------------------> time
882 for (auto colocation : GetTransitiveColocations(buffer_interval)) {
883 auto colocation_interval = buffer_intervals_.at(colocation);
884 auto colocation_overlapping = interval_tree_.ChunksOverlappingInTime(
885 colocation_interval.start, colocation_interval.end);
886 VLOG(1) << " Alias size " << colocation_interval.size << ", start "
887 << colocation_interval.start << ", end " << colocation_interval.end
888 << " " << colocation_interval.buffer->ToString();
889 chunks_overlapping_in_time.insert(chunks_overlapping_in_time.end(),
890 colocation_overlapping.begin(),
891 colocation_overlapping.end());
892 }
893 absl::c_sort(chunks_overlapping_in_time, [](const Chunk& x, const Chunk& y) {
894 return x.offset < y.offset;
895 });
896
897 // Find the minimum free chunk that can hold this buffer.
898 ChunkCandidate chunk_candidate{Chunk{-1, INT64_MAX}, result_.heap_size};
899 Chunk& min_fit_chunk = chunk_candidate.chunk;
900 int64 preferred_chunk_end = preferred_offset + buffer_interval.size;
901 auto use_free_chunk_if_smaller = [&](int64 free_offset, int64 free_size) {
902 if (free_size < buffer_interval.size) {
903 return;
904 }
905
906 // If a preferred offset is provided, pick that offset.
907 if (free_offset <= preferred_offset &&
908 free_offset + free_size >= preferred_chunk_end) {
909 min_fit_chunk = {preferred_offset, buffer_interval.size};
910 } else if (free_offset + free_size == result_.heap_size &&
911 free_offset <= preferred_offset) {
912 // If the free offset is at the very end and if the preferred offset lies
913 // in this, pick the preferred offset and grow the heap.
914 min_fit_chunk = {preferred_offset, buffer_interval.size};
915 chunk_candidate.heap_size = preferred_chunk_end;
916 }
917
918 // Pick the min-fit chunk only if we didn't have a preferred offset or a
919 // chunk at the preferred offset hasn't been found.
920 if ((preferred_offset < 0 || min_fit_chunk.offset != preferred_offset) &&
921 free_size < min_fit_chunk.size) {
922 min_fit_chunk = {free_offset, free_size};
923 }
924 };
925
926 int64 offset = 0;
927 for (auto& chunk : chunks_overlapping_in_time) {
928 if (offset < chunk.offset) {
929 use_free_chunk_if_smaller(offset, chunk.offset - offset);
930 }
931 offset = std::max(offset, RoundUpToNearest(chunk.chunk_end(), alignment_));
932 }
933 use_free_chunk_if_smaller(offset, result_.heap_size - offset);
934 // When preferred offset is provided and the preferred offset is larger than
935 // the current heap size, simply use the preferred offset provided.
936 if (result_.heap_size <= preferred_offset) {
937 chunk_candidate.heap_size = preferred_chunk_end;
938 min_fit_chunk = {preferred_offset, buffer_interval.size};
939 }
940
941 if (min_fit_chunk.offset == -1) {
942 // Increase the heap size to fit in the last free chunk.
943 chunk_candidate.heap_size = offset + buffer_interval.size;
944 min_fit_chunk = {offset, buffer_interval.size};
945 }
946
947 min_fit_chunk.size = buffer_interval.size;
948 return chunk_candidate;
949 }
950
951 template <typename BufferType>
CommitChunk(const GlobalDecreasingSizeBestFitHeap<BufferType>::BufferInterval & buffer_interval,GlobalDecreasingSizeBestFitHeap<BufferType>::ChunkCandidate chunk_candidate)952 void GlobalDecreasingSizeBestFitHeap<BufferType>::CommitChunk(
953 const GlobalDecreasingSizeBestFitHeap<BufferType>::BufferInterval&
954 buffer_interval,
955 GlobalDecreasingSizeBestFitHeap<BufferType>::ChunkCandidate
956 chunk_candidate) {
957 // Update the maximum heap size according to the one determined by the chunk
958 // candidate.
959 result_.heap_size = chunk_candidate.heap_size;
960 interval_tree_.Add(buffer_interval.start, buffer_interval.end,
961 chunk_candidate.chunk);
962 for (auto colocation : GetTransitiveColocations(buffer_interval)) {
963 AddToChunkMap(colocation, chunk_candidate.chunk);
964 auto colocation_interval = buffer_intervals_[colocation];
965 interval_tree_.Add(colocation_interval.start, colocation_interval.end,
966 chunk_candidate.chunk);
967 }
968
969 AddToChunkMap(buffer_interval.buffer, chunk_candidate.chunk);
970 }
971
972 template <typename BufferType>
AddToChunkMap(const BufferType * buffer,Chunk chunk)973 void GlobalDecreasingSizeBestFitHeap<BufferType>::AddToChunkMap(
974 const BufferType* buffer, Chunk chunk) {
975 const auto emplace_result = result_.chunk_map.emplace(buffer, chunk);
976 DCHECK(emplace_result.second);
977 }
978
979 HeapSimulator::Result<HloValue>
Finish()980 ConstrainedGlobalDecreasingSizeBestFitHeap::Finish() {
981 std::vector<BufferInterval> sorted_buffer_vec = GetSortedBufferIntervals();
982 // Convert into std::list so that erase() is O(1).
983 std::list<BufferInterval> sorted_buffer_intervals(sorted_buffer_vec.begin(),
984 sorted_buffer_vec.end());
985
986 // Use do-while here, because we need to create 1 heap in `multi_heap_result`
987 // even if `sorted_buffer_intervals` is empty.
988 Result multi_heap_result;
989 do {
990 // Place buffers into the currently processed heap as many as possible.
991 for (auto it = sorted_buffer_intervals.begin();
992 it != sorted_buffer_intervals.end();) {
993 BufferInterval buffer_interval = *it;
994 if (!buffer_interval.need_allocation) {
995 it = sorted_buffer_intervals.erase(it);
996 continue;
997 }
998 if (buffer_interval.size > size_limit_per_heap_) {
999 LOG(WARNING) << "Alloc buffer size " << buffer_interval.size
1000 << " larger than the per-heap size limit "
1001 << size_limit_per_heap_;
1002 }
1003
1004 ChunkCandidate chunk_candidate = FindChunkCandidate(buffer_interval);
1005 if (chunk_candidate.heap_size <= size_limit_per_heap_ ||
1006 // Commit the chunk as long as the heap is empty. We do this because
1007 // we want the size constraint to be soft, meaning that results are
1008 // successfully generated even if there are some buffer sizes larger
1009 // than the given constraint size.
1010 result_.heap_size == 0) {
1011 CommitChunk(buffer_interval, chunk_candidate);
1012 it = sorted_buffer_intervals.erase(it);
1013 continue;
1014 }
1015
1016 ++it;
1017 }
1018 // Collect the result from the currently processed heap and reset the heap
1019 // states.
1020 multi_heap_result.heap_size += result_.heap_size;
1021 multi_heap_result.heap_results.push_back(std::move(result_));
1022 result_ = {};
1023 interval_tree_ = {};
1024 } while (!sorted_buffer_intervals.empty());
1025
1026 VLOG(1) << "Number of heaps produced = "
1027 << multi_heap_result.heap_results.size();
1028 return multi_heap_result;
1029 }
1030
1031 template <typename BufferType>
1032 HeapSimulator::Result<BufferType>
Finish()1033 ChooseBestHeapAlgorithm<BufferType>::Finish() {
1034 DCHECK(!algorithms_.empty());
1035 std::vector<Result> results(algorithms_.size());
1036 int64 min_size = INT64_MAX;
1037 int min_size_index = -1;
1038 for (int i = 0; i < algorithms_.size(); ++i) {
1039 results[i] = algorithms_[i]->Finish();
1040 if (results[i].heap_size < min_size) {
1041 min_size = results[i].heap_size;
1042 min_size_index = i;
1043 }
1044 }
1045
1046 DCHECK_GE(min_size_index, 0);
1047 return results[min_size_index];
1048 }
1049
1050 template class GlobalDecreasingSizeBestFitHeap<HloValue>;
1051 template class GlobalDecreasingSizeBestFitHeap<
1052 MemorySpaceAssignmentRepacker::AllocationBlock>;
1053 template class ChooseBestHeapAlgorithm<HloValue>;
1054
1055 } // namespace xla
1056