chore: update flash attention kernels (#1518)

* chore: update flash attention kernels

* fmt

* remove unused kernels

* force f32

* correct stride
This commit is contained in:
OlivierDehaene
2024-01-05 18:28:55 +01:00
committed by GitHub
parent 3a7304cb0d
commit 8d1a57c9a0
28 changed files with 1087 additions and 466 deletions

View File

@ -3,12 +3,14 @@ mod ffi;
use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::cuda_backend::WrapErr;
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
use half::{bf16, f16};
pub struct FlashAttn {
pub softmax_scale: f32,
pub causal: bool,
pub alibi_slopes: Option<Tensor>,
pub window_size_left: Option<usize>,
pub window_size_right: Option<usize>,
}
fn round_multiple(x: usize, m: usize) -> usize {
@ -85,6 +87,51 @@ impl FlashAttn {
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
}
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
if alibi_slopes.dtype() != DType::F32 {
candle::bail!(
"DType mismatch alibi_slopes {:?}, expected {:?}",
alibi_slopes.dtype(),
DType::F32
);
}
let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout();
if num_heads != alibi_slopes_layout.shape().dims1()? {
candle::bail!(
"shape mismatch alibi_slopes {:?}, expected {:?}",
alibi_slopes_layout.shape(),
(num_heads)
);
}
let alibi_slopes = match &*alibi_slopes {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("alibi_slopes must be a cuda tensor"),
};
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
*alibi_slopes.device_ptr() as *const core::ffi::c_void
} else {
std::ptr::null()
};
// if window_size_left > self.max_seqlen_k or None => -1
let mut window_size_left = self
.window_size_left
.filter(|v| v <= &seqlen_k)
.map(|v| v as i32)
.unwrap_or(-1);
// if window_size_right > self.max_seqlen_k or None => -1
let mut window_size_right = self
.window_size_right
.filter(|v| v <= &seqlen_k)
.map(|v| v as i32)
.unwrap_or(-1);
let head_size = round_multiple(head_size_og, 8);
let head_size_rounded = round_multiple(head_size, 32);
let seqlen_q_rounded = round_multiple(seqlen_q, 128);
@ -94,9 +141,22 @@ impl FlashAttn {
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
let causal = if self.causal { 1 } else { 0 };
let is_bf16 = if is_bf16 { 1 } else { 0 };
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
let is_causal = if window_size_left < 0 && window_size_right == 0 {
1
} else {
0
};
if window_size_left < 0 && window_size_right >= 0 {
window_size_left = seqlen_k as i32;
}
if window_size_left >= 0 && window_size_right < 0 {
window_size_right = seqlen_k as i32;
}
unsafe {
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
@ -109,12 +169,14 @@ impl FlashAttn {
v_ptr,
dst_ptr,
softmax_lse_ptr,
/* alibi_slopes_ptr */ alibi_slopes_ptr,
/* cu_seqlens_q_ptr */ std::ptr::null(),
/* cu_seqlens_k_ptr */ std::ptr::null(),
/* q_batch_stride */ q_stride[0] as u32,
/* k_batch_stride */ k_stride[0] as u32,
/* v_batch_stride */ v_stride[0] as u32,
/* o_batch_stride */ o_stride[0] as u32,
/* alibi_slopes_batch_stride */ 0,
/* q_row_stride */ q_stride[q_rank - 3] as u32,
/* k_row_stride */ k_stride[k_rank - 3] as u32,
/* v_row_stride */ v_stride[v_rank - 3] as u32,
@ -133,8 +195,10 @@ impl FlashAttn {
/* seqlen_k */ seqlen_k as u32,
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_causal */ causal,
/* is_bf16 */ is_bf16,
/* is_causal */ is_causal,
/* window_size_left */ window_size_left,
/* window_size_right */ window_size_right,
)
}
@ -197,20 +261,137 @@ pub fn flash_attn(
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
let window_size_left = None;
let window_size_right = if causal { Some(0) } else { None };
let op = FlashAttn {
softmax_scale,
causal,
alibi_slopes: None,
window_size_left,
window_size_right,
};
q.apply_op3(k, v, op)
}
/// Flash-attention v2 layer.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `window_size_left` - Limit left attention to value tokens.
/// * `window_size_right` - Limit right attention to value tokens.
///
/// # Causal mask
///
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
/// of `Q @ K^T`
///
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
pub fn flash_attn_windowed(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
window_size_left: Option<usize>,
window_size_right: Option<usize>,
) -> Result<Tensor> {
let op = FlashAttn {
softmax_scale,
alibi_slopes: None,
window_size_left,
window_size_right,
};
q.apply_op3(k, v, op)
}
/// Flash-attention v2 layer.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
///
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
pub fn flash_attn_alibi(
q: &Tensor,
k: &Tensor,
v: &Tensor,
alibi_slopes: &Tensor,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
let window_size_left = None;
let window_size_right = if causal { Some(0) } else { None };
let op = FlashAttn {
softmax_scale,
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
};
q.apply_op3(k, v, op)
}
/// Flash-attention v2 layer.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
/// * `window_size_left` - Limit left attention to value tokens.
/// * `window_size_right` - Limit right attention to value tokens.
///
/// # Causal mask
///
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
/// of `Q @ K^T`
///
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
pub fn flash_attn_alibi_windowed(
q: &Tensor,
k: &Tensor,
v: &Tensor,
alibi_slopes: &Tensor,
softmax_scale: f32,
window_size_left: Option<usize>,
window_size_right: Option<usize>,
) -> Result<Tensor> {
let op = FlashAttn {
softmax_scale,
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
};
q.apply_op3(k, v, op)
}
struct FlashAttnVarLen {
softmax_scale: f32,
causal: bool,
max_seqlen_q: usize,
max_seqlen_k: usize,
seqlens_q: Tensor,
seqlens_k: Tensor,
pub softmax_scale: f32,
pub max_seqlen_q: usize,
pub max_seqlen_k: usize,
pub seqlens_q: Tensor,
pub seqlens_k: Tensor,
pub alibi_slopes: Option<Tensor>,
pub window_size_left: Option<usize>,
pub window_size_right: Option<usize>,
}
impl FlashAttnVarLen {
@ -311,7 +492,54 @@ impl FlashAttnVarLen {
if nseqlens_k != nseqlens_q {
candle::bail!("seqlens_q and seqlens_k should have the same number of elements {nseqlens_q} <> {nseqlens_k}")
}
let batch_size = nseqlens_q - 1;
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
if alibi_slopes.dtype() != DType::F32 {
candle::bail!(
"DType mismatch alibi_slopes {:?}, expected {:?}",
alibi_slopes.dtype(),
DType::F32
);
}
let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout();
if num_heads != alibi_slopes_layout.shape().dims1()? {
candle::bail!(
"shape mismatch alibi_slopes {:?}, expected {:?}",
alibi_slopes_layout.shape(),
(num_heads)
);
}
let alibi_slopes = match &*alibi_slopes {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("alibi_slopes must be a cuda tensor"),
};
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
*alibi_slopes.device_ptr() as *const core::ffi::c_void
} else {
std::ptr::null()
};
// if window_size_left > self.max_seqlen_k or None => -1
let mut window_size_left = self
.window_size_left
.filter(|v| v <= &self.max_seqlen_k)
.map(|v| v as i32)
.unwrap_or(-1);
// if window_size_right > self.max_seqlen_k or None => -1
let mut window_size_right = self
.window_size_right
.filter(|v| v <= &self.max_seqlen_k)
.map(|v| v as i32)
.unwrap_or(-1);
let head_size = round_multiple(head_size_og, 8);
let head_size_rounded = round_multiple(head_size, 32);
let seqlen_q_rounded = round_multiple(self.max_seqlen_q, 128);
@ -323,9 +551,22 @@ impl FlashAttnVarLen {
.alloc_zeros::<f32>(batch_size * num_heads * self.max_seqlen_q)
.w()?;
let causal = if self.causal { 1 } else { 0 };
let is_bf16 = if is_bf16 { 1 } else { 0 };
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
let is_causal = if window_size_left < 0 && window_size_right == 0 {
1
} else {
0
};
if window_size_left < 0 && window_size_right >= 0 {
window_size_left = self.max_seqlen_k as i32;
}
if window_size_left >= 0 && window_size_right < 0 {
window_size_right = self.max_seqlen_k as i32;
}
unsafe {
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
@ -340,12 +581,14 @@ impl FlashAttnVarLen {
v_ptr,
dst_ptr,
softmax_lse_ptr,
/* alibi_slopes_ptr */ alibi_slopes_ptr,
/* cu_seqlens_q_ptr */ seqlens_q_ptr,
/* cu_seqlens_k_ptr */ seqlens_k_ptr,
/* q_batch_stride */ 0,
/* k_batch_stride */ 0,
/* v_batch_stride */ 0,
/* o_batch_stride */ 0,
/* alibi_slopes_batch_stride */ 0,
/* q_row_stride */ q_stride[q_rank - 3] as u32,
/* k_row_stride */ k_stride[k_rank - 3] as u32,
/* v_row_stride */ v_stride[v_rank - 3] as u32,
@ -364,8 +607,10 @@ impl FlashAttnVarLen {
/* seqlen_k */ self.max_seqlen_k as u32,
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_causal */ causal,
/* is_bf16 */ is_bf16,
/* is_causal */ is_causal,
/* window_size_left */ window_size_left,
/* window_size_right */ window_size_right,
)
}
@ -440,13 +685,176 @@ pub fn flash_attn_varlen(
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
let window_size_left = None;
let window_size_right = if causal { Some(0) } else { None };
let op = FlashAttnVarLen {
softmax_scale,
causal,
max_seqlen_q,
max_seqlen_k,
seqlens_q: seqlens_q.clone(),
seqlens_k: seqlens_k.clone(),
alibi_slopes: None,
window_size_left,
window_size_right,
};
q.apply_op3(k, v, op)
}
#[allow(clippy::too_many_arguments)]
/// Flash-attention v2 layer with variable-length batching.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
/// * `window_size_left` - Limit left attention to value tokens.
/// * `window_size_right` - Limit right attention to value tokens.
///
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
/// `seqlen_1 + seqlen_2`, etc.
///
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
///
/// # Causal mask
///
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
/// of `Q @ K^T`
pub fn flash_attn_varlen_windowed(
q: &Tensor,
k: &Tensor,
v: &Tensor,
seqlens_q: &Tensor,
seqlens_k: &Tensor,
max_seqlen_q: usize,
max_seqlen_k: usize,
softmax_scale: f32,
window_size_left: Option<usize>,
window_size_right: Option<usize>,
) -> Result<Tensor> {
let op = FlashAttnVarLen {
softmax_scale,
max_seqlen_q,
max_seqlen_k,
seqlens_q: seqlens_q.clone(),
seqlens_k: seqlens_k.clone(),
alibi_slopes: None,
window_size_left,
window_size_right,
};
q.apply_op3(k, v, op)
}
#[allow(clippy::too_many_arguments)]
/// Flash-attention v2 layer with variable-length batching.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
///
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
/// `seqlen_1 + seqlen_2`, etc.
///
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
pub fn flash_attn_varlen_alibi(
q: &Tensor,
k: &Tensor,
v: &Tensor,
alibi_slopes: &Tensor,
seqlens_q: &Tensor,
seqlens_k: &Tensor,
max_seqlen_q: usize,
max_seqlen_k: usize,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
let window_size_left = None;
let window_size_right = if causal { Some(0) } else { None };
let op = FlashAttnVarLen {
softmax_scale,
max_seqlen_q,
max_seqlen_k,
seqlens_q: seqlens_q.clone(),
seqlens_k: seqlens_k.clone(),
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
};
q.apply_op3(k, v, op)
}
#[allow(clippy::too_many_arguments)]
/// Flash-attention v2 layer with variable-length batching.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
/// * `window_size_left` - Limit left attention to value tokens.
/// * `window_size_right` - Limit right attention to value tokens.
///
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
/// `seqlen_1 + seqlen_2`, etc.
///
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
///
/// # Causal mask
///
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
/// of `Q @ K^T`
pub fn flash_attn_varlen_alibi_windowed(
q: &Tensor,
k: &Tensor,
v: &Tensor,
alibi_slopes: &Tensor,
seqlens_q: &Tensor,
seqlens_k: &Tensor,
max_seqlen_q: usize,
max_seqlen_k: usize,
softmax_scale: f32,
window_size_left: Option<usize>,
window_size_right: Option<usize>,
) -> Result<Tensor> {
let op = FlashAttnVarLen {
softmax_scale,
max_seqlen_q,
max_seqlen_k,
seqlens_q: seqlens_q.clone(),
seqlens_k: seqlens_k.clone(),
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
};
q.apply_op3(k, v, op)
}