• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_
18 
19 #include "tensorflow/compiler/xla/service/heap_simulator.h"
20 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
21 #include "tensorflow/compiler/xla/service/memory_space_assignment_repacking.h"
22 
23 namespace xla {
24 
25 // This class contains pre-set assignments determined by memory space
26 // assignment. It contains two data structures: (1) a chunks vector that maps a
27 // defining HloPosition to a Chunk (offset and size), and (2) an assignment_info
28 // vector that maps the memory space to information like its allocated size and
29 // heap memory trace. If there is only one alternate memory space like there is
30 // currently, there will be one entry in assignment_info.
31 class PresetAssignments {
32  public:
33   // Contains per-memory-space information like the allocated size and heap
34   // simulator trace.
35   struct AssignmentInformation {
36     int64 size;
37     HeapSimulatorTrace heap_simulator_trace;
38   };
39 
40   PresetAssignments() = default;
41 
add_chunk(const HloPosition & position,const HeapSimulator::Chunk & chunk)42   void add_chunk(const HloPosition& position,
43                  const HeapSimulator::Chunk& chunk) {
44     chunks_.emplace_back(position, chunk);
45   }
46 
assignment_information_for_space(int64 memory_space)47   AssignmentInformation* assignment_information_for_space(int64 memory_space) {
48     for (auto& space_and_info : assignment_info_) {
49       if (space_and_info.first == memory_space) {
50         return &space_and_info.second;
51       }
52     }
53     assignment_info_.emplace_back(memory_space, AssignmentInformation());
54     return &assignment_info_.back().second;
55   }
56 
chunks()57   absl::Span<const std::pair<HloPosition, HeapSimulator::Chunk>> chunks()
58       const {
59     return chunks_;
60   }
61 
62   absl::Span<const std::pair<int64, AssignmentInformation>>
assignment_informations()63   assignment_informations() const {
64     return assignment_info_;
65   }
66 
67   // Get debugging information.
buffer_info_str()68   std::string buffer_info_str() const { return buffer_info_str_; }
allocation_info_str()69   std::string allocation_info_str() const { return allocation_info_str_; }
70 
71  private:
72   std::vector<std::pair<HloPosition, HeapSimulator::Chunk>> chunks_;
73   std::vector<std::pair<int64, AssignmentInformation>> assignment_info_;
74   std::string buffer_info_str_;
75   std::string allocation_info_str_;
76 };
77 
78 // A wrapper class around HloCostAnalysis with additional knowledge about the
79 // bandwidths of different memory spaces.
80 class MemorySpaceAssignmentCostAnalysis {
81  public:
82   // An optional Cache object may be provided to some of the methods below to
83   // speed up the lookup.
84   struct Cache {
85     absl::flat_hash_map<const HloInstruction*, float> while_nest_multiplier;
86   };
87 
88   virtual ~MemorySpaceAssignmentCostAnalysis() = default;
89 
90   static StatusOr<std::unique_ptr<MemorySpaceAssignmentCostAnalysis>> Create(
91       const HloCostAnalysis& cost_analysis,
92       float async_copy_bandwidth_bytes_per_second,
93       float alternate_mem_bandwidth_bytes_per_second, const HloModule& module);
94 
cost_analysis()95   const HloCostAnalysis& cost_analysis() const { return cost_analysis_; }
96 
97   // Returns a heuristic value that captures how much putting this tensor to the
98   // alternate memory would help if the op is memory bound, or otherwise how far
99   // off is the op to memory boundedness. The larger this number, the higher
100   // priority it will be placed in the alternate memory.
101   float GetAlternateMemoryBenefit(const HloInstruction& instruction,
102                                   float elapsed_time_due_to_alternate_mem,
103                                   Cache* cache = nullptr) const;
104 
105   // Returns a heuristic value of memory boundedness for the given
106   // BufferInterval.  The larger this number, the higher priority it will be
107   // placed in the alternate memory.
108   float GetMemoryBoundedness(
109       const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval,
110       Cache* cache = nullptr) const;
111 
112   // Returns the elapsed time in seconds due to compute only.
113   float GetInstructionElapsedDueToCompute(
114       const HloInstruction& instruction) const;
115 
116   // Returns the elapsed time in seconds due to memory only. If
117   // operand_in_alternate_mem is provided or if output_in_alternate_mem is true,
118   // it will assume that operand or output will be in the alternate memory
119   // space. This is useful for calculating the benefit of placing the buffer in
120   // alternate memory.
121   float GetInstructionElapsedDueToMemory(
122       const HloInstruction& instruction,
123       absl::optional<int64> operand_in_alternate_mem = absl::nullopt,
124       bool output_in_alternate_mem = false) const;
125 
126   // Returns the estimated elapsed duration of the instruction in seconds.  It
127   // assumes all operands and outputs of the instruction are in the default
128   // memory.
129   virtual float GetInstructionElapsed(const HloInstruction& instruction) const;
130 
131   // Returns the estimated elapsed duration of the instruction in seconds.  It
132   // assumes all operands and outputs of the instruction are in the default
133   // memory, except for the operand number that is in the alternate memory, if
134   // provided, or output if output_in_alternate_mem is true.
135   virtual float GetInstructionElapsedInAlternateMemory(
136       const HloInstruction& instruction,
137       absl::optional<int64> operand_in_alternate_mem,
138       bool output_in_alternate_mem) const;
139 
140   // Returns the elapsed time it would take to asynchronously copy the shape
141   // from default to alternate memory space (or vice versa).
142   virtual float GetAsyncCopyElapsed(const Shape& shape) const;
143 
144   int64 GetScheduleEndTime() const;
145 
146   // Returns the number of nested computation levels this instruction resides
147   // in. If while_only is true, it returns the while loop nest level and 0
148   // means the instruction is not in a while loop.
149   int CalculateComputationNestLevel(const HloInstruction* instruction,
150                                     bool while_only) const;
151 
hlo_live_range()152   const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; }
153 
154  protected:
MemorySpaceAssignmentCostAnalysis(const HloCostAnalysis & cost_analysis,float async_copy_bandwidth_bytes_per_second,float alternate_mem_bandwidth_bytes_per_second,std::unique_ptr<HloAliasAnalysis> alias_analysis,std::unique_ptr<HloLiveRange> hlo_live_range,std::unique_ptr<CallGraph> call_graph)155   MemorySpaceAssignmentCostAnalysis(
156       const HloCostAnalysis& cost_analysis,
157       float async_copy_bandwidth_bytes_per_second,
158       float alternate_mem_bandwidth_bytes_per_second,
159       std::unique_ptr<HloAliasAnalysis> alias_analysis,
160       std::unique_ptr<HloLiveRange> hlo_live_range,
161       std::unique_ptr<CallGraph> call_graph)
162       : cost_analysis_(cost_analysis),
163         async_copy_bandwidth_bytes_per_second_(
164             async_copy_bandwidth_bytes_per_second),
165         alternate_mem_bandwidth_bytes_per_second_(
166             alternate_mem_bandwidth_bytes_per_second),
167         alias_analysis_(std::move(alias_analysis)),
168         hlo_live_range_(std::move(hlo_live_range)),
169         call_graph_(std::move(call_graph)) {}
170 
171  private:
172   const HloCostAnalysis& cost_analysis_;
173   float async_copy_bandwidth_bytes_per_second_;
174   float alternate_mem_bandwidth_bytes_per_second_;
175   std::unique_ptr<HloAliasAnalysis> alias_analysis_;
176   std::unique_ptr<HloLiveRange> hlo_live_range_;
177   std::unique_ptr<CallGraph> call_graph_;
178 };
179 
180 // Abstract base class that memory space assignment uses to pick prefetch
181 // intervals.
182 class PrefetchIntervalPicker {
183  public:
184   PrefetchIntervalPicker() = default;
185   virtual ~PrefetchIntervalPicker() = default;
186 
187   // Returns true if the buffer can be allocated in alternate memory space
188   // without any copies (prefetches).
189   virtual bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape,
190                                                   int64 start_time,
191                                                   int64 end_time) const = 0;
192 
193   // Returns the preferred end time for an eviction that starts at a given time
194   // and must end by the given end time.
195   virtual int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time,
196                                          int64 latest_end_time) const = 0;
197 
198   // Returns the latest time that a prefetch can start.
199   virtual int64 LatestPrefetchStartTime(const Shape& shape, int64 start_time,
200                                         int64 end_time,
201                                         const HloUse* use) const = 0;
202 
203   // Returns the preferred time that a prefetch can start.
204   virtual int64 PreferredPrefetchStartTime(const Shape& shape,
205                                            int64 earliest_prefetch_start_time,
206                                            int64 latest_prefetch_start_time,
207                                            int64 prefetch_end_time) const = 0;
208 
209   // Returns the latest time that a prefetch can end that is less than or equal
210   // to proposed_prefetch_end_time.
LatestPrefetchEndTime(int64 original_prefetch_end_time,int64 proposed_prefetch_end_time)211   virtual int64 LatestPrefetchEndTime(int64 original_prefetch_end_time,
212                                       int64 proposed_prefetch_end_time) const {
213     return proposed_prefetch_end_time;
214   }
215 
216   // Begins the iterator for the first start time of the prefetch.
217   virtual void Begin(const HloUse& use, int64 start_time, int64 end_time) = 0;
218 
219   // Advances the start time of the prefetch and returns that value.
220   virtual int64 Next() = 0;
221 
222   // Returns true if the available prefetch intervals have been exhausted.
223   virtual bool Done() const = 0;
224 
225   // The retry number can be used to modify the interval picking policies. The
226   // first attempt will have a retry_number of 0, then 1, etc.
SetRetryNumber(int retry_number)227   virtual void SetRetryNumber(int retry_number) {}
228 
229   // Returns a debug string for the current state of the prefetch interval
230   // picker.
231   virtual std::string ToDebugString() const = 0;
232 
233   // Returns a debug string for no-copy allocation.
234   virtual std::string ToNoCopyDebugString(const Shape& shape, int64 start_time,
235                                           int64 end_time) const = 0;
236 
237   // Prefetch interval pickers may return a value corresponding to the benefit
238   // of placing the BufferInterval in the alternate memory. The larger value,
239   // the more beneficial.
BufferIntervalAlternateMemoryBenefit(const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval & interval)240   virtual absl::optional<float> BufferIntervalAlternateMemoryBenefit(
241       const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval)
242       const {
243     return absl::nullopt;
244   }
245 
246  protected:
247   const absl::flat_hash_map<const HloInstruction*, int64>*
248       instruction_schedule_ = nullptr;
249 };
250 
251 // Prefetch interval picker that uses instruction count to overlap asynchronous
252 // copies with independent computation. The min and max overlap counts describe
253 // the number of independent HLOs overlapped while a value is being prefetched
254 // into the alternate memory (between CopyStart and CopyDone HLO instructions).
255 // max_overlap_count attempts to prevent bringing tensors into the alternate
256 // memory too eagerly and hence occupying the space for other tensors which
257 // might use it.  min_overlap_count attempts to prevent cases where tensors are
258 // prefetched into the alternate memory without sufficient time for the copy to
259 // take place.  In those cases, it's just better to keep the tensor in the
260 // default memory instead of hurting the critical path with this copy that
261 // likely won't finish in time.
262 class InstructionCountPrefetchIntervalPicker : public PrefetchIntervalPicker {
263  public:
InstructionCountPrefetchIntervalPicker(int64 min_overlap_count,int64 max_overlap_count)264   InstructionCountPrefetchIntervalPicker(int64 min_overlap_count,
265                                          int64 max_overlap_count)
266       : min_overlap_count_(min_overlap_count),
267         max_overlap_count_(max_overlap_count) {}
268 
269   bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, int64 start_time,
270                                           int64 end_time) const override;
271 
272   int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time,
273                                  int64 latest_end_time) const override;
274 
275   int64 LatestPrefetchStartTime(const Shape& shape, int64 start_time,
276                                 int64 end_time,
277                                 const HloUse* use) const override;
278 
279   int64 PreferredPrefetchStartTime(const Shape& shape,
280                                    int64 earliest_prefetch_start_time,
281                                    int64 latest_prefetch_start_time,
282                                    int64 prefetch_end_time) const override;
283 
284   void Begin(const HloUse& use, int64 start_time, int64 end_time) override;
285 
286   int64 Next() override;
287   bool Done() const override;
288 
289   std::string ToDebugString() const override;
290   std::string ToNoCopyDebugString(const Shape& shape, int64 start_time,
291                                   int64 end_time) const override;
292 
293  private:
294   int64 min_overlap_count_;
295   int64 max_overlap_count_;
296   int64 end_time_;
297   int64 current_prefetch_time_;
298 };
299 
300 // Prefetch interval picker that uses cost analysis to overlap asynchronous
301 // copies with independent computation. It uses min/max (asynchronous copy
302 // duration) / (independent computation duration) ratios to guide whether the
303 // prefetch is within those bounds. It starts with the preferred ratio in
304 // Begin() and works its way for alternately earlier and later prefetches until
305 // hitting min and max ratios.
306 class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker {
307  public:
308   CostAnalysisPrefetchIntervalPicker(
309       const MemorySpaceAssignmentCostAnalysis& cost_analysis,
310       float min_async_copy_to_overlap_ratio,
311       float max_async_copy_to_overlap_ratio,
312       float preferred_async_copy_to_overlap_ratio);
313 
314   bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, int64 start_time,
315                                           int64 end_time) const override;
316 
317   int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time,
318                                  int64 latest_end_time) const override;
319 
320   int64 LatestPrefetchEndTime(int64 original_prefetch_end_time,
321                               int64 proposed_prefetch_end_time) const override;
322 
323   int64 LatestPrefetchStartTime(const Shape& shape, int64 start_time,
324                                 int64 end_time,
325                                 const HloUse* use) const override;
326 
327   int64 PreferredPrefetchStartTime(const Shape& shape,
328                                    int64 earliest_prefetch_start_time,
329                                    int64 latest_prefetch_start_time,
330                                    int64 prefetch_end_time) const override;
331 
332   void Begin(const HloUse& use, int64 start_time, int64 end_time) override;
333 
334   int64 Next() override;
335   bool Done() const override;
336 
337   void SetRetryNumber(int retry_number) override;
338 
339   std::string ToDebugString() const override;
340   std::string ToNoCopyDebugString(const Shape& shape, int64 start_time,
341                                   int64 end_time) const override;
342 
343   absl::optional<float> BufferIntervalAlternateMemoryBenefit(
344       const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval)
345       const override;
346 
347  private:
348   // Returns the elapsed time in seconds between the logical interval that
349   // corresponds to the instruction schedule.
350   float GetLogicalIntervalElapsed(int64 start_time, int64 end_time) const;
351 
352   // Finds the minimum nest level in the given interval.
353   int GetMinWhileNestLevel(int64 start_time, int64 end_time) const;
354 
355   // For each instruction in the flattened schedule, maintain their elapsed time
356   // (in cumulative sum) and while nesting level.
357   std::vector<float> elapsed_time_cumsum_;
358   std::vector<int> while_nest_level_;
359   std::vector<int> computation_nest_level_;
360   // Maintain the index of the most recent (before this instruction) nest level
361   // change in order to efficiently determine the minimum nest level in an
362   // interval.
363   std::vector<int> while_nest_level_change_;
364 
365   const MemorySpaceAssignmentCostAnalysis& cost_analysis_;
366   float min_async_copy_to_overlap_ratio_;
367   float max_async_copy_to_overlap_ratio_;
368   float preferred_async_copy_to_overlap_ratio_;
369   float max_overlap_multiplier_ = 1.0;
370 
371   float async_copy_elapsed_;
372   float inst_elapsed_reduction_;
373   int64 end_logical_time_;
374   int64 earliest_prefetch_time_;
375   int64 latest_prefetch_time_;
376   bool using_increasing_prefetch_time_iterator_ = true;
377   int64 increasing_prefetch_time_iterator_;
378   int64 decreasing_prefetch_time_iterator_;
379 };
380 
381 // MemorySpaceAssignment assigns memory spaces (default or alternate) to each
382 // instruction in the module. It will greedily try placing as as many values in
383 // the alternate memory space as possible. It uses the heap simulator to
384 // determine the actual allocation offsets of values in the alternate memory
385 // space to account for fragmentation. The default memory space is assumed to be
386 // large enough to hold the values that could not be placed in the alternate
387 // memory space.
388 class MemorySpaceAssignment {
389  public:
390   using Chunk = HeapSimulator::Chunk;
391   using BufferInterval =
392       GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval;
393   using BufferIntervalCompare =
394       GlobalDecreasingSizeBestFitHeap<HloValue>::BufferIntervalCompare;
395   using IsAllowedInAlternateMemoryFunction =
396       std::function<bool(const HloValue&)>;
397   using IsUseAllowedInAlternateMemoryFunction =
398       std::function<bool(const HloUse&)>;
399 
400   // MemorySpaceAssignment uses a notion of a slow and large default memory
401   // space and a fast and small alternate memory space.
402   enum class MemorySpace { kDefault, kAlternate };
403 
404   // Forward declaration for Allocation.
405   class Allocation;
406 
407   // The different options to be passed to the Run() API.
408   struct Options {
409     // Backend-specific integer value that describes the alternate memory.
410     int64 alternate_memory_space = 0;
411 
412     // Maximum size of the alternate memory space.
413     int64 max_size_in_bytes = 0;
414 
415     // Memory alignment of the alternate memory space.
416     int64 alignment_in_bytes = 1;
417 
418     // If provided, we sort the buffers using this comparison function
419     // otherwise, we use GlobalDecreasingSizeBestFitHeap::kSpatial.
420     absl::optional<BufferIntervalCompare> buffer_interval_compare =
421         absl::nullopt;
422 
423     // This object determines how early and how late prefetches can occur.
424     PrefetchIntervalPicker* prefetch_interval_picker = nullptr;
425 
426     // Size function for buffer values.
427     BufferValue::SizeFunction size_fn;
428 
429     // This function can be used to prevent certain HloValues (e.g., based on
430     // the opcode) to be placed on the alternate memory.
431     IsAllowedInAlternateMemoryFunction is_allowed_in_alternate_mem_fn;
432 
433     // This function can be used to prevent certain HloUses (e.g., based on
434     // the opcode) to be placed on the alternate memory.
435     IsUseAllowedInAlternateMemoryFunction is_use_allowed_in_alternate_mem_fn =
436         [](const HloUse&) { return true; };
437 
438     // Specifies the upper bound for number of outstanding prefetches and
439     // evictions, -1 for unlimited.
440     int64 max_outstanding_prefetches = -1;
441     int64 max_outstanding_evictions = -1;
442 
443     // Extra outstanding prefetch limit for while uses (in addition to
444     // max_outstanding_prefetches).
445     int64 while_use_extra_outstanding_prefetch_limit = 0;
446 
447     // Specifies the maximum number of times we are willing to move a copy
448     // done of a prefetch earlier due to an asynchronous copy ordering
449     // violation.
450     int64 prefetch_copy_done_reorder_max_retries = 1;
451 
452     // Specifies the maximum number of retries that will be performed for each
453     // value in case prefetching failed due to running out of asynchronous
454     // copies or asynchronous copy ordering.
455     int64 max_retries = 1;
456 
457     // The maximum number of repacks that we are willing to perform in case we
458     // can't allocate a buffer due to running out of memory. If this value is
459     // greater than 0, repacker must be non-nullptr.
460     int64 max_repacks = 0;
461 
462     // The repacking algorithm to reduce fragmentation. Must be non-null if
463     // max_repacks is greater than 0.
464     MemorySpaceAssignmentRepacker* repacker = nullptr;
465 
466     // This is only useful for testing, repack after every allocation.
467     bool repack_after_every_allocation = false;
468 
469     // If true, tries allocating buffers across (e.g., before and inside a while
470     // loop body) sequential calls (kWhile, kCall, and kConditional).
471     bool allocate_across_sequential_calls = false;
472 
473     // If true, verifies the memory space assignment against overlapping
474     // buffers.
475     bool verify = false;
476 
477     // If not nullptr, this function is called to dump debugging information.
478     // The first argument is appended to the file name and the second argument
479     // is the contents of the file.
480     std::function<void(absl::string_view, absl::string_view)> dump_fn = nullptr;
481 
482     // Enable prefetching buffers into preferred memory across program
483     // boundaries
484     bool enable_cross_program_prefetch = true;
485 
486     // If true, use buffer_interval_compare to determine which buffers to
487     // prefetch across program boundaries.
488     bool default_cross_program_prefetch_heuristic = false;
489   };
490 
491   // This class represents an allocation that might either be in the default or
492   // alternate memory. An HloValue might live in multiple different allocations
493   // over its lifetime. The lifetimes of the allocations are defined using
494   // start_time and end_time, which corresponds to the instruction indexes in
495   // the flattened schedule. Each of these allocations might partially overlap
496   // with each other. CopyAllocation defined below represents asynchronous
497   // copies between Allocations.
498   //
499   // Consider an instruction Foo, and its users Bar and Baz, and the times given
500   // in terms of the flattened schedule of the entire module:
501   //
502   //      Foo:10
503   //       /   \
504   //    Bar:14  \
505   //           Baz:25
506   //
507   // A valid memory space assignment could be like the following:
508   //
509   //  Time:         10 ... 14        ...      25
510   //                Foo    Bar                Baz
511   //  Alternate     +-------+           +-----+
512   //  Default           +---------------------+
513   //                    ^   ^           ^     ^
514   //                    |   |           |     |
515   //                evict   evict  prefetch  prefetch
516   //                start    end    start      end
517   //
518   // This would be represented with:
519   //   - Allocation(memory_space=kAlternate, start_time=10, end_time=14)
520   //   - CopyAllocation(memory_space=kDefault, start_time=12, end_time=25)
521   //   - CopyAllocation(memory_space=kAlternate, start_time=22, end_time=25)
522   class Allocation {
523    public:
Allocation(HloPosition defining_position,MemorySpace memory_space,absl::optional<Chunk> chunk,int64 start_time,int64 end_time)524     Allocation(HloPosition defining_position, MemorySpace memory_space,
525                absl::optional<Chunk> chunk, int64 start_time, int64 end_time)
526         : defining_position_(defining_position),
527           memory_space_(memory_space),
528           chunk_(chunk),
529           start_time_(start_time),
530           end_time_(end_time) {}
531     virtual ~Allocation() = default;
532 
is_copy_allocation()533     virtual bool is_copy_allocation() const { return false; }
534 
535     // Adds a use to this allocation.
536     void AddUse(HloUse use);
537 
538     // Extends the end time of this allocation.
Extend(int64 end_time)539     void Extend(int64 end_time) { end_time_ = end_time; }
540 
541     // After all of the time ranges for the allocations have been assigned,
542     // Process morphs the instructions affected to assign the memory spaces and
543     // insert asynchronous copy instructions if necessary.
544     virtual Status Process(MemorySpaceAssignment* memory_space_assignment);
545 
546     // Returns the defining position for this allocation.
defining_position()547     virtual HloPosition defining_position() const { return defining_position_; }
548 
549     // Returns the time the buffer is first available to be used. For
550     // Allocation, this is start_time.
earliest_available_time()551     virtual int64 earliest_available_time() const { return start_time_; }
552 
uses()553     const std::vector<HloUse>& uses() const { return uses_; }
memory_space()554     MemorySpace memory_space() const { return memory_space_; }
chunk()555     Chunk chunk() const { return *chunk_; }
mutable_chunk()556     Chunk* mutable_chunk() { return &*chunk_; }
set_start_time(int64 start_time)557     void set_start_time(int64 start_time) { start_time_ = start_time; }
start_time()558     int64 start_time() const { return start_time_; }
end_time()559     int64 end_time() const { return end_time_; }
560 
561     bool operator==(const Allocation& other) const;
562     virtual std::string ToString() const;
563 
564    protected:
565     // Descend to the shape_index element of the tuple and replace that with
566     // new_instruction.
567     StatusOr<HloInstruction*> ReplaceTupleWith(HloInstruction* new_instruction,
568                                                HloInstruction* tuple,
569                                                ShapeIndex shape_index);
570 
571     // Recursively create kGetTupleElement instructions if the defining position
572     // shape is not an array. Returns the new instruction that has array shape.
573     HloInstruction* AddGetTupleElements();
574 
575     HloPosition defining_position_;
576     std::vector<HloUse> uses_;
577     MemorySpace memory_space_;
578     absl::optional<Chunk> chunk_;
579     int64 start_time_;
580     int64 end_time_;
581   };
582 
583   // This class represents an allocation as a result of an asynchronous copy.
584   // Note: CopyStart instructions are inserted after `start_time` or later,
585   // while CopyDone instructions are inserted before
586   // `copy_done_schedule_before_time` or earlier.
587   class CopyAllocation : public Allocation {
588    public:
589     CopyAllocation(const Allocation& prev_allocation, MemorySpace memory_space,
590                    absl::optional<Chunk> chunk, int64 start_time,
591                    int64 end_time, int64 copy_done_schedule_before_time,
592                    bool is_cross_program_prefetch = false)
593         : Allocation(/*defining_position=*/{nullptr, {}}, memory_space, chunk,
594                      start_time, end_time),
595           prev_allocation_(prev_allocation),
596           copy_start_schedule_after_(start_time),
597           copy_done_schedule_before_(copy_done_schedule_before_time),
598           is_cross_program_prefetch_(is_cross_program_prefetch) {}
599 
is_copy_allocation()600     bool is_copy_allocation() const override { return true; }
601 
602     Status Process(MemorySpaceAssignment* memory_space_assignment) override;
603 
defining_position()604     HloPosition defining_position() const override {
605       // Unless explicitly set, the defining position of a copy allocation in
606       // retrieved from the previous allocation. This is because we don't create
607       // new CopyStart/CopyDone instructions until later and the position should
608       // point to the previous (copy or otherwise) allocation's position for the
609       // original defining position.
610       if (defining_position_.instruction == nullptr) {
611         return prev_allocation_.defining_position();
612       } else {
613         return defining_position_;
614       }
615     }
616 
copy_start()617     HloInstruction* copy_start() const { return copy_start_; }
copy_done()618     HloInstruction* copy_done() const { return copy_done_; }
619 
620     // Returns the time the buffer is first available to be used. For For
621     // CopyAllocation, this is when the copy ends, which is
622     // copy_done_schedule_before.
earliest_available_time()623     int64 earliest_available_time() const override {
624       return copy_done_schedule_before_;
625     }
626 
copy_start_schedule_after()627     int64 copy_start_schedule_after() const {
628       return copy_start_schedule_after_;
629     }
copy_done_schedule_before()630     int64 copy_done_schedule_before() const {
631       return copy_done_schedule_before_;
632     }
633 
set_copy_start_schedule_after(int64 copy_start_schedule_after)634     void set_copy_start_schedule_after(int64 copy_start_schedule_after) {
635       copy_start_schedule_after_ = copy_start_schedule_after;
636     }
637 
is_cross_program_prefetch()638     bool is_cross_program_prefetch() const {
639       return is_cross_program_prefetch_;
640     }
641 
642     bool operator==(const CopyAllocation& other) const;
643     std::string ToString() const override;
644 
645    private:
646     const Allocation& prev_allocation_;
647     // These variables define the scheduling boundaries where CopyStart and
648     // CopyDone can be scheduled. The earliest CopyStart can be scheduled is
649     // after copy_start_schedule_after_ and the latest CopyDone can be scheduled
650     // is before copy_done_schedule_before_.
651     int64 copy_start_schedule_after_;
652     int64 copy_done_schedule_before_;
653     bool is_cross_program_prefetch_;
654     HloInstruction* copy_start_;
655     HloInstruction* copy_done_;
656   };
657 
658   using AllocationSequence = std::vector<std::unique_ptr<Allocation>>;
659   // AllocationValue is used to break up HloValues for each non-trivial position
660   // (trivial positions are considered Tuple, GetTupleElement, and Bitcast). An
661   // HloValue may include positions and uses that alias with each other across
662   // multiple computations. We use this class to break these HloValues such that
663   // every AllocationValue has one defining position (that may alias with other
664   // AllocationValues). The uses field of the AllocationValue contains only the
665   // direct uses of the AllocationValue's defining position.
666   //
667   // For example, consider the following HLO snippet:
668   //
669   // Body {
670   //   body_param = (f32[4,3]{1,0}, f32[]) parameter(0)
671   //   get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element(body_param),
672   //   index=0
673   //   ...
674   //   ROOT tuple = (f32[4,3]{1,0}, f32[]) tuple(get-tuple-element.3, ...)
675   // }
676   //
677   // Cond {
678   //   cond_param = (f32[4,3]{1,0}, f32[]) parameter(0)
679   //   ...
680   // }
681   //
682   // add.4 = f32[4,3]{1,0} add(...)
683   // tuple.1 = (f32[4,3]{1,0}, f32[]) tuple(add.4, ...)
684   // while = (f32[4,3]{1,0}, f32[]) while(tuple.1), body=Body, condition=Cond
685   // get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element(while), index=0
686   // add.5 = f32[4,3]{1,0} add(get-tuple-element.5, ...)
687   //
688   // This contains an HloValue that looks like the following:
689   // positions:
690   //  add.4
691   //  body_param {0}
692   //  get-tuple-element.3
693   //  tuple {0}
694   //  cond_param {0}
695   //  tuple.1 {0}
696   //  while {0}
697   //  get-tuple-element.5
698   // uses:
699   //  add.1, operand 0
700   //  tuple, operand 0
701   //  while, operand 0 {0}
702   //  add.5, operand 0
703   //
704   // We break this HloValue up into the following AllocationValues for each
705   // non-trivial position:
706   // AllocationValue1: computation = Entry
707   //  position:
708   //   add.4
709   //  uses:
710   //   while, operand 0 {0}
711   // AllocationValue2: computation = Cond
712   //  position:
713   //   cond_param {0}
714   //  uses:
715   // AllocationValue3: computation = Body
716   //  position:
717   //   body_param {0}
718   //  uses:
719   //   add.1, operand 0
720   //   tuple, operand 0
721   // AllocationValue4: computation = Entry
722   //  position:
723   //   while {0}
724   //  uses:
725   //   add.5, operand 0
726   class AllocationValue {
727    public:
728     // This data structure wraps an HloUse and adds additional metadata that are
729     // useful for allocation.
730     struct Use {
731       // The wrapped HloUse object.
732       HloUse hlo_use;
733       // The logical time this use is scheduled.
734       int64 time;
735       // All the positions where this use aliases with. The aliased positions
736       // must get the same allocation.
737       std::vector<HloPosition> aliases;
738 
739       bool operator==(const Use& other) const {
740         return hlo_use == other.hlo_use && time == other.time &&
741                aliases == other.aliases;
742       }
743 
744       template <typename H>
AbslHashValueUse745       friend H AbslHashValue(H h, const Use& s) {
746         return H::combine(std::move(h), s.hlo_use, s.time, s.aliases);
747       }
748     };
749 
AllocationValue(const HloValue * value,const HloPosition & position,int64 size)750     AllocationValue(const HloValue* value, const HloPosition& position,
751                     int64 size)
752         : value_(value), defining_position_(position), size_(size) {}
753 
defining_position()754     const HloPosition& defining_position() const { return defining_position_; }
defining_instruction()755     const HloInstruction* defining_instruction() const {
756       return defining_position().instruction;
757     }
size()758     int64 size() const { return size_; }
uses()759     const std::vector<Use>& uses() const { return uses_; }
uses()760     std::vector<Use>& uses() { return uses_; }
value()761     const HloValue* value() const { return value_; }
computation()762     const HloComputation* computation() const {
763       return defining_instruction()->parent();
764     }
allocation_sequence()765     AllocationSequence* allocation_sequence() { return &allocation_sequence_; }
766 
AddUse(const HloUse & use,int64 use_time)767     void AddUse(const HloUse& use, int64 use_time) {
768       uses_.push_back({use, use_time, {}});
769     }
770 
771     std::string ToString() const;
772     std::string ToShortString() const;
773 
774    private:
775     const HloValue* value_;
776     HloPosition defining_position_;
777     int64 size_;
778     std::vector<Use> uses_;
779     AllocationSequence allocation_sequence_;
780   };
781 
782   // Statistics of asynchronous copies.
783   struct AsyncCopyStats {
784     int64 max_outstanding_async_copies;
785     int64 num_prefetches;
786     int64 prefetch_bytes;
787     int64 num_evictions;
788     int64 eviction_bytes;
789   };
790 
791   virtual ~MemorySpaceAssignment() = default;
792 
793   // Runs the MemorySpaceAssignment pass.
794   static StatusOr<std::unique_ptr<PresetAssignments>> Run(
795       HloModule* module, const HloLiveRange& hlo_live_range,
796       const HloAliasAnalysis& alias_analysis, const Options& options);
797 
798   // Calculates asynchronous copy statistics.
799   StatusOr<AsyncCopyStats> CalculateAsyncCopyStats() const;
800 
801   static BufferIntervalCompare GetMemoryBoundednessBufferIntervalCompare(
802       const MemorySpaceAssignmentCostAnalysis& cost_analysis,
803       MemorySpaceAssignmentCostAnalysis::Cache* cache = nullptr);
804 
805   // Verify that the memory space assignment is free of overlapping buffers and
806   // export heap simulator trace to be used by buffer_assignment.
807   Status VerifyAndExportHeapSimulatorTrace();
808 
809  protected:
810   // Main driver of the memory space assignment pass.
811   virtual StatusOr<std::unique_ptr<PresetAssignments>> RunMemorySpaceAssignment(
812       const HloLiveRange& hlo_live_range,
813       const HloAliasAnalysis& alias_analysis);
814 
815   // Finds an AllocationSequence for placing buffers in alternate memory using
816   // the AlternateMemoryBestFitHeap algorithm. Must be set before Process() is
817   // called.
818   virtual Status FindAllocationSequence(const HloLiveRange& hlo_live_range,
819                                         const HloAliasAnalysis& alias_analysis);
820 
options()821   Options options() const { return options_; }
822 
MemorySpaceAssignment(HloModule * module,Options options,const HloLiveRange & hlo_live_range)823   MemorySpaceAssignment(HloModule* module, Options options,
824                         const HloLiveRange& hlo_live_range)
825       : module_(module),
826         options_(options),
827         flattened_instructions_(hlo_live_range.flattened_instruction_sequence()
828                                     .instructions()
829                                     .begin(),
830                                 hlo_live_range.flattened_instruction_sequence()
831                                     .instructions()
832                                     .end()),
833         computations_in_schedule_(),
834         preset_assignments_(absl::make_unique<PresetAssignments>()) {
835     for (const auto& computation_and_bound :
836          hlo_live_range.computation_span_times()) {
837       computations_in_schedule_.insert(computation_and_bound.first);
838     }
839   }
840 
841   AllocationSequence allocations_;
842 
module()843   HloModule* module() { return module_; }
844 
845  private:
846   // Process calls Process methods of the allocations after the allocations have
847   // been finalized.
848   Status Process();
849 
850   // Process() might have altered the computation graph by inserting kTuple and
851   // kGetTupleElement instructions. SimplifyGraph performs a simple DCE and
852   // tuple simplification operation (e.g., given GetTupleElement(Tuple(a, b),
853   // 1), simply forwards b). Runs to fixed point.
854   Status SimplifyGraph();
855 
856   // FixSchedule inserts asynchronous copies in the schedule.
857   Status FixSchedule();
858 
859   // Export the alternate memory assignments to the PresetAssignments and color
860   // the HLO graph with the determined memory spaces.
861   Status ExportAndColorBuffers();
862 
863   // Insert an instruction to the schedule, and make sure its dependencies
864   // (operands) are already in the schedule. If not, insert these operands
865   // before the instruction.
866   void EnsureInstructionAndOperandsInserted(
867       HloInstruction* new_instruction, HloInstructionSequence* new_sequence,
868       absl::flat_hash_set<HloInstruction*>* inserted_instructions) const;
869 
870   // Schedules asynchronous copies and ensures that the CopyStarts and their
871   // corresponding CopyDones follow the same order.
872   void ScheduleAsynchronousCopies();
873 
874   // Remove the positions and chunks associated with the instruction from
875   // alternate_memory_assignments_.
876   void RemoveAssignmentForInstruction(const HloInstruction* instruction);
877 
878   HloModule* module_;
879   Options options_;
880   std::vector<HloInstruction*> flattened_instructions_;
881   absl::flat_hash_set<const HloComputation*> computations_in_schedule_;
882   std::unique_ptr<PresetAssignments> preset_assignments_;
883   std::vector<std::pair<HloPosition, Chunk>> alternate_memory_assignments_;
884   int64 alternate_memory_size_ = 0;
885 
886   // These maps hold vectors of new instructions that need to be scheduled after
887   // (or before) the instruction index in the key. FixSchedule uses these maps
888   // to modify and fix the schedule.
889   absl::flat_hash_map<int64, std::vector<HloInstruction*>> schedule_after_;
890   absl::flat_hash_map<int64, std::vector<HloInstruction*>> schedule_before_;
891 };
892 
893 // A struct representing an asynchronous copy with its logical start and end
894 // time and its destination memory space.
895 struct AsynchronousCopy {
896   int64 start_time;
897   int64 end_time;
898   MemorySpaceAssignment::MemorySpace destination;
899 };
900 
901 // Compare asynchronous copies such that an earlier start time has the same or
902 // earlier end time and an earlier end time has the same or earlier start time.
903 bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b);
904 
905 // Helper class to enforce asynchronous copy ordering. We only allow
906 // asynchronous copies that are pipelined: if an asynchronous copy ends earlier
907 // than another asynchronous copy, it must start the same time or earlier than
908 // the other asynchronous copy; and if an asynchronous copy starts earlier than
909 // another asynchronous copy, it must end the same time or earlier than the
910 // other asynchronous copy.
911 class AsynchronousCopyOrdering {
912  public:
913   AsynchronousCopyOrdering() = default;
914 
915   // Adds an asynchronous copy.
916   void AddCopy(const AsynchronousCopy& copy);
917 
918   // Removes an asynchronous copy. CHECKs that it is removed.
919   void RemoveCopy(const AsynchronousCopy& copy);
920 
921   // If the addition of an asynchronous copy in the given time interval would
922   // violate the asynchronous copy ordering, returns the violating
923   // already-committed asynchronous copy. E.g., consider the following scenario:
924   //                                  CS          CD
925   //  already committed async copy:   +-----------+
926   //                new async copy:     +--------+
927   //
928   // The new asynchronous copy would violate the ordering guarantee because the
929   // copy start is after an already committed asynchronous copy while its copy
930   // done is before the committed copy.
931   absl::optional<AsynchronousCopy> ViolatesOrdering(int64 start_time,
932                                                     int64 end_time) const;
933 
934  private:
935   // Stores asynchronous copies in a tree set respecting the pipelining order.
936   std::set<AsynchronousCopy> ranges_;
937 };
938 
939 // This class inherits from GlobalDecreasingSizeBestFitHeap with a notion of
940 // maximum size.
941 class AlternateMemoryBestFitHeap
942     : public GlobalDecreasingSizeBestFitHeap<HloValue> {
943  public:
944   using MemorySpace = MemorySpaceAssignment::MemorySpace;
945   using AllocationValue = MemorySpaceAssignment::AllocationValue;
946 
AlternateMemoryBestFitHeap(MemorySpaceAssignment::AllocationSequence * allocations,const MemorySpaceAssignment::Options & options,const HloAliasAnalysis & alias_analysis,const HloLiveRange & hlo_live_range)947   AlternateMemoryBestFitHeap(
948       MemorySpaceAssignment::AllocationSequence* allocations,
949       const MemorySpaceAssignment::Options& options,
950       const HloAliasAnalysis& alias_analysis,
951       const HloLiveRange& hlo_live_range)
952       : GlobalDecreasingSizeBestFitHeap(options.alignment_in_bytes),
953         allocations_(allocations),
954         options_(options),
955         alias_analysis_(alias_analysis),
956         hlo_live_range_(hlo_live_range) {
957     // Override buffer interval compare if provided.
958     if (options.buffer_interval_compare) {
959       buffer_interval_compare_ = *options.buffer_interval_compare;
960     }
961   }
962 
963   // Allocates a buffer in preferred memory with whole program lifetime and
964   // enables prefetching prefetch_candidate from default memory across program
965   // boundaries.
966   void AllocateCrossProgramPrefetchBuffer(
967       HloModule* module, absl::optional<BufferInterval> prefetch_candidate);
968 
969   HeapSimulator::Result<HloValue> Finish() override;
970 
971  protected:
972   // Given a buffer interval, returns the colocated intervals. Unlike the
973   // similar GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations, it
974   // returns the colocated intervals sorted by scheduled time.
975   std::vector<const BufferInterval*> GetSortedColocatedIntervals(
976       const BufferInterval& interval) const;
977 
978   // Given a BufferInterval, creates AllocationValue objects and corresponding
979   // AllocationSequences and appends them into allocation_sequence_list_.
980   void CreateAllocationValues(
981       const BufferInterval& buffer_interval,
982       std::vector<AllocationValue>& allocation_values) const;
983 
984   // Given colocated intervals, populates allocation_values with the
985   // corresponding AllocationValue objects.
986   virtual void CreateAllocationValuesFromColocatedIntervals(
987       absl::Span<const AlternateMemoryBestFitHeap::BufferInterval* const>
988           colocated_intervals,
989       std::vector<MemorySpaceAssignment::AllocationValue>& allocation_values);
990 
991   // Go through all the uses in the AllocationValues and find the aliasing
992   // positions.
993   void FindAliases(std::vector<AllocationValue>* allocation_values,
994                    bool skip_values_with_no_uses) const;
995 
allocations()996   MemorySpaceAssignment::AllocationSequence* allocations() {
997     return allocations_;
998   }
options()999   const MemorySpaceAssignment::Options& options() { return options_; }
alias_analysis()1000   const HloAliasAnalysis& alias_analysis() { return alias_analysis_; }
hlo_live_range()1001   const HloLiveRange& hlo_live_range() { return hlo_live_range_; }
1002 
1003  private:
1004   // We inherit AllocationBlock struct to attach the Allocation information to
1005   // make importing repacked offsets easier.
1006   struct RepackAllocationBlock
1007       : MemorySpaceAssignmentRepacker::AllocationBlock {
1008     MemorySpaceAssignment::Allocation* allocation;
1009   };
1010 
1011   // A data structure we use to associate Allocation objects that are aliased
1012   // and must get the same offset.
1013   struct AliasedOffset {
1014     int64 offset;
1015     absl::flat_hash_set<const MemorySpaceAssignment::Allocation*> allocations;
1016   };
1017 
1018   // An allocation request for a use segment. A use segment is the time segment
1019   // between the definition and the first use, and the time segment between the
1020   // uses of a buffer. For example, the time between the definition and Use1, is
1021   // the first segment, and the time between Use1 and Use2 is the second segment
1022   // and so on:
1023   //
1024   //        +------+----------+-------+
1025   //       /        \          \       \
1026   //      /          v          v       v
1027   //    Def         Use1       Use2    Use3
1028   //     <----------> <--------> <----->
1029   //        Segment    Segment   Segment
1030   //
1031   // start_time and end_time are the start and end logical times of the segment.
1032   // use_times is a sorted sequence of the times of all uses.
1033   // latest_prefetch_time is the latest time we can schedule the CopyDone for a
1034   // prefetch.
1035   // If allow_no_copy_alternate_mem_allocation is false, an eviction is forced.
1036   // If earliest_prefetch_time is set, prefetches cannot start before this
1037   // value.
1038   struct AllocationRequest {
1039     int64 start_time;
1040     int64 end_time;
1041     int64 latest_prefetch_time;
1042     int64 size;
1043     bool allow_no_copy_alternate_mem_allocation;
1044     absl::optional<int64> earliest_prefetch_time;
1045     AliasedOffset* preferred_offset;
1046     const MemorySpaceAssignment::AllocationValue::Use* use;
1047     MemorySpaceAssignment::AllocationValue* allocation_value;
1048   };
1049 
1050   // This struct contains mandatory memory assignments at a given time. E.g., an
1051   // input's required memory assignment time would correspond to the definition
1052   // time of the parameter instruction, and an output's time would correspond to
1053   // the time of last use.
1054   struct RequiredMemoryAssignment {
1055     MemorySpaceAssignment::MemorySpace memory_space;
1056     int64 time;
1057     AliasedOffset* offset;
1058 
equals_ignoring_timeRequiredMemoryAssignment1059     bool equals_ignoring_time(const RequiredMemoryAssignment& other) const {
1060       return memory_space == other.memory_space && offset == other.offset;
1061     }
1062 
1063     bool operator==(const RequiredMemoryAssignment& other) const {
1064       return memory_space == other.memory_space && time == other.time &&
1065              offset == other.offset;
1066     }
1067 
1068     bool operator!=(const RequiredMemoryAssignment& other) const {
1069       return !(*this == other);
1070     }
1071   };
1072 
1073   // Result of an allocation, prefetch, eviction etc. request.  The result is
1074   // either kSuccess or a bitwise OR of one or more failures. The values are
1075   // unique powers of two. To check if a result contains a particular failure,
1076   // use the result_is method. To add a new failure to a result, use the
1077   // result_mark method.
1078   enum class Result {
1079     // Successful allocation.
1080     kSuccess = 0,
1081     // Allocation failed because we ran out of alternate memory.
1082     kFailOutOfMemory = 1,
1083     // A no-copy allocation couldn't be performed because the previous
1084     // allocation wasn't in the alternate memory space.
1085     kFailPrevAllocationNotInAlternateMem = 2,
1086     // A no-copy allocation couldn't be performed because the live range was too
1087     // long.
1088     kFailLiveRangeTooLong = 4,
1089     // A prefetching couldn't be performed because the live range was too short.
1090     kFailLiveRangeTooShort = 8,
1091     // Ran out of outstanding asynchronous copy limit either during prefetching
1092     // or eviction.
1093     kFailOutOfAsyncCopies = 16,
1094     // A prefetching couldn't be performed because the asynchronous copy
1095     // ordering was violated.
1096     kFailViolatesAsyncCopyOrdering = 32,
1097     // An allocation failure happened that requires uncommitting all the pending
1098     // allocations. Usually this is due to a situation requiring an eviction but
1099     // the eviction couldn't be performed.
1100     kFailRequiresUncommit = 64
1101   };
1102 
1103   // Return true if the result belongs to a failure.
result_is(Result result,Result failure)1104   static bool result_is(Result result, Result failure) {
1105     return static_cast<int>(result) & static_cast<int>(failure);
1106   }
1107 
1108   // Mark (bitwise OR) a failure to the result.
result_mark(Result failure,Result & result)1109   static Result result_mark(Result failure, Result& result) {
1110     result = static_cast<Result>(static_cast<int>(result) |
1111                                  static_cast<int>(failure));
1112     return result;
1113   }
1114 
1115   // Return true if the result is a failure that requires us to uncommit pending
1116   // chunks.
result_requires_uncommit(Result result)1117   static bool result_requires_uncommit(Result result) {
1118     return result_is(result, Result::kFailRequiresUncommit);
1119   }
1120 
1121   // Return true if the result is a failure either due to running out of
1122   // outstanding asynchronous copies or due to violating asynchronous copy
1123   // ordering.
result_failed_because_of_async_copy(Result result)1124   static bool result_failed_because_of_async_copy(Result result) {
1125     return result_is(result, Result::kFailOutOfAsyncCopies) ||
1126            result_is(result, Result::kFailViolatesAsyncCopyOrdering);
1127   }
1128 
1129   // Returns the AliasedOffset object associated with the allocation.
1130   AliasedOffset* GetAliasedOffset(
1131       const MemorySpaceAssignment::Allocation& allocation);
1132 
1133   // If aliased_offset is non-null, this method adds the allocation to
1134   // aliased_offset. Otherwise, it creates a new AliasedOffset object and adds
1135   // the allocation to this new AliasedOffset.
1136   void CreateOrAddToAliasedOffset(
1137       const MemorySpaceAssignment::Allocation& allocation,
1138       AliasedOffset* aliased_offset);
1139 
1140   // Given an allocation sequence, returns the live allocation at time with a
1141   // preference towards allocations in alternate memory. Returns nullptr if no
1142   // allocation is alive at that time.
1143   static MemorySpaceAssignment::Allocation* GetLiveAllocationAt(
1144       const MemorySpaceAssignment::AllocationSequence& allocations, int64 time);
1145 
1146   // Returns true if the use is allowed in the alternate memory.
1147   bool IsUseAllowedInAlternateMemory(const AllocationValue& value,
1148                                      const HloUse& use) const;
1149 
1150   // Finds allocations for allocation values generated from colocated intervals.
1151   // All of the allocation values have a must-alias relationship with each
1152   // other. Returns either kSuccess if all of the sites could be placed in the
1153   // alternate memory or a bitwise OR of failure reasons why they couldn't
1154   Result AllocateAllocationValues(
1155       absl::Span<AllocationValue> allocation_values);
1156 
1157   // Finds an allocation for an allocation request for a segment (see the
1158   // documentation for AllocationRequest above how a segment is defined).
1159   //
1160   // It performs three things in the following order:
1161   //  1- Allocate the allocation request entirely in the alternate memory, if
1162   //     there is enough space and if the prefetch interval picker allows.
1163   //  2- If (1) was unsuccessful, and the only allocation for
1164   //     this buffer was in the alternate memory, we try to perform a prefetch.
1165   //  3- If (1) was unsuccessful, prefetch the buffer into the alternate memory,
1166   //     if there is enough space and if the prefetch interval picker allows.
1167   //
1168   // If an eviction (2) was requested and was unsuccessful, this method returns
1169   // Result::kFailRequiresUncommit. This means we could not find a suitable
1170   // allocation, so all previous allocations for this buffer must be removed and
1171   // allocated in the default memory. Otherwise, this method may return
1172   // Result::kSuccess if the buffer could be placed in alternate memory or some
1173   // other Result with an OR of reasons why the buffer couldn't be placed in
1174   // alternate memory.
1175   Result AllocateSegment(const AllocationRequest& request);
1176 
1177   // Try allocating in alternate memory without any copies.
1178   Result AllocateInAlternateMemoryNoCopy(const AllocationRequest& request);
1179 
1180   // Try evicting to default memory space.
1181   Result Evict(const AllocationRequest& request);
1182 
1183   // Returns the time a copy done of a prefetch should be scheduled.
1184   int64 FindPrefetchEndTime(const AllocationRequest& request,
1185                             int64 earliest_prefetch_time) const;
1186 
1187   // Try prefetching to alternate memory space.
1188   Result Prefetch(
1189       const AllocationRequest& request,
1190       const MemorySpaceAssignment::Allocation& prev_allocation_in_default_mem);
1191 
1192   // Find the best possible chunk candidate, where it has the longest possible
1193   // availability if no preferred offset is given, or at the preferred_offset if
1194   // it is given.
1195   absl::optional<ChunkCandidate> FindBestChunkCandidate(
1196       const AllocationRequest& request, const AliasedOffset* preferred_offset,
1197       BufferInterval* alternate_mem_interval) const;
1198 
1199   // Returns the required assignment at a particular time, if available.
1200   absl::optional<RequiredMemoryAssignment> RequiredMemoryAssignmentAt(
1201       const HloValue* buffer, int64 time) const;
1202 
1203   // Searches for aliases in the use for a required assignment, and returns it
1204   // if found.
1205   absl::optional<RequiredMemoryAssignment> AliasedRequiredAssignmentForUse(
1206       const AllocationValue::Use& use) const;
1207 
1208   // Goes through the colocated intervals and adds any required assignment.
1209   void AddRequiredAssignmentsForColocatedIntervals(
1210       absl::Span<const AlternateMemoryBestFitHeap::BufferInterval* const>
1211           colocated_intervals);
1212 
1213   // Propagates aliased required assignment for a given position.
1214   void AddAliasedRequiredAssignment(
1215       const HloInstruction* instruction, ShapeIndex index,
1216       const MemorySpaceAssignment::Allocation* aliased_allocation);
1217 
1218   // This sets a required assignment. CHECK fails if there is a conflicting
1219   // required assignment at the same time.
1220   void AddRequiredAssignment(const HloValue* value,
1221                              const HloInstruction* instruction,
1222                              MemorySpace memory_space, int64 time,
1223                              AliasedOffset* offset = nullptr);
1224   void AddRequiredAssignment(const HloInstruction* instruction,
1225                              ShapeIndex index, MemorySpace memory_space,
1226                              AliasedOffset* offset = nullptr);
1227 
1228   // Adds input and outputs as required assignments.
1229   void AddInputAndOutputRequiredAssignments();
1230 
1231   // Returns true if the colocated intervals in the argument are in a parameter
1232   // or root instruction of the entry computation and are reserved by the user
1233   // to be in the alternate memory space.
1234   bool AreIntervalsReservedInAlternateMemory(
1235       absl::Span<const BufferInterval* const> colocated_intervals) const;
1236 
1237   // Since the allocations are recorded to the AllocationSequence, we don't
1238   // maintain result_ in GlobalDecreasingSizeBestFitHeap. Override AddToChunkMap
1239   // to avoid unnecessarily adding the chunk to the chunk map.
AddToChunkMap(const HloValue * buffer,Chunk chunk)1240   void AddToChunkMap(const HloValue* buffer, Chunk chunk) override {}
1241 
1242   // Returns true if the addition of an asynchronous copy in the given time
1243   // interval would violate the maximum number of asynchronous copies. An extra
1244   // async copy limit can be provided to increase the limit of asynchronous
1245   // copies for this instance.
1246   bool ViolatesMaximumOutstandingAsyncCopies(
1247       int64 start_time, int64 end_time, bool is_prefetch,
1248       int64 extra_async_copy_limit = 0) const;
1249 
1250   // If the asynchronous copy would violate the pipelining order, returns the
1251   // violating asynchronous copy.
1252   absl::optional<AsynchronousCopy> ViolatesAsyncCopyOrdering(
1253       int64 start_time, int64 end_time) const;
1254 
1255   // Exports the allocations for repacking and puts them into the vector in the
1256   // parameter.
1257   void ExportAllocationsForRepacking(
1258       std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*>&
1259           allocations);
1260 
1261   // Imports repacked allocations and updates the internal data structures
1262   // consistent with the new packing.
1263   void ImportRepackedAllocations();
1264 
1265   // Adds an asynchronous copy to the allocations.
1266   void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation,
1267                     MemorySpace memory_space, absl::optional<Chunk> chunk,
1268                     int64 start_time, int64 end_time,
1269                     int64 copy_done_schedule_before_time,
1270                     MemorySpaceAssignment::AllocationSequence* allocations,
1271                     AliasedOffset* aliased_offset,
1272                     bool is_cross_program_prefetch = false);
1273 
1274   // This method is used for committing the chunk candidate but adding it to
1275   // pending_chunks_ so that we can "uncommit" them in case we need to roll back
1276   // this allocation sequence.
1277   void AddToPendingChunks(const BufferInterval& buffer_interval,
1278                           const ChunkCandidate& chunk_candidate);
1279   // If we need to remove the allocations for this allocation sequence, this
1280   // removes pending chunks and asynchronous copies in the respective pending
1281   // buffers from the interval trees. If an allocation request returns
1282   // kFailRequiresUncommit, this method must be called.
1283   void UncommitPendingChunks(absl::Span<AllocationValue> allocation_values);
1284 
1285   // Finalizes the allocations where they can no longer be uncommitted.
1286   void FinalizeAllocations(absl::Span<AllocationValue> allocation_values);
1287 
1288   // Clears all pending chunks and asynchronous copies.
1289   void ClearPendingChunks();
1290 
1291   // Append buffer and allocation infos for debugging and dump it into a file,
1292   // if enabled.
1293   void AppendBufferInfoDebugString(const BufferInterval& interval,
1294                                    std::string* debug_str) const;
1295   void AppendAllocationInfoDebugString(
1296       const AllocationValue& value,
1297       const MemorySpaceAssignment::Allocation& allocation,
1298       std::string& debug_str) const;
1299   void DumpDebugStringsIfEnabled() const;
1300 
1301   // Returns the available heap size in the alternate memory.
available_heap_size()1302   int64 available_heap_size() const {
1303     return options_.max_size_in_bytes - reserved_in_bytes_;
1304   }
1305 
1306   // Creates and returns a RepackAllocationBlock.
MakeRepackAllocationBlock(int64 start_time,int64 end_time,int64 size,int64 initial_offset,int64 id,MemorySpaceAssignment::Allocation * allocation)1307   static RepackAllocationBlock MakeRepackAllocationBlock(
1308       int64 start_time, int64 end_time, int64 size, int64 initial_offset,
1309       int64 id, MemorySpaceAssignment::Allocation* allocation) {
1310     RepackAllocationBlock allocation_block;
1311     allocation_block.start_time = start_time;
1312     allocation_block.end_time = end_time;
1313     allocation_block.size = size;
1314     allocation_block.offset = -1;
1315     allocation_block.initial_offset = initial_offset;
1316     allocation_block.id = id;
1317     allocation_block.colocations = {};
1318     allocation_block.allocation = allocation;
1319     return allocation_block;
1320   }
1321 
1322   MemorySpaceAssignment::AllocationSequence* allocations_;
1323   const MemorySpaceAssignment::Options& options_;
1324   const HloAliasAnalysis& alias_analysis_;
1325   const HloLiveRange& hlo_live_range_;
1326   // We use a interval tree to keep track of the number of outstanding
1327   // prefetches and evictions.
1328   BufferIntervalTree prefetch_interval_tree_;
1329   BufferIntervalTree eviction_interval_tree_;
1330   AsynchronousCopyOrdering async_copy_ordering_;
1331   // A list of RepackAllocationBlock objects that mirrors allocation sequences,
1332   // used for repacking. We use a list here because we need pointer stability
1333   // for aliased allocations.
1334   std::list<RepackAllocationBlock> repack_allocation_blocks_;
1335   int64 num_repacks_ = 0;
1336   std::vector<std::pair<BufferInterval, ChunkCandidate>> pending_chunks_;
1337   std::vector<AsynchronousCopy> pending_async_copies_;
1338   std::vector<std::pair<const HloValue*, RequiredMemoryAssignment>>
1339       pending_required_assignments_;
1340   // The data structure that contains AliasedOffset objects and Allocation to
1341   // AliasedOffset map for efficient lookup.
1342   std::list<AliasedOffset> aliased_offsets_;
1343   absl::flat_hash_map<const MemorySpaceAssignment::Allocation*, AliasedOffset*>
1344       aliased_offset_map_;
1345   // This map contains required memory assignments for HloValues (e.g., input
1346   // and outputs).
1347   absl::flat_hash_map<const HloValue*, std::vector<RequiredMemoryAssignment>>
1348       required_assignments_;
1349   // Number of bytes reserved in alternate memory space.
1350   int64 reserved_in_bytes_ = 0;
1351   // Debug strings.
1352   std::string buffer_info_str_;
1353   std::string allocation_info_str_;
1354 };
1355 
1356 }  // namespace xla
1357 
1358 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_
1359