1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Defines the data returned by the XLA buffer assignment packages.
17
18 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
19
20 #include <algorithm>
21 #include <deque>
22 #include <ostream>
23 #include <utility>
24
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/memory/memory.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_format.h"
30 #include "tensorflow/compiler/xla/map_util.h"
31 #include "tensorflow/compiler/xla/service/buffer_value_containers.h"
32 #include "tensorflow/compiler/xla/service/heap_simulator.h"
33 #include "tensorflow/compiler/xla/service/hlo.pb.h"
34 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
35 #include "tensorflow/compiler/xla/service/hlo_buffer.h"
36 #include "tensorflow/compiler/xla/service/hlo_live_range.h"
37 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/status_macros.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/hash/hash.h"
44 #include "tensorflow/core/lib/strings/numbers.h"
45
46 namespace xla {
47 namespace {
48
49 using absl::flat_hash_map;
50 using absl::flat_hash_set;
51 using absl::StrAppend;
52 using absl::StrAppendFormat;
53 using ::tensorflow::strings::HumanReadableNumBytes;
54
55 // Given the interference map of a graph (the list of interfering node indices
56 // for each node), perform graph coloring such that interfering nodes are
57 // assigned to different colors. Returns the assigned color of the nodes, where
58 // the colors are represented as integer values [0, color_count).
ColorInterferenceGraph(const std::vector<std::vector<int64>> & interference_map)59 std::vector<int64> ColorInterferenceGraph(
60 const std::vector<std::vector<int64>>& interference_map) {
61 const int64 node_count = interference_map.size();
62
63 // Sort the nodes such that we assign nodes with more interference first. This
64 // relies on the common heuristic of assigning the most constrained node
65 // first, but it would be good to investigate other ordering heuristics too.
66 std::vector<int64> nodes(node_count);
67 std::iota(nodes.begin(), nodes.end(), 0);
68 absl::c_sort(nodes, [&interference_map](const int64 i, const int64 j) {
69 return interference_map[i].size() > interference_map[j].size();
70 });
71
72 const int64 kColorUnassigned = -1;
73 std::vector<int64> assigned_colors(node_count, kColorUnassigned);
74 for (int64 node : nodes) {
75 // Mark the colors that are already assigned to the neighbors.
76 std::vector<bool> available_colors(node_count, true);
77 for (int64 neighbor : interference_map[node]) {
78 int64 color = assigned_colors[neighbor];
79 if (color != kColorUnassigned) {
80 available_colors[color] = false;
81 }
82 }
83
84 // Find the color that is not yet assigned to the neighbors.
85 int64 color = kColorUnassigned;
86 for (color = 0; color < available_colors.size(); ++color) {
87 if (available_colors[color]) {
88 break;
89 }
90 }
91 CHECK_NE(color, kColorUnassigned);
92 assigned_colors[node] = color;
93 }
94 return assigned_colors;
95 }
96
97 // If an hlo buffer contains an entry parameter, the buffer is read-only unless
98 // it is aliased with an output.
HloBufferIsReadOnly(const HloBuffer & buffer)99 bool HloBufferIsReadOnly(const HloBuffer& buffer) {
100 for (const HloValue* value : buffer.values()) {
101 const HloInstruction* instruction = value->instruction();
102 const HloModule* module = instruction->parent()->parent();
103 const bool is_entry_parameter =
104 instruction->opcode() == HloOpcode::kParameter &&
105 instruction->parent() == module->entry_computation();
106
107 if (is_entry_parameter) {
108 bool parameter_has_alias =
109 module->input_output_alias_config().ParameterHasAlias(
110 instruction->parameter_number(), value->index());
111 // The parameter doesn't have an alias, it must be read-only.
112 if (!parameter_has_alias) {
113 return true;
114 }
115 }
116 }
117 return false;
118 }
119
120 } // namespace
121
GatherComputationsByAllocationType(const HloModule * module,std::vector<const HloComputation * > * thread_local_computations,std::vector<const HloComputation * > * global_computations)122 Status GatherComputationsByAllocationType(
123 const HloModule* module,
124 std::vector<const HloComputation*>* thread_local_computations,
125 std::vector<const HloComputation*>* global_computations) {
126 // Create a worklist of computations paired with whether the allocation must
127 // be thread-local.
128 std::deque<std::pair<const HloComputation*, bool>> worklist;
129 worklist.push_back(std::make_pair(module->entry_computation(),
130 /*is_thread_local*/ false));
131
132 // Sets for quickly checking membership. Computations are returned in vectors
133 // for stable iteration.
134 flat_hash_set<const HloComputation*> thread_local_set;
135 flat_hash_set<const HloComputation*> global_set;
136
137 while (!worklist.empty()) {
138 auto worklist_front = worklist.front();
139 worklist.pop_front();
140 const HloComputation* computation = worklist_front.first;
141 bool is_thread_local = worklist_front.second;
142 bool in_thread_local_set = thread_local_set.contains(computation);
143 bool in_global_set = global_set.contains(computation);
144
145 // If the computation has already been added to the respective set, then
146 // nothing to do.
147 if ((is_thread_local && in_thread_local_set) ||
148 (!is_thread_local && in_global_set)) {
149 continue;
150 }
151
152 // If the computation has already been added to the other set this is an
153 // error condition because the global call to the computation (eg,
154 // while/call) may return a reference to one of the thread-local buffers to
155 // the calling computation which will become a dangling reference when the
156 // thread-local is deallocated with the call return.
157 if ((is_thread_local && in_global_set) ||
158 (!is_thread_local && in_thread_local_set)) {
159 return InvalidArgument(
160 "computation %s has conflicting allocation requirements (global "
161 "and thread-local)",
162 computation->name());
163 }
164
165 if (is_thread_local) {
166 thread_local_set.insert(computation);
167 } else {
168 global_set.insert(computation);
169 }
170
171 for (auto* instruction : computation->instructions()) {
172 for (HloComputation* subcomputation :
173 instruction->called_computations()) {
174 switch (instruction->opcode()) {
175 case HloOpcode::kCall:
176 case HloOpcode::kConditional:
177 case HloOpcode::kWhile:
178 // Call and while must be called from a computation with global
179 // allocations as they may return references to buffers inside the
180 // called computation which cannot be thread-local.
181 if (is_thread_local) {
182 return InvalidArgument(
183 "computation %s cannot contain call/while op because it "
184 "requires thread-local buffer allocations",
185 computation->name());
186 }
187 worklist.push_back(std::make_pair(subcomputation,
188 false)); // Not thread local.
189 break;
190 case HloOpcode::kAllReduce:
191 case HloOpcode::kMap:
192 case HloOpcode::kReduce:
193 case HloOpcode::kReduceWindow:
194 case HloOpcode::kScatter:
195 case HloOpcode::kSelectAndScatter:
196 case HloOpcode::kSort:
197 case HloOpcode::kFusion:
198 // Map/reduce etc computations are always thread-local.
199 worklist.push_back(std::make_pair(subcomputation,
200 true)); // Thread local.
201 break;
202 default:
203 return InternalError("Unexpected calling opcode: %s",
204 HloOpcodeString(instruction->opcode()));
205 }
206 }
207 }
208 }
209
210 // Add the computations to the vectors in post order.
211 for (auto* computation : module->MakeComputationPostOrder()) {
212 if (thread_local_set.contains(computation)) {
213 thread_local_computations->push_back(computation);
214 } else if (global_set.contains(computation)) {
215 global_computations->push_back(computation);
216 }
217 // If the computation is not reachable from the entry computation, then it
218 // will not appear in either thread_local_set or global_set. We don't bother
219 // assigning buffers for these.
220 }
221 return Status::OK();
222 }
223
ToString() const224 string BufferAllocation::Slice::ToString() const {
225 return absl::StrCat("{index:", index(), ", offset:", offset_,
226 ", size:", size_, "}");
227 }
228
GetSlice(const HloValue & buffer) const229 BufferAllocation::Slice BufferAllocation::GetSlice(
230 const HloValue& buffer) const {
231 const OffsetSize os = FindOrDie(assigned_buffers_, &buffer);
232 return Slice(this, os.offset, os.size);
233 }
234
AddAssignment(const HloValue & buffer,int64 offset,int64 size)235 void BufferAllocation::AddAssignment(const HloValue& buffer, int64 offset,
236 int64 size) {
237 VLOG(4) << "Adding the following buffer to allocation #" << index()
238 << absl::StrFormat(" (size=%d, offset=%d) %s", size, offset,
239 buffer.ToShortString());
240 CHECK(!assigned_buffers_.contains(&buffer))
241 << "LogicalBuffer " << buffer << " already assigned to allocation "
242 << index_;
243 CHECK_LE(offset, size_) << "LogicalBuffer " << buffer
244 << " offset out of range";
245 CHECK_LE(offset + size, size_)
246 << "LogicalBuffer " << buffer
247 << " size out of range at offset: " << offset << " with size: " << size;
248 CHECK_EQ(buffer.color(), color())
249 << "Buffer color " << buffer.color() << " for buffer " << buffer
250 << " does not match allocation color " << color() << ".";
251 OffsetSize offset_size;
252 offset_size.offset = offset;
253 offset_size.size = size;
254 assigned_buffers_.emplace(&buffer, offset_size);
255 // For debugging purposes, store the assigned memory space in the
256 // instruction's layout.
257 for (HloPosition position : buffer.positions()) {
258 Shape* shape = ShapeUtil::GetMutableSubshape(
259 position.instruction->mutable_shape(), position.index);
260 if (shape->has_layout()) {
261 shape->mutable_layout()->set_memory_space(buffer.color().value());
262 }
263 }
264 }
265
ToProto() const266 BufferAllocationProto BufferAllocation::ToProto() const {
267 BufferAllocationProto proto;
268 proto.set_index(index_);
269 proto.set_size(size_);
270 proto.set_is_thread_local(is_thread_local_);
271 proto.set_is_tuple(is_tuple_);
272 proto.set_color(color_.value());
273 if (is_entry_computation_parameter_) {
274 proto.set_is_entry_computation_parameter(true);
275 for (int64 idx : param_shape_index()) {
276 proto.add_parameter_shape_index(idx);
277 }
278 proto.set_parameter_number(parameter_number_);
279 }
280 proto.set_is_constant(is_constant_);
281 proto.set_maybe_live_out(maybe_live_out_);
282 for (const auto& buffer_offset_size : assigned_buffers_) {
283 BufferAllocationProto::Assigned* proto_assigned = proto.add_assigned();
284 proto_assigned->set_logical_buffer_id(buffer_offset_size.first->id());
285 proto_assigned->set_offset(buffer_offset_size.second.offset);
286 proto_assigned->set_size(buffer_offset_size.second.size);
287 }
288 absl::c_sort(*proto.mutable_assigned(),
289 [](const BufferAllocationProto::Assigned& assign1,
290 const BufferAllocationProto::Assigned& assign2) {
291 return assign1.logical_buffer_id() <
292 assign2.logical_buffer_id();
293 });
294 return proto;
295 }
296
CompareHloValuesById(const HloValue * a,const HloValue * b)297 static bool CompareHloValuesById(const HloValue* a, const HloValue* b) {
298 return a->id() < b->id();
299 }
300
301 // Returns parameter instruction corresponding to the allocation or nullptr.
GetEntryParameterInstruction(const BufferAllocation & alloc)302 static const HloInstruction* GetEntryParameterInstruction(
303 const BufferAllocation& alloc) {
304 for (const auto& p : alloc.assigned_buffers()) {
305 const HloValue* value = p.first;
306 const HloInstruction* instr = value->instruction();
307 if (instr->opcode() == HloOpcode::kParameter &&
308 instr->parent() == instr->parent()->parent()->entry_computation()) {
309 return instr;
310 }
311 }
312 return nullptr;
313 }
314
315 // Returns root module output instruction corresponding to the allocation or
316 // nullptr.
GetOutputInstruction(const BufferAllocation & alloc)317 static const HloInstruction* GetOutputInstruction(
318 const BufferAllocation& alloc) {
319 for (const auto& p : alloc.assigned_buffers()) {
320 const HloValue* value = p.first;
321 for (const HloPosition& position : value->positions()) {
322 const HloInstruction* instr = position.instruction;
323 if (position.index.empty() &&
324 instr->parent()->root_instruction() == instr &&
325 instr->parent()->IsEntryComputation()) {
326 return instr;
327 }
328 }
329 }
330 return nullptr;
331 }
332
ToString() const333 string BufferAllocation::ToString() const {
334 string output;
335 StrAppendFormat(&output, "allocation %d: %p, size %d", index_, this, size());
336 if (color().value() != 0) {
337 StrAppend(&output, ", color ", color().value());
338 }
339 if (is_entry_computation_parameter()) {
340 const HloInstruction* param = GetEntryParameterInstruction(*this);
341 CHECK(param);
342 StrAppend(&output, ", parameter ", parameter_number(), ", shape |",
343 param->shape().ToString(/*print_layout=*/false),
344 "| at ShapeIndex ", param_shape_index().ToString());
345 }
346 if (const HloInstruction* instr = GetOutputInstruction(*this)) {
347 StrAppend(&output, ", output shape is |",
348 instr->shape().ToString(/*print_layout=*/false), "|");
349 }
350 if (is_constant()) {
351 StrAppend(&output, ", constant");
352 }
353 if (is_thread_local()) {
354 StrAppend(&output, ", thread-local");
355 }
356 if (maybe_live_out()) {
357 StrAppend(&output, ", maybe-live-out");
358 }
359 if (IsPreallocatedTempBuffer()) {
360 StrAppend(&output, ", preallocated-temp");
361 }
362 StrAppend(&output, ":\n");
363 // Dump the assigned buffers ordered by id.
364 std::vector<const HloValue*> sorted_buffers;
365 for (const auto& buffer_offset_size : assigned_buffers_) {
366 sorted_buffers.push_back(buffer_offset_size.first);
367 }
368 absl::c_sort(sorted_buffers, &CompareHloValuesById);
369 for (const HloValue* buffer : sorted_buffers) {
370 const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer);
371 StrAppend(&output,
372 absl::StrFormat(
373 " value: %s (size=%d,offset=%d): %s\n",
374 buffer->ToShortString(), offset_size.size, offset_size.offset,
375 ShapeUtil::HumanStringWithLayout(buffer->shape())));
376 }
377 return output;
378 }
379
operator <<(std::ostream & out,const BufferAllocation & buffer)380 std::ostream& operator<<(std::ostream& out, const BufferAllocation& buffer) {
381 out << buffer.ToString();
382 return out;
383 }
384
operator <<(std::ostream & out,const BufferAllocation::Slice & s)385 std::ostream& operator<<(std::ostream& out, const BufferAllocation::Slice& s) {
386 out << s.ToString();
387 return out;
388 }
389
HasAllocation(const HloValue & value) const390 bool BufferAssignment::HasAllocation(const HloValue& value) const {
391 return allocation_index_for_value_.contains(&value);
392 }
393
HasAllocation(const HloBuffer & buffer) const394 bool BufferAssignment::HasAllocation(const HloBuffer& buffer) const {
395 return allocation_index_for_value_.contains(buffer.values()[0]);
396 }
397
GetAssignedAllocation(const HloValue & value) const398 const BufferAllocation& BufferAssignment::GetAssignedAllocation(
399 const HloValue& value) const {
400 CHECK(HasAllocation(value));
401 return GetAllocation(allocation_index_for_value_.at(&value));
402 }
403
GetAssignedAllocation(const HloBuffer & hlo_buffer) const404 const BufferAllocation& BufferAssignment::GetAssignedAllocation(
405 const HloBuffer& hlo_buffer) const {
406 return GetAssignedAllocation(*hlo_buffer.values()[0]);
407 }
408
GetMutableAssignedAllocation(const HloBuffer & buffer)409 BufferAllocation* BufferAssignment::GetMutableAssignedAllocation(
410 const HloBuffer& buffer) {
411 return const_cast<BufferAllocation*>(&GetAssignedAllocation(buffer));
412 }
413
GetAllSlices(const HloInstruction * instruction,const ShapeIndex & index) const414 std::set<BufferAllocation::Slice> BufferAssignment::GetAllSlices(
415 const HloInstruction* instruction, const ShapeIndex& index) const {
416 std::set<BufferAllocation::Slice> result;
417 for (const HloValue* value :
418 dataflow_analysis().GetValueSet(instruction, index).values()) {
419 if (HasAllocation(*value)) {
420 result.insert(GetAssignedAllocation(*value).GetSlice(*value));
421 }
422 }
423 return result;
424 }
425
GetAllocation(BufferAllocation::Index index) const426 const BufferAllocation& BufferAssignment::GetAllocation(
427 BufferAllocation::Index index) const {
428 CHECK_GE(index, 0);
429 CHECK_LT(index, allocations_.size());
430 return allocations_[index];
431 }
432
GetInstructionAllocation(const HloInstruction * hlo,const ShapeIndex & shape_index) const433 const BufferAllocation* BufferAssignment::GetInstructionAllocation(
434 const HloInstruction* hlo, const ShapeIndex& shape_index) const {
435 const HloValue* value =
436 dataflow_analysis().GetValueSet(hlo, shape_index).values()[0];
437
438 if (!HasAllocation(*value)) {
439 return nullptr;
440 }
441
442 const BufferAllocation& instruction_allocation =
443 GetAssignedAllocation(*value);
444 return &instruction_allocation;
445 }
446
GetMutableAllocation(BufferAllocation::Index index)447 BufferAllocation* BufferAssignment::GetMutableAllocation(
448 BufferAllocation::Index index) {
449 return const_cast<BufferAllocation*>(&GetAllocation(index));
450 }
451
HasAllocationAt(const HloInstruction * instruction,const ShapeIndex & index) const452 bool BufferAssignment::HasAllocationAt(const HloInstruction* instruction,
453 const ShapeIndex& index) const {
454 for (const HloValue* value :
455 dataflow_analysis().GetValueSet(instruction, index).values()) {
456 if (allocation_index_for_value_.contains(value)) {
457 return true;
458 }
459 }
460 return false;
461 }
462
HasTopLevelAllocation(const HloInstruction * instruction) const463 bool BufferAssignment::HasTopLevelAllocation(
464 const HloInstruction* instruction) const {
465 return HasAllocationAt(instruction, /*index=*/{});
466 }
467
GetUniqueSlice(const HloInstruction * instruction,const ShapeIndex & index) const468 StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueSlice(
469 const HloInstruction* instruction, const ShapeIndex& index) const {
470 VLOG(3) << "Trying to find unique slice for " << instruction->name() << " ["
471 << index << "]";
472 BufferAllocation::Slice result;
473 for (const HloValue* value :
474 dataflow_analysis().GetValueSet(instruction, index).values()) {
475 VLOG(3) << "Examining value " << *value;
476 if (HasAllocation(*value)) {
477 VLOG(3) << "Has allocation";
478 const BufferAllocation::Slice slice =
479 GetAssignedAllocation(*value).GetSlice(*value);
480 if (result.allocation() == nullptr) {
481 result = slice;
482 } else if (result != slice) {
483 return FailedPrecondition(
484 "BufferAllocation::Slice for instruction %s at index %s cannot "
485 "be determined at compile-time.",
486 instruction->name(), index.ToString());
487 }
488 } else {
489 VLOG(3) << "No allocation";
490 }
491 }
492 if (result.allocation() == nullptr) {
493 return FailedPrecondition(
494 "BufferAllocation::Slice not assigned for instruction %s at index %s",
495 instruction->name(), index.ToString());
496 }
497 return result;
498 }
499
GetUniqueTopLevelSlice(const HloInstruction * instruction) const500 StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueTopLevelSlice(
501 const HloInstruction* instruction) const {
502 return GetUniqueSlice(instruction, /*index=*/{});
503 }
504
SharesSliceAtIndex(const HloInstruction * hlo_a,const ShapeIndex & shape_index_a,const HloInstruction * hlo_b,const ShapeIndex & shape_index_b) const505 bool BufferAssignment::SharesSliceAtIndex(
506 const HloInstruction* hlo_a, const ShapeIndex& shape_index_a,
507 const HloInstruction* hlo_b, const ShapeIndex& shape_index_b) const {
508 return GetUniqueSlice(hlo_a, shape_index_a).ConsumeValueOrDie() ==
509 GetUniqueSlice(hlo_b, shape_index_b).ConsumeValueOrDie();
510 }
511
HaveDisjointSlices(const HloInstruction * hlo_a,const HloInstruction * hlo_b) const512 bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a,
513 const HloInstruction* hlo_b) const {
514 using SliceSet = flat_hash_set<BufferAllocation::Slice>;
515 // Gets the slices all of instr's subshapes. If any subshape doesn't have an
516 // assigned slice, returns the empty set.
517 auto collect_slices = [&](const HloInstruction* instr) -> SliceSet {
518 SliceSet slices;
519 Status status = ShapeUtil::ForEachSubshapeWithStatus(
520 instr->shape(),
521 [&](const Shape& /*subshape*/, const ShapeIndex& index) {
522 auto shape_slices = GetAllSlices(instr, index);
523 if (shape_slices.empty()) {
524 return InvalidArgument("No slices assigned to part of instr.");
525 }
526 slices.insert(shape_slices.begin(), shape_slices.end());
527 return Status::OK();
528 });
529 if (!status.ok()) {
530 return {};
531 }
532 return slices;
533 };
534
535 SliceSet slices_a = collect_slices(hlo_a);
536 SliceSet slices_b = collect_slices(hlo_b);
537 // hlo_a and hlo_b have disjoint slices if collect_slices succeeded (i.e.
538 // didn't return the empty set) for both HLOs, and the two resulting sets of
539 // slices are disjoint.
540 return !slices_a.empty() && !slices_b.empty() &&
541 absl::c_none_of(slices_a, [&](const BufferAllocation::Slice& slice) {
542 return slices_b.contains(slice);
543 });
544 }
545
546 StatusOr<BufferAllocation::Slice>
GetUniqueTopLevelOutputSlice() const547 BufferAssignment::GetUniqueTopLevelOutputSlice() const {
548 return GetUniqueTopLevelSlice(
549 module_->entry_computation()->root_instruction());
550 }
551
NewEmptyAllocation(int64 size,LogicalBuffer::Color color)552 BufferAllocation* BufferAssignment::NewEmptyAllocation(
553 int64 size, LogicalBuffer::Color color) {
554 BufferAllocation::Index index = allocations_.size();
555 allocations_.emplace_back(index, size, color);
556 BufferAllocation* allocation = &allocations_.back();
557 return allocation;
558 }
559
NewAllocation(const HloBuffer & buffer,int64 size)560 BufferAllocation* BufferAssignment::NewAllocation(const HloBuffer& buffer,
561 int64 size) {
562 BufferAllocation* allocation = NewEmptyAllocation(size, buffer.color());
563 AddAssignment(allocation, buffer, /*offset=*/0, size);
564 allocation->peak_buffers_.push_back(buffer.values()[0]);
565 return allocation;
566 }
567
AddAssignment(BufferAllocation * allocation,const HloBuffer & buffer,int64 offset,int64 size)568 void BufferAssignment::AddAssignment(BufferAllocation* allocation,
569 const HloBuffer& buffer, int64 offset,
570 int64 size) {
571 CHECK(allocation->is_reusable() || allocation->assigned_buffers().empty())
572 << "Non-reusable allocation already assigned a buffer: "
573 << allocation->ToString();
574
575 for (const HloValue* buffer_value : buffer.values()) {
576 CHECK(!allocation_index_for_value_.contains(buffer_value))
577 << "BufferValue " << buffer_value << " already has an allocation.";
578 allocation->AddAssignment(*buffer_value, offset, size);
579 allocation_index_for_value_[buffer_value] = allocation->index();
580 }
581
582 if (alias_analysis().BufferLivesOut(buffer)) {
583 VLOG(3) << "HloBuffer lives out" << buffer.ToString();
584 VLOG(3) << "Set maybe live out: " << allocation->ToString();
585 allocation->set_maybe_live_out(true);
586 }
587 }
588
AddAssignment(BufferAllocation * allocation,const HloValue & value,int64 offset,int64 size)589 void BufferAssignment::AddAssignment(BufferAllocation* allocation,
590 const HloValue& value, int64 offset,
591 int64 size) {
592 allocation->AddAssignment(value, offset, size);
593 allocation_index_for_value_[&value] = allocation->index();
594 const HloValue& hlo_value =
595 *CHECK_NOTNULL(dynamic_cast<const HloValue*>(&value));
596 if (alias_analysis().ValueLivesOut(hlo_value)) {
597 VLOG(3) << "HloValue lives out: " << hlo_value.ToString();
598 VLOG(3) << "Set maybe live out: " << allocation->ToString();
599 allocation->set_maybe_live_out(true);
600 }
601 }
602
603 // Combines allocations of temporary buffers of the same color into one big
604 // BufferAllocation.
CombineTempAllocations()605 void BufferAssignment::CombineTempAllocations() {
606 VLOG(1) << "CombineTempAllocations()";
607 flat_hash_map<BufferValue::Color, BufferAllocation,
608 BufferValue::Color::Hasher>
609 combined_allocation_map;
610
611 // Move all temp allocations into a single run at the end of the allocations
612 // vector.
613 const auto first_temp_it =
614 std::partition(allocations_.begin(), allocations_.end(),
615 [](const BufferAllocation& allocation) {
616 return !allocation.IsPreallocatedTempBuffer();
617 });
618
619 // Walk over the run of temp allocations, collecting the allocations belonging
620 // to the same color.
621 if (first_temp_it != allocations_.end()) {
622 for (auto it = first_temp_it; it != allocations_.end(); ++it) {
623 const BufferAllocation& temp_allocation = *it;
624 BufferValue::Color color = temp_allocation.color();
625 auto combined_it = combined_allocation_map.find(color);
626 if (combined_it == combined_allocation_map.end()) {
627 // We have found the first temp allocation of this color. Collect
628 // the other temp allocations of the same color into it.
629 VLOG(1) << "Combined temp allocation for color " << color
630 << " is: " << temp_allocation;
631 combined_allocation_map.emplace(color, temp_allocation);
632 continue;
633 }
634
635 auto* combined_allocation = &combined_it->second;
636 VLOG(1) << "Combined allocation absorbing temp allocation: "
637 << temp_allocation;
638
639 // Each temp allocation is placed end-to-end, accounting for alignment.
640 // The offset of each buffer in the combined allocation is computed from
641 // the base offset of the allocation.
642 int64 alignment = color_alignment_(color);
643 const int64 base =
644 RoundUpToNearest(combined_allocation->size(), alignment);
645 combined_allocation->set_size(base + temp_allocation.size());
646 for (const auto& buffer_offset_size : temp_allocation.assigned_buffers_) {
647 const HloValue* value = buffer_offset_size.first;
648 const int64 offset = buffer_offset_size.second.offset;
649 const int64 size = buffer_offset_size.second.size;
650 combined_allocation->AddAssignment(*value, base + offset, size);
651 }
652 if (!temp_allocation.HeapTraces().empty()) {
653 CHECK_EQ(temp_allocation.HeapTraces().size(), 1);
654 combined_allocation->AddHeapTrace(temp_allocation.HeapTraces().front());
655 }
656
657 combined_allocation->peak_buffers_.insert(
658 combined_allocation->peak_buffers_.end(),
659 temp_allocation.peak_buffers_.begin(),
660 temp_allocation.peak_buffers_.end());
661 }
662 // Replace all existing temporary allocations with the new combined
663 // allocations.
664 allocations_.erase(first_temp_it, allocations_.end());
665 for (auto& combined : combined_allocation_map) {
666 allocations_.push_back(combined.second);
667 temp_allocation_total_size_ += combined.second.size();
668 }
669 }
670
671 // Update allocation indices to their new positions.
672 allocation_index_for_value_.erase(allocation_index_for_value_.begin(),
673 allocation_index_for_value_.end());
674 for (size_t index = 0; index < allocations_.size(); ++index) {
675 BufferAllocation* allocation = &allocations_[index];
676 allocation->set_index(index);
677 for (const auto& buffer_offset_size : allocation->assigned_buffers_) {
678 const HloValue* value = buffer_offset_size.first;
679 allocation_index_for_value_[value] = index;
680 }
681 }
682 }
683
ComputeSummaryStats()684 Status BufferAssignment::ComputeSummaryStats() {
685 for (auto& allocation : Allocations()) {
686 if (allocation.is_entry_computation_parameter()) {
687 stats_.parameter_allocation_count++;
688 stats_.parameter_allocation_bytes += allocation.size();
689 }
690 if (allocation.is_constant()) {
691 stats_.constant_allocation_count++;
692 stats_.constant_allocation_bytes += allocation.size();
693 }
694 if (allocation.maybe_live_out()) {
695 stats_.maybe_live_out_allocation_count++;
696 stats_.maybe_live_out_allocation_bytes += allocation.size();
697 }
698 if (allocation.IsPreallocatedTempBuffer()) {
699 stats_.preallocated_temp_allocation_count++;
700 stats_.preallocated_temp_allocation_bytes += allocation.size();
701 }
702 stats_.total_allocation_count++;
703 stats_.total_allocation_bytes += allocation.size();
704 }
705
706 // Only compute total fragmentation if all computations have schedules.
707 HloSchedule schedule(module_);
708 bool schedule_complete = true;
709 for (const auto& computation : module_->computations()) {
710 if (!computation->IsFusionComputation()) {
711 const HloInstructionSequence* sequence =
712 hlo_ordering().SequentialOrder(*computation);
713 if (sequence == nullptr) {
714 schedule_complete = false;
715 } else {
716 schedule.set_sequence(computation, *sequence);
717 }
718 }
719 }
720 if (schedule_complete) {
721 TF_RETURN_IF_ERROR(schedule.Verify());
722 TF_ASSIGN_OR_RETURN(
723 const int64 min_size,
724 HeapSimulator::MinimumMemoryForModule(schedule, buffer_size_));
725 stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size;
726 }
727
728 return Status::OK();
729 }
730
ToString() const731 string BufferAssignment::Stats::ToString() const {
732 string s;
733 StrAppendFormat(&s, "BufferAssignment stats:\n");
734 StrAppendFormat(&s, " parameter allocation: %10s\n",
735 HumanReadableNumBytes(parameter_allocation_bytes));
736 StrAppendFormat(&s, " constant allocation: %10s\n",
737 HumanReadableNumBytes(constant_allocation_bytes));
738 StrAppendFormat(&s, " maybe_live_out allocation: %10s\n",
739 HumanReadableNumBytes(maybe_live_out_allocation_bytes));
740 StrAppendFormat(&s, " preallocated temp allocation: %10s\n",
741 HumanReadableNumBytes(preallocated_temp_allocation_bytes));
742 if (preallocated_temp_fragmentation_bytes >= 0) {
743 const double percent = 100. * preallocated_temp_fragmentation_bytes /
744 preallocated_temp_allocation_bytes;
745 StrAppendFormat(
746 &s, " preallocated temp fragmentation: %10s (%.2f%%)\n",
747 HumanReadableNumBytes(preallocated_temp_fragmentation_bytes), percent);
748 }
749 StrAppendFormat(&s, " total allocation: %10s\n",
750 HumanReadableNumBytes(total_allocation_bytes));
751 if (total_fragmentation_bytes >= 0) {
752 const double percent =
753 100. * total_fragmentation_bytes / total_allocation_bytes;
754 StrAppendFormat(&s, " total fragmentation: %10s (%.2f%%)\n",
755 HumanReadableNumBytes(total_fragmentation_bytes), percent);
756 }
757 return s;
758 }
759
ToString() const760 string BufferAssignment::ToString() const {
761 string output;
762 absl::StrAppend(&output, "BufferAssignment:\n");
763 std::vector<const HloValue*> used_values;
764 int64 total_size = 0;
765 for (auto& allocation : allocations_) {
766 total_size += allocation.size();
767 absl::StrAppend(&output, allocation.ToString());
768 for (const auto& p : allocation.assigned_buffers()) {
769 used_values.push_back(p.first);
770 }
771 }
772 absl::StrAppend(&output, "\nTotal bytes used: ", total_size, "\n");
773 absl::StrAppend(&output, "\nUsed values:\n");
774 absl::c_sort(used_values, &CompareHloValuesById);
775 for (const HloValue* value : used_values) {
776 absl::StrAppend(&output, value->ToString());
777 }
778 return output;
779 }
780
ToProto() const781 BufferAssignmentProto BufferAssignment::ToProto() const {
782 BufferAssignmentProto proto;
783 // NOTE: DataflowAnalysis state is serialized here in BufferAssignment,
784 // because we need to do the HasAllocation check for each buffer. Otherwise
785 // the buffer_size_ call might fail for some backends.
786 const HloDataflowAnalysis& dataflow = this->dataflow_analysis();
787 for (BufferValue::Id id = 0; id < dataflow.values().size(); id++) {
788 auto& value = dataflow.values().at(id);
789 if (HasAllocation(*value)) {
790 LogicalBufferProto proto_buffer = value->ToProto(buffer_size_);
791 proto.add_logical_buffers()->Swap(&proto_buffer);
792
793 // Fill buffer aliases.
794 for (const HloValue* alias :
795 alias_analysis().GetBufferContainingValue(*value).values()) {
796 if (alias->instruction() == value->instruction() &&
797 alias->index() == value->index()) {
798 continue; // skip self-aliases
799 }
800 BufferAssignmentProto::BufferAlias* proto_alias =
801 proto.add_buffer_aliases();
802 LogicalBufferProto::Location proto_alias_location =
803 BufferValue::ToLocationProto(*alias->instruction(), alias->index());
804 proto_alias->set_source_buffer_id(value->id());
805 proto_alias->mutable_location()->Swap(&proto_alias_location);
806 }
807 }
808 }
809 for (const BufferAllocation& allocation : Allocations()) {
810 BufferAllocationProto proto_allocation = allocation.ToProto();
811 proto.add_buffer_allocations()->Swap(&proto_allocation);
812 for (const HeapSimulatorTrace& heap_trace : allocation.HeapTraces()) {
813 *proto.add_heap_simulator_traces() = heap_trace;
814 }
815 }
816 return proto;
817 }
818
819 /* static */
Run(const HloModule * module,std::unique_ptr<HloOrdering> hlo_ordering,BufferValue::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment,bool allocate_buffers_for_constants,BufferAssigner::Colorer colorer,const absl::flat_hash_set<HloOpcode> & reuse_checker,HloDataflowAnalysis::CanShareBuffer can_share_buffer,std::unique_ptr<PresetAssignments> preset_assignments)820 StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::Run(
821 const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
822 BufferValue::SizeFunction buffer_size,
823 LogicalBuffer::AlignmentFunction color_alignment,
824 bool allocate_buffers_for_constants, BufferAssigner::Colorer colorer,
825 const absl::flat_hash_set<HloOpcode>& reuse_checker,
826 HloDataflowAnalysis::CanShareBuffer can_share_buffer,
827 std::unique_ptr<PresetAssignments> preset_assignments) {
828 BufferAssigner assigner(allocate_buffers_for_constants, std::move(colorer),
829 reuse_checker, std::move(preset_assignments));
830 return assigner.CreateAssignment(
831 module, std::move(hlo_ordering), std::move(buffer_size),
832 std::move(color_alignment), std::move(can_share_buffer));
833 }
834
LiveRangeInterferes(const HloValue * buffer1,const HloValue * buffer2,BufferAssignment * assignment)835 bool BufferAssigner::LiveRangeInterferes(const HloValue* buffer1,
836 const HloValue* buffer2,
837 BufferAssignment* assignment) {
838 CHECK((assignment->hlo_live_range().total_order_scheduled()));
839 const HloLiveRange& hlo_live_range = assignment->hlo_live_range();
840
841 const auto& buffer_live_ranges = hlo_live_range.buffer_live_ranges();
842
843 CHECK(buffer_live_ranges.contains(buffer1))
844 << "Buffer doesn't have a proper live range:" << buffer1;
845
846 CHECK(buffer_live_ranges.contains(buffer2))
847 << "Buffer doesn't have a proper live range:" << buffer2;
848
849 // Check if a user value can share the same buffer as its operand.
850 auto can_share_as_operand = [&assignment](const HloValue* user_value,
851 const HloValue* operand_value) {
852 return user_value->instruction()->IsUserOf(operand_value->instruction()) &&
853 assignment->dataflow_analysis().CanShareOperandBufferWithUser(
854 operand_value->instruction(), operand_value->index(),
855 user_value->instruction(), user_value->index()) &&
856 user_value->instruction()->opcode() != HloOpcode::kCopy;
857 };
858
859 auto live_range_1 = buffer_live_ranges.at(buffer1);
860 auto live_range_2 = buffer_live_ranges.at(buffer2);
861
862 if (!(live_range_1.start > live_range_2.end ||
863 live_range_2.start > live_range_1.end)) {
864 if (live_range_1.end == live_range_2.start) {
865 auto operand_value = buffer1;
866 auto user_value = buffer2;
867 if (!can_share_as_operand(user_value, operand_value)) {
868 VLOG(4) << "End of live range of " << buffer1->ToShortString()
869 << " is equal to the start of live range of "
870 << buffer2->ToShortString() << ", buffer cannot be shared.";
871 return true;
872 }
873 } else if (live_range_2.end == live_range_1.start) {
874 auto operand_value = buffer2;
875 auto user_value = buffer1;
876 if (!can_share_as_operand(user_value, operand_value)) {
877 VLOG(4) << "End of live range of " << buffer2->ToShortString()
878 << " is equal to the start of live range of "
879 << buffer1->ToShortString() << ", buffer cannot be shared.";
880 return true;
881 }
882 } else {
883 VLOG(4) << "Can't assign: assignee " << *buffer1 << " may interfere with "
884 << *buffer2;
885 VLOG(4) << "assigned_buffer.start: " << live_range_1.start;
886 VLOG(4) << "assigned_buffer.end: " << live_range_1.end;
887 VLOG(4) << "live_range_2.start" << live_range_2.start;
888 VLOG(4) << "live_range_2.end" << live_range_2.end;
889 return true;
890 }
891 }
892 return false;
893 }
894
MaybeAssignBuffer(BufferAllocation * allocation,const HloBuffer & hlo_buffer,BufferAssignment * assignment)895 bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
896 const HloBuffer& hlo_buffer,
897 BufferAssignment* assignment) {
898 CHECK(!assignment->HasAllocation(hlo_buffer))
899 << "buffer " << hlo_buffer << " already has an allocation assigned.";
900
901 VLOG(4) << "Trying to assign " << hlo_buffer << " size "
902 << assignment->HloBufferSize(hlo_buffer)
903 << " to allocation: " << *allocation;
904
905 if (hlo_buffer.color() != allocation->color()) {
906 VLOG(4) << "Can't assign: buffer has color " << hlo_buffer.color()
907 << " and allocation has color " << allocation->color() << ".";
908 return false;
909 }
910
911 if (assignment->HloBufferSize(hlo_buffer) > allocation->size()) {
912 VLOG(4) << "Can't assign: buffer is larger than allocation ("
913 << assignment->HloBufferSize(hlo_buffer) << " > "
914 << allocation->size() << ")";
915 return false;
916 }
917
918 if (allocation->is_readonly()) {
919 VLOG(4) << "Can't assign: allocation is readonly";
920 return false;
921 }
922
923 if (!must_not_live_out_.empty()) {
924 if (allocation->maybe_live_out()) {
925 // If a buffer maybe live out, the allocation cannot contain any node from
926 // the "must_not_live_out_" set.
927 for (const HloValue* value : hlo_buffer.values()) {
928 if (must_not_live_out_.count(value->instruction()->opcode()) > 0) {
929 VLOG(4) << "Can't assign: " << value->instruction()->ToString()
930 << " cannot live out of the module";
931 return false;
932 }
933 }
934 }
935 // The above check is not enough -- There could be the case where an
936 // allocation can be not live out and contains an instruction with opcode
937 // from the "must_not_live_out_" set, but assigning a live out buffer to
938 // that allocation makes the allocation live out and also contains
939 // instruction from the "must_not_live_out_" set.
940 if (assignment->alias_analysis().BufferLivesOut(hlo_buffer)) {
941 for (const auto& buffer_offset_size : allocation->assigned_buffers()) {
942 if (must_not_live_out_.count(
943 buffer_offset_size.first->instruction()->opcode()) > 0) {
944 VLOG(4) << "Can't assign: " << buffer_offset_size.first->instruction()
945 << " cannot live out of the module";
946 return false;
947 }
948 }
949 }
950 }
951
952 if (!allocation->is_reusable()) {
953 VLOG(4) << "Can't assign: allocation is not reusable";
954 return false;
955 }
956
957 for (const auto& buffer_offset_size : allocation->assigned_buffers()) {
958 // Pairwise compare.
959 const HloValue& assigned_buffer =
960 *CHECK_NOTNULL(dynamic_cast<const HloValue*>(buffer_offset_size.first));
961 for (const HloValue* new_value : hlo_buffer.values()) {
962 if (assignment->hlo_live_range().total_order_scheduled()) {
963 if (LiveRangeInterferes(new_value, &assigned_buffer, assignment)) {
964 VLOG(4) << "Can't assign: assignee " << assigned_buffer
965 << " live range interferes with "
966 << new_value->ToShortString();
967 return false;
968 }
969 } else if (assignment->hlo_ordering().MayInterfere(
970 assigned_buffer, *new_value,
971 assignment->dataflow_analysis())) {
972 // Fallback to partial order based interference detection (slower) when
973 // we don't have a total order scheduled module.
974 VLOG(4) << "Can't assign: assignee " << assigned_buffer
975 << " may interfere with " << new_value->ToShortString();
976 return false;
977 }
978
979 for (const HloPosition& assigned_buffer_position :
980 assigned_buffer.positions()) {
981 // Copy instruction don't share a buffer with their input operand.
982 if (new_value->instruction()->IsUserOf(
983 assigned_buffer_position.instruction) &&
984 new_value->instruction()->opcode() == HloOpcode::kCopy) {
985 VLOG(4) << "Can't assign: assignee " << assigned_buffer
986 << " is used at copy instruction "
987 << new_value->ToShortString();
988 return false;
989 }
990 }
991 }
992 }
993
994 // If the buffer is live out of the computation then it should only be
995 // assigned a buffer which exactly fits the result to avoid wasting memory
996 // (result buffers can have arbitrary lifetimes).
997 if (assignment->alias_analysis().BufferLivesOut(hlo_buffer) &&
998 allocation->size() != assignment->HloBufferSize(hlo_buffer)) {
999 VLOG(4) << "Can't assign: buffer " << hlo_buffer
1000 << "is live out and size not the same as allocation";
1001 return false;
1002 }
1003
1004 assignment->AddAssignment(allocation, hlo_buffer, /*offset=*/0,
1005 assignment->HloBufferSize(hlo_buffer));
1006 return true;
1007 } // namespace xla
1008
MergeInplaceOpBuffers(BufferAssignment * assignment)1009 Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) {
1010 // Try allocate same buffer for dynamic update slice's operand and output.
1011 //
1012 // TODO(yunxing): Moving this logic to alias analysis and add must-alias rule
1013 // to operations that can be done in place.
1014 for (HloComputation* computation : assignment->module().computations()) {
1015 for (HloInstruction* instruction : computation->instructions()) {
1016 if (!(instruction->opcode() == HloOpcode::kDynamicUpdateSlice ||
1017 (instruction->opcode() == HloOpcode::kFusion &&
1018 (instruction->fused_expression_root()->opcode() ==
1019 HloOpcode::kDynamicUpdateSlice)))) {
1020 continue;
1021 }
1022 if (instruction->parent()->IsFusionComputation()) {
1023 continue;
1024 }
1025 if (instruction->operand_count() == 0) {
1026 continue;
1027 }
1028
1029 // The operand can't share the same buffer with the user based on dataflow
1030 // analysis.
1031 if (!assignment->dataflow_analysis().CanShareOperandBufferWithUser(
1032 instruction->mutable_operand(0), {}, instruction, {})) {
1033 continue;
1034 }
1035 HloBuffer& instruction_buffer =
1036 assignment->alias_analysis().GetUniqueBufferAt(instruction, {});
1037
1038 HloBuffer& operand_buffer =
1039 assignment->alias_analysis().GetUniqueBufferAt(
1040 instruction->operand(0), {});
1041
1042 // Already have the same buffer. No need to merge those.
1043 if (instruction_buffer.id() == operand_buffer.id()) {
1044 continue;
1045 }
1046
1047 // Do not perform in-place dynamic update slice if the operand buffer is
1048 // read-only.
1049 if (HloBufferIsReadOnly(operand_buffer)) {
1050 continue;
1051 }
1052
1053 bool interfere = false;
1054
1055 for (const HloValue* instruction_value : instruction_buffer.values()) {
1056 for (const HloValue* operand_value : operand_buffer.values()) {
1057 if (assignment->hlo_ordering().MayInterfere(
1058 *instruction_value, *operand_value,
1059 assignment->dataflow_analysis())) {
1060 interfere = true;
1061 break;
1062 }
1063 }
1064 }
1065 if (interfere) {
1066 continue;
1067 }
1068 if (assignment->alias_analysis().BufferLivesOut(instruction_buffer)) {
1069 continue;
1070 }
1071 if (instruction_buffer.color() != operand_buffer.color()) {
1072 continue;
1073 }
1074 VLOG(3) << "Merging inplace " << instruction_buffer << " and "
1075 << operand_buffer;
1076 assignment->alias_analysis().MergeBuffers(instruction_buffer,
1077 operand_buffer);
1078 }
1079 }
1080 return Status::OK();
1081 }
1082
AssignSingleHloBuffer(const HloBuffer * hlo_buffer,bool is_thread_local,absl::flat_hash_map<const HloComputation *,absl::flat_hash_set<const HloValue * >> * buffers_to_assign_sequentially,std::vector<BufferAllocation::Index> * allocation_indices,BufferAssignment * assignment)1083 Status BufferAssigner::AssignSingleHloBuffer(
1084 const HloBuffer* hlo_buffer, bool is_thread_local,
1085 absl::flat_hash_map<const HloComputation*,
1086 absl::flat_hash_set<const HloValue*>>*
1087 buffers_to_assign_sequentially,
1088 std::vector<BufferAllocation::Index>* allocation_indices,
1089 BufferAssignment* assignment) {
1090 const int64 buffer_size = assignment->HloBufferSize(*hlo_buffer);
1091 for (const HloValue* value : hlo_buffer->values()) {
1092 if (value->instruction()->opcode() == HloOpcode::kConstant) {
1093 if (allocate_buffers_for_constants_) {
1094 BufferAllocation* allocation =
1095 assignment->NewAllocation(*hlo_buffer, buffer_size);
1096 allocation->set_constant(true);
1097 VLOG(3) << "New allocation #" << allocation->index() << " for constant "
1098 << *hlo_buffer << " value ptr: " << value;
1099 }
1100 VLOG(3) << "Not allocating buffer for constant";
1101 return Status::OK();
1102 }
1103
1104 const HloInstruction* instruction = value->instruction();
1105 const bool is_entry_parameter =
1106 instruction->opcode() == HloOpcode::kParameter &&
1107 instruction->parent() ==
1108 instruction->parent()->parent()->entry_computation();
1109
1110 if (is_entry_parameter) {
1111 bool parameter_has_alias =
1112 assignment->module().input_output_alias_config().ParameterHasAlias(
1113 instruction->parameter_number(), value->index());
1114 // If the hlo buffer is part of an external parameter, creates a new
1115 // allocation and sets its parameter number. Parameters of non-entry
1116 // computations do not need special allocations because they live inside
1117 // callers.
1118 BufferAllocation* allocation =
1119 assignment->NewAllocation(*hlo_buffer, buffer_size);
1120
1121 allocation->set_entry_computation_parameter(
1122 instruction->parameter_number(), value->index(), parameter_has_alias);
1123 if (parameter_has_alias) {
1124 allocation_indices->push_back(allocation->index());
1125 }
1126 VLOG(3) << "New allocation #" << allocation->index()
1127 << " marked as entry computation parameter: " << *hlo_buffer;
1128 return Status::OK();
1129 }
1130 }
1131
1132 if (is_thread_local) {
1133 BufferAllocation* allocation =
1134 assignment->NewAllocation(*hlo_buffer, buffer_size);
1135 allocation->set_is_thread_local(true);
1136 VLOG(3) << "New allocation #" << allocation->index()
1137 << " for thread-local: " << *hlo_buffer;
1138 return Status::OK();
1139 }
1140
1141 for (const HloValue* value : hlo_buffer->values()) {
1142 if (value->shape().IsTuple()) {
1143 BufferAllocation* allocation =
1144 assignment->NewAllocation(*hlo_buffer, buffer_size);
1145 allocation->set_is_tuple(true);
1146 VLOG(3) << "New allocation #" << allocation->index()
1147 << " for tuple-shaped buffer: " << *hlo_buffer;
1148 return Status::OK();
1149 }
1150
1151 if (value->IsTopLevel() && !value->IsTuple()) {
1152 const HloInstruction* instruction = value->instruction();
1153 for (auto* operand : instruction->operands()) {
1154 for (const auto& operand_slice :
1155 assignment->GetAllSlices(operand, /*index=*/{})) {
1156 BufferAllocation* allocation =
1157 assignment->GetMutableAllocation(operand_slice.index());
1158 if (MaybeAssignBuffer(allocation, *hlo_buffer, assignment)) {
1159 VLOG(3) << "Reusing (operand) allocation #" << allocation->index()
1160 << " for: " << *hlo_buffer;
1161 return Status::OK();
1162 }
1163 }
1164 }
1165 }
1166 }
1167
1168 // Find the smallest buffer which can be reused iterating from end of
1169 // allocation_indices (smallest) to beginning (largest).
1170 for (int allocation_index = allocation_indices->size() - 1;
1171 allocation_index >= 0; allocation_index--) {
1172 BufferAllocation* allocation = assignment->GetMutableAllocation(
1173 allocation_indices->at(allocation_index));
1174 if (MaybeAssignBuffer(allocation, *hlo_buffer, assignment)) {
1175 VLOG(3) << "Reusing allocation #" << allocation->index()
1176 << " for: " << *hlo_buffer;
1177 return Status::OK();
1178 }
1179 }
1180
1181 if (!assignment->HasAllocation(*hlo_buffer) &&
1182 !assignment->alias_analysis().BufferLivesOut(*hlo_buffer)) {
1183 bool all_computations_have_sequential_order = true;
1184 for (const HloValue* hlo_value : hlo_buffer->values()) {
1185 HloComputation* computation = hlo_value->instruction()->parent();
1186 const bool has_sequential_order =
1187 assignment->hlo_ordering().SequentialOrder(*computation) != nullptr;
1188 all_computations_have_sequential_order &= has_sequential_order;
1189 }
1190
1191 if (all_computations_have_sequential_order) {
1192 for (const HloValue* hlo_value : hlo_buffer->values()) {
1193 HloComputation* computation = hlo_value->instruction()->parent();
1194 // There is a sequential instruction ordering, so we delay assignment
1195 // of temp buffers until after the loop. We do this right before we
1196 // decide to create a new allocation, to ensure we've exhausted all
1197 // the buffer re-use cases above.
1198 //
1199 // Entry parameters and thread local buffers were already handled
1200 // earlier in this loop iteration. See
1201 // BufferAllocation::IsPreallocatedTempBuffer for the definition of
1202 // temp buffers.
1203 (*buffers_to_assign_sequentially)[computation].insert(hlo_value);
1204 VLOG(3) << "Delaying assignment of temp buffer: " << *hlo_value;
1205 }
1206 return Status::OK();
1207 }
1208 }
1209
1210 if (!assignment->HasAllocation(*hlo_buffer)) {
1211 BufferAllocation* allocation =
1212 assignment->NewAllocation(*hlo_buffer, buffer_size);
1213 allocation_indices->push_back(allocation->index());
1214 VLOG(3) << "New allocation #" << allocation->index()
1215 << " for: " << *hlo_buffer;
1216 }
1217
1218 TF_RET_CHECK(assignment->HasAllocation(*hlo_buffer));
1219 return Status::OK();
1220 }
1221
AssignBuffersForComputations(const std::vector<const HloComputation * > & computations,bool is_thread_local,absl::flat_hash_map<const HloComputation *,absl::flat_hash_set<const HloValue * >> * buffers_to_assign_sequentially,BufferAssignment * assignment)1222 Status BufferAssigner::AssignBuffersForComputations(
1223 const std::vector<const HloComputation*>& computations,
1224 bool is_thread_local,
1225 absl::flat_hash_map<const HloComputation*,
1226 absl::flat_hash_set<const HloValue*>>*
1227 buffers_to_assign_sequentially,
1228 BufferAssignment* assignment) {
1229 if (computations.empty()) {
1230 return Status::OK();
1231 }
1232 std::vector<const HloBuffer*> sorted_buffers;
1233
1234 // First assign the preset allocations.
1235 absl::flat_hash_set<const HloBuffer*> preset_assigned_buffers;
1236
1237 TF_RETURN_IF_ERROR(AssignPresetBuffers(&preset_assigned_buffers, assignment));
1238
1239 const HloAliasAnalysis& alias_analysis = assignment->alias_analysis();
1240
1241 for (const HloBuffer& buffer : alias_analysis.buffers()) {
1242 // Skip if the buffer is already assigned since it had a preset allocation.
1243 if (preset_assigned_buffers.find(&buffer) !=
1244 preset_assigned_buffers.end()) {
1245 VLOG(3) << "Skip allocation for buffer: " << buffer;
1246 continue;
1247 }
1248 TF_RET_CHECK(!buffer.values().empty());
1249 const HloComputation* comp = buffer.values()[0]->instruction()->parent();
1250 if (absl::c_linear_search(computations, comp)) {
1251 sorted_buffers.push_back(&buffer);
1252 }
1253 }
1254
1255 // Generate a post order sort of instructions for sorting of the
1256 // HloBuffers.
1257 flat_hash_map<const HloInstruction*, int> post_order_position;
1258 int position = 0;
1259 std::vector<const HloComputation*> reverse_post_order_computations;
1260 std::unique_ptr<CallGraph> call_graph =
1261 CallGraph::Build(computations[0]->parent());
1262 TF_RETURN_IF_ERROR(call_graph->VisitNodes([&](const CallGraphNode& node) {
1263 if (absl::c_linear_search(computations, node.computation())) {
1264 reverse_post_order_computations.push_back(node.computation());
1265 }
1266 return Status::OK();
1267 }));
1268 absl::c_reverse(reverse_post_order_computations);
1269 for (auto* computation : reverse_post_order_computations) {
1270 for (auto* instruction : computation->MakeInstructionPostOrder()) {
1271 post_order_position.emplace(instruction, position);
1272 position++;
1273 }
1274 }
1275
1276 HloSchedule schedule(&assignment->module());
1277
1278 for (const HloComputation* computation : computations) {
1279 const HloInstructionSequence* instruction_sequence =
1280 assignment->hlo_ordering().SequentialOrder(*computation);
1281 const bool has_sequential_order = instruction_sequence != nullptr;
1282 if (has_sequential_order && buffers_to_assign_sequentially != nullptr) {
1283 // Every sequential computation must get an entry in the
1284 // buffers_to_assign_sequentially map, even if we end up with an empty
1285 // set of buffers. This ensures we can correctly determine whether to
1286 // run whole-module heap simulation.
1287 buffers_to_assign_sequentially->emplace(computation,
1288 flat_hash_set<const HloValue*>());
1289
1290 schedule.set_sequence(computation, *instruction_sequence);
1291 }
1292 }
1293
1294 absl::c_sort(
1295 sorted_buffers, [&post_order_position, &alias_analysis, assignment](
1296 const HloBuffer* a, const HloBuffer* b) {
1297 // Primary sort is by decreasing buffer size.
1298 const int64 a_size = assignment->HloBufferSize(*a);
1299 const int64 b_size = assignment->HloBufferSize(*b);
1300 if (a_size != b_size) {
1301 return a_size > b_size; // use ">" for decreasing size.
1302 }
1303
1304 const bool a_live_out = alias_analysis.BufferLivesOut(*a);
1305 const bool b_live_out = alias_analysis.BufferLivesOut(*b);
1306 if (a_live_out != b_live_out) {
1307 return a_live_out;
1308 }
1309 auto compare = [&post_order_position](const HloValue* value1,
1310 const HloValue* value2) {
1311 return post_order_position.at(value1->instruction()) <
1312 post_order_position.at(value2->instruction());
1313 };
1314 const HloValue* a_min = *absl::c_min_element(a->values(), compare);
1315 const HloValue* b_min = *absl::c_min_element(b->values(), compare);
1316 return compare(a_min, b_min);
1317 });
1318
1319 std::vector<BufferAllocation::Index> allocation_indices;
1320
1321 for (const HloBuffer* buffer : sorted_buffers) {
1322 VLOG(3) << "=================================================";
1323 VLOG(3) << "Assigning buffer for " << *buffer;
1324 TF_RETURN_IF_ERROR(AssignSingleHloBuffer(buffer, is_thread_local,
1325 buffers_to_assign_sequentially,
1326 &allocation_indices, assignment));
1327 }
1328 return Status::OK();
1329 }
1330
1331 flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>,
1332 LogicalBuffer::Color::Hasher>
SplitBuffersByColor(const flat_hash_set<const HloValue * > & buffers)1333 BufferAssigner::SplitBuffersByColor(
1334 const flat_hash_set<const HloValue*>& buffers) {
1335 flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>,
1336 LogicalBuffer::Color::Hasher>
1337 color_map;
1338 for (auto buffer : buffers) {
1339 color_map[buffer->color()].insert(buffer);
1340 }
1341 return color_map;
1342 }
1343
AssignPresetBuffers(absl::flat_hash_set<const HloBuffer * > * assigned_buffers,BufferAssignment * assignment)1344 Status BufferAssigner::AssignPresetBuffers(
1345 absl::flat_hash_set<const HloBuffer*>* assigned_buffers,
1346 BufferAssignment* assignment) {
1347 if (!preset_assignments_) {
1348 return Status::OK();
1349 }
1350
1351 // Create an allocation for each preset color.
1352 absl::flat_hash_map<LogicalBuffer::Color, BufferAllocation*,
1353 LogicalBuffer::Color::Hasher>
1354 preset_allocations;
1355 for (auto& color_and_info : preset_assignments_->assignment_informations()) {
1356 LogicalBuffer::Color color(color_and_info.first);
1357 auto inserted = preset_allocations.emplace(
1358 color,
1359 assignment->NewEmptyAllocation(color_and_info.second.size, color));
1360 BufferAllocation* inserted_allocation = inserted.first->second;
1361 inserted_allocation->AddHeapTrace(
1362 color_and_info.second.heap_simulator_trace);
1363 VLOG(3) << "Created preset buffer allocation "
1364 << inserted_allocation->index()
1365 << ", color: " << inserted_allocation->color()
1366 << ", size: " << inserted_allocation->size();
1367 }
1368
1369 const HloAliasAnalysis& alias_analysis = assignment->alias_analysis();
1370 const HloDataflowAnalysis& dataflow_analysis =
1371 alias_analysis.dataflow_analysis();
1372
1373 for (auto& position_and_chunk : preset_assignments_->chunks()) {
1374 const HloPosition& position = position_and_chunk.first;
1375 const HloValue& value = dataflow_analysis.GetUniqueValueAt(
1376 position.instruction, position.index);
1377 VLOG(3) << "Preset allocation for value: " << value.ToShortString();
1378 const HeapSimulator::Chunk& chunk = position_and_chunk.second;
1379 auto preset_allocations_iter = preset_allocations.find(value.color());
1380 CHECK(preset_allocations_iter != preset_allocations.end())
1381 << "No preset value allocation for color " << value.color() << " for "
1382 << value.ToShortString() << " found.";
1383 preset_allocations_iter->second->AddAssignment(value, chunk.offset,
1384 chunk.size);
1385
1386 const HloBuffer& buffer = alias_analysis.GetBufferContainingValue(value);
1387 assigned_buffers->insert(&buffer);
1388 }
1389
1390 // Upon consumption of the preset assignments, delete it so that if this
1391 // method is called again, it does not assign the same buffers multiple times.
1392 preset_assignments_ = {};
1393
1394 return Status::OK();
1395 }
1396
AssignBuffersWithSequentialOrdering(const flat_hash_map<const HloComputation *,flat_hash_set<const HloValue * >> & buffers_to_assign_sequentially,bool run_whole_module_heap_simulation,BufferAssignment * assignment)1397 Status BufferAssigner::AssignBuffersWithSequentialOrdering(
1398 const flat_hash_map<const HloComputation*, flat_hash_set<const HloValue*>>&
1399 buffers_to_assign_sequentially,
1400 bool run_whole_module_heap_simulation, BufferAssignment* assignment) {
1401 // Run the sequence of instructions through the heap simulator. The
1402 // heuristic that seems to give the best results is lazy-best-fit, with all
1403 // runs of alloc / free calls sorted in decreasing size order.
1404 const HloOrdering& hlo_ordering = assignment->hlo_ordering();
1405
1406 // Returns a heap algorithm that chooses the best result from several
1407 // algorithms.
1408 auto get_heap_algorithm = [&](int64 alignment) {
1409 auto algorithms =
1410 absl::make_unique<std::vector<std::unique_ptr<HeapAlgorithm>>>();
1411 algorithms->push_back(absl::make_unique<GlobalDecreasingSizeBestFitHeap>(
1412 alignment, GlobalDecreasingSizeBestFitHeap::kSpatial));
1413 algorithms->push_back(absl::make_unique<GlobalDecreasingSizeBestFitHeap>(
1414 alignment, GlobalDecreasingSizeBestFitHeap::kTemporal));
1415 return absl::make_unique<ChooseBestHeapAlgorithm>(std::move(algorithms));
1416 };
1417
1418 if (run_whole_module_heap_simulation) {
1419 // Run the heap simulation over the whole module. This reduces memory
1420 // usage, since buffers for kCall, kWhile, and kConditional
1421 // sub-computations are only live for the duration of their calling
1422 // instructions.
1423 VLOG(1) << "Running whole-module heap simulation";
1424 HloSchedule schedule(&assignment->module());
1425 flat_hash_set<const HloValue*> all_buffers_to_assign;
1426 for (const auto& pair : buffers_to_assign_sequentially) {
1427 const HloComputation* computation = pair.first;
1428 const flat_hash_set<const HloValue*>& buffers_to_assign = pair.second;
1429 const HloInstructionSequence* instruction_sequence =
1430 hlo_ordering.SequentialOrder(*computation);
1431 CHECK(instruction_sequence != nullptr) << computation->name();
1432 schedule.set_sequence(computation, *instruction_sequence);
1433 all_buffers_to_assign.insert(buffers_to_assign.begin(),
1434 buffers_to_assign.end());
1435 }
1436 auto color_map = SplitBuffersByColor(all_buffers_to_assign);
1437 for (auto& single_colored_set : color_map) {
1438 auto color = single_colored_set.first;
1439 VLOG(2) << "Simulating heap for color " << color;
1440 int64 alignment = assignment->color_alignment_(color);
1441 HeapSimulator::Options options;
1442 options.alloc_constants = allocate_buffers_for_constants_;
1443 options.buffers_to_assign = &single_colored_set.second;
1444
1445 TF_ASSIGN_OR_RETURN(
1446 HeapSimulator::Result result,
1447 HeapSimulator::Run(
1448 get_heap_algorithm(alignment), assignment->module(), schedule,
1449 assignment->alias_analysis(), assignment->buffer_size_, options));
1450 AssignBuffersFromHeapSimulator(result, assignment,
1451 single_colored_set.first);
1452 }
1453 } else {
1454 // Run the heap-simulation on a per-computation basis. Buffers for
1455 // sub-computations are assigned disjoint BufferAllocations, assuming the
1456 // worst-case that they may all be live concurrently.
1457 VLOG(1) << "Running per-computation heap simulation";
1458 for (const auto& pair : buffers_to_assign_sequentially) {
1459 const HloComputation* computation = pair.first;
1460 const flat_hash_set<const HloValue*>& buffers_to_assign = pair.second;
1461 const HloInstructionSequence* instruction_sequence =
1462 hlo_ordering.SequentialOrder(*computation);
1463 CHECK(instruction_sequence != nullptr) << computation->name();
1464 auto color_map = SplitBuffersByColor(buffers_to_assign);
1465 for (auto& single_colored_set : color_map) {
1466 auto color = single_colored_set.first;
1467 VLOG(2) << "Simulating heap for color " << color;
1468 int64 alignment = assignment->color_alignment_(color);
1469 HeapSimulator::Options options;
1470 options.buffers_to_assign = &single_colored_set.second;
1471 TF_ASSIGN_OR_RETURN(
1472 HeapSimulator::Result result,
1473 HeapSimulator::Run(get_heap_algorithm(alignment), *computation,
1474 *instruction_sequence,
1475 assignment->alias_analysis(),
1476 assignment->buffer_size_, options));
1477 AssignBuffersFromHeapSimulator(result, assignment,
1478 single_colored_set.first);
1479 }
1480 }
1481 }
1482 return Status::OK();
1483 }
1484
1485 namespace {
1486 // Computes and returns the set of logical buffers live at the point of
1487 // maximal liveness in the given heap trace. LogicalBuffers are (stabily)
1488 // sorted by id.
ComputePeakMemoryLogicalBuffers(const BufferAllocation & allocation,const HeapSimulatorTrace & heap_trace)1489 std::vector<const HloValue*> ComputePeakMemoryLogicalBuffers(
1490 const BufferAllocation& allocation, const HeapSimulatorTrace& heap_trace) {
1491 // Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical
1492 // buffers in this allocation.
1493 absl::flat_hash_map<BufferValue::Id, const HloValue*> id_to_value;
1494 absl::flat_hash_map<const HloValue*, int64> buffer_sizes;
1495 for (const auto& pair : allocation.assigned_buffers()) {
1496 const HloValue* value = pair.first;
1497 const BufferAllocation::OffsetSize& offset_size = pair.second;
1498 id_to_value[value->id()] = value;
1499 buffer_sizes[value] = offset_size.size;
1500 }
1501 VLOG(1) << "Compute peak memory logical buffers";
1502
1503 // Returns how much the given event increases the total size of live
1504 // buffers. Can be negative.
1505 auto memory_delta = [&id_to_value, &buffer_sizes](
1506 const HeapSimulatorTrace::Event& event) -> int64 {
1507 const HloValue* buffer = id_to_value.at(event.buffer_id());
1508 const int64 buffer_size = buffer_sizes.at(buffer);
1509 if (event.kind() == HeapSimulatorTrace::Event::ALLOC ||
1510 event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
1511 return buffer_size;
1512 } else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
1513 return -1 * buffer_size;
1514 }
1515 LOG(FATAL) << "Unknown event kind: " << event.kind();
1516 };
1517
1518 // First compute the size of the maximal live set.
1519 int64 max_live_size = 0;
1520 int64 live_size = 0;
1521 for (const auto& event : heap_trace.events()) {
1522 live_size += memory_delta(event);
1523 if (max_live_size < live_size) {
1524 max_live_size = live_size;
1525 }
1526 }
1527
1528 // Next gather the set of logical buffers live at the earliest point of
1529 // maximal live set size.
1530 absl::flat_hash_set<const HloValue*> live_values;
1531 live_size = 0;
1532 for (const auto& event : heap_trace.events()) {
1533 const HloValue* value = id_to_value.at(event.buffer_id());
1534 if (event.kind() == HeapSimulatorTrace::Event::ALLOC ||
1535 event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
1536 InsertOrDie(&live_values, value);
1537 } else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
1538 CHECK(ContainsKey(live_values, value));
1539 live_values.erase(value);
1540 }
1541 live_size += memory_delta(event);
1542
1543 if (live_size == max_live_size) {
1544 break;
1545 }
1546 }
1547 CHECK_EQ(live_size, max_live_size);
1548
1549 std::vector<const HloValue*> live_values_vector;
1550 live_values_vector.insert(live_values_vector.end(), live_values.begin(),
1551 live_values.end());
1552
1553 // Stabily sort the live buffers.
1554 absl::c_sort(live_values_vector, [](const HloValue* a, const HloValue* b) {
1555 return a->id() < b->id();
1556 });
1557 VLOG(4) << "Peak memory buffer:";
1558 for (auto value : live_values_vector) {
1559 VLOG(4) << " " << value->ToString();
1560 }
1561 return live_values_vector;
1562 }
1563
1564 } // namespace
1565
AssignBuffersFromHeapSimulator(const HeapSimulator::Result & result,BufferAssignment * assignment,BufferValue::Color color)1566 void BufferAssigner::AssignBuffersFromHeapSimulator(
1567 const HeapSimulator::Result& result, BufferAssignment* assignment,
1568 BufferValue::Color color) {
1569 if (assignment->stats_.preallocated_temp_fragmentation_bytes == -1) {
1570 assignment->stats_.preallocated_temp_fragmentation_bytes =
1571 result.fragmentation_size;
1572 } else {
1573 assignment->stats_.preallocated_temp_fragmentation_bytes +=
1574 result.fragmentation_size;
1575 }
1576 VLOG(1) << "Result size from heap simulator: " << result.heap_size;
1577
1578 BufferAllocation* allocation =
1579 assignment->NewEmptyAllocation(result.heap_size, color);
1580 for (const auto& buffer_chunk : result.chunk_map) {
1581 const HloValue& value = *buffer_chunk.first;
1582 const HeapSimulator::Chunk& chunk = buffer_chunk.second;
1583 assignment->AddAssignment(allocation, value, chunk.offset, chunk.size);
1584 }
1585 allocation->peak_buffers_ =
1586 ComputePeakMemoryLogicalBuffers(*allocation, result.debug_trace);
1587
1588 VLOG(1) << "Ran heap simulation for allocation: ";
1589 XLA_VLOG_LINES(2, allocation->ToString());
1590
1591 allocation->AddHeapTrace(result.debug_trace);
1592 }
1593
CreateAssignment(const HloModule * module,std::unique_ptr<HloOrdering> hlo_ordering,BufferValue::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment,HloDataflowAnalysis::CanShareBuffer can_share_buffer)1594 StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
1595 const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
1596 BufferValue::SizeFunction buffer_size,
1597 LogicalBuffer::AlignmentFunction color_alignment,
1598 HloDataflowAnalysis::CanShareBuffer can_share_buffer) {
1599 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1600 HloAliasAnalysis::Run(module, can_share_buffer));
1601
1602 // Set up a schedule for each computation.
1603 HloSchedule schedule(module);
1604 for (const HloComputation* computation : module->computations()) {
1605 const HloInstructionSequence* instruction_sequence =
1606 hlo_ordering->SequentialOrder(*computation);
1607 const bool has_sequential_order = instruction_sequence != nullptr;
1608 if (has_sequential_order) {
1609 schedule.set_sequence(computation, *instruction_sequence);
1610 }
1611 }
1612
1613 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
1614 HloLiveRange::Run(schedule, *alias_analysis,
1615 module->entry_computation(), true));
1616
1617 VLOG(1) << "Assigning buffers to module " << module->name();
1618 XLA_VLOG_LINES(3, module->ToString());
1619 XLA_VLOG_LINES(3, alias_analysis->ToString());
1620 XLA_VLOG_LINES(3, alias_analysis->dataflow_analysis().ToString());
1621 VLOG(1) << "Number of buffers to assign: "
1622 << alias_analysis->buffers().size();
1623
1624 // Can't use absl::make_unique because BufferAssignment constructor is
1625 // private.
1626 std::unique_ptr<BufferAssignment> assignment(new BufferAssignment(
1627 module, std::move(hlo_ordering), std::move(buffer_size),
1628 std::move(color_alignment), std::move(alias_analysis),
1629 std::move(hlo_live_range)));
1630
1631 TF_RETURN_IF_ERROR(
1632 colorer_(&assignment->alias_analysis(), assignment->hlo_ordering()));
1633 VLOG(3) << "After coloring:";
1634 XLA_VLOG_LINES(3,
1635 assignment->alias_analysis().dataflow_analysis().ToString());
1636 TF_RETURN_IF_ERROR(MergeInplaceOpBuffers(assignment.get()));
1637
1638 std::vector<const HloComputation*> thread_local_computations;
1639 std::vector<const HloComputation*> global_computations;
1640 TF_RETURN_IF_ERROR(GatherComputationsByAllocationType(
1641 module, &thread_local_computations, &global_computations));
1642
1643 // First assign buffers for global computations. Temporary buffers for
1644 // sequential computations are collected in
1645 // 'buffers_to_assign_sequentially'.
1646 flat_hash_map<const HloComputation*, flat_hash_set<const HloValue*>>
1647 buffers_to_assign_sequentially;
1648 TF_RETURN_IF_ERROR(AssignBuffersForComputations(
1649 global_computations,
1650 /*is_thread_local=*/false, &buffers_to_assign_sequentially,
1651 assignment.get()));
1652 // Assign buffers with sequential ordering, if any. If all global
1653 // computations are sequential, we can run heap simulation on the whole
1654 // module, which reduces memory usage.
1655 const bool run_whole_module_heap_simulation =
1656 buffers_to_assign_sequentially.size() == global_computations.size();
1657 VLOG(2) << "Running whole module heap simulation: "
1658 << run_whole_module_heap_simulation;
1659 TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering(
1660 buffers_to_assign_sequentially, run_whole_module_heap_simulation,
1661 assignment.get()));
1662
1663 std::vector<const HloComputation*> thread_local_computations_no_fusion;
1664 // Now assign buffers for thread-local computations. All LogicalBuffers get
1665 // their own BufferAllocation.
1666
1667 for (auto* computation : thread_local_computations) {
1668 TF_RET_CHECK(computation != module->entry_computation());
1669 if (computation->IsFusionComputation()) {
1670 continue;
1671 }
1672 thread_local_computations_no_fusion.push_back(computation);
1673 }
1674
1675 TF_RETURN_IF_ERROR(AssignBuffersForComputations(
1676 thread_local_computations_no_fusion,
1677 /*is_thread_local=*/true,
1678 /*buffers_to_assign_sequentially=*/nullptr, assignment.get()));
1679
1680 // Mark all buffers which may be live out of the entry computation as
1681 // "liveout".
1682 for (const HloBuffer* buffer :
1683 assignment->alias_analysis().LiveOutBuffers()) {
1684 VLOG(3) << "maybe_live_out LogicalBuffer: " << *buffer;
1685 if (assignment->HasAllocation(*buffer)) {
1686 BufferAllocation* alloc =
1687 assignment->GetMutableAssignedAllocation(*buffer);
1688 alloc->set_maybe_live_out(true);
1689 VLOG(3) << "maybe_live_out BufferAllocation: " << *alloc;
1690 }
1691 }
1692
1693 // Combines allocations of temporary buffers into one big BufferAllocation.
1694 // This can only be performed after all buffers have been assigned, and
1695 // after maybe_live_out is marked, since it is used to determine whether an
1696 // allocation contains temporary buffers or not.
1697 assignment->CombineTempAllocations();
1698
1699 XLA_VLOG_LINES(2, assignment->ToString());
1700 TF_RETURN_IF_ERROR(assignment->ComputeSummaryStats());
1701 XLA_VLOG_LINES(1, assignment->GetStats().ToString());
1702 VLOG(1) << "Buffer assignment done.";
1703 return std::move(assignment);
1704 }
1705
1706 } // namespace xla
1707