1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <math.h>
18 #include <iostream>
19 #include <limits>
20 #include <algorithm>
21 #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
22 #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
23 #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/softmax_impl.cuh"
24
25 #define WARPSIZE 32
26 const int max_threads = 1024;
27 constexpr int ALIGN_BYTES = 16;
28
SpatialSoftMaxGetGridSize(dim3 * block,uint32_t activate_block,uint64_t outer_size,uint64_t dim_size,uint64_t inner_size)29 inline dim3 SpatialSoftMaxGetGridSize(dim3 *block, uint32_t activate_block, uint64_t outer_size, uint64_t dim_size,
30 uint64_t inner_size) {
31 uint32_t inner = (inner_size + block->y - 1) / block->y;
32 if (inner > activate_block) {
33 inner = activate_block;
34 }
35 uint32_t outer = (activate_block + inner - 1) / inner;
36 if (outer > outer_size) {
37 outer = outer_size;
38 }
39 return dim3(outer, inner);
40 }
41
SpatialSoftMaxGetBlockSize(uint64_t outer_size,uint64_t dim_size,uint64_t inner_size)42 inline dim3 SpatialSoftMaxGetBlockSize(uint64_t outer_size, uint64_t dim_size, uint64_t inner_size) {
43 uint32_t inner_ths = inner_size;
44 inner_ths = std::min(inner_ths, static_cast<uint32_t>(max_threads));
45 uint32_t dim_threads = 1;
46 if (inner_ths <= 64 && dim_size >= 64) {
47 while ((inner_ths * dim_threads <= max_threads) && (dim_threads <= dim_size)) {
48 dim_threads *= 2;
49 }
50 dim_threads /= 2;
51 }
52 return dim3(dim_threads, inner_ths);
53 }
54
55 template <typename accumulate_t, typename Kernel>
SpatialSoftMaxGetLaunchSizes(Kernel k,uint64_t outer_size,uint64_t dim_size,uint64_t inner_size,dim3 * grid,dim3 * block,uint32_t * smem_size,uint32_t device_id)56 void SpatialSoftMaxGetLaunchSizes(Kernel k, uint64_t outer_size, uint64_t dim_size, uint64_t inner_size, dim3 *grid,
57 dim3 *block, uint32_t *smem_size, uint32_t device_id) {
58 *block = SpatialSoftMaxGetBlockSize(outer_size, dim_size, inner_size);
59 uint32_t block_ths = block->x * block->y;
60 if (block->x == 1) {
61 *smem_size = 0;
62 } else {
63 *smem_size = block_ths * sizeof(accumulate_t);
64 }
65 int activate_size;
66 cudaOccupancyMaxActiveBlocksPerMultiprocessor(&activate_size, k, block_ths, *smem_size);
67 cudaDeviceProp prop;
68 (void)cudaGetDeviceProperties(&prop, device_id);
69 activate_size *= prop.multiProcessorCount;
70 *grid = SpatialSoftMaxGetGridSize(block, activate_size, outer_size, dim_size, inner_size);
71 }
72
log2_ceil(int val)73 int log2_ceil(int val) {
74 int final_val = 0;
75 while ((1 << final_val) < val) ++final_val;
76 return final_val;
77 }
78
SoftMaxGetBlockSize(int ins,uint64_t dim)79 inline dim3 SoftMaxGetBlockSize(int ins, uint64_t dim) {
80 uint64_t block_size = 1;
81 uint64_t max_block_size = std::min(dim / ins, static_cast<uint64_t>(max_threads));
82
83 if (ins > 1) {
84 max_block_size /= 2;
85 }
86
87 while (block_size < (max_block_size)) {
88 block_size *= 2;
89 }
90 block_size = std::max(block_size, static_cast<uint64_t>(WARPSIZE));
91 return dim3(block_size);
92 }
93
94 template <typename T, typename AccT>
95 struct GetMaxFloat {
operator ()GetMaxFloat96 __device__ __forceinline__ AccT operator()(AccT max, T v) const { return ::max(max, (AccT)v); }
97 };
98
99 template <typename T, typename AccT>
100 struct GetSumExpFloat {
GetSumExpFloatGetSumExpFloat101 __device__ __forceinline__ GetSumExpFloat(AccT v) : max_k(v) {}
102
operator ()GetSumExpFloat103 __device__ __forceinline__ AccT operator()(AccT sum, T v) const { return sum + std::exp((AccT)v - max_k); }
104
105 const AccT max_k;
106 };
107
108 template <typename T>
WARPSHFL_XOR(T value,int laneMask,int width=warpSize,unsigned int mask=0xffffffff)109 __device__ __forceinline__ T WARPSHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) {
110 #ifndef __HIP_PLATFORM_HCC__
111 return __shfl_xor_sync(mask, value, laneMask, width);
112 #else
113 return __shfl_xor(value, laneMask, width);
114 #endif
115 }
116
117 template <typename T, typename Function>
SpatialBlockReduceX(T * memsha,T val)118 __forceinline__ __device__ T SpatialBlockReduceX(T *memsha, T val) {
119 Function r = Function();
120 memsha += threadIdx.y * blockDim.x;
121 __syncthreads();
122 memsha[threadIdx.x] = val;
123 int offset = blockDim.x / 2;
124 while (offset > 0) {
125 __syncthreads();
126 if (threadIdx.x < offset) memsha[threadIdx.x] = r(memsha[threadIdx.x], memsha[threadIdx.x + offset]);
127 offset /= 2;
128 }
129 __syncthreads();
130 return memsha[0];
131 }
132
133 template <typename input_t, typename accumulate_t, typename output_t, bool is_log_softmax>
SpatialSoftMaxForward(output_t * output,input_t * input,uint32_t outer_size,uint32_t dim_size,uint32_t inner_size)134 __global__ void SpatialSoftMaxForward(output_t *output, input_t *input, uint32_t outer_size, uint32_t dim_size,
135 uint32_t inner_size) {
136 extern __shared__ unsigned char smem[];
137 auto sdata = reinterpret_cast<accumulate_t *>(smem);
138 const uint32_t outer_stride = inner_size * dim_size;
139 const uint32_t dim_stride = inner_size;
140 for (uint32_t outer_index = blockIdx.x; outer_index < outer_size; outer_index += gridDim.x) {
141 const uint32_t outer_offset = outer_index * outer_stride;
142 for (uint32_t inner_index = blockIdx.y * blockDim.y + threadIdx.y; inner_index < inner_size;
143 inner_index += blockDim.y * gridDim.y) {
144 const uint32_t offset = outer_offset + inner_index;
145 if (blockDim.x > 1) {
146 accumulate_t max_data_input = std::numeric_limits<accumulate_t>::lowest();
147 for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
148 const accumulate_t value = static_cast<accumulate_t>(input[offset + d * dim_stride]);
149 max_data_input = atomic::Max()(max_data_input, value);
150 }
151 max_data_input = SpatialBlockReduceX<accumulate_t, atomic::Max>(sdata, max_data_input);
152 accumulate_t sum = 0;
153 for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
154 sum += std::exp(static_cast<accumulate_t>(input[offset + d * dim_stride]) - max_data_input);
155 sum = SpatialBlockReduceX<accumulate_t, atomic::Add>(sdata, sum);
156 SoftMaxForwardEpilogue<input_t, accumulate_t, output_t, is_log_softmax> epilogue(max_data_input, sum);
157 for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
158 output[offset + d * dim_stride] = epilogue(input[offset + d * dim_stride]);
159 } else {
160 accumulate_t max_data_input = std::numeric_limits<accumulate_t>::lowest();
161 for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
162 const accumulate_t value = static_cast<accumulate_t>(input[offset + d * dim_stride]);
163 max_data_input = atomic::Max()(max_data_input, value);
164 }
165 accumulate_t sum = 0;
166 for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
167 sum += std::exp(static_cast<accumulate_t>(input[offset + d * dim_stride]) - max_data_input);
168
169 SoftMaxForwardEpilogue<input_t, accumulate_t, output_t, is_log_softmax> epilogue(max_data_input, sum);
170 for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
171 output[offset + d * dim_stride] = epilogue(input[offset + d * dim_stride]);
172 }
173 }
174 }
175 }
176
177 template <int InsP, typename input_t, typename accum_t, typename output_t, bool is_log_softmax>
WriteResults(int clas,input_t * input,output_t * output,input_t max_k,input_t sum_all)178 __device__ __forceinline__ void WriteResults(int clas, input_t *input, output_t *output, input_t max_k,
179 input_t sum_all) {
180 SoftMaxForwardEpilogue<input_t, accum_t, output_t, is_log_softmax> epilogue(max_k, sum_all);
181 int offset = threadIdx.x;
182 int last = clas % (InsP * blockDim.x);
183 for (; offset < clas - last; offset += blockDim.x * InsP) {
184 input_t tmp[InsP];
185 #pragma unroll
186 for (int j = 0; j < InsP; ++j) {
187 tmp[j] = input[offset + j * blockDim.x];
188 }
189 #pragma unroll
190 for (int j = 0; j < InsP; ++j) {
191 output[offset + j * blockDim.x] = epilogue(tmp[j]);
192 }
193 }
194 for (; offset < clas; offset += blockDim.x) {
195 output[offset] = epilogue(input[offset]);
196 }
197 }
198
199 template <int InsP, typename input_t, typename accum_t, typename output_t, bool is_log_softmax>
WriteResultsVectorized(int size,const int deviate,input_t * input,output_t * output,input_t max_k,input_t sum_all)200 __device__ __forceinline__ void WriteResultsVectorized(int size, const int deviate, input_t *input, output_t *output,
201 input_t max_k, input_t sum_all) {
202 SoftMaxForwardEpilogue<input_t, accum_t, output_t, is_log_softmax> epilogue(max_k, sum_all);
203 using loadT = aligned_vector<input_t>;
204 using storeT = aligned_vector<output_t>;
205 int offset = threadIdx.x;
206 if (deviate > 0) {
207 input -= deviate;
208 output -= deviate;
209 size += deviate;
210 if (threadIdx.x >= deviate) {
211 output[offset] = epilogue(input[offset]);
212 }
213 size -= blockDim.x;
214 input += blockDim.x;
215 output += blockDim.x;
216 }
217 const int last = size % (InsP * blockDim.x);
218 input_t in_v[InsP];
219 loadT *in_value = reinterpret_cast<loadT *>(&in_v);
220 output_t out_v[InsP];
221 storeT *out_value = reinterpret_cast<storeT *>(&out_v);
222 for (; offset * InsP < (size - last); offset += blockDim.x) {
223 *in_value = reinterpret_cast<loadT *>(input)[offset];
224 #pragma unroll
225 for (int j = 0; j < InsP; ++j) {
226 out_v[j] = epilogue(in_v[j]);
227 }
228 reinterpret_cast<storeT *>(output)[offset] = *out_value;
229 }
230 offset = size - last + threadIdx.x;
231 for (; offset < size; offset += blockDim.x) {
232 output[offset] = epilogue(input[offset]);
233 }
234 }
235
236 template <typename Reduction, typename AccT>
ReduceBlock(AccT * sharemen,AccT val,AccT initVal)237 __device__ __forceinline__ AccT ReduceBlock(AccT *sharemen, AccT val, AccT initVal) {
238 Reduction r = Reduction();
239 __syncthreads();
240 sharemen[threadIdx.x] = val;
241 __syncthreads();
242 AccT warpVal = initVal;
243 uint32_t mask = (((uint64_t)1) << (blockDim.x / WARPSIZE)) - 1;
244 if (threadIdx.x < WARPSIZE) {
245 int lane = threadIdx.x % WARPSIZE;
246 if (lane < blockDim.x / WARPSIZE) {
247 #pragma unroll
248 for (int i = 0; i < WARPSIZE; ++i) {
249 warpVal = r(warpVal, sharemen[lane * WARPSIZE + i]);
250 }
251 #ifndef __HIP_PLATFORM_HCC__
252 __syncwarp(mask);
253 #endif
254 sharemen[lane] = warpVal;
255 }
256 }
257 __syncthreads();
258 AccT blockVal = initVal;
259 if (threadIdx.x == 0) {
260 for (int i = 0; i < blockDim.x / WARPSIZE; ++i) {
261 blockVal = r(blockVal, sharemen[i]);
262 }
263 sharemen[0] = blockVal;
264 }
265 __syncthreads();
266 return sharemen[0];
267 }
268
269 template <template <typename, typename> class Reduction, int InsP, typename T, typename AccT>
ILPReduce(int shift,T * data,int size,const Reduction<T,AccT> & r,AccT initVal)270 __device__ __forceinline__ AccT ILPReduce(int shift, T *data, int size, const Reduction<T, AccT> &r, AccT initVal) {
271 using loadT = aligned_vector<T>;
272 AccT threadVal = initVal;
273 int offset = threadIdx.x;
274 if (shift > 0) {
275 data -= shift;
276 size += shift;
277 if (threadIdx.x >= shift) {
278 threadVal = r(threadVal, data[offset]);
279 }
280 size -= blockDim.x;
281 data += blockDim.x;
282 }
283 int last = size % (InsP * blockDim.x);
284 T v[InsP];
285 loadT *value = reinterpret_cast<loadT *>(&v);
286 for (; offset * InsP < (size - last); offset += blockDim.x) {
287 *value = reinterpret_cast<loadT *>(data)[offset];
288 #pragma unroll
289 for (int j = 0; j < InsP; ++j) {
290 threadVal = r(threadVal, v[j]);
291 }
292 }
293 offset = size - last + threadIdx.x;
294 for (; offset < size; offset += blockDim.x) threadVal = r(threadVal, data[offset]);
295 return threadVal;
296 }
297
298 template <typename acc_t, int WARP_BATCH, int WARP_SIZE, typename func>
warp_reduce(acc_t * sum)299 __device__ __forceinline__ void warp_reduce(acc_t *sum) {
300 #pragma unroll
301 for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
302 #pragma unroll
303 for (int i = 0; i < WARP_BATCH; ++i) {
304 acc_t b = WARPSHFL_XOR(sum[i], offset, WARP_SIZE);
305 sum[i] = func()(sum[i], b);
306 }
307 }
308 }
309
310 template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax, bool is_masked>
SoftMaxWarpForward(output_t * dst,const input_t * src,int batch_size,int stride,int element_count,const bool * mask=nullptr,const int head_chunk_size=-1,bool is_transformer_mask=false)311 __global__ void SoftMaxWarpForward(output_t *dst, const input_t *src, int batch_size, int stride, int element_count,
312 const bool *mask = nullptr, const int head_chunk_size = -1,
313 bool is_transformer_mask = false) {
314 constexpr int next_power_of_two = 1 << log2_elements;
315 constexpr int WARP_SIZE = (next_power_of_two < WARPSIZE) ? next_power_of_two : WARPSIZE;
316 constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
317 constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
318
319 int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
320 int local_batches = batch_size - first_batch;
321 if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
322 int local_idx = threadIdx.x;
323 int idx_offset = first_batch * stride + local_idx;
324
325 src += idx_offset;
326 dst += idx_offset;
327
328 if (is_transformer_mask) {
329 mask += ((first_batch * stride) / head_chunk_size) * stride + local_idx;
330 } else {
331 mask += idx_offset;
332 }
333 acc_t elements[WARP_BATCH][WARP_ITERATIONS];
334 for (int i = 0; i < WARP_BATCH; ++i) {
335 int batch_element_count = (i >= local_batches) ? 0 : element_count;
336 for (int it = 0; it < WARP_ITERATIONS; ++it) {
337 int element_index = local_idx + it * WARP_SIZE;
338 if (element_index < batch_element_count) {
339 elements[i][it] = src[i * element_count + it * WARP_SIZE];
340 } else {
341 elements[i][it] = -std::numeric_limits<acc_t>::infinity();
342 }
343 }
344 }
345
346 acc_t max_value[WARP_BATCH];
347 #pragma unroll
348 for (int i = 0; i < WARP_BATCH; ++i) {
349 int batch_element_count = (i >= local_batches) ? 0 : element_count;
350 bool is_meaningful_max = false;
351 max_value[i] = elements[i][0];
352 #pragma unroll
353 for (int it = 0; it < WARP_ITERATIONS; ++it) {
354 if (is_masked) {
355 int idx = it * WARP_SIZE;
356 if ((idx + local_idx) < batch_element_count) {
357 if (!is_transformer_mask) {
358 idx += i * element_count;
359 }
360 if (!mask[idx]) {
361 max_value[i] = (is_meaningful_max && max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
362 is_meaningful_max = true;
363 }
364 }
365 } else {
366 max_value[i] = max_value[i] > elements[i][it] ? max_value[i] : elements[i][it];
367 }
368 }
369 if (is_masked) {
370 if (!is_meaningful_max) {
371 max_value[i] = -std::numeric_limits<acc_t>::infinity();
372 }
373 }
374 }
375 warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, atomic::Max>(max_value);
376
377 acc_t sum[WARP_BATCH]{0.0f};
378 #pragma unroll
379 for (int i = 0; i < WARP_BATCH; ++i) {
380 int batch_element_count = (i >= local_batches) ? 0 : element_count;
381 #pragma unroll
382 for (int it = 0; it < WARP_ITERATIONS; ++it) {
383 if (!is_masked) {
384 if (is_log_softmax) {
385 sum[i] += std::exp(elements[i][it] - max_value[i]);
386 } else {
387 elements[i][it] = std::exp(elements[i][it] - max_value[i]);
388 sum[i] += elements[i][it];
389 }
390 } else {
391 int idx = it * WARP_SIZE;
392 bool valid = (idx + local_idx) < batch_element_count;
393 if (!is_transformer_mask) {
394 idx += i * element_count;
395 }
396 if (valid) {
397 if (!mask[idx]) {
398 if (is_log_softmax) {
399 sum[i] += std::exp(elements[i][it] - max_value[i]);
400 } else {
401 elements[i][it] = std::exp(elements[i][it] - max_value[i]);
402 sum[i] += elements[i][it];
403 }
404 } else {
405 if (!is_log_softmax) {
406 elements[i][it] = 0;
407 }
408 }
409 } else {
410 if (!is_log_softmax) {
411 elements[i][it] = 0.;
412 }
413 }
414 }
415 }
416 }
417 warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, atomic::Add>(sum);
418 #pragma unroll
419 for (int i = 0; i < WARP_BATCH; ++i) {
420 if (i >= local_batches) break;
421 if (is_log_softmax) sum[i] = std::log(sum[i]);
422 #pragma unroll
423 for (int it = 0; it < WARP_ITERATIONS; ++it) {
424 int element_index = local_idx + it * WARP_SIZE;
425 if (element_index < element_count) {
426 if (is_log_softmax) {
427 dst[i * element_count + it * WARP_SIZE] = elements[i][it] - max_value[i] - sum[i];
428 } else if (sum[i] == 0) {
429 dst[i * element_count + it * WARP_SIZE] = std::numeric_limits<acc_t>::quiet_NaN();
430 } else {
431 dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
432 }
433 } else {
434 break;
435 }
436 }
437 }
438 }
439
440 template <int InsP, typename T, typename accumulate_t, bool is_log_softmax>
cunn_SoftMaxForward(T * output,T * input,int classes)441 __global__ void cunn_SoftMaxForward(T *output, T *input, int classes) {
442 extern __shared__ unsigned char smem[];
443 auto sdata = reinterpret_cast<accumulate_t *>(smem);
444
445 using loadT = aligned_vector<T>;
446 using storeT = aligned_vector<T>;
447 input += blockIdx.x * classes;
448 output += blockIdx.x * classes;
449 const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(T);
450 const int output_shift = ((uint64_t)output) % ALIGN_BYTES / sizeof(T);
451
452 accumulate_t threadMax = ILPReduce<GetMaxFloat, InsP, T, accumulate_t>(
453 shift, input, classes, GetMaxFloat<T, accumulate_t>(), -std::numeric_limits<accumulate_t>::max());
454 accumulate_t max_k =
455 ReduceBlock<atomic::Max, accumulate_t>(sdata, threadMax, -std::numeric_limits<accumulate_t>::max());
456
457 accumulate_t threadExp = ILPReduce<GetSumExpFloat, InsP, T, accumulate_t>(
458 shift, input, classes, GetSumExpFloat<T, accumulate_t>(max_k), static_cast<accumulate_t>(0));
459 accumulate_t sumAll = ReduceBlock<atomic::Add, accumulate_t>(sdata, threadExp, static_cast<accumulate_t>(0));
460
461 if (shift == output_shift) {
462 WriteResultsVectorized<InsP, T, accumulate_t, T, is_log_softmax>(classes, shift, input, output, max_k, sumAll);
463 } else {
464 WriteResults<InsP, T, accumulate_t, T, is_log_softmax>(classes, input, output, max_k, sumAll);
465 }
466 }
467
468 // end of kernel function
469 template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax, bool is_masked>
dispatch_softmax_forward(output_t * dst,const input_t * src,int softmax_elements,int softmax_elements_stride,int batch_count,cudaStream_t stream,const bool * mask=nullptr,int chunk_size=-1,bool is_transformer_mask=false)470 void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_elements, int softmax_elements_stride,
471 int batch_count, cudaStream_t stream, const bool *mask = nullptr, int chunk_size = -1,
472 bool is_transformer_mask = false) {
473 if (softmax_elements == 0) {
474 return;
475 } else {
476 int log2_elements = log2_ceil(softmax_elements);
477 const int next_power_of_two = 1 << log2_elements;
478
479 int warp_size = (next_power_of_two < WARPSIZE) ? next_power_of_two : WARPSIZE;
480 int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
481 constexpr int threads_per_block = 128;
482
483 int warps_per_block = (threads_per_block / warp_size);
484 int batches_per_block = warps_per_block * batches_per_warp;
485 int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
486 dim3 threads(warp_size, warps_per_block, 1);
487
488 switch (log2_elements) {
489 #define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) \
490 case L2E: \
491 SoftMaxWarpForward<input_t, output_t, acc_t, L2E, is_log_softmax, is_masked><<<blocks, threads, 0, stream>>>( \
492 dst, src, batch_count, softmax_elements_stride, softmax_elements, mask, chunk_size, is_transformer_mask); \
493 break;
494
495 LAUNCH_SOFTMAX_WARP_FORWARD(0); // 1
496 LAUNCH_SOFTMAX_WARP_FORWARD(1); // 2
497 LAUNCH_SOFTMAX_WARP_FORWARD(2); // 4
498 LAUNCH_SOFTMAX_WARP_FORWARD(3); // 8
499 LAUNCH_SOFTMAX_WARP_FORWARD(4); // 16
500 LAUNCH_SOFTMAX_WARP_FORWARD(5); // 32
501 LAUNCH_SOFTMAX_WARP_FORWARD(6); // 64
502 LAUNCH_SOFTMAX_WARP_FORWARD(7); // 128
503 LAUNCH_SOFTMAX_WARP_FORWARD(8); // 256
504 LAUNCH_SOFTMAX_WARP_FORWARD(9); // 512
505 LAUNCH_SOFTMAX_WARP_FORWARD(10); // 1024
506 default:
507 break;
508 }
509 }
510 }
511
512 template <typename T, bool is_log_softmax>
Softmax(T * input_,T * output_,size_t dim_size_,size_t outer_size_,size_t inner_size_,size_t device_id,cudaStream_t cuda_stream)513 cudaError_t Softmax(T *input_, T *output_, size_t dim_size_, size_t outer_size_, size_t inner_size_, size_t device_id,
514 cudaStream_t cuda_stream) {
515 using accumulate_t = acc_type<T, true>;
516 if (inner_size_ == 1) {
517 dim3 grid(outer_size_);
518 if (dim_size_ <= 1024 && dim_size_ * sizeof(T) <= 4096) {
519 int64_t remaining = outer_size_;
520 int64_t chunk_size = (1L << 30L) / dim_size_;
521 while (remaining > 0) {
522 dispatch_softmax_forward<T, T, accumulate_t, is_log_softmax, false>(output_, input_, dim_size_, dim_size_,
523 std::min<int64_t>(remaining, chunk_size),
524 cuda_stream, nullptr /* not masked */);
525 input_ += chunk_size * dim_size_;
526 output_ += chunk_size * dim_size_;
527 remaining -= chunk_size;
528 }
529 } else {
530 constexpr int InsP = sizeof(float4) / sizeof(T);
531 dim3 block = SoftMaxGetBlockSize(InsP, dim_size_);
532 cunn_SoftMaxForward<InsP, T, accumulate_t, is_log_softmax>
533 <<<grid, block, block.x * sizeof(accumulate_t), cuda_stream>>>(output_, input_, dim_size_);
534 }
535 } else {
536 uint32_t smem_size;
537 dim3 grid, block;
538 SpatialSoftMaxGetLaunchSizes<T>(&SpatialSoftMaxForward<T, accumulate_t, T, is_log_softmax>, outer_size_, dim_size_,
539 inner_size_, &grid, &block, &smem_size, device_id);
540 SpatialSoftMaxForward<T, accumulate_t, T, is_log_softmax>
541 <<<grid, block, smem_size, cuda_stream>>>(output_, input_, outer_size_, dim_size_, inner_size_);
542 }
543 return GetCudaStatus();
544 }
545
546 template CUDA_LIB_EXPORT cudaError_t Softmax<double, false>(double *input_, double *output_, size_t dim_,
547 size_t outer_size_, size_t inner_size_, size_t device_id,
548 cudaStream_t cuda_stream);
549 template CUDA_LIB_EXPORT cudaError_t Softmax<double, true>(double *input_, double *output_, size_t dim_,
550 size_t outer_size_, size_t inner_size_, size_t device_id,
551 cudaStream_t cuda_stream);
552 template CUDA_LIB_EXPORT cudaError_t Softmax<float, false>(float *input_, float *output_, size_t dim_,
553 size_t outer_size_, size_t inner_size_, size_t device_id,
554 cudaStream_t cuda_stream);
555 template CUDA_LIB_EXPORT cudaError_t Softmax<float, true>(float *input_, float *output_, size_t dim_,
556 size_t outer_size_, size_t inner_size_, size_t device_id,
557 cudaStream_t cuda_stream);
558 template CUDA_LIB_EXPORT cudaError_t Softmax<half, false>(half *input_, half *output_, size_t dim_, size_t outer_size_,
559 size_t inner_size_, size_t device_id,
560 cudaStream_t cuda_stream);
561 template CUDA_LIB_EXPORT cudaError_t Softmax<half, true>(half *input_, half *output_, size_t dim_, size_t outer_size_,
562 size_t inner_size_, size_t device_id,
563 cudaStream_t cuda_stream);
564