mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Flash attention without padding (varlen). (#281)
* Expose the seqlen variable for flash-attn without padding. * Fix the batched call. * Adapt for the varlen variant. * No need to set the batch strides when in varlen mode. * Add a test (disabled at the moment). * Get the test to work properly.
This commit is contained in:
@ -22,6 +22,9 @@ extern "C" void run_mha(
|
|||||||
void *o_ptr,
|
void *o_ptr,
|
||||||
void *softmax_lse_ptr,
|
void *softmax_lse_ptr,
|
||||||
|
|
||||||
|
int32_t *cu_seqlens_q_ptr,
|
||||||
|
int32_t *cu_seqlens_k_ptr,
|
||||||
|
|
||||||
uint32_t q_batch_stride,
|
uint32_t q_batch_stride,
|
||||||
uint32_t k_batch_stride,
|
uint32_t k_batch_stride,
|
||||||
uint32_t v_batch_stride,
|
uint32_t v_batch_stride,
|
||||||
@ -100,9 +103,9 @@ extern "C" void run_mha(
|
|||||||
params.rp_dropout = 1.f / params.p_dropout;
|
params.rp_dropout = 1.f / params.p_dropout;
|
||||||
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
|
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
|
||||||
params.is_bf16 = 0;
|
params.is_bf16 = 0;
|
||||||
params.cu_seqlens_q = nullptr;
|
params.cu_seqlens_q = cu_seqlens_q_ptr;
|
||||||
params.cu_seqlens_k = nullptr;
|
params.cu_seqlens_k = cu_seqlens_k_ptr;
|
||||||
params.p_ptr = nullptr;
|
params.p_ptr = nullptr; // used for `return_softmax`.
|
||||||
|
|
||||||
cudaStream_t stream = 0; // Use the default stream.
|
cudaStream_t stream = 0; // Use the default stream.
|
||||||
run_mha_fwd(params, stream);
|
run_mha_fwd(params, stream);
|
||||||
|
@ -7,6 +7,8 @@ extern "C" {
|
|||||||
v_ptr: *const c_void,
|
v_ptr: *const c_void,
|
||||||
o_ptr: *const c_void,
|
o_ptr: *const c_void,
|
||||||
softmax_lse_ptr: *const c_void,
|
softmax_lse_ptr: *const c_void,
|
||||||
|
cu_seqlens_q_ptr: *const i32,
|
||||||
|
cu_seqlens_k_ptr: *const i32,
|
||||||
|
|
||||||
q_batch_stride: u32,
|
q_batch_stride: u32,
|
||||||
k_batch_stride: u32,
|
k_batch_stride: u32,
|
||||||
|
@ -49,6 +49,9 @@ impl candle::CustomOp3 for FlashAttn {
|
|||||||
let q = q.as_cuda_slice::<f16>()?;
|
let q = q.as_cuda_slice::<f16>()?;
|
||||||
let k = k.as_cuda_slice::<f16>()?;
|
let k = k.as_cuda_slice::<f16>()?;
|
||||||
let v = v.as_cuda_slice::<f16>()?;
|
let v = v.as_cuda_slice::<f16>()?;
|
||||||
|
let q = q.slice(q_l.start_offset()..);
|
||||||
|
let k = k.slice(k_l.start_offset()..);
|
||||||
|
let v = v.slice(v_l.start_offset()..);
|
||||||
|
|
||||||
let q_stride = q_l.stride();
|
let q_stride = q_l.stride();
|
||||||
let k_stride = k_l.stride();
|
let k_stride = k_l.stride();
|
||||||
@ -118,6 +121,8 @@ impl candle::CustomOp3 for FlashAttn {
|
|||||||
v_ptr,
|
v_ptr,
|
||||||
dst_ptr,
|
dst_ptr,
|
||||||
softmax_lse_ptr,
|
softmax_lse_ptr,
|
||||||
|
/* cu_seqlens_q_ptr */ std::ptr::null(),
|
||||||
|
/* cu_seqlens_k_ptr */ std::ptr::null(),
|
||||||
/* q_batch_stride */ q_stride[0] as u32,
|
/* q_batch_stride */ q_stride[0] as u32,
|
||||||
/* k_batch_stride */ k_stride[0] as u32,
|
/* k_batch_stride */ k_stride[0] as u32,
|
||||||
/* v_batch_stride */ v_stride[0] as u32,
|
/* v_batch_stride */ v_stride[0] as u32,
|
||||||
@ -149,7 +154,7 @@ impl candle::CustomOp3 for FlashAttn {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Flash-attention v2 layer using flash-attention.
|
/// Flash-attention v2 layer.
|
||||||
///
|
///
|
||||||
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||||
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
|
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
|
||||||
@ -175,3 +180,227 @@ pub fn flash_attn(
|
|||||||
};
|
};
|
||||||
q.custom_op3(k, v, op)
|
q.custom_op3(k, v, op)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct FlashAttnVarLen {
|
||||||
|
softmax_scale: f32,
|
||||||
|
causal: bool,
|
||||||
|
max_seqlen_q: usize,
|
||||||
|
max_seqlen_k: usize,
|
||||||
|
seqlens_q: Tensor,
|
||||||
|
seqlens_k: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl candle::CustomOp3 for FlashAttnVarLen {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"flash-hdim32-sm80"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cpu_fwd(
|
||||||
|
&self,
|
||||||
|
_: &CpuStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &CpuStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &CpuStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<(CpuStorage, Shape)> {
|
||||||
|
candle::bail!("no cpu support for flash-attn")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cuda_fwd(
|
||||||
|
&self,
|
||||||
|
q: &candle::CudaStorage,
|
||||||
|
q_l: &Layout,
|
||||||
|
k: &candle::CudaStorage,
|
||||||
|
k_l: &Layout,
|
||||||
|
v: &candle::CudaStorage,
|
||||||
|
v_l: &Layout,
|
||||||
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
|
// https://github.com/Dao-AILab/flash-attention/blob/184b992dcb2a0890adaa19eb9b541c3e4f9d2a08/csrc/flash_attn/flash_api.cpp#L327
|
||||||
|
let dev = q.device();
|
||||||
|
let out_shape = q_l.shape().clone();
|
||||||
|
let out_l = Layout::contiguous(&out_shape);
|
||||||
|
|
||||||
|
let (seqlens_q, seqlens_q_layout) = self.seqlens_q.storage_and_layout();
|
||||||
|
let seqlens_q = match &*seqlens_q {
|
||||||
|
candle::Storage::Cpu(_) => candle::bail!("seqlens_q must be a cuda tensor"),
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
|
||||||
|
};
|
||||||
|
let seqlens_q = match seqlens_q_layout.contiguous_offsets() {
|
||||||
|
Some((o1, o2)) => seqlens_q.slice(o1..o2),
|
||||||
|
None => candle::bail!("seqlens_q has to be contiguous"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let (seqlens_k, seqlens_k_layout) = self.seqlens_k.storage_and_layout();
|
||||||
|
let seqlens_k = match &*seqlens_k {
|
||||||
|
candle::Storage::Cpu(_) => candle::bail!("seqlens_k must be a cuda tensor"),
|
||||||
|
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
|
||||||
|
};
|
||||||
|
let seqlens_k = match seqlens_k_layout.contiguous_offsets() {
|
||||||
|
Some((o1, o2)) => seqlens_k.slice(o1..o2),
|
||||||
|
None => candle::bail!("seqlens_k has to be contiguous"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let q = q.as_cuda_slice::<f16>()?;
|
||||||
|
let k = k.as_cuda_slice::<f16>()?;
|
||||||
|
let v = v.as_cuda_slice::<f16>()?;
|
||||||
|
let q = q.slice(q_l.start_offset()..);
|
||||||
|
let k = k.slice(k_l.start_offset()..);
|
||||||
|
let v = v.slice(v_l.start_offset()..);
|
||||||
|
|
||||||
|
let q_stride = q_l.stride();
|
||||||
|
let k_stride = k_l.stride();
|
||||||
|
let v_stride = v_l.stride();
|
||||||
|
let o_stride = out_l.stride();
|
||||||
|
|
||||||
|
let q_rank = q_stride.len();
|
||||||
|
let k_rank = k_stride.len();
|
||||||
|
let v_rank = v_stride.len();
|
||||||
|
let o_rank = o_stride.len();
|
||||||
|
|
||||||
|
if q_rank != 3 || k_rank != 3 || v_rank != 3 {
|
||||||
|
candle::bail!(
|
||||||
|
"flash-attn-varlen expects input tensors of rank 3 (q: {q_rank}, k: {k_rank}, v: {v_rank}"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if q_stride[q_rank - 1] != 1 {
|
||||||
|
candle::bail!("the last dim of q must be contiguous {q_stride:?}")
|
||||||
|
}
|
||||||
|
if k_stride[k_rank - 1] != 1 {
|
||||||
|
candle::bail!("the last dim of k must be contiguous {k_stride:?}")
|
||||||
|
}
|
||||||
|
if v_stride[v_rank - 1] != 1 {
|
||||||
|
candle::bail!("the last dim of v must be contiguous {v_stride:?}")
|
||||||
|
}
|
||||||
|
|
||||||
|
let (_total_q, num_heads, head_size_og) = q_l.shape().dims3()?;
|
||||||
|
let (total_k, num_heads_k, _head_size_og) = k_l.shape().dims3()?;
|
||||||
|
let expected_kv = (total_k, num_heads_k, head_size_og);
|
||||||
|
if expected_kv != k_l.shape().dims3()? {
|
||||||
|
candle::bail!("shape mismatch q {:?} and k {:?}", q_l.shape(), k_l.shape())
|
||||||
|
}
|
||||||
|
if expected_kv != v_l.shape().dims3()? {
|
||||||
|
candle::bail!("shape mismatch q {:?} and v {:?}", q_l.shape(), v_l.shape())
|
||||||
|
}
|
||||||
|
if head_size_og > 256 {
|
||||||
|
candle::bail!("only supports head dimension at most 256 (got {head_size_og})")
|
||||||
|
}
|
||||||
|
if head_size_og % 8 != 0 {
|
||||||
|
// TODO: Handle head sizes that are not a multiple of 8 via some padding.
|
||||||
|
candle::bail!("only supports head sizes that are a multiple of 8 (got {head_size_og})")
|
||||||
|
}
|
||||||
|
if num_heads % num_heads_k != 0 {
|
||||||
|
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
|
||||||
|
}
|
||||||
|
|
||||||
|
let nseqlens_q = seqlens_q_layout.shape().dims1()?;
|
||||||
|
if nseqlens_q < 2 {
|
||||||
|
candle::bail!("seqlens_q should have a len >= 2 {nseqlens_q}")
|
||||||
|
}
|
||||||
|
let nseqlens_k = seqlens_k_layout.shape().dims1()?;
|
||||||
|
if nseqlens_k != nseqlens_q {
|
||||||
|
candle::bail!("seqlens_q and seqlens_k should have the same number of elements {nseqlens_q} <> {nseqlens_k}")
|
||||||
|
}
|
||||||
|
let batch_size = nseqlens_q - 1;
|
||||||
|
let head_size = round_multiple(head_size_og, 8);
|
||||||
|
let head_size_rounded = round_multiple(head_size, 32);
|
||||||
|
let seqlen_q_rounded = round_multiple(self.max_seqlen_q, 128);
|
||||||
|
let seqlen_k_rounded = round_multiple(self.max_seqlen_k, 128);
|
||||||
|
|
||||||
|
let elem_count = out_shape.elem_count();
|
||||||
|
let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||||
|
let softmax_lse = dev
|
||||||
|
.alloc_zeros::<f32>(batch_size * num_heads * self.max_seqlen_q)
|
||||||
|
.w()?;
|
||||||
|
|
||||||
|
let causal = if self.causal { 1 } else { 0 };
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
||||||
|
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
|
||||||
|
let v_ptr = *v.device_ptr() as *const core::ffi::c_void;
|
||||||
|
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
|
||||||
|
let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void;
|
||||||
|
let seqlens_q_ptr = *seqlens_q.device_ptr() as *const core::ffi::c_int;
|
||||||
|
let seqlens_k_ptr = *seqlens_k.device_ptr() as *const core::ffi::c_int;
|
||||||
|
ffi::run_mha(
|
||||||
|
q_ptr,
|
||||||
|
k_ptr,
|
||||||
|
v_ptr,
|
||||||
|
dst_ptr,
|
||||||
|
softmax_lse_ptr,
|
||||||
|
/* cu_seqlens_q_ptr */ seqlens_q_ptr,
|
||||||
|
/* cu_seqlens_k_ptr */ seqlens_k_ptr,
|
||||||
|
/* q_batch_stride */ 0,
|
||||||
|
/* k_batch_stride */ 0,
|
||||||
|
/* v_batch_stride */ 0,
|
||||||
|
/* o_batch_stride */ 0,
|
||||||
|
/* q_row_stride */ q_stride[q_rank - 3] as u32,
|
||||||
|
/* k_row_stride */ k_stride[k_rank - 3] as u32,
|
||||||
|
/* v_row_stride */ v_stride[v_rank - 3] as u32,
|
||||||
|
/* o_row_stride */ o_stride[o_rank - 3] as u32,
|
||||||
|
/* q_head_stride */ q_stride[q_rank - 2] as u32,
|
||||||
|
/* k_head_stride */ k_stride[k_rank - 2] as u32,
|
||||||
|
/* v_head_stride */ v_stride[v_rank - 2] as u32,
|
||||||
|
/* o_head_stride */ o_stride[o_rank - 2] as u32,
|
||||||
|
/* b */ batch_size as u32,
|
||||||
|
/* h */ num_heads as u32,
|
||||||
|
/* h_k */ num_heads_k as u32,
|
||||||
|
/* d */ head_size as u32,
|
||||||
|
/* d_rounded */ head_size_rounded as u32,
|
||||||
|
/* softmax_scale*/ self.softmax_scale,
|
||||||
|
/* seqlen_q */ self.max_seqlen_q as u32,
|
||||||
|
/* seqlen_k */ self.max_seqlen_k as u32,
|
||||||
|
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
|
||||||
|
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
||||||
|
/* is_causal */ causal,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev.clone());
|
||||||
|
Ok((dst, out_shape))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
/// Flash-attention v2 layer with variable-length batching.
|
||||||
|
///
|
||||||
|
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||||
|
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
|
||||||
|
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
|
||||||
|
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
|
||||||
|
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
|
||||||
|
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
|
||||||
|
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
|
||||||
|
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
|
||||||
|
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
|
||||||
|
///
|
||||||
|
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
|
||||||
|
/// `seqlen_1 + seqlen_2`, etc.
|
||||||
|
///
|
||||||
|
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
|
||||||
|
pub fn flash_attn_varlen(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
seqlens_q: &Tensor,
|
||||||
|
seqlens_k: &Tensor,
|
||||||
|
max_seqlen_q: usize,
|
||||||
|
max_seqlen_k: usize,
|
||||||
|
softmax_scale: f32,
|
||||||
|
causal: bool,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let op = FlashAttnVarLen {
|
||||||
|
softmax_scale,
|
||||||
|
causal,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
seqlens_q: seqlens_q.clone(),
|
||||||
|
seqlens_k: seqlens_k.clone(),
|
||||||
|
};
|
||||||
|
q.custom_op3(k, v, op)
|
||||||
|
}
|
||||||
|
@ -88,3 +88,48 @@ fn flash_attn_acausal() -> Result<()> {
|
|||||||
assert!(diff.to_vec0::<f32>()?.abs() < 1e-5);
|
assert!(diff.to_vec0::<f32>()?.abs() < 1e-5);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn flash_attn_varlen() -> Result<()> {
|
||||||
|
let device = Device::new_cuda(0)?;
|
||||||
|
let q = Tensor::arange(0u32, 48, &device)?
|
||||||
|
.to_dtype(DType::F16)?
|
||||||
|
.reshape((3, 2, 8))?;
|
||||||
|
let k = (&q / 40.)?;
|
||||||
|
let v = (&q / 50.)?;
|
||||||
|
let q = (&q / 30.)?;
|
||||||
|
|
||||||
|
let seqlens_q = Tensor::new(&[0u32, 2u32], &device)?;
|
||||||
|
let seqlens_k = Tensor::new(&[0u32, 2u32], &device)?;
|
||||||
|
|
||||||
|
let ys = {
|
||||||
|
let q = q.transpose(0, 1)?;
|
||||||
|
let k = k.transpose(0, 1)?;
|
||||||
|
let v = v.transpose(0, 1)?;
|
||||||
|
candle_flash_attn::flash_attn_varlen(
|
||||||
|
&q, &k, &v, &seqlens_q, &seqlens_k, 32, 32, 0.5, false,
|
||||||
|
)?
|
||||||
|
.transpose(0, 1)?
|
||||||
|
};
|
||||||
|
let ys = ys.to_dtype(DType::F32)?;
|
||||||
|
|
||||||
|
assert_eq!(ys.dims(), &[3, 2, 8]);
|
||||||
|
assert_eq!(
|
||||||
|
to_vec3_round(ys, 4)?,
|
||||||
|
&[
|
||||||
|
[
|
||||||
|
[0.0837, 0.1038, 0.1238, 0.1438, 0.1637, 0.1837, 0.2037, 0.2238],
|
||||||
|
[0.0922, 0.1122, 0.1322, 0.1522, 0.1721, 0.1921, 0.2122, 0.2322]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[0.4204, 0.4404, 0.4604, 0.4805, 0.5005, 0.5205, 0.5405, 0.5605],
|
||||||
|
[0.428, 0.448, 0.468, 0.488, 0.5083, 0.5283, 0.5483, 0.5684]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[0.7554, 0.7754, 0.7954, 0.8154, 0.8354, 0.8555, 0.8755, 0.8955],
|
||||||
|
[0.7622, 0.7822, 0.8022, 0.8223, 0.8423, 0.8623, 0.8823, 0.9023]
|
||||||
|
]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user