Proper flash-attn parameters. (#244)

* Proper flash-attn parameters.

* Set the flash attention parameters.

* Add more validations.

* Setup the o_ flash attn parameters.

* More flash-attn support.

* Set more flash attn parameters.
This commit is contained in:
Laurent Mazare
2023-07-26 10:13:40 +01:00
committed by GitHub
parent e40b150bbe
commit fa2b64d678
5 changed files with 147 additions and 12 deletions

View File

@ -203,3 +203,16 @@ impl Error {
}
}
}
#[macro_export]
macro_rules! bail {
($msg:literal $(,)?) => {
return Err($crate::Error::Wrapped(format!($msg).into()).bt())
};
($err:expr $(,)?) => {
return Err($crate::Error::Wrapped(format!($err).into()).bt())
};
($fmt:expr, $($arg:tt)*) => {
return Err($crate::Error::Wrapped(format!($fmt, $($arg)*).into()).bt())
};
}

View File

@ -146,12 +146,19 @@ struct CausalSelfAttention {
}
#[cfg(feature = "flash-attn")]
fn flash_attn(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
q.custom_op3(k, v, candle_flash_attn::FlashHdim32Sm80)
fn flash_attn(softmax_scale: f32, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
q.custom_op3(
k,
v,
candle_flash_attn::FlashHdim32Sm80 {
softmax_scale,
causal: true,
},
)
}
#[cfg(not(feature = "flash-attn"))]
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor) -> Result<Tensor> {
fn flash_attn(_: f32, _: &Tensor, _: &Tensor, _: &Tensor) -> Result<Tensor> {
unimplemented!("compile with '--features flash-attn'")
}
@ -213,7 +220,8 @@ impl CausalSelfAttention {
let v = self.repeat_kv(v)?;
let y = if self.use_flash_attn {
flash_attn(&q, &k, &v)?
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
flash_attn(softmax_scale, &q, &k, &v)?
} else {
let in_dtype = q.dtype();
let q = q.to_dtype(DType::F32)?;

View File

@ -28,16 +28,22 @@ extern "C" void run_mha(
void *k_ptr,
void *v_ptr,
void *o_ptr,
void *softmax_lse_ptr,
uint32_t q_batch_stride,
uint32_t k_batch_stride,
uint32_t v_batch_stride,
uint32_t o_batch_stride,
uint32_t q_row_stride,
uint32_t k_row_stride,
uint32_t v_row_stride,
uint32_t o_row_stride,
uint32_t q_head_stride,
uint32_t k_head_stride,
uint32_t v_head_stride,
uint32_t o_head_stride,
uint32_t b,
uint32_t h,
@ -61,14 +67,24 @@ extern "C" void run_mha(
params.q_ptr = q_ptr;
params.k_ptr = k_ptr;
params.v_ptr = v_ptr;
params.o_ptr = o_ptr;
params.softmax_lse_ptr = softmax_lse_ptr;
// All stride are in elements, not bytes.
params.q_batch_stride = q_batch_stride;
params.k_batch_stride = k_batch_stride;
params.v_batch_stride = v_batch_stride;
params.o_batch_stride = o_batch_stride;
params.q_row_stride = q_row_stride;
params.k_row_stride = k_row_stride;
params.v_row_stride = v_row_stride;
params.o_row_stride = o_row_stride;
params.q_head_stride = q_head_stride;
params.k_head_stride = k_head_stride;
params.v_head_stride = v_head_stride;
params.o_ptr = o_ptr;
params.o_head_stride = o_head_stride;
// Set the dimensions.
params.b = b;
@ -87,6 +103,11 @@ extern "C" void run_mha(
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
params.p_dropout = 1.; // probability to keep
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
params.rp_dropout = 1.f / params.p_dropout;
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
cudaStream_t stream = 0; // Use the default stream.
run_mha_fwd_<cutlass::half_t, 32>(params, stream);
}

View File

@ -6,16 +6,22 @@ extern "C" {
k_ptr: *const c_void,
v_ptr: *const c_void,
o_ptr: *const c_void,
softmax_lse_ptr: *const c_void,
q_batch_stride: u32,
k_batch_stride: u32,
v_batch_stride: u32,
o_batch_stride: u32,
q_row_stride: u32,
k_row_stride: u32,
v_row_stride: u32,
o_row_stride: u32,
q_head_stride: u32,
k_head_stride: u32,
v_head_stride: u32,
o_head_stride: u32,
b: u32,
h: u32,

View File

@ -6,7 +6,14 @@ use candle::cuda_backend::WrapErr;
use candle::{CpuStorage, Error, Layout, Result, Shape};
use half::f16;
pub struct FlashHdim32Sm80;
pub struct FlashHdim32Sm80 {
pub softmax_scale: f32,
pub causal: bool,
}
fn round_multiple(x: usize, m: usize) -> usize {
(x + m - 1) / m * m
}
impl candle::CustomOp3 for FlashHdim32Sm80 {
fn name(&self) -> &'static str {
@ -28,28 +35,108 @@ impl candle::CustomOp3 for FlashHdim32Sm80 {
fn cuda_fwd(
&self,
q: &candle::CudaStorage,
_q_l: &Layout,
q_l: &Layout,
k: &candle::CudaStorage,
_k_l: &Layout,
k_l: &Layout,
v: &candle::CudaStorage,
_v_l: &Layout,
v_l: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
// https://github.com/Dao-AILab/flash-attention/blob/b252072409e69c25f2b9d473cc534e49b24decd2/csrc/flash_attn/flash_api.cpp#L187
let dev = q.device();
let out_shape = Shape::from(&[1]);
let out_shape = q_l.shape().clone();
let out_l = Layout::contiguous(&out_shape);
let q = q.as_cuda_slice::<f16>()?;
let k = k.as_cuda_slice::<f16>()?;
let v = v.as_cuda_slice::<f16>()?;
let q_stride = q_l.stride();
let k_stride = k_l.stride();
let v_stride = v_l.stride();
let o_stride = out_l.stride();
let q_rank = q_stride.len();
let k_rank = k_stride.len();
let v_rank = v_stride.len();
let o_rank = o_stride.len();
if q_rank != 4 || k_rank != 4 || v_rank != 4 {
candle::bail!(
"flash-attn expects input tensors of rank 4 (q: {q_rank}, k: {k_rank}, v: {v_rank}"
)
}
if q_stride[q_rank - 1] != 1 {
candle::bail!("the last dim of q must be contiguous {q_stride:?}")
}
if k_stride[k_rank - 1] != 1 {
candle::bail!("the last dim of k must be contiguous {k_stride:?}")
}
if v_stride[v_rank - 1] != 1 {
candle::bail!("the last dim of v must be contiguous {v_stride:?}")
}
let (b_sz, seqlen_q, num_heads, head_size_og) = q_l.shape().dims4()?;
let (_b_sz, seqlen_k, num_heads_k, _head_size_og) = k_l.shape().dims4()?;
let expected_kv = (b_sz, seqlen_k, num_heads_k, head_size_og);
if expected_kv != k_l.shape().dims4()? {
candle::bail!("shape mismatch q {:?} and k {:?}", q_l.shape(), k_l.shape())
}
if expected_kv != v_l.shape().dims4()? {
candle::bail!("shape mismatch q {:?} and v {:?}", q_l.shape(), v_l.shape())
}
if head_size_og > 256 {
candle::bail!("only supports head dimension at most 256 (got {head_size_og})")
}
if num_heads % num_heads_k != 0 {
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
}
let head_size = round_multiple(head_size_og, 8);
let head_size_rounded = round_multiple(head_size, 32);
let seqlen_q_rounded = round_multiple(seqlen_q, 128);
let seqlen_k_rounded = round_multiple(seqlen_k, 128);
let elem_count = out_shape.elem_count();
let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
let causal = if self.causal { 1 } else { 0 };
unsafe {
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
let v_ptr = *v.device_ptr() as *const core::ffi::c_void;
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void;
ffi::run_mha(
q_ptr, k_ptr, v_ptr, dst_ptr, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1.0, 1, 1,
1, 1, 1,
q_ptr,
k_ptr,
v_ptr,
dst_ptr,
softmax_lse_ptr,
/* q_batch_stride */ q_stride[0] as u32,
/* k_batch_stride */ k_stride[0] as u32,
/* v_batch_stride */ v_stride[0] as u32,
/* o_batch_stride */ o_stride[0] as u32,
/* q_row_stride */ q_stride[q_rank - 3] as u32,
/* k_row_stride */ k_stride[k_rank - 3] as u32,
/* v_row_stride */ v_stride[v_rank - 3] as u32,
/* o_row_stride */ o_stride[o_rank - 3] as u32,
/* q_head_stride */ q_stride[q_rank - 2] as u32,
/* k_head_stride */ k_stride[k_rank - 2] as u32,
/* v_head_stride */ v_stride[v_rank - 2] as u32,
/* o_head_stride */ o_stride[o_rank - 2] as u32,
/* b */ b_sz as u32,
/* h */ num_heads as u32,
/* h_k */ num_heads_k as u32,
/* d */ head_size as u32,
/* d_rounded */ head_size_rounded as u32,
/* softmax_scale*/ self.softmax_scale,
/* seqlen_q */ seqlen_q as u32,
/* seqlen_k */ seqlen_k as u32,
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_causal */ causal,
)
}