1 /****************************************************************************** 2 * Copyright (c) 2023, Tri Dao. 3 ******************************************************************************/ 4 5 #pragma once 6 7 namespace pytorch_flash { 8 9 //////////////////////////////////////////////////////////////////////////////////////////////////// 10 11 template<bool Varlen=true> 12 struct BlockInfo { 13 14 template<typename Params> BlockInfoBlockInfo15 __device__ BlockInfo(const Params ¶ms, const int bidb) 16 : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) 17 , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]) 18 , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) 19 // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. 20 // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. 21 , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) 22 , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) 23 { 24 } 25 26 template <typename index_t> q_offsetBlockInfo27 __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { 28 return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; 29 } 30 31 template <typename index_t> k_offsetBlockInfo32 __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { 33 return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; 34 } 35 36 const int sum_s_q; 37 const int sum_s_k; 38 const int actual_seqlen_q; 39 // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. 40 const int seqlen_k_cache; 41 const int actual_seqlen_k; 42 }; 43 44 //////////////////////////////////////////////////////////////////////////////////////////////////// 45 46 } // namespace pytorch_flash 47