• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_REWRITER_H_
2 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_REWRITER_H_
3 
4 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
5 
6 Licensed under the Apache License, Version 2.0 (the "License");
7 you may not use this file except in compliance with the License.
8 You may obtain a copy of the License at
9 
10     http://www.apache.org/licenses/LICENSE-2.0
11 
12 Unless required by applicable law or agreed to in writing, software
13 distributed under the License is distributed on an "AS IS" BASIS,
14 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 See the License for the specific language governing permissions and
16 limitations under the License.
17 ==============================================================================*/
18 
19 #include "tensorflow/compiler/xla/service/hlo_module.h"
20 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
21 
22 namespace xla {
23 namespace gpu {
24 
25 // Rewrites BatchNorm HLOs into calls into cudnn where possible.
26 //
27 // A call into cudnn for performing a batchnorm op is represented as a
28 // CustomCall HLO with custom_call_target equal to one of
29 //
30 //   - kCudnnBatchNormForwardInferenceCallTarget
31 //   - kCudnnBatchNormForwardTrainingCallTarget, or
32 //   - kCudnnBatchNormBackwardCallTarget.
33 //
34 // A CustomCall created by this pass has the same operands corresponding
35 // batchnorm HLO, except the epsilon() and feature_index() properties of the
36 // batchnorm HLO are converted into proper operands, added to the end of the
37 // CustomCall's operands list.
38 //
39 // The inputs/outputs of the cudnn calls for BatchNormTraining and BatchNormGrad
40 // do not correspond exactly to the HLOs.  In particular, the training cudnn
41 // call returns 1/sqrt(variance + epsilon), while the HLO returns plain
42 // variance.  Similarly, the grad cudnn call expects 1/sqrt(variance + epsilon)
43 // as input, whereas the HLO expects plain variance.
44 //
45 // This pass adds HLOs in front of / behind the CustomCalls to fix up the
46 // inputs/outputs as appropriate, and we rely on the AlgebraicSimplifier to
47 // remove these where possible.
48 //
49 // Currently batchnorm ops over F32s are converted into cudnn calls, so long as
50 // epsilon is not too small.  This pass leaves other batchnorm ops unmodified.
51 //
52 // The GPU backend does not implement a lowering for the batchnorm HLOs -- it
53 // expects them to be lowered to cudnn calls via this pass or to HLO soup via
54 // BatchNormRewriter.
55 class CudnnBatchNormRewriter : public HloModulePass {
56  public:
name()57   absl::string_view name() const override { return "cudnn_batchnorm_rewriter"; }
58   StatusOr<bool> Run(HloModule* module) override;
59 };
60 
61 }  // namespace gpu
62 }  // namespace xla
63 
64 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_REWRITER_H_
65