mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add the phi-v3 quantized model. (#2118)
* Add the phi-v3 quantized model. * Also include phi-3 in the main phi example.
This commit is contained in:
@ -13,8 +13,8 @@ use candle::Tensor;
|
|||||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||||
|
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
use candle_transformers::models::quantized_phi as model;
|
use candle_transformers::models::quantized_llama::ModelWeights as Phi3;
|
||||||
use model::ModelWeights;
|
use candle_transformers::models::quantized_phi::ModelWeights as Phi2;
|
||||||
|
|
||||||
const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. ";
|
const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. ";
|
||||||
|
|
||||||
@ -22,6 +22,8 @@ const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. "
|
|||||||
enum Which {
|
enum Which {
|
||||||
#[value(name = "phi-2")]
|
#[value(name = "phi-2")]
|
||||||
Phi2,
|
Phi2,
|
||||||
|
#[value(name = "phi-3")]
|
||||||
|
Phi3,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -92,7 +94,11 @@ impl Args {
|
|||||||
Some(config) => std::path::PathBuf::from(config),
|
Some(config) => std::path::PathBuf::from(config),
|
||||||
None => {
|
None => {
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let api = api.model("microsoft/phi-2".to_string());
|
let repo = match self.which {
|
||||||
|
Which::Phi2 => "microsoft/phi-2",
|
||||||
|
Which::Phi3 => "microsoft/Phi-3-mini-4k-instruct",
|
||||||
|
};
|
||||||
|
let api = api.model(repo.to_string());
|
||||||
api.get("tokenizer.json")?
|
api.get("tokenizer.json")?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -105,6 +111,10 @@ impl Args {
|
|||||||
None => {
|
None => {
|
||||||
let (repo, filename) = match self.which {
|
let (repo, filename) = match self.which {
|
||||||
Which::Phi2 => ("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf"),
|
Which::Phi2 => ("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf"),
|
||||||
|
Which::Phi3 => (
|
||||||
|
"microsoft/Phi-3-mini-4k-instruct-gguf",
|
||||||
|
"Phi-3-mini-4k-instruct-q4.gguf",
|
||||||
|
),
|
||||||
};
|
};
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let api = api.model(repo.to_string());
|
let api = api.model(repo.to_string());
|
||||||
@ -127,6 +137,20 @@ fn format_size(size_in_bytes: usize) -> String {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum Model {
|
||||||
|
Phi2(Phi2),
|
||||||
|
Phi3(Phi3),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
fn forward(&mut self, xs: &Tensor, pos: usize) -> candle::Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
Self::Phi2(m) => m.forward(xs, pos),
|
||||||
|
Self::Phi3(m) => m.forward(xs, pos),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
use tracing_subscriber::prelude::*;
|
use tracing_subscriber::prelude::*;
|
||||||
@ -171,7 +195,10 @@ fn main() -> anyhow::Result<()> {
|
|||||||
&format_size(total_size_in_bytes),
|
&format_size(total_size_in_bytes),
|
||||||
start.elapsed().as_secs_f32(),
|
start.elapsed().as_secs_f32(),
|
||||||
);
|
);
|
||||||
ModelWeights::from_gguf(model, &mut file, &device)?
|
match args.which {
|
||||||
|
Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?),
|
||||||
|
Which::Phi3 => Model::Phi3(Phi3::from_gguf(model, &mut file, &device)?),
|
||||||
|
}
|
||||||
};
|
};
|
||||||
println!("model built");
|
println!("model built");
|
||||||
|
|
||||||
|
@ -69,6 +69,8 @@ enum Which {
|
|||||||
MixtralInstruct,
|
MixtralInstruct,
|
||||||
#[value(name = "llama3-8b")]
|
#[value(name = "llama3-8b")]
|
||||||
L8b,
|
L8b,
|
||||||
|
#[value(name = "phi3")]
|
||||||
|
Phi3,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Which {
|
impl Which {
|
||||||
@ -85,7 +87,8 @@ impl Which {
|
|||||||
| Self::L34bCode
|
| Self::L34bCode
|
||||||
| Self::Leo7b
|
| Self::Leo7b
|
||||||
| Self::Leo13b
|
| Self::Leo13b
|
||||||
| Self::L8b => false,
|
| Self::L8b
|
||||||
|
| Self::Phi3 => false,
|
||||||
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
|
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
|
||||||
// same way. Starling is a fine tuned version of OpenChat.
|
// same way. Starling is a fine tuned version of OpenChat.
|
||||||
Self::OpenChat35
|
Self::OpenChat35
|
||||||
@ -120,7 +123,8 @@ impl Which {
|
|||||||
| Self::Mistral7bInstructV02
|
| Self::Mistral7bInstructV02
|
||||||
| Self::OpenChat35
|
| Self::OpenChat35
|
||||||
| Self::Starling7bAlpha
|
| Self::Starling7bAlpha
|
||||||
| Self::L8b => false,
|
| Self::L8b
|
||||||
|
| Self::Phi3 => false,
|
||||||
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -145,34 +149,36 @@ impl Which {
|
|||||||
| Self::Mistral7bInstructV02
|
| Self::Mistral7bInstructV02
|
||||||
| Self::Zephyr7bAlpha
|
| Self::Zephyr7bAlpha
|
||||||
| Self::Zephyr7bBeta
|
| Self::Zephyr7bBeta
|
||||||
| Self::L8b => false,
|
| Self::L8b
|
||||||
|
| Self::Phi3 => false,
|
||||||
Self::OpenChat35 | Self::Starling7bAlpha => true,
|
Self::OpenChat35 | Self::Starling7bAlpha => true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tokenizer_repo(&self) -> &'static str {
|
fn tokenizer_repo(&self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
Which::L7b
|
Self::L7b
|
||||||
| Which::L13b
|
| Self::L13b
|
||||||
| Which::L70b
|
| Self::L70b
|
||||||
| Which::L7bChat
|
| Self::L7bChat
|
||||||
| Which::L13bChat
|
| Self::L13bChat
|
||||||
| Which::L70bChat
|
| Self::L70bChat
|
||||||
| Which::L7bCode
|
| Self::L7bCode
|
||||||
| Which::L13bCode
|
| Self::L13bCode
|
||||||
| Which::L34bCode => "hf-internal-testing/llama-tokenizer",
|
| Self::L34bCode => "hf-internal-testing/llama-tokenizer",
|
||||||
Which::Leo7b => "LeoLM/leo-hessianai-7b",
|
Self::Leo7b => "LeoLM/leo-hessianai-7b",
|
||||||
Which::Leo13b => "LeoLM/leo-hessianai-13b",
|
Self::Leo13b => "LeoLM/leo-hessianai-13b",
|
||||||
Which::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
|
Self::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
|
||||||
Which::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
Self::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||||
Which::Mistral7b
|
Self::Mistral7b
|
||||||
| Which::Mistral7bInstruct
|
| Self::Mistral7bInstruct
|
||||||
| Which::Mistral7bInstructV02
|
| Self::Mistral7bInstructV02
|
||||||
| Which::Zephyr7bAlpha
|
| Self::Zephyr7bAlpha
|
||||||
| Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
|
| Self::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
|
||||||
Which::OpenChat35 => "openchat/openchat_3.5",
|
Self::OpenChat35 => "openchat/openchat_3.5",
|
||||||
Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
|
Self::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
|
||||||
Self::L8b => "meta-llama/Meta-Llama-3-8B",
|
Self::L8b => "meta-llama/Meta-Llama-3-8B",
|
||||||
|
Self::Phi3 => "microsoft/Phi-3-mini-4k-instruct",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -333,6 +339,10 @@ impl Args {
|
|||||||
"QuantFactory/Meta-Llama-3-8B-GGUF",
|
"QuantFactory/Meta-Llama-3-8B-GGUF",
|
||||||
"Meta-Llama-3-8B.Q4_K_S.gguf",
|
"Meta-Llama-3-8B.Q4_K_S.gguf",
|
||||||
),
|
),
|
||||||
|
Which::Phi3 => (
|
||||||
|
"microsoft/Phi-3-mini-4k-instruct-gguf",
|
||||||
|
"Phi-3-mini-4k-instruct-q4.gguf",
|
||||||
|
),
|
||||||
};
|
};
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let api = api.model(repo.to_string());
|
let api = api.model(repo.to_string());
|
||||||
@ -432,7 +442,8 @@ fn main() -> anyhow::Result<()> {
|
|||||||
| Which::L34bCode
|
| Which::L34bCode
|
||||||
| Which::Leo7b
|
| Which::Leo7b
|
||||||
| Which::Leo13b
|
| Which::Leo13b
|
||||||
| Which::L8b => 1,
|
| Which::L8b
|
||||||
|
| Which::Phi3 => 1,
|
||||||
Which::Mixtral
|
Which::Mixtral
|
||||||
| Which::MixtralInstruct
|
| Which::MixtralInstruct
|
||||||
| Which::Mistral7b
|
| Which::Mistral7b
|
||||||
|
Reference in New Issue
Block a user