Support mistral instruct v0.2. (#1475)

* Support mistral instruct v0.2.

* Use the safetensors model now that they are available.
This commit is contained in:
Laurent Mazare
2023-12-23 16:18:49 +01:00
committed by GitHub
parent 5b35fd0fcf
commit 88589d8815
2 changed files with 18 additions and 7 deletions

View File

@ -155,7 +155,7 @@ struct Args {
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
#[arg(long, default_value = "lmz/candle-mistral")]
#[arg(long, default_value = "mistralai/Mistral-7B-v0.1")]
model_id: String,
#[arg(long, default_value = "main")]
@ -226,8 +226,8 @@ fn main() -> Result<()> {
vec![repo.get("model-q4k.gguf")?]
} else {
vec![
repo.get("pytorch_model-00001-of-00002.safetensors")?,
repo.get("pytorch_model-00002-of-00002.safetensors")?,
repo.get("model-00001-of-00002.safetensors")?,
repo.get("model-00002-of-00002.safetensors")?,
]
}
}

View File

@ -53,6 +53,8 @@ enum Which {
Mistral7b,
#[value(name = "7b-mistral-instruct")]
Mistral7bInstruct,
#[value(name = "7b-mistral-instruct-v0.2")]
Mistral7bInstructV02,
#[value(name = "7b-zephyr-a")]
Zephyr7bAlpha,
#[value(name = "7b-zephyr-b")]
@ -90,7 +92,8 @@ impl Which {
| Self::Mixtral
| Self::MixtralInstruct
| Self::Mistral7b
| Self::Mistral7bInstruct => true,
| Self::Mistral7bInstruct
| Self::Mistral7bInstructV02 => true,
}
}
@ -111,6 +114,7 @@ impl Which {
| Self::MixtralInstruct
| Self::Mistral7b
| Self::Mistral7bInstruct
| Self::Mistral7bInstructV02
| Self::OpenChat35
| Self::Starling7bAlpha => false,
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
@ -134,6 +138,7 @@ impl Which {
| Self::MixtralInstruct
| Self::Mistral7b
| Self::Mistral7bInstruct
| Self::Mistral7bInstructV02
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta => false,
Self::OpenChat35 | Self::Starling7bAlpha => true,
@ -157,6 +162,7 @@ impl Which {
Which::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
Which::Mistral7b
| Which::Mistral7bInstruct
| Which::Mistral7bInstructV02
| Which::Zephyr7bAlpha
| Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
Which::OpenChat35 => "openchat/openchat_3.5",
@ -168,7 +174,7 @@ impl Which {
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// GGML file to load, typically a .bin file generated by the quantize command from llama.cpp
/// GGML/GGUF file to load, typically a .bin/.gguf file generated by the quantize command from llama.cpp
#[arg(long)]
model: Option<String>,
@ -284,6 +290,10 @@ impl Args {
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
"mistral-7b-instruct-v0.1.Q4_K_S.gguf",
),
Which::Mistral7bInstructV02 => (
"TheBloke/Mistral-7B-Instruct-v0.2-GGUF",
"mistral-7b-instruct-v0.2.Q4_K_S.gguf",
),
Which::Zephyr7bAlpha => (
"TheBloke/zephyr-7B-alpha-GGUF",
"zephyr-7b-alpha.Q4_K_M.gguf",
@ -354,7 +364,7 @@ fn main() -> anyhow::Result<()> {
let mut model = match model_path.extension().and_then(|v| v.to_str()) {
Some("gguf") => {
let model = gguf_file::Content::read(&mut file)?;
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count();
@ -370,7 +380,7 @@ fn main() -> anyhow::Result<()> {
ModelWeights::from_gguf(model, &mut file)?
}
Some("ggml" | "bin") | Some(_) | None => {
let model = ggml_file::Content::read(&mut file)?;
let model = ggml_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensors.iter() {
let elem_count = tensor.shape().elem_count();
@ -398,6 +408,7 @@ fn main() -> anyhow::Result<()> {
| Which::MixtralInstruct
| Which::Mistral7b
| Which::Mistral7bInstruct
| Which::Mistral7bInstructV02
| Which::Zephyr7bAlpha
| Which::Zephyr7bBeta
| Which::L70b