• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
17 
18 #include <algorithm>
19 #include <queue>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/types/optional.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
35 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
36 #include "tensorflow/compiler/xla/service/hlo_module.h"
37 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
38 #include "tensorflow/compiler/xla/service/hlo_value.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/status.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/platform/logging.h"
45 
46 namespace xla {
47 namespace {
48 // CalculatePostOrderSchedule traverses a module and assign a ordinal to each
49 // instruction based the postorder dependency.
CalculatePostOrderScheduleHelper(const HloComputation * comp,int64_t start_ordinal,absl::flat_hash_map<HloInstruction *,int64> * ordinal_map)50 int64 CalculatePostOrderScheduleHelper(
51     const HloComputation* comp, int64_t start_ordinal,
52     absl::flat_hash_map<HloInstruction*, int64>* ordinal_map) {
53   int64_t ordinal = start_ordinal;
54   for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) {
55     if (instruction->opcode() == HloOpcode::kCall ||
56         instruction->opcode() == HloOpcode::kConditional) {
57       for (const HloComputation* called_computation :
58            instruction->called_computations()) {
59         ordinal = CalculatePostOrderScheduleHelper(called_computation, ordinal,
60                                                    ordinal_map);
61       }
62     }
63     if (instruction->opcode() == HloOpcode::kWhile) {
64       ordinal = CalculatePostOrderScheduleHelper(instruction->while_condition(),
65                                                  ordinal, ordinal_map);
66       ordinal = CalculatePostOrderScheduleHelper(instruction->while_body(),
67                                                  ordinal, ordinal_map);
68     }
69     // It's possible that in some unit tests the computation graph is not
70     // flatten (meaning we could have multiple callers for one computation). In
71     // that case the oridinal_map will see the instruction multiple times. We
72     // consider that case to be ok as it only shows up in unit tests.
73     ordinal_map->insert({instruction, ordinal++});
74   }
75   return ordinal;
76 }
77 
CalculatePostOrderSchedule(const HloModule & module)78 absl::flat_hash_map<HloInstruction*, int64> CalculatePostOrderSchedule(
79     const HloModule& module) {
80   absl::flat_hash_map<HloInstruction*, int64> map;
81   CalculatePostOrderScheduleHelper(module.entry_computation(), 0, &map);
82   return map;
83 }
84 
85 }  // namespace
86 using absl::StrAppend;
87 using absl::StrCat;
88 
HloDataflowAnalysis(const HloModule & module,bool ssa_form,bool bitcast_defines_value,const CanShareBuffer & can_share_buffer)89 HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form,
90                                          bool bitcast_defines_value,
91                                          const CanShareBuffer& can_share_buffer)
92     : module_(module),
93       ssa_form_(ssa_form),
94       bitcast_defines_value_(bitcast_defines_value),
95       call_graph_(CallGraph::Build(&module)),
96       can_share_buffer_(can_share_buffer) {}
97 
AreTransitiveUsesElementwiseOrTuple(const HloInstruction * inst)98 bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
99     const HloInstruction* inst) {
100   absl::flat_hash_set<const HloInstruction*> visited;
101   absl::InlinedVector<const HloInstruction*, 4> stack;
102   stack.push_back(inst);
103   while (!stack.empty()) {
104     const HloInstruction* current = stack.back();
105     stack.pop_back();
106     visited.insert(current);
107     for (const HloInstruction* user : current->users()) {
108       // Found a user that is non-elementwise on current instruction.
109       for (const int64_t use_index : user->OperandIndices(current)) {
110         if (!user->IsElementwiseOnOperand(use_index) &&
111             user->opcode() != HloOpcode::kTuple) {
112           return false;
113         }
114       }
115       if (!visited.contains(user)) {
116         stack.push_back(user);
117       }
118     }
119   }
120   return true;
121 }
122 
123 namespace {
Is1dSliceWithoutStrides(const HloInstruction * instr)124 bool Is1dSliceWithoutStrides(const HloInstruction* instr) {
125   return instr->opcode() == HloOpcode::kSlice &&
126          1 == instr->slice_starts().size() &&
127          1 == instr->slice_limits().size() &&
128          1 == instr->slice_strides().size() &&
129          1 == instr->slice_strides().at(0);
130 }
131 
IsSliceInputFusion(const HloInstruction & unnested_hlo)132 bool IsSliceInputFusion(const HloInstruction& unnested_hlo) {
133   if (!unnested_hlo.IsInputFusion()) {
134     return false;
135   }
136   const HloInstruction* root = unnested_hlo.fused_expression_root();
137   if (root->opcode() != HloOpcode::kTuple) {
138     return false;
139   }
140   return absl::c_all_of(root->operands(), [](const HloInstruction* instr) {
141     return Is1dSliceWithoutStrides(instr);
142   });
143 }
144 
145 struct ConcatUsageInfo {
146   // Pointer to a previously seen concat. nullptr if no previously seen concat.
147   const HloInstruction* prev_concat;
148   // The opnd id of the seen concat.
149   int64 concat_opnd_idx;
150   // The slice that recovers the opnd in the concat outputs.
151   const HloInstruction* slice_to_recover_opnd;
152 };
153 
154 // Returns an optional concat usage info to denote whether the concat is used in
155 // an elementwise manner. A concat followed by slices is considered effectively
156 // elementwise if the slices combinedly is a reverse function of the concat.
ConcatIsEffectivelyElementwise(const HloInstruction & concat,const HloInstruction & operand,const ConcatUsageInfo & info)157 absl::optional<ConcatUsageInfo> ConcatIsEffectivelyElementwise(
158     const HloInstruction& concat, const HloInstruction& operand,
159     const ConcatUsageInfo& info) {
160   // First, check if this concat is in the below pattern. Also, we check
161   // that the slices combinedly are in effect a reverse function of the concat.
162   //
163   //     Concat
164   //     |    |
165   //     v    v
166   //   Slice Slice
167   //
168   std::vector<HloInstruction*> users = concat.users();
169   if (!absl::c_all_of(users, Is1dSliceWithoutStrides)) {
170     // Limit our supported cases to 1 dimensional slices.
171     return absl::optional<ConcatUsageInfo>();
172   }
173   // Verify that each operand to the concat is reversed by a slice.
174   if (users.size() != concat.operand_count() ||
175       concat.operand_count() != concat.unique_operands().size()) {
176     return absl::optional<ConcatUsageInfo>();
177   }
178   absl::c_sort(users, [](const HloInstruction* a, const HloInstruction* b) {
179     return a->slice_starts().at(0) < b->slice_starts().at(0);
180   });
181   int64_t prev_limit = 0;
182   for (int64_t i = 0; i < users.size(); ++i) {
183     const HloInstruction* u = users[i];
184     int64_t slice_size = u->slice_limits().at(0) - u->slice_starts().at(0);
185     if (u->slice_starts().at(0) != prev_limit ||
186         slice_size != ShapeUtil::ElementsIn(concat.operand(i)->shape())) {
187       return absl::optional<ConcatUsageInfo>();
188     }
189     prev_limit = u->slice_limits().at(0);
190   }
191 
192   // If we have seen other concats, make sure they are identical. Multiple
193   // concats exist because horizontal fusion inserts one concat for each output
194   // of the fusion candidates. Check that all concats and operand ids are the
195   // same to know that the "transitive use closure" will be computed in the same
196   // iteration space.
197   int64_t operand_idx = concat.operand_index(&operand);
198   if (info.prev_concat != nullptr) {
199     bool is_concat_identical = info.prev_concat->Identical(
200         concat,
201         /*eq_operands=*/[](const HloInstruction*, const HloInstruction*) {
202           // Operands don't need to be the same.
203           return true;
204         });
205     if (!is_concat_identical || info.concat_opnd_idx != operand_idx) {
206       return absl::optional<ConcatUsageInfo>();
207     }
208   }
209 
210   const HloInstruction* slice_to_recover_opnd = users.at(operand_idx);
211   return absl::optional<ConcatUsageInfo>(
212       ConcatUsageInfo{&concat, operand_idx, slice_to_recover_opnd});
213 }
214 
215 // Returns whether we can prove the transitive uses of `param` are in effect
216 // elementwise. In other words, we prove that the "transitive use closure" will
217 // all be computed in the same iteration space without any reorder of elements.
218 // In addition, we check that the "transitive use closure" includes the output
219 // in the `root_tuple`.
220 // Theoretically, We can prove more patterns but our primary use case is
221 // SliceInputFusion.
AreTransitiveUsesEffectivelyElementwise(const HloInstruction * param,const HloInstruction * root_tuple,const ShapeIndex & out_shape_idx)222 bool AreTransitiveUsesEffectivelyElementwise(const HloInstruction* param,
223                                              const HloInstruction* root_tuple,
224                                              const ShapeIndex& out_shape_idx) {
225   CHECK_EQ(root_tuple->opcode(), HloOpcode::kTuple);
226   CHECK_EQ(out_shape_idx.size(), 1);
227   absl::flat_hash_set<const HloInstruction*> visited;
228   absl::InlinedVector<const HloInstruction*, 4> stack;
229   stack.push_back(param);
230   ConcatUsageInfo concat_usage_info{nullptr, 0, nullptr};
231   bool is_output_reachable = false;
232   while (!stack.empty()) {
233     const HloInstruction* current = stack.back();
234     stack.pop_back();
235     visited.insert(current);
236     for (const HloInstruction* user : current->users()) {
237       VLOG(3) << "Visiting: " << user->ToString();
238       switch (user->opcode()) {
239         case HloOpcode::kTuple:
240           if (user == root_tuple &&
241               current == root_tuple->operand(out_shape_idx.back())) {
242             // We need to know if the output is reachable by the `param` to make
243             // sure that they will be computed in the same iteration space.
244             is_output_reachable = true;
245           }
246           break;
247         case HloOpcode::kReshape:
248           if (!ShapeUtil::ReshapeIsBitcast(current->shape(), user->shape())) {
249             return false;
250           }
251           break;
252         case HloOpcode::kConcatenate: {
253           absl::optional<ConcatUsageInfo> optional_concat_info =
254               ConcatIsEffectivelyElementwise(*user, *current,
255                                              concat_usage_info);
256           if (!optional_concat_info) {
257             return false;
258           }
259           concat_usage_info = *optional_concat_info;
260           // Early continue as we only want to traverse through the slice that
261           // recovers the operand. It is guaranteed that the operand to the
262           // concat and the slice have the same iteration space. Insert the
263           // slice instead of the concat.
264           CHECK(!visited.contains(concat_usage_info.slice_to_recover_opnd));
265           stack.push_back(concat_usage_info.slice_to_recover_opnd);
266           continue;
267         }
268         default:
269           for (const int64_t use_index : user->OperandIndices(current)) {
270             if (!user->IsElementwiseOnOperand(use_index)) {
271               // Found a user that is non-elementwise on the current
272               // instruction.
273               return false;
274             }
275           }
276           if (!LayoutUtil::Equal(current->shape().layout(),
277                                  user->shape().layout())) {
278             // Make sure the layout is not changed by the elementwise op.
279             return false;
280           }
281           break;
282       }  // end of switch
283       if (!visited.contains(user)) {
284         stack.push_back(user);
285       }
286     }
287   }
288   return is_output_reachable;
289 }
290 }  // namespace
291 
ValueIsDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const292 bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
293                                            const ShapeIndex& index) const {
294   const HloValueSet& value_set = GetValueSet(instruction, index);
295   if (value_set.values().size() != 1) {
296     return false;
297   }
298   return value_set.GetUniqueValue().defining_instruction() == instruction;
299 }
300 
GetValueDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const301 const HloValue& HloDataflowAnalysis::GetValueDefinedAt(
302     const HloInstruction* instruction, const ShapeIndex& index) const {
303   CHECK(ValueIsDefinedAt(instruction, index)) << instruction->ToString();
304   return GetUniqueValueAt(instruction, index);
305 }
306 
GetValueDefinedAt(const HloInstruction * instruction,const ShapeIndex & index)307 HloValue& HloDataflowAnalysis::GetValueDefinedAt(
308     const HloInstruction* instruction, const ShapeIndex& index) {
309   CHECK(ValueIsDefinedAt(instruction, index));
310   return GetUniqueValueAt(instruction, index);
311 }
312 
NewHloValue(HloInstruction * instruction,const ShapeIndex & index,bool is_phi)313 HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction,
314                                            const ShapeIndex& index,
315                                            bool is_phi) {
316   const int64_t value_id = next_value_id_++;
317   auto emplaced = values_.emplace(
318       std::piecewise_construct, std::forward_as_tuple(value_id),
319       std::forward_as_tuple(value_id, instruction, index, is_phi));
320   CHECK(emplaced.second);
321 
322   VLOG(4) << "NewHloValue = " << emplaced.first->second.ToShortString();
323 
324   return &emplaced.first->second;
325 }
326 
MarkValueForDeletion(HloValue::Id value_id)327 void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) {
328   HloValue& value = values_.at(value_id);
329   VLOG(4) << "MarkValueForDeletion(" << value.ToShortString() << ")";
330 
331   value_ids_to_delete_.push_back(value_id);
332 }
333 
DeleteMarkedValues()334 void HloDataflowAnalysis::DeleteMarkedValues() {
335   // Use a set to prevent deleting an id twice.
336   absl::flat_hash_set<HloValue::Id> id_set(value_ids_to_delete_.begin(),
337                                            value_ids_to_delete_.end());
338 #ifndef NDEBUG
339   // Verify that no marked-for-deletion values are in any of the value sets.
340   for (const auto& pair : value_sets_) {
341     const HloInstruction* instruction = pair.first;
342     const InstructionValueSet& instruction_value_set = pair.second;
343     for (const auto& index_value_set : instruction_value_set) {
344       const HloValueSet& value_set = index_value_set.second;
345       for (const HloValue* value : value_set.values()) {
346         DCHECK(!ContainsKey(id_set, value->id()))
347             << "Value " << value->ToShortString()
348             << " marked for deletion, but still exists in value set for "
349                "instruction "
350             << instruction->name();
351       }
352     }
353   }
354 #endif
355 
356   for (HloValue::Id value_id : id_set) {
357     values_.erase(value_id);
358   }
359   value_ids_to_delete_.clear();
360 }
361 
ToString() const362 string HloDataflowAnalysis::ToString() const {
363   string out = StrCat("HloDataflowAnalysis, module ", module_.name(), "\n");
364   StrAppend(&out, "  Instruction value sets:\n");
365   for (const HloComputation* computation : module_.computations()) {
366     for (const HloInstruction* instruction : computation->instructions()) {
367       StrAppend(&out, "Instruction: \n  ", instruction->name(), ":\n");
368       if (instruction->shape().IsTuple()) {
369         GetInstructionValueSet(instruction)
370             .ForEachElement([this, &instruction, &out](
371                                 const ShapeIndex& index,
372                                 const HloValueSet& value_set) {
373               StrAppend(&out, "      tuple index ", index.ToString(), ":\n");
374               for (const HloValue* value : value_set.values()) {
375                 StrAppend(&out, "        ", value->ToShortString(),
376                           ValueIsDefinedAt(instruction, index) ? " (def)" : "",
377                           "\n");
378               }
379             });
380       } else {
381         const HloValueSet& top_level_value_set =
382             GetValueSet(instruction, /*index=*/{});
383         for (const HloValue* value : top_level_value_set.values()) {
384           StrAppend(&out, "      ", value->ToShortString(),
385                     ValueIsDefinedAt(instruction) ? " (def)" : "", "\n");
386         }
387       }
388     }
389   }
390   StrAppend(&out, "  HloValues:\n");
391   for (const HloValue* value : values()) {
392     StrAppend(&out, value->ToString(/*indent=*/4));
393   }
394   return out;
395 }
396 
Phi(HloInstruction * instruction,absl::Span<const InstructionValueSet * const> inputs)397 bool HloDataflowAnalysis::Phi(
398     HloInstruction* instruction,
399     absl::Span<const InstructionValueSet* const> inputs) {
400   CHECK(ssa_form_);
401   VLOG(4) << "Phi(" << instruction->name() << ")";
402   VLOG(5) << "instruction value set = "
403           << GetInstructionValueSet(instruction).ToString();
404   for (const InstructionValueSet* input : inputs) {
405     VLOG(5) << "input value set = " << input->ToString();
406   }
407 
408   if (bitcast_defines_value_) {
409     absl::c_for_each(inputs, [&](const InstructionValueSet* input) {
410       DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape()));
411     });
412   } else {
413     const Shape& shape = instruction->shape();
414     PrimitiveType ty = shape.element_type();
415     bool is_array = shape.IsArray();
416     absl::c_for_each(inputs, [&](const InstructionValueSet* input) {
417       DCHECK(ty == input->shape().element_type() &&
418              (!is_array || ShapeUtil::ElementsIn(shape) ==
419                                ShapeUtil::ElementsIn(input->shape())));
420     });
421   }
422 
423   bool changed = false;
424   for (auto& pair : GetInstructionValueSet(instruction)) {
425     const ShapeIndex& index = pair.first;
426     HloValueSet& value_set = pair.second;
427 
428     // Positions with phi values should never have more than one value in the
429     // value set.
430     CHECK_LE(value_set.values().size(), 1);
431     const HloValue* current_value =
432         value_set.values().size() == 1 ? value_set.values()[0] : nullptr;
433 
434     // Construct a vector of value IDs of the inputs.
435     std::vector<HloValue::Id> input_value_ids;
436     for (const InstructionValueSet* input : inputs) {
437       for (const HloValue* value : input->element(index).values()) {
438         input_value_ids.push_back(value->id());
439       }
440     }
441 
442     // Remove the existing phi value (if it exists). The phi can be its own
443     // input, for example, in while body parameters where the body passes
444     // through the parameter value.
445     bool current_value_defined_here =
446         (current_value != nullptr &&
447          current_value->defining_instruction() == instruction &&
448          current_value->defining_index() == index);
449 
450     VLOG(5) << "after input_value_ids.size = " << input_value_ids.size();
451     if (input_value_ids.empty()) {
452       // A value set which has at least one element should never have its value
453       // set reduced to zero elements. During dataflow value sets only can go
454       // from empty to non-empty, not the reverse.
455       CHECK_EQ(value_set.values().size(), 0)
456           << "Instruction " << instruction->name() << " at index " << index
457           << " previously had non-empty value set. Value set: " << value_set;
458     } else if (input_value_ids.size() == 1) {
459       // Only a single value reaches this point. There should be no phi, and
460       // this value set should contain this single value.
461       const HloValue& new_value = GetValue(input_value_ids[0]);
462       if (current_value == nullptr) {
463         value_set.Clear();
464         value_set.AddValue(&new_value);
465         changed = true;
466       } else if (current_value != &new_value) {
467         if (current_value_defined_here) {
468           // Remove the existing phi.
469           MarkValueForDeletion(current_value->id());
470         }
471         value_set.Clear();
472         value_set.AddValue(&new_value);
473         changed = true;
474       }
475     } else {
476       // Multiple distinct values reach this point. A phi value is
477       // necessary.
478       CHECK_GT(input_value_ids.size(), 1);
479       bool phi_defined_here =
480           current_value_defined_here && current_value->is_phi();
481       if (current_value == nullptr || !phi_defined_here) {
482         value_set.Clear();
483         value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true));
484 
485         std::vector<HloValue*> inputs;
486         inputs.reserve(input_value_ids.size());
487         for (HloValue::Id id : input_value_ids) {
488           inputs.push_back(&GetValue(id));
489         }
490         // Register the phi into phi graph.
491         phi_graph_.RegisterPhi(*value_set.values()[0], inputs);
492         changed = true;
493       } else if (phi_defined_here) {
494         std::vector<HloValue*> new_inputs;
495         new_inputs.reserve(input_value_ids.size());
496         for (HloValue::Id id : input_value_ids) {
497           new_inputs.push_back(&GetValue(id));
498         }
499 
500         if (!phi_graph_.InputsEqualTo(*current_value, new_inputs)) {
501           VLOG(1) << current_value->ToShortString() << " has new phi inputs: ";
502           // Update phi inputs.
503           phi_graph_.RegisterPhi(*current_value, new_inputs);
504           changed = true;
505         }
506       }
507     }
508   }
509   return changed;
510 }
511 
GetValue(HloValue::Id value_id) const512 const HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) const {
513   return values_.at(value_id);
514 }
515 
GetValue(HloValue::Id value_id)516 HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) {
517   return values_.at(value_id);
518 }
519 
GetFlattenedValueSet(const HloInstruction * instruction) const520 HloValueSet HloDataflowAnalysis::GetFlattenedValueSet(
521     const HloInstruction* instruction) const {
522   HloValueSet value_set;
523 
524   const InstructionValueSet& value_set_tree =
525       GetInstructionValueSet(instruction);
526 
527   std::vector<const HloValueSet*> all_sets;
528   for (auto& pair : value_set_tree) {
529     const HloValueSet& value_set = pair.second;
530     all_sets.push_back(&value_set);
531   }
532   value_set.AssignUnionOf(all_sets);
533 
534   return value_set;
535 }
536 
GetValueSet(const HloInstruction * instruction,const ShapeIndex & index) const537 const HloValueSet& HloDataflowAnalysis::GetValueSet(
538     const HloInstruction* instruction, const ShapeIndex& index) const {
539   return GetInstructionValueSet(instruction).element(index);
540 }
541 
GetValueSet(const HloInstruction * instruction,const ShapeIndex & index)542 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloInstruction* instruction,
543                                               const ShapeIndex& index) {
544   return *GetInstructionValueSet(instruction).mutable_element(index);
545 }
546 
GetValueSet(const HloPosition & position) const547 const HloValueSet& HloDataflowAnalysis::GetValueSet(
548     const HloPosition& position) const {
549   return GetValueSet(position.instruction, position.index);
550 }
551 
GetValueSet(const HloPosition & position)552 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloPosition& position) {
553   return GetValueSet(position.instruction, position.index);
554 }
555 
UpdateBitcastValueSet(HloInstruction * bitcast)556 bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) {
557   CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast);
558   const InstructionValueSet& operand_set =
559       GetInstructionValueSet(bitcast->operand(0));
560   InstructionValueSet& bitcast_set = GetInstructionValueSet(bitcast);
561   if (!bitcast_defines_value_ && operand_set != bitcast_set) {
562     bitcast_set = operand_set;
563     return true;
564   }
565   return false;
566 }
567 
UpdateSetDimensionSizeValueSet(HloInstruction * set_dimension_size)568 bool HloDataflowAnalysis::UpdateSetDimensionSizeValueSet(
569     HloInstruction* set_dimension_size) {
570   CHECK_EQ(set_dimension_size->opcode(), HloOpcode::kSetDimensionSize);
571   const InstructionValueSet& operand_set =
572       GetInstructionValueSet(set_dimension_size->operand(0));
573   InstructionValueSet& set_dimension_size_set =
574       GetInstructionValueSet(set_dimension_size);
575   if (operand_set != set_dimension_size_set) {
576     set_dimension_size_set = operand_set;
577     return true;
578   }
579   return false;
580 }
581 
UpdateSendValueSet(HloInstruction * send)582 bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
583   CHECK_EQ(send->opcode(), HloOpcode::kSend);
584   bool changed = false;
585   // Send forwards the operand value to the output tuple at {0}.
586   for (auto& pair : GetInstructionValueSet(send->operand(0))) {
587     const ShapeIndex& operand_index = pair.first;
588     const HloValueSet& operand_value_set = pair.second;
589 
590     ShapeIndex index = {0};
591     for (int64_t i : operand_index) {
592       index.push_back(i);
593     }
594 
595     HloValueSet& value_set = GetValueSet(send, index);
596     if (value_set != operand_value_set) {
597       value_set = operand_value_set;
598       changed = true;
599     }
600   }
601   return changed;
602 }
603 
UpdateCustomCallValueSet(HloInstruction * custom_call)604 bool HloDataflowAnalysis::UpdateCustomCallValueSet(
605     HloInstruction* custom_call) {
606   CHECK_EQ(custom_call->opcode(), HloOpcode::kCustomCall);
607   bool changed = false;
608   for (const auto& aliasing : Cast<HloCustomCallInstruction>(custom_call)
609                                   ->output_to_operand_aliasing()) {
610     const HloValueSet& operand_value_set = GetValueSet(
611         custom_call->operand(aliasing.second.first), aliasing.second.second);
612     HloValueSet& value_set = GetValueSet(custom_call, aliasing.first);
613     if (value_set != operand_value_set) {
614       value_set = operand_value_set;
615       changed = true;
616     }
617   }
618   return changed;
619 }
620 
UpdateCopyStartValueSet(HloInstruction * copy_start)621 bool HloDataflowAnalysis::UpdateCopyStartValueSet(HloInstruction* copy_start) {
622   CHECK_EQ(copy_start->opcode(), HloOpcode::kCopyStart);
623   bool changed = false;
624   // CopyStart forwards the operand value to element {1} of its output.
625   const HloValueSet& operand_value_set = GetValueSet(copy_start->operand(0));
626   HloValueSet& value_set = GetValueSet(copy_start, {1});
627   if (value_set != operand_value_set) {
628     value_set = operand_value_set;
629     changed = true;
630   }
631   return changed;
632 }
633 
UpdateCopyDoneValueSet(HloInstruction * copy_done)634 bool HloDataflowAnalysis::UpdateCopyDoneValueSet(HloInstruction* copy_done) {
635   CHECK_EQ(copy_done->opcode(), HloOpcode::kCopyDone);
636   bool changed = false;
637   // CopyDone forwards the operand value at {0} to element {} of its output.
638   const HloValueSet& operand_value_set =
639       GetValueSet(copy_done->operand(0), {0});
640   HloValueSet& value_set = GetValueSet(copy_done);
641   if (value_set != operand_value_set) {
642     value_set = operand_value_set;
643     changed = true;
644   }
645   return changed;
646 }
647 
UpdateRecvDoneValueSet(HloInstruction * recv_done)648 bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) {
649   CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone);
650   bool changed = false;
651   // RecvDone forwards the operand value at {0} to element {0} of its output.
652   for (auto& pair : GetInstructionValueSet(recv_done)) {
653     ShapeIndex& index = pair.first;
654     HloValueSet& value_set = pair.second;
655 
656     if (index.empty() || index[0] != 0) {
657       continue;
658     }
659 
660     const HloValueSet& operand_value_set =
661         GetValueSet(recv_done->operand(0), index);
662     if (value_set != operand_value_set) {
663       value_set = operand_value_set;
664       changed = true;
665     }
666   }
667   return changed;
668 }
669 
UpdateCallValueSet(HloInstruction * call)670 bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) {
671   CHECK_EQ(call->opcode(), HloOpcode::kCall);
672   InstructionValueSet& value_set = GetInstructionValueSet(call);
673   InstructionValueSet& root_value_set =
674       GetInstructionValueSet(call->to_apply()->root_instruction());
675   if (value_set != root_value_set) {
676     value_set = root_value_set;
677     return true;
678   }
679   return false;
680 }
681 
UpdateConditionalValueSet(HloInstruction * conditional)682 bool HloDataflowAnalysis::UpdateConditionalValueSet(
683     HloInstruction* conditional) {
684   CHECK_EQ(conditional->opcode(), HloOpcode::kConditional);
685   std::vector<const InstructionValueSet*> inputs(conditional->branch_count());
686   for (int j = 0; j < conditional->branch_count(); ++j) {
687     inputs[j] = &GetInstructionValueSet(
688         conditional->branch_computation(j)->root_instruction());
689   }
690   if (ssa_form_) {
691     return Phi(conditional, inputs);
692   } else {
693     return GetInstructionValueSet(conditional).AssignUnionOf(inputs);
694   }
695 }
696 
UpdateCopyValueSet(HloInstruction * copy)697 bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) {
698   CHECK_EQ(copy->opcode(), HloOpcode::kCopy);
699   bool changed = false;
700   for (auto& pair : GetInstructionValueSet(copy)) {
701     const ShapeIndex& index = pair.first;
702     if (index.empty()) {
703       // kCopy shallow copies and thus defines the top-level value so nothing to
704       // update.
705       continue;
706     }
707 
708     HloValueSet& value_set = pair.second;
709     HloValueSet& operand_value_set = GetValueSet(copy->operand(0), index);
710     if (value_set != operand_value_set) {
711       value_set = operand_value_set;
712       changed = true;
713     }
714   }
715   return changed;
716 }
717 
UpdateDomainValueSet(HloInstruction * domain)718 bool HloDataflowAnalysis::UpdateDomainValueSet(HloInstruction* domain) {
719   // Domain instructions just forward their operand. Given that domains can have
720   // a tuple operand, we iterate through its indexes, like for copies.
721   // Unlike copies though we also propagate the top-level value.
722   CHECK_EQ(domain->opcode(), HloOpcode::kDomain);
723   bool changed = false;
724   for (auto& pair : GetInstructionValueSet(domain)) {
725     const ShapeIndex& index = pair.first;
726     HloValueSet& value_set = pair.second;
727     HloValueSet& operand_value_set = GetValueSet(domain->operand(0), index);
728     if (value_set != operand_value_set) {
729       value_set = operand_value_set;
730       changed = true;
731     }
732   }
733   return changed;
734 }
735 
UpdateAddDependencyValueSet(HloInstruction * add_dependency)736 bool HloDataflowAnalysis::UpdateAddDependencyValueSet(
737     HloInstruction* add_dependency) {
738   // AddDependency just forwards the value of its zero-th operand.
739   CHECK_EQ(add_dependency->opcode(), HloOpcode::kAddDependency);
740   const InstructionValueSet& operand_set =
741       GetInstructionValueSet(add_dependency->operand(0));
742   InstructionValueSet& add_dependency_set =
743       GetInstructionValueSet(add_dependency);
744   if (operand_set != add_dependency_set) {
745     add_dependency_set = operand_set;
746     return true;
747   }
748   return false;
749 }
750 
UpdateGetTupleElementValueSet(HloInstruction * gte)751 bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) {
752   CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement);
753   bool changed = false;
754   // The GetTupleElement instruction forwards the values from the specified
755   // tuple element.
756   for (auto& pair : GetInstructionValueSet(gte)) {
757     const ShapeIndex& index = pair.first;
758     HloValueSet& value_set = pair.second;
759 
760     // The corresponding ShapeIndex of the operand is simply the GTE ShapeIndex
761     // with the tuple element number prefixed.
762     ShapeIndex operand_index = {gte->tuple_index()};
763     for (int64_t i : index) {
764       operand_index.push_back(i);
765     }
766 
767     HloValueSet& operand_value_set =
768         GetValueSet(gte->operand(0), operand_index);
769     if (value_set != operand_value_set) {
770       value_set = operand_value_set;
771       changed = true;
772     }
773   }
774   return changed;
775 }
776 
UpdateParameterValueSet(HloInstruction * parameter)777 bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) {
778   CHECK_EQ(parameter->opcode(), HloOpcode::kParameter);
779   const CallGraphNode& call_graph_node =
780       call_graph_->GetNode(parameter->parent());
781 
782   // Subcomputations called in a parallel context (eg, map) do not have dataflow
783   // from the caller operands.
784   if (call_graph_node.context() == CallContext::kParallel ||
785       call_graph_node.caller_callsites().empty()) {
786     return false;
787   }
788   CHECK_EQ(call_graph_node.context(), CallContext::kSequential);
789 
790   std::vector<const InstructionValueSet*> inputs;
791   bool need_phi = false;
792   for (const CallSite& callsite : call_graph_node.caller_callsites()) {
793     if (callsite.instruction()->opcode() == HloOpcode::kCall) {
794       // The operand values of a call instruction are forwarded to the
795       // respective parameter instruction of the subcomputation.
796       inputs.push_back(&GetInstructionValueSet(
797           callsite.instruction()->operand(parameter->parameter_number())));
798     } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
799       // In a while instruction, the while operand (ie, the init value) and the
800       // backedge are dataflow inputs to the parameter instruction. This is the
801       // case for parameters of both the body and condition computations.
802       CHECK_EQ(parameter->parameter_number(), 0);
803       inputs.push_back(
804           &GetInstructionValueSet(callsite.instruction()->operand(0)));
805       // If the parameter *is not* the root, parameter state would be
806       // updated by the root, otherwise don't consider it's current state
807       // (InstructionValueSet) as we are recomputing its current state.
808       if (parameter !=
809           callsite.instruction()->while_body()->root_instruction()) {
810         inputs.push_back(&GetInstructionValueSet(
811             callsite.instruction()->while_body()->root_instruction()));
812       }
813       need_phi = true;
814     } else if (callsite.instruction()->opcode() == HloOpcode::kConditional) {
815       CHECK_EQ(parameter->parameter_number(), 0);
816       auto conditional = callsite.instruction();
817       // Conditional has branch_count+1 operands. Operand 0 is the branch_index,
818       // operands 1 and onward are the arguments to the branch computations.
819       //
820       // If the parameter belongs to conditional's branch 0 computation, then
821       // operand 1 is forwarded to this parameter instruction. If the parameter
822       // belongs to conditional's branch 5 computation, then operand 6 is
823       // forwarded to this parameter instruction.
824       bool found_parent = false;
825       for (int j = 0; j < conditional->branch_count(); ++j) {
826         if (parameter->parent() == conditional->branch_computation(j)) {
827           inputs.push_back(
828               &GetInstructionValueSet(conditional->operand(j + 1)));
829           found_parent = true;
830           break;
831         }
832       }
833       CHECK(found_parent);
834       need_phi = true;
835     } else {
836       LOG(FATAL) << "CallContext::kSequential computations should only be "
837                     "called from call, while, or conditional instructions";
838     }
839   }
840   if (ssa_form_ && need_phi) {
841     return Phi(parameter, inputs);
842   } else {
843     return GetInstructionValueSet(parameter).AssignUnionOf(inputs);
844   }
845 }
846 
UpdateTupleSelectValueSet(HloInstruction * select)847 bool HloDataflowAnalysis::UpdateTupleSelectValueSet(HloInstruction* select) {
848   CHECK_EQ(select->opcode(), HloOpcode::kTupleSelect);
849   // A phi value is not defined at a kTupleSelect instruction because
850   // kTupleSelect does not create a new value. Rather it forwards a value from
851   // its operands. This contrasts with kWhile instruction (which does define a
852   // phi value) which has in-place update semantics.
853   bool changed = false;
854   for (auto& pair : GetInstructionValueSet(select)) {
855     const ShapeIndex& index = pair.first;
856     if (index.empty()) {
857       // kTupleSelect copies (not forwards) the top-level value.
858       continue;
859     }
860     HloValueSet& value_set = pair.second;
861     changed |=
862         value_set.AssignUnionOf({&GetValueSet(select->operand(1), index),
863                                  &GetValueSet(select->operand(2), index)});
864   }
865   return changed;
866 }
867 
UpdateTupleValueSet(HloInstruction * tuple)868 bool HloDataflowAnalysis::UpdateTupleValueSet(HloInstruction* tuple) {
869   CHECK_EQ(tuple->opcode(), HloOpcode::kTuple);
870   bool changed = false;
871   for (int64_t i = 0; i < tuple->operands().size(); ++i) {
872     // Copy the value set(s) of each operand into the respective position in the
873     // kTuple instruction's value sets.
874     for (auto& pair : GetInstructionValueSet(tuple->operand(i))) {
875       const ShapeIndex& operand_index = pair.first;
876       HloValueSet& operand_value_set = pair.second;
877 
878       ShapeIndex index = {i};
879       for (int64_t op_index : operand_index) {
880         index.push_back(op_index);
881       }
882       HloValueSet& value_set = GetValueSet(tuple, index);
883 
884       if (value_set != operand_value_set) {
885         value_set = operand_value_set;
886         changed = true;
887       }
888     }
889   }
890   return changed;
891 }
892 
UpdateWhileValueSet(HloInstruction * xla_while)893 bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) {
894   CHECK_EQ(xla_while->opcode(), HloOpcode::kWhile);
895   const InstructionValueSet* const inputs[] = {
896       &GetInstructionValueSet(xla_while->while_body()->root_instruction()),
897       &GetInstructionValueSet(xla_while->operand(0))};
898   if (ssa_form_) {
899     return Phi(xla_while, inputs);
900   } else {
901     return GetInstructionValueSet(xla_while).AssignUnionOf(inputs);
902   }
903 }
904 
UpdateAllGatherStartValueSet(HloInstruction * all_gather_start)905 bool HloDataflowAnalysis::UpdateAllGatherStartValueSet(
906     HloInstruction* all_gather_start) {
907   CHECK_EQ(all_gather_start->opcode(), HloOpcode::kAllGatherStart);
908   bool changed = false;
909   // AllGatherStart forwards the operand values to element {0} of its output.
910   for (int64_t i = 0; i < all_gather_start->operand_count(); ++i) {
911     const HloValueSet& operand_value_set =
912         GetValueSet(all_gather_start->operand(i));
913 
914     ShapeIndex output_index = {0};
915     if (all_gather_start->operand_count() > 1) {
916       output_index.push_back(i);
917     }
918 
919     HloValueSet& value_set = GetValueSet(all_gather_start, output_index);
920     if (value_set != operand_value_set) {
921       value_set = operand_value_set;
922       changed = true;
923     }
924   }
925   return changed;
926 }
927 
UpdateAllGatherDoneValueSet(HloInstruction * all_gather_done)928 bool HloDataflowAnalysis::UpdateAllGatherDoneValueSet(
929     HloInstruction* all_gather_done) {
930   CHECK_EQ(all_gather_done->opcode(), HloOpcode::kAllGatherDone);
931   bool changed = false;
932   // AllGatherDone forwards the operand value at {1} to its output.
933   for (auto& pair : GetInstructionValueSet(all_gather_done)) {
934     const ShapeIndex& output_index = pair.first;
935     HloValueSet& value_set = pair.second;
936 
937     ShapeIndex operand_index = {1};
938     for (int64_t i : output_index) {
939       operand_index.push_back(i);
940     }
941 
942     const HloValueSet& operand_value_set =
943         GetValueSet(all_gather_done->operand(0), operand_index);
944     if (value_set != operand_value_set) {
945       value_set = operand_value_set;
946       changed = true;
947     }
948   }
949   return changed;
950 }
951 
UpdateAllReduceStartValueSet(HloInstruction * all_reduce_start)952 bool HloDataflowAnalysis::UpdateAllReduceStartValueSet(
953     HloInstruction* all_reduce_start) {
954   CHECK_EQ(all_reduce_start->opcode(), HloOpcode::kAllReduceStart);
955   bool changed = false;
956   // AllReduceStart forwards the operand values to element {0} of its output.
957   for (int64_t i = 0; i < all_reduce_start->operand_count(); ++i) {
958     const HloValueSet& operand_value_set =
959         GetValueSet(all_reduce_start->operand(i));
960 
961     ShapeIndex output_index = {0};
962     if (all_reduce_start->operand_count() > 1) {
963       output_index.push_back(i);
964     }
965 
966     HloValueSet& value_set = GetValueSet(all_reduce_start, output_index);
967     if (value_set != operand_value_set) {
968       value_set = operand_value_set;
969       changed = true;
970     }
971   }
972   return changed;
973 }
974 
UpdateAllReduceDoneValueSet(HloInstruction * all_reduce_done)975 bool HloDataflowAnalysis::UpdateAllReduceDoneValueSet(
976     HloInstruction* all_reduce_done) {
977   CHECK_EQ(all_reduce_done->opcode(), HloOpcode::kAllReduceDone);
978   bool changed = false;
979   // AllReduceDone forwards the operand value at {1} to its output.
980   for (auto& pair : GetInstructionValueSet(all_reduce_done)) {
981     const ShapeIndex& output_index = pair.first;
982     HloValueSet& value_set = pair.second;
983 
984     ShapeIndex operand_index = {1};
985     for (int64_t i : output_index) {
986       operand_index.push_back(i);
987     }
988 
989     const HloValueSet& operand_value_set =
990         GetValueSet(all_reduce_done->operand(0), operand_index);
991     if (value_set != operand_value_set) {
992       value_set = operand_value_set;
993       changed = true;
994     }
995   }
996   return changed;
997 }
998 
UpdateCollectivePermuteStartValueSet(HloInstruction * collective_permute_start)999 bool HloDataflowAnalysis::UpdateCollectivePermuteStartValueSet(
1000     HloInstruction* collective_permute_start) {
1001   CHECK_EQ(collective_permute_start->opcode(),
1002            HloOpcode::kCollectivePermuteStart);
1003   bool changed = false;
1004   // CollectivePermuteStart forwards the operand value to element {0} of its
1005   // output.
1006   if (collective_permute_start->operand(0)->shape().IsTuple()) {
1007     for (int i = 0; i < ShapeUtil::TupleElementCount(
1008                             collective_permute_start->operand(0)->shape());
1009          ++i) {
1010       const HloValueSet& operand_value_set =
1011           GetValueSet(collective_permute_start->operand(0), {i});
1012       HloValueSet& value_set = GetValueSet(collective_permute_start, {0, i});
1013       if (value_set != operand_value_set) {
1014         value_set = operand_value_set;
1015         changed = true;
1016       }
1017     }
1018   } else {
1019     const HloValueSet& operand_value_set =
1020         GetValueSet(collective_permute_start->operand(0));
1021     HloValueSet& value_set = GetValueSet(collective_permute_start, {0});
1022     if (value_set != operand_value_set) {
1023       value_set = operand_value_set;
1024       changed = true;
1025     }
1026   }
1027   return changed;
1028 }
1029 
UpdateCollectivePermuteDoneValueSet(HloInstruction * collective_permute_done)1030 bool HloDataflowAnalysis::UpdateCollectivePermuteDoneValueSet(
1031     HloInstruction* collective_permute_done) {
1032   CHECK_EQ(collective_permute_done->opcode(),
1033            HloOpcode::kCollectivePermuteDone);
1034   bool changed = false;
1035   // CollectivePermuteDone forwards the operand value at {1} to its output.
1036   if (collective_permute_done->shape().IsTuple()) {
1037     for (int i = 0;
1038          i < ShapeUtil::TupleElementCount(collective_permute_done->shape());
1039          ++i) {
1040       const HloValueSet& operand_value_set =
1041           GetValueSet(collective_permute_done->operand(0), {1, i});
1042       HloValueSet& value_set = GetValueSet(collective_permute_done, {i});
1043       if (value_set != operand_value_set) {
1044         value_set = operand_value_set;
1045         changed = true;
1046       }
1047     }
1048   } else {
1049     const HloValueSet& operand_value_set =
1050         GetValueSet(collective_permute_done->operand(0), {1});
1051     HloValueSet& value_set = GetValueSet(collective_permute_done);
1052     if (value_set != operand_value_set) {
1053       value_set = operand_value_set;
1054       changed = true;
1055     }
1056   }
1057   return changed;
1058 }
1059 
UpdateInstructionValueSet(HloInstruction * instruction)1060 bool HloDataflowAnalysis::UpdateInstructionValueSet(
1061     HloInstruction* instruction) {
1062   // Recompute from operands.
1063   switch (instruction->opcode()) {
1064     case HloOpcode::kAddDependency:
1065       return UpdateAddDependencyValueSet(instruction);
1066     case HloOpcode::kAllGatherStart:
1067       return UpdateAllGatherStartValueSet(instruction);
1068     case HloOpcode::kAllGatherDone:
1069       return UpdateAllGatherDoneValueSet(instruction);
1070     case HloOpcode::kBitcast:
1071       return UpdateBitcastValueSet(instruction);
1072     case HloOpcode::kCustomCall:
1073       return UpdateCustomCallValueSet(instruction);
1074     case HloOpcode::kSetDimensionSize:
1075       return UpdateSetDimensionSizeValueSet(instruction);
1076     case HloOpcode::kDomain:
1077       return UpdateDomainValueSet(instruction);
1078     case HloOpcode::kCopy:
1079       return UpdateCopyValueSet(instruction);
1080     case HloOpcode::kGetTupleElement:
1081       return UpdateGetTupleElementValueSet(instruction);
1082     case HloOpcode::kTupleSelect:
1083       return UpdateTupleSelectValueSet(instruction);
1084     case HloOpcode::kTuple:
1085       return UpdateTupleValueSet(instruction);
1086     case HloOpcode::kParameter:
1087       return UpdateParameterValueSet(instruction);
1088     case HloOpcode::kCall:
1089       return UpdateCallValueSet(instruction);
1090     case HloOpcode::kWhile:
1091       return UpdateWhileValueSet(instruction);
1092     case HloOpcode::kSend:
1093       return UpdateSendValueSet(instruction);
1094     case HloOpcode::kRecvDone:
1095       return UpdateRecvDoneValueSet(instruction);
1096     case HloOpcode::kCopyStart:
1097       return UpdateCopyStartValueSet(instruction);
1098     case HloOpcode::kCopyDone:
1099       return UpdateCopyDoneValueSet(instruction);
1100     case HloOpcode::kConditional:
1101       return UpdateConditionalValueSet(instruction);
1102     case HloOpcode::kAllReduceStart:
1103       return UpdateAllReduceStartValueSet(instruction);
1104     case HloOpcode::kAllReduceDone:
1105       return UpdateAllReduceDoneValueSet(instruction);
1106     case HloOpcode::kCollectivePermuteStart:
1107       return UpdateCollectivePermuteStartValueSet(instruction);
1108     case HloOpcode::kCollectivePermuteDone:
1109       return UpdateCollectivePermuteDoneValueSet(instruction);
1110     default:
1111       // Instruction does not forward HloValues (it defines all values in its
1112       // output). No update is necessary.
1113       return false;
1114   }
1115 }
1116 
Propagate()1117 void HloDataflowAnalysis::Propagate() {
1118   using Work = std::pair<int64, HloInstruction*>;
1119   // Avoid duplicating work by preferring work items early in the post order
1120   // schedule. Intuitively, we start from entry parameters and propagate buffers
1121   // updates throughout the module only once.
1122   std::priority_queue<Work, std::vector<Work>, std::greater<Work>> worklist;
1123   absl::flat_hash_set<HloInstruction*> workset;
1124   auto priority_map = CalculatePostOrderSchedule(module_);
1125   auto add_to_worklist = [&priority_map, &worklist,
1126                           &workset](HloInstruction* instruction) {
1127     if (workset.insert(instruction).second) {
1128       worklist.emplace(priority_map[instruction], instruction);
1129     }
1130   };
1131 
1132   auto comps = module_.MakeComputationPostOrder();
1133   for (HloComputation* computation : comps) {
1134     for (HloInstruction* instruction :
1135          computation->MakeInstructionPostOrder()) {
1136       add_to_worklist(instruction);
1137     }
1138   }
1139   VLOG(1) << "SSA_FORM_: " << ssa_form_;
1140 
1141   while (!worklist.empty()) {
1142     HloInstruction* instruction = worklist.top().second;
1143     auto add_to_worklist = [&](HloInstruction* todo) {
1144       if (workset.insert(todo).second) {
1145         VLOG(1) << "  Adding todo : " << todo->name();
1146         worklist.emplace(priority_map[todo], todo);
1147       }
1148     };
1149     worklist.pop();
1150 
1151     workset.erase(workset.find(instruction));
1152 
1153     VLOG(3) << "Worklist top: " << instruction->name();
1154     VLOG(3) << ToString();
1155 
1156     if (!UpdateInstructionValueSet(instruction)) {
1157       // No change to the instruction's value set.
1158       VLOG(4) << "No change.";
1159       continue;
1160     }
1161 
1162     VLOG(4) << "New value set for " << instruction->name() << ": "
1163             << GetInstructionValueSet(instruction);
1164 
1165     // Instruction value was updated. Add users to work list if we haven't
1166     // already.
1167     for (HloInstruction* user : instruction->users()) {
1168       add_to_worklist(user);
1169 
1170       // If user sequentially calls a computation, then the respective
1171       // parameter(s) of the computation need to be updated.
1172       if (user->opcode() == HloOpcode::kConditional) {
1173         // If operand 0 is the use of instruction, then no parameters need to be
1174         // updated, since that is the branch_index of the conditional.
1175         // If operand n+1 is the use of instruction, then the branch_computation
1176         // n's parameter need to be updated.
1177         //
1178         // Note that the same instruction can be used in multiple branches'
1179         // operands.
1180         for (int j = 0; j < user->branch_count(); ++j) {
1181           if (user->operand(j + 1) == instruction) {
1182             add_to_worklist(
1183                 user->branch_computation(j)->parameter_instruction(0));
1184           }
1185         }
1186       } else {
1187         for (HloComputation* called_computation : user->called_computations()) {
1188           const CallGraphNode& call_graph_node =
1189               call_graph_->GetNode(called_computation);
1190           if (call_graph_node.context() == CallContext::kSequential) {
1191             for (int64_t operand_number : user->OperandIndices(instruction)) {
1192               add_to_worklist(
1193                   called_computation->parameter_instruction(operand_number));
1194             }
1195           }
1196         }
1197       }
1198     }
1199 
1200     // If instruction is a root instruction, then propagate out to any calling
1201     // instruction and across any while backedge.
1202     if (instruction == instruction->parent()->root_instruction()) {
1203       const CallGraphNode& call_graph_node =
1204           call_graph_->GetNode(instruction->parent());
1205       for (const CallSite& callsite : call_graph_node.caller_callsites()) {
1206         if (callsite.instruction()->opcode() == HloOpcode::kCall ||
1207             callsite.instruction()->opcode() == HloOpcode::kConditional) {
1208           add_to_worklist(callsite.instruction());
1209         } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
1210           // Add the while itself, and the body and condition parameters.
1211           add_to_worklist(callsite.instruction());
1212           add_to_worklist(
1213               callsite.instruction()->while_body()->parameter_instruction(0));
1214           add_to_worklist(
1215               callsite.instruction()->while_condition()->parameter_instruction(
1216                   0));
1217         }
1218       }
1219     }
1220   }
1221 }
1222 
GetInstructionValueSet(const HloInstruction * instruction) const1223 const InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
1224     const HloInstruction* instruction) const {
1225   return value_sets_.at(instruction);
1226 }
1227 
GetInstructionValueSet(const HloInstruction * instruction)1228 InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
1229     const HloInstruction* instruction) {
1230   return value_sets_.at(instruction);
1231 }
1232 
InitializeInstructionValueSets()1233 Status HloDataflowAnalysis::InitializeInstructionValueSets() {
1234   for (const HloComputation* computation : module_.computations()) {
1235     const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
1236     for (HloInstruction* instruction : computation->instructions()) {
1237       // Create an empty shape tree.
1238       value_sets_.emplace(std::piecewise_construct,
1239                           std::forward_as_tuple(instruction),
1240                           std::forward_as_tuple(instruction->shape()));
1241 
1242       // For each sub-shape of the instruction shape, add a new HloValue to its
1243       // HloValueSet.
1244       auto define_all_values = [this, &instruction]() {
1245         for (auto& pair : GetInstructionValueSet(instruction)) {
1246           const ShapeIndex& index = pair.first;
1247           HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
1248           GetValueSet(instruction, index).AddValue(value);
1249         }
1250       };
1251 
1252       // Add a new HloValue to the HloValueSet corresponding to the given index
1253       // of the instruction shape.
1254       auto define_value_at = [this, &instruction](const ShapeIndex& index) {
1255         HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
1256         GetValueSet(instruction, index).AddValue(value);
1257       };
1258 
1259       switch (instruction->opcode()) {
1260         case HloOpcode::kBitcast:
1261           if (bitcast_defines_value_) {
1262             define_all_values();
1263           }
1264           break;
1265         case HloOpcode::kSetDimensionSize:
1266         case HloOpcode::kAddDependency:
1267         case HloOpcode::kWhile:
1268         case HloOpcode::kCall:
1269         case HloOpcode::kConditional:
1270         case HloOpcode::kGetTupleElement:
1271         case HloOpcode::kDomain:
1272           // These instructions define no values. The values in their output
1273           // flow from their operands or from cross computation dataflow.
1274           break;
1275         case HloOpcode::kParameter:
1276           if (call_graph_node.context() == CallContext::kBoth) {
1277             // We do not support a subcomputation that is called from both a
1278             // parallel and sequential context. In this case, the parameter
1279             // would both define a value and propagate a value from its
1280             // caller. This limitation is not really a problem because the call
1281             // graph is typically flattened.
1282             return Unimplemented(
1283                 "Computation %s is called in both a parallel (eg, kMap) and "
1284                 "sequential (eg, kCall) context",
1285                 computation->name());
1286           }
1287           if (call_graph_node.caller_callsites().empty() ||
1288               call_graph_node.context() == CallContext::kParallel) {
1289             // Parameters of computations called in a parallel context (eg, map
1290             // and reduce) as well as parameters of dead computations define all
1291             // values in their output. Otherwise the values of the parameter
1292             // come from the caller (eg, operands to the kCall instruction).
1293             define_all_values();
1294           }
1295           break;
1296         case HloOpcode::kCopy:
1297         case HloOpcode::kTupleSelect:
1298         case HloOpcode::kTuple:
1299           // These instructions only define their top-level values. Any other
1300           // values flow from their operands.
1301           define_value_at(/*index=*/{});
1302           break;
1303         case HloOpcode::kCopyStart:
1304           // CopyStart produces a tuple of {destination buffer, aliased operand,
1305           // U32 context}.
1306           define_value_at(/*index=*/{});
1307           define_value_at(/*index=*/{0});
1308           define_value_at(/*index=*/{2});
1309           break;
1310         case HloOpcode::kCopyDone:
1311           // CopyDone consumes a tuple produced by CopyStart and produces an
1312           // element. Its output aliases its input tuple element {0}.
1313           break;
1314         case HloOpcode::kAllGatherStart:
1315           // AllGatherStart produces a tuple of
1316           // {aliased operand, destination buffer}.
1317           define_value_at(/*index=*/{});
1318           define_value_at(/*index=*/{1});
1319           break;
1320         case HloOpcode::kAllGatherDone:
1321           // AllGatherDone's output aliases its input tuple element {1}.
1322           if (instruction->shape().IsTuple()) {
1323             define_value_at(/*index=*/{});
1324           }
1325           break;
1326         case HloOpcode::kAllReduceStart:
1327           // AllReduceStart produces a tuple of
1328           // {aliased operand, destination buffer}.
1329           define_value_at(/*index=*/{});
1330           define_value_at(/*index=*/{1});
1331           if (instruction->operand_count() > 1) {
1332             for (int64_t i = 0; i < instruction->operand_count(); ++i) {
1333               define_value_at(/*index=*/{1, i});
1334             }
1335           }
1336           break;
1337         case HloOpcode::kAllReduceDone:
1338           // AllReduceDone's output aliases its input tuple element {1}.
1339           break;
1340         case HloOpcode::kCollectivePermuteStart:
1341           // CollectivePermuteStart produces a tuple of
1342           // {aliased operand, destination buffer, U32 context, U32 context}.
1343           define_value_at(/*index=*/{});
1344           define_value_at(/*index=*/{1});
1345           define_value_at(/*index=*/{2});
1346           define_value_at(/*index=*/{3});
1347           if (instruction->operand_count() > 1) {
1348             CHECK_EQ(instruction->operand_count(), 4);
1349             if (instruction->operand(1)->shape().IsTuple()) {
1350               for (int i = 0; i < ShapeUtil::TupleElementCount(
1351                                       instruction->operand(1)->shape());
1352                    ++i) {
1353                 define_value_at(/*index=*/{1, i});
1354               }
1355             }
1356             define_value_at(/*index=*/{4});
1357           }
1358           break;
1359         case HloOpcode::kCollectivePermuteDone:
1360           // CollectivePermuteDone's output aliases its input tuple element {1}.
1361           if (instruction->shape().IsTuple()) {
1362             define_value_at(/*index=*/{});
1363           }
1364           break;
1365         case HloOpcode::kRecvDone:
1366           // RecvDone produces a two-element tuple. Element zero aliases its
1367           // input tuple element {0}; element one is a token.
1368           define_value_at(/*index=*/{});
1369           define_value_at(/*index=*/{1});
1370           break;
1371         case HloOpcode::kSend:
1372           // Send produces a tuple of {aliased operand, U32 context, token},
1373           // therefore only defines the top-level tuple and the tuple elements
1374           // at {1} and {2}.
1375           define_value_at(/*index=*/{});
1376           define_value_at(/*index=*/{1});
1377           define_value_at(/*index=*/{2});
1378           break;
1379         case HloOpcode::kCustomCall: {
1380           absl::flat_hash_set<ShapeIndex> aliasing_indices;
1381           for (const auto& aliasing :
1382                Cast<HloCustomCallInstruction>(instruction)
1383                    ->output_to_operand_aliasing()) {
1384             aliasing_indices.insert(aliasing.first);
1385           }
1386           ShapeUtil::ForEachSubshape(
1387               instruction->shape(),
1388               [&](const Shape& /*subshape*/, const ShapeIndex& index) {
1389                 if (!aliasing_indices.contains(index)) {
1390                   define_value_at(index);
1391                 }
1392               });
1393           break;
1394         }
1395         default:
1396           define_all_values();
1397           break;
1398       }
1399     }
1400   }
1401 
1402   return Status::OK();
1403 }
1404 
OptimizePhiValues()1405 void HloDataflowAnalysis::OptimizePhiValues() {
1406   // Only applicable to SSA form where phis are defined.
1407   if (!ssa_form_) {
1408     return;
1409   }
1410 
1411   VLOG(1) << "Before phi graph optimization";
1412   XLA_VLOG_LINES(1, phi_graph_.ToString());
1413   phi_graph_.Optimize();
1414   VLOG(1) << "After phi graph optimization";
1415   XLA_VLOG_LINES(1, phi_graph_.ToString());
1416 
1417   for (const HloComputation* computation : module_.computations()) {
1418     for (HloInstruction* instruction : computation->instructions()) {
1419       InstructionValueSet& instruction_value_set =
1420           GetInstructionValueSet(instruction);
1421       VLOG(1) << "inst: " << instruction->name();
1422       VLOG(1) << instruction_value_set.ToString();
1423       instruction_value_set.ForEachMutableElement(
1424           [&](const xla::ShapeIndex& index, HloValueSet* value_set) {
1425             auto values = value_set->values();
1426             if (!(values.size() == 1 && values[0]->is_phi())) {
1427               return;
1428             }
1429             HloValue::Id phi_id = values[0]->id();
1430             HloValue::Id new_id = phi_graph_.FindOptimizedValue(phi_id);
1431             if (new_id != phi_id) {
1432               VLOG(1) << "Replacing " << values[0]->ToString() << " with "
1433                       << GetValue(new_id).ToString();
1434               value_set->Clear();
1435               const HloValue& new_value = GetValue(new_id);
1436               value_set->AddValue(&new_value);
1437               MarkValueForDeletion(phi_id);
1438             }
1439           });
1440     }
1441   }
1442 }
1443 
1444 /* static */
Run(const HloModule & module,bool ssa_form,bool bitcast_defines_value,const CanShareBuffer & can_share_buffer)1445 StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
1446     const HloModule& module, bool ssa_form, bool bitcast_defines_value,
1447     const CanShareBuffer& can_share_buffer) {
1448   VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name();
1449   XLA_VLOG_LINES(2, module.ToString());
1450 
1451   auto dataflow_analysis = absl::WrapUnique(new HloDataflowAnalysis(
1452       module, ssa_form, bitcast_defines_value, can_share_buffer));
1453 
1454   TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
1455   dataflow_analysis->Propagate();
1456   dataflow_analysis->OptimizePhiValues();
1457 
1458   // Delete all values marked for deletion.
1459   dataflow_analysis->DeleteMarkedValues();
1460 
1461   // Gather and set all non-definition positions of all values. Value deletion
1462   // is rare, so just use a vector indexed by Value::Id rather than a map from
1463   // Value::Id to positions. There should be very few holes in the vector, and
1464   // lookup is faster.
1465   std::vector<std::vector<HloPosition>> value_positions(
1466       dataflow_analysis->next_value_id_);
1467   for (const HloComputation* computation : module.computations()) {
1468     for (HloInstruction* instruction : computation->instructions()) {
1469       for (const auto& pair :
1470            dataflow_analysis->GetInstructionValueSet(instruction)) {
1471         const ShapeIndex& index = pair.first;
1472         const HloValueSet& value_set = pair.second;
1473         for (const HloValue* value : value_set.values()) {
1474           if (value->defining_instruction() != instruction) {
1475             value_positions[value->id()].push_back(
1476                 HloPosition{instruction, index});
1477           }
1478         }
1479       }
1480     }
1481   }
1482   for (auto& pair : dataflow_analysis->values_) {
1483     HloValue::Id value_id = pair.first;
1484     HloValue& value = pair.second;
1485     value.SetPositionsAndComputeUses(value_positions[value_id]);
1486   }
1487 
1488   // Construct vector of values.
1489   dataflow_analysis->values_vector_.reserve(dataflow_analysis->values_.size());
1490   for (auto& pair : dataflow_analysis->values_) {
1491     dataflow_analysis->values_vector_.push_back(&pair.second);
1492   }
1493   absl::c_sort(dataflow_analysis->values_vector_, HloValue::IdLessThan);
1494 
1495   TF_DCHECK_OK(dataflow_analysis->Verify());
1496 
1497   XLA_VLOG_LINES(1, dataflow_analysis->ToString());
1498 
1499   return std::move(dataflow_analysis);
1500 }
1501 
Verify() const1502 Status HloDataflowAnalysis::Verify() const {
1503   // Verify each HloValue appears in the value sets that the value's positions()
1504   // indicate.
1505   for (const HloValue* value : values()) {
1506     for (const HloPosition& position : value->positions()) {
1507       const HloValueSet& value_set = GetValueSet(position);
1508       TF_RET_CHECK(absl::c_linear_search(value_set.values(), value))
1509           << "Value set at position " << position << " does not contain value "
1510           << value->ToShortString();
1511     }
1512   }
1513 
1514   // For each value in each value set, verify that the value set's position
1515   // appears in the value's positions().
1516   for (const auto& computation : module_.computations()) {
1517     for (const auto& instruction : computation->instructions()) {
1518       for (const auto& pair : GetInstructionValueSet(instruction)) {
1519         const ShapeIndex& index = pair.first;
1520         const HloValueSet& value_set = pair.second;
1521         const HloPosition position{instruction, index};
1522         for (const HloValue* value : value_set.values()) {
1523           TF_RET_CHECK(absl::c_linear_search(value->positions(), position))
1524               << "Value set at position " << position
1525               << " unexpectedly contains value " << value->ToShortString();
1526         }
1527       }
1528     }
1529   }
1530 
1531   return Status::OK();
1532 }
1533 
DoesNotUseOperandBuffer(const HloInstruction * operand,const ShapeIndex & index,const HloInstruction * user) const1534 bool HloDataflowAnalysis::DoesNotUseOperandBuffer(
1535     const HloInstruction* operand, const ShapeIndex& index,
1536     const HloInstruction* user) const {
1537   // Return false if no value at 'operand' and 'index' is used at 'user'.
1538   for (const HloValue* value : GetValueSet(operand, index).values()) {
1539     for (const HloUse& use : value->uses()) {
1540       if (use.instruction == user) {
1541         if (user->IsLoopFusion()) {
1542           HloInstruction* fusion_param =
1543               user->fused_parameter(use.operand_number);
1544           const HloValue& value =
1545               GetValueDefinedAt(fusion_param, use.operand_index);
1546           return value.uses().empty();
1547         }
1548         return false;
1549       }
1550     }
1551   }
1552   return true;
1553 }
1554 
IsInPlaceOperation(HloOpcode opcode)1555 /*static*/ bool HloDataflowAnalysis::IsInPlaceOperation(HloOpcode opcode) {
1556   return opcode == HloOpcode::kDynamicUpdateSlice ||
1557          opcode == HloOpcode::kScatter;
1558 }
1559 
IsAsynchronousOperationStart(HloOpcode opcode)1560 /*static*/ bool HloDataflowAnalysis::IsAsynchronousOperationStart(
1561     HloOpcode opcode) {
1562   return opcode == HloOpcode::kSend || opcode == HloOpcode::kRecv ||
1563          opcode == HloOpcode::kCopyStart ||
1564          opcode == HloOpcode::kAllReduceStart ||
1565          opcode == HloOpcode::kAllGatherStart ||
1566          opcode == HloOpcode::kCollectivePermuteStart;
1567 }
1568 
IsAsynchronousOperationDone(HloOpcode opcode)1569 /*static*/ bool HloDataflowAnalysis::IsAsynchronousOperationDone(
1570     HloOpcode opcode) {
1571   return opcode == HloOpcode::kSendDone || opcode == HloOpcode::kRecvDone ||
1572          opcode == HloOpcode::kCopyDone ||
1573          opcode == HloOpcode::kAllReduceDone ||
1574          opcode == HloOpcode::kAllGatherDone ||
1575          opcode == HloOpcode::kCollectivePermuteDone;
1576 }
1577 
1578 /*static*/ std::vector<std::pair<HloUse, ShapeIndex>>
GetInPlaceInputOutputPairs(HloInstruction * instruction)1579 HloDataflowAnalysis::GetInPlaceInputOutputPairs(HloInstruction* instruction) {
1580   if (IsInPlaceOperation(instruction->opcode())) {
1581     return {{HloUse{instruction, 0, {}}, {}}};
1582   } else if (instruction->opcode() == HloOpcode::kCollectivePermute &&
1583              instruction->operands().size() == 4) {
1584     if (instruction->operand(1)->shape().IsTuple()) {
1585       std::vector<std::pair<HloUse, ShapeIndex>> in_place_pairs(
1586           {{HloUse{instruction, 1, {}}, {}}});
1587       for (int i = 0; i < instruction->operand(1)->shape().tuple_shapes_size();
1588            i++) {
1589         in_place_pairs.push_back({HloUse{instruction, 1, {i}}, {i}});
1590       }
1591       return in_place_pairs;
1592     } else {
1593       return {{HloUse{instruction, 1, {}}, {}}};
1594     }
1595   } else if (instruction->opcode() == HloOpcode::kCollectivePermuteStart &&
1596              instruction->operands().size() == 4) {
1597     if (instruction->operand(1)->shape().IsTuple()) {
1598       std::vector<std::pair<HloUse, ShapeIndex>> in_place_pairs(
1599           {{HloUse{instruction, 1, {}}, {1}}});
1600       for (int i = 0; i < instruction->operand(1)->shape().tuple_shapes_size();
1601            i++) {
1602         in_place_pairs.push_back({HloUse{instruction, 1, {i}}, {1, i}});
1603       }
1604       return in_place_pairs;
1605     } else {
1606       return {{HloUse{instruction, 1, {}}, {1}}};
1607     }
1608   } else if (instruction->opcode() != HloOpcode::kFusion) {
1609     return {};
1610   }
1611 
1612   std::vector<std::pair<HloUse, ShapeIndex>> input_output_pairs;
1613   for (auto& indexed_shape : ShapeUtil::GetLeafShapes(instruction->shape())) {
1614     const HloInstruction* hlo_generating_output =
1615         instruction->fused_expression_root();
1616     for (int64_t i = 0; i < indexed_shape.index.size(); ++i) {
1617       if (hlo_generating_output->opcode() == HloOpcode::kTuple) {
1618         hlo_generating_output =
1619             hlo_generating_output->operand(indexed_shape.index[i]);
1620       } else {
1621         CHECK_EQ(i, indexed_shape.index.size() - 1);
1622       }
1623     }
1624 
1625     if (IsInPlaceOperation(hlo_generating_output->opcode())) {
1626       ShapeIndex operand_index;
1627       const HloInstruction* fusion_parameter =
1628           hlo_generating_output->operand(0);
1629       while (fusion_parameter->opcode() == HloOpcode::kGetTupleElement) {
1630         operand_index.push_front(fusion_parameter->tuple_index());
1631         fusion_parameter = fusion_parameter->operand(0);
1632       }
1633 
1634       if (fusion_parameter->opcode() == HloOpcode::kParameter) {
1635         input_output_pairs.emplace_back(
1636             HloUse{instruction, fusion_parameter->parameter_number(),
1637                    operand_index},
1638             indexed_shape.index);
1639       }
1640     }
1641   }
1642   return input_output_pairs;
1643 }
1644 
CanShareOperandBufferWithUser(HloInstruction * operand,const ShapeIndex & operand_index,HloInstruction * user,const ShapeIndex & user_index) const1645 bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
1646     HloInstruction* operand, const ShapeIndex& operand_index,
1647     HloInstruction* user, const ShapeIndex& user_index) const {
1648   CHECK(user->IsUserOf(operand))
1649       << "user: " << user->ToString() << " operand: " << operand->ToString();
1650   if (operand->opcode() == HloOpcode::kConstant) {
1651     return false;
1652   }
1653 
1654   const Shape& operand_subshape =
1655       ShapeUtil::GetSubshape(operand->shape(), operand_index);
1656   const Shape& user_subshape =
1657       ShapeUtil::GetSubshape(user->shape(), user_index);
1658   if (IsSliceInputFusion(*user)) {
1659     HloInstruction* fusion_param =
1660         user->fused_parameter(user->operand_index(operand));
1661     // We don't require the same dimensions but only the same number of elements
1662     // and type (to make sure the same buffer size).
1663     return operand_subshape.IsArray() && user_subshape.IsArray() &&
1664            ShapeUtil::ElementsIn(operand_subshape) ==
1665                ShapeUtil::ElementsIn(user_subshape) &&
1666            ShapeUtil::SameElementType(operand_subshape, user_subshape) &&
1667            AreTransitiveUsesEffectivelyElementwise(
1668                fusion_param, user->fused_expression_root(), user_index);
1669   }
1670 
1671   // Check that operand and user emit the same shape and layout.
1672   if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
1673     return false;
1674   }
1675 
1676   // Must-alias relationship returns true for in-place operations (DUS and DUS
1677   // fusions), regardless of the backend.
1678   for (const auto& operand_and_output_index :
1679        GetInPlaceInputOutputPairs(user)) {
1680     if (operand_and_output_index.second != user_index) {
1681       continue;
1682     }
1683     for (const HloUse& use : GetUniqueValueAt(operand, operand_index).uses()) {
1684       if (use == operand_and_output_index.first) {
1685         return true;
1686       }
1687     }
1688   }
1689 
1690   if (can_share_buffer_ != nullptr) {
1691     if (absl::optional<bool> hint =
1692             can_share_buffer_(user, operand, user_index)) {
1693       return *hint;
1694     }
1695   }
1696 
1697   if (user->opcode() == HloOpcode::kFusion) {
1698     HloInstruction* fusion_param =
1699         user->fused_parameter(user->operand_index(operand));
1700     const HloValue& fusion_param_value =
1701         GetValueDefinedAt(fusion_param, operand_index);
1702 
1703     if (user->IsLoopFusion() || user->IsInputFusion()) {
1704       return AreTransitiveUsesElementwiseOrTuple(fusion_param);
1705     }
1706 
1707     if (user->IsOutputFusion() &&
1708         user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
1709       // Output fusion with kAdd fused root.
1710 
1711       // Check if one operand of kAdd fused root is kDot or kConvolution.
1712       auto* add = user->fused_expression_root();
1713       auto add_operand_it =
1714           absl::c_find_if(add->operands(), [&](HloInstruction* operand) {
1715             return operand->opcode() == HloOpcode::kConvolution ||
1716                    operand->opcode() == HloOpcode::kDot;
1717           });
1718       if (add_operand_it == add->operands().end()) {
1719         return false;
1720       }
1721       auto* matched_add_operand = *add_operand_it;
1722       // Calculate operand index of 'add' operand which was not matched above.
1723       const int64_t other_add_operand_index =
1724           matched_add_operand == add->operand(0) ? 1 : 0;
1725       // Returns true iff there is exactly one use of 'operand' at shape index
1726       // 'operand_index', and this singleton use is the fused root (at operand
1727       // index 'other_add_operand_index').
1728       if (fusion_param_value.uses().size() == 1) {
1729         const HloUse& use = fusion_param_value.uses()[0];
1730         return use.instruction == user->fused_expression_root() &&
1731                use.operand_number == other_add_operand_index;
1732       }
1733       return false;
1734     }
1735   }
1736 
1737   if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
1738       user->opcode() == HloOpcode::kScatter ||
1739       user->opcode() == HloOpcode::kTriangularSolve ||
1740       user->opcode() == HloOpcode::kWhile) {
1741     // We eliminated other users in HloOrdering::LiveRangeStrictlyBefore
1742     // so here we just need to check that the use is at the right operand index.
1743     const auto operand_indices = user->OperandIndices(operand);
1744     int64_t operand_no = user->opcode() == HloOpcode::kTriangularSolve ? 1 : 0;
1745     return operand_indices.size() == 1 && operand_indices[0] == operand_no;
1746   }
1747   if (user->opcode() == HloOpcode::kSort) {
1748     // Only valid if there are no other users.
1749     if (operand->users().size() != 1) {
1750       return false;
1751     }
1752     // If we only sort keys, the output of sort is not a tuple, so we can always
1753     // share the buffer.
1754     if (user->operand_count() == 1) {
1755       return true;
1756     }
1757     CHECK(!user_index.empty());
1758     // Only share with the right tuple element buffer.
1759     const auto operand_indices = user->OperandIndices(operand);
1760     return operand_indices.size() == 1 && user_index[0] == operand_indices[0];
1761   }
1762   if (user->opcode() == HloOpcode::kCall) {
1763     // Get all uses of value defined by 'operand' at 'operand_index'.
1764     const auto& uses = GetValueDefinedAt(operand, operand_index).uses();
1765     // Return true iff:
1766     // *) There exists two uses of 'operand'.
1767     // *) One use is by 'user' (caller).
1768     // *) One use is by root instruction of called computation (callee root).
1769     //    (Note: we check the root of the called computation, because the
1770     //     root result buffer is required to alias with the Call result buffer).
1771     // *) The root instruction of the called computation is element-wise on
1772     //    'operand'.
1773     const bool found_caller_use =
1774         absl::c_find_if(uses, [user](const HloUse& use) {
1775           return use.instruction == user;
1776         }) != uses.end();
1777     auto* callee_root = user->to_apply()->root_instruction();
1778     const bool found_elementwise_callee_use =
1779         absl::c_find_if(uses, [callee_root](const HloUse& use) {
1780           return use.instruction == callee_root &&
1781                  callee_root->IsElementwiseOnOperand(use.operand_number);
1782         }) != uses.end();
1783     return uses.size() == 2 && found_caller_use && found_elementwise_callee_use;
1784   }
1785 
1786   // Loop fusions that contain transposing copies won't reach here as they have
1787   // different layouts, which fails the check in the beginning of this function.
1788   return user->IsElementwiseOnOperand(user->operand_index(operand));
1789 }
1790 
1791 }  // namespace xla
1792