• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <stdint.h>
18 #include <stdio.h>
19 #include "include/cuda_fp16.h"
20 #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
21 #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/pad_impl.cuh"
22 
23 template <typename T>
24 using Complex = mindspore::utils::Complex<T>;
25 
26 // For internal OP use, not user facing
27 template <typename T>
Pad(const size_t size,const T * input,const int num,const int channels,const int old_height,const int old_width,const int padded_height,const int padded_width,const int pad_top,const int pad_left,const float pad_value,T * output)28 __global__ void Pad(const size_t size, const T *input, const int num, const int channels, const int old_height,
29                     const int old_width, const int padded_height, const int padded_width, const int pad_top,
30                     const int pad_left, const float pad_value, T *output) {
31   T pad_value_ = static_cast<T>(pad_value);
32   for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
33     int block_num = pos / padded_width / padded_height;
34     const int padded_w = pos % padded_width;
35     const int padded_h = pos / padded_width % padded_height;
36     if (padded_h - pad_top < 0 || padded_w - pad_left < 0 || padded_h - pad_top >= old_height ||
37         padded_w - pad_left >= old_width) {
38       output[pos] = pad_value_;
39     } else {
40       output[pos] = input[(block_num * old_height + padded_h - pad_top) * old_width + padded_w - pad_left];
41     }
42   }
43 }
44 
45 // For internal OP use, not user facing
46 template <typename T>
PadNHWC(const size_t size,const T * input,const int num,const int old_height,const int old_width,const int channels,const int padded_height,const int padded_width,const int pad_top,const int pad_left,float pad_value,T * output)47 __global__ void PadNHWC(const size_t size, const T *input, const int num, const int old_height, const int old_width,
48                         const int channels, const int padded_height, const int padded_width, const int pad_top,
49                         const int pad_left, float pad_value, T *output) {
50   T pad_value_ = static_cast<T>(pad_value);
51   for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
52     int block_num = pos / channels / padded_width / padded_height;
53     const int padded_w = pos / channels % padded_width;
54     const int padded_h = pos / channels / padded_width % padded_height;
55     if (padded_h - pad_top < 0 || padded_w - pad_left < 0 || padded_h - pad_top >= old_height ||
56         padded_w - pad_left >= old_width) {
57       output[pos] = pad_value_;
58     } else {
59       output[pos] = input[((block_num * old_height + padded_h - pad_top) * old_width + padded_w - pad_left) * channels +
60                           pos % channels];
61     }
62   }
63 }
64 
65 template <typename T>
PadGeneral(const T * input,T * output,const PadInfo info,const int input_size,const size_t input_rank)66 __global__ void PadGeneral(const T *input, T *output, const PadInfo info, const int input_size,
67                            const size_t input_rank) {
68   const int *input_shape = info.shape;
69   const int *strides = info.strides;
70   const int *paddings = info.paddings;
71   for (size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; gt_id < input_size; gt_id += blockDim.x * gridDim.x) {
72     int linear_index = gt_id;
73     int padded_linear_index = 0;
74     for (int i = input_rank - 1; i >= 0; i--) {
75       int unravel_dimension = input_shape[i];
76       int unraveled_index = linear_index % unravel_dimension;
77       padded_linear_index += ((unraveled_index + paddings[2 * i]) * strides[i]);
78       linear_index -= unraveled_index;
79       linear_index /= unravel_dimension;
80     }
81     output[padded_linear_index] = input[gt_id];
82   }
83 }
84 
85 template <typename T>
PadGradNHWC(const size_t size,const T * dy,const int num,const int old_height,const int old_width,const int channels,const int padded_height,const int padded_width,const int pad_top,const int pad_left,T * dx)86 __global__ void PadGradNHWC(const size_t size, const T *dy, const int num, const int old_height, const int old_width,
87                             const int channels, const int padded_height, const int padded_width, const int pad_top,
88                             const int pad_left, T *dx) {
89   for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
90     int block_num = pos / channels / old_width / old_height;
91     const int padded_w = pos / channels % old_width + pad_left;
92     const int padded_h = pos / channels / old_width % old_height + pad_top;
93     dx[pos] = dy[((block_num * padded_height + padded_h) * padded_width + padded_w) * channels + pos % channels];
94   }
95 }
96 
97 template <typename T>
PadGrad(const size_t size,const T * dy,const int num,const int channels,const int old_height,const int old_width,const int padded_height,const int padded_width,const int pad_top,const int pad_left,T * dx)98 __global__ void PadGrad(const size_t size, const T *dy, const int num, const int channels, const int old_height,
99                         const int old_width, const int padded_height, const int padded_width, const int pad_top,
100                         const int pad_left, T *dx) {
101   for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
102     int block_num = pos / old_width / old_height;
103     const int padded_w = pos % old_width + pad_left;
104     const int padded_h = pos / old_width % old_height + pad_top;
105     dx[pos] = dy[(block_num * padded_height + padded_h) * padded_width + padded_w];
106   }
107 }
108 
109 // For internal OP use, not user facing
110 template <typename T>
Pad3d(const size_t size,const T * input,const int num,const int channels,const int old_depth,const int old_height,const int old_width,const int old_dhw,const int old_hw,const int padded_depth,const int padded_height,const int padded_width,const int padded_dhw,const int padded_hw,const int pad_head,const int pad_top,const int pad_left,const float pad_value,T * output)111 __global__ void Pad3d(const size_t size, const T *input, const int num, const int channels, const int old_depth,
112                       const int old_height, const int old_width, const int old_dhw, const int old_hw,
113                       const int padded_depth, const int padded_height, const int padded_width, const int padded_dhw,
114                       const int padded_hw, const int pad_head, const int pad_top, const int pad_left,
115                       const float pad_value, T *output) {
116   T pad_value_ = static_cast<T>(pad_value);
117   for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
118     const int pos_d = pos / padded_hw % padded_depth;
119     const int pos_h = pos / padded_width % padded_height;
120     const int pos_w = pos % padded_width;
121     const int block_num = pos / padded_dhw;
122 
123     if (pos_d - pad_head < 0 || pos_h - pad_top < 0 || pos_w - pad_left < 0 || pos_d - pad_head >= old_depth ||
124         pos_h - pad_top >= old_height || pos_w - pad_left >= old_width) {
125       output[pos] = pad_value_;
126     } else {
127       int index = block_num * old_dhw + old_hw * (pos_d - pad_head) + old_width * (pos_h - pad_top) + pos_w - pad_left;
128       output[pos] = input[index];
129     }
130   }
131 }
132 
133 template <typename T>
PadGrad3d(const size_t size,const T * dy,const int num,const int channels,const int old_depth,const int old_height,const int old_width,const int old_dhw,const int old_hw,const int padded_depth,const int padded_height,const int padded_width,const int padded_dhw,const int padded_hw,const int pad_head,const int pad_top,const int pad_left,T * dx)134 __global__ void PadGrad3d(const size_t size, const T *dy, const int num, const int channels, const int old_depth,
135                           const int old_height, const int old_width, const int old_dhw, const int old_hw,
136                           const int padded_depth, const int padded_height, const int padded_width, const int padded_dhw,
137                           const int padded_hw, const int pad_head, const int pad_top, const int pad_left, T *dx) {
138   for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
139     const int block_num = pos / old_dhw;
140     const int pos_d = pos / old_hw % old_depth + pad_head;
141     const int pos_h = pos / old_width % old_height + pad_top;
142     const int pos_w = pos % old_width + pad_left;
143     const int index = block_num * padded_dhw + pos_d * padded_hw + pos_h * padded_width + pos_w;
144     dx[pos] = dy[index];
145   }
146 }
147 
148 // For internal OP use, not user facing
149 template <typename T>
PadNDHWC(const size_t size,const T * input,const int num,const int old_depth,const int old_height,const int old_width,const int channels,const int padded_depth,const int padded_height,const int padded_width,const int pad_head,const int pad_top,const int pad_left,float pad_value,T * output)150 __global__ void PadNDHWC(const size_t size, const T *input, const int num, const int old_depth, const int old_height,
151                          const int old_width, const int channels, const int padded_depth, const int padded_height,
152                          const int padded_width, const int pad_head, const int pad_top, const int pad_left,
153                          float pad_value, T *output) {
154   T pad_value_ = static_cast<T>(pad_value);
155   for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
156     int block_num = pos / (channels * padded_width * padded_height * padded_depth);
157     const int padded_w = pos / channels % padded_width;
158     const int padded_h = pos / (channels * padded_width) % padded_height;
159     const int padded_d = pos / (channels * padded_width * padded_height) % padded_depth;
160     if (padded_d - pad_head < 0 || padded_h - pad_top < 0 || padded_w - pad_left < 0 ||
161         padded_d - pad_head >= old_depth || padded_h - pad_top >= old_height || padded_w - pad_left >= old_width) {
162       output[pos] = pad_value_;
163     } else {
164       output[pos] =
165         input[(((block_num * old_depth + padded_d - pad_head) * old_height + padded_h - pad_top) * old_width +
166                padded_w - pad_left) *
167                 channels +
168               pos % channels];
169     }
170   }
171 }
172 
173 template <typename T>
PadGradNDHWC(const size_t size,const T * dy,const int num,const int old_depth,const int old_height,const int old_width,const int channels,const int padded_depth,const int padded_height,const int padded_width,const int pad_head,const int pad_top,const int pad_left,T * dx)174 __global__ void PadGradNDHWC(const size_t size, const T *dy, const int num, const int old_depth, const int old_height,
175                              const int old_width, const int channels, const int padded_depth, const int padded_height,
176                              const int padded_width, const int pad_head, const int pad_top, const int pad_left, T *dx) {
177   for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
178     int block_num = pos / (channels * old_width * old_height);
179     const int padded_w = pos / channels % old_width + pad_left;
180     const int padded_h = pos / (channels * old_width) % old_height + pad_top;
181     const int padded_d = pos / (channels * old_width * old_height) % old_depth + pad_head;
182     dx[pos] =
183       dy[(((block_num * padded_depth + padded_d) * padded_height + padded_h) * padded_width + padded_w) * channels +
184          pos % channels];
185   }
186 }
187 
188 template <typename T>
CalPad(const size_t size,const T * input,const int num,const int channels,const int old_height,const int old_width,const int padded_height,const int padded_width,const int pad_top,const int pad_left,const float pad_value,T * output,cudaStream_t cuda_stream)189 cudaError_t CalPad(const size_t size, const T *input, const int num, const int channels, const int old_height,
190                    const int old_width, const int padded_height, const int padded_width, const int pad_top,
191                    const int pad_left, const float pad_value, T *output, cudaStream_t cuda_stream) {
192   Pad<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, num, channels, old_height, old_width,
193                                                          padded_height, padded_width, pad_top, pad_left, pad_value,
194                                                          output);
195   return GetCudaStatus();
196 }
197 
198 template <typename T>
CalPadNHWC(const size_t size,const T * input,const int num,const int old_height,const int old_width,const int channels,const int padded_height,const int padded_width,const int pad_top,const int pad_left,const float pad_value,T * output,cudaStream_t cuda_stream)199 cudaError_t CalPadNHWC(const size_t size, const T *input, const int num, const int old_height, const int old_width,
200                        const int channels, const int padded_height, const int padded_width, const int pad_top,
201                        const int pad_left, const float pad_value, T *output, cudaStream_t cuda_stream) {
202   PadNHWC<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, num, old_height, old_width, channels,
203                                                              padded_height, padded_width, pad_top, pad_left, pad_value,
204                                                              output);
205   return GetCudaStatus();
206 }
207 
208 template <typename T>
CalPadGeneral(const T * input,T * output,const PadInfo & info,const int input_size,const size_t input_rank,cudaStream_t cuda_stream)209 cudaError_t CalPadGeneral(const T *input, T *output, const PadInfo &info, const int input_size, const size_t input_rank,
210                           cudaStream_t cuda_stream) {
211   PadGeneral<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(input, output, info, input_size, input_rank);
212   return GetCudaStatus();
213 }
214 
215 template <typename T>
CalPadGradNHWC(const size_t size,const T * dy,const int num,const int old_height,const int old_width,const int channels,const int padded_height,const int padded_width,const int pad_top,const int pad_left,T * dx,cudaStream_t cuda_stream)216 cudaError_t CalPadGradNHWC(const size_t size, const T *dy, const int num, const int old_height, const int old_width,
217                            const int channels, const int padded_height, const int padded_width, const int pad_top,
218                            const int pad_left, T *dx, cudaStream_t cuda_stream) {
219   PadGradNHWC<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dy, num, old_height, old_width, channels,
220                                                                  padded_height, padded_width, pad_top, pad_left, dx);
221   return GetCudaStatus();
222 }
223 
224 template <typename T>
CalPadGrad(const size_t size,const T * dy,const int num,const int channels,const int old_height,const int old_width,const int padded_height,const int padded_width,const int pad_top,const int pad_left,T * dx,cudaStream_t cuda_stream)225 cudaError_t CalPadGrad(const size_t size, const T *dy, const int num, const int channels, const int old_height,
226                        const int old_width, const int padded_height, const int padded_width, const int pad_top,
227                        const int pad_left, T *dx, cudaStream_t cuda_stream) {
228   PadGrad<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dy, num, channels, old_height, old_width,
229                                                              padded_height, padded_width, pad_top, pad_left, dx);
230   return GetCudaStatus();
231 }
232 
233 template <typename T>
CalPad3d(const size_t size,const T * input,const int num,const int channels,const int old_depth,const int old_height,const int old_width,const int padded_depth,const int padded_height,const int padded_width,const int pad_head,const int pad_top,const int pad_left,const float pad_value,T * output,cudaStream_t cuda_stream)234 cudaError_t CalPad3d(const size_t size, const T *input, const int num, const int channels, const int old_depth,
235                      const int old_height, const int old_width, const int padded_depth, const int padded_height,
236                      const int padded_width, const int pad_head, const int pad_top, const int pad_left,
237                      const float pad_value, T *output, cudaStream_t cuda_stream) {
238   const int old_hw = old_height * old_width;
239   const int old_dhw = old_depth * old_hw;
240   const int padded_hw = padded_height * padded_width;
241   const int padded_dhw = padded_depth * padded_hw;
242   Pad3d<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
243     size, input, num, channels, old_depth, old_height, old_width, old_dhw, old_hw, padded_depth, padded_height,
244     padded_width, padded_dhw, padded_hw, pad_head, pad_top, pad_left, pad_value, output);
245   return GetCudaStatus();
246 }
247 
248 template <typename T>
CalPadGrad3d(const size_t size,const T * dy,const int num,const int channels,const int old_depth,const int old_height,const int old_width,const int padded_depth,const int padded_height,const int padded_width,const int pad_head,const int pad_top,const int pad_left,T * dx,cudaStream_t cuda_stream)249 cudaError_t CalPadGrad3d(const size_t size, const T *dy, const int num, const int channels, const int old_depth,
250                          const int old_height, const int old_width, const int padded_depth, const int padded_height,
251                          const int padded_width, const int pad_head, const int pad_top, const int pad_left, T *dx,
252                          cudaStream_t cuda_stream) {
253   const int old_hw = old_height * old_width;
254   const int old_dhw = old_depth * old_hw;
255   const int padded_hw = padded_height * padded_width;
256   const int padded_dhw = padded_depth * padded_hw;
257   PadGrad3d<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
258     size, dy, num, channels, old_depth, old_height, old_width, old_dhw, old_hw, padded_depth, padded_height,
259     padded_width, padded_dhw, padded_hw, pad_head, pad_top, pad_left, dx);
260   return GetCudaStatus();
261 }
262 
263 template <typename T>
CalPadNDHWC(const size_t size,const T * input,const int num,const int old_depth,const int old_height,const int old_width,const int channels,const int padded_depth,const int padded_height,const int padded_width,const int pad_head,const int pad_top,const int pad_left,const float pad_value,T * output,cudaStream_t cuda_stream)264 cudaError_t CalPadNDHWC(const size_t size, const T *input, const int num, const int old_depth, const int old_height,
265                         const int old_width, const int channels, const int padded_depth, const int padded_height,
266                         const int padded_width, const int pad_head, const int pad_top, const int pad_left,
267                         const float pad_value, T *output, cudaStream_t cuda_stream) {
268   PadNDHWC<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, num, old_depth, old_height, old_width,
269                                                               channels, padded_depth, padded_height, padded_width,
270                                                               pad_head, pad_top, pad_left, pad_value, output);
271   return GetCudaStatus();
272 }
273 
274 template <typename T>
CalPadGradNDHWC(const size_t size,const T * dy,const int num,const int old_depth,const int old_height,const int old_width,const int channels,const int padded_depth,const int padded_height,const int padded_width,const int pad_head,const int pad_top,const int pad_left,T * dx,cudaStream_t cuda_stream)275 cudaError_t CalPadGradNDHWC(const size_t size, const T *dy, const int num, const int old_depth, const int old_height,
276                             const int old_width, const int channels, const int padded_depth, const int padded_height,
277                             const int padded_width, const int pad_head, const int pad_top, const int pad_left, T *dx,
278                             cudaStream_t cuda_stream) {
279   PadGradNDHWC<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dy, num, old_depth, old_height, old_width,
280                                                                   channels, padded_depth, padded_height, padded_width,
281                                                                   pad_head, pad_top, pad_left, dx);
282   return GetCudaStatus();
283 }
284 
285 template CUDA_LIB_EXPORT cudaError_t CalPad<float>(const size_t size, const float *input, const int num,
286                                                    const int channels, const int old_height, const int old_width,
287                                                    const int padded_height, const int padded_width, const int pad_top,
288                                                    const int pad_left, float pad_value, float *output,
289                                                    cudaStream_t cuda_stream);
290 template CUDA_LIB_EXPORT cudaError_t CalPadGrad<float>(const size_t size, const float *dy, const int num,
291                                                        const int channels, const int old_height, const int old_width,
292                                                        const int padded_height, const int padded_width,
293                                                        const int pad_top, const int pad_left, float *dx,
294                                                        cudaStream_t cuda_stream);
295 template CUDA_LIB_EXPORT cudaError_t CalPad<half>(const size_t size, const half *input, const int num,
296                                                   const int channels, const int old_height, const int old_width,
297                                                   const int padded_height, const int padded_width, const int pad_top,
298                                                   const int pad_left, float pad_value, half *output,
299                                                   cudaStream_t cuda_stream);
300 template CUDA_LIB_EXPORT cudaError_t CalPadGrad<half>(const size_t size, const half *dy, const int num,
301                                                       const int channels, const int old_height, const int old_width,
302                                                       const int padded_height, const int padded_width,
303                                                       const int pad_top, const int pad_left, half *dx,
304                                                       cudaStream_t cuda_stream);
305 template CUDA_LIB_EXPORT cudaError_t CalPadNHWC<float>(const size_t size, const float *input, const int num,
306                                                        const int old_height, const int old_width, const int channels,
307                                                        const int padded_height, const int padded_width,
308                                                        const int pad_top, const int pad_left, float pad_value,
309                                                        float *output, cudaStream_t cuda_stream);
310 template CUDA_LIB_EXPORT cudaError_t CalPadNHWC<half>(const size_t size, const half *input, const int num,
311                                                       const int old_height, const int old_width, const int channels,
312                                                       const int padded_height, const int padded_width,
313                                                       const int pad_top, const int pad_left, float pad_value,
314                                                       half *output, cudaStream_t cuda_stream);
315 template CUDA_LIB_EXPORT cudaError_t CalPadGradNHWC<float>(const size_t size, const float *dy, const int num,
316                                                            const int old_height, const int old_width,
317                                                            const int channels, const int padded_height,
318                                                            const int padded_width, const int pad_top,
319                                                            const int pad_left, float *dx, cudaStream_t cuda_stream);
320 template CUDA_LIB_EXPORT cudaError_t CalPadGradNHWC<half>(const size_t size, const half *dy, const int num,
321                                                           const int old_height, const int old_width, const int channels,
322                                                           const int padded_height, const int padded_width,
323                                                           const int pad_top, const int pad_left, half *dx,
324                                                           cudaStream_t cuda_stream);
325 template CUDA_LIB_EXPORT cudaError_t CalPadGeneral<double>(const double *input, double *output, const PadInfo &info,
326                                                            const int input_size, const size_t input_rank,
327                                                            cudaStream_t cuda_stream);
328 template CUDA_LIB_EXPORT cudaError_t CalPadGeneral<float>(const float *input, float *output, const PadInfo &info,
329                                                           const int input_size, const size_t input_rank,
330                                                           cudaStream_t cuda_stream);
331 template CUDA_LIB_EXPORT cudaError_t CalPadGeneral<half>(const half *input, half *output, const PadInfo &info,
332                                                          const int input_size, const size_t input_rank,
333                                                          cudaStream_t cuda_stream);
334 template CUDA_LIB_EXPORT cudaError_t CalPadGeneral<int8_t>(const int8_t *input, int8_t *output, const PadInfo &info,
335                                                            const int input_size, const size_t input_rank,
336                                                            cudaStream_t cuda_stream);
337 template CUDA_LIB_EXPORT cudaError_t CalPadGeneral<int16_t>(const int16_t *input, int16_t *output, const PadInfo &info,
338                                                             const int input_size, const size_t input_rank,
339                                                             cudaStream_t cuda_stream);
340 template CUDA_LIB_EXPORT cudaError_t CalPadGeneral<int32_t>(const int32_t *input, int32_t *output, const PadInfo &info,
341                                                             const int input_size, const size_t input_rank,
342                                                             cudaStream_t cuda_stream);
343 template CUDA_LIB_EXPORT cudaError_t CalPadGeneral<int64_t>(const int64_t *input, int64_t *output, const PadInfo &info,
344                                                             const int input_size, const size_t input_rank,
345                                                             cudaStream_t cuda_stream);
346 template CUDA_LIB_EXPORT cudaError_t CalPadGeneral<uint8_t>(const uint8_t *input, uint8_t *output, const PadInfo &info,
347                                                             const int input_size, const size_t input_rank,
348                                                             cudaStream_t cuda_stream);
349 template CUDA_LIB_EXPORT cudaError_t CalPadGeneral<uint16_t>(const uint16_t *input, uint16_t *output,
350                                                              const PadInfo &info, const int input_size,
351                                                              const size_t input_rank, cudaStream_t cuda_stream);
352 template CUDA_LIB_EXPORT cudaError_t CalPadGeneral<uint32_t>(const uint32_t *input, uint32_t *output,
353                                                              const PadInfo &info, const int input_size,
354                                                              const size_t input_rank, cudaStream_t cuda_stream);
355 template CUDA_LIB_EXPORT cudaError_t CalPadGeneral<uint64_t>(const uint64_t *input, uint64_t *output,
356                                                              const PadInfo &info, const int input_size,
357                                                              const size_t input_rank, cudaStream_t cuda_stream);
358 template CUDA_LIB_EXPORT cudaError_t CalPadGeneral<bool>(const bool *input, bool *output, const PadInfo &info,
359                                                          const int input_size, const size_t input_rank,
360                                                          cudaStream_t cuda_stream);
361 template CUDA_LIB_EXPORT cudaError_t CalPadGeneral<Complex<float>>(const Complex<float> *input, Complex<float> *output,
362                                                                    const PadInfo &info, const int input_size,
363                                                                    const size_t input_rank, cudaStream_t cuda_stream);
364 template CUDA_LIB_EXPORT cudaError_t CalPadGeneral<Complex<double>>(const Complex<double> *input,
365                                                                     Complex<double> *output, const PadInfo &info,
366                                                                     const int input_size, const size_t input_rank,
367                                                                     cudaStream_t cuda_stream);
368 template CUDA_LIB_EXPORT cudaError_t CalPad3d<float>(const size_t size, const float *input, const int num,
369                                                      const int channels, const int old_depth, const int old_height,
370                                                      const int old_width, const int padded_depth,
371                                                      const int padded_height, const int padded_width,
372                                                      const int pad_head, const int pad_top, const int pad_left,
373                                                      const float pad_value, float *output, cudaStream_t cuda_stream);
374 template CUDA_LIB_EXPORT cudaError_t CalPad3d<half>(const size_t size, const half *input, const int num,
375                                                     const int channels, const int old_depth, const int old_height,
376                                                     const int old_width, const int padded_depth,
377                                                     const int padded_height, const int padded_width, const int pad_head,
378                                                     const int pad_top, const int pad_left, const float pad_value,
379                                                     half *output, cudaStream_t cuda_stream);
380 template CUDA_LIB_EXPORT cudaError_t CalPadGrad3d<float>(const size_t size, const float *dy, const int num,
381                                                          const int channels, const int old_depth, const int old_height,
382                                                          const int old_width, const int padded_depth,
383                                                          const int padded_height, const int padded_width,
384                                                          const int pad_head, const int pad_top, const int pad_left,
385                                                          float *dx, cudaStream_t cuda_stream);
386 template CUDA_LIB_EXPORT cudaError_t CalPadGrad3d<half>(const size_t size, const half *dy, const int num,
387                                                         const int channels, const int old_depth, const int old_height,
388                                                         const int old_width, const int padded_depth,
389                                                         const int padded_height, const int padded_width,
390                                                         const int pad_head, const int pad_top, const int pad_left,
391                                                         half *dx, cudaStream_t cuda_stream);
392 template CUDA_LIB_EXPORT cudaError_t CalPadGradNDHWC<float>(
393   const size_t size, const float *dy, const int num, const int old_depth, const int old_height, const int old_width,
394   const int channels, const int padded_depth, const int padded_height, const int padded_width, const int pad_head,
395   const int pad_top, const int pad_left, float *dx, cudaStream_t cuda_stream);
396 template CUDA_LIB_EXPORT cudaError_t CalPadGradNDHWC<half>(
397   const size_t size, const half *dy, const int num, const int old_depth, const int old_height, const int old_width,
398   const int channels, const int padded_depth, const int padded_height, const int padded_width, const int pad_head,
399   const int pad_top, const int pad_left, half *dx, cudaStream_t cuda_stream);
400