Home
last modified time | relevance | path

Searched defs:bidh (Results 1 – 3 of 3) sorted by relevance

/external/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/
Dflash_bwd_preprocess_kernel.h66 const int bidh = blockIdx.z; in compute_dot_do_o() local
149 const int bidh = blockIdx.z; in clear_dKVaccum() local
193 const int bidh = blockIdx.z; in convert_dQ() local
281 const int bidh = blockIdx.z; in convert_dKV() local
Dflash_bwd_kernel.h80 …oid compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_b… in compute_dq_dk_dv_1colblock()
794 const int bidh = blockIdx.y; in compute_dq_dk_dv() local
820 const int bidh = blockIdx.z; in compute_dq_dk_dv_seqk_parallel() local
Dflash_fwd_kernel.h29 inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh,… in compute_attn_1rowblock()
479 …compute_attn_1rowblock_splitkv(const Params &params, const int bidb, const int bidh, const int m_b… in compute_attn_1rowblock_splitkv()
1051 const int bidh = blockIdx.z; in compute_attn() local
1072 const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; in compute_attn_splitkv() local