Home
last modified time | relevance | path

Searched refs:max_seqlen_batch_q (Results 1 – 7 of 7) sorted by relevance

/external/pytorch/aten/src/ATen/native/nested/cuda/
DNestedTensorTransformerFunctions.cpp243 max_seqlen_batch_q, in _scaled_dot_product_flash_attention_nestedtensor_cuda()
259 max_seqlen_batch_q, in _scaled_dot_product_flash_attention_nestedtensor_cuda()
274 max_seqlen_batch_q, in _scaled_dot_product_flash_attention_nestedtensor_cuda()
297 max_seqlen_batch_q, in _scaled_dot_product_efficient_attention_nestedtensor_cuda()
314 max_seqlen_batch_q, in _scaled_dot_product_efficient_attention_nestedtensor_cuda()
335 const int64_t max_seqlen_batch_q, in _scaled_dot_product_flash_attention_backward_nested() argument
359 max_seqlen_batch_q, in _scaled_dot_product_flash_attention_backward_nested()
371 max_seqlen_batch_q, in _scaled_dot_product_flash_attention_backward_nested()
DNestedTensorTransformerUtils.cpp250 int64_t max_seqlen_batch_q = 0, Nnz_q = 0; in sdpa_nested_preprocessing_with_broadcast() local
253 max_seqlen_batch_q = q_t.size(1); in sdpa_nested_preprocessing_with_broadcast()
256 (output_batch_size + 1) * max_seqlen_batch_q, in sdpa_nested_preprocessing_with_broadcast()
257 max_seqlen_batch_q, in sdpa_nested_preprocessing_with_broadcast()
259 Nnz_q = output_batch_size * max_seqlen_batch_q; in sdpa_nested_preprocessing_with_broadcast()
264 max_seqlen_batch_q = std::get<1>(cumulative_and_max_q_and_nnz_q); in sdpa_nested_preprocessing_with_broadcast()
377 max_seqlen_batch_q, in sdpa_nested_preprocessing_with_broadcast()
413 …auto [cumulative_sequence_length_q, max_seqlen_batch_q, Nnz_q] = cumulative_and_max_seq_len_nnz(q_… in sdpa_nested_preprocessing()
452 max_seqlen_batch_q, in sdpa_nested_preprocessing()
466 const int64_t max_seqlen_batch_q, in sdpa_nested_preprocessing_backward() argument
/external/pytorch/torch/nested/_internal/
Dsdpa.py542 max_seqlen_batch_q,
580 max_seqlen_batch_q,
673 max_seqlen_batch_q,
690 max_seqlen_batch_q,
710 max_seqlen_batch_q,
728 max_seqlen_batch_q,
/external/pytorch/aten/src/ATen/native/transformers/cuda/
Dattention_backward.cu70 int64_t max_seqlen_batch_q, in _flash_attention_backward() argument
127 max_seqlen_batch_q, in _flash_attention_backward()
198 const int64_t max_seqlen_batch_q = query.size(2); in _scaled_dot_product_cudnn_attention_backward_cuda() local
210 … attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); in _scaled_dot_product_cudnn_attention_backward_cuda()
212 … attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); in _scaled_dot_product_cudnn_attention_backward_cuda()
214 …attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_… in _scaled_dot_product_cudnn_attention_backward_cuda()
741 const int64_t max_seqlen_batch_q, in _scaled_dot_product_flash_attention_backward_cuda() argument
768 max_seqlen_batch_q, in _scaled_dot_product_flash_attention_backward_cuda()
Dattention.cu698 const int64_t max_seqlen_batch_q = query.size(2); in _scaled_dot_product_flash_attention_cuda() local
724 max_seqlen_batch_q, in _scaled_dot_product_flash_attention_cuda()
735 …return std::make_tuple(attention, logsumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_ba… in _scaled_dot_product_flash_attention_cuda()
753 const int64_t max_seqlen_batch_q = query.size(2); in _scaled_dot_product_cudnn_attention_cuda() local
770 max_seqlen_batch_q/*int64_t s_q*/, in _scaled_dot_product_cudnn_attention_cuda()
786 …return std::make_tuple(attention, log_sumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_b… in _scaled_dot_product_cudnn_attention_cuda()
811 …auto [attention, log_sumexp, seed, offset, max_seqlen_batch_q, max_seqlen_batch_kv] = at::_efficie… in _scaled_dot_product_efficient_attention_cuda()
849 int64_t max_seqlen_batch_q, in _flash_attention_forward() argument
899 max_seqlen_batch_q, in _flash_attention_forward()
/external/pytorch/aten/src/ATen/native/nested/
DNestedTensorTransformerUtils.h36 const int64_t max_seqlen_batch_q,
/external/pytorch/torch/
D_meta_registrations.py5056 max_seqlen_batch_q = query.size(2)
5063 (batch_size, num_heads, max_seqlen_batch_q),
5070 max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
5076 (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
5093 max_seqlen_batch_q,
5188 max_seqlen_batch_q = query.size(2)
5195 max_seqlen_batch_q,
5410 max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q
5418 (batch_size, num_heads, max_seqlen_batch_q),
5425 max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
[all …]