Home
last modified time | relevance | path

Searched defs:row_offset_o (Results 1 – 3 of 3) sorted by relevance

/external/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/
Dflash_bwd_preprocess_kernel.h78 const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) in compute_dot_do_o() local
Dflash_fwd_kernel.h521 … const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) in compute_attn_1rowblock_splitkv() local
991 const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) in compute_attn_1rowblock_splitkv() local
Dflash_bwd_kernel.h115 const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) in compute_dq_dk_dv_1colblock() local