mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
chore: update flash attention kernels (#1518)
* chore: update flash attention kernels * fmt * remove unused kernels * force f32 * correct stride
This commit is contained in:
@ -8,8 +8,7 @@
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include "philox.cuh"
|
||||
#include "utils.h"
|
||||
@ -117,15 +116,18 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
|
||||
}
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t max_seqlen_k) {
|
||||
inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
|
||||
const int col_idx_offset_ = 0) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const uint32_t lane_id = threadIdx.x % 32;
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const uint32_t col_idx = nj * 8 + j + (lane_id % 4) * 2;
|
||||
const int col_idx = col_idx_base + j;
|
||||
if (col_idx >= max_seqlen_k) {
|
||||
// Without the "make_coord" we get wrong results
|
||||
#pragma unroll
|
||||
@ -137,30 +139,30 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const uint32_t col_idx_offset_,
|
||||
const uint32_t max_seqlen_k, const uint32_t row_idx_offset_,
|
||||
const uint32_t warp_row_stride) {
|
||||
template <bool HasWSLeft=true, typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
|
||||
const int max_seqlen_k, const int row_idx_offset,
|
||||
const int max_seqlen_q, const int warp_row_stride,
|
||||
const int window_size_left, const int window_size_right) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const uint32_t lane_id = threadIdx.x % 32;
|
||||
// const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4;
|
||||
const uint32_t row_idx_offset = row_idx_offset_;
|
||||
const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
||||
const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
||||
const uint32_t row_idx = row_idx_base + i * 8;
|
||||
const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1);
|
||||
const int row_idx = row_idx_base + i * 8;
|
||||
const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
|
||||
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const uint32_t col_idx_base = col_idx_offset + nj * 8;
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const uint32_t col_idx = col_idx_base + j;
|
||||
if (col_idx >= col_idx_limit) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
@ -174,10 +176,19 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const u
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
|
||||
const int max_seqlen_k, const int row_idx_offset,
|
||||
const int max_seqlen_q, const int warp_row_stride) {
|
||||
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
|
||||
apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
|
||||
max_seqlen_q, warp_row_stride, -1, 0);
|
||||
}
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
inline __device__ void apply_mask_causal_w_idx(
|
||||
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
|
||||
const uint32_t col_idx_offset_, const uint32_t max_seqlen_k, const uint32_t row_idx_offset_)
|
||||
const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
|
||||
{
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
@ -186,7 +197,7 @@ inline __device__ void apply_mask_causal_w_idx(
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
const uint32_t col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0)));
|
||||
const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0)));
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
|
||||
if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
|
||||
@ -204,8 +215,8 @@ inline __device__ void apply_mask_causal_w_idx(
|
||||
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
|
||||
inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor, uint8_t p_dropout_in_uint8_t,
|
||||
unsigned long long seed, unsigned long long offset,
|
||||
uint32_t block_row_start, uint32_t block_col_start,
|
||||
uint32_t block_row_stride) {
|
||||
int block_row_start, int block_col_start,
|
||||
int block_row_stride) {
|
||||
// tensor has shape (8, MMA_M, MMA_N / 2)
|
||||
using T = typename Engine::value_type;
|
||||
auto encode_dropout = [](bool keep, T val) {
|
||||
|
Reference in New Issue
Block a user