1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/heap_simulator.h"
17
18 #include <algorithm>
19 #include <vector>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/memory/memory.h"
24 #include "tensorflow/compiler/xla/map_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_live_range.h"
26 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
27 #include "tensorflow/compiler/xla/util.h"
28
29 namespace xla {
30
31 using absl::flat_hash_map;
32 using absl::flat_hash_set;
33
OverlapsWith(Chunk other_chunk) const34 bool HeapSimulator::Chunk::OverlapsWith(Chunk other_chunk) const {
35 CHECK_NE(size, 0);
36 CHECK_NE(other_chunk.size, 0);
37 return offset < other_chunk.chunk_end() && other_chunk.offset < chunk_end();
38 }
39
40 /*static*/
MinimumMemoryForModule(const HloSchedule & schedule,const LogicalBuffer::SizeFunction & size_function)41 StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
42 const HloSchedule& schedule,
43 const LogicalBuffer::SizeFunction& size_function) {
44 if (schedule.empty()) {
45 return 0;
46 }
47 const HloModule* module = schedule.module();
48
49 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
50 HloAliasAnalysis::Run(module));
51
52 // The absolute minimum memory required for a given sequence of instructions
53 // is determined by the sequence of Alloc and Free calls on a simulated heap,
54 // ignoring fragmentation. We run the heap simulation on the whole module,
55 // rather than summing each computation, since it gives us a better lower
56 // bound, by minimizing the liveness of sub-computations.
57 TF_ASSIGN_OR_RETURN(
58 HeapSimulator::Result result,
59 HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(), *module,
60 schedule, *alias_analysis, size_function));
61 return result.heap_size;
62 }
63
64 /*static*/
MinimumMemoryForComputation(const HloComputation & computation,const HloInstructionSequence & sequence,const HloAliasAnalysis & alias_analysis,const LogicalBuffer::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> * memory_by_computation)65 StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
66 const HloComputation& computation, const HloInstructionSequence& sequence,
67 const HloAliasAnalysis& alias_analysis,
68 const LogicalBuffer::SizeFunction& size_function,
69 const absl::flat_hash_map<const HloComputation*, int64>*
70 memory_by_computation) {
71 TF_ASSIGN_OR_RETURN(
72 HeapSimulator::Result result,
73 HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(),
74 computation, sequence, alias_analysis, size_function,
75 HeapSimulator::Options(), memory_by_computation));
76 return result.heap_size;
77 }
78
MinimumMemoryForComputation(const HloComputation & computation,const HloInstructionSequence & sequence,const HloAliasAnalysis & alias_analysis,const LogicalBuffer::SizeFunction & size_function,const HloSchedule * schedule)79 StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
80 const HloComputation& computation, const HloInstructionSequence& sequence,
81 const HloAliasAnalysis& alias_analysis,
82 const LogicalBuffer::SizeFunction& size_function,
83 const HloSchedule* schedule) {
84 TF_ASSIGN_OR_RETURN(
85 HeapSimulator::Result result,
86 HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(),
87 computation, sequence, alias_analysis, size_function,
88 schedule, HeapSimulator::Options()));
89 return result.heap_size;
90 }
91
92 /*static*/
Run(std::unique_ptr<HeapAlgorithm> algorithm,const HloModule & module,const HloSchedule & schedule,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_fn,const Options & options)93 StatusOr<HeapSimulator::Result> HeapSimulator::Run(
94 std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module,
95 const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis,
96 const BufferValue::SizeFunction& size_fn, const Options& options) {
97 HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule);
98 const HloComputation* entry_computation = module.entry_computation();
99 const HloInstructionSequence& instruction_sequence =
100 schedule.sequence(entry_computation);
101 TF_ASSIGN_OR_RETURN(
102 std::unique_ptr<HloLiveRange> hlo_live_range,
103 HloLiveRange::Run(schedule, alias_analysis, entry_computation));
104 TF_RETURN_IF_ERROR(heap.RunComputation(*entry_computation,
105 instruction_sequence, alias_analysis,
106 hlo_live_range.get()));
107 return heap.Finish();
108 }
109
110 /*static*/
Run(std::unique_ptr<HeapAlgorithm> algorithm,const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_fn,const Options & options,const absl::flat_hash_map<const HloComputation *,int64> * memory_by_computation)111 StatusOr<HeapSimulator::Result> HeapSimulator::Run(
112 std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation,
113 const HloInstructionSequence& instruction_sequence,
114 const HloAliasAnalysis& alias_analysis,
115 const BufferValue::SizeFunction& size_fn, const Options& options,
116 const absl::flat_hash_map<const HloComputation*, int64>*
117 memory_by_computation) {
118 HeapSimulator heap(std::move(algorithm), size_fn, options,
119 /*schedule=*/nullptr, memory_by_computation);
120 HloSchedule schedule(computation.parent());
121 schedule.set_sequence(&computation, instruction_sequence);
122 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
123 HloLiveRange::Run(schedule, alias_analysis, &computation,
124 /*module_scoped_analysis=*/false));
125 TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
126 alias_analysis, hlo_live_range.get()));
127 return heap.Finish();
128 }
129
130 /*static*/
Run(std::unique_ptr<HeapAlgorithm> algorithm,const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_fn,const HloSchedule * schedule,const Options & options)131 StatusOr<HeapSimulator::Result> HeapSimulator::Run(
132 std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation,
133 const HloInstructionSequence& instruction_sequence,
134 const HloAliasAnalysis& alias_analysis,
135 const BufferValue::SizeFunction& size_fn, const HloSchedule* schedule,
136 const Options& options) {
137 HeapSimulator heap(std::move(algorithm), size_fn, options,
138 /*schedule=*/schedule, nullptr);
139 TF_ASSIGN_OR_RETURN(
140 std::unique_ptr<HloLiveRange> hlo_live_range,
141 HloLiveRange::Run(*schedule, alias_analysis, &computation));
142 TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
143 alias_analysis, hlo_live_range.get()));
144 return heap.Finish();
145 }
146
147 // Runs a heap simulation for the given 'computation', assuming the given
148 // 'instruction_sequence'.
RunComputation(const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const HloAliasAnalysis & alias_analysis,HloLiveRange * hlo_live_range)149 Status HeapSimulator::RunComputation(
150 const HloComputation& computation,
151 const HloInstructionSequence& instruction_sequence,
152 const HloAliasAnalysis& alias_analysis, HloLiveRange* hlo_live_range) {
153 XLA_VLOG_LINES(1, computation.parent()->ToString());
154 XLA_VLOG_LINES(2, computation.ToString());
155
156 VLOG(1) << hlo_live_range->ToString();
157
158 HloDataflowAnalysis& dataflow_analysis = alias_analysis.dataflow_analysis();
159
160 // Record the buffer define/free event for each time step. We free all
161 // remaining buffers (entry parameter, etc) after the program has finished
162 // running, so we set the size of to program_end_time + 1.
163 std::vector<std::vector<const HloValue*>> buffers_defined(
164 hlo_live_range->schedule_end_time() + 1);
165 std::vector<std::vector<const HloValue*>> buffers_freed(
166 hlo_live_range->schedule_end_time() + 1);
167
168 // values_to_assign tracks the HloValues that we need to assign a buffer to.
169 // Note that we only need to assign a buffer to a value when both of the
170 // following conditions are met:
171 //
172 // - The user specifically asks us to assign a buffer to a set of HloValues,
173 // and the value is in the set. If the user don't provide such a set, by
174 // default we assign buffer to all HloValues.
175 //
176 // - If the instruction is in a nested call of the current computation, only
177 // assign a buffer if we are doing global heap simulation.
178 std::vector<const HloValue*> values_to_assign;
179 values_to_assign.reserve(dataflow_analysis.values().size());
180
181 for (const HloValue* value : dataflow_analysis.values()) {
182 // Ignore buffers that are not tracked.
183 if (hlo_live_range->instruction_schedule().count(
184 value->defining_instruction()) == 0) {
185 continue;
186 }
187 if (IgnoreBuffer(value)) {
188 continue;
189 }
190 values_to_assign.push_back(value);
191 }
192
193 auto& buffer_live_ranges = hlo_live_range->buffer_live_ranges();
194
195 absl::c_sort(values_to_assign,
196 [&](const HloValue* value1, const HloValue* value2) {
197 const auto& live_range1 = buffer_live_ranges.at(value1);
198 const auto& live_range2 = buffer_live_ranges.at(value2);
199 return std::forward_as_tuple(live_range1.start,
200 live_range1.end, value1->id()) <
201 std::forward_as_tuple(live_range2.start,
202 live_range2.end, value2->id());
203 });
204
205 // For each value that we need to assign a buffer to, add the define and free
206 // events.
207 for (const HloValue* value : values_to_assign) {
208 auto live_range = buffer_live_ranges.at(value);
209 buffers_defined[live_range.start].push_back(value);
210 buffers_freed[live_range.end].push_back(value);
211 }
212
213 // All HloValues in a hlo buffer should be allocated to the same address. This
214 // map tracks the first value that got allocated in a buffer.
215 absl::flat_hash_map<const HloBuffer*, const HloValue*> first_allocated_value;
216
217 VLOG(1) << "Program time" << hlo_live_range->schedule_end_time();
218
219 // Go through each step in the program and replay each buffer define and free
220 // events.
221 for (int64 i = 0; i < hlo_live_range->schedule_end_time() + 1; ++i) {
222 VLOG(1) << "Time step: " << i;
223
224 for (const HloValue* value : buffers_defined[i]) {
225 bool shared = false;
226 VLOG(1) << "Start buffer: " << value->ToShortString();
227 const HloBuffer* hlo_buffer =
228 &alias_analysis.GetBufferContainingValue(*value);
229 if (first_allocated_value.count(hlo_buffer) != 0) {
230 // We've already assigned an address for another value in this HloBuffer
231 // (HloBuffer holds several aliased HloValues). All values in a buffer
232 // should be assigned the same address. Find the one that's already
233 // allocated and reuse its address.
234 ShareBuffer(value, first_allocated_value[hlo_buffer],
235 value->instruction());
236 VLOG(1) << " ShareWith"
237 << first_allocated_value[hlo_buffer]->ToShortString();
238 continue;
239 }
240 if (options_.may_reuse_operand_buffers &&
241 hlo_buffer->values().size() == 1) {
242 // We don't support sharing an aliased buffer
243 // (hlo_buffer->values().size() > 1) with its operand.
244 for (const HloInstruction* operand : value->instruction()->operands()) {
245 const HloValueSet operand_value_set =
246 dataflow_analysis.GetValueSet(operand);
247 for (const HloValue* operand_value : operand_value_set.values()) {
248 const HloBuffer* operand_buffer =
249 &alias_analysis.GetBufferContainingValue(*operand_value);
250 if (operand_buffer->values().size() > 1) {
251 continue;
252 }
253 auto it = buffer_live_ranges.find(operand_value);
254 if (it == buffer_live_ranges.end()) {
255 continue;
256 }
257
258 auto& operand_live_range = it->second;
259
260 auto& user_live_range = buffer_live_ranges[value];
261
262 // Can only share buffers that are about to be freed.
263 if (operand_live_range.end != i) {
264 continue;
265 }
266
267 if (IgnoreBuffer(operand_value)) {
268 continue;
269 }
270
271 if (!absl::c_linear_search(buffers_freed[i], operand_value)) {
272 // If the operand buffer is not being freed (either because it has
273 // existing users, or it has been reused by other buffers), don't
274 // consider the operand as a candidate of buffer sharing.
275 continue;
276 }
277
278 // The instruction that defines the operand value can be different
279 // from the actual operand, if directly passing the defining
280 // instruction into "CanShareOperandBufferWithUser" it creates a
281 // check failure. The first condition guards against that case.
282 if (value->instruction()->IsUserOf(operand_value->instruction()) &&
283 value->instruction()->opcode() != HloOpcode::kCopy &&
284 dataflow_analysis.CanShareOperandBufferWithUser(
285 operand_value->instruction(), operand_value->index(),
286 value->instruction(), value->index())) {
287 // Remove the operand buffer right before sharing (allocating) a
288 // new one.
289 Free(operand_value, operand_value->instruction());
290 buffers_freed[i].erase(
291 std::remove(buffers_freed[i].begin(), buffers_freed[i].end(),
292 operand_value),
293 buffers_freed[i].end());
294 ShareBuffer(value, operand_value, value->instruction());
295 // The live range of the operand buffer is now extended to the end
296 // of the current instruction.
297 operand_live_range.end = user_live_range.end;
298 VLOG(1) << "Sharing " << value->ToShortString() << " with "
299 << operand_value->ToShortString()
300 << ", size:" << size_fn_(*value);
301 shared = true;
302 break;
303 }
304 }
305 if (shared) {
306 break;
307 }
308 }
309 }
310 if (!shared) {
311 Alloc(value, value->instruction());
312 first_allocated_value[hlo_buffer] = value;
313 }
314 }
315
316 if (!buffers_freed[i].empty()) {
317 VLOG(1) << "Free Buffer: ";
318 }
319 for (const HloValue* value : buffers_freed[i]) {
320 VLOG(1) << " " << value->ToShortString();
321
322 Free(value, value->instruction());
323 }
324 }
325 return Status::OK();
326 }
327
HeapSimulator(std::unique_ptr<HeapAlgorithm> algorithm,const BufferValue::SizeFunction & size_fn,const Options & options,const HloSchedule * schedule,const absl::flat_hash_map<const HloComputation *,int64> * memory_by_computation)328 HeapSimulator::HeapSimulator(
329 std::unique_ptr<HeapAlgorithm> algorithm,
330 const BufferValue::SizeFunction& size_fn, const Options& options,
331 const HloSchedule* schedule,
332 const absl::flat_hash_map<const HloComputation*, int64>*
333 memory_by_computation)
334 : no_fragmentation_stats_(absl::make_unique<NoFragmentationStatsHeap>()),
335 algorithm_(std::move(algorithm)),
336 size_fn_(size_fn),
337 options_(options),
338 schedule_(schedule),
339 memory_by_computation_(memory_by_computation) {
340 debug_trace_.set_whole_module_simulation(schedule_ != nullptr);
341 }
342
~HeapSimulator()343 HeapSimulator::~HeapSimulator() {}
344
IgnoreBuffer(const HloValue * buffer) const345 bool HeapSimulator::IgnoreBuffer(const HloValue* buffer) const {
346 // Buffers for constants are ignored unless the alloc_constants option is
347 // set. Also ignore buffers that we're not meant to assign.
348 //
349 // TODO(b/32248867): For consistency, constants should get allocations.
350 if (!options_.alloc_constants &&
351 buffer->instruction()->opcode() == HloOpcode::kConstant) {
352 return true;
353 }
354 return options_.buffers_to_assign != nullptr &&
355 !options_.buffers_to_assign->contains(buffer);
356 }
357
358 // Alloc always calls the underlying heap algorithm.
Alloc(const HloValue * buffer,const HloInstruction * instruction)359 void HeapSimulator::Alloc(const HloValue* buffer,
360 const HloInstruction* instruction) {
361 CHECK(!allocated_buffers_.contains(buffer))
362 << "Alloc called on allocated buffer: " << *buffer;
363 CHECK(!freed_buffers_.contains(buffer))
364 << "Alloc called on freed buffer: " << *buffer;
365
366 allocated_buffers_.insert(buffer);
367 const int64 size = size_fn_(*buffer);
368 algorithm_->Alloc(buffer, size);
369 no_fragmentation_stats_->Alloc(buffer, size);
370 FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction,
371 nullptr);
372 }
373
374 // Free calls the underlying algorithm for non-shared buffers, and for shared
375 // buffers whose group liveness has expired. Shared group liveness is tracked
376 // by maintaining a refcount; the Free call on the last buffer in the group
377 // causes Free to be called on the underlying algorithm.
Free(const HloValue * buffer,const HloInstruction * instruction)378 void HeapSimulator::Free(const HloValue* buffer,
379 const HloInstruction* instruction) {
380 const int64 size = size_fn_(*buffer);
381 algorithm_->Free(buffer, size);
382 no_fragmentation_stats_->Free(buffer, size);
383 FillDebugTrace(HeapSimulatorTrace::Event::FREE, buffer, instruction, nullptr);
384 }
385
386 // ShareBuffer associates buffers with their SharedGroup in shared_buffers_.
387 // The 'buffer' must be a non-allocated, non-freed buffer, just like in calls
388 // to Alloc. The 'shared' buffer must be a previously allocated or shared
389 // buffer. Both 'buffer' and 'shared' will be associated with the same
390 // SharedGroup.
ShareBuffer(const HloValue * buffer,const HloValue * shared,const HloInstruction * instruction)391 void HeapSimulator::ShareBuffer(const HloValue* buffer, const HloValue* shared,
392 const HloInstruction* instruction) {
393 algorithm_->ShareWith(buffer, shared, size_fn_(*shared));
394 no_fragmentation_stats_->ShareWith(buffer, shared, size_fn_(*shared));
395 FillDebugTrace(HeapSimulatorTrace::Event::SHARE_WITH, buffer, instruction,
396 shared);
397 }
398
Finish()399 HeapSimulator::Result HeapSimulator::Finish() {
400 Result result = algorithm_->Finish();
401
402 // Post-process the result to add chunks for shared buffers. An empty chunk
403 // map means that either no buffers were allocated, or the heap was only
404 // collecting statistics, e.g. NoFragmentationStatsHeap.
405 if (!result.chunk_map.empty()) {
406 // If we were told to assign specific buffers, make sure we've assigned
407 // exactly that many buffers.
408 if (options_.buffers_to_assign != nullptr) {
409 CHECK_EQ(options_.buffers_to_assign->size(), result.chunk_map.size());
410 }
411 }
412
413 // Fragmentation is the difference between the actual and ideal sizes.
414 const Result no_frag_result = no_fragmentation_stats_->Finish();
415 result.fragmentation_size = result.heap_size - no_frag_result.heap_size;
416
417 // Copy the debug trace we collected to the final result.
418 result.debug_trace.Swap(&debug_trace_);
419
420 return result;
421 }
422
FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,const HloValue * buffer,const HloInstruction * instruction,const HloValue * share_with_canonical)423 void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,
424 const HloValue* buffer,
425 const HloInstruction* instruction,
426 const HloValue* share_with_canonical) {
427 HeapSimulatorTrace::Event* event = debug_trace_.add_events();
428 event->set_kind(kind);
429 event->set_buffer_id(buffer->id());
430 event->set_computation_name(instruction->parent()->name());
431 event->set_instruction_name(instruction->name());
432 if (kind == HeapSimulatorTrace::Event::SHARE_WITH) {
433 CHECK(share_with_canonical != nullptr);
434 event->set_share_with_canonical_id(share_with_canonical->id());
435 } else {
436 CHECK(share_with_canonical == nullptr);
437 }
438 }
439
Alloc(const HloValue * buffer,int64 size)440 void NoFragmentationStatsHeap::Alloc(const HloValue* buffer, int64 size) {
441 current_heap_size_ += size;
442 if (current_heap_size_ > max_heap_size_) {
443 max_heap_size_ = current_heap_size_;
444 }
445 }
446
AccountForSubcomputationMemory(const HloInstruction * instruction,int64 alloc_size_by_instruction,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation)447 void NoFragmentationStatsHeap::AccountForSubcomputationMemory(
448 const HloInstruction* instruction, int64 alloc_size_by_instruction,
449 const absl::flat_hash_map<const HloComputation*, int64>&
450 memory_by_computation) {
451 // We only count the memory usage of the largest subcomputation, instead of
452 // adding them all, because subcomputations won't execute in parallel.
453 int64 max_subcomputation_bytes = 0;
454 for (const auto* c : instruction->called_computations()) {
455 auto it = memory_by_computation.find(c);
456 if (it != memory_by_computation.end()) {
457 int64 subcomputation_bytes = it->second;
458 if (subcomputation_bytes > max_subcomputation_bytes) {
459 max_subcomputation_bytes = subcomputation_bytes;
460 }
461 }
462 }
463 if (max_subcomputation_bytes > 0 &&
464 (instruction->opcode() == HloOpcode::kWhile ||
465 instruction->opcode() == HloOpcode::kCall ||
466 instruction->opcode() == HloOpcode::kConditional)) {
467 // The output buffer of while/call/conditional is always aliased with the
468 // output buffer of the root instruction in the body. Don't double count.
469 max_subcomputation_bytes -= alloc_size_by_instruction;
470 }
471 max_heap_size_ =
472 std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes);
473 }
474
Free(const HloValue * buffer,int64 size)475 void NoFragmentationStatsHeap::Free(const HloValue* buffer, int64 size) {
476 current_heap_size_ -= size;
477 }
478
Finish()479 HeapSimulator::Result NoFragmentationStatsHeap::Finish() {
480 // The result.chunk_map is empty, since we only collect stats, and don't
481 // actually compute chunk assignments.
482 Result result;
483 result.heap_size = max_heap_size_;
484 return result;
485 }
486
GlobalDecreasingSizeBestFitHeap(int64 alignment,Type type)487 GlobalDecreasingSizeBestFitHeap::GlobalDecreasingSizeBestFitHeap(
488 int64 alignment, Type type)
489 : alignment_(alignment) {
490 if (type == kTemporal) {
491 buffer_interval_compare_ = GetTemporalBufferIntervalCompare();
492 } else {
493 CHECK(type == kSpatial);
494 buffer_interval_compare_ = GetSpatialBufferIntervalCompare();
495 }
496 }
497
498 GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare
GetTemporalBufferIntervalCompare() const499 GlobalDecreasingSizeBestFitHeap::GetTemporalBufferIntervalCompare() const {
500 return [&](const BufferInterval& x, const BufferInterval& y) {
501 int64 x_end = x.end;
502 for (auto colocation : GetTransitiveColocations(x)) {
503 x_end = std::max(x_end, buffer_intervals_.at(colocation).end);
504 }
505
506 int64 y_end = y.end;
507 for (auto colocation : GetTransitiveColocations(y)) {
508 y_end = std::max(y_end, buffer_intervals_.at(colocation).end);
509 }
510
511 if (x_end - x.start != y_end - y.start) {
512 return x_end - x.start > y_end - y.start;
513 }
514
515 if (x.size != y.size) {
516 return x.size > y.size;
517 }
518 return x.buffer->id() < y.buffer->id();
519 };
520 }
521
522 /*static*/ GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare
GetSpatialBufferIntervalCompare()523 GlobalDecreasingSizeBestFitHeap::GetSpatialBufferIntervalCompare() {
524 return [&](const BufferInterval& x, const BufferInterval& y) {
525 if (x.size != y.size) {
526 return x.size > y.size;
527 }
528 if (x.end - x.start != y.end - y.start) {
529 return x.end - x.start > y.end - y.start;
530 }
531 return x.buffer->id() < y.buffer->id();
532 };
533 }
534
Alloc(const HloValue * buffer,int64 size)535 void GlobalDecreasingSizeBestFitHeap::Alloc(const HloValue* buffer,
536 int64 size) {
537 // Degenerate case: 0-sized buffers are always allocated at offset 0.
538 if (size == 0) {
539 result_.chunk_map.emplace(buffer, Chunk{0, 0});
540 return;
541 }
542
543 auto emplace_result = buffer_intervals_.emplace(
544 buffer, BufferInterval{buffer, size, current_time_, -1, {}, true});
545 DCHECK(emplace_result.second);
546 ++current_time_;
547 }
548
ShareWith(const HloValue * buffer,const HloValue * share_with,int64 size)549 void GlobalDecreasingSizeBestFitHeap::ShareWith(const HloValue* buffer,
550 const HloValue* share_with,
551 int64 size) {
552 // Degenerate case: 0-sized buffers are always allocated at offset 0.
553 if (size == 0) {
554 result_.chunk_map.emplace(buffer, Chunk{0, 0});
555 return;
556 }
557 DCHECK_NE(buffer_intervals_.count(share_with), 0);
558 buffer_intervals_[share_with].colocations.push_back(buffer);
559 auto emplace_result = buffer_intervals_.emplace(
560 buffer, BufferInterval{buffer, size, current_time_, -1, {}, false});
561 DCHECK(emplace_result.second);
562 ++current_time_;
563 }
564
565 absl::flat_hash_set<const HloValue*>
GetTransitiveColocations(const BufferInterval & interval) const566 GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations(
567 const BufferInterval& interval) const {
568 absl::flat_hash_set<const HloValue*> result;
569 std::vector<const BufferInterval*> worklist = {&interval};
570 while (!worklist.empty()) {
571 const BufferInterval* item = worklist.back();
572 worklist.pop_back();
573 for (const HloValue* buffer_colocated : item->colocations) {
574 result.insert(buffer_colocated);
575 worklist.push_back(&buffer_intervals_.at(buffer_colocated));
576 }
577 }
578
579 return result;
580 }
581
Free(const HloValue * buffer,int64 size)582 void GlobalDecreasingSizeBestFitHeap::Free(const HloValue* buffer, int64 size) {
583 // Degenerate case: 0-sized buffers are always allocated at offset 0.
584 if (size == 0) {
585 return;
586 }
587 BufferInterval& buffer_interval = FindOrDie(buffer_intervals_, buffer);
588 DCHECK_EQ(buffer_interval.buffer, buffer);
589 DCHECK_EQ(buffer_interval.size, size);
590 DCHECK_EQ(buffer_interval.end, -1);
591 if (buffer_interval.end != -1) {
592 return;
593 }
594 buffer_interval.end = current_time_;
595 ++current_time_;
596 }
597
598 using Chunk = HeapSimulator::Chunk;
599
Add(int64 start,int64 end,const Chunk & chunk)600 void BufferIntervalTree::Add(int64 start, int64 end, const Chunk& chunk) {
601 node_storage_.emplace_back(
602 BufferIntervalTreeNode{start, end, end, chunk, nullptr, nullptr});
603
604 if (node_storage_.size() == 1) {
605 // This is root.
606 return;
607 }
608
609 BufferIntervalTreeNode* parent = &node_storage_.front();
610 while (true) {
611 parent->subtree_end = std::max(parent->subtree_end, end);
612 if (parent->start > start) {
613 if (parent->left == nullptr) {
614 parent->left = &node_storage_.back();
615 return;
616 }
617 parent = parent->left;
618 } else {
619 if (parent->right == nullptr) {
620 parent->right = &node_storage_.back();
621 return;
622 }
623 parent = parent->right;
624 }
625 }
626 }
627
ChunksOverlappingInTime(int64 start,int64 end) const628 std::vector<Chunk> BufferIntervalTree::ChunksOverlappingInTime(
629 int64 start, int64 end) const {
630 std::vector<Chunk> result;
631 if (node_storage_.empty()) {
632 return result;
633 }
634 std::vector<const BufferIntervalTreeNode*> visiting_stack;
635 visiting_stack.push_back(&node_storage_.front());
636 while (!visiting_stack.empty()) {
637 const BufferIntervalTreeNode* top = visiting_stack.back();
638 visiting_stack.pop_back();
639 if (start > top->subtree_end) {
640 continue;
641 }
642 if (top->left != nullptr) {
643 visiting_stack.push_back(top->left);
644 }
645 if (top->start <= end && top->end >= start) {
646 result.push_back(top->chunk);
647 }
648 if (end < top->start) {
649 continue;
650 }
651 if (top->right != nullptr) {
652 visiting_stack.push_back(top->right);
653 }
654 }
655 return result;
656 }
657
Finish()658 HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() {
659 std::vector<BufferInterval> sorted_buffer_intervals =
660 GetSortedBufferIntervals();
661
662 for (auto& buffer_interval : sorted_buffer_intervals) {
663 if (!buffer_interval.need_allocation) {
664 continue;
665 }
666
667 ChunkCandidate chunk_candidate = FindChunkCandidate(buffer_interval);
668 // This implementation of the heap algorithm does not have a notion of
669 // maximum heap size, so it just commits.
670 CommitChunk(buffer_interval, chunk_candidate);
671 }
672 VLOG(1) << "result heap_size: " << result_.heap_size;
673 return result_;
674 }
675
676 std::vector<GlobalDecreasingSizeBestFitHeap::BufferInterval>
GetSortedBufferIntervals() const677 GlobalDecreasingSizeBestFitHeap::GetSortedBufferIntervals() const {
678 std::vector<BufferInterval> sorted_buffer_intervals;
679 for (auto& entry : buffer_intervals_) {
680 sorted_buffer_intervals.push_back(entry.second);
681 }
682 absl::c_sort(sorted_buffer_intervals, buffer_interval_compare_);
683
684 return sorted_buffer_intervals;
685 }
686
687 GlobalDecreasingSizeBestFitHeap::ChunkCandidate
FindChunkCandidate(const GlobalDecreasingSizeBestFitHeap::BufferInterval & buffer_interval,int64 preferred_offset) const688 GlobalDecreasingSizeBestFitHeap::FindChunkCandidate(
689 const GlobalDecreasingSizeBestFitHeap::BufferInterval& buffer_interval,
690 int64 preferred_offset) const {
691 VLOG(1) << "Finding chunks for buffer: "
692 << buffer_interval.buffer->ToString();
693 VLOG(1) << "Size " << buffer_interval.size << ", start "
694 << buffer_interval.start << ", end " << buffer_interval.end;
695 auto chunks_overlapping_in_time = interval_tree_.ChunksOverlappingInTime(
696 buffer_interval.start, buffer_interval.end);
697 // Get all colocated buffers and gather all interferenced chunks.
698 //
699 // Imagine that we've already allocated three chunks : a, b and c. And now
700 // we want to allocate d. Since e is colocated with d, we have to allocate
701 // chunks for them together at the same address. To do this, we first gather
702 // all chunks that overlap with d and e on the time dimension, in this case
703 // the overlapped chunks are a and b (c doesn't overlap with either of d and
704 // e), then find create a new chunk that doesn't overlap with a and b on the
705 // space dimension.
706 //
707 // space
708 // ^
709 // |+--d---+ +---e---+
710 // |
711 // |+---+ +---------------+ +-------+
712 // || | | | | |
713 // || | | | | |
714 // |+-a-+ +-------b-------+ +---c---+
715 // ----------------------------------------> time
716 for (auto colocation : GetTransitiveColocations(buffer_interval)) {
717 auto colocation_interval = buffer_intervals_.at(colocation);
718 auto colocation_overlapping = interval_tree_.ChunksOverlappingInTime(
719 colocation_interval.start, colocation_interval.end);
720 VLOG(1) << " Alias size " << colocation_interval.size << ", start "
721 << colocation_interval.start << ", end " << colocation_interval.end
722 << " " << colocation_interval.buffer->ToString();
723 chunks_overlapping_in_time.insert(chunks_overlapping_in_time.end(),
724 colocation_overlapping.begin(),
725 colocation_overlapping.end());
726 }
727 absl::c_sort(chunks_overlapping_in_time, [](const Chunk& x, const Chunk& y) {
728 return x.offset < y.offset;
729 });
730
731 // Find the minimum free chunk that can hold this buffer.
732 ChunkCandidate chunk_candidate{Chunk{-1, INT64_MAX}, result_.heap_size};
733 Chunk& min_fit_chunk = chunk_candidate.chunk;
734 int64 preferred_chunk_end = preferred_offset + buffer_interval.size;
735 auto use_free_chunk_if_smaller = [&](int64 free_offset, int64 free_size) {
736 if (free_size < buffer_interval.size) {
737 return;
738 }
739
740 // If a preferred offset is provided, pick that offset.
741 if (free_offset <= preferred_offset &&
742 free_offset + free_size >= preferred_chunk_end) {
743 min_fit_chunk = {preferred_offset, buffer_interval.size};
744 } else if (free_offset + free_size == result_.heap_size &&
745 free_offset <= preferred_offset) {
746 // If the free offset is at the very end and if the preferred offset lies
747 // in this, pick the preferred offset and grow the heap.
748 min_fit_chunk = {preferred_offset, buffer_interval.size};
749 chunk_candidate.heap_size = preferred_chunk_end;
750 }
751
752 // Pick the min-fit chunk only if we didn't have a preferred offset or a
753 // chunk at the preferred offset hasn't been found.
754 if ((preferred_offset < 0 || min_fit_chunk.offset != preferred_offset) &&
755 free_size < min_fit_chunk.size) {
756 min_fit_chunk = {free_offset, free_size};
757 }
758 };
759
760 int64 offset = 0;
761 for (auto& chunk : chunks_overlapping_in_time) {
762 if (offset < chunk.offset) {
763 use_free_chunk_if_smaller(offset, chunk.offset - offset);
764 }
765 offset = std::max(offset, RoundUpToNearest(chunk.chunk_end(), alignment_));
766 }
767 use_free_chunk_if_smaller(offset, result_.heap_size - offset);
768 // When preferred offset is provided and the preferred offset is larger than
769 // the current heap size, simply use the preferred offset provided.
770 if (result_.heap_size <= preferred_offset) {
771 chunk_candidate.heap_size = preferred_chunk_end;
772 min_fit_chunk = {preferred_offset, buffer_interval.size};
773 }
774
775 if (min_fit_chunk.offset == -1) {
776 // Increase the heap size to fit in the last free chunk.
777 chunk_candidate.heap_size = offset + buffer_interval.size;
778 min_fit_chunk = {offset, buffer_interval.size};
779 }
780
781 min_fit_chunk.size = buffer_interval.size;
782 return chunk_candidate;
783 }
784
CommitChunk(const GlobalDecreasingSizeBestFitHeap::BufferInterval & buffer_interval,GlobalDecreasingSizeBestFitHeap::ChunkCandidate chunk_candidate)785 void GlobalDecreasingSizeBestFitHeap::CommitChunk(
786 const GlobalDecreasingSizeBestFitHeap::BufferInterval& buffer_interval,
787 GlobalDecreasingSizeBestFitHeap::ChunkCandidate chunk_candidate) {
788 // Update the maximum heap size according to the one determined by the chunk
789 // candidate.
790 result_.heap_size = chunk_candidate.heap_size;
791 interval_tree_.Add(buffer_interval.start, buffer_interval.end,
792 chunk_candidate.chunk);
793 for (auto colocation : GetTransitiveColocations(buffer_interval)) {
794 AddToChunkMap(colocation, chunk_candidate.chunk);
795 auto colocation_interval = buffer_intervals_[colocation];
796 interval_tree_.Add(colocation_interval.start, colocation_interval.end,
797 chunk_candidate.chunk);
798 }
799
800 AddToChunkMap(buffer_interval.buffer, chunk_candidate.chunk);
801 }
802
AddToChunkMap(const HloValue * buffer,Chunk chunk)803 void GlobalDecreasingSizeBestFitHeap::AddToChunkMap(const HloValue* buffer,
804 Chunk chunk) {
805 const auto emplace_result = result_.chunk_map.emplace(buffer, chunk);
806 DCHECK(emplace_result.second);
807 }
808
Finish()809 HeapSimulator::Result ChooseBestHeapAlgorithm::Finish() {
810 DCHECK(!algorithms_.empty());
811 std::vector<Result> results(algorithms_.size());
812 int64 min_size = INT64_MAX;
813 int min_size_index = -1;
814 for (int i = 0; i < algorithms_.size(); ++i) {
815 results[i] = algorithms_[i]->Finish();
816 if (results[i].heap_size < min_size) {
817 min_size = results[i].heap_size;
818 min_size_index = i;
819 }
820 }
821
822 DCHECK_GE(min_size_index, 0);
823 return results[min_size_index];
824 }
825
826 } // namespace xla
827