Home
last modified time | relevance | path

Searched defs:sK (Results 1 – 2 of 2) sorted by relevance

/external/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/
Dflash_fwd_kernel.h145 Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), in compute_attn_1rowblock() local
592 Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); in compute_attn_1rowblock_splitkv() local
Dflash_bwd_kernel.h163 Tensor sK = make_tensor(sdO.data() + size(sdO), typename Kernel_traits::SmemLayoutKV{}); in compute_dq_dk_dv_1colblock() local