* Llama v3.

* Tweak the default params + handle special tokens.

* Small tweak.
This commit is contained in:
Laurent Mazare
2024-04-18 22:19:54 +02:00
committed by GitHub
parent 1690ab45d2
commit e6ee7ba4d4
2 changed files with 23 additions and 9 deletions

View File

@ -31,6 +31,7 @@ const DEFAULT_PROMPT: &str = "My favorite theorem is ";
enum Which {
V1,
V2,
V3,
#[value(name = "solar-10.7b")]
Solar10_7B,
#[value(name = "tiny-llama-1.1b-chat")]
@ -45,8 +46,8 @@ struct Args {
cpu: bool,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
#[arg(long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
@ -90,11 +91,11 @@ struct Args {
use_flash_attn: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.0)]
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
#[arg(long, default_value_t = 128)]
repeat_last_n: usize,
}
@ -120,11 +121,12 @@ fn main() -> Result<()> {
Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::F16,
};
let (llama, tokenizer_filename, mut cache) = {
let (llama, tokenizer_filename, mut cache, config) = {
let api = Api::new()?;
let model_id = args.model_id.unwrap_or_else(|| match args.which {
Which::V1 => "Narsil/amall-7b".to_string(),
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
});
@ -138,7 +140,7 @@ fn main() -> Result<()> {
let config = config.into_config(args.use_flash_attn);
let filenames = match args.which {
Which::V1 | Which::V2 | Which::Solar10_7B => {
Which::V1 | Which::V2 | Which::V3 | Which::Solar10_7B => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
}
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
@ -146,10 +148,12 @@ fn main() -> Result<()> {
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
(Llama::load(vb, &config)?, tokenizer_filename, cache)
(Llama::load(vb, &config)?, tokenizer_filename, cache, config)
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
let eos_token_id = config
.eos_token_id
.or_else(|| tokenizer.token_to_id(EOS_TOKEN));
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
let mut tokens = tokenizer
.encode(prompt, true)
@ -160,7 +164,7 @@ fn main() -> Result<()> {
println!("starting the inference loop");
print!("{prompt}");
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), args.top_p);
let start_gen = std::time::Instant::now();
let mut index_pos = 0;
let mut token_generated = 0;

View File

@ -16,6 +16,8 @@ pub struct LlamaConfig {
pub rms_norm_eps: f64,
#[serde(default = "default_rope")]
pub rope_theta: f32,
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<u32>,
}
fn default_rope() -> f32 {
@ -34,6 +36,8 @@ impl LlamaConfig {
rms_norm_eps: self.rms_norm_eps,
rope_theta: self.rope_theta,
use_flash_attn,
bos_token_id: self.bos_token_id,
eos_token_id: self.eos_token_id,
}
}
}
@ -49,6 +53,8 @@ pub struct Config {
pub use_flash_attn: bool,
pub rms_norm_eps: f64,
pub rope_theta: f32,
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<u32>,
}
impl Config {
@ -63,6 +69,8 @@ impl Config {
use_flash_attn,
rms_norm_eps: 1e-6,
rope_theta: 10_000.0,
bos_token_id: None,
eos_token_id: None,
}
}
@ -77,6 +85,8 @@ impl Config {
use_flash_attn,
rms_norm_eps: 1e-5,
rope_theta: 10_000.0,
bos_token_id: None,
eos_token_id: None,
}
}
}