Searched refs:max_seqlen_batch_k (Results 1 – 4 of 4) sorted by relevance
/external/pytorch/aten/src/ATen/native/nested/cuda/ |
D | NestedTensorTransformerFunctions.cpp | 298 max_seqlen_batch_k, in _scaled_dot_product_efficient_attention_nestedtensor_cuda() 315 max_seqlen_batch_k, in _scaled_dot_product_efficient_attention_nestedtensor_cuda() 336 const int64_t max_seqlen_batch_k, in _scaled_dot_product_flash_attention_backward_nested() argument 360 max_seqlen_batch_k); in _scaled_dot_product_flash_attention_backward_nested() 372 max_seqlen_batch_k, in _scaled_dot_product_flash_attention_backward_nested()
|
/external/pytorch/aten/src/ATen/native/transformers/cuda/ |
D | attention_backward.cu | 71 int64_t max_seqlen_batch_k, in _flash_attention_backward() argument 128 max_seqlen_batch_k, in _flash_attention_backward() 199 const int64_t max_seqlen_batch_k = key.size(2); in _scaled_dot_product_cudnn_attention_backward_cuda() local 210 … attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); in _scaled_dot_product_cudnn_attention_backward_cuda() 212 … attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); in _scaled_dot_product_cudnn_attention_backward_cuda() 214 …_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k}); in _scaled_dot_product_cudnn_attention_backward_cuda() 742 const int64_t max_seqlen_batch_k, in _scaled_dot_product_flash_attention_backward_cuda() argument 769 max_seqlen_batch_k, in _scaled_dot_product_flash_attention_backward_cuda()
|
D | attention.cu | 699 const int64_t max_seqlen_batch_k = key.size(2); in _scaled_dot_product_flash_attention_cuda() local 702 max_seqlen_batch_k == max_seqlen_batch_v, in _scaled_dot_product_flash_attention_cuda() 725 max_seqlen_batch_k, in _scaled_dot_product_flash_attention_cuda() 735 …uple(attention, logsumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, philox_seed… in _scaled_dot_product_flash_attention_cuda() 756 const int64_t max_seqlen_batch_k = key.size(2); in _scaled_dot_product_cudnn_attention_cuda() local 759 max_seqlen_batch_k == max_seqlen_batch_v, in _scaled_dot_product_cudnn_attention_cuda() 771 max_seqlen_batch_k/*int64_t s_kv*/, in _scaled_dot_product_cudnn_attention_cuda() 786 …ple(attention, log_sumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, cudnn_seed,… in _scaled_dot_product_cudnn_attention_cuda() 850 int64_t max_seqlen_batch_k, in _flash_attention_forward() argument 900 max_seqlen_batch_k, in _flash_attention_forward()
|
/external/pytorch/torch/ |
D | _meta_registrations.py | 5058 max_seqlen_batch_k = key.size(2) 5071 if max_seqlen_batch_k <= 128: 5073 elif max_seqlen_batch_k <= 256: 5094 max_seqlen_batch_k, 5411 max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k 5426 if max_seqlen_batch_k <= 128: 5428 elif max_seqlen_batch_k <= 256:
|