• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /******************************************************************************
2  * Copyright (c) 2024, Tri Dao.
3  ******************************************************************************/
4 
5 #pragma once
6 
7 #include <cute/algorithm/copy.hpp>
8 
9 #include <cutlass/cutlass.h>
10 #include <cutlass/layout/layout.h>
11 #include <cutlass/numeric_types.h>
12 
13 namespace pytorch_flash{
14 
15 using namespace cute;
16 
17 template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t>
18 struct Flash_kernel_traits {
19 
20 #if defined(__CUDA_ARCH__) &&  __CUDA_ARCH__ >= 800
21     using Element = elem_type;
22     static constexpr bool Has_cp_async = true;
23 #else
24     using Element = cutlass::half_t;
25     static constexpr bool Has_cp_async = false;
26 #endif
27 
28     using ElementAccum = float;
29     using index_t = int64_t;
30 
31 #if defined(__CUDA_ARCH__) &&  __CUDA_ARCH__ >= 800
32     using MMA_Atom_Arch = std::conditional_t<
33         std::is_same_v<elem_type, cutlass::half_t>,
34         MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
35         MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
36     >;
37 #else
38     using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
39 #endif
40 
41 #if defined(__CUDA_ARCH__) &&  __CUDA_ARCH__ >= 750
42     using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
43     using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
44 #else
45     using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
46     using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
47 #endif
48 };
49 
50 // If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
51 template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=cutlass::half_t,
52          typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
53 struct Flash_fwd_kernel_traits : public Base {
54     using Element = typename Base::Element;
55     using ElementAccum = typename Base::ElementAccum;
56     using index_t = typename Base::index_t;
57     static constexpr bool Has_cp_async = Base::Has_cp_async;
58     using SmemCopyAtom = typename Base::SmemCopyAtom;
59     using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
60 
61     static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;
62     static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;
63 
64     // The number of threads.
65     static constexpr int kNWarps = kNWarps_;
66     static constexpr int kNThreads = kNWarps * 32;
67 
68     static constexpr int kBlockM = kBlockM_;
69     static constexpr int kBlockN = kBlockN_;
70     static constexpr int kHeadDim = kHeadDim_;
71     static_assert(kHeadDim % 32 == 0);
72     static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
73     static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
74     static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
75 
76     using TiledMma = TiledMMA<
77         typename Base::MMA_Atom_Arch,
78         Layout<Shape<Int<kNWarps>,_1,_1>>,  // 4x1x1 or 8x1x1 thread group
79         Tile<Int<16 * kNWarps>, _16, _16>>;
80 
81     using SmemLayoutAtomQ = decltype(
82         composition(Swizzle<kSwizzle, 3, 3>{},
83                     // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
84                     Layout<Shape<_8, Int<kBlockKSmem>>,
85                            Stride<Int<kBlockKSmem>, _1>>{}));
86     using SmemLayoutQ = decltype(tile_to_shape(
87         SmemLayoutAtomQ{},
88         Shape<Int<kBlockM>, Int<kHeadDim>>{}));
89 
90     using SmemLayoutKV = decltype(tile_to_shape(
91         SmemLayoutAtomQ{},
92         Shape<Int<kBlockN>, Int<kHeadDim>>{}));
93 
94     // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434
95     using SmemLayoutVtransposed = decltype(
96         composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})));
97     using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
98 
99     using SmemLayoutAtomO = decltype(
100         composition(Swizzle<kSwizzle, 3, 3>{},
101                     Layout<Shape<Int<8>, Int<kBlockKSmem>>,
102                            Stride<Int<kBlockKSmem>, _1>>{}));
103     using SmemLayoutO = decltype(tile_to_shape(
104         SmemLayoutAtomO{},
105         Shape<Int<kBlockM>, Int<kHeadDim>>{}));
106     using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
107     using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
108 
109     static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element);
110     static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
111     static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
112 
113     static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
114     static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
115     // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
116     // For example, for d=128, smem is split into 2 "pages", each page takes care of columns
117     // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
118     // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
119     // to the same banks.
120     static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
121     static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
122     using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
123                                   Stride<Int<kGmemThreadsPerRow>, _1>>;
124 
125     // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
126     // from the same address by the same threadblock. This is slightly faster.
127     using Gmem_copy_struct = std::conditional_t<
128         Has_cp_async,
129         SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
130         DefaultCopy
131     >;
132     using GmemTiledCopyQKV = decltype(
133         make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
134                         GmemLayoutAtom{},
135                         Layout<Shape<_1, _8>>{}));  // Val layout, 8 vals per read
136     using GmemTiledCopyO = decltype(
137         make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
138                         GmemLayoutAtom{},
139                         Layout<Shape<_1, _8>>{}));  // Val layout, 8 vals per store
140 
141     using GmemLayoutAtomOaccum = std::conditional_t<
142         kBlockKSmem == 32,
143         Layout<Shape <_16, _8>,  // Thread layout, 8 threads per row
144                Stride< _8, _1>>,
145         Layout<Shape <_8, _16>,  // Thread layout, 16 threads per row
146                Stride< _16, _1>>
147     >;
148     using GmemTiledCopyOaccum = decltype(
149         make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
150                         GmemLayoutAtomOaccum{},
151                         Layout<Shape < _1, _4>>{}));  // Val layout, 4 vals per store
152     using GmemLayoutAtomRotcossin = GmemLayoutAtom;
153     using GmemTiledCopyRotcossin = decltype(
154         make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
155                         GmemLayoutAtomRotcossin{},
156                         Layout<Shape < _1, _4>>{}));  // Val layout, 4 vals per load
157     using GmemTiledCopyRotcossinCont = decltype(
158         make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
159                         GmemLayoutAtomRotcossin{},
160                         Layout<Shape < _1, _8>>{}));  // Val layout, 8 vals per load
161 };
162 
163 // Is_V_in_regs is an option to reduce smem usage, but will increase register pressure.
164 // No_double_buffer is another option to reduce smem usage, but will slow things down.
165 template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
166          int AtomLayoutMSdP_=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=2,
167          bool Is_V_in_regs_=false, bool No_double_buffer_=false, typename elem_type=cutlass::half_t,
168          typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
169 struct Flash_bwd_kernel_traits : public Base {
170     using Element = typename Base::Element;
171     using ElementAccum = typename Base::ElementAccum;
172     using index_t = typename Base::index_t;
173     static constexpr bool Has_cp_async = Base::Has_cp_async;
174     using SmemCopyAtom = typename Base::SmemCopyAtom;
175     using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
176 
177     static constexpr bool Is_V_in_regs = Is_V_in_regs_;
178     static constexpr bool No_double_buffer = No_double_buffer_;
179 
180     // The number of threads.
181     static constexpr int kNWarps = kNWarps_;
182     static constexpr int kNThreads = kNWarps * 32;
183 
184     static constexpr int kBlockM = kBlockM_;
185     static constexpr int kBlockN = kBlockN_;
186     static constexpr int kHeadDim = kHeadDim_;
187     static_assert(kHeadDim % 32 == 0);
188     static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
189     static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
190     static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
191 
192     static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_;
193     static_assert(kNWarps % AtomLayoutMSdP == 0);
194     static_assert(kNWarps % AtomLayoutNdKV == 0);
195     static_assert(kNWarps % AtomLayoutMdQ == 0);
196 
197     using TiledMmaSdP = TiledMMA<
198         typename Base::MMA_Atom_Arch,
199         Layout<Shape<Int<AtomLayoutMSdP>, Int<kNWarps / AtomLayoutMSdP>, _1>>,
200         Tile<Int<16 * AtomLayoutMSdP>, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>;
201     using TiledMmadKV = TiledMMA<
202         typename Base::MMA_Atom_Arch,
203         Layout<Shape<Int<AtomLayoutNdKV>, Int<kNWarps / AtomLayoutNdKV>, _1>>,
204         Tile<Int<16 * AtomLayoutNdKV>, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>;
205     using TiledMmadQ = TiledMMA<
206         typename Base::MMA_Atom_Arch,
207         Layout<Shape<Int<AtomLayoutMdQ>, Int<kNWarps / AtomLayoutMdQ>, _1>>,  // 2x4x1 or 4x2x1 thread group
208         Tile<Int<16 * AtomLayoutMdQ>, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>;
209     using SmemLayoutAtomQdO = decltype(
210         composition(Swizzle<kSwizzle, 3, 3>{},
211                     Layout<Shape<_8, Int<kBlockKSmem>>,
212                            Stride<Int<kBlockKSmem>, _1>>{}));
213     using SmemLayoutQdO = decltype(tile_to_shape(
214         SmemLayoutAtomQdO{},
215         make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
216 
217     using SmemLayoutAtomKV = decltype(
218         composition(Swizzle<kSwizzle, 3, 3>{},
219                     Layout<Shape<Int<kBlockM / kNWarps>, Int<kBlockKSmem>>,
220                            Stride<Int<kBlockKSmem>, _1>>{}));
221     using SmemLayoutKV = decltype(tile_to_shape(
222         // SmemLayoutAtomQdO{},
223         SmemLayoutAtomKV{},
224         make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
225 
226     using SmemLayoutKtransposed = decltype(
227         composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})));
228     using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{}));
229 
230     // TODO: generalize to other values of kBlockN
231     // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
232     // static constexpr int kPBlockN = kBlockN;
233     // Temporarily disabling this for hdim 256 on sm86 and sm89
234     // static_assert(kBlockN >= 64);
235     static_assert(kBlockN >= 32);
236     // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest.
237     static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32;
238     static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64);
239     // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3);
240     static constexpr int kSwizzlePdS = 3;
241     using SmemLayoutAtomPdS = decltype(
242         composition(Swizzle<kSwizzlePdS, 3, 3>{},
243                     Layout<Shape<Int<kBlockM>, Int<kPBlockN>>,
244                            Stride<Int<kPBlockN>, _1>>{}));
245     using SmemLayoutPdS = decltype(tile_to_shape(
246         SmemLayoutAtomPdS{},
247         make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
248     using SmemLayoutPdStransposed = decltype(
249         composition(SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{})));
250     using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{}));
251 
252     using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
253 
254     using SmemLayoutQdOtransposed = decltype(
255         composition(SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{})));
256     using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{}));
257 
258     using SmemLayoutAtomdKV = decltype(
259         composition(Swizzle<kSwizzle, 3, 3>{},
260                     Layout<Shape<_8, Int<kBlockKSmem>>,
261                            Stride<Int<kBlockKSmem>, _1>>{}));
262     using SmemLayoutdKV = decltype(tile_to_shape(
263         SmemLayoutAtomdKV{},
264         make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
265     using SmemCopyAtomdKV = Copy_Atom<DefaultCopy, elem_type>;
266 
267     using SmemLayoutAtomdQ = decltype(
268         composition(Swizzle<kSwizzle, 3, 3>{},
269                     Layout<Shape<_8, Int<kBlockKSmem>>,
270                            Stride<Int<kBlockKSmem>, _1>>{}));
271     using SmemLayoutdQ = decltype(tile_to_shape(
272         SmemLayoutAtomdQ{},
273         make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
274     using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>;
275 
276     // Double buffer for sQ
277     static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element);
278     static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
279     static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element);
280     static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element);
281     static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
282     static constexpr int kSmemSize = kSmemQdOSize
283         + (!Is_V_in_regs
284            ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)
285            : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)));
286     static constexpr int kSmemSize1colblock = kSmemQdOSize
287         + (!Is_V_in_regs
288            ? kSmemKVSize + kSmemdSSize + kSmemPSize
289            : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize));
290 
291     static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
292     static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
293     // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
294     // to affect speed in practice.
295     static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
296     static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
297     using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
298                                   Stride<Int<kGmemThreadsPerRow>, _1>>;
299 
300     // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
301     // from the same address by the same threadblock. This is slightly faster.
302     using Gmem_copy_struct = std::conditional_t<
303         Has_cp_async,
304         SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
305         DefaultCopy
306     >;
307     using GmemTiledCopyQKV = decltype(
308         make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
309                         GmemLayoutAtom{},
310                         Layout<Shape<_1, _8>>{}));  // Val layout, 8 vals per read
311     using GmemTiledCopydO = decltype(
312         make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
313                         GmemLayoutAtom{},
314                         Layout<Shape < _1, _8>>{}));  // Val layout, 8 vals per store
315     using GmemTiledCopydKV = decltype(
316         make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
317                         GmemLayoutAtom{},
318                         Layout<Shape < _1, _8>>{}));  // Val layout, 8 vals per store
319     using GmemTiledCopydQ = decltype(
320         make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
321                         GmemLayoutAtom{},
322                         Layout<Shape < _1, _8>>{}));  // Val layout, 8 vals per store
323     using GmemLayoutAtomdQaccum = std::conditional_t<
324         kBlockKSmem == 32,
325         Layout<Shape <_32, _8>,  // Thread layout, 8 threads per row
326                Stride< _8, _1>>,
327         Layout<Shape <_16, _16>,  // Thread layout, 16 threads per row
328                Stride< _16, _1>>
329     >;
330     using GmemTiledCopydQaccum = decltype(
331         make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
332                         GmemLayoutAtomdQaccum{},
333                         Layout<Shape < _1, _4>>{}));  // Val layout, 4 vals per store
334 
335     using GmemTiledCopydQaccumAtomicAdd = decltype(
336         make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
337                         Layout<Shape <_8, _32>,  // Thread layout, 8 threads per row
338                                Stride<_32, _1>>{},
339                         Layout<Shape < _1, _1>>{}));  // Val layout, 1 val per store
340 
341 };
342 
343 ////////////////////////////////////////////////////////////////////////////////////////////////////
344 } // namespace pytorch_flash
345