1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/heap_simulator.h"
17
18 #include <algorithm>
19 #include <vector>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/memory/memory.h"
24 #include "tensorflow/compiler/xla/map_util.h"
25 #include "tensorflow/compiler/xla/util.h"
26
27 namespace xla {
28
29 using absl::flat_hash_map;
30 using absl::flat_hash_set;
31
32 /*static*/
MinimumMemoryForModule(const HloSchedule & schedule,const LogicalBuffer::SizeFunction & size_function)33 StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
34 const HloSchedule& schedule,
35 const LogicalBuffer::SizeFunction& size_function) {
36 if (schedule.empty()) {
37 return 0;
38 }
39
40 const HloModule* module = schedule.module();
41 TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
42 TuplePointsToAnalysis::Run(module));
43
44 // The absolute minimum memory required for a given sequence of instructions
45 // is determined by the sequence of Alloc and Free calls on a simulated heap,
46 // ignoring fragmentation. We run the heap simulation on the whole module,
47 // rather than summing each computation, since it gives us a better lower
48 // bound, by minimizing the liveness of sub-computations.
49 TF_ASSIGN_OR_RETURN(
50 HeapSimulator::Result result,
51 HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(), *module,
52 schedule, *points_to_analysis, size_function));
53 return result.heap_size;
54 }
55
56 /*static*/
MinimumMemoryForComputation(const HloComputation & computation,const HloInstructionSequence & sequence,const TuplePointsToAnalysis & points_to_analysis,const LogicalBuffer::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> * memory_by_computation)57 StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
58 const HloComputation& computation, const HloInstructionSequence& sequence,
59 const TuplePointsToAnalysis& points_to_analysis,
60 const LogicalBuffer::SizeFunction& size_function,
61 const absl::flat_hash_map<const HloComputation*, int64>*
62 memory_by_computation) {
63 TF_ASSIGN_OR_RETURN(
64 HeapSimulator::Result result,
65 HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(),
66 computation, sequence, points_to_analysis,
67 size_function, HeapSimulator::Options(),
68 memory_by_computation));
69 return result.heap_size;
70 }
71
72 /*static*/
Run(std::unique_ptr<HeapAlgorithm> algorithm,const HloModule & module,const HloSchedule & schedule,const TuplePointsToAnalysis & points_to_analysis,const BufferValue::SizeFunction & size_fn,const Options & options)73 StatusOr<HeapSimulator::Result> HeapSimulator::Run(
74 std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module,
75 const HloSchedule& schedule,
76 const TuplePointsToAnalysis& points_to_analysis,
77 const BufferValue::SizeFunction& size_fn, const Options& options) {
78 HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule);
79 const HloComputation* entry_computation = module.entry_computation();
80 const HloInstructionSequence& instruction_sequence =
81 schedule.sequence(entry_computation);
82 TF_RETURN_IF_ERROR(heap.RunComputation(
83 *entry_computation, instruction_sequence, points_to_analysis));
84 return heap.Finish();
85 }
86
87 /*static*/
Run(std::unique_ptr<HeapAlgorithm> algorithm,const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const TuplePointsToAnalysis & points_to_analysis,const BufferValue::SizeFunction & size_fn,const Options & options,const absl::flat_hash_map<const HloComputation *,int64> * memory_by_computation)88 StatusOr<HeapSimulator::Result> HeapSimulator::Run(
89 std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation,
90 const HloInstructionSequence& instruction_sequence,
91 const TuplePointsToAnalysis& points_to_analysis,
92 const BufferValue::SizeFunction& size_fn, const Options& options,
93 const absl::flat_hash_map<const HloComputation*, int64>*
94 memory_by_computation) {
95 HeapSimulator heap(std::move(algorithm), size_fn, options,
96 /*schedule=*/nullptr, memory_by_computation);
97 TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
98 points_to_analysis));
99 return heap.Finish();
100 }
101
102 // Runs a heap simulation for the given 'computation', assuming the given
103 // 'instruction_sequence'.
RunComputation(const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const TuplePointsToAnalysis & points_to_analysis)104 Status HeapSimulator::RunComputation(
105 const HloComputation& computation,
106 const HloInstructionSequence& instruction_sequence,
107 const TuplePointsToAnalysis& points_to_analysis) {
108 VLOG(3) << "Computation:\n" << computation.ToString();
109 // The goal here is to minimize memory usage, assuming the given sequential
110 // ordering of instructions. The strategy is to walk through the instruction
111 // sequence, calling Alloc and Free on the underlying heap algorithm. The
112 // heap algorithm takes care of packing and reducing fragmentation.
113 //
114 // 'live_buffers' tracks the liveness of each buffer that we assign, by
115 // associating it with a set of HloInstructions that need to be visited. When
116 // the set becomes empty, the buffer is no longer used, and can be freed.
117 // 'used_buffers' is the reverse map - it tracks which buffers were used by an
118 // instruction, so that we can remove the instructions from a buffer's live
119 // set after they are visited.
120 flat_hash_map<const BufferValue*, flat_hash_set<const HloInstruction*>>
121 live_buffers;
122 flat_hash_map<const HloInstruction*, flat_hash_set<const BufferValue*>>
123 used_buffers;
124 auto add_user_to_buffer = [this, &live_buffers, &used_buffers](
125 const HloInstruction* user,
126 const BufferValue* buffer) {
127 if (!IgnoreBuffer(buffer)) {
128 VLOG(4) << " Adding user " << user->name() << " to buffer "
129 << buffer->ToString();
130 live_buffers[buffer].insert(user);
131 used_buffers[user].insert(buffer);
132 }
133 };
134
135 // Initialize live_buffers for each buffer that we're going to assign. The
136 // set of instructions that need to be visited contains all users of all
137 // aliases, that is, all users of all instructions that have the buffer
138 // contained in their points-to set.
139 for (const HloInstruction* instruction :
140 instruction_sequence.instructions()) {
141 const PointsToSet& points_to =
142 points_to_analysis.GetPointsToSet(instruction);
143 const PointsToSet::BufferSet& buffer_set = points_to.CreateFlattenedSet();
144 for (const HloInstruction* user : instruction->users()) {
145 if (user->opcode() != HloOpcode::kGetTupleElement) {
146 for (const BufferValue* buffer : buffer_set) {
147 add_user_to_buffer(user, buffer);
148 }
149 } else {
150 // A GetTupleElement doesn't need to keep all of its operand's buffers
151 // alive. It only needs the buffers that relate to the element it's
152 // extracting, and the tuple it's extracting from, but not the buffers
153 // for the other elements.
154 for (const BufferValue* buffer : points_to.element({})) {
155 add_user_to_buffer(user, buffer);
156 }
157 const PointsToSet& gte_points_to =
158 points_to_analysis.GetPointsToSet(user);
159 for (const BufferValue* buffer : gte_points_to.CreateFlattenedSet()) {
160 add_user_to_buffer(user, buffer);
161 }
162 }
163 }
164 }
165
166 const HloInstruction* root = computation.root_instruction();
167 BufferValueCompactPointerSet output_source_buffers =
168 ToBufferValueCompactPointerSet(
169 points_to_analysis.GetPointsToSet(root).CreateFlattenedSet());
170
171 std::vector<const BufferValue*> dead_buffers_to_free;
172 std::vector<const BufferValue*> operand_buffers_to_free;
173 for (const HloInstruction* instruction :
174 instruction_sequence.instructions()) {
175 const TuplePointsToAnalysis::BufferDefinitionVector&
176 buffers_defined_by_instruction =
177 points_to_analysis.GetBuffersDefinedByInstruction(instruction);
178
179 VLOG(3) << "Instruction: " << instruction->ToString();
180 for (const BufferValue* buffer : buffers_defined_by_instruction) {
181 VLOG(4) << " Defines: " << buffer->ToString()
182 << (IgnoreBuffer(buffer) ? " (Ignored)" : "");
183 }
184
185 dead_buffers_to_free.clear();
186 for (const BufferValue* buffer : buffers_defined_by_instruction) {
187 if (IgnoreBuffer(buffer)) {
188 continue;
189 }
190 // Add a nullptr sentry to ensure entry parameters and output source
191 // buffers are not freed until the very end.
192 const bool entry_parameter =
193 &computation == computation.parent()->entry_computation() &&
194 buffer->instruction()->opcode() == HloOpcode::kParameter;
195 const bool output = output_source_buffers.count(buffer) > 0;
196 if (entry_parameter || output) {
197 live_buffers[buffer].insert(nullptr);
198 }
199
200 // If the buffer has no users and isn't an entry parameter or output, it
201 // must be a dead value.
202 if (!live_buffers.contains(buffer)) {
203 dead_buffers_to_free.push_back(buffer);
204 }
205 }
206
207 // Update live_buffers to indicate we've visited this instruction; this is
208 // the inverse of the initialization logic. We erase this instruction from
209 // all source buffers of all operands of this instruction. Buffers that
210 // have no instructions left to visit are moved from live_buffers to
211 // operand_buffers_to_free.
212 operand_buffers_to_free.clear();
213 for (const BufferValue* operand_buffer : used_buffers[instruction]) {
214 if (IgnoreBuffer(operand_buffer)) {
215 continue;
216 }
217 VLOG(4) << " Removing user " << instruction->name() << " from buffer "
218 << operand_buffer->ToString();
219 auto it = live_buffers.find(operand_buffer);
220 flat_hash_set<const HloInstruction*>* live_set = &it->second;
221 live_set->erase(instruction);
222 if (live_set->empty()) {
223 live_buffers.erase(it);
224 operand_buffers_to_free.push_back(operand_buffer);
225 }
226 }
227 // Sort to get a deterministic iteration order.
228 absl::c_sort(operand_buffers_to_free,
229 [](const BufferValue* x, const BufferValue* y) {
230 return x->id() < y->id();
231 });
232
233 // Allocate buffers defined by this instruction. This is the latest point
234 // that we can allocate; right before the buffer is first used. This must
235 // happen before dead or operand buffers are freed; the instruction reads
236 // the operand buffers to produce its output.
237 //
238 // INVARIANT: Either Alloc or ShareBuffer will be called for each buffer
239 // that we should assign.
240
241 // Make sure each buffer get reused at most once.
242 flat_hash_set<const BufferValue*> reused_buffers;
243 int64 alloc_size_by_instruction = 0;
244 for (const BufferValue* buffer : buffers_defined_by_instruction) {
245 if (IgnoreBuffer(buffer)) {
246 continue;
247 }
248
249 // Check whether the buffer can share with one of its operands; we can
250 // save memory by sharing the buffer, rather than allocating a new one.
251 // We can only share with the operand buffer if it is about to be freed;
252 // we must be the last user of the buffer.
253 bool shared = false;
254 if (options_.may_reuse_operand_buffers) {
255 for (const BufferValue* operand_buffer : operand_buffers_to_free) {
256 if (reused_buffers.contains(operand_buffer)) {
257 continue;
258 }
259 if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) &&
260 buffer->instruction()->opcode() != HloOpcode::kCopy &&
261 points_to_analysis.CanShareOperandBufferWithUser(
262 operand_buffer->instruction(), operand_buffer->index(),
263 buffer->instruction(), buffer->index())) {
264 VLOG(3) << " Sharing: " << buffer->ToString() << " with "
265 << operand_buffer->ToString();
266 ShareBuffer(buffer, operand_buffer, instruction);
267 shared = true;
268 reused_buffers.insert(operand_buffer);
269 break;
270 }
271 }
272 }
273
274 if (!shared) {
275 VLOG(3) << " Allocating: " << buffer->ToString();
276 alloc_size_by_instruction += size_fn_(*buffer);
277 Alloc(buffer, instruction);
278 }
279 }
280 // Account for the memory used by subcomputations when estimating the
281 // current heap size.
282 if (memory_by_computation_ != nullptr) {
283 algorithm_->AccountForSubcomputationMemory(
284 instruction, alloc_size_by_instruction, *memory_by_computation_);
285 }
286
287 // If all computations in the module have been scheduled, we can save memory
288 // by running the heap-simulation for sub-computations inline. E.g. the
289 // buffers for the condition and body of a kWhile instruction are only live
290 // for the duration of the instruction itself.
291 //
292 // The order that the sub-computations are simulated does not affect
293 // correctness; since the whole module has been scheduled, we know that the
294 // sub-computations will never be run concurrently.
295 if (schedule_ != nullptr) {
296 if (instruction->opcode() == HloOpcode::kCall ||
297 instruction->opcode() == HloOpcode::kConditional ||
298 instruction->opcode() == HloOpcode::kWhile) {
299 for (const HloComputation* called_computation :
300 instruction->called_computations()) {
301 const HloInstructionSequence& called_sequence =
302 schedule_->sequence(called_computation);
303 TF_RETURN_IF_ERROR(RunComputation(
304 *called_computation, called_sequence, points_to_analysis));
305 }
306 }
307
308 // Other sub-computations (e.g. Map, Reduce, ...) are skipped; they are
309 // assigned "thread-local" allocations, meaning their buffers are not
310 // allocated up-front at the beginning of the computation.
311 }
312
313 // Free buffers that are no longer live. This is the earliest point that we
314 // can de-allocate; right after the last use of the buffer.
315 for (const BufferValue* buffer : dead_buffers_to_free) {
316 VLOG(3) << " Freeing dead: " << buffer->ToString();
317 Free(buffer, instruction);
318 }
319 for (const BufferValue* buffer : operand_buffers_to_free) {
320 VLOG(3) << " Freeing operand: " << buffer->ToString();
321 Free(buffer, instruction);
322 }
323 }
324
325 // Any remaining live buffers must be entry parameters or output source
326 // buffers, which had a nullptr sentry added. Free them now, in a
327 // deterministic order.
328 std::vector<const BufferValue*> to_free;
329 to_free.reserve(live_buffers.size());
330 for (const auto& buffer_pending : live_buffers) {
331 const BufferValue* buffer = buffer_pending.first;
332 const flat_hash_set<const HloInstruction*>& pending = buffer_pending.second;
333 CHECK_EQ(pending.size(), 1) << *buffer;
334 CHECK(*pending.begin() == nullptr) << *buffer;
335 to_free.push_back(buffer);
336 }
337
338 absl::c_sort(to_free, [](const BufferValue* x, const BufferValue* y) {
339 return x->id() < y->id();
340 });
341 for (const BufferValue* buffer : to_free) {
342 VLOG(3) << "Freeing pending: " << buffer->ToString();
343 Free(buffer, root);
344 }
345
346 return Status::OK();
347 }
348
HeapSimulator(std::unique_ptr<HeapAlgorithm> algorithm,const BufferValue::SizeFunction & size_fn,const Options & options,const HloSchedule * schedule,const absl::flat_hash_map<const HloComputation *,int64> * memory_by_computation)349 HeapSimulator::HeapSimulator(
350 std::unique_ptr<HeapAlgorithm> algorithm,
351 const BufferValue::SizeFunction& size_fn, const Options& options,
352 const HloSchedule* schedule,
353 const absl::flat_hash_map<const HloComputation*, int64>*
354 memory_by_computation)
355 : no_fragmentation_stats_(absl::make_unique<NoFragmentationStatsHeap>()),
356 algorithm_(std::move(algorithm)),
357 size_fn_(size_fn),
358 options_(options),
359 schedule_(schedule),
360 memory_by_computation_(memory_by_computation) {
361 debug_trace_.set_whole_module_simulation(schedule_ != nullptr);
362 }
363
~HeapSimulator()364 HeapSimulator::~HeapSimulator() {}
365
IgnoreBuffer(const BufferValue * buffer) const366 bool HeapSimulator::IgnoreBuffer(const BufferValue* buffer) const {
367 // Buffers for constants are ignored unless the alloc_constants option is
368 // set. Also ignore buffers that we're not meant to assign.
369 //
370 // TODO(b/32248867): For consistency, constants should get allocations.
371 if (!options_.alloc_constants &&
372 buffer->instruction()->opcode() == HloOpcode::kConstant) {
373 return true;
374 }
375 return options_.buffers_to_assign != nullptr &&
376 !options_.buffers_to_assign->contains(buffer);
377 }
378
379 // Alloc always calls the underlying heap algorithm.
Alloc(const BufferValue * buffer,const HloInstruction * instruction)380 void HeapSimulator::Alloc(const BufferValue* buffer,
381 const HloInstruction* instruction) {
382 CHECK(!allocated_buffers_.contains(buffer))
383 << "Alloc called on allocated buffer: " << *buffer;
384 CHECK(!freed_buffers_.contains(buffer))
385 << "Alloc called on freed buffer: " << *buffer;
386
387 allocated_buffers_.insert(buffer);
388 const int64 size = size_fn_(*buffer);
389 algorithm_->Alloc(buffer, size);
390 no_fragmentation_stats_->Alloc(buffer, size);
391 FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction,
392 nullptr);
393 }
394
395 // Free calls the underlying algorithm for non-shared buffers, and for shared
396 // buffers whose group liveness has expired. Shared group liveness is tracked
397 // by maintaining a refcount; the Free call on the last buffer in the group
398 // causes Free to be called on the underlying algorithm.
Free(const BufferValue * buffer,const HloInstruction * instruction)399 void HeapSimulator::Free(const BufferValue* buffer,
400 const HloInstruction* instruction) {
401 auto shared_it = shared_buffers_.find(buffer);
402 if (shared_it != shared_buffers_.end()) {
403 std::shared_ptr<SharedGroup> group = shared_it->second;
404 --group->refcount;
405 if (group->refcount > 0) {
406 return;
407 }
408 CHECK_EQ(group->refcount, 0)
409 << "Free caused negative refcount on shared buffer: " << *buffer;
410 buffer = group->canonical;
411 }
412
413 CHECK(allocated_buffers_.contains(buffer))
414 << "Free called on non-allocated buffer: " << *buffer;
415 CHECK(!freed_buffers_.contains(buffer))
416 << "Free called on freed buffer: " << *buffer;
417
418 freed_buffers_.insert(buffer);
419 const int64 size = size_fn_(*buffer);
420 algorithm_->Free(buffer, size);
421 no_fragmentation_stats_->Free(buffer, size);
422
423 FillDebugTrace(HeapSimulatorTrace::Event::FREE, buffer, instruction, nullptr);
424 }
425
426 // ShareBuffer associates buffers with their SharedGroup in shared_buffers_.
427 // The 'buffer' must be a non-allocated, non-freed buffer, just like in calls to
428 // Alloc. The 'shared' buffer must be a previously allocated or shared buffer.
429 // Both 'buffer' and 'shared' will be associated with the same SharedGroup.
ShareBuffer(const BufferValue * buffer,const BufferValue * shared,const HloInstruction * instruction)430 void HeapSimulator::ShareBuffer(const BufferValue* buffer,
431 const BufferValue* shared,
432 const HloInstruction* instruction) {
433 CHECK_LE(size_fn_(*buffer), size_fn_(*shared))
434 << "ShareBuffer oversized buffer" << *buffer << " shared: " << *shared;
435 CHECK(!allocated_buffers_.contains(buffer))
436 << "ShareBuffer called on allocated buffer: " << *buffer;
437 CHECK(!freed_buffers_.contains(buffer))
438 << "ShareBuffer called on freed buffer: " << *buffer;
439 CHECK(!freed_buffers_.contains(shared))
440 << "ShareBuffer called on freed shared buffer: " << *shared;
441
442 const BufferValue* canonical = nullptr;
443 auto shared_it = shared_buffers_.find(shared);
444 if (shared_it != shared_buffers_.end()) {
445 // The 'shared' buffer already has a group; it might be the canonical, but
446 // also might not be. Just add 'buffer' to the existing group.
447 std::shared_ptr<SharedGroup> group = shared_it->second;
448 canonical = group->canonical;
449 ++group->refcount;
450 shared_buffers_.emplace(buffer, group);
451 } else {
452 // The 'shared' buffer doesn't have a group; it must be the canonical. Add
453 // both 'buffer' and 'shared' to a new group.
454 CHECK(allocated_buffers_.contains(shared))
455 << "ShareBuffer called on non-allocated shared buffer: " << *shared;
456 auto group = std::make_shared<SharedGroup>();
457 canonical = shared;
458 group->canonical = canonical;
459 group->refcount = 2;
460 shared_buffers_.emplace(buffer, group);
461 shared_buffers_.emplace(shared, group);
462 }
463
464 FillDebugTrace(HeapSimulatorTrace::Event::SHARE_WITH, buffer, instruction,
465 canonical);
466 }
467
Finish()468 HeapSimulator::Result HeapSimulator::Finish() {
469 Result result = algorithm_->Finish();
470
471 // Post-process the result to add chunks for shared buffers. An empty chunk
472 // map means that either no buffers were allocated, or the heap was only
473 // collecting statistics, e.g. NoFragmentationStatsHeap.
474 if (!result.chunk_map.empty()) {
475 for (const auto& share_pair : shared_buffers_) {
476 const BufferValue* buffer = share_pair.first;
477 std::shared_ptr<SharedGroup> group = share_pair.second;
478 if (buffer != group->canonical) {
479 // The canonical must already exist in the chunk_map, since we called
480 // Alloc(canonical) on the underlying algorithm. Add non-canonical
481 // chunks with the same offset as the canonical.
482 Chunk chunk = FindOrDie(result.chunk_map, group->canonical);
483 chunk.size = size_fn_(*buffer);
484 result.chunk_map.emplace(buffer, chunk);
485 }
486 }
487 // If we were told to assign specific buffers, make sure we've assigned
488 // exactly that many buffers.
489 if (options_.buffers_to_assign != nullptr) {
490 CHECK_EQ(options_.buffers_to_assign->size(), result.chunk_map.size());
491 }
492 }
493
494 // Fragmentation is the difference between the actual and ideal sizes.
495 const Result no_frag_result = no_fragmentation_stats_->Finish();
496 result.fragmentation_size = result.heap_size - no_frag_result.heap_size;
497
498 // Copy the debug trace we collected to the final result.
499 result.debug_trace.Swap(&debug_trace_);
500
501 return result;
502 }
503
FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,const BufferValue * buffer,const HloInstruction * instruction,const BufferValue * share_with_canonical)504 void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,
505 const BufferValue* buffer,
506 const HloInstruction* instruction,
507 const BufferValue* share_with_canonical) {
508 HeapSimulatorTrace::Event* event = debug_trace_.add_events();
509 event->set_kind(kind);
510 event->set_buffer_id(buffer->id());
511 event->set_computation_name(instruction->parent()->name());
512 event->set_instruction_name(instruction->name());
513 if (kind == HeapSimulatorTrace::Event::SHARE_WITH) {
514 CHECK(share_with_canonical != nullptr);
515 event->set_share_with_canonical_id(share_with_canonical->id());
516 } else {
517 CHECK(share_with_canonical == nullptr);
518 }
519 }
520
Alloc(const BufferValue * buffer,int64 size)521 void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) {
522 current_heap_size_ += size;
523 if (current_heap_size_ > max_heap_size_) {
524 max_heap_size_ = current_heap_size_;
525 }
526 }
527
AccountForSubcomputationMemory(const HloInstruction * instruction,int64 alloc_size_by_instruction,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation)528 void NoFragmentationStatsHeap::AccountForSubcomputationMemory(
529 const HloInstruction* instruction, int64 alloc_size_by_instruction,
530 const absl::flat_hash_map<const HloComputation*, int64>&
531 memory_by_computation) {
532 // We only count the memory usage of the largest subcomputation, instead of
533 // adding them all, because subcomputations won't execute in parallel.
534 int64 max_subcomputation_bytes = 0;
535 for (const auto* c : instruction->called_computations()) {
536 auto it = memory_by_computation.find(c);
537 if (it != memory_by_computation.end()) {
538 int64 subcomputation_bytes = it->second;
539 if (subcomputation_bytes > max_subcomputation_bytes) {
540 max_subcomputation_bytes = subcomputation_bytes;
541 }
542 }
543 }
544 if (max_subcomputation_bytes > 0 &&
545 (instruction->opcode() == HloOpcode::kWhile ||
546 instruction->opcode() == HloOpcode::kCall ||
547 instruction->opcode() == HloOpcode::kConditional)) {
548 // The output buffer of while/call/conditional is always aliased with the
549 // output buffer of the root instruction in the body. Don't double count.
550 max_subcomputation_bytes -= alloc_size_by_instruction;
551 }
552 max_heap_size_ =
553 std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes);
554 }
555
Free(const BufferValue * buffer,int64 size)556 void NoFragmentationStatsHeap::Free(const BufferValue* buffer, int64 size) {
557 current_heap_size_ -= size;
558 }
559
Finish()560 HeapSimulator::Result NoFragmentationStatsHeap::Finish() {
561 // The result.chunk_map is empty, since we only collect stats, and don't
562 // actually compute chunk assignments.
563 Result result;
564 result.heap_size = max_heap_size_;
565 return result;
566 }
567
Alloc(const BufferValue * buffer,int64 size)568 void DecreasingSizeRunsHeap::Alloc(const BufferValue* buffer, int64 size) {
569 SetMode(kAlloc);
570 run_.emplace_back(Op{buffer, size});
571 }
572
Free(const BufferValue * buffer,int64 size)573 void DecreasingSizeRunsHeap::Free(const BufferValue* buffer, int64 size) {
574 CHECK(mode_ != kInit) << "Free called on empty heap: " << *buffer;
575 SetMode(kFree);
576 run_.emplace_back(Op{buffer, size});
577 }
578
Finish()579 HeapSimulator::Result DecreasingSizeRunsHeap::Finish() {
580 CallAndDrainRun();
581 return algorithm_->Finish();
582 }
583
SetMode(Mode mode)584 void DecreasingSizeRunsHeap::SetMode(Mode mode) {
585 if (mode_ != mode) {
586 CallAndDrainRun();
587 mode_ = mode;
588 }
589 }
590
CallAndDrainRun()591 void DecreasingSizeRunsHeap::CallAndDrainRun() {
592 if (mode_ == kInit) {
593 CHECK(run_.empty());
594 return;
595 }
596
597 // Call ops in the run sorted by decreasing size, breaking ties by buffer id.
598 absl::c_sort(run_, [](const Op& a, const Op& b) {
599 if (a.size != b.size) {
600 return a.size > b.size;
601 }
602 return a.buffer->id() < b.buffer->id();
603 });
604 for (const Op& op : run_) {
605 if (mode_ == kAlloc) {
606 algorithm_->Alloc(op.buffer, op.size);
607 } else {
608 algorithm_->Free(op.buffer, op.size);
609 }
610 }
611 run_.clear();
612 }
613
Alloc(const BufferValue * buffer,int64 size)614 void LazyBestFitHeap::Alloc(const BufferValue* buffer, int64 size) {
615 // Degenerate case: 0-sized buffers are always allocated at offset 0.
616 if (size == 0) {
617 result_.chunk_map.emplace(buffer, Chunk{0, 0});
618 }
619
620 // First try to allocate from the best-fitting free chunk.
621 auto best_fit_it = free_.lower_bound(Chunk{0, size});
622 while (best_fit_it != free_.end()) {
623 // Account for alignment.
624 const Chunk best = *best_fit_it;
625 const int64 new_offset = RoundUpToNearest(best.offset, alignment_);
626 const int64 new_end = new_offset + size;
627 if (new_end > best.chunk_end()) {
628 // We don't fit after accounting for alignment.
629 ++best_fit_it;
630 continue;
631 }
632 // The buffer is allocated a chunk out of the best-fitting free chunk.
633 free_.erase(best_fit_it);
634 result_.chunk_map.emplace(buffer, Chunk{new_offset, size});
635 // Add remaining portions of the best-fitting free chunk back into free_.
636 AddFreeChunk(best.offset, new_offset - best.offset);
637 AddFreeChunk(new_end, best.chunk_end() - new_end);
638 return;
639 }
640
641 // The buffer doesn't completely fit into any existing free chunk. If the
642 // last free chunk is adjacent to the end of the heap, allocate the buffer
643 // re-using that space, increasing the heap size.
644 //
645 // Allocating the buffer now causes the heap to grow by less than the buffer
646 // size, whereas if we allocated lazily in Free, the heap would grow by
647 // exactly the buffer size. However it's still a greedy heuristical approach;
648 // we might have ended up with a tighter packing by being lazy here.
649 //
650 // In theory we could also check if we could re-use space from the first free
651 // chunk and grow the heap at the front, and choose whether to grow from the
652 // front or back based on the amount of re-use. But that's more complicated,
653 // and these are all heuristics anyways, so it isn't implemented.
654 for (auto it = free_.begin(); it != free_.end(); ++it) {
655 if (it->chunk_end() == result_.heap_size) {
656 // Account for alignment in the last free chunk.
657 const Chunk last = *it;
658 const int64 new_offset = RoundUpToNearest(last.offset, alignment_);
659 if (new_offset >= last.chunk_end()) {
660 // There's no point in using the last free chunk if alignment causes us
661 // to skip over it anyways.
662 break;
663 }
664 // The buffer is allocated a chunk that includes the last free chunk.
665 free_.erase(it);
666 result_.chunk_map.emplace(buffer, Chunk{new_offset, size});
667 // Add remaining portion of the last free chunk back into free_.
668 AddFreeChunk(last.offset, new_offset - last.offset);
669 // Grow the heap.
670 const int64 new_end = new_offset + size;
671 CHECK_GT(new_end, result_.heap_size);
672 CHECK_LT(new_end, result_.heap_size + size);
673 result_.heap_size = new_end;
674 return;
675 }
676 }
677
678 // Otherwise lazily allocate the buffer in Free.
679 result_.chunk_map.emplace(buffer, Chunk{kLazyAllocOffset, size});
680 }
681
Free(const BufferValue * buffer,int64 size)682 void LazyBestFitHeap::Free(const BufferValue* buffer, int64 size) {
683 auto alloc_it = result_.chunk_map.find(buffer);
684 CHECK(alloc_it != result_.chunk_map.end())
685 << "Free called on non-allocated buffer: " << *buffer;
686 Chunk* alloc = &alloc_it->second;
687 CHECK_EQ(alloc->size, size) << "Free with mismatched sizes: " << *buffer;
688 if (alloc->offset != kLazyAllocOffset) {
689 // The buffer was already allocated in Alloc, do a normal free.
690 AddFreeChunk(alloc->offset, alloc->size);
691 } else {
692 // This buffer is lazily allocated, so we *can not* allocate out of existing
693 // free chunks, since that might cause interference between buffers. The
694 // buffer is allocated by growing the heap, accounting for alignment.
695 alloc->offset = RoundUpToNearest(result_.heap_size, alignment_);
696 const int64 new_end = alloc->chunk_end();
697 AddFreeChunk(result_.heap_size, new_end - result_.heap_size);
698 CHECK_GT(new_end, result_.heap_size);
699 CHECK_GE(new_end, result_.heap_size + alloc->size);
700 result_.heap_size = new_end;
701 }
702 }
703
AddFreeChunk(int64 offset,int64 size)704 void LazyBestFitHeap::AddFreeChunk(int64 offset, int64 size) {
705 if (size <= 0) {
706 return;
707 }
708
709 // Coalesce the chunk with adjacent free chunks on either side. We must
710 // remove the free chunks from free_, since it's ordered by size.
711 Chunk chunk{offset, size};
712 for (auto it = free_.begin(); it != free_.end();) {
713 if (it->chunk_end() == chunk.offset || it->offset == chunk.chunk_end()) {
714 chunk.offset = std::min(chunk.offset, it->offset);
715 chunk.size += it->size;
716 it = free_.erase(it);
717 } else {
718 ++it;
719 }
720 }
721
722 // This is the only place we add free chunks to free_. It maintains the
723 // invariant that all free chunks are disjoint and non-adjacent.
724 free_.emplace(chunk);
725 }
726
Finish()727 HeapSimulator::Result LazyBestFitHeap::Finish() {
728 if (!free_.empty()) {
729 // When Finish is called, all calls to Alloc must have had corresponding
730 // calls to Free, which will result in a single free chunk [0, heap_size).
731 CHECK_EQ(free_.size(), 1);
732 CHECK_EQ(free_.begin()->offset, 0);
733 CHECK_EQ(free_.begin()->size, result_.heap_size);
734 }
735 return result_;
736 }
737
Alloc(const BufferValue * buffer,int64 size)738 void GlobalDecreasingSizeBestFitHeap::Alloc(const BufferValue* buffer,
739 int64 size) {
740 // Degenerate case: 0-sized buffers are always allocated at offset 0.
741 if (size == 0) {
742 result_.chunk_map.emplace(buffer, Chunk{0, 0});
743 return;
744 }
745 auto emplace_result = buffer_intervals_.emplace(
746 buffer, BufferInterval{buffer, size, current_time_, -1});
747 DCHECK(emplace_result.second);
748 ++current_time_;
749 }
750
Free(const BufferValue * buffer,int64 size)751 void GlobalDecreasingSizeBestFitHeap::Free(const BufferValue* buffer,
752 int64 size) {
753 // Degenerate case: 0-sized buffers are always allocated at offset 0.
754 if (size == 0) {
755 return;
756 }
757 BufferInterval& buffer_interval = FindOrDie(buffer_intervals_, buffer);
758 DCHECK_EQ(buffer_interval.buffer, buffer);
759 DCHECK_EQ(buffer_interval.size, size);
760 DCHECK_EQ(buffer_interval.end, -1);
761 buffer_interval.end = current_time_;
762 ++current_time_;
763 }
764
765 namespace {
766
767 // Node in BufferIntervalTree that stores the alloc and free times of a buffer,
768 // and the chunk assigned to it.
769 struct BufferIntervalTreeNode {
770 // Alloc time.
771 int64 start;
772 // Free time.
773 int64 end;
774 // Maximum free time of all nodes in the subtree where this node is the root.
775 int64 subtree_end;
776 // Allocated chunk for the buffer.
777 HeapSimulator::Chunk chunk;
778 // Left child.
779 BufferIntervalTreeNode* left;
780 // Right child.
781 BufferIntervalTreeNode* right;
782 };
783
784 // An interval tree that can query buffers overlapping in time.
785 class BufferIntervalTree {
786 public:
BufferIntervalTree(int capacity)787 explicit BufferIntervalTree(int capacity) : node_storage_(capacity) {}
788
789 using Chunk = HeapSimulator::Chunk;
790
791 // Adds a buffer to the interval tree, with the time interval and allocated
792 // chunk specified.
Add(int64 start,int64 end,const Chunk & chunk)793 void Add(int64 start, int64 end, const Chunk& chunk) {
794 int index = node_count_;
795 DCHECK_LT(index, node_storage_.size());
796 ++node_count_;
797
798 node_storage_[index] =
799 BufferIntervalTreeNode{start, end, end, chunk, nullptr, nullptr};
800
801 if (index == 0) {
802 // This is root.
803 return;
804 }
805
806 BufferIntervalTreeNode* parent = &node_storage_[0];
807 while (true) {
808 parent->subtree_end = std::max(parent->subtree_end, end);
809 if (parent->start > start) {
810 if (parent->left == nullptr) {
811 parent->left = &node_storage_[index];
812 return;
813 }
814 parent = parent->left;
815 } else {
816 if (parent->right == nullptr) {
817 parent->right = &node_storage_[index];
818 return;
819 }
820 parent = parent->right;
821 }
822 }
823 }
824
825 // Returns vector of allocated chunks that overlap with the given time
826 // interval.
ChunksOverlappingInTime(int64 start,int64 end)827 std::vector<Chunk> ChunksOverlappingInTime(int64 start, int64 end) {
828 std::vector<Chunk> result;
829 if (node_count_ == 0) {
830 return result;
831 }
832 std::vector<BufferIntervalTreeNode*> visiting_stack;
833 visiting_stack.push_back(&node_storage_[0]);
834 while (!visiting_stack.empty()) {
835 BufferIntervalTreeNode* top = visiting_stack.back();
836 visiting_stack.pop_back();
837 if (start > top->subtree_end) {
838 continue;
839 }
840 if (top->left != nullptr) {
841 visiting_stack.push_back(top->left);
842 }
843 if (top->start <= end && top->end >= start) {
844 result.push_back(top->chunk);
845 }
846 if (end < top->start) {
847 continue;
848 }
849 if (top->right != nullptr) {
850 visiting_stack.push_back(top->right);
851 }
852 }
853 return result;
854 }
855
856 private:
857 int64 node_count_ = 0;
858 std::vector<BufferIntervalTreeNode> node_storage_;
859 };
860
861 } // namespace
862
Finish()863 HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() {
864 std::vector<BufferInterval> sorted_buffer_intervals;
865 for (auto& entry : buffer_intervals_) {
866 sorted_buffer_intervals.push_back(entry.second);
867 }
868 absl::c_sort(sorted_buffer_intervals,
869 [](const BufferInterval& x, const BufferInterval& y) {
870 if (x.size != y.size) {
871 return x.size > y.size;
872 }
873 if (x.end - x.start != y.end - y.start) {
874 return x.end - x.start > y.end - y.start;
875 }
876 return x.buffer->id() < y.buffer->id();
877 });
878
879 BufferIntervalTree interval_tree(sorted_buffer_intervals.size());
880 for (auto& buffer_interval : sorted_buffer_intervals) {
881 auto chunks_overlapping_in_time = interval_tree.ChunksOverlappingInTime(
882 buffer_interval.start, buffer_interval.end);
883 absl::c_sort(
884 chunks_overlapping_in_time,
885 [](const Chunk& x, const Chunk& y) { return x.offset < y.offset; });
886
887 // Find the minimum free chunk that can hold this buffer.
888 Chunk min_fit_chunk{-1, INT64_MAX};
889 auto use_free_chunk_if_smaller = [&](int64 free_offset, int64 free_size) {
890 if (free_size < buffer_interval.size) {
891 return;
892 }
893
894 if (free_size < min_fit_chunk.size) {
895 min_fit_chunk = {free_offset, free_size};
896 }
897 };
898
899 int64 offset = 0;
900 for (auto& chunk : chunks_overlapping_in_time) {
901 if (offset < chunk.offset) {
902 use_free_chunk_if_smaller(offset, chunk.offset - offset);
903 }
904 offset =
905 std::max(offset, RoundUpToNearest(chunk.chunk_end(), alignment_));
906 }
907 use_free_chunk_if_smaller(offset, result_.heap_size - offset);
908
909 if (min_fit_chunk.offset == -1) {
910 // Increase the heap size to fit in the last free chunk.
911 result_.heap_size = offset + buffer_interval.size;
912 min_fit_chunk = {offset, buffer_interval.size};
913 }
914
915 min_fit_chunk.size = buffer_interval.size;
916 const auto emplace_result =
917 result_.chunk_map.emplace(buffer_interval.buffer, min_fit_chunk);
918 DCHECK(emplace_result.second);
919
920 interval_tree.Add(buffer_interval.start, buffer_interval.end,
921 min_fit_chunk);
922 }
923 return result_;
924 }
925
Finish()926 HeapSimulator::Result ChooseBestHeapAlgorithm::Finish() {
927 DCHECK(!algorithms_.empty());
928 std::vector<Result> results(algorithms_.size());
929 int64 min_size = INT64_MAX;
930 int min_size_index = -1;
931 for (int i = 0; i < algorithms_.size(); ++i) {
932 results[i] = algorithms_[i]->Finish();
933 if (results[i].heap_size < min_size) {
934 min_size = results[i].heap_size;
935 min_size_index = i;
936 }
937 }
938
939 DCHECK_GE(min_size_index, 0);
940 return results[min_size_index];
941 }
942
943 } // namespace xla
944