• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_THUNK_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_THUNK_H_
18 
19 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
20 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
21 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h"
22 #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
23 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/core/lib/core/status.h"
27 
28 namespace xla {
29 namespace gpu {
30 
31 // This file contains thunks which call into cudnn to run the various flavors of
32 // batch normalization: BatchNormInference, BatchNormTraining, and
33 // BatchNormGrad, known to cudnn as BatchNormForwardInference,
34 // BatchNormForwardTraining, and BatchNormBackward.
35 //
36 // As an alternative to using these thunks, XLA can decompose batchnorm HLOs
37 // into smaller components using the BatchNormRewriter pass.  This can result in
38 // faster code because those individual components can fuse into their
39 // inputs/outputs, but it may also be slower if cudnn's batchnorm implementation
40 // outperforms the code XLA generates for these components.
41 //
42 // Currently these thunks require that their inputs are F32s.
43 //
44 // Note that these thunks do not take full advantage of the cudnn batchnorm
45 // functions.  For example, cudnn lets you bias and/or scale the input/output,
46 // but these thunks don't currently support that.
47 
48 class CudnnBatchNormForwardInferenceThunk : public Thunk {
49  public:
50   CudnnBatchNormForwardInferenceThunk(ThunkInfo thunk_info,
51                                       CudnnBatchNormConfig config,
52                                       const BufferAllocation::Slice& operand,
53                                       const BufferAllocation::Slice& scale,
54                                       const BufferAllocation::Slice& offset,
55                                       const BufferAllocation::Slice& mean,
56                                       const BufferAllocation::Slice& variance,
57                                       const BufferAllocation::Slice& output);
58 
59   CudnnBatchNormForwardInferenceThunk(
60       const CudnnBatchNormForwardInferenceThunk&) = delete;
61   CudnnBatchNormForwardInferenceThunk& operator=(
62       const CudnnBatchNormForwardInferenceThunk&) = delete;
63 
64   Status ExecuteOnStream(const ExecuteParams& params) override;
65 
66  private:
67   CudnnBatchNormConfig config_;
68   BufferAllocation::Slice operand_;
69   BufferAllocation::Slice scale_;
70   BufferAllocation::Slice offset_;
71   BufferAllocation::Slice mean_;
72   BufferAllocation::Slice variance_;
73   BufferAllocation::Slice output_;
74 };
75 
76 class CudnnBatchNormForwardTrainingThunk : public Thunk {
77  public:
78   CudnnBatchNormForwardTrainingThunk(
79       ThunkInfo thunk_info, CudnnBatchNormConfig config,
80       const BufferAllocation::Slice& operand,
81       const BufferAllocation::Slice& scale,
82       const BufferAllocation::Slice& offset,
83       const BufferAllocation::Slice& output_data,
84       const BufferAllocation::Slice& output_mean,
85       const BufferAllocation::Slice& output_inv_stddev);
86 
87   CudnnBatchNormForwardTrainingThunk(
88       const CudnnBatchNormForwardTrainingThunk&) = delete;
89   CudnnBatchNormForwardTrainingThunk& operator=(
90       const CudnnBatchNormForwardTrainingThunk&) = delete;
91 
92   Status ExecuteOnStream(const ExecuteParams& params) override;
93 
94  private:
95   CudnnBatchNormConfig config_;
96   BufferAllocation::Slice operand_;
97   BufferAllocation::Slice scale_;
98   BufferAllocation::Slice offset_;
99   BufferAllocation::Slice output_data_;
100   BufferAllocation::Slice output_mean_;
101   BufferAllocation::Slice output_inv_stddev_;
102 };
103 
104 class CudnnBatchNormBackwardThunk : public Thunk {
105  public:
106   CudnnBatchNormBackwardThunk(
107       ThunkInfo thunk_info, CudnnBatchNormConfig config,
108       const BufferAllocation::Slice& operand,
109       const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean,
110       const BufferAllocation::Slice& inv_stddev,
111       const BufferAllocation::Slice& grad_output,
112       const BufferAllocation::Slice& output_grad_data,
113       const BufferAllocation::Slice& output_grad_scale,
114       const BufferAllocation::Slice& output_grad_offset);
115 
116   CudnnBatchNormBackwardThunk(const CudnnBatchNormBackwardThunk&) = delete;
117   CudnnBatchNormBackwardThunk& operator=(const CudnnBatchNormBackwardThunk&) =
118       delete;
119 
120   Status ExecuteOnStream(const ExecuteParams& params) override;
121 
122  private:
123   const CudnnBatchNormConfig config_;
124   BufferAllocation::Slice operand_;
125   BufferAllocation::Slice scale_;
126   BufferAllocation::Slice mean_;
127   BufferAllocation::Slice inv_stddev_;
128   BufferAllocation::Slice grad_output_;
129   BufferAllocation::Slice output_grad_data_;
130   BufferAllocation::Slice output_grad_scale_;
131   BufferAllocation::Slice output_grad_offset_;
132 };
133 
134 }  // namespace gpu
135 }  // namespace xla
136 
137 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_THUNK_H_
138