1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/memory_space_assignment.h"
17
18 #include <algorithm>
19 #include <utility>
20
21 #include "tensorflow/compiler/xla/debug_options_flags.h"
22 #include "tensorflow/compiler/xla/service/memory_space_assignment_utils.h"
23 #include "tensorflow/core/lib/math/math_util.h"
24 namespace xla {
25
26 namespace memory_space_assignment {
27
28 namespace {
29 // Define a dummy chunk for chunks that will be allocated in the default memory
30 // space and for keeping track of number of asynchronous copies.
31 const HeapSimulator::Chunk kDummyChunk{-1, -1};
32
LooksLikeAnActivation(const HloInstruction * inst)33 bool LooksLikeAnActivation(const HloInstruction* inst) {
34 for (HloInstruction* user : inst->users()) {
35 switch (user->opcode()) {
36 case HloOpcode::kConvolution:
37 case HloOpcode::kDot:
38 if (user->operand(0) == inst) {
39 return true;
40 }
41 break;
42 case HloOpcode::kGather:
43 if (user->operand(1) == inst) {
44 return true;
45 }
46 break;
47 case HloOpcode::kFusion:
48 for (int i = 0; i < user->operand_count(); ++i) {
49 if (user->operand(i) == inst &&
50 LooksLikeAnActivation(user->fused_parameter(i))) {
51 return true;
52 }
53 }
54 break;
55 case HloOpcode::kBitcast:
56 return LooksLikeAnActivation(user);
57 default:
58 return true;
59 }
60 }
61 return false;
62 }
63
IsCrossProgramPrefetchCandidate(const HloValue & value,const Options & options)64 bool IsCrossProgramPrefetchCandidate(const HloValue& value,
65 const Options& options) {
66 return value.instruction()->parent() ==
67 value.instruction()->GetModule()->entry_computation() &&
68 value.instruction()->opcode() == HloOpcode::kParameter &&
69 (!value.shape().has_layout() ||
70 value.shape().layout().memory_space() !=
71 options.alternate_memory_space) &&
72 value.index().size() == 1 && value.shape().IsArray() &&
73 !value.uses().empty() &&
74 options.size_fn(value) <= options.max_size_in_bytes &&
75 absl::c_all_of(value.uses(), [&](const HloUse& use) {
76 const HloInstruction* inst =
77 use.instruction->operand(use.operand_number);
78
79 // Skip the LooksLikeAnActivation test since we're testing the
80 // parent GTE and its children below.
81 if (inst->opcode() == HloOpcode::kBitcast &&
82 inst->operand(0)->opcode() == HloOpcode::kGetTupleElement &&
83 inst->operand(0)->operand(0)->opcode() ==
84 HloOpcode::kParameter) {
85 return true;
86 }
87
88 return inst->opcode() == HloOpcode::kGetTupleElement &&
89 !LooksLikeAnActivation(inst);
90 });
91 }
92
93 absl::optional<MemorySpaceAssignment::BufferInterval>
FindCrossProgramPrefetchCandidate(const HloAliasAnalysis & alias_analysis,const HloLiveRange & hlo_live_range,const Options & options)94 FindCrossProgramPrefetchCandidate(const HloAliasAnalysis& alias_analysis,
95 const HloLiveRange& hlo_live_range,
96 const Options& options) {
97 std::vector<MemorySpaceAssignment::BufferInterval> candidates;
98 for (const HloBuffer& buffer : alias_analysis.buffers()) {
99 CHECK_GE(buffer.values().size(), 1);
100 const HloValue* value = buffer.values().at(0);
101 if (IsCrossProgramPrefetchCandidate(*value, options)) {
102 MemorySpaceAssignment::BufferInterval interval;
103 interval.buffer = value;
104 interval.size = options.size_fn(*value);
105 interval.start = 0;
106 interval.end = hlo_live_range.schedule_end_time();
107 interval.need_allocation = true;
108 interval.colocations = {++buffer.values().begin(), buffer.values().end()};
109 candidates.emplace_back(interval);
110 }
111 }
112
113 // The buffer_interval_compare ought to do a good job picking the most
114 // appropriate buffer to cross program prefetch, but empirically, it makes
115 // worse choices than just picking the largest buffer.
116 // TODO(b/152421603): Investigate.
117 auto size_compare = [](const auto& x, const auto& y) {
118 if (x.size == y.size) {
119 // When both buffers are of same size, we prefer the one that is used to
120 // produce larger tensors in its consumer instructions.
121 auto get_use_size =
122 [](const MemorySpaceAssignment::BufferInterval& bi) -> int64 {
123 int64_t use_size = 0;
124 for (const auto& use : bi.buffer->uses()) {
125 use_size += ShapeUtil::ElementsInRecursive(use.instruction->shape());
126 }
127 return use_size;
128 };
129 return get_use_size(x) < get_use_size(y);
130 }
131 return x.size < y.size;
132 };
133 auto& compare = options.default_cross_program_prefetch_heuristic &&
134 options.buffer_interval_compare
135 ? *options.buffer_interval_compare
136 : size_compare;
137
138 auto best_candidate = absl::c_max_element(candidates, compare);
139 if (best_candidate == candidates.end()) {
140 return absl::nullopt;
141 }
142 return *best_candidate;
143 }
144
145 } // namespace
146
147 /*static*/ StatusOr<std::unique_ptr<MemorySpaceAssignmentCostAnalysis>>
Create(const HloCostAnalysis & cost_analysis,const Options & options,const HloModule & module)148 MemorySpaceAssignmentCostAnalysis::Create(const HloCostAnalysis& cost_analysis,
149 const Options& options,
150 const HloModule& module) {
151 TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module));
152 TF_ASSIGN_OR_RETURN(auto hlo_live_range,
153 HloLiveRange::Run(module.schedule(), *alias_analysis,
154 module.entry_computation()));
155 auto call_graph = CallGraph::Build(&module);
156 return absl::WrapUnique(new MemorySpaceAssignmentCostAnalysis(
157 cost_analysis, options, std::move(alias_analysis),
158 std::move(hlo_live_range), std::move(call_graph)));
159 }
160
GetAlternateMemoryBenefit(const HloInstruction & instruction,float elapsed_time_due_to_alternate_mem,MemorySpaceAssignmentCostAnalysis::Cache * cache) const161 float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit(
162 const HloInstruction& instruction, float elapsed_time_due_to_alternate_mem,
163 MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
164 float elapsed_time_due_to_compute =
165 GetInstructionElapsedDueToCompute(instruction);
166 float elapsed_time_due_to_memory =
167 GetInstructionElapsedDueToMemory(instruction);
168 if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) {
169 // Memory bound, return how much alternate memory is better.
170 float while_nest_multiplier;
171 if (cache) {
172 // If there is a cache provided, memoize the while nest multiplier.
173 auto it = cache->while_nest_multiplier.find(&instruction);
174 if (it != cache->while_nest_multiplier.end()) {
175 while_nest_multiplier = it->second;
176 } else {
177 while_nest_multiplier = tensorflow::MathUtil::IPow<float>(
178 options_.xla_tpu_memory_space_assignment_while_execution_count,
179 CalculateComputationNestLevel(&instruction,
180 /*while_only=*/true));
181 cache->while_nest_multiplier[&instruction] = while_nest_multiplier;
182 }
183 } else {
184 while_nest_multiplier = tensorflow::MathUtil::IPow<float>(
185 options_.xla_tpu_memory_space_assignment_while_execution_count,
186 CalculateComputationNestLevel(&instruction,
187 /*while_only=*/true));
188 }
189 return (elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem) *
190 while_nest_multiplier;
191 } else {
192 // Compute bound, return how far off are we to memory boundedness.
193 return elapsed_time_due_to_memory - elapsed_time_due_to_compute;
194 }
195 }
196
GetMemoryBoundedness(const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval & interval,MemorySpaceAssignmentCostAnalysis::Cache * cache) const197 float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness(
198 const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval,
199 MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
200 const HloInstruction& defining_instruction =
201 *interval.buffer->defining_instruction();
202 float alternate_mem_benefit = GetAlternateMemoryBenefit(
203 defining_instruction,
204 GetInstructionElapsedDueToMemory(
205 defining_instruction,
206 /*operands_in_alternate_mem=*/{},
207 /*outputs_in_alternate_mem=*/{interval.buffer->defining_index()}),
208 cache);
209 for (const HloBuffer* buffer : alias_analysis_->ComputeBuffersAt(
210 interval.buffer->defining_position().instruction,
211 interval.buffer->defining_position().index)) {
212 for (const HloValue* value : buffer->values()) {
213 for (const HloUse& use : value->uses()) {
214 // We look inside the called computations of while and conditional, so
215 // don't use the benefit of while and conditional directly.
216 if (use.instruction->opcode() == HloOpcode::kWhile ||
217 use.instruction->opcode() == HloOpcode::kConditional) {
218 continue;
219 }
220 float use_alternate_mem_benefit = GetAlternateMemoryBenefit(
221 *use.instruction,
222 GetInstructionElapsedDueToMemory(
223 *use.instruction,
224 /*operands_in_alternate_mem=*/{std::make_pair(
225 use.operand_number, use.operand_index)}),
226 cache);
227 // If the benefit is positive (memory bound), add it to this buffer's
228 // benefit. If the benefit is negative (compute bound), calculate the
229 // maximum.
230 if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) {
231 alternate_mem_benefit += use_alternate_mem_benefit;
232 } else {
233 alternate_mem_benefit =
234 std::max(alternate_mem_benefit, use_alternate_mem_benefit);
235 }
236 }
237 }
238 }
239
240 // Penalize larger buffers by dividing the benefit by the square root of the
241 // size. Empirically, we observed this resulted in better performance compared
242 // to dividing by the size.
243 return alternate_mem_benefit / std::sqrt(interval.size);
244 }
245
CalculateComputationNestLevel(const HloInstruction * instruction,bool while_only) const246 int MemorySpaceAssignmentCostAnalysis::CalculateComputationNestLevel(
247 const HloInstruction* instruction, bool while_only) const {
248 int nest_level = 0;
249 const HloComputation* computation = instruction->parent();
250 while (!computation->IsEntryComputation()) {
251 auto node = call_graph_->GetNode(computation);
252 auto callsites = node.caller_callsites();
253 CHECK_EQ(callsites.size(), 1) << "The module is not flattened!";
254 auto callsite = callsites[0];
255 if (!while_only || callsite.instruction()->opcode() == HloOpcode::kWhile) {
256 ++nest_level;
257 }
258 computation = callsite.instruction()->parent();
259 }
260 return nest_level;
261 }
262
GetInstructionElapsedDueToCompute(const HloInstruction & instruction) const263 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToCompute(
264 const HloInstruction& instruction) const {
265 return std::max(
266 cost_analysis_.flop_count(instruction) /
267 cost_analysis_.per_second_rate(HloCostAnalysis::kFlopsKey),
268 cost_analysis_.transcendental_count(instruction) /
269 cost_analysis_.per_second_rate(HloCostAnalysis::kTranscendentalsKey));
270 }
271
GetInstructionElapsedDueToMemory(const HloInstruction & instruction,absl::Span<const std::pair<int64_t,ShapeIndex>> operands_in_alternate_mem,absl::Span<const ShapeIndex> outputs_in_alternate_mem) const272 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToMemory(
273 const HloInstruction& instruction,
274 absl::Span<const std::pair<int64_t, ShapeIndex>> operands_in_alternate_mem,
275 absl::Span<const ShapeIndex> outputs_in_alternate_mem) const {
276 float total_bytes_accessed = cost_analysis_.bytes_accessed(instruction);
277 float bytes_accessed_from_alternate_mem = 0.0;
278 for (auto& operand : operands_in_alternate_mem) {
279 float operand_bytes_accessed = cost_analysis_.operand_bytes_accessed(
280 instruction, operand.first, operand.second);
281 bytes_accessed_from_alternate_mem += operand_bytes_accessed;
282 }
283
284 for (auto& shape_idx : outputs_in_alternate_mem) {
285 float output_bytes_accessed =
286 cost_analysis_.output_bytes_accessed(instruction, shape_idx);
287 bytes_accessed_from_alternate_mem += output_bytes_accessed;
288 }
289 float elapsed_due_to_alternate_mem =
290 bytes_accessed_from_alternate_mem /
291 options().alternate_mem_bandwidth_bytes_per_second;
292 float elapsed_due_to_default_mem =
293 (total_bytes_accessed - bytes_accessed_from_alternate_mem) /
294 cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey);
295 return elapsed_due_to_alternate_mem + elapsed_due_to_default_mem;
296 }
297
GetInstructionElapsed(const HloInstruction & instruction) const298 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsed(
299 const HloInstruction& instruction) const {
300 return std::max(GetInstructionElapsedDueToCompute(instruction),
301 GetInstructionElapsedDueToMemory(instruction));
302 }
303
GetInstructionElapsedInAlternateMemory(const HloInstruction & instruction,absl::Span<const std::pair<int64_t,ShapeIndex>> operands_in_alternate_mem,absl::Span<const ShapeIndex> outputs_in_alternate_mem) const304 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedInAlternateMemory(
305 const HloInstruction& instruction,
306 absl::Span<const std::pair<int64_t, ShapeIndex>> operands_in_alternate_mem,
307 absl::Span<const ShapeIndex> outputs_in_alternate_mem) const {
308 return std::max(
309 GetInstructionElapsedDueToCompute(instruction),
310 GetInstructionElapsedDueToMemory(instruction, operands_in_alternate_mem,
311 outputs_in_alternate_mem));
312 }
313
GetAsyncCopyElapsed(const Shape & shape) const314 float MemorySpaceAssignmentCostAnalysis::GetAsyncCopyElapsed(
315 const Shape& shape) const {
316 int64_t size_in_bytes = cost_analysis_.GetShapeSize(shape);
317 return static_cast<float>(size_in_bytes) /
318 options().async_copy_bandwidth_bytes_per_second;
319 }
320
GetScheduleEndTime() const321 int64 MemorySpaceAssignmentCostAnalysis::GetScheduleEndTime() const {
322 return hlo_live_range_->schedule_end_time();
323 }
324
CanAllocateInAlternateMemoryNoCopy(const Shape & shape,int64_t start_time,int64_t end_time) const325 bool InstructionCountPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy(
326 const Shape& shape, int64_t start_time, int64_t end_time) const {
327 return end_time - start_time <= max_overlap_count_;
328 }
329
PreferredEvictionEndTime(const Shape & shape,int64_t start_time,int64_t latest_end_time) const330 int64 InstructionCountPrefetchIntervalPicker::PreferredEvictionEndTime(
331 const Shape& shape, int64_t start_time, int64_t latest_end_time) const {
332 return std::min(start_time + min_overlap_count_, latest_end_time);
333 }
334
LatestPrefetchStartTime(const Shape & shape,int64_t start_time,int64_t end_time,const HloUse * use) const335 int64 InstructionCountPrefetchIntervalPicker::LatestPrefetchStartTime(
336 const Shape& shape, int64_t start_time, int64_t end_time,
337 const HloUse* use) const {
338 return end_time - min_overlap_count_;
339 }
340
PreferredPrefetchStartTime(const Shape & shape,int64_t earliest_prefetch_start_time,int64_t latest_prefetch_start_time,int64_t prefetch_end_time) const341 int64 InstructionCountPrefetchIntervalPicker::PreferredPrefetchStartTime(
342 const Shape& shape, int64_t earliest_prefetch_start_time,
343 int64_t latest_prefetch_start_time, int64_t prefetch_end_time) const {
344 return std::max(earliest_prefetch_start_time,
345 prefetch_end_time - max_overlap_count_);
346 }
347
Begin(const HloUse & use,int64_t start_time,int64_t end_time)348 void InstructionCountPrefetchIntervalPicker::Begin(const HloUse& use,
349 int64_t start_time,
350 int64_t end_time) {
351 end_time_ = end_time;
352 const Shape& shape = ShapeUtil::GetSubshape(
353 use.instruction->operand(use.operand_number)->shape(), use.operand_index);
354 current_prefetch_time_ =
355 PreferredPrefetchStartTime(shape, start_time, end_time, end_time);
356 }
357
Next()358 int64 InstructionCountPrefetchIntervalPicker::Next() {
359 CHECK(!Done()) << "Prefetch interval picker's Next() is called even though "
360 "Done() is false";
361 return current_prefetch_time_++;
362 }
363
Done() const364 bool InstructionCountPrefetchIntervalPicker::Done() const {
365 return end_time_ - current_prefetch_time_ <= min_overlap_count_;
366 }
367
ToDebugString() const368 std::string InstructionCountPrefetchIntervalPicker::ToDebugString() const {
369 return absl::StrCat("Overlapped HLOs = ", end_time_ - current_prefetch_time_);
370 }
371
ToNoCopyDebugString(const Shape & shape,int64_t start_time,int64_t end_time) const372 std::string InstructionCountPrefetchIntervalPicker::ToNoCopyDebugString(
373 const Shape& shape, int64_t start_time, int64_t end_time) const {
374 return absl::StrCat("Overlapped HLOs = ", end_time - start_time);
375 }
376
CostAnalysisPrefetchIntervalPicker(const MemorySpaceAssignmentCostAnalysis & cost_analysis,float min_async_copy_to_overlap_ratio,float max_async_copy_to_overlap_ratio,float preferred_async_copy_to_overlap_ratio,int64_t buffer_size_for_max_async_copy)377 CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker(
378 const MemorySpaceAssignmentCostAnalysis& cost_analysis,
379 float min_async_copy_to_overlap_ratio,
380 float max_async_copy_to_overlap_ratio,
381 float preferred_async_copy_to_overlap_ratio,
382 int64_t buffer_size_for_max_async_copy)
383 : while_nest_level_(
384 cost_analysis.hlo_live_range().instruction_schedule().size() + 1, 0),
385 computation_nest_level_(
386 cost_analysis.hlo_live_range().instruction_schedule().size() + 1, 0),
387 cost_analysis_(cost_analysis),
388 min_async_copy_to_overlap_ratio_(min_async_copy_to_overlap_ratio),
389 max_async_copy_to_overlap_ratio_(max_async_copy_to_overlap_ratio),
390 preferred_async_copy_to_overlap_ratio_(
391 preferred_async_copy_to_overlap_ratio),
392 buffer_size_for_max_async_copy_(buffer_size_for_max_async_copy) {
393 instruction_schedule_ =
394 &cost_analysis_.hlo_live_range().instruction_schedule();
395
396 // Create a vector of elapsed times and while nesting levels of HLO
397 // instructions. The elapsed times are multiplied by
398 // pow(while_execution_count, nest_level) to account for executing the HLOs
399 // multiple times in while loops.
400 std::vector<float> instructions_elapsed_time(instruction_schedule_->size(),
401 0.0);
402 for (const auto& instruction_and_logical_time : *instruction_schedule_) {
403 // To avoid double counting, don't include the elapsed time of while and
404 // conditional HLOs.
405 const HloInstruction* instruction = instruction_and_logical_time.first;
406 int64_t logical_time = instruction_and_logical_time.second;
407 if (logical_time >= instructions_elapsed_time.size()) {
408 instructions_elapsed_time.resize(logical_time + 1, 0.0);
409 while_nest_level_.resize(logical_time + 1, 0);
410 }
411 int while_nest_level = cost_analysis_.CalculateComputationNestLevel(
412 instruction_and_logical_time.first, /*while_only=*/true);
413 while_nest_level_[logical_time] = while_nest_level;
414 int computation_nest_level = cost_analysis_.CalculateComputationNestLevel(
415 instruction_and_logical_time.first, /*while_only=*/false);
416 computation_nest_level_[logical_time] = computation_nest_level;
417 if (instruction->opcode() == HloOpcode::kWhile ||
418 instruction->opcode() == HloOpcode::kConditional) {
419 continue;
420 }
421 float elapsed_time = cost_analysis_.GetInstructionElapsed(
422 *instruction_and_logical_time.first);
423 instructions_elapsed_time[logical_time] =
424 elapsed_time *
425 tensorflow::MathUtil::IPow<float>(
426 cost_analysis_.options()
427 .xla_tpu_memory_space_assignment_while_execution_count,
428 while_nest_level);
429 }
430 // As an optimization, create a cumulative sum vector of elapsed time.
431 float cumsum = 0.0;
432 elapsed_time_cumsum_.reserve(instructions_elapsed_time.size());
433 for (float elapsed_time : instructions_elapsed_time) {
434 cumsum += elapsed_time;
435 elapsed_time_cumsum_.push_back(cumsum);
436 }
437 // To be able to accurately determine the minimum nest level between a start
438 // time and an end time efficiently, populate a data structure that stores the
439 // closest nest level change index.
440 int prev_nest_level = 0;
441 int change_idx = -1;
442 while_nest_level_change_.reserve(instructions_elapsed_time.size());
443 for (int i = 0; i < while_nest_level_.size(); ++i) {
444 int nest_level = while_nest_level_[i];
445 if (nest_level != prev_nest_level) {
446 prev_nest_level = nest_level;
447 change_idx = i - 1;
448 }
449 while_nest_level_change_.push_back(change_idx);
450 }
451 }
452
GetMaxElapsedInAlternateMemory(float async_copy_elapsed) const453 float CostAnalysisPrefetchIntervalPicker::GetMaxElapsedInAlternateMemory(
454 float async_copy_elapsed) const {
455 return max_async_copy_to_overlap_ratio_ *
456 std::max(max_overlap_multiplier_ * async_copy_elapsed,
457 cost_analysis_.GetAsyncCopyElapsed(ShapeUtil::MakeShape(
458 S32, {buffer_size_for_max_async_copy_ / 4})));
459 }
460
CanAllocateInAlternateMemoryNoCopy(const Shape & shape,int64_t start_time,int64_t end_time) const461 bool CostAnalysisPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy(
462 const Shape& shape, int64_t start_time, int64_t end_time) const {
463 // Even though this method returns if we allow the buffer in alternate memory
464 // _without_ asynchronous copies, calculate how long it would have taken to
465 // copy it and compare it to the elapsed time in the logical interval.
466 float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
467 float logical_interval_elapsed =
468 GetLogicalIntervalElapsed(start_time, end_time);
469 return GetMaxElapsedInAlternateMemory(async_copy_elapsed) >
470 logical_interval_elapsed;
471 }
472
PreferredEvictionEndTime(const Shape & shape,int64_t start_time,int64_t latest_end_time) const473 int64 CostAnalysisPrefetchIntervalPicker::PreferredEvictionEndTime(
474 const Shape& shape, int64_t start_time, int64_t latest_end_time) const {
475 float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
476 int64_t end_time;
477 for (end_time = start_time + 1; end_time <= latest_end_time; ++end_time) {
478 float logical_interval_elapsed =
479 GetLogicalIntervalElapsed(start_time, end_time);
480 if (logical_interval_elapsed >=
481 min_async_copy_to_overlap_ratio_ * async_copy_elapsed) {
482 break;
483 }
484 }
485 return end_time;
486 }
487
LatestPrefetchStartTime(const Shape & shape,int64_t start_time,int64_t end_time,const HloUse * use) const488 int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchStartTime(
489 const Shape& shape, int64_t start_time, int64_t end_time,
490 const HloUse* use) const {
491 // Find the earliest time that satisfies max_async_copy_to_overlap_ratio_.
492 float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
493 // If there is a use, estimate the time we would save by having this op in
494 // alternate memory.
495 float inst_elapsed_reduction = 0.0f;
496 if (use) {
497 float elapsed_time =
498 cost_analysis_.GetInstructionElapsed(*use->instruction);
499 float elapsed_time_in_alternate_mem =
500 cost_analysis_.GetInstructionElapsedInAlternateMemory(
501 *use->instruction,
502 /*operands_in_alternate_mem=*/
503 {std::make_pair(use->operand_number, use->operand_index)},
504 /*outputs_in_alternate_mem=*/{});
505 inst_elapsed_reduction = elapsed_time - elapsed_time_in_alternate_mem;
506 }
507 int end_nest_level = computation_nest_level_[end_time];
508
509 // Find the latest time we're allowed to start prefetching.
510 float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed;
511 int latest_prefetch_time;
512 for (latest_prefetch_time = end_time - 1;
513 latest_prefetch_time >= start_time &&
514 (computation_nest_level_[latest_prefetch_time] != end_nest_level ||
515 min_interval >
516 GetLogicalIntervalElapsed(latest_prefetch_time, end_time) +
517 inst_elapsed_reduction);
518 --latest_prefetch_time) {
519 }
520
521 return latest_prefetch_time;
522 }
523
PreferredPrefetchStartTime(const Shape & shape,int64_t earliest_prefetch_start_time,int64_t latest_prefetch_start_time,int64_t prefetch_end_time) const524 int64 CostAnalysisPrefetchIntervalPicker::PreferredPrefetchStartTime(
525 const Shape& shape, int64_t earliest_prefetch_start_time,
526 int64_t latest_prefetch_start_time, int64_t prefetch_end_time) const {
527 // Between the earliest and latest prefetch interval, find the interval
528 // closest to the preferred interval and start iterating from there.
529 float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
530 int64_t preferred_prefetch_start_time = earliest_prefetch_start_time;
531 float preferred_interval =
532 preferred_async_copy_to_overlap_ratio_ * async_copy_elapsed;
533 float best_interval = GetLogicalIntervalElapsed(earliest_prefetch_start_time,
534 prefetch_end_time);
535 int end_nest_level = computation_nest_level_[prefetch_end_time];
536 for (int64_t prefetch_start_time = earliest_prefetch_start_time + 1;
537 prefetch_start_time <= latest_prefetch_start_time;
538 ++prefetch_start_time) {
539 float interval =
540 GetLogicalIntervalElapsed(prefetch_start_time, prefetch_end_time);
541 if (computation_nest_level_[prefetch_start_time] == end_nest_level &&
542 std::abs(preferred_interval - interval) <
543 std::abs(preferred_interval - best_interval)) {
544 best_interval = interval;
545 preferred_prefetch_start_time = prefetch_start_time;
546 }
547 }
548 return preferred_prefetch_start_time;
549 }
550
LatestPrefetchEndTime(int64_t original_prefetch_end_time,int64_t proposed_prefetch_end_time) const551 int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchEndTime(
552 int64_t original_prefetch_end_time,
553 int64_t proposed_prefetch_end_time) const {
554 // Iterate towards the beginning until we find a suitable end time that is the
555 // same while nest level as the original prefetch end time.
556 int64_t original_nest_level =
557 computation_nest_level_[original_prefetch_end_time];
558 int64_t new_prefetch_end_time;
559 for (new_prefetch_end_time = proposed_prefetch_end_time;
560 computation_nest_level_[new_prefetch_end_time] != original_nest_level;
561 --new_prefetch_end_time) {
562 }
563 return new_prefetch_end_time;
564 }
565
Begin(const HloUse & use,int64_t start_time,int64_t end_time)566 void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use,
567 int64_t start_time,
568 int64_t end_time) {
569 const Shape& shape = ShapeUtil::GetSubshape(
570 use.instruction->operand(use.operand_number)->shape(), use.operand_index);
571 // Find the earliest time that satisfies max_async_copy_to_overlap_ratio_.
572 async_copy_elapsed_ = cost_analysis_.GetAsyncCopyElapsed(shape);
573 // Estimate the time we would save by having this op in alternate memory.
574 float elapsed_time = cost_analysis_.GetInstructionElapsed(*use.instruction);
575 float elapsed_time_in_alternate_mem =
576 cost_analysis_.GetInstructionElapsedInAlternateMemory(
577 *use.instruction, /*operands_in_alternate_mem=*/
578 {std::make_pair(use.operand_number, use.operand_index)},
579 /*outputs_in_alternate_mem=*/{});
580 inst_elapsed_reduction_ = elapsed_time - elapsed_time_in_alternate_mem;
581 end_logical_time_ = end_time;
582 int end_nest_level = computation_nest_level_[end_logical_time_];
583
584 // Find the latest time we're allowed to start prefetching.
585 float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed_;
586 latest_prefetch_time_ =
587 LatestPrefetchStartTime(shape, start_time, end_time, &use);
588
589 // Find the earliest time we're allowed to start prefetching.
590 float max_interval = GetMaxElapsedInAlternateMemory(async_copy_elapsed_);
591 for (earliest_prefetch_time_ = start_time;
592 earliest_prefetch_time_ < latest_prefetch_time_ &&
593 (computation_nest_level_[earliest_prefetch_time_] != end_nest_level ||
594 max_interval < GetLogicalIntervalElapsed(earliest_prefetch_time_,
595 end_logical_time_));
596 ++earliest_prefetch_time_) {
597 }
598 if (earliest_prefetch_time_ > latest_prefetch_time_) {
599 // There is no available prefetch interval for the given start and end
600 // times. Set the iterators accordingly to ensure Done() returns true.
601 increasing_prefetch_time_iterator_ = earliest_prefetch_time_;
602 decreasing_prefetch_time_iterator_ = latest_prefetch_time_;
603 CHECK(Done());
604 return;
605 }
606
607 int64_t starting_prefetch_time = PreferredPrefetchStartTime(
608 shape, earliest_prefetch_time_, latest_prefetch_time_, end_logical_time_);
609 float preferred_interval =
610 preferred_async_copy_to_overlap_ratio_ * async_copy_elapsed_;
611 VLOG(4) << "Interval min/max/preferred = " << min_interval << " "
612 << max_interval << " " << preferred_interval
613 << " prefetch time earliest/latest/starting = "
614 << earliest_prefetch_time_ << " " << latest_prefetch_time_ << " "
615 << starting_prefetch_time;
616
617 increasing_prefetch_time_iterator_ = starting_prefetch_time;
618 decreasing_prefetch_time_iterator_ = starting_prefetch_time;
619 using_increasing_prefetch_time_iterator_ = true;
620 // Since both iterators start at the same position, call Next() once to
621 // advance one of the iterators.
622 Next();
623 }
624
Next()625 int64 CostAnalysisPrefetchIntervalPicker::Next() {
626 CHECK(!Done()) << "Prefetch interval picker's Next() is called even though "
627 "Done() is false";
628 if (using_increasing_prefetch_time_iterator_) {
629 int64_t prefetch_time = increasing_prefetch_time_iterator_++;
630 while (increasing_prefetch_time_iterator_ <= latest_prefetch_time_ &&
631 computation_nest_level_[increasing_prefetch_time_iterator_] !=
632 computation_nest_level_[end_logical_time_]) {
633 ++increasing_prefetch_time_iterator_;
634 }
635 if (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_) {
636 using_increasing_prefetch_time_iterator_ = false;
637 }
638 return prefetch_time;
639 } else {
640 int64_t prefetch_time = decreasing_prefetch_time_iterator_--;
641 while (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_ &&
642 computation_nest_level_[decreasing_prefetch_time_iterator_] !=
643 computation_nest_level_[end_logical_time_]) {
644 --decreasing_prefetch_time_iterator_;
645 }
646 if (increasing_prefetch_time_iterator_ <= latest_prefetch_time_) {
647 using_increasing_prefetch_time_iterator_ = true;
648 }
649 return prefetch_time;
650 }
651 }
652
Done() const653 bool CostAnalysisPrefetchIntervalPicker::Done() const {
654 return increasing_prefetch_time_iterator_ > latest_prefetch_time_ &&
655 decreasing_prefetch_time_iterator_ < earliest_prefetch_time_;
656 }
657
SetRetryNumber(int retry_number)658 void CostAnalysisPrefetchIntervalPicker::SetRetryNumber(int retry_number) {
659 // Use twice as large max overlap limit in each retry.
660 max_overlap_multiplier_ = 1 << retry_number;
661 }
662
GetMinWhileNestLevel(int64_t start_time,int64_t end_time) const663 int CostAnalysisPrefetchIntervalPicker::GetMinWhileNestLevel(
664 int64_t start_time, int64_t end_time) const {
665 int min_nest_level =
666 std::min(while_nest_level_[start_time], while_nest_level_[end_time]);
667 int change_idx = while_nest_level_change_[end_time];
668 while (change_idx >= start_time) {
669 min_nest_level = std::min(min_nest_level, while_nest_level_[change_idx]);
670 change_idx = while_nest_level_change_[change_idx];
671 }
672 return min_nest_level;
673 }
674
GetLogicalIntervalElapsed(int64_t start_time,int64_t end_time) const675 float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed(
676 int64_t start_time, int64_t end_time) const {
677 CHECK_LE(start_time, end_time);
678 if (start_time == end_time) {
679 return 0.0;
680 }
681 if (start_time < 0) {
682 start_time = 0;
683 }
684 // Since elapsed_time_cumsum_ is already weighed by the while loop nesting
685 // level, normalize the elapsed time by dividing with the nesting factor of
686 // the interval (start and end times).
687 int interval_while_nest_level = GetMinWhileNestLevel(start_time, end_time);
688 return (elapsed_time_cumsum_[end_time - 1] -
689 elapsed_time_cumsum_[start_time]) /
690 tensorflow::MathUtil::IPow<float>(
691 cost_analysis_.options()
692 .xla_tpu_memory_space_assignment_while_execution_count,
693 interval_while_nest_level);
694 }
695
ToDebugString() const696 std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const {
697 int current_logical_prefetch_time = using_increasing_prefetch_time_iterator_
698 ? increasing_prefetch_time_iterator_
699 : decreasing_prefetch_time_iterator_;
700 float logical_interval_elapsed = GetLogicalIntervalElapsed(
701 current_logical_prefetch_time, end_logical_time_);
702 return absl::StrCat(
703 "Async copy elapsed (s) = ", async_copy_elapsed_,
704 ", inst elapsed reduction (s) = ", inst_elapsed_reduction_,
705 ", logical interval elapsed (s) = ", logical_interval_elapsed,
706 ", interval = (", current_logical_prefetch_time, ", ", end_logical_time_,
707 ")");
708 }
709
ToNoCopyDebugString(const Shape & shape,int64_t start_time,int64_t end_time) const710 std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString(
711 const Shape& shape, int64_t start_time, int64_t end_time) const {
712 float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
713 float logical_interval_elapsed =
714 GetLogicalIntervalElapsed(start_time, end_time);
715 return absl::StrCat(
716 "Async copy elapsed (s) = ", async_copy_elapsed,
717 ", logical interval elapsed (s) = ", logical_interval_elapsed);
718 }
719
720 absl::optional<float>
BufferIntervalAlternateMemoryBenefit(const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval & interval) const721 CostAnalysisPrefetchIntervalPicker::BufferIntervalAlternateMemoryBenefit(
722 const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval)
723 const {
724 return cost_analysis_.GetMemoryBoundedness(interval);
725 }
726
operator ==(const MemorySpaceAssignment::Allocation & other) const727 bool MemorySpaceAssignment::Allocation::operator==(
728 const MemorySpaceAssignment::Allocation& other) const {
729 return defining_position() == other.defining_position() &&
730 uses() == other.uses() && memory_space() == other.memory_space() &&
731 chunk() == other.chunk() && start_time() == other.start_time() &&
732 end_time() == other.end_time() &&
733 is_copy_allocation() == other.is_copy_allocation() &&
734 is_scoped_allocation() == other.is_scoped_allocation();
735 }
736
operator ==(const MemorySpaceAssignment::CopyAllocation & other) const737 bool MemorySpaceAssignment::CopyAllocation::operator==(
738 const MemorySpaceAssignment::CopyAllocation& other) const {
739 return static_cast<const Allocation&>(*this) ==
740 static_cast<const Allocation&>(other) &&
741 copy_done_schedule_before() == other.copy_done_schedule_before() &&
742 copy_start_schedule_after() == other.copy_start_schedule_after() &&
743 copy_start() == other.copy_start() && copy_done() == other.copy_done();
744 }
745
ToString() const746 std::string MemorySpaceAssignment::AllocationValue::ToString() const {
747 std::string out = absl::StrCat("computation = ", computation()->name());
748 absl::StrAppend(&out,
749 (requires_contiguous_allocation_ ? " (cont alloc)" : ""));
750 absl::StrAppend(&out, "\n position:\n");
751 absl::StrAppend(&out, " ", defining_position_.ToString(), "\n");
752 absl::StrAppend(&out, " uses:\n");
753 for (const Use& use : uses_) {
754 absl::StrAppend(&out, " ", use.hlo_use.ToString(), "\n");
755 }
756 return out;
757 }
758
ToShortString() const759 std::string MemorySpaceAssignment::AllocationValue::ToShortString() const {
760 return absl::StrCat("computation = ", computation()->name(),
761 ", position = ", defining_position_.ToString(),
762 ", value = ", value_->ToShortString(),
763 (requires_contiguous_allocation_ ? " (cont alloc)" : ""));
764 }
765
CreateAllocationValues(const AlternateMemoryBestFitHeap::BufferInterval & buffer_interval,std::vector<AllocationValue> & allocation_values) const766 void AlternateMemoryBestFitHeap::CreateAllocationValues(
767 const AlternateMemoryBestFitHeap::BufferInterval& buffer_interval,
768 std::vector<AllocationValue>& allocation_values) const {
769 const HloValue* value = buffer_interval.buffer;
770 VLOG(3) << "Creating AllocationValues for: " << value->ToString();
771
772 // Find and sort all non-trivial (excluding GTE, Tuple, and bitcast)
773 // positions. We create an AllocationValue object for each non-trivial
774 // position. And for each AllocationValue object, we create an
775 // AllocationSequence consisting of one or more Allocation objects.The reason
776 // why we exclude the trivial positions from AllocationValue is because
777 // Allocation objects have special support for tuples and bitcasts.
778 const absl::flat_hash_map<const HloInstruction*, int64>&
779 instruction_schedule = hlo_live_range_.instruction_schedule();
780 std::vector<HloPosition> positions;
781 for (const HloPosition& position : value->positions()) {
782 const HloInstruction* instruction = position.instruction;
783 if (instruction->opcode() != HloOpcode::kGetTupleElement &&
784 instruction->opcode() != HloOpcode::kTuple &&
785 instruction->opcode() != HloOpcode::kBitcast) {
786 positions.push_back(position);
787 }
788 }
789 absl::c_stable_sort(positions,
790 [&](const HloPosition& pos1, const HloPosition& pos2) {
791 return instruction_schedule.at(pos1.instruction) <
792 instruction_schedule.at(pos2.instruction);
793 });
794
795 // Create an AllocationValue for each non-trivial position.
796 absl::flat_hash_set<const HloComputation*> computations;
797 int beginning_idx = allocation_values.size();
798 for (int i = 0; i < positions.size(); ++i) {
799 const HloPosition& position = positions.at(i);
800 allocation_values.emplace_back(value, position, buffer_interval.size);
801 }
802
803 std::vector<HloUse> uses(value->uses());
804 absl::c_stable_sort(uses, [&](const HloUse& use1, const HloUse& use2) {
805 return instruction_schedule.at(use1.instruction) <
806 instruction_schedule.at(use2.instruction);
807 });
808
809 // Associate each use with an AllocationValue. Each AllocationValue contains a
810 // position and uses in the same computation. Furthermore, if the original
811 // HloValue had multiple non-trivial positions in the same computation, those
812 // will get their own AllocationValue as well. We split these HloValues so
813 // that when we insert CopyStart/CopyDone in CopyAllocation::Process, they
814 // point to the latest position. We then replace the operand of the use with
815 // CopyStart/CopyDone with an operand of the latest position.
816 for (const HloUse& use : uses) {
817 int64_t use_time = instruction_schedule.at(use.instruction);
818 HloComputation* use_computation = use.instruction->parent();
819
820 AllocationValue* last_allocation_value = nullptr;
821 for (int i = beginning_idx; i < allocation_values.size(); ++i) {
822 AllocationValue* allocation_value = &allocation_values.at(i);
823 if (HloDataflowAnalysis::IsAsynchronousOperationDone(
824 use.instruction->opcode())) {
825 if (allocation_value->defining_instruction() ==
826 use.instruction->operand(0)) {
827 last_allocation_value = allocation_value;
828 }
829 } else if (!HloDataflowAnalysis::IsAsynchronousOperationStart(
830 allocation_value->defining_instruction()->opcode()) &&
831 allocation_value->computation() == use_computation &&
832 instruction_schedule.at(
833 allocation_value->defining_position().instruction) <
834 use_time) {
835 last_allocation_value = allocation_value;
836 }
837 }
838 CHECK(last_allocation_value != nullptr);
839 last_allocation_value->AddUse(use, use_time);
840 }
841
842 for (int i = beginning_idx; i < allocation_values.size(); ++i) {
843 AllocationValue& allocation_value = allocation_values.at(i);
844 if (HloDataflowAnalysis::IsAsynchronousOperationStart(
845 allocation_value.defining_instruction()->opcode())) {
846 CHECK_EQ(allocation_value.uses().size(), 1);
847 CHECK(HloDataflowAnalysis::IsAsynchronousOperationDone(
848 allocation_value.uses().at(0).hlo_use.instruction->opcode()));
849 VLOG(3) << "Mark " << allocation_value.ToShortString()
850 << " to require contiguous allocation.";
851 allocation_value.set_requires_contiguous_allocation(true);
852 }
853 VLOG(3) << "Created allocation value: "
854 << allocation_values.at(i).ToString();
855 }
856 }
857
FindAliases(std::vector<AllocationValue> * allocation_values) const858 void AlternateMemoryBestFitHeap::FindAliases(
859 std::vector<AllocationValue>* allocation_values) const {
860 absl::flat_hash_map<const HloInstruction*,
861 std::vector<const AllocationValue*>>
862 values_by_defining_inst;
863 for (AllocationValue& value : *allocation_values) {
864 values_by_defining_inst[value.defining_instruction()].push_back(&value);
865 }
866 auto maybe_add_alias_with_instruction = [&](const HloInstruction* instruction,
867 AllocationValue::Use* use) {
868 auto aliased_values_it = values_by_defining_inst.find(instruction);
869 if (aliased_values_it != values_by_defining_inst.end()) {
870 for (const AllocationValue* aliased_value : aliased_values_it->second) {
871 VLOG(3) << "Adding aliasing for use " << use->hlo_use.ToString()
872 << " to " << aliased_value->ToShortString();
873 use->aliases.push_back(aliased_value->defining_position());
874 }
875 }
876 };
877
878 for (AllocationValue& value : *allocation_values) {
879 for (AllocationValue::Use& use : value.uses()) {
880 // Find any aliases with the instruction itself (operand and output must
881 // alias).
882 maybe_add_alias_with_instruction(use.hlo_use.instruction, &use);
883
884 // Find any aliases with the parameters of called computations.
885 for (const HloComputation* called_computation :
886 use.hlo_use.instruction->called_computations()) {
887 for (const HloInstruction* parameter_instruction :
888 called_computation->parameter_instructions()) {
889 maybe_add_alias_with_instruction(parameter_instruction, &use);
890 }
891 }
892
893 // Special case for kWhile: the root of the body computation must alias as
894 // well.
895 if (use.hlo_use.instruction->opcode() == HloOpcode::kWhile) {
896 HloPosition root_alias{
897 use.hlo_use.instruction->while_body()->root_instruction(),
898 use.hlo_use.operand_index};
899 VLOG(3) << "Adding while body root aliasing for use "
900 << use.hlo_use.ToString() << " to " << root_alias;
901 use.aliases.push_back(root_alias);
902 }
903 }
904 }
905 }
906
907 std::vector<const AlternateMemoryBestFitHeap::BufferInterval*>
GetSortedColocatedIntervals(const AlternateMemoryBestFitHeap::BufferInterval & interval) const908 AlternateMemoryBestFitHeap::GetSortedColocatedIntervals(
909 const AlternateMemoryBestFitHeap::BufferInterval& interval) const {
910 std::vector<const BufferInterval*> colocated_intervals;
911 std::vector<const BufferInterval*> worklist = {&interval};
912 while (!worklist.empty()) {
913 const BufferInterval* item = worklist.back();
914 worklist.pop_back();
915 colocated_intervals.push_back(item);
916 for (const HloValue* buffer_colocated : item->colocations) {
917 worklist.push_back(&buffer_intervals_.at(buffer_colocated));
918 }
919 }
920
921 absl::c_stable_sort(colocated_intervals, [&](const BufferInterval* x,
922 const BufferInterval* y) {
923 return std::make_pair(x->start, x->end) < std::make_pair(y->start, y->end);
924 });
925 return colocated_intervals;
926 }
927
IsUseAllowedInAlternateMemory(const AllocationValue & value,const HloUse & use) const928 bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
929 const AllocationValue& value, const HloUse& use) const {
930 const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
931 if (!options_.is_use_allowed_in_alternate_mem_fn(use)) {
932 return false;
933 }
934 if (use.instruction->opcode() == HloOpcode::kWhile) {
935 HloComputation* while_body = use.instruction->while_body();
936
937 // We don't want to allocate this buffer in alternate memory if it will be
938 // evicted anyway. Find out if it has an early use or a late definition that
939 // would make sense to keep it in the alternate memory.
940 HloValue* parameter_value =
941 &alias_analysis_.dataflow_analysis().GetUniqueValueAt(
942 while_body->parameter_instruction(0), use.operand_index);
943 int64_t parameter_time =
944 instruction_schedule.at(while_body->parameter_instruction(0));
945 int64_t root_time = instruction_schedule.at(while_body->root_instruction());
946 int64_t min_use_time = root_time;
947 for (const HloUse& parameter_use : parameter_value->uses()) {
948 int64_t use_time = instruction_schedule.at(parameter_use.instruction);
949 if (parameter_use.instruction->opcode() != HloOpcode::kGetTupleElement &&
950 parameter_use.instruction->opcode() != HloOpcode::kTuple &&
951 parameter_use.instruction->opcode() != HloOpcode::kBitcast &&
952 use_time > parameter_time) {
953 min_use_time = std::min(min_use_time, use_time);
954 }
955 }
956 // If there is no use of this buffer inside the while loop, there is no need
957 // to allocate it in the loop.
958 if (min_use_time == root_time) {
959 VLOG(4) << "While allocation not allowed in alternate memory. "
960 << "use time = " << min_use_time << ", root time = " << root_time;
961 return false;
962 }
963 const Shape& shape = parameter_value->shape();
964 // Allow the buffer in alternate memory if the buffer has a short live range
965 // either at the beginning or end of the while loop body.
966 if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
967 shape, parameter_time, min_use_time)) {
968 VLOG(4) << "While allocation not allowed in alternate memory. "
969 << "use time = " << min_use_time << ", root time = " << root_time;
970 return false;
971 }
972 // Check if there is a required assignment for the while loop output.
973 HloValue* while_value =
974 &alias_analysis_.dataflow_analysis().GetUniqueValueAt(
975 use.instruction, use.operand_index);
976 int64_t while_time = instruction_schedule.at(use.instruction);
977 auto existing_required_assignment =
978 RequiredMemoryAssignmentAt(while_value, while_time);
979 if (existing_required_assignment &&
980 existing_required_assignment->memory_space == MemorySpace::kDefault) {
981 VLOG(4) << "While allocation not allowed in alternate memory because "
982 "there is a required default memory assignment.";
983 return false;
984 }
985 } else if (use.instruction->opcode() == HloOpcode::kConditional) {
986 // For any use of this conditional (the same value might be passed into
987 // multiple called computations), determine if the parameter->first use
988 // dependency is short.
989 int64_t conditional_time = instruction_schedule.at(use.instruction);
990 for (const AllocationValue::Use& other_use : value.uses()) {
991 if (other_use.hlo_use.instruction != use.instruction) {
992 continue;
993 }
994 HloComputation* called_computation =
995 use.instruction->called_computations().at(
996 other_use.hlo_use.operand_number - 1);
997 const HloInstruction* parameter_instruction =
998 called_computation->parameter_instruction(0);
999 HloValue* parameter_value =
1000 &alias_analysis_.dataflow_analysis().GetUniqueValueAt(
1001 parameter_instruction, other_use.hlo_use.operand_index);
1002 int64_t parameter_time = instruction_schedule.at(parameter_instruction);
1003 int64_t min_use_time = conditional_time;
1004 for (const HloUse& parameter_use : parameter_value->uses()) {
1005 if (parameter_use.instruction->parent() == called_computation &&
1006 parameter_use.instruction->opcode() !=
1007 HloOpcode::kGetTupleElement &&
1008 parameter_use.instruction->opcode() != HloOpcode::kTuple &&
1009 parameter_use.instruction->opcode() != HloOpcode::kBitcast) {
1010 min_use_time = std::min(
1011 min_use_time, instruction_schedule.at(parameter_use.instruction));
1012 }
1013 }
1014 if (options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
1015 parameter_value->shape(), parameter_time, min_use_time)) {
1016 VLOG(4) << "Conditional allocation allowed in alternate memory for "
1017 "computation = "
1018 << called_computation->name()
1019 << ", parameter time = " << parameter_time
1020 << ", min use time = " << min_use_time;
1021 return true;
1022 } else {
1023 VLOG(4) << "Conditional allocation not allowed in alternate memory for "
1024 "computation = "
1025 << called_computation->name()
1026 << ", parameter time = " << parameter_time
1027 << ", min use time = " << min_use_time;
1028 }
1029 }
1030 return false;
1031 }
1032
1033 return true;
1034 }
1035
AppendBufferInfoDebugString(const AlternateMemoryBestFitHeap::BufferInterval & interval,std::string * debug_str) const1036 void AlternateMemoryBestFitHeap::AppendBufferInfoDebugString(
1037 const AlternateMemoryBestFitHeap::BufferInterval& interval,
1038 std::string* debug_str) const {
1039 // Columns in buffer information:
1040 // buffer_id: int. This value can be used to match the allocation in
1041 // allocation information.
1042 // buffer_name: string.
1043 // alt_mem_benefit: float. Roughly corresponds to how much the cost analysis
1044 // thought it would be beneficial to put this in the alternate memory. The
1045 // higher the value, the more it is memory bound.
1046 // size: int. In bytes.
1047 // definition_time: int. Logical time this value was defined in the schedule.
1048 // use_times: string. This is a semicolon-separated list of integers for all
1049 // the use times.
1050 // use_names: string. This is a semicolon-separated list of string
1051 // representation of uses.
1052 if (debug_str->empty()) {
1053 // Append the column names.
1054 absl::StrAppend(debug_str,
1055 "buffer_id,buffer_name,alt_mem_benefit,size,"
1056 "definition_time,use_times,use_names\n");
1057 }
1058 const HloBuffer& buffer =
1059 alias_analysis_.GetBufferContainingValue(*interval.buffer);
1060 const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
1061 int64_t definition_time =
1062 instruction_schedule.at(interval.buffer->defining_position().instruction);
1063 std::vector<std::pair<int64, std::string>> uses;
1064 for (const HloValue* value : buffer.values()) {
1065 for (const HloUse& use : value->uses()) {
1066 uses.push_back(
1067 {instruction_schedule.at(use.instruction), use.ToString()});
1068 }
1069 }
1070 absl::c_sort(uses);
1071 std::vector<int64> use_times;
1072 std::vector<std::string> use_names;
1073 use_times.reserve(uses.size());
1074 use_names.reserve(uses.size());
1075 for (const auto& use : uses) {
1076 use_times.push_back(use.first);
1077 use_names.push_back(use.second);
1078 }
1079
1080 absl::StrAppend(debug_str, buffer.id(), ",");
1081 absl::StrAppend(debug_str, "\"", interval.buffer->ToShortString(), "\",");
1082 auto alternate_memory_benefit =
1083 options_.prefetch_interval_picker->BufferIntervalAlternateMemoryBenefit(
1084 interval);
1085 absl::StrAppend(
1086 debug_str, alternate_memory_benefit ? *alternate_memory_benefit : 0, ",");
1087 absl::StrAppend(debug_str, interval.size, ",");
1088 absl::StrAppend(debug_str, definition_time, ",");
1089 absl::StrAppend(debug_str, "\"", absl::StrJoin(use_times, ";"), "\",");
1090 absl::StrAppend(debug_str, "\"", absl::StrJoin(use_names, ";"), "\"");
1091 absl::StrAppend(debug_str, "\n");
1092 }
1093
AppendAllocationInfoDebugString(const AllocationValue & value,const MemorySpaceAssignment::Allocation & allocation,std::string & debug_str) const1094 void AlternateMemoryBestFitHeap::AppendAllocationInfoDebugString(
1095 const AllocationValue& value,
1096 const MemorySpaceAssignment::Allocation& allocation,
1097 std::string& debug_str) const {
1098 // Columns in allocation information:
1099 // buffer_id: int. This value can be used the match with buffer info.
1100 // size: int. In bytes.
1101 // offset: int. In bytes.
1102 // start_time: int. Logical start time of the allocation.
1103 // end_time: int. Logical end time of the allocation.
1104 if (debug_str.empty()) {
1105 // Append the column names.
1106 absl::StrAppend(&debug_str, "buffer_id,size,offset,start_time,end_time\n");
1107 }
1108 if (allocation.memory_space() == MemorySpace::kAlternate) {
1109 const HloBuffer& buffer =
1110 alias_analysis_.GetBufferContainingValue(*value.value());
1111 absl::StrAppend(&debug_str, buffer.id(), ",");
1112 absl::StrAppend(&debug_str, value.size(), ",");
1113 absl::StrAppend(&debug_str, allocation.chunk().offset, ",");
1114 absl::StrAppend(&debug_str, allocation.start_time(), ",");
1115 absl::StrAppend(&debug_str, allocation.end_time(), "\n");
1116 }
1117 }
1118
DumpDebugStringsIfEnabled() const1119 void AlternateMemoryBestFitHeap::DumpDebugStringsIfEnabled() const {
1120 if (!options_.dump_fn) {
1121 return;
1122 }
1123 options_.dump_fn("bufferinfo", buffer_info_str_);
1124 options_.dump_fn("allocinfo", allocation_info_str_);
1125 }
1126
Finish()1127 HeapSimulator::Result<HloValue> AlternateMemoryBestFitHeap::Finish() {
1128 AllocateReservedScopedAllocations();
1129 if (options_.enable_cross_program_prefetch) {
1130 absl::optional<AlternateMemoryBestFitHeap::BufferInterval>
1131 prefetch_candidate = FindCrossProgramPrefetchCandidate(
1132 alias_analysis_, hlo_live_range_, options_);
1133 if (prefetch_candidate) {
1134 HloModule* module =
1135 prefetch_candidate->buffer->instruction()->GetModule();
1136 AllocateCrossProgramPrefetchBuffer(module, prefetch_candidate);
1137 }
1138 }
1139
1140 std::vector<BufferInterval> sorted_buffer_intervals =
1141 GetSortedBufferIntervals();
1142
1143 VLOG(1) << "Assigning buffers to alternate memory. Max heap size = "
1144 << options_.max_size_in_bytes;
1145
1146 AddInputAndOutputRequiredAssignments();
1147
1148 if (VLOG_IS_ON(3)) {
1149 VLOG(3) << "Flattened instruction sequence:";
1150 const auto& instruction_sequence =
1151 hlo_live_range_.flattened_instruction_sequence().instructions();
1152 for (int i = 0; i < instruction_sequence.size(); ++i) {
1153 VLOG(3) << " " << i << ": " << instruction_sequence[i]->parent()->name()
1154 << " " << instruction_sequence[i]->name();
1155 }
1156 }
1157
1158 for (const auto& interval : sorted_buffer_intervals) {
1159 auto colocated_intervals = GetSortedColocatedIntervals(interval);
1160 if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) {
1161 // Increment the reserved part of alternate memory so that it is not
1162 // available for other buffers.
1163 reserved_in_bytes_ += options_.size_fn(*interval.buffer);
1164 }
1165 }
1166 VLOG(2) << "Total reserved bytes = " << reserved_in_bytes_;
1167
1168 for (auto& interval : sorted_buffer_intervals) {
1169 if (!interval.need_allocation) {
1170 continue;
1171 }
1172
1173 if (!MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
1174 interval)) {
1175 continue;
1176 }
1177
1178 HloInstruction* inst = interval.buffer->instruction();
1179 HloModule* module = inst->GetModule();
1180
1181 // Don't intra-program prefetch a cross program prefetch
1182 if (inst->opcode() == HloOpcode::kParameter &&
1183 absl::c_count(module->CrossProgramPrefetches(),
1184 std::make_pair(inst->parameter_number(),
1185 interval.buffer->index())) > 0) {
1186 VLOG(3) << "Skip " << interval.buffer->ToShortString()
1187 << " because it is cross-program prefetched.";
1188 continue;
1189 }
1190
1191 if (interval.size > available_heap_size()) {
1192 VLOG(3) << "Skip " << interval.buffer->ToShortString()
1193 << " because the buffer is larger than the heap size.";
1194 continue;
1195 }
1196
1197 auto colocated_intervals = GetSortedColocatedIntervals(interval);
1198
1199 if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) {
1200 VLOG(3) << "Interval " << interval.buffer->ToShortString()
1201 << " is reserved in the alternate memory.";
1202 for (const BufferInterval* colocated_interval : colocated_intervals) {
1203 const HloValue* value = colocated_interval->buffer;
1204 // Color all of the aliased reserved buffers here because reserved
1205 // alternate memory allocations will not have an entry in preset
1206 // allocations that is normally used for coloring.
1207 for (auto& position : value->positions()) {
1208 VLOG(4) << "Coloring " << position.ToString();
1209 Shape* shape = ShapeUtil::GetMutableSubshape(
1210 position.instruction->mutable_shape(), position.index);
1211 CHECK(shape->IsArray()) << "Coloring a shape that is not an array: "
1212 << position.ToString();
1213 shape->mutable_layout()->set_memory_space(
1214 options_.alternate_memory_space);
1215 }
1216 }
1217 continue;
1218 }
1219
1220 if (colocated_intervals.size() > 1 &&
1221 !options_.allocate_across_sequential_calls) {
1222 VLOG(4) << "Not allocating " << interval.buffer->ToShortString()
1223 << " because it aliases with another interval and "
1224 << " allocate_across_sequential_calls is false.";
1225 continue;
1226 }
1227
1228 if (!ConsumeFuel("memory_space_assignment", [&] {
1229 return absl::StrCat("Ran out of fuel at buffer: ",
1230 colocated_intervals[0]->buffer->ToShortString());
1231 })) {
1232 continue;
1233 }
1234
1235 AppendBufferInfoDebugString(interval, &buffer_info_str_);
1236
1237 std::vector<AllocationValue> allocation_values;
1238 CreateAllocationValuesFromColocatedIntervals(colocated_intervals,
1239 allocation_values);
1240
1241 // Retry allocating this value with larger limits if allocation fails.
1242 bool repacked = false;
1243 for (int retry_number = 0; retry_number < options_.max_retries;
1244 retry_number++) {
1245 AddRequiredAssignmentsForColocatedIntervals(colocated_intervals);
1246 bool final_retry = (retry_number == options_.max_retries - 1);
1247 options_.prefetch_interval_picker->SetRetryNumber(retry_number);
1248 Result result =
1249 AllocateAllocationValues(absl::MakeSpan(allocation_values));
1250 VLOG(2) << "Allocation result = "
1251 << absl::StrFormat("%x", static_cast<int>(result));
1252 if (result_requires_uncommit(result) ||
1253 (!final_retry && result_failed_because_of_async_copy(result))) {
1254 UncommitPendingChunks(absl::MakeSpan(allocation_values));
1255 VLOG(2) << "Couldn't allocate. Retry number " << retry_number;
1256 } else if ((result_is(result, Result::kFailOutOfMemory) ||
1257 options_.repack_after_every_allocation) &&
1258 num_repacks_ < options_.max_repacks && !repacked) {
1259 UncommitPendingChunks(absl::MakeSpan(allocation_values));
1260 ++num_repacks_;
1261 repacked = true;
1262 CHECK_NE(options_.repacker, nullptr);
1263 std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*>
1264 repack_allocation_blocks;
1265 ExportAllocationsForRepacking(repack_allocation_blocks);
1266 VLOG(2) << "Repacking.";
1267 auto repack_status =
1268 options_.repacker->Repack(absl::MakeSpan(repack_allocation_blocks));
1269 CHECK_EQ(repack_status.status(), Status::OK());
1270 VLOG(2) << "Repack complete. Modified = " << *repack_status;
1271 if (*repack_status) {
1272 ImportRepackedAllocations();
1273 --retry_number;
1274 }
1275 } else {
1276 FinalizeAllocations(absl::MakeSpan(allocation_values));
1277 break;
1278 }
1279 }
1280 }
1281
1282 VLOG(3) << "Debug buffer info: ";
1283 XLA_VLOG_LINES(3, buffer_info_str_);
1284 VLOG(3) << "Debug allocation info: ";
1285 XLA_VLOG_LINES(3, allocation_info_str_);
1286 DumpDebugStringsIfEnabled();
1287
1288 HeapSimulator::Result<HloValue> result;
1289 result.heap_size = result_.heap_size;
1290 result.heap_results.emplace_back(std::move(result_));
1291 return result;
1292 }
1293
AddRequiredAssignmentsForColocatedIntervals(absl::Span<const AlternateMemoryBestFitHeap::BufferInterval * const> colocated_intervals)1294 void AlternateMemoryBestFitHeap::AddRequiredAssignmentsForColocatedIntervals(
1295 absl::Span<const AlternateMemoryBestFitHeap::BufferInterval* const>
1296 colocated_intervals) {
1297 // TODO(berkin): For now, place the phi values due to conditionals in
1298 // default memory.
1299 for (const BufferInterval* colocated_interval : colocated_intervals) {
1300 const HloValue* value = colocated_interval->buffer;
1301 for (const auto& position : value->positions()) {
1302 if (position.instruction->opcode() == HloOpcode::kConditional) {
1303 VLOG(3) << "Adding required assignment for condition output: "
1304 << value->ToShortString();
1305 AddRequiredAssignment(position.instruction, position.index,
1306 MemorySpace::kDefault);
1307 for (const HloComputation* called_computation :
1308 position.instruction->called_computations()) {
1309 AddRequiredAssignment(called_computation->root_instruction(),
1310 position.index, MemorySpace::kDefault);
1311 }
1312 }
1313 }
1314 }
1315 }
1316
CreateAllocationValuesFromColocatedIntervals(absl::Span<const AlternateMemoryBestFitHeap::BufferInterval * const> colocated_intervals,std::vector<MemorySpaceAssignment::AllocationValue> & allocation_values)1317 void AlternateMemoryBestFitHeap::CreateAllocationValuesFromColocatedIntervals(
1318 absl::Span<const AlternateMemoryBestFitHeap::BufferInterval* const>
1319 colocated_intervals,
1320 std::vector<MemorySpaceAssignment::AllocationValue>& allocation_values) {
1321 // Create AllocationValues for all the colocated intervals.
1322 for (const auto& colocated_interval : colocated_intervals) {
1323 CreateAllocationValues(*colocated_interval, allocation_values);
1324 }
1325 // Go through the AllocationValues and delete the ones that have the identical
1326 // defining instruction and use instructions. This is useful for async
1327 // operations that can read and write to the same buffer, e.g., in-place
1328 // asynchronous collective permute. The AllocationValues that corresponds to
1329 // collective-permute-start{0} (the input) and collective-permute-start{1}
1330 // (the output) refer to the same buffer by definition (since they are created
1331 // from colocated intervals). If we don't delete one of these buffers, then
1332 // when we try to allocate the AllocationValue, we would think they overlap.
1333 auto create_instruction_vector = [](const AllocationValue& allocation_value) {
1334 std::vector<const HloInstruction*> instruction_vector;
1335 instruction_vector.push_back(allocation_value.defining_instruction());
1336 for (const AllocationValue::Use& use : allocation_value.uses()) {
1337 instruction_vector.push_back(use.hlo_use.instruction);
1338 }
1339 return instruction_vector;
1340 };
1341 for (int i = 0; i < allocation_values.size() - 1; ++i) {
1342 for (int j = i + 1; j < allocation_values.size(); ++j) {
1343 const AllocationValue& allocation_value_1 = allocation_values[i];
1344 const AllocationValue& allocation_value_2 = allocation_values[j];
1345 if (create_instruction_vector(allocation_value_1) ==
1346 create_instruction_vector(allocation_value_2)) {
1347 VLOG(3) << "Allocation values " << allocation_value_1.ToShortString()
1348 << " and " << allocation_value_2.ToShortString()
1349 << " are equivalent, deleting the second one.";
1350 allocation_values.erase(allocation_values.begin() + j);
1351 --j;
1352 }
1353 }
1354 }
1355
1356 FindAliases(&allocation_values);
1357 }
1358
1359 AlternateMemoryBestFitHeap::Result
AllocateAllocationValues(absl::Span<MemorySpaceAssignment::AllocationValue> allocation_values)1360 AlternateMemoryBestFitHeap::AllocateAllocationValues(
1361 absl::Span<MemorySpaceAssignment::AllocationValue> allocation_values) {
1362 const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
1363
1364 // Find the use times across all of the related AllocationValues and sort
1365 // them. We use these to find allocations that are available throughout the
1366 // entire live range of all the AllocationValues.
1367 std::vector<int64_t> all_use_times;
1368 for (const AllocationValue& allocation_value : allocation_values) {
1369 absl::c_transform(allocation_value.uses(),
1370 std::back_inserter(all_use_times),
1371 [](const AllocationValue::Use& use) { return use.time; });
1372 }
1373 absl::c_sort(all_use_times);
1374
1375 // Data structure to contain the preferred offset for a given computation.
1376 // We ensure that the same offset will be allocated outside the while loop
1377 // as well as inside the while loop.
1378 absl::flat_hash_map<const HloComputation*, AliasedOffset*>
1379 preferred_offset_for_computation;
1380
1381 Result result = Result::kSuccess;
1382 for (AllocationValue& allocation_value : allocation_values) {
1383 int64_t definition_time =
1384 instruction_schedule.at(allocation_value.defining_instruction());
1385
1386 AliasedOffset* preferred_offset = nullptr;
1387 auto preferred_offset_it =
1388 preferred_offset_for_computation.find(allocation_value.computation());
1389 if (preferred_offset_it != preferred_offset_for_computation.end()) {
1390 preferred_offset = preferred_offset_it->second;
1391 }
1392
1393 // Iterate over the uses.
1394 for (int use_idx = 0; use_idx < allocation_value.uses().size(); ++use_idx) {
1395 const AllocationValue::Use& use = allocation_value.uses().at(use_idx);
1396 const HloUse hlo_use = use.hlo_use;
1397 int64_t use_time = instruction_schedule.at(hlo_use.instruction);
1398 int64_t latest_prefetch_time = use_time;
1399 bool allow_no_copy_alternate_mem_allocation = true;
1400 absl::optional<int64> earliest_prefetch_time = absl::nullopt;
1401
1402 // Sequential calls include kWhile, kCall, and kConditional opcodes.
1403 bool is_sequential_call =
1404 (GetInstructionCallContext(hlo_use.instruction->opcode()) ==
1405 CallContext::kSequential);
1406 if (is_sequential_call) {
1407 for (const HloComputation* called_computation :
1408 hlo_use.instruction->called_computations()) {
1409 const HloLiveRange::TimeBound& computation_span =
1410 hlo_live_range_.computation_span_times().at(called_computation);
1411 latest_prefetch_time =
1412 std::min(computation_span.start - 1, latest_prefetch_time);
1413 }
1414 if (hlo_use.instruction->opcode() == HloOpcode::kWhile) {
1415 // Given an example while loop and flattened schedule (logical times
1416 // shown on the left):
1417 //
1418 // 0: a = ...
1419 // 1: ...
1420 // cond {
1421 // 2: p = param(0)
1422 // 3: ...
1423 // }
1424 // body {
1425 // 4: p = param(0)
1426 // 5: ...
1427 // 6: ROOT ...
1428 // }
1429 // 7: w = while(a), body=body, cond=cond
1430 //
1431 // When processing "a" (time 0) and its while use (time 7), we update
1432 // the interval to time 0-4. This is so that the remaining interval
1433 // (5-6) can be allocated separately and this buffer doesn't waste
1434 // alternate memory space within the while loop body.
1435 HloComputation* while_body = hlo_use.instruction->while_body();
1436 // We require while body ROOTs to be the last in the schedule.
1437 CHECK_EQ(instruction_schedule.at(while_body->root_instruction()) + 1,
1438 instruction_schedule.at(hlo_use.instruction))
1439 << "While body ROOTs need to be the last in the schedule! "
1440 "Please run RootInstructionSinker.";
1441 // Replace the use time with the parameter time so that we can decide
1442 // on alternate memory allocations within the while loop body when we
1443 // look at uses within the while loop body.
1444 use_time =
1445 instruction_schedule.at(while_body->parameter_instruction(0));
1446 } else if (hlo_use.instruction->opcode() == HloOpcode::kConditional) {
1447 // Replace the use time with the earliest parameter of called
1448 // computations.
1449 for (const HloComputation* called_computation :
1450 hlo_use.instruction->called_computations()) {
1451 use_time = std::min(
1452 use_time, instruction_schedule.at(
1453 called_computation->parameter_instruction(0)));
1454 }
1455 }
1456 }
1457
1458 // Add a required assignment in default memory if the use not allowed in
1459 // alternate memory.
1460 if (!IsUseAllowedInAlternateMemory(allocation_value, hlo_use)) {
1461 AddRequiredAssignment(allocation_value.value(), hlo_use.instruction,
1462 MemorySpace::kDefault, use_time);
1463 } else if (use_idx > 0) {
1464 // We allow buffers in alternate memory that are passed into
1465 // conditionals to give up their alternate memory allocation inside the
1466 // called computation. This means that if a conditional operator has an
1467 // alternate memory allocation, subsequent uses cannot use the same
1468 // alternate memory allocation in order not to clobber data. So we force
1469 // default memory allocation for these subsequent uses.
1470 const AllocationValue::Use& previous_use =
1471 allocation_value.uses().at(use_idx - 1);
1472 if (previous_use.hlo_use.instruction->opcode() ==
1473 HloOpcode::kConditional &&
1474 previous_use.hlo_use.instruction != hlo_use.instruction) {
1475 allow_no_copy_alternate_mem_allocation = false;
1476 earliest_prefetch_time =
1477 instruction_schedule.at(previous_use.hlo_use.instruction);
1478 VLOG(3) << "Previous use (" << previous_use.hlo_use.ToString()
1479 << ") of use (" << hlo_use.ToString()
1480 << ") is a conditional, so this use will need to evict. "
1481 << "Earliest prefetch time = " << *earliest_prefetch_time;
1482 }
1483 }
1484
1485 // Bitcasts don't define buffers and don't directly consume buffers. Skip
1486 // allocating buffers for bitcast uses (unless they are the root
1487 // instruction). The uses that feed from bitcasts will be handled
1488 // specially.
1489 if (hlo_use.instruction->opcode() != HloOpcode::kBitcast ||
1490 hlo_use.instruction ==
1491 hlo_use.instruction->parent()->root_instruction()) {
1492 AllocationRequest request;
1493 // Rarely, (e.g., when conditional true and false parameters are the
1494 // same), definition time can be the time of the conditional and use
1495 // time is the parameter use, which is less.
1496 request.start_time = std::min(definition_time, use_time);
1497 request.end_time = use_time;
1498 request.latest_prefetch_time = latest_prefetch_time;
1499 request.size = allocation_value.size();
1500 request.allow_no_copy_alternate_mem_allocation =
1501 allow_no_copy_alternate_mem_allocation;
1502 request.earliest_prefetch_time = earliest_prefetch_time;
1503 request.preferred_offset = preferred_offset;
1504 request.use = &use;
1505 request.allocation_value = &allocation_value;
1506 request.all_use_times = all_use_times;
1507 result_mark(AllocateSegment(request), result);
1508 if (result_requires_uncommit(result)) {
1509 // If the allocation finding failed (e.g., due to running out of
1510 // asynchronous copies), then fall back to allocating the buffer
1511 // entirely in the default memory.
1512 return result;
1513 }
1514
1515 // If there are multiple uses, they can try using the memory allocation
1516 // already at the alternate memory.
1517 definition_time = instruction_schedule.at(hlo_use.instruction);
1518 }
1519
1520 // Propagate the allocation to any aliases this use might have had.
1521 MemorySpaceAssignment::Allocation* aliased_allocation =
1522 GetLiveAllocationAt(*allocation_value.allocation_sequence(),
1523 use_time);
1524 for (const HloPosition& aliased_position : use.aliases) {
1525 AddAliasedRequiredAssignment(aliased_position.instruction,
1526 aliased_position.index,
1527 aliased_allocation);
1528 }
1529
1530 if (hlo_use.instruction->opcode() == HloOpcode::kWhile &&
1531 aliased_allocation->memory_space() == MemorySpace::kAlternate) {
1532 // For while uses that are allocated in the alternate memory space, if
1533 // they also have an allocation in the default memory space in their
1534 // allocation sequence, create a "parent" allocation that mirrors this
1535 // default memory space allocation. When we process the parent
1536 // allocation, we add an additional parameter to the while that is a
1537 // reference to the buffer in the default memory space. With parent
1538 // allocations, we don't need to unnecessarily evict buffers since they
1539 // already have a copy in the default memory space. We search backwards
1540 // (latest to earliest in execution time) for a suitable allocation in
1541 // order to find the most recent one.
1542 if (absl::c_find_if(allocation_value.value()->positions(),
1543 [&hlo_use](const HloPosition& position) {
1544 return position.instruction ==
1545 hlo_use.instruction &&
1546 position.index == hlo_use.operand_index;
1547 }) != allocation_value.value()->positions().end()) {
1548 auto allocation_sequence = allocation_value.allocation_sequence();
1549 auto prev_allocation_in_default_mem_it = std::find_if(
1550 allocation_sequence->rbegin(), allocation_sequence->rend(),
1551 [&](const auto& allocation) {
1552 return allocation->memory_space() == MemorySpace::kDefault &&
1553 allocation->defining_position() ==
1554 allocation_value.defining_position();
1555 });
1556 if (prev_allocation_in_default_mem_it !=
1557 allocation_sequence->rend()) {
1558 VLOG(3) << "Found a prev allocation in default mem for while use: "
1559 << (*prev_allocation_in_default_mem_it)->ToString();
1560 auto body_allocation_value_it = absl::c_find_if(
1561 allocation_values, [&](const AllocationValue& value) {
1562 return value.computation() ==
1563 hlo_use.instruction->while_body() &&
1564 value.defining_instruction()->opcode() ==
1565 HloOpcode::kParameter;
1566 });
1567 CHECK_NE(body_allocation_value_it, allocation_values.end());
1568 VLOG(3) << "Body allocation value: "
1569 << body_allocation_value_it->ToShortString();
1570 int64_t body_parameter_time = instruction_schedule.at(
1571 body_allocation_value_it->defining_instruction());
1572 body_allocation_value_it->allocation_sequence()->push_back(
1573 absl::make_unique<MemorySpaceAssignment::ParentAllocation>(
1574 **prev_allocation_in_default_mem_it, hlo_use.instruction,
1575 body_allocation_value_it->defining_position(),
1576 body_parameter_time));
1577 VLOG(3) << "Created: "
1578 << body_allocation_value_it->allocation_sequence()
1579 ->back()
1580 ->ToString();
1581 }
1582 }
1583 // Special case for while loops since the root offset must agree with
1584 // other offsets: remember the preferred offset for the while loop body.
1585 preferred_offset_for_computation[hlo_use.instruction->while_body()] =
1586 GetAliasedOffset(*aliased_allocation);
1587 }
1588 }
1589 }
1590 return result;
1591 }
1592
operator <(const AsynchronousCopy & a,const AsynchronousCopy & b)1593 bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b) {
1594 return (a.start_time < b.start_time && a.end_time <= b.end_time) ||
1595 (a.start_time <= b.start_time && a.end_time < b.end_time);
1596 }
1597
AddCopy(const AsynchronousCopy & copy)1598 void AsynchronousCopyOrdering::AddCopy(const AsynchronousCopy& copy) {
1599 auto it_and_inserted = ranges_.insert(copy);
1600 CHECK(it_and_inserted.second ||
1601 it_and_inserted.first->start_time == copy.start_time);
1602 }
1603
RemoveCopy(const AsynchronousCopy & copy)1604 void AsynchronousCopyOrdering::RemoveCopy(const AsynchronousCopy& copy) {
1605 auto copy_it = ranges_.find(copy);
1606 CHECK(copy_it != ranges_.end());
1607 ranges_.erase(copy_it);
1608 }
1609
ViolatesOrdering(int64_t start_time,int64_t end_time) const1610 absl::optional<AsynchronousCopy> AsynchronousCopyOrdering::ViolatesOrdering(
1611 int64_t start_time, int64_t end_time) const {
1612 // We allow identical start and end times. It is enough to check for just the
1613 // start time in case we find a match in ranges_ because the found value will
1614 // either be identical to {start_time, end_time} (and this doesn't violate) or
1615 // its start_time will be smaller and end_time will be larger (this violates).
1616 auto copy_it = ranges_.find(
1617 {start_time, end_time, MemorySpaceAssignment::MemorySpace::kAlternate});
1618 if (copy_it != ranges_.end() && copy_it->start_time != start_time) {
1619 VLOG(4) << "Violates ordering: (" << start_time << ", " << end_time
1620 << ") and (" << copy_it->start_time << ", " << copy_it->end_time
1621 << ")";
1622 return *copy_it;
1623 }
1624 return absl::nullopt;
1625 }
1626
1627 AlternateMemoryBestFitHeap::AliasedOffset*
GetAliasedOffset(const MemorySpaceAssignment::Allocation & allocation)1628 AlternateMemoryBestFitHeap::GetAliasedOffset(
1629 const MemorySpaceAssignment::Allocation& allocation) {
1630 auto aliased_offset_it = aliased_offset_map_.find(&allocation);
1631 CHECK(aliased_offset_it != aliased_offset_map_.end());
1632 return aliased_offset_it->second;
1633 }
1634
CreateOrAddToAliasedOffset(const MemorySpaceAssignment::Allocation & allocation,AlternateMemoryBestFitHeap::AliasedOffset * aliased_offset)1635 void AlternateMemoryBestFitHeap::CreateOrAddToAliasedOffset(
1636 const MemorySpaceAssignment::Allocation& allocation,
1637 AlternateMemoryBestFitHeap::AliasedOffset* aliased_offset) {
1638 CHECK(allocation.memory_space() == MemorySpace::kAlternate);
1639 CHECK(!aliased_offset_map_.contains(&allocation));
1640 if (!aliased_offset) {
1641 aliased_offsets_.push_back({allocation.chunk().offset});
1642 aliased_offset = &aliased_offsets_.back();
1643 }
1644 CHECK_EQ(allocation.chunk().offset, aliased_offset->offset);
1645 CHECK(aliased_offset->allocations.insert(&allocation).second);
1646 aliased_offset_map_[&allocation] = aliased_offset;
1647 }
1648
1649 /*static*/ MemorySpaceAssignment::Allocation*
GetLiveAllocationAt(const MemorySpaceAssignment::AllocationSequence & allocations,int64_t time)1650 AlternateMemoryBestFitHeap::GetLiveAllocationAt(
1651 const MemorySpaceAssignment::AllocationSequence& allocations,
1652 int64_t time) {
1653 for (auto allocation_it = allocations.rbegin();
1654 allocation_it != allocations.rend(); ++allocation_it) {
1655 if ((*allocation_it)->start_time() <= time &&
1656 (*allocation_it)->end_time() >= time) {
1657 return allocation_it->get();
1658 }
1659 }
1660 return nullptr;
1661 }
1662
AllocateCrossProgramPrefetchBuffer(HloModule * module,absl::optional<BufferInterval> prefetch_candidate)1663 void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer(
1664 HloModule* module, absl::optional<BufferInterval> prefetch_candidate) {
1665 if (!prefetch_candidate) {
1666 return;
1667 }
1668
1669 ChunkCandidate chunk_candidate = FindChunkCandidate(*prefetch_candidate);
1670 if (chunk_candidate.chunk.offset != 0 ||
1671 chunk_candidate.heap_size > available_heap_size()) {
1672 LOG(WARNING)
1673 << "Could not allocate preferred memory for cross program prefetch";
1674 return;
1675 }
1676
1677 const HloValue* buffer = prefetch_candidate->buffer;
1678 int64_t parameter = buffer->instruction()->parameter_number();
1679 module->AddCrossProgramPrefetch(parameter, buffer->index());
1680
1681 MemorySpaceAssignment::AllocationSequence allocations;
1682 allocations.push_back(absl::make_unique<MemorySpaceAssignment::Allocation>(
1683 buffer->defining_position(), MemorySpace::kDefault, kDummyChunk,
1684 prefetch_candidate->start, prefetch_candidate->end,
1685 /*is_scoped_allocation=*/false));
1686
1687 // Find the earliest use.
1688 const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
1689 auto uses = buffer->uses();
1690 auto use_schedule_compare = [&](const HloUse& lhs, const HloUse& rhs) {
1691 return instruction_schedule.at(lhs.instruction) <
1692 instruction_schedule.at(rhs.instruction);
1693 };
1694 auto first_use = absl::c_min_element(uses, use_schedule_compare);
1695 int64_t latest_prefetch_time =
1696 instruction_schedule.at(first_use->instruction);
1697
1698 // Find the latest use time.
1699 int64_t last_use_time = instruction_schedule.at(
1700 absl::c_max_element(uses, use_schedule_compare)->instruction);
1701 for (const HloValue* colocation : prefetch_candidate->colocations) {
1702 last_use_time = std::max(
1703 last_use_time,
1704 instruction_schedule.at(
1705 absl::c_max_element(colocation->uses(), use_schedule_compare)
1706 ->instruction));
1707 }
1708
1709 int64_t end_of_program_prefetch_end_time = instruction_schedule.size();
1710 int64_t end_of_program_prefetch_start_time =
1711 options_.prefetch_interval_picker->PreferredPrefetchStartTime(
1712 buffer->defining_position().shape(), last_use_time,
1713 end_of_program_prefetch_end_time, end_of_program_prefetch_end_time);
1714 VLOG(2) << "last use time = " << last_use_time
1715 << ", end-of-program prefetch start time = "
1716 << end_of_program_prefetch_start_time;
1717 bool free_buffer =
1718 (options_.enable_cross_program_prefetch_freeing &&
1719 end_of_program_prefetch_start_time > last_use_time &&
1720 end_of_program_prefetch_start_time < end_of_program_prefetch_end_time);
1721 int64_t cross_program_prefetch_end_time =
1722 free_buffer ? last_use_time : prefetch_candidate->end;
1723
1724 AddAsyncCopy(*allocations.back(), MemorySpace::kAlternate,
1725 chunk_candidate.chunk, prefetch_candidate->start,
1726 cross_program_prefetch_end_time, latest_prefetch_time,
1727 &allocations, /*aliased_offset=*/nullptr,
1728 /*is_cross_program_prefetch=*/true);
1729 absl::c_for_each(uses, [&](auto& use) { allocations.back()->AddUse(use); });
1730 AliasedOffset* cross_program_prefetch_offset =
1731 GetAliasedOffset(*allocations.back());
1732
1733 if (free_buffer) {
1734 VLOG(2) << "Adding an end-of-program prefetch for freed "
1735 "cross-program-prefetched buffer.";
1736 AddAsyncCopy(*allocations.front(), MemorySpace::kAlternate,
1737 chunk_candidate.chunk, end_of_program_prefetch_start_time,
1738 end_of_program_prefetch_end_time,
1739 end_of_program_prefetch_end_time, &allocations,
1740 cross_program_prefetch_offset);
1741 CHECK_EQ(cross_program_prefetch_offset->offset,
1742 allocations.back()->chunk().offset);
1743 }
1744
1745 const int allocations_initial_size = allocations_->size();
1746 for (auto& allocation : allocations) {
1747 if (allocation->memory_space() == MemorySpace::kAlternate) {
1748 BufferInterval buffer_interval;
1749 buffer_interval.start = allocation->start_time();
1750 buffer_interval.end = allocation->end_time();
1751 buffer_interval.size = allocation->chunk().size;
1752 buffer_interval.buffer = prefetch_candidate->buffer;
1753 AddToPendingChunks(buffer_interval, chunk_candidate);
1754 }
1755 allocations_->push_back(std::move(allocation));
1756 }
1757
1758 // Add a repack allocation block for the Allocation objects in alternate
1759 // memory.
1760 for (int i = allocations_initial_size; i < allocations_->size(); ++i) {
1761 const auto& allocation = allocations_->at(i);
1762 if (allocation->memory_space() == MemorySpace::kAlternate) {
1763 repack_allocation_blocks_.push_back(MakeRepackAllocationBlock(
1764 allocation->start_time(), allocation->end_time(),
1765 allocation->chunk().size, allocation->chunk().offset,
1766 static_cast<int64>(repack_allocation_blocks_.size()),
1767 allocation.get()));
1768 RepackAllocationBlock* inserted = &repack_allocation_blocks_.back();
1769 for (RepackAllocationBlock& colocation : repack_allocation_blocks_) {
1770 colocation.colocations.push_back(inserted);
1771 if (&colocation != inserted) {
1772 inserted->colocations.push_back(&colocation);
1773 }
1774 }
1775 }
1776 }
1777
1778 ClearPendingChunks();
1779 }
1780
AllocateReservedScopedAllocations()1781 void AlternateMemoryBestFitHeap::AllocateReservedScopedAllocations() {
1782 const auto& instruction_sequence =
1783 hlo_live_range_.flattened_instruction_sequence().instructions();
1784 std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*> colocations;
1785 for (int i = 0; i < instruction_sequence.size(); ++i) {
1786 int64_t reserved_scoped_memory =
1787 options_.reserved_scoped_memory_fn(instruction_sequence[i]);
1788 if (reserved_scoped_memory != 0) {
1789 VLOG(1) << "Allocate reserved scoped memory at " << i << " ("
1790 << instruction_sequence[i]->name()
1791 << "): " << reserved_scoped_memory;
1792 MemorySpaceAssignment::BufferInterval interval;
1793 interval.buffer = nullptr;
1794 interval.size = reserved_scoped_memory;
1795 interval.start = i;
1796 interval.end = i;
1797 interval.need_allocation = true;
1798 interval.colocations = {};
1799 ChunkCandidate chunk_candidate =
1800 FindChunkCandidate(interval, /*preferred_offset=*/0);
1801 CHECK_EQ(chunk_candidate.chunk.offset, 0);
1802 AddToPendingChunks(interval, chunk_candidate);
1803
1804 allocations_->push_back(
1805 absl::make_unique<MemorySpaceAssignment::Allocation>(
1806 HloPosition{instruction_sequence[i], {}}, MemorySpace::kAlternate,
1807 chunk_candidate.chunk, i, i, /*is_scoped_allocation=*/true));
1808
1809 repack_allocation_blocks_.push_back(MakeRepackAllocationBlock(
1810 i, i, reserved_scoped_memory,
1811 /*initial_offset=*/0,
1812 static_cast<int64>(repack_allocation_blocks_.size()),
1813 allocations_->back().get()));
1814 colocations.push_back(&repack_allocation_blocks_.back());
1815 }
1816 }
1817 // If requested, make all scoped allocations to colocate with each other so
1818 // that when we repack, all scoped allocations get the same offsets. Since
1819 // they will all have the same scoped memory addresses, this increases the
1820 // opportunity to deduplicate different ops. However, this may hurt the
1821 // memory packing efficiency.
1822 if (options_.allocate_reserved_scoped_memory_at_same_offset) {
1823 for (MemorySpaceAssignmentRepacker::AllocationBlock* repack_block :
1824 colocations) {
1825 repack_block->colocations = colocations;
1826 }
1827 }
1828 }
1829
1830 absl::optional<AlternateMemoryBestFitHeap::RequiredMemoryAssignment>
RequiredMemoryAssignmentAt(const HloValue * buffer,int64_t time) const1831 AlternateMemoryBestFitHeap::RequiredMemoryAssignmentAt(const HloValue* buffer,
1832 int64_t time) const {
1833 auto required_assignment_it = required_assignments_.find(buffer);
1834 absl::optional<RequiredMemoryAssignment> required_assignment_at_time;
1835 if (required_assignment_it != required_assignments_.end()) {
1836 for (const RequiredMemoryAssignment& required_assignment :
1837 required_assignment_it->second) {
1838 if (required_assignment.time == time) {
1839 // Sanity check that there is only one required at time.
1840 CHECK(!required_assignment_at_time);
1841 required_assignment_at_time = required_assignment;
1842 }
1843 }
1844 }
1845 return required_assignment_at_time;
1846 }
1847
1848 absl::optional<AlternateMemoryBestFitHeap::RequiredMemoryAssignment>
AliasedRequiredAssignmentForUse(const AllocationValue::Use & use) const1849 AlternateMemoryBestFitHeap::AliasedRequiredAssignmentForUse(
1850 const AllocationValue::Use& use) const {
1851 absl::optional<RequiredMemoryAssignment> required_assignment;
1852 for (const HloPosition& position : use.aliases) {
1853 const HloValue* value =
1854 &alias_analysis_.dataflow_analysis().GetUniqueValueAt(
1855 position.instruction, position.index);
1856 int64_t time =
1857 hlo_live_range_.instruction_schedule().at(position.instruction);
1858 absl::optional<RequiredMemoryAssignment> required_assignment_for_alias =
1859 RequiredMemoryAssignmentAt(value, time);
1860 if (required_assignment == absl::nullopt) {
1861 required_assignment = required_assignment_for_alias;
1862 } else {
1863 CHECK(required_assignment_for_alias == absl::nullopt ||
1864 required_assignment->equals_ignoring_time(
1865 *required_assignment_for_alias));
1866 }
1867 }
1868 return required_assignment;
1869 }
1870
AddAliasedRequiredAssignment(const HloInstruction * instruction,ShapeIndex index,const MemorySpaceAssignment::Allocation * aliased_allocation)1871 void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignment(
1872 const HloInstruction* instruction, ShapeIndex index,
1873 const MemorySpaceAssignment::Allocation* aliased_allocation) {
1874 AliasedOffset* offset = nullptr;
1875 if (aliased_allocation->memory_space() == MemorySpace::kAlternate) {
1876 offset = GetAliasedOffset(*aliased_allocation);
1877 }
1878 AddRequiredAssignment(instruction, index, aliased_allocation->memory_space(),
1879 offset);
1880 }
1881
AddRequiredAssignment(const HloValue * value,const HloInstruction * instruction,MemorySpaceAssignment::MemorySpace memory_space,int64_t time,AliasedOffset * offset)1882 void AlternateMemoryBestFitHeap::AddRequiredAssignment(
1883 const HloValue* value, const HloInstruction* instruction,
1884 MemorySpaceAssignment::MemorySpace memory_space, int64_t time,
1885 AliasedOffset* offset) {
1886 // Check for existing required assignment at this time and make sure it is the
1887 // same as this if there is one.
1888 auto existing_required_assignment = RequiredMemoryAssignmentAt(value, time);
1889 if (existing_required_assignment) {
1890 CHECK(memory_space == existing_required_assignment->memory_space)
1891 << "inst = " << instruction->ToString() << " at " << time;
1892 CHECK((!offset && !existing_required_assignment->offset) ||
1893 offset == existing_required_assignment->offset);
1894 VLOG(3) << "Not adding required assignment because there is one already: "
1895 << value->ToShortString() << " at " << time << " at "
1896 << (memory_space == MemorySpace::kDefault ? "def" : "alt");
1897 } else {
1898 VLOG(3) << "Adding required assignment: " << value->ToShortString()
1899 << " at " << time << " at "
1900 << (memory_space == MemorySpace::kDefault ? "def" : "alt");
1901 RequiredMemoryAssignment required_assignment{memory_space, time, offset};
1902 required_assignments_[value].push_back(required_assignment);
1903 pending_required_assignments_.push_back({value, required_assignment});
1904 }
1905 }
1906
AddRequiredAssignment(const HloInstruction * instruction,ShapeIndex index,MemorySpace memory_space,AliasedOffset * offset)1907 void AlternateMemoryBestFitHeap::AddRequiredAssignment(
1908 const HloInstruction* instruction, ShapeIndex index,
1909 MemorySpace memory_space, AliasedOffset* offset) {
1910 const HloValue* value =
1911 &alias_analysis_.dataflow_analysis().GetUniqueValueAt(instruction, index);
1912 int64_t instruction_time =
1913 hlo_live_range_.instruction_schedule().at(instruction);
1914 AddRequiredAssignment(value, instruction, memory_space, instruction_time,
1915 offset);
1916 }
1917
AddInputAndOutputRequiredAssignments()1918 void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() {
1919 // Go through the parameters, outputs, and constants and pin them to the
1920 // corresponding memory by adding a required assignment.
1921 const HloModule& module = alias_analysis_.dataflow_analysis().module();
1922 const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
1923 HloComputation* entry_computation = module.entry_computation();
1924 for (HloInstruction* parameter_instruction :
1925 entry_computation->parameter_instructions()) {
1926 int64_t parameter_instruction_time =
1927 instruction_schedule.at(parameter_instruction);
1928 ShapeUtil::ForEachSubshape(
1929 parameter_instruction->shape(),
1930 [&](const Shape& subshape, const ShapeIndex& index) {
1931 MemorySpace memory_space = MemorySpace::kDefault;
1932 if (subshape.has_layout() && subshape.layout().memory_space() ==
1933 options_.alternate_memory_space) {
1934 memory_space = MemorySpace::kAlternate;
1935 }
1936 for (const HloBuffer* buffer :
1937 alias_analysis_.ComputeBuffersAt(parameter_instruction, index)) {
1938 for (const HloValue* value : buffer->values()) {
1939 VLOG(3) << "Adding required assignment for parameter value = "
1940 << value->ToShortString()
1941 << " time = " << parameter_instruction_time << " space = "
1942 << (memory_space == MemorySpace::kDefault ? "def"
1943 : "alt");
1944 required_assignments_[value].push_back(
1945 {memory_space, /*time=*/parameter_instruction_time});
1946 }
1947 }
1948 });
1949 }
1950 HloInstruction* root_instruction = entry_computation->root_instruction();
1951 int64_t root_instruction_time = instruction_schedule.at(root_instruction);
1952 ShapeUtil::ForEachSubshape(
1953 root_instruction->shape(),
1954 [&](const Shape& subshape, const ShapeIndex& index) {
1955 MemorySpace memory_space = MemorySpace::kDefault;
1956 if (subshape.has_layout() && subshape.layout().memory_space() ==
1957 options_.alternate_memory_space) {
1958 memory_space = MemorySpace::kAlternate;
1959 }
1960 for (const HloBuffer* buffer :
1961 alias_analysis_.ComputeBuffersAt(root_instruction, index)) {
1962 for (const HloValue* value : buffer->values()) {
1963 VLOG(3) << "Adding required assignment for output value = "
1964 << value->ToShortString()
1965 << " time = " << root_instruction_time << " space = "
1966 << (memory_space == MemorySpace::kDefault ? "def" : "alt");
1967 required_assignments_[value].push_back(
1968 {memory_space, /*time=*/root_instruction_time});
1969 }
1970 }
1971 });
1972
1973 for (const HloComputation* computation : module.MakeNonfusionComputations()) {
1974 for (HloInstruction* instruction : computation->instructions()) {
1975 if (instruction->opcode() == HloOpcode::kConstant) {
1976 auto constant_instruction_it = instruction_schedule.find(instruction);
1977 if (constant_instruction_it == instruction_schedule.end()) {
1978 continue;
1979 }
1980 int64_t constant_instruction_time = constant_instruction_it->second;
1981 for (const auto& indexed_shape :
1982 ShapeUtil::GetLeafShapes(instruction->shape())) {
1983 const ShapeIndex& index = indexed_shape.index;
1984 for (const HloBuffer* buffer :
1985 alias_analysis_.ComputeBuffersAt(instruction, index)) {
1986 for (const HloValue* value : buffer->values()) {
1987 VLOG(3) << "Adding required assignment for constant value = "
1988 << value->ToShortString()
1989 << " time = " << constant_instruction_time
1990 << " space = def";
1991 required_assignments_[value].push_back(
1992 {MemorySpace::kDefault, /*time=*/constant_instruction_time});
1993 }
1994 }
1995 }
1996 }
1997 }
1998 }
1999 }
2000
AreIntervalsReservedInAlternateMemory(absl::Span<const BufferInterval * const> colocated_intervals) const2001 bool AlternateMemoryBestFitHeap::AreIntervalsReservedInAlternateMemory(
2002 absl::Span<const BufferInterval* const> colocated_intervals) const {
2003 auto is_position_in_alternate_memory = [&](const HloPosition& position) {
2004 const Shape& shape = position.shape();
2005 return shape.has_layout() &&
2006 shape.layout().memory_space() == options_.alternate_memory_space;
2007 };
2008
2009 const HloModule& module = alias_analysis_.dataflow_analysis().module();
2010 const HloComputation* entry_computation = module.entry_computation();
2011 const HloInstruction* root_instruction =
2012 entry_computation->root_instruction();
2013 for (const BufferInterval* colocated_interval : colocated_intervals) {
2014 const HloValue* value = colocated_interval->buffer;
2015 if (value->defining_instruction()->opcode() == HloOpcode::kParameter &&
2016 value->defining_instruction()->parent() == entry_computation &&
2017 is_position_in_alternate_memory(value->defining_position())) {
2018 return true;
2019 }
2020
2021 for (const HloPosition& position : value->positions()) {
2022 if (position.instruction == root_instruction &&
2023 is_position_in_alternate_memory(position)) {
2024 return true;
2025 }
2026 }
2027 }
2028 return false;
2029 }
2030
ExportAllocationsForRepacking(std::vector<MemorySpaceAssignmentRepacker::AllocationBlock * > & allocations)2031 void AlternateMemoryBestFitHeap::ExportAllocationsForRepacking(
2032 std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*>& allocations) {
2033 for (RepackAllocationBlock& allocation_block : repack_allocation_blocks_) {
2034 allocations.push_back(&allocation_block);
2035 }
2036 }
2037
ImportRepackedAllocations()2038 void AlternateMemoryBestFitHeap::ImportRepackedAllocations() {
2039 interval_tree_ = {};
2040 for (RepackAllocationBlock& allocation_block : repack_allocation_blocks_) {
2041 MemorySpaceAssignment::Allocation* allocation = allocation_block.allocation;
2042 VLOG(3) << "Moved " << allocation->ToString() << ", size "
2043 << allocation->chunk().size << ", (" << allocation_block.start_time
2044 << ", " << allocation_block.end_time << ") from "
2045 << allocation_block.initial_offset << " to "
2046 << allocation_block.offset;
2047 allocation_block.allocation->mutable_chunk()->offset =
2048 allocation_block.offset;
2049 interval_tree_.Add(allocation_block.start_time, allocation_block.end_time,
2050 {allocation_block.offset, allocation_block.size});
2051 allocation_block.initial_offset = allocation_block.offset;
2052 allocation_block.offset = -1;
2053 }
2054 }
2055
UncommitPendingChunks(absl::Span<AllocationValue> allocation_values)2056 void AlternateMemoryBestFitHeap::UncommitPendingChunks(
2057 absl::Span<AllocationValue> allocation_values) {
2058 // Clear the allocation sequence of the allocation values so that in case we
2059 // retry allocation after uncommitting.
2060 for (AllocationValue& allocation_value : allocation_values) {
2061 allocation_value.allocation_sequence()->clear();
2062 }
2063 for (const auto& interval_and_chunk : pending_chunks_) {
2064 const BufferInterval& interval = interval_and_chunk.first;
2065 const Chunk& chunk = interval_and_chunk.second.chunk;
2066 VLOG(3) << "Uncommitting: (" << interval.start << ", " << interval.end
2067 << ") off = " << chunk.offset << " size = " << chunk.size;
2068 interval_tree_.Remove(interval.start, interval.end, chunk);
2069 }
2070 for (const auto& interval : pending_async_copies_) {
2071 if (interval.destination == MemorySpace::kAlternate) {
2072 prefetch_interval_tree_.Remove(interval.start_time, interval.end_time,
2073 kDummyChunk);
2074 async_copy_ordering_.RemoveCopy(interval);
2075 } else {
2076 eviction_interval_tree_.Remove(interval.start_time, interval.end_time,
2077 kDummyChunk);
2078 }
2079 }
2080 for (const auto& value_and_required_assignment :
2081 pending_required_assignments_) {
2082 auto& required_assignment_vector =
2083 required_assignments_[value_and_required_assignment.first];
2084 const RequiredMemoryAssignment& required_assignment =
2085 value_and_required_assignment.second;
2086 VLOG(3) << "Removing required assignment: "
2087 << (required_assignment.memory_space == MemorySpace::kDefault
2088 ? "def"
2089 : "alt")
2090 << " time = " << required_assignment.time << " off = "
2091 << (required_assignment.offset ? required_assignment.offset->offset
2092 : -1);
2093 for (auto it = required_assignment_vector.begin();
2094 it != required_assignment_vector.end(); ++it) {
2095 if (*it == value_and_required_assignment.second) {
2096 required_assignment_vector.erase(it);
2097 break;
2098 }
2099 }
2100 }
2101 ClearPendingChunks();
2102 }
2103
FinalizeAllocations(absl::Span<AllocationValue> allocation_values)2104 void AlternateMemoryBestFitHeap::FinalizeAllocations(
2105 absl::Span<AllocationValue> allocation_values) {
2106 absl::flat_hash_map<const AliasedOffset*,
2107 std::vector<MemorySpaceAssignment::Allocation*>>
2108 colocation_map;
2109 for (AllocationValue& allocation_value : allocation_values) {
2110 for (auto& allocation : *allocation_value.allocation_sequence()) {
2111 AppendAllocationInfoDebugString(allocation_value, *allocation,
2112 allocation_info_str_);
2113 allocations_->push_back(std::move(allocation));
2114 MemorySpaceAssignment::Allocation* inserted_allocation =
2115 allocations_->back().get();
2116 if (inserted_allocation->memory_space() == MemorySpace::kAlternate) {
2117 colocation_map[GetAliasedOffset(*inserted_allocation)].push_back(
2118 inserted_allocation);
2119 }
2120 }
2121 }
2122 // The allocations that have the same AliasedOffset need to be colocated.
2123 // Export these to repack_allocation_blocks_ so that we can repack them to
2124 // reduce fragmentation.
2125 for (auto& colocation : colocation_map) {
2126 std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*> colocations;
2127 for (MemorySpaceAssignment::Allocation* colocated_allocation :
2128 colocation.second) {
2129 repack_allocation_blocks_.push_back(MakeRepackAllocationBlock(
2130 colocated_allocation->start_time(), colocated_allocation->end_time(),
2131 colocated_allocation->chunk().size,
2132 colocated_allocation->chunk().offset,
2133 static_cast<int64>(repack_allocation_blocks_.size()),
2134 colocated_allocation));
2135 colocations.push_back(&repack_allocation_blocks_.back());
2136 }
2137 for (MemorySpaceAssignmentRepacker::AllocationBlock* repack_block :
2138 colocations) {
2139 repack_block->colocations = colocations;
2140 }
2141 }
2142 ClearPendingChunks();
2143 }
2144
ClearPendingChunks()2145 void AlternateMemoryBestFitHeap::ClearPendingChunks() {
2146 pending_chunks_.clear();
2147 pending_async_copies_.clear();
2148 pending_required_assignments_.clear();
2149 aliased_offset_map_.clear();
2150 aliased_offsets_.clear();
2151 }
2152
AddToPendingChunks(const BufferInterval & buffer_interval,const ChunkCandidate & chunk_candidate)2153 void AlternateMemoryBestFitHeap::AddToPendingChunks(
2154 const BufferInterval& buffer_interval,
2155 const ChunkCandidate& chunk_candidate) {
2156 VLOG(3) << "Committing chunk: " << buffer_interval.start << "-"
2157 << buffer_interval.end << " : [" << chunk_candidate.chunk.offset
2158 << ", " << chunk_candidate.chunk.size << "]";
2159 pending_chunks_.emplace_back(buffer_interval, chunk_candidate);
2160 CommitChunk(buffer_interval, chunk_candidate);
2161 }
2162
AllocateSegment(const AllocationRequest & request)2163 AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment(
2164 const AllocationRequest& request) {
2165 auto allocation_sequence = request.allocation_value->allocation_sequence();
2166 // start_time == end_time is a special case where the value is consumed
2167 // multiple times by the same instruction. We can just find the previous
2168 // allocation and use that allocation.
2169 if (request.start_time == request.end_time) {
2170 MemorySpaceAssignment::Allocation* allocation =
2171 GetLiveAllocationAt(*allocation_sequence, request.end_time);
2172 CHECK_NE(allocation, nullptr);
2173 allocation->AddUse(request.use->hlo_use);
2174 return Result::kSuccess;
2175 }
2176
2177 const HloPosition& defining_position =
2178 request.allocation_value->defining_position();
2179 VLOG(2) << "Finding allocation for "
2180 << request.allocation_value->ToShortString() << " ("
2181 << request.start_time << ", " << request.end_time
2182 << ") latest prefetch = " << request.latest_prefetch_time
2183 << " last use = " << request.allocation_value->uses().back().time
2184 << " use = " << request.use->hlo_use.ToString()
2185 << ". Size = " << request.size
2186 << ", def pos = " << defining_position.ToString();
2187 CHECK_LE(request.start_time, request.end_time);
2188
2189 // There could be a requirement to pin this buffer to default memory either
2190 // because it is a parameter or an output. If the buffer is a parameter, then
2191 // we're allowed to prefetch. If the use expects the output to be in default
2192 // memory, we cannot prefetch it because if we did, it would be in alternate
2193 // memory instead.
2194 auto required_assignment_at_start = RequiredMemoryAssignmentAt(
2195 request.allocation_value->value(), request.start_time);
2196 absl::optional<MemorySpace> required_memory_space_at_start;
2197 if (required_assignment_at_start) {
2198 required_memory_space_at_start = required_assignment_at_start->memory_space;
2199 }
2200 // Find required assignment both for the use and its aliases. If they are both
2201 // non-nullopt, then make sure they require the same assignment.
2202 auto required_assignment_at_end = RequiredMemoryAssignmentAt(
2203 request.allocation_value->value(), request.end_time);
2204 auto aliased_required_assignment_at_end =
2205 AliasedRequiredAssignmentForUse(*request.use);
2206 if (required_assignment_at_end != aliased_required_assignment_at_end) {
2207 if (required_assignment_at_end == absl::nullopt) {
2208 required_assignment_at_end = aliased_required_assignment_at_end;
2209 } else {
2210 CHECK(aliased_required_assignment_at_end == absl::nullopt ||
2211 aliased_required_assignment_at_end->equals_ignoring_time(
2212 *required_assignment_at_end));
2213 }
2214 }
2215 absl::optional<MemorySpace> required_memory_space_at_end;
2216 if (required_assignment_at_end) {
2217 required_memory_space_at_end = required_assignment_at_end->memory_space;
2218 }
2219
2220 if (required_assignment_at_start) {
2221 bool needs_required_allocation = true;
2222 if (!allocation_sequence->empty()) {
2223 auto prev_allocation_it = std::find_if(
2224 allocation_sequence->rbegin(), allocation_sequence->rend(),
2225 [&](const auto& allocation) {
2226 return allocation->memory_space() ==
2227 required_memory_space_at_start &&
2228 allocation->defining_position() == defining_position;
2229 });
2230 if (prev_allocation_it != allocation_sequence->rend()) {
2231 (*prev_allocation_it)->Extend(request.start_time);
2232 needs_required_allocation = false;
2233 }
2234 }
2235 if (needs_required_allocation) {
2236 absl::optional<Chunk> aliased_chunk = absl::nullopt;
2237 if (required_assignment_at_start->memory_space ==
2238 MemorySpace::kAlternate) {
2239 aliased_chunk =
2240 Chunk{required_assignment_at_start->offset->offset, request.size};
2241 }
2242 allocation_sequence->push_back(
2243 absl::make_unique<MemorySpaceAssignment::Allocation>(
2244 defining_position, required_assignment_at_start->memory_space,
2245 aliased_chunk, request.start_time, request.start_time,
2246 /*is_scoped_allocation=*/false));
2247 if (required_assignment_at_start->memory_space ==
2248 MemorySpace::kAlternate) {
2249 CreateOrAddToAliasedOffset(*allocation_sequence->back(),
2250 required_assignment_at_start->offset);
2251 }
2252 }
2253 }
2254
2255 Result allocation_result = Result::kSuccess;
2256 // First try keeping the allocation entirely in the alternate memory.
2257 if (required_memory_space_at_start != MemorySpace::kDefault &&
2258 required_memory_space_at_end != MemorySpace::kDefault &&
2259 request.allow_no_copy_alternate_mem_allocation) {
2260 allocation_result = AllocateInAlternateMemoryNoCopy(request);
2261 if (allocation_result == Result::kSuccess) {
2262 return Result::kSuccess;
2263 }
2264 }
2265
2266 auto prev_allocation_it = allocation_sequence->rbegin();
2267 // Find a previous allocation that is in the default memory space (not
2268 // necessarily the very last allocation).
2269 auto prev_allocation_in_default_mem_it = std::find_if(
2270 allocation_sequence->rbegin(), allocation_sequence->rend(),
2271 [&](const auto& allocation) {
2272 return allocation->memory_space() == MemorySpace::kDefault &&
2273 allocation->defining_position() == defining_position;
2274 });
2275
2276 if (prev_allocation_in_default_mem_it == allocation_sequence->rend() &&
2277 prev_allocation_it != allocation_sequence->rend() &&
2278 (*prev_allocation_it)->memory_space() == MemorySpace::kAlternate &&
2279 (*prev_allocation_it)->defining_position() == defining_position &&
2280 !request.allocation_value->requires_contiguous_allocation()) {
2281 // If there was an allocation for this HloValue that was in the alternate
2282 // memory space, we also need to perform an eviction.
2283 Result eviction_result = Evict(request);
2284 if (eviction_result != Result::kSuccess) {
2285 // A non-success eviction requires us to uncommit previous allocations.
2286 return result_mark(Result::kFailRequiresUncommit, eviction_result);
2287 }
2288 prev_allocation_in_default_mem_it = allocation_sequence->rbegin();
2289 } else if (prev_allocation_in_default_mem_it == allocation_sequence->rend()) {
2290 allocation_sequence->push_back(
2291 absl::make_unique<MemorySpaceAssignment::Allocation>(
2292 defining_position, MemorySpace::kDefault, /*chunk=*/absl::nullopt,
2293 request.start_time, request.end_time,
2294 /*is_scoped_allocation=*/false));
2295 prev_allocation_in_default_mem_it = allocation_sequence->rbegin();
2296 }
2297
2298 CHECK(prev_allocation_in_default_mem_it != allocation_sequence->rend());
2299 CHECK((*prev_allocation_in_default_mem_it)->memory_space() ==
2300 MemorySpace::kDefault);
2301
2302 // If the buffer must be in default memory at the end_time, don't prefetch.
2303 if (required_memory_space_at_end == MemorySpace::kDefault) {
2304 VLOG(3)
2305 << "Not trying to prefetch because use requires buffer in default mem.";
2306 (*prev_allocation_in_default_mem_it)->Extend(request.end_time);
2307 (*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use);
2308 return Result::kSuccess;
2309 }
2310
2311 // Finally, try to prefetch the buffer into alternate memory.
2312 if (!request.allocation_value->requires_contiguous_allocation()) {
2313 Result prefetch_result =
2314 Prefetch(request, **prev_allocation_in_default_mem_it);
2315 if (prefetch_result == Result::kSuccess) {
2316 return Result::kSuccess;
2317 }
2318 result_mark(prefetch_result, allocation_result);
2319 }
2320
2321 // If the end assignment was required to be in alternate memory but that
2322 // wasn't possible, then this allocation is invalid.
2323 if (required_memory_space_at_end == MemorySpace::kAlternate) {
2324 return result_mark(Result::kFailRequiresUncommit, allocation_result);
2325 }
2326
2327 // If the start assignment was required to be in alternate memory and the
2328 // buffer needs a contiguous assignment, we couldn't satisfy this requirement
2329 // and must abort.
2330 if (required_memory_space_at_start == MemorySpace::kAlternate &&
2331 request.allocation_value->requires_contiguous_allocation()) {
2332 return result_mark(Result::kFailRequiresUncommit, allocation_result);
2333 }
2334
2335 // If a copy wasn't inserted, then add this use to the latest allocation in
2336 // default memory.
2337 (*prev_allocation_in_default_mem_it)->Extend(request.end_time);
2338 (*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use);
2339 return allocation_result;
2340 }
2341
AddAsyncCopy(const MemorySpaceAssignment::Allocation & prev_allocation,MemorySpace memory_space,absl::optional<Chunk> chunk,int64_t start_time,int64_t end_time,int64_t copy_done_schedule_before_time,MemorySpaceAssignment::AllocationSequence * allocations,AliasedOffset * aliased_offset,bool is_cross_program_prefetch)2342 void AlternateMemoryBestFitHeap::AddAsyncCopy(
2343 const MemorySpaceAssignment::Allocation& prev_allocation,
2344 MemorySpace memory_space, absl::optional<Chunk> chunk, int64_t start_time,
2345 int64_t end_time, int64_t copy_done_schedule_before_time,
2346 MemorySpaceAssignment::AllocationSequence* allocations,
2347 AliasedOffset* aliased_offset, bool is_cross_program_prefetch) {
2348 VLOG(3) << "Copy to "
2349 << (memory_space == MemorySpaceAssignment::MemorySpace::kDefault
2350 ? "default"
2351 : "alternate")
2352 << " memory between " << start_time << " and "
2353 << copy_done_schedule_before_time << " keeping until " << end_time;
2354 CHECK_LT(start_time, copy_done_schedule_before_time);
2355
2356 allocations->push_back(
2357 absl::make_unique<MemorySpaceAssignment::CopyAllocation>(
2358 prev_allocation, memory_space, chunk, start_time, end_time,
2359 copy_done_schedule_before_time, is_cross_program_prefetch));
2360
2361 // Register the additional async copy with the interval tree to keep track of
2362 // the limit at any given time.
2363 pending_async_copies_.push_back(
2364 {start_time, copy_done_schedule_before_time, memory_space});
2365 if (memory_space == MemorySpaceAssignment::MemorySpace::kAlternate) {
2366 prefetch_interval_tree_.Add(start_time, copy_done_schedule_before_time,
2367 kDummyChunk);
2368 async_copy_ordering_.AddCopy(pending_async_copies_.back());
2369 CreateOrAddToAliasedOffset(*allocations->back(), aliased_offset);
2370 } else {
2371 eviction_interval_tree_.Add(start_time, copy_done_schedule_before_time,
2372 kDummyChunk);
2373 }
2374 }
2375
ViolatesMaximumOutstandingAsyncCopies(int64_t start_time,int64_t end_time,bool is_prefetch,int64_t extra_async_copy_limit) const2376 bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies(
2377 int64_t start_time, int64_t end_time, bool is_prefetch,
2378 int64_t extra_async_copy_limit) const {
2379 if (options_.max_outstanding_prefetches < 0 && is_prefetch) {
2380 return false;
2381 }
2382 if (options_.max_outstanding_evictions < 0 && !is_prefetch) {
2383 return false;
2384 }
2385
2386 // Count the prefetches/evictions in the interval tree for the given interval.
2387 if (is_prefetch) {
2388 int64_t num_prefetches =
2389 prefetch_interval_tree_.ChunksOverlappingInTime(start_time, end_time)
2390 .size();
2391 return num_prefetches >=
2392 options_.max_outstanding_prefetches + extra_async_copy_limit;
2393 } else {
2394 int64_t num_evictions =
2395 eviction_interval_tree_.ChunksOverlappingInTime(start_time, end_time)
2396 .size();
2397 return num_evictions >=
2398 options_.max_outstanding_evictions + extra_async_copy_limit;
2399 }
2400 }
2401
2402 absl::optional<AsynchronousCopy>
ViolatesAsyncCopyOrdering(int64_t start_time,int64_t end_time) const2403 AlternateMemoryBestFitHeap::ViolatesAsyncCopyOrdering(int64_t start_time,
2404 int64_t end_time) const {
2405 return async_copy_ordering_.ViolatesOrdering(start_time, end_time);
2406 }
2407
2408 AlternateMemoryBestFitHeap::Result
AllocateInAlternateMemoryNoCopy(const AllocationRequest & request)2409 AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy(
2410 const AllocationRequest& request) {
2411 MemorySpaceAssignment::Allocation* prev_allocation = nullptr;
2412 bool can_eliminate_copy = false;
2413 if (request.allocation_value->allocation_sequence()->empty()) {
2414 // There hasn't been any allocations for this interval so far. We can
2415 // eliminate copy if the value can be placed in the alternate memory.
2416 can_eliminate_copy = options_.is_allowed_in_alternate_mem_fn(
2417 *request.allocation_value->value());
2418 } else {
2419 // If there has been a previous allocation, we can eliminate the copy if the
2420 // previous allocation was also in the alternate memory.
2421 prev_allocation =
2422 request.allocation_value->allocation_sequence()->back().get();
2423 can_eliminate_copy =
2424 (prev_allocation->memory_space() == MemorySpace::kAlternate);
2425 }
2426
2427 if (!can_eliminate_copy) {
2428 return Result::kFailPrevAllocationNotInAlternateMem;
2429 }
2430
2431 const HloPosition& defining_position =
2432 request.allocation_value->defining_position();
2433 if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
2434 defining_position.shape(), request.start_time + 1,
2435 request.end_time)) {
2436 return Result::kFailLiveRangeTooLong;
2437 }
2438
2439 BufferInterval alternate_mem_interval;
2440 alternate_mem_interval.buffer = request.allocation_value->value();
2441 alternate_mem_interval.size = request.size;
2442 alternate_mem_interval.end = request.end_time;
2443 alternate_mem_interval.start = request.start_time;
2444
2445 // Prefer the offset that was previously used for the previous allocation.
2446 AliasedOffset* preferred_offset = nullptr;
2447 if (prev_allocation != nullptr) {
2448 preferred_offset = GetAliasedOffset(*prev_allocation);
2449 // If there is a previous allocation, set the start time one after the end
2450 // of the previous allocation's end.
2451 alternate_mem_interval.start = prev_allocation->end_time() + 1;
2452 }
2453
2454 if (request.preferred_offset) {
2455 // Sanity check that if there is a preferred offset provided in the request,
2456 // it matches with the previous allocation.
2457 CHECK(!preferred_offset || request.preferred_offset == preferred_offset)
2458 << "preferred_offset = " << preferred_offset->offset
2459 << ", request.preferred_offset = " << request.preferred_offset->offset;
2460 preferred_offset = request.preferred_offset;
2461 }
2462
2463 VLOG(3) << "We can eliminate copy to alternate memory. Preferred offset = "
2464 << (preferred_offset ? preferred_offset->offset : -1);
2465 // In case there are additional uses after this use, we rely on the last use
2466 // time to try to reserve a chunk in the heap simulator. This is to prevent
2467 // the following scenario:
2468 //
2469 // +-------+
2470 // / \
2471 // Producer--->Use1 +-->Use2
2472 // +---------+---------+
2473 // New buffer: | | |
2474 // +---------+---------+
2475 //
2476 // +-----------+
2477 // Current heap: | offset: 0 |
2478 // --------------------------+-----------+------
2479 //
2480 // Because we allocate buffers greedily, Producer to Use1 segment first, and
2481 // then Use1 to Use2 segment, it is possible to allocate the first segment at
2482 // an offset that is available for the first segment (e.g. offset 0) but not
2483 // for the entire live range. This can result in unnecessary copies. By using
2484 // the last use time, we try to find an allocation that is available for the
2485 // entire Producer to Use2 range.
2486 absl::optional<ChunkCandidate> chunk_candidate = FindBestChunkCandidate(
2487 request, preferred_offset, &alternate_mem_interval);
2488 // Check if the new heap size fits within limits. Also ensure if a
2489 // preferred offset was provided, that offset was used.
2490 if (chunk_candidate) {
2491 VLOG(3) << "Keep the buffer in alternate memory. Offset = "
2492 << chunk_candidate->chunk.offset
2493 << ", size = " << chunk_candidate->chunk.size
2494 << ", heap_size = " << chunk_candidate->heap_size
2495 << ", prefetch picker = "
2496 << options_.prefetch_interval_picker->ToNoCopyDebugString(
2497 defining_position.shape(), request.start_time,
2498 request.end_time);
2499 AddToPendingChunks(alternate_mem_interval, *chunk_candidate);
2500
2501 // If there was a previous allocation, the buffer location is the
2502 // same as the previous. Otherwise, it is the operand.
2503 if (prev_allocation != nullptr &&
2504 (prev_allocation->is_copy_allocation() ||
2505 prev_allocation->defining_position() == defining_position)) {
2506 prev_allocation->Extend(request.end_time);
2507 } else {
2508 request.allocation_value->allocation_sequence()->push_back(
2509 absl::make_unique<MemorySpaceAssignment::Allocation>(
2510 defining_position, MemorySpace::kAlternate,
2511 chunk_candidate->chunk, request.start_time, request.end_time,
2512 /*is_scoped_allocation=*/false));
2513 CreateOrAddToAliasedOffset(
2514 *request.allocation_value->allocation_sequence()->back(),
2515 preferred_offset);
2516 }
2517 request.allocation_value->allocation_sequence()->back()->AddUse(
2518 request.use->hlo_use);
2519 return Result::kSuccess;
2520 }
2521 return Result::kFailOutOfMemory;
2522 }
2523
Evict(const AllocationRequest & request)2524 AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::Evict(
2525 const AllocationRequest& request) {
2526 CHECK_GT(request.allocation_value->allocation_sequence()->size(), 0);
2527 MemorySpaceAssignment::Allocation* prev_allocation =
2528 request.allocation_value->allocation_sequence()->back().get();
2529 int64_t eviction_start_time = prev_allocation->start_time();
2530 int64_t eviction_end_time = prev_allocation->end_time();
2531 CHECK(eviction_start_time <= eviction_end_time);
2532
2533 int64_t preferred_eviction_end_time =
2534 std::max(options_.prefetch_interval_picker->PreferredEvictionEndTime(
2535 request.allocation_value->defining_position().shape(),
2536 eviction_start_time, request.end_time),
2537 eviction_end_time);
2538 // Evictions must complete by the time of this use.
2539 preferred_eviction_end_time =
2540 std::min(preferred_eviction_end_time, request.latest_prefetch_time);
2541
2542 BufferInterval eviction_mem_interval;
2543 eviction_mem_interval.buffer = request.allocation_value->value();
2544 eviction_mem_interval.size = request.size;
2545 // Try to reserve a buffer from the end of the previous allocation to the
2546 // preferred eviction end time.
2547 eviction_mem_interval.start = eviction_end_time + 1;
2548 eviction_mem_interval.end = preferred_eviction_end_time;
2549 int64_t preferred_offset = prev_allocation->chunk().offset;
2550 VLOG(3) << "Eviction (" << eviction_start_time << ", " << eviction_end_time
2551 << ") preferred end time = " << eviction_mem_interval.end;
2552
2553 for (; eviction_mem_interval.end > eviction_end_time;
2554 --eviction_mem_interval.end) {
2555 ChunkCandidate chunk_candidate =
2556 FindChunkCandidate(eviction_mem_interval, preferred_offset);
2557 if (chunk_candidate.chunk.offset == preferred_offset) {
2558 AddToPendingChunks(eviction_mem_interval, chunk_candidate);
2559 break;
2560 }
2561 }
2562 eviction_end_time = eviction_mem_interval.end;
2563
2564 VLOG(3) << "Evicting buffer at " << prev_allocation->chunk().offset << " ("
2565 << eviction_start_time << ", " << eviction_end_time << ")";
2566
2567 bool eviction_interval_too_short = (eviction_start_time == eviction_end_time);
2568 bool eviction_violates_outstanding_copies =
2569 ViolatesMaximumOutstandingAsyncCopies(eviction_start_time,
2570 eviction_end_time,
2571 /*is_prefetch=*/false);
2572
2573 // See if this interval would violate the asynchronous copy limit.
2574 if (!eviction_interval_too_short && !eviction_violates_outstanding_copies) {
2575 prev_allocation->Extend(eviction_end_time);
2576 AddAsyncCopy(*prev_allocation, MemorySpace::kDefault,
2577 /*chunk=*/absl::nullopt, eviction_start_time,
2578 prev_allocation->end_time(), eviction_end_time,
2579 request.allocation_value->allocation_sequence(),
2580 /*aliased_offset=*/nullptr);
2581 } else {
2582 if (eviction_violates_outstanding_copies) {
2583 VLOG(3) << "This violates the maximum async copies.";
2584 } else {
2585 VLOG(3) << "Eviction interval is too short (" << eviction_start_time
2586 << ", " << eviction_end_time << ").";
2587 }
2588 // If the original interval violated the limit, try sub-intervals within
2589 // this interval.
2590 bool eviction_scheduled = false;
2591 for (int64_t time = eviction_start_time; time < eviction_end_time; ++time) {
2592 VLOG(4) << "Try evicting (" << time << ", " << time + 1 << ")";
2593 if (!ViolatesMaximumOutstandingAsyncCopies(time, time + 1,
2594 /*is_prefetch=*/false)) {
2595 VLOG(3) << "Eviction successful.";
2596 AddAsyncCopy(*prev_allocation, MemorySpace::kDefault,
2597 /*chunk=*/absl::nullopt, time, time + 1, time + 1,
2598 request.allocation_value->allocation_sequence(),
2599 /*aliased_offset=*/nullptr);
2600 eviction_scheduled = true;
2601 break;
2602 }
2603 }
2604
2605 if (!eviction_scheduled) {
2606 // If the eviction couldn't be scheduled, then fail. This buffer will be
2607 // kept in the default memory.
2608 VLOG(3) << "Bailing: Could not evict " << request.use->hlo_use.ToString()
2609 << " because we hit the limit of maximum asynchronous copies "
2610 << "between "
2611 << hlo_live_range_.flattened_instruction_sequence()
2612 .instructions()[eviction_start_time]
2613 << " and "
2614 << hlo_live_range_.flattened_instruction_sequence()
2615 .instructions()[eviction_end_time];
2616 // return false;
2617 return Result::kFailOutOfAsyncCopies;
2618 }
2619 }
2620 // return true;
2621 return Result::kSuccess;
2622 }
2623
FindPrefetchEndTime(const AllocationRequest & request,int64_t earliest_prefetch_time) const2624 int64 AlternateMemoryBestFitHeap::FindPrefetchEndTime(
2625 const AllocationRequest& request, int64_t earliest_prefetch_time) const {
2626 int64_t prefetch_end_time = request.latest_prefetch_time;
2627
2628 const HloUse& use = request.use->hlo_use;
2629 const Shape& shape = ShapeUtil::GetSubshape(
2630 use.instruction->operand(use.operand_number)->shape(), use.operand_index);
2631 for (int retry_number = 0;
2632 retry_number < options_.prefetch_copy_done_reorder_max_retries;
2633 ++retry_number) {
2634 int64_t latest_prefetch_time =
2635 options_.prefetch_interval_picker->LatestPrefetchStartTime(
2636 shape, earliest_prefetch_time, prefetch_end_time, &use);
2637 VLOG(4) << "Latest prefetch start time = " << latest_prefetch_time
2638 << ", earliest prefetch start time = " << earliest_prefetch_time
2639 << ", prefetch end time = " << prefetch_end_time;
2640 // Return if we couldn't find a suitable prefetch start time.
2641 if (latest_prefetch_time < earliest_prefetch_time) {
2642 break;
2643 }
2644
2645 // Return either if there is no other violating asynchronous copy (since we
2646 // don't need to change the prefetch end time) or if the violating
2647 // asynchronous copy ends after the prefetch end time.
2648 auto violating_async_copy =
2649 ViolatesAsyncCopyOrdering(latest_prefetch_time, prefetch_end_time);
2650 if (!violating_async_copy ||
2651 violating_async_copy->end_time >= prefetch_end_time) {
2652 break;
2653 }
2654 VLOG(4) << "Violating async copy: (" << violating_async_copy->start_time
2655 << ", " << violating_async_copy->end_time << ")";
2656
2657 int64_t new_prefetch_end_time =
2658 options_.prefetch_interval_picker->LatestPrefetchEndTime(
2659 prefetch_end_time, violating_async_copy->end_time);
2660 if (new_prefetch_end_time > earliest_prefetch_time) {
2661 VLOG(3) << "Update prefetch end time = " << new_prefetch_end_time;
2662 prefetch_end_time = new_prefetch_end_time;
2663 } else {
2664 VLOG(3) << "Can't update prefetch end time = " << new_prefetch_end_time
2665 << " because earliest prefetch start time = "
2666 << earliest_prefetch_time;
2667 break;
2668 }
2669 }
2670
2671 return prefetch_end_time;
2672 }
2673
Prefetch(const AllocationRequest & request,const MemorySpaceAssignment::Allocation & prev_allocation_in_default_mem)2674 AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::Prefetch(
2675 const AllocationRequest& request,
2676 const MemorySpaceAssignment::Allocation& prev_allocation_in_default_mem) {
2677 // Try partially placing the buffer in the alternate space. The time that is
2678 // overlapped will be used to asynchronously copy the buffer from the
2679 // default memory to the alternate memory.
2680 //
2681 // start end
2682 // time time
2683 // X---------------------X
2684 // Alternate: +------+
2685 // Default: +---------------------+
2686 // ^ ^
2687 // Copy Copy
2688 // Start Done
2689 int64_t earliest_prefetch_time =
2690 prev_allocation_in_default_mem.earliest_available_time();
2691 if (request.earliest_prefetch_time) {
2692 earliest_prefetch_time =
2693 std::max(earliest_prefetch_time, *request.earliest_prefetch_time);
2694 }
2695 int64_t prefetch_end_time =
2696 FindPrefetchEndTime(request, earliest_prefetch_time);
2697
2698 options_.prefetch_interval_picker->Begin(
2699 request.use->hlo_use, earliest_prefetch_time, prefetch_end_time);
2700 VLOG(3) << "Trying prefetch picker = "
2701 << options_.prefetch_interval_picker->ToDebugString();
2702
2703 // Create an alternate memory interval that starts at the earliest
2704 // possible position, given by max_prefetch_interval.
2705 BufferInterval alternate_mem_interval;
2706 alternate_mem_interval.buffer = request.allocation_value->value();
2707 alternate_mem_interval.size = request.size;
2708 // While uses might be allowed to have additional outstanding prefetches.
2709 int64_t extra_async_copy_limit =
2710 request.use->hlo_use.instruction->opcode() == HloOpcode::kWhile
2711 ? options_.while_use_extra_outstanding_prefetch_limit
2712 : 0;
2713 Result result = Result::kSuccess;
2714 while (!options_.prefetch_interval_picker->Done()) {
2715 alternate_mem_interval.start = options_.prefetch_interval_picker->Next();
2716 CHECK_LT(alternate_mem_interval.start, prefetch_end_time);
2717 VLOG(4) << "Trying alternate memory allocation ("
2718 << alternate_mem_interval.start << ", " << request.end_time << ")";
2719 // If this additional asynchronous copy would violate the limit, try a
2720 // different interval.
2721 if (ViolatesAsyncCopyOrdering(alternate_mem_interval.start,
2722 prefetch_end_time)) {
2723 VLOG(4) << "This would violate asynchronous copy ordering.";
2724 result_mark(Result::kFailViolatesAsyncCopyOrdering, result);
2725 continue;
2726 }
2727 if (ViolatesMaximumOutstandingAsyncCopies(
2728 alternate_mem_interval.start, prefetch_end_time,
2729 /*is_prefetch=*/true, extra_async_copy_limit)) {
2730 VLOG(4) << "This would violate the outstanding async copy limit.";
2731 result_mark(Result::kFailOutOfAsyncCopies, result);
2732 continue;
2733 }
2734
2735 auto chunk_candidate = FindBestChunkCandidate(
2736 request, request.preferred_offset, &alternate_mem_interval);
2737 // Check if we could find a suitable chunk.
2738 if (chunk_candidate) {
2739 VLOG(3) << "Move the buffer to alternate memory at "
2740 << alternate_mem_interval.start
2741 << ". Offset = " << chunk_candidate->chunk.offset
2742 << ", size = " << chunk_candidate->chunk.size
2743 << ", heap_size = " << chunk_candidate->heap_size
2744 << ", prefetch picker = "
2745 << options_.prefetch_interval_picker->ToDebugString();
2746 AddToPendingChunks(alternate_mem_interval, *chunk_candidate);
2747
2748 AddAsyncCopy(prev_allocation_in_default_mem, MemorySpace::kAlternate,
2749 chunk_candidate->chunk, alternate_mem_interval.start,
2750 request.end_time, prefetch_end_time,
2751 request.allocation_value->allocation_sequence(),
2752 request.preferred_offset);
2753
2754 request.allocation_value->allocation_sequence()->back()->AddUse(
2755 request.use->hlo_use);
2756 return Result::kSuccess;
2757 }
2758 result_mark(Result::kFailOutOfMemory, result);
2759 }
2760 // If we didn't consider any prefetch intervals, then the live range was too
2761 // short.
2762 if (result == Result::kSuccess) {
2763 return Result::kFailLiveRangeTooShort;
2764 } else {
2765 return result;
2766 }
2767 }
2768
2769 absl::optional<AlternateMemoryBestFitHeap::ChunkCandidate>
FindBestChunkCandidate(const AllocationRequest & request,const AliasedOffset * preferred_offset,BufferInterval * alternate_mem_interval) const2770 AlternateMemoryBestFitHeap::FindBestChunkCandidate(
2771 const AllocationRequest& request, const AliasedOffset* preferred_offset,
2772 BufferInterval* alternate_mem_interval) const {
2773 int64_t end_time = request.end_time;
2774 if (!preferred_offset) {
2775 // First find the earliest use that is the same or later than the end time.
2776 const auto& use_times = request.all_use_times;
2777 auto use_time_it = use_times.begin();
2778 for (; *use_time_it < end_time; ++use_time_it) {
2779 }
2780 CHECK(use_time_it != use_times.end());
2781 int64_t earliest_use = *use_time_it;
2782
2783 // Then find the latest use that can be allocated contiguously without
2784 // copies.
2785 const Shape& shape = request.allocation_value->defining_position().shape();
2786 for (;
2787 (use_time_it + 1) != use_times.end() &&
2788 options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
2789 shape, *use_time_it, *(use_time_it + 1));
2790 ++use_time_it) {
2791 }
2792 CHECK(use_time_it != use_times.end());
2793 int64_t latest_contiguous_use_time = *use_time_it;
2794
2795 // Find a chunk that's as long living as possible iterating in reverse over
2796 // the use times.
2797 for (; use_time_it >= use_times.begin() && *use_time_it >= end_time;
2798 --use_time_it) {
2799 alternate_mem_interval->end = *use_time_it;
2800 ChunkCandidate chunk_candidate =
2801 FindChunkCandidate(*alternate_mem_interval);
2802 if (chunk_candidate.heap_size <= available_heap_size()) {
2803 alternate_mem_interval->end = end_time;
2804 VLOG(3) << "FindBestChunkCandidate earliest use = " << earliest_use
2805 << ", latest contiguous use = " << latest_contiguous_use_time
2806 << ", use with available mem = " << *use_time_it
2807 << ", offset = " << chunk_candidate.chunk.offset;
2808 return chunk_candidate;
2809 }
2810 }
2811 alternate_mem_interval->end = end_time;
2812 return absl::nullopt;
2813 }
2814 // If a preferred offset is given, try to find an allocation at that offset
2815 // only.
2816 alternate_mem_interval->end = end_time;
2817 ChunkCandidate chunk_candidate =
2818 FindChunkCandidate(*alternate_mem_interval, preferred_offset->offset);
2819 if (chunk_candidate.chunk.offset == preferred_offset->offset) {
2820 return chunk_candidate;
2821 }
2822 return absl::nullopt;
2823 }
2824
2825 StatusOr<MemorySpaceAssignment::AsyncCopyStats>
CalculateAsyncCopyStats() const2826 MemorySpaceAssignment::CalculateAsyncCopyStats() const {
2827 AsyncCopyStats stats;
2828 stats.max_outstanding_async_copies = 0;
2829 stats.num_prefetches = 0;
2830 stats.prefetch_bytes = 0;
2831 stats.num_evictions = 0;
2832 stats.eviction_bytes = 0;
2833 int64_t current_copies = 0;
2834 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
2835 HloDataflowAnalysis::Run(*module_));
2836 for (const HloComputation* computation :
2837 module_->MakeNonfusionComputations()) {
2838 for (HloInstruction* instruction : computation->instructions()) {
2839 if (instruction->opcode() == HloOpcode::kCopyStart) {
2840 current_copies++;
2841 } else if (instruction->opcode() == HloOpcode::kCopyDone) {
2842 current_copies--;
2843 int64_t size =
2844 options_.size_fn(dataflow_analysis->GetUniqueValueAt(instruction));
2845 if (instruction->shape().layout().memory_space() ==
2846 options_.alternate_memory_space) {
2847 ++stats.num_prefetches;
2848 stats.prefetch_bytes += size;
2849 } else {
2850 ++stats.num_evictions;
2851 stats.eviction_bytes += size;
2852 }
2853 }
2854 stats.max_outstanding_async_copies =
2855 std::max(stats.max_outstanding_async_copies, current_copies);
2856 }
2857 }
2858 return stats;
2859 }
2860
2861 /*static*/ MemorySpaceAssignment::BufferIntervalCompare
GetMemoryBoundednessBufferIntervalCompare(const MemorySpaceAssignmentCostAnalysis & cost_analysis,MemorySpaceAssignmentCostAnalysis::Cache * cache)2862 MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
2863 const MemorySpaceAssignmentCostAnalysis& cost_analysis,
2864 MemorySpaceAssignmentCostAnalysis::Cache* cache) {
2865 return [&cost_analysis, cache](const BufferInterval& x,
2866 const BufferInterval& y) {
2867 float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x, cache);
2868 float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y, cache);
2869 if (x_memory_boundedness != y_memory_boundedness) {
2870 return x_memory_boundedness > y_memory_boundedness;
2871 }
2872 // Tie-break if the memory boundedness is the same.
2873 return GlobalDecreasingSizeBestFitHeap<
2874 HloValue>::GetSpatialBufferIntervalCompare()(x, y);
2875 };
2876 }
2877
2878 /*static*/ StatusOr<std::unique_ptr<PresetAssignments>>
Run(HloModule * module,const HloLiveRange & hlo_live_range,const HloAliasAnalysis & alias_analysis,const Options & options)2879 MemorySpaceAssignment::Run(HloModule* module,
2880 const HloLiveRange& hlo_live_range,
2881 const HloAliasAnalysis& alias_analysis,
2882 const Options& options) {
2883 CHECK(module->has_schedule());
2884 VLOG(3) << "Module before memory space assignment: ";
2885 XLA_VLOG_LINES(3, module->ToString());
2886 VLOG(3) << "Schedule: " << module->schedule().ToString();
2887 MemorySpaceAssignment memory_space_assignment(module, options,
2888 hlo_live_range);
2889
2890 return memory_space_assignment.RunMemorySpaceAssignment(hlo_live_range,
2891 alias_analysis);
2892 }
2893
2894 StatusOr<std::unique_ptr<PresetAssignments>>
RunMemorySpaceAssignment(const HloLiveRange & hlo_live_range,const HloAliasAnalysis & alias_analysis)2895 MemorySpaceAssignment::RunMemorySpaceAssignment(
2896 const HloLiveRange& hlo_live_range,
2897 const HloAliasAnalysis& alias_analysis) {
2898 TF_RETURN_IF_ERROR(FindAllocationSequence(hlo_live_range, alias_analysis));
2899
2900 if (options_.cost_analysis) {
2901 float estimated_time =
2902 ComputeEstimatedElapsedTime(hlo_live_range, allocations_);
2903 VLOG(1) << "Estimated elapsed time (sec): " << estimated_time;
2904 }
2905
2906 TF_RETURN_IF_ERROR(Process());
2907 ScheduleAsynchronousCopies();
2908 TF_RETURN_IF_ERROR(SimplifyGraph());
2909 TF_RETURN_IF_ERROR(FixSchedule());
2910 TF_RETURN_IF_ERROR(ExportAndColorBuffers());
2911
2912 VLOG(3) << "Module after memory space assignment: ";
2913 XLA_VLOG_LINES(3, module_->ToString());
2914 TF_CHECK_OK(module_->schedule().Verify());
2915 TF_ASSIGN_OR_RETURN(AsyncCopyStats stats, CalculateAsyncCopyStats());
2916 VLOG(1) << "Maximum number of outstanding async copies: "
2917 << stats.max_outstanding_async_copies;
2918 VLOG(1) << "Number of prefetches: " << stats.num_prefetches
2919 << ", in bytes: " << stats.prefetch_bytes;
2920 VLOG(1) << "Number of evictions: " << stats.num_evictions
2921 << ", in bytes: " << stats.eviction_bytes;
2922
2923 TF_RETURN_IF_ERROR(VerifyAndExportHeapSimulatorTrace());
2924
2925 return std::move(preset_assignments_);
2926 }
2927
FindAllocationSequence(const HloLiveRange & hlo_live_range,const HloAliasAnalysis & alias_analysis)2928 Status MemorySpaceAssignment::FindAllocationSequence(
2929 const HloLiveRange& hlo_live_range,
2930 const HloAliasAnalysis& alias_analysis) {
2931 auto algorithm = absl::make_unique<AlternateMemoryBestFitHeap>(
2932 &allocations_, options_, alias_analysis, hlo_live_range);
2933
2934 HeapSimulator::Options heap_simulator_options;
2935 heap_simulator_options.may_reuse_operand_buffers = false;
2936 heap_simulator_options.alloc_constants = true;
2937 TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module_,
2938 module_->schedule(), alias_analysis,
2939 options_.size_fn,
2940 heap_simulator_options)
2941 .status());
2942 return Status::OK();
2943 }
2944
AddUse(HloUse use)2945 void MemorySpaceAssignment::Allocation::AddUse(HloUse use) {
2946 HloInstruction* operand =
2947 use.instruction->mutable_operand(use.operand_number);
2948 // If the use is a tuple, look inside the tuple to find the actual use.
2949 for (int64_t index : use.operand_index) {
2950 if (operand->opcode() != HloOpcode::kTuple) {
2951 break;
2952 }
2953 operand = operand->mutable_operand(index);
2954 }
2955
2956 // Look beyond GetTupleElement(Tuple()) pattern for any bitcasts.
2957 std::function<HloInstruction*(HloInstruction*)> get_simplified_operand;
2958 get_simplified_operand = [&](HloInstruction* instruction) {
2959 while (instruction->opcode() == HloOpcode::kGetTupleElement) {
2960 HloInstruction* operand =
2961 get_simplified_operand(instruction->mutable_operand(0));
2962 if (operand->opcode() == HloOpcode::kTuple) {
2963 instruction = operand->mutable_operand(instruction->tuple_index());
2964 } else {
2965 return instruction;
2966 }
2967 }
2968 return instruction;
2969 };
2970 operand = get_simplified_operand(operand);
2971
2972 uses_.push_back(use);
2973 }
2974
ComputeEstimatedElapsedTime(const HloLiveRange & hlo_live_range,const AllocationSequence & allocations)2975 float MemorySpaceAssignment::ComputeEstimatedElapsedTime(
2976 const HloLiveRange& hlo_live_range, const AllocationSequence& allocations) {
2977 absl::flat_hash_map<const HloInstruction*, std::vector<ShapeIndex>>
2978 outputs_in_alternate_memory_map;
2979 absl::flat_hash_map<const HloInstruction*,
2980 std::vector<std::pair<int64, ShapeIndex>>>
2981 operands_in_alternate_memory_map;
2982
2983 for (auto& allocation : allocations) {
2984 if (!allocation->is_copy_allocation()) {
2985 if (allocation->memory_space() == MemorySpace::kAlternate) {
2986 const HloInstruction* defining_instruction =
2987 allocation->defining_position().instruction;
2988 outputs_in_alternate_memory_map[defining_instruction].push_back(
2989 allocation->defining_position().index);
2990 }
2991 }
2992 for (auto& hlo_use : allocation->uses()) {
2993 const HloInstruction* use_instruction = hlo_use.instruction;
2994 operands_in_alternate_memory_map[use_instruction].push_back(
2995 std::make_pair(hlo_use.operand_number, hlo_use.operand_index));
2996 }
2997 }
2998
2999 const auto& instruction_sequence =
3000 hlo_live_range.flattened_instruction_sequence().instructions();
3001 float total_elapsed = 0.0;
3002 for (const HloInstruction* instruction : instruction_sequence) {
3003 std::vector<ShapeIndex> outputs_in_alternate_memory;
3004 auto output_it = outputs_in_alternate_memory_map.find(instruction);
3005 if (output_it != outputs_in_alternate_memory_map.end()) {
3006 outputs_in_alternate_memory = output_it->second;
3007 }
3008 std::vector<std::pair<int64_t, ShapeIndex>> operands_in_alternate_memory;
3009 auto operand_it = operands_in_alternate_memory_map.find(instruction);
3010 if (operand_it != operands_in_alternate_memory_map.end()) {
3011 operands_in_alternate_memory = operand_it->second;
3012 }
3013 float instruction_elapsed =
3014 options_.cost_analysis->GetInstructionElapsedInAlternateMemory(
3015 *instruction, operands_in_alternate_memory,
3016 outputs_in_alternate_memory);
3017 float while_nest_multiplier = tensorflow::MathUtil::IPow<float>(
3018 options_.xla_tpu_memory_space_assignment_while_execution_count,
3019 options_.cost_analysis->CalculateComputationNestLevel(
3020 instruction,
3021 /*while_only=*/true));
3022 total_elapsed += while_nest_multiplier * instruction_elapsed;
3023 }
3024 return total_elapsed;
3025 }
3026
Process()3027 Status MemorySpaceAssignment::Allocation::Process() {
3028 if (is_scoped_allocation()) {
3029 // Nothing to do here for scoped allocations.
3030 return Status::OK();
3031 }
3032 HloInstruction* producing_instruction = AddGetTupleElements();
3033 HloComputation* computation = producing_instruction->parent();
3034 for (const HloUse& use : uses_) {
3035 Shape operand_shape = use.instruction->operand(use.operand_number)->shape();
3036 HloInstruction* replacement_instruction = producing_instruction;
3037 if (operand_shape.IsTuple()) {
3038 TF_ASSIGN_OR_RETURN(
3039 replacement_instruction,
3040 ReplaceTupleWith(producing_instruction,
3041 use.instruction->mutable_operand(use.operand_number),
3042 use.operand_index));
3043 } else if (operand_shape != producing_instruction->shape()) {
3044 VLOG(4) << "Old shape = " << operand_shape.ToString()
3045 << ", new shape = " << producing_instruction->shape().ToString()
3046 << "; inserting a bitcast.";
3047 replacement_instruction = computation->AddInstruction(
3048 HloInstruction::CreateBitcast(operand_shape, producing_instruction));
3049 }
3050 TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith(
3051 use.operand_number, replacement_instruction));
3052 }
3053 return Status::OK();
3054 }
3055
ReplaceTupleWith(HloInstruction * new_instruction,HloInstruction * tuple,ShapeIndex shape_index)3056 StatusOr<HloInstruction*> MemorySpaceAssignment::Allocation::ReplaceTupleWith(
3057 HloInstruction* new_instruction, HloInstruction* tuple,
3058 ShapeIndex shape_index) {
3059 const Shape& tuple_shape = tuple->shape();
3060 CHECK(tuple->shape().IsTuple())
3061 << "ReplaceTupleWith was called for a non-tuple. Tuple = "
3062 << tuple->ToString()
3063 << ", new_instruction = " << new_instruction->ToString()
3064 << ", shape_index = " << shape_index.ToString();
3065
3066 HloComputation* computation = new_instruction->parent();
3067 std::vector<HloInstruction*> tuple_args(tuple_shape.tuple_shapes_size());
3068 CHECK_GE(tuple_shape.tuple_shapes_size(), shape_index[0]);
3069 for (int64_t i = 0; i < tuple_shape.tuple_shapes_size(); ++i) {
3070 const Shape& subshape = tuple_shape.tuple_shapes(i);
3071 // If tuple is a tuple instruction, we can get the tuple instruction's
3072 // operand to construct the new tuple to improve compilation time
3073 // performance.
3074 auto get_operand = [&]() {
3075 if (tuple->opcode() == HloOpcode::kTuple) {
3076 return tuple->mutable_operand(i);
3077 } else {
3078 return computation->AddInstruction(
3079 HloInstruction::CreateGetTupleElement(subshape, tuple, i));
3080 }
3081 };
3082 if (i == shape_index[0]) {
3083 // If the subshape is still a tuple, recurse and pass a new shape index
3084 // for the one level deeper.
3085 if (subshape.IsTuple()) {
3086 TF_ASSIGN_OR_RETURN(tuple_args[i],
3087 ReplaceTupleWith(new_instruction, get_operand(),
3088 ShapeIndex(shape_index.begin() + 1,
3089 shape_index.end())));
3090 } else {
3091 if (subshape != new_instruction->shape()) {
3092 VLOG(4) << "Old shape = " << subshape.ToString()
3093 << ", new shape = " << new_instruction->shape().ToString()
3094 << "; inserting a bitcast.";
3095 new_instruction = computation->AddInstruction(
3096 HloInstruction::CreateBitcast(subshape, new_instruction));
3097 } else if (tuple->opcode() == HloOpcode::kTuple &&
3098 tuple->operand(i) == new_instruction) {
3099 // If the tuple element is the same as the new instruction, we
3100 // actually don't have to create a new tuple, just return the original
3101 // tuple.
3102 VLOG(4) << "Tuple already contains the new instruction = "
3103 << new_instruction->ToShortString()
3104 << " tuple = " << tuple->ToShortString();
3105 return tuple;
3106 }
3107 tuple_args[i] = new_instruction;
3108 }
3109 } else {
3110 tuple_args[i] = get_operand();
3111 }
3112 }
3113 if (shape_index[0] == tuple_shape.tuple_shapes_size()) {
3114 // If shape_index[0] is equal to the tuple shape size, add the new
3115 // instruction as an additional argument.
3116 tuple_args.push_back(new_instruction);
3117 }
3118 return computation->AddInstruction(HloInstruction::CreateTuple(tuple_args));
3119 }
3120
AddGetTupleElements() const3121 HloInstruction* MemorySpaceAssignment::Allocation::AddGetTupleElements() const {
3122 HloInstruction* producing_instruction = defining_position().instruction;
3123 CHECK_NE(producing_instruction, nullptr);
3124
3125 Shape shape = defining_position().shape();
3126 CHECK(shape.IsArray()) << "Allocation shape is not an array. Shape = "
3127 << shape.ToString()
3128 << " position = " << defining_position().shape();
3129 HloComputation* computation = producing_instruction->parent();
3130
3131 // If the instruction we're processing is a tuple, we (recursively) search or
3132 // create kGetTupleElement instructions and copy that value. Asynchronous
3133 // copies only support array types.
3134 for (int64_t index : defining_position().index) {
3135 // We first search if there already is a get-tuple-element with the correct
3136 // index. If there is no such get-tuple-element, we create one.
3137 auto gte_it = absl::c_find_if(
3138 producing_instruction->users(), [index](const HloInstruction* use) {
3139 return use != use->parent()->root_instruction() &&
3140 use->opcode() == HloOpcode::kGetTupleElement &&
3141 use->tuple_index() == index;
3142 });
3143 if (gte_it != producing_instruction->users().end()) {
3144 producing_instruction = *gte_it;
3145 } else {
3146 producing_instruction =
3147 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
3148 producing_instruction->shape().tuple_shapes(index),
3149 producing_instruction, index));
3150 }
3151 }
3152 return producing_instruction;
3153 }
3154
ToString() const3155 std::string MemorySpaceAssignment::Allocation::ToString() const {
3156 std::string memory_space_str = "def";
3157 if (memory_space_ == MemorySpace::kAlternate) {
3158 memory_space_str = absl::StrCat("alt (off: ", chunk_->offset, ")");
3159 }
3160 return absl::StrCat((is_scoped_allocation() ? "Scoped " : ""),
3161 "Allocation in ", memory_space_str, " defined at ",
3162 defining_position_.ToString());
3163 }
3164
ToString() const3165 std::string MemorySpaceAssignment::CopyAllocation::ToString() const {
3166 std::string memory_space_str = "def";
3167 if (memory_space_ == MemorySpace::kAlternate) {
3168 memory_space_str = absl::StrCat("alt (off: ", chunk_->offset, ")");
3169 }
3170 return absl::StrCat("Copy Allocation in ", memory_space_str, " from ",
3171 prev_allocation_.ToString());
3172 }
3173
ToString() const3174 std::string MemorySpaceAssignment::ParentAllocation::ToString() const {
3175 return absl::StrCat("Parent Allocation mirrored at ",
3176 defining_position_.ToString(), ", originally ",
3177 original_allocation_.ToString());
3178 }
3179
Process()3180 Status MemorySpaceAssignment::CopyAllocation::Process() {
3181 // Copy allocations need to insert asynchronous copy nodes.
3182 Shape shape = defining_position().shape();
3183 HloInstruction* producing_instruction = AddGetTupleElements();
3184 HloComputation* computation = producing_instruction->parent();
3185 copy_start_ = computation->AddInstruction(HloInstruction::CreateCopyStart(
3186 ShapeUtil::MakeTupleShape({shape, shape, ShapeUtil::MakeShape(U32, {})}),
3187 producing_instruction, is_cross_program_prefetch_));
3188 copy_done_ = computation->AddInstruction(
3189 HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_));
3190 VLOG(4) << "Created " << copy_start_->name()
3191 << " for position: " << defining_position().ToString();
3192 // Update the allocation position with the copy done instruction so that if
3193 // there are further copies from it, it can find the correct position.
3194 defining_position_ = HloPosition{copy_done_, {}};
3195
3196 // Replace all the uses with the new copy instruction.
3197 for (HloUse use : uses_) {
3198 // If the operand is a tuple, we need to descend to the actual instruction
3199 // we want to replace.
3200 HloInstruction* replacement_instruction;
3201 Shape operand_shape = use.instruction->operand(use.operand_number)->shape();
3202 if (operand_shape.IsTuple()) {
3203 TF_ASSIGN_OR_RETURN(
3204 replacement_instruction,
3205 ReplaceTupleWith(copy_done_,
3206 use.instruction->mutable_operand(use.operand_number),
3207 use.operand_index));
3208 } else if (operand_shape != copy_done_->shape()) {
3209 VLOG(4) << "Old shape = " << operand_shape.ToString()
3210 << ", new shape = " << copy_done_->shape().ToString()
3211 << "; inserting a bitcast.";
3212 replacement_instruction = computation->AddInstruction(
3213 HloInstruction::CreateBitcast(operand_shape, copy_done_));
3214 } else {
3215 replacement_instruction = copy_done_;
3216 }
3217 TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith(
3218 use.operand_number, replacement_instruction));
3219 }
3220
3221 return Status::OK();
3222 }
3223
Process()3224 Status MemorySpaceAssignment::ParentAllocation::Process() {
3225 // Add an additional parameter to the while HLO with a reference to the buffer
3226 // in the default memory space.
3227 HloInstruction* producing_instruction =
3228 original_allocation_.AddGetTupleElements();
3229 int64_t new_tuple_index = calling_instruction_->shape().tuple_shapes_size();
3230
3231 TF_ASSIGN_OR_RETURN(HloInstruction * new_while_operand,
3232 ReplaceTupleWith(producing_instruction,
3233 calling_instruction_->mutable_operand(0),
3234 {new_tuple_index}));
3235 TF_RETURN_IF_ERROR(calling_instruction_->ReplaceOperandWithDifferentShape(
3236 0, new_while_operand));
3237 *calling_instruction_->mutable_shape() = new_while_operand->shape();
3238 *calling_instruction_->while_condition()
3239 ->parameter_instruction(0)
3240 ->mutable_shape() = new_while_operand->shape();
3241 *calling_instruction_->while_body()
3242 ->parameter_instruction(0)
3243 ->mutable_shape() = new_while_operand->shape();
3244 defining_position_.index = {new_tuple_index};
3245 return Allocation::Process();
3246 }
3247
PostProcess()3248 Status MemorySpaceAssignment::ParentAllocation::PostProcess() {
3249 // Update the root of the while body with the new parameter. The reason why we
3250 // need a separate post-process for this is because other allocations may have
3251 // while body root as a use, so they would update the old root instead of the
3252 // new root. Doing the post-process step later ensures the root has been
3253 // updated with other changes, and we can safely add the additional parameter.
3254 HloComputation* while_body = calling_instruction_->while_body();
3255 TF_ASSIGN_OR_RETURN(
3256 HloInstruction * new_while_body_root,
3257 ReplaceTupleWith(AddGetTupleElements(), while_body->root_instruction(),
3258 defining_position_.index));
3259 while_body->set_root_instruction(new_while_body_root,
3260 /*accept_different_shape=*/true);
3261 return Status::OK();
3262 }
3263
MarkIfNeeded(absl::flat_hash_set<const Allocation * > & needed_allocations) const3264 void MemorySpaceAssignment::Allocation::MarkIfNeeded(
3265 absl::flat_hash_set<const Allocation*>& needed_allocations) const {
3266 MarkNeeded(needed_allocations);
3267 }
3268
MarkNeeded(absl::flat_hash_set<const Allocation * > & needed_allocations) const3269 void MemorySpaceAssignment::Allocation::MarkNeeded(
3270 absl::flat_hash_set<const Allocation*>& needed_allocations) const {
3271 needed_allocations.insert(this);
3272 }
3273
MarkNeeded(absl::flat_hash_set<const Allocation * > & needed_allocations) const3274 void MemorySpaceAssignment::CopyAllocation::MarkNeeded(
3275 absl::flat_hash_set<const Allocation*>& needed_allocations) const {
3276 needed_allocations.insert(this);
3277 prev_allocation_.MarkNeeded(needed_allocations);
3278 }
3279
MarkIfNeeded(absl::flat_hash_set<const Allocation * > & needed_allocations) const3280 void MemorySpaceAssignment::ParentAllocation::MarkIfNeeded(
3281 absl::flat_hash_set<const Allocation*>& needed_allocations) const {
3282 // Parent allocations are only needed if they have any uses or if there is a
3283 // copy allocation that copies this value (in that case, the copy allocation
3284 // will call this allocation's MarkNeeded function).
3285 if (!uses_.empty()) {
3286 MarkNeeded(needed_allocations);
3287 }
3288 }
3289
MarkNeeded(absl::flat_hash_set<const Allocation * > & needed_allocations) const3290 void MemorySpaceAssignment::ParentAllocation::MarkNeeded(
3291 absl::flat_hash_set<const Allocation*>& needed_allocations) const {
3292 needed_allocations.insert(this);
3293 original_allocation_.MarkNeeded(needed_allocations);
3294 }
3295
Process()3296 Status MemorySpaceAssignment::Process() {
3297 VLOG(1) << "Processing assigned buffers...";
3298 // Since some parent allocations may not be needed (e.g. when they don't have
3299 // any uses and if there is no other (non-parent) allocation that depends on
3300 // it, before we process the allocations, mark all allocations that are
3301 // needed.
3302 absl::flat_hash_set<const Allocation*> needed_allocations;
3303 for (auto& allocation : allocations_) {
3304 allocation->MarkIfNeeded(needed_allocations);
3305 }
3306 // Insert CopyStart/CopyDone pairs.
3307 for (auto& allocation : allocations_) {
3308 VLOG(3) << "Processing: " << allocation->ToString();
3309 if (!needed_allocations.contains(allocation.get())) {
3310 VLOG(3) << "Allocation not needed.";
3311 continue;
3312 }
3313 TF_RETURN_IF_ERROR(allocation->Process());
3314 // Add the offset and size of the allocation in the alternate memory to
3315 // the output map.
3316 if (allocation->is_scoped_allocation()) {
3317 CHECK(allocation->memory_space() == MemorySpace::kAlternate);
3318 scoped_memory_assignments_.emplace_back(
3319 allocation->defining_position().instruction, allocation->chunk());
3320 alternate_memory_size_ =
3321 std::max(alternate_memory_size_, allocation->chunk().chunk_end());
3322 } else if (allocation->memory_space() == MemorySpace::kAlternate) {
3323 alternate_memory_assignments_.emplace_back(
3324 allocation->defining_position(), allocation->chunk());
3325 alternate_memory_size_ =
3326 std::max(alternate_memory_size_, allocation->chunk().chunk_end());
3327 }
3328 }
3329 // Post-process allocations. This is only used for parent allocations where we
3330 // update the body root with a reference to the buffer in default memory
3331 // space.
3332 for (auto& allocation : allocations_) {
3333 if (needed_allocations.contains(allocation.get())) {
3334 VLOG(3) << "Post-Processing: " << allocation->ToString();
3335 TF_RETURN_IF_ERROR(allocation->PostProcess());
3336 }
3337 }
3338 return Status::OK();
3339 }
3340
ExportAndColorBuffers()3341 Status MemorySpaceAssignment::ExportAndColorBuffers() {
3342 VLOG(1) << "Exporting buffers...";
3343 TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module_));
3344 absl::flat_hash_map<int64, int64> seen_buffer_offsets;
3345 VLOG(3) << "Exported alternate memory allocations:";
3346 for (const auto& position_and_chunk : alternate_memory_assignments_) {
3347 const HloPosition& defining_position = position_and_chunk.first;
3348 const Chunk& chunk = position_and_chunk.second;
3349 const HloBuffer& buffer = alias_analysis->GetUniqueBufferAt(
3350 defining_position.instruction, defining_position.index);
3351 auto seen_buffer_offset_it = seen_buffer_offsets.find(buffer.id());
3352 if (seen_buffer_offset_it != seen_buffer_offsets.end()) {
3353 CHECK_EQ(chunk.offset, seen_buffer_offset_it->second)
3354 << "Mismatch in offset for positions that map to the same value: "
3355 << buffer.ToString() << ", pos: " << defining_position.ToString();
3356 } else {
3357 VLOG(3) << " [" << chunk.offset << ", " << chunk.size
3358 << "] : " << defining_position.ToString() << " ("
3359 << buffer.ToString() << ")";
3360 preset_assignments_->add_chunk(defining_position, chunk);
3361 seen_buffer_offsets[buffer.id()] = chunk.offset;
3362 }
3363 }
3364
3365 VLOG(3) << "Exported scoped allocations in alternate memory:";
3366 for (const auto& instruction_and_chunk : scoped_memory_assignments_) {
3367 HloInstruction* instruction = instruction_and_chunk.first;
3368 const Chunk& chunk = instruction_and_chunk.second;
3369 VLOG(3) << " [" << chunk.offset << ", " << chunk.size
3370 << "] : " << instruction->name();
3371 preset_assignments_->add_scoped_allocation_chunk(instruction, chunk);
3372 }
3373
3374 if (!preset_assignments_->chunks().empty() ||
3375 !preset_assignments_->scoped_allocation_chunks().empty()) {
3376 preset_assignments_
3377 ->assignment_information_for_space(options_.alternate_memory_space)
3378 ->size = alternate_memory_size_;
3379 }
3380
3381 VLOG(3) << "Exported alternate memory sizes:";
3382 for (auto& pair : preset_assignments_->assignment_informations()) {
3383 VLOG(3) << " space: " << pair.first << ", size: " << pair.second.size;
3384 }
3385
3386 VLOG(1) << "Coloring buffers...";
3387 // Color the pending positions and all of their aliased buffers.
3388 for (const auto& defining_position_and_chunk :
3389 preset_assignments_->chunks()) {
3390 const HloPosition& defining_position = defining_position_and_chunk.first;
3391 for (auto& buffer : alias_analysis->ComputeBuffersAt(
3392 defining_position.instruction, defining_position.index)) {
3393 for (auto& value : buffer->values()) {
3394 for (auto& position : value->positions()) {
3395 VLOG(4) << "Coloring " << position.ToString();
3396 Shape* shape = ShapeUtil::GetMutableSubshape(
3397 position.instruction->mutable_shape(), position.index);
3398 CHECK(shape->IsArray()) << "Coloring a shape that is not an array: "
3399 << position.ToString();
3400 shape->mutable_layout()->set_memory_space(
3401 options_.alternate_memory_space);
3402 }
3403 }
3404 }
3405 }
3406 return Status::OK();
3407 }
3408
RemoveAssignmentForInstruction(const HloInstruction * instruction)3409 void MemorySpaceAssignment::RemoveAssignmentForInstruction(
3410 const HloInstruction* instruction) {
3411 for (auto& position_and_chunk : alternate_memory_assignments_) {
3412 const HloPosition& position = position_and_chunk.first;
3413 if (position.instruction == instruction) {
3414 VLOG(3) << "Removing instruction from alternate memory assignments.";
3415 // Swap the removed position and chunk with the back and pop back.
3416 position_and_chunk = alternate_memory_assignments_.back();
3417 alternate_memory_assignments_.pop_back();
3418 break;
3419 }
3420 }
3421 }
3422
SimplifyGraph()3423 Status MemorySpaceAssignment::SimplifyGraph() {
3424 VLOG(1) << "Simplifying graph...";
3425 for (HloComputation* computation : module_->MakeNonfusionComputations()) {
3426 // Parallel computations aren't in the schedule and don't need to be
3427 // modified.
3428 if (!computations_in_schedule_.contains(computation)) {
3429 VLOG(4) << "Not simplifying " << computation->name()
3430 << " because it's not in the schedule.";
3431 continue;
3432 }
3433 // Drop control dependencies. Since the computation is already scheduled, we
3434 // don't need control dependencies anymore, and having control
3435 // predecessors/successors prevents us from removing instructions without
3436 // users (HloComputation::IsSafelyRemovable returns false if there are
3437 // control dependencies).
3438 for (HloInstruction* instruction :
3439 computation->MakeInstructionPostOrder()) {
3440 TF_RETURN_IF_ERROR(instruction->DropAllControlDeps());
3441 }
3442 // We perform limited DCE and forward the tuple operand in patterns like
3443 // GetTupleElement(Tuple(a, b), 0). This is mostly because memory space
3444 // assignment is ran late in compilation (after DCE and arithmetic
3445 // simplification passes) and we don't want to generate redundant code. Run
3446 // to fixed point.
3447 bool computation_modified = true;
3448 while (computation_modified) {
3449 computation_modified = false;
3450 VLOG(4) << "Running simplify graph loop over " << computation->name();
3451 for (HloInstruction* instruction :
3452 computation->MakeInstructionPostOrder()) {
3453 if (computation->IsSafelyRemovable(instruction) &&
3454 instruction->user_count() == 0 && !instruction->HasSideEffect() &&
3455 instruction != computation->root_instruction() &&
3456 instruction->opcode() != HloOpcode::kCopyStart &&
3457 instruction->opcode() != HloOpcode::kCopyDone) {
3458 VLOG(4) << "Instruction removed: " << instruction->ToString();
3459 // Ensure the alternate memory assignments don't contain a reference
3460 // to the removed instruction.
3461 RemoveAssignmentForInstruction(instruction);
3462 // Instead of deleting the instruction from the schedule, replace it
3463 // with a nullptr. This is needed because FixSchedule relies on the
3464 // logical time that is the index into flattened_instructions_ for
3465 // scheduling asynchronous copies.
3466 auto instruction_it =
3467 absl::c_find(flattened_instructions_, instruction);
3468 if (instruction_it != flattened_instructions_.end()) {
3469 *instruction_it = nullptr;
3470 }
3471 TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
3472 computation_modified = true;
3473 } else if (instruction->opcode() == HloOpcode::kGetTupleElement) {
3474 HloInstruction* operand = instruction->mutable_operand(0);
3475 if (operand->opcode() == HloOpcode::kTuple) {
3476 HloInstruction* forwarded_instruction =
3477 operand->mutable_operand(instruction->tuple_index());
3478 VLOG(4) << "Replacing uses of " << instruction->ToString()
3479 << " with " << forwarded_instruction->ToString();
3480 TF_RETURN_IF_ERROR(
3481 instruction->ReplaceAllUsesWith(forwarded_instruction));
3482 computation_modified = true;
3483 }
3484 } else if (instruction->opcode() == HloOpcode::kTuple) {
3485 // Replace Tuple(GetTupleElement(x), ..., GetTupleElement(x)) pattern
3486 // with x.
3487 bool can_replace =
3488 instruction->operand_count() > 0 &&
3489 instruction->operand(0)->opcode() ==
3490 HloOpcode::kGetTupleElement &&
3491 instruction->operand(0)
3492 ->operand(0)
3493 ->shape()
3494 .tuple_shapes_size() == instruction->operand_count();
3495 for (int operand_number = 0;
3496 operand_number < instruction->operand_count();
3497 ++operand_number) {
3498 const HloInstruction* operand =
3499 instruction->operand(operand_number);
3500 if (operand->opcode() != HloOpcode::kGetTupleElement ||
3501 operand->tuple_index() != operand_number ||
3502 operand->operand(0) != instruction->operand(0)->operand(0)) {
3503 can_replace = false;
3504 break;
3505 }
3506 }
3507 if (can_replace) {
3508 HloInstruction* forwarded_instruction =
3509 instruction->mutable_operand(0)->mutable_operand(0);
3510 VLOG(4) << "Replacing uses of " << instruction->ToString()
3511 << " with " << forwarded_instruction->ToString();
3512 TF_RETURN_IF_ERROR(
3513 instruction->ReplaceAllUsesWith(forwarded_instruction));
3514 computation_modified = true;
3515 }
3516 }
3517 }
3518 }
3519 }
3520
3521 return Status::OK();
3522 }
3523
EnsureInstructionAndOperandsInserted(HloInstruction * new_instruction,HloInstructionSequence * new_sequence,absl::flat_hash_set<HloInstruction * > * inserted_instructions) const3524 void MemorySpaceAssignment::EnsureInstructionAndOperandsInserted(
3525 HloInstruction* new_instruction, HloInstructionSequence* new_sequence,
3526 absl::flat_hash_set<HloInstruction*>* inserted_instructions) const {
3527 if (inserted_instructions->contains(new_instruction)) {
3528 return;
3529 }
3530 for (HloInstruction* operand : new_instruction->operands()) {
3531 // CopyStart/CopyDone dependencies should always be already inserted; it is
3532 // a red flag when they haven't already been inserted.
3533 CHECK((operand->opcode() != HloOpcode::kCopyStart &&
3534 operand->opcode() != HloOpcode::kCopyDone) ||
3535 inserted_instructions->contains(operand))
3536 << "Inserted instruction " << new_instruction->ToString()
3537 << " has un-inserted dependency: " << operand->ToString();
3538 EnsureInstructionAndOperandsInserted(operand, new_sequence,
3539 inserted_instructions);
3540 }
3541 VLOG(4) << "inserting: " << new_instruction->ToShortString();
3542 new_sequence->push_back(new_instruction);
3543 inserted_instructions->insert(new_instruction);
3544 }
3545
ScheduleAsynchronousCopies()3546 void MemorySpaceAssignment::ScheduleAsynchronousCopies() {
3547 VLOG(1) << "Scheduling asynchronous copies...";
3548 for (MemorySpace memory_space :
3549 {MemorySpace::kDefault, MemorySpace::kAlternate}) {
3550 std::vector<CopyAllocation*> copy_allocations;
3551 for (auto& allocation : allocations_) {
3552 if (allocation->is_copy_allocation()) {
3553 auto copy_allocation = static_cast<CopyAllocation*>(allocation.get());
3554 if (copy_allocation->memory_space() == memory_space) {
3555 copy_allocations.push_back(copy_allocation);
3556 }
3557 }
3558 }
3559
3560 absl::c_stable_sort(
3561 copy_allocations, [](CopyAllocation* first, CopyAllocation* second) {
3562 return std::forward_as_tuple(first->copy_done_schedule_before(),
3563 first->copy_start_schedule_after()) <
3564 std::forward_as_tuple(second->copy_done_schedule_before(),
3565 second->copy_start_schedule_after());
3566 });
3567 for (CopyAllocation* copy_allocation : copy_allocations) {
3568 // If the copy start doesn't happen to be scheduled at the correct
3569 // computation, delay it until the correct computation starts.
3570 int64_t copy_start_schedule_after =
3571 copy_allocation->copy_start_schedule_after();
3572 // Accessing flattened_instructions_ here without checking if it is
3573 // nullptr is safe because this method is called before SimplifyGraph.
3574 while (copy_allocation->defining_position().instruction->parent() !=
3575 flattened_instructions_[copy_start_schedule_after]->parent()) {
3576 VLOG(4) << "Delaying CopyStart (" << copy_start_schedule_after << " to "
3577 << (copy_start_schedule_after + 1) << ") for "
3578 << copy_allocation->copy_start()->ToString()
3579 << " because it is not in the correct computation.";
3580 copy_allocation->set_copy_start_schedule_after(
3581 ++copy_start_schedule_after);
3582 }
3583
3584 schedule_after_[copy_allocation->copy_start_schedule_after()].push_back(
3585 copy_allocation->copy_start());
3586 schedule_before_[copy_allocation->copy_done_schedule_before()].push_back(
3587 copy_allocation->copy_done());
3588 }
3589 }
3590 }
3591
FixSchedule()3592 Status MemorySpaceAssignment::FixSchedule() {
3593 VLOG(1) << "Fixing schedule...";
3594 CHECK(module_->has_schedule());
3595 HloSchedule& schedule = module_->schedule();
3596 for (const HloComputation* computation :
3597 module_->MakeNonfusionComputations()) {
3598 // Parallel computations aren't in the schedule and don't need to be
3599 // modified.
3600 if (!computations_in_schedule_.contains(computation)) {
3601 VLOG(4) << "Not scheduling " << computation->name()
3602 << " because it's not in the schedule.";
3603 continue;
3604 }
3605 CHECK(schedule.is_computation_scheduled(computation));
3606 HloInstructionSequence new_sequence;
3607
3608 absl::flat_hash_set<HloInstruction*> inserted_instructions;
3609
3610 VLOG(4) << "Scheduling: " << computation->ToString();
3611
3612 for (int64_t instruction_index = 0;; ++instruction_index) {
3613 auto insts_before_iter = schedule_before_.find(instruction_index);
3614 if (insts_before_iter != schedule_before_.end()) {
3615 for (HloInstruction* new_instruction : insts_before_iter->second) {
3616 if (new_instruction->parent() == computation) {
3617 VLOG(4) << "before " << instruction_index << ": "
3618 << new_instruction->name();
3619 EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence,
3620 &inserted_instructions);
3621 }
3622 }
3623 }
3624 // We allow scheduling copy dones past the root instruction (for
3625 // end-of-program cross-program prefetch). So the loop exit condition is
3626 // actually here.
3627 if (instruction_index >= flattened_instructions_.size()) {
3628 break;
3629 }
3630 HloInstruction* instruction = flattened_instructions_[instruction_index];
3631 // Insert only if it is not deleted (SimplifyGraph sets it to nullptr if
3632 // it was deleted) and not previously inserted. Also bitcasts and tuples
3633 // are treated specially and only inserted as a result of operand
3634 // dependencies.
3635 if (instruction != nullptr &&
3636 !inserted_instructions.contains(instruction) &&
3637 instruction->parent() == computation &&
3638 instruction->opcode() != HloOpcode::kBitcast &&
3639 instruction->opcode() != HloOpcode::kTuple) {
3640 VLOG(4) << "inst " << instruction_index << ": " << instruction->name();
3641 EnsureInstructionAndOperandsInserted(instruction, &new_sequence,
3642 &inserted_instructions);
3643 }
3644 auto insts_after_iter = schedule_after_.find(instruction_index);
3645 if (insts_after_iter != schedule_after_.end()) {
3646 for (HloInstruction* new_instruction : insts_after_iter->second) {
3647 if (new_instruction->parent() == computation) {
3648 VLOG(4) << "after " << instruction_index << ": "
3649 << new_instruction->name();
3650 EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence,
3651 &inserted_instructions);
3652 }
3653 }
3654 }
3655 }
3656 // For rare cases where the original sequence is empty, ensure the root
3657 // instruction and its dependencies are scheduled.
3658 EnsureInstructionAndOperandsInserted(computation->root_instruction(),
3659 &new_sequence, &inserted_instructions);
3660 CHECK_EQ(new_sequence.size(), computation->instruction_count())
3661 << "New sequence for computation " << computation->name() << " has "
3662 << new_sequence.size() << " instructions, expects "
3663 << computation->instruction_count() << ".";
3664 schedule.set_sequence(computation, new_sequence);
3665 }
3666
3667 return Status::OK();
3668 }
3669
VerifyAndExportHeapSimulatorTrace()3670 Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() {
3671 VLOG(1) << "Verifying...";
3672 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
3673 HloAliasAnalysis::Run(module_));
3674 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
3675 HloLiveRange::Run(module_->schedule(), *alias_analysis,
3676 module_->entry_computation()));
3677
3678 BufferIntervalTree interval_tree;
3679 absl::flat_hash_set<int64> seen_buffers;
3680 // The key for events is: time, is_free, value_id. This is so that the events
3681 // are sorted first by time, then within the same time, allocations are sorted
3682 // earlier than frees, and finally the value id as a tie breaker.
3683 std::map<std::tuple<int64, bool, int64>,
3684 std::tuple<const HloValue*, Chunk, HeapSimulatorTrace::Event::Kind>>
3685 events;
3686
3687 auto add_allocation_and_verify = [&](int64_t start_time, int64_t end_time,
3688 const Chunk& chunk,
3689 const HloValue* value) {
3690 events[std::make_tuple(start_time, /*is_free=*/false, value->id())] =
3691 std::make_tuple(value, chunk, HeapSimulatorTrace::Event::ALLOC);
3692 events[std::make_tuple(end_time, /*is_free=*/true, value->id())] =
3693 std::make_tuple(value, chunk, HeapSimulatorTrace::Event::FREE);
3694
3695 // Get the chunks overlapping in time and search if they overlap in space
3696 // as well.
3697 // TODO(berkin): For now checking against end_time - 1 (exclusive), but we
3698 // really should check against end_time (inclusive) for cases where the
3699 // operand can't share buffer with user (see
3700 // HloDataflowAnalysis::CanShareOperandBufferWithUser).
3701 for (const Chunk& overlapping_chunk :
3702 interval_tree.ChunksOverlappingInTime(start_time, end_time - 1)) {
3703 if (chunk.OverlapsWith(overlapping_chunk)) {
3704 return InternalError(
3705 ("Value %s (%d, %d) off: %d size: %d overlaps with another chunk"
3706 " off: %d size: %d"),
3707 value->ToShortString(), start_time, end_time, chunk.offset,
3708 chunk.size, overlapping_chunk.offset, overlapping_chunk.size);
3709 }
3710 }
3711 interval_tree.Add(start_time, end_time - 1, chunk);
3712 return Status::OK();
3713 };
3714
3715 // Go through all instructions in the module to ensure CopyStart/CopyDone
3716 // instructions copy between alternate memory and default memory.
3717 for (const HloComputation* computation :
3718 module_->MakeNonfusionComputations()) {
3719 for (const HloInstruction* instruction : computation->instructions()) {
3720 if (instruction->opcode() == HloOpcode::kCopyStart) {
3721 int64_t from_memory_space =
3722 ShapeUtil::GetSubshape(instruction->shape(), {1})
3723 .layout()
3724 .memory_space();
3725 int64_t to_memory_space =
3726 ShapeUtil::GetSubshape(instruction->shape(), {0})
3727 .layout()
3728 .memory_space();
3729 CHECK_NE(from_memory_space, to_memory_space)
3730 << "Asynchronous copy to the same memory space: "
3731 << instruction->ToString();
3732 }
3733 }
3734 }
3735
3736 for (const auto& position_and_chunk : preset_assignments_->chunks()) {
3737 const HloPosition& position = position_and_chunk.first;
3738 const Chunk& chunk = position_and_chunk.second;
3739 const HloBuffer& buffer =
3740 alias_analysis->GetUniqueBufferAt(position.instruction, position.index);
3741 CHECK(!seen_buffers.contains(buffer.id()))
3742 << "Multiple preset assignments for the same buffer: "
3743 << buffer.ToString() << ", pos: " << position.ToString()
3744 << ", off: " << chunk.offset << ", size: " << chunk.size;
3745 seen_buffers.insert(buffer.id());
3746
3747 for (const HloValue* value : buffer.values()) {
3748 const HloLiveRange::TimeBound& time_bound =
3749 hlo_live_range->buffer_live_ranges().at(value);
3750 const HloInstruction* last_use_instruction = nullptr;
3751 int64_t last_use_time = time_bound.start;
3752 for (const HloUse& use : value->uses()) {
3753 int64_t use_time =
3754 hlo_live_range->instruction_schedule().at(use.instruction);
3755 if (use_time > last_use_time) {
3756 last_use_time = use_time;
3757 last_use_instruction = use.instruction;
3758 }
3759 }
3760
3761 std::function<Status(const HloInstruction*, int64_t, int64_t,
3762 absl::string_view)>
3763 split_conditional_buffer;
3764 split_conditional_buffer = [&](const HloInstruction* use_instruction,
3765 int64_t start_time, int64_t end_time,
3766 absl::string_view indent_string) {
3767 // Special case when verifying conditional: we internally split the use
3768 // of alternate memory in conditionals, so fish them out from the
3769 // conditionals.
3770 VLOG(3) << indent_string
3771 << "Splitting conditional buffer: " << buffer.ToString()
3772 << " value: " << value->ToShortString() << ": (" << start_time
3773 << ", " << end_time << ") off: " << chunk.offset
3774 << ", size: " << chunk.size;
3775 int64_t earliest_computation_start_time = end_time;
3776 for (const HloComputation* called_computation :
3777 use_instruction->called_computations()) {
3778 earliest_computation_start_time =
3779 std::min(earliest_computation_start_time,
3780 hlo_live_range->computation_span_times()
3781 .at(called_computation)
3782 .start);
3783 int64_t parameter_time = -1;
3784 int64_t last_use_time = -1;
3785 const HloInstruction* last_use_instruction = nullptr;
3786 for (const HloPosition& position : value->positions()) {
3787 if (position.instruction->opcode() == HloOpcode::kParameter &&
3788 position.instruction->parent() == called_computation) {
3789 parameter_time = hlo_live_range->instruction_schedule().at(
3790 position.instruction);
3791 break;
3792 }
3793 }
3794 for (const HloUse& use : value->uses()) {
3795 int64_t use_time =
3796 hlo_live_range->instruction_schedule().at(use.instruction);
3797 if (use.instruction->parent() == called_computation &&
3798 use_time > last_use_time) {
3799 last_use_time = use_time;
3800 last_use_instruction = use.instruction;
3801 }
3802 }
3803 if (last_use_time != -1) {
3804 CHECK_NE(parameter_time, -1);
3805 VLOG(3) << indent_string
3806 << " computation: " << called_computation->name() << ": ("
3807 << parameter_time << ", " << last_use_time << ")";
3808 CHECK(last_use_instruction);
3809 if (last_use_instruction->opcode() == HloOpcode::kConditional) {
3810 // The last use is another (nested) conditional. Call this
3811 // function recursively.
3812 TF_RETURN_IF_ERROR(split_conditional_buffer(
3813 last_use_instruction, parameter_time, last_use_time,
3814 absl::StrCat(indent_string, " ")));
3815 } else {
3816 last_use_time = std::min(last_use_time, end_time);
3817 TF_RETURN_IF_ERROR(add_allocation_and_verify(
3818 parameter_time, last_use_time, chunk, value));
3819 }
3820 }
3821 }
3822 VLOG(3) << indent_string << " from beginning until first computation: ("
3823 << start_time << ", " << (earliest_computation_start_time - 1)
3824 << ")";
3825 TF_RETURN_IF_ERROR(add_allocation_and_verify(
3826 start_time, earliest_computation_start_time - 1, chunk, value));
3827 return Status::OK();
3828 };
3829
3830 if (last_use_instruction &&
3831 last_use_instruction->opcode() == HloOpcode::kConditional) {
3832 TF_RETURN_IF_ERROR(split_conditional_buffer(
3833 last_use_instruction, time_bound.start, time_bound.end, " "));
3834 } else if (!value->uses().empty()) {
3835 last_use_time = std::min(last_use_time, time_bound.end);
3836 VLOG(3) << " buffer: " << buffer.ToString()
3837 << " value: " << value->ToShortString() << ": ("
3838 << time_bound.start << ", " << last_use_time
3839 << ") off: " << chunk.offset << ", size: " << chunk.size;
3840 TF_RETURN_IF_ERROR(add_allocation_and_verify(
3841 time_bound.start, last_use_time, chunk, value));
3842 }
3843 }
3844 }
3845
3846 HeapSimulatorTrace* heap_trace =
3847 &preset_assignments_
3848 ->assignment_information_for_space(options_.alternate_memory_space)
3849 ->heap_simulator_trace;
3850 int64_t memory_usage = 0;
3851 int64_t max_memory_usage = 0;
3852 for (const auto& event : events) {
3853 int64_t time;
3854 bool is_free;
3855 int64_t buffer_id;
3856 std::tie(time, is_free, buffer_id) = event.first;
3857 const HloValue* value;
3858 Chunk chunk;
3859 HeapSimulatorTrace::Event::Kind kind;
3860 std::tie(value, chunk, kind) = event.second;
3861 HeapSimulatorTrace::Event* heap_trace_event = heap_trace->add_events();
3862 heap_trace_event->set_kind(kind);
3863 heap_trace_event->set_buffer_id(buffer_id);
3864 heap_trace_event->set_instruction_name(value->instruction()->name());
3865 heap_trace_event->set_computation_name(
3866 value->instruction()->parent()->name());
3867
3868 if (kind == HeapSimulatorTrace::Event::ALLOC) {
3869 memory_usage += chunk.size;
3870 } else {
3871 CHECK_EQ(kind, HeapSimulatorTrace::Event::FREE);
3872 memory_usage -= chunk.size;
3873 }
3874 max_memory_usage = std::max(max_memory_usage, memory_usage);
3875 VLOG(4) << "Memory usage: " << memory_usage << " at time: " << time;
3876 }
3877 VLOG(1) << "Max memory usage ignoring fragmentation: " << max_memory_usage;
3878
3879 return Status::OK();
3880 }
3881 } // namespace memory_space_assignment
3882 } // namespace xla
3883