mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add some llama-v2 variants. (#545)
This commit is contained in:
@ -5,7 +5,7 @@ extern crate intel_mkl_src;
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::Parser;
|
||||
use clap::{Parser, ValueEnum};
|
||||
use std::collections::HashMap;
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
@ -291,6 +291,16 @@ impl ModelWeights {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "7b")]
|
||||
L7b,
|
||||
#[value(name = "13b")]
|
||||
L13b,
|
||||
#[value(name = "70b")]
|
||||
L70b,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
@ -333,6 +343,10 @@ struct Args {
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "7b")]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
@ -352,9 +366,14 @@ impl Args {
|
||||
let model_path = match &self.model {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let (repo, filename) = match self.which {
|
||||
Which::L7b => ("TheBloke/Llama-2-7B-GGML", "llama-2-7b.ggmlv3.q4_0.bin"),
|
||||
Which::L13b => ("TheBloke/Llama-2-13B-GGML", "llama-2-13b.ggmlv3.q4_0.bin"),
|
||||
Which::L70b => ("TheBloke/Llama-2-70B-GGML", "llama-2-70b.ggmlv3.q4_0.bin"),
|
||||
};
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("TheBloke/Llama-2-7B-GGML".to_string());
|
||||
api.get("llama-2-7b.ggmlv3.q4_0.bin")?
|
||||
let api = api.model(repo.to_string());
|
||||
api.get(filename)?
|
||||
}
|
||||
};
|
||||
Ok(model_path)
|
||||
|
Reference in New Issue
Block a user