From 9d3f1c8af535e32d7a4981217a0bc8ccd71c5179 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 24 Apr 2024 08:22:23 +0200 Subject: [PATCH] Add the phi-v3 quantized model. (#2118) * Add the phi-v3 quantized model. * Also include phi-3 in the main phi example. --- .../examples/quantized-phi/main.rs | 35 +++++++++-- candle-examples/examples/quantized/main.rs | 59 +++++++++++-------- 2 files changed, 66 insertions(+), 28 deletions(-) diff --git a/candle-examples/examples/quantized-phi/main.rs b/candle-examples/examples/quantized-phi/main.rs index 301c2e06..c18d25cf 100644 --- a/candle-examples/examples/quantized-phi/main.rs +++ b/candle-examples/examples/quantized-phi/main.rs @@ -13,8 +13,8 @@ use candle::Tensor; use candle_transformers::generation::{LogitsProcessor, Sampling}; use candle_examples::token_output_stream::TokenOutputStream; -use candle_transformers::models::quantized_phi as model; -use model::ModelWeights; +use candle_transformers::models::quantized_llama::ModelWeights as Phi3; +use candle_transformers::models::quantized_phi::ModelWeights as Phi2; const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. "; @@ -22,6 +22,8 @@ const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. " enum Which { #[value(name = "phi-2")] Phi2, + #[value(name = "phi-3")] + Phi3, } #[derive(Parser, Debug)] @@ -92,7 +94,11 @@ impl Args { Some(config) => std::path::PathBuf::from(config), None => { let api = hf_hub::api::sync::Api::new()?; - let api = api.model("microsoft/phi-2".to_string()); + let repo = match self.which { + Which::Phi2 => "microsoft/phi-2", + Which::Phi3 => "microsoft/Phi-3-mini-4k-instruct", + }; + let api = api.model(repo.to_string()); api.get("tokenizer.json")? } }; @@ -105,6 +111,10 @@ impl Args { None => { let (repo, filename) = match self.which { Which::Phi2 => ("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf"), + Which::Phi3 => ( + "microsoft/Phi-3-mini-4k-instruct-gguf", + "Phi-3-mini-4k-instruct-q4.gguf", + ), }; let api = hf_hub::api::sync::Api::new()?; let api = api.model(repo.to_string()); @@ -127,6 +137,20 @@ fn format_size(size_in_bytes: usize) -> String { } } +enum Model { + Phi2(Phi2), + Phi3(Phi3), +} + +impl Model { + fn forward(&mut self, xs: &Tensor, pos: usize) -> candle::Result { + match self { + Self::Phi2(m) => m.forward(xs, pos), + Self::Phi3(m) => m.forward(xs, pos), + } + } +} + fn main() -> anyhow::Result<()> { use tracing_chrome::ChromeLayerBuilder; use tracing_subscriber::prelude::*; @@ -171,7 +195,10 @@ fn main() -> anyhow::Result<()> { &format_size(total_size_in_bytes), start.elapsed().as_secs_f32(), ); - ModelWeights::from_gguf(model, &mut file, &device)? + match args.which { + Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?), + Which::Phi3 => Model::Phi3(Phi3::from_gguf(model, &mut file, &device)?), + } }; println!("model built"); diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index 80865304..a5b830dd 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -69,6 +69,8 @@ enum Which { MixtralInstruct, #[value(name = "llama3-8b")] L8b, + #[value(name = "phi3")] + Phi3, } impl Which { @@ -85,7 +87,8 @@ impl Which { | Self::L34bCode | Self::Leo7b | Self::Leo13b - | Self::L8b => false, + | Self::L8b + | Self::Phi3 => false, // Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the // same way. Starling is a fine tuned version of OpenChat. Self::OpenChat35 @@ -120,7 +123,8 @@ impl Which { | Self::Mistral7bInstructV02 | Self::OpenChat35 | Self::Starling7bAlpha - | Self::L8b => false, + | Self::L8b + | Self::Phi3 => false, Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true, } } @@ -145,34 +149,36 @@ impl Which { | Self::Mistral7bInstructV02 | Self::Zephyr7bAlpha | Self::Zephyr7bBeta - | Self::L8b => false, + | Self::L8b + | Self::Phi3 => false, Self::OpenChat35 | Self::Starling7bAlpha => true, } } fn tokenizer_repo(&self) -> &'static str { match self { - Which::L7b - | Which::L13b - | Which::L70b - | Which::L7bChat - | Which::L13bChat - | Which::L70bChat - | Which::L7bCode - | Which::L13bCode - | Which::L34bCode => "hf-internal-testing/llama-tokenizer", - Which::Leo7b => "LeoLM/leo-hessianai-7b", - Which::Leo13b => "LeoLM/leo-hessianai-13b", - Which::Mixtral => "mistralai/Mixtral-8x7B-v0.1", - Which::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1", - Which::Mistral7b - | Which::Mistral7bInstruct - | Which::Mistral7bInstructV02 - | Which::Zephyr7bAlpha - | Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1", - Which::OpenChat35 => "openchat/openchat_3.5", - Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha", + Self::L7b + | Self::L13b + | Self::L70b + | Self::L7bChat + | Self::L13bChat + | Self::L70bChat + | Self::L7bCode + | Self::L13bCode + | Self::L34bCode => "hf-internal-testing/llama-tokenizer", + Self::Leo7b => "LeoLM/leo-hessianai-7b", + Self::Leo13b => "LeoLM/leo-hessianai-13b", + Self::Mixtral => "mistralai/Mixtral-8x7B-v0.1", + Self::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1", + Self::Mistral7b + | Self::Mistral7bInstruct + | Self::Mistral7bInstructV02 + | Self::Zephyr7bAlpha + | Self::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1", + Self::OpenChat35 => "openchat/openchat_3.5", + Self::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha", Self::L8b => "meta-llama/Meta-Llama-3-8B", + Self::Phi3 => "microsoft/Phi-3-mini-4k-instruct", } } } @@ -333,6 +339,10 @@ impl Args { "QuantFactory/Meta-Llama-3-8B-GGUF", "Meta-Llama-3-8B.Q4_K_S.gguf", ), + Which::Phi3 => ( + "microsoft/Phi-3-mini-4k-instruct-gguf", + "Phi-3-mini-4k-instruct-q4.gguf", + ), }; let api = hf_hub::api::sync::Api::new()?; let api = api.model(repo.to_string()); @@ -432,7 +442,8 @@ fn main() -> anyhow::Result<()> { | Which::L34bCode | Which::Leo7b | Which::Leo13b - | Which::L8b => 1, + | Which::L8b + | Which::Phi3 => 1, Which::Mixtral | Which::MixtralInstruct | Which::Mistral7b