1 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_REWRITER_H_ 2 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_REWRITER_H_ 3 4 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 5 6 Licensed under the Apache License, Version 2.0 (the "License"); 7 you may not use this file except in compliance with the License. 8 You may obtain a copy of the License at 9 10 http://www.apache.org/licenses/LICENSE-2.0 11 12 Unless required by applicable law or agreed to in writing, software 13 distributed under the License is distributed on an "AS IS" BASIS, 14 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 See the License for the specific language governing permissions and 16 limitations under the License. 17 ==============================================================================*/ 18 19 #include "tensorflow/compiler/xla/service/hlo_module.h" 20 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 21 22 namespace xla { 23 namespace gpu { 24 25 // Rewrites BatchNorm HLOs into calls into cudnn where possible. 26 // 27 // A call into cudnn for performing a batchnorm op is represented as a 28 // CustomCall HLO with custom_call_target equal to one of 29 // 30 // - kCudnnBatchNormForwardInferenceCallTarget 31 // - kCudnnBatchNormForwardTrainingCallTarget, or 32 // - kCudnnBatchNormBackwardCallTarget. 33 // 34 // A CustomCall created by this pass has the same operands corresponding 35 // batchnorm HLO, except the epsilon() and feature_index() properties of the 36 // batchnorm HLO are converted into proper operands, added to the end of the 37 // CustomCall's operands list. 38 // 39 // The inputs/outputs of the cudnn calls for BatchNormTraining and BatchNormGrad 40 // do not correspond exactly to the HLOs. In particular, the training cudnn 41 // call returns 1/sqrt(variance + epsilon), while the HLO returns plain 42 // variance. Similarly, the grad cudnn call expects 1/sqrt(variance + epsilon) 43 // as input, whereas the HLO expects plain variance. 44 // 45 // This pass adds HLOs in front of / behind the CustomCalls to fix up the 46 // inputs/outputs as appropriate, and we rely on the AlgebraicSimplifier to 47 // remove these where possible. 48 // 49 // Currently batchnorm ops over F32s are converted into cudnn calls, so long as 50 // epsilon is not too small. This pass leaves other batchnorm ops unmodified. 51 // 52 // The GPU backend does not implement a lowering for the batchnorm HLOs -- it 53 // expects them to be lowered to cudnn calls via this pass or to HLO soup via 54 // BatchNormRewriter. 55 class CudnnBatchNormRewriter : public HloModulePass { 56 public: name()57 absl::string_view name() const override { return "cudnn_batchnorm_rewriter"; } 58 StatusOr<bool> Run(HloModule* module) override; 59 }; 60 61 } // namespace gpu 62 } // namespace xla 63 64 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_REWRITER_H_ 65