Add some flash attn test (#253)

* Add some flash-attn test.

* Add the cpu test.

* Fail when the head is not a multiple of 8.

* Polish the flash attention test.
This commit is contained in:
Laurent Mazare
2023-07-26 20:56:00 +01:00
committed by GitHub
parent ded197497c
commit 4f92420132
5 changed files with 125 additions and 14 deletions

View File

@ -18,3 +18,6 @@ half = { version = "2.3.1", features = ["num-traits"] }
anyhow = { version = "1", features = ["backtrace"] }
num_cpus = "1.15.0"
rayon = "1.7.0"
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }

View File

@ -6,7 +6,7 @@ use rayon::prelude::*;
use std::path::PathBuf;
use std::str::FromStr;
const KERNEL_FILES: [&'static str; 9] = [
const KERNEL_FILES: [&str; 9] = [
"flash_api.cu",
"flash_fwd_hdim128_fp16_sm80.cu",
"flash_fwd_hdim160_fp16_sm80.cu",
@ -52,7 +52,11 @@ fn main() -> Result<()> {
println!("cargo:rerun-if-changed=kernels/static_switch.h");
let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?);
let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") {
Err(_) => out_dir.clone(),
Err(_) =>
{
#[allow(clippy::redundant_clone)]
out_dir.clone()
}
Ok(build_dir) => PathBuf::from(build_dir),
};
set_cuda_include_dir()?;

View File

@ -6,7 +6,7 @@ use candle::cuda_backend::WrapErr;
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
use half::f16;
pub struct FlashHdim32Sm80 {
pub struct FlashAttn {
pub softmax_scale: f32,
pub causal: bool,
}
@ -15,7 +15,7 @@ fn round_multiple(x: usize, m: usize) -> usize {
(x + m - 1) / m * m
}
impl candle::CustomOp3 for FlashHdim32Sm80 {
impl candle::CustomOp3 for FlashAttn {
fn name(&self) -> &'static str {
"flash-hdim32-sm80"
}
@ -87,6 +87,10 @@ impl candle::CustomOp3 for FlashHdim32Sm80 {
if head_size_og > 256 {
candle::bail!("only supports head dimension at most 256 (got {head_size_og})")
}
if head_size_og % 8 != 0 {
// TODO: Handle head sizes that are not a multiple of 8 via some padding.
candle::bail!("only supports head sizes that are a multiple of 8 (got {head_size_og})")
}
if num_heads % num_heads_k != 0 {
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
}
@ -145,6 +149,19 @@ impl candle::CustomOp3 for FlashHdim32Sm80 {
}
}
/// Flash-attention v2 layer using flash-attention.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
///
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
pub fn flash_attn(
q: &Tensor,
k: &Tensor,
@ -152,12 +169,9 @@ pub fn flash_attn(
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
q.custom_op3(
k,
v,
FlashHdim32Sm80 {
softmax_scale,
causal,
},
)
let op = FlashAttn {
softmax_scale,
causal,
};
q.custom_op3(k, v, op)
}

View File

@ -0,0 +1,90 @@
use anyhow::Result;
use candle::{DType, Device, IndexOp, Tensor, D};
fn to_vec3_round(t: Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
let b = 10f32.powi(digits);
let t = t.to_vec3::<f32>()?;
let t = t
.iter()
.map(|t| {
t.iter()
.map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
.collect()
})
.collect();
Ok(t)
}
fn fa_acausal(q: &Tensor, k: &Tensor, v: &Tensor, softmax_scale: f32) -> Result<Tensor> {
let in_dtype = q.dtype();
let q = q.to_dtype(DType::F32)?;
let k = k.to_dtype(DType::F32)?;
let v = v.to_dtype(DType::F32)?;
let att = (q.matmul(&k.t()?)? * softmax_scale as f64)?;
let att = att.softmax(D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let output = att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?;
Ok(output)
}
#[test]
fn flash_attn_acausal() -> Result<()> {
let device = Device::new_cuda(0)?;
let q = Tensor::arange(0u32, 48, &device)?
.to_dtype(DType::F16)?
.reshape((1, 3, 2, 8))?;
let k = (&q / 40.)?;
let v = (&q / 50.)?;
let q = (&q / 30.)?;
let ys1 = fa_acausal(&q, &k, &v, 0.5)?;
let ys1 = ys1.i(0)?.to_dtype(DType::F32)?;
let ys2 = {
let q = q.transpose(1, 2)?;
let k = k.transpose(1, 2)?;
let v = v.transpose(1, 2)?;
candle_flash_attn::flash_attn(&q, &k, &v, 0.5, false)?.transpose(1, 2)?
};
let ys2 = ys2.i(0)?.to_dtype(DType::F32)?;
let diff = ys1.sub(&ys2)?.abs()?.flatten_all()?.max(0)?;
assert_eq!(ys1.dims(), &[3, 2, 8]);
assert_eq!(
to_vec3_round(ys1, 4)?,
&[
[
[0.0837, 0.1038, 0.1238, 0.1438, 0.1637, 0.1837, 0.2037, 0.2238],
[0.0922, 0.1122, 0.1322, 0.1522, 0.1721, 0.1921, 0.2122, 0.2322]
],
[
[0.4204, 0.4404, 0.4604, 0.4805, 0.5005, 0.5205, 0.5405, 0.5605],
[0.428, 0.448, 0.468, 0.488, 0.5083, 0.5283, 0.5483, 0.5684]
],
[
[0.7554, 0.7754, 0.7954, 0.8154, 0.8354, 0.8555, 0.8755, 0.8955],
[0.7622, 0.7822, 0.8022, 0.8223, 0.8423, 0.8623, 0.8823, 0.9023]
]
]
);
assert_eq!(ys2.dims(), &[3, 2, 8]);
assert_eq!(
to_vec3_round(ys2, 4)?,
&[
[
[0.0837, 0.1038, 0.1238, 0.1438, 0.1637, 0.1837, 0.2037, 0.2238],
[0.0922, 0.1122, 0.1322, 0.1522, 0.1721, 0.1921, 0.2122, 0.2322]
],
[
[0.4204, 0.4404, 0.4604, 0.4805, 0.5005, 0.5205, 0.5405, 0.5605],
[0.428, 0.448, 0.468, 0.488, 0.5083, 0.5283, 0.5483, 0.5684]
],
[
[0.7554, 0.7754, 0.7954, 0.8154, 0.8354, 0.8555, 0.8755, 0.8955],
[0.7622, 0.7822, 0.8022, 0.8223, 0.8423, 0.8623, 0.8823, 0.9023]
]
]
);
assert!(diff.to_vec0::<f32>()?.abs() < 1e-5);
Ok(())
}

View File

@ -121,8 +121,8 @@ UNARY_OP(float, usin_f32, sing(x))
UNARY_OP(double, usin_f64, sing(x))
UNARY_OP(float, ucos_f32, cosg(x))
UNARY_OP(double, ucos_f64, cosg(x))
UNARY_OP(float, uabsg_f32, absg(x))
UNARY_OP(double, uabsg_f64, absg(x))
UNARY_OP(float, uabs_f32, absg(x))
UNARY_OP(double, uabs_f64, absg(x))
UNARY_OP(float, usqr_f32, x*x)
UNARY_OP(double, usqr_f64, x*x)
UNARY_OP(float, usqrt_f32, sqrtg(x))