Flash attention without padding (varlen). (#281)

* Expose the seqlen variable for flash-attn without padding.

* Fix the batched call.

* Adapt for the varlen variant.

* No need to set the batch strides when in varlen mode.

* Add a test (disabled at the moment).

* Get the test to work properly.
This commit is contained in:
Laurent Mazare
2023-07-31 09:45:39 +01:00
committed by GitHub
parent a8d8f9f206
commit 0ace420e66
4 changed files with 283 additions and 4 deletions

View File

@ -88,3 +88,48 @@ fn flash_attn_acausal() -> Result<()> {
assert!(diff.to_vec0::<f32>()?.abs() < 1e-5);
Ok(())
}
#[test]
fn flash_attn_varlen() -> Result<()> {
let device = Device::new_cuda(0)?;
let q = Tensor::arange(0u32, 48, &device)?
.to_dtype(DType::F16)?
.reshape((3, 2, 8))?;
let k = (&q / 40.)?;
let v = (&q / 50.)?;
let q = (&q / 30.)?;
let seqlens_q = Tensor::new(&[0u32, 2u32], &device)?;
let seqlens_k = Tensor::new(&[0u32, 2u32], &device)?;
let ys = {
let q = q.transpose(0, 1)?;
let k = k.transpose(0, 1)?;
let v = v.transpose(0, 1)?;
candle_flash_attn::flash_attn_varlen(
&q, &k, &v, &seqlens_q, &seqlens_k, 32, 32, 0.5, false,
)?
.transpose(0, 1)?
};
let ys = ys.to_dtype(DType::F32)?;
assert_eq!(ys.dims(), &[3, 2, 8]);
assert_eq!(
to_vec3_round(ys, 4)?,
&[
[
[0.0837, 0.1038, 0.1238, 0.1438, 0.1637, 0.1837, 0.2037, 0.2238],
[0.0922, 0.1122, 0.1322, 0.1522, 0.1721, 0.1921, 0.2122, 0.2322]
],
[
[0.4204, 0.4404, 0.4604, 0.4805, 0.5005, 0.5205, 0.5405, 0.5605],
[0.428, 0.448, 0.468, 0.488, 0.5083, 0.5283, 0.5483, 0.5684]
],
[
[0.7554, 0.7754, 0.7954, 0.8154, 0.8354, 0.8555, 0.8755, 0.8955],
[0.7622, 0.7822, 0.8022, 0.8223, 0.8423, 0.8623, 0.8823, 0.9023]
]
]
);
Ok(())
}