Home
last modified time | relevance | path

Searched defs:Flash_fwd_params (Results 1 – 1 of 1) sorted by relevance

/external/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/
Dflash.h50 struct Flash_fwd_params : public Qkv_params { struct
53 void * __restrict__ o_ptr;
54 void * __restrict__ oaccum_ptr;
57 index_t o_batch_stride;
58 index_t o_row_stride;
59 index_t o_head_stride;
62 void * __restrict__ p_ptr;
65 void * __restrict__ softmax_lse_ptr;
66 void * __restrict__ softmax_lseaccum_ptr;
69 …t b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
[all …]