1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
17
18 #include <algorithm>
19 #include <queue>
20 #include <vector>
21
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/container/inlined_vector.h"
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_cat.h"
26 #include "tensorflow/compiler/xla/map_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/platform/logging.h"
38
39 namespace xla {
40
41 using absl::StrAppend;
42 using absl::StrCat;
43
HloDataflowAnalysis(const HloModule & module,bool ssa_form,bool bitcast_defines_value,const FusionCanShareBufferFunction & fusion_can_share_buffer)44 HloDataflowAnalysis::HloDataflowAnalysis(
45 const HloModule& module, bool ssa_form, bool bitcast_defines_value,
46 const FusionCanShareBufferFunction& fusion_can_share_buffer)
47 : module_(module),
48 ssa_form_(ssa_form),
49 bitcast_defines_value_(bitcast_defines_value),
50 call_graph_(CallGraph::Build(&module)),
51 fusion_can_share_buffer_(fusion_can_share_buffer) {}
52
AreTransitiveUsesElementwiseOrTuple(const HloInstruction * inst)53 bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
54 const HloInstruction* inst) {
55 absl::flat_hash_set<const HloInstruction*> visited;
56 absl::InlinedVector<const HloInstruction*, 4> stack;
57 stack.push_back(inst);
58 while (!stack.empty()) {
59 const HloInstruction* current = stack.back();
60 stack.pop_back();
61 visited.insert(current);
62 for (const HloInstruction* user : current->users()) {
63 // Found a user that is non-elementwise on current instruction.
64 for (const int64 use_index : user->OperandIndices(current)) {
65 if (!user->IsElementwiseOnOperand(use_index) &&
66 user->opcode() != HloOpcode::kTuple) {
67 return false;
68 }
69 }
70 if (!visited.contains(user)) {
71 stack.push_back(user);
72 }
73 }
74 }
75 return true;
76 }
77
ValueIsDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const78 bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
79 const ShapeIndex& index) const {
80 const HloValueSet& value_set = GetValueSet(instruction, index);
81 if (value_set.values().size() != 1) {
82 return false;
83 }
84 return value_set.GetUniqueValue().defining_instruction() == instruction;
85 }
86
GetValueDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const87 const HloValue& HloDataflowAnalysis::GetValueDefinedAt(
88 const HloInstruction* instruction, const ShapeIndex& index) const {
89 CHECK(ValueIsDefinedAt(instruction, index)) << instruction->ToString();
90 return GetUniqueValueAt(instruction, index);
91 }
92
GetValueDefinedAt(const HloInstruction * instruction,const ShapeIndex & index)93 HloValue& HloDataflowAnalysis::GetValueDefinedAt(
94 const HloInstruction* instruction, const ShapeIndex& index) {
95 CHECK(ValueIsDefinedAt(instruction, index));
96 return GetUniqueValueAt(instruction, index);
97 }
98
NewHloValue(HloInstruction * instruction,const ShapeIndex & index,bool is_phi)99 HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction,
100 const ShapeIndex& index,
101 bool is_phi) {
102 const int64 value_id = next_value_id_++;
103 auto emplaced = values_.emplace(
104 std::piecewise_construct, std::forward_as_tuple(value_id),
105 std::forward_as_tuple(value_id, instruction, index, is_phi));
106 CHECK(emplaced.second);
107
108 VLOG(4) << "NewHloValue = " << emplaced.first->second.ToShortString();
109
110 return &emplaced.first->second;
111 }
112
MarkValueForDeletion(HloValue::Id value_id)113 void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) {
114 HloValue& value = values_.at(value_id);
115 VLOG(4) << "MarkValueForDeletion(" << value.ToShortString() << ")";
116
117 value_ids_to_delete_.push_back(value_id);
118 }
119
DeleteMarkedValues()120 void HloDataflowAnalysis::DeleteMarkedValues() {
121 #ifndef NDEBUG
122 // Verify that no marked-for-deletion values are in any of the value sets.
123 absl::flat_hash_set<HloValue::Id> id_set(value_ids_to_delete_.begin(),
124 value_ids_to_delete_.end());
125 for (const auto& pair : value_sets_) {
126 const HloInstruction* instruction = pair.first;
127 const InstructionValueSet& instruction_value_set = pair.second;
128 for (const auto& index_value_set : instruction_value_set) {
129 const HloValueSet& value_set = index_value_set.second;
130 for (const HloValue* value : value_set.values()) {
131 DCHECK(!ContainsKey(id_set, value->id()))
132 << "Value " << value->ToShortString()
133 << " marked for deletion, but still exists in value set for "
134 "instruction "
135 << instruction->name();
136 }
137 }
138 }
139 #endif
140
141 for (HloValue::Id value_id : value_ids_to_delete_) {
142 values_.erase(value_id);
143 }
144 value_ids_to_delete_.clear();
145 }
146
ToString() const147 string HloDataflowAnalysis::ToString() const {
148 string out = StrCat("HloDataflowAnalysis, module ", module_.name(), "\n");
149 StrAppend(&out, " Instruction value sets:\n");
150 for (const HloComputation* computation : module_.computations()) {
151 for (const HloInstruction* instruction : computation->instructions()) {
152 StrAppend(&out, " ", instruction->name(), ":\n");
153 if (instruction->shape().IsTuple()) {
154 GetInstructionValueSet(instruction)
155 .ForEachElement([this, &instruction, &out](
156 const ShapeIndex& index,
157 const HloValueSet& value_set) {
158 StrAppend(&out, " tuple index ", index.ToString(), ":\n");
159 for (const HloValue* value : value_set.values()) {
160 StrAppend(&out, " ", value->ToShortString(),
161 ValueIsDefinedAt(instruction, index) ? " (def)" : "",
162 "\n");
163 }
164 });
165 } else {
166 const HloValueSet& top_level_value_set =
167 GetValueSet(instruction, /*index=*/{});
168 for (const HloValue* value : top_level_value_set.values()) {
169 StrAppend(&out, " ", value->ToShortString(),
170 ValueIsDefinedAt(instruction) ? " (def)" : "", "\n");
171 }
172 }
173 }
174 }
175 StrAppend(&out, " HloValues:\n");
176 for (const HloValue* value : values()) {
177 StrAppend(&out, value->ToString(/*indent=*/4));
178 }
179 return out;
180 }
181
Phi(HloInstruction * instruction,absl::Span<const InstructionValueSet * const> inputs)182 bool HloDataflowAnalysis::Phi(
183 HloInstruction* instruction,
184 absl::Span<const InstructionValueSet* const> inputs) {
185 CHECK(ssa_form_);
186 VLOG(4) << "Phi(" << instruction->name() << ")";
187 VLOG(5) << "instruction value set = "
188 << GetInstructionValueSet(instruction).ToString();
189 for (const InstructionValueSet* input : inputs) {
190 VLOG(5) << "input value set = " << input->ToString();
191 }
192 for (const InstructionValueSet* input : inputs) {
193 DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape()));
194 }
195
196 bool changed = false;
197 for (auto& pair : GetInstructionValueSet(instruction)) {
198 const ShapeIndex& index = pair.first;
199 HloValueSet& value_set = pair.second;
200
201 // Positions with phi values should never have more than one value in the
202 // value set.
203 CHECK_LE(value_set.values().size(), 1);
204 const HloValue* current_value =
205 value_set.values().size() == 1 ? value_set.values()[0] : nullptr;
206
207 // Construct a vector of unique value IDs of the inputs.
208 // Don't add value ids where the input is equal to the definition.
209 std::vector<HloValue::Id> input_value_ids;
210 for (const InstructionValueSet* input : inputs) {
211 for (const HloValue* value : input->element(index).values()) {
212 if (value->defining_instruction() == instruction &&
213 value->defining_index() == index) {
214 continue;
215 }
216 input_value_ids.push_back(value->id());
217 }
218 }
219 absl::c_sort(input_value_ids);
220 input_value_ids.erase(
221 std::unique(input_value_ids.begin(), input_value_ids.end()),
222 input_value_ids.end());
223
224 // Remove the existing phi value (if it exists). The phi can be its own
225 // input, for example, in while body parameters where the body passes
226 // through the parameter value.
227 bool current_value_defined_here =
228 (current_value != nullptr &&
229 current_value->defining_instruction() == instruction &&
230 current_value->defining_index() == index);
231 if (current_value_defined_here) {
232 VLOG(5) << "current_value_defined_here: " << current_value->ToString();
233 CHECK(current_value->is_phi());
234 auto it = absl::c_find(input_value_ids, current_value->id());
235 if (it != input_value_ids.end()) {
236 input_value_ids.erase(it);
237 }
238 }
239 VLOG(5) << "after input_value_ids.size = " << input_value_ids.size();
240 if (input_value_ids.empty()) {
241 // A value set which has at least one element should never have its value
242 // set reduced to zero elements. During dataflow value sets only can go
243 // from empty to non-empty, not the reverse.
244 CHECK_EQ(value_set.values().size(), 0)
245 << "Instruction " << instruction->name() << " at index " << index
246 << " previously had non-empty value set. Value set: " << value_set;
247 } else if (input_value_ids.size() == 1) {
248 // Only a single value reaches this point. There should be no phi, and
249 // this value set should contain this single value.
250 const HloValue& new_value = GetValue(input_value_ids[0]);
251 if (current_value == nullptr) {
252 value_set.Clear();
253 value_set.AddValue(&new_value);
254 changed = true;
255 } else if (current_value != &new_value) {
256 if (current_value_defined_here) {
257 // Remove the existing phi.
258 MarkValueForDeletion(current_value->id());
259 }
260 value_set.Clear();
261 value_set.AddValue(&new_value);
262 changed = true;
263 }
264 } else {
265 // Multiple distinct values reach this point. A phi value is
266 // necessary.
267 CHECK_GT(input_value_ids.size(), 1);
268 if (current_value == nullptr ||
269 !(current_value->is_phi() && current_value_defined_here)) {
270 value_set.Clear();
271 value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true));
272 changed = true;
273 }
274 }
275 }
276 return changed;
277 }
278
GetValue(HloValue::Id value_id) const279 const HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) const {
280 return values_.at(value_id);
281 }
282
GetValue(HloValue::Id value_id)283 HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) {
284 return values_.at(value_id);
285 }
286
GetValueSet(const HloInstruction * instruction,const ShapeIndex & index) const287 const HloValueSet& HloDataflowAnalysis::GetValueSet(
288 const HloInstruction* instruction, const ShapeIndex& index) const {
289 return GetInstructionValueSet(instruction).element(index);
290 }
291
GetValueSet(const HloInstruction * instruction,const ShapeIndex & index)292 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloInstruction* instruction,
293 const ShapeIndex& index) {
294 return *GetInstructionValueSet(instruction).mutable_element(index);
295 }
296
GetValueSet(const HloPosition & position) const297 const HloValueSet& HloDataflowAnalysis::GetValueSet(
298 const HloPosition& position) const {
299 return GetValueSet(position.instruction, position.index);
300 }
301
GetValueSet(const HloPosition & position)302 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloPosition& position) {
303 return GetValueSet(position.instruction, position.index);
304 }
305
UpdateBitcastValueSet(HloInstruction * bitcast)306 bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) {
307 CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast);
308 const InstructionValueSet& operand_set =
309 GetInstructionValueSet(bitcast->operand(0));
310 InstructionValueSet& bitcast_set = GetInstructionValueSet(bitcast);
311 if (!bitcast_defines_value_ && operand_set != bitcast_set) {
312 bitcast_set = operand_set;
313 return true;
314 }
315 return false;
316 }
317
UpdateSendValueSet(HloInstruction * send)318 bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
319 CHECK_EQ(send->opcode(), HloOpcode::kSend);
320 bool changed = false;
321 // Send forwards the operand value to the output tuple at {0}.
322 for (auto& pair : GetInstructionValueSet(send->operand(0))) {
323 const ShapeIndex& operand_index = pair.first;
324 const HloValueSet& operand_value_set = pair.second;
325
326 ShapeIndex index = {0};
327 for (int64 i : operand_index) {
328 index.push_back(i);
329 }
330
331 HloValueSet& value_set = GetValueSet(send, index);
332 if (value_set != operand_value_set) {
333 value_set = operand_value_set;
334 changed = true;
335 }
336 }
337 return changed;
338 }
339
UpdateRecvDoneValueSet(HloInstruction * recv_done)340 bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) {
341 CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone);
342 bool changed = false;
343 // RecvDone forwards the operand value at {0} to element {0} of its output.
344 for (auto& pair : GetInstructionValueSet(recv_done)) {
345 ShapeIndex& index = pair.first;
346 HloValueSet& value_set = pair.second;
347
348 if (index.empty() || index[0] != 0) {
349 continue;
350 }
351
352 const HloValueSet& operand_value_set =
353 GetValueSet(recv_done->operand(0), index);
354 if (value_set != operand_value_set) {
355 value_set = operand_value_set;
356 changed = true;
357 }
358 }
359 return changed;
360 }
361
UpdateCallValueSet(HloInstruction * call)362 bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) {
363 CHECK_EQ(call->opcode(), HloOpcode::kCall);
364 InstructionValueSet& value_set = GetInstructionValueSet(call);
365 InstructionValueSet& root_value_set =
366 GetInstructionValueSet(call->to_apply()->root_instruction());
367 if (value_set != root_value_set) {
368 value_set = root_value_set;
369 return true;
370 }
371 return false;
372 }
373
UpdateConditionalValueSet(HloInstruction * conditional)374 bool HloDataflowAnalysis::UpdateConditionalValueSet(
375 HloInstruction* conditional) {
376 CHECK_EQ(conditional->opcode(), HloOpcode::kConditional);
377 std::vector<const InstructionValueSet*> inputs(conditional->branch_count());
378 for (int j = 0; j < conditional->branch_count(); ++j) {
379 inputs[j] = &GetInstructionValueSet(
380 conditional->branch_computation(j)->root_instruction());
381 }
382 if (ssa_form_) {
383 return Phi(conditional, inputs);
384 } else {
385 return GetInstructionValueSet(conditional).AssignUnionOf(inputs);
386 }
387 }
388
UpdateCopyValueSet(HloInstruction * copy)389 bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) {
390 CHECK_EQ(copy->opcode(), HloOpcode::kCopy);
391 bool changed = false;
392 for (auto& pair : GetInstructionValueSet(copy)) {
393 const ShapeIndex& index = pair.first;
394 if (index.empty()) {
395 // kCopy shallow copies and thus defines the top-level value so nothing to
396 // update.
397 continue;
398 }
399
400 HloValueSet& value_set = pair.second;
401 HloValueSet& operand_value_set = GetValueSet(copy->operand(0), index);
402 if (value_set != operand_value_set) {
403 value_set = operand_value_set;
404 changed = true;
405 }
406 }
407 return changed;
408 }
409
UpdateDomainValueSet(HloInstruction * domain)410 bool HloDataflowAnalysis::UpdateDomainValueSet(HloInstruction* domain) {
411 // Domain instructions just forward their operand. Given that domains can have
412 // a tuple operand, we iterate through its indexes, like for copies.
413 // Unlike copies though we also propagate the top-level value.
414 CHECK_EQ(domain->opcode(), HloOpcode::kDomain);
415 bool changed = false;
416 for (auto& pair : GetInstructionValueSet(domain)) {
417 const ShapeIndex& index = pair.first;
418 HloValueSet& value_set = pair.second;
419 HloValueSet& operand_value_set = GetValueSet(domain->operand(0), index);
420 if (value_set != operand_value_set) {
421 value_set = operand_value_set;
422 changed = true;
423 }
424 }
425 return changed;
426 }
427
UpdateAddDependencyValueSet(HloInstruction * add_dependency)428 bool HloDataflowAnalysis::UpdateAddDependencyValueSet(
429 HloInstruction* add_dependency) {
430 // AddDependency just forwards the value of its zero-th operand.
431 CHECK_EQ(add_dependency->opcode(), HloOpcode::kAddDependency);
432 const InstructionValueSet& operand_set =
433 GetInstructionValueSet(add_dependency->operand(0));
434 InstructionValueSet& add_dependency_set =
435 GetInstructionValueSet(add_dependency);
436 if (operand_set != add_dependency_set) {
437 add_dependency_set = operand_set;
438 return true;
439 }
440 return false;
441 }
442
UpdateGetTupleElementValueSet(HloInstruction * gte)443 bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) {
444 CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement);
445 bool changed = false;
446 // The GetTupleElement instruction forwards the values from the specified
447 // tuple element.
448 for (auto& pair : GetInstructionValueSet(gte)) {
449 const ShapeIndex& index = pair.first;
450 HloValueSet& value_set = pair.second;
451
452 // The corresponding ShapeIndex of the operand is simply the GTE ShapeIndex
453 // with the tuple element number prefixed.
454 ShapeIndex operand_index = {gte->tuple_index()};
455 for (int64 i : index) {
456 operand_index.push_back(i);
457 }
458
459 HloValueSet& operand_value_set =
460 GetValueSet(gte->operand(0), operand_index);
461 if (value_set != operand_value_set) {
462 value_set = operand_value_set;
463 changed = true;
464 }
465 }
466 return changed;
467 }
468
UpdateParameterValueSet(HloInstruction * parameter)469 bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) {
470 CHECK_EQ(parameter->opcode(), HloOpcode::kParameter);
471 const CallGraphNode& call_graph_node =
472 call_graph_->GetNode(parameter->parent());
473
474 // Subcomputations called in a parallel context (eg, map) do not have dataflow
475 // from the caller operands.
476 if (call_graph_node.context() == CallContext::kParallel ||
477 call_graph_node.caller_callsites().empty()) {
478 return false;
479 }
480 CHECK_EQ(call_graph_node.context(), CallContext::kSequential);
481
482 std::vector<const InstructionValueSet*> inputs;
483 bool need_phi = false;
484 for (const CallSite& callsite : call_graph_node.caller_callsites()) {
485 if (callsite.instruction()->opcode() == HloOpcode::kCall) {
486 // The operand values of a call instruction are forwarded to the
487 // respective parameter instruction of the subcomputation.
488 inputs.push_back(&GetInstructionValueSet(
489 callsite.instruction()->operand(parameter->parameter_number())));
490 } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
491 // In a while instruction, the while operand (ie, the init value) and the
492 // backedge are dataflow inputs to the parameter instruction. This is the
493 // case for parameters of both the body and condition computations.
494 CHECK_EQ(parameter->parameter_number(), 0);
495 inputs.push_back(
496 &GetInstructionValueSet(callsite.instruction()->operand(0)));
497 // If the parameter *is* the root, then don't consider it's current state
498 // (InstructionValueSet) as we are recomputing its current
499 // state. Otherwise, the parameter state would never be updated.
500 if (parameter !=
501 callsite.instruction()->while_body()->root_instruction()) {
502 inputs.push_back(&GetInstructionValueSet(
503 callsite.instruction()->while_body()->root_instruction()));
504 }
505 need_phi = true;
506 } else if (callsite.instruction()->opcode() == HloOpcode::kConditional) {
507 CHECK_EQ(parameter->parameter_number(), 0);
508 auto conditional = callsite.instruction();
509 // Conditional has branch_count+1 operands. Operand 0 is the branch_index,
510 // operands 1 and onward are the arguments to the branch computations.
511 //
512 // If the parameter belongs to conditional's branch 0 computation, then
513 // operand 1 is forwarded to this parameter instruction. If the parameter
514 // belongs to conditional's branch 5 computation, then operand 6 is
515 // forwarded to this parameter instruction.
516 bool found_parent = false;
517 for (int j = 0; j < conditional->branch_count(); ++j) {
518 if (parameter->parent() == conditional->branch_computation(j)) {
519 inputs.push_back(
520 &GetInstructionValueSet(conditional->operand(j + 1)));
521 found_parent = true;
522 break;
523 }
524 }
525 CHECK(found_parent);
526 need_phi = true;
527 } else {
528 LOG(FATAL) << "CallContext::kSequential computations should only be "
529 "called from call, while, or conditional instructions";
530 }
531 }
532
533 if (ssa_form_ && need_phi) {
534 return Phi(parameter, inputs);
535 } else {
536 return GetInstructionValueSet(parameter).AssignUnionOf(inputs);
537 }
538 }
539
UpdateTupleSelectValueSet(HloInstruction * select)540 bool HloDataflowAnalysis::UpdateTupleSelectValueSet(HloInstruction* select) {
541 CHECK_EQ(select->opcode(), HloOpcode::kTupleSelect);
542 // A phi value is not defined at a kTupleSelect instruction because
543 // kTupleSelect does not create a new value. Rather it forwards a value from
544 // its operands. This contrasts with kWhile instruction (which does define a
545 // phi value) which has in-place update semantics.
546 bool changed = false;
547 for (auto& pair : GetInstructionValueSet(select)) {
548 const ShapeIndex& index = pair.first;
549 if (index.empty()) {
550 // kTupleSelect copies (not forwards) the top-level value.
551 continue;
552 }
553 HloValueSet& value_set = pair.second;
554 changed |=
555 value_set.AssignUnionOf({&GetValueSet(select->operand(1), index),
556 &GetValueSet(select->operand(2), index)});
557 }
558 return changed;
559 }
560
UpdateTupleValueSet(HloInstruction * tuple)561 bool HloDataflowAnalysis::UpdateTupleValueSet(HloInstruction* tuple) {
562 CHECK_EQ(tuple->opcode(), HloOpcode::kTuple);
563 bool changed = false;
564 for (int64 i = 0; i < tuple->operands().size(); ++i) {
565 // Copy the value set(s) of each operand into the respective position in the
566 // kTuple instruction's value sets.
567 for (auto& pair : GetInstructionValueSet(tuple->operand(i))) {
568 const ShapeIndex& operand_index = pair.first;
569 HloValueSet& operand_value_set = pair.second;
570
571 ShapeIndex index = {i};
572 for (int64 op_index : operand_index) {
573 index.push_back(op_index);
574 }
575 HloValueSet& value_set = GetValueSet(tuple, index);
576
577 if (value_set != operand_value_set) {
578 value_set = operand_value_set;
579 changed = true;
580 }
581 }
582 }
583 return changed;
584 }
585
UpdateWhileValueSet(HloInstruction * xla_while)586 bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) {
587 CHECK_EQ(xla_while->opcode(), HloOpcode::kWhile);
588 const InstructionValueSet* const inputs[] = {
589 &GetInstructionValueSet(xla_while->while_body()->root_instruction()),
590 &GetInstructionValueSet(xla_while->operand(0))};
591 if (ssa_form_) {
592 return Phi(xla_while, inputs);
593 } else {
594 return GetInstructionValueSet(xla_while).AssignUnionOf(inputs);
595 }
596 }
597
UpdateInstructionValueSet(HloInstruction * instruction)598 bool HloDataflowAnalysis::UpdateInstructionValueSet(
599 HloInstruction* instruction) {
600 // Recompute from operands.
601 switch (instruction->opcode()) {
602 case HloOpcode::kAddDependency:
603 return UpdateAddDependencyValueSet(instruction);
604 case HloOpcode::kBitcast:
605 return UpdateBitcastValueSet(instruction);
606 case HloOpcode::kDomain:
607 return UpdateDomainValueSet(instruction);
608 case HloOpcode::kCopy:
609 return UpdateCopyValueSet(instruction);
610 case HloOpcode::kGetTupleElement:
611 return UpdateGetTupleElementValueSet(instruction);
612 case HloOpcode::kTupleSelect:
613 return UpdateTupleSelectValueSet(instruction);
614 case HloOpcode::kTuple:
615 return UpdateTupleValueSet(instruction);
616 case HloOpcode::kParameter:
617 return UpdateParameterValueSet(instruction);
618 case HloOpcode::kCall:
619 return UpdateCallValueSet(instruction);
620 case HloOpcode::kWhile:
621 return UpdateWhileValueSet(instruction);
622 case HloOpcode::kSend:
623 return UpdateSendValueSet(instruction);
624 case HloOpcode::kRecvDone:
625 return UpdateRecvDoneValueSet(instruction);
626 case HloOpcode::kConditional:
627 return UpdateConditionalValueSet(instruction);
628 default:
629 // Instruction does not forward HloValues (it defines all values in its
630 // output). No update is necessary.
631 return false;
632 }
633 }
634
Propagate()635 void HloDataflowAnalysis::Propagate() {
636 std::queue<HloInstruction*> worklist;
637 absl::flat_hash_set<HloInstruction*> workset;
638 auto add_to_worklist = [&worklist, &workset](HloInstruction* instruction) {
639 if (workset.insert(instruction).second) {
640 worklist.push(instruction);
641 }
642 };
643
644 for (HloComputation* computation : module_.computations()) {
645 for (HloInstruction* instruction : computation->instructions()) {
646 add_to_worklist(instruction);
647 }
648 }
649
650 while (!worklist.empty()) {
651 HloInstruction* instruction = worklist.front();
652 worklist.pop();
653 workset.erase(workset.find(instruction));
654
655 VLOG(3) << "Worklist top: " << instruction->name();
656 VLOG(3) << ToString();
657
658 if (!UpdateInstructionValueSet(instruction)) {
659 // No change to the instruction's value set.
660 VLOG(4) << "No change.";
661 continue;
662 }
663
664 VLOG(4) << "New value set for " << instruction->name() << ": "
665 << GetInstructionValueSet(instruction);
666
667 // Instruction value was updated. Add users to work list if we haven't
668 // already.
669 for (HloInstruction* user : instruction->users()) {
670 add_to_worklist(user);
671
672 // If user sequentially calls a computation, then the respective
673 // parameter(s) of the computation need to be updated.
674 if (user->opcode() == HloOpcode::kConditional) {
675 // If operand 0 is the use of instruction, then no parameters need to be
676 // updated, since that is the branch_index of the conditional.
677 // If operand n+1 is the use of instruction, then the branch_computation
678 // n's parameter need to be updated.
679 //
680 // Note that the same instruction can be used in multiple branches'
681 // operands.
682 for (int j = 0; j < user->branch_count(); ++j) {
683 if (user->operand(j + 1) == instruction) {
684 add_to_worklist(
685 user->branch_computation(j)->parameter_instruction(0));
686 }
687 }
688 } else {
689 for (HloComputation* called_computation : user->called_computations()) {
690 const CallGraphNode& call_graph_node =
691 call_graph_->GetNode(called_computation);
692 if (call_graph_node.context() == CallContext::kSequential) {
693 for (int64 operand_number : user->OperandIndices(instruction)) {
694 add_to_worklist(
695 called_computation->parameter_instruction(operand_number));
696 }
697 }
698 }
699 }
700 }
701
702 // If instruction is a root instruction, then propagate out to any calling
703 // instruction and across any while backedge.
704 if (instruction == instruction->parent()->root_instruction()) {
705 const CallGraphNode& call_graph_node =
706 call_graph_->GetNode(instruction->parent());
707 for (const CallSite& callsite : call_graph_node.caller_callsites()) {
708 if (callsite.instruction()->opcode() == HloOpcode::kCall ||
709 callsite.instruction()->opcode() == HloOpcode::kConditional) {
710 add_to_worklist(callsite.instruction());
711 } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
712 // Add the while itself, and the body and condition parameters.
713 add_to_worklist(callsite.instruction());
714 add_to_worklist(
715 callsite.instruction()->while_body()->parameter_instruction(0));
716 add_to_worklist(
717 callsite.instruction()->while_condition()->parameter_instruction(
718 0));
719 }
720 }
721 }
722 }
723 }
724
GetInstructionValueSet(const HloInstruction * instruction) const725 const InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
726 const HloInstruction* instruction) const {
727 return value_sets_.at(instruction);
728 }
729
GetInstructionValueSet(const HloInstruction * instruction)730 InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
731 const HloInstruction* instruction) {
732 return value_sets_.at(instruction);
733 }
734
InitializeInstructionValueSets()735 Status HloDataflowAnalysis::InitializeInstructionValueSets() {
736 for (const HloComputation* computation : module_.computations()) {
737 const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
738 for (HloInstruction* instruction : computation->instructions()) {
739 // Create an empty shape tree.
740 value_sets_.emplace(std::piecewise_construct,
741 std::forward_as_tuple(instruction),
742 std::forward_as_tuple(instruction->shape()));
743
744 // Lambda to set the value set to define all values in the output of the
745 // instruction.
746 auto define_all_values = [this, &instruction](bool is_phi = false) {
747 for (auto& pair : GetInstructionValueSet(instruction)) {
748 const ShapeIndex& index = pair.first;
749 HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
750 GetValueSet(instruction, index).AddValue(value);
751 }
752 };
753
754 // Lambda to set the value set to define only the top-level buffer in the
755 // output of the instruction. Any other values flow from the operands of
756 // the instruction (or from cross-computation dataflow).
757 auto define_top_level_only = [this, &instruction]() {
758 HloValue* value =
759 NewHloValue(instruction, /*index=*/{}, /*is_phi=*/false);
760 GetValueSet(instruction, /*index=*/{}).AddValue(value);
761 };
762
763 // Lambda to set the value set at the given index of the output.
764 auto define_value_at = [this, &instruction](const ShapeIndex& index) {
765 HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
766 GetValueSet(instruction, index).AddValue(value);
767 };
768
769 switch (instruction->opcode()) {
770 case HloOpcode::kBitcast:
771 if (bitcast_defines_value_) {
772 define_all_values();
773 }
774 break;
775 case HloOpcode::kAddDependency:
776 case HloOpcode::kWhile:
777 case HloOpcode::kCall:
778 case HloOpcode::kConditional:
779 case HloOpcode::kGetTupleElement:
780 case HloOpcode::kDomain:
781 // These instructions define no values. The values in their output
782 // flow from their operands or from cross computation dataflow.
783 break;
784 case HloOpcode::kParameter:
785 if (call_graph_node.context() == CallContext::kBoth) {
786 // We do not support a subcomputation that is called from both a
787 // parallel and sequential context. In this case, the parameter
788 // would both define a value and propagate a value from its
789 // caller. This limitation is not really a problem because the call
790 // graph is typically flattened.
791 return Unimplemented(
792 "Computation %s is called in both a parallel (eg, kMap) and "
793 "sequential (eg, kCall) context",
794 computation->name());
795 }
796 if (call_graph_node.caller_callsites().empty() ||
797 call_graph_node.context() == CallContext::kParallel) {
798 // Parameters of computations called in a parallel context (eg, map
799 // and reduce) as well as parameters of dead computations define all
800 // values in their output. Otherwise the values of the parameter
801 // come from the caller (eg, operands to the kCall instruction).
802 define_all_values();
803 }
804 break;
805 case HloOpcode::kCopy:
806 case HloOpcode::kTupleSelect:
807 case HloOpcode::kTuple:
808 // These instructions only define their top-level values. Any other
809 // values flow from their operands.
810 define_top_level_only();
811 break;
812 case HloOpcode::kRecvDone:
813 // RecvDone produces a two-element tuple. Element zero aliases its
814 // input tuple element {0}; element one is a token.
815 define_value_at(/*index=*/{});
816 define_value_at(/*index=*/{1});
817 break;
818 case HloOpcode::kSend:
819 // Send produces a tuple of {aliased operand, U32 context, token},
820 // therefore only defines the top-level tuple and the tuple elements
821 // at {1} and {2}.
822 define_value_at(/*index=*/{});
823 define_value_at(/*index=*/{1});
824 define_value_at(/*index=*/{2});
825 break;
826 default:
827 define_all_values();
828 break;
829 }
830 }
831 }
832
833 return Status::OK();
834 }
835
836 /* static */
Run(const HloModule & module,bool ssa_form,bool bitcast_defines_value,const FusionCanShareBufferFunction & fusion_can_share_buffer)837 StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
838 const HloModule& module, bool ssa_form, bool bitcast_defines_value,
839 const FusionCanShareBufferFunction& fusion_can_share_buffer) {
840 VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name();
841 XLA_VLOG_LINES(2, module.ToString());
842
843 auto dataflow_analysis = absl::WrapUnique(new HloDataflowAnalysis(
844 module, ssa_form, bitcast_defines_value, fusion_can_share_buffer));
845
846 TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
847 dataflow_analysis->Propagate();
848
849 // Delete all values marked for deletion.
850 dataflow_analysis->DeleteMarkedValues();
851
852 // Gather and set all non-definition positions of all values. Value deletion
853 // is rare, so just use a vector indexed by Value::Id rather than a map from
854 // Value::Id to positions. There should be very few holes in the vector, and
855 // lookup is faster.
856 std::vector<std::vector<HloPosition>> value_positions(
857 dataflow_analysis->next_value_id_);
858 for (const HloComputation* computation : module.computations()) {
859 for (HloInstruction* instruction : computation->instructions()) {
860 for (const auto& pair :
861 dataflow_analysis->GetInstructionValueSet(instruction)) {
862 const ShapeIndex& index = pair.first;
863 const HloValueSet& value_set = pair.second;
864 for (const HloValue* value : value_set.values()) {
865 if (value->defining_instruction() != instruction) {
866 value_positions[value->id()].push_back(
867 HloPosition{instruction, index});
868 }
869 }
870 }
871 }
872 }
873 for (auto& pair : dataflow_analysis->values_) {
874 HloValue::Id value_id = pair.first;
875 HloValue& value = pair.second;
876 value.SetPositionsAndComputeUses(value_positions[value_id]);
877 }
878
879 // Construct vector of values.
880 dataflow_analysis->values_vector_.reserve(dataflow_analysis->values_.size());
881 for (auto& pair : dataflow_analysis->values_) {
882 dataflow_analysis->values_vector_.push_back(&pair.second);
883 }
884 absl::c_sort(dataflow_analysis->values_vector_, HloValue::IdLessThan);
885
886 TF_DCHECK_OK(dataflow_analysis->Verify());
887
888 XLA_VLOG_LINES(1, dataflow_analysis->ToString());
889
890 return std::move(dataflow_analysis);
891 }
892
Verify() const893 Status HloDataflowAnalysis::Verify() const {
894 // Verify each HloValue appears in the value sets that the value's positions()
895 // indicate.
896 for (const HloValue* value : values()) {
897 for (const HloPosition& position : value->positions()) {
898 const HloValueSet& value_set = GetValueSet(position);
899 TF_RET_CHECK(absl::c_linear_search(value_set.values(), value))
900 << "Value set at position " << position << " does not contain value "
901 << value->ToShortString();
902 }
903 }
904
905 // For each value in each value set, verify that the value set's position
906 // appears in the value's positions().
907 for (const auto& computation : module_.computations()) {
908 for (const auto& instruction : computation->instructions()) {
909 for (const auto& pair : GetInstructionValueSet(instruction)) {
910 const ShapeIndex& index = pair.first;
911 const HloValueSet& value_set = pair.second;
912 const HloPosition position{instruction, index};
913 for (const HloValue* value : value_set.values()) {
914 TF_RET_CHECK(absl::c_linear_search(value->positions(), position))
915 << "Value set at position " << position
916 << " unexpectedly contains value " << value->ToShortString();
917 }
918 }
919 }
920 }
921
922 return Status::OK();
923 }
924
DoesNotUseOperandBuffer(const HloInstruction * operand,const ShapeIndex & index,const HloInstruction * user) const925 bool HloDataflowAnalysis::DoesNotUseOperandBuffer(
926 const HloInstruction* operand, const ShapeIndex& index,
927 const HloInstruction* user) const {
928 // Return false if no value at 'operand' and 'index' is used at 'user'.
929 for (const HloValue* value : GetValueSet(operand, index).values()) {
930 for (const HloUse& use : value->uses()) {
931 if (use.instruction == user) {
932 if (user->opcode() == HloOpcode::kFusion &&
933 user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
934 HloInstruction* fusion_param =
935 user->fused_parameter(use.operand_number);
936 const HloValue& value =
937 GetValueDefinedAt(fusion_param, use.operand_index);
938 return value.uses().empty();
939 }
940 return false;
941 }
942 }
943 }
944 return true;
945 }
946
947 // Given a fusion whose root is a dynamic-update-slice op, determines whether
948 // the fusion's output buffer can be shared with the buffer of fusion_param,
949 // which must be a fused parameter of the fusion.
950 //
951 // Preconditions:
952 //
953 // - fusion's root is a dynamic-update-slice op.
954 // - fusion_param is a parameter within the fusion.
955 //
956 // fusion_param may point to a subelement of the actual parameter instruction if
957 // the param is a tuple; i.e. fusion_param->index() need not be the empty list.
958 //
959 // Returns true if:
960 //
961 // * fusion is a loop or input fusion, AND
962 // * fusion_param is used by the root of dynamic-update-slice as the "base" of
963 // the update, i.e. the thing being updated, AND
964 // * all other uses of fusion_param are dynamic-slices that slice the same
965 // indices as are overwritten in the dynamic-update-slice.
966 //
967 // In the case that there are no other uses of fusion_param (last bullet point
968 // is vacuously true) it's easy to see why an in-place DUS is safe; this is just
969 // the "natural" implementation of DUS. If there are other users, in-place DUS
970 // is safe on the assumption that the thread which writes element i of the
971 // output will be the only one to read element i of fusion_param (via the
972 // dynamic-slice ops).
CanDoInPlaceDynamicUpdateSlice(HloInstruction * fusion,const HloValue & fusion_param_value)973 static bool CanDoInPlaceDynamicUpdateSlice(HloInstruction* fusion,
974 const HloValue& fusion_param_value) {
975 auto* root =
976 Cast<HloDynamicUpdateSliceInstruction>(fusion->fused_expression_root());
977 auto* fusion_param = fusion_param_value.instruction();
978 CHECK_EQ(fusion_param->opcode(), HloOpcode::kParameter);
979 CHECK_EQ(fusion_param->parent(), fusion->fused_instructions_computation());
980
981 // fusion must be a loop or input fusion.
982 auto kind = fusion->fusion_kind();
983 if (kind != HloInstruction::FusionKind::kLoop &&
984 kind != HloInstruction::FusionKind::kInput) {
985 return false;
986 }
987
988 // fusion_param must be used by the root as the "base" of the
989 // dynamic-update-slice. The natural way to check this would be
990 //
991 // `if (root->operand(0) != fusion_param)`
992 //
993 // but we also have to handle the case where the fusion parameter is
994 // tuple-shaped and we're considering just one element of that tuple, i.e.
995 // fusion_param.index() != {}.
996 if (absl::c_count_if(fusion_param_value.uses(), [&](const HloUse& use) {
997 return use.instruction == root;
998 }) != 1) {
999 return false;
1000 }
1001
1002 // All other uses of fusion_param must be dynamic-slices that slice the same
1003 // indices as are overwritten by the dynamic-update-slice.
1004 for (const HloUse& use : fusion_param_value.uses()) {
1005 auto* user = use.instruction;
1006 if (user == root) {
1007 continue;
1008 }
1009
1010 // Check that `user` is a dynamic-slice op and has the same slice indices as
1011 // `root`.
1012 auto* ds = DynCast<HloDynamicSliceInstruction>(user);
1013 if (!ds || ds->index_operands() != root->index_operands()) {
1014 return false;
1015 }
1016 }
1017 return true;
1018 }
1019
CanShareOperandBufferWithUser(HloInstruction * operand,const ShapeIndex & operand_index,HloInstruction * user,const ShapeIndex & user_index) const1020 bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
1021 HloInstruction* operand, const ShapeIndex& operand_index,
1022 HloInstruction* user, const ShapeIndex& user_index) const {
1023 CHECK(user->IsUserOf(operand))
1024 << "user: " << user->ToString() << " operand: " << operand->ToString();
1025 const Shape& operand_subshape =
1026 ShapeUtil::GetSubshape(operand->shape(), operand_index);
1027 const Shape& user_subshape =
1028 ShapeUtil::GetSubshape(user->shape(), user_index);
1029
1030 // Check that operand and user emit the same shape and layout.
1031 if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
1032 return false;
1033 }
1034
1035 if (user->opcode() == HloOpcode::kFusion) {
1036 // Get the parameter associated with 'operand';
1037 HloInstruction* fusion_param =
1038 user->fused_parameter(user->operand_index(operand));
1039
1040 const HloValue& fusion_param_value =
1041 GetValueDefinedAt(fusion_param, operand_index);
1042
1043 // TODO(b/80315712): This code is in a bit of a weird intermediate state
1044 // at the moment. The in-place DUS check really needs to be common to all
1045 // backends, so it runs first. Then we run the backend-specific check if
1046 // provided, or go through the target-indepdendent check if not.
1047 // Unfortunately, the notionally "target-independent" path actually contains
1048 // some target-specific code, so we can't run all of it *in addition* to the
1049 // target-specific function, like the interface documentation says.
1050 if (user->fused_expression_root()->opcode() ==
1051 HloOpcode::kDynamicUpdateSlice) {
1052 return CanDoInPlaceDynamicUpdateSlice(user, fusion_param_value);
1053 }
1054
1055 if (fusion_can_share_buffer_ != nullptr) {
1056 return fusion_can_share_buffer_(user, operand);
1057 }
1058
1059 if (user->fusion_kind() == HloInstruction::FusionKind::kLoop ||
1060 user->fusion_kind() == HloInstruction::FusionKind::kInput) {
1061 return AreTransitiveUsesElementwiseOrTuple(fusion_param);
1062 }
1063
1064 if (user->fusion_kind() == HloInstruction::FusionKind::kOutput &&
1065 user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
1066 // Output fusion with kAdd fused root.
1067
1068 // Check if one operand of kAdd fused root is kDot or kConvolution.
1069 auto* add = user->fused_expression_root();
1070 auto add_operand_it =
1071 absl::c_find_if(add->operands(), [&](HloInstruction* operand) {
1072 return operand->opcode() == HloOpcode::kConvolution ||
1073 operand->opcode() == HloOpcode::kDot;
1074 });
1075 if (add_operand_it == add->operands().end()) {
1076 return false;
1077 }
1078 auto* matched_add_operand = *add_operand_it;
1079 // Calculate operand index of 'add' operand which was not matched above.
1080 const int64 other_add_operand_index =
1081 matched_add_operand == add->operand(0) ? 1 : 0;
1082 // Returns true iff there is exactly one use of 'operand' at shape index
1083 // 'operand_index', and this singleton use is the fused root (at operand
1084 // index 'other_add_operand_index').
1085 if (fusion_param_value.uses().size() == 1) {
1086 const HloUse& use = fusion_param_value.uses()[0];
1087 return use.instruction == user->fused_expression_root() &&
1088 use.operand_number == other_add_operand_index;
1089 }
1090 return false;
1091 }
1092 }
1093
1094 if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
1095 user->opcode() == HloOpcode::kScatter ||
1096 user->opcode() == HloOpcode::kWhile) {
1097 // We eliminated other users in BufferLiveness::live_range_strictly_before,
1098 // so here we just need to check that the use is at operand index 0.
1099 std::vector<int64> operand_indices = user->OperandIndices(operand);
1100 return operand_indices.size() == 1 && operand_indices[0] == 0;
1101 }
1102 if (user->opcode() == HloOpcode::kSort) {
1103 // Only valid if there are no other users.
1104 if (operand->users().size() != 1) {
1105 return false;
1106 }
1107 // If we only sort keys, the output of sort is not a tuple, so we can always
1108 // share the buffer.
1109 if (user->operand_count() == 1) {
1110 return true;
1111 }
1112 CHECK(!user_index.empty());
1113 // Only share with the right tuple element buffer.
1114 std::vector<int64> operand_indices = user->OperandIndices(operand);
1115 return operand_indices.size() == 1 && user_index[0] == operand_indices[0];
1116 }
1117 if (user->opcode() == HloOpcode::kCall) {
1118 // Get all uses of value defined by 'operand' at 'operand_index'.
1119 const auto& uses = GetValueDefinedAt(operand, operand_index).uses();
1120 // Return true iff:
1121 // *) There exists two uses of 'operand'.
1122 // *) One use is by 'user' (caller).
1123 // *) One use is by root instruction of called computation (callee root).
1124 // (Note: we check the root of the called computation, because the
1125 // root result buffer is required to alias with the Call result buffer).
1126 // *) The root instruction of the called computation is element-wise on
1127 // 'operand'.
1128 const bool found_caller_use =
1129 absl::c_find_if(uses, [user](const HloUse& use) {
1130 return use.instruction == user;
1131 }) != uses.end();
1132 auto* callee_root = user->to_apply()->root_instruction();
1133 const bool found_elementwise_callee_use =
1134 absl::c_find_if(uses, [callee_root](const HloUse& use) {
1135 return use.instruction == callee_root &&
1136 callee_root->IsElementwiseOnOperand(use.operand_number);
1137 }) != uses.end();
1138 return uses.size() == 2 && found_caller_use && found_elementwise_callee_use;
1139 }
1140
1141 // Loop fusions that contain transposing copies won't reach here as they have
1142 // different layouts, which fails the check in the beginning of this function.
1143 return user->IsElementwiseOnOperand(user->operand_index(operand));
1144 }
1145
1146 } // namespace xla
1147