1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/memory_space_assignment.h"
17
18 #include "tensorflow/compiler/xla/debug_options_flags.h"
19 #include "tensorflow/compiler/xla/service/memory_space_assignment_utils.h"
20 #include "tensorflow/core/lib/math/math_util.h"
21 namespace xla {
22
23 namespace {
24 // Define a dummy chunk for chunks that will be allocated in the default memory
25 // space and for keeping track of number of asynchronous copies.
26 const HeapSimulator::Chunk kDummyChunk{-1, -1};
27 // This variable is used by the cost analysis in estimating how many times each
28 // while loop will execute. Nested loops will be assumed to have executed
29 // pow(kWhileExecutionCount, nesting_level) times.
30 const int kWhileExecutionCount = 5;
31
LooksLikeAnActivation(const HloInstruction * inst)32 bool LooksLikeAnActivation(const HloInstruction* inst) {
33 for (HloInstruction* user : inst->users()) {
34 switch (user->opcode()) {
35 case HloOpcode::kConvolution:
36 case HloOpcode::kDot:
37 if (user->operand(0) == inst) {
38 return true;
39 }
40 break;
41 case HloOpcode::kGather:
42 if (user->operand(1) == inst) {
43 return true;
44 }
45 break;
46 case HloOpcode::kFusion:
47 for (int i = 0; i < user->operand_count(); ++i) {
48 if (user->operand(i) == inst &&
49 LooksLikeAnActivation(user->fused_parameter(i))) {
50 return true;
51 }
52 }
53 break;
54 case HloOpcode::kBitcast:
55 return LooksLikeAnActivation(user);
56 default:
57 return true;
58 }
59 }
60 return false;
61 }
62
IsCrossProgramPrefetchCandidate(const HloValue & value,const MemorySpaceAssignment::Options & options)63 bool IsCrossProgramPrefetchCandidate(
64 const HloValue& value, const MemorySpaceAssignment::Options& options) {
65 return value.instruction()->parent() ==
66 value.instruction()->GetModule()->entry_computation() &&
67 value.instruction()->opcode() == HloOpcode::kParameter &&
68 (!value.shape().has_layout() ||
69 value.shape().layout().memory_space() !=
70 options.alternate_memory_space) &&
71 value.index().size() == 1 && value.shape().IsArray() &&
72 !value.uses().empty() &&
73 options.size_fn(value) <= options.max_size_in_bytes &&
74 absl::c_all_of(value.uses(), [&](const HloUse& use) {
75 const HloInstruction* inst =
76 use.instruction->operand(use.operand_number);
77
78 // Skip the LooksLikeAnActivation test since we're testing the
79 // parent GTE and its children below.
80 if (inst->opcode() == HloOpcode::kBitcast &&
81 inst->operand(0)->opcode() == HloOpcode::kGetTupleElement &&
82 inst->operand(0)->operand(0)->opcode() ==
83 HloOpcode::kParameter) {
84 return true;
85 }
86
87 return inst->opcode() == HloOpcode::kGetTupleElement &&
88 !LooksLikeAnActivation(inst);
89 });
90 }
91
92 absl::optional<MemorySpaceAssignment::BufferInterval>
FindCrossProgramPrefetchCandidate(const HloAliasAnalysis & alias_analysis,const HloLiveRange & hlo_live_range,const MemorySpaceAssignment::Options & options)93 FindCrossProgramPrefetchCandidate(
94 const HloAliasAnalysis& alias_analysis, const HloLiveRange& hlo_live_range,
95 const MemorySpaceAssignment::Options& options) {
96 std::vector<MemorySpaceAssignment::BufferInterval> candidates;
97 for (const HloBuffer& buffer : alias_analysis.buffers()) {
98 CHECK_GE(buffer.values().size(), 1);
99 const HloValue* value = buffer.values().at(0);
100 if (IsCrossProgramPrefetchCandidate(*value, options)) {
101 MemorySpaceAssignment::BufferInterval interval;
102 interval.buffer = value;
103 interval.size = options.size_fn(*value);
104 interval.start = 0;
105 interval.end = hlo_live_range.schedule_end_time();
106 interval.need_allocation = true;
107 interval.colocations = {++buffer.values().begin(), buffer.values().end()};
108 candidates.emplace_back(interval);
109 }
110 }
111
112 // The buffer_interval_compare ought to do a good job picking the most
113 // appropriate buffer to cross program prefetch, but empirically, it makes
114 // worse choices than just picking the largest buffer.
115 // TODO(b/152421603): Investigate.
116 auto size_compare = [](const auto& x, const auto& y) {
117 return x.size < y.size;
118 };
119 auto& compare = options.default_cross_program_prefetch_heuristic &&
120 options.buffer_interval_compare
121 ? *options.buffer_interval_compare
122 : size_compare;
123
124 auto best_candidate = absl::c_max_element(candidates, compare);
125 if (best_candidate == candidates.end()) {
126 return absl::nullopt;
127 }
128 return *best_candidate;
129 }
130
131 } // namespace
132
133 /*static*/ StatusOr<std::unique_ptr<MemorySpaceAssignmentCostAnalysis>>
Create(const HloCostAnalysis & cost_analysis,float async_copy_bandwidth_bytes_per_second,float alternate_mem_bandwidth_bytes_per_second,const HloModule & module)134 MemorySpaceAssignmentCostAnalysis::Create(
135 const HloCostAnalysis& cost_analysis,
136 float async_copy_bandwidth_bytes_per_second,
137 float alternate_mem_bandwidth_bytes_per_second, const HloModule& module) {
138 TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module));
139 TF_ASSIGN_OR_RETURN(auto hlo_live_range,
140 HloLiveRange::Run(module.schedule(), *alias_analysis,
141 module.entry_computation()));
142 auto call_graph = CallGraph::Build(&module);
143 return absl::WrapUnique(new MemorySpaceAssignmentCostAnalysis(
144 cost_analysis, async_copy_bandwidth_bytes_per_second,
145 alternate_mem_bandwidth_bytes_per_second, std::move(alias_analysis),
146 std::move(hlo_live_range), std::move(call_graph)));
147 }
148
GetAlternateMemoryBenefit(const HloInstruction & instruction,float elapsed_time_due_to_alternate_mem,MemorySpaceAssignmentCostAnalysis::Cache * cache) const149 float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit(
150 const HloInstruction& instruction, float elapsed_time_due_to_alternate_mem,
151 MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
152 float elapsed_time_due_to_compute =
153 GetInstructionElapsedDueToCompute(instruction);
154 float elapsed_time_due_to_memory =
155 GetInstructionElapsedDueToMemory(instruction);
156 if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) {
157 // Memory bound, return how much alternate memory is better.
158 float while_nest_multiplier;
159 if (cache) {
160 // If there is a cache provided, memoize the while nest multiplier.
161 auto it = cache->while_nest_multiplier.find(&instruction);
162 if (it != cache->while_nest_multiplier.end()) {
163 while_nest_multiplier = it->second;
164 } else {
165 while_nest_multiplier = tensorflow::MathUtil::IPow<float>(
166 kWhileExecutionCount,
167 CalculateComputationNestLevel(&instruction,
168 /*while_only=*/true));
169 cache->while_nest_multiplier[&instruction] = while_nest_multiplier;
170 }
171 } else {
172 while_nest_multiplier = tensorflow::MathUtil::IPow<float>(
173 kWhileExecutionCount,
174 CalculateComputationNestLevel(&instruction,
175 /*while_only=*/true));
176 }
177 return (elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem) *
178 while_nest_multiplier;
179 } else {
180 // Compute bound, return how far off are we to memory boundedness.
181 return elapsed_time_due_to_memory - elapsed_time_due_to_compute;
182 }
183 }
184
GetMemoryBoundedness(const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval & interval,MemorySpaceAssignmentCostAnalysis::Cache * cache) const185 float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness(
186 const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval,
187 MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
188 const HloInstruction& defining_instruction =
189 *interval.buffer->defining_instruction();
190 float alternate_mem_benefit = GetAlternateMemoryBenefit(
191 defining_instruction,
192 GetInstructionElapsedDueToMemory(defining_instruction,
193 /*operand_in_alternate_mem=*/{},
194 /*output_in_alternate_mem=*/true),
195 cache);
196 for (const HloBuffer* buffer : alias_analysis_->ComputeBuffersAt(
197 interval.buffer->defining_position().instruction,
198 interval.buffer->defining_position().index)) {
199 for (const HloValue* value : buffer->values()) {
200 for (const HloUse& use : value->uses()) {
201 // We look inside the called computations of while and conditional, so
202 // don't use the benefit of while and conditional directly.
203 if (use.instruction->opcode() == HloOpcode::kWhile ||
204 use.instruction->opcode() == HloOpcode::kConditional) {
205 continue;
206 }
207 float use_alternate_mem_benefit =
208 GetAlternateMemoryBenefit(*use.instruction,
209 GetInstructionElapsedDueToMemory(
210 *use.instruction, use.operand_number),
211 cache);
212 // If the benefit is positive (memory bound), add it to this buffer's
213 // benefit. If the benefit is negative (compute bound), calculate the
214 // maximum.
215 if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) {
216 alternate_mem_benefit += use_alternate_mem_benefit;
217 } else {
218 alternate_mem_benefit =
219 std::max(alternate_mem_benefit, use_alternate_mem_benefit);
220 }
221 }
222 }
223 }
224
225 // Penalize larger buffers by dividing the benefit by the square root of the
226 // size. Empirically, we observed this resulted in better performance compared
227 // to dividing by the size.
228 return alternate_mem_benefit / std::sqrt(interval.size);
229 }
230
CalculateComputationNestLevel(const HloInstruction * instruction,bool while_only) const231 int MemorySpaceAssignmentCostAnalysis::CalculateComputationNestLevel(
232 const HloInstruction* instruction, bool while_only) const {
233 int nest_level = 0;
234 const HloComputation* computation = instruction->parent();
235 while (!computation->IsEntryComputation()) {
236 auto node = call_graph_->GetNode(computation);
237 auto callsites = node.caller_callsites();
238 CHECK_EQ(callsites.size(), 1) << "The module is not flattened!";
239 auto callsite = callsites[0];
240 if (!while_only || callsite.instruction()->opcode() == HloOpcode::kWhile) {
241 ++nest_level;
242 }
243 computation = callsite.instruction()->parent();
244 }
245 return nest_level;
246 }
247
GetInstructionElapsedDueToCompute(const HloInstruction & instruction) const248 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToCompute(
249 const HloInstruction& instruction) const {
250 return std::max(
251 cost_analysis_.flop_count(instruction) /
252 cost_analysis_.per_second_rate(HloCostAnalysis::kFlopsKey),
253 cost_analysis_.transcendental_count(instruction) /
254 cost_analysis_.per_second_rate(HloCostAnalysis::kTranscendentalsKey));
255 }
256
GetInstructionElapsedDueToMemory(const HloInstruction & instruction,absl::optional<int64> operand_in_alternate_mem,bool output_in_alternate_mem) const257 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToMemory(
258 const HloInstruction& instruction,
259 absl::optional<int64> operand_in_alternate_mem,
260 bool output_in_alternate_mem) const {
261 float bytes_accessed = cost_analysis_.bytes_accessed(instruction);
262 float elapsed_due_to_bytes =
263 bytes_accessed /
264 cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey);
265 if (operand_in_alternate_mem) {
266 // Estimate the elapsed time due to the operand being in the alternate
267 // memory space.
268 float operand_bytes_accessed = cost_analysis_.operand_bytes_accessed(
269 instruction, *operand_in_alternate_mem);
270 float elapsed_due_to_operand_bytes =
271 operand_bytes_accessed / alternate_mem_bandwidth_bytes_per_second_;
272 bytes_accessed -= operand_bytes_accessed;
273 elapsed_due_to_bytes =
274 elapsed_due_to_operand_bytes +
275 bytes_accessed /
276 cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey);
277 }
278 if (output_in_alternate_mem) {
279 // Estimate the elapsed time due to the output being in the alternate memory
280 // space.
281 float output_bytes_accessed =
282 cost_analysis_.output_bytes_accessed(instruction);
283 float elapsed_due_to_output_bytes =
284 output_bytes_accessed / alternate_mem_bandwidth_bytes_per_second_;
285 bytes_accessed -= output_bytes_accessed;
286 elapsed_due_to_bytes =
287 elapsed_due_to_output_bytes +
288 bytes_accessed /
289 cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey);
290 }
291 return elapsed_due_to_bytes;
292 }
293
GetInstructionElapsed(const HloInstruction & instruction) const294 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsed(
295 const HloInstruction& instruction) const {
296 return std::max(GetInstructionElapsedDueToCompute(instruction),
297 GetInstructionElapsedDueToMemory(instruction));
298 }
299
GetInstructionElapsedInAlternateMemory(const HloInstruction & instruction,absl::optional<int64> operand_in_alternate_mem,bool output_in_alternate_mem) const300 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedInAlternateMemory(
301 const HloInstruction& instruction,
302 absl::optional<int64> operand_in_alternate_mem,
303 bool output_in_alternate_mem) const {
304 return std::max(
305 GetInstructionElapsedDueToCompute(instruction),
306 GetInstructionElapsedDueToMemory(instruction, operand_in_alternate_mem,
307 output_in_alternate_mem));
308 }
309
GetAsyncCopyElapsed(const Shape & shape) const310 float MemorySpaceAssignmentCostAnalysis::GetAsyncCopyElapsed(
311 const Shape& shape) const {
312 int64 size_in_bytes = cost_analysis_.GetShapeSize(shape);
313 return static_cast<float>(size_in_bytes) /
314 async_copy_bandwidth_bytes_per_second_;
315 }
316
GetScheduleEndTime() const317 int64 MemorySpaceAssignmentCostAnalysis::GetScheduleEndTime() const {
318 return hlo_live_range_->schedule_end_time();
319 }
320
CanAllocateInAlternateMemoryNoCopy(const Shape & shape,int64 start_time,int64 end_time) const321 bool InstructionCountPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy(
322 const Shape& shape, int64 start_time, int64 end_time) const {
323 return end_time - start_time <= max_overlap_count_;
324 }
325
PreferredEvictionEndTime(const Shape & shape,int64 start_time,int64 latest_end_time) const326 int64 InstructionCountPrefetchIntervalPicker::PreferredEvictionEndTime(
327 const Shape& shape, int64 start_time, int64 latest_end_time) const {
328 return std::min(start_time + min_overlap_count_, latest_end_time);
329 }
330
LatestPrefetchStartTime(const Shape & shape,int64 start_time,int64 end_time,const HloUse * use) const331 int64 InstructionCountPrefetchIntervalPicker::LatestPrefetchStartTime(
332 const Shape& shape, int64 start_time, int64 end_time,
333 const HloUse* use) const {
334 return end_time - min_overlap_count_;
335 }
336
PreferredPrefetchStartTime(const Shape & shape,int64 earliest_prefetch_start_time,int64 latest_prefetch_start_time,int64 prefetch_end_time) const337 int64 InstructionCountPrefetchIntervalPicker::PreferredPrefetchStartTime(
338 const Shape& shape, int64 earliest_prefetch_start_time,
339 int64 latest_prefetch_start_time, int64 prefetch_end_time) const {
340 return std::max(earliest_prefetch_start_time,
341 prefetch_end_time - max_overlap_count_);
342 }
343
Begin(const HloUse & use,int64 start_time,int64 end_time)344 void InstructionCountPrefetchIntervalPicker::Begin(const HloUse& use,
345 int64 start_time,
346 int64 end_time) {
347 end_time_ = end_time;
348 const Shape& shape = ShapeUtil::GetSubshape(
349 use.instruction->operand(use.operand_number)->shape(), use.operand_index);
350 current_prefetch_time_ =
351 PreferredPrefetchStartTime(shape, start_time, end_time, end_time);
352 }
353
Next()354 int64 InstructionCountPrefetchIntervalPicker::Next() {
355 CHECK(!Done()) << "Prefetch interval picker's Next() is called even though "
356 "Done() is false";
357 return current_prefetch_time_++;
358 }
359
Done() const360 bool InstructionCountPrefetchIntervalPicker::Done() const {
361 return end_time_ - current_prefetch_time_ <= min_overlap_count_;
362 }
363
ToDebugString() const364 std::string InstructionCountPrefetchIntervalPicker::ToDebugString() const {
365 return absl::StrCat("Overlapped HLOs = ", end_time_ - current_prefetch_time_);
366 }
367
ToNoCopyDebugString(const Shape & shape,int64 start_time,int64 end_time) const368 std::string InstructionCountPrefetchIntervalPicker::ToNoCopyDebugString(
369 const Shape& shape, int64 start_time, int64 end_time) const {
370 return absl::StrCat("Overlapped HLOs = ", end_time - start_time);
371 }
372
CostAnalysisPrefetchIntervalPicker(const MemorySpaceAssignmentCostAnalysis & cost_analysis,float min_async_copy_to_overlap_ratio,float max_async_copy_to_overlap_ratio,float preferred_async_copy_to_overlap_ratio)373 CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker(
374 const MemorySpaceAssignmentCostAnalysis& cost_analysis,
375 float min_async_copy_to_overlap_ratio,
376 float max_async_copy_to_overlap_ratio,
377 float preferred_async_copy_to_overlap_ratio)
378 : while_nest_level_(
379 cost_analysis.hlo_live_range().instruction_schedule().size(), 0),
380 computation_nest_level_(
381 cost_analysis.hlo_live_range().instruction_schedule().size(), 0),
382 cost_analysis_(cost_analysis),
383 min_async_copy_to_overlap_ratio_(min_async_copy_to_overlap_ratio),
384 max_async_copy_to_overlap_ratio_(max_async_copy_to_overlap_ratio),
385 preferred_async_copy_to_overlap_ratio_(
386 preferred_async_copy_to_overlap_ratio) {
387 instruction_schedule_ =
388 &cost_analysis_.hlo_live_range().instruction_schedule();
389
390 // Create a vector of elapsed times and while nesting levels of HLO
391 // instructions. The elapsed times are multiplied by pow(kWhileExecutionCount,
392 // nest_level) to account for executing the HLOs multiple times in while
393 // loops.
394 std::vector<float> instructions_elapsed_time(instruction_schedule_->size(),
395 0.0);
396 for (const auto& instruction_and_logical_time : *instruction_schedule_) {
397 // To avoid double counting, don't include the elapsed time of while and
398 // conditional HLOs.
399 const HloInstruction* instruction = instruction_and_logical_time.first;
400 int64 logical_time = instruction_and_logical_time.second;
401 if (logical_time >= instructions_elapsed_time.size()) {
402 instructions_elapsed_time.resize(logical_time + 1, 0.0);
403 while_nest_level_.resize(logical_time + 1, 0);
404 }
405 int while_nest_level = cost_analysis_.CalculateComputationNestLevel(
406 instruction_and_logical_time.first, /*while_only=*/true);
407 while_nest_level_[logical_time] = while_nest_level;
408 int computation_nest_level = cost_analysis_.CalculateComputationNestLevel(
409 instruction_and_logical_time.first, /*while_only=*/false);
410 computation_nest_level_[logical_time] = computation_nest_level;
411 if (instruction->opcode() == HloOpcode::kWhile ||
412 instruction->opcode() == HloOpcode::kConditional) {
413 continue;
414 }
415 float elapsed_time = cost_analysis_.GetInstructionElapsed(
416 *instruction_and_logical_time.first);
417 instructions_elapsed_time[logical_time] =
418 elapsed_time * tensorflow::MathUtil::IPow<float>(kWhileExecutionCount,
419 while_nest_level);
420 }
421 // As an optimization, create a cumulative sum vector of elapsed time.
422 float cumsum = 0.0;
423 elapsed_time_cumsum_.reserve(instructions_elapsed_time.size());
424 for (float elapsed_time : instructions_elapsed_time) {
425 cumsum += elapsed_time;
426 elapsed_time_cumsum_.push_back(cumsum);
427 }
428 // To be able to accurately determine the minimum nest level between a start
429 // time and an end time efficiently, populate a data structure that stores the
430 // closest nest level change index.
431 int prev_nest_level = 0;
432 int change_idx = -1;
433 while_nest_level_change_.reserve(instructions_elapsed_time.size());
434 for (int i = 0; i < while_nest_level_.size(); ++i) {
435 int nest_level = while_nest_level_[i];
436 if (nest_level != prev_nest_level) {
437 prev_nest_level = nest_level;
438 change_idx = i - 1;
439 }
440 while_nest_level_change_.push_back(change_idx);
441 }
442 }
443
CanAllocateInAlternateMemoryNoCopy(const Shape & shape,int64 start_time,int64 end_time) const444 bool CostAnalysisPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy(
445 const Shape& shape, int64 start_time, int64 end_time) const {
446 // Even though this method returns if we allow the buffer in alternate memory
447 // _without_ asynchronous copies, calculate how long it would have taken to
448 // copy it and compare it to the elapsed time in the logical interval.
449 float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
450 float logical_interval_elapsed =
451 GetLogicalIntervalElapsed(start_time, end_time);
452 return max_async_copy_to_overlap_ratio_ * max_overlap_multiplier_ *
453 async_copy_elapsed >
454 logical_interval_elapsed;
455 }
456
PreferredEvictionEndTime(const Shape & shape,int64 start_time,int64 latest_end_time) const457 int64 CostAnalysisPrefetchIntervalPicker::PreferredEvictionEndTime(
458 const Shape& shape, int64 start_time, int64 latest_end_time) const {
459 float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
460 int64 end_time;
461 for (end_time = start_time + 1; end_time <= latest_end_time; ++end_time) {
462 float logical_interval_elapsed =
463 GetLogicalIntervalElapsed(start_time, end_time);
464 if (logical_interval_elapsed >=
465 min_async_copy_to_overlap_ratio_ * async_copy_elapsed) {
466 break;
467 }
468 }
469 return end_time;
470 }
471
LatestPrefetchStartTime(const Shape & shape,int64 start_time,int64 end_time,const HloUse * use) const472 int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchStartTime(
473 const Shape& shape, int64 start_time, int64 end_time,
474 const HloUse* use) const {
475 // Find the earliest time that satisfies max_async_copy_to_overlap_ratio_.
476 float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
477 // If there is a use, estimate the time we would save by having this op in
478 // alternate memory.
479 float inst_elapsed_reduction = 0.0f;
480 if (use) {
481 float elapsed_time =
482 cost_analysis_.GetInstructionElapsed(*use->instruction);
483 float elapsed_time_in_alternate_mem =
484 cost_analysis_.GetInstructionElapsedInAlternateMemory(
485 *use->instruction, use->operand_number,
486 /*output_in_alternate_mem=*/false);
487 inst_elapsed_reduction = elapsed_time - elapsed_time_in_alternate_mem;
488 }
489 int end_nest_level = computation_nest_level_[end_time];
490
491 // Find the latest time we're allowed to start prefetching.
492 float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed;
493 int latest_prefetch_time;
494 for (latest_prefetch_time = end_time - 1;
495 latest_prefetch_time >= start_time &&
496 (computation_nest_level_[latest_prefetch_time] != end_nest_level ||
497 min_interval >
498 GetLogicalIntervalElapsed(latest_prefetch_time, end_time) +
499 inst_elapsed_reduction);
500 --latest_prefetch_time) {
501 }
502
503 return latest_prefetch_time;
504 }
505
PreferredPrefetchStartTime(const Shape & shape,int64 earliest_prefetch_start_time,int64 latest_prefetch_start_time,int64 prefetch_end_time) const506 int64 CostAnalysisPrefetchIntervalPicker::PreferredPrefetchStartTime(
507 const Shape& shape, int64 earliest_prefetch_start_time,
508 int64 latest_prefetch_start_time, int64 prefetch_end_time) const {
509 // Between the earliest and latest prefetch interval, find the interval
510 // closest to the preferred interval and start iterating from there.
511 float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
512 int64 preferred_prefetch_start_time = earliest_prefetch_start_time;
513 float preferred_interval =
514 preferred_async_copy_to_overlap_ratio_ * async_copy_elapsed;
515 float best_interval = GetLogicalIntervalElapsed(earliest_prefetch_start_time,
516 prefetch_end_time);
517 int end_nest_level = computation_nest_level_[prefetch_end_time];
518 for (int64 prefetch_start_time = earliest_prefetch_start_time + 1;
519 prefetch_start_time <= latest_prefetch_start_time;
520 ++prefetch_start_time) {
521 float interval =
522 GetLogicalIntervalElapsed(prefetch_start_time, prefetch_end_time);
523 if (computation_nest_level_[prefetch_start_time] == end_nest_level &&
524 std::abs(preferred_interval - interval) <
525 std::abs(preferred_interval - best_interval)) {
526 best_interval = interval;
527 preferred_prefetch_start_time = prefetch_start_time;
528 }
529 }
530 return preferred_prefetch_start_time;
531 }
532
LatestPrefetchEndTime(int64 original_prefetch_end_time,int64 proposed_prefetch_end_time) const533 int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchEndTime(
534 int64 original_prefetch_end_time, int64 proposed_prefetch_end_time) const {
535 // Iterate towards the beginning until we find a suitable end time that is the
536 // same while nest level as the original prefetch end time.
537 int64 original_nest_level =
538 computation_nest_level_[original_prefetch_end_time];
539 int64 new_prefetch_end_time;
540 for (new_prefetch_end_time = proposed_prefetch_end_time;
541 computation_nest_level_[new_prefetch_end_time] != original_nest_level;
542 --new_prefetch_end_time) {
543 }
544 return new_prefetch_end_time;
545 }
546
Begin(const HloUse & use,int64 start_time,int64 end_time)547 void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use,
548 int64 start_time,
549 int64 end_time) {
550 const Shape& shape = ShapeUtil::GetSubshape(
551 use.instruction->operand(use.operand_number)->shape(), use.operand_index);
552 // Find the earliest time that satisfies max_async_copy_to_overlap_ratio_.
553 async_copy_elapsed_ = cost_analysis_.GetAsyncCopyElapsed(shape);
554 // Estimate the time we would save by having this op in alternate memory.
555 float elapsed_time = cost_analysis_.GetInstructionElapsed(*use.instruction);
556 float elapsed_time_in_alternate_mem =
557 cost_analysis_.GetInstructionElapsedInAlternateMemory(
558 *use.instruction, use.operand_number,
559 /*output_in_alternate_mem=*/false);
560 inst_elapsed_reduction_ = elapsed_time - elapsed_time_in_alternate_mem;
561 end_logical_time_ = end_time;
562 int end_nest_level = computation_nest_level_[end_logical_time_];
563
564 // Find the latest time we're allowed to start prefetching.
565 float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed_;
566 latest_prefetch_time_ =
567 LatestPrefetchStartTime(shape, start_time, end_time, &use);
568
569 // Find the earliest time we're allowed to start prefetching.
570 float max_interval = max_async_copy_to_overlap_ratio_ *
571 max_overlap_multiplier_ * async_copy_elapsed_;
572 for (earliest_prefetch_time_ = start_time;
573 earliest_prefetch_time_ <= end_logical_time_ &&
574 (computation_nest_level_[earliest_prefetch_time_] != end_nest_level ||
575 max_interval < GetLogicalIntervalElapsed(earliest_prefetch_time_,
576 end_logical_time_));
577 ++earliest_prefetch_time_) {
578 }
579 if (earliest_prefetch_time_ > latest_prefetch_time_) {
580 // There is no available prefetch interval for the given start and end
581 // times. Set the iterators accordingly to ensure Done() returns true.
582 increasing_prefetch_time_iterator_ = earliest_prefetch_time_;
583 decreasing_prefetch_time_iterator_ = latest_prefetch_time_;
584 CHECK(Done());
585 return;
586 }
587
588 int64 starting_prefetch_time = PreferredPrefetchStartTime(
589 shape, earliest_prefetch_time_, latest_prefetch_time_, end_logical_time_);
590 float preferred_interval =
591 preferred_async_copy_to_overlap_ratio_ * async_copy_elapsed_;
592 VLOG(4) << "Interval min/max/preferred = " << min_interval << " "
593 << max_interval << " " << preferred_interval
594 << " prefetch time earliest/latest/starting = "
595 << earliest_prefetch_time_ << " " << latest_prefetch_time_ << " "
596 << starting_prefetch_time;
597
598 increasing_prefetch_time_iterator_ = starting_prefetch_time;
599 decreasing_prefetch_time_iterator_ = starting_prefetch_time;
600 using_increasing_prefetch_time_iterator_ = true;
601 // Since both iterators start at the same position, call Next() once to
602 // advance one of the iterators.
603 Next();
604 }
605
Next()606 int64 CostAnalysisPrefetchIntervalPicker::Next() {
607 CHECK(!Done()) << "Prefetch interval picker's Next() is called even though "
608 "Done() is false";
609 if (using_increasing_prefetch_time_iterator_) {
610 int64 prefetch_time = increasing_prefetch_time_iterator_++;
611 while (increasing_prefetch_time_iterator_ <= latest_prefetch_time_ &&
612 computation_nest_level_[increasing_prefetch_time_iterator_] !=
613 computation_nest_level_[end_logical_time_]) {
614 ++increasing_prefetch_time_iterator_;
615 }
616 if (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_) {
617 using_increasing_prefetch_time_iterator_ = false;
618 }
619 return prefetch_time;
620 } else {
621 int64 prefetch_time = decreasing_prefetch_time_iterator_--;
622 while (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_ &&
623 computation_nest_level_[decreasing_prefetch_time_iterator_] !=
624 computation_nest_level_[end_logical_time_]) {
625 --decreasing_prefetch_time_iterator_;
626 }
627 if (increasing_prefetch_time_iterator_ <= latest_prefetch_time_) {
628 using_increasing_prefetch_time_iterator_ = true;
629 }
630 return prefetch_time;
631 }
632 }
633
Done() const634 bool CostAnalysisPrefetchIntervalPicker::Done() const {
635 return increasing_prefetch_time_iterator_ > latest_prefetch_time_ &&
636 decreasing_prefetch_time_iterator_ < earliest_prefetch_time_;
637 }
638
SetRetryNumber(int retry_number)639 void CostAnalysisPrefetchIntervalPicker::SetRetryNumber(int retry_number) {
640 // Use twice as large max overlap limit in each retry.
641 max_overlap_multiplier_ = 1 << retry_number;
642 }
643
GetMinWhileNestLevel(int64 start_time,int64 end_time) const644 int CostAnalysisPrefetchIntervalPicker::GetMinWhileNestLevel(
645 int64 start_time, int64 end_time) const {
646 int min_nest_level =
647 std::min(while_nest_level_[start_time], while_nest_level_[end_time]);
648 int change_idx = while_nest_level_change_[end_time];
649 while (change_idx >= start_time) {
650 min_nest_level = std::min(min_nest_level, while_nest_level_[change_idx]);
651 change_idx = while_nest_level_change_[change_idx];
652 }
653 return min_nest_level;
654 }
655
GetLogicalIntervalElapsed(int64 start_time,int64 end_time) const656 float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed(
657 int64 start_time, int64 end_time) const {
658 CHECK_LE(start_time, end_time);
659 if (start_time == end_time) {
660 return 0.0;
661 }
662 if (start_time < 0) {
663 start_time = 0;
664 }
665 // Since elapsed_time_cumsum_ is already weighed by the while loop nesting
666 // level, normalize the elapsed time by dividing with the nesting factor of
667 // the interval (start and end times).
668 int interval_while_nest_level = GetMinWhileNestLevel(start_time, end_time);
669 return (elapsed_time_cumsum_[end_time - 1] -
670 elapsed_time_cumsum_[start_time]) /
671 tensorflow::MathUtil::IPow<float>(kWhileExecutionCount,
672 interval_while_nest_level);
673 }
674
ToDebugString() const675 std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const {
676 int current_logical_prefetch_time = using_increasing_prefetch_time_iterator_
677 ? increasing_prefetch_time_iterator_
678 : decreasing_prefetch_time_iterator_;
679 float logical_interval_elapsed = GetLogicalIntervalElapsed(
680 current_logical_prefetch_time, end_logical_time_);
681 return absl::StrCat(
682 "Async copy elapsed (s) = ", async_copy_elapsed_,
683 ", inst elapsed reduction (s) = ", inst_elapsed_reduction_,
684 ", logical interval elapsed (s) = ", logical_interval_elapsed,
685 ", interval = (", current_logical_prefetch_time, ", ", end_logical_time_,
686 ")");
687 }
688
ToNoCopyDebugString(const Shape & shape,int64 start_time,int64 end_time) const689 std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString(
690 const Shape& shape, int64 start_time, int64 end_time) const {
691 float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
692 float logical_interval_elapsed =
693 GetLogicalIntervalElapsed(start_time, end_time);
694 return absl::StrCat(
695 "Async copy elapsed (s) = ", async_copy_elapsed,
696 ", logical interval elapsed (s) = ", logical_interval_elapsed);
697 }
698
699 absl::optional<float>
BufferIntervalAlternateMemoryBenefit(const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval & interval) const700 CostAnalysisPrefetchIntervalPicker::BufferIntervalAlternateMemoryBenefit(
701 const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval)
702 const {
703 return cost_analysis_.GetMemoryBoundedness(interval);
704 }
705
operator ==(const MemorySpaceAssignment::Allocation & other) const706 bool MemorySpaceAssignment::Allocation::operator==(
707 const MemorySpaceAssignment::Allocation& other) const {
708 return defining_position() == other.defining_position() &&
709 uses() == other.uses() && memory_space() == other.memory_space() &&
710 chunk() == other.chunk() && start_time() == other.start_time() &&
711 end_time() == other.end_time() &&
712 is_copy_allocation() == other.is_copy_allocation();
713 }
714
operator ==(const MemorySpaceAssignment::CopyAllocation & other) const715 bool MemorySpaceAssignment::CopyAllocation::operator==(
716 const MemorySpaceAssignment::CopyAllocation& other) const {
717 return static_cast<const Allocation&>(*this) ==
718 static_cast<const Allocation&>(other) &&
719 copy_done_schedule_before() == other.copy_done_schedule_before() &&
720 copy_start_schedule_after() == other.copy_start_schedule_after() &&
721 copy_start() == other.copy_start() && copy_done() == other.copy_done();
722 }
723
ToString() const724 std::string MemorySpaceAssignment::AllocationValue::ToString() const {
725 std::string out = absl::StrCat("computation = ", computation()->name());
726 absl::StrAppend(&out, "\n position:\n");
727 absl::StrAppend(&out, " ", defining_position_.ToString(), "\n");
728 absl::StrAppend(&out, " uses:\n");
729 for (const Use& use : uses_) {
730 absl::StrAppend(&out, " ", use.hlo_use.ToString(), "\n");
731 }
732 return out;
733 }
734
ToShortString() const735 std::string MemorySpaceAssignment::AllocationValue::ToShortString() const {
736 return absl::StrCat("computation = ", computation()->name(),
737 ", position = ", defining_position_.ToString(),
738 ", value = ", value_->ToShortString());
739 }
740
CreateAllocationValues(const AlternateMemoryBestFitHeap::BufferInterval & buffer_interval,std::vector<AllocationValue> & allocation_values) const741 void AlternateMemoryBestFitHeap::CreateAllocationValues(
742 const AlternateMemoryBestFitHeap::BufferInterval& buffer_interval,
743 std::vector<AllocationValue>& allocation_values) const {
744 const HloValue* value = buffer_interval.buffer;
745 VLOG(3) << "Creating AllocationValues for: " << value->ToString();
746
747 // Find and sort all non-trivial (excluding GTE, Tuple, and bitcast)
748 // positions. We create an AllocationValue object for each non-trivial
749 // position. And for each AllocationValue object, we create an
750 // AllocationSequence consisting of one or more Allocation objects.The reason
751 // why we exclude the trivial positions from AllocationValue is because
752 // Allocation objects have special support for tuples and bitcasts.
753 const absl::flat_hash_map<const HloInstruction*, int64>&
754 instruction_schedule = hlo_live_range_.instruction_schedule();
755 std::vector<HloPosition> positions;
756 for (const HloPosition& position : value->positions()) {
757 const HloInstruction* instruction = position.instruction;
758 if (instruction->opcode() != HloOpcode::kGetTupleElement &&
759 instruction->opcode() != HloOpcode::kTuple &&
760 instruction->opcode() != HloOpcode::kBitcast) {
761 positions.push_back(position);
762 }
763 }
764 absl::c_stable_sort(positions,
765 [&](const HloPosition& pos1, const HloPosition& pos2) {
766 return instruction_schedule.at(pos1.instruction) <
767 instruction_schedule.at(pos2.instruction);
768 });
769
770 // Create an AllocationValue for each non-trivial position.
771 absl::flat_hash_set<const HloComputation*> computations;
772 int beginning_idx = allocation_values.size();
773 for (int i = 0; i < positions.size(); ++i) {
774 const HloPosition& position = positions.at(i);
775 allocation_values.emplace_back(value, position, buffer_interval.size);
776 }
777
778 std::vector<HloUse> uses(value->uses());
779 absl::c_stable_sort(uses, [&](const HloUse& use1, const HloUse& use2) {
780 return instruction_schedule.at(use1.instruction) <
781 instruction_schedule.at(use2.instruction);
782 });
783
784 // Associate each use with an AllocationValue. Each AllocationValue contains a
785 // position and uses in the same computation. Furthermore, if the original
786 // HloValue had multiple non-trivial positions in the same computation, those
787 // will get their own AllocationValue as well. We split these HloValues so
788 // that when we insert CopyStart/CopyDone in CopyAllocation::Process, they
789 // point to the latest position. We then replace the operand of the use with
790 // CopyStart/CopyDone with an operand of the latest position.
791 for (const HloUse& use : uses) {
792 int64 use_time = instruction_schedule.at(use.instruction);
793 HloComputation* use_computation = use.instruction->parent();
794
795 AllocationValue* last_allocation_value = nullptr;
796 for (int i = beginning_idx; i < allocation_values.size(); ++i) {
797 AllocationValue* allocation_value = &allocation_values.at(i);
798 if (allocation_value->computation() == use_computation &&
799 instruction_schedule.at(
800 allocation_value->defining_position().instruction) < use_time) {
801 last_allocation_value = allocation_value;
802 }
803 }
804 CHECK(last_allocation_value != nullptr);
805 last_allocation_value->AddUse(use, use_time);
806 }
807
808 for (int i = beginning_idx; i < allocation_values.size(); ++i) {
809 VLOG(3) << "Created allocation value: "
810 << allocation_values.at(i).ToString();
811 }
812 }
813
FindAliases(std::vector<AllocationValue> * allocation_values,bool skip_values_with_no_uses) const814 void AlternateMemoryBestFitHeap::FindAliases(
815 std::vector<AllocationValue>* allocation_values,
816 bool skip_values_with_no_uses) const {
817 absl::flat_hash_map<const HloInstruction*, const AllocationValue*>
818 values_by_defining_inst;
819 for (AllocationValue& value : *allocation_values) {
820 // Skip the value if it doesn't have any uses.
821 if (value.uses().empty() && skip_values_with_no_uses) {
822 continue;
823 }
824 CHECK_EQ(values_by_defining_inst.count(value.defining_instruction()), 0);
825 values_by_defining_inst[value.defining_instruction()] = &value;
826 }
827 auto maybe_add_alias_with_instruction = [&](const HloInstruction* instruction,
828 AllocationValue::Use* use) {
829 auto aliased_value_it = values_by_defining_inst.find(instruction);
830 if (aliased_value_it != values_by_defining_inst.end()) {
831 VLOG(3) << "Adding aliasing for use " << use->hlo_use.ToString() << " to "
832 << aliased_value_it->second->ToShortString();
833 use->aliases.push_back(aliased_value_it->second->defining_position());
834 }
835 };
836
837 for (AllocationValue& value : *allocation_values) {
838 for (AllocationValue::Use& use : value.uses()) {
839 // Find any aliases with the instruction itself (operand and output must
840 // alias).
841 maybe_add_alias_with_instruction(use.hlo_use.instruction, &use);
842
843 // Find any aliases with the parameters of called computations.
844 for (const HloComputation* called_computation :
845 use.hlo_use.instruction->called_computations()) {
846 for (const HloInstruction* parameter_instruction :
847 called_computation->parameter_instructions()) {
848 maybe_add_alias_with_instruction(parameter_instruction, &use);
849 }
850 }
851
852 // Special case for kWhile: the root of the body computation must alias as
853 // well.
854 if (use.hlo_use.instruction->opcode() == HloOpcode::kWhile) {
855 HloPosition root_alias{
856 use.hlo_use.instruction->while_body()->root_instruction(),
857 use.hlo_use.operand_index};
858 VLOG(3) << "Adding while body root aliasing for use "
859 << use.hlo_use.ToString() << " to " << root_alias;
860 use.aliases.push_back(root_alias);
861 }
862 }
863 }
864 }
865
866 std::vector<const AlternateMemoryBestFitHeap::BufferInterval*>
GetSortedColocatedIntervals(const AlternateMemoryBestFitHeap::BufferInterval & interval) const867 AlternateMemoryBestFitHeap::GetSortedColocatedIntervals(
868 const AlternateMemoryBestFitHeap::BufferInterval& interval) const {
869 std::vector<const BufferInterval*> colocated_intervals;
870 std::vector<const BufferInterval*> worklist = {&interval};
871 while (!worklist.empty()) {
872 const BufferInterval* item = worklist.back();
873 worklist.pop_back();
874 colocated_intervals.push_back(item);
875 for (const HloValue* buffer_colocated : item->colocations) {
876 worklist.push_back(&buffer_intervals_.at(buffer_colocated));
877 }
878 }
879
880 absl::c_stable_sort(colocated_intervals, [&](const BufferInterval* x,
881 const BufferInterval* y) {
882 return std::make_pair(x->start, x->end) < std::make_pair(y->start, y->end);
883 });
884 return colocated_intervals;
885 }
886
IsUseAllowedInAlternateMemory(const AllocationValue & value,const HloUse & use) const887 bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
888 const AllocationValue& value, const HloUse& use) const {
889 const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
890 if (!options_.is_use_allowed_in_alternate_mem_fn(use)) {
891 return false;
892 }
893 if (use.instruction->opcode() == HloOpcode::kWhile) {
894 HloComputation* while_body = use.instruction->while_body();
895
896 // We don't want to allocate this buffer in alternate memory if it will be
897 // evicted anyway. Find out if it has an early use or a late definition that
898 // would make sense to keep it in the alternate memory.
899 HloValue* parameter_value =
900 &alias_analysis_.dataflow_analysis().GetUniqueValueAt(
901 while_body->parameter_instruction(0), use.operand_index);
902 int64 parameter_time =
903 instruction_schedule.at(while_body->parameter_instruction(0));
904 int64 root_time = instruction_schedule.at(while_body->root_instruction());
905 int64 min_use_time = root_time;
906 for (const HloUse& parameter_use : parameter_value->uses()) {
907 int64 use_time = instruction_schedule.at(parameter_use.instruction);
908 if (parameter_use.instruction->opcode() != HloOpcode::kGetTupleElement &&
909 parameter_use.instruction->opcode() != HloOpcode::kTuple &&
910 parameter_use.instruction->opcode() != HloOpcode::kBitcast &&
911 use_time > parameter_time) {
912 min_use_time = std::min(min_use_time, use_time);
913 }
914 }
915 // If there is no use of this buffer inside the while loop, there is no need
916 // to allocate it in the loop.
917 if (min_use_time == root_time) {
918 VLOG(4) << "While allocation not allowed in alternate memory. "
919 << "use time = " << min_use_time << ", root time = " << root_time;
920 return false;
921 }
922 const Shape& shape = parameter_value->shape();
923 // Allow the buffer in alternate memory if the buffer has a short live range
924 // either at the beginning or end of the while loop body.
925 if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
926 shape, parameter_time, min_use_time)) {
927 VLOG(4) << "While allocation not allowed in alternate memory. "
928 << "use time = " << min_use_time << ", root time = " << root_time;
929 return false;
930 }
931 // Check if there is a required assignment for the while loop output.
932 HloValue* while_value =
933 &alias_analysis_.dataflow_analysis().GetUniqueValueAt(
934 use.instruction, use.operand_index);
935 int64 while_time = instruction_schedule.at(use.instruction);
936 auto existing_required_assignment =
937 RequiredMemoryAssignmentAt(while_value, while_time);
938 if (existing_required_assignment &&
939 existing_required_assignment->memory_space == MemorySpace::kDefault) {
940 VLOG(4) << "While allocation not allowed in alternate memory because "
941 "there is a required default memory assignment.";
942 return false;
943 }
944 } else if (use.instruction->opcode() == HloOpcode::kConditional) {
945 // For any use of this conditional (the same value might be passed into
946 // multiple called computations), determine if the parameter->first use
947 // dependency is short.
948 int64 conditional_time = instruction_schedule.at(use.instruction);
949 for (const AllocationValue::Use& other_use : value.uses()) {
950 if (other_use.hlo_use.instruction != use.instruction) {
951 continue;
952 }
953 HloComputation* called_computation =
954 use.instruction->called_computations().at(
955 other_use.hlo_use.operand_number - 1);
956 const HloInstruction* parameter_instruction =
957 called_computation->parameter_instruction(0);
958 HloValue* parameter_value =
959 &alias_analysis_.dataflow_analysis().GetUniqueValueAt(
960 parameter_instruction, other_use.hlo_use.operand_index);
961 int64 parameter_time = instruction_schedule.at(parameter_instruction);
962 int64 min_use_time = conditional_time;
963 for (const HloUse& parameter_use : parameter_value->uses()) {
964 if (parameter_use.instruction->parent() == called_computation &&
965 parameter_use.instruction->opcode() !=
966 HloOpcode::kGetTupleElement &&
967 parameter_use.instruction->opcode() != HloOpcode::kTuple &&
968 parameter_use.instruction->opcode() != HloOpcode::kBitcast) {
969 min_use_time = std::min(
970 min_use_time, instruction_schedule.at(parameter_use.instruction));
971 }
972 }
973 if (options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
974 parameter_value->shape(), parameter_time, min_use_time)) {
975 VLOG(4) << "Conditional allocation allowed in alternate memory for "
976 "computation = "
977 << called_computation->name()
978 << ", parameter time = " << parameter_time
979 << ", min use time = " << min_use_time;
980 return true;
981 } else {
982 VLOG(4) << "Conditional allocation not allowed in alternate memory for "
983 "computation = "
984 << called_computation->name()
985 << ", parameter time = " << parameter_time
986 << ", min use time = " << min_use_time;
987 }
988 }
989 return false;
990 }
991
992 return true;
993 }
994
AppendBufferInfoDebugString(const AlternateMemoryBestFitHeap::BufferInterval & interval,std::string * debug_str) const995 void AlternateMemoryBestFitHeap::AppendBufferInfoDebugString(
996 const AlternateMemoryBestFitHeap::BufferInterval& interval,
997 std::string* debug_str) const {
998 // Columns in buffer information:
999 // buffer_id: int. This value can be used to match the allocation in
1000 // allocation information.
1001 // buffer_name: string.
1002 // alt_mem_benefit: float. Roughly corresponds to how much the cost analysis
1003 // thought it would be beneficial to put this in the alternate memory. The
1004 // higher the value, the more it is memory bound.
1005 // size: int. In bytes.
1006 // definition_time: int. Logical time this value was defined in the schedule.
1007 // use_times: string. This is a semicolon-separated list of integers for all
1008 // the use times.
1009 // use_names: string. This is a semicolon-separated list of string
1010 // representation of uses.
1011 if (debug_str->empty()) {
1012 // Append the column names.
1013 absl::StrAppend(debug_str,
1014 "buffer_id,buffer_name,alt_mem_benefit,size,"
1015 "definition_time,use_times,use_names\n");
1016 }
1017 const HloBuffer& buffer =
1018 alias_analysis_.GetBufferContainingValue(*interval.buffer);
1019 const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
1020 int64 definition_time =
1021 instruction_schedule.at(interval.buffer->defining_position().instruction);
1022 std::vector<std::pair<int64, std::string>> uses;
1023 for (const HloValue* value : buffer.values()) {
1024 for (const HloUse& use : value->uses()) {
1025 uses.push_back(
1026 {instruction_schedule.at(use.instruction), use.ToString()});
1027 }
1028 }
1029 absl::c_sort(uses);
1030 std::vector<int64> use_times;
1031 std::vector<std::string> use_names;
1032 use_times.reserve(uses.size());
1033 use_names.reserve(uses.size());
1034 for (const auto& use : uses) {
1035 use_times.push_back(use.first);
1036 use_names.push_back(use.second);
1037 }
1038
1039 absl::StrAppend(debug_str, buffer.id(), ",");
1040 absl::StrAppend(debug_str, "\"", interval.buffer->ToShortString(), "\",");
1041 auto alternate_memory_benefit =
1042 options_.prefetch_interval_picker->BufferIntervalAlternateMemoryBenefit(
1043 interval);
1044 absl::StrAppend(
1045 debug_str, alternate_memory_benefit ? *alternate_memory_benefit : 0, ",");
1046 absl::StrAppend(debug_str, interval.size, ",");
1047 absl::StrAppend(debug_str, definition_time, ",");
1048 absl::StrAppend(debug_str, "\"", absl::StrJoin(use_times, ";"), "\",");
1049 absl::StrAppend(debug_str, "\"", absl::StrJoin(use_names, ";"), "\"");
1050 absl::StrAppend(debug_str, "\n");
1051 }
1052
AppendAllocationInfoDebugString(const AllocationValue & value,const MemorySpaceAssignment::Allocation & allocation,std::string & debug_str) const1053 void AlternateMemoryBestFitHeap::AppendAllocationInfoDebugString(
1054 const AllocationValue& value,
1055 const MemorySpaceAssignment::Allocation& allocation,
1056 std::string& debug_str) const {
1057 // Columns in allocation information:
1058 // buffer_id: int. This value can be used the match with buffer info.
1059 // size: int. In bytes.
1060 // offset: int. In bytes.
1061 // start_time: int. Logical start time of the allocation.
1062 // end_time: int. Logical end time of the allocation.
1063 if (debug_str.empty()) {
1064 // Append the column names.
1065 absl::StrAppend(&debug_str, "buffer_id,size,offset,start_time,end_time\n");
1066 }
1067 if (allocation.memory_space() == MemorySpace::kAlternate) {
1068 const HloBuffer& buffer =
1069 alias_analysis_.GetBufferContainingValue(*value.value());
1070 absl::StrAppend(&debug_str, buffer.id(), ",");
1071 absl::StrAppend(&debug_str, value.size(), ",");
1072 absl::StrAppend(&debug_str, allocation.chunk().offset, ",");
1073 absl::StrAppend(&debug_str, allocation.start_time(), ",");
1074 absl::StrAppend(&debug_str, allocation.end_time(), "\n");
1075 }
1076 }
1077
DumpDebugStringsIfEnabled() const1078 void AlternateMemoryBestFitHeap::DumpDebugStringsIfEnabled() const {
1079 if (!options_.dump_fn) {
1080 return;
1081 }
1082 options_.dump_fn("bufferinfo", buffer_info_str_);
1083 options_.dump_fn("allocinfo", allocation_info_str_);
1084 }
1085
Finish()1086 HeapSimulator::Result<HloValue> AlternateMemoryBestFitHeap::Finish() {
1087 if (options_.enable_cross_program_prefetch) {
1088 absl::optional<AlternateMemoryBestFitHeap::BufferInterval>
1089 prefetch_candidate = FindCrossProgramPrefetchCandidate(
1090 alias_analysis_, hlo_live_range_, options_);
1091 if (prefetch_candidate) {
1092 HloModule* module =
1093 prefetch_candidate->buffer->instruction()->GetModule();
1094 AllocateCrossProgramPrefetchBuffer(module, prefetch_candidate);
1095 }
1096 }
1097
1098 std::vector<BufferInterval> sorted_buffer_intervals =
1099 GetSortedBufferIntervals();
1100
1101 VLOG(1) << "Assigning buffers to alternate memory. Max heap size = "
1102 << options_.max_size_in_bytes;
1103
1104 AddInputAndOutputRequiredAssignments();
1105
1106 if (VLOG_IS_ON(3)) {
1107 VLOG(3) << "Flattened instruction sequence:";
1108 const auto& instruction_sequence =
1109 hlo_live_range_.flattened_instruction_sequence().instructions();
1110 for (int i = 0; i < instruction_sequence.size(); ++i) {
1111 VLOG(3) << " " << i << ": " << instruction_sequence[i]->parent()->name()
1112 << " " << instruction_sequence[i]->name();
1113 }
1114 }
1115
1116 for (const auto& interval : sorted_buffer_intervals) {
1117 auto colocated_intervals = GetSortedColocatedIntervals(interval);
1118 if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) {
1119 // Increment the reserved part of alternate memory so that it is not
1120 // available for other buffers.
1121 reserved_in_bytes_ += options_.size_fn(*interval.buffer);
1122 }
1123 }
1124 VLOG(2) << "Total reserved bytes = " << reserved_in_bytes_;
1125
1126 for (auto& interval : sorted_buffer_intervals) {
1127 if (!interval.need_allocation) {
1128 continue;
1129 }
1130
1131 if (!MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
1132 interval)) {
1133 continue;
1134 }
1135
1136 HloInstruction* inst = interval.buffer->instruction();
1137 HloModule* module = inst->GetModule();
1138
1139 // Don't intra-program prefetch a cross program prefetch
1140 if (inst->opcode() == HloOpcode::kParameter &&
1141 absl::c_count(module->CrossProgramPrefetches(),
1142 std::make_pair(inst->parameter_number(),
1143 interval.buffer->index())) > 0) {
1144 VLOG(3) << "Skip " << interval.buffer->ToShortString()
1145 << " because it is cross-program prefetched.";
1146 continue;
1147 }
1148
1149 if (interval.size > available_heap_size()) {
1150 VLOG(3) << "Skip " << interval.buffer->ToShortString()
1151 << " because the buffer is larger than the heap size.";
1152 continue;
1153 }
1154
1155 auto colocated_intervals = GetSortedColocatedIntervals(interval);
1156
1157 if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) {
1158 VLOG(3) << "Interval " << interval.buffer->ToShortString()
1159 << " is reserved in the alternate memory.";
1160 for (const BufferInterval* colocated_interval : colocated_intervals) {
1161 const HloValue* value = colocated_interval->buffer;
1162 // Color all of the aliased reserved buffers here because reserved
1163 // alternate memory allocations will not have an entry in preset
1164 // allocations that is normally used for coloring.
1165 for (auto& position : value->positions()) {
1166 VLOG(4) << "Coloring " << position.ToString();
1167 Shape* shape = ShapeUtil::GetMutableSubshape(
1168 position.instruction->mutable_shape(), position.index);
1169 CHECK(shape->IsArray()) << "Coloring a shape that is not an array: "
1170 << position.ToString();
1171 shape->mutable_layout()->set_memory_space(
1172 options_.alternate_memory_space);
1173 }
1174 }
1175 continue;
1176 }
1177
1178 if (colocated_intervals.size() > 1 &&
1179 !options_.allocate_across_sequential_calls) {
1180 VLOG(4) << "Not allocating " << interval.buffer->ToShortString()
1181 << " because it aliases with another interval and "
1182 << " allocate_across_sequential_calls is false.";
1183 continue;
1184 }
1185
1186 if (!ConsumeFuel("memory_space_assignment", [&] {
1187 return absl::StrCat("Ran out of fuel at buffer: ",
1188 colocated_intervals[0]->buffer->ToShortString());
1189 })) {
1190 continue;
1191 }
1192
1193 AppendBufferInfoDebugString(interval, &buffer_info_str_);
1194
1195 std::vector<AllocationValue> allocation_values;
1196 CreateAllocationValuesFromColocatedIntervals(colocated_intervals,
1197 allocation_values);
1198
1199 // Retry allocating this value with larger limits if allocation fails.
1200 bool repacked = false;
1201 for (int retry_number = 0; retry_number < options_.max_retries;
1202 retry_number++) {
1203 AddRequiredAssignmentsForColocatedIntervals(colocated_intervals);
1204 bool final_retry = (retry_number == options_.max_retries - 1);
1205 options_.prefetch_interval_picker->SetRetryNumber(retry_number);
1206 Result result =
1207 AllocateAllocationValues(absl::MakeSpan(allocation_values));
1208 VLOG(2) << "Allocation result = "
1209 << absl::StrFormat("%x", static_cast<int>(result));
1210 if (result_requires_uncommit(result) ||
1211 (!final_retry && result_failed_because_of_async_copy(result))) {
1212 UncommitPendingChunks(absl::MakeSpan(allocation_values));
1213 VLOG(2) << "Couldn't allocate. Retry number " << retry_number;
1214 } else if ((result_is(result, Result::kFailOutOfMemory) ||
1215 options_.repack_after_every_allocation) &&
1216 num_repacks_ < options_.max_repacks && !repacked) {
1217 UncommitPendingChunks(absl::MakeSpan(allocation_values));
1218 ++num_repacks_;
1219 repacked = true;
1220 CHECK_NE(options_.repacker, nullptr);
1221 std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*>
1222 repack_allocation_blocks;
1223 ExportAllocationsForRepacking(repack_allocation_blocks);
1224 VLOG(2) << "Repacking.";
1225 auto repack_status =
1226 options_.repacker->Repack(absl::MakeSpan(repack_allocation_blocks));
1227 CHECK_EQ(repack_status.status(), Status::OK());
1228 VLOG(2) << "Repack complete. Modified = " << *repack_status;
1229 if (*repack_status) {
1230 ImportRepackedAllocations();
1231 --retry_number;
1232 }
1233 } else {
1234 FinalizeAllocations(absl::MakeSpan(allocation_values));
1235 break;
1236 }
1237 }
1238 }
1239
1240 VLOG(3) << "Debug buffer info: ";
1241 VLOG(3) << buffer_info_str_;
1242 VLOG(3) << "Debug allocation info: ";
1243 VLOG(3) << allocation_info_str_;
1244 DumpDebugStringsIfEnabled();
1245
1246 HeapSimulator::Result<HloValue> result;
1247 result.heap_size = result_.heap_size;
1248 result.heap_results.emplace_back(std::move(result_));
1249 return result;
1250 }
1251
AddRequiredAssignmentsForColocatedIntervals(absl::Span<const AlternateMemoryBestFitHeap::BufferInterval * const> colocated_intervals)1252 void AlternateMemoryBestFitHeap::AddRequiredAssignmentsForColocatedIntervals(
1253 absl::Span<const AlternateMemoryBestFitHeap::BufferInterval* const>
1254 colocated_intervals) {
1255 // TODO(berkin): For now, place the phi values due to conditionals in
1256 // default memory.
1257 for (const BufferInterval* colocated_interval : colocated_intervals) {
1258 const HloValue* value = colocated_interval->buffer;
1259 for (const auto& position : value->positions()) {
1260 if (position.instruction->opcode() == HloOpcode::kConditional) {
1261 VLOG(3) << "Adding required assignment for condition output: "
1262 << value->ToShortString();
1263 AddRequiredAssignment(position.instruction, position.index,
1264 MemorySpace::kDefault);
1265 for (const HloComputation* called_computation :
1266 position.instruction->called_computations()) {
1267 AddRequiredAssignment(called_computation->root_instruction(),
1268 position.index, MemorySpace::kDefault);
1269 }
1270 }
1271 }
1272 }
1273 }
1274
CreateAllocationValuesFromColocatedIntervals(absl::Span<const AlternateMemoryBestFitHeap::BufferInterval * const> colocated_intervals,std::vector<MemorySpaceAssignment::AllocationValue> & allocation_values)1275 void AlternateMemoryBestFitHeap::CreateAllocationValuesFromColocatedIntervals(
1276 absl::Span<const AlternateMemoryBestFitHeap::BufferInterval* const>
1277 colocated_intervals,
1278 std::vector<MemorySpaceAssignment::AllocationValue>& allocation_values) {
1279 // Create AllocationValues for all the colocated intervals.
1280 for (const auto& colocated_interval : colocated_intervals) {
1281 CreateAllocationValues(*colocated_interval, allocation_values);
1282 }
1283 FindAliases(&allocation_values, /*skip_values_with_no_uses=*/true);
1284 }
1285
1286 AlternateMemoryBestFitHeap::Result
AllocateAllocationValues(absl::Span<MemorySpaceAssignment::AllocationValue> allocation_values)1287 AlternateMemoryBestFitHeap::AllocateAllocationValues(
1288 absl::Span<MemorySpaceAssignment::AllocationValue> allocation_values) {
1289 const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
1290
1291 // Data structure to contain the preferred offset for a given computation.
1292 // We ensure that the same offset will be allocated outside the while loop
1293 // as well as inside the while loop.
1294 absl::flat_hash_map<const HloComputation*, AliasedOffset*>
1295 preferred_offset_for_computation;
1296
1297 Result result = Result::kSuccess;
1298 for (AllocationValue& allocation_value : allocation_values) {
1299 int64 definition_time =
1300 instruction_schedule.at(allocation_value.defining_instruction());
1301
1302 AliasedOffset* preferred_offset = nullptr;
1303 auto preferred_offset_it =
1304 preferred_offset_for_computation.find(allocation_value.computation());
1305 if (preferred_offset_it != preferred_offset_for_computation.end()) {
1306 preferred_offset = preferred_offset_it->second;
1307 }
1308
1309 // Iterate over the uses.
1310 for (int use_idx = 0; use_idx < allocation_value.uses().size(); ++use_idx) {
1311 const AllocationValue::Use& use = allocation_value.uses().at(use_idx);
1312 const HloUse hlo_use = use.hlo_use;
1313 int64 use_time = instruction_schedule.at(hlo_use.instruction);
1314 int64 latest_prefetch_time = use_time;
1315 bool allow_no_copy_alternate_mem_allocation = true;
1316 absl::optional<int64> earliest_prefetch_time = absl::nullopt;
1317
1318 // Sequential calls include kWhile, kCall, and kConditional opcodes.
1319 bool is_sequential_call =
1320 (GetInstructionCallContext(hlo_use.instruction->opcode()) ==
1321 CallContext::kSequential);
1322 if (is_sequential_call) {
1323 for (const HloComputation* called_computation :
1324 hlo_use.instruction->called_computations()) {
1325 const HloLiveRange::TimeBound& computation_span =
1326 hlo_live_range_.computation_span_times().at(called_computation);
1327 latest_prefetch_time =
1328 std::min(computation_span.start - 1, latest_prefetch_time);
1329 }
1330 if (hlo_use.instruction->opcode() == HloOpcode::kWhile) {
1331 // Given an example while loop and flattened schedule (logical times
1332 // shown on the left):
1333 //
1334 // 0: a = ...
1335 // 1: ...
1336 // cond {
1337 // 2: p = param(0)
1338 // 3: ...
1339 // }
1340 // body {
1341 // 4: p = param(0)
1342 // 5: ...
1343 // 6: ROOT ...
1344 // }
1345 // 7: w = while(a), body=body, cond=cond
1346 //
1347 // When processing "a" (time 0) and its while use (time 7), we update
1348 // the interval to time 0-4. This is so that the remaining interval
1349 // (5-6) can be allocated separately and this buffer doesn't waste
1350 // alternate memory space within the while loop body.
1351 HloComputation* while_body = hlo_use.instruction->while_body();
1352 // We require while body ROOTs to be the last in the schedule.
1353 CHECK_EQ(instruction_schedule.at(while_body->root_instruction()) + 1,
1354 instruction_schedule.at(hlo_use.instruction))
1355 << "While body ROOTs need to be the last in the schedule! "
1356 "Please run RootInstructionSinker.";
1357 // Replace the use time with the parameter time so that we can decide
1358 // on alternate memory allocations within the while loop body when we
1359 // look at uses within the while loop body.
1360 use_time =
1361 instruction_schedule.at(while_body->parameter_instruction(0));
1362 } else if (hlo_use.instruction->opcode() == HloOpcode::kConditional) {
1363 // Replace the use time with the earliest parameter of called
1364 // computations.
1365 for (const HloComputation* called_computation :
1366 hlo_use.instruction->called_computations()) {
1367 use_time = std::min(
1368 use_time, instruction_schedule.at(
1369 called_computation->parameter_instruction(0)));
1370 }
1371 }
1372 }
1373
1374 // Add a required assignment in default memory if the use not allowed in
1375 // alternate memory.
1376 if (!IsUseAllowedInAlternateMemory(allocation_value, hlo_use)) {
1377 AddRequiredAssignment(allocation_value.value(), hlo_use.instruction,
1378 MemorySpace::kDefault, use_time);
1379 } else if (use_idx > 0) {
1380 // We allow buffers in alternate memory that are passed into
1381 // conditionals to give up their alternate memory allocation inside the
1382 // called computation. This means that if a conditional operator has an
1383 // alternate memory allocation, subsequent uses cannot use the same
1384 // alternate memory allocation in order not to clobber data. So we force
1385 // default memory allocation for these subsequent uses.
1386 const AllocationValue::Use& previous_use =
1387 allocation_value.uses().at(use_idx - 1);
1388 if (previous_use.hlo_use.instruction->opcode() ==
1389 HloOpcode::kConditional &&
1390 previous_use.hlo_use.instruction != hlo_use.instruction) {
1391 allow_no_copy_alternate_mem_allocation = false;
1392 earliest_prefetch_time =
1393 instruction_schedule.at(previous_use.hlo_use.instruction);
1394 VLOG(3) << "Previous use (" << previous_use.hlo_use.ToString()
1395 << ") of use (" << hlo_use.ToString()
1396 << ") is a conditional, so this use will need to evict. "
1397 << "Earliest prefetch time = " << *earliest_prefetch_time;
1398 }
1399 }
1400
1401 // Bitcasts don't define buffers and don't directly consume buffers. Skip
1402 // allocating buffers for bitcast uses (unless they are the root
1403 // instruction). The uses that feed from bitcasts will be handled
1404 // specially.
1405 if (hlo_use.instruction->opcode() != HloOpcode::kBitcast ||
1406 hlo_use.instruction ==
1407 hlo_use.instruction->parent()->root_instruction()) {
1408 AllocationRequest request;
1409 // Rarely, (e.g., when conditional true and false parameters are the
1410 // same), definition time can be the time of the conditional and use
1411 // time is the parameter use, which is less.
1412 request.start_time = std::min(definition_time, use_time);
1413 request.end_time = use_time;
1414 request.latest_prefetch_time = latest_prefetch_time;
1415 request.size = allocation_value.size();
1416 request.allow_no_copy_alternate_mem_allocation =
1417 allow_no_copy_alternate_mem_allocation;
1418 request.earliest_prefetch_time = earliest_prefetch_time;
1419 request.preferred_offset = preferred_offset;
1420 request.use = &use;
1421 request.allocation_value = &allocation_value;
1422 result_mark(AllocateSegment(request), result);
1423 if (result_requires_uncommit(result)) {
1424 // If the allocation finding failed (e.g., due to running out of
1425 // asynchronous copies), then fall back to allocating the buffer
1426 // entirely in the default memory.
1427 return result;
1428 }
1429
1430 // If there are multiple uses, they can try using the memory allocation
1431 // already at the alternate memory.
1432 definition_time = instruction_schedule.at(hlo_use.instruction);
1433 }
1434
1435 // Propagate the allocation to any aliases this use might have had.
1436 MemorySpaceAssignment::Allocation* aliased_allocation =
1437 GetLiveAllocationAt(*allocation_value.allocation_sequence(),
1438 use_time);
1439 for (const HloPosition& aliased_position : use.aliases) {
1440 AddAliasedRequiredAssignment(aliased_position.instruction,
1441 aliased_position.index,
1442 aliased_allocation);
1443 }
1444
1445 // Special case for while loops since the root offset must agree with
1446 // other offsets: remember the preferred offset for the while loop body.
1447 if (hlo_use.instruction->opcode() == HloOpcode::kWhile &&
1448 aliased_allocation->memory_space() == MemorySpace::kAlternate) {
1449 preferred_offset_for_computation[hlo_use.instruction->while_body()] =
1450 GetAliasedOffset(*aliased_allocation);
1451 }
1452 }
1453 }
1454 return result;
1455 }
1456
operator <(const AsynchronousCopy & a,const AsynchronousCopy & b)1457 bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b) {
1458 return (a.start_time < b.start_time && a.end_time <= b.end_time) ||
1459 (a.start_time <= b.start_time && a.end_time < b.end_time);
1460 }
1461
AddCopy(const AsynchronousCopy & copy)1462 void AsynchronousCopyOrdering::AddCopy(const AsynchronousCopy& copy) {
1463 auto it_and_inserted = ranges_.insert(copy);
1464 CHECK(it_and_inserted.second ||
1465 it_and_inserted.first->start_time == copy.start_time);
1466 }
1467
RemoveCopy(const AsynchronousCopy & copy)1468 void AsynchronousCopyOrdering::RemoveCopy(const AsynchronousCopy& copy) {
1469 auto copy_it = ranges_.find(copy);
1470 CHECK(copy_it != ranges_.end());
1471 ranges_.erase(copy_it);
1472 }
1473
ViolatesOrdering(int64 start_time,int64 end_time) const1474 absl::optional<AsynchronousCopy> AsynchronousCopyOrdering::ViolatesOrdering(
1475 int64 start_time, int64 end_time) const {
1476 // We allow identical start and end times. It is enough to check for just the
1477 // start time in case we find a match in ranges_ because the found value will
1478 // either be identical to {start_time, end_time} (and this doesn't violate) or
1479 // its start_time will be smaller and end_time will be larger (this violates).
1480 auto copy_it = ranges_.find(
1481 {start_time, end_time, MemorySpaceAssignment::MemorySpace::kAlternate});
1482 if (copy_it != ranges_.end() && copy_it->start_time != start_time) {
1483 VLOG(4) << "Violates ordering: (" << start_time << ", " << end_time
1484 << ") and (" << copy_it->start_time << ", " << copy_it->end_time
1485 << ")";
1486 return *copy_it;
1487 }
1488 return absl::nullopt;
1489 }
1490
1491 AlternateMemoryBestFitHeap::AliasedOffset*
GetAliasedOffset(const MemorySpaceAssignment::Allocation & allocation)1492 AlternateMemoryBestFitHeap::GetAliasedOffset(
1493 const MemorySpaceAssignment::Allocation& allocation) {
1494 auto aliased_offset_it = aliased_offset_map_.find(&allocation);
1495 CHECK(aliased_offset_it != aliased_offset_map_.end());
1496 return aliased_offset_it->second;
1497 }
1498
CreateOrAddToAliasedOffset(const MemorySpaceAssignment::Allocation & allocation,AlternateMemoryBestFitHeap::AliasedOffset * aliased_offset)1499 void AlternateMemoryBestFitHeap::CreateOrAddToAliasedOffset(
1500 const MemorySpaceAssignment::Allocation& allocation,
1501 AlternateMemoryBestFitHeap::AliasedOffset* aliased_offset) {
1502 CHECK(allocation.memory_space() == MemorySpace::kAlternate);
1503 CHECK(!aliased_offset_map_.contains(&allocation));
1504 if (!aliased_offset) {
1505 aliased_offsets_.push_back({allocation.chunk().offset});
1506 aliased_offset = &aliased_offsets_.back();
1507 }
1508 CHECK_EQ(allocation.chunk().offset, aliased_offset->offset);
1509 CHECK(aliased_offset->allocations.insert(&allocation).second);
1510 aliased_offset_map_[&allocation] = aliased_offset;
1511 }
1512
1513 /*static*/ MemorySpaceAssignment::Allocation*
GetLiveAllocationAt(const MemorySpaceAssignment::AllocationSequence & allocations,int64 time)1514 AlternateMemoryBestFitHeap::GetLiveAllocationAt(
1515 const MemorySpaceAssignment::AllocationSequence& allocations, int64 time) {
1516 for (auto allocation_it = allocations.rbegin();
1517 allocation_it != allocations.rend(); ++allocation_it) {
1518 if ((*allocation_it)->start_time() <= time &&
1519 (*allocation_it)->end_time() >= time) {
1520 return allocation_it->get();
1521 }
1522 }
1523 return nullptr;
1524 }
1525
AllocateCrossProgramPrefetchBuffer(HloModule * module,absl::optional<BufferInterval> prefetch_candidate)1526 void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer(
1527 HloModule* module, absl::optional<BufferInterval> prefetch_candidate) {
1528 if (!prefetch_candidate) {
1529 return;
1530 }
1531
1532 ChunkCandidate chunk_candidate = FindChunkCandidate(*prefetch_candidate);
1533 if (chunk_candidate.chunk.offset != 0 ||
1534 chunk_candidate.heap_size > available_heap_size()) {
1535 LOG(WARNING)
1536 << "Could not allocate preferred memory for cross program prefetch";
1537 return;
1538 }
1539 AddToPendingChunks(*prefetch_candidate, chunk_candidate);
1540
1541 const HloValue* buffer = prefetch_candidate->buffer;
1542 int64 parameter = buffer->instruction()->parameter_number();
1543 module->AddCrossProgramPrefetch(parameter, buffer->index());
1544
1545 MemorySpaceAssignment::AllocationSequence allocations;
1546 allocations.push_back(absl::make_unique<MemorySpaceAssignment::Allocation>(
1547 buffer->defining_position(), MemorySpace::kDefault, kDummyChunk,
1548 prefetch_candidate->start, prefetch_candidate->end));
1549
1550 // Find the earliest use.
1551 const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
1552 auto uses = buffer->uses();
1553 auto use_schedule_compare = [&](const HloUse& lhs, const HloUse& rhs) {
1554 return instruction_schedule.at(lhs.instruction) <
1555 instruction_schedule.at(rhs.instruction);
1556 };
1557 auto first_use = absl::c_min_element(uses, use_schedule_compare);
1558 int64 latest_prefetch_time = instruction_schedule.at(first_use->instruction);
1559
1560 // Find the latest use time.
1561 int64 last_use_time = instruction_schedule.at(
1562 absl::c_max_element(uses, use_schedule_compare)->instruction);
1563 for (const HloValue* colocation : prefetch_candidate->colocations) {
1564 last_use_time = std::max(
1565 last_use_time,
1566 instruction_schedule.at(
1567 absl::c_max_element(colocation->uses(), use_schedule_compare)
1568 ->instruction));
1569 }
1570
1571 int64 end_of_program_prefetch_end_time = instruction_schedule.size() - 1;
1572 int64 end_of_program_prefetch_start_time =
1573 options_.prefetch_interval_picker->PreferredPrefetchStartTime(
1574 buffer->defining_position().shape(), last_use_time,
1575 end_of_program_prefetch_end_time, end_of_program_prefetch_end_time);
1576 VLOG(2) << "last use time = " << last_use_time
1577 << ", end-of-program prefetch start time = "
1578 << end_of_program_prefetch_start_time;
1579 bool free_buffer =
1580 (end_of_program_prefetch_start_time > last_use_time &&
1581 end_of_program_prefetch_start_time < end_of_program_prefetch_end_time);
1582 int64 cross_program_prefetch_end_time =
1583 free_buffer ? last_use_time : prefetch_candidate->end;
1584
1585 AddAsyncCopy(*allocations.back(), MemorySpace::kAlternate,
1586 chunk_candidate.chunk, prefetch_candidate->start,
1587 cross_program_prefetch_end_time, latest_prefetch_time,
1588 &allocations, /*aliased_offset=*/nullptr,
1589 /*is_cross_program_prefetch=*/true);
1590 absl::c_for_each(uses, [&](auto& use) { allocations.back()->AddUse(use); });
1591 AliasedOffset* cross_program_prefetch_offset =
1592 GetAliasedOffset(*allocations.back());
1593
1594 if (free_buffer) {
1595 VLOG(2) << "Adding an end-of-program prefetch for freed "
1596 "cross-program-prefetched buffer.";
1597 AddAsyncCopy(*allocations.front(), MemorySpace::kAlternate,
1598 chunk_candidate.chunk, end_of_program_prefetch_start_time,
1599 end_of_program_prefetch_end_time,
1600 end_of_program_prefetch_end_time, &allocations,
1601 cross_program_prefetch_offset);
1602 CHECK_EQ(cross_program_prefetch_offset->offset,
1603 allocations.back()->chunk().offset);
1604 }
1605
1606 for (auto& allocation : allocations) {
1607 allocations_->push_back(std::move(allocation));
1608 }
1609
1610 // Add a repack allocation block for the Allocation objects in alternate
1611 // memory.
1612 CHECK_EQ(repack_allocation_blocks_.size(), 0);
1613 for (const auto& allocation : *allocations_) {
1614 if (allocation->memory_space() == MemorySpace::kAlternate) {
1615 repack_allocation_blocks_.push_back(MakeRepackAllocationBlock(
1616 allocation->start_time(), allocation->end_time(),
1617 allocation->chunk().size, allocation->chunk().offset,
1618 static_cast<int64>(repack_allocation_blocks_.size()),
1619 allocation.get()));
1620 RepackAllocationBlock* inserted = &repack_allocation_blocks_.back();
1621 for (RepackAllocationBlock& colocation : repack_allocation_blocks_) {
1622 colocation.colocations.push_back(inserted);
1623 if (&colocation != inserted) {
1624 inserted->colocations.push_back(&colocation);
1625 }
1626 }
1627 }
1628 }
1629
1630 ClearPendingChunks();
1631 }
1632
1633 absl::optional<AlternateMemoryBestFitHeap::RequiredMemoryAssignment>
RequiredMemoryAssignmentAt(const HloValue * buffer,int64 time) const1634 AlternateMemoryBestFitHeap::RequiredMemoryAssignmentAt(const HloValue* buffer,
1635 int64 time) const {
1636 auto required_assignment_it = required_assignments_.find(buffer);
1637 absl::optional<RequiredMemoryAssignment> required_assignment_at_time;
1638 if (required_assignment_it != required_assignments_.end()) {
1639 for (const RequiredMemoryAssignment& required_assignment :
1640 required_assignment_it->second) {
1641 if (required_assignment.time == time) {
1642 // Sanity check that there is only one required at time.
1643 CHECK(!required_assignment_at_time);
1644 required_assignment_at_time = required_assignment;
1645 }
1646 }
1647 }
1648 return required_assignment_at_time;
1649 }
1650
1651 absl::optional<AlternateMemoryBestFitHeap::RequiredMemoryAssignment>
AliasedRequiredAssignmentForUse(const AllocationValue::Use & use) const1652 AlternateMemoryBestFitHeap::AliasedRequiredAssignmentForUse(
1653 const AllocationValue::Use& use) const {
1654 absl::optional<RequiredMemoryAssignment> required_assignment;
1655 for (const HloPosition& position : use.aliases) {
1656 const HloValue* value =
1657 &alias_analysis_.dataflow_analysis().GetUniqueValueAt(
1658 position.instruction, position.index);
1659 int64 time =
1660 hlo_live_range_.instruction_schedule().at(position.instruction);
1661 absl::optional<RequiredMemoryAssignment> required_assignment_for_alias =
1662 RequiredMemoryAssignmentAt(value, time);
1663 if (required_assignment == absl::nullopt) {
1664 required_assignment = required_assignment_for_alias;
1665 } else {
1666 CHECK(required_assignment_for_alias == absl::nullopt ||
1667 required_assignment->equals_ignoring_time(
1668 *required_assignment_for_alias));
1669 }
1670 }
1671 return required_assignment;
1672 }
1673
AddAliasedRequiredAssignment(const HloInstruction * instruction,ShapeIndex index,const MemorySpaceAssignment::Allocation * aliased_allocation)1674 void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignment(
1675 const HloInstruction* instruction, ShapeIndex index,
1676 const MemorySpaceAssignment::Allocation* aliased_allocation) {
1677 AliasedOffset* offset = nullptr;
1678 if (aliased_allocation->memory_space() == MemorySpace::kAlternate) {
1679 offset = GetAliasedOffset(*aliased_allocation);
1680 }
1681 AddRequiredAssignment(instruction, index, aliased_allocation->memory_space(),
1682 offset);
1683 }
1684
AddRequiredAssignment(const HloValue * value,const HloInstruction * instruction,MemorySpaceAssignment::MemorySpace memory_space,int64 time,AliasedOffset * offset)1685 void AlternateMemoryBestFitHeap::AddRequiredAssignment(
1686 const HloValue* value, const HloInstruction* instruction,
1687 MemorySpaceAssignment::MemorySpace memory_space, int64 time,
1688 AliasedOffset* offset) {
1689 // Check for existing required assignment at this time and make sure it is the
1690 // same as this if there is one.
1691 auto existing_required_assignment = RequiredMemoryAssignmentAt(value, time);
1692 if (existing_required_assignment) {
1693 CHECK(memory_space == existing_required_assignment->memory_space)
1694 << "inst = " << instruction->ToString() << " at " << time;
1695 CHECK((!offset && !existing_required_assignment->offset) ||
1696 offset == existing_required_assignment->offset);
1697 VLOG(3) << "Not adding required assignment because there is one already: "
1698 << value->ToShortString() << " at " << time << " at "
1699 << (memory_space == MemorySpace::kDefault ? "def" : "alt");
1700 } else {
1701 VLOG(3) << "Adding required assignment: " << value->ToShortString()
1702 << " at " << time << " at "
1703 << (memory_space == MemorySpace::kDefault ? "def" : "alt");
1704 RequiredMemoryAssignment required_assignment{memory_space, time, offset};
1705 required_assignments_[value].push_back(required_assignment);
1706 pending_required_assignments_.push_back({value, required_assignment});
1707 }
1708 }
1709
AddRequiredAssignment(const HloInstruction * instruction,ShapeIndex index,MemorySpace memory_space,AliasedOffset * offset)1710 void AlternateMemoryBestFitHeap::AddRequiredAssignment(
1711 const HloInstruction* instruction, ShapeIndex index,
1712 MemorySpace memory_space, AliasedOffset* offset) {
1713 const HloValue* value =
1714 &alias_analysis_.dataflow_analysis().GetUniqueValueAt(instruction, index);
1715 int64 instruction_time =
1716 hlo_live_range_.instruction_schedule().at(instruction);
1717 AddRequiredAssignment(value, instruction, memory_space, instruction_time,
1718 offset);
1719 }
1720
AddInputAndOutputRequiredAssignments()1721 void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() {
1722 // Go through the parameters and outputs and pin them to the corresponding
1723 // memory by adding a required assignment.
1724 const HloModule& module = alias_analysis_.dataflow_analysis().module();
1725 const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
1726 HloComputation* entry_computation = module.entry_computation();
1727 for (HloInstruction* parameter_instruction :
1728 entry_computation->parameter_instructions()) {
1729 int64 parameter_instruction_time =
1730 instruction_schedule.at(parameter_instruction);
1731 ShapeUtil::ForEachSubshape(
1732 parameter_instruction->shape(),
1733 [&](const Shape& subshape, const ShapeIndex& index) {
1734 MemorySpace memory_space = MemorySpace::kDefault;
1735 if (subshape.has_layout() && subshape.layout().memory_space() ==
1736 options_.alternate_memory_space) {
1737 memory_space = MemorySpace::kAlternate;
1738 }
1739 for (const HloBuffer* buffer :
1740 alias_analysis_.ComputeBuffersAt(parameter_instruction, index)) {
1741 for (const HloValue* value : buffer->values()) {
1742 VLOG(3) << "Adding required assignment for parameter value = "
1743 << value->ToShortString()
1744 << " time = " << parameter_instruction_time << " space = "
1745 << (memory_space == MemorySpace::kDefault ? "def"
1746 : "alt");
1747 required_assignments_[value].push_back(
1748 {memory_space, /*time=*/parameter_instruction_time});
1749 }
1750 }
1751 });
1752 }
1753 HloInstruction* root_instruction = entry_computation->root_instruction();
1754 int64 root_instruction_time = instruction_schedule.at(root_instruction);
1755 ShapeUtil::ForEachSubshape(
1756 root_instruction->shape(),
1757 [&](const Shape& subshape, const ShapeIndex& index) {
1758 MemorySpace memory_space = MemorySpace::kDefault;
1759 if (subshape.has_layout() && subshape.layout().memory_space() ==
1760 options_.alternate_memory_space) {
1761 memory_space = MemorySpace::kAlternate;
1762 }
1763 for (const HloBuffer* buffer :
1764 alias_analysis_.ComputeBuffersAt(root_instruction, index)) {
1765 for (const HloValue* value : buffer->values()) {
1766 VLOG(3) << "Adding required assignment for output value = "
1767 << value->ToShortString()
1768 << " time = " << root_instruction_time << " space = "
1769 << (memory_space == MemorySpace::kDefault ? "def" : "alt");
1770 required_assignments_[value].push_back(
1771 {memory_space, /*time=*/root_instruction_time});
1772 }
1773 }
1774 });
1775 }
1776
AreIntervalsReservedInAlternateMemory(absl::Span<const BufferInterval * const> colocated_intervals) const1777 bool AlternateMemoryBestFitHeap::AreIntervalsReservedInAlternateMemory(
1778 absl::Span<const BufferInterval* const> colocated_intervals) const {
1779 auto is_position_in_alternate_memory = [&](const HloPosition& position) {
1780 const Shape& shape = position.shape();
1781 return shape.has_layout() &&
1782 shape.layout().memory_space() == options_.alternate_memory_space;
1783 };
1784
1785 const HloModule& module = alias_analysis_.dataflow_analysis().module();
1786 const HloComputation* entry_computation = module.entry_computation();
1787 const HloInstruction* root_instruction =
1788 entry_computation->root_instruction();
1789 for (const BufferInterval* colocated_interval : colocated_intervals) {
1790 const HloValue* value = colocated_interval->buffer;
1791 if (value->defining_instruction()->opcode() == HloOpcode::kParameter &&
1792 value->defining_instruction()->parent() == entry_computation &&
1793 is_position_in_alternate_memory(value->defining_position())) {
1794 return true;
1795 }
1796
1797 for (const HloPosition& position : value->positions()) {
1798 if (position.instruction == root_instruction &&
1799 is_position_in_alternate_memory(position)) {
1800 return true;
1801 }
1802 }
1803 }
1804 return false;
1805 }
1806
ExportAllocationsForRepacking(std::vector<MemorySpaceAssignmentRepacker::AllocationBlock * > & allocations)1807 void AlternateMemoryBestFitHeap::ExportAllocationsForRepacking(
1808 std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*>& allocations) {
1809 for (RepackAllocationBlock& allocation_block : repack_allocation_blocks_) {
1810 allocations.push_back(&allocation_block);
1811 }
1812 }
1813
ImportRepackedAllocations()1814 void AlternateMemoryBestFitHeap::ImportRepackedAllocations() {
1815 interval_tree_ = {};
1816 for (RepackAllocationBlock& allocation_block : repack_allocation_blocks_) {
1817 MemorySpaceAssignment::Allocation* allocation = allocation_block.allocation;
1818 VLOG(3) << "Moved " << allocation->ToString() << ", size "
1819 << allocation->chunk().size << ", (" << allocation_block.start_time
1820 << ", " << allocation_block.end_time << ") from "
1821 << allocation_block.initial_offset << " to "
1822 << allocation_block.offset;
1823 allocation_block.allocation->mutable_chunk()->offset =
1824 allocation_block.offset;
1825 interval_tree_.Add(allocation_block.start_time, allocation_block.end_time,
1826 {allocation_block.offset, allocation_block.size});
1827 allocation_block.initial_offset = allocation_block.offset;
1828 allocation_block.offset = -1;
1829 }
1830 }
1831
UncommitPendingChunks(absl::Span<AllocationValue> allocation_values)1832 void AlternateMemoryBestFitHeap::UncommitPendingChunks(
1833 absl::Span<AllocationValue> allocation_values) {
1834 // Clear the allocation sequence of the allocation values so that in case we
1835 // retry allocation after uncommitting.
1836 for (AllocationValue& allocation_value : allocation_values) {
1837 allocation_value.allocation_sequence()->clear();
1838 }
1839 for (const auto& interval_and_chunk : pending_chunks_) {
1840 const BufferInterval& interval = interval_and_chunk.first;
1841 const Chunk& chunk = interval_and_chunk.second.chunk;
1842 VLOG(3) << "Uncommitting: (" << interval.start << ", " << interval.end
1843 << ") off = " << chunk.offset << " size = " << chunk.size;
1844 interval_tree_.Remove(interval.start, interval.end, chunk);
1845 }
1846 for (const auto& interval : pending_async_copies_) {
1847 if (interval.destination == MemorySpace::kAlternate) {
1848 prefetch_interval_tree_.Remove(interval.start_time, interval.end_time,
1849 kDummyChunk);
1850 async_copy_ordering_.RemoveCopy(interval);
1851 } else {
1852 eviction_interval_tree_.Remove(interval.start_time, interval.end_time,
1853 kDummyChunk);
1854 }
1855 }
1856 for (const auto& value_and_required_assignment :
1857 pending_required_assignments_) {
1858 auto& required_assignment_vector =
1859 required_assignments_[value_and_required_assignment.first];
1860 const RequiredMemoryAssignment& required_assignment =
1861 value_and_required_assignment.second;
1862 VLOG(3) << "Removing required assignment: "
1863 << (required_assignment.memory_space == MemorySpace::kDefault
1864 ? "def"
1865 : "alt")
1866 << " time = " << required_assignment.time << " off = "
1867 << (required_assignment.offset ? required_assignment.offset->offset
1868 : -1);
1869 for (auto it = required_assignment_vector.begin();
1870 it != required_assignment_vector.end(); ++it) {
1871 if (*it == value_and_required_assignment.second) {
1872 required_assignment_vector.erase(it);
1873 break;
1874 }
1875 }
1876 }
1877 ClearPendingChunks();
1878 }
1879
FinalizeAllocations(absl::Span<AllocationValue> allocation_values)1880 void AlternateMemoryBestFitHeap::FinalizeAllocations(
1881 absl::Span<AllocationValue> allocation_values) {
1882 absl::flat_hash_map<const AliasedOffset*,
1883 std::vector<MemorySpaceAssignment::Allocation*>>
1884 colocation_map;
1885 for (AllocationValue& allocation_value : allocation_values) {
1886 for (auto& allocation : *allocation_value.allocation_sequence()) {
1887 AppendAllocationInfoDebugString(allocation_value, *allocation,
1888 allocation_info_str_);
1889 allocations_->push_back(std::move(allocation));
1890 MemorySpaceAssignment::Allocation* inserted_allocation =
1891 allocations_->back().get();
1892 if (inserted_allocation->memory_space() == MemorySpace::kAlternate) {
1893 colocation_map[GetAliasedOffset(*inserted_allocation)].push_back(
1894 inserted_allocation);
1895 }
1896 }
1897 }
1898 // The allocations that have the same AliasedOffset need to be colocated.
1899 // Export these to repack_allocation_blocks_ so that we can repack them to
1900 // reduce fragmentation.
1901 for (auto& colocation : colocation_map) {
1902 std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*> colocations;
1903 for (MemorySpaceAssignment::Allocation* colocated_allocation :
1904 colocation.second) {
1905 repack_allocation_blocks_.push_back(MakeRepackAllocationBlock(
1906 colocated_allocation->start_time(), colocated_allocation->end_time(),
1907 colocated_allocation->chunk().size,
1908 colocated_allocation->chunk().offset,
1909 static_cast<int64>(repack_allocation_blocks_.size()),
1910 colocated_allocation));
1911 colocations.push_back(&repack_allocation_blocks_.back());
1912 }
1913 for (MemorySpaceAssignmentRepacker::AllocationBlock* repack_block :
1914 colocations) {
1915 repack_block->colocations = colocations;
1916 }
1917 }
1918 ClearPendingChunks();
1919 }
1920
ClearPendingChunks()1921 void AlternateMemoryBestFitHeap::ClearPendingChunks() {
1922 pending_chunks_.clear();
1923 pending_async_copies_.clear();
1924 pending_required_assignments_.clear();
1925 aliased_offset_map_.clear();
1926 aliased_offsets_.clear();
1927 }
1928
AddToPendingChunks(const BufferInterval & buffer_interval,const ChunkCandidate & chunk_candidate)1929 void AlternateMemoryBestFitHeap::AddToPendingChunks(
1930 const BufferInterval& buffer_interval,
1931 const ChunkCandidate& chunk_candidate) {
1932 VLOG(3) << "Committing chunk: " << buffer_interval.start << "-"
1933 << buffer_interval.end << " : [" << chunk_candidate.chunk.offset
1934 << ", " << chunk_candidate.chunk.size << "]";
1935 pending_chunks_.emplace_back(buffer_interval, chunk_candidate);
1936 CommitChunk(buffer_interval, chunk_candidate);
1937 }
1938
AllocateSegment(const AllocationRequest & request)1939 AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment(
1940 const AllocationRequest& request) {
1941 auto allocation_sequence = request.allocation_value->allocation_sequence();
1942 // start_time == end_time is a special case where the value is consumed
1943 // multiple times by the same instruction. We can just find the previous
1944 // allocation and use that allocation.
1945 if (request.start_time == request.end_time) {
1946 MemorySpaceAssignment::Allocation* allocation =
1947 GetLiveAllocationAt(*allocation_sequence, request.end_time);
1948 CHECK_NE(allocation, nullptr);
1949 allocation->AddUse(request.use->hlo_use);
1950 return Result::kSuccess;
1951 }
1952
1953 const HloPosition& defining_position =
1954 request.allocation_value->defining_position();
1955 VLOG(2) << "Finding allocation for "
1956 << request.allocation_value->ToShortString() << " ("
1957 << request.start_time << ", " << request.end_time
1958 << ") latest prefetch = " << request.latest_prefetch_time
1959 << " last use = " << request.allocation_value->uses().back().time
1960 << " use = " << request.use->hlo_use.ToString()
1961 << ". Size = " << request.size
1962 << ", def pos = " << defining_position.ToString();
1963 CHECK_LE(request.start_time, request.end_time);
1964
1965 // There could be a requirement to pin this buffer to default memory either
1966 // because it is a parameter or an output. If the buffer is a parameter, then
1967 // we're allowed to prefetch. If the use expects the output to be in default
1968 // memory, we cannot prefetch it because if we did, it would be in alternate
1969 // memory instead.
1970 auto required_assignment_at_start = RequiredMemoryAssignmentAt(
1971 request.allocation_value->value(), request.start_time);
1972 absl::optional<MemorySpace> required_memory_space_at_start;
1973 if (required_assignment_at_start) {
1974 required_memory_space_at_start = required_assignment_at_start->memory_space;
1975 }
1976 // Find required assignment both for the use and its aliases. If they are both
1977 // non-nullopt, then make sure they require the same assignment.
1978 auto required_assignment_at_end = RequiredMemoryAssignmentAt(
1979 request.allocation_value->value(), request.end_time);
1980 auto aliased_required_assignment_at_end =
1981 AliasedRequiredAssignmentForUse(*request.use);
1982 if (required_assignment_at_end != aliased_required_assignment_at_end) {
1983 if (required_assignment_at_end == absl::nullopt) {
1984 required_assignment_at_end = aliased_required_assignment_at_end;
1985 } else {
1986 CHECK(aliased_required_assignment_at_end == absl::nullopt ||
1987 aliased_required_assignment_at_end->equals_ignoring_time(
1988 *required_assignment_at_end));
1989 }
1990 }
1991 absl::optional<MemorySpace> required_memory_space_at_end;
1992 if (required_assignment_at_end) {
1993 required_memory_space_at_end = required_assignment_at_end->memory_space;
1994 }
1995
1996 if (required_assignment_at_start) {
1997 if (!allocation_sequence->empty()) {
1998 // We shouldn't have a situation where the required assignment at start is
1999 // at alternate memory space and we have existing allocations in the
2000 // allocation sequence. The only time we'll have required assignment at
2001 // start to be in the alternate memory space is in called computations
2002 // (e.g., while body) and we shouldn't have any allocations in the
2003 // allocation sequence so far.
2004 CHECK(required_assignment_at_start->memory_space ==
2005 MemorySpace::kDefault);
2006 // Find the previous allocation in default memory (might not be the very
2007 // last one) and extend its lifetime to include the start time of this
2008 // segment.
2009 auto prev_allocation_in_default_mem_it = std::find_if(
2010 allocation_sequence->rbegin(), allocation_sequence->rend(),
2011 [&](const auto& allocation) {
2012 return allocation->memory_space() == MemorySpace::kDefault &&
2013 allocation->defining_position() == defining_position;
2014 });
2015 CHECK(prev_allocation_in_default_mem_it != allocation_sequence->rend());
2016 (*prev_allocation_in_default_mem_it)->Extend(request.start_time);
2017 } else {
2018 absl::optional<Chunk> aliased_chunk = absl::nullopt;
2019 if (required_assignment_at_start->memory_space ==
2020 MemorySpace::kAlternate) {
2021 aliased_chunk =
2022 Chunk{required_assignment_at_start->offset->offset, request.size};
2023 }
2024 allocation_sequence->push_back(
2025 absl::make_unique<MemorySpaceAssignment::Allocation>(
2026 defining_position, required_assignment_at_start->memory_space,
2027 aliased_chunk, request.start_time, request.start_time));
2028 if (required_assignment_at_start->memory_space ==
2029 MemorySpace::kAlternate) {
2030 CreateOrAddToAliasedOffset(*allocation_sequence->back(),
2031 required_assignment_at_start->offset);
2032 }
2033 }
2034 }
2035
2036 Result allocation_result = Result::kSuccess;
2037 // First try keeping the allocation entirely in the alternate memory.
2038 if (required_memory_space_at_start != MemorySpace::kDefault &&
2039 required_memory_space_at_end != MemorySpace::kDefault &&
2040 request.allow_no_copy_alternate_mem_allocation) {
2041 allocation_result = AllocateInAlternateMemoryNoCopy(request);
2042 if (allocation_result == Result::kSuccess) {
2043 return Result::kSuccess;
2044 }
2045 }
2046
2047 auto prev_allocation_it = allocation_sequence->rbegin();
2048 // Find a previous allocation that is in the default memory space (not
2049 // necessarily the very last allocation).
2050 auto prev_allocation_in_default_mem_it = std::find_if(
2051 allocation_sequence->rbegin(), allocation_sequence->rend(),
2052 [&](const auto& allocation) {
2053 return allocation->memory_space() == MemorySpace::kDefault &&
2054 allocation->defining_position() == defining_position;
2055 });
2056
2057 if (prev_allocation_in_default_mem_it == allocation_sequence->rend() &&
2058 prev_allocation_it != allocation_sequence->rend() &&
2059 (*prev_allocation_it)->memory_space() == MemorySpace::kAlternate &&
2060 (*prev_allocation_it)->defining_position() == defining_position) {
2061 // If there was an allocation for this HloValue that was in the alternate
2062 // memory space, we also need to perform an eviction.
2063 Result eviction_result = Evict(request);
2064 if (eviction_result != Result::kSuccess) {
2065 // A non-success eviction requires us to uncommit previous allocations.
2066 return result_mark(Result::kFailRequiresUncommit, eviction_result);
2067 }
2068 prev_allocation_in_default_mem_it = allocation_sequence->rbegin();
2069 } else if (prev_allocation_in_default_mem_it == allocation_sequence->rend()) {
2070 allocation_sequence->push_back(
2071 absl::make_unique<MemorySpaceAssignment::Allocation>(
2072 defining_position, MemorySpace::kDefault, /*chunk=*/absl::nullopt,
2073 request.start_time, request.end_time));
2074 prev_allocation_in_default_mem_it = allocation_sequence->rbegin();
2075 }
2076
2077 CHECK(prev_allocation_in_default_mem_it != allocation_sequence->rend());
2078 CHECK((*prev_allocation_in_default_mem_it)->memory_space() ==
2079 MemorySpace::kDefault);
2080
2081 // If the buffer must be in default memory at the end_time, don't prefetch.
2082 if (required_memory_space_at_end == MemorySpace::kDefault) {
2083 VLOG(3)
2084 << "Not trying to prefetch because use requires buffer in default mem.";
2085 (*prev_allocation_in_default_mem_it)->Extend(request.end_time);
2086 (*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use);
2087 return Result::kSuccess;
2088 }
2089
2090 // Finally, try to prefetch the buffer into alternate memory.
2091 Result prefetch_result =
2092 Prefetch(request, **prev_allocation_in_default_mem_it);
2093 if (prefetch_result == Result::kSuccess) {
2094 return Result::kSuccess;
2095 }
2096 result_mark(prefetch_result, allocation_result);
2097
2098 // If the end assignment was required to be in alternate memory but that
2099 // wasn't possible, then this allocation is invalid.
2100 if (required_memory_space_at_end == MemorySpace::kAlternate) {
2101 return result_mark(Result::kFailRequiresUncommit, allocation_result);
2102 }
2103
2104 // If a copy wasn't inserted, then add this use to the latest allocation in
2105 // default memory.
2106 (*prev_allocation_in_default_mem_it)->Extend(request.end_time);
2107 (*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use);
2108 return allocation_result;
2109 }
2110
AddAsyncCopy(const MemorySpaceAssignment::Allocation & prev_allocation,MemorySpace memory_space,absl::optional<Chunk> chunk,int64 start_time,int64 end_time,int64 copy_done_schedule_before_time,MemorySpaceAssignment::AllocationSequence * allocations,AliasedOffset * aliased_offset,bool is_cross_program_prefetch)2111 void AlternateMemoryBestFitHeap::AddAsyncCopy(
2112 const MemorySpaceAssignment::Allocation& prev_allocation,
2113 MemorySpace memory_space, absl::optional<Chunk> chunk, int64 start_time,
2114 int64 end_time, int64 copy_done_schedule_before_time,
2115 MemorySpaceAssignment::AllocationSequence* allocations,
2116 AliasedOffset* aliased_offset, bool is_cross_program_prefetch) {
2117 VLOG(3) << "Copy to "
2118 << (memory_space == MemorySpaceAssignment::MemorySpace::kDefault
2119 ? "default"
2120 : "alternate")
2121 << " memory between " << start_time << " and "
2122 << copy_done_schedule_before_time << " keeping until " << end_time;
2123 CHECK_LT(start_time, copy_done_schedule_before_time);
2124
2125 allocations->push_back(
2126 absl::make_unique<MemorySpaceAssignment::CopyAllocation>(
2127 prev_allocation, memory_space, chunk, start_time, end_time,
2128 copy_done_schedule_before_time, is_cross_program_prefetch));
2129
2130 // Register the additional async copy with the interval tree to keep track of
2131 // the limit at any given time.
2132 pending_async_copies_.push_back(
2133 {start_time, copy_done_schedule_before_time, memory_space});
2134 if (memory_space == MemorySpaceAssignment::MemorySpace::kAlternate) {
2135 prefetch_interval_tree_.Add(start_time, copy_done_schedule_before_time,
2136 kDummyChunk);
2137 async_copy_ordering_.AddCopy(pending_async_copies_.back());
2138 CreateOrAddToAliasedOffset(*allocations->back(), aliased_offset);
2139 } else {
2140 eviction_interval_tree_.Add(start_time, copy_done_schedule_before_time,
2141 kDummyChunk);
2142 }
2143 }
2144
ViolatesMaximumOutstandingAsyncCopies(int64 start_time,int64 end_time,bool is_prefetch,int64 extra_async_copy_limit) const2145 bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies(
2146 int64 start_time, int64 end_time, bool is_prefetch,
2147 int64 extra_async_copy_limit) const {
2148 if (options_.max_outstanding_prefetches < 0 && is_prefetch) {
2149 return false;
2150 }
2151 if (options_.max_outstanding_evictions < 0 && !is_prefetch) {
2152 return false;
2153 }
2154
2155 // Count the prefetches/evictions in the interval tree for the given interval.
2156 if (is_prefetch) {
2157 int64 num_prefetches =
2158 prefetch_interval_tree_.ChunksOverlappingInTime(start_time, end_time)
2159 .size();
2160 return num_prefetches >=
2161 options_.max_outstanding_prefetches + extra_async_copy_limit;
2162 } else {
2163 int64 num_evictions =
2164 eviction_interval_tree_.ChunksOverlappingInTime(start_time, end_time)
2165 .size();
2166 return num_evictions >=
2167 options_.max_outstanding_evictions + extra_async_copy_limit;
2168 }
2169 }
2170
2171 absl::optional<AsynchronousCopy>
ViolatesAsyncCopyOrdering(int64 start_time,int64 end_time) const2172 AlternateMemoryBestFitHeap::ViolatesAsyncCopyOrdering(int64 start_time,
2173 int64 end_time) const {
2174 return async_copy_ordering_.ViolatesOrdering(start_time, end_time);
2175 }
2176
2177 AlternateMemoryBestFitHeap::Result
AllocateInAlternateMemoryNoCopy(const AllocationRequest & request)2178 AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy(
2179 const AllocationRequest& request) {
2180 MemorySpaceAssignment::Allocation* prev_allocation = nullptr;
2181 bool can_eliminate_copy = false;
2182 if (request.allocation_value->allocation_sequence()->empty()) {
2183 // There hasn't been any allocations for this interval so far. We can
2184 // eliminate copy if the value can be placed in the alternate memory.
2185 can_eliminate_copy = options_.is_allowed_in_alternate_mem_fn(
2186 *request.allocation_value->value());
2187 } else {
2188 // If there has been a previous allocation, we can eliminate the copy if the
2189 // previous allocation was also in the alternate memory.
2190 prev_allocation =
2191 request.allocation_value->allocation_sequence()->back().get();
2192 can_eliminate_copy =
2193 (prev_allocation->memory_space() == MemorySpace::kAlternate);
2194 }
2195
2196 if (!can_eliminate_copy) {
2197 return Result::kFailPrevAllocationNotInAlternateMem;
2198 }
2199
2200 const HloPosition& defining_position =
2201 request.allocation_value->defining_position();
2202 if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
2203 defining_position.shape(), request.start_time + 1,
2204 request.end_time)) {
2205 return Result::kFailLiveRangeTooLong;
2206 }
2207
2208 BufferInterval alternate_mem_interval;
2209 alternate_mem_interval.buffer = request.allocation_value->value();
2210 alternate_mem_interval.size = request.size;
2211 alternate_mem_interval.end = request.end_time;
2212 alternate_mem_interval.start = request.start_time;
2213
2214 // Prefer the offset that was previously used for the previous allocation.
2215 AliasedOffset* preferred_offset = nullptr;
2216 if (prev_allocation != nullptr) {
2217 preferred_offset = GetAliasedOffset(*prev_allocation);
2218 // If there is a previous allocation, set the start time one after the end
2219 // of the previous allocation's end.
2220 alternate_mem_interval.start = prev_allocation->end_time() + 1;
2221 }
2222
2223 if (request.preferred_offset) {
2224 // Sanity check that if there is a preferred offset provided in the request,
2225 // it matches with the previous allocation.
2226 CHECK(!preferred_offset || request.preferred_offset == preferred_offset)
2227 << "preferred_offset = " << preferred_offset->offset
2228 << ", request.preferred_offset = " << request.preferred_offset->offset;
2229 preferred_offset = request.preferred_offset;
2230 }
2231
2232 VLOG(3) << "We can eliminate copy to alternate memory. Preferred offset = "
2233 << (preferred_offset ? preferred_offset->offset : -1);
2234 // In case there are additional uses after this use, we rely on the last use
2235 // time to try to reserve a chunk in the heap simulator. This is to prevent
2236 // the following scenario:
2237 //
2238 // +-------+
2239 // / \
2240 // Producer--->Use1 +-->Use2
2241 // +---------+---------+
2242 // New buffer: | | |
2243 // +---------+---------+
2244 //
2245 // +-----------+
2246 // Current heap: | offset: 0 |
2247 // --------------------------+-----------+------
2248 //
2249 // Because we allocate buffers greedily, Producer to Use1 segment first, and
2250 // then Use1 to Use2 segment, it is possible to allocate the first segment at
2251 // an offset that is available for the first segment (e.g. offset 0) but not
2252 // for the entire live range. This can result in unnecessary copies. By using
2253 // the last use time, we try to find an allocation that is available for the
2254 // entire Producer to Use2 range.
2255 absl::optional<ChunkCandidate> chunk_candidate = FindBestChunkCandidate(
2256 request, preferred_offset, &alternate_mem_interval);
2257 // Check if the new heap size fits within limits. Also ensure if a
2258 // preferred offset was provided, that offset was used.
2259 if (chunk_candidate) {
2260 VLOG(3) << "Keep the buffer in alternate memory. Offset = "
2261 << chunk_candidate->chunk.offset
2262 << ", size = " << chunk_candidate->chunk.size
2263 << ", heap_size = " << chunk_candidate->heap_size
2264 << ", prefetch picker = "
2265 << options_.prefetch_interval_picker->ToNoCopyDebugString(
2266 defining_position.shape(), request.start_time,
2267 request.end_time);
2268 AddToPendingChunks(alternate_mem_interval, *chunk_candidate);
2269
2270 // If there was a previous allocation, the buffer location is the
2271 // same as the previous. Otherwise, it is the operand.
2272 if (prev_allocation != nullptr &&
2273 (prev_allocation->is_copy_allocation() ||
2274 prev_allocation->defining_position() == defining_position)) {
2275 prev_allocation->Extend(request.end_time);
2276 } else {
2277 request.allocation_value->allocation_sequence()->push_back(
2278 absl::make_unique<MemorySpaceAssignment::Allocation>(
2279 defining_position, MemorySpace::kAlternate,
2280 chunk_candidate->chunk, request.start_time, request.end_time));
2281 CreateOrAddToAliasedOffset(
2282 *request.allocation_value->allocation_sequence()->back(),
2283 preferred_offset);
2284 }
2285 request.allocation_value->allocation_sequence()->back()->AddUse(
2286 request.use->hlo_use);
2287 return Result::kSuccess;
2288 }
2289 return Result::kFailOutOfMemory;
2290 }
2291
Evict(const AllocationRequest & request)2292 AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::Evict(
2293 const AllocationRequest& request) {
2294 CHECK_GT(request.allocation_value->allocation_sequence()->size(), 0);
2295 MemorySpaceAssignment::Allocation* prev_allocation =
2296 request.allocation_value->allocation_sequence()->back().get();
2297 int64 eviction_start_time = prev_allocation->start_time();
2298 int64 eviction_end_time = prev_allocation->end_time();
2299 CHECK(eviction_start_time <= eviction_end_time);
2300
2301 int64 preferred_eviction_end_time =
2302 std::max(options_.prefetch_interval_picker->PreferredEvictionEndTime(
2303 request.allocation_value->defining_position().shape(),
2304 eviction_start_time, request.end_time),
2305 eviction_end_time);
2306 // Evictions must complete by the time of this use.
2307 preferred_eviction_end_time =
2308 std::min(preferred_eviction_end_time, request.latest_prefetch_time);
2309
2310 BufferInterval eviction_mem_interval;
2311 eviction_mem_interval.buffer = request.allocation_value->value();
2312 eviction_mem_interval.size = request.size;
2313 // Try to reserve a buffer from the end of the previous allocation to the
2314 // preferred eviction end time.
2315 eviction_mem_interval.start = eviction_end_time + 1;
2316 eviction_mem_interval.end = preferred_eviction_end_time;
2317 int64 preferred_offset = prev_allocation->chunk().offset;
2318 VLOG(3) << "Eviction (" << eviction_start_time << ", " << eviction_end_time
2319 << ") preferred end time = " << eviction_mem_interval.end;
2320
2321 for (; eviction_mem_interval.end > eviction_end_time;
2322 --eviction_mem_interval.end) {
2323 ChunkCandidate chunk_candidate =
2324 FindChunkCandidate(eviction_mem_interval, preferred_offset);
2325 if (chunk_candidate.chunk.offset == preferred_offset) {
2326 AddToPendingChunks(eviction_mem_interval, chunk_candidate);
2327 break;
2328 }
2329 }
2330 eviction_end_time = eviction_mem_interval.end;
2331
2332 VLOG(3) << "Evicting buffer at " << prev_allocation->chunk().offset << " ("
2333 << eviction_start_time << ", " << eviction_end_time << ")";
2334
2335 bool eviction_interval_too_short = (eviction_start_time == eviction_end_time);
2336 bool eviction_violates_outstanding_copies =
2337 ViolatesMaximumOutstandingAsyncCopies(eviction_start_time,
2338 eviction_end_time,
2339 /*is_prefetch=*/false);
2340
2341 // See if this interval would violate the asynchronous copy limit.
2342 if (!eviction_interval_too_short && !eviction_violates_outstanding_copies) {
2343 prev_allocation->Extend(eviction_end_time);
2344 AddAsyncCopy(*prev_allocation, MemorySpace::kDefault,
2345 /*chunk=*/absl::nullopt, eviction_start_time,
2346 prev_allocation->end_time(), eviction_end_time,
2347 request.allocation_value->allocation_sequence(),
2348 /*aliased_offset=*/nullptr);
2349 } else {
2350 if (eviction_violates_outstanding_copies) {
2351 VLOG(3) << "This violates the maximum async copies.";
2352 } else {
2353 VLOG(3) << "Eviction interval is too short (" << eviction_start_time
2354 << ", " << eviction_end_time << ").";
2355 }
2356 // If the original interval violated the limit, try sub-intervals within
2357 // this interval.
2358 bool eviction_scheduled = false;
2359 for (int64 time = eviction_start_time; time < eviction_end_time; ++time) {
2360 VLOG(4) << "Try evicting (" << time << ", " << time + 1 << ")";
2361 if (!ViolatesMaximumOutstandingAsyncCopies(time, time + 1,
2362 /*is_prefetch=*/false)) {
2363 VLOG(3) << "Eviction successful.";
2364 AddAsyncCopy(*prev_allocation, MemorySpace::kDefault,
2365 /*chunk=*/absl::nullopt, time, time + 1, time + 1,
2366 request.allocation_value->allocation_sequence(),
2367 /*aliased_offset=*/nullptr);
2368 eviction_scheduled = true;
2369 break;
2370 }
2371 }
2372
2373 if (!eviction_scheduled) {
2374 // If the eviction couldn't be scheduled, then fail. This buffer will be
2375 // kept in the default memory.
2376 VLOG(3) << "Bailing: Could not evict " << request.use->hlo_use.ToString()
2377 << " because we hit the limit of maximum asynchronous copies "
2378 << "between "
2379 << hlo_live_range_.flattened_instruction_sequence()
2380 .instructions()[eviction_start_time]
2381 << " and "
2382 << hlo_live_range_.flattened_instruction_sequence()
2383 .instructions()[eviction_end_time];
2384 // return false;
2385 return Result::kFailOutOfAsyncCopies;
2386 }
2387 }
2388 // return true;
2389 return Result::kSuccess;
2390 }
2391
FindPrefetchEndTime(const AllocationRequest & request,int64 earliest_prefetch_time) const2392 int64 AlternateMemoryBestFitHeap::FindPrefetchEndTime(
2393 const AllocationRequest& request, int64 earliest_prefetch_time) const {
2394 int64 prefetch_end_time = request.latest_prefetch_time;
2395
2396 const HloUse& use = request.use->hlo_use;
2397 const Shape& shape = ShapeUtil::GetSubshape(
2398 use.instruction->operand(use.operand_number)->shape(), use.operand_index);
2399 for (int retry_number = 0;
2400 retry_number < options_.prefetch_copy_done_reorder_max_retries;
2401 ++retry_number) {
2402 int64 latest_prefetch_time =
2403 options_.prefetch_interval_picker->LatestPrefetchStartTime(
2404 shape, earliest_prefetch_time, prefetch_end_time, &use);
2405 VLOG(4) << "Latest prefetch start time = " << latest_prefetch_time
2406 << ", earliest prefetch start time = " << earliest_prefetch_time
2407 << ", prefetch end time = " << prefetch_end_time;
2408 // Return if we couldn't find a suitable prefetch start time.
2409 if (latest_prefetch_time < earliest_prefetch_time) {
2410 break;
2411 }
2412
2413 // Return either if there is no other violating asynchronous copy (since we
2414 // don't need to change the prefetch end time) or if the violating
2415 // asynchronous copy ends after the prefetch end time.
2416 auto violating_async_copy =
2417 ViolatesAsyncCopyOrdering(latest_prefetch_time, prefetch_end_time);
2418 if (!violating_async_copy ||
2419 violating_async_copy->end_time >= prefetch_end_time) {
2420 break;
2421 }
2422 VLOG(4) << "Violating async copy: (" << violating_async_copy->start_time
2423 << ", " << violating_async_copy->end_time << ")";
2424
2425 int64 new_prefetch_end_time =
2426 options_.prefetch_interval_picker->LatestPrefetchEndTime(
2427 prefetch_end_time, violating_async_copy->end_time);
2428 if (new_prefetch_end_time > earliest_prefetch_time) {
2429 VLOG(3) << "Update prefetch end time = " << new_prefetch_end_time;
2430 prefetch_end_time = new_prefetch_end_time;
2431 } else {
2432 VLOG(3) << "Can't update prefetch end time = " << new_prefetch_end_time
2433 << " because earliest prefetch start time = "
2434 << earliest_prefetch_time;
2435 break;
2436 }
2437 }
2438
2439 return prefetch_end_time;
2440 }
2441
Prefetch(const AllocationRequest & request,const MemorySpaceAssignment::Allocation & prev_allocation_in_default_mem)2442 AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::Prefetch(
2443 const AllocationRequest& request,
2444 const MemorySpaceAssignment::Allocation& prev_allocation_in_default_mem) {
2445 // Try partially placing the buffer in the alternate space. The time that is
2446 // overlapped will be used to asynchronously copy the buffer from the
2447 // default memory to the alternate memory.
2448 //
2449 // start end
2450 // time time
2451 // X---------------------X
2452 // Alternate: +------+
2453 // Default: +---------------------+
2454 // ^ ^
2455 // Copy Copy
2456 // Start Done
2457 int64 earliest_prefetch_time =
2458 prev_allocation_in_default_mem.earliest_available_time();
2459 if (request.earliest_prefetch_time) {
2460 earliest_prefetch_time =
2461 std::max(earliest_prefetch_time, *request.earliest_prefetch_time);
2462 }
2463 int64 prefetch_end_time =
2464 FindPrefetchEndTime(request, earliest_prefetch_time);
2465
2466 options_.prefetch_interval_picker->Begin(
2467 request.use->hlo_use, earliest_prefetch_time, prefetch_end_time);
2468 VLOG(3) << "Trying prefetch picker = "
2469 << options_.prefetch_interval_picker->ToDebugString();
2470
2471 // Create an alternate memory interval that starts at the earliest
2472 // possible position, given by max_prefetch_interval.
2473 BufferInterval alternate_mem_interval;
2474 alternate_mem_interval.buffer = request.allocation_value->value();
2475 alternate_mem_interval.size = request.size;
2476 // While uses might be allowed to have additional outstanding prefetches.
2477 int64 extra_async_copy_limit =
2478 request.use->hlo_use.instruction->opcode() == HloOpcode::kWhile
2479 ? options_.while_use_extra_outstanding_prefetch_limit
2480 : 0;
2481 Result result = Result::kSuccess;
2482 while (!options_.prefetch_interval_picker->Done()) {
2483 alternate_mem_interval.start = options_.prefetch_interval_picker->Next();
2484 CHECK_LT(alternate_mem_interval.start, prefetch_end_time);
2485 VLOG(4) << "Trying alternate memory allocation ("
2486 << alternate_mem_interval.start << ", " << request.end_time << ")";
2487 // If this additional asynchronous copy would violate the limit, try a
2488 // different interval.
2489 if (ViolatesAsyncCopyOrdering(alternate_mem_interval.start,
2490 prefetch_end_time)) {
2491 VLOG(4) << "This would violate asynchronous copy ordering.";
2492 result_mark(Result::kFailViolatesAsyncCopyOrdering, result);
2493 continue;
2494 }
2495 if (ViolatesMaximumOutstandingAsyncCopies(
2496 alternate_mem_interval.start, prefetch_end_time,
2497 /*is_prefetch=*/true, extra_async_copy_limit)) {
2498 VLOG(4) << "This would violate the outstanding async copy limit.";
2499 result_mark(Result::kFailOutOfAsyncCopies, result);
2500 continue;
2501 }
2502
2503 auto chunk_candidate = FindBestChunkCandidate(
2504 request, request.preferred_offset, &alternate_mem_interval);
2505 // Check if we could find a suitable chunk.
2506 if (chunk_candidate) {
2507 VLOG(3) << "Move the buffer to alternate memory at "
2508 << alternate_mem_interval.start
2509 << ". Offset = " << chunk_candidate->chunk.offset
2510 << ", size = " << chunk_candidate->chunk.size
2511 << ", heap_size = " << chunk_candidate->heap_size
2512 << ", prefetch picker = "
2513 << options_.prefetch_interval_picker->ToDebugString();
2514 AddToPendingChunks(alternate_mem_interval, *chunk_candidate);
2515
2516 AddAsyncCopy(prev_allocation_in_default_mem, MemorySpace::kAlternate,
2517 chunk_candidate->chunk, alternate_mem_interval.start,
2518 request.end_time, prefetch_end_time,
2519 request.allocation_value->allocation_sequence(),
2520 request.preferred_offset);
2521
2522 request.allocation_value->allocation_sequence()->back()->AddUse(
2523 request.use->hlo_use);
2524 return Result::kSuccess;
2525 }
2526 result_mark(Result::kFailOutOfMemory, result);
2527 }
2528 // If we didn't consider any prefetch intervals, then the live range was too
2529 // short.
2530 if (result == Result::kSuccess) {
2531 return Result::kFailLiveRangeTooShort;
2532 } else {
2533 return result;
2534 }
2535 }
2536
2537 absl::optional<AlternateMemoryBestFitHeap::ChunkCandidate>
FindBestChunkCandidate(const AllocationRequest & request,const AliasedOffset * preferred_offset,BufferInterval * alternate_mem_interval) const2538 AlternateMemoryBestFitHeap::FindBestChunkCandidate(
2539 const AllocationRequest& request, const AliasedOffset* preferred_offset,
2540 BufferInterval* alternate_mem_interval) const {
2541 int64 end_time = request.end_time;
2542 if (!preferred_offset) {
2543 // First find the earliest use that is the same or later than the end time.
2544 const auto& uses = request.allocation_value->uses();
2545 auto use_it = uses.begin();
2546 for (; use_it->time < end_time; ++use_it) {
2547 }
2548 CHECK(use_it != uses.end());
2549 int64 earliest_use = use_it->time;
2550
2551 // Then find the latest use that can be allocated contiguously without
2552 // copies.
2553 const Shape& shape = request.allocation_value->defining_position().shape();
2554 for (;
2555 (use_it + 1) != uses.end() &&
2556 options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
2557 shape, use_it->time, (use_it + 1)->time);
2558 ++use_it) {
2559 }
2560 CHECK(use_it != uses.end());
2561 int64 latest_contiguous_use = use_it->time;
2562
2563 // Find a chunk that's as long living as possible iterating in reverse over
2564 // the use times.
2565 for (; use_it >= uses.begin() && use_it->time >= end_time; --use_it) {
2566 alternate_mem_interval->end = use_it->time;
2567 ChunkCandidate chunk_candidate =
2568 FindChunkCandidate(*alternate_mem_interval);
2569 if (chunk_candidate.heap_size <= available_heap_size()) {
2570 alternate_mem_interval->end = end_time;
2571 VLOG(3) << "FindBestChunkCandidate earliest use = " << earliest_use
2572 << ", latest contiguous use = " << latest_contiguous_use
2573 << ", use with available mem = " << use_it->time
2574 << ", offset = " << chunk_candidate.chunk.offset;
2575 return chunk_candidate;
2576 }
2577 }
2578 alternate_mem_interval->end = end_time;
2579 return absl::nullopt;
2580 }
2581 // If a preferred offset is given, try to find an allocation at that offset
2582 // only.
2583 alternate_mem_interval->end = end_time;
2584 ChunkCandidate chunk_candidate =
2585 FindChunkCandidate(*alternate_mem_interval, preferred_offset->offset);
2586 if (chunk_candidate.chunk.offset == preferred_offset->offset) {
2587 return chunk_candidate;
2588 }
2589 return absl::nullopt;
2590 }
2591
2592 StatusOr<MemorySpaceAssignment::AsyncCopyStats>
CalculateAsyncCopyStats() const2593 MemorySpaceAssignment::CalculateAsyncCopyStats() const {
2594 AsyncCopyStats stats;
2595 stats.max_outstanding_async_copies = 0;
2596 stats.num_prefetches = 0;
2597 stats.prefetch_bytes = 0;
2598 stats.num_evictions = 0;
2599 stats.eviction_bytes = 0;
2600 int64 current_copies = 0;
2601 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
2602 HloDataflowAnalysis::Run(*module_));
2603 for (const HloComputation* computation :
2604 module_->MakeNonfusionComputations()) {
2605 for (HloInstruction* instruction : computation->instructions()) {
2606 if (instruction->opcode() == HloOpcode::kCopyStart) {
2607 current_copies++;
2608 } else if (instruction->opcode() == HloOpcode::kCopyDone) {
2609 current_copies--;
2610 int64 size =
2611 options_.size_fn(dataflow_analysis->GetUniqueValueAt(instruction));
2612 if (instruction->shape().layout().memory_space() ==
2613 options_.alternate_memory_space) {
2614 ++stats.num_prefetches;
2615 stats.prefetch_bytes += size;
2616 } else {
2617 ++stats.num_evictions;
2618 stats.eviction_bytes += size;
2619 }
2620 }
2621 stats.max_outstanding_async_copies =
2622 std::max(stats.max_outstanding_async_copies, current_copies);
2623 }
2624 }
2625 return stats;
2626 }
2627
2628 /*static*/ MemorySpaceAssignment::BufferIntervalCompare
GetMemoryBoundednessBufferIntervalCompare(const MemorySpaceAssignmentCostAnalysis & cost_analysis,MemorySpaceAssignmentCostAnalysis::Cache * cache)2629 MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
2630 const MemorySpaceAssignmentCostAnalysis& cost_analysis,
2631 MemorySpaceAssignmentCostAnalysis::Cache* cache) {
2632 return [&cost_analysis, cache](const BufferInterval& x,
2633 const BufferInterval& y) {
2634 float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x, cache);
2635 float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y, cache);
2636 if (x_memory_boundedness != y_memory_boundedness) {
2637 return x_memory_boundedness > y_memory_boundedness;
2638 }
2639 // Tie-break if the memory boundedness is the same.
2640 return GlobalDecreasingSizeBestFitHeap<
2641 HloValue>::GetSpatialBufferIntervalCompare()(x, y);
2642 };
2643 }
2644
2645
2646 /*static*/ StatusOr<std::unique_ptr<PresetAssignments>>
Run(HloModule * module,const HloLiveRange & hlo_live_range,const HloAliasAnalysis & alias_analysis,const Options & options)2647 MemorySpaceAssignment::Run(HloModule* module,
2648 const HloLiveRange& hlo_live_range,
2649 const HloAliasAnalysis& alias_analysis,
2650 const Options& options) {
2651 CHECK(module->has_schedule());
2652 VLOG(3) << "Module before memory space assignment: ";
2653 XLA_VLOG_LINES(3, module->ToString());
2654 VLOG(3) << "Schedule: " << module->schedule().ToString();
2655 MemorySpaceAssignment memory_space_assignment(module, options,
2656 hlo_live_range);
2657
2658 return memory_space_assignment.RunMemorySpaceAssignment(hlo_live_range,
2659 alias_analysis);
2660 }
2661
2662 StatusOr<std::unique_ptr<PresetAssignments>>
RunMemorySpaceAssignment(const HloLiveRange & hlo_live_range,const HloAliasAnalysis & alias_analysis)2663 MemorySpaceAssignment::RunMemorySpaceAssignment(
2664 const HloLiveRange& hlo_live_range,
2665 const HloAliasAnalysis& alias_analysis) {
2666 TF_RETURN_IF_ERROR(FindAllocationSequence(hlo_live_range, alias_analysis));
2667 TF_RETURN_IF_ERROR(Process());
2668 ScheduleAsynchronousCopies();
2669 TF_RETURN_IF_ERROR(SimplifyGraph());
2670 TF_RETURN_IF_ERROR(FixSchedule());
2671 TF_RETURN_IF_ERROR(ExportAndColorBuffers());
2672
2673 VLOG(3) << "Module after memory space assignment: ";
2674 XLA_VLOG_LINES(3, module_->ToString());
2675 TF_CHECK_OK(module_->schedule().Verify());
2676 TF_ASSIGN_OR_RETURN(AsyncCopyStats stats, CalculateAsyncCopyStats());
2677 VLOG(1) << "Maximum number of outstanding async copies: "
2678 << stats.max_outstanding_async_copies;
2679 VLOG(1) << "Number of prefetches: " << stats.num_prefetches
2680 << ", in bytes: " << stats.prefetch_bytes;
2681 VLOG(1) << "Number of evictions: " << stats.num_evictions
2682 << ", in bytes: " << stats.eviction_bytes;
2683
2684 TF_RETURN_IF_ERROR(VerifyAndExportHeapSimulatorTrace());
2685
2686 return std::move(preset_assignments_);
2687 }
2688
FindAllocationSequence(const HloLiveRange & hlo_live_range,const HloAliasAnalysis & alias_analysis)2689 Status MemorySpaceAssignment::FindAllocationSequence(
2690 const HloLiveRange& hlo_live_range,
2691 const HloAliasAnalysis& alias_analysis) {
2692 auto algorithm = absl::make_unique<AlternateMemoryBestFitHeap>(
2693 &allocations_, options_, alias_analysis, hlo_live_range);
2694
2695 HeapSimulator::Options heap_simulator_options;
2696 heap_simulator_options.may_reuse_operand_buffers = false;
2697 TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module_,
2698 module_->schedule(), alias_analysis,
2699 options_.size_fn,
2700 heap_simulator_options)
2701 .status());
2702 return Status::OK();
2703 }
2704
AddUse(HloUse use)2705 void MemorySpaceAssignment::Allocation::AddUse(HloUse use) {
2706 HloInstruction* operand =
2707 use.instruction->mutable_operand(use.operand_number);
2708 // If the use is a tuple, look inside the tuple to find the actual use.
2709 for (int64 index : use.operand_index) {
2710 if (operand->opcode() != HloOpcode::kTuple) {
2711 break;
2712 }
2713 operand = operand->mutable_operand(index);
2714 }
2715
2716 // Look beyond GetTupleElement(Tuple()) pattern for any bitcasts.
2717 std::function<HloInstruction*(HloInstruction*)> get_simplified_operand;
2718 get_simplified_operand = [&](HloInstruction* instruction) {
2719 while (instruction->opcode() == HloOpcode::kGetTupleElement) {
2720 HloInstruction* operand =
2721 get_simplified_operand(instruction->mutable_operand(0));
2722 if (operand->opcode() == HloOpcode::kTuple) {
2723 instruction = operand->mutable_operand(instruction->tuple_index());
2724 } else {
2725 return instruction;
2726 }
2727 }
2728 return instruction;
2729 };
2730 operand = get_simplified_operand(operand);
2731
2732 uses_.push_back(use);
2733 }
2734
Process(MemorySpaceAssignment * memory_space_assignment)2735 Status MemorySpaceAssignment::Allocation::Process(
2736 MemorySpaceAssignment* memory_space_assignment) {
2737 HloInstruction* producing_instruction = AddGetTupleElements();
2738 HloComputation* computation = producing_instruction->parent();
2739 for (const HloUse& use : uses_) {
2740 Shape operand_shape = use.instruction->operand(use.operand_number)->shape();
2741 HloInstruction* replacement_instruction = producing_instruction;
2742 if (operand_shape.IsTuple()) {
2743 TF_ASSIGN_OR_RETURN(
2744 replacement_instruction,
2745 ReplaceTupleWith(producing_instruction,
2746 use.instruction->mutable_operand(use.operand_number),
2747 use.operand_index));
2748 } else if (operand_shape != producing_instruction->shape()) {
2749 VLOG(4) << "Old shape = " << operand_shape.ToString()
2750 << ", new shape = " << producing_instruction->shape().ToString()
2751 << "; inserting a bitcast.";
2752 replacement_instruction = computation->AddInstruction(
2753 HloInstruction::CreateBitcast(operand_shape, producing_instruction));
2754 }
2755 TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith(
2756 use.operand_number, replacement_instruction));
2757 }
2758 return Status::OK();
2759 }
2760
ReplaceTupleWith(HloInstruction * new_instruction,HloInstruction * tuple,ShapeIndex shape_index)2761 StatusOr<HloInstruction*> MemorySpaceAssignment::Allocation::ReplaceTupleWith(
2762 HloInstruction* new_instruction, HloInstruction* tuple,
2763 ShapeIndex shape_index) {
2764 const Shape& tuple_shape = tuple->shape();
2765 CHECK(tuple->shape().IsTuple())
2766 << "ReplaceTupleWith was called for a non-tuple. Tuple = "
2767 << tuple->ToString()
2768 << ", new_instruction = " << new_instruction->ToString()
2769 << ", shape_index = " << shape_index.ToString();
2770
2771 HloComputation* computation = new_instruction->parent();
2772 std::vector<HloInstruction*> tuple_args(tuple_shape.tuple_shapes_size());
2773 for (int64 i = 0; i < tuple_shape.tuple_shapes_size(); ++i) {
2774 const Shape& subshape = tuple_shape.tuple_shapes(i);
2775 // If tuple is a tuple instruction, we can get the tuple instruction's
2776 // operand to construct the new tuple to improve compilation time
2777 // performance.
2778 auto get_operand = [&]() {
2779 if (tuple->opcode() == HloOpcode::kTuple) {
2780 return tuple->mutable_operand(i);
2781 } else {
2782 return computation->AddInstruction(
2783 HloInstruction::CreateGetTupleElement(subshape, tuple, i));
2784 }
2785 };
2786 if (i == shape_index[0]) {
2787 // If the subshape is still a tuple, recurse and pass a new shape index
2788 // for the one level deeper.
2789 if (subshape.IsTuple()) {
2790 TF_ASSIGN_OR_RETURN(tuple_args[i],
2791 ReplaceTupleWith(new_instruction, get_operand(),
2792 ShapeIndex(shape_index.begin() + 1,
2793 shape_index.end())));
2794 } else {
2795 if (subshape != new_instruction->shape()) {
2796 VLOG(4) << "Old shape = " << subshape.ToString()
2797 << ", new shape = " << new_instruction->shape().ToString()
2798 << "; inserting a bitcast.";
2799 new_instruction = computation->AddInstruction(
2800 HloInstruction::CreateBitcast(subshape, new_instruction));
2801 } else if (tuple->opcode() == HloOpcode::kTuple &&
2802 tuple->operand(i) == new_instruction) {
2803 // If the tuple element is the same as the new instruction, we
2804 // actually don't have to create a new tuple, just return the original
2805 // tuple.
2806 VLOG(4) << "Tuple already contains the new instruction = "
2807 << new_instruction->ToShortString()
2808 << " tuple = " << tuple->ToShortString();
2809 return tuple;
2810 }
2811 tuple_args[i] = new_instruction;
2812 }
2813 } else {
2814 tuple_args[i] = get_operand();
2815 }
2816 }
2817 return computation->AddInstruction(HloInstruction::CreateTuple(tuple_args));
2818 }
2819
AddGetTupleElements()2820 HloInstruction* MemorySpaceAssignment::Allocation::AddGetTupleElements() {
2821 HloInstruction* producing_instruction = defining_position().instruction;
2822 CHECK_NE(producing_instruction, nullptr);
2823
2824 Shape shape = defining_position().shape();
2825 CHECK(shape.IsArray()) << "Allocation shape is not an array. Shape = "
2826 << shape.ToString()
2827 << " position = " << defining_position().shape();
2828 HloComputation* computation = producing_instruction->parent();
2829
2830 // If the instruction we're processing is a tuple, we (recursively) search or
2831 // create kGetTupleElement instructions and copy that value. Asynchronous
2832 // copies only support array types.
2833 for (int64 index : defining_position().index) {
2834 // We first search if there already is a get-tuple-element with the correct
2835 // index. If there is no such get-tuple-element, we create one.
2836 auto gte_it = absl::c_find_if(
2837 producing_instruction->users(), [index](const HloInstruction* use) {
2838 return use != use->parent()->root_instruction() &&
2839 use->opcode() == HloOpcode::kGetTupleElement &&
2840 use->tuple_index() == index;
2841 });
2842 if (gte_it != producing_instruction->users().end()) {
2843 producing_instruction = *gte_it;
2844 } else {
2845 producing_instruction =
2846 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
2847 producing_instruction->shape().tuple_shapes(index),
2848 producing_instruction, index));
2849 }
2850 }
2851 return producing_instruction;
2852 }
2853
ToString() const2854 std::string MemorySpaceAssignment::Allocation::ToString() const {
2855 std::string memory_space_str = "def";
2856 if (memory_space_ == MemorySpace::kAlternate) {
2857 memory_space_str = absl::StrCat("alt (off: ", chunk_->offset, ")");
2858 }
2859 return absl::StrCat("Allocation in ", memory_space_str, " defined at ",
2860 defining_position_.ToString());
2861 }
2862
ToString() const2863 std::string MemorySpaceAssignment::CopyAllocation::ToString() const {
2864 std::string memory_space_str = "def";
2865 if (memory_space_ == MemorySpace::kAlternate) {
2866 memory_space_str = absl::StrCat("alt (off: ", chunk_->offset, ")");
2867 }
2868 return absl::StrCat("Copy Allocation in ", memory_space_str, " from ",
2869 prev_allocation_.ToString());
2870 }
2871
Process(MemorySpaceAssignment * memory_space_assignment)2872 Status MemorySpaceAssignment::CopyAllocation::Process(
2873 MemorySpaceAssignment* memory_space_assignment) {
2874 // Copy allocations need to insert asynchronous copy nodes.
2875 Shape shape = defining_position().shape();
2876 HloInstruction* producing_instruction = AddGetTupleElements();
2877 HloComputation* computation = producing_instruction->parent();
2878 copy_start_ = computation->AddInstruction(HloInstruction::CreateCopyStart(
2879 ShapeUtil::MakeTupleShape({shape, shape, ShapeUtil::MakeShape(U32, {})}),
2880 producing_instruction, is_cross_program_prefetch_));
2881 copy_done_ = computation->AddInstruction(
2882 HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_));
2883 VLOG(4) << "Created " << copy_start_->name()
2884 << " for position: " << defining_position().ToString();
2885 // Update the allocation position with the copy done instruction so that if
2886 // there are further copies from it, it can find the correct position.
2887 defining_position_ = HloPosition{copy_done_, {}};
2888
2889 // Replace all the uses with the new copy instruction.
2890 for (HloUse use : uses_) {
2891 // If the operand is a tuple, we need to descend to the actual instruction
2892 // we want to replace.
2893 HloInstruction* replacement_instruction;
2894 Shape operand_shape = use.instruction->operand(use.operand_number)->shape();
2895 if (operand_shape.IsTuple()) {
2896 TF_ASSIGN_OR_RETURN(
2897 replacement_instruction,
2898 ReplaceTupleWith(copy_done_,
2899 use.instruction->mutable_operand(use.operand_number),
2900 use.operand_index));
2901 } else if (operand_shape != copy_done_->shape()) {
2902 VLOG(4) << "Old shape = " << operand_shape.ToString()
2903 << ", new shape = " << copy_done_->shape().ToString()
2904 << "; inserting a bitcast.";
2905 replacement_instruction = computation->AddInstruction(
2906 HloInstruction::CreateBitcast(operand_shape, copy_done_));
2907 } else {
2908 replacement_instruction = copy_done_;
2909 }
2910 TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith(
2911 use.operand_number, replacement_instruction));
2912 }
2913
2914 return Status::OK();
2915 }
2916
Process()2917 Status MemorySpaceAssignment::Process() {
2918 VLOG(1) << "Processing assigned buffers...";
2919 // Insert CopyStart/CopyDone pairs.
2920 for (auto& allocation : allocations_) {
2921 VLOG(3) << "Processing: " << allocation->ToString();
2922 TF_RETURN_IF_ERROR(allocation->Process(this));
2923 // Add the offset and size of the allocation in the alternate memory to
2924 // the output map.
2925 if (allocation->memory_space() == MemorySpace::kAlternate) {
2926 alternate_memory_assignments_.emplace_back(
2927 allocation->defining_position(), allocation->chunk());
2928 alternate_memory_size_ =
2929 std::max(alternate_memory_size_, allocation->chunk().chunk_end());
2930 }
2931 }
2932 return Status::OK();
2933 }
2934
ExportAndColorBuffers()2935 Status MemorySpaceAssignment::ExportAndColorBuffers() {
2936 VLOG(1) << "Exporting buffers...";
2937 TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module_));
2938 absl::flat_hash_map<int64, int64> seen_buffer_offsets;
2939 VLOG(3) << "Exported alternate memory allocations:";
2940 for (const auto& position_and_chunk : alternate_memory_assignments_) {
2941 const HloPosition& defining_position = position_and_chunk.first;
2942 const Chunk& chunk = position_and_chunk.second;
2943 const HloBuffer& buffer = alias_analysis->GetUniqueBufferAt(
2944 defining_position.instruction, defining_position.index);
2945 auto seen_buffer_offset_it = seen_buffer_offsets.find(buffer.id());
2946 if (seen_buffer_offset_it != seen_buffer_offsets.end()) {
2947 CHECK_EQ(chunk.offset, seen_buffer_offset_it->second)
2948 << "Mismatch in offset for positions that map to the same value: "
2949 << buffer.ToString() << ", pos: " << defining_position.ToString();
2950 } else {
2951 VLOG(3) << " [" << chunk.offset << ", " << chunk.size
2952 << "] : " << defining_position.ToString() << " ("
2953 << buffer.ToString() << ")";
2954 preset_assignments_->add_chunk(defining_position, chunk);
2955 seen_buffer_offsets[buffer.id()] = chunk.offset;
2956 }
2957 }
2958
2959 if (!preset_assignments_->chunks().empty()) {
2960 preset_assignments_
2961 ->assignment_information_for_space(options_.alternate_memory_space)
2962 ->size = alternate_memory_size_;
2963 }
2964
2965 VLOG(3) << "Exported alternate memory sizes:";
2966 for (auto& pair : preset_assignments_->assignment_informations()) {
2967 VLOG(3) << " space: " << pair.first << ", size: " << pair.second.size;
2968 }
2969
2970 VLOG(1) << "Coloring buffers...";
2971 // Color the pending positions and all of their aliased buffers.
2972 for (const auto& defining_position_and_chunk :
2973 preset_assignments_->chunks()) {
2974 const HloPosition& defining_position = defining_position_and_chunk.first;
2975 for (auto& buffer : alias_analysis->ComputeBuffersAt(
2976 defining_position.instruction, defining_position.index)) {
2977 for (auto& value : buffer->values()) {
2978 for (auto& position : value->positions()) {
2979 VLOG(4) << "Coloring " << position.ToString();
2980 Shape* shape = ShapeUtil::GetMutableSubshape(
2981 position.instruction->mutable_shape(), position.index);
2982 CHECK(shape->IsArray()) << "Coloring a shape that is not an array: "
2983 << position.ToString();
2984 shape->mutable_layout()->set_memory_space(
2985 options_.alternate_memory_space);
2986 }
2987 }
2988 }
2989 }
2990 return Status::OK();
2991 }
2992
RemoveAssignmentForInstruction(const HloInstruction * instruction)2993 void MemorySpaceAssignment::RemoveAssignmentForInstruction(
2994 const HloInstruction* instruction) {
2995 for (auto& position_and_chunk : alternate_memory_assignments_) {
2996 const HloPosition& position = position_and_chunk.first;
2997 if (position.instruction == instruction) {
2998 VLOG(3) << "Removing instruction from alternate memory assignments.";
2999 // Swap the removed position and chunk with the back and pop back.
3000 position_and_chunk = alternate_memory_assignments_.back();
3001 alternate_memory_assignments_.pop_back();
3002 break;
3003 }
3004 }
3005 }
3006
SimplifyGraph()3007 Status MemorySpaceAssignment::SimplifyGraph() {
3008 VLOG(1) << "Simplifying graph...";
3009 for (HloComputation* computation : module_->MakeNonfusionComputations()) {
3010 // Parallel computations aren't in the schedule and don't need to be
3011 // modified.
3012 if (!computations_in_schedule_.contains(computation)) {
3013 VLOG(4) << "Not simplifying " << computation->name()
3014 << " because it's not in the schedule.";
3015 continue;
3016 }
3017 // Drop control dependencies. Since the computation is already scheduled, we
3018 // don't need control dependencies anymore, and having control
3019 // predecessors/successors prevents us from removing instructions without
3020 // users (HloComputation::IsSafelyRemovable returns false if there are
3021 // control dependencies).
3022 for (HloInstruction* instruction :
3023 computation->MakeInstructionPostOrder()) {
3024 TF_RETURN_IF_ERROR(instruction->DropAllControlDeps());
3025 }
3026 // We perform limited DCE and forward the tuple operand in patterns like
3027 // GetTupleElement(Tuple(a, b), 0). This is mostly because memory space
3028 // assignment is ran late in compilation (after DCE and arithmetic
3029 // simplification passes) and we don't want to generate redundant code. Run
3030 // to fixed point.
3031 bool computation_modified = true;
3032 while (computation_modified) {
3033 computation_modified = false;
3034 VLOG(4) << "Running simplify graph loop over " << computation->name();
3035 for (HloInstruction* instruction :
3036 computation->MakeInstructionPostOrder()) {
3037 if (computation->IsSafelyRemovable(instruction) &&
3038 instruction->user_count() == 0 && !instruction->HasSideEffect() &&
3039 instruction != computation->root_instruction() &&
3040 instruction->opcode() != HloOpcode::kCopyStart &&
3041 instruction->opcode() != HloOpcode::kCopyDone) {
3042 VLOG(4) << "Instruction removed: " << instruction->ToString();
3043 // Ensure the alternate memory assignments don't contain a reference
3044 // to the removed instruction.
3045 RemoveAssignmentForInstruction(instruction);
3046 // Instead of deleting the instruction from the schedule, replace it
3047 // with a nullptr. This is needed because FixSchedule relies on the
3048 // logical time that is the index into flattened_instructions_ for
3049 // scheduling asynchronous copies.
3050 auto instruction_it =
3051 absl::c_find(flattened_instructions_, instruction);
3052 if (instruction_it != flattened_instructions_.end()) {
3053 *instruction_it = nullptr;
3054 }
3055 TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
3056 computation_modified = true;
3057 } else if (instruction->opcode() == HloOpcode::kGetTupleElement) {
3058 HloInstruction* operand = instruction->mutable_operand(0);
3059 if (operand->opcode() == HloOpcode::kTuple) {
3060 HloInstruction* forwarded_instruction =
3061 operand->mutable_operand(instruction->tuple_index());
3062 VLOG(4) << "Replacing uses of " << instruction->ToString()
3063 << " with " << forwarded_instruction->ToString();
3064 TF_RETURN_IF_ERROR(
3065 instruction->ReplaceAllUsesWith(forwarded_instruction));
3066 computation_modified = true;
3067 }
3068 } else if (instruction->opcode() == HloOpcode::kTuple) {
3069 // Replace Tuple(GetTupleElement(x), ..., GetTupleElement(x)) pattern
3070 // with x.
3071 bool can_replace =
3072 instruction->operand_count() > 0 &&
3073 instruction->operand(0)->opcode() ==
3074 HloOpcode::kGetTupleElement &&
3075 instruction->operand(0)
3076 ->operand(0)
3077 ->shape()
3078 .tuple_shapes_size() == instruction->operand_count();
3079 for (int operand_number = 0;
3080 operand_number < instruction->operand_count();
3081 ++operand_number) {
3082 const HloInstruction* operand =
3083 instruction->operand(operand_number);
3084 if (operand->opcode() != HloOpcode::kGetTupleElement ||
3085 operand->tuple_index() != operand_number ||
3086 operand->operand(0) != instruction->operand(0)->operand(0)) {
3087 can_replace = false;
3088 break;
3089 }
3090 }
3091 if (can_replace) {
3092 HloInstruction* forwarded_instruction =
3093 instruction->mutable_operand(0)->mutable_operand(0);
3094 VLOG(4) << "Replacing uses of " << instruction->ToString()
3095 << " with " << forwarded_instruction->ToString();
3096 TF_RETURN_IF_ERROR(
3097 instruction->ReplaceAllUsesWith(forwarded_instruction));
3098 computation_modified = true;
3099 }
3100 }
3101 }
3102 }
3103 }
3104
3105 return Status::OK();
3106 }
3107
EnsureInstructionAndOperandsInserted(HloInstruction * new_instruction,HloInstructionSequence * new_sequence,absl::flat_hash_set<HloInstruction * > * inserted_instructions) const3108 void MemorySpaceAssignment::EnsureInstructionAndOperandsInserted(
3109 HloInstruction* new_instruction, HloInstructionSequence* new_sequence,
3110 absl::flat_hash_set<HloInstruction*>* inserted_instructions) const {
3111 if (inserted_instructions->contains(new_instruction)) {
3112 return;
3113 }
3114 for (HloInstruction* operand : new_instruction->operands()) {
3115 // CopyStart/CopyDone dependencies should always be already inserted; it is
3116 // a red flag when they haven't already been inserted.
3117 CHECK((operand->opcode() != HloOpcode::kCopyStart &&
3118 operand->opcode() != HloOpcode::kCopyDone) ||
3119 inserted_instructions->contains(operand))
3120 << "Inserted instruction " << new_instruction->ToString()
3121 << " has un-inserted dependency: " << operand->ToString();
3122 EnsureInstructionAndOperandsInserted(operand, new_sequence,
3123 inserted_instructions);
3124 }
3125 VLOG(4) << "inserting: " << new_instruction->ToShortString();
3126 new_sequence->push_back(new_instruction);
3127 inserted_instructions->insert(new_instruction);
3128 }
3129
ScheduleAsynchronousCopies()3130 void MemorySpaceAssignment::ScheduleAsynchronousCopies() {
3131 VLOG(1) << "Scheduling asynchronous copies...";
3132 for (MemorySpace memory_space :
3133 {MemorySpace::kDefault, MemorySpace::kAlternate}) {
3134 std::vector<CopyAllocation*> copy_allocations;
3135 for (auto& allocation : allocations_) {
3136 if (allocation->is_copy_allocation()) {
3137 auto copy_allocation = static_cast<CopyAllocation*>(allocation.get());
3138 if (copy_allocation->memory_space() == memory_space) {
3139 copy_allocations.push_back(copy_allocation);
3140 }
3141 }
3142 }
3143
3144 absl::c_stable_sort(
3145 copy_allocations, [](CopyAllocation* first, CopyAllocation* second) {
3146 return std::forward_as_tuple(first->copy_done_schedule_before(),
3147 first->copy_start_schedule_after()) <
3148 std::forward_as_tuple(second->copy_done_schedule_before(),
3149 second->copy_start_schedule_after());
3150 });
3151
3152 CopyAllocation* prev_copy_allocation = nullptr;
3153 for (CopyAllocation* copy_allocation : copy_allocations) {
3154 // If the copy start doesn't happen to be scheduled at the correct
3155 // computation, delay it until the correct computation starts.
3156 int64 copy_start_schedule_after =
3157 copy_allocation->copy_start_schedule_after();
3158 // Accessing flattened_instructions_ here without checking if it is
3159 // nullptr is safe because this method is called before SimplifyGraph.
3160 while (copy_allocation->defining_position().instruction->parent() !=
3161 flattened_instructions_[copy_start_schedule_after]->parent()) {
3162 VLOG(4) << "Delaying CopyStart (" << copy_start_schedule_after << " to "
3163 << (copy_start_schedule_after + 1) << ") for "
3164 << copy_allocation->copy_start()->ToString()
3165 << " because it is not in the correct computation.";
3166 copy_allocation->set_copy_start_schedule_after(
3167 ++copy_start_schedule_after);
3168 }
3169
3170 schedule_after_[copy_allocation->copy_start_schedule_after()].push_back(
3171 copy_allocation->copy_start());
3172 schedule_before_[copy_allocation->copy_done_schedule_before()].push_back(
3173 copy_allocation->copy_done());
3174 prev_copy_allocation = copy_allocation;
3175 }
3176 }
3177 }
3178
FixSchedule()3179 Status MemorySpaceAssignment::FixSchedule() {
3180 VLOG(1) << "Fixing schedule...";
3181 CHECK(module_->has_schedule());
3182 HloSchedule& schedule = module_->schedule();
3183 for (const HloComputation* computation :
3184 module_->MakeNonfusionComputations()) {
3185 // Parallel computations aren't in the schedule and don't need to be
3186 // modified.
3187 if (!computations_in_schedule_.contains(computation)) {
3188 VLOG(4) << "Not scheduling " << computation->name()
3189 << " because it's not in the schedule.";
3190 continue;
3191 }
3192 CHECK(schedule.is_computation_scheduled(computation));
3193 HloInstructionSequence new_sequence;
3194
3195 absl::flat_hash_set<HloInstruction*> inserted_instructions;
3196
3197 VLOG(4) << "Scheduling: " << computation->ToString();
3198
3199 for (int64 instruction_index = 0;
3200 instruction_index < flattened_instructions_.size();
3201 ++instruction_index) {
3202 auto insts_before_iter = schedule_before_.find(instruction_index);
3203 if (insts_before_iter != schedule_before_.end()) {
3204 for (HloInstruction* new_instruction : insts_before_iter->second) {
3205 if (new_instruction->parent() == computation) {
3206 VLOG(4) << "before " << instruction_index << ": "
3207 << new_instruction->name();
3208 EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence,
3209 &inserted_instructions);
3210 }
3211 }
3212 }
3213 HloInstruction* instruction = flattened_instructions_[instruction_index];
3214 // Insert only if it is not deleted (SimplifyGraph sets it to nullptr if
3215 // it was deleted) and not previously inserted. Also bitcasts and tuples
3216 // are treated specially and only inserted as a result of operand
3217 // dependencies.
3218 if (instruction != nullptr &&
3219 !inserted_instructions.contains(instruction) &&
3220 instruction->parent() == computation &&
3221 instruction->opcode() != HloOpcode::kBitcast &&
3222 instruction->opcode() != HloOpcode::kTuple) {
3223 VLOG(4) << "inst " << instruction_index << ": " << instruction->name();
3224 EnsureInstructionAndOperandsInserted(instruction, &new_sequence,
3225 &inserted_instructions);
3226 }
3227 auto insts_after_iter = schedule_after_.find(instruction_index);
3228 if (insts_after_iter != schedule_after_.end()) {
3229 for (HloInstruction* new_instruction : insts_after_iter->second) {
3230 if (new_instruction->parent() == computation) {
3231 VLOG(4) << "after " << instruction_index << ": "
3232 << new_instruction->name();
3233 EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence,
3234 &inserted_instructions);
3235 }
3236 }
3237 }
3238 }
3239 // For rare cases where the original sequence is empty, ensure the root
3240 // instruction and its dependencies are scheduled.
3241 EnsureInstructionAndOperandsInserted(computation->root_instruction(),
3242 &new_sequence, &inserted_instructions);
3243 CHECK_EQ(new_sequence.size(), computation->instruction_count())
3244 << "New sequence for computation " << computation->name() << " has "
3245 << new_sequence.size() << " instructions, expects "
3246 << computation->instruction_count() << ".";
3247 schedule.set_sequence(computation, new_sequence);
3248 }
3249
3250 return Status::OK();
3251 }
3252
VerifyAndExportHeapSimulatorTrace()3253 Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() {
3254 VLOG(1) << "Verifying...";
3255 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
3256 HloAliasAnalysis::Run(module_));
3257 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
3258 HloLiveRange::Run(module_->schedule(), *alias_analysis,
3259 module_->entry_computation()));
3260
3261 BufferIntervalTree interval_tree;
3262 absl::flat_hash_set<int64> seen_buffers;
3263 // The key for events is: time, is_free, value_id. This is so that the events
3264 // are sorted first by time, then within the same time, allocations are sorted
3265 // earlier than frees, and finally the value id as a tie breaker.
3266 std::map<std::tuple<int64, bool, int64>,
3267 std::tuple<const HloValue*, Chunk, HeapSimulatorTrace::Event::Kind>>
3268 events;
3269
3270 auto add_allocation_and_verify = [&](int64 start_time, int64 end_time,
3271 const Chunk& chunk,
3272 const HloValue* value) {
3273 events[std::make_tuple(start_time, /*is_free=*/false, value->id())] =
3274 std::make_tuple(value, chunk, HeapSimulatorTrace::Event::ALLOC);
3275 events[std::make_tuple(end_time, /*is_free=*/true, value->id())] =
3276 std::make_tuple(value, chunk, HeapSimulatorTrace::Event::FREE);
3277
3278 // Get the chunks overlapping in time and search if they overlap in space
3279 // as well.
3280 // TODO(berkin): For now checking against end_time - 1 (exclusive), but we
3281 // really should check against end_time (inclusive) for cases where the
3282 // operand can't share buffer with user (see
3283 // HloDataflowAnalysis::CanShareOperandBufferWithUser).
3284 for (const Chunk& overlapping_chunk :
3285 interval_tree.ChunksOverlappingInTime(start_time, end_time - 1)) {
3286 if (chunk.OverlapsWith(overlapping_chunk)) {
3287 return InternalError(
3288 ("Value %s (%d, %d) off: %d size: %d overlaps with another chunk"
3289 " off: %d size: %d"),
3290 value->ToShortString(), start_time, end_time, chunk.offset,
3291 chunk.size, overlapping_chunk.offset, overlapping_chunk.size);
3292 }
3293 }
3294 interval_tree.Add(start_time, end_time - 1, chunk);
3295 return Status::OK();
3296 };
3297
3298 // Go through all instructions in the module to ensure CopyStart/CopyDone
3299 // instructions copy between alternate memory and default memory.
3300 for (const HloComputation* computation :
3301 module_->MakeNonfusionComputations()) {
3302 for (const HloInstruction* instruction : computation->instructions()) {
3303 if (instruction->opcode() == HloOpcode::kCopyStart) {
3304 int64 from_memory_space =
3305 ShapeUtil::GetSubshape(instruction->shape(), {1})
3306 .layout()
3307 .memory_space();
3308 int64 to_memory_space =
3309 ShapeUtil::GetSubshape(instruction->shape(), {0})
3310 .layout()
3311 .memory_space();
3312 CHECK_NE(from_memory_space, to_memory_space)
3313 << "Asynchronous copy to the same memory space: "
3314 << instruction->ToString();
3315 }
3316 }
3317 }
3318
3319 for (const auto& position_and_chunk : preset_assignments_->chunks()) {
3320 const HloPosition& position = position_and_chunk.first;
3321 const Chunk& chunk = position_and_chunk.second;
3322 const HloBuffer& buffer =
3323 alias_analysis->GetUniqueBufferAt(position.instruction, position.index);
3324 CHECK(!seen_buffers.contains(buffer.id()))
3325 << "Multiple preset assignments for the same buffer: "
3326 << buffer.ToString() << ", pos: " << position.ToString()
3327 << ", off: " << chunk.offset << ", size: " << chunk.size;
3328 seen_buffers.insert(buffer.id());
3329
3330 for (const HloValue* value : buffer.values()) {
3331 const HloLiveRange::TimeBound& time_bound =
3332 hlo_live_range->buffer_live_ranges().at(value);
3333 const HloInstruction* last_use_instruction = nullptr;
3334 int64 last_use_time = time_bound.start;
3335 for (const HloUse& use : value->uses()) {
3336 int64 use_time =
3337 hlo_live_range->instruction_schedule().at(use.instruction);
3338 if (use_time > last_use_time) {
3339 last_use_time = use_time;
3340 last_use_instruction = use.instruction;
3341 }
3342 }
3343
3344 std::function<Status(const HloInstruction*, int64, int64,
3345 absl::string_view)>
3346 split_conditional_buffer;
3347 split_conditional_buffer = [&](const HloInstruction* use_instruction,
3348 int64 start_time, int64 end_time,
3349 absl::string_view indent_string) {
3350 // Special case when verifying conditional: we internally split the use
3351 // of alternate memory in conditionals, so fish them out from the
3352 // conditionals.
3353 VLOG(3) << indent_string
3354 << "Splitting conditional buffer: " << buffer.ToString()
3355 << " value: " << value->ToShortString() << ": (" << start_time
3356 << ", " << end_time << ") off: " << chunk.offset
3357 << ", size: " << chunk.size;
3358 int64 earliest_computation_start_time = end_time;
3359 for (const HloComputation* called_computation :
3360 use_instruction->called_computations()) {
3361 earliest_computation_start_time =
3362 std::min(earliest_computation_start_time,
3363 hlo_live_range->computation_span_times()
3364 .at(called_computation)
3365 .start);
3366 int64 parameter_time = -1;
3367 int64 last_use_time = -1;
3368 const HloInstruction* last_use_instruction = nullptr;
3369 for (const HloPosition& position : value->positions()) {
3370 if (position.instruction->opcode() == HloOpcode::kParameter &&
3371 position.instruction->parent() == called_computation) {
3372 parameter_time = hlo_live_range->instruction_schedule().at(
3373 position.instruction);
3374 break;
3375 }
3376 }
3377 for (const HloUse& use : value->uses()) {
3378 int64 use_time =
3379 hlo_live_range->instruction_schedule().at(use.instruction);
3380 if (use.instruction->parent() == called_computation &&
3381 use_time > last_use_time) {
3382 last_use_time = use_time;
3383 last_use_instruction = use.instruction;
3384 }
3385 }
3386 if (last_use_time != -1) {
3387 CHECK_NE(parameter_time, -1);
3388 VLOG(3) << indent_string
3389 << " computation: " << called_computation->name() << ": ("
3390 << parameter_time << ", " << last_use_time << ")";
3391 CHECK(last_use_instruction);
3392 if (last_use_instruction->opcode() == HloOpcode::kConditional) {
3393 // The last use is another (nested) conditional. Call this
3394 // function recursively.
3395 TF_RETURN_IF_ERROR(split_conditional_buffer(
3396 last_use_instruction, parameter_time, last_use_time,
3397 absl::StrCat(indent_string, " ")));
3398 } else {
3399 last_use_time = std::min(last_use_time, end_time);
3400 TF_RETURN_IF_ERROR(add_allocation_and_verify(
3401 parameter_time, last_use_time, chunk, value));
3402 }
3403 }
3404 }
3405 VLOG(3) << indent_string << " from beginning until first computation: ("
3406 << start_time << ", " << (earliest_computation_start_time - 1)
3407 << ")";
3408 TF_RETURN_IF_ERROR(add_allocation_and_verify(
3409 start_time, earliest_computation_start_time - 1, chunk, value));
3410 return Status::OK();
3411 };
3412
3413 if (last_use_instruction &&
3414 last_use_instruction->opcode() == HloOpcode::kConditional) {
3415 TF_RETURN_IF_ERROR(split_conditional_buffer(
3416 last_use_instruction, time_bound.start, time_bound.end, " "));
3417 } else if (!value->uses().empty()) {
3418 last_use_time = std::min(last_use_time, time_bound.end);
3419 VLOG(3) << " buffer: " << buffer.ToString()
3420 << " value: " << value->ToShortString() << ": ("
3421 << time_bound.start << ", " << last_use_time
3422 << ") off: " << chunk.offset << ", size: " << chunk.size;
3423 TF_RETURN_IF_ERROR(add_allocation_and_verify(
3424 time_bound.start, last_use_time, chunk, value));
3425 }
3426 }
3427 }
3428
3429 HeapSimulatorTrace* heap_trace =
3430 &preset_assignments_
3431 ->assignment_information_for_space(options_.alternate_memory_space)
3432 ->heap_simulator_trace;
3433 int64 memory_usage = 0;
3434 int64 max_memory_usage = 0;
3435 for (const auto& event : events) {
3436 int64 time;
3437 bool is_free;
3438 int64 buffer_id;
3439 std::tie(time, is_free, buffer_id) = event.first;
3440 const HloValue* value;
3441 Chunk chunk;
3442 HeapSimulatorTrace::Event::Kind kind;
3443 std::tie(value, chunk, kind) = event.second;
3444 HeapSimulatorTrace::Event* heap_trace_event = heap_trace->add_events();
3445 heap_trace_event->set_kind(kind);
3446 heap_trace_event->set_buffer_id(buffer_id);
3447 heap_trace_event->set_instruction_name(value->instruction()->name());
3448 heap_trace_event->set_computation_name(
3449 value->instruction()->parent()->name());
3450
3451 if (kind == HeapSimulatorTrace::Event::ALLOC) {
3452 memory_usage += chunk.size;
3453 } else {
3454 CHECK_EQ(kind, HeapSimulatorTrace::Event::FREE);
3455 memory_usage -= chunk.size;
3456 }
3457 max_memory_usage = std::max(max_memory_usage, memory_usage);
3458 VLOG(4) << "Memory usage: " << memory_usage << " at time: " << time;
3459 }
3460 VLOG(1) << "Max memory usage ignoring fragmentation: " << max_memory_usage;
3461
3462 return Status::OK();
3463 }
3464
3465 } // namespace xla
3466