diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index 18db3f9a..b21d6751 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -45,6 +45,10 @@ enum Which { L13bCode, #[value(name = "32b-code")] L34bCode, + #[value(name = "7b-leo")] + Leo7b, + #[value(name = "13b-leo")] + Leo13b, #[value(name = "7b-mistral")] Mistral7b, #[value(name = "7b-mistral-instruct")] @@ -70,7 +74,9 @@ impl Which { | Self::L70bChat | Self::L7bCode | Self::L13bCode - | Self::L34bCode => false, + | Self::L34bCode + | Self::Leo7b + | Self::Leo13b => false, // Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the // same way. Starling is a fine tuned version of OpenChat. Self::OpenChat35 @@ -93,6 +99,8 @@ impl Which { | Self::L7bCode | Self::L13bCode | Self::L34bCode + | Self::Leo7b + | Self::Leo13b | Self::Mistral7b | Self::Mistral7bInstruct | Self::OpenChat35 @@ -103,23 +111,26 @@ impl Which { fn is_open_chat(&self) -> bool { match self { - Which::L7b - | Which::L13b - | Which::L70b - | Which::L7bChat - | Which::L13bChat - | Which::L70bChat - | Which::L7bCode - | Which::L13bCode - | Which::L34bCode - | Which::Mistral7b - | Which::Mistral7bInstruct - | Which::Zephyr7bAlpha - | Which::Zephyr7bBeta => false, - Which::OpenChat35 | Self::Starling7bAlpha => true, + Self::L7b + | Self::L13b + | Self::L70b + | Self::L7bChat + | Self::L13bChat + | Self::L70bChat + | Self::L7bCode + | Self::L13bCode + | Self::L34bCode + | Self::Leo7b + | Self::Leo13b + | Self::Mistral7b + | Self::Mistral7bInstruct + | Self::Zephyr7bAlpha + | Self::Zephyr7bBeta => false, + Self::OpenChat35 | Self::Starling7bAlpha => true, } } - fn is_starling(&self) -> bool { + + fn tokenizer_repo(&self) -> &'static str { match self { Which::L7b | Which::L13b @@ -129,13 +140,15 @@ impl Which { | Which::L70bChat | Which::L7bCode | Which::L13bCode - | Which::L34bCode - | Which::Mistral7b + | Which::L34bCode => "hf-internal-testing/llama-tokenizer", + Which::Leo7b => "LeoLM/leo-hessianai-7b", + Which::Leo13b => "LeoLM/leo-hessianai-13b", + Which::Mistral7b | Which::Mistral7bInstruct | Which::Zephyr7bAlpha - | Which::Zephyr7bBeta - | Which::OpenChat35 => false, - Which::Starling7bAlpha => true, + | Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1", + Which::OpenChat35 => "openchat/openchat_3.5", + Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha", } } } @@ -204,15 +217,7 @@ impl Args { Some(config) => std::path::PathBuf::from(config), None => { let api = hf_hub::api::sync::Api::new()?; - let repo = if self.which.is_starling() { - "berkeley-nest/Starling-LM-7B-alpha" - } else if self.which.is_open_chat() { - "openchat/openchat_3.5" - } else if self.which.is_mistral() { - "mistralai/Mistral-7B-v0.1" - } else { - "hf-internal-testing/llama-tokenizer" - }; + let repo = self.which.tokenizer_repo(); let api = api.model(repo.to_string()); api.get("tokenizer.json")? } @@ -243,6 +248,14 @@ impl Args { Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"), Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"), Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"), + Which::Leo7b => ( + "TheBloke/leo-hessianai-7B-GGUF", + "leo-hessianai-7b.Q4_K_M.gguf", + ), + Which::Leo13b => ( + "TheBloke/leo-hessianai-13B-GGUF", + "leo-hessianai-13b.Q4_K_M.gguf", + ), Which::Mistral7b => ( "TheBloke/Mistral-7B-v0.1-GGUF", "mistral-7b-v0.1.Q4_K_S.gguf", @@ -358,7 +371,9 @@ fn main() -> anyhow::Result<()> { | Which::L13bChat | Which::L7bCode | Which::L13bCode - | Which::L34bCode => 1, + | Which::L34bCode + | Which::Leo7b + | Which::Leo13b => 1, Which::Mistral7b | Which::Mistral7bInstruct | Which::Zephyr7bAlpha