mirror of
https://github.com/huggingface/candle.git
synced 2025-06-14 09:57:10 +00:00

* Add some flash-attn kernel, import the code for flash-attn v2 from Dao-AILab. * More flash attn. * Set up the flash attn parameters. * Get things to compile locally. * Move the flash attention files in a different directory. * Build the static C library with nvcc. * Add more flash attention. * Update the build part. * Better caching. * Exclude flash attention from the default workspace. * Put flash-attn behind a feature gate. * Get the flash attn kernel to run. * Move the flags to a more appropriate place. * Enable flash attention in llama. * Use flash attention in llama.
42 lines
1.6 KiB
C++
42 lines
1.6 KiB
C++
/******************************************************************************
|
|
* Copyright (c) 2023, Tri Dao.
|
|
******************************************************************************/
|
|
|
|
#pragma once
|
|
|
|
namespace flash {
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template<bool Varlen=true>
|
|
struct BlockInfo {
|
|
|
|
template<typename Params>
|
|
__device__ BlockInfo(const Params ¶ms, const int bidb)
|
|
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
|
|
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb])
|
|
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
|
|
, actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k)
|
|
{
|
|
}
|
|
|
|
template <typename index_t>
|
|
inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
|
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
|
|
}
|
|
|
|
template <typename index_t>
|
|
inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
|
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
|
|
}
|
|
|
|
const int sum_s_q;
|
|
const int sum_s_k;
|
|
const uint32_t actual_seqlen_q;
|
|
const uint32_t actual_seqlen_k;
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace flash
|