mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 04:22:50 +00:00
Add support for Llama 3.1 (#2359)
* Add Llama 3.1 rope * Clippy * Format * Clippy * Add support for multiple eos tokens: * Untagged either * Remove either dep and fix settings.json * Make the max positional embeddings configurable
This commit is contained in:
@ -1,9 +1,33 @@
|
||||
use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||
use std::collections::HashMap;
|
||||
use std::{collections::HashMap, f32::consts::PI};
|
||||
|
||||
pub const MAX_SEQ_LEN: usize = 4096;
|
||||
pub const DEFAULT_MAX_SEQ_LEN: usize = 4096;
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize, Default)]
|
||||
pub enum Llama3RopeType {
|
||||
#[serde(rename = "llama3")]
|
||||
Llama3,
|
||||
#[default]
|
||||
#[serde(rename = "default")]
|
||||
Default,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize, Default)]
|
||||
pub struct Llama3RopeConfig {
|
||||
pub factor: f32,
|
||||
pub low_freq_factor: f32,
|
||||
pub high_freq_factor: f32,
|
||||
pub original_max_position_embeddings: usize,
|
||||
pub rope_type: Llama3RopeType,
|
||||
}
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum LlamaEosToks {
|
||||
Single(u32),
|
||||
Multiple(Vec<u32>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct LlamaConfig {
|
||||
@ -17,7 +41,9 @@ pub struct LlamaConfig {
|
||||
#[serde(default = "default_rope")]
|
||||
pub rope_theta: f32,
|
||||
pub bos_token_id: Option<u32>,
|
||||
pub eos_token_id: Option<u32>,
|
||||
pub eos_token_id: Option<LlamaEosToks>,
|
||||
pub rope_scaling: Option<Llama3RopeConfig>,
|
||||
pub max_position_embeddings: usize,
|
||||
}
|
||||
|
||||
impl LlamaConfig {
|
||||
@ -44,6 +70,8 @@ impl LlamaConfig {
|
||||
use_flash_attn,
|
||||
bos_token_id: self.bos_token_id,
|
||||
eos_token_id: self.eos_token_id,
|
||||
rope_scaling: self.rope_scaling,
|
||||
max_position_embeddings: self.max_position_embeddings,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -60,7 +88,9 @@ pub struct Config {
|
||||
pub rms_norm_eps: f64,
|
||||
pub rope_theta: f32,
|
||||
pub bos_token_id: Option<u32>,
|
||||
pub eos_token_id: Option<u32>,
|
||||
pub eos_token_id: Option<LlamaEosToks>,
|
||||
pub rope_scaling: Option<Llama3RopeConfig>,
|
||||
pub max_position_embeddings: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@ -77,6 +107,8 @@ impl Config {
|
||||
rope_theta: 10_000.0,
|
||||
bos_token_id: None,
|
||||
eos_token_id: None,
|
||||
rope_scaling: None,
|
||||
max_position_embeddings: DEFAULT_MAX_SEQ_LEN,
|
||||
}
|
||||
}
|
||||
|
||||
@ -93,6 +125,8 @@ impl Config {
|
||||
rope_theta: 10_000.0,
|
||||
bos_token_id: None,
|
||||
eos_token_id: None,
|
||||
rope_scaling: None,
|
||||
max_position_embeddings: DEFAULT_MAX_SEQ_LEN,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -107,18 +141,54 @@ pub struct Cache {
|
||||
device: Device,
|
||||
}
|
||||
|
||||
fn calculate_default_inv_freq(cfg: &Config) -> Vec<f32> {
|
||||
let head_dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||
(0..head_dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
|
||||
.collect()
|
||||
}
|
||||
|
||||
impl Cache {
|
||||
pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> {
|
||||
// precompute freqs_cis
|
||||
let n_elem = config.hidden_size / config.num_attention_heads;
|
||||
let theta: Vec<_> = (0..n_elem)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||
let theta = match &config.rope_scaling {
|
||||
None
|
||||
| Some(Llama3RopeConfig {
|
||||
rope_type: Llama3RopeType::Default,
|
||||
..
|
||||
}) => calculate_default_inv_freq(config),
|
||||
Some(rope_scaling) => {
|
||||
let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
|
||||
/ rope_scaling.low_freq_factor;
|
||||
let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
|
||||
/ rope_scaling.high_freq_factor;
|
||||
|
||||
calculate_default_inv_freq(config)
|
||||
.into_iter()
|
||||
.map(|freq| {
|
||||
let wavelen = 2. * PI / freq;
|
||||
if wavelen < high_freq_wavelen {
|
||||
freq
|
||||
} else if wavelen > low_freq_wavelen {
|
||||
freq / rope_scaling.factor
|
||||
} else {
|
||||
let smooth = (rope_scaling.original_max_position_embeddings as f32
|
||||
/ wavelen
|
||||
- rope_scaling.low_freq_factor)
|
||||
/ (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor);
|
||||
(1. - smooth) * freq / rope_scaling.factor + smooth * freq
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
};
|
||||
|
||||
let theta = Tensor::new(theta, device)?;
|
||||
|
||||
let idx_theta = Tensor::arange(0, config.max_position_embeddings as u32, device)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((MAX_SEQ_LEN, 1))?
|
||||
.reshape((config.max_position_embeddings, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
// This is different from the paper, see:
|
||||
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
||||
@ -160,6 +230,7 @@ struct CausalSelfAttention {
|
||||
use_flash_attn: bool,
|
||||
span: tracing::Span,
|
||||
span_rot: tracing::Span,
|
||||
max_position_embeddings: usize,
|
||||
}
|
||||
|
||||
#[cfg(feature = "flash-attn")]
|
||||
@ -220,15 +291,23 @@ impl CausalSelfAttention {
|
||||
k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
|
||||
v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
|
||||
let k_seq_len = k.dims()[1];
|
||||
if k_seq_len > MAX_SEQ_LEN {
|
||||
if k_seq_len > self.max_position_embeddings {
|
||||
k = k
|
||||
.narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
||||
.narrow(
|
||||
D::Minus1,
|
||||
k_seq_len - self.max_position_embeddings,
|
||||
self.max_position_embeddings,
|
||||
)?
|
||||
.contiguous()?
|
||||
}
|
||||
let v_seq_len = v.dims()[1];
|
||||
if v_seq_len > 2 * MAX_SEQ_LEN {
|
||||
if v_seq_len > 2 * self.max_position_embeddings {
|
||||
v = v
|
||||
.narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
||||
.narrow(
|
||||
D::Minus1,
|
||||
v_seq_len - self.max_position_embeddings,
|
||||
self.max_position_embeddings,
|
||||
)?
|
||||
.contiguous()?
|
||||
}
|
||||
}
|
||||
@ -291,6 +370,7 @@ impl CausalSelfAttention {
|
||||
use_flash_attn: cfg.use_flash_attn,
|
||||
span,
|
||||
span_rot,
|
||||
max_position_embeddings: cfg.max_position_embeddings,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user