• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /******************************************************************************
2  * Copyright (c) 2023, Tri Dao.
3  ******************************************************************************/
4 
5 #pragma once
6 
7 namespace pytorch_flash {
8 
9 ////////////////////////////////////////////////////////////////////////////////////////////////////
10 
11 template<bool Varlen=true>
12 struct BlockInfo {
13 
14     template<typename Params>
BlockInfoBlockInfo15     __device__ BlockInfo(const Params &params, const int bidb)
16         : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
17         , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
18         , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
19         // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
20         // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
21         , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
22         , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
23         {
24         }
25 
26     template <typename index_t>
q_offsetBlockInfo27     __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
28         return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
29     }
30 
31     template <typename index_t>
k_offsetBlockInfo32     __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
33         return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
34     }
35 
36     const int sum_s_q;
37     const int sum_s_k;
38     const int actual_seqlen_q;
39     // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
40     const int seqlen_k_cache;
41     const int actual_seqlen_k;
42 };
43 
44 ////////////////////////////////////////////////////////////////////////////////////////////////////
45 
46 }  // namespace pytorch_flash
47