From e6ee7ba4d46de6e5e1e003319da4a49a3a7a0813 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 18 Apr 2024 22:19:54 +0200 Subject: [PATCH] Llama v3. (#2085) * Llama v3. * Tweak the default params + handle special tokens. * Small tweak. --- candle-examples/examples/llama/main.rs | 22 +++++++++++++--------- candle-transformers/src/models/llama.rs | 10 ++++++++++ 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index f7998396..dbff1b7d 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -31,6 +31,7 @@ const DEFAULT_PROMPT: &str = "My favorite theorem is "; enum Which { V1, V2, + V3, #[value(name = "solar-10.7b")] Solar10_7B, #[value(name = "tiny-llama-1.1b-chat")] @@ -45,8 +46,8 @@ struct Args { cpu: bool, /// The temperature used to generate samples. - #[arg(long)] - temperature: Option, + #[arg(long, default_value_t = 0.8)] + temperature: f64, /// Nucleus sampling probability cutoff. #[arg(long)] @@ -90,11 +91,11 @@ struct Args { use_flash_attn: bool, /// Penalty to be applied for repeating tokens, 1. means no penalty. - #[arg(long, default_value_t = 1.0)] + #[arg(long, default_value_t = 1.1)] repeat_penalty: f32, /// The context size to consider for the repeat penalty. - #[arg(long, default_value_t = 64)] + #[arg(long, default_value_t = 128)] repeat_last_n: usize, } @@ -120,11 +121,12 @@ fn main() -> Result<()> { Some(dtype) => bail!("Unsupported dtype {dtype}"), None => DType::F16, }; - let (llama, tokenizer_filename, mut cache) = { + let (llama, tokenizer_filename, mut cache, config) = { let api = Api::new()?; let model_id = args.model_id.unwrap_or_else(|| match args.which { Which::V1 => "Narsil/amall-7b".to_string(), Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(), + Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(), Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(), Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(), }); @@ -138,7 +140,7 @@ fn main() -> Result<()> { let config = config.into_config(args.use_flash_attn); let filenames = match args.which { - Which::V1 | Which::V2 | Which::Solar10_7B => { + Which::V1 | Which::V2 | Which::V3 | Which::Solar10_7B => { candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")? } Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?], @@ -146,10 +148,12 @@ fn main() -> Result<()> { let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; - (Llama::load(vb, &config)?, tokenizer_filename, cache) + (Llama::load(vb, &config)?, tokenizer_filename, cache, config) }; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; - let eos_token_id = tokenizer.token_to_id(EOS_TOKEN); + let eos_token_id = config + .eos_token_id + .or_else(|| tokenizer.token_to_id(EOS_TOKEN)); let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str()); let mut tokens = tokenizer .encode(prompt, true) @@ -160,7 +164,7 @@ fn main() -> Result<()> { println!("starting the inference loop"); print!("{prompt}"); - let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p); + let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), args.top_p); let start_gen = std::time::Instant::now(); let mut index_pos = 0; let mut token_generated = 0; diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index f3d482eb..97a40d37 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -16,6 +16,8 @@ pub struct LlamaConfig { pub rms_norm_eps: f64, #[serde(default = "default_rope")] pub rope_theta: f32, + pub bos_token_id: Option, + pub eos_token_id: Option, } fn default_rope() -> f32 { @@ -34,6 +36,8 @@ impl LlamaConfig { rms_norm_eps: self.rms_norm_eps, rope_theta: self.rope_theta, use_flash_attn, + bos_token_id: self.bos_token_id, + eos_token_id: self.eos_token_id, } } } @@ -49,6 +53,8 @@ pub struct Config { pub use_flash_attn: bool, pub rms_norm_eps: f64, pub rope_theta: f32, + pub bos_token_id: Option, + pub eos_token_id: Option, } impl Config { @@ -63,6 +69,8 @@ impl Config { use_flash_attn, rms_norm_eps: 1e-6, rope_theta: 10_000.0, + bos_token_id: None, + eos_token_id: None, } } @@ -77,6 +85,8 @@ impl Config { use_flash_attn, rms_norm_eps: 1e-5, rope_theta: 10_000.0, + bos_token_id: None, + eos_token_id: None, } } }