• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
17 
18 #include <algorithm>
19 #include <queue>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/container/inlined_vector.h"
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_cat.h"
26 #include "tensorflow/compiler/xla/map_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/platform/logging.h"
38 
39 namespace xla {
40 
41 using absl::StrAppend;
42 using absl::StrCat;
43 
HloDataflowAnalysis(const HloModule & module,bool ssa_form,bool bitcast_defines_value,const CanShareBuffer & can_share_buffer)44 HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form,
45                                          bool bitcast_defines_value,
46                                          const CanShareBuffer& can_share_buffer)
47     : module_(module),
48       ssa_form_(ssa_form),
49       bitcast_defines_value_(bitcast_defines_value),
50       call_graph_(CallGraph::Build(&module)),
51       can_share_buffer_(can_share_buffer) {}
52 
AreTransitiveUsesElementwiseOrTuple(const HloInstruction * inst)53 bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
54     const HloInstruction* inst) {
55   absl::flat_hash_set<const HloInstruction*> visited;
56   absl::InlinedVector<const HloInstruction*, 4> stack;
57   stack.push_back(inst);
58   while (!stack.empty()) {
59     const HloInstruction* current = stack.back();
60     stack.pop_back();
61     visited.insert(current);
62     for (const HloInstruction* user : current->users()) {
63       // Found a user that is non-elementwise on current instruction.
64       for (const int64 use_index : user->OperandIndices(current)) {
65         if (!user->IsElementwiseOnOperand(use_index) &&
66             user->opcode() != HloOpcode::kTuple) {
67           return false;
68         }
69       }
70       if (!visited.contains(user)) {
71         stack.push_back(user);
72       }
73     }
74   }
75   return true;
76 }
77 
ValueIsDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const78 bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
79                                            const ShapeIndex& index) const {
80   const HloValueSet& value_set = GetValueSet(instruction, index);
81   if (value_set.values().size() != 1) {
82     return false;
83   }
84   return value_set.GetUniqueValue().defining_instruction() == instruction;
85 }
86 
GetValueDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const87 const HloValue& HloDataflowAnalysis::GetValueDefinedAt(
88     const HloInstruction* instruction, const ShapeIndex& index) const {
89   CHECK(ValueIsDefinedAt(instruction, index)) << instruction->ToString();
90   return GetUniqueValueAt(instruction, index);
91 }
92 
GetValueDefinedAt(const HloInstruction * instruction,const ShapeIndex & index)93 HloValue& HloDataflowAnalysis::GetValueDefinedAt(
94     const HloInstruction* instruction, const ShapeIndex& index) {
95   CHECK(ValueIsDefinedAt(instruction, index));
96   return GetUniqueValueAt(instruction, index);
97 }
98 
NewHloValue(HloInstruction * instruction,const ShapeIndex & index,bool is_phi)99 HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction,
100                                            const ShapeIndex& index,
101                                            bool is_phi) {
102   const int64 value_id = next_value_id_++;
103   auto emplaced = values_.emplace(
104       std::piecewise_construct, std::forward_as_tuple(value_id),
105       std::forward_as_tuple(value_id, instruction, index, is_phi));
106   CHECK(emplaced.second);
107 
108   VLOG(4) << "NewHloValue = " << emplaced.first->second.ToShortString();
109 
110   return &emplaced.first->second;
111 }
112 
MarkValueForDeletion(HloValue::Id value_id)113 void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) {
114   HloValue& value = values_.at(value_id);
115   VLOG(4) << "MarkValueForDeletion(" << value.ToShortString() << ")";
116 
117   value_ids_to_delete_.push_back(value_id);
118 }
119 
DeleteMarkedValues()120 void HloDataflowAnalysis::DeleteMarkedValues() {
121 #ifndef NDEBUG
122   // Verify that no marked-for-deletion values are in any of the value sets.
123   absl::flat_hash_set<HloValue::Id> id_set(value_ids_to_delete_.begin(),
124                                            value_ids_to_delete_.end());
125   for (const auto& pair : value_sets_) {
126     const HloInstruction* instruction = pair.first;
127     const InstructionValueSet& instruction_value_set = pair.second;
128     for (const auto& index_value_set : instruction_value_set) {
129       const HloValueSet& value_set = index_value_set.second;
130       for (const HloValue* value : value_set.values()) {
131         DCHECK(!ContainsKey(id_set, value->id()))
132             << "Value " << value->ToShortString()
133             << " marked for deletion, but still exists in value set for "
134                "instruction "
135             << instruction->name();
136       }
137     }
138   }
139 #endif
140 
141   for (HloValue::Id value_id : value_ids_to_delete_) {
142     values_.erase(value_id);
143   }
144   value_ids_to_delete_.clear();
145 }
146 
ToString() const147 string HloDataflowAnalysis::ToString() const {
148   string out = StrCat("HloDataflowAnalysis, module ", module_.name(), "\n");
149   StrAppend(&out, "  Instruction value sets:\n");
150   for (const HloComputation* computation : module_.computations()) {
151     for (const HloInstruction* instruction : computation->instructions()) {
152       StrAppend(&out, "Instruction: \n  ", instruction->name(), ":\n");
153       if (instruction->shape().IsTuple()) {
154         GetInstructionValueSet(instruction)
155             .ForEachElement([this, &instruction, &out](
156                                 const ShapeIndex& index,
157                                 const HloValueSet& value_set) {
158               StrAppend(&out, "      tuple index ", index.ToString(), ":\n");
159               for (const HloValue* value : value_set.values()) {
160                 StrAppend(&out, "        ", value->ToShortString(),
161                           ValueIsDefinedAt(instruction, index) ? " (def)" : "",
162                           "\n");
163               }
164             });
165       } else {
166         const HloValueSet& top_level_value_set =
167             GetValueSet(instruction, /*index=*/{});
168         for (const HloValue* value : top_level_value_set.values()) {
169           StrAppend(&out, "      ", value->ToShortString(),
170                     ValueIsDefinedAt(instruction) ? " (def)" : "", "\n");
171         }
172       }
173     }
174   }
175   StrAppend(&out, "  HloValues:\n");
176   for (const HloValue* value : values()) {
177     StrAppend(&out, value->ToString(/*indent=*/4));
178   }
179   return out;
180 }
181 
Phi(HloInstruction * instruction,absl::Span<const InstructionValueSet * const> inputs)182 bool HloDataflowAnalysis::Phi(
183     HloInstruction* instruction,
184     absl::Span<const InstructionValueSet* const> inputs) {
185   CHECK(ssa_form_);
186   VLOG(4) << "Phi(" << instruction->name() << ")";
187   VLOG(5) << "instruction value set = "
188           << GetInstructionValueSet(instruction).ToString();
189   for (const InstructionValueSet* input : inputs) {
190     VLOG(5) << "input value set = " << input->ToString();
191   }
192 
193   if (bitcast_defines_value_) {
194     absl::c_for_each(inputs, [&](const InstructionValueSet* input) {
195       DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape()));
196     });
197   } else {
198     const Shape& shape = instruction->shape();
199     PrimitiveType ty = shape.element_type();
200     bool is_array = shape.IsArray();
201     absl::c_for_each(inputs, [&](const InstructionValueSet* input) {
202       DCHECK(ty == input->shape().element_type() &&
203              (!is_array || ShapeUtil::ElementsIn(shape) ==
204                                ShapeUtil::ElementsIn(input->shape())));
205     });
206   }
207 
208   bool changed = false;
209   for (auto& pair : GetInstructionValueSet(instruction)) {
210     const ShapeIndex& index = pair.first;
211     HloValueSet& value_set = pair.second;
212 
213     // Positions with phi values should never have more than one value in the
214     // value set.
215     CHECK_LE(value_set.values().size(), 1);
216     const HloValue* current_value =
217         value_set.values().size() == 1 ? value_set.values()[0] : nullptr;
218 
219     // Construct a vector of unique value IDs of the inputs.
220     // Don't add value ids where the input is equal to the definition.
221     std::vector<HloValue::Id> input_value_ids;
222     for (const InstructionValueSet* input : inputs) {
223       for (const HloValue* value : input->element(index).values()) {
224         if (value->defining_instruction() == instruction &&
225             value->defining_index() == index) {
226           continue;
227         }
228         input_value_ids.push_back(value->id());
229       }
230     }
231     absl::c_sort(input_value_ids);
232     input_value_ids.erase(
233         std::unique(input_value_ids.begin(), input_value_ids.end()),
234         input_value_ids.end());
235 
236     // Remove the existing phi value (if it exists). The phi can be its own
237     // input, for example, in while body parameters where the body passes
238     // through the parameter value.
239     bool current_value_defined_here =
240         (current_value != nullptr &&
241          current_value->defining_instruction() == instruction &&
242          current_value->defining_index() == index);
243     if (current_value_defined_here) {
244       VLOG(5) << "current_value_defined_here: " << current_value->ToString();
245       CHECK(current_value->is_phi());
246       auto it = absl::c_find(input_value_ids, current_value->id());
247       if (it != input_value_ids.end()) {
248         input_value_ids.erase(it);
249       }
250     }
251     VLOG(5) << "after input_value_ids.size = " << input_value_ids.size();
252     if (input_value_ids.empty()) {
253       // A value set which has at least one element should never have its value
254       // set reduced to zero elements. During dataflow value sets only can go
255       // from empty to non-empty, not the reverse.
256       CHECK_EQ(value_set.values().size(), 0)
257           << "Instruction " << instruction->name() << " at index " << index
258           << " previously had non-empty value set. Value set: " << value_set;
259     } else if (input_value_ids.size() == 1) {
260       // Only a single value reaches this point. There should be no phi, and
261       // this value set should contain this single value.
262       const HloValue& new_value = GetValue(input_value_ids[0]);
263       if (current_value == nullptr) {
264         value_set.Clear();
265         value_set.AddValue(&new_value);
266         changed = true;
267       } else if (current_value != &new_value) {
268         if (current_value_defined_here) {
269           // Remove the existing phi.
270           MarkValueForDeletion(current_value->id());
271         }
272         value_set.Clear();
273         value_set.AddValue(&new_value);
274         changed = true;
275       }
276     } else {
277       // Multiple distinct values reach this point. A phi value is
278       // necessary.
279       CHECK_GT(input_value_ids.size(), 1);
280       if (current_value == nullptr ||
281           !(current_value->is_phi() && current_value_defined_here)) {
282         value_set.Clear();
283         value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true));
284         changed = true;
285       }
286     }
287   }
288   return changed;
289 }
290 
GetValue(HloValue::Id value_id) const291 const HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) const {
292   return values_.at(value_id);
293 }
294 
GetValue(HloValue::Id value_id)295 HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) {
296   return values_.at(value_id);
297 }
298 
GetFlattenedValueSet(const HloInstruction * instruction) const299 HloValueSet HloDataflowAnalysis::GetFlattenedValueSet(
300     const HloInstruction* instruction) const {
301   HloValueSet value_set;
302 
303   const InstructionValueSet& value_set_tree =
304       GetInstructionValueSet(instruction);
305 
306   std::vector<const HloValueSet*> all_sets;
307   for (auto& pair : value_set_tree) {
308     const HloValueSet& value_set = pair.second;
309     all_sets.push_back(&value_set);
310   }
311   value_set.AssignUnionOf(all_sets);
312 
313   return value_set;
314 }
315 
GetValueSet(const HloInstruction * instruction,const ShapeIndex & index) const316 const HloValueSet& HloDataflowAnalysis::GetValueSet(
317     const HloInstruction* instruction, const ShapeIndex& index) const {
318   return GetInstructionValueSet(instruction).element(index);
319 }
320 
GetValueSet(const HloInstruction * instruction,const ShapeIndex & index)321 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloInstruction* instruction,
322                                               const ShapeIndex& index) {
323   return *GetInstructionValueSet(instruction).mutable_element(index);
324 }
325 
GetValueSet(const HloPosition & position) const326 const HloValueSet& HloDataflowAnalysis::GetValueSet(
327     const HloPosition& position) const {
328   return GetValueSet(position.instruction, position.index);
329 }
330 
GetValueSet(const HloPosition & position)331 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloPosition& position) {
332   return GetValueSet(position.instruction, position.index);
333 }
334 
UpdateBitcastValueSet(HloInstruction * bitcast)335 bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) {
336   CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast);
337   const InstructionValueSet& operand_set =
338       GetInstructionValueSet(bitcast->operand(0));
339   InstructionValueSet& bitcast_set = GetInstructionValueSet(bitcast);
340   if (!bitcast_defines_value_ && operand_set != bitcast_set) {
341     bitcast_set = operand_set;
342     return true;
343   }
344   return false;
345 }
346 
UpdateSetDimensionSizeValueSet(HloInstruction * set_dimension_size)347 bool HloDataflowAnalysis::UpdateSetDimensionSizeValueSet(
348     HloInstruction* set_dimension_size) {
349   CHECK_EQ(set_dimension_size->opcode(), HloOpcode::kSetDimensionSize);
350   const InstructionValueSet& operand_set =
351       GetInstructionValueSet(set_dimension_size->operand(0));
352   InstructionValueSet& set_dimension_size_set =
353       GetInstructionValueSet(set_dimension_size);
354   if (operand_set != set_dimension_size_set) {
355     set_dimension_size_set = operand_set;
356     return true;
357   }
358   return false;
359 }
360 
UpdateSendValueSet(HloInstruction * send)361 bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
362   CHECK_EQ(send->opcode(), HloOpcode::kSend);
363   bool changed = false;
364   // Send forwards the operand value to the output tuple at {0}.
365   for (auto& pair : GetInstructionValueSet(send->operand(0))) {
366     const ShapeIndex& operand_index = pair.first;
367     const HloValueSet& operand_value_set = pair.second;
368 
369     ShapeIndex index = {0};
370     for (int64 i : operand_index) {
371       index.push_back(i);
372     }
373 
374     HloValueSet& value_set = GetValueSet(send, index);
375     if (value_set != operand_value_set) {
376       value_set = operand_value_set;
377       changed = true;
378     }
379   }
380   return changed;
381 }
382 
UpdateCopyStartValueSet(HloInstruction * copy_start)383 bool HloDataflowAnalysis::UpdateCopyStartValueSet(HloInstruction* copy_start) {
384   CHECK_EQ(copy_start->opcode(), HloOpcode::kCopyStart);
385   bool changed = false;
386   // CopyStart forwards the operand value to element {1} of its output.
387   const HloValueSet& operand_value_set = GetValueSet(copy_start->operand(0));
388   HloValueSet& value_set = GetValueSet(copy_start, {1});
389   if (value_set != operand_value_set) {
390     value_set = operand_value_set;
391     changed = true;
392   }
393   return changed;
394 }
395 
UpdateCopyDoneValueSet(HloInstruction * copy_done)396 bool HloDataflowAnalysis::UpdateCopyDoneValueSet(HloInstruction* copy_done) {
397   CHECK_EQ(copy_done->opcode(), HloOpcode::kCopyDone);
398   bool changed = false;
399   // CopyDone forwards the operand value at {0} to element {} of its output.
400   const HloValueSet& operand_value_set =
401       GetValueSet(copy_done->operand(0), {0});
402   HloValueSet& value_set = GetValueSet(copy_done);
403   if (value_set != operand_value_set) {
404     value_set = operand_value_set;
405     changed = true;
406   }
407   return changed;
408 }
409 
UpdateRecvDoneValueSet(HloInstruction * recv_done)410 bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) {
411   CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone);
412   bool changed = false;
413   // RecvDone forwards the operand value at {0} to element {0} of its output.
414   for (auto& pair : GetInstructionValueSet(recv_done)) {
415     ShapeIndex& index = pair.first;
416     HloValueSet& value_set = pair.second;
417 
418     if (index.empty() || index[0] != 0) {
419       continue;
420     }
421 
422     const HloValueSet& operand_value_set =
423         GetValueSet(recv_done->operand(0), index);
424     if (value_set != operand_value_set) {
425       value_set = operand_value_set;
426       changed = true;
427     }
428   }
429   return changed;
430 }
431 
UpdateCallValueSet(HloInstruction * call)432 bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) {
433   CHECK_EQ(call->opcode(), HloOpcode::kCall);
434   InstructionValueSet& value_set = GetInstructionValueSet(call);
435   InstructionValueSet& root_value_set =
436       GetInstructionValueSet(call->to_apply()->root_instruction());
437   if (value_set != root_value_set) {
438     value_set = root_value_set;
439     return true;
440   }
441   return false;
442 }
443 
UpdateConditionalValueSet(HloInstruction * conditional)444 bool HloDataflowAnalysis::UpdateConditionalValueSet(
445     HloInstruction* conditional) {
446   CHECK_EQ(conditional->opcode(), HloOpcode::kConditional);
447   std::vector<const InstructionValueSet*> inputs(conditional->branch_count());
448   for (int j = 0; j < conditional->branch_count(); ++j) {
449     inputs[j] = &GetInstructionValueSet(
450         conditional->branch_computation(j)->root_instruction());
451   }
452   if (ssa_form_) {
453     return Phi(conditional, inputs);
454   } else {
455     return GetInstructionValueSet(conditional).AssignUnionOf(inputs);
456   }
457 }
458 
UpdateCopyValueSet(HloInstruction * copy)459 bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) {
460   CHECK_EQ(copy->opcode(), HloOpcode::kCopy);
461   bool changed = false;
462   for (auto& pair : GetInstructionValueSet(copy)) {
463     const ShapeIndex& index = pair.first;
464     if (index.empty()) {
465       // kCopy shallow copies and thus defines the top-level value so nothing to
466       // update.
467       continue;
468     }
469 
470     HloValueSet& value_set = pair.second;
471     HloValueSet& operand_value_set = GetValueSet(copy->operand(0), index);
472     if (value_set != operand_value_set) {
473       value_set = operand_value_set;
474       changed = true;
475     }
476   }
477   return changed;
478 }
479 
UpdateDomainValueSet(HloInstruction * domain)480 bool HloDataflowAnalysis::UpdateDomainValueSet(HloInstruction* domain) {
481   // Domain instructions just forward their operand. Given that domains can have
482   // a tuple operand, we iterate through its indexes, like for copies.
483   // Unlike copies though we also propagate the top-level value.
484   CHECK_EQ(domain->opcode(), HloOpcode::kDomain);
485   bool changed = false;
486   for (auto& pair : GetInstructionValueSet(domain)) {
487     const ShapeIndex& index = pair.first;
488     HloValueSet& value_set = pair.second;
489     HloValueSet& operand_value_set = GetValueSet(domain->operand(0), index);
490     if (value_set != operand_value_set) {
491       value_set = operand_value_set;
492       changed = true;
493     }
494   }
495   return changed;
496 }
497 
UpdateAddDependencyValueSet(HloInstruction * add_dependency)498 bool HloDataflowAnalysis::UpdateAddDependencyValueSet(
499     HloInstruction* add_dependency) {
500   // AddDependency just forwards the value of its zero-th operand.
501   CHECK_EQ(add_dependency->opcode(), HloOpcode::kAddDependency);
502   const InstructionValueSet& operand_set =
503       GetInstructionValueSet(add_dependency->operand(0));
504   InstructionValueSet& add_dependency_set =
505       GetInstructionValueSet(add_dependency);
506   if (operand_set != add_dependency_set) {
507     add_dependency_set = operand_set;
508     return true;
509   }
510   return false;
511 }
512 
UpdateGetTupleElementValueSet(HloInstruction * gte)513 bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) {
514   CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement);
515   bool changed = false;
516   // The GetTupleElement instruction forwards the values from the specified
517   // tuple element.
518   for (auto& pair : GetInstructionValueSet(gte)) {
519     const ShapeIndex& index = pair.first;
520     HloValueSet& value_set = pair.second;
521 
522     // The corresponding ShapeIndex of the operand is simply the GTE ShapeIndex
523     // with the tuple element number prefixed.
524     ShapeIndex operand_index = {gte->tuple_index()};
525     for (int64 i : index) {
526       operand_index.push_back(i);
527     }
528 
529     HloValueSet& operand_value_set =
530         GetValueSet(gte->operand(0), operand_index);
531     if (value_set != operand_value_set) {
532       value_set = operand_value_set;
533       changed = true;
534     }
535   }
536   return changed;
537 }
538 
UpdateParameterValueSet(HloInstruction * parameter)539 bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) {
540   CHECK_EQ(parameter->opcode(), HloOpcode::kParameter);
541   const CallGraphNode& call_graph_node =
542       call_graph_->GetNode(parameter->parent());
543 
544   // Subcomputations called in a parallel context (eg, map) do not have dataflow
545   // from the caller operands.
546   if (call_graph_node.context() == CallContext::kParallel ||
547       call_graph_node.caller_callsites().empty()) {
548     return false;
549   }
550   CHECK_EQ(call_graph_node.context(), CallContext::kSequential);
551 
552   std::vector<const InstructionValueSet*> inputs;
553   bool need_phi = false;
554   for (const CallSite& callsite : call_graph_node.caller_callsites()) {
555     if (callsite.instruction()->opcode() == HloOpcode::kCall) {
556       // The operand values of a call instruction are forwarded to the
557       // respective parameter instruction of the subcomputation.
558       inputs.push_back(&GetInstructionValueSet(
559           callsite.instruction()->operand(parameter->parameter_number())));
560     } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
561       // In a while instruction, the while operand (ie, the init value) and the
562       // backedge are dataflow inputs to the parameter instruction. This is the
563       // case for parameters of both the body and condition computations.
564       CHECK_EQ(parameter->parameter_number(), 0);
565       inputs.push_back(
566           &GetInstructionValueSet(callsite.instruction()->operand(0)));
567       // If the parameter *is* the root, then don't consider it's current state
568       // (InstructionValueSet) as we are recomputing its current
569       // state. Otherwise, the parameter state would never be updated.
570       if (parameter !=
571           callsite.instruction()->while_body()->root_instruction()) {
572         inputs.push_back(&GetInstructionValueSet(
573             callsite.instruction()->while_body()->root_instruction()));
574       }
575       need_phi = true;
576     } else if (callsite.instruction()->opcode() == HloOpcode::kConditional) {
577       CHECK_EQ(parameter->parameter_number(), 0);
578       auto conditional = callsite.instruction();
579       // Conditional has branch_count+1 operands. Operand 0 is the branch_index,
580       // operands 1 and onward are the arguments to the branch computations.
581       //
582       // If the parameter belongs to conditional's branch 0 computation, then
583       // operand 1 is forwarded to this parameter instruction. If the parameter
584       // belongs to conditional's branch 5 computation, then operand 6 is
585       // forwarded to this parameter instruction.
586       bool found_parent = false;
587       for (int j = 0; j < conditional->branch_count(); ++j) {
588         if (parameter->parent() == conditional->branch_computation(j)) {
589           inputs.push_back(
590               &GetInstructionValueSet(conditional->operand(j + 1)));
591           found_parent = true;
592           break;
593         }
594       }
595       CHECK(found_parent);
596       need_phi = true;
597     } else {
598       LOG(FATAL) << "CallContext::kSequential computations should only be "
599                     "called from call, while, or conditional instructions";
600     }
601   }
602 
603   if (ssa_form_ && need_phi) {
604     return Phi(parameter, inputs);
605   } else {
606     return GetInstructionValueSet(parameter).AssignUnionOf(inputs);
607   }
608 }
609 
UpdateTupleSelectValueSet(HloInstruction * select)610 bool HloDataflowAnalysis::UpdateTupleSelectValueSet(HloInstruction* select) {
611   CHECK_EQ(select->opcode(), HloOpcode::kTupleSelect);
612   // A phi value is not defined at a kTupleSelect instruction because
613   // kTupleSelect does not create a new value. Rather it forwards a value from
614   // its operands. This contrasts with kWhile instruction (which does define a
615   // phi value) which has in-place update semantics.
616   bool changed = false;
617   for (auto& pair : GetInstructionValueSet(select)) {
618     const ShapeIndex& index = pair.first;
619     if (index.empty()) {
620       // kTupleSelect copies (not forwards) the top-level value.
621       continue;
622     }
623     HloValueSet& value_set = pair.second;
624     changed |=
625         value_set.AssignUnionOf({&GetValueSet(select->operand(1), index),
626                                  &GetValueSet(select->operand(2), index)});
627   }
628   return changed;
629 }
630 
UpdateTupleValueSet(HloInstruction * tuple)631 bool HloDataflowAnalysis::UpdateTupleValueSet(HloInstruction* tuple) {
632   CHECK_EQ(tuple->opcode(), HloOpcode::kTuple);
633   bool changed = false;
634   for (int64 i = 0; i < tuple->operands().size(); ++i) {
635     // Copy the value set(s) of each operand into the respective position in the
636     // kTuple instruction's value sets.
637     for (auto& pair : GetInstructionValueSet(tuple->operand(i))) {
638       const ShapeIndex& operand_index = pair.first;
639       HloValueSet& operand_value_set = pair.second;
640 
641       ShapeIndex index = {i};
642       for (int64 op_index : operand_index) {
643         index.push_back(op_index);
644       }
645       HloValueSet& value_set = GetValueSet(tuple, index);
646 
647       if (value_set != operand_value_set) {
648         value_set = operand_value_set;
649         changed = true;
650       }
651     }
652   }
653   return changed;
654 }
655 
UpdateWhileValueSet(HloInstruction * xla_while)656 bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) {
657   CHECK_EQ(xla_while->opcode(), HloOpcode::kWhile);
658   const InstructionValueSet* const inputs[] = {
659       &GetInstructionValueSet(xla_while->while_body()->root_instruction()),
660       &GetInstructionValueSet(xla_while->operand(0))};
661   if (ssa_form_) {
662     return Phi(xla_while, inputs);
663   } else {
664     return GetInstructionValueSet(xla_while).AssignUnionOf(inputs);
665   }
666 }
667 
UpdateInstructionValueSet(HloInstruction * instruction)668 bool HloDataflowAnalysis::UpdateInstructionValueSet(
669     HloInstruction* instruction) {
670   // Recompute from operands.
671   switch (instruction->opcode()) {
672     case HloOpcode::kAddDependency:
673       return UpdateAddDependencyValueSet(instruction);
674     case HloOpcode::kBitcast:
675       return UpdateBitcastValueSet(instruction);
676     case HloOpcode::kSetDimensionSize:
677       return UpdateSetDimensionSizeValueSet(instruction);
678     case HloOpcode::kDomain:
679       return UpdateDomainValueSet(instruction);
680     case HloOpcode::kCopy:
681       return UpdateCopyValueSet(instruction);
682     case HloOpcode::kGetTupleElement:
683       return UpdateGetTupleElementValueSet(instruction);
684     case HloOpcode::kTupleSelect:
685       return UpdateTupleSelectValueSet(instruction);
686     case HloOpcode::kTuple:
687       return UpdateTupleValueSet(instruction);
688     case HloOpcode::kParameter:
689       return UpdateParameterValueSet(instruction);
690     case HloOpcode::kCall:
691       return UpdateCallValueSet(instruction);
692     case HloOpcode::kWhile:
693       return UpdateWhileValueSet(instruction);
694     case HloOpcode::kSend:
695       return UpdateSendValueSet(instruction);
696     case HloOpcode::kRecvDone:
697       return UpdateRecvDoneValueSet(instruction);
698     case HloOpcode::kCopyStart:
699       return UpdateCopyStartValueSet(instruction);
700     case HloOpcode::kCopyDone:
701       return UpdateCopyDoneValueSet(instruction);
702     case HloOpcode::kConditional:
703       return UpdateConditionalValueSet(instruction);
704     default:
705       // Instruction does not forward HloValues (it defines all values in its
706       // output). No update is necessary.
707       return false;
708   }
709 }
710 
Propagate()711 void HloDataflowAnalysis::Propagate() {
712   std::queue<HloInstruction*> worklist;
713   absl::flat_hash_set<HloInstruction*> workset;
714   auto add_to_worklist = [&worklist, &workset](HloInstruction* instruction) {
715     if (workset.insert(instruction).second) {
716       worklist.push(instruction);
717     }
718   };
719 
720   for (HloComputation* computation : module_.computations()) {
721     for (HloInstruction* instruction : computation->instructions()) {
722       add_to_worklist(instruction);
723     }
724   }
725 
726   while (!worklist.empty()) {
727     HloInstruction* instruction = worklist.front();
728     worklist.pop();
729     workset.erase(workset.find(instruction));
730 
731     VLOG(3) << "Worklist top: " << instruction->name();
732     VLOG(3) << ToString();
733 
734     if (!UpdateInstructionValueSet(instruction)) {
735       // No change to the instruction's value set.
736       VLOG(4) << "No change.";
737       continue;
738     }
739 
740     VLOG(4) << "New value set for " << instruction->name() << ": "
741             << GetInstructionValueSet(instruction);
742 
743     // Instruction value was updated. Add users to work list if we haven't
744     // already.
745     for (HloInstruction* user : instruction->users()) {
746       add_to_worklist(user);
747 
748       // If user sequentially calls a computation, then the respective
749       // parameter(s) of the computation need to be updated.
750       if (user->opcode() == HloOpcode::kConditional) {
751         // If operand 0 is the use of instruction, then no parameters need to be
752         // updated, since that is the branch_index of the conditional.
753         // If operand n+1 is the use of instruction, then the branch_computation
754         // n's parameter need to be updated.
755         //
756         // Note that the same instruction can be used in multiple branches'
757         // operands.
758         for (int j = 0; j < user->branch_count(); ++j) {
759           if (user->operand(j + 1) == instruction) {
760             add_to_worklist(
761                 user->branch_computation(j)->parameter_instruction(0));
762           }
763         }
764       } else {
765         for (HloComputation* called_computation : user->called_computations()) {
766           const CallGraphNode& call_graph_node =
767               call_graph_->GetNode(called_computation);
768           if (call_graph_node.context() == CallContext::kSequential) {
769             for (int64 operand_number : user->OperandIndices(instruction)) {
770               add_to_worklist(
771                   called_computation->parameter_instruction(operand_number));
772             }
773           }
774         }
775       }
776     }
777 
778     // If instruction is a root instruction, then propagate out to any calling
779     // instruction and across any while backedge.
780     if (instruction == instruction->parent()->root_instruction()) {
781       const CallGraphNode& call_graph_node =
782           call_graph_->GetNode(instruction->parent());
783       for (const CallSite& callsite : call_graph_node.caller_callsites()) {
784         if (callsite.instruction()->opcode() == HloOpcode::kCall ||
785             callsite.instruction()->opcode() == HloOpcode::kConditional) {
786           add_to_worklist(callsite.instruction());
787         } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
788           // Add the while itself, and the body and condition parameters.
789           add_to_worklist(callsite.instruction());
790           add_to_worklist(
791               callsite.instruction()->while_body()->parameter_instruction(0));
792           add_to_worklist(
793               callsite.instruction()->while_condition()->parameter_instruction(
794                   0));
795         }
796       }
797     }
798   }
799 }
800 
GetInstructionValueSet(const HloInstruction * instruction) const801 const InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
802     const HloInstruction* instruction) const {
803   return value_sets_.at(instruction);
804 }
805 
GetInstructionValueSet(const HloInstruction * instruction)806 InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
807     const HloInstruction* instruction) {
808   return value_sets_.at(instruction);
809 }
810 
InitializeInstructionValueSets()811 Status HloDataflowAnalysis::InitializeInstructionValueSets() {
812   for (const HloComputation* computation : module_.computations()) {
813     const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
814     for (HloInstruction* instruction : computation->instructions()) {
815       // Create an empty shape tree.
816       value_sets_.emplace(std::piecewise_construct,
817                           std::forward_as_tuple(instruction),
818                           std::forward_as_tuple(instruction->shape()));
819 
820       // For each sub-shape of the instruction shape, add a new HloValue to its
821       // HloValueSet.
822       auto define_all_values = [this, &instruction]() {
823         for (auto& pair : GetInstructionValueSet(instruction)) {
824           const ShapeIndex& index = pair.first;
825           HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
826           GetValueSet(instruction, index).AddValue(value);
827         }
828       };
829 
830       // Add a new HloValue to the HloValueSet corresponding to the given index
831       // of the instruction shape.
832       auto define_value_at = [this, &instruction](const ShapeIndex& index) {
833         HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
834         GetValueSet(instruction, index).AddValue(value);
835       };
836 
837       switch (instruction->opcode()) {
838         case HloOpcode::kBitcast:
839           if (bitcast_defines_value_) {
840             define_all_values();
841           }
842           break;
843         case HloOpcode::kSetDimensionSize:
844         case HloOpcode::kAddDependency:
845         case HloOpcode::kWhile:
846         case HloOpcode::kCall:
847         case HloOpcode::kConditional:
848         case HloOpcode::kGetTupleElement:
849         case HloOpcode::kDomain:
850           // These instructions define no values. The values in their output
851           // flow from their operands or from cross computation dataflow.
852           break;
853         case HloOpcode::kParameter:
854           if (call_graph_node.context() == CallContext::kBoth) {
855             // We do not support a subcomputation that is called from both a
856             // parallel and sequential context. In this case, the parameter
857             // would both define a value and propagate a value from its
858             // caller. This limitation is not really a problem because the call
859             // graph is typically flattened.
860             return Unimplemented(
861                 "Computation %s is called in both a parallel (eg, kMap) and "
862                 "sequential (eg, kCall) context",
863                 computation->name());
864           }
865           if (call_graph_node.caller_callsites().empty() ||
866               call_graph_node.context() == CallContext::kParallel) {
867             // Parameters of computations called in a parallel context (eg, map
868             // and reduce) as well as parameters of dead computations define all
869             // values in their output. Otherwise the values of the parameter
870             // come from the caller (eg, operands to the kCall instruction).
871             define_all_values();
872           }
873           break;
874         case HloOpcode::kCopy:
875         case HloOpcode::kTupleSelect:
876         case HloOpcode::kTuple:
877           // These instructions only define their top-level values. Any other
878           // values flow from their operands.
879           define_value_at(/*index=*/{});
880           break;
881         case HloOpcode::kCopyStart:
882           // CopyStart produces a tuple of {destination buffer, aliased operand,
883           // U32 context}.
884           define_value_at(/*index=*/{});
885           define_value_at(/*index=*/{0});
886           define_value_at(/*index=*/{2});
887           break;
888         case HloOpcode::kCopyDone:
889           // CopyDone consumes a tuple produced by CopyStart and produces an
890           // element. Its output aliases its input tuple element {0}.
891           break;
892         case HloOpcode::kRecvDone:
893           // RecvDone produces a two-element tuple. Element zero aliases its
894           // input tuple element {0}; element one is a token.
895           define_value_at(/*index=*/{});
896           define_value_at(/*index=*/{1});
897           break;
898         case HloOpcode::kSend:
899           // Send produces a tuple of {aliased operand, U32 context, token},
900           // therefore only defines the top-level tuple and the tuple elements
901           // at {1} and {2}.
902           define_value_at(/*index=*/{});
903           define_value_at(/*index=*/{1});
904           define_value_at(/*index=*/{2});
905           break;
906         default:
907           define_all_values();
908           break;
909       }
910     }
911   }
912 
913   return Status::OK();
914 }
915 
916 /* static */
Run(const HloModule & module,bool ssa_form,bool bitcast_defines_value,const CanShareBuffer & can_share_buffer)917 StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
918     const HloModule& module, bool ssa_form, bool bitcast_defines_value,
919     const CanShareBuffer& can_share_buffer) {
920   VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name();
921   XLA_VLOG_LINES(2, module.ToString());
922 
923   auto dataflow_analysis = absl::WrapUnique(new HloDataflowAnalysis(
924       module, ssa_form, bitcast_defines_value, can_share_buffer));
925 
926   TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
927   dataflow_analysis->Propagate();
928 
929   // Delete all values marked for deletion.
930   dataflow_analysis->DeleteMarkedValues();
931 
932   // Gather and set all non-definition positions of all values. Value deletion
933   // is rare, so just use a vector indexed by Value::Id rather than a map from
934   // Value::Id to positions. There should be very few holes in the vector, and
935   // lookup is faster.
936   std::vector<std::vector<HloPosition>> value_positions(
937       dataflow_analysis->next_value_id_);
938   for (const HloComputation* computation : module.computations()) {
939     for (HloInstruction* instruction : computation->instructions()) {
940       for (const auto& pair :
941            dataflow_analysis->GetInstructionValueSet(instruction)) {
942         const ShapeIndex& index = pair.first;
943         const HloValueSet& value_set = pair.second;
944         for (const HloValue* value : value_set.values()) {
945           if (value->defining_instruction() != instruction) {
946             value_positions[value->id()].push_back(
947                 HloPosition{instruction, index});
948           }
949         }
950       }
951     }
952   }
953   for (auto& pair : dataflow_analysis->values_) {
954     HloValue::Id value_id = pair.first;
955     HloValue& value = pair.second;
956     value.SetPositionsAndComputeUses(value_positions[value_id]);
957   }
958 
959   // Construct vector of values.
960   dataflow_analysis->values_vector_.reserve(dataflow_analysis->values_.size());
961   for (auto& pair : dataflow_analysis->values_) {
962     dataflow_analysis->values_vector_.push_back(&pair.second);
963   }
964   absl::c_sort(dataflow_analysis->values_vector_, HloValue::IdLessThan);
965 
966   TF_DCHECK_OK(dataflow_analysis->Verify());
967 
968   XLA_VLOG_LINES(1, dataflow_analysis->ToString());
969 
970   return std::move(dataflow_analysis);
971 }
972 
Verify() const973 Status HloDataflowAnalysis::Verify() const {
974   // Verify each HloValue appears in the value sets that the value's positions()
975   // indicate.
976   for (const HloValue* value : values()) {
977     for (const HloPosition& position : value->positions()) {
978       const HloValueSet& value_set = GetValueSet(position);
979       TF_RET_CHECK(absl::c_linear_search(value_set.values(), value))
980           << "Value set at position " << position << " does not contain value "
981           << value->ToShortString();
982     }
983   }
984 
985   // For each value in each value set, verify that the value set's position
986   // appears in the value's positions().
987   for (const auto& computation : module_.computations()) {
988     for (const auto& instruction : computation->instructions()) {
989       for (const auto& pair : GetInstructionValueSet(instruction)) {
990         const ShapeIndex& index = pair.first;
991         const HloValueSet& value_set = pair.second;
992         const HloPosition position{instruction, index};
993         for (const HloValue* value : value_set.values()) {
994           TF_RET_CHECK(absl::c_linear_search(value->positions(), position))
995               << "Value set at position " << position
996               << " unexpectedly contains value " << value->ToShortString();
997         }
998       }
999     }
1000   }
1001 
1002   return Status::OK();
1003 }
1004 
DoesNotUseOperandBuffer(const HloInstruction * operand,const ShapeIndex & index,const HloInstruction * user) const1005 bool HloDataflowAnalysis::DoesNotUseOperandBuffer(
1006     const HloInstruction* operand, const ShapeIndex& index,
1007     const HloInstruction* user) const {
1008   // Return false if no value at 'operand' and 'index' is used at 'user'.
1009   for (const HloValue* value : GetValueSet(operand, index).values()) {
1010     for (const HloUse& use : value->uses()) {
1011       if (use.instruction == user) {
1012         if (user->IsLoopFusion()) {
1013           HloInstruction* fusion_param =
1014               user->fused_parameter(use.operand_number);
1015           const HloValue& value =
1016               GetValueDefinedAt(fusion_param, use.operand_index);
1017           return value.uses().empty();
1018         }
1019         return false;
1020       }
1021     }
1022   }
1023   return true;
1024 }
1025 
1026 // Given a fusion whose root is a dynamic-update-slice op, determines whether
1027 // the fusion's output buffer can be shared with the buffer of fusion_param,
1028 // which must be a fused parameter of the fusion.
1029 //
1030 // Preconditions:
1031 //
1032 //  - fusion's root is a dynamic-update-slice op.
1033 //  - fusion_param is a parameter within the fusion.
1034 //
1035 // fusion_param may point to a subelement of the actual parameter instruction if
1036 // the param is a tuple; i.e. fusion_param->index() need not be the empty list.
1037 //
1038 // Returns true if:
1039 //
1040 //  * fusion_param is used by the root of dynamic-update-slice as the "base" of
1041 //    the update, i.e. the thing being updated, AND
1042 //  * all other uses of fusion_param are dynamic-slices that slice the same
1043 //    indices as are overwritten in the dynamic-update-slice.
1044 //
1045 // In the case that there are no other uses of fusion_param (last bullet point
1046 // is vacuously true) it's easy to see why an in-place DUS is safe; this is just
1047 // the "natural" implementation of DUS.  If there are other users, in-place DUS
1048 // is safe on the assumption that the thread which writes element i of the
1049 // output will be the only one to read element i of fusion_param (via the
1050 // dynamic-slice ops).
CanDoInPlaceDynamicUpdateSlice(HloInstruction * fusion,const HloValue & fusion_param_value)1051 static bool CanDoInPlaceDynamicUpdateSlice(HloInstruction* fusion,
1052                                            const HloValue& fusion_param_value) {
1053   auto* root =
1054       Cast<HloDynamicUpdateSliceInstruction>(fusion->fused_expression_root());
1055   auto* fusion_param = fusion_param_value.instruction();
1056   CHECK_EQ(fusion_param->opcode(), HloOpcode::kParameter);
1057   CHECK_EQ(fusion_param->parent(), fusion->fused_instructions_computation());
1058 
1059   // fusion_param must be used by the root as the "base" of the
1060   // dynamic-update-slice.  The natural way to check this would be
1061   //
1062   //   `if (root->operand(0) != fusion_param)`
1063   //
1064   // but we also have to handle the case where the fusion parameter is
1065   // tuple-shaped and we're considering just one element of that tuple, i.e.
1066   // fusion_param.index() != {}.
1067   if (absl::c_count_if(fusion_param_value.uses(), [&](const HloUse& use) {
1068         return use.instruction == root;
1069       }) != 1) {
1070     return false;
1071   }
1072 
1073   // All other uses of fusion_param must be dynamic-slices that slice the same
1074   // indices as are overwritten by the dynamic-update-slice.
1075   for (const HloUse& use : fusion_param_value.uses()) {
1076     auto* user = use.instruction;
1077     if (user == root) {
1078       continue;
1079     }
1080 
1081     // Check that `user` is a dynamic-slice op and has the same slice indices as
1082     // `root`.
1083     auto* ds = DynCast<HloDynamicSliceInstruction>(user);
1084     if (!ds || ds->index_operands() != root->index_operands()) {
1085       return false;
1086     }
1087   }
1088   return true;
1089 }
1090 
CanShareOperandBufferWithUser(HloInstruction * operand,const ShapeIndex & operand_index,HloInstruction * user,const ShapeIndex & user_index) const1091 bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
1092     HloInstruction* operand, const ShapeIndex& operand_index,
1093     HloInstruction* user, const ShapeIndex& user_index) const {
1094   CHECK(user->IsUserOf(operand))
1095       << "user: " << user->ToString() << " operand: " << operand->ToString();
1096   const Shape& operand_subshape =
1097       ShapeUtil::GetSubshape(operand->shape(), operand_index);
1098   const Shape& user_subshape =
1099       ShapeUtil::GetSubshape(user->shape(), user_index);
1100 
1101   // Check that operand and user emit the same shape and layout.
1102   if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
1103     return false;
1104   }
1105 
1106   if (user->opcode() == HloOpcode::kFusion) {
1107     // Get the parameter associated with 'operand';
1108     HloInstruction* fusion_param =
1109         user->fused_parameter(user->operand_index(operand));
1110 
1111     const HloValue& fusion_param_value =
1112         GetValueDefinedAt(fusion_param, operand_index);
1113 
1114     // TODO(b/80315712): This code is in a bit of a weird intermediate state
1115     // at the moment. The in-place DUS check really needs to be common to all
1116     // backends, so it runs first. Then we run the backend-specific check if
1117     // provided, or go through the target-independent check if not.
1118     // Unfortunately, the notionally "target-independent" path actually contains
1119     // some target-specific code, so we can't run all of it *in addition* to the
1120     // target-specific function, like the interface documentation says.
1121     if (user->fused_expression_root()->opcode() ==
1122         HloOpcode::kDynamicUpdateSlice) {
1123       return CanDoInPlaceDynamicUpdateSlice(user, fusion_param_value);
1124     }
1125   }
1126 
1127   if (can_share_buffer_ != nullptr) {
1128     if (absl::optional<bool> hint =
1129             can_share_buffer_(user, operand, user_index)) {
1130       return *hint;
1131     }
1132   }
1133 
1134   if (user->opcode() == HloOpcode::kFusion) {
1135     HloInstruction* fusion_param =
1136         user->fused_parameter(user->operand_index(operand));
1137     const HloValue& fusion_param_value =
1138         GetValueDefinedAt(fusion_param, operand_index);
1139 
1140     if (user->IsLoopFusion() || user->IsInputFusion()) {
1141       return AreTransitiveUsesElementwiseOrTuple(fusion_param);
1142     }
1143 
1144     if (user->IsOutputFusion() &&
1145         user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
1146       // Output fusion with kAdd fused root.
1147 
1148       // Check if one operand of kAdd fused root is kDot or kConvolution.
1149       auto* add = user->fused_expression_root();
1150       auto add_operand_it =
1151           absl::c_find_if(add->operands(), [&](HloInstruction* operand) {
1152             return operand->opcode() == HloOpcode::kConvolution ||
1153                    operand->opcode() == HloOpcode::kDot;
1154           });
1155       if (add_operand_it == add->operands().end()) {
1156         return false;
1157       }
1158       auto* matched_add_operand = *add_operand_it;
1159       // Calculate operand index of 'add' operand which was not matched above.
1160       const int64 other_add_operand_index =
1161           matched_add_operand == add->operand(0) ? 1 : 0;
1162       // Returns true iff there is exactly one use of 'operand' at shape index
1163       // 'operand_index', and this singleton use is the fused root (at operand
1164       // index 'other_add_operand_index').
1165       if (fusion_param_value.uses().size() == 1) {
1166         const HloUse& use = fusion_param_value.uses()[0];
1167         return use.instruction == user->fused_expression_root() &&
1168                use.operand_number == other_add_operand_index;
1169       }
1170       return false;
1171     }
1172   }
1173 
1174   if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
1175       user->opcode() == HloOpcode::kScatter ||
1176       user->opcode() == HloOpcode::kTriangularSolve ||
1177       user->opcode() == HloOpcode::kWhile) {
1178     // We eliminated other users in HloOrdering::LiveRangeStrictlyBefore
1179     // so here we just need to check that the use is at the right operand index.
1180     const auto operand_indices = user->OperandIndices(operand);
1181     int64 operand_no = user->opcode() == HloOpcode::kTriangularSolve ? 1 : 0;
1182     return operand_indices.size() == 1 && operand_indices[0] == operand_no;
1183   }
1184   if (user->opcode() == HloOpcode::kSort) {
1185     // Only valid if there are no other users.
1186     if (operand->users().size() != 1) {
1187       return false;
1188     }
1189     // If we only sort keys, the output of sort is not a tuple, so we can always
1190     // share the buffer.
1191     if (user->operand_count() == 1) {
1192       return true;
1193     }
1194     CHECK(!user_index.empty());
1195     // Only share with the right tuple element buffer.
1196     const auto operand_indices = user->OperandIndices(operand);
1197     return operand_indices.size() == 1 && user_index[0] == operand_indices[0];
1198   }
1199   if (user->opcode() == HloOpcode::kCall) {
1200     // Get all uses of value defined by 'operand' at 'operand_index'.
1201     const auto& uses = GetValueDefinedAt(operand, operand_index).uses();
1202     // Return true iff:
1203     // *) There exists two uses of 'operand'.
1204     // *) One use is by 'user' (caller).
1205     // *) One use is by root instruction of called computation (callee root).
1206     //    (Note: we check the root of the called computation, because the
1207     //     root result buffer is required to alias with the Call result buffer).
1208     // *) The root instruction of the called computation is element-wise on
1209     //    'operand'.
1210     const bool found_caller_use =
1211         absl::c_find_if(uses, [user](const HloUse& use) {
1212           return use.instruction == user;
1213         }) != uses.end();
1214     auto* callee_root = user->to_apply()->root_instruction();
1215     const bool found_elementwise_callee_use =
1216         absl::c_find_if(uses, [callee_root](const HloUse& use) {
1217           return use.instruction == callee_root &&
1218                  callee_root->IsElementwiseOnOperand(use.operand_number);
1219         }) != uses.end();
1220     return uses.size() == 2 && found_caller_use && found_elementwise_callee_use;
1221   }
1222 
1223   // Loop fusions that contain transposing copies won't reach here as they have
1224   // different layouts, which fails the check in the beginning of this function.
1225   return user->IsElementwiseOnOperand(user->operand_index(operand));
1226 }
1227 
1228 }  // namespace xla
1229