1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <optional>
21 #include <queue>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "absl/algorithm/container.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/container/inlined_vector.h"
30 #include "absl/strings/str_cat.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
35 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
36 #include "tensorflow/compiler/xla/service/hlo_module.h"
37 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/status.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/platform/logging.h"
44
45 namespace xla {
46 namespace {
47 // CalculatePostOrderSchedule traverses a module and assign a ordinal to each
48 // instruction based the postorder dependency.
CalculatePostOrderScheduleHelper(const HloComputation * comp,int64_t start_ordinal,absl::flat_hash_map<HloInstruction *,int64_t> * ordinal_map)49 int64_t CalculatePostOrderScheduleHelper(
50 const HloComputation* comp, int64_t start_ordinal,
51 absl::flat_hash_map<HloInstruction*, int64_t>* ordinal_map) {
52 int64_t ordinal = start_ordinal;
53 for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) {
54 if (instruction->opcode() == HloOpcode::kCall ||
55 instruction->opcode() == HloOpcode::kConditional) {
56 for (const HloComputation* called_computation :
57 instruction->called_computations()) {
58 ordinal = CalculatePostOrderScheduleHelper(called_computation, ordinal,
59 ordinal_map);
60 }
61 }
62 if (instruction->opcode() == HloOpcode::kWhile) {
63 ordinal = CalculatePostOrderScheduleHelper(instruction->while_condition(),
64 ordinal, ordinal_map);
65 ordinal = CalculatePostOrderScheduleHelper(instruction->while_body(),
66 ordinal, ordinal_map);
67 }
68 // It's possible that in some unit tests the computation graph is not
69 // flatten (meaning we could have multiple callers for one computation). In
70 // that case the oridinal_map will see the instruction multiple times. We
71 // consider that case to be ok as it only shows up in unit tests.
72 ordinal_map->insert({instruction, ordinal++});
73 }
74 return ordinal;
75 }
76
CalculatePostOrderSchedule(const HloModule & module)77 absl::flat_hash_map<HloInstruction*, int64_t> CalculatePostOrderSchedule(
78 const HloModule& module) {
79 absl::flat_hash_map<HloInstruction*, int64_t> map;
80 CalculatePostOrderScheduleHelper(module.entry_computation(), 0, &map);
81 return map;
82 }
83
84 } // namespace
85 using absl::StrAppend;
86 using absl::StrCat;
87
HloDataflowAnalysis(const HloModule & module,bool ssa_form,bool bitcast_defines_value,const CanShareBuffer & can_share_buffer)88 HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form,
89 bool bitcast_defines_value,
90 const CanShareBuffer& can_share_buffer)
91 : module_(module),
92 ssa_form_(ssa_form),
93 bitcast_defines_value_(bitcast_defines_value),
94 call_graph_(CallGraph::Build(&module)),
95 can_share_buffer_(can_share_buffer) {}
96
AreTransitiveUsesElementwiseOrTuple(const HloInstruction * inst)97 bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
98 const HloInstruction* inst) {
99 absl::flat_hash_set<const HloInstruction*> visited;
100 absl::InlinedVector<const HloInstruction*, 4> stack;
101 stack.push_back(inst);
102 while (!stack.empty()) {
103 const HloInstruction* current = stack.back();
104 stack.pop_back();
105 visited.insert(current);
106 for (const HloInstruction* user : current->users()) {
107 // Found a user that is non-elementwise on current instruction.
108 for (const int64_t use_index : user->OperandIndices(current)) {
109 if (!user->IsElementwiseOnOperand(use_index) &&
110 user->opcode() != HloOpcode::kTuple) {
111 return false;
112 }
113 }
114 if (!visited.contains(user)) {
115 stack.push_back(user);
116 }
117 }
118 }
119 return true;
120 }
121
122 namespace {
Is1dSliceWithoutStrides(const HloInstruction * instr)123 bool Is1dSliceWithoutStrides(const HloInstruction* instr) {
124 return instr->opcode() == HloOpcode::kSlice &&
125 1 == instr->slice_starts().size() &&
126 1 == instr->slice_limits().size() &&
127 1 == instr->slice_strides().size() &&
128 1 == instr->slice_strides().at(0);
129 }
130
IsSliceInputFusion(const HloInstruction & unnested_hlo)131 bool IsSliceInputFusion(const HloInstruction& unnested_hlo) {
132 if (!unnested_hlo.IsInputFusion()) {
133 return false;
134 }
135 const HloInstruction* root = unnested_hlo.fused_expression_root();
136 if (root->opcode() != HloOpcode::kTuple) {
137 return false;
138 }
139 return absl::c_all_of(root->operands(), [](const HloInstruction* instr) {
140 return Is1dSliceWithoutStrides(instr);
141 });
142 }
143
144 struct ConcatUsageInfo {
145 // Pointer to a previously seen concat. nullptr if no previously seen concat.
146 const HloInstruction* prev_concat;
147 // The opnd id of the seen concat.
148 int64_t concat_opnd_idx;
149 // The slice that recovers the opnd in the concat outputs.
150 const HloInstruction* slice_to_recover_opnd;
151 };
152
153 // Returns an optional concat usage info to denote whether the concat is used in
154 // an elementwise manner. A concat followed by slices is considered effectively
155 // elementwise if the slices combinedly is a reverse function of the concat.
ConcatIsEffectivelyElementwise(const HloInstruction & concat,const HloInstruction & operand,const ConcatUsageInfo & info)156 std::optional<ConcatUsageInfo> ConcatIsEffectivelyElementwise(
157 const HloInstruction& concat, const HloInstruction& operand,
158 const ConcatUsageInfo& info) {
159 // First, check if this concat is in the below pattern. Also, we check
160 // that the slices combinedly are in effect a reverse function of the concat.
161 //
162 // Concat
163 // | |
164 // v v
165 // Slice Slice
166 //
167 std::vector<HloInstruction*> users = concat.users();
168 if (!absl::c_all_of(users, Is1dSliceWithoutStrides)) {
169 // Limit our supported cases to 1 dimensional slices.
170 return std::optional<ConcatUsageInfo>();
171 }
172 // Verify that each operand to the concat is reversed by a slice.
173 if (users.size() != concat.operand_count() ||
174 concat.operand_count() != concat.unique_operands().size()) {
175 return std::optional<ConcatUsageInfo>();
176 }
177 absl::c_sort(users, [](const HloInstruction* a, const HloInstruction* b) {
178 return a->slice_starts().at(0) < b->slice_starts().at(0);
179 });
180 int64_t prev_limit = 0;
181 for (int64_t i = 0; i < users.size(); ++i) {
182 const HloInstruction* u = users[i];
183 int64_t slice_size = u->slice_limits().at(0) - u->slice_starts().at(0);
184 if (u->slice_starts().at(0) != prev_limit ||
185 slice_size != ShapeUtil::ElementsIn(concat.operand(i)->shape())) {
186 return std::optional<ConcatUsageInfo>();
187 }
188 prev_limit = u->slice_limits().at(0);
189 }
190
191 // If we have seen other concats, make sure they are identical. Multiple
192 // concats exist because horizontal fusion inserts one concat for each output
193 // of the fusion candidates. Check that all concats and operand ids are the
194 // same to know that the "transitive use closure" will be computed in the same
195 // iteration space.
196 int64_t operand_idx = concat.operand_index(&operand);
197 if (info.prev_concat != nullptr) {
198 bool is_concat_identical = info.prev_concat->Identical(
199 concat,
200 /*eq_operands=*/[](const HloInstruction*, const HloInstruction*) {
201 // Operands don't need to be the same.
202 return true;
203 });
204 if (!is_concat_identical || info.concat_opnd_idx != operand_idx) {
205 return std::optional<ConcatUsageInfo>();
206 }
207 }
208
209 const HloInstruction* slice_to_recover_opnd = users.at(operand_idx);
210 return std::optional<ConcatUsageInfo>(
211 ConcatUsageInfo{&concat, operand_idx, slice_to_recover_opnd});
212 }
213
214 // Returns whether we can prove the transitive uses of `param` are in effect
215 // elementwise. In other words, we prove that the "transitive use closure" will
216 // all be computed in the same iteration space without any reorder of elements.
217 // In addition, we check that the "transitive use closure" includes the output
218 // in the `root_tuple`.
219 // Theoretically, We can prove more patterns but our primary use case is
220 // SliceInputFusion.
AreTransitiveUsesEffectivelyElementwise(const HloInstruction * param,const HloInstruction * root_tuple,const ShapeIndex & out_shape_idx)221 bool AreTransitiveUsesEffectivelyElementwise(const HloInstruction* param,
222 const HloInstruction* root_tuple,
223 const ShapeIndex& out_shape_idx) {
224 CHECK_EQ(root_tuple->opcode(), HloOpcode::kTuple);
225 CHECK_EQ(out_shape_idx.size(), 1);
226 absl::flat_hash_set<const HloInstruction*> visited;
227 absl::InlinedVector<const HloInstruction*, 4> stack;
228 stack.push_back(param);
229 ConcatUsageInfo concat_usage_info{nullptr, 0, nullptr};
230 bool is_output_reachable = false;
231 while (!stack.empty()) {
232 const HloInstruction* current = stack.back();
233 stack.pop_back();
234 visited.insert(current);
235 for (const HloInstruction* user : current->users()) {
236 VLOG(3) << "Visiting: " << user->ToString();
237 switch (user->opcode()) {
238 case HloOpcode::kTuple:
239 if (user == root_tuple &&
240 current == root_tuple->operand(out_shape_idx.back())) {
241 // We need to know if the output is reachable by the `param` to make
242 // sure that they will be computed in the same iteration space.
243 is_output_reachable = true;
244 }
245 break;
246 case HloOpcode::kReshape:
247 if (!ShapeUtil::ReshapeIsBitcast(current->shape(), user->shape())) {
248 return false;
249 }
250 break;
251 case HloOpcode::kConcatenate: {
252 std::optional<ConcatUsageInfo> optional_concat_info =
253 ConcatIsEffectivelyElementwise(*user, *current,
254 concat_usage_info);
255 if (!optional_concat_info) {
256 return false;
257 }
258 concat_usage_info = *optional_concat_info;
259 // Early continue as we only want to traverse through the slice that
260 // recovers the operand. It is guaranteed that the operand to the
261 // concat and the slice have the same iteration space. Insert the
262 // slice instead of the concat.
263 CHECK(!visited.contains(concat_usage_info.slice_to_recover_opnd));
264 stack.push_back(concat_usage_info.slice_to_recover_opnd);
265 continue;
266 }
267 default:
268 for (const int64_t use_index : user->OperandIndices(current)) {
269 if (!user->IsElementwiseOnOperand(use_index)) {
270 // Found a user that is non-elementwise on the current
271 // instruction.
272 return false;
273 }
274 }
275 if (!LayoutUtil::Equal(current->shape().layout(),
276 user->shape().layout())) {
277 // Make sure the layout is not changed by the elementwise op.
278 return false;
279 }
280 break;
281 } // end of switch
282 if (!visited.contains(user)) {
283 stack.push_back(user);
284 }
285 }
286 }
287 return is_output_reachable;
288 }
289 } // namespace
290
ValueIsDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const291 bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
292 const ShapeIndex& index) const {
293 const HloValueSet& value_set = GetValueSet(instruction, index);
294 if (value_set.values().size() != 1) {
295 return false;
296 }
297 return value_set.GetUniqueValue().defining_instruction() == instruction;
298 }
299
GetValueDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const300 const HloValue& HloDataflowAnalysis::GetValueDefinedAt(
301 const HloInstruction* instruction, const ShapeIndex& index) const {
302 CHECK(ValueIsDefinedAt(instruction, index)) << instruction->ToString();
303 return GetUniqueValueAt(instruction, index);
304 }
305
GetValueDefinedAt(const HloInstruction * instruction,const ShapeIndex & index)306 HloValue& HloDataflowAnalysis::GetValueDefinedAt(
307 const HloInstruction* instruction, const ShapeIndex& index) {
308 CHECK(ValueIsDefinedAt(instruction, index));
309 return GetUniqueValueAt(instruction, index);
310 }
311
NewHloValue(HloInstruction * instruction,const ShapeIndex & index,bool is_phi)312 HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction,
313 const ShapeIndex& index,
314 bool is_phi) {
315 const int64_t value_id = next_value_id_++;
316 auto result =
317 values_.insert({value_id, std::make_unique<HloValue>(
318 value_id, instruction, index, is_phi)});
319 CHECK(result.second);
320
321 VLOG(4) << "NewHloValue = " << result.first->second->ToShortString();
322
323 return result.first->second.get();
324 }
325
MarkValueForDeletion(HloValue::Id value_id)326 void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) {
327 const HloValue& value = *values_.at(value_id);
328 VLOG(4) << "MarkValueForDeletion(" << value.ToShortString() << ")";
329
330 value_ids_to_delete_.push_back(value_id);
331 }
332
DeleteMarkedValues()333 void HloDataflowAnalysis::DeleteMarkedValues() {
334 // Use a set to prevent deleting an id twice.
335 absl::flat_hash_set<HloValue::Id> id_set(value_ids_to_delete_.begin(),
336 value_ids_to_delete_.end());
337 #ifndef NDEBUG
338 // Verify that no marked-for-deletion values are in any of the value sets.
339 for (const auto& pair : value_sets_) {
340 const HloInstruction* instruction = pair.first;
341 const InstructionValueSet& instruction_value_set = *pair.second;
342 for (const auto& index_value_set : instruction_value_set) {
343 const HloValueSet& value_set = index_value_set.second;
344 for (const HloValue* value : value_set.values()) {
345 DCHECK(!ContainsKey(id_set, value->id()))
346 << "Value " << value->ToShortString()
347 << " marked for deletion, but still exists in value set for "
348 "instruction "
349 << instruction->name();
350 }
351 }
352 }
353 #endif
354
355 for (HloValue::Id value_id : id_set) {
356 values_.erase(value_id);
357 }
358 value_ids_to_delete_.clear();
359 }
360
ToString() const361 std::string HloDataflowAnalysis::ToString() const {
362 std::string out =
363 StrCat("HloDataflowAnalysis, module ", module_.name(), "\n");
364 StrAppend(&out, " Instruction value sets:\n");
365 for (const HloComputation* computation : module_.computations()) {
366 for (const HloInstruction* instruction : computation->instructions()) {
367 StrAppend(&out, "Instruction: \n ", instruction->name(), ":\n");
368 if (instruction->shape().IsTuple()) {
369 GetInstructionValueSet(instruction)
370 .ForEachElement([this, &instruction, &out](
371 const ShapeIndex& index,
372 const HloValueSet& value_set) {
373 StrAppend(&out, " tuple index ", index.ToString(), ":\n");
374 for (const HloValue* value : value_set.values()) {
375 StrAppend(&out, " ", value->ToShortString(),
376 ValueIsDefinedAt(instruction, index) ? " (def)" : "",
377 "\n");
378 }
379 });
380 } else {
381 const HloValueSet& top_level_value_set =
382 GetValueSet(instruction, /*index=*/{});
383 for (const HloValue* value : top_level_value_set.values()) {
384 StrAppend(&out, " ", value->ToShortString(),
385 ValueIsDefinedAt(instruction) ? " (def)" : "", "\n");
386 }
387 }
388 }
389 }
390 StrAppend(&out, " HloValues:\n");
391 for (const HloValue* value : values()) {
392 StrAppend(&out, value->ToString(/*indent=*/4));
393 }
394 return out;
395 }
396
Phi(HloInstruction * instruction,absl::Span<const InstructionValueSet * const> inputs)397 bool HloDataflowAnalysis::Phi(
398 HloInstruction* instruction,
399 absl::Span<const InstructionValueSet* const> inputs) {
400 CHECK(ssa_form_);
401 VLOG(4) << "Phi(" << instruction->name() << ")";
402 VLOG(5) << "instruction value set = "
403 << GetInstructionValueSet(instruction).ToString();
404 for (const InstructionValueSet* input : inputs) {
405 VLOG(5) << "input value set = " << input->ToString();
406 }
407
408 if (bitcast_defines_value_) {
409 absl::c_for_each(inputs, [&](const InstructionValueSet* input) {
410 DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape()));
411 });
412 } else {
413 const Shape& shape = instruction->shape();
414 PrimitiveType ty = shape.element_type();
415 bool is_array = shape.IsArray();
416 absl::c_for_each(inputs, [&](const InstructionValueSet* input) {
417 DCHECK(ty == input->shape().element_type() &&
418 (!is_array || ShapeUtil::ElementsIn(shape) ==
419 ShapeUtil::ElementsIn(input->shape())));
420 });
421 }
422
423 bool changed = false;
424 for (auto& pair : GetInstructionValueSet(instruction)) {
425 const ShapeIndex& index = pair.first;
426 HloValueSet& value_set = pair.second;
427
428 // Positions with phi values should never have more than one value in the
429 // value set.
430 CHECK_LE(value_set.values().size(), 1);
431 const HloValue* current_value =
432 value_set.values().size() == 1 ? value_set.values()[0] : nullptr;
433
434 // Construct a vector of value IDs of the inputs.
435 std::vector<HloValue::Id> input_value_ids;
436 for (const InstructionValueSet* input : inputs) {
437 for (const HloValue* value : input->element(index).values()) {
438 input_value_ids.push_back(value->id());
439 }
440 }
441
442 // Remove the existing phi value (if it exists). The phi can be its own
443 // input, for example, in while body parameters where the body passes
444 // through the parameter value.
445 bool current_value_defined_here =
446 (current_value != nullptr &&
447 current_value->defining_instruction() == instruction &&
448 current_value->defining_index() == index);
449
450 VLOG(5) << "after input_value_ids.size = " << input_value_ids.size();
451 if (input_value_ids.empty()) {
452 // A value set which has at least one element should never have its value
453 // set reduced to zero elements. During dataflow value sets only can go
454 // from empty to non-empty, not the reverse.
455 CHECK_EQ(value_set.values().size(), 0)
456 << "Instruction " << instruction->name() << " at index " << index
457 << " previously had non-empty value set. Value set: " << value_set;
458 } else if (input_value_ids.size() == 1) {
459 // Only a single value reaches this point. There should be no phi, and
460 // this value set should contain this single value.
461 const HloValue& new_value = GetValue(input_value_ids[0]);
462 if (current_value == nullptr) {
463 value_set.Clear();
464 value_set.AddValue(&new_value);
465 changed = true;
466 } else if (current_value != &new_value) {
467 if (current_value_defined_here) {
468 // Remove the existing phi.
469 MarkValueForDeletion(current_value->id());
470 }
471 value_set.Clear();
472 value_set.AddValue(&new_value);
473 changed = true;
474 }
475 } else {
476 // Multiple distinct values reach this point. A phi value is
477 // necessary.
478 CHECK_GT(input_value_ids.size(), 1);
479 bool phi_defined_here =
480 current_value_defined_here && current_value->is_phi();
481 if (current_value == nullptr || !phi_defined_here) {
482 value_set.Clear();
483 value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true));
484
485 std::vector<HloValue*> inputs;
486 inputs.reserve(input_value_ids.size());
487 for (HloValue::Id id : input_value_ids) {
488 inputs.push_back(&GetValue(id));
489 }
490 // Register the phi into phi graph.
491 phi_graph_.RegisterPhi(*value_set.values()[0], inputs);
492 changed = true;
493 } else if (phi_defined_here) {
494 std::vector<HloValue*> new_inputs;
495 new_inputs.reserve(input_value_ids.size());
496 for (HloValue::Id id : input_value_ids) {
497 new_inputs.push_back(&GetValue(id));
498 }
499
500 if (!phi_graph_.InputsEqualTo(*current_value, new_inputs)) {
501 VLOG(1) << current_value->ToShortString() << " has new phi inputs: ";
502 // Update phi inputs.
503 phi_graph_.RegisterPhi(*current_value, new_inputs);
504 changed = true;
505 }
506 }
507 }
508 }
509 return changed;
510 }
511
GetValue(HloValue::Id value_id) const512 const HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) const {
513 return *values_.at(value_id);
514 }
515
GetValue(HloValue::Id value_id)516 HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) {
517 return *values_.at(value_id);
518 }
519
GetFlattenedValueSet(const HloInstruction * instruction) const520 HloValueSet HloDataflowAnalysis::GetFlattenedValueSet(
521 const HloInstruction* instruction) const {
522 HloValueSet value_set;
523
524 const InstructionValueSet& value_set_tree =
525 GetInstructionValueSet(instruction);
526
527 std::vector<const HloValueSet*> all_sets;
528 for (auto& pair : value_set_tree) {
529 const HloValueSet& value_set = pair.second;
530 all_sets.push_back(&value_set);
531 }
532 value_set.AssignUnionOf(all_sets);
533
534 return value_set;
535 }
536
GetValueSet(const HloInstruction * instruction,const ShapeIndex & index) const537 const HloValueSet& HloDataflowAnalysis::GetValueSet(
538 const HloInstruction* instruction, const ShapeIndex& index) const {
539 return GetInstructionValueSet(instruction).element(index);
540 }
541
GetValueSet(const HloInstruction * instruction,const ShapeIndex & index)542 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloInstruction* instruction,
543 const ShapeIndex& index) {
544 return *GetInstructionValueSet(instruction).mutable_element(index);
545 }
546
GetValueSet(const HloPosition & position) const547 const HloValueSet& HloDataflowAnalysis::GetValueSet(
548 const HloPosition& position) const {
549 return GetValueSet(position.instruction, position.index);
550 }
551
GetValueSet(const HloPosition & position)552 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloPosition& position) {
553 return GetValueSet(position.instruction, position.index);
554 }
555
UpdateBitcastValueSet(HloInstruction * bitcast)556 bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) {
557 CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast);
558 const InstructionValueSet& operand_set =
559 GetInstructionValueSet(bitcast->operand(0));
560 InstructionValueSet& bitcast_set = GetInstructionValueSet(bitcast);
561 if (!bitcast_defines_value_ && operand_set != bitcast_set) {
562 bitcast_set = operand_set;
563 return true;
564 }
565 return false;
566 }
567
UpdateSetDimensionSizeValueSet(HloInstruction * set_dimension_size)568 bool HloDataflowAnalysis::UpdateSetDimensionSizeValueSet(
569 HloInstruction* set_dimension_size) {
570 CHECK_EQ(set_dimension_size->opcode(), HloOpcode::kSetDimensionSize);
571 const InstructionValueSet& operand_set =
572 GetInstructionValueSet(set_dimension_size->operand(0));
573 InstructionValueSet& set_dimension_size_set =
574 GetInstructionValueSet(set_dimension_size);
575 if (operand_set != set_dimension_size_set) {
576 set_dimension_size_set = operand_set;
577 return true;
578 }
579 return false;
580 }
581
UpdateSendValueSet(HloInstruction * send)582 bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
583 CHECK_EQ(send->opcode(), HloOpcode::kSend);
584 bool changed = false;
585 // Send forwards the operand value to the output tuple at {0}.
586 for (auto& pair : GetInstructionValueSet(send->operand(0))) {
587 const ShapeIndex& operand_index = pair.first;
588 const HloValueSet& operand_value_set = pair.second;
589
590 ShapeIndex index = {0};
591 for (int64_t i : operand_index) {
592 index.push_back(i);
593 }
594
595 HloValueSet& value_set = GetValueSet(send, index);
596 if (value_set != operand_value_set) {
597 value_set = operand_value_set;
598 changed = true;
599 }
600 }
601 return changed;
602 }
603
UpdateAsyncStartValueSet(HloInstruction * async_start)604 bool HloDataflowAnalysis::UpdateAsyncStartValueSet(
605 HloInstruction* async_start) {
606 CHECK_EQ(async_start->opcode(), HloOpcode::kAsyncStart);
607 bool changed = false;
608 // AsyncStart forwards the operand values to element {0} of its output.
609 for (int64_t i = 0; i < async_start->operand_count(); ++i) {
610 const HloInstruction* operand = async_start->operand(i);
611 ShapeUtil::ForEachSubshape(
612 operand->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
613 if (!subshape.IsArray()) {
614 return;
615 }
616 const HloValueSet& operand_value_set = GetValueSet(operand, index);
617
618 ShapeIndex output_index = {0, i};
619 output_index.insert(output_index.end(), index.begin(), index.end());
620
621 HloValueSet& value_set = GetValueSet(async_start, output_index);
622 if (value_set != operand_value_set) {
623 value_set = operand_value_set;
624 changed = true;
625 }
626 });
627 }
628 // AsyncStart forwards the async wrapped computation root values to element
629 // {1} of its output.
630 HloInstruction* root =
631 async_start->async_wrapped_computation()->root_instruction();
632 ShapeUtil::ForEachSubshape(
633 root->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
634 if (!subshape.IsArray()) {
635 return;
636 }
637 const HloValueSet& root_value_set = GetValueSet(root, index);
638
639 ShapeIndex output_index = {1};
640 output_index.insert(output_index.end(), index.begin(), index.end());
641
642 HloValueSet& value_set = GetValueSet(async_start, output_index);
643 if (value_set != root_value_set) {
644 value_set = root_value_set;
645 changed = true;
646 }
647 });
648 return changed;
649 }
650
UpdateAsyncUpdateValueSet(HloInstruction * async_update)651 bool HloDataflowAnalysis::UpdateAsyncUpdateValueSet(
652 HloInstruction* async_update) {
653 CHECK_EQ(async_update->opcode(), HloOpcode::kAsyncUpdate);
654 CHECK_EQ(async_update->shape(), async_update->operand(0)->shape());
655 bool changed = false;
656 HloInstruction* root =
657 async_update->async_wrapped_computation()->root_instruction();
658 // AsyncUpdate forwards all of the operand values to corresponding elements of
659 // its output.
660 ShapeUtil::ForEachSubshape(
661 async_update->operand(0)->shape(),
662 [&](const Shape& subshape, const ShapeIndex& index) {
663 if (!subshape.IsArray()) {
664 return;
665 }
666 const HloValueSet& operand_value_set =
667 GetValueSet(async_update->operand(0), index);
668
669 HloValueSet& value_set = GetValueSet(async_update, index);
670 CHECK_GE(index.size(), 0);
671 if (index[0] != 1) {
672 if (value_set != operand_value_set) {
673 value_set = operand_value_set;
674 changed = true;
675 }
676 } else {
677 // If this subshape is an output (index {1}), we need to create the
678 // union with the async wrapped computation root.
679 ShapeIndex root_index(index.begin() + 1, index.end());
680 const HloValueSet& root_value_set = GetValueSet(root, root_index);
681 changed |=
682 value_set.AssignUnionOf({&operand_value_set, &root_value_set});
683 }
684 });
685 return changed;
686 }
687
UpdateAsyncDoneValueSet(HloInstruction * async_done)688 bool HloDataflowAnalysis::UpdateAsyncDoneValueSet(HloInstruction* async_done) {
689 CHECK_EQ(async_done->opcode(), HloOpcode::kAsyncDone);
690 bool changed = false;
691 HloInstruction* root =
692 async_done->async_wrapped_computation()->root_instruction();
693 // AsyncDone creates a union of the operand values at {1} and the async
694 // wrapped computation root to element {} of its output.
695 ShapeUtil::ForEachSubshape(
696 async_done->operand(0)->shape(),
697 [&](const Shape& subshape, const ShapeIndex& index) {
698 if (!subshape.IsArray() || index.front() != 1) {
699 return;
700 }
701 const HloValueSet& operand_value_set =
702 GetValueSet(async_done->operand(0), index);
703
704 ShapeIndex output_index(index.begin() + 1, index.end());
705 HloValueSet& value_set = GetValueSet(async_done, output_index);
706 const HloValueSet& root_value_set = GetValueSet(root, output_index);
707 changed |=
708 value_set.AssignUnionOf({&operand_value_set, &root_value_set});
709 });
710 return changed;
711 }
712
UpdateCopyStartValueSet(HloInstruction * copy_start)713 bool HloDataflowAnalysis::UpdateCopyStartValueSet(HloInstruction* copy_start) {
714 CHECK_EQ(copy_start->opcode(), HloOpcode::kCopyStart);
715 bool changed = false;
716 // CopyStart forwards the operand value to element {1} of its output.
717 const HloValueSet& operand_value_set = GetValueSet(copy_start->operand(0));
718 HloValueSet& value_set = GetValueSet(copy_start, {1});
719 if (value_set != operand_value_set) {
720 value_set = operand_value_set;
721 changed = true;
722 }
723 return changed;
724 }
725
UpdateCopyDoneValueSet(HloInstruction * copy_done)726 bool HloDataflowAnalysis::UpdateCopyDoneValueSet(HloInstruction* copy_done) {
727 CHECK_EQ(copy_done->opcode(), HloOpcode::kCopyDone);
728 bool changed = false;
729 // CopyDone forwards the operand value at {0} to element {} of its output.
730 const HloValueSet& operand_value_set =
731 GetValueSet(copy_done->operand(0), {0});
732 HloValueSet& value_set = GetValueSet(copy_done);
733 if (value_set != operand_value_set) {
734 value_set = operand_value_set;
735 changed = true;
736 }
737 return changed;
738 }
739
UpdateRecvDoneValueSet(HloInstruction * recv_done)740 bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) {
741 CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone);
742 bool changed = false;
743 // RecvDone forwards the operand value at {0} to element {0} of its output.
744 for (auto& pair : GetInstructionValueSet(recv_done)) {
745 ShapeIndex& index = pair.first;
746 HloValueSet& value_set = pair.second;
747
748 if (index.empty() || index[0] != 0) {
749 continue;
750 }
751
752 const HloValueSet& operand_value_set =
753 GetValueSet(recv_done->operand(0), index);
754 if (value_set != operand_value_set) {
755 value_set = operand_value_set;
756 changed = true;
757 }
758 }
759 return changed;
760 }
761
UpdateCallValueSet(HloInstruction * call)762 bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) {
763 CHECK_EQ(call->opcode(), HloOpcode::kCall);
764 InstructionValueSet& value_set = GetInstructionValueSet(call);
765 InstructionValueSet& root_value_set =
766 GetInstructionValueSet(call->to_apply()->root_instruction());
767 if (value_set != root_value_set) {
768 value_set = root_value_set;
769 return true;
770 }
771 return false;
772 }
773
UpdateConditionalValueSet(HloInstruction * conditional)774 bool HloDataflowAnalysis::UpdateConditionalValueSet(
775 HloInstruction* conditional) {
776 CHECK_EQ(conditional->opcode(), HloOpcode::kConditional);
777 std::vector<const InstructionValueSet*> inputs(conditional->branch_count());
778 for (int j = 0; j < conditional->branch_count(); ++j) {
779 inputs[j] = &GetInstructionValueSet(
780 conditional->branch_computation(j)->root_instruction());
781 }
782 if (ssa_form_) {
783 return Phi(conditional, inputs);
784 } else {
785 return GetInstructionValueSet(conditional).AssignUnionOf(inputs);
786 }
787 }
788
UpdateCopyValueSet(HloInstruction * copy)789 bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) {
790 CHECK_EQ(copy->opcode(), HloOpcode::kCopy);
791 bool changed = false;
792 for (auto& pair : GetInstructionValueSet(copy)) {
793 const ShapeIndex& index = pair.first;
794 if (index.empty()) {
795 // kCopy shallow copies and thus defines the top-level value so nothing to
796 // update.
797 continue;
798 }
799
800 HloValueSet& value_set = pair.second;
801 HloValueSet& operand_value_set = GetValueSet(copy->operand(0), index);
802 if (value_set != operand_value_set) {
803 value_set = operand_value_set;
804 changed = true;
805 }
806 }
807 return changed;
808 }
809
UpdateOptimizationBarrierValueSet(HloInstruction * barrier)810 bool HloDataflowAnalysis::UpdateOptimizationBarrierValueSet(
811 HloInstruction* barrier) {
812 // Optimization Barriers just forward their operand. Given that barriers can
813 // have a tuple operand, we iterate through its indexes, like for copies.
814 // Unlike copies though we also propagate the top-level value.
815 CHECK_EQ(barrier->opcode(), HloOpcode::kOptimizationBarrier);
816 bool changed = false;
817 for (auto& pair : GetInstructionValueSet(barrier)) {
818 const ShapeIndex& index = pair.first;
819 HloValueSet& value_set = pair.second;
820 HloValueSet& operand_value_set = GetValueSet(barrier->operand(0), index);
821 if (value_set != operand_value_set) {
822 value_set = operand_value_set;
823 changed = true;
824 }
825 }
826 return changed;
827 }
828
UpdateDomainValueSet(HloInstruction * domain)829 bool HloDataflowAnalysis::UpdateDomainValueSet(HloInstruction* domain) {
830 // Domain instructions just forward their operand. Given that domains can have
831 // a tuple operand, we iterate through its indexes, like for copies.
832 // Unlike copies though we also propagate the top-level value.
833 CHECK_EQ(domain->opcode(), HloOpcode::kDomain);
834 bool changed = false;
835 for (auto& pair : GetInstructionValueSet(domain)) {
836 const ShapeIndex& index = pair.first;
837 HloValueSet& value_set = pair.second;
838 HloValueSet& operand_value_set = GetValueSet(domain->operand(0), index);
839 if (value_set != operand_value_set) {
840 value_set = operand_value_set;
841 changed = true;
842 }
843 }
844 return changed;
845 }
846
UpdateAddDependencyValueSet(HloInstruction * add_dependency)847 bool HloDataflowAnalysis::UpdateAddDependencyValueSet(
848 HloInstruction* add_dependency) {
849 // AddDependency just forwards the value of its zero-th operand.
850 CHECK_EQ(add_dependency->opcode(), HloOpcode::kAddDependency);
851 const InstructionValueSet& operand_set =
852 GetInstructionValueSet(add_dependency->operand(0));
853 InstructionValueSet& add_dependency_set =
854 GetInstructionValueSet(add_dependency);
855 if (operand_set != add_dependency_set) {
856 add_dependency_set = operand_set;
857 return true;
858 }
859 return false;
860 }
861
UpdateGetTupleElementValueSet(HloInstruction * gte)862 bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) {
863 CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement);
864 bool changed = false;
865 // The GetTupleElement instruction forwards the values from the specified
866 // tuple element.
867 for (auto& pair : GetInstructionValueSet(gte)) {
868 const ShapeIndex& index = pair.first;
869 HloValueSet& value_set = pair.second;
870
871 // The corresponding ShapeIndex of the operand is simply the GTE ShapeIndex
872 // with the tuple element number prefixed.
873 ShapeIndex operand_index = {gte->tuple_index()};
874 for (int64_t i : index) {
875 operand_index.push_back(i);
876 }
877
878 HloValueSet& operand_value_set =
879 GetValueSet(gte->operand(0), operand_index);
880 if (value_set != operand_value_set) {
881 value_set = operand_value_set;
882 changed = true;
883 }
884 }
885 return changed;
886 }
887
UpdateParameterValueSet(HloInstruction * parameter)888 bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) {
889 CHECK_EQ(parameter->opcode(), HloOpcode::kParameter);
890 const CallGraphNode& call_graph_node =
891 call_graph_->GetNode(parameter->parent());
892
893 // Subcomputations called in a parallel context (eg, map) do not have dataflow
894 // from the caller operands.
895 if (call_graph_node.context() == CallContext::kEmbedded ||
896 call_graph_node.caller_callsites().empty()) {
897 return false;
898 }
899 CHECK_EQ(call_graph_node.context(), CallContext::kControlFlow);
900
901 std::vector<const InstructionValueSet*> inputs;
902 bool need_phi = false;
903 for (const CallSite& callsite : call_graph_node.caller_callsites()) {
904 const HloOpcode& opcode = callsite.instruction()->opcode();
905 if (opcode == HloOpcode::kCall) {
906 // The operand values of a call instruction are forwarded to the
907 // respective parameter instruction of the subcomputation.
908 inputs.push_back(&GetInstructionValueSet(
909 callsite.instruction()->operand(parameter->parameter_number())));
910 } else if (opcode == HloOpcode::kWhile) {
911 // In a while instruction, the while operand (ie, the init value) and the
912 // backedge are dataflow inputs to the parameter instruction. This is the
913 // case for parameters of both the body and condition computations.
914 CHECK_EQ(parameter->parameter_number(), 0);
915 inputs.push_back(
916 &GetInstructionValueSet(callsite.instruction()->operand(0)));
917 // If the parameter *is not* the root, parameter state would be
918 // updated by the root, otherwise don't consider it's current state
919 // (InstructionValueSet) as we are recomputing its current state.
920 if (parameter !=
921 callsite.instruction()->while_body()->root_instruction()) {
922 inputs.push_back(&GetInstructionValueSet(
923 callsite.instruction()->while_body()->root_instruction()));
924 }
925 need_phi = true;
926 } else if (opcode == HloOpcode::kConditional) {
927 CHECK_EQ(parameter->parameter_number(), 0);
928 auto conditional = callsite.instruction();
929 // Conditional has branch_count+1 operands. Operand 0 is the branch_index,
930 // operands 1 and onward are the arguments to the branch computations.
931 //
932 // If the parameter belongs to conditional's branch 0 computation, then
933 // operand 1 is forwarded to this parameter instruction. If the parameter
934 // belongs to conditional's branch 5 computation, then operand 6 is
935 // forwarded to this parameter instruction.
936 bool found_parent = false;
937 for (int j = 0; j < conditional->branch_count(); ++j) {
938 if (parameter->parent() == conditional->branch_computation(j)) {
939 inputs.push_back(
940 &GetInstructionValueSet(conditional->operand(j + 1)));
941 found_parent = true;
942 break;
943 }
944 }
945 CHECK(found_parent);
946 need_phi = true;
947 } else if (opcode == HloOpcode::kAsyncStart) {
948 inputs.push_back(&GetInstructionValueSet(
949 callsite.instruction()->operand(parameter->parameter_number())));
950 } else if (opcode == HloOpcode::kAsyncUpdate ||
951 opcode == HloOpcode::kAsyncDone) {
952 return GetInstructionValueSet(parameter).AssignUnionOf(
953 GetInstructionValueSet(callsite.instruction()->operand(0)),
954 {0, parameter->parameter_number()});
955 } else {
956 LOG(FATAL) << "CallContext::kSequential computations should only be "
957 "called from call, while, or conditional instructions";
958 }
959 }
960 if (ssa_form_ && need_phi) {
961 return Phi(parameter, inputs);
962 } else {
963 return GetInstructionValueSet(parameter).AssignUnionOf(inputs);
964 }
965 }
966
UpdateTupleValueSet(HloInstruction * tuple)967 bool HloDataflowAnalysis::UpdateTupleValueSet(HloInstruction* tuple) {
968 CHECK_EQ(tuple->opcode(), HloOpcode::kTuple);
969 bool changed = false;
970 for (int64_t i = 0; i < tuple->operands().size(); ++i) {
971 // Copy the value set(s) of each operand into the respective position in the
972 // kTuple instruction's value sets.
973 for (auto& pair : GetInstructionValueSet(tuple->operand(i))) {
974 const ShapeIndex& operand_index = pair.first;
975 HloValueSet& operand_value_set = pair.second;
976
977 ShapeIndex index = {i};
978 for (int64_t op_index : operand_index) {
979 index.push_back(op_index);
980 }
981 HloValueSet& value_set = GetValueSet(tuple, index);
982
983 if (value_set != operand_value_set) {
984 value_set = operand_value_set;
985 changed = true;
986 }
987 }
988 }
989 return changed;
990 }
991
UpdateWhileValueSet(HloInstruction * xla_while)992 bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) {
993 CHECK_EQ(xla_while->opcode(), HloOpcode::kWhile);
994 const InstructionValueSet* const inputs[] = {
995 &GetInstructionValueSet(xla_while->while_body()->root_instruction()),
996 &GetInstructionValueSet(xla_while->operand(0))};
997 if (ssa_form_) {
998 return Phi(xla_while, inputs);
999 } else {
1000 return GetInstructionValueSet(xla_while).AssignUnionOf(inputs);
1001 }
1002 }
1003
UpdateAllGatherStartValueSet(HloInstruction * all_gather_start)1004 bool HloDataflowAnalysis::UpdateAllGatherStartValueSet(
1005 HloInstruction* all_gather_start) {
1006 CHECK_EQ(all_gather_start->opcode(), HloOpcode::kAllGatherStart);
1007 bool changed = false;
1008 // AllGatherStart forwards the operand values to element {0} of its output.
1009 for (int64_t i = 0; i < all_gather_start->operand_count(); ++i) {
1010 const HloValueSet& operand_value_set =
1011 GetValueSet(all_gather_start->operand(i));
1012
1013 ShapeIndex output_index = {0};
1014 if (all_gather_start->operand_count() > 1) {
1015 output_index.push_back(i);
1016 }
1017
1018 HloValueSet& value_set = GetValueSet(all_gather_start, output_index);
1019 if (value_set != operand_value_set) {
1020 value_set = operand_value_set;
1021 changed = true;
1022 }
1023 }
1024 return changed;
1025 }
1026
UpdateAllGatherDoneValueSet(HloInstruction * all_gather_done)1027 bool HloDataflowAnalysis::UpdateAllGatherDoneValueSet(
1028 HloInstruction* all_gather_done) {
1029 CHECK_EQ(all_gather_done->opcode(), HloOpcode::kAllGatherDone);
1030 bool changed = false;
1031 // AllGatherDone forwards the operand value at {1} to its output.
1032 for (auto& pair : GetInstructionValueSet(all_gather_done)) {
1033 const ShapeIndex& output_index = pair.first;
1034 HloValueSet& value_set = pair.second;
1035
1036 ShapeIndex operand_index = {1};
1037 for (int64_t i : output_index) {
1038 operand_index.push_back(i);
1039 }
1040
1041 const HloValueSet& operand_value_set =
1042 GetValueSet(all_gather_done->operand(0), operand_index);
1043 if (value_set != operand_value_set) {
1044 value_set = operand_value_set;
1045 changed = true;
1046 }
1047 }
1048 return changed;
1049 }
1050
UpdateAllReduceDoneValueSet(HloInstruction * all_reduce_done)1051 bool HloDataflowAnalysis::UpdateAllReduceDoneValueSet(
1052 HloInstruction* all_reduce_done) {
1053 CHECK_EQ(all_reduce_done->opcode(), HloOpcode::kAllReduceDone);
1054 bool changed = false;
1055 // AllReduceDone forwards its only operand.
1056 for (auto& pair : GetInstructionValueSet(all_reduce_done)) {
1057 const ShapeIndex& output_index = pair.first;
1058 HloValueSet& value_set = pair.second;
1059
1060 ShapeIndex operand_index = {};
1061 for (int64_t i : output_index) {
1062 operand_index.push_back(i);
1063 }
1064
1065 const HloValueSet& operand_value_set =
1066 GetValueSet(all_reduce_done->operand(0), operand_index);
1067 if (value_set != operand_value_set) {
1068 value_set = operand_value_set;
1069 changed = true;
1070 }
1071 }
1072 return changed;
1073 }
1074
UpdateCollectivePermuteStartValueSet(HloInstruction * collective_permute_start)1075 bool HloDataflowAnalysis::UpdateCollectivePermuteStartValueSet(
1076 HloInstruction* collective_permute_start) {
1077 CHECK_EQ(collective_permute_start->opcode(),
1078 HloOpcode::kCollectivePermuteStart);
1079 bool changed = false;
1080 // CollectivePermuteStart forwards the operand value to element {0} of its
1081 // output.
1082 if (collective_permute_start->operand(0)->shape().IsTuple()) {
1083 for (int i = 0; i < ShapeUtil::TupleElementCount(
1084 collective_permute_start->operand(0)->shape());
1085 ++i) {
1086 const HloValueSet& operand_value_set =
1087 GetValueSet(collective_permute_start->operand(0), {i});
1088 HloValueSet& value_set = GetValueSet(collective_permute_start, {0, i});
1089 if (value_set != operand_value_set) {
1090 value_set = operand_value_set;
1091 changed = true;
1092 }
1093 }
1094 } else {
1095 const HloValueSet& operand_value_set =
1096 GetValueSet(collective_permute_start->operand(0));
1097 HloValueSet& value_set = GetValueSet(collective_permute_start, {0});
1098 if (value_set != operand_value_set) {
1099 value_set = operand_value_set;
1100 changed = true;
1101 }
1102 }
1103 return changed;
1104 }
1105
UpdateCollectivePermuteDoneValueSet(HloInstruction * collective_permute_done)1106 bool HloDataflowAnalysis::UpdateCollectivePermuteDoneValueSet(
1107 HloInstruction* collective_permute_done) {
1108 CHECK_EQ(collective_permute_done->opcode(),
1109 HloOpcode::kCollectivePermuteDone);
1110 bool changed = false;
1111 // CollectivePermuteDone forwards the operand value at {1} to its output.
1112 if (collective_permute_done->shape().IsTuple()) {
1113 for (int i = 0;
1114 i < ShapeUtil::TupleElementCount(collective_permute_done->shape());
1115 ++i) {
1116 const HloValueSet& operand_value_set =
1117 GetValueSet(collective_permute_done->operand(0), {1, i});
1118 HloValueSet& value_set = GetValueSet(collective_permute_done, {i});
1119 if (value_set != operand_value_set) {
1120 value_set = operand_value_set;
1121 changed = true;
1122 }
1123 }
1124 } else {
1125 const HloValueSet& operand_value_set =
1126 GetValueSet(collective_permute_done->operand(0), {1});
1127 HloValueSet& value_set = GetValueSet(collective_permute_done);
1128 if (value_set != operand_value_set) {
1129 value_set = operand_value_set;
1130 changed = true;
1131 }
1132 }
1133 return changed;
1134 }
1135
UpdateInstructionValueSet(HloInstruction * instruction)1136 bool HloDataflowAnalysis::UpdateInstructionValueSet(
1137 HloInstruction* instruction) {
1138 // Recompute from operands.
1139 switch (instruction->opcode()) {
1140 case HloOpcode::kAddDependency:
1141 return UpdateAddDependencyValueSet(instruction);
1142 case HloOpcode::kAllGatherStart:
1143 return UpdateAllGatherStartValueSet(instruction);
1144 case HloOpcode::kAllGatherDone:
1145 return UpdateAllGatherDoneValueSet(instruction);
1146 case HloOpcode::kAsyncStart:
1147 return UpdateAsyncStartValueSet(instruction);
1148 case HloOpcode::kAsyncUpdate:
1149 return UpdateAsyncUpdateValueSet(instruction);
1150 case HloOpcode::kAsyncDone:
1151 return UpdateAsyncDoneValueSet(instruction);
1152 case HloOpcode::kBitcast:
1153 return UpdateBitcastValueSet(instruction);
1154 case HloOpcode::kSetDimensionSize:
1155 return UpdateSetDimensionSizeValueSet(instruction);
1156 case HloOpcode::kDomain:
1157 return UpdateDomainValueSet(instruction);
1158 case HloOpcode::kCopy:
1159 return UpdateCopyValueSet(instruction);
1160 case HloOpcode::kGetTupleElement:
1161 return UpdateGetTupleElementValueSet(instruction);
1162 case HloOpcode::kTuple:
1163 return UpdateTupleValueSet(instruction);
1164 case HloOpcode::kParameter:
1165 return UpdateParameterValueSet(instruction);
1166 case HloOpcode::kCall:
1167 return UpdateCallValueSet(instruction);
1168 case HloOpcode::kWhile:
1169 return UpdateWhileValueSet(instruction);
1170 case HloOpcode::kSend:
1171 return UpdateSendValueSet(instruction);
1172 case HloOpcode::kRecvDone:
1173 return UpdateRecvDoneValueSet(instruction);
1174 case HloOpcode::kCopyStart:
1175 return UpdateCopyStartValueSet(instruction);
1176 case HloOpcode::kCopyDone:
1177 return UpdateCopyDoneValueSet(instruction);
1178 case HloOpcode::kConditional:
1179 return UpdateConditionalValueSet(instruction);
1180 case HloOpcode::kAllReduceDone:
1181 return UpdateAllReduceDoneValueSet(instruction);
1182 case HloOpcode::kCollectivePermuteStart:
1183 return UpdateCollectivePermuteStartValueSet(instruction);
1184 case HloOpcode::kCollectivePermuteDone:
1185 return UpdateCollectivePermuteDoneValueSet(instruction);
1186 case HloOpcode::kOptimizationBarrier:
1187 return UpdateOptimizationBarrierValueSet(instruction);
1188 default:
1189 // Instruction does not forward HloValues (it defines all values in its
1190 // output). No update is necessary.
1191 return false;
1192 }
1193 }
1194
Propagate()1195 void HloDataflowAnalysis::Propagate() {
1196 using Work = std::pair<int64_t, HloInstruction*>;
1197 // Avoid duplicating work by preferring work items early in the post order
1198 // schedule. Intuitively, we start from entry parameters and propagate buffers
1199 // updates throughout the module only once.
1200 std::priority_queue<Work, std::vector<Work>, std::greater<Work>> worklist;
1201 absl::flat_hash_set<HloInstruction*> workset;
1202 auto priority_map = CalculatePostOrderSchedule(module_);
1203 auto add_to_worklist = [&priority_map, &worklist,
1204 &workset](HloInstruction* instruction) {
1205 if (workset.insert(instruction).second) {
1206 worklist.emplace(priority_map[instruction], instruction);
1207 }
1208 };
1209
1210 auto comps = module_.MakeComputationPostOrder();
1211 for (HloComputation* computation : comps) {
1212 for (HloInstruction* instruction :
1213 computation->MakeInstructionPostOrder()) {
1214 add_to_worklist(instruction);
1215 }
1216 }
1217 VLOG(1) << "SSA_FORM_: " << ssa_form_;
1218
1219 while (!worklist.empty()) {
1220 HloInstruction* instruction = worklist.top().second;
1221 auto add_to_worklist = [&](HloInstruction* todo) {
1222 if (workset.insert(todo).second) {
1223 VLOG(1) << " Adding todo : " << todo->name();
1224 worklist.emplace(priority_map[todo], todo);
1225 }
1226 };
1227 worklist.pop();
1228
1229 workset.erase(workset.find(instruction));
1230
1231 VLOG(3) << "Worklist top: " << instruction->name();
1232 XLA_VLOG_LINES(3, ToString());
1233
1234 if (!UpdateInstructionValueSet(instruction)) {
1235 // No change to the instruction's value set.
1236 VLOG(4) << "No change.";
1237 continue;
1238 }
1239
1240 VLOG(4) << "New value set for " << instruction->name() << ": "
1241 << GetInstructionValueSet(instruction);
1242
1243 // Instruction value was updated. Add users to work list if we haven't
1244 // already.
1245 for (HloInstruction* user : instruction->users()) {
1246 add_to_worklist(user);
1247
1248 // If user sequentially calls a computation, then the respective
1249 // parameter(s) of the computation need to be updated.
1250 if (user->opcode() == HloOpcode::kConditional) {
1251 // If operand 0 is the use of instruction, then no parameters need to be
1252 // updated, since that is the branch_index of the conditional.
1253 // If operand n+1 is the use of instruction, then the branch_computation
1254 // n's parameter need to be updated.
1255 //
1256 // Note that the same instruction can be used in multiple branches'
1257 // operands.
1258 for (int j = 0; j < user->branch_count(); ++j) {
1259 if (user->operand(j + 1) == instruction) {
1260 add_to_worklist(
1261 user->branch_computation(j)->parameter_instruction(0));
1262 }
1263 }
1264 } else if (user->opcode() == HloOpcode::kAsyncUpdate ||
1265 user->opcode() == HloOpcode::kAsyncDone) {
1266 // For async update and async done, we cannot distinguish which
1267 // parameter needs to be updated so add all to the worklist.
1268 for (int64_t parameter_number = 0;
1269 parameter_number <
1270 user->async_wrapped_computation()->num_parameters();
1271 ++parameter_number) {
1272 add_to_worklist(
1273 user->async_wrapped_computation()->parameter_instruction(
1274 parameter_number));
1275 }
1276 } else {
1277 for (HloComputation* called_computation : user->called_computations()) {
1278 const CallGraphNode& call_graph_node =
1279 call_graph_->GetNode(called_computation);
1280 if (call_graph_node.context() == CallContext::kControlFlow) {
1281 for (int64_t operand_number : user->OperandIndices(instruction)) {
1282 add_to_worklist(
1283 called_computation->parameter_instruction(operand_number));
1284 }
1285 }
1286 }
1287 }
1288 }
1289
1290 // If instruction is a root instruction, then propagate out to any calling
1291 // instruction and across any while backedge.
1292 if (instruction == instruction->parent()->root_instruction()) {
1293 const CallGraphNode& call_graph_node =
1294 call_graph_->GetNode(instruction->parent());
1295 for (const CallSite& callsite : call_graph_node.caller_callsites()) {
1296 if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
1297 // Add the while itself, and the body and condition parameters.
1298 add_to_worklist(callsite.instruction());
1299 add_to_worklist(
1300 callsite.instruction()->while_body()->parameter_instruction(0));
1301 add_to_worklist(
1302 callsite.instruction()->while_condition()->parameter_instruction(
1303 0));
1304 } else if (call_graph_node.context() == CallContext::kControlFlow) {
1305 add_to_worklist(callsite.instruction());
1306 }
1307 }
1308 }
1309 }
1310 }
1311
GetInstructionValueSet(const HloInstruction * instruction) const1312 const InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
1313 const HloInstruction* instruction) const {
1314 return *value_sets_.at(instruction);
1315 }
1316
GetInstructionValueSet(const HloInstruction * instruction)1317 InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
1318 const HloInstruction* instruction) {
1319 return *value_sets_.at(instruction);
1320 }
1321
InitializeInstructionValueSets()1322 Status HloDataflowAnalysis::InitializeInstructionValueSets() {
1323 for (const HloComputation* computation : module_.MakeComputationSorted()) {
1324 const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
1325 for (HloInstruction* instruction :
1326 computation->MakeInstructionPostOrder()) {
1327 // Create an empty shape tree.
1328 value_sets_.insert({instruction, std::make_unique<InstructionValueSet>(
1329 instruction->shape())});
1330
1331 // For each sub-shape of the instruction shape, add a new HloValue to its
1332 // HloValueSet. should_define may be provided to define a subset of
1333 // values.
1334 auto define_all_values =
1335 [this,
1336 &instruction](std::function<bool(const ShapeIndex&)> should_define =
1337 [](const ShapeIndex&) { return true; }) {
1338 for (auto& pair : GetInstructionValueSet(instruction)) {
1339 const ShapeIndex& index = pair.first;
1340 if (should_define(index)) {
1341 HloValue* value =
1342 NewHloValue(instruction, index, /*is_phi=*/false);
1343 GetValueSet(instruction, index).AddValue(value);
1344 }
1345 }
1346 };
1347
1348 // Add a new HloValue to the HloValueSet corresponding to the given index
1349 // of the instruction shape.
1350 auto define_value_at = [this, &instruction](const ShapeIndex& index) {
1351 HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
1352 GetValueSet(instruction, index).AddValue(value);
1353 };
1354
1355 switch (instruction->opcode()) {
1356 case HloOpcode::kBitcast:
1357 if (bitcast_defines_value_) {
1358 define_all_values();
1359 }
1360 break;
1361 case HloOpcode::kSetDimensionSize:
1362 case HloOpcode::kAddDependency:
1363 case HloOpcode::kWhile:
1364 case HloOpcode::kCall:
1365 case HloOpcode::kConditional:
1366 case HloOpcode::kGetTupleElement:
1367 case HloOpcode::kDomain:
1368 case HloOpcode::kOptimizationBarrier:
1369 // These instructions define no values. The values in their output
1370 // flow from their operands or from cross computation dataflow.
1371 break;
1372 case HloOpcode::kParameter:
1373 if (call_graph_node.context() == CallContext::kBoth) {
1374 // We do not support a subcomputation that is called from both a
1375 // parallel and sequential context. In this case, the parameter
1376 // would both define a value and propagate a value from its
1377 // caller. This limitation is not really a problem because the call
1378 // graph is typically flattened.
1379 return Unimplemented(
1380 "Computation %s is called in both a parallel (eg, kMap) and "
1381 "sequential (eg, kCall) context",
1382 computation->name());
1383 }
1384 if (call_graph_node.caller_callsites().empty() ||
1385 call_graph_node.context() == CallContext::kEmbedded) {
1386 // Parameters of computations called in a parallel context (eg, map
1387 // and reduce) as well as parameters of dead computations define all
1388 // values in their output. Otherwise the values of the parameter
1389 // come from the caller (eg, operands to the kCall instruction).
1390 define_all_values();
1391 }
1392 break;
1393 case HloOpcode::kCopy:
1394 case HloOpcode::kTuple:
1395 // These instructions only define their top-level values. Any other
1396 // values flow from their operands.
1397 define_value_at(/*index=*/{});
1398 break;
1399 case HloOpcode::kAsyncStart:
1400 // AsyncStart produces a tuple of {{aliased operands}, {destination},
1401 // contexts}. It defines all of the tuple-shaped values and the
1402 // contexts.
1403 define_all_values([&](const ShapeIndex& index) {
1404 return ShapeUtil::GetSubshape(instruction->shape(), index)
1405 .IsTuple() ||
1406 index.front() > 1;
1407 });
1408 break;
1409 case HloOpcode::kAsyncUpdate:
1410 // AsyncUpdate produces a tuple of {{aliased operands}, {destination},
1411 // contexts} where all of the array-typed values alias with the
1412 // operand. So, only tuple-shaped values are defined by AsyncUpdate.
1413 define_all_values([&](const ShapeIndex& index) {
1414 return ShapeUtil::GetSubshape(instruction->shape(), index)
1415 .IsTuple();
1416 });
1417 break;
1418 case HloOpcode::kAsyncDone:
1419 // AsyncDone's output aliases its output.
1420 break;
1421 case HloOpcode::kCopyStart:
1422 // CopyStart produces a tuple of {destination buffer, aliased operand,
1423 // U32 context}.
1424 define_value_at(/*index=*/{});
1425 define_value_at(/*index=*/{0});
1426 define_value_at(/*index=*/{2});
1427 break;
1428 case HloOpcode::kCopyDone:
1429 // CopyDone consumes a tuple produced by CopyStart and produces an
1430 // element. Its output aliases its input tuple element {0}.
1431 break;
1432 case HloOpcode::kAllGatherStart:
1433 // AllGatherStart produces a tuple of
1434 // {aliased operand, destination buffer}.
1435 define_value_at(/*index=*/{});
1436 define_value_at(/*index=*/{1});
1437 break;
1438 case HloOpcode::kAllGatherDone:
1439 // AllGatherDone's output aliases its input tuple element {1}.
1440 if (instruction->shape().IsTuple()) {
1441 define_value_at(/*index=*/{});
1442 }
1443 break;
1444 case HloOpcode::kAllReduceDone:
1445 // AllReduceDone's output aliases its input.
1446 break;
1447 case HloOpcode::kCollectivePermuteStart:
1448 // CollectivePermuteStart produces a tuple of
1449 // {aliased operand, destination buffer, U32 context, U32 context}.
1450 define_value_at(/*index=*/{});
1451 define_value_at(/*index=*/{1});
1452 define_value_at(/*index=*/{2});
1453 define_value_at(/*index=*/{3});
1454 if (instruction->operand_count() > 1) {
1455 CHECK_EQ(instruction->operand_count(), 4);
1456 if (instruction->operand(1)->shape().IsTuple()) {
1457 for (int i = 0; i < ShapeUtil::TupleElementCount(
1458 instruction->operand(1)->shape());
1459 ++i) {
1460 define_value_at(/*index=*/{1, i});
1461 }
1462 }
1463 }
1464 break;
1465 case HloOpcode::kCollectivePermuteDone:
1466 // CollectivePermuteDone's output aliases its input tuple element {1}.
1467 if (instruction->shape().IsTuple()) {
1468 define_value_at(/*index=*/{});
1469 }
1470 break;
1471 case HloOpcode::kRecvDone:
1472 // RecvDone produces a two-element tuple. Element zero aliases its
1473 // input tuple element {0}; element one is a token.
1474 define_value_at(/*index=*/{});
1475 define_value_at(/*index=*/{1});
1476 break;
1477 case HloOpcode::kSend:
1478 // Send produces a tuple of {aliased operand, U32 context, token},
1479 // therefore only defines the top-level tuple and the tuple elements
1480 // at {1} and {2}.
1481 define_value_at(/*index=*/{});
1482 define_value_at(/*index=*/{1});
1483 define_value_at(/*index=*/{2});
1484 break;
1485 default:
1486 define_all_values();
1487 break;
1488 }
1489 }
1490 }
1491
1492 return OkStatus();
1493 }
1494
OptimizePhiValues()1495 void HloDataflowAnalysis::OptimizePhiValues() {
1496 // Only applicable to SSA form where phis are defined.
1497 if (!ssa_form_) {
1498 return;
1499 }
1500
1501 VLOG(1) << "Before phi graph optimization";
1502 XLA_VLOG_LINES(1, phi_graph_.ToString());
1503 phi_graph_.Optimize();
1504 VLOG(1) << "After phi graph optimization";
1505 XLA_VLOG_LINES(1, phi_graph_.ToString());
1506
1507 for (const HloComputation* computation : module_.computations()) {
1508 for (HloInstruction* instruction : computation->instructions()) {
1509 InstructionValueSet& instruction_value_set =
1510 GetInstructionValueSet(instruction);
1511 VLOG(1) << "inst: " << instruction->name();
1512 VLOG(1) << instruction_value_set.ToString();
1513 instruction_value_set.ForEachMutableElement(
1514 [&](const xla::ShapeIndex& index, HloValueSet* value_set) {
1515 auto values = value_set->values();
1516 if (!(values.size() == 1 && values[0]->is_phi())) {
1517 return;
1518 }
1519 HloValue::Id phi_id = values[0]->id();
1520 HloValue::Id new_id = phi_graph_.FindOptimizedValue(phi_id);
1521 if (new_id != phi_id) {
1522 VLOG(1) << "Replacing " << values[0]->ToString() << " with "
1523 << GetValue(new_id).ToString();
1524 value_set->Clear();
1525 const HloValue& new_value = GetValue(new_id);
1526 value_set->AddValue(&new_value);
1527 MarkValueForDeletion(phi_id);
1528 }
1529 });
1530 }
1531 }
1532 }
1533
1534 /* static */
Run(const HloModule & module,bool ssa_form,bool bitcast_defines_value,const CanShareBuffer & can_share_buffer)1535 StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
1536 const HloModule& module, bool ssa_form, bool bitcast_defines_value,
1537 const CanShareBuffer& can_share_buffer) {
1538 VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name();
1539 XLA_VLOG_LINES(2, module.ToString());
1540
1541 auto dataflow_analysis = absl::WrapUnique(new HloDataflowAnalysis(
1542 module, ssa_form, bitcast_defines_value, can_share_buffer));
1543
1544 TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
1545 dataflow_analysis->Propagate();
1546 dataflow_analysis->OptimizePhiValues();
1547
1548 // Delete all values marked for deletion.
1549 dataflow_analysis->DeleteMarkedValues();
1550
1551 // Gather and set all non-definition positions of all values. Value deletion
1552 // is rare, so just use a vector indexed by Value::Id rather than a map from
1553 // Value::Id to positions. There should be very few holes in the vector, and
1554 // lookup is faster.
1555 std::vector<std::vector<HloPosition>> value_positions(
1556 dataflow_analysis->next_value_id_);
1557 for (const HloComputation* computation : module.computations()) {
1558 for (HloInstruction* instruction : computation->instructions()) {
1559 for (const auto& pair :
1560 dataflow_analysis->GetInstructionValueSet(instruction)) {
1561 const ShapeIndex& index = pair.first;
1562 const HloValueSet& value_set = pair.second;
1563 for (const HloValue* value : value_set.values()) {
1564 if (value->defining_instruction() != instruction) {
1565 value_positions[value->id()].push_back(
1566 HloPosition{instruction, index});
1567 }
1568 }
1569 }
1570 }
1571 }
1572 for (auto& pair : dataflow_analysis->values_) {
1573 HloValue::Id value_id = pair.first;
1574 HloValue& value = *pair.second;
1575 value.SetPositions(value_positions[value_id]);
1576 }
1577
1578 // Construct vector of values.
1579 dataflow_analysis->values_vector_.reserve(dataflow_analysis->values_.size());
1580 for (const auto& pair : dataflow_analysis->values_) {
1581 dataflow_analysis->values_vector_.push_back(pair.second.get());
1582 }
1583 absl::c_sort(dataflow_analysis->values_vector_, HloValue::IdLessThan);
1584
1585 TF_DCHECK_OK(dataflow_analysis->Verify());
1586
1587 XLA_VLOG_LINES(1, dataflow_analysis->ToString());
1588
1589 return std::move(dataflow_analysis);
1590 }
1591
Verify() const1592 Status HloDataflowAnalysis::Verify() const {
1593 // Verify each HloValue appears in the value sets that the value's positions()
1594 // indicate.
1595 for (const HloValue* value : values()) {
1596 for (const HloPosition& position : value->positions()) {
1597 const HloValueSet& value_set = GetValueSet(position);
1598 TF_RET_CHECK(absl::c_linear_search(value_set.values(), value))
1599 << "Value set at position " << position << " does not contain value "
1600 << value->ToShortString();
1601 }
1602 }
1603
1604 // For each value in each value set, verify that the value set's position
1605 // appears in the value's positions().
1606 for (const auto& computation : module_.computations()) {
1607 for (const auto& instruction : computation->instructions()) {
1608 for (const auto& pair : GetInstructionValueSet(instruction)) {
1609 const ShapeIndex& index = pair.first;
1610 const HloValueSet& value_set = pair.second;
1611 const HloPosition position{instruction, index};
1612 for (const HloValue* value : value_set.values()) {
1613 TF_RET_CHECK(absl::c_linear_search(value->positions(), position))
1614 << "Value set at position " << position
1615 << " unexpectedly contains value " << value->ToShortString();
1616 }
1617 }
1618 }
1619 }
1620
1621 return OkStatus();
1622 }
1623
DoesNotUseOperandBuffer(const HloInstruction * operand,const ShapeIndex & index,const HloInstruction * user) const1624 bool HloDataflowAnalysis::DoesNotUseOperandBuffer(
1625 const HloInstruction* operand, const ShapeIndex& index,
1626 const HloInstruction* user) const {
1627 // Return false if no value at 'operand' and 'index' is used at 'user'.
1628 for (const HloValue* value : GetValueSet(operand, index).values()) {
1629 for (const HloUse& use : value->GetUses()) {
1630 if (use.instruction == user) {
1631 if (user->IsLoopFusion()) {
1632 HloInstruction* fusion_param =
1633 user->fused_parameter(use.operand_number);
1634 const HloValue& value =
1635 GetValueDefinedAt(fusion_param, use.operand_index);
1636 return value.GetUses().empty();
1637 }
1638 return false;
1639 }
1640 }
1641 }
1642 return true;
1643 }
1644
IsInPlaceOperation(HloOpcode opcode)1645 /*static*/ bool HloDataflowAnalysis::IsInPlaceOperation(HloOpcode opcode) {
1646 return opcode == HloOpcode::kDynamicUpdateSlice ||
1647 opcode == HloOpcode::kScatter;
1648 }
1649
IsAsynchronousOperationStart(HloOpcode opcode)1650 /*static*/ bool HloDataflowAnalysis::IsAsynchronousOperationStart(
1651 HloOpcode opcode) {
1652 return opcode == HloOpcode::kSend || opcode == HloOpcode::kRecv ||
1653 opcode == HloOpcode::kCopyStart ||
1654 opcode == HloOpcode::kAllReduceStart ||
1655 opcode == HloOpcode::kAllGatherStart ||
1656 opcode == HloOpcode::kCollectivePermuteStart ||
1657 opcode == HloOpcode::kAsyncStart;
1658 }
1659
IsAsynchronousOperationDone(HloOpcode opcode)1660 /*static*/ bool HloDataflowAnalysis::IsAsynchronousOperationDone(
1661 HloOpcode opcode) {
1662 return opcode == HloOpcode::kSendDone || opcode == HloOpcode::kRecvDone ||
1663 opcode == HloOpcode::kCopyDone ||
1664 opcode == HloOpcode::kAllReduceDone ||
1665 opcode == HloOpcode::kAllGatherDone ||
1666 opcode == HloOpcode::kCollectivePermuteDone ||
1667 opcode == HloOpcode::kAsyncDone;
1668 }
1669
1670 namespace {
1671
1672 // Removes layers of tuple indirection introduced via 'tuple' and
1673 // 'get-tuple-element' instructions to more directly identify the source of the
1674 // given HLO value (identified by the given `ShapeIndex` into the output of the
1675 // given `HloInstruction`).
1676 //
1677 // e.g. for the following:
1678 // %x = some-op(...)
1679 // %foo = get-tuple-element(%x), index=0
1680 // %bar = tuple(%y, %foo)
1681 //
1682 // ... FollowTupleIndirection(%bar, {1}) == {%x, {0}} (output 1 of 'bar' comes
1683 // from output 0 of %x).
1684 //
1685 // Note that all 'tuple' instructions are followed before all
1686 // 'get-tuple-element' instructions are followed. This is because it is assumed
1687 // that tupling a value and then extracting it from the tuple again will not
1688 // occur in properly-optimized IR.
FollowTupleIndirection(const HloInstruction * instruction,ShapeIndex operand_index)1689 std::pair<const HloInstruction*, ShapeIndex> FollowTupleIndirection(
1690 const HloInstruction* instruction, ShapeIndex operand_index) {
1691 while (instruction->opcode() == HloOpcode::kTuple && !operand_index.empty()) {
1692 instruction = instruction->operand(operand_index.front());
1693 operand_index.pop_front();
1694 }
1695 while (instruction->opcode() == HloOpcode::kGetTupleElement) {
1696 operand_index.push_front(instruction->tuple_index());
1697 instruction = instruction->operand(0);
1698 }
1699
1700 return {instruction, operand_index};
1701 }
1702
1703 // Returns in-place input/output pairs for the given fusion instruction,
1704 // according to the aliasing rules for the corresponding fusion computation.
1705 //
1706 // `instruction` must be a fusion instruction.
1707 std::vector<std::pair<HloOperandIndex, ShapeIndex>>
GetFusionInstructionInPlaceInputOutputPairs(const HloInstruction * instruction)1708 GetFusionInstructionInPlaceInputOutputPairs(const HloInstruction* instruction) {
1709 std::vector<std::pair<HloOperandIndex, ShapeIndex>>
1710 in_place_input_output_pairs;
1711 // Each of these leaves represents one array output of the fusion that might
1712 // be aliased with one of the fusion computation's array inputs (both could be
1713 // nested arbitrarily deep inside tuples).
1714 for (const auto& fusion_output_array_shape :
1715 ShapeUtil::GetLeafShapes(instruction->shape())) {
1716 // Start from the root instruction of the fusion computation and follow
1717 // tuple indirection backwards to find the "output source", i.e. the
1718 // instruction that is the original source of the array output in question.
1719 // If there is no such indirection the "output source" will just be the
1720 // fusion root instruction itself.
1721 const HloInstruction* output_source_instruction =
1722 instruction->fused_expression_root();
1723 ShapeIndex output_source_index = fusion_output_array_shape.index;
1724 std::tie(output_source_instruction, output_source_index) =
1725 FollowTupleIndirection(output_source_instruction, output_source_index);
1726
1727 // The aliasing rules of the "output source" instruction determine the
1728 // aliasing rules for the entire fusion. If we can connect (following tuple
1729 // indirection) the input of an "in-place" pair to one of the fusion's
1730 // inputs, and the output of this "in-place" pair to the fusion output
1731 // in question, then this fusion input and output must alias.
1732 auto in_place_pairs = HloDataflowAnalysis::GetInPlaceInputOutputPairs(
1733 output_source_instruction);
1734 ShapeIndex in_place_input_index;
1735 const HloInstruction* in_place_input_source = nullptr;
1736
1737 for (const auto& output_source_in_place_pair : in_place_pairs) {
1738 const HloOperandIndex& input = output_source_in_place_pair.first;
1739 const ShapeIndex& output_index = output_source_in_place_pair.second;
1740 if (output_index == output_source_index) {
1741 // It is not possible for the same output to alias multiple inputs.
1742 CHECK(in_place_input_source == nullptr);
1743 in_place_input_source =
1744 output_source_instruction->operand(input.operand_number);
1745 in_place_input_index = input.operand_index;
1746 }
1747 }
1748
1749 if (in_place_input_source) {
1750 // Follow tuple indirection backwards from the instruction input to try to
1751 // find a fusion parameter. If found, that parameter aliases the current
1752 // output. If not, the current output aliases no input.
1753 std::tie(in_place_input_source, in_place_input_index) =
1754 FollowTupleIndirection(in_place_input_source, in_place_input_index);
1755
1756 if (in_place_input_source->opcode() == HloOpcode::kParameter) {
1757 in_place_input_output_pairs.emplace_back(
1758 HloOperandIndex{in_place_input_source->parameter_number(),
1759 in_place_input_index},
1760 fusion_output_array_shape.index);
1761 }
1762 }
1763 }
1764 return in_place_input_output_pairs;
1765 }
1766
1767 } // namespace
1768
1769 /*static*/ std::vector<std::pair<HloOperandIndex, ShapeIndex>>
GetInPlaceInputOutputPairs(const HloInstruction * instruction)1770 HloDataflowAnalysis::GetInPlaceInputOutputPairs(
1771 const HloInstruction* instruction) {
1772 if (IsInPlaceOperation(instruction->opcode())) {
1773 const HloScatterInstruction* scatter =
1774 DynCast<HloScatterInstruction>(instruction);
1775 if (scatter && scatter->scatter_operand_count() > 1) {
1776 std::vector<std::pair<HloOperandIndex, ShapeIndex>> pairs;
1777 pairs.reserve(scatter->scatter_operand_count());
1778 for (int i = 0, n = scatter->scatter_operand_count(); i < n; ++i) {
1779 pairs.emplace_back(HloOperandIndex{i, {}}, ShapeIndex{i});
1780 }
1781 return pairs;
1782 }
1783 return {{HloOperandIndex{0, {}}, {}}};
1784 } else if (instruction->opcode() == HloOpcode::kCollectivePermute &&
1785 instruction->operands().size() == 4) {
1786 if (instruction->operand(1)->shape().IsTuple()) {
1787 std::vector<std::pair<HloOperandIndex, ShapeIndex>> in_place_pairs(
1788 {{HloOperandIndex{1, {}}, {}}});
1789 for (int i = 0; i < instruction->operand(1)->shape().tuple_shapes_size();
1790 i++) {
1791 in_place_pairs.push_back({HloOperandIndex{1, {i}}, {i}});
1792 }
1793 return in_place_pairs;
1794 } else {
1795 return {{HloOperandIndex{1, {}}, {}}};
1796 }
1797 } else if (instruction->opcode() == HloOpcode::kCollectivePermuteStart &&
1798 instruction->operands().size() == 4) {
1799 if (instruction->operand(1)->shape().IsTuple()) {
1800 std::vector<std::pair<HloOperandIndex, ShapeIndex>> in_place_pairs(
1801 {{HloOperandIndex{1, {}}, {1}}});
1802 for (int i = 0; i < instruction->operand(1)->shape().tuple_shapes_size();
1803 i++) {
1804 in_place_pairs.push_back({HloOperandIndex{1, {i}}, {1, i}});
1805 }
1806 return in_place_pairs;
1807 } else {
1808 return {{HloOperandIndex{1, {}}, {1}}};
1809 }
1810 } else if (instruction->opcode() == HloOpcode::kCustomCall) {
1811 // Custom Calls previously assumed that aliased operands were
1812 // forwarded, but now supports modifiction semantics.
1813 const auto& aliasing_pairs = Cast<HloCustomCallInstruction>(instruction)
1814 ->output_to_operand_aliasing();
1815 std::vector<std::pair<HloOperandIndex, ShapeIndex>> in_place_pairs;
1816 in_place_pairs.reserve(aliasing_pairs.size());
1817 for (const auto& pair : aliasing_pairs) {
1818 ShapeIndex output_shape_index = pair.first;
1819 int64_t operand_index = pair.second.first;
1820 ShapeIndex operand_shape_index = pair.second.second;
1821 in_place_pairs.push_back(
1822 {HloOperandIndex{operand_index, {operand_shape_index}},
1823 output_shape_index});
1824 }
1825 return in_place_pairs;
1826 } else if (instruction->opcode() == HloOpcode::kAllReduceStart) {
1827 if (instruction->operands().size() == 1) {
1828 return {{HloOperandIndex{0, {}}, {}}};
1829 }
1830 std::vector<std::pair<HloOperandIndex, ShapeIndex>> in_place_pairs;
1831 in_place_pairs.reserve(instruction->operands().size());
1832 for (int i = 0; i < instruction->operands().size(); i++) {
1833 in_place_pairs.push_back({HloOperandIndex{i, {}}, {i}});
1834 }
1835 return in_place_pairs;
1836 } else if (instruction->opcode() == HloOpcode::kFusion) {
1837 return GetFusionInstructionInPlaceInputOutputPairs(instruction);
1838 }
1839
1840 return {};
1841 }
1842
CanShareOperandBufferWithUser(HloInstruction * operand,const ShapeIndex & operand_index,HloInstruction * user,const ShapeIndex & user_index) const1843 bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
1844 HloInstruction* operand, const ShapeIndex& operand_index,
1845 HloInstruction* user, const ShapeIndex& user_index) const {
1846 CHECK(user->IsUserOf(operand))
1847 << "user: " << user->ToString() << " operand: " << operand->ToString();
1848 if (operand->opcode() == HloOpcode::kConstant) {
1849 return false;
1850 }
1851
1852 const Shape& operand_subshape =
1853 ShapeUtil::GetSubshape(operand->shape(), operand_index);
1854 const Shape& user_subshape =
1855 ShapeUtil::GetSubshape(user->shape(), user_index);
1856 if (IsSliceInputFusion(*user)) {
1857 HloInstruction* fusion_param =
1858 user->fused_parameter(user->operand_index(operand));
1859 // We don't require the same dimensions but only the same number of elements
1860 // and type (to make sure the same buffer size).
1861 return operand_subshape.IsArray() && user_subshape.IsArray() &&
1862 ShapeUtil::ElementsIn(operand_subshape) ==
1863 ShapeUtil::ElementsIn(user_subshape) &&
1864 ShapeUtil::SameElementType(operand_subshape, user_subshape) &&
1865 AreTransitiveUsesEffectivelyElementwise(
1866 fusion_param, user->fused_expression_root(), user_index);
1867 }
1868
1869 auto shapes_equal = ShapeUtil::Equal(operand_subshape, user_subshape);
1870 // Check that operand and user emit the same shape and layout.
1871 if (shapes_equal) {
1872 // Must-alias relationship returns true for in-place operations (DUS and DUS
1873 // fusions), regardless of the backend.
1874 for (const auto& operand_and_output_index :
1875 GetInPlaceInputOutputPairs(user)) {
1876 if (operand_and_output_index.second != user_index) {
1877 continue;
1878 }
1879 for (const HloUse& use :
1880 GetUniqueValueAt(operand, operand_index).GetUses()) {
1881 if (use == HloUse{user, operand_and_output_index.first.operand_number,
1882 operand_and_output_index.first.operand_index}) {
1883 return true;
1884 }
1885 }
1886 }
1887 }
1888
1889 if (can_share_buffer_ != nullptr) {
1890 if (std::optional<bool> hint =
1891 can_share_buffer_(user, operand, user_index)) {
1892 return *hint;
1893 }
1894 }
1895
1896 if (!shapes_equal) {
1897 return false;
1898 }
1899
1900 if (user->opcode() == HloOpcode::kFusion) {
1901 HloInstruction* fusion_param =
1902 user->fused_parameter(user->operand_index(operand));
1903 const HloValue& fusion_param_value =
1904 GetValueDefinedAt(fusion_param, operand_index);
1905
1906 if (user->IsLoopFusion() || user->IsInputFusion()) {
1907 return AreTransitiveUsesElementwiseOrTuple(fusion_param);
1908 }
1909
1910 if (user->IsOutputFusion() &&
1911 user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
1912 // Output fusion with kAdd fused root.
1913
1914 // Check if one operand of kAdd fused root is kDot or kConvolution.
1915 auto* add = user->fused_expression_root();
1916 auto add_operand_it =
1917 absl::c_find_if(add->operands(), [&](HloInstruction* operand) {
1918 return operand->opcode() == HloOpcode::kConvolution ||
1919 operand->opcode() == HloOpcode::kDot;
1920 });
1921 if (add_operand_it == add->operands().end()) {
1922 return false;
1923 }
1924 auto* matched_add_operand = *add_operand_it;
1925 // Calculate operand index of 'add' operand which was not matched above.
1926 const int64_t other_add_operand_index =
1927 matched_add_operand == add->operand(0) ? 1 : 0;
1928 // Returns true iff there is exactly one use of 'operand' at shape index
1929 // 'operand_index', and this singleton use is the fused root (at operand
1930 // index 'other_add_operand_index').
1931 if (fusion_param_value.GetUses().size() == 1) {
1932 const HloUse& use = fusion_param_value.GetUses()[0];
1933 return use.instruction == user->fused_expression_root() &&
1934 use.operand_number == other_add_operand_index;
1935 }
1936 return false;
1937 }
1938 }
1939
1940 // There is nothing inherently wrong with while and conditional ops to have
1941 // input/output buffers to alias with each other, even when the indices are
1942 // different in the while case. It is a problem when this aliasing causes HLO
1943 // ops inside these while or conditional to have input/output buffer aliasing
1944 // that isn't allowed. So allow while and conditional to share buffers with
1945 // operands and we will discover any problematic sharing when we explore the
1946 // ops inside these computations.
1947 if (user->opcode() == HloOpcode::kWhile ||
1948 user->opcode() == HloOpcode::kConditional) {
1949 return true;
1950 }
1951
1952 if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
1953 user->opcode() == HloOpcode::kScatter ||
1954 user->opcode() == HloOpcode::kTriangularSolve) {
1955 // We eliminated other users in HloOrdering::LiveRangeStrictlyBefore
1956 // so here we just need to check that the use is at the right operand index.
1957 const auto operand_indices = user->OperandIndices(operand);
1958 int64_t operand_no = user->opcode() == HloOpcode::kTriangularSolve ? 1 : 0;
1959 return operand_indices.size() == 1 && operand_indices[0] == operand_no;
1960 }
1961 if (user->opcode() == HloOpcode::kSort) {
1962 // Only valid if there are no other users.
1963 if (operand->users().size() != 1) {
1964 return false;
1965 }
1966 // If we only sort keys, the output of sort is not a tuple, so we can always
1967 // share the buffer.
1968 if (user->operand_count() == 1) {
1969 return true;
1970 }
1971 CHECK(!user_index.empty());
1972 // Only share with the right tuple element buffer.
1973 const auto operand_indices = user->OperandIndices(operand);
1974 return operand_indices.size() == 1 && user_index[0] == operand_indices[0];
1975 }
1976 if (user->opcode() == HloOpcode::kCall) {
1977 // Get all uses of value defined by 'operand' at 'operand_index'.
1978 auto uses = GetValueDefinedAt(operand, operand_index).GetUses();
1979 // Return true iff:
1980 // *) There exists two uses of 'operand'.
1981 // *) One use is by 'user' (caller).
1982 // *) One use is by root instruction of called computation (callee root).
1983 // (Note: we check the root of the called computation, because the
1984 // root result buffer is required to alias with the Call result buffer).
1985 // *) The root instruction of the called computation is element-wise on
1986 // 'operand'.
1987 const bool found_caller_use =
1988 absl::c_find_if(uses, [user](const HloUse& use) {
1989 return use.instruction == user;
1990 }) != uses.end();
1991 auto* callee_root = user->to_apply()->root_instruction();
1992 const bool found_elementwise_callee_use =
1993 absl::c_find_if(uses, [callee_root](const HloUse& use) {
1994 return use.instruction == callee_root &&
1995 callee_root->IsElementwiseOnOperand(use.operand_number);
1996 }) != uses.end();
1997 return uses.size() == 2 && found_caller_use && found_elementwise_callee_use;
1998 }
1999
2000 // Loop fusions that contain transposing copies won't reach here as they have
2001 // different layouts, which fails the check in the beginning of this function.
2002 return user->IsElementwiseOnOperand(user->operand_index(operand));
2003 }
2004
2005 } // namespace xla
2006