• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <algorithm>
18 #include <limits>
19 #include "runtime/device/gpu/cuda_common.h"
20 #include "include/cuda_fp16.h"
21 #include "backend/kernel_compiler/gpu/cuda_impl/general_reduction_impl.cuh"
22 
23 const int kWarpSize = 32;
24 const int kBlockSize = 512;
25 const int kWarpGroup = 4;
26 const int kNumWarps = kBlockSize / kWarpSize;   // 16
27 const int kGroupSize = kWarpGroup * kWarpSize;  // 128
28 
29 // Mode selection constant
30 const int kMaxThreadLoop = 4;
31 const int kMaxWarpLoop = kWarpSize * 3;    // 32 * 3 = 96
32 const int kMaxGroupLoop = kGroupSize * 3;  // 128 * 3 =
33                                            // 384
34 
35 template <typename T, typename S>
36 struct Cmp {
ltCmp37   __device__ static inline bool lt(T a, T b, S i, S j) { return (a < b) || ((a == b) && (i < 0 || j < i)); }
gtCmp38   __device__ static inline bool gt(T a, T b, S i, S j) { return (a > b) || ((a == b) && (i < 0 || j < i)); }
39 };
40 
41 template <typename T>
ConditionAssign(bool is_assign,T * x,const T & y)42 inline __device__ void ConditionAssign(bool is_assign, T *x, const T &y) {
43   (*x) = is_assign ? y : (*x);
44 }
45 
46 template <typename T, typename S>
ThreadReduction(bool small,size_t outer_size,size_t bound,size_t inner_size,const T * input,T * output,S * output_index,bool fp16_flag,T init_K)47 __global__ void ThreadReduction(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input,
48                                 T *output, S *output_index, bool fp16_flag, T init_K) {
49   if (fp16_flag) {
50     init_K = small ? __int2half_rd(65504) : __int2half_rd(-65504);
51   }
52 
53   const S init_V = static_cast<S>(-1);
54 
55   for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < outer_size * inner_size;
56        t_idx += blockDim.x * gridDim.x) {
57     int outer_id = t_idx / inner_size;
58     int inner_id = t_idx % inner_size;
59 
60     T threadK = init_K;
61     S threadV = init_V;
62 
63     for (int i = 0; i < bound; i++) {
64       T other_K = input[outer_id * bound * inner_size + i * inner_size + inner_id];
65       S other_V = i;
66       bool is_winner =
67         small ? Cmp<T, S>::gt(threadK, other_K, threadV, other_V) : Cmp<T, S>::lt(threadK, other_K, threadV, other_V);
68       ConditionAssign(is_winner, &threadK, other_K);
69       ConditionAssign(is_winner, &threadV, other_V);
70     }
71 
72     output[outer_id * inner_size + inner_id] = threadK;
73     output_index[outer_id * inner_size + inner_id] = threadV;
74   }
75 }
76 
77 template <typename T, typename S>
WarpReduction(bool small,size_t outer_size,size_t bound,size_t inner_size,const T * input,T * output,S * output_index,bool fp16_flag,T init_K)78 __global__ void WarpReduction(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input, T *output,
79                               S *output_index, bool fp16_flag, T init_K) {
80   if (fp16_flag) {
81     init_K = small ? __int2half_rd(65504) : __int2half_rd(-65504);
82   }
83   const S init_V = static_cast<S>(-1);
84 
85   for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < kWarpSize * outer_size * inner_size;
86        t_idx += blockDim.x * gridDim.x) {
87     int outer_id = t_idx / kWarpSize / inner_size;
88     int inner_id = t_idx / kWarpSize % inner_size;
89 
90     int laneId = threadIdx.x % kWarpSize;
91 
92     T threadK = init_K;
93     S threadV = init_V;
94 
95     for (int i = laneId; i < bound; i += kWarpSize) {
96       T other_K = input[outer_id * bound * inner_size + i * inner_size + inner_id];
97       S other_V = i;
98       bool is_winner =
99         small ? Cmp<T, S>::gt(threadK, other_K, threadV, other_V) : Cmp<T, S>::lt(threadK, other_K, threadV, other_V);
100       ConditionAssign(is_winner, &threadK, other_K);
101       ConditionAssign(is_winner, &threadV, other_V);
102     }
103     __syncwarp();
104 
105     for (int offset = kWarpSize / 2; offset > 0; offset /= 2) {
106       T other_K = __shfl_down_sync(0xffffffff, threadK, offset);
107       S other_V = __shfl_down_sync(0xffffffff, threadV, offset);
108 
109       bool is_winner =
110         small ? Cmp<T, S>::gt(threadK, other_K, threadV, other_V) : Cmp<T, S>::lt(threadK, other_K, threadV, other_V);
111       ConditionAssign(is_winner, &threadK, other_K);
112       ConditionAssign(is_winner, &threadV, other_V);
113     }
114 
115     __syncwarp();
116 
117     if (laneId == 0) {
118       output[outer_id * inner_size + inner_id] = threadK;
119       output_index[outer_id * inner_size + inner_id] = threadV;
120     }
121     __syncthreads();
122   }
123 }
124 
125 template <typename T, typename S>
Warp4Reduction(bool small,size_t outer_size,size_t bound,size_t inner_size,const T * input,T * output,S * output_index,bool fp16_flag,T init_K)126 __global__ void Warp4Reduction(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input,
127                                T *output, S *output_index, bool fp16_flag, T init_K) {
128   __shared__ T shared_K[kNumWarps];
129   __shared__ S shared_V[kNumWarps];
130   if (fp16_flag) {
131     init_K = small ? __int2half_rd(65504) : __int2half_rd(-65504);
132   }
133   const S init_V = static_cast<S>(-1);
134 
135   for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < kGroupSize * outer_size * inner_size;
136        t_idx += blockDim.x * gridDim.x) {
137     int outer_id = t_idx / kGroupSize / inner_size;
138     int inner_id = t_idx / kGroupSize % inner_size;
139 
140     int groupId = threadIdx.x / kGroupSize;
141     int tgId = threadIdx.x % kGroupSize;
142     int warpId = threadIdx.x / kWarpSize;
143     int laneId = threadIdx.x % kWarpSize;
144 
145     T threadK = init_K;
146     S threadV = init_V;
147 
148     if (laneId == 0) {
149       shared_K[warpId] = init_K;
150       shared_V[warpId] = init_V;
151     }
152     __syncthreads();
153 
154     for (int i = tgId; i < bound; i += kGroupSize) {
155       T other_K = input[outer_id * bound * inner_size + i * inner_size + inner_id];
156       S other_V = i;
157       bool is_winner =
158         small ? Cmp<T, S>::gt(threadK, other_K, threadV, other_V) : Cmp<T, S>::lt(threadK, other_K, threadV, other_V);
159       ConditionAssign(is_winner, &threadK, other_K);
160       ConditionAssign(is_winner, &threadV, other_V);
161     }
162     __syncwarp();
163 
164     for (int offset = kWarpSize / 2; offset > 0; offset /= 2) {
165       T other_K = __shfl_down_sync(0xffffffff, threadK, offset);
166       S other_V = __shfl_down_sync(0xffffffff, threadV, offset);
167 
168       bool is_winner =
169         small ? Cmp<T, S>::gt(threadK, other_K, threadV, other_V) : Cmp<T, S>::lt(threadK, other_K, threadV, other_V);
170       ConditionAssign(is_winner, &threadK, other_K);
171       ConditionAssign(is_winner, &threadV, other_V);
172     }
173 
174     __syncwarp();
175 
176     if (laneId == 0) {
177       shared_K[warpId] = threadK;
178       shared_V[warpId] = threadV;
179     }
180     __syncthreads();
181 
182     if (tgId < 2) {
183       bool is_winner =
184         small ? Cmp<T, S>::gt(shared_K[(groupId * kWarpGroup) + tgId], shared_K[(groupId * kWarpGroup) + tgId + 2],
185                               shared_V[(groupId * kWarpGroup) + tgId], shared_V[(groupId * kWarpGroup) + tgId + 2])
186               : Cmp<T, S>::lt(shared_K[(groupId * kWarpGroup) + tgId], shared_K[(groupId * kWarpGroup) + tgId + 2],
187                               shared_V[(groupId * kWarpGroup) + tgId], shared_V[(groupId * kWarpGroup) + tgId + 2]);
188       ConditionAssign(is_winner, (shared_K + (groupId * kWarpGroup) + tgId),
189                       (shared_K[(groupId * kWarpGroup) + tgId + 2]));
190       ConditionAssign(is_winner, (shared_V + (groupId * kWarpGroup) + tgId),
191                       (shared_V[(groupId * kWarpGroup) + tgId + 2]));
192     }
193     __syncwarp();
194 
195     if (tgId == 0) {
196       bool is_winner =
197         small ? Cmp<T, S>::gt(shared_K[(groupId * kWarpGroup) + tgId], shared_K[(groupId * kWarpGroup) + tgId + 1],
198                               shared_V[(groupId * kWarpGroup) + tgId], shared_V[(groupId * kWarpGroup) + tgId + 1])
199               : Cmp<T, S>::lt(shared_K[(groupId * kWarpGroup) + tgId], shared_K[(groupId * kWarpGroup) + tgId + 1],
200                               shared_V[(groupId * kWarpGroup) + tgId], shared_V[(groupId * kWarpGroup) + tgId + 1]);
201       ConditionAssign(is_winner, (shared_K + (groupId * kWarpGroup) + tgId),
202                       (shared_K[(groupId * kWarpGroup) + tgId + 1]));
203       ConditionAssign(is_winner, (shared_V + (groupId * kWarpGroup) + tgId),
204                       (shared_V[(groupId * kWarpGroup) + tgId + 1]));
205 
206       // The first thread of each group write output
207       output[outer_id * inner_size + inner_id] = shared_K[groupId * kWarpGroup];
208       output_index[outer_id * inner_size + inner_id] = shared_V[groupId * kWarpGroup];
209     }
210     __syncthreads();
211   }
212 }
213 
214 template <typename T, typename S>
BlockReduction(bool small,size_t outer_size,size_t bound,size_t inner_size,const T * input,T * output,S * output_index,bool fp16_flag,T init_K)215 __global__ void BlockReduction(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input,
216                                T *output, S *output_index, bool fp16_flag, T init_K) {
217   __shared__ T shared_K[kNumWarps];
218   __shared__ S shared_V[kNumWarps];
219   if (fp16_flag) {
220     init_K = small ? __int2half_rd(65504) : __int2half_rd(-65504);
221   }
222   const S init_V = static_cast<S>(-1);
223 
224   for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < kBlockSize * outer_size * inner_size;
225        t_idx += blockDim.x * gridDim.x) {
226     int outer_id = t_idx / kBlockSize / inner_size;
227     int inner_id = t_idx / kBlockSize % inner_size;
228 
229     int tgId = threadIdx.x % kBlockSize;
230     int warpId = threadIdx.x / kWarpSize;
231     int laneId = threadIdx.x % kWarpSize;
232 
233     T threadK = init_K;
234     S threadV = init_V;
235 
236     if (laneId == 0) {
237       shared_K[warpId] = init_K;
238       shared_V[warpId] = init_V;
239     }
240     __syncthreads();
241 
242     for (int i = tgId; i < bound; i += kBlockSize) {
243       T other_K = input[outer_id * bound * inner_size + i * inner_size + inner_id];
244       S other_V = i;
245       bool is_winner =
246         small ? Cmp<T, S>::gt(threadK, other_K, threadV, other_V) : Cmp<T, S>::lt(threadK, other_K, threadV, other_V);
247       ConditionAssign(is_winner, &threadK, other_K);
248       ConditionAssign(is_winner, &threadV, other_V);
249     }
250     __syncwarp();
251 
252     for (int offset = kWarpSize / 2; offset > 0; offset /= 2) {
253       T other_K = __shfl_down_sync(0xffffffff, threadK, offset);
254       S other_V = __shfl_down_sync(0xffffffff, threadV, offset);
255 
256       bool is_winner =
257         small ? Cmp<T, S>::gt(threadK, other_K, threadV, other_V) : Cmp<T, S>::lt(threadK, other_K, threadV, other_V);
258       ConditionAssign(is_winner, &threadK, other_K);
259       ConditionAssign(is_winner, &threadV, other_V);
260     }
261 
262     __syncwarp();
263 
264     if (laneId == 0) {
265       shared_K[warpId] = threadK;
266       shared_V[warpId] = threadV;
267     }
268     __syncthreads();
269 
270     // Shared memory reduction
271     // There are 16 items in shared memory, can be reduced within one warp.
272     if (warpId == 0) {
273       threadK = laneId < kNumWarps ? shared_K[laneId] : init_K;
274       threadV = laneId < kNumWarps ? shared_V[laneId] : init_V;
275     }
276     __syncwarp();
277 
278     if (warpId == 0) {
279       for (int offset = kWarpSize / 4; offset > 0; offset /= 2) {
280         T other_K = __shfl_down_sync(0xffffffff, threadK, offset);
281         S other_V = __shfl_down_sync(0xffffffff, threadV, offset);
282 
283         bool is_winner =
284           small ? Cmp<T, S>::gt(threadK, other_K, threadV, other_V) : Cmp<T, S>::lt(threadK, other_K, threadV, other_V);
285         ConditionAssign(is_winner, &threadK, other_K);
286         ConditionAssign(is_winner, &threadV, other_V);
287       }
288     }
289     __syncwarp();
290 
291     if (warpId == 0 && laneId == 0) {
292       output[outer_id * inner_size + inner_id] = threadK;
293       output_index[outer_id * inner_size + inner_id] = threadV;
294     }
295   }
296 }
297 
298 template <typename T, typename S>
GeneralReduction(bool small,size_t outer_size,size_t bound,size_t inner_size,const T * input,T * output,S * output_index,cudaStream_t stream)299 void GeneralReduction(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input, T *output,
300                       S *output_index, cudaStream_t stream) {
301   int block_num_limit = outer_size * inner_size;
302   bool fp16_flag = false;
303   if (std::is_same<T, half>::value) {
304     fp16_flag = true;
305   }
306   T init_K = small ? std::numeric_limits<T>::max() : std::numeric_limits<T>::lowest();
307 
308   if (bound <= kMaxThreadLoop) {
309     ThreadReduction<T, S><<<GET_BLOCKS(block_num_limit), kBlockSize, 0, stream>>>(
310       small, outer_size, bound, inner_size, input, output, output_index, fp16_flag, init_K);
311   } else if (bound <= kMaxWarpLoop) {
312     WarpReduction<T, S><<<GET_BLOCKS(block_num_limit * kWarpSize), kBlockSize, 0, stream>>>(
313       small, outer_size, bound, inner_size, input, output, output_index, fp16_flag, init_K);
314   } else if (bound <= kMaxGroupLoop) {
315     Warp4Reduction<T, S><<<GET_BLOCKS(block_num_limit * kGroupSize), kBlockSize, 0, stream>>>(
316       small, outer_size, bound, inner_size, input, output, output_index, fp16_flag, init_K);
317   } else {
318     BlockReduction<T, S><<<GET_BLOCKS(block_num_limit * kBlockSize), kBlockSize, 0, stream>>>(
319       small, outer_size, bound, inner_size, input, output, output_index, fp16_flag, init_K);
320   }
321 }
322 
323 template <typename T, typename S>
CalGeneralReduction(bool small,const T * input,const size_t bound,const size_t outerSize,const size_t innerSize,S * index,T * output,cudaStream_t cuda_stream)324 void CalGeneralReduction(bool small, const T *input, const size_t bound, const size_t outerSize, const size_t innerSize,
325                          S *index, T *output, cudaStream_t cuda_stream) {
326   GeneralReduction(small, outerSize, bound, innerSize, input, output, index, cuda_stream);
327   return;
328 }
329 
330 template void CalGeneralReduction(bool small, const double *input, const size_t bound_, const size_t outerSize_,
331                                   const size_t innerSize_, int *index, double *output, cudaStream_t cuda_stream);
332 template void CalGeneralReduction(bool small, const float *input, const size_t bound_, const size_t outerSize_,
333                                   const size_t innerSize_, int *index, float *output, cudaStream_t cuda_stream);
334 template void CalGeneralReduction(bool small, const half *input, const size_t bound_, const size_t outerSize_,
335                                   const size_t innerSize_, int *index, half *output, cudaStream_t cuda_stream);
336