1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
17
18 #include <algorithm>
19 #include <queue>
20 #include <vector>
21
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/container/inlined_vector.h"
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_cat.h"
26 #include "tensorflow/compiler/xla/map_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/platform/logging.h"
38
39 namespace xla {
40
41 using absl::StrAppend;
42 using absl::StrCat;
43
HloDataflowAnalysis(const HloModule & module,bool ssa_form,bool bitcast_defines_value,const CanShareBuffer & can_share_buffer)44 HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form,
45 bool bitcast_defines_value,
46 const CanShareBuffer& can_share_buffer)
47 : module_(module),
48 ssa_form_(ssa_form),
49 bitcast_defines_value_(bitcast_defines_value),
50 call_graph_(CallGraph::Build(&module)),
51 can_share_buffer_(can_share_buffer) {}
52
AreTransitiveUsesElementwiseOrTuple(const HloInstruction * inst)53 bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
54 const HloInstruction* inst) {
55 absl::flat_hash_set<const HloInstruction*> visited;
56 absl::InlinedVector<const HloInstruction*, 4> stack;
57 stack.push_back(inst);
58 while (!stack.empty()) {
59 const HloInstruction* current = stack.back();
60 stack.pop_back();
61 visited.insert(current);
62 for (const HloInstruction* user : current->users()) {
63 // Found a user that is non-elementwise on current instruction.
64 for (const int64 use_index : user->OperandIndices(current)) {
65 if (!user->IsElementwiseOnOperand(use_index) &&
66 user->opcode() != HloOpcode::kTuple) {
67 return false;
68 }
69 }
70 if (!visited.contains(user)) {
71 stack.push_back(user);
72 }
73 }
74 }
75 return true;
76 }
77
ValueIsDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const78 bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
79 const ShapeIndex& index) const {
80 const HloValueSet& value_set = GetValueSet(instruction, index);
81 if (value_set.values().size() != 1) {
82 return false;
83 }
84 return value_set.GetUniqueValue().defining_instruction() == instruction;
85 }
86
GetValueDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const87 const HloValue& HloDataflowAnalysis::GetValueDefinedAt(
88 const HloInstruction* instruction, const ShapeIndex& index) const {
89 CHECK(ValueIsDefinedAt(instruction, index)) << instruction->ToString();
90 return GetUniqueValueAt(instruction, index);
91 }
92
GetValueDefinedAt(const HloInstruction * instruction,const ShapeIndex & index)93 HloValue& HloDataflowAnalysis::GetValueDefinedAt(
94 const HloInstruction* instruction, const ShapeIndex& index) {
95 CHECK(ValueIsDefinedAt(instruction, index));
96 return GetUniqueValueAt(instruction, index);
97 }
98
NewHloValue(HloInstruction * instruction,const ShapeIndex & index,bool is_phi)99 HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction,
100 const ShapeIndex& index,
101 bool is_phi) {
102 const int64 value_id = next_value_id_++;
103 auto emplaced = values_.emplace(
104 std::piecewise_construct, std::forward_as_tuple(value_id),
105 std::forward_as_tuple(value_id, instruction, index, is_phi));
106 CHECK(emplaced.second);
107
108 VLOG(4) << "NewHloValue = " << emplaced.first->second.ToShortString();
109
110 return &emplaced.first->second;
111 }
112
MarkValueForDeletion(HloValue::Id value_id)113 void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) {
114 HloValue& value = values_.at(value_id);
115 VLOG(4) << "MarkValueForDeletion(" << value.ToShortString() << ")";
116
117 value_ids_to_delete_.push_back(value_id);
118 }
119
DeleteMarkedValues()120 void HloDataflowAnalysis::DeleteMarkedValues() {
121 #ifndef NDEBUG
122 // Verify that no marked-for-deletion values are in any of the value sets.
123 absl::flat_hash_set<HloValue::Id> id_set(value_ids_to_delete_.begin(),
124 value_ids_to_delete_.end());
125 for (const auto& pair : value_sets_) {
126 const HloInstruction* instruction = pair.first;
127 const InstructionValueSet& instruction_value_set = pair.second;
128 for (const auto& index_value_set : instruction_value_set) {
129 const HloValueSet& value_set = index_value_set.second;
130 for (const HloValue* value : value_set.values()) {
131 DCHECK(!ContainsKey(id_set, value->id()))
132 << "Value " << value->ToShortString()
133 << " marked for deletion, but still exists in value set for "
134 "instruction "
135 << instruction->name();
136 }
137 }
138 }
139 #endif
140
141 for (HloValue::Id value_id : value_ids_to_delete_) {
142 values_.erase(value_id);
143 }
144 value_ids_to_delete_.clear();
145 }
146
ToString() const147 string HloDataflowAnalysis::ToString() const {
148 string out = StrCat("HloDataflowAnalysis, module ", module_.name(), "\n");
149 StrAppend(&out, " Instruction value sets:\n");
150 for (const HloComputation* computation : module_.computations()) {
151 for (const HloInstruction* instruction : computation->instructions()) {
152 StrAppend(&out, "Instruction: \n ", instruction->name(), ":\n");
153 if (instruction->shape().IsTuple()) {
154 GetInstructionValueSet(instruction)
155 .ForEachElement([this, &instruction, &out](
156 const ShapeIndex& index,
157 const HloValueSet& value_set) {
158 StrAppend(&out, " tuple index ", index.ToString(), ":\n");
159 for (const HloValue* value : value_set.values()) {
160 StrAppend(&out, " ", value->ToShortString(),
161 ValueIsDefinedAt(instruction, index) ? " (def)" : "",
162 "\n");
163 }
164 });
165 } else {
166 const HloValueSet& top_level_value_set =
167 GetValueSet(instruction, /*index=*/{});
168 for (const HloValue* value : top_level_value_set.values()) {
169 StrAppend(&out, " ", value->ToShortString(),
170 ValueIsDefinedAt(instruction) ? " (def)" : "", "\n");
171 }
172 }
173 }
174 }
175 StrAppend(&out, " HloValues:\n");
176 for (const HloValue* value : values()) {
177 StrAppend(&out, value->ToString(/*indent=*/4));
178 }
179 return out;
180 }
181
Phi(HloInstruction * instruction,absl::Span<const InstructionValueSet * const> inputs)182 bool HloDataflowAnalysis::Phi(
183 HloInstruction* instruction,
184 absl::Span<const InstructionValueSet* const> inputs) {
185 CHECK(ssa_form_);
186 VLOG(4) << "Phi(" << instruction->name() << ")";
187 VLOG(5) << "instruction value set = "
188 << GetInstructionValueSet(instruction).ToString();
189 for (const InstructionValueSet* input : inputs) {
190 VLOG(5) << "input value set = " << input->ToString();
191 }
192
193 if (bitcast_defines_value_) {
194 absl::c_for_each(inputs, [&](const InstructionValueSet* input) {
195 DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape()));
196 });
197 } else {
198 const Shape& shape = instruction->shape();
199 PrimitiveType ty = shape.element_type();
200 bool is_array = shape.IsArray();
201 absl::c_for_each(inputs, [&](const InstructionValueSet* input) {
202 DCHECK(ty == input->shape().element_type() &&
203 (!is_array || ShapeUtil::ElementsIn(shape) ==
204 ShapeUtil::ElementsIn(input->shape())));
205 });
206 }
207
208 bool changed = false;
209 for (auto& pair : GetInstructionValueSet(instruction)) {
210 const ShapeIndex& index = pair.first;
211 HloValueSet& value_set = pair.second;
212
213 // Positions with phi values should never have more than one value in the
214 // value set.
215 CHECK_LE(value_set.values().size(), 1);
216 const HloValue* current_value =
217 value_set.values().size() == 1 ? value_set.values()[0] : nullptr;
218
219 // Construct a vector of unique value IDs of the inputs.
220 // Don't add value ids where the input is equal to the definition.
221 std::vector<HloValue::Id> input_value_ids;
222 for (const InstructionValueSet* input : inputs) {
223 for (const HloValue* value : input->element(index).values()) {
224 if (value->defining_instruction() == instruction &&
225 value->defining_index() == index) {
226 continue;
227 }
228 input_value_ids.push_back(value->id());
229 }
230 }
231 absl::c_sort(input_value_ids);
232 input_value_ids.erase(
233 std::unique(input_value_ids.begin(), input_value_ids.end()),
234 input_value_ids.end());
235
236 // Remove the existing phi value (if it exists). The phi can be its own
237 // input, for example, in while body parameters where the body passes
238 // through the parameter value.
239 bool current_value_defined_here =
240 (current_value != nullptr &&
241 current_value->defining_instruction() == instruction &&
242 current_value->defining_index() == index);
243 if (current_value_defined_here) {
244 VLOG(5) << "current_value_defined_here: " << current_value->ToString();
245 CHECK(current_value->is_phi());
246 auto it = absl::c_find(input_value_ids, current_value->id());
247 if (it != input_value_ids.end()) {
248 input_value_ids.erase(it);
249 }
250 }
251 VLOG(5) << "after input_value_ids.size = " << input_value_ids.size();
252 if (input_value_ids.empty()) {
253 // A value set which has at least one element should never have its value
254 // set reduced to zero elements. During dataflow value sets only can go
255 // from empty to non-empty, not the reverse.
256 CHECK_EQ(value_set.values().size(), 0)
257 << "Instruction " << instruction->name() << " at index " << index
258 << " previously had non-empty value set. Value set: " << value_set;
259 } else if (input_value_ids.size() == 1) {
260 // Only a single value reaches this point. There should be no phi, and
261 // this value set should contain this single value.
262 const HloValue& new_value = GetValue(input_value_ids[0]);
263 if (current_value == nullptr) {
264 value_set.Clear();
265 value_set.AddValue(&new_value);
266 changed = true;
267 } else if (current_value != &new_value) {
268 if (current_value_defined_here) {
269 // Remove the existing phi.
270 MarkValueForDeletion(current_value->id());
271 }
272 value_set.Clear();
273 value_set.AddValue(&new_value);
274 changed = true;
275 }
276 } else {
277 // Multiple distinct values reach this point. A phi value is
278 // necessary.
279 CHECK_GT(input_value_ids.size(), 1);
280 if (current_value == nullptr ||
281 !(current_value->is_phi() && current_value_defined_here)) {
282 value_set.Clear();
283 value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true));
284 changed = true;
285 }
286 }
287 }
288 return changed;
289 }
290
GetValue(HloValue::Id value_id) const291 const HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) const {
292 return values_.at(value_id);
293 }
294
GetValue(HloValue::Id value_id)295 HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) {
296 return values_.at(value_id);
297 }
298
GetFlattenedValueSet(const HloInstruction * instruction) const299 HloValueSet HloDataflowAnalysis::GetFlattenedValueSet(
300 const HloInstruction* instruction) const {
301 HloValueSet value_set;
302
303 const InstructionValueSet& value_set_tree =
304 GetInstructionValueSet(instruction);
305
306 std::vector<const HloValueSet*> all_sets;
307 for (auto& pair : value_set_tree) {
308 const HloValueSet& value_set = pair.second;
309 all_sets.push_back(&value_set);
310 }
311 value_set.AssignUnionOf(all_sets);
312
313 return value_set;
314 }
315
GetValueSet(const HloInstruction * instruction,const ShapeIndex & index) const316 const HloValueSet& HloDataflowAnalysis::GetValueSet(
317 const HloInstruction* instruction, const ShapeIndex& index) const {
318 return GetInstructionValueSet(instruction).element(index);
319 }
320
GetValueSet(const HloInstruction * instruction,const ShapeIndex & index)321 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloInstruction* instruction,
322 const ShapeIndex& index) {
323 return *GetInstructionValueSet(instruction).mutable_element(index);
324 }
325
GetValueSet(const HloPosition & position) const326 const HloValueSet& HloDataflowAnalysis::GetValueSet(
327 const HloPosition& position) const {
328 return GetValueSet(position.instruction, position.index);
329 }
330
GetValueSet(const HloPosition & position)331 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloPosition& position) {
332 return GetValueSet(position.instruction, position.index);
333 }
334
UpdateBitcastValueSet(HloInstruction * bitcast)335 bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) {
336 CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast);
337 const InstructionValueSet& operand_set =
338 GetInstructionValueSet(bitcast->operand(0));
339 InstructionValueSet& bitcast_set = GetInstructionValueSet(bitcast);
340 if (!bitcast_defines_value_ && operand_set != bitcast_set) {
341 bitcast_set = operand_set;
342 return true;
343 }
344 return false;
345 }
346
UpdateSetDimensionSizeValueSet(HloInstruction * set_dimension_size)347 bool HloDataflowAnalysis::UpdateSetDimensionSizeValueSet(
348 HloInstruction* set_dimension_size) {
349 CHECK_EQ(set_dimension_size->opcode(), HloOpcode::kSetDimensionSize);
350 const InstructionValueSet& operand_set =
351 GetInstructionValueSet(set_dimension_size->operand(0));
352 InstructionValueSet& set_dimension_size_set =
353 GetInstructionValueSet(set_dimension_size);
354 if (operand_set != set_dimension_size_set) {
355 set_dimension_size_set = operand_set;
356 return true;
357 }
358 return false;
359 }
360
UpdateSendValueSet(HloInstruction * send)361 bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
362 CHECK_EQ(send->opcode(), HloOpcode::kSend);
363 bool changed = false;
364 // Send forwards the operand value to the output tuple at {0}.
365 for (auto& pair : GetInstructionValueSet(send->operand(0))) {
366 const ShapeIndex& operand_index = pair.first;
367 const HloValueSet& operand_value_set = pair.second;
368
369 ShapeIndex index = {0};
370 for (int64 i : operand_index) {
371 index.push_back(i);
372 }
373
374 HloValueSet& value_set = GetValueSet(send, index);
375 if (value_set != operand_value_set) {
376 value_set = operand_value_set;
377 changed = true;
378 }
379 }
380 return changed;
381 }
382
UpdateCopyStartValueSet(HloInstruction * copy_start)383 bool HloDataflowAnalysis::UpdateCopyStartValueSet(HloInstruction* copy_start) {
384 CHECK_EQ(copy_start->opcode(), HloOpcode::kCopyStart);
385 bool changed = false;
386 // CopyStart forwards the operand value to element {1} of its output.
387 const HloValueSet& operand_value_set = GetValueSet(copy_start->operand(0));
388 HloValueSet& value_set = GetValueSet(copy_start, {1});
389 if (value_set != operand_value_set) {
390 value_set = operand_value_set;
391 changed = true;
392 }
393 return changed;
394 }
395
UpdateCopyDoneValueSet(HloInstruction * copy_done)396 bool HloDataflowAnalysis::UpdateCopyDoneValueSet(HloInstruction* copy_done) {
397 CHECK_EQ(copy_done->opcode(), HloOpcode::kCopyDone);
398 bool changed = false;
399 // CopyDone forwards the operand value at {0} to element {} of its output.
400 const HloValueSet& operand_value_set =
401 GetValueSet(copy_done->operand(0), {0});
402 HloValueSet& value_set = GetValueSet(copy_done);
403 if (value_set != operand_value_set) {
404 value_set = operand_value_set;
405 changed = true;
406 }
407 return changed;
408 }
409
UpdateRecvDoneValueSet(HloInstruction * recv_done)410 bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) {
411 CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone);
412 bool changed = false;
413 // RecvDone forwards the operand value at {0} to element {0} of its output.
414 for (auto& pair : GetInstructionValueSet(recv_done)) {
415 ShapeIndex& index = pair.first;
416 HloValueSet& value_set = pair.second;
417
418 if (index.empty() || index[0] != 0) {
419 continue;
420 }
421
422 const HloValueSet& operand_value_set =
423 GetValueSet(recv_done->operand(0), index);
424 if (value_set != operand_value_set) {
425 value_set = operand_value_set;
426 changed = true;
427 }
428 }
429 return changed;
430 }
431
UpdateCallValueSet(HloInstruction * call)432 bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) {
433 CHECK_EQ(call->opcode(), HloOpcode::kCall);
434 InstructionValueSet& value_set = GetInstructionValueSet(call);
435 InstructionValueSet& root_value_set =
436 GetInstructionValueSet(call->to_apply()->root_instruction());
437 if (value_set != root_value_set) {
438 value_set = root_value_set;
439 return true;
440 }
441 return false;
442 }
443
UpdateConditionalValueSet(HloInstruction * conditional)444 bool HloDataflowAnalysis::UpdateConditionalValueSet(
445 HloInstruction* conditional) {
446 CHECK_EQ(conditional->opcode(), HloOpcode::kConditional);
447 std::vector<const InstructionValueSet*> inputs(conditional->branch_count());
448 for (int j = 0; j < conditional->branch_count(); ++j) {
449 inputs[j] = &GetInstructionValueSet(
450 conditional->branch_computation(j)->root_instruction());
451 }
452 if (ssa_form_) {
453 return Phi(conditional, inputs);
454 } else {
455 return GetInstructionValueSet(conditional).AssignUnionOf(inputs);
456 }
457 }
458
UpdateCopyValueSet(HloInstruction * copy)459 bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) {
460 CHECK_EQ(copy->opcode(), HloOpcode::kCopy);
461 bool changed = false;
462 for (auto& pair : GetInstructionValueSet(copy)) {
463 const ShapeIndex& index = pair.first;
464 if (index.empty()) {
465 // kCopy shallow copies and thus defines the top-level value so nothing to
466 // update.
467 continue;
468 }
469
470 HloValueSet& value_set = pair.second;
471 HloValueSet& operand_value_set = GetValueSet(copy->operand(0), index);
472 if (value_set != operand_value_set) {
473 value_set = operand_value_set;
474 changed = true;
475 }
476 }
477 return changed;
478 }
479
UpdateDomainValueSet(HloInstruction * domain)480 bool HloDataflowAnalysis::UpdateDomainValueSet(HloInstruction* domain) {
481 // Domain instructions just forward their operand. Given that domains can have
482 // a tuple operand, we iterate through its indexes, like for copies.
483 // Unlike copies though we also propagate the top-level value.
484 CHECK_EQ(domain->opcode(), HloOpcode::kDomain);
485 bool changed = false;
486 for (auto& pair : GetInstructionValueSet(domain)) {
487 const ShapeIndex& index = pair.first;
488 HloValueSet& value_set = pair.second;
489 HloValueSet& operand_value_set = GetValueSet(domain->operand(0), index);
490 if (value_set != operand_value_set) {
491 value_set = operand_value_set;
492 changed = true;
493 }
494 }
495 return changed;
496 }
497
UpdateAddDependencyValueSet(HloInstruction * add_dependency)498 bool HloDataflowAnalysis::UpdateAddDependencyValueSet(
499 HloInstruction* add_dependency) {
500 // AddDependency just forwards the value of its zero-th operand.
501 CHECK_EQ(add_dependency->opcode(), HloOpcode::kAddDependency);
502 const InstructionValueSet& operand_set =
503 GetInstructionValueSet(add_dependency->operand(0));
504 InstructionValueSet& add_dependency_set =
505 GetInstructionValueSet(add_dependency);
506 if (operand_set != add_dependency_set) {
507 add_dependency_set = operand_set;
508 return true;
509 }
510 return false;
511 }
512
UpdateGetTupleElementValueSet(HloInstruction * gte)513 bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) {
514 CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement);
515 bool changed = false;
516 // The GetTupleElement instruction forwards the values from the specified
517 // tuple element.
518 for (auto& pair : GetInstructionValueSet(gte)) {
519 const ShapeIndex& index = pair.first;
520 HloValueSet& value_set = pair.second;
521
522 // The corresponding ShapeIndex of the operand is simply the GTE ShapeIndex
523 // with the tuple element number prefixed.
524 ShapeIndex operand_index = {gte->tuple_index()};
525 for (int64 i : index) {
526 operand_index.push_back(i);
527 }
528
529 HloValueSet& operand_value_set =
530 GetValueSet(gte->operand(0), operand_index);
531 if (value_set != operand_value_set) {
532 value_set = operand_value_set;
533 changed = true;
534 }
535 }
536 return changed;
537 }
538
UpdateParameterValueSet(HloInstruction * parameter)539 bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) {
540 CHECK_EQ(parameter->opcode(), HloOpcode::kParameter);
541 const CallGraphNode& call_graph_node =
542 call_graph_->GetNode(parameter->parent());
543
544 // Subcomputations called in a parallel context (eg, map) do not have dataflow
545 // from the caller operands.
546 if (call_graph_node.context() == CallContext::kParallel ||
547 call_graph_node.caller_callsites().empty()) {
548 return false;
549 }
550 CHECK_EQ(call_graph_node.context(), CallContext::kSequential);
551
552 std::vector<const InstructionValueSet*> inputs;
553 bool need_phi = false;
554 for (const CallSite& callsite : call_graph_node.caller_callsites()) {
555 if (callsite.instruction()->opcode() == HloOpcode::kCall) {
556 // The operand values of a call instruction are forwarded to the
557 // respective parameter instruction of the subcomputation.
558 inputs.push_back(&GetInstructionValueSet(
559 callsite.instruction()->operand(parameter->parameter_number())));
560 } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
561 // In a while instruction, the while operand (ie, the init value) and the
562 // backedge are dataflow inputs to the parameter instruction. This is the
563 // case for parameters of both the body and condition computations.
564 CHECK_EQ(parameter->parameter_number(), 0);
565 inputs.push_back(
566 &GetInstructionValueSet(callsite.instruction()->operand(0)));
567 // If the parameter *is* the root, then don't consider it's current state
568 // (InstructionValueSet) as we are recomputing its current
569 // state. Otherwise, the parameter state would never be updated.
570 if (parameter !=
571 callsite.instruction()->while_body()->root_instruction()) {
572 inputs.push_back(&GetInstructionValueSet(
573 callsite.instruction()->while_body()->root_instruction()));
574 }
575 need_phi = true;
576 } else if (callsite.instruction()->opcode() == HloOpcode::kConditional) {
577 CHECK_EQ(parameter->parameter_number(), 0);
578 auto conditional = callsite.instruction();
579 // Conditional has branch_count+1 operands. Operand 0 is the branch_index,
580 // operands 1 and onward are the arguments to the branch computations.
581 //
582 // If the parameter belongs to conditional's branch 0 computation, then
583 // operand 1 is forwarded to this parameter instruction. If the parameter
584 // belongs to conditional's branch 5 computation, then operand 6 is
585 // forwarded to this parameter instruction.
586 bool found_parent = false;
587 for (int j = 0; j < conditional->branch_count(); ++j) {
588 if (parameter->parent() == conditional->branch_computation(j)) {
589 inputs.push_back(
590 &GetInstructionValueSet(conditional->operand(j + 1)));
591 found_parent = true;
592 break;
593 }
594 }
595 CHECK(found_parent);
596 need_phi = true;
597 } else {
598 LOG(FATAL) << "CallContext::kSequential computations should only be "
599 "called from call, while, or conditional instructions";
600 }
601 }
602
603 if (ssa_form_ && need_phi) {
604 return Phi(parameter, inputs);
605 } else {
606 return GetInstructionValueSet(parameter).AssignUnionOf(inputs);
607 }
608 }
609
UpdateTupleSelectValueSet(HloInstruction * select)610 bool HloDataflowAnalysis::UpdateTupleSelectValueSet(HloInstruction* select) {
611 CHECK_EQ(select->opcode(), HloOpcode::kTupleSelect);
612 // A phi value is not defined at a kTupleSelect instruction because
613 // kTupleSelect does not create a new value. Rather it forwards a value from
614 // its operands. This contrasts with kWhile instruction (which does define a
615 // phi value) which has in-place update semantics.
616 bool changed = false;
617 for (auto& pair : GetInstructionValueSet(select)) {
618 const ShapeIndex& index = pair.first;
619 if (index.empty()) {
620 // kTupleSelect copies (not forwards) the top-level value.
621 continue;
622 }
623 HloValueSet& value_set = pair.second;
624 changed |=
625 value_set.AssignUnionOf({&GetValueSet(select->operand(1), index),
626 &GetValueSet(select->operand(2), index)});
627 }
628 return changed;
629 }
630
UpdateTupleValueSet(HloInstruction * tuple)631 bool HloDataflowAnalysis::UpdateTupleValueSet(HloInstruction* tuple) {
632 CHECK_EQ(tuple->opcode(), HloOpcode::kTuple);
633 bool changed = false;
634 for (int64 i = 0; i < tuple->operands().size(); ++i) {
635 // Copy the value set(s) of each operand into the respective position in the
636 // kTuple instruction's value sets.
637 for (auto& pair : GetInstructionValueSet(tuple->operand(i))) {
638 const ShapeIndex& operand_index = pair.first;
639 HloValueSet& operand_value_set = pair.second;
640
641 ShapeIndex index = {i};
642 for (int64 op_index : operand_index) {
643 index.push_back(op_index);
644 }
645 HloValueSet& value_set = GetValueSet(tuple, index);
646
647 if (value_set != operand_value_set) {
648 value_set = operand_value_set;
649 changed = true;
650 }
651 }
652 }
653 return changed;
654 }
655
UpdateWhileValueSet(HloInstruction * xla_while)656 bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) {
657 CHECK_EQ(xla_while->opcode(), HloOpcode::kWhile);
658 const InstructionValueSet* const inputs[] = {
659 &GetInstructionValueSet(xla_while->while_body()->root_instruction()),
660 &GetInstructionValueSet(xla_while->operand(0))};
661 if (ssa_form_) {
662 return Phi(xla_while, inputs);
663 } else {
664 return GetInstructionValueSet(xla_while).AssignUnionOf(inputs);
665 }
666 }
667
UpdateInstructionValueSet(HloInstruction * instruction)668 bool HloDataflowAnalysis::UpdateInstructionValueSet(
669 HloInstruction* instruction) {
670 // Recompute from operands.
671 switch (instruction->opcode()) {
672 case HloOpcode::kAddDependency:
673 return UpdateAddDependencyValueSet(instruction);
674 case HloOpcode::kBitcast:
675 return UpdateBitcastValueSet(instruction);
676 case HloOpcode::kSetDimensionSize:
677 return UpdateSetDimensionSizeValueSet(instruction);
678 case HloOpcode::kDomain:
679 return UpdateDomainValueSet(instruction);
680 case HloOpcode::kCopy:
681 return UpdateCopyValueSet(instruction);
682 case HloOpcode::kGetTupleElement:
683 return UpdateGetTupleElementValueSet(instruction);
684 case HloOpcode::kTupleSelect:
685 return UpdateTupleSelectValueSet(instruction);
686 case HloOpcode::kTuple:
687 return UpdateTupleValueSet(instruction);
688 case HloOpcode::kParameter:
689 return UpdateParameterValueSet(instruction);
690 case HloOpcode::kCall:
691 return UpdateCallValueSet(instruction);
692 case HloOpcode::kWhile:
693 return UpdateWhileValueSet(instruction);
694 case HloOpcode::kSend:
695 return UpdateSendValueSet(instruction);
696 case HloOpcode::kRecvDone:
697 return UpdateRecvDoneValueSet(instruction);
698 case HloOpcode::kCopyStart:
699 return UpdateCopyStartValueSet(instruction);
700 case HloOpcode::kCopyDone:
701 return UpdateCopyDoneValueSet(instruction);
702 case HloOpcode::kConditional:
703 return UpdateConditionalValueSet(instruction);
704 default:
705 // Instruction does not forward HloValues (it defines all values in its
706 // output). No update is necessary.
707 return false;
708 }
709 }
710
Propagate()711 void HloDataflowAnalysis::Propagate() {
712 std::queue<HloInstruction*> worklist;
713 absl::flat_hash_set<HloInstruction*> workset;
714 auto add_to_worklist = [&worklist, &workset](HloInstruction* instruction) {
715 if (workset.insert(instruction).second) {
716 worklist.push(instruction);
717 }
718 };
719
720 for (HloComputation* computation : module_.computations()) {
721 for (HloInstruction* instruction : computation->instructions()) {
722 add_to_worklist(instruction);
723 }
724 }
725
726 while (!worklist.empty()) {
727 HloInstruction* instruction = worklist.front();
728 worklist.pop();
729 workset.erase(workset.find(instruction));
730
731 VLOG(3) << "Worklist top: " << instruction->name();
732 VLOG(3) << ToString();
733
734 if (!UpdateInstructionValueSet(instruction)) {
735 // No change to the instruction's value set.
736 VLOG(4) << "No change.";
737 continue;
738 }
739
740 VLOG(4) << "New value set for " << instruction->name() << ": "
741 << GetInstructionValueSet(instruction);
742
743 // Instruction value was updated. Add users to work list if we haven't
744 // already.
745 for (HloInstruction* user : instruction->users()) {
746 add_to_worklist(user);
747
748 // If user sequentially calls a computation, then the respective
749 // parameter(s) of the computation need to be updated.
750 if (user->opcode() == HloOpcode::kConditional) {
751 // If operand 0 is the use of instruction, then no parameters need to be
752 // updated, since that is the branch_index of the conditional.
753 // If operand n+1 is the use of instruction, then the branch_computation
754 // n's parameter need to be updated.
755 //
756 // Note that the same instruction can be used in multiple branches'
757 // operands.
758 for (int j = 0; j < user->branch_count(); ++j) {
759 if (user->operand(j + 1) == instruction) {
760 add_to_worklist(
761 user->branch_computation(j)->parameter_instruction(0));
762 }
763 }
764 } else {
765 for (HloComputation* called_computation : user->called_computations()) {
766 const CallGraphNode& call_graph_node =
767 call_graph_->GetNode(called_computation);
768 if (call_graph_node.context() == CallContext::kSequential) {
769 for (int64 operand_number : user->OperandIndices(instruction)) {
770 add_to_worklist(
771 called_computation->parameter_instruction(operand_number));
772 }
773 }
774 }
775 }
776 }
777
778 // If instruction is a root instruction, then propagate out to any calling
779 // instruction and across any while backedge.
780 if (instruction == instruction->parent()->root_instruction()) {
781 const CallGraphNode& call_graph_node =
782 call_graph_->GetNode(instruction->parent());
783 for (const CallSite& callsite : call_graph_node.caller_callsites()) {
784 if (callsite.instruction()->opcode() == HloOpcode::kCall ||
785 callsite.instruction()->opcode() == HloOpcode::kConditional) {
786 add_to_worklist(callsite.instruction());
787 } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
788 // Add the while itself, and the body and condition parameters.
789 add_to_worklist(callsite.instruction());
790 add_to_worklist(
791 callsite.instruction()->while_body()->parameter_instruction(0));
792 add_to_worklist(
793 callsite.instruction()->while_condition()->parameter_instruction(
794 0));
795 }
796 }
797 }
798 }
799 }
800
GetInstructionValueSet(const HloInstruction * instruction) const801 const InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
802 const HloInstruction* instruction) const {
803 return value_sets_.at(instruction);
804 }
805
GetInstructionValueSet(const HloInstruction * instruction)806 InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
807 const HloInstruction* instruction) {
808 return value_sets_.at(instruction);
809 }
810
InitializeInstructionValueSets()811 Status HloDataflowAnalysis::InitializeInstructionValueSets() {
812 for (const HloComputation* computation : module_.computations()) {
813 const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
814 for (HloInstruction* instruction : computation->instructions()) {
815 // Create an empty shape tree.
816 value_sets_.emplace(std::piecewise_construct,
817 std::forward_as_tuple(instruction),
818 std::forward_as_tuple(instruction->shape()));
819
820 // For each sub-shape of the instruction shape, add a new HloValue to its
821 // HloValueSet.
822 auto define_all_values = [this, &instruction]() {
823 for (auto& pair : GetInstructionValueSet(instruction)) {
824 const ShapeIndex& index = pair.first;
825 HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
826 GetValueSet(instruction, index).AddValue(value);
827 }
828 };
829
830 // Add a new HloValue to the HloValueSet corresponding to the given index
831 // of the instruction shape.
832 auto define_value_at = [this, &instruction](const ShapeIndex& index) {
833 HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
834 GetValueSet(instruction, index).AddValue(value);
835 };
836
837 switch (instruction->opcode()) {
838 case HloOpcode::kBitcast:
839 if (bitcast_defines_value_) {
840 define_all_values();
841 }
842 break;
843 case HloOpcode::kSetDimensionSize:
844 case HloOpcode::kAddDependency:
845 case HloOpcode::kWhile:
846 case HloOpcode::kCall:
847 case HloOpcode::kConditional:
848 case HloOpcode::kGetTupleElement:
849 case HloOpcode::kDomain:
850 // These instructions define no values. The values in their output
851 // flow from their operands or from cross computation dataflow.
852 break;
853 case HloOpcode::kParameter:
854 if (call_graph_node.context() == CallContext::kBoth) {
855 // We do not support a subcomputation that is called from both a
856 // parallel and sequential context. In this case, the parameter
857 // would both define a value and propagate a value from its
858 // caller. This limitation is not really a problem because the call
859 // graph is typically flattened.
860 return Unimplemented(
861 "Computation %s is called in both a parallel (eg, kMap) and "
862 "sequential (eg, kCall) context",
863 computation->name());
864 }
865 if (call_graph_node.caller_callsites().empty() ||
866 call_graph_node.context() == CallContext::kParallel) {
867 // Parameters of computations called in a parallel context (eg, map
868 // and reduce) as well as parameters of dead computations define all
869 // values in their output. Otherwise the values of the parameter
870 // come from the caller (eg, operands to the kCall instruction).
871 define_all_values();
872 }
873 break;
874 case HloOpcode::kCopy:
875 case HloOpcode::kTupleSelect:
876 case HloOpcode::kTuple:
877 // These instructions only define their top-level values. Any other
878 // values flow from their operands.
879 define_value_at(/*index=*/{});
880 break;
881 case HloOpcode::kCopyStart:
882 // CopyStart produces a tuple of {destination buffer, aliased operand,
883 // U32 context}.
884 define_value_at(/*index=*/{});
885 define_value_at(/*index=*/{0});
886 define_value_at(/*index=*/{2});
887 break;
888 case HloOpcode::kCopyDone:
889 // CopyDone consumes a tuple produced by CopyStart and produces an
890 // element. Its output aliases its input tuple element {0}.
891 break;
892 case HloOpcode::kRecvDone:
893 // RecvDone produces a two-element tuple. Element zero aliases its
894 // input tuple element {0}; element one is a token.
895 define_value_at(/*index=*/{});
896 define_value_at(/*index=*/{1});
897 break;
898 case HloOpcode::kSend:
899 // Send produces a tuple of {aliased operand, U32 context, token},
900 // therefore only defines the top-level tuple and the tuple elements
901 // at {1} and {2}.
902 define_value_at(/*index=*/{});
903 define_value_at(/*index=*/{1});
904 define_value_at(/*index=*/{2});
905 break;
906 default:
907 define_all_values();
908 break;
909 }
910 }
911 }
912
913 return Status::OK();
914 }
915
916 /* static */
Run(const HloModule & module,bool ssa_form,bool bitcast_defines_value,const CanShareBuffer & can_share_buffer)917 StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
918 const HloModule& module, bool ssa_form, bool bitcast_defines_value,
919 const CanShareBuffer& can_share_buffer) {
920 VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name();
921 XLA_VLOG_LINES(2, module.ToString());
922
923 auto dataflow_analysis = absl::WrapUnique(new HloDataflowAnalysis(
924 module, ssa_form, bitcast_defines_value, can_share_buffer));
925
926 TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
927 dataflow_analysis->Propagate();
928
929 // Delete all values marked for deletion.
930 dataflow_analysis->DeleteMarkedValues();
931
932 // Gather and set all non-definition positions of all values. Value deletion
933 // is rare, so just use a vector indexed by Value::Id rather than a map from
934 // Value::Id to positions. There should be very few holes in the vector, and
935 // lookup is faster.
936 std::vector<std::vector<HloPosition>> value_positions(
937 dataflow_analysis->next_value_id_);
938 for (const HloComputation* computation : module.computations()) {
939 for (HloInstruction* instruction : computation->instructions()) {
940 for (const auto& pair :
941 dataflow_analysis->GetInstructionValueSet(instruction)) {
942 const ShapeIndex& index = pair.first;
943 const HloValueSet& value_set = pair.second;
944 for (const HloValue* value : value_set.values()) {
945 if (value->defining_instruction() != instruction) {
946 value_positions[value->id()].push_back(
947 HloPosition{instruction, index});
948 }
949 }
950 }
951 }
952 }
953 for (auto& pair : dataflow_analysis->values_) {
954 HloValue::Id value_id = pair.first;
955 HloValue& value = pair.second;
956 value.SetPositionsAndComputeUses(value_positions[value_id]);
957 }
958
959 // Construct vector of values.
960 dataflow_analysis->values_vector_.reserve(dataflow_analysis->values_.size());
961 for (auto& pair : dataflow_analysis->values_) {
962 dataflow_analysis->values_vector_.push_back(&pair.second);
963 }
964 absl::c_sort(dataflow_analysis->values_vector_, HloValue::IdLessThan);
965
966 TF_DCHECK_OK(dataflow_analysis->Verify());
967
968 XLA_VLOG_LINES(1, dataflow_analysis->ToString());
969
970 return std::move(dataflow_analysis);
971 }
972
Verify() const973 Status HloDataflowAnalysis::Verify() const {
974 // Verify each HloValue appears in the value sets that the value's positions()
975 // indicate.
976 for (const HloValue* value : values()) {
977 for (const HloPosition& position : value->positions()) {
978 const HloValueSet& value_set = GetValueSet(position);
979 TF_RET_CHECK(absl::c_linear_search(value_set.values(), value))
980 << "Value set at position " << position << " does not contain value "
981 << value->ToShortString();
982 }
983 }
984
985 // For each value in each value set, verify that the value set's position
986 // appears in the value's positions().
987 for (const auto& computation : module_.computations()) {
988 for (const auto& instruction : computation->instructions()) {
989 for (const auto& pair : GetInstructionValueSet(instruction)) {
990 const ShapeIndex& index = pair.first;
991 const HloValueSet& value_set = pair.second;
992 const HloPosition position{instruction, index};
993 for (const HloValue* value : value_set.values()) {
994 TF_RET_CHECK(absl::c_linear_search(value->positions(), position))
995 << "Value set at position " << position
996 << " unexpectedly contains value " << value->ToShortString();
997 }
998 }
999 }
1000 }
1001
1002 return Status::OK();
1003 }
1004
DoesNotUseOperandBuffer(const HloInstruction * operand,const ShapeIndex & index,const HloInstruction * user) const1005 bool HloDataflowAnalysis::DoesNotUseOperandBuffer(
1006 const HloInstruction* operand, const ShapeIndex& index,
1007 const HloInstruction* user) const {
1008 // Return false if no value at 'operand' and 'index' is used at 'user'.
1009 for (const HloValue* value : GetValueSet(operand, index).values()) {
1010 for (const HloUse& use : value->uses()) {
1011 if (use.instruction == user) {
1012 if (user->IsLoopFusion()) {
1013 HloInstruction* fusion_param =
1014 user->fused_parameter(use.operand_number);
1015 const HloValue& value =
1016 GetValueDefinedAt(fusion_param, use.operand_index);
1017 return value.uses().empty();
1018 }
1019 return false;
1020 }
1021 }
1022 }
1023 return true;
1024 }
1025
1026 // Given a fusion whose root is a dynamic-update-slice op, determines whether
1027 // the fusion's output buffer can be shared with the buffer of fusion_param,
1028 // which must be a fused parameter of the fusion.
1029 //
1030 // Preconditions:
1031 //
1032 // - fusion's root is a dynamic-update-slice op.
1033 // - fusion_param is a parameter within the fusion.
1034 //
1035 // fusion_param may point to a subelement of the actual parameter instruction if
1036 // the param is a tuple; i.e. fusion_param->index() need not be the empty list.
1037 //
1038 // Returns true if:
1039 //
1040 // * fusion_param is used by the root of dynamic-update-slice as the "base" of
1041 // the update, i.e. the thing being updated, AND
1042 // * all other uses of fusion_param are dynamic-slices that slice the same
1043 // indices as are overwritten in the dynamic-update-slice.
1044 //
1045 // In the case that there are no other uses of fusion_param (last bullet point
1046 // is vacuously true) it's easy to see why an in-place DUS is safe; this is just
1047 // the "natural" implementation of DUS. If there are other users, in-place DUS
1048 // is safe on the assumption that the thread which writes element i of the
1049 // output will be the only one to read element i of fusion_param (via the
1050 // dynamic-slice ops).
CanDoInPlaceDynamicUpdateSlice(HloInstruction * fusion,const HloValue & fusion_param_value)1051 static bool CanDoInPlaceDynamicUpdateSlice(HloInstruction* fusion,
1052 const HloValue& fusion_param_value) {
1053 auto* root =
1054 Cast<HloDynamicUpdateSliceInstruction>(fusion->fused_expression_root());
1055 auto* fusion_param = fusion_param_value.instruction();
1056 CHECK_EQ(fusion_param->opcode(), HloOpcode::kParameter);
1057 CHECK_EQ(fusion_param->parent(), fusion->fused_instructions_computation());
1058
1059 // fusion_param must be used by the root as the "base" of the
1060 // dynamic-update-slice. The natural way to check this would be
1061 //
1062 // `if (root->operand(0) != fusion_param)`
1063 //
1064 // but we also have to handle the case where the fusion parameter is
1065 // tuple-shaped and we're considering just one element of that tuple, i.e.
1066 // fusion_param.index() != {}.
1067 if (absl::c_count_if(fusion_param_value.uses(), [&](const HloUse& use) {
1068 return use.instruction == root;
1069 }) != 1) {
1070 return false;
1071 }
1072
1073 // All other uses of fusion_param must be dynamic-slices that slice the same
1074 // indices as are overwritten by the dynamic-update-slice.
1075 for (const HloUse& use : fusion_param_value.uses()) {
1076 auto* user = use.instruction;
1077 if (user == root) {
1078 continue;
1079 }
1080
1081 // Check that `user` is a dynamic-slice op and has the same slice indices as
1082 // `root`.
1083 auto* ds = DynCast<HloDynamicSliceInstruction>(user);
1084 if (!ds || ds->index_operands() != root->index_operands()) {
1085 return false;
1086 }
1087 }
1088 return true;
1089 }
1090
CanShareOperandBufferWithUser(HloInstruction * operand,const ShapeIndex & operand_index,HloInstruction * user,const ShapeIndex & user_index) const1091 bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
1092 HloInstruction* operand, const ShapeIndex& operand_index,
1093 HloInstruction* user, const ShapeIndex& user_index) const {
1094 CHECK(user->IsUserOf(operand))
1095 << "user: " << user->ToString() << " operand: " << operand->ToString();
1096 const Shape& operand_subshape =
1097 ShapeUtil::GetSubshape(operand->shape(), operand_index);
1098 const Shape& user_subshape =
1099 ShapeUtil::GetSubshape(user->shape(), user_index);
1100
1101 // Check that operand and user emit the same shape and layout.
1102 if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
1103 return false;
1104 }
1105
1106 if (user->opcode() == HloOpcode::kFusion) {
1107 // Get the parameter associated with 'operand';
1108 HloInstruction* fusion_param =
1109 user->fused_parameter(user->operand_index(operand));
1110
1111 const HloValue& fusion_param_value =
1112 GetValueDefinedAt(fusion_param, operand_index);
1113
1114 // TODO(b/80315712): This code is in a bit of a weird intermediate state
1115 // at the moment. The in-place DUS check really needs to be common to all
1116 // backends, so it runs first. Then we run the backend-specific check if
1117 // provided, or go through the target-independent check if not.
1118 // Unfortunately, the notionally "target-independent" path actually contains
1119 // some target-specific code, so we can't run all of it *in addition* to the
1120 // target-specific function, like the interface documentation says.
1121 if (user->fused_expression_root()->opcode() ==
1122 HloOpcode::kDynamicUpdateSlice) {
1123 return CanDoInPlaceDynamicUpdateSlice(user, fusion_param_value);
1124 }
1125 }
1126
1127 if (can_share_buffer_ != nullptr) {
1128 if (absl::optional<bool> hint =
1129 can_share_buffer_(user, operand, user_index)) {
1130 return *hint;
1131 }
1132 }
1133
1134 if (user->opcode() == HloOpcode::kFusion) {
1135 HloInstruction* fusion_param =
1136 user->fused_parameter(user->operand_index(operand));
1137 const HloValue& fusion_param_value =
1138 GetValueDefinedAt(fusion_param, operand_index);
1139
1140 if (user->IsLoopFusion() || user->IsInputFusion()) {
1141 return AreTransitiveUsesElementwiseOrTuple(fusion_param);
1142 }
1143
1144 if (user->IsOutputFusion() &&
1145 user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
1146 // Output fusion with kAdd fused root.
1147
1148 // Check if one operand of kAdd fused root is kDot or kConvolution.
1149 auto* add = user->fused_expression_root();
1150 auto add_operand_it =
1151 absl::c_find_if(add->operands(), [&](HloInstruction* operand) {
1152 return operand->opcode() == HloOpcode::kConvolution ||
1153 operand->opcode() == HloOpcode::kDot;
1154 });
1155 if (add_operand_it == add->operands().end()) {
1156 return false;
1157 }
1158 auto* matched_add_operand = *add_operand_it;
1159 // Calculate operand index of 'add' operand which was not matched above.
1160 const int64 other_add_operand_index =
1161 matched_add_operand == add->operand(0) ? 1 : 0;
1162 // Returns true iff there is exactly one use of 'operand' at shape index
1163 // 'operand_index', and this singleton use is the fused root (at operand
1164 // index 'other_add_operand_index').
1165 if (fusion_param_value.uses().size() == 1) {
1166 const HloUse& use = fusion_param_value.uses()[0];
1167 return use.instruction == user->fused_expression_root() &&
1168 use.operand_number == other_add_operand_index;
1169 }
1170 return false;
1171 }
1172 }
1173
1174 if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
1175 user->opcode() == HloOpcode::kScatter ||
1176 user->opcode() == HloOpcode::kTriangularSolve ||
1177 user->opcode() == HloOpcode::kWhile) {
1178 // We eliminated other users in HloOrdering::LiveRangeStrictlyBefore
1179 // so here we just need to check that the use is at the right operand index.
1180 const auto operand_indices = user->OperandIndices(operand);
1181 int64 operand_no = user->opcode() == HloOpcode::kTriangularSolve ? 1 : 0;
1182 return operand_indices.size() == 1 && operand_indices[0] == operand_no;
1183 }
1184 if (user->opcode() == HloOpcode::kSort) {
1185 // Only valid if there are no other users.
1186 if (operand->users().size() != 1) {
1187 return false;
1188 }
1189 // If we only sort keys, the output of sort is not a tuple, so we can always
1190 // share the buffer.
1191 if (user->operand_count() == 1) {
1192 return true;
1193 }
1194 CHECK(!user_index.empty());
1195 // Only share with the right tuple element buffer.
1196 const auto operand_indices = user->OperandIndices(operand);
1197 return operand_indices.size() == 1 && user_index[0] == operand_indices[0];
1198 }
1199 if (user->opcode() == HloOpcode::kCall) {
1200 // Get all uses of value defined by 'operand' at 'operand_index'.
1201 const auto& uses = GetValueDefinedAt(operand, operand_index).uses();
1202 // Return true iff:
1203 // *) There exists two uses of 'operand'.
1204 // *) One use is by 'user' (caller).
1205 // *) One use is by root instruction of called computation (callee root).
1206 // (Note: we check the root of the called computation, because the
1207 // root result buffer is required to alias with the Call result buffer).
1208 // *) The root instruction of the called computation is element-wise on
1209 // 'operand'.
1210 const bool found_caller_use =
1211 absl::c_find_if(uses, [user](const HloUse& use) {
1212 return use.instruction == user;
1213 }) != uses.end();
1214 auto* callee_root = user->to_apply()->root_instruction();
1215 const bool found_elementwise_callee_use =
1216 absl::c_find_if(uses, [callee_root](const HloUse& use) {
1217 return use.instruction == callee_root &&
1218 callee_root->IsElementwiseOnOperand(use.operand_number);
1219 }) != uses.end();
1220 return uses.size() == 2 && found_caller_use && found_elementwise_callee_use;
1221 }
1222
1223 // Loop fusions that contain transposing copies won't reach here as they have
1224 // different layouts, which fails the check in the beginning of this function.
1225 return user->IsElementwiseOnOperand(user->operand_index(operand));
1226 }
1227
1228 } // namespace xla
1229