mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Fix the rotary embeddings for the new phi implementation. (#1582)
* Fix the rotary embeddings for the new phi implementation. * Match the activation. * KV cache fix. * Use the config activation function.
This commit is contained in:
@ -38,6 +38,7 @@ impl Config {
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
dim: usize,
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
@ -55,29 +56,24 @@ impl RotaryEmbedding {
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((cfg.max_position_embeddings, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
dim,
|
||||
sin: emb.sin()?,
|
||||
cos: emb.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (_b_size, seqlen, _, _headdim) = xs.dims4()?;
|
||||
let (_rotary_seqlen, rotary_dim) = self.cos.dims2()?;
|
||||
let rotary_dim = rotary_dim * 2;
|
||||
let xs_rot = xs.i((.., .., .., ..rotary_dim))?;
|
||||
let xs_pass = xs.i((.., .., .., rotary_dim..))?;
|
||||
let (_b_size, _num_heads, seq_len, _headdim) = xs.dims4()?;
|
||||
let xs_rot = xs.i((.., .., .., ..self.dim))?;
|
||||
let xs_pass = xs.i((.., .., .., self.dim..))?;
|
||||
let xs12 = xs_rot.chunk(2, D::Minus1)?;
|
||||
let (xs1, xs2) = (&xs12[0], &xs12[1]);
|
||||
let c = self.cos.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?;
|
||||
let s = self.sin.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?;
|
||||
let xs_rot = Tensor::cat(
|
||||
&[
|
||||
(xs1.broadcast_mul(&c)? - xs2.broadcast_mul(&s)?)?,
|
||||
(xs1.broadcast_mul(&s)? + xs2.broadcast_mul(&c)?)?,
|
||||
],
|
||||
D::Minus1,
|
||||
)?;
|
||||
let c = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let s = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let rotate_half = Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)?;
|
||||
let xs_rot = (xs_rot.broadcast_mul(&c)? + rotate_half.broadcast_mul(&s)?)?;
|
||||
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)
|
||||
}
|
||||
}
|
||||
@ -97,6 +93,8 @@ impl MLP {
|
||||
Ok(Self {
|
||||
fc1,
|
||||
fc2,
|
||||
// This does not match the mixformers implementation where Gelu is used rather than
|
||||
// GeluNew.
|
||||
act: cfg.hidden_act,
|
||||
})
|
||||
}
|
||||
@ -216,7 +214,7 @@ impl Attention {
|
||||
// Rotary embeddings.
|
||||
let seqlen_offset = match &self.kv_cache {
|
||||
None => 0,
|
||||
Some((prev_k, _)) => prev_k.dim(1)?,
|
||||
Some((prev_k, _)) => prev_k.dim(2)?,
|
||||
};
|
||||
let query_states = self
|
||||
.rotary_emb
|
||||
@ -351,7 +349,7 @@ impl Model {
|
||||
Some(get_mask(seq_len, xs.device())?)
|
||||
};
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, mask.as_ref())?
|
||||
xs = layer.forward(&xs, mask.as_ref())?;
|
||||
}
|
||||
xs.apply(&self.final_layernorm)?
|
||||
.narrow(1, seq_len - 1, 1)?
|
||||
|
Reference in New Issue
Block a user