• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
17 
18 #include <memory>
19 #include <ostream>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/algorithm/container.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/str_format.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/compiler/xla/map_util.h"
29 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/platform/logging.h"
38 
39 namespace xla {
40 
ToString() const41 std::string BufferAlias::ToString() const {
42   return absl::StrCat("BufferAlias(", instruction_->name(), "[",
43                       absl::StrJoin(index_, ","), "])");
44 }
45 
operator <<(std::ostream & out,const BufferAlias & buffer_alias)46 std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias) {
47   out << buffer_alias.ToString();
48   return out;
49 }
50 
IsAmbiguous() const51 bool PointsToSet::IsAmbiguous() const {
52   bool ambiguous = false;
53   ForEachElement(
54       [&ambiguous](const ShapeIndex& /*index*/, const BufferList& points_to) {
55         ambiguous |= points_to.size() > 1;
56       });
57   return ambiguous;
58 }
59 
IsDistinct() const60 bool PointsToSet::IsDistinct() const {
61   bool distinct = true;
62   absl::flat_hash_set<const LogicalBuffer*> all_points_to;
63   ForEachElement([&](const ShapeIndex& /*index*/, const BufferList& points_to) {
64     for (auto& buffer : points_to) {
65       if (all_points_to.contains(buffer)) {
66         distinct = false;
67       }
68       all_points_to.insert(buffer);
69     }
70   });
71   return distinct;
72 }
73 
size() const74 size_t PointsToSet::size() const {
75   // Because pointed-to elements may be duplicated we have to create a flattened
76   // set and return the size.
77   return CreateFlattenedSet().size();
78 }
79 
CreateFlattenedSet() const80 PointsToSet::BufferSet PointsToSet::CreateFlattenedSet() const {
81   BufferSet flat_set;
82   ForEachElement(
83       [&flat_set](const ShapeIndex& /*index*/, const BufferList& buffers) {
84         flat_set.insert(buffers.begin(), buffers.end());
85       });
86   return flat_set;
87 }
88 
ContainsBuffer(const LogicalBuffer & buffer) const89 bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const {
90   bool found = false;
91   ForEachElement([&found, &buffer](const ShapeIndex& /*index*/,
92                                    const BufferList& pointed_to_buffers) {
93     if (!found && absl::c_linear_search(pointed_to_buffers, &buffer)) {
94       found = true;
95     }
96   });
97   return found;
98 }
99 
ContainsBufferAtIndex(const LogicalBuffer & buffer,const ShapeIndex & index) const100 bool PointsToSet::ContainsBufferAtIndex(const LogicalBuffer& buffer,
101                                         const ShapeIndex& index) const {
102   const auto& pointed_to_buffers = element(index);
103   return absl::c_linear_search(pointed_to_buffers, &buffer);
104 }
105 
AddPointedToBuffer(const LogicalBuffer & buffer,const ShapeIndex & index)106 void PointsToSet::AddPointedToBuffer(const LogicalBuffer& buffer,
107                                      const ShapeIndex& index) {
108   if (ContainsBufferAtIndex(buffer, index)) {
109     return;
110   }
111   mutable_element(index)->push_back(&buffer);
112 }
113 
tuple_sources(const ShapeIndex & index) const114 const PointsToSet::SourceSet& PointsToSet::tuple_sources(
115     const ShapeIndex& index) const {
116   return tree_.element(index).tuple_sources;
117 }
118 
add_tuple_source(const ShapeIndex & index,HloInstruction * tuple)119 void PointsToSet::add_tuple_source(const ShapeIndex& index,
120                                    HloInstruction* tuple) {
121   tree_.mutable_element(index)->tuple_sources.insert(tuple);
122 }
123 
124 namespace {
125 // Gather fusion instructions from 'instruction' into 'fusion_instructions'.
GatherFusionInstructions(HloInstruction * instruction,std::vector<HloInstruction * > * fusion_instructions)126 void GatherFusionInstructions(
127     HloInstruction* instruction,
128     std::vector<HloInstruction*>* fusion_instructions) {
129   CHECK_EQ(HloOpcode::kFusion, instruction->opcode());
130   for (auto* fused : instruction->fused_instructions()) {
131     if (fused->opcode() == HloOpcode::kFusion) {
132       GatherFusionInstructions(fused, fusion_instructions);
133     }
134   }
135   fusion_instructions->push_back(instruction);
136 }
137 
138 }  // namespace
139 
140 /* static */ StatusOr<std::unique_ptr<TuplePointsToAnalysis>>
Run(const HloModule * module)141 TuplePointsToAnalysis::Run(const HloModule* module) {
142   auto logical_buffer_analysis = LogicalBufferAnalysis::Run(module);
143   std::unique_ptr<TuplePointsToAnalysis> analysis(new TuplePointsToAnalysis(
144       module, std::move(logical_buffer_analysis).value()));
145   TF_RETURN_IF_ERROR(analysis->Analyze());
146   return std::move(analysis);
147 }
148 
Analyze()149 Status TuplePointsToAnalysis::Analyze() {
150   per_instruction_.clear();
151   per_instruction_.reserve(module_->instruction_count());
152 
153   logical_buffer_aliases_.clear();
154   logical_buffer_aliases_.resize(
155       logical_buffer_analysis_->num_logical_buffers());
156 
157   std::vector<HloInstruction*> fusion_instructions;
158   for (auto* computation : module_->MakeNonfusionComputations()) {
159     TF_RETURN_IF_ERROR(computation->Accept(this));
160     TF_RETURN_IF_ERROR(
161         PopulateDefinedBuffersAndAliases(computation->instructions()));
162     for (auto* instruction : computation->instructions()) {
163       if (instruction->opcode() == HloOpcode::kFusion) {
164         GatherFusionInstructions(instruction, &fusion_instructions);
165       }
166     }
167   }
168   // Run points-to analysis on fusion instructions in 'computation'.
169   for (auto* instruction : fusion_instructions) {
170     TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this));
171     TF_RETURN_IF_ERROR(
172         PopulateDefinedBuffersAndAliases(instruction->fused_instructions()));
173   }
174 
175   XLA_VLOG_LINES(3, ToString());
176 
177   return OkStatus();
178 }
179 
180 Status TuplePointsToAnalysis::PopulateDefinedBuffersAndAliases(
181     const decltype(std::declval<HloComputation>()
182                        .instructions())& instructions) {
183   for (auto* instruction : instructions) {
184     PerInstruction* pi = PerInst(instruction);
185     TF_RETURN_IF_ERROR(GatherBuffersDefinedByInstruction(
186         instruction, &pi->instruction_defined_buffers));
187 
188     const PointsToSet& points_to_set = GetPointsToSet(instruction);
189     points_to_set.ForEachElement(
190         [this, &instruction](
191             const ShapeIndex& index,
__anon819384570602( const ShapeIndex& index, const PointsToSet::BufferList& pointed_to_buffers) 192             const PointsToSet::BufferList& pointed_to_buffers) {
193           for (const LogicalBuffer* buffer : pointed_to_buffers) {
194             logical_buffer_aliases_[buffer->id()].emplace_back(instruction,
195                                                                index);
196           }
197         });
198   }
199   return OkStatus();
200 }
201 
DefaultAction(HloInstruction * hlo_instruction)202 Status TuplePointsToAnalysis::DefaultAction(HloInstruction* hlo_instruction) {
203   // Create trivial points-to set for instruction. Each points-to set at index i
204   // contains a single element LogicalBuffer(hlo_instruction, i). This indicates
205   // that this instruction is the source of all buffers in its own output.
206   PointsToSet& points_to_set = CreateEmptyPointsToSet(hlo_instruction);
207   points_to_set.ForEachMutableElement(
208       [this, hlo_instruction](const ShapeIndex& index,
209                               PointsToSet::BufferList* buffers) {
210         buffers->push_back(
211             &logical_buffer_analysis_->GetBuffer(hlo_instruction, index));
212       });
213 
214   if (hlo_instruction->shape().IsTuple()) {
215     // If the hlo instruction is a tuple-shaped, then trivially the instruction
216     // itself is the source of the tuple.
217     points_to_set.add_tuple_source({}, hlo_instruction);
218   }
219 
220   return OkStatus();
221 }
222 
HandleGetTupleElement(HloInstruction * get_tuple_element)223 Status TuplePointsToAnalysis::HandleGetTupleElement(
224     HloInstruction* get_tuple_element) {
225   // GetTupleElement forwards a pointer to a particular element of the tuple
226   // operand.
227   int64_t element_index = get_tuple_element->tuple_index();
228 
229   PointsToSet& points_to_set = CreateEmptyPointsToSet(get_tuple_element);
230   const PointsToSet& operand_points_to_set =
231       *PerInst(get_tuple_element->operand(0))->points_to_set;
232 
233   // Copy the points-to set (and tuple sources) at index {element_index} of the
234   // operand to the points-to set for this GetTupleElement instruction.
235   points_to_set.ForEachMutableElement(
236       [&](const ShapeIndex& target_index, PointsToSet::BufferList* points_to) {
237         // Construct an index into the operand by prepending element_index to
238         // the index for the GetTupleElement instruction's points-to set.
239         ShapeIndex src_index;
240         src_index.push_back(element_index);
241         for (auto element : target_index) {
242           src_index.push_back(element);
243         }
244 
245         *points_to = operand_points_to_set.element(src_index);
246         for (HloInstruction* tuple :
247              operand_points_to_set.tuple_sources(src_index)) {
248           points_to_set.add_tuple_source(target_index, tuple);
249         }
250       });
251 
252   return OkStatus();
253 }
254 
HandleCopy(HloInstruction * copy)255 Status TuplePointsToAnalysis::HandleCopy(HloInstruction* copy) {
256   // A kCopy instruction performs a shallow copy of the operand. The top-level
257   // buffer (index={}) is newly created, but all other buffers (in the case of a
258   // tuple shape) come from the operand
259   PointsToSet& points_to_set = CreateCopiedPointsToSet(copy, copy->operand(0));
260   points_to_set.mutable_element(/*index=*/{})->clear();
261   points_to_set.AddPointedToBuffer(
262       logical_buffer_analysis_->GetBuffer(copy, /*index=*/{}),
263       /*index=*/{});
264 
265   return OkStatus();
266 }
267 
HandleBitcast(HloInstruction * bitcast)268 Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) {
269   // A kBitcast instruction aliases its operand. That is, the buffer of its
270   // result *is* the buffer of its operand, so just copy the operands points-to
271   // set.
272   CreateCopiedPointsToSet(bitcast, bitcast->operand(0));
273   return OkStatus();
274 }
275 
HandleDomain(HloInstruction * domain)276 Status TuplePointsToAnalysis::HandleDomain(HloInstruction* domain) {
277   // A kDomain instruction aliases its operand. That is, the buffer of its
278   // result *is* the buffer of its operand, so just copy the operands points-to
279   // set.
280   CreateCopiedPointsToSet(domain, domain->operand(0));
281   return OkStatus();
282 }
283 
HandleAddDependency(HloInstruction * add_dependency)284 Status TuplePointsToAnalysis::HandleAddDependency(
285     HloInstruction* add_dependency) {
286   // AddDependency just forwards the value of its zero-th operand.
287   CreateCopiedPointsToSet(add_dependency, add_dependency->operand(0));
288   return OkStatus();
289 }
290 
HandleRecvDone(HloInstruction * recv_done)291 Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) {
292   // RecvDone aliases its input (Recv) tuple element {0} to element {0} of its
293   // output. The other indices ({} and {1}) define their own buffers.
294   PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done);
295   points_to_set.AddPointedToBuffer(
296       logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{}),
297       /*index=*/{});
298   points_to_set.AddPointedToBuffer(
299       logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{1}),
300       /*index=*/{1});
301 
302   const PointsToSet& operand_points_to_set =
303       GetPointsToSet(recv_done->operand(0));
304 
305   // Recursively copy the points to set of the operand tuple {0} to the output
306   // element {0}.
307   points_to_set.ForEachMutableElement(
308       [&points_to_set, &operand_points_to_set](
309           const ShapeIndex& index, PointsToSet::BufferList* buffers) {
310         if (index.empty() || index[0] != 0) {
311           return;
312         }
313         *buffers = operand_points_to_set.element(index);
314         for (auto& tuple_source : operand_points_to_set.tuple_sources(index)) {
315           points_to_set.add_tuple_source(index, tuple_source);
316         }
317       });
318   return OkStatus();
319 }
320 
HandleAsyncStart(HloInstruction * async_start)321 Status TuplePointsToAnalysis::HandleAsyncStart(HloInstruction* async_start) {
322   // AsyncStart forwards its aliased operands to {0}.
323   PointsToSet& points_to_set = CreateEmptyPointsToSet(async_start);
324 
325   points_to_set.ForEachMutableElement(
326       [&](const ShapeIndex& target_index, PointsToSet::BufferList* buffers) {
327         if (target_index.size() >= 2 && target_index.front() == 0) {
328           const PointsToSet& operand_points_to_set =
329               GetPointsToSet(async_start->operand(target_index.at(1)));
330           ShapeIndex source_index(target_index.begin() + 2, target_index.end());
331           *buffers = operand_points_to_set.element(source_index);
332           for (HloInstruction* tuple :
333                operand_points_to_set.tuple_sources(source_index)) {
334             points_to_set.add_tuple_source(target_index, tuple);
335           }
336         } else {
337           buffers->push_back(
338               &logical_buffer_analysis_->GetBuffer(async_start, target_index));
339         }
340       });
341 
342   return OkStatus();
343 }
344 
HandleAsyncUpdate(HloInstruction * async_update)345 Status TuplePointsToAnalysis::HandleAsyncUpdate(HloInstruction* async_update) {
346   // AsyncUpdate forwards its aliased operand to {}.
347   PointsToSet& points_to_set = CreateEmptyPointsToSet(async_update);
348   const PointsToSet& operand_points_to_set =
349       GetPointsToSet(async_update->operand(0));
350   CHECK_EQ(async_update->shape(), async_update->operand(0)->shape());
351 
352   points_to_set.ForEachMutableElement([&](const ShapeIndex& index,
353                                           PointsToSet::BufferList* buffers) {
354     *buffers = operand_points_to_set.element(index);
355     for (HloInstruction* tuple : operand_points_to_set.tuple_sources(index)) {
356       points_to_set.add_tuple_source(index, tuple);
357     }
358   });
359 
360   return OkStatus();
361 }
362 
HandleAsyncDone(HloInstruction * async_done)363 Status TuplePointsToAnalysis::HandleAsyncDone(HloInstruction* async_done) {
364   // AsyncDone forwards its aliased operand.
365   PointsToSet& points_to_set = CreateEmptyPointsToSet(async_done);
366   const PointsToSet& operand_points_to_set =
367       GetPointsToSet(async_done->operand(0));
368   operand_points_to_set.ForEachElement(
369       [&points_to_set, &operand_points_to_set](
370           const ShapeIndex& src_index,
371           const PointsToSet::BufferList& points_to) {
372         if (!src_index.empty() && src_index.front() == 1) {
373           const ShapeIndex target_index(src_index.begin() + 1, src_index.end());
374           *points_to_set.mutable_element(target_index) = points_to;
375 
376           for (HloInstruction* tuple :
377                operand_points_to_set.tuple_sources(src_index)) {
378             points_to_set.add_tuple_source(target_index, tuple);
379           }
380         }
381       });
382 
383   return OkStatus();
384 }
385 
HandleCopyStart(HloInstruction * copy_start)386 Status TuplePointsToAnalysis::HandleCopyStart(HloInstruction* copy_start) {
387   // CopyStart forwards its aliased operand to {1}.
388   PointsToSet& points_to_set = CreateEmptyPointsToSet(copy_start);
389   const PointsToSet& operand_points_to_set =
390       GetPointsToSet(copy_start->operand(0));
391 
392   points_to_set.ForEachMutableElement(
393       [&](const ShapeIndex& target_index, PointsToSet::BufferList* buffers) {
394         if (target_index == ShapeIndex({1})) {
395           *buffers = operand_points_to_set.element(/*index=*/{});
396         } else {
397           buffers->push_back(
398               &logical_buffer_analysis_->GetBuffer(copy_start, target_index));
399         }
400       });
401 
402   for (HloInstruction* tuple :
403        operand_points_to_set.tuple_sources(/*index=*/{})) {
404     points_to_set.add_tuple_source(/*index=*/{1}, tuple);
405   }
406 
407   return OkStatus();
408 }
409 
HandleCopyDone(HloInstruction * copy_done)410 Status TuplePointsToAnalysis::HandleCopyDone(HloInstruction* copy_done) {
411   // CopyDone forwards its aliased operand.
412   PointsToSet& points_to_set = CreateEmptyPointsToSet(copy_done);
413   const PointsToSet& operand_points_to_set =
414       GetPointsToSet(copy_done->operand(0));
415   operand_points_to_set.ForEachElement(
416       [&points_to_set, &operand_points_to_set](
417           const ShapeIndex& src_index,
418           const PointsToSet::BufferList& points_to) {
419         if (src_index == ShapeIndex({0})) {
420           const ShapeIndex target_index = {};
421           *points_to_set.mutable_element(target_index) = points_to;
422 
423           for (HloInstruction* tuple :
424                operand_points_to_set.tuple_sources(src_index)) {
425             points_to_set.add_tuple_source(target_index, tuple);
426           }
427         }
428       });
429 
430   return OkStatus();
431 }
432 
HandleSend(HloInstruction * send)433 Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) {
434   // Send creates a tuple of {aliased operand, U32 context, token}.
435   PointsToSet& points_to_set = CreateEmptyPointsToSet(send);
436 
437   // Creates the points to set for the tuple and its element at {1}.
438   auto top_buffer = points_to_set.mutable_element(ShapeIndex({}));
439   top_buffer->push_back(
440       &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({})));
441   points_to_set.add_tuple_source({}, send);
442 
443   auto context_buffer = points_to_set.mutable_element(ShapeIndex({1}));
444   context_buffer->push_back(
445       &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({1})));
446 
447   auto token_buffer = points_to_set.mutable_element(ShapeIndex({2}));
448   token_buffer->push_back(
449       &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({2})));
450 
451   // Recursively copy the points to set of the operand to output tuple {0}.
452   const PointsToSet& operand_points_to_set = GetPointsToSet(send->operand(0));
453   operand_points_to_set.ForEachElement(
454       [&points_to_set, &operand_points_to_set](
455           const ShapeIndex& src_index,
456           const PointsToSet::BufferList& points_to) {
457         ShapeIndex target_index({0});
458         for (auto element : src_index) {
459           target_index.push_back(element);
460         }
461         *points_to_set.mutable_element(target_index) = points_to;
462 
463         for (HloInstruction* tuple :
464              operand_points_to_set.tuple_sources(src_index)) {
465           points_to_set.add_tuple_source(target_index, tuple);
466         }
467       });
468 
469   return OkStatus();
470 }
471 
HandleTuple(HloInstruction * tuple)472 Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) {
473   absl::Span<HloInstruction* const> operands(tuple->operands());
474   PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple);
475   points_to_set.AddPointedToBuffer(
476       logical_buffer_analysis_->GetBuffer(tuple, /*index=*/{}),
477       /*index=*/{});
478 
479   // A tuple contains references to all input operands and transitively any
480   // references in those operands.
481   for (int64_t i = 0; i < operands.size(); ++i) {
482     const PointsToSet& operand_points_to_set =
483         *PerInst(operands[i])->points_to_set;
484 
485     // Copy the points-to set (and tuple sources) of the operand into the
486     // respective subtree of the tuple instructions points-to set.
487     operand_points_to_set.ForEachElement(
488         [&points_to_set, &operand_points_to_set, i](
489             const ShapeIndex& src_index,
490             const PointsToSet::BufferList& points_to) {
491           ShapeIndex target_index;
492           target_index.push_back(i);
493           for (auto element : src_index) {
494             target_index.push_back(element);
495           }
496 
497           *points_to_set.mutable_element(target_index) = points_to;
498 
499           for (HloInstruction* tuple :
500                operand_points_to_set.tuple_sources(src_index)) {
501             points_to_set.add_tuple_source(target_index, tuple);
502           }
503         });
504   }
505 
506   points_to_set.add_tuple_source({}, tuple);
507 
508   return OkStatus();
509 }
510 
HandleCustomCall(HloInstruction * custom_call)511 Status TuplePointsToAnalysis::HandleCustomCall(HloInstruction* custom_call) {
512   auto ccall = Cast<HloCustomCallInstruction>(custom_call);
513   PointsToSet& points_to_set = CreateEmptyPointsToSet(custom_call);
514   absl::flat_hash_map<ShapeIndex, std::pair<int64_t, ShapeIndex>>
515       aliased_outputs;
516   for (const auto& pair : ccall->output_to_operand_aliasing()) {
517     aliased_outputs.emplace(pair.first, pair.second);
518   }
519   points_to_set.ForEachMutableElement([&](const ShapeIndex& index,
520                                           PointsToSet::BufferList* buffers) {
521     auto it = aliased_outputs.find(index);
522     if (it == aliased_outputs.end()) {
523       points_to_set.AddPointedToBuffer(
524           logical_buffer_analysis_->GetBuffer(custom_call, index), index);
525     } else {
526       const PointsToSet& input_set =
527           *PerInst(ccall->operand(it->second.first))->points_to_set;
528       for (const LogicalBuffer* input_buffer :
529            input_set.element(it->second.second)) {
530         points_to_set.AddPointedToBuffer(*input_buffer, index);
531       }
532 
533       for (HloInstruction* tuple : input_set.tuple_sources(it->second.second)) {
534         points_to_set.add_tuple_source(index, tuple);
535       }
536     }
537   });
538   points_to_set.add_tuple_source({}, custom_call);
539   return OkStatus();
540 }
541 
HandleOptimizationBarrier(HloInstruction * barrier)542 Status TuplePointsToAnalysis::HandleOptimizationBarrier(
543     HloInstruction* barrier) {
544   // A kOptimizationBarrier instruction is a no-op.
545   CreateCopiedPointsToSet(barrier, barrier->operand(0));
546   return OkStatus();
547 }
548 
GetPointsToSet(const HloInstruction * hlo_instruction) const549 const PointsToSet& TuplePointsToAnalysis::GetPointsToSet(
550     const HloInstruction* hlo_instruction) const {
551   return *PerInst(hlo_instruction)->points_to_set;
552 }
553 
CreateEmptyPointsToSet(const HloInstruction * instruction)554 PointsToSet& TuplePointsToAnalysis::CreateEmptyPointsToSet(
555     const HloInstruction* instruction) {
556   PerInstruction* pi = PerInst(instruction);
557   CHECK(pi->points_to_set == nullptr)
558       << "instruction should not have been present in the map.";
559   auto set = std::make_unique<PointsToSet>(&instruction->shape());
560   pi->points_to_set = std::move(set);
561   // Return *set using the iterator returned by emplace.
562   return *pi->points_to_set;
563 }
564 
InstructionDefinesBufferAtIndex(const HloInstruction * instruction,const ShapeIndex & index) const565 bool TuplePointsToAnalysis::InstructionDefinesBufferAtIndex(
566     const HloInstruction* instruction, const ShapeIndex& index) const {
567   const auto& buffers = GetPointsToSet(instruction).element(index);
568   return (buffers.size() == 1 && buffers[0]->instruction() == instruction);
569 }
570 
VerifyBuffer(const LogicalBuffer & buffer) const571 Status TuplePointsToAnalysis::VerifyBuffer(const LogicalBuffer& buffer) const {
572   if (!InstructionDefinesBufferAtIndex(buffer.instruction(), buffer.index())) {
573     return FailedPrecondition(
574         "LogicalBuffer %s is ill-defined: instruction %s does not define a "
575         "buffer at that index",
576         buffer.ToString(), buffer.instruction()->name());
577   }
578 
579   if (buffer.id() < 0 ||
580       buffer.id() >= logical_buffer_analysis_->num_logical_buffers()) {
581     return FailedPrecondition("LogicalBuffer %s is ill-defined: invalid id %d",
582                               buffer.ToString(), buffer.id());
583   }
584   if (GetBuffer(buffer.id()).instruction() != buffer.instruction() ||
585       GetBuffer(buffer.id()).index() != buffer.index()) {
586     return FailedPrecondition(
587         "LogicalBuffer %s is ill-defined: buffer with same id differs: %s",
588         buffer.ToString(), GetBuffer(buffer.id()).ToString());
589   }
590 
591   return OkStatus();
592 }
593 
GetBuffer(LogicalBuffer::Id id) const594 const LogicalBuffer& TuplePointsToAnalysis::GetBuffer(
595     LogicalBuffer::Id id) const {
596   CHECK_GE(id, 0);
597   CHECK_LT(id, logical_buffer_analysis_->num_logical_buffers());
598   return logical_buffer_analysis_->GetBuffer(id);
599 }
600 
GetBufferDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const601 StatusOr<const LogicalBuffer*> TuplePointsToAnalysis::GetBufferDefinedAt(
602     const HloInstruction* instruction, const ShapeIndex& index) const {
603   const auto& buffers = GetPointsToSet(instruction).element(index);
604   if (buffers.size() != 1 || buffers[0]->instruction() != instruction) {
605     return FailedPrecondition(
606         "instruction %s does not define buffer at index {%s}",
607         instruction->name(), absl::StrJoin(index, ","));
608   }
609   return buffers[0];
610 }
611 
612 const TuplePointsToAnalysis::BufferAliasVector&
GetBufferAliases(const LogicalBuffer & buffer) const613 TuplePointsToAnalysis::GetBufferAliases(const LogicalBuffer& buffer) const {
614   return logical_buffer_aliases_.at(buffer.id());
615 }
616 
617 const TuplePointsToAnalysis::BufferDefinitionVector&
GetBuffersDefinedByInstruction(const HloInstruction * instruction) const618 TuplePointsToAnalysis::GetBuffersDefinedByInstruction(
619     const HloInstruction* instruction) const {
620   return PerInst(instruction)->instruction_defined_buffers;
621 }
622 
GatherBuffersDefinedByInstruction(const HloInstruction * instruction,TuplePointsToAnalysis::BufferDefinitionVector * buffers)623 Status TuplePointsToAnalysis::GatherBuffersDefinedByInstruction(
624     const HloInstruction* instruction,
625     TuplePointsToAnalysis::BufferDefinitionVector* buffers) {
626   GetPointsToSet(instruction)
627       .ForEachElement([buffers, instruction](
628                           const ShapeIndex& index,
629                           const PointsToSet::BufferList& source_buffers) {
630         // Add buffers which 'instruction' is the source of.
631         CHECK(!source_buffers.empty());
632         if (source_buffers.size() == 1 &&
633             source_buffers[0]->instruction() == instruction) {
634           // If this instruction is the source of this buffer the
635           // indices must match.
636           DCHECK(source_buffers[0]->index() == index);
637           buffers->push_back(source_buffers[0]);
638         } else {
639           // If the points-to set includes more than one buffer then
640           // necessarily this instruction did not produce the
641           // buffer.
642           for (const LogicalBuffer* source_buffer : source_buffers) {
643             DCHECK(source_buffer->instruction() != instruction);
644           }
645         }
646       });
647   return OkStatus();
648 }
649 
CreateCopiedPointsToSet(const HloInstruction * instruction,const HloInstruction * src)650 PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet(
651     const HloInstruction* instruction, const HloInstruction* src) {
652   // PointsToSet doesn't have a copy constructor so copy over element-by-element
653   // from src PointsToSet.
654   PointsToSet& dst_points_to_set = CreateEmptyPointsToSet(instruction);
655   const PointsToSet& src_points_to_set = GetPointsToSet(src);
656   dst_points_to_set.ForEachMutableElement(
657       [&dst_points_to_set, &src_points_to_set](
658           const ShapeIndex& index, PointsToSet::BufferList* buffers) {
659         *buffers = src_points_to_set.element(index);
660         for (auto& tuple_source : src_points_to_set.tuple_sources(index)) {
661           dst_points_to_set.add_tuple_source(index, tuple_source);
662         }
663       });
664   return *PerInst(instruction)->points_to_set;
665 }
666 
ToString() const667 std::string TuplePointsToAnalysis::ToString() const {
668   std::string output =
669       absl::StrFormat("TuplePointsToSet for module %s:\n", module_->name());
670   for (const auto* computation : module_->MakeNonfusionComputations()) {
671     const char* entry =
672         computation == module_->entry_computation() ? "entry " : "";
673     absl::StrAppend(&output, entry, "computation ", computation->name(), ":\n");
674     for (const HloInstruction* instruction :
675          computation->MakeInstructionPostOrder()) {
676       InstructionToString(instruction, &output);
677       if (instruction->opcode() == HloOpcode::kFusion) {
678         for (auto* fused : instruction->fused_instructions()) {
679           InstructionToString(fused, &output);
680         }
681       }
682     }
683   }
684 
685   absl::StrAppend(&output, "LogicalBuffers:\n");
686   for (const auto& b : logical_buffer_analysis_->logical_buffers()) {
687     absl::StrAppend(&output, "  buffer ", b->ToString(), ":\n");
688     for (const BufferAlias& alias : logical_buffer_aliases_.at(b->id())) {
689       absl::StrAppend(&output, "    alias ", alias.ToString(), "\n");
690     }
691   }
692   return output;
693 }
694 
InstructionToString(const HloInstruction * instruction,std::string * output) const695 void TuplePointsToAnalysis::InstructionToString(
696     const HloInstruction* instruction, std::string* output) const {
697   const std::string prefix = instruction->IsFused() ? "    " : "";
698   absl::StrAppend(output, prefix, "  instruction ",
699                   instruction->ToShortString(), ":\n");
700   const PointsToSet& points_to_set = GetPointsToSet(instruction);
701   points_to_set.ForEachElement(
702       [&prefix, &output](const ShapeIndex& index,
703                          const PointsToSet::BufferList& points_to) {
704         absl::StrAppend(
705             output, prefix, "    {", absl::StrJoin(index, ","), "}: ",
706             absl::StrJoin(points_to, ", ",
707                           [](std::string* out, const LogicalBuffer* source) {
708                             out->append(source->ToString());
709                           }),
710             "\n");
711       });
712 }
713 
DoesNotUseOperandBuffer(const HloInstruction * operand,const ShapeIndex & index,const HloInstruction * user) const714 bool TuplePointsToAnalysis::DoesNotUseOperandBuffer(
715     const HloInstruction* operand, const ShapeIndex& index,
716     const HloInstruction* user) const {
717   CHECK(user->IsUserOf(operand))
718       << "user: " << user->ToString() << " operand: " << operand->ToString();
719   if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) {
720     // GetTupleElement instructions only access the top-level buffer of their
721     // operand.
722     return true;
723   } else if (user->IsLoopFusion()) {
724     // Find fusion parameter associated with 'operand'.
725     auto it = absl::c_find_if(
726         user->fused_parameters(), [&](HloInstruction* fused_param) {
727           return user->operand(fused_param->parameter_number()) == operand;
728         });
729     CHECK(it != user->fused_parameters().end());
730     // Iterate through all users of all buffer aliases of the buffer in the
731     // points-to set of fusion parameter at 'index'.
732     // Return false if any uses are detected at 'index', returns true otherwise.
733     const LogicalBuffer* buffer = GetBufferDefinedAt(*it, index).ValueOrDie();
734     for (const BufferAlias& alias : GetBufferAliases(*buffer)) {
735       for (HloInstruction* alias_user : alias.instruction()->users()) {
736         if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
737                                     alias_user)) {
738           continue;
739         }
740         // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'.
741         return false;
742       }
743     }
744     // Return true: found no uses of 'operand' at 'index' in 'user'.
745     return true;
746   }
747   return false;
748 }
749 
750 // Returns all uses of all aliases of 'instruction' at 'index' in 'uses'.
751 // Each use in 'uses' is a pair (HloInstruction* user, int64_t operand_index)
752 // where 'user' is a user of an alias of 'instruction' at 'index', and
753 // 'operand_index' is the operand index at which the alias appears in the
754 // operand list of 'user'.
755 std::vector<std::pair<HloInstruction*, int64_t>>
GetAllUsesOfInstructionAtIndex(HloInstruction * instruction,const ShapeIndex & index) const756 TuplePointsToAnalysis::GetAllUsesOfInstructionAtIndex(
757     HloInstruction* instruction, const ShapeIndex& index) const {
758   std::vector<std::pair<HloInstruction*, int64_t>> uses;
759   const PointsToSet::BufferList& points_to =
760       GetPointsToSet(instruction).element(index);
761   for (const LogicalBuffer* buffer : points_to) {
762     for (const BufferAlias& alias : GetBufferAliases(*buffer)) {
763       for (HloInstruction* alias_user : alias.instruction()->users()) {
764         if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
765                                     alias_user)) {
766           continue;
767         }
768         for (int64_t op_idx : alias_user->OperandIndices(alias.instruction())) {
769           uses.emplace_back(alias_user, op_idx);
770         }
771       }
772     }
773   }
774   return uses;
775 }
776 
777 // Returns true if there is exactly one use of 'operand' at 'operand_index'
778 // in 'fusion.fused_instructions', where the singleton use is the fused
779 // root at operand index 'use_operand_index'. Returns false otherwise.
780 //
781 // REQUIRES: 'fusion' opcode is a kFusion instruction.
HasUniqueFusedUseOfOperandAt(HloInstruction * operand,const ShapeIndex & operand_index,HloInstruction * fusion,const int64_t use_operand_index) const782 bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt(
783     HloInstruction* operand, const ShapeIndex& operand_index,
784     HloInstruction* fusion, const int64_t use_operand_index) const {
785   CHECK_EQ(HloOpcode::kFusion, fusion->opcode());
786   // Check that 'operand' is unique in the operand list of 'fusion'.
787   if (fusion->OperandIndices(operand).size() > 1) {
788     return false;
789   }
790   // Find fusion parameter associated with 'operand'.
791   const auto& fused_params = fusion->fused_parameters();
792   auto fused_param_it =
793       absl::c_find_if(fused_params, [&](HloInstruction* fused_param) {
794         return fusion->operand(fused_param->parameter_number()) == operand;
795       });
796   if (fused_param_it == fused_params.end()) {
797     return false;
798   }
799   auto* fused_param = *fused_param_it;
800   // Get all uses of 'operand' at 'index' from 'fusion.fused_instructions'.
801   auto fused_param_uses =
802       GetAllUsesOfInstructionAtIndex(fused_param, operand_index);
803   // Return true iff there is exactly one use of 'operand' at 'index', and
804   // this singleton use is the fused root (at index in 'use_operand_indices').
805   return fused_param_uses.size() == 1 &&
806          fused_param_uses[0].first == fusion->fused_expression_root() &&
807          fused_param_uses[0].second == use_operand_index;
808 }
809 }  // namespace xla
810