mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Proper support for phi-4 (#2960)
* Add phi-4 support. * Long-rope support. * Get clippy to be happy.:
This commit is contained in:
@ -147,9 +147,9 @@ enum WhichModel {
|
||||
V3,
|
||||
#[value(name = "3-medium")]
|
||||
V3Medium,
|
||||
#[value(name = "2-old")]
|
||||
V4Mini,
|
||||
#[value(name = "4-mini")]
|
||||
V4Mini,
|
||||
#[value(name = "2-old")]
|
||||
V2Old,
|
||||
PuffinPhiV2,
|
||||
PhiHermes,
|
||||
|
@ -20,10 +20,24 @@
|
||||
// This implementation is based on:
|
||||
// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
|
||||
use crate::models::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::VarBuilder;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub enum RopeScalingType {
|
||||
#[serde(rename = "longrope")]
|
||||
LongRope,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct RopeScaling {
|
||||
pub short_factor: Vec<f32>,
|
||||
pub long_factor: Vec<f32>,
|
||||
#[serde(rename = "type")]
|
||||
pub type_: RopeScalingType,
|
||||
}
|
||||
|
||||
// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
@ -38,8 +52,12 @@ pub struct Config {
|
||||
pub rope_theta: f64,
|
||||
pub bos_token_id: Option<u32>,
|
||||
pub eos_token_id: Option<u32>,
|
||||
pub rope_scaling: Option<String>,
|
||||
pub rope_scaling: Option<RopeScaling>,
|
||||
pub max_position_embeddings: usize,
|
||||
pub original_max_position_embeddings: Option<usize>,
|
||||
pub partial_rotary_factor: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub tie_word_embeddings: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@ -50,30 +68,88 @@ impl Config {
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RotaryEmbedding {
|
||||
partial_dim: Option<usize>,
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
pub fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.head_dim();
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let partial_dim = cfg
|
||||
.partial_rotary_factor
|
||||
.as_ref()
|
||||
.map(|v| (v * cfg.head_dim() as f64) as usize);
|
||||
let dim = partial_dim.unwrap_or(cfg.head_dim());
|
||||
let freqs = match cfg.rope_scaling.as_ref() {
|
||||
None => {
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, ()), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
t.matmul(&inv_freq)?
|
||||
}
|
||||
Some(rope_scaling) => {
|
||||
let inv_freq_s: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.zip(rope_scaling.short_factor.iter())
|
||||
.map(|(i, &f)| f / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_s = Tensor::from_vec(inv_freq_s, (1, ()), dev)?.to_dtype(dtype)?;
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
match cfg.original_max_position_embeddings {
|
||||
None => {
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
t.matmul(&inv_freq_s)?
|
||||
}
|
||||
Some(original_max_seq_len) => {
|
||||
let t_s = Tensor::arange(0u32, original_max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((original_max_seq_len, 1))?;
|
||||
let freq_s = t_s.matmul(&inv_freq_s)?;
|
||||
let inv_freq_l: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.zip(rope_scaling.long_factor.iter())
|
||||
.map(|(i, &f)| f / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_l =
|
||||
Tensor::from_vec(inv_freq_l, (1, ()), dev)?.to_dtype(dtype)?;
|
||||
let t_l =
|
||||
Tensor::arange(original_max_seq_len as u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape(((), 1))?;
|
||||
let freq_l = t_l.matmul(&inv_freq_l)?;
|
||||
Tensor::cat(&[&freq_s, &freq_l], 0)?
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(Self {
|
||||
partial_dim,
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
fn rope(&self, xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
|
||||
let x = match self.partial_dim {
|
||||
None => candle_nn::rotary_emb::rope(&xs.contiguous()?, cos, sin)?,
|
||||
Some(dim) => {
|
||||
let xs_rot = xs.i((.., .., .., ..dim))?.contiguous()?;
|
||||
let xs_pass = xs.i((.., .., .., dim..))?;
|
||||
let xs_rot = candle_nn::rotary_emb::rope(&xs_rot, cos, sin)?;
|
||||
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)?.contiguous()?
|
||||
}
|
||||
};
|
||||
Ok(x)
|
||||
}
|
||||
|
||||
pub fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
@ -83,8 +159,8 @@ impl RotaryEmbedding {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||
let q_embed = self.rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = self.rope(&k.contiguous()?, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
@ -292,7 +368,11 @@ impl Model {
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let lm_head = if cfg.tie_word_embeddings {
|
||||
Linear::from_weights(embed_tokens.embeddings().clone(), None)
|
||||
} else {
|
||||
linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
|
||||
};
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
|
Reference in New Issue
Block a user