Flash-Attn upgrade / SoftCap Candle-FlashAttn [2/n] (#2689)

* update flash-attn v1

* restore: hdim224

* add 224 flash_fwd_template

* remove whitespace

* softcap is working, including test and api.

* make softcap test case better

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
Michael Feil
2024-12-31 09:41:23 +01:00
committed by GitHub
parent 71cd6d5533
commit a594ef669c
4 changed files with 182 additions and 3 deletions

View File

@ -27,6 +27,20 @@ fn fa_acausal(q: &Tensor, k: &Tensor, v: &Tensor, softmax_scale: f32) -> Result<
Ok(output)
}
fn fa_acausal_softcap(q: &Tensor, k: &Tensor, v: &Tensor, softcap: f32) -> Result<Tensor> {
let in_dtype = q.dtype();
let q = q.to_dtype(DType::F32)?;
let k = k.to_dtype(DType::F32)?;
let v = v.to_dtype(DType::F32)?;
// let att = (q.matmul(&k.t()?)? * softmax_scale as f64)?;
let att = q.matmul(&k.t()?)?;
let att = (softcap as f64 * ((att / softcap as f64)?.tanh())?)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let output = att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?;
Ok(output)
}
#[test]
fn flash_attn_acausal() -> Result<()> {
let device = Device::new_cuda(0)?;
@ -89,6 +103,44 @@ fn flash_attn_acausal() -> Result<()> {
Ok(())
}
#[test]
fn flash_attn_acausal_softcap() -> Result<()> {
let device = Device::new_cuda(0)?;
let q = Tensor::arange(0u32, 3 * 5 * 8, &device)?
.to_dtype(DType::F16)?
.reshape((1, 3, 5, 8))?;
let k = (&q / 40.)?;
let v = (&q / 50.)?;
let q = (&q / 30.)?;
let softcap = 5.0f32;
let ys1 = fa_acausal_softcap(&q, &k, &v, softcap.clone())?;
let ys1 = ys1.i(0)?.to_dtype(DType::F32)?;
let ys2 = {
let q = q.transpose(1, 2)?;
let k = k.transpose(1, 2)?;
let v = v.transpose(1, 2)?;
candle_flash_attn::flash_attn_alibi_windowed_softcap(
&q,
&k,
&v,
None, // alibi_slopes //
1.0, // softmax //
None, // window_size_left //
None, // window_size_right //
softcap.clone(), // softcap //
)?
.transpose(1, 2)?
};
let ys2 = ys2.i(0)?.to_dtype(DType::F32)?;
let diff = ys1.sub(&ys2)?.abs()?.flatten_all()?.max(0)?;
assert_eq!(ys1.dims(), &[3, 5, 8]);
assert_eq!(ys2.dims(), &[3, 5, 8]);
assert!(diff.to_vec0::<f32>()?.abs() < 1e-3);
Ok(())
}
#[test]
fn flash_attn_varlen() -> Result<()> {
let device = Device::new_cuda(0)?;