Home
last modified time | relevance | path

Searched refs:max_seqlen_batch_k (Results 1 – 4 of 4) sorted by relevance

/external/pytorch/aten/src/ATen/native/nested/cuda/
DNestedTensorTransformerFunctions.cpp298 max_seqlen_batch_k, in _scaled_dot_product_efficient_attention_nestedtensor_cuda()
315 max_seqlen_batch_k, in _scaled_dot_product_efficient_attention_nestedtensor_cuda()
336 const int64_t max_seqlen_batch_k, in _scaled_dot_product_flash_attention_backward_nested() argument
360 max_seqlen_batch_k); in _scaled_dot_product_flash_attention_backward_nested()
372 max_seqlen_batch_k, in _scaled_dot_product_flash_attention_backward_nested()
/external/pytorch/aten/src/ATen/native/transformers/cuda/
Dattention_backward.cu71 int64_t max_seqlen_batch_k, in _flash_attention_backward() argument
128 max_seqlen_batch_k, in _flash_attention_backward()
199 const int64_t max_seqlen_batch_k = key.size(2); in _scaled_dot_product_cudnn_attention_backward_cuda() local
210 … attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); in _scaled_dot_product_cudnn_attention_backward_cuda()
212 … attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); in _scaled_dot_product_cudnn_attention_backward_cuda()
214 …_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k}); in _scaled_dot_product_cudnn_attention_backward_cuda()
742 const int64_t max_seqlen_batch_k, in _scaled_dot_product_flash_attention_backward_cuda() argument
769 max_seqlen_batch_k, in _scaled_dot_product_flash_attention_backward_cuda()
Dattention.cu699 const int64_t max_seqlen_batch_k = key.size(2); in _scaled_dot_product_flash_attention_cuda() local
702 max_seqlen_batch_k == max_seqlen_batch_v, in _scaled_dot_product_flash_attention_cuda()
725 max_seqlen_batch_k, in _scaled_dot_product_flash_attention_cuda()
735 …uple(attention, logsumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, philox_seed… in _scaled_dot_product_flash_attention_cuda()
756 const int64_t max_seqlen_batch_k = key.size(2); in _scaled_dot_product_cudnn_attention_cuda() local
759 max_seqlen_batch_k == max_seqlen_batch_v, in _scaled_dot_product_cudnn_attention_cuda()
771 max_seqlen_batch_k/*int64_t s_kv*/, in _scaled_dot_product_cudnn_attention_cuda()
786 …ple(attention, log_sumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, cudnn_seed,… in _scaled_dot_product_cudnn_attention_cuda()
850 int64_t max_seqlen_batch_k, in _flash_attention_forward() argument
900 max_seqlen_batch_k, in _flash_attention_forward()
/external/pytorch/torch/
D_meta_registrations.py5058 max_seqlen_batch_k = key.size(2)
5071 if max_seqlen_batch_k <= 128:
5073 elif max_seqlen_batch_k <= 256:
5094 max_seqlen_batch_k,
5411 max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k
5426 if max_seqlen_batch_k <= 128:
5428 elif max_seqlen_batch_k <= 256: