Support both llama v1 and llama v2. (#272)

This commit is contained in:
Laurent Mazare
2023-07-28 18:40:59 +01:00
committed by GitHub
parent 7513a5e005
commit 50d8273ae4
2 changed files with 20 additions and 2 deletions

View File

@ -127,7 +127,11 @@ fn main() -> Result<()> {
let args = Args::parse(); let args = Args::parse();
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let config = Config::config_7b(args.use_flash_attn); let config = if args.v1 {
Config::config_7b_v1(args.use_flash_attn)
} else {
Config::config_7b_v2(args.use_flash_attn)
};
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 }; let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let (llama, tokenizer_filename) = match args.npy { let (llama, tokenizer_filename) = match args.npy {

View File

@ -18,7 +18,21 @@ pub struct Config {
} }
impl Config { impl Config {
pub fn config_7b(use_flash_attn: bool) -> Self { pub fn config_7b_v1(use_flash_attn: bool) -> Self {
Self {
hidden_size: 4096,
intermediate_size: 11008,
vocab_size: 32000,
n_layer: 32,
n_head: 32,
n_embd: 4096,
n_key_value_head: 32,
use_flash_attn,
rms_norm_eps: 1e-6,
}
}
pub fn config_7b_v2(use_flash_attn: bool) -> Self {
Self { Self {
hidden_size: 4096, hidden_size: 4096,
intermediate_size: 11008, intermediate_size: 11008,