diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index ab8a56ba..18db3f9a 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -55,6 +55,8 @@ enum Which { Zephyr7bBeta, #[value(name = "7b-open-chat-3.5")] OpenChat35, + #[value(name = "7b-starling-a")] + Starling7bAlpha, } impl Which { @@ -70,8 +72,9 @@ impl Which { | Self::L13bCode | Self::L34bCode => false, // Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the - // same way. + // same way. Starling is a fine tuned version of OpenChat. Self::OpenChat35 + | Self::Starling7bAlpha | Self::Zephyr7bAlpha | Self::Zephyr7bBeta | Self::Mistral7b @@ -92,7 +95,8 @@ impl Which { | Self::L34bCode | Self::Mistral7b | Self::Mistral7bInstruct - | Self::OpenChat35 => false, + | Self::OpenChat35 + | Self::Starling7bAlpha => false, Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true, } } @@ -112,7 +116,26 @@ impl Which { | Which::Mistral7bInstruct | Which::Zephyr7bAlpha | Which::Zephyr7bBeta => false, - Which::OpenChat35 => true, + Which::OpenChat35 | Self::Starling7bAlpha => true, + } + } + fn is_starling(&self) -> bool { + match self { + Which::L7b + | Which::L13b + | Which::L70b + | Which::L7bChat + | Which::L13bChat + | Which::L70bChat + | Which::L7bCode + | Which::L13bCode + | Which::L34bCode + | Which::Mistral7b + | Which::Mistral7bInstruct + | Which::Zephyr7bAlpha + | Which::Zephyr7bBeta + | Which::OpenChat35 => false, + Which::Starling7bAlpha => true, } } } @@ -181,7 +204,9 @@ impl Args { Some(config) => std::path::PathBuf::from(config), None => { let api = hf_hub::api::sync::Api::new()?; - let repo = if self.which.is_open_chat() { + let repo = if self.which.is_starling() { + "berkeley-nest/Starling-LM-7B-alpha" + } else if self.which.is_open_chat() { "openchat/openchat_3.5" } else if self.which.is_mistral() { "mistralai/Mistral-7B-v0.1" @@ -234,6 +259,10 @@ impl Args { ("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf") } Which::OpenChat35 => ("TheBloke/openchat_3.5-GGUF", "openchat_3.5.Q4_K_M.gguf"), + Which::Starling7bAlpha => ( + "TheBloke/Starling-LM-7B-alpha-GGUF", + "starling-lm-7b-alpha.Q4_K_M.gguf", + ), }; let api = hf_hub::api::sync::Api::new()?; let api = api.model(repo.to_string()); @@ -336,7 +365,8 @@ fn main() -> anyhow::Result<()> { | Which::Zephyr7bBeta | Which::L70b | Which::L70bChat - | Which::OpenChat35 => 8, + | Which::OpenChat35 + | Which::Starling7bAlpha => 8, }; ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))? } @@ -369,7 +399,7 @@ fn main() -> anyhow::Result<()> { } } if args.which.is_open_chat() { - format!("User: {prompt}<|end_of_turn|>Assistant: ") + format!("GPT4 Correct User: {prompt}<|end_of_turn|>GPT4 Correct Assistant:") } else if args.which.is_zephyr() { if prompt_index == 0 || is_interactive { format!("<|system|>\n\n<|user|>\n{prompt}\n<|assistant|>",)