• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <optional>
21 #include <queue>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/algorithm/container.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/container/inlined_vector.h"
30 #include "absl/strings/str_cat.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
35 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
36 #include "tensorflow/compiler/xla/service/hlo_module.h"
37 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/status.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/platform/logging.h"
44 
45 namespace xla {
46 namespace {
47 // CalculatePostOrderSchedule traverses a module and assign a ordinal to each
48 // instruction based the postorder dependency.
CalculatePostOrderScheduleHelper(const HloComputation * comp,int64_t start_ordinal,absl::flat_hash_map<HloInstruction *,int64_t> * ordinal_map)49 int64_t CalculatePostOrderScheduleHelper(
50     const HloComputation* comp, int64_t start_ordinal,
51     absl::flat_hash_map<HloInstruction*, int64_t>* ordinal_map) {
52   int64_t ordinal = start_ordinal;
53   for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) {
54     if (instruction->opcode() == HloOpcode::kCall ||
55         instruction->opcode() == HloOpcode::kConditional) {
56       for (const HloComputation* called_computation :
57            instruction->called_computations()) {
58         ordinal = CalculatePostOrderScheduleHelper(called_computation, ordinal,
59                                                    ordinal_map);
60       }
61     }
62     if (instruction->opcode() == HloOpcode::kWhile) {
63       ordinal = CalculatePostOrderScheduleHelper(instruction->while_condition(),
64                                                  ordinal, ordinal_map);
65       ordinal = CalculatePostOrderScheduleHelper(instruction->while_body(),
66                                                  ordinal, ordinal_map);
67     }
68     // It's possible that in some unit tests the computation graph is not
69     // flatten (meaning we could have multiple callers for one computation). In
70     // that case the oridinal_map will see the instruction multiple times. We
71     // consider that case to be ok as it only shows up in unit tests.
72     ordinal_map->insert({instruction, ordinal++});
73   }
74   return ordinal;
75 }
76 
CalculatePostOrderSchedule(const HloModule & module)77 absl::flat_hash_map<HloInstruction*, int64_t> CalculatePostOrderSchedule(
78     const HloModule& module) {
79   absl::flat_hash_map<HloInstruction*, int64_t> map;
80   CalculatePostOrderScheduleHelper(module.entry_computation(), 0, &map);
81   return map;
82 }
83 
84 }  // namespace
85 using absl::StrAppend;
86 using absl::StrCat;
87 
HloDataflowAnalysis(const HloModule & module,bool ssa_form,bool bitcast_defines_value,const CanShareBuffer & can_share_buffer)88 HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form,
89                                          bool bitcast_defines_value,
90                                          const CanShareBuffer& can_share_buffer)
91     : module_(module),
92       ssa_form_(ssa_form),
93       bitcast_defines_value_(bitcast_defines_value),
94       call_graph_(CallGraph::Build(&module)),
95       can_share_buffer_(can_share_buffer) {}
96 
AreTransitiveUsesElementwiseOrTuple(const HloInstruction * inst)97 bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
98     const HloInstruction* inst) {
99   absl::flat_hash_set<const HloInstruction*> visited;
100   absl::InlinedVector<const HloInstruction*, 4> stack;
101   stack.push_back(inst);
102   while (!stack.empty()) {
103     const HloInstruction* current = stack.back();
104     stack.pop_back();
105     visited.insert(current);
106     for (const HloInstruction* user : current->users()) {
107       // Found a user that is non-elementwise on current instruction.
108       for (const int64_t use_index : user->OperandIndices(current)) {
109         if (!user->IsElementwiseOnOperand(use_index) &&
110             user->opcode() != HloOpcode::kTuple) {
111           return false;
112         }
113       }
114       if (!visited.contains(user)) {
115         stack.push_back(user);
116       }
117     }
118   }
119   return true;
120 }
121 
122 namespace {
Is1dSliceWithoutStrides(const HloInstruction * instr)123 bool Is1dSliceWithoutStrides(const HloInstruction* instr) {
124   return instr->opcode() == HloOpcode::kSlice &&
125          1 == instr->slice_starts().size() &&
126          1 == instr->slice_limits().size() &&
127          1 == instr->slice_strides().size() &&
128          1 == instr->slice_strides().at(0);
129 }
130 
IsSliceInputFusion(const HloInstruction & unnested_hlo)131 bool IsSliceInputFusion(const HloInstruction& unnested_hlo) {
132   if (!unnested_hlo.IsInputFusion()) {
133     return false;
134   }
135   const HloInstruction* root = unnested_hlo.fused_expression_root();
136   if (root->opcode() != HloOpcode::kTuple) {
137     return false;
138   }
139   return absl::c_all_of(root->operands(), [](const HloInstruction* instr) {
140     return Is1dSliceWithoutStrides(instr);
141   });
142 }
143 
144 struct ConcatUsageInfo {
145   // Pointer to a previously seen concat. nullptr if no previously seen concat.
146   const HloInstruction* prev_concat;
147   // The opnd id of the seen concat.
148   int64_t concat_opnd_idx;
149   // The slice that recovers the opnd in the concat outputs.
150   const HloInstruction* slice_to_recover_opnd;
151 };
152 
153 // Returns an optional concat usage info to denote whether the concat is used in
154 // an elementwise manner. A concat followed by slices is considered effectively
155 // elementwise if the slices combinedly is a reverse function of the concat.
ConcatIsEffectivelyElementwise(const HloInstruction & concat,const HloInstruction & operand,const ConcatUsageInfo & info)156 std::optional<ConcatUsageInfo> ConcatIsEffectivelyElementwise(
157     const HloInstruction& concat, const HloInstruction& operand,
158     const ConcatUsageInfo& info) {
159   // First, check if this concat is in the below pattern. Also, we check
160   // that the slices combinedly are in effect a reverse function of the concat.
161   //
162   //     Concat
163   //     |    |
164   //     v    v
165   //   Slice Slice
166   //
167   std::vector<HloInstruction*> users = concat.users();
168   if (!absl::c_all_of(users, Is1dSliceWithoutStrides)) {
169     // Limit our supported cases to 1 dimensional slices.
170     return std::optional<ConcatUsageInfo>();
171   }
172   // Verify that each operand to the concat is reversed by a slice.
173   if (users.size() != concat.operand_count() ||
174       concat.operand_count() != concat.unique_operands().size()) {
175     return std::optional<ConcatUsageInfo>();
176   }
177   absl::c_sort(users, [](const HloInstruction* a, const HloInstruction* b) {
178     return a->slice_starts().at(0) < b->slice_starts().at(0);
179   });
180   int64_t prev_limit = 0;
181   for (int64_t i = 0; i < users.size(); ++i) {
182     const HloInstruction* u = users[i];
183     int64_t slice_size = u->slice_limits().at(0) - u->slice_starts().at(0);
184     if (u->slice_starts().at(0) != prev_limit ||
185         slice_size != ShapeUtil::ElementsIn(concat.operand(i)->shape())) {
186       return std::optional<ConcatUsageInfo>();
187     }
188     prev_limit = u->slice_limits().at(0);
189   }
190 
191   // If we have seen other concats, make sure they are identical. Multiple
192   // concats exist because horizontal fusion inserts one concat for each output
193   // of the fusion candidates. Check that all concats and operand ids are the
194   // same to know that the "transitive use closure" will be computed in the same
195   // iteration space.
196   int64_t operand_idx = concat.operand_index(&operand);
197   if (info.prev_concat != nullptr) {
198     bool is_concat_identical = info.prev_concat->Identical(
199         concat,
200         /*eq_operands=*/[](const HloInstruction*, const HloInstruction*) {
201           // Operands don't need to be the same.
202           return true;
203         });
204     if (!is_concat_identical || info.concat_opnd_idx != operand_idx) {
205       return std::optional<ConcatUsageInfo>();
206     }
207   }
208 
209   const HloInstruction* slice_to_recover_opnd = users.at(operand_idx);
210   return std::optional<ConcatUsageInfo>(
211       ConcatUsageInfo{&concat, operand_idx, slice_to_recover_opnd});
212 }
213 
214 // Returns whether we can prove the transitive uses of `param` are in effect
215 // elementwise. In other words, we prove that the "transitive use closure" will
216 // all be computed in the same iteration space without any reorder of elements.
217 // In addition, we check that the "transitive use closure" includes the output
218 // in the `root_tuple`.
219 // Theoretically, We can prove more patterns but our primary use case is
220 // SliceInputFusion.
AreTransitiveUsesEffectivelyElementwise(const HloInstruction * param,const HloInstruction * root_tuple,const ShapeIndex & out_shape_idx)221 bool AreTransitiveUsesEffectivelyElementwise(const HloInstruction* param,
222                                              const HloInstruction* root_tuple,
223                                              const ShapeIndex& out_shape_idx) {
224   CHECK_EQ(root_tuple->opcode(), HloOpcode::kTuple);
225   CHECK_EQ(out_shape_idx.size(), 1);
226   absl::flat_hash_set<const HloInstruction*> visited;
227   absl::InlinedVector<const HloInstruction*, 4> stack;
228   stack.push_back(param);
229   ConcatUsageInfo concat_usage_info{nullptr, 0, nullptr};
230   bool is_output_reachable = false;
231   while (!stack.empty()) {
232     const HloInstruction* current = stack.back();
233     stack.pop_back();
234     visited.insert(current);
235     for (const HloInstruction* user : current->users()) {
236       VLOG(3) << "Visiting: " << user->ToString();
237       switch (user->opcode()) {
238         case HloOpcode::kTuple:
239           if (user == root_tuple &&
240               current == root_tuple->operand(out_shape_idx.back())) {
241             // We need to know if the output is reachable by the `param` to make
242             // sure that they will be computed in the same iteration space.
243             is_output_reachable = true;
244           }
245           break;
246         case HloOpcode::kReshape:
247           if (!ShapeUtil::ReshapeIsBitcast(current->shape(), user->shape())) {
248             return false;
249           }
250           break;
251         case HloOpcode::kConcatenate: {
252           std::optional<ConcatUsageInfo> optional_concat_info =
253               ConcatIsEffectivelyElementwise(*user, *current,
254                                              concat_usage_info);
255           if (!optional_concat_info) {
256             return false;
257           }
258           concat_usage_info = *optional_concat_info;
259           // Early continue as we only want to traverse through the slice that
260           // recovers the operand. It is guaranteed that the operand to the
261           // concat and the slice have the same iteration space. Insert the
262           // slice instead of the concat.
263           CHECK(!visited.contains(concat_usage_info.slice_to_recover_opnd));
264           stack.push_back(concat_usage_info.slice_to_recover_opnd);
265           continue;
266         }
267         default:
268           for (const int64_t use_index : user->OperandIndices(current)) {
269             if (!user->IsElementwiseOnOperand(use_index)) {
270               // Found a user that is non-elementwise on the current
271               // instruction.
272               return false;
273             }
274           }
275           if (!LayoutUtil::Equal(current->shape().layout(),
276                                  user->shape().layout())) {
277             // Make sure the layout is not changed by the elementwise op.
278             return false;
279           }
280           break;
281       }  // end of switch
282       if (!visited.contains(user)) {
283         stack.push_back(user);
284       }
285     }
286   }
287   return is_output_reachable;
288 }
289 }  // namespace
290 
ValueIsDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const291 bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
292                                            const ShapeIndex& index) const {
293   const HloValueSet& value_set = GetValueSet(instruction, index);
294   if (value_set.values().size() != 1) {
295     return false;
296   }
297   return value_set.GetUniqueValue().defining_instruction() == instruction;
298 }
299 
GetValueDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const300 const HloValue& HloDataflowAnalysis::GetValueDefinedAt(
301     const HloInstruction* instruction, const ShapeIndex& index) const {
302   CHECK(ValueIsDefinedAt(instruction, index)) << instruction->ToString();
303   return GetUniqueValueAt(instruction, index);
304 }
305 
GetValueDefinedAt(const HloInstruction * instruction,const ShapeIndex & index)306 HloValue& HloDataflowAnalysis::GetValueDefinedAt(
307     const HloInstruction* instruction, const ShapeIndex& index) {
308   CHECK(ValueIsDefinedAt(instruction, index));
309   return GetUniqueValueAt(instruction, index);
310 }
311 
NewHloValue(HloInstruction * instruction,const ShapeIndex & index,bool is_phi)312 HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction,
313                                            const ShapeIndex& index,
314                                            bool is_phi) {
315   const int64_t value_id = next_value_id_++;
316   auto result =
317       values_.insert({value_id, std::make_unique<HloValue>(
318                                     value_id, instruction, index, is_phi)});
319   CHECK(result.second);
320 
321   VLOG(4) << "NewHloValue = " << result.first->second->ToShortString();
322 
323   return result.first->second.get();
324 }
325 
MarkValueForDeletion(HloValue::Id value_id)326 void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) {
327   const HloValue& value = *values_.at(value_id);
328   VLOG(4) << "MarkValueForDeletion(" << value.ToShortString() << ")";
329 
330   value_ids_to_delete_.push_back(value_id);
331 }
332 
DeleteMarkedValues()333 void HloDataflowAnalysis::DeleteMarkedValues() {
334   // Use a set to prevent deleting an id twice.
335   absl::flat_hash_set<HloValue::Id> id_set(value_ids_to_delete_.begin(),
336                                            value_ids_to_delete_.end());
337 #ifndef NDEBUG
338   // Verify that no marked-for-deletion values are in any of the value sets.
339   for (const auto& pair : value_sets_) {
340     const HloInstruction* instruction = pair.first;
341     const InstructionValueSet& instruction_value_set = *pair.second;
342     for (const auto& index_value_set : instruction_value_set) {
343       const HloValueSet& value_set = index_value_set.second;
344       for (const HloValue* value : value_set.values()) {
345         DCHECK(!ContainsKey(id_set, value->id()))
346             << "Value " << value->ToShortString()
347             << " marked for deletion, but still exists in value set for "
348                "instruction "
349             << instruction->name();
350       }
351     }
352   }
353 #endif
354 
355   for (HloValue::Id value_id : id_set) {
356     values_.erase(value_id);
357   }
358   value_ids_to_delete_.clear();
359 }
360 
ToString() const361 std::string HloDataflowAnalysis::ToString() const {
362   std::string out =
363       StrCat("HloDataflowAnalysis, module ", module_.name(), "\n");
364   StrAppend(&out, "  Instruction value sets:\n");
365   for (const HloComputation* computation : module_.computations()) {
366     for (const HloInstruction* instruction : computation->instructions()) {
367       StrAppend(&out, "Instruction: \n  ", instruction->name(), ":\n");
368       if (instruction->shape().IsTuple()) {
369         GetInstructionValueSet(instruction)
370             .ForEachElement([this, &instruction, &out](
371                                 const ShapeIndex& index,
372                                 const HloValueSet& value_set) {
373               StrAppend(&out, "      tuple index ", index.ToString(), ":\n");
374               for (const HloValue* value : value_set.values()) {
375                 StrAppend(&out, "        ", value->ToShortString(),
376                           ValueIsDefinedAt(instruction, index) ? " (def)" : "",
377                           "\n");
378               }
379             });
380       } else {
381         const HloValueSet& top_level_value_set =
382             GetValueSet(instruction, /*index=*/{});
383         for (const HloValue* value : top_level_value_set.values()) {
384           StrAppend(&out, "      ", value->ToShortString(),
385                     ValueIsDefinedAt(instruction) ? " (def)" : "", "\n");
386         }
387       }
388     }
389   }
390   StrAppend(&out, "  HloValues:\n");
391   for (const HloValue* value : values()) {
392     StrAppend(&out, value->ToString(/*indent=*/4));
393   }
394   return out;
395 }
396 
Phi(HloInstruction * instruction,absl::Span<const InstructionValueSet * const> inputs)397 bool HloDataflowAnalysis::Phi(
398     HloInstruction* instruction,
399     absl::Span<const InstructionValueSet* const> inputs) {
400   CHECK(ssa_form_);
401   VLOG(4) << "Phi(" << instruction->name() << ")";
402   VLOG(5) << "instruction value set = "
403           << GetInstructionValueSet(instruction).ToString();
404   for (const InstructionValueSet* input : inputs) {
405     VLOG(5) << "input value set = " << input->ToString();
406   }
407 
408   if (bitcast_defines_value_) {
409     absl::c_for_each(inputs, [&](const InstructionValueSet* input) {
410       DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape()));
411     });
412   } else {
413     const Shape& shape = instruction->shape();
414     PrimitiveType ty = shape.element_type();
415     bool is_array = shape.IsArray();
416     absl::c_for_each(inputs, [&](const InstructionValueSet* input) {
417       DCHECK(ty == input->shape().element_type() &&
418              (!is_array || ShapeUtil::ElementsIn(shape) ==
419                                ShapeUtil::ElementsIn(input->shape())));
420     });
421   }
422 
423   bool changed = false;
424   for (auto& pair : GetInstructionValueSet(instruction)) {
425     const ShapeIndex& index = pair.first;
426     HloValueSet& value_set = pair.second;
427 
428     // Positions with phi values should never have more than one value in the
429     // value set.
430     CHECK_LE(value_set.values().size(), 1);
431     const HloValue* current_value =
432         value_set.values().size() == 1 ? value_set.values()[0] : nullptr;
433 
434     // Construct a vector of value IDs of the inputs.
435     std::vector<HloValue::Id> input_value_ids;
436     for (const InstructionValueSet* input : inputs) {
437       for (const HloValue* value : input->element(index).values()) {
438         input_value_ids.push_back(value->id());
439       }
440     }
441 
442     // Remove the existing phi value (if it exists). The phi can be its own
443     // input, for example, in while body parameters where the body passes
444     // through the parameter value.
445     bool current_value_defined_here =
446         (current_value != nullptr &&
447          current_value->defining_instruction() == instruction &&
448          current_value->defining_index() == index);
449 
450     VLOG(5) << "after input_value_ids.size = " << input_value_ids.size();
451     if (input_value_ids.empty()) {
452       // A value set which has at least one element should never have its value
453       // set reduced to zero elements. During dataflow value sets only can go
454       // from empty to non-empty, not the reverse.
455       CHECK_EQ(value_set.values().size(), 0)
456           << "Instruction " << instruction->name() << " at index " << index
457           << " previously had non-empty value set. Value set: " << value_set;
458     } else if (input_value_ids.size() == 1) {
459       // Only a single value reaches this point. There should be no phi, and
460       // this value set should contain this single value.
461       const HloValue& new_value = GetValue(input_value_ids[0]);
462       if (current_value == nullptr) {
463         value_set.Clear();
464         value_set.AddValue(&new_value);
465         changed = true;
466       } else if (current_value != &new_value) {
467         if (current_value_defined_here) {
468           // Remove the existing phi.
469           MarkValueForDeletion(current_value->id());
470         }
471         value_set.Clear();
472         value_set.AddValue(&new_value);
473         changed = true;
474       }
475     } else {
476       // Multiple distinct values reach this point. A phi value is
477       // necessary.
478       CHECK_GT(input_value_ids.size(), 1);
479       bool phi_defined_here =
480           current_value_defined_here && current_value->is_phi();
481       if (current_value == nullptr || !phi_defined_here) {
482         value_set.Clear();
483         value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true));
484 
485         std::vector<HloValue*> inputs;
486         inputs.reserve(input_value_ids.size());
487         for (HloValue::Id id : input_value_ids) {
488           inputs.push_back(&GetValue(id));
489         }
490         // Register the phi into phi graph.
491         phi_graph_.RegisterPhi(*value_set.values()[0], inputs);
492         changed = true;
493       } else if (phi_defined_here) {
494         std::vector<HloValue*> new_inputs;
495         new_inputs.reserve(input_value_ids.size());
496         for (HloValue::Id id : input_value_ids) {
497           new_inputs.push_back(&GetValue(id));
498         }
499 
500         if (!phi_graph_.InputsEqualTo(*current_value, new_inputs)) {
501           VLOG(1) << current_value->ToShortString() << " has new phi inputs: ";
502           // Update phi inputs.
503           phi_graph_.RegisterPhi(*current_value, new_inputs);
504           changed = true;
505         }
506       }
507     }
508   }
509   return changed;
510 }
511 
GetValue(HloValue::Id value_id) const512 const HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) const {
513   return *values_.at(value_id);
514 }
515 
GetValue(HloValue::Id value_id)516 HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) {
517   return *values_.at(value_id);
518 }
519 
GetFlattenedValueSet(const HloInstruction * instruction) const520 HloValueSet HloDataflowAnalysis::GetFlattenedValueSet(
521     const HloInstruction* instruction) const {
522   HloValueSet value_set;
523 
524   const InstructionValueSet& value_set_tree =
525       GetInstructionValueSet(instruction);
526 
527   std::vector<const HloValueSet*> all_sets;
528   for (auto& pair : value_set_tree) {
529     const HloValueSet& value_set = pair.second;
530     all_sets.push_back(&value_set);
531   }
532   value_set.AssignUnionOf(all_sets);
533 
534   return value_set;
535 }
536 
GetValueSet(const HloInstruction * instruction,const ShapeIndex & index) const537 const HloValueSet& HloDataflowAnalysis::GetValueSet(
538     const HloInstruction* instruction, const ShapeIndex& index) const {
539   return GetInstructionValueSet(instruction).element(index);
540 }
541 
GetValueSet(const HloInstruction * instruction,const ShapeIndex & index)542 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloInstruction* instruction,
543                                               const ShapeIndex& index) {
544   return *GetInstructionValueSet(instruction).mutable_element(index);
545 }
546 
GetValueSet(const HloPosition & position) const547 const HloValueSet& HloDataflowAnalysis::GetValueSet(
548     const HloPosition& position) const {
549   return GetValueSet(position.instruction, position.index);
550 }
551 
GetValueSet(const HloPosition & position)552 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloPosition& position) {
553   return GetValueSet(position.instruction, position.index);
554 }
555 
UpdateBitcastValueSet(HloInstruction * bitcast)556 bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) {
557   CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast);
558   const InstructionValueSet& operand_set =
559       GetInstructionValueSet(bitcast->operand(0));
560   InstructionValueSet& bitcast_set = GetInstructionValueSet(bitcast);
561   if (!bitcast_defines_value_ && operand_set != bitcast_set) {
562     bitcast_set = operand_set;
563     return true;
564   }
565   return false;
566 }
567 
UpdateSetDimensionSizeValueSet(HloInstruction * set_dimension_size)568 bool HloDataflowAnalysis::UpdateSetDimensionSizeValueSet(
569     HloInstruction* set_dimension_size) {
570   CHECK_EQ(set_dimension_size->opcode(), HloOpcode::kSetDimensionSize);
571   const InstructionValueSet& operand_set =
572       GetInstructionValueSet(set_dimension_size->operand(0));
573   InstructionValueSet& set_dimension_size_set =
574       GetInstructionValueSet(set_dimension_size);
575   if (operand_set != set_dimension_size_set) {
576     set_dimension_size_set = operand_set;
577     return true;
578   }
579   return false;
580 }
581 
UpdateSendValueSet(HloInstruction * send)582 bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
583   CHECK_EQ(send->opcode(), HloOpcode::kSend);
584   bool changed = false;
585   // Send forwards the operand value to the output tuple at {0}.
586   for (auto& pair : GetInstructionValueSet(send->operand(0))) {
587     const ShapeIndex& operand_index = pair.first;
588     const HloValueSet& operand_value_set = pair.second;
589 
590     ShapeIndex index = {0};
591     for (int64_t i : operand_index) {
592       index.push_back(i);
593     }
594 
595     HloValueSet& value_set = GetValueSet(send, index);
596     if (value_set != operand_value_set) {
597       value_set = operand_value_set;
598       changed = true;
599     }
600   }
601   return changed;
602 }
603 
UpdateAsyncStartValueSet(HloInstruction * async_start)604 bool HloDataflowAnalysis::UpdateAsyncStartValueSet(
605     HloInstruction* async_start) {
606   CHECK_EQ(async_start->opcode(), HloOpcode::kAsyncStart);
607   bool changed = false;
608   // AsyncStart forwards the operand values to element {0} of its output.
609   for (int64_t i = 0; i < async_start->operand_count(); ++i) {
610     const HloInstruction* operand = async_start->operand(i);
611     ShapeUtil::ForEachSubshape(
612         operand->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
613           if (!subshape.IsArray()) {
614             return;
615           }
616           const HloValueSet& operand_value_set = GetValueSet(operand, index);
617 
618           ShapeIndex output_index = {0, i};
619           output_index.insert(output_index.end(), index.begin(), index.end());
620 
621           HloValueSet& value_set = GetValueSet(async_start, output_index);
622           if (value_set != operand_value_set) {
623             value_set = operand_value_set;
624             changed = true;
625           }
626         });
627   }
628   // AsyncStart forwards the async wrapped computation root values to element
629   // {1} of its output.
630   HloInstruction* root =
631       async_start->async_wrapped_computation()->root_instruction();
632   ShapeUtil::ForEachSubshape(
633       root->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
634         if (!subshape.IsArray()) {
635           return;
636         }
637         const HloValueSet& root_value_set = GetValueSet(root, index);
638 
639         ShapeIndex output_index = {1};
640         output_index.insert(output_index.end(), index.begin(), index.end());
641 
642         HloValueSet& value_set = GetValueSet(async_start, output_index);
643         if (value_set != root_value_set) {
644           value_set = root_value_set;
645           changed = true;
646         }
647       });
648   return changed;
649 }
650 
UpdateAsyncUpdateValueSet(HloInstruction * async_update)651 bool HloDataflowAnalysis::UpdateAsyncUpdateValueSet(
652     HloInstruction* async_update) {
653   CHECK_EQ(async_update->opcode(), HloOpcode::kAsyncUpdate);
654   CHECK_EQ(async_update->shape(), async_update->operand(0)->shape());
655   bool changed = false;
656   HloInstruction* root =
657       async_update->async_wrapped_computation()->root_instruction();
658   // AsyncUpdate forwards all of the operand values to corresponding elements of
659   // its output.
660   ShapeUtil::ForEachSubshape(
661       async_update->operand(0)->shape(),
662       [&](const Shape& subshape, const ShapeIndex& index) {
663         if (!subshape.IsArray()) {
664           return;
665         }
666         const HloValueSet& operand_value_set =
667             GetValueSet(async_update->operand(0), index);
668 
669         HloValueSet& value_set = GetValueSet(async_update, index);
670         CHECK_GE(index.size(), 0);
671         if (index[0] != 1) {
672           if (value_set != operand_value_set) {
673             value_set = operand_value_set;
674             changed = true;
675           }
676         } else {
677           // If this subshape is an output (index {1}), we need to create the
678           // union with the async wrapped computation root.
679           ShapeIndex root_index(index.begin() + 1, index.end());
680           const HloValueSet& root_value_set = GetValueSet(root, root_index);
681           changed |=
682               value_set.AssignUnionOf({&operand_value_set, &root_value_set});
683         }
684       });
685   return changed;
686 }
687 
UpdateAsyncDoneValueSet(HloInstruction * async_done)688 bool HloDataflowAnalysis::UpdateAsyncDoneValueSet(HloInstruction* async_done) {
689   CHECK_EQ(async_done->opcode(), HloOpcode::kAsyncDone);
690   bool changed = false;
691   HloInstruction* root =
692       async_done->async_wrapped_computation()->root_instruction();
693   // AsyncDone creates a union of the operand values at {1} and the async
694   // wrapped computation root to element {} of its output.
695   ShapeUtil::ForEachSubshape(
696       async_done->operand(0)->shape(),
697       [&](const Shape& subshape, const ShapeIndex& index) {
698         if (!subshape.IsArray() || index.front() != 1) {
699           return;
700         }
701         const HloValueSet& operand_value_set =
702             GetValueSet(async_done->operand(0), index);
703 
704         ShapeIndex output_index(index.begin() + 1, index.end());
705         HloValueSet& value_set = GetValueSet(async_done, output_index);
706         const HloValueSet& root_value_set = GetValueSet(root, output_index);
707         changed |=
708             value_set.AssignUnionOf({&operand_value_set, &root_value_set});
709       });
710   return changed;
711 }
712 
UpdateCopyStartValueSet(HloInstruction * copy_start)713 bool HloDataflowAnalysis::UpdateCopyStartValueSet(HloInstruction* copy_start) {
714   CHECK_EQ(copy_start->opcode(), HloOpcode::kCopyStart);
715   bool changed = false;
716   // CopyStart forwards the operand value to element {1} of its output.
717   const HloValueSet& operand_value_set = GetValueSet(copy_start->operand(0));
718   HloValueSet& value_set = GetValueSet(copy_start, {1});
719   if (value_set != operand_value_set) {
720     value_set = operand_value_set;
721     changed = true;
722   }
723   return changed;
724 }
725 
UpdateCopyDoneValueSet(HloInstruction * copy_done)726 bool HloDataflowAnalysis::UpdateCopyDoneValueSet(HloInstruction* copy_done) {
727   CHECK_EQ(copy_done->opcode(), HloOpcode::kCopyDone);
728   bool changed = false;
729   // CopyDone forwards the operand value at {0} to element {} of its output.
730   const HloValueSet& operand_value_set =
731       GetValueSet(copy_done->operand(0), {0});
732   HloValueSet& value_set = GetValueSet(copy_done);
733   if (value_set != operand_value_set) {
734     value_set = operand_value_set;
735     changed = true;
736   }
737   return changed;
738 }
739 
UpdateRecvDoneValueSet(HloInstruction * recv_done)740 bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) {
741   CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone);
742   bool changed = false;
743   // RecvDone forwards the operand value at {0} to element {0} of its output.
744   for (auto& pair : GetInstructionValueSet(recv_done)) {
745     ShapeIndex& index = pair.first;
746     HloValueSet& value_set = pair.second;
747 
748     if (index.empty() || index[0] != 0) {
749       continue;
750     }
751 
752     const HloValueSet& operand_value_set =
753         GetValueSet(recv_done->operand(0), index);
754     if (value_set != operand_value_set) {
755       value_set = operand_value_set;
756       changed = true;
757     }
758   }
759   return changed;
760 }
761 
UpdateCallValueSet(HloInstruction * call)762 bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) {
763   CHECK_EQ(call->opcode(), HloOpcode::kCall);
764   InstructionValueSet& value_set = GetInstructionValueSet(call);
765   InstructionValueSet& root_value_set =
766       GetInstructionValueSet(call->to_apply()->root_instruction());
767   if (value_set != root_value_set) {
768     value_set = root_value_set;
769     return true;
770   }
771   return false;
772 }
773 
UpdateConditionalValueSet(HloInstruction * conditional)774 bool HloDataflowAnalysis::UpdateConditionalValueSet(
775     HloInstruction* conditional) {
776   CHECK_EQ(conditional->opcode(), HloOpcode::kConditional);
777   std::vector<const InstructionValueSet*> inputs(conditional->branch_count());
778   for (int j = 0; j < conditional->branch_count(); ++j) {
779     inputs[j] = &GetInstructionValueSet(
780         conditional->branch_computation(j)->root_instruction());
781   }
782   if (ssa_form_) {
783     return Phi(conditional, inputs);
784   } else {
785     return GetInstructionValueSet(conditional).AssignUnionOf(inputs);
786   }
787 }
788 
UpdateCopyValueSet(HloInstruction * copy)789 bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) {
790   CHECK_EQ(copy->opcode(), HloOpcode::kCopy);
791   bool changed = false;
792   for (auto& pair : GetInstructionValueSet(copy)) {
793     const ShapeIndex& index = pair.first;
794     if (index.empty()) {
795       // kCopy shallow copies and thus defines the top-level value so nothing to
796       // update.
797       continue;
798     }
799 
800     HloValueSet& value_set = pair.second;
801     HloValueSet& operand_value_set = GetValueSet(copy->operand(0), index);
802     if (value_set != operand_value_set) {
803       value_set = operand_value_set;
804       changed = true;
805     }
806   }
807   return changed;
808 }
809 
UpdateOptimizationBarrierValueSet(HloInstruction * barrier)810 bool HloDataflowAnalysis::UpdateOptimizationBarrierValueSet(
811     HloInstruction* barrier) {
812   // Optimization Barriers just forward their operand. Given that barriers can
813   // have a tuple operand, we iterate through its indexes, like for copies.
814   // Unlike copies though we also propagate the top-level value.
815   CHECK_EQ(barrier->opcode(), HloOpcode::kOptimizationBarrier);
816   bool changed = false;
817   for (auto& pair : GetInstructionValueSet(barrier)) {
818     const ShapeIndex& index = pair.first;
819     HloValueSet& value_set = pair.second;
820     HloValueSet& operand_value_set = GetValueSet(barrier->operand(0), index);
821     if (value_set != operand_value_set) {
822       value_set = operand_value_set;
823       changed = true;
824     }
825   }
826   return changed;
827 }
828 
UpdateDomainValueSet(HloInstruction * domain)829 bool HloDataflowAnalysis::UpdateDomainValueSet(HloInstruction* domain) {
830   // Domain instructions just forward their operand. Given that domains can have
831   // a tuple operand, we iterate through its indexes, like for copies.
832   // Unlike copies though we also propagate the top-level value.
833   CHECK_EQ(domain->opcode(), HloOpcode::kDomain);
834   bool changed = false;
835   for (auto& pair : GetInstructionValueSet(domain)) {
836     const ShapeIndex& index = pair.first;
837     HloValueSet& value_set = pair.second;
838     HloValueSet& operand_value_set = GetValueSet(domain->operand(0), index);
839     if (value_set != operand_value_set) {
840       value_set = operand_value_set;
841       changed = true;
842     }
843   }
844   return changed;
845 }
846 
UpdateAddDependencyValueSet(HloInstruction * add_dependency)847 bool HloDataflowAnalysis::UpdateAddDependencyValueSet(
848     HloInstruction* add_dependency) {
849   // AddDependency just forwards the value of its zero-th operand.
850   CHECK_EQ(add_dependency->opcode(), HloOpcode::kAddDependency);
851   const InstructionValueSet& operand_set =
852       GetInstructionValueSet(add_dependency->operand(0));
853   InstructionValueSet& add_dependency_set =
854       GetInstructionValueSet(add_dependency);
855   if (operand_set != add_dependency_set) {
856     add_dependency_set = operand_set;
857     return true;
858   }
859   return false;
860 }
861 
UpdateGetTupleElementValueSet(HloInstruction * gte)862 bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) {
863   CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement);
864   bool changed = false;
865   // The GetTupleElement instruction forwards the values from the specified
866   // tuple element.
867   for (auto& pair : GetInstructionValueSet(gte)) {
868     const ShapeIndex& index = pair.first;
869     HloValueSet& value_set = pair.second;
870 
871     // The corresponding ShapeIndex of the operand is simply the GTE ShapeIndex
872     // with the tuple element number prefixed.
873     ShapeIndex operand_index = {gte->tuple_index()};
874     for (int64_t i : index) {
875       operand_index.push_back(i);
876     }
877 
878     HloValueSet& operand_value_set =
879         GetValueSet(gte->operand(0), operand_index);
880     if (value_set != operand_value_set) {
881       value_set = operand_value_set;
882       changed = true;
883     }
884   }
885   return changed;
886 }
887 
UpdateParameterValueSet(HloInstruction * parameter)888 bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) {
889   CHECK_EQ(parameter->opcode(), HloOpcode::kParameter);
890   const CallGraphNode& call_graph_node =
891       call_graph_->GetNode(parameter->parent());
892 
893   // Subcomputations called in a parallel context (eg, map) do not have dataflow
894   // from the caller operands.
895   if (call_graph_node.context() == CallContext::kEmbedded ||
896       call_graph_node.caller_callsites().empty()) {
897     return false;
898   }
899   CHECK_EQ(call_graph_node.context(), CallContext::kControlFlow);
900 
901   std::vector<const InstructionValueSet*> inputs;
902   bool need_phi = false;
903   for (const CallSite& callsite : call_graph_node.caller_callsites()) {
904     const HloOpcode& opcode = callsite.instruction()->opcode();
905     if (opcode == HloOpcode::kCall) {
906       // The operand values of a call instruction are forwarded to the
907       // respective parameter instruction of the subcomputation.
908       inputs.push_back(&GetInstructionValueSet(
909           callsite.instruction()->operand(parameter->parameter_number())));
910     } else if (opcode == HloOpcode::kWhile) {
911       // In a while instruction, the while operand (ie, the init value) and the
912       // backedge are dataflow inputs to the parameter instruction. This is the
913       // case for parameters of both the body and condition computations.
914       CHECK_EQ(parameter->parameter_number(), 0);
915       inputs.push_back(
916           &GetInstructionValueSet(callsite.instruction()->operand(0)));
917       // If the parameter *is not* the root, parameter state would be
918       // updated by the root, otherwise don't consider it's current state
919       // (InstructionValueSet) as we are recomputing its current state.
920       if (parameter !=
921           callsite.instruction()->while_body()->root_instruction()) {
922         inputs.push_back(&GetInstructionValueSet(
923             callsite.instruction()->while_body()->root_instruction()));
924       }
925       need_phi = true;
926     } else if (opcode == HloOpcode::kConditional) {
927       CHECK_EQ(parameter->parameter_number(), 0);
928       auto conditional = callsite.instruction();
929       // Conditional has branch_count+1 operands. Operand 0 is the branch_index,
930       // operands 1 and onward are the arguments to the branch computations.
931       //
932       // If the parameter belongs to conditional's branch 0 computation, then
933       // operand 1 is forwarded to this parameter instruction. If the parameter
934       // belongs to conditional's branch 5 computation, then operand 6 is
935       // forwarded to this parameter instruction.
936       bool found_parent = false;
937       for (int j = 0; j < conditional->branch_count(); ++j) {
938         if (parameter->parent() == conditional->branch_computation(j)) {
939           inputs.push_back(
940               &GetInstructionValueSet(conditional->operand(j + 1)));
941           found_parent = true;
942           break;
943         }
944       }
945       CHECK(found_parent);
946       need_phi = true;
947     } else if (opcode == HloOpcode::kAsyncStart) {
948       inputs.push_back(&GetInstructionValueSet(
949           callsite.instruction()->operand(parameter->parameter_number())));
950     } else if (opcode == HloOpcode::kAsyncUpdate ||
951                opcode == HloOpcode::kAsyncDone) {
952       return GetInstructionValueSet(parameter).AssignUnionOf(
953           GetInstructionValueSet(callsite.instruction()->operand(0)),
954           {0, parameter->parameter_number()});
955     } else {
956       LOG(FATAL) << "CallContext::kSequential computations should only be "
957                     "called from call, while, or conditional instructions";
958     }
959   }
960   if (ssa_form_ && need_phi) {
961     return Phi(parameter, inputs);
962   } else {
963     return GetInstructionValueSet(parameter).AssignUnionOf(inputs);
964   }
965 }
966 
UpdateTupleValueSet(HloInstruction * tuple)967 bool HloDataflowAnalysis::UpdateTupleValueSet(HloInstruction* tuple) {
968   CHECK_EQ(tuple->opcode(), HloOpcode::kTuple);
969   bool changed = false;
970   for (int64_t i = 0; i < tuple->operands().size(); ++i) {
971     // Copy the value set(s) of each operand into the respective position in the
972     // kTuple instruction's value sets.
973     for (auto& pair : GetInstructionValueSet(tuple->operand(i))) {
974       const ShapeIndex& operand_index = pair.first;
975       HloValueSet& operand_value_set = pair.second;
976 
977       ShapeIndex index = {i};
978       for (int64_t op_index : operand_index) {
979         index.push_back(op_index);
980       }
981       HloValueSet& value_set = GetValueSet(tuple, index);
982 
983       if (value_set != operand_value_set) {
984         value_set = operand_value_set;
985         changed = true;
986       }
987     }
988   }
989   return changed;
990 }
991 
UpdateWhileValueSet(HloInstruction * xla_while)992 bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) {
993   CHECK_EQ(xla_while->opcode(), HloOpcode::kWhile);
994   const InstructionValueSet* const inputs[] = {
995       &GetInstructionValueSet(xla_while->while_body()->root_instruction()),
996       &GetInstructionValueSet(xla_while->operand(0))};
997   if (ssa_form_) {
998     return Phi(xla_while, inputs);
999   } else {
1000     return GetInstructionValueSet(xla_while).AssignUnionOf(inputs);
1001   }
1002 }
1003 
UpdateAllGatherStartValueSet(HloInstruction * all_gather_start)1004 bool HloDataflowAnalysis::UpdateAllGatherStartValueSet(
1005     HloInstruction* all_gather_start) {
1006   CHECK_EQ(all_gather_start->opcode(), HloOpcode::kAllGatherStart);
1007   bool changed = false;
1008   // AllGatherStart forwards the operand values to element {0} of its output.
1009   for (int64_t i = 0; i < all_gather_start->operand_count(); ++i) {
1010     const HloValueSet& operand_value_set =
1011         GetValueSet(all_gather_start->operand(i));
1012 
1013     ShapeIndex output_index = {0};
1014     if (all_gather_start->operand_count() > 1) {
1015       output_index.push_back(i);
1016     }
1017 
1018     HloValueSet& value_set = GetValueSet(all_gather_start, output_index);
1019     if (value_set != operand_value_set) {
1020       value_set = operand_value_set;
1021       changed = true;
1022     }
1023   }
1024   return changed;
1025 }
1026 
UpdateAllGatherDoneValueSet(HloInstruction * all_gather_done)1027 bool HloDataflowAnalysis::UpdateAllGatherDoneValueSet(
1028     HloInstruction* all_gather_done) {
1029   CHECK_EQ(all_gather_done->opcode(), HloOpcode::kAllGatherDone);
1030   bool changed = false;
1031   // AllGatherDone forwards the operand value at {1} to its output.
1032   for (auto& pair : GetInstructionValueSet(all_gather_done)) {
1033     const ShapeIndex& output_index = pair.first;
1034     HloValueSet& value_set = pair.second;
1035 
1036     ShapeIndex operand_index = {1};
1037     for (int64_t i : output_index) {
1038       operand_index.push_back(i);
1039     }
1040 
1041     const HloValueSet& operand_value_set =
1042         GetValueSet(all_gather_done->operand(0), operand_index);
1043     if (value_set != operand_value_set) {
1044       value_set = operand_value_set;
1045       changed = true;
1046     }
1047   }
1048   return changed;
1049 }
1050 
UpdateAllReduceDoneValueSet(HloInstruction * all_reduce_done)1051 bool HloDataflowAnalysis::UpdateAllReduceDoneValueSet(
1052     HloInstruction* all_reduce_done) {
1053   CHECK_EQ(all_reduce_done->opcode(), HloOpcode::kAllReduceDone);
1054   bool changed = false;
1055   // AllReduceDone forwards its only operand.
1056   for (auto& pair : GetInstructionValueSet(all_reduce_done)) {
1057     const ShapeIndex& output_index = pair.first;
1058     HloValueSet& value_set = pair.second;
1059 
1060     ShapeIndex operand_index = {};
1061     for (int64_t i : output_index) {
1062       operand_index.push_back(i);
1063     }
1064 
1065     const HloValueSet& operand_value_set =
1066         GetValueSet(all_reduce_done->operand(0), operand_index);
1067     if (value_set != operand_value_set) {
1068       value_set = operand_value_set;
1069       changed = true;
1070     }
1071   }
1072   return changed;
1073 }
1074 
UpdateCollectivePermuteStartValueSet(HloInstruction * collective_permute_start)1075 bool HloDataflowAnalysis::UpdateCollectivePermuteStartValueSet(
1076     HloInstruction* collective_permute_start) {
1077   CHECK_EQ(collective_permute_start->opcode(),
1078            HloOpcode::kCollectivePermuteStart);
1079   bool changed = false;
1080   // CollectivePermuteStart forwards the operand value to element {0} of its
1081   // output.
1082   if (collective_permute_start->operand(0)->shape().IsTuple()) {
1083     for (int i = 0; i < ShapeUtil::TupleElementCount(
1084                             collective_permute_start->operand(0)->shape());
1085          ++i) {
1086       const HloValueSet& operand_value_set =
1087           GetValueSet(collective_permute_start->operand(0), {i});
1088       HloValueSet& value_set = GetValueSet(collective_permute_start, {0, i});
1089       if (value_set != operand_value_set) {
1090         value_set = operand_value_set;
1091         changed = true;
1092       }
1093     }
1094   } else {
1095     const HloValueSet& operand_value_set =
1096         GetValueSet(collective_permute_start->operand(0));
1097     HloValueSet& value_set = GetValueSet(collective_permute_start, {0});
1098     if (value_set != operand_value_set) {
1099       value_set = operand_value_set;
1100       changed = true;
1101     }
1102   }
1103   return changed;
1104 }
1105 
UpdateCollectivePermuteDoneValueSet(HloInstruction * collective_permute_done)1106 bool HloDataflowAnalysis::UpdateCollectivePermuteDoneValueSet(
1107     HloInstruction* collective_permute_done) {
1108   CHECK_EQ(collective_permute_done->opcode(),
1109            HloOpcode::kCollectivePermuteDone);
1110   bool changed = false;
1111   // CollectivePermuteDone forwards the operand value at {1} to its output.
1112   if (collective_permute_done->shape().IsTuple()) {
1113     for (int i = 0;
1114          i < ShapeUtil::TupleElementCount(collective_permute_done->shape());
1115          ++i) {
1116       const HloValueSet& operand_value_set =
1117           GetValueSet(collective_permute_done->operand(0), {1, i});
1118       HloValueSet& value_set = GetValueSet(collective_permute_done, {i});
1119       if (value_set != operand_value_set) {
1120         value_set = operand_value_set;
1121         changed = true;
1122       }
1123     }
1124   } else {
1125     const HloValueSet& operand_value_set =
1126         GetValueSet(collective_permute_done->operand(0), {1});
1127     HloValueSet& value_set = GetValueSet(collective_permute_done);
1128     if (value_set != operand_value_set) {
1129       value_set = operand_value_set;
1130       changed = true;
1131     }
1132   }
1133   return changed;
1134 }
1135 
UpdateInstructionValueSet(HloInstruction * instruction)1136 bool HloDataflowAnalysis::UpdateInstructionValueSet(
1137     HloInstruction* instruction) {
1138   // Recompute from operands.
1139   switch (instruction->opcode()) {
1140     case HloOpcode::kAddDependency:
1141       return UpdateAddDependencyValueSet(instruction);
1142     case HloOpcode::kAllGatherStart:
1143       return UpdateAllGatherStartValueSet(instruction);
1144     case HloOpcode::kAllGatherDone:
1145       return UpdateAllGatherDoneValueSet(instruction);
1146     case HloOpcode::kAsyncStart:
1147       return UpdateAsyncStartValueSet(instruction);
1148     case HloOpcode::kAsyncUpdate:
1149       return UpdateAsyncUpdateValueSet(instruction);
1150     case HloOpcode::kAsyncDone:
1151       return UpdateAsyncDoneValueSet(instruction);
1152     case HloOpcode::kBitcast:
1153       return UpdateBitcastValueSet(instruction);
1154     case HloOpcode::kSetDimensionSize:
1155       return UpdateSetDimensionSizeValueSet(instruction);
1156     case HloOpcode::kDomain:
1157       return UpdateDomainValueSet(instruction);
1158     case HloOpcode::kCopy:
1159       return UpdateCopyValueSet(instruction);
1160     case HloOpcode::kGetTupleElement:
1161       return UpdateGetTupleElementValueSet(instruction);
1162     case HloOpcode::kTuple:
1163       return UpdateTupleValueSet(instruction);
1164     case HloOpcode::kParameter:
1165       return UpdateParameterValueSet(instruction);
1166     case HloOpcode::kCall:
1167       return UpdateCallValueSet(instruction);
1168     case HloOpcode::kWhile:
1169       return UpdateWhileValueSet(instruction);
1170     case HloOpcode::kSend:
1171       return UpdateSendValueSet(instruction);
1172     case HloOpcode::kRecvDone:
1173       return UpdateRecvDoneValueSet(instruction);
1174     case HloOpcode::kCopyStart:
1175       return UpdateCopyStartValueSet(instruction);
1176     case HloOpcode::kCopyDone:
1177       return UpdateCopyDoneValueSet(instruction);
1178     case HloOpcode::kConditional:
1179       return UpdateConditionalValueSet(instruction);
1180     case HloOpcode::kAllReduceDone:
1181       return UpdateAllReduceDoneValueSet(instruction);
1182     case HloOpcode::kCollectivePermuteStart:
1183       return UpdateCollectivePermuteStartValueSet(instruction);
1184     case HloOpcode::kCollectivePermuteDone:
1185       return UpdateCollectivePermuteDoneValueSet(instruction);
1186     case HloOpcode::kOptimizationBarrier:
1187       return UpdateOptimizationBarrierValueSet(instruction);
1188     default:
1189       // Instruction does not forward HloValues (it defines all values in its
1190       // output). No update is necessary.
1191       return false;
1192   }
1193 }
1194 
Propagate()1195 void HloDataflowAnalysis::Propagate() {
1196   using Work = std::pair<int64_t, HloInstruction*>;
1197   // Avoid duplicating work by preferring work items early in the post order
1198   // schedule. Intuitively, we start from entry parameters and propagate buffers
1199   // updates throughout the module only once.
1200   std::priority_queue<Work, std::vector<Work>, std::greater<Work>> worklist;
1201   absl::flat_hash_set<HloInstruction*> workset;
1202   auto priority_map = CalculatePostOrderSchedule(module_);
1203   auto add_to_worklist = [&priority_map, &worklist,
1204                           &workset](HloInstruction* instruction) {
1205     if (workset.insert(instruction).second) {
1206       worklist.emplace(priority_map[instruction], instruction);
1207     }
1208   };
1209 
1210   auto comps = module_.MakeComputationPostOrder();
1211   for (HloComputation* computation : comps) {
1212     for (HloInstruction* instruction :
1213          computation->MakeInstructionPostOrder()) {
1214       add_to_worklist(instruction);
1215     }
1216   }
1217   VLOG(1) << "SSA_FORM_: " << ssa_form_;
1218 
1219   while (!worklist.empty()) {
1220     HloInstruction* instruction = worklist.top().second;
1221     auto add_to_worklist = [&](HloInstruction* todo) {
1222       if (workset.insert(todo).second) {
1223         VLOG(1) << "  Adding todo : " << todo->name();
1224         worklist.emplace(priority_map[todo], todo);
1225       }
1226     };
1227     worklist.pop();
1228 
1229     workset.erase(workset.find(instruction));
1230 
1231     VLOG(3) << "Worklist top: " << instruction->name();
1232     XLA_VLOG_LINES(3, ToString());
1233 
1234     if (!UpdateInstructionValueSet(instruction)) {
1235       // No change to the instruction's value set.
1236       VLOG(4) << "No change.";
1237       continue;
1238     }
1239 
1240     VLOG(4) << "New value set for " << instruction->name() << ": "
1241             << GetInstructionValueSet(instruction);
1242 
1243     // Instruction value was updated. Add users to work list if we haven't
1244     // already.
1245     for (HloInstruction* user : instruction->users()) {
1246       add_to_worklist(user);
1247 
1248       // If user sequentially calls a computation, then the respective
1249       // parameter(s) of the computation need to be updated.
1250       if (user->opcode() == HloOpcode::kConditional) {
1251         // If operand 0 is the use of instruction, then no parameters need to be
1252         // updated, since that is the branch_index of the conditional.
1253         // If operand n+1 is the use of instruction, then the branch_computation
1254         // n's parameter need to be updated.
1255         //
1256         // Note that the same instruction can be used in multiple branches'
1257         // operands.
1258         for (int j = 0; j < user->branch_count(); ++j) {
1259           if (user->operand(j + 1) == instruction) {
1260             add_to_worklist(
1261                 user->branch_computation(j)->parameter_instruction(0));
1262           }
1263         }
1264       } else if (user->opcode() == HloOpcode::kAsyncUpdate ||
1265                  user->opcode() == HloOpcode::kAsyncDone) {
1266         // For async update and async done, we cannot distinguish which
1267         // parameter needs to be updated so add all to the worklist.
1268         for (int64_t parameter_number = 0;
1269              parameter_number <
1270              user->async_wrapped_computation()->num_parameters();
1271              ++parameter_number) {
1272           add_to_worklist(
1273               user->async_wrapped_computation()->parameter_instruction(
1274                   parameter_number));
1275         }
1276       } else {
1277         for (HloComputation* called_computation : user->called_computations()) {
1278           const CallGraphNode& call_graph_node =
1279               call_graph_->GetNode(called_computation);
1280           if (call_graph_node.context() == CallContext::kControlFlow) {
1281             for (int64_t operand_number : user->OperandIndices(instruction)) {
1282               add_to_worklist(
1283                   called_computation->parameter_instruction(operand_number));
1284             }
1285           }
1286         }
1287       }
1288     }
1289 
1290     // If instruction is a root instruction, then propagate out to any calling
1291     // instruction and across any while backedge.
1292     if (instruction == instruction->parent()->root_instruction()) {
1293       const CallGraphNode& call_graph_node =
1294           call_graph_->GetNode(instruction->parent());
1295       for (const CallSite& callsite : call_graph_node.caller_callsites()) {
1296         if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
1297           // Add the while itself, and the body and condition parameters.
1298           add_to_worklist(callsite.instruction());
1299           add_to_worklist(
1300               callsite.instruction()->while_body()->parameter_instruction(0));
1301           add_to_worklist(
1302               callsite.instruction()->while_condition()->parameter_instruction(
1303                   0));
1304         } else if (call_graph_node.context() == CallContext::kControlFlow) {
1305           add_to_worklist(callsite.instruction());
1306         }
1307       }
1308     }
1309   }
1310 }
1311 
GetInstructionValueSet(const HloInstruction * instruction) const1312 const InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
1313     const HloInstruction* instruction) const {
1314   return *value_sets_.at(instruction);
1315 }
1316 
GetInstructionValueSet(const HloInstruction * instruction)1317 InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
1318     const HloInstruction* instruction) {
1319   return *value_sets_.at(instruction);
1320 }
1321 
InitializeInstructionValueSets()1322 Status HloDataflowAnalysis::InitializeInstructionValueSets() {
1323   for (const HloComputation* computation : module_.MakeComputationSorted()) {
1324     const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
1325     for (HloInstruction* instruction :
1326          computation->MakeInstructionPostOrder()) {
1327       // Create an empty shape tree.
1328       value_sets_.insert({instruction, std::make_unique<InstructionValueSet>(
1329                                            instruction->shape())});
1330 
1331       // For each sub-shape of the instruction shape, add a new HloValue to its
1332       // HloValueSet. should_define may be provided to define a subset of
1333       // values.
1334       auto define_all_values =
1335           [this,
1336            &instruction](std::function<bool(const ShapeIndex&)> should_define =
1337                              [](const ShapeIndex&) { return true; }) {
1338             for (auto& pair : GetInstructionValueSet(instruction)) {
1339               const ShapeIndex& index = pair.first;
1340               if (should_define(index)) {
1341                 HloValue* value =
1342                     NewHloValue(instruction, index, /*is_phi=*/false);
1343                 GetValueSet(instruction, index).AddValue(value);
1344               }
1345             }
1346           };
1347 
1348       // Add a new HloValue to the HloValueSet corresponding to the given index
1349       // of the instruction shape.
1350       auto define_value_at = [this, &instruction](const ShapeIndex& index) {
1351         HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
1352         GetValueSet(instruction, index).AddValue(value);
1353       };
1354 
1355       switch (instruction->opcode()) {
1356         case HloOpcode::kBitcast:
1357           if (bitcast_defines_value_) {
1358             define_all_values();
1359           }
1360           break;
1361         case HloOpcode::kSetDimensionSize:
1362         case HloOpcode::kAddDependency:
1363         case HloOpcode::kWhile:
1364         case HloOpcode::kCall:
1365         case HloOpcode::kConditional:
1366         case HloOpcode::kGetTupleElement:
1367         case HloOpcode::kDomain:
1368         case HloOpcode::kOptimizationBarrier:
1369           // These instructions define no values. The values in their output
1370           // flow from their operands or from cross computation dataflow.
1371           break;
1372         case HloOpcode::kParameter:
1373           if (call_graph_node.context() == CallContext::kBoth) {
1374             // We do not support a subcomputation that is called from both a
1375             // parallel and sequential context. In this case, the parameter
1376             // would both define a value and propagate a value from its
1377             // caller. This limitation is not really a problem because the call
1378             // graph is typically flattened.
1379             return Unimplemented(
1380                 "Computation %s is called in both a parallel (eg, kMap) and "
1381                 "sequential (eg, kCall) context",
1382                 computation->name());
1383           }
1384           if (call_graph_node.caller_callsites().empty() ||
1385               call_graph_node.context() == CallContext::kEmbedded) {
1386             // Parameters of computations called in a parallel context (eg, map
1387             // and reduce) as well as parameters of dead computations define all
1388             // values in their output. Otherwise the values of the parameter
1389             // come from the caller (eg, operands to the kCall instruction).
1390             define_all_values();
1391           }
1392           break;
1393         case HloOpcode::kCopy:
1394         case HloOpcode::kTuple:
1395           // These instructions only define their top-level values. Any other
1396           // values flow from their operands.
1397           define_value_at(/*index=*/{});
1398           break;
1399         case HloOpcode::kAsyncStart:
1400           // AsyncStart produces a tuple of {{aliased operands}, {destination},
1401           // contexts}. It defines all of the tuple-shaped values and the
1402           // contexts.
1403           define_all_values([&](const ShapeIndex& index) {
1404             return ShapeUtil::GetSubshape(instruction->shape(), index)
1405                        .IsTuple() ||
1406                    index.front() > 1;
1407           });
1408           break;
1409         case HloOpcode::kAsyncUpdate:
1410           // AsyncUpdate produces a tuple of {{aliased operands}, {destination},
1411           // contexts} where all of the array-typed values alias with the
1412           // operand. So, only tuple-shaped values are defined by AsyncUpdate.
1413           define_all_values([&](const ShapeIndex& index) {
1414             return ShapeUtil::GetSubshape(instruction->shape(), index)
1415                 .IsTuple();
1416           });
1417           break;
1418         case HloOpcode::kAsyncDone:
1419           // AsyncDone's output aliases its output.
1420           break;
1421         case HloOpcode::kCopyStart:
1422           // CopyStart produces a tuple of {destination buffer, aliased operand,
1423           // U32 context}.
1424           define_value_at(/*index=*/{});
1425           define_value_at(/*index=*/{0});
1426           define_value_at(/*index=*/{2});
1427           break;
1428         case HloOpcode::kCopyDone:
1429           // CopyDone consumes a tuple produced by CopyStart and produces an
1430           // element. Its output aliases its input tuple element {0}.
1431           break;
1432         case HloOpcode::kAllGatherStart:
1433           // AllGatherStart produces a tuple of
1434           // {aliased operand, destination buffer}.
1435           define_value_at(/*index=*/{});
1436           define_value_at(/*index=*/{1});
1437           break;
1438         case HloOpcode::kAllGatherDone:
1439           // AllGatherDone's output aliases its input tuple element {1}.
1440           if (instruction->shape().IsTuple()) {
1441             define_value_at(/*index=*/{});
1442           }
1443           break;
1444         case HloOpcode::kAllReduceDone:
1445           // AllReduceDone's output aliases its input.
1446           break;
1447         case HloOpcode::kCollectivePermuteStart:
1448           // CollectivePermuteStart produces a tuple of
1449           // {aliased operand, destination buffer, U32 context, U32 context}.
1450           define_value_at(/*index=*/{});
1451           define_value_at(/*index=*/{1});
1452           define_value_at(/*index=*/{2});
1453           define_value_at(/*index=*/{3});
1454           if (instruction->operand_count() > 1) {
1455             CHECK_EQ(instruction->operand_count(), 4);
1456             if (instruction->operand(1)->shape().IsTuple()) {
1457               for (int i = 0; i < ShapeUtil::TupleElementCount(
1458                                       instruction->operand(1)->shape());
1459                    ++i) {
1460                 define_value_at(/*index=*/{1, i});
1461               }
1462             }
1463           }
1464           break;
1465         case HloOpcode::kCollectivePermuteDone:
1466           // CollectivePermuteDone's output aliases its input tuple element {1}.
1467           if (instruction->shape().IsTuple()) {
1468             define_value_at(/*index=*/{});
1469           }
1470           break;
1471         case HloOpcode::kRecvDone:
1472           // RecvDone produces a two-element tuple. Element zero aliases its
1473           // input tuple element {0}; element one is a token.
1474           define_value_at(/*index=*/{});
1475           define_value_at(/*index=*/{1});
1476           break;
1477         case HloOpcode::kSend:
1478           // Send produces a tuple of {aliased operand, U32 context, token},
1479           // therefore only defines the top-level tuple and the tuple elements
1480           // at {1} and {2}.
1481           define_value_at(/*index=*/{});
1482           define_value_at(/*index=*/{1});
1483           define_value_at(/*index=*/{2});
1484           break;
1485         default:
1486           define_all_values();
1487           break;
1488       }
1489     }
1490   }
1491 
1492   return OkStatus();
1493 }
1494 
OptimizePhiValues()1495 void HloDataflowAnalysis::OptimizePhiValues() {
1496   // Only applicable to SSA form where phis are defined.
1497   if (!ssa_form_) {
1498     return;
1499   }
1500 
1501   VLOG(1) << "Before phi graph optimization";
1502   XLA_VLOG_LINES(1, phi_graph_.ToString());
1503   phi_graph_.Optimize();
1504   VLOG(1) << "After phi graph optimization";
1505   XLA_VLOG_LINES(1, phi_graph_.ToString());
1506 
1507   for (const HloComputation* computation : module_.computations()) {
1508     for (HloInstruction* instruction : computation->instructions()) {
1509       InstructionValueSet& instruction_value_set =
1510           GetInstructionValueSet(instruction);
1511       VLOG(1) << "inst: " << instruction->name();
1512       VLOG(1) << instruction_value_set.ToString();
1513       instruction_value_set.ForEachMutableElement(
1514           [&](const xla::ShapeIndex& index, HloValueSet* value_set) {
1515             auto values = value_set->values();
1516             if (!(values.size() == 1 && values[0]->is_phi())) {
1517               return;
1518             }
1519             HloValue::Id phi_id = values[0]->id();
1520             HloValue::Id new_id = phi_graph_.FindOptimizedValue(phi_id);
1521             if (new_id != phi_id) {
1522               VLOG(1) << "Replacing " << values[0]->ToString() << " with "
1523                       << GetValue(new_id).ToString();
1524               value_set->Clear();
1525               const HloValue& new_value = GetValue(new_id);
1526               value_set->AddValue(&new_value);
1527               MarkValueForDeletion(phi_id);
1528             }
1529           });
1530     }
1531   }
1532 }
1533 
1534 /* static */
Run(const HloModule & module,bool ssa_form,bool bitcast_defines_value,const CanShareBuffer & can_share_buffer)1535 StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
1536     const HloModule& module, bool ssa_form, bool bitcast_defines_value,
1537     const CanShareBuffer& can_share_buffer) {
1538   VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name();
1539   XLA_VLOG_LINES(2, module.ToString());
1540 
1541   auto dataflow_analysis = absl::WrapUnique(new HloDataflowAnalysis(
1542       module, ssa_form, bitcast_defines_value, can_share_buffer));
1543 
1544   TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
1545   dataflow_analysis->Propagate();
1546   dataflow_analysis->OptimizePhiValues();
1547 
1548   // Delete all values marked for deletion.
1549   dataflow_analysis->DeleteMarkedValues();
1550 
1551   // Gather and set all non-definition positions of all values. Value deletion
1552   // is rare, so just use a vector indexed by Value::Id rather than a map from
1553   // Value::Id to positions. There should be very few holes in the vector, and
1554   // lookup is faster.
1555   std::vector<std::vector<HloPosition>> value_positions(
1556       dataflow_analysis->next_value_id_);
1557   for (const HloComputation* computation : module.computations()) {
1558     for (HloInstruction* instruction : computation->instructions()) {
1559       for (const auto& pair :
1560            dataflow_analysis->GetInstructionValueSet(instruction)) {
1561         const ShapeIndex& index = pair.first;
1562         const HloValueSet& value_set = pair.second;
1563         for (const HloValue* value : value_set.values()) {
1564           if (value->defining_instruction() != instruction) {
1565             value_positions[value->id()].push_back(
1566                 HloPosition{instruction, index});
1567           }
1568         }
1569       }
1570     }
1571   }
1572   for (auto& pair : dataflow_analysis->values_) {
1573     HloValue::Id value_id = pair.first;
1574     HloValue& value = *pair.second;
1575     value.SetPositions(value_positions[value_id]);
1576   }
1577 
1578   // Construct vector of values.
1579   dataflow_analysis->values_vector_.reserve(dataflow_analysis->values_.size());
1580   for (const auto& pair : dataflow_analysis->values_) {
1581     dataflow_analysis->values_vector_.push_back(pair.second.get());
1582   }
1583   absl::c_sort(dataflow_analysis->values_vector_, HloValue::IdLessThan);
1584 
1585   TF_DCHECK_OK(dataflow_analysis->Verify());
1586 
1587   XLA_VLOG_LINES(1, dataflow_analysis->ToString());
1588 
1589   return std::move(dataflow_analysis);
1590 }
1591 
Verify() const1592 Status HloDataflowAnalysis::Verify() const {
1593   // Verify each HloValue appears in the value sets that the value's positions()
1594   // indicate.
1595   for (const HloValue* value : values()) {
1596     for (const HloPosition& position : value->positions()) {
1597       const HloValueSet& value_set = GetValueSet(position);
1598       TF_RET_CHECK(absl::c_linear_search(value_set.values(), value))
1599           << "Value set at position " << position << " does not contain value "
1600           << value->ToShortString();
1601     }
1602   }
1603 
1604   // For each value in each value set, verify that the value set's position
1605   // appears in the value's positions().
1606   for (const auto& computation : module_.computations()) {
1607     for (const auto& instruction : computation->instructions()) {
1608       for (const auto& pair : GetInstructionValueSet(instruction)) {
1609         const ShapeIndex& index = pair.first;
1610         const HloValueSet& value_set = pair.second;
1611         const HloPosition position{instruction, index};
1612         for (const HloValue* value : value_set.values()) {
1613           TF_RET_CHECK(absl::c_linear_search(value->positions(), position))
1614               << "Value set at position " << position
1615               << " unexpectedly contains value " << value->ToShortString();
1616         }
1617       }
1618     }
1619   }
1620 
1621   return OkStatus();
1622 }
1623 
DoesNotUseOperandBuffer(const HloInstruction * operand,const ShapeIndex & index,const HloInstruction * user) const1624 bool HloDataflowAnalysis::DoesNotUseOperandBuffer(
1625     const HloInstruction* operand, const ShapeIndex& index,
1626     const HloInstruction* user) const {
1627   // Return false if no value at 'operand' and 'index' is used at 'user'.
1628   for (const HloValue* value : GetValueSet(operand, index).values()) {
1629     for (const HloUse& use : value->GetUses()) {
1630       if (use.instruction == user) {
1631         if (user->IsLoopFusion()) {
1632           HloInstruction* fusion_param =
1633               user->fused_parameter(use.operand_number);
1634           const HloValue& value =
1635               GetValueDefinedAt(fusion_param, use.operand_index);
1636           return value.GetUses().empty();
1637         }
1638         return false;
1639       }
1640     }
1641   }
1642   return true;
1643 }
1644 
IsInPlaceOperation(HloOpcode opcode)1645 /*static*/ bool HloDataflowAnalysis::IsInPlaceOperation(HloOpcode opcode) {
1646   return opcode == HloOpcode::kDynamicUpdateSlice ||
1647          opcode == HloOpcode::kScatter;
1648 }
1649 
IsAsynchronousOperationStart(HloOpcode opcode)1650 /*static*/ bool HloDataflowAnalysis::IsAsynchronousOperationStart(
1651     HloOpcode opcode) {
1652   return opcode == HloOpcode::kSend || opcode == HloOpcode::kRecv ||
1653          opcode == HloOpcode::kCopyStart ||
1654          opcode == HloOpcode::kAllReduceStart ||
1655          opcode == HloOpcode::kAllGatherStart ||
1656          opcode == HloOpcode::kCollectivePermuteStart ||
1657          opcode == HloOpcode::kAsyncStart;
1658 }
1659 
IsAsynchronousOperationDone(HloOpcode opcode)1660 /*static*/ bool HloDataflowAnalysis::IsAsynchronousOperationDone(
1661     HloOpcode opcode) {
1662   return opcode == HloOpcode::kSendDone || opcode == HloOpcode::kRecvDone ||
1663          opcode == HloOpcode::kCopyDone ||
1664          opcode == HloOpcode::kAllReduceDone ||
1665          opcode == HloOpcode::kAllGatherDone ||
1666          opcode == HloOpcode::kCollectivePermuteDone ||
1667          opcode == HloOpcode::kAsyncDone;
1668 }
1669 
1670 namespace {
1671 
1672 // Removes layers of tuple indirection introduced via 'tuple' and
1673 // 'get-tuple-element' instructions to more directly identify the source of the
1674 // given HLO value (identified by the given `ShapeIndex` into the output of the
1675 // given `HloInstruction`).
1676 //
1677 // e.g. for the following:
1678 //    %x = some-op(...)
1679 //    %foo = get-tuple-element(%x), index=0
1680 //    %bar = tuple(%y, %foo)
1681 //
1682 // ... FollowTupleIndirection(%bar, {1}) == {%x, {0}} (output 1 of 'bar' comes
1683 // from output 0 of %x).
1684 //
1685 // Note that all 'tuple' instructions are followed before all
1686 // 'get-tuple-element' instructions are followed. This is because it is assumed
1687 // that tupling a value and then extracting it from the tuple again will not
1688 // occur in properly-optimized IR.
FollowTupleIndirection(const HloInstruction * instruction,ShapeIndex operand_index)1689 std::pair<const HloInstruction*, ShapeIndex> FollowTupleIndirection(
1690     const HloInstruction* instruction, ShapeIndex operand_index) {
1691   while (instruction->opcode() == HloOpcode::kTuple && !operand_index.empty()) {
1692     instruction = instruction->operand(operand_index.front());
1693     operand_index.pop_front();
1694   }
1695   while (instruction->opcode() == HloOpcode::kGetTupleElement) {
1696     operand_index.push_front(instruction->tuple_index());
1697     instruction = instruction->operand(0);
1698   }
1699 
1700   return {instruction, operand_index};
1701 }
1702 
1703 // Returns in-place input/output pairs for the given fusion instruction,
1704 // according to the aliasing rules for the corresponding fusion computation.
1705 //
1706 // `instruction` must be a fusion instruction.
1707 std::vector<std::pair<HloOperandIndex, ShapeIndex>>
GetFusionInstructionInPlaceInputOutputPairs(const HloInstruction * instruction)1708 GetFusionInstructionInPlaceInputOutputPairs(const HloInstruction* instruction) {
1709   std::vector<std::pair<HloOperandIndex, ShapeIndex>>
1710       in_place_input_output_pairs;
1711   // Each of these leaves represents one array output of the fusion that might
1712   // be aliased with one of the fusion computation's array inputs (both could be
1713   // nested arbitrarily deep inside tuples).
1714   for (const auto& fusion_output_array_shape :
1715        ShapeUtil::GetLeafShapes(instruction->shape())) {
1716     // Start from the root instruction of the fusion computation and follow
1717     // tuple indirection backwards to find the "output source", i.e. the
1718     // instruction that is the original source of the array output in question.
1719     // If there is no such indirection the "output source" will just be the
1720     // fusion root instruction itself.
1721     const HloInstruction* output_source_instruction =
1722         instruction->fused_expression_root();
1723     ShapeIndex output_source_index = fusion_output_array_shape.index;
1724     std::tie(output_source_instruction, output_source_index) =
1725         FollowTupleIndirection(output_source_instruction, output_source_index);
1726 
1727     // The aliasing rules of the "output source" instruction determine the
1728     // aliasing rules for the entire fusion. If we can connect (following tuple
1729     // indirection) the input of an "in-place" pair to one of the fusion's
1730     // inputs, and the output of this "in-place" pair to the fusion output
1731     // in question, then this fusion input and output must alias.
1732     auto in_place_pairs = HloDataflowAnalysis::GetInPlaceInputOutputPairs(
1733         output_source_instruction);
1734     ShapeIndex in_place_input_index;
1735     const HloInstruction* in_place_input_source = nullptr;
1736 
1737     for (const auto& output_source_in_place_pair : in_place_pairs) {
1738       const HloOperandIndex& input = output_source_in_place_pair.first;
1739       const ShapeIndex& output_index = output_source_in_place_pair.second;
1740       if (output_index == output_source_index) {
1741         // It is not possible for the same output to alias multiple inputs.
1742         CHECK(in_place_input_source == nullptr);
1743         in_place_input_source =
1744             output_source_instruction->operand(input.operand_number);
1745         in_place_input_index = input.operand_index;
1746       }
1747     }
1748 
1749     if (in_place_input_source) {
1750       // Follow tuple indirection backwards from the instruction input to try to
1751       // find a fusion parameter. If found, that parameter aliases the current
1752       // output. If not, the current output aliases no input.
1753       std::tie(in_place_input_source, in_place_input_index) =
1754           FollowTupleIndirection(in_place_input_source, in_place_input_index);
1755 
1756       if (in_place_input_source->opcode() == HloOpcode::kParameter) {
1757         in_place_input_output_pairs.emplace_back(
1758             HloOperandIndex{in_place_input_source->parameter_number(),
1759                             in_place_input_index},
1760             fusion_output_array_shape.index);
1761       }
1762     }
1763   }
1764   return in_place_input_output_pairs;
1765 }
1766 
1767 }  // namespace
1768 
1769 /*static*/ std::vector<std::pair<HloOperandIndex, ShapeIndex>>
GetInPlaceInputOutputPairs(const HloInstruction * instruction)1770 HloDataflowAnalysis::GetInPlaceInputOutputPairs(
1771     const HloInstruction* instruction) {
1772   if (IsInPlaceOperation(instruction->opcode())) {
1773     const HloScatterInstruction* scatter =
1774         DynCast<HloScatterInstruction>(instruction);
1775     if (scatter && scatter->scatter_operand_count() > 1) {
1776       std::vector<std::pair<HloOperandIndex, ShapeIndex>> pairs;
1777       pairs.reserve(scatter->scatter_operand_count());
1778       for (int i = 0, n = scatter->scatter_operand_count(); i < n; ++i) {
1779         pairs.emplace_back(HloOperandIndex{i, {}}, ShapeIndex{i});
1780       }
1781       return pairs;
1782     }
1783     return {{HloOperandIndex{0, {}}, {}}};
1784   } else if (instruction->opcode() == HloOpcode::kCollectivePermute &&
1785              instruction->operands().size() == 4) {
1786     if (instruction->operand(1)->shape().IsTuple()) {
1787       std::vector<std::pair<HloOperandIndex, ShapeIndex>> in_place_pairs(
1788           {{HloOperandIndex{1, {}}, {}}});
1789       for (int i = 0; i < instruction->operand(1)->shape().tuple_shapes_size();
1790            i++) {
1791         in_place_pairs.push_back({HloOperandIndex{1, {i}}, {i}});
1792       }
1793       return in_place_pairs;
1794     } else {
1795       return {{HloOperandIndex{1, {}}, {}}};
1796     }
1797   } else if (instruction->opcode() == HloOpcode::kCollectivePermuteStart &&
1798              instruction->operands().size() == 4) {
1799     if (instruction->operand(1)->shape().IsTuple()) {
1800       std::vector<std::pair<HloOperandIndex, ShapeIndex>> in_place_pairs(
1801           {{HloOperandIndex{1, {}}, {1}}});
1802       for (int i = 0; i < instruction->operand(1)->shape().tuple_shapes_size();
1803            i++) {
1804         in_place_pairs.push_back({HloOperandIndex{1, {i}}, {1, i}});
1805       }
1806       return in_place_pairs;
1807     } else {
1808       return {{HloOperandIndex{1, {}}, {1}}};
1809     }
1810   } else if (instruction->opcode() == HloOpcode::kCustomCall) {
1811     // Custom Calls previously assumed that aliased operands were
1812     // forwarded, but now supports modifiction semantics.
1813     const auto& aliasing_pairs = Cast<HloCustomCallInstruction>(instruction)
1814                                      ->output_to_operand_aliasing();
1815     std::vector<std::pair<HloOperandIndex, ShapeIndex>> in_place_pairs;
1816     in_place_pairs.reserve(aliasing_pairs.size());
1817     for (const auto& pair : aliasing_pairs) {
1818       ShapeIndex output_shape_index = pair.first;
1819       int64_t operand_index = pair.second.first;
1820       ShapeIndex operand_shape_index = pair.second.second;
1821       in_place_pairs.push_back(
1822           {HloOperandIndex{operand_index, {operand_shape_index}},
1823            output_shape_index});
1824     }
1825     return in_place_pairs;
1826   } else if (instruction->opcode() == HloOpcode::kAllReduceStart) {
1827     if (instruction->operands().size() == 1) {
1828       return {{HloOperandIndex{0, {}}, {}}};
1829     }
1830     std::vector<std::pair<HloOperandIndex, ShapeIndex>> in_place_pairs;
1831     in_place_pairs.reserve(instruction->operands().size());
1832     for (int i = 0; i < instruction->operands().size(); i++) {
1833       in_place_pairs.push_back({HloOperandIndex{i, {}}, {i}});
1834     }
1835     return in_place_pairs;
1836   } else if (instruction->opcode() == HloOpcode::kFusion) {
1837     return GetFusionInstructionInPlaceInputOutputPairs(instruction);
1838   }
1839 
1840   return {};
1841 }
1842 
CanShareOperandBufferWithUser(HloInstruction * operand,const ShapeIndex & operand_index,HloInstruction * user,const ShapeIndex & user_index) const1843 bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
1844     HloInstruction* operand, const ShapeIndex& operand_index,
1845     HloInstruction* user, const ShapeIndex& user_index) const {
1846   CHECK(user->IsUserOf(operand))
1847       << "user: " << user->ToString() << " operand: " << operand->ToString();
1848   if (operand->opcode() == HloOpcode::kConstant) {
1849     return false;
1850   }
1851 
1852   const Shape& operand_subshape =
1853       ShapeUtil::GetSubshape(operand->shape(), operand_index);
1854   const Shape& user_subshape =
1855       ShapeUtil::GetSubshape(user->shape(), user_index);
1856   if (IsSliceInputFusion(*user)) {
1857     HloInstruction* fusion_param =
1858         user->fused_parameter(user->operand_index(operand));
1859     // We don't require the same dimensions but only the same number of elements
1860     // and type (to make sure the same buffer size).
1861     return operand_subshape.IsArray() && user_subshape.IsArray() &&
1862            ShapeUtil::ElementsIn(operand_subshape) ==
1863                ShapeUtil::ElementsIn(user_subshape) &&
1864            ShapeUtil::SameElementType(operand_subshape, user_subshape) &&
1865            AreTransitiveUsesEffectivelyElementwise(
1866                fusion_param, user->fused_expression_root(), user_index);
1867   }
1868 
1869   auto shapes_equal = ShapeUtil::Equal(operand_subshape, user_subshape);
1870   // Check that operand and user emit the same shape and layout.
1871   if (shapes_equal) {
1872     // Must-alias relationship returns true for in-place operations (DUS and DUS
1873     // fusions), regardless of the backend.
1874     for (const auto& operand_and_output_index :
1875          GetInPlaceInputOutputPairs(user)) {
1876       if (operand_and_output_index.second != user_index) {
1877         continue;
1878       }
1879       for (const HloUse& use :
1880            GetUniqueValueAt(operand, operand_index).GetUses()) {
1881         if (use == HloUse{user, operand_and_output_index.first.operand_number,
1882                           operand_and_output_index.first.operand_index}) {
1883           return true;
1884         }
1885       }
1886     }
1887   }
1888 
1889   if (can_share_buffer_ != nullptr) {
1890     if (std::optional<bool> hint =
1891             can_share_buffer_(user, operand, user_index)) {
1892       return *hint;
1893     }
1894   }
1895 
1896   if (!shapes_equal) {
1897     return false;
1898   }
1899 
1900   if (user->opcode() == HloOpcode::kFusion) {
1901     HloInstruction* fusion_param =
1902         user->fused_parameter(user->operand_index(operand));
1903     const HloValue& fusion_param_value =
1904         GetValueDefinedAt(fusion_param, operand_index);
1905 
1906     if (user->IsLoopFusion() || user->IsInputFusion()) {
1907       return AreTransitiveUsesElementwiseOrTuple(fusion_param);
1908     }
1909 
1910     if (user->IsOutputFusion() &&
1911         user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
1912       // Output fusion with kAdd fused root.
1913 
1914       // Check if one operand of kAdd fused root is kDot or kConvolution.
1915       auto* add = user->fused_expression_root();
1916       auto add_operand_it =
1917           absl::c_find_if(add->operands(), [&](HloInstruction* operand) {
1918             return operand->opcode() == HloOpcode::kConvolution ||
1919                    operand->opcode() == HloOpcode::kDot;
1920           });
1921       if (add_operand_it == add->operands().end()) {
1922         return false;
1923       }
1924       auto* matched_add_operand = *add_operand_it;
1925       // Calculate operand index of 'add' operand which was not matched above.
1926       const int64_t other_add_operand_index =
1927           matched_add_operand == add->operand(0) ? 1 : 0;
1928       // Returns true iff there is exactly one use of 'operand' at shape index
1929       // 'operand_index', and this singleton use is the fused root (at operand
1930       // index 'other_add_operand_index').
1931       if (fusion_param_value.GetUses().size() == 1) {
1932         const HloUse& use = fusion_param_value.GetUses()[0];
1933         return use.instruction == user->fused_expression_root() &&
1934                use.operand_number == other_add_operand_index;
1935       }
1936       return false;
1937     }
1938   }
1939 
1940   // There is nothing inherently wrong with while and conditional ops to have
1941   // input/output buffers to alias with each other, even when the indices are
1942   // different in the while case. It is a problem when this aliasing causes HLO
1943   // ops inside these while or conditional to have input/output buffer aliasing
1944   // that isn't allowed. So allow while and conditional to share buffers with
1945   // operands and we will discover any problematic sharing when we explore the
1946   // ops inside these computations.
1947   if (user->opcode() == HloOpcode::kWhile ||
1948       user->opcode() == HloOpcode::kConditional) {
1949     return true;
1950   }
1951 
1952   if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
1953       user->opcode() == HloOpcode::kScatter ||
1954       user->opcode() == HloOpcode::kTriangularSolve) {
1955     // We eliminated other users in HloOrdering::LiveRangeStrictlyBefore
1956     // so here we just need to check that the use is at the right operand index.
1957     const auto operand_indices = user->OperandIndices(operand);
1958     int64_t operand_no = user->opcode() == HloOpcode::kTriangularSolve ? 1 : 0;
1959     return operand_indices.size() == 1 && operand_indices[0] == operand_no;
1960   }
1961   if (user->opcode() == HloOpcode::kSort) {
1962     // Only valid if there are no other users.
1963     if (operand->users().size() != 1) {
1964       return false;
1965     }
1966     // If we only sort keys, the output of sort is not a tuple, so we can always
1967     // share the buffer.
1968     if (user->operand_count() == 1) {
1969       return true;
1970     }
1971     CHECK(!user_index.empty());
1972     // Only share with the right tuple element buffer.
1973     const auto operand_indices = user->OperandIndices(operand);
1974     return operand_indices.size() == 1 && user_index[0] == operand_indices[0];
1975   }
1976   if (user->opcode() == HloOpcode::kCall) {
1977     // Get all uses of value defined by 'operand' at 'operand_index'.
1978     auto uses = GetValueDefinedAt(operand, operand_index).GetUses();
1979     // Return true iff:
1980     // *) There exists two uses of 'operand'.
1981     // *) One use is by 'user' (caller).
1982     // *) One use is by root instruction of called computation (callee root).
1983     //    (Note: we check the root of the called computation, because the
1984     //     root result buffer is required to alias with the Call result buffer).
1985     // *) The root instruction of the called computation is element-wise on
1986     //    'operand'.
1987     const bool found_caller_use =
1988         absl::c_find_if(uses, [user](const HloUse& use) {
1989           return use.instruction == user;
1990         }) != uses.end();
1991     auto* callee_root = user->to_apply()->root_instruction();
1992     const bool found_elementwise_callee_use =
1993         absl::c_find_if(uses, [callee_root](const HloUse& use) {
1994           return use.instruction == callee_root &&
1995                  callee_root->IsElementwiseOnOperand(use.operand_number);
1996         }) != uses.end();
1997     return uses.size() == 2 && found_caller_use && found_elementwise_callee_use;
1998   }
1999 
2000   // Loop fusions that contain transposing copies won't reach here as they have
2001   // different layouts, which fails the check in the beginning of this function.
2002   return user->IsElementwiseOnOperand(user->operand_index(operand));
2003 }
2004 
2005 }  // namespace xla
2006