From 328167ec04bec4536b4ab5581685ebdf918211ee Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 30 Sep 2023 22:39:42 +0100 Subject: [PATCH] Integrate TheBloke quantized mistral weights. (#1012) --- candle-examples/examples/quantized/main.rs | 28 ++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index a80ad420..3e663851 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -44,6 +44,10 @@ enum Which { L13bCode, #[value(name = "32b-code")] L34bCode, + #[value(name = "7b-mistral")] + Mistral7b, + #[value(name = "7b-mistral-instruct")] + Mistral7bInstruct, } #[derive(Parser, Debug)] @@ -110,7 +114,19 @@ impl Args { Some(config) => std::path::PathBuf::from(config), None => { let api = hf_hub::api::sync::Api::new()?; - let api = api.model("hf-internal-testing/llama-tokenizer".to_string()); + let repo = match self.which { + Which::L7b + | Which::L13b + | Which::L70b + | Which::L7bCode + | Which::L13bCode + | Which::L34bCode + | Which::L7bChat + | Which::L13bChat + | Which::L70bChat => "hf-internal-testing/llama-tokenizer", + Which::Mistral7b | Which::Mistral7bInstruct => "mistralai/Mistral-7B-v0.1", + }; + let api = api.model(repo.to_string()); api.get("tokenizer.json")? } }; @@ -140,6 +156,14 @@ impl Args { Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"), Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"), Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"), + Which::Mistral7b => ( + "TheBloke/Mistral-7B-v0.1-GGUF", + "mistral-7b-v0.1.Q4_K_S.gguf", + ), + Which::Mistral7bInstruct => ( + "TheBloke/Mistral-7B-Instruct-v0.1-GGUF", + "mistral-7b-instruct-v0.1.Q4_K_S.gguf", + ), }; let api = hf_hub::api::sync::Api::new()?; let api = api.model(repo.to_string()); @@ -261,7 +285,7 @@ fn main() -> anyhow::Result<()> { | Which::L7bCode | Which::L13bCode | Which::L34bCode => 1, - Which::L70b | Which::L70bChat => 8, + Which::Mistral7b | Which::Mistral7bInstruct | Which::L70b | Which::L70bChat => 8, }; ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))? }