1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17
18 #define EIGEN_USE_GPU
19
20 #include <algorithm>
21
22 #include "tensorflow/core/framework/register_types.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/tensor_types.h"
26 #include "tensorflow/core/kernels/bias_op.h"
27 #include "tensorflow/core/kernels/bias_op_gpu.h"
28 #include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h"
29 #include "tensorflow/core/kernels/reduction_ops_common.h"
30 #include "tensorflow/core/platform/types.h"
31 #include "tensorflow/core/util/gpu_kernel_helper.h"
32
33 namespace tensorflow {
34
35 typedef Eigen::GpuDevice GPUDevice;
36
37 // There are no native fp16 atomics (we simulate them using 32-bit atomics),
38 // so fp16 sums are done in fp32 internally. (We don't have a lot of shared
39 // memory traffic; BiasGradNCHW_SharedAtomics in particular works almost
40 // entirely on a local variable.)
41 template <class T>
42 struct AccumulatorType {
43 typedef T type;
44 };
45
46 template <>
47 struct AccumulatorType<Eigen::half> {
48 typedef float type;
49 };
50
51 // Definition of the GPU implementations declared in bias_op.cc.
52
53 template <typename T>
BiasNHWCKernel(int32 nthreads,const T * __restrict__ input,const T * __restrict__ bias,T * __restrict__ output,int32 bias_size)54 __global__ void BiasNHWCKernel(int32 nthreads, const T* __restrict__ input,
55 const T* __restrict__ bias,
56 T* __restrict__ output, int32 bias_size) {
57 GPU_1D_KERNEL_LOOP(index, nthreads) {
58 int32 bias_offset = index % bias_size;
59 output[index] = ldg(input + index) + ldg(bias + bias_offset);
60 }
61 }
62
63 template <typename T>
BiasNCHWKernel(int32 nthreads,const T * __restrict__ input,const T * __restrict__ bias,T * __restrict__ output,int32 bias_size,int32 image_size)64 __global__ void BiasNCHWKernel(int32 nthreads, const T* __restrict__ input,
65 const T* __restrict__ bias,
66 T* __restrict__ output, int32 bias_size,
67 int32 image_size) {
68 GPU_1D_KERNEL_LOOP(index, nthreads) {
69 int32 index2 = index / image_size;
70 int32 bias_offset = index2 % bias_size;
71 output[index] = ldg(input + index) + ldg(bias + bias_offset);
72 }
73 }
74
75 // Add "bias" to "input", broadcasting it on all dimensions but the bias
76 // dimension.
77 template <typename T>
compute(const GPUDevice & d,const T * input,const T * bias,T * output,int32 batch,int32 height,int32 width,int depth,int32 channel,TensorFormat data_format)78 void BiasGPU<T>::compute(const GPUDevice& d, const T* input, const T* bias,
79 T* output, int32 batch, int32 height, int32 width,
80 int depth, int32 channel, TensorFormat data_format) {
81 const int32 bias_size = channel;
82 const int32 image_size = height * width * depth;
83 const int32 total_count = batch * bias_size * image_size;
84 if (total_count == 0) {
85 return;
86 }
87 if (data_format == FORMAT_NHWC) {
88 GpuLaunchConfig config =
89 GetGpuLaunchConfig(total_count, d, BiasNHWCKernel<T>, 0, 0);
90 TF_CHECK_OK(GpuLaunchKernel(BiasNHWCKernel<T>, config.block_count,
91 config.thread_per_block, 0, d.stream(),
92 config.virtual_thread_count, input, bias,
93 output, bias_size));
94 } else {
95 GpuLaunchConfig config =
96 GetGpuLaunchConfig(total_count, d, BiasNCHWKernel<T>, 0, 0);
97 TF_CHECK_OK(GpuLaunchKernel(BiasNCHWKernel<T>, config.block_count,
98 config.thread_per_block, 0, d.stream(),
99 config.virtual_thread_count, input, bias,
100 output, bias_size, image_size));
101 }
102 }
103
104 // A naive implementation that is functional on all cases.
105 template <typename T>
BiasGradNHWC_Naive(int32 nthreads,const T * __restrict__ output_backprop,T * __restrict__ bias_backprop,int32 bias_size)106 __global__ void BiasGradNHWC_Naive(int32 nthreads,
107 const T* __restrict__ output_backprop,
108 T* __restrict__ bias_backprop,
109 int32 bias_size) {
110 GPU_1D_KERNEL_LOOP(index, nthreads) {
111 int32 bias_offset = index % bias_size;
112 GpuAtomicAdd(bias_backprop + bias_offset, ldg(output_backprop + index));
113 }
114 }
115
116 // A naive implementation that is functional on all cases.
117 template <typename T>
BiasGradNCHW_Naive(int32 nthreads,const T * __restrict__ output_backprop,T * __restrict__ bias_backprop,int32 bias_size,int32 image_size)118 __global__ void BiasGradNCHW_Naive(int32 nthreads,
119 const T* __restrict__ output_backprop,
120 T* __restrict__ bias_backprop,
121 int32 bias_size, int32 image_size) {
122 GPU_1D_KERNEL_LOOP(index, nthreads) {
123 int32 index2 = index / image_size;
124 int32 bias_offset = index2 % bias_size;
125 GpuAtomicAdd(bias_backprop + bias_offset, ldg(output_backprop + index));
126 }
127 }
128
129 template <typename T>
BiasGradNHWC_SharedAtomics(int32 nthreads,const T * __restrict__ output_backprop,T * __restrict__ bias_backprop,int32 bias_size)130 __global__ void BiasGradNHWC_SharedAtomics(
131 int32 nthreads, const T* __restrict__ output_backprop,
132 T* __restrict__ bias_backprop, int32 bias_size) {
133 typedef typename AccumulatorType<T>::type AccT;
134 GPU_DYNAMIC_SHARED_MEM_DECL(8, char, s_buf);
135 AccT* s_data = reinterpret_cast<AccT*>(s_buf);
136 for (int32 index = threadIdx.x; index < bias_size; index += blockDim.x) {
137 s_data[index] = AccT(0);
138 }
139 __syncthreads();
140
141 for (int32 index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
142 index += blockDim.x * gridDim.x) {
143 int32 bias_offset = index % bias_size;
144 GpuAtomicAdd(s_data + bias_offset, AccT(ldg(output_backprop + index)));
145 }
146 __syncthreads();
147
148 for (int32 index = threadIdx.x; index < bias_size; index += blockDim.x) {
149 GpuAtomicAdd(bias_backprop + index, T(s_data[index]));
150 }
151 }
152
153 template <typename T>
BiasGradNCHW_SharedAtomics(const T * __restrict__ output_backprop,T * __restrict__ bias_backprop,int32 batch,int32 bias_size,int32 image_size,int group_size)154 __global__ void BiasGradNCHW_SharedAtomics(
155 const T* __restrict__ output_backprop, T* __restrict__ bias_backprop,
156 int32 batch, int32 bias_size, int32 image_size, int group_size) {
157 // Initialize the shared memory.
158 typedef typename AccumulatorType<T>::type AccT;
159 const int32 kSDataSize = 32;
160 __shared__ AccT s_data[kSDataSize];
161 for (int32 index = threadIdx.x; index < kSDataSize; index += blockDim.x) {
162 s_data[index] = AccT(0);
163 }
164 __syncthreads();
165
166 // Accumulate all the values within this thread. They all have the same bias
167 // index.
168 int32 bias_index = blockIdx.x % bias_size;
169 int32 group_index = blockIdx.x / bias_size;
170 int32 total_count = batch * image_size;
171 AccT sum(0);
172 for (int32 index = group_index * blockDim.x + threadIdx.x;
173 index < total_count; index += blockDim.x * group_size) {
174 int32 image_offset = index % image_size;
175 int32 batch = index / image_size;
176 T val = ldg(output_backprop +
177 (batch * bias_size + bias_index) * image_size + image_offset);
178 sum += AccT(val);
179 }
180
181 // Write the accumulated sum in this thread to the shared memory. Each thread
182 // shifts their write location to avoid bank conflict.
183 int bias_offset = threadIdx.x % 32;
184 GpuAtomicAdd(s_data + bias_offset, sum);
185 __syncthreads();
186
187 // Accumulate the results in the shared memory into the first element.
188 // No syncthreads is needed since this is only in the same warp.
189 int32 thread_index = threadIdx.x;
190 #if GOOGLE_CUDA
191 if (thread_index < 32) {
192 AccT data = s_data[thread_index];
193 for (int32 delta = warpSize / 2; delta > 0; delta /= 2) {
194 data += GpuShuffleXorSync(kCudaWarpAll, data, delta);
195 }
196 if (thread_index == 0) {
197 GpuAtomicAdd(bias_backprop + bias_index, T(data));
198 }
199 }
200 #elif TENSORFLOW_USE_ROCM
201 if (thread_index < 16) s_data[thread_index] += s_data[thread_index + 16];
202 if (thread_index < 8) s_data[thread_index] += s_data[thread_index + 8];
203 if (thread_index < 4) s_data[thread_index] += s_data[thread_index + 4];
204 if (thread_index < 2) s_data[thread_index] += s_data[thread_index + 2];
205 if (thread_index < 1) s_data[thread_index] += s_data[thread_index + 1];
206
207 // The first thread writes out the accumulated result to the global location.
208 if (thread_index == 0) {
209 GpuAtomicAdd(bias_backprop + bias_index, T(s_data[0]));
210 }
211 #endif
212 }
213
214 template <typename T>
compute(const GPUDevice & d,const T * output_backprop,T * bias_backprop,int32 batch,int32 height,int32 width,int32 depth,int32 channel,TensorFormat data_format)215 void BiasGradGPU<T>::compute(const GPUDevice& d, const T* output_backprop,
216 T* bias_backprop, int32 batch, int32 height,
217 int32 width, int32 depth, int32 channel,
218 TensorFormat data_format) {
219 const int32 bias_size = channel;
220 const int32 image_size = height * width * depth;
221 const int32 total_count = batch * bias_size * image_size;
222 if (total_count == 0) {
223 return;
224 }
225 static constexpr int32 kWarpSize = 32;
226 GpuLaunchConfig config = GetGpuLaunchConfig(total_count, d);
227
228 const int max_shared_memory_size = d.sharedMemPerBlock() / 2;
229 int32 shared_memory_size = 0;
230 if (data_format == FORMAT_NHWC) {
231 shared_memory_size = bias_size * sizeof(typename AccumulatorType<T>::type);
232 }
233 // Check if we have enough shared memory.
234 if (shared_memory_size <= max_shared_memory_size) {
235 if (data_format == FORMAT_NHWC) {
236 TF_CHECK_OK(GpuLaunchKernel(BiasGradNHWC_SharedAtomics<T>,
237 config.block_count, config.thread_per_block,
238 shared_memory_size, d.stream(), total_count,
239 output_backprop, bias_backprop, bias_size));
240 } else {
241 // Round up the block count to multiple of bias_size.
242 int group_size = (config.block_count + bias_size - 1) / bias_size;
243 config.block_count = group_size * bias_size;
244 if (config.thread_per_block < kWarpSize) {
245 config.thread_per_block = kWarpSize;
246 }
247 TF_CHECK_OK(GpuLaunchKernel(BiasGradNCHW_SharedAtomics<T>,
248 config.block_count, config.thread_per_block,
249 0, d.stream(), output_backprop, bias_backprop,
250 batch, bias_size, image_size, group_size));
251 }
252 } else {
253 // Note that even if we don't have enough shared memory to fit the entire
254 // output block, it is possible to process one group of elements at a time.
255 // But for now, we simply fall back to the naive implementation.
256 if (data_format == FORMAT_NHWC) {
257 TF_CHECK_OK(GpuLaunchKernel(
258 BiasGradNHWC_Naive<T>, config.block_count, config.thread_per_block, 0,
259 d.stream(), total_count, output_backprop, bias_backprop, bias_size));
260 } else {
261 TF_CHECK_OK(GpuLaunchKernel(BiasGradNCHW_Naive<T>, config.block_count,
262 config.thread_per_block, 0, d.stream(),
263 total_count, output_backprop, bias_backprop,
264 bias_size, image_size));
265 }
266 }
267 }
268
269 template <typename T>
DoRowReduction(OpKernelContext * context,T * output,const T * input,int rows,int cols)270 void BiasGradGPU<T>::DoRowReduction(OpKernelContext* context, T* output,
271 const T* input, int rows, int cols) {
272 typedef const Eigen::array<TTypes<float>::Tensor::Index, 1>& ReductionAxes;
273 Constants<GPUDevice> constants;
274 gpuprim::Sum op;
275 functor::ReduceImpl<T, gpuprim::Sum, T*, const T*, ReductionAxes>(
276 context, output, input, 2, rows, cols, 1, 1, constants.kOne, op);
277 }
278
279 template <typename T>
DoColReduction(OpKernelContext * context,T * output,const T * input,int rows,int cols)280 void BiasGradGPU<T>::DoColReduction(OpKernelContext* context, T* output,
281 const T* input, int rows, int cols) {
282 typedef const Eigen::array<TTypes<float>::Tensor::Index, 1>& ReductionAxes;
283 Constants<GPUDevice> constants;
284 gpuprim::Sum op;
285 functor::ReduceImpl<T, gpuprim::Sum, T*, const T*, ReductionAxes>(
286 context, output, input, 2, rows, cols, 1, 1, constants.kZero, op);
287 }
288
289 #define DEFINE_GPU_SPECS(T) \
290 template struct BiasGPU<T>; \
291 template struct BiasGradGPU<T>;
292
293 TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
294
295 // No BiasGrad kernel for int32.
296 template struct BiasGPU<int32>;
297
298 } // end namespace tensorflow
299
300 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
301