• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Defines the data returned by the XLA buffer assignment packages.
17 
18 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
19 
20 #include <algorithm>
21 #include <deque>
22 #include <ostream>
23 #include <utility>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/str_format.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/buffer_value_containers.h"
33 #include "tensorflow/compiler/xla/service/heap_simulator.h"
34 #include "tensorflow/compiler/xla/service/hlo.pb.h"
35 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
36 #include "tensorflow/compiler/xla/service/hlo_buffer.h"
37 #include "tensorflow/compiler/xla/service/hlo_live_range.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/service/hlo_value.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/status_macros.h"
42 #include "tensorflow/compiler/xla/types.h"
43 #include "tensorflow/compiler/xla/util.h"
44 #include "tensorflow/core/lib/core/errors.h"
45 #include "tensorflow/core/lib/hash/hash.h"
46 #include "tensorflow/core/lib/strings/numbers.h"
47 
48 namespace xla {
49 namespace {
50 
51 using absl::flat_hash_map;
52 using absl::flat_hash_set;
53 using absl::StrAppend;
54 using absl::StrAppendFormat;
55 using memory_space_assignment::PresetAssignments;
56 using ::tensorflow::strings::HumanReadableNumBytes;
57 
58 // Given the interference map of a graph (the list of interfering node indices
59 // for each node), perform graph coloring such that interfering nodes are
60 // assigned to different colors. Returns the assigned color of the nodes, where
61 // the colors are represented as integer values [0, color_count).
ColorInterferenceGraph(const std::vector<std::vector<int64>> & interference_map)62 std::vector<int64> ColorInterferenceGraph(
63     const std::vector<std::vector<int64>>& interference_map) {
64   const int64_t node_count = interference_map.size();
65 
66   // Sort the nodes such that we assign nodes with more interference first. This
67   // relies on the common heuristic of assigning the most constrained node
68   // first, but it would be good to investigate other ordering heuristics too.
69   std::vector<int64> nodes(node_count);
70   std::iota(nodes.begin(), nodes.end(), 0);
71   absl::c_sort(nodes, [&interference_map](const int64_t i, const int64_t j) {
72     return interference_map[i].size() > interference_map[j].size();
73   });
74 
75   const int64_t kColorUnassigned = -1;
76   std::vector<int64> assigned_colors(node_count, kColorUnassigned);
77   for (int64_t node : nodes) {
78     // Mark the colors that are already assigned to the neighbors.
79     std::vector<bool> available_colors(node_count, true);
80     for (int64_t neighbor : interference_map[node]) {
81       int64_t color = assigned_colors[neighbor];
82       if (color != kColorUnassigned) {
83         available_colors[color] = false;
84       }
85     }
86 
87     // Find the color that is not yet assigned to the neighbors.
88     int64_t color = kColorUnassigned;
89     for (color = 0; color < available_colors.size(); ++color) {
90       if (available_colors[color]) {
91         break;
92       }
93     }
94     CHECK_NE(color, kColorUnassigned);
95     assigned_colors[node] = color;
96   }
97   return assigned_colors;
98 }
99 
100 // If an hlo buffer contains an entry parameter, the buffer is read-only unless
101 // it is aliased with an output.
HloBufferIsReadOnly(const HloBuffer & buffer)102 bool HloBufferIsReadOnly(const HloBuffer& buffer) {
103   for (const HloValue* value : buffer.values()) {
104     const HloInstruction* instruction = value->instruction();
105     if (instruction->opcode() == HloOpcode::kConstant) {
106       return true;
107     }
108     const HloModule* module = instruction->parent()->parent();
109     const bool is_entry_parameter =
110         instruction->opcode() == HloOpcode::kParameter &&
111         instruction->parent() == module->entry_computation();
112 
113     if (is_entry_parameter) {
114       bool parameter_has_alias =
115           module->input_output_alias_config().ParameterHasAlias(
116               instruction->parameter_number(), value->index());
117       // The parameter doesn't have an alias, it must be read-only.
118       if (!parameter_has_alias) {
119         return true;
120       }
121     }
122   }
123   return false;
124 }
125 
126 }  // namespace
127 
GatherComputationsByAllocationType(const HloModule * module,std::vector<const HloComputation * > * thread_local_computations,std::vector<const HloComputation * > * global_computations)128 Status GatherComputationsByAllocationType(
129     const HloModule* module,
130     std::vector<const HloComputation*>* thread_local_computations,
131     std::vector<const HloComputation*>* global_computations) {
132   // Create a worklist of computations paired with whether the allocation must
133   // be thread-local.
134   std::deque<std::pair<const HloComputation*, bool>> worklist;
135   worklist.push_back(std::make_pair(module->entry_computation(),
136                                     /*is_thread_local*/ false));
137 
138   // Sets for quickly checking membership. Computations are returned in vectors
139   // for stable iteration.
140   flat_hash_set<const HloComputation*> thread_local_set;
141   flat_hash_set<const HloComputation*> global_set;
142 
143   while (!worklist.empty()) {
144     auto worklist_front = worklist.front();
145     worklist.pop_front();
146     const HloComputation* computation = worklist_front.first;
147     bool is_thread_local = worklist_front.second;
148     bool in_thread_local_set = thread_local_set.contains(computation);
149     bool in_global_set = global_set.contains(computation);
150 
151     // If the computation has already been added to the respective set, then
152     // nothing to do.
153     if ((is_thread_local && in_thread_local_set) ||
154         (!is_thread_local && in_global_set)) {
155       continue;
156     }
157 
158     // If the computation has already been added to the other set this is an
159     // error condition because the global call to the computation (eg,
160     // while/call) may return a reference to one of the thread-local buffers to
161     // the calling computation which will become a dangling reference when the
162     // thread-local is deallocated with the call return.
163     if ((is_thread_local && in_global_set) ||
164         (!is_thread_local && in_thread_local_set)) {
165       return InvalidArgument(
166           "computation %s has conflicting allocation requirements (global "
167           "and thread-local)",
168           computation->name());
169     }
170 
171     if (is_thread_local) {
172       thread_local_set.insert(computation);
173     } else {
174       global_set.insert(computation);
175     }
176 
177     for (auto* instruction : computation->instructions()) {
178       for (HloComputation* subcomputation :
179            instruction->called_computations()) {
180         switch (instruction->opcode()) {
181           case HloOpcode::kCall:
182           case HloOpcode::kConditional:
183           case HloOpcode::kWhile:
184             // Call and while must be called from a computation with global
185             // allocations as they may return references to buffers inside the
186             // called computation which cannot be thread-local.
187             if (is_thread_local) {
188               return InvalidArgument(
189                   "computation %s cannot contain call/while op because it "
190                   "requires thread-local buffer allocations",
191                   computation->name());
192             }
193             worklist.push_back(std::make_pair(subcomputation,
194                                               false));  // Not thread local.
195             break;
196           case HloOpcode::kCustomCall:
197           case HloOpcode::kAllReduce:
198           case HloOpcode::kReduceScatter:
199           case HloOpcode::kAllReduceStart:
200           case HloOpcode::kMap:
201           case HloOpcode::kReduce:
202           case HloOpcode::kReduceWindow:
203           case HloOpcode::kScatter:
204           case HloOpcode::kSelectAndScatter:
205           case HloOpcode::kSort:
206           case HloOpcode::kFusion:
207             // Map/reduce etc computations are always thread-local.
208             worklist.push_back(std::make_pair(subcomputation,
209                                               true));  // Thread local.
210             break;
211           default:
212             return InternalError("Unexpected calling opcode: %s",
213                                  HloOpcodeString(instruction->opcode()));
214         }
215       }
216     }
217   }
218 
219   // Add the computations to the vectors in post order.
220   for (auto* computation : module->MakeComputationPostOrder()) {
221     if (thread_local_set.contains(computation)) {
222       thread_local_computations->push_back(computation);
223     } else if (global_set.contains(computation)) {
224       global_computations->push_back(computation);
225     }
226     // If the computation is not reachable from the entry computation, then it
227     // will not appear in either thread_local_set or global_set. We don't bother
228     // assigning buffers for these.
229   }
230   return Status::OK();
231 }
232 
ToString() const233 string BufferAllocation::Slice::ToString() const {
234   return absl::StrCat("{index:", index(), ", offset:", offset_,
235                       ", size:", size_, "}");
236 }
237 
GetSlice(const HloValue & buffer) const238 BufferAllocation::Slice BufferAllocation::GetSlice(
239     const HloValue& buffer) const {
240   const OffsetSize os = FindOrDie(assigned_buffers_, &buffer);
241   return Slice(this, os.offset, os.size);
242 }
243 
AddAssignment(const HloValue & buffer,int64_t offset,int64_t size)244 void BufferAllocation::AddAssignment(const HloValue& buffer, int64_t offset,
245                                      int64_t size) {
246   VLOG(4) << "Adding the following buffer to allocation #" << index()
247           << absl::StrFormat(" (size=%d, offset=%d) %s", size, offset,
248                              buffer.ToShortString());
249   CHECK(!assigned_buffers_.contains(&buffer))
250       << "LogicalBuffer " << buffer << " already assigned to allocation "
251       << index_;
252   CHECK_LE(offset, size_) << "LogicalBuffer " << buffer
253                           << " offset out of range";
254   CHECK_LE(offset + size, size_)
255       << "LogicalBuffer " << buffer
256       << " size out of range at offset: " << offset << " with size: " << size;
257   CHECK_EQ(buffer.color(), color())
258       << "Buffer color " << buffer.color() << " for buffer " << buffer
259       << " does not match allocation color " << color() << ".";
260   OffsetSize offset_size;
261   offset_size.offset = offset;
262   offset_size.size = size;
263   assigned_buffers_.emplace(&buffer, offset_size);
264   // For debugging purposes, store the assigned memory space in the
265   // instruction's layout.
266   for (HloPosition position : buffer.positions()) {
267     Shape* shape = ShapeUtil::GetMutableSubshape(
268         position.instruction->mutable_shape(), position.index);
269     if (shape->has_layout()) {
270       shape->mutable_layout()->set_memory_space(buffer.color());
271     }
272   }
273 }
274 
ToProto() const275 BufferAllocationProto BufferAllocation::ToProto() const {
276   BufferAllocationProto proto;
277   proto.set_index(index_);
278   proto.set_size(size_);
279   proto.set_is_thread_local(is_thread_local_);
280   proto.set_is_tuple(is_tuple_);
281   proto.set_color(color_);
282   if (is_entry_computation_parameter_) {
283     proto.set_is_entry_computation_parameter(true);
284     for (int64_t idx : param_shape_index()) {
285       proto.add_parameter_shape_index(idx);
286     }
287     proto.set_parameter_number(parameter_number_);
288   }
289   proto.set_is_constant(is_constant_);
290   proto.set_maybe_live_out(maybe_live_out_);
291   for (const auto& buffer_offset_size : assigned_buffers_) {
292     BufferAllocationProto::Assigned* proto_assigned = proto.add_assigned();
293     proto_assigned->set_logical_buffer_id(buffer_offset_size.first->id());
294     proto_assigned->set_offset(buffer_offset_size.second.offset);
295     proto_assigned->set_size(buffer_offset_size.second.size);
296   }
297   absl::c_sort(*proto.mutable_assigned(),
298                [](const BufferAllocationProto::Assigned& assign1,
299                   const BufferAllocationProto::Assigned& assign2) {
300                  return assign1.logical_buffer_id() <
301                         assign2.logical_buffer_id();
302                });
303   return proto;
304 }
305 
CompareHloValuesById(const HloValue * a,const HloValue * b)306 static bool CompareHloValuesById(const HloValue* a, const HloValue* b) {
307   return a->id() < b->id();
308 }
309 
310 // Returns parameter instruction corresponding to the allocation or nullptr.
GetEntryParameterInstruction(const BufferAllocation & alloc)311 static const HloInstruction* GetEntryParameterInstruction(
312     const BufferAllocation& alloc) {
313   for (const auto& p : alloc.assigned_buffers()) {
314     const HloValue* value = p.first;
315     const HloInstruction* instr = value->instruction();
316     if (instr->opcode() == HloOpcode::kParameter &&
317         instr->parent() == instr->parent()->parent()->entry_computation()) {
318       return instr;
319     }
320   }
321   return nullptr;
322 }
323 
324 // Returns root module output instruction corresponding to the allocation or
325 // nullptr.
GetOutputInstruction(const BufferAllocation & alloc)326 static const HloInstruction* GetOutputInstruction(
327     const BufferAllocation& alloc) {
328   for (const auto& p : alloc.assigned_buffers()) {
329     const HloValue* value = p.first;
330     for (const HloPosition& position : value->positions()) {
331       const HloInstruction* instr = position.instruction;
332       if (position.index.empty() &&
333           instr->parent()->root_instruction() == instr &&
334           instr->parent()->IsEntryComputation()) {
335         return instr;
336       }
337     }
338   }
339   return nullptr;
340 }
341 
ToString() const342 string BufferAllocation::ToString() const {
343   string output;
344   StrAppendFormat(&output, "allocation %d: %p, size %d", index_, this, size());
345   if (color() != 0) {
346     StrAppend(&output, ", color ", color());
347   }
348   if (is_entry_computation_parameter()) {
349     const HloInstruction* param = GetEntryParameterInstruction(*this);
350     StrAppend(&output, ", parameter ", parameter_number(), ", shape |",
351               param ? param->shape().ToString(/*print_layout=*/false)
352                     : "<unknown shape>",
353               "| at ShapeIndex ", param_shape_index().ToString());
354   }
355   if (const HloInstruction* instr = GetOutputInstruction(*this)) {
356     StrAppend(&output, ", output shape is |",
357               instr->shape().ToString(/*print_layout=*/false), "|");
358   }
359   if (is_constant()) {
360     StrAppend(&output, ", constant");
361   }
362   if (is_thread_local()) {
363     StrAppend(&output, ", thread-local");
364   }
365   if (maybe_live_out()) {
366     StrAppend(&output, ", maybe-live-out");
367   }
368   if (IsPreallocatedTempBuffer()) {
369     StrAppend(&output, ", preallocated-temp");
370   }
371   StrAppend(&output, ":\n");
372   // Dump the assigned buffers ordered by id.
373   std::vector<const HloValue*> sorted_buffers;
374   for (const auto& buffer_offset_size : assigned_buffers_) {
375     sorted_buffers.push_back(buffer_offset_size.first);
376   }
377   absl::c_sort(sorted_buffers, &CompareHloValuesById);
378   for (const HloValue* buffer : sorted_buffers) {
379     const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer);
380     StrAppend(&output,
381               absl::StrFormat(
382                   " value: %s (size=%d,offset=%d): %s\n",
383                   buffer->ToShortString(), offset_size.size, offset_size.offset,
384                   ShapeUtil::HumanStringWithLayout(buffer->shape())));
385   }
386   return output;
387 }
388 
operator <<(std::ostream & out,const BufferAllocation & buffer)389 std::ostream& operator<<(std::ostream& out, const BufferAllocation& buffer) {
390   out << buffer.ToString();
391   return out;
392 }
393 
operator <<(std::ostream & out,const BufferAllocation::Slice & s)394 std::ostream& operator<<(std::ostream& out, const BufferAllocation::Slice& s) {
395   out << s.ToString();
396   return out;
397 }
398 
HasAllocation(const HloValue & value) const399 bool BufferAssignment::HasAllocation(const HloValue& value) const {
400   return allocation_index_for_value_.contains(&value);
401 }
402 
HasAllocation(const HloBuffer & buffer) const403 bool BufferAssignment::HasAllocation(const HloBuffer& buffer) const {
404   return allocation_index_for_value_.contains(buffer.values()[0]);
405 }
406 
GetAssignedAllocation(const HloValue & value) const407 const BufferAllocation& BufferAssignment::GetAssignedAllocation(
408     const HloValue& value) const {
409   CHECK(HasAllocation(value));
410   return GetAllocation(allocation_index_for_value_.at(&value));
411 }
412 
GetAssignedAllocation(const HloBuffer & hlo_buffer) const413 const BufferAllocation& BufferAssignment::GetAssignedAllocation(
414     const HloBuffer& hlo_buffer) const {
415   return GetAssignedAllocation(*hlo_buffer.values()[0]);
416 }
417 
GetMutableAssignedAllocation(const HloBuffer & buffer)418 BufferAllocation* BufferAssignment::GetMutableAssignedAllocation(
419     const HloBuffer& buffer) {
420   return const_cast<BufferAllocation*>(&GetAssignedAllocation(buffer));
421 }
422 
GetAllSlices(const HloInstruction * instruction,const ShapeIndex & index) const423 std::set<BufferAllocation::Slice> BufferAssignment::GetAllSlices(
424     const HloInstruction* instruction, const ShapeIndex& index) const {
425   std::set<BufferAllocation::Slice> result;
426   for (const HloValue* value :
427        dataflow_analysis().GetValueSet(instruction, index).values()) {
428     if (HasAllocation(*value)) {
429       result.insert(GetAssignedAllocation(*value).GetSlice(*value));
430     }
431   }
432   return result;
433 }
434 
GetAllocation(BufferAllocation::Index index) const435 const BufferAllocation& BufferAssignment::GetAllocation(
436     BufferAllocation::Index index) const {
437   CHECK_GE(index, 0);
438   CHECK_LT(index, allocations_.size());
439   return allocations_[index];
440 }
441 
GetInstructionAllocation(const HloInstruction * hlo,const ShapeIndex & shape_index) const442 const BufferAllocation* BufferAssignment::GetInstructionAllocation(
443     const HloInstruction* hlo, const ShapeIndex& shape_index) const {
444   const HloValue* value =
445       dataflow_analysis().GetValueSet(hlo, shape_index).values()[0];
446 
447   if (!HasAllocation(*value)) {
448     return nullptr;
449   }
450 
451   const BufferAllocation& instruction_allocation =
452       GetAssignedAllocation(*value);
453   return &instruction_allocation;
454 }
455 
GetMutableAllocation(BufferAllocation::Index index)456 BufferAllocation* BufferAssignment::GetMutableAllocation(
457     BufferAllocation::Index index) {
458   return const_cast<BufferAllocation*>(&GetAllocation(index));
459 }
460 
HasAllocationAt(const HloInstruction * instruction,const ShapeIndex & index) const461 bool BufferAssignment::HasAllocationAt(const HloInstruction* instruction,
462                                        const ShapeIndex& index) const {
463   for (const HloValue* value :
464        dataflow_analysis().GetValueSet(instruction, index).values()) {
465     if (allocation_index_for_value_.contains(value)) {
466       return true;
467     }
468   }
469   return false;
470 }
471 
HasTopLevelAllocation(const HloInstruction * instruction) const472 bool BufferAssignment::HasTopLevelAllocation(
473     const HloInstruction* instruction) const {
474   return HasAllocationAt(instruction, /*index=*/{});
475 }
476 
GetUniqueSlice(const HloInstruction * instruction,const ShapeIndex & index) const477 StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueSlice(
478     const HloInstruction* instruction, const ShapeIndex& index) const {
479   VLOG(3) << "Trying to find unique slice for " << instruction->name() << " ["
480           << index << "]";
481   BufferAllocation::Slice result;
482   for (const HloValue* value :
483        dataflow_analysis().GetValueSet(instruction, index).values()) {
484     VLOG(3) << "Examining value " << *value;
485     if (HasAllocation(*value)) {
486       VLOG(3) << "Has allocation";
487       const BufferAllocation::Slice slice =
488           GetAssignedAllocation(*value).GetSlice(*value);
489       if (result.allocation() == nullptr) {
490         result = slice;
491       } else if (result != slice) {
492         return FailedPrecondition(
493             "BufferAllocation::Slice for instruction %s at index %s cannot "
494             "be determined at compile-time.",
495             instruction->name(), index.ToString());
496       }
497     } else {
498       VLOG(3) << "No allocation";
499     }
500   }
501   if (result.allocation() == nullptr) {
502     return FailedPrecondition(
503         "BufferAllocation::Slice not assigned for instruction %s at index %s",
504         instruction->name(), index.ToString());
505   }
506   return result;
507 }
508 
GetUniqueTopLevelSlice(const HloInstruction * instruction) const509 StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueTopLevelSlice(
510     const HloInstruction* instruction) const {
511   return GetUniqueSlice(instruction, /*index=*/{});
512 }
513 
SharesSliceAtIndex(const HloInstruction * hlo_a,const ShapeIndex & shape_index_a,const HloInstruction * hlo_b,const ShapeIndex & shape_index_b) const514 bool BufferAssignment::SharesSliceAtIndex(
515     const HloInstruction* hlo_a, const ShapeIndex& shape_index_a,
516     const HloInstruction* hlo_b, const ShapeIndex& shape_index_b) const {
517   return GetUniqueSlice(hlo_a, shape_index_a).ConsumeValueOrDie() ==
518          GetUniqueSlice(hlo_b, shape_index_b).ConsumeValueOrDie();
519 }
520 
HaveDisjointSlices(const HloInstruction * hlo_a,const HloInstruction * hlo_b) const521 bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a,
522                                           const HloInstruction* hlo_b) const {
523   using SliceSet = flat_hash_set<BufferAllocation::Slice>;
524   // Gets the slices all of instr's subshapes.  If any subshape doesn't have an
525   // assigned slice, returns the empty set.
526   auto collect_slices = [&](const HloInstruction* instr) -> SliceSet {
527     SliceSet slices;
528     Status status = ShapeUtil::ForEachSubshapeWithStatus(
529         instr->shape(),
530         [&](const Shape& /*subshape*/, const ShapeIndex& index) {
531           auto shape_slices = GetAllSlices(instr, index);
532           if (shape_slices.empty()) {
533             return InvalidArgument("No slices assigned to part of instr.");
534           }
535           slices.insert(shape_slices.begin(), shape_slices.end());
536           return Status::OK();
537         });
538     if (!status.ok()) {
539       return {};
540     }
541     return slices;
542   };
543 
544   SliceSet slices_a = collect_slices(hlo_a);
545   SliceSet slices_b = collect_slices(hlo_b);
546   // hlo_a and hlo_b have disjoint slices if collect_slices succeeded (i.e.
547   // didn't return the empty set) for both HLOs, and the two resulting sets of
548   // slices are disjoint.
549   return !slices_a.empty() && !slices_b.empty() &&
550          absl::c_none_of(slices_a, [&](const BufferAllocation::Slice& slice) {
551            return slices_b.contains(slice);
552          });
553 }
554 
555 StatusOr<BufferAllocation::Slice>
GetUniqueTopLevelOutputSlice() const556 BufferAssignment::GetUniqueTopLevelOutputSlice() const {
557   return GetUniqueTopLevelSlice(
558       module_->entry_computation()->root_instruction());
559 }
560 
NewEmptyAllocation(int64_t size,LogicalBuffer::Color color)561 BufferAllocation* BufferAssignment::NewEmptyAllocation(
562     int64_t size, LogicalBuffer::Color color) {
563   BufferAllocation::Index index = allocations_.size();
564   allocations_.emplace_back(index, size, color);
565   BufferAllocation* allocation = &allocations_.back();
566   return allocation;
567 }
568 
NewAllocation(const HloBuffer & buffer,int64_t size)569 BufferAllocation* BufferAssignment::NewAllocation(const HloBuffer& buffer,
570                                                   int64_t size) {
571   BufferAllocation* allocation = NewEmptyAllocation(size, buffer.color());
572   AddAssignment(allocation, buffer, /*offset=*/0, size);
573   allocation->peak_buffers_.push_back(buffer.values()[0]);
574   return allocation;
575 }
576 
AddAssignment(BufferAllocation * allocation,const HloBuffer & buffer,int64_t offset,int64_t size)577 void BufferAssignment::AddAssignment(BufferAllocation* allocation,
578                                      const HloBuffer& buffer, int64_t offset,
579                                      int64_t size) {
580   CHECK(allocation->is_reusable() || allocation->assigned_buffers().empty())
581       << "Non-reusable allocation already assigned a buffer: "
582       << allocation->ToString();
583 
584   for (const HloValue* buffer_value : buffer.values()) {
585     CHECK(!allocation_index_for_value_.contains(buffer_value))
586         << "BufferValue " << buffer_value << " already has an allocation.";
587     allocation->AddAssignment(*buffer_value, offset, size);
588     allocation_index_for_value_[buffer_value] = allocation->index();
589   }
590 
591   if (alias_analysis().BufferLivesOut(buffer)) {
592     VLOG(3) << "HloBuffer lives out: " << buffer.ToString();
593     VLOG(3) << "Set maybe live out: " << allocation->ToString();
594     allocation->set_maybe_live_out(true);
595   }
596 }
597 
AddAssignment(BufferAllocation * allocation,const HloValue & value,int64_t offset,int64_t size)598 void BufferAssignment::AddAssignment(BufferAllocation* allocation,
599                                      const HloValue& value, int64_t offset,
600                                      int64_t size) {
601   allocation->AddAssignment(value, offset, size);
602   allocation_index_for_value_[&value] = allocation->index();
603   const HloValue& hlo_value =
604       *CHECK_NOTNULL(dynamic_cast<const HloValue*>(&value));
605   if (alias_analysis().ValueLivesOut(hlo_value)) {
606     VLOG(3) << "HloValue lives out: " << hlo_value.ToString();
607     VLOG(3) << "Set maybe live out: " << allocation->ToString();
608     allocation->set_maybe_live_out(true);
609   }
610 }
611 
612 // Combines allocations of temporary buffers of the same color into one big
613 // BufferAllocation.
CombineTempAllocations()614 void BufferAssignment::CombineTempAllocations() {
615   VLOG(1) << "CombineTempAllocations()";
616   // Stores the combined allocations.
617   std::deque<BufferAllocation> combined_allocations;
618   // Holds the pointer to a combined allocation of each color, if any.
619   flat_hash_map<BufferValue::Color, BufferAllocation*> combined_allocation_map;
620 
621   // Move all temp allocations into a single run at the end of the allocations
622   // vector.
623   const auto first_temp_it =
624       std::partition(allocations_.begin(), allocations_.end(),
625                      [](const BufferAllocation& allocation) {
626                        return !allocation.IsPreallocatedTempBuffer();
627                      });
628 
629   // Walk over the run of temp allocations, collecting the allocations belonging
630   // to the same color.
631   if (first_temp_it != allocations_.end()) {
632     for (auto it = first_temp_it; it != allocations_.end(); ++it) {
633       BufferAllocation& temp_allocation = *it;
634       BufferValue::Color color = temp_allocation.color();
635       auto combined_it = combined_allocation_map.find(color);
636       if (combined_it == combined_allocation_map.end()) {
637         // We have found the first temp allocation of this color. Collect
638         // the other temp allocations of the same color into it subject to the
639         // size constraint.
640         VLOG(1) << "Combined temp allocation for color " << color
641                 << " is: " << temp_allocation;
642         combined_allocations.emplace_back(temp_allocation);
643         combined_allocation_map.emplace(color, &combined_allocations.back());
644         continue;
645       }
646       if (combined_it->second->size() + it->size() >=
647           multiheap_size_constraint_per_heap_) {
648         // We cannot put more into the current combined_it. So, appoint a new
649         // combined_it.
650         VLOG(1) << "Due to size constraint, reset temp allocation for color "
651                 << color << " to: " << temp_allocation;
652         combined_allocations.emplace_back(temp_allocation);
653         combined_allocation_map.emplace(color, &combined_allocations.back());
654         continue;
655       }
656 
657       BufferAllocation* combined_allocation = combined_it->second;
658       VLOG(1) << "Combined allocation absorbing temp allocation: "
659               << temp_allocation;
660 
661       // Each temp allocation is placed end-to-end, accounting for alignment.
662       // The offset of each buffer in the combined allocation is computed from
663       // the base offset of the allocation.
664       int64_t alignment = color_alignment_(color);
665       const int64_t base =
666           RoundUpToNearest(combined_allocation->size(), alignment);
667       combined_allocation->set_size(base + temp_allocation.size());
668       for (const auto& buffer_offset_size : temp_allocation.assigned_buffers_) {
669         const HloValue* value = buffer_offset_size.first;
670         const int64_t offset = buffer_offset_size.second.offset;
671         const int64_t size = buffer_offset_size.second.size;
672         combined_allocation->AddAssignment(*value, base + offset, size);
673       }
674       if (!temp_allocation.HeapTraces().empty()) {
675         CHECK_EQ(temp_allocation.HeapTraces().size(), 1);
676         combined_allocation->AddHeapTrace(temp_allocation.HeapTraces().front());
677       }
678 
679       combined_allocation->peak_buffers_.insert(
680           combined_allocation->peak_buffers_.end(),
681           temp_allocation.peak_buffers_.begin(),
682           temp_allocation.peak_buffers_.end());
683     }
684     // Replace all existing temporary allocations with the new combined
685     // allocations.
686     allocations_.erase(first_temp_it, allocations_.end());
687     for (BufferAllocation& combined : combined_allocations) {
688       temp_allocation_total_size_ += combined.size();
689       allocations_.push_back(std::move(combined));
690     }
691   }
692 
693   // Update allocation indices to their new positions.
694   allocation_index_for_value_.erase(allocation_index_for_value_.begin(),
695                                     allocation_index_for_value_.end());
696   for (size_t index = 0; index < allocations_.size(); ++index) {
697     BufferAllocation* allocation = &allocations_[index];
698     allocation->set_index(index);
699     for (const auto& buffer_offset_size : allocation->assigned_buffers_) {
700       const HloValue* value = buffer_offset_size.first;
701       allocation_index_for_value_[value] = index;
702     }
703   }
704 }
705 
ComputeSummaryStats()706 Status BufferAssignment::ComputeSummaryStats() {
707   for (auto& allocation : Allocations()) {
708     if (allocation.is_entry_computation_parameter()) {
709       stats_.parameter_allocation_count++;
710       stats_.parameter_allocation_bytes += allocation.size();
711     }
712     if (allocation.is_constant()) {
713       stats_.constant_allocation_count++;
714       stats_.constant_allocation_bytes += allocation.size();
715     }
716     if (allocation.maybe_live_out()) {
717       stats_.maybe_live_out_allocation_count++;
718       stats_.maybe_live_out_allocation_bytes += allocation.size();
719     }
720     if (allocation.IsPreallocatedTempBuffer()) {
721       stats_.preallocated_temp_allocation_count++;
722       stats_.preallocated_temp_allocation_bytes += allocation.size();
723     }
724     stats_.total_allocation_count++;
725     stats_.total_allocation_bytes += allocation.size();
726   }
727 
728   // Only compute total fragmentation if all computations have schedules.
729   HloSchedule schedule(module_);
730   bool schedule_complete = true;
731   for (const auto& computation : module_->computations()) {
732     if (!computation->IsFusionComputation()) {
733       const HloInstructionSequence* sequence =
734           hlo_ordering().SequentialOrder(*computation);
735       if (sequence == nullptr) {
736         schedule_complete = false;
737       } else {
738         schedule.set_sequence(computation, *sequence);
739       }
740     }
741   }
742   if (schedule_complete) {
743     TF_RETURN_IF_ERROR(schedule.Verify());
744     TF_ASSIGN_OR_RETURN(
745         const int64_t min_size,
746         HeapSimulator::MinimumMemoryForModule(schedule, buffer_size_));
747     stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size;
748   }
749 
750   return Status::OK();
751 }
752 
ToString() const753 string BufferAssignment::Stats::ToString() const {
754   string s;
755   StrAppendFormat(&s, "BufferAssignment stats:\n");
756   StrAppendFormat(&s, "             parameter allocation: %10s\n",
757                   HumanReadableNumBytes(parameter_allocation_bytes));
758   StrAppendFormat(&s, "              constant allocation: %10s\n",
759                   HumanReadableNumBytes(constant_allocation_bytes));
760   StrAppendFormat(&s, "        maybe_live_out allocation: %10s\n",
761                   HumanReadableNumBytes(maybe_live_out_allocation_bytes));
762   StrAppendFormat(&s, "     preallocated temp allocation: %10s\n",
763                   HumanReadableNumBytes(preallocated_temp_allocation_bytes));
764   if (preallocated_temp_fragmentation_bytes >= 0) {
765     const double percent = 100. * preallocated_temp_fragmentation_bytes /
766                            preallocated_temp_allocation_bytes;
767     StrAppendFormat(
768         &s, "  preallocated temp fragmentation: %10s (%.2f%%)\n",
769         HumanReadableNumBytes(preallocated_temp_fragmentation_bytes), percent);
770   }
771   StrAppendFormat(&s, "                 total allocation: %10s\n",
772                   HumanReadableNumBytes(total_allocation_bytes));
773   if (total_fragmentation_bytes >= 0) {
774     const double percent =
775         100. * total_fragmentation_bytes / total_allocation_bytes;
776     StrAppendFormat(&s, "              total fragmentation: %10s (%.2f%%)\n",
777                     HumanReadableNumBytes(total_fragmentation_bytes), percent);
778   }
779   return s;
780 }
781 
ToString() const782 string BufferAssignment::ToString() const {
783   string output;
784   absl::StrAppend(&output, "BufferAssignment:\n");
785   std::vector<const HloValue*> used_values;
786   int64_t total_size = 0;
787   for (auto& allocation : allocations_) {
788     total_size += allocation.size();
789     absl::StrAppend(&output, allocation.ToString());
790     for (const auto& p : allocation.assigned_buffers()) {
791       used_values.push_back(p.first);
792     }
793   }
794   absl::StrAppend(&output, "\nTotal bytes used: ", total_size, " (",
795                   HumanReadableNumBytes(total_size), ")\n");
796   absl::StrAppend(&output, "\nUsed values:\n");
797   absl::c_sort(used_values, &CompareHloValuesById);
798   for (const HloValue* value : used_values) {
799     absl::StrAppend(&output, value->ToString());
800   }
801   return output;
802 }
803 
BufferInfoString() const804 string BufferAssignment::BufferInfoString() const {
805   string binfo;
806   // Columns in buffer information:
807   // buffer_id: int. This value can be used to match the allocation in
808   // allocation information.
809   // buffer_name: string.
810   // offset: int. Starting position of the buffer in the memory space.
811   // size: int. Size of the buffer in bytes.
812   // definition_time: int. Position in the schedule where the buffer starts
813   // being live (inclusive).
814   // end_time: int. Position in the schedule where the buffer stops being live
815   // (exclusive).
816   // num_uses: int. Number of uses of the buffer.
817   // use_names: string. This is a semicolon-separated list of string
818   // representation of uses.
819   // Append the column names.
820   absl::StrAppend(&binfo,
821                   "buffer_id,buffer_name,offset,size,"
822                   "definition_time,end_time,num_uses,use_times,use_names\n");
823   const HloLiveRange& live_ranges = hlo_live_range();
824   const auto& instruction_schedule = live_ranges.instruction_schedule();
825   const auto& buffer_live_ranges = live_ranges.buffer_live_ranges();
826   // Sort the buffers by Id.
827   std::vector<std::pair<const HloValue*, BufferAllocation::OffsetSize>> buffers;
828   for (const BufferAllocation& allocation : allocations_) {
829     absl::c_copy(allocation.assigned_buffers(), std::back_inserter(buffers));
830   }
831   absl::c_sort(
832       buffers,
833       [](const std::pair<const HloValue*, BufferAllocation::OffsetSize>& b1,
834          const std::pair<const HloValue*, BufferAllocation::OffsetSize>& b2) {
835         return b1.first->id() < b2.first->id();
836       });
837   for (const auto& buffer_pair : buffers) {
838     const HloValue& buffer = *buffer_pair.first;
839     const BufferAllocation::OffsetSize& offset_size = buffer_pair.second;
840     if (!buffer_live_ranges.contains(&buffer)) {
841       continue;
842     }
843     // Ordering uses by their use position.
844     std::vector<std::pair<int64, std::string>> uses;
845     uses.reserve(buffer.uses().size());
846     for (const HloUse& use : buffer.uses()) {
847       uses.emplace_back(instruction_schedule.at(use.instruction),
848                         use.ToString());
849     }
850     absl::c_sort(uses);
851     std::vector<int64> use_positions;
852     std::vector<std::string> use_names;
853     use_positions.reserve(uses.size());
854     use_names.reserve(uses.size());
855     for (const auto& use : uses) {
856       use_positions.push_back(use.first);
857       use_names.push_back(use.second);
858     }
859     const int64_t definition_time =
860         instruction_schedule.at(buffer.defining_position().instruction);
861     const int64_t end_t = buffer_live_ranges.at(&buffer).end;
862     absl::StrAppend(&binfo, buffer.id(), ",");
863     absl::StrAppend(&binfo, "\"", buffer.ToShortString(), "\",");
864     absl::StrAppend(&binfo, offset_size.offset, ",");
865     absl::StrAppend(&binfo, offset_size.size, ",");
866     absl::StrAppend(&binfo, definition_time, ",");
867     absl::StrAppend(&binfo, end_t, ",");
868     absl::StrAppend(&binfo, use_positions.size(), ",");
869     absl::StrAppend(&binfo, "\"", absl::StrJoin(use_positions, ";"), "\",");
870     absl::StrAppend(&binfo, "\"", absl::StrJoin(use_names, ";"), "\"");
871     absl::StrAppend(&binfo, "\n");
872   }
873   return binfo;
874 }
875 
ToProto() const876 BufferAssignmentProto BufferAssignment::ToProto() const {
877   BufferAssignmentProto proto;
878   // NOTE: DataflowAnalysis state is serialized here in BufferAssignment,
879   // because we need to do the HasAllocation check for each buffer. Otherwise
880   // the buffer_size_ call might fail for some backends.
881   const HloDataflowAnalysis& dataflow = this->dataflow_analysis();
882   for (BufferValue::Id id = 0; id < dataflow.values().size(); id++) {
883     auto& value = dataflow.values().at(id);
884     if (HasAllocation(*value)) {
885       LogicalBufferProto proto_buffer = value->ToProto(buffer_size_);
886       proto.add_logical_buffers()->Swap(&proto_buffer);
887 
888       // Fill buffer aliases.
889       for (const HloValue* alias :
890            alias_analysis().GetBufferContainingValue(*value).values()) {
891         if (alias->instruction() == value->instruction() &&
892             alias->index() == value->index()) {
893           continue;  // skip self-aliases
894         }
895         BufferAssignmentProto::BufferAlias* proto_alias =
896             proto.add_buffer_aliases();
897         LogicalBufferProto::Location proto_alias_location =
898             BufferValue::ToLocationProto(*alias->instruction(), alias->index());
899         proto_alias->set_source_buffer_id(value->id());
900         proto_alias->mutable_location()->Swap(&proto_alias_location);
901       }
902     }
903   }
904   for (const BufferAllocation& allocation : Allocations()) {
905     BufferAllocationProto proto_allocation = allocation.ToProto();
906     proto.add_buffer_allocations()->Swap(&proto_allocation);
907     for (const HeapSimulatorTrace& heap_trace : allocation.HeapTraces()) {
908       *proto.add_heap_simulator_traces() = heap_trace;
909     }
910   }
911   return proto;
912 }
913 
914 /* static */
Run(const HloModule * module,std::unique_ptr<HloOrdering> hlo_ordering,BufferValue::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment,bool allocate_buffers_for_constants,BufferAssigner::Colorer colorer,const absl::flat_hash_set<HloOpcode> & must_not_live_out,HloDataflowAnalysis::CanShareBuffer can_share_buffer,std::unique_ptr<PresetAssignments> preset_assignments)915 StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::Run(
916     const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
917     BufferValue::SizeFunction buffer_size,
918     LogicalBuffer::AlignmentFunction color_alignment,
919     bool allocate_buffers_for_constants, BufferAssigner::Colorer colorer,
920     const absl::flat_hash_set<HloOpcode>& must_not_live_out,
921     HloDataflowAnalysis::CanShareBuffer can_share_buffer,
922     std::unique_ptr<PresetAssignments> preset_assignments) {
923   BufferAssigner assigner(allocate_buffers_for_constants, std::move(colorer),
924                           must_not_live_out, std::move(preset_assignments));
925   return assigner.CreateAssignment(
926       module, std::move(hlo_ordering), std::move(buffer_size),
927       std::move(color_alignment), std::move(can_share_buffer));
928 }
929 
LiveRangeInterferes(const HloValue * buffer1,const HloValue * buffer2,BufferAssignment * assignment)930 bool BufferAssigner::LiveRangeInterferes(const HloValue* buffer1,
931                                          const HloValue* buffer2,
932                                          BufferAssignment* assignment) {
933   CHECK((assignment->hlo_live_range().total_order_scheduled()));
934   const HloLiveRange& hlo_live_range = assignment->hlo_live_range();
935 
936   const auto& buffer_live_ranges = hlo_live_range.buffer_live_ranges();
937 
938   CHECK(buffer_live_ranges.contains(buffer1))
939       << "Buffer doesn't have a proper live range:" << buffer1;
940 
941   CHECK(buffer_live_ranges.contains(buffer2))
942       << "Buffer doesn't have a proper live range:" << buffer2;
943 
944   // Check if a user value can share the same buffer as its operand.
945   auto can_share_as_operand = [&assignment, &buffer_live_ranges](
946                                   const HloValue* user_value,
947                                   const HloValue* operand_value) {
948     // An hlo value can hold multiple instructions during its life time. We only
949     // look at the last instruction and check if it can be shared with the
950     // operand.
951     HloPosition operand_end_position =
952         buffer_live_ranges.at(operand_value).end_position;
953     return user_value->instruction()->IsUserOf(
954                operand_end_position.instruction) &&
955            assignment->dataflow_analysis().CanShareOperandBufferWithUser(
956                operand_end_position.instruction, operand_end_position.index,
957                user_value->instruction(), user_value->index()) &&
958            user_value->instruction()->opcode() != HloOpcode::kCopy;
959   };
960 
961   auto live_range_1 = buffer_live_ranges.at(buffer1);
962   auto live_range_2 = buffer_live_ranges.at(buffer2);
963 
964   if (!(live_range_1.start > live_range_2.end ||
965         live_range_2.start > live_range_1.end)) {
966     if (live_range_1.end == live_range_2.start) {
967       auto operand_value = buffer1;
968       auto user_value = buffer2;
969       if (!can_share_as_operand(user_value, operand_value)) {
970         VLOG(4) << "End of live range of " << buffer1->ToShortString()
971                 << " is equal to the start of live range of "
972                 << buffer2->ToShortString() << ", buffer cannot be shared.";
973         return true;
974       }
975     } else if (live_range_2.end == live_range_1.start) {
976       auto operand_value = buffer2;
977       auto user_value = buffer1;
978       if (!can_share_as_operand(user_value, operand_value)) {
979         VLOG(4) << "End of live range of " << buffer2->ToShortString()
980                 << " is equal to the start of live range of "
981                 << buffer1->ToShortString() << ", buffer cannot be shared.";
982         return true;
983       }
984     } else {
985       VLOG(4) << "Can't assign: assignee " << *buffer1 << " may interfere with "
986               << *buffer2;
987       VLOG(4) << "assigned_buffer.start: " << live_range_1.start;
988       VLOG(4) << "assigned_buffer.end: " << live_range_1.end;
989       VLOG(4) << "live_range_2.start" << live_range_2.start;
990       VLOG(4) << "live_range_2.end" << live_range_2.end;
991       return true;
992     }
993   }
994   return false;
995 }
996 
MaybeAssignBuffer(BufferAllocation * allocation,const HloBuffer & hlo_buffer,BufferAssignment * assignment)997 bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
998                                        const HloBuffer& hlo_buffer,
999                                        BufferAssignment* assignment) {
1000   CHECK(!assignment->HasAllocation(hlo_buffer))
1001       << "buffer " << hlo_buffer << " already has an allocation assigned.";
1002 
1003   VLOG(4) << "Trying to assign " << hlo_buffer << " size "
1004           << assignment->HloBufferSize(hlo_buffer)
1005           << " to allocation: " << *allocation;
1006 
1007   if (hlo_buffer.color() != allocation->color()) {
1008     VLOG(4) << "Can't assign: buffer has color " << hlo_buffer.color()
1009             << " and allocation has color " << allocation->color() << ".";
1010     return false;
1011   }
1012 
1013   if (assignment->HloBufferSize(hlo_buffer) > allocation->size()) {
1014     VLOG(4) << "Can't assign: buffer is larger than allocation ("
1015             << assignment->HloBufferSize(hlo_buffer) << " > "
1016             << allocation->size() << ")";
1017     return false;
1018   }
1019 
1020   if (allocation->is_readonly()) {
1021     VLOG(4) << "Can't assign: allocation is readonly";
1022     return false;
1023   }
1024 
1025   if (!must_not_live_out_.empty()) {
1026     if (allocation->maybe_live_out()) {
1027       // If a buffer maybe live out, the allocation cannot contain any node from
1028       // the "must_not_live_out_" set.
1029       for (const HloValue* value : hlo_buffer.values()) {
1030         if (must_not_live_out_.count(value->instruction()->opcode()) > 0) {
1031           VLOG(4) << "Can't assign: " << value->instruction()->ToString()
1032                   << " cannot live out of the module";
1033           return false;
1034         }
1035       }
1036     }
1037     // The above check is not enough -- There could be the case where an
1038     // allocation can be not live out and contains an instruction with opcode
1039     // from the "must_not_live_out_" set, but assigning a live out buffer to
1040     // that allocation makes the allocation live out and also contains
1041     // instruction from the "must_not_live_out_" set.
1042     if (assignment->alias_analysis().BufferLivesOut(hlo_buffer)) {
1043       for (const auto& buffer_offset_size : allocation->assigned_buffers()) {
1044         if (must_not_live_out_.count(
1045                 buffer_offset_size.first->instruction()->opcode()) > 0) {
1046           VLOG(4) << "Can't assign: " << buffer_offset_size.first->instruction()
1047                   << " cannot live out of the module";
1048           return false;
1049         }
1050       }
1051     }
1052   }
1053 
1054   if (!allocation->is_reusable()) {
1055     VLOG(4) << "Can't assign: allocation is not reusable";
1056     return false;
1057   }
1058 
1059   for (const auto& buffer_offset_size : allocation->assigned_buffers()) {
1060     // Pairwise compare.
1061     const HloValue& assigned_buffer =
1062         *CHECK_NOTNULL(dynamic_cast<const HloValue*>(buffer_offset_size.first));
1063     for (const HloValue* new_value : hlo_buffer.values()) {
1064       if (assignment->hlo_live_range().total_order_scheduled()) {
1065         if (LiveRangeInterferes(new_value, &assigned_buffer, assignment)) {
1066           VLOG(4) << "Can't assign: assignee " << assigned_buffer
1067                   << " live range interferes with "
1068                   << new_value->ToShortString();
1069           return false;
1070         }
1071       } else if (assignment->hlo_ordering().MayInterfere(
1072                      assigned_buffer, *new_value,
1073                      assignment->dataflow_analysis())) {
1074         // Fallback to partial order based interference detection (slower) when
1075         // we don't have a total order scheduled module.
1076         VLOG(4) << "Can't assign: assignee " << assigned_buffer
1077                 << " may interfere with " << new_value->ToShortString();
1078         return false;
1079       }
1080 
1081       for (const HloPosition& assigned_buffer_position :
1082            assigned_buffer.positions()) {
1083         // Copy instruction don't share a buffer with their input operand.
1084         if (new_value->instruction()->IsUserOf(
1085                 assigned_buffer_position.instruction) &&
1086             new_value->instruction()->opcode() == HloOpcode::kCopy) {
1087           VLOG(4) << "Can't assign: assignee " << assigned_buffer
1088                   << " is used at copy instruction "
1089                   << new_value->ToShortString();
1090           return false;
1091         }
1092       }
1093     }
1094   }
1095 
1096   // If the buffer is live out of the computation then it should only be
1097   // assigned a buffer which exactly fits the result to avoid wasting memory
1098   // (result buffers can have arbitrary lifetimes).
1099   if (assignment->alias_analysis().BufferLivesOut(hlo_buffer) &&
1100       allocation->size() != assignment->HloBufferSize(hlo_buffer)) {
1101     VLOG(4) << "Can't assign: buffer " << hlo_buffer
1102             << "is live out and size not the same as allocation";
1103     return false;
1104   }
1105 
1106   assignment->AddAssignment(allocation, hlo_buffer, /*offset=*/0,
1107                             assignment->HloBufferSize(hlo_buffer));
1108   return true;
1109 }  // namespace xla
1110 
AssignSingleHloBuffer(const HloBuffer * hlo_buffer,bool is_thread_local,absl::flat_hash_map<const HloComputation *,absl::flat_hash_set<const HloValue * >> * buffers_to_assign_sequentially,std::vector<BufferAllocation::Index> * allocation_indices,BufferAssignment * assignment)1111 Status BufferAssigner::AssignSingleHloBuffer(
1112     const HloBuffer* hlo_buffer, bool is_thread_local,
1113     absl::flat_hash_map<const HloComputation*,
1114                         absl::flat_hash_set<const HloValue*>>*
1115         buffers_to_assign_sequentially,
1116     std::vector<BufferAllocation::Index>* allocation_indices,
1117     BufferAssignment* assignment) {
1118   const int64_t buffer_size = assignment->HloBufferSize(*hlo_buffer);
1119   for (const HloValue* value : hlo_buffer->values()) {
1120     if (value->instruction()->opcode() == HloOpcode::kConstant) {
1121       if (allocate_buffers_for_constants_) {
1122         BufferAllocation* allocation =
1123             assignment->NewAllocation(*hlo_buffer, buffer_size);
1124         allocation->set_constant(true);
1125         VLOG(3) << "New allocation #" << allocation->index() << " for constant "
1126                 << *hlo_buffer << " value ptr: " << value;
1127       }
1128       VLOG(3) << "Not allocating buffer for constant";
1129       return Status::OK();
1130     }
1131 
1132     const HloInstruction* instruction = value->instruction();
1133     const bool is_entry_parameter =
1134         instruction->opcode() == HloOpcode::kParameter &&
1135         instruction->parent() ==
1136             instruction->parent()->parent()->entry_computation();
1137 
1138     if (is_entry_parameter) {
1139       bool parameter_has_alias =
1140           assignment->module().input_output_alias_config().ParameterHasAlias(
1141               instruction->parameter_number(), value->index());
1142       // If the hlo buffer is part of an external parameter, creates a new
1143       // allocation and sets its parameter number. Parameters of non-entry
1144       // computations do not need special allocations because they live inside
1145       // callers.
1146       BufferAllocation* allocation =
1147           assignment->NewAllocation(*hlo_buffer, buffer_size);
1148 
1149       allocation->set_entry_computation_parameter(
1150           instruction->parameter_number(), value->index(), parameter_has_alias);
1151       if (parameter_has_alias) {
1152         allocation_indices->push_back(allocation->index());
1153       }
1154       VLOG(3) << "New allocation #" << allocation->index()
1155               << " marked as entry computation parameter: " << *hlo_buffer;
1156       return Status::OK();
1157     }
1158   }
1159 
1160   if (is_thread_local) {
1161     BufferAllocation* allocation =
1162         assignment->NewAllocation(*hlo_buffer, buffer_size);
1163     allocation->set_is_thread_local(true);
1164     VLOG(3) << "New allocation #" << allocation->index()
1165             << " for thread-local: " << *hlo_buffer;
1166     return Status::OK();
1167   }
1168 
1169   for (const HloValue* value : hlo_buffer->values()) {
1170     if (value->shape().IsTuple()) {
1171       BufferAllocation* allocation =
1172           assignment->NewAllocation(*hlo_buffer, buffer_size);
1173       allocation->set_is_tuple(true);
1174       VLOG(3) << "New allocation #" << allocation->index()
1175               << " for tuple-shaped buffer: " << *hlo_buffer;
1176       return Status::OK();
1177     }
1178 
1179     if (value->IsTopLevel() && !value->IsTuple()) {
1180       const HloInstruction* instruction = value->instruction();
1181       for (auto* operand : instruction->operands()) {
1182         for (const auto& operand_slice :
1183              assignment->GetAllSlices(operand, /*index=*/{})) {
1184           BufferAllocation* allocation =
1185               assignment->GetMutableAllocation(operand_slice.index());
1186           if (MaybeAssignBuffer(allocation, *hlo_buffer, assignment)) {
1187             VLOG(3) << "Reusing (operand) allocation #" << allocation->index()
1188                     << " for: " << *hlo_buffer;
1189             return Status::OK();
1190           }
1191         }
1192       }
1193     }
1194   }
1195 
1196   // Find the smallest buffer which can be reused iterating from end of
1197   // allocation_indices (smallest) to beginning (largest).
1198   for (int allocation_index = allocation_indices->size() - 1;
1199        allocation_index >= 0; allocation_index--) {
1200     BufferAllocation* allocation = assignment->GetMutableAllocation(
1201         allocation_indices->at(allocation_index));
1202     if (MaybeAssignBuffer(allocation, *hlo_buffer, assignment)) {
1203       VLOG(3) << "Reusing allocation #" << allocation->index()
1204               << " for: " << *hlo_buffer;
1205       return Status::OK();
1206     }
1207   }
1208 
1209   if (!assignment->HasAllocation(*hlo_buffer) &&
1210       !assignment->alias_analysis().BufferLivesOut(*hlo_buffer)) {
1211     bool all_computations_have_sequential_order = true;
1212     for (const HloValue* hlo_value : hlo_buffer->values()) {
1213       HloComputation* computation = hlo_value->instruction()->parent();
1214       const bool has_sequential_order =
1215           assignment->hlo_ordering().SequentialOrder(*computation) != nullptr;
1216       all_computations_have_sequential_order &= has_sequential_order;
1217     }
1218 
1219     if (all_computations_have_sequential_order) {
1220       for (const HloValue* hlo_value : hlo_buffer->values()) {
1221         HloComputation* computation = hlo_value->instruction()->parent();
1222         // There is a sequential instruction ordering, so we delay assignment
1223         // of temp buffers until after the loop. We do this right before we
1224         // decide to create a new allocation, to ensure we've exhausted all
1225         // the buffer re-use cases above.
1226         //
1227         // Entry parameters and thread local buffers were already handled
1228         // earlier in this loop iteration.  See
1229         // BufferAllocation::IsPreallocatedTempBuffer for the definition of
1230         // temp buffers.
1231         (*buffers_to_assign_sequentially)[computation].insert(hlo_value);
1232         VLOG(3) << "Delaying assignment of temp buffer: " << *hlo_value;
1233       }
1234       return Status::OK();
1235     }
1236   }
1237 
1238   if (!assignment->HasAllocation(*hlo_buffer)) {
1239     BufferAllocation* allocation =
1240         assignment->NewAllocation(*hlo_buffer, buffer_size);
1241     allocation_indices->push_back(allocation->index());
1242     VLOG(3) << "New allocation #" << allocation->index()
1243             << " for: " << *hlo_buffer;
1244   }
1245 
1246   TF_RET_CHECK(assignment->HasAllocation(*hlo_buffer));
1247   return Status::OK();
1248 }
1249 
AssignBuffersForComputations(const std::vector<const HloComputation * > & computations,bool is_thread_local,absl::flat_hash_map<const HloComputation *,absl::flat_hash_set<const HloValue * >> * buffers_to_assign_sequentially,BufferAssignment * assignment)1250 Status BufferAssigner::AssignBuffersForComputations(
1251     const std::vector<const HloComputation*>& computations,
1252     bool is_thread_local,
1253     absl::flat_hash_map<const HloComputation*,
1254                         absl::flat_hash_set<const HloValue*>>*
1255         buffers_to_assign_sequentially,
1256     BufferAssignment* assignment) {
1257   if (computations.empty()) {
1258     return Status::OK();
1259   }
1260   std::vector<const HloBuffer*> sorted_buffers;
1261 
1262   // First assign the preset allocations.
1263   absl::flat_hash_set<const HloBuffer*> preset_assigned_buffers;
1264 
1265   TF_RETURN_IF_ERROR(AssignPresetBuffers(&preset_assigned_buffers, assignment));
1266 
1267   const HloAliasAnalysis& alias_analysis = assignment->alias_analysis();
1268 
1269   for (const HloBuffer& buffer : alias_analysis.buffers()) {
1270     // Skip if the buffer is already assigned since it had a preset allocation.
1271     if (preset_assigned_buffers.find(&buffer) !=
1272         preset_assigned_buffers.end()) {
1273       VLOG(3) << "Skip allocation for buffer: " << buffer;
1274       continue;
1275     }
1276     TF_RET_CHECK(!buffer.values().empty());
1277     const HloComputation* comp = buffer.values()[0]->instruction()->parent();
1278     if (absl::c_linear_search(computations, comp)) {
1279       sorted_buffers.push_back(&buffer);
1280     }
1281   }
1282 
1283   // Generate a post order sort of instructions for sorting of the
1284   // HloBuffers.
1285   flat_hash_map<const HloInstruction*, int> post_order_position;
1286   int position = 0;
1287   std::vector<const HloComputation*> reverse_post_order_computations;
1288   std::unique_ptr<CallGraph> call_graph =
1289       CallGraph::Build(computations[0]->parent());
1290   TF_RETURN_IF_ERROR(call_graph->VisitNodes([&](const CallGraphNode& node) {
1291     if (absl::c_linear_search(computations, node.computation())) {
1292       reverse_post_order_computations.push_back(node.computation());
1293     }
1294     return Status::OK();
1295   }));
1296   absl::c_reverse(reverse_post_order_computations);
1297   for (auto* computation : reverse_post_order_computations) {
1298     for (auto* instruction : computation->MakeInstructionPostOrder()) {
1299       post_order_position.emplace(instruction, position);
1300       position++;
1301     }
1302   }
1303 
1304   HloSchedule schedule(&assignment->module());
1305 
1306   for (const HloComputation* computation : computations) {
1307     const HloInstructionSequence* instruction_sequence =
1308         assignment->hlo_ordering().SequentialOrder(*computation);
1309     const bool has_sequential_order = instruction_sequence != nullptr;
1310     if (has_sequential_order && buffers_to_assign_sequentially != nullptr) {
1311       // Every sequential computation must get an entry in the
1312       // buffers_to_assign_sequentially map, even if we end up with an empty
1313       // set of buffers. This ensures we can correctly determine whether to
1314       // run whole-module heap simulation.
1315       buffers_to_assign_sequentially->emplace(computation,
1316                                               flat_hash_set<const HloValue*>());
1317 
1318       schedule.set_sequence(computation, *instruction_sequence);
1319     }
1320   }
1321 
1322   absl::c_sort(
1323       sorted_buffers, [&post_order_position, &alias_analysis, assignment](
1324                           const HloBuffer* a, const HloBuffer* b) {
1325         // Primary sort is by decreasing buffer size.
1326         const int64_t a_size = assignment->HloBufferSize(*a);
1327         const int64_t b_size = assignment->HloBufferSize(*b);
1328         if (a_size != b_size) {
1329           return a_size > b_size;  // use ">" for decreasing size.
1330         }
1331 
1332         const bool a_live_out = alias_analysis.BufferLivesOut(*a);
1333         const bool b_live_out = alias_analysis.BufferLivesOut(*b);
1334         if (a_live_out != b_live_out) {
1335           return a_live_out;
1336         }
1337         auto compare = [&post_order_position](const HloValue* value1,
1338                                               const HloValue* value2) {
1339           return post_order_position.at(value1->instruction()) <
1340                  post_order_position.at(value2->instruction());
1341         };
1342         const HloValue* a_min = *absl::c_min_element(a->values(), compare);
1343         const HloValue* b_min = *absl::c_min_element(b->values(), compare);
1344         return compare(a_min, b_min);
1345       });
1346 
1347   std::vector<BufferAllocation::Index> allocation_indices;
1348 
1349   for (const HloBuffer* buffer : sorted_buffers) {
1350     VLOG(3) << "=================================================";
1351     VLOG(3) << "Assigning buffer for " << *buffer;
1352     TF_RETURN_IF_ERROR(AssignSingleHloBuffer(buffer, is_thread_local,
1353                                              buffers_to_assign_sequentially,
1354                                              &allocation_indices, assignment));
1355   }
1356   return Status::OK();
1357 }
1358 
1359 flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>>
SplitBuffersByColor(const flat_hash_set<const HloValue * > & buffers)1360 BufferAssigner::SplitBuffersByColor(
1361     const flat_hash_set<const HloValue*>& buffers) {
1362   flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>> color_map;
1363   for (auto buffer : buffers) {
1364     color_map[buffer->color()].insert(buffer);
1365   }
1366   return color_map;
1367 }
1368 
AssignPresetBuffers(absl::flat_hash_set<const HloBuffer * > * assigned_buffers,BufferAssignment * assignment)1369 Status BufferAssigner::AssignPresetBuffers(
1370     absl::flat_hash_set<const HloBuffer*>* assigned_buffers,
1371     BufferAssignment* assignment) {
1372   if (!preset_assignments_) {
1373     return Status::OK();
1374   }
1375 
1376   // Create an allocation for each preset color.
1377   absl::flat_hash_map<LogicalBuffer::Color, BufferAllocation*>
1378       preset_allocations;
1379   for (auto& color_and_info : preset_assignments_->assignment_informations()) {
1380     LogicalBuffer::Color color(color_and_info.first);
1381     auto inserted = preset_allocations.emplace(
1382         color,
1383         assignment->NewEmptyAllocation(color_and_info.second.size, color));
1384     BufferAllocation* inserted_allocation = inserted.first->second;
1385     inserted_allocation->AddHeapTrace(
1386         color_and_info.second.heap_simulator_trace);
1387     VLOG(3) << "Created preset buffer allocation "
1388             << inserted_allocation->index()
1389             << ", color: " << inserted_allocation->color()
1390             << ", size: " << inserted_allocation->size();
1391   }
1392 
1393   const HloAliasAnalysis& alias_analysis = assignment->alias_analysis();
1394 
1395   for (auto& position_and_chunk : preset_assignments_->chunks()) {
1396     const HloPosition& defining_position = position_and_chunk.first;
1397     const HloBuffer& buffer = alias_analysis.GetUniqueBufferAt(
1398         defining_position.instruction, defining_position.index);
1399     for (const HloValue* value : buffer.values()) {
1400       VLOG(3) << "Preset allocation for value: " << value->ToShortString();
1401       const HeapSimulator::Chunk& chunk = position_and_chunk.second;
1402       auto preset_allocations_iter = preset_allocations.find(value->color());
1403       CHECK(preset_allocations_iter != preset_allocations.end())
1404           << "No preset value allocation for color " << value->color()
1405           << " for " << value->ToShortString() << " found.";
1406       preset_allocations_iter->second->AddAssignment(*value, chunk.offset,
1407                                                      chunk.size);
1408     }
1409 
1410     assigned_buffers->insert(&buffer);
1411   }
1412 
1413   // Upon consumption of the preset assignments, delete it so that if this
1414   // method is called again, it does not assign the same buffers multiple times.
1415   preset_assignments_ = {};
1416 
1417   return Status::OK();
1418 }
1419 
AssignBuffersWithSequentialOrdering(const flat_hash_map<const HloComputation *,flat_hash_set<const HloValue * >> & buffers_to_assign_sequentially,bool run_whole_module_heap_simulation,BufferAssignment * assignment)1420 Status BufferAssigner::AssignBuffersWithSequentialOrdering(
1421     const flat_hash_map<const HloComputation*, flat_hash_set<const HloValue*>>&
1422         buffers_to_assign_sequentially,
1423     bool run_whole_module_heap_simulation, BufferAssignment* assignment) {
1424   // Run the sequence of instructions through the heap simulator.  The
1425   // heuristic that seems to give the best results is lazy-best-fit, with all
1426   // runs of alloc / free calls sorted in decreasing size order.
1427   const HloOrdering& hlo_ordering = assignment->hlo_ordering();
1428 
1429   // Returns a heap algorithm that chooses the best result from several
1430   // algorithms.
1431   auto get_heap_algorithm = [&](int64_t alignment) {
1432     auto algorithms = absl::make_unique<
1433         std::vector<std::unique_ptr<HeapAlgorithm<HloValue>>>>();
1434     algorithms->push_back(
1435         absl::make_unique<ConstrainedGlobalDecreasingSizeBestFitHeap>(
1436             assignment->multiheap_size_constraint_per_heap(), alignment,
1437             GlobalDecreasingSizeBestFitHeap<HloValue>::kSpatial));
1438     algorithms->push_back(
1439         absl::make_unique<ConstrainedGlobalDecreasingSizeBestFitHeap>(
1440             assignment->multiheap_size_constraint_per_heap(), alignment,
1441             GlobalDecreasingSizeBestFitHeap<HloValue>::kTemporal));
1442     return absl::make_unique<ChooseBestHeapAlgorithm<HloValue>>(
1443         std::move(algorithms));
1444   };
1445 
1446   if (run_whole_module_heap_simulation) {
1447     // Run the heap simulation over the whole module. This reduces memory
1448     // usage, since buffers for kCall, kWhile, and kConditional
1449     // sub-computations are only live for the duration of their calling
1450     // instructions.
1451     VLOG(1) << "Running whole-module heap simulation";
1452     HloSchedule schedule(&assignment->module());
1453     flat_hash_set<const HloValue*> all_buffers_to_assign;
1454     for (const auto& pair : buffers_to_assign_sequentially) {
1455       const HloComputation* computation = pair.first;
1456       const flat_hash_set<const HloValue*>& buffers_to_assign = pair.second;
1457       const HloInstructionSequence* instruction_sequence =
1458           hlo_ordering.SequentialOrder(*computation);
1459       CHECK(instruction_sequence != nullptr) << computation->name();
1460       schedule.set_sequence(computation, *instruction_sequence);
1461       all_buffers_to_assign.insert(buffers_to_assign.begin(),
1462                                    buffers_to_assign.end());
1463     }
1464     auto color_map = SplitBuffersByColor(all_buffers_to_assign);
1465     for (auto& single_colored_set : color_map) {
1466       auto color = single_colored_set.first;
1467       VLOG(2) << "Simulating heap for color " << color;
1468       int64_t alignment = assignment->color_alignment_(color);
1469       HeapSimulator::Options options;
1470       options.alloc_constants = allocate_buffers_for_constants_;
1471       options.buffers_to_assign = &single_colored_set.second;
1472 
1473       TF_ASSIGN_OR_RETURN(
1474           HeapSimulator::Result<HloValue> result,
1475           HeapSimulator::Run(
1476               get_heap_algorithm(alignment), assignment->module(), schedule,
1477               assignment->alias_analysis(), assignment->buffer_size_, options));
1478       AssignBuffersFromHeapSimulator(result, assignment,
1479                                      single_colored_set.first);
1480     }
1481   } else {
1482     // Run the heap-simulation on a per-computation basis. Buffers for
1483     // sub-computations are assigned disjoint BufferAllocations, assuming the
1484     // worst-case that they may all be live concurrently.
1485     VLOG(1) << "Running per-computation heap simulation";
1486     for (const auto& pair : buffers_to_assign_sequentially) {
1487       const HloComputation* computation = pair.first;
1488       const flat_hash_set<const HloValue*>& buffers_to_assign = pair.second;
1489       const HloInstructionSequence* instruction_sequence =
1490           hlo_ordering.SequentialOrder(*computation);
1491       CHECK(instruction_sequence != nullptr) << computation->name();
1492       auto color_map = SplitBuffersByColor(buffers_to_assign);
1493       for (auto& single_colored_set : color_map) {
1494         auto color = single_colored_set.first;
1495         VLOG(2) << "Simulating heap for color " << color;
1496         int64_t alignment = assignment->color_alignment_(color);
1497         HeapSimulator::Options options;
1498         options.buffers_to_assign = &single_colored_set.second;
1499         TF_ASSIGN_OR_RETURN(
1500             HeapSimulator::Result<HloValue> result,
1501             HeapSimulator::Run(get_heap_algorithm(alignment), *computation,
1502                                *instruction_sequence,
1503                                assignment->alias_analysis(),
1504                                assignment->buffer_size_, options));
1505         AssignBuffersFromHeapSimulator(result, assignment,
1506                                        single_colored_set.first);
1507       }
1508     }
1509   }
1510   return Status::OK();
1511 }
1512 
1513 namespace {
1514 // Computes and returns the set of logical buffers live at the point of
1515 // maximal liveness in the given heap trace. LogicalBuffers are (stabily)
1516 // sorted by id.
ComputePeakMemoryLogicalBuffers(const BufferAllocation & allocation,const HeapSimulatorTrace & heap_trace)1517 std::vector<const HloValue*> ComputePeakMemoryLogicalBuffers(
1518     const BufferAllocation& allocation, const HeapSimulatorTrace& heap_trace) {
1519   // Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical
1520   // buffers in this allocation.
1521   absl::flat_hash_map<BufferValue::Id, const HloValue*> id_to_value;
1522   absl::flat_hash_map<const HloValue*, int64> buffer_sizes;
1523   for (const auto& pair : allocation.assigned_buffers()) {
1524     const HloValue* value = pair.first;
1525     const BufferAllocation::OffsetSize& offset_size = pair.second;
1526     id_to_value[value->id()] = value;
1527     buffer_sizes[value] = offset_size.size;
1528   }
1529   VLOG(1) << "Compute peak memory logical buffers";
1530 
1531   // Returns how much the given event increases the total size of live
1532   // buffers. Can be negative.
1533   auto memory_delta = [&id_to_value, &buffer_sizes](
1534                           const HeapSimulatorTrace::Event& event) -> int64 {
1535     const HloValue* buffer = id_to_value.at(event.buffer_id());
1536     const int64_t buffer_size = buffer_sizes.at(buffer);
1537     if (event.kind() == HeapSimulatorTrace::Event::ALLOC ||
1538         event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
1539       return buffer_size;
1540     } else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
1541       return -1 * buffer_size;
1542     }
1543     LOG(FATAL) << "Unknown event kind: " << event.kind();
1544   };
1545 
1546   // First compute the size of the maximal live set.
1547   int64_t max_live_size = 0;
1548   int64_t live_size = 0;
1549   for (const auto& event : heap_trace.events()) {
1550     if (!id_to_value.contains(event.buffer_id())) {
1551       // Skip as the buffer associated with this trace event is not placed into
1552       // this allocation. This can happen when size constraints are given to the
1553       // heap simulator.
1554       continue;
1555     }
1556     live_size += memory_delta(event);
1557     if (max_live_size < live_size) {
1558       max_live_size = live_size;
1559     }
1560   }
1561 
1562   // Next gather the set of logical buffers live at the earliest point of
1563   // maximal live set size.
1564   absl::flat_hash_set<const HloValue*> live_values;
1565   live_size = 0;
1566   for (const auto& event : heap_trace.events()) {
1567     if (!id_to_value.contains(event.buffer_id())) {
1568       // Skip as the buffer associated with this trace event is not placed into
1569       // this allocation. This can happen when size constraints are given to the
1570       // heap simulator.
1571       continue;
1572     }
1573     const HloValue* value = id_to_value.at(event.buffer_id());
1574     if (event.kind() == HeapSimulatorTrace::Event::ALLOC ||
1575         event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
1576       InsertOrDie(&live_values, value);
1577     } else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
1578       CHECK(ContainsKey(live_values, value));
1579       live_values.erase(value);
1580     }
1581     live_size += memory_delta(event);
1582 
1583     if (live_size == max_live_size) {
1584       break;
1585     }
1586   }
1587   CHECK_EQ(live_size, max_live_size);
1588 
1589   std::vector<const HloValue*> live_values_vector;
1590   live_values_vector.insert(live_values_vector.end(), live_values.begin(),
1591                             live_values.end());
1592 
1593   // Stabily sort the live buffers.
1594   absl::c_sort(live_values_vector, [](const HloValue* a, const HloValue* b) {
1595     return a->id() < b->id();
1596   });
1597   VLOG(4) << "Peak memory buffer:";
1598   for (auto value : live_values_vector) {
1599     VLOG(4) << "  " << value->ToString();
1600   }
1601   return live_values_vector;
1602 }
1603 
1604 }  // namespace
1605 
AssignBuffersFromHeapSimulator(const HeapSimulator::Result<HloValue> & result,BufferAssignment * assignment,BufferValue::Color color)1606 void BufferAssigner::AssignBuffersFromHeapSimulator(
1607     const HeapSimulator::Result<HloValue>& result, BufferAssignment* assignment,
1608     BufferValue::Color color) {
1609   if (assignment->stats_.preallocated_temp_fragmentation_bytes == -1) {
1610     assignment->stats_.preallocated_temp_fragmentation_bytes =
1611         result.fragmentation_size;
1612   } else {
1613     assignment->stats_.preallocated_temp_fragmentation_bytes +=
1614         result.fragmentation_size;
1615   }
1616   VLOG(1) << "Result size from heap simulator: " << result.heap_size;
1617 
1618   // Iterate through heap_results. For each heap_result, create a new allocation
1619   // in `assignment`.
1620   for (const HeapSimulator::HeapResult<HloValue>& heap_result :
1621        result.heap_results) {
1622     BufferAllocation* allocation =
1623         assignment->NewEmptyAllocation(heap_result.heap_size, color);
1624     for (const auto& buffer_chunk : heap_result.chunk_map) {
1625       const HloValue& value = *buffer_chunk.first;
1626       const HeapSimulator::Chunk& chunk = buffer_chunk.second;
1627       assignment->AddAssignment(allocation, value, chunk.offset, chunk.size);
1628     }
1629     allocation->peak_buffers_ =
1630         ComputePeakMemoryLogicalBuffers(*allocation, result.debug_trace);
1631 
1632     XLA_VLOG_LINES(2, allocation->ToString());
1633 
1634     allocation->AddHeapTrace(result.debug_trace);
1635   }
1636 }
1637 
CreateAssignment(const HloModule * module,std::unique_ptr<HloOrdering> hlo_ordering,BufferValue::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment,HloDataflowAnalysis::CanShareBuffer can_share_buffer)1638 StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
1639     const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
1640     BufferValue::SizeFunction buffer_size,
1641     LogicalBuffer::AlignmentFunction color_alignment,
1642     HloDataflowAnalysis::CanShareBuffer can_share_buffer) {
1643   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1644                       HloAliasAnalysis::Run(module, can_share_buffer));
1645 
1646   // Set up a schedule for each computation.
1647   HloSchedule schedule(module);
1648   for (const HloComputation* computation : module->computations()) {
1649     const HloInstructionSequence* instruction_sequence =
1650         hlo_ordering->SequentialOrder(*computation);
1651     const bool has_sequential_order = instruction_sequence != nullptr;
1652     if (has_sequential_order) {
1653       schedule.set_sequence(computation, *instruction_sequence);
1654     }
1655   }
1656 
1657   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
1658                       HloLiveRange::Run(schedule, *alias_analysis,
1659                                         module->entry_computation(), true));
1660 
1661   VLOG(1) << "Assigning buffers to module " << module->name();
1662   XLA_VLOG_LINES(3, module->ToString());
1663   XLA_VLOG_LINES(3, alias_analysis->ToString());
1664   XLA_VLOG_LINES(3, alias_analysis->dataflow_analysis().ToString());
1665   VLOG(1) << "Number of buffers to assign: "
1666           << alias_analysis->buffers().size();
1667 
1668   // Can't use absl::make_unique because BufferAssignment constructor is
1669   // private.
1670   std::unique_ptr<BufferAssignment> assignment(new BufferAssignment(
1671       module, std::move(hlo_ordering), std::move(buffer_size),
1672       std::move(color_alignment), std::move(alias_analysis),
1673       std::move(hlo_live_range)));
1674 
1675   TF_RETURN_IF_ERROR(
1676       colorer_(&assignment->alias_analysis(), assignment->hlo_ordering()));
1677   VLOG(3) << "After coloring:";
1678   XLA_VLOG_LINES(3,
1679                  assignment->alias_analysis().dataflow_analysis().ToString());
1680 
1681   std::vector<const HloComputation*> thread_local_computations;
1682   std::vector<const HloComputation*> global_computations;
1683   TF_RETURN_IF_ERROR(GatherComputationsByAllocationType(
1684       module, &thread_local_computations, &global_computations));
1685 
1686   // First assign buffers for global computations. Temporary buffers for
1687   // sequential computations are collected in
1688   // 'buffers_to_assign_sequentially'.
1689   flat_hash_map<const HloComputation*, flat_hash_set<const HloValue*>>
1690       buffers_to_assign_sequentially;
1691   TF_RETURN_IF_ERROR(AssignBuffersForComputations(
1692       global_computations,
1693       /*is_thread_local=*/false, &buffers_to_assign_sequentially,
1694       assignment.get()));
1695   // Assign buffers with sequential ordering, if any. If all global
1696   // computations are sequential, we can run heap simulation on the whole
1697   // module, which reduces memory usage.
1698   const bool run_whole_module_heap_simulation =
1699       buffers_to_assign_sequentially.size() == global_computations.size();
1700   VLOG(2) << "Running whole module heap simulation: "
1701           << run_whole_module_heap_simulation;
1702   const int32_t multiheap_size_constraint_per_heap =
1703       module->config().debug_options().xla_multiheap_size_constraint_per_heap();
1704   VLOG(2) << "Multiheap per heap size limit: "
1705           << multiheap_size_constraint_per_heap;
1706   TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering(
1707       buffers_to_assign_sequentially, run_whole_module_heap_simulation,
1708       assignment.get()));
1709 
1710   std::vector<const HloComputation*> thread_local_computations_no_fusion;
1711   // Now assign buffers for thread-local computations. All LogicalBuffers get
1712   // their own BufferAllocation.
1713 
1714   for (auto* computation : thread_local_computations) {
1715     TF_RET_CHECK(computation != module->entry_computation());
1716     if (computation->IsFusionComputation()) {
1717       continue;
1718     }
1719     thread_local_computations_no_fusion.push_back(computation);
1720   }
1721 
1722   TF_RETURN_IF_ERROR(AssignBuffersForComputations(
1723       thread_local_computations_no_fusion,
1724       /*is_thread_local=*/true,
1725       /*buffers_to_assign_sequentially=*/nullptr, assignment.get()));
1726 
1727   // Mark all buffers which may be live out of the entry computation as
1728   // "liveout".
1729   for (const HloBuffer* buffer :
1730        assignment->alias_analysis().LiveOutBuffers()) {
1731     VLOG(3) << "maybe_live_out LogicalBuffer: " << *buffer;
1732     if (assignment->HasAllocation(*buffer)) {
1733       BufferAllocation* alloc =
1734           assignment->GetMutableAssignedAllocation(*buffer);
1735       alloc->set_maybe_live_out(true);
1736       VLOG(3) << "maybe_live_out BufferAllocation: " << *alloc;
1737     }
1738   }
1739 
1740   // Combines allocations of temporary buffers into big BufferAllocations
1741   // subject to the buffer allocation size constraint. This can only be
1742   // performed after all buffers have been assigned, and after maybe_live_out
1743   // is marked, since it is used to determine whether an allocation contains
1744   // temporary buffers or not.
1745   assignment->CombineTempAllocations();
1746 
1747   XLA_VLOG_LINES(2, assignment->ToString());
1748   TF_RETURN_IF_ERROR(assignment->ComputeSummaryStats());
1749   XLA_VLOG_LINES(1, assignment->GetStats().ToString());
1750   VLOG(1) << "Buffer assignment done.";
1751   return std::move(assignment);
1752 }
1753 
1754 }  // namespace xla
1755