1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Defines the data returned by the XLA buffer assignment packages.
17
18 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
19
20 #include <algorithm>
21 #include <deque>
22 #include <ostream>
23 #include <utility>
24
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/str_format.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/buffer_value_containers.h"
33 #include "tensorflow/compiler/xla/service/heap_simulator.h"
34 #include "tensorflow/compiler/xla/service/hlo.pb.h"
35 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
36 #include "tensorflow/compiler/xla/service/hlo_buffer.h"
37 #include "tensorflow/compiler/xla/service/hlo_live_range.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/service/hlo_value.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/status_macros.h"
42 #include "tensorflow/compiler/xla/types.h"
43 #include "tensorflow/compiler/xla/util.h"
44 #include "tensorflow/core/lib/core/errors.h"
45 #include "tensorflow/core/lib/hash/hash.h"
46 #include "tensorflow/core/lib/strings/numbers.h"
47
48 namespace xla {
49 namespace {
50
51 using absl::flat_hash_map;
52 using absl::flat_hash_set;
53 using absl::StrAppend;
54 using absl::StrAppendFormat;
55 using memory_space_assignment::PresetAssignments;
56 using ::tensorflow::strings::HumanReadableNumBytes;
57
58 // Given the interference map of a graph (the list of interfering node indices
59 // for each node), perform graph coloring such that interfering nodes are
60 // assigned to different colors. Returns the assigned color of the nodes, where
61 // the colors are represented as integer values [0, color_count).
ColorInterferenceGraph(const std::vector<std::vector<int64>> & interference_map)62 std::vector<int64> ColorInterferenceGraph(
63 const std::vector<std::vector<int64>>& interference_map) {
64 const int64_t node_count = interference_map.size();
65
66 // Sort the nodes such that we assign nodes with more interference first. This
67 // relies on the common heuristic of assigning the most constrained node
68 // first, but it would be good to investigate other ordering heuristics too.
69 std::vector<int64> nodes(node_count);
70 std::iota(nodes.begin(), nodes.end(), 0);
71 absl::c_sort(nodes, [&interference_map](const int64_t i, const int64_t j) {
72 return interference_map[i].size() > interference_map[j].size();
73 });
74
75 const int64_t kColorUnassigned = -1;
76 std::vector<int64> assigned_colors(node_count, kColorUnassigned);
77 for (int64_t node : nodes) {
78 // Mark the colors that are already assigned to the neighbors.
79 std::vector<bool> available_colors(node_count, true);
80 for (int64_t neighbor : interference_map[node]) {
81 int64_t color = assigned_colors[neighbor];
82 if (color != kColorUnassigned) {
83 available_colors[color] = false;
84 }
85 }
86
87 // Find the color that is not yet assigned to the neighbors.
88 int64_t color = kColorUnassigned;
89 for (color = 0; color < available_colors.size(); ++color) {
90 if (available_colors[color]) {
91 break;
92 }
93 }
94 CHECK_NE(color, kColorUnassigned);
95 assigned_colors[node] = color;
96 }
97 return assigned_colors;
98 }
99
100 // If an hlo buffer contains an entry parameter, the buffer is read-only unless
101 // it is aliased with an output.
HloBufferIsReadOnly(const HloBuffer & buffer)102 bool HloBufferIsReadOnly(const HloBuffer& buffer) {
103 for (const HloValue* value : buffer.values()) {
104 const HloInstruction* instruction = value->instruction();
105 if (instruction->opcode() == HloOpcode::kConstant) {
106 return true;
107 }
108 const HloModule* module = instruction->parent()->parent();
109 const bool is_entry_parameter =
110 instruction->opcode() == HloOpcode::kParameter &&
111 instruction->parent() == module->entry_computation();
112
113 if (is_entry_parameter) {
114 bool parameter_has_alias =
115 module->input_output_alias_config().ParameterHasAlias(
116 instruction->parameter_number(), value->index());
117 // The parameter doesn't have an alias, it must be read-only.
118 if (!parameter_has_alias) {
119 return true;
120 }
121 }
122 }
123 return false;
124 }
125
126 } // namespace
127
GatherComputationsByAllocationType(const HloModule * module,std::vector<const HloComputation * > * thread_local_computations,std::vector<const HloComputation * > * global_computations)128 Status GatherComputationsByAllocationType(
129 const HloModule* module,
130 std::vector<const HloComputation*>* thread_local_computations,
131 std::vector<const HloComputation*>* global_computations) {
132 // Create a worklist of computations paired with whether the allocation must
133 // be thread-local.
134 std::deque<std::pair<const HloComputation*, bool>> worklist;
135 worklist.push_back(std::make_pair(module->entry_computation(),
136 /*is_thread_local*/ false));
137
138 // Sets for quickly checking membership. Computations are returned in vectors
139 // for stable iteration.
140 flat_hash_set<const HloComputation*> thread_local_set;
141 flat_hash_set<const HloComputation*> global_set;
142
143 while (!worklist.empty()) {
144 auto worklist_front = worklist.front();
145 worklist.pop_front();
146 const HloComputation* computation = worklist_front.first;
147 bool is_thread_local = worklist_front.second;
148 bool in_thread_local_set = thread_local_set.contains(computation);
149 bool in_global_set = global_set.contains(computation);
150
151 // If the computation has already been added to the respective set, then
152 // nothing to do.
153 if ((is_thread_local && in_thread_local_set) ||
154 (!is_thread_local && in_global_set)) {
155 continue;
156 }
157
158 // If the computation has already been added to the other set this is an
159 // error condition because the global call to the computation (eg,
160 // while/call) may return a reference to one of the thread-local buffers to
161 // the calling computation which will become a dangling reference when the
162 // thread-local is deallocated with the call return.
163 if ((is_thread_local && in_global_set) ||
164 (!is_thread_local && in_thread_local_set)) {
165 return InvalidArgument(
166 "computation %s has conflicting allocation requirements (global "
167 "and thread-local)",
168 computation->name());
169 }
170
171 if (is_thread_local) {
172 thread_local_set.insert(computation);
173 } else {
174 global_set.insert(computation);
175 }
176
177 for (auto* instruction : computation->instructions()) {
178 for (HloComputation* subcomputation :
179 instruction->called_computations()) {
180 switch (instruction->opcode()) {
181 case HloOpcode::kCall:
182 case HloOpcode::kConditional:
183 case HloOpcode::kWhile:
184 // Call and while must be called from a computation with global
185 // allocations as they may return references to buffers inside the
186 // called computation which cannot be thread-local.
187 if (is_thread_local) {
188 return InvalidArgument(
189 "computation %s cannot contain call/while op because it "
190 "requires thread-local buffer allocations",
191 computation->name());
192 }
193 worklist.push_back(std::make_pair(subcomputation,
194 false)); // Not thread local.
195 break;
196 case HloOpcode::kCustomCall:
197 case HloOpcode::kAllReduce:
198 case HloOpcode::kReduceScatter:
199 case HloOpcode::kAllReduceStart:
200 case HloOpcode::kMap:
201 case HloOpcode::kReduce:
202 case HloOpcode::kReduceWindow:
203 case HloOpcode::kScatter:
204 case HloOpcode::kSelectAndScatter:
205 case HloOpcode::kSort:
206 case HloOpcode::kFusion:
207 // Map/reduce etc computations are always thread-local.
208 worklist.push_back(std::make_pair(subcomputation,
209 true)); // Thread local.
210 break;
211 default:
212 return InternalError("Unexpected calling opcode: %s",
213 HloOpcodeString(instruction->opcode()));
214 }
215 }
216 }
217 }
218
219 // Add the computations to the vectors in post order.
220 for (auto* computation : module->MakeComputationPostOrder()) {
221 if (thread_local_set.contains(computation)) {
222 thread_local_computations->push_back(computation);
223 } else if (global_set.contains(computation)) {
224 global_computations->push_back(computation);
225 }
226 // If the computation is not reachable from the entry computation, then it
227 // will not appear in either thread_local_set or global_set. We don't bother
228 // assigning buffers for these.
229 }
230 return Status::OK();
231 }
232
ToString() const233 string BufferAllocation::Slice::ToString() const {
234 return absl::StrCat("{index:", index(), ", offset:", offset_,
235 ", size:", size_, "}");
236 }
237
GetSlice(const HloValue & buffer) const238 BufferAllocation::Slice BufferAllocation::GetSlice(
239 const HloValue& buffer) const {
240 const OffsetSize os = FindOrDie(assigned_buffers_, &buffer);
241 return Slice(this, os.offset, os.size);
242 }
243
AddAssignment(const HloValue & buffer,int64_t offset,int64_t size)244 void BufferAllocation::AddAssignment(const HloValue& buffer, int64_t offset,
245 int64_t size) {
246 VLOG(4) << "Adding the following buffer to allocation #" << index()
247 << absl::StrFormat(" (size=%d, offset=%d) %s", size, offset,
248 buffer.ToShortString());
249 CHECK(!assigned_buffers_.contains(&buffer))
250 << "LogicalBuffer " << buffer << " already assigned to allocation "
251 << index_;
252 CHECK_LE(offset, size_) << "LogicalBuffer " << buffer
253 << " offset out of range";
254 CHECK_LE(offset + size, size_)
255 << "LogicalBuffer " << buffer
256 << " size out of range at offset: " << offset << " with size: " << size;
257 CHECK_EQ(buffer.color(), color())
258 << "Buffer color " << buffer.color() << " for buffer " << buffer
259 << " does not match allocation color " << color() << ".";
260 OffsetSize offset_size;
261 offset_size.offset = offset;
262 offset_size.size = size;
263 assigned_buffers_.emplace(&buffer, offset_size);
264 // For debugging purposes, store the assigned memory space in the
265 // instruction's layout.
266 for (HloPosition position : buffer.positions()) {
267 Shape* shape = ShapeUtil::GetMutableSubshape(
268 position.instruction->mutable_shape(), position.index);
269 if (shape->has_layout()) {
270 shape->mutable_layout()->set_memory_space(buffer.color());
271 }
272 }
273 }
274
ToProto() const275 BufferAllocationProto BufferAllocation::ToProto() const {
276 BufferAllocationProto proto;
277 proto.set_index(index_);
278 proto.set_size(size_);
279 proto.set_is_thread_local(is_thread_local_);
280 proto.set_is_tuple(is_tuple_);
281 proto.set_color(color_);
282 if (is_entry_computation_parameter_) {
283 proto.set_is_entry_computation_parameter(true);
284 for (int64_t idx : param_shape_index()) {
285 proto.add_parameter_shape_index(idx);
286 }
287 proto.set_parameter_number(parameter_number_);
288 }
289 proto.set_is_constant(is_constant_);
290 proto.set_maybe_live_out(maybe_live_out_);
291 for (const auto& buffer_offset_size : assigned_buffers_) {
292 BufferAllocationProto::Assigned* proto_assigned = proto.add_assigned();
293 proto_assigned->set_logical_buffer_id(buffer_offset_size.first->id());
294 proto_assigned->set_offset(buffer_offset_size.second.offset);
295 proto_assigned->set_size(buffer_offset_size.second.size);
296 }
297 absl::c_sort(*proto.mutable_assigned(),
298 [](const BufferAllocationProto::Assigned& assign1,
299 const BufferAllocationProto::Assigned& assign2) {
300 return assign1.logical_buffer_id() <
301 assign2.logical_buffer_id();
302 });
303 return proto;
304 }
305
CompareHloValuesById(const HloValue * a,const HloValue * b)306 static bool CompareHloValuesById(const HloValue* a, const HloValue* b) {
307 return a->id() < b->id();
308 }
309
310 // Returns parameter instruction corresponding to the allocation or nullptr.
GetEntryParameterInstruction(const BufferAllocation & alloc)311 static const HloInstruction* GetEntryParameterInstruction(
312 const BufferAllocation& alloc) {
313 for (const auto& p : alloc.assigned_buffers()) {
314 const HloValue* value = p.first;
315 const HloInstruction* instr = value->instruction();
316 if (instr->opcode() == HloOpcode::kParameter &&
317 instr->parent() == instr->parent()->parent()->entry_computation()) {
318 return instr;
319 }
320 }
321 return nullptr;
322 }
323
324 // Returns root module output instruction corresponding to the allocation or
325 // nullptr.
GetOutputInstruction(const BufferAllocation & alloc)326 static const HloInstruction* GetOutputInstruction(
327 const BufferAllocation& alloc) {
328 for (const auto& p : alloc.assigned_buffers()) {
329 const HloValue* value = p.first;
330 for (const HloPosition& position : value->positions()) {
331 const HloInstruction* instr = position.instruction;
332 if (position.index.empty() &&
333 instr->parent()->root_instruction() == instr &&
334 instr->parent()->IsEntryComputation()) {
335 return instr;
336 }
337 }
338 }
339 return nullptr;
340 }
341
ToString() const342 string BufferAllocation::ToString() const {
343 string output;
344 StrAppendFormat(&output, "allocation %d: %p, size %d", index_, this, size());
345 if (color() != 0) {
346 StrAppend(&output, ", color ", color());
347 }
348 if (is_entry_computation_parameter()) {
349 const HloInstruction* param = GetEntryParameterInstruction(*this);
350 StrAppend(&output, ", parameter ", parameter_number(), ", shape |",
351 param ? param->shape().ToString(/*print_layout=*/false)
352 : "<unknown shape>",
353 "| at ShapeIndex ", param_shape_index().ToString());
354 }
355 if (const HloInstruction* instr = GetOutputInstruction(*this)) {
356 StrAppend(&output, ", output shape is |",
357 instr->shape().ToString(/*print_layout=*/false), "|");
358 }
359 if (is_constant()) {
360 StrAppend(&output, ", constant");
361 }
362 if (is_thread_local()) {
363 StrAppend(&output, ", thread-local");
364 }
365 if (maybe_live_out()) {
366 StrAppend(&output, ", maybe-live-out");
367 }
368 if (IsPreallocatedTempBuffer()) {
369 StrAppend(&output, ", preallocated-temp");
370 }
371 StrAppend(&output, ":\n");
372 // Dump the assigned buffers ordered by id.
373 std::vector<const HloValue*> sorted_buffers;
374 for (const auto& buffer_offset_size : assigned_buffers_) {
375 sorted_buffers.push_back(buffer_offset_size.first);
376 }
377 absl::c_sort(sorted_buffers, &CompareHloValuesById);
378 for (const HloValue* buffer : sorted_buffers) {
379 const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer);
380 StrAppend(&output,
381 absl::StrFormat(
382 " value: %s (size=%d,offset=%d): %s\n",
383 buffer->ToShortString(), offset_size.size, offset_size.offset,
384 ShapeUtil::HumanStringWithLayout(buffer->shape())));
385 }
386 return output;
387 }
388
operator <<(std::ostream & out,const BufferAllocation & buffer)389 std::ostream& operator<<(std::ostream& out, const BufferAllocation& buffer) {
390 out << buffer.ToString();
391 return out;
392 }
393
operator <<(std::ostream & out,const BufferAllocation::Slice & s)394 std::ostream& operator<<(std::ostream& out, const BufferAllocation::Slice& s) {
395 out << s.ToString();
396 return out;
397 }
398
HasAllocation(const HloValue & value) const399 bool BufferAssignment::HasAllocation(const HloValue& value) const {
400 return allocation_index_for_value_.contains(&value);
401 }
402
HasAllocation(const HloBuffer & buffer) const403 bool BufferAssignment::HasAllocation(const HloBuffer& buffer) const {
404 return allocation_index_for_value_.contains(buffer.values()[0]);
405 }
406
GetAssignedAllocation(const HloValue & value) const407 const BufferAllocation& BufferAssignment::GetAssignedAllocation(
408 const HloValue& value) const {
409 CHECK(HasAllocation(value));
410 return GetAllocation(allocation_index_for_value_.at(&value));
411 }
412
GetAssignedAllocation(const HloBuffer & hlo_buffer) const413 const BufferAllocation& BufferAssignment::GetAssignedAllocation(
414 const HloBuffer& hlo_buffer) const {
415 return GetAssignedAllocation(*hlo_buffer.values()[0]);
416 }
417
GetMutableAssignedAllocation(const HloBuffer & buffer)418 BufferAllocation* BufferAssignment::GetMutableAssignedAllocation(
419 const HloBuffer& buffer) {
420 return const_cast<BufferAllocation*>(&GetAssignedAllocation(buffer));
421 }
422
GetAllSlices(const HloInstruction * instruction,const ShapeIndex & index) const423 std::set<BufferAllocation::Slice> BufferAssignment::GetAllSlices(
424 const HloInstruction* instruction, const ShapeIndex& index) const {
425 std::set<BufferAllocation::Slice> result;
426 for (const HloValue* value :
427 dataflow_analysis().GetValueSet(instruction, index).values()) {
428 if (HasAllocation(*value)) {
429 result.insert(GetAssignedAllocation(*value).GetSlice(*value));
430 }
431 }
432 return result;
433 }
434
GetAllocation(BufferAllocation::Index index) const435 const BufferAllocation& BufferAssignment::GetAllocation(
436 BufferAllocation::Index index) const {
437 CHECK_GE(index, 0);
438 CHECK_LT(index, allocations_.size());
439 return allocations_[index];
440 }
441
GetInstructionAllocation(const HloInstruction * hlo,const ShapeIndex & shape_index) const442 const BufferAllocation* BufferAssignment::GetInstructionAllocation(
443 const HloInstruction* hlo, const ShapeIndex& shape_index) const {
444 const HloValue* value =
445 dataflow_analysis().GetValueSet(hlo, shape_index).values()[0];
446
447 if (!HasAllocation(*value)) {
448 return nullptr;
449 }
450
451 const BufferAllocation& instruction_allocation =
452 GetAssignedAllocation(*value);
453 return &instruction_allocation;
454 }
455
GetMutableAllocation(BufferAllocation::Index index)456 BufferAllocation* BufferAssignment::GetMutableAllocation(
457 BufferAllocation::Index index) {
458 return const_cast<BufferAllocation*>(&GetAllocation(index));
459 }
460
HasAllocationAt(const HloInstruction * instruction,const ShapeIndex & index) const461 bool BufferAssignment::HasAllocationAt(const HloInstruction* instruction,
462 const ShapeIndex& index) const {
463 for (const HloValue* value :
464 dataflow_analysis().GetValueSet(instruction, index).values()) {
465 if (allocation_index_for_value_.contains(value)) {
466 return true;
467 }
468 }
469 return false;
470 }
471
HasTopLevelAllocation(const HloInstruction * instruction) const472 bool BufferAssignment::HasTopLevelAllocation(
473 const HloInstruction* instruction) const {
474 return HasAllocationAt(instruction, /*index=*/{});
475 }
476
GetUniqueSlice(const HloInstruction * instruction,const ShapeIndex & index) const477 StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueSlice(
478 const HloInstruction* instruction, const ShapeIndex& index) const {
479 VLOG(3) << "Trying to find unique slice for " << instruction->name() << " ["
480 << index << "]";
481 BufferAllocation::Slice result;
482 for (const HloValue* value :
483 dataflow_analysis().GetValueSet(instruction, index).values()) {
484 VLOG(3) << "Examining value " << *value;
485 if (HasAllocation(*value)) {
486 VLOG(3) << "Has allocation";
487 const BufferAllocation::Slice slice =
488 GetAssignedAllocation(*value).GetSlice(*value);
489 if (result.allocation() == nullptr) {
490 result = slice;
491 } else if (result != slice) {
492 return FailedPrecondition(
493 "BufferAllocation::Slice for instruction %s at index %s cannot "
494 "be determined at compile-time.",
495 instruction->name(), index.ToString());
496 }
497 } else {
498 VLOG(3) << "No allocation";
499 }
500 }
501 if (result.allocation() == nullptr) {
502 return FailedPrecondition(
503 "BufferAllocation::Slice not assigned for instruction %s at index %s",
504 instruction->name(), index.ToString());
505 }
506 return result;
507 }
508
GetUniqueTopLevelSlice(const HloInstruction * instruction) const509 StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueTopLevelSlice(
510 const HloInstruction* instruction) const {
511 return GetUniqueSlice(instruction, /*index=*/{});
512 }
513
SharesSliceAtIndex(const HloInstruction * hlo_a,const ShapeIndex & shape_index_a,const HloInstruction * hlo_b,const ShapeIndex & shape_index_b) const514 bool BufferAssignment::SharesSliceAtIndex(
515 const HloInstruction* hlo_a, const ShapeIndex& shape_index_a,
516 const HloInstruction* hlo_b, const ShapeIndex& shape_index_b) const {
517 return GetUniqueSlice(hlo_a, shape_index_a).ConsumeValueOrDie() ==
518 GetUniqueSlice(hlo_b, shape_index_b).ConsumeValueOrDie();
519 }
520
HaveDisjointSlices(const HloInstruction * hlo_a,const HloInstruction * hlo_b) const521 bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a,
522 const HloInstruction* hlo_b) const {
523 using SliceSet = flat_hash_set<BufferAllocation::Slice>;
524 // Gets the slices all of instr's subshapes. If any subshape doesn't have an
525 // assigned slice, returns the empty set.
526 auto collect_slices = [&](const HloInstruction* instr) -> SliceSet {
527 SliceSet slices;
528 Status status = ShapeUtil::ForEachSubshapeWithStatus(
529 instr->shape(),
530 [&](const Shape& /*subshape*/, const ShapeIndex& index) {
531 auto shape_slices = GetAllSlices(instr, index);
532 if (shape_slices.empty()) {
533 return InvalidArgument("No slices assigned to part of instr.");
534 }
535 slices.insert(shape_slices.begin(), shape_slices.end());
536 return Status::OK();
537 });
538 if (!status.ok()) {
539 return {};
540 }
541 return slices;
542 };
543
544 SliceSet slices_a = collect_slices(hlo_a);
545 SliceSet slices_b = collect_slices(hlo_b);
546 // hlo_a and hlo_b have disjoint slices if collect_slices succeeded (i.e.
547 // didn't return the empty set) for both HLOs, and the two resulting sets of
548 // slices are disjoint.
549 return !slices_a.empty() && !slices_b.empty() &&
550 absl::c_none_of(slices_a, [&](const BufferAllocation::Slice& slice) {
551 return slices_b.contains(slice);
552 });
553 }
554
555 StatusOr<BufferAllocation::Slice>
GetUniqueTopLevelOutputSlice() const556 BufferAssignment::GetUniqueTopLevelOutputSlice() const {
557 return GetUniqueTopLevelSlice(
558 module_->entry_computation()->root_instruction());
559 }
560
NewEmptyAllocation(int64_t size,LogicalBuffer::Color color)561 BufferAllocation* BufferAssignment::NewEmptyAllocation(
562 int64_t size, LogicalBuffer::Color color) {
563 BufferAllocation::Index index = allocations_.size();
564 allocations_.emplace_back(index, size, color);
565 BufferAllocation* allocation = &allocations_.back();
566 return allocation;
567 }
568
NewAllocation(const HloBuffer & buffer,int64_t size)569 BufferAllocation* BufferAssignment::NewAllocation(const HloBuffer& buffer,
570 int64_t size) {
571 BufferAllocation* allocation = NewEmptyAllocation(size, buffer.color());
572 AddAssignment(allocation, buffer, /*offset=*/0, size);
573 allocation->peak_buffers_.push_back(buffer.values()[0]);
574 return allocation;
575 }
576
AddAssignment(BufferAllocation * allocation,const HloBuffer & buffer,int64_t offset,int64_t size)577 void BufferAssignment::AddAssignment(BufferAllocation* allocation,
578 const HloBuffer& buffer, int64_t offset,
579 int64_t size) {
580 CHECK(allocation->is_reusable() || allocation->assigned_buffers().empty())
581 << "Non-reusable allocation already assigned a buffer: "
582 << allocation->ToString();
583
584 for (const HloValue* buffer_value : buffer.values()) {
585 CHECK(!allocation_index_for_value_.contains(buffer_value))
586 << "BufferValue " << buffer_value << " already has an allocation.";
587 allocation->AddAssignment(*buffer_value, offset, size);
588 allocation_index_for_value_[buffer_value] = allocation->index();
589 }
590
591 if (alias_analysis().BufferLivesOut(buffer)) {
592 VLOG(3) << "HloBuffer lives out: " << buffer.ToString();
593 VLOG(3) << "Set maybe live out: " << allocation->ToString();
594 allocation->set_maybe_live_out(true);
595 }
596 }
597
AddAssignment(BufferAllocation * allocation,const HloValue & value,int64_t offset,int64_t size)598 void BufferAssignment::AddAssignment(BufferAllocation* allocation,
599 const HloValue& value, int64_t offset,
600 int64_t size) {
601 allocation->AddAssignment(value, offset, size);
602 allocation_index_for_value_[&value] = allocation->index();
603 const HloValue& hlo_value =
604 *CHECK_NOTNULL(dynamic_cast<const HloValue*>(&value));
605 if (alias_analysis().ValueLivesOut(hlo_value)) {
606 VLOG(3) << "HloValue lives out: " << hlo_value.ToString();
607 VLOG(3) << "Set maybe live out: " << allocation->ToString();
608 allocation->set_maybe_live_out(true);
609 }
610 }
611
612 // Combines allocations of temporary buffers of the same color into one big
613 // BufferAllocation.
CombineTempAllocations()614 void BufferAssignment::CombineTempAllocations() {
615 VLOG(1) << "CombineTempAllocations()";
616 // Stores the combined allocations.
617 std::deque<BufferAllocation> combined_allocations;
618 // Holds the pointer to a combined allocation of each color, if any.
619 flat_hash_map<BufferValue::Color, BufferAllocation*> combined_allocation_map;
620
621 // Move all temp allocations into a single run at the end of the allocations
622 // vector.
623 const auto first_temp_it =
624 std::partition(allocations_.begin(), allocations_.end(),
625 [](const BufferAllocation& allocation) {
626 return !allocation.IsPreallocatedTempBuffer();
627 });
628
629 // Walk over the run of temp allocations, collecting the allocations belonging
630 // to the same color.
631 if (first_temp_it != allocations_.end()) {
632 for (auto it = first_temp_it; it != allocations_.end(); ++it) {
633 BufferAllocation& temp_allocation = *it;
634 BufferValue::Color color = temp_allocation.color();
635 auto combined_it = combined_allocation_map.find(color);
636 if (combined_it == combined_allocation_map.end()) {
637 // We have found the first temp allocation of this color. Collect
638 // the other temp allocations of the same color into it subject to the
639 // size constraint.
640 VLOG(1) << "Combined temp allocation for color " << color
641 << " is: " << temp_allocation;
642 combined_allocations.emplace_back(temp_allocation);
643 combined_allocation_map.emplace(color, &combined_allocations.back());
644 continue;
645 }
646 if (combined_it->second->size() + it->size() >=
647 multiheap_size_constraint_per_heap_) {
648 // We cannot put more into the current combined_it. So, appoint a new
649 // combined_it.
650 VLOG(1) << "Due to size constraint, reset temp allocation for color "
651 << color << " to: " << temp_allocation;
652 combined_allocations.emplace_back(temp_allocation);
653 combined_allocation_map.emplace(color, &combined_allocations.back());
654 continue;
655 }
656
657 BufferAllocation* combined_allocation = combined_it->second;
658 VLOG(1) << "Combined allocation absorbing temp allocation: "
659 << temp_allocation;
660
661 // Each temp allocation is placed end-to-end, accounting for alignment.
662 // The offset of each buffer in the combined allocation is computed from
663 // the base offset of the allocation.
664 int64_t alignment = color_alignment_(color);
665 const int64_t base =
666 RoundUpToNearest(combined_allocation->size(), alignment);
667 combined_allocation->set_size(base + temp_allocation.size());
668 for (const auto& buffer_offset_size : temp_allocation.assigned_buffers_) {
669 const HloValue* value = buffer_offset_size.first;
670 const int64_t offset = buffer_offset_size.second.offset;
671 const int64_t size = buffer_offset_size.second.size;
672 combined_allocation->AddAssignment(*value, base + offset, size);
673 }
674 if (!temp_allocation.HeapTraces().empty()) {
675 CHECK_EQ(temp_allocation.HeapTraces().size(), 1);
676 combined_allocation->AddHeapTrace(temp_allocation.HeapTraces().front());
677 }
678
679 combined_allocation->peak_buffers_.insert(
680 combined_allocation->peak_buffers_.end(),
681 temp_allocation.peak_buffers_.begin(),
682 temp_allocation.peak_buffers_.end());
683 }
684 // Replace all existing temporary allocations with the new combined
685 // allocations.
686 allocations_.erase(first_temp_it, allocations_.end());
687 for (BufferAllocation& combined : combined_allocations) {
688 temp_allocation_total_size_ += combined.size();
689 allocations_.push_back(std::move(combined));
690 }
691 }
692
693 // Update allocation indices to their new positions.
694 allocation_index_for_value_.erase(allocation_index_for_value_.begin(),
695 allocation_index_for_value_.end());
696 for (size_t index = 0; index < allocations_.size(); ++index) {
697 BufferAllocation* allocation = &allocations_[index];
698 allocation->set_index(index);
699 for (const auto& buffer_offset_size : allocation->assigned_buffers_) {
700 const HloValue* value = buffer_offset_size.first;
701 allocation_index_for_value_[value] = index;
702 }
703 }
704 }
705
ComputeSummaryStats()706 Status BufferAssignment::ComputeSummaryStats() {
707 for (auto& allocation : Allocations()) {
708 if (allocation.is_entry_computation_parameter()) {
709 stats_.parameter_allocation_count++;
710 stats_.parameter_allocation_bytes += allocation.size();
711 }
712 if (allocation.is_constant()) {
713 stats_.constant_allocation_count++;
714 stats_.constant_allocation_bytes += allocation.size();
715 }
716 if (allocation.maybe_live_out()) {
717 stats_.maybe_live_out_allocation_count++;
718 stats_.maybe_live_out_allocation_bytes += allocation.size();
719 }
720 if (allocation.IsPreallocatedTempBuffer()) {
721 stats_.preallocated_temp_allocation_count++;
722 stats_.preallocated_temp_allocation_bytes += allocation.size();
723 }
724 stats_.total_allocation_count++;
725 stats_.total_allocation_bytes += allocation.size();
726 }
727
728 // Only compute total fragmentation if all computations have schedules.
729 HloSchedule schedule(module_);
730 bool schedule_complete = true;
731 for (const auto& computation : module_->computations()) {
732 if (!computation->IsFusionComputation()) {
733 const HloInstructionSequence* sequence =
734 hlo_ordering().SequentialOrder(*computation);
735 if (sequence == nullptr) {
736 schedule_complete = false;
737 } else {
738 schedule.set_sequence(computation, *sequence);
739 }
740 }
741 }
742 if (schedule_complete) {
743 TF_RETURN_IF_ERROR(schedule.Verify());
744 TF_ASSIGN_OR_RETURN(
745 const int64_t min_size,
746 HeapSimulator::MinimumMemoryForModule(schedule, buffer_size_));
747 stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size;
748 }
749
750 return Status::OK();
751 }
752
ToString() const753 string BufferAssignment::Stats::ToString() const {
754 string s;
755 StrAppendFormat(&s, "BufferAssignment stats:\n");
756 StrAppendFormat(&s, " parameter allocation: %10s\n",
757 HumanReadableNumBytes(parameter_allocation_bytes));
758 StrAppendFormat(&s, " constant allocation: %10s\n",
759 HumanReadableNumBytes(constant_allocation_bytes));
760 StrAppendFormat(&s, " maybe_live_out allocation: %10s\n",
761 HumanReadableNumBytes(maybe_live_out_allocation_bytes));
762 StrAppendFormat(&s, " preallocated temp allocation: %10s\n",
763 HumanReadableNumBytes(preallocated_temp_allocation_bytes));
764 if (preallocated_temp_fragmentation_bytes >= 0) {
765 const double percent = 100. * preallocated_temp_fragmentation_bytes /
766 preallocated_temp_allocation_bytes;
767 StrAppendFormat(
768 &s, " preallocated temp fragmentation: %10s (%.2f%%)\n",
769 HumanReadableNumBytes(preallocated_temp_fragmentation_bytes), percent);
770 }
771 StrAppendFormat(&s, " total allocation: %10s\n",
772 HumanReadableNumBytes(total_allocation_bytes));
773 if (total_fragmentation_bytes >= 0) {
774 const double percent =
775 100. * total_fragmentation_bytes / total_allocation_bytes;
776 StrAppendFormat(&s, " total fragmentation: %10s (%.2f%%)\n",
777 HumanReadableNumBytes(total_fragmentation_bytes), percent);
778 }
779 return s;
780 }
781
ToString() const782 string BufferAssignment::ToString() const {
783 string output;
784 absl::StrAppend(&output, "BufferAssignment:\n");
785 std::vector<const HloValue*> used_values;
786 int64_t total_size = 0;
787 for (auto& allocation : allocations_) {
788 total_size += allocation.size();
789 absl::StrAppend(&output, allocation.ToString());
790 for (const auto& p : allocation.assigned_buffers()) {
791 used_values.push_back(p.first);
792 }
793 }
794 absl::StrAppend(&output, "\nTotal bytes used: ", total_size, " (",
795 HumanReadableNumBytes(total_size), ")\n");
796 absl::StrAppend(&output, "\nUsed values:\n");
797 absl::c_sort(used_values, &CompareHloValuesById);
798 for (const HloValue* value : used_values) {
799 absl::StrAppend(&output, value->ToString());
800 }
801 return output;
802 }
803
BufferInfoString() const804 string BufferAssignment::BufferInfoString() const {
805 string binfo;
806 // Columns in buffer information:
807 // buffer_id: int. This value can be used to match the allocation in
808 // allocation information.
809 // buffer_name: string.
810 // offset: int. Starting position of the buffer in the memory space.
811 // size: int. Size of the buffer in bytes.
812 // definition_time: int. Position in the schedule where the buffer starts
813 // being live (inclusive).
814 // end_time: int. Position in the schedule where the buffer stops being live
815 // (exclusive).
816 // num_uses: int. Number of uses of the buffer.
817 // use_names: string. This is a semicolon-separated list of string
818 // representation of uses.
819 // Append the column names.
820 absl::StrAppend(&binfo,
821 "buffer_id,buffer_name,offset,size,"
822 "definition_time,end_time,num_uses,use_times,use_names\n");
823 const HloLiveRange& live_ranges = hlo_live_range();
824 const auto& instruction_schedule = live_ranges.instruction_schedule();
825 const auto& buffer_live_ranges = live_ranges.buffer_live_ranges();
826 // Sort the buffers by Id.
827 std::vector<std::pair<const HloValue*, BufferAllocation::OffsetSize>> buffers;
828 for (const BufferAllocation& allocation : allocations_) {
829 absl::c_copy(allocation.assigned_buffers(), std::back_inserter(buffers));
830 }
831 absl::c_sort(
832 buffers,
833 [](const std::pair<const HloValue*, BufferAllocation::OffsetSize>& b1,
834 const std::pair<const HloValue*, BufferAllocation::OffsetSize>& b2) {
835 return b1.first->id() < b2.first->id();
836 });
837 for (const auto& buffer_pair : buffers) {
838 const HloValue& buffer = *buffer_pair.first;
839 const BufferAllocation::OffsetSize& offset_size = buffer_pair.second;
840 if (!buffer_live_ranges.contains(&buffer)) {
841 continue;
842 }
843 // Ordering uses by their use position.
844 std::vector<std::pair<int64, std::string>> uses;
845 uses.reserve(buffer.uses().size());
846 for (const HloUse& use : buffer.uses()) {
847 uses.emplace_back(instruction_schedule.at(use.instruction),
848 use.ToString());
849 }
850 absl::c_sort(uses);
851 std::vector<int64> use_positions;
852 std::vector<std::string> use_names;
853 use_positions.reserve(uses.size());
854 use_names.reserve(uses.size());
855 for (const auto& use : uses) {
856 use_positions.push_back(use.first);
857 use_names.push_back(use.second);
858 }
859 const int64_t definition_time =
860 instruction_schedule.at(buffer.defining_position().instruction);
861 const int64_t end_t = buffer_live_ranges.at(&buffer).end;
862 absl::StrAppend(&binfo, buffer.id(), ",");
863 absl::StrAppend(&binfo, "\"", buffer.ToShortString(), "\",");
864 absl::StrAppend(&binfo, offset_size.offset, ",");
865 absl::StrAppend(&binfo, offset_size.size, ",");
866 absl::StrAppend(&binfo, definition_time, ",");
867 absl::StrAppend(&binfo, end_t, ",");
868 absl::StrAppend(&binfo, use_positions.size(), ",");
869 absl::StrAppend(&binfo, "\"", absl::StrJoin(use_positions, ";"), "\",");
870 absl::StrAppend(&binfo, "\"", absl::StrJoin(use_names, ";"), "\"");
871 absl::StrAppend(&binfo, "\n");
872 }
873 return binfo;
874 }
875
ToProto() const876 BufferAssignmentProto BufferAssignment::ToProto() const {
877 BufferAssignmentProto proto;
878 // NOTE: DataflowAnalysis state is serialized here in BufferAssignment,
879 // because we need to do the HasAllocation check for each buffer. Otherwise
880 // the buffer_size_ call might fail for some backends.
881 const HloDataflowAnalysis& dataflow = this->dataflow_analysis();
882 for (BufferValue::Id id = 0; id < dataflow.values().size(); id++) {
883 auto& value = dataflow.values().at(id);
884 if (HasAllocation(*value)) {
885 LogicalBufferProto proto_buffer = value->ToProto(buffer_size_);
886 proto.add_logical_buffers()->Swap(&proto_buffer);
887
888 // Fill buffer aliases.
889 for (const HloValue* alias :
890 alias_analysis().GetBufferContainingValue(*value).values()) {
891 if (alias->instruction() == value->instruction() &&
892 alias->index() == value->index()) {
893 continue; // skip self-aliases
894 }
895 BufferAssignmentProto::BufferAlias* proto_alias =
896 proto.add_buffer_aliases();
897 LogicalBufferProto::Location proto_alias_location =
898 BufferValue::ToLocationProto(*alias->instruction(), alias->index());
899 proto_alias->set_source_buffer_id(value->id());
900 proto_alias->mutable_location()->Swap(&proto_alias_location);
901 }
902 }
903 }
904 for (const BufferAllocation& allocation : Allocations()) {
905 BufferAllocationProto proto_allocation = allocation.ToProto();
906 proto.add_buffer_allocations()->Swap(&proto_allocation);
907 for (const HeapSimulatorTrace& heap_trace : allocation.HeapTraces()) {
908 *proto.add_heap_simulator_traces() = heap_trace;
909 }
910 }
911 return proto;
912 }
913
914 /* static */
Run(const HloModule * module,std::unique_ptr<HloOrdering> hlo_ordering,BufferValue::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment,bool allocate_buffers_for_constants,BufferAssigner::Colorer colorer,const absl::flat_hash_set<HloOpcode> & must_not_live_out,HloDataflowAnalysis::CanShareBuffer can_share_buffer,std::unique_ptr<PresetAssignments> preset_assignments)915 StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::Run(
916 const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
917 BufferValue::SizeFunction buffer_size,
918 LogicalBuffer::AlignmentFunction color_alignment,
919 bool allocate_buffers_for_constants, BufferAssigner::Colorer colorer,
920 const absl::flat_hash_set<HloOpcode>& must_not_live_out,
921 HloDataflowAnalysis::CanShareBuffer can_share_buffer,
922 std::unique_ptr<PresetAssignments> preset_assignments) {
923 BufferAssigner assigner(allocate_buffers_for_constants, std::move(colorer),
924 must_not_live_out, std::move(preset_assignments));
925 return assigner.CreateAssignment(
926 module, std::move(hlo_ordering), std::move(buffer_size),
927 std::move(color_alignment), std::move(can_share_buffer));
928 }
929
LiveRangeInterferes(const HloValue * buffer1,const HloValue * buffer2,BufferAssignment * assignment)930 bool BufferAssigner::LiveRangeInterferes(const HloValue* buffer1,
931 const HloValue* buffer2,
932 BufferAssignment* assignment) {
933 CHECK((assignment->hlo_live_range().total_order_scheduled()));
934 const HloLiveRange& hlo_live_range = assignment->hlo_live_range();
935
936 const auto& buffer_live_ranges = hlo_live_range.buffer_live_ranges();
937
938 CHECK(buffer_live_ranges.contains(buffer1))
939 << "Buffer doesn't have a proper live range:" << buffer1;
940
941 CHECK(buffer_live_ranges.contains(buffer2))
942 << "Buffer doesn't have a proper live range:" << buffer2;
943
944 // Check if a user value can share the same buffer as its operand.
945 auto can_share_as_operand = [&assignment, &buffer_live_ranges](
946 const HloValue* user_value,
947 const HloValue* operand_value) {
948 // An hlo value can hold multiple instructions during its life time. We only
949 // look at the last instruction and check if it can be shared with the
950 // operand.
951 HloPosition operand_end_position =
952 buffer_live_ranges.at(operand_value).end_position;
953 return user_value->instruction()->IsUserOf(
954 operand_end_position.instruction) &&
955 assignment->dataflow_analysis().CanShareOperandBufferWithUser(
956 operand_end_position.instruction, operand_end_position.index,
957 user_value->instruction(), user_value->index()) &&
958 user_value->instruction()->opcode() != HloOpcode::kCopy;
959 };
960
961 auto live_range_1 = buffer_live_ranges.at(buffer1);
962 auto live_range_2 = buffer_live_ranges.at(buffer2);
963
964 if (!(live_range_1.start > live_range_2.end ||
965 live_range_2.start > live_range_1.end)) {
966 if (live_range_1.end == live_range_2.start) {
967 auto operand_value = buffer1;
968 auto user_value = buffer2;
969 if (!can_share_as_operand(user_value, operand_value)) {
970 VLOG(4) << "End of live range of " << buffer1->ToShortString()
971 << " is equal to the start of live range of "
972 << buffer2->ToShortString() << ", buffer cannot be shared.";
973 return true;
974 }
975 } else if (live_range_2.end == live_range_1.start) {
976 auto operand_value = buffer2;
977 auto user_value = buffer1;
978 if (!can_share_as_operand(user_value, operand_value)) {
979 VLOG(4) << "End of live range of " << buffer2->ToShortString()
980 << " is equal to the start of live range of "
981 << buffer1->ToShortString() << ", buffer cannot be shared.";
982 return true;
983 }
984 } else {
985 VLOG(4) << "Can't assign: assignee " << *buffer1 << " may interfere with "
986 << *buffer2;
987 VLOG(4) << "assigned_buffer.start: " << live_range_1.start;
988 VLOG(4) << "assigned_buffer.end: " << live_range_1.end;
989 VLOG(4) << "live_range_2.start" << live_range_2.start;
990 VLOG(4) << "live_range_2.end" << live_range_2.end;
991 return true;
992 }
993 }
994 return false;
995 }
996
MaybeAssignBuffer(BufferAllocation * allocation,const HloBuffer & hlo_buffer,BufferAssignment * assignment)997 bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
998 const HloBuffer& hlo_buffer,
999 BufferAssignment* assignment) {
1000 CHECK(!assignment->HasAllocation(hlo_buffer))
1001 << "buffer " << hlo_buffer << " already has an allocation assigned.";
1002
1003 VLOG(4) << "Trying to assign " << hlo_buffer << " size "
1004 << assignment->HloBufferSize(hlo_buffer)
1005 << " to allocation: " << *allocation;
1006
1007 if (hlo_buffer.color() != allocation->color()) {
1008 VLOG(4) << "Can't assign: buffer has color " << hlo_buffer.color()
1009 << " and allocation has color " << allocation->color() << ".";
1010 return false;
1011 }
1012
1013 if (assignment->HloBufferSize(hlo_buffer) > allocation->size()) {
1014 VLOG(4) << "Can't assign: buffer is larger than allocation ("
1015 << assignment->HloBufferSize(hlo_buffer) << " > "
1016 << allocation->size() << ")";
1017 return false;
1018 }
1019
1020 if (allocation->is_readonly()) {
1021 VLOG(4) << "Can't assign: allocation is readonly";
1022 return false;
1023 }
1024
1025 if (!must_not_live_out_.empty()) {
1026 if (allocation->maybe_live_out()) {
1027 // If a buffer maybe live out, the allocation cannot contain any node from
1028 // the "must_not_live_out_" set.
1029 for (const HloValue* value : hlo_buffer.values()) {
1030 if (must_not_live_out_.count(value->instruction()->opcode()) > 0) {
1031 VLOG(4) << "Can't assign: " << value->instruction()->ToString()
1032 << " cannot live out of the module";
1033 return false;
1034 }
1035 }
1036 }
1037 // The above check is not enough -- There could be the case where an
1038 // allocation can be not live out and contains an instruction with opcode
1039 // from the "must_not_live_out_" set, but assigning a live out buffer to
1040 // that allocation makes the allocation live out and also contains
1041 // instruction from the "must_not_live_out_" set.
1042 if (assignment->alias_analysis().BufferLivesOut(hlo_buffer)) {
1043 for (const auto& buffer_offset_size : allocation->assigned_buffers()) {
1044 if (must_not_live_out_.count(
1045 buffer_offset_size.first->instruction()->opcode()) > 0) {
1046 VLOG(4) << "Can't assign: " << buffer_offset_size.first->instruction()
1047 << " cannot live out of the module";
1048 return false;
1049 }
1050 }
1051 }
1052 }
1053
1054 if (!allocation->is_reusable()) {
1055 VLOG(4) << "Can't assign: allocation is not reusable";
1056 return false;
1057 }
1058
1059 for (const auto& buffer_offset_size : allocation->assigned_buffers()) {
1060 // Pairwise compare.
1061 const HloValue& assigned_buffer =
1062 *CHECK_NOTNULL(dynamic_cast<const HloValue*>(buffer_offset_size.first));
1063 for (const HloValue* new_value : hlo_buffer.values()) {
1064 if (assignment->hlo_live_range().total_order_scheduled()) {
1065 if (LiveRangeInterferes(new_value, &assigned_buffer, assignment)) {
1066 VLOG(4) << "Can't assign: assignee " << assigned_buffer
1067 << " live range interferes with "
1068 << new_value->ToShortString();
1069 return false;
1070 }
1071 } else if (assignment->hlo_ordering().MayInterfere(
1072 assigned_buffer, *new_value,
1073 assignment->dataflow_analysis())) {
1074 // Fallback to partial order based interference detection (slower) when
1075 // we don't have a total order scheduled module.
1076 VLOG(4) << "Can't assign: assignee " << assigned_buffer
1077 << " may interfere with " << new_value->ToShortString();
1078 return false;
1079 }
1080
1081 for (const HloPosition& assigned_buffer_position :
1082 assigned_buffer.positions()) {
1083 // Copy instruction don't share a buffer with their input operand.
1084 if (new_value->instruction()->IsUserOf(
1085 assigned_buffer_position.instruction) &&
1086 new_value->instruction()->opcode() == HloOpcode::kCopy) {
1087 VLOG(4) << "Can't assign: assignee " << assigned_buffer
1088 << " is used at copy instruction "
1089 << new_value->ToShortString();
1090 return false;
1091 }
1092 }
1093 }
1094 }
1095
1096 // If the buffer is live out of the computation then it should only be
1097 // assigned a buffer which exactly fits the result to avoid wasting memory
1098 // (result buffers can have arbitrary lifetimes).
1099 if (assignment->alias_analysis().BufferLivesOut(hlo_buffer) &&
1100 allocation->size() != assignment->HloBufferSize(hlo_buffer)) {
1101 VLOG(4) << "Can't assign: buffer " << hlo_buffer
1102 << "is live out and size not the same as allocation";
1103 return false;
1104 }
1105
1106 assignment->AddAssignment(allocation, hlo_buffer, /*offset=*/0,
1107 assignment->HloBufferSize(hlo_buffer));
1108 return true;
1109 } // namespace xla
1110
AssignSingleHloBuffer(const HloBuffer * hlo_buffer,bool is_thread_local,absl::flat_hash_map<const HloComputation *,absl::flat_hash_set<const HloValue * >> * buffers_to_assign_sequentially,std::vector<BufferAllocation::Index> * allocation_indices,BufferAssignment * assignment)1111 Status BufferAssigner::AssignSingleHloBuffer(
1112 const HloBuffer* hlo_buffer, bool is_thread_local,
1113 absl::flat_hash_map<const HloComputation*,
1114 absl::flat_hash_set<const HloValue*>>*
1115 buffers_to_assign_sequentially,
1116 std::vector<BufferAllocation::Index>* allocation_indices,
1117 BufferAssignment* assignment) {
1118 const int64_t buffer_size = assignment->HloBufferSize(*hlo_buffer);
1119 for (const HloValue* value : hlo_buffer->values()) {
1120 if (value->instruction()->opcode() == HloOpcode::kConstant) {
1121 if (allocate_buffers_for_constants_) {
1122 BufferAllocation* allocation =
1123 assignment->NewAllocation(*hlo_buffer, buffer_size);
1124 allocation->set_constant(true);
1125 VLOG(3) << "New allocation #" << allocation->index() << " for constant "
1126 << *hlo_buffer << " value ptr: " << value;
1127 }
1128 VLOG(3) << "Not allocating buffer for constant";
1129 return Status::OK();
1130 }
1131
1132 const HloInstruction* instruction = value->instruction();
1133 const bool is_entry_parameter =
1134 instruction->opcode() == HloOpcode::kParameter &&
1135 instruction->parent() ==
1136 instruction->parent()->parent()->entry_computation();
1137
1138 if (is_entry_parameter) {
1139 bool parameter_has_alias =
1140 assignment->module().input_output_alias_config().ParameterHasAlias(
1141 instruction->parameter_number(), value->index());
1142 // If the hlo buffer is part of an external parameter, creates a new
1143 // allocation and sets its parameter number. Parameters of non-entry
1144 // computations do not need special allocations because they live inside
1145 // callers.
1146 BufferAllocation* allocation =
1147 assignment->NewAllocation(*hlo_buffer, buffer_size);
1148
1149 allocation->set_entry_computation_parameter(
1150 instruction->parameter_number(), value->index(), parameter_has_alias);
1151 if (parameter_has_alias) {
1152 allocation_indices->push_back(allocation->index());
1153 }
1154 VLOG(3) << "New allocation #" << allocation->index()
1155 << " marked as entry computation parameter: " << *hlo_buffer;
1156 return Status::OK();
1157 }
1158 }
1159
1160 if (is_thread_local) {
1161 BufferAllocation* allocation =
1162 assignment->NewAllocation(*hlo_buffer, buffer_size);
1163 allocation->set_is_thread_local(true);
1164 VLOG(3) << "New allocation #" << allocation->index()
1165 << " for thread-local: " << *hlo_buffer;
1166 return Status::OK();
1167 }
1168
1169 for (const HloValue* value : hlo_buffer->values()) {
1170 if (value->shape().IsTuple()) {
1171 BufferAllocation* allocation =
1172 assignment->NewAllocation(*hlo_buffer, buffer_size);
1173 allocation->set_is_tuple(true);
1174 VLOG(3) << "New allocation #" << allocation->index()
1175 << " for tuple-shaped buffer: " << *hlo_buffer;
1176 return Status::OK();
1177 }
1178
1179 if (value->IsTopLevel() && !value->IsTuple()) {
1180 const HloInstruction* instruction = value->instruction();
1181 for (auto* operand : instruction->operands()) {
1182 for (const auto& operand_slice :
1183 assignment->GetAllSlices(operand, /*index=*/{})) {
1184 BufferAllocation* allocation =
1185 assignment->GetMutableAllocation(operand_slice.index());
1186 if (MaybeAssignBuffer(allocation, *hlo_buffer, assignment)) {
1187 VLOG(3) << "Reusing (operand) allocation #" << allocation->index()
1188 << " for: " << *hlo_buffer;
1189 return Status::OK();
1190 }
1191 }
1192 }
1193 }
1194 }
1195
1196 // Find the smallest buffer which can be reused iterating from end of
1197 // allocation_indices (smallest) to beginning (largest).
1198 for (int allocation_index = allocation_indices->size() - 1;
1199 allocation_index >= 0; allocation_index--) {
1200 BufferAllocation* allocation = assignment->GetMutableAllocation(
1201 allocation_indices->at(allocation_index));
1202 if (MaybeAssignBuffer(allocation, *hlo_buffer, assignment)) {
1203 VLOG(3) << "Reusing allocation #" << allocation->index()
1204 << " for: " << *hlo_buffer;
1205 return Status::OK();
1206 }
1207 }
1208
1209 if (!assignment->HasAllocation(*hlo_buffer) &&
1210 !assignment->alias_analysis().BufferLivesOut(*hlo_buffer)) {
1211 bool all_computations_have_sequential_order = true;
1212 for (const HloValue* hlo_value : hlo_buffer->values()) {
1213 HloComputation* computation = hlo_value->instruction()->parent();
1214 const bool has_sequential_order =
1215 assignment->hlo_ordering().SequentialOrder(*computation) != nullptr;
1216 all_computations_have_sequential_order &= has_sequential_order;
1217 }
1218
1219 if (all_computations_have_sequential_order) {
1220 for (const HloValue* hlo_value : hlo_buffer->values()) {
1221 HloComputation* computation = hlo_value->instruction()->parent();
1222 // There is a sequential instruction ordering, so we delay assignment
1223 // of temp buffers until after the loop. We do this right before we
1224 // decide to create a new allocation, to ensure we've exhausted all
1225 // the buffer re-use cases above.
1226 //
1227 // Entry parameters and thread local buffers were already handled
1228 // earlier in this loop iteration. See
1229 // BufferAllocation::IsPreallocatedTempBuffer for the definition of
1230 // temp buffers.
1231 (*buffers_to_assign_sequentially)[computation].insert(hlo_value);
1232 VLOG(3) << "Delaying assignment of temp buffer: " << *hlo_value;
1233 }
1234 return Status::OK();
1235 }
1236 }
1237
1238 if (!assignment->HasAllocation(*hlo_buffer)) {
1239 BufferAllocation* allocation =
1240 assignment->NewAllocation(*hlo_buffer, buffer_size);
1241 allocation_indices->push_back(allocation->index());
1242 VLOG(3) << "New allocation #" << allocation->index()
1243 << " for: " << *hlo_buffer;
1244 }
1245
1246 TF_RET_CHECK(assignment->HasAllocation(*hlo_buffer));
1247 return Status::OK();
1248 }
1249
AssignBuffersForComputations(const std::vector<const HloComputation * > & computations,bool is_thread_local,absl::flat_hash_map<const HloComputation *,absl::flat_hash_set<const HloValue * >> * buffers_to_assign_sequentially,BufferAssignment * assignment)1250 Status BufferAssigner::AssignBuffersForComputations(
1251 const std::vector<const HloComputation*>& computations,
1252 bool is_thread_local,
1253 absl::flat_hash_map<const HloComputation*,
1254 absl::flat_hash_set<const HloValue*>>*
1255 buffers_to_assign_sequentially,
1256 BufferAssignment* assignment) {
1257 if (computations.empty()) {
1258 return Status::OK();
1259 }
1260 std::vector<const HloBuffer*> sorted_buffers;
1261
1262 // First assign the preset allocations.
1263 absl::flat_hash_set<const HloBuffer*> preset_assigned_buffers;
1264
1265 TF_RETURN_IF_ERROR(AssignPresetBuffers(&preset_assigned_buffers, assignment));
1266
1267 const HloAliasAnalysis& alias_analysis = assignment->alias_analysis();
1268
1269 for (const HloBuffer& buffer : alias_analysis.buffers()) {
1270 // Skip if the buffer is already assigned since it had a preset allocation.
1271 if (preset_assigned_buffers.find(&buffer) !=
1272 preset_assigned_buffers.end()) {
1273 VLOG(3) << "Skip allocation for buffer: " << buffer;
1274 continue;
1275 }
1276 TF_RET_CHECK(!buffer.values().empty());
1277 const HloComputation* comp = buffer.values()[0]->instruction()->parent();
1278 if (absl::c_linear_search(computations, comp)) {
1279 sorted_buffers.push_back(&buffer);
1280 }
1281 }
1282
1283 // Generate a post order sort of instructions for sorting of the
1284 // HloBuffers.
1285 flat_hash_map<const HloInstruction*, int> post_order_position;
1286 int position = 0;
1287 std::vector<const HloComputation*> reverse_post_order_computations;
1288 std::unique_ptr<CallGraph> call_graph =
1289 CallGraph::Build(computations[0]->parent());
1290 TF_RETURN_IF_ERROR(call_graph->VisitNodes([&](const CallGraphNode& node) {
1291 if (absl::c_linear_search(computations, node.computation())) {
1292 reverse_post_order_computations.push_back(node.computation());
1293 }
1294 return Status::OK();
1295 }));
1296 absl::c_reverse(reverse_post_order_computations);
1297 for (auto* computation : reverse_post_order_computations) {
1298 for (auto* instruction : computation->MakeInstructionPostOrder()) {
1299 post_order_position.emplace(instruction, position);
1300 position++;
1301 }
1302 }
1303
1304 HloSchedule schedule(&assignment->module());
1305
1306 for (const HloComputation* computation : computations) {
1307 const HloInstructionSequence* instruction_sequence =
1308 assignment->hlo_ordering().SequentialOrder(*computation);
1309 const bool has_sequential_order = instruction_sequence != nullptr;
1310 if (has_sequential_order && buffers_to_assign_sequentially != nullptr) {
1311 // Every sequential computation must get an entry in the
1312 // buffers_to_assign_sequentially map, even if we end up with an empty
1313 // set of buffers. This ensures we can correctly determine whether to
1314 // run whole-module heap simulation.
1315 buffers_to_assign_sequentially->emplace(computation,
1316 flat_hash_set<const HloValue*>());
1317
1318 schedule.set_sequence(computation, *instruction_sequence);
1319 }
1320 }
1321
1322 absl::c_sort(
1323 sorted_buffers, [&post_order_position, &alias_analysis, assignment](
1324 const HloBuffer* a, const HloBuffer* b) {
1325 // Primary sort is by decreasing buffer size.
1326 const int64_t a_size = assignment->HloBufferSize(*a);
1327 const int64_t b_size = assignment->HloBufferSize(*b);
1328 if (a_size != b_size) {
1329 return a_size > b_size; // use ">" for decreasing size.
1330 }
1331
1332 const bool a_live_out = alias_analysis.BufferLivesOut(*a);
1333 const bool b_live_out = alias_analysis.BufferLivesOut(*b);
1334 if (a_live_out != b_live_out) {
1335 return a_live_out;
1336 }
1337 auto compare = [&post_order_position](const HloValue* value1,
1338 const HloValue* value2) {
1339 return post_order_position.at(value1->instruction()) <
1340 post_order_position.at(value2->instruction());
1341 };
1342 const HloValue* a_min = *absl::c_min_element(a->values(), compare);
1343 const HloValue* b_min = *absl::c_min_element(b->values(), compare);
1344 return compare(a_min, b_min);
1345 });
1346
1347 std::vector<BufferAllocation::Index> allocation_indices;
1348
1349 for (const HloBuffer* buffer : sorted_buffers) {
1350 VLOG(3) << "=================================================";
1351 VLOG(3) << "Assigning buffer for " << *buffer;
1352 TF_RETURN_IF_ERROR(AssignSingleHloBuffer(buffer, is_thread_local,
1353 buffers_to_assign_sequentially,
1354 &allocation_indices, assignment));
1355 }
1356 return Status::OK();
1357 }
1358
1359 flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>>
SplitBuffersByColor(const flat_hash_set<const HloValue * > & buffers)1360 BufferAssigner::SplitBuffersByColor(
1361 const flat_hash_set<const HloValue*>& buffers) {
1362 flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>> color_map;
1363 for (auto buffer : buffers) {
1364 color_map[buffer->color()].insert(buffer);
1365 }
1366 return color_map;
1367 }
1368
AssignPresetBuffers(absl::flat_hash_set<const HloBuffer * > * assigned_buffers,BufferAssignment * assignment)1369 Status BufferAssigner::AssignPresetBuffers(
1370 absl::flat_hash_set<const HloBuffer*>* assigned_buffers,
1371 BufferAssignment* assignment) {
1372 if (!preset_assignments_) {
1373 return Status::OK();
1374 }
1375
1376 // Create an allocation for each preset color.
1377 absl::flat_hash_map<LogicalBuffer::Color, BufferAllocation*>
1378 preset_allocations;
1379 for (auto& color_and_info : preset_assignments_->assignment_informations()) {
1380 LogicalBuffer::Color color(color_and_info.first);
1381 auto inserted = preset_allocations.emplace(
1382 color,
1383 assignment->NewEmptyAllocation(color_and_info.second.size, color));
1384 BufferAllocation* inserted_allocation = inserted.first->second;
1385 inserted_allocation->AddHeapTrace(
1386 color_and_info.second.heap_simulator_trace);
1387 VLOG(3) << "Created preset buffer allocation "
1388 << inserted_allocation->index()
1389 << ", color: " << inserted_allocation->color()
1390 << ", size: " << inserted_allocation->size();
1391 }
1392
1393 const HloAliasAnalysis& alias_analysis = assignment->alias_analysis();
1394
1395 for (auto& position_and_chunk : preset_assignments_->chunks()) {
1396 const HloPosition& defining_position = position_and_chunk.first;
1397 const HloBuffer& buffer = alias_analysis.GetUniqueBufferAt(
1398 defining_position.instruction, defining_position.index);
1399 for (const HloValue* value : buffer.values()) {
1400 VLOG(3) << "Preset allocation for value: " << value->ToShortString();
1401 const HeapSimulator::Chunk& chunk = position_and_chunk.second;
1402 auto preset_allocations_iter = preset_allocations.find(value->color());
1403 CHECK(preset_allocations_iter != preset_allocations.end())
1404 << "No preset value allocation for color " << value->color()
1405 << " for " << value->ToShortString() << " found.";
1406 preset_allocations_iter->second->AddAssignment(*value, chunk.offset,
1407 chunk.size);
1408 }
1409
1410 assigned_buffers->insert(&buffer);
1411 }
1412
1413 // Upon consumption of the preset assignments, delete it so that if this
1414 // method is called again, it does not assign the same buffers multiple times.
1415 preset_assignments_ = {};
1416
1417 return Status::OK();
1418 }
1419
AssignBuffersWithSequentialOrdering(const flat_hash_map<const HloComputation *,flat_hash_set<const HloValue * >> & buffers_to_assign_sequentially,bool run_whole_module_heap_simulation,BufferAssignment * assignment)1420 Status BufferAssigner::AssignBuffersWithSequentialOrdering(
1421 const flat_hash_map<const HloComputation*, flat_hash_set<const HloValue*>>&
1422 buffers_to_assign_sequentially,
1423 bool run_whole_module_heap_simulation, BufferAssignment* assignment) {
1424 // Run the sequence of instructions through the heap simulator. The
1425 // heuristic that seems to give the best results is lazy-best-fit, with all
1426 // runs of alloc / free calls sorted in decreasing size order.
1427 const HloOrdering& hlo_ordering = assignment->hlo_ordering();
1428
1429 // Returns a heap algorithm that chooses the best result from several
1430 // algorithms.
1431 auto get_heap_algorithm = [&](int64_t alignment) {
1432 auto algorithms = absl::make_unique<
1433 std::vector<std::unique_ptr<HeapAlgorithm<HloValue>>>>();
1434 algorithms->push_back(
1435 absl::make_unique<ConstrainedGlobalDecreasingSizeBestFitHeap>(
1436 assignment->multiheap_size_constraint_per_heap(), alignment,
1437 GlobalDecreasingSizeBestFitHeap<HloValue>::kSpatial));
1438 algorithms->push_back(
1439 absl::make_unique<ConstrainedGlobalDecreasingSizeBestFitHeap>(
1440 assignment->multiheap_size_constraint_per_heap(), alignment,
1441 GlobalDecreasingSizeBestFitHeap<HloValue>::kTemporal));
1442 return absl::make_unique<ChooseBestHeapAlgorithm<HloValue>>(
1443 std::move(algorithms));
1444 };
1445
1446 if (run_whole_module_heap_simulation) {
1447 // Run the heap simulation over the whole module. This reduces memory
1448 // usage, since buffers for kCall, kWhile, and kConditional
1449 // sub-computations are only live for the duration of their calling
1450 // instructions.
1451 VLOG(1) << "Running whole-module heap simulation";
1452 HloSchedule schedule(&assignment->module());
1453 flat_hash_set<const HloValue*> all_buffers_to_assign;
1454 for (const auto& pair : buffers_to_assign_sequentially) {
1455 const HloComputation* computation = pair.first;
1456 const flat_hash_set<const HloValue*>& buffers_to_assign = pair.second;
1457 const HloInstructionSequence* instruction_sequence =
1458 hlo_ordering.SequentialOrder(*computation);
1459 CHECK(instruction_sequence != nullptr) << computation->name();
1460 schedule.set_sequence(computation, *instruction_sequence);
1461 all_buffers_to_assign.insert(buffers_to_assign.begin(),
1462 buffers_to_assign.end());
1463 }
1464 auto color_map = SplitBuffersByColor(all_buffers_to_assign);
1465 for (auto& single_colored_set : color_map) {
1466 auto color = single_colored_set.first;
1467 VLOG(2) << "Simulating heap for color " << color;
1468 int64_t alignment = assignment->color_alignment_(color);
1469 HeapSimulator::Options options;
1470 options.alloc_constants = allocate_buffers_for_constants_;
1471 options.buffers_to_assign = &single_colored_set.second;
1472
1473 TF_ASSIGN_OR_RETURN(
1474 HeapSimulator::Result<HloValue> result,
1475 HeapSimulator::Run(
1476 get_heap_algorithm(alignment), assignment->module(), schedule,
1477 assignment->alias_analysis(), assignment->buffer_size_, options));
1478 AssignBuffersFromHeapSimulator(result, assignment,
1479 single_colored_set.first);
1480 }
1481 } else {
1482 // Run the heap-simulation on a per-computation basis. Buffers for
1483 // sub-computations are assigned disjoint BufferAllocations, assuming the
1484 // worst-case that they may all be live concurrently.
1485 VLOG(1) << "Running per-computation heap simulation";
1486 for (const auto& pair : buffers_to_assign_sequentially) {
1487 const HloComputation* computation = pair.first;
1488 const flat_hash_set<const HloValue*>& buffers_to_assign = pair.second;
1489 const HloInstructionSequence* instruction_sequence =
1490 hlo_ordering.SequentialOrder(*computation);
1491 CHECK(instruction_sequence != nullptr) << computation->name();
1492 auto color_map = SplitBuffersByColor(buffers_to_assign);
1493 for (auto& single_colored_set : color_map) {
1494 auto color = single_colored_set.first;
1495 VLOG(2) << "Simulating heap for color " << color;
1496 int64_t alignment = assignment->color_alignment_(color);
1497 HeapSimulator::Options options;
1498 options.buffers_to_assign = &single_colored_set.second;
1499 TF_ASSIGN_OR_RETURN(
1500 HeapSimulator::Result<HloValue> result,
1501 HeapSimulator::Run(get_heap_algorithm(alignment), *computation,
1502 *instruction_sequence,
1503 assignment->alias_analysis(),
1504 assignment->buffer_size_, options));
1505 AssignBuffersFromHeapSimulator(result, assignment,
1506 single_colored_set.first);
1507 }
1508 }
1509 }
1510 return Status::OK();
1511 }
1512
1513 namespace {
1514 // Computes and returns the set of logical buffers live at the point of
1515 // maximal liveness in the given heap trace. LogicalBuffers are (stabily)
1516 // sorted by id.
ComputePeakMemoryLogicalBuffers(const BufferAllocation & allocation,const HeapSimulatorTrace & heap_trace)1517 std::vector<const HloValue*> ComputePeakMemoryLogicalBuffers(
1518 const BufferAllocation& allocation, const HeapSimulatorTrace& heap_trace) {
1519 // Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical
1520 // buffers in this allocation.
1521 absl::flat_hash_map<BufferValue::Id, const HloValue*> id_to_value;
1522 absl::flat_hash_map<const HloValue*, int64> buffer_sizes;
1523 for (const auto& pair : allocation.assigned_buffers()) {
1524 const HloValue* value = pair.first;
1525 const BufferAllocation::OffsetSize& offset_size = pair.second;
1526 id_to_value[value->id()] = value;
1527 buffer_sizes[value] = offset_size.size;
1528 }
1529 VLOG(1) << "Compute peak memory logical buffers";
1530
1531 // Returns how much the given event increases the total size of live
1532 // buffers. Can be negative.
1533 auto memory_delta = [&id_to_value, &buffer_sizes](
1534 const HeapSimulatorTrace::Event& event) -> int64 {
1535 const HloValue* buffer = id_to_value.at(event.buffer_id());
1536 const int64_t buffer_size = buffer_sizes.at(buffer);
1537 if (event.kind() == HeapSimulatorTrace::Event::ALLOC ||
1538 event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
1539 return buffer_size;
1540 } else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
1541 return -1 * buffer_size;
1542 }
1543 LOG(FATAL) << "Unknown event kind: " << event.kind();
1544 };
1545
1546 // First compute the size of the maximal live set.
1547 int64_t max_live_size = 0;
1548 int64_t live_size = 0;
1549 for (const auto& event : heap_trace.events()) {
1550 if (!id_to_value.contains(event.buffer_id())) {
1551 // Skip as the buffer associated with this trace event is not placed into
1552 // this allocation. This can happen when size constraints are given to the
1553 // heap simulator.
1554 continue;
1555 }
1556 live_size += memory_delta(event);
1557 if (max_live_size < live_size) {
1558 max_live_size = live_size;
1559 }
1560 }
1561
1562 // Next gather the set of logical buffers live at the earliest point of
1563 // maximal live set size.
1564 absl::flat_hash_set<const HloValue*> live_values;
1565 live_size = 0;
1566 for (const auto& event : heap_trace.events()) {
1567 if (!id_to_value.contains(event.buffer_id())) {
1568 // Skip as the buffer associated with this trace event is not placed into
1569 // this allocation. This can happen when size constraints are given to the
1570 // heap simulator.
1571 continue;
1572 }
1573 const HloValue* value = id_to_value.at(event.buffer_id());
1574 if (event.kind() == HeapSimulatorTrace::Event::ALLOC ||
1575 event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
1576 InsertOrDie(&live_values, value);
1577 } else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
1578 CHECK(ContainsKey(live_values, value));
1579 live_values.erase(value);
1580 }
1581 live_size += memory_delta(event);
1582
1583 if (live_size == max_live_size) {
1584 break;
1585 }
1586 }
1587 CHECK_EQ(live_size, max_live_size);
1588
1589 std::vector<const HloValue*> live_values_vector;
1590 live_values_vector.insert(live_values_vector.end(), live_values.begin(),
1591 live_values.end());
1592
1593 // Stabily sort the live buffers.
1594 absl::c_sort(live_values_vector, [](const HloValue* a, const HloValue* b) {
1595 return a->id() < b->id();
1596 });
1597 VLOG(4) << "Peak memory buffer:";
1598 for (auto value : live_values_vector) {
1599 VLOG(4) << " " << value->ToString();
1600 }
1601 return live_values_vector;
1602 }
1603
1604 } // namespace
1605
AssignBuffersFromHeapSimulator(const HeapSimulator::Result<HloValue> & result,BufferAssignment * assignment,BufferValue::Color color)1606 void BufferAssigner::AssignBuffersFromHeapSimulator(
1607 const HeapSimulator::Result<HloValue>& result, BufferAssignment* assignment,
1608 BufferValue::Color color) {
1609 if (assignment->stats_.preallocated_temp_fragmentation_bytes == -1) {
1610 assignment->stats_.preallocated_temp_fragmentation_bytes =
1611 result.fragmentation_size;
1612 } else {
1613 assignment->stats_.preallocated_temp_fragmentation_bytes +=
1614 result.fragmentation_size;
1615 }
1616 VLOG(1) << "Result size from heap simulator: " << result.heap_size;
1617
1618 // Iterate through heap_results. For each heap_result, create a new allocation
1619 // in `assignment`.
1620 for (const HeapSimulator::HeapResult<HloValue>& heap_result :
1621 result.heap_results) {
1622 BufferAllocation* allocation =
1623 assignment->NewEmptyAllocation(heap_result.heap_size, color);
1624 for (const auto& buffer_chunk : heap_result.chunk_map) {
1625 const HloValue& value = *buffer_chunk.first;
1626 const HeapSimulator::Chunk& chunk = buffer_chunk.second;
1627 assignment->AddAssignment(allocation, value, chunk.offset, chunk.size);
1628 }
1629 allocation->peak_buffers_ =
1630 ComputePeakMemoryLogicalBuffers(*allocation, result.debug_trace);
1631
1632 XLA_VLOG_LINES(2, allocation->ToString());
1633
1634 allocation->AddHeapTrace(result.debug_trace);
1635 }
1636 }
1637
CreateAssignment(const HloModule * module,std::unique_ptr<HloOrdering> hlo_ordering,BufferValue::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment,HloDataflowAnalysis::CanShareBuffer can_share_buffer)1638 StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
1639 const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
1640 BufferValue::SizeFunction buffer_size,
1641 LogicalBuffer::AlignmentFunction color_alignment,
1642 HloDataflowAnalysis::CanShareBuffer can_share_buffer) {
1643 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1644 HloAliasAnalysis::Run(module, can_share_buffer));
1645
1646 // Set up a schedule for each computation.
1647 HloSchedule schedule(module);
1648 for (const HloComputation* computation : module->computations()) {
1649 const HloInstructionSequence* instruction_sequence =
1650 hlo_ordering->SequentialOrder(*computation);
1651 const bool has_sequential_order = instruction_sequence != nullptr;
1652 if (has_sequential_order) {
1653 schedule.set_sequence(computation, *instruction_sequence);
1654 }
1655 }
1656
1657 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
1658 HloLiveRange::Run(schedule, *alias_analysis,
1659 module->entry_computation(), true));
1660
1661 VLOG(1) << "Assigning buffers to module " << module->name();
1662 XLA_VLOG_LINES(3, module->ToString());
1663 XLA_VLOG_LINES(3, alias_analysis->ToString());
1664 XLA_VLOG_LINES(3, alias_analysis->dataflow_analysis().ToString());
1665 VLOG(1) << "Number of buffers to assign: "
1666 << alias_analysis->buffers().size();
1667
1668 // Can't use absl::make_unique because BufferAssignment constructor is
1669 // private.
1670 std::unique_ptr<BufferAssignment> assignment(new BufferAssignment(
1671 module, std::move(hlo_ordering), std::move(buffer_size),
1672 std::move(color_alignment), std::move(alias_analysis),
1673 std::move(hlo_live_range)));
1674
1675 TF_RETURN_IF_ERROR(
1676 colorer_(&assignment->alias_analysis(), assignment->hlo_ordering()));
1677 VLOG(3) << "After coloring:";
1678 XLA_VLOG_LINES(3,
1679 assignment->alias_analysis().dataflow_analysis().ToString());
1680
1681 std::vector<const HloComputation*> thread_local_computations;
1682 std::vector<const HloComputation*> global_computations;
1683 TF_RETURN_IF_ERROR(GatherComputationsByAllocationType(
1684 module, &thread_local_computations, &global_computations));
1685
1686 // First assign buffers for global computations. Temporary buffers for
1687 // sequential computations are collected in
1688 // 'buffers_to_assign_sequentially'.
1689 flat_hash_map<const HloComputation*, flat_hash_set<const HloValue*>>
1690 buffers_to_assign_sequentially;
1691 TF_RETURN_IF_ERROR(AssignBuffersForComputations(
1692 global_computations,
1693 /*is_thread_local=*/false, &buffers_to_assign_sequentially,
1694 assignment.get()));
1695 // Assign buffers with sequential ordering, if any. If all global
1696 // computations are sequential, we can run heap simulation on the whole
1697 // module, which reduces memory usage.
1698 const bool run_whole_module_heap_simulation =
1699 buffers_to_assign_sequentially.size() == global_computations.size();
1700 VLOG(2) << "Running whole module heap simulation: "
1701 << run_whole_module_heap_simulation;
1702 const int32_t multiheap_size_constraint_per_heap =
1703 module->config().debug_options().xla_multiheap_size_constraint_per_heap();
1704 VLOG(2) << "Multiheap per heap size limit: "
1705 << multiheap_size_constraint_per_heap;
1706 TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering(
1707 buffers_to_assign_sequentially, run_whole_module_heap_simulation,
1708 assignment.get()));
1709
1710 std::vector<const HloComputation*> thread_local_computations_no_fusion;
1711 // Now assign buffers for thread-local computations. All LogicalBuffers get
1712 // their own BufferAllocation.
1713
1714 for (auto* computation : thread_local_computations) {
1715 TF_RET_CHECK(computation != module->entry_computation());
1716 if (computation->IsFusionComputation()) {
1717 continue;
1718 }
1719 thread_local_computations_no_fusion.push_back(computation);
1720 }
1721
1722 TF_RETURN_IF_ERROR(AssignBuffersForComputations(
1723 thread_local_computations_no_fusion,
1724 /*is_thread_local=*/true,
1725 /*buffers_to_assign_sequentially=*/nullptr, assignment.get()));
1726
1727 // Mark all buffers which may be live out of the entry computation as
1728 // "liveout".
1729 for (const HloBuffer* buffer :
1730 assignment->alias_analysis().LiveOutBuffers()) {
1731 VLOG(3) << "maybe_live_out LogicalBuffer: " << *buffer;
1732 if (assignment->HasAllocation(*buffer)) {
1733 BufferAllocation* alloc =
1734 assignment->GetMutableAssignedAllocation(*buffer);
1735 alloc->set_maybe_live_out(true);
1736 VLOG(3) << "maybe_live_out BufferAllocation: " << *alloc;
1737 }
1738 }
1739
1740 // Combines allocations of temporary buffers into big BufferAllocations
1741 // subject to the buffer allocation size constraint. This can only be
1742 // performed after all buffers have been assigned, and after maybe_live_out
1743 // is marked, since it is used to determine whether an allocation contains
1744 // temporary buffers or not.
1745 assignment->CombineTempAllocations();
1746
1747 XLA_VLOG_LINES(2, assignment->ToString());
1748 TF_RETURN_IF_ERROR(assignment->ComputeSummaryStats());
1749 XLA_VLOG_LINES(1, assignment->GetStats().ToString());
1750 VLOG(1) << "Buffer assignment done.";
1751 return std::move(assignment);
1752 }
1753
1754 } // namespace xla
1755