• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17 
18 #define EIGEN_USE_GPU
19 
20 #include <stdio.h>
21 
22 #include <iostream>
23 
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor_types.h"
26 #include "tensorflow/core/kernels/avgpooling_op.h"
27 #include "tensorflow/core/util/gpu_kernel_helper.h"
28 
29 namespace tensorflow {
30 
31 typedef Eigen::GpuDevice GPUDevice;
32 
33 #define DEFINE_GPU_KERNELS(T) \
34   template struct functor::SpatialAvgPooling<GPUDevice, T>;
35 
36 DEFINE_GPU_KERNELS(Eigen::half)
DEFINE_GPU_KERNELS(float)37 DEFINE_GPU_KERNELS(float)
38 DEFINE_GPU_KERNELS(double)
39 
40 #undef DEFINE_GPU_KERNELS
41 
42 template <typename dtype>
43 __global__ void AvePoolBackwardNHWC(
44     const int nthreads, const dtype* const __restrict__ top_diff, const int num,
45     const int height, const int width, const int channels,
46     const int pooled_height, const int pooled_width, const int kernel_h,
47     const int kernel_w, const int stride_h, const int stride_w, const int pad_t,
48     const int pad_l, dtype* const __restrict__ bottom_diff) {
49   GPU_1D_KERNEL_LOOP(index, nthreads) {
50     // find out the local index
51     // find out the local offset
52     const int c = index % channels;
53     const int w = index / channels % width + pad_l;
54     const int h = (index / channels / width) % height + pad_t;
55     const int n = index / channels / width / height;
56     const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
57     const int phend = min(h / stride_h + 1, pooled_height);
58     const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
59     const int pwend = min(w / stride_w + 1, pooled_width);
60     dtype gradient(0);
61     const dtype* const top_diff_slice =
62         top_diff + n * pooled_height * pooled_width * channels + c;
63     for (int ph = phstart; ph < phend; ++ph) {
64       for (int pw = pwstart; pw < pwend; ++pw) {
65         // figure out the pooling size
66         int hstart = ph * stride_h - pad_t;
67         int wstart = pw * stride_w - pad_l;
68         int hend = min(hstart + kernel_h, height);
69         int wend = min(wstart + kernel_w, width);
70         hstart = max(hstart, 0);
71         wstart = max(wstart, 0);
72         int pool_size = (hend - hstart) * (wend - wstart);
73         gradient += top_diff_slice[(ph * pooled_width + pw) * channels] /
74                     dtype(pool_size);
75       }
76     }
77     bottom_diff[index] = gradient;
78   }
79 }
80 
81 template <typename T>
RunAvePoolBackwardNHWC(const T * const top_diff,const int num,const int height,const int width,const int channels,const int pooled_height,const int pooled_width,const int kernel_h,const int kernel_w,const int stride_h,const int stride_w,const int pad_t,const int pad_l,T * const bottom_diff,const GPUDevice & d)82 bool RunAvePoolBackwardNHWC(const T* const top_diff, const int num,
83                             const int height, const int width,
84                             const int channels, const int pooled_height,
85                             const int pooled_width, const int kernel_h,
86                             const int kernel_w, const int stride_h,
87                             const int stride_w, const int pad_t,
88                             const int pad_l, T* const bottom_diff,
89                             const GPUDevice& d) {
90   int x_size = num * height * width * channels;
91   GpuLaunchConfig config = GetGpuLaunchConfig(x_size, d);
92   TF_CHECK_OK(GpuLaunchKernel(
93       AvePoolBackwardNHWC<T>, config.block_count, config.thread_per_block, 0,
94       d.stream(), config.virtual_thread_count, top_diff, num, height, width,
95       channels, pooled_height, pooled_width, kernel_h, kernel_w, stride_h,
96       stride_w, pad_t, pad_t, bottom_diff));
97 
98   return d.ok();
99 }
100 
101 template bool RunAvePoolBackwardNHWC(
102     const double* const top_diff, const int num, const int height,
103     const int width, const int channels, const int pooled_height,
104     const int pooled_width, const int kernel_h, const int kernel_w,
105     const int stride_h, const int stride_w, const int pad_t, const int pad_l,
106     double* const bottom_diff, const GPUDevice& d);
107 template bool RunAvePoolBackwardNHWC(
108     const float* const top_diff, const int num, const int height,
109     const int width, const int channels, const int pooled_height,
110     const int pooled_width, const int kernel_h, const int kernel_w,
111     const int stride_h, const int stride_w, const int pad_t, const int pad_l,
112     float* const bottom_diff, const GPUDevice& d);
113 template bool RunAvePoolBackwardNHWC(
114     const Eigen::half* const top_diff, const int num, const int height,
115     const int width, const int channels, const int pooled_height,
116     const int pooled_width, const int kernel_h, const int kernel_w,
117     const int stride_h, const int stride_w, const int pad_t, const int pad_l,
118     Eigen::half* const bottom_diff, const GPUDevice& d);
119 
120 }  // end namespace tensorflow
121 
122 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
123