• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/memory_space_assignment.h"
17 
18 namespace xla {
19 
20 namespace {
21 // Define a dummy chunk for chunks that will be allocated in the default memory
22 // space and for keeping track of number of asynchronous copies.
23 const HeapSimulator::Chunk kDummyChunk{-1, -1};
24 }  // namespace
25 
GetInstructionElapsedDueToCompute(const HloInstruction & instruction) const26 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToCompute(
27     const HloInstruction& instruction) const {
28   return std::max(
29       cost_analysis_.flop_count(instruction) /
30           cost_analysis_.per_second_rate(HloCostAnalysis::kFlopsKey),
31       cost_analysis_.transcendental_count(instruction) /
32           cost_analysis_.per_second_rate(HloCostAnalysis::kTranscendentalsKey));
33 }
34 
35 float MemorySpaceAssignmentCostAnalysis::
GetInstructionElapsedDueToMemorySlowdown(int64 bytes) const36     GetInstructionElapsedDueToMemorySlowdown(int64 bytes) const {
37   return bytes /
38          cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey);
39 }
40 
GetInstructionElapsedDueToMemory(const HloInstruction & instruction,absl::optional<int64> operand_in_alternate_mem,bool output_in_alternate_mem) const41 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToMemory(
42     const HloInstruction& instruction,
43     absl::optional<int64> operand_in_alternate_mem,
44     bool output_in_alternate_mem) const {
45   float bytes_accessed = cost_analysis_.bytes_accessed(instruction);
46   float elapsed_due_to_bytes =
47       bytes_accessed /
48       cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey);
49   if (operand_in_alternate_mem) {
50     // Estimate the elapsed time due to the operand being in the alternate
51     // memory space.
52     float operand_bytes_accessed = cost_analysis_.operand_bytes_accessed(
53         instruction, *operand_in_alternate_mem);
54     float elapsed_due_to_operand_bytes =
55         operand_bytes_accessed / alternate_mem_bandwidth_bytes_per_second_;
56     bytes_accessed -= operand_bytes_accessed;
57     elapsed_due_to_bytes =
58         elapsed_due_to_operand_bytes +
59         bytes_accessed /
60             cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey);
61   }
62   if (output_in_alternate_mem) {
63     // Estimate the elapsed time due to the output being in the alternate memory
64     // space.
65     float output_bytes_accessed =
66         cost_analysis_.output_bytes_accessed(instruction);
67     float elapsed_due_to_output_bytes =
68         output_bytes_accessed / alternate_mem_bandwidth_bytes_per_second_;
69     bytes_accessed -= output_bytes_accessed;
70     elapsed_due_to_bytes =
71         elapsed_due_to_output_bytes +
72         bytes_accessed /
73             cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey);
74   }
75   return elapsed_due_to_bytes;
76 }
77 
GetInstructionElapsed(const HloInstruction & instruction,absl::optional<int64> operand_in_alternate_mem,bool output_in_alternate_mem) const78 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsed(
79     const HloInstruction& instruction,
80     absl::optional<int64> operand_in_alternate_mem,
81     bool output_in_alternate_mem) const {
82   return std::max(
83       GetInstructionElapsedDueToCompute(instruction),
84       GetInstructionElapsedDueToMemory(instruction, operand_in_alternate_mem,
85                                        output_in_alternate_mem));
86 }
87 
GetAsyncCopyElapsed(const Shape & shape) const88 float MemorySpaceAssignmentCostAnalysis::GetAsyncCopyElapsed(
89     const Shape& shape) const {
90   int64 size_in_bytes = cost_analysis_.GetShapeSize(shape);
91   return static_cast<float>(size_in_bytes) /
92          async_copy_bandwidth_bytes_per_second_;
93 }
94 
GetScheduleEndTime() const95 int64 MemorySpaceAssignmentCostAnalysis::GetScheduleEndTime() const {
96   return hlo_live_range_.schedule_end_time();
97 }
98 
CanAllocateInAlternateMemoryNoCopy(const Shape & shape,int64 start_time,int64 end_time) const99 bool InstructionCountPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy(
100     const Shape& shape, int64 start_time, int64 end_time) const {
101   return end_time - start_time <= max_overlap_count_;
102 }
103 
PreferredEvictionEndTime(const Shape & shape,int64 start_time,int64 latest_end_time) const104 int64 InstructionCountPrefetchIntervalPicker::PreferredEvictionEndTime(
105     const Shape& shape, int64 start_time, int64 latest_end_time) const {
106   return std::min(start_time + min_overlap_count_, latest_end_time);
107 }
108 
Begin(const HloUse & use,int64 start_time,int64 end_time)109 void InstructionCountPrefetchIntervalPicker::Begin(const HloUse& use,
110                                                    int64 start_time,
111                                                    int64 end_time) {
112   end_time_ = end_time;
113   current_prefetch_time_ = std::max(start_time, end_time_ - max_overlap_count_);
114 }
115 
Next()116 int64 InstructionCountPrefetchIntervalPicker::Next() {
117   CHECK(!Done()) << "Prefetch interval picker's Next() is called even though "
118                     "Done() is false";
119   return current_prefetch_time_++;
120 }
121 
Done() const122 bool InstructionCountPrefetchIntervalPicker::Done() const {
123   return end_time_ - current_prefetch_time_ <= min_overlap_count_;
124 }
125 
ToDebugString() const126 std::string InstructionCountPrefetchIntervalPicker::ToDebugString() const {
127   return absl::StrCat("Overlapped HLOs = ", end_time_ - current_prefetch_time_);
128 }
129 
ToNoCopyDebugString(const Shape & shape,int64 start_time,int64 end_time) const130 std::string InstructionCountPrefetchIntervalPicker::ToNoCopyDebugString(
131     const Shape& shape, int64 start_time, int64 end_time) const {
132   return absl::StrCat("Overlapped HLOs = ", end_time - start_time);
133 }
134 
CostAnalysisPrefetchIntervalPicker(const MemorySpaceAssignmentCostAnalysis & cost_analysis,float min_async_copy_to_overlap_ratio,float max_async_copy_to_overlap_ratio)135 CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker(
136     const MemorySpaceAssignmentCostAnalysis& cost_analysis,
137     float min_async_copy_to_overlap_ratio,
138     float max_async_copy_to_overlap_ratio)
139     : cost_analysis_(cost_analysis),
140       min_async_copy_to_overlap_ratio_(min_async_copy_to_overlap_ratio),
141       max_async_copy_to_overlap_ratio_(max_async_copy_to_overlap_ratio) {
142   instruction_schedule_ =
143       &cost_analysis_.hlo_live_range().instruction_schedule();
144 
145   // First create a vector of elapsed times of HLO instructions.
146   std::vector<float> instructions_elapsed_time(instruction_schedule_->size(),
147                                                0.0);
148   for (const auto& instruction_and_logical_time : *instruction_schedule_) {
149     float elapsed_time = cost_analysis_.cost_analysis().optimal_seconds(
150         *instruction_and_logical_time.first);
151     int64 logical_time = instruction_and_logical_time.second;
152     if (logical_time >= instructions_elapsed_time.size()) {
153       instructions_elapsed_time.resize(logical_time + 1, 0.0);
154     }
155     instructions_elapsed_time[logical_time] = elapsed_time;
156   }
157   // As an optimization, create a cumulative sum vector of elapsed time.
158   float cumsum = 0.0;
159   for (float elapsed_time : instructions_elapsed_time) {
160     cumsum += elapsed_time;
161     elapsed_time_cumsum_.push_back(cumsum);
162   }
163 }
164 
CanAllocateInAlternateMemoryNoCopy(const Shape & shape,int64 start_time,int64 end_time) const165 bool CostAnalysisPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy(
166     const Shape& shape, int64 start_time, int64 end_time) const {
167   // Even though this method returns if we allow the buffer in alternate memory
168   // _without_ asynchronous copies, calculate how long it would have taken to
169   // copy it and compare it to the elapsed time in the logical interval.
170   float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
171   float logical_interval_elapsed =
172       GetLogicalIntervalElapsed(start_time, end_time);
173   return max_async_copy_to_overlap_ratio_ * async_copy_elapsed >
174          logical_interval_elapsed;
175 }
176 
PreferredEvictionEndTime(const Shape & shape,int64 start_time,int64 latest_end_time) const177 int64 CostAnalysisPrefetchIntervalPicker::PreferredEvictionEndTime(
178     const Shape& shape, int64 start_time, int64 latest_end_time) const {
179   float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
180   int64 end_time;
181   for (end_time = start_time + 1; end_time <= latest_end_time; ++end_time) {
182     float logical_interval_elapsed =
183         GetLogicalIntervalElapsed(start_time, end_time);
184     if (logical_interval_elapsed >=
185         min_async_copy_to_overlap_ratio_ * async_copy_elapsed) {
186       break;
187     }
188   }
189   return end_time;
190 }
191 
Begin(const HloUse & use,int64 start_time,int64 end_time)192 void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use,
193                                                int64 start_time,
194                                                int64 end_time) {
195   const Shape& shape = use.instruction->operand(use.operand_number)->shape();
196   // Find the earliest time that satisfies max_async_copy_to_overlap_ratio_.
197   async_copy_elapsed_ = cost_analysis_.GetAsyncCopyElapsed(shape);
198   // Estimate the time we would save by having this op in alternate memory.
199   float elapsed_time = cost_analysis_.GetInstructionElapsed(*use.instruction);
200   float elapsed_time_in_alternate_mem = cost_analysis_.GetInstructionElapsed(
201       *use.instruction, use.operand_number);
202   inst_elapsed_reduction_ = elapsed_time - elapsed_time_in_alternate_mem;
203   end_logical_time_ = end_time;
204   // Find the earliest time we're allowed to start prefetching.
205   for (current_logical_prefetch_time_ = start_time;
206        current_logical_prefetch_time_ <= end_logical_time_ &&
207        max_async_copy_to_overlap_ratio_ * async_copy_elapsed_ <
208            GetLogicalIntervalElapsed(current_logical_prefetch_time_,
209                                      end_logical_time_);
210        ++current_logical_prefetch_time_) {
211   }
212 }
213 
Next()214 int64 CostAnalysisPrefetchIntervalPicker::Next() {
215   CHECK(!Done()) << "Prefetch interval picker's Next() is called even though "
216                     "Done() is false";
217   return current_logical_prefetch_time_++;
218 }
219 
Done() const220 bool CostAnalysisPrefetchIntervalPicker::Done() const {
221   // The end time is inclusive, so we're done if the prefetch time is greater
222   // than that.
223   if (current_logical_prefetch_time_ > end_logical_time_) {
224     return true;
225   }
226   float logical_interval_elapsed = GetLogicalIntervalElapsed(
227       current_logical_prefetch_time_, end_logical_time_);
228   return async_copy_elapsed_ * min_async_copy_to_overlap_ratio_ >
229          logical_interval_elapsed + inst_elapsed_reduction_;
230 }
231 
GetLogicalIntervalElapsed(int64 start_time,int64 end_time) const232 float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed(
233     int64 start_time, int64 end_time) const {
234   return elapsed_time_cumsum_[end_time - 1] - elapsed_time_cumsum_[start_time];
235 }
236 
ToDebugString() const237 std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const {
238   float logical_interval_elapsed = GetLogicalIntervalElapsed(
239       current_logical_prefetch_time_, end_logical_time_);
240   return absl::StrCat(
241       "Async copy elapsed (s) = ", async_copy_elapsed_,
242       ", inst elapsed reduction (s) = ", inst_elapsed_reduction_,
243       ", logical interval elapsed (s) = ", logical_interval_elapsed);
244 }
245 
ToNoCopyDebugString(const Shape & shape,int64 start_time,int64 end_time) const246 std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString(
247     const Shape& shape, int64 start_time, int64 end_time) const {
248   float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
249   float logical_interval_elapsed =
250       GetLogicalIntervalElapsed(start_time, end_time);
251   return absl::StrCat(
252       "Async copy elapsed (s) = ", async_copy_elapsed,
253       ", logical interval elapsed (s) = ", logical_interval_elapsed);
254 }
255 
256 std::vector<const GlobalDecreasingSizeBestFitHeap::BufferInterval*>
GetSortedColocatedIntervals(const GlobalDecreasingSizeBestFitHeap::BufferInterval & interval) const257 AlternateMemoryBestFitHeap::GetSortedColocatedIntervals(
258     const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const {
259   std::vector<const BufferInterval*> colocated_intervals;
260   std::vector<const BufferInterval*> worklist = {&interval};
261   while (!worklist.empty()) {
262     const BufferInterval* item = worklist.back();
263     worklist.pop_back();
264     colocated_intervals.push_back(item);
265     for (const HloValue* buffer_colocated : item->colocations) {
266       worklist.push_back(&buffer_intervals_.at(buffer_colocated));
267     }
268   }
269 
270   absl::c_stable_sort(colocated_intervals, [&](const BufferInterval* x,
271                                                const BufferInterval* y) {
272     return std::make_pair(x->start, x->end) < std::make_pair(y->start, y->end);
273   });
274   return colocated_intervals;
275 }
276 
IsIntervalAllowedInAlternateMemory(const BufferInterval & interval) const277 bool AlternateMemoryBestFitHeap::IsIntervalAllowedInAlternateMemory(
278     const BufferInterval& interval) const {
279   // If the buffer is a tuple, don't use this algorithm for now. The buffers
280   // that are pointed to by the tuple will still use this algorithm.  Because
281   // tuples are cheap to place in the alternate memory (they are just pointers)
282   // we don't need to use prefetch/evict logic.
283   if (interval.buffer->shape().IsTuple()) {
284     VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
285             << " in default mem because it is a tuple.";
286     return false;
287   }
288 
289   // The semantics of TupleSelect are weird: TupleSelect doesn't define a
290   // buffer, but just forwards the buffers in the either left or right side.
291   // This means the the two different inputs to TupleSelect must not alias, yet
292   // they should be allocated in the same memory space, and both buffers must be
293   // kept alive for the entire live range of TupleSelect. Instead, just don't
294   // allocate TupleSelect in the alternate memory space.
295   // TODO(berkin): Not allocating add-dependencies either since they need to be
296   // treated specially. We should revisit this later.
297   for (const HloPosition& position : interval.buffer->positions()) {
298     if (position.instruction->opcode() == HloOpcode::kTupleSelect ||
299         position.instruction->opcode() == HloOpcode::kAddDependency) {
300       VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
301               << " in default mem because it has a tuple-select or "
302               << "add-dependency position.";
303       return false;
304     }
305   }
306 
307   // Send and Recv HLOs return a request identifier. These should not be
308   // allocated in the alternate memory.
309   const HloPosition& defining_position = interval.buffer->defining_position();
310   if ((defining_position.instruction->opcode() == HloOpcode::kSend ||
311        defining_position.instruction->opcode() == HloOpcode::kRecv) &&
312       defining_position.index == ShapeIndex({1})) {
313     VLOG(4)
314         << "Keeping value " << interval.buffer->ToShortString()
315         << " in default mem because it is a request identifier for send/recv.";
316     return false;
317   }
318 
319   return true;
320 }
321 
Finish()322 HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
323   std::vector<BufferInterval> sorted_buffer_intervals =
324       GetSortedBufferIntervals();
325 
326   VLOG(1) << "Assigning buffers to alternate memory. Max heap size = "
327           << options_.max_size_in_bytes;
328 
329   AddInputAndOutputRequiredAssignments();
330 
331   for (auto& interval : sorted_buffer_intervals) {
332     if (!interval.need_allocation) {
333       continue;
334     }
335 
336     if (!IsIntervalAllowedInAlternateMemory(interval)) {
337       continue;
338     }
339 
340     auto colocated_intervals = GetSortedColocatedIntervals(interval);
341 
342     if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) {
343       VLOG(4) << "Interval " << interval.buffer->ToShortString()
344               << " is reserved in the alternate memory. Total reserved bytes = "
345               << reserved_in_bytes_;
346       for (const BufferInterval* colocated_interval : colocated_intervals) {
347         const HloValue* value = colocated_interval->buffer;
348         // Color all of the aliased reserved buffers here because reserved
349         // alternate memory allocations will not have an entry in preset
350         // allocations that is normally used for coloring.
351         for (auto& position : value->positions()) {
352           VLOG(3) << "Coloring " << position.ToString();
353           Shape* shape = ShapeUtil::GetMutableSubshape(
354               position.instruction->mutable_shape(), position.index);
355           CHECK(shape->IsArray()) << "Coloring a shape that is not an array: "
356                                   << position.ToString();
357           shape->mutable_layout()->set_memory_space(
358               options_.alternate_memory_space);
359         }
360       }
361       // Increment the reserved part of alternate memory so that it is not
362       // available for other buffers. Since all colocated intervals should have
363       // the same size, just use the first one.
364       reserved_in_bytes_ += options_.size_fn(*colocated_intervals[0]->buffer);
365       continue;
366     }
367 
368     if (colocated_intervals.size() > 1 &&
369         !options_.allocate_across_sequential_calls) {
370       VLOG(4) << "Not allocating " << interval.buffer->ToShortString()
371               << " because it aliases with another interval and "
372               << " allocate_across_sequential_calls is false.";
373       continue;
374     }
375 
376     const HloComputation* defining_computation =
377         colocated_intervals[0]->buffer->defining_instruction()->parent();
378     MemorySpaceAssignment::Allocation* aliased_allocation = nullptr;
379     for (const BufferInterval* colocated_interval : colocated_intervals) {
380       const HloValue* value = colocated_interval->buffer;
381       const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
382       allocation_sequence_list_->push_back({value, {}});
383       MemorySpaceAssignment::AllocationSequence* allocation_sequence =
384           &allocation_sequence_list_->back().sequence;
385       int64 definition_time =
386           instruction_schedule.at(value->defining_instruction());
387       // Sort the uses by the use time.
388       std::vector<HloUse> uses = value->uses();
389       absl::c_stable_sort(uses, [&](HloUse use1, HloUse use2) {
390         return instruction_schedule.at(use1.instruction) <
391                instruction_schedule.at(use2.instruction);
392       });
393 
394       // If there was an aliased allocation for this buffer, propagate that for
395       // this HloValue.
396       if (aliased_allocation != nullptr) {
397         VLOG(3) << "Adding an aliased allocation: ("
398                 << aliased_allocation->start_time() << ", "
399                 << aliased_allocation->end_time()
400                 << ") pos: " << aliased_allocation->defining_position()
401                 << " mem space: "
402                 << (aliased_allocation->memory_space() == MemorySpace::kDefault
403                         ? "default"
404                         : "alt");
405         allocation_sequence->push_back(
406             absl::make_unique<MemorySpaceAssignment::Allocation>(
407                 value->defining_instruction(), value->defining_position(),
408                 aliased_allocation->memory_space(), aliased_allocation->chunk(),
409                 definition_time, definition_time));
410       }
411 
412       // Iterate over the uses.
413       for (HloUse use : uses) {
414         int64 use_time = instruction_schedule.at(use.instruction);
415         int64 last_use_time = instruction_schedule.at(uses.back().instruction);
416         int64 latest_prefetch_time = use_time;
417 
418         if (use.instruction->parent() != defining_computation) {
419           VLOG(3) << "skip use " << use.ToString()
420                   << " because it's in a different computation.";
421           continue;
422         }
423 
424         // Sequential calls include kWhile, kCall, and kConditional opcodes.
425         bool is_sequential_call =
426             (GetInstructionCallContext(use.instruction->opcode()) ==
427              CallContext::kSequential);
428         if (is_sequential_call) {
429           for (const HloComputation* called_computation :
430                use.instruction->called_computations()) {
431             const HloLiveRange::TimeBound& computation_span =
432                 hlo_live_range_.computation_span_times().at(called_computation);
433             latest_prefetch_time =
434                 std::min(computation_span.start, latest_prefetch_time);
435           }
436         }
437 
438         // Bitcasts don't define buffers and don't directly consume buffers.
439         // Skip allocating buffers for bitcast uses. The uses that feed from
440         // bitcasts will be handled specially.
441         if (use.instruction->opcode() != HloOpcode::kBitcast) {
442           if (!FindAllocation(definition_time, use_time, last_use_time,
443                               latest_prefetch_time, value->defining_position(),
444                               use, value, colocated_interval->size,
445                               allocation_sequence)) {
446             // If the allocation finding failed (e.g., due to running out of
447             // asynchronous copies), then fall back to allocating the buffer
448             // entirely in the default memory.
449             pending_chunks_.clear();
450             pending_async_copies_.clear();
451             allocation_sequence->clear();
452             break;
453           }
454 
455           // If there are multiple uses, they can try using the memory
456           // allocation already at the alternate memory.
457           definition_time = use_time;
458         }
459 
460         // If the use has been a sequential call (e.g. a while loop), the other
461         // colocated intervals must alias with this allocation.
462         if (is_sequential_call) {
463           aliased_allocation =
464               GetLiveAllocationAt(*allocation_sequence, use_time);
465         }
466       }
467     }
468 
469     CommitPendingChunks();
470   }
471 
472   if (VLOG_IS_ON(3)) {
473     for (const auto& value_and_sequence : *allocation_sequence_list_) {
474       VLOG(3) << "Allocation for " << value_and_sequence.value->ToShortString();
475       for (const auto& alloc : value_and_sequence.sequence) {
476         std::string addr_str = ": default";
477         if (alloc->memory_space() == MemorySpace::kAlternate) {
478           addr_str = absl::StrCat(": alt ", alloc->chunk().offset);
479         }
480 
481         VLOG(3) << "  " << alloc->start_time() << "-" << alloc->end_time()
482                 << addr_str << ", " << alloc->uses().size() << " uses";
483       }
484     }
485   }
486 
487   return result_;
488 }
489 
operator <(const AsynchronousCopy & a,const AsynchronousCopy & b)490 bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b) {
491   return (a.start_time < b.start_time && a.end_time <= b.end_time) ||
492          (a.start_time <= b.start_time && a.end_time < b.end_time);
493 }
494 
AddCopy(const AsynchronousCopy & copy)495 void AsynchronousCopyOrdering::AddCopy(const AsynchronousCopy& copy) {
496   auto it_and_inserted = ranges_.insert(copy);
497   CHECK(it_and_inserted.second ||
498         it_and_inserted.first->start_time == copy.start_time);
499 }
500 
ViolatesOrdering(int64 start_time,int64 end_time) const501 bool AsynchronousCopyOrdering::ViolatesOrdering(int64 start_time,
502                                                 int64 end_time) const {
503   // We allow identical start and end times. It is enough to check for just the
504   // start time in case we find a match in ranges_ because the found value will
505   // either be identical to {start_time, end_time} (and this doesn't violate) or
506   // its start_time will be smaller and end_time will be larger (this violates).
507   auto copy_it = ranges_.find(
508       {start_time, end_time, MemorySpaceAssignment::MemorySpace::kAlternate});
509   return copy_it != ranges_.end() && copy_it->start_time != start_time;
510 }
511 
512 /*static*/ MemorySpaceAssignment::Allocation*
GetLiveAllocationAt(const MemorySpaceAssignment::AllocationSequence & allocations,int64 time)513 AlternateMemoryBestFitHeap::GetLiveAllocationAt(
514     const MemorySpaceAssignment::AllocationSequence& allocations, int64 time) {
515   for (auto allocation_it = allocations.rbegin();
516        allocation_it != allocations.rend(); ++allocation_it) {
517     if ((*allocation_it)->start_time() <= time &&
518         (*allocation_it)->end_time() >= time) {
519       return allocation_it->get();
520     }
521   }
522   return nullptr;
523 }
524 
AddInputAndOutputRequiredAssignments()525 void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() {
526   // Go through the parameters and outputs and pin them to the corresponding
527   // memory by adding a required assignment.
528   const HloModule& module = alias_analysis_.dataflow_analysis().module();
529   const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
530   HloComputation* entry_computation = module.entry_computation();
531   for (HloInstruction* parameter_instruction :
532        entry_computation->parameter_instructions()) {
533     int64 parameter_instruction_time =
534         instruction_schedule.at(parameter_instruction);
535     ShapeUtil::ForEachSubshape(
536         parameter_instruction->shape(),
537         [&](const Shape& subshape, const ShapeIndex& index) {
538           MemorySpace memory_space = MemorySpace::kDefault;
539           if (subshape.has_layout() && subshape.layout().memory_space() ==
540                                            options_.alternate_memory_space) {
541             memory_space = MemorySpace::kAlternate;
542           }
543           for (const HloBuffer* buffer :
544                alias_analysis_.ComputeBuffersAt(parameter_instruction, index)) {
545             for (const HloValue* value : buffer->values()) {
546               VLOG(3) << "Adding required assignment for parameter value = "
547                       << value->ToShortString()
548                       << " time = " << parameter_instruction_time << " space = "
549                       << (memory_space == MemorySpace::kDefault ? "def"
550                                                                 : "alt");
551               required_assignments_[value].push_back(
552                   {memory_space, /*time=*/parameter_instruction_time});
553             }
554           }
555         });
556   }
557   HloInstruction* root_instruction = entry_computation->root_instruction();
558   int64 root_instruction_time = instruction_schedule.at(root_instruction);
559   ShapeUtil::ForEachSubshape(
560       root_instruction->shape(),
561       [&](const Shape& subshape, const ShapeIndex& index) {
562         MemorySpace memory_space = MemorySpace::kDefault;
563         if (subshape.has_layout() && subshape.layout().memory_space() ==
564                                          options_.alternate_memory_space) {
565           memory_space = MemorySpace::kAlternate;
566         }
567         for (const HloBuffer* buffer :
568              alias_analysis_.ComputeBuffersAt(root_instruction, index)) {
569           for (const HloValue* value : buffer->values()) {
570             VLOG(3) << "Adding required assignment for output value = "
571                     << value->ToShortString()
572                     << " time = " << root_instruction_time << " space = "
573                     << (memory_space == MemorySpace::kDefault ? "def" : "alt");
574             required_assignments_[value].push_back(
575                 {memory_space, /*time=*/root_instruction_time});
576           }
577         }
578       });
579 }
580 
AreIntervalsReservedInAlternateMemory(absl::Span<const BufferInterval * const> colocated_intervals) const581 bool AlternateMemoryBestFitHeap::AreIntervalsReservedInAlternateMemory(
582     absl::Span<const BufferInterval* const> colocated_intervals) const {
583   auto is_position_in_alternate_memory = [&](const HloPosition& position) {
584     const Shape& shape = position.shape();
585     return shape.has_layout() &&
586            shape.layout().memory_space() == options_.alternate_memory_space;
587   };
588 
589   const HloModule& module = alias_analysis_.dataflow_analysis().module();
590   const HloComputation* entry_computation = module.entry_computation();
591   const HloInstruction* root_instruction =
592       entry_computation->root_instruction();
593   for (const BufferInterval* colocated_interval : colocated_intervals) {
594     const HloValue* value = colocated_interval->buffer;
595     if (value->defining_instruction()->opcode() == HloOpcode::kParameter &&
596         value->defining_instruction()->parent() == entry_computation &&
597         is_position_in_alternate_memory(value->defining_position())) {
598       return true;
599     }
600 
601     for (const HloPosition& position : value->positions()) {
602       if (position.instruction == root_instruction &&
603           is_position_in_alternate_memory(position)) {
604         return true;
605       }
606     }
607   }
608   return false;
609 }
610 
CommitPendingChunks()611 void AlternateMemoryBestFitHeap::CommitPendingChunks() {
612   for (auto interval_and_chunk : pending_chunks_) {
613     VLOG(3) << "Committing chunk: " << interval_and_chunk.first.start << "-"
614             << interval_and_chunk.first.end << " : ["
615             << interval_and_chunk.second.chunk.offset << ", "
616             << interval_and_chunk.second.chunk.size << "]";
617     CommitChunk(interval_and_chunk.first, interval_and_chunk.second);
618   }
619   pending_chunks_.clear();
620   // Also add the pending async copies to the interval tree.
621   for (const auto& interval : pending_async_copies_) {
622     if (options_.max_outstanding_async_copies >= 0) {
623       async_copy_interval_tree_.Add(interval.start_time, interval.end_time,
624                                     kDummyChunk);
625     }
626     if (interval.destination == MemorySpace::kAlternate) {
627       async_copy_ordering_.AddCopy(interval);
628     }
629   }
630   pending_async_copies_.clear();
631 }
632 
AddToPendingChunks(const BufferInterval & buffer_interval,const ChunkCandidate & chunk_candidate)633 void AlternateMemoryBestFitHeap::AddToPendingChunks(
634     const BufferInterval& buffer_interval,
635     const ChunkCandidate& chunk_candidate) {
636   pending_chunks_.emplace_back(buffer_interval, chunk_candidate);
637 }
638 
RequiredInDefaultMemory(const HloValue * buffer,int64 time) const639 bool AlternateMemoryBestFitHeap::RequiredInDefaultMemory(const HloValue* buffer,
640                                                          int64 time) const {
641   auto required_assignment_it = required_assignments_.find(buffer);
642   return required_assignment_it != required_assignments_.end() &&
643          absl::c_any_of(
644              required_assignment_it->second,
645              [&](const RequiredMemoryAssignment& required_assignment) {
646                return required_assignment.memory_space ==
647                           MemorySpace::kDefault &&
648                       required_assignment.time == time;
649              });
650 }
651 
FindAllocation(int64 start_time,int64 end_time,int64 last_use_time,int64 latest_prefetch_time,HloPosition defining_position,HloUse use,const HloValue * buffer,int64 size,MemorySpaceAssignment::AllocationSequence * allocations)652 bool AlternateMemoryBestFitHeap::FindAllocation(
653     int64 start_time, int64 end_time, int64 last_use_time,
654     int64 latest_prefetch_time, HloPosition defining_position, HloUse use,
655     const HloValue* buffer, int64 size,
656     MemorySpaceAssignment::AllocationSequence* allocations) {
657   HloInstruction* operand =
658       use.instruction->mutable_operand(use.operand_number);
659   // If the operand is a bitcast, we look at bitcast's operand until we find a
660   // non-bitcast operand.
661   HloInstruction* non_bitcast_operand = operand;
662   while (non_bitcast_operand->opcode() == HloOpcode::kBitcast) {
663     non_bitcast_operand = non_bitcast_operand->mutable_operand(0);
664   }
665   // Create an alternate memory interval that starts at the earliest
666   // possible position, given by max_prefetch_interval.
667   BufferInterval alternate_mem_interval;
668   alternate_mem_interval.buffer = buffer;
669   alternate_mem_interval.size = size;
670   alternate_mem_interval.end = end_time;
671 
672   // start_time == end_time is a special case where the value is consumed
673   // multiple times by the same instruction. We can just find the previous
674   // allocation and use that allocation.
675   if (start_time == end_time) {
676     MemorySpaceAssignment::Allocation* allocation =
677         GetLiveAllocationAt(*allocations, end_time);
678     CHECK_NE(allocation, nullptr);
679     allocation->AddUse(use);
680     return true;
681   }
682 
683   VLOG(2) << "Finding allocation for " << buffer->ToShortString() << " ("
684           << start_time << ", " << end_time
685           << ") latest prefetch = " << latest_prefetch_time
686           << " last use = " << last_use_time << " use = " << use.ToString()
687           << ". Size = " << size
688           << ", def pos = " << defining_position.ToString()
689           << ", operand = " << operand->ToShortString()
690           << (non_bitcast_operand != operand
691                   ? ", non_bitcast_operand = " +
692                         non_bitcast_operand->ToShortString()
693                   : "");
694   CHECK_LE(start_time, end_time);
695 
696   // There could be a requirement to pin this buffer to default memory either
697   // because it is a parameter or an output.  If the buffer is a parameter, then
698   // we're allowed to prefetch. If the use expects the ouput to be in default
699   // memory, we cannot prefetch it because if we did, it would be in alternate
700   // memory instead.
701   bool in_default_mem_at_start = RequiredInDefaultMemory(buffer, start_time);
702   bool in_default_mem_at_end = RequiredInDefaultMemory(buffer, end_time);
703 
704   // First try keeping the allocation entirely in the alternate memory.
705   if (!in_default_mem_at_start && !in_default_mem_at_end &&
706       TryAllocatingInAlternateMemoryNoCopy(
707           start_time, end_time, last_use_time, defining_position, use,
708           alternate_mem_interval, non_bitcast_operand, allocations)) {
709     return true;
710   }
711 
712   auto prev_allocation_it = allocations->rbegin();
713   // Find a previous allocation that is in the default memory space (not
714   // necessarily the very last allocation).
715   auto prev_allocation_in_default_mem_it = std::find_if(
716       allocations->rbegin(), allocations->rend(), [&](const auto& allocation) {
717         return allocation->memory_space() == MemorySpace::kDefault &&
718                allocation->defining_position() == defining_position;
719       });
720 
721   if (prev_allocation_in_default_mem_it == allocations->rend() &&
722       prev_allocation_it != allocations->rend() &&
723       (*prev_allocation_it)->memory_space() == MemorySpace::kAlternate &&
724       (*prev_allocation_it)->defining_position() == defining_position) {
725     // If there was an allocation for this HloValue that was in the alternate
726     // memory space, we also need to perform an eviction.
727     int64 eviction_start_time = (*prev_allocation_it)->start_time();
728     int64 eviction_end_time = (*prev_allocation_it)->end_time();
729     CHECK(eviction_start_time <= eviction_end_time);
730 
731     int64 preferred_eviction_end_time = std::max(
732         options_.prefetch_interval_picker->PreferredEvictionEndTime(
733             non_bitcast_operand->shape(), eviction_start_time, end_time),
734         eviction_end_time);
735 
736     BufferInterval eviction_mem_interval;
737     eviction_mem_interval.buffer = buffer;
738     eviction_mem_interval.size = size;
739     // Try to reserve a buffer from the end of the previous allocation to the
740     // preferred eviction end time.
741     eviction_mem_interval.start = eviction_end_time + 1;
742     eviction_mem_interval.end = preferred_eviction_end_time;
743     int64 preferred_offset = (*prev_allocation_it)->chunk().offset;
744     VLOG(4) << "Eviction (" << eviction_start_time << ", " << eviction_end_time
745             << ") preferred end time = " << eviction_mem_interval.end;
746 
747     for (; eviction_mem_interval.end > eviction_end_time;
748          --eviction_mem_interval.end) {
749       ChunkCandidate chunk_candidate =
750           FindChunkCandidate(eviction_mem_interval, preferred_offset);
751       if (chunk_candidate.chunk.offset == preferred_offset) {
752         AddToPendingChunks(eviction_mem_interval, chunk_candidate);
753         break;
754       }
755     }
756     eviction_end_time = eviction_mem_interval.end;
757 
758     VLOG(3) << "Evicting buffer at " << (*prev_allocation_it)->chunk().offset
759             << " (" << eviction_start_time << ", " << eviction_end_time << ")";
760 
761     bool eviction_interval_too_short =
762         (eviction_start_time == eviction_end_time);
763     bool eviction_violates_outstanding_copies =
764         ViolatesMaximumOutstandingAsyncCopies(eviction_start_time,
765                                               eviction_end_time);
766 
767     // See if this interval would violate the asynchronous copy limit.
768     if (!eviction_interval_too_short && !eviction_violates_outstanding_copies) {
769       (*prev_allocation_it)->Extend(eviction_end_time);
770       AddAsyncCopy(**prev_allocation_it, MemorySpace::kDefault, kDummyChunk,
771                    eviction_start_time, (*prev_allocation_it)->end_time(),
772                    eviction_end_time, allocations);
773     } else {
774       if (eviction_violates_outstanding_copies) {
775         VLOG(3) << "This violates the maximum async copies.";
776       } else {
777         VLOG(3) << "Eviction interval is too short (" << eviction_start_time
778                 << ", " << eviction_end_time << ").";
779       }
780       // If the original interval violated the limit, try sub-intervals within
781       // this interval.
782       bool eviction_scheduled = false;
783       for (int64 time = eviction_start_time; time < eviction_end_time; ++time) {
784         VLOG(3) << "Try evicting (" << time << ", " << time + 1 << ")";
785         if (!ViolatesMaximumOutstandingAsyncCopies(time, time + 1)) {
786           VLOG(3) << "Eviction successful.";
787           AddAsyncCopy(**prev_allocation_it, MemorySpace::kDefault, kDummyChunk,
788                        time, time + 1, time + 1, allocations);
789           eviction_scheduled = true;
790           break;
791         }
792       }
793 
794       if (!eviction_scheduled) {
795         // If the eviction couldn't be scheduled, then fail. This buffer will be
796         // kept in the default memory.
797         VLOG(3) << "Bailing: Could not evict " << use.ToString()
798                 << " because we hit the limit of maximum asynchronous copies "
799                 << "between "
800                 << hlo_live_range_.flattened_instruction_sequence()
801                        .instructions()[eviction_start_time]
802                 << " and "
803                 << hlo_live_range_.flattened_instruction_sequence()
804                        .instructions()[eviction_end_time];
805         return false;
806       }
807     }
808     prev_allocation_in_default_mem_it = allocations->rbegin();
809   } else if (prev_allocation_in_default_mem_it == allocations->rend()) {
810     allocations->push_back(absl::make_unique<MemorySpaceAssignment::Allocation>(
811         non_bitcast_operand, defining_position, MemorySpace::kDefault,
812         kDummyChunk, start_time, end_time));
813     prev_allocation_in_default_mem_it = allocations->rbegin();
814   }
815 
816   CHECK(prev_allocation_in_default_mem_it != allocations->rend());
817   CHECK((*prev_allocation_in_default_mem_it)->memory_space() ==
818         MemorySpace::kDefault);
819 
820   // If the buffer must be in default memory at the end_time, don't prefetch.
821   if (in_default_mem_at_end) {
822     VLOG(4)
823         << "Not trying to prefetch because use requires buffer in default mem.";
824     (*prev_allocation_in_default_mem_it)->Extend(end_time);
825     (*prev_allocation_in_default_mem_it)->AddUse(use);
826     return true;
827   }
828 
829   // Try partially placing the buffer in the alternate space. The time that is
830   // overlapped will be used to asynchronously copy the buffer from the
831   // default memory to the alternate memory.
832   //
833   //                      start                 end
834   //                      time                  time
835   //                      X---------------------X
836   // Alternate:                          +------+
837   // Default:             +---------------------+
838   //                                     ^      ^
839   //                                   Copy    Copy
840   //                                   Start   Done
841   options_.prefetch_interval_picker->Begin(
842       use, (*prev_allocation_in_default_mem_it)->earliest_available_time(),
843       latest_prefetch_time);
844   VLOG(4) << "Trying prefetch picker = "
845           << options_.prefetch_interval_picker->ToDebugString();
846   while (!options_.prefetch_interval_picker->Done()) {
847     alternate_mem_interval.start = options_.prefetch_interval_picker->Next();
848     VLOG(4) << "Trying alternate memory allocation ("
849             << alternate_mem_interval.start << ", "
850             << alternate_mem_interval.end << ")";
851     // If this additional asynchronous copy would violate the limit, try a
852     // different interval.
853     if (ViolatesMaximumOutstandingAsyncCopies(alternate_mem_interval.start,
854                                               alternate_mem_interval.end)) {
855       VLOG(4) << "This would violate the outstanding async copy limit.";
856       continue;
857     }
858     if (ViolatesAsyncCopyOrdering(alternate_mem_interval.start,
859                                   alternate_mem_interval.end)) {
860       VLOG(4) << "This would violate asynchronous copy ordering.";
861       continue;
862     }
863 
864     ChunkCandidate chunk_candidate = FindChunkCandidate(alternate_mem_interval);
865     // Check if the new heap size fits within limits.
866     if (chunk_candidate.heap_size < available_heap_size()) {
867       VLOG(3) << "Move the buffer to alternate memory at "
868               << alternate_mem_interval.start
869               << ". Offset = " << chunk_candidate.chunk.offset
870               << ", size = " << chunk_candidate.chunk.size
871               << ", heap_size = " << chunk_candidate.heap_size
872               << ", prefetch picker = "
873               << options_.prefetch_interval_picker->ToDebugString();
874       AddToPendingChunks(alternate_mem_interval, chunk_candidate);
875 
876       AddAsyncCopy(**prev_allocation_in_default_mem_it, MemorySpace::kAlternate,
877                    chunk_candidate.chunk, alternate_mem_interval.start,
878                    end_time, latest_prefetch_time, allocations);
879 
880       allocations->back()->AddUse(use);
881       return true;
882     }
883   }
884 
885   // If a copy wasn't inserted, then add this use to the latest allocation in
886   // default memory.
887   (*prev_allocation_in_default_mem_it)->Extend(end_time);
888   (*prev_allocation_in_default_mem_it)->AddUse(use);
889   return true;
890 }
891 
AddAsyncCopy(const MemorySpaceAssignment::Allocation & prev_allocation,MemorySpace memory_space,Chunk chunk,int64 start_time,int64 end_time,int64 copy_done_schedule_before_time,MemorySpaceAssignment::AllocationSequence * allocations)892 void AlternateMemoryBestFitHeap::AddAsyncCopy(
893     const MemorySpaceAssignment::Allocation& prev_allocation,
894     MemorySpace memory_space, Chunk chunk, int64 start_time, int64 end_time,
895     int64 copy_done_schedule_before_time,
896     MemorySpaceAssignment::AllocationSequence* allocations) {
897   VLOG(3) << "Copy to "
898           << (memory_space == MemorySpaceAssignment::MemorySpace::kDefault
899                   ? "default"
900                   : "alternate")
901           << " memory between " << start_time << " and "
902           << copy_done_schedule_before_time << " keeping until " << end_time;
903 
904   allocations->push_back(
905       absl::make_unique<MemorySpaceAssignment::CopyAllocation>(
906           prev_allocation, memory_space, chunk, start_time, end_time,
907           copy_done_schedule_before_time));
908 
909   // Register the additional async copy with the interval tree to keep track of
910   // the limit at any given time.
911   pending_async_copies_.push_back({start_time, end_time, memory_space});
912 }
913 
ViolatesMaximumOutstandingAsyncCopies(int64 start_time,int64 end_time) const914 bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies(
915     int64 start_time, int64 end_time) const {
916   if (options_.max_outstanding_async_copies < 0) {
917     return false;
918   }
919 
920   // Count both the asynchronous copies in the interval tree as well as the
921   // pending asynchronous copies belonging to this buffer.
922   int64 num_async_copies =
923       async_copy_interval_tree_.ChunksOverlappingInTime(start_time, end_time)
924           .size();
925 
926   for (const auto& interval : pending_async_copies_) {
927     if (interval.start_time > start_time && interval.end_time < end_time) {
928       num_async_copies++;
929     }
930   }
931   // Add one because we are checking if adding an additional asynchronous copy
932   // would violate the limit.
933   return num_async_copies + 1 > options_.max_outstanding_async_copies;
934 }
935 
ViolatesAsyncCopyOrdering(int64 start_time,int64 end_time) const936 bool AlternateMemoryBestFitHeap::ViolatesAsyncCopyOrdering(
937     int64 start_time, int64 end_time) const {
938   if (async_copy_ordering_.ViolatesOrdering(start_time, end_time)) {
939     return true;
940   }
941 
942   // Also check pending async copies.
943   for (const auto& async_copy : pending_async_copies_) {
944     if (async_copy.destination == MemorySpace::kAlternate &&
945         async_copy.start_time <= end_time &&
946         start_time <= async_copy.end_time) {
947       return true;
948     }
949   }
950   return false;
951 }
952 
TryAllocatingInAlternateMemoryNoCopy(int64 start_time,int64 end_time,int64 last_use_time,HloPosition defining_position,HloUse use,BufferInterval alternate_mem_interval,HloInstruction * non_bitcast_operand,MemorySpaceAssignment::AllocationSequence * allocations)953 bool AlternateMemoryBestFitHeap::TryAllocatingInAlternateMemoryNoCopy(
954     int64 start_time, int64 end_time, int64 last_use_time,
955     HloPosition defining_position, HloUse use,
956     BufferInterval alternate_mem_interval, HloInstruction* non_bitcast_operand,
957     MemorySpaceAssignment::AllocationSequence* allocations) {
958   MemorySpaceAssignment::Allocation* prev_allocation = nullptr;
959   bool can_eliminate_copy = false;
960   if (allocations->empty()) {
961     // There hasn't been any allocations for this interval so far. We can
962     // eliminate copy if the value can be placed in the alternate memory.
963     can_eliminate_copy =
964         options_.is_allowed_in_alternate_mem_fn(*alternate_mem_interval.buffer);
965   } else {
966     // If there has been a previous allocation, we can eliminate the copy if the
967     // previous allocation was also in the alternate memory.
968     prev_allocation = allocations->back().get();
969     can_eliminate_copy =
970         (prev_allocation->memory_space() == MemorySpace::kAlternate);
971   }
972 
973   if (!can_eliminate_copy) {
974     return false;
975   }
976 
977   if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
978           non_bitcast_operand->shape(), start_time + 1, end_time)) {
979     return false;
980   }
981 
982   alternate_mem_interval.start = start_time;
983 
984   // Prefer the offset that was previously used for the previous allocation.
985   absl::optional<int64> preferred_offset;
986   if (prev_allocation != nullptr) {
987     preferred_offset = prev_allocation->chunk().offset;
988     // If there is a previous allocation, set the start time one after the end
989     // of the previous allocation's end.
990     alternate_mem_interval.start = prev_allocation->end_time() + 1;
991   }
992 
993   VLOG(4) << "We can eliminate copy to alternate memory. Preferred offset = "
994           << (preferred_offset ? *preferred_offset : -1);
995   // In case there are additional uses after this use, we rely on the last use
996   // time to try to reserve a chunk in the heap simulator. This is to prevent
997   // the following scenario:
998   //
999   //                            +-------+
1000   //                           /         \
1001   //                   Producer--->Use1   +-->Use2
1002   //                       +---------+---------+
1003   // New buffer:           |         |         |
1004   //                       +---------+---------+
1005   //
1006   //                                     +-----------+
1007   // Current heap:                       | offset: 0 |
1008   //           --------------------------+-----------+------
1009   //
1010   // Because we allocate buffers greedily, Producer to Use1 segment first, and
1011   // then Use1 to Use2 segment, it is possible to allocate the first segment at
1012   // an offset that is available for the first segment (e.g. offset 0) but not
1013   // for the entire live range. This can result in unnecessary copies. By using
1014   // the last use time, we try to find an allocation that is available for the
1015   // entire Producer to Use2 range.
1016   absl::optional<ChunkCandidate> chunk_candidate = FindBestNoCopyChunkCandidate(
1017       end_time, last_use_time, preferred_offset, &alternate_mem_interval);
1018   // Check if the new heap size fits within limits. Also ensure if a
1019   // preferred offset was provided, that offset was used.
1020   if (chunk_candidate) {
1021     VLOG(3) << "Keep the buffer in alternate memory. Offset = "
1022             << chunk_candidate->chunk.offset
1023             << ", size = " << chunk_candidate->chunk.size
1024             << ", heap_size = " << chunk_candidate->heap_size
1025             << ", prefetch picker = "
1026             << options_.prefetch_interval_picker->ToNoCopyDebugString(
1027                    non_bitcast_operand->shape(), start_time, end_time);
1028     AddToPendingChunks(alternate_mem_interval, *chunk_candidate);
1029 
1030     // If there was a previous allocation, the buffer location is the
1031     // same as the previous. Otherwise, it is the operand.
1032     if (prev_allocation != nullptr &&
1033         (prev_allocation->is_copy_allocation() ||
1034          prev_allocation->defining_position() == defining_position)) {
1035       prev_allocation->Extend(end_time);
1036     } else {
1037       allocations->push_back(
1038           absl::make_unique<MemorySpaceAssignment::Allocation>(
1039               non_bitcast_operand, defining_position, MemorySpace::kAlternate,
1040               chunk_candidate->chunk, start_time, end_time));
1041     }
1042     allocations->back()->AddUse(use);
1043     return true;
1044   }
1045   return false;
1046 }
1047 
1048 absl::optional<AlternateMemoryBestFitHeap::ChunkCandidate>
FindBestNoCopyChunkCandidate(int64 end_time,int64 last_use_time,absl::optional<int64> preferred_offset,BufferInterval * alternate_mem_interval) const1049 AlternateMemoryBestFitHeap::FindBestNoCopyChunkCandidate(
1050     int64 end_time, int64 last_use_time, absl::optional<int64> preferred_offset,
1051     BufferInterval* alternate_mem_interval) const {
1052   if (!preferred_offset) {
1053     // Find a chunk that's as long living as possible.
1054     for (alternate_mem_interval->end = last_use_time;
1055          alternate_mem_interval->end >= end_time;
1056          --alternate_mem_interval->end) {
1057       ChunkCandidate chunk_candidate =
1058           FindChunkCandidate(*alternate_mem_interval);
1059       if (chunk_candidate.heap_size <= available_heap_size()) {
1060         alternate_mem_interval->end = end_time;
1061         return chunk_candidate;
1062       }
1063     }
1064     return absl::nullopt;
1065   }
1066   // If a preferred offset is given, try to find an allocation at that offset
1067   // only.
1068   alternate_mem_interval->end = end_time;
1069   ChunkCandidate chunk_candidate =
1070       FindChunkCandidate(*alternate_mem_interval, *preferred_offset);
1071   if (chunk_candidate.chunk.offset == *preferred_offset) {
1072     return chunk_candidate;
1073   }
1074   return absl::nullopt;
1075 }
1076 
CountMaximumOutstandingAsyncCopies(const HloModule & module)1077 /*static*/ int64 MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies(
1078     const HloModule& module) {
1079   int64 max_copies = 0;
1080   int64 current_copies = 0;
1081   for (HloInstruction* instruction :
1082        module.schedule().sequence(module.entry_computation()).instructions()) {
1083     if (instruction->opcode() == HloOpcode::kCopyStart) {
1084       current_copies++;
1085     } else if (instruction->opcode() == HloOpcode::kCopyDone) {
1086       current_copies--;
1087     }
1088     max_copies = std::max(max_copies, current_copies);
1089   }
1090   return max_copies;
1091 }
1092 
1093 /*static*/ MemorySpaceAssignment::BufferIntervalCompare
GetMemoryBoundednessBufferIntervalCompare(const MemorySpaceAssignmentCostAnalysis & cost_analysis)1094 MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
1095     const MemorySpaceAssignmentCostAnalysis& cost_analysis) {
1096   return [&](const BufferInterval& x, const BufferInterval& y) {
1097     // Returns a heuristic value that captures how much putting this tensor to
1098     // the alternate memory would help if the op is memory bound, or otherwise
1099     // how far off is the op to memory boundedness. The larger this number, the
1100     // higher priority it will be placed in the alternate memory.
1101     auto get_alternate_mem_benefit =
1102         [&](const HloInstruction& instruction,
1103             float elapsed_time_due_to_alternate_mem) {
1104           float elapsed_time_due_to_compute =
1105               cost_analysis.GetInstructionElapsedDueToCompute(instruction);
1106           float elapsed_time_due_to_memory =
1107               cost_analysis.GetInstructionElapsedDueToMemory(instruction);
1108           if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) {
1109             // Memory bound, return how much alternate memory is better.
1110             return elapsed_time_due_to_memory -
1111                    elapsed_time_due_to_alternate_mem;
1112           } else {
1113             // Compute bound, return how far off are we to memory boundedness.
1114             return elapsed_time_due_to_memory - elapsed_time_due_to_compute;
1115           }
1116         };
1117 
1118     auto get_memory_boundedness = [&](const BufferInterval& interval) {
1119       const HloInstruction& defining_instruction =
1120           *interval.buffer->defining_instruction();
1121       float alternate_mem_benefit = get_alternate_mem_benefit(
1122           defining_instruction, cost_analysis.GetInstructionElapsedDueToMemory(
1123                                     defining_instruction,
1124                                     /*operand_in_alternate_mem=*/{},
1125                                     /*output_in_alternate_mem=*/true));
1126       for (const HloUse& use : interval.buffer->uses()) {
1127         float use_alternate_mem_benefit = get_alternate_mem_benefit(
1128             *use.instruction, cost_analysis.GetInstructionElapsedDueToMemory(
1129                                   *use.instruction, use.operand_number));
1130         // If the benefit is positive (memory bound), add it to this buffer's
1131         // benefit. If the benefit is negative (compute bound), calculate the
1132         // maximum.
1133         if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) {
1134           alternate_mem_benefit += use_alternate_mem_benefit;
1135         } else {
1136           alternate_mem_benefit =
1137               std::max(alternate_mem_benefit, use_alternate_mem_benefit);
1138         }
1139       }
1140 
1141       // Get performance slowdown in seconds of prefetching current
1142       // BufferInterval causing to other BufferIntervals.
1143       float alternate_mem_slowdown =
1144           cost_analysis.GetInstructionElapsedDueToMemorySlowdown(interval.size);
1145 
1146       // Scale the slowdown based on the time of this buffer. We would want
1147       // earlier buffers have lower slowdown values, because they are less
1148       // likely to overlap with other HLOs.
1149       // TODO (yuemmawang) We may want a piecewise function, where a lower
1150       // slowdown for early HLOs, and full slowdown for mid-to-late HLOs.
1151       // TODO (yuemmawang) Further in a smarter way, we want buffers overlapped
1152       // with more HLOs have higher slowdown, and vice versa.
1153       float scale = interval.start * 1.0 / cost_analysis.GetScheduleEndTime();
1154       alternate_mem_slowdown *= scale;
1155 
1156       return alternate_mem_benefit - alternate_mem_slowdown;
1157     };
1158 
1159     float x_memory_boundedness = get_memory_boundedness(x);
1160     float y_memory_boundedness = get_memory_boundedness(y);
1161     if (x_memory_boundedness != y_memory_boundedness) {
1162       return x_memory_boundedness > y_memory_boundedness;
1163     }
1164     // Tie-break if the memory boundedness is the same.
1165     return GlobalDecreasingSizeBestFitHeap::GetSpatialBufferIntervalCompare()(
1166         x, y);
1167   };
1168 }
1169 
1170 /*static*/ StatusOr<std::unique_ptr<PresetAssignments>>
Run(HloModule * module,const HloLiveRange & hlo_live_range,const HloAliasAnalysis & alias_analysis,const Options & options)1171 MemorySpaceAssignment::Run(HloModule* module,
1172                            const HloLiveRange& hlo_live_range,
1173                            const HloAliasAnalysis& alias_analysis,
1174                            const Options& options) {
1175   CHECK(module->has_schedule());
1176   VLOG(4) << "Module before memory space assignment: ";
1177   XLA_VLOG_LINES(4, module->ToString());
1178   VLOG(4) << "Schedule: " << module->schedule().ToString();
1179   MemorySpaceAssignment memory_space_assignment(module, options,
1180                                                 hlo_live_range);
1181   auto algorithm = absl::make_unique<AlternateMemoryBestFitHeap>(
1182       &memory_space_assignment.allocation_sequence_list_, options,
1183       alias_analysis, hlo_live_range);
1184 
1185   HeapSimulator::Options heap_simulator_options;
1186   heap_simulator_options.may_reuse_operand_buffers = false;
1187   TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module,
1188                                         module->schedule(), alias_analysis,
1189                                         options.size_fn, heap_simulator_options)
1190                          .status());
1191 
1192   TF_RETURN_IF_ERROR(memory_space_assignment.Process());
1193   memory_space_assignment.ScheduleAsynchronousCopies();
1194   TF_RETURN_IF_ERROR(memory_space_assignment.SimplifyGraph());
1195   TF_RETURN_IF_ERROR(memory_space_assignment.FixSchedule());
1196 
1197   VLOG(4) << "Module after memory space assignment: ";
1198   XLA_VLOG_LINES(4, module->ToString());
1199   TF_CHECK_OK(module->schedule().Verify());
1200   VLOG(1) << "Maximum number of outstanding async copies: "
1201           << CountMaximumOutstandingAsyncCopies(*module);
1202 
1203   TF_RETURN_IF_ERROR(
1204       memory_space_assignment.VerifyAndExportHeapSimulatorTrace());
1205 
1206   return std::move(memory_space_assignment.preset_assignments_);
1207 }
1208 
AddUse(HloUse use)1209 void MemorySpaceAssignment::Allocation::AddUse(HloUse use) {
1210   HloInstruction* operand =
1211       use.instruction->mutable_operand(use.operand_number);
1212   // If the use is a tuple, look inside the tuple to find the actual use.
1213   for (int64 index : use.operand_index) {
1214     if (operand->opcode() != HloOpcode::kTuple) {
1215       break;
1216     }
1217     operand = operand->mutable_operand(index);
1218   }
1219 
1220   // Look beyond GetTupleElement(Tuple()) pattern for any bitcasts.
1221   std::function<HloInstruction*(HloInstruction*)> get_simplified_operand;
1222   get_simplified_operand = [&](HloInstruction* instruction) {
1223     while (instruction->opcode() == HloOpcode::kGetTupleElement) {
1224       HloInstruction* operand =
1225           get_simplified_operand(instruction->mutable_operand(0));
1226       if (operand->opcode() == HloOpcode::kTuple) {
1227         instruction = operand->mutable_operand(instruction->tuple_index());
1228       } else {
1229         return instruction;
1230       }
1231     }
1232     return instruction;
1233   };
1234   operand = get_simplified_operand(operand);
1235 
1236   uses_.push_back(use);
1237 }
1238 
Process(MemorySpaceAssignment * memory_space_assignment)1239 Status MemorySpaceAssignment::Allocation::Process(
1240     MemorySpaceAssignment* memory_space_assignment) {
1241   return Status::OK();
1242 }
1243 
ReplaceTupleWith(HloInstruction * new_instruction,HloInstruction * tuple,ShapeIndex shape_index)1244 StatusOr<HloInstruction*> MemorySpaceAssignment::Allocation::ReplaceTupleWith(
1245     HloInstruction* new_instruction, HloInstruction* tuple,
1246     ShapeIndex shape_index) {
1247   const Shape& tuple_shape = tuple->shape();
1248   CHECK(tuple->shape().IsTuple())
1249       << "ReplaceTupleWith was called for a non-tuple. Tuple = "
1250       << tuple->ToString()
1251       << ", new_instruction = " << new_instruction->ToString()
1252       << ", shape_index = " << shape_index.ToString();
1253 
1254   HloComputation* computation = new_instruction->parent();
1255   std::vector<HloInstruction*> tuple_args(tuple_shape.tuple_shapes_size());
1256   for (int64 i = 0; i < tuple_shape.tuple_shapes_size(); ++i) {
1257     const Shape& subshape = tuple_shape.tuple_shapes(i);
1258     if (i == shape_index[0]) {
1259       // If the subshape is still a tuple, recurse and pass a new shape index
1260       // for the one level deeper.
1261       if (subshape.IsTuple()) {
1262         HloInstruction* get_tuple_element = computation->AddInstruction(
1263             HloInstruction::CreateGetTupleElement(subshape, tuple, i));
1264         TF_ASSIGN_OR_RETURN(tuple_args[i],
1265                             ReplaceTupleWith(new_instruction, get_tuple_element,
1266                                              ShapeIndex(shape_index.begin() + 1,
1267                                                         shape_index.end())));
1268       } else {
1269         if (subshape != new_instruction->shape()) {
1270           VLOG(4) << "Old shape = " << subshape.ToString()
1271                   << ", new shape = " << new_instruction->shape().ToString()
1272                   << "; inserting a bitcast.";
1273           new_instruction = computation->AddInstruction(
1274               HloInstruction::CreateBitcast(subshape, new_instruction));
1275         }
1276         tuple_args[i] = new_instruction;
1277       }
1278     } else {
1279       HloInstruction* get_tuple_element = computation->AddInstruction(
1280           HloInstruction::CreateGetTupleElement(subshape, tuple, i));
1281       tuple_args[i] = get_tuple_element;
1282     }
1283   }
1284   return computation->AddInstruction(HloInstruction::CreateTuple(tuple_args));
1285 }
1286 
Process(MemorySpaceAssignment * memory_space_assignment)1287 Status MemorySpaceAssignment::CopyAllocation::Process(
1288     MemorySpaceAssignment* memory_space_assignment) {
1289   // Copy allocations need to insert asynchronous copy nodes.
1290   HloInstruction* producing_instruction = defining_position().instruction;
1291   CHECK_NE(producing_instruction, nullptr);
1292 
1293   Shape shape = defining_position().shape();
1294   CHECK(shape.IsArray()) << "CopyAllocation shape is not an array. Shape = "
1295                          << shape.ToString()
1296                          << " position = " << defining_position().shape();
1297   HloComputation* computation = producing_instruction->parent();
1298 
1299   // If the instruction we're copying from is a tuple, we (recursively) create
1300   // kGetTupleElement instructions and copy that value. Asynchronous copies only
1301   // support array types.
1302   if (!producing_instruction->shape().IsArray()) {
1303     producing_instruction = defining_position().instruction;
1304     for (int64 index : defining_position().index) {
1305       producing_instruction =
1306           computation->AddInstruction(HloInstruction::CreateGetTupleElement(
1307               producing_instruction->shape().tuple_shapes(index),
1308               producing_instruction, index));
1309     }
1310   }
1311   copy_start_ = computation->AddInstruction(HloInstruction::CreateUnary(
1312       ShapeUtil::MakeTupleShape({shape, shape, ShapeUtil::MakeShape(U32, {})}),
1313       HloOpcode::kCopyStart, producing_instruction));
1314   copy_done_ = computation->AddInstruction(
1315       HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_));
1316   // Update the allocation with the copy done instruction so that if there
1317   // are further copies from it, it can find the correct instruction.
1318   instruction_ = copy_done_;
1319 
1320   // Also update the defining position.
1321   defining_position_ = HloPosition{copy_done_, {}};
1322 
1323   // Replace all the uses with the new copy instruction.
1324   for (HloUse use : uses_) {
1325     // If the operand is a tuple, we need to descend to the actual instruction
1326     // we want to replace.
1327     HloInstruction* replacement_instruction;
1328     Shape operand_shape = use.instruction->operand(use.operand_number)->shape();
1329     if (operand_shape.IsTuple()) {
1330       TF_ASSIGN_OR_RETURN(
1331           replacement_instruction,
1332           ReplaceTupleWith(copy_done_,
1333                            use.instruction->mutable_operand(use.operand_number),
1334                            use.operand_index));
1335     } else if (operand_shape != copy_done_->shape()) {
1336       VLOG(4) << "Old shape = " << operand_shape.ToString()
1337               << ", new shape = " << copy_done_->shape().ToString()
1338               << "; inserting a bitcast.";
1339       replacement_instruction = computation->AddInstruction(
1340           HloInstruction::CreateBitcast(operand_shape, copy_done_));
1341     } else {
1342       replacement_instruction = copy_done_;
1343     }
1344     TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith(
1345         use.operand_number, replacement_instruction));
1346   }
1347 
1348   return Status::OK();
1349 }
1350 
Process()1351 Status MemorySpaceAssignment::Process() {
1352   // Insert CopyStart/CopyDone pairs.
1353   int64 alternate_memory_size = 0;
1354   for (auto& value_and_sequence : allocation_sequence_list_) {
1355     for (auto& allocation : value_and_sequence.sequence) {
1356       TF_RETURN_IF_ERROR(allocation->Process(this));
1357       // Add the offset and size of the allocation in the alternate memory to
1358       // the output map. Special case for bitcast: since bitcast doesn't define
1359       // its own buffer, that shouldn't be exported as a preset chunk.
1360       if (allocation->memory_space() == MemorySpace::kAlternate &&
1361           allocation->instruction()->opcode() != HloOpcode::kBitcast) {
1362         preset_assignments_->add_chunk(allocation->defining_position(),
1363                                        allocation->chunk());
1364         alternate_memory_size =
1365             std::max(alternate_memory_size, allocation->chunk().chunk_end());
1366       }
1367     }
1368   }
1369 
1370   if (!preset_assignments_->chunks().empty()) {
1371     preset_assignments_
1372         ->assignment_information_for_space(options_.alternate_memory_space)
1373         ->size = alternate_memory_size;
1374   }
1375 
1376   if (VLOG_IS_ON(3)) {
1377     VLOG(3) << "Exported alternate memory allocations:";
1378     for (auto& pair : preset_assignments_->chunks()) {
1379       VLOG(3) << " [" << pair.second.offset << ", " << pair.second.size
1380               << "] : " << pair.first.ToString();
1381     }
1382     VLOG(3) << "Exported alternate memory sizes:";
1383     for (auto& pair : preset_assignments_->assignment_informations()) {
1384       VLOG(3) << "  space: " << pair.first << ", size: " << pair.second.size;
1385     }
1386   }
1387 
1388   // Color the pending positions and all of their aliased buffers.
1389   TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module_));
1390   for (const auto& defining_position_and_chunk :
1391        preset_assignments_->chunks()) {
1392     const HloPosition& defining_position = defining_position_and_chunk.first;
1393     for (auto& buffer : alias_analysis->ComputeBuffersAt(
1394              defining_position.instruction, defining_position.index)) {
1395       for (auto& value : buffer->values()) {
1396         for (auto& position : value->positions()) {
1397           VLOG(3) << "Coloring " << position.ToString();
1398           Shape* shape = ShapeUtil::GetMutableSubshape(
1399               position.instruction->mutable_shape(), position.index);
1400           CHECK(shape->IsArray()) << "Coloring a shape that is not an array: "
1401                                   << position.ToString();
1402           shape->mutable_layout()->set_memory_space(
1403               options_.alternate_memory_space);
1404         }
1405       }
1406     }
1407   }
1408 
1409   return Status::OK();
1410 }
1411 
RemoveAssignmentForInstruction(const HloInstruction * instruction)1412 void PresetAssignments::RemoveAssignmentForInstruction(
1413     const HloInstruction* instruction) {
1414   for (auto& position_and_chunk : chunks_) {
1415     const HloPosition& position = position_and_chunk.first;
1416     if (position.instruction == instruction) {
1417       VLOG(3) << "Removing instruction from preset assignments.";
1418       // Swap the removed position and chunk with the back and pop back.
1419       position_and_chunk = chunks_.back();
1420       chunks_.pop_back();
1421       break;
1422     }
1423   }
1424 }
1425 
SimplifyGraph()1426 Status MemorySpaceAssignment::SimplifyGraph() {
1427   for (HloComputation* computation : module_->MakeNonfusionComputations()) {
1428     // Parallel computations aren't in the schedule and don't need to be
1429     // modified.
1430     if (!computations_in_schedule_.contains(computation)) {
1431       VLOG(4) << "Not simplifying " << computation->name()
1432               << " because it's not in the schedule.";
1433       continue;
1434     }
1435     // Drop control dependencies. Since the computation is already scheduled, we
1436     // don't need control dependencies anymore, and having control
1437     // predecessors/successors prevents us from removing instructions without
1438     // users (HloComputation::IsSafelyRemovable returns false if there are
1439     // control dependencies).
1440     for (HloInstruction* instruction :
1441          computation->MakeInstructionPostOrder()) {
1442       TF_RETURN_IF_ERROR(instruction->DropAllControlDeps());
1443     }
1444     // We perform limited DCE and forward the tuple operand in patterns like
1445     // GetTupleElement(Tuple(a, b), 0). This is mostly because memory space
1446     // assignment is ran late in compilation (after DCE and arithmetic
1447     // simplification passes) and we don't want to generate redundant code.  Run
1448     // to fixed point.
1449     bool computation_modified = true;
1450     while (computation_modified) {
1451       computation_modified = false;
1452       VLOG(4) << "Running simplify graph loop over " << computation->name();
1453       for (HloInstruction* instruction :
1454            computation->MakeInstructionPostOrder()) {
1455         if (computation->IsSafelyRemovable(instruction) &&
1456             instruction->user_count() == 0 && !instruction->HasSideEffect() &&
1457             instruction != computation->root_instruction() &&
1458             instruction->opcode() != HloOpcode::kCopyStart &&
1459             instruction->opcode() != HloOpcode::kCopyDone) {
1460           VLOG(4) << "Instruction removed: " << instruction->ToString();
1461           // Ensure the exported preset assignments don't contain a reference to
1462           // the removed instruction.
1463           preset_assignments_->RemoveAssignmentForInstruction(instruction);
1464           // Instead of deleting the instruction from the schedule, replace it
1465           // with a nullptr. This is needed because FixSchedule relies on the
1466           // logical time that is the index into flattened_instructions_ for
1467           // scheduling asynchronous copies.
1468           auto instruction_it =
1469               absl::c_find(flattened_instructions_, instruction);
1470           if (instruction_it != flattened_instructions_.end()) {
1471             *instruction_it = nullptr;
1472           }
1473           TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
1474           computation_modified = true;
1475         } else if (instruction->opcode() == HloOpcode::kGetTupleElement) {
1476           HloInstruction* operand = instruction->mutable_operand(0);
1477           if (operand->opcode() == HloOpcode::kTuple) {
1478             HloInstruction* forwarded_instruction =
1479                 operand->mutable_operand(instruction->tuple_index());
1480             VLOG(4) << "Replacing uses of " << instruction->ToString()
1481                     << " with " << forwarded_instruction->ToString();
1482             TF_RETURN_IF_ERROR(
1483                 instruction->ReplaceAllUsesWith(forwarded_instruction));
1484             computation_modified = true;
1485           }
1486         }
1487       }
1488     }
1489   }
1490 
1491   return Status::OK();
1492 }
1493 
EnsureInstructionAndOperandsInserted(HloInstruction * new_instruction,HloInstructionSequence * new_sequence,absl::flat_hash_set<HloInstruction * > * inserted_instructions) const1494 void MemorySpaceAssignment::EnsureInstructionAndOperandsInserted(
1495     HloInstruction* new_instruction, HloInstructionSequence* new_sequence,
1496     absl::flat_hash_set<HloInstruction*>* inserted_instructions) const {
1497   if (inserted_instructions->contains(new_instruction)) {
1498     return;
1499   }
1500   for (HloInstruction* operand : new_instruction->operands()) {
1501     // CopyStart/CopyDone dependencies should always be already inserted; it is
1502     // a red flag when they haven't already been inserted.
1503     CHECK((operand->opcode() != HloOpcode::kCopyStart &&
1504            operand->opcode() != HloOpcode::kCopyDone) ||
1505           inserted_instructions->contains(operand))
1506         << "Inserted instruction " << new_instruction->ToString()
1507         << " has un-inserted dependency: " << operand->ToString();
1508     EnsureInstructionAndOperandsInserted(operand, new_sequence,
1509                                          inserted_instructions);
1510   }
1511   VLOG(4) << "inserting: " << new_instruction->ToShortString();
1512   new_sequence->push_back(new_instruction);
1513   inserted_instructions->insert(new_instruction);
1514 }
1515 
ScheduleAsynchronousCopies()1516 void MemorySpaceAssignment::ScheduleAsynchronousCopies() {
1517   for (MemorySpace memory_space :
1518        {MemorySpace::kDefault, MemorySpace::kAlternate}) {
1519     std::vector<CopyAllocation*> copy_allocations;
1520     for (auto& value_and_sequence : allocation_sequence_list_) {
1521       for (auto& allocation : value_and_sequence.sequence) {
1522         if (allocation->is_copy_allocation()) {
1523           auto copy_allocation = static_cast<CopyAllocation*>(allocation.get());
1524           if (copy_allocation->memory_space() == memory_space) {
1525             copy_allocations.push_back(copy_allocation);
1526           }
1527         }
1528       }
1529     }
1530 
1531     absl::c_stable_sort(
1532         copy_allocations, [](CopyAllocation* first, CopyAllocation* second) {
1533           return std::forward_as_tuple(first->copy_done_schedule_before(),
1534                                        first->copy_start_schedule_after()) <
1535                  std::forward_as_tuple(second->copy_done_schedule_before(),
1536                                        second->copy_start_schedule_after());
1537         });
1538 
1539     CopyAllocation* prev_copy_allocation = nullptr;
1540     for (CopyAllocation* copy_allocation : copy_allocations) {
1541       // If the copy start doesn't happen to be scheduled at the correct
1542       // computation, delay it until the correct computation starts.
1543       int64 copy_start_schedule_after =
1544           copy_allocation->copy_start_schedule_after();
1545       // Accessing flattened_instructions_ here without checking if it is
1546       // nullptr is safe because this method is called before SimplifyGraph.
1547       while (copy_allocation->instruction()->parent() !=
1548              flattened_instructions_[copy_start_schedule_after]->parent()) {
1549         VLOG(4) << "Delaying CopyStart (" << copy_start_schedule_after << " to "
1550                 << (copy_start_schedule_after + 1) << ") for "
1551                 << copy_allocation->copy_start()->ToString()
1552                 << " because it is not in the correct computation.";
1553         copy_allocation->set_copy_start_schedule_after(
1554             ++copy_start_schedule_after);
1555       }
1556 
1557       schedule_after_[copy_allocation->copy_start_schedule_after()].push_back(
1558           copy_allocation->copy_start());
1559       schedule_before_[copy_allocation->copy_done_schedule_before()].push_back(
1560           copy_allocation->copy_done());
1561       prev_copy_allocation = copy_allocation;
1562     }
1563   }
1564 }
1565 
FixSchedule()1566 Status MemorySpaceAssignment::FixSchedule() {
1567   CHECK(module_->has_schedule());
1568   HloSchedule& schedule = module_->schedule();
1569   for (const HloComputation* computation :
1570        module_->MakeNonfusionComputations()) {
1571     // Parallel computations aren't in the schedule and don't need to be
1572     // modified.
1573     if (!computations_in_schedule_.contains(computation)) {
1574       VLOG(4) << "Not scheduling " << computation->name()
1575               << " because it's not in the schedule.";
1576       continue;
1577     }
1578     CHECK(schedule.is_computation_scheduled(computation));
1579     HloInstructionSequence new_sequence;
1580 
1581     absl::flat_hash_set<HloInstruction*> inserted_instructions;
1582 
1583     VLOG(4) << "Scheduling: " << computation->ToString();
1584 
1585     for (int64 instruction_index = 0;
1586          instruction_index < flattened_instructions_.size();
1587          ++instruction_index) {
1588       auto insts_before_iter = schedule_before_.find(instruction_index);
1589       if (insts_before_iter != schedule_before_.end()) {
1590         for (HloInstruction* new_instruction : insts_before_iter->second) {
1591           if (new_instruction->parent() == computation) {
1592             VLOG(4) << "before " << instruction_index << ": "
1593                     << new_instruction->name();
1594             EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence,
1595                                                  &inserted_instructions);
1596           }
1597         }
1598       }
1599       HloInstruction* instruction = flattened_instructions_[instruction_index];
1600       // Insert only if it is not deleted (SimplifyGraph sets it to nullptr if
1601       // it was deleted) and not previously inserted. Also bitcasts and tuples
1602       // are treated specially and only inserted as a result of operand
1603       // dependencies.
1604       if (instruction != nullptr &&
1605           !inserted_instructions.contains(instruction) &&
1606           instruction->parent() == computation &&
1607           instruction->opcode() != HloOpcode::kBitcast &&
1608           instruction->opcode() != HloOpcode::kTuple) {
1609         VLOG(4) << "inst " << instruction_index << ": " << instruction->name();
1610         EnsureInstructionAndOperandsInserted(instruction, &new_sequence,
1611                                              &inserted_instructions);
1612       }
1613       auto insts_after_iter = schedule_after_.find(instruction_index);
1614       if (insts_after_iter != schedule_after_.end()) {
1615         for (HloInstruction* new_instruction : insts_after_iter->second) {
1616           if (new_instruction->parent() == computation) {
1617             VLOG(4) << "after " << instruction_index << ": "
1618                     << new_instruction->name();
1619             EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence,
1620                                                  &inserted_instructions);
1621           }
1622         }
1623       }
1624     }
1625     // For rare cases where the original sequence is empty, ensure the root
1626     // instruction and its dependencies are scheduled.
1627     EnsureInstructionAndOperandsInserted(computation->root_instruction(),
1628                                          &new_sequence, &inserted_instructions);
1629     CHECK_EQ(new_sequence.size(), computation->instruction_count())
1630         << "New sequence for computation " << computation->name() << " has "
1631         << new_sequence.size() << " instructions, expects "
1632         << computation->instruction_count() << ".";
1633     schedule.set_sequence(computation, new_sequence);
1634   }
1635 
1636   return Status::OK();
1637 }
1638 
VerifyAndExportHeapSimulatorTrace()1639 Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() {
1640   VLOG(3) << "Verifying:";
1641   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1642                       HloAliasAnalysis::Run(module_));
1643   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
1644                       HloLiveRange::Run(module_->schedule(), *alias_analysis,
1645                                         module_->entry_computation()));
1646 
1647   BufferIntervalTree interval_tree;
1648   absl::flat_hash_set<int64> seen_buffers;
1649   std::map<std::pair<int64, int64>,
1650            std::tuple<const HloValue*, Chunk, HeapSimulatorTrace::Event::Kind>>
1651       events;
1652 
1653   for (const auto& position_and_chunk : preset_assignments_->chunks()) {
1654     const HloPosition& position = position_and_chunk.first;
1655     const Chunk& chunk = position_and_chunk.second;
1656     const HloBuffer& buffer =
1657         alias_analysis->GetUniqueBufferAt(position.instruction, position.index);
1658     if (seen_buffers.contains(buffer.id())) {
1659       continue;
1660     }
1661     seen_buffers.insert(buffer.id());
1662 
1663     int64 start_time = INT64_MAX;
1664     int64 end_time = -1;
1665     for (const HloValue* value : buffer.values()) {
1666       const HloLiveRange::TimeBound& time_bound =
1667           hlo_live_range->buffer_live_ranges().at(value);
1668       VLOG(3) << "  value: " << value->ToShortString() << " ("
1669               << time_bound.start << ", " << time_bound.end << ")";
1670       start_time = std::min(start_time, time_bound.start);
1671       end_time = std::max(end_time, time_bound.end);
1672       events[std::make_pair(time_bound.start, value->id())] =
1673           std::make_tuple(value, chunk, HeapSimulatorTrace::Event::ALLOC);
1674       events[std::make_pair(time_bound.end, value->id())] =
1675           std::make_tuple(value, chunk, HeapSimulatorTrace::Event::FREE);
1676     }
1677     CHECK_GE(start_time, 0);
1678     CHECK_GT(end_time, 0);
1679     // Get the chunks overlapping in time and search if they overlap in space as
1680     // well.
1681     // TODO(berkin): For now checking against end_time - 1 (exclusive), but we
1682     // really should check against end_time (inclusive) for cases where the
1683     // operand can't share buffer with user (see
1684     // HloDataflowAnalysis::CanShareOperandBufferWithUser).
1685     if (options_.verify || VLOG_IS_ON(1)) {
1686       // Verify only if the option is set or if vlog is on.
1687       for (const Chunk& overlapping_chunk :
1688            interval_tree.ChunksOverlappingInTime(start_time, end_time - 1)) {
1689         if (chunk.OverlapsWith(overlapping_chunk)) {
1690           return InternalError(
1691               ("Buffer %s (%d, %d) off: %d size: %d overlaps with another chunk"
1692                " off: %d size: %d"),
1693               buffer.ToString(), start_time, end_time, chunk.offset, chunk.size,
1694               overlapping_chunk.offset, overlapping_chunk.size);
1695         }
1696       }
1697     }
1698     interval_tree.Add(start_time, end_time - 1, chunk);
1699     VLOG(3) << " buffer: " << buffer.ToString() << ": (" << start_time << ", "
1700             << end_time << ") off: " << position_and_chunk.second.offset
1701             << ", size: " << position_and_chunk.second.size;
1702   }
1703 
1704   HeapSimulatorTrace* heap_trace =
1705       &preset_assignments_
1706            ->assignment_information_for_space(options_.alternate_memory_space)
1707            ->heap_simulator_trace;
1708   int64 memory_usage = 0;
1709   int64 max_memory_usage = 0;
1710   for (const auto& event : events) {
1711     int64 time = event.first.first;
1712     int64 buffer_id = event.first.second;
1713     const HloValue* value;
1714     Chunk chunk;
1715     HeapSimulatorTrace::Event::Kind kind;
1716     std::tie(value, chunk, kind) = event.second;
1717     HeapSimulatorTrace::Event* heap_trace_event = heap_trace->add_events();
1718     heap_trace_event->set_kind(kind);
1719     heap_trace_event->set_buffer_id(buffer_id);
1720     heap_trace_event->set_instruction_name(value->instruction()->name());
1721     heap_trace_event->set_computation_name(
1722         value->instruction()->parent()->name());
1723 
1724     if (kind == HeapSimulatorTrace::Event::ALLOC) {
1725       memory_usage += chunk.size;
1726     } else {
1727       CHECK_EQ(kind, HeapSimulatorTrace::Event::FREE);
1728       memory_usage -= chunk.size;
1729     }
1730     max_memory_usage = std::max(max_memory_usage, memory_usage);
1731     VLOG(3) << "Memory usage: " << memory_usage << " at time: " << time;
1732   }
1733   VLOG(1) << "Max memory usage ignoring fragmentation: " << max_memory_usage;
1734 
1735   return Status::OK();
1736 }
1737 
1738 }  // namespace xla
1739