1 /******************************************************************************
2  * Copyright (c) 2024, Tri Dao.
3  ******************************************************************************/
4 
5 #pragma once
6 
7 #include <cute/tensor.hpp>
8 
9 namespace pytorch_flash {
10 
11 using namespace cute;
12 
13 template <typename Engine, typename Layout>
14 __forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
15                                   const int col_idx_offset_ = 0) {
16     // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
17     static_assert(Layout::rank == 2, "Only support 2D Tensor");
18     const int lane_id = threadIdx.x % 32;
19     const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
20     #pragma unroll
21     for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
22         const int col_idx_base = col_idx_offset + nj * 8;
23         #pragma unroll
24         for (int j = 0; j < size<1, 0>(tensor); ++j) {
25             const int col_idx = col_idx_base + j;
26             if (col_idx >= max_seqlen_k) {
27                 // Without the "make_coord" we get wrong results
28                 #pragma unroll
29                 for (int mi = 0; mi < size<0>(tensor); ++mi) {
30                     tensor(mi, make_coord(j, nj)) = -INFINITY;
31                 }
32             }
33         }
34     }
35 }
36 
37 template <bool HasWSLeft=true, typename Engine, typename Layout>
apply_mask_local(Tensor<Engine,Layout> & tensor,const int col_idx_offset_,const int max_seqlen_k,const int row_idx_offset,const int max_seqlen_q,const int warp_row_stride,const int window_size_left,const int window_size_right)38 __forceinline__ __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
39                                         const int max_seqlen_k, const int row_idx_offset,
40                                         const int max_seqlen_q, const int warp_row_stride,
41                                         const int window_size_left, const int window_size_right) {
42     // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
43     static_assert(Layout::rank == 2, "Only support 2D Tensor");
44     const int lane_id = threadIdx.x % 32;
45     const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
46     #pragma unroll
47     for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
48         const int row_idx_base = row_idx_offset + mi * warp_row_stride;
49         #pragma unroll
50         for (int i = 0; i < size<0, 0>(tensor); ++i) {
51             const int row_idx = row_idx_base + i * 8;
52             const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
53             const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
54             #pragma unroll
55             for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
56                 const int col_idx_base = col_idx_offset + nj * 8;
57                 #pragma unroll
58                 for (int j = 0; j < size<1, 0>(tensor); ++j) {
59                     const int col_idx = col_idx_base + j;
60                     if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
61                         tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
62                     }
63                 }
64             }
65             // if (cute::thread0()) {
66             //     printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k);
67             //     print(tensor(make_coord(i, mi), _));
68             //     // print(tensor(_, j + nj * size<1, 0>(tensor)));
69             // }
70         }
71     }
72 }
73 
74 template <typename Engine, typename Layout>
apply_mask_causal(Tensor<Engine,Layout> & tensor,const int col_idx_offset_,const int max_seqlen_k,const int row_idx_offset,const int max_seqlen_q,const int warp_row_stride)75 __forceinline__ __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
76                                          const int max_seqlen_k, const int row_idx_offset,
77                                          const int max_seqlen_q, const int warp_row_stride) {
78     // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
79     apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
80                                           max_seqlen_q, warp_row_stride, -1, 0);
81 }
82 
83 template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
apply_mask_causal_w_idx(Tensor<Engine0,Layout0> & tensor,Tensor<Engine1,Layout1> const & idx_rowcol,const int col_idx_offset_,const int max_seqlen_k,const int row_idx_offset)84 __forceinline__ __device__ void apply_mask_causal_w_idx(
85     Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
86     const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
87 {
88     // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
89     static_assert(Layout0::rank == 2, "Only support 2D Tensor");
90     static_assert(Layout1::rank == 2, "Only support 2D Tensor");
91     CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol));
92     CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
93     #pragma unroll
94     for (int mi = 0; mi < size<0>(tensor); ++mi) {
95         const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0)));
96         #pragma unroll
97         for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
98             if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
99                 tensor(mi, ni) = -INFINITY;
100             }
101         }
102         // if (cute::thread0()) {
103         //     printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k);
104         //     print(tensor(_, make_coord(j, ni)));
105         //     // print(tensor(_, j + ni * size<1, 0>(tensor)));
106         // }
107     }
108 }
109 
110 template <bool Is_causal, bool Is_local, bool Has_alibi>
111 struct Mask {
112 
113     const int max_seqlen_k, max_seqlen_q;
114     const int window_size_left, window_size_right;
115     const float alibi_slope;
116 
117     __forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q,
118                                     const int window_size_left, const int window_size_right,
119                                     const float alibi_slope=0.f)
max_seqlen_kMask120         : max_seqlen_k(max_seqlen_k)
121         , max_seqlen_q(max_seqlen_q)
122         , window_size_left(window_size_left)
123         , window_size_right(window_size_right)
124         , alibi_slope(!Has_alibi ? 0.0 : alibi_slope) {
125     };
126 
127     // Causal_mask: whether this particular iteration needs causal masking
128     template <bool Causal_mask=false, bool Is_even_MN=true, typename Engine, typename Layout>
apply_maskMask129     __forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor_,
130                                                const int col_idx_offset_,
131                                                const int row_idx_offset,
132                                                const int warp_row_stride) {
133         static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local");
134         static_assert(Layout::rank == 3, "Only support 3D Tensor");
135         static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4");
136         static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN;
137         // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); }
138         if constexpr (Need_masking) {
139             // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
140             Tensor tensor = make_tensor(tensor_.data(), pytorch_flash::convert_layout_acc_rowcol(tensor_.layout()));
141             // Do we need both row and column indices, or just column incides?
142             static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask;
143             const int lane_id = threadIdx.x % 32;
144             const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
145             if constexpr (Col_idx_only) {
146                 #pragma unroll
147                 for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
148                     const int col_idx_base = col_idx_offset + nj * 8;
149                     #pragma unroll
150                     for (int j = 0; j < size<1, 0>(tensor); ++j) {
151                         const int col_idx = col_idx_base + j;
152                         #pragma unroll
153                         for (int mi = 0; mi < size<0>(tensor); ++mi) {
154                             // No causal, no local
155                             if constexpr (Has_alibi) {
156                                 tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
157                             }
158                             if constexpr (!Is_even_MN) {
159                                 if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; }
160                             }
161                         }
162                     }
163                 }
164             } else {
165                 #pragma unroll
166                 for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
167                     const int row_idx_base = row_idx_offset + mi * warp_row_stride;
168                     #pragma unroll
169                     for (int i = 0; i < size<0, 0>(tensor); ++i) {
170                         const int row_idx = row_idx_base + i * 8;
171                         const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
172                         const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
173                         #pragma unroll
174                         for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
175                             const int col_idx_base = col_idx_offset + nj * 8;
176                             #pragma unroll
177                             for (int j = 0; j < size<1, 0>(tensor); ++j) {
178                                 const int col_idx = col_idx_base + j;
179                                 if constexpr (Has_alibi) {
180                                     if constexpr (Is_causal) {
181                                         tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx;
182                                     } else {
183                                         tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
184 
185                                     }
186                                 }
187                                 if constexpr (Causal_mask) {
188                                     if (col_idx >= col_idx_limit_right) {
189                                         tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
190                                     }
191                                 }
192                                 if constexpr (Is_local) {
193                                     if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) {
194                                         tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
195                                     }
196                                 }
197                                 if constexpr (!Causal_mask && !Is_local && !Is_even_MN) {
198                                     // Causal and Local already handles MN masking
199                                     if (col_idx >= max_seqlen_k) {
200                                         tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
201                                     }
202                                 }
203                             }
204                         }
205                     }
206                 }
207             }
208         }
209     };
210 
211 };
212 
213 }  // namespace pytorch_flash
214