1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Defines the data returned by the XLA buffer assignment packages.
17
18 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
19
20 #include <algorithm>
21 #include <deque>
22 #include <ostream>
23 #include <utility>
24
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/memory/memory.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_format.h"
30 #include "tensorflow/compiler/xla/map_util.h"
31 #include "tensorflow/compiler/xla/service/buffer_value_containers.h"
32 #include "tensorflow/compiler/xla/service/heap_simulator.h"
33 #include "tensorflow/compiler/xla/service/hlo.pb.h"
34 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
35 #include "tensorflow/compiler/xla/shape_util.h"
36 #include "tensorflow/compiler/xla/status_macros.h"
37 #include "tensorflow/compiler/xla/types.h"
38 #include "tensorflow/compiler/xla/util.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/lib/hash/hash.h"
41 #include "tensorflow/core/lib/strings/numbers.h"
42
43 namespace xla {
44 namespace {
45
46 using absl::flat_hash_map;
47 using absl::flat_hash_set;
48 using absl::StrAppend;
49 using absl::StrAppendFormat;
50 using ::tensorflow::strings::HumanReadableNumBytes;
51
52 template <typename T>
ColocatedBufferSetsToString(const T & container,const char * title)53 string ColocatedBufferSetsToString(const T& container, const char* title) {
54 string result;
55 StrAppend(&result, title, "\n");
56 for (const auto& it : container) {
57 StrAppend(&result, "\t", it->ToString(), "\n");
58 }
59 return result;
60 }
61
62 // Checks that points-to set of 'instruction' is unambiguous and distinct
63 // (ensured by CopyInsertion), then adds the buffer from the points-to set at
64 // 'index' to 'colocated_set'.
AddBufferToColocatedSet(const HloInstruction * instruction,const ShapeIndex & index,const TuplePointsToAnalysis & points_to_analysis,std::vector<const LogicalBuffer * > * colocated_set)65 const LogicalBuffer* AddBufferToColocatedSet(
66 const HloInstruction* instruction, const ShapeIndex& index,
67 const TuplePointsToAnalysis& points_to_analysis,
68 std::vector<const LogicalBuffer*>* colocated_set) {
69 // CopyInsertion ensures root points-to set is unambiguous and distinct.
70 const auto& points_to = points_to_analysis.GetPointsToSet(instruction);
71 DCHECK(!points_to.IsAmbiguous());
72 colocated_set->push_back(points_to.element(index)[0]);
73 return colocated_set->back();
74 }
75
76 // Given the interference map of a graph (the list of interfering node indices
77 // for each node), perform graph coloring such that interfering nodes are
78 // assigned to different colors. Returns the assigned color of the nodes, where
79 // the colors are represented as integer values [0, color_count).
ColorInterferenceGraph(const std::vector<std::vector<int64>> & interference_map)80 std::vector<int64> ColorInterferenceGraph(
81 const std::vector<std::vector<int64>>& interference_map) {
82 const int64 node_count = interference_map.size();
83
84 // Sort the nodes such that we assign nodes with more interference first. This
85 // relies on the common heuristic of assigning the most constrained node
86 // first, but it would be good to investigate other ordering heuristics too.
87 std::vector<int64> nodes(node_count);
88 std::iota(nodes.begin(), nodes.end(), 0);
89 absl::c_sort(nodes, [&interference_map](const int64 i, const int64 j) {
90 return interference_map[i].size() > interference_map[j].size();
91 });
92
93 const int64 kColorUnassigned = -1;
94 std::vector<int64> assigned_colors(node_count, kColorUnassigned);
95 for (int64 node : nodes) {
96 // Mark the colors that are already assigned to the neighbors.
97 std::vector<bool> available_colors(node_count, true);
98 for (int64 neighbor : interference_map[node]) {
99 int64 color = assigned_colors[neighbor];
100 if (color != kColorUnassigned) {
101 available_colors[color] = false;
102 }
103 }
104
105 // Find the color that is not yet assigned to the neighbors.
106 int64 color = kColorUnassigned;
107 for (color = 0; color < available_colors.size(); ++color) {
108 if (available_colors[color]) {
109 break;
110 }
111 }
112 CHECK_NE(color, kColorUnassigned);
113 assigned_colors[node] = color;
114 }
115 return assigned_colors;
116 }
117
118 } // namespace
119
GatherComputationsByAllocationType(const HloModule * module,std::vector<const HloComputation * > * thread_local_computations,std::vector<const HloComputation * > * global_computations)120 Status GatherComputationsByAllocationType(
121 const HloModule* module,
122 std::vector<const HloComputation*>* thread_local_computations,
123 std::vector<const HloComputation*>* global_computations) {
124 // Create a worklist of computations paired with whether the allocation must
125 // be thread-local.
126 std::deque<std::pair<const HloComputation*, bool>> worklist;
127 worklist.push_back(std::make_pair(module->entry_computation(),
128 /*is_thread_local*/ false));
129
130 // Sets for quickly checking membership. Computations are returned in vectors
131 // for stable iteration.
132 flat_hash_set<const HloComputation*> thread_local_set;
133 flat_hash_set<const HloComputation*> global_set;
134
135 while (!worklist.empty()) {
136 auto worklist_front = worklist.front();
137 worklist.pop_front();
138 const HloComputation* computation = worklist_front.first;
139 bool is_thread_local = worklist_front.second;
140 bool in_thread_local_set = thread_local_set.contains(computation);
141 bool in_global_set = global_set.contains(computation);
142
143 // If the computation has already been added to the respective set, then
144 // nothing to do.
145 if ((is_thread_local && in_thread_local_set) ||
146 (!is_thread_local && in_global_set)) {
147 continue;
148 }
149
150 // If the computation has already been added to the other set this is an
151 // error condition because the global call to the computation (eg,
152 // while/call) may return a reference to one of the thread-local buffers to
153 // the calling computation which will become a dangling reference when the
154 // thread-local is deallocated with the call return.
155 if ((is_thread_local && in_global_set) ||
156 (!is_thread_local && in_thread_local_set)) {
157 return InvalidArgument(
158 "computation %s has conflicting allocation requirements (global "
159 "and thread-local)",
160 computation->name());
161 }
162
163 if (is_thread_local) {
164 thread_local_set.insert(computation);
165 } else {
166 global_set.insert(computation);
167 }
168
169 for (auto* instruction : computation->instructions()) {
170 for (HloComputation* subcomputation :
171 instruction->called_computations()) {
172 switch (instruction->opcode()) {
173 case HloOpcode::kCall:
174 case HloOpcode::kConditional:
175 case HloOpcode::kWhile:
176 // Call and while must be called from a computation with global
177 // allocations as they may return references to buffers inside the
178 // called computation which cannot be thread-local.
179 if (is_thread_local) {
180 return InvalidArgument(
181 "computation %s cannot contain call/while op because it "
182 "requires thread-local buffer allocations",
183 computation->name());
184 }
185 worklist.push_back(std::make_pair(subcomputation,
186 false)); // Not thread local.
187 break;
188 case HloOpcode::kAllReduce:
189 case HloOpcode::kMap:
190 case HloOpcode::kReduce:
191 case HloOpcode::kReduceWindow:
192 case HloOpcode::kScatter:
193 case HloOpcode::kSelectAndScatter:
194 case HloOpcode::kSort:
195 case HloOpcode::kFusion:
196 // Map/reduce etc computations are always thread-local.
197 worklist.push_back(std::make_pair(subcomputation,
198 true)); // Thread local.
199 break;
200 default:
201 return InternalError("Unexpected calling opcode: %s",
202 HloOpcodeString(instruction->opcode()));
203 }
204 }
205 }
206 }
207
208 // Add the computations to the vectors in post order.
209 for (auto* computation : module->MakeComputationPostOrder()) {
210 if (thread_local_set.contains(computation)) {
211 thread_local_computations->push_back(computation);
212 } else if (global_set.contains(computation)) {
213 global_computations->push_back(computation);
214 }
215 // If the computation is not reachable from the entry computation, then it
216 // will not appear in either thread_local_set or global_set. We don't bother
217 // assigning buffers for these.
218 }
219 return Status::OK();
220 }
221
ToString() const222 string BufferAllocation::Slice::ToString() const {
223 return absl::StrCat("{index:", index(), ", offset:", offset_,
224 ", size:", size_, "}");
225 }
226
GetSlice(const LogicalBuffer & buffer) const227 BufferAllocation::Slice BufferAllocation::GetSlice(
228 const LogicalBuffer& buffer) const {
229 const OffsetSize os = FindOrDie(assigned_buffers_, &buffer);
230 return Slice(this, os.offset, os.size);
231 }
232
AddAssignment(const LogicalBuffer & buffer,int64 offset,int64 size)233 void BufferAllocation::AddAssignment(const LogicalBuffer& buffer, int64 offset,
234 int64 size) {
235 VLOG(4) << "Trying to add " << buffer << " to allocation #" << index();
236 CHECK(!assigned_buffers_.contains(&buffer))
237 << "LogicalBuffer " << buffer << " already assigned to allocation "
238 << index_;
239 CHECK_LE(offset, size_) << "LogicalBuffer " << buffer
240 << " offset out of range";
241 CHECK_LE(offset + size, size_)
242 << "LogicalBuffer " << buffer
243 << " size out of range at offset: " << offset << " with size: " << size;
244 CHECK_EQ(buffer.color(), color())
245 << "Buffer color " << buffer.color() << " for buffer " << buffer
246 << " does not match allocation color " << color() << ".";
247 OffsetSize offset_size;
248 offset_size.offset = offset;
249 offset_size.size = size;
250 assigned_buffers_.emplace(&buffer, offset_size);
251 }
252
ToProto() const253 BufferAllocationProto BufferAllocation::ToProto() const {
254 BufferAllocationProto proto;
255 proto.set_index(index_);
256 proto.set_size(size_);
257 proto.set_is_thread_local(is_thread_local_);
258 proto.set_is_tuple(is_tuple_);
259 proto.set_color(color_.value());
260 if (is_entry_computation_parameter_) {
261 proto.set_is_entry_computation_parameter(true);
262 for (int64 idx : param_shape_index()) {
263 proto.add_parameter_shape_index(idx);
264 }
265 proto.set_parameter_number(parameter_number_);
266 }
267 proto.set_is_constant(is_constant_);
268 proto.set_maybe_live_out(maybe_live_out_);
269 for (const auto& buffer_offset_size : assigned_buffers_) {
270 BufferAllocationProto::Assigned* proto_assigned = proto.add_assigned();
271 proto_assigned->set_logical_buffer_id(buffer_offset_size.first->id());
272 proto_assigned->set_offset(buffer_offset_size.second.offset);
273 proto_assigned->set_size(buffer_offset_size.second.size);
274 }
275 absl::c_sort(*proto.mutable_assigned(),
276 [](const BufferAllocationProto::Assigned& assign1,
277 const BufferAllocationProto::Assigned& assign2) {
278 return assign1.logical_buffer_id() <
279 assign2.logical_buffer_id();
280 });
281 return proto;
282 }
283
ToString() const284 string BufferAllocation::ToString() const {
285 string output;
286 StrAppendFormat(&output, "allocation %d: %p, size %d", index_, this, size());
287 if (color().value() != 0) {
288 StrAppend(&output, ", color ", color().value());
289 }
290 if (is_entry_computation_parameter()) {
291 StrAppend(&output, ", parameter ", parameter_number(), " at ShapeIndex ",
292 param_shape_index().ToString());
293 }
294 if (is_constant()) {
295 StrAppend(&output, ", constant");
296 }
297 if (is_thread_local()) {
298 StrAppend(&output, ", thread-local");
299 }
300 if (maybe_live_out()) {
301 StrAppend(&output, ", maybe-live-out");
302 }
303 if (IsPreallocatedTempBuffer()) {
304 StrAppend(&output, ", preallocated-temp");
305 }
306 StrAppend(&output, ":\n");
307 // Dump the assigned buffers ordered by id.
308 std::vector<const LogicalBuffer*> sorted_buffers;
309 for (const auto& buffer_offset_size : assigned_buffers_) {
310 sorted_buffers.push_back(buffer_offset_size.first);
311 }
312 absl::c_sort(sorted_buffers,
313 [](const LogicalBuffer* a, const LogicalBuffer* b) {
314 return a->id() < b->id();
315 });
316 for (const LogicalBuffer* buffer : sorted_buffers) {
317 const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer);
318 StrAppend(&output, absl::StrFormat(
319 " %s [%d,%d]: %s\n", buffer->ToString(),
320 offset_size.offset, offset_size.size,
321 ShapeUtil::HumanStringWithLayout(buffer->shape())));
322 }
323 return output;
324 }
325
operator <<(std::ostream & out,const BufferAllocation & buffer)326 std::ostream& operator<<(std::ostream& out, const BufferAllocation& buffer) {
327 out << buffer.ToString();
328 return out;
329 }
330
operator <<(std::ostream & out,const BufferAllocation::Slice & s)331 std::ostream& operator<<(std::ostream& out, const BufferAllocation::Slice& s) {
332 out << s.ToString();
333 return out;
334 }
335
GetPointsToSet(const HloInstruction * instruction) const336 const PointsToSet& BufferAssignment::GetPointsToSet(
337 const HloInstruction* instruction) const {
338 return points_to_analysis().GetPointsToSet(instruction);
339 }
340
HasAllocation(const LogicalBuffer & buffer) const341 bool BufferAssignment::HasAllocation(const LogicalBuffer& buffer) const {
342 TF_CHECK_OK(points_to_analysis().VerifyBuffer(buffer));
343 return allocation_index_for_buffer_.contains(&buffer);
344 }
345
GetAssignedAllocation(const LogicalBuffer & buffer) const346 const BufferAllocation& BufferAssignment::GetAssignedAllocation(
347 const LogicalBuffer& buffer) const {
348 CHECK(HasAllocation(buffer));
349 return GetAllocation(allocation_index_for_buffer_.at(&buffer));
350 }
351
GetMutableAssignedAllocation(const LogicalBuffer & buffer)352 BufferAllocation* BufferAssignment::GetMutableAssignedAllocation(
353 const LogicalBuffer& buffer) {
354 return const_cast<BufferAllocation*>(&GetAssignedAllocation(buffer));
355 }
356
GetAllSlices(const HloInstruction * instruction,const ShapeIndex & index) const357 std::set<BufferAllocation::Slice> BufferAssignment::GetAllSlices(
358 const HloInstruction* instruction, const ShapeIndex& index) const {
359 std::set<BufferAllocation::Slice> result;
360 for (const LogicalBuffer* buffer : GetSourceBuffers(instruction, index)) {
361 if (HasAllocation(*buffer)) {
362 result.insert(GetAssignedAllocation(*buffer).GetSlice(*buffer));
363 }
364 }
365 return result;
366 }
367
GetAllocation(BufferAllocation::Index index) const368 const BufferAllocation& BufferAssignment::GetAllocation(
369 BufferAllocation::Index index) const {
370 CHECK_GE(index, 0);
371 CHECK_LT(index, allocations_.size());
372 return allocations_[index];
373 }
374
GetInstructionAllocation(const HloInstruction * hlo,const ShapeIndex & shape_index) const375 const BufferAllocation* BufferAssignment::GetInstructionAllocation(
376 const HloInstruction* hlo, const ShapeIndex& shape_index) const {
377 const PointsToSet& points_to_set = points_to_analysis().GetPointsToSet(hlo);
378 const LogicalBuffer* buffer = points_to_set.element(shape_index)[0];
379
380 if (!HasAllocation(*buffer)) {
381 return nullptr;
382 }
383
384 const BufferAllocation& instruction_allocation =
385 GetAssignedAllocation(*buffer);
386 return &instruction_allocation;
387 }
388
GetMutableAllocation(BufferAllocation::Index index)389 BufferAllocation* BufferAssignment::GetMutableAllocation(
390 BufferAllocation::Index index) {
391 return const_cast<BufferAllocation*>(&GetAllocation(index));
392 }
393
HasAllocationAt(const HloInstruction * instruction,const ShapeIndex & index) const394 bool BufferAssignment::HasAllocationAt(const HloInstruction* instruction,
395 const ShapeIndex& index) const {
396 for (const LogicalBuffer* buffer :
397 GetPointsToSet(instruction).element(index)) {
398 if (allocation_index_for_buffer_.contains(buffer)) {
399 return true;
400 }
401 }
402 return false;
403 }
404
HasTopLevelAllocation(const HloInstruction * instruction) const405 bool BufferAssignment::HasTopLevelAllocation(
406 const HloInstruction* instruction) const {
407 return HasAllocationAt(instruction, /*index=*/{});
408 }
409
GetUniqueSlice(const HloInstruction * instruction,const ShapeIndex & index) const410 StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueSlice(
411 const HloInstruction* instruction, const ShapeIndex& index) const {
412 VLOG(3) << "Trying to find unique slice for " << instruction->name() << " ["
413 << index << "]";
414 BufferAllocation::Slice result;
415 for (const LogicalBuffer* buffer :
416 GetPointsToSet(instruction).element(index)) {
417 VLOG(3) << "Examining buffer " << *buffer;
418 if (HasAllocation(*buffer)) {
419 VLOG(3) << "Has allocation";
420 const BufferAllocation::Slice slice =
421 GetAssignedAllocation(*buffer).GetSlice(*buffer);
422 if (result.allocation() == nullptr) {
423 result = slice;
424 } else if (result != slice) {
425 return FailedPrecondition(
426 "BufferAllocation::Slice for instruction %s at index %s cannot "
427 "be determined at compile-time.",
428 instruction->name(), index.ToString());
429 }
430 } else {
431 VLOG(3) << "No allocation";
432 }
433 }
434 if (result.allocation() == nullptr) {
435 return FailedPrecondition(
436 "BufferAllocation::Slice not assigned for instruction %s at index %s",
437 instruction->name(), index.ToString());
438 }
439 return result;
440 }
441
GetUniqueTopLevelSlice(const HloInstruction * instruction) const442 StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueTopLevelSlice(
443 const HloInstruction* instruction) const {
444 return GetUniqueSlice(instruction, /*index=*/{});
445 }
446
SharesSliceAtIndex(const HloInstruction * hlo_a,const ShapeIndex & shape_index_a,const HloInstruction * hlo_b,const ShapeIndex & shape_index_b) const447 bool BufferAssignment::SharesSliceAtIndex(
448 const HloInstruction* hlo_a, const ShapeIndex& shape_index_a,
449 const HloInstruction* hlo_b, const ShapeIndex& shape_index_b) const {
450 return GetUniqueSlice(hlo_a, shape_index_a).ConsumeValueOrDie() ==
451 GetUniqueSlice(hlo_b, shape_index_b).ConsumeValueOrDie();
452 }
453
HaveDisjointSlices(const HloInstruction * hlo_a,const HloInstruction * hlo_b) const454 bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a,
455 const HloInstruction* hlo_b) const {
456 using SliceSet = flat_hash_set<BufferAllocation::Slice>;
457 // Gets the slices all of instr's subshapes. If any subshape doesn't have an
458 // assigned slice, returns the empty set.
459 auto collect_slices = [&](const HloInstruction* instr) -> SliceSet {
460 SliceSet slices;
461 Status status = ShapeUtil::ForEachSubshapeWithStatus(
462 instr->shape(),
463 [&](const Shape& /*subshape*/, const ShapeIndex& index) {
464 auto shape_slices = GetAllSlices(instr, index);
465 if (shape_slices.empty()) {
466 return InvalidArgument("No slices assigned to part of instr.");
467 }
468 slices.insert(shape_slices.begin(), shape_slices.end());
469 return Status::OK();
470 });
471 if (!status.ok()) {
472 return {};
473 }
474 return slices;
475 };
476
477 SliceSet slices_a = collect_slices(hlo_a);
478 SliceSet slices_b = collect_slices(hlo_b);
479 // hlo_a and hlo_b have disjoint slices if collect_slices succeeded (i.e.
480 // didn't return the empty set) for both HLOs, and the two resulting sets of
481 // slices are disjoint.
482 return !slices_a.empty() && !slices_b.empty() &&
483 absl::c_none_of(slices_a, [&](const BufferAllocation::Slice& slice) {
484 return slices_b.contains(slice);
485 });
486 }
487
488 StatusOr<BufferAllocation::Slice>
GetUniqueTopLevelOutputSlice() const489 BufferAssignment::GetUniqueTopLevelOutputSlice() const {
490 return GetUniqueTopLevelSlice(
491 module_->entry_computation()->root_instruction());
492 }
493
NewEmptyAllocation(int64 size,LogicalBuffer::Color color)494 BufferAllocation* BufferAssignment::NewEmptyAllocation(
495 int64 size, LogicalBuffer::Color color) {
496 BufferAllocation::Index index = allocations_.size();
497 allocations_.emplace_back(index, size, color);
498 BufferAllocation* allocation = &allocations_.back();
499 return allocation;
500 }
501
NewAllocation(const LogicalBuffer & buffer,int64 size)502 BufferAllocation* BufferAssignment::NewAllocation(const LogicalBuffer& buffer,
503 int64 size) {
504 BufferAllocation* allocation = NewEmptyAllocation(size, buffer.color());
505 AddAssignment(allocation, buffer, /*offset=*/0, size);
506 allocation->peak_buffers_.push_back(&buffer);
507 return allocation;
508 }
509
510 // Adds an instruction to the set assigned to the given buffer.
AddAssignment(BufferAllocation * allocation,const LogicalBuffer & buffer,int64 offset,int64 size)511 void BufferAssignment::AddAssignment(BufferAllocation* allocation,
512 const LogicalBuffer& buffer, int64 offset,
513 int64 size) {
514 CHECK(!allocation_index_for_buffer_.contains(&buffer))
515 << "LogicalBuffer " << buffer << " already has an allocation.";
516 CHECK(allocation->is_reusable() || allocation->assigned_buffers().empty())
517 << "Non-reusable allocation already assigned a buffer: "
518 << allocation->ToString();
519
520 TF_CHECK_OK(points_to_analysis().VerifyBuffer(buffer));
521
522 allocation->AddAssignment(buffer, offset, size);
523 if (liveness().MaybeLiveOut(buffer)) {
524 allocation->set_maybe_live_out(true);
525 }
526 allocation_index_for_buffer_[&buffer] = allocation->index();
527 }
528
529 // Combines allocations of temporary buffers of the same color into one big
530 // BufferAllocation.
CombineTempAllocations()531 void BufferAssignment::CombineTempAllocations() {
532 VLOG(1) << "CombineTempAllocations()";
533 flat_hash_map<LogicalBuffer::Color, BufferAllocation,
534 LogicalBuffer::Color::Hasher>
535 combined_allocation_map;
536
537 // Move all temp allocations into a single run at the end of the allocations
538 // vector.
539 const auto first_temp_it =
540 std::partition(allocations_.begin(), allocations_.end(),
541 [](const BufferAllocation& allocation) {
542 return !allocation.IsPreallocatedTempBuffer();
543 });
544
545 // Walk over the run of temp allocations, collecting the allocations belonging
546 // to the same color.
547 if (first_temp_it != allocations_.end()) {
548 for (auto it = first_temp_it; it != allocations_.end(); ++it) {
549 const BufferAllocation& temp_allocation = *it;
550 LogicalBuffer::Color color = temp_allocation.color();
551 auto combined_it = combined_allocation_map.find(color);
552 if (combined_it == combined_allocation_map.end()) {
553 // We have found the first temp allocation of this color. Collect
554 // the other temp allocations of the same color into it.
555 VLOG(1) << "Combined temp allocation for color " << color
556 << " is: " << temp_allocation;
557 combined_allocation_map.emplace(color, temp_allocation);
558 continue;
559 }
560
561 auto* combined_allocation = &combined_it->second;
562 VLOG(1) << "Combined allocation absorbing temp allocation: "
563 << temp_allocation;
564
565 // Each temp allocation is placed end-to-end, accounting for alignment.
566 // The offset of each buffer in the combined allocation is computed from
567 // the base offset of the allocation.
568 int64 alignment = color_alignment_(color);
569 const int64 base =
570 RoundUpToNearest(combined_allocation->size(), alignment);
571 combined_allocation->set_size(base + temp_allocation.size());
572 for (const auto& buffer_offset_size : temp_allocation.assigned_buffers_) {
573 const LogicalBuffer* buffer = buffer_offset_size.first;
574 const int64 offset = buffer_offset_size.second.offset;
575 const int64 size = buffer_offset_size.second.size;
576 combined_allocation->AddAssignment(*buffer, base + offset, size);
577 }
578 if (!temp_allocation.HeapTraces().empty()) {
579 CHECK_EQ(temp_allocation.HeapTraces().size(), 1);
580 combined_allocation->AddHeapTrace(temp_allocation.HeapTraces().front());
581 }
582 combined_allocation->peak_buffers_.insert(
583 combined_allocation->peak_buffers_.end(),
584 temp_allocation.peak_buffers_.begin(),
585 temp_allocation.peak_buffers_.end());
586 }
587 // Replace all existing temporary allocations with the new combined
588 // allocations.
589 allocations_.erase(first_temp_it, allocations_.end());
590 for (auto& combined : combined_allocation_map) {
591 allocations_.push_back(combined.second);
592 temp_allocation_total_size_ += combined.second.size();
593 }
594 }
595
596 // Update allocation indices to their new positions.
597 allocation_index_for_buffer_.erase(allocation_index_for_buffer_.begin(),
598 allocation_index_for_buffer_.end());
599 for (size_t index = 0; index < allocations_.size(); ++index) {
600 BufferAllocation* allocation = &allocations_[index];
601 allocation->set_index(index);
602 for (const auto& buffer_offset_size : allocation->assigned_buffers_) {
603 const LogicalBuffer* buffer = buffer_offset_size.first;
604 allocation_index_for_buffer_[buffer] = index;
605 }
606 }
607 }
608
ComputeSummaryStats()609 Status BufferAssignment::ComputeSummaryStats() {
610 for (auto& allocation : Allocations()) {
611 if (allocation.is_entry_computation_parameter()) {
612 stats_.parameter_allocation_count++;
613 stats_.parameter_allocation_bytes += allocation.size();
614 }
615 if (allocation.is_constant()) {
616 stats_.constant_allocation_count++;
617 stats_.constant_allocation_bytes += allocation.size();
618 }
619 if (allocation.maybe_live_out()) {
620 stats_.maybe_live_out_allocation_count++;
621 stats_.maybe_live_out_allocation_bytes += allocation.size();
622 }
623 if (allocation.IsPreallocatedTempBuffer()) {
624 stats_.preallocated_temp_allocation_count++;
625 stats_.preallocated_temp_allocation_bytes += allocation.size();
626 }
627 stats_.total_allocation_count++;
628 stats_.total_allocation_bytes += allocation.size();
629 }
630
631 // Only compute total fragmentation if all computations have schedules.
632 HloSchedule schedule(module_);
633 bool schedule_complete = true;
634 for (const auto& computation : module_->computations()) {
635 if (!computation->IsFusionComputation()) {
636 const HloInstructionSequence* sequence =
637 liveness_->hlo_ordering().SequentialOrder(*computation);
638 if (sequence == nullptr) {
639 schedule_complete = false;
640 } else {
641 schedule.set_sequence(computation, *sequence);
642 }
643 }
644 }
645 if (schedule_complete) {
646 TF_RETURN_IF_ERROR(schedule.Verify());
647 TF_ASSIGN_OR_RETURN(
648 const int64 min_size,
649 HeapSimulator::MinimumMemoryForModule(schedule, buffer_size_));
650 stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size;
651 }
652
653 return Status::OK();
654 }
655
ToString() const656 string BufferAssignment::Stats::ToString() const {
657 string s;
658 StrAppendFormat(&s, "BufferAssignment stats:\n");
659 StrAppendFormat(&s, " parameter allocation: %10s\n",
660 HumanReadableNumBytes(parameter_allocation_bytes));
661 StrAppendFormat(&s, " constant allocation: %10s\n",
662 HumanReadableNumBytes(constant_allocation_bytes));
663 StrAppendFormat(&s, " maybe_live_out allocation: %10s\n",
664 HumanReadableNumBytes(maybe_live_out_allocation_bytes));
665 StrAppendFormat(&s, " preallocated temp allocation: %10s\n",
666 HumanReadableNumBytes(preallocated_temp_allocation_bytes));
667 if (preallocated_temp_fragmentation_bytes >= 0) {
668 const double percent = 100. * preallocated_temp_fragmentation_bytes /
669 preallocated_temp_allocation_bytes;
670 StrAppendFormat(
671 &s, " preallocated temp fragmentation: %10s (%.2f%%)\n",
672 HumanReadableNumBytes(preallocated_temp_fragmentation_bytes), percent);
673 }
674 StrAppendFormat(&s, " total allocation: %10s\n",
675 HumanReadableNumBytes(total_allocation_bytes));
676 if (total_fragmentation_bytes >= 0) {
677 const double percent =
678 100. * total_fragmentation_bytes / total_allocation_bytes;
679 StrAppendFormat(&s, " total fragmentation: %10s (%.2f%%)\n",
680 HumanReadableNumBytes(total_fragmentation_bytes), percent);
681 }
682 return s;
683 }
684
ToString() const685 string BufferAssignment::ToString() const {
686 string output;
687 absl::StrAppend(&output, "BufferAssignment:\n");
688 for (auto& allocation : allocations_) {
689 absl::StrAppend(&output, allocation.ToString());
690 }
691 return output;
692 }
693
ToProto() const694 BufferAssignmentProto BufferAssignment::ToProto() const {
695 BufferAssignmentProto proto;
696 // NOTE: TuplePointsToAnalysis state is serialized here in BufferAssigment,
697 // because we need to do the HasAllocation check for each buffer. Otherwise
698 // the buffer_size_ call might fail for some backends.
699 const TuplePointsToAnalysis& points_to_analysis =
700 liveness_->points_to_analysis();
701 for (LogicalBuffer::Id id = 0; id < points_to_analysis.num_logical_buffers();
702 id++) {
703 auto& buffer = points_to_analysis.logical_buffer(id);
704 if (HasAllocation(buffer)) {
705 LogicalBufferProto proto_buffer = buffer.ToProto(buffer_size_);
706 proto.add_logical_buffers()->Swap(&proto_buffer);
707
708 // Fill buffer aliases.
709 for (const BufferAlias& alias :
710 points_to_analysis.GetBufferAliases(buffer)) {
711 if (alias.instruction() == buffer.instruction() &&
712 alias.index() == buffer.index()) {
713 continue; // skip self-aliases
714 }
715 BufferAssignmentProto::BufferAlias* proto_alias =
716 proto.add_buffer_aliases();
717 LogicalBufferProto::Location proto_alias_location =
718 BufferValue::ToLocationProto(*alias.instruction(), alias.index());
719 proto_alias->set_source_buffer_id(buffer.id());
720 proto_alias->mutable_location()->Swap(&proto_alias_location);
721 }
722 }
723 }
724 for (const BufferAllocation& allocation : Allocations()) {
725 BufferAllocationProto proto_allocation = allocation.ToProto();
726 proto.add_buffer_allocations()->Swap(&proto_allocation);
727 for (const HeapSimulatorTrace& heap_trace : allocation.HeapTraces()) {
728 *proto.add_heap_simulator_traces() = heap_trace;
729 }
730 }
731 return proto;
732 }
733
734 /* static */
Run(const HloModule * module,std::unique_ptr<HloOrdering> hlo_ordering,LogicalBuffer::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment,bool allow_input_output_aliasing,bool allocate_buffers_for_constants,BufferLiveness::Colorer colorer,ReuseAllocationFunction reuse_checker)735 StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::Run(
736 const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
737 LogicalBuffer::SizeFunction buffer_size,
738 LogicalBuffer::AlignmentFunction color_alignment,
739 bool allow_input_output_aliasing, bool allocate_buffers_for_constants,
740 BufferLiveness::Colorer colorer, ReuseAllocationFunction reuse_checker) {
741 BufferAssigner assigner(allocate_buffers_for_constants, std::move(colorer),
742 std::move(reuse_checker));
743 return assigner.CreateAssignment(module, std::move(hlo_ordering),
744 std::move(buffer_size),
745 std::move(color_alignment));
746 }
747
748 namespace {
749
750 // a and b are in different subcomputations. Check for the case
751 // where a is inside the while body, and b is outside, part of the same while's
752 // init-operand or while-result.
MayInterfereAcrossSubcomputations(BufferAssignment * assignment,const LogicalBuffer & a_buffer,const LogicalBuffer & b_buffer)753 bool MayInterfereAcrossSubcomputations(BufferAssignment* assignment,
754 const LogicalBuffer& a_buffer,
755 const LogicalBuffer& b_buffer) {
756 const CallGraph& call_graph =
757 assignment->liveness().hlo_ordering().call_graph();
758 const HloInstruction* a_ancestor;
759 const HloInstruction* b_ancestor;
760 std::tie(a_ancestor, b_ancestor) =
761 call_graph.NearestAncestorsInSameComputation(a_buffer.instruction(),
762 b_buffer.instruction());
763 if (a_ancestor == nullptr) {
764 // No common ancestor.
765 return true;
766 }
767 if (a_ancestor->opcode() == HloOpcode::kWhile &&
768 call_graph.InstructionIsNestedIn(a_buffer.instruction(),
769 a_ancestor->while_body())) {
770 const PointsToSet& init_set =
771 assignment->liveness().points_to_analysis().GetPointsToSet(
772 a_ancestor->operand(0));
773 if (init_set.ContainsBuffer(b_buffer)) {
774 VLOG(4) << "Can't interfere: " << a_buffer << " and " << b_buffer
775 << " (part of while-operand)";
776 return false;
777 }
778 const PointsToSet& while_set =
779 assignment->liveness().points_to_analysis().GetPointsToSet(a_ancestor);
780 if (while_set.ContainsBuffer(b_buffer)) {
781 VLOG(4) << "Can't interfere: " << a_buffer << " and " << b_buffer
782 << " (part of while)";
783 return false;
784 }
785 }
786 return true;
787 }
788
789 // Return true, if a and b can't possibly interfere (and therefore further
790 // checking for interference can be skipped). This function checks for special
791 // cases where copy insertion guarantees no interference, but the regular buffer
792 // liveness is too conservative:
793 //
794 // Operations inside a while-body can't interfere with operations outside the
795 // while op if their last use is at the while-loop itself as part of the
796 // while-init op, or the while-result. For ops that are live across a
797 // while-loop, copy insertion will already insert the necessary copies to avoid
798 // such interference.
799 //
800 // This allows sharing buffers in cases like this:
801 // init = {...}
802 // while (init):
803 // p = param(0)
804 // gte = get-tuple-element(p), index=i
805 // t1 = op1 (gte)
806 // t2 = op2 (t1)
807 // ROOT tuple = {..., t2, ...}
808 //
809 // where t1 and t2 can share the same buffer.
MaySkipInterferenceCheck(BufferAssignment * assignment,const LogicalBuffer & a_buffer,const LogicalBuffer & b_buffer)810 bool MaySkipInterferenceCheck(BufferAssignment* assignment,
811 const LogicalBuffer& a_buffer,
812 const LogicalBuffer& b_buffer) {
813 if (a_buffer.instruction()->parent() == b_buffer.instruction()->parent()) {
814 // Ops within the same computation are not handled here. Assume that they
815 // may interfere.
816 return false;
817 }
818 return !MayInterfereAcrossSubcomputations(assignment, a_buffer, b_buffer) ||
819 !MayInterfereAcrossSubcomputations(assignment, b_buffer, a_buffer);
820 }
821
822 } // namespace
823
MaybeAssignBuffer(BufferAllocation * allocation,const LogicalBuffer & buffer,BufferAssignment * assignment)824 bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
825 const LogicalBuffer& buffer,
826 BufferAssignment* assignment) {
827 const LogicalBuffer::SizeFunction& buffer_size = assignment->buffer_size_;
828
829 CHECK(!assignment->HasAllocation(buffer))
830 << "buffer " << buffer << " already has an allocation assigned.";
831
832 VLOG(4) << "Trying to assign " << buffer << " to allocation: " << *allocation;
833
834 if (buffer.color() != allocation->color()) {
835 VLOG(4) << "Can't assign: buffer has color" << buffer.color()
836 << " and allocation has color " << allocation->color() << ".";
837 return false;
838 }
839
840 if (buffer_size(buffer) > allocation->size()) {
841 VLOG(4) << "Can't assign: buffer is larger than allocation ("
842 << buffer_size(buffer) << " > " << allocation->size() << ")";
843 return false;
844 }
845
846 if (allocation->is_readonly()) {
847 VLOG(4) << "Can't assign: allocation is readonly";
848 return false;
849 }
850
851 if (reuse_checker_ != nullptr &&
852 !reuse_checker_(*assignment, *allocation, buffer)) {
853 VLOG(4) << "Can't assign: reuse_checker_(allocation, buffer) == false";
854 return false;
855 }
856
857 if (!allocation->is_reusable()) {
858 VLOG(4) << "Can't assign: allocation is not reusable";
859 return false;
860 }
861
862 for (const auto& buffer_offset_size : allocation->assigned_buffers()) {
863 const LogicalBuffer& assigned_buffer = *buffer_offset_size.first;
864 if (MaySkipInterferenceCheck(assignment, buffer, assigned_buffer)) {
865 continue;
866 }
867 if (assignment->liveness().MayInterfere(assigned_buffer, buffer)) {
868 VLOG(4) << "Can't assign: assignee " << assigned_buffer
869 << " may interfere with " << buffer;
870 return false;
871 }
872 // Copy instruction don't share a buffer with their input operand.
873 if (buffer.instruction()->IsUserOf(assigned_buffer.instruction()) &&
874 buffer.instruction()->opcode() == HloOpcode::kCopy) {
875 VLOG(4) << "Can't assign: assignee " << assigned_buffer
876 << " is used at copy instruction " << buffer;
877 return false;
878 }
879 }
880
881 // If the buffer is live out of the computation then it should only be
882 // assigned a buffer which exactly fits the result to avoid wasting memory
883 // (result buffers can have arbitrary lifetimes).
884 if (assignment->liveness().MaybeLiveOut(buffer) &&
885 allocation->size() != buffer_size(buffer)) {
886 VLOG(4) << "Can't assign: buffer " << buffer
887 << "is live out and size not the same as allocation";
888 return false;
889 }
890
891 assignment->AddAssignment(allocation, buffer, /*offset=*/0,
892 buffer_size(buffer));
893 return true;
894 }
895
AssignBuffersForComputation(const HloComputation * computation,bool is_thread_local,const flat_hash_set<const LogicalBuffer * > & colocated_buffers,const flat_hash_set<BufferAllocation::Index> & colocated_allocations,flat_hash_map<const HloComputation *,flat_hash_set<const LogicalBuffer * >> * buffers_to_assign_sequentially,BufferAssignment * assignment)896 Status BufferAssigner::AssignBuffersForComputation(
897 const HloComputation* computation, bool is_thread_local,
898 const flat_hash_set<const LogicalBuffer*>& colocated_buffers,
899 const flat_hash_set<BufferAllocation::Index>& colocated_allocations,
900 flat_hash_map<const HloComputation*, flat_hash_set<const LogicalBuffer*>>*
901 buffers_to_assign_sequentially,
902 BufferAssignment* assignment) {
903 // Buffers are sorted and assigned to BufferAllocations in decreasing order of
904 // size.
905 std::vector<const LogicalBuffer*> sorted_buffers;
906 for (auto* instruction : computation->instructions()) {
907 // Add all buffers which this instruction defines. Instruction which don't
908 // define buffers (eg, bitcast which just forwards a pointer) don't need
909 // any allocations.
910 for (const LogicalBuffer* buffer :
911 assignment->points_to_analysis().GetBuffersDefinedByInstruction(
912 instruction)) {
913 sorted_buffers.push_back(buffer);
914 }
915 }
916
917 // Generate a post order sort of instructions for sorting of the
918 // LogicalBuffers.
919 flat_hash_map<const HloInstruction*, int> post_order_position;
920 int position = 0;
921 for (auto* instruction : computation->MakeInstructionPostOrder()) {
922 post_order_position.emplace(instruction, position);
923 position++;
924 }
925
926 // If there is a sequential instruction ordering, we'll delay assignment of
927 // temp buffers until after the main assignment loop.
928 const BufferLiveness& liveness = assignment->liveness();
929 const bool has_sequential_order =
930 liveness.hlo_ordering().SequentialOrder(*computation) != nullptr;
931 if (has_sequential_order && buffers_to_assign_sequentially != nullptr) {
932 // Every sequential computation must get an entry in the
933 // buffers_to_assign_sequentially map, even if we end up with an empty set
934 // of buffers. This ensures we can correctly determine whether to run
935 // whole-module heap simulation.
936 buffers_to_assign_sequentially->emplace(
937 computation, flat_hash_set<const LogicalBuffer*>());
938 }
939
940 // Sort the LogicalBuffers first by size. We assign the larger LogicalBuffers
941 // first for simplicity. This means any previously created BufferAllocation is
942 // necessarily large enough to hold the output of the current Buffer in
943 // consideration.
944 //
945 // As a secondary sorting criteria, if the instructions are sequentially
946 // ordered, we assign live-out buffers before others. Note that for sequential
947 // computations, we'll take temp buffers that can't re-use any allocations and
948 // assign them via a heap scheduler. By assigning live-out buffers first, we
949 // increase the odds that temp buffers can re-use an allocation.
950 //
951 // As a final tiebreaker use post order position of the HLO instruction which
952 // defines the buffer. This means an instruction will appear after its
953 // operands (assuming operands are the same/larger size) enabling the
954 // important reuse case where an elementwise instruction reuses one of its
955 // operand's buffer. This improves locality.
956 absl::c_sort(sorted_buffers,
957 [has_sequential_order, &liveness, &post_order_position,
958 assignment](const LogicalBuffer* a, const LogicalBuffer* b) {
959 // Primary sort is by decreasing buffer size.
960 const int64 a_size = assignment->buffer_size_(*a);
961 const int64 b_size = assignment->buffer_size_(*b);
962 if (a_size != b_size) {
963 return a_size > b_size; // use ">" for decreasing size.
964 }
965 // Otherwise live out buffers come before others, if the
966 // instructions are sequentially ordered.
967 if (has_sequential_order) {
968 const bool a_live_out = liveness.MaybeLiveOut(*a);
969 const bool b_live_out = liveness.MaybeLiveOut(*b);
970 if (a_live_out != b_live_out) {
971 return a_live_out;
972 }
973 }
974 // Final tiebreaker is in instruction post order.
975 return post_order_position.at(a->instruction()) <
976 post_order_position.at(b->instruction());
977 });
978
979 // BufferAllocations are necessarily created in decreasing size order. Keep
980 // indices of previously created BufferAllocations in allocation_indices.
981 std::vector<BufferAllocation::Index> allocation_indices;
982 for (const LogicalBuffer* buffer : sorted_buffers) {
983 VLOG(3) << "Assigning allocation to: " << *buffer;
984 if (colocated_buffers.contains(buffer)) {
985 // Colocated buffers are currently assigned in an earlier pass.
986 VLOG(3) << "Skipping colocated buffer: " << *buffer;
987 continue;
988 }
989
990 TF_RET_CHECK(!assignment->HasAllocation(*buffer));
991
992 const HloInstruction* instruction = buffer->instruction();
993 const int64 buffer_size = assignment->buffer_size_(*buffer);
994
995 if (instruction->opcode() == HloOpcode::kConstant) {
996 if (allocate_buffers_for_constants_) {
997 BufferAllocation* allocation =
998 assignment->NewAllocation(*buffer, buffer_size);
999 allocation->set_constant(true);
1000 VLOG(3) << "New allocation #" << allocation->index() << " for constant "
1001 << *buffer;
1002 }
1003 continue;
1004 }
1005
1006 const bool is_entry_parameter =
1007 instruction->opcode() == HloOpcode::kParameter &&
1008 computation == computation->parent()->entry_computation();
1009 if (is_entry_parameter) {
1010 // If the LogicalBuffer is part of an external parameter, creates a new
1011 // allocation and sets its parameter number. Parameters of non-entry
1012 // computations do not need special allocations because they live inside
1013 // callers.
1014 BufferAllocation* allocation =
1015 assignment->NewAllocation(*buffer, buffer_size);
1016 bool parameter_has_alias =
1017 assignment->module().input_output_alias_config().ParameterHasAlias(
1018 instruction->parameter_number(), buffer->index());
1019 allocation->set_entry_computation_parameter(
1020 instruction->parameter_number(), buffer->index(),
1021 parameter_has_alias);
1022 VLOG(3) << "Mark allocation #" << allocation->index()
1023 << " as entry computation parameter: " << *buffer;
1024 continue;
1025 }
1026
1027 if (is_thread_local) {
1028 BufferAllocation* allocation =
1029 assignment->NewAllocation(*buffer, buffer_size);
1030 allocation->set_is_thread_local(true);
1031 VLOG(3) << "New allocation #" << allocation->index()
1032 << " for thread-local: " << *buffer;
1033 continue;
1034 }
1035
1036 if (buffer->shape().IsTuple()) {
1037 BufferAllocation* allocation =
1038 assignment->NewAllocation(*buffer, buffer_size);
1039 allocation->set_is_tuple(true);
1040 VLOG(3) << "New allocation #" << allocation->index()
1041 << " for tuple-shaped buffer: " << *buffer;
1042 continue;
1043 }
1044
1045 // First try to assign a LogicalBuffer to one of its operand allocations to
1046 // improve locality. This is only possible with elementwise operations
1047 // (checked in liveness analysis) which are necessarily top-level
1048 // array-shaped buffers.
1049 if (buffer->IsTopLevel() && !buffer->IsTuple()) {
1050 for (auto* operand : instruction->operands()) {
1051 bool assigned_operand = false;
1052 for (const auto& operand_slice :
1053 assignment->GetAllSlices(operand, /*index=*/{})) {
1054 BufferAllocation* allocation =
1055 assignment->GetMutableAllocation(operand_slice.index());
1056 if (!colocated_allocations.contains(allocation->index())) {
1057 // TODO(b/32491382) Colocated buffers are currently assigned in an
1058 // earlier pass, and so can break the "increasing allocation size"
1059 // invariant in this function (causing this CHECK to fail). However,
1060 // the call to MaybeAssignBuffer is safe as it returns false if
1061 // allocation.size < buffer.size.
1062 CHECK_GE(allocation->size(), buffer_size);
1063 }
1064 if (MaybeAssignBuffer(allocation, *buffer, assignment)) {
1065 VLOG(3) << "Reusing (operand) allocation #" << allocation->index()
1066 << " for: " << *buffer;
1067 assigned_operand = true;
1068 break;
1069 }
1070 }
1071 if (assigned_operand) {
1072 break;
1073 }
1074 }
1075 }
1076
1077 if (!assignment->HasAllocation(*buffer)) {
1078 // Find the smallest buffer which can be reused iterating from end of
1079 // allocation_indices (smallest) to beginning (largest).
1080 for (int allocation_index = allocation_indices.size() - 1;
1081 allocation_index >= 0; allocation_index--) {
1082 BufferAllocation* allocation = assignment->GetMutableAllocation(
1083 allocation_indices[allocation_index]);
1084 // Instructions are iterated in increasing buffer size, so any
1085 // previously create allocation must be large enough to hold this
1086 // instruction's output (with the exception of colocated buffers).
1087 if (!colocated_allocations.contains(allocation->index())) {
1088 // TODO(b/32491382) Colocated buffers are currently assigned in an
1089 // earlier pass, and so can break the "increasing allocation size"
1090 // invariant in this function (causing this CHECK to fail). However,
1091 // the call to MaybeAssignBuffer is safe as it returns false if
1092 // allocation.size < buffer.size.
1093 CHECK_GE(allocation->size(), buffer_size);
1094 }
1095
1096 if (MaybeAssignBuffer(allocation, *buffer, assignment)) {
1097 VLOG(3) << "Reusing allocation #" << allocation->index()
1098 << " for: " << *buffer;
1099 break;
1100 }
1101 }
1102 }
1103
1104 if (!assignment->HasAllocation(*buffer) && has_sequential_order &&
1105 !liveness.MaybeLiveOut(*buffer)) {
1106 // There is a sequential instruction ordering, so we delay assignment of
1107 // temp buffers until after the loop. We do this right before we decide to
1108 // create a new allocation, to ensure we've exhausted all the buffer
1109 // re-use cases above.
1110 //
1111 // Entry parameters and thread local buffers were already handled earlier
1112 // in this loop iteration. See BufferAllocation::IsPreallocatedTempBuffer
1113 // for the definition of temp buffers.
1114 CHECK(!is_entry_parameter) << *buffer;
1115 CHECK(!is_thread_local) << *buffer;
1116 (*buffers_to_assign_sequentially)[computation].insert(buffer);
1117 VLOG(3) << "Delaying assignment of temp buffer: " << *buffer;
1118 continue;
1119 }
1120
1121 if (!assignment->HasAllocation(*buffer)) {
1122 BufferAllocation* allocation =
1123 assignment->NewAllocation(*buffer, buffer_size);
1124 allocation_indices.push_back(allocation->index());
1125 VLOG(3) << "New allocation #" << allocation->index()
1126 << " for: " << *buffer;
1127 }
1128 }
1129
1130 return Status::OK();
1131 }
1132
1133 flat_hash_map<LogicalBuffer::Color, flat_hash_set<const LogicalBuffer*>,
1134 LogicalBuffer::Color::Hasher>
SplitBuffersByColor(const flat_hash_set<const LogicalBuffer * > & buffers)1135 BufferAssigner::SplitBuffersByColor(
1136 const flat_hash_set<const LogicalBuffer*>& buffers) {
1137 flat_hash_map<LogicalBuffer::Color, flat_hash_set<const LogicalBuffer*>,
1138 LogicalBuffer::Color::Hasher>
1139 color_map;
1140 for (auto buffer : buffers) {
1141 color_map[buffer->color()].insert(buffer);
1142 }
1143 return color_map;
1144 }
1145
AssignBuffersWithSequentialOrdering(const flat_hash_map<const HloComputation *,flat_hash_set<const LogicalBuffer * >> & buffers_to_assign_sequentially,bool run_whole_module_heap_simulation,BufferAssignment * assignment)1146 Status BufferAssigner::AssignBuffersWithSequentialOrdering(
1147 const flat_hash_map<const HloComputation*,
1148 flat_hash_set<const LogicalBuffer*>>&
1149 buffers_to_assign_sequentially,
1150 bool run_whole_module_heap_simulation, BufferAssignment* assignment) {
1151 // Run the sequence of instructions through the heap simulator. The heuristic
1152 // that seems to give the best results is lazy-best-fit, with all runs of
1153 // alloc / free calls sorted in decreasing size order.
1154 const HloOrdering& hlo_ordering = assignment->liveness().hlo_ordering();
1155
1156 // Returns a heap algorithm that chooses the best result from several
1157 // algorithms.
1158 auto get_heap_algorithm = [&](int64 alignment) {
1159 auto algorithms =
1160 absl::make_unique<std::vector<std::unique_ptr<HeapAlgorithm>>>();
1161 algorithms->push_back(absl::make_unique<DecreasingSizeRunsHeap>(
1162 absl::make_unique<LazyBestFitHeap>(alignment)));
1163 algorithms->push_back(
1164 absl::make_unique<GlobalDecreasingSizeBestFitHeap>(alignment));
1165 return absl::make_unique<ChooseBestHeapAlgorithm>(std::move(algorithms));
1166 };
1167
1168 if (run_whole_module_heap_simulation) {
1169 // Run the heap simulation over the whole module. This reduces memory usage,
1170 // since buffers for kCall, kWhile, and kConditional sub-computations are
1171 // only live for the duration of their calling instructions.
1172 VLOG(1) << "Running whole-module heap simulation";
1173 HloSchedule schedule(&assignment->module());
1174 flat_hash_set<const LogicalBuffer*> all_buffers_to_assign;
1175 for (const auto& pair : buffers_to_assign_sequentially) {
1176 const HloComputation* computation = pair.first;
1177 const flat_hash_set<const LogicalBuffer*>& buffers_to_assign =
1178 pair.second;
1179 const HloInstructionSequence* instruction_sequence =
1180 hlo_ordering.SequentialOrder(*computation);
1181 CHECK(instruction_sequence != nullptr) << computation->name();
1182 schedule.set_sequence(computation, *instruction_sequence);
1183 all_buffers_to_assign.insert(buffers_to_assign.begin(),
1184 buffers_to_assign.end());
1185 }
1186 auto color_map = SplitBuffersByColor(all_buffers_to_assign);
1187 for (auto& single_colored_set : color_map) {
1188 auto color = single_colored_set.first;
1189 VLOG(2) << "Simulating heap for color " << color;
1190 int64 alignment = assignment->color_alignment_(color);
1191 HeapSimulator::Options options;
1192 options.alloc_constants = allocate_buffers_for_constants_;
1193 BufferValueFlatSet buffer_value_set =
1194 ToBufferValueFlatSet(single_colored_set.second);
1195 options.buffers_to_assign = &buffer_value_set;
1196 TF_ASSIGN_OR_RETURN(
1197 const HeapSimulator::Result result,
1198 HeapSimulator::Run(get_heap_algorithm(alignment),
1199 assignment->module(), schedule,
1200 assignment->points_to_analysis(),
1201 assignment->buffer_size_, options));
1202 AssignBuffersFromHeapSimulator(result, assignment,
1203 single_colored_set.first);
1204 }
1205 } else {
1206 // Run the heap-simulation on a per-computation basis. Buffers for
1207 // sub-computations are assigned disjoint BufferAllocations, assuming the
1208 // worst-case that they may all be live concurrently.
1209 VLOG(1) << "Running per-computation heap simulation";
1210 for (const auto& pair : buffers_to_assign_sequentially) {
1211 const HloComputation* computation = pair.first;
1212 const flat_hash_set<const LogicalBuffer*>& buffers_to_assign =
1213 pair.second;
1214 const HloInstructionSequence* instruction_sequence =
1215 hlo_ordering.SequentialOrder(*computation);
1216 CHECK(instruction_sequence != nullptr) << computation->name();
1217 auto color_map = SplitBuffersByColor(buffers_to_assign);
1218 for (auto& single_colored_set : color_map) {
1219 auto color = single_colored_set.first;
1220 VLOG(2) << "Simulating heap for color " << color;
1221 int64 alignment = assignment->color_alignment_(color);
1222 HeapSimulator::Options options;
1223 BufferValueFlatSet buffer_value_set =
1224 ToBufferValueFlatSet(single_colored_set.second);
1225 options.buffers_to_assign = &buffer_value_set;
1226 TF_ASSIGN_OR_RETURN(
1227 const HeapSimulator::Result result,
1228 HeapSimulator::Run(get_heap_algorithm(alignment), *computation,
1229 *instruction_sequence,
1230 assignment->points_to_analysis(),
1231 assignment->buffer_size_, options));
1232 AssignBuffersFromHeapSimulator(result, assignment,
1233 single_colored_set.first);
1234 }
1235 }
1236 }
1237 return Status::OK();
1238 }
1239
1240 namespace {
1241
1242 // Computes and returns the set of logical buffers live at the point of maximal
1243 // liveness in the given heap trace. LogicalBuffers are (stabily) sorted by id.
ComputePeakMemoryLogicalBuffers(const BufferAllocation & allocation,const HeapSimulatorTrace & heap_trace)1244 std::vector<const LogicalBuffer*> ComputePeakMemoryLogicalBuffers(
1245 const BufferAllocation& allocation, const HeapSimulatorTrace& heap_trace) {
1246 // Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical
1247 // buffers in this allocation.
1248 absl::flat_hash_map<LogicalBuffer::Id, const LogicalBuffer*> id_to_buffer;
1249 absl::flat_hash_map<const LogicalBuffer*, int64> buffer_sizes;
1250 for (const auto& pair : allocation.assigned_buffers()) {
1251 const LogicalBuffer* buffer = pair.first;
1252 const BufferAllocation::OffsetSize& offset_size = pair.second;
1253 id_to_buffer[buffer->id()] = buffer;
1254 buffer_sizes[buffer] = offset_size.size;
1255 }
1256
1257 // Returns how much the given event increases the total size of live
1258 // buffers. Can be negative.
1259 auto memory_delta = [&id_to_buffer, &buffer_sizes](
1260 const HeapSimulatorTrace::Event& event) -> int64 {
1261 const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id());
1262 const int64 buffer_size = buffer_sizes.at(buffer);
1263 if (event.kind() == HeapSimulatorTrace::Event::ALLOC) {
1264 return buffer_size;
1265 } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
1266 // Sharing a buffer does not change the live set size for the purposes of
1267 // the heap simulator. Even though the shared-with buffer may be smaller,
1268 // the entire allocation remains live.
1269 return 0;
1270 } else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
1271 return -1 * buffer_size;
1272 }
1273 LOG(FATAL) << "Unknown event kind: " << event.kind();
1274 };
1275
1276 // First compute the size of the maximal live set.
1277 int64 max_live_size = 0;
1278 int64 live_size = 0;
1279 for (const auto& event : heap_trace.events()) {
1280 live_size += memory_delta(event);
1281 if (max_live_size < live_size) {
1282 max_live_size = live_size;
1283 }
1284 }
1285
1286 // Next gather the set of logical buffers live at the earliest point of
1287 // maximal live set size.
1288 absl::flat_hash_set<const LogicalBuffer*> live_buffers;
1289 live_size = 0;
1290 for (const auto& event : heap_trace.events()) {
1291 const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id());
1292 if (event.kind() == HeapSimulatorTrace::Event::ALLOC) {
1293 InsertOrDie(&live_buffers, buffer);
1294 } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
1295 // Nothing to do.
1296 } else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
1297 CHECK(ContainsKey(live_buffers, buffer));
1298 live_buffers.erase(buffer);
1299 }
1300
1301 live_size += memory_delta(event);
1302 if (live_size == max_live_size) {
1303 break;
1304 }
1305 }
1306 CHECK_EQ(live_size, max_live_size);
1307
1308 std::vector<const LogicalBuffer*> live_buffers_vector;
1309 live_buffers_vector.insert(live_buffers_vector.end(), live_buffers.begin(),
1310 live_buffers.end());
1311
1312 // Stabily sort the live buffers.
1313 absl::c_sort(live_buffers_vector,
1314 [](const LogicalBuffer* a, const LogicalBuffer* b) {
1315 return a->id() < b->id();
1316 });
1317 return live_buffers_vector;
1318 }
1319
1320 } // namespace
1321
AssignBuffersFromHeapSimulator(const HeapSimulator::Result & result,BufferAssignment * assignment,LogicalBuffer::Color color)1322 void BufferAssigner::AssignBuffersFromHeapSimulator(
1323 const HeapSimulator::Result& result, BufferAssignment* assignment,
1324 LogicalBuffer::Color color) {
1325 if (assignment->stats_.preallocated_temp_fragmentation_bytes == -1) {
1326 assignment->stats_.preallocated_temp_fragmentation_bytes =
1327 result.fragmentation_size;
1328 } else {
1329 assignment->stats_.preallocated_temp_fragmentation_bytes +=
1330 result.fragmentation_size;
1331 }
1332
1333 BufferAllocation* allocation =
1334 assignment->NewEmptyAllocation(result.heap_size, color);
1335 for (const auto& buffer_chunk : result.chunk_map) {
1336 // TODO(lauj) Remove this down_cast after downstream users of
1337 // BufferAllocation::assigned_buffers() are updated to use BufferValue.
1338 const LogicalBuffer& buffer =
1339 *CHECK_NOTNULL(dynamic_cast<const LogicalBuffer*>(buffer_chunk.first));
1340 const HeapSimulator::Chunk& chunk = buffer_chunk.second;
1341 assignment->AddAssignment(allocation, buffer, chunk.offset, chunk.size);
1342 }
1343 allocation->peak_buffers_ =
1344 ComputePeakMemoryLogicalBuffers(*allocation, result.debug_trace);
1345
1346 VLOG(1) << "Ran heap simulation for allocation: " << allocation->ToString();
1347 allocation->AddHeapTrace(result.debug_trace);
1348 }
1349
1350 // Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining
1351 // the invariant that all sets in 'colocated_buffer_sets' are disjoint.
1352 //
1353 // A practical example of when this is necessary is a chain of kCall ops:
1354 // computation.entry
1355 // %a = call() -> computation.1
1356 // computation.1
1357 // %b = call() -> computation.2
1358 // computation.2
1359 // %c = parameter()
1360 // This yields the logical sets {%a,%b} {%b,%c} {%c}, which need to be merged
1361 // into a single set {%a,%b,%c}
AddSetToColocatedBufferSets(const std::vector<const LogicalBuffer * > & colocated_set,std::vector<ColocatedBufferSet> * colocated_buffer_sets)1362 void BufferAssigner::AddSetToColocatedBufferSets(
1363 const std::vector<const LogicalBuffer*>& colocated_set,
1364 std::vector<ColocatedBufferSet>* colocated_buffer_sets) {
1365 if (colocated_set.empty()) {
1366 return;
1367 }
1368 VLOG(5) << ColocatedBufferSetsToString(colocated_set,
1369 "Adding colocated buffer set");
1370 // Find existing sets that overlap with at least one buffer from the
1371 // colocated_set. The resulting 'overlap_set_indices' will have at most
1372 // colocated_buffer_sets->size() entries, and will be in increasing order.
1373 std::vector<size_t> overlap_set_indices;
1374 for (size_t index = 0; index < colocated_buffer_sets->size(); ++index) {
1375 for (const LogicalBuffer* buffer : colocated_set) {
1376 if ((*colocated_buffer_sets)[index].contains(buffer)) {
1377 VLOG(5) << "Found overlap with existing set on buffer "
1378 << buffer->ToString() << "\n"
1379 << ColocatedBufferSetsToString((*colocated_buffer_sets)[index],
1380 "Overlapping set");
1381 overlap_set_indices.push_back(index);
1382 break;
1383 }
1384 }
1385 }
1386
1387 // If there is no overlap with existing sets, create a new set.
1388 if (overlap_set_indices.empty()) {
1389 colocated_buffer_sets->emplace_back();
1390 colocated_buffer_sets->back().insert(colocated_set.begin(),
1391 colocated_set.end());
1392 VLOG(5) << "No overlap found, new group created";
1393 return;
1394 }
1395
1396 // Merge all overlap sets and the colocated set into the first overlap set.
1397 ColocatedBufferSet* first = &(*colocated_buffer_sets)[overlap_set_indices[0]];
1398 for (size_t index = 1; index < overlap_set_indices.size(); ++index) {
1399 const ColocatedBufferSet& overlap_set =
1400 (*colocated_buffer_sets)[overlap_set_indices[index]];
1401 first->insert(overlap_set.begin(), overlap_set.end());
1402 }
1403 first->insert(colocated_set.begin(), colocated_set.end());
1404 VLOG(5) << ColocatedBufferSetsToString(
1405 *first, "Result of the colocated buffer set merging");
1406
1407 // Remove overlap sets that we just merged. The offset accounts for the fact
1408 // that as elements are erased, the indices need to be adjusted. Keep in mind
1409 // that overlap_set_indices is in increasing order.
1410 for (size_t index = 1; index < overlap_set_indices.size(); ++index) {
1411 const size_t offset = overlap_set_indices[index] - index + 1;
1412 colocated_buffer_sets->erase(colocated_buffer_sets->begin() + offset);
1413 }
1414 }
1415
1416 std::vector<BufferAssigner::ColocatedBufferSet>
MergeColocatedBufferSets(const std::vector<ColocatedBufferSet> & colocated_buffer_sets,const BufferLiveness & buffer_liveness,const LogicalBuffer::SizeFunction & buffer_size)1417 BufferAssigner::MergeColocatedBufferSets(
1418 const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
1419 const BufferLiveness& buffer_liveness,
1420 const LogicalBuffer::SizeFunction& buffer_size) {
1421 VLOG(1) << "colocation sets count before coalescing:"
1422 << colocated_buffer_sets.size();
1423
1424 // Returns true if the given buffer is for the entry parameter.
1425 auto is_readonly_entry_parameter = [](const LogicalBuffer& buffer) {
1426 auto* instruction = buffer.instruction();
1427 auto* computation = instruction->parent();
1428 auto* module = computation->parent();
1429 return instruction->opcode() == HloOpcode::kParameter &&
1430 computation == module->entry_computation() &&
1431 !module->input_output_alias_config().ParameterHasAlias(
1432 instruction->parameter_number(), buffer.index());
1433 };
1434
1435 std::vector<bool> set_can_be_merged(colocated_buffer_sets.size(), true);
1436
1437 // Do not merge if one of the sets includes live outs, entry parameters or
1438 // constants.
1439 //
1440 // Buffer liveness does not report the correct live range for entry
1441 // parameter and live out buffers so we have to special case them here. On
1442 // backends that support constant buffer allocations, constant buffers are
1443 // assigned globals in readonly storage so we can't merge colocated buffer
1444 // sets containing constants with colocated buffer sets containing writing
1445 // instructions or other constants.
1446 //
1447 // Moreover (on the CPU/GPU backends) the entry parameter buffers belong to
1448 // the caller of the executable so we can't write to entry parameters
1449 // either, and the argument for not merging constants also applies to entry
1450 // parameters.
1451 for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) {
1452 for (auto& buffer : colocated_buffer_sets[i]) {
1453 if (buffer_liveness.MaybeLiveOut(*buffer) ||
1454 is_readonly_entry_parameter(*buffer) ||
1455 buffer->instruction()->opcode() == HloOpcode::kConstant) {
1456 set_can_be_merged[i] = false;
1457 break;
1458 }
1459 }
1460 }
1461
1462 // Returns true if the two colocated buffer sets (specified by their indices
1463 // into the colocated_buffer_sets) can be merged into a single set.
1464 auto cannot_merge_buffer_sets = [&colocated_buffer_sets, &buffer_liveness,
1465 &buffer_size,
1466 &set_can_be_merged](int64 i, int64 j) {
1467 if (!set_can_be_merged[i] || !set_can_be_merged[j]) {
1468 return true;
1469 }
1470
1471 // Colocated sets satisfy the invariant that all buffers within a set have
1472 // the same size. That means we need to check whether the size is the same
1473 // between the two sets, but also that it's enough to look at just one
1474 // buffer within each set.
1475 if (buffer_size(**colocated_buffer_sets[i].begin()) !=
1476 buffer_size(**colocated_buffer_sets[j].begin())) {
1477 return true;
1478 }
1479
1480 // Do not merge if some pair of buffers interferes with each other.
1481 for (auto& buffer_a : colocated_buffer_sets[i]) {
1482 for (auto& buffer_b : colocated_buffer_sets[j]) {
1483 if (buffer_a->id() != buffer_b->id() &&
1484 buffer_liveness.MayInterfere(*buffer_a, *buffer_b)) {
1485 return true;
1486 }
1487 }
1488 }
1489
1490 return false;
1491 };
1492
1493 // Build the interference map among the colocated buffer sets (nodes), by
1494 // adding an edge between any two nodes that cannot be merged into a single
1495 // colocated buffer set.
1496 std::vector<std::vector<int64>> interference_map(
1497 colocated_buffer_sets.size());
1498 for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) {
1499 for (int64 j = i + 1; j < colocated_buffer_sets.size(); ++j) {
1500 if (cannot_merge_buffer_sets(i, j)) {
1501 interference_map[i].push_back(j);
1502 interference_map[j].push_back(i);
1503 }
1504 }
1505 }
1506
1507 // Assign a color to each colocation set in colocated_buffer_sets, such that
1508 // the sets that can be merged are assigned with the same color.
1509 auto assigned_colors = ColorInterferenceGraph(interference_map);
1510
1511 // Merge the buffer sets with the same color.
1512 CHECK(!assigned_colors.empty());
1513 int64 num_sets =
1514 *std::max_element(assigned_colors.begin(), assigned_colors.end()) + 1;
1515 std::vector<ColocatedBufferSet> new_colocated_buffer_sets(num_sets);
1516 for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) {
1517 const auto& buffer_set = colocated_buffer_sets[i];
1518 new_colocated_buffer_sets[assigned_colors[i]].insert(buffer_set.begin(),
1519 buffer_set.end());
1520 }
1521
1522 VLOG(1) << "colocation sets count after coalescing:"
1523 << colocated_buffer_sets.size();
1524 return new_colocated_buffer_sets;
1525 }
1526
1527 // Builds sets of buffers in 'colocated_buffer_sets' which should be colocated
1528 // in the same allocation (currently just supports kWhile, kCall, and
1529 // kConditional and input output aliasing).
BuildColocatedBufferSets(const HloModule * module,const BufferLiveness & buffer_liveness,const LogicalBuffer::SizeFunction & buffer_size,std::vector<ColocatedBufferSet> * colocated_buffer_sets)1530 void BufferAssigner::BuildColocatedBufferSets(
1531 const HloModule* module, const BufferLiveness& buffer_liveness,
1532 const LogicalBuffer::SizeFunction& buffer_size,
1533 std::vector<ColocatedBufferSet>* colocated_buffer_sets) {
1534 const TuplePointsToAnalysis& points_to_analysis =
1535 buffer_liveness.points_to_analysis();
1536
1537 // Set up colocated buffer set for input and output.
1538 VLOG(4) << "Input/Output Alias Config: ";
1539 VLOG(4) << module->input_output_alias_config();
1540 module->input_output_alias_config().ForEachAlias(
1541 [&](const ShapeIndex& output_index,
1542 const HloInputOutputAliasConfig::Alias& alias) {
1543 std::vector<const LogicalBuffer*> colocated_set;
1544 AddBufferToColocatedSet(module->entry_computation()->root_instruction(),
1545 output_index, points_to_analysis,
1546 &colocated_set);
1547 AddBufferToColocatedSet(
1548 module->entry_computation()->parameter_instruction(
1549 alias.parameter_number),
1550 alias.parameter_index, points_to_analysis, &colocated_set);
1551 AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
1552 });
1553
1554 for (const HloComputation* computation : module->MakeComputationPostOrder()) {
1555 if (computation->IsFusionComputation()) {
1556 continue;
1557 }
1558 for (const HloInstruction* instruction :
1559 computation->MakeInstructionPostOrder()) {
1560 const HloOpcode opcode = instruction->opcode();
1561 if (opcode == HloOpcode::kWhile) {
1562 const HloInstruction* while_hlo = instruction;
1563 ShapeUtil::ForEachSubshape(
1564 while_hlo->shape(),
1565 [this, while_hlo, &points_to_analysis, buffer_size,
1566 colocated_buffer_sets](const Shape& /*subshape*/,
1567 const ShapeIndex& index) {
1568 std::vector<const LogicalBuffer*> colocated_set;
1569 // Add while.init.
1570 AddBufferToColocatedSet(while_hlo->operand(0), index,
1571 points_to_analysis, &colocated_set);
1572 // Add while.result.
1573 AddBufferToColocatedSet(while_hlo, index, points_to_analysis,
1574 &colocated_set);
1575 // Add while.cond.parameter.
1576 AddBufferToColocatedSet(
1577 while_hlo->while_condition()->parameter_instruction(0), index,
1578 points_to_analysis, &colocated_set);
1579 // Add while.body.parameter.
1580 AddBufferToColocatedSet(
1581 while_hlo->while_body()->parameter_instruction(0), index,
1582 points_to_analysis, &colocated_set);
1583 // Add while.body.root.
1584 AddBufferToColocatedSet(
1585 while_hlo->while_body()->root_instruction(), index,
1586 points_to_analysis, &colocated_set);
1587 AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
1588 });
1589 } else if (opcode == HloOpcode::kCall) {
1590 const HloInstruction* call_hlo = instruction;
1591 const HloComputation* callee = call_hlo->to_apply();
1592 const HloInstruction* root_hlo = callee->root_instruction();
1593 for (int64 i = 0; i < call_hlo->operand_count(); i++) {
1594 const HloInstruction* call_param = callee->parameter_instruction(i);
1595 const HloInstruction* call_operand = call_hlo->operand(i);
1596 ShapeUtil::ForEachSubshape(
1597 call_operand->shape(),
1598 [&](const Shape& /*subshape*/, const ShapeIndex& index) {
1599 std::vector<const LogicalBuffer*> colocated_set;
1600 AddBufferToColocatedSet(call_param, index, points_to_analysis,
1601 &colocated_set);
1602 AddBufferToColocatedSet(call_operand, index, points_to_analysis,
1603 &colocated_set);
1604 AddSetToColocatedBufferSets(colocated_set,
1605 colocated_buffer_sets);
1606 });
1607 }
1608 ShapeUtil::ForEachSubshape(
1609 call_hlo->shape(),
1610 [this, call_hlo, root_hlo, &points_to_analysis,
1611 colocated_buffer_sets](const Shape& /*subshape*/,
1612 const ShapeIndex& index) {
1613 std::vector<const LogicalBuffer*> colocated_set;
1614 // Add call.result.
1615 AddBufferToColocatedSet(call_hlo, index, points_to_analysis,
1616 &colocated_set);
1617 // Add call.subcomputation.root.
1618 AddBufferToColocatedSet(root_hlo, index, points_to_analysis,
1619 &colocated_set);
1620 AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
1621 });
1622 } else if (opcode == HloOpcode::kConditional) {
1623 const HloInstruction* conditional = instruction;
1624 ShapeUtil::ForEachSubshape(
1625 conditional->shape(),
1626 [this, conditional, &points_to_analysis, colocated_buffer_sets](
1627 const Shape& /*subshape*/, const ShapeIndex& index) {
1628 std::vector<const LogicalBuffer*> colocated_set;
1629 // Add cond.result.
1630 AddBufferToColocatedSet(conditional, index, points_to_analysis,
1631 &colocated_set);
1632 for (int j = 0; j < conditional->branch_count(); ++j) {
1633 // Add each cond.branch_computation[j].root.
1634 AddBufferToColocatedSet(
1635 conditional->branch_computation(j)->root_instruction(),
1636 index, points_to_analysis, &colocated_set);
1637 }
1638 AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
1639 });
1640
1641 for (int j = 0; j < conditional->branch_count(); ++j) {
1642 // Add branch_operand[j] (which is operand[j+1]) and
1643 // cond.branch_computation[j].parameter(0) as a colocated
1644 // buffer set. Note that this has to be done for each subshape in the
1645 // branch_operand of the case.
1646 ShapeUtil::ForEachSubshape(
1647 conditional->operand(j + 1)->shape(),
1648 [this, j, conditional, &points_to_analysis,
1649 colocated_buffer_sets](const Shape& /*subshape*/,
1650 const ShapeIndex& index) {
1651 std::vector<const LogicalBuffer*> branch_set;
1652 // Add cond.operand[j+1].
1653 AddBufferToColocatedSet(conditional->operand(j + 1), index,
1654 points_to_analysis, &branch_set);
1655 // Add cond.branch_computation[j].parameter_instruction(0).
1656 AddBufferToColocatedSet(
1657 conditional->branch_computation(j)->parameter_instruction(
1658 0),
1659 index, points_to_analysis, &branch_set);
1660 AddSetToColocatedBufferSets(branch_set, colocated_buffer_sets);
1661 });
1662 }
1663 }
1664 }
1665 }
1666
1667 if (colocated_buffer_sets->empty()) {
1668 return;
1669 }
1670
1671 int64 i = 0;
1672 for (const auto& colocated_set : *colocated_buffer_sets) {
1673 VLOG(4) << "Colocated set " << i++ << ":";
1674 for (const auto& buffer : colocated_set) {
1675 VLOG(4) << " " << buffer->ToString();
1676 }
1677 }
1678 // Try to find more coalescing opportunities among the colocated buffer sets.
1679 //
1680 // TODO(b/32491382): We should be able to remove this by using the
1681 // module-level liveness analysis, which would let us directly detect buffer
1682 // sharing opportunities between the while instruction buffer and the buffers
1683 // from the predicate and body computation, as well as sharing across
1684 // different while instructions.
1685 std::vector<ColocatedBufferSet> new_colocated_buffer_sets =
1686 MergeColocatedBufferSets(*colocated_buffer_sets, buffer_liveness,
1687 buffer_size);
1688 std::swap(*colocated_buffer_sets, new_colocated_buffer_sets);
1689 }
1690
1691 // Assigns all colocated buffer sets in 'colocated_buffer_sets' to the same
1692 // allocation in 'assignment'.
AssignColocatedBufferSets(const std::vector<ColocatedBufferSet> & colocated_buffer_sets,BufferAssignment * assignment,flat_hash_set<const LogicalBuffer * > * colocated_buffers,flat_hash_set<BufferAllocation::Index> * colocated_allocations)1693 void BufferAssigner::AssignColocatedBufferSets(
1694 const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
1695 BufferAssignment* assignment,
1696 flat_hash_set<const LogicalBuffer*>* colocated_buffers,
1697 flat_hash_set<BufferAllocation::Index>* colocated_allocations) {
1698 for (const ColocatedBufferSet& colocated_buffer_set : colocated_buffer_sets) {
1699 BufferAllocation* allocation = nullptr;
1700 // Set 'entry_parameter_number' and 'entry_parameter_shape_idx' if entry
1701 // param in 'colocated_buffer_set'.
1702 int64 entry_parameter_number = -1;
1703 const ShapeIndex* entry_parameter_shape_idx = nullptr;
1704 bool is_constant = false;
1705 for (const LogicalBuffer* buffer : colocated_buffer_set) {
1706 const HloInstruction* instruction = buffer->instruction();
1707 const HloComputation* computation = instruction->parent();
1708 if (instruction->opcode() == HloOpcode::kParameter &&
1709 computation == computation->parent()->entry_computation()) {
1710 entry_parameter_number = instruction->parameter_number();
1711 entry_parameter_shape_idx = &buffer->index();
1712 } else if (instruction->opcode() == HloOpcode::kConstant) {
1713 is_constant = true;
1714 }
1715 }
1716
1717 CHECK(!is_constant || entry_parameter_number == -1)
1718 << "Copy insertion should have inserted copies to prevent this.";
1719
1720 for (const LogicalBuffer* buffer : colocated_buffer_set) {
1721 const int64 buffer_size = assignment->buffer_size_(*buffer);
1722 if (allocation == nullptr) {
1723 // TODO(b/32491382) Avoid current trivial solution of using new
1724 // allocations for each colocated buffer set. When liveness has
1725 // module-level scope, we can allow buffers to be shared across
1726 // computations (in some cases).
1727 allocation = assignment->NewAllocation(*buffer, buffer_size);
1728 if (is_constant) {
1729 allocation->set_constant(true);
1730 }
1731 colocated_allocations->insert(allocation->index());
1732 } else {
1733 CHECK_EQ(buffer_size, allocation->size())
1734 << "Buffer: " << *buffer << " size mismatch in colocated buffer "
1735 << "allocation: " << *allocation;
1736 assignment->AddAssignment(allocation, *buffer, /*offset=*/0,
1737 buffer_size);
1738 }
1739 colocated_buffers->insert(buffer);
1740 }
1741
1742 // If an allocation contains a parameter, set corresponding fields.
1743 if (entry_parameter_number >= 0) {
1744 bool parameter_has_alias =
1745 assignment->module().input_output_alias_config().ParameterHasAlias(
1746 entry_parameter_number, *entry_parameter_shape_idx);
1747 allocation->set_entry_computation_parameter(entry_parameter_number,
1748 *entry_parameter_shape_idx,
1749 parameter_has_alias);
1750 }
1751 }
1752 }
1753
CreateAssignment(const HloModule * module,std::unique_ptr<HloOrdering> hlo_ordering,LogicalBuffer::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment)1754 StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
1755 const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
1756 LogicalBuffer::SizeFunction buffer_size,
1757 LogicalBuffer::AlignmentFunction color_alignment) {
1758 TF_ASSIGN_OR_RETURN(std::unique_ptr<BufferLiveness> liveness,
1759 BufferLiveness::Run(module, std::move(hlo_ordering)));
1760
1761 VLOG(1) << "Assigning buffers to module " << module->name();
1762 XLA_VLOG_LINES(2, module->ToString());
1763 XLA_VLOG_LINES(3, liveness->ToString());
1764 XLA_VLOG_LINES(3, liveness->points_to_analysis().ToString());
1765
1766 // Can't use absl::make_unique because BufferAssignment constructor is
1767 // private.
1768 std::unique_ptr<BufferAssignment> assignment(
1769 new BufferAssignment(module, std::move(liveness), std::move(buffer_size),
1770 std::move(color_alignment)));
1771
1772 // Assign buffers with the tightest constraints first (colocated buffer sets).
1773 // Once b/32491382 enables module-level liveness analysis, we may be able
1774 // to assign colocated buffers (or at least reuse their allocation for
1775 // buffers outside of the set) in AssignBuffersForComputation.
1776 flat_hash_set<const LogicalBuffer*> colocated_buffers;
1777 flat_hash_set<BufferAllocation::Index> colocated_allocations;
1778 std::vector<ColocatedBufferSet> colocated_buffer_sets;
1779 BuildColocatedBufferSets(module, assignment->liveness(),
1780 assignment->buffer_size_, &colocated_buffer_sets);
1781 TF_RETURN_IF_ERROR(colorer_(assignment->liveness()));
1782 VLOG(3) << "After coloring:";
1783 XLA_VLOG_LINES(3, assignment->points_to_analysis().ToString());
1784
1785 AssignColocatedBufferSets(colocated_buffer_sets, assignment.get(),
1786 &colocated_buffers, &colocated_allocations);
1787
1788 std::vector<const HloComputation*> thread_local_computations;
1789 std::vector<const HloComputation*> global_computations;
1790 TF_RETURN_IF_ERROR(GatherComputationsByAllocationType(
1791 module, &thread_local_computations, &global_computations));
1792
1793 // First assign buffers for global computatations. Temporary buffers for
1794 // sequential computations are collected in 'buffers_to_assign_sequentially'.
1795 flat_hash_map<const HloComputation*, flat_hash_set<const LogicalBuffer*>>
1796 buffers_to_assign_sequentially;
1797 for (auto* computation : global_computations) {
1798 TF_RETURN_IF_ERROR(AssignBuffersForComputation(
1799 computation,
1800 /*is_thread_local=*/false, colocated_buffers, colocated_allocations,
1801 &buffers_to_assign_sequentially, assignment.get()));
1802 }
1803 // Assign buffers with sequential ordering, if any. If all global computations
1804 // are sequential, we can run heap simuation on the whole module, which
1805 // reduces memory usage.
1806 const bool run_whole_module_heap_simulation =
1807 buffers_to_assign_sequentially.size() == global_computations.size();
1808 TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering(
1809 buffers_to_assign_sequentially, run_whole_module_heap_simulation,
1810 assignment.get()));
1811
1812 // Now assign buffers for thread-local computations. All LogicalBuffers get
1813 // their own BufferAllocation.
1814 for (auto* computation : thread_local_computations) {
1815 TF_RET_CHECK(computation != module->entry_computation());
1816 if (computation->IsFusionComputation()) {
1817 continue;
1818 }
1819 TF_RETURN_IF_ERROR(AssignBuffersForComputation(
1820 computation,
1821 /*is_thread_local=*/true, colocated_buffers, colocated_allocations,
1822 /*buffers_to_assign_sequentially=*/nullptr, assignment.get()));
1823 }
1824
1825 // Mark all buffers which may be live out of the entry computation as
1826 // "liveout".
1827 for (const LogicalBuffer* buffer :
1828 assignment->liveness().maybe_live_out_buffers()) {
1829 VLOG(3) << "maybe_live_out LogicalBuffer: " << *buffer;
1830 if (assignment->HasAllocation(*buffer)) {
1831 BufferAllocation* alloc =
1832 assignment->GetMutableAssignedAllocation(*buffer);
1833 alloc->set_maybe_live_out(true);
1834 VLOG(3) << "maybe_live_out BufferAllocation: " << *alloc;
1835 }
1836 }
1837
1838 // Combines allocations of temporary buffers into one big BufferAllocation.
1839 // This can only be performed after all buffers have been assigned, and after
1840 // maybe_live_out is marked, since it is used to determine whether an
1841 // allocation contains temporary buffers or not.
1842 assignment->CombineTempAllocations();
1843
1844 XLA_VLOG_LINES(2, assignment->ToString());
1845 TF_RETURN_IF_ERROR(assignment->ComputeSummaryStats());
1846 XLA_VLOG_LINES(1, assignment->GetStats().ToString());
1847 return std::move(assignment);
1848 }
1849
1850 } // namespace xla
1851