1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_ 18 19 #include "tensorflow/compiler/xla/service/hlo_computation.h" 20 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 21 #include "tensorflow/compiler/xla/service/hlo_module.h" 22 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 23 #include "tensorflow/core/platform/macros.h" 24 25 namespace xla { 26 namespace gpu { 27 28 // This optimization pass horizontally fuses computations for reducing kernel 29 // launch overhead while increasing kernel launch dims on GPU. The initial 30 // motivation of this horizontal fusion is due to the observation that the 31 // training optimizer phase (e.g., AdamOptimizer and L2Loss, etc.) typically 32 // has many small kernels as a result of applying the same formula on many 33 // training parameters (or variables in Tensorflow). Fusing these small 34 // kernels, hence, provides performance gain. 35 // 36 // Theoretically speaking, we may implement a cycle detection algorithm to make 37 // sure no cycles are created after fusion. However, cycle detection check is 38 // somewhat cumbersome; also, we observe that naive horizontal fusion of 39 // arbitrary kernels may not be profitable due to control divergence and 40 // possible increase of memory bandwidth pressure due to uncoalesced memory 41 // accesses (note that horizontal fusion does not change the amount of memory 42 // read+written at all). In practice, a simple yet effective heuristic is used 43 // to avoid these issues while addressing the known beneficial cases. That is, 44 // we simply search for fusion candidates by looking for instructions whose 45 // outputs are all consumed by the same instruction. This catches the cases in 46 // the training optimizer phase, as the candidate instructions are typically 47 // consumed only by the ROOT tuple of the entry computation. 48 // 49 // The following illustrates the mechanism of the horizontal fusion. Before 50 // fusion, there are two trivial kernels in the illustrating example. One has 51 // only a Mul op, while the other consists of only an Add op. Since they are 52 // only consumed by the same (ROOT) tuple instruction, horizontal fusion is 53 // triggered. 54 // 55 // i0 i1 i2 i3 56 // | | | | 57 // v v v v 58 // Mul Add 59 // | | 60 // v v 61 // (ROOT) tuple 62 // 63 // We horizontally fuse them into the below pattern. 64 // 65 // i0 i1 i2 i3 +++ (Slice) Input Fusion 66 // | | | | + 67 // v v v v + 68 // Mul Add + 69 // | | + 70 // v v + 71 // Reshape0 Reshape1 + 72 // | | + 73 // v v + 74 // Concatenate + 75 // | | + 76 // v v + 77 // Slice0 Slice1 +++ 78 // | | 79 // v v 80 // Reshape2 Reshape3 81 // | | 82 // v v 83 // (ROOT) tuple 84 // 85 // Note that this fusion style provides an important advantage that kernels of 86 // different shapes can be horizontally fused. The first pair of reshapes 87 // (i.e., Reshape0 and Reshape1) reshape the dims to 1 dimension, so that the 88 // outputs of the fused kernels can (always) be concatenated. The second pair 89 // of reshapes (Reshape2 and Reshape3) restore the original shapes to the 90 // output tensors. 91 // 92 // No extra copies are introduced by the horizontal fusion. Besides Reshape2 93 // and Reshape3, the other instructions are fused into an input fusion; the 94 // output dims of the concatenate will be used as the kernel launch dims. 95 // Instruction bitcasts can be used for Reshape2 and Reshape3 as long as the 96 // outputs of Mul and Add are row-major. 97 class GpuHorizontalLoopFusion : public HloModulePass { 98 public: GpuHorizontalLoopFusion()99 GpuHorizontalLoopFusion() {} 100 name()101 absl::string_view name() const override { 102 return "gpu_horizontal_loop_fusion"; 103 } 104 105 StatusOr<bool> Run(HloModule* module) override; 106 107 private: 108 StatusOr<bool> RunOnComputation(HloComputation*); 109 }; 110 111 } // namespace gpu 112 } // namespace xla 113 114 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_ 115