From 88618255cb3c20b511a2f0e6db35d84081ce3c4a Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 13 Jan 2024 19:44:41 +0100 Subject: [PATCH] Fix the rotary embeddings for the new phi implementation. (#1582) * Fix the rotary embeddings for the new phi implementation. * Match the activation. * KV cache fix. * Use the config activation function. --- candle-transformers/src/models/phi.rs | 34 +++++++++++++-------------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/candle-transformers/src/models/phi.rs b/candle-transformers/src/models/phi.rs index a635f3ce..8bf357e7 100644 --- a/candle-transformers/src/models/phi.rs +++ b/candle-transformers/src/models/phi.rs @@ -38,6 +38,7 @@ impl Config { #[derive(Debug, Clone)] struct RotaryEmbedding { + dim: usize, sin: Tensor, cos: Tensor, } @@ -55,29 +56,24 @@ impl RotaryEmbedding { .to_dtype(DType::F32)? .reshape((cfg.max_position_embeddings, 1))?; let freqs = t.matmul(&inv_freq)?; + let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; Ok(Self { - sin: freqs.sin()?, - cos: freqs.cos()?, + dim, + sin: emb.sin()?, + cos: emb.cos()?, }) } fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offset: usize) -> Result { - let (_b_size, seqlen, _, _headdim) = xs.dims4()?; - let (_rotary_seqlen, rotary_dim) = self.cos.dims2()?; - let rotary_dim = rotary_dim * 2; - let xs_rot = xs.i((.., .., .., ..rotary_dim))?; - let xs_pass = xs.i((.., .., .., rotary_dim..))?; + let (_b_size, _num_heads, seq_len, _headdim) = xs.dims4()?; + let xs_rot = xs.i((.., .., .., ..self.dim))?; + let xs_pass = xs.i((.., .., .., self.dim..))?; let xs12 = xs_rot.chunk(2, D::Minus1)?; let (xs1, xs2) = (&xs12[0], &xs12[1]); - let c = self.cos.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?; - let s = self.sin.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?; - let xs_rot = Tensor::cat( - &[ - (xs1.broadcast_mul(&c)? - xs2.broadcast_mul(&s)?)?, - (xs1.broadcast_mul(&s)? + xs2.broadcast_mul(&c)?)?, - ], - D::Minus1, - )?; + let c = self.cos.narrow(0, seqlen_offset, seq_len)?; + let s = self.sin.narrow(0, seqlen_offset, seq_len)?; + let rotate_half = Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)?; + let xs_rot = (xs_rot.broadcast_mul(&c)? + rotate_half.broadcast_mul(&s)?)?; Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1) } } @@ -97,6 +93,8 @@ impl MLP { Ok(Self { fc1, fc2, + // This does not match the mixformers implementation where Gelu is used rather than + // GeluNew. act: cfg.hidden_act, }) } @@ -216,7 +214,7 @@ impl Attention { // Rotary embeddings. let seqlen_offset = match &self.kv_cache { None => 0, - Some((prev_k, _)) => prev_k.dim(1)?, + Some((prev_k, _)) => prev_k.dim(2)?, }; let query_states = self .rotary_emb @@ -351,7 +349,7 @@ impl Model { Some(get_mask(seq_len, xs.device())?) }; for layer in self.layers.iter_mut() { - xs = layer.forward(&xs, mask.as_ref())? + xs = layer.forward(&xs, mask.as_ref())?; } xs.apply(&self.final_layernorm)? .narrow(1, seq_len - 1, 1)?