• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_
18 
19 #include <memory>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "tensorflow/compiler/xla/service/bfloat16_support.h"
27 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
31 #include "tensorflow/core/lib/hash/hash.h"
32 
33 namespace xla {
34 
35 // HLO pass which reduces the precision of some HLO instructions to BF16
36 // according to the backend-specific BFloat16Support rule provided by the
37 // caller.
38 //
39 // This pass can be used to reduce instruction precision without affecting the
40 // numerical accuracy of the module, i.e., the final output of the module would
41 // be bitwise identical to that without this pass; this is possible if the
42 // backend already reduces precision to BF16 on some HLO instructions.
43 //
44 // This pass will not modify the signature of a computation, unless it is a
45 // fusion computation or its only caller is a while.
46 //
47 // !!! WARNING !!! This pass can introduce mixed precision in individual HLOs,
48 // which has two issues:
49 //
50 // 1) It does not guarantee to respect the passed-in BFloat16Support
51 // specification in terms of mixed precision, so the backend may not support an
52 // HLO that has mixed precision produced by this pass. To address this issue,
53 // run BFloat16Normalization with the same BFloat16Support after this pass.
54 //
55 // 2) In general, mixed precision may break the assumptions of some other HLO
56 // passes even if the specific backend supports the individual HLOs. Such
57 // assumptions include that there are no HLOs using mixed precision, or that the
58 // precision of an HLO's output is determined by its inputs. It should be used
59 // at the end of the HLO optimization pipeline but before
60 // BFloat16ConversionFolding. If other passes are needed after this pass, run
61 // BFloat16MixedPrecisionRemoval first to undo some of the changes made by this
62 // pass.
63 class BFloat16Propagation : public HloModulePass {
64  public:
65   explicit BFloat16Propagation(const BFloat16Support* bfloat16_support);
66 
67   ~BFloat16Propagation() override = default;
68 
name()69   absl::string_view name() const override { return "bfloat16-propagation"; }
70 
71   // Runs the pass on the given module. Returns whether the module was changed
72   // (precision reductions were added).
73   StatusOr<bool> Run(HloModule* module) override;
74 
75   // Returns whether we should avoid changing the precision of inst regardless
76   // of the producers and users.
77   virtual bool ShouldKeepPrecisionUnchanged(const HloInstruction* inst);
78 
79   // Determines whether we should consider changing the precision of the given
80   // instruction in the forward pass.
81   virtual bool InstructionIsCandidateForBF16Output(HloInstruction* hlo);
82 
83  private:
84   // ***************************
85   // Function called and state produced by the forward analysis pass (from
86   // parameters to root) that determines the candidate HLOs to use BF16 outputs.
87 
88   // The set of instructions to consider using bfloat16, computed in the forward
89   // pass.
90   absl::flat_hash_set<const HloInstruction*> consider_using_bfloat16_;
91 
92   // ***************************
93   // Functions called and state produced by the backward pass (from root to
94   // parameters) that finds opportunities to use BF16.
95 
96   // Determines the precision for the given instruction in the
97   // opportunity-finding pass.
98   void DetermineInstructionPrecision(HloInstruction* hlo, bool skip_parameters);
99 
100   // Special handling in the opportunity-finding pass for fusion computations.
101   //
102   // Precondition: hlo->opcode() == kFusion
103   void DetermineFusionComputationPrecision(HloInstruction* fusion);
104 
105   // Reverts changes to BF16 that will not propagate outside a fusion
106   // computation. This avoids BF16 casts overhead inside a fusion which won't
107   // save memory bandwidth.
108   //
109   // Precondition: hlo->opcode() == kFusion
110   void RevertIfFusionInternalBF16Changes(HloInstruction* fusion);
111 
112   // Special handling in the opportunity-finding pass for while computations.
113   //
114   // Precondition: hlo->opcode() == kWhile
115   void DetermineWhileComputationsPrecision(HloInstruction* while_hlo);
116 
117   // Special handling in the opportunity-finding pass for conditional branches.
118   //
119   // Precondition: hlo->opcode() == kConditional
120   void DetermineConditionalComputationsPrecision(HloInstruction* cond);
121 
122   // The set of HloInstructions that have been visited in the
123   // opportunity-finding pass.
124   absl::flat_hash_set<const HloInstruction*>
125       instructions_visited_in_backward_pass_;
126 
127   // The set of HloComputations that have been visited in the
128   // opportunity-finding pass.
129   absl::flat_hash_set<const HloComputation*>
130       computations_visited_in_backward_pass_;
131 
132   // ***************************
133   // Functions called by the final inconsistency resolving pass.
134 
135   // Adjusts the output shapes of HloInstructions such that if two
136   // HloInstructions have aliasing buffers in their outputs, they must have the
137   // same precision.
138   void ResolveInconsistencyOfAliasingBuffers(HloModule* module);
139 
140   // Resolves inconsistency of aliasing buffers for the given computation, and
141   // recursively runs on a while instruction's condition and body until a fixed
142   // point is reached.
143   bool ResolveInconsistencyOfAliasingBuffersHelper(
144       HloComputation* computation,
145       absl::flat_hash_set<const HloComputation*>* visited_computations);
146 
147   // Makes the parameters of called computations match how they are called by
148   // the given HLO.
149   void AdjustCalledComputationParameters(HloInstruction* hlo);
150 
151   // Makes the root instructions of called computations match how they are used
152   // by the given HLO.
153   void AdjustCalledComputationRoot(HloInstruction* hlo);
154 
155   // ***************************
156   // Functions called after changes in changes_to_bf16_ are applied.
157 
158   // Resolves inconsistencies introduced by this pass for fusions with
159   // tuple-type output.
160   Status ResolveInconsistentFusions(HloModule* module);
161 
162   // Converts the literals in kConstant HLOs which have their types changed to
163   // BF16 by this pass.
164   Status ResolveConvertedConstants(HloModule* module);
165 
166   // Skips no-op conversions (same source and target shapes) that can be
167   // produced this pass, i.e., replaces them in their uses with their operands.
168   Status SkipNoopConversions(HloModule* module);
169 
170   // ***************************
171   // Functions called and state used by two or more passes.
172 
173   // Returns whether all uses of the given HloInstruction can consume BF16
174   // input.
175   bool AllUsersConsumeBF16(const HloInstruction& hlo,
176                            const ShapeIndex& index) const;
177 
178   // The output element type of the HLO at the given shape index after changes
179   // in changes_to_bf16_ are applied.
180   PrimitiveType OutputTypeAfterChange(HloInstruction* hlo,
181                                       const ShapeIndex& index) const;
182 
183   // The element type of the HLO value after changes in changes_to_bf16_ are
184   // applied.
185   PrimitiveType ValueTypeAfterChange(const HloValue* value) const;
186 
187   // If target_type == BF16, adds the HLO at the given index to
188   // changes_to_bf16_; otherwise, target_type must be F32 and this function
189   // removes the HLO at the given index from changes_to_bf16_ if it was earlier
190   // added.
191   void AddToOrRemoveFromBF16ChangeSet(HloInstruction* hlo,
192                                       const ShapeIndex& index,
193                                       PrimitiveType target_type);
194 
195   // The set of F32 HLO values that must be kept in F32.
196   absl::flat_hash_set<const HloValue*> values_that_must_be_kept_as_f32_;
197 
198   // Mapping from each HloComputation to the number of callers to it in the
199   // module. Populated at the beginning of this pass.
200   absl::flat_hash_map<const HloComputation*, int64> caller_counts_;
201 
202   // We first store the potential F32-to-BF16 changes to changes_to_bf16_, which
203   // are subject to further adjustment, then finally applied to the HLOs. This
204   // avoids setting changed_ to true but all changes are reverted during
205   // adjustment.
206   //
207   // For each HloInstruction, changes_to_bf16_ stores the affected buffers in
208   // the output as a map from in-place pointers to subshapes to shape indices.
209   absl::flat_hash_map<HloInstruction*, absl::flat_hash_map<Shape*, ShapeIndex>>
210       changes_to_bf16_;
211 
212   // Whether the last processed HLO module has been changed by this pass.
213   bool changed_ = false;
214 
215   const BFloat16Support* bfloat16_support_;
216   std::unique_ptr<HloDataflowAnalysis> dataflow_;
217 };
218 
219 }  // namespace xla
220 
221 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_
222