1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_ 18 19 #include <functional> 20 #include <iosfwd> 21 #include <memory> 22 #include <string> 23 #include <vector> 24 25 #include "absl/container/flat_hash_map.h" 26 #include "absl/container/flat_hash_set.h" 27 #include "absl/types/span.h" 28 #include "tensorflow/compiler/xla/service/heap_simulator.h" 29 #include "tensorflow/compiler/xla/service/hlo.pb.h" 30 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" 31 #include "tensorflow/compiler/xla/service/hlo_computation.h" 32 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" 33 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 34 #include "tensorflow/compiler/xla/service/hlo_live_range.h" 35 #include "tensorflow/compiler/xla/service/hlo_module.h" 36 #include "tensorflow/compiler/xla/service/logical_buffer.h" 37 #include "tensorflow/compiler/xla/service/memory_space_assignment.h" 38 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" 39 #include "tensorflow/compiler/xla/statusor.h" 40 #include "tensorflow/compiler/xla/types.h" 41 #include "tensorflow/core/platform/logging.h" 42 #include "tensorflow/core/platform/macros.h" 43 #include "tensorflow/core/platform/types.h" 44 45 namespace xla { 46 47 // Walk the call graph of the HLO module and place each computation into either 48 // thread_local_computations or global_computations depending upon whether the 49 // computation requires thread-local allocations or global allocations. The 50 // elements in thread_local_computations and global_computations are in post 51 // order (if computation A has an instruction which calls computation B, then A 52 // will appear after B in the vector). 53 Status GatherComputationsByAllocationType( 54 const HloModule* module, 55 std::vector<const HloComputation*>* thread_local_computations, 56 std::vector<const HloComputation*>* global_computations); 57 58 // This class abstracts an allocation of contiguous memory which can hold the 59 // values described by LogicalBuffers. Each LogicalBuffer occupies a sub-range 60 // of the allocation, represented by a Slice. A single BufferAllocation may hold 61 // LogicalBuffers with disjoint liveness, which may have overlapping Slices. A 62 // single BufferAllocation may also hold LogicalBuffers with overlapping 63 // liveness, which must have disjoint Slices. 64 // 65 // The abstraction includes information required by the backends for allocation, 66 // use, and deallocation of the buffer. This includes the LogicalBuffers which 67 // are held in this allocation through the execution of the computation. 68 class BufferAllocation { 69 public: 70 // Holds a unique identifier for each allocation. Values are assigned 71 // contiguously and can be used as array indexes. 72 using Index = int64; 73 BufferAllocation(Index index,int64 size,LogicalBuffer::Color color)74 BufferAllocation(Index index, int64 size, LogicalBuffer::Color color) 75 : index_(index), size_(size), color_(color) {} ~BufferAllocation()76 ~BufferAllocation() {} 77 78 // Returns the index of this allocation. index()79 Index index() const { return index_; } 80 81 // Whether this allocation is used in a parallel calling context such as 82 // inside of a map or reduce computation. Such allocations need to be thread 83 // local. is_thread_local()84 bool is_thread_local() const { return is_thread_local_; } set_is_thread_local(bool is_thread_local)85 void set_is_thread_local(bool is_thread_local) { 86 is_thread_local_ = is_thread_local; 87 } 88 89 // Whether this allocation can be used by more than one logical buffer. is_reusable()90 bool is_reusable() const { 91 // We do not reuse thread-local buffers for now, because they are 92 // dynamically allocated and their lifetimes are hard to compute. 93 // 94 // TODO(b/34669761): Don't reuse tuple buffers because the GPU backend 95 // assumes longer buffer liveness than indicated by the analysis. 96 return !is_thread_local() && !is_tuple(); 97 } 98 99 // Whether this allocation is readonly i.e. backed by memory we cannot write 100 // to. is_readonly()101 bool is_readonly() const { 102 // Entry parameters are generally readonly, except when they are aliased 103 // with any output. 104 return (is_entry_computation_parameter() && 105 !is_parameter_aliased_with_output_) || 106 is_constant(); 107 } 108 is_tuple()109 bool is_tuple() const { return is_tuple_; } set_is_tuple(bool is_tuple)110 void set_is_tuple(bool is_tuple) { is_tuple_ = is_tuple; } 111 112 // Whether this allocation holds a LogicalBuffer from a parameter of the entry 113 // computation. These buffers have lifetimes which may be longer than the 114 // XLA computation. is_entry_computation_parameter()115 bool is_entry_computation_parameter() const { 116 return is_entry_computation_parameter_; 117 } 118 119 // Whether this allocation holds a constant. On the CPU and GPU backends 120 // constant allocations are not allocated dynamically, instead we resolve 121 // references to these buffer allocations to a global in the readonly section 122 // of the binary. is_constant()123 bool is_constant() const { return is_constant_; } 124 125 // If this allocation holds a Buffer from a parameter of the entry 126 // computation, this methods returns the parameter number. CHECKs otherwise. parameter_number()127 int64 parameter_number() const { 128 CHECK(is_entry_computation_parameter_); 129 return parameter_number_; 130 } 131 132 // If this allocation is for a parameter of the entry computation, this 133 // function returns which subshape of the parameter the allocation is for. param_shape_index()134 const ShapeIndex& param_shape_index() const { 135 CHECK(is_entry_computation_parameter_); 136 return param_shape_index_; 137 } 138 139 // Returns whether this allocation is assigned a LogicalBuffer which may 140 // be live out of the entry computation. maybe_live_out()141 bool maybe_live_out() const { return maybe_live_out_; } 142 143 // Returns the size of the allocation. Necessarily this must be at least as 144 // large as any LogicalBuffer assigned to this allocation. size()145 int64 size() const { return size_; } 146 147 // Returns the color of the allocation. Only logical buffers with a matching 148 // color can reside in this allocation. color()149 LogicalBuffer::Color color() const { return color_; } 150 151 struct OffsetSize { 152 int64 offset = 0; 153 int64 size = 0; 154 }; 155 156 // Access to the logical buffers assigned to this allocation, and their 157 // associated logical offsets and sizes. assigned_buffers()158 const absl::flat_hash_map<const HloValue*, OffsetSize>& assigned_buffers() 159 const { 160 return assigned_buffers_; 161 } 162 163 // A Slice represents a contiguous portion of a memory allocation. It is used 164 // to identify the memory range that a LogicalBuffer corresponds to. 165 class Slice { 166 public: Slice()167 Slice() {} Slice(const BufferAllocation * allocation,int64 offset,int64 size)168 Slice(const BufferAllocation* allocation, int64 offset, int64 size) 169 : allocation_(allocation), offset_(offset), size_(size) {} 170 allocation()171 const BufferAllocation* allocation() const { return allocation_; } index()172 Index index() const { return allocation_->index(); } offset()173 int64 offset() const { return offset_; } size()174 int64 size() const { return size_; } 175 176 bool operator==(const Slice& other) const { 177 return index() == other.index() && offset_ == other.offset_ && 178 size_ == other.size_; 179 } 180 bool operator!=(const Slice& other) const { return !(*this == other); } 181 bool operator<(const Slice& other) const { 182 if (index() != other.index()) return index() < other.index(); 183 if (offset_ != other.offset_) return offset_ < other.offset_; 184 return size_ < other.size_; 185 } 186 187 // Returns true iff this slice's memory range has a non-empty intersection 188 // with the other slice's memory range. OverlapsWith(const Slice & other)189 bool OverlapsWith(const Slice& other) const { 190 const int64 end = offset_ + size_; 191 const int64 other_end = other.offset_ + other.size_; 192 return index() == other.index() && offset_ < other_end && 193 end > other.offset_; 194 } 195 196 template <typename H> AbslHashValue(H h,const Slice & s)197 friend H AbslHashValue(H h, const Slice& s) { 198 return H::combine(std::move(h), s.index(), s.offset(), s.size()); 199 } 200 201 string ToString() const; 202 203 private: 204 const BufferAllocation* allocation_ = nullptr; 205 int64 offset_ = 0; 206 int64 size_ = 0; 207 }; 208 209 // GetSlice returns the Slice of contiguous memory that holds the value 210 // described by the given 'buffer'. 211 // REQUIRES: 'buffer' must be assigned to this allocation. 212 Slice GetSlice(const HloValue& buffer) const; 213 214 string ToString() const; 215 BufferAllocationProto ToProto() const; 216 217 // Whether the buffer is a parameter to or live out of the entry computation. IsInputOrOutput()218 bool IsInputOrOutput() const { 219 return is_entry_computation_parameter() || maybe_live_out(); 220 } 221 222 // Whether the buffer is a temporary buffer allocated before 223 // Executable::ExecuteOnStream. IsPreallocatedTempBuffer()224 bool IsPreallocatedTempBuffer() const { 225 // Parameters do not need temporary buffers. 226 return !is_entry_computation_parameter() && 227 // LogicalBuffers that maybe pointed to by the output should live out 228 // of the computation. 229 !maybe_live_out() && 230 // Thread-local buffers are allocated using `alloca`s. 231 !is_thread_local() && 232 // Constant buffers are allocated as global values. 233 !is_constant(); 234 } 235 236 // Add a heap trace which was used to assign slices to logical buffers in this 237 // allocation. A single BufferAllocation may include multiple heap traces 238 // in the case of the temporary block where there is a heap trace per 239 // computation. AddHeapTrace(const HeapSimulatorTrace & heap_trace)240 void AddHeapTrace(const HeapSimulatorTrace& heap_trace) { 241 heap_traces_.push_back(heap_trace); 242 } 243 244 // Return the set of heap traces used to assign slices to logical buffers in 245 // this allocation. HeapTraces()246 const std::vector<HeapSimulatorTrace> HeapTraces() const { 247 return heap_traces_; 248 } 249 250 // Returns the LogicalBuffers which are live at the point of peak memory usage 251 // for this allocation. The point of peak memory usage is the point at which 252 // the total size of all live logical buffers is maximal. If peak memory is 253 // reached at multiple points, the set of logical buffers live at the earliest 254 // maximal point is returned. The vector is stably sorted by 255 // BufferValue::Index. PeakMemoryLogicalBuffers()256 const std::vector<const HloValue*>& PeakMemoryLogicalBuffers() const { 257 return peak_buffers_; 258 } 259 260 // Get the number of bytes lost to fragmentation. This is equal to the 261 // difference between the size of the allocation and the size of the maximal 262 // live set. fragmentation_bytes()263 int64 fragmentation_bytes() const { return fragmentation_bytes_; } 264 265 bool operator==(const BufferAllocation& other) const { 266 return index_ == other.index_; 267 } 268 bool operator!=(const BufferAllocation& other) const { 269 return !(*this == other); 270 } 271 bool operator<(const BufferAllocation& other) const { 272 return index() < other.index(); 273 } 274 275 private: 276 // Only BufferAssigner and BufferAssignment can modify BufferAllocation. 277 friend class BufferAssigner; 278 friend class BufferAssignment; 279 280 // Adds a LogicalBuffer to the set assigned to this buffer. 281 void AddAssignment(const HloValue& buffer, int64 offset, int64 size); 282 set_entry_computation_parameter(int64 parameter_number,ShapeIndex param_shape_index,bool parameter_aliased_with_output)283 void set_entry_computation_parameter(int64 parameter_number, 284 ShapeIndex param_shape_index, 285 bool parameter_aliased_with_output) { 286 is_entry_computation_parameter_ = true; 287 is_parameter_aliased_with_output_ = parameter_aliased_with_output; 288 parameter_number_ = parameter_number; 289 param_shape_index_ = std::move(param_shape_index); 290 } 291 set_constant(bool is_constant)292 void set_constant(bool is_constant) { is_constant_ = is_constant; } set_maybe_live_out(bool value)293 void set_maybe_live_out(bool value) { maybe_live_out_ = value; } set_index(Index index)294 void set_index(Index index) { index_ = index; } set_size(int64 size)295 void set_size(int64 size) { size_ = size; } 296 297 // The index of the allocation in the BufferAssignment. 298 Index index_; 299 300 // Size of the allocation in bytes. 301 int64 size_; 302 303 // Whether this buffer needs to be thread-local. 304 bool is_thread_local_ = false; 305 306 // Whether this buffer holds a tuple. 307 bool is_tuple_ = false; 308 309 // Color of the allocation. 310 LogicalBuffer::Color color_; 311 312 // Whether this allocation holds an entry computation parameter. Entry 313 // computation parameters are special because they have lifetimes which may 314 // outlast the computation. 315 bool is_entry_computation_parameter_ = false; 316 317 // Whether this entry computation parameter is aliased with output. 318 bool is_parameter_aliased_with_output_ = false; 319 320 // If this allocation holds an entry computation parameter, this field 321 // indicates the index (starting from 0) of the parameter. 322 int64 parameter_number_ = 0; 323 324 // If this buffer is for an entry computation parameter, which subshape of the 325 // parameter is it for? 326 ShapeIndex param_shape_index_; 327 328 // Whether the allocation contains a LogicalBuffer which may be live-out of 329 // the entry computation. Note that this flag is conservatively computed by 330 // TuplePointsToAnalysis. That is, an allocation marked `maybe_live_out_` 331 // might not actually escape. 332 bool maybe_live_out_ = false; 333 334 // See comment on the is_constant() accessor. 335 bool is_constant_ = false; 336 337 // Mapping from the set of buffers assigned to this allocation to their 338 // logical offsets and sizes. 339 absl::flat_hash_map<const HloValue*, OffsetSize> assigned_buffers_; 340 341 int64 fragmentation_bytes_ = 0; 342 std::vector<HeapSimulatorTrace> heap_traces_; 343 344 // Set of buffers live at the point of peak memory usage for this allocation. 345 std::vector<const HloValue*> peak_buffers_; 346 }; 347 348 // Add stream operators for nicer output of CHECK/RET_CHECK failures. 349 std::ostream& operator<<(std::ostream& out, const BufferAllocation& s); 350 std::ostream& operator<<(std::ostream& out, const BufferAllocation::Slice& s); 351 352 // This class encapsulates an assignment of the LogicalBuffers in an XLA 353 // module to a set of BufferAllocations. 354 class BufferAssignment { 355 public: 356 // Returns the vector containing all buffer allocations in this assignment. Allocations()357 const std::vector<BufferAllocation>& Allocations() const { 358 return allocations_; 359 } 360 361 // Returns the total size allocation holding all temporary buffers. temp_allocation_total_size()362 int64 temp_allocation_total_size() const { 363 return temp_allocation_total_size_; 364 } 365 366 // Returns whether the given buffer has been assigned an allocation. 367 bool HasAllocation(const HloValue& value) const; 368 369 bool HasAllocation(const HloBuffer& buffer) const; 370 371 // Returns the allocation that a particular LogicalBuffer has been assigned 372 // to. CHECKs if buffer has not been assigned an allocation. 373 const BufferAllocation& GetAssignedAllocation(const HloValue& value) const; 374 375 const BufferAllocation& GetAssignedAllocation( 376 const HloBuffer& hlo_buffer) const; 377 378 // Returns the allocation with the given index. CHECKs if no allocation exists 379 // with the given index. 380 const BufferAllocation& GetAllocation(BufferAllocation::Index index) const; 381 382 // Returns the allocation with the given instruction and shape index. nullptr 383 // if no allocation exists. 384 const BufferAllocation* GetInstructionAllocation( 385 const HloInstruction* hlo, const ShapeIndex& shape_index) const; 386 387 // Builds and returns a vector containing the slices which might contain the 388 // subvalue at the given index of given instruction. 389 std::set<BufferAllocation::Slice> GetAllSlices( 390 const HloInstruction* instruction, const ShapeIndex& index) const; 391 392 // Convenience function which returns whether the buffer of the 393 // instruction at the given index is assigned an allocation. 394 bool HasAllocationAt(const HloInstruction* instruction, 395 const ShapeIndex& index) const; 396 397 // Convenience function which returns whether the top-level buffer of the 398 // instruction (index == {}) is assigned an allocation. 399 bool HasTopLevelAllocation(const HloInstruction* instruction) const; 400 401 // Convenience function which returns the unique slice containing the buffer 402 // at the given index of the given instruction. If a slice is not assigned or 403 // the slice cannot be determined at compile time then an error is returned. 404 StatusOr<BufferAllocation::Slice> GetUniqueSlice( 405 const HloInstruction* instruction, const ShapeIndex& index) const; 406 // Like GetUniqueSlice but fixes the index to the top-level of the shape 407 // (index = {}). 408 StatusOr<BufferAllocation::Slice> GetUniqueTopLevelSlice( 409 const HloInstruction* instruction) const; 410 // Like GetUniqueTopLevelSlice but returns the slice for the output of the 411 // entry computation of the HLO module (ie, the result of the XLA 412 // computation). 413 StatusOr<BufferAllocation::Slice> GetUniqueTopLevelOutputSlice() const; 414 415 // Returns the set BufferValues which may be the source of the value at the 416 // given index and instruction. GetSourceBuffers(const HloInstruction * instruction,const ShapeIndex & index)417 const std::vector<const HloValue*>& GetSourceBuffers( 418 const HloInstruction* instruction, const ShapeIndex& index) const { 419 return dataflow_analysis().GetValueSet(instruction, index).values(); 420 } 421 422 // Returns true if 'hlo_a{shape_index_a}' and 'hlo_b{shape_index_b}' 423 // share the same BufferAllocation::Slice. 424 // Returns false otherwise. 425 // REQUIRES: BufferAssignment assigned allocations to both instructions. 426 bool SharesSliceAtIndex(const HloInstruction* hlo_a, 427 const ShapeIndex& shape_index_a, 428 const HloInstruction* hlo_b, 429 const ShapeIndex& shape_index_b) const; 430 431 // Returns true if the top-level buffers of hlo_a and hlo_b are the same. 432 // REQUIRES: HasTopLevelAllocation(hlo_a) && HasTopLevelAllocation(hlo_b). SharesTopLevelSlice(const HloInstruction * hlo_a,const HloInstruction * hlo_b)433 bool SharesTopLevelSlice(const HloInstruction* hlo_a, 434 const HloInstruction* hlo_b) const { 435 return SharesSliceAtIndex(hlo_a, {}, hlo_b, {}); 436 } 437 438 // Returns true if hlo_a and hlo_b both have at least one buffer assigned for 439 // their top-level and each of their nested shape indices, and if hlo_a's 440 // buffers are all different from hlo_b's buffers. 441 bool HaveDisjointSlices(const HloInstruction* hlo_a, 442 const HloInstruction* hlo_b) const; 443 dataflow_analysis()444 const HloDataflowAnalysis& dataflow_analysis() const { 445 return alias_analysis_->dataflow_analysis(); 446 } 447 alias_analysis()448 HloAliasAnalysis& alias_analysis() const { return *alias_analysis_; } 449 hlo_ordering()450 const HloOrdering& hlo_ordering() const { return *hlo_ordering_; } 451 452 // Returns the HloLiveRange object used to construct this assignment. hlo_live_range()453 const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; } 454 455 string ToString() const; 456 BufferAssignmentProto ToProto() const; 457 458 // Statistics for the assignment. Values initialized to -1 are not always 459 // collected; fragmentation is only collected for instructions that have a 460 // sequential total ordering. 461 struct Stats { 462 int64 parameter_allocation_count = 0; 463 int64 parameter_allocation_bytes = 0; 464 int64 constant_allocation_count = 0; 465 int64 constant_allocation_bytes = 0; 466 int64 maybe_live_out_allocation_count = 0; 467 int64 maybe_live_out_allocation_bytes = 0; 468 int64 preallocated_temp_allocation_count = 0; 469 int64 preallocated_temp_allocation_bytes = 0; 470 int64 preallocated_temp_fragmentation_bytes = -1; 471 int64 total_allocation_count = 0; 472 int64 total_allocation_bytes = 0; 473 int64 total_fragmentation_bytes = -1; 474 475 string ToString() const; 476 }; GetStats()477 const Stats& GetStats() const { return stats_; } 478 479 private: 480 // Only BufferAssigner can build or modify BufferAssignments. 481 friend class BufferAssigner; 482 BufferAssignment(const HloModule * module,std::unique_ptr<HloOrdering> hlo_ordering,BufferValue::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment,std::unique_ptr<HloAliasAnalysis> alias_analysis,std::unique_ptr<HloLiveRange> hlo_live_range)483 BufferAssignment(const HloModule* module, 484 std::unique_ptr<HloOrdering> hlo_ordering, 485 BufferValue::SizeFunction buffer_size, 486 LogicalBuffer::AlignmentFunction color_alignment, 487 std::unique_ptr<HloAliasAnalysis> alias_analysis, 488 std::unique_ptr<HloLiveRange> hlo_live_range) 489 : module_(module), 490 hlo_ordering_(std::move(hlo_ordering)), 491 buffer_size_(std::move(buffer_size)), 492 color_alignment_(std::move(color_alignment)), 493 alias_analysis_(std::move(alias_analysis)), 494 hlo_live_range_(std::move(hlo_live_range)) {} 495 496 // Creates and returns a new BufferAllocation, with no assigned 497 // LogicalBuffers. Ownership is maintained internally. 498 BufferAllocation* NewEmptyAllocation(int64 size, LogicalBuffer::Color color); 499 500 // Helper that calls NewEmptyAllocation and AddAssignment in one call, 501 // creating an allocation containing a single LogicalBuffer. 502 BufferAllocation* NewAllocation(const HloBuffer& buffer, int64 size); 503 504 // Adds a LogicalBuffer to the set assigned to the given allocation. 505 void AddAssignment(BufferAllocation* allocation, const HloBuffer& buffer, 506 int64 offset, int64 size); 507 508 void AddAssignment(BufferAllocation* allocation, const HloValue& value, 509 int64 offset, int64 size); 510 511 // Returns the HloModule used to construct this assignment. module()512 const HloModule& module() const { return *module_; } 513 514 // Mutable accessors for allocations. 515 BufferAllocation* GetMutableAssignedAllocation(const HloBuffer& buffer); 516 BufferAllocation* GetMutableAllocation(BufferAllocation::Index index); 517 HloBufferSize(const HloBuffer & buffer)518 int64 HloBufferSize(const HloBuffer& buffer) { 519 int64 result = buffer_size_(*buffer.values()[0]); 520 for (const HloValue* value : buffer.values()) { 521 DCHECK_EQ(result, buffer_size_(*value)); 522 } 523 return result; 524 } 525 526 // Combines allocations of temporary buffers into one big BufferAllocation. 527 void CombineTempAllocations(); 528 529 // Computes stats for the assignment, to be retrieved by GetStats. 530 Status ComputeSummaryStats(); 531 532 // The vector of buffer allocations. Indexed by BufferAllocation::Index. 533 std::vector<BufferAllocation> allocations_; 534 535 // The total size of all temporary buffers. 536 int64 temp_allocation_total_size_ = 0; 537 538 // Maps Buffers to the index of the BufferAllocation which holds the buffer. 539 absl::flat_hash_map<const HloValue*, BufferAllocation::Index> 540 allocation_index_for_value_; 541 542 const HloModule* module_; 543 544 const std::unique_ptr<HloOrdering> hlo_ordering_; 545 546 // Function which returns the buffer size for a given logical buffer (shape). 547 BufferValue::SizeFunction buffer_size_; 548 549 // Function which returns the alignment for a given logical buffer color. 550 LogicalBuffer::AlignmentFunction color_alignment_; 551 552 std::unique_ptr<HloAliasAnalysis> alias_analysis_; 553 554 std::unique_ptr<HloLiveRange> hlo_live_range_; 555 556 Stats stats_; 557 558 TF_DISALLOW_COPY_AND_ASSIGN(BufferAssignment); 559 }; 560 561 // A class which constructs a buffer assignment. 562 class BufferAssigner { 563 public: 564 using Colorer = std::function<Status(HloAliasAnalysis*, const HloOrdering&)>; 565 DefaultColorer()566 static Colorer DefaultColorer() { 567 return [](HloAliasAnalysis* alias_analysis, const HloOrdering&) { 568 for (HloValue* value : alias_analysis->dataflow_analysis().values()) { 569 const HloPosition& defining_position = value->defining_position(); 570 if (defining_position.shape().has_layout()) { 571 value->set_color(BufferValue::Color( 572 defining_position.shape().layout().memory_space())); 573 } else { 574 value->set_color(BufferValue::Color(0)); 575 } 576 } 577 return Status::OK(); 578 }; 579 } 580 581 // Returns false if a buffer cannot be assigned to given allocation. 582 583 // Build and return a BufferAssignment for the given module. The given 584 // HloOrdering is used to determine buffer liveness. buffer_size and 585 // color_alignment are functions which returns the size and alignment of a 586 // LogicalBuffer. If preset_assignments is provided, those pre-set assignment 587 // offsets will be used. The caller guarantees that those assignments are 588 // valid and they do not overwrite each other. 589 static StatusOr<std::unique_ptr<BufferAssignment>> Run( 590 const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering, 591 BufferValue::SizeFunction buffer_size, 592 LogicalBuffer::AlignmentFunction color_alignment, 593 bool allocate_buffers_for_constants = false, 594 Colorer colorer = DefaultColorer(), 595 const absl::flat_hash_set<HloOpcode>& must_not_live_out = {}, 596 HloDataflowAnalysis::CanShareBuffer can_share_buffer = nullptr, 597 std::unique_ptr<PresetAssignments> preset_assignments = {}); 598 599 private: BufferAssigner(bool allocate_buffers_for_constants,Colorer colorer,const absl::flat_hash_set<HloOpcode> & must_not_live_out,std::unique_ptr<PresetAssignments> preset_assignments)600 BufferAssigner(bool allocate_buffers_for_constants, Colorer colorer, 601 const absl::flat_hash_set<HloOpcode>& must_not_live_out, 602 std::unique_ptr<PresetAssignments> preset_assignments) 603 : allocate_buffers_for_constants_(allocate_buffers_for_constants), 604 colorer_(colorer), 605 must_not_live_out_(must_not_live_out), 606 preset_assignments_(std::move(preset_assignments)) {} 607 virtual ~BufferAssigner() = default; 608 609 // Create a buffer assignment. 610 StatusOr<std::unique_ptr<BufferAssignment>> CreateAssignment( 611 const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering, 612 BufferValue::SizeFunction buffer_size, 613 LogicalBuffer::AlignmentFunction color_alignment, 614 HloDataflowAnalysis::CanShareBuffer can_share_buffer); 615 616 // Assigns buffers to the instructions in the given computations. "assignment" 617 // is modified to reflect the new buffer assignments. If is_thread_local is 618 // true, then all assigned buffers have the is_thread_local flag set to 619 // true. 620 Status AssignBuffersForComputations( 621 const std::vector<const HloComputation*>& computations, 622 bool is_thread_local, 623 absl::flat_hash_map<const HloComputation*, 624 absl::flat_hash_set<const HloValue*>>* 625 buffers_to_assign_sequentially, 626 BufferAssignment* assignment); 627 628 // Returns true if buffer's live range interferences with buffer2's. 629 bool LiveRangeInterferes(const HloValue* buffer1, const HloValue* buffer2, 630 BufferAssignment* assignment); 631 632 // Assigns pre-set assignments, if provided. These assignments will be added 633 // to assigned_buffers and skip buffer allocation. 634 Status AssignPresetBuffers( 635 absl::flat_hash_set<const HloBuffer*>* assigned_buffers, 636 BufferAssignment* assignment); 637 638 // Promotes operations (DUS, scatter) to be done in place: If an operation can 639 // be done in place, merge its buffer with its operand buffer. 640 Status MergeInplaceOpBuffers(BufferAssignment* assignment); 641 642 // Assigns a single hlo buffer to an HLO allocation. 643 Status AssignSingleHloBuffer( 644 const HloBuffer* hlo_buffer, bool is_thread_local, 645 absl::flat_hash_map<const HloComputation*, 646 absl::flat_hash_set<const HloValue*>>* 647 buffers_to_assign_sequentially, 648 std::vector<BufferAllocation::Index>* allocation_indices, 649 BufferAssignment* assignment); 650 651 // Assigns 'buffers_to_assign_sequentially' using heap simulation, assuming 652 // the HLO instructions will be executed in the sequential order given by 653 // assignment->liveness().hlo_ordering().SequentialOrder. If 654 // 'run_whole_module_heap_simulation' is true, the heap simulation will be run 655 // assuming all global computations are sequentially ordered. 656 Status AssignBuffersWithSequentialOrdering( 657 const absl::flat_hash_map<const HloComputation*, 658 absl::flat_hash_set<const HloValue*>>& 659 buffers_to_assign_sequentially, 660 bool run_whole_module_heap_simulation, BufferAssignment* assignment); 661 662 // Uses the results of the heap simulator to create a single allocation, with 663 // LogicalBuffers packed to specific offsets. 664 void AssignBuffersFromHeapSimulator(const HeapSimulator::Result& result, 665 BufferAssignment* assignment, 666 LogicalBuffer::Color color); 667 668 // Tries to assign the given instruction to the given buffer. Returns if the 669 // assignment was successful. 670 bool MaybeAssignBuffer(BufferAllocation* allocation, const HloBuffer& buffer, 671 BufferAssignment* assignment); 672 673 // Split a set of buffers into several sets, each of which contains buffers 674 // colored with the same color. 675 absl::flat_hash_map<LogicalBuffer::Color, 676 absl::flat_hash_set<const HloValue*>, 677 LogicalBuffer::Color::Hasher> 678 SplitBuffersByColor(const absl::flat_hash_set<const HloValue*>& buffers); 679 680 // If true, allocate buffers for constant instructions. 681 bool allocate_buffers_for_constants_; 682 683 // Functor used to assign colors to newly allocated logical buffers. 684 Colorer colorer_; 685 686 // A set of hlo opcodes that can't live out of a computation. 687 absl::flat_hash_set<HloOpcode> must_not_live_out_; 688 689 // Description of any buffer offsets that are already set by an earlier pass. 690 std::unique_ptr<PresetAssignments> preset_assignments_; 691 692 TF_DISALLOW_COPY_AND_ASSIGN(BufferAssigner); 693 }; 694 695 } // namespace xla 696 697 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_ 698