1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Defines the data returned by the XLA buffer assignment packages.
17
18 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
19
20 #include <algorithm>
21 #include <deque>
22 #include <ostream>
23 #include <utility>
24
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/str_format.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/buffer_value_containers.h"
33 #include "tensorflow/compiler/xla/service/heap_simulator.h"
34 #include "tensorflow/compiler/xla/service/hlo.pb.h"
35 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
36 #include "tensorflow/compiler/xla/service/hlo_buffer.h"
37 #include "tensorflow/compiler/xla/service/hlo_live_range.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/status_macros.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/hash/hash.h"
45 #include "tensorflow/core/lib/strings/numbers.h"
46
47 namespace xla {
48 namespace {
49
50 using absl::flat_hash_map;
51 using absl::flat_hash_set;
52 using absl::StrAppend;
53 using absl::StrAppendFormat;
54 using ::tensorflow::strings::HumanReadableNumBytes;
55
56 // Given the interference map of a graph (the list of interfering node indices
57 // for each node), perform graph coloring such that interfering nodes are
58 // assigned to different colors. Returns the assigned color of the nodes, where
59 // the colors are represented as integer values [0, color_count).
ColorInterferenceGraph(const std::vector<std::vector<int64>> & interference_map)60 std::vector<int64> ColorInterferenceGraph(
61 const std::vector<std::vector<int64>>& interference_map) {
62 const int64 node_count = interference_map.size();
63
64 // Sort the nodes such that we assign nodes with more interference first. This
65 // relies on the common heuristic of assigning the most constrained node
66 // first, but it would be good to investigate other ordering heuristics too.
67 std::vector<int64> nodes(node_count);
68 std::iota(nodes.begin(), nodes.end(), 0);
69 absl::c_sort(nodes, [&interference_map](const int64 i, const int64 j) {
70 return interference_map[i].size() > interference_map[j].size();
71 });
72
73 const int64 kColorUnassigned = -1;
74 std::vector<int64> assigned_colors(node_count, kColorUnassigned);
75 for (int64 node : nodes) {
76 // Mark the colors that are already assigned to the neighbors.
77 std::vector<bool> available_colors(node_count, true);
78 for (int64 neighbor : interference_map[node]) {
79 int64 color = assigned_colors[neighbor];
80 if (color != kColorUnassigned) {
81 available_colors[color] = false;
82 }
83 }
84
85 // Find the color that is not yet assigned to the neighbors.
86 int64 color = kColorUnassigned;
87 for (color = 0; color < available_colors.size(); ++color) {
88 if (available_colors[color]) {
89 break;
90 }
91 }
92 CHECK_NE(color, kColorUnassigned);
93 assigned_colors[node] = color;
94 }
95 return assigned_colors;
96 }
97
98 // If an hlo buffer contains an entry parameter, the buffer is read-only unless
99 // it is aliased with an output.
HloBufferIsReadOnly(const HloBuffer & buffer)100 bool HloBufferIsReadOnly(const HloBuffer& buffer) {
101 for (const HloValue* value : buffer.values()) {
102 const HloInstruction* instruction = value->instruction();
103 if (instruction->opcode() == HloOpcode::kConstant) {
104 return true;
105 }
106 const HloModule* module = instruction->parent()->parent();
107 const bool is_entry_parameter =
108 instruction->opcode() == HloOpcode::kParameter &&
109 instruction->parent() == module->entry_computation();
110
111 if (is_entry_parameter) {
112 bool parameter_has_alias =
113 module->input_output_alias_config().ParameterHasAlias(
114 instruction->parameter_number(), value->index());
115 // The parameter doesn't have an alias, it must be read-only.
116 if (!parameter_has_alias) {
117 return true;
118 }
119 }
120 }
121 return false;
122 }
123
124 } // namespace
125
GatherComputationsByAllocationType(const HloModule * module,std::vector<const HloComputation * > * thread_local_computations,std::vector<const HloComputation * > * global_computations)126 Status GatherComputationsByAllocationType(
127 const HloModule* module,
128 std::vector<const HloComputation*>* thread_local_computations,
129 std::vector<const HloComputation*>* global_computations) {
130 // Create a worklist of computations paired with whether the allocation must
131 // be thread-local.
132 std::deque<std::pair<const HloComputation*, bool>> worklist;
133 worklist.push_back(std::make_pair(module->entry_computation(),
134 /*is_thread_local*/ false));
135
136 // Sets for quickly checking membership. Computations are returned in vectors
137 // for stable iteration.
138 flat_hash_set<const HloComputation*> thread_local_set;
139 flat_hash_set<const HloComputation*> global_set;
140
141 while (!worklist.empty()) {
142 auto worklist_front = worklist.front();
143 worklist.pop_front();
144 const HloComputation* computation = worklist_front.first;
145 bool is_thread_local = worklist_front.second;
146 bool in_thread_local_set = thread_local_set.contains(computation);
147 bool in_global_set = global_set.contains(computation);
148
149 // If the computation has already been added to the respective set, then
150 // nothing to do.
151 if ((is_thread_local && in_thread_local_set) ||
152 (!is_thread_local && in_global_set)) {
153 continue;
154 }
155
156 // If the computation has already been added to the other set this is an
157 // error condition because the global call to the computation (eg,
158 // while/call) may return a reference to one of the thread-local buffers to
159 // the calling computation which will become a dangling reference when the
160 // thread-local is deallocated with the call return.
161 if ((is_thread_local && in_global_set) ||
162 (!is_thread_local && in_thread_local_set)) {
163 return InvalidArgument(
164 "computation %s has conflicting allocation requirements (global "
165 "and thread-local)",
166 computation->name());
167 }
168
169 if (is_thread_local) {
170 thread_local_set.insert(computation);
171 } else {
172 global_set.insert(computation);
173 }
174
175 for (auto* instruction : computation->instructions()) {
176 for (HloComputation* subcomputation :
177 instruction->called_computations()) {
178 switch (instruction->opcode()) {
179 case HloOpcode::kCall:
180 case HloOpcode::kConditional:
181 case HloOpcode::kWhile:
182 // Call and while must be called from a computation with global
183 // allocations as they may return references to buffers inside the
184 // called computation which cannot be thread-local.
185 if (is_thread_local) {
186 return InvalidArgument(
187 "computation %s cannot contain call/while op because it "
188 "requires thread-local buffer allocations",
189 computation->name());
190 }
191 worklist.push_back(std::make_pair(subcomputation,
192 false)); // Not thread local.
193 break;
194 case HloOpcode::kAllReduce:
195 case HloOpcode::kMap:
196 case HloOpcode::kReduce:
197 case HloOpcode::kReduceWindow:
198 case HloOpcode::kScatter:
199 case HloOpcode::kSelectAndScatter:
200 case HloOpcode::kSort:
201 case HloOpcode::kFusion:
202 // Map/reduce etc computations are always thread-local.
203 worklist.push_back(std::make_pair(subcomputation,
204 true)); // Thread local.
205 break;
206 default:
207 return InternalError("Unexpected calling opcode: %s",
208 HloOpcodeString(instruction->opcode()));
209 }
210 }
211 }
212 }
213
214 // Add the computations to the vectors in post order.
215 for (auto* computation : module->MakeComputationPostOrder()) {
216 if (thread_local_set.contains(computation)) {
217 thread_local_computations->push_back(computation);
218 } else if (global_set.contains(computation)) {
219 global_computations->push_back(computation);
220 }
221 // If the computation is not reachable from the entry computation, then it
222 // will not appear in either thread_local_set or global_set. We don't bother
223 // assigning buffers for these.
224 }
225 return Status::OK();
226 }
227
ToString() const228 string BufferAllocation::Slice::ToString() const {
229 return absl::StrCat("{index:", index(), ", offset:", offset_,
230 ", size:", size_, "}");
231 }
232
GetSlice(const HloValue & buffer) const233 BufferAllocation::Slice BufferAllocation::GetSlice(
234 const HloValue& buffer) const {
235 const OffsetSize os = FindOrDie(assigned_buffers_, &buffer);
236 return Slice(this, os.offset, os.size);
237 }
238
AddAssignment(const HloValue & buffer,int64 offset,int64 size)239 void BufferAllocation::AddAssignment(const HloValue& buffer, int64 offset,
240 int64 size) {
241 VLOG(4) << "Adding the following buffer to allocation #" << index()
242 << absl::StrFormat(" (size=%d, offset=%d) %s", size, offset,
243 buffer.ToShortString());
244 CHECK(!assigned_buffers_.contains(&buffer))
245 << "LogicalBuffer " << buffer << " already assigned to allocation "
246 << index_;
247 CHECK_LE(offset, size_) << "LogicalBuffer " << buffer
248 << " offset out of range";
249 CHECK_LE(offset + size, size_)
250 << "LogicalBuffer " << buffer
251 << " size out of range at offset: " << offset << " with size: " << size;
252 CHECK_EQ(buffer.color(), color())
253 << "Buffer color " << buffer.color() << " for buffer " << buffer
254 << " does not match allocation color " << color() << ".";
255 OffsetSize offset_size;
256 offset_size.offset = offset;
257 offset_size.size = size;
258 assigned_buffers_.emplace(&buffer, offset_size);
259 // For debugging purposes, store the assigned memory space in the
260 // instruction's layout.
261 for (HloPosition position : buffer.positions()) {
262 Shape* shape = ShapeUtil::GetMutableSubshape(
263 position.instruction->mutable_shape(), position.index);
264 if (shape->has_layout()) {
265 shape->mutable_layout()->set_memory_space(buffer.color());
266 }
267 }
268 }
269
ToProto() const270 BufferAllocationProto BufferAllocation::ToProto() const {
271 BufferAllocationProto proto;
272 proto.set_index(index_);
273 proto.set_size(size_);
274 proto.set_is_thread_local(is_thread_local_);
275 proto.set_is_tuple(is_tuple_);
276 proto.set_color(color_);
277 if (is_entry_computation_parameter_) {
278 proto.set_is_entry_computation_parameter(true);
279 for (int64 idx : param_shape_index()) {
280 proto.add_parameter_shape_index(idx);
281 }
282 proto.set_parameter_number(parameter_number_);
283 }
284 proto.set_is_constant(is_constant_);
285 proto.set_maybe_live_out(maybe_live_out_);
286 for (const auto& buffer_offset_size : assigned_buffers_) {
287 BufferAllocationProto::Assigned* proto_assigned = proto.add_assigned();
288 proto_assigned->set_logical_buffer_id(buffer_offset_size.first->id());
289 proto_assigned->set_offset(buffer_offset_size.second.offset);
290 proto_assigned->set_size(buffer_offset_size.second.size);
291 }
292 absl::c_sort(*proto.mutable_assigned(),
293 [](const BufferAllocationProto::Assigned& assign1,
294 const BufferAllocationProto::Assigned& assign2) {
295 return assign1.logical_buffer_id() <
296 assign2.logical_buffer_id();
297 });
298 return proto;
299 }
300
CompareHloValuesById(const HloValue * a,const HloValue * b)301 static bool CompareHloValuesById(const HloValue* a, const HloValue* b) {
302 return a->id() < b->id();
303 }
304
305 // Returns parameter instruction corresponding to the allocation or nullptr.
GetEntryParameterInstruction(const BufferAllocation & alloc)306 static const HloInstruction* GetEntryParameterInstruction(
307 const BufferAllocation& alloc) {
308 for (const auto& p : alloc.assigned_buffers()) {
309 const HloValue* value = p.first;
310 const HloInstruction* instr = value->instruction();
311 if (instr->opcode() == HloOpcode::kParameter &&
312 instr->parent() == instr->parent()->parent()->entry_computation()) {
313 return instr;
314 }
315 }
316 return nullptr;
317 }
318
319 // Returns root module output instruction corresponding to the allocation or
320 // nullptr.
GetOutputInstruction(const BufferAllocation & alloc)321 static const HloInstruction* GetOutputInstruction(
322 const BufferAllocation& alloc) {
323 for (const auto& p : alloc.assigned_buffers()) {
324 const HloValue* value = p.first;
325 for (const HloPosition& position : value->positions()) {
326 const HloInstruction* instr = position.instruction;
327 if (position.index.empty() &&
328 instr->parent()->root_instruction() == instr &&
329 instr->parent()->IsEntryComputation()) {
330 return instr;
331 }
332 }
333 }
334 return nullptr;
335 }
336
ToString() const337 string BufferAllocation::ToString() const {
338 string output;
339 StrAppendFormat(&output, "allocation %d: %p, size %d", index_, this, size());
340 if (color() != 0) {
341 StrAppend(&output, ", color ", color());
342 }
343 if (is_entry_computation_parameter()) {
344 const HloInstruction* param = GetEntryParameterInstruction(*this);
345 CHECK(param);
346 StrAppend(&output, ", parameter ", parameter_number(), ", shape |",
347 param->shape().ToString(/*print_layout=*/false),
348 "| at ShapeIndex ", param_shape_index().ToString());
349 }
350 if (const HloInstruction* instr = GetOutputInstruction(*this)) {
351 StrAppend(&output, ", output shape is |",
352 instr->shape().ToString(/*print_layout=*/false), "|");
353 }
354 if (is_constant()) {
355 StrAppend(&output, ", constant");
356 }
357 if (is_thread_local()) {
358 StrAppend(&output, ", thread-local");
359 }
360 if (maybe_live_out()) {
361 StrAppend(&output, ", maybe-live-out");
362 }
363 if (IsPreallocatedTempBuffer()) {
364 StrAppend(&output, ", preallocated-temp");
365 }
366 StrAppend(&output, ":\n");
367 // Dump the assigned buffers ordered by id.
368 std::vector<const HloValue*> sorted_buffers;
369 for (const auto& buffer_offset_size : assigned_buffers_) {
370 sorted_buffers.push_back(buffer_offset_size.first);
371 }
372 absl::c_sort(sorted_buffers, &CompareHloValuesById);
373 for (const HloValue* buffer : sorted_buffers) {
374 const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer);
375 StrAppend(&output,
376 absl::StrFormat(
377 " value: %s (size=%d,offset=%d): %s\n",
378 buffer->ToShortString(), offset_size.size, offset_size.offset,
379 ShapeUtil::HumanStringWithLayout(buffer->shape())));
380 }
381 return output;
382 }
383
operator <<(std::ostream & out,const BufferAllocation & buffer)384 std::ostream& operator<<(std::ostream& out, const BufferAllocation& buffer) {
385 out << buffer.ToString();
386 return out;
387 }
388
operator <<(std::ostream & out,const BufferAllocation::Slice & s)389 std::ostream& operator<<(std::ostream& out, const BufferAllocation::Slice& s) {
390 out << s.ToString();
391 return out;
392 }
393
HasAllocation(const HloValue & value) const394 bool BufferAssignment::HasAllocation(const HloValue& value) const {
395 return allocation_index_for_value_.contains(&value);
396 }
397
HasAllocation(const HloBuffer & buffer) const398 bool BufferAssignment::HasAllocation(const HloBuffer& buffer) const {
399 return allocation_index_for_value_.contains(buffer.values()[0]);
400 }
401
GetAssignedAllocation(const HloValue & value) const402 const BufferAllocation& BufferAssignment::GetAssignedAllocation(
403 const HloValue& value) const {
404 CHECK(HasAllocation(value));
405 return GetAllocation(allocation_index_for_value_.at(&value));
406 }
407
GetAssignedAllocation(const HloBuffer & hlo_buffer) const408 const BufferAllocation& BufferAssignment::GetAssignedAllocation(
409 const HloBuffer& hlo_buffer) const {
410 return GetAssignedAllocation(*hlo_buffer.values()[0]);
411 }
412
GetMutableAssignedAllocation(const HloBuffer & buffer)413 BufferAllocation* BufferAssignment::GetMutableAssignedAllocation(
414 const HloBuffer& buffer) {
415 return const_cast<BufferAllocation*>(&GetAssignedAllocation(buffer));
416 }
417
GetAllSlices(const HloInstruction * instruction,const ShapeIndex & index) const418 std::set<BufferAllocation::Slice> BufferAssignment::GetAllSlices(
419 const HloInstruction* instruction, const ShapeIndex& index) const {
420 std::set<BufferAllocation::Slice> result;
421 for (const HloValue* value :
422 dataflow_analysis().GetValueSet(instruction, index).values()) {
423 if (HasAllocation(*value)) {
424 result.insert(GetAssignedAllocation(*value).GetSlice(*value));
425 }
426 }
427 return result;
428 }
429
GetAllocation(BufferAllocation::Index index) const430 const BufferAllocation& BufferAssignment::GetAllocation(
431 BufferAllocation::Index index) const {
432 CHECK_GE(index, 0);
433 CHECK_LT(index, allocations_.size());
434 return allocations_[index];
435 }
436
GetInstructionAllocation(const HloInstruction * hlo,const ShapeIndex & shape_index) const437 const BufferAllocation* BufferAssignment::GetInstructionAllocation(
438 const HloInstruction* hlo, const ShapeIndex& shape_index) const {
439 const HloValue* value =
440 dataflow_analysis().GetValueSet(hlo, shape_index).values()[0];
441
442 if (!HasAllocation(*value)) {
443 return nullptr;
444 }
445
446 const BufferAllocation& instruction_allocation =
447 GetAssignedAllocation(*value);
448 return &instruction_allocation;
449 }
450
GetMutableAllocation(BufferAllocation::Index index)451 BufferAllocation* BufferAssignment::GetMutableAllocation(
452 BufferAllocation::Index index) {
453 return const_cast<BufferAllocation*>(&GetAllocation(index));
454 }
455
HasAllocationAt(const HloInstruction * instruction,const ShapeIndex & index) const456 bool BufferAssignment::HasAllocationAt(const HloInstruction* instruction,
457 const ShapeIndex& index) const {
458 for (const HloValue* value :
459 dataflow_analysis().GetValueSet(instruction, index).values()) {
460 if (allocation_index_for_value_.contains(value)) {
461 return true;
462 }
463 }
464 return false;
465 }
466
HasTopLevelAllocation(const HloInstruction * instruction) const467 bool BufferAssignment::HasTopLevelAllocation(
468 const HloInstruction* instruction) const {
469 return HasAllocationAt(instruction, /*index=*/{});
470 }
471
GetUniqueSlice(const HloInstruction * instruction,const ShapeIndex & index) const472 StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueSlice(
473 const HloInstruction* instruction, const ShapeIndex& index) const {
474 VLOG(3) << "Trying to find unique slice for " << instruction->name() << " ["
475 << index << "]";
476 BufferAllocation::Slice result;
477 for (const HloValue* value :
478 dataflow_analysis().GetValueSet(instruction, index).values()) {
479 VLOG(3) << "Examining value " << *value;
480 if (HasAllocation(*value)) {
481 VLOG(3) << "Has allocation";
482 const BufferAllocation::Slice slice =
483 GetAssignedAllocation(*value).GetSlice(*value);
484 if (result.allocation() == nullptr) {
485 result = slice;
486 } else if (result != slice) {
487 return FailedPrecondition(
488 "BufferAllocation::Slice for instruction %s at index %s cannot "
489 "be determined at compile-time.",
490 instruction->name(), index.ToString());
491 }
492 } else {
493 VLOG(3) << "No allocation";
494 }
495 }
496 if (result.allocation() == nullptr) {
497 return FailedPrecondition(
498 "BufferAllocation::Slice not assigned for instruction %s at index %s",
499 instruction->name(), index.ToString());
500 }
501 return result;
502 }
503
GetUniqueTopLevelSlice(const HloInstruction * instruction) const504 StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueTopLevelSlice(
505 const HloInstruction* instruction) const {
506 return GetUniqueSlice(instruction, /*index=*/{});
507 }
508
SharesSliceAtIndex(const HloInstruction * hlo_a,const ShapeIndex & shape_index_a,const HloInstruction * hlo_b,const ShapeIndex & shape_index_b) const509 bool BufferAssignment::SharesSliceAtIndex(
510 const HloInstruction* hlo_a, const ShapeIndex& shape_index_a,
511 const HloInstruction* hlo_b, const ShapeIndex& shape_index_b) const {
512 return GetUniqueSlice(hlo_a, shape_index_a).ConsumeValueOrDie() ==
513 GetUniqueSlice(hlo_b, shape_index_b).ConsumeValueOrDie();
514 }
515
HaveDisjointSlices(const HloInstruction * hlo_a,const HloInstruction * hlo_b) const516 bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a,
517 const HloInstruction* hlo_b) const {
518 using SliceSet = flat_hash_set<BufferAllocation::Slice>;
519 // Gets the slices all of instr's subshapes. If any subshape doesn't have an
520 // assigned slice, returns the empty set.
521 auto collect_slices = [&](const HloInstruction* instr) -> SliceSet {
522 SliceSet slices;
523 Status status = ShapeUtil::ForEachSubshapeWithStatus(
524 instr->shape(),
525 [&](const Shape& /*subshape*/, const ShapeIndex& index) {
526 auto shape_slices = GetAllSlices(instr, index);
527 if (shape_slices.empty()) {
528 return InvalidArgument("No slices assigned to part of instr.");
529 }
530 slices.insert(shape_slices.begin(), shape_slices.end());
531 return Status::OK();
532 });
533 if (!status.ok()) {
534 return {};
535 }
536 return slices;
537 };
538
539 SliceSet slices_a = collect_slices(hlo_a);
540 SliceSet slices_b = collect_slices(hlo_b);
541 // hlo_a and hlo_b have disjoint slices if collect_slices succeeded (i.e.
542 // didn't return the empty set) for both HLOs, and the two resulting sets of
543 // slices are disjoint.
544 return !slices_a.empty() && !slices_b.empty() &&
545 absl::c_none_of(slices_a, [&](const BufferAllocation::Slice& slice) {
546 return slices_b.contains(slice);
547 });
548 }
549
550 StatusOr<BufferAllocation::Slice>
GetUniqueTopLevelOutputSlice() const551 BufferAssignment::GetUniqueTopLevelOutputSlice() const {
552 return GetUniqueTopLevelSlice(
553 module_->entry_computation()->root_instruction());
554 }
555
NewEmptyAllocation(int64 size,LogicalBuffer::Color color)556 BufferAllocation* BufferAssignment::NewEmptyAllocation(
557 int64 size, LogicalBuffer::Color color) {
558 BufferAllocation::Index index = allocations_.size();
559 allocations_.emplace_back(index, size, color);
560 BufferAllocation* allocation = &allocations_.back();
561 return allocation;
562 }
563
NewAllocation(const HloBuffer & buffer,int64 size)564 BufferAllocation* BufferAssignment::NewAllocation(const HloBuffer& buffer,
565 int64 size) {
566 BufferAllocation* allocation = NewEmptyAllocation(size, buffer.color());
567 AddAssignment(allocation, buffer, /*offset=*/0, size);
568 allocation->peak_buffers_.push_back(buffer.values()[0]);
569 return allocation;
570 }
571
AddAssignment(BufferAllocation * allocation,const HloBuffer & buffer,int64 offset,int64 size)572 void BufferAssignment::AddAssignment(BufferAllocation* allocation,
573 const HloBuffer& buffer, int64 offset,
574 int64 size) {
575 CHECK(allocation->is_reusable() || allocation->assigned_buffers().empty())
576 << "Non-reusable allocation already assigned a buffer: "
577 << allocation->ToString();
578
579 for (const HloValue* buffer_value : buffer.values()) {
580 CHECK(!allocation_index_for_value_.contains(buffer_value))
581 << "BufferValue " << buffer_value << " already has an allocation.";
582 allocation->AddAssignment(*buffer_value, offset, size);
583 allocation_index_for_value_[buffer_value] = allocation->index();
584 }
585
586 if (alias_analysis().BufferLivesOut(buffer)) {
587 VLOG(3) << "HloBuffer lives out" << buffer.ToString();
588 VLOG(3) << "Set maybe live out: " << allocation->ToString();
589 allocation->set_maybe_live_out(true);
590 }
591 }
592
AddAssignment(BufferAllocation * allocation,const HloValue & value,int64 offset,int64 size)593 void BufferAssignment::AddAssignment(BufferAllocation* allocation,
594 const HloValue& value, int64 offset,
595 int64 size) {
596 allocation->AddAssignment(value, offset, size);
597 allocation_index_for_value_[&value] = allocation->index();
598 const HloValue& hlo_value =
599 *CHECK_NOTNULL(dynamic_cast<const HloValue*>(&value));
600 if (alias_analysis().ValueLivesOut(hlo_value)) {
601 VLOG(3) << "HloValue lives out: " << hlo_value.ToString();
602 VLOG(3) << "Set maybe live out: " << allocation->ToString();
603 allocation->set_maybe_live_out(true);
604 }
605 }
606
607 // Combines allocations of temporary buffers of the same color into one big
608 // BufferAllocation.
CombineTempAllocations()609 void BufferAssignment::CombineTempAllocations() {
610 VLOG(1) << "CombineTempAllocations()";
611 // Stores the combined allocations.
612 std::deque<BufferAllocation> combined_allocations;
613 // Holds the pointer to a combined allocation of each color, if any.
614 flat_hash_map<BufferValue::Color, BufferAllocation*> combined_allocation_map;
615
616 // Move all temp allocations into a single run at the end of the allocations
617 // vector.
618 const auto first_temp_it =
619 std::partition(allocations_.begin(), allocations_.end(),
620 [](const BufferAllocation& allocation) {
621 return !allocation.IsPreallocatedTempBuffer();
622 });
623
624 // Walk over the run of temp allocations, collecting the allocations belonging
625 // to the same color.
626 if (first_temp_it != allocations_.end()) {
627 for (auto it = first_temp_it; it != allocations_.end(); ++it) {
628 BufferAllocation& temp_allocation = *it;
629 BufferValue::Color color = temp_allocation.color();
630 auto combined_it = combined_allocation_map.find(color);
631 if (combined_it == combined_allocation_map.end()) {
632 // We have found the first temp allocation of this color. Collect
633 // the other temp allocations of the same color into it subject to the
634 // size constraint.
635 VLOG(1) << "Combined temp allocation for color " << color
636 << " is: " << temp_allocation;
637 combined_allocations.emplace_back(temp_allocation);
638 combined_allocation_map.emplace(color, &combined_allocations.back());
639 continue;
640 }
641 if (combined_it->second->size() + it->size() >=
642 multiheap_size_constraint_per_heap_) {
643 // We cannot put more into the current combined_it. So, appoint a new
644 // combined_it.
645 VLOG(1) << "Due to size constraint, reset temp allocation for color "
646 << color << " to: " << temp_allocation;
647 combined_allocations.emplace_back(temp_allocation);
648 combined_allocation_map.emplace(color, &combined_allocations.back());
649 continue;
650 }
651
652 BufferAllocation* combined_allocation = combined_it->second;
653 VLOG(1) << "Combined allocation absorbing temp allocation: "
654 << temp_allocation;
655
656 // Each temp allocation is placed end-to-end, accounting for alignment.
657 // The offset of each buffer in the combined allocation is computed from
658 // the base offset of the allocation.
659 int64 alignment = color_alignment_(color);
660 const int64 base =
661 RoundUpToNearest(combined_allocation->size(), alignment);
662 combined_allocation->set_size(base + temp_allocation.size());
663 for (const auto& buffer_offset_size : temp_allocation.assigned_buffers_) {
664 const HloValue* value = buffer_offset_size.first;
665 const int64 offset = buffer_offset_size.second.offset;
666 const int64 size = buffer_offset_size.second.size;
667 combined_allocation->AddAssignment(*value, base + offset, size);
668 }
669 if (!temp_allocation.HeapTraces().empty()) {
670 CHECK_EQ(temp_allocation.HeapTraces().size(), 1);
671 combined_allocation->AddHeapTrace(temp_allocation.HeapTraces().front());
672 }
673
674 combined_allocation->peak_buffers_.insert(
675 combined_allocation->peak_buffers_.end(),
676 temp_allocation.peak_buffers_.begin(),
677 temp_allocation.peak_buffers_.end());
678 }
679 // Replace all existing temporary allocations with the new combined
680 // allocations.
681 allocations_.erase(first_temp_it, allocations_.end());
682 for (BufferAllocation& combined : combined_allocations) {
683 temp_allocation_total_size_ += combined.size();
684 allocations_.push_back(std::move(combined));
685 }
686 }
687
688 // Update allocation indices to their new positions.
689 allocation_index_for_value_.erase(allocation_index_for_value_.begin(),
690 allocation_index_for_value_.end());
691 for (size_t index = 0; index < allocations_.size(); ++index) {
692 BufferAllocation* allocation = &allocations_[index];
693 allocation->set_index(index);
694 for (const auto& buffer_offset_size : allocation->assigned_buffers_) {
695 const HloValue* value = buffer_offset_size.first;
696 allocation_index_for_value_[value] = index;
697 }
698 }
699 }
700
ComputeSummaryStats()701 Status BufferAssignment::ComputeSummaryStats() {
702 for (auto& allocation : Allocations()) {
703 if (allocation.is_entry_computation_parameter()) {
704 stats_.parameter_allocation_count++;
705 stats_.parameter_allocation_bytes += allocation.size();
706 }
707 if (allocation.is_constant()) {
708 stats_.constant_allocation_count++;
709 stats_.constant_allocation_bytes += allocation.size();
710 }
711 if (allocation.maybe_live_out()) {
712 stats_.maybe_live_out_allocation_count++;
713 stats_.maybe_live_out_allocation_bytes += allocation.size();
714 }
715 if (allocation.IsPreallocatedTempBuffer()) {
716 stats_.preallocated_temp_allocation_count++;
717 stats_.preallocated_temp_allocation_bytes += allocation.size();
718 }
719 stats_.total_allocation_count++;
720 stats_.total_allocation_bytes += allocation.size();
721 }
722
723 // Only compute total fragmentation if all computations have schedules.
724 HloSchedule schedule(module_);
725 bool schedule_complete = true;
726 for (const auto& computation : module_->computations()) {
727 if (!computation->IsFusionComputation()) {
728 const HloInstructionSequence* sequence =
729 hlo_ordering().SequentialOrder(*computation);
730 if (sequence == nullptr) {
731 schedule_complete = false;
732 } else {
733 schedule.set_sequence(computation, *sequence);
734 }
735 }
736 }
737 if (schedule_complete) {
738 TF_RETURN_IF_ERROR(schedule.Verify());
739 TF_ASSIGN_OR_RETURN(
740 const int64 min_size,
741 HeapSimulator::MinimumMemoryForModule(schedule, buffer_size_));
742 stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size;
743 }
744
745 return Status::OK();
746 }
747
ToString() const748 string BufferAssignment::Stats::ToString() const {
749 string s;
750 StrAppendFormat(&s, "BufferAssignment stats:\n");
751 StrAppendFormat(&s, " parameter allocation: %10s\n",
752 HumanReadableNumBytes(parameter_allocation_bytes));
753 StrAppendFormat(&s, " constant allocation: %10s\n",
754 HumanReadableNumBytes(constant_allocation_bytes));
755 StrAppendFormat(&s, " maybe_live_out allocation: %10s\n",
756 HumanReadableNumBytes(maybe_live_out_allocation_bytes));
757 StrAppendFormat(&s, " preallocated temp allocation: %10s\n",
758 HumanReadableNumBytes(preallocated_temp_allocation_bytes));
759 if (preallocated_temp_fragmentation_bytes >= 0) {
760 const double percent = 100. * preallocated_temp_fragmentation_bytes /
761 preallocated_temp_allocation_bytes;
762 StrAppendFormat(
763 &s, " preallocated temp fragmentation: %10s (%.2f%%)\n",
764 HumanReadableNumBytes(preallocated_temp_fragmentation_bytes), percent);
765 }
766 StrAppendFormat(&s, " total allocation: %10s\n",
767 HumanReadableNumBytes(total_allocation_bytes));
768 if (total_fragmentation_bytes >= 0) {
769 const double percent =
770 100. * total_fragmentation_bytes / total_allocation_bytes;
771 StrAppendFormat(&s, " total fragmentation: %10s (%.2f%%)\n",
772 HumanReadableNumBytes(total_fragmentation_bytes), percent);
773 }
774 return s;
775 }
776
ToString() const777 string BufferAssignment::ToString() const {
778 string output;
779 absl::StrAppend(&output, "BufferAssignment:\n");
780 std::vector<const HloValue*> used_values;
781 int64 total_size = 0;
782 for (auto& allocation : allocations_) {
783 total_size += allocation.size();
784 absl::StrAppend(&output, allocation.ToString());
785 for (const auto& p : allocation.assigned_buffers()) {
786 used_values.push_back(p.first);
787 }
788 }
789 absl::StrAppend(&output, "\nTotal bytes used: ", total_size, "\n");
790 absl::StrAppend(&output, "\nUsed values:\n");
791 absl::c_sort(used_values, &CompareHloValuesById);
792 for (const HloValue* value : used_values) {
793 absl::StrAppend(&output, value->ToString());
794 }
795 return output;
796 }
797
BufferInfoString() const798 string BufferAssignment::BufferInfoString() const {
799 string binfo;
800 // Columns in buffer information:
801 // buffer_id: int. This value can be used to match the allocation in
802 // allocation information.
803 // buffer_name: string.
804 // offset: int. Starting position of the buffer in the memory space.
805 // size: int. Size of the buffer in bytes.
806 // definition_time: int. Position in the schedule where the buffer starts
807 // being live (inclusive).
808 // end_time: int. Position in the schedule where the buffer stops being live
809 // (exclusive).
810 // num_uses: int. Number of uses of the buffer.
811 // use_names: string. This is a semicolon-separated list of string
812 // representation of uses.
813 // Append the column names.
814 absl::StrAppend(&binfo,
815 "buffer_id,buffer_name,offset,size,"
816 "definition_time,end_time,num_uses,use_times,use_names\n");
817 const HloLiveRange& live_ranges = hlo_live_range();
818 const auto& instruction_schedule = live_ranges.instruction_schedule();
819 const auto& buffer_live_ranges = live_ranges.buffer_live_ranges();
820 // Sort the buffers by Id.
821 std::vector<std::pair<const HloValue*, BufferAllocation::OffsetSize>> buffers;
822 for (const BufferAllocation& allocation : allocations_) {
823 absl::c_copy(allocation.assigned_buffers(), std::back_inserter(buffers));
824 }
825 absl::c_sort(
826 buffers,
827 [](const std::pair<const HloValue*, BufferAllocation::OffsetSize>& b1,
828 const std::pair<const HloValue*, BufferAllocation::OffsetSize>& b2) {
829 return b1.first->id() < b2.first->id();
830 });
831 for (const auto& buffer_pair : buffers) {
832 const HloValue& buffer = *buffer_pair.first;
833 const BufferAllocation::OffsetSize& offset_size = buffer_pair.second;
834 if (!buffer_live_ranges.contains(&buffer)) {
835 continue;
836 }
837 // Ordering uses by their use position.
838 std::vector<std::pair<int64, std::string>> uses;
839 uses.reserve(buffer.uses().size());
840 for (const HloUse& use : buffer.uses()) {
841 uses.emplace_back(instruction_schedule.at(use.instruction),
842 use.ToString());
843 }
844 absl::c_sort(uses);
845 std::vector<int64> use_positions;
846 std::vector<std::string> use_names;
847 use_positions.reserve(uses.size());
848 use_names.reserve(uses.size());
849 for (const auto& use : uses) {
850 use_positions.push_back(use.first);
851 use_names.push_back(use.second);
852 }
853 const int64 definition_time =
854 instruction_schedule.at(buffer.defining_position().instruction);
855 const int64 end_t = buffer_live_ranges.at(&buffer).end;
856 absl::StrAppend(&binfo, buffer.id(), ",");
857 absl::StrAppend(&binfo, "\"", buffer.ToShortString(), "\",");
858 absl::StrAppend(&binfo, offset_size.offset, ",");
859 absl::StrAppend(&binfo, offset_size.size, ",");
860 absl::StrAppend(&binfo, definition_time, ",");
861 absl::StrAppend(&binfo, end_t, ",");
862 absl::StrAppend(&binfo, use_positions.size(), ",");
863 absl::StrAppend(&binfo, "\"", absl::StrJoin(use_positions, ";"), "\",");
864 absl::StrAppend(&binfo, "\"", absl::StrJoin(use_names, ";"), "\"");
865 absl::StrAppend(&binfo, "\n");
866 }
867 return binfo;
868 }
869
ToProto() const870 BufferAssignmentProto BufferAssignment::ToProto() const {
871 BufferAssignmentProto proto;
872 // NOTE: DataflowAnalysis state is serialized here in BufferAssignment,
873 // because we need to do the HasAllocation check for each buffer. Otherwise
874 // the buffer_size_ call might fail for some backends.
875 const HloDataflowAnalysis& dataflow = this->dataflow_analysis();
876 for (BufferValue::Id id = 0; id < dataflow.values().size(); id++) {
877 auto& value = dataflow.values().at(id);
878 if (HasAllocation(*value)) {
879 LogicalBufferProto proto_buffer = value->ToProto(buffer_size_);
880 proto.add_logical_buffers()->Swap(&proto_buffer);
881
882 // Fill buffer aliases.
883 for (const HloValue* alias :
884 alias_analysis().GetBufferContainingValue(*value).values()) {
885 if (alias->instruction() == value->instruction() &&
886 alias->index() == value->index()) {
887 continue; // skip self-aliases
888 }
889 BufferAssignmentProto::BufferAlias* proto_alias =
890 proto.add_buffer_aliases();
891 LogicalBufferProto::Location proto_alias_location =
892 BufferValue::ToLocationProto(*alias->instruction(), alias->index());
893 proto_alias->set_source_buffer_id(value->id());
894 proto_alias->mutable_location()->Swap(&proto_alias_location);
895 }
896 }
897 }
898 for (const BufferAllocation& allocation : Allocations()) {
899 BufferAllocationProto proto_allocation = allocation.ToProto();
900 proto.add_buffer_allocations()->Swap(&proto_allocation);
901 for (const HeapSimulatorTrace& heap_trace : allocation.HeapTraces()) {
902 *proto.add_heap_simulator_traces() = heap_trace;
903 }
904 }
905 return proto;
906 }
907
908 /* static */
Run(const HloModule * module,std::unique_ptr<HloOrdering> hlo_ordering,BufferValue::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment,bool allocate_buffers_for_constants,BufferAssigner::Colorer colorer,const absl::flat_hash_set<HloOpcode> & must_not_live_out,HloDataflowAnalysis::CanShareBuffer can_share_buffer,std::unique_ptr<PresetAssignments> preset_assignments)909 StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::Run(
910 const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
911 BufferValue::SizeFunction buffer_size,
912 LogicalBuffer::AlignmentFunction color_alignment,
913 bool allocate_buffers_for_constants, BufferAssigner::Colorer colorer,
914 const absl::flat_hash_set<HloOpcode>& must_not_live_out,
915 HloDataflowAnalysis::CanShareBuffer can_share_buffer,
916 std::unique_ptr<PresetAssignments> preset_assignments) {
917 BufferAssigner assigner(allocate_buffers_for_constants, std::move(colorer),
918 must_not_live_out, std::move(preset_assignments));
919 return assigner.CreateAssignment(
920 module, std::move(hlo_ordering), std::move(buffer_size),
921 std::move(color_alignment), std::move(can_share_buffer));
922 }
923
LiveRangeInterferes(const HloValue * buffer1,const HloValue * buffer2,BufferAssignment * assignment)924 bool BufferAssigner::LiveRangeInterferes(const HloValue* buffer1,
925 const HloValue* buffer2,
926 BufferAssignment* assignment) {
927 CHECK((assignment->hlo_live_range().total_order_scheduled()));
928 const HloLiveRange& hlo_live_range = assignment->hlo_live_range();
929
930 const auto& buffer_live_ranges = hlo_live_range.buffer_live_ranges();
931
932 CHECK(buffer_live_ranges.contains(buffer1))
933 << "Buffer doesn't have a proper live range:" << buffer1;
934
935 CHECK(buffer_live_ranges.contains(buffer2))
936 << "Buffer doesn't have a proper live range:" << buffer2;
937
938 // Check if a user value can share the same buffer as its operand.
939 auto can_share_as_operand = [&assignment](const HloValue* user_value,
940 const HloValue* operand_value) {
941 return user_value->instruction()->IsUserOf(operand_value->instruction()) &&
942 assignment->dataflow_analysis().CanShareOperandBufferWithUser(
943 operand_value->instruction(), operand_value->index(),
944 user_value->instruction(), user_value->index()) &&
945 user_value->instruction()->opcode() != HloOpcode::kCopy;
946 };
947
948 auto live_range_1 = buffer_live_ranges.at(buffer1);
949 auto live_range_2 = buffer_live_ranges.at(buffer2);
950
951 if (!(live_range_1.start > live_range_2.end ||
952 live_range_2.start > live_range_1.end)) {
953 if (live_range_1.end == live_range_2.start) {
954 auto operand_value = buffer1;
955 auto user_value = buffer2;
956 if (!can_share_as_operand(user_value, operand_value)) {
957 VLOG(4) << "End of live range of " << buffer1->ToShortString()
958 << " is equal to the start of live range of "
959 << buffer2->ToShortString() << ", buffer cannot be shared.";
960 return true;
961 }
962 } else if (live_range_2.end == live_range_1.start) {
963 auto operand_value = buffer2;
964 auto user_value = buffer1;
965 if (!can_share_as_operand(user_value, operand_value)) {
966 VLOG(4) << "End of live range of " << buffer2->ToShortString()
967 << " is equal to the start of live range of "
968 << buffer1->ToShortString() << ", buffer cannot be shared.";
969 return true;
970 }
971 } else {
972 VLOG(4) << "Can't assign: assignee " << *buffer1 << " may interfere with "
973 << *buffer2;
974 VLOG(4) << "assigned_buffer.start: " << live_range_1.start;
975 VLOG(4) << "assigned_buffer.end: " << live_range_1.end;
976 VLOG(4) << "live_range_2.start" << live_range_2.start;
977 VLOG(4) << "live_range_2.end" << live_range_2.end;
978 return true;
979 }
980 }
981 return false;
982 }
983
MaybeAssignBuffer(BufferAllocation * allocation,const HloBuffer & hlo_buffer,BufferAssignment * assignment)984 bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
985 const HloBuffer& hlo_buffer,
986 BufferAssignment* assignment) {
987 CHECK(!assignment->HasAllocation(hlo_buffer))
988 << "buffer " << hlo_buffer << " already has an allocation assigned.";
989
990 VLOG(4) << "Trying to assign " << hlo_buffer << " size "
991 << assignment->HloBufferSize(hlo_buffer)
992 << " to allocation: " << *allocation;
993
994 if (hlo_buffer.color() != allocation->color()) {
995 VLOG(4) << "Can't assign: buffer has color " << hlo_buffer.color()
996 << " and allocation has color " << allocation->color() << ".";
997 return false;
998 }
999
1000 if (assignment->HloBufferSize(hlo_buffer) > allocation->size()) {
1001 VLOG(4) << "Can't assign: buffer is larger than allocation ("
1002 << assignment->HloBufferSize(hlo_buffer) << " > "
1003 << allocation->size() << ")";
1004 return false;
1005 }
1006
1007 if (allocation->is_readonly()) {
1008 VLOG(4) << "Can't assign: allocation is readonly";
1009 return false;
1010 }
1011
1012 if (!must_not_live_out_.empty()) {
1013 if (allocation->maybe_live_out()) {
1014 // If a buffer maybe live out, the allocation cannot contain any node from
1015 // the "must_not_live_out_" set.
1016 for (const HloValue* value : hlo_buffer.values()) {
1017 if (must_not_live_out_.count(value->instruction()->opcode()) > 0) {
1018 VLOG(4) << "Can't assign: " << value->instruction()->ToString()
1019 << " cannot live out of the module";
1020 return false;
1021 }
1022 }
1023 }
1024 // The above check is not enough -- There could be the case where an
1025 // allocation can be not live out and contains an instruction with opcode
1026 // from the "must_not_live_out_" set, but assigning a live out buffer to
1027 // that allocation makes the allocation live out and also contains
1028 // instruction from the "must_not_live_out_" set.
1029 if (assignment->alias_analysis().BufferLivesOut(hlo_buffer)) {
1030 for (const auto& buffer_offset_size : allocation->assigned_buffers()) {
1031 if (must_not_live_out_.count(
1032 buffer_offset_size.first->instruction()->opcode()) > 0) {
1033 VLOG(4) << "Can't assign: " << buffer_offset_size.first->instruction()
1034 << " cannot live out of the module";
1035 return false;
1036 }
1037 }
1038 }
1039 }
1040
1041 if (!allocation->is_reusable()) {
1042 VLOG(4) << "Can't assign: allocation is not reusable";
1043 return false;
1044 }
1045
1046 for (const auto& buffer_offset_size : allocation->assigned_buffers()) {
1047 // Pairwise compare.
1048 const HloValue& assigned_buffer =
1049 *CHECK_NOTNULL(dynamic_cast<const HloValue*>(buffer_offset_size.first));
1050 for (const HloValue* new_value : hlo_buffer.values()) {
1051 if (assignment->hlo_live_range().total_order_scheduled()) {
1052 if (LiveRangeInterferes(new_value, &assigned_buffer, assignment)) {
1053 VLOG(4) << "Can't assign: assignee " << assigned_buffer
1054 << " live range interferes with "
1055 << new_value->ToShortString();
1056 return false;
1057 }
1058 } else if (assignment->hlo_ordering().MayInterfere(
1059 assigned_buffer, *new_value,
1060 assignment->dataflow_analysis())) {
1061 // Fallback to partial order based interference detection (slower) when
1062 // we don't have a total order scheduled module.
1063 VLOG(4) << "Can't assign: assignee " << assigned_buffer
1064 << " may interfere with " << new_value->ToShortString();
1065 return false;
1066 }
1067
1068 for (const HloPosition& assigned_buffer_position :
1069 assigned_buffer.positions()) {
1070 // Copy instruction don't share a buffer with their input operand.
1071 if (new_value->instruction()->IsUserOf(
1072 assigned_buffer_position.instruction) &&
1073 new_value->instruction()->opcode() == HloOpcode::kCopy) {
1074 VLOG(4) << "Can't assign: assignee " << assigned_buffer
1075 << " is used at copy instruction "
1076 << new_value->ToShortString();
1077 return false;
1078 }
1079 }
1080 }
1081 }
1082
1083 // If the buffer is live out of the computation then it should only be
1084 // assigned a buffer which exactly fits the result to avoid wasting memory
1085 // (result buffers can have arbitrary lifetimes).
1086 if (assignment->alias_analysis().BufferLivesOut(hlo_buffer) &&
1087 allocation->size() != assignment->HloBufferSize(hlo_buffer)) {
1088 VLOG(4) << "Can't assign: buffer " << hlo_buffer
1089 << "is live out and size not the same as allocation";
1090 return false;
1091 }
1092
1093 assignment->AddAssignment(allocation, hlo_buffer, /*offset=*/0,
1094 assignment->HloBufferSize(hlo_buffer));
1095 return true;
1096 } // namespace xla
1097
AssignSingleHloBuffer(const HloBuffer * hlo_buffer,bool is_thread_local,absl::flat_hash_map<const HloComputation *,absl::flat_hash_set<const HloValue * >> * buffers_to_assign_sequentially,std::vector<BufferAllocation::Index> * allocation_indices,BufferAssignment * assignment)1098 Status BufferAssigner::AssignSingleHloBuffer(
1099 const HloBuffer* hlo_buffer, bool is_thread_local,
1100 absl::flat_hash_map<const HloComputation*,
1101 absl::flat_hash_set<const HloValue*>>*
1102 buffers_to_assign_sequentially,
1103 std::vector<BufferAllocation::Index>* allocation_indices,
1104 BufferAssignment* assignment) {
1105 const int64 buffer_size = assignment->HloBufferSize(*hlo_buffer);
1106 for (const HloValue* value : hlo_buffer->values()) {
1107 if (value->instruction()->opcode() == HloOpcode::kConstant) {
1108 if (allocate_buffers_for_constants_) {
1109 BufferAllocation* allocation =
1110 assignment->NewAllocation(*hlo_buffer, buffer_size);
1111 allocation->set_constant(true);
1112 VLOG(3) << "New allocation #" << allocation->index() << " for constant "
1113 << *hlo_buffer << " value ptr: " << value;
1114 }
1115 VLOG(3) << "Not allocating buffer for constant";
1116 return Status::OK();
1117 }
1118
1119 const HloInstruction* instruction = value->instruction();
1120 const bool is_entry_parameter =
1121 instruction->opcode() == HloOpcode::kParameter &&
1122 instruction->parent() ==
1123 instruction->parent()->parent()->entry_computation();
1124
1125 if (is_entry_parameter) {
1126 bool parameter_has_alias =
1127 assignment->module().input_output_alias_config().ParameterHasAlias(
1128 instruction->parameter_number(), value->index());
1129 // If the hlo buffer is part of an external parameter, creates a new
1130 // allocation and sets its parameter number. Parameters of non-entry
1131 // computations do not need special allocations because they live inside
1132 // callers.
1133 BufferAllocation* allocation =
1134 assignment->NewAllocation(*hlo_buffer, buffer_size);
1135
1136 allocation->set_entry_computation_parameter(
1137 instruction->parameter_number(), value->index(), parameter_has_alias);
1138 if (parameter_has_alias) {
1139 allocation_indices->push_back(allocation->index());
1140 }
1141 VLOG(3) << "New allocation #" << allocation->index()
1142 << " marked as entry computation parameter: " << *hlo_buffer;
1143 return Status::OK();
1144 }
1145 }
1146
1147 if (is_thread_local) {
1148 BufferAllocation* allocation =
1149 assignment->NewAllocation(*hlo_buffer, buffer_size);
1150 allocation->set_is_thread_local(true);
1151 VLOG(3) << "New allocation #" << allocation->index()
1152 << " for thread-local: " << *hlo_buffer;
1153 return Status::OK();
1154 }
1155
1156 for (const HloValue* value : hlo_buffer->values()) {
1157 if (value->shape().IsTuple()) {
1158 BufferAllocation* allocation =
1159 assignment->NewAllocation(*hlo_buffer, buffer_size);
1160 allocation->set_is_tuple(true);
1161 VLOG(3) << "New allocation #" << allocation->index()
1162 << " for tuple-shaped buffer: " << *hlo_buffer;
1163 return Status::OK();
1164 }
1165
1166 if (value->IsTopLevel() && !value->IsTuple()) {
1167 const HloInstruction* instruction = value->instruction();
1168 for (auto* operand : instruction->operands()) {
1169 for (const auto& operand_slice :
1170 assignment->GetAllSlices(operand, /*index=*/{})) {
1171 BufferAllocation* allocation =
1172 assignment->GetMutableAllocation(operand_slice.index());
1173 if (MaybeAssignBuffer(allocation, *hlo_buffer, assignment)) {
1174 VLOG(3) << "Reusing (operand) allocation #" << allocation->index()
1175 << " for: " << *hlo_buffer;
1176 return Status::OK();
1177 }
1178 }
1179 }
1180 }
1181 }
1182
1183 // Find the smallest buffer which can be reused iterating from end of
1184 // allocation_indices (smallest) to beginning (largest).
1185 for (int allocation_index = allocation_indices->size() - 1;
1186 allocation_index >= 0; allocation_index--) {
1187 BufferAllocation* allocation = assignment->GetMutableAllocation(
1188 allocation_indices->at(allocation_index));
1189 if (MaybeAssignBuffer(allocation, *hlo_buffer, assignment)) {
1190 VLOG(3) << "Reusing allocation #" << allocation->index()
1191 << " for: " << *hlo_buffer;
1192 return Status::OK();
1193 }
1194 }
1195
1196 if (!assignment->HasAllocation(*hlo_buffer) &&
1197 !assignment->alias_analysis().BufferLivesOut(*hlo_buffer)) {
1198 bool all_computations_have_sequential_order = true;
1199 for (const HloValue* hlo_value : hlo_buffer->values()) {
1200 HloComputation* computation = hlo_value->instruction()->parent();
1201 const bool has_sequential_order =
1202 assignment->hlo_ordering().SequentialOrder(*computation) != nullptr;
1203 all_computations_have_sequential_order &= has_sequential_order;
1204 }
1205
1206 if (all_computations_have_sequential_order) {
1207 for (const HloValue* hlo_value : hlo_buffer->values()) {
1208 HloComputation* computation = hlo_value->instruction()->parent();
1209 // There is a sequential instruction ordering, so we delay assignment
1210 // of temp buffers until after the loop. We do this right before we
1211 // decide to create a new allocation, to ensure we've exhausted all
1212 // the buffer re-use cases above.
1213 //
1214 // Entry parameters and thread local buffers were already handled
1215 // earlier in this loop iteration. See
1216 // BufferAllocation::IsPreallocatedTempBuffer for the definition of
1217 // temp buffers.
1218 (*buffers_to_assign_sequentially)[computation].insert(hlo_value);
1219 VLOG(3) << "Delaying assignment of temp buffer: " << *hlo_value;
1220 }
1221 return Status::OK();
1222 }
1223 }
1224
1225 if (!assignment->HasAllocation(*hlo_buffer)) {
1226 BufferAllocation* allocation =
1227 assignment->NewAllocation(*hlo_buffer, buffer_size);
1228 allocation_indices->push_back(allocation->index());
1229 VLOG(3) << "New allocation #" << allocation->index()
1230 << " for: " << *hlo_buffer;
1231 }
1232
1233 TF_RET_CHECK(assignment->HasAllocation(*hlo_buffer));
1234 return Status::OK();
1235 }
1236
AssignBuffersForComputations(const std::vector<const HloComputation * > & computations,bool is_thread_local,absl::flat_hash_map<const HloComputation *,absl::flat_hash_set<const HloValue * >> * buffers_to_assign_sequentially,BufferAssignment * assignment)1237 Status BufferAssigner::AssignBuffersForComputations(
1238 const std::vector<const HloComputation*>& computations,
1239 bool is_thread_local,
1240 absl::flat_hash_map<const HloComputation*,
1241 absl::flat_hash_set<const HloValue*>>*
1242 buffers_to_assign_sequentially,
1243 BufferAssignment* assignment) {
1244 if (computations.empty()) {
1245 return Status::OK();
1246 }
1247 std::vector<const HloBuffer*> sorted_buffers;
1248
1249 // First assign the preset allocations.
1250 absl::flat_hash_set<const HloBuffer*> preset_assigned_buffers;
1251
1252 TF_RETURN_IF_ERROR(AssignPresetBuffers(&preset_assigned_buffers, assignment));
1253
1254 const HloAliasAnalysis& alias_analysis = assignment->alias_analysis();
1255
1256 for (const HloBuffer& buffer : alias_analysis.buffers()) {
1257 // Skip if the buffer is already assigned since it had a preset allocation.
1258 if (preset_assigned_buffers.find(&buffer) !=
1259 preset_assigned_buffers.end()) {
1260 VLOG(3) << "Skip allocation for buffer: " << buffer;
1261 continue;
1262 }
1263 TF_RET_CHECK(!buffer.values().empty());
1264 const HloComputation* comp = buffer.values()[0]->instruction()->parent();
1265 if (absl::c_linear_search(computations, comp)) {
1266 sorted_buffers.push_back(&buffer);
1267 }
1268 }
1269
1270 // Generate a post order sort of instructions for sorting of the
1271 // HloBuffers.
1272 flat_hash_map<const HloInstruction*, int> post_order_position;
1273 int position = 0;
1274 std::vector<const HloComputation*> reverse_post_order_computations;
1275 std::unique_ptr<CallGraph> call_graph =
1276 CallGraph::Build(computations[0]->parent());
1277 TF_RETURN_IF_ERROR(call_graph->VisitNodes([&](const CallGraphNode& node) {
1278 if (absl::c_linear_search(computations, node.computation())) {
1279 reverse_post_order_computations.push_back(node.computation());
1280 }
1281 return Status::OK();
1282 }));
1283 absl::c_reverse(reverse_post_order_computations);
1284 for (auto* computation : reverse_post_order_computations) {
1285 for (auto* instruction : computation->MakeInstructionPostOrder()) {
1286 post_order_position.emplace(instruction, position);
1287 position++;
1288 }
1289 }
1290
1291 HloSchedule schedule(&assignment->module());
1292
1293 for (const HloComputation* computation : computations) {
1294 const HloInstructionSequence* instruction_sequence =
1295 assignment->hlo_ordering().SequentialOrder(*computation);
1296 const bool has_sequential_order = instruction_sequence != nullptr;
1297 if (has_sequential_order && buffers_to_assign_sequentially != nullptr) {
1298 // Every sequential computation must get an entry in the
1299 // buffers_to_assign_sequentially map, even if we end up with an empty
1300 // set of buffers. This ensures we can correctly determine whether to
1301 // run whole-module heap simulation.
1302 buffers_to_assign_sequentially->emplace(computation,
1303 flat_hash_set<const HloValue*>());
1304
1305 schedule.set_sequence(computation, *instruction_sequence);
1306 }
1307 }
1308
1309 absl::c_sort(
1310 sorted_buffers, [&post_order_position, &alias_analysis, assignment](
1311 const HloBuffer* a, const HloBuffer* b) {
1312 // Primary sort is by decreasing buffer size.
1313 const int64 a_size = assignment->HloBufferSize(*a);
1314 const int64 b_size = assignment->HloBufferSize(*b);
1315 if (a_size != b_size) {
1316 return a_size > b_size; // use ">" for decreasing size.
1317 }
1318
1319 const bool a_live_out = alias_analysis.BufferLivesOut(*a);
1320 const bool b_live_out = alias_analysis.BufferLivesOut(*b);
1321 if (a_live_out != b_live_out) {
1322 return a_live_out;
1323 }
1324 auto compare = [&post_order_position](const HloValue* value1,
1325 const HloValue* value2) {
1326 return post_order_position.at(value1->instruction()) <
1327 post_order_position.at(value2->instruction());
1328 };
1329 const HloValue* a_min = *absl::c_min_element(a->values(), compare);
1330 const HloValue* b_min = *absl::c_min_element(b->values(), compare);
1331 return compare(a_min, b_min);
1332 });
1333
1334 std::vector<BufferAllocation::Index> allocation_indices;
1335
1336 for (const HloBuffer* buffer : sorted_buffers) {
1337 VLOG(3) << "=================================================";
1338 VLOG(3) << "Assigning buffer for " << *buffer;
1339 TF_RETURN_IF_ERROR(AssignSingleHloBuffer(buffer, is_thread_local,
1340 buffers_to_assign_sequentially,
1341 &allocation_indices, assignment));
1342 }
1343 return Status::OK();
1344 }
1345
1346 flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>>
SplitBuffersByColor(const flat_hash_set<const HloValue * > & buffers)1347 BufferAssigner::SplitBuffersByColor(
1348 const flat_hash_set<const HloValue*>& buffers) {
1349 flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>> color_map;
1350 for (auto buffer : buffers) {
1351 color_map[buffer->color()].insert(buffer);
1352 }
1353 return color_map;
1354 }
1355
AssignPresetBuffers(absl::flat_hash_set<const HloBuffer * > * assigned_buffers,BufferAssignment * assignment)1356 Status BufferAssigner::AssignPresetBuffers(
1357 absl::flat_hash_set<const HloBuffer*>* assigned_buffers,
1358 BufferAssignment* assignment) {
1359 if (!preset_assignments_) {
1360 return Status::OK();
1361 }
1362
1363 // Create an allocation for each preset color.
1364 absl::flat_hash_map<LogicalBuffer::Color, BufferAllocation*>
1365 preset_allocations;
1366 for (auto& color_and_info : preset_assignments_->assignment_informations()) {
1367 LogicalBuffer::Color color(color_and_info.first);
1368 auto inserted = preset_allocations.emplace(
1369 color,
1370 assignment->NewEmptyAllocation(color_and_info.second.size, color));
1371 BufferAllocation* inserted_allocation = inserted.first->second;
1372 inserted_allocation->AddHeapTrace(
1373 color_and_info.second.heap_simulator_trace);
1374 VLOG(3) << "Created preset buffer allocation "
1375 << inserted_allocation->index()
1376 << ", color: " << inserted_allocation->color()
1377 << ", size: " << inserted_allocation->size();
1378 }
1379
1380 const HloAliasAnalysis& alias_analysis = assignment->alias_analysis();
1381
1382 for (auto& position_and_chunk : preset_assignments_->chunks()) {
1383 const HloPosition& defining_position = position_and_chunk.first;
1384 const HloBuffer& buffer = alias_analysis.GetUniqueBufferAt(
1385 defining_position.instruction, defining_position.index);
1386 for (const HloValue* value : buffer.values()) {
1387 VLOG(3) << "Preset allocation for value: " << value->ToShortString();
1388 const HeapSimulator::Chunk& chunk = position_and_chunk.second;
1389 auto preset_allocations_iter = preset_allocations.find(value->color());
1390 CHECK(preset_allocations_iter != preset_allocations.end())
1391 << "No preset value allocation for color " << value->color()
1392 << " for " << value->ToShortString() << " found.";
1393 preset_allocations_iter->second->AddAssignment(*value, chunk.offset,
1394 chunk.size);
1395 }
1396
1397 assigned_buffers->insert(&buffer);
1398 }
1399
1400 // Upon consumption of the preset assignments, delete it so that if this
1401 // method is called again, it does not assign the same buffers multiple times.
1402 preset_assignments_ = {};
1403
1404 return Status::OK();
1405 }
1406
AssignBuffersWithSequentialOrdering(const flat_hash_map<const HloComputation *,flat_hash_set<const HloValue * >> & buffers_to_assign_sequentially,bool run_whole_module_heap_simulation,BufferAssignment * assignment)1407 Status BufferAssigner::AssignBuffersWithSequentialOrdering(
1408 const flat_hash_map<const HloComputation*, flat_hash_set<const HloValue*>>&
1409 buffers_to_assign_sequentially,
1410 bool run_whole_module_heap_simulation, BufferAssignment* assignment) {
1411 // Run the sequence of instructions through the heap simulator. The
1412 // heuristic that seems to give the best results is lazy-best-fit, with all
1413 // runs of alloc / free calls sorted in decreasing size order.
1414 const HloOrdering& hlo_ordering = assignment->hlo_ordering();
1415
1416 // Returns a heap algorithm that chooses the best result from several
1417 // algorithms.
1418 auto get_heap_algorithm = [&](int64 alignment) {
1419 auto algorithms = absl::make_unique<
1420 std::vector<std::unique_ptr<HeapAlgorithm<HloValue>>>>();
1421 algorithms->push_back(
1422 absl::make_unique<ConstrainedGlobalDecreasingSizeBestFitHeap>(
1423 assignment->multiheap_size_constraint_per_heap(), alignment,
1424 GlobalDecreasingSizeBestFitHeap<HloValue>::kSpatial));
1425 algorithms->push_back(
1426 absl::make_unique<ConstrainedGlobalDecreasingSizeBestFitHeap>(
1427 assignment->multiheap_size_constraint_per_heap(), alignment,
1428 GlobalDecreasingSizeBestFitHeap<HloValue>::kTemporal));
1429 return absl::make_unique<ChooseBestHeapAlgorithm<HloValue>>(
1430 std::move(algorithms));
1431 };
1432
1433 if (run_whole_module_heap_simulation) {
1434 // Run the heap simulation over the whole module. This reduces memory
1435 // usage, since buffers for kCall, kWhile, and kConditional
1436 // sub-computations are only live for the duration of their calling
1437 // instructions.
1438 VLOG(1) << "Running whole-module heap simulation";
1439 HloSchedule schedule(&assignment->module());
1440 flat_hash_set<const HloValue*> all_buffers_to_assign;
1441 for (const auto& pair : buffers_to_assign_sequentially) {
1442 const HloComputation* computation = pair.first;
1443 const flat_hash_set<const HloValue*>& buffers_to_assign = pair.second;
1444 const HloInstructionSequence* instruction_sequence =
1445 hlo_ordering.SequentialOrder(*computation);
1446 CHECK(instruction_sequence != nullptr) << computation->name();
1447 schedule.set_sequence(computation, *instruction_sequence);
1448 all_buffers_to_assign.insert(buffers_to_assign.begin(),
1449 buffers_to_assign.end());
1450 }
1451 auto color_map = SplitBuffersByColor(all_buffers_to_assign);
1452 for (auto& single_colored_set : color_map) {
1453 auto color = single_colored_set.first;
1454 VLOG(2) << "Simulating heap for color " << color;
1455 int64 alignment = assignment->color_alignment_(color);
1456 HeapSimulator::Options options;
1457 options.alloc_constants = allocate_buffers_for_constants_;
1458 options.buffers_to_assign = &single_colored_set.second;
1459
1460 TF_ASSIGN_OR_RETURN(
1461 HeapSimulator::Result<HloValue> result,
1462 HeapSimulator::Run(
1463 get_heap_algorithm(alignment), assignment->module(), schedule,
1464 assignment->alias_analysis(), assignment->buffer_size_, options));
1465 AssignBuffersFromHeapSimulator(result, assignment,
1466 single_colored_set.first);
1467 }
1468 } else {
1469 // Run the heap-simulation on a per-computation basis. Buffers for
1470 // sub-computations are assigned disjoint BufferAllocations, assuming the
1471 // worst-case that they may all be live concurrently.
1472 VLOG(1) << "Running per-computation heap simulation";
1473 for (const auto& pair : buffers_to_assign_sequentially) {
1474 const HloComputation* computation = pair.first;
1475 const flat_hash_set<const HloValue*>& buffers_to_assign = pair.second;
1476 const HloInstructionSequence* instruction_sequence =
1477 hlo_ordering.SequentialOrder(*computation);
1478 CHECK(instruction_sequence != nullptr) << computation->name();
1479 auto color_map = SplitBuffersByColor(buffers_to_assign);
1480 for (auto& single_colored_set : color_map) {
1481 auto color = single_colored_set.first;
1482 VLOG(2) << "Simulating heap for color " << color;
1483 int64 alignment = assignment->color_alignment_(color);
1484 HeapSimulator::Options options;
1485 options.buffers_to_assign = &single_colored_set.second;
1486 TF_ASSIGN_OR_RETURN(
1487 HeapSimulator::Result<HloValue> result,
1488 HeapSimulator::Run(get_heap_algorithm(alignment), *computation,
1489 *instruction_sequence,
1490 assignment->alias_analysis(),
1491 assignment->buffer_size_, options));
1492 AssignBuffersFromHeapSimulator(result, assignment,
1493 single_colored_set.first);
1494 }
1495 }
1496 }
1497 return Status::OK();
1498 }
1499
1500 namespace {
1501 // Computes and returns the set of logical buffers live at the point of
1502 // maximal liveness in the given heap trace. LogicalBuffers are (stabily)
1503 // sorted by id.
ComputePeakMemoryLogicalBuffers(const BufferAllocation & allocation,const HeapSimulatorTrace & heap_trace)1504 std::vector<const HloValue*> ComputePeakMemoryLogicalBuffers(
1505 const BufferAllocation& allocation, const HeapSimulatorTrace& heap_trace) {
1506 // Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical
1507 // buffers in this allocation.
1508 absl::flat_hash_map<BufferValue::Id, const HloValue*> id_to_value;
1509 absl::flat_hash_map<const HloValue*, int64> buffer_sizes;
1510 for (const auto& pair : allocation.assigned_buffers()) {
1511 const HloValue* value = pair.first;
1512 const BufferAllocation::OffsetSize& offset_size = pair.second;
1513 id_to_value[value->id()] = value;
1514 buffer_sizes[value] = offset_size.size;
1515 }
1516 VLOG(1) << "Compute peak memory logical buffers";
1517
1518 // Returns how much the given event increases the total size of live
1519 // buffers. Can be negative.
1520 auto memory_delta = [&id_to_value, &buffer_sizes](
1521 const HeapSimulatorTrace::Event& event) -> int64 {
1522 const HloValue* buffer = id_to_value.at(event.buffer_id());
1523 const int64 buffer_size = buffer_sizes.at(buffer);
1524 if (event.kind() == HeapSimulatorTrace::Event::ALLOC ||
1525 event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
1526 return buffer_size;
1527 } else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
1528 return -1 * buffer_size;
1529 }
1530 LOG(FATAL) << "Unknown event kind: " << event.kind();
1531 };
1532
1533 // First compute the size of the maximal live set.
1534 int64 max_live_size = 0;
1535 int64 live_size = 0;
1536 for (const auto& event : heap_trace.events()) {
1537 if (!id_to_value.contains(event.buffer_id())) {
1538 // Skip as the buffer associated with this trace event is not placed into
1539 // this allocation. This can happen when size constraints are given to the
1540 // heap simulator.
1541 continue;
1542 }
1543 live_size += memory_delta(event);
1544 if (max_live_size < live_size) {
1545 max_live_size = live_size;
1546 }
1547 }
1548
1549 // Next gather the set of logical buffers live at the earliest point of
1550 // maximal live set size.
1551 absl::flat_hash_set<const HloValue*> live_values;
1552 live_size = 0;
1553 for (const auto& event : heap_trace.events()) {
1554 if (!id_to_value.contains(event.buffer_id())) {
1555 // Skip as the buffer associated with this trace event is not placed into
1556 // this allocation. This can happen when size constraints are given to the
1557 // heap simulator.
1558 continue;
1559 }
1560 const HloValue* value = id_to_value.at(event.buffer_id());
1561 if (event.kind() == HeapSimulatorTrace::Event::ALLOC ||
1562 event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
1563 InsertOrDie(&live_values, value);
1564 } else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
1565 CHECK(ContainsKey(live_values, value));
1566 live_values.erase(value);
1567 }
1568 live_size += memory_delta(event);
1569
1570 if (live_size == max_live_size) {
1571 break;
1572 }
1573 }
1574 CHECK_EQ(live_size, max_live_size);
1575
1576 std::vector<const HloValue*> live_values_vector;
1577 live_values_vector.insert(live_values_vector.end(), live_values.begin(),
1578 live_values.end());
1579
1580 // Stabily sort the live buffers.
1581 absl::c_sort(live_values_vector, [](const HloValue* a, const HloValue* b) {
1582 return a->id() < b->id();
1583 });
1584 VLOG(4) << "Peak memory buffer:";
1585 for (auto value : live_values_vector) {
1586 VLOG(4) << " " << value->ToString();
1587 }
1588 return live_values_vector;
1589 }
1590
1591 } // namespace
1592
AssignBuffersFromHeapSimulator(const HeapSimulator::Result<HloValue> & result,BufferAssignment * assignment,BufferValue::Color color)1593 void BufferAssigner::AssignBuffersFromHeapSimulator(
1594 const HeapSimulator::Result<HloValue>& result, BufferAssignment* assignment,
1595 BufferValue::Color color) {
1596 if (assignment->stats_.preallocated_temp_fragmentation_bytes == -1) {
1597 assignment->stats_.preallocated_temp_fragmentation_bytes =
1598 result.fragmentation_size;
1599 } else {
1600 assignment->stats_.preallocated_temp_fragmentation_bytes +=
1601 result.fragmentation_size;
1602 }
1603 VLOG(1) << "Result size from heap simulator: " << result.heap_size;
1604
1605 // Iterate through heap_results. For each heap_result, create a new allocation
1606 // in `assignment`.
1607 for (const HeapSimulator::HeapResult<HloValue>& heap_result :
1608 result.heap_results) {
1609 BufferAllocation* allocation =
1610 assignment->NewEmptyAllocation(heap_result.heap_size, color);
1611 for (const auto& buffer_chunk : heap_result.chunk_map) {
1612 const HloValue& value = *buffer_chunk.first;
1613 const HeapSimulator::Chunk& chunk = buffer_chunk.second;
1614 assignment->AddAssignment(allocation, value, chunk.offset, chunk.size);
1615 }
1616 allocation->peak_buffers_ =
1617 ComputePeakMemoryLogicalBuffers(*allocation, result.debug_trace);
1618
1619 XLA_VLOG_LINES(2, allocation->ToString());
1620
1621 allocation->AddHeapTrace(result.debug_trace);
1622 }
1623 }
1624
CreateAssignment(const HloModule * module,std::unique_ptr<HloOrdering> hlo_ordering,BufferValue::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment,HloDataflowAnalysis::CanShareBuffer can_share_buffer)1625 StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
1626 const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
1627 BufferValue::SizeFunction buffer_size,
1628 LogicalBuffer::AlignmentFunction color_alignment,
1629 HloDataflowAnalysis::CanShareBuffer can_share_buffer) {
1630 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1631 HloAliasAnalysis::Run(module, can_share_buffer));
1632
1633 // Set up a schedule for each computation.
1634 HloSchedule schedule(module);
1635 for (const HloComputation* computation : module->computations()) {
1636 const HloInstructionSequence* instruction_sequence =
1637 hlo_ordering->SequentialOrder(*computation);
1638 const bool has_sequential_order = instruction_sequence != nullptr;
1639 if (has_sequential_order) {
1640 schedule.set_sequence(computation, *instruction_sequence);
1641 }
1642 }
1643
1644 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
1645 HloLiveRange::Run(schedule, *alias_analysis,
1646 module->entry_computation(), true));
1647
1648 VLOG(1) << "Assigning buffers to module " << module->name();
1649 XLA_VLOG_LINES(3, module->ToString());
1650 XLA_VLOG_LINES(3, alias_analysis->ToString());
1651 XLA_VLOG_LINES(3, alias_analysis->dataflow_analysis().ToString());
1652 VLOG(1) << "Number of buffers to assign: "
1653 << alias_analysis->buffers().size();
1654
1655 // Can't use absl::make_unique because BufferAssignment constructor is
1656 // private.
1657 std::unique_ptr<BufferAssignment> assignment(new BufferAssignment(
1658 module, std::move(hlo_ordering), std::move(buffer_size),
1659 std::move(color_alignment), std::move(alias_analysis),
1660 std::move(hlo_live_range)));
1661
1662 TF_RETURN_IF_ERROR(
1663 colorer_(&assignment->alias_analysis(), assignment->hlo_ordering()));
1664 VLOG(3) << "After coloring:";
1665 XLA_VLOG_LINES(3,
1666 assignment->alias_analysis().dataflow_analysis().ToString());
1667
1668 std::vector<const HloComputation*> thread_local_computations;
1669 std::vector<const HloComputation*> global_computations;
1670 TF_RETURN_IF_ERROR(GatherComputationsByAllocationType(
1671 module, &thread_local_computations, &global_computations));
1672
1673 // First assign buffers for global computations. Temporary buffers for
1674 // sequential computations are collected in
1675 // 'buffers_to_assign_sequentially'.
1676 flat_hash_map<const HloComputation*, flat_hash_set<const HloValue*>>
1677 buffers_to_assign_sequentially;
1678 TF_RETURN_IF_ERROR(AssignBuffersForComputations(
1679 global_computations,
1680 /*is_thread_local=*/false, &buffers_to_assign_sequentially,
1681 assignment.get()));
1682 // Assign buffers with sequential ordering, if any. If all global
1683 // computations are sequential, we can run heap simulation on the whole
1684 // module, which reduces memory usage.
1685 const bool run_whole_module_heap_simulation =
1686 buffers_to_assign_sequentially.size() == global_computations.size();
1687 VLOG(2) << "Running whole module heap simulation: "
1688 << run_whole_module_heap_simulation;
1689 const int32 multiheap_size_constraint_per_heap =
1690 module->config().debug_options().xla_multiheap_size_constraint_per_heap();
1691 VLOG(2) << "Multiheap per heap size limit: "
1692 << multiheap_size_constraint_per_heap;
1693 TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering(
1694 buffers_to_assign_sequentially, run_whole_module_heap_simulation,
1695 assignment.get()));
1696
1697 std::vector<const HloComputation*> thread_local_computations_no_fusion;
1698 // Now assign buffers for thread-local computations. All LogicalBuffers get
1699 // their own BufferAllocation.
1700
1701 for (auto* computation : thread_local_computations) {
1702 TF_RET_CHECK(computation != module->entry_computation());
1703 if (computation->IsFusionComputation()) {
1704 continue;
1705 }
1706 thread_local_computations_no_fusion.push_back(computation);
1707 }
1708
1709 TF_RETURN_IF_ERROR(AssignBuffersForComputations(
1710 thread_local_computations_no_fusion,
1711 /*is_thread_local=*/true,
1712 /*buffers_to_assign_sequentially=*/nullptr, assignment.get()));
1713
1714 // Mark all buffers which may be live out of the entry computation as
1715 // "liveout".
1716 for (const HloBuffer* buffer :
1717 assignment->alias_analysis().LiveOutBuffers()) {
1718 VLOG(3) << "maybe_live_out LogicalBuffer: " << *buffer;
1719 if (assignment->HasAllocation(*buffer)) {
1720 BufferAllocation* alloc =
1721 assignment->GetMutableAssignedAllocation(*buffer);
1722 alloc->set_maybe_live_out(true);
1723 VLOG(3) << "maybe_live_out BufferAllocation: " << *alloc;
1724 }
1725 }
1726
1727 // Combines allocations of temporary buffers into big BufferAllocations
1728 // subject to the buffer allocation size constraint. This can only be
1729 // performed after all buffers have been assigned, and after maybe_live_out
1730 // is marked, since it is used to determine whether an allocation contains
1731 // temporary buffers or not.
1732 assignment->CombineTempAllocations();
1733
1734 XLA_VLOG_LINES(2, assignment->ToString());
1735 TF_RETURN_IF_ERROR(assignment->ComputeSummaryStats());
1736 XLA_VLOG_LINES(1, assignment->GetStats().ToString());
1737 VLOG(1) << "Buffer assignment done.";
1738 return std::move(assignment);
1739 }
1740
1741 } // namespace xla
1742