From 70d06ab4b0065576e779a628fc024ef46003cdbc Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 27 Oct 2023 05:57:08 +0100 Subject: [PATCH] Add support for the phi-hermes finetuned model. (#1192) --- candle-examples/examples/phi/main.rs | 14 +++++++++++--- candle-transformers/src/models/mixformer.rs | 17 +++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index 9401299a..720a4441 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -124,6 +124,7 @@ enum WhichModel { #[value(name = "1.5")] V1_5, PuffinPhiV2, + PhiHermes, } #[derive(Parser, Debug)] @@ -224,7 +225,9 @@ fn main() -> Result<()> { match args.model { WhichModel::V1 => "microsoft/phi-1".to_string(), WhichModel::V1_5 => "microsoft/phi-1_5".to_string(), - WhichModel::PuffinPhiV2 => "lmz/candle-quantized-phi".to_string(), + WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { + "lmz/candle-quantized-phi".to_string() + } } } } @@ -238,7 +241,7 @@ fn main() -> Result<()> { match args.model { WhichModel::V1 => "refs/pr/2".to_string(), WhichModel::V1_5 => "refs/pr/18".to_string(), - WhichModel::PuffinPhiV2 => "main".to_string(), + WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => "main".to_string(), } } } @@ -248,7 +251,9 @@ fn main() -> Result<()> { Some(file) => std::path::PathBuf::from(file), None => match args.model { WhichModel::V1 | WhichModel::V1_5 => repo.get("tokenizer.json")?, - WhichModel::PuffinPhiV2 => repo.get("tokenizer-puffin-phi-v2.json")?, + WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { + repo.get("tokenizer-puffin-phi-v2.json")? + } }, }; let filename = match args.weight_file { @@ -259,11 +264,13 @@ fn main() -> Result<()> { WhichModel::V1 => repo.get("model-v1-q4k.gguf")?, WhichModel::V1_5 => repo.get("model-q4k.gguf")?, WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2-q4k.gguf")?, + WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B-q4k.gguf")?, } } else { match args.model { WhichModel::V1 | WhichModel::V1_5 => repo.get("model.safetensors")?, WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2.safetensors")?, + WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B.safetensors")?, } } } @@ -276,6 +283,7 @@ fn main() -> Result<()> { WhichModel::V1 => Config::v1(), WhichModel::V1_5 => Config::v1_5(), WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(), + WhichModel::PhiHermes => Config::phi_hermes_1_3b(), }; let (model, device) = if args.quantized { let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?; diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index 33aefbfe..e822ca14 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -73,6 +73,23 @@ impl Config { pad_vocab_size_multiple: 64, } } + + // https://huggingface.co/teknium/Phi-Hermes-1.3B/blob/main/config.json + pub fn phi_hermes_1_3b() -> Self { + Self { + vocab_size: 50304, + n_positions: 2048, + n_embd: 2048, + n_layer: 24, + n_inner: None, + n_head: 32, + rotary_dim: usize::min(32, 2048 / 32), + activation_function: Activation::NewGelu, + layer_norm_epsilon: 1e-5, + tie_word_embeddings: false, + pad_vocab_size_multiple: 64, + } + } } #[derive(Debug, Clone)]