• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Defines the data returned by the XLA buffer assignment packages.
17 
18 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
19 
20 #include <algorithm>
21 #include <deque>
22 #include <ostream>
23 #include <utility>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/memory/memory.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_format.h"
30 #include "tensorflow/compiler/xla/map_util.h"
31 #include "tensorflow/compiler/xla/service/buffer_value_containers.h"
32 #include "tensorflow/compiler/xla/service/heap_simulator.h"
33 #include "tensorflow/compiler/xla/service/hlo.pb.h"
34 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
35 #include "tensorflow/compiler/xla/shape_util.h"
36 #include "tensorflow/compiler/xla/status_macros.h"
37 #include "tensorflow/compiler/xla/types.h"
38 #include "tensorflow/compiler/xla/util.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/lib/hash/hash.h"
41 #include "tensorflow/core/lib/strings/numbers.h"
42 
43 namespace xla {
44 namespace {
45 
46 using absl::flat_hash_map;
47 using absl::flat_hash_set;
48 using absl::StrAppend;
49 using absl::StrAppendFormat;
50 using ::tensorflow::strings::HumanReadableNumBytes;
51 
52 template <typename T>
ColocatedBufferSetsToString(const T & container,const char * title)53 string ColocatedBufferSetsToString(const T& container, const char* title) {
54   string result;
55   StrAppend(&result, title, "\n");
56   for (const auto& it : container) {
57     StrAppend(&result, "\t", it->ToString(), "\n");
58   }
59   return result;
60 }
61 
62 // Checks that points-to set of 'instruction' is unambiguous and distinct
63 // (ensured by CopyInsertion), then adds the buffer from the points-to set at
64 // 'index' to 'colocated_set'.
AddBufferToColocatedSet(const HloInstruction * instruction,const ShapeIndex & index,const TuplePointsToAnalysis & points_to_analysis,std::vector<const LogicalBuffer * > * colocated_set)65 const LogicalBuffer* AddBufferToColocatedSet(
66     const HloInstruction* instruction, const ShapeIndex& index,
67     const TuplePointsToAnalysis& points_to_analysis,
68     std::vector<const LogicalBuffer*>* colocated_set) {
69   // CopyInsertion ensures root points-to set is unambiguous and distinct.
70   const auto& points_to = points_to_analysis.GetPointsToSet(instruction);
71   DCHECK(!points_to.IsAmbiguous());
72   colocated_set->push_back(points_to.element(index)[0]);
73   return colocated_set->back();
74 }
75 
76 // Given the interference map of a graph (the list of interfering node indices
77 // for each node), perform graph coloring such that interfering nodes are
78 // assigned to different colors. Returns the assigned color of the nodes, where
79 // the colors are represented as integer values [0, color_count).
ColorInterferenceGraph(const std::vector<std::vector<int64>> & interference_map)80 std::vector<int64> ColorInterferenceGraph(
81     const std::vector<std::vector<int64>>& interference_map) {
82   const int64 node_count = interference_map.size();
83 
84   // Sort the nodes such that we assign nodes with more interference first. This
85   // relies on the common heuristic of assigning the most constrained node
86   // first, but it would be good to investigate other ordering heuristics too.
87   std::vector<int64> nodes(node_count);
88   std::iota(nodes.begin(), nodes.end(), 0);
89   absl::c_sort(nodes, [&interference_map](const int64 i, const int64 j) {
90     return interference_map[i].size() > interference_map[j].size();
91   });
92 
93   const int64 kColorUnassigned = -1;
94   std::vector<int64> assigned_colors(node_count, kColorUnassigned);
95   for (int64 node : nodes) {
96     // Mark the colors that are already assigned to the neighbors.
97     std::vector<bool> available_colors(node_count, true);
98     for (int64 neighbor : interference_map[node]) {
99       int64 color = assigned_colors[neighbor];
100       if (color != kColorUnassigned) {
101         available_colors[color] = false;
102       }
103     }
104 
105     // Find the color that is not yet assigned to the neighbors.
106     int64 color = kColorUnassigned;
107     for (color = 0; color < available_colors.size(); ++color) {
108       if (available_colors[color]) {
109         break;
110       }
111     }
112     CHECK_NE(color, kColorUnassigned);
113     assigned_colors[node] = color;
114   }
115   return assigned_colors;
116 }
117 
118 }  // namespace
119 
GatherComputationsByAllocationType(const HloModule * module,std::vector<const HloComputation * > * thread_local_computations,std::vector<const HloComputation * > * global_computations)120 Status GatherComputationsByAllocationType(
121     const HloModule* module,
122     std::vector<const HloComputation*>* thread_local_computations,
123     std::vector<const HloComputation*>* global_computations) {
124   // Create a worklist of computations paired with whether the allocation must
125   // be thread-local.
126   std::deque<std::pair<const HloComputation*, bool>> worklist;
127   worklist.push_back(std::make_pair(module->entry_computation(),
128                                     /*is_thread_local*/ false));
129 
130   // Sets for quickly checking membership. Computations are returned in vectors
131   // for stable iteration.
132   flat_hash_set<const HloComputation*> thread_local_set;
133   flat_hash_set<const HloComputation*> global_set;
134 
135   while (!worklist.empty()) {
136     auto worklist_front = worklist.front();
137     worklist.pop_front();
138     const HloComputation* computation = worklist_front.first;
139     bool is_thread_local = worklist_front.second;
140     bool in_thread_local_set = thread_local_set.contains(computation);
141     bool in_global_set = global_set.contains(computation);
142 
143     // If the computation has already been added to the respective set, then
144     // nothing to do.
145     if ((is_thread_local && in_thread_local_set) ||
146         (!is_thread_local && in_global_set)) {
147       continue;
148     }
149 
150     // If the computation has already been added to the other set this is an
151     // error condition because the global call to the computation (eg,
152     // while/call) may return a reference to one of the thread-local buffers to
153     // the calling computation which will become a dangling reference when the
154     // thread-local is deallocated with the call return.
155     if ((is_thread_local && in_global_set) ||
156         (!is_thread_local && in_thread_local_set)) {
157       return InvalidArgument(
158           "computation %s has conflicting allocation requirements (global "
159           "and thread-local)",
160           computation->name());
161     }
162 
163     if (is_thread_local) {
164       thread_local_set.insert(computation);
165     } else {
166       global_set.insert(computation);
167     }
168 
169     for (auto* instruction : computation->instructions()) {
170       for (HloComputation* subcomputation :
171            instruction->called_computations()) {
172         switch (instruction->opcode()) {
173           case HloOpcode::kCall:
174           case HloOpcode::kConditional:
175           case HloOpcode::kWhile:
176             // Call and while must be called from a computation with global
177             // allocations as they may return references to buffers inside the
178             // called computation which cannot be thread-local.
179             if (is_thread_local) {
180               return InvalidArgument(
181                   "computation %s cannot contain call/while op because it "
182                   "requires thread-local buffer allocations",
183                   computation->name());
184             }
185             worklist.push_back(std::make_pair(subcomputation,
186                                               false));  // Not thread local.
187             break;
188           case HloOpcode::kAllReduce:
189           case HloOpcode::kMap:
190           case HloOpcode::kReduce:
191           case HloOpcode::kReduceWindow:
192           case HloOpcode::kScatter:
193           case HloOpcode::kSelectAndScatter:
194           case HloOpcode::kSort:
195           case HloOpcode::kFusion:
196             // Map/reduce etc computations are always thread-local.
197             worklist.push_back(std::make_pair(subcomputation,
198                                               true));  // Thread local.
199             break;
200           default:
201             return InternalError("Unexpected calling opcode: %s",
202                                  HloOpcodeString(instruction->opcode()));
203         }
204       }
205     }
206   }
207 
208   // Add the computations to the vectors in post order.
209   for (auto* computation : module->MakeComputationPostOrder()) {
210     if (thread_local_set.contains(computation)) {
211       thread_local_computations->push_back(computation);
212     } else if (global_set.contains(computation)) {
213       global_computations->push_back(computation);
214     }
215     // If the computation is not reachable from the entry computation, then it
216     // will not appear in either thread_local_set or global_set. We don't bother
217     // assigning buffers for these.
218   }
219   return Status::OK();
220 }
221 
ToString() const222 string BufferAllocation::Slice::ToString() const {
223   return absl::StrCat("{index:", index(), ", offset:", offset_,
224                       ", size:", size_, "}");
225 }
226 
GetSlice(const LogicalBuffer & buffer) const227 BufferAllocation::Slice BufferAllocation::GetSlice(
228     const LogicalBuffer& buffer) const {
229   const OffsetSize os = FindOrDie(assigned_buffers_, &buffer);
230   return Slice(this, os.offset, os.size);
231 }
232 
AddAssignment(const LogicalBuffer & buffer,int64 offset,int64 size)233 void BufferAllocation::AddAssignment(const LogicalBuffer& buffer, int64 offset,
234                                      int64 size) {
235   VLOG(4) << "Trying to add " << buffer << " to allocation #" << index();
236   CHECK(!assigned_buffers_.contains(&buffer))
237       << "LogicalBuffer " << buffer << " already assigned to allocation "
238       << index_;
239   CHECK_LE(offset, size_) << "LogicalBuffer " << buffer
240                           << " offset out of range";
241   CHECK_LE(offset + size, size_)
242       << "LogicalBuffer " << buffer
243       << " size out of range at offset: " << offset << " with size: " << size;
244   CHECK_EQ(buffer.color(), color())
245       << "Buffer color " << buffer.color() << " for buffer " << buffer
246       << " does not match allocation color " << color() << ".";
247   OffsetSize offset_size;
248   offset_size.offset = offset;
249   offset_size.size = size;
250   assigned_buffers_.emplace(&buffer, offset_size);
251 }
252 
ToProto() const253 BufferAllocationProto BufferAllocation::ToProto() const {
254   BufferAllocationProto proto;
255   proto.set_index(index_);
256   proto.set_size(size_);
257   proto.set_is_thread_local(is_thread_local_);
258   proto.set_is_tuple(is_tuple_);
259   proto.set_color(color_.value());
260   if (is_entry_computation_parameter_) {
261     proto.set_is_entry_computation_parameter(true);
262     for (int64 idx : param_shape_index()) {
263       proto.add_parameter_shape_index(idx);
264     }
265     proto.set_parameter_number(parameter_number_);
266   }
267   proto.set_is_constant(is_constant_);
268   proto.set_maybe_live_out(maybe_live_out_);
269   for (const auto& buffer_offset_size : assigned_buffers_) {
270     BufferAllocationProto::Assigned* proto_assigned = proto.add_assigned();
271     proto_assigned->set_logical_buffer_id(buffer_offset_size.first->id());
272     proto_assigned->set_offset(buffer_offset_size.second.offset);
273     proto_assigned->set_size(buffer_offset_size.second.size);
274   }
275   absl::c_sort(*proto.mutable_assigned(),
276                [](const BufferAllocationProto::Assigned& assign1,
277                   const BufferAllocationProto::Assigned& assign2) {
278                  return assign1.logical_buffer_id() <
279                         assign2.logical_buffer_id();
280                });
281   return proto;
282 }
283 
ToString() const284 string BufferAllocation::ToString() const {
285   string output;
286   StrAppendFormat(&output, "allocation %d: %p, size %d", index_, this, size());
287   if (color().value() != 0) {
288     StrAppend(&output, ", color ", color().value());
289   }
290   if (is_entry_computation_parameter()) {
291     StrAppend(&output, ", parameter ", parameter_number(), " at ShapeIndex ",
292               param_shape_index().ToString());
293   }
294   if (is_constant()) {
295     StrAppend(&output, ", constant");
296   }
297   if (is_thread_local()) {
298     StrAppend(&output, ", thread-local");
299   }
300   if (maybe_live_out()) {
301     StrAppend(&output, ", maybe-live-out");
302   }
303   if (IsPreallocatedTempBuffer()) {
304     StrAppend(&output, ", preallocated-temp");
305   }
306   StrAppend(&output, ":\n");
307   // Dump the assigned buffers ordered by id.
308   std::vector<const LogicalBuffer*> sorted_buffers;
309   for (const auto& buffer_offset_size : assigned_buffers_) {
310     sorted_buffers.push_back(buffer_offset_size.first);
311   }
312   absl::c_sort(sorted_buffers,
313                [](const LogicalBuffer* a, const LogicalBuffer* b) {
314                  return a->id() < b->id();
315                });
316   for (const LogicalBuffer* buffer : sorted_buffers) {
317     const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer);
318     StrAppend(&output, absl::StrFormat(
319                            "  %s [%d,%d]: %s\n", buffer->ToString(),
320                            offset_size.offset, offset_size.size,
321                            ShapeUtil::HumanStringWithLayout(buffer->shape())));
322   }
323   return output;
324 }
325 
operator <<(std::ostream & out,const BufferAllocation & buffer)326 std::ostream& operator<<(std::ostream& out, const BufferAllocation& buffer) {
327   out << buffer.ToString();
328   return out;
329 }
330 
operator <<(std::ostream & out,const BufferAllocation::Slice & s)331 std::ostream& operator<<(std::ostream& out, const BufferAllocation::Slice& s) {
332   out << s.ToString();
333   return out;
334 }
335 
GetPointsToSet(const HloInstruction * instruction) const336 const PointsToSet& BufferAssignment::GetPointsToSet(
337     const HloInstruction* instruction) const {
338   return points_to_analysis().GetPointsToSet(instruction);
339 }
340 
HasAllocation(const LogicalBuffer & buffer) const341 bool BufferAssignment::HasAllocation(const LogicalBuffer& buffer) const {
342   TF_CHECK_OK(points_to_analysis().VerifyBuffer(buffer));
343   return allocation_index_for_buffer_.contains(&buffer);
344 }
345 
GetAssignedAllocation(const LogicalBuffer & buffer) const346 const BufferAllocation& BufferAssignment::GetAssignedAllocation(
347     const LogicalBuffer& buffer) const {
348   CHECK(HasAllocation(buffer));
349   return GetAllocation(allocation_index_for_buffer_.at(&buffer));
350 }
351 
GetMutableAssignedAllocation(const LogicalBuffer & buffer)352 BufferAllocation* BufferAssignment::GetMutableAssignedAllocation(
353     const LogicalBuffer& buffer) {
354   return const_cast<BufferAllocation*>(&GetAssignedAllocation(buffer));
355 }
356 
GetAllSlices(const HloInstruction * instruction,const ShapeIndex & index) const357 std::set<BufferAllocation::Slice> BufferAssignment::GetAllSlices(
358     const HloInstruction* instruction, const ShapeIndex& index) const {
359   std::set<BufferAllocation::Slice> result;
360   for (const LogicalBuffer* buffer : GetSourceBuffers(instruction, index)) {
361     if (HasAllocation(*buffer)) {
362       result.insert(GetAssignedAllocation(*buffer).GetSlice(*buffer));
363     }
364   }
365   return result;
366 }
367 
GetAllocation(BufferAllocation::Index index) const368 const BufferAllocation& BufferAssignment::GetAllocation(
369     BufferAllocation::Index index) const {
370   CHECK_GE(index, 0);
371   CHECK_LT(index, allocations_.size());
372   return allocations_[index];
373 }
374 
GetInstructionAllocation(const HloInstruction * hlo,const ShapeIndex & shape_index) const375 const BufferAllocation* BufferAssignment::GetInstructionAllocation(
376     const HloInstruction* hlo, const ShapeIndex& shape_index) const {
377   const PointsToSet& points_to_set = points_to_analysis().GetPointsToSet(hlo);
378   const LogicalBuffer* buffer = points_to_set.element(shape_index)[0];
379 
380   if (!HasAllocation(*buffer)) {
381     return nullptr;
382   }
383 
384   const BufferAllocation& instruction_allocation =
385       GetAssignedAllocation(*buffer);
386   return &instruction_allocation;
387 }
388 
GetMutableAllocation(BufferAllocation::Index index)389 BufferAllocation* BufferAssignment::GetMutableAllocation(
390     BufferAllocation::Index index) {
391   return const_cast<BufferAllocation*>(&GetAllocation(index));
392 }
393 
HasAllocationAt(const HloInstruction * instruction,const ShapeIndex & index) const394 bool BufferAssignment::HasAllocationAt(const HloInstruction* instruction,
395                                        const ShapeIndex& index) const {
396   for (const LogicalBuffer* buffer :
397        GetPointsToSet(instruction).element(index)) {
398     if (allocation_index_for_buffer_.contains(buffer)) {
399       return true;
400     }
401   }
402   return false;
403 }
404 
HasTopLevelAllocation(const HloInstruction * instruction) const405 bool BufferAssignment::HasTopLevelAllocation(
406     const HloInstruction* instruction) const {
407   return HasAllocationAt(instruction, /*index=*/{});
408 }
409 
GetUniqueSlice(const HloInstruction * instruction,const ShapeIndex & index) const410 StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueSlice(
411     const HloInstruction* instruction, const ShapeIndex& index) const {
412   VLOG(3) << "Trying to find unique slice for " << instruction->name() << " ["
413           << index << "]";
414   BufferAllocation::Slice result;
415   for (const LogicalBuffer* buffer :
416        GetPointsToSet(instruction).element(index)) {
417     VLOG(3) << "Examining buffer " << *buffer;
418     if (HasAllocation(*buffer)) {
419       VLOG(3) << "Has allocation";
420       const BufferAllocation::Slice slice =
421           GetAssignedAllocation(*buffer).GetSlice(*buffer);
422       if (result.allocation() == nullptr) {
423         result = slice;
424       } else if (result != slice) {
425         return FailedPrecondition(
426             "BufferAllocation::Slice for instruction %s at index %s cannot "
427             "be determined at compile-time.",
428             instruction->name(), index.ToString());
429       }
430     } else {
431       VLOG(3) << "No allocation";
432     }
433   }
434   if (result.allocation() == nullptr) {
435     return FailedPrecondition(
436         "BufferAllocation::Slice not assigned for instruction %s at index %s",
437         instruction->name(), index.ToString());
438   }
439   return result;
440 }
441 
GetUniqueTopLevelSlice(const HloInstruction * instruction) const442 StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueTopLevelSlice(
443     const HloInstruction* instruction) const {
444   return GetUniqueSlice(instruction, /*index=*/{});
445 }
446 
SharesSliceAtIndex(const HloInstruction * hlo_a,const ShapeIndex & shape_index_a,const HloInstruction * hlo_b,const ShapeIndex & shape_index_b) const447 bool BufferAssignment::SharesSliceAtIndex(
448     const HloInstruction* hlo_a, const ShapeIndex& shape_index_a,
449     const HloInstruction* hlo_b, const ShapeIndex& shape_index_b) const {
450   return GetUniqueSlice(hlo_a, shape_index_a).ConsumeValueOrDie() ==
451          GetUniqueSlice(hlo_b, shape_index_b).ConsumeValueOrDie();
452 }
453 
HaveDisjointSlices(const HloInstruction * hlo_a,const HloInstruction * hlo_b) const454 bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a,
455                                           const HloInstruction* hlo_b) const {
456   using SliceSet = flat_hash_set<BufferAllocation::Slice>;
457   // Gets the slices all of instr's subshapes.  If any subshape doesn't have an
458   // assigned slice, returns the empty set.
459   auto collect_slices = [&](const HloInstruction* instr) -> SliceSet {
460     SliceSet slices;
461     Status status = ShapeUtil::ForEachSubshapeWithStatus(
462         instr->shape(),
463         [&](const Shape& /*subshape*/, const ShapeIndex& index) {
464           auto shape_slices = GetAllSlices(instr, index);
465           if (shape_slices.empty()) {
466             return InvalidArgument("No slices assigned to part of instr.");
467           }
468           slices.insert(shape_slices.begin(), shape_slices.end());
469           return Status::OK();
470         });
471     if (!status.ok()) {
472       return {};
473     }
474     return slices;
475   };
476 
477   SliceSet slices_a = collect_slices(hlo_a);
478   SliceSet slices_b = collect_slices(hlo_b);
479   // hlo_a and hlo_b have disjoint slices if collect_slices succeeded (i.e.
480   // didn't return the empty set) for both HLOs, and the two resulting sets of
481   // slices are disjoint.
482   return !slices_a.empty() && !slices_b.empty() &&
483          absl::c_none_of(slices_a, [&](const BufferAllocation::Slice& slice) {
484            return slices_b.contains(slice);
485          });
486 }
487 
488 StatusOr<BufferAllocation::Slice>
GetUniqueTopLevelOutputSlice() const489 BufferAssignment::GetUniqueTopLevelOutputSlice() const {
490   return GetUniqueTopLevelSlice(
491       module_->entry_computation()->root_instruction());
492 }
493 
NewEmptyAllocation(int64 size,LogicalBuffer::Color color)494 BufferAllocation* BufferAssignment::NewEmptyAllocation(
495     int64 size, LogicalBuffer::Color color) {
496   BufferAllocation::Index index = allocations_.size();
497   allocations_.emplace_back(index, size, color);
498   BufferAllocation* allocation = &allocations_.back();
499   return allocation;
500 }
501 
NewAllocation(const LogicalBuffer & buffer,int64 size)502 BufferAllocation* BufferAssignment::NewAllocation(const LogicalBuffer& buffer,
503                                                   int64 size) {
504   BufferAllocation* allocation = NewEmptyAllocation(size, buffer.color());
505   AddAssignment(allocation, buffer, /*offset=*/0, size);
506   allocation->peak_buffers_.push_back(&buffer);
507   return allocation;
508 }
509 
510 // Adds an instruction to the set assigned to the given buffer.
AddAssignment(BufferAllocation * allocation,const LogicalBuffer & buffer,int64 offset,int64 size)511 void BufferAssignment::AddAssignment(BufferAllocation* allocation,
512                                      const LogicalBuffer& buffer, int64 offset,
513                                      int64 size) {
514   CHECK(!allocation_index_for_buffer_.contains(&buffer))
515       << "LogicalBuffer " << buffer << " already has an allocation.";
516   CHECK(allocation->is_reusable() || allocation->assigned_buffers().empty())
517       << "Non-reusable allocation already assigned a buffer: "
518       << allocation->ToString();
519 
520   TF_CHECK_OK(points_to_analysis().VerifyBuffer(buffer));
521 
522   allocation->AddAssignment(buffer, offset, size);
523   if (liveness().MaybeLiveOut(buffer)) {
524     allocation->set_maybe_live_out(true);
525   }
526   allocation_index_for_buffer_[&buffer] = allocation->index();
527 }
528 
529 // Combines allocations of temporary buffers of the same color into one big
530 // BufferAllocation.
CombineTempAllocations()531 void BufferAssignment::CombineTempAllocations() {
532   VLOG(1) << "CombineTempAllocations()";
533   flat_hash_map<LogicalBuffer::Color, BufferAllocation,
534                 LogicalBuffer::Color::Hasher>
535       combined_allocation_map;
536 
537   // Move all temp allocations into a single run at the end of the allocations
538   // vector.
539   const auto first_temp_it =
540       std::partition(allocations_.begin(), allocations_.end(),
541                      [](const BufferAllocation& allocation) {
542                        return !allocation.IsPreallocatedTempBuffer();
543                      });
544 
545   // Walk over the run of temp allocations, collecting the allocations belonging
546   // to the same color.
547   if (first_temp_it != allocations_.end()) {
548     for (auto it = first_temp_it; it != allocations_.end(); ++it) {
549       const BufferAllocation& temp_allocation = *it;
550       LogicalBuffer::Color color = temp_allocation.color();
551       auto combined_it = combined_allocation_map.find(color);
552       if (combined_it == combined_allocation_map.end()) {
553         // We have found the first temp allocation of this color. Collect
554         // the other temp allocations of the same color into it.
555         VLOG(1) << "Combined temp allocation for color " << color
556                 << " is: " << temp_allocation;
557         combined_allocation_map.emplace(color, temp_allocation);
558         continue;
559       }
560 
561       auto* combined_allocation = &combined_it->second;
562       VLOG(1) << "Combined allocation absorbing temp allocation: "
563               << temp_allocation;
564 
565       // Each temp allocation is placed end-to-end, accounting for alignment.
566       // The offset of each buffer in the combined allocation is computed from
567       // the base offset of the allocation.
568       int64 alignment = color_alignment_(color);
569       const int64 base =
570           RoundUpToNearest(combined_allocation->size(), alignment);
571       combined_allocation->set_size(base + temp_allocation.size());
572       for (const auto& buffer_offset_size : temp_allocation.assigned_buffers_) {
573         const LogicalBuffer* buffer = buffer_offset_size.first;
574         const int64 offset = buffer_offset_size.second.offset;
575         const int64 size = buffer_offset_size.second.size;
576         combined_allocation->AddAssignment(*buffer, base + offset, size);
577       }
578       if (!temp_allocation.HeapTraces().empty()) {
579         CHECK_EQ(temp_allocation.HeapTraces().size(), 1);
580         combined_allocation->AddHeapTrace(temp_allocation.HeapTraces().front());
581       }
582       combined_allocation->peak_buffers_.insert(
583           combined_allocation->peak_buffers_.end(),
584           temp_allocation.peak_buffers_.begin(),
585           temp_allocation.peak_buffers_.end());
586     }
587     // Replace all existing temporary allocations with the new combined
588     // allocations.
589     allocations_.erase(first_temp_it, allocations_.end());
590     for (auto& combined : combined_allocation_map) {
591       allocations_.push_back(combined.second);
592       temp_allocation_total_size_ += combined.second.size();
593     }
594   }
595 
596   // Update allocation indices to their new positions.
597   allocation_index_for_buffer_.erase(allocation_index_for_buffer_.begin(),
598                                      allocation_index_for_buffer_.end());
599   for (size_t index = 0; index < allocations_.size(); ++index) {
600     BufferAllocation* allocation = &allocations_[index];
601     allocation->set_index(index);
602     for (const auto& buffer_offset_size : allocation->assigned_buffers_) {
603       const LogicalBuffer* buffer = buffer_offset_size.first;
604       allocation_index_for_buffer_[buffer] = index;
605     }
606   }
607 }
608 
ComputeSummaryStats()609 Status BufferAssignment::ComputeSummaryStats() {
610   for (auto& allocation : Allocations()) {
611     if (allocation.is_entry_computation_parameter()) {
612       stats_.parameter_allocation_count++;
613       stats_.parameter_allocation_bytes += allocation.size();
614     }
615     if (allocation.is_constant()) {
616       stats_.constant_allocation_count++;
617       stats_.constant_allocation_bytes += allocation.size();
618     }
619     if (allocation.maybe_live_out()) {
620       stats_.maybe_live_out_allocation_count++;
621       stats_.maybe_live_out_allocation_bytes += allocation.size();
622     }
623     if (allocation.IsPreallocatedTempBuffer()) {
624       stats_.preallocated_temp_allocation_count++;
625       stats_.preallocated_temp_allocation_bytes += allocation.size();
626     }
627     stats_.total_allocation_count++;
628     stats_.total_allocation_bytes += allocation.size();
629   }
630 
631   // Only compute total fragmentation if all computations have schedules.
632   HloSchedule schedule(module_);
633   bool schedule_complete = true;
634   for (const auto& computation : module_->computations()) {
635     if (!computation->IsFusionComputation()) {
636       const HloInstructionSequence* sequence =
637           liveness_->hlo_ordering().SequentialOrder(*computation);
638       if (sequence == nullptr) {
639         schedule_complete = false;
640       } else {
641         schedule.set_sequence(computation, *sequence);
642       }
643     }
644   }
645   if (schedule_complete) {
646     TF_RETURN_IF_ERROR(schedule.Verify());
647     TF_ASSIGN_OR_RETURN(
648         const int64 min_size,
649         HeapSimulator::MinimumMemoryForModule(schedule, buffer_size_));
650     stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size;
651   }
652 
653   return Status::OK();
654 }
655 
ToString() const656 string BufferAssignment::Stats::ToString() const {
657   string s;
658   StrAppendFormat(&s, "BufferAssignment stats:\n");
659   StrAppendFormat(&s, "             parameter allocation: %10s\n",
660                   HumanReadableNumBytes(parameter_allocation_bytes));
661   StrAppendFormat(&s, "              constant allocation: %10s\n",
662                   HumanReadableNumBytes(constant_allocation_bytes));
663   StrAppendFormat(&s, "        maybe_live_out allocation: %10s\n",
664                   HumanReadableNumBytes(maybe_live_out_allocation_bytes));
665   StrAppendFormat(&s, "     preallocated temp allocation: %10s\n",
666                   HumanReadableNumBytes(preallocated_temp_allocation_bytes));
667   if (preallocated_temp_fragmentation_bytes >= 0) {
668     const double percent = 100. * preallocated_temp_fragmentation_bytes /
669                            preallocated_temp_allocation_bytes;
670     StrAppendFormat(
671         &s, "  preallocated temp fragmentation: %10s (%.2f%%)\n",
672         HumanReadableNumBytes(preallocated_temp_fragmentation_bytes), percent);
673   }
674   StrAppendFormat(&s, "                 total allocation: %10s\n",
675                   HumanReadableNumBytes(total_allocation_bytes));
676   if (total_fragmentation_bytes >= 0) {
677     const double percent =
678         100. * total_fragmentation_bytes / total_allocation_bytes;
679     StrAppendFormat(&s, "              total fragmentation: %10s (%.2f%%)\n",
680                     HumanReadableNumBytes(total_fragmentation_bytes), percent);
681   }
682   return s;
683 }
684 
ToString() const685 string BufferAssignment::ToString() const {
686   string output;
687   absl::StrAppend(&output, "BufferAssignment:\n");
688   for (auto& allocation : allocations_) {
689     absl::StrAppend(&output, allocation.ToString());
690   }
691   return output;
692 }
693 
ToProto() const694 BufferAssignmentProto BufferAssignment::ToProto() const {
695   BufferAssignmentProto proto;
696   // NOTE: TuplePointsToAnalysis state is serialized here in BufferAssigment,
697   // because we need to do the HasAllocation check for each buffer. Otherwise
698   // the buffer_size_ call might fail for some backends.
699   const TuplePointsToAnalysis& points_to_analysis =
700       liveness_->points_to_analysis();
701   for (LogicalBuffer::Id id = 0; id < points_to_analysis.num_logical_buffers();
702        id++) {
703     auto& buffer = points_to_analysis.logical_buffer(id);
704     if (HasAllocation(buffer)) {
705       LogicalBufferProto proto_buffer = buffer.ToProto(buffer_size_);
706       proto.add_logical_buffers()->Swap(&proto_buffer);
707 
708       // Fill buffer aliases.
709       for (const BufferAlias& alias :
710            points_to_analysis.GetBufferAliases(buffer)) {
711         if (alias.instruction() == buffer.instruction() &&
712             alias.index() == buffer.index()) {
713           continue;  // skip self-aliases
714         }
715         BufferAssignmentProto::BufferAlias* proto_alias =
716             proto.add_buffer_aliases();
717         LogicalBufferProto::Location proto_alias_location =
718             BufferValue::ToLocationProto(*alias.instruction(), alias.index());
719         proto_alias->set_source_buffer_id(buffer.id());
720         proto_alias->mutable_location()->Swap(&proto_alias_location);
721       }
722     }
723   }
724   for (const BufferAllocation& allocation : Allocations()) {
725     BufferAllocationProto proto_allocation = allocation.ToProto();
726     proto.add_buffer_allocations()->Swap(&proto_allocation);
727     for (const HeapSimulatorTrace& heap_trace : allocation.HeapTraces()) {
728       *proto.add_heap_simulator_traces() = heap_trace;
729     }
730   }
731   return proto;
732 }
733 
734 /* static */
Run(const HloModule * module,std::unique_ptr<HloOrdering> hlo_ordering,LogicalBuffer::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment,bool allow_input_output_aliasing,bool allocate_buffers_for_constants,BufferLiveness::Colorer colorer,ReuseAllocationFunction reuse_checker)735 StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::Run(
736     const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
737     LogicalBuffer::SizeFunction buffer_size,
738     LogicalBuffer::AlignmentFunction color_alignment,
739     bool allow_input_output_aliasing, bool allocate_buffers_for_constants,
740     BufferLiveness::Colorer colorer, ReuseAllocationFunction reuse_checker) {
741   BufferAssigner assigner(allocate_buffers_for_constants, std::move(colorer),
742                           std::move(reuse_checker));
743   return assigner.CreateAssignment(module, std::move(hlo_ordering),
744                                    std::move(buffer_size),
745                                    std::move(color_alignment));
746 }
747 
748 namespace {
749 
750 // a and b are in different subcomputations. Check for the case
751 // where a is inside the while body, and b is outside, part of the same while's
752 // init-operand or while-result.
MayInterfereAcrossSubcomputations(BufferAssignment * assignment,const LogicalBuffer & a_buffer,const LogicalBuffer & b_buffer)753 bool MayInterfereAcrossSubcomputations(BufferAssignment* assignment,
754                                        const LogicalBuffer& a_buffer,
755                                        const LogicalBuffer& b_buffer) {
756   const CallGraph& call_graph =
757       assignment->liveness().hlo_ordering().call_graph();
758   const HloInstruction* a_ancestor;
759   const HloInstruction* b_ancestor;
760   std::tie(a_ancestor, b_ancestor) =
761       call_graph.NearestAncestorsInSameComputation(a_buffer.instruction(),
762                                                    b_buffer.instruction());
763   if (a_ancestor == nullptr) {
764     // No common ancestor.
765     return true;
766   }
767   if (a_ancestor->opcode() == HloOpcode::kWhile &&
768       call_graph.InstructionIsNestedIn(a_buffer.instruction(),
769                                        a_ancestor->while_body())) {
770     const PointsToSet& init_set =
771         assignment->liveness().points_to_analysis().GetPointsToSet(
772             a_ancestor->operand(0));
773     if (init_set.ContainsBuffer(b_buffer)) {
774       VLOG(4) << "Can't interfere: " << a_buffer << " and " << b_buffer
775               << " (part of while-operand)";
776       return false;
777     }
778     const PointsToSet& while_set =
779         assignment->liveness().points_to_analysis().GetPointsToSet(a_ancestor);
780     if (while_set.ContainsBuffer(b_buffer)) {
781       VLOG(4) << "Can't interfere: " << a_buffer << " and " << b_buffer
782               << " (part of while)";
783       return false;
784     }
785   }
786   return true;
787 }
788 
789 // Return true, if a and b can't possibly interfere (and therefore further
790 // checking for interference can be skipped). This function checks for special
791 // cases where copy insertion guarantees no interference, but the regular buffer
792 // liveness is too conservative:
793 //
794 // Operations inside a while-body can't interfere with operations outside the
795 // while op if their last use is at the while-loop itself as part of the
796 // while-init op, or the while-result.  For ops that are live across a
797 // while-loop, copy insertion will already insert the necessary copies to avoid
798 // such interference.
799 //
800 // This allows sharing buffers in cases like this:
801 // init = {...}
802 // while (init):
803 //  p = param(0)
804 //  gte = get-tuple-element(p), index=i
805 //  t1 = op1 (gte)
806 //  t2 = op2 (t1)
807 //  ROOT tuple = {..., t2, ...}
808 //
809 // where t1 and t2 can share the same buffer.
MaySkipInterferenceCheck(BufferAssignment * assignment,const LogicalBuffer & a_buffer,const LogicalBuffer & b_buffer)810 bool MaySkipInterferenceCheck(BufferAssignment* assignment,
811                               const LogicalBuffer& a_buffer,
812                               const LogicalBuffer& b_buffer) {
813   if (a_buffer.instruction()->parent() == b_buffer.instruction()->parent()) {
814     // Ops within the same computation are not handled here. Assume that they
815     // may interfere.
816     return false;
817   }
818   return !MayInterfereAcrossSubcomputations(assignment, a_buffer, b_buffer) ||
819          !MayInterfereAcrossSubcomputations(assignment, b_buffer, a_buffer);
820 }
821 
822 }  // namespace
823 
MaybeAssignBuffer(BufferAllocation * allocation,const LogicalBuffer & buffer,BufferAssignment * assignment)824 bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
825                                        const LogicalBuffer& buffer,
826                                        BufferAssignment* assignment) {
827   const LogicalBuffer::SizeFunction& buffer_size = assignment->buffer_size_;
828 
829   CHECK(!assignment->HasAllocation(buffer))
830       << "buffer " << buffer << " already has an allocation assigned.";
831 
832   VLOG(4) << "Trying to assign " << buffer << " to allocation: " << *allocation;
833 
834   if (buffer.color() != allocation->color()) {
835     VLOG(4) << "Can't assign: buffer has color" << buffer.color()
836             << " and allocation has color " << allocation->color() << ".";
837     return false;
838   }
839 
840   if (buffer_size(buffer) > allocation->size()) {
841     VLOG(4) << "Can't assign: buffer is larger than allocation ("
842             << buffer_size(buffer) << " > " << allocation->size() << ")";
843     return false;
844   }
845 
846   if (allocation->is_readonly()) {
847     VLOG(4) << "Can't assign: allocation is readonly";
848     return false;
849   }
850 
851   if (reuse_checker_ != nullptr &&
852       !reuse_checker_(*assignment, *allocation, buffer)) {
853     VLOG(4) << "Can't assign: reuse_checker_(allocation, buffer) == false";
854     return false;
855   }
856 
857   if (!allocation->is_reusable()) {
858     VLOG(4) << "Can't assign: allocation is not reusable";
859     return false;
860   }
861 
862   for (const auto& buffer_offset_size : allocation->assigned_buffers()) {
863     const LogicalBuffer& assigned_buffer = *buffer_offset_size.first;
864     if (MaySkipInterferenceCheck(assignment, buffer, assigned_buffer)) {
865       continue;
866     }
867     if (assignment->liveness().MayInterfere(assigned_buffer, buffer)) {
868       VLOG(4) << "Can't assign: assignee " << assigned_buffer
869               << " may interfere with " << buffer;
870       return false;
871     }
872     // Copy instruction don't share a buffer with their input operand.
873     if (buffer.instruction()->IsUserOf(assigned_buffer.instruction()) &&
874         buffer.instruction()->opcode() == HloOpcode::kCopy) {
875       VLOG(4) << "Can't assign: assignee " << assigned_buffer
876               << " is used at copy instruction " << buffer;
877       return false;
878     }
879   }
880 
881   // If the buffer is live out of the computation then it should only be
882   // assigned a buffer which exactly fits the result to avoid wasting memory
883   // (result buffers can have arbitrary lifetimes).
884   if (assignment->liveness().MaybeLiveOut(buffer) &&
885       allocation->size() != buffer_size(buffer)) {
886     VLOG(4) << "Can't assign: buffer " << buffer
887             << "is live out and size not the same as allocation";
888     return false;
889   }
890 
891   assignment->AddAssignment(allocation, buffer, /*offset=*/0,
892                             buffer_size(buffer));
893   return true;
894 }
895 
AssignBuffersForComputation(const HloComputation * computation,bool is_thread_local,const flat_hash_set<const LogicalBuffer * > & colocated_buffers,const flat_hash_set<BufferAllocation::Index> & colocated_allocations,flat_hash_map<const HloComputation *,flat_hash_set<const LogicalBuffer * >> * buffers_to_assign_sequentially,BufferAssignment * assignment)896 Status BufferAssigner::AssignBuffersForComputation(
897     const HloComputation* computation, bool is_thread_local,
898     const flat_hash_set<const LogicalBuffer*>& colocated_buffers,
899     const flat_hash_set<BufferAllocation::Index>& colocated_allocations,
900     flat_hash_map<const HloComputation*, flat_hash_set<const LogicalBuffer*>>*
901         buffers_to_assign_sequentially,
902     BufferAssignment* assignment) {
903   // Buffers are sorted and assigned to BufferAllocations in decreasing order of
904   // size.
905   std::vector<const LogicalBuffer*> sorted_buffers;
906   for (auto* instruction : computation->instructions()) {
907     // Add all buffers which this instruction defines. Instruction which don't
908     // define buffers (eg, bitcast which just forwards a pointer) don't need
909     // any allocations.
910     for (const LogicalBuffer* buffer :
911          assignment->points_to_analysis().GetBuffersDefinedByInstruction(
912              instruction)) {
913       sorted_buffers.push_back(buffer);
914     }
915   }
916 
917   // Generate a post order sort of instructions for sorting of the
918   // LogicalBuffers.
919   flat_hash_map<const HloInstruction*, int> post_order_position;
920   int position = 0;
921   for (auto* instruction : computation->MakeInstructionPostOrder()) {
922     post_order_position.emplace(instruction, position);
923     position++;
924   }
925 
926   // If there is a sequential instruction ordering, we'll delay assignment of
927   // temp buffers until after the main assignment loop.
928   const BufferLiveness& liveness = assignment->liveness();
929   const bool has_sequential_order =
930       liveness.hlo_ordering().SequentialOrder(*computation) != nullptr;
931   if (has_sequential_order && buffers_to_assign_sequentially != nullptr) {
932     // Every sequential computation must get an entry in the
933     // buffers_to_assign_sequentially map, even if we end up with an empty set
934     // of buffers. This ensures we can correctly determine whether to run
935     // whole-module heap simulation.
936     buffers_to_assign_sequentially->emplace(
937         computation, flat_hash_set<const LogicalBuffer*>());
938   }
939 
940   // Sort the LogicalBuffers first by size. We assign the larger LogicalBuffers
941   // first for simplicity. This means any previously created BufferAllocation is
942   // necessarily large enough to hold the output of the current Buffer in
943   // consideration.
944   //
945   // As a secondary sorting criteria, if the instructions are sequentially
946   // ordered, we assign live-out buffers before others. Note that for sequential
947   // computations, we'll take temp buffers that can't re-use any allocations and
948   // assign them via a heap scheduler. By assigning live-out buffers first, we
949   // increase the odds that temp buffers can re-use an allocation.
950   //
951   // As a final tiebreaker use post order position of the HLO instruction which
952   // defines the buffer. This means an instruction will appear after its
953   // operands (assuming operands are the same/larger size) enabling the
954   // important reuse case where an elementwise instruction reuses one of its
955   // operand's buffer. This improves locality.
956   absl::c_sort(sorted_buffers,
957                [has_sequential_order, &liveness, &post_order_position,
958                 assignment](const LogicalBuffer* a, const LogicalBuffer* b) {
959                  // Primary sort is by decreasing buffer size.
960                  const int64 a_size = assignment->buffer_size_(*a);
961                  const int64 b_size = assignment->buffer_size_(*b);
962                  if (a_size != b_size) {
963                    return a_size > b_size;  // use ">" for decreasing size.
964                  }
965                  // Otherwise live out buffers come before others, if the
966                  // instructions are sequentially ordered.
967                  if (has_sequential_order) {
968                    const bool a_live_out = liveness.MaybeLiveOut(*a);
969                    const bool b_live_out = liveness.MaybeLiveOut(*b);
970                    if (a_live_out != b_live_out) {
971                      return a_live_out;
972                    }
973                  }
974                  // Final tiebreaker is in instruction post order.
975                  return post_order_position.at(a->instruction()) <
976                         post_order_position.at(b->instruction());
977                });
978 
979   // BufferAllocations are necessarily created in decreasing size order. Keep
980   // indices of previously created BufferAllocations in allocation_indices.
981   std::vector<BufferAllocation::Index> allocation_indices;
982   for (const LogicalBuffer* buffer : sorted_buffers) {
983     VLOG(3) << "Assigning allocation to: " << *buffer;
984     if (colocated_buffers.contains(buffer)) {
985       // Colocated buffers are currently assigned in an earlier pass.
986       VLOG(3) << "Skipping colocated buffer: " << *buffer;
987       continue;
988     }
989 
990     TF_RET_CHECK(!assignment->HasAllocation(*buffer));
991 
992     const HloInstruction* instruction = buffer->instruction();
993     const int64 buffer_size = assignment->buffer_size_(*buffer);
994 
995     if (instruction->opcode() == HloOpcode::kConstant) {
996       if (allocate_buffers_for_constants_) {
997         BufferAllocation* allocation =
998             assignment->NewAllocation(*buffer, buffer_size);
999         allocation->set_constant(true);
1000         VLOG(3) << "New allocation #" << allocation->index() << " for constant "
1001                 << *buffer;
1002       }
1003       continue;
1004     }
1005 
1006     const bool is_entry_parameter =
1007         instruction->opcode() == HloOpcode::kParameter &&
1008         computation == computation->parent()->entry_computation();
1009     if (is_entry_parameter) {
1010       // If the LogicalBuffer is part of an external parameter, creates a new
1011       // allocation and sets its parameter number. Parameters of non-entry
1012       // computations do not need special allocations because they live inside
1013       // callers.
1014       BufferAllocation* allocation =
1015           assignment->NewAllocation(*buffer, buffer_size);
1016       bool parameter_has_alias =
1017           assignment->module().input_output_alias_config().ParameterHasAlias(
1018               instruction->parameter_number(), buffer->index());
1019       allocation->set_entry_computation_parameter(
1020           instruction->parameter_number(), buffer->index(),
1021           parameter_has_alias);
1022       VLOG(3) << "Mark allocation #" << allocation->index()
1023               << " as entry computation parameter: " << *buffer;
1024       continue;
1025     }
1026 
1027     if (is_thread_local) {
1028       BufferAllocation* allocation =
1029           assignment->NewAllocation(*buffer, buffer_size);
1030       allocation->set_is_thread_local(true);
1031       VLOG(3) << "New allocation #" << allocation->index()
1032               << " for thread-local: " << *buffer;
1033       continue;
1034     }
1035 
1036     if (buffer->shape().IsTuple()) {
1037       BufferAllocation* allocation =
1038           assignment->NewAllocation(*buffer, buffer_size);
1039       allocation->set_is_tuple(true);
1040       VLOG(3) << "New allocation #" << allocation->index()
1041               << " for tuple-shaped buffer: " << *buffer;
1042       continue;
1043     }
1044 
1045     // First try to assign a LogicalBuffer to one of its operand allocations to
1046     // improve locality. This is only possible with elementwise operations
1047     // (checked in liveness analysis) which are necessarily top-level
1048     // array-shaped buffers.
1049     if (buffer->IsTopLevel() && !buffer->IsTuple()) {
1050       for (auto* operand : instruction->operands()) {
1051         bool assigned_operand = false;
1052         for (const auto& operand_slice :
1053              assignment->GetAllSlices(operand, /*index=*/{})) {
1054           BufferAllocation* allocation =
1055               assignment->GetMutableAllocation(operand_slice.index());
1056           if (!colocated_allocations.contains(allocation->index())) {
1057             // TODO(b/32491382) Colocated buffers are currently assigned in an
1058             // earlier pass, and so can break the "increasing allocation size"
1059             // invariant in this function (causing this CHECK to fail). However,
1060             // the call to MaybeAssignBuffer is safe as it returns false if
1061             // allocation.size < buffer.size.
1062             CHECK_GE(allocation->size(), buffer_size);
1063           }
1064           if (MaybeAssignBuffer(allocation, *buffer, assignment)) {
1065             VLOG(3) << "Reusing (operand) allocation #" << allocation->index()
1066                     << " for: " << *buffer;
1067             assigned_operand = true;
1068             break;
1069           }
1070         }
1071         if (assigned_operand) {
1072           break;
1073         }
1074       }
1075     }
1076 
1077     if (!assignment->HasAllocation(*buffer)) {
1078       // Find the smallest buffer which can be reused iterating from end of
1079       // allocation_indices (smallest) to beginning (largest).
1080       for (int allocation_index = allocation_indices.size() - 1;
1081            allocation_index >= 0; allocation_index--) {
1082         BufferAllocation* allocation = assignment->GetMutableAllocation(
1083             allocation_indices[allocation_index]);
1084         // Instructions are iterated in increasing buffer size, so any
1085         // previously create allocation must be large enough to hold this
1086         // instruction's output (with the exception of colocated buffers).
1087         if (!colocated_allocations.contains(allocation->index())) {
1088           // TODO(b/32491382) Colocated buffers are currently assigned in an
1089           // earlier pass, and so can break the "increasing allocation size"
1090           // invariant in this function (causing this CHECK to fail). However,
1091           // the call to MaybeAssignBuffer is safe as it returns false if
1092           // allocation.size < buffer.size.
1093           CHECK_GE(allocation->size(), buffer_size);
1094         }
1095 
1096         if (MaybeAssignBuffer(allocation, *buffer, assignment)) {
1097           VLOG(3) << "Reusing allocation #" << allocation->index()
1098                   << " for: " << *buffer;
1099           break;
1100         }
1101       }
1102     }
1103 
1104     if (!assignment->HasAllocation(*buffer) && has_sequential_order &&
1105         !liveness.MaybeLiveOut(*buffer)) {
1106       // There is a sequential instruction ordering, so we delay assignment of
1107       // temp buffers until after the loop. We do this right before we decide to
1108       // create a new allocation, to ensure we've exhausted all the buffer
1109       // re-use cases above.
1110       //
1111       // Entry parameters and thread local buffers were already handled earlier
1112       // in this loop iteration.  See BufferAllocation::IsPreallocatedTempBuffer
1113       // for the definition of temp buffers.
1114       CHECK(!is_entry_parameter) << *buffer;
1115       CHECK(!is_thread_local) << *buffer;
1116       (*buffers_to_assign_sequentially)[computation].insert(buffer);
1117       VLOG(3) << "Delaying assignment of temp buffer: " << *buffer;
1118       continue;
1119     }
1120 
1121     if (!assignment->HasAllocation(*buffer)) {
1122       BufferAllocation* allocation =
1123           assignment->NewAllocation(*buffer, buffer_size);
1124       allocation_indices.push_back(allocation->index());
1125       VLOG(3) << "New allocation #" << allocation->index()
1126               << " for: " << *buffer;
1127     }
1128   }
1129 
1130   return Status::OK();
1131 }
1132 
1133 flat_hash_map<LogicalBuffer::Color, flat_hash_set<const LogicalBuffer*>,
1134               LogicalBuffer::Color::Hasher>
SplitBuffersByColor(const flat_hash_set<const LogicalBuffer * > & buffers)1135 BufferAssigner::SplitBuffersByColor(
1136     const flat_hash_set<const LogicalBuffer*>& buffers) {
1137   flat_hash_map<LogicalBuffer::Color, flat_hash_set<const LogicalBuffer*>,
1138                 LogicalBuffer::Color::Hasher>
1139       color_map;
1140   for (auto buffer : buffers) {
1141     color_map[buffer->color()].insert(buffer);
1142   }
1143   return color_map;
1144 }
1145 
AssignBuffersWithSequentialOrdering(const flat_hash_map<const HloComputation *,flat_hash_set<const LogicalBuffer * >> & buffers_to_assign_sequentially,bool run_whole_module_heap_simulation,BufferAssignment * assignment)1146 Status BufferAssigner::AssignBuffersWithSequentialOrdering(
1147     const flat_hash_map<const HloComputation*,
1148                         flat_hash_set<const LogicalBuffer*>>&
1149         buffers_to_assign_sequentially,
1150     bool run_whole_module_heap_simulation, BufferAssignment* assignment) {
1151   // Run the sequence of instructions through the heap simulator.  The heuristic
1152   // that seems to give the best results is lazy-best-fit, with all runs of
1153   // alloc / free calls sorted in decreasing size order.
1154   const HloOrdering& hlo_ordering = assignment->liveness().hlo_ordering();
1155 
1156   // Returns a heap algorithm that chooses the best result from several
1157   // algorithms.
1158   auto get_heap_algorithm = [&](int64 alignment) {
1159     auto algorithms =
1160         absl::make_unique<std::vector<std::unique_ptr<HeapAlgorithm>>>();
1161     algorithms->push_back(absl::make_unique<DecreasingSizeRunsHeap>(
1162         absl::make_unique<LazyBestFitHeap>(alignment)));
1163     algorithms->push_back(
1164         absl::make_unique<GlobalDecreasingSizeBestFitHeap>(alignment));
1165     return absl::make_unique<ChooseBestHeapAlgorithm>(std::move(algorithms));
1166   };
1167 
1168   if (run_whole_module_heap_simulation) {
1169     // Run the heap simulation over the whole module. This reduces memory usage,
1170     // since buffers for kCall, kWhile, and kConditional sub-computations are
1171     // only live for the duration of their calling instructions.
1172     VLOG(1) << "Running whole-module heap simulation";
1173     HloSchedule schedule(&assignment->module());
1174     flat_hash_set<const LogicalBuffer*> all_buffers_to_assign;
1175     for (const auto& pair : buffers_to_assign_sequentially) {
1176       const HloComputation* computation = pair.first;
1177       const flat_hash_set<const LogicalBuffer*>& buffers_to_assign =
1178           pair.second;
1179       const HloInstructionSequence* instruction_sequence =
1180           hlo_ordering.SequentialOrder(*computation);
1181       CHECK(instruction_sequence != nullptr) << computation->name();
1182       schedule.set_sequence(computation, *instruction_sequence);
1183       all_buffers_to_assign.insert(buffers_to_assign.begin(),
1184                                    buffers_to_assign.end());
1185     }
1186     auto color_map = SplitBuffersByColor(all_buffers_to_assign);
1187     for (auto& single_colored_set : color_map) {
1188       auto color = single_colored_set.first;
1189       VLOG(2) << "Simulating heap for color " << color;
1190       int64 alignment = assignment->color_alignment_(color);
1191       HeapSimulator::Options options;
1192       options.alloc_constants = allocate_buffers_for_constants_;
1193       BufferValueFlatSet buffer_value_set =
1194           ToBufferValueFlatSet(single_colored_set.second);
1195       options.buffers_to_assign = &buffer_value_set;
1196       TF_ASSIGN_OR_RETURN(
1197           const HeapSimulator::Result result,
1198           HeapSimulator::Run(get_heap_algorithm(alignment),
1199                              assignment->module(), schedule,
1200                              assignment->points_to_analysis(),
1201                              assignment->buffer_size_, options));
1202       AssignBuffersFromHeapSimulator(result, assignment,
1203                                      single_colored_set.first);
1204     }
1205   } else {
1206     // Run the heap-simulation on a per-computation basis. Buffers for
1207     // sub-computations are assigned disjoint BufferAllocations, assuming the
1208     // worst-case that they may all be live concurrently.
1209     VLOG(1) << "Running per-computation heap simulation";
1210     for (const auto& pair : buffers_to_assign_sequentially) {
1211       const HloComputation* computation = pair.first;
1212       const flat_hash_set<const LogicalBuffer*>& buffers_to_assign =
1213           pair.second;
1214       const HloInstructionSequence* instruction_sequence =
1215           hlo_ordering.SequentialOrder(*computation);
1216       CHECK(instruction_sequence != nullptr) << computation->name();
1217       auto color_map = SplitBuffersByColor(buffers_to_assign);
1218       for (auto& single_colored_set : color_map) {
1219         auto color = single_colored_set.first;
1220         VLOG(2) << "Simulating heap for color " << color;
1221         int64 alignment = assignment->color_alignment_(color);
1222         HeapSimulator::Options options;
1223         BufferValueFlatSet buffer_value_set =
1224             ToBufferValueFlatSet(single_colored_set.second);
1225         options.buffers_to_assign = &buffer_value_set;
1226         TF_ASSIGN_OR_RETURN(
1227             const HeapSimulator::Result result,
1228             HeapSimulator::Run(get_heap_algorithm(alignment), *computation,
1229                                *instruction_sequence,
1230                                assignment->points_to_analysis(),
1231                                assignment->buffer_size_, options));
1232         AssignBuffersFromHeapSimulator(result, assignment,
1233                                        single_colored_set.first);
1234       }
1235     }
1236   }
1237   return Status::OK();
1238 }
1239 
1240 namespace {
1241 
1242 // Computes and returns the set of logical buffers live at the point of maximal
1243 // liveness in the given heap trace. LogicalBuffers are (stabily) sorted by id.
ComputePeakMemoryLogicalBuffers(const BufferAllocation & allocation,const HeapSimulatorTrace & heap_trace)1244 std::vector<const LogicalBuffer*> ComputePeakMemoryLogicalBuffers(
1245     const BufferAllocation& allocation, const HeapSimulatorTrace& heap_trace) {
1246   // Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical
1247   // buffers in this allocation.
1248   absl::flat_hash_map<LogicalBuffer::Id, const LogicalBuffer*> id_to_buffer;
1249   absl::flat_hash_map<const LogicalBuffer*, int64> buffer_sizes;
1250   for (const auto& pair : allocation.assigned_buffers()) {
1251     const LogicalBuffer* buffer = pair.first;
1252     const BufferAllocation::OffsetSize& offset_size = pair.second;
1253     id_to_buffer[buffer->id()] = buffer;
1254     buffer_sizes[buffer] = offset_size.size;
1255   }
1256 
1257   // Returns how much the given event increases the total size of live
1258   // buffers. Can be negative.
1259   auto memory_delta = [&id_to_buffer, &buffer_sizes](
1260                           const HeapSimulatorTrace::Event& event) -> int64 {
1261     const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id());
1262     const int64 buffer_size = buffer_sizes.at(buffer);
1263     if (event.kind() == HeapSimulatorTrace::Event::ALLOC) {
1264       return buffer_size;
1265     } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
1266       // Sharing a buffer does not change the live set size for the purposes of
1267       // the heap simulator. Even though the shared-with buffer may be smaller,
1268       // the entire allocation remains live.
1269       return 0;
1270     } else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
1271       return -1 * buffer_size;
1272     }
1273     LOG(FATAL) << "Unknown event kind: " << event.kind();
1274   };
1275 
1276   // First compute the size of the maximal live set.
1277   int64 max_live_size = 0;
1278   int64 live_size = 0;
1279   for (const auto& event : heap_trace.events()) {
1280     live_size += memory_delta(event);
1281     if (max_live_size < live_size) {
1282       max_live_size = live_size;
1283     }
1284   }
1285 
1286   // Next gather the set of logical buffers live at the earliest point of
1287   // maximal live set size.
1288   absl::flat_hash_set<const LogicalBuffer*> live_buffers;
1289   live_size = 0;
1290   for (const auto& event : heap_trace.events()) {
1291     const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id());
1292     if (event.kind() == HeapSimulatorTrace::Event::ALLOC) {
1293       InsertOrDie(&live_buffers, buffer);
1294     } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
1295       // Nothing to do.
1296     } else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
1297       CHECK(ContainsKey(live_buffers, buffer));
1298       live_buffers.erase(buffer);
1299     }
1300 
1301     live_size += memory_delta(event);
1302     if (live_size == max_live_size) {
1303       break;
1304     }
1305   }
1306   CHECK_EQ(live_size, max_live_size);
1307 
1308   std::vector<const LogicalBuffer*> live_buffers_vector;
1309   live_buffers_vector.insert(live_buffers_vector.end(), live_buffers.begin(),
1310                              live_buffers.end());
1311 
1312   // Stabily sort the live buffers.
1313   absl::c_sort(live_buffers_vector,
1314                [](const LogicalBuffer* a, const LogicalBuffer* b) {
1315                  return a->id() < b->id();
1316                });
1317   return live_buffers_vector;
1318 }
1319 
1320 }  // namespace
1321 
AssignBuffersFromHeapSimulator(const HeapSimulator::Result & result,BufferAssignment * assignment,LogicalBuffer::Color color)1322 void BufferAssigner::AssignBuffersFromHeapSimulator(
1323     const HeapSimulator::Result& result, BufferAssignment* assignment,
1324     LogicalBuffer::Color color) {
1325   if (assignment->stats_.preallocated_temp_fragmentation_bytes == -1) {
1326     assignment->stats_.preallocated_temp_fragmentation_bytes =
1327         result.fragmentation_size;
1328   } else {
1329     assignment->stats_.preallocated_temp_fragmentation_bytes +=
1330         result.fragmentation_size;
1331   }
1332 
1333   BufferAllocation* allocation =
1334       assignment->NewEmptyAllocation(result.heap_size, color);
1335   for (const auto& buffer_chunk : result.chunk_map) {
1336     // TODO(lauj) Remove this down_cast after downstream users of
1337     // BufferAllocation::assigned_buffers() are updated to use BufferValue.
1338     const LogicalBuffer& buffer =
1339         *CHECK_NOTNULL(dynamic_cast<const LogicalBuffer*>(buffer_chunk.first));
1340     const HeapSimulator::Chunk& chunk = buffer_chunk.second;
1341     assignment->AddAssignment(allocation, buffer, chunk.offset, chunk.size);
1342   }
1343   allocation->peak_buffers_ =
1344       ComputePeakMemoryLogicalBuffers(*allocation, result.debug_trace);
1345 
1346   VLOG(1) << "Ran heap simulation for allocation: " << allocation->ToString();
1347   allocation->AddHeapTrace(result.debug_trace);
1348 }
1349 
1350 // Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining
1351 // the invariant that all sets in 'colocated_buffer_sets' are disjoint.
1352 //
1353 // A practical example of when this is necessary is a chain of kCall ops:
1354 //   computation.entry
1355 //     %a = call() -> computation.1
1356 //   computation.1
1357 //     %b = call() -> computation.2
1358 //   computation.2
1359 //     %c = parameter()
1360 // This yields the logical sets {%a,%b} {%b,%c} {%c}, which need to be merged
1361 // into a single set {%a,%b,%c}
AddSetToColocatedBufferSets(const std::vector<const LogicalBuffer * > & colocated_set,std::vector<ColocatedBufferSet> * colocated_buffer_sets)1362 void BufferAssigner::AddSetToColocatedBufferSets(
1363     const std::vector<const LogicalBuffer*>& colocated_set,
1364     std::vector<ColocatedBufferSet>* colocated_buffer_sets) {
1365   if (colocated_set.empty()) {
1366     return;
1367   }
1368   VLOG(5) << ColocatedBufferSetsToString(colocated_set,
1369                                          "Adding colocated buffer set");
1370   // Find existing sets that overlap with at least one buffer from the
1371   // colocated_set. The resulting 'overlap_set_indices' will have at most
1372   // colocated_buffer_sets->size() entries, and will be in increasing order.
1373   std::vector<size_t> overlap_set_indices;
1374   for (size_t index = 0; index < colocated_buffer_sets->size(); ++index) {
1375     for (const LogicalBuffer* buffer : colocated_set) {
1376       if ((*colocated_buffer_sets)[index].contains(buffer)) {
1377         VLOG(5) << "Found overlap with existing set on buffer "
1378                 << buffer->ToString() << "\n"
1379                 << ColocatedBufferSetsToString((*colocated_buffer_sets)[index],
1380                                                "Overlapping set");
1381         overlap_set_indices.push_back(index);
1382         break;
1383       }
1384     }
1385   }
1386 
1387   // If there is no overlap with existing sets, create a new set.
1388   if (overlap_set_indices.empty()) {
1389     colocated_buffer_sets->emplace_back();
1390     colocated_buffer_sets->back().insert(colocated_set.begin(),
1391                                          colocated_set.end());
1392     VLOG(5) << "No overlap found, new group created";
1393     return;
1394   }
1395 
1396   // Merge all overlap sets and the colocated set into the first overlap set.
1397   ColocatedBufferSet* first = &(*colocated_buffer_sets)[overlap_set_indices[0]];
1398   for (size_t index = 1; index < overlap_set_indices.size(); ++index) {
1399     const ColocatedBufferSet& overlap_set =
1400         (*colocated_buffer_sets)[overlap_set_indices[index]];
1401     first->insert(overlap_set.begin(), overlap_set.end());
1402   }
1403   first->insert(colocated_set.begin(), colocated_set.end());
1404   VLOG(5) << ColocatedBufferSetsToString(
1405       *first, "Result of the colocated buffer set merging");
1406 
1407   // Remove overlap sets that we just merged. The offset accounts for the fact
1408   // that as elements are erased, the indices need to be adjusted. Keep in mind
1409   // that overlap_set_indices is in increasing order.
1410   for (size_t index = 1; index < overlap_set_indices.size(); ++index) {
1411     const size_t offset = overlap_set_indices[index] - index + 1;
1412     colocated_buffer_sets->erase(colocated_buffer_sets->begin() + offset);
1413   }
1414 }
1415 
1416 std::vector<BufferAssigner::ColocatedBufferSet>
MergeColocatedBufferSets(const std::vector<ColocatedBufferSet> & colocated_buffer_sets,const BufferLiveness & buffer_liveness,const LogicalBuffer::SizeFunction & buffer_size)1417 BufferAssigner::MergeColocatedBufferSets(
1418     const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
1419     const BufferLiveness& buffer_liveness,
1420     const LogicalBuffer::SizeFunction& buffer_size) {
1421   VLOG(1) << "colocation sets count before coalescing:"
1422           << colocated_buffer_sets.size();
1423 
1424   // Returns true if the given buffer is for the entry parameter.
1425   auto is_readonly_entry_parameter = [](const LogicalBuffer& buffer) {
1426     auto* instruction = buffer.instruction();
1427     auto* computation = instruction->parent();
1428     auto* module = computation->parent();
1429     return instruction->opcode() == HloOpcode::kParameter &&
1430            computation == module->entry_computation() &&
1431            !module->input_output_alias_config().ParameterHasAlias(
1432                instruction->parameter_number(), buffer.index());
1433   };
1434 
1435   std::vector<bool> set_can_be_merged(colocated_buffer_sets.size(), true);
1436 
1437   // Do not merge if one of the sets includes live outs, entry parameters or
1438   // constants.
1439   //
1440   // Buffer liveness does not report the correct live range for entry
1441   // parameter and live out buffers so we have to special case them here.  On
1442   // backends that support constant buffer allocations, constant buffers are
1443   // assigned globals in readonly storage so we can't merge colocated buffer
1444   // sets containing constants with colocated buffer sets containing writing
1445   // instructions or other constants.
1446   //
1447   // Moreover (on the CPU/GPU backends) the entry parameter buffers belong to
1448   // the caller of the executable so we can't write to entry parameters
1449   // either, and the argument for not merging constants also applies to entry
1450   // parameters.
1451   for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) {
1452     for (auto& buffer : colocated_buffer_sets[i]) {
1453       if (buffer_liveness.MaybeLiveOut(*buffer) ||
1454           is_readonly_entry_parameter(*buffer) ||
1455           buffer->instruction()->opcode() == HloOpcode::kConstant) {
1456         set_can_be_merged[i] = false;
1457         break;
1458       }
1459     }
1460   }
1461 
1462   // Returns true if the two colocated buffer sets (specified by their indices
1463   // into the colocated_buffer_sets) can be merged into a single set.
1464   auto cannot_merge_buffer_sets = [&colocated_buffer_sets, &buffer_liveness,
1465                                    &buffer_size,
1466                                    &set_can_be_merged](int64 i, int64 j) {
1467     if (!set_can_be_merged[i] || !set_can_be_merged[j]) {
1468       return true;
1469     }
1470 
1471     // Colocated sets satisfy the invariant that all buffers within a set have
1472     // the same size. That means we need to check whether the size is the same
1473     // between the two sets, but also that it's enough to look at just one
1474     // buffer within each set.
1475     if (buffer_size(**colocated_buffer_sets[i].begin()) !=
1476         buffer_size(**colocated_buffer_sets[j].begin())) {
1477       return true;
1478     }
1479 
1480     // Do not merge if some pair of buffers interferes with each other.
1481     for (auto& buffer_a : colocated_buffer_sets[i]) {
1482       for (auto& buffer_b : colocated_buffer_sets[j]) {
1483         if (buffer_a->id() != buffer_b->id() &&
1484             buffer_liveness.MayInterfere(*buffer_a, *buffer_b)) {
1485           return true;
1486         }
1487       }
1488     }
1489 
1490     return false;
1491   };
1492 
1493   // Build the interference map among the colocated buffer sets (nodes), by
1494   // adding an edge between any two nodes that cannot be merged into a single
1495   // colocated buffer set.
1496   std::vector<std::vector<int64>> interference_map(
1497       colocated_buffer_sets.size());
1498   for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) {
1499     for (int64 j = i + 1; j < colocated_buffer_sets.size(); ++j) {
1500       if (cannot_merge_buffer_sets(i, j)) {
1501         interference_map[i].push_back(j);
1502         interference_map[j].push_back(i);
1503       }
1504     }
1505   }
1506 
1507   // Assign a color to each colocation set in colocated_buffer_sets, such that
1508   // the sets that can be merged are assigned with the same color.
1509   auto assigned_colors = ColorInterferenceGraph(interference_map);
1510 
1511   // Merge the buffer sets with the same color.
1512   CHECK(!assigned_colors.empty());
1513   int64 num_sets =
1514       *std::max_element(assigned_colors.begin(), assigned_colors.end()) + 1;
1515   std::vector<ColocatedBufferSet> new_colocated_buffer_sets(num_sets);
1516   for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) {
1517     const auto& buffer_set = colocated_buffer_sets[i];
1518     new_colocated_buffer_sets[assigned_colors[i]].insert(buffer_set.begin(),
1519                                                          buffer_set.end());
1520   }
1521 
1522   VLOG(1) << "colocation sets count after coalescing:"
1523           << colocated_buffer_sets.size();
1524   return new_colocated_buffer_sets;
1525 }
1526 
1527 // Builds sets of buffers in 'colocated_buffer_sets' which should be colocated
1528 // in the same allocation (currently just supports kWhile, kCall, and
1529 // kConditional and input output aliasing).
BuildColocatedBufferSets(const HloModule * module,const BufferLiveness & buffer_liveness,const LogicalBuffer::SizeFunction & buffer_size,std::vector<ColocatedBufferSet> * colocated_buffer_sets)1530 void BufferAssigner::BuildColocatedBufferSets(
1531     const HloModule* module, const BufferLiveness& buffer_liveness,
1532     const LogicalBuffer::SizeFunction& buffer_size,
1533     std::vector<ColocatedBufferSet>* colocated_buffer_sets) {
1534   const TuplePointsToAnalysis& points_to_analysis =
1535       buffer_liveness.points_to_analysis();
1536 
1537   // Set up colocated buffer set for input and output.
1538   VLOG(4) << "Input/Output Alias Config: ";
1539   VLOG(4) << module->input_output_alias_config();
1540   module->input_output_alias_config().ForEachAlias(
1541       [&](const ShapeIndex& output_index,
1542           const HloInputOutputAliasConfig::Alias& alias) {
1543         std::vector<const LogicalBuffer*> colocated_set;
1544         AddBufferToColocatedSet(module->entry_computation()->root_instruction(),
1545                                 output_index, points_to_analysis,
1546                                 &colocated_set);
1547         AddBufferToColocatedSet(
1548             module->entry_computation()->parameter_instruction(
1549                 alias.parameter_number),
1550             alias.parameter_index, points_to_analysis, &colocated_set);
1551         AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
1552       });
1553 
1554   for (const HloComputation* computation : module->MakeComputationPostOrder()) {
1555     if (computation->IsFusionComputation()) {
1556       continue;
1557     }
1558     for (const HloInstruction* instruction :
1559          computation->MakeInstructionPostOrder()) {
1560       const HloOpcode opcode = instruction->opcode();
1561       if (opcode == HloOpcode::kWhile) {
1562         const HloInstruction* while_hlo = instruction;
1563         ShapeUtil::ForEachSubshape(
1564             while_hlo->shape(),
1565             [this, while_hlo, &points_to_analysis, buffer_size,
1566              colocated_buffer_sets](const Shape& /*subshape*/,
1567                                     const ShapeIndex& index) {
1568               std::vector<const LogicalBuffer*> colocated_set;
1569               // Add while.init.
1570               AddBufferToColocatedSet(while_hlo->operand(0), index,
1571                                       points_to_analysis, &colocated_set);
1572               // Add while.result.
1573               AddBufferToColocatedSet(while_hlo, index, points_to_analysis,
1574                                       &colocated_set);
1575               // Add while.cond.parameter.
1576               AddBufferToColocatedSet(
1577                   while_hlo->while_condition()->parameter_instruction(0), index,
1578                   points_to_analysis, &colocated_set);
1579               // Add while.body.parameter.
1580               AddBufferToColocatedSet(
1581                   while_hlo->while_body()->parameter_instruction(0), index,
1582                   points_to_analysis, &colocated_set);
1583               // Add while.body.root.
1584               AddBufferToColocatedSet(
1585                   while_hlo->while_body()->root_instruction(), index,
1586                   points_to_analysis, &colocated_set);
1587               AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
1588             });
1589       } else if (opcode == HloOpcode::kCall) {
1590         const HloInstruction* call_hlo = instruction;
1591         const HloComputation* callee = call_hlo->to_apply();
1592         const HloInstruction* root_hlo = callee->root_instruction();
1593         for (int64 i = 0; i < call_hlo->operand_count(); i++) {
1594           const HloInstruction* call_param = callee->parameter_instruction(i);
1595           const HloInstruction* call_operand = call_hlo->operand(i);
1596           ShapeUtil::ForEachSubshape(
1597               call_operand->shape(),
1598               [&](const Shape& /*subshape*/, const ShapeIndex& index) {
1599                 std::vector<const LogicalBuffer*> colocated_set;
1600                 AddBufferToColocatedSet(call_param, index, points_to_analysis,
1601                                         &colocated_set);
1602                 AddBufferToColocatedSet(call_operand, index, points_to_analysis,
1603                                         &colocated_set);
1604                 AddSetToColocatedBufferSets(colocated_set,
1605                                             colocated_buffer_sets);
1606               });
1607         }
1608         ShapeUtil::ForEachSubshape(
1609             call_hlo->shape(),
1610             [this, call_hlo, root_hlo, &points_to_analysis,
1611              colocated_buffer_sets](const Shape& /*subshape*/,
1612                                     const ShapeIndex& index) {
1613               std::vector<const LogicalBuffer*> colocated_set;
1614               // Add call.result.
1615               AddBufferToColocatedSet(call_hlo, index, points_to_analysis,
1616                                       &colocated_set);
1617               // Add call.subcomputation.root.
1618               AddBufferToColocatedSet(root_hlo, index, points_to_analysis,
1619                                       &colocated_set);
1620               AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
1621             });
1622       } else if (opcode == HloOpcode::kConditional) {
1623         const HloInstruction* conditional = instruction;
1624         ShapeUtil::ForEachSubshape(
1625             conditional->shape(),
1626             [this, conditional, &points_to_analysis, colocated_buffer_sets](
1627                 const Shape& /*subshape*/, const ShapeIndex& index) {
1628               std::vector<const LogicalBuffer*> colocated_set;
1629               // Add cond.result.
1630               AddBufferToColocatedSet(conditional, index, points_to_analysis,
1631                                       &colocated_set);
1632               for (int j = 0; j < conditional->branch_count(); ++j) {
1633                 // Add each cond.branch_computation[j].root.
1634                 AddBufferToColocatedSet(
1635                     conditional->branch_computation(j)->root_instruction(),
1636                     index, points_to_analysis, &colocated_set);
1637               }
1638               AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
1639             });
1640 
1641         for (int j = 0; j < conditional->branch_count(); ++j) {
1642           // Add branch_operand[j] (which is operand[j+1]) and
1643           // cond.branch_computation[j].parameter(0) as a colocated
1644           // buffer set. Note that this has to be done for each subshape in the
1645           // branch_operand of the case.
1646           ShapeUtil::ForEachSubshape(
1647               conditional->operand(j + 1)->shape(),
1648               [this, j, conditional, &points_to_analysis,
1649                colocated_buffer_sets](const Shape& /*subshape*/,
1650                                       const ShapeIndex& index) {
1651                 std::vector<const LogicalBuffer*> branch_set;
1652                 // Add cond.operand[j+1].
1653                 AddBufferToColocatedSet(conditional->operand(j + 1), index,
1654                                         points_to_analysis, &branch_set);
1655                 // Add cond.branch_computation[j].parameter_instruction(0).
1656                 AddBufferToColocatedSet(
1657                     conditional->branch_computation(j)->parameter_instruction(
1658                         0),
1659                     index, points_to_analysis, &branch_set);
1660                 AddSetToColocatedBufferSets(branch_set, colocated_buffer_sets);
1661               });
1662         }
1663       }
1664     }
1665   }
1666 
1667   if (colocated_buffer_sets->empty()) {
1668     return;
1669   }
1670 
1671   int64 i = 0;
1672   for (const auto& colocated_set : *colocated_buffer_sets) {
1673     VLOG(4) << "Colocated set " << i++ << ":";
1674     for (const auto& buffer : colocated_set) {
1675       VLOG(4) << "  " << buffer->ToString();
1676     }
1677   }
1678   // Try to find more coalescing opportunities among the colocated buffer sets.
1679   //
1680   // TODO(b/32491382): We should be able to remove this by using the
1681   // module-level liveness analysis, which would let us directly detect buffer
1682   // sharing opportunities between the while instruction buffer and the buffers
1683   // from the predicate and body computation, as well as sharing across
1684   // different while instructions.
1685   std::vector<ColocatedBufferSet> new_colocated_buffer_sets =
1686       MergeColocatedBufferSets(*colocated_buffer_sets, buffer_liveness,
1687                                buffer_size);
1688   std::swap(*colocated_buffer_sets, new_colocated_buffer_sets);
1689 }
1690 
1691 // Assigns all colocated buffer sets in 'colocated_buffer_sets' to the same
1692 // allocation in 'assignment'.
AssignColocatedBufferSets(const std::vector<ColocatedBufferSet> & colocated_buffer_sets,BufferAssignment * assignment,flat_hash_set<const LogicalBuffer * > * colocated_buffers,flat_hash_set<BufferAllocation::Index> * colocated_allocations)1693 void BufferAssigner::AssignColocatedBufferSets(
1694     const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
1695     BufferAssignment* assignment,
1696     flat_hash_set<const LogicalBuffer*>* colocated_buffers,
1697     flat_hash_set<BufferAllocation::Index>* colocated_allocations) {
1698   for (const ColocatedBufferSet& colocated_buffer_set : colocated_buffer_sets) {
1699     BufferAllocation* allocation = nullptr;
1700     // Set 'entry_parameter_number' and 'entry_parameter_shape_idx' if entry
1701     // param in 'colocated_buffer_set'.
1702     int64 entry_parameter_number = -1;
1703     const ShapeIndex* entry_parameter_shape_idx = nullptr;
1704     bool is_constant = false;
1705     for (const LogicalBuffer* buffer : colocated_buffer_set) {
1706       const HloInstruction* instruction = buffer->instruction();
1707       const HloComputation* computation = instruction->parent();
1708       if (instruction->opcode() == HloOpcode::kParameter &&
1709           computation == computation->parent()->entry_computation()) {
1710         entry_parameter_number = instruction->parameter_number();
1711         entry_parameter_shape_idx = &buffer->index();
1712       } else if (instruction->opcode() == HloOpcode::kConstant) {
1713         is_constant = true;
1714       }
1715     }
1716 
1717     CHECK(!is_constant || entry_parameter_number == -1)
1718         << "Copy insertion should have inserted copies to prevent this.";
1719 
1720     for (const LogicalBuffer* buffer : colocated_buffer_set) {
1721       const int64 buffer_size = assignment->buffer_size_(*buffer);
1722       if (allocation == nullptr) {
1723         // TODO(b/32491382) Avoid current trivial solution of using new
1724         // allocations for each colocated buffer set. When liveness has
1725         // module-level scope, we can allow buffers to be shared across
1726         // computations (in some cases).
1727         allocation = assignment->NewAllocation(*buffer, buffer_size);
1728         if (is_constant) {
1729           allocation->set_constant(true);
1730         }
1731         colocated_allocations->insert(allocation->index());
1732       } else {
1733         CHECK_EQ(buffer_size, allocation->size())
1734             << "Buffer: " << *buffer << " size mismatch in colocated buffer "
1735             << "allocation: " << *allocation;
1736         assignment->AddAssignment(allocation, *buffer, /*offset=*/0,
1737                                   buffer_size);
1738       }
1739       colocated_buffers->insert(buffer);
1740     }
1741 
1742     // If an allocation contains a parameter, set corresponding fields.
1743     if (entry_parameter_number >= 0) {
1744       bool parameter_has_alias =
1745           assignment->module().input_output_alias_config().ParameterHasAlias(
1746               entry_parameter_number, *entry_parameter_shape_idx);
1747       allocation->set_entry_computation_parameter(entry_parameter_number,
1748                                                   *entry_parameter_shape_idx,
1749                                                   parameter_has_alias);
1750     }
1751   }
1752 }
1753 
CreateAssignment(const HloModule * module,std::unique_ptr<HloOrdering> hlo_ordering,LogicalBuffer::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment)1754 StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
1755     const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
1756     LogicalBuffer::SizeFunction buffer_size,
1757     LogicalBuffer::AlignmentFunction color_alignment) {
1758   TF_ASSIGN_OR_RETURN(std::unique_ptr<BufferLiveness> liveness,
1759                       BufferLiveness::Run(module, std::move(hlo_ordering)));
1760 
1761   VLOG(1) << "Assigning buffers to module " << module->name();
1762   XLA_VLOG_LINES(2, module->ToString());
1763   XLA_VLOG_LINES(3, liveness->ToString());
1764   XLA_VLOG_LINES(3, liveness->points_to_analysis().ToString());
1765 
1766   // Can't use absl::make_unique because BufferAssignment constructor is
1767   // private.
1768   std::unique_ptr<BufferAssignment> assignment(
1769       new BufferAssignment(module, std::move(liveness), std::move(buffer_size),
1770                            std::move(color_alignment)));
1771 
1772   // Assign buffers with the tightest constraints first (colocated buffer sets).
1773   // Once b/32491382 enables module-level liveness analysis, we may be able
1774   // to assign colocated buffers (or at least reuse their allocation for
1775   // buffers outside of the set) in AssignBuffersForComputation.
1776   flat_hash_set<const LogicalBuffer*> colocated_buffers;
1777   flat_hash_set<BufferAllocation::Index> colocated_allocations;
1778   std::vector<ColocatedBufferSet> colocated_buffer_sets;
1779   BuildColocatedBufferSets(module, assignment->liveness(),
1780                            assignment->buffer_size_, &colocated_buffer_sets);
1781   TF_RETURN_IF_ERROR(colorer_(assignment->liveness()));
1782   VLOG(3) << "After coloring:";
1783   XLA_VLOG_LINES(3, assignment->points_to_analysis().ToString());
1784 
1785   AssignColocatedBufferSets(colocated_buffer_sets, assignment.get(),
1786                             &colocated_buffers, &colocated_allocations);
1787 
1788   std::vector<const HloComputation*> thread_local_computations;
1789   std::vector<const HloComputation*> global_computations;
1790   TF_RETURN_IF_ERROR(GatherComputationsByAllocationType(
1791       module, &thread_local_computations, &global_computations));
1792 
1793   // First assign buffers for global computatations. Temporary buffers for
1794   // sequential computations are collected in 'buffers_to_assign_sequentially'.
1795   flat_hash_map<const HloComputation*, flat_hash_set<const LogicalBuffer*>>
1796       buffers_to_assign_sequentially;
1797   for (auto* computation : global_computations) {
1798     TF_RETURN_IF_ERROR(AssignBuffersForComputation(
1799         computation,
1800         /*is_thread_local=*/false, colocated_buffers, colocated_allocations,
1801         &buffers_to_assign_sequentially, assignment.get()));
1802   }
1803   // Assign buffers with sequential ordering, if any. If all global computations
1804   // are sequential, we can run heap simuation on the whole module, which
1805   // reduces memory usage.
1806   const bool run_whole_module_heap_simulation =
1807       buffers_to_assign_sequentially.size() == global_computations.size();
1808   TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering(
1809       buffers_to_assign_sequentially, run_whole_module_heap_simulation,
1810       assignment.get()));
1811 
1812   // Now assign buffers for thread-local computations. All LogicalBuffers get
1813   // their own BufferAllocation.
1814   for (auto* computation : thread_local_computations) {
1815     TF_RET_CHECK(computation != module->entry_computation());
1816     if (computation->IsFusionComputation()) {
1817       continue;
1818     }
1819     TF_RETURN_IF_ERROR(AssignBuffersForComputation(
1820         computation,
1821         /*is_thread_local=*/true, colocated_buffers, colocated_allocations,
1822         /*buffers_to_assign_sequentially=*/nullptr, assignment.get()));
1823   }
1824 
1825   // Mark all buffers which may be live out of the entry computation as
1826   // "liveout".
1827   for (const LogicalBuffer* buffer :
1828        assignment->liveness().maybe_live_out_buffers()) {
1829     VLOG(3) << "maybe_live_out LogicalBuffer: " << *buffer;
1830     if (assignment->HasAllocation(*buffer)) {
1831       BufferAllocation* alloc =
1832           assignment->GetMutableAssignedAllocation(*buffer);
1833       alloc->set_maybe_live_out(true);
1834       VLOG(3) << "maybe_live_out BufferAllocation: " << *alloc;
1835     }
1836   }
1837 
1838   // Combines allocations of temporary buffers into one big BufferAllocation.
1839   // This can only be performed after all buffers have been assigned, and after
1840   // maybe_live_out is marked, since it is used to determine whether an
1841   // allocation contains temporary buffers or not.
1842   assignment->CombineTempAllocations();
1843 
1844   XLA_VLOG_LINES(2, assignment->ToString());
1845   TF_RETURN_IF_ERROR(assignment->ComputeSummaryStats());
1846   XLA_VLOG_LINES(1, assignment->GetStats().ToString());
1847   return std::move(assignment);
1848 }
1849 
1850 }  // namespace xla
1851