mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Flash attention without padding (varlen). (#281)
* Expose the seqlen variable for flash-attn without padding. * Fix the batched call. * Adapt for the varlen variant. * No need to set the batch strides when in varlen mode. * Add a test (disabled at the moment). * Get the test to work properly.
This commit is contained in:
@ -88,3 +88,48 @@ fn flash_attn_acausal() -> Result<()> {
|
||||
assert!(diff.to_vec0::<f32>()?.abs() < 1e-5);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flash_attn_varlen() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
let q = Tensor::arange(0u32, 48, &device)?
|
||||
.to_dtype(DType::F16)?
|
||||
.reshape((3, 2, 8))?;
|
||||
let k = (&q / 40.)?;
|
||||
let v = (&q / 50.)?;
|
||||
let q = (&q / 30.)?;
|
||||
|
||||
let seqlens_q = Tensor::new(&[0u32, 2u32], &device)?;
|
||||
let seqlens_k = Tensor::new(&[0u32, 2u32], &device)?;
|
||||
|
||||
let ys = {
|
||||
let q = q.transpose(0, 1)?;
|
||||
let k = k.transpose(0, 1)?;
|
||||
let v = v.transpose(0, 1)?;
|
||||
candle_flash_attn::flash_attn_varlen(
|
||||
&q, &k, &v, &seqlens_q, &seqlens_k, 32, 32, 0.5, false,
|
||||
)?
|
||||
.transpose(0, 1)?
|
||||
};
|
||||
let ys = ys.to_dtype(DType::F32)?;
|
||||
|
||||
assert_eq!(ys.dims(), &[3, 2, 8]);
|
||||
assert_eq!(
|
||||
to_vec3_round(ys, 4)?,
|
||||
&[
|
||||
[
|
||||
[0.0837, 0.1038, 0.1238, 0.1438, 0.1637, 0.1837, 0.2037, 0.2238],
|
||||
[0.0922, 0.1122, 0.1322, 0.1522, 0.1721, 0.1921, 0.2122, 0.2322]
|
||||
],
|
||||
[
|
||||
[0.4204, 0.4404, 0.4604, 0.4805, 0.5005, 0.5205, 0.5405, 0.5605],
|
||||
[0.428, 0.448, 0.468, 0.488, 0.5083, 0.5283, 0.5483, 0.5684]
|
||||
],
|
||||
[
|
||||
[0.7554, 0.7754, 0.7954, 0.8154, 0.8354, 0.8555, 0.8755, 0.8955],
|
||||
[0.7622, 0.7822, 0.8022, 0.8223, 0.8423, 0.8623, 0.8823, 0.9023]
|
||||
]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user