Home
last modified time | relevance | path

Searched defs:attn_weight (Results 1 – 2 of 2) sorted by relevance

/external/executorch/backends/vulkan/runtime/graph/ops/impl/
DSDPA.cpp73 const ValueRef attn_weight) { in add_attn_weight_scale_and_mask_node()
297 TmpTensor attn_weight( in sdpa_with_kv_cache_impl() local
/external/executorch/backends/vulkan/test/op_tests/
Dsdpa_test.cpp201 at::Tensor attn_weight = attn_weight_prescale * scale_factor + attn_mask; in sdpa_reference_impl() local