• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Defines the data returned by the XLA buffer assignment packages.
17 
18 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
19 
20 #include <algorithm>
21 #include <deque>
22 #include <ostream>
23 #include <utility>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/memory/memory.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_format.h"
30 #include "tensorflow/compiler/xla/map_util.h"
31 #include "tensorflow/compiler/xla/service/buffer_value_containers.h"
32 #include "tensorflow/compiler/xla/service/heap_simulator.h"
33 #include "tensorflow/compiler/xla/service/hlo.pb.h"
34 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
35 #include "tensorflow/compiler/xla/service/hlo_buffer.h"
36 #include "tensorflow/compiler/xla/service/hlo_live_range.h"
37 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/status_macros.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/hash/hash.h"
44 #include "tensorflow/core/lib/strings/numbers.h"
45 
46 namespace xla {
47 namespace {
48 
49 using absl::flat_hash_map;
50 using absl::flat_hash_set;
51 using absl::StrAppend;
52 using absl::StrAppendFormat;
53 using ::tensorflow::strings::HumanReadableNumBytes;
54 
55 // Given the interference map of a graph (the list of interfering node indices
56 // for each node), perform graph coloring such that interfering nodes are
57 // assigned to different colors. Returns the assigned color of the nodes, where
58 // the colors are represented as integer values [0, color_count).
ColorInterferenceGraph(const std::vector<std::vector<int64>> & interference_map)59 std::vector<int64> ColorInterferenceGraph(
60     const std::vector<std::vector<int64>>& interference_map) {
61   const int64 node_count = interference_map.size();
62 
63   // Sort the nodes such that we assign nodes with more interference first. This
64   // relies on the common heuristic of assigning the most constrained node
65   // first, but it would be good to investigate other ordering heuristics too.
66   std::vector<int64> nodes(node_count);
67   std::iota(nodes.begin(), nodes.end(), 0);
68   absl::c_sort(nodes, [&interference_map](const int64 i, const int64 j) {
69     return interference_map[i].size() > interference_map[j].size();
70   });
71 
72   const int64 kColorUnassigned = -1;
73   std::vector<int64> assigned_colors(node_count, kColorUnassigned);
74   for (int64 node : nodes) {
75     // Mark the colors that are already assigned to the neighbors.
76     std::vector<bool> available_colors(node_count, true);
77     for (int64 neighbor : interference_map[node]) {
78       int64 color = assigned_colors[neighbor];
79       if (color != kColorUnassigned) {
80         available_colors[color] = false;
81       }
82     }
83 
84     // Find the color that is not yet assigned to the neighbors.
85     int64 color = kColorUnassigned;
86     for (color = 0; color < available_colors.size(); ++color) {
87       if (available_colors[color]) {
88         break;
89       }
90     }
91     CHECK_NE(color, kColorUnassigned);
92     assigned_colors[node] = color;
93   }
94   return assigned_colors;
95 }
96 
97 // If an hlo buffer contains an entry parameter, the buffer is read-only unless
98 // it is aliased with an output.
HloBufferIsReadOnly(const HloBuffer & buffer)99 bool HloBufferIsReadOnly(const HloBuffer& buffer) {
100   for (const HloValue* value : buffer.values()) {
101     const HloInstruction* instruction = value->instruction();
102     const HloModule* module = instruction->parent()->parent();
103     const bool is_entry_parameter =
104         instruction->opcode() == HloOpcode::kParameter &&
105         instruction->parent() == module->entry_computation();
106 
107     if (is_entry_parameter) {
108       bool parameter_has_alias =
109           module->input_output_alias_config().ParameterHasAlias(
110               instruction->parameter_number(), value->index());
111       // The parameter doesn't have an alias, it must be read-only.
112       if (!parameter_has_alias) {
113         return true;
114       }
115     }
116   }
117   return false;
118 }
119 
120 }  // namespace
121 
GatherComputationsByAllocationType(const HloModule * module,std::vector<const HloComputation * > * thread_local_computations,std::vector<const HloComputation * > * global_computations)122 Status GatherComputationsByAllocationType(
123     const HloModule* module,
124     std::vector<const HloComputation*>* thread_local_computations,
125     std::vector<const HloComputation*>* global_computations) {
126   // Create a worklist of computations paired with whether the allocation must
127   // be thread-local.
128   std::deque<std::pair<const HloComputation*, bool>> worklist;
129   worklist.push_back(std::make_pair(module->entry_computation(),
130                                     /*is_thread_local*/ false));
131 
132   // Sets for quickly checking membership. Computations are returned in vectors
133   // for stable iteration.
134   flat_hash_set<const HloComputation*> thread_local_set;
135   flat_hash_set<const HloComputation*> global_set;
136 
137   while (!worklist.empty()) {
138     auto worklist_front = worklist.front();
139     worklist.pop_front();
140     const HloComputation* computation = worklist_front.first;
141     bool is_thread_local = worklist_front.second;
142     bool in_thread_local_set = thread_local_set.contains(computation);
143     bool in_global_set = global_set.contains(computation);
144 
145     // If the computation has already been added to the respective set, then
146     // nothing to do.
147     if ((is_thread_local && in_thread_local_set) ||
148         (!is_thread_local && in_global_set)) {
149       continue;
150     }
151 
152     // If the computation has already been added to the other set this is an
153     // error condition because the global call to the computation (eg,
154     // while/call) may return a reference to one of the thread-local buffers to
155     // the calling computation which will become a dangling reference when the
156     // thread-local is deallocated with the call return.
157     if ((is_thread_local && in_global_set) ||
158         (!is_thread_local && in_thread_local_set)) {
159       return InvalidArgument(
160           "computation %s has conflicting allocation requirements (global "
161           "and thread-local)",
162           computation->name());
163     }
164 
165     if (is_thread_local) {
166       thread_local_set.insert(computation);
167     } else {
168       global_set.insert(computation);
169     }
170 
171     for (auto* instruction : computation->instructions()) {
172       for (HloComputation* subcomputation :
173            instruction->called_computations()) {
174         switch (instruction->opcode()) {
175           case HloOpcode::kCall:
176           case HloOpcode::kConditional:
177           case HloOpcode::kWhile:
178             // Call and while must be called from a computation with global
179             // allocations as they may return references to buffers inside the
180             // called computation which cannot be thread-local.
181             if (is_thread_local) {
182               return InvalidArgument(
183                   "computation %s cannot contain call/while op because it "
184                   "requires thread-local buffer allocations",
185                   computation->name());
186             }
187             worklist.push_back(std::make_pair(subcomputation,
188                                               false));  // Not thread local.
189             break;
190           case HloOpcode::kAllReduce:
191           case HloOpcode::kMap:
192           case HloOpcode::kReduce:
193           case HloOpcode::kReduceWindow:
194           case HloOpcode::kScatter:
195           case HloOpcode::kSelectAndScatter:
196           case HloOpcode::kSort:
197           case HloOpcode::kFusion:
198             // Map/reduce etc computations are always thread-local.
199             worklist.push_back(std::make_pair(subcomputation,
200                                               true));  // Thread local.
201             break;
202           default:
203             return InternalError("Unexpected calling opcode: %s",
204                                  HloOpcodeString(instruction->opcode()));
205         }
206       }
207     }
208   }
209 
210   // Add the computations to the vectors in post order.
211   for (auto* computation : module->MakeComputationPostOrder()) {
212     if (thread_local_set.contains(computation)) {
213       thread_local_computations->push_back(computation);
214     } else if (global_set.contains(computation)) {
215       global_computations->push_back(computation);
216     }
217     // If the computation is not reachable from the entry computation, then it
218     // will not appear in either thread_local_set or global_set. We don't bother
219     // assigning buffers for these.
220   }
221   return Status::OK();
222 }
223 
ToString() const224 string BufferAllocation::Slice::ToString() const {
225   return absl::StrCat("{index:", index(), ", offset:", offset_,
226                       ", size:", size_, "}");
227 }
228 
GetSlice(const HloValue & buffer) const229 BufferAllocation::Slice BufferAllocation::GetSlice(
230     const HloValue& buffer) const {
231   const OffsetSize os = FindOrDie(assigned_buffers_, &buffer);
232   return Slice(this, os.offset, os.size);
233 }
234 
AddAssignment(const HloValue & buffer,int64 offset,int64 size)235 void BufferAllocation::AddAssignment(const HloValue& buffer, int64 offset,
236                                      int64 size) {
237   VLOG(4) << "Adding the following buffer to allocation #" << index()
238           << absl::StrFormat(" (size=%d, offset=%d) %s", size, offset,
239                              buffer.ToShortString());
240   CHECK(!assigned_buffers_.contains(&buffer))
241       << "LogicalBuffer " << buffer << " already assigned to allocation "
242       << index_;
243   CHECK_LE(offset, size_) << "LogicalBuffer " << buffer
244                           << " offset out of range";
245   CHECK_LE(offset + size, size_)
246       << "LogicalBuffer " << buffer
247       << " size out of range at offset: " << offset << " with size: " << size;
248   CHECK_EQ(buffer.color(), color())
249       << "Buffer color " << buffer.color() << " for buffer " << buffer
250       << " does not match allocation color " << color() << ".";
251   OffsetSize offset_size;
252   offset_size.offset = offset;
253   offset_size.size = size;
254   assigned_buffers_.emplace(&buffer, offset_size);
255   // For debugging purposes, store the assigned memory space in the
256   // instruction's layout.
257   for (HloPosition position : buffer.positions()) {
258     Shape* shape = ShapeUtil::GetMutableSubshape(
259         position.instruction->mutable_shape(), position.index);
260     if (shape->has_layout()) {
261       shape->mutable_layout()->set_memory_space(buffer.color().value());
262     }
263   }
264 }
265 
ToProto() const266 BufferAllocationProto BufferAllocation::ToProto() const {
267   BufferAllocationProto proto;
268   proto.set_index(index_);
269   proto.set_size(size_);
270   proto.set_is_thread_local(is_thread_local_);
271   proto.set_is_tuple(is_tuple_);
272   proto.set_color(color_.value());
273   if (is_entry_computation_parameter_) {
274     proto.set_is_entry_computation_parameter(true);
275     for (int64 idx : param_shape_index()) {
276       proto.add_parameter_shape_index(idx);
277     }
278     proto.set_parameter_number(parameter_number_);
279   }
280   proto.set_is_constant(is_constant_);
281   proto.set_maybe_live_out(maybe_live_out_);
282   for (const auto& buffer_offset_size : assigned_buffers_) {
283     BufferAllocationProto::Assigned* proto_assigned = proto.add_assigned();
284     proto_assigned->set_logical_buffer_id(buffer_offset_size.first->id());
285     proto_assigned->set_offset(buffer_offset_size.second.offset);
286     proto_assigned->set_size(buffer_offset_size.second.size);
287   }
288   absl::c_sort(*proto.mutable_assigned(),
289                [](const BufferAllocationProto::Assigned& assign1,
290                   const BufferAllocationProto::Assigned& assign2) {
291                  return assign1.logical_buffer_id() <
292                         assign2.logical_buffer_id();
293                });
294   return proto;
295 }
296 
CompareHloValuesById(const HloValue * a,const HloValue * b)297 static bool CompareHloValuesById(const HloValue* a, const HloValue* b) {
298   return a->id() < b->id();
299 }
300 
301 // Returns parameter instruction corresponding to the allocation or nullptr.
GetEntryParameterInstruction(const BufferAllocation & alloc)302 static const HloInstruction* GetEntryParameterInstruction(
303     const BufferAllocation& alloc) {
304   for (const auto& p : alloc.assigned_buffers()) {
305     const HloValue* value = p.first;
306     const HloInstruction* instr = value->instruction();
307     if (instr->opcode() == HloOpcode::kParameter &&
308         instr->parent() == instr->parent()->parent()->entry_computation()) {
309       return instr;
310     }
311   }
312   return nullptr;
313 }
314 
315 // Returns root module output instruction corresponding to the allocation or
316 // nullptr.
GetOutputInstruction(const BufferAllocation & alloc)317 static const HloInstruction* GetOutputInstruction(
318     const BufferAllocation& alloc) {
319   for (const auto& p : alloc.assigned_buffers()) {
320     const HloValue* value = p.first;
321     for (const HloPosition& position : value->positions()) {
322       const HloInstruction* instr = position.instruction;
323       if (position.index.empty() &&
324           instr->parent()->root_instruction() == instr &&
325           instr->parent()->IsEntryComputation()) {
326         return instr;
327       }
328     }
329   }
330   return nullptr;
331 }
332 
ToString() const333 string BufferAllocation::ToString() const {
334   string output;
335   StrAppendFormat(&output, "allocation %d: %p, size %d", index_, this, size());
336   if (color().value() != 0) {
337     StrAppend(&output, ", color ", color().value());
338   }
339   if (is_entry_computation_parameter()) {
340     const HloInstruction* param = GetEntryParameterInstruction(*this);
341     CHECK(param);
342     StrAppend(&output, ", parameter ", parameter_number(), ", shape |",
343               param->shape().ToString(/*print_layout=*/false),
344               "| at ShapeIndex ", param_shape_index().ToString());
345   }
346   if (const HloInstruction* instr = GetOutputInstruction(*this)) {
347     StrAppend(&output, ", output shape is |",
348               instr->shape().ToString(/*print_layout=*/false), "|");
349   }
350   if (is_constant()) {
351     StrAppend(&output, ", constant");
352   }
353   if (is_thread_local()) {
354     StrAppend(&output, ", thread-local");
355   }
356   if (maybe_live_out()) {
357     StrAppend(&output, ", maybe-live-out");
358   }
359   if (IsPreallocatedTempBuffer()) {
360     StrAppend(&output, ", preallocated-temp");
361   }
362   StrAppend(&output, ":\n");
363   // Dump the assigned buffers ordered by id.
364   std::vector<const HloValue*> sorted_buffers;
365   for (const auto& buffer_offset_size : assigned_buffers_) {
366     sorted_buffers.push_back(buffer_offset_size.first);
367   }
368   absl::c_sort(sorted_buffers, &CompareHloValuesById);
369   for (const HloValue* buffer : sorted_buffers) {
370     const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer);
371     StrAppend(&output,
372               absl::StrFormat(
373                   " value: %s (size=%d,offset=%d): %s\n",
374                   buffer->ToShortString(), offset_size.size, offset_size.offset,
375                   ShapeUtil::HumanStringWithLayout(buffer->shape())));
376   }
377   return output;
378 }
379 
operator <<(std::ostream & out,const BufferAllocation & buffer)380 std::ostream& operator<<(std::ostream& out, const BufferAllocation& buffer) {
381   out << buffer.ToString();
382   return out;
383 }
384 
operator <<(std::ostream & out,const BufferAllocation::Slice & s)385 std::ostream& operator<<(std::ostream& out, const BufferAllocation::Slice& s) {
386   out << s.ToString();
387   return out;
388 }
389 
HasAllocation(const HloValue & value) const390 bool BufferAssignment::HasAllocation(const HloValue& value) const {
391   return allocation_index_for_value_.contains(&value);
392 }
393 
HasAllocation(const HloBuffer & buffer) const394 bool BufferAssignment::HasAllocation(const HloBuffer& buffer) const {
395   return allocation_index_for_value_.contains(buffer.values()[0]);
396 }
397 
GetAssignedAllocation(const HloValue & value) const398 const BufferAllocation& BufferAssignment::GetAssignedAllocation(
399     const HloValue& value) const {
400   CHECK(HasAllocation(value));
401   return GetAllocation(allocation_index_for_value_.at(&value));
402 }
403 
GetAssignedAllocation(const HloBuffer & hlo_buffer) const404 const BufferAllocation& BufferAssignment::GetAssignedAllocation(
405     const HloBuffer& hlo_buffer) const {
406   return GetAssignedAllocation(*hlo_buffer.values()[0]);
407 }
408 
GetMutableAssignedAllocation(const HloBuffer & buffer)409 BufferAllocation* BufferAssignment::GetMutableAssignedAllocation(
410     const HloBuffer& buffer) {
411   return const_cast<BufferAllocation*>(&GetAssignedAllocation(buffer));
412 }
413 
GetAllSlices(const HloInstruction * instruction,const ShapeIndex & index) const414 std::set<BufferAllocation::Slice> BufferAssignment::GetAllSlices(
415     const HloInstruction* instruction, const ShapeIndex& index) const {
416   std::set<BufferAllocation::Slice> result;
417   for (const HloValue* value :
418        dataflow_analysis().GetValueSet(instruction, index).values()) {
419     if (HasAllocation(*value)) {
420       result.insert(GetAssignedAllocation(*value).GetSlice(*value));
421     }
422   }
423   return result;
424 }
425 
GetAllocation(BufferAllocation::Index index) const426 const BufferAllocation& BufferAssignment::GetAllocation(
427     BufferAllocation::Index index) const {
428   CHECK_GE(index, 0);
429   CHECK_LT(index, allocations_.size());
430   return allocations_[index];
431 }
432 
GetInstructionAllocation(const HloInstruction * hlo,const ShapeIndex & shape_index) const433 const BufferAllocation* BufferAssignment::GetInstructionAllocation(
434     const HloInstruction* hlo, const ShapeIndex& shape_index) const {
435   const HloValue* value =
436       dataflow_analysis().GetValueSet(hlo, shape_index).values()[0];
437 
438   if (!HasAllocation(*value)) {
439     return nullptr;
440   }
441 
442   const BufferAllocation& instruction_allocation =
443       GetAssignedAllocation(*value);
444   return &instruction_allocation;
445 }
446 
GetMutableAllocation(BufferAllocation::Index index)447 BufferAllocation* BufferAssignment::GetMutableAllocation(
448     BufferAllocation::Index index) {
449   return const_cast<BufferAllocation*>(&GetAllocation(index));
450 }
451 
HasAllocationAt(const HloInstruction * instruction,const ShapeIndex & index) const452 bool BufferAssignment::HasAllocationAt(const HloInstruction* instruction,
453                                        const ShapeIndex& index) const {
454   for (const HloValue* value :
455        dataflow_analysis().GetValueSet(instruction, index).values()) {
456     if (allocation_index_for_value_.contains(value)) {
457       return true;
458     }
459   }
460   return false;
461 }
462 
HasTopLevelAllocation(const HloInstruction * instruction) const463 bool BufferAssignment::HasTopLevelAllocation(
464     const HloInstruction* instruction) const {
465   return HasAllocationAt(instruction, /*index=*/{});
466 }
467 
GetUniqueSlice(const HloInstruction * instruction,const ShapeIndex & index) const468 StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueSlice(
469     const HloInstruction* instruction, const ShapeIndex& index) const {
470   VLOG(3) << "Trying to find unique slice for " << instruction->name() << " ["
471           << index << "]";
472   BufferAllocation::Slice result;
473   for (const HloValue* value :
474        dataflow_analysis().GetValueSet(instruction, index).values()) {
475     VLOG(3) << "Examining value " << *value;
476     if (HasAllocation(*value)) {
477       VLOG(3) << "Has allocation";
478       const BufferAllocation::Slice slice =
479           GetAssignedAllocation(*value).GetSlice(*value);
480       if (result.allocation() == nullptr) {
481         result = slice;
482       } else if (result != slice) {
483         return FailedPrecondition(
484             "BufferAllocation::Slice for instruction %s at index %s cannot "
485             "be determined at compile-time.",
486             instruction->name(), index.ToString());
487       }
488     } else {
489       VLOG(3) << "No allocation";
490     }
491   }
492   if (result.allocation() == nullptr) {
493     return FailedPrecondition(
494         "BufferAllocation::Slice not assigned for instruction %s at index %s",
495         instruction->name(), index.ToString());
496   }
497   return result;
498 }
499 
GetUniqueTopLevelSlice(const HloInstruction * instruction) const500 StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueTopLevelSlice(
501     const HloInstruction* instruction) const {
502   return GetUniqueSlice(instruction, /*index=*/{});
503 }
504 
SharesSliceAtIndex(const HloInstruction * hlo_a,const ShapeIndex & shape_index_a,const HloInstruction * hlo_b,const ShapeIndex & shape_index_b) const505 bool BufferAssignment::SharesSliceAtIndex(
506     const HloInstruction* hlo_a, const ShapeIndex& shape_index_a,
507     const HloInstruction* hlo_b, const ShapeIndex& shape_index_b) const {
508   return GetUniqueSlice(hlo_a, shape_index_a).ConsumeValueOrDie() ==
509          GetUniqueSlice(hlo_b, shape_index_b).ConsumeValueOrDie();
510 }
511 
HaveDisjointSlices(const HloInstruction * hlo_a,const HloInstruction * hlo_b) const512 bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a,
513                                           const HloInstruction* hlo_b) const {
514   using SliceSet = flat_hash_set<BufferAllocation::Slice>;
515   // Gets the slices all of instr's subshapes.  If any subshape doesn't have an
516   // assigned slice, returns the empty set.
517   auto collect_slices = [&](const HloInstruction* instr) -> SliceSet {
518     SliceSet slices;
519     Status status = ShapeUtil::ForEachSubshapeWithStatus(
520         instr->shape(),
521         [&](const Shape& /*subshape*/, const ShapeIndex& index) {
522           auto shape_slices = GetAllSlices(instr, index);
523           if (shape_slices.empty()) {
524             return InvalidArgument("No slices assigned to part of instr.");
525           }
526           slices.insert(shape_slices.begin(), shape_slices.end());
527           return Status::OK();
528         });
529     if (!status.ok()) {
530       return {};
531     }
532     return slices;
533   };
534 
535   SliceSet slices_a = collect_slices(hlo_a);
536   SliceSet slices_b = collect_slices(hlo_b);
537   // hlo_a and hlo_b have disjoint slices if collect_slices succeeded (i.e.
538   // didn't return the empty set) for both HLOs, and the two resulting sets of
539   // slices are disjoint.
540   return !slices_a.empty() && !slices_b.empty() &&
541          absl::c_none_of(slices_a, [&](const BufferAllocation::Slice& slice) {
542            return slices_b.contains(slice);
543          });
544 }
545 
546 StatusOr<BufferAllocation::Slice>
GetUniqueTopLevelOutputSlice() const547 BufferAssignment::GetUniqueTopLevelOutputSlice() const {
548   return GetUniqueTopLevelSlice(
549       module_->entry_computation()->root_instruction());
550 }
551 
NewEmptyAllocation(int64 size,LogicalBuffer::Color color)552 BufferAllocation* BufferAssignment::NewEmptyAllocation(
553     int64 size, LogicalBuffer::Color color) {
554   BufferAllocation::Index index = allocations_.size();
555   allocations_.emplace_back(index, size, color);
556   BufferAllocation* allocation = &allocations_.back();
557   return allocation;
558 }
559 
NewAllocation(const HloBuffer & buffer,int64 size)560 BufferAllocation* BufferAssignment::NewAllocation(const HloBuffer& buffer,
561                                                   int64 size) {
562   BufferAllocation* allocation = NewEmptyAllocation(size, buffer.color());
563   AddAssignment(allocation, buffer, /*offset=*/0, size);
564   allocation->peak_buffers_.push_back(buffer.values()[0]);
565   return allocation;
566 }
567 
AddAssignment(BufferAllocation * allocation,const HloBuffer & buffer,int64 offset,int64 size)568 void BufferAssignment::AddAssignment(BufferAllocation* allocation,
569                                      const HloBuffer& buffer, int64 offset,
570                                      int64 size) {
571   CHECK(allocation->is_reusable() || allocation->assigned_buffers().empty())
572       << "Non-reusable allocation already assigned a buffer: "
573       << allocation->ToString();
574 
575   for (const HloValue* buffer_value : buffer.values()) {
576     CHECK(!allocation_index_for_value_.contains(buffer_value))
577         << "BufferValue " << buffer_value << " already has an allocation.";
578     allocation->AddAssignment(*buffer_value, offset, size);
579     allocation_index_for_value_[buffer_value] = allocation->index();
580   }
581 
582   if (alias_analysis().BufferLivesOut(buffer)) {
583     VLOG(3) << "HloBuffer lives out" << buffer.ToString();
584     VLOG(3) << "Set maybe live out: " << allocation->ToString();
585     allocation->set_maybe_live_out(true);
586   }
587 }
588 
AddAssignment(BufferAllocation * allocation,const HloValue & value,int64 offset,int64 size)589 void BufferAssignment::AddAssignment(BufferAllocation* allocation,
590                                      const HloValue& value, int64 offset,
591                                      int64 size) {
592   allocation->AddAssignment(value, offset, size);
593   allocation_index_for_value_[&value] = allocation->index();
594   const HloValue& hlo_value =
595       *CHECK_NOTNULL(dynamic_cast<const HloValue*>(&value));
596   if (alias_analysis().ValueLivesOut(hlo_value)) {
597     VLOG(3) << "HloValue lives out: " << hlo_value.ToString();
598     VLOG(3) << "Set maybe live out: " << allocation->ToString();
599     allocation->set_maybe_live_out(true);
600   }
601 }
602 
603 // Combines allocations of temporary buffers of the same color into one big
604 // BufferAllocation.
CombineTempAllocations()605 void BufferAssignment::CombineTempAllocations() {
606   VLOG(1) << "CombineTempAllocations()";
607   flat_hash_map<BufferValue::Color, BufferAllocation,
608                 BufferValue::Color::Hasher>
609       combined_allocation_map;
610 
611   // Move all temp allocations into a single run at the end of the allocations
612   // vector.
613   const auto first_temp_it =
614       std::partition(allocations_.begin(), allocations_.end(),
615                      [](const BufferAllocation& allocation) {
616                        return !allocation.IsPreallocatedTempBuffer();
617                      });
618 
619   // Walk over the run of temp allocations, collecting the allocations belonging
620   // to the same color.
621   if (first_temp_it != allocations_.end()) {
622     for (auto it = first_temp_it; it != allocations_.end(); ++it) {
623       const BufferAllocation& temp_allocation = *it;
624       BufferValue::Color color = temp_allocation.color();
625       auto combined_it = combined_allocation_map.find(color);
626       if (combined_it == combined_allocation_map.end()) {
627         // We have found the first temp allocation of this color. Collect
628         // the other temp allocations of the same color into it.
629         VLOG(1) << "Combined temp allocation for color " << color
630                 << " is: " << temp_allocation;
631         combined_allocation_map.emplace(color, temp_allocation);
632         continue;
633       }
634 
635       auto* combined_allocation = &combined_it->second;
636       VLOG(1) << "Combined allocation absorbing temp allocation: "
637               << temp_allocation;
638 
639       // Each temp allocation is placed end-to-end, accounting for alignment.
640       // The offset of each buffer in the combined allocation is computed from
641       // the base offset of the allocation.
642       int64 alignment = color_alignment_(color);
643       const int64 base =
644           RoundUpToNearest(combined_allocation->size(), alignment);
645       combined_allocation->set_size(base + temp_allocation.size());
646       for (const auto& buffer_offset_size : temp_allocation.assigned_buffers_) {
647         const HloValue* value = buffer_offset_size.first;
648         const int64 offset = buffer_offset_size.second.offset;
649         const int64 size = buffer_offset_size.second.size;
650         combined_allocation->AddAssignment(*value, base + offset, size);
651       }
652       if (!temp_allocation.HeapTraces().empty()) {
653         CHECK_EQ(temp_allocation.HeapTraces().size(), 1);
654         combined_allocation->AddHeapTrace(temp_allocation.HeapTraces().front());
655       }
656 
657       combined_allocation->peak_buffers_.insert(
658           combined_allocation->peak_buffers_.end(),
659           temp_allocation.peak_buffers_.begin(),
660           temp_allocation.peak_buffers_.end());
661     }
662     // Replace all existing temporary allocations with the new combined
663     // allocations.
664     allocations_.erase(first_temp_it, allocations_.end());
665     for (auto& combined : combined_allocation_map) {
666       allocations_.push_back(combined.second);
667       temp_allocation_total_size_ += combined.second.size();
668     }
669   }
670 
671   // Update allocation indices to their new positions.
672   allocation_index_for_value_.erase(allocation_index_for_value_.begin(),
673                                     allocation_index_for_value_.end());
674   for (size_t index = 0; index < allocations_.size(); ++index) {
675     BufferAllocation* allocation = &allocations_[index];
676     allocation->set_index(index);
677     for (const auto& buffer_offset_size : allocation->assigned_buffers_) {
678       const HloValue* value = buffer_offset_size.first;
679       allocation_index_for_value_[value] = index;
680     }
681   }
682 }
683 
ComputeSummaryStats()684 Status BufferAssignment::ComputeSummaryStats() {
685   for (auto& allocation : Allocations()) {
686     if (allocation.is_entry_computation_parameter()) {
687       stats_.parameter_allocation_count++;
688       stats_.parameter_allocation_bytes += allocation.size();
689     }
690     if (allocation.is_constant()) {
691       stats_.constant_allocation_count++;
692       stats_.constant_allocation_bytes += allocation.size();
693     }
694     if (allocation.maybe_live_out()) {
695       stats_.maybe_live_out_allocation_count++;
696       stats_.maybe_live_out_allocation_bytes += allocation.size();
697     }
698     if (allocation.IsPreallocatedTempBuffer()) {
699       stats_.preallocated_temp_allocation_count++;
700       stats_.preallocated_temp_allocation_bytes += allocation.size();
701     }
702     stats_.total_allocation_count++;
703     stats_.total_allocation_bytes += allocation.size();
704   }
705 
706   // Only compute total fragmentation if all computations have schedules.
707   HloSchedule schedule(module_);
708   bool schedule_complete = true;
709   for (const auto& computation : module_->computations()) {
710     if (!computation->IsFusionComputation()) {
711       const HloInstructionSequence* sequence =
712           hlo_ordering().SequentialOrder(*computation);
713       if (sequence == nullptr) {
714         schedule_complete = false;
715       } else {
716         schedule.set_sequence(computation, *sequence);
717       }
718     }
719   }
720   if (schedule_complete) {
721     TF_RETURN_IF_ERROR(schedule.Verify());
722     TF_ASSIGN_OR_RETURN(
723         const int64 min_size,
724         HeapSimulator::MinimumMemoryForModule(schedule, buffer_size_));
725     stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size;
726   }
727 
728   return Status::OK();
729 }
730 
ToString() const731 string BufferAssignment::Stats::ToString() const {
732   string s;
733   StrAppendFormat(&s, "BufferAssignment stats:\n");
734   StrAppendFormat(&s, "             parameter allocation: %10s\n",
735                   HumanReadableNumBytes(parameter_allocation_bytes));
736   StrAppendFormat(&s, "              constant allocation: %10s\n",
737                   HumanReadableNumBytes(constant_allocation_bytes));
738   StrAppendFormat(&s, "        maybe_live_out allocation: %10s\n",
739                   HumanReadableNumBytes(maybe_live_out_allocation_bytes));
740   StrAppendFormat(&s, "     preallocated temp allocation: %10s\n",
741                   HumanReadableNumBytes(preallocated_temp_allocation_bytes));
742   if (preallocated_temp_fragmentation_bytes >= 0) {
743     const double percent = 100. * preallocated_temp_fragmentation_bytes /
744                            preallocated_temp_allocation_bytes;
745     StrAppendFormat(
746         &s, "  preallocated temp fragmentation: %10s (%.2f%%)\n",
747         HumanReadableNumBytes(preallocated_temp_fragmentation_bytes), percent);
748   }
749   StrAppendFormat(&s, "                 total allocation: %10s\n",
750                   HumanReadableNumBytes(total_allocation_bytes));
751   if (total_fragmentation_bytes >= 0) {
752     const double percent =
753         100. * total_fragmentation_bytes / total_allocation_bytes;
754     StrAppendFormat(&s, "              total fragmentation: %10s (%.2f%%)\n",
755                     HumanReadableNumBytes(total_fragmentation_bytes), percent);
756   }
757   return s;
758 }
759 
ToString() const760 string BufferAssignment::ToString() const {
761   string output;
762   absl::StrAppend(&output, "BufferAssignment:\n");
763   std::vector<const HloValue*> used_values;
764   int64 total_size = 0;
765   for (auto& allocation : allocations_) {
766     total_size += allocation.size();
767     absl::StrAppend(&output, allocation.ToString());
768     for (const auto& p : allocation.assigned_buffers()) {
769       used_values.push_back(p.first);
770     }
771   }
772   absl::StrAppend(&output, "\nTotal bytes used: ", total_size, "\n");
773   absl::StrAppend(&output, "\nUsed values:\n");
774   absl::c_sort(used_values, &CompareHloValuesById);
775   for (const HloValue* value : used_values) {
776     absl::StrAppend(&output, value->ToString());
777   }
778   return output;
779 }
780 
ToProto() const781 BufferAssignmentProto BufferAssignment::ToProto() const {
782   BufferAssignmentProto proto;
783   // NOTE: DataflowAnalysis state is serialized here in BufferAssignment,
784   // because we need to do the HasAllocation check for each buffer. Otherwise
785   // the buffer_size_ call might fail for some backends.
786   const HloDataflowAnalysis& dataflow = this->dataflow_analysis();
787   for (BufferValue::Id id = 0; id < dataflow.values().size(); id++) {
788     auto& value = dataflow.values().at(id);
789     if (HasAllocation(*value)) {
790       LogicalBufferProto proto_buffer = value->ToProto(buffer_size_);
791       proto.add_logical_buffers()->Swap(&proto_buffer);
792 
793       // Fill buffer aliases.
794       for (const HloValue* alias :
795            alias_analysis().GetBufferContainingValue(*value).values()) {
796         if (alias->instruction() == value->instruction() &&
797             alias->index() == value->index()) {
798           continue;  // skip self-aliases
799         }
800         BufferAssignmentProto::BufferAlias* proto_alias =
801             proto.add_buffer_aliases();
802         LogicalBufferProto::Location proto_alias_location =
803             BufferValue::ToLocationProto(*alias->instruction(), alias->index());
804         proto_alias->set_source_buffer_id(value->id());
805         proto_alias->mutable_location()->Swap(&proto_alias_location);
806       }
807     }
808   }
809   for (const BufferAllocation& allocation : Allocations()) {
810     BufferAllocationProto proto_allocation = allocation.ToProto();
811     proto.add_buffer_allocations()->Swap(&proto_allocation);
812     for (const HeapSimulatorTrace& heap_trace : allocation.HeapTraces()) {
813       *proto.add_heap_simulator_traces() = heap_trace;
814     }
815   }
816   return proto;
817 }
818 
819 /* static */
Run(const HloModule * module,std::unique_ptr<HloOrdering> hlo_ordering,BufferValue::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment,bool allocate_buffers_for_constants,BufferAssigner::Colorer colorer,const absl::flat_hash_set<HloOpcode> & reuse_checker,HloDataflowAnalysis::CanShareBuffer can_share_buffer,std::unique_ptr<PresetAssignments> preset_assignments)820 StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::Run(
821     const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
822     BufferValue::SizeFunction buffer_size,
823     LogicalBuffer::AlignmentFunction color_alignment,
824     bool allocate_buffers_for_constants, BufferAssigner::Colorer colorer,
825     const absl::flat_hash_set<HloOpcode>& reuse_checker,
826     HloDataflowAnalysis::CanShareBuffer can_share_buffer,
827     std::unique_ptr<PresetAssignments> preset_assignments) {
828   BufferAssigner assigner(allocate_buffers_for_constants, std::move(colorer),
829                           reuse_checker, std::move(preset_assignments));
830   return assigner.CreateAssignment(
831       module, std::move(hlo_ordering), std::move(buffer_size),
832       std::move(color_alignment), std::move(can_share_buffer));
833 }
834 
LiveRangeInterferes(const HloValue * buffer1,const HloValue * buffer2,BufferAssignment * assignment)835 bool BufferAssigner::LiveRangeInterferes(const HloValue* buffer1,
836                                          const HloValue* buffer2,
837                                          BufferAssignment* assignment) {
838   CHECK((assignment->hlo_live_range().total_order_scheduled()));
839   const HloLiveRange& hlo_live_range = assignment->hlo_live_range();
840 
841   const auto& buffer_live_ranges = hlo_live_range.buffer_live_ranges();
842 
843   CHECK(buffer_live_ranges.contains(buffer1))
844       << "Buffer doesn't have a proper live range:" << buffer1;
845 
846   CHECK(buffer_live_ranges.contains(buffer2))
847       << "Buffer doesn't have a proper live range:" << buffer2;
848 
849   // Check if a user value can share the same buffer as its operand.
850   auto can_share_as_operand = [&assignment](const HloValue* user_value,
851                                             const HloValue* operand_value) {
852     return user_value->instruction()->IsUserOf(operand_value->instruction()) &&
853            assignment->dataflow_analysis().CanShareOperandBufferWithUser(
854                operand_value->instruction(), operand_value->index(),
855                user_value->instruction(), user_value->index()) &&
856            user_value->instruction()->opcode() != HloOpcode::kCopy;
857   };
858 
859   auto live_range_1 = buffer_live_ranges.at(buffer1);
860   auto live_range_2 = buffer_live_ranges.at(buffer2);
861 
862   if (!(live_range_1.start > live_range_2.end ||
863         live_range_2.start > live_range_1.end)) {
864     if (live_range_1.end == live_range_2.start) {
865       auto operand_value = buffer1;
866       auto user_value = buffer2;
867       if (!can_share_as_operand(user_value, operand_value)) {
868         VLOG(4) << "End of live range of " << buffer1->ToShortString()
869                 << " is equal to the start of live range of "
870                 << buffer2->ToShortString() << ", buffer cannot be shared.";
871         return true;
872       }
873     } else if (live_range_2.end == live_range_1.start) {
874       auto operand_value = buffer2;
875       auto user_value = buffer1;
876       if (!can_share_as_operand(user_value, operand_value)) {
877         VLOG(4) << "End of live range of " << buffer2->ToShortString()
878                 << " is equal to the start of live range of "
879                 << buffer1->ToShortString() << ", buffer cannot be shared.";
880         return true;
881       }
882     } else {
883       VLOG(4) << "Can't assign: assignee " << *buffer1 << " may interfere with "
884               << *buffer2;
885       VLOG(4) << "assigned_buffer.start: " << live_range_1.start;
886       VLOG(4) << "assigned_buffer.end: " << live_range_1.end;
887       VLOG(4) << "live_range_2.start" << live_range_2.start;
888       VLOG(4) << "live_range_2.end" << live_range_2.end;
889       return true;
890     }
891   }
892   return false;
893 }
894 
MaybeAssignBuffer(BufferAllocation * allocation,const HloBuffer & hlo_buffer,BufferAssignment * assignment)895 bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
896                                        const HloBuffer& hlo_buffer,
897                                        BufferAssignment* assignment) {
898   CHECK(!assignment->HasAllocation(hlo_buffer))
899       << "buffer " << hlo_buffer << " already has an allocation assigned.";
900 
901   VLOG(4) << "Trying to assign " << hlo_buffer << " size "
902           << assignment->HloBufferSize(hlo_buffer)
903           << " to allocation: " << *allocation;
904 
905   if (hlo_buffer.color() != allocation->color()) {
906     VLOG(4) << "Can't assign: buffer has color " << hlo_buffer.color()
907             << " and allocation has color " << allocation->color() << ".";
908     return false;
909   }
910 
911   if (assignment->HloBufferSize(hlo_buffer) > allocation->size()) {
912     VLOG(4) << "Can't assign: buffer is larger than allocation ("
913             << assignment->HloBufferSize(hlo_buffer) << " > "
914             << allocation->size() << ")";
915     return false;
916   }
917 
918   if (allocation->is_readonly()) {
919     VLOG(4) << "Can't assign: allocation is readonly";
920     return false;
921   }
922 
923   if (!must_not_live_out_.empty()) {
924     if (allocation->maybe_live_out()) {
925       // If a buffer maybe live out, the allocation cannot contain any node from
926       // the "must_not_live_out_" set.
927       for (const HloValue* value : hlo_buffer.values()) {
928         if (must_not_live_out_.count(value->instruction()->opcode()) > 0) {
929           VLOG(4) << "Can't assign: " << value->instruction()->ToString()
930                   << " cannot live out of the module";
931           return false;
932         }
933       }
934     }
935     // The above check is not enough -- There could be the case where an
936     // allocation can be not live out and contains an instruction with opcode
937     // from the "must_not_live_out_" set, but assigning a live out buffer to
938     // that allocation makes the allocation live out and also contains
939     // instruction from the "must_not_live_out_" set.
940     if (assignment->alias_analysis().BufferLivesOut(hlo_buffer)) {
941       for (const auto& buffer_offset_size : allocation->assigned_buffers()) {
942         if (must_not_live_out_.count(
943                 buffer_offset_size.first->instruction()->opcode()) > 0) {
944           VLOG(4) << "Can't assign: " << buffer_offset_size.first->instruction()
945                   << " cannot live out of the module";
946           return false;
947         }
948       }
949     }
950   }
951 
952   if (!allocation->is_reusable()) {
953     VLOG(4) << "Can't assign: allocation is not reusable";
954     return false;
955   }
956 
957   for (const auto& buffer_offset_size : allocation->assigned_buffers()) {
958     // Pairwise compare.
959     const HloValue& assigned_buffer =
960         *CHECK_NOTNULL(dynamic_cast<const HloValue*>(buffer_offset_size.first));
961     for (const HloValue* new_value : hlo_buffer.values()) {
962       if (assignment->hlo_live_range().total_order_scheduled()) {
963         if (LiveRangeInterferes(new_value, &assigned_buffer, assignment)) {
964           VLOG(4) << "Can't assign: assignee " << assigned_buffer
965                   << " live range interferes with "
966                   << new_value->ToShortString();
967           return false;
968         }
969       } else if (assignment->hlo_ordering().MayInterfere(
970                      assigned_buffer, *new_value,
971                      assignment->dataflow_analysis())) {
972         // Fallback to partial order based interference detection (slower) when
973         // we don't have a total order scheduled module.
974         VLOG(4) << "Can't assign: assignee " << assigned_buffer
975                 << " may interfere with " << new_value->ToShortString();
976         return false;
977       }
978 
979       for (const HloPosition& assigned_buffer_position :
980            assigned_buffer.positions()) {
981         // Copy instruction don't share a buffer with their input operand.
982         if (new_value->instruction()->IsUserOf(
983                 assigned_buffer_position.instruction) &&
984             new_value->instruction()->opcode() == HloOpcode::kCopy) {
985           VLOG(4) << "Can't assign: assignee " << assigned_buffer
986                   << " is used at copy instruction "
987                   << new_value->ToShortString();
988           return false;
989         }
990       }
991     }
992   }
993 
994   // If the buffer is live out of the computation then it should only be
995   // assigned a buffer which exactly fits the result to avoid wasting memory
996   // (result buffers can have arbitrary lifetimes).
997   if (assignment->alias_analysis().BufferLivesOut(hlo_buffer) &&
998       allocation->size() != assignment->HloBufferSize(hlo_buffer)) {
999     VLOG(4) << "Can't assign: buffer " << hlo_buffer
1000             << "is live out and size not the same as allocation";
1001     return false;
1002   }
1003 
1004   assignment->AddAssignment(allocation, hlo_buffer, /*offset=*/0,
1005                             assignment->HloBufferSize(hlo_buffer));
1006   return true;
1007 }  // namespace xla
1008 
MergeInplaceOpBuffers(BufferAssignment * assignment)1009 Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) {
1010   // Try allocate same buffer for dynamic update slice's operand and output.
1011   //
1012   // TODO(yunxing): Moving this logic to alias analysis and add must-alias rule
1013   // to operations that can be done in place.
1014   for (HloComputation* computation : assignment->module().computations()) {
1015     for (HloInstruction* instruction : computation->instructions()) {
1016       if (!(instruction->opcode() == HloOpcode::kDynamicUpdateSlice ||
1017             (instruction->opcode() == HloOpcode::kFusion &&
1018              (instruction->fused_expression_root()->opcode() ==
1019               HloOpcode::kDynamicUpdateSlice)))) {
1020         continue;
1021       }
1022       if (instruction->parent()->IsFusionComputation()) {
1023         continue;
1024       }
1025       if (instruction->operand_count() == 0) {
1026         continue;
1027       }
1028 
1029       // The operand can't share the same buffer with the user based on dataflow
1030       // analysis.
1031       if (!assignment->dataflow_analysis().CanShareOperandBufferWithUser(
1032               instruction->mutable_operand(0), {}, instruction, {})) {
1033         continue;
1034       }
1035       HloBuffer& instruction_buffer =
1036           assignment->alias_analysis().GetUniqueBufferAt(instruction, {});
1037 
1038       HloBuffer& operand_buffer =
1039           assignment->alias_analysis().GetUniqueBufferAt(
1040               instruction->operand(0), {});
1041 
1042       // Already have the same buffer. No need to merge those.
1043       if (instruction_buffer.id() == operand_buffer.id()) {
1044         continue;
1045       }
1046 
1047       // Do not perform in-place dynamic update slice if the operand buffer is
1048       // read-only.
1049       if (HloBufferIsReadOnly(operand_buffer)) {
1050         continue;
1051       }
1052 
1053       bool interfere = false;
1054 
1055       for (const HloValue* instruction_value : instruction_buffer.values()) {
1056         for (const HloValue* operand_value : operand_buffer.values()) {
1057           if (assignment->hlo_ordering().MayInterfere(
1058                   *instruction_value, *operand_value,
1059                   assignment->dataflow_analysis())) {
1060             interfere = true;
1061             break;
1062           }
1063         }
1064       }
1065       if (interfere) {
1066         continue;
1067       }
1068       if (assignment->alias_analysis().BufferLivesOut(instruction_buffer)) {
1069         continue;
1070       }
1071       if (instruction_buffer.color() != operand_buffer.color()) {
1072         continue;
1073       }
1074       VLOG(3) << "Merging inplace " << instruction_buffer << " and "
1075               << operand_buffer;
1076       assignment->alias_analysis().MergeBuffers(instruction_buffer,
1077                                                 operand_buffer);
1078     }
1079   }
1080   return Status::OK();
1081 }
1082 
AssignSingleHloBuffer(const HloBuffer * hlo_buffer,bool is_thread_local,absl::flat_hash_map<const HloComputation *,absl::flat_hash_set<const HloValue * >> * buffers_to_assign_sequentially,std::vector<BufferAllocation::Index> * allocation_indices,BufferAssignment * assignment)1083 Status BufferAssigner::AssignSingleHloBuffer(
1084     const HloBuffer* hlo_buffer, bool is_thread_local,
1085     absl::flat_hash_map<const HloComputation*,
1086                         absl::flat_hash_set<const HloValue*>>*
1087         buffers_to_assign_sequentially,
1088     std::vector<BufferAllocation::Index>* allocation_indices,
1089     BufferAssignment* assignment) {
1090   const int64 buffer_size = assignment->HloBufferSize(*hlo_buffer);
1091   for (const HloValue* value : hlo_buffer->values()) {
1092     if (value->instruction()->opcode() == HloOpcode::kConstant) {
1093       if (allocate_buffers_for_constants_) {
1094         BufferAllocation* allocation =
1095             assignment->NewAllocation(*hlo_buffer, buffer_size);
1096         allocation->set_constant(true);
1097         VLOG(3) << "New allocation #" << allocation->index() << " for constant "
1098                 << *hlo_buffer << " value ptr: " << value;
1099       }
1100       VLOG(3) << "Not allocating buffer for constant";
1101       return Status::OK();
1102     }
1103 
1104     const HloInstruction* instruction = value->instruction();
1105     const bool is_entry_parameter =
1106         instruction->opcode() == HloOpcode::kParameter &&
1107         instruction->parent() ==
1108             instruction->parent()->parent()->entry_computation();
1109 
1110     if (is_entry_parameter) {
1111       bool parameter_has_alias =
1112           assignment->module().input_output_alias_config().ParameterHasAlias(
1113               instruction->parameter_number(), value->index());
1114       // If the hlo buffer is part of an external parameter, creates a new
1115       // allocation and sets its parameter number. Parameters of non-entry
1116       // computations do not need special allocations because they live inside
1117       // callers.
1118       BufferAllocation* allocation =
1119           assignment->NewAllocation(*hlo_buffer, buffer_size);
1120 
1121       allocation->set_entry_computation_parameter(
1122           instruction->parameter_number(), value->index(), parameter_has_alias);
1123       if (parameter_has_alias) {
1124         allocation_indices->push_back(allocation->index());
1125       }
1126       VLOG(3) << "New allocation #" << allocation->index()
1127               << " marked as entry computation parameter: " << *hlo_buffer;
1128       return Status::OK();
1129     }
1130   }
1131 
1132   if (is_thread_local) {
1133     BufferAllocation* allocation =
1134         assignment->NewAllocation(*hlo_buffer, buffer_size);
1135     allocation->set_is_thread_local(true);
1136     VLOG(3) << "New allocation #" << allocation->index()
1137             << " for thread-local: " << *hlo_buffer;
1138     return Status::OK();
1139   }
1140 
1141   for (const HloValue* value : hlo_buffer->values()) {
1142     if (value->shape().IsTuple()) {
1143       BufferAllocation* allocation =
1144           assignment->NewAllocation(*hlo_buffer, buffer_size);
1145       allocation->set_is_tuple(true);
1146       VLOG(3) << "New allocation #" << allocation->index()
1147               << " for tuple-shaped buffer: " << *hlo_buffer;
1148       return Status::OK();
1149     }
1150 
1151     if (value->IsTopLevel() && !value->IsTuple()) {
1152       const HloInstruction* instruction = value->instruction();
1153       for (auto* operand : instruction->operands()) {
1154         for (const auto& operand_slice :
1155              assignment->GetAllSlices(operand, /*index=*/{})) {
1156           BufferAllocation* allocation =
1157               assignment->GetMutableAllocation(operand_slice.index());
1158           if (MaybeAssignBuffer(allocation, *hlo_buffer, assignment)) {
1159             VLOG(3) << "Reusing (operand) allocation #" << allocation->index()
1160                     << " for: " << *hlo_buffer;
1161             return Status::OK();
1162           }
1163         }
1164       }
1165     }
1166   }
1167 
1168   // Find the smallest buffer which can be reused iterating from end of
1169   // allocation_indices (smallest) to beginning (largest).
1170   for (int allocation_index = allocation_indices->size() - 1;
1171        allocation_index >= 0; allocation_index--) {
1172     BufferAllocation* allocation = assignment->GetMutableAllocation(
1173         allocation_indices->at(allocation_index));
1174     if (MaybeAssignBuffer(allocation, *hlo_buffer, assignment)) {
1175       VLOG(3) << "Reusing allocation #" << allocation->index()
1176               << " for: " << *hlo_buffer;
1177       return Status::OK();
1178     }
1179   }
1180 
1181   if (!assignment->HasAllocation(*hlo_buffer) &&
1182       !assignment->alias_analysis().BufferLivesOut(*hlo_buffer)) {
1183     bool all_computations_have_sequential_order = true;
1184     for (const HloValue* hlo_value : hlo_buffer->values()) {
1185       HloComputation* computation = hlo_value->instruction()->parent();
1186       const bool has_sequential_order =
1187           assignment->hlo_ordering().SequentialOrder(*computation) != nullptr;
1188       all_computations_have_sequential_order &= has_sequential_order;
1189     }
1190 
1191     if (all_computations_have_sequential_order) {
1192       for (const HloValue* hlo_value : hlo_buffer->values()) {
1193         HloComputation* computation = hlo_value->instruction()->parent();
1194         // There is a sequential instruction ordering, so we delay assignment
1195         // of temp buffers until after the loop. We do this right before we
1196         // decide to create a new allocation, to ensure we've exhausted all
1197         // the buffer re-use cases above.
1198         //
1199         // Entry parameters and thread local buffers were already handled
1200         // earlier in this loop iteration.  See
1201         // BufferAllocation::IsPreallocatedTempBuffer for the definition of
1202         // temp buffers.
1203         (*buffers_to_assign_sequentially)[computation].insert(hlo_value);
1204         VLOG(3) << "Delaying assignment of temp buffer: " << *hlo_value;
1205       }
1206       return Status::OK();
1207     }
1208   }
1209 
1210   if (!assignment->HasAllocation(*hlo_buffer)) {
1211     BufferAllocation* allocation =
1212         assignment->NewAllocation(*hlo_buffer, buffer_size);
1213     allocation_indices->push_back(allocation->index());
1214     VLOG(3) << "New allocation #" << allocation->index()
1215             << " for: " << *hlo_buffer;
1216   }
1217 
1218   TF_RET_CHECK(assignment->HasAllocation(*hlo_buffer));
1219   return Status::OK();
1220 }
1221 
AssignBuffersForComputations(const std::vector<const HloComputation * > & computations,bool is_thread_local,absl::flat_hash_map<const HloComputation *,absl::flat_hash_set<const HloValue * >> * buffers_to_assign_sequentially,BufferAssignment * assignment)1222 Status BufferAssigner::AssignBuffersForComputations(
1223     const std::vector<const HloComputation*>& computations,
1224     bool is_thread_local,
1225     absl::flat_hash_map<const HloComputation*,
1226                         absl::flat_hash_set<const HloValue*>>*
1227         buffers_to_assign_sequentially,
1228     BufferAssignment* assignment) {
1229   if (computations.empty()) {
1230     return Status::OK();
1231   }
1232   std::vector<const HloBuffer*> sorted_buffers;
1233 
1234   // First assign the preset allocations.
1235   absl::flat_hash_set<const HloBuffer*> preset_assigned_buffers;
1236 
1237   TF_RETURN_IF_ERROR(AssignPresetBuffers(&preset_assigned_buffers, assignment));
1238 
1239   const HloAliasAnalysis& alias_analysis = assignment->alias_analysis();
1240 
1241   for (const HloBuffer& buffer : alias_analysis.buffers()) {
1242     // Skip if the buffer is already assigned since it had a preset allocation.
1243     if (preset_assigned_buffers.find(&buffer) !=
1244         preset_assigned_buffers.end()) {
1245       VLOG(3) << "Skip allocation for buffer: " << buffer;
1246       continue;
1247     }
1248     TF_RET_CHECK(!buffer.values().empty());
1249     const HloComputation* comp = buffer.values()[0]->instruction()->parent();
1250     if (absl::c_linear_search(computations, comp)) {
1251       sorted_buffers.push_back(&buffer);
1252     }
1253   }
1254 
1255   // Generate a post order sort of instructions for sorting of the
1256   // HloBuffers.
1257   flat_hash_map<const HloInstruction*, int> post_order_position;
1258   int position = 0;
1259   std::vector<const HloComputation*> reverse_post_order_computations;
1260   std::unique_ptr<CallGraph> call_graph =
1261       CallGraph::Build(computations[0]->parent());
1262   TF_RETURN_IF_ERROR(call_graph->VisitNodes([&](const CallGraphNode& node) {
1263     if (absl::c_linear_search(computations, node.computation())) {
1264       reverse_post_order_computations.push_back(node.computation());
1265     }
1266     return Status::OK();
1267   }));
1268   absl::c_reverse(reverse_post_order_computations);
1269   for (auto* computation : reverse_post_order_computations) {
1270     for (auto* instruction : computation->MakeInstructionPostOrder()) {
1271       post_order_position.emplace(instruction, position);
1272       position++;
1273     }
1274   }
1275 
1276   HloSchedule schedule(&assignment->module());
1277 
1278   for (const HloComputation* computation : computations) {
1279     const HloInstructionSequence* instruction_sequence =
1280         assignment->hlo_ordering().SequentialOrder(*computation);
1281     const bool has_sequential_order = instruction_sequence != nullptr;
1282     if (has_sequential_order && buffers_to_assign_sequentially != nullptr) {
1283       // Every sequential computation must get an entry in the
1284       // buffers_to_assign_sequentially map, even if we end up with an empty
1285       // set of buffers. This ensures we can correctly determine whether to
1286       // run whole-module heap simulation.
1287       buffers_to_assign_sequentially->emplace(computation,
1288                                               flat_hash_set<const HloValue*>());
1289 
1290       schedule.set_sequence(computation, *instruction_sequence);
1291     }
1292   }
1293 
1294   absl::c_sort(
1295       sorted_buffers, [&post_order_position, &alias_analysis, assignment](
1296                           const HloBuffer* a, const HloBuffer* b) {
1297         // Primary sort is by decreasing buffer size.
1298         const int64 a_size = assignment->HloBufferSize(*a);
1299         const int64 b_size = assignment->HloBufferSize(*b);
1300         if (a_size != b_size) {
1301           return a_size > b_size;  // use ">" for decreasing size.
1302         }
1303 
1304         const bool a_live_out = alias_analysis.BufferLivesOut(*a);
1305         const bool b_live_out = alias_analysis.BufferLivesOut(*b);
1306         if (a_live_out != b_live_out) {
1307           return a_live_out;
1308         }
1309         auto compare = [&post_order_position](const HloValue* value1,
1310                                               const HloValue* value2) {
1311           return post_order_position.at(value1->instruction()) <
1312                  post_order_position.at(value2->instruction());
1313         };
1314         const HloValue* a_min = *absl::c_min_element(a->values(), compare);
1315         const HloValue* b_min = *absl::c_min_element(b->values(), compare);
1316         return compare(a_min, b_min);
1317       });
1318 
1319   std::vector<BufferAllocation::Index> allocation_indices;
1320 
1321   for (const HloBuffer* buffer : sorted_buffers) {
1322     VLOG(3) << "=================================================";
1323     VLOG(3) << "Assigning buffer for " << *buffer;
1324     TF_RETURN_IF_ERROR(AssignSingleHloBuffer(buffer, is_thread_local,
1325                                              buffers_to_assign_sequentially,
1326                                              &allocation_indices, assignment));
1327   }
1328   return Status::OK();
1329 }
1330 
1331 flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>,
1332               LogicalBuffer::Color::Hasher>
SplitBuffersByColor(const flat_hash_set<const HloValue * > & buffers)1333 BufferAssigner::SplitBuffersByColor(
1334     const flat_hash_set<const HloValue*>& buffers) {
1335   flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>,
1336                 LogicalBuffer::Color::Hasher>
1337       color_map;
1338   for (auto buffer : buffers) {
1339     color_map[buffer->color()].insert(buffer);
1340   }
1341   return color_map;
1342 }
1343 
AssignPresetBuffers(absl::flat_hash_set<const HloBuffer * > * assigned_buffers,BufferAssignment * assignment)1344 Status BufferAssigner::AssignPresetBuffers(
1345     absl::flat_hash_set<const HloBuffer*>* assigned_buffers,
1346     BufferAssignment* assignment) {
1347   if (!preset_assignments_) {
1348     return Status::OK();
1349   }
1350 
1351   // Create an allocation for each preset color.
1352   absl::flat_hash_map<LogicalBuffer::Color, BufferAllocation*,
1353                       LogicalBuffer::Color::Hasher>
1354       preset_allocations;
1355   for (auto& color_and_info : preset_assignments_->assignment_informations()) {
1356     LogicalBuffer::Color color(color_and_info.first);
1357     auto inserted = preset_allocations.emplace(
1358         color,
1359         assignment->NewEmptyAllocation(color_and_info.second.size, color));
1360     BufferAllocation* inserted_allocation = inserted.first->second;
1361     inserted_allocation->AddHeapTrace(
1362         color_and_info.second.heap_simulator_trace);
1363     VLOG(3) << "Created preset buffer allocation "
1364             << inserted_allocation->index()
1365             << ", color: " << inserted_allocation->color()
1366             << ", size: " << inserted_allocation->size();
1367   }
1368 
1369   const HloAliasAnalysis& alias_analysis = assignment->alias_analysis();
1370   const HloDataflowAnalysis& dataflow_analysis =
1371       alias_analysis.dataflow_analysis();
1372 
1373   for (auto& position_and_chunk : preset_assignments_->chunks()) {
1374     const HloPosition& position = position_and_chunk.first;
1375     const HloValue& value = dataflow_analysis.GetUniqueValueAt(
1376         position.instruction, position.index);
1377     VLOG(3) << "Preset allocation for value: " << value.ToShortString();
1378     const HeapSimulator::Chunk& chunk = position_and_chunk.second;
1379     auto preset_allocations_iter = preset_allocations.find(value.color());
1380     CHECK(preset_allocations_iter != preset_allocations.end())
1381         << "No preset value allocation for color " << value.color() << " for "
1382         << value.ToShortString() << " found.";
1383     preset_allocations_iter->second->AddAssignment(value, chunk.offset,
1384                                                    chunk.size);
1385 
1386     const HloBuffer& buffer = alias_analysis.GetBufferContainingValue(value);
1387     assigned_buffers->insert(&buffer);
1388   }
1389 
1390   // Upon consumption of the preset assignments, delete it so that if this
1391   // method is called again, it does not assign the same buffers multiple times.
1392   preset_assignments_ = {};
1393 
1394   return Status::OK();
1395 }
1396 
AssignBuffersWithSequentialOrdering(const flat_hash_map<const HloComputation *,flat_hash_set<const HloValue * >> & buffers_to_assign_sequentially,bool run_whole_module_heap_simulation,BufferAssignment * assignment)1397 Status BufferAssigner::AssignBuffersWithSequentialOrdering(
1398     const flat_hash_map<const HloComputation*, flat_hash_set<const HloValue*>>&
1399         buffers_to_assign_sequentially,
1400     bool run_whole_module_heap_simulation, BufferAssignment* assignment) {
1401   // Run the sequence of instructions through the heap simulator.  The
1402   // heuristic that seems to give the best results is lazy-best-fit, with all
1403   // runs of alloc / free calls sorted in decreasing size order.
1404   const HloOrdering& hlo_ordering = assignment->hlo_ordering();
1405 
1406   // Returns a heap algorithm that chooses the best result from several
1407   // algorithms.
1408   auto get_heap_algorithm = [&](int64 alignment) {
1409     auto algorithms =
1410         absl::make_unique<std::vector<std::unique_ptr<HeapAlgorithm>>>();
1411     algorithms->push_back(absl::make_unique<GlobalDecreasingSizeBestFitHeap>(
1412         alignment, GlobalDecreasingSizeBestFitHeap::kSpatial));
1413     algorithms->push_back(absl::make_unique<GlobalDecreasingSizeBestFitHeap>(
1414         alignment, GlobalDecreasingSizeBestFitHeap::kTemporal));
1415     return absl::make_unique<ChooseBestHeapAlgorithm>(std::move(algorithms));
1416   };
1417 
1418   if (run_whole_module_heap_simulation) {
1419     // Run the heap simulation over the whole module. This reduces memory
1420     // usage, since buffers for kCall, kWhile, and kConditional
1421     // sub-computations are only live for the duration of their calling
1422     // instructions.
1423     VLOG(1) << "Running whole-module heap simulation";
1424     HloSchedule schedule(&assignment->module());
1425     flat_hash_set<const HloValue*> all_buffers_to_assign;
1426     for (const auto& pair : buffers_to_assign_sequentially) {
1427       const HloComputation* computation = pair.first;
1428       const flat_hash_set<const HloValue*>& buffers_to_assign = pair.second;
1429       const HloInstructionSequence* instruction_sequence =
1430           hlo_ordering.SequentialOrder(*computation);
1431       CHECK(instruction_sequence != nullptr) << computation->name();
1432       schedule.set_sequence(computation, *instruction_sequence);
1433       all_buffers_to_assign.insert(buffers_to_assign.begin(),
1434                                    buffers_to_assign.end());
1435     }
1436     auto color_map = SplitBuffersByColor(all_buffers_to_assign);
1437     for (auto& single_colored_set : color_map) {
1438       auto color = single_colored_set.first;
1439       VLOG(2) << "Simulating heap for color " << color;
1440       int64 alignment = assignment->color_alignment_(color);
1441       HeapSimulator::Options options;
1442       options.alloc_constants = allocate_buffers_for_constants_;
1443       options.buffers_to_assign = &single_colored_set.second;
1444 
1445       TF_ASSIGN_OR_RETURN(
1446           HeapSimulator::Result result,
1447           HeapSimulator::Run(
1448               get_heap_algorithm(alignment), assignment->module(), schedule,
1449               assignment->alias_analysis(), assignment->buffer_size_, options));
1450       AssignBuffersFromHeapSimulator(result, assignment,
1451                                      single_colored_set.first);
1452     }
1453   } else {
1454     // Run the heap-simulation on a per-computation basis. Buffers for
1455     // sub-computations are assigned disjoint BufferAllocations, assuming the
1456     // worst-case that they may all be live concurrently.
1457     VLOG(1) << "Running per-computation heap simulation";
1458     for (const auto& pair : buffers_to_assign_sequentially) {
1459       const HloComputation* computation = pair.first;
1460       const flat_hash_set<const HloValue*>& buffers_to_assign = pair.second;
1461       const HloInstructionSequence* instruction_sequence =
1462           hlo_ordering.SequentialOrder(*computation);
1463       CHECK(instruction_sequence != nullptr) << computation->name();
1464       auto color_map = SplitBuffersByColor(buffers_to_assign);
1465       for (auto& single_colored_set : color_map) {
1466         auto color = single_colored_set.first;
1467         VLOG(2) << "Simulating heap for color " << color;
1468         int64 alignment = assignment->color_alignment_(color);
1469         HeapSimulator::Options options;
1470         options.buffers_to_assign = &single_colored_set.second;
1471         TF_ASSIGN_OR_RETURN(
1472             HeapSimulator::Result result,
1473             HeapSimulator::Run(get_heap_algorithm(alignment), *computation,
1474                                *instruction_sequence,
1475                                assignment->alias_analysis(),
1476                                assignment->buffer_size_, options));
1477         AssignBuffersFromHeapSimulator(result, assignment,
1478                                        single_colored_set.first);
1479       }
1480     }
1481   }
1482   return Status::OK();
1483 }
1484 
1485 namespace {
1486 // Computes and returns the set of logical buffers live at the point of
1487 // maximal liveness in the given heap trace. LogicalBuffers are (stabily)
1488 // sorted by id.
ComputePeakMemoryLogicalBuffers(const BufferAllocation & allocation,const HeapSimulatorTrace & heap_trace)1489 std::vector<const HloValue*> ComputePeakMemoryLogicalBuffers(
1490     const BufferAllocation& allocation, const HeapSimulatorTrace& heap_trace) {
1491   // Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical
1492   // buffers in this allocation.
1493   absl::flat_hash_map<BufferValue::Id, const HloValue*> id_to_value;
1494   absl::flat_hash_map<const HloValue*, int64> buffer_sizes;
1495   for (const auto& pair : allocation.assigned_buffers()) {
1496     const HloValue* value = pair.first;
1497     const BufferAllocation::OffsetSize& offset_size = pair.second;
1498     id_to_value[value->id()] = value;
1499     buffer_sizes[value] = offset_size.size;
1500   }
1501   VLOG(1) << "Compute peak memory logical buffers";
1502 
1503   // Returns how much the given event increases the total size of live
1504   // buffers. Can be negative.
1505   auto memory_delta = [&id_to_value, &buffer_sizes](
1506                           const HeapSimulatorTrace::Event& event) -> int64 {
1507     const HloValue* buffer = id_to_value.at(event.buffer_id());
1508     const int64 buffer_size = buffer_sizes.at(buffer);
1509     if (event.kind() == HeapSimulatorTrace::Event::ALLOC ||
1510         event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
1511       return buffer_size;
1512     } else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
1513       return -1 * buffer_size;
1514     }
1515     LOG(FATAL) << "Unknown event kind: " << event.kind();
1516   };
1517 
1518   // First compute the size of the maximal live set.
1519   int64 max_live_size = 0;
1520   int64 live_size = 0;
1521   for (const auto& event : heap_trace.events()) {
1522     live_size += memory_delta(event);
1523     if (max_live_size < live_size) {
1524       max_live_size = live_size;
1525     }
1526   }
1527 
1528   // Next gather the set of logical buffers live at the earliest point of
1529   // maximal live set size.
1530   absl::flat_hash_set<const HloValue*> live_values;
1531   live_size = 0;
1532   for (const auto& event : heap_trace.events()) {
1533     const HloValue* value = id_to_value.at(event.buffer_id());
1534     if (event.kind() == HeapSimulatorTrace::Event::ALLOC ||
1535         event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
1536       InsertOrDie(&live_values, value);
1537     } else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
1538       CHECK(ContainsKey(live_values, value));
1539       live_values.erase(value);
1540     }
1541     live_size += memory_delta(event);
1542 
1543     if (live_size == max_live_size) {
1544       break;
1545     }
1546   }
1547   CHECK_EQ(live_size, max_live_size);
1548 
1549   std::vector<const HloValue*> live_values_vector;
1550   live_values_vector.insert(live_values_vector.end(), live_values.begin(),
1551                             live_values.end());
1552 
1553   // Stabily sort the live buffers.
1554   absl::c_sort(live_values_vector, [](const HloValue* a, const HloValue* b) {
1555     return a->id() < b->id();
1556   });
1557   VLOG(4) << "Peak memory buffer:";
1558   for (auto value : live_values_vector) {
1559     VLOG(4) << "  " << value->ToString();
1560   }
1561   return live_values_vector;
1562 }
1563 
1564 }  // namespace
1565 
AssignBuffersFromHeapSimulator(const HeapSimulator::Result & result,BufferAssignment * assignment,BufferValue::Color color)1566 void BufferAssigner::AssignBuffersFromHeapSimulator(
1567     const HeapSimulator::Result& result, BufferAssignment* assignment,
1568     BufferValue::Color color) {
1569   if (assignment->stats_.preallocated_temp_fragmentation_bytes == -1) {
1570     assignment->stats_.preallocated_temp_fragmentation_bytes =
1571         result.fragmentation_size;
1572   } else {
1573     assignment->stats_.preallocated_temp_fragmentation_bytes +=
1574         result.fragmentation_size;
1575   }
1576   VLOG(1) << "Result size from heap simulator: " << result.heap_size;
1577 
1578   BufferAllocation* allocation =
1579       assignment->NewEmptyAllocation(result.heap_size, color);
1580   for (const auto& buffer_chunk : result.chunk_map) {
1581     const HloValue& value = *buffer_chunk.first;
1582     const HeapSimulator::Chunk& chunk = buffer_chunk.second;
1583     assignment->AddAssignment(allocation, value, chunk.offset, chunk.size);
1584   }
1585   allocation->peak_buffers_ =
1586       ComputePeakMemoryLogicalBuffers(*allocation, result.debug_trace);
1587 
1588   VLOG(1) << "Ran heap simulation for allocation: ";
1589   XLA_VLOG_LINES(2, allocation->ToString());
1590 
1591   allocation->AddHeapTrace(result.debug_trace);
1592 }
1593 
CreateAssignment(const HloModule * module,std::unique_ptr<HloOrdering> hlo_ordering,BufferValue::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment,HloDataflowAnalysis::CanShareBuffer can_share_buffer)1594 StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
1595     const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
1596     BufferValue::SizeFunction buffer_size,
1597     LogicalBuffer::AlignmentFunction color_alignment,
1598     HloDataflowAnalysis::CanShareBuffer can_share_buffer) {
1599   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1600                       HloAliasAnalysis::Run(module, can_share_buffer));
1601 
1602   // Set up a schedule for each computation.
1603   HloSchedule schedule(module);
1604   for (const HloComputation* computation : module->computations()) {
1605     const HloInstructionSequence* instruction_sequence =
1606         hlo_ordering->SequentialOrder(*computation);
1607     const bool has_sequential_order = instruction_sequence != nullptr;
1608     if (has_sequential_order) {
1609       schedule.set_sequence(computation, *instruction_sequence);
1610     }
1611   }
1612 
1613   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
1614                       HloLiveRange::Run(schedule, *alias_analysis,
1615                                         module->entry_computation(), true));
1616 
1617   VLOG(1) << "Assigning buffers to module " << module->name();
1618   XLA_VLOG_LINES(3, module->ToString());
1619   XLA_VLOG_LINES(3, alias_analysis->ToString());
1620   XLA_VLOG_LINES(3, alias_analysis->dataflow_analysis().ToString());
1621   VLOG(1) << "Number of buffers to assign: "
1622           << alias_analysis->buffers().size();
1623 
1624   // Can't use absl::make_unique because BufferAssignment constructor is
1625   // private.
1626   std::unique_ptr<BufferAssignment> assignment(new BufferAssignment(
1627       module, std::move(hlo_ordering), std::move(buffer_size),
1628       std::move(color_alignment), std::move(alias_analysis),
1629       std::move(hlo_live_range)));
1630 
1631   TF_RETURN_IF_ERROR(
1632       colorer_(&assignment->alias_analysis(), assignment->hlo_ordering()));
1633   VLOG(3) << "After coloring:";
1634   XLA_VLOG_LINES(3,
1635                  assignment->alias_analysis().dataflow_analysis().ToString());
1636   TF_RETURN_IF_ERROR(MergeInplaceOpBuffers(assignment.get()));
1637 
1638   std::vector<const HloComputation*> thread_local_computations;
1639   std::vector<const HloComputation*> global_computations;
1640   TF_RETURN_IF_ERROR(GatherComputationsByAllocationType(
1641       module, &thread_local_computations, &global_computations));
1642 
1643   // First assign buffers for global computations. Temporary buffers for
1644   // sequential computations are collected in
1645   // 'buffers_to_assign_sequentially'.
1646   flat_hash_map<const HloComputation*, flat_hash_set<const HloValue*>>
1647       buffers_to_assign_sequentially;
1648   TF_RETURN_IF_ERROR(AssignBuffersForComputations(
1649       global_computations,
1650       /*is_thread_local=*/false, &buffers_to_assign_sequentially,
1651       assignment.get()));
1652   // Assign buffers with sequential ordering, if any. If all global
1653   // computations are sequential, we can run heap simulation on the whole
1654   // module, which reduces memory usage.
1655   const bool run_whole_module_heap_simulation =
1656       buffers_to_assign_sequentially.size() == global_computations.size();
1657   VLOG(2) << "Running whole module heap simulation: "
1658           << run_whole_module_heap_simulation;
1659   TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering(
1660       buffers_to_assign_sequentially, run_whole_module_heap_simulation,
1661       assignment.get()));
1662 
1663   std::vector<const HloComputation*> thread_local_computations_no_fusion;
1664   // Now assign buffers for thread-local computations. All LogicalBuffers get
1665   // their own BufferAllocation.
1666 
1667   for (auto* computation : thread_local_computations) {
1668     TF_RET_CHECK(computation != module->entry_computation());
1669     if (computation->IsFusionComputation()) {
1670       continue;
1671     }
1672     thread_local_computations_no_fusion.push_back(computation);
1673   }
1674 
1675   TF_RETURN_IF_ERROR(AssignBuffersForComputations(
1676       thread_local_computations_no_fusion,
1677       /*is_thread_local=*/true,
1678       /*buffers_to_assign_sequentially=*/nullptr, assignment.get()));
1679 
1680   // Mark all buffers which may be live out of the entry computation as
1681   // "liveout".
1682   for (const HloBuffer* buffer :
1683        assignment->alias_analysis().LiveOutBuffers()) {
1684     VLOG(3) << "maybe_live_out LogicalBuffer: " << *buffer;
1685     if (assignment->HasAllocation(*buffer)) {
1686       BufferAllocation* alloc =
1687           assignment->GetMutableAssignedAllocation(*buffer);
1688       alloc->set_maybe_live_out(true);
1689       VLOG(3) << "maybe_live_out BufferAllocation: " << *alloc;
1690     }
1691   }
1692 
1693   // Combines allocations of temporary buffers into one big BufferAllocation.
1694   // This can only be performed after all buffers have been assigned, and
1695   // after maybe_live_out is marked, since it is used to determine whether an
1696   // allocation contains temporary buffers or not.
1697   assignment->CombineTempAllocations();
1698 
1699   XLA_VLOG_LINES(2, assignment->ToString());
1700   TF_RETURN_IF_ERROR(assignment->ComputeSummaryStats());
1701   XLA_VLOG_LINES(1, assignment->GetStats().ToString());
1702   VLOG(1) << "Buffer assignment done.";
1703   return std::move(assignment);
1704 }
1705 
1706 }  // namespace xla
1707