• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include "absl/container/flat_hash_map.h"
2 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3 
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7 
8     http://www.apache.org/licenses/LICENSE-2.0
9 
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 ==============================================================================*/
16 
17 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_
18 #define TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_
19 
20 #include "tensorflow/compiler/xla/service/fusion_queue.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_module.h"
24 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
25 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
26 #include "tensorflow/core/platform/macros.h"
27 
28 namespace xla {
29 
30 // HLO pass which performs instruction fusion. Instructions are fused
31 // "vertically", meaning producing instructions are fused into their consumers
32 // with the intent that the loops which compute their values will be fused in
33 // code generation. Derived classes define ShouldFuse method to select which
34 // instructions to fuse.
35 class InstructionFusion : public HloModulePass {
36  public:
37   explicit InstructionFusion(
38       std::function<bool(const HloInstruction& instruction)> is_expensive,
39       bool may_duplicate = true)
is_expensive_(is_expensive)40       : is_expensive_(is_expensive), may_duplicate_(may_duplicate) {}
41   ~InstructionFusion() override = default;
name()42   absl::string_view name() const override { return "fusion"; }
43 
44   // Run instruction fusion on the given computation. Returns whether the
45   // computation was changed (instructions were fused).
46   StatusOr<bool> Run(HloModule* module) override;
47 
48   // Returns true if the computation of the given instruction is significantly
49   // more expensive than just writing all the values of the instructions' result
50   // array. Expensive operations will not be duplicated.
51   static bool IsExpensive(const HloInstruction& instruction);
52 
53  protected:
54   // Returns a FusionQueue that implements custom order of instructions being
55   // fused. The default implementation processes consumers in reverse post
56   // order.
57   virtual std::unique_ptr<FusionQueue> GetFusionQueue(
58       HloComputation* computation);
59 
60   // Returns whether the given producer instruction should be fused into the
61   // given consumer instruction. producer is necessarily an operand of consumer.
62   // Derived classes should define this method to specify which instructions
63   // should be fused. `operand_index` is which operand of the consumer the
64   // producer is.
65   //
66   // Instructions are traversed in reverse post order (computation root to
67   // leaves). This method is called for each operand of the instruction (where
68   // the operand is 'producer' and the instruction is 'consumer')
69   //
70   // Subtypes can override this with target-specific heuristics.
71   virtual bool ShouldFuse(HloInstruction* consumer, int64 operand_index);
72 
73   // Returns whether multi-output fusion can be applied to fuse `producer` into
74   // `consumer`. In contrast to "regular" fusion, the `producer` is not
75   // duplicated by multi-output fusion.
ShouldFuseIntoMultiOutput(HloInstruction * consumer,int64 operand_index)76   virtual bool ShouldFuseIntoMultiOutput(HloInstruction* consumer,
77                                          int64 operand_index) {
78     return false;
79   }
80 
81   // Chooses a fusion kind for `producer` and `consumer`.
82   // Default method chooses `kLoop`.
83   virtual HloInstruction::FusionKind ChooseKind(const HloInstruction* producer,
84                                                 const HloInstruction* consumer);
85 
86   // Fuses producer into consumer.
87   virtual HloInstruction* Fuse(HloInstruction* producer,
88                                HloInstruction* consumer);
89 
90   // Creates a new fusion instruction containing `producer` and `consumer`. A
91   // tuple is added as the fusion instruction's root, which consumes from both,
92   // `producer` and `consumer`. This style of fusion is referred to as
93   // multi-output fusion.
94   virtual HloInstruction* FuseIntoMultiOutput(HloInstruction* producer,
95                                               HloInstruction* consumer);
96 
97   // An "effectively unary" operation is one that has at most one "large"
98   // input with the others being negligible in terms of memory usage.
99   // We use "has a smaller true rank than the output" as a heuristic
100   // for "negligible" memory usage.
101   bool EffectivelyAtMostUnary(HloInstruction* hlo);
102 
103   // Returns true if fusing producer into consumer would cause producer to be
104   // duplicated. This is the case if producer has uses other than consumer.
FusionWouldDuplicate(const HloInstruction & producer,const HloInstruction & consumer)105   bool FusionWouldDuplicate(const HloInstruction& producer,
106                             const HloInstruction& consumer) {
107     return !(producer.users().size() == 1 && consumer.IsUserOf(&producer));
108   }
109 
is_expensive(const HloInstruction & instruction)110   bool is_expensive(const HloInstruction& instruction) {
111     return is_expensive_(instruction);
112   }
113 
114   // Whether multi-output fusion would introduce a cycle into the HLO graph.
115   bool MultiOutputFusionCreatesCycle(HloInstruction* producer,
116                                      HloInstruction* consumer);
117 
118   // Current HloComputation instance the loop fuser is traversing.
119   HloComputation* computation_;
120   HloModule* module_;
121   // Reachability information for the current computation.
122   std::unique_ptr<HloReachabilityMap> reachability_;
123 
124  private:
125   // The set of producers whose consumers we cannot fuse into.
126   using HloInstructionSet = std::unordered_set<HloInstruction*>;
127 
128   HloInstruction* AddFusionInstruction(HloInstruction* producer,
129                                        HloInstruction* consumer);
130 
131   // Whether or not we can fuse producer into consumer on all paths
132   // from the producer to the consumer where nodes are HLOs and edges are uses.
133   //
134   // A map from <producer, consumer> to a bool is required as the result cache
135   // to store and query the results of calls to this function, in order to avoid
136   // repeated computations.
137   bool CanFuseOnAllPaths(
138       HloInstruction* producer, HloInstruction* consumer,
139       const HloInstructionSet& do_not_fuse,
140       absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool>*
141           result_cache);
142 
143   // Computes the set of nodes that we do not want to fuse into any of their
144   // consumers based on a global analysis of the HLO graph.
145   HloInstructionSet ComputeGloballyUnfusible(
146       absl::Span<HloInstruction* const> post_order);
147 
148   // Used to determine if an HLO is expensive. Expensive operations will not be
149   // duplicated.
150   std::function<bool(const HloInstruction& instruction)> is_expensive_;
151 
152   // Returns whether we may duplicate an instruction if we want to fuse it.
153   bool may_duplicate_;
154 
155   TF_DISALLOW_COPY_AND_ASSIGN(InstructionFusion);
156 };
157 
158 }  // namespace xla
159 
160 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_
161