1 /******************************************************************************
2 * Copyright (c) 2024, Tri Dao.
3 ******************************************************************************/
4
5 #pragma once
6
7 #include <cute/algorithm/copy.hpp>
8
9 #include <cutlass/cutlass.h>
10 #include <cutlass/array.h>
11 #include <cutlass/numeric_types.h>
12
13
14 #include <ATen/native/transformers/cuda/flash_attn/block_info.h>
15 #include <ATen/native/transformers/cuda/flash_attn/kernel_traits.h>
16 #include <ATen/native/transformers/cuda/flash_attn/utils.h>
17 #include <ATen/native/transformers/cuda/flash_attn/softmax.h>
18 #include <ATen/native/transformers/cuda/flash_attn/mask.h>
19 #include <ATen/native/transformers/cuda/flash_attn/dropout.h>
20 #include <ATen/native/transformers/cuda/flash_attn/rotary.h>
21
22 namespace pytorch_flash {
23
24 using namespace cute;
25
26 ////////////////////////////////////////////////////////////////////////////////////////////////////
27
28 template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
compute_attn_1rowblock(const Params & params,const int bidb,const int bidh,const int m_block)29 inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) {
30
31 using Element = typename Kernel_traits::Element;
32 using ElementAccum = typename Kernel_traits::ElementAccum;
33 using index_t = typename Kernel_traits::index_t;
34
35 // Shared memory.
36 extern __shared__ char smem_[];
37
38 // The thread index.
39 const int tidx = threadIdx.x;
40
41 constexpr int kBlockM = Kernel_traits::kBlockM;
42 constexpr int kBlockN = Kernel_traits::kBlockN;
43 constexpr int kHeadDim = Kernel_traits::kHeadDim;
44 constexpr int kNWarps = Kernel_traits::kNWarps;
45
46 auto seed_offset = at::cuda::philox::unpack(params.philox_args);
47 pytorch_flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t,
48 bidb, bidh, tidx, params.h);
49
50 // Save seed and offset for backward. If we don't have this here, the 0-th thread block might
51 // exit early and no one saves the rng state.
52 if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
53 if (params.philox_args.captured_) {
54 *params.seed = std::get<0>(seed_offset);
55 *params.extragraph_offset = std::get<1>(seed_offset);
56 }
57 }
58
59 const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
60 if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
61
62 const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
63 int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
64 if (Is_causal || Is_local) {
65 n_block_max = std::min(n_block_max,
66 cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
67 // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
68 // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
69 // }
70 }
71 // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0.
72 // Otherwise we might read OOB elements from gK and gV.
73 if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) {
74 Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr)
75 + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)),
76 make_shape(binfo.actual_seqlen_q, params.h, params.d),
77 make_stride(params.o_row_stride, params.o_head_stride, _1{}));
78 Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
79 make_coord(m_block, 0)); // (kBlockM, kHeadDim)
80 Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
81 make_shape(params.b, params.h, params.seqlen_q),
82 make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
83 Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
84
85
86 typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
87 auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
88 Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
89 Tensor tOrO = make_tensor<Element>(shape(tOgO));
90 clear(tOrO);
91 // Construct identity layout for sO
92 Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
93 // Repeat the partitioning with identity layouts
94 Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
95 Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
96 if (!Is_even_K) {
97 #pragma unroll
98 for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
99 }
100 // Clear_OOB_K must be false since we don't want to write zeros to gmem
101 pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
102 gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
103 );
104 #pragma unroll
105 for (int m = 0; m < size<1>(tOgO); ++m) {
106 const int row = get<0>(tOcO(0, m, 0));
107 if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; }
108 }
109 return;
110 }
111 // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); }
112
113 // We iterate over the blocks in reverse order. This is because the last block is the only one
114 // that needs masking when we read K and V from global memory. Moreover, iterating in reverse
115 // might save us 1 register (we just need n_block instead of both n_block and n_block_max).
116
117 const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded
118 + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN;
119
120 Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr)
121 + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
122 make_shape(binfo.actual_seqlen_q, params.h, params.d),
123 make_stride(params.q_row_stride, params.q_head_stride, _1{}));
124 Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
125 make_coord(m_block, 0)); // (kBlockM, kHeadDim)
126 Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.k_ptr)
127 + binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)),
128 make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
129 make_stride(params.k_row_stride, params.k_head_stride, _1{}));
130 Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _), Shape<Int<kBlockN>, Int<kHeadDim>>{},
131 make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN)
132 Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.v_ptr)
133 + binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)),
134 make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
135 make_stride(params.v_row_stride, params.v_head_stride, _1{}));
136 Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _), Shape<Int<kBlockN>, Int<kHeadDim>>{},
137 make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN)
138 Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.p_ptr) + row_offset_p),
139 Shape<Int<kBlockM>, Int<kBlockN>>{},
140 make_stride(params.seqlen_k_rounded, _1{}));
141
142 Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
143 typename Kernel_traits::SmemLayoutQ{});
144 // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
145 Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)),
146 typename Kernel_traits::SmemLayoutKV{});
147 Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
148 Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
149 Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
150
151 typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
152 auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
153
154 Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
155 Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
156 Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN)
157 Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
158 Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN)
159 Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
160
161 typename Kernel_traits::TiledMma tiled_mma;
162 auto thr_mma = tiled_mma.get_thread_slice(tidx);
163 Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
164 Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
165 Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
166
167 Tensor tSgS = thr_mma.partition_C(gP);
168
169 Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
170
171 //
172 // Copy Atom retiling
173 //
174
175 auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
176 auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
177 // if (cute::thread0()) {smem_thr_copy_Q.print_all();}
178 Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
179 // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
180
181 auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
182 auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
183 Tensor tSsK = smem_thr_copy_K.partition_S(sK);
184
185 auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
186 auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
187 Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
188
189 //
190 // PREDICATES
191 //
192
193 // // Allocate predicate tensors for m and n
194 // Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{});
195 // Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{});
196
197 // Construct identity layout for sQ and sK
198 Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
199 Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
200 // Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K)
201 // if (cute::thread0()) {
202 // print(tScQ.layout()); printf("\n");
203 // for (int i = 0; i < size(tScQ); ++i) {
204 // printf("%d ", get<0>(tScQ(i)));
205 // }
206 // printf("\n");
207 // for (int i = 0; i < size(tScQ); ++i) {
208 // printf("%d ", get<1>(tScQ(i)));
209 // }
210 // printf("\n");
211 // }
212
213 // Repeat the partitioning with identity layouts
214 Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
215 Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
216
217 // Allocate predicate tensors for k
218 Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
219 Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
220
221 // Set predicates for k bounds
222 if (!Is_even_K) {
223 #pragma unroll
224 for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }
225 #pragma unroll
226 for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; }
227 }
228
229 // Prologue
230
231 // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
232 pytorch_flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
233 binfo.actual_seqlen_q - m_block * kBlockM);
234 if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
235
236 // // if (cute::thread(1, 0)) { print(tQsQ); }
237 // // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{});
238 // // if (cute::thread0()) { print(sQNoSwizzle); }
239
240 if (Kernel_traits::Share_Q_K_smem) {
241 pytorch_flash::cp_async_wait<0>();
242 __syncthreads();
243 Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
244 CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
245 cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
246 __syncthreads();
247 }
248
249 int n_block = n_block_max - 1;
250 // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
251 pytorch_flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV,
252 binfo.actual_seqlen_k - n_block * kBlockN);
253 cute::cp_async_fence();
254 // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
255 // __syncthreads();
256
257 if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) {
258 pytorch_flash::cp_async_wait<1>();
259 __syncthreads();
260 Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
261 CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
262 cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
263 }
264
265 clear(acc_o);
266
267 pytorch_flash::Softmax<2 * size<1>(acc_o)> softmax;
268
269 const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
270 pytorch_flash::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);
271
272 // For performance reason, we separate out two kinds of iterations:
273 // those that need masking on S, and those that don't.
274 // We need masking on S for the very last block when K and V has length not multiple of kBlockN.
275 // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
276 // We will have at least 1 "masking" iteration.
277
278 // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
279 // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
280 constexpr int n_masking_steps = (!Is_causal && !Is_local)
281 ? 1
282 : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
283 #pragma unroll
284 for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
285 Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
286 clear(acc_s);
287 pytorch_flash::cp_async_wait<0>();
288 __syncthreads();
289
290 // Advance gV
291 if (masking_step > 0) {
292 pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
293 } else {
294 // Clear the smem tiles to account for predicated off loads
295 pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
296 gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
297 );
298 }
299 cute::cp_async_fence();
300 cute::cp_async_fence();
301
302 pytorch_flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
303 acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
304 smem_thr_copy_Q, smem_thr_copy_K
305 );
306 // if (cute::thread0()) { print(acc_s); }
307
308 mask.template apply_mask<Is_causal, Is_even_MN>(
309 acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
310 );
311
312 pytorch_flash::cp_async_wait<0>();
313 __syncthreads();
314 if (n_block > n_block_min) {
315 pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV);
316 // This cp_async_fence needs to be in the if block, otherwise the synchronization
317 // isn't right and we get race conditions.
318 cute::cp_async_fence();
319 }
320
321 // TODO: when we have key_padding_mask we'll need to Check_inf
322 masking_step == 0
323 ? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2)
324 : softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2);
325
326 // Convert acc_s from fp32 to fp16/bf16
327 Tensor rP = pytorch_flash::convert_type<Element>(acc_s);
328 int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
329 int block_col_idx = n_block * (kBlockN / 32);
330 if (Return_softmax) {
331 Tensor rP_drop = make_fragment_like(rP);
332 cute::copy(rP, rP_drop);
333 dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
334 rP_drop, block_row_idx, block_col_idx, kNWarps
335 );
336 cute::copy(rP_drop, tSgS);
337 tSgS.data() = tSgS.data() + (-kBlockN);
338 }
339 if (Is_dropout) {
340 dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);
341 }
342
343 // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
344 // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
345 Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<typename Kernel_traits::TiledMma>(rP.layout()));
346 // if (cute::thread0()) { print(tOrP); }
347 pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
348 // if (cute::thread0()) { print(scores); }
349
350 // This check is at the end of the loop since we always have at least 1 iteration
351 if (n_masking_steps > 1 && n_block <= n_block_min) {
352 --n_block;
353 break;
354 }
355 }
356
357 // These are the iterations where we don't need masking on S
358 for (; n_block >= n_block_min; --n_block) {
359 Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
360 clear(acc_s);
361 pytorch_flash::cp_async_wait<0>();
362 __syncthreads();
363 pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
364 cute::cp_async_fence();
365
366 pytorch_flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
367 acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
368 smem_thr_copy_Q, smem_thr_copy_K
369 );
370
371 pytorch_flash::cp_async_wait<0>();
372 __syncthreads();
373 if (n_block > n_block_min) {
374 pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV);
375 // This cp_async_fence needs to be in the if block, otherwise the synchronization
376 // isn't right and we get race conditions.
377 cute::cp_async_fence();
378 }
379
380 mask.template apply_mask</*Causal_mask=*/false>(
381 acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
382 );
383
384 softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2);
385
386 Tensor rP = pytorch_flash::convert_type<Element>(acc_s);
387 int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
388 int block_col_idx = n_block * (kBlockN / 32);
389 if (Return_softmax) {
390 Tensor rP_drop = make_fragment_like(rP);
391 cute::copy(rP, rP_drop);
392 dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
393 rP_drop, block_row_idx, block_col_idx, kNWarps
394 );
395 cute::copy(rP_drop, tSgS);
396 tSgS.data() = tSgS.data() + (-kBlockN);
397 }
398 if (Is_dropout) {
399 dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);
400 }
401
402 // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
403 // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
404 Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<typename Kernel_traits::TiledMma>(rP.layout()));
405 pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
406 }
407
408 // Epilogue
409
410 Tensor lse = softmax.template normalize_softmax_lse<Is_dropout>(acc_o, params.scale_softmax, params.rp_dropout);
411
412 // Convert acc_o from fp32 to fp16/bf16
413 Tensor rO = pytorch_flash::convert_type<Element>(acc_o);
414 Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
415 // Partition sO to match the accumulator partitioning
416 auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
417 auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
418 Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
419 Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
420
421 // sO has the same size as sQ, so we don't need to sync here.
422 if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); }
423
424 cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
425
426 Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr)
427 + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)),
428 make_shape(binfo.actual_seqlen_q, params.h, params.d),
429 make_stride(params.o_row_stride, params.o_head_stride, _1{}));
430 Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
431 make_coord(m_block, 0)); // (kBlockM, kHeadDim)
432 Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
433 make_shape(params.b, params.h, params.seqlen_q),
434 make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
435 Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
436
437 typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
438 auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
439 Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
440 Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
441
442 __syncthreads();
443
444 Tensor tOrO = make_tensor<Element>(shape(tOgO));
445 cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
446
447 Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
448 Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
449 static_assert(decltype(size<0>(taccOcO))::value == 4);
450 // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
451 Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0);
452 CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
453 if (get<1>(taccOcO_row(0)) == 0) {
454 #pragma unroll
455 for (int mi = 0; mi < size(lse); ++mi) {
456 const int row = get<0>(taccOcO_row(mi));
457 if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); }
458 }
459 }
460
461 // Construct identity layout for sO
462 Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
463 // Repeat the partitioning with identity layouts
464 Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
465 Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
466 if (!Is_even_K) {
467 #pragma unroll
468 for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
469 }
470 // Clear_OOB_K must be false since we don't want to write zeros to gmem
471 pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
472 gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
473 );
474 }
475
476 ////////////////////////////////////////////////////////////////////////////////////////////////////
477
478 template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
compute_attn_1rowblock_splitkv(const Params & params,const int bidb,const int bidh,const int m_block,const int n_split_idx,const int num_n_splits)479 inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) {
480
481 using Element = typename Kernel_traits::Element;
482 using ElementAccum = typename Kernel_traits::ElementAccum;
483 using index_t = typename Kernel_traits::index_t;
484
485 // Shared memory.
486 extern __shared__ char smem_[];
487
488 // The thread index.
489 const int tidx = threadIdx.x;
490
491 constexpr int kBlockM = Kernel_traits::kBlockM;
492 constexpr int kBlockN = Kernel_traits::kBlockN;
493 constexpr int kHeadDim = Kernel_traits::kHeadDim;
494 constexpr int kNWarps = Kernel_traits::kNWarps;
495
496 using GmemTiledCopyO = std::conditional_t<
497 !Split,
498 typename Kernel_traits::GmemTiledCopyO,
499 typename Kernel_traits::GmemTiledCopyOaccum
500 >;
501 using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
502
503 const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
504 // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); }
505 // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); }
506 if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
507
508 const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits;
509 const int n_block_min = !Is_local
510 ? n_split_idx * n_blocks_per_split
511 : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
512 int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split);
513 if (Is_causal || Is_local) {
514 n_block_max = std::min(n_block_max,
515 cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
516 }
517 if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0
518 // We exit early and write 0 to gOaccum and -inf to gLSEaccum.
519 // Otherwise we might read OOB elements from gK and gV,
520 // or get wrong results when we combine gOaccum from different blocks.
521 const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
522 + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
523 const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q
524 + m_block * kBlockM) * params.d_rounded;
525 const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
526 Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
527 Shape<Int<kBlockM>, Int<kHeadDim>>{},
528 make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));
529 Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum),
530 Shape<Int<kBlockM>>{}, Stride<_1>{});
531
532 GmemTiledCopyO gmem_tiled_copy_Oaccum;
533 auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
534 Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
535 Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
536 clear(tOrOaccum);
537 // Construct identity layout for sO
538 Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
539 // Repeat the partitioning with identity layouts
540 Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO);
541 Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
542 if (!Is_even_K) {
543 #pragma unroll
544 for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
545 }
546 // Clear_OOB_K must be false since we don't want to write zeros to gmem
547 pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
548 gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
549 );
550 #pragma unroll
551 for (int m = 0; m < size<1>(tOgOaccum); ++m) {
552 const int row = get<0>(tOcO(0, m, 0));
553 if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = Split ? -INFINITY : INFINITY; }
554 }
555 return;
556 }
557
558 // We iterate over the blocks in reverse order. This is because the last block is the only one
559 // that needs masking when we read K and V from global memory. Moreover, iterating in reverse
560 // might save us 1 register (we just need n_block instead of both n_block and n_block_max).
561
562
563 // We move K and V to the last block.
564 const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
565 const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride;
566 const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size;
567 const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size;
568 const index_t row_offset_k = block_table == nullptr
569 ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache)
570 + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride
571 : block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
572 const index_t row_offset_v = block_table == nullptr
573 ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
574 + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
575 : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
576
577 Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
578 make_shape(binfo.actual_seqlen_q, params.h, params.d),
579 make_stride(params.q_row_stride, params.q_head_stride, _1{}));
580 Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
581 make_coord(m_block, 0)); // (kBlockM, kHeadDim)
582 Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
583 Shape<Int<kBlockN>, Int<kHeadDim>>{},
584 make_stride(params.k_row_stride, _1{}));
585 // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); }
586 Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
587 Shape<Int<kBlockN>, Int<kHeadDim>>{},
588 make_stride(params.v_row_stride, _1{}));
589
590 Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
591 typename Kernel_traits::SmemLayoutQ{});
592 Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{});
593 Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
594 Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
595 Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
596
597 typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
598 auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
599
600 Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
601 Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
602 Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
603 Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
604 Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
605 Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
606
607 typename Kernel_traits::TiledMma tiled_mma;
608 auto thr_mma = tiled_mma.get_thread_slice(tidx);
609 Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
610 Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
611 Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
612
613 Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
614
615 //
616 // Copy Atom retiling
617 //
618
619 auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
620 auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
621 Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
622
623 auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
624 auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
625 Tensor tSsK = smem_thr_copy_K.partition_S(sK);
626
627 auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
628 auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
629 Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
630
631 // PREDICATES
632 //
633
634 // // Allocate predicate tensors for m and n
635 // Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{});
636 // Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{});
637
638 // Construct identity layout for sQ and sK
639 Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
640 Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
641
642 // Repeat the partitioning with identity layouts
643 Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
644 Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
645
646 // Allocate predicate tensors for k
647 Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
648 Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
649
650 // Set predicates for k bounds
651 if (!Is_even_K) {
652 #pragma unroll
653 for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }
654 #pragma unroll
655 for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; }
656 }
657
658 // Prologue
659
660 // Copy from Knew to K, optionally apply rotary embedding.
661 typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary;
662 auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx);
663 typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont;
664 auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx);
665 if constexpr (Append_KV) {
666 // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
667 // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
668 // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
669 const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2);
670 Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
671 Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
672 make_stride(params.rotary_dim / 2, _1{}));
673 Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
674 Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
675 make_stride(params.rotary_dim / 2, _1{}));
676 Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
677 Shape<Int<kBlockN>, Int<kHeadDim>>{},
678 make_stride(params.rotary_dim / 2, _1{}));
679 Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
680 Shape<Int<kBlockN>, Int<kHeadDim>>{},
681 make_stride(params.rotary_dim / 2, _1{}));
682 Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
683 Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
684 Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
685 Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
686 // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); }
687 // if (cute::thread(8, 0)) { print_tensor(gCos); }
688 // if (cute::thread(0, 0)) { print_tensor(tRgCos); }
689
690 const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
691 + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
692 const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
693 + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
694 // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
695 // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
696 // This maps to accessing the first 64 rows of knew_ptr.
697 Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.knew_ptr)
698 + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride),
699 Shape<Int<kBlockN>, Int<kHeadDim>>{},
700 make_stride(params.knew_row_stride, _1{}));
701 // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); }
702 Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.vnew_ptr)
703 + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride),
704 Shape<Int<kBlockN>, Int<kHeadDim>>{},
705 make_stride(params.vnew_row_stride, _1{}));
706 Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K)
707 Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)
708
709 const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN);
710 auto tKgK_data = tKgK.data();
711 auto tVgV_data = tVgV.data();
712 for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) {
713 pytorch_flash::copy_w_min_idx<Is_even_K>(
714 tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
715 );
716 tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
717 if (params.rotary_dim == 0) {
718 pytorch_flash::copy_w_min_idx<Is_even_K>(
719 tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
720 );
721 } else {
722 if (params.is_rotary_interleaved) {
723 // Don't clear OOB_K because we're writing to global memory
724 pytorch_flash::copy_rotary_interleaved<Is_even_K, /*Clear_OOB_K=*/false>(
725 tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN,
726 binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim
727 );
728 tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2));
729 tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2));
730 } else {
731 // Don't clear OOB_K because we're writing to global memory
732 pytorch_flash::copy_rotary_contiguous<Is_even_K, /*Clear_OOB_K=*/false>(
733 tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN,
734 binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim
735 );
736 tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2));
737 tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2));
738
739 }
740 }
741 tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
742 if (block_table == nullptr) {
743 tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
744 tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
745 } else {
746 if (n_block > n_block_copy_min) {
747 const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
748 const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
749 const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
750 const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
751 const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur];
752 const int offset_diff = block_table_offset_next - block_table_offset_cur;
753 tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;
754 tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;
755 }
756 }
757 }
758 // Need this before we can read in K again, so that we'll see the updated K values.
759 __syncthreads();
760 tKgK.data() = tKgK_data;
761 tVgV.data() = tVgV_data;
762 }
763
764 // Read Q from gmem to smem, optionally apply rotary embedding.
765 if (!Append_KV || params.rotary_dim == 0) {
766 // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
767 pytorch_flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
768 binfo.actual_seqlen_q - m_block * kBlockM);
769 } else {
770 const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
771 // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
772 // We do this by setting the row stride of gCos / gSin to 0.
773 Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
774 Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},
775 make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
776 Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
777 Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},
778 make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
779 Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
780 Shape<Int<kBlockM>, Int<kHeadDim>>{},
781 make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
782 Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
783 Shape<Int<kBlockM>, Int<kHeadDim>>{},
784 make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
785 Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
786 Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
787 Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
788 Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
789 if (params.is_rotary_interleaved) {
790 pytorch_flash::copy_rotary_interleaved<Is_even_K>(
791 tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM,
792 0, params.d, params.rotary_dim
793 );
794 } else {
795 pytorch_flash::copy_rotary_contiguous<Is_even_K>(
796 tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM,
797 0, params.d, params.rotary_dim
798 );
799 }
800 }
801
802 int n_block = n_block_max - 1;
803 // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
804 pytorch_flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
805 binfo.actual_seqlen_k - n_block * kBlockN);
806 cute::cp_async_fence();
807
808 // pytorch_flash::cp_async_wait<0>();
809 // __syncthreads();
810 // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); }
811 // __syncthreads();
812
813 clear(acc_o);
814
815 pytorch_flash::Softmax<2 * size<1>(acc_o)> softmax;
816
817 const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
818 pytorch_flash::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);
819
820 // For performance reason, we separate out two kinds of iterations:
821 // those that need masking on S, and those that don't.
822 // We need masking on S for the very last block when K and V has length not multiple of kBlockN.
823 // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
824 // We will have at least 1 "masking" iteration.
825
826 // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
827 // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
828 constexpr int n_masking_steps = (!Is_causal && !Is_local)
829 ? 1
830 : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
831 #pragma unroll
832 for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
833 Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
834 clear(acc_s);
835 pytorch_flash::cp_async_wait<0>();
836 __syncthreads();
837
838 // Advance gV
839 if (masking_step > 0) {
840 if (block_table == nullptr) {
841 tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
842 } else {
843 const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
844 const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
845 const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
846 const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
847 tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
848 }
849 pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
850 } else {
851 // Clear the smem tiles to account for predicated off loads
852 pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
853 gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
854 );
855 }
856 cute::cp_async_fence();
857
858 pytorch_flash::gemm(
859 acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
860 smem_thr_copy_Q, smem_thr_copy_K
861 );
862 // if (cute::thread0()) { print(acc_s); }
863
864 mask.template apply_mask<Is_causal, Is_even_MN>(
865 acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
866 );
867
868 pytorch_flash::cp_async_wait<0>();
869 __syncthreads();
870 // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); }
871 // __syncthreads();
872
873 if (n_block > n_block_min) {
874 // Advance gK
875 if (block_table == nullptr) {
876 tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
877 } else {
878 const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
879 const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
880 const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
881 const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
882 tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
883 }
884 pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
885 // This cp_async_fence needs to be in the if block, otherwise the synchronization
886 // isn't right and we get race conditions.
887 cute::cp_async_fence();
888 }
889
890 // We have key_padding_mask so we'll need to Check_inf
891 masking_step == 0
892 ? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, acc_o, params.scale_softmax_log2)
893 : softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, acc_o, params.scale_softmax_log2);
894 // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
895
896 // Convert acc_s from fp32 to fp16/bf16
897 Tensor rP = pytorch_flash::convert_type<Element>(acc_s);
898 // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
899 // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
900 Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<typename Kernel_traits::TiledMma>(rP.layout()));
901
902 pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
903
904 // This check is at the end of the loop since we always have at least 1 iteration
905 if (n_masking_steps > 1 && n_block <= n_block_min) {
906 --n_block;
907 break;
908 }
909 }
910
911 // These are the iterations where we don't need masking on S
912 for (; n_block >= n_block_min; --n_block) {
913 Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
914 clear(acc_s);
915 pytorch_flash::cp_async_wait<0>();
916 __syncthreads();
917 // Advance gV
918 if (block_table == nullptr) {
919 tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
920 } else {
921 const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
922 const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
923 const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
924 const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
925 tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
926 }
927 pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
928 cute::cp_async_fence();
929
930 pytorch_flash::gemm(
931 acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
932 smem_thr_copy_Q, smem_thr_copy_K
933 );
934
935 pytorch_flash::cp_async_wait<0>();
936 __syncthreads();
937 if (n_block > n_block_min) {
938 // Advance gK
939 if (block_table == nullptr) {
940 tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
941 } else {
942 const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
943 const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
944 const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
945 const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
946 tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
947 }
948 pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
949 // This cp_async_fence needs to be in the if block, otherwise the synchronization
950 // isn't right and we get race conditions.
951 cute::cp_async_fence();
952 }
953
954 mask.template apply_mask</*Causal_mask=*/false>(
955 acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
956 );
957 softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2);
958
959 Tensor rP = pytorch_flash::convert_type<Element>(acc_s);
960 // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
961 // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
962 Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<typename Kernel_traits::TiledMma>(rP.layout()));
963
964 pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
965 }
966
967 // Epilogue
968
969 Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(acc_o, params.scale_softmax);
970 // if (cute::thread0()) { print(lse); }
971
972 Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
973 // Partition sO to match the accumulator partitioning
974 using SmemTiledCopyO = std::conditional_t<
975 !Split,
976 typename Kernel_traits::SmemCopyAtomO,
977 typename Kernel_traits::SmemCopyAtomOaccum
978 >;
979 auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma);
980 auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
981 Tensor rO = pytorch_flash::convert_type<ElementO>(acc_o);
982 Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
983 Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N)
984
985 // sOaccum is larger than sQ, so we need to syncthreads here
986 // TODO: allocate enough smem for sOaccum
987 if constexpr (Split) { __syncthreads(); }
988
989 cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
990
991 const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
992 + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
993 const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q
994 + m_block * kBlockM) * params.d_rounded;
995 const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
996
997 Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
998 Shape<Int<kBlockM>, Int<kHeadDim>>{},
999 make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));
1000 Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum),
1001 Shape<Int<kBlockM>>{}, Stride<_1>{});
1002 // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); }
1003
1004 GmemTiledCopyO gmem_tiled_copy_Oaccum;
1005 auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
1006 Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N)
1007 Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
1008
1009 __syncthreads();
1010
1011 Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
1012 cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
1013
1014 Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
1015 Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
1016 static_assert(decltype(size<0>(taccOcO))::value == 4);
1017 // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
1018 Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0);
1019 CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
1020 if (get<1>(taccOcO_row(0)) == 0) {
1021 #pragma unroll
1022 for (int mi = 0; mi < size(lse); ++mi) {
1023 const int row = get<0>(taccOcO_row(mi));
1024 if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
1025 }
1026 }
1027
1028 // Construct identity layout for sO
1029 Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
1030 // Repeat the partitioning with identity layouts
1031 Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
1032 Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
1033 if (!Is_even_K) {
1034 #pragma unroll
1035 for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
1036 }
1037 // Clear_OOB_K must be false since we don't want to write zeros to gmem
1038 pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
1039 gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
1040 );
1041 }
1042
1043 ////////////////////////////////////////////////////////////////////////////////////////////////////
1044
1045 template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
compute_attn(const Params & params)1046 inline __device__ void compute_attn(const Params ¶ms) {
1047 const int m_block = blockIdx.x;
1048 // The block index for the batch.
1049 const int bidb = blockIdx.y;
1050 // The block index for the head.
1051 const int bidh = blockIdx.z;
1052
1053 // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting
1054 // them to have the same number of threads or have to traverse the attention matrix
1055 // in the same order.
1056 // In the Philox RNG, we use the offset to store the batch, head, and the lane id
1057 // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within
1058 // the attention matrix. This way, as long as we have the batch, head, and the location of
1059 // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
1060
1061 pytorch_flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
1062 }
1063
1064 ////////////////////////////////////////////////////////////////////////////////////////////////////
1065
1066 template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
compute_attn_splitkv(const Params & params)1067 inline __device__ void compute_attn_splitkv(const Params ¶ms) {
1068 const int m_block = blockIdx.x;
1069 // The block index for the batch.
1070 const int bidb = Split ? blockIdx.z / params.h : blockIdx.y;
1071 // The block index for the head.
1072 const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;
1073 const int n_split_idx = Split ? blockIdx.y : 0;
1074 const int num_n_splits = Split ? gridDim.y : 1;
1075 pytorch_flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
1076 }
1077
1078 ////////////////////////////////////////////////////////////////////////////////////////////////////
1079
1080 template <typename T>
ceil_div(T numerator,T denominator)1081 constexpr T ceil_div(T numerator, T denominator) {
1082 return (numerator + denominator - 1) / denominator;
1083 }
1084
1085
1086 template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K, typename Params>
combine_attn_seqk_parallel(const Params & params)1087 inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
1088 using Element = typename Kernel_traits::Element;
1089 using ElementAccum = typename Kernel_traits::ElementAccum;
1090 using index_t = typename Kernel_traits::index_t;
1091 constexpr int kMaxSplits = 1 << Log_max_splits;
1092 constexpr int kHeadDim = Kernel_traits::kHeadDim;
1093 constexpr int kNThreads = Kernel_traits::kNThreads;
1094
1095 static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128");
1096 static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32");
1097 static_assert(kNThreads == 128, "We assume that each block has 128 threads");
1098
1099 // Shared memory.
1100 // kBlockM + 1 instead of kBlockM to reduce bank conflicts.
1101 __shared__ ElementAccum sLSE[kMaxSplits][kBlockM + 1];
1102
1103 // The thread and block index.
1104 const int tidx = threadIdx.x;
1105 const int bidx = blockIdx.x;
1106
1107 const index_t row_offset_lse = bidx * kBlockM;
1108 Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lse),
1109 Shape<Int<kMaxSplits>, Int<kBlockM>>{},
1110 make_stride(params.b * params.h * params.seqlen_q, _1{}));
1111 Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
1112 Shape<Int<kBlockM>>{}, Stride<_1>{});
1113 constexpr int kNLsePerThread = ceil_div(kMaxSplits * kBlockM, kNThreads);
1114 // Read the LSE values from gmem and store them in shared memory, then tranpose them.
1115 constexpr int kRowsPerLoadLSE = kNThreads / kBlockM;
1116 #pragma unroll
1117 for (int l = 0; l < kNLsePerThread; ++l) {
1118 const int row = l * kRowsPerLoadLSE + tidx / kBlockM;
1119 const int col = tidx % kBlockM;
1120 ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY;
1121 if (row < kMaxSplits) { sLSE[row][col] = lse; }
1122 // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); }
1123 }
1124 // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); }
1125 __syncthreads();
1126 Tensor lse_accum = make_tensor<ElementAccum>(Shape<Int<kNLsePerThread>>{});
1127 constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits);
1128 // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits
1129 // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads,
1130 // kBlockM rows, so each time we load we can load 128 / kBlockM rows).
1131 // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose;
1132 // static_assert(kThreadsPerSplit <= 32);
1133 static_assert(kRowsPerLoadTranspose <= 32);
1134 static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits);
1135 #pragma unroll
1136 for (int l = 0; l < kNLsePerThread; ++l) {
1137 const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
1138 const int col = tidx / kRowsPerLoadTranspose;
1139 lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -INFINITY;
1140 // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); }
1141 }
1142
1143 // Compute the logsumexp of the LSE along the split dimension.
1144 ElementAccum lse_max = lse_accum(0);
1145 #pragma unroll
1146 for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); }
1147 MaxOp<float> max_op;
1148 lse_max = Allreduce<kRowsPerLoadTranspose>::run(lse_max, max_op);
1149 lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf
1150 float lse_sum = expf(lse_accum(0) - lse_max);
1151 #pragma unroll
1152 for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); }
1153 SumOp<float> sum_op;
1154 lse_sum = Allreduce<kRowsPerLoadTranspose>::run(lse_sum, sum_op);
1155 // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise
1156 // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.
1157 ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max;
1158 // Calculate valid rows for this block
1159 const int total_rows = params.b * params.h * params.seqlen_q;
1160 const int local_row = tidx / kRowsPerLoadTranspose;
1161 const int global_row = blockIdx.x * kBlockM + local_row;
1162
1163 const bool is_reduction_writer = tidx % kRowsPerLoadTranspose == 0;
1164 const bool is_valid_row = (local_row < kBlockM) && (global_row < total_rows);
1165
1166 if (is_reduction_writer && is_valid_row) {
1167 gLSE(local_row) = lse_logsum;
1168 }
1169 // Store the scales exp(lse - lse_logsum) in shared memory.
1170 #pragma unroll
1171 for (int l = 0; l < kNLsePerThread; ++l) {
1172 const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
1173 const int col = tidx / kRowsPerLoadTranspose;
1174 if (row < params.num_splits && col < kBlockM) { sLSE[row][col] = expf(lse_accum(l) - lse_logsum); }
1175 }
1176 __syncthreads();
1177
1178 const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded;
1179 Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum),
1180 Shape<Int<kBlockM>, Int<kHeadDim>>{},
1181 Stride<Int<kHeadDim>, _1>{});
1182 constexpr int kBlockN = kNThreads / kBlockM;
1183 using GmemLayoutAtomOaccum = Layout<Shape<Int<kBlockM>, Int<kBlockN>>, Stride<Int<kBlockN>, _1>>;
1184 using GmemTiledCopyOaccum = decltype(
1185 make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
1186 GmemLayoutAtomOaccum{},
1187 Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
1188 GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
1189 auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
1190 Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);
1191 Tensor tOrO = make_tensor<ElementAccum>(shape(tOgOaccum));
1192 Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
1193 clear(tOrO);
1194
1195 // Predicates
1196 Tensor cOaccum = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
1197 // Repeat the partitioning with identity layouts
1198 Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum);
1199 Tensor tOpOaccum = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
1200 if (!Is_even_K) {
1201 #pragma unroll
1202 for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; }
1203 }
1204 // Load Oaccum in then scale and accumulate to O
1205 for (int split = 0; split < params.num_splits; ++split) {
1206 pytorch_flash::copy</*Is_even_MN=*/false, Is_even_K>(
1207 gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM
1208 );
1209 #pragma unroll
1210 for (int m = 0; m < size<1>(tOrOaccum); ++m) {
1211 int row = get<0>(tOcOaccum(0, m, 0));
1212 ElementAccum lse_scale = sLSE[split][row];
1213 #pragma unroll
1214 for (int k = 0; k < size<2>(tOrOaccum); ++k) {
1215 #pragma unroll
1216 for (int i = 0; i < size<0>(tOrOaccum); ++i) {
1217 tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k);
1218 }
1219 }
1220 // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); }
1221 }
1222 tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded;
1223 }
1224 // if (cute::thread0()) { print_tensor(tOrO); }
1225
1226 Tensor rO = pytorch_flash::convert_type<Element>(tOrO);
1227 // Write to gO
1228 #pragma unroll
1229 for (int m = 0; m < size<1>(rO); ++m) {
1230 const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0));
1231 if (idx < params.b * params.h * params.seqlen_q) {
1232 const int batch_idx = idx / (params.h * params.seqlen_q);
1233 const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q;
1234 // The index to the rows of Q
1235 const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q;
1236 auto o_ptr = reinterpret_cast<Element *>(params.o_ptr) + batch_idx * params.o_batch_stride
1237 + head_idx * params.o_head_stride + row * params.o_row_stride;
1238 #pragma unroll
1239 for (int k = 0; k < size<2>(rO); ++k) {
1240 if (Is_even_K || tOpOaccum(k)) {
1241 const int col = get<1>(tOcOaccum(0, m, k));
1242 Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col),
1243 Shape<Int<decltype(size<0>(rO))::value>>{}, Stride<_1>{});
1244 // TODO: Should check if this is using vectorized store, but it seems pretty fast
1245 copy(rO(_, m, k), gO);
1246 // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); }
1247 // reinterpret_cast<uint64_t *>(o_ptr)[col / 4] = recast<uint64_t>(rO)(0, m, k);
1248 }
1249 }
1250 }
1251 }
1252 }
1253
1254 } // namespace pytorch_flash
1255