• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/heap_simulator.h"
17 
18 #include <algorithm>
19 #include <vector>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/memory/memory.h"
24 #include "tensorflow/compiler/xla/map_util.h"
25 #include "tensorflow/compiler/xla/util.h"
26 
27 namespace xla {
28 
29 using absl::flat_hash_map;
30 using absl::flat_hash_set;
31 
32 /*static*/
MinimumMemoryForModule(const HloSchedule & schedule,const LogicalBuffer::SizeFunction & size_function)33 StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
34     const HloSchedule& schedule,
35     const LogicalBuffer::SizeFunction& size_function) {
36   if (schedule.empty()) {
37     return 0;
38   }
39 
40   const HloModule* module = schedule.module();
41   TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
42                       TuplePointsToAnalysis::Run(module));
43 
44   // The absolute minimum memory required for a given sequence of instructions
45   // is determined by the sequence of Alloc and Free calls on a simulated heap,
46   // ignoring fragmentation. We run the heap simulation on the whole module,
47   // rather than summing each computation, since it gives us a better lower
48   // bound, by minimizing the liveness of sub-computations.
49   TF_ASSIGN_OR_RETURN(
50       HeapSimulator::Result result,
51       HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(), *module,
52                          schedule, *points_to_analysis, size_function));
53   return result.heap_size;
54 }
55 
56 /*static*/
MinimumMemoryForComputation(const HloComputation & computation,const HloInstructionSequence & sequence,const TuplePointsToAnalysis & points_to_analysis,const LogicalBuffer::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64> * memory_by_computation)57 StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
58     const HloComputation& computation, const HloInstructionSequence& sequence,
59     const TuplePointsToAnalysis& points_to_analysis,
60     const LogicalBuffer::SizeFunction& size_function,
61     const absl::flat_hash_map<const HloComputation*, int64>*
62         memory_by_computation) {
63   TF_ASSIGN_OR_RETURN(
64       HeapSimulator::Result result,
65       HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(),
66                          computation, sequence, points_to_analysis,
67                          size_function, HeapSimulator::Options(),
68                          memory_by_computation));
69   return result.heap_size;
70 }
71 
72 /*static*/
Run(std::unique_ptr<HeapAlgorithm> algorithm,const HloModule & module,const HloSchedule & schedule,const TuplePointsToAnalysis & points_to_analysis,const BufferValue::SizeFunction & size_fn,const Options & options)73 StatusOr<HeapSimulator::Result> HeapSimulator::Run(
74     std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module,
75     const HloSchedule& schedule,
76     const TuplePointsToAnalysis& points_to_analysis,
77     const BufferValue::SizeFunction& size_fn, const Options& options) {
78   HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule);
79   const HloComputation* entry_computation = module.entry_computation();
80   const HloInstructionSequence& instruction_sequence =
81       schedule.sequence(entry_computation);
82   TF_RETURN_IF_ERROR(heap.RunComputation(
83       *entry_computation, instruction_sequence, points_to_analysis));
84   return heap.Finish();
85 }
86 
87 /*static*/
Run(std::unique_ptr<HeapAlgorithm> algorithm,const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const TuplePointsToAnalysis & points_to_analysis,const BufferValue::SizeFunction & size_fn,const Options & options,const absl::flat_hash_map<const HloComputation *,int64> * memory_by_computation)88 StatusOr<HeapSimulator::Result> HeapSimulator::Run(
89     std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation,
90     const HloInstructionSequence& instruction_sequence,
91     const TuplePointsToAnalysis& points_to_analysis,
92     const BufferValue::SizeFunction& size_fn, const Options& options,
93     const absl::flat_hash_map<const HloComputation*, int64>*
94         memory_by_computation) {
95   HeapSimulator heap(std::move(algorithm), size_fn, options,
96                      /*schedule=*/nullptr, memory_by_computation);
97   TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
98                                          points_to_analysis));
99   return heap.Finish();
100 }
101 
102 // Runs a heap simulation for the given 'computation', assuming the given
103 // 'instruction_sequence'.
RunComputation(const HloComputation & computation,const HloInstructionSequence & instruction_sequence,const TuplePointsToAnalysis & points_to_analysis)104 Status HeapSimulator::RunComputation(
105     const HloComputation& computation,
106     const HloInstructionSequence& instruction_sequence,
107     const TuplePointsToAnalysis& points_to_analysis) {
108   VLOG(3) << "Computation:\n" << computation.ToString();
109   // The goal here is to minimize memory usage, assuming the given sequential
110   // ordering of instructions.  The strategy is to walk through the instruction
111   // sequence, calling Alloc and Free on the underlying heap algorithm.  The
112   // heap algorithm takes care of packing and reducing fragmentation.
113   //
114   // 'live_buffers' tracks the liveness of each buffer that we assign, by
115   // associating it with a set of HloInstructions that need to be visited.  When
116   // the set becomes empty, the buffer is no longer used, and can be freed.
117   // 'used_buffers' is the reverse map - it tracks which buffers were used by an
118   // instruction, so that we can remove the instructions from a buffer's live
119   // set after they are visited.
120   flat_hash_map<const BufferValue*, flat_hash_set<const HloInstruction*>>
121       live_buffers;
122   flat_hash_map<const HloInstruction*, flat_hash_set<const BufferValue*>>
123       used_buffers;
124   auto add_user_to_buffer = [this, &live_buffers, &used_buffers](
125                                 const HloInstruction* user,
126                                 const BufferValue* buffer) {
127     if (!IgnoreBuffer(buffer)) {
128       VLOG(4) << "  Adding user " << user->name() << " to buffer "
129               << buffer->ToString();
130       live_buffers[buffer].insert(user);
131       used_buffers[user].insert(buffer);
132     }
133   };
134 
135   // Initialize live_buffers for each buffer that we're going to assign.  The
136   // set of instructions that need to be visited contains all users of all
137   // aliases, that is, all users of all instructions that have the buffer
138   // contained in their points-to set.
139   for (const HloInstruction* instruction :
140        instruction_sequence.instructions()) {
141     const PointsToSet& points_to =
142         points_to_analysis.GetPointsToSet(instruction);
143     const PointsToSet::BufferSet& buffer_set = points_to.CreateFlattenedSet();
144     for (const HloInstruction* user : instruction->users()) {
145       if (user->opcode() != HloOpcode::kGetTupleElement) {
146         for (const BufferValue* buffer : buffer_set) {
147           add_user_to_buffer(user, buffer);
148         }
149       } else {
150         // A GetTupleElement doesn't need to keep all of its operand's buffers
151         // alive. It only needs the buffers that relate to the element it's
152         // extracting, and the tuple it's extracting from, but not the buffers
153         // for the other elements.
154         for (const BufferValue* buffer : points_to.element({})) {
155           add_user_to_buffer(user, buffer);
156         }
157         const PointsToSet& gte_points_to =
158             points_to_analysis.GetPointsToSet(user);
159         for (const BufferValue* buffer : gte_points_to.CreateFlattenedSet()) {
160           add_user_to_buffer(user, buffer);
161         }
162       }
163     }
164   }
165 
166   const HloInstruction* root = computation.root_instruction();
167   BufferValueCompactPointerSet output_source_buffers =
168       ToBufferValueCompactPointerSet(
169           points_to_analysis.GetPointsToSet(root).CreateFlattenedSet());
170 
171   std::vector<const BufferValue*> dead_buffers_to_free;
172   std::vector<const BufferValue*> operand_buffers_to_free;
173   for (const HloInstruction* instruction :
174        instruction_sequence.instructions()) {
175     const TuplePointsToAnalysis::BufferDefinitionVector&
176         buffers_defined_by_instruction =
177             points_to_analysis.GetBuffersDefinedByInstruction(instruction);
178 
179     VLOG(3) << "Instruction: " << instruction->ToString();
180     for (const BufferValue* buffer : buffers_defined_by_instruction) {
181       VLOG(4) << "  Defines: " << buffer->ToString()
182               << (IgnoreBuffer(buffer) ? " (Ignored)" : "");
183     }
184 
185     dead_buffers_to_free.clear();
186     for (const BufferValue* buffer : buffers_defined_by_instruction) {
187       if (IgnoreBuffer(buffer)) {
188         continue;
189       }
190       // Add a nullptr sentry to ensure entry parameters and output source
191       // buffers are not freed until the very end.
192       const bool entry_parameter =
193           &computation == computation.parent()->entry_computation() &&
194           buffer->instruction()->opcode() == HloOpcode::kParameter;
195       const bool output = output_source_buffers.count(buffer) > 0;
196       if (entry_parameter || output) {
197         live_buffers[buffer].insert(nullptr);
198       }
199 
200       // If the buffer has no users and isn't an entry parameter or output, it
201       // must be a dead value.
202       if (!live_buffers.contains(buffer)) {
203         dead_buffers_to_free.push_back(buffer);
204       }
205     }
206 
207     // Update live_buffers to indicate we've visited this instruction; this is
208     // the inverse of the initialization logic.  We erase this instruction from
209     // all source buffers of all operands of this instruction.  Buffers that
210     // have no instructions left to visit are moved from live_buffers to
211     // operand_buffers_to_free.
212     operand_buffers_to_free.clear();
213     for (const BufferValue* operand_buffer : used_buffers[instruction]) {
214       if (IgnoreBuffer(operand_buffer)) {
215         continue;
216       }
217       VLOG(4) << "  Removing user " << instruction->name() << " from buffer "
218               << operand_buffer->ToString();
219       auto it = live_buffers.find(operand_buffer);
220       flat_hash_set<const HloInstruction*>* live_set = &it->second;
221       live_set->erase(instruction);
222       if (live_set->empty()) {
223         live_buffers.erase(it);
224         operand_buffers_to_free.push_back(operand_buffer);
225       }
226     }
227     // Sort to get a deterministic iteration order.
228     absl::c_sort(operand_buffers_to_free,
229                  [](const BufferValue* x, const BufferValue* y) {
230                    return x->id() < y->id();
231                  });
232 
233     // Allocate buffers defined by this instruction.  This is the latest point
234     // that we can allocate; right before the buffer is first used.  This must
235     // happen before dead or operand buffers are freed; the instruction reads
236     // the operand buffers to produce its output.
237     //
238     // INVARIANT: Either Alloc or ShareBuffer will be called for each buffer
239     // that we should assign.
240 
241     // Make sure each buffer get reused at most once.
242     flat_hash_set<const BufferValue*> reused_buffers;
243     int64 alloc_size_by_instruction = 0;
244     for (const BufferValue* buffer : buffers_defined_by_instruction) {
245       if (IgnoreBuffer(buffer)) {
246         continue;
247       }
248 
249       // Check whether the buffer can share with one of its operands; we can
250       // save memory by sharing the buffer, rather than allocating a new one.
251       // We can only share with the operand buffer if it is about to be freed;
252       // we must be the last user of the buffer.
253       bool shared = false;
254       if (options_.may_reuse_operand_buffers) {
255         for (const BufferValue* operand_buffer : operand_buffers_to_free) {
256           if (reused_buffers.contains(operand_buffer)) {
257             continue;
258           }
259           if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) &&
260               buffer->instruction()->opcode() != HloOpcode::kCopy &&
261               points_to_analysis.CanShareOperandBufferWithUser(
262                   operand_buffer->instruction(), operand_buffer->index(),
263                   buffer->instruction(), buffer->index())) {
264             VLOG(3) << "  Sharing: " << buffer->ToString() << " with "
265                     << operand_buffer->ToString();
266             ShareBuffer(buffer, operand_buffer, instruction);
267             shared = true;
268             reused_buffers.insert(operand_buffer);
269             break;
270           }
271         }
272       }
273 
274       if (!shared) {
275         VLOG(3) << "  Allocating: " << buffer->ToString();
276         alloc_size_by_instruction += size_fn_(*buffer);
277         Alloc(buffer, instruction);
278       }
279     }
280     // Account for the memory used by subcomputations when estimating the
281     // current heap size.
282     if (memory_by_computation_ != nullptr) {
283       algorithm_->AccountForSubcomputationMemory(
284           instruction, alloc_size_by_instruction, *memory_by_computation_);
285     }
286 
287     // If all computations in the module have been scheduled, we can save memory
288     // by running the heap-simulation for sub-computations inline. E.g. the
289     // buffers for the condition and body of a kWhile instruction are only live
290     // for the duration of the instruction itself.
291     //
292     // The order that the sub-computations are simulated does not affect
293     // correctness; since the whole module has been scheduled, we know that the
294     // sub-computations will never be run concurrently.
295     if (schedule_ != nullptr) {
296       if (instruction->opcode() == HloOpcode::kCall ||
297           instruction->opcode() == HloOpcode::kConditional ||
298           instruction->opcode() == HloOpcode::kWhile) {
299         for (const HloComputation* called_computation :
300              instruction->called_computations()) {
301           const HloInstructionSequence& called_sequence =
302               schedule_->sequence(called_computation);
303           TF_RETURN_IF_ERROR(RunComputation(
304               *called_computation, called_sequence, points_to_analysis));
305         }
306       }
307 
308       // Other sub-computations (e.g. Map, Reduce, ...) are skipped; they are
309       // assigned "thread-local" allocations, meaning their buffers are not
310       // allocated up-front at the beginning of the computation.
311     }
312 
313     // Free buffers that are no longer live.  This is the earliest point that we
314     // can de-allocate; right after the last use of the buffer.
315     for (const BufferValue* buffer : dead_buffers_to_free) {
316       VLOG(3) << "  Freeing dead: " << buffer->ToString();
317       Free(buffer, instruction);
318     }
319     for (const BufferValue* buffer : operand_buffers_to_free) {
320       VLOG(3) << "  Freeing operand: " << buffer->ToString();
321       Free(buffer, instruction);
322     }
323   }
324 
325   // Any remaining live buffers must be entry parameters or output source
326   // buffers, which had a nullptr sentry added.  Free them now, in a
327   // deterministic order.
328   std::vector<const BufferValue*> to_free;
329   to_free.reserve(live_buffers.size());
330   for (const auto& buffer_pending : live_buffers) {
331     const BufferValue* buffer = buffer_pending.first;
332     const flat_hash_set<const HloInstruction*>& pending = buffer_pending.second;
333     CHECK_EQ(pending.size(), 1) << *buffer;
334     CHECK(*pending.begin() == nullptr) << *buffer;
335     to_free.push_back(buffer);
336   }
337 
338   absl::c_sort(to_free, [](const BufferValue* x, const BufferValue* y) {
339     return x->id() < y->id();
340   });
341   for (const BufferValue* buffer : to_free) {
342     VLOG(3) << "Freeing pending: " << buffer->ToString();
343     Free(buffer, root);
344   }
345 
346   return Status::OK();
347 }
348 
HeapSimulator(std::unique_ptr<HeapAlgorithm> algorithm,const BufferValue::SizeFunction & size_fn,const Options & options,const HloSchedule * schedule,const absl::flat_hash_map<const HloComputation *,int64> * memory_by_computation)349 HeapSimulator::HeapSimulator(
350     std::unique_ptr<HeapAlgorithm> algorithm,
351     const BufferValue::SizeFunction& size_fn, const Options& options,
352     const HloSchedule* schedule,
353     const absl::flat_hash_map<const HloComputation*, int64>*
354         memory_by_computation)
355     : no_fragmentation_stats_(absl::make_unique<NoFragmentationStatsHeap>()),
356       algorithm_(std::move(algorithm)),
357       size_fn_(size_fn),
358       options_(options),
359       schedule_(schedule),
360       memory_by_computation_(memory_by_computation) {
361   debug_trace_.set_whole_module_simulation(schedule_ != nullptr);
362 }
363 
~HeapSimulator()364 HeapSimulator::~HeapSimulator() {}
365 
IgnoreBuffer(const BufferValue * buffer) const366 bool HeapSimulator::IgnoreBuffer(const BufferValue* buffer) const {
367   // Buffers for constants are ignored unless the alloc_constants option is
368   // set. Also ignore buffers that we're not meant to assign.
369   //
370   // TODO(b/32248867): For consistency, constants should get allocations.
371   if (!options_.alloc_constants &&
372       buffer->instruction()->opcode() == HloOpcode::kConstant) {
373     return true;
374   }
375   return options_.buffers_to_assign != nullptr &&
376          !options_.buffers_to_assign->contains(buffer);
377 }
378 
379 // Alloc always calls the underlying heap algorithm.
Alloc(const BufferValue * buffer,const HloInstruction * instruction)380 void HeapSimulator::Alloc(const BufferValue* buffer,
381                           const HloInstruction* instruction) {
382   CHECK(!allocated_buffers_.contains(buffer))
383       << "Alloc called on allocated buffer: " << *buffer;
384   CHECK(!freed_buffers_.contains(buffer))
385       << "Alloc called on freed buffer: " << *buffer;
386 
387   allocated_buffers_.insert(buffer);
388   const int64 size = size_fn_(*buffer);
389   algorithm_->Alloc(buffer, size);
390   no_fragmentation_stats_->Alloc(buffer, size);
391   FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction,
392                  nullptr);
393 }
394 
395 // Free calls the underlying algorithm for non-shared buffers, and for shared
396 // buffers whose group liveness has expired.  Shared group liveness is tracked
397 // by maintaining a refcount; the Free call on the last buffer in the group
398 // causes Free to be called on the underlying algorithm.
Free(const BufferValue * buffer,const HloInstruction * instruction)399 void HeapSimulator::Free(const BufferValue* buffer,
400                          const HloInstruction* instruction) {
401   auto shared_it = shared_buffers_.find(buffer);
402   if (shared_it != shared_buffers_.end()) {
403     std::shared_ptr<SharedGroup> group = shared_it->second;
404     --group->refcount;
405     if (group->refcount > 0) {
406       return;
407     }
408     CHECK_EQ(group->refcount, 0)
409         << "Free caused negative refcount on shared buffer: " << *buffer;
410     buffer = group->canonical;
411   }
412 
413   CHECK(allocated_buffers_.contains(buffer))
414       << "Free called on non-allocated buffer: " << *buffer;
415   CHECK(!freed_buffers_.contains(buffer))
416       << "Free called on freed buffer: " << *buffer;
417 
418   freed_buffers_.insert(buffer);
419   const int64 size = size_fn_(*buffer);
420   algorithm_->Free(buffer, size);
421   no_fragmentation_stats_->Free(buffer, size);
422 
423   FillDebugTrace(HeapSimulatorTrace::Event::FREE, buffer, instruction, nullptr);
424 }
425 
426 // ShareBuffer associates buffers with their SharedGroup in shared_buffers_.
427 // The 'buffer' must be a non-allocated, non-freed buffer, just like in calls to
428 // Alloc.  The 'shared' buffer must be a previously allocated or shared buffer.
429 // Both 'buffer' and 'shared' will be associated with the same SharedGroup.
ShareBuffer(const BufferValue * buffer,const BufferValue * shared,const HloInstruction * instruction)430 void HeapSimulator::ShareBuffer(const BufferValue* buffer,
431                                 const BufferValue* shared,
432                                 const HloInstruction* instruction) {
433   CHECK_LE(size_fn_(*buffer), size_fn_(*shared))
434       << "ShareBuffer oversized buffer" << *buffer << " shared: " << *shared;
435   CHECK(!allocated_buffers_.contains(buffer))
436       << "ShareBuffer called on allocated buffer: " << *buffer;
437   CHECK(!freed_buffers_.contains(buffer))
438       << "ShareBuffer called on freed buffer: " << *buffer;
439   CHECK(!freed_buffers_.contains(shared))
440       << "ShareBuffer called on freed shared buffer: " << *shared;
441 
442   const BufferValue* canonical = nullptr;
443   auto shared_it = shared_buffers_.find(shared);
444   if (shared_it != shared_buffers_.end()) {
445     // The 'shared' buffer already has a group; it might be the canonical, but
446     // also might not be.  Just add 'buffer' to the existing group.
447     std::shared_ptr<SharedGroup> group = shared_it->second;
448     canonical = group->canonical;
449     ++group->refcount;
450     shared_buffers_.emplace(buffer, group);
451   } else {
452     // The 'shared' buffer doesn't have a group; it must be the canonical.  Add
453     // both 'buffer' and 'shared' to a new group.
454     CHECK(allocated_buffers_.contains(shared))
455         << "ShareBuffer called on non-allocated shared buffer: " << *shared;
456     auto group = std::make_shared<SharedGroup>();
457     canonical = shared;
458     group->canonical = canonical;
459     group->refcount = 2;
460     shared_buffers_.emplace(buffer, group);
461     shared_buffers_.emplace(shared, group);
462   }
463 
464   FillDebugTrace(HeapSimulatorTrace::Event::SHARE_WITH, buffer, instruction,
465                  canonical);
466 }
467 
Finish()468 HeapSimulator::Result HeapSimulator::Finish() {
469   Result result = algorithm_->Finish();
470 
471   // Post-process the result to add chunks for shared buffers.  An empty chunk
472   // map means that either no buffers were allocated, or the heap was only
473   // collecting statistics, e.g. NoFragmentationStatsHeap.
474   if (!result.chunk_map.empty()) {
475     for (const auto& share_pair : shared_buffers_) {
476       const BufferValue* buffer = share_pair.first;
477       std::shared_ptr<SharedGroup> group = share_pair.second;
478       if (buffer != group->canonical) {
479         // The canonical must already exist in the chunk_map, since we called
480         // Alloc(canonical) on the underlying algorithm.  Add non-canonical
481         // chunks with the same offset as the canonical.
482         Chunk chunk = FindOrDie(result.chunk_map, group->canonical);
483         chunk.size = size_fn_(*buffer);
484         result.chunk_map.emplace(buffer, chunk);
485       }
486     }
487     // If we were told to assign specific buffers, make sure we've assigned
488     // exactly that many buffers.
489     if (options_.buffers_to_assign != nullptr) {
490       CHECK_EQ(options_.buffers_to_assign->size(), result.chunk_map.size());
491     }
492   }
493 
494   // Fragmentation is the difference between the actual and ideal sizes.
495   const Result no_frag_result = no_fragmentation_stats_->Finish();
496   result.fragmentation_size = result.heap_size - no_frag_result.heap_size;
497 
498   // Copy the debug trace we collected to the final result.
499   result.debug_trace.Swap(&debug_trace_);
500 
501   return result;
502 }
503 
FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,const BufferValue * buffer,const HloInstruction * instruction,const BufferValue * share_with_canonical)504 void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind,
505                                    const BufferValue* buffer,
506                                    const HloInstruction* instruction,
507                                    const BufferValue* share_with_canonical) {
508   HeapSimulatorTrace::Event* event = debug_trace_.add_events();
509   event->set_kind(kind);
510   event->set_buffer_id(buffer->id());
511   event->set_computation_name(instruction->parent()->name());
512   event->set_instruction_name(instruction->name());
513   if (kind == HeapSimulatorTrace::Event::SHARE_WITH) {
514     CHECK(share_with_canonical != nullptr);
515     event->set_share_with_canonical_id(share_with_canonical->id());
516   } else {
517     CHECK(share_with_canonical == nullptr);
518   }
519 }
520 
Alloc(const BufferValue * buffer,int64 size)521 void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) {
522   current_heap_size_ += size;
523   if (current_heap_size_ > max_heap_size_) {
524     max_heap_size_ = current_heap_size_;
525   }
526 }
527 
AccountForSubcomputationMemory(const HloInstruction * instruction,int64 alloc_size_by_instruction,const absl::flat_hash_map<const HloComputation *,int64> & memory_by_computation)528 void NoFragmentationStatsHeap::AccountForSubcomputationMemory(
529     const HloInstruction* instruction, int64 alloc_size_by_instruction,
530     const absl::flat_hash_map<const HloComputation*, int64>&
531         memory_by_computation) {
532   // We only count the memory usage of the largest subcomputation, instead of
533   // adding them all, because subcomputations won't execute in parallel.
534   int64 max_subcomputation_bytes = 0;
535   for (const auto* c : instruction->called_computations()) {
536     auto it = memory_by_computation.find(c);
537     if (it != memory_by_computation.end()) {
538       int64 subcomputation_bytes = it->second;
539       if (subcomputation_bytes > max_subcomputation_bytes) {
540         max_subcomputation_bytes = subcomputation_bytes;
541       }
542     }
543   }
544   if (max_subcomputation_bytes > 0 &&
545       (instruction->opcode() == HloOpcode::kWhile ||
546        instruction->opcode() == HloOpcode::kCall ||
547        instruction->opcode() == HloOpcode::kConditional)) {
548     // The output buffer of while/call/conditional is always aliased with the
549     // output buffer of the root instruction in the body. Don't double count.
550     max_subcomputation_bytes -= alloc_size_by_instruction;
551   }
552   max_heap_size_ =
553       std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes);
554 }
555 
Free(const BufferValue * buffer,int64 size)556 void NoFragmentationStatsHeap::Free(const BufferValue* buffer, int64 size) {
557   current_heap_size_ -= size;
558 }
559 
Finish()560 HeapSimulator::Result NoFragmentationStatsHeap::Finish() {
561   // The result.chunk_map is empty, since we only collect stats, and don't
562   // actually compute chunk assignments.
563   Result result;
564   result.heap_size = max_heap_size_;
565   return result;
566 }
567 
Alloc(const BufferValue * buffer,int64 size)568 void DecreasingSizeRunsHeap::Alloc(const BufferValue* buffer, int64 size) {
569   SetMode(kAlloc);
570   run_.emplace_back(Op{buffer, size});
571 }
572 
Free(const BufferValue * buffer,int64 size)573 void DecreasingSizeRunsHeap::Free(const BufferValue* buffer, int64 size) {
574   CHECK(mode_ != kInit) << "Free called on empty heap: " << *buffer;
575   SetMode(kFree);
576   run_.emplace_back(Op{buffer, size});
577 }
578 
Finish()579 HeapSimulator::Result DecreasingSizeRunsHeap::Finish() {
580   CallAndDrainRun();
581   return algorithm_->Finish();
582 }
583 
SetMode(Mode mode)584 void DecreasingSizeRunsHeap::SetMode(Mode mode) {
585   if (mode_ != mode) {
586     CallAndDrainRun();
587     mode_ = mode;
588   }
589 }
590 
CallAndDrainRun()591 void DecreasingSizeRunsHeap::CallAndDrainRun() {
592   if (mode_ == kInit) {
593     CHECK(run_.empty());
594     return;
595   }
596 
597   // Call ops in the run sorted by decreasing size, breaking ties by buffer id.
598   absl::c_sort(run_, [](const Op& a, const Op& b) {
599     if (a.size != b.size) {
600       return a.size > b.size;
601     }
602     return a.buffer->id() < b.buffer->id();
603   });
604   for (const Op& op : run_) {
605     if (mode_ == kAlloc) {
606       algorithm_->Alloc(op.buffer, op.size);
607     } else {
608       algorithm_->Free(op.buffer, op.size);
609     }
610   }
611   run_.clear();
612 }
613 
Alloc(const BufferValue * buffer,int64 size)614 void LazyBestFitHeap::Alloc(const BufferValue* buffer, int64 size) {
615   // Degenerate case: 0-sized buffers are always allocated at offset 0.
616   if (size == 0) {
617     result_.chunk_map.emplace(buffer, Chunk{0, 0});
618   }
619 
620   // First try to allocate from the best-fitting free chunk.
621   auto best_fit_it = free_.lower_bound(Chunk{0, size});
622   while (best_fit_it != free_.end()) {
623     // Account for alignment.
624     const Chunk best = *best_fit_it;
625     const int64 new_offset = RoundUpToNearest(best.offset, alignment_);
626     const int64 new_end = new_offset + size;
627     if (new_end > best.chunk_end()) {
628       // We don't fit after accounting for alignment.
629       ++best_fit_it;
630       continue;
631     }
632     // The buffer is allocated a chunk out of the best-fitting free chunk.
633     free_.erase(best_fit_it);
634     result_.chunk_map.emplace(buffer, Chunk{new_offset, size});
635     // Add remaining portions of the best-fitting free chunk back into free_.
636     AddFreeChunk(best.offset, new_offset - best.offset);
637     AddFreeChunk(new_end, best.chunk_end() - new_end);
638     return;
639   }
640 
641   // The buffer doesn't completely fit into any existing free chunk.  If the
642   // last free chunk is adjacent to the end of the heap, allocate the buffer
643   // re-using that space, increasing the heap size.
644   //
645   // Allocating the buffer now causes the heap to grow by less than the buffer
646   // size, whereas if we allocated lazily in Free, the heap would grow by
647   // exactly the buffer size.  However it's still a greedy heuristical approach;
648   // we might have ended up with a tighter packing by being lazy here.
649   //
650   // In theory we could also check if we could re-use space from the first free
651   // chunk and grow the heap at the front, and choose whether to grow from the
652   // front or back based on the amount of re-use.  But that's more complicated,
653   // and these are all heuristics anyways, so it isn't implemented.
654   for (auto it = free_.begin(); it != free_.end(); ++it) {
655     if (it->chunk_end() == result_.heap_size) {
656       // Account for alignment in the last free chunk.
657       const Chunk last = *it;
658       const int64 new_offset = RoundUpToNearest(last.offset, alignment_);
659       if (new_offset >= last.chunk_end()) {
660         // There's no point in using the last free chunk if alignment causes us
661         // to skip over it anyways.
662         break;
663       }
664       // The buffer is allocated a chunk that includes the last free chunk.
665       free_.erase(it);
666       result_.chunk_map.emplace(buffer, Chunk{new_offset, size});
667       // Add remaining portion of the last free chunk back into free_.
668       AddFreeChunk(last.offset, new_offset - last.offset);
669       // Grow the heap.
670       const int64 new_end = new_offset + size;
671       CHECK_GT(new_end, result_.heap_size);
672       CHECK_LT(new_end, result_.heap_size + size);
673       result_.heap_size = new_end;
674       return;
675     }
676   }
677 
678   // Otherwise lazily allocate the buffer in Free.
679   result_.chunk_map.emplace(buffer, Chunk{kLazyAllocOffset, size});
680 }
681 
Free(const BufferValue * buffer,int64 size)682 void LazyBestFitHeap::Free(const BufferValue* buffer, int64 size) {
683   auto alloc_it = result_.chunk_map.find(buffer);
684   CHECK(alloc_it != result_.chunk_map.end())
685       << "Free called on non-allocated buffer: " << *buffer;
686   Chunk* alloc = &alloc_it->second;
687   CHECK_EQ(alloc->size, size) << "Free with mismatched sizes: " << *buffer;
688   if (alloc->offset != kLazyAllocOffset) {
689     // The buffer was already allocated in Alloc, do a normal free.
690     AddFreeChunk(alloc->offset, alloc->size);
691   } else {
692     // This buffer is lazily allocated, so we *can not* allocate out of existing
693     // free chunks, since that might cause interference between buffers.  The
694     // buffer is allocated by growing the heap, accounting for alignment.
695     alloc->offset = RoundUpToNearest(result_.heap_size, alignment_);
696     const int64 new_end = alloc->chunk_end();
697     AddFreeChunk(result_.heap_size, new_end - result_.heap_size);
698     CHECK_GT(new_end, result_.heap_size);
699     CHECK_GE(new_end, result_.heap_size + alloc->size);
700     result_.heap_size = new_end;
701   }
702 }
703 
AddFreeChunk(int64 offset,int64 size)704 void LazyBestFitHeap::AddFreeChunk(int64 offset, int64 size) {
705   if (size <= 0) {
706     return;
707   }
708 
709   // Coalesce the chunk with adjacent free chunks on either side.  We must
710   // remove the free chunks from free_, since it's ordered by size.
711   Chunk chunk{offset, size};
712   for (auto it = free_.begin(); it != free_.end();) {
713     if (it->chunk_end() == chunk.offset || it->offset == chunk.chunk_end()) {
714       chunk.offset = std::min(chunk.offset, it->offset);
715       chunk.size += it->size;
716       it = free_.erase(it);
717     } else {
718       ++it;
719     }
720   }
721 
722   // This is the only place we add free chunks to free_.  It maintains the
723   // invariant that all free chunks are disjoint and non-adjacent.
724   free_.emplace(chunk);
725 }
726 
Finish()727 HeapSimulator::Result LazyBestFitHeap::Finish() {
728   if (!free_.empty()) {
729     // When Finish is called, all calls to Alloc must have had corresponding
730     // calls to Free, which will result in a single free chunk [0, heap_size).
731     CHECK_EQ(free_.size(), 1);
732     CHECK_EQ(free_.begin()->offset, 0);
733     CHECK_EQ(free_.begin()->size, result_.heap_size);
734   }
735   return result_;
736 }
737 
Alloc(const BufferValue * buffer,int64 size)738 void GlobalDecreasingSizeBestFitHeap::Alloc(const BufferValue* buffer,
739                                             int64 size) {
740   // Degenerate case: 0-sized buffers are always allocated at offset 0.
741   if (size == 0) {
742     result_.chunk_map.emplace(buffer, Chunk{0, 0});
743     return;
744   }
745   auto emplace_result = buffer_intervals_.emplace(
746       buffer, BufferInterval{buffer, size, current_time_, -1});
747   DCHECK(emplace_result.second);
748   ++current_time_;
749 }
750 
Free(const BufferValue * buffer,int64 size)751 void GlobalDecreasingSizeBestFitHeap::Free(const BufferValue* buffer,
752                                            int64 size) {
753   // Degenerate case: 0-sized buffers are always allocated at offset 0.
754   if (size == 0) {
755     return;
756   }
757   BufferInterval& buffer_interval = FindOrDie(buffer_intervals_, buffer);
758   DCHECK_EQ(buffer_interval.buffer, buffer);
759   DCHECK_EQ(buffer_interval.size, size);
760   DCHECK_EQ(buffer_interval.end, -1);
761   buffer_interval.end = current_time_;
762   ++current_time_;
763 }
764 
765 namespace {
766 
767 // Node in BufferIntervalTree that stores the alloc and free times of a buffer,
768 // and the chunk assigned to it.
769 struct BufferIntervalTreeNode {
770   // Alloc time.
771   int64 start;
772   // Free time.
773   int64 end;
774   // Maximum free time of all nodes in the subtree where this node is the root.
775   int64 subtree_end;
776   // Allocated chunk for the buffer.
777   HeapSimulator::Chunk chunk;
778   // Left child.
779   BufferIntervalTreeNode* left;
780   // Right child.
781   BufferIntervalTreeNode* right;
782 };
783 
784 // An interval tree that can query buffers overlapping in time.
785 class BufferIntervalTree {
786  public:
BufferIntervalTree(int capacity)787   explicit BufferIntervalTree(int capacity) : node_storage_(capacity) {}
788 
789   using Chunk = HeapSimulator::Chunk;
790 
791   // Adds a buffer to the interval tree, with the time interval and allocated
792   // chunk specified.
Add(int64 start,int64 end,const Chunk & chunk)793   void Add(int64 start, int64 end, const Chunk& chunk) {
794     int index = node_count_;
795     DCHECK_LT(index, node_storage_.size());
796     ++node_count_;
797 
798     node_storage_[index] =
799         BufferIntervalTreeNode{start, end, end, chunk, nullptr, nullptr};
800 
801     if (index == 0) {
802       // This is root.
803       return;
804     }
805 
806     BufferIntervalTreeNode* parent = &node_storage_[0];
807     while (true) {
808       parent->subtree_end = std::max(parent->subtree_end, end);
809       if (parent->start > start) {
810         if (parent->left == nullptr) {
811           parent->left = &node_storage_[index];
812           return;
813         }
814         parent = parent->left;
815       } else {
816         if (parent->right == nullptr) {
817           parent->right = &node_storage_[index];
818           return;
819         }
820         parent = parent->right;
821       }
822     }
823   }
824 
825   // Returns vector of allocated chunks that overlap with the given time
826   // interval.
ChunksOverlappingInTime(int64 start,int64 end)827   std::vector<Chunk> ChunksOverlappingInTime(int64 start, int64 end) {
828     std::vector<Chunk> result;
829     if (node_count_ == 0) {
830       return result;
831     }
832     std::vector<BufferIntervalTreeNode*> visiting_stack;
833     visiting_stack.push_back(&node_storage_[0]);
834     while (!visiting_stack.empty()) {
835       BufferIntervalTreeNode* top = visiting_stack.back();
836       visiting_stack.pop_back();
837       if (start > top->subtree_end) {
838         continue;
839       }
840       if (top->left != nullptr) {
841         visiting_stack.push_back(top->left);
842       }
843       if (top->start <= end && top->end >= start) {
844         result.push_back(top->chunk);
845       }
846       if (end < top->start) {
847         continue;
848       }
849       if (top->right != nullptr) {
850         visiting_stack.push_back(top->right);
851       }
852     }
853     return result;
854   }
855 
856  private:
857   int64 node_count_ = 0;
858   std::vector<BufferIntervalTreeNode> node_storage_;
859 };
860 
861 }  // namespace
862 
Finish()863 HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() {
864   std::vector<BufferInterval> sorted_buffer_intervals;
865   for (auto& entry : buffer_intervals_) {
866     sorted_buffer_intervals.push_back(entry.second);
867   }
868   absl::c_sort(sorted_buffer_intervals,
869                [](const BufferInterval& x, const BufferInterval& y) {
870                  if (x.size != y.size) {
871                    return x.size > y.size;
872                  }
873                  if (x.end - x.start != y.end - y.start) {
874                    return x.end - x.start > y.end - y.start;
875                  }
876                  return x.buffer->id() < y.buffer->id();
877                });
878 
879   BufferIntervalTree interval_tree(sorted_buffer_intervals.size());
880   for (auto& buffer_interval : sorted_buffer_intervals) {
881     auto chunks_overlapping_in_time = interval_tree.ChunksOverlappingInTime(
882         buffer_interval.start, buffer_interval.end);
883     absl::c_sort(
884         chunks_overlapping_in_time,
885         [](const Chunk& x, const Chunk& y) { return x.offset < y.offset; });
886 
887     // Find the minimum free chunk that can hold this buffer.
888     Chunk min_fit_chunk{-1, INT64_MAX};
889     auto use_free_chunk_if_smaller = [&](int64 free_offset, int64 free_size) {
890       if (free_size < buffer_interval.size) {
891         return;
892       }
893 
894       if (free_size < min_fit_chunk.size) {
895         min_fit_chunk = {free_offset, free_size};
896       }
897     };
898 
899     int64 offset = 0;
900     for (auto& chunk : chunks_overlapping_in_time) {
901       if (offset < chunk.offset) {
902         use_free_chunk_if_smaller(offset, chunk.offset - offset);
903       }
904       offset =
905           std::max(offset, RoundUpToNearest(chunk.chunk_end(), alignment_));
906     }
907     use_free_chunk_if_smaller(offset, result_.heap_size - offset);
908 
909     if (min_fit_chunk.offset == -1) {
910       // Increase the heap size to fit in the last free chunk.
911       result_.heap_size = offset + buffer_interval.size;
912       min_fit_chunk = {offset, buffer_interval.size};
913     }
914 
915     min_fit_chunk.size = buffer_interval.size;
916     const auto emplace_result =
917         result_.chunk_map.emplace(buffer_interval.buffer, min_fit_chunk);
918     DCHECK(emplace_result.second);
919 
920     interval_tree.Add(buffer_interval.start, buffer_interval.end,
921                       min_fit_chunk);
922   }
923   return result_;
924 }
925 
Finish()926 HeapSimulator::Result ChooseBestHeapAlgorithm::Finish() {
927   DCHECK(!algorithms_.empty());
928   std::vector<Result> results(algorithms_.size());
929   int64 min_size = INT64_MAX;
930   int min_size_index = -1;
931   for (int i = 0; i < algorithms_.size(); ++i) {
932     results[i] = algorithms_[i]->Finish();
933     if (results[i].heap_size < min_size) {
934       min_size = results[i].heap_size;
935       min_size_index = i;
936     }
937   }
938 
939   DCHECK_GE(min_size_index, 0);
940   return results[min_size_index];
941 }
942 
943 }  // namespace xla
944