• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/bfloat16_propagation.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "tensorflow/compiler/xla/literal.h"
21 #include "tensorflow/compiler/xla/map_util.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_dce.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
28 #include "tensorflow/compiler/xla/shape_tree.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/core/lib/gtl/cleanup.h"
31 #include "tensorflow/core/platform/logging.h"
32 
33 namespace xla {
34 
BFloat16Propagation(const BFloat16Support * bfloat16_support)35 BFloat16Propagation::BFloat16Propagation(
36     const BFloat16Support* bfloat16_support)
37     : bfloat16_support_(bfloat16_support) {}
38 
DetermineFusionComputationPrecision(HloInstruction * fusion)39 void BFloat16Propagation::DetermineFusionComputationPrecision(
40     HloInstruction* fusion) {
41   CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
42   if (!bfloat16_support_->SupportsMixedPrecisions(*fusion)) {
43     return;
44   }
45 
46   // We are depending on the fusion node itself having already been analyzed
47   // for whether it can output BF16 and this has been adjusted in the output
48   // shape, and now we're looking to update the interior of the fusion node to
49   // match the new output shape, as well as recursively process the whole fusion
50   // node even if the output shape was not modified.
51   auto root = fusion->fused_instructions_computation()->root_instruction();
52 
53   // Adjust root's element types according to the fusion's output shape.
54   ShapeUtil::ForEachSubshape(
55       root->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
56         if (subshape.element_type() != F32) {
57           return;
58         }
59         if (OutputTypeAfterChange(fusion, index) == BF16) {
60           AddToOrRemoveFromBF16ChangeSet(root, index, BF16);
61           VLOG(2) << "Fused root " << root->ToString() << " at shape index "
62                   << index << " changed to BF16 precision for fusion "
63                   << fusion->ToString();
64         }
65       });
66 
67   // Propagate BF16 in the fusion computation.
68   auto insts =
69       fusion->fused_instructions_computation()->MakeInstructionPostOrder();
70   for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
71     DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
72   }
73   computations_visited_in_backward_pass_.insert(
74       fusion->fused_instructions_computation());
75 
76   RevertIfFusionInternalBF16Changes(fusion);
77 }
78 
RevertIfFusionInternalBF16Changes(HloInstruction * fusion)79 void BFloat16Propagation::RevertIfFusionInternalBF16Changes(
80     HloInstruction* fusion) {
81   auto has_changes = [this](HloInstruction* inst) {
82     auto it = changes_to_bf16_.find(inst);
83     return it != changes_to_bf16_.end() && !it->second.empty();
84   };
85 
86   auto root = fusion->fused_instructions_computation()->root_instruction();
87   absl::flat_hash_set<const HloValue*> changed_root_buffers;
88 
89   auto root_changes_it = changes_to_bf16_.find(root);
90   if (root_changes_it != changes_to_bf16_.end()) {
91     for (const auto& entry : root_changes_it->second) {
92       for (const HloValue* value :
93            dataflow_->GetValueSet(root, entry.second).values()) {
94         changed_root_buffers.insert(value);
95       }
96     }
97   }
98 
99   auto aliases_changed_root_buffer =
100       [this, &changed_root_buffers](const HloInstruction* inst) {
101         bool aliasing = false;
102         ShapeUtil::ForEachSubshape(
103             inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
104               if (aliasing) {
105                 // Skip if aliasing is already found.
106                 return;
107               }
108               // Only F32 buffers are considered for changing to BF16 in this
109               // pass.
110               if (subshape.element_type() != F32) {
111                 return;
112               }
113               for (const HloValue* value :
114                    dataflow_->GetValueSet(inst, index).values()) {
115                 if (ContainsKey(changed_root_buffers, value)) {
116                   aliasing = true;
117                   break;
118                 }
119               }
120             });
121         return aliasing;
122       };
123 
124   for (auto inst :
125        fusion->fused_instructions_computation()->MakeInstructionPostOrder()) {
126     if (inst->opcode() == HloOpcode::kParameter) {
127       continue;
128     }
129     if (aliases_changed_root_buffer(inst)) {
130       continue;
131     }
132     if (inst->opcode() == HloOpcode::kFusion) {
133       bool parameter_reverted = false;
134       for (int64_t i = 0; i < inst->operand_count(); ++i) {
135         if (has_changes(inst->mutable_operand(i))) {
136           // Changes on the operand have not been reverted.
137           continue;
138         }
139         auto* fused_parameter = inst->fused_parameter(i);
140         if (has_changes(fused_parameter)) {
141           changes_to_bf16_.erase(fused_parameter);
142           parameter_reverted = true;
143         }
144       }
145       if (parameter_reverted) {
146         RevertIfFusionInternalBF16Changes(inst);
147       }
148     }
149     if (!has_changes(inst)) {
150       continue;
151     }
152     bool revert_changes = true;
153     for (auto operand : inst->operands()) {
154       if (has_changes(operand)) {
155         revert_changes = false;
156         break;
157       }
158     }
159     if (revert_changes) {
160       changes_to_bf16_.erase(inst);
161     }
162   }
163 }
164 
DetermineWhileComputationsPrecision(HloInstruction * while_hlo)165 void BFloat16Propagation::DetermineWhileComputationsPrecision(
166     HloInstruction* while_hlo) {
167   CHECK_EQ(while_hlo->opcode(), HloOpcode::kWhile);
168 
169   // We are depending on the while node itself having already been analyzed for
170   // whether it can output BF16 and this has been adjusted in the output shape,
171   // and now we're looking to update the body and condition computations to
172   // match the new output shape, as well as recursively process the whole while
173   // node even if the output shape was not modified.
174   HloComputation* body = while_hlo->while_body();
175   auto body_root = body->root_instruction();
176   HloComputation* condition = while_hlo->while_condition();
177 
178   ShapeUtil::ForEachSubshape(
179       body_root->shape(), [this, while_hlo, body_root](
180                               const Shape& subshape, const ShapeIndex& index) {
181         if (subshape.element_type() != F32) {
182           return;
183         }
184         if (OutputTypeAfterChange(while_hlo, index) == BF16) {
185           AddToOrRemoveFromBF16ChangeSet(body_root, index, BF16);
186           VLOG(2) << "While body root " << body_root->ToString()
187                   << " at shape index " << index
188                   << " changed to BF16 precision for while "
189                   << while_hlo->ToString();
190         }
191       });
192 
193   auto body_insts = body->MakeInstructionPostOrder();
194   for (auto inst_it = body_insts.rbegin(); inst_it != body_insts.rend();
195        ++inst_it) {
196     DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
197   }
198   computations_visited_in_backward_pass_.insert(body);
199 
200   auto condition_insts = condition->MakeInstructionPostOrder();
201   for (auto inst_it = condition_insts.rbegin();
202        inst_it != condition_insts.rend(); ++inst_it) {
203     DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
204   }
205   computations_visited_in_backward_pass_.insert(condition);
206 }
207 
DetermineConditionalComputationsPrecision(HloInstruction * cond)208 void BFloat16Propagation::DetermineConditionalComputationsPrecision(
209     HloInstruction* cond) {
210   CHECK_EQ(cond->opcode(), HloOpcode::kConditional);
211   for (int64_t i = 0; i < cond->branch_count(); ++i) {
212     auto branch = cond->branch_computation(i);
213     auto root = branch->root_instruction();
214     ShapeUtil::ForEachSubshape(
215         root->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
216           if (subshape.element_type() != F32) {
217             return;
218           }
219           if (OutputTypeAfterChange(cond, index) == BF16) {
220             AddToOrRemoveFromBF16ChangeSet(root, index, BF16);
221             VLOG(2) << "Conditional branch " << i << " root "
222                     << root->ToString() << " at shape index " << index
223                     << " changed to BF16 precision for conditional "
224                     << cond->ToString();
225           }
226         });
227     auto insts = branch->MakeInstructionPostOrder();
228     for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
229       DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
230     }
231     computations_visited_in_backward_pass_.insert(branch);
232   }
233 }
234 
AllUsersConsumeBF16(const HloInstruction & hlo,const ShapeIndex & index) const235 bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
236                                               const ShapeIndex& index) const {
237   // If the subshape isn't floating point then none of the users will be BF16.
238   const Shape& subshape = ShapeUtil::GetSubshape(hlo.shape(), index);
239   if (subshape.element_type() != BF16 && subshape.element_type() != F32) {
240     return false;
241   }
242 
243   auto& value_set = dataflow_->GetValueSet(&hlo, index);
244   for (const HloValue* value : value_set.values()) {
245     if (ContainsKey(values_that_must_be_kept_as_f32_, value)) {
246       return false;
247     }
248     // We use the original type for the value because we are going to examine
249     // the uses of it, instead of the value itself. If ValueTypeAfterChange()
250     // were used, it would cause problems when there are aliasing buffers, i.e.,
251     // ResolveInconsistencyOfAliasingBuffers() would fail to revert the
252     // tentative change to BF16 even if the uses require F32.
253     if (value->shape().element_type() == BF16) {
254       continue;
255     }
256     for (const HloUse& use : value->uses()) {
257       if (!ContainsKey(instructions_visited_in_backward_pass_,
258                        use.instruction)) {
259         // We don't know yet whether use.instruction will consume BF16 since it
260         // hasn't been visited. Although we visit instructions in reverse
261         // topological order, this is still possible because there may be
262         // unvisited instruction that alias the same buffer. In this case, we
263         // aggressively skip this use, and if this causes inconsistency (e.g.,
264         // one use is in BF16 but another use is in F32), it will be resolved at
265         // the end of the BFloat16Propagation pass.
266         continue;
267       }
268       if (use.instruction->HasSideEffectNoRecurse()) {
269         // Keep side-effecting instruction's operands unchanged.
270         return false;
271       }
272       // Any visited user that can accept BF16 has already been updated if
273       // necessary, e.g., the output has been changed to BF16 if it propagates
274       // precision, or a called computation's parameters have been changed to
275       // BF16 for fusions or whiles.
276       if (use.instruction->opcode() == HloOpcode::kFusion) {
277         auto* fused_parameter =
278             use.instruction->fused_parameter(use.operand_number);
279         if (OutputTypeAfterChange(fused_parameter, use.operand_index) != BF16) {
280           return false;
281         }
282         continue;
283       } else if (use.instruction->opcode() == HloOpcode::kWhile) {
284         auto* cond_parameter =
285             use.instruction->while_condition()->parameter_instruction(
286                 use.operand_number);
287         if (OutputTypeAfterChange(cond_parameter, use.operand_index) != BF16) {
288           return false;
289         }
290         auto* body_parameter =
291             use.instruction->while_body()->parameter_instruction(
292                 use.operand_number);
293         if (OutputTypeAfterChange(body_parameter, use.operand_index) != BF16) {
294           return false;
295         }
296         continue;
297       } else if (use.instruction->opcode() == HloOpcode::kConditional) {
298         auto* cond_parameter =
299             use.instruction->branch_computation(use.operand_number - 1)
300                 ->parameter_instruction(0);
301         if (OutputTypeAfterChange(cond_parameter, use.operand_index) != BF16) {
302           return false;
303         }
304         continue;
305       }
306       if (bfloat16_support_->EffectiveOperandPrecisionIsBF16(
307               *use.instruction, use.operand_number)) {
308         continue;
309       }
310       // If the op propagates precision and it outputs a BF16, then it's OK to
311       // supply BF16 also as the input. In the backward pass, the users shapes
312       // should have already been processed.
313       if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(
314               *use.instruction, use.operand_number)) {
315         if (use.instruction->opcode() == HloOpcode::kTuple ||
316             (use.instruction->opcode() == HloOpcode::kAllReduce &&
317              use.instruction->shape().IsTuple())) {
318           ShapeIndex use_output_index{use.operand_number};
319           for (int64_t i : use.operand_index) {
320             use_output_index.push_back(i);
321           }
322           if (OutputTypeAfterChange(use.instruction, use_output_index) ==
323               BF16) {
324             continue;
325           }
326         } else if (use.instruction->opcode() == HloOpcode::kGetTupleElement) {
327           ShapeIndex use_output_index;
328           for (int64_t i = 1; i < use.operand_index.size(); ++i) {
329             use_output_index.push_back(use.operand_index[i]);
330           }
331           if (OutputTypeAfterChange(use.instruction, use_output_index) ==
332               BF16) {
333             continue;
334           }
335         } else {
336           if (OutputTypeAfterChange(use.instruction, use.operand_index) ==
337               BF16) {
338             continue;
339           }
340         }
341       }
342       return false;
343     }
344   }
345   return true;
346 }
347 
ShouldKeepPrecisionUnchanged(const HloInstruction * inst)348 bool BFloat16Propagation::ShouldKeepPrecisionUnchanged(
349     const HloInstruction* inst) {
350   if (inst->opcode() == HloOpcode::kFusion &&
351       inst->fusion_kind() == HloInstruction::FusionKind::kCustom) {
352     return ShouldKeepPrecisionUnchanged(
353         inst->fused_instructions_computation()->root_instruction());
354   }
355   // Do not change precision for side-effecting instructions, control flow, and
356   // bitcast-convert, because this pass might break the interfaces or
357   // assumptions for them.
358   return inst->opcode() == HloOpcode::kCustomCall ||
359          inst->opcode() == HloOpcode::kCall ||
360          inst->opcode() == HloOpcode::kBitcastConvert ||
361          inst->HasSideEffectNoRecurse();
362 }
363 
DetermineInstructionPrecision(HloInstruction * hlo,bool skip_parameters)364 void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo,
365                                                         bool skip_parameters) {
366   // We handle any fusion computation, while body/condition or conditional
367   // branches after the instruction is handled, because we need to know the
368   // output shape of a fusion or while before propagating inside its
369   // computations.
370   bool postpone_processing_called_computations = false;
371   auto cleaner = tensorflow::gtl::MakeCleanup(
372       [this, hlo, &postpone_processing_called_computations] {
373         if (!postpone_processing_called_computations) {
374           if (hlo->opcode() == HloOpcode::kFusion) {
375             DetermineFusionComputationPrecision(hlo);
376           } else if (hlo->opcode() == HloOpcode::kWhile) {
377             DetermineWhileComputationsPrecision(hlo);
378           } else if (hlo->opcode() == HloOpcode::kConditional) {
379             DetermineConditionalComputationsPrecision(hlo);
380           }
381         }
382         instructions_visited_in_backward_pass_.insert(hlo);
383       });
384 
385   if (hlo->opcode() == HloOpcode::kWhile &&
386       (caller_counts_[hlo->while_condition()] > 1 ||
387        caller_counts_[hlo->while_body()] > 1)) {
388     postpone_processing_called_computations = true;
389     return;
390   }
391 
392   if (hlo->opcode() == HloOpcode::kConditional &&
393       absl::c_any_of(hlo->branch_computations(), [&](const HloComputation* c) {
394         return caller_counts_[c] > 1;
395       })) {
396     postpone_processing_called_computations = true;
397     return;
398   }
399 
400   // Prevent root instructions from having their output modified by recording
401   // all F32 output values as needing to stay as F32.
402   CHECK(hlo->parent() != nullptr);
403   if (hlo == hlo->parent()->root_instruction()) {
404     if (!hlo->parent()->IsFusionComputation()) {
405       ShapeUtil::ForEachSubshape(hlo->shape(), [&](const Shape& /* subshape */,
406                                                    const ShapeIndex& index) {
407         if (OutputTypeAfterChange(hlo, index) != F32) {
408           return;
409         }
410         for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
411           // Since we use HloValues from the dataflow analysis, this can also
412           // affect HLO instructions beyond the root, e.g., if the root is a
413           // Tuple HLO, then its operands are also affected.
414           values_that_must_be_kept_as_f32_.insert(value);
415         }
416       });
417     }
418     return;
419   }
420 
421   if (ShouldKeepPrecisionUnchanged(hlo) ||
422       (hlo->opcode() == HloOpcode::kParameter && skip_parameters)) {
423     return;
424   }
425 
426   if (!ContainsKey(consider_using_bfloat16_, hlo)) {
427     return;
428   }
429 
430   if (!bfloat16_support_->SupportsBF16Output(*hlo)) {
431     return;
432   }
433 
434   ShapeUtil::ForEachSubshape(
435       hlo->shape(),
436       [hlo, this](const Shape& /* subshape */, const ShapeIndex& index) {
437         if (OutputTypeAfterChange(hlo, index) == F32 &&
438             AllUsersConsumeBF16(*hlo, index)) {
439           AddToOrRemoveFromBF16ChangeSet(hlo, index, BF16);
440           VLOG(2) << "HloInstruction output at shape index " << index
441                   << " changed to BF16 precision: " << hlo->ToString();
442         }
443       });
444 }
445 
InstructionIsCandidateForBF16Output(HloInstruction * hlo)446 bool BFloat16Propagation::InstructionIsCandidateForBF16Output(
447     HloInstruction* hlo) {
448   if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) &&
449       hlo->opcode() != HloOpcode::kTuple &&
450       hlo->opcode() != HloOpcode::kGetTupleElement &&
451       hlo->opcode() != HloOpcode::kDomain &&
452       hlo->shape().element_type() != BF16) {
453     for (int64_t i = 0; i < hlo->operand_count(); ++i) {
454       if (!bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo,
455                                                                          i) ||
456           !ContainsKey(consider_using_bfloat16_, hlo->operand(i))) {
457         return false;
458       }
459     }
460   }
461   return true;
462 }
463 
AdjustCalledComputationParameters(HloInstruction * hlo)464 void BFloat16Propagation::AdjustCalledComputationParameters(
465     HloInstruction* hlo) {
466   auto adjust_computation =
467       [this, hlo](HloComputation* computation,
468                   absl::Span<HloInstruction* const> operands) {
469         // Adjust parameters.
470         CHECK_EQ(operands.size(), computation->num_parameters());
471         for (int64_t i = 0; i < operands.size(); ++i) {
472           auto parameter = computation->parameter_instruction(i);
473           ShapeUtil::ForEachSubshape(
474               parameter->shape(),
475               [this, i, hlo, &operands, parameter](const Shape& /* subshape */,
476                                                    const ShapeIndex& index) {
477                 if (!ShapeUtil::IsLeafIndex(parameter->shape(), index)) {
478                   return;
479                 }
480                 PrimitiveType operand_type =
481                     OutputTypeAfterChange(operands[i], index);
482                 if (OutputTypeAfterChange(parameter, index) == operand_type) {
483                   return;
484                 }
485                 AddToOrRemoveFromBF16ChangeSet(parameter, index, operand_type);
486                 VLOG(2) << "Called computation parameter "
487                         << parameter->ToString() << " at shape index " << index
488                         << " adjusted to "
489                         << (operand_type == BF16 ? "BF16" : "F32")
490                         << " to match operand in HLO " << hlo->ToString();
491               });
492         }
493       };
494 
495   switch (hlo->opcode()) {
496     case HloOpcode::kFusion:
497       adjust_computation(hlo->fused_instructions_computation(),
498                          hlo->operands());
499       break;
500     case HloOpcode::kWhile:
501       adjust_computation(hlo->while_condition(), hlo->operands());
502       adjust_computation(hlo->while_body(), hlo->operands());
503       break;
504     case HloOpcode::kConditional:
505       for (int64_t i = 0; i < hlo->branch_count(); ++i) {
506         adjust_computation(hlo->branch_computation(i),
507                            {hlo->mutable_operand(i + 1)});
508       }
509       break;
510     default:
511       break;
512   }
513 }
514 
AdjustCalledComputationRoot(HloInstruction * hlo)515 void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) {
516   auto adjust_computation = [this, hlo](HloComputation* computation,
517                                         HloInstruction* output) {
518     // Adjust root.
519     HloInstruction* root = computation->root_instruction();
520     ShapeUtil::ForEachSubshape(root->shape(), [this, hlo, root, output](
521                                                   const Shape& /* subshape */,
522                                                   const ShapeIndex& index) {
523       if (!ShapeUtil::IsLeafIndex(hlo->shape(), index)) {
524         return;
525       }
526       const PrimitiveType output_type = OutputTypeAfterChange(output, index);
527       if (OutputTypeAfterChange(root, index) == output_type) {
528         return;
529       }
530       AddToOrRemoveFromBF16ChangeSet(root, index, output_type);
531       // It's possible that output_type is F32, but the root instruction's
532       // type is BF16; e.g., a fusion node's output was changed to BF16
533       // initially but then adjusted back to F32, and the fusion computation
534       // is now being adjusted after the fusion node.
535       if (output_type == F32) {
536         for (const auto* value : dataflow_->GetValueSet(root, index).values()) {
537           // We rely on the fact that this adjustment works in reverse
538           // topological order so that called computation will be
539           // processed later. Adding the value to
540           // values_that_must_be_kept_as_f32_ will ensure the
541           // correctness of the adjustment for HLOs that will be
542           // processed later.
543           values_that_must_be_kept_as_f32_.insert(value);
544         }
545       }
546       VLOG(2) << "Called computation root " << root->ToString()
547               << " at shape index " << index << " adjusted to "
548               << (output_type == BF16 ? "BF16" : "F32")
549               << " to match output shape of " << hlo->ToString();
550     });
551   };
552 
553   switch (hlo->opcode()) {
554     case HloOpcode::kFusion:
555       adjust_computation(hlo->fused_instructions_computation(), hlo);
556       break;
557     case HloOpcode::kWhile:
558       adjust_computation(hlo->while_body(), hlo);
559       break;
560     case HloOpcode::kConditional:
561       for (auto* branch : hlo->branch_computations()) {
562         adjust_computation(branch, hlo);
563       }
564       break;
565     default:
566       break;
567   }
568 }
569 
ResolveInconsistencyOfAliasingBuffersHelper(HloComputation * computation,absl::flat_hash_set<const HloComputation * > * visited_computations)570 bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
571     HloComputation* computation,
572     absl::flat_hash_set<const HloComputation*>* visited_computations) {
573   bool parameter_changed = false;
574   auto insts = computation->MakeInstructionPostOrder();
575   // Do the adjustment on each instruction in the computation in reverse
576   // topological order.
577   while (true) {
578     bool any_change = false;
579     for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
580       auto hlo = *inst_it;
581       auto adjust_hlo_output = [&](const Shape& /* subshape */,
582                                    const ShapeIndex& index) {
583         auto output_type = OutputTypeAfterChange(hlo, index);
584         VLOG(2) << "output_type is " << ((output_type == BF16) ? "BF16" : "F32")
585                 << " for :" << hlo->ToString() << "\n";
586         if (output_type != F32 && output_type != BF16) {
587           return;
588         }
589         PrimitiveType type = BF16;
590         for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
591           auto value_type = ValueTypeAfterChange(value);
592           if (value_type == BF16) {
593             continue;
594           }
595           VLOG(2) << "Adjust to F32 due to aliased dataflow value: "
596                   << value->ToString() << "\n";
597           CHECK_EQ(value_type, F32);
598           type = F32;
599           break;
600         }
601         // In order to find aliases due to in-place operations, use
602         // GetInPlaceInputOutputPairs. Ideally, we'd use HloAliasAnalysis here,
603         // but this code works with HloModules that aren't ready yet to use
604         // HloAliasAnalysis (e.g., their computation graphs may not have been
605         // flattened yet).
606         for (const auto& operand_and_output_index :
607              HloDataflowAnalysis::GetInPlaceInputOutputPairs(hlo)) {
608           if (operand_and_output_index.second == index) {
609             const HloUse& operand = operand_and_output_index.first;
610             for (const auto* value :
611                  dataflow_
612                      ->GetValueSet(hlo->operand(operand.operand_number),
613                                    operand.operand_index)
614                      .values()) {
615               auto value_type = ValueTypeAfterChange(value);
616               if (value_type == BF16) {
617                 continue;
618               }
619               VLOG(2) << "Adjust to F32 due to InputOutPair: "
620                       << value->ToString() << "\n";
621               CHECK_EQ(value_type, F32);
622               type = F32;
623               break;
624             }
625           }
626         }
627 
628         // It's possible that a user has been changed from BF16 to F32
629         // during this final adjustment pass, so we need to check
630         // AllUsersConsumeBF16() again.
631         if (type == BF16 && !AllUsersConsumeBF16(*hlo, index)) {
632           VLOG(2) << "Adjust to F32 due to All user consumeBF16 fail\n";
633           type = F32;
634         }
635         if (type == F32) {
636           for (const auto* value :
637                dataflow_->GetValueSet(hlo, index).values()) {
638             // We rely on the fact that this adjustment works in reverse
639             // topological order. Adding the value to
640             // values_that_must_be_kept_as_f32_ will ensure the correctness
641             // of the adjustment for HLOs that will be processed later.
642             values_that_must_be_kept_as_f32_.insert(value);
643           }
644         }
645         if (type != output_type) {
646           any_change = true;
647           AddToOrRemoveFromBF16ChangeSet(hlo, index, type);
648           VLOG(2) << "HloInstruction output at shape index " << index
649                   << " adjusted to " << (type == BF16 ? "BF16" : "F32") << ": "
650                   << hlo->ToString();
651           if (hlo->opcode() == HloOpcode::kParameter) {
652             parameter_changed = true;
653           }
654         }
655       };
656       ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
657       AdjustCalledComputationRoot(hlo);
658       if (hlo->opcode() == HloOpcode::kWhile) {
659         // We need to run on the while body and condition repeatedly until a
660         // fixed point is reached, i.e., the parameters do not change any more.
661         // We may need more than one iteration because the while input and
662         // output alias each other, so changing one input parameter requires
663         // changing the corresponding output element and thus may transitively
664         // require changing another input parameter. A fixed point will be
665         // reached because the parameters can only be changed from BF16 to F32,
666         // not the other way around.
667         absl::flat_hash_set<const HloComputation*> visited_in_while;
668         while (ResolveInconsistencyOfAliasingBuffersHelper(
669                    hlo->while_condition(), &visited_in_while) ||
670                ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(),
671                                                            &visited_in_while)) {
672           visited_in_while.clear();
673           ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
674           AdjustCalledComputationRoot(hlo);
675         }
676         visited_computations->insert(visited_in_while.begin(),
677                                      visited_in_while.end());
678       } else if (hlo->opcode() == HloOpcode::kFusion) {
679         ResolveInconsistencyOfAliasingBuffersHelper(
680             hlo->fused_instructions_computation(), visited_computations);
681       } else if (hlo->opcode() == HloOpcode::kConditional) {
682         for (auto* branch : hlo->branch_computations()) {
683           ResolveInconsistencyOfAliasingBuffersHelper(branch,
684                                                       visited_computations);
685         }
686       }
687     }
688     if (!any_change) {
689       break;
690     }
691   }
692   // Now adjust parameters of called computations.
693   for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
694     AdjustCalledComputationParameters(*inst_it);
695   }
696   return parameter_changed;
697 }
698 
ResolveInconsistencyOfAliasingBuffers(HloModule * module)699 void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
700     HloModule* module) {
701   const auto& computations_topological_order =
702       module->MakeComputationPostOrder();
703   absl::flat_hash_set<const HloComputation*> resolved;
704   for (auto comp_it = computations_topological_order.rbegin();
705        comp_it != computations_topological_order.rend(); ++comp_it) {
706     if (ContainsKey(resolved, *comp_it)) {
707       continue;
708     }
709     ResolveInconsistencyOfAliasingBuffersHelper(*comp_it, &resolved);
710   }
711 }
712 
ResolveInconsistentFusions(HloModule * module)713 Status BFloat16Propagation::ResolveInconsistentFusions(HloModule* module) {
714   // We could have changed a fusion computation's root shape to have a different
715   // precision than the fusion node's output, if the fusion root does not
716   // define a buffer (e.g., a tuple). Now we add conversions after such fusion
717   // roots to make them match the fusion output. If the fusion output is a
718   // (possibly nested) tuple, we first create get-tuple-elements, then convert
719   // the unmatching leaf nodes, and finally create a new tuple as the fusion
720   // computation's root. If tuples and get-tuple-elements are created, we will
721   // run tuple simplifier and dead code elimination at the end (dead code is not
722   // allowed in fusion computation). E.g.,
723   //
724   // (1)             (2)             (3)
725   // a  b            a  b            a  b
726   // |\ |            |\ |            |\ |
727   // \ add   ->      |add    ->      | add
728   //  \ |            \ |        convert |
729   //  tuple         tuple             \ |
730   //                 / \              tuple
731   //               gte gte
732   //                |   |
733   //           convert  |
734   //                 \  /
735   //                 tuple
736   // (1) a is F32 but tuple is BF16
737   // (2) after adding conversion
738   // (3) after tuple simplifier and DCE.
739   for (auto computation : module->MakeComputationPostOrder()) {
740     auto insts = computation->MakeInstructionPostOrder();
741     for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
742       auto hlo = *inst_it;
743       if (hlo->opcode() != HloOpcode::kFusion) {
744         continue;
745       }
746       auto fusion_computation = hlo->fused_instructions_computation();
747       auto fusion_root = fusion_computation->root_instruction();
748       if (ShapeUtil::Compatible(fusion_root->shape(), hlo->shape())) {
749         continue;
750       }
751       ShapeTree<HloInstruction*> converted_outputs(hlo->shape());
752       // Deep copy the fusion root, and convert a leaf node only if its shape
753       // does not match the fusion output.
754       TF_ASSIGN_OR_RETURN(
755           HloInstruction * copy,
756           fusion_computation->DeepCopyInstructionWithCustomCopier(
757               fusion_root,
758               [hlo](HloInstruction* leaf, const ShapeIndex& leaf_index,
759                     HloComputation* comp) {
760                 const Shape& hlo_subshape =
761                     ShapeUtil::GetSubshape(hlo->shape(), leaf_index);
762                 if (ShapeUtil::Compatible(leaf->shape(), hlo_subshape)) {
763                   return leaf;
764                 }
765                 return comp->AddInstruction(
766                     HloInstruction::CreateConvert(hlo_subshape, leaf));
767               }));
768       fusion_computation->set_root_instruction(copy);
769     }
770   }
771   return Status::OK();
772 }
773 
ResolveConvertedConstants(HloModule * module)774 Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) {
775   // We may have converted some constants from F32 to BF16, so adjust the
776   // constant literals in such cases. We do this here instead of when the
777   // constant node's is changed because 1) the HloInstruction interface does not
778   // allow resetting the literal so we have to create a new kConstant
779   // instruction to replace the old one, which invalidates dataflow analysis,
780   // and 2) it's possible that a kConstant's output gets changed to BF16 at the
781   // beginning but later on adjusted back to F32, so converting literals here
782   // can avoid repeated conversions.
783   //
784   // TODO(b/73833576): Consider resetting literal in HloInstruction.
785   for (auto computation : module->MakeComputationPostOrder()) {
786     for (auto hlo : computation->MakeInstructionPostOrder()) {
787       if (hlo->opcode() != HloOpcode::kConstant) {
788         continue;
789       }
790       if (!Shape::Equal().MinorToMajorOnlyInLayout()(hlo->literal().shape(),
791                                                      hlo->shape())) {
792         TF_ASSIGN_OR_RETURN(auto converted_literal,
793                             hlo->literal().ConvertToShape(hlo->shape()));
794         auto new_constant = computation->AddInstruction(
795             HloInstruction::CreateConstant(std::move(converted_literal)));
796         UpdateLayout(new_constant->mutable_shape());
797         TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant));
798       }
799     }
800   }
801   return Status::OK();
802 }
803 
SkipNoopConversions(HloModule * module)804 Status BFloat16Propagation::SkipNoopConversions(HloModule* module) {
805   for (auto computation : module->computations()) {
806     for (auto hlo : computation->MakeInstructionPostOrder()) {
807       if (hlo->opcode() != HloOpcode::kConvert) {
808         continue;
809       }
810       auto source = hlo->mutable_operand(0);
811       if (!ShapeUtil::Equal(source->shape(), hlo->shape())) {
812         continue;
813       }
814       const bool is_root = hlo == computation->root_instruction();
815       TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(source));
816       if (is_root) {
817         computation->set_root_instruction(source);
818       }
819     }
820   }
821   return Status::OK();
822 }
823 
824 // The algorithm first does a forward pass (parameters to root) to determine a
825 // set of instructions to consider using bfloat16, then does a backward pass to
826 // determine the precisions of those instructions according to the need of
827 // their users. During the backward pass, the potential changes are stored in
828 // changes_to_bf16_ which are subject to further adjustments then applied to the
829 // HLOs.
Run(HloModule * module)830 StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
831   consider_using_bfloat16_.clear();
832   instructions_visited_in_backward_pass_.clear();
833   computations_visited_in_backward_pass_.clear();
834   values_that_must_be_kept_as_f32_.clear();
835   caller_counts_.clear();
836   changes_to_bf16_.clear();
837   changed_ = false;
838 
839   auto computations_topological_order = module->MakeComputationPostOrder();
840 
841   // Before running the propagation pass, we insert copies (kConvert to the same
842   // type) of F32 inputs to while loops. This prevents other uses of the same
843   // input from aliasing the while loop input/output, so that there's greater
844   // chance to use BF16 inside the loop. If some of these added copies do not
845   // help, they will remain F32 after BF16 propagation and will be removed since
846   // they are no-ops.
847   for (auto computation : computations_topological_order) {
848     for (auto inst : computation->MakeInstructionPostOrder()) {
849       if (inst->opcode() != HloOpcode::kWhile) {
850         continue;
851       }
852 
853       auto operand = inst->mutable_operand(0);
854       TF_ASSIGN_OR_RETURN(
855           HloInstruction * copy,
856           computation->DeepCopyInstructionWithCustomCopier(
857               operand, [](HloInstruction* leaf, const ShapeIndex& leaf_index,
858                           HloComputation* comp) {
859                 if (leaf->shape().element_type() != F32) {
860                   return leaf;
861                 }
862                 return comp->AddInstruction(
863                     HloInstruction::CreateConvert(leaf->shape(), leaf));
864               }));
865       TF_RETURN_IF_ERROR(operand->ReplaceUseWith(inst, copy));
866     }
867   }
868 
869   TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module));
870 
871   // The first step is a forward pass (parameters to root), where we determine
872   // the potential candidate instructions to use bfloat16 in the outputs that
873   // are not likely to cause overhead from extra explicit conversions. This is
874   // done forwardly because we determine whether an HLO is a candidate partially
875   // based on whether its operands are candidates.
876   for (auto computation : computations_topological_order) {
877     for (auto inst : computation->MakeInstructionPostOrder()) {
878       if (InstructionIsCandidateForBF16Output(inst)) {
879         consider_using_bfloat16_.insert(inst);
880       }
881     }
882   }
883 
884   // The second step is a backward pass (root to parameters), where we modify
885   // the precisions of the instructions identified in the first step when
886   // feasible. This is done backwardly because we determine the precision of an
887   // HLO's output based on how it is later used.
888   //
889   // The precision of an instruction is determined by its users, so we do the
890   // propagation in reverse topological order.
891   for (auto comp_it = computations_topological_order.rbegin();
892        comp_it != computations_topological_order.rend(); ++comp_it) {
893     if (ContainsKey(computations_visited_in_backward_pass_, *comp_it)) {
894       continue;
895     }
896     auto insts = (*comp_it)->MakeInstructionPostOrder();
897     for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
898       DetermineInstructionPrecision(*inst_it,
899                                     /*skip_parameters=*/true);
900     }
901     computations_visited_in_backward_pass_.insert(*comp_it);
902   }
903 
904   // It's possible that an instruction does not define a buffer, but the
905   // defining instruction's shape has changed. So we need to adjust the output
906   // shapes of instructions according to the HLO values they refer to.
907   ResolveInconsistencyOfAliasingBuffers(module);
908 
909   // Apply the changes in changes_to_bf16_.
910   for (auto& change : changes_to_bf16_) {
911     auto inst = change.first;
912     // It is possible that we marked inst to change precision even if it is an
913     // unsupported change, when inst is the root of a fusion computation and it
914     // has to match the fusion node's output precision. We do a convert instead
915     // of in-place change for such cases.
916     if (ShouldKeepPrecisionUnchanged(inst)) {
917       auto users = inst->users();
918       bool is_root = inst == inst->parent()->root_instruction();
919       TF_ASSIGN_OR_RETURN(
920           HloInstruction * copy,
921           inst->parent()->DeepCopyInstructionWithCustomCopier(
922               inst, [&](HloInstruction* leaf, const ShapeIndex& leaf_index,
923                         HloComputation* comp) {
924                 if (!ContainsKey(change.second,
925                                  ShapeUtil::GetMutableSubshape(
926                                      inst->mutable_shape(), leaf_index))) {
927                   return leaf;
928                 }
929                 auto converted_shape =
930                     ShapeUtil::ChangeElementType(leaf->shape(), BF16);
931                 UpdateLayout(&converted_shape);
932                 return comp->AddInstruction(
933                     HloInstruction::CreateConvert(converted_shape, leaf));
934               }));
935       for (auto user : users) {
936         TF_RETURN_IF_ERROR(inst->ReplaceUseWithDifferentShape(user, copy));
937       }
938       if (is_root) {
939         inst->parent()->set_root_instruction(copy,
940                                              /*accept_different_shape=*/true);
941       }
942       continue;
943     }
944     for (const auto& entry : change.second) {
945       auto subshape = entry.first;
946       CHECK_EQ(subshape->element_type(), F32);
947       subshape->set_element_type(BF16);
948       UpdateLayout(subshape);
949       changed_ = true;
950     }
951   }
952 
953   // Removes redundant HLOs added by this pass, either when inserting
954   // de-aliasing copies to while loop inputs, or later when converting output
955   // types.
956   auto clean_up = [this, module]() {
957     TF_RETURN_IF_ERROR(SkipNoopConversions(module));
958     TupleSimplifier tuple_simplifier;
959     TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
960     HloDCE dce;
961     TF_RETURN_IF_ERROR(dce.Run(module).status());
962     return Status::OK();
963   };
964 
965   if (!changed_) {
966     TF_RETURN_IF_ERROR(clean_up());
967     return false;
968   }
969 
970   TF_RETURN_IF_ERROR(ResolveInconsistentFusions(module));
971   TF_RETURN_IF_ERROR(ResolveConvertedConstants(module));
972 
973   TF_RETURN_IF_ERROR(clean_up());
974   return true;
975 }
976 
OutputTypeAfterChange(HloInstruction * hlo,const ShapeIndex & index) const977 PrimitiveType BFloat16Propagation::OutputTypeAfterChange(
978     HloInstruction* hlo, const ShapeIndex& index) const {
979   Shape* subshape = ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index);
980   const PrimitiveType type_on_hlo = subshape->element_type();
981   if (type_on_hlo != F32) {
982     return type_on_hlo;
983   }
984   auto it = changes_to_bf16_.find(hlo);
985   if (it == changes_to_bf16_.end()) {
986     return type_on_hlo;
987   }
988   return ContainsKey(it->second, subshape) ? BF16 : F32;
989 }
990 
ValueTypeAfterChange(const HloValue * value) const991 PrimitiveType BFloat16Propagation::ValueTypeAfterChange(
992     const HloValue* value) const {
993   auto hlo = value->defining_instruction();
994   const auto& position = value->defining_position();
995   return OutputTypeAfterChange(hlo, position.index);
996 }
997 
AddToOrRemoveFromBF16ChangeSet(HloInstruction * hlo,const ShapeIndex & index,PrimitiveType target_type)998 void BFloat16Propagation::AddToOrRemoveFromBF16ChangeSet(
999     HloInstruction* hlo, const ShapeIndex& index, PrimitiveType target_type) {
1000   if (target_type == BF16) {
1001     auto& entry = changes_to_bf16_[hlo];
1002     entry.emplace(ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index),
1003                   index);
1004   } else {
1005     CHECK_EQ(target_type, F32);
1006     auto it = changes_to_bf16_.find(hlo);
1007     if (it == changes_to_bf16_.end()) {
1008       return;
1009     }
1010     it->second.erase(
1011         ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index));
1012   }
1013 }
1014 
1015 }  // namespace xla
1016