• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
18 
19 #include <functional>
20 #include <iosfwd>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/compiler/xla/service/heap_simulator.h"
29 #include "tensorflow/compiler/xla/service/hlo.pb.h"
30 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
32 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
33 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
34 #include "tensorflow/compiler/xla/service/hlo_live_range.h"
35 #include "tensorflow/compiler/xla/service/hlo_module.h"
36 #include "tensorflow/compiler/xla/service/logical_buffer.h"
37 #include "tensorflow/compiler/xla/service/memory_space_assignment.h"
38 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
39 #include "tensorflow/compiler/xla/statusor.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/core/platform/logging.h"
42 #include "tensorflow/core/platform/macros.h"
43 #include "tensorflow/core/platform/types.h"
44 
45 namespace xla {
46 
47 // Walk the call graph of the HLO module and place each computation into either
48 // thread_local_computations or global_computations depending upon whether the
49 // computation requires thread-local allocations or global allocations. The
50 // elements in thread_local_computations and global_computations are in post
51 // order (if computation A has an instruction which calls computation B, then A
52 // will appear after B in the vector).
53 Status GatherComputationsByAllocationType(
54     const HloModule* module,
55     std::vector<const HloComputation*>* thread_local_computations,
56     std::vector<const HloComputation*>* global_computations);
57 
58 // This class abstracts an allocation of contiguous memory which can hold the
59 // values described by LogicalBuffers. Each LogicalBuffer occupies a sub-range
60 // of the allocation, represented by a Slice. A single BufferAllocation may hold
61 // LogicalBuffers with disjoint liveness, which may have overlapping Slices. A
62 // single BufferAllocation may also hold LogicalBuffers with overlapping
63 // liveness, which must have disjoint Slices.
64 //
65 // The abstraction includes information required by the backends for allocation,
66 // use, and deallocation of the buffer. This includes the LogicalBuffers which
67 // are held in this allocation through the execution of the computation.
68 class BufferAllocation {
69  public:
70   // Holds a unique identifier for each allocation. Values are assigned
71   // contiguously and can be used as array indexes.
72   using Index = int64;
73 
BufferAllocation(Index index,int64 size,LogicalBuffer::Color color)74   BufferAllocation(Index index, int64 size, LogicalBuffer::Color color)
75       : index_(index), size_(size), color_(color) {}
~BufferAllocation()76   ~BufferAllocation() {}
77 
78   // Returns the index of this allocation.
index()79   Index index() const { return index_; }
80 
81   // Whether this allocation is used in a parallel calling context such as
82   // inside of a map or reduce computation. Such allocations need to be thread
83   // local.
is_thread_local()84   bool is_thread_local() const { return is_thread_local_; }
set_is_thread_local(bool is_thread_local)85   void set_is_thread_local(bool is_thread_local) {
86     is_thread_local_ = is_thread_local;
87   }
88 
89   // Whether this allocation can be used by more than one logical buffer.
is_reusable()90   bool is_reusable() const {
91     // We do not reuse thread-local buffers for now, because they are
92     // dynamically allocated and their lifetimes are hard to compute.
93     //
94     // TODO(b/34669761): Don't reuse tuple buffers because the GPU backend
95     // assumes longer buffer liveness than indicated by the analysis.
96     return !is_thread_local() && !is_tuple();
97   }
98 
99   // Whether this allocation is readonly i.e. backed by memory we cannot write
100   // to.
is_readonly()101   bool is_readonly() const {
102     // Entry parameters are generally readonly, except when they are aliased
103     // with any output.
104     return (is_entry_computation_parameter() &&
105             !is_parameter_aliased_with_output_) ||
106            is_constant();
107   }
108 
is_tuple()109   bool is_tuple() const { return is_tuple_; }
set_is_tuple(bool is_tuple)110   void set_is_tuple(bool is_tuple) { is_tuple_ = is_tuple; }
111 
112   // Whether this allocation holds a LogicalBuffer from a parameter of the entry
113   // computation. These buffers have lifetimes which may be longer than the
114   // XLA computation.
is_entry_computation_parameter()115   bool is_entry_computation_parameter() const {
116     return is_entry_computation_parameter_;
117   }
118 
119   // Whether this allocation holds a constant.  On the CPU and GPU backends
120   // constant allocations are not allocated dynamically, instead we resolve
121   // references to these buffer allocations to a global in the readonly section
122   // of the binary.
is_constant()123   bool is_constant() const { return is_constant_; }
124 
125   // If this allocation holds a Buffer from a parameter of the entry
126   // computation, this methods returns the parameter number. CHECKs otherwise.
parameter_number()127   int64 parameter_number() const {
128     CHECK(is_entry_computation_parameter_);
129     return parameter_number_;
130   }
131 
132   // If this allocation is for a parameter of the entry computation, this
133   // function returns which subshape of the parameter the allocation is for.
param_shape_index()134   const ShapeIndex& param_shape_index() const {
135     CHECK(is_entry_computation_parameter_);
136     return param_shape_index_;
137   }
138 
139   // Returns whether this allocation is assigned a LogicalBuffer which may
140   // be live out of the entry computation.
maybe_live_out()141   bool maybe_live_out() const { return maybe_live_out_; }
142 
set_maybe_live_out(bool value)143   void set_maybe_live_out(bool value) { maybe_live_out_ = value; }
144 
145   // Returns the size of the allocation. Necessarily this must be at least as
146   // large as any LogicalBuffer assigned to this allocation.
size()147   int64 size() const { return size_; }
148 
149   // Returns the color of the allocation. Only logical buffers with a matching
150   // color can reside in this allocation.
color()151   LogicalBuffer::Color color() const { return color_; }
152 
153   struct OffsetSize {
154     int64 offset = 0;
155     int64 size = 0;
156   };
157 
158   // Access to the logical buffers assigned to this allocation, and their
159   // associated logical offsets and sizes.
assigned_buffers()160   const absl::flat_hash_map<const HloValue*, OffsetSize>& assigned_buffers()
161       const {
162     return assigned_buffers_;
163   }
164 
165   // A Slice represents a contiguous portion of a memory allocation. It is used
166   // to identify the memory range that a LogicalBuffer corresponds to.
167   class Slice {
168    public:
Slice()169     Slice() {}
Slice(const BufferAllocation * allocation,int64 offset,int64 size)170     Slice(const BufferAllocation* allocation, int64 offset, int64 size)
171         : allocation_(allocation), offset_(offset), size_(size) {}
172 
allocation()173     const BufferAllocation* allocation() const { return allocation_; }
index()174     Index index() const { return allocation_->index(); }
offset()175     int64 offset() const { return offset_; }
size()176     int64 size() const { return size_; }
177 
178     bool operator==(const Slice& other) const {
179       return index() == other.index() && offset_ == other.offset_ &&
180              size_ == other.size_;
181     }
182     bool operator!=(const Slice& other) const { return !(*this == other); }
183     bool operator<(const Slice& other) const {
184       if (index() != other.index()) return index() < other.index();
185       if (offset_ != other.offset_) return offset_ < other.offset_;
186       return size_ < other.size_;
187     }
188 
189     // Returns true iff this slice's memory range has a non-empty intersection
190     // with the other slice's memory range.
OverlapsWith(const Slice & other)191     bool OverlapsWith(const Slice& other) const {
192       const int64 end = offset_ + size_;
193       const int64 other_end = other.offset_ + other.size_;
194       return index() == other.index() && offset_ < other_end &&
195              end > other.offset_;
196     }
197 
198     template <typename H>
AbslHashValue(H h,const Slice & s)199     friend H AbslHashValue(H h, const Slice& s) {
200       return H::combine(std::move(h), s.index(), s.offset(), s.size());
201     }
202 
203     string ToString() const;
204 
205    private:
206     const BufferAllocation* allocation_ = nullptr;
207     int64 offset_ = 0;
208     int64 size_ = 0;
209   };
210 
211   // GetSlice returns the Slice of contiguous memory that holds the value
212   // described by the given 'buffer'.
213   // REQUIRES: 'buffer' must be assigned to this allocation.
214   Slice GetSlice(const HloValue& buffer) const;
215 
216   string ToString() const;
217   BufferAllocationProto ToProto() const;
218 
219   // Whether the buffer is a parameter to or live out of the entry computation.
IsInputOrOutput()220   bool IsInputOrOutput() const {
221     return is_entry_computation_parameter() || maybe_live_out();
222   }
223 
224   // Whether the buffer is a temporary buffer allocated before
225   // Executable::ExecuteOnStream.
IsPreallocatedTempBuffer()226   bool IsPreallocatedTempBuffer() const {
227     // Parameters do not need temporary buffers.
228     return !is_entry_computation_parameter() &&
229            // LogicalBuffers that maybe pointed to by the output should live out
230            // of the computation.
231            !maybe_live_out() &&
232            // Thread-local buffers are allocated using `alloca`s.
233            !is_thread_local() &&
234            // Constant buffers are allocated as global values.
235            !is_constant();
236   }
237 
238   // Add a heap trace which was used to assign slices to logical buffers in this
239   // allocation. A single BufferAllocation may include multiple heap traces
240   // in the case of the temporary block where there is a heap trace per
241   // computation.
AddHeapTrace(const HeapSimulatorTrace & heap_trace)242   void AddHeapTrace(const HeapSimulatorTrace& heap_trace) {
243     heap_traces_.push_back(heap_trace);
244   }
245 
246   // Return the set of heap traces used to assign slices to logical buffers in
247   // this allocation.
HeapTraces()248   const std::vector<HeapSimulatorTrace> HeapTraces() const {
249     return heap_traces_;
250   }
251 
252   // Returns the LogicalBuffers which are live at the point of peak memory usage
253   // for this allocation. The point of peak memory usage is the point at which
254   // the total size of all live logical buffers is maximal. If peak memory is
255   // reached at multiple points, the set of logical buffers live at the earliest
256   // maximal point is returned. The vector is stably sorted by
257   // BufferValue::Index.
PeakMemoryLogicalBuffers()258   const std::vector<const HloValue*>& PeakMemoryLogicalBuffers() const {
259     return peak_buffers_;
260   }
261 
262   // Get the number of bytes lost to fragmentation. This is equal to the
263   // difference between the size of the allocation and the size of the maximal
264   // live set.
fragmentation_bytes()265   int64 fragmentation_bytes() const { return fragmentation_bytes_; }
266 
267   bool operator==(const BufferAllocation& other) const {
268     return index_ == other.index_;
269   }
270   bool operator!=(const BufferAllocation& other) const {
271     return !(*this == other);
272   }
273   bool operator<(const BufferAllocation& other) const {
274     return index() < other.index();
275   }
276 
set_entry_computation_parameter(int64 parameter_number,ShapeIndex param_shape_index,bool parameter_aliased_with_output)277   void set_entry_computation_parameter(int64 parameter_number,
278                                        ShapeIndex param_shape_index,
279                                        bool parameter_aliased_with_output) {
280     is_entry_computation_parameter_ = true;
281     is_parameter_aliased_with_output_ = parameter_aliased_with_output;
282     parameter_number_ = parameter_number;
283     param_shape_index_ = std::move(param_shape_index);
284   }
285 
286  private:
287   // Only BufferAssigner and BufferAssignment can modify BufferAllocation.
288   friend class BufferAssigner;
289   friend class BufferAssignment;
290 
291   // Adds a LogicalBuffer to the set assigned to this buffer.
292   void AddAssignment(const HloValue& buffer, int64 offset, int64 size);
293 
set_constant(bool is_constant)294   void set_constant(bool is_constant) { is_constant_ = is_constant; }
set_index(Index index)295   void set_index(Index index) { index_ = index; }
set_size(int64 size)296   void set_size(int64 size) { size_ = size; }
297 
298   // The index of the allocation in the BufferAssignment.
299   Index index_;
300 
301   // Size of the allocation in bytes.
302   int64 size_;
303 
304   // Whether this buffer needs to be thread-local.
305   bool is_thread_local_ = false;
306 
307   // Whether this buffer holds a tuple.
308   bool is_tuple_ = false;
309 
310   // Color of the allocation.
311   LogicalBuffer::Color color_;
312 
313   // Whether this allocation holds an entry computation parameter. Entry
314   // computation parameters are special because they have lifetimes which may
315   // outlast the computation.
316   bool is_entry_computation_parameter_ = false;
317 
318   // Whether this entry computation parameter is aliased with output.
319   bool is_parameter_aliased_with_output_ = false;
320 
321   // If this allocation holds an entry computation parameter, this field
322   // indicates the index (starting from 0) of the parameter.
323   int64 parameter_number_ = 0;
324 
325   // If this buffer is for an entry computation parameter, which subshape of the
326   // parameter is it for?
327   ShapeIndex param_shape_index_;
328 
329   // Whether the allocation contains a LogicalBuffer which may be live-out of
330   // the entry computation. Note that this flag is conservatively computed by
331   // TuplePointsToAnalysis.  That is, an allocation marked `maybe_live_out_`
332   // might not actually escape.
333   bool maybe_live_out_ = false;
334 
335   // See comment on the is_constant() accessor.
336   bool is_constant_ = false;
337 
338   // Mapping from the set of buffers assigned to this allocation to their
339   // logical offsets and sizes.
340   absl::flat_hash_map<const HloValue*, OffsetSize> assigned_buffers_;
341 
342   int64 fragmentation_bytes_ = 0;
343   std::vector<HeapSimulatorTrace> heap_traces_;
344 
345   // Set of buffers live at the point of peak memory usage for this allocation.
346   std::vector<const HloValue*> peak_buffers_;
347 };
348 
349 // Add stream operators for nicer output of CHECK/RET_CHECK failures.
350 std::ostream& operator<<(std::ostream& out, const BufferAllocation& s);
351 std::ostream& operator<<(std::ostream& out, const BufferAllocation::Slice& s);
352 
353 // This class encapsulates an assignment of the LogicalBuffers in an XLA
354 // module to a set of BufferAllocations.
355 class BufferAssignment {
356  public:
357   // Returns the vector containing all buffer allocations in this assignment.
Allocations()358   const std::vector<BufferAllocation>& Allocations() const {
359     return allocations_;
360   }
361 
362   // This is similar to copying Allocations(), but since it's moved out, it
363   // preserves the addresses. Since BufferAllocation::Slice keeps a
364   // BufferAllocation*, and some backends keep BufferAllocation::Slice in
365   // xla::Executables, migrating off the use of addresses can be hard.
ReleaseAllocations()366   std::vector<BufferAllocation> ReleaseAllocations() {
367     return std::move(allocations_);
368   }
369 
370   // Returns the total size allocation holding all temporary buffers.
temp_allocation_total_size()371   int64 temp_allocation_total_size() const {
372     return temp_allocation_total_size_;
373   }
374 
multiheap_size_constraint_per_heap()375   uint64 multiheap_size_constraint_per_heap() const {
376     return multiheap_size_constraint_per_heap_;
377   }
378 
379   // Returns whether the given buffer has been assigned an allocation.
380   bool HasAllocation(const HloValue& value) const;
381 
382   bool HasAllocation(const HloBuffer& buffer) const;
383 
384   // Returns the allocation that a particular LogicalBuffer has been assigned
385   // to. CHECKs if buffer has not been assigned an allocation.
386   const BufferAllocation& GetAssignedAllocation(const HloValue& value) const;
387 
388   const BufferAllocation& GetAssignedAllocation(
389       const HloBuffer& hlo_buffer) const;
390 
391   // Returns the allocation with the given index. CHECKs if no allocation exists
392   // with the given index.
393   const BufferAllocation& GetAllocation(BufferAllocation::Index index) const;
394 
395   // Returns the allocation with the given instruction and shape index. nullptr
396   // if no allocation exists.
397   const BufferAllocation* GetInstructionAllocation(
398       const HloInstruction* hlo, const ShapeIndex& shape_index) const;
399 
400   // Builds and returns a vector containing the slices which might contain the
401   // subvalue at the given index of given instruction.
402   std::set<BufferAllocation::Slice> GetAllSlices(
403       const HloInstruction* instruction, const ShapeIndex& index) const;
404 
405   // Convenience function which returns whether the buffer of the
406   // instruction at the given index is assigned an allocation.
407   bool HasAllocationAt(const HloInstruction* instruction,
408                        const ShapeIndex& index) const;
409 
410   // Convenience function which returns whether the top-level buffer of the
411   // instruction (index == {}) is assigned an allocation.
412   bool HasTopLevelAllocation(const HloInstruction* instruction) const;
413 
414   // Convenience function which returns the unique slice containing the buffer
415   // at the given index of the given instruction. If a slice is not assigned or
416   // the slice cannot be determined at compile time then an error is returned.
417   StatusOr<BufferAllocation::Slice> GetUniqueSlice(
418       const HloInstruction* instruction, const ShapeIndex& index) const;
419   // Like GetUniqueSlice but fixes the index to the top-level of the shape
420   // (index = {}).
421   StatusOr<BufferAllocation::Slice> GetUniqueTopLevelSlice(
422       const HloInstruction* instruction) const;
423   // Like GetUniqueTopLevelSlice but returns the slice for the output of the
424   // entry computation of the HLO module (ie, the result of the XLA
425   // computation).
426   StatusOr<BufferAllocation::Slice> GetUniqueTopLevelOutputSlice() const;
427 
428   // Returns the set BufferValues which may be the source of the value at the
429   // given index and instruction.
GetSourceBuffers(const HloInstruction * instruction,const ShapeIndex & index)430   const std::vector<const HloValue*>& GetSourceBuffers(
431       const HloInstruction* instruction, const ShapeIndex& index) const {
432     return dataflow_analysis().GetValueSet(instruction, index).values();
433   }
434 
435   // Returns true if 'hlo_a{shape_index_a}' and 'hlo_b{shape_index_b}'
436   // share the same BufferAllocation::Slice.
437   // Returns false otherwise.
438   // REQUIRES: BufferAssignment assigned allocations to both instructions.
439   bool SharesSliceAtIndex(const HloInstruction* hlo_a,
440                           const ShapeIndex& shape_index_a,
441                           const HloInstruction* hlo_b,
442                           const ShapeIndex& shape_index_b) const;
443 
444   // Returns true if the top-level buffers of hlo_a and hlo_b are the same.
445   // REQUIRES: HasTopLevelAllocation(hlo_a) && HasTopLevelAllocation(hlo_b).
SharesTopLevelSlice(const HloInstruction * hlo_a,const HloInstruction * hlo_b)446   bool SharesTopLevelSlice(const HloInstruction* hlo_a,
447                            const HloInstruction* hlo_b) const {
448     return SharesSliceAtIndex(hlo_a, {}, hlo_b, {});
449   }
450 
451   // Returns true if hlo_a and hlo_b both have at least one buffer assigned for
452   // their top-level and each of their nested shape indices, and if hlo_a's
453   // buffers are all different from hlo_b's buffers.
454   bool HaveDisjointSlices(const HloInstruction* hlo_a,
455                           const HloInstruction* hlo_b) const;
456 
dataflow_analysis()457   const HloDataflowAnalysis& dataflow_analysis() const {
458     return alias_analysis_->dataflow_analysis();
459   }
460 
alias_analysis()461   HloAliasAnalysis& alias_analysis() const { return *alias_analysis_; }
462 
hlo_ordering()463   const HloOrdering& hlo_ordering() const { return *hlo_ordering_; }
464 
465   // Returns the HloLiveRange object used to construct this assignment.
hlo_live_range()466   const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; }
467 
468   string ToString() const;
469   string BufferInfoString() const;
470   BufferAssignmentProto ToProto() const;
471 
472   // Statistics for the assignment.  Values initialized to -1 are not always
473   // collected; fragmentation is only collected for instructions that have a
474   // sequential total ordering.
475   struct Stats {
476     int64 parameter_allocation_count = 0;
477     int64 parameter_allocation_bytes = 0;
478     int64 constant_allocation_count = 0;
479     int64 constant_allocation_bytes = 0;
480     int64 maybe_live_out_allocation_count = 0;
481     int64 maybe_live_out_allocation_bytes = 0;
482     int64 preallocated_temp_allocation_count = 0;
483     int64 preallocated_temp_allocation_bytes = 0;
484     int64 preallocated_temp_fragmentation_bytes = -1;
485     int64 total_allocation_count = 0;
486     int64 total_allocation_bytes = 0;
487     int64 total_fragmentation_bytes = -1;
488 
489     string ToString() const;
490   };
GetStats()491   const Stats& GetStats() const { return stats_; }
492 
493  private:
494   // Only BufferAssigner can build or modify BufferAssignments.
495   friend class BufferAssigner;
496 
BufferAssignment(const HloModule * module,std::unique_ptr<HloOrdering> hlo_ordering,BufferValue::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment,std::unique_ptr<HloAliasAnalysis> alias_analysis,std::unique_ptr<HloLiveRange> hlo_live_range)497   BufferAssignment(const HloModule* module,
498                    std::unique_ptr<HloOrdering> hlo_ordering,
499                    BufferValue::SizeFunction buffer_size,
500                    LogicalBuffer::AlignmentFunction color_alignment,
501                    std::unique_ptr<HloAliasAnalysis> alias_analysis,
502                    std::unique_ptr<HloLiveRange> hlo_live_range)
503       : module_(module),
504         hlo_ordering_(std::move(hlo_ordering)),
505         buffer_size_(std::move(buffer_size)),
506         color_alignment_(std::move(color_alignment)),
507         alias_analysis_(std::move(alias_analysis)),
508         hlo_live_range_(std::move(hlo_live_range)) {
509     int32 raw_value = module->config()
510                           .debug_options()
511                           .xla_multiheap_size_constraint_per_heap();
512     // -1 means no constraint.
513     multiheap_size_constraint_per_heap_ =
514         (raw_value == -1) ? UINT64_MAX : raw_value;
515   }
516 
517   // Creates and returns a new BufferAllocation, with no assigned
518   // LogicalBuffers. Ownership is maintained internally.
519   BufferAllocation* NewEmptyAllocation(int64 size, LogicalBuffer::Color color);
520 
521   // Helper that calls NewEmptyAllocation and AddAssignment in one call,
522   // creating an allocation containing a single LogicalBuffer.
523   BufferAllocation* NewAllocation(const HloBuffer& buffer, int64 size);
524 
525   // Adds a LogicalBuffer to the set assigned to the given allocation.
526   void AddAssignment(BufferAllocation* allocation, const HloBuffer& buffer,
527                      int64 offset, int64 size);
528 
529   void AddAssignment(BufferAllocation* allocation, const HloValue& value,
530                      int64 offset, int64 size);
531 
532   // Returns the HloModule used to construct this assignment.
module()533   const HloModule& module() const { return *module_; }
534 
535   // Mutable accessors for allocations.
536   BufferAllocation* GetMutableAssignedAllocation(const HloBuffer& buffer);
537   BufferAllocation* GetMutableAllocation(BufferAllocation::Index index);
538 
HloBufferSize(const HloBuffer & buffer)539   int64 HloBufferSize(const HloBuffer& buffer) {
540     int64 result = buffer_size_(*buffer.values()[0]);
541     for (const HloValue* value : buffer.values()) {
542       DCHECK_EQ(result, buffer_size_(*value));
543     }
544     return result;
545   }
546 
547   // Combines allocations of temporary buffers into one big BufferAllocation.
548   void CombineTempAllocations();
549 
550   // Computes stats for the assignment, to be retrieved by GetStats.
551   Status ComputeSummaryStats();
552 
553   // The vector of buffer allocations. Indexed by BufferAllocation::Index.
554   std::vector<BufferAllocation> allocations_;
555 
556   // The total size of all temporary buffers.
557   int64 temp_allocation_total_size_ = 0;
558 
559   uint64 multiheap_size_constraint_per_heap_;
560 
561   // Maps Buffers to the index of the BufferAllocation which holds the buffer.
562   absl::flat_hash_map<const HloValue*, BufferAllocation::Index>
563       allocation_index_for_value_;
564 
565   const HloModule* module_;
566 
567   const std::unique_ptr<HloOrdering> hlo_ordering_;
568 
569   // Function which returns the buffer size for a given logical buffer (shape).
570   BufferValue::SizeFunction buffer_size_;
571 
572   // Function which returns the alignment for a given logical buffer color.
573   LogicalBuffer::AlignmentFunction color_alignment_;
574 
575   std::unique_ptr<HloAliasAnalysis> alias_analysis_;
576 
577   std::unique_ptr<HloLiveRange> hlo_live_range_;
578 
579   Stats stats_;
580 
581   TF_DISALLOW_COPY_AND_ASSIGN(BufferAssignment);
582 };
583 
584 // A class which constructs a buffer assignment.
585 class BufferAssigner {
586  public:
587   using Colorer = std::function<Status(HloAliasAnalysis*, const HloOrdering&)>;
588 
DefaultColorer()589   static Colorer DefaultColorer() {
590     return [](HloAliasAnalysis* alias_analysis, const HloOrdering&) {
591       for (HloValue* value : alias_analysis->dataflow_analysis().values()) {
592         const HloPosition& defining_position = value->defining_position();
593         if (defining_position.shape().has_layout()) {
594           value->set_color(BufferValue::Color(
595               defining_position.shape().layout().memory_space()));
596         } else {
597           value->set_color(BufferValue::Color(0));
598         }
599       }
600       return Status::OK();
601     };
602   }
603 
604   // Returns false if a buffer cannot be assigned to given allocation.
605 
606   // Build and return a BufferAssignment for the given module. The given
607   // HloOrdering is used to determine buffer liveness. buffer_size and
608   // color_alignment are functions which returns the size and alignment of a
609   // LogicalBuffer. If preset_assignments is provided, those pre-set assignment
610   // offsets will be used. The caller guarantees that those assignments are
611   // valid and they do not overwrite each other.
612   static StatusOr<std::unique_ptr<BufferAssignment>> Run(
613       const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
614       BufferValue::SizeFunction buffer_size,
615       LogicalBuffer::AlignmentFunction color_alignment,
616       bool allocate_buffers_for_constants = false,
617       Colorer colorer = DefaultColorer(),
618       const absl::flat_hash_set<HloOpcode>& must_not_live_out = {},
619       HloDataflowAnalysis::CanShareBuffer can_share_buffer = nullptr,
620       std::unique_ptr<PresetAssignments> preset_assignments = {});
621 
622  private:
BufferAssigner(bool allocate_buffers_for_constants,Colorer colorer,const absl::flat_hash_set<HloOpcode> & must_not_live_out,std::unique_ptr<PresetAssignments> preset_assignments)623   BufferAssigner(bool allocate_buffers_for_constants, Colorer colorer,
624                  const absl::flat_hash_set<HloOpcode>& must_not_live_out,
625                  std::unique_ptr<PresetAssignments> preset_assignments)
626       : allocate_buffers_for_constants_(allocate_buffers_for_constants),
627         colorer_(colorer),
628         must_not_live_out_(must_not_live_out),
629         preset_assignments_(std::move(preset_assignments)) {}
630   virtual ~BufferAssigner() = default;
631 
632   // Create a buffer assignment.
633   StatusOr<std::unique_ptr<BufferAssignment>> CreateAssignment(
634       const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
635       BufferValue::SizeFunction buffer_size,
636       LogicalBuffer::AlignmentFunction color_alignment,
637       HloDataflowAnalysis::CanShareBuffer can_share_buffer);
638 
639   // Assigns buffers to the instructions in the given computations. "assignment"
640   // is modified to reflect the new buffer assignments. If is_thread_local is
641   // true, then all assigned buffers have the is_thread_local flag set to
642   // true.
643   Status AssignBuffersForComputations(
644       const std::vector<const HloComputation*>& computations,
645       bool is_thread_local,
646       absl::flat_hash_map<const HloComputation*,
647                           absl::flat_hash_set<const HloValue*>>*
648           buffers_to_assign_sequentially,
649       BufferAssignment* assignment);
650 
651   // Returns true if buffer's live range interferences with buffer2's.
652   bool LiveRangeInterferes(const HloValue* buffer1, const HloValue* buffer2,
653                            BufferAssignment* assignment);
654 
655   // Assigns pre-set assignments, if provided. These assignments will be added
656   // to assigned_buffers and skip buffer allocation.
657   Status AssignPresetBuffers(
658       absl::flat_hash_set<const HloBuffer*>* assigned_buffers,
659       BufferAssignment* assignment);
660 
661   // Assigns a single hlo buffer to an HLO allocation.
662   Status AssignSingleHloBuffer(
663       const HloBuffer* hlo_buffer, bool is_thread_local,
664       absl::flat_hash_map<const HloComputation*,
665                           absl::flat_hash_set<const HloValue*>>*
666           buffers_to_assign_sequentially,
667       std::vector<BufferAllocation::Index>* allocation_indices,
668       BufferAssignment* assignment);
669 
670   // Assigns 'buffers_to_assign_sequentially' using heap simulation, assuming
671   // the HLO instructions will be executed in the sequential order given by
672   // assignment->liveness().hlo_ordering().SequentialOrder. If
673   // 'run_whole_module_heap_simulation' is true, the heap simulation will be run
674   // assuming all global computations are sequentially ordered.
675   Status AssignBuffersWithSequentialOrdering(
676       const absl::flat_hash_map<const HloComputation*,
677                                 absl::flat_hash_set<const HloValue*>>&
678           buffers_to_assign_sequentially,
679       bool run_whole_module_heap_simulation, BufferAssignment* assignment);
680 
681   // Uses the results of the heap simulator to create a single allocation, with
682   // LogicalBuffers packed to specific offsets.
683   void AssignBuffersFromHeapSimulator(
684       const HeapSimulator::Result<HloValue>& result,
685       BufferAssignment* assignment, LogicalBuffer::Color color);
686 
687   // Tries to assign the given instruction to the given buffer. Returns if the
688   // assignment was successful.
689   bool MaybeAssignBuffer(BufferAllocation* allocation, const HloBuffer& buffer,
690                          BufferAssignment* assignment);
691 
692   // Split a set of buffers into several sets, each of which contains buffers
693   // colored with the same color.
694   absl::flat_hash_map<LogicalBuffer::Color,
695                       absl::flat_hash_set<const HloValue*>>
696   SplitBuffersByColor(const absl::flat_hash_set<const HloValue*>& buffers);
697 
698   // If true, allocate buffers for constant instructions.
699   bool allocate_buffers_for_constants_;
700 
701   // Functor used to assign colors to newly allocated logical buffers.
702   Colorer colorer_;
703 
704   // A set of hlo opcodes that can't live out of a computation.
705   absl::flat_hash_set<HloOpcode> must_not_live_out_;
706 
707   // Description of any buffer offsets that are already set by an earlier pass.
708   std::unique_ptr<PresetAssignments> preset_assignments_;
709 
710   TF_DISALLOW_COPY_AND_ASSIGN(BufferAssigner);
711 };
712 
713 }  // namespace xla
714 
715 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
716