• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_
18 
19 #include "tensorflow/compiler/xla/service/heap_simulator.h"
20 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
21 
22 namespace xla {
23 
24 // This class contains pre-set assignments determined by memory space
25 // assignment. It contains two data structures: (1) a chunks vector that maps a
26 // defining HloPosition to a Chunk (offset and size), and (2) a sizes vector
27 // that maps the memory space to its size. If there is only one alternate memory
28 // space like there is currently, there will be one entry in sizes.
29 class PresetAssignments {
30  public:
31   // Contains per-memory-space information like the allocated size and heap
32   // simulator trace.
33   struct AssignmentInformation {
34     int64 size;
35     HeapSimulatorTrace heap_simulator_trace;
36   };
37 
38   PresetAssignments() = default;
39 
add_chunk(const HloPosition & position,const HeapSimulator::Chunk & chunk)40   void add_chunk(const HloPosition& position,
41                  const HeapSimulator::Chunk& chunk) {
42     chunks_.emplace_back(position, chunk);
43   }
44 
assignment_information_for_space(int64 memory_space)45   AssignmentInformation* assignment_information_for_space(int64 memory_space) {
46     for (auto& space_and_info : assignment_info_) {
47       if (space_and_info.first == memory_space) {
48         return &space_and_info.second;
49       }
50     }
51     assignment_info_.emplace_back(memory_space, AssignmentInformation());
52     return &assignment_info_.back().second;
53   }
54 
chunks()55   absl::Span<const std::pair<HloPosition, HeapSimulator::Chunk>> chunks()
56       const {
57     return chunks_;
58   }
59 
60   absl::Span<const std::pair<int64, AssignmentInformation>>
assignment_informations()61   assignment_informations() const {
62     return assignment_info_;
63   }
64 
65   // Remove the chunks_ entry that corresponds to instruction.
66   void RemoveAssignmentForInstruction(const HloInstruction* instruction);
67 
68  private:
69   std::vector<std::pair<HloPosition, HeapSimulator::Chunk>> chunks_;
70   std::vector<std::pair<int64, AssignmentInformation>> assignment_info_;
71 };
72 
73 // A wrapper class around HloCostAnalysis with additional knowledge about the
74 // bandwidths of different memory spaces.
75 class MemorySpaceAssignmentCostAnalysis {
76  public:
MemorySpaceAssignmentCostAnalysis(const HloCostAnalysis & cost_analysis,float async_copy_bandwidth_bytes_per_second,float alternate_mem_bandwidth_bytes_per_second,const HloLiveRange & hlo_live_range)77   MemorySpaceAssignmentCostAnalysis(
78       const HloCostAnalysis& cost_analysis,
79       float async_copy_bandwidth_bytes_per_second,
80       float alternate_mem_bandwidth_bytes_per_second,
81       const HloLiveRange& hlo_live_range)
82       : cost_analysis_(cost_analysis),
83         async_copy_bandwidth_bytes_per_second_(
84             async_copy_bandwidth_bytes_per_second),
85         alternate_mem_bandwidth_bytes_per_second_(
86             alternate_mem_bandwidth_bytes_per_second),
87         hlo_live_range_(hlo_live_range) {}
88 
cost_analysis()89   const HloCostAnalysis& cost_analysis() const { return cost_analysis_; }
90 
91   // Returns the elapsed time in seconds due to compute only.
92   float GetInstructionElapsedDueToCompute(
93       const HloInstruction& instruction) const;
94 
95   // Returns the elapsed time in seconds due to memory only. If
96   // operand_in_alternate_mem is provided or if output_in_alternate_mem is true,
97   // it will assume that operand or output will be in the alternate memory
98   // space. This is useful for calculating the benefit of placing the buffer in
99   // alternate memory.
100   float GetInstructionElapsedDueToMemory(
101       const HloInstruction& instruction,
102       absl::optional<int64> operand_in_alternate_mem = absl::nullopt,
103       bool output_in_alternate_mem = false) const;
104 
105   // Returns the elapsed time in seconds that other BufferIntervals are slowed
106   // down, due to the prefetching of current bytes. Assuming other
107   // BufferIntervals needs default memory bandwidth, and only current
108   // BufferInterval is prefetched.
109   float GetInstructionElapsedDueToMemorySlowdown(int64 bytes) const;
110 
111   // Returns the estimated elapsed duration of the instruction in seconds.  It
112   // assumes all operands and outputs of the instruction are in the default
113   // memory, except for the operand number that is in the alternate memory, if
114   // provided, or output if output_in_alternate_mem is true.
115   float GetInstructionElapsed(
116       const HloInstruction& instruction,
117       absl::optional<int64> operand_in_alternate_mem = absl::nullopt,
118       bool output_in_alternate_mem = false) const;
119 
120   // Returns the elapsed time it would take to asynchronously copy the shape
121   // from default to alternate memory space (or vice versa).
122   float GetAsyncCopyElapsed(const Shape& shape) const;
123 
124   int64 GetScheduleEndTime() const;
125 
hlo_live_range()126   const HloLiveRange& hlo_live_range() const { return hlo_live_range_; }
127 
128  private:
129   const HloCostAnalysis& cost_analysis_;
130   float async_copy_bandwidth_bytes_per_second_;
131   float alternate_mem_bandwidth_bytes_per_second_;
132   const HloLiveRange& hlo_live_range_;
133 };
134 
135 // Abstract base class that memory space assignment uses to pick prefetch
136 // intervals.
137 class PrefetchIntervalPicker {
138  public:
139   PrefetchIntervalPicker() = default;
140   virtual ~PrefetchIntervalPicker() = default;
141 
142   // Returns true if the buffer can be allocated in alternate memory space
143   // without any copies (prefetches).
144   virtual bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape,
145                                                   int64 start_time,
146                                                   int64 end_time) const = 0;
147 
148   // Returns the preferred end time for an eviction that starts at a given time
149   // and must end by the given end time.
150   virtual int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time,
151                                          int64 latest_end_time) const = 0;
152 
153   // Begins the iterator for the first start time of the prefetch.
154   virtual void Begin(const HloUse& use, int64 start_time, int64 end_time) = 0;
155 
156   // Advances the start time of the prefetch and returns that value.
157   virtual int64 Next() = 0;
158 
159   // Returns true if the available prefetch intervals have been exhausted.
160   virtual bool Done() const = 0;
161 
162   // Returns a debug string for the current state of the prefetch interval
163   // picker.
164   virtual std::string ToDebugString() const = 0;
165 
166   // Returns a debug string for no-copy allocation.
167   virtual std::string ToNoCopyDebugString(const Shape& shape, int64 start_time,
168                                           int64 end_time) const = 0;
169 
170  protected:
171   const absl::flat_hash_map<const HloInstruction*, int64>*
172       instruction_schedule_ = nullptr;
173 };
174 
175 // Prefetch interval picker that uses instruction count to overlap asynchronous
176 // copies with independent computation. The min and max overlap counts describe
177 // the number of independent HLOs overlapped while a value is being prefetched
178 // into the alternate memory (between CopyStart and CopyDone HLO instructions).
179 // max_overlap_count attempts to prevent bringing tensors into the alternate
180 // memory too eagerly and hence occupying the space for other tensors which
181 // might use it.  min_overlap_count attempts to prevent cases where tensors are
182 // prefetched into the alternate memory without sufficient time for the copy to
183 // take place.  In those cases, it's just better to keep the tensor in the
184 // default memory instead of hurting the critical path with this copy that
185 // likely won't finish in time.
186 class InstructionCountPrefetchIntervalPicker : public PrefetchIntervalPicker {
187  public:
InstructionCountPrefetchIntervalPicker(int64 min_overlap_count,int64 max_overlap_count)188   InstructionCountPrefetchIntervalPicker(int64 min_overlap_count,
189                                          int64 max_overlap_count)
190       : min_overlap_count_(min_overlap_count),
191         max_overlap_count_(max_overlap_count) {}
192 
193   bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, int64 start_time,
194                                           int64 end_time) const override;
195 
196   int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time,
197                                  int64 latest_end_time) const override;
198 
199   void Begin(const HloUse& use, int64 start_time, int64 end_time) override;
200 
201   int64 Next() override;
202   bool Done() const override;
203 
204   std::string ToDebugString() const override;
205   std::string ToNoCopyDebugString(const Shape& shape, int64 start_time,
206                                   int64 end_time) const override;
207 
208  private:
209   int64 min_overlap_count_;
210   int64 max_overlap_count_;
211   int64 end_time_;
212   int64 current_prefetch_time_;
213 };
214 
215 // Prefetch interval picker that uses cost analysis to overlap asynchronous
216 // copies with independent computation. It uses min/max (asynchronous copy
217 // duration) / (independent computation duration) ratios to guide whether the
218 // prefetch is within those bounds. It starts with the maximum allowed ratio
219 // (earliest prefetch) in Begin() and works its way for later and later prefetch
220 // with each Next() call until hitting the minimum ratio, in order not to hurt
221 // the critical path.
222 class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker {
223  public:
224   CostAnalysisPrefetchIntervalPicker(
225       const MemorySpaceAssignmentCostAnalysis& cost_analysis,
226       float min_async_copy_to_overlap_ratio,
227       float max_async_copy_to_overlap_ratio);
228 
229   bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, int64 start_time,
230                                           int64 end_time) const override;
231 
232   int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time,
233                                  int64 latest_end_time) const override;
234 
235   void Begin(const HloUse& use, int64 start_time, int64 end_time) override;
236 
237   int64 Next() override;
238   bool Done() const override;
239 
240   std::string ToDebugString() const override;
241   std::string ToNoCopyDebugString(const Shape& shape, int64 start_time,
242                                   int64 end_time) const override;
243 
244  private:
245   // Returns the elapsed time in seconds between the logical interval that
246   // corresponds to the instruction schedule.
247   float GetLogicalIntervalElapsed(int64 start_time, int64 end_time) const;
248 
249   // For performance reasons, we calculate the prefix sum of the elapsed time so
250   // that it's efficient to find the elapsed time in seconds in any logical
251   // interval.
252   std::vector<float> elapsed_time_cumsum_;
253 
254   const MemorySpaceAssignmentCostAnalysis& cost_analysis_;
255   float min_async_copy_to_overlap_ratio_;
256   float max_async_copy_to_overlap_ratio_;
257 
258   float async_copy_elapsed_;
259   float inst_elapsed_reduction_;
260   int64 end_logical_time_;
261   int64 current_logical_prefetch_time_;
262 };
263 
264 // MemorySpaceAssignment assigns memory spaces (default or alternate) to each
265 // instruction in the module. It will greedily try placing as as many values in
266 // the alternate memory space as possible. It uses the heap simulator to
267 // determine the actual allocation offsets of values in the alternate memory
268 // space to account for fragmentation. The default memory space is assumed to be
269 // large enough to hold the values that could not be placed in the alternate
270 // memory space.
271 class MemorySpaceAssignment {
272  public:
273   using Chunk = HeapSimulator::Chunk;
274   using BufferInterval = GlobalDecreasingSizeBestFitHeap::BufferInterval;
275   using BufferIntervalCompare =
276       GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare;
277   using IsAllowedInAlternateMemoryFunction =
278       std::function<bool(const HloValue&)>;
279 
280   // MemorySpaceAssignment uses a notion of a slow and large default memory
281   // space and a fast and small alternate memory space.
282   enum class MemorySpace { kDefault, kAlternate };
283 
284   // The different options to be passed to the Run() API.
285   struct Options {
286     // Backend-specific integer value that describes the alternate memory.
287     int64 alternate_memory_space = 0;
288 
289     // Maximum size of the alternate memory space.
290     int64 max_size_in_bytes = 0;
291 
292     // Memory alignment of the alternate memory space.
293     int64 alignment_in_bytes = 1;
294 
295     // If provided, we sort the buffers using this comparison function
296     // otherwise, we use GlobalDecreasingSizeBestFitHeap::kSpatial.
297     absl::optional<BufferIntervalCompare> buffer_interval_compare =
298         absl::nullopt;
299 
300     // This object determines how early and how late prefetches can occur.
301     PrefetchIntervalPicker* prefetch_interval_picker = nullptr;
302 
303     // Size function for buffer values.
304     BufferValue::SizeFunction size_fn;
305 
306     // This function can be used to prevent certain HloValues (e.g., based on
307     // the opcode) to be placed on the alternate memory.
308     IsAllowedInAlternateMemoryFunction is_allowed_in_alternate_mem_fn;
309 
310     // Specifies the upper bound for number of outstanding asynchronous copies,
311     // -1 for unlimited.
312     int64 max_outstanding_async_copies = -1;
313 
314     // If true, tries allocating buffers across (e.g., before and inside a while
315     // loop body) sequential calls (kWhile, kCall, and kConditional).
316     bool allocate_across_sequential_calls = false;
317 
318     // If true, verifies the memory space assignment against overlapping
319     // buffers.
320     bool verify = false;
321   };
322 
323   // This class represents an allocation that might either be in the default or
324   // alternate memory. An HloValue might live in multiple different allocations
325   // over its lifetime. The lifetimes of the allocations are defined using
326   // start_time and end_time, which corresponds to the instruction indexes in
327   // the flattened schedule. Each of these allocations might partially overlap
328   // with each other. CopyAllocation defined below represents asynchronous
329   // copies between Allocations.
330   //
331   // Consider an instruction Foo, and its users Bar and Baz, and the times given
332   // in terms of the flattened schedule of the entire module:
333   //
334   //      Foo:10
335   //       /   \
336   //    Bar:14  \
337   //           Baz:25
338   //
339   // A valid memory space assignment could be like the following:
340   //
341   //  Time:         10 ... 14        ...      25
342   //                Foo    Bar                Baz
343   //  Alternate     +-------+           +-----+
344   //  Default           +---------------------+
345   //                    ^   ^           ^     ^
346   //                    |   |           |     |
347   //                evict   evict  prefetch  prefetch
348   //                start    end    start      end
349   //
350   // This would be represented with:
351   //   - Allocation(memory_space=kAlternate, start_time=10, end_time=14)
352   //   - CopyAllocation(memory_space=kDefault, start_time=12, end_time=25)
353   //   - CopyAllocation(memory_space=kAlternate, start_time=22, end_time=25)
354   class Allocation {
355    public:
Allocation(HloInstruction * instruction,HloPosition defining_position,MemorySpace memory_space,Chunk chunk,int64 start_time,int64 end_time)356     Allocation(HloInstruction* instruction, HloPosition defining_position,
357                MemorySpace memory_space, Chunk chunk, int64 start_time,
358                int64 end_time)
359         : instruction_(instruction),
360           defining_position_(defining_position),
361           memory_space_(memory_space),
362           chunk_(chunk),
363           start_time_(start_time),
364           end_time_(end_time) {}
365     virtual ~Allocation() = default;
366 
is_copy_allocation()367     virtual bool is_copy_allocation() const { return false; }
368 
369     // Adds a use to this allocation.
370     void AddUse(HloUse use);
371 
372     // Extends the end time of this allocation.
Extend(int64 end_time)373     void Extend(int64 end_time) { end_time_ = end_time; }
374 
375     // After all of the time ranges for the allocations have been assigned,
376     // Process morphs the instructions affected to assign the memory spaces and
377     // insert asynchronous copy instructions if necessary.
378     virtual Status Process(MemorySpaceAssignment* memory_space_assignment);
379 
380     // Returns the instruction that produces this allocation. It might be
381     // different than the instruction in defining_position (e.g., a
382     // GetTupleElement instruction does not define the buffer).
instruction()383     virtual HloInstruction* instruction() const { return instruction_; }
384 
385     // Returns the defining position for this allocation.
defining_position()386     virtual HloPosition defining_position() const { return defining_position_; }
387 
388     // Returns the time the buffer is first available to be used. For
389     // Allocation, this is start_time.
earliest_available_time()390     virtual int64 earliest_available_time() const { return start_time_; }
391 
uses()392     const std::vector<HloUse>& uses() const { return uses_; }
memory_space()393     MemorySpace memory_space() const { return memory_space_; }
chunk()394     Chunk chunk() const { return chunk_; }
set_start_time(int64 start_time)395     void set_start_time(int64 start_time) { start_time_ = start_time; }
start_time()396     int64 start_time() const { return start_time_; }
end_time()397     int64 end_time() const { return end_time_; }
398 
399    protected:
400     // Descend to the shape_index element of the tuple and replace that with
401     // new_instruction.
402     StatusOr<HloInstruction*> ReplaceTupleWith(HloInstruction* new_instruction,
403                                                HloInstruction* tuple,
404                                                ShapeIndex shape_index);
405 
406     HloInstruction* instruction_;
407     HloPosition defining_position_;
408     std::vector<HloUse> uses_;
409     MemorySpace memory_space_;
410     Chunk chunk_;
411     int64 start_time_;
412     int64 end_time_;
413   };
414 
415   // This class represents an allocation as a result of an asynchronous copy.
416   class CopyAllocation : public Allocation {
417    public:
CopyAllocation(const Allocation & prev_allocation,MemorySpace memory_space,Chunk chunk,int64 start_time,int64 end_time,int64 copy_done_schedule_before_time)418     CopyAllocation(const Allocation& prev_allocation, MemorySpace memory_space,
419                    Chunk chunk, int64 start_time, int64 end_time,
420                    int64 copy_done_schedule_before_time)
421         : Allocation(/*instruction=*/nullptr,
422                      /*defining_position=*/{nullptr, {}}, memory_space, chunk,
423                      start_time, end_time),
424           prev_allocation_(prev_allocation),
425           copy_start_schedule_after_(start_time),
426           copy_done_schedule_before_(copy_done_schedule_before_time) {}
427 
is_copy_allocation()428     bool is_copy_allocation() const override { return true; }
429 
430     Status Process(MemorySpaceAssignment* memory_space_assignment) override;
431 
instruction()432     HloInstruction* instruction() const override {
433       // Unless explicitly set, the instruction of a copy allocation in
434       // retrieved from the previous allocation.
435       if (instruction_ != nullptr) {
436         return instruction_;
437       } else {
438         return prev_allocation_.instruction();
439       }
440     }
441 
defining_position()442     HloPosition defining_position() const override {
443       // Unless explicitly set, the defining position of a copy allocation in
444       // retrieved from the previous allocation. This is because we don't create
445       // new CopyStart/CopyDone instructions until later and the position should
446       // point to the previous (copy or otherwise) allocation's position for the
447       // original defining position.
448       if (defining_position_.instruction == nullptr) {
449         return prev_allocation_.defining_position();
450       } else {
451         return defining_position_;
452       }
453     }
454 
copy_start()455     HloInstruction* copy_start() const { return copy_start_; }
copy_done()456     HloInstruction* copy_done() const { return copy_done_; }
457 
458     // Returns the time the buffer is first available to be used. For For
459     // CopyAllocation, this is when the copy ends, which is
460     // copy_done_schedule_before.
earliest_available_time()461     int64 earliest_available_time() const override {
462       return copy_done_schedule_before_;
463     }
464 
copy_start_schedule_after()465     int64 copy_start_schedule_after() const {
466       return copy_start_schedule_after_;
467     }
copy_done_schedule_before()468     int64 copy_done_schedule_before() const {
469       return copy_done_schedule_before_;
470     }
471 
set_copy_start_schedule_after(int64 copy_start_schedule_after)472     void set_copy_start_schedule_after(int64 copy_start_schedule_after) {
473       copy_start_schedule_after_ = copy_start_schedule_after;
474     }
475 
476    private:
477     const Allocation& prev_allocation_;
478     // These variables define the scheduling boundaries where CopyStart and
479     // CopyDone can be scheduled. The earliest CopyStart can be scheduled is
480     // after copy_start_schedule_after_ and the latest CopyDone can be scheduled
481     // is before copy_done_schedule_before_.
482     int64 copy_start_schedule_after_;
483     int64 copy_done_schedule_before_;
484     HloInstruction* copy_start_;
485     HloInstruction* copy_done_;
486   };
487 
488   using AllocationSequence = std::vector<std::unique_ptr<Allocation>>;
489   struct ValueAndAllocationSequence {
490     const HloValue* value;
491     AllocationSequence sequence;
492   };
493   using AllocationSequenceList = std::vector<ValueAndAllocationSequence>;
494 
495   // Runs the MemorySpaceAssignment pass.
496   static StatusOr<std::unique_ptr<PresetAssignments>> Run(
497       HloModule* module, const HloLiveRange& hlo_live_range,
498       const HloAliasAnalysis& alias_analysis, const Options& options);
499 
500   // Returns the maximum number of outstanding asynchronous copies in the
501   // module.
502   static int64 CountMaximumOutstandingAsyncCopies(const HloModule& module);
503 
504   static BufferIntervalCompare GetMemoryBoundednessBufferIntervalCompare(
505       const MemorySpaceAssignmentCostAnalysis& cost_analysis);
506 
507   // Verify that the memory space assignment is free of overlapping buffers and
508   // export heap simulator trace to be used by buffer_assignment.
509   Status VerifyAndExportHeapSimulatorTrace();
510 
511  private:
MemorySpaceAssignment(HloModule * module,Options options,const HloLiveRange & hlo_live_range)512   MemorySpaceAssignment(HloModule* module, Options options,
513                         const HloLiveRange& hlo_live_range)
514       : module_(module),
515         options_(options),
516         flattened_instructions_(hlo_live_range.flattened_instruction_sequence()
517                                     .instructions()
518                                     .begin(),
519                                 hlo_live_range.flattened_instruction_sequence()
520                                     .instructions()
521                                     .end()),
522         computations_in_schedule_(),
523         preset_assignments_(absl::make_unique<PresetAssignments>()) {
524     for (const auto& computation_and_bound :
525          hlo_live_range.computation_span_times()) {
526       computations_in_schedule_.insert(computation_and_bound.first);
527     }
528   }
529 
530   // Process calls Process methods of the allocations after the allocations have
531   // been finalized.
532   Status Process();
533 
534   // Process() might have altered the computation graph by inserting kTuple and
535   // kGetTupleElement instructions. SimplifyGraph performs a simple DCE and
536   // tuple simplification operation (e.g., given GetTupleElement(Tuple(a, b),
537   // 1), simply forwards b). Runs to fixed point.
538   Status SimplifyGraph();
539 
540   // FixSchedule inserts asynchronous copies in the schedule.
541   Status FixSchedule();
542 
543   // Insert an instruction to the schedule, and make sure its dependencies
544   // (operands) are already in the schedule. If not, insert these operands
545   // before the instruction.
546   void EnsureInstructionAndOperandsInserted(
547       HloInstruction* new_instruction, HloInstructionSequence* new_sequence,
548       absl::flat_hash_set<HloInstruction*>* inserted_instructions) const;
549 
550   // Schedules asynchronous copies and ensures that the CopyStarts and their
551   // corresponding CopyDones follow the same order.
552   void ScheduleAsynchronousCopies();
553 
554   HloModule* module_;
555   Options options_;
556   std::vector<HloInstruction*> flattened_instructions_;
557   absl::flat_hash_set<const HloComputation*> computations_in_schedule_;
558   AllocationSequenceList allocation_sequence_list_;
559   std::unique_ptr<PresetAssignments> preset_assignments_;
560 
561   // These maps hold vectors of new instructions that need to be scheduled after
562   // (or before) the instruction index in the key. FixSchedule uses these maps
563   // to modify and fix the schedule.
564   absl::flat_hash_map<int64, std::vector<HloInstruction*>> schedule_after_;
565   absl::flat_hash_map<int64, std::vector<HloInstruction*>> schedule_before_;
566 };
567 
568 // This struct contains mandatory memory assignments at a given time. E.g., an
569 // input's required memory assignment time would correspond to the definition
570 // time of the parameter instruction, and an output's time would correspond to
571 // the time of last use.
572 struct RequiredMemoryAssignment {
573   MemorySpaceAssignment::MemorySpace memory_space;
574   int64 time;
575 };
576 
577 // A struct representing an asynchronous copy with its logical start and end
578 // time and its destination memory space.
579 struct AsynchronousCopy {
580   int64 start_time;
581   int64 end_time;
582   MemorySpaceAssignment::MemorySpace destination;
583 };
584 
585 // Compare asynchronous copies such that an earlier start time has the same or
586 // earlier end time and an earlier end time has the same or earlier start time.
587 bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b);
588 
589 // Helper class to enforce asynchronous copy ordering. We only allow
590 // asynchronous copies that are pipelined: if an asynchronous copy ends earlier
591 // than another asynchronous copy, it must start the same time or earlier than
592 // the other asynchronous copy; and if an asynchronous copy starts earlier than
593 // another asynchronous copy, it must end the same time or earlier than the
594 // other asynchronous copy.
595 class AsynchronousCopyOrdering {
596  public:
597   AsynchronousCopyOrdering() = default;
598 
599   // Adds an asynchronous copy.
600   void AddCopy(const AsynchronousCopy& copy);
601 
602   // Returns true if the addition of an asynchronous copy in the the given time
603   // interval would violate the asynchronous copy ordering. E.g., consider the
604   // following scenario:
605   //                                  CS          CD
606   //  already committed async copy:   +-----------+
607   //                new async copy:     +--------+
608   //
609   // The new asynchronous copy would violate the ordering guarantee because the
610   // copy start is after an already committed asynchronous copy while its copy
611   // done is before the committed copy.
612   bool ViolatesOrdering(int64 start_time, int64 end_time) const;
613 
614  private:
615   // Stores asynchronous copies in a tree set respecting the pipelining order.
616   std::set<AsynchronousCopy> ranges_;
617 };
618 
619 // This class inherits from GlobalDecreasingSizeBestFitHeap with a notion of
620 // maximum size.
621 class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
622  public:
623   using MemorySpace = MemorySpaceAssignment::MemorySpace;
624 
AlternateMemoryBestFitHeap(MemorySpaceAssignment::AllocationSequenceList * allocation_sequence_list,const MemorySpaceAssignment::Options & options,const HloAliasAnalysis & alias_analysis,const HloLiveRange & hlo_live_range)625   AlternateMemoryBestFitHeap(
626       MemorySpaceAssignment::AllocationSequenceList* allocation_sequence_list,
627       const MemorySpaceAssignment::Options& options,
628       const HloAliasAnalysis& alias_analysis,
629       const HloLiveRange& hlo_live_range)
630       : GlobalDecreasingSizeBestFitHeap(options.alignment_in_bytes),
631         allocation_sequence_list_(allocation_sequence_list),
632         options_(options),
633         alias_analysis_(alias_analysis),
634         hlo_live_range_(hlo_live_range) {
635     // Override buffer interval compare if provided.
636     if (options.buffer_interval_compare) {
637       buffer_interval_compare_ = *options.buffer_interval_compare;
638     }
639   }
640 
641   HeapSimulator::Result Finish() override;
642 
643  private:
644   // Given an allocation sequence, returns the live allocation at time with a
645   // preference towards allocations in alternate memory. Returns nullptr if no
646   // allocation is alive at that time.
647   static MemorySpaceAssignment::Allocation* GetLiveAllocationAt(
648       const MemorySpaceAssignment::AllocationSequence& allocations, int64 time);
649 
650   // Returns true if a buffer is required to be in default memory at a
651   // particular time. A buffer may be required to be in default memory because
652   // it is a parameter in default memory or an ouput in default memory.
653   bool RequiredInDefaultMemory(const HloValue* buffer, int64 time) const;
654 
655   // Returns true if this buffer is allowed to be placed in the alternate
656   // memory.
657   bool IsIntervalAllowedInAlternateMemory(const BufferInterval& interval) const;
658 
659   // Finds an allocation for the given interval. Internally, it will attempt to
660   // find a suitable chunk candidate within the heap size and prefetch interval
661   // limits, and append the new allocation(s) to allocations. The new
662   // allocations can be in default or alternate memory spaces, or can be
663   // prefetches or evictions. Returns true if successful.
664   bool FindAllocation(int64 start_time, int64 end_time, int64 last_use_time,
665                       int64 latest_prefetch_time, HloPosition defining_position,
666                       HloUse use, const HloValue* buffer, int64 size,
667                       MemorySpaceAssignment::AllocationSequence* allocations);
668 
669   // Try allocating in alternate memory without any copies. Returns true if
670   // successful.
671   bool TryAllocatingInAlternateMemoryNoCopy(
672       int64 start_time, int64 end_time, int64 last_use_time,
673       HloPosition defining_position, HloUse use,
674       BufferInterval alternate_mem_interval,
675       HloInstruction* non_bitcast_operand,
676       MemorySpaceAssignment::AllocationSequence* allocations);
677 
678   // For a no-copy allocation, find the best possible chunk candidate, where it
679   // has the longest possible availability if no preferred offset is given, or
680   // at the preferred_offset if it is given.
681   absl::optional<ChunkCandidate> FindBestNoCopyChunkCandidate(
682       int64 end_time, int64 last_use_time,
683       absl::optional<int64> preferred_offset,
684       BufferInterval* alternate_mem_interval) const;
685 
686   // Adds input and outputs as required assignments.
687   void AddInputAndOutputRequiredAssignments();
688 
689   // Returns true if the colocated intervals in the argument are in a parameter
690   // or root instruction of the entry computation and are reserved by the user
691   // to be in the alternate memory space.
692   bool AreIntervalsReservedInAlternateMemory(
693       absl::Span<const BufferInterval* const> colocated_intervals) const;
694 
695   // Given a buffer interval, returns the colocated intervals. Unlike the
696   // similar GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations, it
697   // returns the colocated intervals sorted by scheduled time.
698   std::vector<const BufferInterval*> GetSortedColocatedIntervals(
699       const BufferInterval& interval) const;
700 
701   // Since the allocations are recorded to the AllocationSequenceList, we don't
702   // maintain result_ in GlobalDecreasingSizeBestFitHeap. Override AddToChunkMap
703   // to avoid unnecessarily adding the chunk to the chunk map.
AddToChunkMap(const HloValue * buffer,Chunk chunk)704   void AddToChunkMap(const HloValue* buffer, Chunk chunk) override {}
705 
706   // Returns true if the addition of an asynchronous copy in the given time
707   // interval would violate the maximum number of asynchronous copies.
708   bool ViolatesMaximumOutstandingAsyncCopies(int64 start_time,
709                                              int64 end_time) const;
710 
711   // Return true if the asynchronous copy would violate the pipelining order.
712   bool ViolatesAsyncCopyOrdering(int64 start_time, int64 end_time) const;
713 
714   // Adds an asynchronous copy to the allocations.
715   void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation,
716                     MemorySpace memory_space, Chunk chunk, int64 start_time,
717                     int64 end_time, int64 copy_done_schedule_before_time,
718                     MemorySpaceAssignment::AllocationSequence* allocations);
719 
720   // These methods are used for delaying committing the chunk candidate until
721   // the entire live range of the buffer has been considered.
722   void AddToPendingChunks(const BufferInterval& buffer_interval,
723                           const ChunkCandidate& chunk_candidate);
724   void CommitPendingChunks();
725 
726   // Returns the available heap size in the alternate memory.
available_heap_size()727   int64 available_heap_size() const {
728     return options_.max_size_in_bytes - reserved_in_bytes_;
729   }
730 
731   MemorySpaceAssignment::AllocationSequenceList* allocation_sequence_list_;
732   const MemorySpaceAssignment::Options& options_;
733   const HloAliasAnalysis& alias_analysis_;
734   const HloLiveRange& hlo_live_range_;
735   // We use a interval tree to keep track of the number of outstanding
736   // asynchronous copies.
737   BufferIntervalTree async_copy_interval_tree_;
738   AsynchronousCopyOrdering async_copy_ordering_;
739   std::vector<std::pair<BufferInterval, ChunkCandidate>> pending_chunks_;
740   std::vector<AsynchronousCopy> pending_async_copies_;
741   // This map contains required memory assignments for HloValues (e.g., input
742   // and outputs).
743   absl::flat_hash_map<const HloValue*, std::vector<RequiredMemoryAssignment>>
744       required_assignments_;
745   // Number of bytes reserved in alternate memory space.
746   int64 reserved_in_bytes_ = 0;
747 };
748 
749 }  // namespace xla
750 
751 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_
752