mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Fix the rotary embeddings for the new phi implementation. (#1582)
* Fix the rotary embeddings for the new phi implementation. * Match the activation. * KV cache fix. * Use the config activation function.
This commit is contained in:
@ -38,6 +38,7 @@ impl Config {
|
|||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct RotaryEmbedding {
|
struct RotaryEmbedding {
|
||||||
|
dim: usize,
|
||||||
sin: Tensor,
|
sin: Tensor,
|
||||||
cos: Tensor,
|
cos: Tensor,
|
||||||
}
|
}
|
||||||
@ -55,29 +56,24 @@ impl RotaryEmbedding {
|
|||||||
.to_dtype(DType::F32)?
|
.to_dtype(DType::F32)?
|
||||||
.reshape((cfg.max_position_embeddings, 1))?;
|
.reshape((cfg.max_position_embeddings, 1))?;
|
||||||
let freqs = t.matmul(&inv_freq)?;
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
|
let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
sin: freqs.sin()?,
|
dim,
|
||||||
cos: freqs.cos()?,
|
sin: emb.sin()?,
|
||||||
|
cos: emb.cos()?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||||
let (_b_size, seqlen, _, _headdim) = xs.dims4()?;
|
let (_b_size, _num_heads, seq_len, _headdim) = xs.dims4()?;
|
||||||
let (_rotary_seqlen, rotary_dim) = self.cos.dims2()?;
|
let xs_rot = xs.i((.., .., .., ..self.dim))?;
|
||||||
let rotary_dim = rotary_dim * 2;
|
let xs_pass = xs.i((.., .., .., self.dim..))?;
|
||||||
let xs_rot = xs.i((.., .., .., ..rotary_dim))?;
|
|
||||||
let xs_pass = xs.i((.., .., .., rotary_dim..))?;
|
|
||||||
let xs12 = xs_rot.chunk(2, D::Minus1)?;
|
let xs12 = xs_rot.chunk(2, D::Minus1)?;
|
||||||
let (xs1, xs2) = (&xs12[0], &xs12[1]);
|
let (xs1, xs2) = (&xs12[0], &xs12[1]);
|
||||||
let c = self.cos.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?;
|
let c = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||||
let s = self.sin.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?;
|
let s = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||||
let xs_rot = Tensor::cat(
|
let rotate_half = Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)?;
|
||||||
&[
|
let xs_rot = (xs_rot.broadcast_mul(&c)? + rotate_half.broadcast_mul(&s)?)?;
|
||||||
(xs1.broadcast_mul(&c)? - xs2.broadcast_mul(&s)?)?,
|
|
||||||
(xs1.broadcast_mul(&s)? + xs2.broadcast_mul(&c)?)?,
|
|
||||||
],
|
|
||||||
D::Minus1,
|
|
||||||
)?;
|
|
||||||
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)
|
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -97,6 +93,8 @@ impl MLP {
|
|||||||
Ok(Self {
|
Ok(Self {
|
||||||
fc1,
|
fc1,
|
||||||
fc2,
|
fc2,
|
||||||
|
// This does not match the mixformers implementation where Gelu is used rather than
|
||||||
|
// GeluNew.
|
||||||
act: cfg.hidden_act,
|
act: cfg.hidden_act,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -216,7 +214,7 @@ impl Attention {
|
|||||||
// Rotary embeddings.
|
// Rotary embeddings.
|
||||||
let seqlen_offset = match &self.kv_cache {
|
let seqlen_offset = match &self.kv_cache {
|
||||||
None => 0,
|
None => 0,
|
||||||
Some((prev_k, _)) => prev_k.dim(1)?,
|
Some((prev_k, _)) => prev_k.dim(2)?,
|
||||||
};
|
};
|
||||||
let query_states = self
|
let query_states = self
|
||||||
.rotary_emb
|
.rotary_emb
|
||||||
@ -351,7 +349,7 @@ impl Model {
|
|||||||
Some(get_mask(seq_len, xs.device())?)
|
Some(get_mask(seq_len, xs.device())?)
|
||||||
};
|
};
|
||||||
for layer in self.layers.iter_mut() {
|
for layer in self.layers.iter_mut() {
|
||||||
xs = layer.forward(&xs, mask.as_ref())?
|
xs = layer.forward(&xs, mask.as_ref())?;
|
||||||
}
|
}
|
||||||
xs.apply(&self.final_layernorm)?
|
xs.apply(&self.final_layernorm)?
|
||||||
.narrow(1, seq_len - 1, 1)?
|
.narrow(1, seq_len - 1, 1)?
|
||||||
|
Reference in New Issue
Block a user