Home
last modified time | relevance | path

Searched defs:gLSE (Results 1 – 2 of 2) sorted by relevance

/external/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/
Dflash_fwd_kernel.h83 Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block)); in compute_attn_1rowblock() local
435 Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block)); in compute_attn_1rowblock() local
1111 …Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) +… in combine_attn_seqk_parallel() local
Dflash_bwd_kernel.h149 …Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) +… in compute_dq_dk_dv_1colblock() local