mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 12:28:06 +00:00
Update the flash attn kernels. (#2333)
This commit is contained in:
@ -24,12 +24,12 @@ struct BlockInfo {
|
||||
}
|
||||
|
||||
template <typename index_t>
|
||||
inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
||||
__forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
||||
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
|
||||
}
|
||||
|
||||
template <typename index_t>
|
||||
inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
||||
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
||||
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user