diff --git a/candle-examples/examples/quantized-qwen2-instruct/main.rs b/candle-examples/examples/quantized-qwen2-instruct/main.rs index 1bd230e0..ff6ebe90 100644 --- a/candle-examples/examples/quantized-qwen2-instruct/main.rs +++ b/candle-examples/examples/quantized-qwen2-instruct/main.rs @@ -27,6 +27,8 @@ enum Which { W2_7b, #[value(name = "72b")] W2_72b, + #[value(name = "deepseekr1-qwen7b")] + DeepseekR1Qwen7B, } #[derive(Parser, Debug)] @@ -102,6 +104,7 @@ impl Args { Which::W2_1_5b => "Qwen/Qwen2-1.5B-Instruct", Which::W2_7b => "Qwen/Qwen2-7B-Instruct", Which::W2_72b => "Qwen/Qwen2-72B-Instruct", + Which::DeepseekR1Qwen7B => "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", }; let api = api.model(repo.to_string()); api.get("tokenizer.json")? @@ -135,6 +138,11 @@ impl Args { "qwen2-72b-instruct-q4_0.gguf", "main", ), + Which::DeepseekR1Qwen7B => ( + "unsloth/DeepSeek-R1-Distill-Qwen-7B-GGUF", + "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf", + "main", + ), }; let api = hf_hub::api::sync::Api::new()?; api.repo(hf_hub::Repo::with_revision( @@ -211,11 +219,15 @@ fn main() -> anyhow::Result<()> { let tokenizer = args.tokenizer()?; let mut tos = TokenOutputStream::new(tokenizer); - let prompt_str = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string()); - let prompt_str = format!( - "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", - prompt_str - ); + let prompt_str = args + .prompt + .clone() + .unwrap_or_else(|| DEFAULT_PROMPT.to_string()); + + let prompt_str = match args.which { + Which::DeepseekR1Qwen7B => format!("<|User|>{prompt_str}<|Assistant|>"), + _ => format!("<|im_start|>user\n{prompt_str}<|im_end|>\n<|im_start|>assistant\n"), + }; print!("formatted instruct prompt: {}", &prompt_str); let tokens = tos .tokenizer() @@ -260,7 +272,13 @@ fn main() -> anyhow::Result<()> { print!("{t}"); std::io::stdout().flush()?; } - let eos_token = *tos.tokenizer().get_vocab(true).get("<|im_end|>").unwrap(); + + let eos_token = match args.which { + Which::DeepseekR1Qwen7B => "<|end▁of▁sentence|>", + _ => "<|im_end|>", + }; + + let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap(); let start_post_prompt = std::time::Instant::now(); let mut sampled = 0; for index in 0..to_sample {