mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Support for phi-2. (#1429)
* Support for phi-2. * Use the v2 naming scheme.
This commit is contained in:
@ -123,6 +123,8 @@ enum WhichModel {
|
|||||||
V1,
|
V1,
|
||||||
#[value(name = "1.5")]
|
#[value(name = "1.5")]
|
||||||
V1_5,
|
V1_5,
|
||||||
|
#[value(name = "2")]
|
||||||
|
V2,
|
||||||
PuffinPhiV2,
|
PuffinPhiV2,
|
||||||
PhiHermes,
|
PhiHermes,
|
||||||
}
|
}
|
||||||
@ -225,6 +227,7 @@ fn main() -> Result<()> {
|
|||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::V1 => "microsoft/phi-1".to_string(),
|
WhichModel::V1 => "microsoft/phi-1".to_string(),
|
||||||
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
||||||
|
WhichModel::V2 => "microsoft/phi-2".to_string(),
|
||||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||||
"lmz/candle-quantized-phi".to_string()
|
"lmz/candle-quantized-phi".to_string()
|
||||||
}
|
}
|
||||||
@ -241,7 +244,9 @@ fn main() -> Result<()> {
|
|||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::V1 => "refs/pr/2".to_string(),
|
WhichModel::V1 => "refs/pr/2".to_string(),
|
||||||
WhichModel::V1_5 => "refs/pr/18".to_string(),
|
WhichModel::V1_5 => "refs/pr/18".to_string(),
|
||||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => "main".to_string(),
|
WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||||
|
"main".to_string()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -250,27 +255,32 @@ fn main() -> Result<()> {
|
|||||||
let tokenizer_filename = match args.tokenizer {
|
let tokenizer_filename = match args.tokenizer {
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
None => match args.model {
|
None => match args.model {
|
||||||
WhichModel::V1 | WhichModel::V1_5 => repo.get("tokenizer.json")?,
|
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => repo.get("tokenizer.json")?,
|
||||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||||
repo.get("tokenizer-puffin-phi-v2.json")?
|
repo.get("tokenizer-puffin-phi-v2.json")?
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
let filename = match args.weight_file {
|
let filenames = match args.weight_file {
|
||||||
Some(weight_file) => std::path::PathBuf::from(weight_file),
|
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
||||||
None => {
|
None => {
|
||||||
if args.quantized {
|
if args.quantized {
|
||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::V1 => repo.get("model-v1-q4k.gguf")?,
|
WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?],
|
||||||
WhichModel::V1_5 => repo.get("model-q4k.gguf")?,
|
WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?],
|
||||||
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2-q4k.gguf")?,
|
WhichModel::V2 => anyhow::bail!("phi-2 is not supported in quantized mode"),
|
||||||
WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B-q4k.gguf")?,
|
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
|
||||||
|
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::V1 | WhichModel::V1_5 => repo.get("model.safetensors")?,
|
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
|
||||||
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2.safetensors")?,
|
WhichModel::V2 => vec![
|
||||||
WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B.safetensors")?,
|
repo.get("model-00001-of-00002.safetensors")?,
|
||||||
|
repo.get("model-00002-of-00002.safetensors")?,
|
||||||
|
],
|
||||||
|
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2.safetensors")?],
|
||||||
|
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B.safetensors")?],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -282,17 +292,21 @@ fn main() -> Result<()> {
|
|||||||
let config = match args.model {
|
let config = match args.model {
|
||||||
WhichModel::V1 => Config::v1(),
|
WhichModel::V1 => Config::v1(),
|
||||||
WhichModel::V1_5 => Config::v1_5(),
|
WhichModel::V1_5 => Config::v1_5(),
|
||||||
|
WhichModel::V2 => Config::v2(),
|
||||||
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
||||||
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
||||||
};
|
};
|
||||||
let (model, device) = if args.quantized {
|
let (model, device) = if args.quantized {
|
||||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
|
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?;
|
||||||
let model = QMixFormer::new(&config, vb)?;
|
let model = QMixFormer::new(&config, vb)?;
|
||||||
(Model::Quantized(model), Device::Cpu)
|
(Model::Quantized(model), Device::Cpu)
|
||||||
} else {
|
} else {
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||||
let model = MixFormer::new(&config, vb)?;
|
let model = match args.model {
|
||||||
|
WhichModel::V2 => MixFormer::new_v2(&config, vb)?,
|
||||||
|
_ => MixFormer::new(&config, vb)?,
|
||||||
|
};
|
||||||
(Model::MixFormer(model), device)
|
(Model::MixFormer(model), device)
|
||||||
};
|
};
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
@ -57,6 +57,22 @@ impl Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn v2() -> Self {
|
||||||
|
Self {
|
||||||
|
vocab_size: 51200,
|
||||||
|
n_positions: 2048,
|
||||||
|
n_embd: 2560,
|
||||||
|
n_layer: 32,
|
||||||
|
n_inner: None,
|
||||||
|
n_head: 32,
|
||||||
|
rotary_dim: usize::min(32, 2560 / 32),
|
||||||
|
activation_function: Activation::Gelu,
|
||||||
|
layer_norm_epsilon: 1e-5,
|
||||||
|
tie_word_embeddings: false,
|
||||||
|
pad_vocab_size_multiple: 64,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// https://huggingface.co/teknium/Puffin-Phi-v2/blob/main/config.json
|
// https://huggingface.co/teknium/Puffin-Phi-v2/blob/main/config.json
|
||||||
pub fn puffin_phi_v2() -> Self {
|
pub fn puffin_phi_v2() -> Self {
|
||||||
Self {
|
Self {
|
||||||
@ -372,6 +388,24 @@ pub struct MixFormerSequentialForCausalLM {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl MixFormerSequentialForCausalLM {
|
impl MixFormerSequentialForCausalLM {
|
||||||
|
pub fn new_v2(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let vb_head = vb.pp("lm_head");
|
||||||
|
let vb = vb.pp("transformer");
|
||||||
|
let embedding = Embedding::new(cfg, vb.pp("embd"))?;
|
||||||
|
let mut blocks = Vec::new();
|
||||||
|
for i in 0..cfg.n_layer {
|
||||||
|
let block = ParallelBlock::new(cfg, vb.pp("h").pp(i))?;
|
||||||
|
blocks.push(block)
|
||||||
|
}
|
||||||
|
let head = CausalLMHead::new(cfg, vb_head)?;
|
||||||
|
Ok(Self {
|
||||||
|
embedding,
|
||||||
|
blocks,
|
||||||
|
head,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "mixformer"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
let vb = vb.pp("layers");
|
let vb = vb.pp("layers");
|
||||||
let embedding = Embedding::new(cfg, vb.pp(0))?;
|
let embedding = Embedding::new(cfg, vb.pp(0))?;
|
||||||
|
Reference in New Issue
Block a user