• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <math.h>
18 #include <iostream>
19 #include <limits>
20 #include <algorithm>
21 #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
22 #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
23 #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/softmax_impl.cuh"
24 
25 #define WARPSIZE 32
26 const int max_threads = 1024;
27 constexpr int ALIGN_BYTES = 16;
28 
SpatialSoftMaxGetGridSize(dim3 * block,uint32_t activate_block,uint64_t outer_size,uint64_t dim_size,uint64_t inner_size)29 inline dim3 SpatialSoftMaxGetGridSize(dim3 *block, uint32_t activate_block, uint64_t outer_size, uint64_t dim_size,
30                                       uint64_t inner_size) {
31   uint32_t inner = (inner_size + block->y - 1) / block->y;
32   if (inner > activate_block) {
33     inner = activate_block;
34   }
35   uint32_t outer = (activate_block + inner - 1) / inner;
36   if (outer > outer_size) {
37     outer = outer_size;
38   }
39   return dim3(outer, inner);
40 }
41 
SpatialSoftMaxGetBlockSize(uint64_t outer_size,uint64_t dim_size,uint64_t inner_size)42 inline dim3 SpatialSoftMaxGetBlockSize(uint64_t outer_size, uint64_t dim_size, uint64_t inner_size) {
43   uint32_t inner_ths = inner_size;
44   inner_ths = std::min(inner_ths, static_cast<uint32_t>(max_threads));
45   uint32_t dim_threads = 1;
46   if (inner_ths <= 64 && dim_size >= 64) {
47     while ((inner_ths * dim_threads <= max_threads) && (dim_threads <= dim_size)) {
48       dim_threads *= 2;
49     }
50     dim_threads /= 2;
51   }
52   return dim3(dim_threads, inner_ths);
53 }
54 
55 template <typename accumulate_t, typename Kernel>
SpatialSoftMaxGetLaunchSizes(Kernel k,uint64_t outer_size,uint64_t dim_size,uint64_t inner_size,dim3 * grid,dim3 * block,uint32_t * smem_size,uint32_t device_id)56 void SpatialSoftMaxGetLaunchSizes(Kernel k, uint64_t outer_size, uint64_t dim_size, uint64_t inner_size, dim3 *grid,
57                                   dim3 *block, uint32_t *smem_size, uint32_t device_id) {
58   *block = SpatialSoftMaxGetBlockSize(outer_size, dim_size, inner_size);
59   uint32_t block_ths = block->x * block->y;
60   if (block->x == 1) {
61     *smem_size = 0;
62   } else {
63     *smem_size = block_ths * sizeof(accumulate_t);
64   }
65   int activate_size;
66   cudaOccupancyMaxActiveBlocksPerMultiprocessor(&activate_size, k, block_ths, *smem_size);
67   cudaDeviceProp prop;
68   (void)cudaGetDeviceProperties(&prop, device_id);
69   activate_size *= prop.multiProcessorCount;
70   *grid = SpatialSoftMaxGetGridSize(block, activate_size, outer_size, dim_size, inner_size);
71 }
72 
log2_ceil(int val)73 int log2_ceil(int val) {
74   int final_val = 0;
75   while ((1 << final_val) < val) ++final_val;
76   return final_val;
77 }
78 
SoftMaxGetBlockSize(int ins,uint64_t dim)79 inline dim3 SoftMaxGetBlockSize(int ins, uint64_t dim) {
80   uint64_t block_size = 1;
81   uint64_t max_block_size = std::min(dim / ins, static_cast<uint64_t>(max_threads));
82 
83   if (ins > 1) {
84     max_block_size /= 2;
85   }
86 
87   while (block_size < (max_block_size)) {
88     block_size *= 2;
89   }
90   block_size = std::max(block_size, static_cast<uint64_t>(WARPSIZE));
91   return dim3(block_size);
92 }
93 
94 template <typename T, typename AccT>
95 struct GetMaxFloat {
operator ()GetMaxFloat96   __device__ __forceinline__ AccT operator()(AccT max, T v) const { return ::max(max, (AccT)v); }
97 };
98 
99 template <typename T, typename AccT>
100 struct GetSumExpFloat {
GetSumExpFloatGetSumExpFloat101   __device__ __forceinline__ GetSumExpFloat(AccT v) : max_k(v) {}
102 
operator ()GetSumExpFloat103   __device__ __forceinline__ AccT operator()(AccT sum, T v) const { return sum + std::exp((AccT)v - max_k); }
104 
105   const AccT max_k;
106 };
107 
108 template <typename T>
WARPSHFL_XOR(T value,int laneMask,int width=warpSize,unsigned int mask=0xffffffff)109 __device__ __forceinline__ T WARPSHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) {
110 #ifndef __HIP_PLATFORM_HCC__
111   return __shfl_xor_sync(mask, value, laneMask, width);
112 #else
113   return __shfl_xor(value, laneMask, width);
114 #endif
115 }
116 
117 template <typename T, typename Function>
SpatialBlockReduceX(T * memsha,T val)118 __forceinline__ __device__ T SpatialBlockReduceX(T *memsha, T val) {
119   Function r = Function();
120   memsha += threadIdx.y * blockDim.x;
121   __syncthreads();
122   memsha[threadIdx.x] = val;
123   int offset = blockDim.x / 2;
124   while (offset > 0) {
125     __syncthreads();
126     if (threadIdx.x < offset) memsha[threadIdx.x] = r(memsha[threadIdx.x], memsha[threadIdx.x + offset]);
127     offset /= 2;
128   }
129   __syncthreads();
130   return memsha[0];
131 }
132 
133 template <typename input_t, typename accumulate_t, typename output_t, bool is_log_softmax>
SpatialSoftMaxForward(output_t * output,input_t * input,uint32_t outer_size,uint32_t dim_size,uint32_t inner_size)134 __global__ void SpatialSoftMaxForward(output_t *output, input_t *input, uint32_t outer_size, uint32_t dim_size,
135                                       uint32_t inner_size) {
136   extern __shared__ unsigned char smem[];
137   auto sdata = reinterpret_cast<accumulate_t *>(smem);
138   const uint32_t outer_stride = inner_size * dim_size;
139   const uint32_t dim_stride = inner_size;
140   for (uint32_t outer_index = blockIdx.x; outer_index < outer_size; outer_index += gridDim.x) {
141     const uint32_t outer_offset = outer_index * outer_stride;
142     for (uint32_t inner_index = blockIdx.y * blockDim.y + threadIdx.y; inner_index < inner_size;
143          inner_index += blockDim.y * gridDim.y) {
144       const uint32_t offset = outer_offset + inner_index;
145       if (blockDim.x > 1) {
146         accumulate_t max_data_input = std::numeric_limits<accumulate_t>::lowest();
147         for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
148           const accumulate_t value = static_cast<accumulate_t>(input[offset + d * dim_stride]);
149           max_data_input = atomic::Max()(max_data_input, value);
150         }
151         max_data_input = SpatialBlockReduceX<accumulate_t, atomic::Max>(sdata, max_data_input);
152         accumulate_t sum = 0;
153         for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
154           sum += std::exp(static_cast<accumulate_t>(input[offset + d * dim_stride]) - max_data_input);
155         sum = SpatialBlockReduceX<accumulate_t, atomic::Add>(sdata, sum);
156         SoftMaxForwardEpilogue<input_t, accumulate_t, output_t, is_log_softmax> epilogue(max_data_input, sum);
157         for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
158           output[offset + d * dim_stride] = epilogue(input[offset + d * dim_stride]);
159       } else {
160         accumulate_t max_data_input = std::numeric_limits<accumulate_t>::lowest();
161         for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
162           const accumulate_t value = static_cast<accumulate_t>(input[offset + d * dim_stride]);
163           max_data_input = atomic::Max()(max_data_input, value);
164         }
165         accumulate_t sum = 0;
166         for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
167           sum += std::exp(static_cast<accumulate_t>(input[offset + d * dim_stride]) - max_data_input);
168 
169         SoftMaxForwardEpilogue<input_t, accumulate_t, output_t, is_log_softmax> epilogue(max_data_input, sum);
170         for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
171           output[offset + d * dim_stride] = epilogue(input[offset + d * dim_stride]);
172       }
173     }
174   }
175 }
176 
177 template <int InsP, typename input_t, typename accum_t, typename output_t, bool is_log_softmax>
WriteResults(int clas,input_t * input,output_t * output,input_t max_k,input_t sum_all)178 __device__ __forceinline__ void WriteResults(int clas, input_t *input, output_t *output, input_t max_k,
179                                              input_t sum_all) {
180   SoftMaxForwardEpilogue<input_t, accum_t, output_t, is_log_softmax> epilogue(max_k, sum_all);
181   int offset = threadIdx.x;
182   int last = clas % (InsP * blockDim.x);
183   for (; offset < clas - last; offset += blockDim.x * InsP) {
184     input_t tmp[InsP];
185 #pragma unroll
186     for (int j = 0; j < InsP; ++j) {
187       tmp[j] = input[offset + j * blockDim.x];
188     }
189 #pragma unroll
190     for (int j = 0; j < InsP; ++j) {
191       output[offset + j * blockDim.x] = epilogue(tmp[j]);
192     }
193   }
194   for (; offset < clas; offset += blockDim.x) {
195     output[offset] = epilogue(input[offset]);
196   }
197 }
198 
199 template <int InsP, typename input_t, typename accum_t, typename output_t, bool is_log_softmax>
WriteResultsVectorized(int size,const int deviate,input_t * input,output_t * output,input_t max_k,input_t sum_all)200 __device__ __forceinline__ void WriteResultsVectorized(int size, const int deviate, input_t *input, output_t *output,
201                                                        input_t max_k, input_t sum_all) {
202   SoftMaxForwardEpilogue<input_t, accum_t, output_t, is_log_softmax> epilogue(max_k, sum_all);
203   using loadT = aligned_vector<input_t>;
204   using storeT = aligned_vector<output_t>;
205   int offset = threadIdx.x;
206   if (deviate > 0) {
207     input -= deviate;
208     output -= deviate;
209     size += deviate;
210     if (threadIdx.x >= deviate) {
211       output[offset] = epilogue(input[offset]);
212     }
213     size -= blockDim.x;
214     input += blockDim.x;
215     output += blockDim.x;
216   }
217   const int last = size % (InsP * blockDim.x);
218   input_t in_v[InsP];
219   loadT *in_value = reinterpret_cast<loadT *>(&in_v);
220   output_t out_v[InsP];
221   storeT *out_value = reinterpret_cast<storeT *>(&out_v);
222   for (; offset * InsP < (size - last); offset += blockDim.x) {
223     *in_value = reinterpret_cast<loadT *>(input)[offset];
224 #pragma unroll
225     for (int j = 0; j < InsP; ++j) {
226       out_v[j] = epilogue(in_v[j]);
227     }
228     reinterpret_cast<storeT *>(output)[offset] = *out_value;
229   }
230   offset = size - last + threadIdx.x;
231   for (; offset < size; offset += blockDim.x) {
232     output[offset] = epilogue(input[offset]);
233   }
234 }
235 
236 template <typename Reduction, typename AccT>
ReduceBlock(AccT * sharemen,AccT val,AccT initVal)237 __device__ __forceinline__ AccT ReduceBlock(AccT *sharemen, AccT val, AccT initVal) {
238   Reduction r = Reduction();
239   __syncthreads();
240   sharemen[threadIdx.x] = val;
241   __syncthreads();
242   AccT warpVal = initVal;
243   uint32_t mask = (((uint64_t)1) << (blockDim.x / WARPSIZE)) - 1;
244   if (threadIdx.x < WARPSIZE) {
245     int lane = threadIdx.x % WARPSIZE;
246     if (lane < blockDim.x / WARPSIZE) {
247 #pragma unroll
248       for (int i = 0; i < WARPSIZE; ++i) {
249         warpVal = r(warpVal, sharemen[lane * WARPSIZE + i]);
250       }
251 #ifndef __HIP_PLATFORM_HCC__
252       __syncwarp(mask);
253 #endif
254       sharemen[lane] = warpVal;
255     }
256   }
257   __syncthreads();
258   AccT blockVal = initVal;
259   if (threadIdx.x == 0) {
260     for (int i = 0; i < blockDim.x / WARPSIZE; ++i) {
261       blockVal = r(blockVal, sharemen[i]);
262     }
263     sharemen[0] = blockVal;
264   }
265   __syncthreads();
266   return sharemen[0];
267 }
268 
269 template <template <typename, typename> class Reduction, int InsP, typename T, typename AccT>
ILPReduce(int shift,T * data,int size,const Reduction<T,AccT> & r,AccT initVal)270 __device__ __forceinline__ AccT ILPReduce(int shift, T *data, int size, const Reduction<T, AccT> &r, AccT initVal) {
271   using loadT = aligned_vector<T>;
272   AccT threadVal = initVal;
273   int offset = threadIdx.x;
274   if (shift > 0) {
275     data -= shift;
276     size += shift;
277     if (threadIdx.x >= shift) {
278       threadVal = r(threadVal, data[offset]);
279     }
280     size -= blockDim.x;
281     data += blockDim.x;
282   }
283   int last = size % (InsP * blockDim.x);
284   T v[InsP];
285   loadT *value = reinterpret_cast<loadT *>(&v);
286   for (; offset * InsP < (size - last); offset += blockDim.x) {
287     *value = reinterpret_cast<loadT *>(data)[offset];
288 #pragma unroll
289     for (int j = 0; j < InsP; ++j) {
290       threadVal = r(threadVal, v[j]);
291     }
292   }
293   offset = size - last + threadIdx.x;
294   for (; offset < size; offset += blockDim.x) threadVal = r(threadVal, data[offset]);
295   return threadVal;
296 }
297 
298 template <typename acc_t, int WARP_BATCH, int WARP_SIZE, typename func>
warp_reduce(acc_t * sum)299 __device__ __forceinline__ void warp_reduce(acc_t *sum) {
300 #pragma unroll
301   for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
302 #pragma unroll
303     for (int i = 0; i < WARP_BATCH; ++i) {
304       acc_t b = WARPSHFL_XOR(sum[i], offset, WARP_SIZE);
305       sum[i] = func()(sum[i], b);
306     }
307   }
308 }
309 
310 template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax, bool is_masked>
SoftMaxWarpForward(output_t * dst,const input_t * src,int batch_size,int stride,int element_count,const bool * mask=nullptr,const int head_chunk_size=-1,bool is_transformer_mask=false)311 __global__ void SoftMaxWarpForward(output_t *dst, const input_t *src, int batch_size, int stride, int element_count,
312                                    const bool *mask = nullptr, const int head_chunk_size = -1,
313                                    bool is_transformer_mask = false) {
314   constexpr int next_power_of_two = 1 << log2_elements;
315   constexpr int WARP_SIZE = (next_power_of_two < WARPSIZE) ? next_power_of_two : WARPSIZE;
316   constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
317   constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
318 
319   int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
320   int local_batches = batch_size - first_batch;
321   if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
322   int local_idx = threadIdx.x;
323   int idx_offset = first_batch * stride + local_idx;
324 
325   src += idx_offset;
326   dst += idx_offset;
327 
328   if (is_transformer_mask) {
329     mask += ((first_batch * stride) / head_chunk_size) * stride + local_idx;
330   } else {
331     mask += idx_offset;
332   }
333   acc_t elements[WARP_BATCH][WARP_ITERATIONS];
334   for (int i = 0; i < WARP_BATCH; ++i) {
335     int batch_element_count = (i >= local_batches) ? 0 : element_count;
336     for (int it = 0; it < WARP_ITERATIONS; ++it) {
337       int element_index = local_idx + it * WARP_SIZE;
338       if (element_index < batch_element_count) {
339         elements[i][it] = src[i * element_count + it * WARP_SIZE];
340       } else {
341         elements[i][it] = -std::numeric_limits<acc_t>::infinity();
342       }
343     }
344   }
345 
346   acc_t max_value[WARP_BATCH];
347 #pragma unroll
348   for (int i = 0; i < WARP_BATCH; ++i) {
349     int batch_element_count = (i >= local_batches) ? 0 : element_count;
350     bool is_meaningful_max = false;
351     max_value[i] = elements[i][0];
352 #pragma unroll
353     for (int it = 0; it < WARP_ITERATIONS; ++it) {
354       if (is_masked) {
355         int idx = it * WARP_SIZE;
356         if ((idx + local_idx) < batch_element_count) {
357           if (!is_transformer_mask) {
358             idx += i * element_count;
359           }
360           if (!mask[idx]) {
361             max_value[i] = (is_meaningful_max && max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
362             is_meaningful_max = true;
363           }
364         }
365       } else {
366         max_value[i] = max_value[i] > elements[i][it] ? max_value[i] : elements[i][it];
367       }
368     }
369     if (is_masked) {
370       if (!is_meaningful_max) {
371         max_value[i] = -std::numeric_limits<acc_t>::infinity();
372       }
373     }
374   }
375   warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, atomic::Max>(max_value);
376 
377   acc_t sum[WARP_BATCH]{0.0f};
378 #pragma unroll
379   for (int i = 0; i < WARP_BATCH; ++i) {
380     int batch_element_count = (i >= local_batches) ? 0 : element_count;
381 #pragma unroll
382     for (int it = 0; it < WARP_ITERATIONS; ++it) {
383       if (!is_masked) {
384         if (is_log_softmax) {
385           sum[i] += std::exp(elements[i][it] - max_value[i]);
386         } else {
387           elements[i][it] = std::exp(elements[i][it] - max_value[i]);
388           sum[i] += elements[i][it];
389         }
390       } else {
391         int idx = it * WARP_SIZE;
392         bool valid = (idx + local_idx) < batch_element_count;
393         if (!is_transformer_mask) {
394           idx += i * element_count;
395         }
396         if (valid) {
397           if (!mask[idx]) {
398             if (is_log_softmax) {
399               sum[i] += std::exp(elements[i][it] - max_value[i]);
400             } else {
401               elements[i][it] = std::exp(elements[i][it] - max_value[i]);
402               sum[i] += elements[i][it];
403             }
404           } else {
405             if (!is_log_softmax) {
406               elements[i][it] = 0;
407             }
408           }
409         } else {
410           if (!is_log_softmax) {
411             elements[i][it] = 0.;
412           }
413         }
414       }
415     }
416   }
417   warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, atomic::Add>(sum);
418 #pragma unroll
419   for (int i = 0; i < WARP_BATCH; ++i) {
420     if (i >= local_batches) break;
421     if (is_log_softmax) sum[i] = std::log(sum[i]);
422 #pragma unroll
423     for (int it = 0; it < WARP_ITERATIONS; ++it) {
424       int element_index = local_idx + it * WARP_SIZE;
425       if (element_index < element_count) {
426         if (is_log_softmax) {
427           dst[i * element_count + it * WARP_SIZE] = elements[i][it] - max_value[i] - sum[i];
428         } else if (sum[i] == 0) {
429           dst[i * element_count + it * WARP_SIZE] = std::numeric_limits<acc_t>::quiet_NaN();
430         } else {
431           dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
432         }
433       } else {
434         break;
435       }
436     }
437   }
438 }
439 
440 template <int InsP, typename T, typename accumulate_t, bool is_log_softmax>
cunn_SoftMaxForward(T * output,T * input,int classes)441 __global__ void cunn_SoftMaxForward(T *output, T *input, int classes) {
442   extern __shared__ unsigned char smem[];
443   auto sdata = reinterpret_cast<accumulate_t *>(smem);
444 
445   using loadT = aligned_vector<T>;
446   using storeT = aligned_vector<T>;
447   input += blockIdx.x * classes;
448   output += blockIdx.x * classes;
449   const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(T);
450   const int output_shift = ((uint64_t)output) % ALIGN_BYTES / sizeof(T);
451 
452   accumulate_t threadMax = ILPReduce<GetMaxFloat, InsP, T, accumulate_t>(
453     shift, input, classes, GetMaxFloat<T, accumulate_t>(), -std::numeric_limits<accumulate_t>::max());
454   accumulate_t max_k =
455     ReduceBlock<atomic::Max, accumulate_t>(sdata, threadMax, -std::numeric_limits<accumulate_t>::max());
456 
457   accumulate_t threadExp = ILPReduce<GetSumExpFloat, InsP, T, accumulate_t>(
458     shift, input, classes, GetSumExpFloat<T, accumulate_t>(max_k), static_cast<accumulate_t>(0));
459   accumulate_t sumAll = ReduceBlock<atomic::Add, accumulate_t>(sdata, threadExp, static_cast<accumulate_t>(0));
460 
461   if (shift == output_shift) {
462     WriteResultsVectorized<InsP, T, accumulate_t, T, is_log_softmax>(classes, shift, input, output, max_k, sumAll);
463   } else {
464     WriteResults<InsP, T, accumulate_t, T, is_log_softmax>(classes, input, output, max_k, sumAll);
465   }
466 }
467 
468 // end of kernel function
469 template <typename input_t, typename output_t, typename acc_t, bool is_log_softmax, bool is_masked>
dispatch_softmax_forward(output_t * dst,const input_t * src,int softmax_elements,int softmax_elements_stride,int batch_count,cudaStream_t stream,const bool * mask=nullptr,int chunk_size=-1,bool is_transformer_mask=false)470 void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_elements, int softmax_elements_stride,
471                               int batch_count, cudaStream_t stream, const bool *mask = nullptr, int chunk_size = -1,
472                               bool is_transformer_mask = false) {
473   if (softmax_elements == 0) {
474     return;
475   } else {
476     int log2_elements = log2_ceil(softmax_elements);
477     const int next_power_of_two = 1 << log2_elements;
478 
479     int warp_size = (next_power_of_two < WARPSIZE) ? next_power_of_two : WARPSIZE;
480     int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
481     constexpr int threads_per_block = 128;
482 
483     int warps_per_block = (threads_per_block / warp_size);
484     int batches_per_block = warps_per_block * batches_per_warp;
485     int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
486     dim3 threads(warp_size, warps_per_block, 1);
487 
488     switch (log2_elements) {
489 #define LAUNCH_SOFTMAX_WARP_FORWARD(L2E)                                                                          \
490   case L2E:                                                                                                       \
491     SoftMaxWarpForward<input_t, output_t, acc_t, L2E, is_log_softmax, is_masked><<<blocks, threads, 0, stream>>>( \
492       dst, src, batch_count, softmax_elements_stride, softmax_elements, mask, chunk_size, is_transformer_mask);   \
493     break;
494 
495       LAUNCH_SOFTMAX_WARP_FORWARD(0);   // 1
496       LAUNCH_SOFTMAX_WARP_FORWARD(1);   // 2
497       LAUNCH_SOFTMAX_WARP_FORWARD(2);   // 4
498       LAUNCH_SOFTMAX_WARP_FORWARD(3);   // 8
499       LAUNCH_SOFTMAX_WARP_FORWARD(4);   // 16
500       LAUNCH_SOFTMAX_WARP_FORWARD(5);   // 32
501       LAUNCH_SOFTMAX_WARP_FORWARD(6);   // 64
502       LAUNCH_SOFTMAX_WARP_FORWARD(7);   // 128
503       LAUNCH_SOFTMAX_WARP_FORWARD(8);   // 256
504       LAUNCH_SOFTMAX_WARP_FORWARD(9);   // 512
505       LAUNCH_SOFTMAX_WARP_FORWARD(10);  // 1024
506       default:
507         break;
508     }
509   }
510 }
511 
512 template <typename T, bool is_log_softmax>
Softmax(T * input_,T * output_,size_t dim_size_,size_t outer_size_,size_t inner_size_,size_t device_id,cudaStream_t cuda_stream)513 cudaError_t Softmax(T *input_, T *output_, size_t dim_size_, size_t outer_size_, size_t inner_size_, size_t device_id,
514                     cudaStream_t cuda_stream) {
515   using accumulate_t = acc_type<T, true>;
516   if (inner_size_ == 1) {
517     dim3 grid(outer_size_);
518     if (dim_size_ <= 1024 && dim_size_ * sizeof(T) <= 4096) {
519       int64_t remaining = outer_size_;
520       int64_t chunk_size = (1L << 30L) / dim_size_;
521       while (remaining > 0) {
522         dispatch_softmax_forward<T, T, accumulate_t, is_log_softmax, false>(output_, input_, dim_size_, dim_size_,
523                                                                             std::min<int64_t>(remaining, chunk_size),
524                                                                             cuda_stream, nullptr /* not masked */);
525         input_ += chunk_size * dim_size_;
526         output_ += chunk_size * dim_size_;
527         remaining -= chunk_size;
528       }
529     } else {
530       constexpr int InsP = sizeof(float4) / sizeof(T);
531       dim3 block = SoftMaxGetBlockSize(InsP, dim_size_);
532       cunn_SoftMaxForward<InsP, T, accumulate_t, is_log_softmax>
533         <<<grid, block, block.x * sizeof(accumulate_t), cuda_stream>>>(output_, input_, dim_size_);
534     }
535   } else {
536     uint32_t smem_size;
537     dim3 grid, block;
538     SpatialSoftMaxGetLaunchSizes<T>(&SpatialSoftMaxForward<T, accumulate_t, T, is_log_softmax>, outer_size_, dim_size_,
539                                     inner_size_, &grid, &block, &smem_size, device_id);
540     SpatialSoftMaxForward<T, accumulate_t, T, is_log_softmax>
541       <<<grid, block, smem_size, cuda_stream>>>(output_, input_, outer_size_, dim_size_, inner_size_);
542   }
543   return GetCudaStatus();
544 }
545 
546 template CUDA_LIB_EXPORT cudaError_t Softmax<double, false>(double *input_, double *output_, size_t dim_,
547                                                             size_t outer_size_, size_t inner_size_, size_t device_id,
548                                                             cudaStream_t cuda_stream);
549 template CUDA_LIB_EXPORT cudaError_t Softmax<double, true>(double *input_, double *output_, size_t dim_,
550                                                            size_t outer_size_, size_t inner_size_, size_t device_id,
551                                                            cudaStream_t cuda_stream);
552 template CUDA_LIB_EXPORT cudaError_t Softmax<float, false>(float *input_, float *output_, size_t dim_,
553                                                            size_t outer_size_, size_t inner_size_, size_t device_id,
554                                                            cudaStream_t cuda_stream);
555 template CUDA_LIB_EXPORT cudaError_t Softmax<float, true>(float *input_, float *output_, size_t dim_,
556                                                           size_t outer_size_, size_t inner_size_, size_t device_id,
557                                                           cudaStream_t cuda_stream);
558 template CUDA_LIB_EXPORT cudaError_t Softmax<half, false>(half *input_, half *output_, size_t dim_, size_t outer_size_,
559                                                           size_t inner_size_, size_t device_id,
560                                                           cudaStream_t cuda_stream);
561 template CUDA_LIB_EXPORT cudaError_t Softmax<half, true>(half *input_, half *output_, size_t dim_, size_t outer_size_,
562                                                          size_t inner_size_, size_t device_id,
563                                                          cudaStream_t cuda_stream);
564