Enable the new layer-norm. (#2213)

* Enable the new layer-norm.

* Shape fixes.
This commit is contained in:
Laurent Mazare
2024-05-24 16:48:21 +02:00
committed by GitHub
parent 1df2bddccf
commit 3ceca9901a
3 changed files with 23 additions and 13 deletions

View File

@ -56,24 +56,20 @@ impl RotaryEmbedding {
.to_dtype(DType::F32)?
.reshape((cfg.max_position_embeddings, 1))?;
let freqs = t.matmul(&inv_freq)?;
let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
dim,
sin: emb.sin()?,
cos: emb.cos()?,
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (_b_size, _num_heads, seq_len, _headdim) = xs.dims4()?;
let xs_rot = xs.i((.., .., .., ..self.dim))?;
let xs_rot = xs.i((.., .., .., ..self.dim))?.contiguous()?;
let xs_pass = xs.i((.., .., .., self.dim..))?;
let xs12 = xs_rot.chunk(2, D::Minus1)?;
let (xs1, xs2) = (&xs12[0], &xs12[1]);
let c = self.cos.narrow(0, seqlen_offset, seq_len)?;
let s = self.sin.narrow(0, seqlen_offset, seq_len)?;
let rotate_half = Tensor::cat(&[&xs2.neg()?, xs1], D::Minus1)?;
let xs_rot = (xs_rot.broadcast_mul(&c)? + rotate_half.broadcast_mul(&s)?)?;
let xs_rot = candle_nn::rotary_emb::rope(&xs_rot, &c, &s)?;
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)
}
}