1 /****************************************************************************** 2 * Copyright (c) 2024, Tri Dao. 3 ******************************************************************************/ 4 5 #pragma once 6 7 #include <cute/algorithm/copy.hpp> 8 9 #include <cutlass/cutlass.h> 10 #include <cutlass/layout/layout.h> 11 #include <cutlass/numeric_types.h> 12 13 namespace pytorch_flash{ 14 15 using namespace cute; 16 17 template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t> 18 struct Flash_kernel_traits { 19 20 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 21 using Element = elem_type; 22 static constexpr bool Has_cp_async = true; 23 #else 24 using Element = cutlass::half_t; 25 static constexpr bool Has_cp_async = false; 26 #endif 27 28 using ElementAccum = float; 29 using index_t = int64_t; 30 31 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 32 using MMA_Atom_Arch = std::conditional_t< 33 std::is_same_v<elem_type, cutlass::half_t>, 34 MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>, 35 MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN> 36 >; 37 #else 38 using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>; 39 #endif 40 41 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 42 using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>; 43 using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>; 44 #else 45 using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>; 46 using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>; 47 #endif 48 }; 49 50 // If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true 51 template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=cutlass::half_t, 52 typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> > 53 struct Flash_fwd_kernel_traits : public Base { 54 using Element = typename Base::Element; 55 using ElementAccum = typename Base::ElementAccum; 56 using index_t = typename Base::index_t; 57 static constexpr bool Has_cp_async = Base::Has_cp_async; 58 using SmemCopyAtom = typename Base::SmemCopyAtom; 59 using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; 60 61 static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; 62 static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; 63 64 // The number of threads. 65 static constexpr int kNWarps = kNWarps_; 66 static constexpr int kNThreads = kNWarps * 32; 67 68 static constexpr int kBlockM = kBlockM_; 69 static constexpr int kBlockN = kBlockN_; 70 static constexpr int kHeadDim = kHeadDim_; 71 static_assert(kHeadDim % 32 == 0); 72 static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; 73 static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); 74 static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; 75 76 using TiledMma = TiledMMA< 77 typename Base::MMA_Atom_Arch, 78 Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group 79 Tile<Int<16 * kNWarps>, _16, _16>>; 80 81 using SmemLayoutAtomQ = decltype( 82 composition(Swizzle<kSwizzle, 3, 3>{}, 83 // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 84 Layout<Shape<_8, Int<kBlockKSmem>>, 85 Stride<Int<kBlockKSmem>, _1>>{})); 86 using SmemLayoutQ = decltype(tile_to_shape( 87 SmemLayoutAtomQ{}, 88 Shape<Int<kBlockM>, Int<kHeadDim>>{})); 89 90 using SmemLayoutKV = decltype(tile_to_shape( 91 SmemLayoutAtomQ{}, 92 Shape<Int<kBlockN>, Int<kHeadDim>>{})); 93 94 // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434 95 using SmemLayoutVtransposed = decltype( 96 composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{}))); 97 using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); 98 99 using SmemLayoutAtomO = decltype( 100 composition(Swizzle<kSwizzle, 3, 3>{}, 101 Layout<Shape<Int<8>, Int<kBlockKSmem>>, 102 Stride<Int<kBlockKSmem>, _1>>{})); 103 using SmemLayoutO = decltype(tile_to_shape( 104 SmemLayoutAtomO{}, 105 Shape<Int<kBlockM>, Int<kHeadDim>>{})); 106 using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>; 107 using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>; 108 109 static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); 110 static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); 111 static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; 112 113 static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); 114 static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); 115 // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. 116 // For example, for d=128, smem is split into 2 "pages", each page takes care of columns 117 // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, 118 // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, 119 // to the same banks. 120 static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; 121 static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); 122 using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>, 123 Stride<Int<kGmemThreadsPerRow>, _1>>; 124 125 // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading 126 // from the same address by the same threadblock. This is slightly faster. 127 using Gmem_copy_struct = std::conditional_t< 128 Has_cp_async, 129 SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, 130 DefaultCopy 131 >; 132 using GmemTiledCopyQKV = decltype( 133 make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{}, 134 GmemLayoutAtom{}, 135 Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read 136 using GmemTiledCopyO = decltype( 137 make_tiled_copy(Copy_Atom<DefaultCopy, Element>{}, 138 GmemLayoutAtom{}, 139 Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store 140 141 using GmemLayoutAtomOaccum = std::conditional_t< 142 kBlockKSmem == 32, 143 Layout<Shape <_16, _8>, // Thread layout, 8 threads per row 144 Stride< _8, _1>>, 145 Layout<Shape <_8, _16>, // Thread layout, 16 threads per row 146 Stride< _16, _1>> 147 >; 148 using GmemTiledCopyOaccum = decltype( 149 make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{}, 150 GmemLayoutAtomOaccum{}, 151 Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store 152 using GmemLayoutAtomRotcossin = GmemLayoutAtom; 153 using GmemTiledCopyRotcossin = decltype( 154 make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{}, 155 GmemLayoutAtomRotcossin{}, 156 Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load 157 using GmemTiledCopyRotcossinCont = decltype( 158 make_tiled_copy(Copy_Atom<DefaultCopy, Element>{}, 159 GmemLayoutAtomRotcossin{}, 160 Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load 161 }; 162 163 // Is_V_in_regs is an option to reduce smem usage, but will increase register pressure. 164 // No_double_buffer is another option to reduce smem usage, but will slow things down. 165 template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, 166 int AtomLayoutMSdP_=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=2, 167 bool Is_V_in_regs_=false, bool No_double_buffer_=false, typename elem_type=cutlass::half_t, 168 typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> > 169 struct Flash_bwd_kernel_traits : public Base { 170 using Element = typename Base::Element; 171 using ElementAccum = typename Base::ElementAccum; 172 using index_t = typename Base::index_t; 173 static constexpr bool Has_cp_async = Base::Has_cp_async; 174 using SmemCopyAtom = typename Base::SmemCopyAtom; 175 using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; 176 177 static constexpr bool Is_V_in_regs = Is_V_in_regs_; 178 static constexpr bool No_double_buffer = No_double_buffer_; 179 180 // The number of threads. 181 static constexpr int kNWarps = kNWarps_; 182 static constexpr int kNThreads = kNWarps * 32; 183 184 static constexpr int kBlockM = kBlockM_; 185 static constexpr int kBlockN = kBlockN_; 186 static constexpr int kHeadDim = kHeadDim_; 187 static_assert(kHeadDim % 32 == 0); 188 static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; 189 static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); 190 static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; 191 192 static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; 193 static_assert(kNWarps % AtomLayoutMSdP == 0); 194 static_assert(kNWarps % AtomLayoutNdKV == 0); 195 static_assert(kNWarps % AtomLayoutMdQ == 0); 196 197 using TiledMmaSdP = TiledMMA< 198 typename Base::MMA_Atom_Arch, 199 Layout<Shape<Int<AtomLayoutMSdP>, Int<kNWarps / AtomLayoutMSdP>, _1>>, 200 Tile<Int<16 * AtomLayoutMSdP>, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>; 201 using TiledMmadKV = TiledMMA< 202 typename Base::MMA_Atom_Arch, 203 Layout<Shape<Int<AtomLayoutNdKV>, Int<kNWarps / AtomLayoutNdKV>, _1>>, 204 Tile<Int<16 * AtomLayoutNdKV>, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>; 205 using TiledMmadQ = TiledMMA< 206 typename Base::MMA_Atom_Arch, 207 Layout<Shape<Int<AtomLayoutMdQ>, Int<kNWarps / AtomLayoutMdQ>, _1>>, // 2x4x1 or 4x2x1 thread group 208 Tile<Int<16 * AtomLayoutMdQ>, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>; 209 using SmemLayoutAtomQdO = decltype( 210 composition(Swizzle<kSwizzle, 3, 3>{}, 211 Layout<Shape<_8, Int<kBlockKSmem>>, 212 Stride<Int<kBlockKSmem>, _1>>{})); 213 using SmemLayoutQdO = decltype(tile_to_shape( 214 SmemLayoutAtomQdO{}, 215 make_shape(Int<kBlockM>{}, Int<kHeadDim>{}))); 216 217 using SmemLayoutAtomKV = decltype( 218 composition(Swizzle<kSwizzle, 3, 3>{}, 219 Layout<Shape<Int<kBlockM / kNWarps>, Int<kBlockKSmem>>, 220 Stride<Int<kBlockKSmem>, _1>>{})); 221 using SmemLayoutKV = decltype(tile_to_shape( 222 // SmemLayoutAtomQdO{}, 223 SmemLayoutAtomKV{}, 224 make_shape(Int<kBlockN>{}, Int<kHeadDim>{}))); 225 226 using SmemLayoutKtransposed = decltype( 227 composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{}))); 228 using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{})); 229 230 // TODO: generalize to other values of kBlockN 231 // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 232 // static constexpr int kPBlockN = kBlockN; 233 // Temporarily disabling this for hdim 256 on sm86 and sm89 234 // static_assert(kBlockN >= 64); 235 static_assert(kBlockN >= 32); 236 // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. 237 static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32; 238 static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); 239 // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); 240 static constexpr int kSwizzlePdS = 3; 241 using SmemLayoutAtomPdS = decltype( 242 composition(Swizzle<kSwizzlePdS, 3, 3>{}, 243 Layout<Shape<Int<kBlockM>, Int<kPBlockN>>, 244 Stride<Int<kPBlockN>, _1>>{})); 245 using SmemLayoutPdS = decltype(tile_to_shape( 246 SmemLayoutAtomPdS{}, 247 make_shape(Int<kBlockM>{}, Int<kBlockN>{}))); 248 using SmemLayoutPdStransposed = decltype( 249 composition(SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{}))); 250 using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); 251 252 using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>; 253 254 using SmemLayoutQdOtransposed = decltype( 255 composition(SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{}))); 256 using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{})); 257 258 using SmemLayoutAtomdKV = decltype( 259 composition(Swizzle<kSwizzle, 3, 3>{}, 260 Layout<Shape<_8, Int<kBlockKSmem>>, 261 Stride<Int<kBlockKSmem>, _1>>{})); 262 using SmemLayoutdKV = decltype(tile_to_shape( 263 SmemLayoutAtomdKV{}, 264 make_shape(Int<kBlockN>{}, Int<kHeadDim>{}))); 265 using SmemCopyAtomdKV = Copy_Atom<DefaultCopy, elem_type>; 266 267 using SmemLayoutAtomdQ = decltype( 268 composition(Swizzle<kSwizzle, 3, 3>{}, 269 Layout<Shape<_8, Int<kBlockKSmem>>, 270 Stride<Int<kBlockKSmem>, _1>>{})); 271 using SmemLayoutdQ = decltype(tile_to_shape( 272 SmemLayoutAtomdQ{}, 273 make_shape(Int<kBlockM>{}, Int<kHeadDim>{}))); 274 using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>; 275 276 // Double buffer for sQ 277 static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); 278 static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); 279 static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element); 280 static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element); 281 static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); 282 static constexpr int kSmemSize = kSmemQdOSize 283 + (!Is_V_in_regs 284 ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) 285 : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); 286 static constexpr int kSmemSize1colblock = kSmemQdOSize 287 + (!Is_V_in_regs 288 ? kSmemKVSize + kSmemdSSize + kSmemPSize 289 : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); 290 291 static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); 292 static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); 293 // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem 294 // to affect speed in practice. 295 static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; 296 static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); 297 using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>, 298 Stride<Int<kGmemThreadsPerRow>, _1>>; 299 300 // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading 301 // from the same address by the same threadblock. This is slightly faster. 302 using Gmem_copy_struct = std::conditional_t< 303 Has_cp_async, 304 SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, 305 DefaultCopy 306 >; 307 using GmemTiledCopyQKV = decltype( 308 make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{}, 309 GmemLayoutAtom{}, 310 Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read 311 using GmemTiledCopydO = decltype( 312 make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{}, 313 GmemLayoutAtom{}, 314 Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store 315 using GmemTiledCopydKV = decltype( 316 make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{}, 317 GmemLayoutAtom{}, 318 Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store 319 using GmemTiledCopydQ = decltype( 320 make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{}, 321 GmemLayoutAtom{}, 322 Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store 323 using GmemLayoutAtomdQaccum = std::conditional_t< 324 kBlockKSmem == 32, 325 Layout<Shape <_32, _8>, // Thread layout, 8 threads per row 326 Stride< _8, _1>>, 327 Layout<Shape <_16, _16>, // Thread layout, 16 threads per row 328 Stride< _16, _1>> 329 >; 330 using GmemTiledCopydQaccum = decltype( 331 make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{}, 332 GmemLayoutAtomdQaccum{}, 333 Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store 334 335 using GmemTiledCopydQaccumAtomicAdd = decltype( 336 make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{}, 337 Layout<Shape <_8, _32>, // Thread layout, 8 threads per row 338 Stride<_32, _1>>{}, 339 Layout<Shape < _1, _1>>{})); // Val layout, 1 val per store 340 341 }; 342 343 //////////////////////////////////////////////////////////////////////////////////////////////////// 344 } // namespace pytorch_flash 345