Home
last modified time | relevance | path

Searched defs:Flash_bwd_params (Results 1 – 1 of 1) sorted by relevance

/external/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/
Dflash.h143 struct Flash_bwd_params : public Flash_fwd_params { struct
146 void *__restrict__ do_ptr;
147 void *__restrict__ dq_ptr;
148 void *__restrict__ dk_ptr;
149 void *__restrict__ dv_ptr;
152 void *__restrict__ dq_accum_ptr;
153 void *__restrict__ dk_accum_ptr;
154 void *__restrict__ dv_accum_ptr;
163 index_t do_batch_stride;
164 index_t do_row_stride;
[all …]