1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/heap_simulator.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <memory>
22 #include <tuple>
23 #include <vector>
24
25 #include "absl/algorithm/container.h"
26 #include "absl/container/btree_map.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "tensorflow/compiler/xla/comparison_util.h"
30 #include "tensorflow/compiler/xla/map_util.h"
31 #include "tensorflow/compiler/xla/service/hlo_live_range.h"
32 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
33 #include "tensorflow/compiler/xla/service/memory_space_assignment_repacking.h"
34 #include "tensorflow/compiler/xla/util.h"
35
36 namespace xla {
37
38 using absl::flat_hash_map;
39 using absl::flat_hash_set;
40
OverlapsWith(Chunk other_chunk) const41 bool HeapSimulator::Chunk::OverlapsWith(Chunk other_chunk) const {
42 CHECK_NE(size, 0);
43 CHECK_NE(other_chunk.size, 0);
44 return offset < other_chunk.chunk_end() && other_chunk.offset < chunk_end();
45 }
46
47 /*static*/
MinimumMemoryForModule(const HloSchedule & schedule,const LogicalBuffer::SizeFunction & size_function)48 StatusOr<int64_t> HeapSimulator::MinimumMemoryForModule(
49 const HloSchedule& schedule,
50 const LogicalBuffer::SizeFunction& size_function) {
51 if (schedule.empty()) {
52 return 0;
53 }
54 const HloModule* module = schedule.module();
55
56 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
57 HloAliasAnalysis::Run(module));
58
59 // The absolute minimum memory required for a given sequence of instructions
60 // is determined by the sequence of Alloc and Free calls on a simulated heap,
61 // ignoring fragmentation. We run the heap simulation on the whole module,
62 // rather than summing each computation, since it gives us a better lower
63 // bound, by minimizing the liveness of sub-computations.
64 TF_ASSIGN_OR_RETURN(
65 HeapSimulator::Result<HloValue> result,
66 HeapSimulator::Run(std::make_unique<NoFragmentationStatsHeap<HloValue>>(),
67 *module, schedule, *alias_analysis, size_function));
68 return result.heap_size;
69 }
70
71 /*static*/
MinimumMemoryForComputation(const HloComputation & computation,const HloInstructionSequence & sequence,const HloAliasAnalysis & alias_analysis,const LogicalBuffer::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64_t> * memory_by_computation)72 StatusOr<int64_t> HeapSimulator::MinimumMemoryForComputation(
73 const HloComputation& computation, const HloInstructionSequence& sequence,
74 const HloAliasAnalysis& alias_analysis,
75 const LogicalBuffer::SizeFunction& size_function,
76 const absl::flat_hash_map<const HloComputation*, int64_t>*
77 memory_by_computation) {
78 TF_ASSIGN_OR_RETURN(
79 HeapSimulator::Result<HloValue> result,
80 HeapSimulator::Run(std::make_unique<NoFragmentationStatsHeap<HloValue>>(),
81 computation, sequence, alias_analysis, size_function,
82 HeapSimulator::Options(), memory_by_computation));
83 return result.heap_size;
84 }
85
MinimumMemoryForComputation(const HloComputation & computation,const HloInstructionSequence & sequence,const HloAliasAnalysis & alias_analysis,const LogicalBuffer::SizeFunction & size_function,const HloSchedule * schedule)86 StatusOr<int64_t> HeapSimulator::MinimumMemoryForComputation(
87 const HloComputation& computation, const HloInstructionSequence& sequence,
88 const HloAliasAnalysis& alias_analysis,
89 const LogicalBuffer::SizeFunction& size_function,
90 const HloSchedule* schedule) {
91 TF_ASSIGN_OR_RETURN(
92 HeapSimulator::Result<HloValue> result,
93 HeapSimulator::Run(std::make_unique<NoFragmentationStatsHeap<HloValue>>(),
94 computation, sequence, alias_analysis, size_function,
95 schedule, HeapSimulator::Options()));
96 return result.heap_size;
97 }
98
99 /*static*/
Run(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,const HloModule & module,const HloSchedule & schedule,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_fn,const Options & options)100 StatusOr<HeapSimulator::Result<HloValue>> HeapSimulator::Run(
101 std::unique_ptr<HeapAlgorithm<HloValue>> algorithm, const HloModule& module,
102 const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis,
103 const BufferValue::SizeFunction& size_fn, const Options& options) {
104 HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule);
105 const HloComputation* entry_computation = module.entry_computation();
106 const HloInstructionSequence& instruction_sequence =
107 schedule.sequence(entry_computation);
108 TF_ASSIGN_OR_RETURN(
109 std::unique_ptr<HloLiveRange> hlo_live_range,
110 HloLiveRange::Run(schedule, alias_analysis, entry_computation));
111 TF_RETURN_IF_ERROR(heap.RunComputation(*entry_computation,
112 instruction_sequence, alias_analysis,
113 hlo_live_range.get()));
114 return heap.Finish();
115 }
116
117 /*static*/
Run(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_fn,const Options & options,const absl::flat_hash_map<const HloComputation *,int64_t> * memory_by_computation)118 StatusOr<HeapSimulator::Result<HloValue>> HeapSimulator::Run(
119 std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
120 const HloComputation& computation,
121 const HloInstructionSequence& instruction_sequence,
122 const HloAliasAnalysis& alias_analysis,
123 const BufferValue::SizeFunction& size_fn, const Options& options,
124 const absl::flat_hash_map<const HloComputation*, int64_t>*
125 memory_by_computation) {
126 HeapSimulator heap(std::move(algorithm), size_fn, options,
127 /*schedule=*/nullptr, memory_by_computation);
128 HloSchedule schedule(computation.parent());
129 schedule.set_sequence(&computation, instruction_sequence);
130 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
131 HloLiveRange::Run(schedule, alias_analysis, &computation,
132 /*module_scoped_analysis=*/false));
133 TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
134 alias_analysis, hlo_live_range.get()));
135 return heap.Finish();
136 }
137
138 /*static*/
Run(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_fn,const HloSchedule * schedule,const Options & options)139 StatusOr<HeapSimulator::Result<HloValue>> HeapSimulator::Run(
140 std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
141 const HloComputation& computation,
142 const HloInstructionSequence& instruction_sequence,
143 const HloAliasAnalysis& alias_analysis,
144 const BufferValue::SizeFunction& size_fn, const HloSchedule* schedule,
145 const Options& options) {
146 HeapSimulator heap(std::move(algorithm), size_fn, options,
147 /*schedule=*/schedule, nullptr);
148 TF_ASSIGN_OR_RETURN(
149 std::unique_ptr<HloLiveRange> hlo_live_range,
150 HloLiveRange::Run(*schedule, alias_analysis, &computation));
151 TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
152 alias_analysis, hlo_live_range.get()));
153 return heap.Finish();
154 }
155
156 // Runs a heap simulation for the given 'computation', assuming the given
157 // 'instruction_sequence'.
RunComputation(const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const HloAliasAnalysis & alias_analysis,HloLiveRange * hlo_live_range)158 Status HeapSimulator::RunComputation(
159 const HloComputation& computation,
160 const HloInstructionSequence& instruction_sequence,
161 const HloAliasAnalysis& alias_analysis, HloLiveRange* hlo_live_range) {
162 XLA_VLOG_LINES(1, computation.parent()->ToString());
163 XLA_VLOG_LINES(2, computation.ToString());
164
165 VLOG(1) << hlo_live_range->ToString();
166
167 HloDataflowAnalysis& dataflow_analysis = alias_analysis.dataflow_analysis();
168
169 // Record the buffer define/free event for each time step. We free all
170 // remaining buffers (entry parameter, etc) after the program has finished
171 // running, so we set the size of to program_end_time + 1.
172 std::vector<std::vector<const HloValue*>> buffers_defined(
173 hlo_live_range->schedule_end_time() + 1);
174 std::vector<std::vector<const HloValue*>> buffers_freed(
175 hlo_live_range->schedule_end_time() + 1);
176
177 // values_to_assign tracks the HloValues that we need to assign a buffer to.
178 // Note that we only need to assign a buffer to a value when both of the
179 // following conditions are met:
180 //
181 // - The user specifically asks us to assign a buffer to a set of HloValues,
182 // and the value is in the set. If the user don't provide such a set, by
183 // default we assign buffer to all HloValues.
184 //
185 // - If the instruction is in a nested call of the current computation, only
186 // assign a buffer if we are doing global heap simulation.
187 std::vector<const HloValue*> values_to_assign;
188 values_to_assign.reserve(dataflow_analysis.values().size());
189
190 auto& buffer_live_ranges = hlo_live_range->buffer_live_ranges();
191
192 for (const HloValue* value : dataflow_analysis.values()) {
193 // Ignore buffers that are not tracked.
194 if (!buffer_live_ranges.contains(value)) {
195 continue;
196 }
197 if (IgnoreBuffer(value)) {
198 continue;
199 }
200
201 values_to_assign.push_back(value);
202 }
203
204 absl::c_sort(values_to_assign,
205 [&](const HloValue* value1, const HloValue* value2) {
206 const auto& live_range1 = buffer_live_ranges.at(value1);
207 const auto& live_range2 = buffer_live_ranges.at(value2);
208 return std::forward_as_tuple(live_range1.start,
209 live_range1.end, value1->id()) <
210 std::forward_as_tuple(live_range2.start,
211 live_range2.end, value2->id());
212 });
213
214 // For each value that we need to assign a buffer to, add the define and free
215 // events.
216 for (const HloValue* value : values_to_assign) {
217 auto live_range = buffer_live_ranges.at(value);
218 buffers_defined[live_range.start].push_back(value);
219 buffers_freed[live_range.end].push_back(value);
220 }
221
222 // All HloValues in a hlo buffer should be allocated to the same address. This
223 // map tracks the first value that got allocated in a buffer.
224 absl::flat_hash_map<const HloBuffer*, const HloValue*> first_allocated_value;
225
226 VLOG(1) << "Program time" << hlo_live_range->schedule_end_time();
227
228 // Go through each step in the program and replay each buffer define and free
229 // events.
230 for (int64_t i = 0; i < hlo_live_range->schedule_end_time() + 1; ++i) {
231 VLOG(1) << "Time step: " << i;
232
233 for (const HloValue* value : buffers_defined[i]) {
234 bool shared = false;
235 VLOG(1) << "Start buffer: " << value->ToShortString();
236 const HloBuffer* hlo_buffer =
237 &alias_analysis.GetBufferContainingValue(*value);
238 if (first_allocated_value.count(hlo_buffer) != 0) {
239 // We've already assigned an address for another value in this HloBuffer
240 // (HloBuffer holds several aliased HloValues). All values in a buffer
241 // should be assigned the same address. Find the one that's already
242 // allocated and reuse its address.
243 ShareBuffer(value, first_allocated_value[hlo_buffer],
244 value->instruction());
245 VLOG(1) << " ShareWith"
246 << first_allocated_value[hlo_buffer]->ToShortString();
247 continue;
248 }
249 if (options_.may_reuse_operand_buffers &&
250 hlo_buffer->values().size() == 1) {
251 // We don't support sharing an aliased buffer
252 // (hlo_buffer->values().size() > 1) with its operand.
253 for (const HloInstruction* operand : value->instruction()->operands()) {
254 const HloValueSet operand_value_set =
255 dataflow_analysis.GetValueSet(operand);
256 for (const HloValue* operand_value : operand_value_set.values()) {
257 const HloBuffer* operand_buffer =
258 &alias_analysis.GetBufferContainingValue(*operand_value);
259 if (operand_buffer->values().size() > 1) {
260 continue;
261 }
262 auto it = buffer_live_ranges.find(operand_value);
263 if (it == buffer_live_ranges.end()) {
264 continue;
265 }
266
267 auto& operand_live_range = it->second;
268
269 auto& user_live_range = buffer_live_ranges[value];
270
271 // Can only share buffers that are about to be freed.
272 if (operand_live_range.end != i) {
273 continue;
274 }
275
276 if (IgnoreBuffer(operand_value)) {
277 continue;
278 }
279
280 if (!absl::c_linear_search(buffers_freed[i], operand_value)) {
281 // If the operand buffer is not being freed (either because it has
282 // existing users, or it has been reused by other buffers), don't
283 // consider the operand as a candidate of buffer sharing.
284 continue;
285 }
286
287 // The instruction that defines the operand value can be different
288 // from the actual operand, if directly passing the defining
289 // instruction into "CanShareOperandBufferWithUser" it creates a
290 // check failure. The first condition guards against that case.
291 if (value->instruction()->IsUserOf(operand_value->instruction()) &&
292 value->instruction()->opcode() != HloOpcode::kCopy &&
293 dataflow_analysis.CanShareOperandBufferWithUser(
294 operand_value->instruction(), operand_value->index(),
295 value->instruction(), value->index())) {
296 // Remove the operand buffer right before sharing (allocating) a
297 // new one.
298 Free(operand_value, operand_value->instruction());
299 buffers_freed[i].erase(
300 std::remove(buffers_freed[i].begin(), buffers_freed[i].end(),
301 operand_value),
302 buffers_freed[i].end());
303 ShareBuffer(value, operand_value, value->instruction());
304 // The live range of the operand buffer is now extended to the end
305 // of the current instruction.
306 operand_live_range.end = user_live_range.end;
307 VLOG(1) << "Sharing " << value->ToShortString() << " with "
308 << operand_value->ToShortString()
309 << ", size:" << size_fn_(*value);
310 shared = true;
311 break;
312 }
313 }
314 if (shared) {
315 break;
316 }
317 }
318 }
319 if (!shared) {
320 Alloc(value, value->instruction());
321 first_allocated_value[hlo_buffer] = value;
322 }
323 }
324
325 if (!buffers_freed[i].empty()) {
326 VLOG(1) << "Free Buffer: ";
327 }
328 for (const HloValue* value : buffers_freed[i]) {
329 VLOG(1) << " " << value->ToShortString();
330
331 Free(value, value->instruction());
332 }
333 }
334 return OkStatus();
335 }
336
HeapSimulator(std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,const BufferValue::SizeFunction & size_fn,const Options & options,const HloSchedule * schedule,const absl::flat_hash_map<const HloComputation *,int64_t> * memory_by_computation)337 HeapSimulator::HeapSimulator(
338 std::unique_ptr<HeapAlgorithm<HloValue>> algorithm,
339 const BufferValue::SizeFunction& size_fn, const Options& options,
340 const HloSchedule* schedule,
341 const absl::flat_hash_map<const HloComputation*, int64_t>*
342 memory_by_computation)
343 : no_fragmentation_stats_(
344 std::make_unique<NoFragmentationStatsHeap<HloValue>>()),
345 algorithm_(std::move(algorithm)),
346 size_fn_(size_fn),
347 options_(options),
348 schedule_(schedule),
349 memory_by_computation_(memory_by_computation) {
350 debug_trace_.set_whole_module_simulation(schedule_ != nullptr);
351 }
352
~HeapSimulator()353 HeapSimulator::~HeapSimulator() {}
354
IgnoreBuffer(const HloValue * buffer) const355 bool HeapSimulator::IgnoreBuffer(const HloValue* buffer) const {
356 // Buffers for constants are ignored unless the alloc_constants option is
357 // set. Also ignore buffers that we're not meant to assign.
358 //
359 // TODO(b/32248867): For consistency, constants should get allocations.
360 if (!options_.alloc_constants &&
361 buffer->instruction()->opcode() == HloOpcode::kConstant) {
362 return true;
363 }
364 return options_.buffers_to_assign != nullptr &&
365 !options_.buffers_to_assign->contains(buffer);
366 }
367
368 // Alloc always calls the underlying heap algorithm.
Alloc(const HloValue * buffer,const HloInstruction * instruction)369 void HeapSimulator::Alloc(const HloValue* buffer,
370 const HloInstruction* instruction) {
371 CHECK(!allocated_buffers_.contains(buffer))
372 << "Alloc called on allocated buffer: " << *buffer;
373 CHECK(!freed_buffers_.contains(buffer))
374 << "Alloc called on freed buffer: " << *buffer;
375
376 allocated_buffers_.insert(buffer);
377 const int64_t size = size_fn_(*buffer);
378 algorithm_->Alloc(buffer, size);
379 no_fragmentation_stats_->Alloc(buffer, size);
380 FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction,
381 nullptr);
382 }
383
384 // Free calls the underlying algorithm for non-shared buffers, and for shared
385 // buffers whose group liveness has expired. Shared group liveness is tracked
386 // by maintaining a refcount; the Free call on the last buffer in the group
387 // causes Free to be called on the underlying algorithm.
Free(const HloValue * buffer,const HloInstruction * instruction)388 void HeapSimulator::Free(const HloValue* buffer,
389 const HloInstruction* instruction) {
390 const int64_t size = size_fn_(*buffer);
391 algorithm_->Free(buffer, size);
392 no_fragmentation_stats_->Free(buffer, size);
393 FillDebugTrace(HeapSimulatorTrace::Event::FREE, buffer, instruction, nullptr);
394 }
395
396 // ShareBuffer associates buffers with their SharedGroup in shared_buffers_.
397 // The 'buffer' must be a non-allocated, non-freed buffer, just like in calls
398 // to Alloc. The 'shared' buffer must be a previously allocated or shared
399 // buffer. Both 'buffer' and 'shared' will be associated with the same
400 // SharedGroup.
ShareBuffer(const HloValue * buffer,const HloValue * shared,const HloInstruction * instruction)401 void HeapSimulator::ShareBuffer(const HloValue* buffer, const HloValue* shared,
402 const HloInstruction* instruction) {
403 algorithm_->ShareWith(buffer, shared, size_fn_(*shared));
404 no_fragmentation_stats_->ShareWith(buffer, shared, size_fn_(*shared));
405 FillDebugTrace(HeapSimulatorTrace::Event::SHARE_WITH, buffer, instruction,
406 shared);
407 }
408
Finish()409 HeapSimulator::Result<HloValue> HeapSimulator::Finish() {
410 Result<HloValue> result = algorithm_->Finish();
411
412 // Post-process the result to add chunks for shared buffers. An empty chunk
413 // map means that either no buffers were allocated, or the heap was only
414 // collecting statistics, e.g. NoFragmentationStatsHeap.
415 size_t total_chunk_count = absl::c_accumulate(
416 result.heap_results, static_cast<size_t>(0),
417 [&](size_t lhs, const HeapResult<HloValue>& rhs) -> size_t {
418 return lhs + rhs.chunk_map.size();
419 });
420 if (total_chunk_count != 0) {
421 // If we were told to assign specific buffers, make sure we've assigned
422 // exactly that many buffers.
423 if (options_.buffers_to_assign != nullptr) {
424 CHECK_EQ(options_.buffers_to_assign->size(), total_chunk_count);
425 }
426 }
427
428 // Fragmentation is the difference between the actual and ideal sizes.
429 const Result<HloValue> no_frag_result = no_fragmentation_stats_->Finish();
430 result.fragmentation_size = result.heap_size - no_frag_result.heap_size;
431
432 // Copy the debug trace we collected to the final result.
433 result.debug_trace.Swap(&debug_trace_);
434
435 return result;
436 }
437
FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,const HloValue * buffer,const HloInstruction * instruction,const HloValue * share_with_canonical)438 void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,
439 const HloValue* buffer,
440 const HloInstruction* instruction,
441 const HloValue* share_with_canonical) {
442 HeapSimulatorTrace::Event* event = debug_trace_.add_events();
443 event->set_kind(kind);
444 event->set_buffer_id(buffer->id());
445 event->set_computation_name(instruction->parent()->name());
446 event->set_instruction_name(instruction->name());
447 if (kind == HeapSimulatorTrace::Event::SHARE_WITH) {
448 CHECK(share_with_canonical != nullptr);
449 event->set_share_with_canonical_id(share_with_canonical->id());
450 } else {
451 CHECK(share_with_canonical == nullptr);
452 }
453 }
454
455 template <typename BufferType>
Alloc(const BufferType * buffer,int64_t size)456 void NoFragmentationStatsHeap<BufferType>::Alloc(const BufferType* buffer,
457 int64_t size) {
458 current_heap_size_ += size;
459 if (current_heap_size_ > max_heap_size_) {
460 max_heap_size_ = current_heap_size_;
461 }
462 }
463
464 template <typename BufferType>
AccountForSubcomputationMemory(const HloInstruction * instruction,int64_t alloc_size_by_instruction,const absl::flat_hash_map<const HloComputation *,int64_t> & memory_by_computation)465 void NoFragmentationStatsHeap<BufferType>::AccountForSubcomputationMemory(
466 const HloInstruction* instruction, int64_t alloc_size_by_instruction,
467 const absl::flat_hash_map<const HloComputation*, int64_t>&
468 memory_by_computation) {
469 // We only count the memory usage of the largest subcomputation, instead of
470 // adding them all, because subcomputations won't execute in parallel.
471 int64_t max_subcomputation_bytes = 0;
472 for (const auto* c : instruction->called_computations()) {
473 auto it = memory_by_computation.find(c);
474 if (it != memory_by_computation.end()) {
475 int64_t subcomputation_bytes = it->second;
476 if (subcomputation_bytes > max_subcomputation_bytes) {
477 max_subcomputation_bytes = subcomputation_bytes;
478 }
479 }
480 }
481 if (max_subcomputation_bytes > 0 &&
482 (instruction->opcode() == HloOpcode::kWhile ||
483 instruction->opcode() == HloOpcode::kCall ||
484 instruction->opcode() == HloOpcode::kConditional)) {
485 // The output buffer of while/call/conditional is always aliased with the
486 // output buffer of the root instruction in the body. Don't double count.
487 max_subcomputation_bytes -= alloc_size_by_instruction;
488 }
489 max_heap_size_ =
490 std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes);
491 }
492
493 template <typename BufferType>
Free(const BufferType * buffer,int64_t size)494 void NoFragmentationStatsHeap<BufferType>::Free(const BufferType* buffer,
495 int64_t size) {
496 current_heap_size_ -= size;
497 }
498
499 template <typename BufferType>
500 HeapSimulator::Result<BufferType>
Finish()501 NoFragmentationStatsHeap<BufferType>::Finish() {
502 // The result.chunk_map is empty, since we only collect stats, and don't
503 // actually compute chunk assignments.
504 Result result;
505 result.heap_size = max_heap_size_;
506 return result;
507 }
508
509 template <typename BufferType>
GlobalDecreasingSizeBestFitHeap(int64_t alignment,Type type)510 GlobalDecreasingSizeBestFitHeap<BufferType>::GlobalDecreasingSizeBestFitHeap(
511 int64_t alignment, Type type)
512 : alignment_(alignment) {
513 if (type == kTemporal) {
514 buffer_interval_compare_ = GetTemporalBufferIntervalCompare();
515 } else {
516 CHECK(type == kSpatial);
517 buffer_interval_compare_ = GetSpatialBufferIntervalCompare();
518 }
519 }
520
521 template <typename BufferType>
522 typename GlobalDecreasingSizeBestFitHeap<BufferType>::BufferIntervalCompare
GetTemporalBufferIntervalCompare() const523 GlobalDecreasingSizeBestFitHeap<BufferType>::GetTemporalBufferIntervalCompare()
524 const {
525 return LessThanByKey([this](const BufferInterval& x) {
526 int64_t x_end = x.end;
527 for (auto colocation : GetTransitiveColocations(x)) {
528 x_end = std::max(x_end, buffer_intervals_.at(colocation).end);
529 }
530 // Sort by duration (descending), size (descending), buffer (ascending).
531 return std::make_tuple(x.start - x_end, -x.size, std::cref(*x.buffer));
532 });
533 }
534
535 template <typename BufferType>
536 /*static*/ typename GlobalDecreasingSizeBestFitHeap<
537 BufferType>::BufferIntervalCompare
GetSpatialBufferIntervalCompare()538 GlobalDecreasingSizeBestFitHeap<BufferType>::GetSpatialBufferIntervalCompare() {
539 return LessThanByKey([](const BufferInterval& x) {
540 // Sort by size (descending), duration (descending), buffer (ascending).
541 return std::make_tuple(-x.size, x.start - x.end, std::cref(*x.buffer));
542 });
543 }
544
545 template <typename BufferType>
Alloc(const BufferType * buffer,int64_t size)546 void GlobalDecreasingSizeBestFitHeap<BufferType>::Alloc(
547 const BufferType* buffer, int64_t size) {
548 // Degenerate case: 0-sized buffers are always allocated at offset 0.
549 if (size == 0) {
550 result_.chunk_map.emplace(buffer, Chunk{0, 0});
551 return;
552 }
553
554 auto emplace_result = buffer_intervals_.emplace(
555 buffer, BufferInterval{buffer, size, current_time_, -1, {}, true});
556 DCHECK(emplace_result.second);
557 ++current_time_;
558 }
559
560 template <typename BufferType>
ShareWith(const BufferType * buffer,const BufferType * share_with,int64_t size)561 void GlobalDecreasingSizeBestFitHeap<BufferType>::ShareWith(
562 const BufferType* buffer, const BufferType* share_with, int64_t size) {
563 // Degenerate case: 0-sized buffers are always allocated at offset 0.
564 if (size == 0) {
565 result_.chunk_map.emplace(buffer, Chunk{0, 0});
566 return;
567 }
568 DCHECK_NE(buffer_intervals_.count(share_with), 0);
569 buffer_intervals_[share_with].colocations.push_back(buffer);
570 auto emplace_result = buffer_intervals_.emplace(
571 buffer, BufferInterval{buffer, size, current_time_, -1, {}, false});
572 DCHECK(emplace_result.second);
573 ++current_time_;
574 }
575
576 template <typename BufferType>
577 absl::flat_hash_set<const BufferType*>
GetTransitiveColocations(const BufferInterval & interval) const578 GlobalDecreasingSizeBestFitHeap<BufferType>::GetTransitiveColocations(
579 const BufferInterval& interval) const {
580 absl::flat_hash_set<const BufferType*> result;
581 std::vector<const BufferInterval*> worklist = {&interval};
582 while (!worklist.empty()) {
583 const BufferInterval* item = worklist.back();
584 worklist.pop_back();
585 for (const BufferType* buffer_colocated : item->colocations) {
586 if (result.insert(buffer_colocated).second) {
587 worklist.push_back(&buffer_intervals_.at(buffer_colocated));
588 }
589 }
590 }
591
592 return result;
593 }
594
595 template <typename BufferType>
Free(const BufferType * buffer,int64_t size)596 void GlobalDecreasingSizeBestFitHeap<BufferType>::Free(const BufferType* buffer,
597 int64_t size) {
598 // Degenerate case: 0-sized buffers are always allocated at offset 0.
599 if (size == 0) {
600 return;
601 }
602 BufferInterval& buffer_interval = FindOrDie(buffer_intervals_, buffer);
603 DCHECK_EQ(buffer_interval.buffer, buffer);
604 DCHECK_EQ(buffer_interval.size, size);
605 DCHECK_EQ(buffer_interval.end, -1);
606 if (buffer_interval.end != -1) {
607 return;
608 }
609 buffer_interval.end = current_time_;
610 ++current_time_;
611 }
612
613 using Chunk = HeapSimulator::Chunk;
614
Add(int64_t start,int64_t end,const Chunk & chunk)615 void BufferIntervalTree::Add(int64_t start, int64_t end, const Chunk& chunk) {
616 node_storage_.emplace_back(BufferIntervalTreeNode{
617 start, end, end, chunk,
618 /*left=*/nullptr, /*right=*/nullptr, /*parent=*/nullptr});
619 if (root_ == nullptr) {
620 root_ = &node_storage_.back();
621 // This is root.
622 return;
623 }
624
625 BufferIntervalTreeNode* parent = root_;
626 while (true) {
627 parent->subtree_end = std::max(parent->subtree_end, end);
628 if (parent->start > start) {
629 if (parent->left == nullptr) {
630 parent->left = &node_storage_.back();
631 node_storage_.back().parent = parent;
632 return;
633 }
634 parent = parent->left;
635 } else {
636 if (parent->right == nullptr) {
637 parent->right = &node_storage_.back();
638 node_storage_.back().parent = parent;
639 return;
640 }
641 parent = parent->right;
642 }
643 }
644 }
645
Remove(int64_t start,int64_t end,const Chunk & chunk)646 bool BufferIntervalTree::Remove(int64_t start, int64_t end,
647 const Chunk& chunk) {
648 BufferIntervalTreeNode* to_delete = root_;
649 while (to_delete != nullptr) {
650 if (to_delete->start == start && to_delete->end == end &&
651 to_delete->chunk.offset == chunk.offset) {
652 break;
653 }
654 if (start < to_delete->start) {
655 to_delete = to_delete->left;
656 } else {
657 to_delete = to_delete->right;
658 }
659 }
660 if (to_delete == nullptr) {
661 // Nothing to delete.
662 return false;
663 }
664 // Found the node to be deleted, enter deletion sequence.
665
666 // Recursively traverse the parents of node and fix up the `subtree_end`
667 // invariant of a node. Recursive lambda need an explicit
668 // std::function declaration.
669 std::function<void(BufferIntervalTreeNode*)> fix_up =
670 [&](BufferIntervalTreeNode* node) {
671 if (node == nullptr) {
672 return;
673 }
674 node->subtree_end = node->end;
675 if (node->left) {
676 node->subtree_end =
677 std::max(node->subtree_end, node->left->subtree_end);
678 }
679 if (node->right) {
680 node->subtree_end =
681 std::max(node->subtree_end, node->right->subtree_end);
682 }
683 // Recursively go up.
684 fix_up(node->parent);
685 };
686
687 if (to_delete->right == nullptr) {
688 // to_delete has no right child, simply move up left child of to_delete if
689 // any.
690 //
691 // Turn:
692 // parent
693 // /
694 // to_delete
695 // / \
696 // left nullptr
697 //
698 // Into:
699 // parent
700 // /
701 // left
702 if (root_ == to_delete) {
703 // Deleting root is simply reseting root;
704 root_ = to_delete->left;
705 return true;
706 }
707
708 if (to_delete == to_delete->parent->left) {
709 // to_delete is left child of parent.
710 to_delete->parent->left = to_delete->left;
711 }
712 if (to_delete == to_delete->parent->right) {
713 // to_delete is right child of parent.
714 to_delete->parent->right = to_delete->left;
715 }
716 // Rewire parent to the node being moved up.
717 if (to_delete->left) {
718 to_delete->left->parent = to_delete->parent;
719 }
720 // Fix up starting from subroot.
721 fix_up(to_delete);
722 } else {
723 // 1. Find left-most node of the right subtree, promote it to the position
724 // of to_delete.
725 BufferIntervalTreeNode* to_promote = to_delete->right;
726 while (to_promote->left != nullptr) {
727 // Go to left-most subtree.
728 to_promote = to_promote->left;
729 }
730
731 // 2. Copy the content of `to_promote` to `to_delete`.
732 to_delete->start = to_promote->start;
733 to_delete->end = to_promote->end;
734 // This is incorrect but we will fix this up later in the `fix_up`
735 // procedure.
736 to_delete->subtree_end = to_promote->subtree_end;
737 to_delete->chunk = to_promote->chunk;
738 auto to_promote_parent = to_promote->parent;
739 // 3. Move the right child of `to_promote` up if there is any.
740 //
741 // Turn
742 //
743 // to_delete
744 // \
745 // to_promote_parent
746 // /
747 // to_promote
748 // \
749 // right
750 // into
751 //
752 // to_promote
753 // \
754 // to_promote_parent
755 // /
756 // right
757 if (to_promote_parent->left == to_promote) {
758 to_promote_parent->left = to_promote->right;
759 } else {
760 to_promote_parent->right = to_promote->right;
761 }
762 if (to_promote->right) {
763 // Set correct parent.
764 to_promote->right->parent = to_promote_parent;
765 }
766 // 4. Recursive fix up the `subtree_end` starting from
767 // `to_promote_parent`.
768 fix_up(to_promote_parent);
769 }
770 // Don't free the entry in node_storage_ until we free the entire tree.
771 return true;
772 }
773
ChunksOverlappingInTime(int64_t start,int64_t end) const774 std::vector<Chunk> BufferIntervalTree::ChunksOverlappingInTime(
775 int64_t start, int64_t end) const {
776 std::vector<Chunk> result;
777 if (root_ == nullptr) {
778 return result;
779 }
780 std::vector<const BufferIntervalTreeNode*> visiting_stack;
781 visiting_stack.push_back(root_);
782 while (!visiting_stack.empty()) {
783 const BufferIntervalTreeNode* top = visiting_stack.back();
784 visiting_stack.pop_back();
785 if (start > top->subtree_end) {
786 continue;
787 }
788 if (top->left != nullptr) {
789 visiting_stack.push_back(top->left);
790 }
791 if (top->start <= end && top->end >= start) {
792 result.push_back(top->chunk);
793 }
794 if (end < top->start) {
795 continue;
796 }
797 if (top->right != nullptr) {
798 visiting_stack.push_back(top->right);
799 }
800 }
801 return result;
802 }
803
804 template <typename BufferType>
805 HeapSimulator::Result<BufferType>
Finish()806 GlobalDecreasingSizeBestFitHeap<BufferType>::Finish() {
807 std::vector<BufferInterval> sorted_buffer_intervals =
808 GetSortedBufferIntervals();
809
810 for (auto& buffer_interval : sorted_buffer_intervals) {
811 if (!buffer_interval.need_allocation) {
812 continue;
813 }
814
815 // This implementation of the heap algorithm does not have a notion of
816 // maximum heap size, so it just commits.
817 CommitChunk(buffer_interval, FindChunkCandidate(buffer_interval));
818 }
819 VLOG(1) << "result heap_size: " << result_.heap_size;
820 Result result;
821 result.heap_size = result_.heap_size;
822 result.heap_results.emplace_back(result_);
823 return result;
824 }
825
826 template <typename BufferType>
827 std::vector<
828 typename GlobalDecreasingSizeBestFitHeap<BufferType>::BufferInterval>
GetSortedBufferIntervals() const829 GlobalDecreasingSizeBestFitHeap<BufferType>::GetSortedBufferIntervals() const {
830 std::vector<BufferInterval> sorted_buffer_intervals;
831 sorted_buffer_intervals.reserve(buffer_intervals_.size());
832 for (auto& entry : buffer_intervals_) {
833 sorted_buffer_intervals.push_back(entry.second);
834 }
835 absl::c_sort(sorted_buffer_intervals, buffer_interval_compare_);
836
837 return sorted_buffer_intervals;
838 }
839
840 template <typename BufferType>
841 typename GlobalDecreasingSizeBestFitHeap<BufferType>::Chunk
FindChunkCandidate(const GlobalDecreasingSizeBestFitHeap::BufferInterval & buffer_interval,int64_t preferred_offset) const842 GlobalDecreasingSizeBestFitHeap<BufferType>::FindChunkCandidate(
843 const GlobalDecreasingSizeBestFitHeap::BufferInterval& buffer_interval,
844 int64_t preferred_offset) const {
845 VLOG(1) << "Finding chunks for buffer: "
846 << buffer_interval.buffer->ToString();
847 VLOG(1) << "Size " << buffer_interval.size << ", start "
848 << buffer_interval.start << ", end " << buffer_interval.end;
849 // Get all colocated buffers and gather all interferenced chunks.
850 //
851 // Imagine that we've already allocated three chunks : a, b and c. And now
852 // we want to allocate d. Since e is colocated with d, we have to allocate
853 // chunks for them together at the same address. To do this, we first gather
854 // all chunks that overlap with d and e on the time dimension, in this case
855 // the overlapped chunks are a and b (c doesn't overlap with either of d and
856 // e), then find create a new chunk that doesn't overlap with a and b on the
857 // space dimension.
858 //
859 // space
860 // ^
861 // |+--d---+ +---e---+
862 // |
863 // |+---+ +---------------+ +-------+
864 // || | | | | |
865 // || | | | | |
866 // |+-a-+ +-------b-------+ +---c---+
867 // ----------------------------------------> time
868
869 // Map free chunk offsets -> ends.
870 // We use `greater` for the comparison so that we can use `lower_bound` to
871 // find the largest key less than or equal to the lookup value.
872 absl::btree_map<int64_t, int64_t, std::greater<int64_t>> free_chunks{
873 {0, INT64_MAX}}; // Initialize with "infinite" free memory.
874
875 // Find the max size of interval across its colocations and use this value to
876 // determine whether the buffer will fit in the heap.
877 int64_t max_colocation_size = buffer_interval.size;
878 for (const BufferType* colocation :
879 GetTransitiveColocations(buffer_interval)) {
880 max_colocation_size =
881 std::max(max_colocation_size, buffer_intervals_.at(colocation).size);
882 }
883
884 // Subtract chunks that are in use from the free chunks.
885 auto subtract_used_chunks = [&](const std::vector<Chunk>& used_chunks) {
886 for (const Chunk& used_chunk : used_chunks) {
887 // Find the free chunks containing the start and end of the used chunk.
888 auto it_end = free_chunks.lower_bound(used_chunk.chunk_end());
889 if (it_end == free_chunks.end()) continue;
890 auto it_start = free_chunks.lower_bound(used_chunk.offset);
891
892 // Store original free chunk end, in case `it_start == it_end`.
893 int64_t free_chunk_end = it_end->second;
894
895 // Subtract from free chunk containing start of used range, removing if it
896 // becomes too small for the buffer.
897 if (it_start != free_chunks.end()) {
898 if (used_chunk.offset - it_start->first >= buffer_interval.size) {
899 it_start->second = std::min(it_start->second, used_chunk.offset);
900 } else {
901 ++it_start; // Increment iterator so that this entry is erased below.
902 }
903 }
904
905 // Erase from the start chunk (possibly inclusive) to the end chunk
906 // (always inclusive). We iterate from end to start, as the map is in
907 // reverse order.
908 free_chunks.erase(it_end, it_start);
909
910 // Create a new free chunk after the used chunk, if it is large enough.
911 int64_t chunk_end_aligned = RoundUpTo(used_chunk.chunk_end(), alignment_);
912 if (free_chunk_end - chunk_end_aligned >= max_colocation_size) {
913 CHECK(free_chunks.insert({chunk_end_aligned, free_chunk_end}).second);
914 }
915 }
916 };
917
918 subtract_used_chunks(interval_tree_.ChunksOverlappingInTime(
919 buffer_interval.start, buffer_interval.end));
920
921 for (const BufferType* colocation :
922 GetTransitiveColocations(buffer_interval)) {
923 const BufferInterval& interval = buffer_intervals_.at(colocation);
924 VLOG(1) << " Alias size " << interval.size << ", start " << interval.start
925 << ", end " << interval.end << " " << interval.buffer->ToString();
926
927 subtract_used_chunks(
928 interval_tree_.ChunksOverlappingInTime(interval.start, interval.end));
929 }
930
931 // Try to find a large enough free chunk containing the preferred offset.
932 Chunk chunk{preferred_offset, max_colocation_size};
933 auto it = (preferred_offset < 0) ? free_chunks.end()
934 : free_chunks.lower_bound(preferred_offset);
935 if (it == free_chunks.end() || (it->second < chunk.chunk_end())) {
936 // Otherwise, find the smallest free chunk. In the case of a tie, prefer the
937 // smallest offset. We ensure above that all of the free chunks are large
938 // enough to store the buffer.
939 chunk.offset = absl::c_min_element(free_chunks, [](auto a, auto b) {
940 return std::forward_as_tuple(a.second - a.first, a.first) <
941 std::forward_as_tuple(b.second - b.first, b.first);
942 })->first;
943 }
944 return chunk;
945 }
946
947 template <typename BufferType>
CommitChunk(const GlobalDecreasingSizeBestFitHeap<BufferType>::BufferInterval & buffer_interval,GlobalDecreasingSizeBestFitHeap<BufferType>::Chunk chunk)948 void GlobalDecreasingSizeBestFitHeap<BufferType>::CommitChunk(
949 const GlobalDecreasingSizeBestFitHeap<BufferType>::BufferInterval&
950 buffer_interval,
951 GlobalDecreasingSizeBestFitHeap<BufferType>::Chunk chunk) {
952 // Update the maximum heap size according to the one determined by the chunk
953 // candidate. In case of colocations of different sizes, the chunk size
954 // returned is the maximum of all colocations, so use this value to update the
955 // heap size.
956 result_.heap_size = result_.UpdatedHeapSize(chunk);
957 // Now, update the chunk size to the actual size of the buffer interval.
958 chunk.size = buffer_interval.size;
959 interval_tree_.Add(buffer_interval.start, buffer_interval.end, chunk);
960 for (auto colocation : GetTransitiveColocations(buffer_interval)) {
961 auto colocation_interval = buffer_intervals_[colocation];
962 // Create a colocation chunk with the same offset but with the correct size
963 // of the colocated interval in case the colocations are of different sizes.
964 Chunk colocation_chunk{chunk.offset, colocation_interval.size};
965 AddToChunkMap(colocation, colocation_chunk);
966 interval_tree_.Add(colocation_interval.start, colocation_interval.end,
967 colocation_chunk);
968 }
969
970 AddToChunkMap(buffer_interval.buffer, chunk);
971 }
972
973 template <typename BufferType>
AddToChunkMap(const BufferType * buffer,Chunk chunk)974 void GlobalDecreasingSizeBestFitHeap<BufferType>::AddToChunkMap(
975 const BufferType* buffer, Chunk chunk) {
976 const auto emplace_result = result_.chunk_map.emplace(buffer, chunk);
977 DCHECK(emplace_result.second);
978 }
979
980 HeapSimulator::Result<HloValue>
Finish()981 ConstrainedGlobalDecreasingSizeBestFitHeap::Finish() {
982 std::vector<BufferInterval> sorted_buffer_vec = GetSortedBufferIntervals();
983 // Convert into std::list so that erase() is O(1).
984 std::list<BufferInterval> sorted_buffer_intervals(sorted_buffer_vec.begin(),
985 sorted_buffer_vec.end());
986
987 // Use do-while here, because we need to create 1 heap in `multi_heap_result`
988 // even if `sorted_buffer_intervals` is empty.
989 Result multi_heap_result;
990 do {
991 // Place buffers into the currently processed heap as many as possible.
992 for (auto it = sorted_buffer_intervals.begin();
993 it != sorted_buffer_intervals.end();) {
994 BufferInterval buffer_interval = *it;
995 if (!buffer_interval.need_allocation) {
996 it = sorted_buffer_intervals.erase(it);
997 continue;
998 }
999 if (buffer_interval.size > size_limit_per_heap_) {
1000 LOG(WARNING) << "Alloc buffer size " << buffer_interval.size
1001 << " larger than the per-heap size limit "
1002 << size_limit_per_heap_;
1003 }
1004
1005 Chunk chunk_candidate = FindChunkCandidate(buffer_interval);
1006 if (chunk_candidate.chunk_end() <= size_limit_per_heap_ ||
1007 // Commit the chunk as long as the heap is empty. We do this because
1008 // we want the size constraint to be soft, meaning that results are
1009 // successfully generated even if there are some buffer sizes larger
1010 // than the given constraint size.
1011 result_.heap_size == 0) {
1012 CommitChunk(buffer_interval, chunk_candidate);
1013 it = sorted_buffer_intervals.erase(it);
1014 continue;
1015 }
1016
1017 ++it;
1018 }
1019 // Collect the result from the currently processed heap and reset the heap
1020 // states.
1021 multi_heap_result.heap_size += result_.heap_size;
1022 multi_heap_result.heap_results.push_back(std::move(result_));
1023 result_ = {};
1024 interval_tree_ = {};
1025 } while (!sorted_buffer_intervals.empty());
1026
1027 VLOG(1) << "Number of heaps produced = "
1028 << multi_heap_result.heap_results.size();
1029 return multi_heap_result;
1030 }
1031
1032 template <typename BufferType>
1033 HeapSimulator::Result<BufferType>
Finish()1034 ChooseBestHeapAlgorithm<BufferType>::Finish() {
1035 DCHECK(!algorithms_.empty());
1036 std::vector<Result> results(algorithms_.size());
1037 int64_t min_size = INT64_MAX;
1038 int min_size_index = -1;
1039 for (int i = 0; i < algorithms_.size(); ++i) {
1040 results[i] = algorithms_[i]->Finish();
1041 if (results[i].heap_size < min_size) {
1042 min_size = results[i].heap_size;
1043 min_size_index = i;
1044 }
1045 }
1046
1047 DCHECK_GE(min_size_index, 0);
1048 return results[min_size_index];
1049 }
1050
1051 template class GlobalDecreasingSizeBestFitHeap<HloValue>;
1052 template class GlobalDecreasingSizeBestFitHeap<
1053 MemorySpaceAssignmentRepacker::AllocationBlock>;
1054 template class ChooseBestHeapAlgorithm<HloValue>;
1055
1056 } // namespace xla
1057