From af6767220711c32268fe1524896c261ce433e612 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 16 Oct 2023 20:54:21 +0100 Subject: [PATCH] Add support for Puffin-Phi-v2. (#1110) * Add support for Puffin-Phi-v2. * Tweak the file name. * Support the config for puffin-phi-v2. * Update the readme. --- candle-examples/examples/phi/README.md | 7 ++++++ candle-examples/examples/phi/main.rs | 26 ++++++++++++++++++--- candle-transformers/src/models/mixformer.rs | 17 ++++++++++++++ 3 files changed, 47 insertions(+), 3 deletions(-) diff --git a/candle-examples/examples/phi/README.md b/candle-examples/examples/phi/README.md index bbc252d6..20a1f3aa 100644 --- a/candle-examples/examples/phi/README.md +++ b/candle-examples/examples/phi/README.md @@ -41,3 +41,10 @@ def median(arr): else: return arr[n//2] ``` + +This also supports the [Puffin Phi v2 +model](https://huggingface.co/teknium/Puffin-Phi-v2) for human interaction. +```bash +$ cargo run --example phi --release -- --prompt "USER: What would you do on a sunny day in Paris?\nASSISTANT:" --sample-len 200 --model puffin-phi-v2 --quantized +USER: What would you do on a sunny day in Paris?\nASSISTANT: On a sunny day in Paris, you could visit the Musée du Louvre to admire the famous painting "Mona Lisa" by Leonardo da Vinci. You might also want to stroll along the Champs-Élysées and enjoy the beautiful architecture of the buildings around you. Don't forget to stop by a café for a cup of coffee and to soak up the sun!" +``` diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index 7ee99ef8..9401299a 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -123,6 +123,7 @@ enum WhichModel { V1, #[value(name = "1.5")] V1_5, + PuffinPhiV2, } #[derive(Parser, Debug)] @@ -171,6 +172,9 @@ struct Args { #[arg(long)] weight_file: Option, + #[arg(long)] + tokenizer: Option, + #[arg(long)] quantized: bool, @@ -220,6 +224,7 @@ fn main() -> Result<()> { match args.model { WhichModel::V1 => "microsoft/phi-1".to_string(), WhichModel::V1_5 => "microsoft/phi-1_5".to_string(), + WhichModel::PuffinPhiV2 => "lmz/candle-quantized-phi".to_string(), } } } @@ -233,12 +238,19 @@ fn main() -> Result<()> { match args.model { WhichModel::V1 => "refs/pr/2".to_string(), WhichModel::V1_5 => "refs/pr/18".to_string(), + WhichModel::PuffinPhiV2 => "main".to_string(), } } } }; let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); - let tokenizer_filename = repo.get("tokenizer.json")?; + let tokenizer_filename = match args.tokenizer { + Some(file) => std::path::PathBuf::from(file), + None => match args.model { + WhichModel::V1 | WhichModel::V1_5 => repo.get("tokenizer.json")?, + WhichModel::PuffinPhiV2 => repo.get("tokenizer-puffin-phi-v2.json")?, + }, + }; let filename = match args.weight_file { Some(weight_file) => std::path::PathBuf::from(weight_file), None => { @@ -246,9 +258,13 @@ fn main() -> Result<()> { match args.model { WhichModel::V1 => repo.get("model-v1-q4k.gguf")?, WhichModel::V1_5 => repo.get("model-q4k.gguf")?, + WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2-q4k.gguf")?, } } else { - repo.get("model.safetensors")? + match args.model { + WhichModel::V1 | WhichModel::V1_5 => repo.get("model.safetensors")?, + WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2.safetensors")?, + } } } }; @@ -256,7 +272,11 @@ fn main() -> Result<()> { let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let start = std::time::Instant::now(); - let config = Config::v1_5(); + let config = match args.model { + WhichModel::V1 => Config::v1(), + WhichModel::V1_5 => Config::v1_5(), + WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(), + }; let (model, device) = if args.quantized { let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?; let model = QMixFormer::new(&config, vb)?; diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index 1ef8a984..f1fd8256 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -55,6 +55,23 @@ impl Config { pad_vocab_size_multiple: 64, } } + + // https://huggingface.co/teknium/Puffin-Phi-v2/blob/main/config.json + pub fn puffin_phi_v2() -> Self { + Self { + vocab_size: 50304, + n_positions: 2048, + n_embd: 2048, + n_layer: 24, + n_inner: None, + n_head: 32, + rotary_dim: usize::min(32, 2048 / 32), + activation_function: Activation::Gelu, + layer_norm_epsilon: 1e-5, + tie_word_embeddings: false, + pad_vocab_size_multiple: 64, + } + } } #[derive(Debug)]