1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
17
18 #include <algorithm>
19 #include <queue>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/types/optional.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
35 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
36 #include "tensorflow/compiler/xla/service/hlo_module.h"
37 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
38 #include "tensorflow/compiler/xla/service/hlo_value.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/status.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/platform/logging.h"
45
46 namespace xla {
47 namespace {
48 // CalculatePostOrderSchedule traverses a module and assign a ordinal to each
49 // instruction based the postorder dependency.
CalculatePostOrderScheduleHelper(const HloComputation * comp,int64 start_ordinal,absl::flat_hash_map<HloInstruction *,int64> * ordinal_map)50 int64 CalculatePostOrderScheduleHelper(
51 const HloComputation* comp, int64 start_ordinal,
52 absl::flat_hash_map<HloInstruction*, int64>* ordinal_map) {
53 int64 ordinal = start_ordinal;
54 for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) {
55 if (instruction->opcode() == HloOpcode::kCall ||
56 instruction->opcode() == HloOpcode::kConditional) {
57 for (const HloComputation* called_computation :
58 instruction->called_computations()) {
59 ordinal = CalculatePostOrderScheduleHelper(called_computation, ordinal,
60 ordinal_map);
61 }
62 }
63 if (instruction->opcode() == HloOpcode::kWhile) {
64 ordinal = CalculatePostOrderScheduleHelper(instruction->while_condition(),
65 ordinal, ordinal_map);
66 ordinal = CalculatePostOrderScheduleHelper(instruction->while_body(),
67 ordinal, ordinal_map);
68 }
69 // It's possible that in some unit tests the computation graph is not
70 // flatten (meaning we could have multiple callers for one computation). In
71 // that case the oridinal_map will see the instruction multiple times. We
72 // consider that case to be ok as it only shows up in unit tests.
73 ordinal_map->insert({instruction, ordinal++});
74 }
75 return ordinal;
76 }
77
CalculatePostOrderSchedule(const HloModule & module)78 absl::flat_hash_map<HloInstruction*, int64> CalculatePostOrderSchedule(
79 const HloModule& module) {
80 absl::flat_hash_map<HloInstruction*, int64> map;
81 CalculatePostOrderScheduleHelper(module.entry_computation(), 0, &map);
82 return map;
83 }
84
85 } // namespace
86 using absl::StrAppend;
87 using absl::StrCat;
88
HloDataflowAnalysis(const HloModule & module,bool ssa_form,bool bitcast_defines_value,const CanShareBuffer & can_share_buffer)89 HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form,
90 bool bitcast_defines_value,
91 const CanShareBuffer& can_share_buffer)
92 : module_(module),
93 ssa_form_(ssa_form),
94 bitcast_defines_value_(bitcast_defines_value),
95 call_graph_(CallGraph::Build(&module)),
96 can_share_buffer_(can_share_buffer) {}
97
AreTransitiveUsesElementwiseOrTuple(const HloInstruction * inst)98 bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
99 const HloInstruction* inst) {
100 absl::flat_hash_set<const HloInstruction*> visited;
101 absl::InlinedVector<const HloInstruction*, 4> stack;
102 stack.push_back(inst);
103 while (!stack.empty()) {
104 const HloInstruction* current = stack.back();
105 stack.pop_back();
106 visited.insert(current);
107 for (const HloInstruction* user : current->users()) {
108 // Found a user that is non-elementwise on current instruction.
109 for (const int64 use_index : user->OperandIndices(current)) {
110 if (!user->IsElementwiseOnOperand(use_index) &&
111 user->opcode() != HloOpcode::kTuple) {
112 return false;
113 }
114 }
115 if (!visited.contains(user)) {
116 stack.push_back(user);
117 }
118 }
119 }
120 return true;
121 }
122
ValueIsDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const123 bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
124 const ShapeIndex& index) const {
125 const HloValueSet& value_set = GetValueSet(instruction, index);
126 if (value_set.values().size() != 1) {
127 return false;
128 }
129 return value_set.GetUniqueValue().defining_instruction() == instruction;
130 }
131
GetValueDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const132 const HloValue& HloDataflowAnalysis::GetValueDefinedAt(
133 const HloInstruction* instruction, const ShapeIndex& index) const {
134 CHECK(ValueIsDefinedAt(instruction, index)) << instruction->ToString();
135 return GetUniqueValueAt(instruction, index);
136 }
137
GetValueDefinedAt(const HloInstruction * instruction,const ShapeIndex & index)138 HloValue& HloDataflowAnalysis::GetValueDefinedAt(
139 const HloInstruction* instruction, const ShapeIndex& index) {
140 CHECK(ValueIsDefinedAt(instruction, index));
141 return GetUniqueValueAt(instruction, index);
142 }
143
NewHloValue(HloInstruction * instruction,const ShapeIndex & index,bool is_phi)144 HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction,
145 const ShapeIndex& index,
146 bool is_phi) {
147 const int64 value_id = next_value_id_++;
148 auto emplaced = values_.emplace(
149 std::piecewise_construct, std::forward_as_tuple(value_id),
150 std::forward_as_tuple(value_id, instruction, index, is_phi));
151 CHECK(emplaced.second);
152
153 VLOG(4) << "NewHloValue = " << emplaced.first->second.ToShortString();
154
155 return &emplaced.first->second;
156 }
157
MarkValueForDeletion(HloValue::Id value_id)158 void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) {
159 HloValue& value = values_.at(value_id);
160 VLOG(4) << "MarkValueForDeletion(" << value.ToShortString() << ")";
161
162 value_ids_to_delete_.push_back(value_id);
163 }
164
DeleteMarkedValues()165 void HloDataflowAnalysis::DeleteMarkedValues() {
166 // Use a set to prevent deleting an id twice.
167 absl::flat_hash_set<HloValue::Id> id_set(value_ids_to_delete_.begin(),
168 value_ids_to_delete_.end());
169 #ifndef NDEBUG
170 // Verify that no marked-for-deletion values are in any of the value sets.
171 for (const auto& pair : value_sets_) {
172 const HloInstruction* instruction = pair.first;
173 const InstructionValueSet& instruction_value_set = pair.second;
174 for (const auto& index_value_set : instruction_value_set) {
175 const HloValueSet& value_set = index_value_set.second;
176 for (const HloValue* value : value_set.values()) {
177 DCHECK(!ContainsKey(id_set, value->id()))
178 << "Value " << value->ToShortString()
179 << " marked for deletion, but still exists in value set for "
180 "instruction "
181 << instruction->name();
182 }
183 }
184 }
185 #endif
186
187 for (HloValue::Id value_id : id_set) {
188 values_.erase(value_id);
189 }
190 value_ids_to_delete_.clear();
191 }
192
ToString() const193 string HloDataflowAnalysis::ToString() const {
194 string out = StrCat("HloDataflowAnalysis, module ", module_.name(), "\n");
195 StrAppend(&out, " Instruction value sets:\n");
196 for (const HloComputation* computation : module_.computations()) {
197 for (const HloInstruction* instruction : computation->instructions()) {
198 StrAppend(&out, "Instruction: \n ", instruction->name(), ":\n");
199 if (instruction->shape().IsTuple()) {
200 GetInstructionValueSet(instruction)
201 .ForEachElement([this, &instruction, &out](
202 const ShapeIndex& index,
203 const HloValueSet& value_set) {
204 StrAppend(&out, " tuple index ", index.ToString(), ":\n");
205 for (const HloValue* value : value_set.values()) {
206 StrAppend(&out, " ", value->ToShortString(),
207 ValueIsDefinedAt(instruction, index) ? " (def)" : "",
208 "\n");
209 }
210 });
211 } else {
212 const HloValueSet& top_level_value_set =
213 GetValueSet(instruction, /*index=*/{});
214 for (const HloValue* value : top_level_value_set.values()) {
215 StrAppend(&out, " ", value->ToShortString(),
216 ValueIsDefinedAt(instruction) ? " (def)" : "", "\n");
217 }
218 }
219 }
220 }
221 StrAppend(&out, " HloValues:\n");
222 for (const HloValue* value : values()) {
223 StrAppend(&out, value->ToString(/*indent=*/4));
224 }
225 return out;
226 }
227
Phi(HloInstruction * instruction,absl::Span<const InstructionValueSet * const> inputs)228 bool HloDataflowAnalysis::Phi(
229 HloInstruction* instruction,
230 absl::Span<const InstructionValueSet* const> inputs) {
231 CHECK(ssa_form_);
232 VLOG(4) << "Phi(" << instruction->name() << ")";
233 VLOG(5) << "instruction value set = "
234 << GetInstructionValueSet(instruction).ToString();
235 for (const InstructionValueSet* input : inputs) {
236 VLOG(5) << "input value set = " << input->ToString();
237 }
238
239 if (bitcast_defines_value_) {
240 absl::c_for_each(inputs, [&](const InstructionValueSet* input) {
241 DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape()));
242 });
243 } else {
244 const Shape& shape = instruction->shape();
245 PrimitiveType ty = shape.element_type();
246 bool is_array = shape.IsArray();
247 absl::c_for_each(inputs, [&](const InstructionValueSet* input) {
248 DCHECK(ty == input->shape().element_type() &&
249 (!is_array || ShapeUtil::ElementsIn(shape) ==
250 ShapeUtil::ElementsIn(input->shape())));
251 });
252 }
253
254 bool changed = false;
255 for (auto& pair : GetInstructionValueSet(instruction)) {
256 const ShapeIndex& index = pair.first;
257 HloValueSet& value_set = pair.second;
258
259 // Positions with phi values should never have more than one value in the
260 // value set.
261 CHECK_LE(value_set.values().size(), 1);
262 const HloValue* current_value =
263 value_set.values().size() == 1 ? value_set.values()[0] : nullptr;
264
265 // Construct a vector of value IDs of the inputs.
266 std::vector<HloValue::Id> input_value_ids;
267 for (const InstructionValueSet* input : inputs) {
268 for (const HloValue* value : input->element(index).values()) {
269 input_value_ids.push_back(value->id());
270 }
271 }
272
273 // Remove the existing phi value (if it exists). The phi can be its own
274 // input, for example, in while body parameters where the body passes
275 // through the parameter value.
276 bool current_value_defined_here =
277 (current_value != nullptr &&
278 current_value->defining_instruction() == instruction &&
279 current_value->defining_index() == index);
280
281 VLOG(5) << "after input_value_ids.size = " << input_value_ids.size();
282 if (input_value_ids.empty()) {
283 // A value set which has at least one element should never have its value
284 // set reduced to zero elements. During dataflow value sets only can go
285 // from empty to non-empty, not the reverse.
286 CHECK_EQ(value_set.values().size(), 0)
287 << "Instruction " << instruction->name() << " at index " << index
288 << " previously had non-empty value set. Value set: " << value_set;
289 } else if (input_value_ids.size() == 1) {
290 // Only a single value reaches this point. There should be no phi, and
291 // this value set should contain this single value.
292 const HloValue& new_value = GetValue(input_value_ids[0]);
293 if (current_value == nullptr) {
294 value_set.Clear();
295 value_set.AddValue(&new_value);
296 changed = true;
297 } else if (current_value != &new_value) {
298 if (current_value_defined_here) {
299 // Remove the existing phi.
300 MarkValueForDeletion(current_value->id());
301 }
302 value_set.Clear();
303 value_set.AddValue(&new_value);
304 changed = true;
305 }
306 } else {
307 // Multiple distinct values reach this point. A phi value is
308 // necessary.
309 CHECK_GT(input_value_ids.size(), 1);
310 bool phi_defined_here =
311 current_value_defined_here && current_value->is_phi();
312 if (current_value == nullptr || !phi_defined_here) {
313 value_set.Clear();
314 value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true));
315
316 std::vector<HloValue*> inputs;
317 inputs.reserve(input_value_ids.size());
318 for (HloValue::Id id : input_value_ids) {
319 inputs.push_back(&GetValue(id));
320 }
321 // Register the phi into phi graph.
322 phi_graph_.RegisterPhi(*value_set.values()[0], inputs);
323 changed = true;
324 } else if (phi_defined_here) {
325 std::vector<HloValue*> new_inputs;
326 new_inputs.reserve(input_value_ids.size());
327 for (HloValue::Id id : input_value_ids) {
328 new_inputs.push_back(&GetValue(id));
329 }
330
331 if (!phi_graph_.InputsEqualTo(*current_value, new_inputs)) {
332 VLOG(1) << current_value->ToShortString() << " has new phi inputs: ";
333 // Update phi inputs.
334 phi_graph_.RegisterPhi(*current_value, new_inputs);
335 changed = true;
336 }
337 }
338 }
339 }
340 return changed;
341 }
342
GetValue(HloValue::Id value_id) const343 const HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) const {
344 return values_.at(value_id);
345 }
346
GetValue(HloValue::Id value_id)347 HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) {
348 return values_.at(value_id);
349 }
350
GetFlattenedValueSet(const HloInstruction * instruction) const351 HloValueSet HloDataflowAnalysis::GetFlattenedValueSet(
352 const HloInstruction* instruction) const {
353 HloValueSet value_set;
354
355 const InstructionValueSet& value_set_tree =
356 GetInstructionValueSet(instruction);
357
358 std::vector<const HloValueSet*> all_sets;
359 for (auto& pair : value_set_tree) {
360 const HloValueSet& value_set = pair.second;
361 all_sets.push_back(&value_set);
362 }
363 value_set.AssignUnionOf(all_sets);
364
365 return value_set;
366 }
367
GetValueSet(const HloInstruction * instruction,const ShapeIndex & index) const368 const HloValueSet& HloDataflowAnalysis::GetValueSet(
369 const HloInstruction* instruction, const ShapeIndex& index) const {
370 return GetInstructionValueSet(instruction).element(index);
371 }
372
GetValueSet(const HloInstruction * instruction,const ShapeIndex & index)373 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloInstruction* instruction,
374 const ShapeIndex& index) {
375 return *GetInstructionValueSet(instruction).mutable_element(index);
376 }
377
GetValueSet(const HloPosition & position) const378 const HloValueSet& HloDataflowAnalysis::GetValueSet(
379 const HloPosition& position) const {
380 return GetValueSet(position.instruction, position.index);
381 }
382
GetValueSet(const HloPosition & position)383 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloPosition& position) {
384 return GetValueSet(position.instruction, position.index);
385 }
386
UpdateBitcastValueSet(HloInstruction * bitcast)387 bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) {
388 CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast);
389 const InstructionValueSet& operand_set =
390 GetInstructionValueSet(bitcast->operand(0));
391 InstructionValueSet& bitcast_set = GetInstructionValueSet(bitcast);
392 if (!bitcast_defines_value_ && operand_set != bitcast_set) {
393 bitcast_set = operand_set;
394 return true;
395 }
396 return false;
397 }
398
UpdateSetDimensionSizeValueSet(HloInstruction * set_dimension_size)399 bool HloDataflowAnalysis::UpdateSetDimensionSizeValueSet(
400 HloInstruction* set_dimension_size) {
401 CHECK_EQ(set_dimension_size->opcode(), HloOpcode::kSetDimensionSize);
402 const InstructionValueSet& operand_set =
403 GetInstructionValueSet(set_dimension_size->operand(0));
404 InstructionValueSet& set_dimension_size_set =
405 GetInstructionValueSet(set_dimension_size);
406 if (operand_set != set_dimension_size_set) {
407 set_dimension_size_set = operand_set;
408 return true;
409 }
410 return false;
411 }
412
UpdateSendValueSet(HloInstruction * send)413 bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
414 CHECK_EQ(send->opcode(), HloOpcode::kSend);
415 bool changed = false;
416 // Send forwards the operand value to the output tuple at {0}.
417 for (auto& pair : GetInstructionValueSet(send->operand(0))) {
418 const ShapeIndex& operand_index = pair.first;
419 const HloValueSet& operand_value_set = pair.second;
420
421 ShapeIndex index = {0};
422 for (int64 i : operand_index) {
423 index.push_back(i);
424 }
425
426 HloValueSet& value_set = GetValueSet(send, index);
427 if (value_set != operand_value_set) {
428 value_set = operand_value_set;
429 changed = true;
430 }
431 }
432 return changed;
433 }
434
UpdateCustomCallValueSet(HloInstruction * custom_call)435 bool HloDataflowAnalysis::UpdateCustomCallValueSet(
436 HloInstruction* custom_call) {
437 CHECK_EQ(custom_call->opcode(), HloOpcode::kCustomCall);
438 bool changed = false;
439 for (const auto& aliasing : Cast<HloCustomCallInstruction>(custom_call)
440 ->output_to_operand_aliasing()) {
441 const HloValueSet& operand_value_set = GetValueSet(
442 custom_call->operand(aliasing.second.first), aliasing.second.second);
443 HloValueSet& value_set = GetValueSet(custom_call, aliasing.first);
444 if (value_set != operand_value_set) {
445 value_set = operand_value_set;
446 changed = true;
447 }
448 }
449 return changed;
450 }
451
UpdateCopyStartValueSet(HloInstruction * copy_start)452 bool HloDataflowAnalysis::UpdateCopyStartValueSet(HloInstruction* copy_start) {
453 CHECK_EQ(copy_start->opcode(), HloOpcode::kCopyStart);
454 bool changed = false;
455 // CopyStart forwards the operand value to element {1} of its output.
456 const HloValueSet& operand_value_set = GetValueSet(copy_start->operand(0));
457 HloValueSet& value_set = GetValueSet(copy_start, {1});
458 if (value_set != operand_value_set) {
459 value_set = operand_value_set;
460 changed = true;
461 }
462 return changed;
463 }
464
UpdateCopyDoneValueSet(HloInstruction * copy_done)465 bool HloDataflowAnalysis::UpdateCopyDoneValueSet(HloInstruction* copy_done) {
466 CHECK_EQ(copy_done->opcode(), HloOpcode::kCopyDone);
467 bool changed = false;
468 // CopyDone forwards the operand value at {0} to element {} of its output.
469 const HloValueSet& operand_value_set =
470 GetValueSet(copy_done->operand(0), {0});
471 HloValueSet& value_set = GetValueSet(copy_done);
472 if (value_set != operand_value_set) {
473 value_set = operand_value_set;
474 changed = true;
475 }
476 return changed;
477 }
478
UpdateRecvDoneValueSet(HloInstruction * recv_done)479 bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) {
480 CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone);
481 bool changed = false;
482 // RecvDone forwards the operand value at {0} to element {0} of its output.
483 for (auto& pair : GetInstructionValueSet(recv_done)) {
484 ShapeIndex& index = pair.first;
485 HloValueSet& value_set = pair.second;
486
487 if (index.empty() || index[0] != 0) {
488 continue;
489 }
490
491 const HloValueSet& operand_value_set =
492 GetValueSet(recv_done->operand(0), index);
493 if (value_set != operand_value_set) {
494 value_set = operand_value_set;
495 changed = true;
496 }
497 }
498 return changed;
499 }
500
UpdateCallValueSet(HloInstruction * call)501 bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) {
502 CHECK_EQ(call->opcode(), HloOpcode::kCall);
503 InstructionValueSet& value_set = GetInstructionValueSet(call);
504 InstructionValueSet& root_value_set =
505 GetInstructionValueSet(call->to_apply()->root_instruction());
506 if (value_set != root_value_set) {
507 value_set = root_value_set;
508 return true;
509 }
510 return false;
511 }
512
UpdateConditionalValueSet(HloInstruction * conditional)513 bool HloDataflowAnalysis::UpdateConditionalValueSet(
514 HloInstruction* conditional) {
515 CHECK_EQ(conditional->opcode(), HloOpcode::kConditional);
516 std::vector<const InstructionValueSet*> inputs(conditional->branch_count());
517 for (int j = 0; j < conditional->branch_count(); ++j) {
518 inputs[j] = &GetInstructionValueSet(
519 conditional->branch_computation(j)->root_instruction());
520 }
521 if (ssa_form_) {
522 return Phi(conditional, inputs);
523 } else {
524 return GetInstructionValueSet(conditional).AssignUnionOf(inputs);
525 }
526 }
527
UpdateCopyValueSet(HloInstruction * copy)528 bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) {
529 CHECK_EQ(copy->opcode(), HloOpcode::kCopy);
530 bool changed = false;
531 for (auto& pair : GetInstructionValueSet(copy)) {
532 const ShapeIndex& index = pair.first;
533 if (index.empty()) {
534 // kCopy shallow copies and thus defines the top-level value so nothing to
535 // update.
536 continue;
537 }
538
539 HloValueSet& value_set = pair.second;
540 HloValueSet& operand_value_set = GetValueSet(copy->operand(0), index);
541 if (value_set != operand_value_set) {
542 value_set = operand_value_set;
543 changed = true;
544 }
545 }
546 return changed;
547 }
548
UpdateDomainValueSet(HloInstruction * domain)549 bool HloDataflowAnalysis::UpdateDomainValueSet(HloInstruction* domain) {
550 // Domain instructions just forward their operand. Given that domains can have
551 // a tuple operand, we iterate through its indexes, like for copies.
552 // Unlike copies though we also propagate the top-level value.
553 CHECK_EQ(domain->opcode(), HloOpcode::kDomain);
554 bool changed = false;
555 for (auto& pair : GetInstructionValueSet(domain)) {
556 const ShapeIndex& index = pair.first;
557 HloValueSet& value_set = pair.second;
558 HloValueSet& operand_value_set = GetValueSet(domain->operand(0), index);
559 if (value_set != operand_value_set) {
560 value_set = operand_value_set;
561 changed = true;
562 }
563 }
564 return changed;
565 }
566
UpdateAddDependencyValueSet(HloInstruction * add_dependency)567 bool HloDataflowAnalysis::UpdateAddDependencyValueSet(
568 HloInstruction* add_dependency) {
569 // AddDependency just forwards the value of its zero-th operand.
570 CHECK_EQ(add_dependency->opcode(), HloOpcode::kAddDependency);
571 const InstructionValueSet& operand_set =
572 GetInstructionValueSet(add_dependency->operand(0));
573 InstructionValueSet& add_dependency_set =
574 GetInstructionValueSet(add_dependency);
575 if (operand_set != add_dependency_set) {
576 add_dependency_set = operand_set;
577 return true;
578 }
579 return false;
580 }
581
UpdateGetTupleElementValueSet(HloInstruction * gte)582 bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) {
583 CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement);
584 bool changed = false;
585 // The GetTupleElement instruction forwards the values from the specified
586 // tuple element.
587 for (auto& pair : GetInstructionValueSet(gte)) {
588 const ShapeIndex& index = pair.first;
589 HloValueSet& value_set = pair.second;
590
591 // The corresponding ShapeIndex of the operand is simply the GTE ShapeIndex
592 // with the tuple element number prefixed.
593 ShapeIndex operand_index = {gte->tuple_index()};
594 for (int64 i : index) {
595 operand_index.push_back(i);
596 }
597
598 HloValueSet& operand_value_set =
599 GetValueSet(gte->operand(0), operand_index);
600 if (value_set != operand_value_set) {
601 value_set = operand_value_set;
602 changed = true;
603 }
604 }
605 return changed;
606 }
607
UpdateParameterValueSet(HloInstruction * parameter)608 bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) {
609 CHECK_EQ(parameter->opcode(), HloOpcode::kParameter);
610 const CallGraphNode& call_graph_node =
611 call_graph_->GetNode(parameter->parent());
612
613 // Subcomputations called in a parallel context (eg, map) do not have dataflow
614 // from the caller operands.
615 if (call_graph_node.context() == CallContext::kParallel ||
616 call_graph_node.caller_callsites().empty()) {
617 return false;
618 }
619 CHECK_EQ(call_graph_node.context(), CallContext::kSequential);
620
621 std::vector<const InstructionValueSet*> inputs;
622 bool need_phi = false;
623 for (const CallSite& callsite : call_graph_node.caller_callsites()) {
624 if (callsite.instruction()->opcode() == HloOpcode::kCall) {
625 // The operand values of a call instruction are forwarded to the
626 // respective parameter instruction of the subcomputation.
627 inputs.push_back(&GetInstructionValueSet(
628 callsite.instruction()->operand(parameter->parameter_number())));
629 } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
630 // In a while instruction, the while operand (ie, the init value) and the
631 // backedge are dataflow inputs to the parameter instruction. This is the
632 // case for parameters of both the body and condition computations.
633 CHECK_EQ(parameter->parameter_number(), 0);
634 inputs.push_back(
635 &GetInstructionValueSet(callsite.instruction()->operand(0)));
636 // If the parameter *is not* the root, parameter state would be
637 // updated by the root, otherwise don't consider it's current state
638 // (InstructionValueSet) as we are recomputing its current state.
639 if (parameter !=
640 callsite.instruction()->while_body()->root_instruction()) {
641 inputs.push_back(&GetInstructionValueSet(
642 callsite.instruction()->while_body()->root_instruction()));
643 }
644 need_phi = true;
645 } else if (callsite.instruction()->opcode() == HloOpcode::kConditional) {
646 CHECK_EQ(parameter->parameter_number(), 0);
647 auto conditional = callsite.instruction();
648 // Conditional has branch_count+1 operands. Operand 0 is the branch_index,
649 // operands 1 and onward are the arguments to the branch computations.
650 //
651 // If the parameter belongs to conditional's branch 0 computation, then
652 // operand 1 is forwarded to this parameter instruction. If the parameter
653 // belongs to conditional's branch 5 computation, then operand 6 is
654 // forwarded to this parameter instruction.
655 bool found_parent = false;
656 for (int j = 0; j < conditional->branch_count(); ++j) {
657 if (parameter->parent() == conditional->branch_computation(j)) {
658 inputs.push_back(
659 &GetInstructionValueSet(conditional->operand(j + 1)));
660 found_parent = true;
661 break;
662 }
663 }
664 CHECK(found_parent);
665 need_phi = true;
666 } else {
667 LOG(FATAL) << "CallContext::kSequential computations should only be "
668 "called from call, while, or conditional instructions";
669 }
670 }
671 if (ssa_form_ && need_phi) {
672 return Phi(parameter, inputs);
673 } else {
674 return GetInstructionValueSet(parameter).AssignUnionOf(inputs);
675 }
676 }
677
UpdateTupleSelectValueSet(HloInstruction * select)678 bool HloDataflowAnalysis::UpdateTupleSelectValueSet(HloInstruction* select) {
679 CHECK_EQ(select->opcode(), HloOpcode::kTupleSelect);
680 // A phi value is not defined at a kTupleSelect instruction because
681 // kTupleSelect does not create a new value. Rather it forwards a value from
682 // its operands. This contrasts with kWhile instruction (which does define a
683 // phi value) which has in-place update semantics.
684 bool changed = false;
685 for (auto& pair : GetInstructionValueSet(select)) {
686 const ShapeIndex& index = pair.first;
687 if (index.empty()) {
688 // kTupleSelect copies (not forwards) the top-level value.
689 continue;
690 }
691 HloValueSet& value_set = pair.second;
692 changed |=
693 value_set.AssignUnionOf({&GetValueSet(select->operand(1), index),
694 &GetValueSet(select->operand(2), index)});
695 }
696 return changed;
697 }
698
UpdateTupleValueSet(HloInstruction * tuple)699 bool HloDataflowAnalysis::UpdateTupleValueSet(HloInstruction* tuple) {
700 CHECK_EQ(tuple->opcode(), HloOpcode::kTuple);
701 bool changed = false;
702 for (int64 i = 0; i < tuple->operands().size(); ++i) {
703 // Copy the value set(s) of each operand into the respective position in the
704 // kTuple instruction's value sets.
705 for (auto& pair : GetInstructionValueSet(tuple->operand(i))) {
706 const ShapeIndex& operand_index = pair.first;
707 HloValueSet& operand_value_set = pair.second;
708
709 ShapeIndex index = {i};
710 for (int64 op_index : operand_index) {
711 index.push_back(op_index);
712 }
713 HloValueSet& value_set = GetValueSet(tuple, index);
714
715 if (value_set != operand_value_set) {
716 value_set = operand_value_set;
717 changed = true;
718 }
719 }
720 }
721 return changed;
722 }
723
UpdateWhileValueSet(HloInstruction * xla_while)724 bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) {
725 CHECK_EQ(xla_while->opcode(), HloOpcode::kWhile);
726 const InstructionValueSet* const inputs[] = {
727 &GetInstructionValueSet(xla_while->while_body()->root_instruction()),
728 &GetInstructionValueSet(xla_while->operand(0))};
729 if (ssa_form_) {
730 return Phi(xla_while, inputs);
731 } else {
732 return GetInstructionValueSet(xla_while).AssignUnionOf(inputs);
733 }
734 }
735
UpdateCollectivePermuteStartValueSet(HloInstruction * collective_permute_start)736 bool HloDataflowAnalysis::UpdateCollectivePermuteStartValueSet(
737 HloInstruction* collective_permute_start) {
738 CHECK_EQ(collective_permute_start->opcode(),
739 HloOpcode::kCollectivePermuteStart);
740 bool changed = false;
741 // CollectivePermuteStart forwards the operand value to element {0} of its
742 // output.
743 const HloValueSet& operand_value_set =
744 GetValueSet(collective_permute_start->operand(0));
745 HloValueSet& value_set = GetValueSet(collective_permute_start, {0});
746 if (value_set != operand_value_set) {
747 value_set = operand_value_set;
748 changed = true;
749 }
750 return changed;
751 }
752
UpdateCollectivePermuteDoneValueSet(HloInstruction * collective_permute_done)753 bool HloDataflowAnalysis::UpdateCollectivePermuteDoneValueSet(
754 HloInstruction* collective_permute_done) {
755 CHECK_EQ(collective_permute_done->opcode(),
756 HloOpcode::kCollectivePermuteDone);
757 bool changed = false;
758 // CollectivePermuteDone forwards the operand value at {1} to its output.
759 const HloValueSet& operand_value_set =
760 GetValueSet(collective_permute_done->operand(0), {1});
761 HloValueSet& value_set = GetValueSet(collective_permute_done);
762 if (value_set != operand_value_set) {
763 value_set = operand_value_set;
764 changed = true;
765 }
766 return changed;
767 }
768
UpdateInstructionValueSet(HloInstruction * instruction)769 bool HloDataflowAnalysis::UpdateInstructionValueSet(
770 HloInstruction* instruction) {
771 // Recompute from operands.
772 switch (instruction->opcode()) {
773 case HloOpcode::kAddDependency:
774 return UpdateAddDependencyValueSet(instruction);
775 case HloOpcode::kBitcast:
776 return UpdateBitcastValueSet(instruction);
777 case HloOpcode::kCustomCall:
778 return UpdateCustomCallValueSet(instruction);
779 case HloOpcode::kSetDimensionSize:
780 return UpdateSetDimensionSizeValueSet(instruction);
781 case HloOpcode::kDomain:
782 return UpdateDomainValueSet(instruction);
783 case HloOpcode::kCopy:
784 return UpdateCopyValueSet(instruction);
785 case HloOpcode::kGetTupleElement:
786 return UpdateGetTupleElementValueSet(instruction);
787 case HloOpcode::kTupleSelect:
788 return UpdateTupleSelectValueSet(instruction);
789 case HloOpcode::kTuple:
790 return UpdateTupleValueSet(instruction);
791 case HloOpcode::kParameter:
792 return UpdateParameterValueSet(instruction);
793 case HloOpcode::kCall:
794 return UpdateCallValueSet(instruction);
795 case HloOpcode::kWhile:
796 return UpdateWhileValueSet(instruction);
797 case HloOpcode::kSend:
798 return UpdateSendValueSet(instruction);
799 case HloOpcode::kRecvDone:
800 return UpdateRecvDoneValueSet(instruction);
801 case HloOpcode::kCopyStart:
802 return UpdateCopyStartValueSet(instruction);
803 case HloOpcode::kCopyDone:
804 return UpdateCopyDoneValueSet(instruction);
805 case HloOpcode::kConditional:
806 return UpdateConditionalValueSet(instruction);
807 case HloOpcode::kCollectivePermuteStart:
808 return UpdateCollectivePermuteStartValueSet(instruction);
809 case HloOpcode::kCollectivePermuteDone:
810 return UpdateCollectivePermuteDoneValueSet(instruction);
811 default:
812 // Instruction does not forward HloValues (it defines all values in its
813 // output). No update is necessary.
814 return false;
815 }
816 }
817
Propagate()818 void HloDataflowAnalysis::Propagate() {
819 using Work = std::pair<int64, HloInstruction*>;
820 // Avoid duplicating work by preferring work items early in the post order
821 // schedule. Intuitively, we start from entry parameters and propagate buffers
822 // updates throughout the module only once.
823 std::priority_queue<Work, std::vector<Work>, std::greater<Work>> worklist;
824 absl::flat_hash_set<HloInstruction*> workset;
825 auto priority_map = CalculatePostOrderSchedule(module_);
826 auto add_to_worklist = [&priority_map, &worklist,
827 &workset](HloInstruction* instruction) {
828 if (workset.insert(instruction).second) {
829 worklist.emplace(priority_map[instruction], instruction);
830 }
831 };
832
833 auto comps = module_.MakeComputationPostOrder();
834 for (HloComputation* computation : comps) {
835 for (HloInstruction* instruction :
836 computation->MakeInstructionPostOrder()) {
837 add_to_worklist(instruction);
838 }
839 }
840 VLOG(1) << "SSA_FORM_: " << ssa_form_;
841
842 while (!worklist.empty()) {
843 HloInstruction* instruction = worklist.top().second;
844 auto add_to_worklist = [&](HloInstruction* todo) {
845 if (workset.insert(todo).second) {
846 VLOG(1) << " Adding todo : " << todo->name();
847 worklist.emplace(priority_map[todo], todo);
848 }
849 };
850 worklist.pop();
851
852 workset.erase(workset.find(instruction));
853
854 VLOG(3) << "Worklist top: " << instruction->name();
855 VLOG(3) << ToString();
856
857 if (!UpdateInstructionValueSet(instruction)) {
858 // No change to the instruction's value set.
859 VLOG(4) << "No change.";
860 continue;
861 }
862
863 VLOG(4) << "New value set for " << instruction->name() << ": "
864 << GetInstructionValueSet(instruction);
865
866 // Instruction value was updated. Add users to work list if we haven't
867 // already.
868 for (HloInstruction* user : instruction->users()) {
869 add_to_worklist(user);
870
871 // If user sequentially calls a computation, then the respective
872 // parameter(s) of the computation need to be updated.
873 if (user->opcode() == HloOpcode::kConditional) {
874 // If operand 0 is the use of instruction, then no parameters need to be
875 // updated, since that is the branch_index of the conditional.
876 // If operand n+1 is the use of instruction, then the branch_computation
877 // n's parameter need to be updated.
878 //
879 // Note that the same instruction can be used in multiple branches'
880 // operands.
881 for (int j = 0; j < user->branch_count(); ++j) {
882 if (user->operand(j + 1) == instruction) {
883 add_to_worklist(
884 user->branch_computation(j)->parameter_instruction(0));
885 }
886 }
887 } else {
888 for (HloComputation* called_computation : user->called_computations()) {
889 const CallGraphNode& call_graph_node =
890 call_graph_->GetNode(called_computation);
891 if (call_graph_node.context() == CallContext::kSequential) {
892 for (int64 operand_number : user->OperandIndices(instruction)) {
893 add_to_worklist(
894 called_computation->parameter_instruction(operand_number));
895 }
896 }
897 }
898 }
899 }
900
901 // If instruction is a root instruction, then propagate out to any calling
902 // instruction and across any while backedge.
903 if (instruction == instruction->parent()->root_instruction()) {
904 const CallGraphNode& call_graph_node =
905 call_graph_->GetNode(instruction->parent());
906 for (const CallSite& callsite : call_graph_node.caller_callsites()) {
907 if (callsite.instruction()->opcode() == HloOpcode::kCall ||
908 callsite.instruction()->opcode() == HloOpcode::kConditional) {
909 add_to_worklist(callsite.instruction());
910 } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
911 // Add the while itself, and the body and condition parameters.
912 add_to_worklist(callsite.instruction());
913 add_to_worklist(
914 callsite.instruction()->while_body()->parameter_instruction(0));
915 add_to_worklist(
916 callsite.instruction()->while_condition()->parameter_instruction(
917 0));
918 }
919 }
920 }
921 }
922 }
923
GetInstructionValueSet(const HloInstruction * instruction) const924 const InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
925 const HloInstruction* instruction) const {
926 return value_sets_.at(instruction);
927 }
928
GetInstructionValueSet(const HloInstruction * instruction)929 InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
930 const HloInstruction* instruction) {
931 return value_sets_.at(instruction);
932 }
933
InitializeInstructionValueSets()934 Status HloDataflowAnalysis::InitializeInstructionValueSets() {
935 for (const HloComputation* computation : module_.computations()) {
936 const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
937 for (HloInstruction* instruction : computation->instructions()) {
938 // Create an empty shape tree.
939 value_sets_.emplace(std::piecewise_construct,
940 std::forward_as_tuple(instruction),
941 std::forward_as_tuple(instruction->shape()));
942
943 // For each sub-shape of the instruction shape, add a new HloValue to its
944 // HloValueSet.
945 auto define_all_values = [this, &instruction]() {
946 for (auto& pair : GetInstructionValueSet(instruction)) {
947 const ShapeIndex& index = pair.first;
948 HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
949 GetValueSet(instruction, index).AddValue(value);
950 }
951 };
952
953 // Add a new HloValue to the HloValueSet corresponding to the given index
954 // of the instruction shape.
955 auto define_value_at = [this, &instruction](const ShapeIndex& index) {
956 HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
957 GetValueSet(instruction, index).AddValue(value);
958 };
959
960 switch (instruction->opcode()) {
961 case HloOpcode::kBitcast:
962 if (bitcast_defines_value_) {
963 define_all_values();
964 }
965 break;
966 case HloOpcode::kSetDimensionSize:
967 case HloOpcode::kAddDependency:
968 case HloOpcode::kWhile:
969 case HloOpcode::kCall:
970 case HloOpcode::kConditional:
971 case HloOpcode::kGetTupleElement:
972 case HloOpcode::kDomain:
973 // These instructions define no values. The values in their output
974 // flow from their operands or from cross computation dataflow.
975 break;
976 case HloOpcode::kParameter:
977 if (call_graph_node.context() == CallContext::kBoth) {
978 // We do not support a subcomputation that is called from both a
979 // parallel and sequential context. In this case, the parameter
980 // would both define a value and propagate a value from its
981 // caller. This limitation is not really a problem because the call
982 // graph is typically flattened.
983 return Unimplemented(
984 "Computation %s is called in both a parallel (eg, kMap) and "
985 "sequential (eg, kCall) context",
986 computation->name());
987 }
988 if (call_graph_node.caller_callsites().empty() ||
989 call_graph_node.context() == CallContext::kParallel) {
990 // Parameters of computations called in a parallel context (eg, map
991 // and reduce) as well as parameters of dead computations define all
992 // values in their output. Otherwise the values of the parameter
993 // come from the caller (eg, operands to the kCall instruction).
994 define_all_values();
995 }
996 break;
997 case HloOpcode::kCopy:
998 case HloOpcode::kTupleSelect:
999 case HloOpcode::kTuple:
1000 // These instructions only define their top-level values. Any other
1001 // values flow from their operands.
1002 define_value_at(/*index=*/{});
1003 break;
1004 case HloOpcode::kCopyStart:
1005 // CopyStart produces a tuple of {destination buffer, aliased operand,
1006 // U32 context}.
1007 define_value_at(/*index=*/{});
1008 define_value_at(/*index=*/{0});
1009 define_value_at(/*index=*/{2});
1010 break;
1011 case HloOpcode::kCopyDone:
1012 // CopyDone consumes a tuple produced by CopyStart and produces an
1013 // element. Its output aliases its input tuple element {0}.
1014 break;
1015 case HloOpcode::kCollectivePermuteStart:
1016 // CollectivePermuteStart produces a tuple of
1017 // {aliased operand, destination buffer, U32 context, U32 context}.
1018 define_value_at(/*index=*/{});
1019 define_value_at(/*index=*/{1});
1020 define_value_at(/*index=*/{2});
1021 define_value_at(/*index=*/{3});
1022 break;
1023 case HloOpcode::kCollectivePermuteDone:
1024 // CollectivePermuteDone's output aliases its input tuple element {1}.
1025 break;
1026 case HloOpcode::kRecvDone:
1027 // RecvDone produces a two-element tuple. Element zero aliases its
1028 // input tuple element {0}; element one is a token.
1029 define_value_at(/*index=*/{});
1030 define_value_at(/*index=*/{1});
1031 break;
1032 case HloOpcode::kSend:
1033 // Send produces a tuple of {aliased operand, U32 context, token},
1034 // therefore only defines the top-level tuple and the tuple elements
1035 // at {1} and {2}.
1036 define_value_at(/*index=*/{});
1037 define_value_at(/*index=*/{1});
1038 define_value_at(/*index=*/{2});
1039 break;
1040 case HloOpcode::kCustomCall: {
1041 absl::flat_hash_set<ShapeIndex> aliasing_indices;
1042 for (const auto& aliasing :
1043 Cast<HloCustomCallInstruction>(instruction)
1044 ->output_to_operand_aliasing()) {
1045 aliasing_indices.insert(aliasing.first);
1046 }
1047 ShapeUtil::ForEachSubshape(
1048 instruction->shape(),
1049 [&](const Shape& /*subshape*/, const ShapeIndex& index) {
1050 if (!aliasing_indices.contains(index)) {
1051 define_value_at(index);
1052 }
1053 });
1054 break;
1055 }
1056 default:
1057 define_all_values();
1058 break;
1059 }
1060 }
1061 }
1062
1063 return Status::OK();
1064 }
1065
OptimizePhiValues()1066 void HloDataflowAnalysis::OptimizePhiValues() {
1067 // Only applicable to SSA form where phis are defined.
1068 if (!ssa_form_) {
1069 return;
1070 }
1071
1072 VLOG(1) << "Before phi graph optimization";
1073 XLA_VLOG_LINES(1, phi_graph_.ToString());
1074 phi_graph_.Optimize();
1075 VLOG(1) << "After phi graph optimization";
1076 XLA_VLOG_LINES(1, phi_graph_.ToString());
1077
1078 for (const HloComputation* computation : module_.computations()) {
1079 for (HloInstruction* instruction : computation->instructions()) {
1080 InstructionValueSet& instruction_value_set =
1081 GetInstructionValueSet(instruction);
1082 VLOG(1) << "inst: " << instruction->name();
1083 VLOG(1) << instruction_value_set.ToString();
1084 instruction_value_set.ForEachMutableElement(
1085 [&](const xla::ShapeIndex& index, HloValueSet* value_set) {
1086 auto values = value_set->values();
1087 if (!(values.size() == 1 && values[0]->is_phi())) {
1088 return;
1089 }
1090 HloValue::Id phi_id = values[0]->id();
1091 HloValue::Id new_id = phi_graph_.FindOptimizedValue(phi_id);
1092 if (new_id != phi_id) {
1093 VLOG(1) << "Replacing " << values[0]->ToString() << " with "
1094 << GetValue(new_id).ToString();
1095 value_set->Clear();
1096 const HloValue& new_value = GetValue(new_id);
1097 value_set->AddValue(&new_value);
1098 MarkValueForDeletion(phi_id);
1099 }
1100 });
1101 }
1102 }
1103 }
1104
1105 /* static */
Run(const HloModule & module,bool ssa_form,bool bitcast_defines_value,const CanShareBuffer & can_share_buffer)1106 StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
1107 const HloModule& module, bool ssa_form, bool bitcast_defines_value,
1108 const CanShareBuffer& can_share_buffer) {
1109 VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name();
1110 XLA_VLOG_LINES(2, module.ToString());
1111
1112 auto dataflow_analysis = absl::WrapUnique(new HloDataflowAnalysis(
1113 module, ssa_form, bitcast_defines_value, can_share_buffer));
1114
1115 TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
1116 dataflow_analysis->Propagate();
1117 dataflow_analysis->OptimizePhiValues();
1118
1119 // Delete all values marked for deletion.
1120 dataflow_analysis->DeleteMarkedValues();
1121
1122 // Gather and set all non-definition positions of all values. Value deletion
1123 // is rare, so just use a vector indexed by Value::Id rather than a map from
1124 // Value::Id to positions. There should be very few holes in the vector, and
1125 // lookup is faster.
1126 std::vector<std::vector<HloPosition>> value_positions(
1127 dataflow_analysis->next_value_id_);
1128 for (const HloComputation* computation : module.computations()) {
1129 for (HloInstruction* instruction : computation->instructions()) {
1130 for (const auto& pair :
1131 dataflow_analysis->GetInstructionValueSet(instruction)) {
1132 const ShapeIndex& index = pair.first;
1133 const HloValueSet& value_set = pair.second;
1134 for (const HloValue* value : value_set.values()) {
1135 if (value->defining_instruction() != instruction) {
1136 value_positions[value->id()].push_back(
1137 HloPosition{instruction, index});
1138 }
1139 }
1140 }
1141 }
1142 }
1143 for (auto& pair : dataflow_analysis->values_) {
1144 HloValue::Id value_id = pair.first;
1145 HloValue& value = pair.second;
1146 value.SetPositionsAndComputeUses(value_positions[value_id]);
1147 }
1148
1149 // Construct vector of values.
1150 dataflow_analysis->values_vector_.reserve(dataflow_analysis->values_.size());
1151 for (auto& pair : dataflow_analysis->values_) {
1152 dataflow_analysis->values_vector_.push_back(&pair.second);
1153 }
1154 absl::c_sort(dataflow_analysis->values_vector_, HloValue::IdLessThan);
1155
1156 TF_DCHECK_OK(dataflow_analysis->Verify());
1157
1158 XLA_VLOG_LINES(1, dataflow_analysis->ToString());
1159
1160 return std::move(dataflow_analysis);
1161 }
1162
Verify() const1163 Status HloDataflowAnalysis::Verify() const {
1164 // Verify each HloValue appears in the value sets that the value's positions()
1165 // indicate.
1166 for (const HloValue* value : values()) {
1167 for (const HloPosition& position : value->positions()) {
1168 const HloValueSet& value_set = GetValueSet(position);
1169 TF_RET_CHECK(absl::c_linear_search(value_set.values(), value))
1170 << "Value set at position " << position << " does not contain value "
1171 << value->ToShortString();
1172 }
1173 }
1174
1175 // For each value in each value set, verify that the value set's position
1176 // appears in the value's positions().
1177 for (const auto& computation : module_.computations()) {
1178 for (const auto& instruction : computation->instructions()) {
1179 for (const auto& pair : GetInstructionValueSet(instruction)) {
1180 const ShapeIndex& index = pair.first;
1181 const HloValueSet& value_set = pair.second;
1182 const HloPosition position{instruction, index};
1183 for (const HloValue* value : value_set.values()) {
1184 TF_RET_CHECK(absl::c_linear_search(value->positions(), position))
1185 << "Value set at position " << position
1186 << " unexpectedly contains value " << value->ToShortString();
1187 }
1188 }
1189 }
1190 }
1191
1192 return Status::OK();
1193 }
1194
DoesNotUseOperandBuffer(const HloInstruction * operand,const ShapeIndex & index,const HloInstruction * user) const1195 bool HloDataflowAnalysis::DoesNotUseOperandBuffer(
1196 const HloInstruction* operand, const ShapeIndex& index,
1197 const HloInstruction* user) const {
1198 // Return false if no value at 'operand' and 'index' is used at 'user'.
1199 for (const HloValue* value : GetValueSet(operand, index).values()) {
1200 for (const HloUse& use : value->uses()) {
1201 if (use.instruction == user) {
1202 if (user->IsLoopFusion()) {
1203 HloInstruction* fusion_param =
1204 user->fused_parameter(use.operand_number);
1205 const HloValue& value =
1206 GetValueDefinedAt(fusion_param, use.operand_index);
1207 return value.uses().empty();
1208 }
1209 return false;
1210 }
1211 }
1212 }
1213 return true;
1214 }
1215
IsInPlaceOperation(HloOpcode opcode)1216 /*static*/ bool HloDataflowAnalysis::IsInPlaceOperation(HloOpcode opcode) {
1217 return opcode == HloOpcode::kDynamicUpdateSlice ||
1218 opcode == HloOpcode::kScatter;
1219 }
1220
1221 /*static*/ std::vector<std::pair<HloUse, ShapeIndex>>
GetInPlaceInputOutputPairs(HloInstruction * instruction)1222 HloDataflowAnalysis::GetInPlaceInputOutputPairs(HloInstruction* instruction) {
1223 if (IsInPlaceOperation(instruction->opcode())) {
1224 return {{HloUse{instruction, 0, {}}, {}}};
1225 } else if (instruction->opcode() != HloOpcode::kFusion) {
1226 return {};
1227 }
1228 std::vector<std::pair<HloUse, ShapeIndex>> input_output_pairs;
1229 for (auto& indexed_shape : ShapeUtil::GetLeafShapes(instruction->shape())) {
1230 const HloInstruction* hlo_generating_output =
1231 instruction->fused_expression_root();
1232 for (int64 i = 0; i < indexed_shape.index.size(); ++i) {
1233 if (hlo_generating_output->opcode() == HloOpcode::kTuple) {
1234 hlo_generating_output =
1235 hlo_generating_output->operand(indexed_shape.index[i]);
1236 } else {
1237 CHECK_EQ(i, indexed_shape.index.size() - 1);
1238 }
1239 }
1240
1241 if (IsInPlaceOperation(hlo_generating_output->opcode())) {
1242 ShapeIndex operand_index;
1243 const HloInstruction* fusion_parameter =
1244 hlo_generating_output->operand(0);
1245 while (fusion_parameter->opcode() == HloOpcode::kGetTupleElement) {
1246 operand_index.push_front(fusion_parameter->tuple_index());
1247 fusion_parameter = fusion_parameter->operand(0);
1248 }
1249
1250 if (fusion_parameter->opcode() == HloOpcode::kParameter) {
1251 input_output_pairs.emplace_back(
1252 HloUse{instruction, fusion_parameter->parameter_number(),
1253 operand_index},
1254 indexed_shape.index);
1255 }
1256 }
1257 }
1258 return input_output_pairs;
1259 }
1260
CanShareOperandBufferWithUser(HloInstruction * operand,const ShapeIndex & operand_index,HloInstruction * user,const ShapeIndex & user_index) const1261 bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
1262 HloInstruction* operand, const ShapeIndex& operand_index,
1263 HloInstruction* user, const ShapeIndex& user_index) const {
1264 CHECK(user->IsUserOf(operand))
1265 << "user: " << user->ToString() << " operand: " << operand->ToString();
1266 if (operand->opcode() == HloOpcode::kConstant) {
1267 return false;
1268 }
1269 const Shape& operand_subshape =
1270 ShapeUtil::GetSubshape(operand->shape(), operand_index);
1271 const Shape& user_subshape =
1272 ShapeUtil::GetSubshape(user->shape(), user_index);
1273
1274 // Check that operand and user emit the same shape and layout.
1275 if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
1276 return false;
1277 }
1278
1279 // Must-alias relationship returns true for in-place operations (DUS and DUS
1280 // fusions), regardless of the backend.
1281 for (const auto& operand_and_output_index :
1282 GetInPlaceInputOutputPairs(user)) {
1283 if (operand_and_output_index.second != user_index) {
1284 continue;
1285 }
1286 for (const HloUse& use : GetUniqueValueAt(operand, operand_index).uses()) {
1287 if (use == operand_and_output_index.first) {
1288 return true;
1289 }
1290 }
1291 }
1292
1293 if (can_share_buffer_ != nullptr) {
1294 if (absl::optional<bool> hint =
1295 can_share_buffer_(user, operand, user_index)) {
1296 return *hint;
1297 }
1298 }
1299
1300 if (user->opcode() == HloOpcode::kFusion) {
1301 HloInstruction* fusion_param =
1302 user->fused_parameter(user->operand_index(operand));
1303 const HloValue& fusion_param_value =
1304 GetValueDefinedAt(fusion_param, operand_index);
1305
1306 if (user->IsLoopFusion() || user->IsInputFusion()) {
1307 return AreTransitiveUsesElementwiseOrTuple(fusion_param);
1308 }
1309
1310 if (user->IsOutputFusion() &&
1311 user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
1312 // Output fusion with kAdd fused root.
1313
1314 // Check if one operand of kAdd fused root is kDot or kConvolution.
1315 auto* add = user->fused_expression_root();
1316 auto add_operand_it =
1317 absl::c_find_if(add->operands(), [&](HloInstruction* operand) {
1318 return operand->opcode() == HloOpcode::kConvolution ||
1319 operand->opcode() == HloOpcode::kDot;
1320 });
1321 if (add_operand_it == add->operands().end()) {
1322 return false;
1323 }
1324 auto* matched_add_operand = *add_operand_it;
1325 // Calculate operand index of 'add' operand which was not matched above.
1326 const int64 other_add_operand_index =
1327 matched_add_operand == add->operand(0) ? 1 : 0;
1328 // Returns true iff there is exactly one use of 'operand' at shape index
1329 // 'operand_index', and this singleton use is the fused root (at operand
1330 // index 'other_add_operand_index').
1331 if (fusion_param_value.uses().size() == 1) {
1332 const HloUse& use = fusion_param_value.uses()[0];
1333 return use.instruction == user->fused_expression_root() &&
1334 use.operand_number == other_add_operand_index;
1335 }
1336 return false;
1337 }
1338 }
1339
1340 if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
1341 user->opcode() == HloOpcode::kScatter ||
1342 user->opcode() == HloOpcode::kTriangularSolve ||
1343 user->opcode() == HloOpcode::kWhile) {
1344 // We eliminated other users in HloOrdering::LiveRangeStrictlyBefore
1345 // so here we just need to check that the use is at the right operand index.
1346 const auto operand_indices = user->OperandIndices(operand);
1347 int64 operand_no = user->opcode() == HloOpcode::kTriangularSolve ? 1 : 0;
1348 return operand_indices.size() == 1 && operand_indices[0] == operand_no;
1349 }
1350 if (user->opcode() == HloOpcode::kSort) {
1351 // Only valid if there are no other users.
1352 if (operand->users().size() != 1) {
1353 return false;
1354 }
1355 // If we only sort keys, the output of sort is not a tuple, so we can always
1356 // share the buffer.
1357 if (user->operand_count() == 1) {
1358 return true;
1359 }
1360 CHECK(!user_index.empty());
1361 // Only share with the right tuple element buffer.
1362 const auto operand_indices = user->OperandIndices(operand);
1363 return operand_indices.size() == 1 && user_index[0] == operand_indices[0];
1364 }
1365 if (user->opcode() == HloOpcode::kCall) {
1366 // Get all uses of value defined by 'operand' at 'operand_index'.
1367 const auto& uses = GetValueDefinedAt(operand, operand_index).uses();
1368 // Return true iff:
1369 // *) There exists two uses of 'operand'.
1370 // *) One use is by 'user' (caller).
1371 // *) One use is by root instruction of called computation (callee root).
1372 // (Note: we check the root of the called computation, because the
1373 // root result buffer is required to alias with the Call result buffer).
1374 // *) The root instruction of the called computation is element-wise on
1375 // 'operand'.
1376 const bool found_caller_use =
1377 absl::c_find_if(uses, [user](const HloUse& use) {
1378 return use.instruction == user;
1379 }) != uses.end();
1380 auto* callee_root = user->to_apply()->root_instruction();
1381 const bool found_elementwise_callee_use =
1382 absl::c_find_if(uses, [callee_root](const HloUse& use) {
1383 return use.instruction == callee_root &&
1384 callee_root->IsElementwiseOnOperand(use.operand_number);
1385 }) != uses.end();
1386 return uses.size() == 2 && found_caller_use && found_elementwise_callee_use;
1387 }
1388
1389 // Loop fusions that contain transposing copies won't reach here as they have
1390 // different layouts, which fails the check in the beginning of this function.
1391 return user->IsElementwiseOnOperand(user->operand_index(operand));
1392 }
1393
1394 } // namespace xla
1395