• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/compiler/xla/layout_util.h"
20 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
21 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 
27 namespace xla {
28 namespace gpu {
29 namespace {
30 
31 struct CudnnBatchNormParamsCommon {
32   se::DeviceMemoryBase operand;
33   se::dnn::BatchDescriptor operand_desc;
34   se::dnn::BatchDescriptor scale_offset_desc;
35   se::DeviceMemory<float> scale;
36   float epsilon;
37 };
38 
39 struct CudnnBatchNormForwardInferenceParams {
40   CudnnBatchNormParamsCommon common;
41   se::DeviceMemoryBase output;
42   se::DeviceMemory<float> offset;
43   se::DeviceMemory<float> mean;
44   se::DeviceMemory<float> variance;
45 };
46 
47 struct CudnnBatchNormForwardTrainingParams {
48   CudnnBatchNormParamsCommon common;
49   se::DeviceMemoryBase output_data;
50   se::DeviceMemory<float> offset;
51   se::DeviceMemory<float> output_mean;
52   se::DeviceMemory<float> output_inv_stddev;
53 };
54 
55 struct CudnnBatchNormBackwardParams {
56   CudnnBatchNormParamsCommon common;
57   se::DeviceMemoryBase output_grad_data;
58   se::DeviceMemoryBase grad_output;
59   se::DeviceMemory<float> output_grad_scale;
60   se::DeviceMemory<float> output_grad_offset;
61   se::DeviceMemory<float> mean;
62   se::DeviceMemory<float> inv_stddev;
63 };
64 
65 struct DnnBatchDescriptors {
66   se::dnn::BatchDescriptor input_desc;
67   se::dnn::BatchDescriptor scale_offset_desc;
68 };
69 
MakeBatchNormDescriptors(const Shape & shape,int64_t feature_index)70 DnnBatchDescriptors MakeBatchNormDescriptors(const Shape& shape,
71                                              int64_t feature_index) {
72   std::vector<int64> logical_to_physical =
73       LayoutUtil::MakeLogicalToPhysical(shape.layout());
74 
75   auto physical_dim_size = [&](int64_t physical_dim) {
76     return shape.dimensions(LayoutUtil::Major(shape.layout(), physical_dim));
77   };
78 
79   // Batchnorm only cares about the location of the depth (aka "feature") dim.
80   // The other dims are all treated the same.  Thus we can use the kBatchDepthYX
81   // cudnn layout for any XLA shape+layout, even XLA shapes that don't have
82   // exactly 4 dimensions: We put everything that comes before the feature dim
83   // into "batch", and everything that comes after the feature dim into "Y".
84   int64_t batch_size = 1;
85   int64_t y_size = 1;
86   int64_t physical_dim;
87   for (physical_dim = 0; physical_dim != logical_to_physical[feature_index];
88        ++physical_dim) {
89     CHECK_LT(physical_dim, shape.dimensions_size());
90     batch_size *= physical_dim_size(physical_dim);
91   }
92   ++physical_dim;  // Skip the feature dimension.
93   for (; physical_dim < shape.dimensions_size(); ++physical_dim) {
94     y_size *= physical_dim_size(physical_dim);
95   }
96 
97   DnnBatchDescriptors batch_descs;
98   batch_descs.input_desc.set_layout(se::dnn::DataLayout::kBatchDepthYX)
99       .set_count(batch_size)
100       .set_feature_map_count(shape.dimensions(feature_index))
101       .set_height(y_size)
102       .set_width(1);
103 
104   batch_descs.scale_offset_desc.set_layout(se::dnn::DataLayout::kBatchDepthYX)
105       .set_feature_map_count(batch_descs.input_desc.feature_map_count())
106       .set_height(1)
107       .set_width(1)
108       .set_count(1);
109 
110   return batch_descs;
111 }
112 
AssignCommonParams(const CudnnBatchNormConfig & config,CudnnBatchNormParamsCommon * params,const se::DeviceMemoryBase & operand,const se::DeviceMemory<float> & scale)113 void AssignCommonParams(const CudnnBatchNormConfig& config,
114                         CudnnBatchNormParamsCommon* params,
115                         const se::DeviceMemoryBase& operand,
116                         const se::DeviceMemory<float>& scale) {
117   // The BatchNormTraining HLO outputs a tuple of three elements: output data,
118   // batch mean, and batch variance.  We want to make our descriptors based on
119   // the shape of the output data. Batchnorm backward call outputs a tuple of
120   // three elements: grad data, grad offset, and grad scale.  We want to make
121   // our descriptors based on the shape of the grad data.
122   const Shape& shape = config.output_shape;
123   DnnBatchDescriptors batch_descs =
124       MakeBatchNormDescriptors(shape, config.feature_index);
125   params->operand_desc = batch_descs.input_desc;
126   params->scale_offset_desc = batch_descs.scale_offset_desc;
127   params->operand = operand;
128   params->scale = scale;
129   params->epsilon = config.epsilon;
130 }
131 
132 template <typename ElemType>
RunCudnnBatchNormForwardInferenceImpl(CudnnBatchNormForwardInferenceParams * params,se::Stream * stream)133 void RunCudnnBatchNormForwardInferenceImpl(
134     CudnnBatchNormForwardInferenceParams* params, se::Stream* stream) {
135   se::DeviceMemory<ElemType> null_device_ptr(nullptr);
136   auto output_buf = se::DeviceMemory<ElemType>(params->output);
137   stream->ThenBatchNormalizationForward(
138       se::DeviceMemory<ElemType>(params->common.operand),
139       params->common.scale,                                         //
140       params->offset,                                               //
141       params->mean,                                                 //
142       params->variance,                                             //
143       /*side_input=*/null_device_ptr, params->common.operand_desc,  //
144       params->common.scale_offset_desc,                             //
145       static_cast<double>(params->common.epsilon),                  //
146       // TODO(b/137108598): Extend method to allow use of non-trivial
147       // exponential averaging.
148       /*exponential_average_factor=*/1.0,
149       se::dnn::ActivationMode::kNone,       //
150       &output_buf,                          //
151       /*batch_mean=*/nullptr,               //
152       /*batch_var=*/nullptr,                //
153       /*saved_mean=*/nullptr,               //
154       /*saved_inv_var=*/nullptr,            //
155       /*is_training=*/false,                //
156       /*reserve_space_allocator=*/nullptr,  //
157       /*workspace_allocator=*/nullptr);
158 }
159 
160 template <typename ElemType>
RunCudnnBatchNormForwardTrainingImpl(CudnnBatchNormForwardTrainingParams * params,se::Stream * stream)161 void RunCudnnBatchNormForwardTrainingImpl(
162     CudnnBatchNormForwardTrainingParams* params, se::Stream* stream) {
163   se::DeviceMemory<float> null_device_ptr(nullptr);
164   se::DeviceMemory<ElemType> null_elem_device_ptr(nullptr);
165   auto output_data = se::DeviceMemory<ElemType>(params->output_data);
166   stream->ThenBatchNormalizationForward(
167       se::DeviceMemory<ElemType>(params->common.operand),
168       params->common.scale,                    //
169       params->offset,                          //
170       /*estimated_mean=*/null_device_ptr,      //
171       /*estimated_variance=*/null_device_ptr,  //
172       /*side_input=*/null_elem_device_ptr,     //
173       params->common.operand_desc,             //
174       params->common.scale_offset_desc,        //
175       params->common.epsilon,                  //
176       // TODO(b/137108598): Extend method to allow use of non-trivial
177       // exponential averaging.
178       /*exponential_average_factor=*/1.0,
179       se::dnn::ActivationMode::kNone,                //
180       &output_data,                                  //
181       /*batch_mean=*/&null_device_ptr,               //
182       /*batch_var=*/&null_device_ptr,                //
183       /*saved_mean=*/&params->output_mean,           //
184       /*saved_inv_var=*/&params->output_inv_stddev,  //
185       /*is_training=*/true,                          //
186       /*reserve_space_allocator=*/nullptr,           //
187       /*workspace_allocator=*/nullptr);
188 }
189 
190 template <typename ElemType>
RunCudnnBatchNormBackwardImpl(CudnnBatchNormBackwardParams * params,se::Stream * stream)191 void RunCudnnBatchNormBackwardImpl(CudnnBatchNormBackwardParams* params,
192                                    se::Stream* stream) {
193   se::DeviceMemory<float> null_device_ptr(nullptr);
194   auto output_grad_data = se::DeviceMemory<ElemType>(params->output_grad_data);
195   stream->ThenBatchNormalizationBackward(
196       se::DeviceMemory<ElemType>(params->grad_output),     //
197       se::DeviceMemory<ElemType>(params->common.operand),  //
198       params->common.scale,                                //
199       params->mean,                                        //
200       params->inv_stddev,                                  //
201       params->common.operand_desc,                         //
202       params->common.scale_offset_desc,                    //
203       params->common.epsilon,                              //
204       &output_grad_data,                                   //
205       &params->output_grad_scale,                          //
206       &params->output_grad_offset,                         //
207       /*reserve_space_allocator=*/nullptr,                 //
208       /*workspace_allocator=*/nullptr);
209 }
210 
211 }  // namespace
212 
GetCudnnBatchNormConfig(const HloInstruction * instr,float epsilon,int64_t feature_index)213 CudnnBatchNormConfig GetCudnnBatchNormConfig(const HloInstruction* instr,
214                                              float epsilon,
215                                              int64_t feature_index) {
216   CudnnBatchNormConfig config;
217 
218   config.output_shape = instr->shape().IsTuple()
219                             ? instr->shape().tuple_shapes(0)
220                             : instr->shape();
221   config.output_type = config.output_shape.element_type();
222   config.epsilon = epsilon;
223   config.feature_index = feature_index;
224   return config;
225 }
226 
RunCudnnBatchNormForwardInference(const CudnnBatchNormConfig & config,se::DeviceMemoryBase operand,se::DeviceMemoryBase output,se::DeviceMemory<float> scale,se::DeviceMemory<float> offset,se::DeviceMemory<float> mean,se::DeviceMemory<float> variance,se::Stream * stream)227 Status RunCudnnBatchNormForwardInference(
228     const CudnnBatchNormConfig& config, se::DeviceMemoryBase operand,
229     se::DeviceMemoryBase output, se::DeviceMemory<float> scale,
230     se::DeviceMemory<float> offset, se::DeviceMemory<float> mean,
231     se::DeviceMemory<float> variance, se::Stream* stream) {
232   CudnnBatchNormForwardInferenceParams inference_params;
233   AssignCommonParams(config, &inference_params.common, operand, scale);
234   inference_params.offset = offset;
235   inference_params.mean = mean;
236   inference_params.variance = variance;
237   inference_params.output = output;
238 
239   switch (config.output_type) {
240     case F16:
241       RunCudnnBatchNormForwardInferenceImpl<Eigen::half>(&inference_params,
242                                                          stream);
243       break;
244     case F32:
245       RunCudnnBatchNormForwardInferenceImpl<float>(&inference_params, stream);
246       break;
247     default:
248       return Unimplemented(
249           "Primitive type %s not implemented for batchnorm forward inference",
250           primitive_util::LowercasePrimitiveTypeName(config.output_type)
251               .c_str());
252   }
253   return Status::OK();
254 }
255 
RunCudnnBatchNormForwardTraining(const CudnnBatchNormConfig & config,se::DeviceMemoryBase operand,se::DeviceMemoryBase output_data,se::DeviceMemory<float> output_mean,se::DeviceMemory<float> output_inv_stddev,se::DeviceMemory<float> scale,se::DeviceMemory<float> offset,se::Stream * stream)256 Status RunCudnnBatchNormForwardTraining(
257     const CudnnBatchNormConfig& config, se::DeviceMemoryBase operand,
258     se::DeviceMemoryBase output_data, se::DeviceMemory<float> output_mean,
259     se::DeviceMemory<float> output_inv_stddev, se::DeviceMemory<float> scale,
260     se::DeviceMemory<float> offset, se::Stream* stream) {
261   CudnnBatchNormForwardTrainingParams forward_params;
262   AssignCommonParams(config, &forward_params.common, operand, scale);
263   forward_params.offset = offset;
264   forward_params.output_data = output_data;
265   forward_params.output_mean = output_mean;
266   forward_params.output_inv_stddev = output_inv_stddev;
267 
268   switch (config.output_type) {
269     case F16:
270       RunCudnnBatchNormForwardTrainingImpl<Eigen::half>(&forward_params,
271                                                         stream);
272       break;
273     case F32:
274       RunCudnnBatchNormForwardTrainingImpl<float>(&forward_params, stream);
275       break;
276     default:
277       return Unimplemented(
278           "Primitive type %s not implemented for batchnorm forward training",
279           primitive_util::LowercasePrimitiveTypeName(config.output_type)
280               .c_str());
281   }
282   return Status::OK();
283 }
284 
RunCudnnBatchNormBackward(const CudnnBatchNormConfig & config,se::DeviceMemoryBase operand,se::DeviceMemoryBase output_grad_data,se::DeviceMemoryBase grad_output,se::DeviceMemory<float> output_grad_scale,se::DeviceMemory<float> output_grad_offset,se::DeviceMemory<float> scale,se::DeviceMemory<float> mean,se::DeviceMemory<float> inv_stddev,se::Stream * stream)285 Status RunCudnnBatchNormBackward(
286     const CudnnBatchNormConfig& config, se::DeviceMemoryBase operand,
287     se::DeviceMemoryBase output_grad_data, se::DeviceMemoryBase grad_output,
288     se::DeviceMemory<float> output_grad_scale,
289     se::DeviceMemory<float> output_grad_offset, se::DeviceMemory<float> scale,
290     se::DeviceMemory<float> mean, se::DeviceMemory<float> inv_stddev,
291     se::Stream* stream) {
292   CudnnBatchNormBackwardParams backward_params;
293   AssignCommonParams(config, &backward_params.common, operand, scale);
294   backward_params.output_grad_data = output_grad_data;
295   backward_params.grad_output = grad_output;
296   backward_params.output_grad_scale = output_grad_scale;
297   backward_params.output_grad_offset = output_grad_offset;
298   backward_params.mean = mean;
299   backward_params.inv_stddev = inv_stddev;
300 
301   switch (config.output_type) {
302     case F16:
303       RunCudnnBatchNormBackwardImpl<Eigen::half>(&backward_params, stream);
304       break;
305     case F32:
306       RunCudnnBatchNormBackwardImpl<float>(&backward_params, stream);
307       break;
308     default:
309       return Unimplemented(
310           "Primitive type %s not implemented for batchnorm backward",
311           primitive_util::LowercasePrimitiveTypeName(config.output_type)
312               .c_str());
313   }
314   return Status::OK();
315 }
316 
317 }  // namespace gpu
318 }  // namespace xla
319