mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Llama v3. (#2085)
* Llama v3. * Tweak the default params + handle special tokens. * Small tweak.
This commit is contained in:
@ -31,6 +31,7 @@ const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
|||||||
enum Which {
|
enum Which {
|
||||||
V1,
|
V1,
|
||||||
V2,
|
V2,
|
||||||
|
V3,
|
||||||
#[value(name = "solar-10.7b")]
|
#[value(name = "solar-10.7b")]
|
||||||
Solar10_7B,
|
Solar10_7B,
|
||||||
#[value(name = "tiny-llama-1.1b-chat")]
|
#[value(name = "tiny-llama-1.1b-chat")]
|
||||||
@ -45,8 +46,8 @@ struct Args {
|
|||||||
cpu: bool,
|
cpu: bool,
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
/// The temperature used to generate samples.
|
||||||
#[arg(long)]
|
#[arg(long, default_value_t = 0.8)]
|
||||||
temperature: Option<f64>,
|
temperature: f64,
|
||||||
|
|
||||||
/// Nucleus sampling probability cutoff.
|
/// Nucleus sampling probability cutoff.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@ -90,11 +91,11 @@ struct Args {
|
|||||||
use_flash_attn: bool,
|
use_flash_attn: bool,
|
||||||
|
|
||||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
#[arg(long, default_value_t = 1.0)]
|
#[arg(long, default_value_t = 1.1)]
|
||||||
repeat_penalty: f32,
|
repeat_penalty: f32,
|
||||||
|
|
||||||
/// The context size to consider for the repeat penalty.
|
/// The context size to consider for the repeat penalty.
|
||||||
#[arg(long, default_value_t = 64)]
|
#[arg(long, default_value_t = 128)]
|
||||||
repeat_last_n: usize,
|
repeat_last_n: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -120,11 +121,12 @@ fn main() -> Result<()> {
|
|||||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||||
None => DType::F16,
|
None => DType::F16,
|
||||||
};
|
};
|
||||||
let (llama, tokenizer_filename, mut cache) = {
|
let (llama, tokenizer_filename, mut cache, config) = {
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let model_id = args.model_id.unwrap_or_else(|| match args.which {
|
let model_id = args.model_id.unwrap_or_else(|| match args.which {
|
||||||
Which::V1 => "Narsil/amall-7b".to_string(),
|
Which::V1 => "Narsil/amall-7b".to_string(),
|
||||||
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
|
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
|
||||||
|
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
|
||||||
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
||||||
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
|
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
|
||||||
});
|
});
|
||||||
@ -138,7 +140,7 @@ fn main() -> Result<()> {
|
|||||||
let config = config.into_config(args.use_flash_attn);
|
let config = config.into_config(args.use_flash_attn);
|
||||||
|
|
||||||
let filenames = match args.which {
|
let filenames = match args.which {
|
||||||
Which::V1 | Which::V2 | Which::Solar10_7B => {
|
Which::V1 | Which::V2 | Which::V3 | Which::Solar10_7B => {
|
||||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||||
}
|
}
|
||||||
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
||||||
@ -146,10 +148,12 @@ fn main() -> Result<()> {
|
|||||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||||
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
(Llama::load(vb, &config)?, tokenizer_filename, cache)
|
(Llama::load(vb, &config)?, tokenizer_filename, cache, config)
|
||||||
};
|
};
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
|
let eos_token_id = config
|
||||||
|
.eos_token_id
|
||||||
|
.or_else(|| tokenizer.token_to_id(EOS_TOKEN));
|
||||||
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
||||||
let mut tokens = tokenizer
|
let mut tokens = tokenizer
|
||||||
.encode(prompt, true)
|
.encode(prompt, true)
|
||||||
@ -160,7 +164,7 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
print!("{prompt}");
|
print!("{prompt}");
|
||||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
|
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), args.top_p);
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
let mut index_pos = 0;
|
let mut index_pos = 0;
|
||||||
let mut token_generated = 0;
|
let mut token_generated = 0;
|
||||||
|
@ -16,6 +16,8 @@ pub struct LlamaConfig {
|
|||||||
pub rms_norm_eps: f64,
|
pub rms_norm_eps: f64,
|
||||||
#[serde(default = "default_rope")]
|
#[serde(default = "default_rope")]
|
||||||
pub rope_theta: f32,
|
pub rope_theta: f32,
|
||||||
|
pub bos_token_id: Option<u32>,
|
||||||
|
pub eos_token_id: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_rope() -> f32 {
|
fn default_rope() -> f32 {
|
||||||
@ -34,6 +36,8 @@ impl LlamaConfig {
|
|||||||
rms_norm_eps: self.rms_norm_eps,
|
rms_norm_eps: self.rms_norm_eps,
|
||||||
rope_theta: self.rope_theta,
|
rope_theta: self.rope_theta,
|
||||||
use_flash_attn,
|
use_flash_attn,
|
||||||
|
bos_token_id: self.bos_token_id,
|
||||||
|
eos_token_id: self.eos_token_id,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -49,6 +53,8 @@ pub struct Config {
|
|||||||
pub use_flash_attn: bool,
|
pub use_flash_attn: bool,
|
||||||
pub rms_norm_eps: f64,
|
pub rms_norm_eps: f64,
|
||||||
pub rope_theta: f32,
|
pub rope_theta: f32,
|
||||||
|
pub bos_token_id: Option<u32>,
|
||||||
|
pub eos_token_id: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
@ -63,6 +69,8 @@ impl Config {
|
|||||||
use_flash_attn,
|
use_flash_attn,
|
||||||
rms_norm_eps: 1e-6,
|
rms_norm_eps: 1e-6,
|
||||||
rope_theta: 10_000.0,
|
rope_theta: 10_000.0,
|
||||||
|
bos_token_id: None,
|
||||||
|
eos_token_id: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -77,6 +85,8 @@ impl Config {
|
|||||||
use_flash_attn,
|
use_flash_attn,
|
||||||
rms_norm_eps: 1e-5,
|
rms_norm_eps: 1e-5,
|
||||||
rope_theta: 10_000.0,
|
rope_theta: 10_000.0,
|
||||||
|
bos_token_id: None,
|
||||||
|
eos_token_id: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user