• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
18 
19 #include <functional>
20 #include <iosfwd>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/compiler/xla/service/heap_simulator.h"
29 #include "tensorflow/compiler/xla/service/hlo.pb.h"
30 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
32 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
33 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
34 #include "tensorflow/compiler/xla/service/hlo_live_range.h"
35 #include "tensorflow/compiler/xla/service/hlo_module.h"
36 #include "tensorflow/compiler/xla/service/logical_buffer.h"
37 #include "tensorflow/compiler/xla/service/memory_space_assignment.h"
38 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
39 #include "tensorflow/compiler/xla/statusor.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/core/platform/logging.h"
42 #include "tensorflow/core/platform/macros.h"
43 #include "tensorflow/core/platform/types.h"
44 
45 namespace xla {
46 
47 // Walk the call graph of the HLO module and place each computation into either
48 // thread_local_computations or global_computations depending upon whether the
49 // computation requires thread-local allocations or global allocations. The
50 // elements in thread_local_computations and global_computations are in post
51 // order (if computation A has an instruction which calls computation B, then A
52 // will appear after B in the vector).
53 Status GatherComputationsByAllocationType(
54     const HloModule* module,
55     std::vector<const HloComputation*>* thread_local_computations,
56     std::vector<const HloComputation*>* global_computations);
57 
58 // This class abstracts an allocation of contiguous memory which can hold the
59 // values described by LogicalBuffers. Each LogicalBuffer occupies a sub-range
60 // of the allocation, represented by a Slice. A single BufferAllocation may hold
61 // LogicalBuffers with disjoint liveness, which may have overlapping Slices. A
62 // single BufferAllocation may also hold LogicalBuffers with overlapping
63 // liveness, which must have disjoint Slices.
64 //
65 // The abstraction includes information required by the backends for allocation,
66 // use, and deallocation of the buffer. This includes the LogicalBuffers which
67 // are held in this allocation through the execution of the computation.
68 class BufferAllocation {
69  public:
70   // Holds a unique identifier for each allocation. Values are assigned
71   // contiguously and can be used as array indexes.
72   using Index = int64;
73 
BufferAllocation(Index index,int64_t size,LogicalBuffer::Color color)74   BufferAllocation(Index index, int64_t size, LogicalBuffer::Color color)
75       : index_(index), size_(size), color_(color) {}
~BufferAllocation()76   ~BufferAllocation() {}
77 
78   // Returns the index of this allocation.
index()79   Index index() const { return index_; }
80 
81   // Whether this allocation is used in a parallel calling context such as
82   // inside of a map or reduce computation. Such allocations need to be thread
83   // local.
is_thread_local()84   bool is_thread_local() const { return is_thread_local_; }
set_is_thread_local(bool is_thread_local)85   void set_is_thread_local(bool is_thread_local) {
86     is_thread_local_ = is_thread_local;
87   }
88 
89   // Whether this allocation can be used by more than one logical buffer.
is_reusable()90   bool is_reusable() const {
91     // We do not reuse thread-local buffers for now, because they are
92     // dynamically allocated and their lifetimes are hard to compute.
93     //
94     // TODO(b/34669761): Don't reuse tuple buffers because the GPU backend
95     // assumes longer buffer liveness than indicated by the analysis.
96     return !is_thread_local() && !is_tuple();
97   }
98 
99   // Whether this allocation is readonly i.e. backed by memory we cannot write
100   // to.
is_readonly()101   bool is_readonly() const {
102     // Entry parameters are generally readonly, except when they are aliased
103     // with any output.
104     return (is_entry_computation_parameter() &&
105             !is_parameter_aliased_with_output_) ||
106            is_constant();
107   }
108 
is_tuple()109   bool is_tuple() const { return is_tuple_; }
set_is_tuple(bool is_tuple)110   void set_is_tuple(bool is_tuple) { is_tuple_ = is_tuple; }
111 
112   // Whether this allocation holds a LogicalBuffer from a parameter of the entry
113   // computation. These buffers have lifetimes which may be longer than the
114   // XLA computation.
is_entry_computation_parameter()115   bool is_entry_computation_parameter() const {
116     return is_entry_computation_parameter_;
117   }
118 
119   // Whether this allocation holds a constant.  On the CPU and GPU backends
120   // constant allocations are not allocated dynamically, instead we resolve
121   // references to these buffer allocations to a global in the readonly section
122   // of the binary.
is_constant()123   bool is_constant() const { return is_constant_; }
124 
125   // If this allocation holds a Buffer from a parameter of the entry
126   // computation, this methods returns the parameter number. CHECKs otherwise.
parameter_number()127   int64 parameter_number() const {
128     CHECK(is_entry_computation_parameter_);
129     return parameter_number_;
130   }
131 
132   // If this allocation is for a parameter of the entry computation, this
133   // function returns which subshape of the parameter the allocation is for.
param_shape_index()134   const ShapeIndex& param_shape_index() const {
135     CHECK(is_entry_computation_parameter_);
136     return param_shape_index_;
137   }
138 
139   // Returns whether this allocation is assigned a LogicalBuffer which may
140   // be live out of the entry computation.
maybe_live_out()141   bool maybe_live_out() const { return maybe_live_out_; }
142 
set_maybe_live_out(bool value)143   void set_maybe_live_out(bool value) { maybe_live_out_ = value; }
144 
145   // Returns the size of the allocation. Necessarily this must be at least as
146   // large as any LogicalBuffer assigned to this allocation.
size()147   int64 size() const { return size_; }
148 
149   // Returns the color of the allocation. Only logical buffers with a matching
150   // color can reside in this allocation.
color()151   LogicalBuffer::Color color() const { return color_; }
152 
153   struct OffsetSize {
154     int64 offset = 0;
155     int64 size = 0;
156   };
157 
158   // Access to the logical buffers assigned to this allocation, and their
159   // associated logical offsets and sizes.
assigned_buffers()160   const absl::flat_hash_map<const HloValue*, OffsetSize>& assigned_buffers()
161       const {
162     return assigned_buffers_;
163   }
164 
165   // A Slice represents a contiguous portion of a memory allocation. It is used
166   // to identify the memory range that a LogicalBuffer corresponds to.
167   class Slice {
168    public:
Slice()169     Slice() {}
Slice(const BufferAllocation * allocation,int64_t offset,int64_t size)170     Slice(const BufferAllocation* allocation, int64_t offset, int64_t size)
171         : allocation_(allocation), offset_(offset), size_(size) {}
172 
allocation()173     const BufferAllocation* allocation() const { return allocation_; }
index()174     Index index() const { return allocation_->index(); }
offset()175     int64 offset() const { return offset_; }
size()176     int64 size() const { return size_; }
177 
178     bool operator==(const Slice& other) const {
179       return index() == other.index() && offset_ == other.offset_ &&
180              size_ == other.size_;
181     }
182     bool operator!=(const Slice& other) const { return !(*this == other); }
183     bool operator<(const Slice& other) const {
184       if (index() != other.index()) return index() < other.index();
185       if (offset_ != other.offset_) return offset_ < other.offset_;
186       return size_ < other.size_;
187     }
188 
189     // Returns true iff this slice's memory range has a non-empty intersection
190     // with the other slice's memory range.
OverlapsWith(const Slice & other)191     bool OverlapsWith(const Slice& other) const {
192       const int64_t end = offset_ + size_;
193       const int64_t other_end = other.offset_ + other.size_;
194       return index() == other.index() && offset_ < other_end &&
195              end > other.offset_;
196     }
197 
198     template <typename H>
AbslHashValue(H h,const Slice & s)199     friend H AbslHashValue(H h, const Slice& s) {
200       return H::combine(std::move(h), s.index(), s.offset(), s.size());
201     }
202 
203     string ToString() const;
204 
205    private:
206     const BufferAllocation* allocation_ = nullptr;
207     int64 offset_ = 0;
208     int64 size_ = 0;
209   };
210 
211   // GetSlice returns the Slice of contiguous memory that holds the value
212   // described by the given 'buffer'.
213   // REQUIRES: 'buffer' must be assigned to this allocation.
214   Slice GetSlice(const HloValue& buffer) const;
215 
216   string ToString() const;
217   BufferAllocationProto ToProto() const;
218 
219   // Whether the buffer is a parameter to or live out of the entry computation.
IsInputOrOutput()220   bool IsInputOrOutput() const {
221     return is_entry_computation_parameter() || maybe_live_out();
222   }
223 
224   // Whether the buffer is a temporary buffer allocated before
225   // Executable::ExecuteOnStream.
IsPreallocatedTempBuffer()226   bool IsPreallocatedTempBuffer() const {
227     // Parameters do not need temporary buffers.
228     return !is_entry_computation_parameter() &&
229            // LogicalBuffers that maybe pointed to by the output should live out
230            // of the computation.
231            !maybe_live_out() &&
232            // Thread-local buffers are allocated using `alloca`s.
233            !is_thread_local() &&
234            // Constant buffers are allocated as global values.
235            !is_constant();
236   }
237 
238   // Add a heap trace which was used to assign slices to logical buffers in this
239   // allocation. A single BufferAllocation may include multiple heap traces
240   // in the case of the temporary block where there is a heap trace per
241   // computation.
AddHeapTrace(const HeapSimulatorTrace & heap_trace)242   void AddHeapTrace(const HeapSimulatorTrace& heap_trace) {
243     heap_traces_.push_back(heap_trace);
244     heap_traces_.back().set_buffer_allocation_index(index());
245   }
246 
247   // Return the set of heap traces used to assign slices to logical buffers in
248   // this allocation.
HeapTraces()249   const std::vector<HeapSimulatorTrace> HeapTraces() const {
250     return heap_traces_;
251   }
252 
253   // Returns the LogicalBuffers which are live at the point of peak memory usage
254   // for this allocation. The point of peak memory usage is the point at which
255   // the total size of all live logical buffers is maximal. If peak memory is
256   // reached at multiple points, the set of logical buffers live at the earliest
257   // maximal point is returned. The vector is stably sorted by
258   // BufferValue::Index.
PeakMemoryLogicalBuffers()259   const std::vector<const HloValue*>& PeakMemoryLogicalBuffers() const {
260     return peak_buffers_;
261   }
262 
263   // Get the number of bytes lost to fragmentation. This is equal to the
264   // difference between the size of the allocation and the size of the maximal
265   // live set.
fragmentation_bytes()266   int64 fragmentation_bytes() const { return fragmentation_bytes_; }
267 
268   bool operator==(const BufferAllocation& other) const {
269     return index_ == other.index_;
270   }
271   bool operator!=(const BufferAllocation& other) const {
272     return !(*this == other);
273   }
274   bool operator<(const BufferAllocation& other) const {
275     return index() < other.index();
276   }
277 
set_entry_computation_parameter(int64_t parameter_number,ShapeIndex param_shape_index,bool parameter_aliased_with_output)278   void set_entry_computation_parameter(int64_t parameter_number,
279                                        ShapeIndex param_shape_index,
280                                        bool parameter_aliased_with_output) {
281     is_entry_computation_parameter_ = true;
282     is_parameter_aliased_with_output_ = parameter_aliased_with_output;
283     parameter_number_ = parameter_number;
284     param_shape_index_ = std::move(param_shape_index);
285   }
286 
set_constant(bool is_constant)287   void set_constant(bool is_constant) { is_constant_ = is_constant; }
288 
289  private:
290   // Only BufferAssigner and BufferAssignment can modify BufferAllocation.
291   friend class BufferAssigner;
292   friend class BufferAssignment;
293 
294   // Adds a LogicalBuffer to the set assigned to this buffer.
295   void AddAssignment(const HloValue& buffer, int64_t offset, int64_t size);
296 
set_index(Index index)297   void set_index(Index index) { index_ = index; }
set_size(int64_t size)298   void set_size(int64_t size) { size_ = size; }
299 
300   // The index of the allocation in the BufferAssignment.
301   Index index_;
302 
303   // Size of the allocation in bytes.
304   int64 size_;
305 
306   // Whether this buffer needs to be thread-local.
307   bool is_thread_local_ = false;
308 
309   // Whether this buffer holds a tuple.
310   bool is_tuple_ = false;
311 
312   // Color of the allocation.
313   LogicalBuffer::Color color_;
314 
315   // Whether this allocation holds an entry computation parameter. Entry
316   // computation parameters are special because they have lifetimes which may
317   // outlast the computation.
318   bool is_entry_computation_parameter_ = false;
319 
320   // Whether this entry computation parameter is aliased with output.
321   bool is_parameter_aliased_with_output_ = false;
322 
323   // If this allocation holds an entry computation parameter, this field
324   // indicates the index (starting from 0) of the parameter.
325   int64 parameter_number_ = 0;
326 
327   // If this buffer is for an entry computation parameter, which subshape of the
328   // parameter is it for?
329   ShapeIndex param_shape_index_;
330 
331   // Whether the allocation contains a LogicalBuffer which may be live-out of
332   // the entry computation. Note that this flag is conservatively computed by
333   // TuplePointsToAnalysis.  That is, an allocation marked `maybe_live_out_`
334   // might not actually escape.
335   bool maybe_live_out_ = false;
336 
337   // See comment on the is_constant() accessor.
338   bool is_constant_ = false;
339 
340   // Mapping from the set of buffers assigned to this allocation to their
341   // logical offsets and sizes.
342   absl::flat_hash_map<const HloValue*, OffsetSize> assigned_buffers_;
343 
344   int64 fragmentation_bytes_ = 0;
345   std::vector<HeapSimulatorTrace> heap_traces_;
346 
347   // Set of buffers live at the point of peak memory usage for this allocation.
348   std::vector<const HloValue*> peak_buffers_;
349 };
350 
351 // Add stream operators for nicer output of CHECK/RET_CHECK failures.
352 std::ostream& operator<<(std::ostream& out, const BufferAllocation& s);
353 std::ostream& operator<<(std::ostream& out, const BufferAllocation::Slice& s);
354 
355 // This class encapsulates an assignment of the LogicalBuffers in an XLA
356 // module to a set of BufferAllocations.
357 class BufferAssignment {
358  public:
359   // Returns the vector containing all buffer allocations in this assignment.
Allocations()360   const std::vector<BufferAllocation>& Allocations() const {
361     return allocations_;
362   }
363 
364   // This is similar to copying Allocations(), but since it's moved out, it
365   // preserves the addresses. Since BufferAllocation::Slice keeps a
366   // BufferAllocation*, and some backends keep BufferAllocation::Slice in
367   // xla::Executables, migrating off the use of addresses can be hard.
ReleaseAllocations()368   std::vector<BufferAllocation> ReleaseAllocations() {
369     return std::move(allocations_);
370   }
371 
372   // Returns the total size allocation holding all temporary buffers.
temp_allocation_total_size()373   int64 temp_allocation_total_size() const {
374     return temp_allocation_total_size_;
375   }
376 
multiheap_size_constraint_per_heap()377   uint64 multiheap_size_constraint_per_heap() const {
378     return multiheap_size_constraint_per_heap_;
379   }
380 
381   // Returns whether the given buffer has been assigned an allocation.
382   bool HasAllocation(const HloValue& value) const;
383 
384   bool HasAllocation(const HloBuffer& buffer) const;
385 
386   // Returns the allocation that a particular LogicalBuffer has been assigned
387   // to. CHECKs if buffer has not been assigned an allocation.
388   const BufferAllocation& GetAssignedAllocation(const HloValue& value) const;
389 
390   const BufferAllocation& GetAssignedAllocation(
391       const HloBuffer& hlo_buffer) const;
392 
393   // Returns the allocation with the given index. CHECKs if no allocation exists
394   // with the given index.
395   const BufferAllocation& GetAllocation(BufferAllocation::Index index) const;
396 
397   // Returns the allocation with the given instruction and shape index. nullptr
398   // if no allocation exists.
399   const BufferAllocation* GetInstructionAllocation(
400       const HloInstruction* hlo, const ShapeIndex& shape_index) const;
401 
402   // Builds and returns a vector containing the slices which might contain the
403   // subvalue at the given index of given instruction.
404   std::set<BufferAllocation::Slice> GetAllSlices(
405       const HloInstruction* instruction, const ShapeIndex& index) const;
406 
407   // Convenience function which returns whether the buffer of the
408   // instruction at the given index is assigned an allocation.
409   bool HasAllocationAt(const HloInstruction* instruction,
410                        const ShapeIndex& index) const;
411 
412   // Convenience function which returns whether the top-level buffer of the
413   // instruction (index == {}) is assigned an allocation.
414   bool HasTopLevelAllocation(const HloInstruction* instruction) const;
415 
416   // Convenience function which returns the unique slice containing the buffer
417   // at the given index of the given instruction. If a slice is not assigned or
418   // the slice cannot be determined at compile time then an error is returned.
419   StatusOr<BufferAllocation::Slice> GetUniqueSlice(
420       const HloInstruction* instruction, const ShapeIndex& index) const;
421   // Like GetUniqueSlice but fixes the index to the top-level of the shape
422   // (index = {}).
423   StatusOr<BufferAllocation::Slice> GetUniqueTopLevelSlice(
424       const HloInstruction* instruction) const;
425   // Like GetUniqueTopLevelSlice but returns the slice for the output of the
426   // entry computation of the HLO module (ie, the result of the XLA
427   // computation).
428   StatusOr<BufferAllocation::Slice> GetUniqueTopLevelOutputSlice() const;
429 
430   // Returns the set BufferValues which may be the source of the value at the
431   // given index and instruction.
GetSourceBuffers(const HloInstruction * instruction,const ShapeIndex & index)432   const std::vector<const HloValue*>& GetSourceBuffers(
433       const HloInstruction* instruction, const ShapeIndex& index) const {
434     return dataflow_analysis().GetValueSet(instruction, index).values();
435   }
436 
437   // Returns true if 'hlo_a{shape_index_a}' and 'hlo_b{shape_index_b}'
438   // share the same BufferAllocation::Slice.
439   // Returns false otherwise.
440   // REQUIRES: BufferAssignment assigned allocations to both instructions.
441   bool SharesSliceAtIndex(const HloInstruction* hlo_a,
442                           const ShapeIndex& shape_index_a,
443                           const HloInstruction* hlo_b,
444                           const ShapeIndex& shape_index_b) const;
445 
446   // Returns true if the top-level buffers of hlo_a and hlo_b are the same.
447   // REQUIRES: HasTopLevelAllocation(hlo_a) && HasTopLevelAllocation(hlo_b).
SharesTopLevelSlice(const HloInstruction * hlo_a,const HloInstruction * hlo_b)448   bool SharesTopLevelSlice(const HloInstruction* hlo_a,
449                            const HloInstruction* hlo_b) const {
450     return SharesSliceAtIndex(hlo_a, {}, hlo_b, {});
451   }
452 
453   // Returns true if hlo_a and hlo_b both have at least one buffer assigned for
454   // their top-level and each of their nested shape indices, and if hlo_a's
455   // buffers are all different from hlo_b's buffers.
456   bool HaveDisjointSlices(const HloInstruction* hlo_a,
457                           const HloInstruction* hlo_b) const;
458 
dataflow_analysis()459   const HloDataflowAnalysis& dataflow_analysis() const {
460     return alias_analysis_->dataflow_analysis();
461   }
462 
alias_analysis()463   HloAliasAnalysis& alias_analysis() const { return *alias_analysis_; }
464 
hlo_ordering()465   const HloOrdering& hlo_ordering() const { return *hlo_ordering_; }
466 
467   // Returns the HloLiveRange object used to construct this assignment.
hlo_live_range()468   const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; }
469 
470   string ToString() const;
471   string BufferInfoString() const;
472   BufferAssignmentProto ToProto() const;
473 
474   // Statistics for the assignment.  Values initialized to -1 are not always
475   // collected; fragmentation is only collected for instructions that have a
476   // sequential total ordering.
477   struct Stats {
478     int64 parameter_allocation_count = 0;
479     int64 parameter_allocation_bytes = 0;
480     int64 constant_allocation_count = 0;
481     int64 constant_allocation_bytes = 0;
482     int64 maybe_live_out_allocation_count = 0;
483     int64 maybe_live_out_allocation_bytes = 0;
484     int64 preallocated_temp_allocation_count = 0;
485     int64 preallocated_temp_allocation_bytes = 0;
486     int64 preallocated_temp_fragmentation_bytes = -1;
487     int64 total_allocation_count = 0;
488     int64 total_allocation_bytes = 0;
489     int64 total_fragmentation_bytes = -1;
490 
491     string ToString() const;
492   };
GetStats()493   const Stats& GetStats() const { return stats_; }
494 
495  private:
496   // Only BufferAssigner can build or modify BufferAssignments.
497   friend class BufferAssigner;
498 
BufferAssignment(const HloModule * module,std::unique_ptr<HloOrdering> hlo_ordering,BufferValue::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment,std::unique_ptr<HloAliasAnalysis> alias_analysis,std::unique_ptr<HloLiveRange> hlo_live_range)499   BufferAssignment(const HloModule* module,
500                    std::unique_ptr<HloOrdering> hlo_ordering,
501                    BufferValue::SizeFunction buffer_size,
502                    LogicalBuffer::AlignmentFunction color_alignment,
503                    std::unique_ptr<HloAliasAnalysis> alias_analysis,
504                    std::unique_ptr<HloLiveRange> hlo_live_range)
505       : module_(module),
506         hlo_ordering_(std::move(hlo_ordering)),
507         buffer_size_(std::move(buffer_size)),
508         color_alignment_(std::move(color_alignment)),
509         alias_analysis_(std::move(alias_analysis)),
510         hlo_live_range_(std::move(hlo_live_range)) {
511     int32_t raw_value = module->config()
512                             .debug_options()
513                             .xla_multiheap_size_constraint_per_heap();
514     // -1 means no constraint.
515     multiheap_size_constraint_per_heap_ =
516         (raw_value == -1) ? UINT64_MAX : raw_value;
517   }
518 
519   // Creates and returns a new BufferAllocation, with no assigned
520   // LogicalBuffers. Ownership is maintained internally.
521   BufferAllocation* NewEmptyAllocation(int64_t size,
522                                        LogicalBuffer::Color color);
523 
524   // Helper that calls NewEmptyAllocation and AddAssignment in one call,
525   // creating an allocation containing a single LogicalBuffer.
526   BufferAllocation* NewAllocation(const HloBuffer& buffer, int64_t size);
527 
528   // Adds a LogicalBuffer to the set assigned to the given allocation.
529   void AddAssignment(BufferAllocation* allocation, const HloBuffer& buffer,
530                      int64_t offset, int64_t size);
531 
532   void AddAssignment(BufferAllocation* allocation, const HloValue& value,
533                      int64_t offset, int64_t size);
534 
535   // Returns the HloModule used to construct this assignment.
module()536   const HloModule& module() const { return *module_; }
537 
538   // Mutable accessors for allocations.
539   BufferAllocation* GetMutableAssignedAllocation(const HloBuffer& buffer);
540   BufferAllocation* GetMutableAllocation(BufferAllocation::Index index);
541 
HloBufferSize(const HloBuffer & buffer)542   int64 HloBufferSize(const HloBuffer& buffer) {
543     int64_t result = buffer_size_(*buffer.values()[0]);
544     for (const HloValue* value : buffer.values()) {
545       DCHECK_EQ(result, buffer_size_(*value));
546     }
547     return result;
548   }
549 
550   // Combines allocations of temporary buffers into one big BufferAllocation.
551   void CombineTempAllocations();
552 
553   // Computes stats for the assignment, to be retrieved by GetStats.
554   Status ComputeSummaryStats();
555 
556   // The vector of buffer allocations. Indexed by BufferAllocation::Index.
557   std::vector<BufferAllocation> allocations_;
558 
559   // The total size of all temporary buffers.
560   int64 temp_allocation_total_size_ = 0;
561 
562   uint64 multiheap_size_constraint_per_heap_;
563 
564   // Maps Buffers to the index of the BufferAllocation which holds the buffer.
565   absl::flat_hash_map<const HloValue*, BufferAllocation::Index>
566       allocation_index_for_value_;
567 
568   const HloModule* module_;
569 
570   const std::unique_ptr<HloOrdering> hlo_ordering_;
571 
572   // Function which returns the buffer size for a given logical buffer (shape).
573   BufferValue::SizeFunction buffer_size_;
574 
575   // Function which returns the alignment for a given logical buffer color.
576   LogicalBuffer::AlignmentFunction color_alignment_;
577 
578   std::unique_ptr<HloAliasAnalysis> alias_analysis_;
579 
580   std::unique_ptr<HloLiveRange> hlo_live_range_;
581 
582   Stats stats_;
583 
584   TF_DISALLOW_COPY_AND_ASSIGN(BufferAssignment);
585 };
586 
587 // A class which constructs a buffer assignment.
588 class BufferAssigner {
589  public:
590   using Colorer = std::function<Status(HloAliasAnalysis*, const HloOrdering&)>;
591 
DefaultColorer()592   static Colorer DefaultColorer() {
593     return [](HloAliasAnalysis* alias_analysis, const HloOrdering&) {
594       for (HloValue* value : alias_analysis->dataflow_analysis().values()) {
595         const HloPosition& defining_position = value->defining_position();
596         if (defining_position.shape().has_layout()) {
597           value->set_color(BufferValue::Color(
598               defining_position.shape().layout().memory_space()));
599         } else {
600           value->set_color(BufferValue::Color(0));
601         }
602       }
603       return Status::OK();
604     };
605   }
606 
607   // Returns false if a buffer cannot be assigned to given allocation.
608 
609   // Build and return a BufferAssignment for the given module. The given
610   // HloOrdering is used to determine buffer liveness. buffer_size and
611   // color_alignment are functions which returns the size and alignment of a
612   // LogicalBuffer. If preset_assignments is provided, those pre-set assignment
613   // offsets will be used. The caller guarantees that those assignments are
614   // valid and they do not overwrite each other.
615   static StatusOr<std::unique_ptr<BufferAssignment>> Run(
616       const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
617       BufferValue::SizeFunction buffer_size,
618       LogicalBuffer::AlignmentFunction color_alignment,
619       bool allocate_buffers_for_constants = false,
620       Colorer colorer = DefaultColorer(),
621       const absl::flat_hash_set<HloOpcode>& must_not_live_out = {},
622       HloDataflowAnalysis::CanShareBuffer can_share_buffer = nullptr,
623       std::unique_ptr<memory_space_assignment::PresetAssignments>
624           preset_assignments = {});
625 
626  private:
BufferAssigner(bool allocate_buffers_for_constants,Colorer colorer,const absl::flat_hash_set<HloOpcode> & must_not_live_out,std::unique_ptr<memory_space_assignment::PresetAssignments> preset_assignments)627   BufferAssigner(bool allocate_buffers_for_constants, Colorer colorer,
628                  const absl::flat_hash_set<HloOpcode>& must_not_live_out,
629                  std::unique_ptr<memory_space_assignment::PresetAssignments>
630                      preset_assignments)
631       : allocate_buffers_for_constants_(allocate_buffers_for_constants),
632         colorer_(colorer),
633         must_not_live_out_(must_not_live_out),
634         preset_assignments_(std::move(preset_assignments)) {}
635   virtual ~BufferAssigner() = default;
636 
637   // Create a buffer assignment.
638   StatusOr<std::unique_ptr<BufferAssignment>> CreateAssignment(
639       const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
640       BufferValue::SizeFunction buffer_size,
641       LogicalBuffer::AlignmentFunction color_alignment,
642       HloDataflowAnalysis::CanShareBuffer can_share_buffer);
643 
644   // Assigns buffers to the instructions in the given computations. "assignment"
645   // is modified to reflect the new buffer assignments. If is_thread_local is
646   // true, then all assigned buffers have the is_thread_local flag set to
647   // true.
648   Status AssignBuffersForComputations(
649       const std::vector<const HloComputation*>& computations,
650       bool is_thread_local,
651       absl::flat_hash_map<const HloComputation*,
652                           absl::flat_hash_set<const HloValue*>>*
653           buffers_to_assign_sequentially,
654       BufferAssignment* assignment);
655 
656   // Returns true if buffer's live range interferences with buffer2's.
657   bool LiveRangeInterferes(const HloValue* buffer1, const HloValue* buffer2,
658                            BufferAssignment* assignment);
659 
660   // Assigns pre-set assignments, if provided. These assignments will be added
661   // to assigned_buffers and skip buffer allocation.
662   Status AssignPresetBuffers(
663       absl::flat_hash_set<const HloBuffer*>* assigned_buffers,
664       BufferAssignment* assignment);
665 
666   // Assigns a single hlo buffer to an HLO allocation.
667   Status AssignSingleHloBuffer(
668       const HloBuffer* hlo_buffer, bool is_thread_local,
669       absl::flat_hash_map<const HloComputation*,
670                           absl::flat_hash_set<const HloValue*>>*
671           buffers_to_assign_sequentially,
672       std::vector<BufferAllocation::Index>* allocation_indices,
673       BufferAssignment* assignment);
674 
675   // Assigns 'buffers_to_assign_sequentially' using heap simulation, assuming
676   // the HLO instructions will be executed in the sequential order given by
677   // assignment->liveness().hlo_ordering().SequentialOrder. If
678   // 'run_whole_module_heap_simulation' is true, the heap simulation will be run
679   // assuming all global computations are sequentially ordered.
680   Status AssignBuffersWithSequentialOrdering(
681       const absl::flat_hash_map<const HloComputation*,
682                                 absl::flat_hash_set<const HloValue*>>&
683           buffers_to_assign_sequentially,
684       bool run_whole_module_heap_simulation, BufferAssignment* assignment);
685 
686   // Uses the results of the heap simulator to create a single allocation, with
687   // LogicalBuffers packed to specific offsets.
688   void AssignBuffersFromHeapSimulator(
689       const HeapSimulator::Result<HloValue>& result,
690       BufferAssignment* assignment, LogicalBuffer::Color color);
691 
692   // Tries to assign the given instruction to the given buffer. Returns if the
693   // assignment was successful.
694   bool MaybeAssignBuffer(BufferAllocation* allocation, const HloBuffer& buffer,
695                          BufferAssignment* assignment);
696 
697   // Split a set of buffers into several sets, each of which contains buffers
698   // colored with the same color.
699   absl::flat_hash_map<LogicalBuffer::Color,
700                       absl::flat_hash_set<const HloValue*>>
701   SplitBuffersByColor(const absl::flat_hash_set<const HloValue*>& buffers);
702 
703   // If true, allocate buffers for constant instructions.
704   bool allocate_buffers_for_constants_;
705 
706   // Functor used to assign colors to newly allocated logical buffers.
707   Colorer colorer_;
708 
709   // A set of hlo opcodes that can't live out of a computation.
710   absl::flat_hash_set<HloOpcode> must_not_live_out_;
711 
712   // Description of any buffer offsets that are already set by an earlier pass.
713   std::unique_ptr<memory_space_assignment::PresetAssignments>
714       preset_assignments_;
715 
716   TF_DISALLOW_COPY_AND_ASSIGN(BufferAssigner);
717 };
718 
719 }  // namespace xla
720 
721 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
722