Home
last modified time | relevance | path

Searched defs:attn_mask_ (Results 1 – 3 of 3) sorted by relevance

/external/pytorch/aten/src/ATen/native/transformers/
Dattention.cpp427const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> s… in _fused_sdp_choice_cpp()
449 const std::optional<Tensor>& attn_mask_, in _fused_sdp_choice_meta()
483 const std::optional<Tensor>& attn_mask_, in validate_sdpa_input()
602const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> s… in handle_private_use()
653 const std::optional<Tensor>& attn_mask_, in scaled_dot_product_attention()
717 const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, in _scaled_dot_product_attention_math()
/external/pytorch/torch/csrc/api/include/torch/nn/functional/
Dactivation.h795 Tensor attn_mask_ = attn_mask; variable
/external/pytorch/aten/src/ATen/native/transformers/cuda/
Dattention.cu830const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> s… in _fused_sdp_choice_cuda()