mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Add some llama-v2 variants. (#545)
This commit is contained in:
@ -5,7 +5,7 @@ extern crate intel_mkl_src;
|
|||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use clap::Parser;
|
use clap::{Parser, ValueEnum};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
@ -291,6 +291,16 @@ impl ModelWeights {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "7b")]
|
||||||
|
L7b,
|
||||||
|
#[value(name = "13b")]
|
||||||
|
L13b,
|
||||||
|
#[value(name = "70b")]
|
||||||
|
L70b,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
@ -333,6 +343,10 @@ struct Args {
|
|||||||
/// The context size to consider for the repeat penalty.
|
/// The context size to consider for the repeat penalty.
|
||||||
#[arg(long, default_value_t = 64)]
|
#[arg(long, default_value_t = 64)]
|
||||||
repeat_last_n: usize,
|
repeat_last_n: usize,
|
||||||
|
|
||||||
|
/// The model size to use.
|
||||||
|
#[arg(long, default_value = "7b")]
|
||||||
|
which: Which,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Args {
|
impl Args {
|
||||||
@ -352,9 +366,14 @@ impl Args {
|
|||||||
let model_path = match &self.model {
|
let model_path = match &self.model {
|
||||||
Some(config) => std::path::PathBuf::from(config),
|
Some(config) => std::path::PathBuf::from(config),
|
||||||
None => {
|
None => {
|
||||||
|
let (repo, filename) = match self.which {
|
||||||
|
Which::L7b => ("TheBloke/Llama-2-7B-GGML", "llama-2-7b.ggmlv3.q4_0.bin"),
|
||||||
|
Which::L13b => ("TheBloke/Llama-2-13B-GGML", "llama-2-13b.ggmlv3.q4_0.bin"),
|
||||||
|
Which::L70b => ("TheBloke/Llama-2-70B-GGML", "llama-2-70b.ggmlv3.q4_0.bin"),
|
||||||
|
};
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let api = api.model("TheBloke/Llama-2-7B-GGML".to_string());
|
let api = api.model(repo.to_string());
|
||||||
api.get("llama-2-7b.ggmlv3.q4_0.bin")?
|
api.get(filename)?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
Ok(model_path)
|
Ok(model_path)
|
||||||
|
Reference in New Issue
Block a user