mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Use cat for faster MQA computation. (#2043)
* Use cat for faster MQA computation. * Move the function to utils + use it in mistral. * Use the shared repeat-kv in a few more models. * Fix.
This commit is contained in:
@ -174,15 +174,7 @@ impl Attention {
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_heads / self.num_kv_heads;
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||
xs.unsqueeze(2)?
|
||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||
}
|
||||
crate::utils::repeat_kv(xs, self.num_heads / self.num_kv_heads)
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
|
Reference in New Issue
Block a user