mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Use cat for faster MQA computation. (#2043)
* Use cat for faster MQA computation. * Move the function to utils + use it in mistral. * Use the shared repeat-kv in a few more models. * Fix.
This commit is contained in:
@ -191,18 +191,6 @@ impl Attention {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
|
||||||
let n_rep = self.num_kv_groups;
|
|
||||||
if n_rep == 1 {
|
|
||||||
Ok(xs)
|
|
||||||
} else {
|
|
||||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
|
||||||
xs.unsqueeze(2)?
|
|
||||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
|
||||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(
|
fn forward(
|
||||||
&mut self,
|
&mut self,
|
||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
@ -239,8 +227,8 @@ impl Attention {
|
|||||||
};
|
};
|
||||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||||
|
|
||||||
let key_states = self.repeat_kv(key_states)?.contiguous()?;
|
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
|
||||||
let value_states = self.repeat_kv(value_states)?.contiguous()?;
|
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
|
||||||
|
|
||||||
let attn_output = {
|
let attn_output = {
|
||||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||||
|
@ -256,17 +256,7 @@ impl CausalSelfAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
||||||
let n_rep = self.num_attention_heads / self.num_key_value_heads;
|
crate::utils::repeat_kv(x, self.num_attention_heads / self.num_key_value_heads)
|
||||||
if n_rep == 1 {
|
|
||||||
Ok(x)
|
|
||||||
} else {
|
|
||||||
let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
|
|
||||||
let x = x
|
|
||||||
.unsqueeze(2)?
|
|
||||||
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
|
|
||||||
.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
|
|
||||||
Ok(x)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
|
@ -216,18 +216,6 @@ impl Attention {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
|
||||||
let n_rep = self.num_kv_groups;
|
|
||||||
if n_rep == 1 {
|
|
||||||
Ok(xs)
|
|
||||||
} else {
|
|
||||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
|
||||||
xs.unsqueeze(2)?
|
|
||||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
|
||||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(
|
fn forward(
|
||||||
&mut self,
|
&mut self,
|
||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
@ -266,8 +254,8 @@ impl Attention {
|
|||||||
};
|
};
|
||||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||||
|
|
||||||
let key_states = self.repeat_kv(key_states)?;
|
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
|
||||||
let value_states = self.repeat_kv(value_states)?;
|
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
|
||||||
|
|
||||||
let attn_output = if self.use_flash_attn {
|
let attn_output = if self.use_flash_attn {
|
||||||
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
||||||
|
@ -158,18 +158,6 @@ impl Attention {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
|
||||||
let n_rep = self.num_kv_groups;
|
|
||||||
if n_rep == 1 {
|
|
||||||
Ok(xs)
|
|
||||||
} else {
|
|
||||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
|
||||||
xs.unsqueeze(2)?
|
|
||||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
|
||||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(
|
fn forward(
|
||||||
&mut self,
|
&mut self,
|
||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
@ -206,8 +194,8 @@ impl Attention {
|
|||||||
};
|
};
|
||||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||||
|
|
||||||
let key_states = self.repeat_kv(key_states)?;
|
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
|
||||||
let value_states = self.repeat_kv(value_states)?;
|
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
|
||||||
|
|
||||||
let attn_output = if self.use_flash_attn {
|
let attn_output = if self.use_flash_attn {
|
||||||
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
||||||
|
@ -104,8 +104,8 @@ impl GroupedQueryAttention {
|
|||||||
};
|
};
|
||||||
self.kv_cache = Some((key.clone(), value.clone()));
|
self.kv_cache = Some((key.clone(), value.clone()));
|
||||||
let query = query.contiguous()?;
|
let query = query.contiguous()?;
|
||||||
let key = repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
|
let key = crate::utils::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
|
||||||
let value = repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
|
let value = crate::utils::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
|
||||||
let attn_weights = (query.matmul(&key)? * self.softmax_scale)?;
|
let attn_weights = (query.matmul(&key)? * self.softmax_scale)?;
|
||||||
let attn_bias = {
|
let attn_bias = {
|
||||||
let s_q = query.dim(D::Minus2)?;
|
let s_q = query.dim(D::Minus2)?;
|
||||||
@ -134,20 +134,6 @@ impl GroupedQueryAttention {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
|
|
||||||
// The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
|
|
||||||
// (batch, num_attention_heads, seqlen, head_dim)
|
|
||||||
pub(crate) fn repeat_kv(xs: Tensor, n_rep: usize) -> Result<Tensor> {
|
|
||||||
if n_rep == 1 {
|
|
||||||
Ok(xs)
|
|
||||||
} else {
|
|
||||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
|
||||||
xs.unsqueeze(2)?
|
|
||||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
|
||||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct Ffn {
|
struct Ffn {
|
||||||
up_proj: Linear,
|
up_proj: Linear,
|
||||||
|
@ -174,15 +174,7 @@ impl Attention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
||||||
let n_rep = self.num_heads / self.num_kv_heads;
|
crate::utils::repeat_kv(xs, self.num_heads / self.num_kv_heads)
|
||||||
if n_rep == 1 {
|
|
||||||
Ok(xs)
|
|
||||||
} else {
|
|
||||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
|
||||||
xs.unsqueeze(2)?
|
|
||||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
|
||||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
|
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||||
|
@ -205,9 +205,9 @@ impl LayerWeights {
|
|||||||
};
|
};
|
||||||
self.kv_cache = Some((k.clone(), v.clone()));
|
self.kv_cache = Some((k.clone(), v.clone()));
|
||||||
|
|
||||||
// Support for MQA, useful for 70B models.
|
// Support for MQA, useful for 70B models and mistral.
|
||||||
let k = self.repeat_kv(k)?;
|
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
|
||||||
let v = self.repeat_kv(v)?;
|
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
|
||||||
|
|
||||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||||
let att = match mask {
|
let att = match mask {
|
||||||
@ -224,20 +224,6 @@ impl LayerWeights {
|
|||||||
let y = self.attention_wo.forward(&y)?;
|
let y = self.attention_wo.forward(&y)?;
|
||||||
Ok(y)
|
Ok(y)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
|
||||||
let n_rep = self.n_head / self.n_kv_head;
|
|
||||||
if n_rep == 1 {
|
|
||||||
Ok(x)
|
|
||||||
} else {
|
|
||||||
let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
|
|
||||||
let x = x
|
|
||||||
.unsqueeze(2)?
|
|
||||||
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
|
|
||||||
.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
|
|
||||||
Ok(x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
@ -122,18 +122,6 @@ impl Attention {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
|
||||||
let n_rep = self.num_kv_groups;
|
|
||||||
if n_rep == 1 {
|
|
||||||
Ok(xs)
|
|
||||||
} else {
|
|
||||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
|
||||||
xs.unsqueeze(2)?
|
|
||||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
|
||||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(
|
fn forward(
|
||||||
&mut self,
|
&mut self,
|
||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
@ -172,8 +160,8 @@ impl Attention {
|
|||||||
};
|
};
|
||||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||||
|
|
||||||
let key_states = self.repeat_kv(key_states)?;
|
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
|
||||||
let value_states = self.repeat_kv(value_states)?;
|
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
|
||||||
|
|
||||||
let attn_output = {
|
let attn_output = {
|
||||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||||
|
@ -71,8 +71,8 @@ impl GroupedQueryAttention {
|
|||||||
};
|
};
|
||||||
self.kv_cache = Some((key.clone(), value.clone()));
|
self.kv_cache = Some((key.clone(), value.clone()));
|
||||||
let query = query.contiguous()?;
|
let query = query.contiguous()?;
|
||||||
let key = super::mpt::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
|
let key = crate::utils::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
|
||||||
let value = super::mpt::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
|
let value = crate::utils::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
|
||||||
let attn_weights = (query.matmul(&key)? * self.softmax_scale)?;
|
let attn_weights = (query.matmul(&key)? * self.softmax_scale)?;
|
||||||
let attn_bias = {
|
let attn_bias = {
|
||||||
let s_q = query.dim(D::Minus2)?;
|
let s_q = query.dim(D::Minus2)?;
|
||||||
|
@ -94,18 +94,6 @@ impl Attention {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
|
||||||
let n_rep = self.num_kv_groups;
|
|
||||||
if n_rep == 1 {
|
|
||||||
Ok(xs)
|
|
||||||
} else {
|
|
||||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
|
||||||
xs.unsqueeze(2)?
|
|
||||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
|
||||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(
|
fn forward(
|
||||||
&mut self,
|
&mut self,
|
||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
@ -152,8 +140,9 @@ impl Attention {
|
|||||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||||
}
|
}
|
||||||
|
|
||||||
let key_states = self.repeat_kv(key_states)?.contiguous()?;
|
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
|
||||||
let value_states = self.repeat_kv(value_states)?.contiguous()?;
|
let value_states =
|
||||||
|
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
|
||||||
|
|
||||||
let attn_output = {
|
let attn_output = {
|
||||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||||
|
@ -146,18 +146,6 @@ impl Attention {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
|
||||||
let n_rep = self.num_kv_groups;
|
|
||||||
if n_rep == 1 {
|
|
||||||
Ok(xs)
|
|
||||||
} else {
|
|
||||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
|
||||||
xs.unsqueeze(2)?
|
|
||||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
|
||||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(
|
fn forward(
|
||||||
&mut self,
|
&mut self,
|
||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
@ -194,8 +182,9 @@ impl Attention {
|
|||||||
};
|
};
|
||||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||||
|
|
||||||
let key_states = self.repeat_kv(key_states)?.contiguous()?;
|
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
|
||||||
let value_states = self.repeat_kv(value_states)?.contiguous()?;
|
let value_states =
|
||||||
|
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
|
||||||
|
|
||||||
let attn_output = {
|
let attn_output = {
|
||||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||||
|
@ -151,18 +151,6 @@ impl Attention {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
|
||||||
let n_rep = self.num_kv_groups;
|
|
||||||
if n_rep == 1 {
|
|
||||||
Ok(xs)
|
|
||||||
} else {
|
|
||||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
|
||||||
xs.unsqueeze(2)?
|
|
||||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
|
||||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(
|
fn forward(
|
||||||
&mut self,
|
&mut self,
|
||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
@ -199,8 +187,9 @@ impl Attention {
|
|||||||
};
|
};
|
||||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||||
|
|
||||||
let key_states = self.repeat_kv(key_states)?.contiguous()?;
|
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
|
||||||
let value_states = self.repeat_kv(value_states)?.contiguous()?;
|
let value_states =
|
||||||
|
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
|
||||||
|
|
||||||
let attn_output = {
|
let attn_output = {
|
||||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||||
|
@ -217,18 +217,6 @@ impl Attention {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
|
||||||
let n_rep = self.num_kv_groups;
|
|
||||||
if n_rep == 1 {
|
|
||||||
Ok(xs)
|
|
||||||
} else {
|
|
||||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
|
||||||
xs.unsqueeze(2)?
|
|
||||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
|
||||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(
|
fn forward(
|
||||||
&mut self,
|
&mut self,
|
||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
@ -275,8 +263,9 @@ impl Attention {
|
|||||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||||
}
|
}
|
||||||
|
|
||||||
let key_states = self.repeat_kv(key_states)?.contiguous()?;
|
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
|
||||||
let value_states = self.repeat_kv(value_states)?.contiguous()?;
|
let value_states =
|
||||||
|
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
|
||||||
|
|
||||||
let attn_output = if self.use_flash_attn {
|
let attn_output = if self.use_flash_attn {
|
||||||
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
||||||
|
@ -139,18 +139,6 @@ impl Attention {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
|
||||||
let n_rep = self.num_kv_groups;
|
|
||||||
if n_rep == 1 {
|
|
||||||
Ok(xs)
|
|
||||||
} else {
|
|
||||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
|
||||||
xs.unsqueeze(2)?
|
|
||||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
|
||||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(
|
fn forward(
|
||||||
&mut self,
|
&mut self,
|
||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
@ -187,8 +175,8 @@ impl Attention {
|
|||||||
};
|
};
|
||||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||||
|
|
||||||
let key_states = self.repeat_kv(key_states)?;
|
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
|
||||||
let value_states = self.repeat_kv(value_states)?;
|
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
|
||||||
|
|
||||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||||
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||||
|
@ -175,18 +175,6 @@ impl Attention {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
|
||||||
let n_rep = self.num_kv_groups;
|
|
||||||
if n_rep == 1 {
|
|
||||||
Ok(xs)
|
|
||||||
} else {
|
|
||||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
|
||||||
xs.unsqueeze(2)?
|
|
||||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
|
||||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(
|
fn forward(
|
||||||
&mut self,
|
&mut self,
|
||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
@ -223,8 +211,8 @@ impl Attention {
|
|||||||
};
|
};
|
||||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||||
|
|
||||||
let key_states = self.repeat_kv(key_states)?;
|
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
|
||||||
let value_states = self.repeat_kv(value_states)?;
|
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
|
||||||
|
|
||||||
let attn_output = {
|
let attn_output = {
|
||||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||||
|
@ -20,3 +20,17 @@ pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> R
|
|||||||
let logits_len = logits.len();
|
let logits_len = logits.len();
|
||||||
Tensor::from_vec(logits, logits_len, device)
|
Tensor::from_vec(logits, logits_len, device)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Repeats a key or value tensor for grouped query attention
|
||||||
|
/// The input tensor should have a shape `(batch, num_kv_heads, seq_len, head_dim)`,
|
||||||
|
pub fn repeat_kv(xs: Tensor, n_rep: usize) -> Result<Tensor> {
|
||||||
|
if n_rep == 1 {
|
||||||
|
Ok(xs)
|
||||||
|
} else {
|
||||||
|
let (b_sz, n_kv_head, seq_len, head_dim) = xs.dims4()?;
|
||||||
|
// Using cat is faster than a broadcast as it avoids going through a potentially
|
||||||
|
// strided copy.
|
||||||
|
// https://github.com/huggingface/candle/pull/2043
|
||||||
|
Tensor::cat(&vec![&xs; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user