Home
last modified time | relevance | path

Searched defs:average_attn_weights (Results 1 – 6 of 6) sorted by relevance

/external/pytorch/test/nn/
Dtest_multihead_attention.py41 def test_multihead_attention(self, average_attn_weights): argument
49 average_attn_weights=average_attn_weights, argument
145 average_attn_weights=average_attn_weights, argument
/external/pytorch/test/
Dtest_native_mha.py113 …self, device, dtype, mode, use_nt, need_weights, average_attn_weights, use_padding=False, pad_all=… argument
278 … need_weights, average_attn_weights, use_padding, pad_all, fused): argument
/external/pytorch/torch/csrc/api/src/nn/modules/
Dactivation.cpp447 bool average_attn_weights) { in forward()
/external/pytorch/aten/src/ATen/native/transformers/
Dattention.cpp273 bool average_attn_weights, in native_multi_head_attention_cpu()
/external/pytorch/aten/src/ATen/native/transformers/cuda/
Dattention.cu491 bool average_attn_weights, in native_multi_head_attention_cuda()
/external/pytorch/test/cpp/api/
Dmodules.cpp3437 bool average_attn_weights = true) {
3508 bool average_attn_weights = true) { in _multihead_attn_test_helper()