Home
last modified time | relevance | path

Searched defs:SDPAParams (Results 1 – 3 of 3) sorted by relevance

/external/pytorch/torch/nested/_internal/
Dsdpa.py74 def _check_batch_size_nested(params: SDPAParams, debug=False) -> bool:
87 def _check_head_dim_size_flash_nested(params: SDPAParams, debug=False) -> bool:
162 def _check_for_seq_len_0_nested(params: SDPAParams, debug=False) -> bool:
221 def _can_use_flash_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
233 def _can_use_efficient_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
244 def _can_use_math_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
/external/pytorch/torch/backends/cuda/
D__init__.py264 from torch._C import _SDPAParams as SDPAParams, _SDPBackend as SDPBackend unknown
357 def can_use_flash_attention(params: SDPAParams, debug: bool = False) -> bool:
377 def can_use_efficient_attention(params: SDPAParams, debug: bool = False) -> bool:
/external/pytorch/torch/nn/attention/
D__init__.py56 def _raise_kernel_warnings(params: SDPAParams) -> None: