mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Use cat for faster MQA computation. (#2043)
* Use cat for faster MQA computation. * Move the function to utils + use it in mistral. * Use the shared repeat-kv in a few more models. * Fix.
This commit is contained in:
@ -104,8 +104,8 @@ impl GroupedQueryAttention {
|
||||
};
|
||||
self.kv_cache = Some((key.clone(), value.clone()));
|
||||
let query = query.contiguous()?;
|
||||
let key = repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
|
||||
let value = repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
|
||||
let key = crate::utils::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
|
||||
let value = crate::utils::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
|
||||
let attn_weights = (query.matmul(&key)? * self.softmax_scale)?;
|
||||
let attn_bias = {
|
||||
let s_q = query.dim(D::Minus2)?;
|
||||
@ -134,20 +134,6 @@ impl GroupedQueryAttention {
|
||||
}
|
||||
}
|
||||
|
||||
// This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
|
||||
// The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
|
||||
// (batch, num_attention_heads, seqlen, head_dim)
|
||||
pub(crate) fn repeat_kv(xs: Tensor, n_rep: usize) -> Result<Tensor> {
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||
xs.unsqueeze(2)?
|
||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Ffn {
|
||||
up_proj: Linear,
|
||||
|
Reference in New Issue
Block a user