1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
17
18 #include <memory>
19 #include <ostream>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/algorithm/container.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_format.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/compiler/xla/map_util.h"
29 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/platform/logging.h"
38
39 namespace xla {
40
ToString() const41 std::string BufferAlias::ToString() const {
42 return absl::StrCat("BufferAlias(", instruction_->name(), "[",
43 absl::StrJoin(index_, ","), "])");
44 }
45
operator <<(std::ostream & out,const BufferAlias & buffer_alias)46 std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias) {
47 out << buffer_alias.ToString();
48 return out;
49 }
50
IsAmbiguous() const51 bool PointsToSet::IsAmbiguous() const {
52 bool ambiguous = false;
53 ForEachElement(
54 [&ambiguous](const ShapeIndex& /*index*/, const BufferList& points_to) {
55 ambiguous |= points_to.size() > 1;
56 });
57 return ambiguous;
58 }
59
IsDistinct() const60 bool PointsToSet::IsDistinct() const {
61 bool distinct = true;
62 absl::flat_hash_set<const LogicalBuffer*> all_points_to;
63 ForEachElement([&](const ShapeIndex& /*index*/, const BufferList& points_to) {
64 for (auto& buffer : points_to) {
65 if (all_points_to.contains(buffer)) {
66 distinct = false;
67 }
68 all_points_to.insert(buffer);
69 }
70 });
71 return distinct;
72 }
73
size() const74 size_t PointsToSet::size() const {
75 // Because pointed-to elements may be duplicated we have to create a flattened
76 // set and return the size.
77 return CreateFlattenedSet().size();
78 }
79
CreateFlattenedSet() const80 PointsToSet::BufferSet PointsToSet::CreateFlattenedSet() const {
81 BufferSet flat_set;
82 ForEachElement(
83 [&flat_set](const ShapeIndex& /*index*/, const BufferList& buffers) {
84 flat_set.insert(buffers.begin(), buffers.end());
85 });
86 return flat_set;
87 }
88
ContainsBuffer(const LogicalBuffer & buffer) const89 bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const {
90 bool found = false;
91 ForEachElement([&found, &buffer](const ShapeIndex& /*index*/,
92 const BufferList& pointed_to_buffers) {
93 if (!found && absl::c_linear_search(pointed_to_buffers, &buffer)) {
94 found = true;
95 }
96 });
97 return found;
98 }
99
ContainsBufferAtIndex(const LogicalBuffer & buffer,const ShapeIndex & index) const100 bool PointsToSet::ContainsBufferAtIndex(const LogicalBuffer& buffer,
101 const ShapeIndex& index) const {
102 const auto& pointed_to_buffers = element(index);
103 return absl::c_linear_search(pointed_to_buffers, &buffer);
104 }
105
AddPointedToBuffer(const LogicalBuffer & buffer,const ShapeIndex & index)106 void PointsToSet::AddPointedToBuffer(const LogicalBuffer& buffer,
107 const ShapeIndex& index) {
108 if (ContainsBufferAtIndex(buffer, index)) {
109 return;
110 }
111 mutable_element(index)->push_back(&buffer);
112 }
113
tuple_sources(const ShapeIndex & index) const114 const PointsToSet::SourceSet& PointsToSet::tuple_sources(
115 const ShapeIndex& index) const {
116 return tree_.element(index).tuple_sources;
117 }
118
add_tuple_source(const ShapeIndex & index,HloInstruction * tuple)119 void PointsToSet::add_tuple_source(const ShapeIndex& index,
120 HloInstruction* tuple) {
121 tree_.mutable_element(index)->tuple_sources.insert(tuple);
122 }
123
124 namespace {
125 // Gather fusion instructions from 'instruction' into 'fusion_instructions'.
GatherFusionInstructions(HloInstruction * instruction,std::vector<HloInstruction * > * fusion_instructions)126 void GatherFusionInstructions(
127 HloInstruction* instruction,
128 std::vector<HloInstruction*>* fusion_instructions) {
129 CHECK_EQ(HloOpcode::kFusion, instruction->opcode());
130 for (auto* fused : instruction->fused_instructions()) {
131 if (fused->opcode() == HloOpcode::kFusion) {
132 GatherFusionInstructions(fused, fusion_instructions);
133 }
134 }
135 fusion_instructions->push_back(instruction);
136 }
137
138 } // namespace
139
140 /* static */ StatusOr<std::unique_ptr<TuplePointsToAnalysis>>
Run(const HloModule * module)141 TuplePointsToAnalysis::Run(const HloModule* module) {
142 auto logical_buffer_analysis = LogicalBufferAnalysis::Run(module);
143 std::unique_ptr<TuplePointsToAnalysis> analysis(new TuplePointsToAnalysis(
144 module, std::move(logical_buffer_analysis).value()));
145 TF_RETURN_IF_ERROR(analysis->Analyze());
146 return std::move(analysis);
147 }
148
Analyze()149 Status TuplePointsToAnalysis::Analyze() {
150 per_instruction_.clear();
151 per_instruction_.reserve(module_->instruction_count());
152
153 logical_buffer_aliases_.clear();
154 logical_buffer_aliases_.resize(
155 logical_buffer_analysis_->num_logical_buffers());
156
157 std::vector<HloInstruction*> fusion_instructions;
158 for (auto* computation : module_->MakeNonfusionComputations()) {
159 TF_RETURN_IF_ERROR(computation->Accept(this));
160 TF_RETURN_IF_ERROR(
161 PopulateDefinedBuffersAndAliases(computation->instructions()));
162 for (auto* instruction : computation->instructions()) {
163 if (instruction->opcode() == HloOpcode::kFusion) {
164 GatherFusionInstructions(instruction, &fusion_instructions);
165 }
166 }
167 }
168 // Run points-to analysis on fusion instructions in 'computation'.
169 for (auto* instruction : fusion_instructions) {
170 TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this));
171 TF_RETURN_IF_ERROR(
172 PopulateDefinedBuffersAndAliases(instruction->fused_instructions()));
173 }
174
175 XLA_VLOG_LINES(3, ToString());
176
177 return OkStatus();
178 }
179
180 Status TuplePointsToAnalysis::PopulateDefinedBuffersAndAliases(
181 const decltype(std::declval<HloComputation>()
182 .instructions())& instructions) {
183 for (auto* instruction : instructions) {
184 PerInstruction* pi = PerInst(instruction);
185 TF_RETURN_IF_ERROR(GatherBuffersDefinedByInstruction(
186 instruction, &pi->instruction_defined_buffers));
187
188 const PointsToSet& points_to_set = GetPointsToSet(instruction);
189 points_to_set.ForEachElement(
190 [this, &instruction](
191 const ShapeIndex& index,
__anon819384570602( const ShapeIndex& index, const PointsToSet::BufferList& pointed_to_buffers) 192 const PointsToSet::BufferList& pointed_to_buffers) {
193 for (const LogicalBuffer* buffer : pointed_to_buffers) {
194 logical_buffer_aliases_[buffer->id()].emplace_back(instruction,
195 index);
196 }
197 });
198 }
199 return OkStatus();
200 }
201
DefaultAction(HloInstruction * hlo_instruction)202 Status TuplePointsToAnalysis::DefaultAction(HloInstruction* hlo_instruction) {
203 // Create trivial points-to set for instruction. Each points-to set at index i
204 // contains a single element LogicalBuffer(hlo_instruction, i). This indicates
205 // that this instruction is the source of all buffers in its own output.
206 PointsToSet& points_to_set = CreateEmptyPointsToSet(hlo_instruction);
207 points_to_set.ForEachMutableElement(
208 [this, hlo_instruction](const ShapeIndex& index,
209 PointsToSet::BufferList* buffers) {
210 buffers->push_back(
211 &logical_buffer_analysis_->GetBuffer(hlo_instruction, index));
212 });
213
214 if (hlo_instruction->shape().IsTuple()) {
215 // If the hlo instruction is a tuple-shaped, then trivially the instruction
216 // itself is the source of the tuple.
217 points_to_set.add_tuple_source({}, hlo_instruction);
218 }
219
220 return OkStatus();
221 }
222
HandleGetTupleElement(HloInstruction * get_tuple_element)223 Status TuplePointsToAnalysis::HandleGetTupleElement(
224 HloInstruction* get_tuple_element) {
225 // GetTupleElement forwards a pointer to a particular element of the tuple
226 // operand.
227 int64_t element_index = get_tuple_element->tuple_index();
228
229 PointsToSet& points_to_set = CreateEmptyPointsToSet(get_tuple_element);
230 const PointsToSet& operand_points_to_set =
231 *PerInst(get_tuple_element->operand(0))->points_to_set;
232
233 // Copy the points-to set (and tuple sources) at index {element_index} of the
234 // operand to the points-to set for this GetTupleElement instruction.
235 points_to_set.ForEachMutableElement(
236 [&](const ShapeIndex& target_index, PointsToSet::BufferList* points_to) {
237 // Construct an index into the operand by prepending element_index to
238 // the index for the GetTupleElement instruction's points-to set.
239 ShapeIndex src_index;
240 src_index.push_back(element_index);
241 for (auto element : target_index) {
242 src_index.push_back(element);
243 }
244
245 *points_to = operand_points_to_set.element(src_index);
246 for (HloInstruction* tuple :
247 operand_points_to_set.tuple_sources(src_index)) {
248 points_to_set.add_tuple_source(target_index, tuple);
249 }
250 });
251
252 return OkStatus();
253 }
254
HandleCopy(HloInstruction * copy)255 Status TuplePointsToAnalysis::HandleCopy(HloInstruction* copy) {
256 // A kCopy instruction performs a shallow copy of the operand. The top-level
257 // buffer (index={}) is newly created, but all other buffers (in the case of a
258 // tuple shape) come from the operand
259 PointsToSet& points_to_set = CreateCopiedPointsToSet(copy, copy->operand(0));
260 points_to_set.mutable_element(/*index=*/{})->clear();
261 points_to_set.AddPointedToBuffer(
262 logical_buffer_analysis_->GetBuffer(copy, /*index=*/{}),
263 /*index=*/{});
264
265 return OkStatus();
266 }
267
HandleBitcast(HloInstruction * bitcast)268 Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) {
269 // A kBitcast instruction aliases its operand. That is, the buffer of its
270 // result *is* the buffer of its operand, so just copy the operands points-to
271 // set.
272 CreateCopiedPointsToSet(bitcast, bitcast->operand(0));
273 return OkStatus();
274 }
275
HandleDomain(HloInstruction * domain)276 Status TuplePointsToAnalysis::HandleDomain(HloInstruction* domain) {
277 // A kDomain instruction aliases its operand. That is, the buffer of its
278 // result *is* the buffer of its operand, so just copy the operands points-to
279 // set.
280 CreateCopiedPointsToSet(domain, domain->operand(0));
281 return OkStatus();
282 }
283
HandleAddDependency(HloInstruction * add_dependency)284 Status TuplePointsToAnalysis::HandleAddDependency(
285 HloInstruction* add_dependency) {
286 // AddDependency just forwards the value of its zero-th operand.
287 CreateCopiedPointsToSet(add_dependency, add_dependency->operand(0));
288 return OkStatus();
289 }
290
HandleRecvDone(HloInstruction * recv_done)291 Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) {
292 // RecvDone aliases its input (Recv) tuple element {0} to element {0} of its
293 // output. The other indices ({} and {1}) define their own buffers.
294 PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done);
295 points_to_set.AddPointedToBuffer(
296 logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{}),
297 /*index=*/{});
298 points_to_set.AddPointedToBuffer(
299 logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{1}),
300 /*index=*/{1});
301
302 const PointsToSet& operand_points_to_set =
303 GetPointsToSet(recv_done->operand(0));
304
305 // Recursively copy the points to set of the operand tuple {0} to the output
306 // element {0}.
307 points_to_set.ForEachMutableElement(
308 [&points_to_set, &operand_points_to_set](
309 const ShapeIndex& index, PointsToSet::BufferList* buffers) {
310 if (index.empty() || index[0] != 0) {
311 return;
312 }
313 *buffers = operand_points_to_set.element(index);
314 for (auto& tuple_source : operand_points_to_set.tuple_sources(index)) {
315 points_to_set.add_tuple_source(index, tuple_source);
316 }
317 });
318 return OkStatus();
319 }
320
HandleAsyncStart(HloInstruction * async_start)321 Status TuplePointsToAnalysis::HandleAsyncStart(HloInstruction* async_start) {
322 // AsyncStart forwards its aliased operands to {0}.
323 PointsToSet& points_to_set = CreateEmptyPointsToSet(async_start);
324
325 points_to_set.ForEachMutableElement(
326 [&](const ShapeIndex& target_index, PointsToSet::BufferList* buffers) {
327 if (target_index.size() >= 2 && target_index.front() == 0) {
328 const PointsToSet& operand_points_to_set =
329 GetPointsToSet(async_start->operand(target_index.at(1)));
330 ShapeIndex source_index(target_index.begin() + 2, target_index.end());
331 *buffers = operand_points_to_set.element(source_index);
332 for (HloInstruction* tuple :
333 operand_points_to_set.tuple_sources(source_index)) {
334 points_to_set.add_tuple_source(target_index, tuple);
335 }
336 } else {
337 buffers->push_back(
338 &logical_buffer_analysis_->GetBuffer(async_start, target_index));
339 }
340 });
341
342 return OkStatus();
343 }
344
HandleAsyncUpdate(HloInstruction * async_update)345 Status TuplePointsToAnalysis::HandleAsyncUpdate(HloInstruction* async_update) {
346 // AsyncUpdate forwards its aliased operand to {}.
347 PointsToSet& points_to_set = CreateEmptyPointsToSet(async_update);
348 const PointsToSet& operand_points_to_set =
349 GetPointsToSet(async_update->operand(0));
350 CHECK_EQ(async_update->shape(), async_update->operand(0)->shape());
351
352 points_to_set.ForEachMutableElement([&](const ShapeIndex& index,
353 PointsToSet::BufferList* buffers) {
354 *buffers = operand_points_to_set.element(index);
355 for (HloInstruction* tuple : operand_points_to_set.tuple_sources(index)) {
356 points_to_set.add_tuple_source(index, tuple);
357 }
358 });
359
360 return OkStatus();
361 }
362
HandleAsyncDone(HloInstruction * async_done)363 Status TuplePointsToAnalysis::HandleAsyncDone(HloInstruction* async_done) {
364 // AsyncDone forwards its aliased operand.
365 PointsToSet& points_to_set = CreateEmptyPointsToSet(async_done);
366 const PointsToSet& operand_points_to_set =
367 GetPointsToSet(async_done->operand(0));
368 operand_points_to_set.ForEachElement(
369 [&points_to_set, &operand_points_to_set](
370 const ShapeIndex& src_index,
371 const PointsToSet::BufferList& points_to) {
372 if (!src_index.empty() && src_index.front() == 1) {
373 const ShapeIndex target_index(src_index.begin() + 1, src_index.end());
374 *points_to_set.mutable_element(target_index) = points_to;
375
376 for (HloInstruction* tuple :
377 operand_points_to_set.tuple_sources(src_index)) {
378 points_to_set.add_tuple_source(target_index, tuple);
379 }
380 }
381 });
382
383 return OkStatus();
384 }
385
HandleCopyStart(HloInstruction * copy_start)386 Status TuplePointsToAnalysis::HandleCopyStart(HloInstruction* copy_start) {
387 // CopyStart forwards its aliased operand to {1}.
388 PointsToSet& points_to_set = CreateEmptyPointsToSet(copy_start);
389 const PointsToSet& operand_points_to_set =
390 GetPointsToSet(copy_start->operand(0));
391
392 points_to_set.ForEachMutableElement(
393 [&](const ShapeIndex& target_index, PointsToSet::BufferList* buffers) {
394 if (target_index == ShapeIndex({1})) {
395 *buffers = operand_points_to_set.element(/*index=*/{});
396 } else {
397 buffers->push_back(
398 &logical_buffer_analysis_->GetBuffer(copy_start, target_index));
399 }
400 });
401
402 for (HloInstruction* tuple :
403 operand_points_to_set.tuple_sources(/*index=*/{})) {
404 points_to_set.add_tuple_source(/*index=*/{1}, tuple);
405 }
406
407 return OkStatus();
408 }
409
HandleCopyDone(HloInstruction * copy_done)410 Status TuplePointsToAnalysis::HandleCopyDone(HloInstruction* copy_done) {
411 // CopyDone forwards its aliased operand.
412 PointsToSet& points_to_set = CreateEmptyPointsToSet(copy_done);
413 const PointsToSet& operand_points_to_set =
414 GetPointsToSet(copy_done->operand(0));
415 operand_points_to_set.ForEachElement(
416 [&points_to_set, &operand_points_to_set](
417 const ShapeIndex& src_index,
418 const PointsToSet::BufferList& points_to) {
419 if (src_index == ShapeIndex({0})) {
420 const ShapeIndex target_index = {};
421 *points_to_set.mutable_element(target_index) = points_to;
422
423 for (HloInstruction* tuple :
424 operand_points_to_set.tuple_sources(src_index)) {
425 points_to_set.add_tuple_source(target_index, tuple);
426 }
427 }
428 });
429
430 return OkStatus();
431 }
432
HandleSend(HloInstruction * send)433 Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) {
434 // Send creates a tuple of {aliased operand, U32 context, token}.
435 PointsToSet& points_to_set = CreateEmptyPointsToSet(send);
436
437 // Creates the points to set for the tuple and its element at {1}.
438 auto top_buffer = points_to_set.mutable_element(ShapeIndex({}));
439 top_buffer->push_back(
440 &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({})));
441 points_to_set.add_tuple_source({}, send);
442
443 auto context_buffer = points_to_set.mutable_element(ShapeIndex({1}));
444 context_buffer->push_back(
445 &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({1})));
446
447 auto token_buffer = points_to_set.mutable_element(ShapeIndex({2}));
448 token_buffer->push_back(
449 &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({2})));
450
451 // Recursively copy the points to set of the operand to output tuple {0}.
452 const PointsToSet& operand_points_to_set = GetPointsToSet(send->operand(0));
453 operand_points_to_set.ForEachElement(
454 [&points_to_set, &operand_points_to_set](
455 const ShapeIndex& src_index,
456 const PointsToSet::BufferList& points_to) {
457 ShapeIndex target_index({0});
458 for (auto element : src_index) {
459 target_index.push_back(element);
460 }
461 *points_to_set.mutable_element(target_index) = points_to;
462
463 for (HloInstruction* tuple :
464 operand_points_to_set.tuple_sources(src_index)) {
465 points_to_set.add_tuple_source(target_index, tuple);
466 }
467 });
468
469 return OkStatus();
470 }
471
HandleTuple(HloInstruction * tuple)472 Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) {
473 absl::Span<HloInstruction* const> operands(tuple->operands());
474 PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple);
475 points_to_set.AddPointedToBuffer(
476 logical_buffer_analysis_->GetBuffer(tuple, /*index=*/{}),
477 /*index=*/{});
478
479 // A tuple contains references to all input operands and transitively any
480 // references in those operands.
481 for (int64_t i = 0; i < operands.size(); ++i) {
482 const PointsToSet& operand_points_to_set =
483 *PerInst(operands[i])->points_to_set;
484
485 // Copy the points-to set (and tuple sources) of the operand into the
486 // respective subtree of the tuple instructions points-to set.
487 operand_points_to_set.ForEachElement(
488 [&points_to_set, &operand_points_to_set, i](
489 const ShapeIndex& src_index,
490 const PointsToSet::BufferList& points_to) {
491 ShapeIndex target_index;
492 target_index.push_back(i);
493 for (auto element : src_index) {
494 target_index.push_back(element);
495 }
496
497 *points_to_set.mutable_element(target_index) = points_to;
498
499 for (HloInstruction* tuple :
500 operand_points_to_set.tuple_sources(src_index)) {
501 points_to_set.add_tuple_source(target_index, tuple);
502 }
503 });
504 }
505
506 points_to_set.add_tuple_source({}, tuple);
507
508 return OkStatus();
509 }
510
HandleCustomCall(HloInstruction * custom_call)511 Status TuplePointsToAnalysis::HandleCustomCall(HloInstruction* custom_call) {
512 auto ccall = Cast<HloCustomCallInstruction>(custom_call);
513 PointsToSet& points_to_set = CreateEmptyPointsToSet(custom_call);
514 absl::flat_hash_map<ShapeIndex, std::pair<int64_t, ShapeIndex>>
515 aliased_outputs;
516 for (const auto& pair : ccall->output_to_operand_aliasing()) {
517 aliased_outputs.emplace(pair.first, pair.second);
518 }
519 points_to_set.ForEachMutableElement([&](const ShapeIndex& index,
520 PointsToSet::BufferList* buffers) {
521 auto it = aliased_outputs.find(index);
522 if (it == aliased_outputs.end()) {
523 points_to_set.AddPointedToBuffer(
524 logical_buffer_analysis_->GetBuffer(custom_call, index), index);
525 } else {
526 const PointsToSet& input_set =
527 *PerInst(ccall->operand(it->second.first))->points_to_set;
528 for (const LogicalBuffer* input_buffer :
529 input_set.element(it->second.second)) {
530 points_to_set.AddPointedToBuffer(*input_buffer, index);
531 }
532
533 for (HloInstruction* tuple : input_set.tuple_sources(it->second.second)) {
534 points_to_set.add_tuple_source(index, tuple);
535 }
536 }
537 });
538 points_to_set.add_tuple_source({}, custom_call);
539 return OkStatus();
540 }
541
HandleOptimizationBarrier(HloInstruction * barrier)542 Status TuplePointsToAnalysis::HandleOptimizationBarrier(
543 HloInstruction* barrier) {
544 // A kOptimizationBarrier instruction is a no-op.
545 CreateCopiedPointsToSet(barrier, barrier->operand(0));
546 return OkStatus();
547 }
548
GetPointsToSet(const HloInstruction * hlo_instruction) const549 const PointsToSet& TuplePointsToAnalysis::GetPointsToSet(
550 const HloInstruction* hlo_instruction) const {
551 return *PerInst(hlo_instruction)->points_to_set;
552 }
553
CreateEmptyPointsToSet(const HloInstruction * instruction)554 PointsToSet& TuplePointsToAnalysis::CreateEmptyPointsToSet(
555 const HloInstruction* instruction) {
556 PerInstruction* pi = PerInst(instruction);
557 CHECK(pi->points_to_set == nullptr)
558 << "instruction should not have been present in the map.";
559 auto set = std::make_unique<PointsToSet>(&instruction->shape());
560 pi->points_to_set = std::move(set);
561 // Return *set using the iterator returned by emplace.
562 return *pi->points_to_set;
563 }
564
InstructionDefinesBufferAtIndex(const HloInstruction * instruction,const ShapeIndex & index) const565 bool TuplePointsToAnalysis::InstructionDefinesBufferAtIndex(
566 const HloInstruction* instruction, const ShapeIndex& index) const {
567 const auto& buffers = GetPointsToSet(instruction).element(index);
568 return (buffers.size() == 1 && buffers[0]->instruction() == instruction);
569 }
570
VerifyBuffer(const LogicalBuffer & buffer) const571 Status TuplePointsToAnalysis::VerifyBuffer(const LogicalBuffer& buffer) const {
572 if (!InstructionDefinesBufferAtIndex(buffer.instruction(), buffer.index())) {
573 return FailedPrecondition(
574 "LogicalBuffer %s is ill-defined: instruction %s does not define a "
575 "buffer at that index",
576 buffer.ToString(), buffer.instruction()->name());
577 }
578
579 if (buffer.id() < 0 ||
580 buffer.id() >= logical_buffer_analysis_->num_logical_buffers()) {
581 return FailedPrecondition("LogicalBuffer %s is ill-defined: invalid id %d",
582 buffer.ToString(), buffer.id());
583 }
584 if (GetBuffer(buffer.id()).instruction() != buffer.instruction() ||
585 GetBuffer(buffer.id()).index() != buffer.index()) {
586 return FailedPrecondition(
587 "LogicalBuffer %s is ill-defined: buffer with same id differs: %s",
588 buffer.ToString(), GetBuffer(buffer.id()).ToString());
589 }
590
591 return OkStatus();
592 }
593
GetBuffer(LogicalBuffer::Id id) const594 const LogicalBuffer& TuplePointsToAnalysis::GetBuffer(
595 LogicalBuffer::Id id) const {
596 CHECK_GE(id, 0);
597 CHECK_LT(id, logical_buffer_analysis_->num_logical_buffers());
598 return logical_buffer_analysis_->GetBuffer(id);
599 }
600
GetBufferDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const601 StatusOr<const LogicalBuffer*> TuplePointsToAnalysis::GetBufferDefinedAt(
602 const HloInstruction* instruction, const ShapeIndex& index) const {
603 const auto& buffers = GetPointsToSet(instruction).element(index);
604 if (buffers.size() != 1 || buffers[0]->instruction() != instruction) {
605 return FailedPrecondition(
606 "instruction %s does not define buffer at index {%s}",
607 instruction->name(), absl::StrJoin(index, ","));
608 }
609 return buffers[0];
610 }
611
612 const TuplePointsToAnalysis::BufferAliasVector&
GetBufferAliases(const LogicalBuffer & buffer) const613 TuplePointsToAnalysis::GetBufferAliases(const LogicalBuffer& buffer) const {
614 return logical_buffer_aliases_.at(buffer.id());
615 }
616
617 const TuplePointsToAnalysis::BufferDefinitionVector&
GetBuffersDefinedByInstruction(const HloInstruction * instruction) const618 TuplePointsToAnalysis::GetBuffersDefinedByInstruction(
619 const HloInstruction* instruction) const {
620 return PerInst(instruction)->instruction_defined_buffers;
621 }
622
GatherBuffersDefinedByInstruction(const HloInstruction * instruction,TuplePointsToAnalysis::BufferDefinitionVector * buffers)623 Status TuplePointsToAnalysis::GatherBuffersDefinedByInstruction(
624 const HloInstruction* instruction,
625 TuplePointsToAnalysis::BufferDefinitionVector* buffers) {
626 GetPointsToSet(instruction)
627 .ForEachElement([buffers, instruction](
628 const ShapeIndex& index,
629 const PointsToSet::BufferList& source_buffers) {
630 // Add buffers which 'instruction' is the source of.
631 CHECK(!source_buffers.empty());
632 if (source_buffers.size() == 1 &&
633 source_buffers[0]->instruction() == instruction) {
634 // If this instruction is the source of this buffer the
635 // indices must match.
636 DCHECK(source_buffers[0]->index() == index);
637 buffers->push_back(source_buffers[0]);
638 } else {
639 // If the points-to set includes more than one buffer then
640 // necessarily this instruction did not produce the
641 // buffer.
642 for (const LogicalBuffer* source_buffer : source_buffers) {
643 DCHECK(source_buffer->instruction() != instruction);
644 }
645 }
646 });
647 return OkStatus();
648 }
649
CreateCopiedPointsToSet(const HloInstruction * instruction,const HloInstruction * src)650 PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet(
651 const HloInstruction* instruction, const HloInstruction* src) {
652 // PointsToSet doesn't have a copy constructor so copy over element-by-element
653 // from src PointsToSet.
654 PointsToSet& dst_points_to_set = CreateEmptyPointsToSet(instruction);
655 const PointsToSet& src_points_to_set = GetPointsToSet(src);
656 dst_points_to_set.ForEachMutableElement(
657 [&dst_points_to_set, &src_points_to_set](
658 const ShapeIndex& index, PointsToSet::BufferList* buffers) {
659 *buffers = src_points_to_set.element(index);
660 for (auto& tuple_source : src_points_to_set.tuple_sources(index)) {
661 dst_points_to_set.add_tuple_source(index, tuple_source);
662 }
663 });
664 return *PerInst(instruction)->points_to_set;
665 }
666
ToString() const667 std::string TuplePointsToAnalysis::ToString() const {
668 std::string output =
669 absl::StrFormat("TuplePointsToSet for module %s:\n", module_->name());
670 for (const auto* computation : module_->MakeNonfusionComputations()) {
671 const char* entry =
672 computation == module_->entry_computation() ? "entry " : "";
673 absl::StrAppend(&output, entry, "computation ", computation->name(), ":\n");
674 for (const HloInstruction* instruction :
675 computation->MakeInstructionPostOrder()) {
676 InstructionToString(instruction, &output);
677 if (instruction->opcode() == HloOpcode::kFusion) {
678 for (auto* fused : instruction->fused_instructions()) {
679 InstructionToString(fused, &output);
680 }
681 }
682 }
683 }
684
685 absl::StrAppend(&output, "LogicalBuffers:\n");
686 for (const auto& b : logical_buffer_analysis_->logical_buffers()) {
687 absl::StrAppend(&output, " buffer ", b->ToString(), ":\n");
688 for (const BufferAlias& alias : logical_buffer_aliases_.at(b->id())) {
689 absl::StrAppend(&output, " alias ", alias.ToString(), "\n");
690 }
691 }
692 return output;
693 }
694
InstructionToString(const HloInstruction * instruction,std::string * output) const695 void TuplePointsToAnalysis::InstructionToString(
696 const HloInstruction* instruction, std::string* output) const {
697 const std::string prefix = instruction->IsFused() ? " " : "";
698 absl::StrAppend(output, prefix, " instruction ",
699 instruction->ToShortString(), ":\n");
700 const PointsToSet& points_to_set = GetPointsToSet(instruction);
701 points_to_set.ForEachElement(
702 [&prefix, &output](const ShapeIndex& index,
703 const PointsToSet::BufferList& points_to) {
704 absl::StrAppend(
705 output, prefix, " {", absl::StrJoin(index, ","), "}: ",
706 absl::StrJoin(points_to, ", ",
707 [](std::string* out, const LogicalBuffer* source) {
708 out->append(source->ToString());
709 }),
710 "\n");
711 });
712 }
713
DoesNotUseOperandBuffer(const HloInstruction * operand,const ShapeIndex & index,const HloInstruction * user) const714 bool TuplePointsToAnalysis::DoesNotUseOperandBuffer(
715 const HloInstruction* operand, const ShapeIndex& index,
716 const HloInstruction* user) const {
717 CHECK(user->IsUserOf(operand))
718 << "user: " << user->ToString() << " operand: " << operand->ToString();
719 if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) {
720 // GetTupleElement instructions only access the top-level buffer of their
721 // operand.
722 return true;
723 } else if (user->IsLoopFusion()) {
724 // Find fusion parameter associated with 'operand'.
725 auto it = absl::c_find_if(
726 user->fused_parameters(), [&](HloInstruction* fused_param) {
727 return user->operand(fused_param->parameter_number()) == operand;
728 });
729 CHECK(it != user->fused_parameters().end());
730 // Iterate through all users of all buffer aliases of the buffer in the
731 // points-to set of fusion parameter at 'index'.
732 // Return false if any uses are detected at 'index', returns true otherwise.
733 const LogicalBuffer* buffer = GetBufferDefinedAt(*it, index).ValueOrDie();
734 for (const BufferAlias& alias : GetBufferAliases(*buffer)) {
735 for (HloInstruction* alias_user : alias.instruction()->users()) {
736 if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
737 alias_user)) {
738 continue;
739 }
740 // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'.
741 return false;
742 }
743 }
744 // Return true: found no uses of 'operand' at 'index' in 'user'.
745 return true;
746 }
747 return false;
748 }
749
750 // Returns all uses of all aliases of 'instruction' at 'index' in 'uses'.
751 // Each use in 'uses' is a pair (HloInstruction* user, int64_t operand_index)
752 // where 'user' is a user of an alias of 'instruction' at 'index', and
753 // 'operand_index' is the operand index at which the alias appears in the
754 // operand list of 'user'.
755 std::vector<std::pair<HloInstruction*, int64_t>>
GetAllUsesOfInstructionAtIndex(HloInstruction * instruction,const ShapeIndex & index) const756 TuplePointsToAnalysis::GetAllUsesOfInstructionAtIndex(
757 HloInstruction* instruction, const ShapeIndex& index) const {
758 std::vector<std::pair<HloInstruction*, int64_t>> uses;
759 const PointsToSet::BufferList& points_to =
760 GetPointsToSet(instruction).element(index);
761 for (const LogicalBuffer* buffer : points_to) {
762 for (const BufferAlias& alias : GetBufferAliases(*buffer)) {
763 for (HloInstruction* alias_user : alias.instruction()->users()) {
764 if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
765 alias_user)) {
766 continue;
767 }
768 for (int64_t op_idx : alias_user->OperandIndices(alias.instruction())) {
769 uses.emplace_back(alias_user, op_idx);
770 }
771 }
772 }
773 }
774 return uses;
775 }
776
777 // Returns true if there is exactly one use of 'operand' at 'operand_index'
778 // in 'fusion.fused_instructions', where the singleton use is the fused
779 // root at operand index 'use_operand_index'. Returns false otherwise.
780 //
781 // REQUIRES: 'fusion' opcode is a kFusion instruction.
HasUniqueFusedUseOfOperandAt(HloInstruction * operand,const ShapeIndex & operand_index,HloInstruction * fusion,const int64_t use_operand_index) const782 bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt(
783 HloInstruction* operand, const ShapeIndex& operand_index,
784 HloInstruction* fusion, const int64_t use_operand_index) const {
785 CHECK_EQ(HloOpcode::kFusion, fusion->opcode());
786 // Check that 'operand' is unique in the operand list of 'fusion'.
787 if (fusion->OperandIndices(operand).size() > 1) {
788 return false;
789 }
790 // Find fusion parameter associated with 'operand'.
791 const auto& fused_params = fusion->fused_parameters();
792 auto fused_param_it =
793 absl::c_find_if(fused_params, [&](HloInstruction* fused_param) {
794 return fusion->operand(fused_param->parameter_number()) == operand;
795 });
796 if (fused_param_it == fused_params.end()) {
797 return false;
798 }
799 auto* fused_param = *fused_param_it;
800 // Get all uses of 'operand' at 'index' from 'fusion.fused_instructions'.
801 auto fused_param_uses =
802 GetAllUsesOfInstructionAtIndex(fused_param, operand_index);
803 // Return true iff there is exactly one use of 'operand' at 'index', and
804 // this singleton use is the fused root (at index in 'use_operand_indices').
805 return fused_param_uses.size() == 1 &&
806 fused_param_uses[0].first == fusion->fused_expression_root() &&
807 fused_param_uses[0].second == use_operand_index;
808 }
809 } // namespace xla
810