1 /******************************************************************************
2 * Copyright (c) 2024, Tri Dao.
3 ******************************************************************************/
4
5 #pragma once
6
7 #include <cute/algorithm/copy.hpp>
8
9 #include <ATen/native/transformers/cuda/flash_attn/utils.h>
10
11 ////////////////////////////////////////////////////////////////////////////////////////////////////
12
13 namespace pytorch_flash {
14
15 using namespace cute;
16
17 ////////////////////////////////////////////////////////////////////////////////////////////////////
18
19 template <bool Is_even_K=true, bool Clear_OOB_K=true,
20 typename Engine0, typename Layout0, typename Engine1, typename Layout1,
21 typename Engine2, typename Layout2, typename Engine3, typename Layout3>
copy_rotary_interleaved(Tensor<Engine0,Layout0> const & S,Tensor<Engine1,Layout1> & D,Tensor<Engine2,Layout2> const & Cos,Tensor<Engine2,Layout2> const & Sin,Tensor<Engine3,Layout3> const & identity_MN,const int max_MN,const int min_MN,const int dim,const int rotary_dim)22 __forceinline__ __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S,
23 Tensor<Engine1, Layout1> &D,
24 Tensor<Engine2, Layout2> const &Cos,
25 Tensor<Engine2, Layout2> const &Sin,
26 Tensor<Engine3, Layout3> const &identity_MN,
27 const int max_MN, const int min_MN,
28 const int dim, const int rotary_dim) {
29 CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
30 CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
31 CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
32 CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
33 CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
34 CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
35 CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
36 CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
37 CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
38 CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K
39 static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2);
40 static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
41 Tensor rCos = make_fragment_like(Cos);
42 Tensor rSin = make_fragment_like(Sin);
43 Tensor rS = make_fragment_like(S);
44 #pragma unroll
45 for (int m = 0; m < size<1>(S); ++m) {
46 if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
47 #pragma unroll
48 for (int k = 0; k < size<2>(S); ++k) {
49 if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
50 cute::copy(S(_, m, k), rS(_, m, k));
51 if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
52 cute::copy(Cos(_, m, k), rCos(_, m, k));
53 cute::copy(Sin(_, m, k), rSin(_, m, k));
54 Tensor S_fp32 = convert_type<float>(rS(_, m, k));
55 Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
56 Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
57 #pragma unroll
58 for (int i = 0; i < size<0>(rS) / 2; ++i) {
59 float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i);
60 float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i);
61 S_fp32(2 * i) = real;
62 S_fp32(2 * i + 1) = imag;
63 }
64 // Idk but I need to copy for the convert_type to work
65 Tensor S_fp32_copy = make_fragment_like(S_fp32);
66 cute::copy(S_fp32, S_fp32_copy);
67 using T = typename Engine0::value_type;
68 Tensor S_og_type = convert_type<T>(S_fp32_copy);
69 cute::copy(S_og_type, rS(_, m, k));
70 }
71 cute::copy(rS(_, m, k), D(_, m, k));
72 } else if (Clear_OOB_K) {
73 cute::clear(D(_, m, k));
74 }
75 }
76 }
77 }
78 }
79
80 ////////////////////////////////////////////////////////////////////////////////////////////////////
81
82 template <bool Is_even_K=true, bool Clear_OOB_K=true,
83 typename Engine0, typename Layout0, typename Engine1, typename Layout1,
84 typename Engine2, typename Layout2, typename Engine3, typename Layout3>
copy_rotary_contiguous(Tensor<Engine0,Layout0> const & S,Tensor<Engine1,Layout1> & D,Tensor<Engine2,Layout2> const & Cos,Tensor<Engine2,Layout2> const & Sin,Tensor<Engine3,Layout3> const & identity_MN,const int max_MN,const int min_MN,const int dim,const int rotary_dim)85 __forceinline__ __device__ void copy_rotary_contiguous(Tensor<Engine0, Layout0> const &S,
86 Tensor<Engine1, Layout1> &D,
87 Tensor<Engine2, Layout2> const &Cos,
88 Tensor<Engine2, Layout2> const &Sin,
89 Tensor<Engine3, Layout3> const &identity_MN,
90 const int max_MN, const int min_MN,
91 const int dim, const int rotary_dim) {
92 CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
93 CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
94 CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
95 CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
96 CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
97 CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
98 CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
99 CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
100 CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
101 CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA
102 CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin));
103 static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
104 Tensor rCos = make_fragment_like(Cos);
105 Tensor rSin = make_fragment_like(Sin);
106 Tensor rS = make_fragment_like(S);
107 Tensor rS_other = make_fragment_like(rS(_, 0, 0));
108 #pragma unroll
109 for (int m = 0; m < size<1>(S); ++m) {
110 if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
111 #pragma unroll
112 for (int k = 0; k < size<2>(S); ++k) {
113 if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
114 cute::copy(S(_, m, k), rS(_, m, k));
115 if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
116 const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2;
117 Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout());
118 cute::copy(gS_other, rS_other);
119 // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); }
120 Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout());
121 Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout());
122 cute::copy(gCos, rCos(_, m, k));
123 cute::copy(gSin, rSin(_, m, k));
124 // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); }
125 Tensor S_fp32 = convert_type<float>(rS(_, m, k));
126 Tensor S_other_fp32 = convert_type<float>(rS_other);
127 Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
128 Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
129 #pragma unroll
130 for (int i = 0; i < size<0>(rS); ++i) {
131 S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i));
132 }
133 // Idk but I need to copy for the convert_type to work
134 Tensor S_fp32_copy = make_fragment_like(S_fp32);
135 cute::copy(S_fp32, S_fp32_copy);
136 using T = typename Engine0::value_type;
137 Tensor S_og_type = convert_type<T>(S_fp32_copy);
138 cute::copy(S_og_type, rS(_, m, k));
139 // if (cute::thread0()) { print_tensor(rS(_, m, k)); }
140 }
141 cute::copy(rS(_, m, k), D(_, m, k));
142 } else if (Clear_OOB_K) {
143 cute::clear(D(_, m, k));
144 }
145 }
146 }
147 }
148 }
149
150 ////////////////////////////////////////////////////////////////////////////////////////////////////
151
152 } // namespace pytorch_flash
153