• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /******************************************************************************
2  * Copyright (c) 2024, Tri Dao.
3  ******************************************************************************/
4 
5 #pragma once
6 
7 #include <cute/algorithm/copy.hpp>
8 
9 #include <cutlass/cutlass.h>
10 #include <cutlass/array.h>
11 #include <cutlass/numeric_types.h>
12 
13 
14 #include <ATen/native/transformers/cuda/flash_attn/block_info.h>
15 #include <ATen/native/transformers/cuda/flash_attn/kernel_traits.h>
16 #include <ATen/native/transformers/cuda/flash_attn/utils.h>
17 #include <ATen/native/transformers/cuda/flash_attn/softmax.h>
18 #include <ATen/native/transformers/cuda/flash_attn/mask.h>
19 #include <ATen/native/transformers/cuda/flash_attn/dropout.h>
20 #include <ATen/native/transformers/cuda/flash_attn/rotary.h>
21 
22 namespace pytorch_flash {
23 
24 using namespace cute;
25 
26 ////////////////////////////////////////////////////////////////////////////////////////////////////
27 
28 template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
compute_attn_1rowblock(const Params & params,const int bidb,const int bidh,const int m_block)29 inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
30 
31     using Element = typename Kernel_traits::Element;
32     using ElementAccum = typename Kernel_traits::ElementAccum;
33     using index_t = typename Kernel_traits::index_t;
34 
35     // Shared memory.
36     extern __shared__ char smem_[];
37 
38     // The thread index.
39     const int tidx = threadIdx.x;
40 
41     constexpr int kBlockM = Kernel_traits::kBlockM;
42     constexpr int kBlockN = Kernel_traits::kBlockN;
43     constexpr int kHeadDim = Kernel_traits::kHeadDim;
44     constexpr int kNWarps = Kernel_traits::kNWarps;
45 
46     auto seed_offset = at::cuda::philox::unpack(params.philox_args);
47     pytorch_flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t,
48                            bidb, bidh, tidx, params.h);
49 
50     // Save seed and offset for backward. If we don't have this here, the 0-th thread block might
51     // exit early and no one saves the rng state.
52     if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
53         if (params.philox_args.captured_) {
54             *params.seed = std::get<0>(seed_offset);
55             *params.extragraph_offset = std::get<1>(seed_offset);
56         }
57     }
58 
59     const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
60     if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
61 
62     const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
63     int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
64     if (Is_causal || Is_local) {
65         n_block_max = std::min(n_block_max,
66                                cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
67         // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
68         //     printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
69         // }
70     }
71     // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0.
72     // Otherwise we might read OOB elements from gK and gV.
73     if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) {
74         Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr)
75                                               + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)),
76                                 make_shape(binfo.actual_seqlen_q, params.h, params.d),
77                                 make_stride(params.o_row_stride, params.o_head_stride, _1{}));
78         Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
79                               make_coord(m_block, 0));  // (kBlockM, kHeadDim)
80         Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
81                                   make_shape(params.b, params.h, params.seqlen_q),
82                                   make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
83         Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
84 
85 
86         typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
87         auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
88         Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
89         Tensor tOrO = make_tensor<Element>(shape(tOgO));
90         clear(tOrO);
91         // Construct identity layout for sO
92         Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO)));    // (BLK_M,BLK_K) -> (blk_m,blk_k)
93         // Repeat the partitioning with identity layouts
94         Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
95         Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
96         if (!Is_even_K) {
97             #pragma unroll
98             for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
99         }
100         // Clear_OOB_K must be false since we don't want to write zeros to gmem
101         pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
102             gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
103         );
104         #pragma unroll
105         for (int m = 0; m < size<1>(tOgO); ++m) {
106             const int row = get<0>(tOcO(0, m, 0));
107             if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; }
108         }
109         return;
110     }
111     // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); }
112 
113     // We iterate over the blocks in reverse order. This is because the last block is the only one
114     // that needs masking when we read K and V from global memory. Moreover, iterating in reverse
115     // might save us 1 register (we just need n_block instead of both n_block and n_block_max).
116 
117     const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded
118         + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN;
119 
120     Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr)
121                                           + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
122                             make_shape(binfo.actual_seqlen_q, params.h, params.d),
123                             make_stride(params.q_row_stride, params.q_head_stride, _1{}));
124     Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
125                            make_coord(m_block, 0));  // (kBlockM, kHeadDim)
126     Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.k_ptr)
127                                           + binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)),
128                             make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
129                             make_stride(params.k_row_stride, params.k_head_stride, _1{}));
130     Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _), Shape<Int<kBlockN>, Int<kHeadDim>>{},
131                            make_coord(_, 0));  // (kBlockN, kHeadDim, nblocksN)
132     Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.v_ptr)
133                                           + binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)),
134                             make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
135                             make_stride(params.v_row_stride, params.v_head_stride, _1{}));
136     Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _), Shape<Int<kBlockN>, Int<kHeadDim>>{},
137                            make_coord(_, 0));  // (kBlockN, kHeadDim, nblocksN)
138     Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.p_ptr) + row_offset_p),
139                             Shape<Int<kBlockM>, Int<kBlockN>>{},
140                             make_stride(params.seqlen_k_rounded, _1{}));
141 
142     Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
143                             typename Kernel_traits::SmemLayoutQ{});
144     // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
145     Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)),
146                             typename Kernel_traits::SmemLayoutKV{});
147     Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
148     Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
149     Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
150 
151     typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
152     auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
153 
154     Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
155     Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
156     Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK);  // (KCPY, KCPY_N, KCPY_K, nblocksN)
157     Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
158     Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV);  // (VCPY, VCPY_N, VCPY_K, nblocksN)
159     Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
160 
161     typename Kernel_traits::TiledMma tiled_mma;
162     auto thr_mma = tiled_mma.get_thread_slice(tidx);
163     Tensor tSrQ  = thr_mma.partition_fragment_A(sQ);                           // (MMA,MMA_M,MMA_K)
164     Tensor tSrK  = thr_mma.partition_fragment_B(sK);                           // (MMA,MMA_N,MMA_K)
165     Tensor tOrVt  = thr_mma.partition_fragment_B(sVtNoSwizzle);                // (MMA, MMA_K,MMA_N)
166 
167     Tensor tSgS  = thr_mma.partition_C(gP);
168 
169     Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{});  // MMA, MMA_M, MMA_K
170 
171     //
172     // Copy Atom retiling
173     //
174 
175     auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
176     auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
177     // if (cute::thread0()) {smem_thr_copy_Q.print_all();}
178     Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
179     // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
180 
181     auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
182     auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
183     Tensor tSsK = smem_thr_copy_K.partition_S(sK);
184 
185     auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
186     auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
187     Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
188 
189     //
190     // PREDICATES
191     //
192 
193     // // Allocate predicate tensors for m and n
194     // Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{});
195     // Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{});
196 
197     // Construct identity layout for sQ and sK
198     Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ)));    // (BLK_M,BLK_K) -> (blk_m,blk_k)
199     Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK)));    // (BLK_N,BLK_K) -> (blk_n,blk_k)
200     // Tensor tScQ = thr_mma.partition_A(cQ);                           // (MMA,MMA_M,MMA_K)
201     // if (cute::thread0()) {
202     //     print(tScQ.layout()); printf("\n");
203     //     for (int i = 0; i < size(tScQ); ++i) {
204     //         printf("%d ", get<0>(tScQ(i)));
205     //     }
206     //     printf("\n");
207     //     for (int i = 0; i < size(tScQ); ++i) {
208     //         printf("%d ", get<1>(tScQ(i)));
209     //     }
210     //     printf("\n");
211     // }
212 
213     // Repeat the partitioning with identity layouts
214     Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ);       // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
215     Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV);   // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
216 
217     // Allocate predicate tensors for k
218     Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
219     Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
220 
221     // Set predicates for k bounds
222     if (!Is_even_K) {
223         #pragma unroll
224         for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }
225         #pragma unroll
226         for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; }
227     }
228 
229     // Prologue
230 
231     // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
232     pytorch_flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
233                                        binfo.actual_seqlen_q - m_block * kBlockM);
234     if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
235 
236     // // if (cute::thread(1, 0)) { print(tQsQ); }
237     // // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{});
238     // // if (cute::thread0()) { print(sQNoSwizzle); }
239 
240     if (Kernel_traits::Share_Q_K_smem) {
241         pytorch_flash::cp_async_wait<0>();
242         __syncthreads();
243         Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
244         CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view));            // M
245         cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
246         __syncthreads();
247     }
248 
249     int n_block = n_block_max - 1;
250     // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
251     pytorch_flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV,
252                                        binfo.actual_seqlen_k - n_block * kBlockN);
253     cute::cp_async_fence();
254     // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
255     // __syncthreads();
256 
257     if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) {
258         pytorch_flash::cp_async_wait<1>();
259         __syncthreads();
260         Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
261         CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view));            // M
262         cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
263     }
264 
265     clear(acc_o);
266 
267     pytorch_flash::Softmax<2 * size<1>(acc_o)> softmax;
268 
269     const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
270     pytorch_flash::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);
271 
272     // For performance reason, we separate out two kinds of iterations:
273     // those that need masking on S, and those that don't.
274     // We need masking on S for the very last block when K and V has length not multiple of kBlockN.
275     // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
276     // We will have at least 1 "masking" iteration.
277 
278     // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
279     // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
280     constexpr int n_masking_steps = (!Is_causal && !Is_local)
281         ? 1
282         : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
283     #pragma unroll
284     for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
285         Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});  // (MMA=4, MMA_M, MMA_N)
286         clear(acc_s);
287         pytorch_flash::cp_async_wait<0>();
288         __syncthreads();
289 
290         // Advance gV
291         if (masking_step > 0) {
292             pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
293         } else {
294             // Clear the smem tiles to account for predicated off loads
295             pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
296                 gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
297             );
298         }
299         cute::cp_async_fence();
300         cute::cp_async_fence();
301 
302         pytorch_flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
303             acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
304             smem_thr_copy_Q, smem_thr_copy_K
305         );
306         // if (cute::thread0()) { print(acc_s); }
307 
308         mask.template apply_mask<Is_causal, Is_even_MN>(
309             acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
310         );
311 
312         pytorch_flash::cp_async_wait<0>();
313         __syncthreads();
314         if (n_block > n_block_min) {
315             pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV);
316             // This cp_async_fence needs to be in the if block, otherwise the synchronization
317             // isn't right and we get race conditions.
318             cute::cp_async_fence();
319         }
320 
321         // TODO: when we have key_padding_mask we'll need to Check_inf
322         masking_step == 0
323             ? softmax.template softmax_rescale_o</*Is_first=*/true,  /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2)
324             : softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2);
325 
326         // Convert acc_s from fp32 to fp16/bf16
327         Tensor rP = pytorch_flash::convert_type<Element>(acc_s);
328         int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
329         int block_col_idx = n_block * (kBlockN / 32);
330         if (Return_softmax) {
331             Tensor rP_drop = make_fragment_like(rP);
332             cute::copy(rP, rP_drop);
333             dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
334                 rP_drop, block_row_idx, block_col_idx, kNWarps
335             );
336             cute::copy(rP_drop, tSgS);
337             tSgS.data() = tSgS.data() + (-kBlockN);
338         }
339         if (Is_dropout) {
340             dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);
341         }
342 
343         // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
344         // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
345         Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<typename Kernel_traits::TiledMma>(rP.layout()));
346         // if (cute::thread0()) { print(tOrP); }
347         pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
348         // if (cute::thread0()) { print(scores); }
349 
350         // This check is at the end of the loop since we always have at least 1 iteration
351         if (n_masking_steps > 1 && n_block <= n_block_min) {
352             --n_block;
353             break;
354         }
355     }
356 
357     // These are the iterations where we don't need masking on S
358     for (; n_block >= n_block_min; --n_block) {
359         Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});  // (MMA=4, MMA_M, MMA_N)
360         clear(acc_s);
361         pytorch_flash::cp_async_wait<0>();
362         __syncthreads();
363         pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
364         cute::cp_async_fence();
365 
366         pytorch_flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
367             acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
368             smem_thr_copy_Q, smem_thr_copy_K
369         );
370 
371         pytorch_flash::cp_async_wait<0>();
372         __syncthreads();
373         if (n_block > n_block_min) {
374             pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV);
375             // This cp_async_fence needs to be in the if block, otherwise the synchronization
376             // isn't right and we get race conditions.
377             cute::cp_async_fence();
378         }
379 
380         mask.template apply_mask</*Causal_mask=*/false>(
381             acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
382         );
383 
384         softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2);
385 
386         Tensor rP = pytorch_flash::convert_type<Element>(acc_s);
387         int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
388         int block_col_idx = n_block * (kBlockN / 32);
389         if (Return_softmax) {
390             Tensor rP_drop = make_fragment_like(rP);
391             cute::copy(rP, rP_drop);
392             dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
393                 rP_drop, block_row_idx, block_col_idx, kNWarps
394             );
395             cute::copy(rP_drop, tSgS);
396             tSgS.data() = tSgS.data() + (-kBlockN);
397         }
398         if (Is_dropout) {
399             dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);
400         }
401 
402         // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
403         // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
404         Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<typename Kernel_traits::TiledMma>(rP.layout()));
405         pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
406     }
407 
408     // Epilogue
409 
410     Tensor lse = softmax.template normalize_softmax_lse<Is_dropout>(acc_o, params.scale_softmax, params.rp_dropout);
411 
412     // Convert acc_o from fp32 to fp16/bf16
413     Tensor rO = pytorch_flash::convert_type<Element>(acc_o);
414     Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{});    // (SMEM_M,SMEM_N)
415     // Partition sO to match the accumulator partitioning
416     auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
417     auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
418     Tensor taccOrO = smem_thr_copy_O.retile_S(rO);        // ((Atom,AtomNum), MMA_M, MMA_N)
419     Tensor taccOsO = smem_thr_copy_O.partition_D(sO);     // ((Atom,AtomNum),PIPE_M,PIPE_N)
420 
421     // sO has the same size as sQ, so we don't need to sync here.
422     if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); }
423 
424     cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
425 
426     Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr)
427                                           + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)),
428                             make_shape(binfo.actual_seqlen_q, params.h, params.d),
429                             make_stride(params.o_row_stride, params.o_head_stride, _1{}));
430     Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
431                            make_coord(m_block, 0));  // (kBlockM, kHeadDim)
432     Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
433                               make_shape(params.b, params.h, params.seqlen_q),
434                               make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
435     Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
436 
437     typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
438     auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
439     Tensor tOsO = gmem_thr_copy_O.partition_S(sO);        // ((Atom,AtomNum),ATOM_M,ATOM_N)
440     Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
441 
442     __syncthreads();
443 
444     Tensor tOrO = make_tensor<Element>(shape(tOgO));
445     cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
446 
447     Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});    // (BLK_M,BLK_K) -> (blk_m,blk_k)
448     Tensor taccOcO = thr_mma.partition_C(caccO);                           // (MMA,MMA_M,MMA_K)
449     static_assert(decltype(size<0>(taccOcO))::value == 4);
450     // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
451     Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0);
452     CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row));                     // MMA_M
453     if (get<1>(taccOcO_row(0)) == 0) {
454         #pragma unroll
455         for (int mi = 0; mi < size(lse); ++mi) {
456             const int row = get<0>(taccOcO_row(mi));
457             if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); }
458         }
459     }
460 
461     // Construct identity layout for sO
462     Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO)));    // (BLK_M,BLK_K) -> (blk_m,blk_k)
463     // Repeat the partitioning with identity layouts
464     Tensor tOcO = gmem_thr_copy_O.partition_D(cO);                           // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
465     Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
466     if (!Is_even_K) {
467         #pragma unroll
468         for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
469     }
470     // Clear_OOB_K must be false since we don't want to write zeros to gmem
471     pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
472         gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
473     );
474 }
475 
476 ////////////////////////////////////////////////////////////////////////////////////////////////////
477 
478 template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
compute_attn_1rowblock_splitkv(const Params & params,const int bidb,const int bidh,const int m_block,const int n_split_idx,const int num_n_splits)479 inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) {
480 
481     using Element = typename Kernel_traits::Element;
482     using ElementAccum = typename Kernel_traits::ElementAccum;
483     using index_t = typename Kernel_traits::index_t;
484 
485     // Shared memory.
486     extern __shared__ char smem_[];
487 
488     // The thread index.
489     const int tidx = threadIdx.x;
490 
491     constexpr int kBlockM = Kernel_traits::kBlockM;
492     constexpr int kBlockN = Kernel_traits::kBlockN;
493     constexpr int kHeadDim = Kernel_traits::kHeadDim;
494     constexpr int kNWarps = Kernel_traits::kNWarps;
495 
496     using GmemTiledCopyO = std::conditional_t<
497         !Split,
498         typename Kernel_traits::GmemTiledCopyO,
499         typename Kernel_traits::GmemTiledCopyOaccum
500     >;
501     using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
502 
503     const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
504     // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); }
505     // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); }
506     if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
507 
508     const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits;
509     const int n_block_min = !Is_local
510         ? n_split_idx * n_blocks_per_split
511         : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
512     int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split);
513     if (Is_causal || Is_local) {
514         n_block_max = std::min(n_block_max,
515                                cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
516     }
517     if (n_block_min >= n_block_max) {  // This also covers the case where n_block_max <= 0
518         // We exit early and write 0 to gOaccum and -inf to gLSEaccum.
519         // Otherwise we might read OOB elements from gK and gV,
520         // or get wrong results when we combine gOaccum from different blocks.
521         const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
522             + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
523         const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q
524             + m_block * kBlockM) * params.d_rounded;
525         const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
526         Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
527                                       Shape<Int<kBlockM>, Int<kHeadDim>>{},
528                                      make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));
529         Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum),
530                                       Shape<Int<kBlockM>>{}, Stride<_1>{});
531 
532         GmemTiledCopyO gmem_tiled_copy_Oaccum;
533         auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
534         Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
535         Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
536         clear(tOrOaccum);
537         // Construct identity layout for sO
538         Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum)));    // (BLK_M,BLK_K) -> (blk_m,blk_k)
539         // Repeat the partitioning with identity layouts
540         Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO);
541         Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
542         if (!Is_even_K) {
543             #pragma unroll
544             for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
545         }
546         // Clear_OOB_K must be false since we don't want to write zeros to gmem
547         pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
548             gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
549         );
550         #pragma unroll
551         for (int m = 0; m < size<1>(tOgOaccum); ++m) {
552             const int row = get<0>(tOcO(0, m, 0));
553             if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = Split ? -INFINITY : INFINITY; }
554         }
555         return;
556     }
557 
558     // We iterate over the blocks in reverse order. This is because the last block is the only one
559     // that needs masking when we read K and V from global memory. Moreover, iterating in reverse
560     // might save us 1 register (we just need n_block instead of both n_block and n_block_max).
561 
562 
563     // We move K and V to the last block.
564     const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
565     const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride;
566     const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size;
567     const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size;
568     const index_t row_offset_k = block_table == nullptr
569         ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache)
570           + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride
571         : block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
572     const index_t row_offset_v = block_table == nullptr
573         ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
574           + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
575         : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
576 
577     Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
578                             make_shape(binfo.actual_seqlen_q, params.h, params.d),
579                             make_stride(params.q_row_stride, params.q_head_stride, _1{}));
580     Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
581                            make_coord(m_block, 0));  // (kBlockM, kHeadDim)
582     Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
583                             Shape<Int<kBlockN>, Int<kHeadDim>>{},
584                             make_stride(params.k_row_stride, _1{}));
585     // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); }
586     Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
587                             Shape<Int<kBlockN>, Int<kHeadDim>>{},
588                             make_stride(params.v_row_stride, _1{}));
589 
590     Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
591                             typename Kernel_traits::SmemLayoutQ{});
592     Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{});
593     Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
594     Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
595     Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
596 
597     typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
598     auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
599 
600     Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
601     Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
602     Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK);  // (KCPY, KCPY_N, KCPY_K)
603     Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
604     Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV);  // (VCPY, VCPY_N, VCPY_K)
605     Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
606 
607     typename Kernel_traits::TiledMma tiled_mma;
608     auto thr_mma = tiled_mma.get_thread_slice(tidx);
609     Tensor tSrQ  = thr_mma.partition_fragment_A(sQ);                           // (MMA,MMA_M,MMA_K)
610     Tensor tSrK  = thr_mma.partition_fragment_B(sK);                           // (MMA,MMA_N,MMA_K)
611     Tensor tOrVt  = thr_mma.partition_fragment_B(sVtNoSwizzle);                // (MMA, MMA_K,MMA_N)
612 
613     Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{});  // MMA, MMA_M, MMA_K
614 
615     //
616     // Copy Atom retiling
617     //
618 
619     auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
620     auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
621     Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
622 
623     auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
624     auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
625     Tensor tSsK = smem_thr_copy_K.partition_S(sK);
626 
627     auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
628     auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
629     Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
630 
631     // PREDICATES
632     //
633 
634     // // Allocate predicate tensors for m and n
635     // Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{});
636     // Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{});
637 
638     // Construct identity layout for sQ and sK
639     Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ)));    // (BLK_M,BLK_K) -> (blk_m,blk_k)
640     Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK)));    // (BLK_N,BLK_K) -> (blk_n,blk_k)
641 
642     // Repeat the partitioning with identity layouts
643     Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ);       // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
644     Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV);   // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
645 
646     // Allocate predicate tensors for k
647     Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
648     Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
649 
650     // Set predicates for k bounds
651     if (!Is_even_K) {
652         #pragma unroll
653         for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }
654         #pragma unroll
655         for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; }
656     }
657 
658     // Prologue
659 
660     // Copy from Knew to K, optionally apply rotary embedding.
661     typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary;
662     auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx);
663     typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont;
664     auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx);
665     if constexpr (Append_KV) {
666         // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
667         // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
668         // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
669         const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2);
670         Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
671                                   Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
672                                   make_stride(params.rotary_dim / 2, _1{}));
673         Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
674                                   Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
675                                   make_stride(params.rotary_dim / 2, _1{}));
676         Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
677                                       Shape<Int<kBlockN>, Int<kHeadDim>>{},
678                                       make_stride(params.rotary_dim / 2, _1{}));
679         Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
680                                       Shape<Int<kBlockN>, Int<kHeadDim>>{},
681                                       make_stride(params.rotary_dim / 2, _1{}));
682         Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
683         Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
684         Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
685         Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
686         // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); }
687         // if (cute::thread(8, 0)) { print_tensor(gCos); }
688         // if (cute::thread(0, 0)) { print_tensor(tRgCos); }
689 
690         const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
691             + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
692         const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
693             + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
694         // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
695         // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
696         // This maps to accessing the first 64 rows of knew_ptr.
697         Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.knew_ptr)
698                                                 + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride),
699                                   Shape<Int<kBlockN>, Int<kHeadDim>>{},
700                                   make_stride(params.knew_row_stride, _1{}));
701         // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); }
702         Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.vnew_ptr)
703                                                 + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride),
704                                   Shape<Int<kBlockN>, Int<kHeadDim>>{},
705                                   make_stride(params.vnew_row_stride, _1{}));
706         Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew);  // (KCPY, KCPY_N, KCPY_K)
707         Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew);  // (VCPY, VCPY_N, VCPY_K)
708 
709         const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN);
710         auto tKgK_data = tKgK.data();
711         auto tVgV_data = tVgV.data();
712         for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) {
713             pytorch_flash::copy_w_min_idx<Is_even_K>(
714                 tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
715             );
716             tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
717             if (params.rotary_dim == 0) {
718                 pytorch_flash::copy_w_min_idx<Is_even_K>(
719                     tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
720                 );
721             } else {
722                 if (params.is_rotary_interleaved) {
723                     // Don't clear OOB_K because we're writing to global memory
724                     pytorch_flash::copy_rotary_interleaved<Is_even_K, /*Clear_OOB_K=*/false>(
725                         tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN,
726                         binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim
727                     );
728                     tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2));
729                     tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2));
730                 } else {
731                     // Don't clear OOB_K because we're writing to global memory
732                     pytorch_flash::copy_rotary_contiguous<Is_even_K, /*Clear_OOB_K=*/false>(
733                         tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN,
734                         binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim
735                     );
736                     tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2));
737                     tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2));
738 
739                 }
740             }
741             tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
742             if (block_table == nullptr) {
743                 tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
744                 tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
745             } else {
746                 if (n_block > n_block_copy_min) {
747                     const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
748                     const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
749                     const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
750                     const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
751                     const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur];
752                     const int offset_diff = block_table_offset_next - block_table_offset_cur;
753                     tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;
754                     tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;
755                 }
756             }
757         }
758         // Need this before we can read in K again, so that we'll see the updated K values.
759         __syncthreads();
760         tKgK.data() = tKgK_data;
761         tVgV.data() = tVgV_data;
762     }
763 
764     // Read Q from gmem to smem, optionally apply rotary embedding.
765     if (!Append_KV || params.rotary_dim == 0) {
766         // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
767         pytorch_flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
768                                            binfo.actual_seqlen_q - m_block * kBlockM);
769     } else {
770         const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
771         // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
772         // We do this by setting the row stride of gCos / gSin to 0.
773         Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
774                                   Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},
775                                   make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
776         Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
777                                   Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},
778                                   make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
779         Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
780                                   Shape<Int<kBlockM>, Int<kHeadDim>>{},
781                                   make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
782         Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
783                                   Shape<Int<kBlockM>, Int<kHeadDim>>{},
784                                   make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
785         Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
786         Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
787         Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
788         Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
789         if (params.is_rotary_interleaved) {
790             pytorch_flash::copy_rotary_interleaved<Is_even_K>(
791                 tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM,
792                 0, params.d, params.rotary_dim
793             );
794         } else {
795             pytorch_flash::copy_rotary_contiguous<Is_even_K>(
796                 tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM,
797                 0, params.d, params.rotary_dim
798             );
799         }
800     }
801 
802     int n_block = n_block_max - 1;
803     // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
804     pytorch_flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
805                                        binfo.actual_seqlen_k - n_block * kBlockN);
806     cute::cp_async_fence();
807 
808     // pytorch_flash::cp_async_wait<0>();
809     // __syncthreads();
810     // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); }
811     // __syncthreads();
812 
813     clear(acc_o);
814 
815     pytorch_flash::Softmax<2 * size<1>(acc_o)> softmax;
816 
817     const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
818     pytorch_flash::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);
819 
820     // For performance reason, we separate out two kinds of iterations:
821     // those that need masking on S, and those that don't.
822     // We need masking on S for the very last block when K and V has length not multiple of kBlockN.
823     // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
824     // We will have at least 1 "masking" iteration.
825 
826     // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
827     // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
828     constexpr int n_masking_steps = (!Is_causal && !Is_local)
829         ? 1
830         : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
831     #pragma unroll
832     for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
833         Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});  // (MMA=4, MMA_M, MMA_N)
834         clear(acc_s);
835         pytorch_flash::cp_async_wait<0>();
836         __syncthreads();
837 
838         // Advance gV
839         if (masking_step > 0) {
840             if (block_table == nullptr) {
841                 tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
842             } else {
843                 const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
844                 const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
845                 const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
846                 const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
847                 tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
848             }
849             pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
850         } else {
851             // Clear the smem tiles to account for predicated off loads
852             pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
853                 gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
854             );
855         }
856         cute::cp_async_fence();
857 
858         pytorch_flash::gemm(
859             acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
860             smem_thr_copy_Q, smem_thr_copy_K
861         );
862         // if (cute::thread0()) { print(acc_s); }
863 
864         mask.template apply_mask<Is_causal, Is_even_MN>(
865             acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
866         );
867 
868         pytorch_flash::cp_async_wait<0>();
869         __syncthreads();
870         // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); }
871         // __syncthreads();
872 
873         if (n_block > n_block_min) {
874             // Advance gK
875             if (block_table == nullptr) {
876                 tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
877             } else {
878                 const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
879                 const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
880                 const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
881                 const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
882                 tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
883             }
884             pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
885             // This cp_async_fence needs to be in the if block, otherwise the synchronization
886             // isn't right and we get race conditions.
887             cute::cp_async_fence();
888         }
889 
890         // We have key_padding_mask so we'll need to Check_inf
891         masking_step == 0
892             ? softmax.template softmax_rescale_o</*Is_first=*/true,  /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, acc_o, params.scale_softmax_log2)
893             : softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, acc_o, params.scale_softmax_log2);
894         // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
895 
896         // Convert acc_s from fp32 to fp16/bf16
897         Tensor rP = pytorch_flash::convert_type<Element>(acc_s);
898         // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
899         // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
900         Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<typename Kernel_traits::TiledMma>(rP.layout()));
901 
902         pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
903 
904         // This check is at the end of the loop since we always have at least 1 iteration
905         if (n_masking_steps > 1 && n_block <= n_block_min) {
906             --n_block;
907             break;
908         }
909     }
910 
911     // These are the iterations where we don't need masking on S
912     for (; n_block >= n_block_min; --n_block) {
913         Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});  // (MMA=4, MMA_M, MMA_N)
914         clear(acc_s);
915         pytorch_flash::cp_async_wait<0>();
916         __syncthreads();
917         // Advance gV
918         if (block_table == nullptr) {
919             tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
920         } else {
921             const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
922             const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
923             const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
924             const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
925             tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
926         }
927         pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
928         cute::cp_async_fence();
929 
930         pytorch_flash::gemm(
931             acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
932             smem_thr_copy_Q, smem_thr_copy_K
933         );
934 
935         pytorch_flash::cp_async_wait<0>();
936         __syncthreads();
937         if (n_block > n_block_min) {
938             // Advance gK
939             if (block_table == nullptr) {
940                 tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
941             } else {
942                 const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
943                 const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
944                 const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
945                 const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
946                 tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
947             }
948             pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
949             // This cp_async_fence needs to be in the if block, otherwise the synchronization
950             // isn't right and we get race conditions.
951             cute::cp_async_fence();
952         }
953 
954         mask.template apply_mask</*Causal_mask=*/false>(
955             acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
956         );
957         softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2);
958 
959         Tensor rP = pytorch_flash::convert_type<Element>(acc_s);
960         // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
961         // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
962         Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<typename Kernel_traits::TiledMma>(rP.layout()));
963 
964         pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
965     }
966 
967     // Epilogue
968 
969     Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(acc_o, params.scale_softmax);
970     // if (cute::thread0()) { print(lse); }
971 
972     Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
973     // Partition sO to match the accumulator partitioning
974     using SmemTiledCopyO = std::conditional_t<
975         !Split,
976         typename Kernel_traits::SmemCopyAtomO,
977         typename Kernel_traits::SmemCopyAtomOaccum
978     >;
979     auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma);
980     auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
981     Tensor rO = pytorch_flash::convert_type<ElementO>(acc_o);
982     Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO);        // ((Atom,AtomNum), MMA_M, MMA_N)
983     Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum);     // ((Atom,AtomNum),PIPE_M,PIPE_N)
984 
985     // sOaccum is larger than sQ, so we need to syncthreads here
986     // TODO: allocate enough smem for sOaccum
987     if constexpr (Split) { __syncthreads(); }
988 
989     cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
990 
991     const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
992         + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
993     const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q
994                                          + m_block * kBlockM) * params.d_rounded;
995     const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
996 
997     Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
998                                  Shape<Int<kBlockM>, Int<kHeadDim>>{},
999                                  make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));
1000     Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum),
1001                                    Shape<Int<kBlockM>>{}, Stride<_1>{});
1002     // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); }
1003 
1004     GmemTiledCopyO gmem_tiled_copy_Oaccum;
1005     auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
1006     Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum);        // ((Atom,AtomNum),ATOM_M,ATOM_N)
1007     Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
1008 
1009     __syncthreads();
1010 
1011     Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
1012     cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
1013 
1014     Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});    // (BLK_M,BLK_K) -> (blk_m,blk_k)
1015     Tensor taccOcO = thr_mma.partition_C(caccO);                           // (MMA,MMA_M,MMA_K)
1016     static_assert(decltype(size<0>(taccOcO))::value == 4);
1017     // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
1018     Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0);
1019     CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row));                     // MMA_M
1020     if (get<1>(taccOcO_row(0)) == 0) {
1021         #pragma unroll
1022         for (int mi = 0; mi < size(lse); ++mi) {
1023             const int row = get<0>(taccOcO_row(mi));
1024             if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
1025         }
1026     }
1027 
1028     // Construct identity layout for sO
1029     Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum)));    // (BLK_M,BLK_K) -> (blk_m,blk_k)
1030     // Repeat the partitioning with identity layouts
1031     Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO);                           // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
1032     Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
1033     if (!Is_even_K) {
1034         #pragma unroll
1035         for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
1036     }
1037     // Clear_OOB_K must be false since we don't want to write zeros to gmem
1038     pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
1039         gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
1040     );
1041 }
1042 
1043 ////////////////////////////////////////////////////////////////////////////////////////////////////
1044 
1045 template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
compute_attn(const Params & params)1046 inline __device__ void compute_attn(const Params &params) {
1047     const int m_block = blockIdx.x;
1048     // The block index for the batch.
1049     const int bidb = blockIdx.y;
1050     // The block index for the head.
1051     const int bidh = blockIdx.z;
1052 
1053     // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting
1054     // them to have the same number of threads or have to traverse the attention matrix
1055     // in the same order.
1056     // In the Philox RNG, we use the offset to store the batch, head, and the lane id
1057     // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within
1058     // the attention matrix. This way, as long as we have the batch, head, and the location of
1059     // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
1060 
1061     pytorch_flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
1062 }
1063 
1064 ////////////////////////////////////////////////////////////////////////////////////////////////////
1065 
1066 template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
compute_attn_splitkv(const Params & params)1067 inline __device__ void compute_attn_splitkv(const Params &params) {
1068     const int m_block = blockIdx.x;
1069     // The block index for the batch.
1070     const int bidb = Split ? blockIdx.z / params.h : blockIdx.y;
1071     // The block index for the head.
1072     const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;
1073     const int n_split_idx = Split ? blockIdx.y : 0;
1074     const int num_n_splits = Split ? gridDim.y : 1;
1075     pytorch_flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
1076 }
1077 
1078 ////////////////////////////////////////////////////////////////////////////////////////////////////
1079 
1080 template <typename T>
ceil_div(T numerator,T denominator)1081 constexpr T ceil_div(T numerator, T denominator) {
1082     return (numerator + denominator - 1) / denominator;
1083 }
1084 
1085 
1086 template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K, typename Params>
combine_attn_seqk_parallel(const Params & params)1087 inline __device__ void combine_attn_seqk_parallel(const Params &params) {
1088     using Element = typename Kernel_traits::Element;
1089     using ElementAccum = typename Kernel_traits::ElementAccum;
1090     using index_t = typename Kernel_traits::index_t;
1091     constexpr int kMaxSplits = 1 << Log_max_splits;
1092     constexpr int kHeadDim = Kernel_traits::kHeadDim;
1093     constexpr int kNThreads = Kernel_traits::kNThreads;
1094 
1095     static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128");
1096     static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32");
1097     static_assert(kNThreads == 128, "We assume that each block has 128 threads");
1098 
1099     // Shared memory.
1100     // kBlockM + 1 instead of kBlockM to reduce bank conflicts.
1101     __shared__ ElementAccum sLSE[kMaxSplits][kBlockM + 1];
1102 
1103     // The thread and block index.
1104     const int tidx = threadIdx.x;
1105     const int bidx = blockIdx.x;
1106 
1107     const index_t row_offset_lse = bidx * kBlockM;
1108     Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lse),
1109                                    Shape<Int<kMaxSplits>, Int<kBlockM>>{},
1110                                    make_stride(params.b * params.h * params.seqlen_q, _1{}));
1111     Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
1112                               Shape<Int<kBlockM>>{}, Stride<_1>{});
1113     constexpr int kNLsePerThread = ceil_div(kMaxSplits * kBlockM, kNThreads);
1114     // Read the LSE values from gmem and store them in shared memory, then tranpose them.
1115     constexpr int kRowsPerLoadLSE = kNThreads / kBlockM;
1116     #pragma unroll
1117     for (int l = 0; l < kNLsePerThread; ++l) {
1118         const int row = l * kRowsPerLoadLSE + tidx / kBlockM;
1119         const int col = tidx % kBlockM;
1120         ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY;
1121         if (row < kMaxSplits) { sLSE[row][col] = lse; }
1122         // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); }
1123     }
1124     // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); }
1125     __syncthreads();
1126     Tensor lse_accum = make_tensor<ElementAccum>(Shape<Int<kNLsePerThread>>{});
1127     constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits);
1128     // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits
1129     // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads,
1130     // kBlockM rows, so each time we load we can load 128 / kBlockM rows).
1131     // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose;
1132     // static_assert(kThreadsPerSplit <= 32);
1133     static_assert(kRowsPerLoadTranspose <= 32);
1134     static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits);
1135     #pragma unroll
1136     for (int l = 0; l < kNLsePerThread; ++l) {
1137         const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
1138         const int col = tidx / kRowsPerLoadTranspose;
1139         lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -INFINITY;
1140         // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); }
1141     }
1142 
1143     // Compute the logsumexp of the LSE along the split dimension.
1144     ElementAccum lse_max = lse_accum(0);
1145     #pragma unroll
1146     for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); }
1147     MaxOp<float> max_op;
1148     lse_max = Allreduce<kRowsPerLoadTranspose>::run(lse_max, max_op);
1149     lse_max = lse_max == -INFINITY ? 0.0f : lse_max;  // In case all local LSEs are -inf
1150     float lse_sum = expf(lse_accum(0) - lse_max);
1151     #pragma unroll
1152     for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); }
1153     SumOp<float> sum_op;
1154     lse_sum = Allreduce<kRowsPerLoadTranspose>::run(lse_sum, sum_op);
1155     // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise
1156     // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.
1157     ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max;
1158     // Calculate valid rows for this block
1159     const int total_rows = params.b * params.h * params.seqlen_q;
1160     const int local_row = tidx / kRowsPerLoadTranspose;
1161     const int global_row = blockIdx.x * kBlockM + local_row;
1162 
1163     const bool is_reduction_writer = tidx % kRowsPerLoadTranspose == 0;
1164     const bool is_valid_row = (local_row < kBlockM) && (global_row < total_rows);
1165 
1166     if (is_reduction_writer && is_valid_row) {
1167         gLSE(local_row) = lse_logsum;
1168     }
1169     // Store the scales exp(lse - lse_logsum) in shared memory.
1170     #pragma unroll
1171     for (int l = 0; l < kNLsePerThread; ++l) {
1172         const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
1173         const int col = tidx / kRowsPerLoadTranspose;
1174         if (row < params.num_splits && col < kBlockM) { sLSE[row][col] = expf(lse_accum(l) - lse_logsum); }
1175     }
1176     __syncthreads();
1177 
1178     const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded;
1179     Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum),
1180                                  Shape<Int<kBlockM>, Int<kHeadDim>>{},
1181                                  Stride<Int<kHeadDim>, _1>{});
1182     constexpr int kBlockN = kNThreads / kBlockM;
1183     using GmemLayoutAtomOaccum = Layout<Shape<Int<kBlockM>, Int<kBlockN>>, Stride<Int<kBlockN>, _1>>;
1184     using GmemTiledCopyOaccum = decltype(
1185         make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
1186                         GmemLayoutAtomOaccum{},
1187                         Layout<Shape < _1, _4>>{}));  // Val layout, 4 vals per store
1188     GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
1189     auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
1190     Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);
1191     Tensor tOrO = make_tensor<ElementAccum>(shape(tOgOaccum));
1192     Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
1193     clear(tOrO);
1194 
1195     // Predicates
1196     Tensor cOaccum = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
1197     // Repeat the partitioning with identity layouts
1198     Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum);
1199     Tensor tOpOaccum = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
1200     if (!Is_even_K) {
1201         #pragma unroll
1202         for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; }
1203     }
1204     // Load Oaccum in then scale and accumulate to O
1205     for (int split = 0; split < params.num_splits; ++split) {
1206         pytorch_flash::copy</*Is_even_MN=*/false, Is_even_K>(
1207             gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM
1208         );
1209         #pragma unroll
1210         for (int m = 0; m < size<1>(tOrOaccum); ++m) {
1211             int row = get<0>(tOcOaccum(0, m, 0));
1212             ElementAccum lse_scale = sLSE[split][row];
1213             #pragma unroll
1214             for (int k = 0; k < size<2>(tOrOaccum); ++k) {
1215                 #pragma unroll
1216                 for (int i = 0; i < size<0>(tOrOaccum); ++i) {
1217                     tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k);
1218                 }
1219             }
1220         // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); }
1221         }
1222         tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded;
1223     }
1224     // if (cute::thread0()) { print_tensor(tOrO); }
1225 
1226     Tensor rO = pytorch_flash::convert_type<Element>(tOrO);
1227     // Write to gO
1228     #pragma unroll
1229     for (int m = 0; m < size<1>(rO); ++m) {
1230         const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0));
1231         if (idx < params.b * params.h * params.seqlen_q) {
1232             const int batch_idx = idx / (params.h * params.seqlen_q);
1233             const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q;
1234             // The index to the rows of Q
1235             const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q;
1236             auto o_ptr = reinterpret_cast<Element *>(params.o_ptr) + batch_idx * params.o_batch_stride
1237                 + head_idx * params.o_head_stride + row * params.o_row_stride;
1238             #pragma unroll
1239             for (int k = 0; k < size<2>(rO); ++k) {
1240                 if (Is_even_K || tOpOaccum(k)) {
1241                     const int col = get<1>(tOcOaccum(0, m, k));
1242                     Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col),
1243                                             Shape<Int<decltype(size<0>(rO))::value>>{}, Stride<_1>{});
1244                     // TODO: Should check if this is using vectorized store, but it seems pretty fast
1245                     copy(rO(_, m, k), gO);
1246                     // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); }
1247                     // reinterpret_cast<uint64_t *>(o_ptr)[col / 4] = recast<uint64_t>(rO)(0, m, k);
1248                 }
1249             }
1250         }
1251     }
1252 }
1253 
1254 } // namespace pytorch_flash
1255