1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/heap_simulator.h"
17
18 #include <algorithm>
19 #include <vector>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/memory/memory.h"
24 #include "tensorflow/compiler/xla/map_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_live_range.h"
26 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
27 #include "tensorflow/compiler/xla/service/memory_space_assignment_repacking.h"
28 #include "tensorflow/compiler/xla/util.h"
29
30 namespace xla {
31
32 using absl::flat_hash_map;
33 using absl::flat_hash_set;
34
OverlapsWith(Chunk other_chunk) const35 bool HeapSimulator::Chunk::OverlapsWith(Chunk other_chunk) const {
36 CHECK_NE(size, 0);
37 CHECK_NE(other_chunk.size, 0);
38 return offset < other_chunk.chunk_end() && other_chunk.offset < chunk_end();
39 }
40
41 /*static*/
MinimumMemoryForModule(const HloSchedule & schedule,const LogicalBuffer::SizeFunction & size_function)42 StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
43 const HloSchedule& schedule,
44 const LogicalBuffer::SizeFunction& size_function) {
45 if (schedule.empty()) {
46 return 0;
47 }
48 const HloModule* module = schedule.module();
49
50 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
51 HloAliasAnalysis::Run(module));
52
53 // The absolute minimum memory required for a given sequence of instructions
54 // is determined by the sequence of Alloc and Free calls on a simulated heap,
55 // ignoring fragmentation. We run the heap simulation on the whole module,
56 // rather than summing each computation, since it gives us a better lower
57 // bound, by minimizing the liveness of sub-computations.
58 TF_ASSIGN_OR_RETURN(
59 HeapSimulator::Result<HloValue> result,
60 HeapSimulator::Run(
61 absl::make_unique<NoFragmentationStatsHeap<HloValue>>(), *module,
62 schedule, *alias_analysis, size_function));
63 return result.heap_size;
64 }
65
66 /*static*/
MinimumMemoryForComputation(const HloComputation & computation,const HloInstructionSequence & sequence,const HloAliasAnalysis & alias_analysis,const LogicalBuffer::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> * memory_by_computation)67 StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
68 const HloComputation& computation, const HloInstructionSequence& sequence,
69 const HloAliasAnalysis& alias_analysis,
70 const LogicalBuffer::SizeFunction& size_function,
71 const absl::flat_hash_map<const HloComputation*, int64>*
72 memory_by_computation) {
73 TF_ASSIGN_OR_RETURN(
74 HeapSimulator::Result<HloValue> result,
75 HeapSimulator::Run(
76 absl::make_unique<NoFragmentationStatsHeap<HloValue>>(), computation,
77 sequence, alias_analysis, size_function, HeapSimulator::Options(),
78 memory_by_computation));
79 return result.heap_size;
80 }
81
MinimumMemoryForComputation(const HloComputation & computation,const HloInstructionSequence & sequence,const HloAliasAnalysis & alias_analysis,const LogicalBuffer::SizeFunction & size_function,const HloSchedule * schedule)82 StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
83 const HloComputation& computation, const HloInstructionSequence& sequence,
84 const HloAliasAnalysis& alias_analysis,
85 const LogicalBuffer::SizeFunction& size_function,
86 const HloSchedule* schedule) {
87 TF_ASSIGN_OR_RETURN(
88 HeapSimulator::Result<HloValue> result,
89 HeapSimulator::Run(
90 absl::make_unique<NoFragmentationStatsHeap<HloValue>>(), computation,
91 sequence, alias_analysis, size_function, schedule,
92 HeapSimulator::Options()));
93 return result.heap_size;
94 }
95
96 /*static*/
Run(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,const HloModule & module,const HloSchedule & schedule,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_fn,const Options & options)97 StatusOr<HeapSimulator::Result<HloValue>> HeapSimulator::Run(
98 std::unique_ptr<HeapAlgorithm<HloValue>> algorithm, const HloModule& module,
99 const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis,
100 const BufferValue::SizeFunction& size_fn, const Options& options) {
101 HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule);
102 const HloComputation* entry_computation = module.entry_computation();
103 const HloInstructionSequence& instruction_sequence =
104 schedule.sequence(entry_computation);
105 TF_ASSIGN_OR_RETURN(
106 std::unique_ptr<HloLiveRange> hlo_live_range,
107 HloLiveRange::Run(schedule, alias_analysis, entry_computation));
108 TF_RETURN_IF_ERROR(heap.RunComputation(*entry_computation,
109 instruction_sequence, alias_analysis,
110 hlo_live_range.get()));
111 return heap.Finish();
112 }
113
114 /*static*/
Run(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_fn,const Options & options,const absl::flat_hash_map<const HloComputation *,int64> * memory_by_computation)115 StatusOr<HeapSimulator::Result<HloValue>> HeapSimulator::Run(
116 std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
117 const HloComputation& computation,
118 const HloInstructionSequence& instruction_sequence,
119 const HloAliasAnalysis& alias_analysis,
120 const BufferValue::SizeFunction& size_fn, const Options& options,
121 const absl::flat_hash_map<const HloComputation*, int64>*
122 memory_by_computation) {
123 HeapSimulator heap(std::move(algorithm), size_fn, options,
124 /*schedule=*/nullptr, memory_by_computation);
125 HloSchedule schedule(computation.parent());
126 schedule.set_sequence(&computation, instruction_sequence);
127 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
128 HloLiveRange::Run(schedule, alias_analysis, &computation,
129 /*module_scoped_analysis=*/false));
130 TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
131 alias_analysis, hlo_live_range.get()));
132 return heap.Finish();
133 }
134
135 /*static*/
Run(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_fn,const HloSchedule * schedule,const Options & options)136 StatusOr<HeapSimulator::Result<HloValue>> HeapSimulator::Run(
137 std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
138 const HloComputation& computation,
139 const HloInstructionSequence& instruction_sequence,
140 const HloAliasAnalysis& alias_analysis,
141 const BufferValue::SizeFunction& size_fn, const HloSchedule* schedule,
142 const Options& options) {
143 HeapSimulator heap(std::move(algorithm), size_fn, options,
144 /*schedule=*/schedule, nullptr);
145 TF_ASSIGN_OR_RETURN(
146 std::unique_ptr<HloLiveRange> hlo_live_range,
147 HloLiveRange::Run(*schedule, alias_analysis, &computation));
148 TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
149 alias_analysis, hlo_live_range.get()));
150 return heap.Finish();
151 }
152
153 // Runs a heap simulation for the given 'computation', assuming the given
154 // 'instruction_sequence'.
RunComputation(const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const HloAliasAnalysis & alias_analysis,HloLiveRange * hlo_live_range)155 Status HeapSimulator::RunComputation(
156 const HloComputation& computation,
157 const HloInstructionSequence& instruction_sequence,
158 const HloAliasAnalysis& alias_analysis, HloLiveRange* hlo_live_range) {
159 XLA_VLOG_LINES(1, computation.parent()->ToString());
160 XLA_VLOG_LINES(2, computation.ToString());
161
162 VLOG(1) << hlo_live_range->ToString();
163
164 HloDataflowAnalysis& dataflow_analysis = alias_analysis.dataflow_analysis();
165
166 // Record the buffer define/free event for each time step. We free all
167 // remaining buffers (entry parameter, etc) after the program has finished
168 // running, so we set the size of to program_end_time + 1.
169 std::vector<std::vector<const HloValue*>> buffers_defined(
170 hlo_live_range->schedule_end_time() + 1);
171 std::vector<std::vector<const HloValue*>> buffers_freed(
172 hlo_live_range->schedule_end_time() + 1);
173
174 // values_to_assign tracks the HloValues that we need to assign a buffer to.
175 // Note that we only need to assign a buffer to a value when both of the
176 // following conditions are met:
177 //
178 // - The user specifically asks us to assign a buffer to a set of HloValues,
179 // and the value is in the set. If the user don't provide such a set, by
180 // default we assign buffer to all HloValues.
181 //
182 // - If the instruction is in a nested call of the current computation, only
183 // assign a buffer if we are doing global heap simulation.
184 std::vector<const HloValue*> values_to_assign;
185 values_to_assign.reserve(dataflow_analysis.values().size());
186
187 auto& buffer_live_ranges = hlo_live_range->buffer_live_ranges();
188
189 for (const HloValue* value : dataflow_analysis.values()) {
190 // Ignore buffers that are not tracked.
191 if (!buffer_live_ranges.contains(value)) {
192 continue;
193 }
194 if (IgnoreBuffer(value)) {
195 continue;
196 }
197
198 values_to_assign.push_back(value);
199 }
200
201 absl::c_sort(values_to_assign,
202 [&](const HloValue* value1, const HloValue* value2) {
203 const auto& live_range1 = buffer_live_ranges.at(value1);
204 const auto& live_range2 = buffer_live_ranges.at(value2);
205 return std::forward_as_tuple(live_range1.start,
206 live_range1.end, value1->id()) <
207 std::forward_as_tuple(live_range2.start,
208 live_range2.end, value2->id());
209 });
210
211 // For each value that we need to assign a buffer to, add the define and free
212 // events.
213 for (const HloValue* value : values_to_assign) {
214 auto live_range = buffer_live_ranges.at(value);
215 buffers_defined[live_range.start].push_back(value);
216 buffers_freed[live_range.end].push_back(value);
217 }
218
219 // All HloValues in a hlo buffer should be allocated to the same address. This
220 // map tracks the first value that got allocated in a buffer.
221 absl::flat_hash_map<const HloBuffer*, const HloValue*> first_allocated_value;
222
223 VLOG(1) << "Program time" << hlo_live_range->schedule_end_time();
224
225 // Go through each step in the program and replay each buffer define and free
226 // events.
227 for (int64_t i = 0; i < hlo_live_range->schedule_end_time() + 1; ++i) {
228 VLOG(1) << "Time step: " << i;
229
230 for (const HloValue* value : buffers_defined[i]) {
231 bool shared = false;
232 VLOG(1) << "Start buffer: " << value->ToShortString();
233 const HloBuffer* hlo_buffer =
234 &alias_analysis.GetBufferContainingValue(*value);
235 if (first_allocated_value.count(hlo_buffer) != 0) {
236 // We've already assigned an address for another value in this HloBuffer
237 // (HloBuffer holds several aliased HloValues). All values in a buffer
238 // should be assigned the same address. Find the one that's already
239 // allocated and reuse its address.
240 ShareBuffer(value, first_allocated_value[hlo_buffer],
241 value->instruction());
242 VLOG(1) << " ShareWith"
243 << first_allocated_value[hlo_buffer]->ToShortString();
244 continue;
245 }
246 if (options_.may_reuse_operand_buffers &&
247 hlo_buffer->values().size() == 1) {
248 // We don't support sharing an aliased buffer
249 // (hlo_buffer->values().size() > 1) with its operand.
250 for (const HloInstruction* operand : value->instruction()->operands()) {
251 const HloValueSet operand_value_set =
252 dataflow_analysis.GetValueSet(operand);
253 for (const HloValue* operand_value : operand_value_set.values()) {
254 const HloBuffer* operand_buffer =
255 &alias_analysis.GetBufferContainingValue(*operand_value);
256 if (operand_buffer->values().size() > 1) {
257 continue;
258 }
259 auto it = buffer_live_ranges.find(operand_value);
260 if (it == buffer_live_ranges.end()) {
261 continue;
262 }
263
264 auto& operand_live_range = it->second;
265
266 auto& user_live_range = buffer_live_ranges[value];
267
268 // Can only share buffers that are about to be freed.
269 if (operand_live_range.end != i) {
270 continue;
271 }
272
273 if (IgnoreBuffer(operand_value)) {
274 continue;
275 }
276
277 if (!absl::c_linear_search(buffers_freed[i], operand_value)) {
278 // If the operand buffer is not being freed (either because it has
279 // existing users, or it has been reused by other buffers), don't
280 // consider the operand as a candidate of buffer sharing.
281 continue;
282 }
283
284 // The instruction that defines the operand value can be different
285 // from the actual operand, if directly passing the defining
286 // instruction into "CanShareOperandBufferWithUser" it creates a
287 // check failure. The first condition guards against that case.
288 if (value->instruction()->IsUserOf(operand_value->instruction()) &&
289 value->instruction()->opcode() != HloOpcode::kCopy &&
290 dataflow_analysis.CanShareOperandBufferWithUser(
291 operand_value->instruction(), operand_value->index(),
292 value->instruction(), value->index())) {
293 // Remove the operand buffer right before sharing (allocating) a
294 // new one.
295 Free(operand_value, operand_value->instruction());
296 buffers_freed[i].erase(
297 std::remove(buffers_freed[i].begin(), buffers_freed[i].end(),
298 operand_value),
299 buffers_freed[i].end());
300 ShareBuffer(value, operand_value, value->instruction());
301 // The live range of the operand buffer is now extended to the end
302 // of the current instruction.
303 operand_live_range.end = user_live_range.end;
304 VLOG(1) << "Sharing " << value->ToShortString() << " with "
305 << operand_value->ToShortString()
306 << ", size:" << size_fn_(*value);
307 shared = true;
308 break;
309 }
310 }
311 if (shared) {
312 break;
313 }
314 }
315 }
316 if (!shared) {
317 Alloc(value, value->instruction());
318 first_allocated_value[hlo_buffer] = value;
319 }
320 }
321
322 if (!buffers_freed[i].empty()) {
323 VLOG(1) << "Free Buffer: ";
324 }
325 for (const HloValue* value : buffers_freed[i]) {
326 VLOG(1) << " " << value->ToShortString();
327
328 Free(value, value->instruction());
329 }
330 }
331 return Status::OK();
332 }
333
HeapSimulator(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,const BufferValue::SizeFunction & size_fn,const Options & options,const HloSchedule * schedule,const absl::flat_hash_map<const HloComputation *,int64> * memory_by_computation)334 HeapSimulator::HeapSimulator(
335 std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
336 const BufferValue::SizeFunction& size_fn, const Options& options,
337 const HloSchedule* schedule,
338 const absl::flat_hash_map<const HloComputation*, int64>*
339 memory_by_computation)
340 : no_fragmentation_stats_(
341 absl::make_unique<NoFragmentationStatsHeap<HloValue>>()),
342 algorithm_(std::move(algorithm)),
343 size_fn_(size_fn),
344 options_(options),
345 schedule_(schedule),
346 memory_by_computation_(memory_by_computation) {
347 debug_trace_.set_whole_module_simulation(schedule_ != nullptr);
348 }
349
~HeapSimulator()350 HeapSimulator::~HeapSimulator() {}
351
IgnoreBuffer(const HloValue * buffer) const352 bool HeapSimulator::IgnoreBuffer(const HloValue* buffer) const {
353 // Buffers for constants are ignored unless the alloc_constants option is
354 // set. Also ignore buffers that we're not meant to assign.
355 //
356 // TODO(b/32248867): For consistency, constants should get allocations.
357 if (!options_.alloc_constants &&
358 buffer->instruction()->opcode() == HloOpcode::kConstant) {
359 return true;
360 }
361 return options_.buffers_to_assign != nullptr &&
362 !options_.buffers_to_assign->contains(buffer);
363 }
364
365 // Alloc always calls the underlying heap algorithm.
Alloc(const HloValue * buffer,const HloInstruction * instruction)366 void HeapSimulator::Alloc(const HloValue* buffer,
367 const HloInstruction* instruction) {
368 CHECK(!allocated_buffers_.contains(buffer))
369 << "Alloc called on allocated buffer: " << *buffer;
370 CHECK(!freed_buffers_.contains(buffer))
371 << "Alloc called on freed buffer: " << *buffer;
372
373 allocated_buffers_.insert(buffer);
374 const int64_t size = size_fn_(*buffer);
375 algorithm_->Alloc(buffer, size);
376 no_fragmentation_stats_->Alloc(buffer, size);
377 FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction,
378 nullptr);
379 }
380
381 // Free calls the underlying algorithm for non-shared buffers, and for shared
382 // buffers whose group liveness has expired. Shared group liveness is tracked
383 // by maintaining a refcount; the Free call on the last buffer in the group
384 // causes Free to be called on the underlying algorithm.
Free(const HloValue * buffer,const HloInstruction * instruction)385 void HeapSimulator::Free(const HloValue* buffer,
386 const HloInstruction* instruction) {
387 const int64_t size = size_fn_(*buffer);
388 algorithm_->Free(buffer, size);
389 no_fragmentation_stats_->Free(buffer, size);
390 FillDebugTrace(HeapSimulatorTrace::Event::FREE, buffer, instruction, nullptr);
391 }
392
393 // ShareBuffer associates buffers with their SharedGroup in shared_buffers_.
394 // The 'buffer' must be a non-allocated, non-freed buffer, just like in calls
395 // to Alloc. The 'shared' buffer must be a previously allocated or shared
396 // buffer. Both 'buffer' and 'shared' will be associated with the same
397 // SharedGroup.
ShareBuffer(const HloValue * buffer,const HloValue * shared,const HloInstruction * instruction)398 void HeapSimulator::ShareBuffer(const HloValue* buffer, const HloValue* shared,
399 const HloInstruction* instruction) {
400 algorithm_->ShareWith(buffer, shared, size_fn_(*shared));
401 no_fragmentation_stats_->ShareWith(buffer, shared, size_fn_(*shared));
402 FillDebugTrace(HeapSimulatorTrace::Event::SHARE_WITH, buffer, instruction,
403 shared);
404 }
405
Finish()406 HeapSimulator::Result<HloValue> HeapSimulator::Finish() {
407 Result<HloValue> result = algorithm_->Finish();
408
409 // Post-process the result to add chunks for shared buffers. An empty chunk
410 // map means that either no buffers were allocated, or the heap was only
411 // collecting statistics, e.g. NoFragmentationStatsHeap.
412 size_t total_chunk_count = absl::c_accumulate(
413 result.heap_results, static_cast<size_t>(0),
414 [&](size_t lhs, const HeapResult<HloValue>& rhs) -> size_t {
415 return lhs + rhs.chunk_map.size();
416 });
417 if (total_chunk_count != 0) {
418 // If we were told to assign specific buffers, make sure we've assigned
419 // exactly that many buffers.
420 if (options_.buffers_to_assign != nullptr) {
421 CHECK_EQ(options_.buffers_to_assign->size(), total_chunk_count);
422 }
423 }
424
425 // Fragmentation is the difference between the actual and ideal sizes.
426 const Result<HloValue> no_frag_result = no_fragmentation_stats_->Finish();
427 result.fragmentation_size = result.heap_size - no_frag_result.heap_size;
428
429 // Copy the debug trace we collected to the final result.
430 result.debug_trace.Swap(&debug_trace_);
431
432 return result;
433 }
434
FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,const HloValue * buffer,const HloInstruction * instruction,const HloValue * share_with_canonical)435 void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,
436 const HloValue* buffer,
437 const HloInstruction* instruction,
438 const HloValue* share_with_canonical) {
439 HeapSimulatorTrace::Event* event = debug_trace_.add_events();
440 event->set_kind(kind);
441 event->set_buffer_id(buffer->id());
442 event->set_computation_name(instruction->parent()->name());
443 event->set_instruction_name(instruction->name());
444 if (kind == HeapSimulatorTrace::Event::SHARE_WITH) {
445 CHECK(share_with_canonical != nullptr);
446 event->set_share_with_canonical_id(share_with_canonical->id());
447 } else {
448 CHECK(share_with_canonical == nullptr);
449 }
450 }
451
452 template <typename BufferType>
Alloc(const BufferType * buffer,int64_t size)453 void NoFragmentationStatsHeap<BufferType>::Alloc(const BufferType* buffer,
454 int64_t size) {
455 current_heap_size_ += size;
456 if (current_heap_size_ > max_heap_size_) {
457 max_heap_size_ = current_heap_size_;
458 }
459 }
460
461 template <typename BufferType>
AccountForSubcomputationMemory(const HloInstruction * instruction,int64_t alloc_size_by_instruction,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation)462 void NoFragmentationStatsHeap<BufferType>::AccountForSubcomputationMemory(
463 const HloInstruction* instruction, int64_t alloc_size_by_instruction,
464 const absl::flat_hash_map<const HloComputation*, int64>&
465 memory_by_computation) {
466 // We only count the memory usage of the largest subcomputation, instead of
467 // adding them all, because subcomputations won't execute in parallel.
468 int64_t max_subcomputation_bytes = 0;
469 for (const auto* c : instruction->called_computations()) {
470 auto it = memory_by_computation.find(c);
471 if (it != memory_by_computation.end()) {
472 int64_t subcomputation_bytes = it->second;
473 if (subcomputation_bytes > max_subcomputation_bytes) {
474 max_subcomputation_bytes = subcomputation_bytes;
475 }
476 }
477 }
478 if (max_subcomputation_bytes > 0 &&
479 (instruction->opcode() == HloOpcode::kWhile ||
480 instruction->opcode() == HloOpcode::kCall ||
481 instruction->opcode() == HloOpcode::kConditional)) {
482 // The output buffer of while/call/conditional is always aliased with the
483 // output buffer of the root instruction in the body. Don't double count.
484 max_subcomputation_bytes -= alloc_size_by_instruction;
485 }
486 max_heap_size_ =
487 std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes);
488 }
489
490 template <typename BufferType>
Free(const BufferType * buffer,int64_t size)491 void NoFragmentationStatsHeap<BufferType>::Free(const BufferType* buffer,
492 int64_t size) {
493 current_heap_size_ -= size;
494 }
495
496 template <typename BufferType>
497 HeapSimulator::Result<BufferType>
Finish()498 NoFragmentationStatsHeap<BufferType>::Finish() {
499 // The result.chunk_map is empty, since we only collect stats, and don't
500 // actually compute chunk assignments.
501 Result result;
502 result.heap_size = max_heap_size_;
503 return result;
504 }
505
506 template <typename BufferType>
GlobalDecreasingSizeBestFitHeap(int64_t alignment,Type type)507 GlobalDecreasingSizeBestFitHeap<BufferType>::GlobalDecreasingSizeBestFitHeap(
508 int64_t alignment, Type type)
509 : alignment_(alignment) {
510 if (type == kTemporal) {
511 buffer_interval_compare_ = GetTemporalBufferIntervalCompare();
512 } else {
513 CHECK(type == kSpatial);
514 buffer_interval_compare_ = GetSpatialBufferIntervalCompare();
515 }
516 }
517
518 template <typename BufferType>
519 typename GlobalDecreasingSizeBestFitHeap<BufferType>::BufferIntervalCompare
GetTemporalBufferIntervalCompare() const520 GlobalDecreasingSizeBestFitHeap<BufferType>::GetTemporalBufferIntervalCompare()
521 const {
522 return [&](const BufferInterval& x, const BufferInterval& y) {
523 int64_t x_end = x.end;
524 for (auto colocation : GetTransitiveColocations(x)) {
525 x_end = std::max(x_end, buffer_intervals_.at(colocation).end);
526 }
527
528 int64_t y_end = y.end;
529 for (auto colocation : GetTransitiveColocations(y)) {
530 y_end = std::max(y_end, buffer_intervals_.at(colocation).end);
531 }
532
533 if (x_end - x.start != y_end - y.start) {
534 return x_end - x.start > y_end - y.start;
535 }
536
537 if (x.size != y.size) {
538 return x.size > y.size;
539 }
540 return *x.buffer < *y.buffer;
541 };
542 }
543
544 template <typename BufferType>
545 /*static*/ typename GlobalDecreasingSizeBestFitHeap<
546 BufferType>::BufferIntervalCompare
GetSpatialBufferIntervalCompare()547 GlobalDecreasingSizeBestFitHeap<BufferType>::GetSpatialBufferIntervalCompare() {
548 return [&](const BufferInterval& x, const BufferInterval& y) {
549 if (x.size != y.size) {
550 return x.size > y.size;
551 }
552 if (x.end - x.start != y.end - y.start) {
553 return x.end - x.start > y.end - y.start;
554 }
555 return *x.buffer < *y.buffer;
556 };
557 }
558
559 template <typename BufferType>
Alloc(const BufferType * buffer,int64_t size)560 void GlobalDecreasingSizeBestFitHeap<BufferType>::Alloc(
561 const BufferType* buffer, int64_t size) {
562 // Degenerate case: 0-sized buffers are always allocated at offset 0.
563 if (size == 0) {
564 result_.chunk_map.emplace(buffer, Chunk{0, 0});
565 return;
566 }
567
568 auto emplace_result = buffer_intervals_.emplace(
569 buffer, BufferInterval{buffer, size, current_time_, -1, {}, true});
570 DCHECK(emplace_result.second);
571 ++current_time_;
572 }
573
574 template <typename BufferType>
ShareWith(const BufferType * buffer,const BufferType * share_with,int64_t size)575 void GlobalDecreasingSizeBestFitHeap<BufferType>::ShareWith(
576 const BufferType* buffer, const BufferType* share_with, int64_t size) {
577 // Degenerate case: 0-sized buffers are always allocated at offset 0.
578 if (size == 0) {
579 result_.chunk_map.emplace(buffer, Chunk{0, 0});
580 return;
581 }
582 DCHECK_NE(buffer_intervals_.count(share_with), 0);
583 buffer_intervals_[share_with].colocations.push_back(buffer);
584 auto emplace_result = buffer_intervals_.emplace(
585 buffer, BufferInterval{buffer, size, current_time_, -1, {}, false});
586 DCHECK(emplace_result.second);
587 ++current_time_;
588 }
589
590 template <typename BufferType>
591 absl::flat_hash_set<const BufferType*>
GetTransitiveColocations(const BufferInterval & interval) const592 GlobalDecreasingSizeBestFitHeap<BufferType>::GetTransitiveColocations(
593 const BufferInterval& interval) const {
594 absl::flat_hash_set<const BufferType*> result;
595 std::vector<const BufferInterval*> worklist = {&interval};
596 while (!worklist.empty()) {
597 const BufferInterval* item = worklist.back();
598 worklist.pop_back();
599 for (const BufferType* buffer_colocated : item->colocations) {
600 result.insert(buffer_colocated);
601 worklist.push_back(&buffer_intervals_.at(buffer_colocated));
602 }
603 }
604
605 return result;
606 }
607
608 template <typename BufferType>
Free(const BufferType * buffer,int64_t size)609 void GlobalDecreasingSizeBestFitHeap<BufferType>::Free(const BufferType* buffer,
610 int64_t size) {
611 // Degenerate case: 0-sized buffers are always allocated at offset 0.
612 if (size == 0) {
613 return;
614 }
615 BufferInterval& buffer_interval = FindOrDie(buffer_intervals_, buffer);
616 DCHECK_EQ(buffer_interval.buffer, buffer);
617 DCHECK_EQ(buffer_interval.size, size);
618 DCHECK_EQ(buffer_interval.end, -1);
619 if (buffer_interval.end != -1) {
620 return;
621 }
622 buffer_interval.end = current_time_;
623 ++current_time_;
624 }
625
626 using Chunk = HeapSimulator::Chunk;
627
Add(int64_t start,int64_t end,const Chunk & chunk)628 void BufferIntervalTree::Add(int64_t start, int64_t end, const Chunk& chunk) {
629 node_storage_.emplace_back(BufferIntervalTreeNode{
630 start, end, end, chunk,
631 /*left=*/nullptr, /*right=*/nullptr, /*parent=*/nullptr});
632 if (root_ == nullptr) {
633 root_ = &node_storage_.back();
634 // This is root.
635 return;
636 }
637
638 BufferIntervalTreeNode* parent = root_;
639 while (true) {
640 parent->subtree_end = std::max(parent->subtree_end, end);
641 if (parent->start > start) {
642 if (parent->left == nullptr) {
643 parent->left = &node_storage_.back();
644 node_storage_.back().parent = parent;
645 return;
646 }
647 parent = parent->left;
648 } else {
649 if (parent->right == nullptr) {
650 parent->right = &node_storage_.back();
651 node_storage_.back().parent = parent;
652 return;
653 }
654 parent = parent->right;
655 }
656 }
657 }
658
Remove(int64_t start,int64_t end,const Chunk & chunk)659 bool BufferIntervalTree::Remove(int64_t start, int64_t end,
660 const Chunk& chunk) {
661 BufferIntervalTreeNode* to_delete = root_;
662 while (to_delete != nullptr) {
663 if (to_delete->start == start && to_delete->end == end &&
664 to_delete->chunk.offset == chunk.offset) {
665 break;
666 }
667 if (start < to_delete->start) {
668 to_delete = to_delete->left;
669 } else {
670 to_delete = to_delete->right;
671 }
672 }
673 if (to_delete == nullptr) {
674 // Nothing to delete.
675 return false;
676 }
677 // Found the node to be deleted, enter deletion sequence.
678
679 // Recursively traverse the parents of node and fix up the `subtree_end`
680 // invariant of a node. Recursive lambda need an explicit
681 // std::function declaration.
682 std::function<void(BufferIntervalTreeNode*)> fix_up =
683 [&](BufferIntervalTreeNode* node) {
684 if (node == nullptr) {
685 return;
686 }
687 node->subtree_end = node->end;
688 if (node->left) {
689 node->subtree_end =
690 std::max(node->subtree_end, node->left->subtree_end);
691 }
692 if (node->right) {
693 node->subtree_end =
694 std::max(node->subtree_end, node->right->subtree_end);
695 }
696 // Recursively go up.
697 fix_up(node->parent);
698 };
699
700 if (to_delete->right == nullptr) {
701 // to_delete has no right child, simply move up left child of to_delete if
702 // any.
703 //
704 // Turn:
705 // parent
706 // /
707 // to_delete
708 // / \
709 // left nullptr
710 //
711 // Into:
712 // parent
713 // /
714 // left
715 if (root_ == to_delete) {
716 // Deleting root is simply reseting root;
717 root_ = to_delete->left;
718 return true;
719 }
720
721 if (to_delete == to_delete->parent->left) {
722 // to_delete is left child of parent.
723 to_delete->parent->left = to_delete->left;
724 }
725 if (to_delete == to_delete->parent->right) {
726 // to_delete is right child of parent.
727 to_delete->parent->right = to_delete->left;
728 }
729 // Rewire parent to the node being moved up.
730 if (to_delete->left) {
731 to_delete->left->parent = to_delete->parent;
732 }
733 // Fix up starting from subroot.
734 fix_up(to_delete);
735 } else {
736 // 1. Find left-most node of the right subtree, promote it to the position
737 // of to_delete.
738 BufferIntervalTreeNode* to_promote = to_delete->right;
739 while (to_promote->left != nullptr) {
740 // Go to left-most subtree.
741 to_promote = to_promote->left;
742 }
743
744 // 2. Copy the content of `to_promote` to `to_delete`.
745 to_delete->start = to_promote->start;
746 to_delete->end = to_promote->end;
747 // This is incorrect but we will fix this up later in the `fix_up`
748 // procedure.
749 to_delete->subtree_end = to_promote->subtree_end;
750 to_delete->chunk = to_promote->chunk;
751 auto to_promote_parent = to_promote->parent;
752 // 3. Move the right child of `to_promote` up if there is any.
753 //
754 // Turn
755 //
756 // to_delete
757 // \
758 // to_promote_parent
759 // /
760 // to_promote
761 // \
762 // right
763 // into
764 //
765 // to_promote
766 // \
767 // to_promote_parent
768 // /
769 // right
770 if (to_promote_parent->left == to_promote) {
771 to_promote_parent->left = to_promote->right;
772 } else {
773 to_promote_parent->right = to_promote->right;
774 }
775 if (to_promote->right) {
776 // Set correct parent.
777 to_promote->right->parent = to_promote_parent;
778 }
779 // 4. Recursive fix up the `subtree_end` starting from
780 // `to_promote_parent`.
781 fix_up(to_promote_parent);
782 }
783 // Don't free the entry in node_storage_ until we free the entire tree.
784 return true;
785 }
786
ChunksOverlappingInTime(int64_t start,int64_t end) const787 std::vector<Chunk> BufferIntervalTree::ChunksOverlappingInTime(
788 int64_t start, int64_t end) const {
789 std::vector<Chunk> result;
790 if (root_ == nullptr) {
791 return result;
792 }
793 std::vector<const BufferIntervalTreeNode*> visiting_stack;
794 visiting_stack.push_back(root_);
795 while (!visiting_stack.empty()) {
796 const BufferIntervalTreeNode* top = visiting_stack.back();
797 visiting_stack.pop_back();
798 if (start > top->subtree_end) {
799 continue;
800 }
801 if (top->left != nullptr) {
802 visiting_stack.push_back(top->left);
803 }
804 if (top->start <= end && top->end >= start) {
805 result.push_back(top->chunk);
806 }
807 if (end < top->start) {
808 continue;
809 }
810 if (top->right != nullptr) {
811 visiting_stack.push_back(top->right);
812 }
813 }
814 return result;
815 }
816
817 template <typename BufferType>
818 HeapSimulator::Result<BufferType>
Finish()819 GlobalDecreasingSizeBestFitHeap<BufferType>::Finish() {
820 std::vector<BufferInterval> sorted_buffer_intervals =
821 GetSortedBufferIntervals();
822
823 for (auto& buffer_interval : sorted_buffer_intervals) {
824 if (!buffer_interval.need_allocation) {
825 continue;
826 }
827
828 ChunkCandidate chunk_candidate = FindChunkCandidate(buffer_interval);
829 // This implementation of the heap algorithm does not have a notion of
830 // maximum heap size, so it just commits.
831 CommitChunk(buffer_interval, chunk_candidate);
832 }
833 VLOG(1) << "result heap_size: " << result_.heap_size;
834 Result result;
835 result.heap_size = result_.heap_size;
836 result.heap_results.emplace_back(result_);
837 return result;
838 }
839
840 template <typename BufferType>
841 std::vector<
842 typename GlobalDecreasingSizeBestFitHeap<BufferType>::BufferInterval>
GetSortedBufferIntervals() const843 GlobalDecreasingSizeBestFitHeap<BufferType>::GetSortedBufferIntervals() const {
844 std::vector<BufferInterval> sorted_buffer_intervals;
845 for (auto& entry : buffer_intervals_) {
846 sorted_buffer_intervals.push_back(entry.second);
847 }
848 absl::c_sort(sorted_buffer_intervals, buffer_interval_compare_);
849
850 return sorted_buffer_intervals;
851 }
852
853 template <typename BufferType>
854 typename GlobalDecreasingSizeBestFitHeap<BufferType>::ChunkCandidate
FindChunkCandidate(const GlobalDecreasingSizeBestFitHeap::BufferInterval & buffer_interval,int64_t preferred_offset) const855 GlobalDecreasingSizeBestFitHeap<BufferType>::FindChunkCandidate(
856 const GlobalDecreasingSizeBestFitHeap::BufferInterval& buffer_interval,
857 int64_t preferred_offset) const {
858 VLOG(1) << "Finding chunks for buffer: "
859 << buffer_interval.buffer->ToString();
860 VLOG(1) << "Size " << buffer_interval.size << ", start "
861 << buffer_interval.start << ", end " << buffer_interval.end;
862 auto chunks_overlapping_in_time = interval_tree_.ChunksOverlappingInTime(
863 buffer_interval.start, buffer_interval.end);
864 // Get all colocated buffers and gather all interferenced chunks.
865 //
866 // Imagine that we've already allocated three chunks : a, b and c. And now
867 // we want to allocate d. Since e is colocated with d, we have to allocate
868 // chunks for them together at the same address. To do this, we first gather
869 // all chunks that overlap with d and e on the time dimension, in this case
870 // the overlapped chunks are a and b (c doesn't overlap with either of d and
871 // e), then find create a new chunk that doesn't overlap with a and b on the
872 // space dimension.
873 //
874 // space
875 // ^
876 // |+--d---+ +---e---+
877 // |
878 // |+---+ +---------------+ +-------+
879 // || | | | | |
880 // || | | | | |
881 // |+-a-+ +-------b-------+ +---c---+
882 // ----------------------------------------> time
883 for (auto colocation : GetTransitiveColocations(buffer_interval)) {
884 auto colocation_interval = buffer_intervals_.at(colocation);
885 auto colocation_overlapping = interval_tree_.ChunksOverlappingInTime(
886 colocation_interval.start, colocation_interval.end);
887 VLOG(1) << " Alias size " << colocation_interval.size << ", start "
888 << colocation_interval.start << ", end " << colocation_interval.end
889 << " " << colocation_interval.buffer->ToString();
890 chunks_overlapping_in_time.insert(chunks_overlapping_in_time.end(),
891 colocation_overlapping.begin(),
892 colocation_overlapping.end());
893 }
894 absl::c_sort(chunks_overlapping_in_time, [](const Chunk& x, const Chunk& y) {
895 return x.offset < y.offset;
896 });
897
898 // Find the minimum free chunk that can hold this buffer.
899 ChunkCandidate chunk_candidate{Chunk{-1, INT64_MAX}, result_.heap_size};
900 Chunk& min_fit_chunk = chunk_candidate.chunk;
901 int64_t preferred_chunk_end = preferred_offset + buffer_interval.size;
902 auto use_free_chunk_if_smaller = [&](int64_t free_offset, int64_t free_size) {
903 if (free_size < buffer_interval.size) {
904 return;
905 }
906
907 // If a preferred offset is provided, pick that offset.
908 if (free_offset <= preferred_offset &&
909 free_offset + free_size >= preferred_chunk_end) {
910 min_fit_chunk = {preferred_offset, buffer_interval.size};
911 } else if (free_offset + free_size == result_.heap_size &&
912 free_offset <= preferred_offset) {
913 // If the free offset is at the very end and if the preferred offset lies
914 // in this, pick the preferred offset and grow the heap.
915 min_fit_chunk = {preferred_offset, buffer_interval.size};
916 chunk_candidate.heap_size = preferred_chunk_end;
917 }
918
919 // Pick the min-fit chunk only if we didn't have a preferred offset or a
920 // chunk at the preferred offset hasn't been found.
921 if ((preferred_offset < 0 || min_fit_chunk.offset != preferred_offset) &&
922 free_size < min_fit_chunk.size) {
923 min_fit_chunk = {free_offset, free_size};
924 }
925 };
926
927 int64_t offset = 0;
928 for (auto& chunk : chunks_overlapping_in_time) {
929 if (offset < chunk.offset) {
930 use_free_chunk_if_smaller(offset, chunk.offset - offset);
931 }
932 offset = std::max(offset, RoundUpToNearest(chunk.chunk_end(), alignment_));
933 }
934 use_free_chunk_if_smaller(offset, result_.heap_size - offset);
935 // When preferred offset is provided and the preferred offset is larger than
936 // the current heap size, simply use the preferred offset provided.
937 if (result_.heap_size <= preferred_offset) {
938 chunk_candidate.heap_size = preferred_chunk_end;
939 min_fit_chunk = {preferred_offset, buffer_interval.size};
940 }
941
942 if (min_fit_chunk.offset == -1) {
943 // Increase the heap size to fit in the last free chunk.
944 chunk_candidate.heap_size = offset + buffer_interval.size;
945 min_fit_chunk = {offset, buffer_interval.size};
946 }
947
948 min_fit_chunk.size = buffer_interval.size;
949 return chunk_candidate;
950 }
951
952 template <typename BufferType>
CommitChunk(const GlobalDecreasingSizeBestFitHeap<BufferType>::BufferInterval & buffer_interval,GlobalDecreasingSizeBestFitHeap<BufferType>::ChunkCandidate chunk_candidate)953 void GlobalDecreasingSizeBestFitHeap<BufferType>::CommitChunk(
954 const GlobalDecreasingSizeBestFitHeap<BufferType>::BufferInterval&
955 buffer_interval,
956 GlobalDecreasingSizeBestFitHeap<BufferType>::ChunkCandidate
957 chunk_candidate) {
958 // Update the maximum heap size according to the one determined by the chunk
959 // candidate.
960 result_.heap_size = chunk_candidate.heap_size;
961 interval_tree_.Add(buffer_interval.start, buffer_interval.end,
962 chunk_candidate.chunk);
963 for (auto colocation : GetTransitiveColocations(buffer_interval)) {
964 AddToChunkMap(colocation, chunk_candidate.chunk);
965 auto colocation_interval = buffer_intervals_[colocation];
966 interval_tree_.Add(colocation_interval.start, colocation_interval.end,
967 chunk_candidate.chunk);
968 }
969
970 AddToChunkMap(buffer_interval.buffer, chunk_candidate.chunk);
971 }
972
973 template <typename BufferType>
AddToChunkMap(const BufferType * buffer,Chunk chunk)974 void GlobalDecreasingSizeBestFitHeap<BufferType>::AddToChunkMap(
975 const BufferType* buffer, Chunk chunk) {
976 const auto emplace_result = result_.chunk_map.emplace(buffer, chunk);
977 DCHECK(emplace_result.second);
978 }
979
980 HeapSimulator::Result<HloValue>
Finish()981 ConstrainedGlobalDecreasingSizeBestFitHeap::Finish() {
982 std::vector<BufferInterval> sorted_buffer_vec = GetSortedBufferIntervals();
983 // Convert into std::list so that erase() is O(1).
984 std::list<BufferInterval> sorted_buffer_intervals(sorted_buffer_vec.begin(),
985 sorted_buffer_vec.end());
986
987 // Use do-while here, because we need to create 1 heap in `multi_heap_result`
988 // even if `sorted_buffer_intervals` is empty.
989 Result multi_heap_result;
990 do {
991 // Place buffers into the currently processed heap as many as possible.
992 for (auto it = sorted_buffer_intervals.begin();
993 it != sorted_buffer_intervals.end();) {
994 BufferInterval buffer_interval = *it;
995 if (!buffer_interval.need_allocation) {
996 it = sorted_buffer_intervals.erase(it);
997 continue;
998 }
999 if (buffer_interval.size > size_limit_per_heap_) {
1000 LOG(WARNING) << "Alloc buffer size " << buffer_interval.size
1001 << " larger than the per-heap size limit "
1002 << size_limit_per_heap_;
1003 }
1004
1005 ChunkCandidate chunk_candidate = FindChunkCandidate(buffer_interval);
1006 if (chunk_candidate.heap_size <= size_limit_per_heap_ ||
1007 // Commit the chunk as long as the heap is empty. We do this because
1008 // we want the size constraint to be soft, meaning that results are
1009 // successfully generated even if there are some buffer sizes larger
1010 // than the given constraint size.
1011 result_.heap_size == 0) {
1012 CommitChunk(buffer_interval, chunk_candidate);
1013 it = sorted_buffer_intervals.erase(it);
1014 continue;
1015 }
1016
1017 ++it;
1018 }
1019 // Collect the result from the currently processed heap and reset the heap
1020 // states.
1021 multi_heap_result.heap_size += result_.heap_size;
1022 multi_heap_result.heap_results.push_back(std::move(result_));
1023 result_ = {};
1024 interval_tree_ = {};
1025 } while (!sorted_buffer_intervals.empty());
1026
1027 VLOG(1) << "Number of heaps produced = "
1028 << multi_heap_result.heap_results.size();
1029 return multi_heap_result;
1030 }
1031
1032 template <typename BufferType>
1033 HeapSimulator::Result<BufferType>
Finish()1034 ChooseBestHeapAlgorithm<BufferType>::Finish() {
1035 DCHECK(!algorithms_.empty());
1036 std::vector<Result> results(algorithms_.size());
1037 int64_t min_size = INT64_MAX;
1038 int min_size_index = -1;
1039 for (int i = 0; i < algorithms_.size(); ++i) {
1040 results[i] = algorithms_[i]->Finish();
1041 if (results[i].heap_size < min_size) {
1042 min_size = results[i].heap_size;
1043 min_size_index = i;
1044 }
1045 }
1046
1047 DCHECK_GE(min_size_index, 0);
1048 return results[min_size_index];
1049 }
1050
1051 template class GlobalDecreasingSizeBestFitHeap<HloValue>;
1052 template class GlobalDecreasingSizeBestFitHeap<
1053 MemorySpaceAssignmentRepacker::AllocationBlock>;
1054 template class ChooseBestHeapAlgorithm<HloValue>;
1055
1056 } // namespace xla
1057