mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
chore: update flash attention kernels (#1518)
* chore: update flash attention kernels * fmt * remove unused kernels * force f32 * correct stride
This commit is contained in:
@ -14,9 +14,12 @@ struct BlockInfo {
|
||||
template<typename Params>
|
||||
__device__ BlockInfo(const Params ¶ms, const int bidb)
|
||||
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
|
||||
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb])
|
||||
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
|
||||
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
|
||||
, actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k)
|
||||
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
|
||||
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
|
||||
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
|
||||
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
|
||||
{
|
||||
}
|
||||
|
||||
@ -32,8 +35,10 @@ struct BlockInfo {
|
||||
|
||||
const int sum_s_q;
|
||||
const int sum_s_k;
|
||||
const uint32_t actual_seqlen_q;
|
||||
const uint32_t actual_seqlen_k;
|
||||
const int actual_seqlen_q;
|
||||
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
|
||||
const int seqlen_k_cache;
|
||||
const int actual_seqlen_k;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
Reference in New Issue
Block a user