Update the flash attn kernels. (#2333)

This commit is contained in:
Laurent Mazare
2024-07-15 20:37:36 +02:00
committed by GitHub
parent d74fbed334
commit 30cdd769f9
51 changed files with 2279 additions and 904 deletions

View File

@ -24,12 +24,12 @@ struct BlockInfo {
}
template <typename index_t>
inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
__forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
}
template <typename index_t>
inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
}