1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/memory_space_assignment.h"
17
18 namespace xla {
19
20 namespace {
21 // Define a dummy chunk for chunks that will be allocated in the default memory
22 // space and for keeping track of number of asynchronous copies.
23 const HeapSimulator::Chunk kDummyChunk{-1, -1};
24 } // namespace
25
GetInstructionElapsedDueToCompute(const HloInstruction & instruction) const26 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToCompute(
27 const HloInstruction& instruction) const {
28 return std::max(
29 cost_analysis_.flop_count(instruction) /
30 cost_analysis_.per_second_rate(HloCostAnalysis::kFlopsKey),
31 cost_analysis_.transcendental_count(instruction) /
32 cost_analysis_.per_second_rate(HloCostAnalysis::kTranscendentalsKey));
33 }
34
35 float MemorySpaceAssignmentCostAnalysis::
GetInstructionElapsedDueToMemorySlowdown(int64 bytes) const36 GetInstructionElapsedDueToMemorySlowdown(int64 bytes) const {
37 return bytes /
38 cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey);
39 }
40
GetInstructionElapsedDueToMemory(const HloInstruction & instruction,absl::optional<int64> operand_in_alternate_mem,bool output_in_alternate_mem) const41 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToMemory(
42 const HloInstruction& instruction,
43 absl::optional<int64> operand_in_alternate_mem,
44 bool output_in_alternate_mem) const {
45 float bytes_accessed = cost_analysis_.bytes_accessed(instruction);
46 float elapsed_due_to_bytes =
47 bytes_accessed /
48 cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey);
49 if (operand_in_alternate_mem) {
50 // Estimate the elapsed time due to the operand being in the alternate
51 // memory space.
52 float operand_bytes_accessed = cost_analysis_.operand_bytes_accessed(
53 instruction, *operand_in_alternate_mem);
54 float elapsed_due_to_operand_bytes =
55 operand_bytes_accessed / alternate_mem_bandwidth_bytes_per_second_;
56 bytes_accessed -= operand_bytes_accessed;
57 elapsed_due_to_bytes =
58 elapsed_due_to_operand_bytes +
59 bytes_accessed /
60 cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey);
61 }
62 if (output_in_alternate_mem) {
63 // Estimate the elapsed time due to the output being in the alternate memory
64 // space.
65 float output_bytes_accessed =
66 cost_analysis_.output_bytes_accessed(instruction);
67 float elapsed_due_to_output_bytes =
68 output_bytes_accessed / alternate_mem_bandwidth_bytes_per_second_;
69 bytes_accessed -= output_bytes_accessed;
70 elapsed_due_to_bytes =
71 elapsed_due_to_output_bytes +
72 bytes_accessed /
73 cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey);
74 }
75 return elapsed_due_to_bytes;
76 }
77
GetInstructionElapsed(const HloInstruction & instruction,absl::optional<int64> operand_in_alternate_mem,bool output_in_alternate_mem) const78 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsed(
79 const HloInstruction& instruction,
80 absl::optional<int64> operand_in_alternate_mem,
81 bool output_in_alternate_mem) const {
82 return std::max(
83 GetInstructionElapsedDueToCompute(instruction),
84 GetInstructionElapsedDueToMemory(instruction, operand_in_alternate_mem,
85 output_in_alternate_mem));
86 }
87
GetAsyncCopyElapsed(const Shape & shape) const88 float MemorySpaceAssignmentCostAnalysis::GetAsyncCopyElapsed(
89 const Shape& shape) const {
90 int64 size_in_bytes = cost_analysis_.GetShapeSize(shape);
91 return static_cast<float>(size_in_bytes) /
92 async_copy_bandwidth_bytes_per_second_;
93 }
94
GetScheduleEndTime() const95 int64 MemorySpaceAssignmentCostAnalysis::GetScheduleEndTime() const {
96 return hlo_live_range_.schedule_end_time();
97 }
98
CanAllocateInAlternateMemoryNoCopy(const Shape & shape,int64 start_time,int64 end_time) const99 bool InstructionCountPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy(
100 const Shape& shape, int64 start_time, int64 end_time) const {
101 return end_time - start_time <= max_overlap_count_;
102 }
103
PreferredEvictionEndTime(const Shape & shape,int64 start_time,int64 latest_end_time) const104 int64 InstructionCountPrefetchIntervalPicker::PreferredEvictionEndTime(
105 const Shape& shape, int64 start_time, int64 latest_end_time) const {
106 return std::min(start_time + min_overlap_count_, latest_end_time);
107 }
108
Begin(const HloUse & use,int64 start_time,int64 end_time)109 void InstructionCountPrefetchIntervalPicker::Begin(const HloUse& use,
110 int64 start_time,
111 int64 end_time) {
112 end_time_ = end_time;
113 current_prefetch_time_ = std::max(start_time, end_time_ - max_overlap_count_);
114 }
115
Next()116 int64 InstructionCountPrefetchIntervalPicker::Next() {
117 CHECK(!Done()) << "Prefetch interval picker's Next() is called even though "
118 "Done() is false";
119 return current_prefetch_time_++;
120 }
121
Done() const122 bool InstructionCountPrefetchIntervalPicker::Done() const {
123 return end_time_ - current_prefetch_time_ <= min_overlap_count_;
124 }
125
ToDebugString() const126 std::string InstructionCountPrefetchIntervalPicker::ToDebugString() const {
127 return absl::StrCat("Overlapped HLOs = ", end_time_ - current_prefetch_time_);
128 }
129
ToNoCopyDebugString(const Shape & shape,int64 start_time,int64 end_time) const130 std::string InstructionCountPrefetchIntervalPicker::ToNoCopyDebugString(
131 const Shape& shape, int64 start_time, int64 end_time) const {
132 return absl::StrCat("Overlapped HLOs = ", end_time - start_time);
133 }
134
CostAnalysisPrefetchIntervalPicker(const MemorySpaceAssignmentCostAnalysis & cost_analysis,float min_async_copy_to_overlap_ratio,float max_async_copy_to_overlap_ratio)135 CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker(
136 const MemorySpaceAssignmentCostAnalysis& cost_analysis,
137 float min_async_copy_to_overlap_ratio,
138 float max_async_copy_to_overlap_ratio)
139 : cost_analysis_(cost_analysis),
140 min_async_copy_to_overlap_ratio_(min_async_copy_to_overlap_ratio),
141 max_async_copy_to_overlap_ratio_(max_async_copy_to_overlap_ratio) {
142 instruction_schedule_ =
143 &cost_analysis_.hlo_live_range().instruction_schedule();
144
145 // First create a vector of elapsed times of HLO instructions.
146 std::vector<float> instructions_elapsed_time(instruction_schedule_->size(),
147 0.0);
148 for (const auto& instruction_and_logical_time : *instruction_schedule_) {
149 float elapsed_time = cost_analysis_.cost_analysis().optimal_seconds(
150 *instruction_and_logical_time.first);
151 int64 logical_time = instruction_and_logical_time.second;
152 if (logical_time >= instructions_elapsed_time.size()) {
153 instructions_elapsed_time.resize(logical_time + 1, 0.0);
154 }
155 instructions_elapsed_time[logical_time] = elapsed_time;
156 }
157 // As an optimization, create a cumulative sum vector of elapsed time.
158 float cumsum = 0.0;
159 for (float elapsed_time : instructions_elapsed_time) {
160 cumsum += elapsed_time;
161 elapsed_time_cumsum_.push_back(cumsum);
162 }
163 }
164
CanAllocateInAlternateMemoryNoCopy(const Shape & shape,int64 start_time,int64 end_time) const165 bool CostAnalysisPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy(
166 const Shape& shape, int64 start_time, int64 end_time) const {
167 // Even though this method returns if we allow the buffer in alternate memory
168 // _without_ asynchronous copies, calculate how long it would have taken to
169 // copy it and compare it to the elapsed time in the logical interval.
170 float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
171 float logical_interval_elapsed =
172 GetLogicalIntervalElapsed(start_time, end_time);
173 return max_async_copy_to_overlap_ratio_ * async_copy_elapsed >
174 logical_interval_elapsed;
175 }
176
PreferredEvictionEndTime(const Shape & shape,int64 start_time,int64 latest_end_time) const177 int64 CostAnalysisPrefetchIntervalPicker::PreferredEvictionEndTime(
178 const Shape& shape, int64 start_time, int64 latest_end_time) const {
179 float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
180 int64 end_time;
181 for (end_time = start_time + 1; end_time <= latest_end_time; ++end_time) {
182 float logical_interval_elapsed =
183 GetLogicalIntervalElapsed(start_time, end_time);
184 if (logical_interval_elapsed >=
185 min_async_copy_to_overlap_ratio_ * async_copy_elapsed) {
186 break;
187 }
188 }
189 return end_time;
190 }
191
Begin(const HloUse & use,int64 start_time,int64 end_time)192 void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use,
193 int64 start_time,
194 int64 end_time) {
195 const Shape& shape = use.instruction->operand(use.operand_number)->shape();
196 // Find the earliest time that satisfies max_async_copy_to_overlap_ratio_.
197 async_copy_elapsed_ = cost_analysis_.GetAsyncCopyElapsed(shape);
198 // Estimate the time we would save by having this op in alternate memory.
199 float elapsed_time = cost_analysis_.GetInstructionElapsed(*use.instruction);
200 float elapsed_time_in_alternate_mem = cost_analysis_.GetInstructionElapsed(
201 *use.instruction, use.operand_number);
202 inst_elapsed_reduction_ = elapsed_time - elapsed_time_in_alternate_mem;
203 end_logical_time_ = end_time;
204 // Find the earliest time we're allowed to start prefetching.
205 for (current_logical_prefetch_time_ = start_time;
206 current_logical_prefetch_time_ <= end_logical_time_ &&
207 max_async_copy_to_overlap_ratio_ * async_copy_elapsed_ <
208 GetLogicalIntervalElapsed(current_logical_prefetch_time_,
209 end_logical_time_);
210 ++current_logical_prefetch_time_) {
211 }
212 }
213
Next()214 int64 CostAnalysisPrefetchIntervalPicker::Next() {
215 CHECK(!Done()) << "Prefetch interval picker's Next() is called even though "
216 "Done() is false";
217 return current_logical_prefetch_time_++;
218 }
219
Done() const220 bool CostAnalysisPrefetchIntervalPicker::Done() const {
221 // The end time is inclusive, so we're done if the prefetch time is greater
222 // than that.
223 if (current_logical_prefetch_time_ > end_logical_time_) {
224 return true;
225 }
226 float logical_interval_elapsed = GetLogicalIntervalElapsed(
227 current_logical_prefetch_time_, end_logical_time_);
228 return async_copy_elapsed_ * min_async_copy_to_overlap_ratio_ >
229 logical_interval_elapsed + inst_elapsed_reduction_;
230 }
231
GetLogicalIntervalElapsed(int64 start_time,int64 end_time) const232 float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed(
233 int64 start_time, int64 end_time) const {
234 return elapsed_time_cumsum_[end_time - 1] - elapsed_time_cumsum_[start_time];
235 }
236
ToDebugString() const237 std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const {
238 float logical_interval_elapsed = GetLogicalIntervalElapsed(
239 current_logical_prefetch_time_, end_logical_time_);
240 return absl::StrCat(
241 "Async copy elapsed (s) = ", async_copy_elapsed_,
242 ", inst elapsed reduction (s) = ", inst_elapsed_reduction_,
243 ", logical interval elapsed (s) = ", logical_interval_elapsed);
244 }
245
ToNoCopyDebugString(const Shape & shape,int64 start_time,int64 end_time) const246 std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString(
247 const Shape& shape, int64 start_time, int64 end_time) const {
248 float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
249 float logical_interval_elapsed =
250 GetLogicalIntervalElapsed(start_time, end_time);
251 return absl::StrCat(
252 "Async copy elapsed (s) = ", async_copy_elapsed,
253 ", logical interval elapsed (s) = ", logical_interval_elapsed);
254 }
255
256 std::vector<const GlobalDecreasingSizeBestFitHeap::BufferInterval*>
GetSortedColocatedIntervals(const GlobalDecreasingSizeBestFitHeap::BufferInterval & interval) const257 AlternateMemoryBestFitHeap::GetSortedColocatedIntervals(
258 const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const {
259 std::vector<const BufferInterval*> colocated_intervals;
260 std::vector<const BufferInterval*> worklist = {&interval};
261 while (!worklist.empty()) {
262 const BufferInterval* item = worklist.back();
263 worklist.pop_back();
264 colocated_intervals.push_back(item);
265 for (const HloValue* buffer_colocated : item->colocations) {
266 worklist.push_back(&buffer_intervals_.at(buffer_colocated));
267 }
268 }
269
270 absl::c_stable_sort(colocated_intervals, [&](const BufferInterval* x,
271 const BufferInterval* y) {
272 return std::make_pair(x->start, x->end) < std::make_pair(y->start, y->end);
273 });
274 return colocated_intervals;
275 }
276
IsIntervalAllowedInAlternateMemory(const BufferInterval & interval) const277 bool AlternateMemoryBestFitHeap::IsIntervalAllowedInAlternateMemory(
278 const BufferInterval& interval) const {
279 // If the buffer is a tuple, don't use this algorithm for now. The buffers
280 // that are pointed to by the tuple will still use this algorithm. Because
281 // tuples are cheap to place in the alternate memory (they are just pointers)
282 // we don't need to use prefetch/evict logic.
283 if (interval.buffer->shape().IsTuple()) {
284 VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
285 << " in default mem because it is a tuple.";
286 return false;
287 }
288
289 // The semantics of TupleSelect are weird: TupleSelect doesn't define a
290 // buffer, but just forwards the buffers in the either left or right side.
291 // This means the the two different inputs to TupleSelect must not alias, yet
292 // they should be allocated in the same memory space, and both buffers must be
293 // kept alive for the entire live range of TupleSelect. Instead, just don't
294 // allocate TupleSelect in the alternate memory space.
295 // TODO(berkin): Not allocating add-dependencies either since they need to be
296 // treated specially. We should revisit this later.
297 for (const HloPosition& position : interval.buffer->positions()) {
298 if (position.instruction->opcode() == HloOpcode::kTupleSelect ||
299 position.instruction->opcode() == HloOpcode::kAddDependency) {
300 VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
301 << " in default mem because it has a tuple-select or "
302 << "add-dependency position.";
303 return false;
304 }
305 }
306
307 // Send and Recv HLOs return a request identifier. These should not be
308 // allocated in the alternate memory.
309 const HloPosition& defining_position = interval.buffer->defining_position();
310 if ((defining_position.instruction->opcode() == HloOpcode::kSend ||
311 defining_position.instruction->opcode() == HloOpcode::kRecv) &&
312 defining_position.index == ShapeIndex({1})) {
313 VLOG(4)
314 << "Keeping value " << interval.buffer->ToShortString()
315 << " in default mem because it is a request identifier for send/recv.";
316 return false;
317 }
318
319 return true;
320 }
321
Finish()322 HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
323 std::vector<BufferInterval> sorted_buffer_intervals =
324 GetSortedBufferIntervals();
325
326 VLOG(1) << "Assigning buffers to alternate memory. Max heap size = "
327 << options_.max_size_in_bytes;
328
329 AddInputAndOutputRequiredAssignments();
330
331 for (auto& interval : sorted_buffer_intervals) {
332 if (!interval.need_allocation) {
333 continue;
334 }
335
336 if (!IsIntervalAllowedInAlternateMemory(interval)) {
337 continue;
338 }
339
340 auto colocated_intervals = GetSortedColocatedIntervals(interval);
341
342 if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) {
343 VLOG(4) << "Interval " << interval.buffer->ToShortString()
344 << " is reserved in the alternate memory. Total reserved bytes = "
345 << reserved_in_bytes_;
346 for (const BufferInterval* colocated_interval : colocated_intervals) {
347 const HloValue* value = colocated_interval->buffer;
348 // Color all of the aliased reserved buffers here because reserved
349 // alternate memory allocations will not have an entry in preset
350 // allocations that is normally used for coloring.
351 for (auto& position : value->positions()) {
352 VLOG(3) << "Coloring " << position.ToString();
353 Shape* shape = ShapeUtil::GetMutableSubshape(
354 position.instruction->mutable_shape(), position.index);
355 CHECK(shape->IsArray()) << "Coloring a shape that is not an array: "
356 << position.ToString();
357 shape->mutable_layout()->set_memory_space(
358 options_.alternate_memory_space);
359 }
360 }
361 // Increment the reserved part of alternate memory so that it is not
362 // available for other buffers. Since all colocated intervals should have
363 // the same size, just use the first one.
364 reserved_in_bytes_ += options_.size_fn(*colocated_intervals[0]->buffer);
365 continue;
366 }
367
368 if (colocated_intervals.size() > 1 &&
369 !options_.allocate_across_sequential_calls) {
370 VLOG(4) << "Not allocating " << interval.buffer->ToShortString()
371 << " because it aliases with another interval and "
372 << " allocate_across_sequential_calls is false.";
373 continue;
374 }
375
376 const HloComputation* defining_computation =
377 colocated_intervals[0]->buffer->defining_instruction()->parent();
378 MemorySpaceAssignment::Allocation* aliased_allocation = nullptr;
379 for (const BufferInterval* colocated_interval : colocated_intervals) {
380 const HloValue* value = colocated_interval->buffer;
381 const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
382 allocation_sequence_list_->push_back({value, {}});
383 MemorySpaceAssignment::AllocationSequence* allocation_sequence =
384 &allocation_sequence_list_->back().sequence;
385 int64 definition_time =
386 instruction_schedule.at(value->defining_instruction());
387 // Sort the uses by the use time.
388 std::vector<HloUse> uses = value->uses();
389 absl::c_stable_sort(uses, [&](HloUse use1, HloUse use2) {
390 return instruction_schedule.at(use1.instruction) <
391 instruction_schedule.at(use2.instruction);
392 });
393
394 // If there was an aliased allocation for this buffer, propagate that for
395 // this HloValue.
396 if (aliased_allocation != nullptr) {
397 VLOG(3) << "Adding an aliased allocation: ("
398 << aliased_allocation->start_time() << ", "
399 << aliased_allocation->end_time()
400 << ") pos: " << aliased_allocation->defining_position()
401 << " mem space: "
402 << (aliased_allocation->memory_space() == MemorySpace::kDefault
403 ? "default"
404 : "alt");
405 allocation_sequence->push_back(
406 absl::make_unique<MemorySpaceAssignment::Allocation>(
407 value->defining_instruction(), value->defining_position(),
408 aliased_allocation->memory_space(), aliased_allocation->chunk(),
409 definition_time, definition_time));
410 }
411
412 // Iterate over the uses.
413 for (HloUse use : uses) {
414 int64 use_time = instruction_schedule.at(use.instruction);
415 int64 last_use_time = instruction_schedule.at(uses.back().instruction);
416 int64 latest_prefetch_time = use_time;
417
418 if (use.instruction->parent() != defining_computation) {
419 VLOG(3) << "skip use " << use.ToString()
420 << " because it's in a different computation.";
421 continue;
422 }
423
424 // Sequential calls include kWhile, kCall, and kConditional opcodes.
425 bool is_sequential_call =
426 (GetInstructionCallContext(use.instruction->opcode()) ==
427 CallContext::kSequential);
428 if (is_sequential_call) {
429 for (const HloComputation* called_computation :
430 use.instruction->called_computations()) {
431 const HloLiveRange::TimeBound& computation_span =
432 hlo_live_range_.computation_span_times().at(called_computation);
433 latest_prefetch_time =
434 std::min(computation_span.start, latest_prefetch_time);
435 }
436 }
437
438 // Bitcasts don't define buffers and don't directly consume buffers.
439 // Skip allocating buffers for bitcast uses. The uses that feed from
440 // bitcasts will be handled specially.
441 if (use.instruction->opcode() != HloOpcode::kBitcast) {
442 if (!FindAllocation(definition_time, use_time, last_use_time,
443 latest_prefetch_time, value->defining_position(),
444 use, value, colocated_interval->size,
445 allocation_sequence)) {
446 // If the allocation finding failed (e.g., due to running out of
447 // asynchronous copies), then fall back to allocating the buffer
448 // entirely in the default memory.
449 pending_chunks_.clear();
450 pending_async_copies_.clear();
451 allocation_sequence->clear();
452 break;
453 }
454
455 // If there are multiple uses, they can try using the memory
456 // allocation already at the alternate memory.
457 definition_time = use_time;
458 }
459
460 // If the use has been a sequential call (e.g. a while loop), the other
461 // colocated intervals must alias with this allocation.
462 if (is_sequential_call) {
463 aliased_allocation =
464 GetLiveAllocationAt(*allocation_sequence, use_time);
465 }
466 }
467 }
468
469 CommitPendingChunks();
470 }
471
472 if (VLOG_IS_ON(3)) {
473 for (const auto& value_and_sequence : *allocation_sequence_list_) {
474 VLOG(3) << "Allocation for " << value_and_sequence.value->ToShortString();
475 for (const auto& alloc : value_and_sequence.sequence) {
476 std::string addr_str = ": default";
477 if (alloc->memory_space() == MemorySpace::kAlternate) {
478 addr_str = absl::StrCat(": alt ", alloc->chunk().offset);
479 }
480
481 VLOG(3) << " " << alloc->start_time() << "-" << alloc->end_time()
482 << addr_str << ", " << alloc->uses().size() << " uses";
483 }
484 }
485 }
486
487 return result_;
488 }
489
operator <(const AsynchronousCopy & a,const AsynchronousCopy & b)490 bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b) {
491 return (a.start_time < b.start_time && a.end_time <= b.end_time) ||
492 (a.start_time <= b.start_time && a.end_time < b.end_time);
493 }
494
AddCopy(const AsynchronousCopy & copy)495 void AsynchronousCopyOrdering::AddCopy(const AsynchronousCopy& copy) {
496 auto it_and_inserted = ranges_.insert(copy);
497 CHECK(it_and_inserted.second ||
498 it_and_inserted.first->start_time == copy.start_time);
499 }
500
ViolatesOrdering(int64 start_time,int64 end_time) const501 bool AsynchronousCopyOrdering::ViolatesOrdering(int64 start_time,
502 int64 end_time) const {
503 // We allow identical start and end times. It is enough to check for just the
504 // start time in case we find a match in ranges_ because the found value will
505 // either be identical to {start_time, end_time} (and this doesn't violate) or
506 // its start_time will be smaller and end_time will be larger (this violates).
507 auto copy_it = ranges_.find(
508 {start_time, end_time, MemorySpaceAssignment::MemorySpace::kAlternate});
509 return copy_it != ranges_.end() && copy_it->start_time != start_time;
510 }
511
512 /*static*/ MemorySpaceAssignment::Allocation*
GetLiveAllocationAt(const MemorySpaceAssignment::AllocationSequence & allocations,int64 time)513 AlternateMemoryBestFitHeap::GetLiveAllocationAt(
514 const MemorySpaceAssignment::AllocationSequence& allocations, int64 time) {
515 for (auto allocation_it = allocations.rbegin();
516 allocation_it != allocations.rend(); ++allocation_it) {
517 if ((*allocation_it)->start_time() <= time &&
518 (*allocation_it)->end_time() >= time) {
519 return allocation_it->get();
520 }
521 }
522 return nullptr;
523 }
524
AddInputAndOutputRequiredAssignments()525 void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() {
526 // Go through the parameters and outputs and pin them to the corresponding
527 // memory by adding a required assignment.
528 const HloModule& module = alias_analysis_.dataflow_analysis().module();
529 const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
530 HloComputation* entry_computation = module.entry_computation();
531 for (HloInstruction* parameter_instruction :
532 entry_computation->parameter_instructions()) {
533 int64 parameter_instruction_time =
534 instruction_schedule.at(parameter_instruction);
535 ShapeUtil::ForEachSubshape(
536 parameter_instruction->shape(),
537 [&](const Shape& subshape, const ShapeIndex& index) {
538 MemorySpace memory_space = MemorySpace::kDefault;
539 if (subshape.has_layout() && subshape.layout().memory_space() ==
540 options_.alternate_memory_space) {
541 memory_space = MemorySpace::kAlternate;
542 }
543 for (const HloBuffer* buffer :
544 alias_analysis_.ComputeBuffersAt(parameter_instruction, index)) {
545 for (const HloValue* value : buffer->values()) {
546 VLOG(3) << "Adding required assignment for parameter value = "
547 << value->ToShortString()
548 << " time = " << parameter_instruction_time << " space = "
549 << (memory_space == MemorySpace::kDefault ? "def"
550 : "alt");
551 required_assignments_[value].push_back(
552 {memory_space, /*time=*/parameter_instruction_time});
553 }
554 }
555 });
556 }
557 HloInstruction* root_instruction = entry_computation->root_instruction();
558 int64 root_instruction_time = instruction_schedule.at(root_instruction);
559 ShapeUtil::ForEachSubshape(
560 root_instruction->shape(),
561 [&](const Shape& subshape, const ShapeIndex& index) {
562 MemorySpace memory_space = MemorySpace::kDefault;
563 if (subshape.has_layout() && subshape.layout().memory_space() ==
564 options_.alternate_memory_space) {
565 memory_space = MemorySpace::kAlternate;
566 }
567 for (const HloBuffer* buffer :
568 alias_analysis_.ComputeBuffersAt(root_instruction, index)) {
569 for (const HloValue* value : buffer->values()) {
570 VLOG(3) << "Adding required assignment for output value = "
571 << value->ToShortString()
572 << " time = " << root_instruction_time << " space = "
573 << (memory_space == MemorySpace::kDefault ? "def" : "alt");
574 required_assignments_[value].push_back(
575 {memory_space, /*time=*/root_instruction_time});
576 }
577 }
578 });
579 }
580
AreIntervalsReservedInAlternateMemory(absl::Span<const BufferInterval * const> colocated_intervals) const581 bool AlternateMemoryBestFitHeap::AreIntervalsReservedInAlternateMemory(
582 absl::Span<const BufferInterval* const> colocated_intervals) const {
583 auto is_position_in_alternate_memory = [&](const HloPosition& position) {
584 const Shape& shape = position.shape();
585 return shape.has_layout() &&
586 shape.layout().memory_space() == options_.alternate_memory_space;
587 };
588
589 const HloModule& module = alias_analysis_.dataflow_analysis().module();
590 const HloComputation* entry_computation = module.entry_computation();
591 const HloInstruction* root_instruction =
592 entry_computation->root_instruction();
593 for (const BufferInterval* colocated_interval : colocated_intervals) {
594 const HloValue* value = colocated_interval->buffer;
595 if (value->defining_instruction()->opcode() == HloOpcode::kParameter &&
596 value->defining_instruction()->parent() == entry_computation &&
597 is_position_in_alternate_memory(value->defining_position())) {
598 return true;
599 }
600
601 for (const HloPosition& position : value->positions()) {
602 if (position.instruction == root_instruction &&
603 is_position_in_alternate_memory(position)) {
604 return true;
605 }
606 }
607 }
608 return false;
609 }
610
CommitPendingChunks()611 void AlternateMemoryBestFitHeap::CommitPendingChunks() {
612 for (auto interval_and_chunk : pending_chunks_) {
613 VLOG(3) << "Committing chunk: " << interval_and_chunk.first.start << "-"
614 << interval_and_chunk.first.end << " : ["
615 << interval_and_chunk.second.chunk.offset << ", "
616 << interval_and_chunk.second.chunk.size << "]";
617 CommitChunk(interval_and_chunk.first, interval_and_chunk.second);
618 }
619 pending_chunks_.clear();
620 // Also add the pending async copies to the interval tree.
621 for (const auto& interval : pending_async_copies_) {
622 if (options_.max_outstanding_async_copies >= 0) {
623 async_copy_interval_tree_.Add(interval.start_time, interval.end_time,
624 kDummyChunk);
625 }
626 if (interval.destination == MemorySpace::kAlternate) {
627 async_copy_ordering_.AddCopy(interval);
628 }
629 }
630 pending_async_copies_.clear();
631 }
632
AddToPendingChunks(const BufferInterval & buffer_interval,const ChunkCandidate & chunk_candidate)633 void AlternateMemoryBestFitHeap::AddToPendingChunks(
634 const BufferInterval& buffer_interval,
635 const ChunkCandidate& chunk_candidate) {
636 pending_chunks_.emplace_back(buffer_interval, chunk_candidate);
637 }
638
RequiredInDefaultMemory(const HloValue * buffer,int64 time) const639 bool AlternateMemoryBestFitHeap::RequiredInDefaultMemory(const HloValue* buffer,
640 int64 time) const {
641 auto required_assignment_it = required_assignments_.find(buffer);
642 return required_assignment_it != required_assignments_.end() &&
643 absl::c_any_of(
644 required_assignment_it->second,
645 [&](const RequiredMemoryAssignment& required_assignment) {
646 return required_assignment.memory_space ==
647 MemorySpace::kDefault &&
648 required_assignment.time == time;
649 });
650 }
651
FindAllocation(int64 start_time,int64 end_time,int64 last_use_time,int64 latest_prefetch_time,HloPosition defining_position,HloUse use,const HloValue * buffer,int64 size,MemorySpaceAssignment::AllocationSequence * allocations)652 bool AlternateMemoryBestFitHeap::FindAllocation(
653 int64 start_time, int64 end_time, int64 last_use_time,
654 int64 latest_prefetch_time, HloPosition defining_position, HloUse use,
655 const HloValue* buffer, int64 size,
656 MemorySpaceAssignment::AllocationSequence* allocations) {
657 HloInstruction* operand =
658 use.instruction->mutable_operand(use.operand_number);
659 // If the operand is a bitcast, we look at bitcast's operand until we find a
660 // non-bitcast operand.
661 HloInstruction* non_bitcast_operand = operand;
662 while (non_bitcast_operand->opcode() == HloOpcode::kBitcast) {
663 non_bitcast_operand = non_bitcast_operand->mutable_operand(0);
664 }
665 // Create an alternate memory interval that starts at the earliest
666 // possible position, given by max_prefetch_interval.
667 BufferInterval alternate_mem_interval;
668 alternate_mem_interval.buffer = buffer;
669 alternate_mem_interval.size = size;
670 alternate_mem_interval.end = end_time;
671
672 // start_time == end_time is a special case where the value is consumed
673 // multiple times by the same instruction. We can just find the previous
674 // allocation and use that allocation.
675 if (start_time == end_time) {
676 MemorySpaceAssignment::Allocation* allocation =
677 GetLiveAllocationAt(*allocations, end_time);
678 CHECK_NE(allocation, nullptr);
679 allocation->AddUse(use);
680 return true;
681 }
682
683 VLOG(2) << "Finding allocation for " << buffer->ToShortString() << " ("
684 << start_time << ", " << end_time
685 << ") latest prefetch = " << latest_prefetch_time
686 << " last use = " << last_use_time << " use = " << use.ToString()
687 << ". Size = " << size
688 << ", def pos = " << defining_position.ToString()
689 << ", operand = " << operand->ToShortString()
690 << (non_bitcast_operand != operand
691 ? ", non_bitcast_operand = " +
692 non_bitcast_operand->ToShortString()
693 : "");
694 CHECK_LE(start_time, end_time);
695
696 // There could be a requirement to pin this buffer to default memory either
697 // because it is a parameter or an output. If the buffer is a parameter, then
698 // we're allowed to prefetch. If the use expects the ouput to be in default
699 // memory, we cannot prefetch it because if we did, it would be in alternate
700 // memory instead.
701 bool in_default_mem_at_start = RequiredInDefaultMemory(buffer, start_time);
702 bool in_default_mem_at_end = RequiredInDefaultMemory(buffer, end_time);
703
704 // First try keeping the allocation entirely in the alternate memory.
705 if (!in_default_mem_at_start && !in_default_mem_at_end &&
706 TryAllocatingInAlternateMemoryNoCopy(
707 start_time, end_time, last_use_time, defining_position, use,
708 alternate_mem_interval, non_bitcast_operand, allocations)) {
709 return true;
710 }
711
712 auto prev_allocation_it = allocations->rbegin();
713 // Find a previous allocation that is in the default memory space (not
714 // necessarily the very last allocation).
715 auto prev_allocation_in_default_mem_it = std::find_if(
716 allocations->rbegin(), allocations->rend(), [&](const auto& allocation) {
717 return allocation->memory_space() == MemorySpace::kDefault &&
718 allocation->defining_position() == defining_position;
719 });
720
721 if (prev_allocation_in_default_mem_it == allocations->rend() &&
722 prev_allocation_it != allocations->rend() &&
723 (*prev_allocation_it)->memory_space() == MemorySpace::kAlternate &&
724 (*prev_allocation_it)->defining_position() == defining_position) {
725 // If there was an allocation for this HloValue that was in the alternate
726 // memory space, we also need to perform an eviction.
727 int64 eviction_start_time = (*prev_allocation_it)->start_time();
728 int64 eviction_end_time = (*prev_allocation_it)->end_time();
729 CHECK(eviction_start_time <= eviction_end_time);
730
731 int64 preferred_eviction_end_time = std::max(
732 options_.prefetch_interval_picker->PreferredEvictionEndTime(
733 non_bitcast_operand->shape(), eviction_start_time, end_time),
734 eviction_end_time);
735
736 BufferInterval eviction_mem_interval;
737 eviction_mem_interval.buffer = buffer;
738 eviction_mem_interval.size = size;
739 // Try to reserve a buffer from the end of the previous allocation to the
740 // preferred eviction end time.
741 eviction_mem_interval.start = eviction_end_time + 1;
742 eviction_mem_interval.end = preferred_eviction_end_time;
743 int64 preferred_offset = (*prev_allocation_it)->chunk().offset;
744 VLOG(4) << "Eviction (" << eviction_start_time << ", " << eviction_end_time
745 << ") preferred end time = " << eviction_mem_interval.end;
746
747 for (; eviction_mem_interval.end > eviction_end_time;
748 --eviction_mem_interval.end) {
749 ChunkCandidate chunk_candidate =
750 FindChunkCandidate(eviction_mem_interval, preferred_offset);
751 if (chunk_candidate.chunk.offset == preferred_offset) {
752 AddToPendingChunks(eviction_mem_interval, chunk_candidate);
753 break;
754 }
755 }
756 eviction_end_time = eviction_mem_interval.end;
757
758 VLOG(3) << "Evicting buffer at " << (*prev_allocation_it)->chunk().offset
759 << " (" << eviction_start_time << ", " << eviction_end_time << ")";
760
761 bool eviction_interval_too_short =
762 (eviction_start_time == eviction_end_time);
763 bool eviction_violates_outstanding_copies =
764 ViolatesMaximumOutstandingAsyncCopies(eviction_start_time,
765 eviction_end_time);
766
767 // See if this interval would violate the asynchronous copy limit.
768 if (!eviction_interval_too_short && !eviction_violates_outstanding_copies) {
769 (*prev_allocation_it)->Extend(eviction_end_time);
770 AddAsyncCopy(**prev_allocation_it, MemorySpace::kDefault, kDummyChunk,
771 eviction_start_time, (*prev_allocation_it)->end_time(),
772 eviction_end_time, allocations);
773 } else {
774 if (eviction_violates_outstanding_copies) {
775 VLOG(3) << "This violates the maximum async copies.";
776 } else {
777 VLOG(3) << "Eviction interval is too short (" << eviction_start_time
778 << ", " << eviction_end_time << ").";
779 }
780 // If the original interval violated the limit, try sub-intervals within
781 // this interval.
782 bool eviction_scheduled = false;
783 for (int64 time = eviction_start_time; time < eviction_end_time; ++time) {
784 VLOG(3) << "Try evicting (" << time << ", " << time + 1 << ")";
785 if (!ViolatesMaximumOutstandingAsyncCopies(time, time + 1)) {
786 VLOG(3) << "Eviction successful.";
787 AddAsyncCopy(**prev_allocation_it, MemorySpace::kDefault, kDummyChunk,
788 time, time + 1, time + 1, allocations);
789 eviction_scheduled = true;
790 break;
791 }
792 }
793
794 if (!eviction_scheduled) {
795 // If the eviction couldn't be scheduled, then fail. This buffer will be
796 // kept in the default memory.
797 VLOG(3) << "Bailing: Could not evict " << use.ToString()
798 << " because we hit the limit of maximum asynchronous copies "
799 << "between "
800 << hlo_live_range_.flattened_instruction_sequence()
801 .instructions()[eviction_start_time]
802 << " and "
803 << hlo_live_range_.flattened_instruction_sequence()
804 .instructions()[eviction_end_time];
805 return false;
806 }
807 }
808 prev_allocation_in_default_mem_it = allocations->rbegin();
809 } else if (prev_allocation_in_default_mem_it == allocations->rend()) {
810 allocations->push_back(absl::make_unique<MemorySpaceAssignment::Allocation>(
811 non_bitcast_operand, defining_position, MemorySpace::kDefault,
812 kDummyChunk, start_time, end_time));
813 prev_allocation_in_default_mem_it = allocations->rbegin();
814 }
815
816 CHECK(prev_allocation_in_default_mem_it != allocations->rend());
817 CHECK((*prev_allocation_in_default_mem_it)->memory_space() ==
818 MemorySpace::kDefault);
819
820 // If the buffer must be in default memory at the end_time, don't prefetch.
821 if (in_default_mem_at_end) {
822 VLOG(4)
823 << "Not trying to prefetch because use requires buffer in default mem.";
824 (*prev_allocation_in_default_mem_it)->Extend(end_time);
825 (*prev_allocation_in_default_mem_it)->AddUse(use);
826 return true;
827 }
828
829 // Try partially placing the buffer in the alternate space. The time that is
830 // overlapped will be used to asynchronously copy the buffer from the
831 // default memory to the alternate memory.
832 //
833 // start end
834 // time time
835 // X---------------------X
836 // Alternate: +------+
837 // Default: +---------------------+
838 // ^ ^
839 // Copy Copy
840 // Start Done
841 options_.prefetch_interval_picker->Begin(
842 use, (*prev_allocation_in_default_mem_it)->earliest_available_time(),
843 latest_prefetch_time);
844 VLOG(4) << "Trying prefetch picker = "
845 << options_.prefetch_interval_picker->ToDebugString();
846 while (!options_.prefetch_interval_picker->Done()) {
847 alternate_mem_interval.start = options_.prefetch_interval_picker->Next();
848 VLOG(4) << "Trying alternate memory allocation ("
849 << alternate_mem_interval.start << ", "
850 << alternate_mem_interval.end << ")";
851 // If this additional asynchronous copy would violate the limit, try a
852 // different interval.
853 if (ViolatesMaximumOutstandingAsyncCopies(alternate_mem_interval.start,
854 alternate_mem_interval.end)) {
855 VLOG(4) << "This would violate the outstanding async copy limit.";
856 continue;
857 }
858 if (ViolatesAsyncCopyOrdering(alternate_mem_interval.start,
859 alternate_mem_interval.end)) {
860 VLOG(4) << "This would violate asynchronous copy ordering.";
861 continue;
862 }
863
864 ChunkCandidate chunk_candidate = FindChunkCandidate(alternate_mem_interval);
865 // Check if the new heap size fits within limits.
866 if (chunk_candidate.heap_size < available_heap_size()) {
867 VLOG(3) << "Move the buffer to alternate memory at "
868 << alternate_mem_interval.start
869 << ". Offset = " << chunk_candidate.chunk.offset
870 << ", size = " << chunk_candidate.chunk.size
871 << ", heap_size = " << chunk_candidate.heap_size
872 << ", prefetch picker = "
873 << options_.prefetch_interval_picker->ToDebugString();
874 AddToPendingChunks(alternate_mem_interval, chunk_candidate);
875
876 AddAsyncCopy(**prev_allocation_in_default_mem_it, MemorySpace::kAlternate,
877 chunk_candidate.chunk, alternate_mem_interval.start,
878 end_time, latest_prefetch_time, allocations);
879
880 allocations->back()->AddUse(use);
881 return true;
882 }
883 }
884
885 // If a copy wasn't inserted, then add this use to the latest allocation in
886 // default memory.
887 (*prev_allocation_in_default_mem_it)->Extend(end_time);
888 (*prev_allocation_in_default_mem_it)->AddUse(use);
889 return true;
890 }
891
AddAsyncCopy(const MemorySpaceAssignment::Allocation & prev_allocation,MemorySpace memory_space,Chunk chunk,int64 start_time,int64 end_time,int64 copy_done_schedule_before_time,MemorySpaceAssignment::AllocationSequence * allocations)892 void AlternateMemoryBestFitHeap::AddAsyncCopy(
893 const MemorySpaceAssignment::Allocation& prev_allocation,
894 MemorySpace memory_space, Chunk chunk, int64 start_time, int64 end_time,
895 int64 copy_done_schedule_before_time,
896 MemorySpaceAssignment::AllocationSequence* allocations) {
897 VLOG(3) << "Copy to "
898 << (memory_space == MemorySpaceAssignment::MemorySpace::kDefault
899 ? "default"
900 : "alternate")
901 << " memory between " << start_time << " and "
902 << copy_done_schedule_before_time << " keeping until " << end_time;
903
904 allocations->push_back(
905 absl::make_unique<MemorySpaceAssignment::CopyAllocation>(
906 prev_allocation, memory_space, chunk, start_time, end_time,
907 copy_done_schedule_before_time));
908
909 // Register the additional async copy with the interval tree to keep track of
910 // the limit at any given time.
911 pending_async_copies_.push_back({start_time, end_time, memory_space});
912 }
913
ViolatesMaximumOutstandingAsyncCopies(int64 start_time,int64 end_time) const914 bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies(
915 int64 start_time, int64 end_time) const {
916 if (options_.max_outstanding_async_copies < 0) {
917 return false;
918 }
919
920 // Count both the asynchronous copies in the interval tree as well as the
921 // pending asynchronous copies belonging to this buffer.
922 int64 num_async_copies =
923 async_copy_interval_tree_.ChunksOverlappingInTime(start_time, end_time)
924 .size();
925
926 for (const auto& interval : pending_async_copies_) {
927 if (interval.start_time > start_time && interval.end_time < end_time) {
928 num_async_copies++;
929 }
930 }
931 // Add one because we are checking if adding an additional asynchronous copy
932 // would violate the limit.
933 return num_async_copies + 1 > options_.max_outstanding_async_copies;
934 }
935
ViolatesAsyncCopyOrdering(int64 start_time,int64 end_time) const936 bool AlternateMemoryBestFitHeap::ViolatesAsyncCopyOrdering(
937 int64 start_time, int64 end_time) const {
938 if (async_copy_ordering_.ViolatesOrdering(start_time, end_time)) {
939 return true;
940 }
941
942 // Also check pending async copies.
943 for (const auto& async_copy : pending_async_copies_) {
944 if (async_copy.destination == MemorySpace::kAlternate &&
945 async_copy.start_time <= end_time &&
946 start_time <= async_copy.end_time) {
947 return true;
948 }
949 }
950 return false;
951 }
952
TryAllocatingInAlternateMemoryNoCopy(int64 start_time,int64 end_time,int64 last_use_time,HloPosition defining_position,HloUse use,BufferInterval alternate_mem_interval,HloInstruction * non_bitcast_operand,MemorySpaceAssignment::AllocationSequence * allocations)953 bool AlternateMemoryBestFitHeap::TryAllocatingInAlternateMemoryNoCopy(
954 int64 start_time, int64 end_time, int64 last_use_time,
955 HloPosition defining_position, HloUse use,
956 BufferInterval alternate_mem_interval, HloInstruction* non_bitcast_operand,
957 MemorySpaceAssignment::AllocationSequence* allocations) {
958 MemorySpaceAssignment::Allocation* prev_allocation = nullptr;
959 bool can_eliminate_copy = false;
960 if (allocations->empty()) {
961 // There hasn't been any allocations for this interval so far. We can
962 // eliminate copy if the value can be placed in the alternate memory.
963 can_eliminate_copy =
964 options_.is_allowed_in_alternate_mem_fn(*alternate_mem_interval.buffer);
965 } else {
966 // If there has been a previous allocation, we can eliminate the copy if the
967 // previous allocation was also in the alternate memory.
968 prev_allocation = allocations->back().get();
969 can_eliminate_copy =
970 (prev_allocation->memory_space() == MemorySpace::kAlternate);
971 }
972
973 if (!can_eliminate_copy) {
974 return false;
975 }
976
977 if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
978 non_bitcast_operand->shape(), start_time + 1, end_time)) {
979 return false;
980 }
981
982 alternate_mem_interval.start = start_time;
983
984 // Prefer the offset that was previously used for the previous allocation.
985 absl::optional<int64> preferred_offset;
986 if (prev_allocation != nullptr) {
987 preferred_offset = prev_allocation->chunk().offset;
988 // If there is a previous allocation, set the start time one after the end
989 // of the previous allocation's end.
990 alternate_mem_interval.start = prev_allocation->end_time() + 1;
991 }
992
993 VLOG(4) << "We can eliminate copy to alternate memory. Preferred offset = "
994 << (preferred_offset ? *preferred_offset : -1);
995 // In case there are additional uses after this use, we rely on the last use
996 // time to try to reserve a chunk in the heap simulator. This is to prevent
997 // the following scenario:
998 //
999 // +-------+
1000 // / \
1001 // Producer--->Use1 +-->Use2
1002 // +---------+---------+
1003 // New buffer: | | |
1004 // +---------+---------+
1005 //
1006 // +-----------+
1007 // Current heap: | offset: 0 |
1008 // --------------------------+-----------+------
1009 //
1010 // Because we allocate buffers greedily, Producer to Use1 segment first, and
1011 // then Use1 to Use2 segment, it is possible to allocate the first segment at
1012 // an offset that is available for the first segment (e.g. offset 0) but not
1013 // for the entire live range. This can result in unnecessary copies. By using
1014 // the last use time, we try to find an allocation that is available for the
1015 // entire Producer to Use2 range.
1016 absl::optional<ChunkCandidate> chunk_candidate = FindBestNoCopyChunkCandidate(
1017 end_time, last_use_time, preferred_offset, &alternate_mem_interval);
1018 // Check if the new heap size fits within limits. Also ensure if a
1019 // preferred offset was provided, that offset was used.
1020 if (chunk_candidate) {
1021 VLOG(3) << "Keep the buffer in alternate memory. Offset = "
1022 << chunk_candidate->chunk.offset
1023 << ", size = " << chunk_candidate->chunk.size
1024 << ", heap_size = " << chunk_candidate->heap_size
1025 << ", prefetch picker = "
1026 << options_.prefetch_interval_picker->ToNoCopyDebugString(
1027 non_bitcast_operand->shape(), start_time, end_time);
1028 AddToPendingChunks(alternate_mem_interval, *chunk_candidate);
1029
1030 // If there was a previous allocation, the buffer location is the
1031 // same as the previous. Otherwise, it is the operand.
1032 if (prev_allocation != nullptr &&
1033 (prev_allocation->is_copy_allocation() ||
1034 prev_allocation->defining_position() == defining_position)) {
1035 prev_allocation->Extend(end_time);
1036 } else {
1037 allocations->push_back(
1038 absl::make_unique<MemorySpaceAssignment::Allocation>(
1039 non_bitcast_operand, defining_position, MemorySpace::kAlternate,
1040 chunk_candidate->chunk, start_time, end_time));
1041 }
1042 allocations->back()->AddUse(use);
1043 return true;
1044 }
1045 return false;
1046 }
1047
1048 absl::optional<AlternateMemoryBestFitHeap::ChunkCandidate>
FindBestNoCopyChunkCandidate(int64 end_time,int64 last_use_time,absl::optional<int64> preferred_offset,BufferInterval * alternate_mem_interval) const1049 AlternateMemoryBestFitHeap::FindBestNoCopyChunkCandidate(
1050 int64 end_time, int64 last_use_time, absl::optional<int64> preferred_offset,
1051 BufferInterval* alternate_mem_interval) const {
1052 if (!preferred_offset) {
1053 // Find a chunk that's as long living as possible.
1054 for (alternate_mem_interval->end = last_use_time;
1055 alternate_mem_interval->end >= end_time;
1056 --alternate_mem_interval->end) {
1057 ChunkCandidate chunk_candidate =
1058 FindChunkCandidate(*alternate_mem_interval);
1059 if (chunk_candidate.heap_size <= available_heap_size()) {
1060 alternate_mem_interval->end = end_time;
1061 return chunk_candidate;
1062 }
1063 }
1064 return absl::nullopt;
1065 }
1066 // If a preferred offset is given, try to find an allocation at that offset
1067 // only.
1068 alternate_mem_interval->end = end_time;
1069 ChunkCandidate chunk_candidate =
1070 FindChunkCandidate(*alternate_mem_interval, *preferred_offset);
1071 if (chunk_candidate.chunk.offset == *preferred_offset) {
1072 return chunk_candidate;
1073 }
1074 return absl::nullopt;
1075 }
1076
CountMaximumOutstandingAsyncCopies(const HloModule & module)1077 /*static*/ int64 MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies(
1078 const HloModule& module) {
1079 int64 max_copies = 0;
1080 int64 current_copies = 0;
1081 for (HloInstruction* instruction :
1082 module.schedule().sequence(module.entry_computation()).instructions()) {
1083 if (instruction->opcode() == HloOpcode::kCopyStart) {
1084 current_copies++;
1085 } else if (instruction->opcode() == HloOpcode::kCopyDone) {
1086 current_copies--;
1087 }
1088 max_copies = std::max(max_copies, current_copies);
1089 }
1090 return max_copies;
1091 }
1092
1093 /*static*/ MemorySpaceAssignment::BufferIntervalCompare
GetMemoryBoundednessBufferIntervalCompare(const MemorySpaceAssignmentCostAnalysis & cost_analysis)1094 MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
1095 const MemorySpaceAssignmentCostAnalysis& cost_analysis) {
1096 return [&](const BufferInterval& x, const BufferInterval& y) {
1097 // Returns a heuristic value that captures how much putting this tensor to
1098 // the alternate memory would help if the op is memory bound, or otherwise
1099 // how far off is the op to memory boundedness. The larger this number, the
1100 // higher priority it will be placed in the alternate memory.
1101 auto get_alternate_mem_benefit =
1102 [&](const HloInstruction& instruction,
1103 float elapsed_time_due_to_alternate_mem) {
1104 float elapsed_time_due_to_compute =
1105 cost_analysis.GetInstructionElapsedDueToCompute(instruction);
1106 float elapsed_time_due_to_memory =
1107 cost_analysis.GetInstructionElapsedDueToMemory(instruction);
1108 if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) {
1109 // Memory bound, return how much alternate memory is better.
1110 return elapsed_time_due_to_memory -
1111 elapsed_time_due_to_alternate_mem;
1112 } else {
1113 // Compute bound, return how far off are we to memory boundedness.
1114 return elapsed_time_due_to_memory - elapsed_time_due_to_compute;
1115 }
1116 };
1117
1118 auto get_memory_boundedness = [&](const BufferInterval& interval) {
1119 const HloInstruction& defining_instruction =
1120 *interval.buffer->defining_instruction();
1121 float alternate_mem_benefit = get_alternate_mem_benefit(
1122 defining_instruction, cost_analysis.GetInstructionElapsedDueToMemory(
1123 defining_instruction,
1124 /*operand_in_alternate_mem=*/{},
1125 /*output_in_alternate_mem=*/true));
1126 for (const HloUse& use : interval.buffer->uses()) {
1127 float use_alternate_mem_benefit = get_alternate_mem_benefit(
1128 *use.instruction, cost_analysis.GetInstructionElapsedDueToMemory(
1129 *use.instruction, use.operand_number));
1130 // If the benefit is positive (memory bound), add it to this buffer's
1131 // benefit. If the benefit is negative (compute bound), calculate the
1132 // maximum.
1133 if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) {
1134 alternate_mem_benefit += use_alternate_mem_benefit;
1135 } else {
1136 alternate_mem_benefit =
1137 std::max(alternate_mem_benefit, use_alternate_mem_benefit);
1138 }
1139 }
1140
1141 // Get performance slowdown in seconds of prefetching current
1142 // BufferInterval causing to other BufferIntervals.
1143 float alternate_mem_slowdown =
1144 cost_analysis.GetInstructionElapsedDueToMemorySlowdown(interval.size);
1145
1146 // Scale the slowdown based on the time of this buffer. We would want
1147 // earlier buffers have lower slowdown values, because they are less
1148 // likely to overlap with other HLOs.
1149 // TODO (yuemmawang) We may want a piecewise function, where a lower
1150 // slowdown for early HLOs, and full slowdown for mid-to-late HLOs.
1151 // TODO (yuemmawang) Further in a smarter way, we want buffers overlapped
1152 // with more HLOs have higher slowdown, and vice versa.
1153 float scale = interval.start * 1.0 / cost_analysis.GetScheduleEndTime();
1154 alternate_mem_slowdown *= scale;
1155
1156 return alternate_mem_benefit - alternate_mem_slowdown;
1157 };
1158
1159 float x_memory_boundedness = get_memory_boundedness(x);
1160 float y_memory_boundedness = get_memory_boundedness(y);
1161 if (x_memory_boundedness != y_memory_boundedness) {
1162 return x_memory_boundedness > y_memory_boundedness;
1163 }
1164 // Tie-break if the memory boundedness is the same.
1165 return GlobalDecreasingSizeBestFitHeap::GetSpatialBufferIntervalCompare()(
1166 x, y);
1167 };
1168 }
1169
1170 /*static*/ StatusOr<std::unique_ptr<PresetAssignments>>
Run(HloModule * module,const HloLiveRange & hlo_live_range,const HloAliasAnalysis & alias_analysis,const Options & options)1171 MemorySpaceAssignment::Run(HloModule* module,
1172 const HloLiveRange& hlo_live_range,
1173 const HloAliasAnalysis& alias_analysis,
1174 const Options& options) {
1175 CHECK(module->has_schedule());
1176 VLOG(4) << "Module before memory space assignment: ";
1177 XLA_VLOG_LINES(4, module->ToString());
1178 VLOG(4) << "Schedule: " << module->schedule().ToString();
1179 MemorySpaceAssignment memory_space_assignment(module, options,
1180 hlo_live_range);
1181 auto algorithm = absl::make_unique<AlternateMemoryBestFitHeap>(
1182 &memory_space_assignment.allocation_sequence_list_, options,
1183 alias_analysis, hlo_live_range);
1184
1185 HeapSimulator::Options heap_simulator_options;
1186 heap_simulator_options.may_reuse_operand_buffers = false;
1187 TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module,
1188 module->schedule(), alias_analysis,
1189 options.size_fn, heap_simulator_options)
1190 .status());
1191
1192 TF_RETURN_IF_ERROR(memory_space_assignment.Process());
1193 memory_space_assignment.ScheduleAsynchronousCopies();
1194 TF_RETURN_IF_ERROR(memory_space_assignment.SimplifyGraph());
1195 TF_RETURN_IF_ERROR(memory_space_assignment.FixSchedule());
1196
1197 VLOG(4) << "Module after memory space assignment: ";
1198 XLA_VLOG_LINES(4, module->ToString());
1199 TF_CHECK_OK(module->schedule().Verify());
1200 VLOG(1) << "Maximum number of outstanding async copies: "
1201 << CountMaximumOutstandingAsyncCopies(*module);
1202
1203 TF_RETURN_IF_ERROR(
1204 memory_space_assignment.VerifyAndExportHeapSimulatorTrace());
1205
1206 return std::move(memory_space_assignment.preset_assignments_);
1207 }
1208
AddUse(HloUse use)1209 void MemorySpaceAssignment::Allocation::AddUse(HloUse use) {
1210 HloInstruction* operand =
1211 use.instruction->mutable_operand(use.operand_number);
1212 // If the use is a tuple, look inside the tuple to find the actual use.
1213 for (int64 index : use.operand_index) {
1214 if (operand->opcode() != HloOpcode::kTuple) {
1215 break;
1216 }
1217 operand = operand->mutable_operand(index);
1218 }
1219
1220 // Look beyond GetTupleElement(Tuple()) pattern for any bitcasts.
1221 std::function<HloInstruction*(HloInstruction*)> get_simplified_operand;
1222 get_simplified_operand = [&](HloInstruction* instruction) {
1223 while (instruction->opcode() == HloOpcode::kGetTupleElement) {
1224 HloInstruction* operand =
1225 get_simplified_operand(instruction->mutable_operand(0));
1226 if (operand->opcode() == HloOpcode::kTuple) {
1227 instruction = operand->mutable_operand(instruction->tuple_index());
1228 } else {
1229 return instruction;
1230 }
1231 }
1232 return instruction;
1233 };
1234 operand = get_simplified_operand(operand);
1235
1236 uses_.push_back(use);
1237 }
1238
Process(MemorySpaceAssignment * memory_space_assignment)1239 Status MemorySpaceAssignment::Allocation::Process(
1240 MemorySpaceAssignment* memory_space_assignment) {
1241 return Status::OK();
1242 }
1243
ReplaceTupleWith(HloInstruction * new_instruction,HloInstruction * tuple,ShapeIndex shape_index)1244 StatusOr<HloInstruction*> MemorySpaceAssignment::Allocation::ReplaceTupleWith(
1245 HloInstruction* new_instruction, HloInstruction* tuple,
1246 ShapeIndex shape_index) {
1247 const Shape& tuple_shape = tuple->shape();
1248 CHECK(tuple->shape().IsTuple())
1249 << "ReplaceTupleWith was called for a non-tuple. Tuple = "
1250 << tuple->ToString()
1251 << ", new_instruction = " << new_instruction->ToString()
1252 << ", shape_index = " << shape_index.ToString();
1253
1254 HloComputation* computation = new_instruction->parent();
1255 std::vector<HloInstruction*> tuple_args(tuple_shape.tuple_shapes_size());
1256 for (int64 i = 0; i < tuple_shape.tuple_shapes_size(); ++i) {
1257 const Shape& subshape = tuple_shape.tuple_shapes(i);
1258 if (i == shape_index[0]) {
1259 // If the subshape is still a tuple, recurse and pass a new shape index
1260 // for the one level deeper.
1261 if (subshape.IsTuple()) {
1262 HloInstruction* get_tuple_element = computation->AddInstruction(
1263 HloInstruction::CreateGetTupleElement(subshape, tuple, i));
1264 TF_ASSIGN_OR_RETURN(tuple_args[i],
1265 ReplaceTupleWith(new_instruction, get_tuple_element,
1266 ShapeIndex(shape_index.begin() + 1,
1267 shape_index.end())));
1268 } else {
1269 if (subshape != new_instruction->shape()) {
1270 VLOG(4) << "Old shape = " << subshape.ToString()
1271 << ", new shape = " << new_instruction->shape().ToString()
1272 << "; inserting a bitcast.";
1273 new_instruction = computation->AddInstruction(
1274 HloInstruction::CreateBitcast(subshape, new_instruction));
1275 }
1276 tuple_args[i] = new_instruction;
1277 }
1278 } else {
1279 HloInstruction* get_tuple_element = computation->AddInstruction(
1280 HloInstruction::CreateGetTupleElement(subshape, tuple, i));
1281 tuple_args[i] = get_tuple_element;
1282 }
1283 }
1284 return computation->AddInstruction(HloInstruction::CreateTuple(tuple_args));
1285 }
1286
Process(MemorySpaceAssignment * memory_space_assignment)1287 Status MemorySpaceAssignment::CopyAllocation::Process(
1288 MemorySpaceAssignment* memory_space_assignment) {
1289 // Copy allocations need to insert asynchronous copy nodes.
1290 HloInstruction* producing_instruction = defining_position().instruction;
1291 CHECK_NE(producing_instruction, nullptr);
1292
1293 Shape shape = defining_position().shape();
1294 CHECK(shape.IsArray()) << "CopyAllocation shape is not an array. Shape = "
1295 << shape.ToString()
1296 << " position = " << defining_position().shape();
1297 HloComputation* computation = producing_instruction->parent();
1298
1299 // If the instruction we're copying from is a tuple, we (recursively) create
1300 // kGetTupleElement instructions and copy that value. Asynchronous copies only
1301 // support array types.
1302 if (!producing_instruction->shape().IsArray()) {
1303 producing_instruction = defining_position().instruction;
1304 for (int64 index : defining_position().index) {
1305 producing_instruction =
1306 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
1307 producing_instruction->shape().tuple_shapes(index),
1308 producing_instruction, index));
1309 }
1310 }
1311 copy_start_ = computation->AddInstruction(HloInstruction::CreateUnary(
1312 ShapeUtil::MakeTupleShape({shape, shape, ShapeUtil::MakeShape(U32, {})}),
1313 HloOpcode::kCopyStart, producing_instruction));
1314 copy_done_ = computation->AddInstruction(
1315 HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_));
1316 // Update the allocation with the copy done instruction so that if there
1317 // are further copies from it, it can find the correct instruction.
1318 instruction_ = copy_done_;
1319
1320 // Also update the defining position.
1321 defining_position_ = HloPosition{copy_done_, {}};
1322
1323 // Replace all the uses with the new copy instruction.
1324 for (HloUse use : uses_) {
1325 // If the operand is a tuple, we need to descend to the actual instruction
1326 // we want to replace.
1327 HloInstruction* replacement_instruction;
1328 Shape operand_shape = use.instruction->operand(use.operand_number)->shape();
1329 if (operand_shape.IsTuple()) {
1330 TF_ASSIGN_OR_RETURN(
1331 replacement_instruction,
1332 ReplaceTupleWith(copy_done_,
1333 use.instruction->mutable_operand(use.operand_number),
1334 use.operand_index));
1335 } else if (operand_shape != copy_done_->shape()) {
1336 VLOG(4) << "Old shape = " << operand_shape.ToString()
1337 << ", new shape = " << copy_done_->shape().ToString()
1338 << "; inserting a bitcast.";
1339 replacement_instruction = computation->AddInstruction(
1340 HloInstruction::CreateBitcast(operand_shape, copy_done_));
1341 } else {
1342 replacement_instruction = copy_done_;
1343 }
1344 TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith(
1345 use.operand_number, replacement_instruction));
1346 }
1347
1348 return Status::OK();
1349 }
1350
Process()1351 Status MemorySpaceAssignment::Process() {
1352 // Insert CopyStart/CopyDone pairs.
1353 int64 alternate_memory_size = 0;
1354 for (auto& value_and_sequence : allocation_sequence_list_) {
1355 for (auto& allocation : value_and_sequence.sequence) {
1356 TF_RETURN_IF_ERROR(allocation->Process(this));
1357 // Add the offset and size of the allocation in the alternate memory to
1358 // the output map. Special case for bitcast: since bitcast doesn't define
1359 // its own buffer, that shouldn't be exported as a preset chunk.
1360 if (allocation->memory_space() == MemorySpace::kAlternate &&
1361 allocation->instruction()->opcode() != HloOpcode::kBitcast) {
1362 preset_assignments_->add_chunk(allocation->defining_position(),
1363 allocation->chunk());
1364 alternate_memory_size =
1365 std::max(alternate_memory_size, allocation->chunk().chunk_end());
1366 }
1367 }
1368 }
1369
1370 if (!preset_assignments_->chunks().empty()) {
1371 preset_assignments_
1372 ->assignment_information_for_space(options_.alternate_memory_space)
1373 ->size = alternate_memory_size;
1374 }
1375
1376 if (VLOG_IS_ON(3)) {
1377 VLOG(3) << "Exported alternate memory allocations:";
1378 for (auto& pair : preset_assignments_->chunks()) {
1379 VLOG(3) << " [" << pair.second.offset << ", " << pair.second.size
1380 << "] : " << pair.first.ToString();
1381 }
1382 VLOG(3) << "Exported alternate memory sizes:";
1383 for (auto& pair : preset_assignments_->assignment_informations()) {
1384 VLOG(3) << " space: " << pair.first << ", size: " << pair.second.size;
1385 }
1386 }
1387
1388 // Color the pending positions and all of their aliased buffers.
1389 TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module_));
1390 for (const auto& defining_position_and_chunk :
1391 preset_assignments_->chunks()) {
1392 const HloPosition& defining_position = defining_position_and_chunk.first;
1393 for (auto& buffer : alias_analysis->ComputeBuffersAt(
1394 defining_position.instruction, defining_position.index)) {
1395 for (auto& value : buffer->values()) {
1396 for (auto& position : value->positions()) {
1397 VLOG(3) << "Coloring " << position.ToString();
1398 Shape* shape = ShapeUtil::GetMutableSubshape(
1399 position.instruction->mutable_shape(), position.index);
1400 CHECK(shape->IsArray()) << "Coloring a shape that is not an array: "
1401 << position.ToString();
1402 shape->mutable_layout()->set_memory_space(
1403 options_.alternate_memory_space);
1404 }
1405 }
1406 }
1407 }
1408
1409 return Status::OK();
1410 }
1411
RemoveAssignmentForInstruction(const HloInstruction * instruction)1412 void PresetAssignments::RemoveAssignmentForInstruction(
1413 const HloInstruction* instruction) {
1414 for (auto& position_and_chunk : chunks_) {
1415 const HloPosition& position = position_and_chunk.first;
1416 if (position.instruction == instruction) {
1417 VLOG(3) << "Removing instruction from preset assignments.";
1418 // Swap the removed position and chunk with the back and pop back.
1419 position_and_chunk = chunks_.back();
1420 chunks_.pop_back();
1421 break;
1422 }
1423 }
1424 }
1425
SimplifyGraph()1426 Status MemorySpaceAssignment::SimplifyGraph() {
1427 for (HloComputation* computation : module_->MakeNonfusionComputations()) {
1428 // Parallel computations aren't in the schedule and don't need to be
1429 // modified.
1430 if (!computations_in_schedule_.contains(computation)) {
1431 VLOG(4) << "Not simplifying " << computation->name()
1432 << " because it's not in the schedule.";
1433 continue;
1434 }
1435 // Drop control dependencies. Since the computation is already scheduled, we
1436 // don't need control dependencies anymore, and having control
1437 // predecessors/successors prevents us from removing instructions without
1438 // users (HloComputation::IsSafelyRemovable returns false if there are
1439 // control dependencies).
1440 for (HloInstruction* instruction :
1441 computation->MakeInstructionPostOrder()) {
1442 TF_RETURN_IF_ERROR(instruction->DropAllControlDeps());
1443 }
1444 // We perform limited DCE and forward the tuple operand in patterns like
1445 // GetTupleElement(Tuple(a, b), 0). This is mostly because memory space
1446 // assignment is ran late in compilation (after DCE and arithmetic
1447 // simplification passes) and we don't want to generate redundant code. Run
1448 // to fixed point.
1449 bool computation_modified = true;
1450 while (computation_modified) {
1451 computation_modified = false;
1452 VLOG(4) << "Running simplify graph loop over " << computation->name();
1453 for (HloInstruction* instruction :
1454 computation->MakeInstructionPostOrder()) {
1455 if (computation->IsSafelyRemovable(instruction) &&
1456 instruction->user_count() == 0 && !instruction->HasSideEffect() &&
1457 instruction != computation->root_instruction() &&
1458 instruction->opcode() != HloOpcode::kCopyStart &&
1459 instruction->opcode() != HloOpcode::kCopyDone) {
1460 VLOG(4) << "Instruction removed: " << instruction->ToString();
1461 // Ensure the exported preset assignments don't contain a reference to
1462 // the removed instruction.
1463 preset_assignments_->RemoveAssignmentForInstruction(instruction);
1464 // Instead of deleting the instruction from the schedule, replace it
1465 // with a nullptr. This is needed because FixSchedule relies on the
1466 // logical time that is the index into flattened_instructions_ for
1467 // scheduling asynchronous copies.
1468 auto instruction_it =
1469 absl::c_find(flattened_instructions_, instruction);
1470 if (instruction_it != flattened_instructions_.end()) {
1471 *instruction_it = nullptr;
1472 }
1473 TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
1474 computation_modified = true;
1475 } else if (instruction->opcode() == HloOpcode::kGetTupleElement) {
1476 HloInstruction* operand = instruction->mutable_operand(0);
1477 if (operand->opcode() == HloOpcode::kTuple) {
1478 HloInstruction* forwarded_instruction =
1479 operand->mutable_operand(instruction->tuple_index());
1480 VLOG(4) << "Replacing uses of " << instruction->ToString()
1481 << " with " << forwarded_instruction->ToString();
1482 TF_RETURN_IF_ERROR(
1483 instruction->ReplaceAllUsesWith(forwarded_instruction));
1484 computation_modified = true;
1485 }
1486 }
1487 }
1488 }
1489 }
1490
1491 return Status::OK();
1492 }
1493
EnsureInstructionAndOperandsInserted(HloInstruction * new_instruction,HloInstructionSequence * new_sequence,absl::flat_hash_set<HloInstruction * > * inserted_instructions) const1494 void MemorySpaceAssignment::EnsureInstructionAndOperandsInserted(
1495 HloInstruction* new_instruction, HloInstructionSequence* new_sequence,
1496 absl::flat_hash_set<HloInstruction*>* inserted_instructions) const {
1497 if (inserted_instructions->contains(new_instruction)) {
1498 return;
1499 }
1500 for (HloInstruction* operand : new_instruction->operands()) {
1501 // CopyStart/CopyDone dependencies should always be already inserted; it is
1502 // a red flag when they haven't already been inserted.
1503 CHECK((operand->opcode() != HloOpcode::kCopyStart &&
1504 operand->opcode() != HloOpcode::kCopyDone) ||
1505 inserted_instructions->contains(operand))
1506 << "Inserted instruction " << new_instruction->ToString()
1507 << " has un-inserted dependency: " << operand->ToString();
1508 EnsureInstructionAndOperandsInserted(operand, new_sequence,
1509 inserted_instructions);
1510 }
1511 VLOG(4) << "inserting: " << new_instruction->ToShortString();
1512 new_sequence->push_back(new_instruction);
1513 inserted_instructions->insert(new_instruction);
1514 }
1515
ScheduleAsynchronousCopies()1516 void MemorySpaceAssignment::ScheduleAsynchronousCopies() {
1517 for (MemorySpace memory_space :
1518 {MemorySpace::kDefault, MemorySpace::kAlternate}) {
1519 std::vector<CopyAllocation*> copy_allocations;
1520 for (auto& value_and_sequence : allocation_sequence_list_) {
1521 for (auto& allocation : value_and_sequence.sequence) {
1522 if (allocation->is_copy_allocation()) {
1523 auto copy_allocation = static_cast<CopyAllocation*>(allocation.get());
1524 if (copy_allocation->memory_space() == memory_space) {
1525 copy_allocations.push_back(copy_allocation);
1526 }
1527 }
1528 }
1529 }
1530
1531 absl::c_stable_sort(
1532 copy_allocations, [](CopyAllocation* first, CopyAllocation* second) {
1533 return std::forward_as_tuple(first->copy_done_schedule_before(),
1534 first->copy_start_schedule_after()) <
1535 std::forward_as_tuple(second->copy_done_schedule_before(),
1536 second->copy_start_schedule_after());
1537 });
1538
1539 CopyAllocation* prev_copy_allocation = nullptr;
1540 for (CopyAllocation* copy_allocation : copy_allocations) {
1541 // If the copy start doesn't happen to be scheduled at the correct
1542 // computation, delay it until the correct computation starts.
1543 int64 copy_start_schedule_after =
1544 copy_allocation->copy_start_schedule_after();
1545 // Accessing flattened_instructions_ here without checking if it is
1546 // nullptr is safe because this method is called before SimplifyGraph.
1547 while (copy_allocation->instruction()->parent() !=
1548 flattened_instructions_[copy_start_schedule_after]->parent()) {
1549 VLOG(4) << "Delaying CopyStart (" << copy_start_schedule_after << " to "
1550 << (copy_start_schedule_after + 1) << ") for "
1551 << copy_allocation->copy_start()->ToString()
1552 << " because it is not in the correct computation.";
1553 copy_allocation->set_copy_start_schedule_after(
1554 ++copy_start_schedule_after);
1555 }
1556
1557 schedule_after_[copy_allocation->copy_start_schedule_after()].push_back(
1558 copy_allocation->copy_start());
1559 schedule_before_[copy_allocation->copy_done_schedule_before()].push_back(
1560 copy_allocation->copy_done());
1561 prev_copy_allocation = copy_allocation;
1562 }
1563 }
1564 }
1565
FixSchedule()1566 Status MemorySpaceAssignment::FixSchedule() {
1567 CHECK(module_->has_schedule());
1568 HloSchedule& schedule = module_->schedule();
1569 for (const HloComputation* computation :
1570 module_->MakeNonfusionComputations()) {
1571 // Parallel computations aren't in the schedule and don't need to be
1572 // modified.
1573 if (!computations_in_schedule_.contains(computation)) {
1574 VLOG(4) << "Not scheduling " << computation->name()
1575 << " because it's not in the schedule.";
1576 continue;
1577 }
1578 CHECK(schedule.is_computation_scheduled(computation));
1579 HloInstructionSequence new_sequence;
1580
1581 absl::flat_hash_set<HloInstruction*> inserted_instructions;
1582
1583 VLOG(4) << "Scheduling: " << computation->ToString();
1584
1585 for (int64 instruction_index = 0;
1586 instruction_index < flattened_instructions_.size();
1587 ++instruction_index) {
1588 auto insts_before_iter = schedule_before_.find(instruction_index);
1589 if (insts_before_iter != schedule_before_.end()) {
1590 for (HloInstruction* new_instruction : insts_before_iter->second) {
1591 if (new_instruction->parent() == computation) {
1592 VLOG(4) << "before " << instruction_index << ": "
1593 << new_instruction->name();
1594 EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence,
1595 &inserted_instructions);
1596 }
1597 }
1598 }
1599 HloInstruction* instruction = flattened_instructions_[instruction_index];
1600 // Insert only if it is not deleted (SimplifyGraph sets it to nullptr if
1601 // it was deleted) and not previously inserted. Also bitcasts and tuples
1602 // are treated specially and only inserted as a result of operand
1603 // dependencies.
1604 if (instruction != nullptr &&
1605 !inserted_instructions.contains(instruction) &&
1606 instruction->parent() == computation &&
1607 instruction->opcode() != HloOpcode::kBitcast &&
1608 instruction->opcode() != HloOpcode::kTuple) {
1609 VLOG(4) << "inst " << instruction_index << ": " << instruction->name();
1610 EnsureInstructionAndOperandsInserted(instruction, &new_sequence,
1611 &inserted_instructions);
1612 }
1613 auto insts_after_iter = schedule_after_.find(instruction_index);
1614 if (insts_after_iter != schedule_after_.end()) {
1615 for (HloInstruction* new_instruction : insts_after_iter->second) {
1616 if (new_instruction->parent() == computation) {
1617 VLOG(4) << "after " << instruction_index << ": "
1618 << new_instruction->name();
1619 EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence,
1620 &inserted_instructions);
1621 }
1622 }
1623 }
1624 }
1625 // For rare cases where the original sequence is empty, ensure the root
1626 // instruction and its dependencies are scheduled.
1627 EnsureInstructionAndOperandsInserted(computation->root_instruction(),
1628 &new_sequence, &inserted_instructions);
1629 CHECK_EQ(new_sequence.size(), computation->instruction_count())
1630 << "New sequence for computation " << computation->name() << " has "
1631 << new_sequence.size() << " instructions, expects "
1632 << computation->instruction_count() << ".";
1633 schedule.set_sequence(computation, new_sequence);
1634 }
1635
1636 return Status::OK();
1637 }
1638
VerifyAndExportHeapSimulatorTrace()1639 Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() {
1640 VLOG(3) << "Verifying:";
1641 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1642 HloAliasAnalysis::Run(module_));
1643 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
1644 HloLiveRange::Run(module_->schedule(), *alias_analysis,
1645 module_->entry_computation()));
1646
1647 BufferIntervalTree interval_tree;
1648 absl::flat_hash_set<int64> seen_buffers;
1649 std::map<std::pair<int64, int64>,
1650 std::tuple<const HloValue*, Chunk, HeapSimulatorTrace::Event::Kind>>
1651 events;
1652
1653 for (const auto& position_and_chunk : preset_assignments_->chunks()) {
1654 const HloPosition& position = position_and_chunk.first;
1655 const Chunk& chunk = position_and_chunk.second;
1656 const HloBuffer& buffer =
1657 alias_analysis->GetUniqueBufferAt(position.instruction, position.index);
1658 if (seen_buffers.contains(buffer.id())) {
1659 continue;
1660 }
1661 seen_buffers.insert(buffer.id());
1662
1663 int64 start_time = INT64_MAX;
1664 int64 end_time = -1;
1665 for (const HloValue* value : buffer.values()) {
1666 const HloLiveRange::TimeBound& time_bound =
1667 hlo_live_range->buffer_live_ranges().at(value);
1668 VLOG(3) << " value: " << value->ToShortString() << " ("
1669 << time_bound.start << ", " << time_bound.end << ")";
1670 start_time = std::min(start_time, time_bound.start);
1671 end_time = std::max(end_time, time_bound.end);
1672 events[std::make_pair(time_bound.start, value->id())] =
1673 std::make_tuple(value, chunk, HeapSimulatorTrace::Event::ALLOC);
1674 events[std::make_pair(time_bound.end, value->id())] =
1675 std::make_tuple(value, chunk, HeapSimulatorTrace::Event::FREE);
1676 }
1677 CHECK_GE(start_time, 0);
1678 CHECK_GT(end_time, 0);
1679 // Get the chunks overlapping in time and search if they overlap in space as
1680 // well.
1681 // TODO(berkin): For now checking against end_time - 1 (exclusive), but we
1682 // really should check against end_time (inclusive) for cases where the
1683 // operand can't share buffer with user (see
1684 // HloDataflowAnalysis::CanShareOperandBufferWithUser).
1685 if (options_.verify || VLOG_IS_ON(1)) {
1686 // Verify only if the option is set or if vlog is on.
1687 for (const Chunk& overlapping_chunk :
1688 interval_tree.ChunksOverlappingInTime(start_time, end_time - 1)) {
1689 if (chunk.OverlapsWith(overlapping_chunk)) {
1690 return InternalError(
1691 ("Buffer %s (%d, %d) off: %d size: %d overlaps with another chunk"
1692 " off: %d size: %d"),
1693 buffer.ToString(), start_time, end_time, chunk.offset, chunk.size,
1694 overlapping_chunk.offset, overlapping_chunk.size);
1695 }
1696 }
1697 }
1698 interval_tree.Add(start_time, end_time - 1, chunk);
1699 VLOG(3) << " buffer: " << buffer.ToString() << ": (" << start_time << ", "
1700 << end_time << ") off: " << position_and_chunk.second.offset
1701 << ", size: " << position_and_chunk.second.size;
1702 }
1703
1704 HeapSimulatorTrace* heap_trace =
1705 &preset_assignments_
1706 ->assignment_information_for_space(options_.alternate_memory_space)
1707 ->heap_simulator_trace;
1708 int64 memory_usage = 0;
1709 int64 max_memory_usage = 0;
1710 for (const auto& event : events) {
1711 int64 time = event.first.first;
1712 int64 buffer_id = event.first.second;
1713 const HloValue* value;
1714 Chunk chunk;
1715 HeapSimulatorTrace::Event::Kind kind;
1716 std::tie(value, chunk, kind) = event.second;
1717 HeapSimulatorTrace::Event* heap_trace_event = heap_trace->add_events();
1718 heap_trace_event->set_kind(kind);
1719 heap_trace_event->set_buffer_id(buffer_id);
1720 heap_trace_event->set_instruction_name(value->instruction()->name());
1721 heap_trace_event->set_computation_name(
1722 value->instruction()->parent()->name());
1723
1724 if (kind == HeapSimulatorTrace::Event::ALLOC) {
1725 memory_usage += chunk.size;
1726 } else {
1727 CHECK_EQ(kind, HeapSimulatorTrace::Event::FREE);
1728 memory_usage -= chunk.size;
1729 }
1730 max_memory_usage = std::max(max_memory_usage, memory_usage);
1731 VLOG(3) << "Memory usage: " << memory_usage << " at time: " << time;
1732 }
1733 VLOG(1) << "Max memory usage ignoring fragmentation: " << max_memory_usage;
1734
1735 return Status::OK();
1736 }
1737
1738 } // namespace xla
1739