1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
17
18 #include <ostream>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/memory/memory.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_format.h"
26 #include "absl/strings/str_join.h"
27 #include "tensorflow/compiler/xla/map_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/platform/logging.h"
35
36 namespace xla {
37
ToString() const38 string BufferAlias::ToString() const {
39 return absl::StrCat("BufferAlias(", instruction_->name(), "[",
40 absl::StrJoin(index_, ","), "])");
41 }
42
operator <<(std::ostream & out,const BufferAlias & buffer_alias)43 std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias) {
44 out << buffer_alias.ToString();
45 return out;
46 }
47
IsAmbiguous() const48 bool PointsToSet::IsAmbiguous() const {
49 bool ambiguous = false;
50 ForEachElement(
51 [&ambiguous](const ShapeIndex& /*index*/, const BufferList& points_to) {
52 ambiguous |= points_to.size() > 1;
53 });
54 return ambiguous;
55 }
56
IsDistinct() const57 bool PointsToSet::IsDistinct() const {
58 bool distinct = true;
59 absl::flat_hash_set<const LogicalBuffer*> all_points_to;
60 ForEachElement([&](const ShapeIndex& /*index*/, const BufferList& points_to) {
61 for (auto& buffer : points_to) {
62 if (all_points_to.contains(buffer)) {
63 distinct = false;
64 }
65 all_points_to.insert(buffer);
66 }
67 });
68 return distinct;
69 }
70
size() const71 size_t PointsToSet::size() const {
72 // Because pointed-to elements may be duplicated we have to create a flattened
73 // set and return the size.
74 return CreateFlattenedSet().size();
75 }
76
CreateFlattenedSet() const77 PointsToSet::BufferSet PointsToSet::CreateFlattenedSet() const {
78 BufferSet flat_set;
79 ForEachElement(
80 [&flat_set](const ShapeIndex& /*index*/, const BufferList& buffers) {
81 flat_set.insert(buffers.begin(), buffers.end());
82 });
83 return flat_set;
84 }
85
ContainsBuffer(const LogicalBuffer & buffer) const86 bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const {
87 bool found = false;
88 ForEachElement([&found, &buffer](const ShapeIndex& /*index*/,
89 const BufferList& pointed_to_buffers) {
90 if (!found && absl::c_linear_search(pointed_to_buffers, &buffer)) {
91 found = true;
92 }
93 });
94 return found;
95 }
96
ContainsBufferAtIndex(const LogicalBuffer & buffer,const ShapeIndex & index) const97 bool PointsToSet::ContainsBufferAtIndex(const LogicalBuffer& buffer,
98 const ShapeIndex& index) const {
99 const auto& pointed_to_buffers = element(index);
100 return absl::c_linear_search(pointed_to_buffers, &buffer);
101 }
102
AddPointedToBuffer(const LogicalBuffer & buffer,const ShapeIndex & index)103 void PointsToSet::AddPointedToBuffer(const LogicalBuffer& buffer,
104 const ShapeIndex& index) {
105 if (ContainsBufferAtIndex(buffer, index)) {
106 return;
107 }
108 mutable_element(index)->push_back(&buffer);
109 }
110
tuple_sources(const ShapeIndex & index) const111 const PointsToSet::SourceSet& PointsToSet::tuple_sources(
112 const ShapeIndex& index) const {
113 return tree_.element(index).tuple_sources;
114 }
115
add_tuple_source(const ShapeIndex & index,HloInstruction * tuple)116 void PointsToSet::add_tuple_source(const ShapeIndex& index,
117 HloInstruction* tuple) {
118 tree_.mutable_element(index)->tuple_sources.insert(tuple);
119 }
120
121 namespace {
122 // Gather fusion instructions from 'instruction' into 'fusion_instructions'.
GatherFusionInstructions(HloInstruction * instruction,std::vector<HloInstruction * > * fusion_instructions)123 void GatherFusionInstructions(
124 HloInstruction* instruction,
125 std::vector<HloInstruction*>* fusion_instructions) {
126 CHECK_EQ(HloOpcode::kFusion, instruction->opcode());
127 for (auto* fused : instruction->fused_instructions()) {
128 if (fused->opcode() == HloOpcode::kFusion) {
129 GatherFusionInstructions(fused, fusion_instructions);
130 }
131 }
132 fusion_instructions->push_back(instruction);
133 }
134
135 } // namespace
136
137 /* static */ StatusOr<std::unique_ptr<TuplePointsToAnalysis>>
Run(const HloModule * module)138 TuplePointsToAnalysis::Run(const HloModule* module) {
139 auto logical_buffer_analysis = LogicalBufferAnalysis::Run(module);
140 std::unique_ptr<TuplePointsToAnalysis> analysis(new TuplePointsToAnalysis(
141 module, logical_buffer_analysis.ConsumeValueOrDie()));
142 TF_RETURN_IF_ERROR(analysis->Analyze());
143 return std::move(analysis);
144 }
145
Analyze()146 Status TuplePointsToAnalysis::Analyze() {
147 per_instruction_.clear();
148 per_instruction_.reserve(module_->instruction_count());
149
150 logical_buffer_aliases_.clear();
151 logical_buffer_aliases_.resize(
152 logical_buffer_analysis_->num_logical_buffers());
153
154 std::vector<HloInstruction*> fusion_instructions;
155 for (auto* computation : module_->MakeNonfusionComputations()) {
156 TF_RETURN_IF_ERROR(computation->Accept(this));
157 TF_RETURN_IF_ERROR(
158 PopulateDefinedBuffersAndAliases(computation->instructions()));
159 for (auto* instruction : computation->instructions()) {
160 if (instruction->opcode() == HloOpcode::kFusion) {
161 GatherFusionInstructions(instruction, &fusion_instructions);
162 }
163 }
164 }
165 // Run points-to analysis on fusion instructions in 'computation'.
166 for (auto* instruction : fusion_instructions) {
167 TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this));
168 TF_RETURN_IF_ERROR(
169 PopulateDefinedBuffersAndAliases(instruction->fused_instructions()));
170 }
171
172 XLA_VLOG_LINES(3, ToString());
173
174 return Status::OK();
175 }
176
177 Status TuplePointsToAnalysis::PopulateDefinedBuffersAndAliases(const decltype(
178 std::declval<HloComputation>().instructions())& instructions) {
179 for (auto* instruction : instructions) {
180 PerInstruction* pi = PerInst(instruction);
181 TF_RETURN_IF_ERROR(GatherBuffersDefinedByInstruction(
182 instruction, &pi->instruction_defined_buffers));
183
184 const PointsToSet& points_to_set = GetPointsToSet(instruction);
185 points_to_set.ForEachElement(
186 [this, &instruction](
187 const ShapeIndex& index,
__anon1079e89b0602( const ShapeIndex& index, const PointsToSet::BufferList& pointed_to_buffers) 188 const PointsToSet::BufferList& pointed_to_buffers) {
189 for (const LogicalBuffer* buffer : pointed_to_buffers) {
190 logical_buffer_aliases_[buffer->id()].emplace_back(instruction,
191 index);
192 }
193 });
194 }
195 return Status::OK();
196 }
197
DefaultAction(HloInstruction * hlo_instruction)198 Status TuplePointsToAnalysis::DefaultAction(HloInstruction* hlo_instruction) {
199 // Create trivial points-to set for instruction. Each points-to set at index i
200 // contains a single element LogicalBuffer(hlo_instruction, i). This indicates
201 // that this instruction is the source of all buffers in its own output.
202 PointsToSet& points_to_set = CreateEmptyPointsToSet(hlo_instruction);
203 points_to_set.ForEachMutableElement(
204 [this, hlo_instruction](const ShapeIndex& index,
205 PointsToSet::BufferList* buffers) {
206 buffers->push_back(
207 &logical_buffer_analysis_->GetBuffer(hlo_instruction, index));
208 });
209
210 if (hlo_instruction->shape().IsTuple()) {
211 // If the hlo instruction is a tuple-shaped, then trivially the instruction
212 // itself is the source of the tuple.
213 points_to_set.add_tuple_source({}, hlo_instruction);
214 }
215
216 return Status::OK();
217 }
218
HandleGetTupleElement(HloInstruction * get_tuple_element)219 Status TuplePointsToAnalysis::HandleGetTupleElement(
220 HloInstruction* get_tuple_element) {
221 // GetTupleElement forwards a pointer to a particular element of the tuple
222 // operand.
223 int64 element_index = get_tuple_element->tuple_index();
224
225 PointsToSet& points_to_set = CreateEmptyPointsToSet(get_tuple_element);
226 const PointsToSet& operand_points_to_set =
227 *PerInst(get_tuple_element->operand(0))->points_to_set;
228
229 // Copy the points-to set (and tuple sources) at index {element_index} of the
230 // operand to the points-to set for this GetTupleElement instruction.
231 points_to_set.ForEachMutableElement(
232 [&](const ShapeIndex& target_index, PointsToSet::BufferList* points_to) {
233 // Construct an index into the operand by prepending element_index to
234 // the index for the GetTupleElement instruction's points-to set.
235 ShapeIndex src_index;
236 src_index.push_back(element_index);
237 for (auto element : target_index) {
238 src_index.push_back(element);
239 }
240
241 *points_to = operand_points_to_set.element(src_index);
242 for (HloInstruction* tuple :
243 operand_points_to_set.tuple_sources(src_index)) {
244 points_to_set.add_tuple_source(target_index, tuple);
245 }
246 });
247
248 return Status::OK();
249 }
250
HandleCopy(HloInstruction * copy)251 Status TuplePointsToAnalysis::HandleCopy(HloInstruction* copy) {
252 // A kCopy instruction performs a shallow copy of the operand. The top-level
253 // buffer (index={}) is newly created, but all other buffers (in the case of a
254 // tuple shape) come from the operand
255 PointsToSet& points_to_set = CreateCopiedPointsToSet(copy, copy->operand(0));
256 points_to_set.mutable_element(/*index=*/{})->clear();
257 points_to_set.AddPointedToBuffer(
258 logical_buffer_analysis_->GetBuffer(copy, /*index=*/{}),
259 /*index=*/{});
260
261 return Status::OK();
262 }
263
HandleBitcast(HloInstruction * bitcast)264 Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) {
265 // A kBitcast instruction aliases its operand. That is, the buffer of its
266 // result *is* the buffer of its operand, so just copy the operands points-to
267 // set.
268 CreateCopiedPointsToSet(bitcast, bitcast->operand(0));
269 return Status::OK();
270 }
271
HandleDomain(HloInstruction * domain)272 Status TuplePointsToAnalysis::HandleDomain(HloInstruction* domain) {
273 // A kDomain instruction aliases its operand. That is, the buffer of its
274 // result *is* the buffer of its operand, so just copy the operands points-to
275 // set.
276 CreateCopiedPointsToSet(domain, domain->operand(0));
277 return Status::OK();
278 }
279
HandleAddDependency(HloInstruction * add_dependency)280 Status TuplePointsToAnalysis::HandleAddDependency(
281 HloInstruction* add_dependency) {
282 // AddDependency just forwards the value of its zero-th operand.
283 CreateCopiedPointsToSet(add_dependency, add_dependency->operand(0));
284 return Status::OK();
285 }
286
HandleRecvDone(HloInstruction * recv_done)287 Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) {
288 // RecvDone aliases its input (Recv) tuple element {0} to element {0} of its
289 // output. The other indices ({} and {1}) define their own buffers.
290 PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done);
291 points_to_set.AddPointedToBuffer(
292 logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{}),
293 /*index=*/{});
294 points_to_set.AddPointedToBuffer(
295 logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{1}),
296 /*index=*/{1});
297
298 const PointsToSet& operand_points_to_set =
299 GetPointsToSet(recv_done->operand(0));
300
301 // Recursively copy the points to set of the operand tuple {0} to the output
302 // element {0}.
303 points_to_set.ForEachMutableElement(
304 [&points_to_set, &operand_points_to_set](
305 const ShapeIndex& index, PointsToSet::BufferList* buffers) {
306 if (index.empty() || index[0] != 0) {
307 return;
308 }
309 *buffers = operand_points_to_set.element(index);
310 for (auto& tuple_source : operand_points_to_set.tuple_sources(index)) {
311 points_to_set.add_tuple_source(index, tuple_source);
312 }
313 });
314 return Status::OK();
315 }
316
HandleSend(HloInstruction * send)317 Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) {
318 // Send creates a tuple of {aliased operand, U32 context, token}.
319 PointsToSet& points_to_set = CreateEmptyPointsToSet(send);
320
321 // Creates the points to set for the tuple and its element at {1}.
322 auto top_buffer = points_to_set.mutable_element(ShapeIndex({}));
323 top_buffer->push_back(
324 &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({})));
325 points_to_set.add_tuple_source({}, send);
326
327 auto context_buffer = points_to_set.mutable_element(ShapeIndex({1}));
328 context_buffer->push_back(
329 &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({1})));
330
331 auto token_buffer = points_to_set.mutable_element(ShapeIndex({2}));
332 token_buffer->push_back(
333 &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({2})));
334
335 // Recursively copy the points to set of the operand to output tuple {0}.
336 const PointsToSet& operand_points_to_set = GetPointsToSet(send->operand(0));
337 operand_points_to_set.ForEachElement(
338 [&points_to_set, &operand_points_to_set](
339 const ShapeIndex& src_index,
340 const PointsToSet::BufferList& points_to) {
341 ShapeIndex target_index({0});
342 for (auto element : src_index) {
343 target_index.push_back(element);
344 }
345 *points_to_set.mutable_element(target_index) = points_to;
346
347 for (HloInstruction* tuple :
348 operand_points_to_set.tuple_sources(src_index)) {
349 points_to_set.add_tuple_source(target_index, tuple);
350 }
351 });
352
353 return Status::OK();
354 }
355
HandleTuple(HloInstruction * tuple)356 Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) {
357 absl::Span<HloInstruction* const> operands(tuple->operands());
358 PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple);
359 points_to_set.AddPointedToBuffer(
360 logical_buffer_analysis_->GetBuffer(tuple, /*index=*/{}),
361 /*index=*/{});
362
363 // A tuple contains references to all input operands and transitively any
364 // references in those operands.
365 for (int64 i = 0; i < operands.size(); ++i) {
366 const PointsToSet& operand_points_to_set =
367 *PerInst(operands[i])->points_to_set;
368
369 // Copy the points-to set (and tuple sources) of the operand into the
370 // respective subtree of the tuple instructions points-to set.
371 operand_points_to_set.ForEachElement(
372 [&points_to_set, &operand_points_to_set, i](
373 const ShapeIndex& src_index,
374 const PointsToSet::BufferList& points_to) {
375 ShapeIndex target_index;
376 target_index.push_back(i);
377 for (auto element : src_index) {
378 target_index.push_back(element);
379 }
380
381 *points_to_set.mutable_element(target_index) = points_to;
382
383 for (HloInstruction* tuple :
384 operand_points_to_set.tuple_sources(src_index)) {
385 points_to_set.add_tuple_source(target_index, tuple);
386 }
387 });
388 }
389
390 points_to_set.add_tuple_source({}, tuple);
391
392 return Status::OK();
393 }
394
HandleTupleSelect(HloInstruction * tuple_select)395 Status TuplePointsToAnalysis::HandleTupleSelect(HloInstruction* tuple_select) {
396 // Select allocates a new buffer and then shallow copies the on_true or
397 // on_false buffer into this new buffer. Which side is chosen cannot be
398 // determined statically so conservatively set the points-to set to the union
399 // of these on_true and on_false operands.
400 //
401 // First create a copy of the on_true points-to set (and tuple sources), then
402 // add in elements of the on_false points-to set (tuple sources).
403 auto on_true = tuple_select->operand(1);
404 auto on_false = tuple_select->operand(2);
405 PointsToSet& points_to_set = CreateCopiedPointsToSet(tuple_select, on_true);
406 const PointsToSet& false_points_to_set = *PerInst(on_false)->points_to_set;
407 points_to_set.ForEachMutableElement(
408 [&](const ShapeIndex& index, PointsToSet::BufferList* buffers) {
409 for (const LogicalBuffer* false_buffer :
410 false_points_to_set.element(index)) {
411 points_to_set.AddPointedToBuffer(*false_buffer, index);
412 }
413
414 for (HloInstruction* tuple : false_points_to_set.tuple_sources(index)) {
415 points_to_set.add_tuple_source(index, tuple);
416 }
417 });
418
419 // Select creates a new (top-level) buffer to store its result, so its
420 // respective element in the points-to set should contain only itself.
421 points_to_set.mutable_element({})->clear();
422 points_to_set.AddPointedToBuffer(
423 logical_buffer_analysis_->GetBuffer(tuple_select, /*index=*/{}),
424 /*index=*/{});
425 return Status::OK();
426 }
427
GetPointsToSet(const HloInstruction * hlo_instruction) const428 const PointsToSet& TuplePointsToAnalysis::GetPointsToSet(
429 const HloInstruction* hlo_instruction) const {
430 return *PerInst(hlo_instruction)->points_to_set;
431 }
432
CreateEmptyPointsToSet(const HloInstruction * instruction)433 PointsToSet& TuplePointsToAnalysis::CreateEmptyPointsToSet(
434 const HloInstruction* instruction) {
435 PerInstruction* pi = PerInst(instruction);
436 CHECK(pi->points_to_set == nullptr)
437 << "instruction should not have been present in the map.";
438 auto set = absl::make_unique<PointsToSet>(&instruction->shape());
439 pi->points_to_set = std::move(set);
440 // Return *set using the iterator returned by emplace.
441 return *pi->points_to_set;
442 }
443
InstructionDefinesBufferAtIndex(const HloInstruction * instruction,const ShapeIndex & index) const444 bool TuplePointsToAnalysis::InstructionDefinesBufferAtIndex(
445 const HloInstruction* instruction, const ShapeIndex& index) const {
446 const auto& buffers = GetPointsToSet(instruction).element(index);
447 return (buffers.size() == 1 && buffers[0]->instruction() == instruction);
448 }
449
VerifyBuffer(const LogicalBuffer & buffer) const450 Status TuplePointsToAnalysis::VerifyBuffer(const LogicalBuffer& buffer) const {
451 if (!InstructionDefinesBufferAtIndex(buffer.instruction(), buffer.index())) {
452 return FailedPrecondition(
453 "LogicalBuffer %s is ill-defined: instruction %s does not define a "
454 "buffer at that index",
455 buffer.ToString(), buffer.instruction()->name());
456 }
457
458 if (buffer.id() < 0 ||
459 buffer.id() >= logical_buffer_analysis_->num_logical_buffers()) {
460 return FailedPrecondition("LogicalBuffer %s is ill-defined: invalid id %d",
461 buffer.ToString(), buffer.id());
462 }
463 if (GetBuffer(buffer.id()).instruction() != buffer.instruction() ||
464 GetBuffer(buffer.id()).index() != buffer.index()) {
465 return FailedPrecondition(
466 "LogicalBuffer %s is ill-defined: buffer with same id differs: %s",
467 buffer.ToString(), GetBuffer(buffer.id()).ToString());
468 }
469
470 return Status::OK();
471 }
472
GetBuffer(LogicalBuffer::Id id) const473 const LogicalBuffer& TuplePointsToAnalysis::GetBuffer(
474 LogicalBuffer::Id id) const {
475 CHECK_GE(id, 0);
476 CHECK_LT(id, logical_buffer_analysis_->num_logical_buffers());
477 return logical_buffer_analysis_->GetBuffer(id);
478 }
479
GetBufferDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const480 StatusOr<const LogicalBuffer*> TuplePointsToAnalysis::GetBufferDefinedAt(
481 const HloInstruction* instruction, const ShapeIndex& index) const {
482 const auto& buffers = GetPointsToSet(instruction).element(index);
483 if (buffers.size() != 1 || buffers[0]->instruction() != instruction) {
484 return FailedPrecondition(
485 "instruction %s does not define buffer at index {%s}",
486 instruction->name(), absl::StrJoin(index, ","));
487 }
488 return buffers[0];
489 }
490
491 const TuplePointsToAnalysis::BufferAliasVector&
GetBufferAliases(const LogicalBuffer & buffer) const492 TuplePointsToAnalysis::GetBufferAliases(const LogicalBuffer& buffer) const {
493 return logical_buffer_aliases_.at(buffer.id());
494 }
495
496 const TuplePointsToAnalysis::BufferDefinitionVector&
GetBuffersDefinedByInstruction(const HloInstruction * instruction) const497 TuplePointsToAnalysis::GetBuffersDefinedByInstruction(
498 const HloInstruction* instruction) const {
499 return PerInst(instruction)->instruction_defined_buffers;
500 }
501
GatherBuffersDefinedByInstruction(const HloInstruction * instruction,TuplePointsToAnalysis::BufferDefinitionVector * buffers)502 Status TuplePointsToAnalysis::GatherBuffersDefinedByInstruction(
503 const HloInstruction* instruction,
504 TuplePointsToAnalysis::BufferDefinitionVector* buffers) {
505 GetPointsToSet(instruction)
506 .ForEachElement([buffers, instruction](
507 const ShapeIndex& index,
508 const PointsToSet::BufferList& source_buffers) {
509 // Add buffers which 'instruction' is the source of.
510 CHECK(!source_buffers.empty());
511 if (source_buffers.size() == 1 &&
512 source_buffers[0]->instruction() == instruction) {
513 // If this instruction is the source of this buffer the
514 // indices must match.
515 DCHECK(source_buffers[0]->index() == index);
516 buffers->push_back(source_buffers[0]);
517 } else {
518 // If the points-to set includes more than one buffer then
519 // necessarily this instruction did not produce the
520 // buffer.
521 for (const LogicalBuffer* source_buffer : source_buffers) {
522 DCHECK(source_buffer->instruction() != instruction);
523 }
524 }
525 });
526 return Status::OK();
527 }
528
CreateCopiedPointsToSet(const HloInstruction * instruction,const HloInstruction * src)529 PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet(
530 const HloInstruction* instruction, const HloInstruction* src) {
531 // PointsToSet doesn't have a copy constructor so copy over element-by-element
532 // from src PointsToSet.
533 PointsToSet& dst_points_to_set = CreateEmptyPointsToSet(instruction);
534 const PointsToSet& src_points_to_set = GetPointsToSet(src);
535 dst_points_to_set.ForEachMutableElement(
536 [&dst_points_to_set, &src_points_to_set](
537 const ShapeIndex& index, PointsToSet::BufferList* buffers) {
538 *buffers = src_points_to_set.element(index);
539 for (auto& tuple_source : src_points_to_set.tuple_sources(index)) {
540 dst_points_to_set.add_tuple_source(index, tuple_source);
541 }
542 });
543 return *PerInst(instruction)->points_to_set;
544 }
545
ToString() const546 string TuplePointsToAnalysis::ToString() const {
547 string output =
548 absl::StrFormat("TuplePointsToSet for module %s:\n", module_->name());
549 for (const auto* computation : module_->MakeNonfusionComputations()) {
550 const char* entry =
551 computation == module_->entry_computation() ? "entry " : "";
552 absl::StrAppend(&output, entry, "computation ", computation->name(), ":\n");
553 for (const HloInstruction* instruction :
554 computation->MakeInstructionPostOrder()) {
555 InstructionToString(instruction, &output);
556 if (instruction->opcode() == HloOpcode::kFusion) {
557 for (auto* fused : instruction->fused_instructions()) {
558 InstructionToString(fused, &output);
559 }
560 }
561 }
562 }
563
564 absl::StrAppend(&output, "LogicalBuffers:\n");
565 for (const auto& b : logical_buffer_analysis_->logical_buffers()) {
566 absl::StrAppend(&output, " buffer ", b->ToString(), ":\n");
567 for (const BufferAlias& alias : logical_buffer_aliases_.at(b->id())) {
568 absl::StrAppend(&output, " alias ", alias.ToString(), "\n");
569 }
570 }
571 return output;
572 }
573
InstructionToString(const HloInstruction * instruction,string * output) const574 void TuplePointsToAnalysis::InstructionToString(
575 const HloInstruction* instruction, string* output) const {
576 const string prefix = instruction->IsFused() ? " " : "";
577 absl::StrAppend(output, prefix, " instruction ",
578 instruction->ToShortString(), ":\n");
579 const PointsToSet& points_to_set = GetPointsToSet(instruction);
580 points_to_set.ForEachElement([&prefix, &output](
581 const ShapeIndex& index,
582 const PointsToSet::BufferList& points_to) {
583 absl::StrAppend(output, prefix, " {", absl::StrJoin(index, ","), "}: ",
584 absl::StrJoin(points_to, ", ",
585 [](string* out, const LogicalBuffer* source) {
586 out->append(source->ToString());
587 }),
588 "\n");
589 });
590 }
591
DoesNotUseOperandBuffer(const HloInstruction * operand,const ShapeIndex & index,const HloInstruction * user) const592 bool TuplePointsToAnalysis::DoesNotUseOperandBuffer(
593 const HloInstruction* operand, const ShapeIndex& index,
594 const HloInstruction* user) const {
595 CHECK(user->IsUserOf(operand))
596 << "user: " << user->ToString() << " operand: " << operand->ToString();
597 if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) {
598 // GetTupleElement instructions only access the top-level buffer of their
599 // operand.
600 return true;
601 } else if (user->opcode() == HloOpcode::kFusion &&
602 user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
603 // Find fusion parameter associated with 'operand'.
604 auto it = absl::c_find_if(
605 user->fused_parameters(), [&](HloInstruction* fused_param) {
606 return user->operand(fused_param->parameter_number()) == operand;
607 });
608 CHECK(it != user->fused_parameters().end());
609 // Iterate through all users of all buffer aliases of the buffer in the
610 // points-to set of fusion parameter at 'index'.
611 // Return false if any uses are detected at 'index', returns true otherwise.
612 const LogicalBuffer* buffer = GetBufferDefinedAt(*it, index).ValueOrDie();
613 for (const BufferAlias& alias : GetBufferAliases(*buffer)) {
614 for (HloInstruction* alias_user : alias.instruction()->users()) {
615 if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
616 alias_user)) {
617 continue;
618 }
619 // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'.
620 return false;
621 }
622 }
623 // Return true: found no uses of 'operand' at 'index' in 'user'.
624 return true;
625 }
626 return false;
627 }
628
629 // Returns all uses of all aliases of 'instruction' at 'index' in 'uses'.
630 // Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index)
631 // where 'user' is a user of an alias of 'instruction' at 'index', and
632 // 'operand_index' is the operand index at which the alias appears in the
633 // operand list of 'user'.
634 std::vector<std::pair<HloInstruction*, int64>>
GetAllUsesOfInstructionAtIndex(HloInstruction * instruction,const ShapeIndex & index) const635 TuplePointsToAnalysis::GetAllUsesOfInstructionAtIndex(
636 HloInstruction* instruction, const ShapeIndex& index) const {
637 std::vector<std::pair<HloInstruction*, int64>> uses;
638 const PointsToSet::BufferList& points_to =
639 GetPointsToSet(instruction).element(index);
640 for (const LogicalBuffer* buffer : points_to) {
641 for (const BufferAlias& alias : GetBufferAliases(*buffer)) {
642 for (HloInstruction* alias_user : alias.instruction()->users()) {
643 if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
644 alias_user)) {
645 continue;
646 }
647 for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) {
648 uses.emplace_back(alias_user, op_idx);
649 }
650 }
651 }
652 }
653 return uses;
654 }
655
656 // Returns true if there is exactly one use of 'operand' at 'operand_index'
657 // in 'fusion.fused_instructions', where the singleton use is the fused
658 // root at operand index 'use_operand_index'. Returns false otherwise.
659 //
660 // REQUIRES: 'fusion' opcode is a kFusion instruction.
HasUniqueFusedUseOfOperandAt(HloInstruction * operand,const ShapeIndex & operand_index,HloInstruction * fusion,const int64 use_operand_index) const661 bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt(
662 HloInstruction* operand, const ShapeIndex& operand_index,
663 HloInstruction* fusion, const int64 use_operand_index) const {
664 CHECK_EQ(HloOpcode::kFusion, fusion->opcode());
665 // Check that 'operand' is unique in the operand list of 'fusion'.
666 if (fusion->OperandIndices(operand).size() > 1) {
667 return false;
668 }
669 // Find fusion parameter associated with 'operand'.
670 const auto& fused_params = fusion->fused_parameters();
671 auto fused_param_it =
672 absl::c_find_if(fused_params, [&](HloInstruction* fused_param) {
673 return fusion->operand(fused_param->parameter_number()) == operand;
674 });
675 if (fused_param_it == fused_params.end()) {
676 return false;
677 }
678 auto* fused_param = *fused_param_it;
679 // Get all uses of 'operand' at 'index' from 'fusion.fused_instructions'.
680 auto fused_param_uses =
681 GetAllUsesOfInstructionAtIndex(fused_param, operand_index);
682 // Return true iff there is exactly one use of 'operand' at 'index', and
683 // this singleton use is the fused root (at index in 'use_operand_indices').
684 return fused_param_uses.size() == 1 &&
685 fused_param_uses[0].first == fusion->fused_expression_root() &&
686 fused_param_uses[0].second == use_operand_index;
687 }
688
689 // User and operand can share buffers iff both instructions emit the same shape
690 // and layout, and 'user' meets one of the following qualifications:
691 //
692 // (1) Is element-wise. Or...
693 // (2) Is a loop fusion instruction where the only use of 'operand' at 'index'
694 // in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root
695 // at operand 0. Or...
696 // (3) Is a kDot -> kAdd output fusion instruction where the only use of
697 // 'operand' at 'index' in the set 'user.fused_instructions' is a kAdd fused
698 // root at operand 0 or 1. Or...
699 // (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index
700 // 0.
701 // (5) The 'user' of 'operand' is Sort, and it is the only user.
702 // (6) The 'user' of 'operand' is TriangularSolve, it is the second operand,
703 // and it is the only user.
704 //
705 // (2) and (3) can only be determined if points-to analysis is available.
CanShareOperandBufferWithUser(HloInstruction * operand,const ShapeIndex & operand_index,HloInstruction * user,const ShapeIndex & user_index) const706 bool TuplePointsToAnalysis::CanShareOperandBufferWithUser(
707 HloInstruction* operand, const ShapeIndex& operand_index,
708 HloInstruction* user, const ShapeIndex& user_index) const {
709 CHECK(user->IsUserOf(operand))
710 << "user: " << user->ToString() << " operand: " << operand->ToString();
711 const Shape& operand_subshape =
712 ShapeUtil::GetSubshape(operand->shape(), operand_index);
713 const Shape& user_subshape =
714 ShapeUtil::GetSubshape(user->shape(), user_index);
715 // Check that operand and user emit the same shape and layout.
716 if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
717 return false;
718 }
719 if (user->opcode() == HloOpcode::kFusion) {
720 if (user->fusion_kind() == HloInstruction::FusionKind::kLoop ||
721 user->fusion_kind() == HloInstruction::FusionKind::kInput) {
722 if (user->fused_expression_root()->opcode() ==
723 HloOpcode::kDynamicUpdateSlice) {
724 // Loop fusion with kDynamicUpdateSlice fused root.
725 //
726 // Returns true iff there is exactly one use of 'operand' at shape index
727 // 'operand_index', and this singleton use is the fused root at operand
728 // index 0.
729 return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0);
730 } else {
731 HloInstruction* fusion_param =
732 user->fused_parameter(user->operand_index(operand));
733 return HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
734 fusion_param);
735 }
736 } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput &&
737 user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
738 // Output fusion with kAdd fused root.
739
740 // Check if one operand of kAdd fused root is kDot or kConvolution.
741 auto* add = user->fused_expression_root();
742 auto add_operand_it =
743 absl::c_find_if(add->operands(), [&](HloInstruction* operand) {
744 return operand->opcode() == HloOpcode::kConvolution ||
745 operand->opcode() == HloOpcode::kDot;
746 });
747 if (add_operand_it == add->operands().end()) {
748 return false;
749 }
750 auto* matched_add_operand = *add_operand_it;
751 // Calculate operand index of 'add' operand which was not matched above.
752 const int64 other_add_operand_index =
753 matched_add_operand == add->operand(0) ? 1 : 0;
754 // Returns true iff there is exactly one use of 'operand' at shape index
755 // 'operand_index', and this singleton use is the fused root (at operand
756 // index 'other_add_operand_index').
757 return HasUniqueFusedUseOfOperandAt(operand, operand_index, user,
758 other_add_operand_index);
759 }
760 }
761 if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
762 user->opcode() == HloOpcode::kScatter ||
763 user->opcode() == HloOpcode::kWhile) {
764 // We eliminated other users in BufferLiveness::live_range_strictly_before,
765 // so here we just need to check that the use is at operand index 0.
766 std::vector<int64> operand_indices = user->OperandIndices(operand);
767 return operand_indices.size() == 1 && operand_indices[0] == 0;
768 }
769 if (user->opcode() == HloOpcode::kSort) {
770 // Only valid if there are no other users.
771 if (operand->users().size() != 1) {
772 return false;
773 }
774 // If we only sort keys, the output of sort is not a tuple, so we can always
775 // share the buffer.
776 if (user->operand_count() == 1) {
777 return true;
778 }
779 CHECK(!user_index.empty());
780 // Only share with the right tuple element buffer.
781 std::vector<int64> operand_indices = user->OperandIndices(operand);
782 return operand_indices.size() == 1 && user_index[0] == operand_indices[0];
783 }
784 if (user->opcode() == HloOpcode::kTriangularSolve) {
785 // Only valid if there are no other users.
786 if (operand->users().size() != 1) {
787 return false;
788 }
789 std::vector<int64> operand_indices = user->OperandIndices(operand);
790 return operand_indices.size() == 1 && operand_indices[0] == 1;
791 }
792 if (user->opcode() == HloOpcode::kCall) {
793 // TODO(b/62548313): Remove when buffer assignment is module scoped and
794 // does not assign buffers to calls.
795 // Find called computation parameter associated with 'operand'.
796 const std::vector<int64> operand_indices = user->OperandIndices(operand);
797 if (operand_indices.size() > 1) {
798 return false;
799 }
800 CHECK_EQ(1, operand_indices.size());
801 auto* param = user->to_apply()->parameter_instruction(operand_indices[0]);
802 // Get all uses of 'operand' at 'index' in called computation.
803 auto param_uses = GetAllUsesOfInstructionAtIndex(param, operand_index);
804
805 // Return true iff:
806 // *) There exists exactly one use of 'operand' in called computation.
807 // *) The unique use is by the root instruction of called computation.
808 // (Note: we check the root of the called computation, because the
809 // root result buffer is required to alias with the Call result buffer).
810 // *) The root instruction of the called computation is element-wise on
811 // 'operand'.
812 auto* callee_root = user->to_apply()->root_instruction();
813 return param_uses.size() == 1 && param_uses[0].first == callee_root &&
814 callee_root->IsElementwiseOnOperand(param_uses[0].second);
815 }
816 // Loop fusions that contain transposing copies won't reach here as they have
817 // different layouts, which fails the check in the beginning of this function.
818 //
819 // Multi-output fusion will fail the check here as tuples are not considered
820 // an elementwise operation.
821 return user->IsElementwiseOnOperand(user->operand_index(operand));
822 }
823
824 } // namespace xla
825