1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h"
17
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/compiler/xla/layout_util.h"
20 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
21 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26
27 namespace xla {
28 namespace gpu {
29 namespace {
30
31 struct CudnnBatchNormParamsCommon {
32 se::DeviceMemoryBase operand;
33 se::dnn::BatchDescriptor operand_desc;
34 se::dnn::BatchDescriptor scale_offset_desc;
35 se::DeviceMemory<float> scale;
36 float epsilon;
37 };
38
39 struct CudnnBatchNormForwardInferenceParams {
40 CudnnBatchNormParamsCommon common;
41 se::DeviceMemoryBase output;
42 se::DeviceMemory<float> offset;
43 se::DeviceMemory<float> mean;
44 se::DeviceMemory<float> variance;
45 };
46
47 struct CudnnBatchNormForwardTrainingParams {
48 CudnnBatchNormParamsCommon common;
49 se::DeviceMemoryBase output_data;
50 se::DeviceMemory<float> offset;
51 se::DeviceMemory<float> output_mean;
52 se::DeviceMemory<float> output_inv_stddev;
53 };
54
55 struct CudnnBatchNormBackwardParams {
56 CudnnBatchNormParamsCommon common;
57 se::DeviceMemoryBase output_grad_data;
58 se::DeviceMemoryBase grad_output;
59 se::DeviceMemory<float> output_grad_scale;
60 se::DeviceMemory<float> output_grad_offset;
61 se::DeviceMemory<float> mean;
62 se::DeviceMemory<float> inv_stddev;
63 };
64
65 struct DnnBatchDescriptors {
66 se::dnn::BatchDescriptor input_desc;
67 se::dnn::BatchDescriptor scale_offset_desc;
68 };
69
MakeBatchNormDescriptors(const Shape & shape,int64 feature_index)70 DnnBatchDescriptors MakeBatchNormDescriptors(const Shape& shape,
71 int64 feature_index) {
72 std::vector<int64> logical_to_physical =
73 LayoutUtil::MakeLogicalToPhysical(shape.layout());
74
75 auto physical_dim_size = [&](int64 physical_dim) {
76 return shape.dimensions(LayoutUtil::Major(shape.layout(), physical_dim));
77 };
78
79 // Batchnorm only cares about the location of the depth (aka "feature") dim.
80 // The other dims are all treated the same. Thus we can use the kBatchDepthYX
81 // cudnn layout for any XLA shape+layout, even XLA shapes that don't have
82 // exactly 4 dimensions: We put everything that comes before the feature dim
83 // into "batch", and everything that comes after the feature dim into "Y".
84 int64 batch_size = 1;
85 int64 y_size = 1;
86 int64 physical_dim;
87 for (physical_dim = 0; physical_dim != logical_to_physical[feature_index];
88 ++physical_dim) {
89 CHECK_LT(physical_dim, shape.dimensions_size());
90 batch_size *= physical_dim_size(physical_dim);
91 }
92 ++physical_dim; // Skip the feature dimension.
93 for (; physical_dim < shape.dimensions_size(); ++physical_dim) {
94 y_size *= physical_dim_size(physical_dim);
95 }
96
97 DnnBatchDescriptors batch_descs;
98 batch_descs.input_desc.set_layout(se::dnn::DataLayout::kBatchDepthYX)
99 .set_count(batch_size)
100 .set_feature_map_count(shape.dimensions(feature_index))
101 .set_height(y_size)
102 .set_width(1);
103
104 batch_descs.scale_offset_desc.set_layout(se::dnn::DataLayout::kBatchDepthYX)
105 .set_feature_map_count(batch_descs.input_desc.feature_map_count())
106 .set_height(1)
107 .set_width(1)
108 .set_count(1);
109
110 return batch_descs;
111 }
112
AssignCommonParams(const CudnnBatchNormConfig & config,CudnnBatchNormParamsCommon * params,const se::DeviceMemoryBase & operand,const se::DeviceMemory<float> & scale)113 void AssignCommonParams(const CudnnBatchNormConfig& config,
114 CudnnBatchNormParamsCommon* params,
115 const se::DeviceMemoryBase& operand,
116 const se::DeviceMemory<float>& scale) {
117 // The BatchNormTraining HLO outputs a tuple of three elements: output data,
118 // batch mean, and batch variance. We want to make our descriptors based on
119 // the shape of the output data. Batchnorm backward call outputs a tuple of
120 // three elements: grad data, grad offset, and grad scale. We want to make
121 // our descriptors based on the shape of the grad data.
122 const Shape& shape = config.output_shape;
123 DnnBatchDescriptors batch_descs =
124 MakeBatchNormDescriptors(shape, config.feature_index);
125 params->operand_desc = batch_descs.input_desc;
126 params->scale_offset_desc = batch_descs.scale_offset_desc;
127 params->operand = operand;
128 params->scale = scale;
129 params->epsilon = config.epsilon;
130 }
131
132 template <typename ElemType>
RunCudnnBatchNormForwardInferenceImpl(CudnnBatchNormForwardInferenceParams * params,se::Stream * stream)133 void RunCudnnBatchNormForwardInferenceImpl(
134 CudnnBatchNormForwardInferenceParams* params, se::Stream* stream) {
135 se::DeviceMemory<float> null_device_ptr(nullptr);
136 auto output_buf = se::DeviceMemory<ElemType>(params->output);
137 stream->ThenBatchNormalizationForward(
138 se::DeviceMemory<ElemType>(params->common.operand),
139 params->common.scale, //
140 params->offset, //
141 params->mean, //
142 params->variance, //
143 /*side_input=*/null_device_ptr, params->common.operand_desc, //
144 params->common.scale_offset_desc, //
145 static_cast<double>(params->common.epsilon), //
146 // TODO(b/137108598): Extend method to allow use of non-trivial
147 // exponential averaging.
148 /*exponential_average_factor=*/1.0,
149 se::dnn::ActivationMode::kNone, //
150 &output_buf, //
151 /*batch_mean=*/nullptr, //
152 /*batch_var=*/nullptr, //
153 /*saved_mean=*/nullptr, //
154 /*saved_inv_var=*/nullptr, //
155 /*is_training=*/false, //
156 /*reserve_space_allocator=*/nullptr, //
157 /*workspace_allocator=*/nullptr);
158 }
159
160 template <typename ElemType>
RunCudnnBatchNormForwardTrainingImpl(CudnnBatchNormForwardTrainingParams * params,se::Stream * stream)161 void RunCudnnBatchNormForwardTrainingImpl(
162 CudnnBatchNormForwardTrainingParams* params, se::Stream* stream) {
163 se::DeviceMemory<float> null_device_ptr(nullptr);
164 auto output_data = se::DeviceMemory<ElemType>(params->output_data);
165 stream->ThenBatchNormalizationForward(
166 se::DeviceMemory<ElemType>(params->common.operand),
167 params->common.scale, //
168 params->offset, //
169 /*estimated_mean=*/null_device_ptr, //
170 /*estimated_variance=*/null_device_ptr, //
171 /*side_input=*/null_device_ptr, //
172 params->common.operand_desc, //
173 params->common.scale_offset_desc, //
174 params->common.epsilon, //
175 // TODO(b/137108598): Extend method to allow use of non-trivial
176 // exponential averaging.
177 /*exponential_average_factor=*/1.0,
178 se::dnn::ActivationMode::kNone, //
179 &output_data, //
180 /*batch_mean=*/&null_device_ptr, //
181 /*batch_var=*/&null_device_ptr, //
182 /*saved_mean=*/¶ms->output_mean, //
183 /*saved_inv_var=*/¶ms->output_inv_stddev, //
184 /*is_training=*/true, //
185 /*reserve_space_allocator=*/nullptr, //
186 /*workspace_allocator=*/nullptr);
187 }
188
189 template <typename ElemType>
RunCudnnBatchNormBackwardImpl(CudnnBatchNormBackwardParams * params,se::Stream * stream)190 void RunCudnnBatchNormBackwardImpl(CudnnBatchNormBackwardParams* params,
191 se::Stream* stream) {
192 se::DeviceMemory<float> null_device_ptr(nullptr);
193 auto output_grad_data = se::DeviceMemory<ElemType>(params->output_grad_data);
194 stream->ThenBatchNormalizationBackward(
195 se::DeviceMemory<ElemType>(params->grad_output), //
196 se::DeviceMemory<ElemType>(params->common.operand), //
197 params->common.scale, //
198 params->mean, //
199 params->inv_stddev, //
200 params->common.operand_desc, //
201 params->common.scale_offset_desc, //
202 params->common.epsilon, //
203 &output_grad_data, //
204 ¶ms->output_grad_scale, //
205 ¶ms->output_grad_offset, //
206 /*reserve_space_allocator=*/nullptr, //
207 /*workspace_allocator=*/nullptr);
208 }
209
210 } // namespace
211
GetCudnnBatchNormConfig(const HloInstruction * instr,float epsilon,int64 feature_index)212 CudnnBatchNormConfig GetCudnnBatchNormConfig(const HloInstruction* instr,
213 float epsilon,
214 int64 feature_index) {
215 CudnnBatchNormConfig config;
216
217 config.output_shape = instr->shape().IsTuple()
218 ? instr->shape().tuple_shapes(0)
219 : instr->shape();
220 config.output_type = config.output_shape.element_type();
221 config.epsilon = epsilon;
222 config.feature_index = feature_index;
223 return config;
224 }
225
RunCudnnBatchNormForwardInference(const CudnnBatchNormConfig & config,se::DeviceMemoryBase operand,se::DeviceMemoryBase output,se::DeviceMemory<float> scale,se::DeviceMemory<float> offset,se::DeviceMemory<float> mean,se::DeviceMemory<float> variance,se::Stream * stream)226 Status RunCudnnBatchNormForwardInference(
227 const CudnnBatchNormConfig& config, se::DeviceMemoryBase operand,
228 se::DeviceMemoryBase output, se::DeviceMemory<float> scale,
229 se::DeviceMemory<float> offset, se::DeviceMemory<float> mean,
230 se::DeviceMemory<float> variance, se::Stream* stream) {
231 CudnnBatchNormForwardInferenceParams inference_params;
232 AssignCommonParams(config, &inference_params.common, operand, scale);
233 inference_params.offset = offset;
234 inference_params.mean = mean;
235 inference_params.variance = variance;
236 inference_params.output = output;
237
238 switch (config.output_type) {
239 case F16:
240 RunCudnnBatchNormForwardInferenceImpl<Eigen::half>(&inference_params,
241 stream);
242 break;
243 case F32:
244 RunCudnnBatchNormForwardInferenceImpl<float>(&inference_params, stream);
245 break;
246 default:
247 return Unimplemented(
248 "Primitive type %s not implemented for batchnorm forward inference",
249 primitive_util::LowercasePrimitiveTypeName(config.output_type)
250 .c_str());
251 }
252 return Status::OK();
253 }
254
RunCudnnBatchNormForwardTraining(const CudnnBatchNormConfig & config,se::DeviceMemoryBase operand,se::DeviceMemoryBase output_data,se::DeviceMemory<float> output_mean,se::DeviceMemory<float> output_inv_stddev,se::DeviceMemory<float> scale,se::DeviceMemory<float> offset,se::Stream * stream)255 Status RunCudnnBatchNormForwardTraining(
256 const CudnnBatchNormConfig& config, se::DeviceMemoryBase operand,
257 se::DeviceMemoryBase output_data, se::DeviceMemory<float> output_mean,
258 se::DeviceMemory<float> output_inv_stddev, se::DeviceMemory<float> scale,
259 se::DeviceMemory<float> offset, se::Stream* stream) {
260 CudnnBatchNormForwardTrainingParams forward_params;
261 AssignCommonParams(config, &forward_params.common, operand, scale);
262 forward_params.offset = offset;
263 forward_params.output_data = output_data;
264 forward_params.output_mean = output_mean;
265 forward_params.output_inv_stddev = output_inv_stddev;
266
267 switch (config.output_type) {
268 case F16:
269 RunCudnnBatchNormForwardTrainingImpl<Eigen::half>(&forward_params,
270 stream);
271 break;
272 case F32:
273 RunCudnnBatchNormForwardTrainingImpl<float>(&forward_params, stream);
274 break;
275 default:
276 return Unimplemented(
277 "Primitive type %s not implemented for batchnorm forward training",
278 primitive_util::LowercasePrimitiveTypeName(config.output_type)
279 .c_str());
280 }
281 return Status::OK();
282 }
283
RunCudnnBatchNormBackward(const CudnnBatchNormConfig & config,se::DeviceMemoryBase operand,se::DeviceMemoryBase output_grad_data,se::DeviceMemoryBase grad_output,se::DeviceMemory<float> output_grad_scale,se::DeviceMemory<float> output_grad_offset,se::DeviceMemory<float> scale,se::DeviceMemory<float> mean,se::DeviceMemory<float> inv_stddev,se::Stream * stream)284 Status RunCudnnBatchNormBackward(
285 const CudnnBatchNormConfig& config, se::DeviceMemoryBase operand,
286 se::DeviceMemoryBase output_grad_data, se::DeviceMemoryBase grad_output,
287 se::DeviceMemory<float> output_grad_scale,
288 se::DeviceMemory<float> output_grad_offset, se::DeviceMemory<float> scale,
289 se::DeviceMemory<float> mean, se::DeviceMemory<float> inv_stddev,
290 se::Stream* stream) {
291 CudnnBatchNormBackwardParams backward_params;
292 AssignCommonParams(config, &backward_params.common, operand, scale);
293 backward_params.output_grad_data = output_grad_data;
294 backward_params.grad_output = grad_output;
295 backward_params.output_grad_scale = output_grad_scale;
296 backward_params.output_grad_offset = output_grad_offset;
297 backward_params.mean = mean;
298 backward_params.inv_stddev = inv_stddev;
299
300 switch (config.output_type) {
301 case F16:
302 RunCudnnBatchNormBackwardImpl<Eigen::half>(&backward_params, stream);
303 break;
304 case F32:
305 RunCudnnBatchNormBackwardImpl<float>(&backward_params, stream);
306 break;
307 default:
308 return Unimplemented(
309 "Primitive type %s not implemented for batchnorm backward",
310 primitive_util::LowercasePrimitiveTypeName(config.output_type)
311 .c_str());
312 }
313 return Status::OK();
314 }
315
316 } // namespace gpu
317 } // namespace xla
318