Add support for Llama 3.1 (#2359)

* Add Llama 3.1 rope

* Clippy

* Format

* Clippy

* Add support for multiple eos tokens:

* Untagged either

* Remove either dep and fix settings.json

* Make the max positional embeddings configurable
This commit is contained in:
Eric Buehler
2024-07-26 15:32:26 -04:00
committed by GitHub
parent ddafc61055
commit 0f5cbb08b3
24 changed files with 165 additions and 71 deletions

View File

@ -32,7 +32,9 @@ enum Which {
V1,
V2,
V3,
V31,
V3Instruct,
V31Instruct,
#[value(name = "solar-10.7b")]
Solar10_7B,
#[value(name = "tiny-llama-1.1b-chat")]
@ -133,6 +135,8 @@ fn main() -> Result<()> {
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
Which::V31 => "meta-llama/Meta-Llama-3.1-8B".to_string(),
Which::V31Instruct => "meta-llama/Meta-Llama-3.1-8B-Instruct".to_string(),
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
});
@ -146,7 +150,13 @@ fn main() -> Result<()> {
let config = config.into_config(args.use_flash_attn);
let filenames = match args.which {
Which::V1 | Which::V2 | Which::V3 | Which::V3Instruct | Which::Solar10_7B => {
Which::V1
| Which::V2
| Which::V3
| Which::V3Instruct
| Which::V31
| Which::V31Instruct
| Which::Solar10_7B => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
}
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
@ -157,9 +167,11 @@ fn main() -> Result<()> {
(Llama::load(vb, &config)?, tokenizer_filename, cache, config)
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let eos_token_id = config
.eos_token_id
.or_else(|| tokenizer.token_to_id(EOS_TOKEN));
let eos_token_id = config.eos_token_id.or_else(|| {
tokenizer
.token_to_id(EOS_TOKEN)
.map(model::LlamaEosToks::Single)
});
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
let mut tokens = tokenizer
.encode(prompt, true)
@ -217,8 +229,14 @@ fn main() -> Result<()> {
token_generated += 1;
tokens.push(next_token);
if Some(next_token) == eos_token_id {
break;
match eos_token_id {
Some(model::LlamaEosToks::Single(eos_tok_id)) if next_token == eos_tok_id => {
break;
}
Some(model::LlamaEosToks::Multiple(ref eos_ids)) if eos_ids.contains(&next_token) => {
break;
}
_ => (),
}
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");