1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
17
18 #include <ostream>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_format.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/compiler/xla/map_util.h"
29 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/platform/logging.h"
36
37 namespace xla {
38
ToString() const39 string BufferAlias::ToString() const {
40 return absl::StrCat("BufferAlias(", instruction_->name(), "[",
41 absl::StrJoin(index_, ","), "])");
42 }
43
operator <<(std::ostream & out,const BufferAlias & buffer_alias)44 std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias) {
45 out << buffer_alias.ToString();
46 return out;
47 }
48
IsAmbiguous() const49 bool PointsToSet::IsAmbiguous() const {
50 bool ambiguous = false;
51 ForEachElement(
52 [&ambiguous](const ShapeIndex& /*index*/, const BufferList& points_to) {
53 ambiguous |= points_to.size() > 1;
54 });
55 return ambiguous;
56 }
57
IsDistinct() const58 bool PointsToSet::IsDistinct() const {
59 bool distinct = true;
60 absl::flat_hash_set<const LogicalBuffer*> all_points_to;
61 ForEachElement([&](const ShapeIndex& /*index*/, const BufferList& points_to) {
62 for (auto& buffer : points_to) {
63 if (all_points_to.contains(buffer)) {
64 distinct = false;
65 }
66 all_points_to.insert(buffer);
67 }
68 });
69 return distinct;
70 }
71
size() const72 size_t PointsToSet::size() const {
73 // Because pointed-to elements may be duplicated we have to create a flattened
74 // set and return the size.
75 return CreateFlattenedSet().size();
76 }
77
CreateFlattenedSet() const78 PointsToSet::BufferSet PointsToSet::CreateFlattenedSet() const {
79 BufferSet flat_set;
80 ForEachElement(
81 [&flat_set](const ShapeIndex& /*index*/, const BufferList& buffers) {
82 flat_set.insert(buffers.begin(), buffers.end());
83 });
84 return flat_set;
85 }
86
ContainsBuffer(const LogicalBuffer & buffer) const87 bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const {
88 bool found = false;
89 ForEachElement([&found, &buffer](const ShapeIndex& /*index*/,
90 const BufferList& pointed_to_buffers) {
91 if (!found && absl::c_linear_search(pointed_to_buffers, &buffer)) {
92 found = true;
93 }
94 });
95 return found;
96 }
97
ContainsBufferAtIndex(const LogicalBuffer & buffer,const ShapeIndex & index) const98 bool PointsToSet::ContainsBufferAtIndex(const LogicalBuffer& buffer,
99 const ShapeIndex& index) const {
100 const auto& pointed_to_buffers = element(index);
101 return absl::c_linear_search(pointed_to_buffers, &buffer);
102 }
103
AddPointedToBuffer(const LogicalBuffer & buffer,const ShapeIndex & index)104 void PointsToSet::AddPointedToBuffer(const LogicalBuffer& buffer,
105 const ShapeIndex& index) {
106 if (ContainsBufferAtIndex(buffer, index)) {
107 return;
108 }
109 mutable_element(index)->push_back(&buffer);
110 }
111
tuple_sources(const ShapeIndex & index) const112 const PointsToSet::SourceSet& PointsToSet::tuple_sources(
113 const ShapeIndex& index) const {
114 return tree_.element(index).tuple_sources;
115 }
116
add_tuple_source(const ShapeIndex & index,HloInstruction * tuple)117 void PointsToSet::add_tuple_source(const ShapeIndex& index,
118 HloInstruction* tuple) {
119 tree_.mutable_element(index)->tuple_sources.insert(tuple);
120 }
121
122 namespace {
123 // Gather fusion instructions from 'instruction' into 'fusion_instructions'.
GatherFusionInstructions(HloInstruction * instruction,std::vector<HloInstruction * > * fusion_instructions)124 void GatherFusionInstructions(
125 HloInstruction* instruction,
126 std::vector<HloInstruction*>* fusion_instructions) {
127 CHECK_EQ(HloOpcode::kFusion, instruction->opcode());
128 for (auto* fused : instruction->fused_instructions()) {
129 if (fused->opcode() == HloOpcode::kFusion) {
130 GatherFusionInstructions(fused, fusion_instructions);
131 }
132 }
133 fusion_instructions->push_back(instruction);
134 }
135
136 } // namespace
137
138 /* static */ StatusOr<std::unique_ptr<TuplePointsToAnalysis>>
Run(const HloModule * module)139 TuplePointsToAnalysis::Run(const HloModule* module) {
140 auto logical_buffer_analysis = LogicalBufferAnalysis::Run(module);
141 std::unique_ptr<TuplePointsToAnalysis> analysis(new TuplePointsToAnalysis(
142 module, logical_buffer_analysis.ConsumeValueOrDie()));
143 TF_RETURN_IF_ERROR(analysis->Analyze());
144 return std::move(analysis);
145 }
146
Analyze()147 Status TuplePointsToAnalysis::Analyze() {
148 per_instruction_.clear();
149 per_instruction_.reserve(module_->instruction_count());
150
151 logical_buffer_aliases_.clear();
152 logical_buffer_aliases_.resize(
153 logical_buffer_analysis_->num_logical_buffers());
154
155 std::vector<HloInstruction*> fusion_instructions;
156 for (auto* computation : module_->MakeNonfusionComputations()) {
157 TF_RETURN_IF_ERROR(computation->Accept(this));
158 TF_RETURN_IF_ERROR(
159 PopulateDefinedBuffersAndAliases(computation->instructions()));
160 for (auto* instruction : computation->instructions()) {
161 if (instruction->opcode() == HloOpcode::kFusion) {
162 GatherFusionInstructions(instruction, &fusion_instructions);
163 }
164 }
165 }
166 // Run points-to analysis on fusion instructions in 'computation'.
167 for (auto* instruction : fusion_instructions) {
168 TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this));
169 TF_RETURN_IF_ERROR(
170 PopulateDefinedBuffersAndAliases(instruction->fused_instructions()));
171 }
172
173 XLA_VLOG_LINES(3, ToString());
174
175 return Status::OK();
176 }
177
178 Status TuplePointsToAnalysis::PopulateDefinedBuffersAndAliases(const decltype(
179 std::declval<HloComputation>().instructions())& instructions) {
180 for (auto* instruction : instructions) {
181 PerInstruction* pi = PerInst(instruction);
182 TF_RETURN_IF_ERROR(GatherBuffersDefinedByInstruction(
183 instruction, &pi->instruction_defined_buffers));
184
185 const PointsToSet& points_to_set = GetPointsToSet(instruction);
186 points_to_set.ForEachElement(
187 [this, &instruction](
188 const ShapeIndex& index,
__anon27932d9d0602( const ShapeIndex& index, const PointsToSet::BufferList& pointed_to_buffers) 189 const PointsToSet::BufferList& pointed_to_buffers) {
190 for (const LogicalBuffer* buffer : pointed_to_buffers) {
191 logical_buffer_aliases_[buffer->id()].emplace_back(instruction,
192 index);
193 }
194 });
195 }
196 return Status::OK();
197 }
198
DefaultAction(HloInstruction * hlo_instruction)199 Status TuplePointsToAnalysis::DefaultAction(HloInstruction* hlo_instruction) {
200 // Create trivial points-to set for instruction. Each points-to set at index i
201 // contains a single element LogicalBuffer(hlo_instruction, i). This indicates
202 // that this instruction is the source of all buffers in its own output.
203 PointsToSet& points_to_set = CreateEmptyPointsToSet(hlo_instruction);
204 points_to_set.ForEachMutableElement(
205 [this, hlo_instruction](const ShapeIndex& index,
206 PointsToSet::BufferList* buffers) {
207 buffers->push_back(
208 &logical_buffer_analysis_->GetBuffer(hlo_instruction, index));
209 });
210
211 if (hlo_instruction->shape().IsTuple()) {
212 // If the hlo instruction is a tuple-shaped, then trivially the instruction
213 // itself is the source of the tuple.
214 points_to_set.add_tuple_source({}, hlo_instruction);
215 }
216
217 return Status::OK();
218 }
219
HandleGetTupleElement(HloInstruction * get_tuple_element)220 Status TuplePointsToAnalysis::HandleGetTupleElement(
221 HloInstruction* get_tuple_element) {
222 // GetTupleElement forwards a pointer to a particular element of the tuple
223 // operand.
224 int64 element_index = get_tuple_element->tuple_index();
225
226 PointsToSet& points_to_set = CreateEmptyPointsToSet(get_tuple_element);
227 const PointsToSet& operand_points_to_set =
228 *PerInst(get_tuple_element->operand(0))->points_to_set;
229
230 // Copy the points-to set (and tuple sources) at index {element_index} of the
231 // operand to the points-to set for this GetTupleElement instruction.
232 points_to_set.ForEachMutableElement(
233 [&](const ShapeIndex& target_index, PointsToSet::BufferList* points_to) {
234 // Construct an index into the operand by prepending element_index to
235 // the index for the GetTupleElement instruction's points-to set.
236 ShapeIndex src_index;
237 src_index.push_back(element_index);
238 for (auto element : target_index) {
239 src_index.push_back(element);
240 }
241
242 *points_to = operand_points_to_set.element(src_index);
243 for (HloInstruction* tuple :
244 operand_points_to_set.tuple_sources(src_index)) {
245 points_to_set.add_tuple_source(target_index, tuple);
246 }
247 });
248
249 return Status::OK();
250 }
251
HandleCopy(HloInstruction * copy)252 Status TuplePointsToAnalysis::HandleCopy(HloInstruction* copy) {
253 // A kCopy instruction performs a shallow copy of the operand. The top-level
254 // buffer (index={}) is newly created, but all other buffers (in the case of a
255 // tuple shape) come from the operand
256 PointsToSet& points_to_set = CreateCopiedPointsToSet(copy, copy->operand(0));
257 points_to_set.mutable_element(/*index=*/{})->clear();
258 points_to_set.AddPointedToBuffer(
259 logical_buffer_analysis_->GetBuffer(copy, /*index=*/{}),
260 /*index=*/{});
261
262 return Status::OK();
263 }
264
HandleBitcast(HloInstruction * bitcast)265 Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) {
266 // A kBitcast instruction aliases its operand. That is, the buffer of its
267 // result *is* the buffer of its operand, so just copy the operands points-to
268 // set.
269 CreateCopiedPointsToSet(bitcast, bitcast->operand(0));
270 return Status::OK();
271 }
272
HandleDomain(HloInstruction * domain)273 Status TuplePointsToAnalysis::HandleDomain(HloInstruction* domain) {
274 // A kDomain instruction aliases its operand. That is, the buffer of its
275 // result *is* the buffer of its operand, so just copy the operands points-to
276 // set.
277 CreateCopiedPointsToSet(domain, domain->operand(0));
278 return Status::OK();
279 }
280
HandleAddDependency(HloInstruction * add_dependency)281 Status TuplePointsToAnalysis::HandleAddDependency(
282 HloInstruction* add_dependency) {
283 // AddDependency just forwards the value of its zero-th operand.
284 CreateCopiedPointsToSet(add_dependency, add_dependency->operand(0));
285 return Status::OK();
286 }
287
HandleRecvDone(HloInstruction * recv_done)288 Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) {
289 // RecvDone aliases its input (Recv) tuple element {0} to element {0} of its
290 // output. The other indices ({} and {1}) define their own buffers.
291 PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done);
292 points_to_set.AddPointedToBuffer(
293 logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{}),
294 /*index=*/{});
295 points_to_set.AddPointedToBuffer(
296 logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{1}),
297 /*index=*/{1});
298
299 const PointsToSet& operand_points_to_set =
300 GetPointsToSet(recv_done->operand(0));
301
302 // Recursively copy the points to set of the operand tuple {0} to the output
303 // element {0}.
304 points_to_set.ForEachMutableElement(
305 [&points_to_set, &operand_points_to_set](
306 const ShapeIndex& index, PointsToSet::BufferList* buffers) {
307 if (index.empty() || index[0] != 0) {
308 return;
309 }
310 *buffers = operand_points_to_set.element(index);
311 for (auto& tuple_source : operand_points_to_set.tuple_sources(index)) {
312 points_to_set.add_tuple_source(index, tuple_source);
313 }
314 });
315 return Status::OK();
316 }
317
HandleCopyStart(HloInstruction * copy_start)318 Status TuplePointsToAnalysis::HandleCopyStart(HloInstruction* copy_start) {
319 // CopyStart forwards its aliased operand to {1}.
320 PointsToSet& points_to_set = CreateEmptyPointsToSet(copy_start);
321 const PointsToSet& operand_points_to_set =
322 GetPointsToSet(copy_start->operand(0));
323
324 points_to_set.ForEachMutableElement(
325 [&](const ShapeIndex& target_index, PointsToSet::BufferList* buffers) {
326 if (target_index == ShapeIndex({1})) {
327 *buffers = operand_points_to_set.element(/*index=*/{});
328 } else {
329 buffers->push_back(
330 &logical_buffer_analysis_->GetBuffer(copy_start, target_index));
331 }
332 });
333
334 for (HloInstruction* tuple :
335 operand_points_to_set.tuple_sources(/*index=*/{})) {
336 points_to_set.add_tuple_source(/*index=*/{1}, tuple);
337 }
338
339 return Status::OK();
340 }
341
HandleCopyDone(HloInstruction * copy_done)342 Status TuplePointsToAnalysis::HandleCopyDone(HloInstruction* copy_done) {
343 // CopyDone forwards its aliased operand.
344 PointsToSet& points_to_set = CreateEmptyPointsToSet(copy_done);
345 const PointsToSet& operand_points_to_set =
346 GetPointsToSet(copy_done->operand(0));
347 operand_points_to_set.ForEachElement(
348 [&points_to_set, &operand_points_to_set](
349 const ShapeIndex& src_index,
350 const PointsToSet::BufferList& points_to) {
351 if (src_index == ShapeIndex({0})) {
352 const ShapeIndex target_index = {};
353 *points_to_set.mutable_element(target_index) = points_to;
354
355 for (HloInstruction* tuple :
356 operand_points_to_set.tuple_sources(src_index)) {
357 points_to_set.add_tuple_source(target_index, tuple);
358 }
359 }
360 });
361
362 return Status::OK();
363 }
364
HandleSend(HloInstruction * send)365 Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) {
366 // Send creates a tuple of {aliased operand, U32 context, token}.
367 PointsToSet& points_to_set = CreateEmptyPointsToSet(send);
368
369 // Creates the points to set for the tuple and its element at {1}.
370 auto top_buffer = points_to_set.mutable_element(ShapeIndex({}));
371 top_buffer->push_back(
372 &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({})));
373 points_to_set.add_tuple_source({}, send);
374
375 auto context_buffer = points_to_set.mutable_element(ShapeIndex({1}));
376 context_buffer->push_back(
377 &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({1})));
378
379 auto token_buffer = points_to_set.mutable_element(ShapeIndex({2}));
380 token_buffer->push_back(
381 &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({2})));
382
383 // Recursively copy the points to set of the operand to output tuple {0}.
384 const PointsToSet& operand_points_to_set = GetPointsToSet(send->operand(0));
385 operand_points_to_set.ForEachElement(
386 [&points_to_set, &operand_points_to_set](
387 const ShapeIndex& src_index,
388 const PointsToSet::BufferList& points_to) {
389 ShapeIndex target_index({0});
390 for (auto element : src_index) {
391 target_index.push_back(element);
392 }
393 *points_to_set.mutable_element(target_index) = points_to;
394
395 for (HloInstruction* tuple :
396 operand_points_to_set.tuple_sources(src_index)) {
397 points_to_set.add_tuple_source(target_index, tuple);
398 }
399 });
400
401 return Status::OK();
402 }
403
HandleTuple(HloInstruction * tuple)404 Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) {
405 absl::Span<HloInstruction* const> operands(tuple->operands());
406 PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple);
407 points_to_set.AddPointedToBuffer(
408 logical_buffer_analysis_->GetBuffer(tuple, /*index=*/{}),
409 /*index=*/{});
410
411 // A tuple contains references to all input operands and transitively any
412 // references in those operands.
413 for (int64 i = 0; i < operands.size(); ++i) {
414 const PointsToSet& operand_points_to_set =
415 *PerInst(operands[i])->points_to_set;
416
417 // Copy the points-to set (and tuple sources) of the operand into the
418 // respective subtree of the tuple instructions points-to set.
419 operand_points_to_set.ForEachElement(
420 [&points_to_set, &operand_points_to_set, i](
421 const ShapeIndex& src_index,
422 const PointsToSet::BufferList& points_to) {
423 ShapeIndex target_index;
424 target_index.push_back(i);
425 for (auto element : src_index) {
426 target_index.push_back(element);
427 }
428
429 *points_to_set.mutable_element(target_index) = points_to;
430
431 for (HloInstruction* tuple :
432 operand_points_to_set.tuple_sources(src_index)) {
433 points_to_set.add_tuple_source(target_index, tuple);
434 }
435 });
436 }
437
438 points_to_set.add_tuple_source({}, tuple);
439
440 return Status::OK();
441 }
442
HandleTupleSelect(HloInstruction * tuple_select)443 Status TuplePointsToAnalysis::HandleTupleSelect(HloInstruction* tuple_select) {
444 // Select allocates a new buffer and then shallow copies the on_true or
445 // on_false buffer into this new buffer. Which side is chosen cannot be
446 // determined statically so conservatively set the points-to set to the union
447 // of these on_true and on_false operands.
448 //
449 // First create a copy of the on_true points-to set (and tuple sources), then
450 // add in elements of the on_false points-to set (tuple sources).
451 auto on_true = tuple_select->operand(1);
452 auto on_false = tuple_select->operand(2);
453 PointsToSet& points_to_set = CreateCopiedPointsToSet(tuple_select, on_true);
454 const PointsToSet& false_points_to_set = *PerInst(on_false)->points_to_set;
455 points_to_set.ForEachMutableElement(
456 [&](const ShapeIndex& index, PointsToSet::BufferList* buffers) {
457 for (const LogicalBuffer* false_buffer :
458 false_points_to_set.element(index)) {
459 points_to_set.AddPointedToBuffer(*false_buffer, index);
460 }
461
462 for (HloInstruction* tuple : false_points_to_set.tuple_sources(index)) {
463 points_to_set.add_tuple_source(index, tuple);
464 }
465 });
466
467 // Select creates a new (top-level) buffer to store its result, so its
468 // respective element in the points-to set should contain only itself.
469 points_to_set.mutable_element({})->clear();
470 points_to_set.AddPointedToBuffer(
471 logical_buffer_analysis_->GetBuffer(tuple_select, /*index=*/{}),
472 /*index=*/{});
473 return Status::OK();
474 }
475
GetPointsToSet(const HloInstruction * hlo_instruction) const476 const PointsToSet& TuplePointsToAnalysis::GetPointsToSet(
477 const HloInstruction* hlo_instruction) const {
478 return *PerInst(hlo_instruction)->points_to_set;
479 }
480
CreateEmptyPointsToSet(const HloInstruction * instruction)481 PointsToSet& TuplePointsToAnalysis::CreateEmptyPointsToSet(
482 const HloInstruction* instruction) {
483 PerInstruction* pi = PerInst(instruction);
484 CHECK(pi->points_to_set == nullptr)
485 << "instruction should not have been present in the map.";
486 auto set = absl::make_unique<PointsToSet>(&instruction->shape());
487 pi->points_to_set = std::move(set);
488 // Return *set using the iterator returned by emplace.
489 return *pi->points_to_set;
490 }
491
InstructionDefinesBufferAtIndex(const HloInstruction * instruction,const ShapeIndex & index) const492 bool TuplePointsToAnalysis::InstructionDefinesBufferAtIndex(
493 const HloInstruction* instruction, const ShapeIndex& index) const {
494 const auto& buffers = GetPointsToSet(instruction).element(index);
495 return (buffers.size() == 1 && buffers[0]->instruction() == instruction);
496 }
497
VerifyBuffer(const LogicalBuffer & buffer) const498 Status TuplePointsToAnalysis::VerifyBuffer(const LogicalBuffer& buffer) const {
499 if (!InstructionDefinesBufferAtIndex(buffer.instruction(), buffer.index())) {
500 return FailedPrecondition(
501 "LogicalBuffer %s is ill-defined: instruction %s does not define a "
502 "buffer at that index",
503 buffer.ToString(), buffer.instruction()->name());
504 }
505
506 if (buffer.id() < 0 ||
507 buffer.id() >= logical_buffer_analysis_->num_logical_buffers()) {
508 return FailedPrecondition("LogicalBuffer %s is ill-defined: invalid id %d",
509 buffer.ToString(), buffer.id());
510 }
511 if (GetBuffer(buffer.id()).instruction() != buffer.instruction() ||
512 GetBuffer(buffer.id()).index() != buffer.index()) {
513 return FailedPrecondition(
514 "LogicalBuffer %s is ill-defined: buffer with same id differs: %s",
515 buffer.ToString(), GetBuffer(buffer.id()).ToString());
516 }
517
518 return Status::OK();
519 }
520
GetBuffer(LogicalBuffer::Id id) const521 const LogicalBuffer& TuplePointsToAnalysis::GetBuffer(
522 LogicalBuffer::Id id) const {
523 CHECK_GE(id, 0);
524 CHECK_LT(id, logical_buffer_analysis_->num_logical_buffers());
525 return logical_buffer_analysis_->GetBuffer(id);
526 }
527
GetBufferDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const528 StatusOr<const LogicalBuffer*> TuplePointsToAnalysis::GetBufferDefinedAt(
529 const HloInstruction* instruction, const ShapeIndex& index) const {
530 const auto& buffers = GetPointsToSet(instruction).element(index);
531 if (buffers.size() != 1 || buffers[0]->instruction() != instruction) {
532 return FailedPrecondition(
533 "instruction %s does not define buffer at index {%s}",
534 instruction->name(), absl::StrJoin(index, ","));
535 }
536 return buffers[0];
537 }
538
539 const TuplePointsToAnalysis::BufferAliasVector&
GetBufferAliases(const LogicalBuffer & buffer) const540 TuplePointsToAnalysis::GetBufferAliases(const LogicalBuffer& buffer) const {
541 return logical_buffer_aliases_.at(buffer.id());
542 }
543
544 const TuplePointsToAnalysis::BufferDefinitionVector&
GetBuffersDefinedByInstruction(const HloInstruction * instruction) const545 TuplePointsToAnalysis::GetBuffersDefinedByInstruction(
546 const HloInstruction* instruction) const {
547 return PerInst(instruction)->instruction_defined_buffers;
548 }
549
GatherBuffersDefinedByInstruction(const HloInstruction * instruction,TuplePointsToAnalysis::BufferDefinitionVector * buffers)550 Status TuplePointsToAnalysis::GatherBuffersDefinedByInstruction(
551 const HloInstruction* instruction,
552 TuplePointsToAnalysis::BufferDefinitionVector* buffers) {
553 GetPointsToSet(instruction)
554 .ForEachElement([buffers, instruction](
555 const ShapeIndex& index,
556 const PointsToSet::BufferList& source_buffers) {
557 // Add buffers which 'instruction' is the source of.
558 CHECK(!source_buffers.empty());
559 if (source_buffers.size() == 1 &&
560 source_buffers[0]->instruction() == instruction) {
561 // If this instruction is the source of this buffer the
562 // indices must match.
563 DCHECK(source_buffers[0]->index() == index);
564 buffers->push_back(source_buffers[0]);
565 } else {
566 // If the points-to set includes more than one buffer then
567 // necessarily this instruction did not produce the
568 // buffer.
569 for (const LogicalBuffer* source_buffer : source_buffers) {
570 DCHECK(source_buffer->instruction() != instruction);
571 }
572 }
573 });
574 return Status::OK();
575 }
576
CreateCopiedPointsToSet(const HloInstruction * instruction,const HloInstruction * src)577 PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet(
578 const HloInstruction* instruction, const HloInstruction* src) {
579 // PointsToSet doesn't have a copy constructor so copy over element-by-element
580 // from src PointsToSet.
581 PointsToSet& dst_points_to_set = CreateEmptyPointsToSet(instruction);
582 const PointsToSet& src_points_to_set = GetPointsToSet(src);
583 dst_points_to_set.ForEachMutableElement(
584 [&dst_points_to_set, &src_points_to_set](
585 const ShapeIndex& index, PointsToSet::BufferList* buffers) {
586 *buffers = src_points_to_set.element(index);
587 for (auto& tuple_source : src_points_to_set.tuple_sources(index)) {
588 dst_points_to_set.add_tuple_source(index, tuple_source);
589 }
590 });
591 return *PerInst(instruction)->points_to_set;
592 }
593
ToString() const594 string TuplePointsToAnalysis::ToString() const {
595 string output =
596 absl::StrFormat("TuplePointsToSet for module %s:\n", module_->name());
597 for (const auto* computation : module_->MakeNonfusionComputations()) {
598 const char* entry =
599 computation == module_->entry_computation() ? "entry " : "";
600 absl::StrAppend(&output, entry, "computation ", computation->name(), ":\n");
601 for (const HloInstruction* instruction :
602 computation->MakeInstructionPostOrder()) {
603 InstructionToString(instruction, &output);
604 if (instruction->opcode() == HloOpcode::kFusion) {
605 for (auto* fused : instruction->fused_instructions()) {
606 InstructionToString(fused, &output);
607 }
608 }
609 }
610 }
611
612 absl::StrAppend(&output, "LogicalBuffers:\n");
613 for (const auto& b : logical_buffer_analysis_->logical_buffers()) {
614 absl::StrAppend(&output, " buffer ", b->ToString(), ":\n");
615 for (const BufferAlias& alias : logical_buffer_aliases_.at(b->id())) {
616 absl::StrAppend(&output, " alias ", alias.ToString(), "\n");
617 }
618 }
619 return output;
620 }
621
InstructionToString(const HloInstruction * instruction,string * output) const622 void TuplePointsToAnalysis::InstructionToString(
623 const HloInstruction* instruction, string* output) const {
624 const string prefix = instruction->IsFused() ? " " : "";
625 absl::StrAppend(output, prefix, " instruction ",
626 instruction->ToShortString(), ":\n");
627 const PointsToSet& points_to_set = GetPointsToSet(instruction);
628 points_to_set.ForEachElement([&prefix, &output](
629 const ShapeIndex& index,
630 const PointsToSet::BufferList& points_to) {
631 absl::StrAppend(output, prefix, " {", absl::StrJoin(index, ","), "}: ",
632 absl::StrJoin(points_to, ", ",
633 [](string* out, const LogicalBuffer* source) {
634 out->append(source->ToString());
635 }),
636 "\n");
637 });
638 }
639
DoesNotUseOperandBuffer(const HloInstruction * operand,const ShapeIndex & index,const HloInstruction * user) const640 bool TuplePointsToAnalysis::DoesNotUseOperandBuffer(
641 const HloInstruction* operand, const ShapeIndex& index,
642 const HloInstruction* user) const {
643 CHECK(user->IsUserOf(operand))
644 << "user: " << user->ToString() << " operand: " << operand->ToString();
645 if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) {
646 // GetTupleElement instructions only access the top-level buffer of their
647 // operand.
648 return true;
649 } else if (user->IsLoopFusion()) {
650 // Find fusion parameter associated with 'operand'.
651 auto it = absl::c_find_if(
652 user->fused_parameters(), [&](HloInstruction* fused_param) {
653 return user->operand(fused_param->parameter_number()) == operand;
654 });
655 CHECK(it != user->fused_parameters().end());
656 // Iterate through all users of all buffer aliases of the buffer in the
657 // points-to set of fusion parameter at 'index'.
658 // Return false if any uses are detected at 'index', returns true otherwise.
659 const LogicalBuffer* buffer = GetBufferDefinedAt(*it, index).ValueOrDie();
660 for (const BufferAlias& alias : GetBufferAliases(*buffer)) {
661 for (HloInstruction* alias_user : alias.instruction()->users()) {
662 if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
663 alias_user)) {
664 continue;
665 }
666 // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'.
667 return false;
668 }
669 }
670 // Return true: found no uses of 'operand' at 'index' in 'user'.
671 return true;
672 }
673 return false;
674 }
675
676 // Returns all uses of all aliases of 'instruction' at 'index' in 'uses'.
677 // Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index)
678 // where 'user' is a user of an alias of 'instruction' at 'index', and
679 // 'operand_index' is the operand index at which the alias appears in the
680 // operand list of 'user'.
681 std::vector<std::pair<HloInstruction*, int64>>
GetAllUsesOfInstructionAtIndex(HloInstruction * instruction,const ShapeIndex & index) const682 TuplePointsToAnalysis::GetAllUsesOfInstructionAtIndex(
683 HloInstruction* instruction, const ShapeIndex& index) const {
684 std::vector<std::pair<HloInstruction*, int64>> uses;
685 const PointsToSet::BufferList& points_to =
686 GetPointsToSet(instruction).element(index);
687 for (const LogicalBuffer* buffer : points_to) {
688 for (const BufferAlias& alias : GetBufferAliases(*buffer)) {
689 for (HloInstruction* alias_user : alias.instruction()->users()) {
690 if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
691 alias_user)) {
692 continue;
693 }
694 for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) {
695 uses.emplace_back(alias_user, op_idx);
696 }
697 }
698 }
699 }
700 return uses;
701 }
702
703 // Returns true if there is exactly one use of 'operand' at 'operand_index'
704 // in 'fusion.fused_instructions', where the singleton use is the fused
705 // root at operand index 'use_operand_index'. Returns false otherwise.
706 //
707 // REQUIRES: 'fusion' opcode is a kFusion instruction.
HasUniqueFusedUseOfOperandAt(HloInstruction * operand,const ShapeIndex & operand_index,HloInstruction * fusion,const int64 use_operand_index) const708 bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt(
709 HloInstruction* operand, const ShapeIndex& operand_index,
710 HloInstruction* fusion, const int64 use_operand_index) const {
711 CHECK_EQ(HloOpcode::kFusion, fusion->opcode());
712 // Check that 'operand' is unique in the operand list of 'fusion'.
713 if (fusion->OperandIndices(operand).size() > 1) {
714 return false;
715 }
716 // Find fusion parameter associated with 'operand'.
717 const auto& fused_params = fusion->fused_parameters();
718 auto fused_param_it =
719 absl::c_find_if(fused_params, [&](HloInstruction* fused_param) {
720 return fusion->operand(fused_param->parameter_number()) == operand;
721 });
722 if (fused_param_it == fused_params.end()) {
723 return false;
724 }
725 auto* fused_param = *fused_param_it;
726 // Get all uses of 'operand' at 'index' from 'fusion.fused_instructions'.
727 auto fused_param_uses =
728 GetAllUsesOfInstructionAtIndex(fused_param, operand_index);
729 // Return true iff there is exactly one use of 'operand' at 'index', and
730 // this singleton use is the fused root (at index in 'use_operand_indices').
731 return fused_param_uses.size() == 1 &&
732 fused_param_uses[0].first == fusion->fused_expression_root() &&
733 fused_param_uses[0].second == use_operand_index;
734 }
735 } // namespace xla
736