Quantized model small tweaks (#1290)

* Support the shape op in ONNX.

* Share the axis normalization bits.

* Add some limited support for gather.

* Unsqueeze.

* Comparison with broadcasting.

* Add Not + handle i32.

* Tweaks for the quantized model.
This commit is contained in:
Laurent Mazare
2023-11-07 21:21:37 +01:00
committed by GitHub
parent c912d24570
commit d4a45c936a

View File

@ -12,6 +12,7 @@ use candle::quantized::{ggml_file, gguf_file};
use candle::{Device, Tensor}; use candle::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor; use candle_transformers::generation::LogitsProcessor;
use candle_examples::token_output_stream::TokenOutputStream;
use candle_transformers::models::quantized_llama as model; use candle_transformers::models::quantized_llama as model;
use model::ModelWeights; use model::ModelWeights;
@ -48,8 +49,10 @@ enum Which {
Mistral7b, Mistral7b,
#[value(name = "7b-mistral-instruct")] #[value(name = "7b-mistral-instruct")]
Mistral7bInstruct, Mistral7bInstruct,
#[value(name = "7b-zephyr")] #[value(name = "7b-zephyr-a")]
Zephyr7b, Zephyr7bAlpha,
#[value(name = "7b-zephyr-b")]
Zephyr7bBeta,
} }
impl Which { impl Which {
@ -65,7 +68,27 @@ impl Which {
| Self::L13bCode | Self::L13bCode
| Self::L34bCode => false, | Self::L34bCode => false,
// Zephyr is a fine tuned version of mistral and should be treated in the same way. // Zephyr is a fine tuned version of mistral and should be treated in the same way.
Self::Zephyr7b | Self::Mistral7b | Self::Mistral7bInstruct => true, Self::Zephyr7bAlpha
| Self::Zephyr7bBeta
| Self::Mistral7b
| Self::Mistral7bInstruct => true,
}
}
fn is_zephyr(&self) -> bool {
match self {
Self::L7b
| Self::L13b
| Self::L70b
| Self::L7bChat
| Self::L13bChat
| Self::L70bChat
| Self::L7bCode
| Self::L13bCode
| Self::L34bCode
| Self::Mistral7b
| Self::Mistral7bInstruct => false,
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
} }
} }
} }
@ -84,7 +107,7 @@ struct Args {
prompt: Option<String>, prompt: Option<String>,
/// The length of the sample to generate (in tokens). /// The length of the sample to generate (in tokens).
#[arg(short = 'n', long, default_value_t = 100)] #[arg(short = 'n', long, default_value_t = 1000)]
sample_len: usize, sample_len: usize,
/// The tokenizer config in json format. /// The tokenizer config in json format.
@ -177,10 +200,13 @@ impl Args {
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF", "TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
"mistral-7b-instruct-v0.1.Q4_K_S.gguf", "mistral-7b-instruct-v0.1.Q4_K_S.gguf",
), ),
Which::Zephyr7b => ( Which::Zephyr7bAlpha => (
"TheBloke/zephyr-7B-alpha-GGUF", "TheBloke/zephyr-7B-alpha-GGUF",
"zephyr-7b-alpha.Q4_K_M.gguf", "zephyr-7b-alpha.Q4_K_M.gguf",
), ),
Which::Zephyr7bBeta => {
("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf")
}
}; };
let api = hf_hub::api::sync::Api::new()?; let api = hf_hub::api::sync::Api::new()?;
let api = api.model(repo.to_string()); let api = api.model(repo.to_string());
@ -191,31 +217,6 @@ impl Args {
} }
} }
fn print_token(next_token: u32, tokenizer: &Tokenizer) {
// Extracting the last token as a string is complicated, here we just apply some simple
// heuristics as it seems to work well enough for this example. See the following for more
// details:
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
if let Some(text) = tokenizer.id_to_token(next_token) {
let text = text.replace('▁', " ");
let ascii = text
.strip_prefix("<0x")
.and_then(|t| t.strip_suffix('>'))
.and_then(|t| u8::from_str_radix(t, 16).ok());
match ascii {
None => print!("{text}"),
Some(ascii) => {
if let Some(chr) = char::from_u32(ascii as u32) {
if chr.is_ascii() {
print!("{chr}")
}
}
}
}
let _ = std::io::stdout().flush();
}
}
fn format_size(size_in_bytes: usize) -> String { fn format_size(size_in_bytes: usize) -> String {
if size_in_bytes < 1_000 { if size_in_bytes < 1_000 {
format!("{}B", size_in_bytes) format!("{}B", size_in_bytes)
@ -304,7 +305,8 @@ fn main() -> anyhow::Result<()> {
| Which::L34bCode => 1, | Which::L34bCode => 1,
Which::Mistral7b Which::Mistral7b
| Which::Mistral7bInstruct | Which::Mistral7bInstruct
| Which::Zephyr7b | Which::Zephyr7bAlpha
| Which::Zephyr7bBeta
| Which::L70b | Which::L70b
| Which::L70bChat => 8, | Which::L70bChat => 8,
}; };
@ -314,6 +316,7 @@ fn main() -> anyhow::Result<()> {
println!("model built"); println!("model built");
let tokenizer = args.tokenizer()?; let tokenizer = args.tokenizer()?;
let mut tos = TokenOutputStream::new(tokenizer);
let prompt = match args.prompt.as_deref() { let prompt = match args.prompt.as_deref() {
Some("chat") => Prompt::Chat, Some("chat") => Prompt::Chat,
Some("interactive") => Prompt::Interactive, Some("interactive") => Prompt::Interactive,
@ -336,7 +339,7 @@ fn main() -> anyhow::Result<()> {
prompt.pop(); prompt.pop();
} }
} }
if args.which == Which::Zephyr7b { if args.which.is_zephyr() {
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>") format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>")
} else if args.which.is_mistral() { } else if args.which.is_mistral() {
format!("[INST] {prompt} [/INST]") format!("[INST] {prompt} [/INST]")
@ -346,7 +349,8 @@ fn main() -> anyhow::Result<()> {
} }
}; };
print!("{}", &prompt_str); print!("{}", &prompt_str);
let tokens = tokenizer let tokens = tos
.tokenizer()
.encode(prompt_str, true) .encode(prompt_str, true)
.map_err(anyhow::Error::msg)?; .map_err(anyhow::Error::msg)?;
if args.verbose_prompt { if args.verbose_prompt {
@ -376,11 +380,15 @@ fn main() -> anyhow::Result<()> {
}; };
let prompt_dt = start_prompt_processing.elapsed(); let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token); all_tokens.push(next_token);
print_token(next_token, &tokenizer); if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
let eos_token = *tokenizer.get_vocab(true).get("</s>").unwrap(); let eos_token = *tos.tokenizer().get_vocab(true).get("</s>").unwrap();
let start_post_prompt = std::time::Instant::now(); let start_post_prompt = std::time::Instant::now();
let mut sampled = 0;
for index in 0..to_sample { for index in 0..to_sample {
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?; let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?; let logits = model.forward(&input, prompt_tokens.len() + index)?;
@ -397,11 +405,19 @@ fn main() -> anyhow::Result<()> {
}; };
next_token = logits_processor.sample(&logits)?; next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token); all_tokens.push(next_token);
print_token(next_token, &tokenizer); if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
sampled += 1;
if next_token == eos_token { if next_token == eos_token {
break; break;
}; };
} }
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
let dt = start_post_prompt.elapsed(); let dt = start_post_prompt.elapsed();
println!( println!(
"\n\n{:4} prompt tokens processed: {:.2} token/s", "\n\n{:4} prompt tokens processed: {:.2} token/s",
@ -409,9 +425,8 @@ fn main() -> anyhow::Result<()> {
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(), prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
); );
println!( println!(
"{:4} tokens generated: {:.2} token/s", "{sampled:4} tokens generated: {:.2} token/s",
to_sample, sampled as f64 / dt.as_secs_f64(),
to_sample as f64 / dt.as_secs_f64(),
); );
match prompt { match prompt {