1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_ 18 19 #include <functional> 20 #include <string> 21 #include <utility> 22 23 #include "absl/container/flat_hash_map.h" 24 #include "absl/container/flat_hash_set.h" 25 #include "tensorflow/compiler/xla/service/fusion_queue.h" 26 #include "tensorflow/compiler/xla/service/hlo_computation.h" 27 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 28 #include "tensorflow/compiler/xla/service/hlo_module.h" 29 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 30 #include "tensorflow/compiler/xla/service/hlo_reachability.h" 31 32 namespace xla { 33 34 struct NoFusionPossible; 35 36 // Propagating explanation of fusion decisions: if something could not be fused, 37 // explain the reason. 38 class FusionDecision { 39 public: 40 // Can not be fused: explain why. Implicit conversion due to optional-like 41 // semantics: waiver granted in cl/419938611. FusionDecision(absl::string_view explanation)42 FusionDecision(absl::string_view explanation) // NOLINT 43 : explanation_(explanation) {} 44 45 // Same constructor as string_view, to allow implicit string conversion (can't 46 // implicitly convert both char* to string_view and string_view to 47 // FusionDecision). FusionDecision(const char * explanation)48 FusionDecision(const char* explanation) // NOLINT 49 : explanation_(explanation) {} 50 51 // If condition is `true` means that we CAN fuse. In that case, explanation is 52 // discarded. FusionDecision(bool condition,absl::string_view explanation)53 FusionDecision(bool condition, absl::string_view explanation) { 54 if (!condition) { 55 explanation_ = std::string(explanation); 56 } 57 } 58 59 // Can be fused. FusionDecision()60 FusionDecision() {} 61 62 // A trick to declare and test fusion decision in a single statement (as TF 63 // is still on C++14 and can't use if statement with explicit initializer). 64 // 65 // Cf. NoFusionPossible definition for sample usage. 66 // TODO(b/157309856): Use conditional initializer instead. 67 NoFusionPossible operator!(); 68 69 // Returns whether it can be fused. 70 explicit operator bool() const { return CanFuse(); } 71 72 // Whether the fusion decision is positive. CanFuse()73 bool CanFuse() const { return !explanation_.has_value(); } 74 75 // Connects two decisions with a disjunction. This is different than just 76 // picking one, as we also have to propagate both explanations if only one of 77 // them is false to show why fusion wasn't performed. Or(const FusionDecision & decision)78 FusionDecision Or(const FusionDecision& decision) { 79 if (CanFuse() || decision.CanFuse()) { 80 return {}; 81 } 82 return {absl::StrCat(explanation_.value_or(""), " ; ", decision.Explain())}; 83 } 84 85 // Connects two fusion decision with a conjunction. Unlike disjunction, 86 // propagates only one explanation (as it is enough to show that fusion could 87 // not be done). And(const FusionDecision & decision)88 FusionDecision And(const FusionDecision& decision) { 89 if (CanFuse()) { 90 return decision; 91 } 92 if (decision.CanFuse()) { 93 return *this; 94 } 95 // Both conditions were violated: returning either is valid. 96 return *this; 97 } 98 99 // Appends to explanation, or turns the decision negative. 100 FusionDecision operator<<(absl::string_view explanation) { 101 return {absl::StrCat(explanation_.value_or(""), explanation)}; 102 } 103 104 // Appends to explanation, or turns the decision negative. 105 FusionDecision operator<<(int64_t explanation) { 106 return {absl::StrCat(explanation_.value_or(""), explanation)}; 107 } 108 109 // Explains why the fusion could not be performed. Explain()110 std::string Explain() const { return *explanation_; } 111 112 private: 113 // Empty IFF fusion is possible (explanation provided for negative cases). 114 std::optional<std::string> explanation_; 115 }; 116 117 // Helper class: contextually convertible to "no fusion possible" unlike 118 // FusionDecision. This is a trick to declare and test fusion decision in a 119 // single statement (as we are still on C++14). 120 // 121 // Sample usage: 122 // 123 // if (NoFusionPossible fusible = !FusabilityRestriction(producer, consume)) { 124 // return !fusible; // Note that negation converts it back to FusionDecision. 125 // } 126 struct NoFusionPossible { 127 // Inverts the test value (true <=> not fusible) on wrapped FusionDecision. 128 explicit operator bool() { return !static_cast<bool>(fusion_decision); } 129 130 // Returns wrapped fusion decision. 131 FusionDecision operator!() { return fusion_decision; } 132 133 FusionDecision fusion_decision; 134 }; 135 136 inline NoFusionPossible FusionDecision::operator!() { return {*this}; } 137 138 // HLO pass which performs instruction fusion. Instructions are fused 139 // "vertically", meaning producing instructions are fused into their consumers 140 // with the intent that the loops which compute their values will be fused in 141 // code generation. Derived classes define ShouldFuse method to select which 142 // instructions to fuse. 143 class InstructionFusion : public HloModulePass { 144 public: 145 explicit InstructionFusion( 146 std::function<bool(const HloInstruction& instruction)> is_expensive, 147 bool may_duplicate = true, 148 FusionConfigCollection config_collection_mode = 149 FusionConfigCollection::kOff) is_expensive_(is_expensive)150 : is_expensive_(is_expensive), 151 may_duplicate_(may_duplicate), 152 config_collection_mode_(config_collection_mode) {} 153 ~InstructionFusion() override = default; name()154 absl::string_view name() const override { return "fusion"; } 155 156 // Run instruction fusion on the given computation. Returns whether the 157 // computation was changed (instructions were fused). 158 using HloPassInterface::Run; 159 StatusOr<bool> Run( 160 HloModule* module, 161 const absl::flat_hash_set<absl::string_view>& execution_threads) override; 162 163 // Returns true if the computation of the given instruction is significantly 164 // more expensive than just writing all the values of the instructions' result 165 // array. Expensive operations will not be duplicated. 166 static bool IsExpensive(const HloInstruction& instruction); 167 168 // Returns true if it's legal to fuse the producer instruction into consumer 169 // with regard to in-place semantics of the consumer. For example, it is 170 // illegal to fuse a slice into a dynamic-update-slice if the slice output is 171 // used as the update and if slice and dynamic-update-slice indices cannot be 172 // proven to be the same. 173 static FusionDecision ShouldFuseInPlaceOp(const HloInstruction* producer, 174 const HloInstruction* consumer); 175 176 protected: 177 // Returns a list of computations on which Fusion is performed. 178 virtual std::vector<HloComputation*> GetFusionComputations( 179 HloModule* module, 180 const absl::flat_hash_set<absl::string_view>& execution_threads); 181 182 // Returns a FusionQueue that implements custom order of instructions being 183 // fused. The default implementation processes consumers in reverse post 184 // order. 185 virtual std::unique_ptr<FusionQueue> GetFusionQueue( 186 HloComputation* computation); 187 188 // Returns whether the given producer instruction should be fused into the 189 // given consumer instruction. producer is necessarily an operand of consumer. 190 // Derived classes should define this method to specify which instructions 191 // should be fused. `operand_index` is which operand of the consumer the 192 // producer is. 193 // 194 // Instructions are traversed in reverse post order (computation root to 195 // leaves). This method is called for each operand of the instruction (where 196 // the operand is 'producer' and the instruction is 'consumer') 197 // 198 // Subtypes can override this with target-specific heuristics. 199 virtual FusionDecision ShouldFuse(HloInstruction* consumer, 200 int64_t operand_index); 201 202 // Returns whether multi-output fusion can be applied to fuse `producer` into 203 // `consumer`. In contrast to "regular" fusion, the `producer` is not 204 // duplicated by multi-output fusion. ShouldFuseIntoMultiOutput(HloInstruction * consumer,int64_t operand_index)205 virtual FusionDecision ShouldFuseIntoMultiOutput(HloInstruction* consumer, 206 int64_t operand_index) { 207 return "multi-output fusion not supported by this pass"; 208 } 209 210 // Chooses a fusion kind for `producer` and `consumer`. 211 // Default method chooses `kLoop`. 212 virtual HloInstruction::FusionKind ChooseKind(const HloInstruction* producer, 213 const HloInstruction* consumer); 214 215 // Fuses 'producer' into 'fusion_instruction'. 'fusion_instruction' needs to 216 // be a fusion instruction. Returns the newly created clone of 'producer' 217 // which is part of the fusion computation. 218 virtual HloInstruction* FuseInstruction(HloInstruction* fusion_instruction, 219 HloInstruction* producer); 220 221 // Fuses producer into consumer. Returns the fusion instruction. 222 virtual HloInstruction* Fuse(HloInstruction* producer, 223 HloInstruction* consumer, 224 HloComputation* computation); 225 226 // Creates a new fusion instruction containing `producer` and `consumer`. A 227 // tuple is added as the fusion instruction's root, which consumes from both, 228 // `producer` and `consumer`. This style of fusion is referred to as 229 // multi-output fusion. 230 virtual HloInstruction* FuseIntoMultiOutput(HloInstruction* producer, 231 HloInstruction* consumer, 232 HloComputation* computation); 233 234 // An "effectively unary" operation is one that has at most one "large" 235 // input with the others being negligible in terms of memory usage. 236 // We use "has a smaller true rank than the output" as a heuristic 237 // for "negligible" memory usage. 238 bool EffectivelyAtMostUnary(HloInstruction* hlo); 239 240 // Returns true if fusing producer into consumer would cause producer to be 241 // duplicated. This is the case if producer has uses other than consumer. FusionWouldDuplicate(const HloInstruction & producer,const HloInstruction & consumer)242 bool FusionWouldDuplicate(const HloInstruction& producer, 243 const HloInstruction& consumer) { 244 return !(producer.users().size() == 1 && consumer.IsUserOf(&producer)); 245 } 246 is_expensive(const HloInstruction & instruction)247 bool is_expensive(const HloInstruction& instruction) { 248 return is_expensive_(instruction); 249 } 250 251 // Overwrites the originally initialized is_expensive function. set_is_expensive(std::function<bool (const HloInstruction & instruction)> is_expensive)252 void set_is_expensive( 253 std::function<bool(const HloInstruction& instruction)> is_expensive) { 254 is_expensive_ = is_expensive; 255 } 256 257 // Whether multi-output fusion would introduce a cycle into the HLO graph. 258 bool MultiOutputFusionCreatesCycle(HloInstruction* producer, 259 HloInstruction* consumer, 260 const HloReachabilityMap& reachability); 261 config_collection_mode()262 FusionConfigCollection config_collection_mode() { 263 return config_collection_mode_; 264 } 265 266 // Returns whether 'consumer' may reuse elements of its `operand_index`th 267 // operand. 268 bool ReusesOperandElements(const HloInstruction* consumer, 269 int64_t operand_index); 270 271 // The set of producers whose consumers we cannot fuse into. 272 using HloInstructionSet = absl::flat_hash_set<HloInstruction*>; 273 274 // Computes the set of nodes that we do not want to fuse into any of their 275 // consumers based on a global analysis of the HLO graph. 276 virtual HloInstructionSet ComputeGloballyUnfusible( 277 absl::Span<HloInstruction* const> post_order, 278 const HloReachabilityMap& reachability); 279 280 private: 281 // Returns the reused operands of `instruction` from reused_fusion_operands_, 282 // computing them if they have not previously been computed for that 283 // instruction. 284 // The returned value has pointer stability, assuming entries are not deleted 285 // from reused_fusion_operands_. 286 absl::flat_hash_set<const HloInstruction*>& ReusedOperandsOf( 287 const HloInstruction* instruction); 288 289 // Updates reused_fusion_operands_ for a fusion when we are about to fuse 290 // `producer` into `fusion_instruction`. 291 void UpdateReusedOperandsForFusion(HloInstruction* producer, 292 HloInstruction* fusion_instruction); 293 294 HloInstruction* AddFusionInstruction(HloInstruction* producer, 295 HloInstruction* consumer, 296 HloComputation* computation); 297 298 // Whether or not we can fuse producer into consumer on all paths 299 // from the producer to the consumer where nodes are HLOs and edges are uses. 300 // 301 // A map from <producer, consumer> to a bool is required as the result cache 302 // to store and query the results of calls to this function, in order to avoid 303 // repeated computations. 304 bool CanFuseOnAllPaths( 305 HloInstruction* producer, HloInstruction* consumer, 306 const HloInstructionSet& do_not_fuse, 307 const HloReachabilityMap& reachability, 308 absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool>* 309 result_cache); 310 311 // Used to determine if an HLO is expensive. Expensive operations will not be 312 // duplicated. 313 std::function<bool(const HloInstruction& instruction)> is_expensive_; 314 315 // Returns whether we may duplicate an instruction if we want to fuse it. 316 bool may_duplicate_; 317 318 // Configuration mode. 319 FusionConfigCollection config_collection_mode_; 320 321 // Caches which operands are reused inside fusion computations. 322 absl::flat_hash_map< 323 const HloInstruction*, 324 std::unique_ptr<absl::flat_hash_set<const HloInstruction*>>> 325 reused_fusion_operands_; 326 327 InstructionFusion(const InstructionFusion&) = delete; 328 InstructionFusion& operator=(const InstructionFusion&) = delete; 329 }; 330 331 } // namespace xla 332 333 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_ 334