1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_ 18 19 #include "tensorflow/compiler/xla/service/heap_simulator.h" 20 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" 21 #include "tensorflow/compiler/xla/service/memory_space_assignment_repacking.h" 22 23 namespace xla { 24 25 // This class contains pre-set assignments determined by memory space 26 // assignment. It contains two data structures: (1) a chunks vector that maps a 27 // defining HloPosition to a Chunk (offset and size), and (2) an assignment_info 28 // vector that maps the memory space to information like its allocated size and 29 // heap memory trace. If there is only one alternate memory space like there is 30 // currently, there will be one entry in assignment_info. 31 class PresetAssignments { 32 public: 33 // Contains per-memory-space information like the allocated size and heap 34 // simulator trace. 35 struct AssignmentInformation { 36 int64 size; 37 HeapSimulatorTrace heap_simulator_trace; 38 }; 39 40 PresetAssignments() = default; 41 add_chunk(const HloPosition & position,const HeapSimulator::Chunk & chunk)42 void add_chunk(const HloPosition& position, 43 const HeapSimulator::Chunk& chunk) { 44 chunks_.emplace_back(position, chunk); 45 } 46 assignment_information_for_space(int64 memory_space)47 AssignmentInformation* assignment_information_for_space(int64 memory_space) { 48 for (auto& space_and_info : assignment_info_) { 49 if (space_and_info.first == memory_space) { 50 return &space_and_info.second; 51 } 52 } 53 assignment_info_.emplace_back(memory_space, AssignmentInformation()); 54 return &assignment_info_.back().second; 55 } 56 chunks()57 absl::Span<const std::pair<HloPosition, HeapSimulator::Chunk>> chunks() 58 const { 59 return chunks_; 60 } 61 62 absl::Span<const std::pair<int64, AssignmentInformation>> assignment_informations()63 assignment_informations() const { 64 return assignment_info_; 65 } 66 67 // Get debugging information. buffer_info_str()68 std::string buffer_info_str() const { return buffer_info_str_; } allocation_info_str()69 std::string allocation_info_str() const { return allocation_info_str_; } 70 71 private: 72 std::vector<std::pair<HloPosition, HeapSimulator::Chunk>> chunks_; 73 std::vector<std::pair<int64, AssignmentInformation>> assignment_info_; 74 std::string buffer_info_str_; 75 std::string allocation_info_str_; 76 }; 77 78 // A wrapper class around HloCostAnalysis with additional knowledge about the 79 // bandwidths of different memory spaces. 80 class MemorySpaceAssignmentCostAnalysis { 81 public: 82 // An optional Cache object may be provided to some of the methods below to 83 // speed up the lookup. 84 struct Cache { 85 absl::flat_hash_map<const HloInstruction*, float> while_nest_multiplier; 86 }; 87 88 virtual ~MemorySpaceAssignmentCostAnalysis() = default; 89 90 static StatusOr<std::unique_ptr<MemorySpaceAssignmentCostAnalysis>> Create( 91 const HloCostAnalysis& cost_analysis, 92 float async_copy_bandwidth_bytes_per_second, 93 float alternate_mem_bandwidth_bytes_per_second, const HloModule& module); 94 cost_analysis()95 const HloCostAnalysis& cost_analysis() const { return cost_analysis_; } 96 97 // Returns a heuristic value that captures how much putting this tensor to the 98 // alternate memory would help if the op is memory bound, or otherwise how far 99 // off is the op to memory boundedness. The larger this number, the higher 100 // priority it will be placed in the alternate memory. 101 float GetAlternateMemoryBenefit(const HloInstruction& instruction, 102 float elapsed_time_due_to_alternate_mem, 103 Cache* cache = nullptr) const; 104 105 // Returns a heuristic value of memory boundedness for the given 106 // BufferInterval. The larger this number, the higher priority it will be 107 // placed in the alternate memory. 108 float GetMemoryBoundedness( 109 const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval, 110 Cache* cache = nullptr) const; 111 112 // Returns the elapsed time in seconds due to compute only. 113 float GetInstructionElapsedDueToCompute( 114 const HloInstruction& instruction) const; 115 116 // Returns the elapsed time in seconds due to memory only. If 117 // operand_in_alternate_mem is provided or if output_in_alternate_mem is true, 118 // it will assume that operand or output will be in the alternate memory 119 // space. This is useful for calculating the benefit of placing the buffer in 120 // alternate memory. 121 float GetInstructionElapsedDueToMemory( 122 const HloInstruction& instruction, 123 absl::optional<int64> operand_in_alternate_mem = absl::nullopt, 124 bool output_in_alternate_mem = false) const; 125 126 // Returns the estimated elapsed duration of the instruction in seconds. It 127 // assumes all operands and outputs of the instruction are in the default 128 // memory. 129 virtual float GetInstructionElapsed(const HloInstruction& instruction) const; 130 131 // Returns the estimated elapsed duration of the instruction in seconds. It 132 // assumes all operands and outputs of the instruction are in the default 133 // memory, except for the operand number that is in the alternate memory, if 134 // provided, or output if output_in_alternate_mem is true. 135 virtual float GetInstructionElapsedInAlternateMemory( 136 const HloInstruction& instruction, 137 absl::optional<int64> operand_in_alternate_mem, 138 bool output_in_alternate_mem) const; 139 140 // Returns the elapsed time it would take to asynchronously copy the shape 141 // from default to alternate memory space (or vice versa). 142 virtual float GetAsyncCopyElapsed(const Shape& shape) const; 143 144 int64 GetScheduleEndTime() const; 145 146 // Returns the number of nested computation levels this instruction resides 147 // in. If while_only is true, it returns the while loop nest level and 0 148 // means the instruction is not in a while loop. 149 int CalculateComputationNestLevel(const HloInstruction* instruction, 150 bool while_only) const; 151 hlo_live_range()152 const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; } 153 154 protected: MemorySpaceAssignmentCostAnalysis(const HloCostAnalysis & cost_analysis,float async_copy_bandwidth_bytes_per_second,float alternate_mem_bandwidth_bytes_per_second,std::unique_ptr<HloAliasAnalysis> alias_analysis,std::unique_ptr<HloLiveRange> hlo_live_range,std::unique_ptr<CallGraph> call_graph)155 MemorySpaceAssignmentCostAnalysis( 156 const HloCostAnalysis& cost_analysis, 157 float async_copy_bandwidth_bytes_per_second, 158 float alternate_mem_bandwidth_bytes_per_second, 159 std::unique_ptr<HloAliasAnalysis> alias_analysis, 160 std::unique_ptr<HloLiveRange> hlo_live_range, 161 std::unique_ptr<CallGraph> call_graph) 162 : cost_analysis_(cost_analysis), 163 async_copy_bandwidth_bytes_per_second_( 164 async_copy_bandwidth_bytes_per_second), 165 alternate_mem_bandwidth_bytes_per_second_( 166 alternate_mem_bandwidth_bytes_per_second), 167 alias_analysis_(std::move(alias_analysis)), 168 hlo_live_range_(std::move(hlo_live_range)), 169 call_graph_(std::move(call_graph)) {} 170 171 private: 172 const HloCostAnalysis& cost_analysis_; 173 float async_copy_bandwidth_bytes_per_second_; 174 float alternate_mem_bandwidth_bytes_per_second_; 175 std::unique_ptr<HloAliasAnalysis> alias_analysis_; 176 std::unique_ptr<HloLiveRange> hlo_live_range_; 177 std::unique_ptr<CallGraph> call_graph_; 178 }; 179 180 // Abstract base class that memory space assignment uses to pick prefetch 181 // intervals. 182 class PrefetchIntervalPicker { 183 public: 184 PrefetchIntervalPicker() = default; 185 virtual ~PrefetchIntervalPicker() = default; 186 187 // Returns true if the buffer can be allocated in alternate memory space 188 // without any copies (prefetches). 189 virtual bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, 190 int64 start_time, 191 int64 end_time) const = 0; 192 193 // Returns the preferred end time for an eviction that starts at a given time 194 // and must end by the given end time. 195 virtual int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time, 196 int64 latest_end_time) const = 0; 197 198 // Returns the latest time that a prefetch can start. 199 virtual int64 LatestPrefetchStartTime(const Shape& shape, int64 start_time, 200 int64 end_time, 201 const HloUse* use) const = 0; 202 203 // Returns the preferred time that a prefetch can start. 204 virtual int64 PreferredPrefetchStartTime(const Shape& shape, 205 int64 earliest_prefetch_start_time, 206 int64 latest_prefetch_start_time, 207 int64 prefetch_end_time) const = 0; 208 209 // Returns the latest time that a prefetch can end that is less than or equal 210 // to proposed_prefetch_end_time. LatestPrefetchEndTime(int64 original_prefetch_end_time,int64 proposed_prefetch_end_time)211 virtual int64 LatestPrefetchEndTime(int64 original_prefetch_end_time, 212 int64 proposed_prefetch_end_time) const { 213 return proposed_prefetch_end_time; 214 } 215 216 // Begins the iterator for the first start time of the prefetch. 217 virtual void Begin(const HloUse& use, int64 start_time, int64 end_time) = 0; 218 219 // Advances the start time of the prefetch and returns that value. 220 virtual int64 Next() = 0; 221 222 // Returns true if the available prefetch intervals have been exhausted. 223 virtual bool Done() const = 0; 224 225 // The retry number can be used to modify the interval picking policies. The 226 // first attempt will have a retry_number of 0, then 1, etc. SetRetryNumber(int retry_number)227 virtual void SetRetryNumber(int retry_number) {} 228 229 // Returns a debug string for the current state of the prefetch interval 230 // picker. 231 virtual std::string ToDebugString() const = 0; 232 233 // Returns a debug string for no-copy allocation. 234 virtual std::string ToNoCopyDebugString(const Shape& shape, int64 start_time, 235 int64 end_time) const = 0; 236 237 // Prefetch interval pickers may return a value corresponding to the benefit 238 // of placing the BufferInterval in the alternate memory. The larger value, 239 // the more beneficial. BufferIntervalAlternateMemoryBenefit(const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval & interval)240 virtual absl::optional<float> BufferIntervalAlternateMemoryBenefit( 241 const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval) 242 const { 243 return absl::nullopt; 244 } 245 246 protected: 247 const absl::flat_hash_map<const HloInstruction*, int64>* 248 instruction_schedule_ = nullptr; 249 }; 250 251 // Prefetch interval picker that uses instruction count to overlap asynchronous 252 // copies with independent computation. The min and max overlap counts describe 253 // the number of independent HLOs overlapped while a value is being prefetched 254 // into the alternate memory (between CopyStart and CopyDone HLO instructions). 255 // max_overlap_count attempts to prevent bringing tensors into the alternate 256 // memory too eagerly and hence occupying the space for other tensors which 257 // might use it. min_overlap_count attempts to prevent cases where tensors are 258 // prefetched into the alternate memory without sufficient time for the copy to 259 // take place. In those cases, it's just better to keep the tensor in the 260 // default memory instead of hurting the critical path with this copy that 261 // likely won't finish in time. 262 class InstructionCountPrefetchIntervalPicker : public PrefetchIntervalPicker { 263 public: InstructionCountPrefetchIntervalPicker(int64 min_overlap_count,int64 max_overlap_count)264 InstructionCountPrefetchIntervalPicker(int64 min_overlap_count, 265 int64 max_overlap_count) 266 : min_overlap_count_(min_overlap_count), 267 max_overlap_count_(max_overlap_count) {} 268 269 bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, int64 start_time, 270 int64 end_time) const override; 271 272 int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time, 273 int64 latest_end_time) const override; 274 275 int64 LatestPrefetchStartTime(const Shape& shape, int64 start_time, 276 int64 end_time, 277 const HloUse* use) const override; 278 279 int64 PreferredPrefetchStartTime(const Shape& shape, 280 int64 earliest_prefetch_start_time, 281 int64 latest_prefetch_start_time, 282 int64 prefetch_end_time) const override; 283 284 void Begin(const HloUse& use, int64 start_time, int64 end_time) override; 285 286 int64 Next() override; 287 bool Done() const override; 288 289 std::string ToDebugString() const override; 290 std::string ToNoCopyDebugString(const Shape& shape, int64 start_time, 291 int64 end_time) const override; 292 293 private: 294 int64 min_overlap_count_; 295 int64 max_overlap_count_; 296 int64 end_time_; 297 int64 current_prefetch_time_; 298 }; 299 300 // Prefetch interval picker that uses cost analysis to overlap asynchronous 301 // copies with independent computation. It uses min/max (asynchronous copy 302 // duration) / (independent computation duration) ratios to guide whether the 303 // prefetch is within those bounds. It starts with the preferred ratio in 304 // Begin() and works its way for alternately earlier and later prefetches until 305 // hitting min and max ratios. 306 class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { 307 public: 308 CostAnalysisPrefetchIntervalPicker( 309 const MemorySpaceAssignmentCostAnalysis& cost_analysis, 310 float min_async_copy_to_overlap_ratio, 311 float max_async_copy_to_overlap_ratio, 312 float preferred_async_copy_to_overlap_ratio); 313 314 bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, int64 start_time, 315 int64 end_time) const override; 316 317 int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time, 318 int64 latest_end_time) const override; 319 320 int64 LatestPrefetchEndTime(int64 original_prefetch_end_time, 321 int64 proposed_prefetch_end_time) const override; 322 323 int64 LatestPrefetchStartTime(const Shape& shape, int64 start_time, 324 int64 end_time, 325 const HloUse* use) const override; 326 327 int64 PreferredPrefetchStartTime(const Shape& shape, 328 int64 earliest_prefetch_start_time, 329 int64 latest_prefetch_start_time, 330 int64 prefetch_end_time) const override; 331 332 void Begin(const HloUse& use, int64 start_time, int64 end_time) override; 333 334 int64 Next() override; 335 bool Done() const override; 336 337 void SetRetryNumber(int retry_number) override; 338 339 std::string ToDebugString() const override; 340 std::string ToNoCopyDebugString(const Shape& shape, int64 start_time, 341 int64 end_time) const override; 342 343 absl::optional<float> BufferIntervalAlternateMemoryBenefit( 344 const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval) 345 const override; 346 347 private: 348 // Returns the elapsed time in seconds between the logical interval that 349 // corresponds to the instruction schedule. 350 float GetLogicalIntervalElapsed(int64 start_time, int64 end_time) const; 351 352 // Finds the minimum nest level in the given interval. 353 int GetMinWhileNestLevel(int64 start_time, int64 end_time) const; 354 355 // For each instruction in the flattened schedule, maintain their elapsed time 356 // (in cumulative sum) and while nesting level. 357 std::vector<float> elapsed_time_cumsum_; 358 std::vector<int> while_nest_level_; 359 std::vector<int> computation_nest_level_; 360 // Maintain the index of the most recent (before this instruction) nest level 361 // change in order to efficiently determine the minimum nest level in an 362 // interval. 363 std::vector<int> while_nest_level_change_; 364 365 const MemorySpaceAssignmentCostAnalysis& cost_analysis_; 366 float min_async_copy_to_overlap_ratio_; 367 float max_async_copy_to_overlap_ratio_; 368 float preferred_async_copy_to_overlap_ratio_; 369 float max_overlap_multiplier_ = 1.0; 370 371 float async_copy_elapsed_; 372 float inst_elapsed_reduction_; 373 int64 end_logical_time_; 374 int64 earliest_prefetch_time_; 375 int64 latest_prefetch_time_; 376 bool using_increasing_prefetch_time_iterator_ = true; 377 int64 increasing_prefetch_time_iterator_; 378 int64 decreasing_prefetch_time_iterator_; 379 }; 380 381 // MemorySpaceAssignment assigns memory spaces (default or alternate) to each 382 // instruction in the module. It will greedily try placing as as many values in 383 // the alternate memory space as possible. It uses the heap simulator to 384 // determine the actual allocation offsets of values in the alternate memory 385 // space to account for fragmentation. The default memory space is assumed to be 386 // large enough to hold the values that could not be placed in the alternate 387 // memory space. 388 class MemorySpaceAssignment { 389 public: 390 using Chunk = HeapSimulator::Chunk; 391 using BufferInterval = 392 GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval; 393 using BufferIntervalCompare = 394 GlobalDecreasingSizeBestFitHeap<HloValue>::BufferIntervalCompare; 395 using IsAllowedInAlternateMemoryFunction = 396 std::function<bool(const HloValue&)>; 397 using IsUseAllowedInAlternateMemoryFunction = 398 std::function<bool(const HloUse&)>; 399 400 // MemorySpaceAssignment uses a notion of a slow and large default memory 401 // space and a fast and small alternate memory space. 402 enum class MemorySpace { kDefault, kAlternate }; 403 404 // Forward declaration for Allocation. 405 class Allocation; 406 407 // The different options to be passed to the Run() API. 408 struct Options { 409 // Backend-specific integer value that describes the alternate memory. 410 int64 alternate_memory_space = 0; 411 412 // Maximum size of the alternate memory space. 413 int64 max_size_in_bytes = 0; 414 415 // Memory alignment of the alternate memory space. 416 int64 alignment_in_bytes = 1; 417 418 // If provided, we sort the buffers using this comparison function 419 // otherwise, we use GlobalDecreasingSizeBestFitHeap::kSpatial. 420 absl::optional<BufferIntervalCompare> buffer_interval_compare = 421 absl::nullopt; 422 423 // This object determines how early and how late prefetches can occur. 424 PrefetchIntervalPicker* prefetch_interval_picker = nullptr; 425 426 // Size function for buffer values. 427 BufferValue::SizeFunction size_fn; 428 429 // This function can be used to prevent certain HloValues (e.g., based on 430 // the opcode) to be placed on the alternate memory. 431 IsAllowedInAlternateMemoryFunction is_allowed_in_alternate_mem_fn; 432 433 // This function can be used to prevent certain HloUses (e.g., based on 434 // the opcode) to be placed on the alternate memory. 435 IsUseAllowedInAlternateMemoryFunction is_use_allowed_in_alternate_mem_fn = 436 [](const HloUse&) { return true; }; 437 438 // Specifies the upper bound for number of outstanding prefetches and 439 // evictions, -1 for unlimited. 440 int64 max_outstanding_prefetches = -1; 441 int64 max_outstanding_evictions = -1; 442 443 // Extra outstanding prefetch limit for while uses (in addition to 444 // max_outstanding_prefetches). 445 int64 while_use_extra_outstanding_prefetch_limit = 0; 446 447 // Specifies the maximum number of times we are willing to move a copy 448 // done of a prefetch earlier due to an asynchronous copy ordering 449 // violation. 450 int64 prefetch_copy_done_reorder_max_retries = 1; 451 452 // Specifies the maximum number of retries that will be performed for each 453 // value in case prefetching failed due to running out of asynchronous 454 // copies or asynchronous copy ordering. 455 int64 max_retries = 1; 456 457 // The maximum number of repacks that we are willing to perform in case we 458 // can't allocate a buffer due to running out of memory. If this value is 459 // greater than 0, repacker must be non-nullptr. 460 int64 max_repacks = 0; 461 462 // The repacking algorithm to reduce fragmentation. Must be non-null if 463 // max_repacks is greater than 0. 464 MemorySpaceAssignmentRepacker* repacker = nullptr; 465 466 // This is only useful for testing, repack after every allocation. 467 bool repack_after_every_allocation = false; 468 469 // If true, tries allocating buffers across (e.g., before and inside a while 470 // loop body) sequential calls (kWhile, kCall, and kConditional). 471 bool allocate_across_sequential_calls = false; 472 473 // If true, verifies the memory space assignment against overlapping 474 // buffers. 475 bool verify = false; 476 477 // If not nullptr, this function is called to dump debugging information. 478 // The first argument is appended to the file name and the second argument 479 // is the contents of the file. 480 std::function<void(absl::string_view, absl::string_view)> dump_fn = nullptr; 481 482 // Enable prefetching buffers into preferred memory across program 483 // boundaries 484 bool enable_cross_program_prefetch = true; 485 486 // If true, use buffer_interval_compare to determine which buffers to 487 // prefetch across program boundaries. 488 bool default_cross_program_prefetch_heuristic = false; 489 }; 490 491 // This class represents an allocation that might either be in the default or 492 // alternate memory. An HloValue might live in multiple different allocations 493 // over its lifetime. The lifetimes of the allocations are defined using 494 // start_time and end_time, which corresponds to the instruction indexes in 495 // the flattened schedule. Each of these allocations might partially overlap 496 // with each other. CopyAllocation defined below represents asynchronous 497 // copies between Allocations. 498 // 499 // Consider an instruction Foo, and its users Bar and Baz, and the times given 500 // in terms of the flattened schedule of the entire module: 501 // 502 // Foo:10 503 // / \ 504 // Bar:14 \ 505 // Baz:25 506 // 507 // A valid memory space assignment could be like the following: 508 // 509 // Time: 10 ... 14 ... 25 510 // Foo Bar Baz 511 // Alternate +-------+ +-----+ 512 // Default +---------------------+ 513 // ^ ^ ^ ^ 514 // | | | | 515 // evict evict prefetch prefetch 516 // start end start end 517 // 518 // This would be represented with: 519 // - Allocation(memory_space=kAlternate, start_time=10, end_time=14) 520 // - CopyAllocation(memory_space=kDefault, start_time=12, end_time=25) 521 // - CopyAllocation(memory_space=kAlternate, start_time=22, end_time=25) 522 class Allocation { 523 public: Allocation(HloPosition defining_position,MemorySpace memory_space,absl::optional<Chunk> chunk,int64 start_time,int64 end_time)524 Allocation(HloPosition defining_position, MemorySpace memory_space, 525 absl::optional<Chunk> chunk, int64 start_time, int64 end_time) 526 : defining_position_(defining_position), 527 memory_space_(memory_space), 528 chunk_(chunk), 529 start_time_(start_time), 530 end_time_(end_time) {} 531 virtual ~Allocation() = default; 532 is_copy_allocation()533 virtual bool is_copy_allocation() const { return false; } 534 535 // Adds a use to this allocation. 536 void AddUse(HloUse use); 537 538 // Extends the end time of this allocation. Extend(int64 end_time)539 void Extend(int64 end_time) { end_time_ = end_time; } 540 541 // After all of the time ranges for the allocations have been assigned, 542 // Process morphs the instructions affected to assign the memory spaces and 543 // insert asynchronous copy instructions if necessary. 544 virtual Status Process(MemorySpaceAssignment* memory_space_assignment); 545 546 // Returns the defining position for this allocation. defining_position()547 virtual HloPosition defining_position() const { return defining_position_; } 548 549 // Returns the time the buffer is first available to be used. For 550 // Allocation, this is start_time. earliest_available_time()551 virtual int64 earliest_available_time() const { return start_time_; } 552 uses()553 const std::vector<HloUse>& uses() const { return uses_; } memory_space()554 MemorySpace memory_space() const { return memory_space_; } chunk()555 Chunk chunk() const { return *chunk_; } mutable_chunk()556 Chunk* mutable_chunk() { return &*chunk_; } set_start_time(int64 start_time)557 void set_start_time(int64 start_time) { start_time_ = start_time; } start_time()558 int64 start_time() const { return start_time_; } end_time()559 int64 end_time() const { return end_time_; } 560 561 bool operator==(const Allocation& other) const; 562 virtual std::string ToString() const; 563 564 protected: 565 // Descend to the shape_index element of the tuple and replace that with 566 // new_instruction. 567 StatusOr<HloInstruction*> ReplaceTupleWith(HloInstruction* new_instruction, 568 HloInstruction* tuple, 569 ShapeIndex shape_index); 570 571 // Recursively create kGetTupleElement instructions if the defining position 572 // shape is not an array. Returns the new instruction that has array shape. 573 HloInstruction* AddGetTupleElements(); 574 575 HloPosition defining_position_; 576 std::vector<HloUse> uses_; 577 MemorySpace memory_space_; 578 absl::optional<Chunk> chunk_; 579 int64 start_time_; 580 int64 end_time_; 581 }; 582 583 // This class represents an allocation as a result of an asynchronous copy. 584 // Note: CopyStart instructions are inserted after `start_time` or later, 585 // while CopyDone instructions are inserted before 586 // `copy_done_schedule_before_time` or earlier. 587 class CopyAllocation : public Allocation { 588 public: 589 CopyAllocation(const Allocation& prev_allocation, MemorySpace memory_space, 590 absl::optional<Chunk> chunk, int64 start_time, 591 int64 end_time, int64 copy_done_schedule_before_time, 592 bool is_cross_program_prefetch = false) 593 : Allocation(/*defining_position=*/{nullptr, {}}, memory_space, chunk, 594 start_time, end_time), 595 prev_allocation_(prev_allocation), 596 copy_start_schedule_after_(start_time), 597 copy_done_schedule_before_(copy_done_schedule_before_time), 598 is_cross_program_prefetch_(is_cross_program_prefetch) {} 599 is_copy_allocation()600 bool is_copy_allocation() const override { return true; } 601 602 Status Process(MemorySpaceAssignment* memory_space_assignment) override; 603 defining_position()604 HloPosition defining_position() const override { 605 // Unless explicitly set, the defining position of a copy allocation in 606 // retrieved from the previous allocation. This is because we don't create 607 // new CopyStart/CopyDone instructions until later and the position should 608 // point to the previous (copy or otherwise) allocation's position for the 609 // original defining position. 610 if (defining_position_.instruction == nullptr) { 611 return prev_allocation_.defining_position(); 612 } else { 613 return defining_position_; 614 } 615 } 616 copy_start()617 HloInstruction* copy_start() const { return copy_start_; } copy_done()618 HloInstruction* copy_done() const { return copy_done_; } 619 620 // Returns the time the buffer is first available to be used. For For 621 // CopyAllocation, this is when the copy ends, which is 622 // copy_done_schedule_before. earliest_available_time()623 int64 earliest_available_time() const override { 624 return copy_done_schedule_before_; 625 } 626 copy_start_schedule_after()627 int64 copy_start_schedule_after() const { 628 return copy_start_schedule_after_; 629 } copy_done_schedule_before()630 int64 copy_done_schedule_before() const { 631 return copy_done_schedule_before_; 632 } 633 set_copy_start_schedule_after(int64 copy_start_schedule_after)634 void set_copy_start_schedule_after(int64 copy_start_schedule_after) { 635 copy_start_schedule_after_ = copy_start_schedule_after; 636 } 637 is_cross_program_prefetch()638 bool is_cross_program_prefetch() const { 639 return is_cross_program_prefetch_; 640 } 641 642 bool operator==(const CopyAllocation& other) const; 643 std::string ToString() const override; 644 645 private: 646 const Allocation& prev_allocation_; 647 // These variables define the scheduling boundaries where CopyStart and 648 // CopyDone can be scheduled. The earliest CopyStart can be scheduled is 649 // after copy_start_schedule_after_ and the latest CopyDone can be scheduled 650 // is before copy_done_schedule_before_. 651 int64 copy_start_schedule_after_; 652 int64 copy_done_schedule_before_; 653 bool is_cross_program_prefetch_; 654 HloInstruction* copy_start_; 655 HloInstruction* copy_done_; 656 }; 657 658 using AllocationSequence = std::vector<std::unique_ptr<Allocation>>; 659 // AllocationValue is used to break up HloValues for each non-trivial position 660 // (trivial positions are considered Tuple, GetTupleElement, and Bitcast). An 661 // HloValue may include positions and uses that alias with each other across 662 // multiple computations. We use this class to break these HloValues such that 663 // every AllocationValue has one defining position (that may alias with other 664 // AllocationValues). The uses field of the AllocationValue contains only the 665 // direct uses of the AllocationValue's defining position. 666 // 667 // For example, consider the following HLO snippet: 668 // 669 // Body { 670 // body_param = (f32[4,3]{1,0}, f32[]) parameter(0) 671 // get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element(body_param), 672 // index=0 673 // ... 674 // ROOT tuple = (f32[4,3]{1,0}, f32[]) tuple(get-tuple-element.3, ...) 675 // } 676 // 677 // Cond { 678 // cond_param = (f32[4,3]{1,0}, f32[]) parameter(0) 679 // ... 680 // } 681 // 682 // add.4 = f32[4,3]{1,0} add(...) 683 // tuple.1 = (f32[4,3]{1,0}, f32[]) tuple(add.4, ...) 684 // while = (f32[4,3]{1,0}, f32[]) while(tuple.1), body=Body, condition=Cond 685 // get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element(while), index=0 686 // add.5 = f32[4,3]{1,0} add(get-tuple-element.5, ...) 687 // 688 // This contains an HloValue that looks like the following: 689 // positions: 690 // add.4 691 // body_param {0} 692 // get-tuple-element.3 693 // tuple {0} 694 // cond_param {0} 695 // tuple.1 {0} 696 // while {0} 697 // get-tuple-element.5 698 // uses: 699 // add.1, operand 0 700 // tuple, operand 0 701 // while, operand 0 {0} 702 // add.5, operand 0 703 // 704 // We break this HloValue up into the following AllocationValues for each 705 // non-trivial position: 706 // AllocationValue1: computation = Entry 707 // position: 708 // add.4 709 // uses: 710 // while, operand 0 {0} 711 // AllocationValue2: computation = Cond 712 // position: 713 // cond_param {0} 714 // uses: 715 // AllocationValue3: computation = Body 716 // position: 717 // body_param {0} 718 // uses: 719 // add.1, operand 0 720 // tuple, operand 0 721 // AllocationValue4: computation = Entry 722 // position: 723 // while {0} 724 // uses: 725 // add.5, operand 0 726 class AllocationValue { 727 public: 728 // This data structure wraps an HloUse and adds additional metadata that are 729 // useful for allocation. 730 struct Use { 731 // The wrapped HloUse object. 732 HloUse hlo_use; 733 // The logical time this use is scheduled. 734 int64 time; 735 // All the positions where this use aliases with. The aliased positions 736 // must get the same allocation. 737 std::vector<HloPosition> aliases; 738 739 bool operator==(const Use& other) const { 740 return hlo_use == other.hlo_use && time == other.time && 741 aliases == other.aliases; 742 } 743 744 template <typename H> AbslHashValueUse745 friend H AbslHashValue(H h, const Use& s) { 746 return H::combine(std::move(h), s.hlo_use, s.time, s.aliases); 747 } 748 }; 749 AllocationValue(const HloValue * value,const HloPosition & position,int64 size)750 AllocationValue(const HloValue* value, const HloPosition& position, 751 int64 size) 752 : value_(value), defining_position_(position), size_(size) {} 753 defining_position()754 const HloPosition& defining_position() const { return defining_position_; } defining_instruction()755 const HloInstruction* defining_instruction() const { 756 return defining_position().instruction; 757 } size()758 int64 size() const { return size_; } uses()759 const std::vector<Use>& uses() const { return uses_; } uses()760 std::vector<Use>& uses() { return uses_; } value()761 const HloValue* value() const { return value_; } computation()762 const HloComputation* computation() const { 763 return defining_instruction()->parent(); 764 } allocation_sequence()765 AllocationSequence* allocation_sequence() { return &allocation_sequence_; } 766 AddUse(const HloUse & use,int64 use_time)767 void AddUse(const HloUse& use, int64 use_time) { 768 uses_.push_back({use, use_time, {}}); 769 } 770 771 std::string ToString() const; 772 std::string ToShortString() const; 773 774 private: 775 const HloValue* value_; 776 HloPosition defining_position_; 777 int64 size_; 778 std::vector<Use> uses_; 779 AllocationSequence allocation_sequence_; 780 }; 781 782 // Statistics of asynchronous copies. 783 struct AsyncCopyStats { 784 int64 max_outstanding_async_copies; 785 int64 num_prefetches; 786 int64 prefetch_bytes; 787 int64 num_evictions; 788 int64 eviction_bytes; 789 }; 790 791 virtual ~MemorySpaceAssignment() = default; 792 793 // Runs the MemorySpaceAssignment pass. 794 static StatusOr<std::unique_ptr<PresetAssignments>> Run( 795 HloModule* module, const HloLiveRange& hlo_live_range, 796 const HloAliasAnalysis& alias_analysis, const Options& options); 797 798 // Calculates asynchronous copy statistics. 799 StatusOr<AsyncCopyStats> CalculateAsyncCopyStats() const; 800 801 static BufferIntervalCompare GetMemoryBoundednessBufferIntervalCompare( 802 const MemorySpaceAssignmentCostAnalysis& cost_analysis, 803 MemorySpaceAssignmentCostAnalysis::Cache* cache = nullptr); 804 805 // Verify that the memory space assignment is free of overlapping buffers and 806 // export heap simulator trace to be used by buffer_assignment. 807 Status VerifyAndExportHeapSimulatorTrace(); 808 809 protected: 810 // Main driver of the memory space assignment pass. 811 virtual StatusOr<std::unique_ptr<PresetAssignments>> RunMemorySpaceAssignment( 812 const HloLiveRange& hlo_live_range, 813 const HloAliasAnalysis& alias_analysis); 814 815 // Finds an AllocationSequence for placing buffers in alternate memory using 816 // the AlternateMemoryBestFitHeap algorithm. Must be set before Process() is 817 // called. 818 virtual Status FindAllocationSequence(const HloLiveRange& hlo_live_range, 819 const HloAliasAnalysis& alias_analysis); 820 options()821 Options options() const { return options_; } 822 MemorySpaceAssignment(HloModule * module,Options options,const HloLiveRange & hlo_live_range)823 MemorySpaceAssignment(HloModule* module, Options options, 824 const HloLiveRange& hlo_live_range) 825 : module_(module), 826 options_(options), 827 flattened_instructions_(hlo_live_range.flattened_instruction_sequence() 828 .instructions() 829 .begin(), 830 hlo_live_range.flattened_instruction_sequence() 831 .instructions() 832 .end()), 833 computations_in_schedule_(), 834 preset_assignments_(absl::make_unique<PresetAssignments>()) { 835 for (const auto& computation_and_bound : 836 hlo_live_range.computation_span_times()) { 837 computations_in_schedule_.insert(computation_and_bound.first); 838 } 839 } 840 841 AllocationSequence allocations_; 842 module()843 HloModule* module() { return module_; } 844 845 private: 846 // Process calls Process methods of the allocations after the allocations have 847 // been finalized. 848 Status Process(); 849 850 // Process() might have altered the computation graph by inserting kTuple and 851 // kGetTupleElement instructions. SimplifyGraph performs a simple DCE and 852 // tuple simplification operation (e.g., given GetTupleElement(Tuple(a, b), 853 // 1), simply forwards b). Runs to fixed point. 854 Status SimplifyGraph(); 855 856 // FixSchedule inserts asynchronous copies in the schedule. 857 Status FixSchedule(); 858 859 // Export the alternate memory assignments to the PresetAssignments and color 860 // the HLO graph with the determined memory spaces. 861 Status ExportAndColorBuffers(); 862 863 // Insert an instruction to the schedule, and make sure its dependencies 864 // (operands) are already in the schedule. If not, insert these operands 865 // before the instruction. 866 void EnsureInstructionAndOperandsInserted( 867 HloInstruction* new_instruction, HloInstructionSequence* new_sequence, 868 absl::flat_hash_set<HloInstruction*>* inserted_instructions) const; 869 870 // Schedules asynchronous copies and ensures that the CopyStarts and their 871 // corresponding CopyDones follow the same order. 872 void ScheduleAsynchronousCopies(); 873 874 // Remove the positions and chunks associated with the instruction from 875 // alternate_memory_assignments_. 876 void RemoveAssignmentForInstruction(const HloInstruction* instruction); 877 878 HloModule* module_; 879 Options options_; 880 std::vector<HloInstruction*> flattened_instructions_; 881 absl::flat_hash_set<const HloComputation*> computations_in_schedule_; 882 std::unique_ptr<PresetAssignments> preset_assignments_; 883 std::vector<std::pair<HloPosition, Chunk>> alternate_memory_assignments_; 884 int64 alternate_memory_size_ = 0; 885 886 // These maps hold vectors of new instructions that need to be scheduled after 887 // (or before) the instruction index in the key. FixSchedule uses these maps 888 // to modify and fix the schedule. 889 absl::flat_hash_map<int64, std::vector<HloInstruction*>> schedule_after_; 890 absl::flat_hash_map<int64, std::vector<HloInstruction*>> schedule_before_; 891 }; 892 893 // A struct representing an asynchronous copy with its logical start and end 894 // time and its destination memory space. 895 struct AsynchronousCopy { 896 int64 start_time; 897 int64 end_time; 898 MemorySpaceAssignment::MemorySpace destination; 899 }; 900 901 // Compare asynchronous copies such that an earlier start time has the same or 902 // earlier end time and an earlier end time has the same or earlier start time. 903 bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b); 904 905 // Helper class to enforce asynchronous copy ordering. We only allow 906 // asynchronous copies that are pipelined: if an asynchronous copy ends earlier 907 // than another asynchronous copy, it must start the same time or earlier than 908 // the other asynchronous copy; and if an asynchronous copy starts earlier than 909 // another asynchronous copy, it must end the same time or earlier than the 910 // other asynchronous copy. 911 class AsynchronousCopyOrdering { 912 public: 913 AsynchronousCopyOrdering() = default; 914 915 // Adds an asynchronous copy. 916 void AddCopy(const AsynchronousCopy& copy); 917 918 // Removes an asynchronous copy. CHECKs that it is removed. 919 void RemoveCopy(const AsynchronousCopy& copy); 920 921 // If the addition of an asynchronous copy in the given time interval would 922 // violate the asynchronous copy ordering, returns the violating 923 // already-committed asynchronous copy. E.g., consider the following scenario: 924 // CS CD 925 // already committed async copy: +-----------+ 926 // new async copy: +--------+ 927 // 928 // The new asynchronous copy would violate the ordering guarantee because the 929 // copy start is after an already committed asynchronous copy while its copy 930 // done is before the committed copy. 931 absl::optional<AsynchronousCopy> ViolatesOrdering(int64 start_time, 932 int64 end_time) const; 933 934 private: 935 // Stores asynchronous copies in a tree set respecting the pipelining order. 936 std::set<AsynchronousCopy> ranges_; 937 }; 938 939 // This class inherits from GlobalDecreasingSizeBestFitHeap with a notion of 940 // maximum size. 941 class AlternateMemoryBestFitHeap 942 : public GlobalDecreasingSizeBestFitHeap<HloValue> { 943 public: 944 using MemorySpace = MemorySpaceAssignment::MemorySpace; 945 using AllocationValue = MemorySpaceAssignment::AllocationValue; 946 AlternateMemoryBestFitHeap(MemorySpaceAssignment::AllocationSequence * allocations,const MemorySpaceAssignment::Options & options,const HloAliasAnalysis & alias_analysis,const HloLiveRange & hlo_live_range)947 AlternateMemoryBestFitHeap( 948 MemorySpaceAssignment::AllocationSequence* allocations, 949 const MemorySpaceAssignment::Options& options, 950 const HloAliasAnalysis& alias_analysis, 951 const HloLiveRange& hlo_live_range) 952 : GlobalDecreasingSizeBestFitHeap(options.alignment_in_bytes), 953 allocations_(allocations), 954 options_(options), 955 alias_analysis_(alias_analysis), 956 hlo_live_range_(hlo_live_range) { 957 // Override buffer interval compare if provided. 958 if (options.buffer_interval_compare) { 959 buffer_interval_compare_ = *options.buffer_interval_compare; 960 } 961 } 962 963 // Allocates a buffer in preferred memory with whole program lifetime and 964 // enables prefetching prefetch_candidate from default memory across program 965 // boundaries. 966 void AllocateCrossProgramPrefetchBuffer( 967 HloModule* module, absl::optional<BufferInterval> prefetch_candidate); 968 969 HeapSimulator::Result<HloValue> Finish() override; 970 971 protected: 972 // Given a buffer interval, returns the colocated intervals. Unlike the 973 // similar GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations, it 974 // returns the colocated intervals sorted by scheduled time. 975 std::vector<const BufferInterval*> GetSortedColocatedIntervals( 976 const BufferInterval& interval) const; 977 978 // Given a BufferInterval, creates AllocationValue objects and corresponding 979 // AllocationSequences and appends them into allocation_sequence_list_. 980 void CreateAllocationValues( 981 const BufferInterval& buffer_interval, 982 std::vector<AllocationValue>& allocation_values) const; 983 984 // Given colocated intervals, populates allocation_values with the 985 // corresponding AllocationValue objects. 986 virtual void CreateAllocationValuesFromColocatedIntervals( 987 absl::Span<const AlternateMemoryBestFitHeap::BufferInterval* const> 988 colocated_intervals, 989 std::vector<MemorySpaceAssignment::AllocationValue>& allocation_values); 990 991 // Go through all the uses in the AllocationValues and find the aliasing 992 // positions. 993 void FindAliases(std::vector<AllocationValue>* allocation_values, 994 bool skip_values_with_no_uses) const; 995 allocations()996 MemorySpaceAssignment::AllocationSequence* allocations() { 997 return allocations_; 998 } options()999 const MemorySpaceAssignment::Options& options() { return options_; } alias_analysis()1000 const HloAliasAnalysis& alias_analysis() { return alias_analysis_; } hlo_live_range()1001 const HloLiveRange& hlo_live_range() { return hlo_live_range_; } 1002 1003 private: 1004 // We inherit AllocationBlock struct to attach the Allocation information to 1005 // make importing repacked offsets easier. 1006 struct RepackAllocationBlock 1007 : MemorySpaceAssignmentRepacker::AllocationBlock { 1008 MemorySpaceAssignment::Allocation* allocation; 1009 }; 1010 1011 // A data structure we use to associate Allocation objects that are aliased 1012 // and must get the same offset. 1013 struct AliasedOffset { 1014 int64 offset; 1015 absl::flat_hash_set<const MemorySpaceAssignment::Allocation*> allocations; 1016 }; 1017 1018 // An allocation request for a use segment. A use segment is the time segment 1019 // between the definition and the first use, and the time segment between the 1020 // uses of a buffer. For example, the time between the definition and Use1, is 1021 // the first segment, and the time between Use1 and Use2 is the second segment 1022 // and so on: 1023 // 1024 // +------+----------+-------+ 1025 // / \ \ \ 1026 // / v v v 1027 // Def Use1 Use2 Use3 1028 // <----------> <--------> <-----> 1029 // Segment Segment Segment 1030 // 1031 // start_time and end_time are the start and end logical times of the segment. 1032 // use_times is a sorted sequence of the times of all uses. 1033 // latest_prefetch_time is the latest time we can schedule the CopyDone for a 1034 // prefetch. 1035 // If allow_no_copy_alternate_mem_allocation is false, an eviction is forced. 1036 // If earliest_prefetch_time is set, prefetches cannot start before this 1037 // value. 1038 struct AllocationRequest { 1039 int64 start_time; 1040 int64 end_time; 1041 int64 latest_prefetch_time; 1042 int64 size; 1043 bool allow_no_copy_alternate_mem_allocation; 1044 absl::optional<int64> earliest_prefetch_time; 1045 AliasedOffset* preferred_offset; 1046 const MemorySpaceAssignment::AllocationValue::Use* use; 1047 MemorySpaceAssignment::AllocationValue* allocation_value; 1048 }; 1049 1050 // This struct contains mandatory memory assignments at a given time. E.g., an 1051 // input's required memory assignment time would correspond to the definition 1052 // time of the parameter instruction, and an output's time would correspond to 1053 // the time of last use. 1054 struct RequiredMemoryAssignment { 1055 MemorySpaceAssignment::MemorySpace memory_space; 1056 int64 time; 1057 AliasedOffset* offset; 1058 equals_ignoring_timeRequiredMemoryAssignment1059 bool equals_ignoring_time(const RequiredMemoryAssignment& other) const { 1060 return memory_space == other.memory_space && offset == other.offset; 1061 } 1062 1063 bool operator==(const RequiredMemoryAssignment& other) const { 1064 return memory_space == other.memory_space && time == other.time && 1065 offset == other.offset; 1066 } 1067 1068 bool operator!=(const RequiredMemoryAssignment& other) const { 1069 return !(*this == other); 1070 } 1071 }; 1072 1073 // Result of an allocation, prefetch, eviction etc. request. The result is 1074 // either kSuccess or a bitwise OR of one or more failures. The values are 1075 // unique powers of two. To check if a result contains a particular failure, 1076 // use the result_is method. To add a new failure to a result, use the 1077 // result_mark method. 1078 enum class Result { 1079 // Successful allocation. 1080 kSuccess = 0, 1081 // Allocation failed because we ran out of alternate memory. 1082 kFailOutOfMemory = 1, 1083 // A no-copy allocation couldn't be performed because the previous 1084 // allocation wasn't in the alternate memory space. 1085 kFailPrevAllocationNotInAlternateMem = 2, 1086 // A no-copy allocation couldn't be performed because the live range was too 1087 // long. 1088 kFailLiveRangeTooLong = 4, 1089 // A prefetching couldn't be performed because the live range was too short. 1090 kFailLiveRangeTooShort = 8, 1091 // Ran out of outstanding asynchronous copy limit either during prefetching 1092 // or eviction. 1093 kFailOutOfAsyncCopies = 16, 1094 // A prefetching couldn't be performed because the asynchronous copy 1095 // ordering was violated. 1096 kFailViolatesAsyncCopyOrdering = 32, 1097 // An allocation failure happened that requires uncommitting all the pending 1098 // allocations. Usually this is due to a situation requiring an eviction but 1099 // the eviction couldn't be performed. 1100 kFailRequiresUncommit = 64 1101 }; 1102 1103 // Return true if the result belongs to a failure. result_is(Result result,Result failure)1104 static bool result_is(Result result, Result failure) { 1105 return static_cast<int>(result) & static_cast<int>(failure); 1106 } 1107 1108 // Mark (bitwise OR) a failure to the result. result_mark(Result failure,Result & result)1109 static Result result_mark(Result failure, Result& result) { 1110 result = static_cast<Result>(static_cast<int>(result) | 1111 static_cast<int>(failure)); 1112 return result; 1113 } 1114 1115 // Return true if the result is a failure that requires us to uncommit pending 1116 // chunks. result_requires_uncommit(Result result)1117 static bool result_requires_uncommit(Result result) { 1118 return result_is(result, Result::kFailRequiresUncommit); 1119 } 1120 1121 // Return true if the result is a failure either due to running out of 1122 // outstanding asynchronous copies or due to violating asynchronous copy 1123 // ordering. result_failed_because_of_async_copy(Result result)1124 static bool result_failed_because_of_async_copy(Result result) { 1125 return result_is(result, Result::kFailOutOfAsyncCopies) || 1126 result_is(result, Result::kFailViolatesAsyncCopyOrdering); 1127 } 1128 1129 // Returns the AliasedOffset object associated with the allocation. 1130 AliasedOffset* GetAliasedOffset( 1131 const MemorySpaceAssignment::Allocation& allocation); 1132 1133 // If aliased_offset is non-null, this method adds the allocation to 1134 // aliased_offset. Otherwise, it creates a new AliasedOffset object and adds 1135 // the allocation to this new AliasedOffset. 1136 void CreateOrAddToAliasedOffset( 1137 const MemorySpaceAssignment::Allocation& allocation, 1138 AliasedOffset* aliased_offset); 1139 1140 // Given an allocation sequence, returns the live allocation at time with a 1141 // preference towards allocations in alternate memory. Returns nullptr if no 1142 // allocation is alive at that time. 1143 static MemorySpaceAssignment::Allocation* GetLiveAllocationAt( 1144 const MemorySpaceAssignment::AllocationSequence& allocations, int64 time); 1145 1146 // Returns true if the use is allowed in the alternate memory. 1147 bool IsUseAllowedInAlternateMemory(const AllocationValue& value, 1148 const HloUse& use) const; 1149 1150 // Finds allocations for allocation values generated from colocated intervals. 1151 // All of the allocation values have a must-alias relationship with each 1152 // other. Returns either kSuccess if all of the sites could be placed in the 1153 // alternate memory or a bitwise OR of failure reasons why they couldn't 1154 Result AllocateAllocationValues( 1155 absl::Span<AllocationValue> allocation_values); 1156 1157 // Finds an allocation for an allocation request for a segment (see the 1158 // documentation for AllocationRequest above how a segment is defined). 1159 // 1160 // It performs three things in the following order: 1161 // 1- Allocate the allocation request entirely in the alternate memory, if 1162 // there is enough space and if the prefetch interval picker allows. 1163 // 2- If (1) was unsuccessful, and the only allocation for 1164 // this buffer was in the alternate memory, we try to perform a prefetch. 1165 // 3- If (1) was unsuccessful, prefetch the buffer into the alternate memory, 1166 // if there is enough space and if the prefetch interval picker allows. 1167 // 1168 // If an eviction (2) was requested and was unsuccessful, this method returns 1169 // Result::kFailRequiresUncommit. This means we could not find a suitable 1170 // allocation, so all previous allocations for this buffer must be removed and 1171 // allocated in the default memory. Otherwise, this method may return 1172 // Result::kSuccess if the buffer could be placed in alternate memory or some 1173 // other Result with an OR of reasons why the buffer couldn't be placed in 1174 // alternate memory. 1175 Result AllocateSegment(const AllocationRequest& request); 1176 1177 // Try allocating in alternate memory without any copies. 1178 Result AllocateInAlternateMemoryNoCopy(const AllocationRequest& request); 1179 1180 // Try evicting to default memory space. 1181 Result Evict(const AllocationRequest& request); 1182 1183 // Returns the time a copy done of a prefetch should be scheduled. 1184 int64 FindPrefetchEndTime(const AllocationRequest& request, 1185 int64 earliest_prefetch_time) const; 1186 1187 // Try prefetching to alternate memory space. 1188 Result Prefetch( 1189 const AllocationRequest& request, 1190 const MemorySpaceAssignment::Allocation& prev_allocation_in_default_mem); 1191 1192 // Find the best possible chunk candidate, where it has the longest possible 1193 // availability if no preferred offset is given, or at the preferred_offset if 1194 // it is given. 1195 absl::optional<ChunkCandidate> FindBestChunkCandidate( 1196 const AllocationRequest& request, const AliasedOffset* preferred_offset, 1197 BufferInterval* alternate_mem_interval) const; 1198 1199 // Returns the required assignment at a particular time, if available. 1200 absl::optional<RequiredMemoryAssignment> RequiredMemoryAssignmentAt( 1201 const HloValue* buffer, int64 time) const; 1202 1203 // Searches for aliases in the use for a required assignment, and returns it 1204 // if found. 1205 absl::optional<RequiredMemoryAssignment> AliasedRequiredAssignmentForUse( 1206 const AllocationValue::Use& use) const; 1207 1208 // Goes through the colocated intervals and adds any required assignment. 1209 void AddRequiredAssignmentsForColocatedIntervals( 1210 absl::Span<const AlternateMemoryBestFitHeap::BufferInterval* const> 1211 colocated_intervals); 1212 1213 // Propagates aliased required assignment for a given position. 1214 void AddAliasedRequiredAssignment( 1215 const HloInstruction* instruction, ShapeIndex index, 1216 const MemorySpaceAssignment::Allocation* aliased_allocation); 1217 1218 // This sets a required assignment. CHECK fails if there is a conflicting 1219 // required assignment at the same time. 1220 void AddRequiredAssignment(const HloValue* value, 1221 const HloInstruction* instruction, 1222 MemorySpace memory_space, int64 time, 1223 AliasedOffset* offset = nullptr); 1224 void AddRequiredAssignment(const HloInstruction* instruction, 1225 ShapeIndex index, MemorySpace memory_space, 1226 AliasedOffset* offset = nullptr); 1227 1228 // Adds input and outputs as required assignments. 1229 void AddInputAndOutputRequiredAssignments(); 1230 1231 // Returns true if the colocated intervals in the argument are in a parameter 1232 // or root instruction of the entry computation and are reserved by the user 1233 // to be in the alternate memory space. 1234 bool AreIntervalsReservedInAlternateMemory( 1235 absl::Span<const BufferInterval* const> colocated_intervals) const; 1236 1237 // Since the allocations are recorded to the AllocationSequence, we don't 1238 // maintain result_ in GlobalDecreasingSizeBestFitHeap. Override AddToChunkMap 1239 // to avoid unnecessarily adding the chunk to the chunk map. AddToChunkMap(const HloValue * buffer,Chunk chunk)1240 void AddToChunkMap(const HloValue* buffer, Chunk chunk) override {} 1241 1242 // Returns true if the addition of an asynchronous copy in the given time 1243 // interval would violate the maximum number of asynchronous copies. An extra 1244 // async copy limit can be provided to increase the limit of asynchronous 1245 // copies for this instance. 1246 bool ViolatesMaximumOutstandingAsyncCopies( 1247 int64 start_time, int64 end_time, bool is_prefetch, 1248 int64 extra_async_copy_limit = 0) const; 1249 1250 // If the asynchronous copy would violate the pipelining order, returns the 1251 // violating asynchronous copy. 1252 absl::optional<AsynchronousCopy> ViolatesAsyncCopyOrdering( 1253 int64 start_time, int64 end_time) const; 1254 1255 // Exports the allocations for repacking and puts them into the vector in the 1256 // parameter. 1257 void ExportAllocationsForRepacking( 1258 std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*>& 1259 allocations); 1260 1261 // Imports repacked allocations and updates the internal data structures 1262 // consistent with the new packing. 1263 void ImportRepackedAllocations(); 1264 1265 // Adds an asynchronous copy to the allocations. 1266 void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation, 1267 MemorySpace memory_space, absl::optional<Chunk> chunk, 1268 int64 start_time, int64 end_time, 1269 int64 copy_done_schedule_before_time, 1270 MemorySpaceAssignment::AllocationSequence* allocations, 1271 AliasedOffset* aliased_offset, 1272 bool is_cross_program_prefetch = false); 1273 1274 // This method is used for committing the chunk candidate but adding it to 1275 // pending_chunks_ so that we can "uncommit" them in case we need to roll back 1276 // this allocation sequence. 1277 void AddToPendingChunks(const BufferInterval& buffer_interval, 1278 const ChunkCandidate& chunk_candidate); 1279 // If we need to remove the allocations for this allocation sequence, this 1280 // removes pending chunks and asynchronous copies in the respective pending 1281 // buffers from the interval trees. If an allocation request returns 1282 // kFailRequiresUncommit, this method must be called. 1283 void UncommitPendingChunks(absl::Span<AllocationValue> allocation_values); 1284 1285 // Finalizes the allocations where they can no longer be uncommitted. 1286 void FinalizeAllocations(absl::Span<AllocationValue> allocation_values); 1287 1288 // Clears all pending chunks and asynchronous copies. 1289 void ClearPendingChunks(); 1290 1291 // Append buffer and allocation infos for debugging and dump it into a file, 1292 // if enabled. 1293 void AppendBufferInfoDebugString(const BufferInterval& interval, 1294 std::string* debug_str) const; 1295 void AppendAllocationInfoDebugString( 1296 const AllocationValue& value, 1297 const MemorySpaceAssignment::Allocation& allocation, 1298 std::string& debug_str) const; 1299 void DumpDebugStringsIfEnabled() const; 1300 1301 // Returns the available heap size in the alternate memory. available_heap_size()1302 int64 available_heap_size() const { 1303 return options_.max_size_in_bytes - reserved_in_bytes_; 1304 } 1305 1306 // Creates and returns a RepackAllocationBlock. MakeRepackAllocationBlock(int64 start_time,int64 end_time,int64 size,int64 initial_offset,int64 id,MemorySpaceAssignment::Allocation * allocation)1307 static RepackAllocationBlock MakeRepackAllocationBlock( 1308 int64 start_time, int64 end_time, int64 size, int64 initial_offset, 1309 int64 id, MemorySpaceAssignment::Allocation* allocation) { 1310 RepackAllocationBlock allocation_block; 1311 allocation_block.start_time = start_time; 1312 allocation_block.end_time = end_time; 1313 allocation_block.size = size; 1314 allocation_block.offset = -1; 1315 allocation_block.initial_offset = initial_offset; 1316 allocation_block.id = id; 1317 allocation_block.colocations = {}; 1318 allocation_block.allocation = allocation; 1319 return allocation_block; 1320 } 1321 1322 MemorySpaceAssignment::AllocationSequence* allocations_; 1323 const MemorySpaceAssignment::Options& options_; 1324 const HloAliasAnalysis& alias_analysis_; 1325 const HloLiveRange& hlo_live_range_; 1326 // We use a interval tree to keep track of the number of outstanding 1327 // prefetches and evictions. 1328 BufferIntervalTree prefetch_interval_tree_; 1329 BufferIntervalTree eviction_interval_tree_; 1330 AsynchronousCopyOrdering async_copy_ordering_; 1331 // A list of RepackAllocationBlock objects that mirrors allocation sequences, 1332 // used for repacking. We use a list here because we need pointer stability 1333 // for aliased allocations. 1334 std::list<RepackAllocationBlock> repack_allocation_blocks_; 1335 int64 num_repacks_ = 0; 1336 std::vector<std::pair<BufferInterval, ChunkCandidate>> pending_chunks_; 1337 std::vector<AsynchronousCopy> pending_async_copies_; 1338 std::vector<std::pair<const HloValue*, RequiredMemoryAssignment>> 1339 pending_required_assignments_; 1340 // The data structure that contains AliasedOffset objects and Allocation to 1341 // AliasedOffset map for efficient lookup. 1342 std::list<AliasedOffset> aliased_offsets_; 1343 absl::flat_hash_map<const MemorySpaceAssignment::Allocation*, AliasedOffset*> 1344 aliased_offset_map_; 1345 // This map contains required memory assignments for HloValues (e.g., input 1346 // and outputs). 1347 absl::flat_hash_map<const HloValue*, std::vector<RequiredMemoryAssignment>> 1348 required_assignments_; 1349 // Number of bytes reserved in alternate memory space. 1350 int64 reserved_in_bytes_ = 0; 1351 // Debug strings. 1352 std::string buffer_info_str_; 1353 std::string allocation_info_str_; 1354 }; 1355 1356 } // namespace xla 1357 1358 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_ 1359