Home
last modified time | relevance | path

Searched defs:row_idx_base (Results 1 – 2 of 2) sorted by relevance

/external/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/
Dmask.h48 const int row_idx_base = row_idx_offset + mi * warp_row_stride; in apply_mask_local() local
167 const int row_idx_base = row_idx_offset + mi * warp_row_stride; in apply_mask() local
Dalibi.h54 const int row_idx_base = row_idx_offset + mi * warp_row_stride; in apply_alibi() local