Home
last modified time | relevance | path

Searched defs:Flash_fwd_kernel_traits (Results 1 – 1 of 1) sorted by relevance

/external/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/
Dkernel_traits.h53 struct Flash_fwd_kernel_traits : public Base { struct
54 using Element = typename Base::Element;
55 using ElementAccum = typename Base::ElementAccum;
56 using index_t = typename Base::index_t;
57 static constexpr bool Has_cp_async = Base::Has_cp_async;
58 using SmemCopyAtom = typename Base::SmemCopyAtom;
59 using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
61 static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;
62 static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;
65 static constexpr int kNWarps = kNWarps_;
[all …]