• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <algorithm>
18 #include "maxpool_grad_with_argmax_v2_impl.cuh"
19 #include "include/cuda_fp16.h"
20 #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
21 
22 template <typename T, typename S>
MaxPoolGradWithArgmaxV2(const T * dy,const S * index,const int64_t x_hw,const int64_t x_chw,const int64_t dy_hw,const int64_t dy_chw,const int64_t size,T * dx)23 __global__ void MaxPoolGradWithArgmaxV2(const T *dy, const S *index, const int64_t x_hw, const int64_t x_chw,
24                                         const int64_t dy_hw, const int64_t dy_chw, const int64_t size, T *dx) {
25   for (int64_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
26     const S idx = index[pos];
27     const int64_t pos_n = pos / dy_chw;
28     const int64_t pos_c = pos / dy_hw % (dy_chw / dy_hw);
29     MsAtomicAdd(dx + pos_n * x_chw + pos_c * x_hw + idx, dy[pos]);
30   }
31   return;
32 }
33 
34 template <typename T>
InitOutput(const int size,T * output)35 __global__ void InitOutput(const int size, T *output) {
36   T zero = 0;
37   for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id < size; id += blockDim.x * gridDim.x) {
38     output[id] = zero;
39   }
40   return;
41 }
42 
43 template <typename T, typename S>
CalMaxPoolGradWithArgmaxV2(const T * dy,const S * index,const int64_t x_hw,const int64_t x_chw,const int64_t x_nchw,const int64_t dy_hw,const int64_t dy_chw,const int64_t dy_nchw,T * dx,const uint32_t device_id,cudaStream_t cuda_stream)44 cudaError_t CalMaxPoolGradWithArgmaxV2(const T *dy, const S *index, const int64_t x_hw, const int64_t x_chw,
45                                        const int64_t x_nchw, const int64_t dy_hw, const int64_t dy_chw,
46                                        const int64_t dy_nchw, T *dx, const uint32_t device_id,
47                                        cudaStream_t cuda_stream) {
48   InitOutput<<<GET_BLOCKS(x_nchw), GET_THREADS, 0, cuda_stream>>>(x_nchw, dx);
49   MaxPoolGradWithArgmaxV2<<<GET_BLOCKS(dy_nchw), GET_THREADS, 0, cuda_stream>>>(dy, index, x_hw, x_chw, dy_hw, dy_chw,
50                                                                                 dy_nchw, dx);
51   return GetCudaStatus();
52 }
53 
54 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolGradWithArgmaxV2<half, int32_t>(
55   const half *dy, const int32_t *index, const int64_t x_hw, const int64_t x_chw, const int64_t x_nchw,
56   const int64_t dy_hw, const int64_t dy_chw, const int64_t dy_nchw, half *dx, const uint32_t device_id,
57   cudaStream_t cuda_stream);
58 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolGradWithArgmaxV2<float, int32_t>(
59   const float *dy, const int32_t *index, const int64_t x_hw, const int64_t x_chw, const int64_t x_nchw,
60   const int64_t dy_hw, const int64_t dy_chw, const int64_t dy_nchw, float *dx, const uint32_t device_id,
61   cudaStream_t cuda_stream);
62 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolGradWithArgmaxV2<double, int32_t>(
63   const double *dy, const int32_t *index, const int64_t x_hw, const int64_t x_chw, const int64_t x_nchw,
64   const int64_t dy_hw, const int64_t dy_chw, const int64_t dy_nchw, double *dx, const uint32_t device_id,
65   cudaStream_t cuda_stream);
66 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolGradWithArgmaxV2<int8_t, int32_t>(
67   const int8_t *dy, const int32_t *index, const int64_t x_hw, const int64_t x_chw, const int64_t x_nchw,
68   const int64_t dy_hw, const int64_t dy_chw, const int64_t dy_nchw, int8_t *dx, const uint32_t device_id,
69   cudaStream_t cuda_stream);
70 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolGradWithArgmaxV2<int16_t, int32_t>(
71   const int16_t *dy, const int32_t *index, const int64_t x_hw, const int64_t x_chw, const int64_t x_nchw,
72   const int64_t dy_hw, const int64_t dy_chw, const int64_t dy_nchw, int16_t *dx, const uint32_t device_id,
73   cudaStream_t cuda_stream);
74 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolGradWithArgmaxV2<int32_t, int32_t>(
75   const int32_t *dy, const int32_t *index, const int64_t x_hw, const int64_t x_chw, const int64_t x_nchw,
76   const int64_t dy_hw, const int64_t dy_chw, const int64_t dy_nchw, int32_t *dx, const uint32_t device_id,
77   cudaStream_t cuda_stream);
78 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolGradWithArgmaxV2<int64_t, int32_t>(
79   const int64_t *dy, const int32_t *index, const int64_t x_hw, const int64_t x_chw, const int64_t x_nchw,
80   const int64_t dy_hw, const int64_t dy_chw, const int64_t dy_nchw, int64_t *dx, const uint32_t device_id,
81   cudaStream_t cuda_stream);
82 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolGradWithArgmaxV2<uint8_t, int32_t>(
83   const uint8_t *dy, const int32_t *index, const int64_t x_hw, const int64_t x_chw, const int64_t x_nchw,
84   const int64_t dy_hw, const int64_t dy_chw, const int64_t dy_nchw, uint8_t *dx, const uint32_t device_id,
85   cudaStream_t cuda_stream);
86 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolGradWithArgmaxV2<uint16_t, int32_t>(
87   const uint16_t *dy, const int32_t *index, const int64_t x_hw, const int64_t x_chw, const int64_t x_nchw,
88   const int64_t dy_hw, const int64_t dy_chw, const int64_t dy_nchw, uint16_t *dx, const uint32_t device_id,
89   cudaStream_t cuda_stream);
90 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolGradWithArgmaxV2<uint32_t, int32_t>(
91   const uint32_t *dy, const int32_t *index, const int64_t x_hw, const int64_t x_chw, const int64_t x_nchw,
92   const int64_t dy_hw, const int64_t dy_chw, const int64_t dy_nchw, uint32_t *dx, const uint32_t device_id,
93   cudaStream_t cuda_stream);
94 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolGradWithArgmaxV2<uint64_t, int32_t>(
95   const uint64_t *dy, const int32_t *index, const int64_t x_hw, const int64_t x_chw, const int64_t x_nchw,
96   const int64_t dy_hw, const int64_t dy_chw, const int64_t dy_nchw, uint64_t *dx, const uint32_t device_id,
97   cudaStream_t cuda_stream);
98 
99 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolGradWithArgmaxV2<half, int64_t>(
100   const half *dy, const int64_t *index, const int64_t x_hw, const int64_t x_chw, const int64_t x_nchw,
101   const int64_t dy_hw, const int64_t dy_chw, const int64_t dy_nchw, half *dx, const uint32_t device_id,
102   cudaStream_t cuda_stream);
103 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolGradWithArgmaxV2<float, int64_t>(
104   const float *dy, const int64_t *index, const int64_t x_hw, const int64_t x_chw, const int64_t x_nchw,
105   const int64_t dy_hw, const int64_t dy_chw, const int64_t dy_nchw, float *dx, const uint32_t device_id,
106   cudaStream_t cuda_stream);
107 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolGradWithArgmaxV2<double, int64_t>(
108   const double *dy, const int64_t *index, const int64_t x_hw, const int64_t x_chw, const int64_t x_nchw,
109   const int64_t dy_hw, const int64_t dy_chw, const int64_t dy_nchw, double *dx, const uint32_t device_id,
110   cudaStream_t cuda_stream);
111 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolGradWithArgmaxV2<int8_t, int64_t>(
112   const int8_t *dy, const int64_t *index, const int64_t x_hw, const int64_t x_chw, const int64_t x_nchw,
113   const int64_t dy_hw, const int64_t dy_chw, const int64_t dy_nchw, int8_t *dx, const uint32_t device_id,
114   cudaStream_t cuda_stream);
115 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolGradWithArgmaxV2<int16_t, int64_t>(
116   const int16_t *dy, const int64_t *index, const int64_t x_hw, const int64_t x_chw, const int64_t x_nchw,
117   const int64_t dy_hw, const int64_t dy_chw, const int64_t dy_nchw, int16_t *dx, const uint32_t device_id,
118   cudaStream_t cuda_stream);
119 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolGradWithArgmaxV2<int32_t, int64_t>(
120   const int32_t *dy, const int64_t *index, const int64_t x_hw, const int64_t x_chw, const int64_t x_nchw,
121   const int64_t dy_hw, const int64_t dy_chw, const int64_t dy_nchw, int32_t *dx, const uint32_t device_id,
122   cudaStream_t cuda_stream);
123 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolGradWithArgmaxV2<int64_t, int64_t>(
124   const int64_t *dy, const int64_t *index, const int64_t x_hw, const int64_t x_chw, const int64_t x_nchw,
125   const int64_t dy_hw, const int64_t dy_chw, const int64_t dy_nchw, int64_t *dx, const uint32_t device_id,
126   cudaStream_t cuda_stream);
127 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolGradWithArgmaxV2<uint8_t, int64_t>(
128   const uint8_t *dy, const int64_t *index, const int64_t x_hw, const int64_t x_chw, const int64_t x_nchw,
129   const int64_t dy_hw, const int64_t dy_chw, const int64_t dy_nchw, uint8_t *dx, const uint32_t device_id,
130   cudaStream_t cuda_stream);
131 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolGradWithArgmaxV2<uint16_t, int64_t>(
132   const uint16_t *dy, const int64_t *index, const int64_t x_hw, const int64_t x_chw, const int64_t x_nchw,
133   const int64_t dy_hw, const int64_t dy_chw, const int64_t dy_nchw, uint16_t *dx, const uint32_t device_id,
134   cudaStream_t cuda_stream);
135 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolGradWithArgmaxV2<uint32_t, int64_t>(
136   const uint32_t *dy, const int64_t *index, const int64_t x_hw, const int64_t x_chw, const int64_t x_nchw,
137   const int64_t dy_hw, const int64_t dy_chw, const int64_t dy_nchw, uint32_t *dx, const uint32_t device_id,
138   cudaStream_t cuda_stream);
139 template CUDA_LIB_EXPORT cudaError_t CalMaxPoolGradWithArgmaxV2<uint64_t, int64_t>(
140   const uint64_t *dy, const int64_t *index, const int64_t x_hw, const int64_t x_chw, const int64_t x_nchw,
141   const int64_t dy_hw, const int64_t dy_chw, const int64_t dy_nchw, uint64_t *dx, const uint32_t device_id,
142   cudaStream_t cuda_stream);
143