1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ 18 19 #include <queue> 20 #include <vector> 21 22 #include "absl/container/flat_hash_map.h" 23 #include "absl/strings/string_view.h" 24 #include "tensorflow/compiler/xla/service/hlo_module.h" 25 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 26 #include "tensorflow/compiler/xla/service/hlo_reachability.h" 27 #include "tensorflow/compiler/xla/statusor.h" 28 29 namespace xla { 30 namespace gpu { 31 32 // Multi-output fusion of sibling and producer-consumer instructions for the 33 // GPU backend to reduce memory bandwidth requirements. 34 // 35 // 0) Before multi- 1) Sibling multi- 2) Producer-consumer 36 // output fusion output fusion multi-output fusion 37 // 38 // p p p 39 // | | | 40 // v v v 41 // A A +-fusion--+ 42 // / \ | | A | 43 // | | +-fusion--+ | / \ | 44 // v v | / \ | | B | | 45 // B C | B C | | | | | 46 // \ / | | | | | v v | 47 // v v | v v | | tuple | 48 // ROOT | tuple | +---------+ 49 // +---------+ / \ 50 // / \ gte_b gte_a 51 // gte_b gte_c | | 52 // | | | v 53 // \ / | C 54 // v v \ / 55 // ROOT v v 56 // ROOT 57 // 58 // Multi-output fusion ops have a tuple op at their root containing multiple 59 // elements as outputs. GetTupleElement ops (depicted as gte_* above) are 60 // inserted to extract tuple elements for consumers. 61 // 62 // The two different flavors of multi-output fusion this pass performs are 63 // depicted above. 64 // 1) Fusion of sibling ops reduces memory bandwidth requirements, because 65 // common input parameters have to be read only once. 66 // 2) Fusion of producer-consumer ops reduces memory bandwidth requirements by 67 // saving one read from memory. In the example above, B does not need to read 68 // the output of A from memory, while C still does (using gte_a). 69 // Note that sibling (1) and producer-consumer (2) multi-output fusion can be 70 // combined. 71 // 72 // The GpuMultiOutputFusion pass modifies the HLO in reverse post-order (defs 73 // before uses). First, it attempts to fuse the consumer ops of the current op, 74 // which are siblings (1). Hereafter, it attempts to fuse the current op with 75 // one of its consumers (2). This order avoids a phase ordering issue (described 76 // in go/fusionfusion). It ensures that all GetTupleElement ops inserted as a 77 // by-product of multi-output fusion will occur before the current op in the 78 // order of traversal, and hence, not get into the way of subsequent fusion 79 // attempts. 80 // 81 // The GpuMultiOutputFusion pass ensures several conditions are met for fusion. 82 // Some of them are relevant for correctness. In particular, no cycles must be 83 // introduced into the HLO module. Moreover, the code emitters for multi-output 84 // fusion must support the combination of ops and their shapes. Other 85 // restrictions are rather arbitrary and lifting them could be beneficial. 86 // * Sibling fusion (1) requires at least one op to be a kFusion. 87 // * Sibling fusion (1) does not fuse kInput fusions with kLoop fusions, i.e. 88 // the fusion kinds must match. 89 90 class GpuMultiOutputFusion : public HloModulePass { 91 public: 92 GpuMultiOutputFusion() = default; 93 name()94 absl::string_view name() const override { return "multi_output_fusion"; } 95 96 StatusOr<bool> Run(HloModule* module) override; 97 98 private: 99 bool FuseSiblings(HloInstruction* parent); 100 101 StatusOr<bool> DoMultiOutputFusion(); 102 103 // Recompute reachability for the current computation. 104 void RecomputeReachability(); 105 106 // Computation for the pass. 107 HloComputation* computation_; 108 109 // The reachability map of current computation. 110 std::unique_ptr<HloReachabilityMap> reachability_; 111 }; 112 113 } // namespace gpu 114 } // namespace xla 115 116 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ 117