Again set a few extra params in flash-attn. (#245)

* Again set a few extra params.

* Use the appropriate kernel sizes.

* Add all the kernel sizes.

* Parallel compiling.

* Reduce the amount of parallelism.

* Add the missing kernel.

* Fix a typo.

* Remove bf16 support for now.
This commit is contained in:
Laurent Mazare
2023-07-26 14:16:37 +01:00
committed by GitHub
parent fa2b64d678
commit 2ce5f12513
21 changed files with 476 additions and 116 deletions

View File

@ -220,8 +220,12 @@ impl CausalSelfAttention {
let v = self.repeat_kv(v)?;
let y = if self.use_flash_attn {
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
let q = q.transpose(1, 2)?;
let k = k.transpose(1, 2)?;
let v = v.transpose(1, 2)?;
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
flash_attn(softmax_scale, &q, &k, &v)?
flash_attn(softmax_scale, &q, &k, &v)?.transpose(1, 2)?
} else {
let in_dtype = q.dtype();
let q = q.to_dtype(DType::F32)?;