1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/bfloat16_propagation.h"
17
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "tensorflow/compiler/xla/literal.h"
21 #include "tensorflow/compiler/xla/map_util.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_dce.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
28 #include "tensorflow/compiler/xla/shape_tree.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/core/lib/gtl/cleanup.h"
31 #include "tensorflow/core/platform/logging.h"
32
33 namespace xla {
34
BFloat16Propagation(const BFloat16Support * bfloat16_support)35 BFloat16Propagation::BFloat16Propagation(
36 const BFloat16Support* bfloat16_support)
37 : bfloat16_support_(bfloat16_support) {}
38
DetermineFusionComputationPrecision(HloInstruction * fusion)39 void BFloat16Propagation::DetermineFusionComputationPrecision(
40 HloInstruction* fusion) {
41 CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
42 if (!bfloat16_support_->SupportsMixedPrecisions(*fusion)) {
43 return;
44 }
45
46 // We are depending on the fusion node itself having already been analyzed
47 // for whether it can output BF16 and this has been adjusted in the output
48 // shape, and now we're looking to update the interior of the fusion node to
49 // match the new output shape, as well as recursively process the whole fusion
50 // node even if the output shape was not modified.
51 auto root = fusion->fused_instructions_computation()->root_instruction();
52
53 // Adjust root's element types according to the fusion's output shape.
54 ShapeUtil::ForEachSubshape(
55 root->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
56 if (subshape.element_type() != F32) {
57 return;
58 }
59 if (OutputTypeAfterChange(fusion, index) == BF16) {
60 AddToOrRemoveFromBF16ChangeSet(root, index, BF16);
61 VLOG(2) << "Fused root " << root->ToString() << " at shape index "
62 << index << " changed to BF16 precision for fusion "
63 << fusion->ToString();
64 }
65 });
66
67 // Propagate BF16 in the fusion computation.
68 auto insts =
69 fusion->fused_instructions_computation()->MakeInstructionPostOrder();
70 for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
71 DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
72 }
73 computations_visited_in_backward_pass_.insert(
74 fusion->fused_instructions_computation());
75
76 RevertIfFusionInternalBF16Changes(fusion);
77 }
78
RevertIfFusionInternalBF16Changes(HloInstruction * fusion)79 void BFloat16Propagation::RevertIfFusionInternalBF16Changes(
80 HloInstruction* fusion) {
81 auto has_changes = [this](HloInstruction* inst) {
82 auto it = changes_to_bf16_.find(inst);
83 return it != changes_to_bf16_.end() && !it->second.empty();
84 };
85
86 auto root = fusion->fused_instructions_computation()->root_instruction();
87 absl::flat_hash_set<const HloValue*> changed_root_buffers;
88
89 auto root_changes_it = changes_to_bf16_.find(root);
90 if (root_changes_it != changes_to_bf16_.end()) {
91 for (const auto& entry : root_changes_it->second) {
92 for (const HloValue* value :
93 dataflow_->GetValueSet(root, entry.second).values()) {
94 changed_root_buffers.insert(value);
95 }
96 }
97 }
98
99 auto aliases_changed_root_buffer =
100 [this, &changed_root_buffers](const HloInstruction* inst) {
101 bool aliasing = false;
102 ShapeUtil::ForEachSubshape(
103 inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
104 if (aliasing) {
105 // Skip if aliasing is already found.
106 return;
107 }
108 // Only F32 buffers are considered for changing to BF16 in this
109 // pass.
110 if (subshape.element_type() != F32) {
111 return;
112 }
113 for (const HloValue* value :
114 dataflow_->GetValueSet(inst, index).values()) {
115 if (ContainsKey(changed_root_buffers, value)) {
116 aliasing = true;
117 break;
118 }
119 }
120 });
121 return aliasing;
122 };
123
124 for (auto inst :
125 fusion->fused_instructions_computation()->MakeInstructionPostOrder()) {
126 if (inst->opcode() == HloOpcode::kParameter) {
127 continue;
128 }
129 if (aliases_changed_root_buffer(inst)) {
130 continue;
131 }
132 if (inst->opcode() == HloOpcode::kFusion) {
133 bool parameter_reverted = false;
134 for (int64_t i = 0; i < inst->operand_count(); ++i) {
135 if (has_changes(inst->mutable_operand(i))) {
136 // Changes on the operand have not been reverted.
137 continue;
138 }
139 auto* fused_parameter = inst->fused_parameter(i);
140 if (has_changes(fused_parameter)) {
141 changes_to_bf16_.erase(fused_parameter);
142 parameter_reverted = true;
143 }
144 }
145 if (parameter_reverted) {
146 RevertIfFusionInternalBF16Changes(inst);
147 }
148 }
149 if (!has_changes(inst)) {
150 continue;
151 }
152 bool revert_changes = true;
153 for (auto operand : inst->operands()) {
154 if (has_changes(operand)) {
155 revert_changes = false;
156 break;
157 }
158 }
159 if (revert_changes) {
160 changes_to_bf16_.erase(inst);
161 }
162 }
163 }
164
DetermineWhileComputationsPrecision(HloInstruction * while_hlo)165 void BFloat16Propagation::DetermineWhileComputationsPrecision(
166 HloInstruction* while_hlo) {
167 CHECK_EQ(while_hlo->opcode(), HloOpcode::kWhile);
168
169 // We are depending on the while node itself having already been analyzed for
170 // whether it can output BF16 and this has been adjusted in the output shape,
171 // and now we're looking to update the body and condition computations to
172 // match the new output shape, as well as recursively process the whole while
173 // node even if the output shape was not modified.
174 HloComputation* body = while_hlo->while_body();
175 auto body_root = body->root_instruction();
176 HloComputation* condition = while_hlo->while_condition();
177
178 ShapeUtil::ForEachSubshape(
179 body_root->shape(), [this, while_hlo, body_root](
180 const Shape& subshape, const ShapeIndex& index) {
181 if (subshape.element_type() != F32) {
182 return;
183 }
184 if (OutputTypeAfterChange(while_hlo, index) == BF16) {
185 AddToOrRemoveFromBF16ChangeSet(body_root, index, BF16);
186 VLOG(2) << "While body root " << body_root->ToString()
187 << " at shape index " << index
188 << " changed to BF16 precision for while "
189 << while_hlo->ToString();
190 }
191 });
192
193 auto body_insts = body->MakeInstructionPostOrder();
194 for (auto inst_it = body_insts.rbegin(); inst_it != body_insts.rend();
195 ++inst_it) {
196 DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
197 }
198 computations_visited_in_backward_pass_.insert(body);
199
200 auto condition_insts = condition->MakeInstructionPostOrder();
201 for (auto inst_it = condition_insts.rbegin();
202 inst_it != condition_insts.rend(); ++inst_it) {
203 DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
204 }
205 computations_visited_in_backward_pass_.insert(condition);
206 }
207
DetermineConditionalComputationsPrecision(HloInstruction * cond)208 void BFloat16Propagation::DetermineConditionalComputationsPrecision(
209 HloInstruction* cond) {
210 CHECK_EQ(cond->opcode(), HloOpcode::kConditional);
211 for (int64_t i = 0; i < cond->branch_count(); ++i) {
212 auto branch = cond->branch_computation(i);
213 auto root = branch->root_instruction();
214 ShapeUtil::ForEachSubshape(
215 root->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
216 if (subshape.element_type() != F32) {
217 return;
218 }
219 if (OutputTypeAfterChange(cond, index) == BF16) {
220 AddToOrRemoveFromBF16ChangeSet(root, index, BF16);
221 VLOG(2) << "Conditional branch " << i << " root "
222 << root->ToString() << " at shape index " << index
223 << " changed to BF16 precision for conditional "
224 << cond->ToString();
225 }
226 });
227 auto insts = branch->MakeInstructionPostOrder();
228 for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
229 DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
230 }
231 computations_visited_in_backward_pass_.insert(branch);
232 }
233 }
234
AllUsersConsumeBF16(const HloInstruction & hlo,const ShapeIndex & index) const235 bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
236 const ShapeIndex& index) const {
237 // If the subshape isn't floating point then none of the users will be BF16.
238 const Shape& subshape = ShapeUtil::GetSubshape(hlo.shape(), index);
239 if (subshape.element_type() != BF16 && subshape.element_type() != F32) {
240 return false;
241 }
242
243 auto& value_set = dataflow_->GetValueSet(&hlo, index);
244 for (const HloValue* value : value_set.values()) {
245 if (ContainsKey(values_that_must_be_kept_as_f32_, value)) {
246 return false;
247 }
248 // We use the original type for the value because we are going to examine
249 // the uses of it, instead of the value itself. If ValueTypeAfterChange()
250 // were used, it would cause problems when there are aliasing buffers, i.e.,
251 // ResolveInconsistencyOfAliasingBuffers() would fail to revert the
252 // tentative change to BF16 even if the uses require F32.
253 if (value->shape().element_type() == BF16) {
254 continue;
255 }
256 for (const HloUse& use : value->uses()) {
257 if (!ContainsKey(instructions_visited_in_backward_pass_,
258 use.instruction)) {
259 // We don't know yet whether use.instruction will consume BF16 since it
260 // hasn't been visited. Although we visit instructions in reverse
261 // topological order, this is still possible because there may be
262 // unvisited instruction that alias the same buffer. In this case, we
263 // aggressively skip this use, and if this causes inconsistency (e.g.,
264 // one use is in BF16 but another use is in F32), it will be resolved at
265 // the end of the BFloat16Propagation pass.
266 continue;
267 }
268 if (use.instruction->HasSideEffectNoRecurse()) {
269 // Keep side-effecting instruction's operands unchanged.
270 return false;
271 }
272 // Any visited user that can accept BF16 has already been updated if
273 // necessary, e.g., the output has been changed to BF16 if it propagates
274 // precision, or a called computation's parameters have been changed to
275 // BF16 for fusions or whiles.
276 if (use.instruction->opcode() == HloOpcode::kFusion) {
277 auto* fused_parameter =
278 use.instruction->fused_parameter(use.operand_number);
279 if (OutputTypeAfterChange(fused_parameter, use.operand_index) != BF16) {
280 return false;
281 }
282 continue;
283 } else if (use.instruction->opcode() == HloOpcode::kWhile) {
284 auto* cond_parameter =
285 use.instruction->while_condition()->parameter_instruction(
286 use.operand_number);
287 if (OutputTypeAfterChange(cond_parameter, use.operand_index) != BF16) {
288 return false;
289 }
290 auto* body_parameter =
291 use.instruction->while_body()->parameter_instruction(
292 use.operand_number);
293 if (OutputTypeAfterChange(body_parameter, use.operand_index) != BF16) {
294 return false;
295 }
296 continue;
297 } else if (use.instruction->opcode() == HloOpcode::kConditional) {
298 auto* cond_parameter =
299 use.instruction->branch_computation(use.operand_number - 1)
300 ->parameter_instruction(0);
301 if (OutputTypeAfterChange(cond_parameter, use.operand_index) != BF16) {
302 return false;
303 }
304 continue;
305 }
306 if (bfloat16_support_->EffectiveOperandPrecisionIsBF16(
307 *use.instruction, use.operand_number)) {
308 continue;
309 }
310 // If the op propagates precision and it outputs a BF16, then it's OK to
311 // supply BF16 also as the input. In the backward pass, the users shapes
312 // should have already been processed.
313 if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(
314 *use.instruction, use.operand_number)) {
315 if (use.instruction->opcode() == HloOpcode::kTuple ||
316 (use.instruction->opcode() == HloOpcode::kAllReduce &&
317 use.instruction->shape().IsTuple())) {
318 ShapeIndex use_output_index{use.operand_number};
319 for (int64_t i : use.operand_index) {
320 use_output_index.push_back(i);
321 }
322 if (OutputTypeAfterChange(use.instruction, use_output_index) ==
323 BF16) {
324 continue;
325 }
326 } else if (use.instruction->opcode() == HloOpcode::kGetTupleElement) {
327 ShapeIndex use_output_index;
328 for (int64_t i = 1; i < use.operand_index.size(); ++i) {
329 use_output_index.push_back(use.operand_index[i]);
330 }
331 if (OutputTypeAfterChange(use.instruction, use_output_index) ==
332 BF16) {
333 continue;
334 }
335 } else {
336 if (OutputTypeAfterChange(use.instruction, use.operand_index) ==
337 BF16) {
338 continue;
339 }
340 }
341 }
342 return false;
343 }
344 }
345 return true;
346 }
347
ShouldKeepPrecisionUnchanged(const HloInstruction * inst)348 bool BFloat16Propagation::ShouldKeepPrecisionUnchanged(
349 const HloInstruction* inst) {
350 if (inst->opcode() == HloOpcode::kFusion &&
351 inst->fusion_kind() == HloInstruction::FusionKind::kCustom) {
352 return ShouldKeepPrecisionUnchanged(
353 inst->fused_instructions_computation()->root_instruction());
354 }
355 // Do not change precision for side-effecting instructions, control flow, and
356 // bitcast-convert, because this pass might break the interfaces or
357 // assumptions for them.
358 return inst->opcode() == HloOpcode::kCustomCall ||
359 inst->opcode() == HloOpcode::kCall ||
360 inst->opcode() == HloOpcode::kBitcastConvert ||
361 inst->HasSideEffectNoRecurse();
362 }
363
DetermineInstructionPrecision(HloInstruction * hlo,bool skip_parameters)364 void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo,
365 bool skip_parameters) {
366 // We handle any fusion computation, while body/condition or conditional
367 // branches after the instruction is handled, because we need to know the
368 // output shape of a fusion or while before propagating inside its
369 // computations.
370 bool postpone_processing_called_computations = false;
371 auto cleaner = tensorflow::gtl::MakeCleanup(
372 [this, hlo, &postpone_processing_called_computations] {
373 if (!postpone_processing_called_computations) {
374 if (hlo->opcode() == HloOpcode::kFusion) {
375 DetermineFusionComputationPrecision(hlo);
376 } else if (hlo->opcode() == HloOpcode::kWhile) {
377 DetermineWhileComputationsPrecision(hlo);
378 } else if (hlo->opcode() == HloOpcode::kConditional) {
379 DetermineConditionalComputationsPrecision(hlo);
380 }
381 }
382 instructions_visited_in_backward_pass_.insert(hlo);
383 });
384
385 if (hlo->opcode() == HloOpcode::kWhile &&
386 (caller_counts_[hlo->while_condition()] > 1 ||
387 caller_counts_[hlo->while_body()] > 1)) {
388 postpone_processing_called_computations = true;
389 return;
390 }
391
392 if (hlo->opcode() == HloOpcode::kConditional &&
393 absl::c_any_of(hlo->branch_computations(), [&](const HloComputation* c) {
394 return caller_counts_[c] > 1;
395 })) {
396 postpone_processing_called_computations = true;
397 return;
398 }
399
400 // Prevent root instructions from having their output modified by recording
401 // all F32 output values as needing to stay as F32.
402 CHECK(hlo->parent() != nullptr);
403 if (hlo == hlo->parent()->root_instruction()) {
404 if (!hlo->parent()->IsFusionComputation()) {
405 ShapeUtil::ForEachSubshape(hlo->shape(), [&](const Shape& /* subshape */,
406 const ShapeIndex& index) {
407 if (OutputTypeAfterChange(hlo, index) != F32) {
408 return;
409 }
410 for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
411 // Since we use HloValues from the dataflow analysis, this can also
412 // affect HLO instructions beyond the root, e.g., if the root is a
413 // Tuple HLO, then its operands are also affected.
414 values_that_must_be_kept_as_f32_.insert(value);
415 }
416 });
417 }
418 return;
419 }
420
421 if (ShouldKeepPrecisionUnchanged(hlo) ||
422 (hlo->opcode() == HloOpcode::kParameter && skip_parameters)) {
423 return;
424 }
425
426 if (!ContainsKey(consider_using_bfloat16_, hlo)) {
427 return;
428 }
429
430 if (!bfloat16_support_->SupportsBF16Output(*hlo)) {
431 return;
432 }
433
434 ShapeUtil::ForEachSubshape(
435 hlo->shape(),
436 [hlo, this](const Shape& /* subshape */, const ShapeIndex& index) {
437 if (OutputTypeAfterChange(hlo, index) == F32 &&
438 AllUsersConsumeBF16(*hlo, index)) {
439 AddToOrRemoveFromBF16ChangeSet(hlo, index, BF16);
440 VLOG(2) << "HloInstruction output at shape index " << index
441 << " changed to BF16 precision: " << hlo->ToString();
442 }
443 });
444 }
445
InstructionIsCandidateForBF16Output(HloInstruction * hlo)446 bool BFloat16Propagation::InstructionIsCandidateForBF16Output(
447 HloInstruction* hlo) {
448 if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) &&
449 hlo->opcode() != HloOpcode::kTuple &&
450 hlo->opcode() != HloOpcode::kGetTupleElement &&
451 hlo->opcode() != HloOpcode::kDomain &&
452 hlo->shape().element_type() != BF16) {
453 for (int64_t i = 0; i < hlo->operand_count(); ++i) {
454 if (!bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo,
455 i) ||
456 !ContainsKey(consider_using_bfloat16_, hlo->operand(i))) {
457 return false;
458 }
459 }
460 }
461 return true;
462 }
463
AdjustCalledComputationParameters(HloInstruction * hlo)464 void BFloat16Propagation::AdjustCalledComputationParameters(
465 HloInstruction* hlo) {
466 auto adjust_computation =
467 [this, hlo](HloComputation* computation,
468 absl::Span<HloInstruction* const> operands) {
469 // Adjust parameters.
470 CHECK_EQ(operands.size(), computation->num_parameters());
471 for (int64_t i = 0; i < operands.size(); ++i) {
472 auto parameter = computation->parameter_instruction(i);
473 ShapeUtil::ForEachSubshape(
474 parameter->shape(),
475 [this, i, hlo, &operands, parameter](const Shape& /* subshape */,
476 const ShapeIndex& index) {
477 if (!ShapeUtil::IsLeafIndex(parameter->shape(), index)) {
478 return;
479 }
480 PrimitiveType operand_type =
481 OutputTypeAfterChange(operands[i], index);
482 if (OutputTypeAfterChange(parameter, index) == operand_type) {
483 return;
484 }
485 AddToOrRemoveFromBF16ChangeSet(parameter, index, operand_type);
486 VLOG(2) << "Called computation parameter "
487 << parameter->ToString() << " at shape index " << index
488 << " adjusted to "
489 << (operand_type == BF16 ? "BF16" : "F32")
490 << " to match operand in HLO " << hlo->ToString();
491 });
492 }
493 };
494
495 switch (hlo->opcode()) {
496 case HloOpcode::kFusion:
497 adjust_computation(hlo->fused_instructions_computation(),
498 hlo->operands());
499 break;
500 case HloOpcode::kWhile:
501 adjust_computation(hlo->while_condition(), hlo->operands());
502 adjust_computation(hlo->while_body(), hlo->operands());
503 break;
504 case HloOpcode::kConditional:
505 for (int64_t i = 0; i < hlo->branch_count(); ++i) {
506 adjust_computation(hlo->branch_computation(i),
507 {hlo->mutable_operand(i + 1)});
508 }
509 break;
510 default:
511 break;
512 }
513 }
514
AdjustCalledComputationRoot(HloInstruction * hlo)515 void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) {
516 auto adjust_computation = [this, hlo](HloComputation* computation,
517 HloInstruction* output) {
518 // Adjust root.
519 HloInstruction* root = computation->root_instruction();
520 ShapeUtil::ForEachSubshape(root->shape(), [this, hlo, root, output](
521 const Shape& /* subshape */,
522 const ShapeIndex& index) {
523 if (!ShapeUtil::IsLeafIndex(hlo->shape(), index)) {
524 return;
525 }
526 const PrimitiveType output_type = OutputTypeAfterChange(output, index);
527 if (OutputTypeAfterChange(root, index) == output_type) {
528 return;
529 }
530 AddToOrRemoveFromBF16ChangeSet(root, index, output_type);
531 // It's possible that output_type is F32, but the root instruction's
532 // type is BF16; e.g., a fusion node's output was changed to BF16
533 // initially but then adjusted back to F32, and the fusion computation
534 // is now being adjusted after the fusion node.
535 if (output_type == F32) {
536 for (const auto* value : dataflow_->GetValueSet(root, index).values()) {
537 // We rely on the fact that this adjustment works in reverse
538 // topological order so that called computation will be
539 // processed later. Adding the value to
540 // values_that_must_be_kept_as_f32_ will ensure the
541 // correctness of the adjustment for HLOs that will be
542 // processed later.
543 values_that_must_be_kept_as_f32_.insert(value);
544 }
545 }
546 VLOG(2) << "Called computation root " << root->ToString()
547 << " at shape index " << index << " adjusted to "
548 << (output_type == BF16 ? "BF16" : "F32")
549 << " to match output shape of " << hlo->ToString();
550 });
551 };
552
553 switch (hlo->opcode()) {
554 case HloOpcode::kFusion:
555 adjust_computation(hlo->fused_instructions_computation(), hlo);
556 break;
557 case HloOpcode::kWhile:
558 adjust_computation(hlo->while_body(), hlo);
559 break;
560 case HloOpcode::kConditional:
561 for (auto* branch : hlo->branch_computations()) {
562 adjust_computation(branch, hlo);
563 }
564 break;
565 default:
566 break;
567 }
568 }
569
ResolveInconsistencyOfAliasingBuffersHelper(HloComputation * computation,absl::flat_hash_set<const HloComputation * > * visited_computations)570 bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
571 HloComputation* computation,
572 absl::flat_hash_set<const HloComputation*>* visited_computations) {
573 bool parameter_changed = false;
574 auto insts = computation->MakeInstructionPostOrder();
575 // Do the adjustment on each instruction in the computation in reverse
576 // topological order.
577 while (true) {
578 bool any_change = false;
579 for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
580 auto hlo = *inst_it;
581 auto adjust_hlo_output = [&](const Shape& /* subshape */,
582 const ShapeIndex& index) {
583 auto output_type = OutputTypeAfterChange(hlo, index);
584 VLOG(2) << "output_type is " << ((output_type == BF16) ? "BF16" : "F32")
585 << " for :" << hlo->ToString() << "\n";
586 if (output_type != F32 && output_type != BF16) {
587 return;
588 }
589 PrimitiveType type = BF16;
590 for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
591 auto value_type = ValueTypeAfterChange(value);
592 if (value_type == BF16) {
593 continue;
594 }
595 VLOG(2) << "Adjust to F32 due to aliased dataflow value: "
596 << value->ToString() << "\n";
597 CHECK_EQ(value_type, F32);
598 type = F32;
599 break;
600 }
601 // In order to find aliases due to in-place operations, use
602 // GetInPlaceInputOutputPairs. Ideally, we'd use HloAliasAnalysis here,
603 // but this code works with HloModules that aren't ready yet to use
604 // HloAliasAnalysis (e.g., their computation graphs may not have been
605 // flattened yet).
606 for (const auto& operand_and_output_index :
607 HloDataflowAnalysis::GetInPlaceInputOutputPairs(hlo)) {
608 if (operand_and_output_index.second == index) {
609 const HloUse& operand = operand_and_output_index.first;
610 for (const auto* value :
611 dataflow_
612 ->GetValueSet(hlo->operand(operand.operand_number),
613 operand.operand_index)
614 .values()) {
615 auto value_type = ValueTypeAfterChange(value);
616 if (value_type == BF16) {
617 continue;
618 }
619 VLOG(2) << "Adjust to F32 due to InputOutPair: "
620 << value->ToString() << "\n";
621 CHECK_EQ(value_type, F32);
622 type = F32;
623 break;
624 }
625 }
626 }
627
628 // It's possible that a user has been changed from BF16 to F32
629 // during this final adjustment pass, so we need to check
630 // AllUsersConsumeBF16() again.
631 if (type == BF16 && !AllUsersConsumeBF16(*hlo, index)) {
632 VLOG(2) << "Adjust to F32 due to All user consumeBF16 fail\n";
633 type = F32;
634 }
635 if (type == F32) {
636 for (const auto* value :
637 dataflow_->GetValueSet(hlo, index).values()) {
638 // We rely on the fact that this adjustment works in reverse
639 // topological order. Adding the value to
640 // values_that_must_be_kept_as_f32_ will ensure the correctness
641 // of the adjustment for HLOs that will be processed later.
642 values_that_must_be_kept_as_f32_.insert(value);
643 }
644 }
645 if (type != output_type) {
646 any_change = true;
647 AddToOrRemoveFromBF16ChangeSet(hlo, index, type);
648 VLOG(2) << "HloInstruction output at shape index " << index
649 << " adjusted to " << (type == BF16 ? "BF16" : "F32") << ": "
650 << hlo->ToString();
651 if (hlo->opcode() == HloOpcode::kParameter) {
652 parameter_changed = true;
653 }
654 }
655 };
656 ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
657 AdjustCalledComputationRoot(hlo);
658 if (hlo->opcode() == HloOpcode::kWhile) {
659 // We need to run on the while body and condition repeatedly until a
660 // fixed point is reached, i.e., the parameters do not change any more.
661 // We may need more than one iteration because the while input and
662 // output alias each other, so changing one input parameter requires
663 // changing the corresponding output element and thus may transitively
664 // require changing another input parameter. A fixed point will be
665 // reached because the parameters can only be changed from BF16 to F32,
666 // not the other way around.
667 absl::flat_hash_set<const HloComputation*> visited_in_while;
668 while (ResolveInconsistencyOfAliasingBuffersHelper(
669 hlo->while_condition(), &visited_in_while) ||
670 ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(),
671 &visited_in_while)) {
672 visited_in_while.clear();
673 ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
674 AdjustCalledComputationRoot(hlo);
675 }
676 visited_computations->insert(visited_in_while.begin(),
677 visited_in_while.end());
678 } else if (hlo->opcode() == HloOpcode::kFusion) {
679 ResolveInconsistencyOfAliasingBuffersHelper(
680 hlo->fused_instructions_computation(), visited_computations);
681 } else if (hlo->opcode() == HloOpcode::kConditional) {
682 for (auto* branch : hlo->branch_computations()) {
683 ResolveInconsistencyOfAliasingBuffersHelper(branch,
684 visited_computations);
685 }
686 }
687 }
688 if (!any_change) {
689 break;
690 }
691 }
692 // Now adjust parameters of called computations.
693 for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
694 AdjustCalledComputationParameters(*inst_it);
695 }
696 return parameter_changed;
697 }
698
ResolveInconsistencyOfAliasingBuffers(HloModule * module)699 void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
700 HloModule* module) {
701 const auto& computations_topological_order =
702 module->MakeComputationPostOrder();
703 absl::flat_hash_set<const HloComputation*> resolved;
704 for (auto comp_it = computations_topological_order.rbegin();
705 comp_it != computations_topological_order.rend(); ++comp_it) {
706 if (ContainsKey(resolved, *comp_it)) {
707 continue;
708 }
709 ResolveInconsistencyOfAliasingBuffersHelper(*comp_it, &resolved);
710 }
711 }
712
ResolveInconsistentFusions(HloModule * module)713 Status BFloat16Propagation::ResolveInconsistentFusions(HloModule* module) {
714 // We could have changed a fusion computation's root shape to have a different
715 // precision than the fusion node's output, if the fusion root does not
716 // define a buffer (e.g., a tuple). Now we add conversions after such fusion
717 // roots to make them match the fusion output. If the fusion output is a
718 // (possibly nested) tuple, we first create get-tuple-elements, then convert
719 // the unmatching leaf nodes, and finally create a new tuple as the fusion
720 // computation's root. If tuples and get-tuple-elements are created, we will
721 // run tuple simplifier and dead code elimination at the end (dead code is not
722 // allowed in fusion computation). E.g.,
723 //
724 // (1) (2) (3)
725 // a b a b a b
726 // |\ | |\ | |\ |
727 // \ add -> |add -> | add
728 // \ | \ | convert |
729 // tuple tuple \ |
730 // / \ tuple
731 // gte gte
732 // | |
733 // convert |
734 // \ /
735 // tuple
736 // (1) a is F32 but tuple is BF16
737 // (2) after adding conversion
738 // (3) after tuple simplifier and DCE.
739 for (auto computation : module->MakeComputationPostOrder()) {
740 auto insts = computation->MakeInstructionPostOrder();
741 for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
742 auto hlo = *inst_it;
743 if (hlo->opcode() != HloOpcode::kFusion) {
744 continue;
745 }
746 auto fusion_computation = hlo->fused_instructions_computation();
747 auto fusion_root = fusion_computation->root_instruction();
748 if (ShapeUtil::Compatible(fusion_root->shape(), hlo->shape())) {
749 continue;
750 }
751 ShapeTree<HloInstruction*> converted_outputs(hlo->shape());
752 // Deep copy the fusion root, and convert a leaf node only if its shape
753 // does not match the fusion output.
754 TF_ASSIGN_OR_RETURN(
755 HloInstruction * copy,
756 fusion_computation->DeepCopyInstructionWithCustomCopier(
757 fusion_root,
758 [hlo](HloInstruction* leaf, const ShapeIndex& leaf_index,
759 HloComputation* comp) {
760 const Shape& hlo_subshape =
761 ShapeUtil::GetSubshape(hlo->shape(), leaf_index);
762 if (ShapeUtil::Compatible(leaf->shape(), hlo_subshape)) {
763 return leaf;
764 }
765 return comp->AddInstruction(
766 HloInstruction::CreateConvert(hlo_subshape, leaf));
767 }));
768 fusion_computation->set_root_instruction(copy);
769 }
770 }
771 return Status::OK();
772 }
773
ResolveConvertedConstants(HloModule * module)774 Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) {
775 // We may have converted some constants from F32 to BF16, so adjust the
776 // constant literals in such cases. We do this here instead of when the
777 // constant node's is changed because 1) the HloInstruction interface does not
778 // allow resetting the literal so we have to create a new kConstant
779 // instruction to replace the old one, which invalidates dataflow analysis,
780 // and 2) it's possible that a kConstant's output gets changed to BF16 at the
781 // beginning but later on adjusted back to F32, so converting literals here
782 // can avoid repeated conversions.
783 //
784 // TODO(b/73833576): Consider resetting literal in HloInstruction.
785 for (auto computation : module->MakeComputationPostOrder()) {
786 for (auto hlo : computation->MakeInstructionPostOrder()) {
787 if (hlo->opcode() != HloOpcode::kConstant) {
788 continue;
789 }
790 if (!Shape::Equal().MinorToMajorOnlyInLayout()(hlo->literal().shape(),
791 hlo->shape())) {
792 TF_ASSIGN_OR_RETURN(auto converted_literal,
793 hlo->literal().ConvertToShape(hlo->shape()));
794 auto new_constant = computation->AddInstruction(
795 HloInstruction::CreateConstant(std::move(converted_literal)));
796 UpdateLayout(new_constant->mutable_shape());
797 TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant));
798 }
799 }
800 }
801 return Status::OK();
802 }
803
SkipNoopConversions(HloModule * module)804 Status BFloat16Propagation::SkipNoopConversions(HloModule* module) {
805 for (auto computation : module->computations()) {
806 for (auto hlo : computation->MakeInstructionPostOrder()) {
807 if (hlo->opcode() != HloOpcode::kConvert) {
808 continue;
809 }
810 auto source = hlo->mutable_operand(0);
811 if (!ShapeUtil::Equal(source->shape(), hlo->shape())) {
812 continue;
813 }
814 const bool is_root = hlo == computation->root_instruction();
815 TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(source));
816 if (is_root) {
817 computation->set_root_instruction(source);
818 }
819 }
820 }
821 return Status::OK();
822 }
823
824 // The algorithm first does a forward pass (parameters to root) to determine a
825 // set of instructions to consider using bfloat16, then does a backward pass to
826 // determine the precisions of those instructions according to the need of
827 // their users. During the backward pass, the potential changes are stored in
828 // changes_to_bf16_ which are subject to further adjustments then applied to the
829 // HLOs.
Run(HloModule * module)830 StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
831 consider_using_bfloat16_.clear();
832 instructions_visited_in_backward_pass_.clear();
833 computations_visited_in_backward_pass_.clear();
834 values_that_must_be_kept_as_f32_.clear();
835 caller_counts_.clear();
836 changes_to_bf16_.clear();
837 changed_ = false;
838
839 auto computations_topological_order = module->MakeComputationPostOrder();
840
841 // Before running the propagation pass, we insert copies (kConvert to the same
842 // type) of F32 inputs to while loops. This prevents other uses of the same
843 // input from aliasing the while loop input/output, so that there's greater
844 // chance to use BF16 inside the loop. If some of these added copies do not
845 // help, they will remain F32 after BF16 propagation and will be removed since
846 // they are no-ops.
847 for (auto computation : computations_topological_order) {
848 for (auto inst : computation->MakeInstructionPostOrder()) {
849 if (inst->opcode() != HloOpcode::kWhile) {
850 continue;
851 }
852
853 auto operand = inst->mutable_operand(0);
854 TF_ASSIGN_OR_RETURN(
855 HloInstruction * copy,
856 computation->DeepCopyInstructionWithCustomCopier(
857 operand, [](HloInstruction* leaf, const ShapeIndex& leaf_index,
858 HloComputation* comp) {
859 if (leaf->shape().element_type() != F32) {
860 return leaf;
861 }
862 return comp->AddInstruction(
863 HloInstruction::CreateConvert(leaf->shape(), leaf));
864 }));
865 TF_RETURN_IF_ERROR(operand->ReplaceUseWith(inst, copy));
866 }
867 }
868
869 TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module));
870
871 // The first step is a forward pass (parameters to root), where we determine
872 // the potential candidate instructions to use bfloat16 in the outputs that
873 // are not likely to cause overhead from extra explicit conversions. This is
874 // done forwardly because we determine whether an HLO is a candidate partially
875 // based on whether its operands are candidates.
876 for (auto computation : computations_topological_order) {
877 for (auto inst : computation->MakeInstructionPostOrder()) {
878 if (InstructionIsCandidateForBF16Output(inst)) {
879 consider_using_bfloat16_.insert(inst);
880 }
881 }
882 }
883
884 // The second step is a backward pass (root to parameters), where we modify
885 // the precisions of the instructions identified in the first step when
886 // feasible. This is done backwardly because we determine the precision of an
887 // HLO's output based on how it is later used.
888 //
889 // The precision of an instruction is determined by its users, so we do the
890 // propagation in reverse topological order.
891 for (auto comp_it = computations_topological_order.rbegin();
892 comp_it != computations_topological_order.rend(); ++comp_it) {
893 if (ContainsKey(computations_visited_in_backward_pass_, *comp_it)) {
894 continue;
895 }
896 auto insts = (*comp_it)->MakeInstructionPostOrder();
897 for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
898 DetermineInstructionPrecision(*inst_it,
899 /*skip_parameters=*/true);
900 }
901 computations_visited_in_backward_pass_.insert(*comp_it);
902 }
903
904 // It's possible that an instruction does not define a buffer, but the
905 // defining instruction's shape has changed. So we need to adjust the output
906 // shapes of instructions according to the HLO values they refer to.
907 ResolveInconsistencyOfAliasingBuffers(module);
908
909 // Apply the changes in changes_to_bf16_.
910 for (auto& change : changes_to_bf16_) {
911 auto inst = change.first;
912 // It is possible that we marked inst to change precision even if it is an
913 // unsupported change, when inst is the root of a fusion computation and it
914 // has to match the fusion node's output precision. We do a convert instead
915 // of in-place change for such cases.
916 if (ShouldKeepPrecisionUnchanged(inst)) {
917 auto users = inst->users();
918 bool is_root = inst == inst->parent()->root_instruction();
919 TF_ASSIGN_OR_RETURN(
920 HloInstruction * copy,
921 inst->parent()->DeepCopyInstructionWithCustomCopier(
922 inst, [&](HloInstruction* leaf, const ShapeIndex& leaf_index,
923 HloComputation* comp) {
924 if (!ContainsKey(change.second,
925 ShapeUtil::GetMutableSubshape(
926 inst->mutable_shape(), leaf_index))) {
927 return leaf;
928 }
929 auto converted_shape =
930 ShapeUtil::ChangeElementType(leaf->shape(), BF16);
931 UpdateLayout(&converted_shape);
932 return comp->AddInstruction(
933 HloInstruction::CreateConvert(converted_shape, leaf));
934 }));
935 for (auto user : users) {
936 TF_RETURN_IF_ERROR(inst->ReplaceUseWithDifferentShape(user, copy));
937 }
938 if (is_root) {
939 inst->parent()->set_root_instruction(copy,
940 /*accept_different_shape=*/true);
941 }
942 continue;
943 }
944 for (const auto& entry : change.second) {
945 auto subshape = entry.first;
946 CHECK_EQ(subshape->element_type(), F32);
947 subshape->set_element_type(BF16);
948 UpdateLayout(subshape);
949 changed_ = true;
950 }
951 }
952
953 // Removes redundant HLOs added by this pass, either when inserting
954 // de-aliasing copies to while loop inputs, or later when converting output
955 // types.
956 auto clean_up = [this, module]() {
957 TF_RETURN_IF_ERROR(SkipNoopConversions(module));
958 TupleSimplifier tuple_simplifier;
959 TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
960 HloDCE dce;
961 TF_RETURN_IF_ERROR(dce.Run(module).status());
962 return Status::OK();
963 };
964
965 if (!changed_) {
966 TF_RETURN_IF_ERROR(clean_up());
967 return false;
968 }
969
970 TF_RETURN_IF_ERROR(ResolveInconsistentFusions(module));
971 TF_RETURN_IF_ERROR(ResolveConvertedConstants(module));
972
973 TF_RETURN_IF_ERROR(clean_up());
974 return true;
975 }
976
OutputTypeAfterChange(HloInstruction * hlo,const ShapeIndex & index) const977 PrimitiveType BFloat16Propagation::OutputTypeAfterChange(
978 HloInstruction* hlo, const ShapeIndex& index) const {
979 Shape* subshape = ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index);
980 const PrimitiveType type_on_hlo = subshape->element_type();
981 if (type_on_hlo != F32) {
982 return type_on_hlo;
983 }
984 auto it = changes_to_bf16_.find(hlo);
985 if (it == changes_to_bf16_.end()) {
986 return type_on_hlo;
987 }
988 return ContainsKey(it->second, subshape) ? BF16 : F32;
989 }
990
ValueTypeAfterChange(const HloValue * value) const991 PrimitiveType BFloat16Propagation::ValueTypeAfterChange(
992 const HloValue* value) const {
993 auto hlo = value->defining_instruction();
994 const auto& position = value->defining_position();
995 return OutputTypeAfterChange(hlo, position.index);
996 }
997
AddToOrRemoveFromBF16ChangeSet(HloInstruction * hlo,const ShapeIndex & index,PrimitiveType target_type)998 void BFloat16Propagation::AddToOrRemoveFromBF16ChangeSet(
999 HloInstruction* hlo, const ShapeIndex& index, PrimitiveType target_type) {
1000 if (target_type == BF16) {
1001 auto& entry = changes_to_bf16_[hlo];
1002 entry.emplace(ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index),
1003 index);
1004 } else {
1005 CHECK_EQ(target_type, F32);
1006 auto it = changes_to_bf16_.find(hlo);
1007 if (it == changes_to_bf16_.end()) {
1008 return;
1009 }
1010 it->second.erase(
1011 ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index));
1012 }
1013 }
1014
1015 } // namespace xla
1016