• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
17 
18 #include <ostream>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_format.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/compiler/xla/map_util.h"
29 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/platform/logging.h"
36 
37 namespace xla {
38 
ToString() const39 string BufferAlias::ToString() const {
40   return absl::StrCat("BufferAlias(", instruction_->name(), "[",
41                       absl::StrJoin(index_, ","), "])");
42 }
43 
operator <<(std::ostream & out,const BufferAlias & buffer_alias)44 std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias) {
45   out << buffer_alias.ToString();
46   return out;
47 }
48 
IsAmbiguous() const49 bool PointsToSet::IsAmbiguous() const {
50   bool ambiguous = false;
51   ForEachElement(
52       [&ambiguous](const ShapeIndex& /*index*/, const BufferList& points_to) {
53         ambiguous |= points_to.size() > 1;
54       });
55   return ambiguous;
56 }
57 
IsDistinct() const58 bool PointsToSet::IsDistinct() const {
59   bool distinct = true;
60   absl::flat_hash_set<const LogicalBuffer*> all_points_to;
61   ForEachElement([&](const ShapeIndex& /*index*/, const BufferList& points_to) {
62     for (auto& buffer : points_to) {
63       if (all_points_to.contains(buffer)) {
64         distinct = false;
65       }
66       all_points_to.insert(buffer);
67     }
68   });
69   return distinct;
70 }
71 
size() const72 size_t PointsToSet::size() const {
73   // Because pointed-to elements may be duplicated we have to create a flattened
74   // set and return the size.
75   return CreateFlattenedSet().size();
76 }
77 
CreateFlattenedSet() const78 PointsToSet::BufferSet PointsToSet::CreateFlattenedSet() const {
79   BufferSet flat_set;
80   ForEachElement(
81       [&flat_set](const ShapeIndex& /*index*/, const BufferList& buffers) {
82         flat_set.insert(buffers.begin(), buffers.end());
83       });
84   return flat_set;
85 }
86 
ContainsBuffer(const LogicalBuffer & buffer) const87 bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const {
88   bool found = false;
89   ForEachElement([&found, &buffer](const ShapeIndex& /*index*/,
90                                    const BufferList& pointed_to_buffers) {
91     if (!found && absl::c_linear_search(pointed_to_buffers, &buffer)) {
92       found = true;
93     }
94   });
95   return found;
96 }
97 
ContainsBufferAtIndex(const LogicalBuffer & buffer,const ShapeIndex & index) const98 bool PointsToSet::ContainsBufferAtIndex(const LogicalBuffer& buffer,
99                                         const ShapeIndex& index) const {
100   const auto& pointed_to_buffers = element(index);
101   return absl::c_linear_search(pointed_to_buffers, &buffer);
102 }
103 
AddPointedToBuffer(const LogicalBuffer & buffer,const ShapeIndex & index)104 void PointsToSet::AddPointedToBuffer(const LogicalBuffer& buffer,
105                                      const ShapeIndex& index) {
106   if (ContainsBufferAtIndex(buffer, index)) {
107     return;
108   }
109   mutable_element(index)->push_back(&buffer);
110 }
111 
tuple_sources(const ShapeIndex & index) const112 const PointsToSet::SourceSet& PointsToSet::tuple_sources(
113     const ShapeIndex& index) const {
114   return tree_.element(index).tuple_sources;
115 }
116 
add_tuple_source(const ShapeIndex & index,HloInstruction * tuple)117 void PointsToSet::add_tuple_source(const ShapeIndex& index,
118                                    HloInstruction* tuple) {
119   tree_.mutable_element(index)->tuple_sources.insert(tuple);
120 }
121 
122 namespace {
123 // Gather fusion instructions from 'instruction' into 'fusion_instructions'.
GatherFusionInstructions(HloInstruction * instruction,std::vector<HloInstruction * > * fusion_instructions)124 void GatherFusionInstructions(
125     HloInstruction* instruction,
126     std::vector<HloInstruction*>* fusion_instructions) {
127   CHECK_EQ(HloOpcode::kFusion, instruction->opcode());
128   for (auto* fused : instruction->fused_instructions()) {
129     if (fused->opcode() == HloOpcode::kFusion) {
130       GatherFusionInstructions(fused, fusion_instructions);
131     }
132   }
133   fusion_instructions->push_back(instruction);
134 }
135 
136 }  // namespace
137 
138 /* static */ StatusOr<std::unique_ptr<TuplePointsToAnalysis>>
Run(const HloModule * module)139 TuplePointsToAnalysis::Run(const HloModule* module) {
140   auto logical_buffer_analysis = LogicalBufferAnalysis::Run(module);
141   std::unique_ptr<TuplePointsToAnalysis> analysis(new TuplePointsToAnalysis(
142       module, logical_buffer_analysis.ConsumeValueOrDie()));
143   TF_RETURN_IF_ERROR(analysis->Analyze());
144   return std::move(analysis);
145 }
146 
Analyze()147 Status TuplePointsToAnalysis::Analyze() {
148   per_instruction_.clear();
149   per_instruction_.reserve(module_->instruction_count());
150 
151   logical_buffer_aliases_.clear();
152   logical_buffer_aliases_.resize(
153       logical_buffer_analysis_->num_logical_buffers());
154 
155   std::vector<HloInstruction*> fusion_instructions;
156   for (auto* computation : module_->MakeNonfusionComputations()) {
157     TF_RETURN_IF_ERROR(computation->Accept(this));
158     TF_RETURN_IF_ERROR(
159         PopulateDefinedBuffersAndAliases(computation->instructions()));
160     for (auto* instruction : computation->instructions()) {
161       if (instruction->opcode() == HloOpcode::kFusion) {
162         GatherFusionInstructions(instruction, &fusion_instructions);
163       }
164     }
165   }
166   // Run points-to analysis on fusion instructions in 'computation'.
167   for (auto* instruction : fusion_instructions) {
168     TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this));
169     TF_RETURN_IF_ERROR(
170         PopulateDefinedBuffersAndAliases(instruction->fused_instructions()));
171   }
172 
173   XLA_VLOG_LINES(3, ToString());
174 
175   return Status::OK();
176 }
177 
178 Status TuplePointsToAnalysis::PopulateDefinedBuffersAndAliases(const decltype(
179     std::declval<HloComputation>().instructions())& instructions) {
180   for (auto* instruction : instructions) {
181     PerInstruction* pi = PerInst(instruction);
182     TF_RETURN_IF_ERROR(GatherBuffersDefinedByInstruction(
183         instruction, &pi->instruction_defined_buffers));
184 
185     const PointsToSet& points_to_set = GetPointsToSet(instruction);
186     points_to_set.ForEachElement(
187         [this, &instruction](
188             const ShapeIndex& index,
__anon27932d9d0602( const ShapeIndex& index, const PointsToSet::BufferList& pointed_to_buffers) 189             const PointsToSet::BufferList& pointed_to_buffers) {
190           for (const LogicalBuffer* buffer : pointed_to_buffers) {
191             logical_buffer_aliases_[buffer->id()].emplace_back(instruction,
192                                                                index);
193           }
194         });
195   }
196   return Status::OK();
197 }
198 
DefaultAction(HloInstruction * hlo_instruction)199 Status TuplePointsToAnalysis::DefaultAction(HloInstruction* hlo_instruction) {
200   // Create trivial points-to set for instruction. Each points-to set at index i
201   // contains a single element LogicalBuffer(hlo_instruction, i). This indicates
202   // that this instruction is the source of all buffers in its own output.
203   PointsToSet& points_to_set = CreateEmptyPointsToSet(hlo_instruction);
204   points_to_set.ForEachMutableElement(
205       [this, hlo_instruction](const ShapeIndex& index,
206                               PointsToSet::BufferList* buffers) {
207         buffers->push_back(
208             &logical_buffer_analysis_->GetBuffer(hlo_instruction, index));
209       });
210 
211   if (hlo_instruction->shape().IsTuple()) {
212     // If the hlo instruction is a tuple-shaped, then trivially the instruction
213     // itself is the source of the tuple.
214     points_to_set.add_tuple_source({}, hlo_instruction);
215   }
216 
217   return Status::OK();
218 }
219 
HandleGetTupleElement(HloInstruction * get_tuple_element)220 Status TuplePointsToAnalysis::HandleGetTupleElement(
221     HloInstruction* get_tuple_element) {
222   // GetTupleElement forwards a pointer to a particular element of the tuple
223   // operand.
224   int64 element_index = get_tuple_element->tuple_index();
225 
226   PointsToSet& points_to_set = CreateEmptyPointsToSet(get_tuple_element);
227   const PointsToSet& operand_points_to_set =
228       *PerInst(get_tuple_element->operand(0))->points_to_set;
229 
230   // Copy the points-to set (and tuple sources) at index {element_index} of the
231   // operand to the points-to set for this GetTupleElement instruction.
232   points_to_set.ForEachMutableElement(
233       [&](const ShapeIndex& target_index, PointsToSet::BufferList* points_to) {
234         // Construct an index into the operand by prepending element_index to
235         // the index for the GetTupleElement instruction's points-to set.
236         ShapeIndex src_index;
237         src_index.push_back(element_index);
238         for (auto element : target_index) {
239           src_index.push_back(element);
240         }
241 
242         *points_to = operand_points_to_set.element(src_index);
243         for (HloInstruction* tuple :
244              operand_points_to_set.tuple_sources(src_index)) {
245           points_to_set.add_tuple_source(target_index, tuple);
246         }
247       });
248 
249   return Status::OK();
250 }
251 
HandleCopy(HloInstruction * copy)252 Status TuplePointsToAnalysis::HandleCopy(HloInstruction* copy) {
253   // A kCopy instruction performs a shallow copy of the operand. The top-level
254   // buffer (index={}) is newly created, but all other buffers (in the case of a
255   // tuple shape) come from the operand
256   PointsToSet& points_to_set = CreateCopiedPointsToSet(copy, copy->operand(0));
257   points_to_set.mutable_element(/*index=*/{})->clear();
258   points_to_set.AddPointedToBuffer(
259       logical_buffer_analysis_->GetBuffer(copy, /*index=*/{}),
260       /*index=*/{});
261 
262   return Status::OK();
263 }
264 
HandleBitcast(HloInstruction * bitcast)265 Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) {
266   // A kBitcast instruction aliases its operand. That is, the buffer of its
267   // result *is* the buffer of its operand, so just copy the operands points-to
268   // set.
269   CreateCopiedPointsToSet(bitcast, bitcast->operand(0));
270   return Status::OK();
271 }
272 
HandleDomain(HloInstruction * domain)273 Status TuplePointsToAnalysis::HandleDomain(HloInstruction* domain) {
274   // A kDomain instruction aliases its operand. That is, the buffer of its
275   // result *is* the buffer of its operand, so just copy the operands points-to
276   // set.
277   CreateCopiedPointsToSet(domain, domain->operand(0));
278   return Status::OK();
279 }
280 
HandleAddDependency(HloInstruction * add_dependency)281 Status TuplePointsToAnalysis::HandleAddDependency(
282     HloInstruction* add_dependency) {
283   // AddDependency just forwards the value of its zero-th operand.
284   CreateCopiedPointsToSet(add_dependency, add_dependency->operand(0));
285   return Status::OK();
286 }
287 
HandleRecvDone(HloInstruction * recv_done)288 Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) {
289   // RecvDone aliases its input (Recv) tuple element {0} to element {0} of its
290   // output. The other indices ({} and {1}) define their own buffers.
291   PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done);
292   points_to_set.AddPointedToBuffer(
293       logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{}),
294       /*index=*/{});
295   points_to_set.AddPointedToBuffer(
296       logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{1}),
297       /*index=*/{1});
298 
299   const PointsToSet& operand_points_to_set =
300       GetPointsToSet(recv_done->operand(0));
301 
302   // Recursively copy the points to set of the operand tuple {0} to the output
303   // element {0}.
304   points_to_set.ForEachMutableElement(
305       [&points_to_set, &operand_points_to_set](
306           const ShapeIndex& index, PointsToSet::BufferList* buffers) {
307         if (index.empty() || index[0] != 0) {
308           return;
309         }
310         *buffers = operand_points_to_set.element(index);
311         for (auto& tuple_source : operand_points_to_set.tuple_sources(index)) {
312           points_to_set.add_tuple_source(index, tuple_source);
313         }
314       });
315   return Status::OK();
316 }
317 
HandleCopyStart(HloInstruction * copy_start)318 Status TuplePointsToAnalysis::HandleCopyStart(HloInstruction* copy_start) {
319   // CopyStart forwards its aliased operand to {1}.
320   PointsToSet& points_to_set = CreateEmptyPointsToSet(copy_start);
321   const PointsToSet& operand_points_to_set =
322       GetPointsToSet(copy_start->operand(0));
323 
324   points_to_set.ForEachMutableElement(
325       [&](const ShapeIndex& target_index, PointsToSet::BufferList* buffers) {
326         if (target_index == ShapeIndex({1})) {
327           *buffers = operand_points_to_set.element(/*index=*/{});
328         } else {
329           buffers->push_back(
330               &logical_buffer_analysis_->GetBuffer(copy_start, target_index));
331         }
332       });
333 
334   for (HloInstruction* tuple :
335        operand_points_to_set.tuple_sources(/*index=*/{})) {
336     points_to_set.add_tuple_source(/*index=*/{1}, tuple);
337   }
338 
339   return Status::OK();
340 }
341 
HandleCopyDone(HloInstruction * copy_done)342 Status TuplePointsToAnalysis::HandleCopyDone(HloInstruction* copy_done) {
343   // CopyDone forwards its aliased operand.
344   PointsToSet& points_to_set = CreateEmptyPointsToSet(copy_done);
345   const PointsToSet& operand_points_to_set =
346       GetPointsToSet(copy_done->operand(0));
347   operand_points_to_set.ForEachElement(
348       [&points_to_set, &operand_points_to_set](
349           const ShapeIndex& src_index,
350           const PointsToSet::BufferList& points_to) {
351         if (src_index == ShapeIndex({0})) {
352           const ShapeIndex target_index = {};
353           *points_to_set.mutable_element(target_index) = points_to;
354 
355           for (HloInstruction* tuple :
356                operand_points_to_set.tuple_sources(src_index)) {
357             points_to_set.add_tuple_source(target_index, tuple);
358           }
359         }
360       });
361 
362   return Status::OK();
363 }
364 
HandleSend(HloInstruction * send)365 Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) {
366   // Send creates a tuple of {aliased operand, U32 context, token}.
367   PointsToSet& points_to_set = CreateEmptyPointsToSet(send);
368 
369   // Creates the points to set for the tuple and its element at {1}.
370   auto top_buffer = points_to_set.mutable_element(ShapeIndex({}));
371   top_buffer->push_back(
372       &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({})));
373   points_to_set.add_tuple_source({}, send);
374 
375   auto context_buffer = points_to_set.mutable_element(ShapeIndex({1}));
376   context_buffer->push_back(
377       &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({1})));
378 
379   auto token_buffer = points_to_set.mutable_element(ShapeIndex({2}));
380   token_buffer->push_back(
381       &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({2})));
382 
383   // Recursively copy the points to set of the operand to output tuple {0}.
384   const PointsToSet& operand_points_to_set = GetPointsToSet(send->operand(0));
385   operand_points_to_set.ForEachElement(
386       [&points_to_set, &operand_points_to_set](
387           const ShapeIndex& src_index,
388           const PointsToSet::BufferList& points_to) {
389         ShapeIndex target_index({0});
390         for (auto element : src_index) {
391           target_index.push_back(element);
392         }
393         *points_to_set.mutable_element(target_index) = points_to;
394 
395         for (HloInstruction* tuple :
396              operand_points_to_set.tuple_sources(src_index)) {
397           points_to_set.add_tuple_source(target_index, tuple);
398         }
399       });
400 
401   return Status::OK();
402 }
403 
HandleTuple(HloInstruction * tuple)404 Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) {
405   absl::Span<HloInstruction* const> operands(tuple->operands());
406   PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple);
407   points_to_set.AddPointedToBuffer(
408       logical_buffer_analysis_->GetBuffer(tuple, /*index=*/{}),
409       /*index=*/{});
410 
411   // A tuple contains references to all input operands and transitively any
412   // references in those operands.
413   for (int64 i = 0; i < operands.size(); ++i) {
414     const PointsToSet& operand_points_to_set =
415         *PerInst(operands[i])->points_to_set;
416 
417     // Copy the points-to set (and tuple sources) of the operand into the
418     // respective subtree of the tuple instructions points-to set.
419     operand_points_to_set.ForEachElement(
420         [&points_to_set, &operand_points_to_set, i](
421             const ShapeIndex& src_index,
422             const PointsToSet::BufferList& points_to) {
423           ShapeIndex target_index;
424           target_index.push_back(i);
425           for (auto element : src_index) {
426             target_index.push_back(element);
427           }
428 
429           *points_to_set.mutable_element(target_index) = points_to;
430 
431           for (HloInstruction* tuple :
432                operand_points_to_set.tuple_sources(src_index)) {
433             points_to_set.add_tuple_source(target_index, tuple);
434           }
435         });
436   }
437 
438   points_to_set.add_tuple_source({}, tuple);
439 
440   return Status::OK();
441 }
442 
HandleTupleSelect(HloInstruction * tuple_select)443 Status TuplePointsToAnalysis::HandleTupleSelect(HloInstruction* tuple_select) {
444   // Select allocates a new buffer and then shallow copies the on_true or
445   // on_false buffer into this new buffer. Which side is chosen cannot be
446   // determined statically so conservatively set the points-to set to the union
447   // of these on_true and on_false operands.
448   //
449   // First create a copy of the on_true points-to set (and tuple sources), then
450   // add in elements of the on_false points-to set (tuple sources).
451   auto on_true = tuple_select->operand(1);
452   auto on_false = tuple_select->operand(2);
453   PointsToSet& points_to_set = CreateCopiedPointsToSet(tuple_select, on_true);
454   const PointsToSet& false_points_to_set = *PerInst(on_false)->points_to_set;
455   points_to_set.ForEachMutableElement(
456       [&](const ShapeIndex& index, PointsToSet::BufferList* buffers) {
457         for (const LogicalBuffer* false_buffer :
458              false_points_to_set.element(index)) {
459           points_to_set.AddPointedToBuffer(*false_buffer, index);
460         }
461 
462         for (HloInstruction* tuple : false_points_to_set.tuple_sources(index)) {
463           points_to_set.add_tuple_source(index, tuple);
464         }
465       });
466 
467   // Select creates a new (top-level) buffer to store its result, so its
468   // respective element in the points-to set should contain only itself.
469   points_to_set.mutable_element({})->clear();
470   points_to_set.AddPointedToBuffer(
471       logical_buffer_analysis_->GetBuffer(tuple_select, /*index=*/{}),
472       /*index=*/{});
473   return Status::OK();
474 }
475 
GetPointsToSet(const HloInstruction * hlo_instruction) const476 const PointsToSet& TuplePointsToAnalysis::GetPointsToSet(
477     const HloInstruction* hlo_instruction) const {
478   return *PerInst(hlo_instruction)->points_to_set;
479 }
480 
CreateEmptyPointsToSet(const HloInstruction * instruction)481 PointsToSet& TuplePointsToAnalysis::CreateEmptyPointsToSet(
482     const HloInstruction* instruction) {
483   PerInstruction* pi = PerInst(instruction);
484   CHECK(pi->points_to_set == nullptr)
485       << "instruction should not have been present in the map.";
486   auto set = absl::make_unique<PointsToSet>(&instruction->shape());
487   pi->points_to_set = std::move(set);
488   // Return *set using the iterator returned by emplace.
489   return *pi->points_to_set;
490 }
491 
InstructionDefinesBufferAtIndex(const HloInstruction * instruction,const ShapeIndex & index) const492 bool TuplePointsToAnalysis::InstructionDefinesBufferAtIndex(
493     const HloInstruction* instruction, const ShapeIndex& index) const {
494   const auto& buffers = GetPointsToSet(instruction).element(index);
495   return (buffers.size() == 1 && buffers[0]->instruction() == instruction);
496 }
497 
VerifyBuffer(const LogicalBuffer & buffer) const498 Status TuplePointsToAnalysis::VerifyBuffer(const LogicalBuffer& buffer) const {
499   if (!InstructionDefinesBufferAtIndex(buffer.instruction(), buffer.index())) {
500     return FailedPrecondition(
501         "LogicalBuffer %s is ill-defined: instruction %s does not define a "
502         "buffer at that index",
503         buffer.ToString(), buffer.instruction()->name());
504   }
505 
506   if (buffer.id() < 0 ||
507       buffer.id() >= logical_buffer_analysis_->num_logical_buffers()) {
508     return FailedPrecondition("LogicalBuffer %s is ill-defined: invalid id %d",
509                               buffer.ToString(), buffer.id());
510   }
511   if (GetBuffer(buffer.id()).instruction() != buffer.instruction() ||
512       GetBuffer(buffer.id()).index() != buffer.index()) {
513     return FailedPrecondition(
514         "LogicalBuffer %s is ill-defined: buffer with same id differs: %s",
515         buffer.ToString(), GetBuffer(buffer.id()).ToString());
516   }
517 
518   return Status::OK();
519 }
520 
GetBuffer(LogicalBuffer::Id id) const521 const LogicalBuffer& TuplePointsToAnalysis::GetBuffer(
522     LogicalBuffer::Id id) const {
523   CHECK_GE(id, 0);
524   CHECK_LT(id, logical_buffer_analysis_->num_logical_buffers());
525   return logical_buffer_analysis_->GetBuffer(id);
526 }
527 
GetBufferDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const528 StatusOr<const LogicalBuffer*> TuplePointsToAnalysis::GetBufferDefinedAt(
529     const HloInstruction* instruction, const ShapeIndex& index) const {
530   const auto& buffers = GetPointsToSet(instruction).element(index);
531   if (buffers.size() != 1 || buffers[0]->instruction() != instruction) {
532     return FailedPrecondition(
533         "instruction %s does not define buffer at index {%s}",
534         instruction->name(), absl::StrJoin(index, ","));
535   }
536   return buffers[0];
537 }
538 
539 const TuplePointsToAnalysis::BufferAliasVector&
GetBufferAliases(const LogicalBuffer & buffer) const540 TuplePointsToAnalysis::GetBufferAliases(const LogicalBuffer& buffer) const {
541   return logical_buffer_aliases_.at(buffer.id());
542 }
543 
544 const TuplePointsToAnalysis::BufferDefinitionVector&
GetBuffersDefinedByInstruction(const HloInstruction * instruction) const545 TuplePointsToAnalysis::GetBuffersDefinedByInstruction(
546     const HloInstruction* instruction) const {
547   return PerInst(instruction)->instruction_defined_buffers;
548 }
549 
GatherBuffersDefinedByInstruction(const HloInstruction * instruction,TuplePointsToAnalysis::BufferDefinitionVector * buffers)550 Status TuplePointsToAnalysis::GatherBuffersDefinedByInstruction(
551     const HloInstruction* instruction,
552     TuplePointsToAnalysis::BufferDefinitionVector* buffers) {
553   GetPointsToSet(instruction)
554       .ForEachElement([buffers, instruction](
555                           const ShapeIndex& index,
556                           const PointsToSet::BufferList& source_buffers) {
557         // Add buffers which 'instruction' is the source of.
558         CHECK(!source_buffers.empty());
559         if (source_buffers.size() == 1 &&
560             source_buffers[0]->instruction() == instruction) {
561           // If this instruction is the source of this buffer the
562           // indices must match.
563           DCHECK(source_buffers[0]->index() == index);
564           buffers->push_back(source_buffers[0]);
565         } else {
566           // If the points-to set includes more than one buffer then
567           // necessarily this instruction did not produce the
568           // buffer.
569           for (const LogicalBuffer* source_buffer : source_buffers) {
570             DCHECK(source_buffer->instruction() != instruction);
571           }
572         }
573       });
574   return Status::OK();
575 }
576 
CreateCopiedPointsToSet(const HloInstruction * instruction,const HloInstruction * src)577 PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet(
578     const HloInstruction* instruction, const HloInstruction* src) {
579   // PointsToSet doesn't have a copy constructor so copy over element-by-element
580   // from src PointsToSet.
581   PointsToSet& dst_points_to_set = CreateEmptyPointsToSet(instruction);
582   const PointsToSet& src_points_to_set = GetPointsToSet(src);
583   dst_points_to_set.ForEachMutableElement(
584       [&dst_points_to_set, &src_points_to_set](
585           const ShapeIndex& index, PointsToSet::BufferList* buffers) {
586         *buffers = src_points_to_set.element(index);
587         for (auto& tuple_source : src_points_to_set.tuple_sources(index)) {
588           dst_points_to_set.add_tuple_source(index, tuple_source);
589         }
590       });
591   return *PerInst(instruction)->points_to_set;
592 }
593 
ToString() const594 string TuplePointsToAnalysis::ToString() const {
595   string output =
596       absl::StrFormat("TuplePointsToSet for module %s:\n", module_->name());
597   for (const auto* computation : module_->MakeNonfusionComputations()) {
598     const char* entry =
599         computation == module_->entry_computation() ? "entry " : "";
600     absl::StrAppend(&output, entry, "computation ", computation->name(), ":\n");
601     for (const HloInstruction* instruction :
602          computation->MakeInstructionPostOrder()) {
603       InstructionToString(instruction, &output);
604       if (instruction->opcode() == HloOpcode::kFusion) {
605         for (auto* fused : instruction->fused_instructions()) {
606           InstructionToString(fused, &output);
607         }
608       }
609     }
610   }
611 
612   absl::StrAppend(&output, "LogicalBuffers:\n");
613   for (const auto& b : logical_buffer_analysis_->logical_buffers()) {
614     absl::StrAppend(&output, "  buffer ", b->ToString(), ":\n");
615     for (const BufferAlias& alias : logical_buffer_aliases_.at(b->id())) {
616       absl::StrAppend(&output, "    alias ", alias.ToString(), "\n");
617     }
618   }
619   return output;
620 }
621 
InstructionToString(const HloInstruction * instruction,string * output) const622 void TuplePointsToAnalysis::InstructionToString(
623     const HloInstruction* instruction, string* output) const {
624   const string prefix = instruction->IsFused() ? "    " : "";
625   absl::StrAppend(output, prefix, "  instruction ",
626                   instruction->ToShortString(), ":\n");
627   const PointsToSet& points_to_set = GetPointsToSet(instruction);
628   points_to_set.ForEachElement([&prefix, &output](
629                                    const ShapeIndex& index,
630                                    const PointsToSet::BufferList& points_to) {
631     absl::StrAppend(output, prefix, "    {", absl::StrJoin(index, ","), "}: ",
632                     absl::StrJoin(points_to, ", ",
633                                   [](string* out, const LogicalBuffer* source) {
634                                     out->append(source->ToString());
635                                   }),
636                     "\n");
637   });
638 }
639 
DoesNotUseOperandBuffer(const HloInstruction * operand,const ShapeIndex & index,const HloInstruction * user) const640 bool TuplePointsToAnalysis::DoesNotUseOperandBuffer(
641     const HloInstruction* operand, const ShapeIndex& index,
642     const HloInstruction* user) const {
643   CHECK(user->IsUserOf(operand))
644       << "user: " << user->ToString() << " operand: " << operand->ToString();
645   if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) {
646     // GetTupleElement instructions only access the top-level buffer of their
647     // operand.
648     return true;
649   } else if (user->IsLoopFusion()) {
650     // Find fusion parameter associated with 'operand'.
651     auto it = absl::c_find_if(
652         user->fused_parameters(), [&](HloInstruction* fused_param) {
653           return user->operand(fused_param->parameter_number()) == operand;
654         });
655     CHECK(it != user->fused_parameters().end());
656     // Iterate through all users of all buffer aliases of the buffer in the
657     // points-to set of fusion parameter at 'index'.
658     // Return false if any uses are detected at 'index', returns true otherwise.
659     const LogicalBuffer* buffer = GetBufferDefinedAt(*it, index).ValueOrDie();
660     for (const BufferAlias& alias : GetBufferAliases(*buffer)) {
661       for (HloInstruction* alias_user : alias.instruction()->users()) {
662         if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
663                                     alias_user)) {
664           continue;
665         }
666         // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'.
667         return false;
668       }
669     }
670     // Return true: found no uses of 'operand' at 'index' in 'user'.
671     return true;
672   }
673   return false;
674 }
675 
676 // Returns all uses of all aliases of 'instruction' at 'index' in 'uses'.
677 // Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index)
678 // where 'user' is a user of an alias of 'instruction' at 'index', and
679 // 'operand_index' is the operand index at which the alias appears in the
680 // operand list of 'user'.
681 std::vector<std::pair<HloInstruction*, int64>>
GetAllUsesOfInstructionAtIndex(HloInstruction * instruction,const ShapeIndex & index) const682 TuplePointsToAnalysis::GetAllUsesOfInstructionAtIndex(
683     HloInstruction* instruction, const ShapeIndex& index) const {
684   std::vector<std::pair<HloInstruction*, int64>> uses;
685   const PointsToSet::BufferList& points_to =
686       GetPointsToSet(instruction).element(index);
687   for (const LogicalBuffer* buffer : points_to) {
688     for (const BufferAlias& alias : GetBufferAliases(*buffer)) {
689       for (HloInstruction* alias_user : alias.instruction()->users()) {
690         if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
691                                     alias_user)) {
692           continue;
693         }
694         for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) {
695           uses.emplace_back(alias_user, op_idx);
696         }
697       }
698     }
699   }
700   return uses;
701 }
702 
703 // Returns true if there is exactly one use of 'operand' at 'operand_index'
704 // in 'fusion.fused_instructions', where the singleton use is the fused
705 // root at operand index 'use_operand_index'. Returns false otherwise.
706 //
707 // REQUIRES: 'fusion' opcode is a kFusion instruction.
HasUniqueFusedUseOfOperandAt(HloInstruction * operand,const ShapeIndex & operand_index,HloInstruction * fusion,const int64 use_operand_index) const708 bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt(
709     HloInstruction* operand, const ShapeIndex& operand_index,
710     HloInstruction* fusion, const int64 use_operand_index) const {
711   CHECK_EQ(HloOpcode::kFusion, fusion->opcode());
712   // Check that 'operand' is unique in the operand list of 'fusion'.
713   if (fusion->OperandIndices(operand).size() > 1) {
714     return false;
715   }
716   // Find fusion parameter associated with 'operand'.
717   const auto& fused_params = fusion->fused_parameters();
718   auto fused_param_it =
719       absl::c_find_if(fused_params, [&](HloInstruction* fused_param) {
720         return fusion->operand(fused_param->parameter_number()) == operand;
721       });
722   if (fused_param_it == fused_params.end()) {
723     return false;
724   }
725   auto* fused_param = *fused_param_it;
726   // Get all uses of 'operand' at 'index' from 'fusion.fused_instructions'.
727   auto fused_param_uses =
728       GetAllUsesOfInstructionAtIndex(fused_param, operand_index);
729   // Return true iff there is exactly one use of 'operand' at 'index', and
730   // this singleton use is the fused root (at index in 'use_operand_indices').
731   return fused_param_uses.size() == 1 &&
732          fused_param_uses[0].first == fusion->fused_expression_root() &&
733          fused_param_uses[0].second == use_operand_index;
734 }
735 }  // namespace xla
736