1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_ 18 19 #include <functional> 20 #include <iosfwd> 21 #include <memory> 22 #include <string> 23 #include <vector> 24 25 #include "absl/container/flat_hash_map.h" 26 #include "absl/container/flat_hash_set.h" 27 #include "absl/types/span.h" 28 #include "tensorflow/compiler/xla/service/buffer_liveness.h" 29 #include "tensorflow/compiler/xla/service/heap_simulator.h" 30 #include "tensorflow/compiler/xla/service/hlo.pb.h" 31 #include "tensorflow/compiler/xla/service/hlo_computation.h" 32 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 33 #include "tensorflow/compiler/xla/service/hlo_module.h" 34 #include "tensorflow/compiler/xla/service/logical_buffer.h" 35 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" 36 #include "tensorflow/compiler/xla/statusor.h" 37 #include "tensorflow/compiler/xla/types.h" 38 #include "tensorflow/core/platform/logging.h" 39 #include "tensorflow/core/platform/macros.h" 40 #include "tensorflow/core/platform/types.h" 41 42 namespace xla { 43 44 // Walk the call graph of the HLO module and place each computation into either 45 // thread_local_computations or global_computations depending upon whether the 46 // computation requires thread-local allocations or global allocations. The 47 // elements in thread_local_computations and global_computations are in post 48 // order (if computation A has an instruction which calls computation B, then A 49 // will appear after B in the vector). 50 Status GatherComputationsByAllocationType( 51 const HloModule* module, 52 std::vector<const HloComputation*>* thread_local_computations, 53 std::vector<const HloComputation*>* global_computations); 54 55 // This class abstracts an allocation of contiguous memory which can hold the 56 // values described by LogicalBuffers. Each LogicalBuffer occupies a sub-range 57 // of the allocation, represented by a Slice. A single BufferAllocation may hold 58 // LogicalBuffers with disjoint liveness, which may have overlapping Slices. A 59 // single BufferAllocation may also hold LogicalBuffers with overlapping 60 // liveness, which must have disjoint Slices. 61 // 62 // The abstraction includes information required by the backends for allocation, 63 // use, and deallocation of the buffer. This includes the LogicalBuffers which 64 // are held in this allocation through the execution of the computation. 65 class BufferAllocation { 66 public: 67 // Holds a unique identifier for each allocation. Values are assigned 68 // contiguously and can be used as array indexes. 69 using Index = int64; 70 BufferAllocation(Index index,int64 size,LogicalBuffer::Color color)71 BufferAllocation(Index index, int64 size, LogicalBuffer::Color color) 72 : index_(index), size_(size), color_(color) {} ~BufferAllocation()73 ~BufferAllocation() {} 74 75 // Returns the index of this allocation. index()76 Index index() const { return index_; } 77 78 // Whether this allocation is used in a parallel calling context such as 79 // inside of a map or reduce computation. Such allocations need to be thread 80 // local. is_thread_local()81 bool is_thread_local() const { return is_thread_local_; } set_is_thread_local(bool is_thread_local)82 void set_is_thread_local(bool is_thread_local) { 83 is_thread_local_ = is_thread_local; 84 } 85 86 // Whether this allocation can be used by more than one logical buffer. is_reusable()87 bool is_reusable() const { 88 // We do not reuse thread-local buffers for now, because they are 89 // dynamically allocated and their lifetimes are hard to compute. 90 // 91 // TODO(b/34669761): Don't reuse tuple buffers because the GPU backend 92 // assumes longer buffer liveness than indicated by the analysis. 93 return !is_thread_local() && !is_tuple(); 94 } 95 96 // Whether this allocation is readonly i.e. backed by memory we cannot write 97 // to. is_readonly()98 bool is_readonly() const { 99 // Entry parameters are generally readonly, except when they are aliased 100 // with any output. 101 return (is_entry_computation_parameter() && 102 !is_parameter_aliased_with_output_) || 103 is_constant(); 104 } 105 is_tuple()106 bool is_tuple() const { return is_tuple_; } set_is_tuple(bool is_tuple)107 void set_is_tuple(bool is_tuple) { is_tuple_ = is_tuple; } 108 109 // Whether this allocation holds a LogicalBuffer from a parameter of the entry 110 // computation. These buffers have lifetimes which may be longer than the 111 // XLA computation. is_entry_computation_parameter()112 bool is_entry_computation_parameter() const { 113 return is_entry_computation_parameter_; 114 } 115 116 // Whether this allocation holds a constant. On the CPU and GPU backends 117 // constant allocations are not allocated dynamically, instead we resolve 118 // references to these buffer allocations to a global in the readonly section 119 // of the binary. is_constant()120 bool is_constant() const { return is_constant_; } 121 122 // If this allocation holds a Buffer from a parameter of the entry 123 // computation, this methods returns the parameter number. CHECKs otherwise. parameter_number()124 int64 parameter_number() const { 125 CHECK(is_entry_computation_parameter_); 126 return parameter_number_; 127 } 128 129 // If this allocation is for a parameter of the entry computation, this 130 // function returns which subshape of the parameter the allocation is for. param_shape_index()131 const ShapeIndex& param_shape_index() const { 132 CHECK(is_entry_computation_parameter_); 133 return param_shape_index_; 134 } 135 136 // Returns whether this allocation is assigned a LogicalBuffer which may 137 // be live out of the entry computation. maybe_live_out()138 bool maybe_live_out() const { return maybe_live_out_; } 139 140 // Returns the size of the allocation. Necessarily this must be at least as 141 // large as any LogicalBuffer assigned to this allocation. size()142 int64 size() const { return size_; } 143 144 // Returns the color of the allocation. Only logical buffers with a matching 145 // color can reside in this allocation. color()146 LogicalBuffer::Color color() const { return color_; } 147 148 struct OffsetSize { 149 int64 offset = 0; 150 int64 size = 0; 151 }; 152 153 // Access to the logical buffers assigned to this allocation, and their 154 // associated logical offsets and sizes. 155 const absl::flat_hash_map<const LogicalBuffer*, OffsetSize>& assigned_buffers()156 assigned_buffers() const { 157 return assigned_buffers_; 158 } 159 160 // A Slice represents a contiguous portion of a memory allocation. It is used 161 // to identify the memory range that a LogicalBuffer corresponds to. 162 class Slice { 163 public: Slice()164 Slice() {} Slice(const BufferAllocation * allocation,int64 offset,int64 size)165 Slice(const BufferAllocation* allocation, int64 offset, int64 size) 166 : allocation_(allocation), offset_(offset), size_(size) {} 167 allocation()168 const BufferAllocation* allocation() const { return allocation_; } index()169 Index index() const { return allocation_->index(); } offset()170 int64 offset() const { return offset_; } size()171 int64 size() const { return size_; } 172 173 bool operator==(const Slice& other) const { 174 return index() == other.index() && offset_ == other.offset_ && 175 size_ == other.size_; 176 } 177 bool operator!=(const Slice& other) const { return !(*this == other); } 178 bool operator<(const Slice& other) const { 179 if (index() != other.index()) return index() < other.index(); 180 if (offset_ != other.offset_) return offset_ < other.offset_; 181 return size_ < other.size_; 182 } 183 184 // Returns true iff this slice's memory range has a non-empty intersection 185 // with the other slice's memory range. OverlapsWith(const Slice & other)186 bool OverlapsWith(const Slice& other) const { 187 const int64 end = offset_ + size_; 188 const int64 other_end = other.offset_ + other.size_; 189 return index() == other.index() && offset_ < other_end && 190 end > other.offset_; 191 } 192 193 template <typename H> AbslHashValue(H h,const Slice & s)194 friend H AbslHashValue(H h, const Slice& s) { 195 return H::combine(std::move(h), s.index(), s.offset(), s.size()); 196 } 197 198 string ToString() const; 199 200 private: 201 const BufferAllocation* allocation_ = nullptr; 202 int64 offset_ = 0; 203 int64 size_ = 0; 204 }; 205 206 // GetSlice returns the Slice of contiguous memory that holds the value 207 // described by the given 'buffer'. 208 // REQUIRES: 'buffer' must be assigned to this allocation. 209 Slice GetSlice(const LogicalBuffer& buffer) const; 210 211 string ToString() const; 212 BufferAllocationProto ToProto() const; 213 214 // Whether the buffer is a parameter to or live out of the entry computation. IsInputOrOutput()215 bool IsInputOrOutput() const { 216 return is_entry_computation_parameter() || maybe_live_out(); 217 } 218 219 // Whether the buffer is a temporary buffer allocated before 220 // Executable::ExecuteOnStream. IsPreallocatedTempBuffer()221 bool IsPreallocatedTempBuffer() const { 222 // Parameters do not need temporary buffers. 223 return !is_entry_computation_parameter() && 224 // LogicalBuffers that maybe pointed to by the output should live out 225 // of the computation. 226 !maybe_live_out() && 227 // Thread-local buffers are allocated using `alloca`s. 228 !is_thread_local() && 229 // Constant buffers are allocated as global values. 230 !is_constant(); 231 } 232 233 // Add a heap trace which was used to assign slices to logical buffers in this 234 // allocation. A single BufferAllocation may include multiple heap traces 235 // in the case of the temporary block where there is a heap trace per 236 // computation. AddHeapTrace(const HeapSimulatorTrace & heap_trace)237 void AddHeapTrace(const HeapSimulatorTrace& heap_trace) { 238 heap_traces_.push_back(heap_trace); 239 } 240 241 // Return the set of heap traces used to assign slices to logical buffers in 242 // this allocation. HeapTraces()243 const std::vector<HeapSimulatorTrace> HeapTraces() const { 244 return heap_traces_; 245 } 246 247 // Returns the LogicalBuffers which are live at the point of peak memory usage 248 // for this allocation. The point of peak memory usage is the point at which 249 // the total size of all live logical buffers is maximal. If peak memory is 250 // reached at multiple points, the set of logical buffers live at the earliest 251 // maximal point is returned. The vector is stabily sorted by 252 // LogicalBuffer::Index. PeakMemoryLogicalBuffers()253 const std::vector<const LogicalBuffer*>& PeakMemoryLogicalBuffers() const { 254 return peak_buffers_; 255 } 256 257 // Get the number of bytes lost to fragmentation. This is equal to the 258 // difference between the size of the allocation and the size of the maximal 259 // live set. fragmentation_bytes()260 int64 fragmentation_bytes() const { return fragmentation_bytes_; } 261 262 bool operator==(const BufferAllocation& other) const { 263 return index_ == other.index_; 264 } 265 bool operator!=(const BufferAllocation& other) const { 266 return !(*this == other); 267 } 268 bool operator<(const BufferAllocation& other) const { 269 return index() < other.index(); 270 } 271 272 private: 273 // Only BufferAssigner and BufferAssignment can modify BufferAllocation. 274 friend class BufferAssigner; 275 friend class BufferAssignment; 276 277 // Adds a LogicalBuffer to the set assigned to this buffer. 278 void AddAssignment(const LogicalBuffer& buffer, int64 offset, int64 size); 279 set_entry_computation_parameter(int64 parameter_number,ShapeIndex param_shape_index,bool parameter_aliased_with_output)280 void set_entry_computation_parameter(int64 parameter_number, 281 ShapeIndex param_shape_index, 282 bool parameter_aliased_with_output) { 283 is_entry_computation_parameter_ = true; 284 is_parameter_aliased_with_output_ = parameter_aliased_with_output; 285 parameter_number_ = parameter_number; 286 param_shape_index_ = std::move(param_shape_index); 287 } 288 set_constant(bool is_constant)289 void set_constant(bool is_constant) { is_constant_ = is_constant; } set_maybe_live_out(bool value)290 void set_maybe_live_out(bool value) { maybe_live_out_ = value; } set_index(Index index)291 void set_index(Index index) { index_ = index; } set_size(int64 size)292 void set_size(int64 size) { size_ = size; } 293 294 // The index of the allocation in the BufferAssignment. 295 Index index_; 296 297 // Size of the allocation in bytes. 298 int64 size_; 299 300 // Whether this buffer needs to be thread-local. 301 bool is_thread_local_ = false; 302 303 // Whether this buffer holds a tuple. 304 bool is_tuple_ = false; 305 306 // Color of the allocation. 307 LogicalBuffer::Color color_; 308 309 // Whether this allocation holds an entry computation parameter. Entry 310 // computation parameters are special be cause they have lifetimes which may 311 // outlast the computation. 312 bool is_entry_computation_parameter_ = false; 313 314 // Whether this entry computation parameter is aliased with output. 315 bool is_parameter_aliased_with_output_ = false; 316 317 // If this allocation holds an entry computation parameter, this field 318 // indicates the index (starting from 0) of the parameter. 319 int64 parameter_number_ = 0; 320 321 // If this buffer is for an entry computation parameter, which subshape of the 322 // parameter is it for? 323 ShapeIndex param_shape_index_; 324 325 // Whether the allocation contains a LogicalBuffer which may be live-out of 326 // the entry computation. Note that this flag is conservatively computed by 327 // TuplePointsToAnalysis. That is, an allocation marked `maybe_live_out_` 328 // might not actually escape. 329 bool maybe_live_out_ = false; 330 331 // See comment on the is_constant() accessor. 332 bool is_constant_ = false; 333 334 // Mapping from the set of buffers assigned to this allocation to their 335 // logical offsets and sizes. 336 absl::flat_hash_map<const LogicalBuffer*, OffsetSize> assigned_buffers_; 337 338 int64 fragmentation_bytes_ = 0; 339 std::vector<HeapSimulatorTrace> heap_traces_; 340 341 // Set of buffers live at the point of peak memory usage for this allocation. 342 std::vector<const LogicalBuffer*> peak_buffers_; 343 }; 344 345 // Add stream operators for nicer output of CHECK/RET_CHECK failures. 346 std::ostream& operator<<(std::ostream& out, const BufferAllocation& s); 347 std::ostream& operator<<(std::ostream& out, const BufferAllocation::Slice& s); 348 349 // This class encapsulates an assignment of the LogicalBuffers in an XLA 350 // module to a set of BufferAllocations. 351 class BufferAssignment { 352 public: 353 // Returns the vector containing all buffer allocations in this assignment. Allocations()354 const std::vector<BufferAllocation>& Allocations() const { 355 return allocations_; 356 } 357 358 // Returns the total size allocation holding all temporary buffers. temp_allocation_total_size()359 int64 temp_allocation_total_size() const { 360 return temp_allocation_total_size_; 361 } 362 363 // Returns whether the given buffer has been assigned an allocation. 364 bool HasAllocation(const LogicalBuffer& buffer) const; 365 366 // Returns the allocation that a particular LogicalBuffer has been assigned 367 // to. CHECKs if buffer has not been assigned an allocation. 368 const BufferAllocation& GetAssignedAllocation( 369 const LogicalBuffer& buffer) const; 370 371 // Returns the allocation with the given index. CHECKs if no allocation exists 372 // with the given index. 373 const BufferAllocation& GetAllocation(BufferAllocation::Index index) const; 374 375 // Returns the allocation with the given instruction and shape index. nullptr 376 // if no allocation exists. 377 const BufferAllocation* GetInstructionAllocation( 378 const HloInstruction* hlo, const ShapeIndex& shape_index) const; 379 380 // Builds and returns a vector containing the slices which might contain the 381 // subvalue at the given index of given instruction. 382 std::set<BufferAllocation::Slice> GetAllSlices( 383 const HloInstruction* instruction, const ShapeIndex& index) const; 384 385 // Convenience function which returns whether the buffer of the 386 // instruction at the given index is assigned an allocation. 387 bool HasAllocationAt(const HloInstruction* instruction, 388 const ShapeIndex& index) const; 389 390 // Convenience function which returns whether the top-level buffer of the 391 // instruction (index == {}) is assigned an allocation. 392 bool HasTopLevelAllocation(const HloInstruction* instruction) const; 393 394 // Convenience function which returns the unique slice containing the buffer 395 // at the given index of the given instruction. If a slice is not assigned or 396 // the slice cannot be determined at compile time then an error is returned. 397 StatusOr<BufferAllocation::Slice> GetUniqueSlice( 398 const HloInstruction* instruction, const ShapeIndex& index) const; 399 // Like GetUniqueSlice but fixes the index to the top-level of the shape 400 // (index = {}). 401 StatusOr<BufferAllocation::Slice> GetUniqueTopLevelSlice( 402 const HloInstruction* instruction) const; 403 // Like GetUniqueTopLevelSlice but returns the slice for the output of the 404 // entry computation of the HLO module (ie, the result of the XLA 405 // computation). 406 StatusOr<BufferAllocation::Slice> GetUniqueTopLevelOutputSlice() const; 407 408 // Returns the set LogicalBuffers which may be the source of the value at the 409 // given index and instruction. GetSourceBuffers(const HloInstruction * instruction,const ShapeIndex & index)410 const PointsToSet::BufferList& GetSourceBuffers( 411 const HloInstruction* instruction, const ShapeIndex& index) const { 412 return GetPointsToSet(instruction).element(index); 413 } 414 415 // Returns true if 'hlo_a{shape_index_a}' and 'hlo_b{shape_index_b}' 416 // share the same BufferAllocation::Slice. 417 // Returns false otherwise. 418 // REQUIRES: BufferAssignment assigned allocations to both instructions. 419 bool SharesSliceAtIndex(const HloInstruction* hlo_a, 420 const ShapeIndex& shape_index_a, 421 const HloInstruction* hlo_b, 422 const ShapeIndex& shape_index_b) const; 423 424 // Returns true if the top-level buffers of hlo_a and hlo_b are the same. 425 // REQUIRES: HasTopLevelAllocation(hlo_a) && HasTopLevelAllocation(hlo_b). SharesTopLevelSlice(const HloInstruction * hlo_a,const HloInstruction * hlo_b)426 bool SharesTopLevelSlice(const HloInstruction* hlo_a, 427 const HloInstruction* hlo_b) const { 428 return SharesSliceAtIndex(hlo_a, {}, hlo_b, {}); 429 } 430 431 // Returns true if hlo_a and hlo_b both have at least one buffer assigned for 432 // their top-level and each of their nested shape indices, and if hlo_a's 433 // buffers are all different from hlo_b's buffers. 434 bool HaveDisjointSlices(const HloInstruction* hlo_a, 435 const HloInstruction* hlo_b) const; 436 437 // Returns the underlying points-to analysis used for this assignment. points_to_analysis()438 const TuplePointsToAnalysis& points_to_analysis() const { 439 return liveness_->points_to_analysis(); 440 } 441 442 // Returns the BufferLiveness object used to construct this assignment. liveness()443 const BufferLiveness& liveness() const { return *liveness_; } 444 445 string ToString() const; 446 BufferAssignmentProto ToProto() const; 447 448 // Statistics for the assignment. Values initialized to -1 are not always 449 // collected; fragmentation is only collected for instructions that have a 450 // sequential total ordering. 451 struct Stats { 452 int64 parameter_allocation_count = 0; 453 int64 parameter_allocation_bytes = 0; 454 int64 constant_allocation_count = 0; 455 int64 constant_allocation_bytes = 0; 456 int64 maybe_live_out_allocation_count = 0; 457 int64 maybe_live_out_allocation_bytes = 0; 458 int64 preallocated_temp_allocation_count = 0; 459 int64 preallocated_temp_allocation_bytes = 0; 460 int64 preallocated_temp_fragmentation_bytes = -1; 461 int64 total_allocation_count = 0; 462 int64 total_allocation_bytes = 0; 463 int64 total_fragmentation_bytes = -1; 464 465 string ToString() const; 466 }; GetStats()467 const Stats& GetStats() const { return stats_; } 468 469 private: 470 // Only BufferAssigner can build or modify BufferAssignments. 471 friend class BufferAssigner; 472 BufferAssignment(const HloModule * module,std::unique_ptr<BufferLiveness> liveness,LogicalBuffer::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment)473 BufferAssignment(const HloModule* module, 474 std::unique_ptr<BufferLiveness> liveness, 475 LogicalBuffer::SizeFunction buffer_size, 476 LogicalBuffer::AlignmentFunction color_alignment) 477 : module_(module), 478 liveness_(std::move(liveness)), 479 buffer_size_(std::move(buffer_size)), 480 color_alignment_(std::move(color_alignment)) {} 481 482 // Creates and returns a new BufferAllocation, with no assigned 483 // LogicalBuffers. Ownership is maintained internally. 484 BufferAllocation* NewEmptyAllocation(int64 size, LogicalBuffer::Color color); 485 486 // Helper that calls NewEmptyAllocation and AddAssignment in one call, 487 // creating an allocation containing a single LogicalBuffer. 488 BufferAllocation* NewAllocation(const LogicalBuffer& buffer, int64 size); 489 490 // Adds a LogicalBuffer to the set assigned to the given allocation. 491 void AddAssignment(BufferAllocation* allocation, const LogicalBuffer& buffer, 492 int64 offset, int64 size); 493 494 // Returns the HloModule used to construct this assignment. module()495 const HloModule& module() const { return *module_; } 496 497 // Convenience function which returns the PointsToSet for the given 498 // instruction. Extracted from the liveness object. 499 const PointsToSet& GetPointsToSet(const HloInstruction* instruction) const; 500 501 // Mutable accessors for allocations. 502 BufferAllocation* GetMutableAssignedAllocation(const LogicalBuffer& buffer); 503 BufferAllocation* GetMutableAllocation(BufferAllocation::Index index); 504 505 // Combines allocations of temporary buffers into one big BufferAllocation. 506 void CombineTempAllocations(); 507 508 // Computes stats for the assignment, to be retrieved by GetStats. 509 Status ComputeSummaryStats(); 510 511 // The vector of buffer allocations. Indexed by BufferAllocation::Index. 512 std::vector<BufferAllocation> allocations_; 513 514 // The total size of all temporary buffers. 515 int64 temp_allocation_total_size_ = 0; 516 517 // Maps Buffers to the index of the BufferAllocation which holds the buffer. 518 absl::flat_hash_map<const LogicalBuffer*, BufferAllocation::Index> 519 allocation_index_for_buffer_; 520 521 const HloModule* module_; 522 const std::unique_ptr<BufferLiveness> liveness_; 523 524 // Function which returns the buffer size for a given logical buffer (shape). 525 LogicalBuffer::SizeFunction buffer_size_; 526 527 // Function which returns the alignment for a given logical buffer color. 528 LogicalBuffer::AlignmentFunction color_alignment_; 529 530 Stats stats_; 531 532 TF_DISALLOW_COPY_AND_ASSIGN(BufferAssignment); 533 }; 534 535 // A class which constructs a buffer assignment. 536 class BufferAssigner { 537 public: 538 // Returns false if a buffer cannot be assigned to given allocation. 539 using ReuseAllocationFunction = std::function<bool( 540 const BufferAssignment& assignment, const BufferAllocation& alloc, 541 const LogicalBuffer& buffer)>; 542 543 // Build and return a BufferAssignment for the given module. The given 544 // HloOrdering is used to determine buffer liveness. buffer_size and 545 // color_alignment are functions which returns the size and alignment of a 546 // LogicalBuffer. allow_input_output_aliasing specifies whether input buffer 547 // are allowed to be reused as outbut buffers by the client code. 548 static StatusOr<std::unique_ptr<BufferAssignment>> Run( 549 const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering, 550 LogicalBuffer::SizeFunction buffer_size, 551 LogicalBuffer::AlignmentFunction color_alignment, 552 bool allow_input_output_aliasing = false, 553 bool allocate_buffers_for_constants = false, 554 BufferLiveness::Colorer colorer = BufferLiveness::DefaultColorer(), 555 ReuseAllocationFunction reuse_checker = nullptr); 556 557 private: BufferAssigner(bool allocate_buffers_for_constants,BufferLiveness::Colorer colorer,ReuseAllocationFunction reuse_checker)558 BufferAssigner(bool allocate_buffers_for_constants, 559 BufferLiveness::Colorer colorer, 560 ReuseAllocationFunction reuse_checker) 561 : allocate_buffers_for_constants_(allocate_buffers_for_constants), 562 colorer_(colorer), 563 reuse_checker_(reuse_checker) {} 564 virtual ~BufferAssigner() = default; 565 566 // Create a buffer assignment. 567 StatusOr<std::unique_ptr<BufferAssignment>> CreateAssignment( 568 const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering, 569 LogicalBuffer::SizeFunction buffer_size, 570 LogicalBuffer::AlignmentFunction color_alignment); 571 572 // Assigns buffers to the instructions in the given computation. "assignment" 573 // is modified to reflect the new buffer assignments. If is_thread_local is 574 // true, then all assigned buffers have the is_thread_local flag set to 575 // true. 576 Status AssignBuffersForComputation( 577 const HloComputation* computation, bool is_thread_local, 578 const absl::flat_hash_set<const LogicalBuffer*>& colocated_buffers, 579 const absl::flat_hash_set<BufferAllocation::Index>& colocated_allocations, 580 absl::flat_hash_map<const HloComputation*, 581 absl::flat_hash_set<const LogicalBuffer*>>* 582 buffers_to_assign_sequentially, 583 BufferAssignment* assignment); 584 585 // Assigns 'buffers_to_assign_sequentially' using heap simulation, assuming 586 // the HLO instructions will be executed in the sequential order given by 587 // assignment->liveness().hlo_ordering().SequentialOrder. If 588 // 'run_whole_module_heap_simulation' is true, the heap simulation will be run 589 // assuming all global computations are sequentially ordered. 590 Status AssignBuffersWithSequentialOrdering( 591 const absl::flat_hash_map<const HloComputation*, 592 absl::flat_hash_set<const LogicalBuffer*>>& 593 buffers_to_assign_sequentially, 594 bool run_whole_module_heap_simulation, BufferAssignment* assignment); 595 596 // Uses the results of the heap simulator to create a single allocation, with 597 // LogicalBuffers packed to specific offsets. 598 void AssignBuffersFromHeapSimulator(const HeapSimulator::Result& result, 599 BufferAssignment* assignment, 600 LogicalBuffer::Color color); 601 602 // Tries to assign the given instruction to the given buffer. Returns if the 603 // assignment was successful. 604 bool MaybeAssignBuffer(BufferAllocation* allocation, 605 const LogicalBuffer& buffer, 606 BufferAssignment* assignment); 607 608 // Colocated buffers are logical buffers from different computations which 609 // alias. Explicitly handling these colocated buffers is necessary because 610 // points-to analysis is computation level scope and does not recognize 611 // aliasing across computations (b/32491382). 612 using ColocatedBufferSet = absl::flat_hash_set<const LogicalBuffer*>; 613 614 // Returns a vector of ColocatedBufferSet objects, where each 615 // ColocatedBufferSet aggregates a set of related LogicalBuffers from 'module' 616 // which should be colocated in the same buffer allocation. 617 void BuildColocatedBufferSets( 618 const HloModule* module, const BufferLiveness& buffer_liveness, 619 const LogicalBuffer::SizeFunction& buffer_size, 620 std::vector<ColocatedBufferSet>* colocated_buffer_sets); 621 622 // For each buffer set in 'colocated_buffer_sets', assigns all buffers in the 623 // same set to the same buffer allocation in 'assignment'. 624 void AssignColocatedBufferSets( 625 const std::vector<ColocatedBufferSet>& colocated_buffer_sets, 626 BufferAssignment* assignment, 627 absl::flat_hash_set<const LogicalBuffer*>* colocated_buffers, 628 absl::flat_hash_set<BufferAllocation::Index>* colocated_allocations); 629 630 // Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining 631 // the invariant that all sets in 'colocated_buffer_sets' are disjoint. 632 void AddSetToColocatedBufferSets( 633 const std::vector<const LogicalBuffer*>& colocated_set, 634 std::vector<ColocatedBufferSet>* colocated_buffer_sets); 635 636 // Given a list of colocated buffer sets (each colocated buffer set represents 637 // the logical buffers that would be assigned to the same physical buffer), 638 // try to merge the sets if the buffers can be shared. Returns the merged set. 639 std::vector<ColocatedBufferSet> MergeColocatedBufferSets( 640 const std::vector<ColocatedBufferSet>& colocated_buffer_sets, 641 const BufferLiveness& buffer_liveness, 642 const LogicalBuffer::SizeFunction& buffer_size); 643 644 // Split a set of buffers into several sets, each of which contains buffers 645 // colored with the same color. 646 absl::flat_hash_map<LogicalBuffer::Color, 647 absl::flat_hash_set<const LogicalBuffer*>, 648 LogicalBuffer::Color::Hasher> 649 SplitBuffersByColor(const absl::flat_hash_set<const LogicalBuffer*>& buffers); 650 651 // If true, allocate buffers for constant instructions. 652 bool allocate_buffers_for_constants_; 653 654 // Functor used to assign colors to newly allocated logical buffers. 655 BufferLiveness::Colorer colorer_; 656 657 // Functor to check if a buffer can reuse an allocation. 658 ReuseAllocationFunction reuse_checker_; 659 660 TF_DISALLOW_COPY_AND_ASSIGN(BufferAssigner); 661 }; 662 663 } // namespace xla 664 665 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_ 666