diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index d9d1e21a..f3cf17bc 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -127,7 +127,11 @@ fn main() -> Result<()> { let args = Args::parse(); let device = candle_examples::device(args.cpu)?; - let config = Config::config_7b(args.use_flash_attn); + let config = if args.v1 { + Config::config_7b_v1(args.use_flash_attn) + } else { + Config::config_7b_v2(args.use_flash_attn) + }; let dtype = if args.use_f32 { DType::F32 } else { DType::F16 }; let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; let (llama, tokenizer_filename) = match args.npy { diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index efb9aeef..dba1d535 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -18,7 +18,21 @@ pub struct Config { } impl Config { - pub fn config_7b(use_flash_attn: bool) -> Self { + pub fn config_7b_v1(use_flash_attn: bool) -> Self { + Self { + hidden_size: 4096, + intermediate_size: 11008, + vocab_size: 32000, + n_layer: 32, + n_head: 32, + n_embd: 4096, + n_key_value_head: 32, + use_flash_attn, + rms_norm_eps: 1e-6, + } + } + + pub fn config_7b_v2(use_flash_attn: bool) -> Self { Self { hidden_size: 4096, intermediate_size: 11008,