diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index 15926e0a..12b4b059 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -12,6 +12,7 @@ use candle::quantized::{ggml_file, gguf_file}; use candle::{Device, Tensor}; use candle_transformers::generation::LogitsProcessor; +use candle_examples::token_output_stream::TokenOutputStream; use candle_transformers::models::quantized_llama as model; use model::ModelWeights; @@ -48,8 +49,10 @@ enum Which { Mistral7b, #[value(name = "7b-mistral-instruct")] Mistral7bInstruct, - #[value(name = "7b-zephyr")] - Zephyr7b, + #[value(name = "7b-zephyr-a")] + Zephyr7bAlpha, + #[value(name = "7b-zephyr-b")] + Zephyr7bBeta, } impl Which { @@ -65,7 +68,27 @@ impl Which { | Self::L13bCode | Self::L34bCode => false, // Zephyr is a fine tuned version of mistral and should be treated in the same way. - Self::Zephyr7b | Self::Mistral7b | Self::Mistral7bInstruct => true, + Self::Zephyr7bAlpha + | Self::Zephyr7bBeta + | Self::Mistral7b + | Self::Mistral7bInstruct => true, + } + } + + fn is_zephyr(&self) -> bool { + match self { + Self::L7b + | Self::L13b + | Self::L70b + | Self::L7bChat + | Self::L13bChat + | Self::L70bChat + | Self::L7bCode + | Self::L13bCode + | Self::L34bCode + | Self::Mistral7b + | Self::Mistral7bInstruct => false, + Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true, } } } @@ -84,7 +107,7 @@ struct Args { prompt: Option, /// The length of the sample to generate (in tokens). - #[arg(short = 'n', long, default_value_t = 100)] + #[arg(short = 'n', long, default_value_t = 1000)] sample_len: usize, /// The tokenizer config in json format. @@ -177,10 +200,13 @@ impl Args { "TheBloke/Mistral-7B-Instruct-v0.1-GGUF", "mistral-7b-instruct-v0.1.Q4_K_S.gguf", ), - Which::Zephyr7b => ( + Which::Zephyr7bAlpha => ( "TheBloke/zephyr-7B-alpha-GGUF", "zephyr-7b-alpha.Q4_K_M.gguf", ), + Which::Zephyr7bBeta => { + ("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf") + } }; let api = hf_hub::api::sync::Api::new()?; let api = api.model(repo.to_string()); @@ -191,31 +217,6 @@ impl Args { } } -fn print_token(next_token: u32, tokenizer: &Tokenizer) { - // Extracting the last token as a string is complicated, here we just apply some simple - // heuristics as it seems to work well enough for this example. See the following for more - // details: - // https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141 - if let Some(text) = tokenizer.id_to_token(next_token) { - let text = text.replace('▁', " "); - let ascii = text - .strip_prefix("<0x") - .and_then(|t| t.strip_suffix('>')) - .and_then(|t| u8::from_str_radix(t, 16).ok()); - match ascii { - None => print!("{text}"), - Some(ascii) => { - if let Some(chr) = char::from_u32(ascii as u32) { - if chr.is_ascii() { - print!("{chr}") - } - } - } - } - let _ = std::io::stdout().flush(); - } -} - fn format_size(size_in_bytes: usize) -> String { if size_in_bytes < 1_000 { format!("{}B", size_in_bytes) @@ -304,7 +305,8 @@ fn main() -> anyhow::Result<()> { | Which::L34bCode => 1, Which::Mistral7b | Which::Mistral7bInstruct - | Which::Zephyr7b + | Which::Zephyr7bAlpha + | Which::Zephyr7bBeta | Which::L70b | Which::L70bChat => 8, }; @@ -314,6 +316,7 @@ fn main() -> anyhow::Result<()> { println!("model built"); let tokenizer = args.tokenizer()?; + let mut tos = TokenOutputStream::new(tokenizer); let prompt = match args.prompt.as_deref() { Some("chat") => Prompt::Chat, Some("interactive") => Prompt::Interactive, @@ -336,7 +339,7 @@ fn main() -> anyhow::Result<()> { prompt.pop(); } } - if args.which == Which::Zephyr7b { + if args.which.is_zephyr() { format!("<|system|>\n\n<|user|>\n{prompt}\n<|assistant|>") } else if args.which.is_mistral() { format!("[INST] {prompt} [/INST]") @@ -346,7 +349,8 @@ fn main() -> anyhow::Result<()> { } }; print!("{}", &prompt_str); - let tokens = tokenizer + let tokens = tos + .tokenizer() .encode(prompt_str, true) .map_err(anyhow::Error::msg)?; if args.verbose_prompt { @@ -376,11 +380,15 @@ fn main() -> anyhow::Result<()> { }; let prompt_dt = start_prompt_processing.elapsed(); all_tokens.push(next_token); - print_token(next_token, &tokenizer); + if let Some(t) = tos.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } - let eos_token = *tokenizer.get_vocab(true).get("").unwrap(); + let eos_token = *tos.tokenizer().get_vocab(true).get("").unwrap(); let start_post_prompt = std::time::Instant::now(); + let mut sampled = 0; for index in 0..to_sample { let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?; let logits = model.forward(&input, prompt_tokens.len() + index)?; @@ -397,11 +405,19 @@ fn main() -> anyhow::Result<()> { }; next_token = logits_processor.sample(&logits)?; all_tokens.push(next_token); - print_token(next_token, &tokenizer); + if let Some(t) = tos.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + sampled += 1; if next_token == eos_token { break; }; } + if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? { + print!("{rest}"); + } + std::io::stdout().flush()?; let dt = start_post_prompt.elapsed(); println!( "\n\n{:4} prompt tokens processed: {:.2} token/s", @@ -409,9 +425,8 @@ fn main() -> anyhow::Result<()> { prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(), ); println!( - "{:4} tokens generated: {:.2} token/s", - to_sample, - to_sample as f64 / dt.as_secs_f64(), + "{sampled:4} tokens generated: {:.2} token/s", + sampled as f64 / dt.as_secs_f64(), ); match prompt {