mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Add some flash attn test (#253)
* Add some flash-attn test. * Add the cpu test. * Fail when the head is not a multiple of 8. * Polish the flash attention test.
This commit is contained in:
@ -6,7 +6,7 @@ use candle::cuda_backend::WrapErr;
|
||||
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
|
||||
use half::f16;
|
||||
|
||||
pub struct FlashHdim32Sm80 {
|
||||
pub struct FlashAttn {
|
||||
pub softmax_scale: f32,
|
||||
pub causal: bool,
|
||||
}
|
||||
@ -15,7 +15,7 @@ fn round_multiple(x: usize, m: usize) -> usize {
|
||||
(x + m - 1) / m * m
|
||||
}
|
||||
|
||||
impl candle::CustomOp3 for FlashHdim32Sm80 {
|
||||
impl candle::CustomOp3 for FlashAttn {
|
||||
fn name(&self) -> &'static str {
|
||||
"flash-hdim32-sm80"
|
||||
}
|
||||
@ -87,6 +87,10 @@ impl candle::CustomOp3 for FlashHdim32Sm80 {
|
||||
if head_size_og > 256 {
|
||||
candle::bail!("only supports head dimension at most 256 (got {head_size_og})")
|
||||
}
|
||||
if head_size_og % 8 != 0 {
|
||||
// TODO: Handle head sizes that are not a multiple of 8 via some padding.
|
||||
candle::bail!("only supports head sizes that are a multiple of 8 (got {head_size_og})")
|
||||
}
|
||||
if num_heads % num_heads_k != 0 {
|
||||
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
|
||||
}
|
||||
@ -145,6 +149,19 @@ impl candle::CustomOp3 for FlashHdim32Sm80 {
|
||||
}
|
||||
}
|
||||
|
||||
/// Flash-attention v2 layer using flash-attention.
|
||||
///
|
||||
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
|
||||
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
|
||||
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
|
||||
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
|
||||
///
|
||||
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
|
||||
pub fn flash_attn(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
@ -152,12 +169,9 @@ pub fn flash_attn(
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
q.custom_op3(
|
||||
k,
|
||||
v,
|
||||
FlashHdim32Sm80 {
|
||||
softmax_scale,
|
||||
causal,
|
||||
},
|
||||
)
|
||||
let op = FlashAttn {
|
||||
softmax_scale,
|
||||
causal,
|
||||
};
|
||||
q.custom_op3(k, v, op)
|
||||
}
|
||||
|
Reference in New Issue
Block a user