1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
17
18 #include <ostream>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_format.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/compiler/xla/map_util.h"
29 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/platform/logging.h"
38
39 namespace xla {
40
ToString() const41 string BufferAlias::ToString() const {
42 return absl::StrCat("BufferAlias(", instruction_->name(), "[",
43 absl::StrJoin(index_, ","), "])");
44 }
45
operator <<(std::ostream & out,const BufferAlias & buffer_alias)46 std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias) {
47 out << buffer_alias.ToString();
48 return out;
49 }
50
IsAmbiguous() const51 bool PointsToSet::IsAmbiguous() const {
52 bool ambiguous = false;
53 ForEachElement(
54 [&ambiguous](const ShapeIndex& /*index*/, const BufferList& points_to) {
55 ambiguous |= points_to.size() > 1;
56 });
57 return ambiguous;
58 }
59
IsDistinct() const60 bool PointsToSet::IsDistinct() const {
61 bool distinct = true;
62 absl::flat_hash_set<const LogicalBuffer*> all_points_to;
63 ForEachElement([&](const ShapeIndex& /*index*/, const BufferList& points_to) {
64 for (auto& buffer : points_to) {
65 if (all_points_to.contains(buffer)) {
66 distinct = false;
67 }
68 all_points_to.insert(buffer);
69 }
70 });
71 return distinct;
72 }
73
size() const74 size_t PointsToSet::size() const {
75 // Because pointed-to elements may be duplicated we have to create a flattened
76 // set and return the size.
77 return CreateFlattenedSet().size();
78 }
79
CreateFlattenedSet() const80 PointsToSet::BufferSet PointsToSet::CreateFlattenedSet() const {
81 BufferSet flat_set;
82 ForEachElement(
83 [&flat_set](const ShapeIndex& /*index*/, const BufferList& buffers) {
84 flat_set.insert(buffers.begin(), buffers.end());
85 });
86 return flat_set;
87 }
88
ContainsBuffer(const LogicalBuffer & buffer) const89 bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const {
90 bool found = false;
91 ForEachElement([&found, &buffer](const ShapeIndex& /*index*/,
92 const BufferList& pointed_to_buffers) {
93 if (!found && absl::c_linear_search(pointed_to_buffers, &buffer)) {
94 found = true;
95 }
96 });
97 return found;
98 }
99
ContainsBufferAtIndex(const LogicalBuffer & buffer,const ShapeIndex & index) const100 bool PointsToSet::ContainsBufferAtIndex(const LogicalBuffer& buffer,
101 const ShapeIndex& index) const {
102 const auto& pointed_to_buffers = element(index);
103 return absl::c_linear_search(pointed_to_buffers, &buffer);
104 }
105
AddPointedToBuffer(const LogicalBuffer & buffer,const ShapeIndex & index)106 void PointsToSet::AddPointedToBuffer(const LogicalBuffer& buffer,
107 const ShapeIndex& index) {
108 if (ContainsBufferAtIndex(buffer, index)) {
109 return;
110 }
111 mutable_element(index)->push_back(&buffer);
112 }
113
tuple_sources(const ShapeIndex & index) const114 const PointsToSet::SourceSet& PointsToSet::tuple_sources(
115 const ShapeIndex& index) const {
116 return tree_.element(index).tuple_sources;
117 }
118
add_tuple_source(const ShapeIndex & index,HloInstruction * tuple)119 void PointsToSet::add_tuple_source(const ShapeIndex& index,
120 HloInstruction* tuple) {
121 tree_.mutable_element(index)->tuple_sources.insert(tuple);
122 }
123
124 namespace {
125 // Gather fusion instructions from 'instruction' into 'fusion_instructions'.
GatherFusionInstructions(HloInstruction * instruction,std::vector<HloInstruction * > * fusion_instructions)126 void GatherFusionInstructions(
127 HloInstruction* instruction,
128 std::vector<HloInstruction*>* fusion_instructions) {
129 CHECK_EQ(HloOpcode::kFusion, instruction->opcode());
130 for (auto* fused : instruction->fused_instructions()) {
131 if (fused->opcode() == HloOpcode::kFusion) {
132 GatherFusionInstructions(fused, fusion_instructions);
133 }
134 }
135 fusion_instructions->push_back(instruction);
136 }
137
138 } // namespace
139
140 /* static */ StatusOr<std::unique_ptr<TuplePointsToAnalysis>>
Run(const HloModule * module)141 TuplePointsToAnalysis::Run(const HloModule* module) {
142 auto logical_buffer_analysis = LogicalBufferAnalysis::Run(module);
143 std::unique_ptr<TuplePointsToAnalysis> analysis(new TuplePointsToAnalysis(
144 module, logical_buffer_analysis.ConsumeValueOrDie()));
145 TF_RETURN_IF_ERROR(analysis->Analyze());
146 return std::move(analysis);
147 }
148
Analyze()149 Status TuplePointsToAnalysis::Analyze() {
150 per_instruction_.clear();
151 per_instruction_.reserve(module_->instruction_count());
152
153 logical_buffer_aliases_.clear();
154 logical_buffer_aliases_.resize(
155 logical_buffer_analysis_->num_logical_buffers());
156
157 std::vector<HloInstruction*> fusion_instructions;
158 for (auto* computation : module_->MakeNonfusionComputations()) {
159 TF_RETURN_IF_ERROR(computation->Accept(this));
160 TF_RETURN_IF_ERROR(
161 PopulateDefinedBuffersAndAliases(computation->instructions()));
162 for (auto* instruction : computation->instructions()) {
163 if (instruction->opcode() == HloOpcode::kFusion) {
164 GatherFusionInstructions(instruction, &fusion_instructions);
165 }
166 }
167 }
168 // Run points-to analysis on fusion instructions in 'computation'.
169 for (auto* instruction : fusion_instructions) {
170 TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this));
171 TF_RETURN_IF_ERROR(
172 PopulateDefinedBuffersAndAliases(instruction->fused_instructions()));
173 }
174
175 XLA_VLOG_LINES(3, ToString());
176
177 return Status::OK();
178 }
179
180 Status TuplePointsToAnalysis::PopulateDefinedBuffersAndAliases(const decltype(
181 std::declval<HloComputation>().instructions())& instructions) {
182 for (auto* instruction : instructions) {
183 PerInstruction* pi = PerInst(instruction);
184 TF_RETURN_IF_ERROR(GatherBuffersDefinedByInstruction(
185 instruction, &pi->instruction_defined_buffers));
186
187 const PointsToSet& points_to_set = GetPointsToSet(instruction);
188 points_to_set.ForEachElement(
189 [this, &instruction](
190 const ShapeIndex& index,
__anonc906b6790602( const ShapeIndex& index, const PointsToSet::BufferList& pointed_to_buffers) 191 const PointsToSet::BufferList& pointed_to_buffers) {
192 for (const LogicalBuffer* buffer : pointed_to_buffers) {
193 logical_buffer_aliases_[buffer->id()].emplace_back(instruction,
194 index);
195 }
196 });
197 }
198 return Status::OK();
199 }
200
DefaultAction(HloInstruction * hlo_instruction)201 Status TuplePointsToAnalysis::DefaultAction(HloInstruction* hlo_instruction) {
202 // Create trivial points-to set for instruction. Each points-to set at index i
203 // contains a single element LogicalBuffer(hlo_instruction, i). This indicates
204 // that this instruction is the source of all buffers in its own output.
205 PointsToSet& points_to_set = CreateEmptyPointsToSet(hlo_instruction);
206 points_to_set.ForEachMutableElement(
207 [this, hlo_instruction](const ShapeIndex& index,
208 PointsToSet::BufferList* buffers) {
209 buffers->push_back(
210 &logical_buffer_analysis_->GetBuffer(hlo_instruction, index));
211 });
212
213 if (hlo_instruction->shape().IsTuple()) {
214 // If the hlo instruction is a tuple-shaped, then trivially the instruction
215 // itself is the source of the tuple.
216 points_to_set.add_tuple_source({}, hlo_instruction);
217 }
218
219 return Status::OK();
220 }
221
HandleGetTupleElement(HloInstruction * get_tuple_element)222 Status TuplePointsToAnalysis::HandleGetTupleElement(
223 HloInstruction* get_tuple_element) {
224 // GetTupleElement forwards a pointer to a particular element of the tuple
225 // operand.
226 int64 element_index = get_tuple_element->tuple_index();
227
228 PointsToSet& points_to_set = CreateEmptyPointsToSet(get_tuple_element);
229 const PointsToSet& operand_points_to_set =
230 *PerInst(get_tuple_element->operand(0))->points_to_set;
231
232 // Copy the points-to set (and tuple sources) at index {element_index} of the
233 // operand to the points-to set for this GetTupleElement instruction.
234 points_to_set.ForEachMutableElement(
235 [&](const ShapeIndex& target_index, PointsToSet::BufferList* points_to) {
236 // Construct an index into the operand by prepending element_index to
237 // the index for the GetTupleElement instruction's points-to set.
238 ShapeIndex src_index;
239 src_index.push_back(element_index);
240 for (auto element : target_index) {
241 src_index.push_back(element);
242 }
243
244 *points_to = operand_points_to_set.element(src_index);
245 for (HloInstruction* tuple :
246 operand_points_to_set.tuple_sources(src_index)) {
247 points_to_set.add_tuple_source(target_index, tuple);
248 }
249 });
250
251 return Status::OK();
252 }
253
HandleCopy(HloInstruction * copy)254 Status TuplePointsToAnalysis::HandleCopy(HloInstruction* copy) {
255 // A kCopy instruction performs a shallow copy of the operand. The top-level
256 // buffer (index={}) is newly created, but all other buffers (in the case of a
257 // tuple shape) come from the operand
258 PointsToSet& points_to_set = CreateCopiedPointsToSet(copy, copy->operand(0));
259 points_to_set.mutable_element(/*index=*/{})->clear();
260 points_to_set.AddPointedToBuffer(
261 logical_buffer_analysis_->GetBuffer(copy, /*index=*/{}),
262 /*index=*/{});
263
264 return Status::OK();
265 }
266
HandleBitcast(HloInstruction * bitcast)267 Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) {
268 // A kBitcast instruction aliases its operand. That is, the buffer of its
269 // result *is* the buffer of its operand, so just copy the operands points-to
270 // set.
271 CreateCopiedPointsToSet(bitcast, bitcast->operand(0));
272 return Status::OK();
273 }
274
HandleDomain(HloInstruction * domain)275 Status TuplePointsToAnalysis::HandleDomain(HloInstruction* domain) {
276 // A kDomain instruction aliases its operand. That is, the buffer of its
277 // result *is* the buffer of its operand, so just copy the operands points-to
278 // set.
279 CreateCopiedPointsToSet(domain, domain->operand(0));
280 return Status::OK();
281 }
282
HandleAddDependency(HloInstruction * add_dependency)283 Status TuplePointsToAnalysis::HandleAddDependency(
284 HloInstruction* add_dependency) {
285 // AddDependency just forwards the value of its zero-th operand.
286 CreateCopiedPointsToSet(add_dependency, add_dependency->operand(0));
287 return Status::OK();
288 }
289
HandleRecvDone(HloInstruction * recv_done)290 Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) {
291 // RecvDone aliases its input (Recv) tuple element {0} to element {0} of its
292 // output. The other indices ({} and {1}) define their own buffers.
293 PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done);
294 points_to_set.AddPointedToBuffer(
295 logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{}),
296 /*index=*/{});
297 points_to_set.AddPointedToBuffer(
298 logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{1}),
299 /*index=*/{1});
300
301 const PointsToSet& operand_points_to_set =
302 GetPointsToSet(recv_done->operand(0));
303
304 // Recursively copy the points to set of the operand tuple {0} to the output
305 // element {0}.
306 points_to_set.ForEachMutableElement(
307 [&points_to_set, &operand_points_to_set](
308 const ShapeIndex& index, PointsToSet::BufferList* buffers) {
309 if (index.empty() || index[0] != 0) {
310 return;
311 }
312 *buffers = operand_points_to_set.element(index);
313 for (auto& tuple_source : operand_points_to_set.tuple_sources(index)) {
314 points_to_set.add_tuple_source(index, tuple_source);
315 }
316 });
317 return Status::OK();
318 }
319
HandleCopyStart(HloInstruction * copy_start)320 Status TuplePointsToAnalysis::HandleCopyStart(HloInstruction* copy_start) {
321 // CopyStart forwards its aliased operand to {1}.
322 PointsToSet& points_to_set = CreateEmptyPointsToSet(copy_start);
323 const PointsToSet& operand_points_to_set =
324 GetPointsToSet(copy_start->operand(0));
325
326 points_to_set.ForEachMutableElement(
327 [&](const ShapeIndex& target_index, PointsToSet::BufferList* buffers) {
328 if (target_index == ShapeIndex({1})) {
329 *buffers = operand_points_to_set.element(/*index=*/{});
330 } else {
331 buffers->push_back(
332 &logical_buffer_analysis_->GetBuffer(copy_start, target_index));
333 }
334 });
335
336 for (HloInstruction* tuple :
337 operand_points_to_set.tuple_sources(/*index=*/{})) {
338 points_to_set.add_tuple_source(/*index=*/{1}, tuple);
339 }
340
341 return Status::OK();
342 }
343
HandleCopyDone(HloInstruction * copy_done)344 Status TuplePointsToAnalysis::HandleCopyDone(HloInstruction* copy_done) {
345 // CopyDone forwards its aliased operand.
346 PointsToSet& points_to_set = CreateEmptyPointsToSet(copy_done);
347 const PointsToSet& operand_points_to_set =
348 GetPointsToSet(copy_done->operand(0));
349 operand_points_to_set.ForEachElement(
350 [&points_to_set, &operand_points_to_set](
351 const ShapeIndex& src_index,
352 const PointsToSet::BufferList& points_to) {
353 if (src_index == ShapeIndex({0})) {
354 const ShapeIndex target_index = {};
355 *points_to_set.mutable_element(target_index) = points_to;
356
357 for (HloInstruction* tuple :
358 operand_points_to_set.tuple_sources(src_index)) {
359 points_to_set.add_tuple_source(target_index, tuple);
360 }
361 }
362 });
363
364 return Status::OK();
365 }
366
HandleSend(HloInstruction * send)367 Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) {
368 // Send creates a tuple of {aliased operand, U32 context, token}.
369 PointsToSet& points_to_set = CreateEmptyPointsToSet(send);
370
371 // Creates the points to set for the tuple and its element at {1}.
372 auto top_buffer = points_to_set.mutable_element(ShapeIndex({}));
373 top_buffer->push_back(
374 &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({})));
375 points_to_set.add_tuple_source({}, send);
376
377 auto context_buffer = points_to_set.mutable_element(ShapeIndex({1}));
378 context_buffer->push_back(
379 &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({1})));
380
381 auto token_buffer = points_to_set.mutable_element(ShapeIndex({2}));
382 token_buffer->push_back(
383 &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({2})));
384
385 // Recursively copy the points to set of the operand to output tuple {0}.
386 const PointsToSet& operand_points_to_set = GetPointsToSet(send->operand(0));
387 operand_points_to_set.ForEachElement(
388 [&points_to_set, &operand_points_to_set](
389 const ShapeIndex& src_index,
390 const PointsToSet::BufferList& points_to) {
391 ShapeIndex target_index({0});
392 for (auto element : src_index) {
393 target_index.push_back(element);
394 }
395 *points_to_set.mutable_element(target_index) = points_to;
396
397 for (HloInstruction* tuple :
398 operand_points_to_set.tuple_sources(src_index)) {
399 points_to_set.add_tuple_source(target_index, tuple);
400 }
401 });
402
403 return Status::OK();
404 }
405
HandleTuple(HloInstruction * tuple)406 Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) {
407 absl::Span<HloInstruction* const> operands(tuple->operands());
408 PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple);
409 points_to_set.AddPointedToBuffer(
410 logical_buffer_analysis_->GetBuffer(tuple, /*index=*/{}),
411 /*index=*/{});
412
413 // A tuple contains references to all input operands and transitively any
414 // references in those operands.
415 for (int64 i = 0; i < operands.size(); ++i) {
416 const PointsToSet& operand_points_to_set =
417 *PerInst(operands[i])->points_to_set;
418
419 // Copy the points-to set (and tuple sources) of the operand into the
420 // respective subtree of the tuple instructions points-to set.
421 operand_points_to_set.ForEachElement(
422 [&points_to_set, &operand_points_to_set, i](
423 const ShapeIndex& src_index,
424 const PointsToSet::BufferList& points_to) {
425 ShapeIndex target_index;
426 target_index.push_back(i);
427 for (auto element : src_index) {
428 target_index.push_back(element);
429 }
430
431 *points_to_set.mutable_element(target_index) = points_to;
432
433 for (HloInstruction* tuple :
434 operand_points_to_set.tuple_sources(src_index)) {
435 points_to_set.add_tuple_source(target_index, tuple);
436 }
437 });
438 }
439
440 points_to_set.add_tuple_source({}, tuple);
441
442 return Status::OK();
443 }
444
HandleTupleSelect(HloInstruction * tuple_select)445 Status TuplePointsToAnalysis::HandleTupleSelect(HloInstruction* tuple_select) {
446 // Select allocates a new buffer and then shallow copies the on_true or
447 // on_false buffer into this new buffer. Which side is chosen cannot be
448 // determined statically so conservatively set the points-to set to the union
449 // of these on_true and on_false operands.
450 //
451 // First create a copy of the on_true points-to set (and tuple sources), then
452 // add in elements of the on_false points-to set (tuple sources).
453 auto on_true = tuple_select->operand(1);
454 auto on_false = tuple_select->operand(2);
455 PointsToSet& points_to_set = CreateCopiedPointsToSet(tuple_select, on_true);
456 const PointsToSet& false_points_to_set = *PerInst(on_false)->points_to_set;
457 points_to_set.ForEachMutableElement(
458 [&](const ShapeIndex& index, PointsToSet::BufferList* buffers) {
459 for (const LogicalBuffer* false_buffer :
460 false_points_to_set.element(index)) {
461 points_to_set.AddPointedToBuffer(*false_buffer, index);
462 }
463
464 for (HloInstruction* tuple : false_points_to_set.tuple_sources(index)) {
465 points_to_set.add_tuple_source(index, tuple);
466 }
467 });
468
469 // Select creates a new (top-level) buffer to store its result, so its
470 // respective element in the points-to set should contain only itself.
471 points_to_set.mutable_element({})->clear();
472 points_to_set.AddPointedToBuffer(
473 logical_buffer_analysis_->GetBuffer(tuple_select, /*index=*/{}),
474 /*index=*/{});
475 return Status::OK();
476 }
477
HandleCustomCall(HloInstruction * custom_call)478 Status TuplePointsToAnalysis::HandleCustomCall(HloInstruction* custom_call) {
479 auto ccall = Cast<HloCustomCallInstruction>(custom_call);
480 PointsToSet& points_to_set = CreateEmptyPointsToSet(custom_call);
481 absl::flat_hash_map<ShapeIndex, std::pair<int64, ShapeIndex>> aliased_outputs;
482 for (const auto& pair : ccall->output_to_operand_aliasing()) {
483 aliased_outputs.emplace(pair.first, pair.second);
484 }
485 points_to_set.ForEachMutableElement([&](const ShapeIndex& index,
486 PointsToSet::BufferList* buffers) {
487 auto it = aliased_outputs.find(index);
488 if (it == aliased_outputs.end()) {
489 points_to_set.AddPointedToBuffer(
490 logical_buffer_analysis_->GetBuffer(custom_call, index), index);
491 } else {
492 const PointsToSet& input_set =
493 *PerInst(ccall->operand(it->second.first))->points_to_set;
494 for (const LogicalBuffer* input_buffer :
495 input_set.element(it->second.second)) {
496 points_to_set.AddPointedToBuffer(*input_buffer, index);
497 }
498
499 for (HloInstruction* tuple : input_set.tuple_sources(it->second.second)) {
500 points_to_set.add_tuple_source(index, tuple);
501 }
502 }
503 });
504 points_to_set.add_tuple_source({}, custom_call);
505 return Status::OK();
506 }
507
GetPointsToSet(const HloInstruction * hlo_instruction) const508 const PointsToSet& TuplePointsToAnalysis::GetPointsToSet(
509 const HloInstruction* hlo_instruction) const {
510 return *PerInst(hlo_instruction)->points_to_set;
511 }
512
CreateEmptyPointsToSet(const HloInstruction * instruction)513 PointsToSet& TuplePointsToAnalysis::CreateEmptyPointsToSet(
514 const HloInstruction* instruction) {
515 PerInstruction* pi = PerInst(instruction);
516 CHECK(pi->points_to_set == nullptr)
517 << "instruction should not have been present in the map.";
518 auto set = absl::make_unique<PointsToSet>(&instruction->shape());
519 pi->points_to_set = std::move(set);
520 // Return *set using the iterator returned by emplace.
521 return *pi->points_to_set;
522 }
523
InstructionDefinesBufferAtIndex(const HloInstruction * instruction,const ShapeIndex & index) const524 bool TuplePointsToAnalysis::InstructionDefinesBufferAtIndex(
525 const HloInstruction* instruction, const ShapeIndex& index) const {
526 const auto& buffers = GetPointsToSet(instruction).element(index);
527 return (buffers.size() == 1 && buffers[0]->instruction() == instruction);
528 }
529
VerifyBuffer(const LogicalBuffer & buffer) const530 Status TuplePointsToAnalysis::VerifyBuffer(const LogicalBuffer& buffer) const {
531 if (!InstructionDefinesBufferAtIndex(buffer.instruction(), buffer.index())) {
532 return FailedPrecondition(
533 "LogicalBuffer %s is ill-defined: instruction %s does not define a "
534 "buffer at that index",
535 buffer.ToString(), buffer.instruction()->name());
536 }
537
538 if (buffer.id() < 0 ||
539 buffer.id() >= logical_buffer_analysis_->num_logical_buffers()) {
540 return FailedPrecondition("LogicalBuffer %s is ill-defined: invalid id %d",
541 buffer.ToString(), buffer.id());
542 }
543 if (GetBuffer(buffer.id()).instruction() != buffer.instruction() ||
544 GetBuffer(buffer.id()).index() != buffer.index()) {
545 return FailedPrecondition(
546 "LogicalBuffer %s is ill-defined: buffer with same id differs: %s",
547 buffer.ToString(), GetBuffer(buffer.id()).ToString());
548 }
549
550 return Status::OK();
551 }
552
GetBuffer(LogicalBuffer::Id id) const553 const LogicalBuffer& TuplePointsToAnalysis::GetBuffer(
554 LogicalBuffer::Id id) const {
555 CHECK_GE(id, 0);
556 CHECK_LT(id, logical_buffer_analysis_->num_logical_buffers());
557 return logical_buffer_analysis_->GetBuffer(id);
558 }
559
GetBufferDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const560 StatusOr<const LogicalBuffer*> TuplePointsToAnalysis::GetBufferDefinedAt(
561 const HloInstruction* instruction, const ShapeIndex& index) const {
562 const auto& buffers = GetPointsToSet(instruction).element(index);
563 if (buffers.size() != 1 || buffers[0]->instruction() != instruction) {
564 return FailedPrecondition(
565 "instruction %s does not define buffer at index {%s}",
566 instruction->name(), absl::StrJoin(index, ","));
567 }
568 return buffers[0];
569 }
570
571 const TuplePointsToAnalysis::BufferAliasVector&
GetBufferAliases(const LogicalBuffer & buffer) const572 TuplePointsToAnalysis::GetBufferAliases(const LogicalBuffer& buffer) const {
573 return logical_buffer_aliases_.at(buffer.id());
574 }
575
576 const TuplePointsToAnalysis::BufferDefinitionVector&
GetBuffersDefinedByInstruction(const HloInstruction * instruction) const577 TuplePointsToAnalysis::GetBuffersDefinedByInstruction(
578 const HloInstruction* instruction) const {
579 return PerInst(instruction)->instruction_defined_buffers;
580 }
581
GatherBuffersDefinedByInstruction(const HloInstruction * instruction,TuplePointsToAnalysis::BufferDefinitionVector * buffers)582 Status TuplePointsToAnalysis::GatherBuffersDefinedByInstruction(
583 const HloInstruction* instruction,
584 TuplePointsToAnalysis::BufferDefinitionVector* buffers) {
585 GetPointsToSet(instruction)
586 .ForEachElement([buffers, instruction](
587 const ShapeIndex& index,
588 const PointsToSet::BufferList& source_buffers) {
589 // Add buffers which 'instruction' is the source of.
590 CHECK(!source_buffers.empty());
591 if (source_buffers.size() == 1 &&
592 source_buffers[0]->instruction() == instruction) {
593 // If this instruction is the source of this buffer the
594 // indices must match.
595 DCHECK(source_buffers[0]->index() == index);
596 buffers->push_back(source_buffers[0]);
597 } else {
598 // If the points-to set includes more than one buffer then
599 // necessarily this instruction did not produce the
600 // buffer.
601 for (const LogicalBuffer* source_buffer : source_buffers) {
602 DCHECK(source_buffer->instruction() != instruction);
603 }
604 }
605 });
606 return Status::OK();
607 }
608
CreateCopiedPointsToSet(const HloInstruction * instruction,const HloInstruction * src)609 PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet(
610 const HloInstruction* instruction, const HloInstruction* src) {
611 // PointsToSet doesn't have a copy constructor so copy over element-by-element
612 // from src PointsToSet.
613 PointsToSet& dst_points_to_set = CreateEmptyPointsToSet(instruction);
614 const PointsToSet& src_points_to_set = GetPointsToSet(src);
615 dst_points_to_set.ForEachMutableElement(
616 [&dst_points_to_set, &src_points_to_set](
617 const ShapeIndex& index, PointsToSet::BufferList* buffers) {
618 *buffers = src_points_to_set.element(index);
619 for (auto& tuple_source : src_points_to_set.tuple_sources(index)) {
620 dst_points_to_set.add_tuple_source(index, tuple_source);
621 }
622 });
623 return *PerInst(instruction)->points_to_set;
624 }
625
ToString() const626 string TuplePointsToAnalysis::ToString() const {
627 string output =
628 absl::StrFormat("TuplePointsToSet for module %s:\n", module_->name());
629 for (const auto* computation : module_->MakeNonfusionComputations()) {
630 const char* entry =
631 computation == module_->entry_computation() ? "entry " : "";
632 absl::StrAppend(&output, entry, "computation ", computation->name(), ":\n");
633 for (const HloInstruction* instruction :
634 computation->MakeInstructionPostOrder()) {
635 InstructionToString(instruction, &output);
636 if (instruction->opcode() == HloOpcode::kFusion) {
637 for (auto* fused : instruction->fused_instructions()) {
638 InstructionToString(fused, &output);
639 }
640 }
641 }
642 }
643
644 absl::StrAppend(&output, "LogicalBuffers:\n");
645 for (const auto& b : logical_buffer_analysis_->logical_buffers()) {
646 absl::StrAppend(&output, " buffer ", b->ToString(), ":\n");
647 for (const BufferAlias& alias : logical_buffer_aliases_.at(b->id())) {
648 absl::StrAppend(&output, " alias ", alias.ToString(), "\n");
649 }
650 }
651 return output;
652 }
653
InstructionToString(const HloInstruction * instruction,string * output) const654 void TuplePointsToAnalysis::InstructionToString(
655 const HloInstruction* instruction, string* output) const {
656 const string prefix = instruction->IsFused() ? " " : "";
657 absl::StrAppend(output, prefix, " instruction ",
658 instruction->ToShortString(), ":\n");
659 const PointsToSet& points_to_set = GetPointsToSet(instruction);
660 points_to_set.ForEachElement([&prefix, &output](
661 const ShapeIndex& index,
662 const PointsToSet::BufferList& points_to) {
663 absl::StrAppend(output, prefix, " {", absl::StrJoin(index, ","), "}: ",
664 absl::StrJoin(points_to, ", ",
665 [](string* out, const LogicalBuffer* source) {
666 out->append(source->ToString());
667 }),
668 "\n");
669 });
670 }
671
DoesNotUseOperandBuffer(const HloInstruction * operand,const ShapeIndex & index,const HloInstruction * user) const672 bool TuplePointsToAnalysis::DoesNotUseOperandBuffer(
673 const HloInstruction* operand, const ShapeIndex& index,
674 const HloInstruction* user) const {
675 CHECK(user->IsUserOf(operand))
676 << "user: " << user->ToString() << " operand: " << operand->ToString();
677 if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) {
678 // GetTupleElement instructions only access the top-level buffer of their
679 // operand.
680 return true;
681 } else if (user->IsLoopFusion()) {
682 // Find fusion parameter associated with 'operand'.
683 auto it = absl::c_find_if(
684 user->fused_parameters(), [&](HloInstruction* fused_param) {
685 return user->operand(fused_param->parameter_number()) == operand;
686 });
687 CHECK(it != user->fused_parameters().end());
688 // Iterate through all users of all buffer aliases of the buffer in the
689 // points-to set of fusion parameter at 'index'.
690 // Return false if any uses are detected at 'index', returns true otherwise.
691 const LogicalBuffer* buffer = GetBufferDefinedAt(*it, index).ValueOrDie();
692 for (const BufferAlias& alias : GetBufferAliases(*buffer)) {
693 for (HloInstruction* alias_user : alias.instruction()->users()) {
694 if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
695 alias_user)) {
696 continue;
697 }
698 // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'.
699 return false;
700 }
701 }
702 // Return true: found no uses of 'operand' at 'index' in 'user'.
703 return true;
704 }
705 return false;
706 }
707
708 // Returns all uses of all aliases of 'instruction' at 'index' in 'uses'.
709 // Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index)
710 // where 'user' is a user of an alias of 'instruction' at 'index', and
711 // 'operand_index' is the operand index at which the alias appears in the
712 // operand list of 'user'.
713 std::vector<std::pair<HloInstruction*, int64>>
GetAllUsesOfInstructionAtIndex(HloInstruction * instruction,const ShapeIndex & index) const714 TuplePointsToAnalysis::GetAllUsesOfInstructionAtIndex(
715 HloInstruction* instruction, const ShapeIndex& index) const {
716 std::vector<std::pair<HloInstruction*, int64>> uses;
717 const PointsToSet::BufferList& points_to =
718 GetPointsToSet(instruction).element(index);
719 for (const LogicalBuffer* buffer : points_to) {
720 for (const BufferAlias& alias : GetBufferAliases(*buffer)) {
721 for (HloInstruction* alias_user : alias.instruction()->users()) {
722 if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
723 alias_user)) {
724 continue;
725 }
726 for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) {
727 uses.emplace_back(alias_user, op_idx);
728 }
729 }
730 }
731 }
732 return uses;
733 }
734
735 // Returns true if there is exactly one use of 'operand' at 'operand_index'
736 // in 'fusion.fused_instructions', where the singleton use is the fused
737 // root at operand index 'use_operand_index'. Returns false otherwise.
738 //
739 // REQUIRES: 'fusion' opcode is a kFusion instruction.
HasUniqueFusedUseOfOperandAt(HloInstruction * operand,const ShapeIndex & operand_index,HloInstruction * fusion,const int64 use_operand_index) const740 bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt(
741 HloInstruction* operand, const ShapeIndex& operand_index,
742 HloInstruction* fusion, const int64 use_operand_index) const {
743 CHECK_EQ(HloOpcode::kFusion, fusion->opcode());
744 // Check that 'operand' is unique in the operand list of 'fusion'.
745 if (fusion->OperandIndices(operand).size() > 1) {
746 return false;
747 }
748 // Find fusion parameter associated with 'operand'.
749 const auto& fused_params = fusion->fused_parameters();
750 auto fused_param_it =
751 absl::c_find_if(fused_params, [&](HloInstruction* fused_param) {
752 return fusion->operand(fused_param->parameter_number()) == operand;
753 });
754 if (fused_param_it == fused_params.end()) {
755 return false;
756 }
757 auto* fused_param = *fused_param_it;
758 // Get all uses of 'operand' at 'index' from 'fusion.fused_instructions'.
759 auto fused_param_uses =
760 GetAllUsesOfInstructionAtIndex(fused_param, operand_index);
761 // Return true iff there is exactly one use of 'operand' at 'index', and
762 // this singleton use is the fused root (at index in 'use_operand_indices').
763 return fused_param_uses.size() == 1 &&
764 fused_param_uses[0].first == fusion->fused_expression_root() &&
765 fused_param_uses[0].second == use_operand_index;
766 }
767 } // namespace xla
768