Use cat for faster MQA computation. (#2043)

* Use cat for faster MQA computation.

* Move the function to utils + use it in mistral.

* Use the shared repeat-kv in a few more models.

* Fix.
This commit is contained in:
Laurent Mazare
2024-04-12 09:15:10 +02:00
committed by GitHub
parent a0460cd2b1
commit 3ad4770eb6
16 changed files with 47 additions and 195 deletions

View File

@ -20,3 +20,17 @@ pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> R
let logits_len = logits.len();
Tensor::from_vec(logits, logits_len, device)
}
/// Repeats a key or value tensor for grouped query attention
/// The input tensor should have a shape `(batch, num_kv_heads, seq_len, head_dim)`,
pub fn repeat_kv(xs: Tensor, n_rep: usize) -> Result<Tensor> {
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, n_kv_head, seq_len, head_dim) = xs.dims4()?;
// Using cat is faster than a broadcast as it avoids going through a potentially
// strided copy.
// https://github.com/huggingface/candle/pull/2043
Tensor::cat(&vec![&xs; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
}
}