• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /******************************************************************************
2  * Copyright (c) 2023, Tri Dao.
3  ******************************************************************************/
4 
5 #pragma once
6 
7 #include <ATen/cuda/CUDAContextLight.h>
8 
9 #include <ATen/native/transformers/cuda/flash_attn/flash.h>
10 #include <ATen/native/transformers/cuda/flash_attn/static_switch.h>
11 #include <ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h>
12 
13 namespace pytorch_flash {
14 
15 // Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
16 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
17 #define ARCH_SUPPORTS_FLASH
18 #endif
19 
20 #if defined(ARCH_SUPPORTS_FLASH) && defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 11 && \
21     defined(__CUDACC_VER_MINOR__) && __CUDACC_VER_MINOR__ >= 8
22 #define KERNEL_PARAM_MODIFIER __grid_constant__
23 #else
24 #define KERNEL_PARAM_MODIFIER
25 #endif
26 
27 // Define a macro for unsupported architecture handling to centralize the error message
28 #define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
29 
30 // Use a macro to clean up kernel definitions
31 #define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \
32 template<typename Kernel_traits, __VA_ARGS__> \
33 __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)
34 
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel,bool Is_dropout,bool Is_causal,bool Is_local,bool Has_alibi,bool Is_even_MN,bool Is_even_K,bool Return_softmax)35 DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax) {
36     #if defined(ARCH_SUPPORTS_FLASH)
37         static_assert(!(Is_causal && Is_local)); // Enforce constraints
38         pytorch_flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
39     #else
40         FLASH_UNSUPPORTED_ARCH
41     #endif
42 }
43 
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel,bool Is_causal,bool Is_local,bool Has_alibi,bool Is_even_MN,bool Is_even_K,bool Split,bool Append_KV)44 DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV) {
45     #if defined(ARCH_SUPPORTS_FLASH)
46         pytorch_flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params);
47     #else
48         FLASH_UNSUPPORTED_ARCH
49     #endif
50 }
51 
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel,int kBlockM,int Log_max_splits,bool Is_even_K)52 DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) {
53     static_assert(Log_max_splits >= 1);
54     pytorch_flash::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params);
55 }
56 
57 template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
run_flash_fwd(Flash_fwd_params & params,cudaStream_t stream)58 void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
59     constexpr size_t smem_size = Kernel_traits::kSmemSize;
60     // printf("smem_size = %d\n", smem_size);
61 
62     // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
63     // https://github.com/kokkos/kokkos-kernels/issues/349
64     // https://github.com/HazyResearch/flash-attention/issues/21
65 
66     const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
67     dim3 grid(num_m_block, params.b, params.h);
68     const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
69     const bool is_even_K = params.d == Kernel_traits::kHeadDim;
70     const bool return_softmax = params.p_ptr != nullptr;
71     BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
72         EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
73             LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
74                 BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
75                     ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
76                         // Will only return softmax if dropout, to reduce compilation time.
77                         // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
78                         // If return_softmax, set IsEvenMNConst to false to reduce number of templates
79                         // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
80                         // If Is_local, set Is_causal to false
81                         auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
82                         // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
83                         // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
84                         // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
85                         if (smem_size >= 48 * 1024) {
86                             C10_CUDA_CHECK(cudaFuncSetAttribute(
87                                 kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
88                         }
89                         // int ctas_per_sm;
90                         // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
91                         //     &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
92                         // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
93                         kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
94                         C10_CUDA_KERNEL_LAUNCH_CHECK();
95                     });
96                 });
97             });
98         });
99     });
100 }
101 
102 template<typename Kernel_traits>
run_flash_splitkv_fwd(Flash_fwd_params & params,cudaStream_t stream)103 void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
104     static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs");
105     static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem");
106     constexpr size_t smem_size = Kernel_traits::kSmemSize;
107     const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
108     dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h);
109     const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
110     const bool is_even_K = params.d == Kernel_traits::kHeadDim;
111     BOOL_SWITCH(params.is_causal, Is_causal, [&] {
112         BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
113             EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
114                 LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
115                     BOOL_SWITCH(params.num_splits > 1, Split, [&] {
116                         BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
117                             ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
118                                 // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
119                                 // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
120                                 // If Is_local, set Is_causal to false
121                                 auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV>;
122                                 // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
123                                 // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
124                                 if (smem_size >= 48 * 1024) {
125                                     C10_CUDA_CHECK(cudaFuncSetAttribute(
126                                         kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
127                                 }
128                                 kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
129                                 C10_CUDA_KERNEL_LAUNCH_CHECK();
130                             });
131                         });
132                     });
133                 });
134             });
135         });
136     });
137     if (params.num_splits > 1) {
138         // We want kBlockM to be as small as possible for more parallelism.
139         // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
140         // If headdim is divisible by 64, then we set kBlockM = 8, etc.
141         constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16);
142         dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
143         EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
144             if (params.num_splits <= 2) {
145                 flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
146             } else if (params.num_splits <= 4) {
147                 flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
148             } else if (params.num_splits <= 8) {
149                 flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
150             } else if (params.num_splits <= 16) {
151                 flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
152             } else if (params.num_splits <= 32) {
153                 flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
154             } else if (params.num_splits <= 64) {
155                 flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
156             } else if (params.num_splits <= 128) {
157                 flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
158             }
159             C10_CUDA_KERNEL_LAUNCH_CHECK();
160         });
161     }
162 }
163 
164 template<typename T, int Headdim>
run_mha_fwd_splitkv_dispatch(Flash_fwd_params & params,cudaStream_t stream)165 void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) {
166     constexpr static int kBlockM = 64;  // Fixed for all head dimensions
167     // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
168     // and for headdim 192 with block size 64 x 128.
169     // Also for headdim 160 with block size 64 x 128 after the rotary addition.
170     constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
171     run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
172 }
173 
174 template<typename T>
run_mha_fwd_hdim32(Flash_fwd_params & params,cudaStream_t stream)175 void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
176     constexpr static int Headdim = 32;
177     DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
178         BOOL_SWITCH(params.is_causal, Is_causal, [&] {
179             run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
180         });
181     });
182 }
183 
184 template<typename T>
run_mha_fwd_hdim64(Flash_fwd_params & params,cudaStream_t stream)185 void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
186     constexpr static int Headdim = 64;
187     DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
188         BOOL_SWITCH(params.is_causal, Is_causal, [&] {
189             if constexpr(!Is_dropout) {
190                 // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
191                 // Using block size (64 x 256) is 27% slower for seqlen=2k
192                 // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
193                 run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
194                 // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
195                 // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
196             } else {
197                 run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
198                 // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
199                 // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
200                 // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
201             }
202         });
203     });
204 }
205 
206 template<typename T>
run_mha_fwd_hdim96(Flash_fwd_params & params,cudaStream_t stream)207 void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
208     constexpr static int Headdim = 96;
209     auto dprops = at::cuda::getCurrentDeviceProperties();
210     bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
211     DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
212         BOOL_SWITCH(params.is_causal, Is_causal, [&] {
213             // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
214             if (is_sm8x) {
215                 if constexpr(!Is_causal) {
216                     run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
217                 } else {
218                     run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
219                 }
220             } else {
221                 run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
222             }
223             // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
224             // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
225             // These two are always slower
226             // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
227             // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
228         });
229     });
230 }
231 
232 template<typename T>
run_mha_fwd_hdim128(Flash_fwd_params & params,cudaStream_t stream)233 void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
234     constexpr static int Headdim = 128;
235     auto dprops = at::cuda::getCurrentDeviceProperties();
236     bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
237     DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
238         BOOL_SWITCH(params.is_causal, Is_causal, [&] {
239             if constexpr(!Is_dropout) {
240                 // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
241                 // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
242                 if (is_sm8x) {
243                     if constexpr(!Is_causal) {
244                         run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
245                     } else {
246                         run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
247                     }
248                 } else {
249                     run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
250                 }
251                 // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
252                 // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
253                 // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
254                 // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
255                 // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
256                 // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
257                 // 1st ones are good for H100, A100
258                 // 2nd one is good for A6000 bc we get slightly better occupancy
259             } else {
260                 run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
261                 // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
262                 // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
263                 // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
264             }
265         });
266     });
267 }
268 
269 template<typename T>
run_mha_fwd_hdim160(Flash_fwd_params & params,cudaStream_t stream)270 void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
271     constexpr static int Headdim = 160;
272     auto dprops = at::cuda::getCurrentDeviceProperties();
273     bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
274     DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
275         BOOL_SWITCH(params.is_causal, Is_causal, [&] {
276             // For A100, H100, 128 x 32 is the fastest.
277             // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
278             // and 128 x 64 with 8 warps is the fastest for non-causal.
279             if (is_sm8x) {
280                 if constexpr(!Is_causal) {
281                     run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
282                 } else {
283                     run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
284                 }
285             } else {
286                 run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
287             }
288             // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
289             // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
290             // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
291             // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
292             // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
293             // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
294             // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
295         });
296     });
297 }
298 
299 template<typename T>
run_mha_fwd_hdim192(Flash_fwd_params & params,cudaStream_t stream)300 void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
301     constexpr static int Headdim = 192;
302     DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
303         BOOL_SWITCH(params.is_causal, Is_causal, [&] {
304             if constexpr(!Is_dropout) {
305                 run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
306             } else {
307                 run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
308             }
309             // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
310             // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
311             // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
312             // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
313             // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
314         });
315     });
316 }
317 
318 template<typename T>
run_mha_fwd_hdim224(Flash_fwd_params & params,cudaStream_t stream)319 void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
320     constexpr static int Headdim = 224;
321     int device;
322     cudaGetDevice(&device);
323     int max_smem_per_block;
324     cudaError status_ = cudaDeviceGetAttribute(
325         &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
326     if (status_ != cudaSuccess) {
327       C10_CUDA_CHECK(status_);
328     }
329     // printf("max_smem_per_block = %d\n", max_smem_per_block);
330     DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
331         BOOL_SWITCH(params.is_causal, Is_causal, [&] {
332             if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) {  // 112 KB
333                 run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
334             } else {
335                 run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
336             }
337             // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
338             // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
339             // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32.
340             // If we have N = 32, there are only 1024 elements to load at once, where each load
341             // is 8 elements. This means we can only use 128 threads and not 256 threads.
342             // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
343         });
344     });
345 }
346 
347 template<typename T>
run_mha_fwd_hdim256(Flash_fwd_params & params,cudaStream_t stream)348 void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
349     constexpr static int Headdim = 256;
350     int device;
351     cudaGetDevice(&device);
352     int max_smem_per_sm, max_smem_per_block;
353     cudaError status_ = cudaDeviceGetAttribute(
354         &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
355     status_ = cudaDeviceGetAttribute(
356         &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
357     if (status_ != cudaSuccess) {
358       C10_CUDA_CHECK(status_);
359     }
360     // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
361     DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
362         BOOL_SWITCH(params.is_causal, Is_causal, [&] {
363             // For A100, we want to run with 128 x 64 (128KB smem).
364             // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
365             if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
366                 run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
367             } else {
368                 run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
369             }
370             // 64 KB
371             // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
372             // 96 KB
373             // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
374         });
375     });
376 }
377 
378 }; // namespace pytorch_flash
379