1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/bfloat16_propagation.h"
17
18 #include "absl/container/flat_hash_set.h"
19 #include "tensorflow/compiler/xla/literal.h"
20 #include "tensorflow/compiler/xla/map_util.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_dce.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_module.h"
25 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
26 #include "tensorflow/compiler/xla/shape_tree.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/core/lib/gtl/cleanup.h"
29 #include "tensorflow/core/platform/logging.h"
30
31 namespace xla {
32
BFloat16Propagation(const BFloat16Support * bfloat16_support)33 BFloat16Propagation::BFloat16Propagation(
34 const BFloat16Support* bfloat16_support)
35 : bfloat16_support_(bfloat16_support) {}
36
DetermineFusionComputationPrecision(HloInstruction * fusion)37 void BFloat16Propagation::DetermineFusionComputationPrecision(
38 HloInstruction* fusion) {
39 CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
40 if (!bfloat16_support_->SupportsMixedPrecisions(*fusion)) {
41 return;
42 }
43
44 // We are depending on the fusion node itself having already been analyzed
45 // for whether it can output BF16 and this has been adjusted in the output
46 // shape, and now we're looking to update the interior of the fusion node to
47 // match the new output shape, as well as recursively process the whole fusion
48 // node even if the output shape was not modified.
49 auto root = fusion->fused_instructions_computation()->root_instruction();
50
51 // Adjust root's element types according to the fusion's output shape.
52 ShapeUtil::ForEachSubshape(
53 root->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
54 if (subshape.element_type() != F32) {
55 return;
56 }
57 if (OutputTypeAfterChange(fusion, index) == BF16) {
58 AddToOrRemoveFromBF16ChangeSet(root, index, BF16);
59 VLOG(2) << "Fused root " << root->ToString() << " at shape index "
60 << index << " changed to BF16 precision for fusion "
61 << fusion->ToString();
62 }
63 });
64
65 // Propagate BF16 in the fusion computation.
66 auto insts =
67 fusion->fused_instructions_computation()->MakeInstructionPostOrder();
68 for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
69 DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
70 }
71 computations_visited_in_backward_pass_.insert(
72 fusion->fused_instructions_computation());
73
74 RevertIfFusionInternalBF16Changes(fusion);
75 }
76
RevertIfFusionInternalBF16Changes(HloInstruction * fusion)77 void BFloat16Propagation::RevertIfFusionInternalBF16Changes(
78 HloInstruction* fusion) {
79 auto has_changes = [this](HloInstruction* inst) {
80 auto it = changes_to_bf16_.find(inst);
81 return it != changes_to_bf16_.end() && !it->second.empty();
82 };
83
84 auto root = fusion->fused_instructions_computation()->root_instruction();
85 absl::flat_hash_set<const HloValue*> changed_root_buffers;
86
87 auto root_changes_it = changes_to_bf16_.find(root);
88 if (root_changes_it != changes_to_bf16_.end()) {
89 for (const auto& entry : root_changes_it->second) {
90 for (const HloValue* value :
91 dataflow_->GetValueSet(root, entry.second).values()) {
92 changed_root_buffers.insert(value);
93 }
94 }
95 }
96
97 auto aliases_changed_root_buffer =
98 [this, &changed_root_buffers](const HloInstruction* inst) {
99 bool aliasing = false;
100 ShapeUtil::ForEachSubshape(
101 inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
102 if (aliasing) {
103 // Skip if aliasing is already found.
104 return;
105 }
106 // Only F32 buffers are considered for changing to BF16 in this
107 // pass.
108 if (subshape.element_type() != F32) {
109 return;
110 }
111 for (const HloValue* value :
112 dataflow_->GetValueSet(inst, index).values()) {
113 if (ContainsKey(changed_root_buffers, value)) {
114 aliasing = true;
115 break;
116 }
117 }
118 });
119 return aliasing;
120 };
121
122 for (auto inst :
123 fusion->fused_instructions_computation()->MakeInstructionPostOrder()) {
124 if (inst->opcode() == HloOpcode::kParameter) {
125 continue;
126 }
127 if (aliases_changed_root_buffer(inst)) {
128 continue;
129 }
130 if (inst->opcode() == HloOpcode::kFusion) {
131 bool parameter_reverted = false;
132 for (int64 i = 0; i < inst->operand_count(); ++i) {
133 if (has_changes(inst->mutable_operand(i))) {
134 // Changes on the operand have not been reverted.
135 continue;
136 }
137 auto* fused_parameter = inst->fused_parameter(i);
138 if (has_changes(fused_parameter)) {
139 changes_to_bf16_.erase(fused_parameter);
140 parameter_reverted = true;
141 }
142 }
143 if (parameter_reverted) {
144 RevertIfFusionInternalBF16Changes(inst);
145 }
146 }
147 if (!has_changes(inst)) {
148 continue;
149 }
150 bool revert_changes = true;
151 for (auto operand : inst->operands()) {
152 if (has_changes(operand)) {
153 revert_changes = false;
154 break;
155 }
156 }
157 if (revert_changes) {
158 changes_to_bf16_.erase(inst);
159 }
160 }
161 }
162
DetermineWhileComputationsPrecision(HloInstruction * while_hlo)163 void BFloat16Propagation::DetermineWhileComputationsPrecision(
164 HloInstruction* while_hlo) {
165 CHECK_EQ(while_hlo->opcode(), HloOpcode::kWhile);
166
167 // We are depending on the while node itself having already been analyzed for
168 // whether it can output BF16 and this has been adjusted in the output shape,
169 // and now we're looking to update the body and condition computations to
170 // match the new output shape, as well as recursively process the whole while
171 // node even if the output shape was not modified.
172 HloComputation* body = while_hlo->while_body();
173 auto body_root = body->root_instruction();
174 HloComputation* condition = while_hlo->while_condition();
175
176 ShapeUtil::ForEachSubshape(
177 body_root->shape(), [this, while_hlo, body_root](
178 const Shape& subshape, const ShapeIndex& index) {
179 if (subshape.element_type() != F32) {
180 return;
181 }
182 if (OutputTypeAfterChange(while_hlo, index) == BF16) {
183 AddToOrRemoveFromBF16ChangeSet(body_root, index, BF16);
184 VLOG(2) << "While body root " << body_root->ToString()
185 << " at shape index " << index
186 << " changed to BF16 precision for while "
187 << while_hlo->ToString();
188 }
189 });
190
191 auto body_insts = body->MakeInstructionPostOrder();
192 for (auto inst_it = body_insts.rbegin(); inst_it != body_insts.rend();
193 ++inst_it) {
194 DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
195 }
196 computations_visited_in_backward_pass_.insert(body);
197
198 auto condition_insts = condition->MakeInstructionPostOrder();
199 for (auto inst_it = condition_insts.rbegin();
200 inst_it != condition_insts.rend(); ++inst_it) {
201 DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
202 }
203 computations_visited_in_backward_pass_.insert(condition);
204 }
205
AllUsersConsumeBF16(const HloInstruction & hlo,const ShapeIndex & index) const206 bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
207 const ShapeIndex& index) const {
208 // If the subshape isn't floating point then none of the users will be BF16.
209 const Shape& subshape = ShapeUtil::GetSubshape(hlo.shape(), index);
210 if (subshape.element_type() != BF16 && subshape.element_type() != F32) {
211 return false;
212 }
213
214 auto& value_set = dataflow_->GetValueSet(&hlo, index);
215 for (const HloValue* value : value_set.values()) {
216 if (ContainsKey(values_that_must_be_kept_as_f32_, value)) {
217 return false;
218 }
219 // We use the original type for the value because we are going to examine
220 // the uses of it, instead of the value itself. If ValueTypeAfterChange()
221 // were used, it would cause problems when there are aliasing buffers, i.e.,
222 // ResolveInconsistencyOfAliasingBuffers() would fail to revert the
223 // tentative change to BF16 even if the uses require F32.
224 if (value->shape().element_type() == BF16) {
225 continue;
226 }
227 for (const HloUse& use : value->uses()) {
228 if (!ContainsKey(instructions_visited_in_backward_pass_,
229 use.instruction)) {
230 // We don't know yet whether use.instruction will consume BF16 since it
231 // hasn't been visited. Although we visit instructions in reverse
232 // topological order, this is still possible because there may be
233 // unvisited instruction that alias the same buffer. In this case, we
234 // aggressively skip this use, and if this causes inconsistency (e.g.,
235 // one use is in BF16 but another use is in F32), it will be resolved at
236 // the end of the BFloat16Propagation pass.
237 continue;
238 }
239 if (use.instruction->HasSideEffectNoRecurse()) {
240 // Keep side-effecting instruction's operands unchanged.
241 return false;
242 }
243 // Any visited user that can accept BF16 has already been updated if
244 // necessary, e.g., the output has been changed to BF16 if it propagates
245 // precision, or a called computation's parameters have been changed to
246 // BF16 for fusions or whiles.
247 if (use.instruction->opcode() == HloOpcode::kFusion) {
248 auto* fused_parameter =
249 use.instruction->fused_parameter(use.operand_number);
250 if (OutputTypeAfterChange(fused_parameter, use.operand_index) != BF16) {
251 return false;
252 }
253 continue;
254 } else if (use.instruction->opcode() == HloOpcode::kWhile) {
255 auto* cond_parameter =
256 use.instruction->while_condition()->parameter_instruction(
257 use.operand_number);
258 if (OutputTypeAfterChange(cond_parameter, use.operand_index) != BF16) {
259 return false;
260 }
261 auto* body_parameter =
262 use.instruction->while_body()->parameter_instruction(
263 use.operand_number);
264 if (OutputTypeAfterChange(body_parameter, use.operand_index) != BF16) {
265 return false;
266 }
267 continue;
268 }
269 if (bfloat16_support_->EffectiveOperandPrecisionIsBF16(
270 *use.instruction, use.operand_number)) {
271 continue;
272 }
273 // If the op propagates precision and it outputs a BF16, then it's OK to
274 // supply BF16 also as the input. In the backward pass, the users shapes
275 // should have already been processed.
276 if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(
277 *use.instruction, use.operand_number)) {
278 if (use.instruction->opcode() == HloOpcode::kTuple ||
279 (use.instruction->opcode() == HloOpcode::kAllReduce &&
280 use.instruction->shape().IsTuple())) {
281 ShapeIndex use_output_index{use.operand_number};
282 for (int64 i : use.operand_index) {
283 use_output_index.push_back(i);
284 }
285 if (OutputTypeAfterChange(use.instruction, use_output_index) ==
286 BF16) {
287 continue;
288 }
289 } else if (use.instruction->opcode() == HloOpcode::kGetTupleElement) {
290 ShapeIndex use_output_index;
291 for (int64 i = 1; i < use.operand_index.size(); ++i) {
292 use_output_index.push_back(use.operand_index[i]);
293 }
294 if (OutputTypeAfterChange(use.instruction, use_output_index) ==
295 BF16) {
296 continue;
297 }
298 } else {
299 if (OutputTypeAfterChange(use.instruction, use.operand_index) ==
300 BF16) {
301 continue;
302 }
303 }
304 }
305 return false;
306 }
307 }
308 return true;
309 }
310
DetermineInstructionPrecision(HloInstruction * hlo,bool skip_parameters)311 void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo,
312 bool skip_parameters) {
313 // We handle any fusion computation or while body/condition after the
314 // instruction is handled, because we need to know the output shape of a
315 // fusion or while before propagating inside its computations.
316 bool postpone_processing_called_computations = false;
317 auto cleaner = tensorflow::gtl::MakeCleanup(
318 [this, hlo, &postpone_processing_called_computations] {
319 if (!postpone_processing_called_computations) {
320 if (hlo->opcode() == HloOpcode::kFusion) {
321 DetermineFusionComputationPrecision(hlo);
322 } else if (hlo->opcode() == HloOpcode::kWhile) {
323 DetermineWhileComputationsPrecision(hlo);
324 }
325 }
326 instructions_visited_in_backward_pass_.insert(hlo);
327 });
328
329 if (hlo->opcode() == HloOpcode::kWhile &&
330 (caller_counts_[hlo->while_condition()] > 1 ||
331 caller_counts_[hlo->while_body()] > 1)) {
332 postpone_processing_called_computations = true;
333 return;
334 }
335
336 // Prevent root instructions from having their output modified by recording
337 // all F32 output values as needing to stay as F32.
338 CHECK(hlo->parent() != nullptr);
339 if (hlo == hlo->parent()->root_instruction()) {
340 if (!hlo->parent()->IsFusionComputation()) {
341 ShapeUtil::ForEachSubshape(hlo->shape(), [&](const Shape& /* subshape */,
342 const ShapeIndex& index) {
343 if (OutputTypeAfterChange(hlo, index) != F32) {
344 return;
345 }
346 for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
347 // Since we use HloValues from the dataflow analysis, this can also
348 // affect HLO instructions beyond the root, e.g., if the root is a
349 // Tuple HLO, then its operands are also affected.
350 values_that_must_be_kept_as_f32_.insert(value);
351 }
352 });
353 }
354 return;
355 }
356
357 // Do not change precision for instructions related to entry and exit of a
358 // computation, side-effecting instructions, and control flow, because this
359 // pass might break the interfaces or assumptions for them.
360 if (hlo->opcode() == HloOpcode::kCustomCall || //
361 hlo->opcode() == HloOpcode::kCall || //
362 hlo->opcode() == HloOpcode::kConditional || //
363 hlo->HasSideEffectNoRecurse() || //
364 (hlo->opcode() == HloOpcode::kParameter && skip_parameters)) {
365 return;
366 }
367
368 if (!ContainsKey(consider_using_bfloat16_, hlo)) {
369 return;
370 }
371
372 if (!bfloat16_support_->SupportsBF16Output(*hlo)) {
373 return;
374 }
375
376 ShapeUtil::ForEachSubshape(
377 hlo->shape(),
378 [hlo, this](const Shape& /* subshape */, const ShapeIndex& index) {
379 if (OutputTypeAfterChange(hlo, index) == F32 &&
380 AllUsersConsumeBF16(*hlo, index)) {
381 AddToOrRemoveFromBF16ChangeSet(hlo, index, BF16);
382 VLOG(2) << "HloInstruction output at shape index " << index
383 << " changed to BF16 precision: " << hlo->ToString();
384 }
385 });
386 }
387
InstructionIsCandidateForBF16Output(HloInstruction * hlo)388 bool BFloat16Propagation::InstructionIsCandidateForBF16Output(
389 HloInstruction* hlo) {
390 if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) &&
391 hlo->opcode() != HloOpcode::kTuple &&
392 hlo->opcode() != HloOpcode::kGetTupleElement &&
393 hlo->opcode() != HloOpcode::kDomain &&
394 hlo->shape().element_type() != BF16) {
395 for (int64 i = 0; i < hlo->operand_count(); ++i) {
396 if (!bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo,
397 i) ||
398 !ContainsKey(consider_using_bfloat16_, hlo->operand(i))) {
399 return false;
400 }
401 }
402 }
403 return true;
404 }
405
AdjustCalledComputationParameters(HloInstruction * hlo)406 void BFloat16Propagation::AdjustCalledComputationParameters(
407 HloInstruction* hlo) {
408 auto adjust_computation =
409 [this, hlo](HloComputation* computation,
410 absl::Span<HloInstruction* const> operands) {
411 // Adjust parameters.
412 CHECK_EQ(operands.size(), computation->num_parameters());
413 for (int64 i = 0; i < operands.size(); ++i) {
414 auto parameter = computation->parameter_instruction(i);
415 ShapeUtil::ForEachSubshape(
416 parameter->shape(),
417 [this, i, hlo, &operands, parameter](const Shape& /* subshape */,
418 const ShapeIndex& index) {
419 if (!ShapeUtil::IsLeafIndex(parameter->shape(), index)) {
420 return;
421 }
422 PrimitiveType operand_type =
423 OutputTypeAfterChange(operands[i], index);
424 if (OutputTypeAfterChange(parameter, index) == operand_type) {
425 return;
426 }
427 AddToOrRemoveFromBF16ChangeSet(parameter, index, operand_type);
428 VLOG(2) << "Called computation parameter "
429 << parameter->ToString() << " at shape index " << index
430 << " adjusted to "
431 << (operand_type == BF16 ? "BF16" : "F32")
432 << " to match operand in HLO " << hlo->ToString();
433 });
434 }
435 };
436
437 switch (hlo->opcode()) {
438 case HloOpcode::kFusion:
439 adjust_computation(hlo->fused_instructions_computation(),
440 hlo->operands());
441 break;
442 case HloOpcode::kWhile:
443 adjust_computation(hlo->while_condition(), hlo->operands());
444 adjust_computation(hlo->while_body(), hlo->operands());
445 break;
446 default:
447 break;
448 }
449 }
450
AdjustCalledComputationRoot(HloInstruction * hlo)451 void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) {
452 auto adjust_computation = [this, hlo](HloComputation* computation,
453 HloInstruction* output) {
454 // Adjust root.
455 HloInstruction* root = computation->root_instruction();
456 ShapeUtil::ForEachSubshape(root->shape(), [this, hlo, root, output](
457 const Shape& /* subshape */,
458 const ShapeIndex& index) {
459 if (!ShapeUtil::IsLeafIndex(hlo->shape(), index)) {
460 return;
461 }
462 const PrimitiveType output_type = OutputTypeAfterChange(output, index);
463 if (OutputTypeAfterChange(root, index) == output_type) {
464 return;
465 }
466 AddToOrRemoveFromBF16ChangeSet(root, index, output_type);
467 // It's possible that output_type is F32, but the root instruction's
468 // type is BF16; e.g., a fusion node's output was changed to BF16
469 // initially but then adjusted back to F32, and the fusion computation
470 // is now being adjusted after the fusion node.
471 if (output_type == F32) {
472 for (const auto* value : dataflow_->GetValueSet(root, index).values()) {
473 // We rely on the fact that this adjustment works in reverse
474 // topological order so that called computation will be
475 // processed later. Adding the value to
476 // values_that_must_be_kept_as_f32_ will ensure the
477 // correctness of the adjustment for HLOs that will be
478 // processed later.
479 values_that_must_be_kept_as_f32_.insert(value);
480 }
481 }
482 VLOG(2) << "Called computation root " << root->ToString()
483 << " at shape index " << index << " adjusted to "
484 << (output_type == BF16 ? "BF16" : "F32")
485 << " to match output shape of " << hlo->ToString();
486 });
487 };
488
489 switch (hlo->opcode()) {
490 case HloOpcode::kFusion:
491 adjust_computation(hlo->fused_instructions_computation(), hlo);
492 break;
493 case HloOpcode::kWhile:
494 adjust_computation(hlo->while_body(), hlo);
495 break;
496 default:
497 break;
498 }
499 }
500
ResolveInconsistencyOfAliasingBuffersHelper(HloComputation * computation,absl::flat_hash_set<const HloComputation * > * visited_computations)501 bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
502 HloComputation* computation,
503 absl::flat_hash_set<const HloComputation*>* visited_computations) {
504 bool parameter_changed = false;
505 auto insts = computation->MakeInstructionPostOrder();
506 // Do the adjustment on each instruction in the computation in reverse
507 // topological order.
508 for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
509 auto hlo = *inst_it;
510 auto adjust_hlo_output = [this, hlo, ¶meter_changed](
511 const Shape& /* subshape */,
512 const ShapeIndex& index) {
513 auto output_type = OutputTypeAfterChange(hlo, index);
514 if (output_type != F32 && output_type != BF16) {
515 return;
516 }
517 PrimitiveType type = BF16;
518 for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
519 auto value_type = ValueTypeAfterChange(value);
520 if (value_type == BF16) {
521 continue;
522 }
523 CHECK_EQ(value_type, F32);
524 type = F32;
525 break;
526 }
527 // It's possible that a user has been changed from BF16 to F32
528 // during this final adjustment pass, so we need to check
529 // AllUsersConsumeBF16() again.
530 if (type == BF16 && !AllUsersConsumeBF16(*hlo, index)) {
531 type = F32;
532 }
533 if (type == F32) {
534 for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
535 // We rely on the fact that this adjustment works in reverse
536 // topological order. Adding the value to
537 // values_that_must_be_kept_as_f32_ will ensure the correctness
538 // of the adjustment for HLOs that will be processed later.
539 values_that_must_be_kept_as_f32_.insert(value);
540 }
541 }
542 if (type != output_type) {
543 AddToOrRemoveFromBF16ChangeSet(hlo, index, type);
544 VLOG(2) << "HloInstruction output at shape index " << index
545 << " adjusted to " << (type == BF16 ? "BF16" : "F32") << ": "
546 << hlo->ToString();
547 if (hlo->opcode() == HloOpcode::kParameter) {
548 parameter_changed = true;
549 }
550 }
551 };
552 ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
553 AdjustCalledComputationRoot(hlo);
554 if (hlo->opcode() == HloOpcode::kWhile) {
555 // We need to run on the while body and condition repeatedly until a fixed
556 // point is reached, i.e., the parameters do not change any more. We may
557 // need more than one iteration because the while input and output alias
558 // each other, so changing one input parameter requires changing the
559 // corresponding output element and thus may transitively require changing
560 // another input parameter. A fixed point will be reached because the
561 // parameters can only be changed from BF16 to F32, not the other way
562 // around.
563 absl::flat_hash_set<const HloComputation*> visited_in_while;
564 while (ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_condition(),
565 &visited_in_while) ||
566 ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(),
567 &visited_in_while)) {
568 visited_in_while.clear();
569 ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
570 AdjustCalledComputationRoot(hlo);
571 }
572 visited_computations->insert(visited_in_while.begin(),
573 visited_in_while.end());
574 } else if (hlo->opcode() == HloOpcode::kFusion) {
575 ResolveInconsistencyOfAliasingBuffersHelper(
576 hlo->fused_instructions_computation(), visited_computations);
577 }
578 }
579 // Now adjust parameters of called computations.
580 for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
581 AdjustCalledComputationParameters(*inst_it);
582 }
583 return parameter_changed;
584 }
585
ResolveInconsistencyOfAliasingBuffers(HloModule * module)586 void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
587 HloModule* module) {
588 const auto& computations_topological_order =
589 module->MakeComputationPostOrder();
590 absl::flat_hash_set<const HloComputation*> resolved;
591 for (auto comp_it = computations_topological_order.rbegin();
592 comp_it != computations_topological_order.rend(); ++comp_it) {
593 if (ContainsKey(resolved, *comp_it)) {
594 continue;
595 }
596 ResolveInconsistencyOfAliasingBuffersHelper(*comp_it, &resolved);
597 }
598 }
599
ResolveInconsistentFusions(HloModule * module)600 Status BFloat16Propagation::ResolveInconsistentFusions(HloModule* module) {
601 // We could have changed a fusion computation's root shape to have a different
602 // precision than the fusion node's output, if the fusion root does not
603 // define a buffer (e.g., a tuple). Now we add conversions after such fusion
604 // roots to make them match the fusion output. If the fusion output is a
605 // (possibly nested) tuple, we first create get-tuple-elements, then convert
606 // the unmatching leaf nodes, and finally create a new tuple as the fusion
607 // computation's root. If tuples and get-tuple-elements are created, we will
608 // run tuple simplifier and dead code elimination at the end (dead code is not
609 // allowed in fusion computation). E.g.,
610 //
611 // (1) (2) (3)
612 // a b a b a b
613 // |\ | |\ | |\ |
614 // \ add -> |add -> | add
615 // \ | \ | convert |
616 // tuple tuple \ |
617 // / \ tuple
618 // gte gte
619 // | |
620 // convert |
621 // \ /
622 // tuple
623 // (1) a is F32 but tuple is BF16
624 // (2) after adding conversion
625 // (3) after tuple simplifier and DCE.
626 for (auto computation : module->MakeComputationPostOrder()) {
627 auto insts = computation->MakeInstructionPostOrder();
628 for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
629 auto hlo = *inst_it;
630 if (hlo->opcode() != HloOpcode::kFusion) {
631 continue;
632 }
633 auto fusion_computation = hlo->fused_instructions_computation();
634 auto fusion_root = fusion_computation->root_instruction();
635 if (ShapeUtil::Compatible(fusion_root->shape(), hlo->shape())) {
636 continue;
637 }
638 ShapeTree<HloInstruction*> converted_outputs(hlo->shape());
639 // Deep copy the fusion root, and convert a leaf node only if its shape
640 // does not match the fusion output.
641 TF_ASSIGN_OR_RETURN(
642 HloInstruction * copy,
643 fusion_computation->DeepCopyInstructionWithCustomCopier(
644 fusion_root,
645 [hlo](HloInstruction* leaf, const ShapeIndex& leaf_index,
646 HloComputation* comp) {
647 const Shape& hlo_subshape =
648 ShapeUtil::GetSubshape(hlo->shape(), leaf_index);
649 if (ShapeUtil::Compatible(leaf->shape(), hlo_subshape)) {
650 return leaf;
651 }
652 return comp->AddInstruction(
653 HloInstruction::CreateConvert(hlo_subshape, leaf));
654 }));
655 fusion_computation->set_root_instruction(copy);
656 }
657 }
658 return Status::OK();
659 }
660
ResolveConvertedConstants(HloModule * module)661 Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) {
662 // We may have converted some constants from F32 to BF16, so adjust the
663 // constant literals in such cases. We do this here instead of when the
664 // constant node's is changed because 1) the HloInstruction interface does not
665 // allow resetting the literal so we have to create a new kConstant
666 // instruction to replace the old one, which invalidates dataflow analysis,
667 // and 2) it's possible that a kConstant's output gets changed to BF16 at the
668 // beginning but later on adjusted back to F32, so converting literals here
669 // can avoid repeated conversions.
670 //
671 // TODO(b/73833576): Consider resetting literal in HloInstruction.
672 for (auto computation : module->MakeComputationPostOrder()) {
673 for (auto hlo : computation->MakeInstructionPostOrder()) {
674 if (hlo->opcode() != HloOpcode::kConstant) {
675 continue;
676 }
677 if (!ShapeUtil::Equal(hlo->literal().shape(), hlo->shape())) {
678 TF_ASSIGN_OR_RETURN(auto converted_literal,
679 hlo->literal().ConvertToShape(hlo->shape()));
680 auto new_constant = computation->AddInstruction(
681 HloInstruction::CreateConstant(std::move(converted_literal)));
682 TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant));
683 }
684 }
685 }
686 return Status::OK();
687 }
688
SkipNoopConversions(HloModule * module)689 Status BFloat16Propagation::SkipNoopConversions(HloModule* module) {
690 for (auto computation : module->computations()) {
691 for (auto hlo : computation->MakeInstructionPostOrder()) {
692 if (hlo->opcode() != HloOpcode::kConvert) {
693 continue;
694 }
695 auto source = hlo->mutable_operand(0);
696 if (!ShapeUtil::Equal(source->shape(), hlo->shape())) {
697 continue;
698 }
699 const bool is_root = hlo == computation->root_instruction();
700 TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(source));
701 if (is_root) {
702 computation->set_root_instruction(source);
703 }
704 }
705 }
706 return Status::OK();
707 }
708
709 // The algorithm first does a forward pass (parameters to root) to determine a
710 // set of instructions to consider using bfloat16, then does a backward pass to
711 // determine the precisions of those instructions according to the need of
712 // their users. During the backward pass, the potential changes are stored in
713 // changes_to_bf16_ which are subject to further adjustments then applied to the
714 // HLOs.
Run(HloModule * module)715 StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
716 consider_using_bfloat16_.clear();
717 instructions_visited_in_backward_pass_.clear();
718 computations_visited_in_backward_pass_.clear();
719 values_that_must_be_kept_as_f32_.clear();
720 caller_counts_.clear();
721 changes_to_bf16_.clear();
722 changed_ = false;
723
724 auto computations_topological_order = module->MakeComputationPostOrder();
725
726 // Before running the propagation pass, we insert copies (kConvert to the same
727 // type) of F32 inputs to while loops. This prevents other uses of the same
728 // input from aliasing the while loop input/output, so that there's greater
729 // chance to use BF16 inside the loop. If some of these added copies do not
730 // help, they will remain F32 after BF16 propagation and will be removed since
731 // they are no-ops.
732 for (auto computation : computations_topological_order) {
733 for (auto inst : computation->MakeInstructionPostOrder()) {
734 if (inst->opcode() != HloOpcode::kWhile) {
735 continue;
736 }
737
738 auto operand = inst->mutable_operand(0);
739 TF_ASSIGN_OR_RETURN(
740 HloInstruction * copy,
741 computation->DeepCopyInstructionWithCustomCopier(
742 operand, [](HloInstruction* leaf, const ShapeIndex& leaf_index,
743 HloComputation* comp) {
744 if (leaf->shape().element_type() != F32) {
745 return leaf;
746 }
747 return comp->AddInstruction(
748 HloInstruction::CreateConvert(leaf->shape(), leaf));
749 }));
750 TF_RETURN_IF_ERROR(operand->ReplaceUseWith(inst, copy));
751 }
752 }
753
754 TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module));
755
756 // The first step is a forward pass (parameters to root), where we determine
757 // the potential candidate instructions to use bfloat16 in the outputs that
758 // are not likely to cause overhead from extra explicit conversions. This is
759 // done forwardly because we determine whether an HLO is a candidate partially
760 // based on whether its operands are candidates.
761 for (auto computation : computations_topological_order) {
762 for (auto inst : computation->MakeInstructionPostOrder()) {
763 if (InstructionIsCandidateForBF16Output(inst)) {
764 consider_using_bfloat16_.insert(inst);
765 }
766 }
767 }
768
769 // The second step is a backward pass (root to parameters), where we modify
770 // the precisions of the instructions identified in the first step when
771 // feasible. This is done backwardly because we determine the precision of an
772 // HLO's output based on how it is later used.
773 //
774 // The precision of an instruction is determined by its users, so we do the
775 // propagation in reverse topological order.
776 for (auto comp_it = computations_topological_order.rbegin();
777 comp_it != computations_topological_order.rend(); ++comp_it) {
778 if (ContainsKey(computations_visited_in_backward_pass_, *comp_it)) {
779 continue;
780 }
781 auto insts = (*comp_it)->MakeInstructionPostOrder();
782 for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
783 DetermineInstructionPrecision(*inst_it,
784 /*skip_parameters=*/true);
785 }
786 computations_visited_in_backward_pass_.insert(*comp_it);
787 }
788
789 // It's possible that an instruction does not define a buffer, but the
790 // defining instruction's shape has changed. So we need to adjust the output
791 // shapes of instructions according to the HLO values they refer to.
792 ResolveInconsistencyOfAliasingBuffers(module);
793
794 // Apply the changes in changes_to_bf16_.
795 for (auto& change : changes_to_bf16_) {
796 for (const auto& entry : change.second) {
797 auto subshape = entry.first;
798 CHECK_EQ(subshape->element_type(), F32);
799 subshape->set_element_type(BF16);
800 changed_ = true;
801 }
802 }
803
804 // Removes redundant HLOs added by this pass, either when inserting
805 // de-aliasing copies to while loop inputs, or later when converting output
806 // types.
807 auto clean_up = [this, module]() {
808 TF_RETURN_IF_ERROR(SkipNoopConversions(module));
809 TupleSimplifier tuple_simplifier;
810 TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
811 HloDCE dce;
812 TF_RETURN_IF_ERROR(dce.Run(module).status());
813 return Status::OK();
814 };
815
816 if (!changed_) {
817 TF_RETURN_IF_ERROR(clean_up());
818 return false;
819 }
820
821 TF_RETURN_IF_ERROR(ResolveInconsistentFusions(module));
822 TF_RETURN_IF_ERROR(ResolveConvertedConstants(module));
823
824 TF_RETURN_IF_ERROR(clean_up());
825 return true;
826 }
827
OutputTypeAfterChange(HloInstruction * hlo,const ShapeIndex & index) const828 PrimitiveType BFloat16Propagation::OutputTypeAfterChange(
829 HloInstruction* hlo, const ShapeIndex& index) const {
830 Shape* subshape = ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index);
831 const PrimitiveType type_on_hlo = subshape->element_type();
832 if (type_on_hlo != F32) {
833 return type_on_hlo;
834 }
835 auto it = changes_to_bf16_.find(hlo);
836 if (it == changes_to_bf16_.end()) {
837 return type_on_hlo;
838 }
839 return ContainsKey(it->second, subshape) ? BF16 : F32;
840 }
841
ValueTypeAfterChange(const HloValue * value) const842 PrimitiveType BFloat16Propagation::ValueTypeAfterChange(
843 const HloValue* value) const {
844 auto hlo = value->defining_instruction();
845 const auto& position = value->defining_position();
846 return OutputTypeAfterChange(hlo, position.index);
847 }
848
AddToOrRemoveFromBF16ChangeSet(HloInstruction * hlo,const ShapeIndex & index,PrimitiveType target_type)849 void BFloat16Propagation::AddToOrRemoveFromBF16ChangeSet(
850 HloInstruction* hlo, const ShapeIndex& index, PrimitiveType target_type) {
851 if (target_type == BF16) {
852 auto& entry = changes_to_bf16_[hlo];
853 entry.emplace(ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index),
854 index);
855 } else {
856 CHECK_EQ(target_type, F32);
857 auto it = changes_to_bf16_.find(hlo);
858 if (it == changes_to_bf16_.end()) {
859 return;
860 }
861 it->second.erase(
862 ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index));
863 }
864 }
865
866 } // namespace xla
867