• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /******************************************************************************
2  * Copyright (c) 2024, Tri Dao.
3  ******************************************************************************/
4 
5 #pragma once
6 
7 #include <cute/algorithm/copy.hpp>
8 
9 #include <ATen/native/transformers/cuda/flash_attn/utils.h>
10 
11 ////////////////////////////////////////////////////////////////////////////////////////////////////
12 
13 namespace pytorch_flash {
14 
15 using namespace cute;
16 
17 ////////////////////////////////////////////////////////////////////////////////////////////////////
18 
19 template <bool Is_even_K=true, bool Clear_OOB_K=true,
20           typename Engine0, typename Layout0, typename Engine1, typename Layout1,
21           typename Engine2, typename Layout2, typename Engine3, typename Layout3>
copy_rotary_interleaved(Tensor<Engine0,Layout0> const & S,Tensor<Engine1,Layout1> & D,Tensor<Engine2,Layout2> const & Cos,Tensor<Engine2,Layout2> const & Sin,Tensor<Engine3,Layout3> const & identity_MN,const int max_MN,const int min_MN,const int dim,const int rotary_dim)22 __forceinline__ __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S,
23                                                Tensor<Engine1, Layout1> &D,
24                                                Tensor<Engine2, Layout2> const &Cos,
25                                                Tensor<Engine2, Layout2> const &Sin,
26                                                Tensor<Engine3, Layout3> const &identity_MN,
27                                                const int max_MN, const int min_MN,
28                                                const int dim, const int rotary_dim) {
29     CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
30     CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
31     CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D));                     // MMA
32     CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D));                     // MMA_M
33     CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D));                     // MMA_K
34     CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos));                     // MMA_M
35     CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos));                     // MMA_K
36     CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin));                     // MMA_M
37     CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin));                     // MMA_K
38     CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin));                     // MMA_K
39     static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2);
40     static_assert(decltype(size<0>(Cos))::value % 2 == 0);  // Since we do fast conversion from fp16/bf16 to fp32
41     Tensor rCos = make_fragment_like(Cos);
42     Tensor rSin = make_fragment_like(Sin);
43     Tensor rS = make_fragment_like(S);
44     #pragma unroll
45     for (int m = 0; m < size<1>(S); ++m) {
46         if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
47             #pragma unroll
48             for (int k = 0; k < size<2>(S); ++k) {
49                 if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
50                     cute::copy(S(_, m, k), rS(_, m, k));
51                     if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
52                         cute::copy(Cos(_, m, k), rCos(_, m, k));
53                         cute::copy(Sin(_, m, k), rSin(_, m, k));
54                         Tensor S_fp32 = convert_type<float>(rS(_, m, k));
55                         Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
56                         Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
57                         #pragma unroll
58                         for (int i = 0; i < size<0>(rS) / 2; ++i) {
59                             float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i);
60                             float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i);
61                             S_fp32(2 * i) = real;
62                             S_fp32(2 * i + 1) = imag;
63                         }
64                         // Idk but I need to copy for the convert_type to work
65                         Tensor S_fp32_copy = make_fragment_like(S_fp32);
66                         cute::copy(S_fp32, S_fp32_copy);
67                         using T = typename Engine0::value_type;
68                         Tensor S_og_type = convert_type<T>(S_fp32_copy);
69                         cute::copy(S_og_type, rS(_, m, k));
70                     }
71                     cute::copy(rS(_, m, k), D(_, m, k));
72                 } else if (Clear_OOB_K) {
73                     cute::clear(D(_, m, k));
74                 }
75             }
76         }
77     }
78 }
79 
80 ////////////////////////////////////////////////////////////////////////////////////////////////////
81 
82 template <bool Is_even_K=true, bool Clear_OOB_K=true,
83           typename Engine0, typename Layout0, typename Engine1, typename Layout1,
84           typename Engine2, typename Layout2, typename Engine3, typename Layout3>
copy_rotary_contiguous(Tensor<Engine0,Layout0> const & S,Tensor<Engine1,Layout1> & D,Tensor<Engine2,Layout2> const & Cos,Tensor<Engine2,Layout2> const & Sin,Tensor<Engine3,Layout3> const & identity_MN,const int max_MN,const int min_MN,const int dim,const int rotary_dim)85 __forceinline__ __device__ void copy_rotary_contiguous(Tensor<Engine0, Layout0> const &S,
86                                               Tensor<Engine1, Layout1> &D,
87                                               Tensor<Engine2, Layout2> const &Cos,
88                                               Tensor<Engine2, Layout2> const &Sin,
89                                               Tensor<Engine3, Layout3> const &identity_MN,
90                                               const int max_MN, const int min_MN,
91                                               const int dim, const int rotary_dim) {
92     CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
93     CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
94     CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D));                     // MMA
95     CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D));                     // MMA_M
96     CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D));                     // MMA_K
97     CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos));                     // MMA_M
98     CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos));                     // MMA_K
99     CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin));                     // MMA_M
100     CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin));                     // MMA_K
101     CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos));                     // MMA
102     CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin));
103     static_assert(decltype(size<0>(Cos))::value % 2 == 0);  // Since we do fast conversion from fp16/bf16 to fp32
104     Tensor rCos = make_fragment_like(Cos);
105     Tensor rSin = make_fragment_like(Sin);
106     Tensor rS = make_fragment_like(S);
107     Tensor rS_other = make_fragment_like(rS(_, 0, 0));
108     #pragma unroll
109     for (int m = 0; m < size<1>(S); ++m) {
110         if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
111             #pragma unroll
112             for (int k = 0; k < size<2>(S); ++k) {
113                 if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
114                     cute::copy(S(_, m, k), rS(_, m, k));
115                     if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
116                         const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2;
117                         Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout());
118                         cute::copy(gS_other, rS_other);
119                         // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); }
120                         Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout());
121                         Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout());
122                         cute::copy(gCos, rCos(_, m, k));
123                         cute::copy(gSin, rSin(_, m, k));
124                         // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); }
125                         Tensor S_fp32 = convert_type<float>(rS(_, m, k));
126                         Tensor S_other_fp32 = convert_type<float>(rS_other);
127                         Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
128                         Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
129                         #pragma unroll
130                         for (int i = 0; i < size<0>(rS); ++i) {
131                             S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i));
132                         }
133                         // Idk but I need to copy for the convert_type to work
134                         Tensor S_fp32_copy = make_fragment_like(S_fp32);
135                         cute::copy(S_fp32, S_fp32_copy);
136                         using T = typename Engine0::value_type;
137                         Tensor S_og_type = convert_type<T>(S_fp32_copy);
138                         cute::copy(S_og_type, rS(_, m, k));
139                         // if (cute::thread0()) { print_tensor(rS(_, m, k)); }
140                     }
141                     cute::copy(rS(_, m, k), D(_, m, k));
142                 } else if (Clear_OOB_K) {
143                     cute::clear(D(_, m, k));
144                 }
145             }
146         }
147     }
148 }
149 
150 ////////////////////////////////////////////////////////////////////////////////////////////////////
151 
152 }  // namespace pytorch_flash
153