diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index 068ae12d..cf034142 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -53,6 +53,8 @@ enum Which { Zephyr7bAlpha, #[value(name = "7b-zephyr-b")] Zephyr7bBeta, + #[value(name = "7b-open-chat-3.5")] + OpenChat35, } impl Which { @@ -67,8 +69,10 @@ impl Which { | Self::L7bCode | Self::L13bCode | Self::L34bCode => false, - // Zephyr is a fine tuned version of mistral and should be treated in the same way. - Self::Zephyr7bAlpha + // Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the + // same way. + Self::OpenChat35 + | Self::Zephyr7bAlpha | Self::Zephyr7bBeta | Self::Mistral7b | Self::Mistral7bInstruct => true, @@ -87,10 +91,30 @@ impl Which { | Self::L13bCode | Self::L34bCode | Self::Mistral7b - | Self::Mistral7bInstruct => false, + | Self::Mistral7bInstruct + | Self::OpenChat35 => false, Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true, } } + + fn is_open_chat(&self) -> bool { + match self { + Which::L7b + | Which::L13b + | Which::L70b + | Which::L7bChat + | Which::L13bChat + | Which::L70bChat + | Which::L7bCode + | Which::L13bCode + | Which::L34bCode + | Which::Mistral7b + | Which::Mistral7bInstruct + | Which::Zephyr7bAlpha + | Which::Zephyr7bBeta => false, + Which::OpenChat35 => true, + } + } } #[derive(Parser, Debug)] @@ -207,6 +231,7 @@ impl Args { Which::Zephyr7bBeta => { ("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf") } + Which::OpenChat35 => ("TheBloke/openchat_3.5-GGUF", "openchat_3.5.Q4_K_M.gguf"), }; let api = hf_hub::api::sync::Api::new()?; let api = api.model(repo.to_string()); @@ -308,7 +333,8 @@ fn main() -> anyhow::Result<()> { | Which::Zephyr7bAlpha | Which::Zephyr7bBeta | Which::L70b - | Which::L70bChat => 8, + | Which::L70bChat + | Which::OpenChat35 => 8, }; ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))? } @@ -340,7 +366,9 @@ fn main() -> anyhow::Result<()> { prompt.pop(); } } - if args.which.is_zephyr() { + if args.which.is_open_chat() { + format!("User: {prompt}<|end_of_turn|>Assistant: ") + } else if args.which.is_zephyr() { if prompt_index == 0 || is_interactive { format!("<|system|>\n\n<|user|>\n{prompt}\n<|assistant|>",) } else { @@ -390,8 +418,12 @@ fn main() -> anyhow::Result<()> { std::io::stdout().flush()?; } - let eos_token = *tos.tokenizer().get_vocab(true).get("").unwrap(); - + let eos_token = if args.which.is_open_chat() { + "<|end_of_turn|>" + } else { + "" + }; + let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap(); let start_post_prompt = std::time::Instant::now(); let mut sampled = 0; for index in 0..to_sample {