Integrate TheBloke quantized mistral weights. (#1012)

This commit is contained in:
Laurent Mazare
2023-09-30 22:39:42 +01:00
committed by GitHub
parent 4e55aaa51f
commit 328167ec04

View File

@ -44,6 +44,10 @@ enum Which {
L13bCode, L13bCode,
#[value(name = "32b-code")] #[value(name = "32b-code")]
L34bCode, L34bCode,
#[value(name = "7b-mistral")]
Mistral7b,
#[value(name = "7b-mistral-instruct")]
Mistral7bInstruct,
} }
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -110,7 +114,19 @@ impl Args {
Some(config) => std::path::PathBuf::from(config), Some(config) => std::path::PathBuf::from(config),
None => { None => {
let api = hf_hub::api::sync::Api::new()?; let api = hf_hub::api::sync::Api::new()?;
let api = api.model("hf-internal-testing/llama-tokenizer".to_string()); let repo = match self.which {
Which::L7b
| Which::L13b
| Which::L70b
| Which::L7bCode
| Which::L13bCode
| Which::L34bCode
| Which::L7bChat
| Which::L13bChat
| Which::L70bChat => "hf-internal-testing/llama-tokenizer",
Which::Mistral7b | Which::Mistral7bInstruct => "mistralai/Mistral-7B-v0.1",
};
let api = api.model(repo.to_string());
api.get("tokenizer.json")? api.get("tokenizer.json")?
} }
}; };
@ -140,6 +156,14 @@ impl Args {
Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"), Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"),
Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"), Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"),
Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"), Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"),
Which::Mistral7b => (
"TheBloke/Mistral-7B-v0.1-GGUF",
"mistral-7b-v0.1.Q4_K_S.gguf",
),
Which::Mistral7bInstruct => (
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
"mistral-7b-instruct-v0.1.Q4_K_S.gguf",
),
}; };
let api = hf_hub::api::sync::Api::new()?; let api = hf_hub::api::sync::Api::new()?;
let api = api.model(repo.to_string()); let api = api.model(repo.to_string());
@ -261,7 +285,7 @@ fn main() -> anyhow::Result<()> {
| Which::L7bCode | Which::L7bCode
| Which::L13bCode | Which::L13bCode
| Which::L34bCode => 1, | Which::L34bCode => 1,
Which::L70b | Which::L70bChat => 8, Which::Mistral7b | Which::Mistral7bInstruct | Which::L70b | Which::L70bChat => 8,
}; };
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))? ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
} }