• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
18 
19 #include <functional>
20 #include <iosfwd>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/compiler/xla/service/heap_simulator.h"
29 #include "tensorflow/compiler/xla/service/hlo.pb.h"
30 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
32 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
33 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
34 #include "tensorflow/compiler/xla/service/hlo_live_range.h"
35 #include "tensorflow/compiler/xla/service/hlo_module.h"
36 #include "tensorflow/compiler/xla/service/logical_buffer.h"
37 #include "tensorflow/compiler/xla/service/memory_space_assignment.h"
38 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
39 #include "tensorflow/compiler/xla/statusor.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/core/platform/logging.h"
42 #include "tensorflow/core/platform/macros.h"
43 #include "tensorflow/core/platform/types.h"
44 
45 namespace xla {
46 
47 // Walk the call graph of the HLO module and place each computation into either
48 // thread_local_computations or global_computations depending upon whether the
49 // computation requires thread-local allocations or global allocations. The
50 // elements in thread_local_computations and global_computations are in post
51 // order (if computation A has an instruction which calls computation B, then A
52 // will appear after B in the vector).
53 Status GatherComputationsByAllocationType(
54     const HloModule* module,
55     std::vector<const HloComputation*>* thread_local_computations,
56     std::vector<const HloComputation*>* global_computations);
57 
58 // This class abstracts an allocation of contiguous memory which can hold the
59 // values described by LogicalBuffers. Each LogicalBuffer occupies a sub-range
60 // of the allocation, represented by a Slice. A single BufferAllocation may hold
61 // LogicalBuffers with disjoint liveness, which may have overlapping Slices. A
62 // single BufferAllocation may also hold LogicalBuffers with overlapping
63 // liveness, which must have disjoint Slices.
64 //
65 // The abstraction includes information required by the backends for allocation,
66 // use, and deallocation of the buffer. This includes the LogicalBuffers which
67 // are held in this allocation through the execution of the computation.
68 class BufferAllocation {
69  public:
70   // Holds a unique identifier for each allocation. Values are assigned
71   // contiguously and can be used as array indexes.
72   using Index = int64;
73 
BufferAllocation(Index index,int64 size,LogicalBuffer::Color color)74   BufferAllocation(Index index, int64 size, LogicalBuffer::Color color)
75       : index_(index), size_(size), color_(color) {}
~BufferAllocation()76   ~BufferAllocation() {}
77 
78   // Returns the index of this allocation.
index()79   Index index() const { return index_; }
80 
81   // Whether this allocation is used in a parallel calling context such as
82   // inside of a map or reduce computation. Such allocations need to be thread
83   // local.
is_thread_local()84   bool is_thread_local() const { return is_thread_local_; }
set_is_thread_local(bool is_thread_local)85   void set_is_thread_local(bool is_thread_local) {
86     is_thread_local_ = is_thread_local;
87   }
88 
89   // Whether this allocation can be used by more than one logical buffer.
is_reusable()90   bool is_reusable() const {
91     // We do not reuse thread-local buffers for now, because they are
92     // dynamically allocated and their lifetimes are hard to compute.
93     //
94     // TODO(b/34669761): Don't reuse tuple buffers because the GPU backend
95     // assumes longer buffer liveness than indicated by the analysis.
96     return !is_thread_local() && !is_tuple();
97   }
98 
99   // Whether this allocation is readonly i.e. backed by memory we cannot write
100   // to.
is_readonly()101   bool is_readonly() const {
102     // Entry parameters are generally readonly, except when they are aliased
103     // with any output.
104     return (is_entry_computation_parameter() &&
105             !is_parameter_aliased_with_output_) ||
106            is_constant();
107   }
108 
is_tuple()109   bool is_tuple() const { return is_tuple_; }
set_is_tuple(bool is_tuple)110   void set_is_tuple(bool is_tuple) { is_tuple_ = is_tuple; }
111 
112   // Whether this allocation holds a LogicalBuffer from a parameter of the entry
113   // computation. These buffers have lifetimes which may be longer than the
114   // XLA computation.
is_entry_computation_parameter()115   bool is_entry_computation_parameter() const {
116     return is_entry_computation_parameter_;
117   }
118 
119   // Whether this allocation holds a constant.  On the CPU and GPU backends
120   // constant allocations are not allocated dynamically, instead we resolve
121   // references to these buffer allocations to a global in the readonly section
122   // of the binary.
is_constant()123   bool is_constant() const { return is_constant_; }
124 
125   // If this allocation holds a Buffer from a parameter of the entry
126   // computation, this methods returns the parameter number. CHECKs otherwise.
parameter_number()127   int64 parameter_number() const {
128     CHECK(is_entry_computation_parameter_);
129     return parameter_number_;
130   }
131 
132   // If this allocation is for a parameter of the entry computation, this
133   // function returns which subshape of the parameter the allocation is for.
param_shape_index()134   const ShapeIndex& param_shape_index() const {
135     CHECK(is_entry_computation_parameter_);
136     return param_shape_index_;
137   }
138 
139   // Returns whether this allocation is assigned a LogicalBuffer which may
140   // be live out of the entry computation.
maybe_live_out()141   bool maybe_live_out() const { return maybe_live_out_; }
142 
143   // Returns the size of the allocation. Necessarily this must be at least as
144   // large as any LogicalBuffer assigned to this allocation.
size()145   int64 size() const { return size_; }
146 
147   // Returns the color of the allocation. Only logical buffers with a matching
148   // color can reside in this allocation.
color()149   LogicalBuffer::Color color() const { return color_; }
150 
151   struct OffsetSize {
152     int64 offset = 0;
153     int64 size = 0;
154   };
155 
156   // Access to the logical buffers assigned to this allocation, and their
157   // associated logical offsets and sizes.
assigned_buffers()158   const absl::flat_hash_map<const HloValue*, OffsetSize>& assigned_buffers()
159       const {
160     return assigned_buffers_;
161   }
162 
163   // A Slice represents a contiguous portion of a memory allocation. It is used
164   // to identify the memory range that a LogicalBuffer corresponds to.
165   class Slice {
166    public:
Slice()167     Slice() {}
Slice(const BufferAllocation * allocation,int64 offset,int64 size)168     Slice(const BufferAllocation* allocation, int64 offset, int64 size)
169         : allocation_(allocation), offset_(offset), size_(size) {}
170 
allocation()171     const BufferAllocation* allocation() const { return allocation_; }
index()172     Index index() const { return allocation_->index(); }
offset()173     int64 offset() const { return offset_; }
size()174     int64 size() const { return size_; }
175 
176     bool operator==(const Slice& other) const {
177       return index() == other.index() && offset_ == other.offset_ &&
178              size_ == other.size_;
179     }
180     bool operator!=(const Slice& other) const { return !(*this == other); }
181     bool operator<(const Slice& other) const {
182       if (index() != other.index()) return index() < other.index();
183       if (offset_ != other.offset_) return offset_ < other.offset_;
184       return size_ < other.size_;
185     }
186 
187     // Returns true iff this slice's memory range has a non-empty intersection
188     // with the other slice's memory range.
OverlapsWith(const Slice & other)189     bool OverlapsWith(const Slice& other) const {
190       const int64 end = offset_ + size_;
191       const int64 other_end = other.offset_ + other.size_;
192       return index() == other.index() && offset_ < other_end &&
193              end > other.offset_;
194     }
195 
196     template <typename H>
AbslHashValue(H h,const Slice & s)197     friend H AbslHashValue(H h, const Slice& s) {
198       return H::combine(std::move(h), s.index(), s.offset(), s.size());
199     }
200 
201     string ToString() const;
202 
203    private:
204     const BufferAllocation* allocation_ = nullptr;
205     int64 offset_ = 0;
206     int64 size_ = 0;
207   };
208 
209   // GetSlice returns the Slice of contiguous memory that holds the value
210   // described by the given 'buffer'.
211   // REQUIRES: 'buffer' must be assigned to this allocation.
212   Slice GetSlice(const HloValue& buffer) const;
213 
214   string ToString() const;
215   BufferAllocationProto ToProto() const;
216 
217   // Whether the buffer is a parameter to or live out of the entry computation.
IsInputOrOutput()218   bool IsInputOrOutput() const {
219     return is_entry_computation_parameter() || maybe_live_out();
220   }
221 
222   // Whether the buffer is a temporary buffer allocated before
223   // Executable::ExecuteOnStream.
IsPreallocatedTempBuffer()224   bool IsPreallocatedTempBuffer() const {
225     // Parameters do not need temporary buffers.
226     return !is_entry_computation_parameter() &&
227            // LogicalBuffers that maybe pointed to by the output should live out
228            // of the computation.
229            !maybe_live_out() &&
230            // Thread-local buffers are allocated using `alloca`s.
231            !is_thread_local() &&
232            // Constant buffers are allocated as global values.
233            !is_constant();
234   }
235 
236   // Add a heap trace which was used to assign slices to logical buffers in this
237   // allocation. A single BufferAllocation may include multiple heap traces
238   // in the case of the temporary block where there is a heap trace per
239   // computation.
AddHeapTrace(const HeapSimulatorTrace & heap_trace)240   void AddHeapTrace(const HeapSimulatorTrace& heap_trace) {
241     heap_traces_.push_back(heap_trace);
242   }
243 
244   // Return the set of heap traces used to assign slices to logical buffers in
245   // this allocation.
HeapTraces()246   const std::vector<HeapSimulatorTrace> HeapTraces() const {
247     return heap_traces_;
248   }
249 
250   // Returns the LogicalBuffers which are live at the point of peak memory usage
251   // for this allocation. The point of peak memory usage is the point at which
252   // the total size of all live logical buffers is maximal. If peak memory is
253   // reached at multiple points, the set of logical buffers live at the earliest
254   // maximal point is returned. The vector is stably sorted by
255   // BufferValue::Index.
PeakMemoryLogicalBuffers()256   const std::vector<const HloValue*>& PeakMemoryLogicalBuffers() const {
257     return peak_buffers_;
258   }
259 
260   // Get the number of bytes lost to fragmentation. This is equal to the
261   // difference between the size of the allocation and the size of the maximal
262   // live set.
fragmentation_bytes()263   int64 fragmentation_bytes() const { return fragmentation_bytes_; }
264 
265   bool operator==(const BufferAllocation& other) const {
266     return index_ == other.index_;
267   }
268   bool operator!=(const BufferAllocation& other) const {
269     return !(*this == other);
270   }
271   bool operator<(const BufferAllocation& other) const {
272     return index() < other.index();
273   }
274 
275  private:
276   // Only BufferAssigner and BufferAssignment can modify BufferAllocation.
277   friend class BufferAssigner;
278   friend class BufferAssignment;
279 
280   // Adds a LogicalBuffer to the set assigned to this buffer.
281   void AddAssignment(const HloValue& buffer, int64 offset, int64 size);
282 
set_entry_computation_parameter(int64 parameter_number,ShapeIndex param_shape_index,bool parameter_aliased_with_output)283   void set_entry_computation_parameter(int64 parameter_number,
284                                        ShapeIndex param_shape_index,
285                                        bool parameter_aliased_with_output) {
286     is_entry_computation_parameter_ = true;
287     is_parameter_aliased_with_output_ = parameter_aliased_with_output;
288     parameter_number_ = parameter_number;
289     param_shape_index_ = std::move(param_shape_index);
290   }
291 
set_constant(bool is_constant)292   void set_constant(bool is_constant) { is_constant_ = is_constant; }
set_maybe_live_out(bool value)293   void set_maybe_live_out(bool value) { maybe_live_out_ = value; }
set_index(Index index)294   void set_index(Index index) { index_ = index; }
set_size(int64 size)295   void set_size(int64 size) { size_ = size; }
296 
297   // The index of the allocation in the BufferAssignment.
298   Index index_;
299 
300   // Size of the allocation in bytes.
301   int64 size_;
302 
303   // Whether this buffer needs to be thread-local.
304   bool is_thread_local_ = false;
305 
306   // Whether this buffer holds a tuple.
307   bool is_tuple_ = false;
308 
309   // Color of the allocation.
310   LogicalBuffer::Color color_;
311 
312   // Whether this allocation holds an entry computation parameter. Entry
313   // computation parameters are special because they have lifetimes which may
314   // outlast the computation.
315   bool is_entry_computation_parameter_ = false;
316 
317   // Whether this entry computation parameter is aliased with output.
318   bool is_parameter_aliased_with_output_ = false;
319 
320   // If this allocation holds an entry computation parameter, this field
321   // indicates the index (starting from 0) of the parameter.
322   int64 parameter_number_ = 0;
323 
324   // If this buffer is for an entry computation parameter, which subshape of the
325   // parameter is it for?
326   ShapeIndex param_shape_index_;
327 
328   // Whether the allocation contains a LogicalBuffer which may be live-out of
329   // the entry computation. Note that this flag is conservatively computed by
330   // TuplePointsToAnalysis.  That is, an allocation marked `maybe_live_out_`
331   // might not actually escape.
332   bool maybe_live_out_ = false;
333 
334   // See comment on the is_constant() accessor.
335   bool is_constant_ = false;
336 
337   // Mapping from the set of buffers assigned to this allocation to their
338   // logical offsets and sizes.
339   absl::flat_hash_map<const HloValue*, OffsetSize> assigned_buffers_;
340 
341   int64 fragmentation_bytes_ = 0;
342   std::vector<HeapSimulatorTrace> heap_traces_;
343 
344   // Set of buffers live at the point of peak memory usage for this allocation.
345   std::vector<const HloValue*> peak_buffers_;
346 };
347 
348 // Add stream operators for nicer output of CHECK/RET_CHECK failures.
349 std::ostream& operator<<(std::ostream& out, const BufferAllocation& s);
350 std::ostream& operator<<(std::ostream& out, const BufferAllocation::Slice& s);
351 
352 // This class encapsulates an assignment of the LogicalBuffers in an XLA
353 // module to a set of BufferAllocations.
354 class BufferAssignment {
355  public:
356   // Returns the vector containing all buffer allocations in this assignment.
Allocations()357   const std::vector<BufferAllocation>& Allocations() const {
358     return allocations_;
359   }
360 
361   // Returns the total size allocation holding all temporary buffers.
temp_allocation_total_size()362   int64 temp_allocation_total_size() const {
363     return temp_allocation_total_size_;
364   }
365 
366   // Returns whether the given buffer has been assigned an allocation.
367   bool HasAllocation(const HloValue& value) const;
368 
369   bool HasAllocation(const HloBuffer& buffer) const;
370 
371   // Returns the allocation that a particular LogicalBuffer has been assigned
372   // to. CHECKs if buffer has not been assigned an allocation.
373   const BufferAllocation& GetAssignedAllocation(const HloValue& value) const;
374 
375   const BufferAllocation& GetAssignedAllocation(
376       const HloBuffer& hlo_buffer) const;
377 
378   // Returns the allocation with the given index. CHECKs if no allocation exists
379   // with the given index.
380   const BufferAllocation& GetAllocation(BufferAllocation::Index index) const;
381 
382   // Returns the allocation with the given instruction and shape index. nullptr
383   // if no allocation exists.
384   const BufferAllocation* GetInstructionAllocation(
385       const HloInstruction* hlo, const ShapeIndex& shape_index) const;
386 
387   // Builds and returns a vector containing the slices which might contain the
388   // subvalue at the given index of given instruction.
389   std::set<BufferAllocation::Slice> GetAllSlices(
390       const HloInstruction* instruction, const ShapeIndex& index) const;
391 
392   // Convenience function which returns whether the buffer of the
393   // instruction at the given index is assigned an allocation.
394   bool HasAllocationAt(const HloInstruction* instruction,
395                        const ShapeIndex& index) const;
396 
397   // Convenience function which returns whether the top-level buffer of the
398   // instruction (index == {}) is assigned an allocation.
399   bool HasTopLevelAllocation(const HloInstruction* instruction) const;
400 
401   // Convenience function which returns the unique slice containing the buffer
402   // at the given index of the given instruction. If a slice is not assigned or
403   // the slice cannot be determined at compile time then an error is returned.
404   StatusOr<BufferAllocation::Slice> GetUniqueSlice(
405       const HloInstruction* instruction, const ShapeIndex& index) const;
406   // Like GetUniqueSlice but fixes the index to the top-level of the shape
407   // (index = {}).
408   StatusOr<BufferAllocation::Slice> GetUniqueTopLevelSlice(
409       const HloInstruction* instruction) const;
410   // Like GetUniqueTopLevelSlice but returns the slice for the output of the
411   // entry computation of the HLO module (ie, the result of the XLA
412   // computation).
413   StatusOr<BufferAllocation::Slice> GetUniqueTopLevelOutputSlice() const;
414 
415   // Returns the set BufferValues which may be the source of the value at the
416   // given index and instruction.
GetSourceBuffers(const HloInstruction * instruction,const ShapeIndex & index)417   const std::vector<const HloValue*>& GetSourceBuffers(
418       const HloInstruction* instruction, const ShapeIndex& index) const {
419     return dataflow_analysis().GetValueSet(instruction, index).values();
420   }
421 
422   // Returns true if 'hlo_a{shape_index_a}' and 'hlo_b{shape_index_b}'
423   // share the same BufferAllocation::Slice.
424   // Returns false otherwise.
425   // REQUIRES: BufferAssignment assigned allocations to both instructions.
426   bool SharesSliceAtIndex(const HloInstruction* hlo_a,
427                           const ShapeIndex& shape_index_a,
428                           const HloInstruction* hlo_b,
429                           const ShapeIndex& shape_index_b) const;
430 
431   // Returns true if the top-level buffers of hlo_a and hlo_b are the same.
432   // REQUIRES: HasTopLevelAllocation(hlo_a) && HasTopLevelAllocation(hlo_b).
SharesTopLevelSlice(const HloInstruction * hlo_a,const HloInstruction * hlo_b)433   bool SharesTopLevelSlice(const HloInstruction* hlo_a,
434                            const HloInstruction* hlo_b) const {
435     return SharesSliceAtIndex(hlo_a, {}, hlo_b, {});
436   }
437 
438   // Returns true if hlo_a and hlo_b both have at least one buffer assigned for
439   // their top-level and each of their nested shape indices, and if hlo_a's
440   // buffers are all different from hlo_b's buffers.
441   bool HaveDisjointSlices(const HloInstruction* hlo_a,
442                           const HloInstruction* hlo_b) const;
443 
dataflow_analysis()444   const HloDataflowAnalysis& dataflow_analysis() const {
445     return alias_analysis_->dataflow_analysis();
446   }
447 
alias_analysis()448   HloAliasAnalysis& alias_analysis() const { return *alias_analysis_; }
449 
hlo_ordering()450   const HloOrdering& hlo_ordering() const { return *hlo_ordering_; }
451 
452   // Returns the HloLiveRange object used to construct this assignment.
hlo_live_range()453   const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; }
454 
455   string ToString() const;
456   BufferAssignmentProto ToProto() const;
457 
458   // Statistics for the assignment.  Values initialized to -1 are not always
459   // collected; fragmentation is only collected for instructions that have a
460   // sequential total ordering.
461   struct Stats {
462     int64 parameter_allocation_count = 0;
463     int64 parameter_allocation_bytes = 0;
464     int64 constant_allocation_count = 0;
465     int64 constant_allocation_bytes = 0;
466     int64 maybe_live_out_allocation_count = 0;
467     int64 maybe_live_out_allocation_bytes = 0;
468     int64 preallocated_temp_allocation_count = 0;
469     int64 preallocated_temp_allocation_bytes = 0;
470     int64 preallocated_temp_fragmentation_bytes = -1;
471     int64 total_allocation_count = 0;
472     int64 total_allocation_bytes = 0;
473     int64 total_fragmentation_bytes = -1;
474 
475     string ToString() const;
476   };
GetStats()477   const Stats& GetStats() const { return stats_; }
478 
479  private:
480   // Only BufferAssigner can build or modify BufferAssignments.
481   friend class BufferAssigner;
482 
BufferAssignment(const HloModule * module,std::unique_ptr<HloOrdering> hlo_ordering,BufferValue::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment,std::unique_ptr<HloAliasAnalysis> alias_analysis,std::unique_ptr<HloLiveRange> hlo_live_range)483   BufferAssignment(const HloModule* module,
484                    std::unique_ptr<HloOrdering> hlo_ordering,
485                    BufferValue::SizeFunction buffer_size,
486                    LogicalBuffer::AlignmentFunction color_alignment,
487                    std::unique_ptr<HloAliasAnalysis> alias_analysis,
488                    std::unique_ptr<HloLiveRange> hlo_live_range)
489       : module_(module),
490         hlo_ordering_(std::move(hlo_ordering)),
491         buffer_size_(std::move(buffer_size)),
492         color_alignment_(std::move(color_alignment)),
493         alias_analysis_(std::move(alias_analysis)),
494         hlo_live_range_(std::move(hlo_live_range)) {}
495 
496   // Creates and returns a new BufferAllocation, with no assigned
497   // LogicalBuffers. Ownership is maintained internally.
498   BufferAllocation* NewEmptyAllocation(int64 size, LogicalBuffer::Color color);
499 
500   // Helper that calls NewEmptyAllocation and AddAssignment in one call,
501   // creating an allocation containing a single LogicalBuffer.
502   BufferAllocation* NewAllocation(const HloBuffer& buffer, int64 size);
503 
504   // Adds a LogicalBuffer to the set assigned to the given allocation.
505   void AddAssignment(BufferAllocation* allocation, const HloBuffer& buffer,
506                      int64 offset, int64 size);
507 
508   void AddAssignment(BufferAllocation* allocation, const HloValue& value,
509                      int64 offset, int64 size);
510 
511   // Returns the HloModule used to construct this assignment.
module()512   const HloModule& module() const { return *module_; }
513 
514   // Mutable accessors for allocations.
515   BufferAllocation* GetMutableAssignedAllocation(const HloBuffer& buffer);
516   BufferAllocation* GetMutableAllocation(BufferAllocation::Index index);
517 
HloBufferSize(const HloBuffer & buffer)518   int64 HloBufferSize(const HloBuffer& buffer) {
519     int64 result = buffer_size_(*buffer.values()[0]);
520     for (const HloValue* value : buffer.values()) {
521       DCHECK_EQ(result, buffer_size_(*value));
522     }
523     return result;
524   }
525 
526   // Combines allocations of temporary buffers into one big BufferAllocation.
527   void CombineTempAllocations();
528 
529   // Computes stats for the assignment, to be retrieved by GetStats.
530   Status ComputeSummaryStats();
531 
532   // The vector of buffer allocations. Indexed by BufferAllocation::Index.
533   std::vector<BufferAllocation> allocations_;
534 
535   // The total size of all temporary buffers.
536   int64 temp_allocation_total_size_ = 0;
537 
538   // Maps Buffers to the index of the BufferAllocation which holds the buffer.
539   absl::flat_hash_map<const HloValue*, BufferAllocation::Index>
540       allocation_index_for_value_;
541 
542   const HloModule* module_;
543 
544   const std::unique_ptr<HloOrdering> hlo_ordering_;
545 
546   // Function which returns the buffer size for a given logical buffer (shape).
547   BufferValue::SizeFunction buffer_size_;
548 
549   // Function which returns the alignment for a given logical buffer color.
550   LogicalBuffer::AlignmentFunction color_alignment_;
551 
552   std::unique_ptr<HloAliasAnalysis> alias_analysis_;
553 
554   std::unique_ptr<HloLiveRange> hlo_live_range_;
555 
556   Stats stats_;
557 
558   TF_DISALLOW_COPY_AND_ASSIGN(BufferAssignment);
559 };
560 
561 // A class which constructs a buffer assignment.
562 class BufferAssigner {
563  public:
564   using Colorer = std::function<Status(HloAliasAnalysis*, const HloOrdering&)>;
565 
DefaultColorer()566   static Colorer DefaultColorer() {
567     return [](HloAliasAnalysis* alias_analysis, const HloOrdering&) {
568       for (HloValue* value : alias_analysis->dataflow_analysis().values()) {
569         const HloPosition& defining_position = value->defining_position();
570         if (defining_position.shape().has_layout()) {
571           value->set_color(BufferValue::Color(
572               defining_position.shape().layout().memory_space()));
573         } else {
574           value->set_color(BufferValue::Color(0));
575         }
576       }
577       return Status::OK();
578     };
579   }
580 
581   // Returns false if a buffer cannot be assigned to given allocation.
582 
583   // Build and return a BufferAssignment for the given module. The given
584   // HloOrdering is used to determine buffer liveness. buffer_size and
585   // color_alignment are functions which returns the size and alignment of a
586   // LogicalBuffer. If preset_assignments is provided, those pre-set assignment
587   // offsets will be used. The caller guarantees that those assignments are
588   // valid and they do not overwrite each other.
589   static StatusOr<std::unique_ptr<BufferAssignment>> Run(
590       const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
591       BufferValue::SizeFunction buffer_size,
592       LogicalBuffer::AlignmentFunction color_alignment,
593       bool allocate_buffers_for_constants = false,
594       Colorer colorer = DefaultColorer(),
595       const absl::flat_hash_set<HloOpcode>& must_not_live_out = {},
596       HloDataflowAnalysis::CanShareBuffer can_share_buffer = nullptr,
597       std::unique_ptr<PresetAssignments> preset_assignments = {});
598 
599  private:
BufferAssigner(bool allocate_buffers_for_constants,Colorer colorer,const absl::flat_hash_set<HloOpcode> & must_not_live_out,std::unique_ptr<PresetAssignments> preset_assignments)600   BufferAssigner(bool allocate_buffers_for_constants, Colorer colorer,
601                  const absl::flat_hash_set<HloOpcode>& must_not_live_out,
602                  std::unique_ptr<PresetAssignments> preset_assignments)
603       : allocate_buffers_for_constants_(allocate_buffers_for_constants),
604         colorer_(colorer),
605         must_not_live_out_(must_not_live_out),
606         preset_assignments_(std::move(preset_assignments)) {}
607   virtual ~BufferAssigner() = default;
608 
609   // Create a buffer assignment.
610   StatusOr<std::unique_ptr<BufferAssignment>> CreateAssignment(
611       const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
612       BufferValue::SizeFunction buffer_size,
613       LogicalBuffer::AlignmentFunction color_alignment,
614       HloDataflowAnalysis::CanShareBuffer can_share_buffer);
615 
616   // Assigns buffers to the instructions in the given computations. "assignment"
617   // is modified to reflect the new buffer assignments. If is_thread_local is
618   // true, then all assigned buffers have the is_thread_local flag set to
619   // true.
620   Status AssignBuffersForComputations(
621       const std::vector<const HloComputation*>& computations,
622       bool is_thread_local,
623       absl::flat_hash_map<const HloComputation*,
624                           absl::flat_hash_set<const HloValue*>>*
625           buffers_to_assign_sequentially,
626       BufferAssignment* assignment);
627 
628   // Returns true if buffer's live range interferences with buffer2's.
629   bool LiveRangeInterferes(const HloValue* buffer1, const HloValue* buffer2,
630                            BufferAssignment* assignment);
631 
632   // Assigns pre-set assignments, if provided. These assignments will be added
633   // to assigned_buffers and skip buffer allocation.
634   Status AssignPresetBuffers(
635       absl::flat_hash_set<const HloBuffer*>* assigned_buffers,
636       BufferAssignment* assignment);
637 
638   // Promotes operations (DUS, scatter) to be done in place: If an operation can
639   // be done in place, merge its buffer with its operand buffer.
640   Status MergeInplaceOpBuffers(BufferAssignment* assignment);
641 
642   // Assigns a single hlo buffer to an HLO allocation.
643   Status AssignSingleHloBuffer(
644       const HloBuffer* hlo_buffer, bool is_thread_local,
645       absl::flat_hash_map<const HloComputation*,
646                           absl::flat_hash_set<const HloValue*>>*
647           buffers_to_assign_sequentially,
648       std::vector<BufferAllocation::Index>* allocation_indices,
649       BufferAssignment* assignment);
650 
651   // Assigns 'buffers_to_assign_sequentially' using heap simulation, assuming
652   // the HLO instructions will be executed in the sequential order given by
653   // assignment->liveness().hlo_ordering().SequentialOrder. If
654   // 'run_whole_module_heap_simulation' is true, the heap simulation will be run
655   // assuming all global computations are sequentially ordered.
656   Status AssignBuffersWithSequentialOrdering(
657       const absl::flat_hash_map<const HloComputation*,
658                                 absl::flat_hash_set<const HloValue*>>&
659           buffers_to_assign_sequentially,
660       bool run_whole_module_heap_simulation, BufferAssignment* assignment);
661 
662   // Uses the results of the heap simulator to create a single allocation, with
663   // LogicalBuffers packed to specific offsets.
664   void AssignBuffersFromHeapSimulator(const HeapSimulator::Result& result,
665                                       BufferAssignment* assignment,
666                                       LogicalBuffer::Color color);
667 
668   // Tries to assign the given instruction to the given buffer. Returns if the
669   // assignment was successful.
670   bool MaybeAssignBuffer(BufferAllocation* allocation, const HloBuffer& buffer,
671                          BufferAssignment* assignment);
672 
673   // Split a set of buffers into several sets, each of which contains buffers
674   // colored with the same color.
675   absl::flat_hash_map<LogicalBuffer::Color,
676                       absl::flat_hash_set<const HloValue*>,
677                       LogicalBuffer::Color::Hasher>
678   SplitBuffersByColor(const absl::flat_hash_set<const HloValue*>& buffers);
679 
680   // If true, allocate buffers for constant instructions.
681   bool allocate_buffers_for_constants_;
682 
683   // Functor used to assign colors to newly allocated logical buffers.
684   Colorer colorer_;
685 
686   // A set of hlo opcodes that can't live out of a computation.
687   absl::flat_hash_set<HloOpcode> must_not_live_out_;
688 
689   // Description of any buffer offsets that are already set by an earlier pass.
690   std::unique_ptr<PresetAssignments> preset_assignments_;
691 
692   TF_DISALLOW_COPY_AND_ASSIGN(BufferAssigner);
693 };
694 
695 }  // namespace xla
696 
697 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
698