1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_ 18 19 #include <functional> 20 #include <string> 21 #include <utility> 22 23 #include "tensorflow/compiler/xla/service/heap_simulator.h" 24 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" 25 #include "tensorflow/compiler/xla/service/memory_space_assignment_repacking.h" 26 27 namespace xla { 28 29 namespace memory_space_assignment { 30 // Forward Declaration of Options. 31 class Options; 32 33 // This class contains pre-set assignments determined by memory space 34 // assignment. It contains two data structures: (1) a chunks vector that maps a 35 // defining HloPosition to a Chunk (offset and size), and (2) an assignment_info 36 // vector that maps the memory space to information like its allocated size and 37 // heap memory trace. If there is only one alternate memory space like there is 38 // currently, there will be one entry in assignment_info. 39 class PresetAssignments { 40 public: 41 // Contains per-memory-space information like the allocated size and heap 42 // simulator trace. 43 struct AssignmentInformation { 44 int64 size; 45 HeapSimulatorTrace heap_simulator_trace; 46 }; 47 48 PresetAssignments() = default; 49 add_chunk(const HloPosition & position,const HeapSimulator::Chunk & chunk)50 void add_chunk(const HloPosition& position, 51 const HeapSimulator::Chunk& chunk) { 52 chunks_.emplace_back(position, chunk); 53 } 54 add_scoped_allocation_chunk(HloInstruction * instruction,const HeapSimulator::Chunk & chunk)55 void add_scoped_allocation_chunk(HloInstruction* instruction, 56 const HeapSimulator::Chunk& chunk) { 57 scoped_allocation_chunks_.emplace_back(instruction, chunk); 58 } 59 assignment_information_for_space(int64_t memory_space)60 AssignmentInformation* assignment_information_for_space( 61 int64_t memory_space) { 62 for (auto& space_and_info : assignment_info_) { 63 if (space_and_info.first == memory_space) { 64 return &space_and_info.second; 65 } 66 } 67 assignment_info_.emplace_back(memory_space, AssignmentInformation()); 68 return &assignment_info_.back().second; 69 } 70 chunks()71 absl::Span<const std::pair<HloPosition, HeapSimulator::Chunk>> chunks() 72 const { 73 return chunks_; 74 } 75 76 absl::Span<const std::pair<HloInstruction*, HeapSimulator::Chunk>> scoped_allocation_chunks()77 scoped_allocation_chunks() const { 78 return scoped_allocation_chunks_; 79 } 80 81 absl::Span<const std::pair<int64, AssignmentInformation>> assignment_informations()82 assignment_informations() const { 83 return assignment_info_; 84 } 85 86 // Get debugging information. buffer_info_str()87 std::string buffer_info_str() const { return buffer_info_str_; } allocation_info_str()88 std::string allocation_info_str() const { return allocation_info_str_; } 89 90 private: 91 std::vector<std::pair<HloPosition, HeapSimulator::Chunk>> chunks_; 92 std::vector<std::pair<HloInstruction*, HeapSimulator::Chunk>> 93 scoped_allocation_chunks_; 94 std::vector<std::pair<int64, AssignmentInformation>> assignment_info_; 95 std::string buffer_info_str_; 96 std::string allocation_info_str_; 97 }; 98 99 // A wrapper class around HloCostAnalysis with additional knowledge about the 100 // bandwidths of different memory spaces. 101 class MemorySpaceAssignmentCostAnalysis { 102 public: 103 // An optional Cache object may be provided to some of the methods below to 104 // speed up the lookup. 105 struct Cache { 106 absl::flat_hash_map<const HloInstruction*, float> while_nest_multiplier; 107 }; 108 109 virtual ~MemorySpaceAssignmentCostAnalysis() = default; 110 111 static StatusOr<std::unique_ptr<MemorySpaceAssignmentCostAnalysis>> Create( 112 const HloCostAnalysis& cost_analysis, const Options& options, 113 const HloModule& module); 114 cost_analysis()115 const HloCostAnalysis& cost_analysis() const { return cost_analysis_; } 116 117 // Returns a heuristic value that captures how much putting this tensor to the 118 // alternate memory would help if the op is memory bound, or otherwise how far 119 // off is the op to memory boundedness. The larger this number, the higher 120 // priority it will be placed in the alternate memory. 121 float GetAlternateMemoryBenefit(const HloInstruction& instruction, 122 float elapsed_time_due_to_alternate_mem, 123 Cache* cache = nullptr) const; 124 125 // Returns a heuristic value of memory boundedness for the given 126 // BufferInterval. The larger this number, the higher priority it will be 127 // placed in the alternate memory. 128 float GetMemoryBoundedness( 129 const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval, 130 Cache* cache = nullptr) const; 131 132 // Returns the elapsed time in seconds due to compute only. 133 float GetInstructionElapsedDueToCompute( 134 const HloInstruction& instruction) const; 135 136 // Returns the elapsed time in seconds due to memory only. If 137 // operands_in_alternate_mem or outputs_in_alternate_mem is provided, it will 138 // assume that the corresponding operands or output will be in the alternate 139 // memory space. This is useful for calculating the benefit of placing the 140 // buffer in alternate memory. 141 float GetInstructionElapsedDueToMemory( 142 const HloInstruction& instruction, 143 absl::Span<const std::pair<int64_t, ShapeIndex>> 144 operands_in_alternate_mem = {}, 145 absl::Span<const ShapeIndex> outputs_in_alternate_mem = {}) const; 146 147 // Returns the estimated elapsed duration of the instruction in seconds. It 148 // assumes all operands and outputs of the instruction are in the default 149 // memory. 150 virtual float GetInstructionElapsed(const HloInstruction& instruction) const; 151 152 // Returns the estimated elapsed duration of the instruction in seconds. It 153 // assumes all operands and outputs of the instruction are in the default 154 // memory, except for the operands and outputs specified to be in the 155 // alternate memory. 156 virtual float GetInstructionElapsedInAlternateMemory( 157 const HloInstruction& instruction, 158 absl::Span<const std::pair<int64_t, ShapeIndex>> 159 operands_in_alternate_mem, 160 absl::Span<const ShapeIndex> outputs_in_alternate_mem) const; 161 162 // Returns the elapsed time it would take to asynchronously copy the shape 163 // from default to alternate memory space (or vice versa). 164 virtual float GetAsyncCopyElapsed(const Shape& shape) const; 165 166 int64 GetScheduleEndTime() const; 167 168 // Returns the number of nested computation levels this instruction resides 169 // in. If while_only is true, it returns the while loop nest level and 0 170 // means the instruction is not in a while loop. 171 int CalculateComputationNestLevel(const HloInstruction* instruction, 172 bool while_only) const; 173 hlo_live_range()174 const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; } options()175 const Options& options() const { return options_; } 176 177 protected: MemorySpaceAssignmentCostAnalysis(const HloCostAnalysis & cost_analysis,const Options & options,std::unique_ptr<HloAliasAnalysis> alias_analysis,std::unique_ptr<HloLiveRange> hlo_live_range,std::unique_ptr<CallGraph> call_graph)178 MemorySpaceAssignmentCostAnalysis( 179 const HloCostAnalysis& cost_analysis, const Options& options, 180 std::unique_ptr<HloAliasAnalysis> alias_analysis, 181 std::unique_ptr<HloLiveRange> hlo_live_range, 182 std::unique_ptr<CallGraph> call_graph) 183 : cost_analysis_(cost_analysis), 184 options_(options), 185 alias_analysis_(std::move(alias_analysis)), 186 hlo_live_range_(std::move(hlo_live_range)), 187 call_graph_(std::move(call_graph)) {} 188 189 private: 190 const HloCostAnalysis& cost_analysis_; 191 const Options& options_; 192 std::unique_ptr<HloAliasAnalysis> alias_analysis_; 193 std::unique_ptr<HloLiveRange> hlo_live_range_; 194 std::unique_ptr<CallGraph> call_graph_; 195 }; 196 197 // Abstract base class that memory space assignment uses to pick prefetch 198 // intervals. 199 class PrefetchIntervalPicker { 200 public: 201 PrefetchIntervalPicker() = default; 202 virtual ~PrefetchIntervalPicker() = default; 203 204 // Returns true if the buffer can be allocated in alternate memory space 205 // without any copies (prefetches). 206 virtual bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, 207 int64_t start_time, 208 int64_t end_time) const = 0; 209 210 // Returns the preferred end time for an eviction that starts at a given time 211 // and must end by the given end time. 212 virtual int64 PreferredEvictionEndTime(const Shape& shape, int64_t start_time, 213 int64_t latest_end_time) const = 0; 214 215 // Returns the latest time that a prefetch can start. 216 virtual int64 LatestPrefetchStartTime(const Shape& shape, int64_t start_time, 217 int64_t end_time, 218 const HloUse* use) const = 0; 219 220 // Returns the preferred time that a prefetch can start. 221 virtual int64 PreferredPrefetchStartTime(const Shape& shape, 222 int64_t earliest_prefetch_start_time, 223 int64_t latest_prefetch_start_time, 224 int64_t prefetch_end_time) const = 0; 225 226 // Returns the latest time that a prefetch can end that is less than or equal 227 // to proposed_prefetch_end_time. LatestPrefetchEndTime(int64_t original_prefetch_end_time,int64_t proposed_prefetch_end_time)228 virtual int64 LatestPrefetchEndTime( 229 int64_t original_prefetch_end_time, 230 int64_t proposed_prefetch_end_time) const { 231 return proposed_prefetch_end_time; 232 } 233 234 // Begins the iterator for the first start time of the prefetch. 235 virtual void Begin(const HloUse& use, int64_t start_time, 236 int64_t end_time) = 0; 237 238 // Advances the start time of the prefetch and returns that value. 239 virtual int64 Next() = 0; 240 241 // Returns true if the available prefetch intervals have been exhausted. 242 virtual bool Done() const = 0; 243 244 // The retry number can be used to modify the interval picking policies. The 245 // first attempt will have a retry_number of 0, then 1, etc. SetRetryNumber(int retry_number)246 virtual void SetRetryNumber(int retry_number) {} 247 248 // Returns a debug string for the current state of the prefetch interval 249 // picker. 250 virtual std::string ToDebugString() const = 0; 251 252 // Returns a debug string for no-copy allocation. 253 virtual std::string ToNoCopyDebugString(const Shape& shape, 254 int64_t start_time, 255 int64_t end_time) const = 0; 256 257 // Prefetch interval pickers may return a value corresponding to the benefit 258 // of placing the BufferInterval in the alternate memory. The larger value, 259 // the more beneficial. BufferIntervalAlternateMemoryBenefit(const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval & interval)260 virtual absl::optional<float> BufferIntervalAlternateMemoryBenefit( 261 const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval) 262 const { 263 return absl::nullopt; 264 } 265 266 protected: 267 const absl::flat_hash_map<const HloInstruction*, int64>* 268 instruction_schedule_ = nullptr; 269 }; 270 271 // Prefetch interval picker that uses instruction count to overlap asynchronous 272 // copies with independent computation. The min and max overlap counts describe 273 // the number of independent HLOs overlapped while a value is being prefetched 274 // into the alternate memory (between CopyStart and CopyDone HLO instructions). 275 // max_overlap_count attempts to prevent bringing tensors into the alternate 276 // memory too eagerly and hence occupying the space for other tensors which 277 // might use it. min_overlap_count attempts to prevent cases where tensors are 278 // prefetched into the alternate memory without sufficient time for the copy to 279 // take place. In those cases, it's just better to keep the tensor in the 280 // default memory instead of hurting the critical path with this copy that 281 // likely won't finish in time. 282 class InstructionCountPrefetchIntervalPicker : public PrefetchIntervalPicker { 283 public: InstructionCountPrefetchIntervalPicker(int64_t min_overlap_count,int64_t max_overlap_count)284 InstructionCountPrefetchIntervalPicker(int64_t min_overlap_count, 285 int64_t max_overlap_count) 286 : min_overlap_count_(min_overlap_count), 287 max_overlap_count_(max_overlap_count) {} 288 289 bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, 290 int64_t start_time, 291 int64_t end_time) const override; 292 293 int64 PreferredEvictionEndTime(const Shape& shape, int64_t start_time, 294 int64_t latest_end_time) const override; 295 296 int64 LatestPrefetchStartTime(const Shape& shape, int64_t start_time, 297 int64_t end_time, 298 const HloUse* use) const override; 299 300 int64 PreferredPrefetchStartTime(const Shape& shape, 301 int64_t earliest_prefetch_start_time, 302 int64_t latest_prefetch_start_time, 303 int64_t prefetch_end_time) const override; 304 305 void Begin(const HloUse& use, int64_t start_time, int64_t end_time) override; 306 307 int64 Next() override; 308 bool Done() const override; 309 310 std::string ToDebugString() const override; 311 std::string ToNoCopyDebugString(const Shape& shape, int64_t start_time, 312 int64_t end_time) const override; 313 314 private: 315 int64 min_overlap_count_; 316 int64 max_overlap_count_; 317 int64 end_time_; 318 int64 current_prefetch_time_; 319 }; 320 321 // Forward Declaration of MemorySpaceAssignmentCostAnalysis 322 class MemorySpaceAssignmentCostAnalysis; 323 // Prefetch interval picker that uses cost analysis to overlap asynchronous 324 // copies with independent computation. It uses min/max (asynchronous copy 325 // duration) / (independent computation duration) ratios to guide whether the 326 // prefetch is within those bounds. It starts with the preferred ratio in 327 // Begin() and works its way for alternately earlier and later prefetches until 328 // hitting min and max ratios. The value for buffer size for max async copy is a 329 // mechanism to prevent copying small buffers between the two memories 330 // unnecessarily. For calculating the max time that the buffer can reside in 331 // alternate memory, we use the larger of this value and the actual size of the 332 // buffer. 333 class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { 334 public: 335 CostAnalysisPrefetchIntervalPicker( 336 const MemorySpaceAssignmentCostAnalysis& cost_analysis, 337 float min_async_copy_to_overlap_ratio, 338 float max_async_copy_to_overlap_ratio, 339 float preferred_async_copy_to_overlap_ratio, 340 int64_t buffer_size_for_max_async_copy); 341 342 bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, 343 int64_t start_time, 344 int64_t end_time) const override; 345 346 int64 PreferredEvictionEndTime(const Shape& shape, int64_t start_time, 347 int64_t latest_end_time) const override; 348 349 int64 LatestPrefetchEndTime( 350 int64_t original_prefetch_end_time, 351 int64_t proposed_prefetch_end_time) const override; 352 353 int64 LatestPrefetchStartTime(const Shape& shape, int64_t start_time, 354 int64_t end_time, 355 const HloUse* use) const override; 356 357 int64 PreferredPrefetchStartTime(const Shape& shape, 358 int64_t earliest_prefetch_start_time, 359 int64_t latest_prefetch_start_time, 360 int64_t prefetch_end_time) const override; 361 362 void Begin(const HloUse& use, int64_t start_time, int64_t end_time) override; 363 364 int64 Next() override; 365 bool Done() const override; 366 367 void SetRetryNumber(int retry_number) override; 368 369 std::string ToDebugString() const override; 370 std::string ToNoCopyDebugString(const Shape& shape, int64_t start_time, 371 int64_t end_time) const override; 372 373 absl::optional<float> BufferIntervalAlternateMemoryBenefit( 374 const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval) 375 const override; 376 377 private: 378 // Returns the elapsed time in seconds between the logical interval that 379 // corresponds to the instruction schedule. 380 float GetLogicalIntervalElapsed(int64_t start_time, int64_t end_time) const; 381 382 // Finds the minimum nest level in the given interval. 383 int GetMinWhileNestLevel(int64_t start_time, int64_t end_time) const; 384 385 // Given the elapsed time to copy this buffer to the alternate memory, returns 386 // the longest time that this buffer may reside in the alternate memory space. 387 float GetMaxElapsedInAlternateMemory(float async_copy_elapsed) const; 388 389 // For each instruction in the flattened schedule, maintain their elapsed time 390 // (in cumulative sum) and while nesting level. 391 std::vector<float> elapsed_time_cumsum_; 392 std::vector<int> while_nest_level_; 393 std::vector<int> computation_nest_level_; 394 // Maintain the index of the most recent (before this instruction) nest level 395 // change in order to efficiently determine the minimum nest level in an 396 // interval. 397 std::vector<int> while_nest_level_change_; 398 399 const MemorySpaceAssignmentCostAnalysis& cost_analysis_; 400 float min_async_copy_to_overlap_ratio_; 401 float max_async_copy_to_overlap_ratio_; 402 float preferred_async_copy_to_overlap_ratio_; 403 int64_t buffer_size_for_max_async_copy_; 404 float max_overlap_multiplier_ = 1.0; 405 406 float async_copy_elapsed_; 407 float inst_elapsed_reduction_; 408 int64 end_logical_time_; 409 int64 earliest_prefetch_time_; 410 int64 latest_prefetch_time_; 411 bool using_increasing_prefetch_time_iterator_ = true; 412 int64 increasing_prefetch_time_iterator_; 413 int64 decreasing_prefetch_time_iterator_; 414 }; 415 416 // MemorySpaceAssignment assigns memory spaces (default or alternate) to each 417 // instruction in the module. It will greedily try placing as as many values in 418 // the alternate memory space as possible. It uses the heap simulator to 419 // determine the actual allocation offsets of values in the alternate memory 420 // space to account for fragmentation. The default memory space is assumed to be 421 // large enough to hold the values that could not be placed in the alternate 422 // memory space. 423 class MemorySpaceAssignment { 424 public: 425 using Chunk = HeapSimulator::Chunk; 426 using BufferInterval = 427 GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval; 428 using BufferIntervalCompare = 429 GlobalDecreasingSizeBestFitHeap<HloValue>::BufferIntervalCompare; 430 using IsAllowedInAlternateMemoryFunction = 431 std::function<bool(const HloValue&)>; 432 using IsUseAllowedInAlternateMemoryFunction = 433 std::function<bool(const HloUse&)>; 434 using ReservedScopedMemoryFunction = 435 std::function<int64_t(const HloInstruction*)>; 436 437 // MemorySpaceAssignment uses a notion of a slow and large default memory 438 // space and a fast and small alternate memory space. 439 enum class MemorySpace { kDefault, kAlternate }; 440 441 // Forward declaration for Allocation. 442 class Allocation; 443 class ParentAllocation; 444 445 // This class represents an allocation that might either be in the default or 446 // alternate memory. An HloValue might live in multiple different allocations 447 // over its lifetime. The lifetimes of the allocations are defined using 448 // start_time and end_time, which corresponds to the instruction indexes in 449 // the flattened schedule. Each of these allocations might partially overlap 450 // with each other. CopyAllocation defined below represents asynchronous 451 // copies between Allocations. 452 // 453 // Consider an instruction Foo, and its users Bar and Baz, and the times given 454 // in terms of the flattened schedule of the entire module: 455 // 456 // Foo:10 457 // / \ 458 // Bar:14 \ 459 // Baz:25 460 // 461 // A valid memory space assignment could be like the following: 462 // 463 // Time: 10 ... 14 ... 25 464 // Foo Bar Baz 465 // Alternate +-------+ +-----+ 466 // Default +---------------------+ 467 // ^ ^ ^ ^ 468 // | | | | 469 // evict evict prefetch prefetch 470 // start end start end 471 // 472 // This would be represented with: 473 // - Allocation(memory_space=kAlternate, start_time=10, end_time=14) 474 // - CopyAllocation(memory_space=kDefault, start_time=12, end_time=25) 475 // - CopyAllocation(memory_space=kAlternate, start_time=22, end_time=25) 476 class Allocation { 477 friend class ParentAllocation; 478 479 public: Allocation(HloPosition defining_position,MemorySpace memory_space,absl::optional<Chunk> chunk,int64_t start_time,int64_t end_time,bool is_scoped_allocation)480 Allocation(HloPosition defining_position, MemorySpace memory_space, 481 absl::optional<Chunk> chunk, int64_t start_time, 482 int64_t end_time, bool is_scoped_allocation) 483 : defining_position_(defining_position), 484 memory_space_(memory_space), 485 chunk_(chunk), 486 start_time_(start_time), 487 end_time_(end_time), 488 is_scoped_allocation_(is_scoped_allocation) { 489 CHECK(!is_scoped_allocation || defining_position.index == ShapeIndex({})); 490 } 491 virtual ~Allocation() = default; 492 is_copy_allocation()493 virtual bool is_copy_allocation() const { return false; } 494 495 // Adds a use to this allocation. 496 void AddUse(HloUse use); 497 498 // Extends the end time of this allocation. Extend(int64_t end_time)499 void Extend(int64_t end_time) { end_time_ = end_time; } 500 501 // After all of the time ranges for the allocations have been assigned, 502 // Process morphs the instructions affected to assign the memory spaces and 503 // insert asynchronous copy instructions if necessary. 504 virtual Status Process(); 505 506 // An optional post-process step that will be called after all allocations 507 // have been processed. PostProcess()508 virtual Status PostProcess() { return Status::OK(); } 509 510 // Marks (adds this allocation to needed_allocations) if this allocation is 511 // needed. Allocation and CopyAllocations are always needed and 512 // ParentAllocations are needed if they have any uses or if other 513 // CopyAllocation or ParentAllocations depend on them. 514 virtual void MarkIfNeeded( 515 absl::flat_hash_set<const Allocation*>& needed_allocations) const; 516 517 // Marks this allocation as needed. 518 virtual void MarkNeeded( 519 absl::flat_hash_set<const Allocation*>& needed_allocations) const; 520 521 // Returns the defining position for this allocation. defining_position()522 virtual HloPosition defining_position() const { return defining_position_; } 523 524 // Returns the time the buffer is first available to be used. For 525 // Allocation, this is start_time. earliest_available_time()526 virtual int64 earliest_available_time() const { return start_time_; } 527 uses()528 const std::vector<HloUse>& uses() const { return uses_; } memory_space()529 MemorySpace memory_space() const { return memory_space_; } chunk()530 Chunk chunk() const { return *chunk_; } mutable_chunk()531 Chunk* mutable_chunk() { return &*chunk_; } set_start_time(int64_t start_time)532 void set_start_time(int64_t start_time) { start_time_ = start_time; } start_time()533 int64 start_time() const { return start_time_; } end_time()534 int64 end_time() const { return end_time_; } is_scoped_allocation()535 bool is_scoped_allocation() const { return is_scoped_allocation_; } 536 537 bool operator==(const Allocation& other) const; 538 virtual std::string ToString() const; 539 540 protected: 541 // Descend to the shape_index element of the tuple and replace that with 542 // new_instruction. 543 StatusOr<HloInstruction*> ReplaceTupleWith(HloInstruction* new_instruction, 544 HloInstruction* tuple, 545 ShapeIndex shape_index); 546 547 // Recursively create kGetTupleElement instructions if the defining position 548 // shape is not an array. Returns the new instruction that has array shape. 549 HloInstruction* AddGetTupleElements() const; 550 551 HloPosition defining_position_; 552 std::vector<HloUse> uses_; 553 MemorySpace memory_space_; 554 absl::optional<Chunk> chunk_; 555 int64 start_time_; 556 int64 end_time_; 557 const bool is_scoped_allocation_; 558 }; 559 560 // This class represents an allocation as a result of an asynchronous copy. 561 // Note: CopyStart instructions are inserted after `start_time` or later, 562 // while CopyDone instructions are inserted before 563 // `copy_done_schedule_before_time` or earlier. 564 class CopyAllocation : public Allocation { 565 public: 566 CopyAllocation(const Allocation& prev_allocation, MemorySpace memory_space, 567 absl::optional<Chunk> chunk, int64_t start_time, 568 int64_t end_time, int64_t copy_done_schedule_before_time, 569 bool is_cross_program_prefetch = false) 570 : Allocation(/*defining_position=*/{nullptr, {}}, memory_space, chunk, 571 start_time, end_time, /*is_scoped_allocation=*/false), 572 prev_allocation_(prev_allocation), 573 copy_start_schedule_after_(start_time), 574 copy_done_schedule_before_(copy_done_schedule_before_time), 575 is_cross_program_prefetch_(is_cross_program_prefetch) {} 576 is_copy_allocation()577 bool is_copy_allocation() const override { return true; } 578 579 Status Process() override; 580 581 void MarkNeeded(absl::flat_hash_set<const Allocation*>& needed_allocations) 582 const override; 583 defining_position()584 HloPosition defining_position() const override { 585 // Unless explicitly set, the defining position of a copy allocation in 586 // retrieved from the previous allocation. This is because we don't create 587 // new CopyStart/CopyDone instructions until later and the position should 588 // point to the previous (copy or otherwise) allocation's position for the 589 // original defining position. 590 if (defining_position_.instruction == nullptr) { 591 return prev_allocation_.defining_position(); 592 } 593 return defining_position_; 594 } 595 copy_start()596 HloInstruction* copy_start() const { return copy_start_; } copy_done()597 HloInstruction* copy_done() const { return copy_done_; } 598 599 // Returns the time the buffer is first available to be used. For For 600 // CopyAllocation, this is when the copy ends, which is 601 // copy_done_schedule_before. earliest_available_time()602 int64 earliest_available_time() const override { 603 return copy_done_schedule_before_; 604 } 605 copy_start_schedule_after()606 int64 copy_start_schedule_after() const { 607 return copy_start_schedule_after_; 608 } copy_done_schedule_before()609 int64 copy_done_schedule_before() const { 610 return copy_done_schedule_before_; 611 } 612 set_copy_start_schedule_after(int64_t copy_start_schedule_after)613 void set_copy_start_schedule_after(int64_t copy_start_schedule_after) { 614 copy_start_schedule_after_ = copy_start_schedule_after; 615 } 616 is_cross_program_prefetch()617 bool is_cross_program_prefetch() const { 618 return is_cross_program_prefetch_; 619 } 620 621 bool operator==(const CopyAllocation& other) const; 622 std::string ToString() const override; 623 624 private: 625 const Allocation& prev_allocation_; 626 // These variables define the scheduling boundaries where CopyStart and 627 // CopyDone can be scheduled. The earliest CopyStart can be scheduled is 628 // after copy_start_schedule_after_ and the latest CopyDone can be scheduled 629 // is before copy_done_schedule_before_. 630 int64 copy_start_schedule_after_; 631 int64 copy_done_schedule_before_; 632 bool is_cross_program_prefetch_; 633 HloInstruction* copy_start_; 634 HloInstruction* copy_done_; 635 }; 636 637 // An allocation in default memory space that is defined in the parent 638 // computation. If a value has a copy in the default memory space in the 639 // parent computation, we don't need to evict this buffer in a while loop. 640 class ParentAllocation : public Allocation { 641 public: ParentAllocation(const Allocation & original_allocation,HloInstruction * calling_instruction,HloPosition position,int64_t time)642 ParentAllocation(const Allocation& original_allocation, 643 HloInstruction* calling_instruction, HloPosition position, 644 int64_t time) 645 : Allocation(position, MemorySpace::kDefault, 646 original_allocation.chunk(), /*start_time=*/time, 647 /*end_time=*/time, /*is_scoped_allocation=*/false), 648 original_allocation_(original_allocation), 649 calling_instruction_(calling_instruction) {} 650 651 Status Process() override; 652 Status PostProcess() override; 653 654 void MarkIfNeeded(absl::flat_hash_set<const Allocation*>& 655 needed_allocations) const override; 656 void MarkNeeded(absl::flat_hash_set<const Allocation*>& needed_allocations) 657 const override; 658 659 std::string ToString() const override; 660 661 private: 662 const Allocation& original_allocation_; 663 HloInstruction* calling_instruction_; 664 }; 665 666 using AllocationSequence = std::vector<std::unique_ptr<Allocation>>; 667 // AllocationValue is used to break up HloValues for each non-trivial position 668 // (trivial positions are considered Tuple, GetTupleElement, and Bitcast). An 669 // HloValue may include positions and uses that alias with each other across 670 // multiple computations. We use this class to break these HloValues such that 671 // every AllocationValue has one defining position (that may alias with other 672 // AllocationValues). The uses field of the AllocationValue contains only the 673 // direct uses of the AllocationValue's defining position. 674 // 675 // For example, consider the following HLO snippet: 676 // 677 // Body { 678 // body_param = (f32[4,3]{1,0}, f32[]) parameter(0) 679 // get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element(body_param), 680 // index=0 681 // ... 682 // ROOT tuple = (f32[4,3]{1,0}, f32[]) tuple(get-tuple-element.3, ...) 683 // } 684 // 685 // Cond { 686 // cond_param = (f32[4,3]{1,0}, f32[]) parameter(0) 687 // ... 688 // } 689 // 690 // add.4 = f32[4,3]{1,0} add(...) 691 // tuple.1 = (f32[4,3]{1,0}, f32[]) tuple(add.4, ...) 692 // while = (f32[4,3]{1,0}, f32[]) while(tuple.1), body=Body, condition=Cond 693 // get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element(while), index=0 694 // add.5 = f32[4,3]{1,0} add(get-tuple-element.5, ...) 695 // 696 // This contains an HloValue that looks like the following: 697 // positions: 698 // add.4 699 // body_param {0} 700 // get-tuple-element.3 701 // tuple {0} 702 // cond_param {0} 703 // tuple.1 {0} 704 // while {0} 705 // get-tuple-element.5 706 // uses: 707 // add.1, operand 0 708 // tuple, operand 0 709 // while, operand 0 {0} 710 // add.5, operand 0 711 // 712 // We break this HloValue up into the following AllocationValues for each 713 // non-trivial position: 714 // AllocationValue1: computation = Entry 715 // position: 716 // add.4 717 // uses: 718 // while, operand 0 {0} 719 // AllocationValue2: computation = Cond 720 // position: 721 // cond_param {0} 722 // uses: 723 // AllocationValue3: computation = Body 724 // position: 725 // body_param {0} 726 // uses: 727 // add.1, operand 0 728 // tuple, operand 0 729 // AllocationValue4: computation = Entry 730 // position: 731 // while {0} 732 // uses: 733 // add.5, operand 0 734 class AllocationValue { 735 public: 736 // This data structure wraps an HloUse and adds additional metadata that are 737 // useful for allocation. 738 struct Use { 739 // The wrapped HloUse object. 740 HloUse hlo_use; 741 // The logical time this use is scheduled. 742 int64 time; 743 // All the positions where this use aliases with. The aliased positions 744 // must get the same allocation. 745 std::vector<HloPosition> aliases; 746 747 bool operator==(const Use& other) const { 748 return hlo_use == other.hlo_use && time == other.time && 749 aliases == other.aliases; 750 } 751 752 template <typename H> AbslHashValueUse753 friend H AbslHashValue(H h, const Use& s) { 754 return H::combine(std::move(h), s.hlo_use, s.time, s.aliases); 755 } 756 }; 757 AllocationValue(const HloValue * value,const HloPosition & position,int64_t size)758 AllocationValue(const HloValue* value, const HloPosition& position, 759 int64_t size) 760 : value_(value), 761 defining_position_(position), 762 size_(size), 763 requires_contiguous_allocation_(false) {} 764 defining_position()765 const HloPosition& defining_position() const { return defining_position_; } defining_instruction()766 const HloInstruction* defining_instruction() const { 767 return defining_position().instruction; 768 } size()769 int64 size() const { return size_; } uses()770 const std::vector<Use>& uses() const { return uses_; } uses()771 std::vector<Use>& uses() { return uses_; } value()772 const HloValue* value() const { return value_; } computation()773 const HloComputation* computation() const { 774 return defining_instruction()->parent(); 775 } allocation_sequence()776 AllocationSequence* allocation_sequence() { return &allocation_sequence_; } 777 778 // Sets/gets whether this AllocationValue requires allocating it 779 // contiguously throughout its live range (without any copies). requires_contiguous_allocation()780 bool requires_contiguous_allocation() const { 781 return requires_contiguous_allocation_; 782 } set_requires_contiguous_allocation(bool requires_contiguous_allocation)783 void set_requires_contiguous_allocation( 784 bool requires_contiguous_allocation) { 785 requires_contiguous_allocation_ = requires_contiguous_allocation; 786 } 787 AddUse(const HloUse & use,int64_t use_time)788 void AddUse(const HloUse& use, int64_t use_time) { 789 uses_.push_back({use, use_time, {}}); 790 } 791 792 std::string ToString() const; 793 std::string ToShortString() const; 794 795 private: 796 const HloValue* value_; 797 HloPosition defining_position_; 798 int64 size_; 799 // If true, there must be a contiguous allocation for this buffer without 800 // any copies. 801 bool requires_contiguous_allocation_; 802 std::vector<Use> uses_; 803 AllocationSequence allocation_sequence_; 804 }; 805 806 // Statistics of asynchronous copies. 807 struct AsyncCopyStats { 808 int64 max_outstanding_async_copies; 809 int64 num_prefetches; 810 int64 prefetch_bytes; 811 int64 num_evictions; 812 int64 eviction_bytes; 813 }; 814 815 virtual ~MemorySpaceAssignment() = default; 816 817 // Runs the MemorySpaceAssignment pass. 818 static StatusOr<std::unique_ptr<PresetAssignments>> Run( 819 HloModule* module, const HloLiveRange& hlo_live_range, 820 const HloAliasAnalysis& alias_analysis, const Options& options); 821 822 // Calculates asynchronous copy statistics. 823 StatusOr<AsyncCopyStats> CalculateAsyncCopyStats() const; 824 825 static BufferIntervalCompare GetMemoryBoundednessBufferIntervalCompare( 826 const MemorySpaceAssignmentCostAnalysis& cost_analysis, 827 MemorySpaceAssignmentCostAnalysis::Cache* cache = nullptr); 828 829 // Verify that the memory space assignment is free of overlapping buffers and 830 // export heap simulator trace to be used by buffer_assignment. 831 Status VerifyAndExportHeapSimulatorTrace(); 832 833 protected: 834 // Main driver of the memory space assignment pass. 835 virtual StatusOr<std::unique_ptr<PresetAssignments>> RunMemorySpaceAssignment( 836 const HloLiveRange& hlo_live_range, 837 const HloAliasAnalysis& alias_analysis); 838 839 // Finds an AllocationSequence for placing buffers in alternate memory using 840 // the AlternateMemoryBestFitHeap algorithm. Must be set before Process() is 841 // called. 842 virtual Status FindAllocationSequence(const HloLiveRange& hlo_live_range, 843 const HloAliasAnalysis& alias_analysis); 844 options()845 const Options& options() const { return options_; } 846 MemorySpaceAssignment(HloModule * module,const Options & options,const HloLiveRange & hlo_live_range)847 MemorySpaceAssignment(HloModule* module, const Options& options, 848 const HloLiveRange& hlo_live_range) 849 : module_(module), 850 options_(options), 851 flattened_instructions_(hlo_live_range.flattened_instruction_sequence() 852 .instructions() 853 .begin(), 854 hlo_live_range.flattened_instruction_sequence() 855 .instructions() 856 .end()), 857 computations_in_schedule_(), 858 preset_assignments_(absl::make_unique<PresetAssignments>()) { 859 for (const auto& computation_and_bound : 860 hlo_live_range.computation_span_times()) { 861 computations_in_schedule_.insert(computation_and_bound.first); 862 } 863 } 864 865 AllocationSequence allocations_; 866 module()867 HloModule* module() { return module_; } 868 869 private: 870 // Process calls Process methods of the allocations after the allocations have 871 // been finalized. 872 Status Process(); 873 874 // Process() might have altered the computation graph by inserting kTuple and 875 // kGetTupleElement instructions. SimplifyGraph performs a simple DCE and 876 // tuple simplification operation (e.g., given GetTupleElement(Tuple(a, b), 877 // 1), simply forwards b). Runs to fixed point. 878 Status SimplifyGraph(); 879 880 // FixSchedule inserts asynchronous copies in the schedule. 881 Status FixSchedule(); 882 883 // Export the alternate memory assignments to the PresetAssignments and color 884 // the HLO graph with the determined memory spaces. 885 Status ExportAndColorBuffers(); 886 887 // Insert an instruction to the schedule, and make sure its dependencies 888 // (operands) are already in the schedule. If not, insert these operands 889 // before the instruction. 890 void EnsureInstructionAndOperandsInserted( 891 HloInstruction* new_instruction, HloInstructionSequence* new_sequence, 892 absl::flat_hash_set<HloInstruction*>* inserted_instructions) const; 893 894 // Schedules asynchronous copies and ensures that the CopyStarts and their 895 // corresponding CopyDones follow the same order. 896 void ScheduleAsynchronousCopies(); 897 898 // Remove the positions and chunks associated with the instruction from 899 // alternate_memory_assignments_. 900 void RemoveAssignmentForInstruction(const HloInstruction* instruction); 901 902 // Returns the estimated elapsed duration of the hlo module in seconds. It 903 // uses the 'allocations' argument to determine the location (default memory 904 // or alternate memory) of each operand and output of an instruction. 905 float ComputeEstimatedElapsedTime(const HloLiveRange& hlo_live_range, 906 const AllocationSequence& allocations); 907 908 HloModule* module_; 909 const Options& options_; 910 std::vector<HloInstruction*> flattened_instructions_; 911 absl::flat_hash_set<const HloComputation*> computations_in_schedule_; 912 std::unique_ptr<PresetAssignments> preset_assignments_; 913 std::vector<std::pair<HloPosition, Chunk>> alternate_memory_assignments_; 914 std::vector<std::pair<HloInstruction*, Chunk>> scoped_memory_assignments_; 915 int64 alternate_memory_size_ = 0; 916 917 // These maps hold vectors of new instructions that need to be scheduled after 918 // (or before) the instruction index in the key. FixSchedule uses these maps 919 // to modify and fix the schedule. 920 absl::flat_hash_map<int64, std::vector<HloInstruction*>> schedule_after_; 921 absl::flat_hash_map<int64, std::vector<HloInstruction*>> schedule_before_; 922 }; 923 924 // The different options to be passed to the Run() API. 925 struct Options { 926 // Backend-specific integer value that describes the alternate memory. 927 int64 alternate_memory_space = 0; 928 929 // Maximum size of the alternate memory space. 930 int64 max_size_in_bytes = 0; 931 932 // Memory alignment of the alternate memory space. 933 int64 alignment_in_bytes = 1; 934 935 // If provided, we sort the buffers using this comparison function 936 // otherwise, we use GlobalDecreasingSizeBestFitHeap::kSpatial. 937 absl::optional<MemorySpaceAssignment::BufferIntervalCompare> 938 buffer_interval_compare = absl::nullopt; 939 940 // This object determines how early and how late prefetches can occur. 941 PrefetchIntervalPicker* prefetch_interval_picker = nullptr; 942 943 // This object is used to determine the benefit of a particular allocation. 944 MemorySpaceAssignmentCostAnalysis* cost_analysis = nullptr; 945 946 // Size function for buffer values. 947 BufferValue::SizeFunction size_fn; 948 949 // This function can be used to prevent certain HloValues (e.g., based on 950 // the opcode) to be placed on the alternate memory. 951 MemorySpaceAssignment::IsAllowedInAlternateMemoryFunction 952 is_allowed_in_alternate_mem_fn; 953 954 // This function can be used to prevent certain HloUses (e.g., based on 955 // the opcode) to be placed on the alternate memory. 956 MemorySpaceAssignment::IsUseAllowedInAlternateMemoryFunction 957 is_use_allowed_in_alternate_mem_fn = [](const HloUse&) { return true; }; 958 959 // This function returns the amount of scoped memory in bytes that should be 960 // reserved during the execution of this instruction. 961 MemorySpaceAssignment::ReservedScopedMemoryFunction 962 reserved_scoped_memory_fn = [](const HloInstruction*) { return 0; }; 963 964 // If true, we allocate the reserved scoped memory at the same offset. This 965 // is useful to enable more deduplication between HLOs that have reserved 966 // scoped memories, but may result in less efficient memory packing. 967 bool allocate_reserved_scoped_memory_at_same_offset = true; 968 969 // Specifies the upper bound for number of outstanding prefetches and 970 // evictions, -1 for unlimited. 971 int64 max_outstanding_prefetches = -1; 972 int64 max_outstanding_evictions = -1; 973 974 // Extra outstanding prefetch limit for while uses (in addition to 975 // max_outstanding_prefetches). 976 int64 while_use_extra_outstanding_prefetch_limit = 0; 977 978 // Specifies the maximum number of times we are willing to move a copy 979 // done of a prefetch earlier due to an asynchronous copy ordering 980 // violation. 981 int64 prefetch_copy_done_reorder_max_retries = 1; 982 983 // Specifies the maximum number of retries that will be performed for each 984 // value in case prefetching failed due to running out of asynchronous 985 // copies or asynchronous copy ordering. 986 int64 max_retries = 1; 987 988 // The maximum number of repacks that we are willing to perform in case we 989 // can't allocate a buffer due to running out of memory. If this value is 990 // greater than 0, repacker must be non-nullptr. 991 int64 max_repacks = 0; 992 993 // This variable is used by the cost analysis in estimating how many times 994 // each while loop will execute. Nested loops will be assumed to have 995 // executed pow(while_execution_count, nesting_level) times. 996 uint64 xla_tpu_memory_space_assignment_while_execution_count = 5ULL; 997 998 float async_copy_bandwidth_bytes_per_second = 0.0f; 999 1000 float alternate_mem_bandwidth_bytes_per_second = 0.0f; 1001 1002 // The repacking algorithm to reduce fragmentation. Must be non-null if 1003 // max_repacks is greater than 0. 1004 MemorySpaceAssignmentRepacker* repacker = nullptr; 1005 1006 // This is only useful for testing, repack after every allocation. 1007 bool repack_after_every_allocation = false; 1008 1009 // If true, tries allocating buffers across (e.g., before and inside a while 1010 // loop body) sequential calls (kWhile, kCall, and kConditional). 1011 bool allocate_across_sequential_calls = false; 1012 1013 // If true, verifies the memory space assignment against overlapping 1014 // buffers. 1015 bool verify = false; 1016 1017 // If not nullptr, this function is called to dump debugging information. 1018 // The first argument is appended to the file name and the second argument 1019 // is the contents of the file. 1020 std::function<void(absl::string_view, absl::string_view)> dump_fn = nullptr; 1021 1022 // Enable prefetching buffers into preferred memory across program 1023 // boundaries 1024 bool enable_cross_program_prefetch = true; 1025 1026 // If true, use buffer_interval_compare to determine which buffers to 1027 // prefetch across program boundaries. 1028 bool default_cross_program_prefetch_heuristic = false; 1029 1030 // Enable cross-program prefetch freeing optimization where the 1031 // cross-program-prefetched buffer can be reused. 1032 bool enable_cross_program_prefetch_freeing = true; 1033 }; 1034 1035 // A struct representing an asynchronous copy with its logical start and end 1036 // time and its destination memory space. 1037 struct AsynchronousCopy { 1038 int64 start_time; 1039 int64 end_time; 1040 MemorySpaceAssignment::MemorySpace destination; 1041 }; 1042 1043 // Compare asynchronous copies such that an earlier start time has the same or 1044 // earlier end time and an earlier end time has the same or earlier start time. 1045 bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b); 1046 1047 // Helper class to enforce asynchronous copy ordering. We only allow 1048 // asynchronous copies that are pipelined: if an asynchronous copy ends earlier 1049 // than another asynchronous copy, it must start the same time or earlier than 1050 // the other asynchronous copy; and if an asynchronous copy starts earlier than 1051 // another asynchronous copy, it must end the same time or earlier than the 1052 // other asynchronous copy. 1053 class AsynchronousCopyOrdering { 1054 public: 1055 AsynchronousCopyOrdering() = default; 1056 1057 // Adds an asynchronous copy. 1058 void AddCopy(const AsynchronousCopy& copy); 1059 1060 // Removes an asynchronous copy. CHECKs that it is removed. 1061 void RemoveCopy(const AsynchronousCopy& copy); 1062 1063 // If the addition of an asynchronous copy in the given time interval would 1064 // violate the asynchronous copy ordering, returns the violating 1065 // already-committed asynchronous copy. E.g., consider the following scenario: 1066 // CS CD 1067 // already committed async copy: +-----------+ 1068 // new async copy: +--------+ 1069 // 1070 // The new asynchronous copy would violate the ordering guarantee because the 1071 // copy start is after an already committed asynchronous copy while its copy 1072 // done is before the committed copy. 1073 absl::optional<AsynchronousCopy> ViolatesOrdering(int64_t start_time, 1074 int64_t end_time) const; 1075 1076 private: 1077 // Stores asynchronous copies in a tree set respecting the pipelining order. 1078 std::set<AsynchronousCopy> ranges_; 1079 }; 1080 1081 // This class inherits from GlobalDecreasingSizeBestFitHeap with a notion of 1082 // maximum size. 1083 class AlternateMemoryBestFitHeap 1084 : public GlobalDecreasingSizeBestFitHeap<HloValue> { 1085 public: 1086 using MemorySpace = MemorySpaceAssignment::MemorySpace; 1087 using AllocationValue = MemorySpaceAssignment::AllocationValue; 1088 AlternateMemoryBestFitHeap(MemorySpaceAssignment::AllocationSequence * allocations,const Options & options,const HloAliasAnalysis & alias_analysis,const HloLiveRange & hlo_live_range)1089 AlternateMemoryBestFitHeap( 1090 MemorySpaceAssignment::AllocationSequence* allocations, 1091 const Options& options, const HloAliasAnalysis& alias_analysis, 1092 const HloLiveRange& hlo_live_range) 1093 : GlobalDecreasingSizeBestFitHeap(options.alignment_in_bytes), 1094 allocations_(allocations), 1095 options_(options), 1096 alias_analysis_(alias_analysis), 1097 hlo_live_range_(hlo_live_range) { 1098 // Override buffer interval compare if provided. 1099 if (options.buffer_interval_compare) { 1100 buffer_interval_compare_ = *options.buffer_interval_compare; 1101 } 1102 } 1103 1104 // Allocates a buffer in preferred memory with whole program lifetime and 1105 // enables prefetching prefetch_candidate from default memory across program 1106 // boundaries. 1107 void AllocateCrossProgramPrefetchBuffer( 1108 HloModule* module, absl::optional<BufferInterval> prefetch_candidate); 1109 1110 HeapSimulator::Result<HloValue> Finish() override; 1111 1112 protected: 1113 // Given a buffer interval, returns the colocated intervals. Unlike the 1114 // similar GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations, it 1115 // returns the colocated intervals sorted by scheduled time. 1116 std::vector<const BufferInterval*> GetSortedColocatedIntervals( 1117 const BufferInterval& interval) const; 1118 1119 // Given a BufferInterval, creates AllocationValue objects and corresponding 1120 // AllocationSequences and appends them into allocation_sequence_list_. 1121 void CreateAllocationValues( 1122 const BufferInterval& buffer_interval, 1123 std::vector<AllocationValue>& allocation_values) const; 1124 1125 // Given colocated intervals, populates allocation_values with the 1126 // corresponding AllocationValue objects. 1127 virtual void CreateAllocationValuesFromColocatedIntervals( 1128 absl::Span<const AlternateMemoryBestFitHeap::BufferInterval* const> 1129 colocated_intervals, 1130 std::vector<MemorySpaceAssignment::AllocationValue>& allocation_values); 1131 1132 // Go through all the uses in the AllocationValues and find the aliasing 1133 // positions. 1134 void FindAliases(std::vector<AllocationValue>* allocation_values) const; 1135 allocations()1136 MemorySpaceAssignment::AllocationSequence* allocations() { 1137 return allocations_; 1138 } options()1139 const Options& options() const { return options_; } alias_analysis()1140 const HloAliasAnalysis& alias_analysis() { return alias_analysis_; } hlo_live_range()1141 const HloLiveRange& hlo_live_range() { return hlo_live_range_; } 1142 1143 private: 1144 // We inherit AllocationBlock struct to attach the Allocation information to 1145 // make importing repacked offsets easier. 1146 struct RepackAllocationBlock 1147 : MemorySpaceAssignmentRepacker::AllocationBlock { 1148 MemorySpaceAssignment::Allocation* allocation; 1149 }; 1150 1151 // A data structure we use to associate Allocation objects that are aliased 1152 // and must get the same offset. 1153 struct AliasedOffset { 1154 int64 offset; 1155 absl::flat_hash_set<const MemorySpaceAssignment::Allocation*> allocations; 1156 }; 1157 1158 // An allocation request for a use segment. A use segment is the time segment 1159 // between the definition and the first use, and the time segment between the 1160 // uses of a buffer. For example, the time between the definition and Use1, is 1161 // the first segment, and the time between Use1 and Use2 is the second segment 1162 // and so on: 1163 // 1164 // +------+----------+-------+ 1165 // / \ \ \ 1166 // / v v v 1167 // Def Use1 Use2 Use3 1168 // <----------> <--------> <-----> 1169 // Segment Segment Segment 1170 // 1171 // start_time and end_time are the start and end logical times of the segment. 1172 // use_times is a sorted sequence of the times of all uses. 1173 // latest_prefetch_time is the latest time we can schedule the CopyDone for a 1174 // prefetch. 1175 // If allow_no_copy_alternate_mem_allocation is false, an eviction is forced. 1176 // If earliest_prefetch_time is set, prefetches cannot start before this 1177 // value. 1178 struct AllocationRequest { 1179 int64 start_time; 1180 int64 end_time; 1181 int64 latest_prefetch_time; 1182 int64 size; 1183 bool allow_no_copy_alternate_mem_allocation; 1184 absl::optional<int64> earliest_prefetch_time; 1185 AliasedOffset* preferred_offset; 1186 const MemorySpaceAssignment::AllocationValue::Use* use; 1187 MemorySpaceAssignment::AllocationValue* allocation_value; 1188 absl::Span<const int64_t> all_use_times; 1189 }; 1190 1191 // This struct contains mandatory memory assignments at a given time. E.g., an 1192 // input's required memory assignment time would correspond to the definition 1193 // time of the parameter instruction, and an output's time would correspond to 1194 // the time of last use. 1195 struct RequiredMemoryAssignment { 1196 MemorySpaceAssignment::MemorySpace memory_space; 1197 int64 time; 1198 AliasedOffset* offset; 1199 equals_ignoring_timeRequiredMemoryAssignment1200 bool equals_ignoring_time(const RequiredMemoryAssignment& other) const { 1201 return memory_space == other.memory_space && offset == other.offset; 1202 } 1203 1204 bool operator==(const RequiredMemoryAssignment& other) const { 1205 return memory_space == other.memory_space && time == other.time && 1206 offset == other.offset; 1207 } 1208 1209 bool operator!=(const RequiredMemoryAssignment& other) const { 1210 return !(*this == other); 1211 } 1212 }; 1213 1214 // Result of an allocation, prefetch, eviction etc. request. The result is 1215 // either kSuccess or a bitwise OR of one or more failures. The values are 1216 // unique powers of two. To check if a result contains a particular failure, 1217 // use the result_is method. To add a new failure to a result, use the 1218 // result_mark method. 1219 enum class Result { 1220 // Successful allocation. 1221 kSuccess = 0, 1222 // Allocation failed because we ran out of alternate memory. 1223 kFailOutOfMemory = 1, 1224 // A no-copy allocation couldn't be performed because the previous 1225 // allocation wasn't in the alternate memory space. 1226 kFailPrevAllocationNotInAlternateMem = 2, 1227 // A no-copy allocation couldn't be performed because the live range was too 1228 // long. 1229 kFailLiveRangeTooLong = 4, 1230 // A prefetching couldn't be performed because the live range was too short. 1231 kFailLiveRangeTooShort = 8, 1232 // Ran out of outstanding asynchronous copy limit either during prefetching 1233 // or eviction. 1234 kFailOutOfAsyncCopies = 16, 1235 // A prefetching couldn't be performed because the asynchronous copy 1236 // ordering was violated. 1237 kFailViolatesAsyncCopyOrdering = 32, 1238 // An allocation failure happened that requires uncommitting all the pending 1239 // allocations. Usually this is due to a situation requiring an eviction but 1240 // the eviction couldn't be performed. 1241 kFailRequiresUncommit = 64 1242 }; 1243 1244 // Return true if the result belongs to a failure. result_is(Result result,Result failure)1245 static bool result_is(Result result, Result failure) { 1246 return static_cast<int>(result) & static_cast<int>(failure); 1247 } 1248 1249 // Mark (bitwise OR) a failure to the result. result_mark(Result failure,Result & result)1250 static Result result_mark(Result failure, Result& result) { 1251 result = static_cast<Result>(static_cast<int>(result) | 1252 static_cast<int>(failure)); 1253 return result; 1254 } 1255 1256 // Return true if the result is a failure that requires us to uncommit pending 1257 // chunks. result_requires_uncommit(Result result)1258 static bool result_requires_uncommit(Result result) { 1259 return result_is(result, Result::kFailRequiresUncommit); 1260 } 1261 1262 // Return true if the result is a failure either due to running out of 1263 // outstanding asynchronous copies or due to violating asynchronous copy 1264 // ordering. result_failed_because_of_async_copy(Result result)1265 static bool result_failed_because_of_async_copy(Result result) { 1266 return result_is(result, Result::kFailOutOfAsyncCopies) || 1267 result_is(result, Result::kFailViolatesAsyncCopyOrdering); 1268 } 1269 1270 // Allocates buffers for instructions that need reserved scoped allocations in 1271 // the alternate memory space. 1272 void AllocateReservedScopedAllocations(); 1273 1274 // Returns the AliasedOffset object associated with the allocation. 1275 AliasedOffset* GetAliasedOffset( 1276 const MemorySpaceAssignment::Allocation& allocation); 1277 1278 // If aliased_offset is non-null, this method adds the allocation to 1279 // aliased_offset. Otherwise, it creates a new AliasedOffset object and adds 1280 // the allocation to this new AliasedOffset. 1281 void CreateOrAddToAliasedOffset( 1282 const MemorySpaceAssignment::Allocation& allocation, 1283 AliasedOffset* aliased_offset); 1284 1285 // Given an allocation sequence, returns the live allocation at time with a 1286 // preference towards allocations in alternate memory. Returns nullptr if no 1287 // allocation is alive at that time. 1288 static MemorySpaceAssignment::Allocation* GetLiveAllocationAt( 1289 const MemorySpaceAssignment::AllocationSequence& allocations, 1290 int64_t time); 1291 1292 // Returns true if the use is allowed in the alternate memory. 1293 bool IsUseAllowedInAlternateMemory(const AllocationValue& value, 1294 const HloUse& use) const; 1295 1296 // Finds allocations for allocation values generated from colocated intervals. 1297 // All of the allocation values have a must-alias relationship with each 1298 // other. Returns either kSuccess if all of the sites could be placed in the 1299 // alternate memory or a bitwise OR of failure reasons why they couldn't 1300 Result AllocateAllocationValues( 1301 absl::Span<AllocationValue> allocation_values); 1302 1303 // Finds an allocation for an allocation request for a segment (see the 1304 // documentation for AllocationRequest above how a segment is defined). 1305 // 1306 // It performs three things in the following order: 1307 // 1- Allocate the allocation request entirely in the alternate memory, if 1308 // there is enough space and if the prefetch interval picker allows. 1309 // 2- If (1) was unsuccessful, and the only allocation for 1310 // this buffer was in the alternate memory, we try to perform a prefetch. 1311 // 3- If (1) was unsuccessful, prefetch the buffer into the alternate memory, 1312 // if there is enough space and if the prefetch interval picker allows. 1313 // 1314 // If an eviction (2) was requested and was unsuccessful, this method returns 1315 // Result::kFailRequiresUncommit. This means we could not find a suitable 1316 // allocation, so all previous allocations for this buffer must be removed and 1317 // allocated in the default memory. Otherwise, this method may return 1318 // Result::kSuccess if the buffer could be placed in alternate memory or some 1319 // other Result with an OR of reasons why the buffer couldn't be placed in 1320 // alternate memory. 1321 Result AllocateSegment(const AllocationRequest& request); 1322 1323 // Try allocating in alternate memory without any copies. 1324 Result AllocateInAlternateMemoryNoCopy(const AllocationRequest& request); 1325 1326 // Try evicting to default memory space. 1327 Result Evict(const AllocationRequest& request); 1328 1329 // Returns the time a copy done of a prefetch should be scheduled. 1330 int64 FindPrefetchEndTime(const AllocationRequest& request, 1331 int64_t earliest_prefetch_time) const; 1332 1333 // Try prefetching to alternate memory space. 1334 Result Prefetch( 1335 const AllocationRequest& request, 1336 const MemorySpaceAssignment::Allocation& prev_allocation_in_default_mem); 1337 1338 // Find the best possible chunk candidate, where it has the longest possible 1339 // availability if no preferred offset is given, or at the preferred_offset if 1340 // it is given. 1341 absl::optional<ChunkCandidate> FindBestChunkCandidate( 1342 const AllocationRequest& request, const AliasedOffset* preferred_offset, 1343 BufferInterval* alternate_mem_interval) const; 1344 1345 // Returns the required assignment at a particular time, if available. 1346 absl::optional<RequiredMemoryAssignment> RequiredMemoryAssignmentAt( 1347 const HloValue* buffer, int64_t time) const; 1348 1349 // Searches for aliases in the use for a required assignment, and returns it 1350 // if found. 1351 absl::optional<RequiredMemoryAssignment> AliasedRequiredAssignmentForUse( 1352 const AllocationValue::Use& use) const; 1353 1354 // Goes through the colocated intervals and adds any required assignment. 1355 void AddRequiredAssignmentsForColocatedIntervals( 1356 absl::Span<const AlternateMemoryBestFitHeap::BufferInterval* const> 1357 colocated_intervals); 1358 1359 // Propagates aliased required assignment for a given position. 1360 void AddAliasedRequiredAssignment( 1361 const HloInstruction* instruction, ShapeIndex index, 1362 const MemorySpaceAssignment::Allocation* aliased_allocation); 1363 1364 // This sets a required assignment. CHECK fails if there is a conflicting 1365 // required assignment at the same time. 1366 void AddRequiredAssignment(const HloValue* value, 1367 const HloInstruction* instruction, 1368 MemorySpace memory_space, int64_t time, 1369 AliasedOffset* offset = nullptr); 1370 void AddRequiredAssignment(const HloInstruction* instruction, 1371 ShapeIndex index, MemorySpace memory_space, 1372 AliasedOffset* offset = nullptr); 1373 1374 // Adds input and outputs as required assignments. 1375 void AddInputAndOutputRequiredAssignments(); 1376 1377 // Returns true if the colocated intervals in the argument are in a parameter 1378 // or root instruction of the entry computation and are reserved by the user 1379 // to be in the alternate memory space. 1380 bool AreIntervalsReservedInAlternateMemory( 1381 absl::Span<const BufferInterval* const> colocated_intervals) const; 1382 1383 // Since the allocations are recorded to the AllocationSequence, we don't 1384 // maintain result_ in GlobalDecreasingSizeBestFitHeap. Override AddToChunkMap 1385 // to avoid unnecessarily adding the chunk to the chunk map. AddToChunkMap(const HloValue * buffer,Chunk chunk)1386 void AddToChunkMap(const HloValue* buffer, Chunk chunk) override {} 1387 1388 // Returns true if the addition of an asynchronous copy in the given time 1389 // interval would violate the maximum number of asynchronous copies. An extra 1390 // async copy limit can be provided to increase the limit of asynchronous 1391 // copies for this instance. 1392 bool ViolatesMaximumOutstandingAsyncCopies( 1393 int64_t start_time, int64_t end_time, bool is_prefetch, 1394 int64_t extra_async_copy_limit = 0) const; 1395 1396 // If the asynchronous copy would violate the pipelining order, returns the 1397 // violating asynchronous copy. 1398 absl::optional<AsynchronousCopy> ViolatesAsyncCopyOrdering( 1399 int64_t start_time, int64_t end_time) const; 1400 1401 // Exports the allocations for repacking and puts them into the vector in the 1402 // parameter. 1403 void ExportAllocationsForRepacking( 1404 std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*>& 1405 allocations); 1406 1407 // Imports repacked allocations and updates the internal data structures 1408 // consistent with the new packing. 1409 void ImportRepackedAllocations(); 1410 1411 // Adds an asynchronous copy to the allocations. 1412 void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation, 1413 MemorySpace memory_space, absl::optional<Chunk> chunk, 1414 int64_t start_time, int64_t end_time, 1415 int64_t copy_done_schedule_before_time, 1416 MemorySpaceAssignment::AllocationSequence* allocations, 1417 AliasedOffset* aliased_offset, 1418 bool is_cross_program_prefetch = false); 1419 1420 // This method is used for committing the chunk candidate but adding it to 1421 // pending_chunks_ so that we can "uncommit" them in case we need to roll back 1422 // this allocation sequence. 1423 void AddToPendingChunks(const BufferInterval& buffer_interval, 1424 const ChunkCandidate& chunk_candidate); 1425 // If we need to remove the allocations for this allocation sequence, this 1426 // removes pending chunks and asynchronous copies in the respective pending 1427 // buffers from the interval trees. If an allocation request returns 1428 // kFailRequiresUncommit, this method must be called. 1429 void UncommitPendingChunks(absl::Span<AllocationValue> allocation_values); 1430 1431 // Finalizes the allocations where they can no longer be uncommitted. 1432 void FinalizeAllocations(absl::Span<AllocationValue> allocation_values); 1433 1434 // Clears all pending chunks and asynchronous copies. 1435 void ClearPendingChunks(); 1436 1437 // Append buffer and allocation infos for debugging and dump it into a file, 1438 // if enabled. 1439 void AppendBufferInfoDebugString(const BufferInterval& interval, 1440 std::string* debug_str) const; 1441 void AppendAllocationInfoDebugString( 1442 const AllocationValue& value, 1443 const MemorySpaceAssignment::Allocation& allocation, 1444 std::string& debug_str) const; 1445 void DumpDebugStringsIfEnabled() const; 1446 1447 // Returns the available heap size in the alternate memory. available_heap_size()1448 int64 available_heap_size() const { 1449 return options_.max_size_in_bytes - reserved_in_bytes_; 1450 } 1451 1452 // Creates and returns a RepackAllocationBlock. MakeRepackAllocationBlock(int64_t start_time,int64_t end_time,int64_t size,int64_t initial_offset,int64_t id,MemorySpaceAssignment::Allocation * allocation)1453 static RepackAllocationBlock MakeRepackAllocationBlock( 1454 int64_t start_time, int64_t end_time, int64_t size, 1455 int64_t initial_offset, int64_t id, 1456 MemorySpaceAssignment::Allocation* allocation) { 1457 RepackAllocationBlock allocation_block; 1458 allocation_block.start_time = start_time; 1459 allocation_block.end_time = end_time; 1460 allocation_block.size = size; 1461 allocation_block.offset = -1; 1462 allocation_block.initial_offset = initial_offset; 1463 allocation_block.id = id; 1464 allocation_block.colocations = {}; 1465 allocation_block.allocation = allocation; 1466 return allocation_block; 1467 } 1468 1469 MemorySpaceAssignment::AllocationSequence* allocations_; 1470 const Options& options_; 1471 const HloAliasAnalysis& alias_analysis_; 1472 const HloLiveRange& hlo_live_range_; 1473 // We use a interval tree to keep track of the number of outstanding 1474 // prefetches and evictions. 1475 BufferIntervalTree prefetch_interval_tree_; 1476 BufferIntervalTree eviction_interval_tree_; 1477 AsynchronousCopyOrdering async_copy_ordering_; 1478 // A list of RepackAllocationBlock objects that mirrors allocation sequences, 1479 // used for repacking. We use a list here because we need pointer stability 1480 // for aliased allocations. 1481 std::list<RepackAllocationBlock> repack_allocation_blocks_; 1482 int64 num_repacks_ = 0; 1483 std::vector<std::pair<BufferInterval, ChunkCandidate>> pending_chunks_; 1484 std::vector<AsynchronousCopy> pending_async_copies_; 1485 std::vector<std::pair<const HloValue*, RequiredMemoryAssignment>> 1486 pending_required_assignments_; 1487 // The data structure that contains AliasedOffset objects and Allocation to 1488 // AliasedOffset map for efficient lookup. 1489 std::list<AliasedOffset> aliased_offsets_; 1490 absl::flat_hash_map<const MemorySpaceAssignment::Allocation*, AliasedOffset*> 1491 aliased_offset_map_; 1492 // This map contains required memory assignments for HloValues (e.g., input 1493 // and outputs). 1494 absl::flat_hash_map<const HloValue*, std::vector<RequiredMemoryAssignment>> 1495 required_assignments_; 1496 // Number of bytes reserved in alternate memory space. 1497 int64 reserved_in_bytes_ = 0; 1498 // Debug strings. 1499 std::string buffer_info_str_; 1500 std::string allocation_info_str_; 1501 }; 1502 } // namespace memory_space_assignment 1503 } // namespace xla 1504 1505 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_ 1506