• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
17 
18 #include <ostream>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/memory/memory.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_format.h"
26 #include "absl/strings/str_join.h"
27 #include "tensorflow/compiler/xla/map_util.h"
28 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/platform/logging.h"
35 
36 namespace xla {
37 
ToString() const38 string BufferAlias::ToString() const {
39   return absl::StrCat("BufferAlias(", instruction_->name(), "[",
40                       absl::StrJoin(index_, ","), "])");
41 }
42 
operator <<(std::ostream & out,const BufferAlias & buffer_alias)43 std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias) {
44   out << buffer_alias.ToString();
45   return out;
46 }
47 
IsAmbiguous() const48 bool PointsToSet::IsAmbiguous() const {
49   bool ambiguous = false;
50   ForEachElement(
51       [&ambiguous](const ShapeIndex& /*index*/, const BufferList& points_to) {
52         ambiguous |= points_to.size() > 1;
53       });
54   return ambiguous;
55 }
56 
IsDistinct() const57 bool PointsToSet::IsDistinct() const {
58   bool distinct = true;
59   absl::flat_hash_set<const LogicalBuffer*> all_points_to;
60   ForEachElement([&](const ShapeIndex& /*index*/, const BufferList& points_to) {
61     for (auto& buffer : points_to) {
62       if (all_points_to.contains(buffer)) {
63         distinct = false;
64       }
65       all_points_to.insert(buffer);
66     }
67   });
68   return distinct;
69 }
70 
size() const71 size_t PointsToSet::size() const {
72   // Because pointed-to elements may be duplicated we have to create a flattened
73   // set and return the size.
74   return CreateFlattenedSet().size();
75 }
76 
CreateFlattenedSet() const77 PointsToSet::BufferSet PointsToSet::CreateFlattenedSet() const {
78   BufferSet flat_set;
79   ForEachElement(
80       [&flat_set](const ShapeIndex& /*index*/, const BufferList& buffers) {
81         flat_set.insert(buffers.begin(), buffers.end());
82       });
83   return flat_set;
84 }
85 
ContainsBuffer(const LogicalBuffer & buffer) const86 bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const {
87   bool found = false;
88   ForEachElement([&found, &buffer](const ShapeIndex& /*index*/,
89                                    const BufferList& pointed_to_buffers) {
90     if (!found && absl::c_linear_search(pointed_to_buffers, &buffer)) {
91       found = true;
92     }
93   });
94   return found;
95 }
96 
ContainsBufferAtIndex(const LogicalBuffer & buffer,const ShapeIndex & index) const97 bool PointsToSet::ContainsBufferAtIndex(const LogicalBuffer& buffer,
98                                         const ShapeIndex& index) const {
99   const auto& pointed_to_buffers = element(index);
100   return absl::c_linear_search(pointed_to_buffers, &buffer);
101 }
102 
AddPointedToBuffer(const LogicalBuffer & buffer,const ShapeIndex & index)103 void PointsToSet::AddPointedToBuffer(const LogicalBuffer& buffer,
104                                      const ShapeIndex& index) {
105   if (ContainsBufferAtIndex(buffer, index)) {
106     return;
107   }
108   mutable_element(index)->push_back(&buffer);
109 }
110 
tuple_sources(const ShapeIndex & index) const111 const PointsToSet::SourceSet& PointsToSet::tuple_sources(
112     const ShapeIndex& index) const {
113   return tree_.element(index).tuple_sources;
114 }
115 
add_tuple_source(const ShapeIndex & index,HloInstruction * tuple)116 void PointsToSet::add_tuple_source(const ShapeIndex& index,
117                                    HloInstruction* tuple) {
118   tree_.mutable_element(index)->tuple_sources.insert(tuple);
119 }
120 
121 namespace {
122 // Gather fusion instructions from 'instruction' into 'fusion_instructions'.
GatherFusionInstructions(HloInstruction * instruction,std::vector<HloInstruction * > * fusion_instructions)123 void GatherFusionInstructions(
124     HloInstruction* instruction,
125     std::vector<HloInstruction*>* fusion_instructions) {
126   CHECK_EQ(HloOpcode::kFusion, instruction->opcode());
127   for (auto* fused : instruction->fused_instructions()) {
128     if (fused->opcode() == HloOpcode::kFusion) {
129       GatherFusionInstructions(fused, fusion_instructions);
130     }
131   }
132   fusion_instructions->push_back(instruction);
133 }
134 
135 }  // namespace
136 
137 /* static */ StatusOr<std::unique_ptr<TuplePointsToAnalysis>>
Run(const HloModule * module)138 TuplePointsToAnalysis::Run(const HloModule* module) {
139   auto logical_buffer_analysis = LogicalBufferAnalysis::Run(module);
140   std::unique_ptr<TuplePointsToAnalysis> analysis(new TuplePointsToAnalysis(
141       module, logical_buffer_analysis.ConsumeValueOrDie()));
142   TF_RETURN_IF_ERROR(analysis->Analyze());
143   return std::move(analysis);
144 }
145 
Analyze()146 Status TuplePointsToAnalysis::Analyze() {
147   per_instruction_.clear();
148   per_instruction_.reserve(module_->instruction_count());
149 
150   logical_buffer_aliases_.clear();
151   logical_buffer_aliases_.resize(
152       logical_buffer_analysis_->num_logical_buffers());
153 
154   std::vector<HloInstruction*> fusion_instructions;
155   for (auto* computation : module_->MakeNonfusionComputations()) {
156     TF_RETURN_IF_ERROR(computation->Accept(this));
157     TF_RETURN_IF_ERROR(
158         PopulateDefinedBuffersAndAliases(computation->instructions()));
159     for (auto* instruction : computation->instructions()) {
160       if (instruction->opcode() == HloOpcode::kFusion) {
161         GatherFusionInstructions(instruction, &fusion_instructions);
162       }
163     }
164   }
165   // Run points-to analysis on fusion instructions in 'computation'.
166   for (auto* instruction : fusion_instructions) {
167     TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this));
168     TF_RETURN_IF_ERROR(
169         PopulateDefinedBuffersAndAliases(instruction->fused_instructions()));
170   }
171 
172   XLA_VLOG_LINES(3, ToString());
173 
174   return Status::OK();
175 }
176 
177 Status TuplePointsToAnalysis::PopulateDefinedBuffersAndAliases(const decltype(
178     std::declval<HloComputation>().instructions())& instructions) {
179   for (auto* instruction : instructions) {
180     PerInstruction* pi = PerInst(instruction);
181     TF_RETURN_IF_ERROR(GatherBuffersDefinedByInstruction(
182         instruction, &pi->instruction_defined_buffers));
183 
184     const PointsToSet& points_to_set = GetPointsToSet(instruction);
185     points_to_set.ForEachElement(
186         [this, &instruction](
187             const ShapeIndex& index,
__anon1079e89b0602( const ShapeIndex& index, const PointsToSet::BufferList& pointed_to_buffers) 188             const PointsToSet::BufferList& pointed_to_buffers) {
189           for (const LogicalBuffer* buffer : pointed_to_buffers) {
190             logical_buffer_aliases_[buffer->id()].emplace_back(instruction,
191                                                                index);
192           }
193         });
194   }
195   return Status::OK();
196 }
197 
DefaultAction(HloInstruction * hlo_instruction)198 Status TuplePointsToAnalysis::DefaultAction(HloInstruction* hlo_instruction) {
199   // Create trivial points-to set for instruction. Each points-to set at index i
200   // contains a single element LogicalBuffer(hlo_instruction, i). This indicates
201   // that this instruction is the source of all buffers in its own output.
202   PointsToSet& points_to_set = CreateEmptyPointsToSet(hlo_instruction);
203   points_to_set.ForEachMutableElement(
204       [this, hlo_instruction](const ShapeIndex& index,
205                               PointsToSet::BufferList* buffers) {
206         buffers->push_back(
207             &logical_buffer_analysis_->GetBuffer(hlo_instruction, index));
208       });
209 
210   if (hlo_instruction->shape().IsTuple()) {
211     // If the hlo instruction is a tuple-shaped, then trivially the instruction
212     // itself is the source of the tuple.
213     points_to_set.add_tuple_source({}, hlo_instruction);
214   }
215 
216   return Status::OK();
217 }
218 
HandleGetTupleElement(HloInstruction * get_tuple_element)219 Status TuplePointsToAnalysis::HandleGetTupleElement(
220     HloInstruction* get_tuple_element) {
221   // GetTupleElement forwards a pointer to a particular element of the tuple
222   // operand.
223   int64 element_index = get_tuple_element->tuple_index();
224 
225   PointsToSet& points_to_set = CreateEmptyPointsToSet(get_tuple_element);
226   const PointsToSet& operand_points_to_set =
227       *PerInst(get_tuple_element->operand(0))->points_to_set;
228 
229   // Copy the points-to set (and tuple sources) at index {element_index} of the
230   // operand to the points-to set for this GetTupleElement instruction.
231   points_to_set.ForEachMutableElement(
232       [&](const ShapeIndex& target_index, PointsToSet::BufferList* points_to) {
233         // Construct an index into the operand by prepending element_index to
234         // the index for the GetTupleElement instruction's points-to set.
235         ShapeIndex src_index;
236         src_index.push_back(element_index);
237         for (auto element : target_index) {
238           src_index.push_back(element);
239         }
240 
241         *points_to = operand_points_to_set.element(src_index);
242         for (HloInstruction* tuple :
243              operand_points_to_set.tuple_sources(src_index)) {
244           points_to_set.add_tuple_source(target_index, tuple);
245         }
246       });
247 
248   return Status::OK();
249 }
250 
HandleCopy(HloInstruction * copy)251 Status TuplePointsToAnalysis::HandleCopy(HloInstruction* copy) {
252   // A kCopy instruction performs a shallow copy of the operand. The top-level
253   // buffer (index={}) is newly created, but all other buffers (in the case of a
254   // tuple shape) come from the operand
255   PointsToSet& points_to_set = CreateCopiedPointsToSet(copy, copy->operand(0));
256   points_to_set.mutable_element(/*index=*/{})->clear();
257   points_to_set.AddPointedToBuffer(
258       logical_buffer_analysis_->GetBuffer(copy, /*index=*/{}),
259       /*index=*/{});
260 
261   return Status::OK();
262 }
263 
HandleBitcast(HloInstruction * bitcast)264 Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) {
265   // A kBitcast instruction aliases its operand. That is, the buffer of its
266   // result *is* the buffer of its operand, so just copy the operands points-to
267   // set.
268   CreateCopiedPointsToSet(bitcast, bitcast->operand(0));
269   return Status::OK();
270 }
271 
HandleDomain(HloInstruction * domain)272 Status TuplePointsToAnalysis::HandleDomain(HloInstruction* domain) {
273   // A kDomain instruction aliases its operand. That is, the buffer of its
274   // result *is* the buffer of its operand, so just copy the operands points-to
275   // set.
276   CreateCopiedPointsToSet(domain, domain->operand(0));
277   return Status::OK();
278 }
279 
HandleAddDependency(HloInstruction * add_dependency)280 Status TuplePointsToAnalysis::HandleAddDependency(
281     HloInstruction* add_dependency) {
282   // AddDependency just forwards the value of its zero-th operand.
283   CreateCopiedPointsToSet(add_dependency, add_dependency->operand(0));
284   return Status::OK();
285 }
286 
HandleRecvDone(HloInstruction * recv_done)287 Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) {
288   // RecvDone aliases its input (Recv) tuple element {0} to element {0} of its
289   // output. The other indices ({} and {1}) define their own buffers.
290   PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done);
291   points_to_set.AddPointedToBuffer(
292       logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{}),
293       /*index=*/{});
294   points_to_set.AddPointedToBuffer(
295       logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{1}),
296       /*index=*/{1});
297 
298   const PointsToSet& operand_points_to_set =
299       GetPointsToSet(recv_done->operand(0));
300 
301   // Recursively copy the points to set of the operand tuple {0} to the output
302   // element {0}.
303   points_to_set.ForEachMutableElement(
304       [&points_to_set, &operand_points_to_set](
305           const ShapeIndex& index, PointsToSet::BufferList* buffers) {
306         if (index.empty() || index[0] != 0) {
307           return;
308         }
309         *buffers = operand_points_to_set.element(index);
310         for (auto& tuple_source : operand_points_to_set.tuple_sources(index)) {
311           points_to_set.add_tuple_source(index, tuple_source);
312         }
313       });
314   return Status::OK();
315 }
316 
HandleSend(HloInstruction * send)317 Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) {
318   // Send creates a tuple of {aliased operand, U32 context, token}.
319   PointsToSet& points_to_set = CreateEmptyPointsToSet(send);
320 
321   // Creates the points to set for the tuple and its element at {1}.
322   auto top_buffer = points_to_set.mutable_element(ShapeIndex({}));
323   top_buffer->push_back(
324       &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({})));
325   points_to_set.add_tuple_source({}, send);
326 
327   auto context_buffer = points_to_set.mutable_element(ShapeIndex({1}));
328   context_buffer->push_back(
329       &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({1})));
330 
331   auto token_buffer = points_to_set.mutable_element(ShapeIndex({2}));
332   token_buffer->push_back(
333       &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({2})));
334 
335   // Recursively copy the points to set of the operand to output tuple {0}.
336   const PointsToSet& operand_points_to_set = GetPointsToSet(send->operand(0));
337   operand_points_to_set.ForEachElement(
338       [&points_to_set, &operand_points_to_set](
339           const ShapeIndex& src_index,
340           const PointsToSet::BufferList& points_to) {
341         ShapeIndex target_index({0});
342         for (auto element : src_index) {
343           target_index.push_back(element);
344         }
345         *points_to_set.mutable_element(target_index) = points_to;
346 
347         for (HloInstruction* tuple :
348              operand_points_to_set.tuple_sources(src_index)) {
349           points_to_set.add_tuple_source(target_index, tuple);
350         }
351       });
352 
353   return Status::OK();
354 }
355 
HandleTuple(HloInstruction * tuple)356 Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) {
357   absl::Span<HloInstruction* const> operands(tuple->operands());
358   PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple);
359   points_to_set.AddPointedToBuffer(
360       logical_buffer_analysis_->GetBuffer(tuple, /*index=*/{}),
361       /*index=*/{});
362 
363   // A tuple contains references to all input operands and transitively any
364   // references in those operands.
365   for (int64 i = 0; i < operands.size(); ++i) {
366     const PointsToSet& operand_points_to_set =
367         *PerInst(operands[i])->points_to_set;
368 
369     // Copy the points-to set (and tuple sources) of the operand into the
370     // respective subtree of the tuple instructions points-to set.
371     operand_points_to_set.ForEachElement(
372         [&points_to_set, &operand_points_to_set, i](
373             const ShapeIndex& src_index,
374             const PointsToSet::BufferList& points_to) {
375           ShapeIndex target_index;
376           target_index.push_back(i);
377           for (auto element : src_index) {
378             target_index.push_back(element);
379           }
380 
381           *points_to_set.mutable_element(target_index) = points_to;
382 
383           for (HloInstruction* tuple :
384                operand_points_to_set.tuple_sources(src_index)) {
385             points_to_set.add_tuple_source(target_index, tuple);
386           }
387         });
388   }
389 
390   points_to_set.add_tuple_source({}, tuple);
391 
392   return Status::OK();
393 }
394 
HandleTupleSelect(HloInstruction * tuple_select)395 Status TuplePointsToAnalysis::HandleTupleSelect(HloInstruction* tuple_select) {
396   // Select allocates a new buffer and then shallow copies the on_true or
397   // on_false buffer into this new buffer. Which side is chosen cannot be
398   // determined statically so conservatively set the points-to set to the union
399   // of these on_true and on_false operands.
400   //
401   // First create a copy of the on_true points-to set (and tuple sources), then
402   // add in elements of the on_false points-to set (tuple sources).
403   auto on_true = tuple_select->operand(1);
404   auto on_false = tuple_select->operand(2);
405   PointsToSet& points_to_set = CreateCopiedPointsToSet(tuple_select, on_true);
406   const PointsToSet& false_points_to_set = *PerInst(on_false)->points_to_set;
407   points_to_set.ForEachMutableElement(
408       [&](const ShapeIndex& index, PointsToSet::BufferList* buffers) {
409         for (const LogicalBuffer* false_buffer :
410              false_points_to_set.element(index)) {
411           points_to_set.AddPointedToBuffer(*false_buffer, index);
412         }
413 
414         for (HloInstruction* tuple : false_points_to_set.tuple_sources(index)) {
415           points_to_set.add_tuple_source(index, tuple);
416         }
417       });
418 
419   // Select creates a new (top-level) buffer to store its result, so its
420   // respective element in the points-to set should contain only itself.
421   points_to_set.mutable_element({})->clear();
422   points_to_set.AddPointedToBuffer(
423       logical_buffer_analysis_->GetBuffer(tuple_select, /*index=*/{}),
424       /*index=*/{});
425   return Status::OK();
426 }
427 
GetPointsToSet(const HloInstruction * hlo_instruction) const428 const PointsToSet& TuplePointsToAnalysis::GetPointsToSet(
429     const HloInstruction* hlo_instruction) const {
430   return *PerInst(hlo_instruction)->points_to_set;
431 }
432 
CreateEmptyPointsToSet(const HloInstruction * instruction)433 PointsToSet& TuplePointsToAnalysis::CreateEmptyPointsToSet(
434     const HloInstruction* instruction) {
435   PerInstruction* pi = PerInst(instruction);
436   CHECK(pi->points_to_set == nullptr)
437       << "instruction should not have been present in the map.";
438   auto set = absl::make_unique<PointsToSet>(&instruction->shape());
439   pi->points_to_set = std::move(set);
440   // Return *set using the iterator returned by emplace.
441   return *pi->points_to_set;
442 }
443 
InstructionDefinesBufferAtIndex(const HloInstruction * instruction,const ShapeIndex & index) const444 bool TuplePointsToAnalysis::InstructionDefinesBufferAtIndex(
445     const HloInstruction* instruction, const ShapeIndex& index) const {
446   const auto& buffers = GetPointsToSet(instruction).element(index);
447   return (buffers.size() == 1 && buffers[0]->instruction() == instruction);
448 }
449 
VerifyBuffer(const LogicalBuffer & buffer) const450 Status TuplePointsToAnalysis::VerifyBuffer(const LogicalBuffer& buffer) const {
451   if (!InstructionDefinesBufferAtIndex(buffer.instruction(), buffer.index())) {
452     return FailedPrecondition(
453         "LogicalBuffer %s is ill-defined: instruction %s does not define a "
454         "buffer at that index",
455         buffer.ToString(), buffer.instruction()->name());
456   }
457 
458   if (buffer.id() < 0 ||
459       buffer.id() >= logical_buffer_analysis_->num_logical_buffers()) {
460     return FailedPrecondition("LogicalBuffer %s is ill-defined: invalid id %d",
461                               buffer.ToString(), buffer.id());
462   }
463   if (GetBuffer(buffer.id()).instruction() != buffer.instruction() ||
464       GetBuffer(buffer.id()).index() != buffer.index()) {
465     return FailedPrecondition(
466         "LogicalBuffer %s is ill-defined: buffer with same id differs: %s",
467         buffer.ToString(), GetBuffer(buffer.id()).ToString());
468   }
469 
470   return Status::OK();
471 }
472 
GetBuffer(LogicalBuffer::Id id) const473 const LogicalBuffer& TuplePointsToAnalysis::GetBuffer(
474     LogicalBuffer::Id id) const {
475   CHECK_GE(id, 0);
476   CHECK_LT(id, logical_buffer_analysis_->num_logical_buffers());
477   return logical_buffer_analysis_->GetBuffer(id);
478 }
479 
GetBufferDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const480 StatusOr<const LogicalBuffer*> TuplePointsToAnalysis::GetBufferDefinedAt(
481     const HloInstruction* instruction, const ShapeIndex& index) const {
482   const auto& buffers = GetPointsToSet(instruction).element(index);
483   if (buffers.size() != 1 || buffers[0]->instruction() != instruction) {
484     return FailedPrecondition(
485         "instruction %s does not define buffer at index {%s}",
486         instruction->name(), absl::StrJoin(index, ","));
487   }
488   return buffers[0];
489 }
490 
491 const TuplePointsToAnalysis::BufferAliasVector&
GetBufferAliases(const LogicalBuffer & buffer) const492 TuplePointsToAnalysis::GetBufferAliases(const LogicalBuffer& buffer) const {
493   return logical_buffer_aliases_.at(buffer.id());
494 }
495 
496 const TuplePointsToAnalysis::BufferDefinitionVector&
GetBuffersDefinedByInstruction(const HloInstruction * instruction) const497 TuplePointsToAnalysis::GetBuffersDefinedByInstruction(
498     const HloInstruction* instruction) const {
499   return PerInst(instruction)->instruction_defined_buffers;
500 }
501 
GatherBuffersDefinedByInstruction(const HloInstruction * instruction,TuplePointsToAnalysis::BufferDefinitionVector * buffers)502 Status TuplePointsToAnalysis::GatherBuffersDefinedByInstruction(
503     const HloInstruction* instruction,
504     TuplePointsToAnalysis::BufferDefinitionVector* buffers) {
505   GetPointsToSet(instruction)
506       .ForEachElement([buffers, instruction](
507                           const ShapeIndex& index,
508                           const PointsToSet::BufferList& source_buffers) {
509         // Add buffers which 'instruction' is the source of.
510         CHECK(!source_buffers.empty());
511         if (source_buffers.size() == 1 &&
512             source_buffers[0]->instruction() == instruction) {
513           // If this instruction is the source of this buffer the
514           // indices must match.
515           DCHECK(source_buffers[0]->index() == index);
516           buffers->push_back(source_buffers[0]);
517         } else {
518           // If the points-to set includes more than one buffer then
519           // necessarily this instruction did not produce the
520           // buffer.
521           for (const LogicalBuffer* source_buffer : source_buffers) {
522             DCHECK(source_buffer->instruction() != instruction);
523           }
524         }
525       });
526   return Status::OK();
527 }
528 
CreateCopiedPointsToSet(const HloInstruction * instruction,const HloInstruction * src)529 PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet(
530     const HloInstruction* instruction, const HloInstruction* src) {
531   // PointsToSet doesn't have a copy constructor so copy over element-by-element
532   // from src PointsToSet.
533   PointsToSet& dst_points_to_set = CreateEmptyPointsToSet(instruction);
534   const PointsToSet& src_points_to_set = GetPointsToSet(src);
535   dst_points_to_set.ForEachMutableElement(
536       [&dst_points_to_set, &src_points_to_set](
537           const ShapeIndex& index, PointsToSet::BufferList* buffers) {
538         *buffers = src_points_to_set.element(index);
539         for (auto& tuple_source : src_points_to_set.tuple_sources(index)) {
540           dst_points_to_set.add_tuple_source(index, tuple_source);
541         }
542       });
543   return *PerInst(instruction)->points_to_set;
544 }
545 
ToString() const546 string TuplePointsToAnalysis::ToString() const {
547   string output =
548       absl::StrFormat("TuplePointsToSet for module %s:\n", module_->name());
549   for (const auto* computation : module_->MakeNonfusionComputations()) {
550     const char* entry =
551         computation == module_->entry_computation() ? "entry " : "";
552     absl::StrAppend(&output, entry, "computation ", computation->name(), ":\n");
553     for (const HloInstruction* instruction :
554          computation->MakeInstructionPostOrder()) {
555       InstructionToString(instruction, &output);
556       if (instruction->opcode() == HloOpcode::kFusion) {
557         for (auto* fused : instruction->fused_instructions()) {
558           InstructionToString(fused, &output);
559         }
560       }
561     }
562   }
563 
564   absl::StrAppend(&output, "LogicalBuffers:\n");
565   for (const auto& b : logical_buffer_analysis_->logical_buffers()) {
566     absl::StrAppend(&output, "  buffer ", b->ToString(), ":\n");
567     for (const BufferAlias& alias : logical_buffer_aliases_.at(b->id())) {
568       absl::StrAppend(&output, "    alias ", alias.ToString(), "\n");
569     }
570   }
571   return output;
572 }
573 
InstructionToString(const HloInstruction * instruction,string * output) const574 void TuplePointsToAnalysis::InstructionToString(
575     const HloInstruction* instruction, string* output) const {
576   const string prefix = instruction->IsFused() ? "    " : "";
577   absl::StrAppend(output, prefix, "  instruction ",
578                   instruction->ToShortString(), ":\n");
579   const PointsToSet& points_to_set = GetPointsToSet(instruction);
580   points_to_set.ForEachElement([&prefix, &output](
581                                    const ShapeIndex& index,
582                                    const PointsToSet::BufferList& points_to) {
583     absl::StrAppend(output, prefix, "    {", absl::StrJoin(index, ","), "}: ",
584                     absl::StrJoin(points_to, ", ",
585                                   [](string* out, const LogicalBuffer* source) {
586                                     out->append(source->ToString());
587                                   }),
588                     "\n");
589   });
590 }
591 
DoesNotUseOperandBuffer(const HloInstruction * operand,const ShapeIndex & index,const HloInstruction * user) const592 bool TuplePointsToAnalysis::DoesNotUseOperandBuffer(
593     const HloInstruction* operand, const ShapeIndex& index,
594     const HloInstruction* user) const {
595   CHECK(user->IsUserOf(operand))
596       << "user: " << user->ToString() << " operand: " << operand->ToString();
597   if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) {
598     // GetTupleElement instructions only access the top-level buffer of their
599     // operand.
600     return true;
601   } else if (user->opcode() == HloOpcode::kFusion &&
602              user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
603     // Find fusion parameter associated with 'operand'.
604     auto it = absl::c_find_if(
605         user->fused_parameters(), [&](HloInstruction* fused_param) {
606           return user->operand(fused_param->parameter_number()) == operand;
607         });
608     CHECK(it != user->fused_parameters().end());
609     // Iterate through all users of all buffer aliases of the buffer in the
610     // points-to set of fusion parameter at 'index'.
611     // Return false if any uses are detected at 'index', returns true otherwise.
612     const LogicalBuffer* buffer = GetBufferDefinedAt(*it, index).ValueOrDie();
613     for (const BufferAlias& alias : GetBufferAliases(*buffer)) {
614       for (HloInstruction* alias_user : alias.instruction()->users()) {
615         if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
616                                     alias_user)) {
617           continue;
618         }
619         // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'.
620         return false;
621       }
622     }
623     // Return true: found no uses of 'operand' at 'index' in 'user'.
624     return true;
625   }
626   return false;
627 }
628 
629 // Returns all uses of all aliases of 'instruction' at 'index' in 'uses'.
630 // Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index)
631 // where 'user' is a user of an alias of 'instruction' at 'index', and
632 // 'operand_index' is the operand index at which the alias appears in the
633 // operand list of 'user'.
634 std::vector<std::pair<HloInstruction*, int64>>
GetAllUsesOfInstructionAtIndex(HloInstruction * instruction,const ShapeIndex & index) const635 TuplePointsToAnalysis::GetAllUsesOfInstructionAtIndex(
636     HloInstruction* instruction, const ShapeIndex& index) const {
637   std::vector<std::pair<HloInstruction*, int64>> uses;
638   const PointsToSet::BufferList& points_to =
639       GetPointsToSet(instruction).element(index);
640   for (const LogicalBuffer* buffer : points_to) {
641     for (const BufferAlias& alias : GetBufferAliases(*buffer)) {
642       for (HloInstruction* alias_user : alias.instruction()->users()) {
643         if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
644                                     alias_user)) {
645           continue;
646         }
647         for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) {
648           uses.emplace_back(alias_user, op_idx);
649         }
650       }
651     }
652   }
653   return uses;
654 }
655 
656 // Returns true if there is exactly one use of 'operand' at 'operand_index'
657 // in 'fusion.fused_instructions', where the singleton use is the fused
658 // root at operand index 'use_operand_index'. Returns false otherwise.
659 //
660 // REQUIRES: 'fusion' opcode is a kFusion instruction.
HasUniqueFusedUseOfOperandAt(HloInstruction * operand,const ShapeIndex & operand_index,HloInstruction * fusion,const int64 use_operand_index) const661 bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt(
662     HloInstruction* operand, const ShapeIndex& operand_index,
663     HloInstruction* fusion, const int64 use_operand_index) const {
664   CHECK_EQ(HloOpcode::kFusion, fusion->opcode());
665   // Check that 'operand' is unique in the operand list of 'fusion'.
666   if (fusion->OperandIndices(operand).size() > 1) {
667     return false;
668   }
669   // Find fusion parameter associated with 'operand'.
670   const auto& fused_params = fusion->fused_parameters();
671   auto fused_param_it =
672       absl::c_find_if(fused_params, [&](HloInstruction* fused_param) {
673         return fusion->operand(fused_param->parameter_number()) == operand;
674       });
675   if (fused_param_it == fused_params.end()) {
676     return false;
677   }
678   auto* fused_param = *fused_param_it;
679   // Get all uses of 'operand' at 'index' from 'fusion.fused_instructions'.
680   auto fused_param_uses =
681       GetAllUsesOfInstructionAtIndex(fused_param, operand_index);
682   // Return true iff there is exactly one use of 'operand' at 'index', and
683   // this singleton use is the fused root (at index in 'use_operand_indices').
684   return fused_param_uses.size() == 1 &&
685          fused_param_uses[0].first == fusion->fused_expression_root() &&
686          fused_param_uses[0].second == use_operand_index;
687 }
688 
689 // User and operand can share buffers iff both instructions emit the same shape
690 // and layout, and 'user' meets one of the following qualifications:
691 //
692 // (1) Is element-wise. Or...
693 // (2) Is a loop fusion instruction where the only use of 'operand' at 'index'
694 //     in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root
695 //     at operand 0. Or...
696 // (3) Is a kDot -> kAdd output fusion instruction where the only use of
697 //     'operand' at 'index' in the set 'user.fused_instructions' is a kAdd fused
698 //     root at operand 0 or 1. Or...
699 // (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index
700 //     0.
701 // (5) The 'user' of 'operand' is Sort, and it is the only user.
702 // (6) The 'user' of 'operand' is TriangularSolve, it is the second operand,
703 //     and it is the only user.
704 //
705 // (2) and (3) can only be determined if points-to analysis is available.
CanShareOperandBufferWithUser(HloInstruction * operand,const ShapeIndex & operand_index,HloInstruction * user,const ShapeIndex & user_index) const706 bool TuplePointsToAnalysis::CanShareOperandBufferWithUser(
707     HloInstruction* operand, const ShapeIndex& operand_index,
708     HloInstruction* user, const ShapeIndex& user_index) const {
709   CHECK(user->IsUserOf(operand))
710       << "user: " << user->ToString() << " operand: " << operand->ToString();
711   const Shape& operand_subshape =
712       ShapeUtil::GetSubshape(operand->shape(), operand_index);
713   const Shape& user_subshape =
714       ShapeUtil::GetSubshape(user->shape(), user_index);
715   // Check that operand and user emit the same shape and layout.
716   if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
717     return false;
718   }
719   if (user->opcode() == HloOpcode::kFusion) {
720     if (user->fusion_kind() == HloInstruction::FusionKind::kLoop ||
721         user->fusion_kind() == HloInstruction::FusionKind::kInput) {
722       if (user->fused_expression_root()->opcode() ==
723           HloOpcode::kDynamicUpdateSlice) {
724         // Loop fusion with kDynamicUpdateSlice fused root.
725         //
726         // Returns true iff there is exactly one use of 'operand' at shape index
727         // 'operand_index', and this singleton use is the fused root at operand
728         // index 0.
729         return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0);
730       } else {
731         HloInstruction* fusion_param =
732             user->fused_parameter(user->operand_index(operand));
733         return HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
734             fusion_param);
735       }
736     } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput &&
737                user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
738       // Output fusion with kAdd fused root.
739 
740       // Check if one operand of kAdd fused root is kDot or kConvolution.
741       auto* add = user->fused_expression_root();
742       auto add_operand_it =
743           absl::c_find_if(add->operands(), [&](HloInstruction* operand) {
744             return operand->opcode() == HloOpcode::kConvolution ||
745                    operand->opcode() == HloOpcode::kDot;
746           });
747       if (add_operand_it == add->operands().end()) {
748         return false;
749       }
750       auto* matched_add_operand = *add_operand_it;
751       // Calculate operand index of 'add' operand which was not matched above.
752       const int64 other_add_operand_index =
753           matched_add_operand == add->operand(0) ? 1 : 0;
754       // Returns true iff there is exactly one use of 'operand' at shape index
755       // 'operand_index', and this singleton use is the fused root (at operand
756       // index 'other_add_operand_index').
757       return HasUniqueFusedUseOfOperandAt(operand, operand_index, user,
758                                           other_add_operand_index);
759     }
760   }
761   if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
762       user->opcode() == HloOpcode::kScatter ||
763       user->opcode() == HloOpcode::kWhile) {
764     // We eliminated other users in BufferLiveness::live_range_strictly_before,
765     // so here we just need to check that the use is at operand index 0.
766     std::vector<int64> operand_indices = user->OperandIndices(operand);
767     return operand_indices.size() == 1 && operand_indices[0] == 0;
768   }
769   if (user->opcode() == HloOpcode::kSort) {
770     // Only valid if there are no other users.
771     if (operand->users().size() != 1) {
772       return false;
773     }
774     // If we only sort keys, the output of sort is not a tuple, so we can always
775     // share the buffer.
776     if (user->operand_count() == 1) {
777       return true;
778     }
779     CHECK(!user_index.empty());
780     // Only share with the right tuple element buffer.
781     std::vector<int64> operand_indices = user->OperandIndices(operand);
782     return operand_indices.size() == 1 && user_index[0] == operand_indices[0];
783   }
784   if (user->opcode() == HloOpcode::kTriangularSolve) {
785     // Only valid if there are no other users.
786     if (operand->users().size() != 1) {
787       return false;
788     }
789     std::vector<int64> operand_indices = user->OperandIndices(operand);
790     return operand_indices.size() == 1 && operand_indices[0] == 1;
791   }
792   if (user->opcode() == HloOpcode::kCall) {
793     // TODO(b/62548313): Remove when buffer assignment is module scoped and
794     // does not assign buffers to calls.
795     // Find called computation parameter associated with 'operand'.
796     const std::vector<int64> operand_indices = user->OperandIndices(operand);
797     if (operand_indices.size() > 1) {
798       return false;
799     }
800     CHECK_EQ(1, operand_indices.size());
801     auto* param = user->to_apply()->parameter_instruction(operand_indices[0]);
802     // Get all uses of 'operand' at 'index' in called computation.
803     auto param_uses = GetAllUsesOfInstructionAtIndex(param, operand_index);
804 
805     // Return true iff:
806     // *) There exists exactly one use of 'operand' in called computation.
807     // *) The unique use is by the root instruction of called computation.
808     //    (Note: we check the root of the called computation, because the
809     //     root result buffer is required to alias with the Call result buffer).
810     // *) The root instruction of the called computation is element-wise on
811     //    'operand'.
812     auto* callee_root = user->to_apply()->root_instruction();
813     return param_uses.size() == 1 && param_uses[0].first == callee_root &&
814            callee_root->IsElementwiseOnOperand(param_uses[0].second);
815   }
816   // Loop fusions that contain transposing copies won't reach here as they have
817   // different layouts, which fails the check in the beginning of this function.
818   //
819   // Multi-output fusion will fail the check here as tuples are not considered
820   // an elementwise operation.
821   return user->IsElementwiseOnOperand(user->operand_index(operand));
822 }
823 
824 }  // namespace xla
825