Use cat for faster MQA computation. (#2043)

* Use cat for faster MQA computation.

* Move the function to utils + use it in mistral.

* Use the shared repeat-kv in a few more models.

* Fix.
This commit is contained in:
Laurent Mazare
2024-04-12 09:15:10 +02:00
committed by GitHub
parent a0460cd2b1
commit 3ad4770eb6
16 changed files with 47 additions and 195 deletions

View File

@ -175,18 +175,6 @@ impl Attention {
})
}
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
let n_rep = self.num_kv_groups;
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
xs.unsqueeze(2)?
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
}
}
fn forward(
&mut self,
xs: &Tensor,
@ -223,8 +211,8 @@ impl Attention {
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
let key_states = self.repeat_kv(key_states)?;
let value_states = self.repeat_kv(value_states)?;
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);