1 /******************************************************************************
2 * Copyright (c) 2024, Tri Dao.
3 ******************************************************************************/
4
5 #pragma once
6
7 #include <cute/tensor.hpp>
8
9 namespace pytorch_flash {
10
11 using namespace cute;
12
13 template <typename Engine, typename Layout>
14 __forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
15 const int col_idx_offset_ = 0) {
16 // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
17 static_assert(Layout::rank == 2, "Only support 2D Tensor");
18 const int lane_id = threadIdx.x % 32;
19 const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
20 #pragma unroll
21 for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
22 const int col_idx_base = col_idx_offset + nj * 8;
23 #pragma unroll
24 for (int j = 0; j < size<1, 0>(tensor); ++j) {
25 const int col_idx = col_idx_base + j;
26 if (col_idx >= max_seqlen_k) {
27 // Without the "make_coord" we get wrong results
28 #pragma unroll
29 for (int mi = 0; mi < size<0>(tensor); ++mi) {
30 tensor(mi, make_coord(j, nj)) = -INFINITY;
31 }
32 }
33 }
34 }
35 }
36
37 template <bool HasWSLeft=true, typename Engine, typename Layout>
apply_mask_local(Tensor<Engine,Layout> & tensor,const int col_idx_offset_,const int max_seqlen_k,const int row_idx_offset,const int max_seqlen_q,const int warp_row_stride,const int window_size_left,const int window_size_right)38 __forceinline__ __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
39 const int max_seqlen_k, const int row_idx_offset,
40 const int max_seqlen_q, const int warp_row_stride,
41 const int window_size_left, const int window_size_right) {
42 // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
43 static_assert(Layout::rank == 2, "Only support 2D Tensor");
44 const int lane_id = threadIdx.x % 32;
45 const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
46 #pragma unroll
47 for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
48 const int row_idx_base = row_idx_offset + mi * warp_row_stride;
49 #pragma unroll
50 for (int i = 0; i < size<0, 0>(tensor); ++i) {
51 const int row_idx = row_idx_base + i * 8;
52 const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
53 const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
54 #pragma unroll
55 for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
56 const int col_idx_base = col_idx_offset + nj * 8;
57 #pragma unroll
58 for (int j = 0; j < size<1, 0>(tensor); ++j) {
59 const int col_idx = col_idx_base + j;
60 if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
61 tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
62 }
63 }
64 }
65 // if (cute::thread0()) {
66 // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k);
67 // print(tensor(make_coord(i, mi), _));
68 // // print(tensor(_, j + nj * size<1, 0>(tensor)));
69 // }
70 }
71 }
72 }
73
74 template <typename Engine, typename Layout>
apply_mask_causal(Tensor<Engine,Layout> & tensor,const int col_idx_offset_,const int max_seqlen_k,const int row_idx_offset,const int max_seqlen_q,const int warp_row_stride)75 __forceinline__ __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
76 const int max_seqlen_k, const int row_idx_offset,
77 const int max_seqlen_q, const int warp_row_stride) {
78 // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
79 apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
80 max_seqlen_q, warp_row_stride, -1, 0);
81 }
82
83 template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
apply_mask_causal_w_idx(Tensor<Engine0,Layout0> & tensor,Tensor<Engine1,Layout1> const & idx_rowcol,const int col_idx_offset_,const int max_seqlen_k,const int row_idx_offset)84 __forceinline__ __device__ void apply_mask_causal_w_idx(
85 Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
86 const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
87 {
88 // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
89 static_assert(Layout0::rank == 2, "Only support 2D Tensor");
90 static_assert(Layout1::rank == 2, "Only support 2D Tensor");
91 CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol));
92 CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
93 #pragma unroll
94 for (int mi = 0; mi < size<0>(tensor); ++mi) {
95 const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0)));
96 #pragma unroll
97 for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
98 if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
99 tensor(mi, ni) = -INFINITY;
100 }
101 }
102 // if (cute::thread0()) {
103 // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k);
104 // print(tensor(_, make_coord(j, ni)));
105 // // print(tensor(_, j + ni * size<1, 0>(tensor)));
106 // }
107 }
108 }
109
110 template <bool Is_causal, bool Is_local, bool Has_alibi>
111 struct Mask {
112
113 const int max_seqlen_k, max_seqlen_q;
114 const int window_size_left, window_size_right;
115 const float alibi_slope;
116
117 __forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q,
118 const int window_size_left, const int window_size_right,
119 const float alibi_slope=0.f)
max_seqlen_kMask120 : max_seqlen_k(max_seqlen_k)
121 , max_seqlen_q(max_seqlen_q)
122 , window_size_left(window_size_left)
123 , window_size_right(window_size_right)
124 , alibi_slope(!Has_alibi ? 0.0 : alibi_slope) {
125 };
126
127 // Causal_mask: whether this particular iteration needs causal masking
128 template <bool Causal_mask=false, bool Is_even_MN=true, typename Engine, typename Layout>
apply_maskMask129 __forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor_,
130 const int col_idx_offset_,
131 const int row_idx_offset,
132 const int warp_row_stride) {
133 static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local");
134 static_assert(Layout::rank == 3, "Only support 3D Tensor");
135 static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4");
136 static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN;
137 // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); }
138 if constexpr (Need_masking) {
139 // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
140 Tensor tensor = make_tensor(tensor_.data(), pytorch_flash::convert_layout_acc_rowcol(tensor_.layout()));
141 // Do we need both row and column indices, or just column incides?
142 static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask;
143 const int lane_id = threadIdx.x % 32;
144 const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
145 if constexpr (Col_idx_only) {
146 #pragma unroll
147 for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
148 const int col_idx_base = col_idx_offset + nj * 8;
149 #pragma unroll
150 for (int j = 0; j < size<1, 0>(tensor); ++j) {
151 const int col_idx = col_idx_base + j;
152 #pragma unroll
153 for (int mi = 0; mi < size<0>(tensor); ++mi) {
154 // No causal, no local
155 if constexpr (Has_alibi) {
156 tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
157 }
158 if constexpr (!Is_even_MN) {
159 if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; }
160 }
161 }
162 }
163 }
164 } else {
165 #pragma unroll
166 for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
167 const int row_idx_base = row_idx_offset + mi * warp_row_stride;
168 #pragma unroll
169 for (int i = 0; i < size<0, 0>(tensor); ++i) {
170 const int row_idx = row_idx_base + i * 8;
171 const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
172 const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
173 #pragma unroll
174 for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
175 const int col_idx_base = col_idx_offset + nj * 8;
176 #pragma unroll
177 for (int j = 0; j < size<1, 0>(tensor); ++j) {
178 const int col_idx = col_idx_base + j;
179 if constexpr (Has_alibi) {
180 if constexpr (Is_causal) {
181 tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx;
182 } else {
183 tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
184
185 }
186 }
187 if constexpr (Causal_mask) {
188 if (col_idx >= col_idx_limit_right) {
189 tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
190 }
191 }
192 if constexpr (Is_local) {
193 if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) {
194 tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
195 }
196 }
197 if constexpr (!Causal_mask && !Is_local && !Is_even_MN) {
198 // Causal and Local already handles MN masking
199 if (col_idx >= max_seqlen_k) {
200 tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
201 }
202 }
203 }
204 }
205 }
206 }
207 }
208 }
209 };
210
211 };
212
213 } // namespace pytorch_flash
214