mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add support for Llama 3.1 (#2359)
* Add Llama 3.1 rope * Clippy * Format * Clippy * Add support for multiple eos tokens: * Untagged either * Remove either dep and fix settings.json * Make the max positional embeddings configurable
This commit is contained in:
@ -32,7 +32,9 @@ enum Which {
|
||||
V1,
|
||||
V2,
|
||||
V3,
|
||||
V31,
|
||||
V3Instruct,
|
||||
V31Instruct,
|
||||
#[value(name = "solar-10.7b")]
|
||||
Solar10_7B,
|
||||
#[value(name = "tiny-llama-1.1b-chat")]
|
||||
@ -133,6 +135,8 @@ fn main() -> Result<()> {
|
||||
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
|
||||
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
|
||||
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
|
||||
Which::V31 => "meta-llama/Meta-Llama-3.1-8B".to_string(),
|
||||
Which::V31Instruct => "meta-llama/Meta-Llama-3.1-8B-Instruct".to_string(),
|
||||
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
||||
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
|
||||
});
|
||||
@ -146,7 +150,13 @@ fn main() -> Result<()> {
|
||||
let config = config.into_config(args.use_flash_attn);
|
||||
|
||||
let filenames = match args.which {
|
||||
Which::V1 | Which::V2 | Which::V3 | Which::V3Instruct | Which::Solar10_7B => {
|
||||
Which::V1
|
||||
| Which::V2
|
||||
| Which::V3
|
||||
| Which::V3Instruct
|
||||
| Which::V31
|
||||
| Which::V31Instruct
|
||||
| Which::Solar10_7B => {
|
||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||
}
|
||||
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
||||
@ -157,9 +167,11 @@ fn main() -> Result<()> {
|
||||
(Llama::load(vb, &config)?, tokenizer_filename, cache, config)
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let eos_token_id = config
|
||||
.eos_token_id
|
||||
.or_else(|| tokenizer.token_to_id(EOS_TOKEN));
|
||||
let eos_token_id = config.eos_token_id.or_else(|| {
|
||||
tokenizer
|
||||
.token_to_id(EOS_TOKEN)
|
||||
.map(model::LlamaEosToks::Single)
|
||||
});
|
||||
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
||||
let mut tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
@ -217,8 +229,14 @@ fn main() -> Result<()> {
|
||||
token_generated += 1;
|
||||
tokens.push(next_token);
|
||||
|
||||
if Some(next_token) == eos_token_id {
|
||||
break;
|
||||
match eos_token_id {
|
||||
Some(model::LlamaEosToks::Single(eos_tok_id)) if next_token == eos_tok_id => {
|
||||
break;
|
||||
}
|
||||
Some(model::LlamaEosToks::Multiple(ref eos_ids)) if eos_ids.contains(&next_token) => {
|
||||
break;
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
|
Reference in New Issue
Block a user