Home
last modified time | relevance | path

Searched defs:kBlockN (Results 1 – 5 of 5) sorted by relevance

/external/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/
Dflash_bwd_preprocess_kernel.h153 constexpr int kBlockN = Kernel_traits::kBlockN; in clear_dKVaccum() local
285 constexpr int kBlockN = Kernel_traits::kBlockN; in convert_dKV() local
Dkernel_traits.h69 static constexpr int kBlockN = kBlockN_; member
185 static constexpr int kBlockN = kBlockN_; member
Dflash_fwd_launch_template.h170 constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); in run_mha_fwd_splitkv_dispatch() local
Dflash_fwd_kernel.h42 constexpr int kBlockN = Kernel_traits::kBlockN; in compute_attn_1rowblock() local
492 constexpr int kBlockN = Kernel_traits::kBlockN; in compute_attn_1rowblock_splitkv() local
1182 constexpr int kBlockN = kNThreads / kBlockM; in combine_attn_seqk_parallel() local
Dflash_bwd_kernel.h93 constexpr int kBlockN = Kernel_traits::kBlockN; in compute_dq_dk_dv_1colblock() local