1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_ 18 19 #include "tensorflow/compiler/xla/service/heap_simulator.h" 20 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" 21 22 namespace xla { 23 24 // This class contains pre-set assignments determined by memory space 25 // assignment. It contains two data structures: (1) a chunks vector that maps a 26 // defining HloPosition to a Chunk (offset and size), and (2) a sizes vector 27 // that maps the memory space to its size. If there is only one alternate memory 28 // space like there is currently, there will be one entry in sizes. 29 class PresetAssignments { 30 public: 31 // Contains per-memory-space information like the allocated size and heap 32 // simulator trace. 33 struct AssignmentInformation { 34 int64 size; 35 HeapSimulatorTrace heap_simulator_trace; 36 }; 37 38 PresetAssignments() = default; 39 add_chunk(const HloPosition & position,const HeapSimulator::Chunk & chunk)40 void add_chunk(const HloPosition& position, 41 const HeapSimulator::Chunk& chunk) { 42 chunks_.emplace_back(position, chunk); 43 } 44 assignment_information_for_space(int64 memory_space)45 AssignmentInformation* assignment_information_for_space(int64 memory_space) { 46 for (auto& space_and_info : assignment_info_) { 47 if (space_and_info.first == memory_space) { 48 return &space_and_info.second; 49 } 50 } 51 assignment_info_.emplace_back(memory_space, AssignmentInformation()); 52 return &assignment_info_.back().second; 53 } 54 chunks()55 absl::Span<const std::pair<HloPosition, HeapSimulator::Chunk>> chunks() 56 const { 57 return chunks_; 58 } 59 60 absl::Span<const std::pair<int64, AssignmentInformation>> assignment_informations()61 assignment_informations() const { 62 return assignment_info_; 63 } 64 65 // Remove the chunks_ entry that corresponds to instruction. 66 void RemoveAssignmentForInstruction(const HloInstruction* instruction); 67 68 private: 69 std::vector<std::pair<HloPosition, HeapSimulator::Chunk>> chunks_; 70 std::vector<std::pair<int64, AssignmentInformation>> assignment_info_; 71 }; 72 73 // A wrapper class around HloCostAnalysis with additional knowledge about the 74 // bandwidths of different memory spaces. 75 class MemorySpaceAssignmentCostAnalysis { 76 public: MemorySpaceAssignmentCostAnalysis(const HloCostAnalysis & cost_analysis,float async_copy_bandwidth_bytes_per_second,float alternate_mem_bandwidth_bytes_per_second,const HloLiveRange & hlo_live_range)77 MemorySpaceAssignmentCostAnalysis( 78 const HloCostAnalysis& cost_analysis, 79 float async_copy_bandwidth_bytes_per_second, 80 float alternate_mem_bandwidth_bytes_per_second, 81 const HloLiveRange& hlo_live_range) 82 : cost_analysis_(cost_analysis), 83 async_copy_bandwidth_bytes_per_second_( 84 async_copy_bandwidth_bytes_per_second), 85 alternate_mem_bandwidth_bytes_per_second_( 86 alternate_mem_bandwidth_bytes_per_second), 87 hlo_live_range_(hlo_live_range) {} 88 cost_analysis()89 const HloCostAnalysis& cost_analysis() const { return cost_analysis_; } 90 91 // Returns the elapsed time in seconds due to compute only. 92 float GetInstructionElapsedDueToCompute( 93 const HloInstruction& instruction) const; 94 95 // Returns the elapsed time in seconds due to memory only. If 96 // operand_in_alternate_mem is provided or if output_in_alternate_mem is true, 97 // it will assume that operand or output will be in the alternate memory 98 // space. This is useful for calculating the benefit of placing the buffer in 99 // alternate memory. 100 float GetInstructionElapsedDueToMemory( 101 const HloInstruction& instruction, 102 absl::optional<int64> operand_in_alternate_mem = absl::nullopt, 103 bool output_in_alternate_mem = false) const; 104 105 // Returns the elapsed time in seconds that other BufferIntervals are slowed 106 // down, due to the prefetching of current bytes. Assuming other 107 // BufferIntervals needs default memory bandwidth, and only current 108 // BufferInterval is prefetched. 109 float GetInstructionElapsedDueToMemorySlowdown(int64 bytes) const; 110 111 // Returns the estimated elapsed duration of the instruction in seconds. It 112 // assumes all operands and outputs of the instruction are in the default 113 // memory, except for the operand number that is in the alternate memory, if 114 // provided, or output if output_in_alternate_mem is true. 115 float GetInstructionElapsed( 116 const HloInstruction& instruction, 117 absl::optional<int64> operand_in_alternate_mem = absl::nullopt, 118 bool output_in_alternate_mem = false) const; 119 120 // Returns the elapsed time it would take to asynchronously copy the shape 121 // from default to alternate memory space (or vice versa). 122 float GetAsyncCopyElapsed(const Shape& shape) const; 123 124 int64 GetScheduleEndTime() const; 125 hlo_live_range()126 const HloLiveRange& hlo_live_range() const { return hlo_live_range_; } 127 128 private: 129 const HloCostAnalysis& cost_analysis_; 130 float async_copy_bandwidth_bytes_per_second_; 131 float alternate_mem_bandwidth_bytes_per_second_; 132 const HloLiveRange& hlo_live_range_; 133 }; 134 135 // Abstract base class that memory space assignment uses to pick prefetch 136 // intervals. 137 class PrefetchIntervalPicker { 138 public: 139 PrefetchIntervalPicker() = default; 140 virtual ~PrefetchIntervalPicker() = default; 141 142 // Returns true if the buffer can be allocated in alternate memory space 143 // without any copies (prefetches). 144 virtual bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, 145 int64 start_time, 146 int64 end_time) const = 0; 147 148 // Returns the preferred end time for an eviction that starts at a given time 149 // and must end by the given end time. 150 virtual int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time, 151 int64 latest_end_time) const = 0; 152 153 // Begins the iterator for the first start time of the prefetch. 154 virtual void Begin(const HloUse& use, int64 start_time, int64 end_time) = 0; 155 156 // Advances the start time of the prefetch and returns that value. 157 virtual int64 Next() = 0; 158 159 // Returns true if the available prefetch intervals have been exhausted. 160 virtual bool Done() const = 0; 161 162 // Returns a debug string for the current state of the prefetch interval 163 // picker. 164 virtual std::string ToDebugString() const = 0; 165 166 // Returns a debug string for no-copy allocation. 167 virtual std::string ToNoCopyDebugString(const Shape& shape, int64 start_time, 168 int64 end_time) const = 0; 169 170 protected: 171 const absl::flat_hash_map<const HloInstruction*, int64>* 172 instruction_schedule_ = nullptr; 173 }; 174 175 // Prefetch interval picker that uses instruction count to overlap asynchronous 176 // copies with independent computation. The min and max overlap counts describe 177 // the number of independent HLOs overlapped while a value is being prefetched 178 // into the alternate memory (between CopyStart and CopyDone HLO instructions). 179 // max_overlap_count attempts to prevent bringing tensors into the alternate 180 // memory too eagerly and hence occupying the space for other tensors which 181 // might use it. min_overlap_count attempts to prevent cases where tensors are 182 // prefetched into the alternate memory without sufficient time for the copy to 183 // take place. In those cases, it's just better to keep the tensor in the 184 // default memory instead of hurting the critical path with this copy that 185 // likely won't finish in time. 186 class InstructionCountPrefetchIntervalPicker : public PrefetchIntervalPicker { 187 public: InstructionCountPrefetchIntervalPicker(int64 min_overlap_count,int64 max_overlap_count)188 InstructionCountPrefetchIntervalPicker(int64 min_overlap_count, 189 int64 max_overlap_count) 190 : min_overlap_count_(min_overlap_count), 191 max_overlap_count_(max_overlap_count) {} 192 193 bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, int64 start_time, 194 int64 end_time) const override; 195 196 int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time, 197 int64 latest_end_time) const override; 198 199 void Begin(const HloUse& use, int64 start_time, int64 end_time) override; 200 201 int64 Next() override; 202 bool Done() const override; 203 204 std::string ToDebugString() const override; 205 std::string ToNoCopyDebugString(const Shape& shape, int64 start_time, 206 int64 end_time) const override; 207 208 private: 209 int64 min_overlap_count_; 210 int64 max_overlap_count_; 211 int64 end_time_; 212 int64 current_prefetch_time_; 213 }; 214 215 // Prefetch interval picker that uses cost analysis to overlap asynchronous 216 // copies with independent computation. It uses min/max (asynchronous copy 217 // duration) / (independent computation duration) ratios to guide whether the 218 // prefetch is within those bounds. It starts with the maximum allowed ratio 219 // (earliest prefetch) in Begin() and works its way for later and later prefetch 220 // with each Next() call until hitting the minimum ratio, in order not to hurt 221 // the critical path. 222 class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { 223 public: 224 CostAnalysisPrefetchIntervalPicker( 225 const MemorySpaceAssignmentCostAnalysis& cost_analysis, 226 float min_async_copy_to_overlap_ratio, 227 float max_async_copy_to_overlap_ratio); 228 229 bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, int64 start_time, 230 int64 end_time) const override; 231 232 int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time, 233 int64 latest_end_time) const override; 234 235 void Begin(const HloUse& use, int64 start_time, int64 end_time) override; 236 237 int64 Next() override; 238 bool Done() const override; 239 240 std::string ToDebugString() const override; 241 std::string ToNoCopyDebugString(const Shape& shape, int64 start_time, 242 int64 end_time) const override; 243 244 private: 245 // Returns the elapsed time in seconds between the logical interval that 246 // corresponds to the instruction schedule. 247 float GetLogicalIntervalElapsed(int64 start_time, int64 end_time) const; 248 249 // For performance reasons, we calculate the prefix sum of the elapsed time so 250 // that it's efficient to find the elapsed time in seconds in any logical 251 // interval. 252 std::vector<float> elapsed_time_cumsum_; 253 254 const MemorySpaceAssignmentCostAnalysis& cost_analysis_; 255 float min_async_copy_to_overlap_ratio_; 256 float max_async_copy_to_overlap_ratio_; 257 258 float async_copy_elapsed_; 259 float inst_elapsed_reduction_; 260 int64 end_logical_time_; 261 int64 current_logical_prefetch_time_; 262 }; 263 264 // MemorySpaceAssignment assigns memory spaces (default or alternate) to each 265 // instruction in the module. It will greedily try placing as as many values in 266 // the alternate memory space as possible. It uses the heap simulator to 267 // determine the actual allocation offsets of values in the alternate memory 268 // space to account for fragmentation. The default memory space is assumed to be 269 // large enough to hold the values that could not be placed in the alternate 270 // memory space. 271 class MemorySpaceAssignment { 272 public: 273 using Chunk = HeapSimulator::Chunk; 274 using BufferInterval = GlobalDecreasingSizeBestFitHeap::BufferInterval; 275 using BufferIntervalCompare = 276 GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare; 277 using IsAllowedInAlternateMemoryFunction = 278 std::function<bool(const HloValue&)>; 279 280 // MemorySpaceAssignment uses a notion of a slow and large default memory 281 // space and a fast and small alternate memory space. 282 enum class MemorySpace { kDefault, kAlternate }; 283 284 // The different options to be passed to the Run() API. 285 struct Options { 286 // Backend-specific integer value that describes the alternate memory. 287 int64 alternate_memory_space = 0; 288 289 // Maximum size of the alternate memory space. 290 int64 max_size_in_bytes = 0; 291 292 // Memory alignment of the alternate memory space. 293 int64 alignment_in_bytes = 1; 294 295 // If provided, we sort the buffers using this comparison function 296 // otherwise, we use GlobalDecreasingSizeBestFitHeap::kSpatial. 297 absl::optional<BufferIntervalCompare> buffer_interval_compare = 298 absl::nullopt; 299 300 // This object determines how early and how late prefetches can occur. 301 PrefetchIntervalPicker* prefetch_interval_picker = nullptr; 302 303 // Size function for buffer values. 304 BufferValue::SizeFunction size_fn; 305 306 // This function can be used to prevent certain HloValues (e.g., based on 307 // the opcode) to be placed on the alternate memory. 308 IsAllowedInAlternateMemoryFunction is_allowed_in_alternate_mem_fn; 309 310 // Specifies the upper bound for number of outstanding asynchronous copies, 311 // -1 for unlimited. 312 int64 max_outstanding_async_copies = -1; 313 314 // If true, tries allocating buffers across (e.g., before and inside a while 315 // loop body) sequential calls (kWhile, kCall, and kConditional). 316 bool allocate_across_sequential_calls = false; 317 318 // If true, verifies the memory space assignment against overlapping 319 // buffers. 320 bool verify = false; 321 }; 322 323 // This class represents an allocation that might either be in the default or 324 // alternate memory. An HloValue might live in multiple different allocations 325 // over its lifetime. The lifetimes of the allocations are defined using 326 // start_time and end_time, which corresponds to the instruction indexes in 327 // the flattened schedule. Each of these allocations might partially overlap 328 // with each other. CopyAllocation defined below represents asynchronous 329 // copies between Allocations. 330 // 331 // Consider an instruction Foo, and its users Bar and Baz, and the times given 332 // in terms of the flattened schedule of the entire module: 333 // 334 // Foo:10 335 // / \ 336 // Bar:14 \ 337 // Baz:25 338 // 339 // A valid memory space assignment could be like the following: 340 // 341 // Time: 10 ... 14 ... 25 342 // Foo Bar Baz 343 // Alternate +-------+ +-----+ 344 // Default +---------------------+ 345 // ^ ^ ^ ^ 346 // | | | | 347 // evict evict prefetch prefetch 348 // start end start end 349 // 350 // This would be represented with: 351 // - Allocation(memory_space=kAlternate, start_time=10, end_time=14) 352 // - CopyAllocation(memory_space=kDefault, start_time=12, end_time=25) 353 // - CopyAllocation(memory_space=kAlternate, start_time=22, end_time=25) 354 class Allocation { 355 public: Allocation(HloInstruction * instruction,HloPosition defining_position,MemorySpace memory_space,Chunk chunk,int64 start_time,int64 end_time)356 Allocation(HloInstruction* instruction, HloPosition defining_position, 357 MemorySpace memory_space, Chunk chunk, int64 start_time, 358 int64 end_time) 359 : instruction_(instruction), 360 defining_position_(defining_position), 361 memory_space_(memory_space), 362 chunk_(chunk), 363 start_time_(start_time), 364 end_time_(end_time) {} 365 virtual ~Allocation() = default; 366 is_copy_allocation()367 virtual bool is_copy_allocation() const { return false; } 368 369 // Adds a use to this allocation. 370 void AddUse(HloUse use); 371 372 // Extends the end time of this allocation. Extend(int64 end_time)373 void Extend(int64 end_time) { end_time_ = end_time; } 374 375 // After all of the time ranges for the allocations have been assigned, 376 // Process morphs the instructions affected to assign the memory spaces and 377 // insert asynchronous copy instructions if necessary. 378 virtual Status Process(MemorySpaceAssignment* memory_space_assignment); 379 380 // Returns the instruction that produces this allocation. It might be 381 // different than the instruction in defining_position (e.g., a 382 // GetTupleElement instruction does not define the buffer). instruction()383 virtual HloInstruction* instruction() const { return instruction_; } 384 385 // Returns the defining position for this allocation. defining_position()386 virtual HloPosition defining_position() const { return defining_position_; } 387 388 // Returns the time the buffer is first available to be used. For 389 // Allocation, this is start_time. earliest_available_time()390 virtual int64 earliest_available_time() const { return start_time_; } 391 uses()392 const std::vector<HloUse>& uses() const { return uses_; } memory_space()393 MemorySpace memory_space() const { return memory_space_; } chunk()394 Chunk chunk() const { return chunk_; } set_start_time(int64 start_time)395 void set_start_time(int64 start_time) { start_time_ = start_time; } start_time()396 int64 start_time() const { return start_time_; } end_time()397 int64 end_time() const { return end_time_; } 398 399 protected: 400 // Descend to the shape_index element of the tuple and replace that with 401 // new_instruction. 402 StatusOr<HloInstruction*> ReplaceTupleWith(HloInstruction* new_instruction, 403 HloInstruction* tuple, 404 ShapeIndex shape_index); 405 406 HloInstruction* instruction_; 407 HloPosition defining_position_; 408 std::vector<HloUse> uses_; 409 MemorySpace memory_space_; 410 Chunk chunk_; 411 int64 start_time_; 412 int64 end_time_; 413 }; 414 415 // This class represents an allocation as a result of an asynchronous copy. 416 class CopyAllocation : public Allocation { 417 public: CopyAllocation(const Allocation & prev_allocation,MemorySpace memory_space,Chunk chunk,int64 start_time,int64 end_time,int64 copy_done_schedule_before_time)418 CopyAllocation(const Allocation& prev_allocation, MemorySpace memory_space, 419 Chunk chunk, int64 start_time, int64 end_time, 420 int64 copy_done_schedule_before_time) 421 : Allocation(/*instruction=*/nullptr, 422 /*defining_position=*/{nullptr, {}}, memory_space, chunk, 423 start_time, end_time), 424 prev_allocation_(prev_allocation), 425 copy_start_schedule_after_(start_time), 426 copy_done_schedule_before_(copy_done_schedule_before_time) {} 427 is_copy_allocation()428 bool is_copy_allocation() const override { return true; } 429 430 Status Process(MemorySpaceAssignment* memory_space_assignment) override; 431 instruction()432 HloInstruction* instruction() const override { 433 // Unless explicitly set, the instruction of a copy allocation in 434 // retrieved from the previous allocation. 435 if (instruction_ != nullptr) { 436 return instruction_; 437 } else { 438 return prev_allocation_.instruction(); 439 } 440 } 441 defining_position()442 HloPosition defining_position() const override { 443 // Unless explicitly set, the defining position of a copy allocation in 444 // retrieved from the previous allocation. This is because we don't create 445 // new CopyStart/CopyDone instructions until later and the position should 446 // point to the previous (copy or otherwise) allocation's position for the 447 // original defining position. 448 if (defining_position_.instruction == nullptr) { 449 return prev_allocation_.defining_position(); 450 } else { 451 return defining_position_; 452 } 453 } 454 copy_start()455 HloInstruction* copy_start() const { return copy_start_; } copy_done()456 HloInstruction* copy_done() const { return copy_done_; } 457 458 // Returns the time the buffer is first available to be used. For For 459 // CopyAllocation, this is when the copy ends, which is 460 // copy_done_schedule_before. earliest_available_time()461 int64 earliest_available_time() const override { 462 return copy_done_schedule_before_; 463 } 464 copy_start_schedule_after()465 int64 copy_start_schedule_after() const { 466 return copy_start_schedule_after_; 467 } copy_done_schedule_before()468 int64 copy_done_schedule_before() const { 469 return copy_done_schedule_before_; 470 } 471 set_copy_start_schedule_after(int64 copy_start_schedule_after)472 void set_copy_start_schedule_after(int64 copy_start_schedule_after) { 473 copy_start_schedule_after_ = copy_start_schedule_after; 474 } 475 476 private: 477 const Allocation& prev_allocation_; 478 // These variables define the scheduling boundaries where CopyStart and 479 // CopyDone can be scheduled. The earliest CopyStart can be scheduled is 480 // after copy_start_schedule_after_ and the latest CopyDone can be scheduled 481 // is before copy_done_schedule_before_. 482 int64 copy_start_schedule_after_; 483 int64 copy_done_schedule_before_; 484 HloInstruction* copy_start_; 485 HloInstruction* copy_done_; 486 }; 487 488 using AllocationSequence = std::vector<std::unique_ptr<Allocation>>; 489 struct ValueAndAllocationSequence { 490 const HloValue* value; 491 AllocationSequence sequence; 492 }; 493 using AllocationSequenceList = std::vector<ValueAndAllocationSequence>; 494 495 // Runs the MemorySpaceAssignment pass. 496 static StatusOr<std::unique_ptr<PresetAssignments>> Run( 497 HloModule* module, const HloLiveRange& hlo_live_range, 498 const HloAliasAnalysis& alias_analysis, const Options& options); 499 500 // Returns the maximum number of outstanding asynchronous copies in the 501 // module. 502 static int64 CountMaximumOutstandingAsyncCopies(const HloModule& module); 503 504 static BufferIntervalCompare GetMemoryBoundednessBufferIntervalCompare( 505 const MemorySpaceAssignmentCostAnalysis& cost_analysis); 506 507 // Verify that the memory space assignment is free of overlapping buffers and 508 // export heap simulator trace to be used by buffer_assignment. 509 Status VerifyAndExportHeapSimulatorTrace(); 510 511 private: MemorySpaceAssignment(HloModule * module,Options options,const HloLiveRange & hlo_live_range)512 MemorySpaceAssignment(HloModule* module, Options options, 513 const HloLiveRange& hlo_live_range) 514 : module_(module), 515 options_(options), 516 flattened_instructions_(hlo_live_range.flattened_instruction_sequence() 517 .instructions() 518 .begin(), 519 hlo_live_range.flattened_instruction_sequence() 520 .instructions() 521 .end()), 522 computations_in_schedule_(), 523 preset_assignments_(absl::make_unique<PresetAssignments>()) { 524 for (const auto& computation_and_bound : 525 hlo_live_range.computation_span_times()) { 526 computations_in_schedule_.insert(computation_and_bound.first); 527 } 528 } 529 530 // Process calls Process methods of the allocations after the allocations have 531 // been finalized. 532 Status Process(); 533 534 // Process() might have altered the computation graph by inserting kTuple and 535 // kGetTupleElement instructions. SimplifyGraph performs a simple DCE and 536 // tuple simplification operation (e.g., given GetTupleElement(Tuple(a, b), 537 // 1), simply forwards b). Runs to fixed point. 538 Status SimplifyGraph(); 539 540 // FixSchedule inserts asynchronous copies in the schedule. 541 Status FixSchedule(); 542 543 // Insert an instruction to the schedule, and make sure its dependencies 544 // (operands) are already in the schedule. If not, insert these operands 545 // before the instruction. 546 void EnsureInstructionAndOperandsInserted( 547 HloInstruction* new_instruction, HloInstructionSequence* new_sequence, 548 absl::flat_hash_set<HloInstruction*>* inserted_instructions) const; 549 550 // Schedules asynchronous copies and ensures that the CopyStarts and their 551 // corresponding CopyDones follow the same order. 552 void ScheduleAsynchronousCopies(); 553 554 HloModule* module_; 555 Options options_; 556 std::vector<HloInstruction*> flattened_instructions_; 557 absl::flat_hash_set<const HloComputation*> computations_in_schedule_; 558 AllocationSequenceList allocation_sequence_list_; 559 std::unique_ptr<PresetAssignments> preset_assignments_; 560 561 // These maps hold vectors of new instructions that need to be scheduled after 562 // (or before) the instruction index in the key. FixSchedule uses these maps 563 // to modify and fix the schedule. 564 absl::flat_hash_map<int64, std::vector<HloInstruction*>> schedule_after_; 565 absl::flat_hash_map<int64, std::vector<HloInstruction*>> schedule_before_; 566 }; 567 568 // This struct contains mandatory memory assignments at a given time. E.g., an 569 // input's required memory assignment time would correspond to the definition 570 // time of the parameter instruction, and an output's time would correspond to 571 // the time of last use. 572 struct RequiredMemoryAssignment { 573 MemorySpaceAssignment::MemorySpace memory_space; 574 int64 time; 575 }; 576 577 // A struct representing an asynchronous copy with its logical start and end 578 // time and its destination memory space. 579 struct AsynchronousCopy { 580 int64 start_time; 581 int64 end_time; 582 MemorySpaceAssignment::MemorySpace destination; 583 }; 584 585 // Compare asynchronous copies such that an earlier start time has the same or 586 // earlier end time and an earlier end time has the same or earlier start time. 587 bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b); 588 589 // Helper class to enforce asynchronous copy ordering. We only allow 590 // asynchronous copies that are pipelined: if an asynchronous copy ends earlier 591 // than another asynchronous copy, it must start the same time or earlier than 592 // the other asynchronous copy; and if an asynchronous copy starts earlier than 593 // another asynchronous copy, it must end the same time or earlier than the 594 // other asynchronous copy. 595 class AsynchronousCopyOrdering { 596 public: 597 AsynchronousCopyOrdering() = default; 598 599 // Adds an asynchronous copy. 600 void AddCopy(const AsynchronousCopy& copy); 601 602 // Returns true if the addition of an asynchronous copy in the the given time 603 // interval would violate the asynchronous copy ordering. E.g., consider the 604 // following scenario: 605 // CS CD 606 // already committed async copy: +-----------+ 607 // new async copy: +--------+ 608 // 609 // The new asynchronous copy would violate the ordering guarantee because the 610 // copy start is after an already committed asynchronous copy while its copy 611 // done is before the committed copy. 612 bool ViolatesOrdering(int64 start_time, int64 end_time) const; 613 614 private: 615 // Stores asynchronous copies in a tree set respecting the pipelining order. 616 std::set<AsynchronousCopy> ranges_; 617 }; 618 619 // This class inherits from GlobalDecreasingSizeBestFitHeap with a notion of 620 // maximum size. 621 class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { 622 public: 623 using MemorySpace = MemorySpaceAssignment::MemorySpace; 624 AlternateMemoryBestFitHeap(MemorySpaceAssignment::AllocationSequenceList * allocation_sequence_list,const MemorySpaceAssignment::Options & options,const HloAliasAnalysis & alias_analysis,const HloLiveRange & hlo_live_range)625 AlternateMemoryBestFitHeap( 626 MemorySpaceAssignment::AllocationSequenceList* allocation_sequence_list, 627 const MemorySpaceAssignment::Options& options, 628 const HloAliasAnalysis& alias_analysis, 629 const HloLiveRange& hlo_live_range) 630 : GlobalDecreasingSizeBestFitHeap(options.alignment_in_bytes), 631 allocation_sequence_list_(allocation_sequence_list), 632 options_(options), 633 alias_analysis_(alias_analysis), 634 hlo_live_range_(hlo_live_range) { 635 // Override buffer interval compare if provided. 636 if (options.buffer_interval_compare) { 637 buffer_interval_compare_ = *options.buffer_interval_compare; 638 } 639 } 640 641 HeapSimulator::Result Finish() override; 642 643 private: 644 // Given an allocation sequence, returns the live allocation at time with a 645 // preference towards allocations in alternate memory. Returns nullptr if no 646 // allocation is alive at that time. 647 static MemorySpaceAssignment::Allocation* GetLiveAllocationAt( 648 const MemorySpaceAssignment::AllocationSequence& allocations, int64 time); 649 650 // Returns true if a buffer is required to be in default memory at a 651 // particular time. A buffer may be required to be in default memory because 652 // it is a parameter in default memory or an ouput in default memory. 653 bool RequiredInDefaultMemory(const HloValue* buffer, int64 time) const; 654 655 // Returns true if this buffer is allowed to be placed in the alternate 656 // memory. 657 bool IsIntervalAllowedInAlternateMemory(const BufferInterval& interval) const; 658 659 // Finds an allocation for the given interval. Internally, it will attempt to 660 // find a suitable chunk candidate within the heap size and prefetch interval 661 // limits, and append the new allocation(s) to allocations. The new 662 // allocations can be in default or alternate memory spaces, or can be 663 // prefetches or evictions. Returns true if successful. 664 bool FindAllocation(int64 start_time, int64 end_time, int64 last_use_time, 665 int64 latest_prefetch_time, HloPosition defining_position, 666 HloUse use, const HloValue* buffer, int64 size, 667 MemorySpaceAssignment::AllocationSequence* allocations); 668 669 // Try allocating in alternate memory without any copies. Returns true if 670 // successful. 671 bool TryAllocatingInAlternateMemoryNoCopy( 672 int64 start_time, int64 end_time, int64 last_use_time, 673 HloPosition defining_position, HloUse use, 674 BufferInterval alternate_mem_interval, 675 HloInstruction* non_bitcast_operand, 676 MemorySpaceAssignment::AllocationSequence* allocations); 677 678 // For a no-copy allocation, find the best possible chunk candidate, where it 679 // has the longest possible availability if no preferred offset is given, or 680 // at the preferred_offset if it is given. 681 absl::optional<ChunkCandidate> FindBestNoCopyChunkCandidate( 682 int64 end_time, int64 last_use_time, 683 absl::optional<int64> preferred_offset, 684 BufferInterval* alternate_mem_interval) const; 685 686 // Adds input and outputs as required assignments. 687 void AddInputAndOutputRequiredAssignments(); 688 689 // Returns true if the colocated intervals in the argument are in a parameter 690 // or root instruction of the entry computation and are reserved by the user 691 // to be in the alternate memory space. 692 bool AreIntervalsReservedInAlternateMemory( 693 absl::Span<const BufferInterval* const> colocated_intervals) const; 694 695 // Given a buffer interval, returns the colocated intervals. Unlike the 696 // similar GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations, it 697 // returns the colocated intervals sorted by scheduled time. 698 std::vector<const BufferInterval*> GetSortedColocatedIntervals( 699 const BufferInterval& interval) const; 700 701 // Since the allocations are recorded to the AllocationSequenceList, we don't 702 // maintain result_ in GlobalDecreasingSizeBestFitHeap. Override AddToChunkMap 703 // to avoid unnecessarily adding the chunk to the chunk map. AddToChunkMap(const HloValue * buffer,Chunk chunk)704 void AddToChunkMap(const HloValue* buffer, Chunk chunk) override {} 705 706 // Returns true if the addition of an asynchronous copy in the given time 707 // interval would violate the maximum number of asynchronous copies. 708 bool ViolatesMaximumOutstandingAsyncCopies(int64 start_time, 709 int64 end_time) const; 710 711 // Return true if the asynchronous copy would violate the pipelining order. 712 bool ViolatesAsyncCopyOrdering(int64 start_time, int64 end_time) const; 713 714 // Adds an asynchronous copy to the allocations. 715 void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation, 716 MemorySpace memory_space, Chunk chunk, int64 start_time, 717 int64 end_time, int64 copy_done_schedule_before_time, 718 MemorySpaceAssignment::AllocationSequence* allocations); 719 720 // These methods are used for delaying committing the chunk candidate until 721 // the entire live range of the buffer has been considered. 722 void AddToPendingChunks(const BufferInterval& buffer_interval, 723 const ChunkCandidate& chunk_candidate); 724 void CommitPendingChunks(); 725 726 // Returns the available heap size in the alternate memory. available_heap_size()727 int64 available_heap_size() const { 728 return options_.max_size_in_bytes - reserved_in_bytes_; 729 } 730 731 MemorySpaceAssignment::AllocationSequenceList* allocation_sequence_list_; 732 const MemorySpaceAssignment::Options& options_; 733 const HloAliasAnalysis& alias_analysis_; 734 const HloLiveRange& hlo_live_range_; 735 // We use a interval tree to keep track of the number of outstanding 736 // asynchronous copies. 737 BufferIntervalTree async_copy_interval_tree_; 738 AsynchronousCopyOrdering async_copy_ordering_; 739 std::vector<std::pair<BufferInterval, ChunkCandidate>> pending_chunks_; 740 std::vector<AsynchronousCopy> pending_async_copies_; 741 // This map contains required memory assignments for HloValues (e.g., input 742 // and outputs). 743 absl::flat_hash_map<const HloValue*, std::vector<RequiredMemoryAssignment>> 744 required_assignments_; 745 // Number of bytes reserved in alternate memory space. 746 int64 reserved_in_bytes_ = 0; 747 }; 748 749 } // namespace xla 750 751 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_ 752