mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Quantized model small tweaks (#1290)
* Support the shape op in ONNX. * Share the axis normalization bits. * Add some limited support for gather. * Unsqueeze. * Comparison with broadcasting. * Add Not + handle i32. * Tweaks for the quantized model.
This commit is contained in:
@ -12,6 +12,7 @@ use candle::quantized::{ggml_file, gguf_file};
|
|||||||
use candle::{Device, Tensor};
|
use candle::{Device, Tensor};
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
use candle_transformers::models::quantized_llama as model;
|
use candle_transformers::models::quantized_llama as model;
|
||||||
use model::ModelWeights;
|
use model::ModelWeights;
|
||||||
|
|
||||||
@ -48,8 +49,10 @@ enum Which {
|
|||||||
Mistral7b,
|
Mistral7b,
|
||||||
#[value(name = "7b-mistral-instruct")]
|
#[value(name = "7b-mistral-instruct")]
|
||||||
Mistral7bInstruct,
|
Mistral7bInstruct,
|
||||||
#[value(name = "7b-zephyr")]
|
#[value(name = "7b-zephyr-a")]
|
||||||
Zephyr7b,
|
Zephyr7bAlpha,
|
||||||
|
#[value(name = "7b-zephyr-b")]
|
||||||
|
Zephyr7bBeta,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Which {
|
impl Which {
|
||||||
@ -65,7 +68,27 @@ impl Which {
|
|||||||
| Self::L13bCode
|
| Self::L13bCode
|
||||||
| Self::L34bCode => false,
|
| Self::L34bCode => false,
|
||||||
// Zephyr is a fine tuned version of mistral and should be treated in the same way.
|
// Zephyr is a fine tuned version of mistral and should be treated in the same way.
|
||||||
Self::Zephyr7b | Self::Mistral7b | Self::Mistral7bInstruct => true,
|
Self::Zephyr7bAlpha
|
||||||
|
| Self::Zephyr7bBeta
|
||||||
|
| Self::Mistral7b
|
||||||
|
| Self::Mistral7bInstruct => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_zephyr(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::L7b
|
||||||
|
| Self::L13b
|
||||||
|
| Self::L70b
|
||||||
|
| Self::L7bChat
|
||||||
|
| Self::L13bChat
|
||||||
|
| Self::L70bChat
|
||||||
|
| Self::L7bCode
|
||||||
|
| Self::L13bCode
|
||||||
|
| Self::L34bCode
|
||||||
|
| Self::Mistral7b
|
||||||
|
| Self::Mistral7bInstruct => false,
|
||||||
|
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -84,7 +107,7 @@ struct Args {
|
|||||||
prompt: Option<String>,
|
prompt: Option<String>,
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
/// The length of the sample to generate (in tokens).
|
||||||
#[arg(short = 'n', long, default_value_t = 100)]
|
#[arg(short = 'n', long, default_value_t = 1000)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
/// The tokenizer config in json format.
|
/// The tokenizer config in json format.
|
||||||
@ -177,10 +200,13 @@ impl Args {
|
|||||||
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
|
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
|
||||||
"mistral-7b-instruct-v0.1.Q4_K_S.gguf",
|
"mistral-7b-instruct-v0.1.Q4_K_S.gguf",
|
||||||
),
|
),
|
||||||
Which::Zephyr7b => (
|
Which::Zephyr7bAlpha => (
|
||||||
"TheBloke/zephyr-7B-alpha-GGUF",
|
"TheBloke/zephyr-7B-alpha-GGUF",
|
||||||
"zephyr-7b-alpha.Q4_K_M.gguf",
|
"zephyr-7b-alpha.Q4_K_M.gguf",
|
||||||
),
|
),
|
||||||
|
Which::Zephyr7bBeta => {
|
||||||
|
("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let api = api.model(repo.to_string());
|
let api = api.model(repo.to_string());
|
||||||
@ -191,31 +217,6 @@ impl Args {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn print_token(next_token: u32, tokenizer: &Tokenizer) {
|
|
||||||
// Extracting the last token as a string is complicated, here we just apply some simple
|
|
||||||
// heuristics as it seems to work well enough for this example. See the following for more
|
|
||||||
// details:
|
|
||||||
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
|
||||||
if let Some(text) = tokenizer.id_to_token(next_token) {
|
|
||||||
let text = text.replace('▁', " ");
|
|
||||||
let ascii = text
|
|
||||||
.strip_prefix("<0x")
|
|
||||||
.and_then(|t| t.strip_suffix('>'))
|
|
||||||
.and_then(|t| u8::from_str_radix(t, 16).ok());
|
|
||||||
match ascii {
|
|
||||||
None => print!("{text}"),
|
|
||||||
Some(ascii) => {
|
|
||||||
if let Some(chr) = char::from_u32(ascii as u32) {
|
|
||||||
if chr.is_ascii() {
|
|
||||||
print!("{chr}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let _ = std::io::stdout().flush();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn format_size(size_in_bytes: usize) -> String {
|
fn format_size(size_in_bytes: usize) -> String {
|
||||||
if size_in_bytes < 1_000 {
|
if size_in_bytes < 1_000 {
|
||||||
format!("{}B", size_in_bytes)
|
format!("{}B", size_in_bytes)
|
||||||
@ -304,7 +305,8 @@ fn main() -> anyhow::Result<()> {
|
|||||||
| Which::L34bCode => 1,
|
| Which::L34bCode => 1,
|
||||||
Which::Mistral7b
|
Which::Mistral7b
|
||||||
| Which::Mistral7bInstruct
|
| Which::Mistral7bInstruct
|
||||||
| Which::Zephyr7b
|
| Which::Zephyr7bAlpha
|
||||||
|
| Which::Zephyr7bBeta
|
||||||
| Which::L70b
|
| Which::L70b
|
||||||
| Which::L70bChat => 8,
|
| Which::L70bChat => 8,
|
||||||
};
|
};
|
||||||
@ -314,6 +316,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
println!("model built");
|
println!("model built");
|
||||||
|
|
||||||
let tokenizer = args.tokenizer()?;
|
let tokenizer = args.tokenizer()?;
|
||||||
|
let mut tos = TokenOutputStream::new(tokenizer);
|
||||||
let prompt = match args.prompt.as_deref() {
|
let prompt = match args.prompt.as_deref() {
|
||||||
Some("chat") => Prompt::Chat,
|
Some("chat") => Prompt::Chat,
|
||||||
Some("interactive") => Prompt::Interactive,
|
Some("interactive") => Prompt::Interactive,
|
||||||
@ -336,7 +339,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
prompt.pop();
|
prompt.pop();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if args.which == Which::Zephyr7b {
|
if args.which.is_zephyr() {
|
||||||
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>")
|
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>")
|
||||||
} else if args.which.is_mistral() {
|
} else if args.which.is_mistral() {
|
||||||
format!("[INST] {prompt} [/INST]")
|
format!("[INST] {prompt} [/INST]")
|
||||||
@ -346,7 +349,8 @@ fn main() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
print!("{}", &prompt_str);
|
print!("{}", &prompt_str);
|
||||||
let tokens = tokenizer
|
let tokens = tos
|
||||||
|
.tokenizer()
|
||||||
.encode(prompt_str, true)
|
.encode(prompt_str, true)
|
||||||
.map_err(anyhow::Error::msg)?;
|
.map_err(anyhow::Error::msg)?;
|
||||||
if args.verbose_prompt {
|
if args.verbose_prompt {
|
||||||
@ -376,11 +380,15 @@ fn main() -> anyhow::Result<()> {
|
|||||||
};
|
};
|
||||||
let prompt_dt = start_prompt_processing.elapsed();
|
let prompt_dt = start_prompt_processing.elapsed();
|
||||||
all_tokens.push(next_token);
|
all_tokens.push(next_token);
|
||||||
print_token(next_token, &tokenizer);
|
if let Some(t) = tos.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
|
||||||
let eos_token = *tokenizer.get_vocab(true).get("</s>").unwrap();
|
let eos_token = *tos.tokenizer().get_vocab(true).get("</s>").unwrap();
|
||||||
|
|
||||||
let start_post_prompt = std::time::Instant::now();
|
let start_post_prompt = std::time::Instant::now();
|
||||||
|
let mut sampled = 0;
|
||||||
for index in 0..to_sample {
|
for index in 0..to_sample {
|
||||||
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
||||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||||
@ -397,11 +405,19 @@ fn main() -> anyhow::Result<()> {
|
|||||||
};
|
};
|
||||||
next_token = logits_processor.sample(&logits)?;
|
next_token = logits_processor.sample(&logits)?;
|
||||||
all_tokens.push(next_token);
|
all_tokens.push(next_token);
|
||||||
print_token(next_token, &tokenizer);
|
if let Some(t) = tos.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
sampled += 1;
|
||||||
if next_token == eos_token {
|
if next_token == eos_token {
|
||||||
break;
|
break;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
let dt = start_post_prompt.elapsed();
|
let dt = start_post_prompt.elapsed();
|
||||||
println!(
|
println!(
|
||||||
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
||||||
@ -409,9 +425,8 @@ fn main() -> anyhow::Result<()> {
|
|||||||
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
||||||
);
|
);
|
||||||
println!(
|
println!(
|
||||||
"{:4} tokens generated: {:.2} token/s",
|
"{sampled:4} tokens generated: {:.2} token/s",
|
||||||
to_sample,
|
sampled as f64 / dt.as_secs_f64(),
|
||||||
to_sample as f64 / dt.as_secs_f64(),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
match prompt {
|
match prompt {
|
||||||
|
Reference in New Issue
Block a user