mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 04:22:50 +00:00
Use cat for faster MQA computation. (#2043)
* Use cat for faster MQA computation. * Move the function to utils + use it in mistral. * Use the shared repeat-kv in a few more models. * Fix.
This commit is contained in:
@ -20,3 +20,17 @@ pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> R
|
||||
let logits_len = logits.len();
|
||||
Tensor::from_vec(logits, logits_len, device)
|
||||
}
|
||||
|
||||
/// Repeats a key or value tensor for grouped query attention
|
||||
/// The input tensor should have a shape `(batch, num_kv_heads, seq_len, head_dim)`,
|
||||
pub fn repeat_kv(xs: Tensor, n_rep: usize) -> Result<Tensor> {
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, n_kv_head, seq_len, head_dim) = xs.dims4()?;
|
||||
// Using cat is faster than a broadcast as it avoids going through a potentially
|
||||
// strided copy.
|
||||
// https://github.com/huggingface/candle/pull/2043
|
||||
Tensor::cat(&vec![&xs; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user