1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if GOOGLE_CUDA
17
18 #define EIGEN_USE_GPU
19
20 #include <stdio.h>
21 #include <iostream>
22
23 #include "tensorflow/core/kernels/avgpooling_op.h"
24
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/tensor_types.h"
27 #include "tensorflow/core/util/cuda_kernel_helper.h"
28
29 namespace tensorflow {
30
31 typedef Eigen::GpuDevice GPUDevice;
32
33 #define DEFINE_GPU_KERNELS(T) \
34 template struct functor::SpatialAvgPooling<GPUDevice, T>;
35
36 DEFINE_GPU_KERNELS(Eigen::half)
DEFINE_GPU_KERNELS(float)37 DEFINE_GPU_KERNELS(float)
38 DEFINE_GPU_KERNELS(double)
39
40 #undef DEFINE_GPU_KERNELS
41
42 template <typename dtype>
43 __global__ void AvePoolBackwardNHWC(const int nthreads,
44 const dtype* const top_diff, const int num,
45 const int height, const int width,
46 const int channels, const int pooled_height,
47 const int pooled_width, const int kernel_h,
48 const int kernel_w, const int stride_h,
49 const int stride_w, const int pad_t,
50 const int pad_l, dtype* const bottom_diff) {
51 CUDA_1D_KERNEL_LOOP(index, nthreads) {
52 // find out the local index
53 // find out the local offset
54 const int c = index % channels;
55 const int w = index / channels % width + pad_l;
56 const int h = (index / channels / width) % height + pad_t;
57 const int n = index / channels / width / height;
58 const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
59 const int phend = min(h / stride_h + 1, pooled_height);
60 const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
61 const int pwend = min(w / stride_w + 1, pooled_width);
62 dtype gradient(0);
63 const dtype* const top_diff_slice =
64 top_diff + n * pooled_height * pooled_width * channels + c;
65 for (int ph = phstart; ph < phend; ++ph) {
66 for (int pw = pwstart; pw < pwend; ++pw) {
67 // figure out the pooling size
68 int hstart = ph * stride_h - pad_t;
69 int wstart = pw * stride_w - pad_l;
70 int hend = min(hstart + kernel_h, height);
71 int wend = min(wstart + kernel_w, width);
72 hstart = max(hstart, 0);
73 wstart = max(wstart, 0);
74 int pool_size = (hend - hstart) * (wend - wstart);
75 gradient += top_diff_slice[(ph * pooled_width + pw) * channels] /
76 dtype(pool_size);
77 }
78 }
79 bottom_diff[index] = gradient;
80 }
81 }
82
83 template <typename T>
RunAvePoolBackwardNHWC(const T * const top_diff,const int num,const int height,const int width,const int channels,const int pooled_height,const int pooled_width,const int kernel_h,const int kernel_w,const int stride_h,const int stride_w,const int pad_t,const int pad_l,T * const bottom_diff,const GPUDevice & d)84 bool RunAvePoolBackwardNHWC(const T* const top_diff, const int num,
85 const int height, const int width,
86 const int channels, const int pooled_height,
87 const int pooled_width, const int kernel_h,
88 const int kernel_w, const int stride_h,
89 const int stride_w, const int pad_t,
90 const int pad_l, T* const bottom_diff,
91 const GPUDevice& d) {
92 int x_size = num * height * width * channels;
93 CudaLaunchConfig config = GetCudaLaunchConfig(x_size, d);
94 TF_CHECK_OK(CudaLaunchKernel(
95 AvePoolBackwardNHWC<T>, config.block_count, config.thread_per_block, 0,
96 d.stream(), config.virtual_thread_count, top_diff, num, height, width,
97 channels, pooled_height, pooled_width, kernel_h, kernel_w, stride_h,
98 stride_w, pad_t, pad_t, bottom_diff));
99
100 return d.ok();
101 }
102
103 template bool RunAvePoolBackwardNHWC(
104 const double* const top_diff, const int num, const int height,
105 const int width, const int channels, const int pooled_height,
106 const int pooled_width, const int kernel_h, const int kernel_w,
107 const int stride_h, const int stride_w, const int pad_t, const int pad_l,
108 double* const bottom_diff, const GPUDevice& d);
109 template bool RunAvePoolBackwardNHWC(
110 const float* const top_diff, const int num, const int height,
111 const int width, const int channels, const int pooled_height,
112 const int pooled_width, const int kernel_h, const int kernel_w,
113 const int stride_h, const int stride_w, const int pad_t, const int pad_l,
114 float* const bottom_diff, const GPUDevice& d);
115 template bool RunAvePoolBackwardNHWC(
116 const Eigen::half* const top_diff, const int num, const int height,
117 const int width, const int channels, const int pooled_height,
118 const int pooled_width, const int kernel_h, const int kernel_w,
119 const int stride_h, const int stride_w, const int pad_t, const int pad_l,
120 Eigen::half* const bottom_diff, const GPUDevice& d);
121
122 } // end namespace tensorflow
123
124 #endif // GOOGLE_CUDA
125