mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Flash-Attn upgrade / SoftCap Candle-FlashAttn [2/n] (#2689)
* update flash-attn v1 * restore: hdim224 * add 224 flash_fwd_template * remove whitespace * softcap is working, including test and api. * make softcap test case better --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -27,6 +27,20 @@ fn fa_acausal(q: &Tensor, k: &Tensor, v: &Tensor, softmax_scale: f32) -> Result<
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn fa_acausal_softcap(q: &Tensor, k: &Tensor, v: &Tensor, softcap: f32) -> Result<Tensor> {
|
||||
let in_dtype = q.dtype();
|
||||
let q = q.to_dtype(DType::F32)?;
|
||||
let k = k.to_dtype(DType::F32)?;
|
||||
let v = v.to_dtype(DType::F32)?;
|
||||
// let att = (q.matmul(&k.t()?)? * softmax_scale as f64)?;
|
||||
let att = q.matmul(&k.t()?)?;
|
||||
let att = (softcap as f64 * ((att / softcap as f64)?.tanh())?)?;
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let output = att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?;
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flash_attn_acausal() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
@ -89,6 +103,44 @@ fn flash_attn_acausal() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flash_attn_acausal_softcap() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
let q = Tensor::arange(0u32, 3 * 5 * 8, &device)?
|
||||
.to_dtype(DType::F16)?
|
||||
.reshape((1, 3, 5, 8))?;
|
||||
let k = (&q / 40.)?;
|
||||
let v = (&q / 50.)?;
|
||||
let q = (&q / 30.)?;
|
||||
let softcap = 5.0f32;
|
||||
|
||||
let ys1 = fa_acausal_softcap(&q, &k, &v, softcap.clone())?;
|
||||
let ys1 = ys1.i(0)?.to_dtype(DType::F32)?;
|
||||
let ys2 = {
|
||||
let q = q.transpose(1, 2)?;
|
||||
let k = k.transpose(1, 2)?;
|
||||
let v = v.transpose(1, 2)?;
|
||||
candle_flash_attn::flash_attn_alibi_windowed_softcap(
|
||||
&q,
|
||||
&k,
|
||||
&v,
|
||||
None, // alibi_slopes //
|
||||
1.0, // softmax //
|
||||
None, // window_size_left //
|
||||
None, // window_size_right //
|
||||
softcap.clone(), // softcap //
|
||||
)?
|
||||
.transpose(1, 2)?
|
||||
};
|
||||
let ys2 = ys2.i(0)?.to_dtype(DType::F32)?;
|
||||
let diff = ys1.sub(&ys2)?.abs()?.flatten_all()?.max(0)?;
|
||||
|
||||
assert_eq!(ys1.dims(), &[3, 5, 8]);
|
||||
assert_eq!(ys2.dims(), &[3, 5, 8]);
|
||||
assert!(diff.to_vec0::<f32>()?.abs() < 1e-3);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flash_attn_varlen() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
|
Reference in New Issue
Block a user