1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
17
18 #include <algorithm>
19 #include <queue>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/types/optional.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
35 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
36 #include "tensorflow/compiler/xla/service/hlo_module.h"
37 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
38 #include "tensorflow/compiler/xla/service/hlo_value.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/status.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/platform/logging.h"
45
46 namespace xla {
47 namespace {
48 // CalculatePostOrderSchedule traverses a module and assign a ordinal to each
49 // instruction based the postorder dependency.
CalculatePostOrderScheduleHelper(const HloComputation * comp,int64_t start_ordinal,absl::flat_hash_map<HloInstruction *,int64> * ordinal_map)50 int64 CalculatePostOrderScheduleHelper(
51 const HloComputation* comp, int64_t start_ordinal,
52 absl::flat_hash_map<HloInstruction*, int64>* ordinal_map) {
53 int64_t ordinal = start_ordinal;
54 for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) {
55 if (instruction->opcode() == HloOpcode::kCall ||
56 instruction->opcode() == HloOpcode::kConditional) {
57 for (const HloComputation* called_computation :
58 instruction->called_computations()) {
59 ordinal = CalculatePostOrderScheduleHelper(called_computation, ordinal,
60 ordinal_map);
61 }
62 }
63 if (instruction->opcode() == HloOpcode::kWhile) {
64 ordinal = CalculatePostOrderScheduleHelper(instruction->while_condition(),
65 ordinal, ordinal_map);
66 ordinal = CalculatePostOrderScheduleHelper(instruction->while_body(),
67 ordinal, ordinal_map);
68 }
69 // It's possible that in some unit tests the computation graph is not
70 // flatten (meaning we could have multiple callers for one computation). In
71 // that case the oridinal_map will see the instruction multiple times. We
72 // consider that case to be ok as it only shows up in unit tests.
73 ordinal_map->insert({instruction, ordinal++});
74 }
75 return ordinal;
76 }
77
CalculatePostOrderSchedule(const HloModule & module)78 absl::flat_hash_map<HloInstruction*, int64> CalculatePostOrderSchedule(
79 const HloModule& module) {
80 absl::flat_hash_map<HloInstruction*, int64> map;
81 CalculatePostOrderScheduleHelper(module.entry_computation(), 0, &map);
82 return map;
83 }
84
85 } // namespace
86 using absl::StrAppend;
87 using absl::StrCat;
88
HloDataflowAnalysis(const HloModule & module,bool ssa_form,bool bitcast_defines_value,const CanShareBuffer & can_share_buffer)89 HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form,
90 bool bitcast_defines_value,
91 const CanShareBuffer& can_share_buffer)
92 : module_(module),
93 ssa_form_(ssa_form),
94 bitcast_defines_value_(bitcast_defines_value),
95 call_graph_(CallGraph::Build(&module)),
96 can_share_buffer_(can_share_buffer) {}
97
AreTransitiveUsesElementwiseOrTuple(const HloInstruction * inst)98 bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
99 const HloInstruction* inst) {
100 absl::flat_hash_set<const HloInstruction*> visited;
101 absl::InlinedVector<const HloInstruction*, 4> stack;
102 stack.push_back(inst);
103 while (!stack.empty()) {
104 const HloInstruction* current = stack.back();
105 stack.pop_back();
106 visited.insert(current);
107 for (const HloInstruction* user : current->users()) {
108 // Found a user that is non-elementwise on current instruction.
109 for (const int64_t use_index : user->OperandIndices(current)) {
110 if (!user->IsElementwiseOnOperand(use_index) &&
111 user->opcode() != HloOpcode::kTuple) {
112 return false;
113 }
114 }
115 if (!visited.contains(user)) {
116 stack.push_back(user);
117 }
118 }
119 }
120 return true;
121 }
122
123 namespace {
Is1dSliceWithoutStrides(const HloInstruction * instr)124 bool Is1dSliceWithoutStrides(const HloInstruction* instr) {
125 return instr->opcode() == HloOpcode::kSlice &&
126 1 == instr->slice_starts().size() &&
127 1 == instr->slice_limits().size() &&
128 1 == instr->slice_strides().size() &&
129 1 == instr->slice_strides().at(0);
130 }
131
IsSliceInputFusion(const HloInstruction & unnested_hlo)132 bool IsSliceInputFusion(const HloInstruction& unnested_hlo) {
133 if (!unnested_hlo.IsInputFusion()) {
134 return false;
135 }
136 const HloInstruction* root = unnested_hlo.fused_expression_root();
137 if (root->opcode() != HloOpcode::kTuple) {
138 return false;
139 }
140 return absl::c_all_of(root->operands(), [](const HloInstruction* instr) {
141 return Is1dSliceWithoutStrides(instr);
142 });
143 }
144
145 struct ConcatUsageInfo {
146 // Pointer to a previously seen concat. nullptr if no previously seen concat.
147 const HloInstruction* prev_concat;
148 // The opnd id of the seen concat.
149 int64 concat_opnd_idx;
150 // The slice that recovers the opnd in the concat outputs.
151 const HloInstruction* slice_to_recover_opnd;
152 };
153
154 // Returns an optional concat usage info to denote whether the concat is used in
155 // an elementwise manner. A concat followed by slices is considered effectively
156 // elementwise if the slices combinedly is a reverse function of the concat.
ConcatIsEffectivelyElementwise(const HloInstruction & concat,const HloInstruction & operand,const ConcatUsageInfo & info)157 absl::optional<ConcatUsageInfo> ConcatIsEffectivelyElementwise(
158 const HloInstruction& concat, const HloInstruction& operand,
159 const ConcatUsageInfo& info) {
160 // First, check if this concat is in the below pattern. Also, we check
161 // that the slices combinedly are in effect a reverse function of the concat.
162 //
163 // Concat
164 // | |
165 // v v
166 // Slice Slice
167 //
168 std::vector<HloInstruction*> users = concat.users();
169 if (!absl::c_all_of(users, Is1dSliceWithoutStrides)) {
170 // Limit our supported cases to 1 dimensional slices.
171 return absl::optional<ConcatUsageInfo>();
172 }
173 // Verify that each operand to the concat is reversed by a slice.
174 if (users.size() != concat.operand_count() ||
175 concat.operand_count() != concat.unique_operands().size()) {
176 return absl::optional<ConcatUsageInfo>();
177 }
178 absl::c_sort(users, [](const HloInstruction* a, const HloInstruction* b) {
179 return a->slice_starts().at(0) < b->slice_starts().at(0);
180 });
181 int64_t prev_limit = 0;
182 for (int64_t i = 0; i < users.size(); ++i) {
183 const HloInstruction* u = users[i];
184 int64_t slice_size = u->slice_limits().at(0) - u->slice_starts().at(0);
185 if (u->slice_starts().at(0) != prev_limit ||
186 slice_size != ShapeUtil::ElementsIn(concat.operand(i)->shape())) {
187 return absl::optional<ConcatUsageInfo>();
188 }
189 prev_limit = u->slice_limits().at(0);
190 }
191
192 // If we have seen other concats, make sure they are identical. Multiple
193 // concats exist because horizontal fusion inserts one concat for each output
194 // of the fusion candidates. Check that all concats and operand ids are the
195 // same to know that the "transitive use closure" will be computed in the same
196 // iteration space.
197 int64_t operand_idx = concat.operand_index(&operand);
198 if (info.prev_concat != nullptr) {
199 bool is_concat_identical = info.prev_concat->Identical(
200 concat,
201 /*eq_operands=*/[](const HloInstruction*, const HloInstruction*) {
202 // Operands don't need to be the same.
203 return true;
204 });
205 if (!is_concat_identical || info.concat_opnd_idx != operand_idx) {
206 return absl::optional<ConcatUsageInfo>();
207 }
208 }
209
210 const HloInstruction* slice_to_recover_opnd = users.at(operand_idx);
211 return absl::optional<ConcatUsageInfo>(
212 ConcatUsageInfo{&concat, operand_idx, slice_to_recover_opnd});
213 }
214
215 // Returns whether we can prove the transitive uses of `param` are in effect
216 // elementwise. In other words, we prove that the "transitive use closure" will
217 // all be computed in the same iteration space without any reorder of elements.
218 // In addition, we check that the "transitive use closure" includes the output
219 // in the `root_tuple`.
220 // Theoretically, We can prove more patterns but our primary use case is
221 // SliceInputFusion.
AreTransitiveUsesEffectivelyElementwise(const HloInstruction * param,const HloInstruction * root_tuple,const ShapeIndex & out_shape_idx)222 bool AreTransitiveUsesEffectivelyElementwise(const HloInstruction* param,
223 const HloInstruction* root_tuple,
224 const ShapeIndex& out_shape_idx) {
225 CHECK_EQ(root_tuple->opcode(), HloOpcode::kTuple);
226 CHECK_EQ(out_shape_idx.size(), 1);
227 absl::flat_hash_set<const HloInstruction*> visited;
228 absl::InlinedVector<const HloInstruction*, 4> stack;
229 stack.push_back(param);
230 ConcatUsageInfo concat_usage_info{nullptr, 0, nullptr};
231 bool is_output_reachable = false;
232 while (!stack.empty()) {
233 const HloInstruction* current = stack.back();
234 stack.pop_back();
235 visited.insert(current);
236 for (const HloInstruction* user : current->users()) {
237 VLOG(3) << "Visiting: " << user->ToString();
238 switch (user->opcode()) {
239 case HloOpcode::kTuple:
240 if (user == root_tuple &&
241 current == root_tuple->operand(out_shape_idx.back())) {
242 // We need to know if the output is reachable by the `param` to make
243 // sure that they will be computed in the same iteration space.
244 is_output_reachable = true;
245 }
246 break;
247 case HloOpcode::kReshape:
248 if (!ShapeUtil::ReshapeIsBitcast(current->shape(), user->shape())) {
249 return false;
250 }
251 break;
252 case HloOpcode::kConcatenate: {
253 absl::optional<ConcatUsageInfo> optional_concat_info =
254 ConcatIsEffectivelyElementwise(*user, *current,
255 concat_usage_info);
256 if (!optional_concat_info) {
257 return false;
258 }
259 concat_usage_info = *optional_concat_info;
260 // Early continue as we only want to traverse through the slice that
261 // recovers the operand. It is guaranteed that the operand to the
262 // concat and the slice have the same iteration space. Insert the
263 // slice instead of the concat.
264 CHECK(!visited.contains(concat_usage_info.slice_to_recover_opnd));
265 stack.push_back(concat_usage_info.slice_to_recover_opnd);
266 continue;
267 }
268 default:
269 for (const int64_t use_index : user->OperandIndices(current)) {
270 if (!user->IsElementwiseOnOperand(use_index)) {
271 // Found a user that is non-elementwise on the current
272 // instruction.
273 return false;
274 }
275 }
276 if (!LayoutUtil::Equal(current->shape().layout(),
277 user->shape().layout())) {
278 // Make sure the layout is not changed by the elementwise op.
279 return false;
280 }
281 break;
282 } // end of switch
283 if (!visited.contains(user)) {
284 stack.push_back(user);
285 }
286 }
287 }
288 return is_output_reachable;
289 }
290 } // namespace
291
ValueIsDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const292 bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
293 const ShapeIndex& index) const {
294 const HloValueSet& value_set = GetValueSet(instruction, index);
295 if (value_set.values().size() != 1) {
296 return false;
297 }
298 return value_set.GetUniqueValue().defining_instruction() == instruction;
299 }
300
GetValueDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const301 const HloValue& HloDataflowAnalysis::GetValueDefinedAt(
302 const HloInstruction* instruction, const ShapeIndex& index) const {
303 CHECK(ValueIsDefinedAt(instruction, index)) << instruction->ToString();
304 return GetUniqueValueAt(instruction, index);
305 }
306
GetValueDefinedAt(const HloInstruction * instruction,const ShapeIndex & index)307 HloValue& HloDataflowAnalysis::GetValueDefinedAt(
308 const HloInstruction* instruction, const ShapeIndex& index) {
309 CHECK(ValueIsDefinedAt(instruction, index));
310 return GetUniqueValueAt(instruction, index);
311 }
312
NewHloValue(HloInstruction * instruction,const ShapeIndex & index,bool is_phi)313 HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction,
314 const ShapeIndex& index,
315 bool is_phi) {
316 const int64_t value_id = next_value_id_++;
317 auto emplaced = values_.emplace(
318 std::piecewise_construct, std::forward_as_tuple(value_id),
319 std::forward_as_tuple(value_id, instruction, index, is_phi));
320 CHECK(emplaced.second);
321
322 VLOG(4) << "NewHloValue = " << emplaced.first->second.ToShortString();
323
324 return &emplaced.first->second;
325 }
326
MarkValueForDeletion(HloValue::Id value_id)327 void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) {
328 HloValue& value = values_.at(value_id);
329 VLOG(4) << "MarkValueForDeletion(" << value.ToShortString() << ")";
330
331 value_ids_to_delete_.push_back(value_id);
332 }
333
DeleteMarkedValues()334 void HloDataflowAnalysis::DeleteMarkedValues() {
335 // Use a set to prevent deleting an id twice.
336 absl::flat_hash_set<HloValue::Id> id_set(value_ids_to_delete_.begin(),
337 value_ids_to_delete_.end());
338 #ifndef NDEBUG
339 // Verify that no marked-for-deletion values are in any of the value sets.
340 for (const auto& pair : value_sets_) {
341 const HloInstruction* instruction = pair.first;
342 const InstructionValueSet& instruction_value_set = pair.second;
343 for (const auto& index_value_set : instruction_value_set) {
344 const HloValueSet& value_set = index_value_set.second;
345 for (const HloValue* value : value_set.values()) {
346 DCHECK(!ContainsKey(id_set, value->id()))
347 << "Value " << value->ToShortString()
348 << " marked for deletion, but still exists in value set for "
349 "instruction "
350 << instruction->name();
351 }
352 }
353 }
354 #endif
355
356 for (HloValue::Id value_id : id_set) {
357 values_.erase(value_id);
358 }
359 value_ids_to_delete_.clear();
360 }
361
ToString() const362 string HloDataflowAnalysis::ToString() const {
363 string out = StrCat("HloDataflowAnalysis, module ", module_.name(), "\n");
364 StrAppend(&out, " Instruction value sets:\n");
365 for (const HloComputation* computation : module_.computations()) {
366 for (const HloInstruction* instruction : computation->instructions()) {
367 StrAppend(&out, "Instruction: \n ", instruction->name(), ":\n");
368 if (instruction->shape().IsTuple()) {
369 GetInstructionValueSet(instruction)
370 .ForEachElement([this, &instruction, &out](
371 const ShapeIndex& index,
372 const HloValueSet& value_set) {
373 StrAppend(&out, " tuple index ", index.ToString(), ":\n");
374 for (const HloValue* value : value_set.values()) {
375 StrAppend(&out, " ", value->ToShortString(),
376 ValueIsDefinedAt(instruction, index) ? " (def)" : "",
377 "\n");
378 }
379 });
380 } else {
381 const HloValueSet& top_level_value_set =
382 GetValueSet(instruction, /*index=*/{});
383 for (const HloValue* value : top_level_value_set.values()) {
384 StrAppend(&out, " ", value->ToShortString(),
385 ValueIsDefinedAt(instruction) ? " (def)" : "", "\n");
386 }
387 }
388 }
389 }
390 StrAppend(&out, " HloValues:\n");
391 for (const HloValue* value : values()) {
392 StrAppend(&out, value->ToString(/*indent=*/4));
393 }
394 return out;
395 }
396
Phi(HloInstruction * instruction,absl::Span<const InstructionValueSet * const> inputs)397 bool HloDataflowAnalysis::Phi(
398 HloInstruction* instruction,
399 absl::Span<const InstructionValueSet* const> inputs) {
400 CHECK(ssa_form_);
401 VLOG(4) << "Phi(" << instruction->name() << ")";
402 VLOG(5) << "instruction value set = "
403 << GetInstructionValueSet(instruction).ToString();
404 for (const InstructionValueSet* input : inputs) {
405 VLOG(5) << "input value set = " << input->ToString();
406 }
407
408 if (bitcast_defines_value_) {
409 absl::c_for_each(inputs, [&](const InstructionValueSet* input) {
410 DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape()));
411 });
412 } else {
413 const Shape& shape = instruction->shape();
414 PrimitiveType ty = shape.element_type();
415 bool is_array = shape.IsArray();
416 absl::c_for_each(inputs, [&](const InstructionValueSet* input) {
417 DCHECK(ty == input->shape().element_type() &&
418 (!is_array || ShapeUtil::ElementsIn(shape) ==
419 ShapeUtil::ElementsIn(input->shape())));
420 });
421 }
422
423 bool changed = false;
424 for (auto& pair : GetInstructionValueSet(instruction)) {
425 const ShapeIndex& index = pair.first;
426 HloValueSet& value_set = pair.second;
427
428 // Positions with phi values should never have more than one value in the
429 // value set.
430 CHECK_LE(value_set.values().size(), 1);
431 const HloValue* current_value =
432 value_set.values().size() == 1 ? value_set.values()[0] : nullptr;
433
434 // Construct a vector of value IDs of the inputs.
435 std::vector<HloValue::Id> input_value_ids;
436 for (const InstructionValueSet* input : inputs) {
437 for (const HloValue* value : input->element(index).values()) {
438 input_value_ids.push_back(value->id());
439 }
440 }
441
442 // Remove the existing phi value (if it exists). The phi can be its own
443 // input, for example, in while body parameters where the body passes
444 // through the parameter value.
445 bool current_value_defined_here =
446 (current_value != nullptr &&
447 current_value->defining_instruction() == instruction &&
448 current_value->defining_index() == index);
449
450 VLOG(5) << "after input_value_ids.size = " << input_value_ids.size();
451 if (input_value_ids.empty()) {
452 // A value set which has at least one element should never have its value
453 // set reduced to zero elements. During dataflow value sets only can go
454 // from empty to non-empty, not the reverse.
455 CHECK_EQ(value_set.values().size(), 0)
456 << "Instruction " << instruction->name() << " at index " << index
457 << " previously had non-empty value set. Value set: " << value_set;
458 } else if (input_value_ids.size() == 1) {
459 // Only a single value reaches this point. There should be no phi, and
460 // this value set should contain this single value.
461 const HloValue& new_value = GetValue(input_value_ids[0]);
462 if (current_value == nullptr) {
463 value_set.Clear();
464 value_set.AddValue(&new_value);
465 changed = true;
466 } else if (current_value != &new_value) {
467 if (current_value_defined_here) {
468 // Remove the existing phi.
469 MarkValueForDeletion(current_value->id());
470 }
471 value_set.Clear();
472 value_set.AddValue(&new_value);
473 changed = true;
474 }
475 } else {
476 // Multiple distinct values reach this point. A phi value is
477 // necessary.
478 CHECK_GT(input_value_ids.size(), 1);
479 bool phi_defined_here =
480 current_value_defined_here && current_value->is_phi();
481 if (current_value == nullptr || !phi_defined_here) {
482 value_set.Clear();
483 value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true));
484
485 std::vector<HloValue*> inputs;
486 inputs.reserve(input_value_ids.size());
487 for (HloValue::Id id : input_value_ids) {
488 inputs.push_back(&GetValue(id));
489 }
490 // Register the phi into phi graph.
491 phi_graph_.RegisterPhi(*value_set.values()[0], inputs);
492 changed = true;
493 } else if (phi_defined_here) {
494 std::vector<HloValue*> new_inputs;
495 new_inputs.reserve(input_value_ids.size());
496 for (HloValue::Id id : input_value_ids) {
497 new_inputs.push_back(&GetValue(id));
498 }
499
500 if (!phi_graph_.InputsEqualTo(*current_value, new_inputs)) {
501 VLOG(1) << current_value->ToShortString() << " has new phi inputs: ";
502 // Update phi inputs.
503 phi_graph_.RegisterPhi(*current_value, new_inputs);
504 changed = true;
505 }
506 }
507 }
508 }
509 return changed;
510 }
511
GetValue(HloValue::Id value_id) const512 const HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) const {
513 return values_.at(value_id);
514 }
515
GetValue(HloValue::Id value_id)516 HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) {
517 return values_.at(value_id);
518 }
519
GetFlattenedValueSet(const HloInstruction * instruction) const520 HloValueSet HloDataflowAnalysis::GetFlattenedValueSet(
521 const HloInstruction* instruction) const {
522 HloValueSet value_set;
523
524 const InstructionValueSet& value_set_tree =
525 GetInstructionValueSet(instruction);
526
527 std::vector<const HloValueSet*> all_sets;
528 for (auto& pair : value_set_tree) {
529 const HloValueSet& value_set = pair.second;
530 all_sets.push_back(&value_set);
531 }
532 value_set.AssignUnionOf(all_sets);
533
534 return value_set;
535 }
536
GetValueSet(const HloInstruction * instruction,const ShapeIndex & index) const537 const HloValueSet& HloDataflowAnalysis::GetValueSet(
538 const HloInstruction* instruction, const ShapeIndex& index) const {
539 return GetInstructionValueSet(instruction).element(index);
540 }
541
GetValueSet(const HloInstruction * instruction,const ShapeIndex & index)542 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloInstruction* instruction,
543 const ShapeIndex& index) {
544 return *GetInstructionValueSet(instruction).mutable_element(index);
545 }
546
GetValueSet(const HloPosition & position) const547 const HloValueSet& HloDataflowAnalysis::GetValueSet(
548 const HloPosition& position) const {
549 return GetValueSet(position.instruction, position.index);
550 }
551
GetValueSet(const HloPosition & position)552 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloPosition& position) {
553 return GetValueSet(position.instruction, position.index);
554 }
555
UpdateBitcastValueSet(HloInstruction * bitcast)556 bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) {
557 CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast);
558 const InstructionValueSet& operand_set =
559 GetInstructionValueSet(bitcast->operand(0));
560 InstructionValueSet& bitcast_set = GetInstructionValueSet(bitcast);
561 if (!bitcast_defines_value_ && operand_set != bitcast_set) {
562 bitcast_set = operand_set;
563 return true;
564 }
565 return false;
566 }
567
UpdateSetDimensionSizeValueSet(HloInstruction * set_dimension_size)568 bool HloDataflowAnalysis::UpdateSetDimensionSizeValueSet(
569 HloInstruction* set_dimension_size) {
570 CHECK_EQ(set_dimension_size->opcode(), HloOpcode::kSetDimensionSize);
571 const InstructionValueSet& operand_set =
572 GetInstructionValueSet(set_dimension_size->operand(0));
573 InstructionValueSet& set_dimension_size_set =
574 GetInstructionValueSet(set_dimension_size);
575 if (operand_set != set_dimension_size_set) {
576 set_dimension_size_set = operand_set;
577 return true;
578 }
579 return false;
580 }
581
UpdateSendValueSet(HloInstruction * send)582 bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
583 CHECK_EQ(send->opcode(), HloOpcode::kSend);
584 bool changed = false;
585 // Send forwards the operand value to the output tuple at {0}.
586 for (auto& pair : GetInstructionValueSet(send->operand(0))) {
587 const ShapeIndex& operand_index = pair.first;
588 const HloValueSet& operand_value_set = pair.second;
589
590 ShapeIndex index = {0};
591 for (int64_t i : operand_index) {
592 index.push_back(i);
593 }
594
595 HloValueSet& value_set = GetValueSet(send, index);
596 if (value_set != operand_value_set) {
597 value_set = operand_value_set;
598 changed = true;
599 }
600 }
601 return changed;
602 }
603
UpdateCustomCallValueSet(HloInstruction * custom_call)604 bool HloDataflowAnalysis::UpdateCustomCallValueSet(
605 HloInstruction* custom_call) {
606 CHECK_EQ(custom_call->opcode(), HloOpcode::kCustomCall);
607 bool changed = false;
608 for (const auto& aliasing : Cast<HloCustomCallInstruction>(custom_call)
609 ->output_to_operand_aliasing()) {
610 const HloValueSet& operand_value_set = GetValueSet(
611 custom_call->operand(aliasing.second.first), aliasing.second.second);
612 HloValueSet& value_set = GetValueSet(custom_call, aliasing.first);
613 if (value_set != operand_value_set) {
614 value_set = operand_value_set;
615 changed = true;
616 }
617 }
618 return changed;
619 }
620
UpdateCopyStartValueSet(HloInstruction * copy_start)621 bool HloDataflowAnalysis::UpdateCopyStartValueSet(HloInstruction* copy_start) {
622 CHECK_EQ(copy_start->opcode(), HloOpcode::kCopyStart);
623 bool changed = false;
624 // CopyStart forwards the operand value to element {1} of its output.
625 const HloValueSet& operand_value_set = GetValueSet(copy_start->operand(0));
626 HloValueSet& value_set = GetValueSet(copy_start, {1});
627 if (value_set != operand_value_set) {
628 value_set = operand_value_set;
629 changed = true;
630 }
631 return changed;
632 }
633
UpdateCopyDoneValueSet(HloInstruction * copy_done)634 bool HloDataflowAnalysis::UpdateCopyDoneValueSet(HloInstruction* copy_done) {
635 CHECK_EQ(copy_done->opcode(), HloOpcode::kCopyDone);
636 bool changed = false;
637 // CopyDone forwards the operand value at {0} to element {} of its output.
638 const HloValueSet& operand_value_set =
639 GetValueSet(copy_done->operand(0), {0});
640 HloValueSet& value_set = GetValueSet(copy_done);
641 if (value_set != operand_value_set) {
642 value_set = operand_value_set;
643 changed = true;
644 }
645 return changed;
646 }
647
UpdateRecvDoneValueSet(HloInstruction * recv_done)648 bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) {
649 CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone);
650 bool changed = false;
651 // RecvDone forwards the operand value at {0} to element {0} of its output.
652 for (auto& pair : GetInstructionValueSet(recv_done)) {
653 ShapeIndex& index = pair.first;
654 HloValueSet& value_set = pair.second;
655
656 if (index.empty() || index[0] != 0) {
657 continue;
658 }
659
660 const HloValueSet& operand_value_set =
661 GetValueSet(recv_done->operand(0), index);
662 if (value_set != operand_value_set) {
663 value_set = operand_value_set;
664 changed = true;
665 }
666 }
667 return changed;
668 }
669
UpdateCallValueSet(HloInstruction * call)670 bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) {
671 CHECK_EQ(call->opcode(), HloOpcode::kCall);
672 InstructionValueSet& value_set = GetInstructionValueSet(call);
673 InstructionValueSet& root_value_set =
674 GetInstructionValueSet(call->to_apply()->root_instruction());
675 if (value_set != root_value_set) {
676 value_set = root_value_set;
677 return true;
678 }
679 return false;
680 }
681
UpdateConditionalValueSet(HloInstruction * conditional)682 bool HloDataflowAnalysis::UpdateConditionalValueSet(
683 HloInstruction* conditional) {
684 CHECK_EQ(conditional->opcode(), HloOpcode::kConditional);
685 std::vector<const InstructionValueSet*> inputs(conditional->branch_count());
686 for (int j = 0; j < conditional->branch_count(); ++j) {
687 inputs[j] = &GetInstructionValueSet(
688 conditional->branch_computation(j)->root_instruction());
689 }
690 if (ssa_form_) {
691 return Phi(conditional, inputs);
692 } else {
693 return GetInstructionValueSet(conditional).AssignUnionOf(inputs);
694 }
695 }
696
UpdateCopyValueSet(HloInstruction * copy)697 bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) {
698 CHECK_EQ(copy->opcode(), HloOpcode::kCopy);
699 bool changed = false;
700 for (auto& pair : GetInstructionValueSet(copy)) {
701 const ShapeIndex& index = pair.first;
702 if (index.empty()) {
703 // kCopy shallow copies and thus defines the top-level value so nothing to
704 // update.
705 continue;
706 }
707
708 HloValueSet& value_set = pair.second;
709 HloValueSet& operand_value_set = GetValueSet(copy->operand(0), index);
710 if (value_set != operand_value_set) {
711 value_set = operand_value_set;
712 changed = true;
713 }
714 }
715 return changed;
716 }
717
UpdateDomainValueSet(HloInstruction * domain)718 bool HloDataflowAnalysis::UpdateDomainValueSet(HloInstruction* domain) {
719 // Domain instructions just forward their operand. Given that domains can have
720 // a tuple operand, we iterate through its indexes, like for copies.
721 // Unlike copies though we also propagate the top-level value.
722 CHECK_EQ(domain->opcode(), HloOpcode::kDomain);
723 bool changed = false;
724 for (auto& pair : GetInstructionValueSet(domain)) {
725 const ShapeIndex& index = pair.first;
726 HloValueSet& value_set = pair.second;
727 HloValueSet& operand_value_set = GetValueSet(domain->operand(0), index);
728 if (value_set != operand_value_set) {
729 value_set = operand_value_set;
730 changed = true;
731 }
732 }
733 return changed;
734 }
735
UpdateAddDependencyValueSet(HloInstruction * add_dependency)736 bool HloDataflowAnalysis::UpdateAddDependencyValueSet(
737 HloInstruction* add_dependency) {
738 // AddDependency just forwards the value of its zero-th operand.
739 CHECK_EQ(add_dependency->opcode(), HloOpcode::kAddDependency);
740 const InstructionValueSet& operand_set =
741 GetInstructionValueSet(add_dependency->operand(0));
742 InstructionValueSet& add_dependency_set =
743 GetInstructionValueSet(add_dependency);
744 if (operand_set != add_dependency_set) {
745 add_dependency_set = operand_set;
746 return true;
747 }
748 return false;
749 }
750
UpdateGetTupleElementValueSet(HloInstruction * gte)751 bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) {
752 CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement);
753 bool changed = false;
754 // The GetTupleElement instruction forwards the values from the specified
755 // tuple element.
756 for (auto& pair : GetInstructionValueSet(gte)) {
757 const ShapeIndex& index = pair.first;
758 HloValueSet& value_set = pair.second;
759
760 // The corresponding ShapeIndex of the operand is simply the GTE ShapeIndex
761 // with the tuple element number prefixed.
762 ShapeIndex operand_index = {gte->tuple_index()};
763 for (int64_t i : index) {
764 operand_index.push_back(i);
765 }
766
767 HloValueSet& operand_value_set =
768 GetValueSet(gte->operand(0), operand_index);
769 if (value_set != operand_value_set) {
770 value_set = operand_value_set;
771 changed = true;
772 }
773 }
774 return changed;
775 }
776
UpdateParameterValueSet(HloInstruction * parameter)777 bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) {
778 CHECK_EQ(parameter->opcode(), HloOpcode::kParameter);
779 const CallGraphNode& call_graph_node =
780 call_graph_->GetNode(parameter->parent());
781
782 // Subcomputations called in a parallel context (eg, map) do not have dataflow
783 // from the caller operands.
784 if (call_graph_node.context() == CallContext::kParallel ||
785 call_graph_node.caller_callsites().empty()) {
786 return false;
787 }
788 CHECK_EQ(call_graph_node.context(), CallContext::kSequential);
789
790 std::vector<const InstructionValueSet*> inputs;
791 bool need_phi = false;
792 for (const CallSite& callsite : call_graph_node.caller_callsites()) {
793 if (callsite.instruction()->opcode() == HloOpcode::kCall) {
794 // The operand values of a call instruction are forwarded to the
795 // respective parameter instruction of the subcomputation.
796 inputs.push_back(&GetInstructionValueSet(
797 callsite.instruction()->operand(parameter->parameter_number())));
798 } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
799 // In a while instruction, the while operand (ie, the init value) and the
800 // backedge are dataflow inputs to the parameter instruction. This is the
801 // case for parameters of both the body and condition computations.
802 CHECK_EQ(parameter->parameter_number(), 0);
803 inputs.push_back(
804 &GetInstructionValueSet(callsite.instruction()->operand(0)));
805 // If the parameter *is not* the root, parameter state would be
806 // updated by the root, otherwise don't consider it's current state
807 // (InstructionValueSet) as we are recomputing its current state.
808 if (parameter !=
809 callsite.instruction()->while_body()->root_instruction()) {
810 inputs.push_back(&GetInstructionValueSet(
811 callsite.instruction()->while_body()->root_instruction()));
812 }
813 need_phi = true;
814 } else if (callsite.instruction()->opcode() == HloOpcode::kConditional) {
815 CHECK_EQ(parameter->parameter_number(), 0);
816 auto conditional = callsite.instruction();
817 // Conditional has branch_count+1 operands. Operand 0 is the branch_index,
818 // operands 1 and onward are the arguments to the branch computations.
819 //
820 // If the parameter belongs to conditional's branch 0 computation, then
821 // operand 1 is forwarded to this parameter instruction. If the parameter
822 // belongs to conditional's branch 5 computation, then operand 6 is
823 // forwarded to this parameter instruction.
824 bool found_parent = false;
825 for (int j = 0; j < conditional->branch_count(); ++j) {
826 if (parameter->parent() == conditional->branch_computation(j)) {
827 inputs.push_back(
828 &GetInstructionValueSet(conditional->operand(j + 1)));
829 found_parent = true;
830 break;
831 }
832 }
833 CHECK(found_parent);
834 need_phi = true;
835 } else {
836 LOG(FATAL) << "CallContext::kSequential computations should only be "
837 "called from call, while, or conditional instructions";
838 }
839 }
840 if (ssa_form_ && need_phi) {
841 return Phi(parameter, inputs);
842 } else {
843 return GetInstructionValueSet(parameter).AssignUnionOf(inputs);
844 }
845 }
846
UpdateTupleSelectValueSet(HloInstruction * select)847 bool HloDataflowAnalysis::UpdateTupleSelectValueSet(HloInstruction* select) {
848 CHECK_EQ(select->opcode(), HloOpcode::kTupleSelect);
849 // A phi value is not defined at a kTupleSelect instruction because
850 // kTupleSelect does not create a new value. Rather it forwards a value from
851 // its operands. This contrasts with kWhile instruction (which does define a
852 // phi value) which has in-place update semantics.
853 bool changed = false;
854 for (auto& pair : GetInstructionValueSet(select)) {
855 const ShapeIndex& index = pair.first;
856 if (index.empty()) {
857 // kTupleSelect copies (not forwards) the top-level value.
858 continue;
859 }
860 HloValueSet& value_set = pair.second;
861 changed |=
862 value_set.AssignUnionOf({&GetValueSet(select->operand(1), index),
863 &GetValueSet(select->operand(2), index)});
864 }
865 return changed;
866 }
867
UpdateTupleValueSet(HloInstruction * tuple)868 bool HloDataflowAnalysis::UpdateTupleValueSet(HloInstruction* tuple) {
869 CHECK_EQ(tuple->opcode(), HloOpcode::kTuple);
870 bool changed = false;
871 for (int64_t i = 0; i < tuple->operands().size(); ++i) {
872 // Copy the value set(s) of each operand into the respective position in the
873 // kTuple instruction's value sets.
874 for (auto& pair : GetInstructionValueSet(tuple->operand(i))) {
875 const ShapeIndex& operand_index = pair.first;
876 HloValueSet& operand_value_set = pair.second;
877
878 ShapeIndex index = {i};
879 for (int64_t op_index : operand_index) {
880 index.push_back(op_index);
881 }
882 HloValueSet& value_set = GetValueSet(tuple, index);
883
884 if (value_set != operand_value_set) {
885 value_set = operand_value_set;
886 changed = true;
887 }
888 }
889 }
890 return changed;
891 }
892
UpdateWhileValueSet(HloInstruction * xla_while)893 bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) {
894 CHECK_EQ(xla_while->opcode(), HloOpcode::kWhile);
895 const InstructionValueSet* const inputs[] = {
896 &GetInstructionValueSet(xla_while->while_body()->root_instruction()),
897 &GetInstructionValueSet(xla_while->operand(0))};
898 if (ssa_form_) {
899 return Phi(xla_while, inputs);
900 } else {
901 return GetInstructionValueSet(xla_while).AssignUnionOf(inputs);
902 }
903 }
904
UpdateAllGatherStartValueSet(HloInstruction * all_gather_start)905 bool HloDataflowAnalysis::UpdateAllGatherStartValueSet(
906 HloInstruction* all_gather_start) {
907 CHECK_EQ(all_gather_start->opcode(), HloOpcode::kAllGatherStart);
908 bool changed = false;
909 // AllGatherStart forwards the operand values to element {0} of its output.
910 for (int64_t i = 0; i < all_gather_start->operand_count(); ++i) {
911 const HloValueSet& operand_value_set =
912 GetValueSet(all_gather_start->operand(i));
913
914 ShapeIndex output_index = {0};
915 if (all_gather_start->operand_count() > 1) {
916 output_index.push_back(i);
917 }
918
919 HloValueSet& value_set = GetValueSet(all_gather_start, output_index);
920 if (value_set != operand_value_set) {
921 value_set = operand_value_set;
922 changed = true;
923 }
924 }
925 return changed;
926 }
927
UpdateAllGatherDoneValueSet(HloInstruction * all_gather_done)928 bool HloDataflowAnalysis::UpdateAllGatherDoneValueSet(
929 HloInstruction* all_gather_done) {
930 CHECK_EQ(all_gather_done->opcode(), HloOpcode::kAllGatherDone);
931 bool changed = false;
932 // AllGatherDone forwards the operand value at {1} to its output.
933 for (auto& pair : GetInstructionValueSet(all_gather_done)) {
934 const ShapeIndex& output_index = pair.first;
935 HloValueSet& value_set = pair.second;
936
937 ShapeIndex operand_index = {1};
938 for (int64_t i : output_index) {
939 operand_index.push_back(i);
940 }
941
942 const HloValueSet& operand_value_set =
943 GetValueSet(all_gather_done->operand(0), operand_index);
944 if (value_set != operand_value_set) {
945 value_set = operand_value_set;
946 changed = true;
947 }
948 }
949 return changed;
950 }
951
UpdateAllReduceStartValueSet(HloInstruction * all_reduce_start)952 bool HloDataflowAnalysis::UpdateAllReduceStartValueSet(
953 HloInstruction* all_reduce_start) {
954 CHECK_EQ(all_reduce_start->opcode(), HloOpcode::kAllReduceStart);
955 bool changed = false;
956 // AllReduceStart forwards the operand values to element {0} of its output.
957 for (int64_t i = 0; i < all_reduce_start->operand_count(); ++i) {
958 const HloValueSet& operand_value_set =
959 GetValueSet(all_reduce_start->operand(i));
960
961 ShapeIndex output_index = {0};
962 if (all_reduce_start->operand_count() > 1) {
963 output_index.push_back(i);
964 }
965
966 HloValueSet& value_set = GetValueSet(all_reduce_start, output_index);
967 if (value_set != operand_value_set) {
968 value_set = operand_value_set;
969 changed = true;
970 }
971 }
972 return changed;
973 }
974
UpdateAllReduceDoneValueSet(HloInstruction * all_reduce_done)975 bool HloDataflowAnalysis::UpdateAllReduceDoneValueSet(
976 HloInstruction* all_reduce_done) {
977 CHECK_EQ(all_reduce_done->opcode(), HloOpcode::kAllReduceDone);
978 bool changed = false;
979 // AllReduceDone forwards the operand value at {1} to its output.
980 for (auto& pair : GetInstructionValueSet(all_reduce_done)) {
981 const ShapeIndex& output_index = pair.first;
982 HloValueSet& value_set = pair.second;
983
984 ShapeIndex operand_index = {1};
985 for (int64_t i : output_index) {
986 operand_index.push_back(i);
987 }
988
989 const HloValueSet& operand_value_set =
990 GetValueSet(all_reduce_done->operand(0), operand_index);
991 if (value_set != operand_value_set) {
992 value_set = operand_value_set;
993 changed = true;
994 }
995 }
996 return changed;
997 }
998
UpdateCollectivePermuteStartValueSet(HloInstruction * collective_permute_start)999 bool HloDataflowAnalysis::UpdateCollectivePermuteStartValueSet(
1000 HloInstruction* collective_permute_start) {
1001 CHECK_EQ(collective_permute_start->opcode(),
1002 HloOpcode::kCollectivePermuteStart);
1003 bool changed = false;
1004 // CollectivePermuteStart forwards the operand value to element {0} of its
1005 // output.
1006 if (collective_permute_start->operand(0)->shape().IsTuple()) {
1007 for (int i = 0; i < ShapeUtil::TupleElementCount(
1008 collective_permute_start->operand(0)->shape());
1009 ++i) {
1010 const HloValueSet& operand_value_set =
1011 GetValueSet(collective_permute_start->operand(0), {i});
1012 HloValueSet& value_set = GetValueSet(collective_permute_start, {0, i});
1013 if (value_set != operand_value_set) {
1014 value_set = operand_value_set;
1015 changed = true;
1016 }
1017 }
1018 } else {
1019 const HloValueSet& operand_value_set =
1020 GetValueSet(collective_permute_start->operand(0));
1021 HloValueSet& value_set = GetValueSet(collective_permute_start, {0});
1022 if (value_set != operand_value_set) {
1023 value_set = operand_value_set;
1024 changed = true;
1025 }
1026 }
1027 return changed;
1028 }
1029
UpdateCollectivePermuteDoneValueSet(HloInstruction * collective_permute_done)1030 bool HloDataflowAnalysis::UpdateCollectivePermuteDoneValueSet(
1031 HloInstruction* collective_permute_done) {
1032 CHECK_EQ(collective_permute_done->opcode(),
1033 HloOpcode::kCollectivePermuteDone);
1034 bool changed = false;
1035 // CollectivePermuteDone forwards the operand value at {1} to its output.
1036 if (collective_permute_done->shape().IsTuple()) {
1037 for (int i = 0;
1038 i < ShapeUtil::TupleElementCount(collective_permute_done->shape());
1039 ++i) {
1040 const HloValueSet& operand_value_set =
1041 GetValueSet(collective_permute_done->operand(0), {1, i});
1042 HloValueSet& value_set = GetValueSet(collective_permute_done, {i});
1043 if (value_set != operand_value_set) {
1044 value_set = operand_value_set;
1045 changed = true;
1046 }
1047 }
1048 } else {
1049 const HloValueSet& operand_value_set =
1050 GetValueSet(collective_permute_done->operand(0), {1});
1051 HloValueSet& value_set = GetValueSet(collective_permute_done);
1052 if (value_set != operand_value_set) {
1053 value_set = operand_value_set;
1054 changed = true;
1055 }
1056 }
1057 return changed;
1058 }
1059
UpdateInstructionValueSet(HloInstruction * instruction)1060 bool HloDataflowAnalysis::UpdateInstructionValueSet(
1061 HloInstruction* instruction) {
1062 // Recompute from operands.
1063 switch (instruction->opcode()) {
1064 case HloOpcode::kAddDependency:
1065 return UpdateAddDependencyValueSet(instruction);
1066 case HloOpcode::kAllGatherStart:
1067 return UpdateAllGatherStartValueSet(instruction);
1068 case HloOpcode::kAllGatherDone:
1069 return UpdateAllGatherDoneValueSet(instruction);
1070 case HloOpcode::kBitcast:
1071 return UpdateBitcastValueSet(instruction);
1072 case HloOpcode::kCustomCall:
1073 return UpdateCustomCallValueSet(instruction);
1074 case HloOpcode::kSetDimensionSize:
1075 return UpdateSetDimensionSizeValueSet(instruction);
1076 case HloOpcode::kDomain:
1077 return UpdateDomainValueSet(instruction);
1078 case HloOpcode::kCopy:
1079 return UpdateCopyValueSet(instruction);
1080 case HloOpcode::kGetTupleElement:
1081 return UpdateGetTupleElementValueSet(instruction);
1082 case HloOpcode::kTupleSelect:
1083 return UpdateTupleSelectValueSet(instruction);
1084 case HloOpcode::kTuple:
1085 return UpdateTupleValueSet(instruction);
1086 case HloOpcode::kParameter:
1087 return UpdateParameterValueSet(instruction);
1088 case HloOpcode::kCall:
1089 return UpdateCallValueSet(instruction);
1090 case HloOpcode::kWhile:
1091 return UpdateWhileValueSet(instruction);
1092 case HloOpcode::kSend:
1093 return UpdateSendValueSet(instruction);
1094 case HloOpcode::kRecvDone:
1095 return UpdateRecvDoneValueSet(instruction);
1096 case HloOpcode::kCopyStart:
1097 return UpdateCopyStartValueSet(instruction);
1098 case HloOpcode::kCopyDone:
1099 return UpdateCopyDoneValueSet(instruction);
1100 case HloOpcode::kConditional:
1101 return UpdateConditionalValueSet(instruction);
1102 case HloOpcode::kAllReduceStart:
1103 return UpdateAllReduceStartValueSet(instruction);
1104 case HloOpcode::kAllReduceDone:
1105 return UpdateAllReduceDoneValueSet(instruction);
1106 case HloOpcode::kCollectivePermuteStart:
1107 return UpdateCollectivePermuteStartValueSet(instruction);
1108 case HloOpcode::kCollectivePermuteDone:
1109 return UpdateCollectivePermuteDoneValueSet(instruction);
1110 default:
1111 // Instruction does not forward HloValues (it defines all values in its
1112 // output). No update is necessary.
1113 return false;
1114 }
1115 }
1116
Propagate()1117 void HloDataflowAnalysis::Propagate() {
1118 using Work = std::pair<int64, HloInstruction*>;
1119 // Avoid duplicating work by preferring work items early in the post order
1120 // schedule. Intuitively, we start from entry parameters and propagate buffers
1121 // updates throughout the module only once.
1122 std::priority_queue<Work, std::vector<Work>, std::greater<Work>> worklist;
1123 absl::flat_hash_set<HloInstruction*> workset;
1124 auto priority_map = CalculatePostOrderSchedule(module_);
1125 auto add_to_worklist = [&priority_map, &worklist,
1126 &workset](HloInstruction* instruction) {
1127 if (workset.insert(instruction).second) {
1128 worklist.emplace(priority_map[instruction], instruction);
1129 }
1130 };
1131
1132 auto comps = module_.MakeComputationPostOrder();
1133 for (HloComputation* computation : comps) {
1134 for (HloInstruction* instruction :
1135 computation->MakeInstructionPostOrder()) {
1136 add_to_worklist(instruction);
1137 }
1138 }
1139 VLOG(1) << "SSA_FORM_: " << ssa_form_;
1140
1141 while (!worklist.empty()) {
1142 HloInstruction* instruction = worklist.top().second;
1143 auto add_to_worklist = [&](HloInstruction* todo) {
1144 if (workset.insert(todo).second) {
1145 VLOG(1) << " Adding todo : " << todo->name();
1146 worklist.emplace(priority_map[todo], todo);
1147 }
1148 };
1149 worklist.pop();
1150
1151 workset.erase(workset.find(instruction));
1152
1153 VLOG(3) << "Worklist top: " << instruction->name();
1154 VLOG(3) << ToString();
1155
1156 if (!UpdateInstructionValueSet(instruction)) {
1157 // No change to the instruction's value set.
1158 VLOG(4) << "No change.";
1159 continue;
1160 }
1161
1162 VLOG(4) << "New value set for " << instruction->name() << ": "
1163 << GetInstructionValueSet(instruction);
1164
1165 // Instruction value was updated. Add users to work list if we haven't
1166 // already.
1167 for (HloInstruction* user : instruction->users()) {
1168 add_to_worklist(user);
1169
1170 // If user sequentially calls a computation, then the respective
1171 // parameter(s) of the computation need to be updated.
1172 if (user->opcode() == HloOpcode::kConditional) {
1173 // If operand 0 is the use of instruction, then no parameters need to be
1174 // updated, since that is the branch_index of the conditional.
1175 // If operand n+1 is the use of instruction, then the branch_computation
1176 // n's parameter need to be updated.
1177 //
1178 // Note that the same instruction can be used in multiple branches'
1179 // operands.
1180 for (int j = 0; j < user->branch_count(); ++j) {
1181 if (user->operand(j + 1) == instruction) {
1182 add_to_worklist(
1183 user->branch_computation(j)->parameter_instruction(0));
1184 }
1185 }
1186 } else {
1187 for (HloComputation* called_computation : user->called_computations()) {
1188 const CallGraphNode& call_graph_node =
1189 call_graph_->GetNode(called_computation);
1190 if (call_graph_node.context() == CallContext::kSequential) {
1191 for (int64_t operand_number : user->OperandIndices(instruction)) {
1192 add_to_worklist(
1193 called_computation->parameter_instruction(operand_number));
1194 }
1195 }
1196 }
1197 }
1198 }
1199
1200 // If instruction is a root instruction, then propagate out to any calling
1201 // instruction and across any while backedge.
1202 if (instruction == instruction->parent()->root_instruction()) {
1203 const CallGraphNode& call_graph_node =
1204 call_graph_->GetNode(instruction->parent());
1205 for (const CallSite& callsite : call_graph_node.caller_callsites()) {
1206 if (callsite.instruction()->opcode() == HloOpcode::kCall ||
1207 callsite.instruction()->opcode() == HloOpcode::kConditional) {
1208 add_to_worklist(callsite.instruction());
1209 } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
1210 // Add the while itself, and the body and condition parameters.
1211 add_to_worklist(callsite.instruction());
1212 add_to_worklist(
1213 callsite.instruction()->while_body()->parameter_instruction(0));
1214 add_to_worklist(
1215 callsite.instruction()->while_condition()->parameter_instruction(
1216 0));
1217 }
1218 }
1219 }
1220 }
1221 }
1222
GetInstructionValueSet(const HloInstruction * instruction) const1223 const InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
1224 const HloInstruction* instruction) const {
1225 return value_sets_.at(instruction);
1226 }
1227
GetInstructionValueSet(const HloInstruction * instruction)1228 InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
1229 const HloInstruction* instruction) {
1230 return value_sets_.at(instruction);
1231 }
1232
InitializeInstructionValueSets()1233 Status HloDataflowAnalysis::InitializeInstructionValueSets() {
1234 for (const HloComputation* computation : module_.computations()) {
1235 const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
1236 for (HloInstruction* instruction : computation->instructions()) {
1237 // Create an empty shape tree.
1238 value_sets_.emplace(std::piecewise_construct,
1239 std::forward_as_tuple(instruction),
1240 std::forward_as_tuple(instruction->shape()));
1241
1242 // For each sub-shape of the instruction shape, add a new HloValue to its
1243 // HloValueSet.
1244 auto define_all_values = [this, &instruction]() {
1245 for (auto& pair : GetInstructionValueSet(instruction)) {
1246 const ShapeIndex& index = pair.first;
1247 HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
1248 GetValueSet(instruction, index).AddValue(value);
1249 }
1250 };
1251
1252 // Add a new HloValue to the HloValueSet corresponding to the given index
1253 // of the instruction shape.
1254 auto define_value_at = [this, &instruction](const ShapeIndex& index) {
1255 HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
1256 GetValueSet(instruction, index).AddValue(value);
1257 };
1258
1259 switch (instruction->opcode()) {
1260 case HloOpcode::kBitcast:
1261 if (bitcast_defines_value_) {
1262 define_all_values();
1263 }
1264 break;
1265 case HloOpcode::kSetDimensionSize:
1266 case HloOpcode::kAddDependency:
1267 case HloOpcode::kWhile:
1268 case HloOpcode::kCall:
1269 case HloOpcode::kConditional:
1270 case HloOpcode::kGetTupleElement:
1271 case HloOpcode::kDomain:
1272 // These instructions define no values. The values in their output
1273 // flow from their operands or from cross computation dataflow.
1274 break;
1275 case HloOpcode::kParameter:
1276 if (call_graph_node.context() == CallContext::kBoth) {
1277 // We do not support a subcomputation that is called from both a
1278 // parallel and sequential context. In this case, the parameter
1279 // would both define a value and propagate a value from its
1280 // caller. This limitation is not really a problem because the call
1281 // graph is typically flattened.
1282 return Unimplemented(
1283 "Computation %s is called in both a parallel (eg, kMap) and "
1284 "sequential (eg, kCall) context",
1285 computation->name());
1286 }
1287 if (call_graph_node.caller_callsites().empty() ||
1288 call_graph_node.context() == CallContext::kParallel) {
1289 // Parameters of computations called in a parallel context (eg, map
1290 // and reduce) as well as parameters of dead computations define all
1291 // values in their output. Otherwise the values of the parameter
1292 // come from the caller (eg, operands to the kCall instruction).
1293 define_all_values();
1294 }
1295 break;
1296 case HloOpcode::kCopy:
1297 case HloOpcode::kTupleSelect:
1298 case HloOpcode::kTuple:
1299 // These instructions only define their top-level values. Any other
1300 // values flow from their operands.
1301 define_value_at(/*index=*/{});
1302 break;
1303 case HloOpcode::kCopyStart:
1304 // CopyStart produces a tuple of {destination buffer, aliased operand,
1305 // U32 context}.
1306 define_value_at(/*index=*/{});
1307 define_value_at(/*index=*/{0});
1308 define_value_at(/*index=*/{2});
1309 break;
1310 case HloOpcode::kCopyDone:
1311 // CopyDone consumes a tuple produced by CopyStart and produces an
1312 // element. Its output aliases its input tuple element {0}.
1313 break;
1314 case HloOpcode::kAllGatherStart:
1315 // AllGatherStart produces a tuple of
1316 // {aliased operand, destination buffer}.
1317 define_value_at(/*index=*/{});
1318 define_value_at(/*index=*/{1});
1319 break;
1320 case HloOpcode::kAllGatherDone:
1321 // AllGatherDone's output aliases its input tuple element {1}.
1322 if (instruction->shape().IsTuple()) {
1323 define_value_at(/*index=*/{});
1324 }
1325 break;
1326 case HloOpcode::kAllReduceStart:
1327 // AllReduceStart produces a tuple of
1328 // {aliased operand, destination buffer}.
1329 define_value_at(/*index=*/{});
1330 define_value_at(/*index=*/{1});
1331 if (instruction->operand_count() > 1) {
1332 for (int64_t i = 0; i < instruction->operand_count(); ++i) {
1333 define_value_at(/*index=*/{1, i});
1334 }
1335 }
1336 break;
1337 case HloOpcode::kAllReduceDone:
1338 // AllReduceDone's output aliases its input tuple element {1}.
1339 break;
1340 case HloOpcode::kCollectivePermuteStart:
1341 // CollectivePermuteStart produces a tuple of
1342 // {aliased operand, destination buffer, U32 context, U32 context}.
1343 define_value_at(/*index=*/{});
1344 define_value_at(/*index=*/{1});
1345 define_value_at(/*index=*/{2});
1346 define_value_at(/*index=*/{3});
1347 if (instruction->operand_count() > 1) {
1348 CHECK_EQ(instruction->operand_count(), 4);
1349 if (instruction->operand(1)->shape().IsTuple()) {
1350 for (int i = 0; i < ShapeUtil::TupleElementCount(
1351 instruction->operand(1)->shape());
1352 ++i) {
1353 define_value_at(/*index=*/{1, i});
1354 }
1355 }
1356 define_value_at(/*index=*/{4});
1357 }
1358 break;
1359 case HloOpcode::kCollectivePermuteDone:
1360 // CollectivePermuteDone's output aliases its input tuple element {1}.
1361 if (instruction->shape().IsTuple()) {
1362 define_value_at(/*index=*/{});
1363 }
1364 break;
1365 case HloOpcode::kRecvDone:
1366 // RecvDone produces a two-element tuple. Element zero aliases its
1367 // input tuple element {0}; element one is a token.
1368 define_value_at(/*index=*/{});
1369 define_value_at(/*index=*/{1});
1370 break;
1371 case HloOpcode::kSend:
1372 // Send produces a tuple of {aliased operand, U32 context, token},
1373 // therefore only defines the top-level tuple and the tuple elements
1374 // at {1} and {2}.
1375 define_value_at(/*index=*/{});
1376 define_value_at(/*index=*/{1});
1377 define_value_at(/*index=*/{2});
1378 break;
1379 case HloOpcode::kCustomCall: {
1380 absl::flat_hash_set<ShapeIndex> aliasing_indices;
1381 for (const auto& aliasing :
1382 Cast<HloCustomCallInstruction>(instruction)
1383 ->output_to_operand_aliasing()) {
1384 aliasing_indices.insert(aliasing.first);
1385 }
1386 ShapeUtil::ForEachSubshape(
1387 instruction->shape(),
1388 [&](const Shape& /*subshape*/, const ShapeIndex& index) {
1389 if (!aliasing_indices.contains(index)) {
1390 define_value_at(index);
1391 }
1392 });
1393 break;
1394 }
1395 default:
1396 define_all_values();
1397 break;
1398 }
1399 }
1400 }
1401
1402 return Status::OK();
1403 }
1404
OptimizePhiValues()1405 void HloDataflowAnalysis::OptimizePhiValues() {
1406 // Only applicable to SSA form where phis are defined.
1407 if (!ssa_form_) {
1408 return;
1409 }
1410
1411 VLOG(1) << "Before phi graph optimization";
1412 XLA_VLOG_LINES(1, phi_graph_.ToString());
1413 phi_graph_.Optimize();
1414 VLOG(1) << "After phi graph optimization";
1415 XLA_VLOG_LINES(1, phi_graph_.ToString());
1416
1417 for (const HloComputation* computation : module_.computations()) {
1418 for (HloInstruction* instruction : computation->instructions()) {
1419 InstructionValueSet& instruction_value_set =
1420 GetInstructionValueSet(instruction);
1421 VLOG(1) << "inst: " << instruction->name();
1422 VLOG(1) << instruction_value_set.ToString();
1423 instruction_value_set.ForEachMutableElement(
1424 [&](const xla::ShapeIndex& index, HloValueSet* value_set) {
1425 auto values = value_set->values();
1426 if (!(values.size() == 1 && values[0]->is_phi())) {
1427 return;
1428 }
1429 HloValue::Id phi_id = values[0]->id();
1430 HloValue::Id new_id = phi_graph_.FindOptimizedValue(phi_id);
1431 if (new_id != phi_id) {
1432 VLOG(1) << "Replacing " << values[0]->ToString() << " with "
1433 << GetValue(new_id).ToString();
1434 value_set->Clear();
1435 const HloValue& new_value = GetValue(new_id);
1436 value_set->AddValue(&new_value);
1437 MarkValueForDeletion(phi_id);
1438 }
1439 });
1440 }
1441 }
1442 }
1443
1444 /* static */
Run(const HloModule & module,bool ssa_form,bool bitcast_defines_value,const CanShareBuffer & can_share_buffer)1445 StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
1446 const HloModule& module, bool ssa_form, bool bitcast_defines_value,
1447 const CanShareBuffer& can_share_buffer) {
1448 VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name();
1449 XLA_VLOG_LINES(2, module.ToString());
1450
1451 auto dataflow_analysis = absl::WrapUnique(new HloDataflowAnalysis(
1452 module, ssa_form, bitcast_defines_value, can_share_buffer));
1453
1454 TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
1455 dataflow_analysis->Propagate();
1456 dataflow_analysis->OptimizePhiValues();
1457
1458 // Delete all values marked for deletion.
1459 dataflow_analysis->DeleteMarkedValues();
1460
1461 // Gather and set all non-definition positions of all values. Value deletion
1462 // is rare, so just use a vector indexed by Value::Id rather than a map from
1463 // Value::Id to positions. There should be very few holes in the vector, and
1464 // lookup is faster.
1465 std::vector<std::vector<HloPosition>> value_positions(
1466 dataflow_analysis->next_value_id_);
1467 for (const HloComputation* computation : module.computations()) {
1468 for (HloInstruction* instruction : computation->instructions()) {
1469 for (const auto& pair :
1470 dataflow_analysis->GetInstructionValueSet(instruction)) {
1471 const ShapeIndex& index = pair.first;
1472 const HloValueSet& value_set = pair.second;
1473 for (const HloValue* value : value_set.values()) {
1474 if (value->defining_instruction() != instruction) {
1475 value_positions[value->id()].push_back(
1476 HloPosition{instruction, index});
1477 }
1478 }
1479 }
1480 }
1481 }
1482 for (auto& pair : dataflow_analysis->values_) {
1483 HloValue::Id value_id = pair.first;
1484 HloValue& value = pair.second;
1485 value.SetPositionsAndComputeUses(value_positions[value_id]);
1486 }
1487
1488 // Construct vector of values.
1489 dataflow_analysis->values_vector_.reserve(dataflow_analysis->values_.size());
1490 for (auto& pair : dataflow_analysis->values_) {
1491 dataflow_analysis->values_vector_.push_back(&pair.second);
1492 }
1493 absl::c_sort(dataflow_analysis->values_vector_, HloValue::IdLessThan);
1494
1495 TF_DCHECK_OK(dataflow_analysis->Verify());
1496
1497 XLA_VLOG_LINES(1, dataflow_analysis->ToString());
1498
1499 return std::move(dataflow_analysis);
1500 }
1501
Verify() const1502 Status HloDataflowAnalysis::Verify() const {
1503 // Verify each HloValue appears in the value sets that the value's positions()
1504 // indicate.
1505 for (const HloValue* value : values()) {
1506 for (const HloPosition& position : value->positions()) {
1507 const HloValueSet& value_set = GetValueSet(position);
1508 TF_RET_CHECK(absl::c_linear_search(value_set.values(), value))
1509 << "Value set at position " << position << " does not contain value "
1510 << value->ToShortString();
1511 }
1512 }
1513
1514 // For each value in each value set, verify that the value set's position
1515 // appears in the value's positions().
1516 for (const auto& computation : module_.computations()) {
1517 for (const auto& instruction : computation->instructions()) {
1518 for (const auto& pair : GetInstructionValueSet(instruction)) {
1519 const ShapeIndex& index = pair.first;
1520 const HloValueSet& value_set = pair.second;
1521 const HloPosition position{instruction, index};
1522 for (const HloValue* value : value_set.values()) {
1523 TF_RET_CHECK(absl::c_linear_search(value->positions(), position))
1524 << "Value set at position " << position
1525 << " unexpectedly contains value " << value->ToShortString();
1526 }
1527 }
1528 }
1529 }
1530
1531 return Status::OK();
1532 }
1533
DoesNotUseOperandBuffer(const HloInstruction * operand,const ShapeIndex & index,const HloInstruction * user) const1534 bool HloDataflowAnalysis::DoesNotUseOperandBuffer(
1535 const HloInstruction* operand, const ShapeIndex& index,
1536 const HloInstruction* user) const {
1537 // Return false if no value at 'operand' and 'index' is used at 'user'.
1538 for (const HloValue* value : GetValueSet(operand, index).values()) {
1539 for (const HloUse& use : value->uses()) {
1540 if (use.instruction == user) {
1541 if (user->IsLoopFusion()) {
1542 HloInstruction* fusion_param =
1543 user->fused_parameter(use.operand_number);
1544 const HloValue& value =
1545 GetValueDefinedAt(fusion_param, use.operand_index);
1546 return value.uses().empty();
1547 }
1548 return false;
1549 }
1550 }
1551 }
1552 return true;
1553 }
1554
IsInPlaceOperation(HloOpcode opcode)1555 /*static*/ bool HloDataflowAnalysis::IsInPlaceOperation(HloOpcode opcode) {
1556 return opcode == HloOpcode::kDynamicUpdateSlice ||
1557 opcode == HloOpcode::kScatter;
1558 }
1559
IsAsynchronousOperationStart(HloOpcode opcode)1560 /*static*/ bool HloDataflowAnalysis::IsAsynchronousOperationStart(
1561 HloOpcode opcode) {
1562 return opcode == HloOpcode::kSend || opcode == HloOpcode::kRecv ||
1563 opcode == HloOpcode::kCopyStart ||
1564 opcode == HloOpcode::kAllReduceStart ||
1565 opcode == HloOpcode::kAllGatherStart ||
1566 opcode == HloOpcode::kCollectivePermuteStart;
1567 }
1568
IsAsynchronousOperationDone(HloOpcode opcode)1569 /*static*/ bool HloDataflowAnalysis::IsAsynchronousOperationDone(
1570 HloOpcode opcode) {
1571 return opcode == HloOpcode::kSendDone || opcode == HloOpcode::kRecvDone ||
1572 opcode == HloOpcode::kCopyDone ||
1573 opcode == HloOpcode::kAllReduceDone ||
1574 opcode == HloOpcode::kAllGatherDone ||
1575 opcode == HloOpcode::kCollectivePermuteDone;
1576 }
1577
1578 /*static*/ std::vector<std::pair<HloUse, ShapeIndex>>
GetInPlaceInputOutputPairs(HloInstruction * instruction)1579 HloDataflowAnalysis::GetInPlaceInputOutputPairs(HloInstruction* instruction) {
1580 if (IsInPlaceOperation(instruction->opcode())) {
1581 return {{HloUse{instruction, 0, {}}, {}}};
1582 } else if (instruction->opcode() == HloOpcode::kCollectivePermute &&
1583 instruction->operands().size() == 4) {
1584 if (instruction->operand(1)->shape().IsTuple()) {
1585 std::vector<std::pair<HloUse, ShapeIndex>> in_place_pairs(
1586 {{HloUse{instruction, 1, {}}, {}}});
1587 for (int i = 0; i < instruction->operand(1)->shape().tuple_shapes_size();
1588 i++) {
1589 in_place_pairs.push_back({HloUse{instruction, 1, {i}}, {i}});
1590 }
1591 return in_place_pairs;
1592 } else {
1593 return {{HloUse{instruction, 1, {}}, {}}};
1594 }
1595 } else if (instruction->opcode() == HloOpcode::kCollectivePermuteStart &&
1596 instruction->operands().size() == 4) {
1597 if (instruction->operand(1)->shape().IsTuple()) {
1598 std::vector<std::pair<HloUse, ShapeIndex>> in_place_pairs(
1599 {{HloUse{instruction, 1, {}}, {1}}});
1600 for (int i = 0; i < instruction->operand(1)->shape().tuple_shapes_size();
1601 i++) {
1602 in_place_pairs.push_back({HloUse{instruction, 1, {i}}, {1, i}});
1603 }
1604 return in_place_pairs;
1605 } else {
1606 return {{HloUse{instruction, 1, {}}, {1}}};
1607 }
1608 } else if (instruction->opcode() != HloOpcode::kFusion) {
1609 return {};
1610 }
1611
1612 std::vector<std::pair<HloUse, ShapeIndex>> input_output_pairs;
1613 for (auto& indexed_shape : ShapeUtil::GetLeafShapes(instruction->shape())) {
1614 const HloInstruction* hlo_generating_output =
1615 instruction->fused_expression_root();
1616 for (int64_t i = 0; i < indexed_shape.index.size(); ++i) {
1617 if (hlo_generating_output->opcode() == HloOpcode::kTuple) {
1618 hlo_generating_output =
1619 hlo_generating_output->operand(indexed_shape.index[i]);
1620 } else {
1621 CHECK_EQ(i, indexed_shape.index.size() - 1);
1622 }
1623 }
1624
1625 if (IsInPlaceOperation(hlo_generating_output->opcode())) {
1626 ShapeIndex operand_index;
1627 const HloInstruction* fusion_parameter =
1628 hlo_generating_output->operand(0);
1629 while (fusion_parameter->opcode() == HloOpcode::kGetTupleElement) {
1630 operand_index.push_front(fusion_parameter->tuple_index());
1631 fusion_parameter = fusion_parameter->operand(0);
1632 }
1633
1634 if (fusion_parameter->opcode() == HloOpcode::kParameter) {
1635 input_output_pairs.emplace_back(
1636 HloUse{instruction, fusion_parameter->parameter_number(),
1637 operand_index},
1638 indexed_shape.index);
1639 }
1640 }
1641 }
1642 return input_output_pairs;
1643 }
1644
CanShareOperandBufferWithUser(HloInstruction * operand,const ShapeIndex & operand_index,HloInstruction * user,const ShapeIndex & user_index) const1645 bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
1646 HloInstruction* operand, const ShapeIndex& operand_index,
1647 HloInstruction* user, const ShapeIndex& user_index) const {
1648 CHECK(user->IsUserOf(operand))
1649 << "user: " << user->ToString() << " operand: " << operand->ToString();
1650 if (operand->opcode() == HloOpcode::kConstant) {
1651 return false;
1652 }
1653
1654 const Shape& operand_subshape =
1655 ShapeUtil::GetSubshape(operand->shape(), operand_index);
1656 const Shape& user_subshape =
1657 ShapeUtil::GetSubshape(user->shape(), user_index);
1658 if (IsSliceInputFusion(*user)) {
1659 HloInstruction* fusion_param =
1660 user->fused_parameter(user->operand_index(operand));
1661 // We don't require the same dimensions but only the same number of elements
1662 // and type (to make sure the same buffer size).
1663 return operand_subshape.IsArray() && user_subshape.IsArray() &&
1664 ShapeUtil::ElementsIn(operand_subshape) ==
1665 ShapeUtil::ElementsIn(user_subshape) &&
1666 ShapeUtil::SameElementType(operand_subshape, user_subshape) &&
1667 AreTransitiveUsesEffectivelyElementwise(
1668 fusion_param, user->fused_expression_root(), user_index);
1669 }
1670
1671 // Check that operand and user emit the same shape and layout.
1672 if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
1673 return false;
1674 }
1675
1676 // Must-alias relationship returns true for in-place operations (DUS and DUS
1677 // fusions), regardless of the backend.
1678 for (const auto& operand_and_output_index :
1679 GetInPlaceInputOutputPairs(user)) {
1680 if (operand_and_output_index.second != user_index) {
1681 continue;
1682 }
1683 for (const HloUse& use : GetUniqueValueAt(operand, operand_index).uses()) {
1684 if (use == operand_and_output_index.first) {
1685 return true;
1686 }
1687 }
1688 }
1689
1690 if (can_share_buffer_ != nullptr) {
1691 if (absl::optional<bool> hint =
1692 can_share_buffer_(user, operand, user_index)) {
1693 return *hint;
1694 }
1695 }
1696
1697 if (user->opcode() == HloOpcode::kFusion) {
1698 HloInstruction* fusion_param =
1699 user->fused_parameter(user->operand_index(operand));
1700 const HloValue& fusion_param_value =
1701 GetValueDefinedAt(fusion_param, operand_index);
1702
1703 if (user->IsLoopFusion() || user->IsInputFusion()) {
1704 return AreTransitiveUsesElementwiseOrTuple(fusion_param);
1705 }
1706
1707 if (user->IsOutputFusion() &&
1708 user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
1709 // Output fusion with kAdd fused root.
1710
1711 // Check if one operand of kAdd fused root is kDot or kConvolution.
1712 auto* add = user->fused_expression_root();
1713 auto add_operand_it =
1714 absl::c_find_if(add->operands(), [&](HloInstruction* operand) {
1715 return operand->opcode() == HloOpcode::kConvolution ||
1716 operand->opcode() == HloOpcode::kDot;
1717 });
1718 if (add_operand_it == add->operands().end()) {
1719 return false;
1720 }
1721 auto* matched_add_operand = *add_operand_it;
1722 // Calculate operand index of 'add' operand which was not matched above.
1723 const int64_t other_add_operand_index =
1724 matched_add_operand == add->operand(0) ? 1 : 0;
1725 // Returns true iff there is exactly one use of 'operand' at shape index
1726 // 'operand_index', and this singleton use is the fused root (at operand
1727 // index 'other_add_operand_index').
1728 if (fusion_param_value.uses().size() == 1) {
1729 const HloUse& use = fusion_param_value.uses()[0];
1730 return use.instruction == user->fused_expression_root() &&
1731 use.operand_number == other_add_operand_index;
1732 }
1733 return false;
1734 }
1735 }
1736
1737 if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
1738 user->opcode() == HloOpcode::kScatter ||
1739 user->opcode() == HloOpcode::kTriangularSolve ||
1740 user->opcode() == HloOpcode::kWhile) {
1741 // We eliminated other users in HloOrdering::LiveRangeStrictlyBefore
1742 // so here we just need to check that the use is at the right operand index.
1743 const auto operand_indices = user->OperandIndices(operand);
1744 int64_t operand_no = user->opcode() == HloOpcode::kTriangularSolve ? 1 : 0;
1745 return operand_indices.size() == 1 && operand_indices[0] == operand_no;
1746 }
1747 if (user->opcode() == HloOpcode::kSort) {
1748 // Only valid if there are no other users.
1749 if (operand->users().size() != 1) {
1750 return false;
1751 }
1752 // If we only sort keys, the output of sort is not a tuple, so we can always
1753 // share the buffer.
1754 if (user->operand_count() == 1) {
1755 return true;
1756 }
1757 CHECK(!user_index.empty());
1758 // Only share with the right tuple element buffer.
1759 const auto operand_indices = user->OperandIndices(operand);
1760 return operand_indices.size() == 1 && user_index[0] == operand_indices[0];
1761 }
1762 if (user->opcode() == HloOpcode::kCall) {
1763 // Get all uses of value defined by 'operand' at 'operand_index'.
1764 const auto& uses = GetValueDefinedAt(operand, operand_index).uses();
1765 // Return true iff:
1766 // *) There exists two uses of 'operand'.
1767 // *) One use is by 'user' (caller).
1768 // *) One use is by root instruction of called computation (callee root).
1769 // (Note: we check the root of the called computation, because the
1770 // root result buffer is required to alias with the Call result buffer).
1771 // *) The root instruction of the called computation is element-wise on
1772 // 'operand'.
1773 const bool found_caller_use =
1774 absl::c_find_if(uses, [user](const HloUse& use) {
1775 return use.instruction == user;
1776 }) != uses.end();
1777 auto* callee_root = user->to_apply()->root_instruction();
1778 const bool found_elementwise_callee_use =
1779 absl::c_find_if(uses, [callee_root](const HloUse& use) {
1780 return use.instruction == callee_root &&
1781 callee_root->IsElementwiseOnOperand(use.operand_number);
1782 }) != uses.end();
1783 return uses.size() == 2 && found_caller_use && found_elementwise_callee_use;
1784 }
1785
1786 // Loop fusions that contain transposing copies won't reach here as they have
1787 // different layouts, which fails the check in the beginning of this function.
1788 return user->IsElementwiseOnOperand(user->operand_index(operand));
1789 }
1790
1791 } // namespace xla
1792