Proper support for phi-4 (#2960)

* Add phi-4 support.

* Long-rope support.

* Get clippy to be happy.:
This commit is contained in:
Laurent Mazare
2025-05-21 10:18:33 +02:00
committed by GitHub
parent 92106c8762
commit 9a62c91643
2 changed files with 99 additions and 19 deletions

View File

@ -147,9 +147,9 @@ enum WhichModel {
V3, V3,
#[value(name = "3-medium")] #[value(name = "3-medium")]
V3Medium, V3Medium,
#[value(name = "2-old")]
V4Mini,
#[value(name = "4-mini")] #[value(name = "4-mini")]
V4Mini,
#[value(name = "2-old")]
V2Old, V2Old,
PuffinPhiV2, PuffinPhiV2,
PhiHermes, PhiHermes,

View File

@ -20,10 +20,24 @@
// This implementation is based on: // This implementation is based on:
// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py // https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
use crate::models::with_tracing::{linear_no_bias as linear, Linear, RmsNorm}; use crate::models::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
use candle::{DType, Device, Module, Result, Tensor, D}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
use std::sync::Arc; use std::sync::Arc;
#[derive(Debug, Clone, serde::Deserialize)]
pub enum RopeScalingType {
#[serde(rename = "longrope")]
LongRope,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct RopeScaling {
pub short_factor: Vec<f32>,
pub long_factor: Vec<f32>,
#[serde(rename = "type")]
pub type_: RopeScalingType,
}
// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json // https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json
#[derive(Debug, Clone, serde::Deserialize)] #[derive(Debug, Clone, serde::Deserialize)]
pub struct Config { pub struct Config {
@ -38,8 +52,12 @@ pub struct Config {
pub rope_theta: f64, pub rope_theta: f64,
pub bos_token_id: Option<u32>, pub bos_token_id: Option<u32>,
pub eos_token_id: Option<u32>, pub eos_token_id: Option<u32>,
pub rope_scaling: Option<String>, pub rope_scaling: Option<RopeScaling>,
pub max_position_embeddings: usize, pub max_position_embeddings: usize,
pub original_max_position_embeddings: Option<usize>,
pub partial_rotary_factor: Option<f64>,
#[serde(default)]
pub tie_word_embeddings: bool,
} }
impl Config { impl Config {
@ -50,30 +68,88 @@ impl Config {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct RotaryEmbedding { pub struct RotaryEmbedding {
partial_dim: Option<usize>,
sin: Tensor, sin: Tensor,
cos: Tensor, cos: Tensor,
} }
impl RotaryEmbedding { impl RotaryEmbedding {
pub fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> { pub fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.head_dim(); let partial_dim = cfg
let max_seq_len = cfg.max_position_embeddings; .partial_rotary_factor
let inv_freq: Vec<_> = (0..dim) .as_ref()
.step_by(2) .map(|v| (v * cfg.head_dim() as f64) as usize);
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) let dim = partial_dim.unwrap_or(cfg.head_dim());
.collect(); let freqs = match cfg.rope_scaling.as_ref() {
let inv_freq_len = inv_freq.len(); None => {
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; let max_seq_len = cfg.max_position_embeddings;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)? let inv_freq: Vec<_> = (0..dim)
.to_dtype(dtype)? .step_by(2)
.reshape((max_seq_len, 1))?; .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
let freqs = t.matmul(&inv_freq)?; .collect();
let inv_freq = Tensor::from_vec(inv_freq, (1, ()), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
t.matmul(&inv_freq)?
}
Some(rope_scaling) => {
let inv_freq_s: Vec<_> = (0..dim)
.step_by(2)
.zip(rope_scaling.short_factor.iter())
.map(|(i, &f)| f / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_s = Tensor::from_vec(inv_freq_s, (1, ()), dev)?.to_dtype(dtype)?;
let max_seq_len = cfg.max_position_embeddings;
match cfg.original_max_position_embeddings {
None => {
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
t.matmul(&inv_freq_s)?
}
Some(original_max_seq_len) => {
let t_s = Tensor::arange(0u32, original_max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((original_max_seq_len, 1))?;
let freq_s = t_s.matmul(&inv_freq_s)?;
let inv_freq_l: Vec<_> = (0..dim)
.step_by(2)
.zip(rope_scaling.long_factor.iter())
.map(|(i, &f)| f / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_l =
Tensor::from_vec(inv_freq_l, (1, ()), dev)?.to_dtype(dtype)?;
let t_l =
Tensor::arange(original_max_seq_len as u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape(((), 1))?;
let freq_l = t_l.matmul(&inv_freq_l)?;
Tensor::cat(&[&freq_s, &freq_l], 0)?
}
}
}
};
Ok(Self { Ok(Self {
partial_dim,
sin: freqs.sin()?, sin: freqs.sin()?,
cos: freqs.cos()?, cos: freqs.cos()?,
}) })
} }
fn rope(&self, xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let x = match self.partial_dim {
None => candle_nn::rotary_emb::rope(&xs.contiguous()?, cos, sin)?,
Some(dim) => {
let xs_rot = xs.i((.., .., .., ..dim))?.contiguous()?;
let xs_pass = xs.i((.., .., .., dim..))?;
let xs_rot = candle_nn::rotary_emb::rope(&xs_rot, cos, sin)?;
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)?.contiguous()?
}
};
Ok(x)
}
pub fn apply_rotary_emb_qkv( pub fn apply_rotary_emb_qkv(
&self, &self,
q: &Tensor, q: &Tensor,
@ -83,8 +159,8 @@ impl RotaryEmbedding {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; let q_embed = self.rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; let k_embed = self.rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed)) Ok((q_embed, k_embed))
} }
} }
@ -292,7 +368,11 @@ impl Model {
layers.push(layer) layers.push(layer)
} }
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; let lm_head = if cfg.tie_word_embeddings {
Linear::from_weights(embed_tokens.embeddings().clone(), None)
} else {
linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
};
Ok(Self { Ok(Self {
embed_tokens, embed_tokens,
layers, layers,