mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
Update the flash attn kernels. (#2333)
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
@ -20,7 +20,7 @@ using namespace cute;
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ inline void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
|
||||
@ -35,7 +35,7 @@ __device__ inline void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Te
|
||||
}
|
||||
|
||||
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ inline void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
|
||||
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
|
||||
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(dst); i++){
|
||||
@ -44,26 +44,26 @@ __device__ inline void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Eng
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ inline void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
thread_reduce_<zero_init>(tensor, summary, op);
|
||||
quad_allreduce_(summary, summary, op);
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ inline void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
|
||||
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
|
||||
MaxOp<float> max_op;
|
||||
reduce_<zero_init>(tensor, max, max_op);
|
||||
}
|
||||
|
||||
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ inline void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
|
||||
SumOp<float> sum_op;
|
||||
reduce_(tensor, sum, sum_op);
|
||||
thread_reduce_<zero_init>(tensor, sum, sum_op);
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
inline __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
|
||||
__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
@ -78,14 +78,21 @@ inline __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor
|
||||
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
||||
// max * log_2(e)) This allows the compiler to use the ffma
|
||||
// instruction instead of fadd and fmul separately.
|
||||
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
||||
// The following macro will disable the use of fma.
|
||||
// See: https://github.com/pytorch/pytorch/issues/121558 for more details
|
||||
// This macro is set in PyTorch and not FlashAttention
|
||||
#ifdef UNFUSE_FMA
|
||||
tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);
|
||||
#else
|
||||
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
|
||||
__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
@ -115,169 +122,67 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
|
||||
const int col_idx_offset_ = 0) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
if (col_idx >= max_seqlen_k) {
|
||||
// Without the "make_coord" we get wrong results
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
tensor(mi, make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool HasWSLeft=true, typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
|
||||
const int max_seqlen_k, const int row_idx_offset,
|
||||
const int max_seqlen_q, const int warp_row_stride,
|
||||
const int window_size_left, const int window_size_right) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
||||
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
||||
const int row_idx = row_idx_base + i * 8;
|
||||
const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
|
||||
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
|
||||
template <int kNRows>
|
||||
struct Softmax {
|
||||
|
||||
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
|
||||
TensorT row_max, row_sum;
|
||||
|
||||
__forceinline__ __device__ Softmax() {};
|
||||
|
||||
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1>
|
||||
__forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) {
|
||||
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
static_assert(decltype(size<0>(scores))::value == kNRows);
|
||||
if (Is_first) {
|
||||
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
|
||||
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
||||
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
|
||||
} else {
|
||||
Tensor scores_max_prev = make_fragment_like(row_max);
|
||||
cute::copy(row_max, scores_max_prev);
|
||||
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
float scores_max_cur = !Check_inf
|
||||
? row_max(mi)
|
||||
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
|
||||
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
|
||||
row_sum(mi) *= scores_scale;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
|
||||
}
|
||||
// if (cute::thread0()) {
|
||||
// printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k);
|
||||
// print(tensor(make_coord(i, mi), _));
|
||||
// // print(tensor(_, j + nj * size<1, 0>(tensor)));
|
||||
// }
|
||||
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
||||
// We don't do the reduce across threads here since we don't need to use the row_sum.
|
||||
// We do that reduce at the end when we need to normalize the softmax.
|
||||
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
|
||||
const int max_seqlen_k, const int row_idx_offset,
|
||||
const int max_seqlen_q, const int warp_row_stride) {
|
||||
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
|
||||
apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
|
||||
max_seqlen_q, warp_row_stride, -1, 0);
|
||||
}
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
inline __device__ void apply_mask_causal_w_idx(
|
||||
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
|
||||
const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
|
||||
{
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 2, "Only support 2D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol));
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0)));
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
|
||||
if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
|
||||
tensor(mi, ni) = -INFINITY;
|
||||
}
|
||||
}
|
||||
// if (cute::thread0()) {
|
||||
// printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k);
|
||||
// print(tensor(_, make_coord(j, ni)));
|
||||
// // print(tensor(_, j + ni * size<1, 0>(tensor)));
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
|
||||
inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor, uint8_t p_dropout_in_uint8_t,
|
||||
unsigned long long seed, unsigned long long offset,
|
||||
int block_row_start, int block_col_start,
|
||||
int block_row_stride) {
|
||||
// tensor has shape (8, MMA_M, MMA_N / 2)
|
||||
using T = typename Engine::value_type;
|
||||
auto encode_dropout = [](bool keep, T val) {
|
||||
return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));
|
||||
};
|
||||
static_assert(decltype(size<2>(tensor))::value % 2 == 0);
|
||||
const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t);
|
||||
const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t);
|
||||
// if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); }
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) {
|
||||
uint2 rowcol = make_uint2(block_row_start, block_col_start);
|
||||
|
||||
template<bool Is_dropout=false, bool Split=false, typename Tensor0>
|
||||
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
|
||||
SumOp<float> sum_op;
|
||||
quad_allreduce_(row_sum, row_sum, sum_op);
|
||||
TensorT lse = make_fragment_like(row_sum);
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
|
||||
#pragma unroll
|
||||
for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) {
|
||||
// if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
|
||||
uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
|
||||
// if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
|
||||
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
|
||||
// Special implementation for 16-bit types: we duplicate the threshold to the
|
||||
// low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction
|
||||
// to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000,
|
||||
// and the high 16 bits will be either 0xffff or 0x0000, depending on whether
|
||||
// the random value is less than the threshold.
|
||||
// We then do a bit-wise AND between the mask and the original value (in 32-bit).
|
||||
// We're exploiting the fact that floating point comparison is equivalent to integer
|
||||
// comparison, since we're comparing unsigned integers whose top 8-bits are zero.
|
||||
if (!encode_dropout_in_sign_bit
|
||||
&& (std::is_same<T, cutlass::half_t>::value || std::is_same<T, cutlass::bfloat16_t>::value)) {
|
||||
uint16_t rnd_16[16];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); }
|
||||
uint32_t (&rnd_32)[8] = reinterpret_cast<uint32_t (&)[8]>(rnd_16);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2; j++) {
|
||||
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
|
||||
// if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }
|
||||
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
uint32_t mask;
|
||||
asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t));
|
||||
tensor_uint32(i) &= mask;
|
||||
}
|
||||
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2; j++) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j));
|
||||
}
|
||||
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
|
||||
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||
}
|
||||
}
|
||||
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
|
||||
// // }
|
||||
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
|
||||
float sum = row_sum(mi);
|
||||
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
|
||||
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
|
||||
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
|
||||
}
|
||||
}
|
||||
}
|
||||
return lse;
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace flash
|
||||
|
Reference in New Issue
Block a user