Add some llama-v2 variants. (#545)

This commit is contained in:
Laurent Mazare
2023-08-22 08:35:15 +01:00
committed by GitHub
parent f16bb97401
commit 44420d8ae1

View File

@ -5,7 +5,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")] #[cfg(feature = "accelerate")]
extern crate accelerate_src; extern crate accelerate_src;
use clap::Parser; use clap::{Parser, ValueEnum};
use std::collections::HashMap; use std::collections::HashMap;
use std::io::Write; use std::io::Write;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
@ -291,6 +291,16 @@ impl ModelWeights {
} }
} }
#[derive(Clone, Debug, Copy, ValueEnum)]
enum Which {
#[value(name = "7b")]
L7b,
#[value(name = "13b")]
L13b,
#[value(name = "70b")]
L70b,
}
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
struct Args { struct Args {
@ -333,6 +343,10 @@ struct Args {
/// The context size to consider for the repeat penalty. /// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)] #[arg(long, default_value_t = 64)]
repeat_last_n: usize, repeat_last_n: usize,
/// The model size to use.
#[arg(long, default_value = "7b")]
which: Which,
} }
impl Args { impl Args {
@ -352,9 +366,14 @@ impl Args {
let model_path = match &self.model { let model_path = match &self.model {
Some(config) => std::path::PathBuf::from(config), Some(config) => std::path::PathBuf::from(config),
None => { None => {
let (repo, filename) = match self.which {
Which::L7b => ("TheBloke/Llama-2-7B-GGML", "llama-2-7b.ggmlv3.q4_0.bin"),
Which::L13b => ("TheBloke/Llama-2-13B-GGML", "llama-2-13b.ggmlv3.q4_0.bin"),
Which::L70b => ("TheBloke/Llama-2-70B-GGML", "llama-2-70b.ggmlv3.q4_0.bin"),
};
let api = hf_hub::api::sync::Api::new()?; let api = hf_hub::api::sync::Api::new()?;
let api = api.model("TheBloke/Llama-2-7B-GGML".to_string()); let api = api.model(repo.to_string());
api.get("llama-2-7b.ggmlv3.q4_0.bin")? api.get(filename)?
} }
}; };
Ok(model_path) Ok(model_path)