1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_join.h"
27 #include "tensorflow/compiler/xla/map_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_buffer.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_value.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/platform/logging.h"
36
37 namespace xla {
38
39 using absl::StrAppend;
40
41 // Data structure used to construct the alias analysis. Thrown away after alias
42 // analysis is complete. This data structure keeps track of which sets of
43 // HloValues must be in the same HloBuffer. This is maintained as a map from a
44 // buffer identifier (BufferNumber) to set of HLoValues.
45 //
46 // Initially each value is its own buffer. In MergeAliasedBuffers, sets of
47 // values which must share the same buffer are merged together. The end result
48 // is a partitioning of all HloValues into sets where each set needs its own
49 // HloBuffer. By performing this analysis without constructing HloBuffers on the
50 // fly, we can after-the-fact construct a vector of contiguously numbered
51 // HloBuffers after the buffer requirement has been determined.
52 class BufferValueMap {
53 public:
54 // A unique identifier for a set of colocated values which must share the same
55 // buffer. This is not necessarily the same as the HloBuffer::Id which will
56 // ultimately contain the values. The reason is that HloBuffer::Id's are
57 // contiguous, while BufferNumbers may not be. BufferNumbers may not be
58 // dense because buffers may be created and destroyed during the analysis
59 // construction process.
60 using BufferNumber = int64;
61
BufferValueMap(const HloModule * module,const HloDataflowAnalysis & dataflow)62 explicit BufferValueMap(const HloModule* module,
63 const HloDataflowAnalysis& dataflow)
64 : module_(module), dataflow_(dataflow) {
65 buffers_.reserve(dataflow_.values().size());
66 value_to_buffer_number_.reserve(dataflow_.values().size());
67 for (const HloValue* value : dataflow_.values()) {
68 BufferNumber buffer_number = next_buffer_number_++;
69 buffers_[buffer_number].insert(value);
70 value_to_buffer_number_[value] = buffer_number;
71 }
72 }
73
74 // Merge together sets of HloValues which must be in the same HloBuffer
75 // because of aliasing rules (eg, in-place kWhile instruction).
MergeAliasedBuffers()76 void MergeAliasedBuffers() {
77 for (const HloValue* value : dataflow_.values()) {
78 VLOG(3) << "Merging colocated values, value: " << value->ToShortString();
79
80 // Gather the set of buffers with aliasing rules (eg, kWhile) which this
81 // value must be contained in.
82 std::vector<BufferNumber> aliased_buffers = ComputeAliasedBuffers(*value);
83
84 BufferNumber current_buffer = value_to_buffer_number_.at(value);
85 if (aliased_buffers.empty()) {
86 // The buffer containing 'value' aliases no other buffers. If the buffer
87 // containing 'value' already only contains 'value', then no change is
88 // necessary. If the buffer containing 'value' does contain other
89 // values, then remove 'value' from the buffer and create a new buffer
90 // containing only 'value'
91 if (buffers_.at(current_buffer).size() == 1) {
92 CHECK_EQ(*buffers_.at(current_buffer).begin(), value);
93 } else {
94 MoveValueToNewBuffer(*value);
95 }
96 } else {
97 // If multiple buffers are aliased merge these buffers together into a
98 // single buffer (arbitrarily chosen as the first buffer in the vector).
99 if (aliased_buffers.size() > 1) {
100 for (int64 i = 1; i < aliased_buffers.size(); ++i) {
101 MergeBuffers(/*from=*/aliased_buffers[i],
102 /*to=*/aliased_buffers[0]);
103 }
104 }
105 BufferNumber new_buffer = aliased_buffers[0];
106 if (current_buffer != new_buffer) {
107 MoveValueToBuffer(*value, new_buffer);
108 }
109 }
110 }
111 }
112
113 // Compute and return a sorted vector of all BufferNumbers. Can be used to
114 // iterate through all buffers stabily.
ComputeSortedBufferNumbers() const115 std::vector<BufferNumber> ComputeSortedBufferNumbers() const {
116 std::vector<BufferNumber> buffer_numbers;
117 for (const auto& pair : buffers_) {
118 buffer_numbers.push_back(pair.first);
119 }
120 absl::c_sort(buffer_numbers);
121 return buffer_numbers;
122 }
123
124 // Return a set of all the values in the given buffer.
GetValuesInBuffer(BufferNumber buffer_number) const125 const absl::flat_hash_set<const HloValue*>& GetValuesInBuffer(
126 BufferNumber buffer_number) const {
127 return buffers_.at(buffer_number);
128 }
129
130 private:
131 // Create a new buffer.
NewBuffer(const HloValue & value)132 void NewBuffer(const HloValue& value) {
133 BufferNumber buffer_number = next_buffer_number_++;
134 buffers_[buffer_number].insert(&value);
135 value_to_buffer_number_[&value] = buffer_number;
136 }
137
138 // Move the given value into a new buffer containing only the value.
MoveValueToNewBuffer(const HloValue & value)139 void MoveValueToNewBuffer(const HloValue& value) {
140 BufferNumber new_buffer_number = next_buffer_number_++;
141 buffers_[new_buffer_number];
142 MoveValueToBuffer(value, new_buffer_number);
143 }
144
145 // Move the given value into the given buffer.
MoveValueToBuffer(const HloValue & value,BufferNumber buffer_number)146 void MoveValueToBuffer(const HloValue& value, BufferNumber buffer_number) {
147 BufferNumber old_buffer_number = value_to_buffer_number_.at(&value);
148 absl::flat_hash_set<const HloValue*>& old_value_set =
149 buffers_.at(old_buffer_number);
150 old_value_set.erase(&value);
151 if (old_value_set.empty()) {
152 buffers_.erase(old_buffer_number);
153 }
154
155 buffers_.at(buffer_number).insert(&value);
156 value_to_buffer_number_.at(&value) = buffer_number;
157 }
158
159 // Merge the buffer 'from' into the buffer 'to'.
MergeBuffers(BufferNumber from,BufferNumber to)160 void MergeBuffers(BufferNumber from, BufferNumber to) {
161 auto& from_value_set = buffers_.at(from);
162 buffers_.at(to).insert(from_value_set.begin(), from_value_set.end());
163 // NOTE: using a union-find algorithm to hold the colocated values might be
164 // faster.
165 for (const HloValue* value : from_value_set) {
166 value_to_buffer_number_.at(value) = to;
167 }
168 buffers_.erase(from);
169 }
170
GetBufferForValue(const HloValue & value)171 BufferNumber GetBufferForValue(const HloValue& value) {
172 return value_to_buffer_number_.at(&value);
173 }
174
ComputeInputOutputAliasedBuffers(const HloValue & value,std::vector<BufferNumber> * aliased_buffers)175 void ComputeInputOutputAliasedBuffers(
176 const HloValue& value, std::vector<BufferNumber>* aliased_buffers) {
177 // Get parameter value from an aliased_input object.
178 const auto get_parameter_value =
179 [this](const HloInputOutputAliasConfig::Alias& aliased_input)
180 -> const HloValue& {
181 return dataflow_.GetUniqueValueAt(
182 module_->entry_computation()->parameter_instruction(
183 aliased_input.parameter_number),
184 aliased_input.parameter_index);
185 };
186
187 // If the value shows up in a root instruction, alias it with parameter
188 // instruction.
189 for (const HloPosition& pos : value.positions()) {
190 if (pos.instruction == module_->entry_computation()->root_instruction()) {
191 ShapeIndex output_index = pos.index;
192
193 auto aliased_input =
194 module_->input_output_alias_config().GetAliasedParameter(
195 output_index);
196 if (aliased_input) {
197 aliased_buffers->push_back(
198 GetBufferForValue(get_parameter_value(*aliased_input)));
199 }
200 }
201 }
202
203 // If the value is parameter instruction itself, alias it with itself.
204 if (value.instruction()->opcode() == HloOpcode::kParameter &&
205 value.instruction()->parent() == module_->entry_computation()) {
206 aliased_buffers->push_back(GetBufferForValue(value));
207 }
208 }
209
ComputeWhileAliasedBuffers(const HloValue & value,std::vector<BufferNumber> * aliased_buffers)210 void ComputeWhileAliasedBuffers(const HloValue& value,
211 std::vector<BufferNumber>* aliased_buffers) {
212 VLOG(3) << "Compute kWhile aliases";
213 // Value is init of a while (use is while).
214 for (const HloUse& use : value.uses()) {
215 if (use.instruction->opcode() == HloOpcode::kWhile) {
216 // Determine the while value that this shares a buffer with.
217 const HloValue& while_value =
218 dataflow_.GetUniqueValueAt(use.instruction, use.operand_index);
219 aliased_buffers->push_back(GetBufferForValue(while_value));
220 VLOG(3) << " value is init value to a while; must share buffer with "
221 "while value "
222 << while_value.ToShortString();
223 }
224 }
225 // Value is a parameter of a while body/condition.
226 if (value.defining_instruction()->opcode() == HloOpcode::kParameter) {
227 const HloComputation* computation =
228 value.defining_instruction()->parent();
229 const CallGraphNode& call_graph_node =
230 dataflow_.call_graph().GetNode(computation);
231 for (const CallSite& callsite : call_graph_node.caller_callsites()) {
232 if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
233 // Call graph must have been flattened.
234 CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
235
236 const HloValue& while_value = dataflow_.GetUniqueValueAt(
237 callsite.instruction(), value.defining_index());
238 VLOG(3) << " value is parameter value of the body or condition of a "
239 "while; must share buffer with while value "
240 << while_value.ToShortString();
241 aliased_buffers->push_back(GetBufferForValue(while_value));
242 }
243 }
244 }
245 // Value is the root of a while body.
246 for (const HloPosition& position : value.positions()) {
247 const HloComputation* computation = position.instruction->parent();
248 const CallGraphNode& call_graph_node =
249 dataflow_.call_graph().GetNode(computation);
250 if (position.instruction == computation->root_instruction()) {
251 for (const CallSite& callsite : call_graph_node.caller_callsites()) {
252 if (callsite.instruction()->opcode() == HloOpcode::kWhile &&
253 callsite.instruction()->while_body() == computation) {
254 // Call graph must have been flattened.
255 CHECK_EQ(call_graph_node.caller_callsites().size(), 1)
256 << "Call graph must have been flattened.";
257
258 const HloValue& while_value = dataflow_.GetUniqueValueAt(
259 callsite.instruction(), position.index);
260 VLOG(3) << " value @ " << position << " is root of "
261 << callsite.instruction()->name()
262 << "; body root and while value root must share buffer "
263 "among them : "
264 << while_value.ToShortString();
265 aliased_buffers->push_back(GetBufferForValue(while_value));
266 }
267 }
268 }
269 }
270 // Value is the output of the while instruction itself.
271 if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
272 VLOG(3) << " value is output of a while instruction";
273 aliased_buffers->push_back(GetBufferForValue(value));
274 }
275 }
276
ComputeConditionalAliasedBuffers(const HloValue & value,std::vector<BufferNumber> * aliased_buffers)277 void ComputeConditionalAliasedBuffers(
278 const HloValue& value, std::vector<BufferNumber>* aliased_buffers) {
279 VLOG(3) << "Compute kConditional aliases";
280 // Aliases the buffers of the true/false computations roots, with the one of
281 // the conditional.
282 for (const HloPosition& position : value.positions()) {
283 const HloComputation* computation = position.instruction->parent();
284 const CallGraphNode& call_graph_node =
285 dataflow_.call_graph().GetNode(computation);
286 if (position.instruction == computation->root_instruction()) {
287 for (const CallSite& callsite : call_graph_node.caller_callsites()) {
288 if (callsite.instruction()->opcode() == HloOpcode::kConditional) {
289 // Call graph must have been flattened.
290 CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
291
292 const HloValue& cond_value = dataflow_.GetUniqueValueAt(
293 callsite.instruction(), position.index);
294 VLOG(3)
295 << " value @ " << position << " is root of "
296 << callsite.instruction()->name()
297 << "; branch computation roots must share buffer among them : "
298 << cond_value.ToShortString();
299 aliased_buffers->push_back(GetBufferForValue(cond_value));
300 }
301 }
302 }
303 }
304 // Value is the output of the conditional instruction itself.
305 if (value.defining_instruction()->opcode() == HloOpcode::kConditional) {
306 VLOG(3) << " value is output of a conditional instruction";
307 aliased_buffers->push_back(GetBufferForValue(value));
308 }
309 }
310
ComputeInPlaceOperationAliasedBuffers(const HloValue & value,std::vector<BufferNumber> * aliased_buffers)311 void ComputeInPlaceOperationAliasedBuffers(
312 const HloValue& value, std::vector<BufferNumber>* aliased_buffers) {
313 VLOG(3) << "Compute aliases for in-place operations (e.g. "
314 "kDynamicUpdateSlice and kScatter)";
315 for (const HloPosition& position : value.positions()) {
316 HloInstruction* instruction = position.instruction;
317 for (const auto& operand_and_output_index :
318 HloDataflowAnalysis::GetInPlaceInputOutputPairs(instruction)) {
319 if (position.index == operand_and_output_index.second) {
320 const HloUse& operand = operand_and_output_index.first;
321 const HloValue& operand_value = dataflow_.GetUniqueValueAt(
322 instruction->operand(operand.operand_number),
323 operand.operand_index);
324 VLOG(3) << " operand value " << operand_value.ToShortString()
325 << " aliases.";
326 aliased_buffers->push_back(GetBufferForValue(operand_value));
327 }
328 }
329 }
330
331 for (const HloUse& use : value.uses()) {
332 for (const auto& operand_and_output_index :
333 HloDataflowAnalysis::GetInPlaceInputOutputPairs(use.instruction)) {
334 if (use == operand_and_output_index.first) {
335 const HloValue& use_value = dataflow_.GetUniqueValueAt(
336 use.instruction, operand_and_output_index.second);
337 VLOG(3) << " use value " << use_value.ToShortString() << " aliases.";
338 aliased_buffers->push_back(GetBufferForValue(use_value));
339 }
340 }
341 }
342 }
343
344 // Compute and return a vector of buffers that the given value must be
345 // contained in due to HLO aliasing rules.
ComputeAliasedBuffers(const HloValue & value)346 std::vector<BufferNumber> ComputeAliasedBuffers(const HloValue& value) {
347 for (const HloUse& use : value.uses()) {
348 VLOG(2) << "Use of value " << value.ToShortString() << ": " << use;
349 }
350 std::vector<BufferNumber> aliased_buffers;
351 ComputeInputOutputAliasedBuffers(value, &aliased_buffers);
352 ComputeWhileAliasedBuffers(value, &aliased_buffers);
353 ComputeConditionalAliasedBuffers(value, &aliased_buffers);
354 ComputeInPlaceOperationAliasedBuffers(value, &aliased_buffers);
355 // Uniquify aliased buffers.
356 absl::c_sort(aliased_buffers);
357 aliased_buffers.erase(
358 std::unique(aliased_buffers.begin(), aliased_buffers.end()),
359 aliased_buffers.end());
360 return aliased_buffers;
361 }
362
363 const HloModule* module_ = nullptr;
364
365 // Dataflow analysis used to construct the buffer map.
366 const HloDataflowAnalysis& dataflow_;
367
368 // A map containing the set of values contained in each buffer.
369 absl::flat_hash_map<BufferNumber, absl::flat_hash_set<const HloValue*>>
370 buffers_;
371
372 // A map indicating which buffer each value is contained in.
373 absl::flat_hash_map<const HloValue*, BufferNumber> value_to_buffer_number_;
374
375 // The buffer number of the next buffer to be created.
376 BufferNumber next_buffer_number_ = 0;
377 };
378
HloAliasAnalysis(const HloModule * module)379 HloAliasAnalysis::HloAliasAnalysis(const HloModule* module) : module_(module) {}
380
GetUniqueBufferAt(const HloInstruction * instruction,const ShapeIndex & index) const381 const HloBuffer& HloAliasAnalysis::GetUniqueBufferAt(
382 const HloInstruction* instruction, const ShapeIndex& index) const {
383 std::vector<const HloBuffer*> buffers = ComputeBuffersAt(instruction, index);
384 CHECK_EQ(buffers.size(), 1);
385 return *buffers[0];
386 }
387
GetUniqueBufferAt(const HloInstruction * instruction,const ShapeIndex & index)388 HloBuffer& HloAliasAnalysis::GetUniqueBufferAt(
389 const HloInstruction* instruction, const ShapeIndex& index) {
390 return GetBuffer(static_cast<const HloAliasAnalysis*>(this)
391 ->GetUniqueBufferAt(instruction, index)
392 .id());
393 }
394
ComputeBuffersAt(const HloInstruction * instruction,const ShapeIndex & index) const395 std::vector<const HloBuffer*> HloAliasAnalysis::ComputeBuffersAt(
396 const HloInstruction* instruction, const ShapeIndex& index) const {
397 std::vector<const HloBuffer*> buffers;
398 for (const HloValue* value :
399 dataflow_analysis_->GetValueSet(instruction, index).values()) {
400 buffers.push_back(&GetBufferContainingValue(*value));
401 }
402
403 // Sort and uniquify vector before returning.
404 absl::c_sort(buffers, HloBuffer::IdLessThan);
405 buffers.erase(std::unique(buffers.begin(), buffers.end()), buffers.end());
406
407 return buffers;
408 }
409
InstructionBuffersAreAmbiguous(const HloInstruction * instruction) const410 bool HloAliasAnalysis::InstructionBuffersAreAmbiguous(
411 const HloInstruction* instruction) const {
412 for (const auto& pair :
413 dataflow_analysis_->GetInstructionValueSet(instruction)) {
414 const HloValueSet& value_set = pair.second;
415 const HloBuffer* buffer = nullptr;
416 for (const HloValue* value : value_set.values()) {
417 if (buffer == nullptr) {
418 buffer = &GetBufferContainingValue(*value);
419 } else if (buffer != &GetBufferContainingValue(*value)) {
420 return true;
421 }
422 }
423 }
424 return false;
425 }
426
InstructionBuffersAreDistinct(const HloInstruction * instruction) const427 bool HloAliasAnalysis::InstructionBuffersAreDistinct(
428 const HloInstruction* instruction) const {
429 absl::flat_hash_set<const HloBuffer*> buffers_seen;
430 for (const auto& pair :
431 dataflow_analysis_->GetInstructionValueSet(instruction)) {
432 const HloValueSet& value_set = pair.second;
433 if (value_set.values().size() == 1) {
434 if (!buffers_seen
435 .insert(&GetBufferContainingValue(value_set.GetUniqueValue()))
436 .second) {
437 return false;
438 }
439 } else {
440 // It's possible for multiple values at this index to have the same
441 // HloBuffer. This does not result in non-distinctness. To account for
442 // this case, add all of the buffers at this index after checking
443 // whether each buffer exists at an earlier index. This is a corner
444 // case, however, as the number of values at an index is almost always
445 // one.
446 std::vector<const HloBuffer*> buffers_at_this_index;
447 for (const HloValue* value : value_set.values()) {
448 const HloBuffer* buffer = &GetBufferContainingValue(*value);
449 if (ContainsKey(buffers_seen, buffer)) {
450 return false;
451 }
452 buffers_at_this_index.push_back(buffer);
453 }
454 buffers_seen.insert(buffers_at_this_index.begin(),
455 buffers_at_this_index.end());
456 }
457 }
458 return true;
459 }
460
Verify() const461 Status HloAliasAnalysis::Verify() const {
462 // Verify consistency between the value_to_buffer_ map and
463 // HloBuffer::values().
464 for (const auto& pair : value_to_buffer_) {
465 const HloValue* value = pair.first;
466 const HloBuffer& buffer = *pair.second;
467 TF_RET_CHECK(absl::c_linear_search(buffer.values(), value));
468 }
469
470 for (HloBuffer::Id id = 0; id < buffers_.size(); ++id) {
471 const HloBuffer& buffer = buffers_[id];
472 TF_RET_CHECK(buffer.id() == id);
473
474 HloValue::Id last_value_id = -1;
475 for (const HloValue* value : buffer.values()) {
476 TF_RET_CHECK(GetBufferContainingValue(*value) == buffer);
477
478 // Also verify the values in HloBuffer are unique and sorted by id.
479 TF_RET_CHECK(value->id() > last_value_id);
480 last_value_id = value->id();
481 }
482 }
483
484 return Status::OK();
485 }
486
ToString() const487 string HloAliasAnalysis::ToString() const {
488 string out = absl::StrCat("HloAliasAnalysis, module ", module_->name(), "\n");
489 StrAppend(&out, " Buffers at each position:\n");
490 for (const HloComputation* computation : module_->computations()) {
491 for (const HloInstruction* instruction : computation->instructions()) {
492 StrAppend(&out, " ", instruction->name(), ":\n");
493 if (instruction->shape().IsTuple()) {
494 ShapeUtil::ForEachSubshape(
495 instruction->shape(),
496 [&out, &instruction, this](const Shape&, const ShapeIndex& index) {
497 StrAppend(&out, " tuple index ", index.ToString(), ":\n");
498 for (const HloBuffer* buffer :
499 ComputeBuffersAt(instruction, index)) {
500 StrAppend(&out, " ", buffer->ToString(), "\n");
501 }
502 });
503 } else {
504 for (const HloBuffer* buffer :
505 ComputeBuffersAt(instruction, /*index=*/{})) {
506 StrAppend(&out, " ", buffer->ToString(), "\n");
507 }
508 }
509 }
510 }
511
512 StrAppend(&out, " Buffers:\n");
513 for (const HloBuffer& buffer : buffers()) {
514 StrAppend(&out, " ", buffer.ToString(), "\n");
515 StrAppend(&out, " positions:\n");
516 for (const HloPosition& position : buffer.ComputePositions()) {
517 StrAppend(&out, " ", position.ToString(), "\n");
518 }
519 }
520
521 return out;
522 }
523
524 /* static */
Run(const HloModule * module,const HloDataflowAnalysis::CanShareBuffer & can_share_buffer)525 StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
526 const HloModule* module,
527 const HloDataflowAnalysis::CanShareBuffer& can_share_buffer) {
528 VLOG(2) << "HloAliasAnalysis::Run on module " << module->name();
529 XLA_VLOG_LINES(2, module->ToString());
530
531 auto alias_analysis = absl::WrapUnique(new HloAliasAnalysis(module));
532 TF_ASSIGN_OR_RETURN(alias_analysis->dataflow_analysis_,
533 HloDataflowAnalysis::Run(*module, /*ssa_form=*/true,
534 /*bitcast_defines_value=*/false,
535 can_share_buffer));
536
537 BufferValueMap buffer_map(module, alias_analysis->dataflow_analysis());
538 buffer_map.MergeAliasedBuffers();
539
540 // Create a vector of HloBuffers, one for each set of values in the
541 // BufferValueMap. Create the HloBuffers as a vector of contiguously numbered
542 // buffers.
543 std::vector<BufferValueMap::BufferNumber> sorted_buffer_numbers =
544 buffer_map.ComputeSortedBufferNumbers();
545 alias_analysis->buffers_.reserve(sorted_buffer_numbers.size());
546 HloBuffer::Id next_id = 0;
547 for (BufferValueMap::BufferNumber buffer_number : sorted_buffer_numbers) {
548 auto& value_set = buffer_map.GetValuesInBuffer(buffer_number);
549 std::vector<const HloValue*> sorted_values(value_set.begin(),
550 value_set.end());
551 absl::c_sort(sorted_values, HloValue::IdLessThan);
552 alias_analysis->buffers_.emplace_back(next_id++, sorted_values);
553 for (const HloValue* value : sorted_values) {
554 alias_analysis->value_to_buffer_[value] =
555 &alias_analysis->buffers_.back();
556 }
557 }
558
559 TF_DCHECK_OK(alias_analysis->Verify());
560
561 HloInstruction* root = module->entry_computation()->root_instruction();
562 ShapeUtil::ForEachSubshape(
563 root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) {
564 for (const HloBuffer* buffer :
565 alias_analysis->ComputeBuffersAt(root, index)) {
566 alias_analysis->live_out_buffers_.insert(buffer);
567 }
568 });
569
570 XLA_VLOG_LINES(2, alias_analysis->ToString());
571 return std::move(alias_analysis);
572 }
573
MergeBuffers(const HloBuffer & to,const HloBuffer & from)574 void HloAliasAnalysis::MergeBuffers(const HloBuffer& to,
575 const HloBuffer& from) {
576 CHECK(to.id() != from.id());
577 VLOG(2) << "Merge buffer: " << from.ToString() << " into :" << to.ToString();
578
579 CHECK(from.id() < buffers_.size());
580 CHECK(to.id() < buffers_.size());
581
582 // Merge the values of `to` and `from`, creates a new buffer with the
583 // merged values.
584 std::vector<const HloValue*> merged_values(to.values().begin(),
585 to.values().end());
586
587 merged_values.insert(merged_values.end(), from.values().begin(),
588 from.values().end());
589 absl::c_sort(merged_values, [](const HloValue* a, const HloValue* b) {
590 return a->id() < b->id();
591 });
592
593 buffers_[to.id()] = HloBuffer(to.id(), merged_values);
594 for (const HloValue* value : merged_values) {
595 // Update references of values.
596 value_to_buffer_[value] = &buffers_[to.id()];
597 }
598
599 if (live_out_buffers_.count(&from) > 0) {
600 // Update live out set to erase `from` and add `to`.
601 live_out_buffers_.erase(&from);
602 live_out_buffers_.insert(&buffers_[to.id()]);
603 }
604
605 int64 from_id = from.id();
606 if (from_id != buffers_.size() - 1) {
607 // Now `from` is invalid, move the last element of buffers to replace `from`
608 // and update references to the last element.
609 const HloBuffer& last_elem = buffers_.back();
610 buffers_[from.id()] = HloBuffer(from_id, last_elem.values());
611
612 if (live_out_buffers_.count(&last_elem) > 0) {
613 // Update live out set to redirect the last element to its new position.
614 live_out_buffers_.erase(&last_elem);
615 live_out_buffers_.insert(&buffers_[from_id]);
616 }
617
618 // Update references of values.
619 for (const HloValue* value : buffers_[from_id].values()) {
620 value_to_buffer_[value] = &buffers_[from_id];
621 }
622 }
623
624 // Remove the last element.
625 buffers_.pop_back();
626
627 CHECK(Verify().ok());
628 }
629
HasLiveRangeInterference(const HloOrdering & ordering) const630 bool HloAliasAnalysis::HasLiveRangeInterference(
631 const HloOrdering& ordering) const {
632 for (const HloBuffer& buffer : buffers()) {
633 CHECK(!buffer.values().empty());
634 if (buffer.values().front()->shape().IsToken()) {
635 // Tokens have no on-device representation and cannot interfere.
636 for (const HloValue* value : buffer.values()) {
637 // If one of the values is a token, all values must be a token.
638 DCHECK(value->shape().IsToken());
639 }
640 continue;
641 }
642
643 // Check that the values in the buffer are totally ordered with respect to
644 // 'ordering'. Begin by sorting the values with respect to 'ordering' with a
645 // tie-break using value ID. The tie-break is necessary because we need a
646 // strict weak order for std::sort.
647 std::vector<const HloValue*> values = buffer.values();
648 absl::c_sort(values, [&ordering](const HloValue* a, const HloValue* b) {
649 if (ordering.IsDefinedBefore(*a, *b)) {
650 return true;
651 } else if (ordering.IsDefinedBefore(*b, *a)) {
652 return false;
653 } else {
654 return a->id() < b->id();
655 }
656 });
657
658 // Walk through the ordered vector of values. First verify that the values
659 // are totally ordered with respect to 'ordering', then check that no
660 // adjacent values have overlapping live ranges. Only adjacent values must
661 // be checked because of the property of live range interference. For
662 // example, if you have values A, B, and C (in program order) contained in
663 // a buffer and A interferes with C, then necessarily A also interferes
664 // with B. So to check interference you only need to check interference
665 // between A and B, and between B and C.
666 for (int i = 1; i < values.size(); ++i) {
667 if (!ordering.IsDefinedBefore(*values[i - 1], *values[i])) {
668 VLOG(1) << values[i - 1]->ToShortString() << " and "
669 << values[i]->ToShortString() << " are not ordered";
670 return true;
671 }
672 if (ordering.MayInterfere(*values[i - 1], *values[i],
673 dataflow_analysis())) {
674 VLOG(1) << "In buffer " << buffer.id() << " containing values:\n "
675 << absl::StrJoin(values, ", ",
676 [](string* out, const HloValue* value) {
677 StrAppend(out, value->ToShortString());
678 })
679
680 << "\nValue " << values[i - 1]->ToShortString()
681 << " may interfere with value " << values[i]->ToShortString();
682 return true;
683 }
684 }
685 }
686
687 return false;
688 }
689
690 } // namespace xla
691