mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Llama v3. (#2085)
* Llama v3. * Tweak the default params + handle special tokens. * Small tweak.
This commit is contained in:
@ -31,6 +31,7 @@ const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||
enum Which {
|
||||
V1,
|
||||
V2,
|
||||
V3,
|
||||
#[value(name = "solar-10.7b")]
|
||||
Solar10_7B,
|
||||
#[value(name = "tiny-llama-1.1b-chat")]
|
||||
@ -45,8 +46,8 @@ struct Args {
|
||||
cpu: bool,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
@ -90,11 +91,11 @@ struct Args {
|
||||
use_flash_attn: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.0)]
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
#[arg(long, default_value_t = 128)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
@ -120,11 +121,12 @@ fn main() -> Result<()> {
|
||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||
None => DType::F16,
|
||||
};
|
||||
let (llama, tokenizer_filename, mut cache) = {
|
||||
let (llama, tokenizer_filename, mut cache, config) = {
|
||||
let api = Api::new()?;
|
||||
let model_id = args.model_id.unwrap_or_else(|| match args.which {
|
||||
Which::V1 => "Narsil/amall-7b".to_string(),
|
||||
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
|
||||
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
|
||||
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
||||
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
|
||||
});
|
||||
@ -138,7 +140,7 @@ fn main() -> Result<()> {
|
||||
let config = config.into_config(args.use_flash_attn);
|
||||
|
||||
let filenames = match args.which {
|
||||
Which::V1 | Which::V2 | Which::Solar10_7B => {
|
||||
Which::V1 | Which::V2 | Which::V3 | Which::Solar10_7B => {
|
||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||
}
|
||||
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
||||
@ -146,10 +148,12 @@ fn main() -> Result<()> {
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
(Llama::load(vb, &config)?, tokenizer_filename, cache)
|
||||
(Llama::load(vb, &config)?, tokenizer_filename, cache, config)
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
|
||||
let eos_token_id = config
|
||||
.eos_token_id
|
||||
.or_else(|| tokenizer.token_to_id(EOS_TOKEN));
|
||||
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
||||
let mut tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
@ -160,7 +164,7 @@ fn main() -> Result<()> {
|
||||
|
||||
println!("starting the inference loop");
|
||||
print!("{prompt}");
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), args.top_p);
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut index_pos = 0;
|
||||
let mut token_generated = 0;
|
||||
|
@ -16,6 +16,8 @@ pub struct LlamaConfig {
|
||||
pub rms_norm_eps: f64,
|
||||
#[serde(default = "default_rope")]
|
||||
pub rope_theta: f32,
|
||||
pub bos_token_id: Option<u32>,
|
||||
pub eos_token_id: Option<u32>,
|
||||
}
|
||||
|
||||
fn default_rope() -> f32 {
|
||||
@ -34,6 +36,8 @@ impl LlamaConfig {
|
||||
rms_norm_eps: self.rms_norm_eps,
|
||||
rope_theta: self.rope_theta,
|
||||
use_flash_attn,
|
||||
bos_token_id: self.bos_token_id,
|
||||
eos_token_id: self.eos_token_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -49,6 +53,8 @@ pub struct Config {
|
||||
pub use_flash_attn: bool,
|
||||
pub rms_norm_eps: f64,
|
||||
pub rope_theta: f32,
|
||||
pub bos_token_id: Option<u32>,
|
||||
pub eos_token_id: Option<u32>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@ -63,6 +69,8 @@ impl Config {
|
||||
use_flash_attn,
|
||||
rms_norm_eps: 1e-6,
|
||||
rope_theta: 10_000.0,
|
||||
bos_token_id: None,
|
||||
eos_token_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -77,6 +85,8 @@ impl Config {
|
||||
use_flash_attn,
|
||||
rms_norm_eps: 1e-5,
|
||||
rope_theta: 10_000.0,
|
||||
bos_token_id: None,
|
||||
eos_token_id: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user