Home
last modified time | relevance | path

Searched defs:kBlockM (Results 1 – 5 of 5) sorted by relevance

/external/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/
Dflash_fwd_launch_template.h52 DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool… in DEFINE_FLASH_FORWARD_KERNEL()
141 …constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim %… in run_flash_splitkv_fwd() local
166 constexpr static int kBlockM = 64; // Fixed for all head dimensions in run_mha_fwd_splitkv_dispatch() local
Dflash_bwd_preprocess_kernel.h70 constexpr int kBlockM = Kernel_traits::kBlockM; in compute_dot_do_o() local
197 constexpr int kBlockM = Kernel_traits::kBlockM; in convert_dQ() local
Dkernel_traits.h68 static constexpr int kBlockM = kBlockM_; member
184 static constexpr int kBlockM = kBlockM_; member
Dflash_bwd_kernel.h92 constexpr int kBlockM = Kernel_traits::kBlockM; in compute_dq_dk_dv_1colblock() local
513 …if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - … in compute_dq_dk_dv_1colblock() local
Dflash_fwd_kernel.h41 constexpr int kBlockM = Kernel_traits::kBlockM; in compute_attn_1rowblock() local
491 constexpr int kBlockM = Kernel_traits::kBlockM; in compute_attn_1rowblock_splitkv() local