• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_
18 
19 #include <functional>
20 #include <string>
21 #include <utility>
22 
23 #include "tensorflow/compiler/xla/service/heap_simulator.h"
24 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
25 #include "tensorflow/compiler/xla/service/memory_space_assignment_repacking.h"
26 
27 namespace xla {
28 
29 namespace memory_space_assignment {
30 // Forward Declaration of Options.
31 class Options;
32 
33 // This class contains pre-set assignments determined by memory space
34 // assignment. It contains two data structures: (1) a chunks vector that maps a
35 // defining HloPosition to a Chunk (offset and size), and (2) an assignment_info
36 // vector that maps the memory space to information like its allocated size and
37 // heap memory trace. If there is only one alternate memory space like there is
38 // currently, there will be one entry in assignment_info.
39 class PresetAssignments {
40  public:
41   // Contains per-memory-space information like the allocated size and heap
42   // simulator trace.
43   struct AssignmentInformation {
44     int64 size;
45     HeapSimulatorTrace heap_simulator_trace;
46   };
47 
48   PresetAssignments() = default;
49 
add_chunk(const HloPosition & position,const HeapSimulator::Chunk & chunk)50   void add_chunk(const HloPosition& position,
51                  const HeapSimulator::Chunk& chunk) {
52     chunks_.emplace_back(position, chunk);
53   }
54 
add_scoped_allocation_chunk(HloInstruction * instruction,const HeapSimulator::Chunk & chunk)55   void add_scoped_allocation_chunk(HloInstruction* instruction,
56                                    const HeapSimulator::Chunk& chunk) {
57     scoped_allocation_chunks_.emplace_back(instruction, chunk);
58   }
59 
assignment_information_for_space(int64_t memory_space)60   AssignmentInformation* assignment_information_for_space(
61       int64_t memory_space) {
62     for (auto& space_and_info : assignment_info_) {
63       if (space_and_info.first == memory_space) {
64         return &space_and_info.second;
65       }
66     }
67     assignment_info_.emplace_back(memory_space, AssignmentInformation());
68     return &assignment_info_.back().second;
69   }
70 
chunks()71   absl::Span<const std::pair<HloPosition, HeapSimulator::Chunk>> chunks()
72       const {
73     return chunks_;
74   }
75 
76   absl::Span<const std::pair<HloInstruction*, HeapSimulator::Chunk>>
scoped_allocation_chunks()77   scoped_allocation_chunks() const {
78     return scoped_allocation_chunks_;
79   }
80 
81   absl::Span<const std::pair<int64, AssignmentInformation>>
assignment_informations()82   assignment_informations() const {
83     return assignment_info_;
84   }
85 
86   // Get debugging information.
buffer_info_str()87   std::string buffer_info_str() const { return buffer_info_str_; }
allocation_info_str()88   std::string allocation_info_str() const { return allocation_info_str_; }
89 
90  private:
91   std::vector<std::pair<HloPosition, HeapSimulator::Chunk>> chunks_;
92   std::vector<std::pair<HloInstruction*, HeapSimulator::Chunk>>
93       scoped_allocation_chunks_;
94   std::vector<std::pair<int64, AssignmentInformation>> assignment_info_;
95   std::string buffer_info_str_;
96   std::string allocation_info_str_;
97 };
98 
99 // A wrapper class around HloCostAnalysis with additional knowledge about the
100 // bandwidths of different memory spaces.
101 class MemorySpaceAssignmentCostAnalysis {
102  public:
103   // An optional Cache object may be provided to some of the methods below to
104   // speed up the lookup.
105   struct Cache {
106     absl::flat_hash_map<const HloInstruction*, float> while_nest_multiplier;
107   };
108 
109   virtual ~MemorySpaceAssignmentCostAnalysis() = default;
110 
111   static StatusOr<std::unique_ptr<MemorySpaceAssignmentCostAnalysis>> Create(
112       const HloCostAnalysis& cost_analysis, const Options& options,
113       const HloModule& module);
114 
cost_analysis()115   const HloCostAnalysis& cost_analysis() const { return cost_analysis_; }
116 
117   // Returns a heuristic value that captures how much putting this tensor to the
118   // alternate memory would help if the op is memory bound, or otherwise how far
119   // off is the op to memory boundedness. The larger this number, the higher
120   // priority it will be placed in the alternate memory.
121   float GetAlternateMemoryBenefit(const HloInstruction& instruction,
122                                   float elapsed_time_due_to_alternate_mem,
123                                   Cache* cache = nullptr) const;
124 
125   // Returns a heuristic value of memory boundedness for the given
126   // BufferInterval.  The larger this number, the higher priority it will be
127   // placed in the alternate memory.
128   float GetMemoryBoundedness(
129       const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval,
130       Cache* cache = nullptr) const;
131 
132   // Returns the elapsed time in seconds due to compute only.
133   float GetInstructionElapsedDueToCompute(
134       const HloInstruction& instruction) const;
135 
136   // Returns the elapsed time in seconds due to memory only. If
137   // operands_in_alternate_mem or outputs_in_alternate_mem is provided, it will
138   // assume that the corresponding operands or output will be in the alternate
139   // memory space. This is useful for calculating the benefit of placing the
140   // buffer in alternate memory.
141   float GetInstructionElapsedDueToMemory(
142       const HloInstruction& instruction,
143       absl::Span<const std::pair<int64_t, ShapeIndex>>
144           operands_in_alternate_mem = {},
145       absl::Span<const ShapeIndex> outputs_in_alternate_mem = {}) const;
146 
147   // Returns the estimated elapsed duration of the instruction in seconds.  It
148   // assumes all operands and outputs of the instruction are in the default
149   // memory.
150   virtual float GetInstructionElapsed(const HloInstruction& instruction) const;
151 
152   // Returns the estimated elapsed duration of the instruction in seconds.  It
153   // assumes all operands and outputs of the instruction are in the default
154   // memory, except for the operands and outputs specified to be in the
155   // alternate memory.
156   virtual float GetInstructionElapsedInAlternateMemory(
157       const HloInstruction& instruction,
158       absl::Span<const std::pair<int64_t, ShapeIndex>>
159           operands_in_alternate_mem,
160       absl::Span<const ShapeIndex> outputs_in_alternate_mem) const;
161 
162   // Returns the elapsed time it would take to asynchronously copy the shape
163   // from default to alternate memory space (or vice versa).
164   virtual float GetAsyncCopyElapsed(const Shape& shape) const;
165 
166   int64 GetScheduleEndTime() const;
167 
168   // Returns the number of nested computation levels this instruction resides
169   // in. If while_only is true, it returns the while loop nest level and 0
170   // means the instruction is not in a while loop.
171   int CalculateComputationNestLevel(const HloInstruction* instruction,
172                                     bool while_only) const;
173 
hlo_live_range()174   const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; }
options()175   const Options& options() const { return options_; }
176 
177  protected:
MemorySpaceAssignmentCostAnalysis(const HloCostAnalysis & cost_analysis,const Options & options,std::unique_ptr<HloAliasAnalysis> alias_analysis,std::unique_ptr<HloLiveRange> hlo_live_range,std::unique_ptr<CallGraph> call_graph)178   MemorySpaceAssignmentCostAnalysis(
179       const HloCostAnalysis& cost_analysis, const Options& options,
180       std::unique_ptr<HloAliasAnalysis> alias_analysis,
181       std::unique_ptr<HloLiveRange> hlo_live_range,
182       std::unique_ptr<CallGraph> call_graph)
183       : cost_analysis_(cost_analysis),
184         options_(options),
185         alias_analysis_(std::move(alias_analysis)),
186         hlo_live_range_(std::move(hlo_live_range)),
187         call_graph_(std::move(call_graph)) {}
188 
189  private:
190   const HloCostAnalysis& cost_analysis_;
191   const Options& options_;
192   std::unique_ptr<HloAliasAnalysis> alias_analysis_;
193   std::unique_ptr<HloLiveRange> hlo_live_range_;
194   std::unique_ptr<CallGraph> call_graph_;
195 };
196 
197 // Abstract base class that memory space assignment uses to pick prefetch
198 // intervals.
199 class PrefetchIntervalPicker {
200  public:
201   PrefetchIntervalPicker() = default;
202   virtual ~PrefetchIntervalPicker() = default;
203 
204   // Returns true if the buffer can be allocated in alternate memory space
205   // without any copies (prefetches).
206   virtual bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape,
207                                                   int64_t start_time,
208                                                   int64_t end_time) const = 0;
209 
210   // Returns the preferred end time for an eviction that starts at a given time
211   // and must end by the given end time.
212   virtual int64 PreferredEvictionEndTime(const Shape& shape, int64_t start_time,
213                                          int64_t latest_end_time) const = 0;
214 
215   // Returns the latest time that a prefetch can start.
216   virtual int64 LatestPrefetchStartTime(const Shape& shape, int64_t start_time,
217                                         int64_t end_time,
218                                         const HloUse* use) const = 0;
219 
220   // Returns the preferred time that a prefetch can start.
221   virtual int64 PreferredPrefetchStartTime(const Shape& shape,
222                                            int64_t earliest_prefetch_start_time,
223                                            int64_t latest_prefetch_start_time,
224                                            int64_t prefetch_end_time) const = 0;
225 
226   // Returns the latest time that a prefetch can end that is less than or equal
227   // to proposed_prefetch_end_time.
LatestPrefetchEndTime(int64_t original_prefetch_end_time,int64_t proposed_prefetch_end_time)228   virtual int64 LatestPrefetchEndTime(
229       int64_t original_prefetch_end_time,
230       int64_t proposed_prefetch_end_time) const {
231     return proposed_prefetch_end_time;
232   }
233 
234   // Begins the iterator for the first start time of the prefetch.
235   virtual void Begin(const HloUse& use, int64_t start_time,
236                      int64_t end_time) = 0;
237 
238   // Advances the start time of the prefetch and returns that value.
239   virtual int64 Next() = 0;
240 
241   // Returns true if the available prefetch intervals have been exhausted.
242   virtual bool Done() const = 0;
243 
244   // The retry number can be used to modify the interval picking policies. The
245   // first attempt will have a retry_number of 0, then 1, etc.
SetRetryNumber(int retry_number)246   virtual void SetRetryNumber(int retry_number) {}
247 
248   // Returns a debug string for the current state of the prefetch interval
249   // picker.
250   virtual std::string ToDebugString() const = 0;
251 
252   // Returns a debug string for no-copy allocation.
253   virtual std::string ToNoCopyDebugString(const Shape& shape,
254                                           int64_t start_time,
255                                           int64_t end_time) const = 0;
256 
257   // Prefetch interval pickers may return a value corresponding to the benefit
258   // of placing the BufferInterval in the alternate memory. The larger value,
259   // the more beneficial.
BufferIntervalAlternateMemoryBenefit(const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval & interval)260   virtual absl::optional<float> BufferIntervalAlternateMemoryBenefit(
261       const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval)
262       const {
263     return absl::nullopt;
264   }
265 
266  protected:
267   const absl::flat_hash_map<const HloInstruction*, int64>*
268       instruction_schedule_ = nullptr;
269 };
270 
271 // Prefetch interval picker that uses instruction count to overlap asynchronous
272 // copies with independent computation. The min and max overlap counts describe
273 // the number of independent HLOs overlapped while a value is being prefetched
274 // into the alternate memory (between CopyStart and CopyDone HLO instructions).
275 // max_overlap_count attempts to prevent bringing tensors into the alternate
276 // memory too eagerly and hence occupying the space for other tensors which
277 // might use it.  min_overlap_count attempts to prevent cases where tensors are
278 // prefetched into the alternate memory without sufficient time for the copy to
279 // take place.  In those cases, it's just better to keep the tensor in the
280 // default memory instead of hurting the critical path with this copy that
281 // likely won't finish in time.
282 class InstructionCountPrefetchIntervalPicker : public PrefetchIntervalPicker {
283  public:
InstructionCountPrefetchIntervalPicker(int64_t min_overlap_count,int64_t max_overlap_count)284   InstructionCountPrefetchIntervalPicker(int64_t min_overlap_count,
285                                          int64_t max_overlap_count)
286       : min_overlap_count_(min_overlap_count),
287         max_overlap_count_(max_overlap_count) {}
288 
289   bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape,
290                                           int64_t start_time,
291                                           int64_t end_time) const override;
292 
293   int64 PreferredEvictionEndTime(const Shape& shape, int64_t start_time,
294                                  int64_t latest_end_time) const override;
295 
296   int64 LatestPrefetchStartTime(const Shape& shape, int64_t start_time,
297                                 int64_t end_time,
298                                 const HloUse* use) const override;
299 
300   int64 PreferredPrefetchStartTime(const Shape& shape,
301                                    int64_t earliest_prefetch_start_time,
302                                    int64_t latest_prefetch_start_time,
303                                    int64_t prefetch_end_time) const override;
304 
305   void Begin(const HloUse& use, int64_t start_time, int64_t end_time) override;
306 
307   int64 Next() override;
308   bool Done() const override;
309 
310   std::string ToDebugString() const override;
311   std::string ToNoCopyDebugString(const Shape& shape, int64_t start_time,
312                                   int64_t end_time) const override;
313 
314  private:
315   int64 min_overlap_count_;
316   int64 max_overlap_count_;
317   int64 end_time_;
318   int64 current_prefetch_time_;
319 };
320 
321 // Forward Declaration of MemorySpaceAssignmentCostAnalysis
322 class MemorySpaceAssignmentCostAnalysis;
323 // Prefetch interval picker that uses cost analysis to overlap asynchronous
324 // copies with independent computation. It uses min/max (asynchronous copy
325 // duration) / (independent computation duration) ratios to guide whether the
326 // prefetch is within those bounds. It starts with the preferred ratio in
327 // Begin() and works its way for alternately earlier and later prefetches until
328 // hitting min and max ratios. The value for buffer size for max async copy is a
329 // mechanism to prevent copying small buffers between the two memories
330 // unnecessarily. For calculating the max time that the buffer can reside in
331 // alternate memory, we use the larger of this value and the actual size of the
332 // buffer.
333 class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker {
334  public:
335   CostAnalysisPrefetchIntervalPicker(
336       const MemorySpaceAssignmentCostAnalysis& cost_analysis,
337       float min_async_copy_to_overlap_ratio,
338       float max_async_copy_to_overlap_ratio,
339       float preferred_async_copy_to_overlap_ratio,
340       int64_t buffer_size_for_max_async_copy);
341 
342   bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape,
343                                           int64_t start_time,
344                                           int64_t end_time) const override;
345 
346   int64 PreferredEvictionEndTime(const Shape& shape, int64_t start_time,
347                                  int64_t latest_end_time) const override;
348 
349   int64 LatestPrefetchEndTime(
350       int64_t original_prefetch_end_time,
351       int64_t proposed_prefetch_end_time) const override;
352 
353   int64 LatestPrefetchStartTime(const Shape& shape, int64_t start_time,
354                                 int64_t end_time,
355                                 const HloUse* use) const override;
356 
357   int64 PreferredPrefetchStartTime(const Shape& shape,
358                                    int64_t earliest_prefetch_start_time,
359                                    int64_t latest_prefetch_start_time,
360                                    int64_t prefetch_end_time) const override;
361 
362   void Begin(const HloUse& use, int64_t start_time, int64_t end_time) override;
363 
364   int64 Next() override;
365   bool Done() const override;
366 
367   void SetRetryNumber(int retry_number) override;
368 
369   std::string ToDebugString() const override;
370   std::string ToNoCopyDebugString(const Shape& shape, int64_t start_time,
371                                   int64_t end_time) const override;
372 
373   absl::optional<float> BufferIntervalAlternateMemoryBenefit(
374       const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval)
375       const override;
376 
377  private:
378   // Returns the elapsed time in seconds between the logical interval that
379   // corresponds to the instruction schedule.
380   float GetLogicalIntervalElapsed(int64_t start_time, int64_t end_time) const;
381 
382   // Finds the minimum nest level in the given interval.
383   int GetMinWhileNestLevel(int64_t start_time, int64_t end_time) const;
384 
385   // Given the elapsed time to copy this buffer to the alternate memory, returns
386   // the longest time that this buffer may reside in the alternate memory space.
387   float GetMaxElapsedInAlternateMemory(float async_copy_elapsed) const;
388 
389   // For each instruction in the flattened schedule, maintain their elapsed time
390   // (in cumulative sum) and while nesting level.
391   std::vector<float> elapsed_time_cumsum_;
392   std::vector<int> while_nest_level_;
393   std::vector<int> computation_nest_level_;
394   // Maintain the index of the most recent (before this instruction) nest level
395   // change in order to efficiently determine the minimum nest level in an
396   // interval.
397   std::vector<int> while_nest_level_change_;
398 
399   const MemorySpaceAssignmentCostAnalysis& cost_analysis_;
400   float min_async_copy_to_overlap_ratio_;
401   float max_async_copy_to_overlap_ratio_;
402   float preferred_async_copy_to_overlap_ratio_;
403   int64_t buffer_size_for_max_async_copy_;
404   float max_overlap_multiplier_ = 1.0;
405 
406   float async_copy_elapsed_;
407   float inst_elapsed_reduction_;
408   int64 end_logical_time_;
409   int64 earliest_prefetch_time_;
410   int64 latest_prefetch_time_;
411   bool using_increasing_prefetch_time_iterator_ = true;
412   int64 increasing_prefetch_time_iterator_;
413   int64 decreasing_prefetch_time_iterator_;
414 };
415 
416 // MemorySpaceAssignment assigns memory spaces (default or alternate) to each
417 // instruction in the module. It will greedily try placing as as many values in
418 // the alternate memory space as possible. It uses the heap simulator to
419 // determine the actual allocation offsets of values in the alternate memory
420 // space to account for fragmentation. The default memory space is assumed to be
421 // large enough to hold the values that could not be placed in the alternate
422 // memory space.
423 class MemorySpaceAssignment {
424  public:
425   using Chunk = HeapSimulator::Chunk;
426   using BufferInterval =
427       GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval;
428   using BufferIntervalCompare =
429       GlobalDecreasingSizeBestFitHeap<HloValue>::BufferIntervalCompare;
430   using IsAllowedInAlternateMemoryFunction =
431       std::function<bool(const HloValue&)>;
432   using IsUseAllowedInAlternateMemoryFunction =
433       std::function<bool(const HloUse&)>;
434   using ReservedScopedMemoryFunction =
435       std::function<int64_t(const HloInstruction*)>;
436 
437   // MemorySpaceAssignment uses a notion of a slow and large default memory
438   // space and a fast and small alternate memory space.
439   enum class MemorySpace { kDefault, kAlternate };
440 
441   // Forward declaration for Allocation.
442   class Allocation;
443   class ParentAllocation;
444 
445   // This class represents an allocation that might either be in the default or
446   // alternate memory. An HloValue might live in multiple different allocations
447   // over its lifetime. The lifetimes of the allocations are defined using
448   // start_time and end_time, which corresponds to the instruction indexes in
449   // the flattened schedule. Each of these allocations might partially overlap
450   // with each other. CopyAllocation defined below represents asynchronous
451   // copies between Allocations.
452   //
453   // Consider an instruction Foo, and its users Bar and Baz, and the times given
454   // in terms of the flattened schedule of the entire module:
455   //
456   //      Foo:10
457   //       /   \
458   //    Bar:14  \
459   //           Baz:25
460   //
461   // A valid memory space assignment could be like the following:
462   //
463   //  Time:         10 ... 14        ...      25
464   //                Foo    Bar                Baz
465   //  Alternate     +-------+           +-----+
466   //  Default           +---------------------+
467   //                    ^   ^           ^     ^
468   //                    |   |           |     |
469   //                evict   evict  prefetch  prefetch
470   //                start    end    start      end
471   //
472   // This would be represented with:
473   //   - Allocation(memory_space=kAlternate, start_time=10, end_time=14)
474   //   - CopyAllocation(memory_space=kDefault, start_time=12, end_time=25)
475   //   - CopyAllocation(memory_space=kAlternate, start_time=22, end_time=25)
476   class Allocation {
477     friend class ParentAllocation;
478 
479    public:
Allocation(HloPosition defining_position,MemorySpace memory_space,absl::optional<Chunk> chunk,int64_t start_time,int64_t end_time,bool is_scoped_allocation)480     Allocation(HloPosition defining_position, MemorySpace memory_space,
481                absl::optional<Chunk> chunk, int64_t start_time,
482                int64_t end_time, bool is_scoped_allocation)
483         : defining_position_(defining_position),
484           memory_space_(memory_space),
485           chunk_(chunk),
486           start_time_(start_time),
487           end_time_(end_time),
488           is_scoped_allocation_(is_scoped_allocation) {
489       CHECK(!is_scoped_allocation || defining_position.index == ShapeIndex({}));
490     }
491     virtual ~Allocation() = default;
492 
is_copy_allocation()493     virtual bool is_copy_allocation() const { return false; }
494 
495     // Adds a use to this allocation.
496     void AddUse(HloUse use);
497 
498     // Extends the end time of this allocation.
Extend(int64_t end_time)499     void Extend(int64_t end_time) { end_time_ = end_time; }
500 
501     // After all of the time ranges for the allocations have been assigned,
502     // Process morphs the instructions affected to assign the memory spaces and
503     // insert asynchronous copy instructions if necessary.
504     virtual Status Process();
505 
506     // An optional post-process step that will be called after all allocations
507     // have been processed.
PostProcess()508     virtual Status PostProcess() { return Status::OK(); }
509 
510     // Marks (adds this allocation to needed_allocations) if this allocation is
511     // needed. Allocation and CopyAllocations are always needed and
512     // ParentAllocations are needed if they have any uses or if other
513     // CopyAllocation or ParentAllocations depend on them.
514     virtual void MarkIfNeeded(
515         absl::flat_hash_set<const Allocation*>& needed_allocations) const;
516 
517     // Marks this allocation as needed.
518     virtual void MarkNeeded(
519         absl::flat_hash_set<const Allocation*>& needed_allocations) const;
520 
521     // Returns the defining position for this allocation.
defining_position()522     virtual HloPosition defining_position() const { return defining_position_; }
523 
524     // Returns the time the buffer is first available to be used. For
525     // Allocation, this is start_time.
earliest_available_time()526     virtual int64 earliest_available_time() const { return start_time_; }
527 
uses()528     const std::vector<HloUse>& uses() const { return uses_; }
memory_space()529     MemorySpace memory_space() const { return memory_space_; }
chunk()530     Chunk chunk() const { return *chunk_; }
mutable_chunk()531     Chunk* mutable_chunk() { return &*chunk_; }
set_start_time(int64_t start_time)532     void set_start_time(int64_t start_time) { start_time_ = start_time; }
start_time()533     int64 start_time() const { return start_time_; }
end_time()534     int64 end_time() const { return end_time_; }
is_scoped_allocation()535     bool is_scoped_allocation() const { return is_scoped_allocation_; }
536 
537     bool operator==(const Allocation& other) const;
538     virtual std::string ToString() const;
539 
540    protected:
541     // Descend to the shape_index element of the tuple and replace that with
542     // new_instruction.
543     StatusOr<HloInstruction*> ReplaceTupleWith(HloInstruction* new_instruction,
544                                                HloInstruction* tuple,
545                                                ShapeIndex shape_index);
546 
547     // Recursively create kGetTupleElement instructions if the defining position
548     // shape is not an array. Returns the new instruction that has array shape.
549     HloInstruction* AddGetTupleElements() const;
550 
551     HloPosition defining_position_;
552     std::vector<HloUse> uses_;
553     MemorySpace memory_space_;
554     absl::optional<Chunk> chunk_;
555     int64 start_time_;
556     int64 end_time_;
557     const bool is_scoped_allocation_;
558   };
559 
560   // This class represents an allocation as a result of an asynchronous copy.
561   // Note: CopyStart instructions are inserted after `start_time` or later,
562   // while CopyDone instructions are inserted before
563   // `copy_done_schedule_before_time` or earlier.
564   class CopyAllocation : public Allocation {
565    public:
566     CopyAllocation(const Allocation& prev_allocation, MemorySpace memory_space,
567                    absl::optional<Chunk> chunk, int64_t start_time,
568                    int64_t end_time, int64_t copy_done_schedule_before_time,
569                    bool is_cross_program_prefetch = false)
570         : Allocation(/*defining_position=*/{nullptr, {}}, memory_space, chunk,
571                      start_time, end_time, /*is_scoped_allocation=*/false),
572           prev_allocation_(prev_allocation),
573           copy_start_schedule_after_(start_time),
574           copy_done_schedule_before_(copy_done_schedule_before_time),
575           is_cross_program_prefetch_(is_cross_program_prefetch) {}
576 
is_copy_allocation()577     bool is_copy_allocation() const override { return true; }
578 
579     Status Process() override;
580 
581     void MarkNeeded(absl::flat_hash_set<const Allocation*>& needed_allocations)
582         const override;
583 
defining_position()584     HloPosition defining_position() const override {
585       // Unless explicitly set, the defining position of a copy allocation in
586       // retrieved from the previous allocation. This is because we don't create
587       // new CopyStart/CopyDone instructions until later and the position should
588       // point to the previous (copy or otherwise) allocation's position for the
589       // original defining position.
590       if (defining_position_.instruction == nullptr) {
591         return prev_allocation_.defining_position();
592       }
593       return defining_position_;
594     }
595 
copy_start()596     HloInstruction* copy_start() const { return copy_start_; }
copy_done()597     HloInstruction* copy_done() const { return copy_done_; }
598 
599     // Returns the time the buffer is first available to be used. For For
600     // CopyAllocation, this is when the copy ends, which is
601     // copy_done_schedule_before.
earliest_available_time()602     int64 earliest_available_time() const override {
603       return copy_done_schedule_before_;
604     }
605 
copy_start_schedule_after()606     int64 copy_start_schedule_after() const {
607       return copy_start_schedule_after_;
608     }
copy_done_schedule_before()609     int64 copy_done_schedule_before() const {
610       return copy_done_schedule_before_;
611     }
612 
set_copy_start_schedule_after(int64_t copy_start_schedule_after)613     void set_copy_start_schedule_after(int64_t copy_start_schedule_after) {
614       copy_start_schedule_after_ = copy_start_schedule_after;
615     }
616 
is_cross_program_prefetch()617     bool is_cross_program_prefetch() const {
618       return is_cross_program_prefetch_;
619     }
620 
621     bool operator==(const CopyAllocation& other) const;
622     std::string ToString() const override;
623 
624    private:
625     const Allocation& prev_allocation_;
626     // These variables define the scheduling boundaries where CopyStart and
627     // CopyDone can be scheduled. The earliest CopyStart can be scheduled is
628     // after copy_start_schedule_after_ and the latest CopyDone can be scheduled
629     // is before copy_done_schedule_before_.
630     int64 copy_start_schedule_after_;
631     int64 copy_done_schedule_before_;
632     bool is_cross_program_prefetch_;
633     HloInstruction* copy_start_;
634     HloInstruction* copy_done_;
635   };
636 
637   // An allocation in default memory space that is defined in the parent
638   // computation. If a value has a copy in the default memory space in the
639   // parent computation, we don't need to evict this buffer in a while loop.
640   class ParentAllocation : public Allocation {
641    public:
ParentAllocation(const Allocation & original_allocation,HloInstruction * calling_instruction,HloPosition position,int64_t time)642     ParentAllocation(const Allocation& original_allocation,
643                      HloInstruction* calling_instruction, HloPosition position,
644                      int64_t time)
645         : Allocation(position, MemorySpace::kDefault,
646                      original_allocation.chunk(), /*start_time=*/time,
647                      /*end_time=*/time, /*is_scoped_allocation=*/false),
648           original_allocation_(original_allocation),
649           calling_instruction_(calling_instruction) {}
650 
651     Status Process() override;
652     Status PostProcess() override;
653 
654     void MarkIfNeeded(absl::flat_hash_set<const Allocation*>&
655                           needed_allocations) const override;
656     void MarkNeeded(absl::flat_hash_set<const Allocation*>& needed_allocations)
657         const override;
658 
659     std::string ToString() const override;
660 
661    private:
662     const Allocation& original_allocation_;
663     HloInstruction* calling_instruction_;
664   };
665 
666   using AllocationSequence = std::vector<std::unique_ptr<Allocation>>;
667   // AllocationValue is used to break up HloValues for each non-trivial position
668   // (trivial positions are considered Tuple, GetTupleElement, and Bitcast). An
669   // HloValue may include positions and uses that alias with each other across
670   // multiple computations. We use this class to break these HloValues such that
671   // every AllocationValue has one defining position (that may alias with other
672   // AllocationValues). The uses field of the AllocationValue contains only the
673   // direct uses of the AllocationValue's defining position.
674   //
675   // For example, consider the following HLO snippet:
676   //
677   // Body {
678   //   body_param = (f32[4,3]{1,0}, f32[]) parameter(0)
679   //   get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element(body_param),
680   //   index=0
681   //   ...
682   //   ROOT tuple = (f32[4,3]{1,0}, f32[]) tuple(get-tuple-element.3, ...)
683   // }
684   //
685   // Cond {
686   //   cond_param = (f32[4,3]{1,0}, f32[]) parameter(0)
687   //   ...
688   // }
689   //
690   // add.4 = f32[4,3]{1,0} add(...)
691   // tuple.1 = (f32[4,3]{1,0}, f32[]) tuple(add.4, ...)
692   // while = (f32[4,3]{1,0}, f32[]) while(tuple.1), body=Body, condition=Cond
693   // get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element(while), index=0
694   // add.5 = f32[4,3]{1,0} add(get-tuple-element.5, ...)
695   //
696   // This contains an HloValue that looks like the following:
697   // positions:
698   //  add.4
699   //  body_param {0}
700   //  get-tuple-element.3
701   //  tuple {0}
702   //  cond_param {0}
703   //  tuple.1 {0}
704   //  while {0}
705   //  get-tuple-element.5
706   // uses:
707   //  add.1, operand 0
708   //  tuple, operand 0
709   //  while, operand 0 {0}
710   //  add.5, operand 0
711   //
712   // We break this HloValue up into the following AllocationValues for each
713   // non-trivial position:
714   // AllocationValue1: computation = Entry
715   //  position:
716   //   add.4
717   //  uses:
718   //   while, operand 0 {0}
719   // AllocationValue2: computation = Cond
720   //  position:
721   //   cond_param {0}
722   //  uses:
723   // AllocationValue3: computation = Body
724   //  position:
725   //   body_param {0}
726   //  uses:
727   //   add.1, operand 0
728   //   tuple, operand 0
729   // AllocationValue4: computation = Entry
730   //  position:
731   //   while {0}
732   //  uses:
733   //   add.5, operand 0
734   class AllocationValue {
735    public:
736     // This data structure wraps an HloUse and adds additional metadata that are
737     // useful for allocation.
738     struct Use {
739       // The wrapped HloUse object.
740       HloUse hlo_use;
741       // The logical time this use is scheduled.
742       int64 time;
743       // All the positions where this use aliases with. The aliased positions
744       // must get the same allocation.
745       std::vector<HloPosition> aliases;
746 
747       bool operator==(const Use& other) const {
748         return hlo_use == other.hlo_use && time == other.time &&
749                aliases == other.aliases;
750       }
751 
752       template <typename H>
AbslHashValueUse753       friend H AbslHashValue(H h, const Use& s) {
754         return H::combine(std::move(h), s.hlo_use, s.time, s.aliases);
755       }
756     };
757 
AllocationValue(const HloValue * value,const HloPosition & position,int64_t size)758     AllocationValue(const HloValue* value, const HloPosition& position,
759                     int64_t size)
760         : value_(value),
761           defining_position_(position),
762           size_(size),
763           requires_contiguous_allocation_(false) {}
764 
defining_position()765     const HloPosition& defining_position() const { return defining_position_; }
defining_instruction()766     const HloInstruction* defining_instruction() const {
767       return defining_position().instruction;
768     }
size()769     int64 size() const { return size_; }
uses()770     const std::vector<Use>& uses() const { return uses_; }
uses()771     std::vector<Use>& uses() { return uses_; }
value()772     const HloValue* value() const { return value_; }
computation()773     const HloComputation* computation() const {
774       return defining_instruction()->parent();
775     }
allocation_sequence()776     AllocationSequence* allocation_sequence() { return &allocation_sequence_; }
777 
778     // Sets/gets whether this AllocationValue requires allocating it
779     // contiguously throughout its live range (without any copies).
requires_contiguous_allocation()780     bool requires_contiguous_allocation() const {
781       return requires_contiguous_allocation_;
782     }
set_requires_contiguous_allocation(bool requires_contiguous_allocation)783     void set_requires_contiguous_allocation(
784         bool requires_contiguous_allocation) {
785       requires_contiguous_allocation_ = requires_contiguous_allocation;
786     }
787 
AddUse(const HloUse & use,int64_t use_time)788     void AddUse(const HloUse& use, int64_t use_time) {
789       uses_.push_back({use, use_time, {}});
790     }
791 
792     std::string ToString() const;
793     std::string ToShortString() const;
794 
795    private:
796     const HloValue* value_;
797     HloPosition defining_position_;
798     int64 size_;
799     // If true, there must be a contiguous allocation for this buffer without
800     // any copies.
801     bool requires_contiguous_allocation_;
802     std::vector<Use> uses_;
803     AllocationSequence allocation_sequence_;
804   };
805 
806   // Statistics of asynchronous copies.
807   struct AsyncCopyStats {
808     int64 max_outstanding_async_copies;
809     int64 num_prefetches;
810     int64 prefetch_bytes;
811     int64 num_evictions;
812     int64 eviction_bytes;
813   };
814 
815   virtual ~MemorySpaceAssignment() = default;
816 
817   // Runs the MemorySpaceAssignment pass.
818   static StatusOr<std::unique_ptr<PresetAssignments>> Run(
819       HloModule* module, const HloLiveRange& hlo_live_range,
820       const HloAliasAnalysis& alias_analysis, const Options& options);
821 
822   // Calculates asynchronous copy statistics.
823   StatusOr<AsyncCopyStats> CalculateAsyncCopyStats() const;
824 
825   static BufferIntervalCompare GetMemoryBoundednessBufferIntervalCompare(
826       const MemorySpaceAssignmentCostAnalysis& cost_analysis,
827       MemorySpaceAssignmentCostAnalysis::Cache* cache = nullptr);
828 
829   // Verify that the memory space assignment is free of overlapping buffers and
830   // export heap simulator trace to be used by buffer_assignment.
831   Status VerifyAndExportHeapSimulatorTrace();
832 
833  protected:
834   // Main driver of the memory space assignment pass.
835   virtual StatusOr<std::unique_ptr<PresetAssignments>> RunMemorySpaceAssignment(
836       const HloLiveRange& hlo_live_range,
837       const HloAliasAnalysis& alias_analysis);
838 
839   // Finds an AllocationSequence for placing buffers in alternate memory using
840   // the AlternateMemoryBestFitHeap algorithm. Must be set before Process() is
841   // called.
842   virtual Status FindAllocationSequence(const HloLiveRange& hlo_live_range,
843                                         const HloAliasAnalysis& alias_analysis);
844 
options()845   const Options& options() const { return options_; }
846 
MemorySpaceAssignment(HloModule * module,const Options & options,const HloLiveRange & hlo_live_range)847   MemorySpaceAssignment(HloModule* module, const Options& options,
848                         const HloLiveRange& hlo_live_range)
849       : module_(module),
850         options_(options),
851         flattened_instructions_(hlo_live_range.flattened_instruction_sequence()
852                                     .instructions()
853                                     .begin(),
854                                 hlo_live_range.flattened_instruction_sequence()
855                                     .instructions()
856                                     .end()),
857         computations_in_schedule_(),
858         preset_assignments_(absl::make_unique<PresetAssignments>()) {
859     for (const auto& computation_and_bound :
860          hlo_live_range.computation_span_times()) {
861       computations_in_schedule_.insert(computation_and_bound.first);
862     }
863   }
864 
865   AllocationSequence allocations_;
866 
module()867   HloModule* module() { return module_; }
868 
869  private:
870   // Process calls Process methods of the allocations after the allocations have
871   // been finalized.
872   Status Process();
873 
874   // Process() might have altered the computation graph by inserting kTuple and
875   // kGetTupleElement instructions. SimplifyGraph performs a simple DCE and
876   // tuple simplification operation (e.g., given GetTupleElement(Tuple(a, b),
877   // 1), simply forwards b). Runs to fixed point.
878   Status SimplifyGraph();
879 
880   // FixSchedule inserts asynchronous copies in the schedule.
881   Status FixSchedule();
882 
883   // Export the alternate memory assignments to the PresetAssignments and color
884   // the HLO graph with the determined memory spaces.
885   Status ExportAndColorBuffers();
886 
887   // Insert an instruction to the schedule, and make sure its dependencies
888   // (operands) are already in the schedule. If not, insert these operands
889   // before the instruction.
890   void EnsureInstructionAndOperandsInserted(
891       HloInstruction* new_instruction, HloInstructionSequence* new_sequence,
892       absl::flat_hash_set<HloInstruction*>* inserted_instructions) const;
893 
894   // Schedules asynchronous copies and ensures that the CopyStarts and their
895   // corresponding CopyDones follow the same order.
896   void ScheduleAsynchronousCopies();
897 
898   // Remove the positions and chunks associated with the instruction from
899   // alternate_memory_assignments_.
900   void RemoveAssignmentForInstruction(const HloInstruction* instruction);
901 
902   // Returns the estimated elapsed duration of the hlo module in seconds. It
903   // uses the 'allocations' argument to determine the location (default memory
904   // or alternate memory) of each operand and output of an instruction.
905   float ComputeEstimatedElapsedTime(const HloLiveRange& hlo_live_range,
906                                     const AllocationSequence& allocations);
907 
908   HloModule* module_;
909   const Options& options_;
910   std::vector<HloInstruction*> flattened_instructions_;
911   absl::flat_hash_set<const HloComputation*> computations_in_schedule_;
912   std::unique_ptr<PresetAssignments> preset_assignments_;
913   std::vector<std::pair<HloPosition, Chunk>> alternate_memory_assignments_;
914   std::vector<std::pair<HloInstruction*, Chunk>> scoped_memory_assignments_;
915   int64 alternate_memory_size_ = 0;
916 
917   // These maps hold vectors of new instructions that need to be scheduled after
918   // (or before) the instruction index in the key. FixSchedule uses these maps
919   // to modify and fix the schedule.
920   absl::flat_hash_map<int64, std::vector<HloInstruction*>> schedule_after_;
921   absl::flat_hash_map<int64, std::vector<HloInstruction*>> schedule_before_;
922 };
923 
924 // The different options to be passed to the Run() API.
925 struct Options {
926   // Backend-specific integer value that describes the alternate memory.
927   int64 alternate_memory_space = 0;
928 
929   // Maximum size of the alternate memory space.
930   int64 max_size_in_bytes = 0;
931 
932   // Memory alignment of the alternate memory space.
933   int64 alignment_in_bytes = 1;
934 
935   // If provided, we sort the buffers using this comparison function
936   // otherwise, we use GlobalDecreasingSizeBestFitHeap::kSpatial.
937   absl::optional<MemorySpaceAssignment::BufferIntervalCompare>
938       buffer_interval_compare = absl::nullopt;
939 
940   // This object determines how early and how late prefetches can occur.
941   PrefetchIntervalPicker* prefetch_interval_picker = nullptr;
942 
943   // This object is used to determine the benefit of a particular allocation.
944   MemorySpaceAssignmentCostAnalysis* cost_analysis = nullptr;
945 
946   // Size function for buffer values.
947   BufferValue::SizeFunction size_fn;
948 
949   // This function can be used to prevent certain HloValues (e.g., based on
950   // the opcode) to be placed on the alternate memory.
951   MemorySpaceAssignment::IsAllowedInAlternateMemoryFunction
952       is_allowed_in_alternate_mem_fn;
953 
954   // This function can be used to prevent certain HloUses (e.g., based on
955   // the opcode) to be placed on the alternate memory.
956   MemorySpaceAssignment::IsUseAllowedInAlternateMemoryFunction
957       is_use_allowed_in_alternate_mem_fn = [](const HloUse&) { return true; };
958 
959   // This function returns the amount of scoped memory in bytes that should be
960   // reserved during the execution of this instruction.
961   MemorySpaceAssignment::ReservedScopedMemoryFunction
962       reserved_scoped_memory_fn = [](const HloInstruction*) { return 0; };
963 
964   // If true, we allocate the reserved scoped memory at the same offset. This
965   // is useful to enable more deduplication between HLOs that have reserved
966   // scoped memories, but may result in less efficient memory packing.
967   bool allocate_reserved_scoped_memory_at_same_offset = true;
968 
969   // Specifies the upper bound for number of outstanding prefetches and
970   // evictions, -1 for unlimited.
971   int64 max_outstanding_prefetches = -1;
972   int64 max_outstanding_evictions = -1;
973 
974   // Extra outstanding prefetch limit for while uses (in addition to
975   // max_outstanding_prefetches).
976   int64 while_use_extra_outstanding_prefetch_limit = 0;
977 
978   // Specifies the maximum number of times we are willing to move a copy
979   // done of a prefetch earlier due to an asynchronous copy ordering
980   // violation.
981   int64 prefetch_copy_done_reorder_max_retries = 1;
982 
983   // Specifies the maximum number of retries that will be performed for each
984   // value in case prefetching failed due to running out of asynchronous
985   // copies or asynchronous copy ordering.
986   int64 max_retries = 1;
987 
988   // The maximum number of repacks that we are willing to perform in case we
989   // can't allocate a buffer due to running out of memory. If this value is
990   // greater than 0, repacker must be non-nullptr.
991   int64 max_repacks = 0;
992 
993   // This variable is used by the cost analysis in estimating how many times
994   // each while loop will execute. Nested loops will be assumed to have
995   // executed pow(while_execution_count, nesting_level) times.
996   uint64 xla_tpu_memory_space_assignment_while_execution_count = 5ULL;
997 
998   float async_copy_bandwidth_bytes_per_second = 0.0f;
999 
1000   float alternate_mem_bandwidth_bytes_per_second = 0.0f;
1001 
1002   // The repacking algorithm to reduce fragmentation. Must be non-null if
1003   // max_repacks is greater than 0.
1004   MemorySpaceAssignmentRepacker* repacker = nullptr;
1005 
1006   // This is only useful for testing, repack after every allocation.
1007   bool repack_after_every_allocation = false;
1008 
1009   // If true, tries allocating buffers across (e.g., before and inside a while
1010   // loop body) sequential calls (kWhile, kCall, and kConditional).
1011   bool allocate_across_sequential_calls = false;
1012 
1013   // If true, verifies the memory space assignment against overlapping
1014   // buffers.
1015   bool verify = false;
1016 
1017   // If not nullptr, this function is called to dump debugging information.
1018   // The first argument is appended to the file name and the second argument
1019   // is the contents of the file.
1020   std::function<void(absl::string_view, absl::string_view)> dump_fn = nullptr;
1021 
1022   // Enable prefetching buffers into preferred memory across program
1023   // boundaries
1024   bool enable_cross_program_prefetch = true;
1025 
1026   // If true, use buffer_interval_compare to determine which buffers to
1027   // prefetch across program boundaries.
1028   bool default_cross_program_prefetch_heuristic = false;
1029 
1030   // Enable cross-program prefetch freeing optimization where the
1031   // cross-program-prefetched buffer can be reused.
1032   bool enable_cross_program_prefetch_freeing = true;
1033 };
1034 
1035 // A struct representing an asynchronous copy with its logical start and end
1036 // time and its destination memory space.
1037 struct AsynchronousCopy {
1038   int64 start_time;
1039   int64 end_time;
1040   MemorySpaceAssignment::MemorySpace destination;
1041 };
1042 
1043 // Compare asynchronous copies such that an earlier start time has the same or
1044 // earlier end time and an earlier end time has the same or earlier start time.
1045 bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b);
1046 
1047 // Helper class to enforce asynchronous copy ordering. We only allow
1048 // asynchronous copies that are pipelined: if an asynchronous copy ends earlier
1049 // than another asynchronous copy, it must start the same time or earlier than
1050 // the other asynchronous copy; and if an asynchronous copy starts earlier than
1051 // another asynchronous copy, it must end the same time or earlier than the
1052 // other asynchronous copy.
1053 class AsynchronousCopyOrdering {
1054  public:
1055   AsynchronousCopyOrdering() = default;
1056 
1057   // Adds an asynchronous copy.
1058   void AddCopy(const AsynchronousCopy& copy);
1059 
1060   // Removes an asynchronous copy. CHECKs that it is removed.
1061   void RemoveCopy(const AsynchronousCopy& copy);
1062 
1063   // If the addition of an asynchronous copy in the given time interval would
1064   // violate the asynchronous copy ordering, returns the violating
1065   // already-committed asynchronous copy. E.g., consider the following scenario:
1066   //                                  CS          CD
1067   //  already committed async copy:   +-----------+
1068   //                new async copy:     +--------+
1069   //
1070   // The new asynchronous copy would violate the ordering guarantee because the
1071   // copy start is after an already committed asynchronous copy while its copy
1072   // done is before the committed copy.
1073   absl::optional<AsynchronousCopy> ViolatesOrdering(int64_t start_time,
1074                                                     int64_t end_time) const;
1075 
1076  private:
1077   // Stores asynchronous copies in a tree set respecting the pipelining order.
1078   std::set<AsynchronousCopy> ranges_;
1079 };
1080 
1081 // This class inherits from GlobalDecreasingSizeBestFitHeap with a notion of
1082 // maximum size.
1083 class AlternateMemoryBestFitHeap
1084     : public GlobalDecreasingSizeBestFitHeap<HloValue> {
1085  public:
1086   using MemorySpace = MemorySpaceAssignment::MemorySpace;
1087   using AllocationValue = MemorySpaceAssignment::AllocationValue;
1088 
AlternateMemoryBestFitHeap(MemorySpaceAssignment::AllocationSequence * allocations,const Options & options,const HloAliasAnalysis & alias_analysis,const HloLiveRange & hlo_live_range)1089   AlternateMemoryBestFitHeap(
1090       MemorySpaceAssignment::AllocationSequence* allocations,
1091       const Options& options, const HloAliasAnalysis& alias_analysis,
1092       const HloLiveRange& hlo_live_range)
1093       : GlobalDecreasingSizeBestFitHeap(options.alignment_in_bytes),
1094         allocations_(allocations),
1095         options_(options),
1096         alias_analysis_(alias_analysis),
1097         hlo_live_range_(hlo_live_range) {
1098     // Override buffer interval compare if provided.
1099     if (options.buffer_interval_compare) {
1100       buffer_interval_compare_ = *options.buffer_interval_compare;
1101     }
1102   }
1103 
1104   // Allocates a buffer in preferred memory with whole program lifetime and
1105   // enables prefetching prefetch_candidate from default memory across program
1106   // boundaries.
1107   void AllocateCrossProgramPrefetchBuffer(
1108       HloModule* module, absl::optional<BufferInterval> prefetch_candidate);
1109 
1110   HeapSimulator::Result<HloValue> Finish() override;
1111 
1112  protected:
1113   // Given a buffer interval, returns the colocated intervals. Unlike the
1114   // similar GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations, it
1115   // returns the colocated intervals sorted by scheduled time.
1116   std::vector<const BufferInterval*> GetSortedColocatedIntervals(
1117       const BufferInterval& interval) const;
1118 
1119   // Given a BufferInterval, creates AllocationValue objects and corresponding
1120   // AllocationSequences and appends them into allocation_sequence_list_.
1121   void CreateAllocationValues(
1122       const BufferInterval& buffer_interval,
1123       std::vector<AllocationValue>& allocation_values) const;
1124 
1125   // Given colocated intervals, populates allocation_values with the
1126   // corresponding AllocationValue objects.
1127   virtual void CreateAllocationValuesFromColocatedIntervals(
1128       absl::Span<const AlternateMemoryBestFitHeap::BufferInterval* const>
1129           colocated_intervals,
1130       std::vector<MemorySpaceAssignment::AllocationValue>& allocation_values);
1131 
1132   // Go through all the uses in the AllocationValues and find the aliasing
1133   // positions.
1134   void FindAliases(std::vector<AllocationValue>* allocation_values) const;
1135 
allocations()1136   MemorySpaceAssignment::AllocationSequence* allocations() {
1137     return allocations_;
1138   }
options()1139   const Options& options() const { return options_; }
alias_analysis()1140   const HloAliasAnalysis& alias_analysis() { return alias_analysis_; }
hlo_live_range()1141   const HloLiveRange& hlo_live_range() { return hlo_live_range_; }
1142 
1143  private:
1144   // We inherit AllocationBlock struct to attach the Allocation information to
1145   // make importing repacked offsets easier.
1146   struct RepackAllocationBlock
1147       : MemorySpaceAssignmentRepacker::AllocationBlock {
1148     MemorySpaceAssignment::Allocation* allocation;
1149   };
1150 
1151   // A data structure we use to associate Allocation objects that are aliased
1152   // and must get the same offset.
1153   struct AliasedOffset {
1154     int64 offset;
1155     absl::flat_hash_set<const MemorySpaceAssignment::Allocation*> allocations;
1156   };
1157 
1158   // An allocation request for a use segment. A use segment is the time segment
1159   // between the definition and the first use, and the time segment between the
1160   // uses of a buffer. For example, the time between the definition and Use1, is
1161   // the first segment, and the time between Use1 and Use2 is the second segment
1162   // and so on:
1163   //
1164   //        +------+----------+-------+
1165   //       /        \          \       \
1166   //      /          v          v       v
1167   //    Def         Use1       Use2    Use3
1168   //     <----------> <--------> <----->
1169   //        Segment    Segment   Segment
1170   //
1171   // start_time and end_time are the start and end logical times of the segment.
1172   // use_times is a sorted sequence of the times of all uses.
1173   // latest_prefetch_time is the latest time we can schedule the CopyDone for a
1174   // prefetch.
1175   // If allow_no_copy_alternate_mem_allocation is false, an eviction is forced.
1176   // If earliest_prefetch_time is set, prefetches cannot start before this
1177   // value.
1178   struct AllocationRequest {
1179     int64 start_time;
1180     int64 end_time;
1181     int64 latest_prefetch_time;
1182     int64 size;
1183     bool allow_no_copy_alternate_mem_allocation;
1184     absl::optional<int64> earliest_prefetch_time;
1185     AliasedOffset* preferred_offset;
1186     const MemorySpaceAssignment::AllocationValue::Use* use;
1187     MemorySpaceAssignment::AllocationValue* allocation_value;
1188     absl::Span<const int64_t> all_use_times;
1189   };
1190 
1191   // This struct contains mandatory memory assignments at a given time. E.g., an
1192   // input's required memory assignment time would correspond to the definition
1193   // time of the parameter instruction, and an output's time would correspond to
1194   // the time of last use.
1195   struct RequiredMemoryAssignment {
1196     MemorySpaceAssignment::MemorySpace memory_space;
1197     int64 time;
1198     AliasedOffset* offset;
1199 
equals_ignoring_timeRequiredMemoryAssignment1200     bool equals_ignoring_time(const RequiredMemoryAssignment& other) const {
1201       return memory_space == other.memory_space && offset == other.offset;
1202     }
1203 
1204     bool operator==(const RequiredMemoryAssignment& other) const {
1205       return memory_space == other.memory_space && time == other.time &&
1206              offset == other.offset;
1207     }
1208 
1209     bool operator!=(const RequiredMemoryAssignment& other) const {
1210       return !(*this == other);
1211     }
1212   };
1213 
1214   // Result of an allocation, prefetch, eviction etc. request.  The result is
1215   // either kSuccess or a bitwise OR of one or more failures. The values are
1216   // unique powers of two. To check if a result contains a particular failure,
1217   // use the result_is method. To add a new failure to a result, use the
1218   // result_mark method.
1219   enum class Result {
1220     // Successful allocation.
1221     kSuccess = 0,
1222     // Allocation failed because we ran out of alternate memory.
1223     kFailOutOfMemory = 1,
1224     // A no-copy allocation couldn't be performed because the previous
1225     // allocation wasn't in the alternate memory space.
1226     kFailPrevAllocationNotInAlternateMem = 2,
1227     // A no-copy allocation couldn't be performed because the live range was too
1228     // long.
1229     kFailLiveRangeTooLong = 4,
1230     // A prefetching couldn't be performed because the live range was too short.
1231     kFailLiveRangeTooShort = 8,
1232     // Ran out of outstanding asynchronous copy limit either during prefetching
1233     // or eviction.
1234     kFailOutOfAsyncCopies = 16,
1235     // A prefetching couldn't be performed because the asynchronous copy
1236     // ordering was violated.
1237     kFailViolatesAsyncCopyOrdering = 32,
1238     // An allocation failure happened that requires uncommitting all the pending
1239     // allocations. Usually this is due to a situation requiring an eviction but
1240     // the eviction couldn't be performed.
1241     kFailRequiresUncommit = 64
1242   };
1243 
1244   // Return true if the result belongs to a failure.
result_is(Result result,Result failure)1245   static bool result_is(Result result, Result failure) {
1246     return static_cast<int>(result) & static_cast<int>(failure);
1247   }
1248 
1249   // Mark (bitwise OR) a failure to the result.
result_mark(Result failure,Result & result)1250   static Result result_mark(Result failure, Result& result) {
1251     result = static_cast<Result>(static_cast<int>(result) |
1252                                  static_cast<int>(failure));
1253     return result;
1254   }
1255 
1256   // Return true if the result is a failure that requires us to uncommit pending
1257   // chunks.
result_requires_uncommit(Result result)1258   static bool result_requires_uncommit(Result result) {
1259     return result_is(result, Result::kFailRequiresUncommit);
1260   }
1261 
1262   // Return true if the result is a failure either due to running out of
1263   // outstanding asynchronous copies or due to violating asynchronous copy
1264   // ordering.
result_failed_because_of_async_copy(Result result)1265   static bool result_failed_because_of_async_copy(Result result) {
1266     return result_is(result, Result::kFailOutOfAsyncCopies) ||
1267            result_is(result, Result::kFailViolatesAsyncCopyOrdering);
1268   }
1269 
1270   // Allocates buffers for instructions that need reserved scoped allocations in
1271   // the alternate memory space.
1272   void AllocateReservedScopedAllocations();
1273 
1274   // Returns the AliasedOffset object associated with the allocation.
1275   AliasedOffset* GetAliasedOffset(
1276       const MemorySpaceAssignment::Allocation& allocation);
1277 
1278   // If aliased_offset is non-null, this method adds the allocation to
1279   // aliased_offset. Otherwise, it creates a new AliasedOffset object and adds
1280   // the allocation to this new AliasedOffset.
1281   void CreateOrAddToAliasedOffset(
1282       const MemorySpaceAssignment::Allocation& allocation,
1283       AliasedOffset* aliased_offset);
1284 
1285   // Given an allocation sequence, returns the live allocation at time with a
1286   // preference towards allocations in alternate memory. Returns nullptr if no
1287   // allocation is alive at that time.
1288   static MemorySpaceAssignment::Allocation* GetLiveAllocationAt(
1289       const MemorySpaceAssignment::AllocationSequence& allocations,
1290       int64_t time);
1291 
1292   // Returns true if the use is allowed in the alternate memory.
1293   bool IsUseAllowedInAlternateMemory(const AllocationValue& value,
1294                                      const HloUse& use) const;
1295 
1296   // Finds allocations for allocation values generated from colocated intervals.
1297   // All of the allocation values have a must-alias relationship with each
1298   // other. Returns either kSuccess if all of the sites could be placed in the
1299   // alternate memory or a bitwise OR of failure reasons why they couldn't
1300   Result AllocateAllocationValues(
1301       absl::Span<AllocationValue> allocation_values);
1302 
1303   // Finds an allocation for an allocation request for a segment (see the
1304   // documentation for AllocationRequest above how a segment is defined).
1305   //
1306   // It performs three things in the following order:
1307   //  1- Allocate the allocation request entirely in the alternate memory, if
1308   //     there is enough space and if the prefetch interval picker allows.
1309   //  2- If (1) was unsuccessful, and the only allocation for
1310   //     this buffer was in the alternate memory, we try to perform a prefetch.
1311   //  3- If (1) was unsuccessful, prefetch the buffer into the alternate memory,
1312   //     if there is enough space and if the prefetch interval picker allows.
1313   //
1314   // If an eviction (2) was requested and was unsuccessful, this method returns
1315   // Result::kFailRequiresUncommit. This means we could not find a suitable
1316   // allocation, so all previous allocations for this buffer must be removed and
1317   // allocated in the default memory. Otherwise, this method may return
1318   // Result::kSuccess if the buffer could be placed in alternate memory or some
1319   // other Result with an OR of reasons why the buffer couldn't be placed in
1320   // alternate memory.
1321   Result AllocateSegment(const AllocationRequest& request);
1322 
1323   // Try allocating in alternate memory without any copies.
1324   Result AllocateInAlternateMemoryNoCopy(const AllocationRequest& request);
1325 
1326   // Try evicting to default memory space.
1327   Result Evict(const AllocationRequest& request);
1328 
1329   // Returns the time a copy done of a prefetch should be scheduled.
1330   int64 FindPrefetchEndTime(const AllocationRequest& request,
1331                             int64_t earliest_prefetch_time) const;
1332 
1333   // Try prefetching to alternate memory space.
1334   Result Prefetch(
1335       const AllocationRequest& request,
1336       const MemorySpaceAssignment::Allocation& prev_allocation_in_default_mem);
1337 
1338   // Find the best possible chunk candidate, where it has the longest possible
1339   // availability if no preferred offset is given, or at the preferred_offset if
1340   // it is given.
1341   absl::optional<ChunkCandidate> FindBestChunkCandidate(
1342       const AllocationRequest& request, const AliasedOffset* preferred_offset,
1343       BufferInterval* alternate_mem_interval) const;
1344 
1345   // Returns the required assignment at a particular time, if available.
1346   absl::optional<RequiredMemoryAssignment> RequiredMemoryAssignmentAt(
1347       const HloValue* buffer, int64_t time) const;
1348 
1349   // Searches for aliases in the use for a required assignment, and returns it
1350   // if found.
1351   absl::optional<RequiredMemoryAssignment> AliasedRequiredAssignmentForUse(
1352       const AllocationValue::Use& use) const;
1353 
1354   // Goes through the colocated intervals and adds any required assignment.
1355   void AddRequiredAssignmentsForColocatedIntervals(
1356       absl::Span<const AlternateMemoryBestFitHeap::BufferInterval* const>
1357           colocated_intervals);
1358 
1359   // Propagates aliased required assignment for a given position.
1360   void AddAliasedRequiredAssignment(
1361       const HloInstruction* instruction, ShapeIndex index,
1362       const MemorySpaceAssignment::Allocation* aliased_allocation);
1363 
1364   // This sets a required assignment. CHECK fails if there is a conflicting
1365   // required assignment at the same time.
1366   void AddRequiredAssignment(const HloValue* value,
1367                              const HloInstruction* instruction,
1368                              MemorySpace memory_space, int64_t time,
1369                              AliasedOffset* offset = nullptr);
1370   void AddRequiredAssignment(const HloInstruction* instruction,
1371                              ShapeIndex index, MemorySpace memory_space,
1372                              AliasedOffset* offset = nullptr);
1373 
1374   // Adds input and outputs as required assignments.
1375   void AddInputAndOutputRequiredAssignments();
1376 
1377   // Returns true if the colocated intervals in the argument are in a parameter
1378   // or root instruction of the entry computation and are reserved by the user
1379   // to be in the alternate memory space.
1380   bool AreIntervalsReservedInAlternateMemory(
1381       absl::Span<const BufferInterval* const> colocated_intervals) const;
1382 
1383   // Since the allocations are recorded to the AllocationSequence, we don't
1384   // maintain result_ in GlobalDecreasingSizeBestFitHeap. Override AddToChunkMap
1385   // to avoid unnecessarily adding the chunk to the chunk map.
AddToChunkMap(const HloValue * buffer,Chunk chunk)1386   void AddToChunkMap(const HloValue* buffer, Chunk chunk) override {}
1387 
1388   // Returns true if the addition of an asynchronous copy in the given time
1389   // interval would violate the maximum number of asynchronous copies. An extra
1390   // async copy limit can be provided to increase the limit of asynchronous
1391   // copies for this instance.
1392   bool ViolatesMaximumOutstandingAsyncCopies(
1393       int64_t start_time, int64_t end_time, bool is_prefetch,
1394       int64_t extra_async_copy_limit = 0) const;
1395 
1396   // If the asynchronous copy would violate the pipelining order, returns the
1397   // violating asynchronous copy.
1398   absl::optional<AsynchronousCopy> ViolatesAsyncCopyOrdering(
1399       int64_t start_time, int64_t end_time) const;
1400 
1401   // Exports the allocations for repacking and puts them into the vector in the
1402   // parameter.
1403   void ExportAllocationsForRepacking(
1404       std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*>&
1405           allocations);
1406 
1407   // Imports repacked allocations and updates the internal data structures
1408   // consistent with the new packing.
1409   void ImportRepackedAllocations();
1410 
1411   // Adds an asynchronous copy to the allocations.
1412   void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation,
1413                     MemorySpace memory_space, absl::optional<Chunk> chunk,
1414                     int64_t start_time, int64_t end_time,
1415                     int64_t copy_done_schedule_before_time,
1416                     MemorySpaceAssignment::AllocationSequence* allocations,
1417                     AliasedOffset* aliased_offset,
1418                     bool is_cross_program_prefetch = false);
1419 
1420   // This method is used for committing the chunk candidate but adding it to
1421   // pending_chunks_ so that we can "uncommit" them in case we need to roll back
1422   // this allocation sequence.
1423   void AddToPendingChunks(const BufferInterval& buffer_interval,
1424                           const ChunkCandidate& chunk_candidate);
1425   // If we need to remove the allocations for this allocation sequence, this
1426   // removes pending chunks and asynchronous copies in the respective pending
1427   // buffers from the interval trees. If an allocation request returns
1428   // kFailRequiresUncommit, this method must be called.
1429   void UncommitPendingChunks(absl::Span<AllocationValue> allocation_values);
1430 
1431   // Finalizes the allocations where they can no longer be uncommitted.
1432   void FinalizeAllocations(absl::Span<AllocationValue> allocation_values);
1433 
1434   // Clears all pending chunks and asynchronous copies.
1435   void ClearPendingChunks();
1436 
1437   // Append buffer and allocation infos for debugging and dump it into a file,
1438   // if enabled.
1439   void AppendBufferInfoDebugString(const BufferInterval& interval,
1440                                    std::string* debug_str) const;
1441   void AppendAllocationInfoDebugString(
1442       const AllocationValue& value,
1443       const MemorySpaceAssignment::Allocation& allocation,
1444       std::string& debug_str) const;
1445   void DumpDebugStringsIfEnabled() const;
1446 
1447   // Returns the available heap size in the alternate memory.
available_heap_size()1448   int64 available_heap_size() const {
1449     return options_.max_size_in_bytes - reserved_in_bytes_;
1450   }
1451 
1452   // Creates and returns a RepackAllocationBlock.
MakeRepackAllocationBlock(int64_t start_time,int64_t end_time,int64_t size,int64_t initial_offset,int64_t id,MemorySpaceAssignment::Allocation * allocation)1453   static RepackAllocationBlock MakeRepackAllocationBlock(
1454       int64_t start_time, int64_t end_time, int64_t size,
1455       int64_t initial_offset, int64_t id,
1456       MemorySpaceAssignment::Allocation* allocation) {
1457     RepackAllocationBlock allocation_block;
1458     allocation_block.start_time = start_time;
1459     allocation_block.end_time = end_time;
1460     allocation_block.size = size;
1461     allocation_block.offset = -1;
1462     allocation_block.initial_offset = initial_offset;
1463     allocation_block.id = id;
1464     allocation_block.colocations = {};
1465     allocation_block.allocation = allocation;
1466     return allocation_block;
1467   }
1468 
1469   MemorySpaceAssignment::AllocationSequence* allocations_;
1470   const Options& options_;
1471   const HloAliasAnalysis& alias_analysis_;
1472   const HloLiveRange& hlo_live_range_;
1473   // We use a interval tree to keep track of the number of outstanding
1474   // prefetches and evictions.
1475   BufferIntervalTree prefetch_interval_tree_;
1476   BufferIntervalTree eviction_interval_tree_;
1477   AsynchronousCopyOrdering async_copy_ordering_;
1478   // A list of RepackAllocationBlock objects that mirrors allocation sequences,
1479   // used for repacking. We use a list here because we need pointer stability
1480   // for aliased allocations.
1481   std::list<RepackAllocationBlock> repack_allocation_blocks_;
1482   int64 num_repacks_ = 0;
1483   std::vector<std::pair<BufferInterval, ChunkCandidate>> pending_chunks_;
1484   std::vector<AsynchronousCopy> pending_async_copies_;
1485   std::vector<std::pair<const HloValue*, RequiredMemoryAssignment>>
1486       pending_required_assignments_;
1487   // The data structure that contains AliasedOffset objects and Allocation to
1488   // AliasedOffset map for efficient lookup.
1489   std::list<AliasedOffset> aliased_offsets_;
1490   absl::flat_hash_map<const MemorySpaceAssignment::Allocation*, AliasedOffset*>
1491       aliased_offset_map_;
1492   // This map contains required memory assignments for HloValues (e.g., input
1493   // and outputs).
1494   absl::flat_hash_map<const HloValue*, std::vector<RequiredMemoryAssignment>>
1495       required_assignments_;
1496   // Number of bytes reserved in alternate memory space.
1497   int64 reserved_in_bytes_ = 0;
1498   // Debug strings.
1499   std::string buffer_info_str_;
1500   std::string allocation_info_str_;
1501 };
1502 }  // namespace memory_space_assignment
1503 }  // namespace xla
1504 
1505 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_
1506