Add some flash attn test (#253)

* Add some flash-attn test.

* Add the cpu test.

* Fail when the head is not a multiple of 8.

* Polish the flash attention test.
This commit is contained in:
Laurent Mazare
2023-07-26 20:56:00 +01:00
committed by GitHub
parent ded197497c
commit 4f92420132
5 changed files with 125 additions and 14 deletions

View File

@ -6,7 +6,7 @@ use candle::cuda_backend::WrapErr;
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
use half::f16;
pub struct FlashHdim32Sm80 {
pub struct FlashAttn {
pub softmax_scale: f32,
pub causal: bool,
}
@ -15,7 +15,7 @@ fn round_multiple(x: usize, m: usize) -> usize {
(x + m - 1) / m * m
}
impl candle::CustomOp3 for FlashHdim32Sm80 {
impl candle::CustomOp3 for FlashAttn {
fn name(&self) -> &'static str {
"flash-hdim32-sm80"
}
@ -87,6 +87,10 @@ impl candle::CustomOp3 for FlashHdim32Sm80 {
if head_size_og > 256 {
candle::bail!("only supports head dimension at most 256 (got {head_size_og})")
}
if head_size_og % 8 != 0 {
// TODO: Handle head sizes that are not a multiple of 8 via some padding.
candle::bail!("only supports head sizes that are a multiple of 8 (got {head_size_og})")
}
if num_heads % num_heads_k != 0 {
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
}
@ -145,6 +149,19 @@ impl candle::CustomOp3 for FlashHdim32Sm80 {
}
}
/// Flash-attention v2 layer using flash-attention.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
///
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
pub fn flash_attn(
q: &Tensor,
k: &Tensor,
@ -152,12 +169,9 @@ pub fn flash_attn(
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
q.custom_op3(
k,
v,
FlashHdim32Sm80 {
softmax_scale,
causal,
},
)
let op = FlashAttn {
softmax_scale,
causal,
};
q.custom_op3(k, v, op)
}