Home
last modified time | relevance | path

Searched defs:cumulative_sequence_length_k (Results 1 – 3 of 3) sorted by relevance

/external/pytorch/aten/src/ATen/native/transformers/cuda/
Dattention_backward.cu69 const Tensor& cumulative_sequence_length_k, in _flash_attention_backward()
740 const Tensor& cumulative_sequence_length_k, in _scaled_dot_product_flash_attention_backward_cuda()
Dattention.cu848 const std::optional<Tensor>& cumulative_sequence_length_k, in _flash_attention_forward()
/external/pytorch/aten/src/ATen/native/nested/cuda/
DNestedTensorTransformerFunctions.cpp334 const Tensor& cumulative_sequence_length_k, in _scaled_dot_product_flash_attention_backward_nested()