1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h"
17
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/compiler/xla/layout_util.h"
20 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
21 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26
27 namespace xla {
28 namespace gpu {
29 namespace {
30
31 struct CudnnBatchNormParamsCommon {
32 se::DeviceMemoryBase operand;
33 se::dnn::BatchDescriptor operand_desc;
34 se::dnn::BatchDescriptor scale_offset_desc;
35 se::DeviceMemory<float> scale;
36 float epsilon;
37 };
38
39 struct CudnnBatchNormForwardInferenceParams {
40 CudnnBatchNormParamsCommon common;
41 se::DeviceMemoryBase output;
42 se::DeviceMemory<float> offset;
43 se::DeviceMemory<float> mean;
44 se::DeviceMemory<float> variance;
45 };
46
47 struct CudnnBatchNormForwardTrainingParams {
48 CudnnBatchNormParamsCommon common;
49 se::DeviceMemoryBase output_data;
50 se::DeviceMemory<float> offset;
51 se::DeviceMemory<float> output_mean;
52 se::DeviceMemory<float> output_inv_stddev;
53 };
54
55 struct CudnnBatchNormBackwardParams {
56 CudnnBatchNormParamsCommon common;
57 se::DeviceMemoryBase output_grad_data;
58 se::DeviceMemoryBase grad_output;
59 se::DeviceMemory<float> output_grad_scale;
60 se::DeviceMemory<float> output_grad_offset;
61 se::DeviceMemory<float> mean;
62 se::DeviceMemory<float> inv_stddev;
63 };
64
65 struct DnnBatchDescriptors {
66 se::dnn::BatchDescriptor input_desc;
67 se::dnn::BatchDescriptor scale_offset_desc;
68 };
69
MakeBatchNormDescriptors(const Shape & shape,int64_t feature_index)70 DnnBatchDescriptors MakeBatchNormDescriptors(const Shape& shape,
71 int64_t feature_index) {
72 std::vector<int64> logical_to_physical =
73 LayoutUtil::MakeLogicalToPhysical(shape.layout());
74
75 auto physical_dim_size = [&](int64_t physical_dim) {
76 return shape.dimensions(LayoutUtil::Major(shape.layout(), physical_dim));
77 };
78
79 // Batchnorm only cares about the location of the depth (aka "feature") dim.
80 // The other dims are all treated the same. Thus we can use the kBatchDepthYX
81 // cudnn layout for any XLA shape+layout, even XLA shapes that don't have
82 // exactly 4 dimensions: We put everything that comes before the feature dim
83 // into "batch", and everything that comes after the feature dim into "Y".
84 int64_t batch_size = 1;
85 int64_t y_size = 1;
86 int64_t physical_dim;
87 for (physical_dim = 0; physical_dim != logical_to_physical[feature_index];
88 ++physical_dim) {
89 CHECK_LT(physical_dim, shape.dimensions_size());
90 batch_size *= physical_dim_size(physical_dim);
91 }
92 ++physical_dim; // Skip the feature dimension.
93 for (; physical_dim < shape.dimensions_size(); ++physical_dim) {
94 y_size *= physical_dim_size(physical_dim);
95 }
96
97 DnnBatchDescriptors batch_descs;
98 batch_descs.input_desc.set_layout(se::dnn::DataLayout::kBatchDepthYX)
99 .set_count(batch_size)
100 .set_feature_map_count(shape.dimensions(feature_index))
101 .set_height(y_size)
102 .set_width(1);
103
104 batch_descs.scale_offset_desc.set_layout(se::dnn::DataLayout::kBatchDepthYX)
105 .set_feature_map_count(batch_descs.input_desc.feature_map_count())
106 .set_height(1)
107 .set_width(1)
108 .set_count(1);
109
110 return batch_descs;
111 }
112
AssignCommonParams(const CudnnBatchNormConfig & config,CudnnBatchNormParamsCommon * params,const se::DeviceMemoryBase & operand,const se::DeviceMemory<float> & scale)113 void AssignCommonParams(const CudnnBatchNormConfig& config,
114 CudnnBatchNormParamsCommon* params,
115 const se::DeviceMemoryBase& operand,
116 const se::DeviceMemory<float>& scale) {
117 // The BatchNormTraining HLO outputs a tuple of three elements: output data,
118 // batch mean, and batch variance. We want to make our descriptors based on
119 // the shape of the output data. Batchnorm backward call outputs a tuple of
120 // three elements: grad data, grad offset, and grad scale. We want to make
121 // our descriptors based on the shape of the grad data.
122 const Shape& shape = config.output_shape;
123 DnnBatchDescriptors batch_descs =
124 MakeBatchNormDescriptors(shape, config.feature_index);
125 params->operand_desc = batch_descs.input_desc;
126 params->scale_offset_desc = batch_descs.scale_offset_desc;
127 params->operand = operand;
128 params->scale = scale;
129 params->epsilon = config.epsilon;
130 }
131
132 template <typename ElemType>
RunCudnnBatchNormForwardInferenceImpl(CudnnBatchNormForwardInferenceParams * params,se::Stream * stream)133 void RunCudnnBatchNormForwardInferenceImpl(
134 CudnnBatchNormForwardInferenceParams* params, se::Stream* stream) {
135 se::DeviceMemory<ElemType> null_device_ptr(nullptr);
136 auto output_buf = se::DeviceMemory<ElemType>(params->output);
137 stream->ThenBatchNormalizationForward(
138 se::DeviceMemory<ElemType>(params->common.operand),
139 params->common.scale, //
140 params->offset, //
141 params->mean, //
142 params->variance, //
143 /*side_input=*/null_device_ptr, params->common.operand_desc, //
144 params->common.scale_offset_desc, //
145 static_cast<double>(params->common.epsilon), //
146 // TODO(b/137108598): Extend method to allow use of non-trivial
147 // exponential averaging.
148 /*exponential_average_factor=*/1.0,
149 se::dnn::ActivationMode::kNone, //
150 &output_buf, //
151 /*batch_mean=*/nullptr, //
152 /*batch_var=*/nullptr, //
153 /*saved_mean=*/nullptr, //
154 /*saved_inv_var=*/nullptr, //
155 /*is_training=*/false, //
156 /*reserve_space_allocator=*/nullptr, //
157 /*workspace_allocator=*/nullptr);
158 }
159
160 template <typename ElemType>
RunCudnnBatchNormForwardTrainingImpl(CudnnBatchNormForwardTrainingParams * params,se::Stream * stream)161 void RunCudnnBatchNormForwardTrainingImpl(
162 CudnnBatchNormForwardTrainingParams* params, se::Stream* stream) {
163 se::DeviceMemory<float> null_device_ptr(nullptr);
164 se::DeviceMemory<ElemType> null_elem_device_ptr(nullptr);
165 auto output_data = se::DeviceMemory<ElemType>(params->output_data);
166 stream->ThenBatchNormalizationForward(
167 se::DeviceMemory<ElemType>(params->common.operand),
168 params->common.scale, //
169 params->offset, //
170 /*estimated_mean=*/null_device_ptr, //
171 /*estimated_variance=*/null_device_ptr, //
172 /*side_input=*/null_elem_device_ptr, //
173 params->common.operand_desc, //
174 params->common.scale_offset_desc, //
175 params->common.epsilon, //
176 // TODO(b/137108598): Extend method to allow use of non-trivial
177 // exponential averaging.
178 /*exponential_average_factor=*/1.0,
179 se::dnn::ActivationMode::kNone, //
180 &output_data, //
181 /*batch_mean=*/&null_device_ptr, //
182 /*batch_var=*/&null_device_ptr, //
183 /*saved_mean=*/¶ms->output_mean, //
184 /*saved_inv_var=*/¶ms->output_inv_stddev, //
185 /*is_training=*/true, //
186 /*reserve_space_allocator=*/nullptr, //
187 /*workspace_allocator=*/nullptr);
188 }
189
190 template <typename ElemType>
RunCudnnBatchNormBackwardImpl(CudnnBatchNormBackwardParams * params,se::Stream * stream)191 void RunCudnnBatchNormBackwardImpl(CudnnBatchNormBackwardParams* params,
192 se::Stream* stream) {
193 se::DeviceMemory<float> null_device_ptr(nullptr);
194 auto output_grad_data = se::DeviceMemory<ElemType>(params->output_grad_data);
195 stream->ThenBatchNormalizationBackward(
196 se::DeviceMemory<ElemType>(params->grad_output), //
197 se::DeviceMemory<ElemType>(params->common.operand), //
198 params->common.scale, //
199 params->mean, //
200 params->inv_stddev, //
201 params->common.operand_desc, //
202 params->common.scale_offset_desc, //
203 params->common.epsilon, //
204 &output_grad_data, //
205 ¶ms->output_grad_scale, //
206 ¶ms->output_grad_offset, //
207 /*reserve_space_allocator=*/nullptr, //
208 /*workspace_allocator=*/nullptr);
209 }
210
211 } // namespace
212
GetCudnnBatchNormConfig(const HloInstruction * instr,float epsilon,int64_t feature_index)213 CudnnBatchNormConfig GetCudnnBatchNormConfig(const HloInstruction* instr,
214 float epsilon,
215 int64_t feature_index) {
216 CudnnBatchNormConfig config;
217
218 config.output_shape = instr->shape().IsTuple()
219 ? instr->shape().tuple_shapes(0)
220 : instr->shape();
221 config.output_type = config.output_shape.element_type();
222 config.epsilon = epsilon;
223 config.feature_index = feature_index;
224 return config;
225 }
226
RunCudnnBatchNormForwardInference(const CudnnBatchNormConfig & config,se::DeviceMemoryBase operand,se::DeviceMemoryBase output,se::DeviceMemory<float> scale,se::DeviceMemory<float> offset,se::DeviceMemory<float> mean,se::DeviceMemory<float> variance,se::Stream * stream)227 Status RunCudnnBatchNormForwardInference(
228 const CudnnBatchNormConfig& config, se::DeviceMemoryBase operand,
229 se::DeviceMemoryBase output, se::DeviceMemory<float> scale,
230 se::DeviceMemory<float> offset, se::DeviceMemory<float> mean,
231 se::DeviceMemory<float> variance, se::Stream* stream) {
232 CudnnBatchNormForwardInferenceParams inference_params;
233 AssignCommonParams(config, &inference_params.common, operand, scale);
234 inference_params.offset = offset;
235 inference_params.mean = mean;
236 inference_params.variance = variance;
237 inference_params.output = output;
238
239 switch (config.output_type) {
240 case F16:
241 RunCudnnBatchNormForwardInferenceImpl<Eigen::half>(&inference_params,
242 stream);
243 break;
244 case F32:
245 RunCudnnBatchNormForwardInferenceImpl<float>(&inference_params, stream);
246 break;
247 default:
248 return Unimplemented(
249 "Primitive type %s not implemented for batchnorm forward inference",
250 primitive_util::LowercasePrimitiveTypeName(config.output_type)
251 .c_str());
252 }
253 return Status::OK();
254 }
255
RunCudnnBatchNormForwardTraining(const CudnnBatchNormConfig & config,se::DeviceMemoryBase operand,se::DeviceMemoryBase output_data,se::DeviceMemory<float> output_mean,se::DeviceMemory<float> output_inv_stddev,se::DeviceMemory<float> scale,se::DeviceMemory<float> offset,se::Stream * stream)256 Status RunCudnnBatchNormForwardTraining(
257 const CudnnBatchNormConfig& config, se::DeviceMemoryBase operand,
258 se::DeviceMemoryBase output_data, se::DeviceMemory<float> output_mean,
259 se::DeviceMemory<float> output_inv_stddev, se::DeviceMemory<float> scale,
260 se::DeviceMemory<float> offset, se::Stream* stream) {
261 CudnnBatchNormForwardTrainingParams forward_params;
262 AssignCommonParams(config, &forward_params.common, operand, scale);
263 forward_params.offset = offset;
264 forward_params.output_data = output_data;
265 forward_params.output_mean = output_mean;
266 forward_params.output_inv_stddev = output_inv_stddev;
267
268 switch (config.output_type) {
269 case F16:
270 RunCudnnBatchNormForwardTrainingImpl<Eigen::half>(&forward_params,
271 stream);
272 break;
273 case F32:
274 RunCudnnBatchNormForwardTrainingImpl<float>(&forward_params, stream);
275 break;
276 default:
277 return Unimplemented(
278 "Primitive type %s not implemented for batchnorm forward training",
279 primitive_util::LowercasePrimitiveTypeName(config.output_type)
280 .c_str());
281 }
282 return Status::OK();
283 }
284
RunCudnnBatchNormBackward(const CudnnBatchNormConfig & config,se::DeviceMemoryBase operand,se::DeviceMemoryBase output_grad_data,se::DeviceMemoryBase grad_output,se::DeviceMemory<float> output_grad_scale,se::DeviceMemory<float> output_grad_offset,se::DeviceMemory<float> scale,se::DeviceMemory<float> mean,se::DeviceMemory<float> inv_stddev,se::Stream * stream)285 Status RunCudnnBatchNormBackward(
286 const CudnnBatchNormConfig& config, se::DeviceMemoryBase operand,
287 se::DeviceMemoryBase output_grad_data, se::DeviceMemoryBase grad_output,
288 se::DeviceMemory<float> output_grad_scale,
289 se::DeviceMemory<float> output_grad_offset, se::DeviceMemory<float> scale,
290 se::DeviceMemory<float> mean, se::DeviceMemory<float> inv_stddev,
291 se::Stream* stream) {
292 CudnnBatchNormBackwardParams backward_params;
293 AssignCommonParams(config, &backward_params.common, operand, scale);
294 backward_params.output_grad_data = output_grad_data;
295 backward_params.grad_output = grad_output;
296 backward_params.output_grad_scale = output_grad_scale;
297 backward_params.output_grad_offset = output_grad_offset;
298 backward_params.mean = mean;
299 backward_params.inv_stddev = inv_stddev;
300
301 switch (config.output_type) {
302 case F16:
303 RunCudnnBatchNormBackwardImpl<Eigen::half>(&backward_params, stream);
304 break;
305 case F32:
306 RunCudnnBatchNormBackwardImpl<float>(&backward_params, stream);
307 break;
308 default:
309 return Unimplemented(
310 "Primitive type %s not implemented for batchnorm backward",
311 primitive_util::LowercasePrimitiveTypeName(config.output_type)
312 .c_str());
313 }
314 return Status::OK();
315 }
316
317 } // namespace gpu
318 } // namespace xla
319