• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_
18 
19 #include <functional>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "tensorflow/compiler/xla/service/fusion_queue.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
30 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
31 
32 namespace xla {
33 
34 struct NoFusionPossible;
35 
36 // Propagating explanation of fusion decisions: if something could not be fused,
37 // explain the reason.
38 class FusionDecision {
39  public:
40   // Can not be fused: explain why. Implicit conversion due to optional-like
41   // semantics: waiver granted in cl/419938611.
FusionDecision(absl::string_view explanation)42   FusionDecision(absl::string_view explanation)  // NOLINT
43       : explanation_(explanation) {}
44 
45   // Same constructor as string_view, to allow implicit string conversion (can't
46   // implicitly convert both char* to string_view and string_view to
47   // FusionDecision).
FusionDecision(const char * explanation)48   FusionDecision(const char* explanation)  // NOLINT
49       : explanation_(explanation) {}
50 
51   // If condition is `true` means that we CAN fuse. In that case, explanation is
52   // discarded.
FusionDecision(bool condition,absl::string_view explanation)53   FusionDecision(bool condition, absl::string_view explanation) {
54     if (!condition) {
55       explanation_ = std::string(explanation);
56     }
57   }
58 
59   // Can be fused.
FusionDecision()60   FusionDecision() {}
61 
62   // A trick to declare and test fusion decision in a single statement (as TF
63   // is still on C++14 and can't use if statement with explicit initializer).
64   //
65   // Cf. NoFusionPossible definition for sample usage.
66   // TODO(b/157309856): Use conditional initializer instead.
67   NoFusionPossible operator!();
68 
69   // Returns whether it can be fused.
70   explicit operator bool() const { return CanFuse(); }
71 
72   // Whether the fusion decision is positive.
CanFuse()73   bool CanFuse() const { return !explanation_.has_value(); }
74 
75   // Connects two decisions with a disjunction. This is different than just
76   // picking one, as we also have to propagate both explanations if only one of
77   // them is false to show why fusion wasn't performed.
Or(const FusionDecision & decision)78   FusionDecision Or(const FusionDecision& decision) {
79     if (CanFuse() || decision.CanFuse()) {
80       return {};
81     }
82     return {absl::StrCat(explanation_.value_or(""), " ; ", decision.Explain())};
83   }
84 
85   // Connects two fusion decision with a conjunction. Unlike disjunction,
86   // propagates only one explanation (as it is enough to show that fusion could
87   // not be done).
And(const FusionDecision & decision)88   FusionDecision And(const FusionDecision& decision) {
89     if (CanFuse()) {
90       return decision;
91     }
92     if (decision.CanFuse()) {
93       return *this;
94     }
95     // Both conditions were violated: returning either is valid.
96     return *this;
97   }
98 
99   // Appends to explanation, or turns the decision negative.
100   FusionDecision operator<<(absl::string_view explanation) {
101     return {absl::StrCat(explanation_.value_or(""), explanation)};
102   }
103 
104   // Appends to explanation, or turns the decision negative.
105   FusionDecision operator<<(int64_t explanation) {
106     return {absl::StrCat(explanation_.value_or(""), explanation)};
107   }
108 
109   // Explains why the fusion could not be performed.
Explain()110   std::string Explain() const { return *explanation_; }
111 
112  private:
113   // Empty IFF fusion is possible (explanation provided for negative cases).
114   std::optional<std::string> explanation_;
115 };
116 
117 // Helper class: contextually convertible to "no fusion possible" unlike
118 // FusionDecision. This is a trick to declare and test fusion decision in a
119 // single statement (as we are still on C++14).
120 //
121 // Sample usage:
122 //
123 // if (NoFusionPossible fusible = !FusabilityRestriction(producer, consume)) {
124 //   return !fusible; // Note that negation converts it back to FusionDecision.
125 // }
126 struct NoFusionPossible {
127   // Inverts the test value (true <=> not fusible) on wrapped FusionDecision.
128   explicit operator bool() { return !static_cast<bool>(fusion_decision); }
129 
130   // Returns wrapped fusion decision.
131   FusionDecision operator!() { return fusion_decision; }
132 
133   FusionDecision fusion_decision;
134 };
135 
136 inline NoFusionPossible FusionDecision::operator!() { return {*this}; }
137 
138 // HLO pass which performs instruction fusion. Instructions are fused
139 // "vertically", meaning producing instructions are fused into their consumers
140 // with the intent that the loops which compute their values will be fused in
141 // code generation. Derived classes define ShouldFuse method to select which
142 // instructions to fuse.
143 class InstructionFusion : public HloModulePass {
144  public:
145   explicit InstructionFusion(
146       std::function<bool(const HloInstruction& instruction)> is_expensive,
147       bool may_duplicate = true,
148       FusionConfigCollection config_collection_mode =
149           FusionConfigCollection::kOff)
is_expensive_(is_expensive)150       : is_expensive_(is_expensive),
151         may_duplicate_(may_duplicate),
152         config_collection_mode_(config_collection_mode) {}
153   ~InstructionFusion() override = default;
name()154   absl::string_view name() const override { return "fusion"; }
155 
156   // Run instruction fusion on the given computation. Returns whether the
157   // computation was changed (instructions were fused).
158   using HloPassInterface::Run;
159   StatusOr<bool> Run(
160       HloModule* module,
161       const absl::flat_hash_set<absl::string_view>& execution_threads) override;
162 
163   // Returns true if the computation of the given instruction is significantly
164   // more expensive than just writing all the values of the instructions' result
165   // array. Expensive operations will not be duplicated.
166   static bool IsExpensive(const HloInstruction& instruction);
167 
168   // Returns true if it's legal to fuse the producer instruction into consumer
169   // with regard to in-place semantics of the consumer. For example, it is
170   // illegal to fuse a slice into a dynamic-update-slice if the slice output is
171   // used as the update and if slice and dynamic-update-slice indices cannot be
172   // proven to be the same.
173   static FusionDecision ShouldFuseInPlaceOp(const HloInstruction* producer,
174                                             const HloInstruction* consumer);
175 
176  protected:
177   // Returns a list of computations on which Fusion is performed.
178   virtual std::vector<HloComputation*> GetFusionComputations(
179       HloModule* module,
180       const absl::flat_hash_set<absl::string_view>& execution_threads);
181 
182   // Returns a FusionQueue that implements custom order of instructions being
183   // fused. The default implementation processes consumers in reverse post
184   // order.
185   virtual std::unique_ptr<FusionQueue> GetFusionQueue(
186       HloComputation* computation);
187 
188   // Returns whether the given producer instruction should be fused into the
189   // given consumer instruction. producer is necessarily an operand of consumer.
190   // Derived classes should define this method to specify which instructions
191   // should be fused. `operand_index` is which operand of the consumer the
192   // producer is.
193   //
194   // Instructions are traversed in reverse post order (computation root to
195   // leaves). This method is called for each operand of the instruction (where
196   // the operand is 'producer' and the instruction is 'consumer')
197   //
198   // Subtypes can override this with target-specific heuristics.
199   virtual FusionDecision ShouldFuse(HloInstruction* consumer,
200                                     int64_t operand_index);
201 
202   // Returns whether multi-output fusion can be applied to fuse `producer` into
203   // `consumer`. In contrast to "regular" fusion, the `producer` is not
204   // duplicated by multi-output fusion.
ShouldFuseIntoMultiOutput(HloInstruction * consumer,int64_t operand_index)205   virtual FusionDecision ShouldFuseIntoMultiOutput(HloInstruction* consumer,
206                                                    int64_t operand_index) {
207     return "multi-output fusion not supported by this pass";
208   }
209 
210   // Chooses a fusion kind for `producer` and `consumer`.
211   // Default method chooses `kLoop`.
212   virtual HloInstruction::FusionKind ChooseKind(const HloInstruction* producer,
213                                                 const HloInstruction* consumer);
214 
215   // Fuses 'producer' into 'fusion_instruction'. 'fusion_instruction' needs to
216   // be a fusion instruction. Returns the newly created clone of 'producer'
217   // which is part of the fusion computation.
218   virtual HloInstruction* FuseInstruction(HloInstruction* fusion_instruction,
219                                           HloInstruction* producer);
220 
221   // Fuses producer into consumer. Returns the fusion instruction.
222   virtual HloInstruction* Fuse(HloInstruction* producer,
223                                HloInstruction* consumer,
224                                HloComputation* computation);
225 
226   // Creates a new fusion instruction containing `producer` and `consumer`. A
227   // tuple is added as the fusion instruction's root, which consumes from both,
228   // `producer` and `consumer`. This style of fusion is referred to as
229   // multi-output fusion.
230   virtual HloInstruction* FuseIntoMultiOutput(HloInstruction* producer,
231                                               HloInstruction* consumer,
232                                               HloComputation* computation);
233 
234   // An "effectively unary" operation is one that has at most one "large"
235   // input with the others being negligible in terms of memory usage.
236   // We use "has a smaller true rank than the output" as a heuristic
237   // for "negligible" memory usage.
238   bool EffectivelyAtMostUnary(HloInstruction* hlo);
239 
240   // Returns true if fusing producer into consumer would cause producer to be
241   // duplicated. This is the case if producer has uses other than consumer.
FusionWouldDuplicate(const HloInstruction & producer,const HloInstruction & consumer)242   bool FusionWouldDuplicate(const HloInstruction& producer,
243                             const HloInstruction& consumer) {
244     return !(producer.users().size() == 1 && consumer.IsUserOf(&producer));
245   }
246 
is_expensive(const HloInstruction & instruction)247   bool is_expensive(const HloInstruction& instruction) {
248     return is_expensive_(instruction);
249   }
250 
251   // Overwrites the originally initialized is_expensive function.
set_is_expensive(std::function<bool (const HloInstruction & instruction)> is_expensive)252   void set_is_expensive(
253       std::function<bool(const HloInstruction& instruction)> is_expensive) {
254     is_expensive_ = is_expensive;
255   }
256 
257   // Whether multi-output fusion would introduce a cycle into the HLO graph.
258   bool MultiOutputFusionCreatesCycle(HloInstruction* producer,
259                                      HloInstruction* consumer,
260                                      const HloReachabilityMap& reachability);
261 
config_collection_mode()262   FusionConfigCollection config_collection_mode() {
263     return config_collection_mode_;
264   }
265 
266   // Returns whether 'consumer' may reuse elements of its `operand_index`th
267   // operand.
268   bool ReusesOperandElements(const HloInstruction* consumer,
269                              int64_t operand_index);
270 
271   // The set of producers whose consumers we cannot fuse into.
272   using HloInstructionSet = absl::flat_hash_set<HloInstruction*>;
273 
274   // Computes the set of nodes that we do not want to fuse into any of their
275   // consumers based on a global analysis of the HLO graph.
276   virtual HloInstructionSet ComputeGloballyUnfusible(
277       absl::Span<HloInstruction* const> post_order,
278       const HloReachabilityMap& reachability);
279 
280  private:
281   // Returns the reused operands of `instruction` from reused_fusion_operands_,
282   // computing them if they have not previously been computed for that
283   // instruction.
284   // The returned value has pointer stability, assuming entries are not deleted
285   // from reused_fusion_operands_.
286   absl::flat_hash_set<const HloInstruction*>& ReusedOperandsOf(
287       const HloInstruction* instruction);
288 
289   // Updates reused_fusion_operands_ for a fusion when we are about to fuse
290   // `producer` into `fusion_instruction`.
291   void UpdateReusedOperandsForFusion(HloInstruction* producer,
292                                      HloInstruction* fusion_instruction);
293 
294   HloInstruction* AddFusionInstruction(HloInstruction* producer,
295                                        HloInstruction* consumer,
296                                        HloComputation* computation);
297 
298   // Whether or not we can fuse producer into consumer on all paths
299   // from the producer to the consumer where nodes are HLOs and edges are uses.
300   //
301   // A map from <producer, consumer> to a bool is required as the result cache
302   // to store and query the results of calls to this function, in order to avoid
303   // repeated computations.
304   bool CanFuseOnAllPaths(
305       HloInstruction* producer, HloInstruction* consumer,
306       const HloInstructionSet& do_not_fuse,
307       const HloReachabilityMap& reachability,
308       absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool>*
309           result_cache);
310 
311   // Used to determine if an HLO is expensive. Expensive operations will not be
312   // duplicated.
313   std::function<bool(const HloInstruction& instruction)> is_expensive_;
314 
315   // Returns whether we may duplicate an instruction if we want to fuse it.
316   bool may_duplicate_;
317 
318   // Configuration mode.
319   FusionConfigCollection config_collection_mode_;
320 
321   // Caches which operands are reused inside fusion computations.
322   absl::flat_hash_map<
323       const HloInstruction*,
324       std::unique_ptr<absl::flat_hash_set<const HloInstruction*>>>
325       reused_fusion_operands_;
326 
327   InstructionFusion(const InstructionFusion&) = delete;
328   InstructionFusion& operator=(const InstructionFusion&) = delete;
329 };
330 
331 }  // namespace xla
332 
333 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_
334