• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/memory_space_assignment.h"
17 
18 #include <algorithm>
19 #include <utility>
20 
21 #include "tensorflow/compiler/xla/debug_options_flags.h"
22 #include "tensorflow/compiler/xla/service/memory_space_assignment_utils.h"
23 #include "tensorflow/core/lib/math/math_util.h"
24 namespace xla {
25 
26 namespace memory_space_assignment {
27 
28 namespace {
29 // Define a dummy chunk for chunks that will be allocated in the default memory
30 // space and for keeping track of number of asynchronous copies.
31 const HeapSimulator::Chunk kDummyChunk{-1, -1};
32 
LooksLikeAnActivation(const HloInstruction * inst)33 bool LooksLikeAnActivation(const HloInstruction* inst) {
34   for (HloInstruction* user : inst->users()) {
35     switch (user->opcode()) {
36       case HloOpcode::kConvolution:
37       case HloOpcode::kDot:
38         if (user->operand(0) == inst) {
39           return true;
40         }
41         break;
42       case HloOpcode::kGather:
43         if (user->operand(1) == inst) {
44           return true;
45         }
46         break;
47       case HloOpcode::kFusion:
48         for (int i = 0; i < user->operand_count(); ++i) {
49           if (user->operand(i) == inst &&
50               LooksLikeAnActivation(user->fused_parameter(i))) {
51             return true;
52           }
53         }
54         break;
55       case HloOpcode::kBitcast:
56         return LooksLikeAnActivation(user);
57       default:
58         return true;
59     }
60   }
61   return false;
62 }
63 
IsCrossProgramPrefetchCandidate(const HloValue & value,const Options & options)64 bool IsCrossProgramPrefetchCandidate(const HloValue& value,
65                                      const Options& options) {
66   return value.instruction()->parent() ==
67              value.instruction()->GetModule()->entry_computation() &&
68          value.instruction()->opcode() == HloOpcode::kParameter &&
69          (!value.shape().has_layout() ||
70           value.shape().layout().memory_space() !=
71               options.alternate_memory_space) &&
72          value.index().size() == 1 && value.shape().IsArray() &&
73          !value.uses().empty() &&
74          options.size_fn(value) <= options.max_size_in_bytes &&
75          absl::c_all_of(value.uses(), [&](const HloUse& use) {
76            const HloInstruction* inst =
77                use.instruction->operand(use.operand_number);
78 
79            // Skip the LooksLikeAnActivation test since we're testing the
80            // parent GTE and its children below.
81            if (inst->opcode() == HloOpcode::kBitcast &&
82                inst->operand(0)->opcode() == HloOpcode::kGetTupleElement &&
83                inst->operand(0)->operand(0)->opcode() ==
84                    HloOpcode::kParameter) {
85              return true;
86            }
87 
88            return inst->opcode() == HloOpcode::kGetTupleElement &&
89                   !LooksLikeAnActivation(inst);
90          });
91 }
92 
93 absl::optional<MemorySpaceAssignment::BufferInterval>
FindCrossProgramPrefetchCandidate(const HloAliasAnalysis & alias_analysis,const HloLiveRange & hlo_live_range,const Options & options)94 FindCrossProgramPrefetchCandidate(const HloAliasAnalysis& alias_analysis,
95                                   const HloLiveRange& hlo_live_range,
96                                   const Options& options) {
97   std::vector<MemorySpaceAssignment::BufferInterval> candidates;
98   for (const HloBuffer& buffer : alias_analysis.buffers()) {
99     CHECK_GE(buffer.values().size(), 1);
100     const HloValue* value = buffer.values().at(0);
101     if (IsCrossProgramPrefetchCandidate(*value, options)) {
102       MemorySpaceAssignment::BufferInterval interval;
103       interval.buffer = value;
104       interval.size = options.size_fn(*value);
105       interval.start = 0;
106       interval.end = hlo_live_range.schedule_end_time();
107       interval.need_allocation = true;
108       interval.colocations = {++buffer.values().begin(), buffer.values().end()};
109       candidates.emplace_back(interval);
110     }
111   }
112 
113   // The buffer_interval_compare ought to do a good job picking the most
114   // appropriate buffer to cross program prefetch, but empirically, it makes
115   // worse choices than just picking the largest buffer.
116   // TODO(b/152421603): Investigate.
117   auto size_compare = [](const auto& x, const auto& y) {
118     if (x.size == y.size) {
119       // When both buffers are of same size, we prefer the one that is used to
120       // produce larger tensors in its consumer instructions.
121       auto get_use_size =
122           [](const MemorySpaceAssignment::BufferInterval& bi) -> int64 {
123         int64_t use_size = 0;
124         for (const auto& use : bi.buffer->uses()) {
125           use_size += ShapeUtil::ElementsInRecursive(use.instruction->shape());
126         }
127         return use_size;
128       };
129       return get_use_size(x) < get_use_size(y);
130     }
131     return x.size < y.size;
132   };
133   auto& compare = options.default_cross_program_prefetch_heuristic &&
134                           options.buffer_interval_compare
135                       ? *options.buffer_interval_compare
136                       : size_compare;
137 
138   auto best_candidate = absl::c_max_element(candidates, compare);
139   if (best_candidate == candidates.end()) {
140     return absl::nullopt;
141   }
142   return *best_candidate;
143 }
144 
145 }  // namespace
146 
147 /*static*/ StatusOr<std::unique_ptr<MemorySpaceAssignmentCostAnalysis>>
Create(const HloCostAnalysis & cost_analysis,const Options & options,const HloModule & module)148 MemorySpaceAssignmentCostAnalysis::Create(const HloCostAnalysis& cost_analysis,
149                                           const Options& options,
150                                           const HloModule& module) {
151   TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module));
152   TF_ASSIGN_OR_RETURN(auto hlo_live_range,
153                       HloLiveRange::Run(module.schedule(), *alias_analysis,
154                                         module.entry_computation()));
155   auto call_graph = CallGraph::Build(&module);
156   return absl::WrapUnique(new MemorySpaceAssignmentCostAnalysis(
157       cost_analysis, options, std::move(alias_analysis),
158       std::move(hlo_live_range), std::move(call_graph)));
159 }
160 
GetAlternateMemoryBenefit(const HloInstruction & instruction,float elapsed_time_due_to_alternate_mem,MemorySpaceAssignmentCostAnalysis::Cache * cache) const161 float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit(
162     const HloInstruction& instruction, float elapsed_time_due_to_alternate_mem,
163     MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
164   float elapsed_time_due_to_compute =
165       GetInstructionElapsedDueToCompute(instruction);
166   float elapsed_time_due_to_memory =
167       GetInstructionElapsedDueToMemory(instruction);
168   if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) {
169     // Memory bound, return how much alternate memory is better.
170     float while_nest_multiplier;
171     if (cache) {
172       // If there is a cache provided, memoize the while nest multiplier.
173       auto it = cache->while_nest_multiplier.find(&instruction);
174       if (it != cache->while_nest_multiplier.end()) {
175         while_nest_multiplier = it->second;
176       } else {
177         while_nest_multiplier = tensorflow::MathUtil::IPow<float>(
178             options_.xla_tpu_memory_space_assignment_while_execution_count,
179             CalculateComputationNestLevel(&instruction,
180                                           /*while_only=*/true));
181         cache->while_nest_multiplier[&instruction] = while_nest_multiplier;
182       }
183     } else {
184       while_nest_multiplier = tensorflow::MathUtil::IPow<float>(
185           options_.xla_tpu_memory_space_assignment_while_execution_count,
186           CalculateComputationNestLevel(&instruction,
187                                         /*while_only=*/true));
188     }
189     return (elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem) *
190            while_nest_multiplier;
191   } else {
192     // Compute bound, return how far off are we to memory boundedness.
193     return elapsed_time_due_to_memory - elapsed_time_due_to_compute;
194   }
195 }
196 
GetMemoryBoundedness(const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval & interval,MemorySpaceAssignmentCostAnalysis::Cache * cache) const197 float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness(
198     const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval,
199     MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
200   const HloInstruction& defining_instruction =
201       *interval.buffer->defining_instruction();
202   float alternate_mem_benefit = GetAlternateMemoryBenefit(
203       defining_instruction,
204       GetInstructionElapsedDueToMemory(
205           defining_instruction,
206           /*operands_in_alternate_mem=*/{},
207           /*outputs_in_alternate_mem=*/{interval.buffer->defining_index()}),
208       cache);
209   for (const HloBuffer* buffer : alias_analysis_->ComputeBuffersAt(
210            interval.buffer->defining_position().instruction,
211            interval.buffer->defining_position().index)) {
212     for (const HloValue* value : buffer->values()) {
213       for (const HloUse& use : value->uses()) {
214         // We look inside the called computations of while and conditional, so
215         // don't use the benefit of while and conditional directly.
216         if (use.instruction->opcode() == HloOpcode::kWhile ||
217             use.instruction->opcode() == HloOpcode::kConditional) {
218           continue;
219         }
220         float use_alternate_mem_benefit = GetAlternateMemoryBenefit(
221             *use.instruction,
222             GetInstructionElapsedDueToMemory(
223                 *use.instruction,
224                 /*operands_in_alternate_mem=*/{std::make_pair(
225                     use.operand_number, use.operand_index)}),
226             cache);
227         // If the benefit is positive (memory bound), add it to this buffer's
228         // benefit. If the benefit is negative (compute bound), calculate the
229         // maximum.
230         if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) {
231           alternate_mem_benefit += use_alternate_mem_benefit;
232         } else {
233           alternate_mem_benefit =
234               std::max(alternate_mem_benefit, use_alternate_mem_benefit);
235         }
236       }
237     }
238   }
239 
240   // Penalize larger buffers by dividing the benefit by the square root of the
241   // size. Empirically, we observed this resulted in better performance compared
242   // to dividing by the size.
243   return alternate_mem_benefit / std::sqrt(interval.size);
244 }
245 
CalculateComputationNestLevel(const HloInstruction * instruction,bool while_only) const246 int MemorySpaceAssignmentCostAnalysis::CalculateComputationNestLevel(
247     const HloInstruction* instruction, bool while_only) const {
248   int nest_level = 0;
249   const HloComputation* computation = instruction->parent();
250   while (!computation->IsEntryComputation()) {
251     auto node = call_graph_->GetNode(computation);
252     auto callsites = node.caller_callsites();
253     CHECK_EQ(callsites.size(), 1) << "The module is not flattened!";
254     auto callsite = callsites[0];
255     if (!while_only || callsite.instruction()->opcode() == HloOpcode::kWhile) {
256       ++nest_level;
257     }
258     computation = callsite.instruction()->parent();
259   }
260   return nest_level;
261 }
262 
GetInstructionElapsedDueToCompute(const HloInstruction & instruction) const263 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToCompute(
264     const HloInstruction& instruction) const {
265   return std::max(
266       cost_analysis_.flop_count(instruction) /
267           cost_analysis_.per_second_rate(HloCostAnalysis::kFlopsKey),
268       cost_analysis_.transcendental_count(instruction) /
269           cost_analysis_.per_second_rate(HloCostAnalysis::kTranscendentalsKey));
270 }
271 
GetInstructionElapsedDueToMemory(const HloInstruction & instruction,absl::Span<const std::pair<int64_t,ShapeIndex>> operands_in_alternate_mem,absl::Span<const ShapeIndex> outputs_in_alternate_mem) const272 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToMemory(
273     const HloInstruction& instruction,
274     absl::Span<const std::pair<int64_t, ShapeIndex>> operands_in_alternate_mem,
275     absl::Span<const ShapeIndex> outputs_in_alternate_mem) const {
276   float total_bytes_accessed = cost_analysis_.bytes_accessed(instruction);
277   float bytes_accessed_from_alternate_mem = 0.0;
278   for (auto& operand : operands_in_alternate_mem) {
279     float operand_bytes_accessed = cost_analysis_.operand_bytes_accessed(
280         instruction, operand.first, operand.second);
281     bytes_accessed_from_alternate_mem += operand_bytes_accessed;
282   }
283 
284   for (auto& shape_idx : outputs_in_alternate_mem) {
285     float output_bytes_accessed =
286         cost_analysis_.output_bytes_accessed(instruction, shape_idx);
287     bytes_accessed_from_alternate_mem += output_bytes_accessed;
288   }
289   float elapsed_due_to_alternate_mem =
290       bytes_accessed_from_alternate_mem /
291       options().alternate_mem_bandwidth_bytes_per_second;
292   float elapsed_due_to_default_mem =
293       (total_bytes_accessed - bytes_accessed_from_alternate_mem) /
294       cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey);
295   return elapsed_due_to_alternate_mem + elapsed_due_to_default_mem;
296 }
297 
GetInstructionElapsed(const HloInstruction & instruction) const298 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsed(
299     const HloInstruction& instruction) const {
300   return std::max(GetInstructionElapsedDueToCompute(instruction),
301                   GetInstructionElapsedDueToMemory(instruction));
302 }
303 
GetInstructionElapsedInAlternateMemory(const HloInstruction & instruction,absl::Span<const std::pair<int64_t,ShapeIndex>> operands_in_alternate_mem,absl::Span<const ShapeIndex> outputs_in_alternate_mem) const304 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedInAlternateMemory(
305     const HloInstruction& instruction,
306     absl::Span<const std::pair<int64_t, ShapeIndex>> operands_in_alternate_mem,
307     absl::Span<const ShapeIndex> outputs_in_alternate_mem) const {
308   return std::max(
309       GetInstructionElapsedDueToCompute(instruction),
310       GetInstructionElapsedDueToMemory(instruction, operands_in_alternate_mem,
311                                        outputs_in_alternate_mem));
312 }
313 
GetAsyncCopyElapsed(const Shape & shape) const314 float MemorySpaceAssignmentCostAnalysis::GetAsyncCopyElapsed(
315     const Shape& shape) const {
316   int64_t size_in_bytes = cost_analysis_.GetShapeSize(shape);
317   return static_cast<float>(size_in_bytes) /
318          options().async_copy_bandwidth_bytes_per_second;
319 }
320 
GetScheduleEndTime() const321 int64 MemorySpaceAssignmentCostAnalysis::GetScheduleEndTime() const {
322   return hlo_live_range_->schedule_end_time();
323 }
324 
CanAllocateInAlternateMemoryNoCopy(const Shape & shape,int64_t start_time,int64_t end_time) const325 bool InstructionCountPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy(
326     const Shape& shape, int64_t start_time, int64_t end_time) const {
327   return end_time - start_time <= max_overlap_count_;
328 }
329 
PreferredEvictionEndTime(const Shape & shape,int64_t start_time,int64_t latest_end_time) const330 int64 InstructionCountPrefetchIntervalPicker::PreferredEvictionEndTime(
331     const Shape& shape, int64_t start_time, int64_t latest_end_time) const {
332   return std::min(start_time + min_overlap_count_, latest_end_time);
333 }
334 
LatestPrefetchStartTime(const Shape & shape,int64_t start_time,int64_t end_time,const HloUse * use) const335 int64 InstructionCountPrefetchIntervalPicker::LatestPrefetchStartTime(
336     const Shape& shape, int64_t start_time, int64_t end_time,
337     const HloUse* use) const {
338   return end_time - min_overlap_count_;
339 }
340 
PreferredPrefetchStartTime(const Shape & shape,int64_t earliest_prefetch_start_time,int64_t latest_prefetch_start_time,int64_t prefetch_end_time) const341 int64 InstructionCountPrefetchIntervalPicker::PreferredPrefetchStartTime(
342     const Shape& shape, int64_t earliest_prefetch_start_time,
343     int64_t latest_prefetch_start_time, int64_t prefetch_end_time) const {
344   return std::max(earliest_prefetch_start_time,
345                   prefetch_end_time - max_overlap_count_);
346 }
347 
Begin(const HloUse & use,int64_t start_time,int64_t end_time)348 void InstructionCountPrefetchIntervalPicker::Begin(const HloUse& use,
349                                                    int64_t start_time,
350                                                    int64_t end_time) {
351   end_time_ = end_time;
352   const Shape& shape = ShapeUtil::GetSubshape(
353       use.instruction->operand(use.operand_number)->shape(), use.operand_index);
354   current_prefetch_time_ =
355       PreferredPrefetchStartTime(shape, start_time, end_time, end_time);
356 }
357 
Next()358 int64 InstructionCountPrefetchIntervalPicker::Next() {
359   CHECK(!Done()) << "Prefetch interval picker's Next() is called even though "
360                     "Done() is false";
361   return current_prefetch_time_++;
362 }
363 
Done() const364 bool InstructionCountPrefetchIntervalPicker::Done() const {
365   return end_time_ - current_prefetch_time_ <= min_overlap_count_;
366 }
367 
ToDebugString() const368 std::string InstructionCountPrefetchIntervalPicker::ToDebugString() const {
369   return absl::StrCat("Overlapped HLOs = ", end_time_ - current_prefetch_time_);
370 }
371 
ToNoCopyDebugString(const Shape & shape,int64_t start_time,int64_t end_time) const372 std::string InstructionCountPrefetchIntervalPicker::ToNoCopyDebugString(
373     const Shape& shape, int64_t start_time, int64_t end_time) const {
374   return absl::StrCat("Overlapped HLOs = ", end_time - start_time);
375 }
376 
CostAnalysisPrefetchIntervalPicker(const MemorySpaceAssignmentCostAnalysis & cost_analysis,float min_async_copy_to_overlap_ratio,float max_async_copy_to_overlap_ratio,float preferred_async_copy_to_overlap_ratio,int64_t buffer_size_for_max_async_copy)377 CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker(
378     const MemorySpaceAssignmentCostAnalysis& cost_analysis,
379     float min_async_copy_to_overlap_ratio,
380     float max_async_copy_to_overlap_ratio,
381     float preferred_async_copy_to_overlap_ratio,
382     int64_t buffer_size_for_max_async_copy)
383     : while_nest_level_(
384           cost_analysis.hlo_live_range().instruction_schedule().size() + 1, 0),
385       computation_nest_level_(
386           cost_analysis.hlo_live_range().instruction_schedule().size() + 1, 0),
387       cost_analysis_(cost_analysis),
388       min_async_copy_to_overlap_ratio_(min_async_copy_to_overlap_ratio),
389       max_async_copy_to_overlap_ratio_(max_async_copy_to_overlap_ratio),
390       preferred_async_copy_to_overlap_ratio_(
391           preferred_async_copy_to_overlap_ratio),
392       buffer_size_for_max_async_copy_(buffer_size_for_max_async_copy) {
393   instruction_schedule_ =
394       &cost_analysis_.hlo_live_range().instruction_schedule();
395 
396   // Create a vector of elapsed times and while nesting levels of HLO
397   // instructions. The elapsed times are multiplied by
398   // pow(while_execution_count, nest_level) to account for executing the HLOs
399   // multiple times in while loops.
400   std::vector<float> instructions_elapsed_time(instruction_schedule_->size(),
401                                                0.0);
402   for (const auto& instruction_and_logical_time : *instruction_schedule_) {
403     // To avoid double counting, don't include the elapsed time of while and
404     // conditional HLOs.
405     const HloInstruction* instruction = instruction_and_logical_time.first;
406     int64_t logical_time = instruction_and_logical_time.second;
407     if (logical_time >= instructions_elapsed_time.size()) {
408       instructions_elapsed_time.resize(logical_time + 1, 0.0);
409       while_nest_level_.resize(logical_time + 1, 0);
410     }
411     int while_nest_level = cost_analysis_.CalculateComputationNestLevel(
412         instruction_and_logical_time.first, /*while_only=*/true);
413     while_nest_level_[logical_time] = while_nest_level;
414     int computation_nest_level = cost_analysis_.CalculateComputationNestLevel(
415         instruction_and_logical_time.first, /*while_only=*/false);
416     computation_nest_level_[logical_time] = computation_nest_level;
417     if (instruction->opcode() == HloOpcode::kWhile ||
418         instruction->opcode() == HloOpcode::kConditional) {
419       continue;
420     }
421     float elapsed_time = cost_analysis_.GetInstructionElapsed(
422         *instruction_and_logical_time.first);
423     instructions_elapsed_time[logical_time] =
424         elapsed_time *
425         tensorflow::MathUtil::IPow<float>(
426             cost_analysis_.options()
427                 .xla_tpu_memory_space_assignment_while_execution_count,
428             while_nest_level);
429   }
430   // As an optimization, create a cumulative sum vector of elapsed time.
431   float cumsum = 0.0;
432   elapsed_time_cumsum_.reserve(instructions_elapsed_time.size());
433   for (float elapsed_time : instructions_elapsed_time) {
434     cumsum += elapsed_time;
435     elapsed_time_cumsum_.push_back(cumsum);
436   }
437   // To be able to accurately determine the minimum nest level between a start
438   // time and an end time efficiently, populate a data structure that stores the
439   // closest nest level change index.
440   int prev_nest_level = 0;
441   int change_idx = -1;
442   while_nest_level_change_.reserve(instructions_elapsed_time.size());
443   for (int i = 0; i < while_nest_level_.size(); ++i) {
444     int nest_level = while_nest_level_[i];
445     if (nest_level != prev_nest_level) {
446       prev_nest_level = nest_level;
447       change_idx = i - 1;
448     }
449     while_nest_level_change_.push_back(change_idx);
450   }
451 }
452 
GetMaxElapsedInAlternateMemory(float async_copy_elapsed) const453 float CostAnalysisPrefetchIntervalPicker::GetMaxElapsedInAlternateMemory(
454     float async_copy_elapsed) const {
455   return max_async_copy_to_overlap_ratio_ *
456          std::max(max_overlap_multiplier_ * async_copy_elapsed,
457                   cost_analysis_.GetAsyncCopyElapsed(ShapeUtil::MakeShape(
458                       S32, {buffer_size_for_max_async_copy_ / 4})));
459 }
460 
CanAllocateInAlternateMemoryNoCopy(const Shape & shape,int64_t start_time,int64_t end_time) const461 bool CostAnalysisPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy(
462     const Shape& shape, int64_t start_time, int64_t end_time) const {
463   // Even though this method returns if we allow the buffer in alternate memory
464   // _without_ asynchronous copies, calculate how long it would have taken to
465   // copy it and compare it to the elapsed time in the logical interval.
466   float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
467   float logical_interval_elapsed =
468       GetLogicalIntervalElapsed(start_time, end_time);
469   return GetMaxElapsedInAlternateMemory(async_copy_elapsed) >
470          logical_interval_elapsed;
471 }
472 
PreferredEvictionEndTime(const Shape & shape,int64_t start_time,int64_t latest_end_time) const473 int64 CostAnalysisPrefetchIntervalPicker::PreferredEvictionEndTime(
474     const Shape& shape, int64_t start_time, int64_t latest_end_time) const {
475   float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
476   int64_t end_time;
477   for (end_time = start_time + 1; end_time <= latest_end_time; ++end_time) {
478     float logical_interval_elapsed =
479         GetLogicalIntervalElapsed(start_time, end_time);
480     if (logical_interval_elapsed >=
481         min_async_copy_to_overlap_ratio_ * async_copy_elapsed) {
482       break;
483     }
484   }
485   return end_time;
486 }
487 
LatestPrefetchStartTime(const Shape & shape,int64_t start_time,int64_t end_time,const HloUse * use) const488 int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchStartTime(
489     const Shape& shape, int64_t start_time, int64_t end_time,
490     const HloUse* use) const {
491   // Find the earliest time that satisfies max_async_copy_to_overlap_ratio_.
492   float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
493   // If there is a use, estimate the time we would save by having this op in
494   // alternate memory.
495   float inst_elapsed_reduction = 0.0f;
496   if (use) {
497     float elapsed_time =
498         cost_analysis_.GetInstructionElapsed(*use->instruction);
499     float elapsed_time_in_alternate_mem =
500         cost_analysis_.GetInstructionElapsedInAlternateMemory(
501             *use->instruction,
502             /*operands_in_alternate_mem=*/
503             {std::make_pair(use->operand_number, use->operand_index)},
504             /*outputs_in_alternate_mem=*/{});
505     inst_elapsed_reduction = elapsed_time - elapsed_time_in_alternate_mem;
506   }
507   int end_nest_level = computation_nest_level_[end_time];
508 
509   // Find the latest time we're allowed to start prefetching.
510   float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed;
511   int latest_prefetch_time;
512   for (latest_prefetch_time = end_time - 1;
513        latest_prefetch_time >= start_time &&
514        (computation_nest_level_[latest_prefetch_time] != end_nest_level ||
515         min_interval >
516             GetLogicalIntervalElapsed(latest_prefetch_time, end_time) +
517                 inst_elapsed_reduction);
518        --latest_prefetch_time) {
519   }
520 
521   return latest_prefetch_time;
522 }
523 
PreferredPrefetchStartTime(const Shape & shape,int64_t earliest_prefetch_start_time,int64_t latest_prefetch_start_time,int64_t prefetch_end_time) const524 int64 CostAnalysisPrefetchIntervalPicker::PreferredPrefetchStartTime(
525     const Shape& shape, int64_t earliest_prefetch_start_time,
526     int64_t latest_prefetch_start_time, int64_t prefetch_end_time) const {
527   // Between the earliest and latest prefetch interval, find the interval
528   // closest to the preferred interval and start iterating from there.
529   float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
530   int64_t preferred_prefetch_start_time = earliest_prefetch_start_time;
531   float preferred_interval =
532       preferred_async_copy_to_overlap_ratio_ * async_copy_elapsed;
533   float best_interval = GetLogicalIntervalElapsed(earliest_prefetch_start_time,
534                                                   prefetch_end_time);
535   int end_nest_level = computation_nest_level_[prefetch_end_time];
536   for (int64_t prefetch_start_time = earliest_prefetch_start_time + 1;
537        prefetch_start_time <= latest_prefetch_start_time;
538        ++prefetch_start_time) {
539     float interval =
540         GetLogicalIntervalElapsed(prefetch_start_time, prefetch_end_time);
541     if (computation_nest_level_[prefetch_start_time] == end_nest_level &&
542         std::abs(preferred_interval - interval) <
543             std::abs(preferred_interval - best_interval)) {
544       best_interval = interval;
545       preferred_prefetch_start_time = prefetch_start_time;
546     }
547   }
548   return preferred_prefetch_start_time;
549 }
550 
LatestPrefetchEndTime(int64_t original_prefetch_end_time,int64_t proposed_prefetch_end_time) const551 int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchEndTime(
552     int64_t original_prefetch_end_time,
553     int64_t proposed_prefetch_end_time) const {
554   // Iterate towards the beginning until we find a suitable end time that is the
555   // same while nest level as the original prefetch end time.
556   int64_t original_nest_level =
557       computation_nest_level_[original_prefetch_end_time];
558   int64_t new_prefetch_end_time;
559   for (new_prefetch_end_time = proposed_prefetch_end_time;
560        computation_nest_level_[new_prefetch_end_time] != original_nest_level;
561        --new_prefetch_end_time) {
562   }
563   return new_prefetch_end_time;
564 }
565 
Begin(const HloUse & use,int64_t start_time,int64_t end_time)566 void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use,
567                                                int64_t start_time,
568                                                int64_t end_time) {
569   const Shape& shape = ShapeUtil::GetSubshape(
570       use.instruction->operand(use.operand_number)->shape(), use.operand_index);
571   // Find the earliest time that satisfies max_async_copy_to_overlap_ratio_.
572   async_copy_elapsed_ = cost_analysis_.GetAsyncCopyElapsed(shape);
573   // Estimate the time we would save by having this op in alternate memory.
574   float elapsed_time = cost_analysis_.GetInstructionElapsed(*use.instruction);
575   float elapsed_time_in_alternate_mem =
576       cost_analysis_.GetInstructionElapsedInAlternateMemory(
577           *use.instruction, /*operands_in_alternate_mem=*/
578           {std::make_pair(use.operand_number, use.operand_index)},
579           /*outputs_in_alternate_mem=*/{});
580   inst_elapsed_reduction_ = elapsed_time - elapsed_time_in_alternate_mem;
581   end_logical_time_ = end_time;
582   int end_nest_level = computation_nest_level_[end_logical_time_];
583 
584   // Find the latest time we're allowed to start prefetching.
585   float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed_;
586   latest_prefetch_time_ =
587       LatestPrefetchStartTime(shape, start_time, end_time, &use);
588 
589   // Find the earliest time we're allowed to start prefetching.
590   float max_interval = GetMaxElapsedInAlternateMemory(async_copy_elapsed_);
591   for (earliest_prefetch_time_ = start_time;
592        earliest_prefetch_time_ < latest_prefetch_time_ &&
593        (computation_nest_level_[earliest_prefetch_time_] != end_nest_level ||
594         max_interval < GetLogicalIntervalElapsed(earliest_prefetch_time_,
595                                                  end_logical_time_));
596        ++earliest_prefetch_time_) {
597   }
598   if (earliest_prefetch_time_ > latest_prefetch_time_) {
599     // There is no available prefetch interval for the given start and end
600     // times. Set the iterators accordingly to ensure Done() returns true.
601     increasing_prefetch_time_iterator_ = earliest_prefetch_time_;
602     decreasing_prefetch_time_iterator_ = latest_prefetch_time_;
603     CHECK(Done());
604     return;
605   }
606 
607   int64_t starting_prefetch_time = PreferredPrefetchStartTime(
608       shape, earliest_prefetch_time_, latest_prefetch_time_, end_logical_time_);
609   float preferred_interval =
610       preferred_async_copy_to_overlap_ratio_ * async_copy_elapsed_;
611   VLOG(4) << "Interval min/max/preferred = " << min_interval << " "
612           << max_interval << " " << preferred_interval
613           << " prefetch time earliest/latest/starting = "
614           << earliest_prefetch_time_ << " " << latest_prefetch_time_ << " "
615           << starting_prefetch_time;
616 
617   increasing_prefetch_time_iterator_ = starting_prefetch_time;
618   decreasing_prefetch_time_iterator_ = starting_prefetch_time;
619   using_increasing_prefetch_time_iterator_ = true;
620   // Since both iterators start at the same position, call Next() once to
621   // advance one of the iterators.
622   Next();
623 }
624 
Next()625 int64 CostAnalysisPrefetchIntervalPicker::Next() {
626   CHECK(!Done()) << "Prefetch interval picker's Next() is called even though "
627                     "Done() is false";
628   if (using_increasing_prefetch_time_iterator_) {
629     int64_t prefetch_time = increasing_prefetch_time_iterator_++;
630     while (increasing_prefetch_time_iterator_ <= latest_prefetch_time_ &&
631            computation_nest_level_[increasing_prefetch_time_iterator_] !=
632                computation_nest_level_[end_logical_time_]) {
633       ++increasing_prefetch_time_iterator_;
634     }
635     if (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_) {
636       using_increasing_prefetch_time_iterator_ = false;
637     }
638     return prefetch_time;
639   } else {
640     int64_t prefetch_time = decreasing_prefetch_time_iterator_--;
641     while (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_ &&
642            computation_nest_level_[decreasing_prefetch_time_iterator_] !=
643                computation_nest_level_[end_logical_time_]) {
644       --decreasing_prefetch_time_iterator_;
645     }
646     if (increasing_prefetch_time_iterator_ <= latest_prefetch_time_) {
647       using_increasing_prefetch_time_iterator_ = true;
648     }
649     return prefetch_time;
650   }
651 }
652 
Done() const653 bool CostAnalysisPrefetchIntervalPicker::Done() const {
654   return increasing_prefetch_time_iterator_ > latest_prefetch_time_ &&
655          decreasing_prefetch_time_iterator_ < earliest_prefetch_time_;
656 }
657 
SetRetryNumber(int retry_number)658 void CostAnalysisPrefetchIntervalPicker::SetRetryNumber(int retry_number) {
659   // Use twice as large max overlap limit in each retry.
660   max_overlap_multiplier_ = 1 << retry_number;
661 }
662 
GetMinWhileNestLevel(int64_t start_time,int64_t end_time) const663 int CostAnalysisPrefetchIntervalPicker::GetMinWhileNestLevel(
664     int64_t start_time, int64_t end_time) const {
665   int min_nest_level =
666       std::min(while_nest_level_[start_time], while_nest_level_[end_time]);
667   int change_idx = while_nest_level_change_[end_time];
668   while (change_idx >= start_time) {
669     min_nest_level = std::min(min_nest_level, while_nest_level_[change_idx]);
670     change_idx = while_nest_level_change_[change_idx];
671   }
672   return min_nest_level;
673 }
674 
GetLogicalIntervalElapsed(int64_t start_time,int64_t end_time) const675 float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed(
676     int64_t start_time, int64_t end_time) const {
677   CHECK_LE(start_time, end_time);
678   if (start_time == end_time) {
679     return 0.0;
680   }
681   if (start_time < 0) {
682     start_time = 0;
683   }
684   // Since elapsed_time_cumsum_ is already weighed by the while loop nesting
685   // level, normalize the elapsed time by dividing with the nesting factor of
686   // the interval (start and end times).
687   int interval_while_nest_level = GetMinWhileNestLevel(start_time, end_time);
688   return (elapsed_time_cumsum_[end_time - 1] -
689           elapsed_time_cumsum_[start_time]) /
690          tensorflow::MathUtil::IPow<float>(
691              cost_analysis_.options()
692                  .xla_tpu_memory_space_assignment_while_execution_count,
693              interval_while_nest_level);
694 }
695 
ToDebugString() const696 std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const {
697   int current_logical_prefetch_time = using_increasing_prefetch_time_iterator_
698                                           ? increasing_prefetch_time_iterator_
699                                           : decreasing_prefetch_time_iterator_;
700   float logical_interval_elapsed = GetLogicalIntervalElapsed(
701       current_logical_prefetch_time, end_logical_time_);
702   return absl::StrCat(
703       "Async copy elapsed (s) = ", async_copy_elapsed_,
704       ", inst elapsed reduction (s) = ", inst_elapsed_reduction_,
705       ", logical interval elapsed (s) = ", logical_interval_elapsed,
706       ", interval = (", current_logical_prefetch_time, ", ", end_logical_time_,
707       ")");
708 }
709 
ToNoCopyDebugString(const Shape & shape,int64_t start_time,int64_t end_time) const710 std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString(
711     const Shape& shape, int64_t start_time, int64_t end_time) const {
712   float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
713   float logical_interval_elapsed =
714       GetLogicalIntervalElapsed(start_time, end_time);
715   return absl::StrCat(
716       "Async copy elapsed (s) = ", async_copy_elapsed,
717       ", logical interval elapsed (s) = ", logical_interval_elapsed);
718 }
719 
720 absl::optional<float>
BufferIntervalAlternateMemoryBenefit(const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval & interval) const721 CostAnalysisPrefetchIntervalPicker::BufferIntervalAlternateMemoryBenefit(
722     const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval)
723     const {
724   return cost_analysis_.GetMemoryBoundedness(interval);
725 }
726 
operator ==(const MemorySpaceAssignment::Allocation & other) const727 bool MemorySpaceAssignment::Allocation::operator==(
728     const MemorySpaceAssignment::Allocation& other) const {
729   return defining_position() == other.defining_position() &&
730          uses() == other.uses() && memory_space() == other.memory_space() &&
731          chunk() == other.chunk() && start_time() == other.start_time() &&
732          end_time() == other.end_time() &&
733          is_copy_allocation() == other.is_copy_allocation() &&
734          is_scoped_allocation() == other.is_scoped_allocation();
735 }
736 
operator ==(const MemorySpaceAssignment::CopyAllocation & other) const737 bool MemorySpaceAssignment::CopyAllocation::operator==(
738     const MemorySpaceAssignment::CopyAllocation& other) const {
739   return static_cast<const Allocation&>(*this) ==
740              static_cast<const Allocation&>(other) &&
741          copy_done_schedule_before() == other.copy_done_schedule_before() &&
742          copy_start_schedule_after() == other.copy_start_schedule_after() &&
743          copy_start() == other.copy_start() && copy_done() == other.copy_done();
744 }
745 
ToString() const746 std::string MemorySpaceAssignment::AllocationValue::ToString() const {
747   std::string out = absl::StrCat("computation = ", computation()->name());
748   absl::StrAppend(&out,
749                   (requires_contiguous_allocation_ ? " (cont alloc)" : ""));
750   absl::StrAppend(&out, "\n position:\n");
751   absl::StrAppend(&out, "  ", defining_position_.ToString(), "\n");
752   absl::StrAppend(&out, " uses:\n");
753   for (const Use& use : uses_) {
754     absl::StrAppend(&out, "  ", use.hlo_use.ToString(), "\n");
755   }
756   return out;
757 }
758 
ToShortString() const759 std::string MemorySpaceAssignment::AllocationValue::ToShortString() const {
760   return absl::StrCat("computation = ", computation()->name(),
761                       ", position = ", defining_position_.ToString(),
762                       ", value = ", value_->ToShortString(),
763                       (requires_contiguous_allocation_ ? " (cont alloc)" : ""));
764 }
765 
CreateAllocationValues(const AlternateMemoryBestFitHeap::BufferInterval & buffer_interval,std::vector<AllocationValue> & allocation_values) const766 void AlternateMemoryBestFitHeap::CreateAllocationValues(
767     const AlternateMemoryBestFitHeap::BufferInterval& buffer_interval,
768     std::vector<AllocationValue>& allocation_values) const {
769   const HloValue* value = buffer_interval.buffer;
770   VLOG(3) << "Creating AllocationValues for: " << value->ToString();
771 
772   // Find and sort all non-trivial (excluding GTE, Tuple, and bitcast)
773   // positions. We create an AllocationValue object for each non-trivial
774   // position. And for each AllocationValue object, we create an
775   // AllocationSequence consisting of one or more Allocation objects.The reason
776   // why we exclude the trivial positions from AllocationValue is because
777   // Allocation objects have special support for tuples and bitcasts.
778   const absl::flat_hash_map<const HloInstruction*, int64>&
779       instruction_schedule = hlo_live_range_.instruction_schedule();
780   std::vector<HloPosition> positions;
781   for (const HloPosition& position : value->positions()) {
782     const HloInstruction* instruction = position.instruction;
783     if (instruction->opcode() != HloOpcode::kGetTupleElement &&
784         instruction->opcode() != HloOpcode::kTuple &&
785         instruction->opcode() != HloOpcode::kBitcast) {
786       positions.push_back(position);
787     }
788   }
789   absl::c_stable_sort(positions,
790                       [&](const HloPosition& pos1, const HloPosition& pos2) {
791                         return instruction_schedule.at(pos1.instruction) <
792                                instruction_schedule.at(pos2.instruction);
793                       });
794 
795   // Create an AllocationValue for each non-trivial position.
796   absl::flat_hash_set<const HloComputation*> computations;
797   int beginning_idx = allocation_values.size();
798   for (int i = 0; i < positions.size(); ++i) {
799     const HloPosition& position = positions.at(i);
800     allocation_values.emplace_back(value, position, buffer_interval.size);
801   }
802 
803   std::vector<HloUse> uses(value->uses());
804   absl::c_stable_sort(uses, [&](const HloUse& use1, const HloUse& use2) {
805     return instruction_schedule.at(use1.instruction) <
806            instruction_schedule.at(use2.instruction);
807   });
808 
809   // Associate each use with an AllocationValue. Each AllocationValue contains a
810   // position and uses in the same computation. Furthermore, if the original
811   // HloValue had multiple non-trivial positions in the same computation, those
812   // will get their own AllocationValue as well. We split these HloValues so
813   // that when we insert CopyStart/CopyDone in CopyAllocation::Process, they
814   // point to the latest position. We then replace the operand of the use with
815   // CopyStart/CopyDone with an operand of the latest position.
816   for (const HloUse& use : uses) {
817     int64_t use_time = instruction_schedule.at(use.instruction);
818     HloComputation* use_computation = use.instruction->parent();
819 
820     AllocationValue* last_allocation_value = nullptr;
821     for (int i = beginning_idx; i < allocation_values.size(); ++i) {
822       AllocationValue* allocation_value = &allocation_values.at(i);
823       if (HloDataflowAnalysis::IsAsynchronousOperationDone(
824               use.instruction->opcode())) {
825         if (allocation_value->defining_instruction() ==
826             use.instruction->operand(0)) {
827           last_allocation_value = allocation_value;
828         }
829       } else if (!HloDataflowAnalysis::IsAsynchronousOperationStart(
830                      allocation_value->defining_instruction()->opcode()) &&
831                  allocation_value->computation() == use_computation &&
832                  instruction_schedule.at(
833                      allocation_value->defining_position().instruction) <
834                      use_time) {
835         last_allocation_value = allocation_value;
836       }
837     }
838     CHECK(last_allocation_value != nullptr);
839     last_allocation_value->AddUse(use, use_time);
840   }
841 
842   for (int i = beginning_idx; i < allocation_values.size(); ++i) {
843     AllocationValue& allocation_value = allocation_values.at(i);
844     if (HloDataflowAnalysis::IsAsynchronousOperationStart(
845             allocation_value.defining_instruction()->opcode())) {
846       CHECK_EQ(allocation_value.uses().size(), 1);
847       CHECK(HloDataflowAnalysis::IsAsynchronousOperationDone(
848           allocation_value.uses().at(0).hlo_use.instruction->opcode()));
849       VLOG(3) << "Mark " << allocation_value.ToShortString()
850               << " to require contiguous allocation.";
851       allocation_value.set_requires_contiguous_allocation(true);
852     }
853     VLOG(3) << "Created allocation value: "
854             << allocation_values.at(i).ToString();
855   }
856 }
857 
FindAliases(std::vector<AllocationValue> * allocation_values) const858 void AlternateMemoryBestFitHeap::FindAliases(
859     std::vector<AllocationValue>* allocation_values) const {
860   absl::flat_hash_map<const HloInstruction*,
861                       std::vector<const AllocationValue*>>
862       values_by_defining_inst;
863   for (AllocationValue& value : *allocation_values) {
864     values_by_defining_inst[value.defining_instruction()].push_back(&value);
865   }
866   auto maybe_add_alias_with_instruction = [&](const HloInstruction* instruction,
867                                               AllocationValue::Use* use) {
868     auto aliased_values_it = values_by_defining_inst.find(instruction);
869     if (aliased_values_it != values_by_defining_inst.end()) {
870       for (const AllocationValue* aliased_value : aliased_values_it->second) {
871         VLOG(3) << "Adding aliasing for use " << use->hlo_use.ToString()
872                 << " to " << aliased_value->ToShortString();
873         use->aliases.push_back(aliased_value->defining_position());
874       }
875     }
876   };
877 
878   for (AllocationValue& value : *allocation_values) {
879     for (AllocationValue::Use& use : value.uses()) {
880       // Find any aliases with the instruction itself (operand and output must
881       // alias).
882       maybe_add_alias_with_instruction(use.hlo_use.instruction, &use);
883 
884       // Find any aliases with the parameters of called computations.
885       for (const HloComputation* called_computation :
886            use.hlo_use.instruction->called_computations()) {
887         for (const HloInstruction* parameter_instruction :
888              called_computation->parameter_instructions()) {
889           maybe_add_alias_with_instruction(parameter_instruction, &use);
890         }
891       }
892 
893       // Special case for kWhile: the root of the body computation must alias as
894       // well.
895       if (use.hlo_use.instruction->opcode() == HloOpcode::kWhile) {
896         HloPosition root_alias{
897             use.hlo_use.instruction->while_body()->root_instruction(),
898             use.hlo_use.operand_index};
899         VLOG(3) << "Adding while body root aliasing for use "
900                 << use.hlo_use.ToString() << " to " << root_alias;
901         use.aliases.push_back(root_alias);
902       }
903     }
904   }
905 }
906 
907 std::vector<const AlternateMemoryBestFitHeap::BufferInterval*>
GetSortedColocatedIntervals(const AlternateMemoryBestFitHeap::BufferInterval & interval) const908 AlternateMemoryBestFitHeap::GetSortedColocatedIntervals(
909     const AlternateMemoryBestFitHeap::BufferInterval& interval) const {
910   std::vector<const BufferInterval*> colocated_intervals;
911   std::vector<const BufferInterval*> worklist = {&interval};
912   while (!worklist.empty()) {
913     const BufferInterval* item = worklist.back();
914     worklist.pop_back();
915     colocated_intervals.push_back(item);
916     for (const HloValue* buffer_colocated : item->colocations) {
917       worklist.push_back(&buffer_intervals_.at(buffer_colocated));
918     }
919   }
920 
921   absl::c_stable_sort(colocated_intervals, [&](const BufferInterval* x,
922                                                const BufferInterval* y) {
923     return std::make_pair(x->start, x->end) < std::make_pair(y->start, y->end);
924   });
925   return colocated_intervals;
926 }
927 
IsUseAllowedInAlternateMemory(const AllocationValue & value,const HloUse & use) const928 bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
929     const AllocationValue& value, const HloUse& use) const {
930   const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
931   if (!options_.is_use_allowed_in_alternate_mem_fn(use)) {
932     return false;
933   }
934   if (use.instruction->opcode() == HloOpcode::kWhile) {
935     HloComputation* while_body = use.instruction->while_body();
936 
937     // We don't want to allocate this buffer in alternate memory if it will be
938     // evicted anyway. Find out if it has an early use or a late definition that
939     // would make sense to keep it in the alternate memory.
940     HloValue* parameter_value =
941         &alias_analysis_.dataflow_analysis().GetUniqueValueAt(
942             while_body->parameter_instruction(0), use.operand_index);
943     int64_t parameter_time =
944         instruction_schedule.at(while_body->parameter_instruction(0));
945     int64_t root_time = instruction_schedule.at(while_body->root_instruction());
946     int64_t min_use_time = root_time;
947     for (const HloUse& parameter_use : parameter_value->uses()) {
948       int64_t use_time = instruction_schedule.at(parameter_use.instruction);
949       if (parameter_use.instruction->opcode() != HloOpcode::kGetTupleElement &&
950           parameter_use.instruction->opcode() != HloOpcode::kTuple &&
951           parameter_use.instruction->opcode() != HloOpcode::kBitcast &&
952           use_time > parameter_time) {
953         min_use_time = std::min(min_use_time, use_time);
954       }
955     }
956     // If there is no use of this buffer inside the while loop, there is no need
957     // to allocate it in the loop.
958     if (min_use_time == root_time) {
959       VLOG(4) << "While allocation not allowed in alternate memory. "
960               << "use time = " << min_use_time << ", root time = " << root_time;
961       return false;
962     }
963     const Shape& shape = parameter_value->shape();
964     // Allow the buffer in alternate memory if the buffer has a short live range
965     // either at the beginning or end of the while loop body.
966     if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
967             shape, parameter_time, min_use_time)) {
968       VLOG(4) << "While allocation not allowed in alternate memory. "
969               << "use time = " << min_use_time << ", root time = " << root_time;
970       return false;
971     }
972     // Check if there is a required assignment for the while loop output.
973     HloValue* while_value =
974         &alias_analysis_.dataflow_analysis().GetUniqueValueAt(
975             use.instruction, use.operand_index);
976     int64_t while_time = instruction_schedule.at(use.instruction);
977     auto existing_required_assignment =
978         RequiredMemoryAssignmentAt(while_value, while_time);
979     if (existing_required_assignment &&
980         existing_required_assignment->memory_space == MemorySpace::kDefault) {
981       VLOG(4) << "While allocation not allowed in alternate memory because "
982                  "there is a required default memory assignment.";
983       return false;
984     }
985   } else if (use.instruction->opcode() == HloOpcode::kConditional) {
986     // For any use of this conditional (the same value might be passed into
987     // multiple called computations), determine if the parameter->first use
988     // dependency is short.
989     int64_t conditional_time = instruction_schedule.at(use.instruction);
990     for (const AllocationValue::Use& other_use : value.uses()) {
991       if (other_use.hlo_use.instruction != use.instruction) {
992         continue;
993       }
994       HloComputation* called_computation =
995           use.instruction->called_computations().at(
996               other_use.hlo_use.operand_number - 1);
997       const HloInstruction* parameter_instruction =
998           called_computation->parameter_instruction(0);
999       HloValue* parameter_value =
1000           &alias_analysis_.dataflow_analysis().GetUniqueValueAt(
1001               parameter_instruction, other_use.hlo_use.operand_index);
1002       int64_t parameter_time = instruction_schedule.at(parameter_instruction);
1003       int64_t min_use_time = conditional_time;
1004       for (const HloUse& parameter_use : parameter_value->uses()) {
1005         if (parameter_use.instruction->parent() == called_computation &&
1006             parameter_use.instruction->opcode() !=
1007                 HloOpcode::kGetTupleElement &&
1008             parameter_use.instruction->opcode() != HloOpcode::kTuple &&
1009             parameter_use.instruction->opcode() != HloOpcode::kBitcast) {
1010           min_use_time = std::min(
1011               min_use_time, instruction_schedule.at(parameter_use.instruction));
1012         }
1013       }
1014       if (options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
1015               parameter_value->shape(), parameter_time, min_use_time)) {
1016         VLOG(4) << "Conditional allocation allowed in alternate memory for "
1017                    "computation = "
1018                 << called_computation->name()
1019                 << ", parameter time = " << parameter_time
1020                 << ", min use time = " << min_use_time;
1021         return true;
1022       } else {
1023         VLOG(4) << "Conditional allocation not allowed in alternate memory for "
1024                    "computation = "
1025                 << called_computation->name()
1026                 << ", parameter time = " << parameter_time
1027                 << ", min use time = " << min_use_time;
1028       }
1029     }
1030     return false;
1031   }
1032 
1033   return true;
1034 }
1035 
AppendBufferInfoDebugString(const AlternateMemoryBestFitHeap::BufferInterval & interval,std::string * debug_str) const1036 void AlternateMemoryBestFitHeap::AppendBufferInfoDebugString(
1037     const AlternateMemoryBestFitHeap::BufferInterval& interval,
1038     std::string* debug_str) const {
1039   // Columns in buffer information:
1040   // buffer_id: int. This value can be used to match the allocation in
1041   // allocation information.
1042   // buffer_name: string.
1043   // alt_mem_benefit: float. Roughly corresponds to how much the cost analysis
1044   // thought it would be beneficial to put this in the alternate memory. The
1045   // higher the value, the more it is memory bound.
1046   // size: int. In bytes.
1047   // definition_time: int. Logical time this value was defined in the schedule.
1048   // use_times: string. This is a semicolon-separated list of integers for all
1049   // the use times.
1050   // use_names: string. This is a semicolon-separated list of string
1051   // representation of uses.
1052   if (debug_str->empty()) {
1053     // Append the column names.
1054     absl::StrAppend(debug_str,
1055                     "buffer_id,buffer_name,alt_mem_benefit,size,"
1056                     "definition_time,use_times,use_names\n");
1057   }
1058   const HloBuffer& buffer =
1059       alias_analysis_.GetBufferContainingValue(*interval.buffer);
1060   const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
1061   int64_t definition_time =
1062       instruction_schedule.at(interval.buffer->defining_position().instruction);
1063   std::vector<std::pair<int64, std::string>> uses;
1064   for (const HloValue* value : buffer.values()) {
1065     for (const HloUse& use : value->uses()) {
1066       uses.push_back(
1067           {instruction_schedule.at(use.instruction), use.ToString()});
1068     }
1069   }
1070   absl::c_sort(uses);
1071   std::vector<int64> use_times;
1072   std::vector<std::string> use_names;
1073   use_times.reserve(uses.size());
1074   use_names.reserve(uses.size());
1075   for (const auto& use : uses) {
1076     use_times.push_back(use.first);
1077     use_names.push_back(use.second);
1078   }
1079 
1080   absl::StrAppend(debug_str, buffer.id(), ",");
1081   absl::StrAppend(debug_str, "\"", interval.buffer->ToShortString(), "\",");
1082   auto alternate_memory_benefit =
1083       options_.prefetch_interval_picker->BufferIntervalAlternateMemoryBenefit(
1084           interval);
1085   absl::StrAppend(
1086       debug_str, alternate_memory_benefit ? *alternate_memory_benefit : 0, ",");
1087   absl::StrAppend(debug_str, interval.size, ",");
1088   absl::StrAppend(debug_str, definition_time, ",");
1089   absl::StrAppend(debug_str, "\"", absl::StrJoin(use_times, ";"), "\",");
1090   absl::StrAppend(debug_str, "\"", absl::StrJoin(use_names, ";"), "\"");
1091   absl::StrAppend(debug_str, "\n");
1092 }
1093 
AppendAllocationInfoDebugString(const AllocationValue & value,const MemorySpaceAssignment::Allocation & allocation,std::string & debug_str) const1094 void AlternateMemoryBestFitHeap::AppendAllocationInfoDebugString(
1095     const AllocationValue& value,
1096     const MemorySpaceAssignment::Allocation& allocation,
1097     std::string& debug_str) const {
1098   // Columns in allocation information:
1099   // buffer_id: int. This value can be used the match with buffer info.
1100   // size: int. In bytes.
1101   // offset: int. In bytes.
1102   // start_time: int. Logical start time of the allocation.
1103   // end_time: int. Logical end time of the allocation.
1104   if (debug_str.empty()) {
1105     // Append the column names.
1106     absl::StrAppend(&debug_str, "buffer_id,size,offset,start_time,end_time\n");
1107   }
1108   if (allocation.memory_space() == MemorySpace::kAlternate) {
1109     const HloBuffer& buffer =
1110         alias_analysis_.GetBufferContainingValue(*value.value());
1111     absl::StrAppend(&debug_str, buffer.id(), ",");
1112     absl::StrAppend(&debug_str, value.size(), ",");
1113     absl::StrAppend(&debug_str, allocation.chunk().offset, ",");
1114     absl::StrAppend(&debug_str, allocation.start_time(), ",");
1115     absl::StrAppend(&debug_str, allocation.end_time(), "\n");
1116   }
1117 }
1118 
DumpDebugStringsIfEnabled() const1119 void AlternateMemoryBestFitHeap::DumpDebugStringsIfEnabled() const {
1120   if (!options_.dump_fn) {
1121     return;
1122   }
1123   options_.dump_fn("bufferinfo", buffer_info_str_);
1124   options_.dump_fn("allocinfo", allocation_info_str_);
1125 }
1126 
Finish()1127 HeapSimulator::Result<HloValue> AlternateMemoryBestFitHeap::Finish() {
1128   AllocateReservedScopedAllocations();
1129   if (options_.enable_cross_program_prefetch) {
1130     absl::optional<AlternateMemoryBestFitHeap::BufferInterval>
1131         prefetch_candidate = FindCrossProgramPrefetchCandidate(
1132             alias_analysis_, hlo_live_range_, options_);
1133     if (prefetch_candidate) {
1134       HloModule* module =
1135           prefetch_candidate->buffer->instruction()->GetModule();
1136       AllocateCrossProgramPrefetchBuffer(module, prefetch_candidate);
1137     }
1138   }
1139 
1140   std::vector<BufferInterval> sorted_buffer_intervals =
1141       GetSortedBufferIntervals();
1142 
1143   VLOG(1) << "Assigning buffers to alternate memory. Max heap size = "
1144           << options_.max_size_in_bytes;
1145 
1146   AddInputAndOutputRequiredAssignments();
1147 
1148   if (VLOG_IS_ON(3)) {
1149     VLOG(3) << "Flattened instruction sequence:";
1150     const auto& instruction_sequence =
1151         hlo_live_range_.flattened_instruction_sequence().instructions();
1152     for (int i = 0; i < instruction_sequence.size(); ++i) {
1153       VLOG(3) << " " << i << ": " << instruction_sequence[i]->parent()->name()
1154               << " " << instruction_sequence[i]->name();
1155     }
1156   }
1157 
1158   for (const auto& interval : sorted_buffer_intervals) {
1159     auto colocated_intervals = GetSortedColocatedIntervals(interval);
1160     if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) {
1161       // Increment the reserved part of alternate memory so that it is not
1162       // available for other buffers.
1163       reserved_in_bytes_ += options_.size_fn(*interval.buffer);
1164     }
1165   }
1166   VLOG(2) << "Total reserved bytes = " << reserved_in_bytes_;
1167 
1168   for (auto& interval : sorted_buffer_intervals) {
1169     if (!interval.need_allocation) {
1170       continue;
1171     }
1172 
1173     if (!MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
1174             interval)) {
1175       continue;
1176     }
1177 
1178     HloInstruction* inst = interval.buffer->instruction();
1179     HloModule* module = inst->GetModule();
1180 
1181     // Don't intra-program prefetch a cross program prefetch
1182     if (inst->opcode() == HloOpcode::kParameter &&
1183         absl::c_count(module->CrossProgramPrefetches(),
1184                       std::make_pair(inst->parameter_number(),
1185                                      interval.buffer->index())) > 0) {
1186       VLOG(3) << "Skip " << interval.buffer->ToShortString()
1187               << " because it is cross-program prefetched.";
1188       continue;
1189     }
1190 
1191     if (interval.size > available_heap_size()) {
1192       VLOG(3) << "Skip " << interval.buffer->ToShortString()
1193               << " because the buffer is larger than the heap size.";
1194       continue;
1195     }
1196 
1197     auto colocated_intervals = GetSortedColocatedIntervals(interval);
1198 
1199     if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) {
1200       VLOG(3) << "Interval " << interval.buffer->ToShortString()
1201               << " is reserved in the alternate memory.";
1202       for (const BufferInterval* colocated_interval : colocated_intervals) {
1203         const HloValue* value = colocated_interval->buffer;
1204         // Color all of the aliased reserved buffers here because reserved
1205         // alternate memory allocations will not have an entry in preset
1206         // allocations that is normally used for coloring.
1207         for (auto& position : value->positions()) {
1208           VLOG(4) << "Coloring " << position.ToString();
1209           Shape* shape = ShapeUtil::GetMutableSubshape(
1210               position.instruction->mutable_shape(), position.index);
1211           CHECK(shape->IsArray()) << "Coloring a shape that is not an array: "
1212                                   << position.ToString();
1213           shape->mutable_layout()->set_memory_space(
1214               options_.alternate_memory_space);
1215         }
1216       }
1217       continue;
1218     }
1219 
1220     if (colocated_intervals.size() > 1 &&
1221         !options_.allocate_across_sequential_calls) {
1222       VLOG(4) << "Not allocating " << interval.buffer->ToShortString()
1223               << " because it aliases with another interval and "
1224               << " allocate_across_sequential_calls is false.";
1225       continue;
1226     }
1227 
1228     if (!ConsumeFuel("memory_space_assignment", [&] {
1229           return absl::StrCat("Ran out of fuel at buffer: ",
1230                               colocated_intervals[0]->buffer->ToShortString());
1231         })) {
1232       continue;
1233     }
1234 
1235     AppendBufferInfoDebugString(interval, &buffer_info_str_);
1236 
1237     std::vector<AllocationValue> allocation_values;
1238     CreateAllocationValuesFromColocatedIntervals(colocated_intervals,
1239                                                  allocation_values);
1240 
1241     // Retry allocating this value with larger limits if allocation fails.
1242     bool repacked = false;
1243     for (int retry_number = 0; retry_number < options_.max_retries;
1244          retry_number++) {
1245       AddRequiredAssignmentsForColocatedIntervals(colocated_intervals);
1246       bool final_retry = (retry_number == options_.max_retries - 1);
1247       options_.prefetch_interval_picker->SetRetryNumber(retry_number);
1248       Result result =
1249           AllocateAllocationValues(absl::MakeSpan(allocation_values));
1250       VLOG(2) << "Allocation result = "
1251               << absl::StrFormat("%x", static_cast<int>(result));
1252       if (result_requires_uncommit(result) ||
1253           (!final_retry && result_failed_because_of_async_copy(result))) {
1254         UncommitPendingChunks(absl::MakeSpan(allocation_values));
1255         VLOG(2) << "Couldn't allocate. Retry number " << retry_number;
1256       } else if ((result_is(result, Result::kFailOutOfMemory) ||
1257                   options_.repack_after_every_allocation) &&
1258                  num_repacks_ < options_.max_repacks && !repacked) {
1259         UncommitPendingChunks(absl::MakeSpan(allocation_values));
1260         ++num_repacks_;
1261         repacked = true;
1262         CHECK_NE(options_.repacker, nullptr);
1263         std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*>
1264             repack_allocation_blocks;
1265         ExportAllocationsForRepacking(repack_allocation_blocks);
1266         VLOG(2) << "Repacking.";
1267         auto repack_status =
1268             options_.repacker->Repack(absl::MakeSpan(repack_allocation_blocks));
1269         CHECK_EQ(repack_status.status(), Status::OK());
1270         VLOG(2) << "Repack complete. Modified = " << *repack_status;
1271         if (*repack_status) {
1272           ImportRepackedAllocations();
1273           --retry_number;
1274         }
1275       } else {
1276         FinalizeAllocations(absl::MakeSpan(allocation_values));
1277         break;
1278       }
1279     }
1280   }
1281 
1282   VLOG(3) << "Debug buffer info: ";
1283   XLA_VLOG_LINES(3, buffer_info_str_);
1284   VLOG(3) << "Debug allocation info: ";
1285   XLA_VLOG_LINES(3, allocation_info_str_);
1286   DumpDebugStringsIfEnabled();
1287 
1288   HeapSimulator::Result<HloValue> result;
1289   result.heap_size = result_.heap_size;
1290   result.heap_results.emplace_back(std::move(result_));
1291   return result;
1292 }
1293 
AddRequiredAssignmentsForColocatedIntervals(absl::Span<const AlternateMemoryBestFitHeap::BufferInterval * const> colocated_intervals)1294 void AlternateMemoryBestFitHeap::AddRequiredAssignmentsForColocatedIntervals(
1295     absl::Span<const AlternateMemoryBestFitHeap::BufferInterval* const>
1296         colocated_intervals) {
1297   // TODO(berkin): For now, place the phi values due to conditionals in
1298   // default memory.
1299   for (const BufferInterval* colocated_interval : colocated_intervals) {
1300     const HloValue* value = colocated_interval->buffer;
1301     for (const auto& position : value->positions()) {
1302       if (position.instruction->opcode() == HloOpcode::kConditional) {
1303         VLOG(3) << "Adding required assignment for condition output: "
1304                 << value->ToShortString();
1305         AddRequiredAssignment(position.instruction, position.index,
1306                               MemorySpace::kDefault);
1307         for (const HloComputation* called_computation :
1308              position.instruction->called_computations()) {
1309           AddRequiredAssignment(called_computation->root_instruction(),
1310                                 position.index, MemorySpace::kDefault);
1311         }
1312       }
1313     }
1314   }
1315 }
1316 
CreateAllocationValuesFromColocatedIntervals(absl::Span<const AlternateMemoryBestFitHeap::BufferInterval * const> colocated_intervals,std::vector<MemorySpaceAssignment::AllocationValue> & allocation_values)1317 void AlternateMemoryBestFitHeap::CreateAllocationValuesFromColocatedIntervals(
1318     absl::Span<const AlternateMemoryBestFitHeap::BufferInterval* const>
1319         colocated_intervals,
1320     std::vector<MemorySpaceAssignment::AllocationValue>& allocation_values) {
1321   // Create AllocationValues for all the colocated intervals.
1322   for (const auto& colocated_interval : colocated_intervals) {
1323     CreateAllocationValues(*colocated_interval, allocation_values);
1324   }
1325   // Go through the AllocationValues and delete the ones that have the identical
1326   // defining instruction and use instructions. This is useful for async
1327   // operations that can read and write to the same buffer, e.g., in-place
1328   // asynchronous collective permute. The AllocationValues that corresponds to
1329   // collective-permute-start{0} (the input) and collective-permute-start{1}
1330   // (the output) refer to the same buffer by definition (since they are created
1331   // from colocated intervals). If we don't delete one of these buffers, then
1332   // when we try to allocate the AllocationValue, we would think they overlap.
1333   auto create_instruction_vector = [](const AllocationValue& allocation_value) {
1334     std::vector<const HloInstruction*> instruction_vector;
1335     instruction_vector.push_back(allocation_value.defining_instruction());
1336     for (const AllocationValue::Use& use : allocation_value.uses()) {
1337       instruction_vector.push_back(use.hlo_use.instruction);
1338     }
1339     return instruction_vector;
1340   };
1341   for (int i = 0; i < allocation_values.size() - 1; ++i) {
1342     for (int j = i + 1; j < allocation_values.size(); ++j) {
1343       const AllocationValue& allocation_value_1 = allocation_values[i];
1344       const AllocationValue& allocation_value_2 = allocation_values[j];
1345       if (create_instruction_vector(allocation_value_1) ==
1346           create_instruction_vector(allocation_value_2)) {
1347         VLOG(3) << "Allocation values " << allocation_value_1.ToShortString()
1348                 << " and " << allocation_value_2.ToShortString()
1349                 << " are equivalent, deleting the second one.";
1350         allocation_values.erase(allocation_values.begin() + j);
1351         --j;
1352       }
1353     }
1354   }
1355 
1356   FindAliases(&allocation_values);
1357 }
1358 
1359 AlternateMemoryBestFitHeap::Result
AllocateAllocationValues(absl::Span<MemorySpaceAssignment::AllocationValue> allocation_values)1360 AlternateMemoryBestFitHeap::AllocateAllocationValues(
1361     absl::Span<MemorySpaceAssignment::AllocationValue> allocation_values) {
1362   const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
1363 
1364   // Find the use times across all of the related AllocationValues and sort
1365   // them. We use these to find allocations that are available throughout the
1366   // entire live range of all the AllocationValues.
1367   std::vector<int64_t> all_use_times;
1368   for (const AllocationValue& allocation_value : allocation_values) {
1369     absl::c_transform(allocation_value.uses(),
1370                       std::back_inserter(all_use_times),
1371                       [](const AllocationValue::Use& use) { return use.time; });
1372   }
1373   absl::c_sort(all_use_times);
1374 
1375   // Data structure to contain the preferred offset for a given computation.
1376   // We ensure that the same offset will be allocated outside the while loop
1377   // as well as inside the while loop.
1378   absl::flat_hash_map<const HloComputation*, AliasedOffset*>
1379       preferred_offset_for_computation;
1380 
1381   Result result = Result::kSuccess;
1382   for (AllocationValue& allocation_value : allocation_values) {
1383     int64_t definition_time =
1384         instruction_schedule.at(allocation_value.defining_instruction());
1385 
1386     AliasedOffset* preferred_offset = nullptr;
1387     auto preferred_offset_it =
1388         preferred_offset_for_computation.find(allocation_value.computation());
1389     if (preferred_offset_it != preferred_offset_for_computation.end()) {
1390       preferred_offset = preferred_offset_it->second;
1391     }
1392 
1393     // Iterate over the uses.
1394     for (int use_idx = 0; use_idx < allocation_value.uses().size(); ++use_idx) {
1395       const AllocationValue::Use& use = allocation_value.uses().at(use_idx);
1396       const HloUse hlo_use = use.hlo_use;
1397       int64_t use_time = instruction_schedule.at(hlo_use.instruction);
1398       int64_t latest_prefetch_time = use_time;
1399       bool allow_no_copy_alternate_mem_allocation = true;
1400       absl::optional<int64> earliest_prefetch_time = absl::nullopt;
1401 
1402       // Sequential calls include kWhile, kCall, and kConditional opcodes.
1403       bool is_sequential_call =
1404           (GetInstructionCallContext(hlo_use.instruction->opcode()) ==
1405            CallContext::kSequential);
1406       if (is_sequential_call) {
1407         for (const HloComputation* called_computation :
1408              hlo_use.instruction->called_computations()) {
1409           const HloLiveRange::TimeBound& computation_span =
1410               hlo_live_range_.computation_span_times().at(called_computation);
1411           latest_prefetch_time =
1412               std::min(computation_span.start - 1, latest_prefetch_time);
1413         }
1414         if (hlo_use.instruction->opcode() == HloOpcode::kWhile) {
1415           // Given an example while loop and flattened schedule (logical times
1416           // shown on the left):
1417           //
1418           // 0:  a = ...
1419           // 1:  ...
1420           //     cond {
1421           // 2:   p = param(0)
1422           // 3:   ...
1423           //     }
1424           //     body {
1425           // 4:   p = param(0)
1426           // 5:   ...
1427           // 6:   ROOT ...
1428           //     }
1429           // 7:  w = while(a), body=body, cond=cond
1430           //
1431           // When processing "a" (time 0) and its while use (time 7), we update
1432           // the interval to time 0-4. This is so that the remaining interval
1433           // (5-6) can be allocated separately and this buffer doesn't waste
1434           // alternate memory space within the while loop body.
1435           HloComputation* while_body = hlo_use.instruction->while_body();
1436           // We require while body ROOTs to be the last in the schedule.
1437           CHECK_EQ(instruction_schedule.at(while_body->root_instruction()) + 1,
1438                    instruction_schedule.at(hlo_use.instruction))
1439               << "While body ROOTs need to be the last in the schedule!  "
1440                  "Please run RootInstructionSinker.";
1441           // Replace the use time with the parameter time so that we can decide
1442           // on alternate memory allocations within the while loop body when we
1443           // look at uses within the while loop body.
1444           use_time =
1445               instruction_schedule.at(while_body->parameter_instruction(0));
1446         } else if (hlo_use.instruction->opcode() == HloOpcode::kConditional) {
1447           // Replace the use time with the earliest parameter of called
1448           // computations.
1449           for (const HloComputation* called_computation :
1450                hlo_use.instruction->called_computations()) {
1451             use_time = std::min(
1452                 use_time, instruction_schedule.at(
1453                               called_computation->parameter_instruction(0)));
1454           }
1455         }
1456       }
1457 
1458       // Add a required assignment in default memory if the use not allowed in
1459       // alternate memory.
1460       if (!IsUseAllowedInAlternateMemory(allocation_value, hlo_use)) {
1461         AddRequiredAssignment(allocation_value.value(), hlo_use.instruction,
1462                               MemorySpace::kDefault, use_time);
1463       } else if (use_idx > 0) {
1464         // We allow buffers in alternate memory that are passed into
1465         // conditionals to give up their alternate memory allocation inside the
1466         // called computation. This means that if a conditional operator has an
1467         // alternate memory allocation, subsequent uses cannot use the same
1468         // alternate memory allocation in order not to clobber data. So we force
1469         // default memory allocation for these subsequent uses.
1470         const AllocationValue::Use& previous_use =
1471             allocation_value.uses().at(use_idx - 1);
1472         if (previous_use.hlo_use.instruction->opcode() ==
1473                 HloOpcode::kConditional &&
1474             previous_use.hlo_use.instruction != hlo_use.instruction) {
1475           allow_no_copy_alternate_mem_allocation = false;
1476           earliest_prefetch_time =
1477               instruction_schedule.at(previous_use.hlo_use.instruction);
1478           VLOG(3) << "Previous use (" << previous_use.hlo_use.ToString()
1479                   << ") of use (" << hlo_use.ToString()
1480                   << ") is a conditional, so this use will need to evict. "
1481                   << "Earliest prefetch time = " << *earliest_prefetch_time;
1482         }
1483       }
1484 
1485       // Bitcasts don't define buffers and don't directly consume buffers. Skip
1486       // allocating buffers for bitcast uses (unless they are the root
1487       // instruction). The uses that feed from bitcasts will be handled
1488       // specially.
1489       if (hlo_use.instruction->opcode() != HloOpcode::kBitcast ||
1490           hlo_use.instruction ==
1491               hlo_use.instruction->parent()->root_instruction()) {
1492         AllocationRequest request;
1493         // Rarely, (e.g., when conditional true and false parameters are the
1494         // same), definition time can be the time of the conditional and use
1495         // time is the parameter use, which is less.
1496         request.start_time = std::min(definition_time, use_time);
1497         request.end_time = use_time;
1498         request.latest_prefetch_time = latest_prefetch_time;
1499         request.size = allocation_value.size();
1500         request.allow_no_copy_alternate_mem_allocation =
1501             allow_no_copy_alternate_mem_allocation;
1502         request.earliest_prefetch_time = earliest_prefetch_time;
1503         request.preferred_offset = preferred_offset;
1504         request.use = &use;
1505         request.allocation_value = &allocation_value;
1506         request.all_use_times = all_use_times;
1507         result_mark(AllocateSegment(request), result);
1508         if (result_requires_uncommit(result)) {
1509           // If the allocation finding failed (e.g., due to running out of
1510           // asynchronous copies), then fall back to allocating the buffer
1511           // entirely in the default memory.
1512           return result;
1513         }
1514 
1515         // If there are multiple uses, they can try using the memory allocation
1516         // already at the alternate memory.
1517         definition_time = instruction_schedule.at(hlo_use.instruction);
1518       }
1519 
1520       // Propagate the allocation to any aliases this use might have had.
1521       MemorySpaceAssignment::Allocation* aliased_allocation =
1522           GetLiveAllocationAt(*allocation_value.allocation_sequence(),
1523                               use_time);
1524       for (const HloPosition& aliased_position : use.aliases) {
1525         AddAliasedRequiredAssignment(aliased_position.instruction,
1526                                      aliased_position.index,
1527                                      aliased_allocation);
1528       }
1529 
1530       if (hlo_use.instruction->opcode() == HloOpcode::kWhile &&
1531           aliased_allocation->memory_space() == MemorySpace::kAlternate) {
1532         // For while uses that are allocated in the alternate memory space, if
1533         // they also have an allocation in the default memory space in their
1534         // allocation sequence, create a "parent" allocation that mirrors this
1535         // default memory space allocation. When we process the parent
1536         // allocation, we add an additional parameter to the while that is a
1537         // reference to the buffer in the default memory space. With parent
1538         // allocations, we don't need to unnecessarily evict buffers since they
1539         // already have a copy in the default memory space. We search backwards
1540         // (latest to earliest in execution time) for a suitable allocation in
1541         // order to find the most recent one.
1542         if (absl::c_find_if(allocation_value.value()->positions(),
1543                             [&hlo_use](const HloPosition& position) {
1544                               return position.instruction ==
1545                                          hlo_use.instruction &&
1546                                      position.index == hlo_use.operand_index;
1547                             }) != allocation_value.value()->positions().end()) {
1548           auto allocation_sequence = allocation_value.allocation_sequence();
1549           auto prev_allocation_in_default_mem_it = std::find_if(
1550               allocation_sequence->rbegin(), allocation_sequence->rend(),
1551               [&](const auto& allocation) {
1552                 return allocation->memory_space() == MemorySpace::kDefault &&
1553                        allocation->defining_position() ==
1554                            allocation_value.defining_position();
1555               });
1556           if (prev_allocation_in_default_mem_it !=
1557               allocation_sequence->rend()) {
1558             VLOG(3) << "Found a prev allocation in default mem for while use: "
1559                     << (*prev_allocation_in_default_mem_it)->ToString();
1560             auto body_allocation_value_it = absl::c_find_if(
1561                 allocation_values, [&](const AllocationValue& value) {
1562                   return value.computation() ==
1563                              hlo_use.instruction->while_body() &&
1564                          value.defining_instruction()->opcode() ==
1565                              HloOpcode::kParameter;
1566                 });
1567             CHECK_NE(body_allocation_value_it, allocation_values.end());
1568             VLOG(3) << "Body allocation value: "
1569                     << body_allocation_value_it->ToShortString();
1570             int64_t body_parameter_time = instruction_schedule.at(
1571                 body_allocation_value_it->defining_instruction());
1572             body_allocation_value_it->allocation_sequence()->push_back(
1573                 absl::make_unique<MemorySpaceAssignment::ParentAllocation>(
1574                     **prev_allocation_in_default_mem_it, hlo_use.instruction,
1575                     body_allocation_value_it->defining_position(),
1576                     body_parameter_time));
1577             VLOG(3) << "Created: "
1578                     << body_allocation_value_it->allocation_sequence()
1579                            ->back()
1580                            ->ToString();
1581           }
1582         }
1583         // Special case for while loops since the root offset must agree with
1584         // other offsets: remember the preferred offset for the while loop body.
1585         preferred_offset_for_computation[hlo_use.instruction->while_body()] =
1586             GetAliasedOffset(*aliased_allocation);
1587       }
1588     }
1589   }
1590   return result;
1591 }
1592 
operator <(const AsynchronousCopy & a,const AsynchronousCopy & b)1593 bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b) {
1594   return (a.start_time < b.start_time && a.end_time <= b.end_time) ||
1595          (a.start_time <= b.start_time && a.end_time < b.end_time);
1596 }
1597 
AddCopy(const AsynchronousCopy & copy)1598 void AsynchronousCopyOrdering::AddCopy(const AsynchronousCopy& copy) {
1599   auto it_and_inserted = ranges_.insert(copy);
1600   CHECK(it_and_inserted.second ||
1601         it_and_inserted.first->start_time == copy.start_time);
1602 }
1603 
RemoveCopy(const AsynchronousCopy & copy)1604 void AsynchronousCopyOrdering::RemoveCopy(const AsynchronousCopy& copy) {
1605   auto copy_it = ranges_.find(copy);
1606   CHECK(copy_it != ranges_.end());
1607   ranges_.erase(copy_it);
1608 }
1609 
ViolatesOrdering(int64_t start_time,int64_t end_time) const1610 absl::optional<AsynchronousCopy> AsynchronousCopyOrdering::ViolatesOrdering(
1611     int64_t start_time, int64_t end_time) const {
1612   // We allow identical start and end times. It is enough to check for just the
1613   // start time in case we find a match in ranges_ because the found value will
1614   // either be identical to {start_time, end_time} (and this doesn't violate) or
1615   // its start_time will be smaller and end_time will be larger (this violates).
1616   auto copy_it = ranges_.find(
1617       {start_time, end_time, MemorySpaceAssignment::MemorySpace::kAlternate});
1618   if (copy_it != ranges_.end() && copy_it->start_time != start_time) {
1619     VLOG(4) << "Violates ordering: (" << start_time << ", " << end_time
1620             << ") and (" << copy_it->start_time << ", " << copy_it->end_time
1621             << ")";
1622     return *copy_it;
1623   }
1624   return absl::nullopt;
1625 }
1626 
1627 AlternateMemoryBestFitHeap::AliasedOffset*
GetAliasedOffset(const MemorySpaceAssignment::Allocation & allocation)1628 AlternateMemoryBestFitHeap::GetAliasedOffset(
1629     const MemorySpaceAssignment::Allocation& allocation) {
1630   auto aliased_offset_it = aliased_offset_map_.find(&allocation);
1631   CHECK(aliased_offset_it != aliased_offset_map_.end());
1632   return aliased_offset_it->second;
1633 }
1634 
CreateOrAddToAliasedOffset(const MemorySpaceAssignment::Allocation & allocation,AlternateMemoryBestFitHeap::AliasedOffset * aliased_offset)1635 void AlternateMemoryBestFitHeap::CreateOrAddToAliasedOffset(
1636     const MemorySpaceAssignment::Allocation& allocation,
1637     AlternateMemoryBestFitHeap::AliasedOffset* aliased_offset) {
1638   CHECK(allocation.memory_space() == MemorySpace::kAlternate);
1639   CHECK(!aliased_offset_map_.contains(&allocation));
1640   if (!aliased_offset) {
1641     aliased_offsets_.push_back({allocation.chunk().offset});
1642     aliased_offset = &aliased_offsets_.back();
1643   }
1644   CHECK_EQ(allocation.chunk().offset, aliased_offset->offset);
1645   CHECK(aliased_offset->allocations.insert(&allocation).second);
1646   aliased_offset_map_[&allocation] = aliased_offset;
1647 }
1648 
1649 /*static*/ MemorySpaceAssignment::Allocation*
GetLiveAllocationAt(const MemorySpaceAssignment::AllocationSequence & allocations,int64_t time)1650 AlternateMemoryBestFitHeap::GetLiveAllocationAt(
1651     const MemorySpaceAssignment::AllocationSequence& allocations,
1652     int64_t time) {
1653   for (auto allocation_it = allocations.rbegin();
1654        allocation_it != allocations.rend(); ++allocation_it) {
1655     if ((*allocation_it)->start_time() <= time &&
1656         (*allocation_it)->end_time() >= time) {
1657       return allocation_it->get();
1658     }
1659   }
1660   return nullptr;
1661 }
1662 
AllocateCrossProgramPrefetchBuffer(HloModule * module,absl::optional<BufferInterval> prefetch_candidate)1663 void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer(
1664     HloModule* module, absl::optional<BufferInterval> prefetch_candidate) {
1665   if (!prefetch_candidate) {
1666     return;
1667   }
1668 
1669   ChunkCandidate chunk_candidate = FindChunkCandidate(*prefetch_candidate);
1670   if (chunk_candidate.chunk.offset != 0 ||
1671       chunk_candidate.heap_size > available_heap_size()) {
1672     LOG(WARNING)
1673         << "Could not allocate preferred memory for cross program prefetch";
1674     return;
1675   }
1676 
1677   const HloValue* buffer = prefetch_candidate->buffer;
1678   int64_t parameter = buffer->instruction()->parameter_number();
1679   module->AddCrossProgramPrefetch(parameter, buffer->index());
1680 
1681   MemorySpaceAssignment::AllocationSequence allocations;
1682   allocations.push_back(absl::make_unique<MemorySpaceAssignment::Allocation>(
1683       buffer->defining_position(), MemorySpace::kDefault, kDummyChunk,
1684       prefetch_candidate->start, prefetch_candidate->end,
1685       /*is_scoped_allocation=*/false));
1686 
1687   // Find the earliest use.
1688   const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
1689   auto uses = buffer->uses();
1690   auto use_schedule_compare = [&](const HloUse& lhs, const HloUse& rhs) {
1691     return instruction_schedule.at(lhs.instruction) <
1692            instruction_schedule.at(rhs.instruction);
1693   };
1694   auto first_use = absl::c_min_element(uses, use_schedule_compare);
1695   int64_t latest_prefetch_time =
1696       instruction_schedule.at(first_use->instruction);
1697 
1698   // Find the latest use time.
1699   int64_t last_use_time = instruction_schedule.at(
1700       absl::c_max_element(uses, use_schedule_compare)->instruction);
1701   for (const HloValue* colocation : prefetch_candidate->colocations) {
1702     last_use_time = std::max(
1703         last_use_time,
1704         instruction_schedule.at(
1705             absl::c_max_element(colocation->uses(), use_schedule_compare)
1706                 ->instruction));
1707   }
1708 
1709   int64_t end_of_program_prefetch_end_time = instruction_schedule.size();
1710   int64_t end_of_program_prefetch_start_time =
1711       options_.prefetch_interval_picker->PreferredPrefetchStartTime(
1712           buffer->defining_position().shape(), last_use_time,
1713           end_of_program_prefetch_end_time, end_of_program_prefetch_end_time);
1714   VLOG(2) << "last use time = " << last_use_time
1715           << ", end-of-program prefetch start time = "
1716           << end_of_program_prefetch_start_time;
1717   bool free_buffer =
1718       (options_.enable_cross_program_prefetch_freeing &&
1719        end_of_program_prefetch_start_time > last_use_time &&
1720        end_of_program_prefetch_start_time < end_of_program_prefetch_end_time);
1721   int64_t cross_program_prefetch_end_time =
1722       free_buffer ? last_use_time : prefetch_candidate->end;
1723 
1724   AddAsyncCopy(*allocations.back(), MemorySpace::kAlternate,
1725                chunk_candidate.chunk, prefetch_candidate->start,
1726                cross_program_prefetch_end_time, latest_prefetch_time,
1727                &allocations, /*aliased_offset=*/nullptr,
1728                /*is_cross_program_prefetch=*/true);
1729   absl::c_for_each(uses, [&](auto& use) { allocations.back()->AddUse(use); });
1730   AliasedOffset* cross_program_prefetch_offset =
1731       GetAliasedOffset(*allocations.back());
1732 
1733   if (free_buffer) {
1734     VLOG(2) << "Adding an end-of-program prefetch for freed "
1735                "cross-program-prefetched buffer.";
1736     AddAsyncCopy(*allocations.front(), MemorySpace::kAlternate,
1737                  chunk_candidate.chunk, end_of_program_prefetch_start_time,
1738                  end_of_program_prefetch_end_time,
1739                  end_of_program_prefetch_end_time, &allocations,
1740                  cross_program_prefetch_offset);
1741     CHECK_EQ(cross_program_prefetch_offset->offset,
1742              allocations.back()->chunk().offset);
1743   }
1744 
1745   const int allocations_initial_size = allocations_->size();
1746   for (auto& allocation : allocations) {
1747     if (allocation->memory_space() == MemorySpace::kAlternate) {
1748       BufferInterval buffer_interval;
1749       buffer_interval.start = allocation->start_time();
1750       buffer_interval.end = allocation->end_time();
1751       buffer_interval.size = allocation->chunk().size;
1752       buffer_interval.buffer = prefetch_candidate->buffer;
1753       AddToPendingChunks(buffer_interval, chunk_candidate);
1754     }
1755     allocations_->push_back(std::move(allocation));
1756   }
1757 
1758   // Add a repack allocation block for the Allocation objects in alternate
1759   // memory.
1760   for (int i = allocations_initial_size; i < allocations_->size(); ++i) {
1761     const auto& allocation = allocations_->at(i);
1762     if (allocation->memory_space() == MemorySpace::kAlternate) {
1763       repack_allocation_blocks_.push_back(MakeRepackAllocationBlock(
1764           allocation->start_time(), allocation->end_time(),
1765           allocation->chunk().size, allocation->chunk().offset,
1766           static_cast<int64>(repack_allocation_blocks_.size()),
1767           allocation.get()));
1768       RepackAllocationBlock* inserted = &repack_allocation_blocks_.back();
1769       for (RepackAllocationBlock& colocation : repack_allocation_blocks_) {
1770         colocation.colocations.push_back(inserted);
1771         if (&colocation != inserted) {
1772           inserted->colocations.push_back(&colocation);
1773         }
1774       }
1775     }
1776   }
1777 
1778   ClearPendingChunks();
1779 }
1780 
AllocateReservedScopedAllocations()1781 void AlternateMemoryBestFitHeap::AllocateReservedScopedAllocations() {
1782   const auto& instruction_sequence =
1783       hlo_live_range_.flattened_instruction_sequence().instructions();
1784   std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*> colocations;
1785   for (int i = 0; i < instruction_sequence.size(); ++i) {
1786     int64_t reserved_scoped_memory =
1787         options_.reserved_scoped_memory_fn(instruction_sequence[i]);
1788     if (reserved_scoped_memory != 0) {
1789       VLOG(1) << "Allocate reserved scoped memory at " << i << " ("
1790               << instruction_sequence[i]->name()
1791               << "): " << reserved_scoped_memory;
1792       MemorySpaceAssignment::BufferInterval interval;
1793       interval.buffer = nullptr;
1794       interval.size = reserved_scoped_memory;
1795       interval.start = i;
1796       interval.end = i;
1797       interval.need_allocation = true;
1798       interval.colocations = {};
1799       ChunkCandidate chunk_candidate =
1800           FindChunkCandidate(interval, /*preferred_offset=*/0);
1801       CHECK_EQ(chunk_candidate.chunk.offset, 0);
1802       AddToPendingChunks(interval, chunk_candidate);
1803 
1804       allocations_->push_back(
1805           absl::make_unique<MemorySpaceAssignment::Allocation>(
1806               HloPosition{instruction_sequence[i], {}}, MemorySpace::kAlternate,
1807               chunk_candidate.chunk, i, i, /*is_scoped_allocation=*/true));
1808 
1809       repack_allocation_blocks_.push_back(MakeRepackAllocationBlock(
1810           i, i, reserved_scoped_memory,
1811           /*initial_offset=*/0,
1812           static_cast<int64>(repack_allocation_blocks_.size()),
1813           allocations_->back().get()));
1814       colocations.push_back(&repack_allocation_blocks_.back());
1815     }
1816   }
1817   // If requested, make all scoped allocations to colocate with each other so
1818   // that when we repack, all scoped allocations get the same offsets. Since
1819   // they will all have the same scoped memory addresses, this increases the
1820   // opportunity to deduplicate different ops.  However, this may hurt the
1821   // memory packing efficiency.
1822   if (options_.allocate_reserved_scoped_memory_at_same_offset) {
1823     for (MemorySpaceAssignmentRepacker::AllocationBlock* repack_block :
1824          colocations) {
1825       repack_block->colocations = colocations;
1826     }
1827   }
1828 }
1829 
1830 absl::optional<AlternateMemoryBestFitHeap::RequiredMemoryAssignment>
RequiredMemoryAssignmentAt(const HloValue * buffer,int64_t time) const1831 AlternateMemoryBestFitHeap::RequiredMemoryAssignmentAt(const HloValue* buffer,
1832                                                        int64_t time) const {
1833   auto required_assignment_it = required_assignments_.find(buffer);
1834   absl::optional<RequiredMemoryAssignment> required_assignment_at_time;
1835   if (required_assignment_it != required_assignments_.end()) {
1836     for (const RequiredMemoryAssignment& required_assignment :
1837          required_assignment_it->second) {
1838       if (required_assignment.time == time) {
1839         // Sanity check that there is only one required at time.
1840         CHECK(!required_assignment_at_time);
1841         required_assignment_at_time = required_assignment;
1842       }
1843     }
1844   }
1845   return required_assignment_at_time;
1846 }
1847 
1848 absl::optional<AlternateMemoryBestFitHeap::RequiredMemoryAssignment>
AliasedRequiredAssignmentForUse(const AllocationValue::Use & use) const1849 AlternateMemoryBestFitHeap::AliasedRequiredAssignmentForUse(
1850     const AllocationValue::Use& use) const {
1851   absl::optional<RequiredMemoryAssignment> required_assignment;
1852   for (const HloPosition& position : use.aliases) {
1853     const HloValue* value =
1854         &alias_analysis_.dataflow_analysis().GetUniqueValueAt(
1855             position.instruction, position.index);
1856     int64_t time =
1857         hlo_live_range_.instruction_schedule().at(position.instruction);
1858     absl::optional<RequiredMemoryAssignment> required_assignment_for_alias =
1859         RequiredMemoryAssignmentAt(value, time);
1860     if (required_assignment == absl::nullopt) {
1861       required_assignment = required_assignment_for_alias;
1862     } else {
1863       CHECK(required_assignment_for_alias == absl::nullopt ||
1864             required_assignment->equals_ignoring_time(
1865                 *required_assignment_for_alias));
1866     }
1867   }
1868   return required_assignment;
1869 }
1870 
AddAliasedRequiredAssignment(const HloInstruction * instruction,ShapeIndex index,const MemorySpaceAssignment::Allocation * aliased_allocation)1871 void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignment(
1872     const HloInstruction* instruction, ShapeIndex index,
1873     const MemorySpaceAssignment::Allocation* aliased_allocation) {
1874   AliasedOffset* offset = nullptr;
1875   if (aliased_allocation->memory_space() == MemorySpace::kAlternate) {
1876     offset = GetAliasedOffset(*aliased_allocation);
1877   }
1878   AddRequiredAssignment(instruction, index, aliased_allocation->memory_space(),
1879                         offset);
1880 }
1881 
AddRequiredAssignment(const HloValue * value,const HloInstruction * instruction,MemorySpaceAssignment::MemorySpace memory_space,int64_t time,AliasedOffset * offset)1882 void AlternateMemoryBestFitHeap::AddRequiredAssignment(
1883     const HloValue* value, const HloInstruction* instruction,
1884     MemorySpaceAssignment::MemorySpace memory_space, int64_t time,
1885     AliasedOffset* offset) {
1886   // Check for existing required assignment at this time and make sure it is the
1887   // same as this if there is one.
1888   auto existing_required_assignment = RequiredMemoryAssignmentAt(value, time);
1889   if (existing_required_assignment) {
1890     CHECK(memory_space == existing_required_assignment->memory_space)
1891         << "inst = " << instruction->ToString() << " at " << time;
1892     CHECK((!offset && !existing_required_assignment->offset) ||
1893           offset == existing_required_assignment->offset);
1894     VLOG(3) << "Not adding required assignment because there is one already: "
1895             << value->ToShortString() << " at " << time << " at "
1896             << (memory_space == MemorySpace::kDefault ? "def" : "alt");
1897   } else {
1898     VLOG(3) << "Adding required assignment: " << value->ToShortString()
1899             << " at " << time << " at "
1900             << (memory_space == MemorySpace::kDefault ? "def" : "alt");
1901     RequiredMemoryAssignment required_assignment{memory_space, time, offset};
1902     required_assignments_[value].push_back(required_assignment);
1903     pending_required_assignments_.push_back({value, required_assignment});
1904   }
1905 }
1906 
AddRequiredAssignment(const HloInstruction * instruction,ShapeIndex index,MemorySpace memory_space,AliasedOffset * offset)1907 void AlternateMemoryBestFitHeap::AddRequiredAssignment(
1908     const HloInstruction* instruction, ShapeIndex index,
1909     MemorySpace memory_space, AliasedOffset* offset) {
1910   const HloValue* value =
1911       &alias_analysis_.dataflow_analysis().GetUniqueValueAt(instruction, index);
1912   int64_t instruction_time =
1913       hlo_live_range_.instruction_schedule().at(instruction);
1914   AddRequiredAssignment(value, instruction, memory_space, instruction_time,
1915                         offset);
1916 }
1917 
AddInputAndOutputRequiredAssignments()1918 void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() {
1919   // Go through the parameters, outputs, and constants and pin them to the
1920   // corresponding memory by adding a required assignment.
1921   const HloModule& module = alias_analysis_.dataflow_analysis().module();
1922   const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
1923   HloComputation* entry_computation = module.entry_computation();
1924   for (HloInstruction* parameter_instruction :
1925        entry_computation->parameter_instructions()) {
1926     int64_t parameter_instruction_time =
1927         instruction_schedule.at(parameter_instruction);
1928     ShapeUtil::ForEachSubshape(
1929         parameter_instruction->shape(),
1930         [&](const Shape& subshape, const ShapeIndex& index) {
1931           MemorySpace memory_space = MemorySpace::kDefault;
1932           if (subshape.has_layout() && subshape.layout().memory_space() ==
1933                                            options_.alternate_memory_space) {
1934             memory_space = MemorySpace::kAlternate;
1935           }
1936           for (const HloBuffer* buffer :
1937                alias_analysis_.ComputeBuffersAt(parameter_instruction, index)) {
1938             for (const HloValue* value : buffer->values()) {
1939               VLOG(3) << "Adding required assignment for parameter value = "
1940                       << value->ToShortString()
1941                       << " time = " << parameter_instruction_time << " space = "
1942                       << (memory_space == MemorySpace::kDefault ? "def"
1943                                                                 : "alt");
1944               required_assignments_[value].push_back(
1945                   {memory_space, /*time=*/parameter_instruction_time});
1946             }
1947           }
1948         });
1949   }
1950   HloInstruction* root_instruction = entry_computation->root_instruction();
1951   int64_t root_instruction_time = instruction_schedule.at(root_instruction);
1952   ShapeUtil::ForEachSubshape(
1953       root_instruction->shape(),
1954       [&](const Shape& subshape, const ShapeIndex& index) {
1955         MemorySpace memory_space = MemorySpace::kDefault;
1956         if (subshape.has_layout() && subshape.layout().memory_space() ==
1957                                          options_.alternate_memory_space) {
1958           memory_space = MemorySpace::kAlternate;
1959         }
1960         for (const HloBuffer* buffer :
1961              alias_analysis_.ComputeBuffersAt(root_instruction, index)) {
1962           for (const HloValue* value : buffer->values()) {
1963             VLOG(3) << "Adding required assignment for output value = "
1964                     << value->ToShortString()
1965                     << " time = " << root_instruction_time << " space = "
1966                     << (memory_space == MemorySpace::kDefault ? "def" : "alt");
1967             required_assignments_[value].push_back(
1968                 {memory_space, /*time=*/root_instruction_time});
1969           }
1970         }
1971       });
1972 
1973   for (const HloComputation* computation : module.MakeNonfusionComputations()) {
1974     for (HloInstruction* instruction : computation->instructions()) {
1975       if (instruction->opcode() == HloOpcode::kConstant) {
1976         auto constant_instruction_it = instruction_schedule.find(instruction);
1977         if (constant_instruction_it == instruction_schedule.end()) {
1978           continue;
1979         }
1980         int64_t constant_instruction_time = constant_instruction_it->second;
1981         for (const auto& indexed_shape :
1982              ShapeUtil::GetLeafShapes(instruction->shape())) {
1983           const ShapeIndex& index = indexed_shape.index;
1984           for (const HloBuffer* buffer :
1985                alias_analysis_.ComputeBuffersAt(instruction, index)) {
1986             for (const HloValue* value : buffer->values()) {
1987               VLOG(3) << "Adding required assignment for constant value = "
1988                       << value->ToShortString()
1989                       << " time = " << constant_instruction_time
1990                       << " space = def";
1991               required_assignments_[value].push_back(
1992                   {MemorySpace::kDefault, /*time=*/constant_instruction_time});
1993             }
1994           }
1995         }
1996       }
1997     }
1998   }
1999 }
2000 
AreIntervalsReservedInAlternateMemory(absl::Span<const BufferInterval * const> colocated_intervals) const2001 bool AlternateMemoryBestFitHeap::AreIntervalsReservedInAlternateMemory(
2002     absl::Span<const BufferInterval* const> colocated_intervals) const {
2003   auto is_position_in_alternate_memory = [&](const HloPosition& position) {
2004     const Shape& shape = position.shape();
2005     return shape.has_layout() &&
2006            shape.layout().memory_space() == options_.alternate_memory_space;
2007   };
2008 
2009   const HloModule& module = alias_analysis_.dataflow_analysis().module();
2010   const HloComputation* entry_computation = module.entry_computation();
2011   const HloInstruction* root_instruction =
2012       entry_computation->root_instruction();
2013   for (const BufferInterval* colocated_interval : colocated_intervals) {
2014     const HloValue* value = colocated_interval->buffer;
2015     if (value->defining_instruction()->opcode() == HloOpcode::kParameter &&
2016         value->defining_instruction()->parent() == entry_computation &&
2017         is_position_in_alternate_memory(value->defining_position())) {
2018       return true;
2019     }
2020 
2021     for (const HloPosition& position : value->positions()) {
2022       if (position.instruction == root_instruction &&
2023           is_position_in_alternate_memory(position)) {
2024         return true;
2025       }
2026     }
2027   }
2028   return false;
2029 }
2030 
ExportAllocationsForRepacking(std::vector<MemorySpaceAssignmentRepacker::AllocationBlock * > & allocations)2031 void AlternateMemoryBestFitHeap::ExportAllocationsForRepacking(
2032     std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*>& allocations) {
2033   for (RepackAllocationBlock& allocation_block : repack_allocation_blocks_) {
2034     allocations.push_back(&allocation_block);
2035   }
2036 }
2037 
ImportRepackedAllocations()2038 void AlternateMemoryBestFitHeap::ImportRepackedAllocations() {
2039   interval_tree_ = {};
2040   for (RepackAllocationBlock& allocation_block : repack_allocation_blocks_) {
2041     MemorySpaceAssignment::Allocation* allocation = allocation_block.allocation;
2042     VLOG(3) << "Moved " << allocation->ToString() << ", size "
2043             << allocation->chunk().size << ", (" << allocation_block.start_time
2044             << ", " << allocation_block.end_time << ") from "
2045             << allocation_block.initial_offset << " to "
2046             << allocation_block.offset;
2047     allocation_block.allocation->mutable_chunk()->offset =
2048         allocation_block.offset;
2049     interval_tree_.Add(allocation_block.start_time, allocation_block.end_time,
2050                        {allocation_block.offset, allocation_block.size});
2051     allocation_block.initial_offset = allocation_block.offset;
2052     allocation_block.offset = -1;
2053   }
2054 }
2055 
UncommitPendingChunks(absl::Span<AllocationValue> allocation_values)2056 void AlternateMemoryBestFitHeap::UncommitPendingChunks(
2057     absl::Span<AllocationValue> allocation_values) {
2058   // Clear the allocation sequence of the allocation values so that in case we
2059   // retry allocation after uncommitting.
2060   for (AllocationValue& allocation_value : allocation_values) {
2061     allocation_value.allocation_sequence()->clear();
2062   }
2063   for (const auto& interval_and_chunk : pending_chunks_) {
2064     const BufferInterval& interval = interval_and_chunk.first;
2065     const Chunk& chunk = interval_and_chunk.second.chunk;
2066     VLOG(3) << "Uncommitting: (" << interval.start << ", " << interval.end
2067             << ") off = " << chunk.offset << " size = " << chunk.size;
2068     interval_tree_.Remove(interval.start, interval.end, chunk);
2069   }
2070   for (const auto& interval : pending_async_copies_) {
2071     if (interval.destination == MemorySpace::kAlternate) {
2072       prefetch_interval_tree_.Remove(interval.start_time, interval.end_time,
2073                                      kDummyChunk);
2074       async_copy_ordering_.RemoveCopy(interval);
2075     } else {
2076       eviction_interval_tree_.Remove(interval.start_time, interval.end_time,
2077                                      kDummyChunk);
2078     }
2079   }
2080   for (const auto& value_and_required_assignment :
2081        pending_required_assignments_) {
2082     auto& required_assignment_vector =
2083         required_assignments_[value_and_required_assignment.first];
2084     const RequiredMemoryAssignment& required_assignment =
2085         value_and_required_assignment.second;
2086     VLOG(3) << "Removing required assignment: "
2087             << (required_assignment.memory_space == MemorySpace::kDefault
2088                     ? "def"
2089                     : "alt")
2090             << " time = " << required_assignment.time << " off = "
2091             << (required_assignment.offset ? required_assignment.offset->offset
2092                                            : -1);
2093     for (auto it = required_assignment_vector.begin();
2094          it != required_assignment_vector.end(); ++it) {
2095       if (*it == value_and_required_assignment.second) {
2096         required_assignment_vector.erase(it);
2097         break;
2098       }
2099     }
2100   }
2101   ClearPendingChunks();
2102 }
2103 
FinalizeAllocations(absl::Span<AllocationValue> allocation_values)2104 void AlternateMemoryBestFitHeap::FinalizeAllocations(
2105     absl::Span<AllocationValue> allocation_values) {
2106   absl::flat_hash_map<const AliasedOffset*,
2107                       std::vector<MemorySpaceAssignment::Allocation*>>
2108       colocation_map;
2109   for (AllocationValue& allocation_value : allocation_values) {
2110     for (auto& allocation : *allocation_value.allocation_sequence()) {
2111       AppendAllocationInfoDebugString(allocation_value, *allocation,
2112                                       allocation_info_str_);
2113       allocations_->push_back(std::move(allocation));
2114       MemorySpaceAssignment::Allocation* inserted_allocation =
2115           allocations_->back().get();
2116       if (inserted_allocation->memory_space() == MemorySpace::kAlternate) {
2117         colocation_map[GetAliasedOffset(*inserted_allocation)].push_back(
2118             inserted_allocation);
2119       }
2120     }
2121   }
2122   // The allocations that have the same AliasedOffset need to be colocated.
2123   // Export these to repack_allocation_blocks_ so that we can repack them to
2124   // reduce fragmentation.
2125   for (auto& colocation : colocation_map) {
2126     std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*> colocations;
2127     for (MemorySpaceAssignment::Allocation* colocated_allocation :
2128          colocation.second) {
2129       repack_allocation_blocks_.push_back(MakeRepackAllocationBlock(
2130           colocated_allocation->start_time(), colocated_allocation->end_time(),
2131           colocated_allocation->chunk().size,
2132           colocated_allocation->chunk().offset,
2133           static_cast<int64>(repack_allocation_blocks_.size()),
2134           colocated_allocation));
2135       colocations.push_back(&repack_allocation_blocks_.back());
2136     }
2137     for (MemorySpaceAssignmentRepacker::AllocationBlock* repack_block :
2138          colocations) {
2139       repack_block->colocations = colocations;
2140     }
2141   }
2142   ClearPendingChunks();
2143 }
2144 
ClearPendingChunks()2145 void AlternateMemoryBestFitHeap::ClearPendingChunks() {
2146   pending_chunks_.clear();
2147   pending_async_copies_.clear();
2148   pending_required_assignments_.clear();
2149   aliased_offset_map_.clear();
2150   aliased_offsets_.clear();
2151 }
2152 
AddToPendingChunks(const BufferInterval & buffer_interval,const ChunkCandidate & chunk_candidate)2153 void AlternateMemoryBestFitHeap::AddToPendingChunks(
2154     const BufferInterval& buffer_interval,
2155     const ChunkCandidate& chunk_candidate) {
2156   VLOG(3) << "Committing chunk: " << buffer_interval.start << "-"
2157           << buffer_interval.end << " : [" << chunk_candidate.chunk.offset
2158           << ", " << chunk_candidate.chunk.size << "]";
2159   pending_chunks_.emplace_back(buffer_interval, chunk_candidate);
2160   CommitChunk(buffer_interval, chunk_candidate);
2161 }
2162 
AllocateSegment(const AllocationRequest & request)2163 AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment(
2164     const AllocationRequest& request) {
2165   auto allocation_sequence = request.allocation_value->allocation_sequence();
2166   // start_time == end_time is a special case where the value is consumed
2167   // multiple times by the same instruction. We can just find the previous
2168   // allocation and use that allocation.
2169   if (request.start_time == request.end_time) {
2170     MemorySpaceAssignment::Allocation* allocation =
2171         GetLiveAllocationAt(*allocation_sequence, request.end_time);
2172     CHECK_NE(allocation, nullptr);
2173     allocation->AddUse(request.use->hlo_use);
2174     return Result::kSuccess;
2175   }
2176 
2177   const HloPosition& defining_position =
2178       request.allocation_value->defining_position();
2179   VLOG(2) << "Finding allocation for "
2180           << request.allocation_value->ToShortString() << " ("
2181           << request.start_time << ", " << request.end_time
2182           << ") latest prefetch = " << request.latest_prefetch_time
2183           << " last use = " << request.allocation_value->uses().back().time
2184           << " use = " << request.use->hlo_use.ToString()
2185           << ". Size = " << request.size
2186           << ", def pos = " << defining_position.ToString();
2187   CHECK_LE(request.start_time, request.end_time);
2188 
2189   // There could be a requirement to pin this buffer to default memory either
2190   // because it is a parameter or an output.  If the buffer is a parameter, then
2191   // we're allowed to prefetch. If the use expects the output to be in default
2192   // memory, we cannot prefetch it because if we did, it would be in alternate
2193   // memory instead.
2194   auto required_assignment_at_start = RequiredMemoryAssignmentAt(
2195       request.allocation_value->value(), request.start_time);
2196   absl::optional<MemorySpace> required_memory_space_at_start;
2197   if (required_assignment_at_start) {
2198     required_memory_space_at_start = required_assignment_at_start->memory_space;
2199   }
2200   // Find required assignment both for the use and its aliases. If they are both
2201   // non-nullopt, then make sure they require the same assignment.
2202   auto required_assignment_at_end = RequiredMemoryAssignmentAt(
2203       request.allocation_value->value(), request.end_time);
2204   auto aliased_required_assignment_at_end =
2205       AliasedRequiredAssignmentForUse(*request.use);
2206   if (required_assignment_at_end != aliased_required_assignment_at_end) {
2207     if (required_assignment_at_end == absl::nullopt) {
2208       required_assignment_at_end = aliased_required_assignment_at_end;
2209     } else {
2210       CHECK(aliased_required_assignment_at_end == absl::nullopt ||
2211             aliased_required_assignment_at_end->equals_ignoring_time(
2212                 *required_assignment_at_end));
2213     }
2214   }
2215   absl::optional<MemorySpace> required_memory_space_at_end;
2216   if (required_assignment_at_end) {
2217     required_memory_space_at_end = required_assignment_at_end->memory_space;
2218   }
2219 
2220   if (required_assignment_at_start) {
2221     bool needs_required_allocation = true;
2222     if (!allocation_sequence->empty()) {
2223       auto prev_allocation_it = std::find_if(
2224           allocation_sequence->rbegin(), allocation_sequence->rend(),
2225           [&](const auto& allocation) {
2226             return allocation->memory_space() ==
2227                        required_memory_space_at_start &&
2228                    allocation->defining_position() == defining_position;
2229           });
2230       if (prev_allocation_it != allocation_sequence->rend()) {
2231         (*prev_allocation_it)->Extend(request.start_time);
2232         needs_required_allocation = false;
2233       }
2234     }
2235     if (needs_required_allocation) {
2236       absl::optional<Chunk> aliased_chunk = absl::nullopt;
2237       if (required_assignment_at_start->memory_space ==
2238           MemorySpace::kAlternate) {
2239         aliased_chunk =
2240             Chunk{required_assignment_at_start->offset->offset, request.size};
2241       }
2242       allocation_sequence->push_back(
2243           absl::make_unique<MemorySpaceAssignment::Allocation>(
2244               defining_position, required_assignment_at_start->memory_space,
2245               aliased_chunk, request.start_time, request.start_time,
2246               /*is_scoped_allocation=*/false));
2247       if (required_assignment_at_start->memory_space ==
2248           MemorySpace::kAlternate) {
2249         CreateOrAddToAliasedOffset(*allocation_sequence->back(),
2250                                    required_assignment_at_start->offset);
2251       }
2252     }
2253   }
2254 
2255   Result allocation_result = Result::kSuccess;
2256   // First try keeping the allocation entirely in the alternate memory.
2257   if (required_memory_space_at_start != MemorySpace::kDefault &&
2258       required_memory_space_at_end != MemorySpace::kDefault &&
2259       request.allow_no_copy_alternate_mem_allocation) {
2260     allocation_result = AllocateInAlternateMemoryNoCopy(request);
2261     if (allocation_result == Result::kSuccess) {
2262       return Result::kSuccess;
2263     }
2264   }
2265 
2266   auto prev_allocation_it = allocation_sequence->rbegin();
2267   // Find a previous allocation that is in the default memory space (not
2268   // necessarily the very last allocation).
2269   auto prev_allocation_in_default_mem_it = std::find_if(
2270       allocation_sequence->rbegin(), allocation_sequence->rend(),
2271       [&](const auto& allocation) {
2272         return allocation->memory_space() == MemorySpace::kDefault &&
2273                allocation->defining_position() == defining_position;
2274       });
2275 
2276   if (prev_allocation_in_default_mem_it == allocation_sequence->rend() &&
2277       prev_allocation_it != allocation_sequence->rend() &&
2278       (*prev_allocation_it)->memory_space() == MemorySpace::kAlternate &&
2279       (*prev_allocation_it)->defining_position() == defining_position &&
2280       !request.allocation_value->requires_contiguous_allocation()) {
2281     // If there was an allocation for this HloValue that was in the alternate
2282     // memory space, we also need to perform an eviction.
2283     Result eviction_result = Evict(request);
2284     if (eviction_result != Result::kSuccess) {
2285       // A non-success eviction requires us to uncommit previous allocations.
2286       return result_mark(Result::kFailRequiresUncommit, eviction_result);
2287     }
2288     prev_allocation_in_default_mem_it = allocation_sequence->rbegin();
2289   } else if (prev_allocation_in_default_mem_it == allocation_sequence->rend()) {
2290     allocation_sequence->push_back(
2291         absl::make_unique<MemorySpaceAssignment::Allocation>(
2292             defining_position, MemorySpace::kDefault, /*chunk=*/absl::nullopt,
2293             request.start_time, request.end_time,
2294             /*is_scoped_allocation=*/false));
2295     prev_allocation_in_default_mem_it = allocation_sequence->rbegin();
2296   }
2297 
2298   CHECK(prev_allocation_in_default_mem_it != allocation_sequence->rend());
2299   CHECK((*prev_allocation_in_default_mem_it)->memory_space() ==
2300         MemorySpace::kDefault);
2301 
2302   // If the buffer must be in default memory at the end_time, don't prefetch.
2303   if (required_memory_space_at_end == MemorySpace::kDefault) {
2304     VLOG(3)
2305         << "Not trying to prefetch because use requires buffer in default mem.";
2306     (*prev_allocation_in_default_mem_it)->Extend(request.end_time);
2307     (*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use);
2308     return Result::kSuccess;
2309   }
2310 
2311   // Finally, try to prefetch the buffer into alternate memory.
2312   if (!request.allocation_value->requires_contiguous_allocation()) {
2313     Result prefetch_result =
2314         Prefetch(request, **prev_allocation_in_default_mem_it);
2315     if (prefetch_result == Result::kSuccess) {
2316       return Result::kSuccess;
2317     }
2318     result_mark(prefetch_result, allocation_result);
2319   }
2320 
2321   // If the end assignment was required to be in alternate memory but that
2322   // wasn't possible, then this allocation is invalid.
2323   if (required_memory_space_at_end == MemorySpace::kAlternate) {
2324     return result_mark(Result::kFailRequiresUncommit, allocation_result);
2325   }
2326 
2327   // If the start assignment was required to be in alternate memory and the
2328   // buffer needs a contiguous assignment, we couldn't satisfy this requirement
2329   // and must abort.
2330   if (required_memory_space_at_start == MemorySpace::kAlternate &&
2331       request.allocation_value->requires_contiguous_allocation()) {
2332     return result_mark(Result::kFailRequiresUncommit, allocation_result);
2333   }
2334 
2335   // If a copy wasn't inserted, then add this use to the latest allocation in
2336   // default memory.
2337   (*prev_allocation_in_default_mem_it)->Extend(request.end_time);
2338   (*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use);
2339   return allocation_result;
2340 }
2341 
AddAsyncCopy(const MemorySpaceAssignment::Allocation & prev_allocation,MemorySpace memory_space,absl::optional<Chunk> chunk,int64_t start_time,int64_t end_time,int64_t copy_done_schedule_before_time,MemorySpaceAssignment::AllocationSequence * allocations,AliasedOffset * aliased_offset,bool is_cross_program_prefetch)2342 void AlternateMemoryBestFitHeap::AddAsyncCopy(
2343     const MemorySpaceAssignment::Allocation& prev_allocation,
2344     MemorySpace memory_space, absl::optional<Chunk> chunk, int64_t start_time,
2345     int64_t end_time, int64_t copy_done_schedule_before_time,
2346     MemorySpaceAssignment::AllocationSequence* allocations,
2347     AliasedOffset* aliased_offset, bool is_cross_program_prefetch) {
2348   VLOG(3) << "Copy to "
2349           << (memory_space == MemorySpaceAssignment::MemorySpace::kDefault
2350                   ? "default"
2351                   : "alternate")
2352           << " memory between " << start_time << " and "
2353           << copy_done_schedule_before_time << " keeping until " << end_time;
2354   CHECK_LT(start_time, copy_done_schedule_before_time);
2355 
2356   allocations->push_back(
2357       absl::make_unique<MemorySpaceAssignment::CopyAllocation>(
2358           prev_allocation, memory_space, chunk, start_time, end_time,
2359           copy_done_schedule_before_time, is_cross_program_prefetch));
2360 
2361   // Register the additional async copy with the interval tree to keep track of
2362   // the limit at any given time.
2363   pending_async_copies_.push_back(
2364       {start_time, copy_done_schedule_before_time, memory_space});
2365   if (memory_space == MemorySpaceAssignment::MemorySpace::kAlternate) {
2366     prefetch_interval_tree_.Add(start_time, copy_done_schedule_before_time,
2367                                 kDummyChunk);
2368     async_copy_ordering_.AddCopy(pending_async_copies_.back());
2369     CreateOrAddToAliasedOffset(*allocations->back(), aliased_offset);
2370   } else {
2371     eviction_interval_tree_.Add(start_time, copy_done_schedule_before_time,
2372                                 kDummyChunk);
2373   }
2374 }
2375 
ViolatesMaximumOutstandingAsyncCopies(int64_t start_time,int64_t end_time,bool is_prefetch,int64_t extra_async_copy_limit) const2376 bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies(
2377     int64_t start_time, int64_t end_time, bool is_prefetch,
2378     int64_t extra_async_copy_limit) const {
2379   if (options_.max_outstanding_prefetches < 0 && is_prefetch) {
2380     return false;
2381   }
2382   if (options_.max_outstanding_evictions < 0 && !is_prefetch) {
2383     return false;
2384   }
2385 
2386   // Count the prefetches/evictions in the interval tree for the given interval.
2387   if (is_prefetch) {
2388     int64_t num_prefetches =
2389         prefetch_interval_tree_.ChunksOverlappingInTime(start_time, end_time)
2390             .size();
2391     return num_prefetches >=
2392            options_.max_outstanding_prefetches + extra_async_copy_limit;
2393   } else {
2394     int64_t num_evictions =
2395         eviction_interval_tree_.ChunksOverlappingInTime(start_time, end_time)
2396             .size();
2397     return num_evictions >=
2398            options_.max_outstanding_evictions + extra_async_copy_limit;
2399   }
2400 }
2401 
2402 absl::optional<AsynchronousCopy>
ViolatesAsyncCopyOrdering(int64_t start_time,int64_t end_time) const2403 AlternateMemoryBestFitHeap::ViolatesAsyncCopyOrdering(int64_t start_time,
2404                                                       int64_t end_time) const {
2405   return async_copy_ordering_.ViolatesOrdering(start_time, end_time);
2406 }
2407 
2408 AlternateMemoryBestFitHeap::Result
AllocateInAlternateMemoryNoCopy(const AllocationRequest & request)2409 AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy(
2410     const AllocationRequest& request) {
2411   MemorySpaceAssignment::Allocation* prev_allocation = nullptr;
2412   bool can_eliminate_copy = false;
2413   if (request.allocation_value->allocation_sequence()->empty()) {
2414     // There hasn't been any allocations for this interval so far. We can
2415     // eliminate copy if the value can be placed in the alternate memory.
2416     can_eliminate_copy = options_.is_allowed_in_alternate_mem_fn(
2417         *request.allocation_value->value());
2418   } else {
2419     // If there has been a previous allocation, we can eliminate the copy if the
2420     // previous allocation was also in the alternate memory.
2421     prev_allocation =
2422         request.allocation_value->allocation_sequence()->back().get();
2423     can_eliminate_copy =
2424         (prev_allocation->memory_space() == MemorySpace::kAlternate);
2425   }
2426 
2427   if (!can_eliminate_copy) {
2428     return Result::kFailPrevAllocationNotInAlternateMem;
2429   }
2430 
2431   const HloPosition& defining_position =
2432       request.allocation_value->defining_position();
2433   if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
2434           defining_position.shape(), request.start_time + 1,
2435           request.end_time)) {
2436     return Result::kFailLiveRangeTooLong;
2437   }
2438 
2439   BufferInterval alternate_mem_interval;
2440   alternate_mem_interval.buffer = request.allocation_value->value();
2441   alternate_mem_interval.size = request.size;
2442   alternate_mem_interval.end = request.end_time;
2443   alternate_mem_interval.start = request.start_time;
2444 
2445   // Prefer the offset that was previously used for the previous allocation.
2446   AliasedOffset* preferred_offset = nullptr;
2447   if (prev_allocation != nullptr) {
2448     preferred_offset = GetAliasedOffset(*prev_allocation);
2449     // If there is a previous allocation, set the start time one after the end
2450     // of the previous allocation's end.
2451     alternate_mem_interval.start = prev_allocation->end_time() + 1;
2452   }
2453 
2454   if (request.preferred_offset) {
2455     // Sanity check that if there is a preferred offset provided in the request,
2456     // it matches with the previous allocation.
2457     CHECK(!preferred_offset || request.preferred_offset == preferred_offset)
2458         << "preferred_offset = " << preferred_offset->offset
2459         << ", request.preferred_offset = " << request.preferred_offset->offset;
2460     preferred_offset = request.preferred_offset;
2461   }
2462 
2463   VLOG(3) << "We can eliminate copy to alternate memory. Preferred offset = "
2464           << (preferred_offset ? preferred_offset->offset : -1);
2465   // In case there are additional uses after this use, we rely on the last use
2466   // time to try to reserve a chunk in the heap simulator. This is to prevent
2467   // the following scenario:
2468   //
2469   //                            +-------+
2470   //                           /         \
2471   //                   Producer--->Use1   +-->Use2
2472   //                       +---------+---------+
2473   // New buffer:           |         |         |
2474   //                       +---------+---------+
2475   //
2476   //                                     +-----------+
2477   // Current heap:                       | offset: 0 |
2478   //           --------------------------+-----------+------
2479   //
2480   // Because we allocate buffers greedily, Producer to Use1 segment first, and
2481   // then Use1 to Use2 segment, it is possible to allocate the first segment at
2482   // an offset that is available for the first segment (e.g. offset 0) but not
2483   // for the entire live range. This can result in unnecessary copies. By using
2484   // the last use time, we try to find an allocation that is available for the
2485   // entire Producer to Use2 range.
2486   absl::optional<ChunkCandidate> chunk_candidate = FindBestChunkCandidate(
2487       request, preferred_offset, &alternate_mem_interval);
2488   // Check if the new heap size fits within limits. Also ensure if a
2489   // preferred offset was provided, that offset was used.
2490   if (chunk_candidate) {
2491     VLOG(3) << "Keep the buffer in alternate memory. Offset = "
2492             << chunk_candidate->chunk.offset
2493             << ", size = " << chunk_candidate->chunk.size
2494             << ", heap_size = " << chunk_candidate->heap_size
2495             << ", prefetch picker = "
2496             << options_.prefetch_interval_picker->ToNoCopyDebugString(
2497                    defining_position.shape(), request.start_time,
2498                    request.end_time);
2499     AddToPendingChunks(alternate_mem_interval, *chunk_candidate);
2500 
2501     // If there was a previous allocation, the buffer location is the
2502     // same as the previous. Otherwise, it is the operand.
2503     if (prev_allocation != nullptr &&
2504         (prev_allocation->is_copy_allocation() ||
2505          prev_allocation->defining_position() == defining_position)) {
2506       prev_allocation->Extend(request.end_time);
2507     } else {
2508       request.allocation_value->allocation_sequence()->push_back(
2509           absl::make_unique<MemorySpaceAssignment::Allocation>(
2510               defining_position, MemorySpace::kAlternate,
2511               chunk_candidate->chunk, request.start_time, request.end_time,
2512               /*is_scoped_allocation=*/false));
2513       CreateOrAddToAliasedOffset(
2514           *request.allocation_value->allocation_sequence()->back(),
2515           preferred_offset);
2516     }
2517     request.allocation_value->allocation_sequence()->back()->AddUse(
2518         request.use->hlo_use);
2519     return Result::kSuccess;
2520   }
2521   return Result::kFailOutOfMemory;
2522 }
2523 
Evict(const AllocationRequest & request)2524 AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::Evict(
2525     const AllocationRequest& request) {
2526   CHECK_GT(request.allocation_value->allocation_sequence()->size(), 0);
2527   MemorySpaceAssignment::Allocation* prev_allocation =
2528       request.allocation_value->allocation_sequence()->back().get();
2529   int64_t eviction_start_time = prev_allocation->start_time();
2530   int64_t eviction_end_time = prev_allocation->end_time();
2531   CHECK(eviction_start_time <= eviction_end_time);
2532 
2533   int64_t preferred_eviction_end_time =
2534       std::max(options_.prefetch_interval_picker->PreferredEvictionEndTime(
2535                    request.allocation_value->defining_position().shape(),
2536                    eviction_start_time, request.end_time),
2537                eviction_end_time);
2538   // Evictions must complete by the time of this use.
2539   preferred_eviction_end_time =
2540       std::min(preferred_eviction_end_time, request.latest_prefetch_time);
2541 
2542   BufferInterval eviction_mem_interval;
2543   eviction_mem_interval.buffer = request.allocation_value->value();
2544   eviction_mem_interval.size = request.size;
2545   // Try to reserve a buffer from the end of the previous allocation to the
2546   // preferred eviction end time.
2547   eviction_mem_interval.start = eviction_end_time + 1;
2548   eviction_mem_interval.end = preferred_eviction_end_time;
2549   int64_t preferred_offset = prev_allocation->chunk().offset;
2550   VLOG(3) << "Eviction (" << eviction_start_time << ", " << eviction_end_time
2551           << ") preferred end time = " << eviction_mem_interval.end;
2552 
2553   for (; eviction_mem_interval.end > eviction_end_time;
2554        --eviction_mem_interval.end) {
2555     ChunkCandidate chunk_candidate =
2556         FindChunkCandidate(eviction_mem_interval, preferred_offset);
2557     if (chunk_candidate.chunk.offset == preferred_offset) {
2558       AddToPendingChunks(eviction_mem_interval, chunk_candidate);
2559       break;
2560     }
2561   }
2562   eviction_end_time = eviction_mem_interval.end;
2563 
2564   VLOG(3) << "Evicting buffer at " << prev_allocation->chunk().offset << " ("
2565           << eviction_start_time << ", " << eviction_end_time << ")";
2566 
2567   bool eviction_interval_too_short = (eviction_start_time == eviction_end_time);
2568   bool eviction_violates_outstanding_copies =
2569       ViolatesMaximumOutstandingAsyncCopies(eviction_start_time,
2570                                             eviction_end_time,
2571                                             /*is_prefetch=*/false);
2572 
2573   // See if this interval would violate the asynchronous copy limit.
2574   if (!eviction_interval_too_short && !eviction_violates_outstanding_copies) {
2575     prev_allocation->Extend(eviction_end_time);
2576     AddAsyncCopy(*prev_allocation, MemorySpace::kDefault,
2577                  /*chunk=*/absl::nullopt, eviction_start_time,
2578                  prev_allocation->end_time(), eviction_end_time,
2579                  request.allocation_value->allocation_sequence(),
2580                  /*aliased_offset=*/nullptr);
2581   } else {
2582     if (eviction_violates_outstanding_copies) {
2583       VLOG(3) << "This violates the maximum async copies.";
2584     } else {
2585       VLOG(3) << "Eviction interval is too short (" << eviction_start_time
2586               << ", " << eviction_end_time << ").";
2587     }
2588     // If the original interval violated the limit, try sub-intervals within
2589     // this interval.
2590     bool eviction_scheduled = false;
2591     for (int64_t time = eviction_start_time; time < eviction_end_time; ++time) {
2592       VLOG(4) << "Try evicting (" << time << ", " << time + 1 << ")";
2593       if (!ViolatesMaximumOutstandingAsyncCopies(time, time + 1,
2594                                                  /*is_prefetch=*/false)) {
2595         VLOG(3) << "Eviction successful.";
2596         AddAsyncCopy(*prev_allocation, MemorySpace::kDefault,
2597                      /*chunk=*/absl::nullopt, time, time + 1, time + 1,
2598                      request.allocation_value->allocation_sequence(),
2599                      /*aliased_offset=*/nullptr);
2600         eviction_scheduled = true;
2601         break;
2602       }
2603     }
2604 
2605     if (!eviction_scheduled) {
2606       // If the eviction couldn't be scheduled, then fail. This buffer will be
2607       // kept in the default memory.
2608       VLOG(3) << "Bailing: Could not evict " << request.use->hlo_use.ToString()
2609               << " because we hit the limit of maximum asynchronous copies "
2610               << "between "
2611               << hlo_live_range_.flattened_instruction_sequence()
2612                      .instructions()[eviction_start_time]
2613               << " and "
2614               << hlo_live_range_.flattened_instruction_sequence()
2615                      .instructions()[eviction_end_time];
2616       // return false;
2617       return Result::kFailOutOfAsyncCopies;
2618     }
2619   }
2620   // return true;
2621   return Result::kSuccess;
2622 }
2623 
FindPrefetchEndTime(const AllocationRequest & request,int64_t earliest_prefetch_time) const2624 int64 AlternateMemoryBestFitHeap::FindPrefetchEndTime(
2625     const AllocationRequest& request, int64_t earliest_prefetch_time) const {
2626   int64_t prefetch_end_time = request.latest_prefetch_time;
2627 
2628   const HloUse& use = request.use->hlo_use;
2629   const Shape& shape = ShapeUtil::GetSubshape(
2630       use.instruction->operand(use.operand_number)->shape(), use.operand_index);
2631   for (int retry_number = 0;
2632        retry_number < options_.prefetch_copy_done_reorder_max_retries;
2633        ++retry_number) {
2634     int64_t latest_prefetch_time =
2635         options_.prefetch_interval_picker->LatestPrefetchStartTime(
2636             shape, earliest_prefetch_time, prefetch_end_time, &use);
2637     VLOG(4) << "Latest prefetch start time = " << latest_prefetch_time
2638             << ", earliest prefetch start time = " << earliest_prefetch_time
2639             << ", prefetch end time = " << prefetch_end_time;
2640     // Return if we couldn't find a suitable prefetch start time.
2641     if (latest_prefetch_time < earliest_prefetch_time) {
2642       break;
2643     }
2644 
2645     // Return either if there is no other violating asynchronous copy (since we
2646     // don't need to change the prefetch end time) or if the violating
2647     // asynchronous copy ends after the prefetch end time.
2648     auto violating_async_copy =
2649         ViolatesAsyncCopyOrdering(latest_prefetch_time, prefetch_end_time);
2650     if (!violating_async_copy ||
2651         violating_async_copy->end_time >= prefetch_end_time) {
2652       break;
2653     }
2654     VLOG(4) << "Violating async copy: (" << violating_async_copy->start_time
2655             << ", " << violating_async_copy->end_time << ")";
2656 
2657     int64_t new_prefetch_end_time =
2658         options_.prefetch_interval_picker->LatestPrefetchEndTime(
2659             prefetch_end_time, violating_async_copy->end_time);
2660     if (new_prefetch_end_time > earliest_prefetch_time) {
2661       VLOG(3) << "Update prefetch end time = " << new_prefetch_end_time;
2662       prefetch_end_time = new_prefetch_end_time;
2663     } else {
2664       VLOG(3) << "Can't update prefetch end time = " << new_prefetch_end_time
2665               << " because earliest prefetch start time = "
2666               << earliest_prefetch_time;
2667       break;
2668     }
2669   }
2670 
2671   return prefetch_end_time;
2672 }
2673 
Prefetch(const AllocationRequest & request,const MemorySpaceAssignment::Allocation & prev_allocation_in_default_mem)2674 AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::Prefetch(
2675     const AllocationRequest& request,
2676     const MemorySpaceAssignment::Allocation& prev_allocation_in_default_mem) {
2677   // Try partially placing the buffer in the alternate space. The time that is
2678   // overlapped will be used to asynchronously copy the buffer from the
2679   // default memory to the alternate memory.
2680   //
2681   //                      start                 end
2682   //                      time                  time
2683   //                      X---------------------X
2684   // Alternate:                          +------+
2685   // Default:             +---------------------+
2686   //                                     ^      ^
2687   //                                   Copy    Copy
2688   //                                   Start   Done
2689   int64_t earliest_prefetch_time =
2690       prev_allocation_in_default_mem.earliest_available_time();
2691   if (request.earliest_prefetch_time) {
2692     earliest_prefetch_time =
2693         std::max(earliest_prefetch_time, *request.earliest_prefetch_time);
2694   }
2695   int64_t prefetch_end_time =
2696       FindPrefetchEndTime(request, earliest_prefetch_time);
2697 
2698   options_.prefetch_interval_picker->Begin(
2699       request.use->hlo_use, earliest_prefetch_time, prefetch_end_time);
2700   VLOG(3) << "Trying prefetch picker = "
2701           << options_.prefetch_interval_picker->ToDebugString();
2702 
2703   // Create an alternate memory interval that starts at the earliest
2704   // possible position, given by max_prefetch_interval.
2705   BufferInterval alternate_mem_interval;
2706   alternate_mem_interval.buffer = request.allocation_value->value();
2707   alternate_mem_interval.size = request.size;
2708   // While uses might be allowed to have additional outstanding prefetches.
2709   int64_t extra_async_copy_limit =
2710       request.use->hlo_use.instruction->opcode() == HloOpcode::kWhile
2711           ? options_.while_use_extra_outstanding_prefetch_limit
2712           : 0;
2713   Result result = Result::kSuccess;
2714   while (!options_.prefetch_interval_picker->Done()) {
2715     alternate_mem_interval.start = options_.prefetch_interval_picker->Next();
2716     CHECK_LT(alternate_mem_interval.start, prefetch_end_time);
2717     VLOG(4) << "Trying alternate memory allocation ("
2718             << alternate_mem_interval.start << ", " << request.end_time << ")";
2719     // If this additional asynchronous copy would violate the limit, try a
2720     // different interval.
2721     if (ViolatesAsyncCopyOrdering(alternate_mem_interval.start,
2722                                   prefetch_end_time)) {
2723       VLOG(4) << "This would violate asynchronous copy ordering.";
2724       result_mark(Result::kFailViolatesAsyncCopyOrdering, result);
2725       continue;
2726     }
2727     if (ViolatesMaximumOutstandingAsyncCopies(
2728             alternate_mem_interval.start, prefetch_end_time,
2729             /*is_prefetch=*/true, extra_async_copy_limit)) {
2730       VLOG(4) << "This would violate the outstanding async copy limit.";
2731       result_mark(Result::kFailOutOfAsyncCopies, result);
2732       continue;
2733     }
2734 
2735     auto chunk_candidate = FindBestChunkCandidate(
2736         request, request.preferred_offset, &alternate_mem_interval);
2737     // Check if we could find a suitable chunk.
2738     if (chunk_candidate) {
2739       VLOG(3) << "Move the buffer to alternate memory at "
2740               << alternate_mem_interval.start
2741               << ". Offset = " << chunk_candidate->chunk.offset
2742               << ", size = " << chunk_candidate->chunk.size
2743               << ", heap_size = " << chunk_candidate->heap_size
2744               << ", prefetch picker = "
2745               << options_.prefetch_interval_picker->ToDebugString();
2746       AddToPendingChunks(alternate_mem_interval, *chunk_candidate);
2747 
2748       AddAsyncCopy(prev_allocation_in_default_mem, MemorySpace::kAlternate,
2749                    chunk_candidate->chunk, alternate_mem_interval.start,
2750                    request.end_time, prefetch_end_time,
2751                    request.allocation_value->allocation_sequence(),
2752                    request.preferred_offset);
2753 
2754       request.allocation_value->allocation_sequence()->back()->AddUse(
2755           request.use->hlo_use);
2756       return Result::kSuccess;
2757     }
2758     result_mark(Result::kFailOutOfMemory, result);
2759   }
2760   // If we didn't consider any prefetch intervals, then the live range was too
2761   // short.
2762   if (result == Result::kSuccess) {
2763     return Result::kFailLiveRangeTooShort;
2764   } else {
2765     return result;
2766   }
2767 }
2768 
2769 absl::optional<AlternateMemoryBestFitHeap::ChunkCandidate>
FindBestChunkCandidate(const AllocationRequest & request,const AliasedOffset * preferred_offset,BufferInterval * alternate_mem_interval) const2770 AlternateMemoryBestFitHeap::FindBestChunkCandidate(
2771     const AllocationRequest& request, const AliasedOffset* preferred_offset,
2772     BufferInterval* alternate_mem_interval) const {
2773   int64_t end_time = request.end_time;
2774   if (!preferred_offset) {
2775     // First find the earliest use that is the same or later than the end time.
2776     const auto& use_times = request.all_use_times;
2777     auto use_time_it = use_times.begin();
2778     for (; *use_time_it < end_time; ++use_time_it) {
2779     }
2780     CHECK(use_time_it != use_times.end());
2781     int64_t earliest_use = *use_time_it;
2782 
2783     // Then find the latest use that can be allocated contiguously without
2784     // copies.
2785     const Shape& shape = request.allocation_value->defining_position().shape();
2786     for (;
2787          (use_time_it + 1) != use_times.end() &&
2788          options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
2789              shape, *use_time_it, *(use_time_it + 1));
2790          ++use_time_it) {
2791     }
2792     CHECK(use_time_it != use_times.end());
2793     int64_t latest_contiguous_use_time = *use_time_it;
2794 
2795     // Find a chunk that's as long living as possible iterating in reverse over
2796     // the use times.
2797     for (; use_time_it >= use_times.begin() && *use_time_it >= end_time;
2798          --use_time_it) {
2799       alternate_mem_interval->end = *use_time_it;
2800       ChunkCandidate chunk_candidate =
2801           FindChunkCandidate(*alternate_mem_interval);
2802       if (chunk_candidate.heap_size <= available_heap_size()) {
2803         alternate_mem_interval->end = end_time;
2804         VLOG(3) << "FindBestChunkCandidate earliest use = " << earliest_use
2805                 << ", latest contiguous use = " << latest_contiguous_use_time
2806                 << ", use with available mem = " << *use_time_it
2807                 << ", offset = " << chunk_candidate.chunk.offset;
2808         return chunk_candidate;
2809       }
2810     }
2811     alternate_mem_interval->end = end_time;
2812     return absl::nullopt;
2813   }
2814   // If a preferred offset is given, try to find an allocation at that offset
2815   // only.
2816   alternate_mem_interval->end = end_time;
2817   ChunkCandidate chunk_candidate =
2818       FindChunkCandidate(*alternate_mem_interval, preferred_offset->offset);
2819   if (chunk_candidate.chunk.offset == preferred_offset->offset) {
2820     return chunk_candidate;
2821   }
2822   return absl::nullopt;
2823 }
2824 
2825 StatusOr<MemorySpaceAssignment::AsyncCopyStats>
CalculateAsyncCopyStats() const2826 MemorySpaceAssignment::CalculateAsyncCopyStats() const {
2827   AsyncCopyStats stats;
2828   stats.max_outstanding_async_copies = 0;
2829   stats.num_prefetches = 0;
2830   stats.prefetch_bytes = 0;
2831   stats.num_evictions = 0;
2832   stats.eviction_bytes = 0;
2833   int64_t current_copies = 0;
2834   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
2835                       HloDataflowAnalysis::Run(*module_));
2836   for (const HloComputation* computation :
2837        module_->MakeNonfusionComputations()) {
2838     for (HloInstruction* instruction : computation->instructions()) {
2839       if (instruction->opcode() == HloOpcode::kCopyStart) {
2840         current_copies++;
2841       } else if (instruction->opcode() == HloOpcode::kCopyDone) {
2842         current_copies--;
2843         int64_t size =
2844             options_.size_fn(dataflow_analysis->GetUniqueValueAt(instruction));
2845         if (instruction->shape().layout().memory_space() ==
2846             options_.alternate_memory_space) {
2847           ++stats.num_prefetches;
2848           stats.prefetch_bytes += size;
2849         } else {
2850           ++stats.num_evictions;
2851           stats.eviction_bytes += size;
2852         }
2853       }
2854       stats.max_outstanding_async_copies =
2855           std::max(stats.max_outstanding_async_copies, current_copies);
2856     }
2857   }
2858   return stats;
2859 }
2860 
2861 /*static*/ MemorySpaceAssignment::BufferIntervalCompare
GetMemoryBoundednessBufferIntervalCompare(const MemorySpaceAssignmentCostAnalysis & cost_analysis,MemorySpaceAssignmentCostAnalysis::Cache * cache)2862 MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
2863     const MemorySpaceAssignmentCostAnalysis& cost_analysis,
2864     MemorySpaceAssignmentCostAnalysis::Cache* cache) {
2865   return [&cost_analysis, cache](const BufferInterval& x,
2866                                  const BufferInterval& y) {
2867     float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x, cache);
2868     float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y, cache);
2869     if (x_memory_boundedness != y_memory_boundedness) {
2870       return x_memory_boundedness > y_memory_boundedness;
2871     }
2872     // Tie-break if the memory boundedness is the same.
2873     return GlobalDecreasingSizeBestFitHeap<
2874         HloValue>::GetSpatialBufferIntervalCompare()(x, y);
2875   };
2876 }
2877 
2878 /*static*/ StatusOr<std::unique_ptr<PresetAssignments>>
Run(HloModule * module,const HloLiveRange & hlo_live_range,const HloAliasAnalysis & alias_analysis,const Options & options)2879 MemorySpaceAssignment::Run(HloModule* module,
2880                            const HloLiveRange& hlo_live_range,
2881                            const HloAliasAnalysis& alias_analysis,
2882                            const Options& options) {
2883   CHECK(module->has_schedule());
2884   VLOG(3) << "Module before memory space assignment: ";
2885   XLA_VLOG_LINES(3, module->ToString());
2886   VLOG(3) << "Schedule: " << module->schedule().ToString();
2887   MemorySpaceAssignment memory_space_assignment(module, options,
2888                                                 hlo_live_range);
2889 
2890   return memory_space_assignment.RunMemorySpaceAssignment(hlo_live_range,
2891                                                           alias_analysis);
2892 }
2893 
2894 StatusOr<std::unique_ptr<PresetAssignments>>
RunMemorySpaceAssignment(const HloLiveRange & hlo_live_range,const HloAliasAnalysis & alias_analysis)2895 MemorySpaceAssignment::RunMemorySpaceAssignment(
2896     const HloLiveRange& hlo_live_range,
2897     const HloAliasAnalysis& alias_analysis) {
2898   TF_RETURN_IF_ERROR(FindAllocationSequence(hlo_live_range, alias_analysis));
2899 
2900   if (options_.cost_analysis) {
2901     float estimated_time =
2902         ComputeEstimatedElapsedTime(hlo_live_range, allocations_);
2903     VLOG(1) << "Estimated elapsed time (sec): " << estimated_time;
2904   }
2905 
2906   TF_RETURN_IF_ERROR(Process());
2907   ScheduleAsynchronousCopies();
2908   TF_RETURN_IF_ERROR(SimplifyGraph());
2909   TF_RETURN_IF_ERROR(FixSchedule());
2910   TF_RETURN_IF_ERROR(ExportAndColorBuffers());
2911 
2912   VLOG(3) << "Module after memory space assignment: ";
2913   XLA_VLOG_LINES(3, module_->ToString());
2914   TF_CHECK_OK(module_->schedule().Verify());
2915   TF_ASSIGN_OR_RETURN(AsyncCopyStats stats, CalculateAsyncCopyStats());
2916   VLOG(1) << "Maximum number of outstanding async copies: "
2917           << stats.max_outstanding_async_copies;
2918   VLOG(1) << "Number of prefetches: " << stats.num_prefetches
2919           << ", in bytes: " << stats.prefetch_bytes;
2920   VLOG(1) << "Number of evictions: " << stats.num_evictions
2921           << ", in bytes: " << stats.eviction_bytes;
2922 
2923   TF_RETURN_IF_ERROR(VerifyAndExportHeapSimulatorTrace());
2924 
2925   return std::move(preset_assignments_);
2926 }
2927 
FindAllocationSequence(const HloLiveRange & hlo_live_range,const HloAliasAnalysis & alias_analysis)2928 Status MemorySpaceAssignment::FindAllocationSequence(
2929     const HloLiveRange& hlo_live_range,
2930     const HloAliasAnalysis& alias_analysis) {
2931   auto algorithm = absl::make_unique<AlternateMemoryBestFitHeap>(
2932       &allocations_, options_, alias_analysis, hlo_live_range);
2933 
2934   HeapSimulator::Options heap_simulator_options;
2935   heap_simulator_options.may_reuse_operand_buffers = false;
2936   heap_simulator_options.alloc_constants = true;
2937   TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module_,
2938                                         module_->schedule(), alias_analysis,
2939                                         options_.size_fn,
2940                                         heap_simulator_options)
2941                          .status());
2942   return Status::OK();
2943 }
2944 
AddUse(HloUse use)2945 void MemorySpaceAssignment::Allocation::AddUse(HloUse use) {
2946   HloInstruction* operand =
2947       use.instruction->mutable_operand(use.operand_number);
2948   // If the use is a tuple, look inside the tuple to find the actual use.
2949   for (int64_t index : use.operand_index) {
2950     if (operand->opcode() != HloOpcode::kTuple) {
2951       break;
2952     }
2953     operand = operand->mutable_operand(index);
2954   }
2955 
2956   // Look beyond GetTupleElement(Tuple()) pattern for any bitcasts.
2957   std::function<HloInstruction*(HloInstruction*)> get_simplified_operand;
2958   get_simplified_operand = [&](HloInstruction* instruction) {
2959     while (instruction->opcode() == HloOpcode::kGetTupleElement) {
2960       HloInstruction* operand =
2961           get_simplified_operand(instruction->mutable_operand(0));
2962       if (operand->opcode() == HloOpcode::kTuple) {
2963         instruction = operand->mutable_operand(instruction->tuple_index());
2964       } else {
2965         return instruction;
2966       }
2967     }
2968     return instruction;
2969   };
2970   operand = get_simplified_operand(operand);
2971 
2972   uses_.push_back(use);
2973 }
2974 
ComputeEstimatedElapsedTime(const HloLiveRange & hlo_live_range,const AllocationSequence & allocations)2975 float MemorySpaceAssignment::ComputeEstimatedElapsedTime(
2976     const HloLiveRange& hlo_live_range, const AllocationSequence& allocations) {
2977   absl::flat_hash_map<const HloInstruction*, std::vector<ShapeIndex>>
2978       outputs_in_alternate_memory_map;
2979   absl::flat_hash_map<const HloInstruction*,
2980                       std::vector<std::pair<int64, ShapeIndex>>>
2981       operands_in_alternate_memory_map;
2982 
2983   for (auto& allocation : allocations) {
2984     if (!allocation->is_copy_allocation()) {
2985       if (allocation->memory_space() == MemorySpace::kAlternate) {
2986         const HloInstruction* defining_instruction =
2987             allocation->defining_position().instruction;
2988         outputs_in_alternate_memory_map[defining_instruction].push_back(
2989             allocation->defining_position().index);
2990       }
2991     }
2992     for (auto& hlo_use : allocation->uses()) {
2993       const HloInstruction* use_instruction = hlo_use.instruction;
2994       operands_in_alternate_memory_map[use_instruction].push_back(
2995           std::make_pair(hlo_use.operand_number, hlo_use.operand_index));
2996     }
2997   }
2998 
2999   const auto& instruction_sequence =
3000       hlo_live_range.flattened_instruction_sequence().instructions();
3001   float total_elapsed = 0.0;
3002   for (const HloInstruction* instruction : instruction_sequence) {
3003     std::vector<ShapeIndex> outputs_in_alternate_memory;
3004     auto output_it = outputs_in_alternate_memory_map.find(instruction);
3005     if (output_it != outputs_in_alternate_memory_map.end()) {
3006       outputs_in_alternate_memory = output_it->second;
3007     }
3008     std::vector<std::pair<int64_t, ShapeIndex>> operands_in_alternate_memory;
3009     auto operand_it = operands_in_alternate_memory_map.find(instruction);
3010     if (operand_it != operands_in_alternate_memory_map.end()) {
3011       operands_in_alternate_memory = operand_it->second;
3012     }
3013     float instruction_elapsed =
3014         options_.cost_analysis->GetInstructionElapsedInAlternateMemory(
3015             *instruction, operands_in_alternate_memory,
3016             outputs_in_alternate_memory);
3017     float while_nest_multiplier = tensorflow::MathUtil::IPow<float>(
3018         options_.xla_tpu_memory_space_assignment_while_execution_count,
3019         options_.cost_analysis->CalculateComputationNestLevel(
3020             instruction,
3021             /*while_only=*/true));
3022     total_elapsed += while_nest_multiplier * instruction_elapsed;
3023   }
3024   return total_elapsed;
3025 }
3026 
Process()3027 Status MemorySpaceAssignment::Allocation::Process() {
3028   if (is_scoped_allocation()) {
3029     // Nothing to do here for scoped allocations.
3030     return Status::OK();
3031   }
3032   HloInstruction* producing_instruction = AddGetTupleElements();
3033   HloComputation* computation = producing_instruction->parent();
3034   for (const HloUse& use : uses_) {
3035     Shape operand_shape = use.instruction->operand(use.operand_number)->shape();
3036     HloInstruction* replacement_instruction = producing_instruction;
3037     if (operand_shape.IsTuple()) {
3038       TF_ASSIGN_OR_RETURN(
3039           replacement_instruction,
3040           ReplaceTupleWith(producing_instruction,
3041                            use.instruction->mutable_operand(use.operand_number),
3042                            use.operand_index));
3043     } else if (operand_shape != producing_instruction->shape()) {
3044       VLOG(4) << "Old shape = " << operand_shape.ToString()
3045               << ", new shape = " << producing_instruction->shape().ToString()
3046               << "; inserting a bitcast.";
3047       replacement_instruction = computation->AddInstruction(
3048           HloInstruction::CreateBitcast(operand_shape, producing_instruction));
3049     }
3050     TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith(
3051         use.operand_number, replacement_instruction));
3052   }
3053   return Status::OK();
3054 }
3055 
ReplaceTupleWith(HloInstruction * new_instruction,HloInstruction * tuple,ShapeIndex shape_index)3056 StatusOr<HloInstruction*> MemorySpaceAssignment::Allocation::ReplaceTupleWith(
3057     HloInstruction* new_instruction, HloInstruction* tuple,
3058     ShapeIndex shape_index) {
3059   const Shape& tuple_shape = tuple->shape();
3060   CHECK(tuple->shape().IsTuple())
3061       << "ReplaceTupleWith was called for a non-tuple. Tuple = "
3062       << tuple->ToString()
3063       << ", new_instruction = " << new_instruction->ToString()
3064       << ", shape_index = " << shape_index.ToString();
3065 
3066   HloComputation* computation = new_instruction->parent();
3067   std::vector<HloInstruction*> tuple_args(tuple_shape.tuple_shapes_size());
3068   CHECK_GE(tuple_shape.tuple_shapes_size(), shape_index[0]);
3069   for (int64_t i = 0; i < tuple_shape.tuple_shapes_size(); ++i) {
3070     const Shape& subshape = tuple_shape.tuple_shapes(i);
3071     // If tuple is a tuple instruction, we can get the tuple instruction's
3072     // operand to construct the new tuple to improve compilation time
3073     // performance.
3074     auto get_operand = [&]() {
3075       if (tuple->opcode() == HloOpcode::kTuple) {
3076         return tuple->mutable_operand(i);
3077       } else {
3078         return computation->AddInstruction(
3079             HloInstruction::CreateGetTupleElement(subshape, tuple, i));
3080       }
3081     };
3082     if (i == shape_index[0]) {
3083       // If the subshape is still a tuple, recurse and pass a new shape index
3084       // for the one level deeper.
3085       if (subshape.IsTuple()) {
3086         TF_ASSIGN_OR_RETURN(tuple_args[i],
3087                             ReplaceTupleWith(new_instruction, get_operand(),
3088                                              ShapeIndex(shape_index.begin() + 1,
3089                                                         shape_index.end())));
3090       } else {
3091         if (subshape != new_instruction->shape()) {
3092           VLOG(4) << "Old shape = " << subshape.ToString()
3093                   << ", new shape = " << new_instruction->shape().ToString()
3094                   << "; inserting a bitcast.";
3095           new_instruction = computation->AddInstruction(
3096               HloInstruction::CreateBitcast(subshape, new_instruction));
3097         } else if (tuple->opcode() == HloOpcode::kTuple &&
3098                    tuple->operand(i) == new_instruction) {
3099           // If the tuple element is the same as the new instruction, we
3100           // actually don't have to create a new tuple, just return the original
3101           // tuple.
3102           VLOG(4) << "Tuple already contains the new instruction = "
3103                   << new_instruction->ToShortString()
3104                   << " tuple = " << tuple->ToShortString();
3105           return tuple;
3106         }
3107         tuple_args[i] = new_instruction;
3108       }
3109     } else {
3110       tuple_args[i] = get_operand();
3111     }
3112   }
3113   if (shape_index[0] == tuple_shape.tuple_shapes_size()) {
3114     // If shape_index[0] is equal to the tuple shape size, add the new
3115     // instruction as an additional argument.
3116     tuple_args.push_back(new_instruction);
3117   }
3118   return computation->AddInstruction(HloInstruction::CreateTuple(tuple_args));
3119 }
3120 
AddGetTupleElements() const3121 HloInstruction* MemorySpaceAssignment::Allocation::AddGetTupleElements() const {
3122   HloInstruction* producing_instruction = defining_position().instruction;
3123   CHECK_NE(producing_instruction, nullptr);
3124 
3125   Shape shape = defining_position().shape();
3126   CHECK(shape.IsArray()) << "Allocation shape is not an array. Shape = "
3127                          << shape.ToString()
3128                          << " position = " << defining_position().shape();
3129   HloComputation* computation = producing_instruction->parent();
3130 
3131   // If the instruction we're processing is a tuple, we (recursively) search or
3132   // create kGetTupleElement instructions and copy that value. Asynchronous
3133   // copies only support array types.
3134   for (int64_t index : defining_position().index) {
3135     // We first search if there already is a get-tuple-element with the correct
3136     // index. If there is no such get-tuple-element, we create one.
3137     auto gte_it = absl::c_find_if(
3138         producing_instruction->users(), [index](const HloInstruction* use) {
3139           return use != use->parent()->root_instruction() &&
3140                  use->opcode() == HloOpcode::kGetTupleElement &&
3141                  use->tuple_index() == index;
3142         });
3143     if (gte_it != producing_instruction->users().end()) {
3144       producing_instruction = *gte_it;
3145     } else {
3146       producing_instruction =
3147           computation->AddInstruction(HloInstruction::CreateGetTupleElement(
3148               producing_instruction->shape().tuple_shapes(index),
3149               producing_instruction, index));
3150     }
3151   }
3152   return producing_instruction;
3153 }
3154 
ToString() const3155 std::string MemorySpaceAssignment::Allocation::ToString() const {
3156   std::string memory_space_str = "def";
3157   if (memory_space_ == MemorySpace::kAlternate) {
3158     memory_space_str = absl::StrCat("alt (off: ", chunk_->offset, ")");
3159   }
3160   return absl::StrCat((is_scoped_allocation() ? "Scoped " : ""),
3161                       "Allocation in ", memory_space_str, " defined at ",
3162                       defining_position_.ToString());
3163 }
3164 
ToString() const3165 std::string MemorySpaceAssignment::CopyAllocation::ToString() const {
3166   std::string memory_space_str = "def";
3167   if (memory_space_ == MemorySpace::kAlternate) {
3168     memory_space_str = absl::StrCat("alt (off: ", chunk_->offset, ")");
3169   }
3170   return absl::StrCat("Copy Allocation in ", memory_space_str, " from ",
3171                       prev_allocation_.ToString());
3172 }
3173 
ToString() const3174 std::string MemorySpaceAssignment::ParentAllocation::ToString() const {
3175   return absl::StrCat("Parent Allocation mirrored at ",
3176                       defining_position_.ToString(), ", originally ",
3177                       original_allocation_.ToString());
3178 }
3179 
Process()3180 Status MemorySpaceAssignment::CopyAllocation::Process() {
3181   // Copy allocations need to insert asynchronous copy nodes.
3182   Shape shape = defining_position().shape();
3183   HloInstruction* producing_instruction = AddGetTupleElements();
3184   HloComputation* computation = producing_instruction->parent();
3185   copy_start_ = computation->AddInstruction(HloInstruction::CreateCopyStart(
3186       ShapeUtil::MakeTupleShape({shape, shape, ShapeUtil::MakeShape(U32, {})}),
3187       producing_instruction, is_cross_program_prefetch_));
3188   copy_done_ = computation->AddInstruction(
3189       HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_));
3190   VLOG(4) << "Created " << copy_start_->name()
3191           << " for position: " << defining_position().ToString();
3192   // Update the allocation position with the copy done instruction so that if
3193   // there are further copies from it, it can find the correct position.
3194   defining_position_ = HloPosition{copy_done_, {}};
3195 
3196   // Replace all the uses with the new copy instruction.
3197   for (HloUse use : uses_) {
3198     // If the operand is a tuple, we need to descend to the actual instruction
3199     // we want to replace.
3200     HloInstruction* replacement_instruction;
3201     Shape operand_shape = use.instruction->operand(use.operand_number)->shape();
3202     if (operand_shape.IsTuple()) {
3203       TF_ASSIGN_OR_RETURN(
3204           replacement_instruction,
3205           ReplaceTupleWith(copy_done_,
3206                            use.instruction->mutable_operand(use.operand_number),
3207                            use.operand_index));
3208     } else if (operand_shape != copy_done_->shape()) {
3209       VLOG(4) << "Old shape = " << operand_shape.ToString()
3210               << ", new shape = " << copy_done_->shape().ToString()
3211               << "; inserting a bitcast.";
3212       replacement_instruction = computation->AddInstruction(
3213           HloInstruction::CreateBitcast(operand_shape, copy_done_));
3214     } else {
3215       replacement_instruction = copy_done_;
3216     }
3217     TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith(
3218         use.operand_number, replacement_instruction));
3219   }
3220 
3221   return Status::OK();
3222 }
3223 
Process()3224 Status MemorySpaceAssignment::ParentAllocation::Process() {
3225   // Add an additional parameter to the while HLO with a reference to the buffer
3226   // in the default memory space.
3227   HloInstruction* producing_instruction =
3228       original_allocation_.AddGetTupleElements();
3229   int64_t new_tuple_index = calling_instruction_->shape().tuple_shapes_size();
3230 
3231   TF_ASSIGN_OR_RETURN(HloInstruction * new_while_operand,
3232                       ReplaceTupleWith(producing_instruction,
3233                                        calling_instruction_->mutable_operand(0),
3234                                        {new_tuple_index}));
3235   TF_RETURN_IF_ERROR(calling_instruction_->ReplaceOperandWithDifferentShape(
3236       0, new_while_operand));
3237   *calling_instruction_->mutable_shape() = new_while_operand->shape();
3238   *calling_instruction_->while_condition()
3239        ->parameter_instruction(0)
3240        ->mutable_shape() = new_while_operand->shape();
3241   *calling_instruction_->while_body()
3242        ->parameter_instruction(0)
3243        ->mutable_shape() = new_while_operand->shape();
3244   defining_position_.index = {new_tuple_index};
3245   return Allocation::Process();
3246 }
3247 
PostProcess()3248 Status MemorySpaceAssignment::ParentAllocation::PostProcess() {
3249   // Update the root of the while body with the new parameter. The reason why we
3250   // need a separate post-process for this is because other allocations may have
3251   // while body root as a use, so they would update the old root instead of the
3252   // new root. Doing the post-process step later ensures the root has been
3253   // updated with other changes, and we can safely add the additional parameter.
3254   HloComputation* while_body = calling_instruction_->while_body();
3255   TF_ASSIGN_OR_RETURN(
3256       HloInstruction * new_while_body_root,
3257       ReplaceTupleWith(AddGetTupleElements(), while_body->root_instruction(),
3258                        defining_position_.index));
3259   while_body->set_root_instruction(new_while_body_root,
3260                                    /*accept_different_shape=*/true);
3261   return Status::OK();
3262 }
3263 
MarkIfNeeded(absl::flat_hash_set<const Allocation * > & needed_allocations) const3264 void MemorySpaceAssignment::Allocation::MarkIfNeeded(
3265     absl::flat_hash_set<const Allocation*>& needed_allocations) const {
3266   MarkNeeded(needed_allocations);
3267 }
3268 
MarkNeeded(absl::flat_hash_set<const Allocation * > & needed_allocations) const3269 void MemorySpaceAssignment::Allocation::MarkNeeded(
3270     absl::flat_hash_set<const Allocation*>& needed_allocations) const {
3271   needed_allocations.insert(this);
3272 }
3273 
MarkNeeded(absl::flat_hash_set<const Allocation * > & needed_allocations) const3274 void MemorySpaceAssignment::CopyAllocation::MarkNeeded(
3275     absl::flat_hash_set<const Allocation*>& needed_allocations) const {
3276   needed_allocations.insert(this);
3277   prev_allocation_.MarkNeeded(needed_allocations);
3278 }
3279 
MarkIfNeeded(absl::flat_hash_set<const Allocation * > & needed_allocations) const3280 void MemorySpaceAssignment::ParentAllocation::MarkIfNeeded(
3281     absl::flat_hash_set<const Allocation*>& needed_allocations) const {
3282   // Parent allocations are only needed if they have any uses or if there is a
3283   // copy allocation that copies this value (in that case, the copy allocation
3284   // will call this allocation's MarkNeeded function).
3285   if (!uses_.empty()) {
3286     MarkNeeded(needed_allocations);
3287   }
3288 }
3289 
MarkNeeded(absl::flat_hash_set<const Allocation * > & needed_allocations) const3290 void MemorySpaceAssignment::ParentAllocation::MarkNeeded(
3291     absl::flat_hash_set<const Allocation*>& needed_allocations) const {
3292   needed_allocations.insert(this);
3293   original_allocation_.MarkNeeded(needed_allocations);
3294 }
3295 
Process()3296 Status MemorySpaceAssignment::Process() {
3297   VLOG(1) << "Processing assigned buffers...";
3298   // Since some parent allocations may not be needed (e.g. when they don't have
3299   // any uses and if there is no other (non-parent) allocation that depends on
3300   // it, before we process the allocations, mark all allocations that are
3301   // needed.
3302   absl::flat_hash_set<const Allocation*> needed_allocations;
3303   for (auto& allocation : allocations_) {
3304     allocation->MarkIfNeeded(needed_allocations);
3305   }
3306   // Insert CopyStart/CopyDone pairs.
3307   for (auto& allocation : allocations_) {
3308     VLOG(3) << "Processing: " << allocation->ToString();
3309     if (!needed_allocations.contains(allocation.get())) {
3310       VLOG(3) << "Allocation not needed.";
3311       continue;
3312     }
3313     TF_RETURN_IF_ERROR(allocation->Process());
3314     // Add the offset and size of the allocation in the alternate memory to
3315     // the output map.
3316     if (allocation->is_scoped_allocation()) {
3317       CHECK(allocation->memory_space() == MemorySpace::kAlternate);
3318       scoped_memory_assignments_.emplace_back(
3319           allocation->defining_position().instruction, allocation->chunk());
3320       alternate_memory_size_ =
3321           std::max(alternate_memory_size_, allocation->chunk().chunk_end());
3322     } else if (allocation->memory_space() == MemorySpace::kAlternate) {
3323       alternate_memory_assignments_.emplace_back(
3324           allocation->defining_position(), allocation->chunk());
3325       alternate_memory_size_ =
3326           std::max(alternate_memory_size_, allocation->chunk().chunk_end());
3327     }
3328   }
3329   // Post-process allocations. This is only used for parent allocations where we
3330   // update the body root with a reference to the buffer in default memory
3331   // space.
3332   for (auto& allocation : allocations_) {
3333     if (needed_allocations.contains(allocation.get())) {
3334       VLOG(3) << "Post-Processing: " << allocation->ToString();
3335       TF_RETURN_IF_ERROR(allocation->PostProcess());
3336     }
3337   }
3338   return Status::OK();
3339 }
3340 
ExportAndColorBuffers()3341 Status MemorySpaceAssignment::ExportAndColorBuffers() {
3342   VLOG(1) << "Exporting buffers...";
3343   TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module_));
3344   absl::flat_hash_map<int64, int64> seen_buffer_offsets;
3345   VLOG(3) << "Exported alternate memory allocations:";
3346   for (const auto& position_and_chunk : alternate_memory_assignments_) {
3347     const HloPosition& defining_position = position_and_chunk.first;
3348     const Chunk& chunk = position_and_chunk.second;
3349     const HloBuffer& buffer = alias_analysis->GetUniqueBufferAt(
3350         defining_position.instruction, defining_position.index);
3351     auto seen_buffer_offset_it = seen_buffer_offsets.find(buffer.id());
3352     if (seen_buffer_offset_it != seen_buffer_offsets.end()) {
3353       CHECK_EQ(chunk.offset, seen_buffer_offset_it->second)
3354           << "Mismatch in offset for positions that map to the same value: "
3355           << buffer.ToString() << ", pos: " << defining_position.ToString();
3356     } else {
3357       VLOG(3) << " [" << chunk.offset << ", " << chunk.size
3358               << "] : " << defining_position.ToString() << " ("
3359               << buffer.ToString() << ")";
3360       preset_assignments_->add_chunk(defining_position, chunk);
3361       seen_buffer_offsets[buffer.id()] = chunk.offset;
3362     }
3363   }
3364 
3365   VLOG(3) << "Exported scoped allocations in alternate memory:";
3366   for (const auto& instruction_and_chunk : scoped_memory_assignments_) {
3367     HloInstruction* instruction = instruction_and_chunk.first;
3368     const Chunk& chunk = instruction_and_chunk.second;
3369     VLOG(3) << " [" << chunk.offset << ", " << chunk.size
3370             << "] : " << instruction->name();
3371     preset_assignments_->add_scoped_allocation_chunk(instruction, chunk);
3372   }
3373 
3374   if (!preset_assignments_->chunks().empty() ||
3375       !preset_assignments_->scoped_allocation_chunks().empty()) {
3376     preset_assignments_
3377         ->assignment_information_for_space(options_.alternate_memory_space)
3378         ->size = alternate_memory_size_;
3379   }
3380 
3381   VLOG(3) << "Exported alternate memory sizes:";
3382   for (auto& pair : preset_assignments_->assignment_informations()) {
3383     VLOG(3) << "  space: " << pair.first << ", size: " << pair.second.size;
3384   }
3385 
3386   VLOG(1) << "Coloring buffers...";
3387   // Color the pending positions and all of their aliased buffers.
3388   for (const auto& defining_position_and_chunk :
3389        preset_assignments_->chunks()) {
3390     const HloPosition& defining_position = defining_position_and_chunk.first;
3391     for (auto& buffer : alias_analysis->ComputeBuffersAt(
3392              defining_position.instruction, defining_position.index)) {
3393       for (auto& value : buffer->values()) {
3394         for (auto& position : value->positions()) {
3395           VLOG(4) << "Coloring " << position.ToString();
3396           Shape* shape = ShapeUtil::GetMutableSubshape(
3397               position.instruction->mutable_shape(), position.index);
3398           CHECK(shape->IsArray()) << "Coloring a shape that is not an array: "
3399                                   << position.ToString();
3400           shape->mutable_layout()->set_memory_space(
3401               options_.alternate_memory_space);
3402         }
3403       }
3404     }
3405   }
3406   return Status::OK();
3407 }
3408 
RemoveAssignmentForInstruction(const HloInstruction * instruction)3409 void MemorySpaceAssignment::RemoveAssignmentForInstruction(
3410     const HloInstruction* instruction) {
3411   for (auto& position_and_chunk : alternate_memory_assignments_) {
3412     const HloPosition& position = position_and_chunk.first;
3413     if (position.instruction == instruction) {
3414       VLOG(3) << "Removing instruction from alternate memory assignments.";
3415       // Swap the removed position and chunk with the back and pop back.
3416       position_and_chunk = alternate_memory_assignments_.back();
3417       alternate_memory_assignments_.pop_back();
3418       break;
3419     }
3420   }
3421 }
3422 
SimplifyGraph()3423 Status MemorySpaceAssignment::SimplifyGraph() {
3424   VLOG(1) << "Simplifying graph...";
3425   for (HloComputation* computation : module_->MakeNonfusionComputations()) {
3426     // Parallel computations aren't in the schedule and don't need to be
3427     // modified.
3428     if (!computations_in_schedule_.contains(computation)) {
3429       VLOG(4) << "Not simplifying " << computation->name()
3430               << " because it's not in the schedule.";
3431       continue;
3432     }
3433     // Drop control dependencies. Since the computation is already scheduled, we
3434     // don't need control dependencies anymore, and having control
3435     // predecessors/successors prevents us from removing instructions without
3436     // users (HloComputation::IsSafelyRemovable returns false if there are
3437     // control dependencies).
3438     for (HloInstruction* instruction :
3439          computation->MakeInstructionPostOrder()) {
3440       TF_RETURN_IF_ERROR(instruction->DropAllControlDeps());
3441     }
3442     // We perform limited DCE and forward the tuple operand in patterns like
3443     // GetTupleElement(Tuple(a, b), 0). This is mostly because memory space
3444     // assignment is ran late in compilation (after DCE and arithmetic
3445     // simplification passes) and we don't want to generate redundant code.  Run
3446     // to fixed point.
3447     bool computation_modified = true;
3448     while (computation_modified) {
3449       computation_modified = false;
3450       VLOG(4) << "Running simplify graph loop over " << computation->name();
3451       for (HloInstruction* instruction :
3452            computation->MakeInstructionPostOrder()) {
3453         if (computation->IsSafelyRemovable(instruction) &&
3454             instruction->user_count() == 0 && !instruction->HasSideEffect() &&
3455             instruction != computation->root_instruction() &&
3456             instruction->opcode() != HloOpcode::kCopyStart &&
3457             instruction->opcode() != HloOpcode::kCopyDone) {
3458           VLOG(4) << "Instruction removed: " << instruction->ToString();
3459           // Ensure the alternate memory assignments don't contain a reference
3460           // to the removed instruction.
3461           RemoveAssignmentForInstruction(instruction);
3462           // Instead of deleting the instruction from the schedule, replace it
3463           // with a nullptr. This is needed because FixSchedule relies on the
3464           // logical time that is the index into flattened_instructions_ for
3465           // scheduling asynchronous copies.
3466           auto instruction_it =
3467               absl::c_find(flattened_instructions_, instruction);
3468           if (instruction_it != flattened_instructions_.end()) {
3469             *instruction_it = nullptr;
3470           }
3471           TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
3472           computation_modified = true;
3473         } else if (instruction->opcode() == HloOpcode::kGetTupleElement) {
3474           HloInstruction* operand = instruction->mutable_operand(0);
3475           if (operand->opcode() == HloOpcode::kTuple) {
3476             HloInstruction* forwarded_instruction =
3477                 operand->mutable_operand(instruction->tuple_index());
3478             VLOG(4) << "Replacing uses of " << instruction->ToString()
3479                     << " with " << forwarded_instruction->ToString();
3480             TF_RETURN_IF_ERROR(
3481                 instruction->ReplaceAllUsesWith(forwarded_instruction));
3482             computation_modified = true;
3483           }
3484         } else if (instruction->opcode() == HloOpcode::kTuple) {
3485           // Replace Tuple(GetTupleElement(x), ..., GetTupleElement(x)) pattern
3486           // with x.
3487           bool can_replace =
3488               instruction->operand_count() > 0 &&
3489               instruction->operand(0)->opcode() ==
3490                   HloOpcode::kGetTupleElement &&
3491               instruction->operand(0)
3492                       ->operand(0)
3493                       ->shape()
3494                       .tuple_shapes_size() == instruction->operand_count();
3495           for (int operand_number = 0;
3496                operand_number < instruction->operand_count();
3497                ++operand_number) {
3498             const HloInstruction* operand =
3499                 instruction->operand(operand_number);
3500             if (operand->opcode() != HloOpcode::kGetTupleElement ||
3501                 operand->tuple_index() != operand_number ||
3502                 operand->operand(0) != instruction->operand(0)->operand(0)) {
3503               can_replace = false;
3504               break;
3505             }
3506           }
3507           if (can_replace) {
3508             HloInstruction* forwarded_instruction =
3509                 instruction->mutable_operand(0)->mutable_operand(0);
3510             VLOG(4) << "Replacing uses of " << instruction->ToString()
3511                     << " with " << forwarded_instruction->ToString();
3512             TF_RETURN_IF_ERROR(
3513                 instruction->ReplaceAllUsesWith(forwarded_instruction));
3514             computation_modified = true;
3515           }
3516         }
3517       }
3518     }
3519   }
3520 
3521   return Status::OK();
3522 }
3523 
EnsureInstructionAndOperandsInserted(HloInstruction * new_instruction,HloInstructionSequence * new_sequence,absl::flat_hash_set<HloInstruction * > * inserted_instructions) const3524 void MemorySpaceAssignment::EnsureInstructionAndOperandsInserted(
3525     HloInstruction* new_instruction, HloInstructionSequence* new_sequence,
3526     absl::flat_hash_set<HloInstruction*>* inserted_instructions) const {
3527   if (inserted_instructions->contains(new_instruction)) {
3528     return;
3529   }
3530   for (HloInstruction* operand : new_instruction->operands()) {
3531     // CopyStart/CopyDone dependencies should always be already inserted; it is
3532     // a red flag when they haven't already been inserted.
3533     CHECK((operand->opcode() != HloOpcode::kCopyStart &&
3534            operand->opcode() != HloOpcode::kCopyDone) ||
3535           inserted_instructions->contains(operand))
3536         << "Inserted instruction " << new_instruction->ToString()
3537         << " has un-inserted dependency: " << operand->ToString();
3538     EnsureInstructionAndOperandsInserted(operand, new_sequence,
3539                                          inserted_instructions);
3540   }
3541   VLOG(4) << "inserting: " << new_instruction->ToShortString();
3542   new_sequence->push_back(new_instruction);
3543   inserted_instructions->insert(new_instruction);
3544 }
3545 
ScheduleAsynchronousCopies()3546 void MemorySpaceAssignment::ScheduleAsynchronousCopies() {
3547   VLOG(1) << "Scheduling asynchronous copies...";
3548   for (MemorySpace memory_space :
3549        {MemorySpace::kDefault, MemorySpace::kAlternate}) {
3550     std::vector<CopyAllocation*> copy_allocations;
3551     for (auto& allocation : allocations_) {
3552       if (allocation->is_copy_allocation()) {
3553         auto copy_allocation = static_cast<CopyAllocation*>(allocation.get());
3554         if (copy_allocation->memory_space() == memory_space) {
3555           copy_allocations.push_back(copy_allocation);
3556         }
3557       }
3558     }
3559 
3560     absl::c_stable_sort(
3561         copy_allocations, [](CopyAllocation* first, CopyAllocation* second) {
3562           return std::forward_as_tuple(first->copy_done_schedule_before(),
3563                                        first->copy_start_schedule_after()) <
3564                  std::forward_as_tuple(second->copy_done_schedule_before(),
3565                                        second->copy_start_schedule_after());
3566         });
3567     for (CopyAllocation* copy_allocation : copy_allocations) {
3568       // If the copy start doesn't happen to be scheduled at the correct
3569       // computation, delay it until the correct computation starts.
3570       int64_t copy_start_schedule_after =
3571           copy_allocation->copy_start_schedule_after();
3572       // Accessing flattened_instructions_ here without checking if it is
3573       // nullptr is safe because this method is called before SimplifyGraph.
3574       while (copy_allocation->defining_position().instruction->parent() !=
3575              flattened_instructions_[copy_start_schedule_after]->parent()) {
3576         VLOG(4) << "Delaying CopyStart (" << copy_start_schedule_after << " to "
3577                 << (copy_start_schedule_after + 1) << ") for "
3578                 << copy_allocation->copy_start()->ToString()
3579                 << " because it is not in the correct computation.";
3580         copy_allocation->set_copy_start_schedule_after(
3581             ++copy_start_schedule_after);
3582       }
3583 
3584       schedule_after_[copy_allocation->copy_start_schedule_after()].push_back(
3585           copy_allocation->copy_start());
3586       schedule_before_[copy_allocation->copy_done_schedule_before()].push_back(
3587           copy_allocation->copy_done());
3588     }
3589   }
3590 }
3591 
FixSchedule()3592 Status MemorySpaceAssignment::FixSchedule() {
3593   VLOG(1) << "Fixing schedule...";
3594   CHECK(module_->has_schedule());
3595   HloSchedule& schedule = module_->schedule();
3596   for (const HloComputation* computation :
3597        module_->MakeNonfusionComputations()) {
3598     // Parallel computations aren't in the schedule and don't need to be
3599     // modified.
3600     if (!computations_in_schedule_.contains(computation)) {
3601       VLOG(4) << "Not scheduling " << computation->name()
3602               << " because it's not in the schedule.";
3603       continue;
3604     }
3605     CHECK(schedule.is_computation_scheduled(computation));
3606     HloInstructionSequence new_sequence;
3607 
3608     absl::flat_hash_set<HloInstruction*> inserted_instructions;
3609 
3610     VLOG(4) << "Scheduling: " << computation->ToString();
3611 
3612     for (int64_t instruction_index = 0;; ++instruction_index) {
3613       auto insts_before_iter = schedule_before_.find(instruction_index);
3614       if (insts_before_iter != schedule_before_.end()) {
3615         for (HloInstruction* new_instruction : insts_before_iter->second) {
3616           if (new_instruction->parent() == computation) {
3617             VLOG(4) << "before " << instruction_index << ": "
3618                     << new_instruction->name();
3619             EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence,
3620                                                  &inserted_instructions);
3621           }
3622         }
3623       }
3624       // We allow scheduling copy dones past the root instruction (for
3625       // end-of-program cross-program prefetch). So the loop exit condition is
3626       // actually here.
3627       if (instruction_index >= flattened_instructions_.size()) {
3628         break;
3629       }
3630       HloInstruction* instruction = flattened_instructions_[instruction_index];
3631       // Insert only if it is not deleted (SimplifyGraph sets it to nullptr if
3632       // it was deleted) and not previously inserted. Also bitcasts and tuples
3633       // are treated specially and only inserted as a result of operand
3634       // dependencies.
3635       if (instruction != nullptr &&
3636           !inserted_instructions.contains(instruction) &&
3637           instruction->parent() == computation &&
3638           instruction->opcode() != HloOpcode::kBitcast &&
3639           instruction->opcode() != HloOpcode::kTuple) {
3640         VLOG(4) << "inst " << instruction_index << ": " << instruction->name();
3641         EnsureInstructionAndOperandsInserted(instruction, &new_sequence,
3642                                              &inserted_instructions);
3643       }
3644       auto insts_after_iter = schedule_after_.find(instruction_index);
3645       if (insts_after_iter != schedule_after_.end()) {
3646         for (HloInstruction* new_instruction : insts_after_iter->second) {
3647           if (new_instruction->parent() == computation) {
3648             VLOG(4) << "after " << instruction_index << ": "
3649                     << new_instruction->name();
3650             EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence,
3651                                                  &inserted_instructions);
3652           }
3653         }
3654       }
3655     }
3656     // For rare cases where the original sequence is empty, ensure the root
3657     // instruction and its dependencies are scheduled.
3658     EnsureInstructionAndOperandsInserted(computation->root_instruction(),
3659                                          &new_sequence, &inserted_instructions);
3660     CHECK_EQ(new_sequence.size(), computation->instruction_count())
3661         << "New sequence for computation " << computation->name() << " has "
3662         << new_sequence.size() << " instructions, expects "
3663         << computation->instruction_count() << ".";
3664     schedule.set_sequence(computation, new_sequence);
3665   }
3666 
3667   return Status::OK();
3668 }
3669 
VerifyAndExportHeapSimulatorTrace()3670 Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() {
3671   VLOG(1) << "Verifying...";
3672   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
3673                       HloAliasAnalysis::Run(module_));
3674   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
3675                       HloLiveRange::Run(module_->schedule(), *alias_analysis,
3676                                         module_->entry_computation()));
3677 
3678   BufferIntervalTree interval_tree;
3679   absl::flat_hash_set<int64> seen_buffers;
3680   // The key for events is: time, is_free, value_id. This is so that the events
3681   // are sorted first by time, then within the same time, allocations are sorted
3682   // earlier than frees, and finally the value id as a tie breaker.
3683   std::map<std::tuple<int64, bool, int64>,
3684            std::tuple<const HloValue*, Chunk, HeapSimulatorTrace::Event::Kind>>
3685       events;
3686 
3687   auto add_allocation_and_verify = [&](int64_t start_time, int64_t end_time,
3688                                        const Chunk& chunk,
3689                                        const HloValue* value) {
3690     events[std::make_tuple(start_time, /*is_free=*/false, value->id())] =
3691         std::make_tuple(value, chunk, HeapSimulatorTrace::Event::ALLOC);
3692     events[std::make_tuple(end_time, /*is_free=*/true, value->id())] =
3693         std::make_tuple(value, chunk, HeapSimulatorTrace::Event::FREE);
3694 
3695     // Get the chunks overlapping in time and search if they overlap in space
3696     // as well.
3697     // TODO(berkin): For now checking against end_time - 1 (exclusive), but we
3698     // really should check against end_time (inclusive) for cases where the
3699     // operand can't share buffer with user (see
3700     // HloDataflowAnalysis::CanShareOperandBufferWithUser).
3701     for (const Chunk& overlapping_chunk :
3702          interval_tree.ChunksOverlappingInTime(start_time, end_time - 1)) {
3703       if (chunk.OverlapsWith(overlapping_chunk)) {
3704         return InternalError(
3705             ("Value %s (%d, %d) off: %d size: %d overlaps with another chunk"
3706              " off: %d size: %d"),
3707             value->ToShortString(), start_time, end_time, chunk.offset,
3708             chunk.size, overlapping_chunk.offset, overlapping_chunk.size);
3709       }
3710     }
3711     interval_tree.Add(start_time, end_time - 1, chunk);
3712     return Status::OK();
3713   };
3714 
3715   // Go through all instructions in the module to ensure CopyStart/CopyDone
3716   // instructions copy between alternate memory and default memory.
3717   for (const HloComputation* computation :
3718        module_->MakeNonfusionComputations()) {
3719     for (const HloInstruction* instruction : computation->instructions()) {
3720       if (instruction->opcode() == HloOpcode::kCopyStart) {
3721         int64_t from_memory_space =
3722             ShapeUtil::GetSubshape(instruction->shape(), {1})
3723                 .layout()
3724                 .memory_space();
3725         int64_t to_memory_space =
3726             ShapeUtil::GetSubshape(instruction->shape(), {0})
3727                 .layout()
3728                 .memory_space();
3729         CHECK_NE(from_memory_space, to_memory_space)
3730             << "Asynchronous copy to the same memory space: "
3731             << instruction->ToString();
3732       }
3733     }
3734   }
3735 
3736   for (const auto& position_and_chunk : preset_assignments_->chunks()) {
3737     const HloPosition& position = position_and_chunk.first;
3738     const Chunk& chunk = position_and_chunk.second;
3739     const HloBuffer& buffer =
3740         alias_analysis->GetUniqueBufferAt(position.instruction, position.index);
3741     CHECK(!seen_buffers.contains(buffer.id()))
3742         << "Multiple preset assignments for the same buffer: "
3743         << buffer.ToString() << ", pos: " << position.ToString()
3744         << ", off: " << chunk.offset << ", size: " << chunk.size;
3745     seen_buffers.insert(buffer.id());
3746 
3747     for (const HloValue* value : buffer.values()) {
3748       const HloLiveRange::TimeBound& time_bound =
3749           hlo_live_range->buffer_live_ranges().at(value);
3750       const HloInstruction* last_use_instruction = nullptr;
3751       int64_t last_use_time = time_bound.start;
3752       for (const HloUse& use : value->uses()) {
3753         int64_t use_time =
3754             hlo_live_range->instruction_schedule().at(use.instruction);
3755         if (use_time > last_use_time) {
3756           last_use_time = use_time;
3757           last_use_instruction = use.instruction;
3758         }
3759       }
3760 
3761       std::function<Status(const HloInstruction*, int64_t, int64_t,
3762                            absl::string_view)>
3763           split_conditional_buffer;
3764       split_conditional_buffer = [&](const HloInstruction* use_instruction,
3765                                      int64_t start_time, int64_t end_time,
3766                                      absl::string_view indent_string) {
3767         // Special case when verifying conditional: we internally split the use
3768         // of alternate memory in conditionals, so fish them out from the
3769         // conditionals.
3770         VLOG(3) << indent_string
3771                 << "Splitting conditional buffer: " << buffer.ToString()
3772                 << " value: " << value->ToShortString() << ": (" << start_time
3773                 << ", " << end_time << ") off: " << chunk.offset
3774                 << ", size: " << chunk.size;
3775         int64_t earliest_computation_start_time = end_time;
3776         for (const HloComputation* called_computation :
3777              use_instruction->called_computations()) {
3778           earliest_computation_start_time =
3779               std::min(earliest_computation_start_time,
3780                        hlo_live_range->computation_span_times()
3781                            .at(called_computation)
3782                            .start);
3783           int64_t parameter_time = -1;
3784           int64_t last_use_time = -1;
3785           const HloInstruction* last_use_instruction = nullptr;
3786           for (const HloPosition& position : value->positions()) {
3787             if (position.instruction->opcode() == HloOpcode::kParameter &&
3788                 position.instruction->parent() == called_computation) {
3789               parameter_time = hlo_live_range->instruction_schedule().at(
3790                   position.instruction);
3791               break;
3792             }
3793           }
3794           for (const HloUse& use : value->uses()) {
3795             int64_t use_time =
3796                 hlo_live_range->instruction_schedule().at(use.instruction);
3797             if (use.instruction->parent() == called_computation &&
3798                 use_time > last_use_time) {
3799               last_use_time = use_time;
3800               last_use_instruction = use.instruction;
3801             }
3802           }
3803           if (last_use_time != -1) {
3804             CHECK_NE(parameter_time, -1);
3805             VLOG(3) << indent_string
3806                     << " computation: " << called_computation->name() << ": ("
3807                     << parameter_time << ", " << last_use_time << ")";
3808             CHECK(last_use_instruction);
3809             if (last_use_instruction->opcode() == HloOpcode::kConditional) {
3810               // The last use is another (nested) conditional. Call this
3811               // function recursively.
3812               TF_RETURN_IF_ERROR(split_conditional_buffer(
3813                   last_use_instruction, parameter_time, last_use_time,
3814                   absl::StrCat(indent_string, "  ")));
3815             } else {
3816               last_use_time = std::min(last_use_time, end_time);
3817               TF_RETURN_IF_ERROR(add_allocation_and_verify(
3818                   parameter_time, last_use_time, chunk, value));
3819             }
3820           }
3821         }
3822         VLOG(3) << indent_string << " from beginning until first computation: ("
3823                 << start_time << ", " << (earliest_computation_start_time - 1)
3824                 << ")";
3825         TF_RETURN_IF_ERROR(add_allocation_and_verify(
3826             start_time, earliest_computation_start_time - 1, chunk, value));
3827         return Status::OK();
3828       };
3829 
3830       if (last_use_instruction &&
3831           last_use_instruction->opcode() == HloOpcode::kConditional) {
3832         TF_RETURN_IF_ERROR(split_conditional_buffer(
3833             last_use_instruction, time_bound.start, time_bound.end, " "));
3834       } else if (!value->uses().empty()) {
3835         last_use_time = std::min(last_use_time, time_bound.end);
3836         VLOG(3) << " buffer: " << buffer.ToString()
3837                 << " value: " << value->ToShortString() << ": ("
3838                 << time_bound.start << ", " << last_use_time
3839                 << ") off: " << chunk.offset << ", size: " << chunk.size;
3840         TF_RETURN_IF_ERROR(add_allocation_and_verify(
3841             time_bound.start, last_use_time, chunk, value));
3842       }
3843     }
3844   }
3845 
3846   HeapSimulatorTrace* heap_trace =
3847       &preset_assignments_
3848            ->assignment_information_for_space(options_.alternate_memory_space)
3849            ->heap_simulator_trace;
3850   int64_t memory_usage = 0;
3851   int64_t max_memory_usage = 0;
3852   for (const auto& event : events) {
3853     int64_t time;
3854     bool is_free;
3855     int64_t buffer_id;
3856     std::tie(time, is_free, buffer_id) = event.first;
3857     const HloValue* value;
3858     Chunk chunk;
3859     HeapSimulatorTrace::Event::Kind kind;
3860     std::tie(value, chunk, kind) = event.second;
3861     HeapSimulatorTrace::Event* heap_trace_event = heap_trace->add_events();
3862     heap_trace_event->set_kind(kind);
3863     heap_trace_event->set_buffer_id(buffer_id);
3864     heap_trace_event->set_instruction_name(value->instruction()->name());
3865     heap_trace_event->set_computation_name(
3866         value->instruction()->parent()->name());
3867 
3868     if (kind == HeapSimulatorTrace::Event::ALLOC) {
3869       memory_usage += chunk.size;
3870     } else {
3871       CHECK_EQ(kind, HeapSimulatorTrace::Event::FREE);
3872       memory_usage -= chunk.size;
3873     }
3874     max_memory_usage = std::max(max_memory_usage, memory_usage);
3875     VLOG(4) << "Memory usage: " << memory_usage << " at time: " << time;
3876   }
3877   VLOG(1) << "Max memory usage ignoring fragmentation: " << max_memory_usage;
3878 
3879   return Status::OK();
3880 }
3881 }  // namespace memory_space_assignment
3882 }  // namespace xla
3883