Home
last modified time | relevance | path

Searched defs:q_k_v (Results 1 – 3 of 3) sorted by relevance

/external/pytorch/aten/src/ATen/native/cpu/
DNativeMultiheadAttnKernel.cpp93 scalar_t* q_k_v = static_cast<scalar_t*>(_q_k_v); in transform_bias_rescale_qkv_kernel_impl() local
/external/pytorch/aten/src/ATen/native/transformers/cuda/
Dattention.cu104 PackedTensorAccessor64<scalar_t, 5, RestrictPtrTraits> q_k_v, in transform_bias_rescale_qkv_kernel()
207 PackedTensorAccessor64<scalar_t, 5, RestrictPtrTraits> q_k_v, in transform_bias_rescale_qkv_add_padding_kernel()
402 auto q_k_v = at::empty({3, B, num_head, T, dim_per_head}, qkv_bias.options()); in transform_bias_rescale_qkv_cuda() local
620 auto q_k_v = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head); in native_multi_head_attention_cuda() local
/external/pytorch/aten/src/ATen/native/transformers/
Dattention.cpp244 auto q_k_v = at::empty({3, B, num_head, T, dim_per_head}, qkv_->options()); in transform_bias_rescale_qkv_cpu() local
361 auto q_k_v = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head); in native_multi_head_attention_cpu() local
916 auto q_k_v = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head); in triton_multi_head_attention() local