Home
last modified time | relevance | path

Searched defs:head_size_og (Results 1 – 1 of 1) sorted by relevance

/external/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/
Dflash_api.cpp391 const int head_size_og = sizes[3]; in mha_fwd() local
605 const int head_size_og = sizes[2]; in mha_varlen_fwd() local
867 const int head_size_og = dout.size(3); in mha_bwd() local
1089 const int head_size_og = dout.size(2); in mha_varlen_bwd() local
1311 const int head_size_og = sizes[3]; in mha_fwd_kvcache() local