• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <algorithm>
18 #include <limits>
19 #include "maxpool_with_argmax_v2_impl.cuh"
20 #include "include/cuda_fp16.h"
21 
22 template <typename T>
NumericLimits(T * init_val)23 __device__ __forceinline__ void NumericLimits(T *init_val) {
24   *init_val = std::numeric_limits<T>::lowest();
25 }
26 
27 // For half, different assignment.
NumericLimits(half * init_val)28 __device__ __forceinline__ void NumericLimits(half *init_val) { *init_val = __int2half_rd(-65504); }
29 
30 template <typename T, typename S>
MaxPoolWithArgmaxV2(const T * input,T * output,S * index,const int inputN,const int inputC,const int inputH,const int inputW,const int ksizeH,const int ksizeW,const int stridesH,const int stridesW,const int padsH,const int padsW,const int dilationH,const int dilationW,const int outH,const int outW)31 __global__ void MaxPoolWithArgmaxV2(const T *input, T *output, S *index, const int inputN, const int inputC,
32                                     const int inputH, const int inputW, const int ksizeH, const int ksizeW,
33                                     const int stridesH, const int stridesW, const int padsH, const int padsW,
34                                     const int dilationH, const int dilationW, const int outH, const int outW) {
35   for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (inputN * inputC * outH * outW);
36        pos += blockDim.x * gridDim.x) {
37     const int pos_n = pos / (inputC * outH * outW);
38     const int pos_c = pos / (outH * outW) % inputC;
39     const int pos_h = pos / outW % outH;
40     const int pos_w = pos % outW;
41     int start_h = pos_h * stridesH - padsH;
42     int start_w = pos_w * stridesW - padsW;
43     const int end_h = min(start_h + (ksizeH - 1) * dilationH + 1, inputH);
44     const int end_w = min(start_w + (ksizeW - 1) * dilationW + 1, inputW);
45     if (start_h < 0) {
46       start_h += ceil(-start_h / static_cast<double>(dilationH)) * dilationH;
47     }
48     if (start_w < 0) {
49       start_w += ceil(-start_w / static_cast<double>(dilationW)) * dilationW;
50     }
51     S input_start = pos_n * inputC * inputH * inputW;
52     S stride = pos_c * inputH * inputW;
53     S max_idx = stride + start_h * inputW + start_w;
54     T max_data;
55     NumericLimits(&max_data);
56     for (int cur_h = start_h; cur_h < end_h; cur_h += dilationH) {
57       for (int cur_w = start_w; cur_w < end_w; cur_w += dilationW) {
58         S input_idx = stride + cur_h * inputW + cur_w;
59         T input_data = input[input_start + input_idx];
60         if (input_data > max_data) {
61           max_idx = input_idx - stride;
62           max_data = input_data;
63         }
64       }
65     }
66     output[pos] = max_data;
67     index[pos] = max_idx;
68   }
69 }
70 
71 template <typename T, typename S>
CalMaxPoolWithArgmaxV2(const T * input,const int n,const int c,const int h,const int w,const int ksize_h,const int ksize_w,const int strides_h,const int strides_w,const int pads_h,const int pads_w,const int dilation_h,const int dilation_w,const int out_h,const int out_w,T * output,S * index,const uint32_t & device_id,cudaStream_t cuda_stream)72 cudaError_t CalMaxPoolWithArgmaxV2(const T *input, const int n, const int c, const int h, const int w,
73                                    const int ksize_h, const int ksize_w, const int strides_h, const int strides_w,
74                                    const int pads_h, const int pads_w, const int dilation_h, const int dilation_w,
75                                    const int out_h, const int out_w, T *output, S *index, const uint32_t &device_id,
76                                    cudaStream_t cuda_stream) {
77   MaxPoolWithArgmaxV2<<<CUDA_BLOCKS(device_id, n * c * out_h * out_w), CUDA_THREADS(device_id), 0, cuda_stream>>>(
78     input, output, index, n, c, h, w, ksize_h, ksize_w, strides_h, strides_w, pads_h, pads_w, dilation_h, dilation_w,
79     out_h, out_w);
80   return GetCudaStatus();
81 }
82 
83 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolWithArgmaxV2<half, int32_t>(
84   const half *input, const int n, const int c, const int h, const int w, const int ksize_h, const int ksize_w,
85   const int strides_h, const int strides_w, const int pads_h, const int pads_w, const int dilation_h,
86   const int dilation_w, const int out_h, const int out_w, half *output, int32_t *index, const uint32_t &device_id,
87   cudaStream_t cuda_stream);
88 
89 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolWithArgmaxV2<float, int32_t>(
90   const float *input, const int n, const int c, const int h, const int w, const int ksize_h, const int ksize_w,
91   const int strides_h, const int strides_w, const int pads_h, const int pads_w, const int dilation_h,
92   const int dilation_w, const int out_h, const int out_w, float *output, int32_t *index, const uint32_t &device_id,
93   cudaStream_t cuda_stream);
94 
95 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolWithArgmaxV2<double, int32_t>(
96   const double *input, const int n, const int c, const int h, const int w, const int ksize_h, const int ksize_w,
97   const int strides_h, const int strides_w, const int pads_h, const int pads_w, const int dilation_h,
98   const int dilation_w, const int out_h, const int out_w, double *output, int32_t *index, const uint32_t &device_id,
99   cudaStream_t cuda_stream);
100 
101 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolWithArgmaxV2<int8_t, int32_t>(
102   const int8_t *input, const int n, const int c, const int h, const int w, const int ksize_h, const int ksize_w,
103   const int strides_h, const int strides_w, const int pads_h, const int pads_w, const int dilation_h,
104   const int dilation_w, const int out_h, const int out_w, int8_t *output, int32_t *index, const uint32_t &device_id,
105   cudaStream_t cuda_stream);
106 
107 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolWithArgmaxV2<int16_t, int32_t>(
108   const int16_t *input, const int n, const int c, const int h, const int w, const int ksize_h, const int ksize_w,
109   const int strides_h, const int strides_w, const int pads_h, const int pads_w, const int dilation_h,
110   const int dilation_w, const int out_h, const int out_w, int16_t *output, int32_t *index, const uint32_t &device_id,
111   cudaStream_t cuda_stream);
112 
113 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolWithArgmaxV2<int32_t, int32_t>(
114   const int32_t *input, const int n, const int c, const int h, const int w, const int ksize_h, const int ksize_w,
115   const int strides_h, const int strides_w, const int pads_h, const int pads_w, const int dilation_h,
116   const int dilation_w, const int out_h, const int out_w, int32_t *output, int32_t *index, const uint32_t &device_id,
117   cudaStream_t cuda_stream);
118 
119 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolWithArgmaxV2<int64_t, int32_t>(
120   const int64_t *input, const int n, const int c, const int h, const int w, const int ksize_h, const int ksize_w,
121   const int strides_h, const int strides_w, const int pads_h, const int pads_w, const int dilation_h,
122   const int dilation_w, const int out_h, const int out_w, int64_t *output, int32_t *index, const uint32_t &device_id,
123   cudaStream_t cuda_stream);
124 
125 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolWithArgmaxV2<uint8_t, int32_t>(
126   const uint8_t *input, const int n, const int c, const int h, const int w, const int ksize_h, const int ksize_w,
127   const int strides_h, const int strides_w, const int pads_h, const int pads_w, const int dilation_h,
128   const int dilation_w, const int out_h, const int out_w, uint8_t *output, int32_t *index, const uint32_t &device_id,
129   cudaStream_t cuda_stream);
130 
131 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolWithArgmaxV2<uint16_t, int32_t>(
132   const uint16_t *input, const int n, const int c, const int h, const int w, const int ksize_h, const int ksize_w,
133   const int strides_h, const int strides_w, const int pads_h, const int pads_w, const int dilation_h,
134   const int dilation_w, const int out_h, const int out_w, uint16_t *output, int32_t *index, const uint32_t &device_id,
135   cudaStream_t cuda_stream);
136 
137 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolWithArgmaxV2<uint32_t, int32_t>(
138   const uint32_t *input, const int n, const int c, const int h, const int w, const int ksize_h, const int ksize_w,
139   const int strides_h, const int strides_w, const int pads_h, const int pads_w, const int dilation_h,
140   const int dilation_w, const int out_h, const int out_w, uint32_t *output, int32_t *index, const uint32_t &device_id,
141   cudaStream_t cuda_stream);
142 
143 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolWithArgmaxV2<uint64_t, int32_t>(
144   const uint64_t *input, const int n, const int c, const int h, const int w, const int ksize_h, const int ksize_w,
145   const int strides_h, const int strides_w, const int pads_h, const int pads_w, const int dilation_h,
146   const int dilation_w, const int out_h, const int out_w, uint64_t *output, int32_t *index, const uint32_t &device_id,
147   cudaStream_t cuda_stream);
148 
149 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolWithArgmaxV2<half, int64_t>(
150   const half *input, const int n, const int c, const int h, const int w, const int ksize_h, const int ksize_w,
151   const int strides_h, const int strides_w, const int pads_h, const int pads_w, const int dilation_h,
152   const int dilation_w, const int out_h, const int out_w, half *output, int64_t *index, const uint32_t &device_id,
153   cudaStream_t cuda_stream);
154 
155 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolWithArgmaxV2<float, int64_t>(
156   const float *input, const int n, const int c, const int h, const int w, const int ksize_h, const int ksize_w,
157   const int strides_h, const int strides_w, const int pads_h, const int pads_w, const int dilation_h,
158   const int dilation_w, const int out_h, const int out_w, float *output, int64_t *index, const uint32_t &device_id,
159   cudaStream_t cuda_stream);
160 
161 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolWithArgmaxV2<double, int64_t>(
162   const double *input, const int n, const int c, const int h, const int w, const int ksize_h, const int ksize_w,
163   const int strides_h, const int strides_w, const int pads_h, const int pads_w, const int dilation_h,
164   const int dilation_w, const int out_h, const int out_w, double *output, int64_t *index, const uint32_t &device_id,
165   cudaStream_t cuda_stream);
166 
167 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolWithArgmaxV2<int8_t, int64_t>(
168   const int8_t *input, const int n, const int c, const int h, const int w, const int ksize_h, const int ksize_w,
169   const int strides_h, const int strides_w, const int pads_h, const int pads_w, const int dilation_h,
170   const int dilation_w, const int out_h, const int out_w, int8_t *output, int64_t *index, const uint32_t &device_id,
171   cudaStream_t cuda_stream);
172 
173 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolWithArgmaxV2<int16_t, int64_t>(
174   const int16_t *input, const int n, const int c, const int h, const int w, const int ksize_h, const int ksize_w,
175   const int strides_h, const int strides_w, const int pads_h, const int pads_w, const int dilation_h,
176   const int dilation_w, const int out_h, const int out_w, int16_t *output, int64_t *index, const uint32_t &device_id,
177   cudaStream_t cuda_stream);
178 
179 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolWithArgmaxV2<int32_t, int64_t>(
180   const int32_t *input, const int n, const int c, const int h, const int w, const int ksize_h, const int ksize_w,
181   const int strides_h, const int strides_w, const int pads_h, const int pads_w, const int dilation_h,
182   const int dilation_w, const int out_h, const int out_w, int32_t *output, int64_t *index, const uint32_t &device_id,
183   cudaStream_t cuda_stream);
184 
185 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolWithArgmaxV2<int64_t, int64_t>(
186   const int64_t *input, const int n, const int c, const int h, const int w, const int ksize_h, const int ksize_w,
187   const int strides_h, const int strides_w, const int pads_h, const int pads_w, const int dilation_h,
188   const int dilation_w, const int out_h, const int out_w, int64_t *output, int64_t *index, const uint32_t &device_id,
189   cudaStream_t cuda_stream);
190 
191 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolWithArgmaxV2<uint8_t, int64_t>(
192   const uint8_t *input, const int n, const int c, const int h, const int w, const int ksize_h, const int ksize_w,
193   const int strides_h, const int strides_w, const int pads_h, const int pads_w, const int dilation_h,
194   const int dilation_w, const int out_h, const int out_w, uint8_t *output, int64_t *index, const uint32_t &device_id,
195   cudaStream_t cuda_stream);
196 
197 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolWithArgmaxV2<uint16_t, int64_t>(
198   const uint16_t *input, const int n, const int c, const int h, const int w, const int ksize_h, const int ksize_w,
199   const int strides_h, const int strides_w, const int pads_h, const int pads_w, const int dilation_h,
200   const int dilation_w, const int out_h, const int out_w, uint16_t *output, int64_t *index, const uint32_t &device_id,
201   cudaStream_t cuda_stream);
202 
203 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolWithArgmaxV2<uint32_t, int64_t>(
204   const uint32_t *input, const int n, const int c, const int h, const int w, const int ksize_h, const int ksize_w,
205   const int strides_h, const int strides_w, const int pads_h, const int pads_w, const int dilation_h,
206   const int dilation_w, const int out_h, const int out_w, uint32_t *output, int64_t *index, const uint32_t &device_id,
207   cudaStream_t cuda_stream);
208 
209 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolWithArgmaxV2<uint64_t, int64_t>(
210   const uint64_t *input, const int n, const int c, const int h, const int w, const int ksize_h, const int ksize_w,
211   const int strides_h, const int strides_w, const int pads_h, const int pads_w, const int dilation_h,
212   const int dilation_w, const int out_h, const int out_w, uint64_t *output, int64_t *index, const uint32_t &device_id,
213   cudaStream_t cuda_stream);
214