Proper support for phi-4 (#2960)

* Add phi-4 support.

* Long-rope support.

* Get clippy to be happy.:
This commit is contained in:
Laurent Mazare
2025-05-21 10:18:33 +02:00
committed by GitHub
parent 92106c8762
commit 9a62c91643
2 changed files with 99 additions and 19 deletions

View File

@ -147,9 +147,9 @@ enum WhichModel {
V3,
#[value(name = "3-medium")]
V3Medium,
#[value(name = "2-old")]
V4Mini,
#[value(name = "4-mini")]
V4Mini,
#[value(name = "2-old")]
V2Old,
PuffinPhiV2,
PhiHermes,

View File

@ -20,10 +20,24 @@
// This implementation is based on:
// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
use crate::models::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
use candle::{DType, Device, Module, Result, Tensor, D};
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::VarBuilder;
use std::sync::Arc;
#[derive(Debug, Clone, serde::Deserialize)]
pub enum RopeScalingType {
#[serde(rename = "longrope")]
LongRope,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct RopeScaling {
pub short_factor: Vec<f32>,
pub long_factor: Vec<f32>,
#[serde(rename = "type")]
pub type_: RopeScalingType,
}
// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
@ -38,8 +52,12 @@ pub struct Config {
pub rope_theta: f64,
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<u32>,
pub rope_scaling: Option<String>,
pub rope_scaling: Option<RopeScaling>,
pub max_position_embeddings: usize,
pub original_max_position_embeddings: Option<usize>,
pub partial_rotary_factor: Option<f64>,
#[serde(default)]
pub tie_word_embeddings: bool,
}
impl Config {
@ -50,30 +68,88 @@ impl Config {
#[derive(Debug, Clone)]
pub struct RotaryEmbedding {
partial_dim: Option<usize>,
sin: Tensor,
cos: Tensor,
}
impl RotaryEmbedding {
pub fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.head_dim();
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
let partial_dim = cfg
.partial_rotary_factor
.as_ref()
.map(|v| (v * cfg.head_dim() as f64) as usize);
let dim = partial_dim.unwrap_or(cfg.head_dim());
let freqs = match cfg.rope_scaling.as_ref() {
None => {
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq = Tensor::from_vec(inv_freq, (1, ()), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
t.matmul(&inv_freq)?
}
Some(rope_scaling) => {
let inv_freq_s: Vec<_> = (0..dim)
.step_by(2)
.zip(rope_scaling.short_factor.iter())
.map(|(i, &f)| f / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_s = Tensor::from_vec(inv_freq_s, (1, ()), dev)?.to_dtype(dtype)?;
let max_seq_len = cfg.max_position_embeddings;
match cfg.original_max_position_embeddings {
None => {
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
t.matmul(&inv_freq_s)?
}
Some(original_max_seq_len) => {
let t_s = Tensor::arange(0u32, original_max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((original_max_seq_len, 1))?;
let freq_s = t_s.matmul(&inv_freq_s)?;
let inv_freq_l: Vec<_> = (0..dim)
.step_by(2)
.zip(rope_scaling.long_factor.iter())
.map(|(i, &f)| f / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_l =
Tensor::from_vec(inv_freq_l, (1, ()), dev)?.to_dtype(dtype)?;
let t_l =
Tensor::arange(original_max_seq_len as u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape(((), 1))?;
let freq_l = t_l.matmul(&inv_freq_l)?;
Tensor::cat(&[&freq_s, &freq_l], 0)?
}
}
}
};
Ok(Self {
partial_dim,
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
fn rope(&self, xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let x = match self.partial_dim {
None => candle_nn::rotary_emb::rope(&xs.contiguous()?, cos, sin)?,
Some(dim) => {
let xs_rot = xs.i((.., .., .., ..dim))?.contiguous()?;
let xs_pass = xs.i((.., .., .., dim..))?;
let xs_rot = candle_nn::rotary_emb::rope(&xs_rot, cos, sin)?;
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)?.contiguous()?
}
};
Ok(x)
}
pub fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
@ -83,8 +159,8 @@ impl RotaryEmbedding {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
let q_embed = self.rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = self.rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
@ -292,7 +368,11 @@ impl Model {
layers.push(layer)
}
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let lm_head = if cfg.tie_word_embeddings {
Linear::from_weights(embed_tokens.embeddings().clone(), None)
} else {
linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
};
Ok(Self {
embed_tokens,
layers,