Add the phi-v3 quantized model. (#2118)

* Add the phi-v3 quantized model.

* Also include phi-3 in the main phi example.
This commit is contained in:
Laurent Mazare
2024-04-24 08:22:23 +02:00
committed by GitHub
parent 7211009179
commit 9d3f1c8af5
2 changed files with 66 additions and 28 deletions

View File

@ -13,8 +13,8 @@ use candle::Tensor;
use candle_transformers::generation::{LogitsProcessor, Sampling}; use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_examples::token_output_stream::TokenOutputStream; use candle_examples::token_output_stream::TokenOutputStream;
use candle_transformers::models::quantized_phi as model; use candle_transformers::models::quantized_llama::ModelWeights as Phi3;
use model::ModelWeights; use candle_transformers::models::quantized_phi::ModelWeights as Phi2;
const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. "; const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. ";
@ -22,6 +22,8 @@ const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. "
enum Which { enum Which {
#[value(name = "phi-2")] #[value(name = "phi-2")]
Phi2, Phi2,
#[value(name = "phi-3")]
Phi3,
} }
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -92,7 +94,11 @@ impl Args {
Some(config) => std::path::PathBuf::from(config), Some(config) => std::path::PathBuf::from(config),
None => { None => {
let api = hf_hub::api::sync::Api::new()?; let api = hf_hub::api::sync::Api::new()?;
let api = api.model("microsoft/phi-2".to_string()); let repo = match self.which {
Which::Phi2 => "microsoft/phi-2",
Which::Phi3 => "microsoft/Phi-3-mini-4k-instruct",
};
let api = api.model(repo.to_string());
api.get("tokenizer.json")? api.get("tokenizer.json")?
} }
}; };
@ -105,6 +111,10 @@ impl Args {
None => { None => {
let (repo, filename) = match self.which { let (repo, filename) = match self.which {
Which::Phi2 => ("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf"), Which::Phi2 => ("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf"),
Which::Phi3 => (
"microsoft/Phi-3-mini-4k-instruct-gguf",
"Phi-3-mini-4k-instruct-q4.gguf",
),
}; };
let api = hf_hub::api::sync::Api::new()?; let api = hf_hub::api::sync::Api::new()?;
let api = api.model(repo.to_string()); let api = api.model(repo.to_string());
@ -127,6 +137,20 @@ fn format_size(size_in_bytes: usize) -> String {
} }
} }
enum Model {
Phi2(Phi2),
Phi3(Phi3),
}
impl Model {
fn forward(&mut self, xs: &Tensor, pos: usize) -> candle::Result<Tensor> {
match self {
Self::Phi2(m) => m.forward(xs, pos),
Self::Phi3(m) => m.forward(xs, pos),
}
}
}
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder; use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*; use tracing_subscriber::prelude::*;
@ -171,7 +195,10 @@ fn main() -> anyhow::Result<()> {
&format_size(total_size_in_bytes), &format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(), start.elapsed().as_secs_f32(),
); );
ModelWeights::from_gguf(model, &mut file, &device)? match args.which {
Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?),
Which::Phi3 => Model::Phi3(Phi3::from_gguf(model, &mut file, &device)?),
}
}; };
println!("model built"); println!("model built");

View File

@ -69,6 +69,8 @@ enum Which {
MixtralInstruct, MixtralInstruct,
#[value(name = "llama3-8b")] #[value(name = "llama3-8b")]
L8b, L8b,
#[value(name = "phi3")]
Phi3,
} }
impl Which { impl Which {
@ -85,7 +87,8 @@ impl Which {
| Self::L34bCode | Self::L34bCode
| Self::Leo7b | Self::Leo7b
| Self::Leo13b | Self::Leo13b
| Self::L8b => false, | Self::L8b
| Self::Phi3 => false,
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the // Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
// same way. Starling is a fine tuned version of OpenChat. // same way. Starling is a fine tuned version of OpenChat.
Self::OpenChat35 Self::OpenChat35
@ -120,7 +123,8 @@ impl Which {
| Self::Mistral7bInstructV02 | Self::Mistral7bInstructV02
| Self::OpenChat35 | Self::OpenChat35
| Self::Starling7bAlpha | Self::Starling7bAlpha
| Self::L8b => false, | Self::L8b
| Self::Phi3 => false,
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true, Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
} }
} }
@ -145,34 +149,36 @@ impl Which {
| Self::Mistral7bInstructV02 | Self::Mistral7bInstructV02
| Self::Zephyr7bAlpha | Self::Zephyr7bAlpha
| Self::Zephyr7bBeta | Self::Zephyr7bBeta
| Self::L8b => false, | Self::L8b
| Self::Phi3 => false,
Self::OpenChat35 | Self::Starling7bAlpha => true, Self::OpenChat35 | Self::Starling7bAlpha => true,
} }
} }
fn tokenizer_repo(&self) -> &'static str { fn tokenizer_repo(&self) -> &'static str {
match self { match self {
Which::L7b Self::L7b
| Which::L13b | Self::L13b
| Which::L70b | Self::L70b
| Which::L7bChat | Self::L7bChat
| Which::L13bChat | Self::L13bChat
| Which::L70bChat | Self::L70bChat
| Which::L7bCode | Self::L7bCode
| Which::L13bCode | Self::L13bCode
| Which::L34bCode => "hf-internal-testing/llama-tokenizer", | Self::L34bCode => "hf-internal-testing/llama-tokenizer",
Which::Leo7b => "LeoLM/leo-hessianai-7b", Self::Leo7b => "LeoLM/leo-hessianai-7b",
Which::Leo13b => "LeoLM/leo-hessianai-13b", Self::Leo13b => "LeoLM/leo-hessianai-13b",
Which::Mixtral => "mistralai/Mixtral-8x7B-v0.1", Self::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
Which::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1", Self::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
Which::Mistral7b Self::Mistral7b
| Which::Mistral7bInstruct | Self::Mistral7bInstruct
| Which::Mistral7bInstructV02 | Self::Mistral7bInstructV02
| Which::Zephyr7bAlpha | Self::Zephyr7bAlpha
| Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1", | Self::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
Which::OpenChat35 => "openchat/openchat_3.5", Self::OpenChat35 => "openchat/openchat_3.5",
Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha", Self::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
Self::L8b => "meta-llama/Meta-Llama-3-8B", Self::L8b => "meta-llama/Meta-Llama-3-8B",
Self::Phi3 => "microsoft/Phi-3-mini-4k-instruct",
} }
} }
} }
@ -333,6 +339,10 @@ impl Args {
"QuantFactory/Meta-Llama-3-8B-GGUF", "QuantFactory/Meta-Llama-3-8B-GGUF",
"Meta-Llama-3-8B.Q4_K_S.gguf", "Meta-Llama-3-8B.Q4_K_S.gguf",
), ),
Which::Phi3 => (
"microsoft/Phi-3-mini-4k-instruct-gguf",
"Phi-3-mini-4k-instruct-q4.gguf",
),
}; };
let api = hf_hub::api::sync::Api::new()?; let api = hf_hub::api::sync::Api::new()?;
let api = api.model(repo.to_string()); let api = api.model(repo.to_string());
@ -432,7 +442,8 @@ fn main() -> anyhow::Result<()> {
| Which::L34bCode | Which::L34bCode
| Which::Leo7b | Which::Leo7b
| Which::Leo13b | Which::Leo13b
| Which::L8b => 1, | Which::L8b
| Which::Phi3 => 1,
Which::Mixtral Which::Mixtral
| Which::MixtralInstruct | Which::MixtralInstruct
| Which::Mistral7b | Which::Mistral7b