* Llama v3.

* Tweak the default params + handle special tokens.

* Small tweak.
This commit is contained in:
Laurent Mazare
2024-04-18 22:19:54 +02:00
committed by GitHub
parent 1690ab45d2
commit e6ee7ba4d4
2 changed files with 23 additions and 9 deletions

View File

@ -16,6 +16,8 @@ pub struct LlamaConfig {
pub rms_norm_eps: f64,
#[serde(default = "default_rope")]
pub rope_theta: f32,
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<u32>,
}
fn default_rope() -> f32 {
@ -34,6 +36,8 @@ impl LlamaConfig {
rms_norm_eps: self.rms_norm_eps,
rope_theta: self.rope_theta,
use_flash_attn,
bos_token_id: self.bos_token_id,
eos_token_id: self.eos_token_id,
}
}
}
@ -49,6 +53,8 @@ pub struct Config {
pub use_flash_attn: bool,
pub rms_norm_eps: f64,
pub rope_theta: f32,
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<u32>,
}
impl Config {
@ -63,6 +69,8 @@ impl Config {
use_flash_attn,
rms_norm_eps: 1e-6,
rope_theta: 10_000.0,
bos_token_id: None,
eos_token_id: None,
}
}
@ -77,6 +85,8 @@ impl Config {
use_flash_attn,
rms_norm_eps: 1e-5,
rope_theta: 10_000.0,
bos_token_id: None,
eos_token_id: None,
}
}
}