diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index 9034367d..0f4cf1bb 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -147,9 +147,9 @@ enum WhichModel { V3, #[value(name = "3-medium")] V3Medium, - #[value(name = "2-old")] - V4Mini, #[value(name = "4-mini")] + V4Mini, + #[value(name = "2-old")] V2Old, PuffinPhiV2, PhiHermes, diff --git a/candle-transformers/src/models/phi3.rs b/candle-transformers/src/models/phi3.rs index 7ce9e987..6535d9a4 100644 --- a/candle-transformers/src/models/phi3.rs +++ b/candle-transformers/src/models/phi3.rs @@ -20,10 +20,24 @@ // This implementation is based on: // https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py use crate::models::with_tracing::{linear_no_bias as linear, Linear, RmsNorm}; -use candle::{DType, Device, Module, Result, Tensor, D}; +use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::VarBuilder; use std::sync::Arc; +#[derive(Debug, Clone, serde::Deserialize)] +pub enum RopeScalingType { + #[serde(rename = "longrope")] + LongRope, +} + +#[derive(Debug, Clone, serde::Deserialize)] +pub struct RopeScaling { + pub short_factor: Vec, + pub long_factor: Vec, + #[serde(rename = "type")] + pub type_: RopeScalingType, +} + // https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json #[derive(Debug, Clone, serde::Deserialize)] pub struct Config { @@ -38,8 +52,12 @@ pub struct Config { pub rope_theta: f64, pub bos_token_id: Option, pub eos_token_id: Option, - pub rope_scaling: Option, + pub rope_scaling: Option, pub max_position_embeddings: usize, + pub original_max_position_embeddings: Option, + pub partial_rotary_factor: Option, + #[serde(default)] + pub tie_word_embeddings: bool, } impl Config { @@ -50,30 +68,88 @@ impl Config { #[derive(Debug, Clone)] pub struct RotaryEmbedding { + partial_dim: Option, sin: Tensor, cos: Tensor, } impl RotaryEmbedding { pub fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { - let dim = cfg.head_dim(); - let max_seq_len = cfg.max_position_embeddings; - let inv_freq: Vec<_> = (0..dim) - .step_by(2) - .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) - .collect(); - let inv_freq_len = inv_freq.len(); - let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; - let t = Tensor::arange(0u32, max_seq_len as u32, dev)? - .to_dtype(dtype)? - .reshape((max_seq_len, 1))?; - let freqs = t.matmul(&inv_freq)?; + let partial_dim = cfg + .partial_rotary_factor + .as_ref() + .map(|v| (v * cfg.head_dim() as f64) as usize); + let dim = partial_dim.unwrap_or(cfg.head_dim()); + let freqs = match cfg.rope_scaling.as_ref() { + None => { + let max_seq_len = cfg.max_position_embeddings; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) + .collect(); + let inv_freq = Tensor::from_vec(inv_freq, (1, ()), dev)?.to_dtype(dtype)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(dtype)? + .reshape((max_seq_len, 1))?; + t.matmul(&inv_freq)? + } + Some(rope_scaling) => { + let inv_freq_s: Vec<_> = (0..dim) + .step_by(2) + .zip(rope_scaling.short_factor.iter()) + .map(|(i, &f)| f / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) + .collect(); + let inv_freq_s = Tensor::from_vec(inv_freq_s, (1, ()), dev)?.to_dtype(dtype)?; + let max_seq_len = cfg.max_position_embeddings; + match cfg.original_max_position_embeddings { + None => { + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(dtype)? + .reshape((max_seq_len, 1))?; + t.matmul(&inv_freq_s)? + } + Some(original_max_seq_len) => { + let t_s = Tensor::arange(0u32, original_max_seq_len as u32, dev)? + .to_dtype(dtype)? + .reshape((original_max_seq_len, 1))?; + let freq_s = t_s.matmul(&inv_freq_s)?; + let inv_freq_l: Vec<_> = (0..dim) + .step_by(2) + .zip(rope_scaling.long_factor.iter()) + .map(|(i, &f)| f / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) + .collect(); + let inv_freq_l = + Tensor::from_vec(inv_freq_l, (1, ()), dev)?.to_dtype(dtype)?; + let t_l = + Tensor::arange(original_max_seq_len as u32, max_seq_len as u32, dev)? + .to_dtype(dtype)? + .reshape(((), 1))?; + let freq_l = t_l.matmul(&inv_freq_l)?; + Tensor::cat(&[&freq_s, &freq_l], 0)? + } + } + } + }; Ok(Self { + partial_dim, sin: freqs.sin()?, cos: freqs.cos()?, }) } + fn rope(&self, xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result { + let x = match self.partial_dim { + None => candle_nn::rotary_emb::rope(&xs.contiguous()?, cos, sin)?, + Some(dim) => { + let xs_rot = xs.i((.., .., .., ..dim))?.contiguous()?; + let xs_pass = xs.i((.., .., .., dim..))?; + let xs_rot = candle_nn::rotary_emb::rope(&xs_rot, cos, sin)?; + Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)?.contiguous()? + } + }; + Ok(x) + } + pub fn apply_rotary_emb_qkv( &self, q: &Tensor, @@ -83,8 +159,8 @@ impl RotaryEmbedding { let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; - let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; - let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; + let q_embed = self.rope(&q.contiguous()?, &cos, &sin)?; + let k_embed = self.rope(&k.contiguous()?, &cos, &sin)?; Ok((q_embed, k_embed)) } } @@ -292,7 +368,11 @@ impl Model { layers.push(layer) } let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; - let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + let lm_head = if cfg.tie_word_embeddings { + Linear::from_weights(embed_tokens.embeddings().clone(), None) + } else { + linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))? + }; Ok(Self { embed_tokens, layers,