Add some llama-v2 variants. (#545)

This commit is contained in:
Laurent Mazare
2023-08-22 08:35:15 +01:00
committed by GitHub
parent f16bb97401
commit 44420d8ae1

View File

@ -5,7 +5,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::Parser;
use clap::{Parser, ValueEnum};
use std::collections::HashMap;
use std::io::Write;
use tokenizers::Tokenizer;
@ -291,6 +291,16 @@ impl ModelWeights {
}
}
#[derive(Clone, Debug, Copy, ValueEnum)]
enum Which {
#[value(name = "7b")]
L7b,
#[value(name = "13b")]
L13b,
#[value(name = "70b")]
L70b,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@ -333,6 +343,10 @@ struct Args {
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model size to use.
#[arg(long, default_value = "7b")]
which: Which,
}
impl Args {
@ -352,9 +366,14 @@ impl Args {
let model_path = match &self.model {
Some(config) => std::path::PathBuf::from(config),
None => {
let (repo, filename) = match self.which {
Which::L7b => ("TheBloke/Llama-2-7B-GGML", "llama-2-7b.ggmlv3.q4_0.bin"),
Which::L13b => ("TheBloke/Llama-2-13B-GGML", "llama-2-13b.ggmlv3.q4_0.bin"),
Which::L70b => ("TheBloke/Llama-2-70B-GGML", "llama-2-70b.ggmlv3.q4_0.bin"),
};
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("TheBloke/Llama-2-7B-GGML".to_string());
api.get("llama-2-7b.ggmlv3.q4_0.bin")?
let api = api.model(repo.to_string());
api.get(filename)?
}
};
Ok(model_path)